@@ -288,17 +288,19 @@ | |||
[[projects]] | |||
name = "github.com/go-xorm/builder" | |||
packages = ["."] | |||
revision = "488224409dd8aa2ce7a5baf8d10d55764a913738" | |||
revision = "dc8bf48f58fab2b4da338ffd25191905fd741b8f" | |||
version = "v0.3.0" | |||
[[projects]] | |||
name = "github.com/go-xorm/core" | |||
packages = ["."] | |||
revision = "cb1d0ca71f42d3ee1bf4aba7daa16099bc31a7e9" | |||
revision = "c10e21e7e1cec20e09398f2dfae385e58c8df555" | |||
version = "v0.6.0" | |||
[[projects]] | |||
name = "github.com/go-xorm/xorm" | |||
packages = ["."] | |||
revision = "d4149d1eee0c2c488a74a5863fd9caf13d60fd03" | |||
revision = "ad69f7d8f0861a29438154bb0a20b60501298480" | |||
[[projects]] | |||
branch = "master" | |||
@@ -701,6 +703,6 @@ | |||
[solve-meta] | |||
analyzer-name = "dep" | |||
analyzer-version = 1 | |||
inputs-digest = "59451a3ad1d449f75c5e9035daf542a377c5c4a397e219bebec0aa0007ab9c39" | |||
inputs-digest = "5ae18d543bbb8186589c003422b333097d67bb5fed8b4c294be70c012ccffc94" | |||
solver-name = "gps-cdcl" | |||
solver-version = 1 |
@@ -33,7 +33,7 @@ ignored = ["google.golang.org/appengine*"] | |||
[[override]] | |||
name = "github.com/go-xorm/xorm" | |||
#version = "0.6.5" | |||
revision = "d4149d1eee0c2c488a74a5863fd9caf13d60fd03" | |||
revision = "ad69f7d8f0861a29438154bb0a20b60501298480" | |||
[[override]] | |||
name = "github.com/go-sql-driver/mysql" |
@@ -1297,7 +1297,7 @@ func getParticipantsByIssueID(e Engine, issueID int64) ([]*User, error) { | |||
And("`comment`.type = ?", CommentTypeComment). | |||
And("`user`.is_active = ?", true). | |||
And("`user`.prohibit_login = ?", false). | |||
Join("INNER", "user", "`user`.id = `comment`.poster_id"). | |||
Join("INNER", "`user`", "`user`.id = `comment`.poster_id"). | |||
Distinct("poster_id"). | |||
Find(&userIDs); err != nil { | |||
return nil, fmt.Errorf("get poster IDs: %v", err) |
@@ -166,7 +166,7 @@ func (issues IssueList) loadAssignees(e Engine) error { | |||
var assignees = make(map[int64][]*User, len(issues)) | |||
rows, err := e.Table("issue_assignees"). | |||
Join("INNER", "user", "`user`.id = `issue_assignees`.assignee_id"). | |||
Join("INNER", "`user`", "`user`.id = `issue_assignees`.assignee_id"). | |||
In("`issue_assignees`.issue_id", issues.getIssueIDs()). | |||
Rows(new(AssigneeIssue)) | |||
if err != nil { |
@@ -67,7 +67,7 @@ func getIssueWatchers(e Engine, issueID int64) (watches []*IssueWatch, err error | |||
Where("`issue_watch`.issue_id = ?", issueID). | |||
And("`user`.is_active = ?", true). | |||
And("`user`.prohibit_login = ?", false). | |||
Join("INNER", "user", "`user`.id = `issue_watch`.user_id"). | |||
Join("INNER", "`user`", "`user`.id = `issue_watch`.user_id"). | |||
Find(&watches) | |||
return | |||
} |
@@ -383,7 +383,7 @@ func GetOwnedOrgsByUserIDDesc(userID int64, desc string) ([]*User, error) { | |||
func GetOrgUsersByUserID(uid int64, all bool) ([]*OrgUser, error) { | |||
ous := make([]*OrgUser, 0, 10) | |||
sess := x. | |||
Join("LEFT", "user", "`org_user`.org_id=`user`.id"). | |||
Join("LEFT", "`user`", "`org_user`.org_id=`user`.id"). | |||
Where("`org_user`.uid=?", uid) | |||
if !all { | |||
// Only show public organizations | |||
@@ -575,7 +575,7 @@ func (org *User) getUserTeams(e Engine, userID int64, cols ...string) ([]*Team, | |||
return teams, e. | |||
Where("`team_user`.org_id = ?", org.ID). | |||
Join("INNER", "team_user", "`team_user`.team_id = team.id"). | |||
Join("INNER", "user", "`user`.id=team_user.uid"). | |||
Join("INNER", "`user`", "`user`.id=team_user.uid"). | |||
And("`team_user`.uid = ?", userID). | |||
Asc("`user`.name"). | |||
Cols(cols...). |
@@ -1958,7 +1958,7 @@ func DeleteRepository(doer *User, uid, repoID int64) error { | |||
func GetRepositoryByOwnerAndName(ownerName, repoName string) (*Repository, error) { | |||
var repo Repository | |||
has, err := x.Select("repository.*"). | |||
Join("INNER", "user", "`user`.id = repository.owner_id"). | |||
Join("INNER", "`user`", "`user`.id = repository.owner_id"). | |||
Where("repository.lower_name = ?", strings.ToLower(repoName)). | |||
And("`user`.lower_name = ?", strings.ToLower(ownerName)). | |||
Get(&repo) |
@@ -54,7 +54,7 @@ func getWatchers(e Engine, repoID int64) ([]*Watch, error) { | |||
return watches, e.Where("`watch`.repo_id=?", repoID). | |||
And("`user`.is_active=?", true). | |||
And("`user`.prohibit_login=?", false). | |||
Join("INNER", "user", "`user`.id = `watch`.user_id"). | |||
Join("INNER", "`user`", "`user`.id = `watch`.user_id"). | |||
Find(&watches) | |||
} | |||
@@ -374,9 +374,9 @@ func (u *User) GetFollowers(page int) ([]*User, error) { | |||
Limit(ItemsPerPage, (page-1)*ItemsPerPage). | |||
Where("follow.follow_id=?", u.ID) | |||
if setting.UsePostgreSQL { | |||
sess = sess.Join("LEFT", "follow", `"user".id=follow.user_id`) | |||
sess = sess.Join("LEFT", "follow", "`user`.id=follow.user_id") | |||
} else { | |||
sess = sess.Join("LEFT", "follow", "user.id=follow.user_id") | |||
sess = sess.Join("LEFT", "follow", "`user`.id=follow.user_id") | |||
} | |||
return users, sess.Find(&users) | |||
} | |||
@@ -393,9 +393,9 @@ func (u *User) GetFollowing(page int) ([]*User, error) { | |||
Limit(ItemsPerPage, (page-1)*ItemsPerPage). | |||
Where("follow.user_id=?", u.ID) | |||
if setting.UsePostgreSQL { | |||
sess = sess.Join("LEFT", "follow", `"user".id=follow.follow_id`) | |||
sess = sess.Join("LEFT", "follow", "`user`.id=follow.follow_id") | |||
} else { | |||
sess = sess.Join("LEFT", "follow", "user.id=follow.follow_id") | |||
sess = sess.Join("LEFT", "follow", "`user`.id=follow.follow_id") | |||
} | |||
return users, sess.Find(&users) | |||
} |
@@ -4,6 +4,10 @@ | |||
package builder | |||
import ( | |||
"fmt" | |||
) | |||
type optype byte | |||
const ( | |||
@@ -29,6 +33,9 @@ type Builder struct { | |||
joins []join | |||
inserts Eq | |||
updates []Eq | |||
orderBy string | |||
groupBy string | |||
having string | |||
} | |||
// Select creates a select Builder | |||
@@ -67,6 +74,11 @@ func (b *Builder) From(tableName string) *Builder { | |||
return b | |||
} | |||
// TableName returns the table name | |||
func (b *Builder) TableName() string { | |||
return b.tableName | |||
} | |||
// Into sets insert table name | |||
func (b *Builder) Into(tableName string) *Builder { | |||
b.tableName = tableName | |||
@@ -178,6 +190,33 @@ func (b *Builder) ToSQL() (string, []interface{}, error) { | |||
return w.writer.String(), w.args, nil | |||
} | |||
// ConvertPlaceholder replaces ? to $1, $2 ... or :1, :2 ... according prefix | |||
func ConvertPlaceholder(sql, prefix string) (string, error) { | |||
buf := StringBuilder{} | |||
var j, start = 0, 0 | |||
for i := 0; i < len(sql); i++ { | |||
if sql[i] == '?' { | |||
_, err := buf.WriteString(sql[start:i]) | |||
if err != nil { | |||
return "", err | |||
} | |||
start = i + 1 | |||
_, err = buf.WriteString(prefix) | |||
if err != nil { | |||
return "", err | |||
} | |||
j = j + 1 | |||
_, err = buf.WriteString(fmt.Sprintf("%d", j)) | |||
if err != nil { | |||
return "", err | |||
} | |||
} | |||
} | |||
return buf.String(), nil | |||
} | |||
// ToSQL convert a builder or condtions to SQL and args | |||
func ToSQL(cond interface{}) (string, []interface{}, error) { | |||
switch cond.(type) { |
@@ -15,7 +15,7 @@ func (b *Builder) insertWriteTo(w Writer) error { | |||
return errors.New("no table indicated") | |||
} | |||
if len(b.inserts) <= 0 { | |||
return errors.New("no column to be update") | |||
return errors.New("no column to be insert") | |||
} | |||
if _, err := fmt.Fprintf(w, "INSERT INTO %s (", b.tableName); err != nil { | |||
@@ -26,7 +26,9 @@ func (b *Builder) insertWriteTo(w Writer) error { | |||
var bs []byte | |||
var valBuffer = bytes.NewBuffer(bs) | |||
var i = 0 | |||
for col, value := range b.inserts { | |||
for _, col := range b.inserts.sortedKeys() { | |||
value := b.inserts[col] | |||
fmt.Fprint(w, col) | |||
if e, ok := value.(expr); ok { | |||
fmt.Fprint(valBuffer, e.sql) |
@@ -34,24 +34,65 @@ func (b *Builder) selectWriteTo(w Writer) error { | |||
} | |||
} | |||
if _, err := fmt.Fprintf(w, " FROM %s", b.tableName); err != nil { | |||
if _, err := fmt.Fprint(w, " FROM ", b.tableName); err != nil { | |||
return err | |||
} | |||
for _, v := range b.joins { | |||
fmt.Fprintf(w, " %s JOIN %s ON ", v.joinType, v.joinTable) | |||
if _, err := fmt.Fprintf(w, " %s JOIN %s ON ", v.joinType, v.joinTable); err != nil { | |||
return err | |||
} | |||
if err := v.joinCond.WriteTo(w); err != nil { | |||
return err | |||
} | |||
} | |||
if !b.cond.IsValid() { | |||
return nil | |||
if b.cond.IsValid() { | |||
if _, err := fmt.Fprint(w, " WHERE "); err != nil { | |||
return err | |||
} | |||
if err := b.cond.WriteTo(w); err != nil { | |||
return err | |||
} | |||
} | |||
if _, err := fmt.Fprint(w, " WHERE "); err != nil { | |||
return err | |||
if len(b.groupBy) > 0 { | |||
if _, err := fmt.Fprint(w, " GROUP BY ", b.groupBy); err != nil { | |||
return err | |||
} | |||
} | |||
return b.cond.WriteTo(w) | |||
if len(b.having) > 0 { | |||
if _, err := fmt.Fprint(w, " HAVING ", b.having); err != nil { | |||
return err | |||
} | |||
} | |||
if len(b.orderBy) > 0 { | |||
if _, err := fmt.Fprint(w, " ORDER BY ", b.orderBy); err != nil { | |||
return err | |||
} | |||
} | |||
return nil | |||
} | |||
// OrderBy orderBy SQL | |||
func (b *Builder) OrderBy(orderBy string) *Builder { | |||
b.orderBy = orderBy | |||
return b | |||
} | |||
// GroupBy groupby SQL | |||
func (b *Builder) GroupBy(groupby string) *Builder { | |||
b.groupBy = groupby | |||
return b | |||
} | |||
// Having having SQL | |||
func (b *Builder) Having(having string) *Builder { | |||
b.having = having | |||
return b | |||
} |
@@ -5,7 +5,6 @@ | |||
package builder | |||
import ( | |||
"bytes" | |||
"io" | |||
) | |||
@@ -19,15 +18,15 @@ var _ Writer = NewWriter() | |||
// BytesWriter implments Writer and save SQL in bytes.Buffer | |||
type BytesWriter struct { | |||
writer *bytes.Buffer | |||
buffer []byte | |||
writer *StringBuilder | |||
args []interface{} | |||
} | |||
// NewWriter creates a new string writer | |||
func NewWriter() *BytesWriter { | |||
w := &BytesWriter{} | |||
w.writer = bytes.NewBuffer(w.buffer) | |||
w := &BytesWriter{ | |||
writer: &StringBuilder{}, | |||
} | |||
return w | |||
} | |||
@@ -10,7 +10,13 @@ import "fmt" | |||
func WriteMap(w Writer, data map[string]interface{}, op string) error { | |||
var args = make([]interface{}, 0, len(data)) | |||
var i = 0 | |||
for k, v := range data { | |||
keys := make([]string, 0, len(data)) | |||
for k := range data { | |||
keys = append(keys, k) | |||
} | |||
for _, k := range keys { | |||
v := data[k] | |||
switch v.(type) { | |||
case expr: | |||
if _, err := fmt.Fprintf(w, "%s%s(", k, op); err != nil { |
@@ -4,7 +4,10 @@ | |||
package builder | |||
import "fmt" | |||
import ( | |||
"fmt" | |||
"sort" | |||
) | |||
// Incr implements a type used by Eq | |||
type Incr int | |||
@@ -19,7 +22,8 @@ var _ Cond = Eq{} | |||
func (eq Eq) opWriteTo(op string, w Writer) error { | |||
var i = 0 | |||
for k, v := range eq { | |||
for _, k := range eq.sortedKeys() { | |||
v := eq[k] | |||
switch v.(type) { | |||
case []int, []int64, []string, []int32, []int16, []int8, []uint, []uint64, []uint32, []uint16, []interface{}: | |||
if err := In(k, v).WriteTo(w); err != nil { | |||
@@ -94,3 +98,15 @@ func (eq Eq) Or(conds ...Cond) Cond { | |||
func (eq Eq) IsValid() bool { | |||
return len(eq) > 0 | |||
} | |||
// sortedKeys returns all keys of this Eq sorted with sort.Strings. | |||
// It is used internally for consistent ordering when generating | |||
// SQL, see https://github.com/go-xorm/builder/issues/10 | |||
func (eq Eq) sortedKeys() []string { | |||
keys := make([]string, 0, len(eq)) | |||
for key := range eq { | |||
keys = append(keys, key) | |||
} | |||
sort.Strings(keys) | |||
return keys | |||
} |
@@ -16,7 +16,7 @@ func (like Like) WriteTo(w Writer) error { | |||
if _, err := fmt.Fprintf(w, "%s LIKE ?", like[0]); err != nil { | |||
return err | |||
} | |||
// FIXME: if use other regular express, this will be failed. but for compitable, keep this | |||
// FIXME: if use other regular express, this will be failed. but for compatible, keep this | |||
if like[1][0] == '%' || like[1][len(like[1])-1] == '%' { | |||
w.Append(like[1]) | |||
} else { |
@@ -4,7 +4,10 @@ | |||
package builder | |||
import "fmt" | |||
import ( | |||
"fmt" | |||
"sort" | |||
) | |||
// Neq defines not equal conditions | |||
type Neq map[string]interface{} | |||
@@ -15,7 +18,8 @@ var _ Cond = Neq{} | |||
func (neq Neq) WriteTo(w Writer) error { | |||
var args = make([]interface{}, 0, len(neq)) | |||
var i = 0 | |||
for k, v := range neq { | |||
for _, k := range neq.sortedKeys() { | |||
v := neq[k] | |||
switch v.(type) { | |||
case []int, []int64, []string, []int32, []int16, []int8: | |||
if err := NotIn(k, v).WriteTo(w); err != nil { | |||
@@ -76,3 +80,15 @@ func (neq Neq) Or(conds ...Cond) Cond { | |||
func (neq Neq) IsValid() bool { | |||
return len(neq) > 0 | |||
} | |||
// sortedKeys returns all keys of this Neq sorted with sort.Strings. | |||
// It is used internally for consistent ordering when generating | |||
// SQL, see https://github.com/go-xorm/builder/issues/10 | |||
func (neq Neq) sortedKeys() []string { | |||
keys := make([]string, 0, len(neq)) | |||
for key := range neq { | |||
keys = append(keys, key) | |||
} | |||
sort.Strings(keys) | |||
return keys | |||
} |
@@ -21,6 +21,18 @@ func (not Not) WriteTo(w Writer) error { | |||
if _, err := fmt.Fprint(w, "("); err != nil { | |||
return err | |||
} | |||
case Eq: | |||
if len(not[0].(Eq)) > 1 { | |||
if _, err := fmt.Fprint(w, "("); err != nil { | |||
return err | |||
} | |||
} | |||
case Neq: | |||
if len(not[0].(Neq)) > 1 { | |||
if _, err := fmt.Fprint(w, "("); err != nil { | |||
return err | |||
} | |||
} | |||
} | |||
if err := not[0].WriteTo(w); err != nil { | |||
@@ -32,6 +44,18 @@ func (not Not) WriteTo(w Writer) error { | |||
if _, err := fmt.Fprint(w, ")"); err != nil { | |||
return err | |||
} | |||
case Eq: | |||
if len(not[0].(Eq)) > 1 { | |||
if _, err := fmt.Fprint(w, ")"); err != nil { | |||
return err | |||
} | |||
} | |||
case Neq: | |||
if len(not[0].(Neq)) > 1 { | |||
if _, err := fmt.Fprint(w, ")"); err != nil { | |||
return err | |||
} | |||
} | |||
} | |||
return nil |
@@ -0,0 +1,119 @@ | |||
// Copyright 2017 The Go Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
package builder | |||
import ( | |||
"unicode/utf8" | |||
"unsafe" | |||
) | |||
// A StringBuilder is used to efficiently build a string using Write methods. | |||
// It minimizes memory copying. The zero value is ready to use. | |||
// Do not copy a non-zero Builder. | |||
type StringBuilder struct { | |||
addr *StringBuilder // of receiver, to detect copies by value | |||
buf []byte | |||
} | |||
// noescape hides a pointer from escape analysis. noescape is | |||
// the identity function but escape analysis doesn't think the | |||
// output depends on the input. noescape is inlined and currently | |||
// compiles down to zero instructions. | |||
// USE CAREFULLY! | |||
// This was copied from the runtime; see issues 23382 and 7921. | |||
//go:nosplit | |||
func noescape(p unsafe.Pointer) unsafe.Pointer { | |||
x := uintptr(p) | |||
return unsafe.Pointer(x ^ 0) | |||
} | |||
func (b *StringBuilder) copyCheck() { | |||
if b.addr == nil { | |||
// This hack works around a failing of Go's escape analysis | |||
// that was causing b to escape and be heap allocated. | |||
// See issue 23382. | |||
// TODO: once issue 7921 is fixed, this should be reverted to | |||
// just "b.addr = b". | |||
b.addr = (*StringBuilder)(noescape(unsafe.Pointer(b))) | |||
} else if b.addr != b { | |||
panic("strings: illegal use of non-zero Builder copied by value") | |||
} | |||
} | |||
// String returns the accumulated string. | |||
func (b *StringBuilder) String() string { | |||
return *(*string)(unsafe.Pointer(&b.buf)) | |||
} | |||
// Len returns the number of accumulated bytes; b.Len() == len(b.String()). | |||
func (b *StringBuilder) Len() int { return len(b.buf) } | |||
// Reset resets the Builder to be empty. | |||
func (b *StringBuilder) Reset() { | |||
b.addr = nil | |||
b.buf = nil | |||
} | |||
// grow copies the buffer to a new, larger buffer so that there are at least n | |||
// bytes of capacity beyond len(b.buf). | |||
func (b *StringBuilder) grow(n int) { | |||
buf := make([]byte, len(b.buf), 2*cap(b.buf)+n) | |||
copy(buf, b.buf) | |||
b.buf = buf | |||
} | |||
// Grow grows b's capacity, if necessary, to guarantee space for | |||
// another n bytes. After Grow(n), at least n bytes can be written to b | |||
// without another allocation. If n is negative, Grow panics. | |||
func (b *StringBuilder) Grow(n int) { | |||
b.copyCheck() | |||
if n < 0 { | |||
panic("strings.Builder.Grow: negative count") | |||
} | |||
if cap(b.buf)-len(b.buf) < n { | |||
b.grow(n) | |||
} | |||
} | |||
// Write appends the contents of p to b's buffer. | |||
// Write always returns len(p), nil. | |||
func (b *StringBuilder) Write(p []byte) (int, error) { | |||
b.copyCheck() | |||
b.buf = append(b.buf, p...) | |||
return len(p), nil | |||
} | |||
// WriteByte appends the byte c to b's buffer. | |||
// The returned error is always nil. | |||
func (b *StringBuilder) WriteByte(c byte) error { | |||
b.copyCheck() | |||
b.buf = append(b.buf, c) | |||
return nil | |||
} | |||
// WriteRune appends the UTF-8 encoding of Unicode code point r to b's buffer. | |||
// It returns the length of r and a nil error. | |||
func (b *StringBuilder) WriteRune(r rune) (int, error) { | |||
b.copyCheck() | |||
if r < utf8.RuneSelf { | |||
b.buf = append(b.buf, byte(r)) | |||
return 1, nil | |||
} | |||
l := len(b.buf) | |||
if cap(b.buf)-l < utf8.UTFMax { | |||
b.grow(utf8.UTFMax) | |||
} | |||
n := utf8.EncodeRune(b.buf[l:l+utf8.UTFMax], r) | |||
b.buf = b.buf[:l+n] | |||
return n, nil | |||
} | |||
// WriteString appends the contents of s to b's buffer. | |||
// It returns the length of s and a nil error. | |||
func (b *StringBuilder) WriteString(s string) (int, error) { | |||
b.copyCheck() | |||
b.buf = append(b.buf, s...) | |||
return len(s), nil | |||
} |
@@ -147,12 +147,12 @@ func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) { | |||
} | |||
fieldValue = fieldValue.Elem().FieldByName(fieldPath[i+1]) | |||
} else { | |||
return nil, fmt.Errorf("field %v is not valid", col.FieldName) | |||
return nil, fmt.Errorf("field %v is not valid", col.FieldName) | |||
} | |||
} | |||
if !fieldValue.IsValid() { | |||
return nil, fmt.Errorf("field %v is not valid", col.FieldName) | |||
return nil, fmt.Errorf("field %v is not valid", col.FieldName) | |||
} | |||
return &fieldValue, nil |
@@ -7,6 +7,11 @@ import ( | |||
"fmt" | |||
"reflect" | |||
"regexp" | |||
"sync" | |||
) | |||
var ( | |||
DefaultCacheSize = 200 | |||
) | |||
func MapToSlice(query string, mp interface{}) (string, []interface{}, error) { | |||
@@ -58,9 +63,16 @@ func StructToSlice(query string, st interface{}) (string, []interface{}, error) | |||
return query, args, nil | |||
} | |||
type cacheStruct struct { | |||
value reflect.Value | |||
idx int | |||
} | |||
type DB struct { | |||
*sql.DB | |||
Mapper IMapper | |||
Mapper IMapper | |||
reflectCache map[reflect.Type]*cacheStruct | |||
reflectCacheMutex sync.RWMutex | |||
} | |||
func Open(driverName, dataSourceName string) (*DB, error) { | |||
@@ -68,11 +80,32 @@ func Open(driverName, dataSourceName string) (*DB, error) { | |||
if err != nil { | |||
return nil, err | |||
} | |||
return &DB{db, NewCacheMapper(&SnakeMapper{})}, nil | |||
return &DB{ | |||
DB: db, | |||
Mapper: NewCacheMapper(&SnakeMapper{}), | |||
reflectCache: make(map[reflect.Type]*cacheStruct), | |||
}, nil | |||
} | |||
func FromDB(db *sql.DB) *DB { | |||
return &DB{db, NewCacheMapper(&SnakeMapper{})} | |||
return &DB{ | |||
DB: db, | |||
Mapper: NewCacheMapper(&SnakeMapper{}), | |||
reflectCache: make(map[reflect.Type]*cacheStruct), | |||
} | |||
} | |||
func (db *DB) reflectNew(typ reflect.Type) reflect.Value { | |||
db.reflectCacheMutex.Lock() | |||
defer db.reflectCacheMutex.Unlock() | |||
cs, ok := db.reflectCache[typ] | |||
if !ok || cs.idx+1 > DefaultCacheSize-1 { | |||
cs = &cacheStruct{reflect.MakeSlice(reflect.SliceOf(typ), DefaultCacheSize, DefaultCacheSize), 0} | |||
db.reflectCache[typ] = cs | |||
} else { | |||
cs.idx = cs.idx + 1 | |||
} | |||
return cs.value.Index(cs.idx).Addr() | |||
} | |||
func (db *DB) Query(query string, args ...interface{}) (*Rows, error) { | |||
@@ -83,7 +116,7 @@ func (db *DB) Query(query string, args ...interface{}) (*Rows, error) { | |||
} | |||
return nil, err | |||
} | |||
return &Rows{rows, db.Mapper}, nil | |||
return &Rows{rows, db}, nil | |||
} | |||
func (db *DB) QueryMap(query string, mp interface{}) (*Rows, error) { | |||
@@ -128,8 +161,8 @@ func (db *DB) QueryRowStruct(query string, st interface{}) *Row { | |||
type Stmt struct { | |||
*sql.Stmt | |||
Mapper IMapper | |||
names map[string]int | |||
db *DB | |||
names map[string]int | |||
} | |||
func (db *DB) Prepare(query string) (*Stmt, error) { | |||
@@ -145,7 +178,7 @@ func (db *DB) Prepare(query string) (*Stmt, error) { | |||
if err != nil { | |||
return nil, err | |||
} | |||
return &Stmt{stmt, db.Mapper, names}, nil | |||
return &Stmt{stmt, db, names}, nil | |||
} | |||
func (s *Stmt) ExecMap(mp interface{}) (sql.Result, error) { | |||
@@ -179,7 +212,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { | |||
if err != nil { | |||
return nil, err | |||
} | |||
return &Rows{rows, s.Mapper}, nil | |||
return &Rows{rows, s.db}, nil | |||
} | |||
func (s *Stmt) QueryMap(mp interface{}) (*Rows, error) { | |||
@@ -274,7 +307,7 @@ func (EmptyScanner) Scan(src interface{}) error { | |||
type Tx struct { | |||
*sql.Tx | |||
Mapper IMapper | |||
db *DB | |||
} | |||
func (db *DB) Begin() (*Tx, error) { | |||
@@ -282,7 +315,7 @@ func (db *DB) Begin() (*Tx, error) { | |||
if err != nil { | |||
return nil, err | |||
} | |||
return &Tx{tx, db.Mapper}, nil | |||
return &Tx{tx, db}, nil | |||
} | |||
func (tx *Tx) Prepare(query string) (*Stmt, error) { | |||
@@ -298,7 +331,7 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) { | |||
if err != nil { | |||
return nil, err | |||
} | |||
return &Stmt{stmt, tx.Mapper, names}, nil | |||
return &Stmt{stmt, tx.db, names}, nil | |||
} | |||
func (tx *Tx) Stmt(stmt *Stmt) *Stmt { | |||
@@ -327,7 +360,7 @@ func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) { | |||
if err != nil { | |||
return nil, err | |||
} | |||
return &Rows{rows, tx.Mapper}, nil | |||
return &Rows{rows, tx.db}, nil | |||
} | |||
func (tx *Tx) QueryMap(query string, mp interface{}) (*Rows, error) { |
@@ -74,6 +74,7 @@ type Dialect interface { | |||
GetIndexes(tableName string) (map[string]*Index, error) | |||
Filters() []Filter | |||
SetParams(params map[string]string) | |||
} | |||
func OpenDialect(dialect Dialect) (*DB, error) { | |||
@@ -148,7 +149,8 @@ func (db *Base) SupportDropIfExists() bool { | |||
} | |||
func (db *Base) DropTableSql(tableName string) string { | |||
return fmt.Sprintf("DROP TABLE IF EXISTS `%s`", tableName) | |||
quote := db.dialect.Quote | |||
return fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName)) | |||
} | |||
func (db *Base) HasRecords(query string, args ...interface{}) (bool, error) { | |||
@@ -289,6 +291,9 @@ func (b *Base) LogSQL(sql string, args []interface{}) { | |||
} | |||
} | |||
func (b *Base) SetParams(params map[string]string) { | |||
} | |||
var ( | |||
dialects = map[string]func() Dialect{} | |||
) |
@@ -37,9 +37,9 @@ func (q *Quoter) Quote(content string) string { | |||
func (i *IdFilter) Do(sql string, dialect Dialect, table *Table) string { | |||
quoter := NewQuoter(dialect) | |||
if table != nil && len(table.PrimaryKeys) == 1 { | |||
sql = strings.Replace(sql, "`(id)`", quoter.Quote(table.PrimaryKeys[0]), -1) | |||
sql = strings.Replace(sql, quoter.Quote("(id)"), quoter.Quote(table.PrimaryKeys[0]), -1) | |||
return strings.Replace(sql, "(id)", quoter.Quote(table.PrimaryKeys[0]), -1) | |||
sql = strings.Replace(sql, " `(id)` ", " "+quoter.Quote(table.PrimaryKeys[0])+" ", -1) | |||
sql = strings.Replace(sql, " "+quoter.Quote("(id)")+" ", " "+quoter.Quote(table.PrimaryKeys[0])+" ", -1) | |||
return strings.Replace(sql, " (id) ", " "+quoter.Quote(table.PrimaryKeys[0])+" ", -1) | |||
} | |||
return sql | |||
} |
@@ -22,6 +22,8 @@ type Index struct { | |||
func (index *Index) XName(tableName string) string { | |||
if !strings.HasPrefix(index.Name, "UQE_") && | |||
!strings.HasPrefix(index.Name, "IDX_") { | |||
tableName = strings.Replace(tableName, `"`, "", -1) | |||
tableName = strings.Replace(tableName, `.`, "_", -1) | |||
if index.Type == UniqueType { | |||
return fmt.Sprintf("UQE_%v_%v", tableName, index.Name) | |||
} |
@@ -9,7 +9,7 @@ import ( | |||
type Rows struct { | |||
*sql.Rows | |||
Mapper IMapper | |||
db *DB | |||
} | |||
func (rs *Rows) ToMapString() ([]map[string]string, error) { | |||
@@ -105,7 +105,7 @@ func (rs *Rows) ScanStructByName(dest interface{}) error { | |||
newDest := make([]interface{}, len(cols)) | |||
var v EmptyScanner | |||
for j, name := range cols { | |||
f := fieldByName(vv.Elem(), rs.Mapper.Table2Obj(name)) | |||
f := fieldByName(vv.Elem(), rs.db.Mapper.Table2Obj(name)) | |||
if f.IsValid() { | |||
newDest[j] = f.Addr().Interface() | |||
} else { | |||
@@ -116,36 +116,6 @@ func (rs *Rows) ScanStructByName(dest interface{}) error { | |||
return rs.Rows.Scan(newDest...) | |||
} | |||
type cacheStruct struct { | |||
value reflect.Value | |||
idx int | |||
} | |||
var ( | |||
reflectCache = make(map[reflect.Type]*cacheStruct) | |||
reflectCacheMutex sync.RWMutex | |||
) | |||
func ReflectNew(typ reflect.Type) reflect.Value { | |||
reflectCacheMutex.RLock() | |||
cs, ok := reflectCache[typ] | |||
reflectCacheMutex.RUnlock() | |||
const newSize = 200 | |||
if !ok || cs.idx+1 > newSize-1 { | |||
cs = &cacheStruct{reflect.MakeSlice(reflect.SliceOf(typ), newSize, newSize), 0} | |||
reflectCacheMutex.Lock() | |||
reflectCache[typ] = cs | |||
reflectCacheMutex.Unlock() | |||
} else { | |||
reflectCacheMutex.Lock() | |||
cs.idx = cs.idx + 1 | |||
reflectCacheMutex.Unlock() | |||
} | |||
return cs.value.Index(cs.idx).Addr() | |||
} | |||
// scan data to a slice's pointer, slice's length should equal to columns' number | |||
func (rs *Rows) ScanSlice(dest interface{}) error { | |||
vv := reflect.ValueOf(dest) | |||
@@ -197,9 +167,7 @@ func (rs *Rows) ScanMap(dest interface{}) error { | |||
vvv := vv.Elem() | |||
for i, _ := range cols { | |||
newDest[i] = ReflectNew(vvv.Type().Elem()).Interface() | |||
//v := reflect.New(vvv.Type().Elem()) | |||
//newDest[i] = v.Interface() | |||
newDest[i] = rs.db.reflectNew(vvv.Type().Elem()).Interface() | |||
} | |||
err = rs.Rows.Scan(newDest...) | |||
@@ -215,32 +183,6 @@ func (rs *Rows) ScanMap(dest interface{}) error { | |||
return nil | |||
} | |||
/*func (rs *Rows) ScanMap(dest interface{}) error { | |||
vv := reflect.ValueOf(dest) | |||
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { | |||
return errors.New("dest should be a map's pointer") | |||
} | |||
cols, err := rs.Columns() | |||
if err != nil { | |||
return err | |||
} | |||
newDest := make([]interface{}, len(cols)) | |||
err = rs.ScanSlice(newDest) | |||
if err != nil { | |||
return err | |||
} | |||
vvv := vv.Elem() | |||
for i, name := range cols { | |||
vname := reflect.ValueOf(name) | |||
vvv.SetMapIndex(vname, reflect.ValueOf(newDest[i]).Elem()) | |||
} | |||
return nil | |||
}*/ | |||
type Row struct { | |||
rows *Rows | |||
// One of these two will be non-nil: |
@@ -49,7 +49,6 @@ func NewTable(name string, t reflect.Type) *Table { | |||
} | |||
func (table *Table) columnsByName(name string) []*Column { | |||
n := len(name) | |||
for k := range table.columnsMap { | |||
@@ -75,7 +74,6 @@ func (table *Table) GetColumn(name string) *Column { | |||
} | |||
func (table *Table) GetColumnIdx(name string, idx int) *Column { | |||
cols := table.columnsByName(name) | |||
if cols != nil && idx < len(cols) { |
@@ -69,15 +69,18 @@ var ( | |||
Enum = "ENUM" | |||
Set = "SET" | |||
Char = "CHAR" | |||
Varchar = "VARCHAR" | |||
NVarchar = "NVARCHAR" | |||
TinyText = "TINYTEXT" | |||
Text = "TEXT" | |||
Clob = "CLOB" | |||
MediumText = "MEDIUMTEXT" | |||
LongText = "LONGTEXT" | |||
Uuid = "UUID" | |||
Char = "CHAR" | |||
Varchar = "VARCHAR" | |||
NVarchar = "NVARCHAR" | |||
TinyText = "TINYTEXT" | |||
Text = "TEXT" | |||
NText = "NTEXT" | |||
Clob = "CLOB" | |||
MediumText = "MEDIUMTEXT" | |||
LongText = "LONGTEXT" | |||
Uuid = "UUID" | |||
UniqueIdentifier = "UNIQUEIDENTIFIER" | |||
SysName = "SYSNAME" | |||
Date = "DATE" | |||
DateTime = "DATETIME" | |||
@@ -128,10 +131,12 @@ var ( | |||
NVarchar: TEXT_TYPE, | |||
TinyText: TEXT_TYPE, | |||
Text: TEXT_TYPE, | |||
NText: TEXT_TYPE, | |||
MediumText: TEXT_TYPE, | |||
LongText: TEXT_TYPE, | |||
Uuid: TEXT_TYPE, | |||
Clob: TEXT_TYPE, | |||
SysName: TEXT_TYPE, | |||
Date: TIME_TYPE, | |||
DateTime: TIME_TYPE, | |||
@@ -148,11 +153,12 @@ var ( | |||
Binary: BLOB_TYPE, | |||
VarBinary: BLOB_TYPE, | |||
TinyBlob: BLOB_TYPE, | |||
Blob: BLOB_TYPE, | |||
MediumBlob: BLOB_TYPE, | |||
LongBlob: BLOB_TYPE, | |||
Bytea: BLOB_TYPE, | |||
TinyBlob: BLOB_TYPE, | |||
Blob: BLOB_TYPE, | |||
MediumBlob: BLOB_TYPE, | |||
LongBlob: BLOB_TYPE, | |||
Bytea: BLOB_TYPE, | |||
UniqueIdentifier: BLOB_TYPE, | |||
Bool: NUMERIC_TYPE, | |||
@@ -289,9 +295,9 @@ func SQLType2Type(st SQLType) reflect.Type { | |||
return reflect.TypeOf(float32(1)) | |||
case Double: | |||
return reflect.TypeOf(float64(1)) | |||
case Char, Varchar, NVarchar, TinyText, Text, MediumText, LongText, Enum, Set, Uuid, Clob: | |||
case Char, Varchar, NVarchar, TinyText, Text, NText, MediumText, LongText, Enum, Set, Uuid, Clob, SysName: | |||
return reflect.TypeOf("") | |||
case TinyBlob, Blob, LongBlob, Bytea, Binary, MediumBlob, VarBinary: | |||
case TinyBlob, Blob, LongBlob, Bytea, Binary, MediumBlob, VarBinary, UniqueIdentifier: | |||
return reflect.TypeOf([]byte{}) | |||
case Bool: | |||
return reflect.TypeOf(true) |
@@ -172,12 +172,33 @@ type mysql struct { | |||
allowAllFiles bool | |||
allowOldPasswords bool | |||
clientFoundRows bool | |||
rowFormat string | |||
} | |||
func (db *mysql) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { | |||
return db.Base.Init(d, db, uri, drivername, dataSourceName) | |||
} | |||
func (db *mysql) SetParams(params map[string]string) { | |||
rowFormat, ok := params["rowFormat"] | |||
if ok { | |||
var t = strings.ToUpper(rowFormat) | |||
switch t { | |||
case "COMPACT": | |||
fallthrough | |||
case "REDUNDANT": | |||
fallthrough | |||
case "DYNAMIC": | |||
fallthrough | |||
case "COMPRESSED": | |||
db.rowFormat = t | |||
break | |||
default: | |||
break | |||
} | |||
} | |||
} | |||
func (db *mysql) SqlType(c *core.Column) string { | |||
var res string | |||
switch t := c.SQLType.Name; t { | |||
@@ -487,6 +508,59 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*core.Index, error) { | |||
return indexes, nil | |||
} | |||
func (db *mysql) CreateTableSql(table *core.Table, tableName, storeEngine, charset string) string { | |||
var sql string | |||
sql = "CREATE TABLE IF NOT EXISTS " | |||
if tableName == "" { | |||
tableName = table.Name | |||
} | |||
sql += db.Quote(tableName) | |||
sql += " (" | |||
if len(table.ColumnsSeq()) > 0 { | |||
pkList := table.PrimaryKeys | |||
for _, colName := range table.ColumnsSeq() { | |||
col := table.GetColumn(colName) | |||
if col.IsPrimaryKey && len(pkList) == 1 { | |||
sql += col.String(db) | |||
} else { | |||
sql += col.StringNoPk(db) | |||
} | |||
sql = strings.TrimSpace(sql) | |||
if len(col.Comment) > 0 { | |||
sql += " COMMENT '" + col.Comment + "'" | |||
} | |||
sql += ", " | |||
} | |||
if len(pkList) > 1 { | |||
sql += "PRIMARY KEY ( " | |||
sql += db.Quote(strings.Join(pkList, db.Quote(","))) | |||
sql += " ), " | |||
} | |||
sql = sql[:len(sql)-2] | |||
} | |||
sql += ")" | |||
if storeEngine != "" { | |||
sql += " ENGINE=" + storeEngine | |||
} | |||
if len(charset) == 0 { | |||
charset = db.URI().Charset | |||
} else if len(charset) > 0 { | |||
sql += " DEFAULT CHARSET " + charset | |||
} | |||
if db.rowFormat != "" { | |||
sql += " ROW_FORMAT=" + db.rowFormat | |||
} | |||
return sql | |||
} | |||
func (db *mysql) Filters() []core.Filter { | |||
return []core.Filter{&core.IdFilter{}} | |||
} |
@@ -769,14 +769,21 @@ var ( | |||
DefaultPostgresSchema = "public" | |||
) | |||
const postgresPublicSchema = "public" | |||
type postgres struct { | |||
core.Base | |||
schema string | |||
} | |||
func (db *postgres) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { | |||
db.schema = DefaultPostgresSchema | |||
return db.Base.Init(d, db, uri, drivername, dataSourceName) | |||
err := db.Base.Init(d, db, uri, drivername, dataSourceName) | |||
if err != nil { | |||
return err | |||
} | |||
if db.Schema == "" { | |||
db.Schema = DefaultPostgresSchema | |||
} | |||
return nil | |||
} | |||
func (db *postgres) SqlType(c *core.Column) string { | |||
@@ -873,32 +880,42 @@ func (db *postgres) IndexOnTable() bool { | |||
} | |||
func (db *postgres) IndexCheckSql(tableName, idxName string) (string, []interface{}) { | |||
args := []interface{}{tableName, idxName} | |||
if len(db.Schema) == 0 { | |||
args := []interface{}{tableName, idxName} | |||
return `SELECT indexname FROM pg_indexes WHERE tablename = ? AND indexname = ?`, args | |||
} | |||
args := []interface{}{db.Schema, tableName, idxName} | |||
return `SELECT indexname FROM pg_indexes ` + | |||
`WHERE tablename = ? AND indexname = ?`, args | |||
`WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args | |||
} | |||
func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) { | |||
args := []interface{}{tableName} | |||
return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args | |||
} | |||
if len(db.Schema) == 0 { | |||
args := []interface{}{tableName} | |||
return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args | |||
} | |||
/*func (db *postgres) ColumnCheckSql(tableName, colName string) (string, []interface{}) { | |||
args := []interface{}{tableName, colName} | |||
return "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = ?" + | |||
" AND column_name = ?", args | |||
}*/ | |||
args := []interface{}{db.Schema, tableName} | |||
return `SELECT tablename FROM pg_tables WHERE schemaname = ? AND tablename = ?`, args | |||
} | |||
func (db *postgres) ModifyColumnSql(tableName string, col *core.Column) string { | |||
return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s", | |||
tableName, col.Name, db.SqlType(col)) | |||
if len(db.Schema) == 0 { | |||
return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s", | |||
tableName, col.Name, db.SqlType(col)) | |||
} | |||
return fmt.Sprintf("alter table %s.%s ALTER COLUMN %s TYPE %s", | |||
db.Schema, tableName, col.Name, db.SqlType(col)) | |||
} | |||
func (db *postgres) DropIndexSql(tableName string, index *core.Index) string { | |||
//var unique string | |||
quote := db.Quote | |||
idxName := index.Name | |||
tableName = strings.Replace(tableName, `"`, "", -1) | |||
tableName = strings.Replace(tableName, `.`, "_", -1) | |||
if !strings.HasPrefix(idxName, "UQE_") && | |||
!strings.HasPrefix(idxName, "IDX_") { | |||
if index.Type == core.UniqueType { | |||
@@ -907,13 +924,21 @@ func (db *postgres) DropIndexSql(tableName string, index *core.Index) string { | |||
idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name) | |||
} | |||
} | |||
if db.Uri.Schema != "" { | |||
idxName = db.Uri.Schema + "." + idxName | |||
} | |||
return fmt.Sprintf("DROP INDEX %v", quote(idxName)) | |||
} | |||
func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) { | |||
args := []interface{}{tableName, colName} | |||
query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" + | |||
" AND column_name = $2" | |||
args := []interface{}{db.Schema, tableName, colName} | |||
query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = $1 AND table_name = $2" + | |||
" AND column_name = $3" | |||
if len(db.Schema) == 0 { | |||
args = []interface{}{tableName, colName} | |||
query = "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" + | |||
" AND column_name = $2" | |||
} | |||
db.LogSQL(query, args) | |||
rows, err := db.DB().Query(query, args...) | |||
@@ -926,8 +951,7 @@ func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) { | |||
} | |||
func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { | |||
// FIXME: the schema should be replaced by user custom's | |||
args := []interface{}{tableName, db.schema} | |||
args := []interface{}{tableName} | |||
s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, numeric_precision, numeric_precision_radix , | |||
CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey, | |||
CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey | |||
@@ -938,7 +962,15 @@ FROM pg_attribute f | |||
LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey) | |||
LEFT JOIN pg_class AS g ON p.confrelid = g.oid | |||
LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name | |||
WHERE c.relkind = 'r'::char AND c.relname = $1 AND s.table_schema = $2 AND f.attnum > 0 ORDER BY f.attnum;` | |||
WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;` | |||
var f string | |||
if len(db.Schema) != 0 { | |||
args = append(args, db.Schema) | |||
f = " AND s.table_schema = $2" | |||
} | |||
s = fmt.Sprintf(s, f) | |||
db.LogSQL(s, args) | |||
rows, err := db.DB().Query(s, args...) | |||
@@ -1028,8 +1060,13 @@ WHERE c.relkind = 'r'::char AND c.relname = $1 AND s.table_schema = $2 AND f.att | |||
} | |||
func (db *postgres) GetTables() ([]*core.Table, error) { | |||
args := []interface{}{db.schema} | |||
s := fmt.Sprintf("SELECT tablename FROM pg_tables WHERE schemaname = $1") | |||
args := []interface{}{} | |||
s := "SELECT tablename FROM pg_tables" | |||
if len(db.Schema) != 0 { | |||
args = append(args, db.Schema) | |||
s = s + " WHERE schemaname = $1" | |||
} | |||
db.LogSQL(s, args) | |||
rows, err := db.DB().Query(s, args...) | |||
@@ -1053,8 +1090,12 @@ func (db *postgres) GetTables() ([]*core.Table, error) { | |||
} | |||
func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) { | |||
args := []interface{}{db.schema, tableName} | |||
s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE schemaname=$1 AND tablename=$2") | |||
args := []interface{}{tableName} | |||
s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1") | |||
if len(db.Schema) != 0 { | |||
args = append(args, db.Schema) | |||
s = s + " AND schemaname=$2" | |||
} | |||
db.LogSQL(s, args) | |||
rows, err := db.DB().Query(s, args...) | |||
@@ -1182,3 +1223,15 @@ func (p *pqDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { | |||
return db, nil | |||
} | |||
type pqDriverPgx struct { | |||
pqDriver | |||
} | |||
func (pgx *pqDriverPgx) Parse(driverName, dataSourceName string) (*core.Uri, error) { | |||
// Remove the leading characters for driver to work | |||
if len(dataSourceName) >= 9 && dataSourceName[0] == 0 { | |||
dataSourceName = dataSourceName[9:] | |||
} | |||
return pgx.pqDriver.Parse(driverName, dataSourceName) | |||
} |
@@ -49,6 +49,35 @@ type Engine struct { | |||
tagHandlers map[string]tagHandler | |||
engineGroup *EngineGroup | |||
cachers map[string]core.Cacher | |||
cacherLock sync.RWMutex | |||
} | |||
func (engine *Engine) setCacher(tableName string, cacher core.Cacher) { | |||
engine.cacherLock.Lock() | |||
engine.cachers[tableName] = cacher | |||
engine.cacherLock.Unlock() | |||
} | |||
func (engine *Engine) SetCacher(tableName string, cacher core.Cacher) { | |||
engine.setCacher(tableName, cacher) | |||
} | |||
func (engine *Engine) getCacher(tableName string) core.Cacher { | |||
var cacher core.Cacher | |||
var ok bool | |||
engine.cacherLock.RLock() | |||
cacher, ok = engine.cachers[tableName] | |||
engine.cacherLock.RUnlock() | |||
if !ok && !engine.disableGlobalCache { | |||
cacher = engine.Cacher | |||
} | |||
return cacher | |||
} | |||
func (engine *Engine) GetCacher(tableName string) core.Cacher { | |||
return engine.getCacher(tableName) | |||
} | |||
// BufferSize sets buffer size for iterate | |||
@@ -165,7 +194,7 @@ func (engine *Engine) Quote(value string) string { | |||
} | |||
// QuoteTo quotes string and writes into the buffer | |||
func (engine *Engine) QuoteTo(buf *bytes.Buffer, value string) { | |||
func (engine *Engine) QuoteTo(buf *builder.StringBuilder, value string) { | |||
if buf == nil { | |||
return | |||
} | |||
@@ -245,13 +274,7 @@ func (engine *Engine) NoCascade() *Session { | |||
// MapCacher Set a table use a special cacher | |||
func (engine *Engine) MapCacher(bean interface{}, cacher core.Cacher) error { | |||
v := rValue(bean) | |||
tb, err := engine.autoMapType(v) | |||
if err != nil { | |||
return err | |||
} | |||
tb.Cacher = cacher | |||
engine.setCacher(engine.TableName(bean, true), cacher) | |||
return nil | |||
} | |||
@@ -536,33 +559,6 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D | |||
return nil | |||
} | |||
func (engine *Engine) tableName(beanOrTableName interface{}) (string, error) { | |||
v := rValue(beanOrTableName) | |||
if v.Type().Kind() == reflect.String { | |||
return beanOrTableName.(string), nil | |||
} else if v.Type().Kind() == reflect.Struct { | |||
return engine.tbName(v), nil | |||
} | |||
return "", errors.New("bean should be a struct or struct's point") | |||
} | |||
func (engine *Engine) tbName(v reflect.Value) string { | |||
if tb, ok := v.Interface().(TableName); ok { | |||
return tb.TableName() | |||
} | |||
if v.Type().Kind() == reflect.Ptr { | |||
if tb, ok := reflect.Indirect(v).Interface().(TableName); ok { | |||
return tb.TableName() | |||
} | |||
} else if v.CanAddr() { | |||
if tb, ok := v.Addr().Interface().(TableName); ok { | |||
return tb.TableName() | |||
} | |||
} | |||
return engine.TableMapper.Obj2Table(reflect.Indirect(v).Type().Name()) | |||
} | |||
// Cascade use cascade or not | |||
func (engine *Engine) Cascade(trueOrFalse ...bool) *Session { | |||
session := engine.NewSession() | |||
@@ -846,7 +842,7 @@ func (engine *Engine) TableInfo(bean interface{}) *Table { | |||
if err != nil { | |||
engine.logger.Error(err) | |||
} | |||
return &Table{tb, engine.tbName(v)} | |||
return &Table{tb, engine.TableName(bean)} | |||
} | |||
func addIndex(indexName string, table *core.Table, col *core.Column, indexType int) { | |||
@@ -861,15 +857,6 @@ func addIndex(indexName string, table *core.Table, col *core.Column, indexType i | |||
} | |||
} | |||
func (engine *Engine) newTable() *core.Table { | |||
table := core.NewEmptyTable() | |||
if !engine.disableGlobalCache { | |||
table.Cacher = engine.Cacher | |||
} | |||
return table | |||
} | |||
// TableName table name interface to define customerize table name | |||
type TableName interface { | |||
TableName() string | |||
@@ -881,21 +868,9 @@ var ( | |||
func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { | |||
t := v.Type() | |||
table := engine.newTable() | |||
if tb, ok := v.Interface().(TableName); ok { | |||
table.Name = tb.TableName() | |||
} else { | |||
if v.CanAddr() { | |||
if tb, ok = v.Addr().Interface().(TableName); ok { | |||
table.Name = tb.TableName() | |||
} | |||
} | |||
if table.Name == "" { | |||
table.Name = engine.TableMapper.Obj2Table(t.Name()) | |||
} | |||
} | |||
table := core.NewEmptyTable() | |||
table.Type = t | |||
table.Name = engine.tbNameForMap(v) | |||
var idFieldColName string | |||
var hasCacheTag, hasNoCacheTag bool | |||
@@ -1049,15 +1024,15 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { | |||
if hasCacheTag { | |||
if engine.Cacher != nil { // !nash! use engine's cacher if provided | |||
engine.logger.Info("enable cache on table:", table.Name) | |||
table.Cacher = engine.Cacher | |||
engine.setCacher(table.Name, engine.Cacher) | |||
} else { | |||
engine.logger.Info("enable LRU cache on table:", table.Name) | |||
table.Cacher = NewLRUCacher2(NewMemoryStore(), time.Hour, 10000) // !nashtsai! HACK use LRU cacher for now | |||
engine.setCacher(table.Name, NewLRUCacher2(NewMemoryStore(), time.Hour, 10000)) | |||
} | |||
} | |||
if hasNoCacheTag { | |||
engine.logger.Info("no cache on table:", table.Name) | |||
table.Cacher = nil | |||
engine.logger.Info("disable cache on table:", table.Name) | |||
engine.setCacher(table.Name, nil) | |||
} | |||
return table, nil | |||
@@ -1116,7 +1091,25 @@ func (engine *Engine) idOfV(rv reflect.Value) (core.PK, error) { | |||
pk := make([]interface{}, len(table.PrimaryKeys)) | |||
for i, col := range table.PKColumns() { | |||
var err error | |||
pkField := v.FieldByName(col.FieldName) | |||
fieldName := col.FieldName | |||
for { | |||
parts := strings.SplitN(fieldName, ".", 2) | |||
if len(parts) == 1 { | |||
break | |||
} | |||
v = v.FieldByName(parts[0]) | |||
if v.Kind() == reflect.Ptr { | |||
v = v.Elem() | |||
} | |||
if v.Kind() != reflect.Struct { | |||
return nil, ErrUnSupportedType | |||
} | |||
fieldName = parts[1] | |||
} | |||
pkField := v.FieldByName(fieldName) | |||
switch pkField.Kind() { | |||
case reflect.String: | |||
pk[i], err = engine.idTypeAssertion(col, pkField.String()) | |||
@@ -1162,26 +1155,10 @@ func (engine *Engine) CreateUniques(bean interface{}) error { | |||
return session.CreateUniques(bean) | |||
} | |||
func (engine *Engine) getCacher2(table *core.Table) core.Cacher { | |||
return table.Cacher | |||
} | |||
// ClearCacheBean if enabled cache, clear the cache bean | |||
func (engine *Engine) ClearCacheBean(bean interface{}, id string) error { | |||
v := rValue(bean) | |||
t := v.Type() | |||
if t.Kind() != reflect.Struct { | |||
return errors.New("error params") | |||
} | |||
tableName := engine.tbName(v) | |||
table, err := engine.autoMapType(v) | |||
if err != nil { | |||
return err | |||
} | |||
cacher := table.Cacher | |||
if cacher == nil { | |||
cacher = engine.Cacher | |||
} | |||
tableName := engine.TableName(bean) | |||
cacher := engine.getCacher(tableName) | |||
if cacher != nil { | |||
cacher.ClearIds(tableName) | |||
cacher.DelBean(tableName, id) | |||
@@ -1192,21 +1169,8 @@ func (engine *Engine) ClearCacheBean(bean interface{}, id string) error { | |||
// ClearCache if enabled cache, clear some tables' cache | |||
func (engine *Engine) ClearCache(beans ...interface{}) error { | |||
for _, bean := range beans { | |||
v := rValue(bean) | |||
t := v.Type() | |||
if t.Kind() != reflect.Struct { | |||
return errors.New("error params") | |||
} | |||
tableName := engine.tbName(v) | |||
table, err := engine.autoMapType(v) | |||
if err != nil { | |||
return err | |||
} | |||
cacher := table.Cacher | |||
if cacher == nil { | |||
cacher = engine.Cacher | |||
} | |||
tableName := engine.TableName(bean) | |||
cacher := engine.getCacher(tableName) | |||
if cacher != nil { | |||
cacher.ClearIds(tableName) | |||
cacher.ClearBeans(tableName) | |||
@@ -1224,13 +1188,13 @@ func (engine *Engine) Sync(beans ...interface{}) error { | |||
for _, bean := range beans { | |||
v := rValue(bean) | |||
tableName := engine.tbName(v) | |||
tableNameNoSchema := engine.TableName(bean) | |||
table, err := engine.autoMapType(v) | |||
if err != nil { | |||
return err | |||
} | |||
isExist, err := session.Table(bean).isTableExist(tableName) | |||
isExist, err := session.Table(bean).isTableExist(tableNameNoSchema) | |||
if err != nil { | |||
return err | |||
} | |||
@@ -1256,12 +1220,12 @@ func (engine *Engine) Sync(beans ...interface{}) error { | |||
} | |||
} else { | |||
for _, col := range table.Columns() { | |||
isExist, err := engine.dialect.IsColumnExist(tableName, col.Name) | |||
isExist, err := engine.dialect.IsColumnExist(tableNameNoSchema, col.Name) | |||
if err != nil { | |||
return err | |||
} | |||
if !isExist { | |||
if err := session.statement.setRefValue(v); err != nil { | |||
if err := session.statement.setRefBean(bean); err != nil { | |||
return err | |||
} | |||
err = session.addColumn(col.Name) | |||
@@ -1272,35 +1236,35 @@ func (engine *Engine) Sync(beans ...interface{}) error { | |||
} | |||
for name, index := range table.Indexes { | |||
if err := session.statement.setRefValue(v); err != nil { | |||
if err := session.statement.setRefBean(bean); err != nil { | |||
return err | |||
} | |||
if index.Type == core.UniqueType { | |||
isExist, err := session.isIndexExist2(tableName, index.Cols, true) | |||
isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, true) | |||
if err != nil { | |||
return err | |||
} | |||
if !isExist { | |||
if err := session.statement.setRefValue(v); err != nil { | |||
if err := session.statement.setRefBean(bean); err != nil { | |||
return err | |||
} | |||
err = session.addUnique(tableName, name) | |||
err = session.addUnique(tableNameNoSchema, name) | |||
if err != nil { | |||
return err | |||
} | |||
} | |||
} else if index.Type == core.IndexType { | |||
isExist, err := session.isIndexExist2(tableName, index.Cols, false) | |||
isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, false) | |||
if err != nil { | |||
return err | |||
} | |||
if !isExist { | |||
if err := session.statement.setRefValue(v); err != nil { | |||
if err := session.statement.setRefBean(bean); err != nil { | |||
return err | |||
} | |||
err = session.addIndex(tableName, name) | |||
err = session.addIndex(tableNameNoSchema, name) | |||
if err != nil { | |||
return err | |||
} | |||
@@ -1453,6 +1417,13 @@ func (engine *Engine) Find(beans interface{}, condiBeans ...interface{}) error { | |||
return session.Find(beans, condiBeans...) | |||
} | |||
// FindAndCount find the results and also return the counts | |||
func (engine *Engine) FindAndCount(rowsSlicePtr interface{}, condiBean ...interface{}) (int64, error) { | |||
session := engine.NewSession() | |||
defer session.Close() | |||
return session.FindAndCount(rowsSlicePtr, condiBean...) | |||
} | |||
// Iterate record by record handle records from table, bean's non-empty fields | |||
// are conditions. | |||
func (engine *Engine) Iterate(bean interface{}, fun IterFunc) error { | |||
@@ -1629,6 +1600,11 @@ func (engine *Engine) SetTZDatabase(tz *time.Location) { | |||
engine.DatabaseTZ = tz | |||
} | |||
// SetSchema sets the schema of database | |||
func (engine *Engine) SetSchema(schema string) { | |||
engine.dialect.URI().Schema = schema | |||
} | |||
// Unscoped always disable struct tag "deleted" | |||
func (engine *Engine) Unscoped() *Session { | |||
session := engine.NewSession() |
@@ -9,6 +9,7 @@ import ( | |||
"encoding/json" | |||
"fmt" | |||
"reflect" | |||
"strings" | |||
"time" | |||
"github.com/go-xorm/builder" | |||
@@ -51,7 +52,9 @@ func (engine *Engine) buildConds(table *core.Table, bean interface{}, | |||
fieldValuePtr, err := col.ValueOf(bean) | |||
if err != nil { | |||
engine.logger.Error(err) | |||
if !strings.Contains(err.Error(), "is not valid") { | |||
engine.logger.Warn(err) | |||
} | |||
continue | |||
} | |||
@@ -0,0 +1,113 @@ | |||
// Copyright 2018 The Xorm Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
package xorm | |||
import ( | |||
"fmt" | |||
"reflect" | |||
"strings" | |||
"github.com/go-xorm/core" | |||
) | |||
// TableNameWithSchema will automatically add schema prefix on table name | |||
func (engine *Engine) tbNameWithSchema(v string) string { | |||
// Add schema name as prefix of table name. | |||
// Only for postgres database. | |||
if engine.dialect.DBType() == core.POSTGRES && | |||
engine.dialect.URI().Schema != "" && | |||
engine.dialect.URI().Schema != postgresPublicSchema && | |||
strings.Index(v, ".") == -1 { | |||
return engine.dialect.URI().Schema + "." + v | |||
} | |||
return v | |||
} | |||
// TableName returns table name with schema prefix if has | |||
func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string { | |||
tbName := engine.tbNameNoSchema(bean) | |||
if len(includeSchema) > 0 && includeSchema[0] { | |||
tbName = engine.tbNameWithSchema(tbName) | |||
} | |||
return tbName | |||
} | |||
// tbName get some table's table name | |||
func (session *Session) tbNameNoSchema(table *core.Table) string { | |||
if len(session.statement.AltTableName) > 0 { | |||
return session.statement.AltTableName | |||
} | |||
return table.Name | |||
} | |||
func (engine *Engine) tbNameForMap(v reflect.Value) string { | |||
if v.Type().Implements(tpTableName) { | |||
return v.Interface().(TableName).TableName() | |||
} | |||
if v.Kind() == reflect.Ptr { | |||
v = v.Elem() | |||
if v.Type().Implements(tpTableName) { | |||
return v.Interface().(TableName).TableName() | |||
} | |||
} | |||
return engine.TableMapper.Obj2Table(v.Type().Name()) | |||
} | |||
func (engine *Engine) tbNameNoSchema(tablename interface{}) string { | |||
switch tablename.(type) { | |||
case []string: | |||
t := tablename.([]string) | |||
if len(t) > 1 { | |||
return fmt.Sprintf("%v AS %v", engine.Quote(t[0]), engine.Quote(t[1])) | |||
} else if len(t) == 1 { | |||
return engine.Quote(t[0]) | |||
} | |||
case []interface{}: | |||
t := tablename.([]interface{}) | |||
l := len(t) | |||
var table string | |||
if l > 0 { | |||
f := t[0] | |||
switch f.(type) { | |||
case string: | |||
table = f.(string) | |||
case TableName: | |||
table = f.(TableName).TableName() | |||
default: | |||
v := rValue(f) | |||
t := v.Type() | |||
if t.Kind() == reflect.Struct { | |||
table = engine.tbNameForMap(v) | |||
} else { | |||
table = engine.Quote(fmt.Sprintf("%v", f)) | |||
} | |||
} | |||
} | |||
if l > 1 { | |||
return fmt.Sprintf("%v AS %v", engine.Quote(table), | |||
engine.Quote(fmt.Sprintf("%v", t[1]))) | |||
} else if l == 1 { | |||
return engine.Quote(table) | |||
} | |||
case TableName: | |||
return tablename.(TableName).TableName() | |||
case string: | |||
return tablename.(string) | |||
case reflect.Value: | |||
v := tablename.(reflect.Value) | |||
return engine.tbNameForMap(v) | |||
default: | |||
v := rValue(tablename) | |||
t := v.Type() | |||
if t.Kind() == reflect.Struct { | |||
return engine.tbNameForMap(v) | |||
} | |||
return engine.Quote(fmt.Sprintf("%v", tablename)) | |||
} | |||
return "" | |||
} |
@@ -6,23 +6,44 @@ package xorm | |||
import ( | |||
"errors" | |||
"fmt" | |||
) | |||
var ( | |||
// ErrParamsType params error | |||
ErrParamsType = errors.New("Params type error") | |||
// ErrTableNotFound table not found error | |||
ErrTableNotFound = errors.New("Not found table") | |||
ErrTableNotFound = errors.New("Table not found") | |||
// ErrUnSupportedType unsupported error | |||
ErrUnSupportedType = errors.New("Unsupported type error") | |||
// ErrNotExist record is not exist error | |||
ErrNotExist = errors.New("Not exist error") | |||
// ErrNotExist record does not exist error | |||
ErrNotExist = errors.New("Record does not exist") | |||
// ErrCacheFailed cache failed error | |||
ErrCacheFailed = errors.New("Cache failed") | |||
// ErrNeedDeletedCond delete needs less one condition error | |||
ErrNeedDeletedCond = errors.New("Delete need at least one condition") | |||
ErrNeedDeletedCond = errors.New("Delete action needs at least one condition") | |||
// ErrNotImplemented not implemented | |||
ErrNotImplemented = errors.New("Not implemented") | |||
// ErrConditionType condition type unsupported | |||
ErrConditionType = errors.New("Unsupported conditon type") | |||
ErrConditionType = errors.New("Unsupported condition type") | |||
) | |||
// ErrFieldIsNotExist columns does not exist | |||
type ErrFieldIsNotExist struct { | |||
FieldName string | |||
TableName string | |||
} | |||
func (e ErrFieldIsNotExist) Error() string { | |||
return fmt.Sprintf("field %s is not valid on table %s", e.FieldName, e.TableName) | |||
} | |||
// ErrFieldIsNotValid is not valid | |||
type ErrFieldIsNotValid struct { | |||
FieldName string | |||
TableName string | |||
} | |||
func (e ErrFieldIsNotValid) Error() string { | |||
return fmt.Sprintf("field %s is not valid on table %s", e.FieldName, e.TableName) | |||
} |
@@ -11,7 +11,6 @@ import ( | |||
"sort" | |||
"strconv" | |||
"strings" | |||
"time" | |||
"github.com/go-xorm/core" | |||
) | |||
@@ -293,19 +292,6 @@ func structName(v reflect.Type) string { | |||
return v.Name() | |||
} | |||
func col2NewCols(columns ...string) []string { | |||
newColumns := make([]string, 0, len(columns)) | |||
for _, col := range columns { | |||
col = strings.Replace(col, "`", "", -1) | |||
col = strings.Replace(col, `"`, "", -1) | |||
ccols := strings.Split(col, ",") | |||
for _, c := range ccols { | |||
newColumns = append(newColumns, strings.TrimSpace(c)) | |||
} | |||
} | |||
return newColumns | |||
} | |||
func sliceEq(left, right []string) bool { | |||
if len(left) != len(right) { | |||
return false | |||
@@ -320,154 +306,6 @@ func sliceEq(left, right []string) bool { | |||
return true | |||
} | |||
func setColumnInt(bean interface{}, col *core.Column, t int64) { | |||
v, err := col.ValueOf(bean) | |||
if err != nil { | |||
return | |||
} | |||
if v.CanSet() { | |||
switch v.Type().Kind() { | |||
case reflect.Int, reflect.Int64, reflect.Int32: | |||
v.SetInt(t) | |||
case reflect.Uint, reflect.Uint64, reflect.Uint32: | |||
v.SetUint(uint64(t)) | |||
} | |||
} | |||
} | |||
func setColumnTime(bean interface{}, col *core.Column, t time.Time) { | |||
v, err := col.ValueOf(bean) | |||
if err != nil { | |||
return | |||
} | |||
if v.CanSet() { | |||
switch v.Type().Kind() { | |||
case reflect.Struct: | |||
v.Set(reflect.ValueOf(t).Convert(v.Type())) | |||
case reflect.Int, reflect.Int64, reflect.Int32: | |||
v.SetInt(t.Unix()) | |||
case reflect.Uint, reflect.Uint64, reflect.Uint32: | |||
v.SetUint(uint64(t.Unix())) | |||
} | |||
} | |||
} | |||
func genCols(table *core.Table, session *Session, bean interface{}, useCol bool, includeQuote bool) ([]string, []interface{}, error) { | |||
colNames := make([]string, 0, len(table.ColumnsSeq())) | |||
args := make([]interface{}, 0, len(table.ColumnsSeq())) | |||
for _, col := range table.Columns() { | |||
if useCol && !col.IsVersion && !col.IsCreated && !col.IsUpdated { | |||
if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok { | |||
continue | |||
} | |||
} | |||
if col.MapType == core.ONLYFROMDB { | |||
continue | |||
} | |||
fieldValuePtr, err := col.ValueOf(bean) | |||
if err != nil { | |||
return nil, nil, err | |||
} | |||
fieldValue := *fieldValuePtr | |||
if col.IsAutoIncrement { | |||
switch fieldValue.Type().Kind() { | |||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64: | |||
if fieldValue.Int() == 0 { | |||
continue | |||
} | |||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64: | |||
if fieldValue.Uint() == 0 { | |||
continue | |||
} | |||
case reflect.String: | |||
if len(fieldValue.String()) == 0 { | |||
continue | |||
} | |||
case reflect.Ptr: | |||
if fieldValue.Pointer() == 0 { | |||
continue | |||
} | |||
} | |||
} | |||
if col.IsDeleted { | |||
continue | |||
} | |||
if session.statement.ColumnStr != "" { | |||
if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok { | |||
continue | |||
} else if _, ok := session.statement.incrColumns[col.Name]; ok { | |||
continue | |||
} else if _, ok := session.statement.decrColumns[col.Name]; ok { | |||
continue | |||
} | |||
} | |||
if session.statement.OmitStr != "" { | |||
if _, ok := getFlagForColumn(session.statement.columnMap, col); ok { | |||
continue | |||
} | |||
} | |||
// !evalphobia! set fieldValue as nil when column is nullable and zero-value | |||
if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok { | |||
if col.Nullable && isZero(fieldValue.Interface()) { | |||
var nilValue *int | |||
fieldValue = reflect.ValueOf(nilValue) | |||
} | |||
} | |||
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ { | |||
// if time is non-empty, then set to auto time | |||
val, t := session.engine.nowTime(col) | |||
args = append(args, val) | |||
var colName = col.Name | |||
session.afterClosures = append(session.afterClosures, func(bean interface{}) { | |||
col := table.GetColumn(colName) | |||
setColumnTime(bean, col, t) | |||
}) | |||
} else if col.IsVersion && session.statement.checkVersion { | |||
args = append(args, 1) | |||
} else { | |||
arg, err := session.value2Interface(col, fieldValue) | |||
if err != nil { | |||
return colNames, args, err | |||
} | |||
args = append(args, arg) | |||
} | |||
if includeQuote { | |||
colNames = append(colNames, session.engine.Quote(col.Name)+" = ?") | |||
} else { | |||
colNames = append(colNames, col.Name) | |||
} | |||
} | |||
return colNames, args, nil | |||
} | |||
func indexName(tableName, idxName string) string { | |||
return fmt.Sprintf("IDX_%v_%v", tableName, idxName) | |||
} | |||
func getFlagForColumn(m map[string]bool, col *core.Column) (val bool, has bool) { | |||
if len(m) == 0 { | |||
return false, false | |||
} | |||
n := len(col.Name) | |||
for mk := range m { | |||
if len(mk) != n { | |||
continue | |||
} | |||
if strings.EqualFold(mk, col.Name) { | |||
return m[mk], true | |||
} | |||
} | |||
return false, false | |||
} |
@@ -30,6 +30,7 @@ type Interface interface { | |||
Exec(string, ...interface{}) (sql.Result, error) | |||
Exist(bean ...interface{}) (bool, error) | |||
Find(interface{}, ...interface{}) error | |||
FindAndCount(interface{}, ...interface{}) (int64, error) | |||
Get(interface{}) (bool, error) | |||
GroupBy(keys string) *Session | |||
ID(interface{}) *Session | |||
@@ -41,6 +42,7 @@ type Interface interface { | |||
IsTableExist(beanOrTableName interface{}) (bool, error) | |||
Iterate(interface{}, IterFunc) error | |||
Limit(int, ...int) *Session | |||
MustCols(columns ...string) *Session | |||
NoAutoCondition(...bool) *Session | |||
NotIn(string, ...interface{}) *Session | |||
Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session | |||
@@ -75,6 +77,7 @@ type EngineInterface interface { | |||
Dialect() core.Dialect | |||
DropTables(...interface{}) error | |||
DumpAllToFile(fp string, tp ...core.DbType) error | |||
GetCacher(string) core.Cacher | |||
GetColumnMapper() core.IMapper | |||
GetDefaultCacher() core.Cacher | |||
GetTableMapper() core.IMapper | |||
@@ -83,9 +86,11 @@ type EngineInterface interface { | |||
NewSession() *Session | |||
NoAutoTime() *Session | |||
Quote(string) string | |||
SetCacher(string, core.Cacher) | |||
SetDefaultCacher(core.Cacher) | |||
SetLogLevel(core.LogLevel) | |||
SetMapper(core.IMapper) | |||
SetSchema(string) | |||
SetTZDatabase(tz *time.Location) | |||
SetTZLocation(tz *time.Location) | |||
ShowSQL(show ...bool) | |||
@@ -93,6 +98,7 @@ type EngineInterface interface { | |||
Sync2(...interface{}) error | |||
StoreEngine(storeEngine string) *Session | |||
TableInfo(bean interface{}) *Table | |||
TableName(interface{}, ...bool) string | |||
UnMapType(reflect.Type) | |||
} | |||
@@ -32,7 +32,7 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { | |||
var args []interface{} | |||
var err error | |||
if err = rows.session.statement.setRefValue(rValue(bean)); err != nil { | |||
if err = rows.session.statement.setRefBean(bean); err != nil { | |||
return nil, err | |||
} | |||
@@ -94,8 +94,7 @@ func (rows *Rows) Scan(bean interface{}) error { | |||
return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType) | |||
} | |||
dataStruct := rValue(bean) | |||
if err := rows.session.statement.setRefValue(dataStruct); err != nil { | |||
if err := rows.session.statement.setRefBean(bean); err != nil { | |||
return err | |||
} | |||
@@ -104,6 +103,7 @@ func (rows *Rows) Scan(bean interface{}) error { | |||
return err | |||
} | |||
dataStruct := rValue(bean) | |||
_, err = rows.session.slice2Bean(scanResults, rows.fields, bean, &dataStruct, rows.session.statement.RefTable) | |||
if err != nil { | |||
return err |
@@ -278,24 +278,22 @@ func (session *Session) doPrepare(db *core.DB, sqlStr string) (stmt *core.Stmt, | |||
return | |||
} | |||
func (session *Session) getField(dataStruct *reflect.Value, key string, table *core.Table, idx int) *reflect.Value { | |||
func (session *Session) getField(dataStruct *reflect.Value, key string, table *core.Table, idx int) (*reflect.Value, error) { | |||
var col *core.Column | |||
if col = table.GetColumnIdx(key, idx); col == nil { | |||
//session.engine.logger.Warnf("table %v has no column %v. %v", table.Name, key, table.ColumnsSeq()) | |||
return nil | |||
return nil, ErrFieldIsNotExist{key, table.Name} | |||
} | |||
fieldValue, err := col.ValueOfV(dataStruct) | |||
if err != nil { | |||
session.engine.logger.Error(err) | |||
return nil | |||
return nil, err | |||
} | |||
if !fieldValue.IsValid() || !fieldValue.CanSet() { | |||
session.engine.logger.Warnf("table %v's column %v is not valid or cannot set", table.Name, key) | |||
return nil | |||
return nil, ErrFieldIsNotValid{key, table.Name} | |||
} | |||
return fieldValue | |||
return fieldValue, nil | |||
} | |||
// Cell cell is a result of one column field | |||
@@ -407,409 +405,417 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b | |||
} | |||
tempMap[lKey] = idx | |||
if fieldValue := session.getField(dataStruct, key, table, idx); fieldValue != nil { | |||
rawValue := reflect.Indirect(reflect.ValueOf(scanResults[ii])) | |||
// if row is null then ignore | |||
if rawValue.Interface() == nil { | |||
continue | |||
fieldValue, err := session.getField(dataStruct, key, table, idx) | |||
if err != nil { | |||
if !strings.Contains(err.Error(), "is not valid") { | |||
session.engine.logger.Warn(err) | |||
} | |||
continue | |||
} | |||
if fieldValue == nil { | |||
continue | |||
} | |||
rawValue := reflect.Indirect(reflect.ValueOf(scanResults[ii])) | |||
if fieldValue.CanAddr() { | |||
if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok { | |||
if data, err := value2Bytes(&rawValue); err == nil { | |||
if err := structConvert.FromDB(data); err != nil { | |||
return nil, err | |||
} | |||
} else { | |||
return nil, err | |||
} | |||
continue | |||
} | |||
} | |||
// if row is null then ignore | |||
if rawValue.Interface() == nil { | |||
continue | |||
} | |||
if _, ok := fieldValue.Interface().(core.Conversion); ok { | |||
if fieldValue.CanAddr() { | |||
if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok { | |||
if data, err := value2Bytes(&rawValue); err == nil { | |||
if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { | |||
fieldValue.Set(reflect.New(fieldValue.Type().Elem())) | |||
if err := structConvert.FromDB(data); err != nil { | |||
return nil, err | |||
} | |||
fieldValue.Interface().(core.Conversion).FromDB(data) | |||
} else { | |||
return nil, err | |||
} | |||
continue | |||
} | |||
} | |||
rawValueType := reflect.TypeOf(rawValue.Interface()) | |||
vv := reflect.ValueOf(rawValue.Interface()) | |||
col := table.GetColumnIdx(key, idx) | |||
if col.IsPrimaryKey { | |||
pk = append(pk, rawValue.Interface()) | |||
if _, ok := fieldValue.Interface().(core.Conversion); ok { | |||
if data, err := value2Bytes(&rawValue); err == nil { | |||
if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { | |||
fieldValue.Set(reflect.New(fieldValue.Type().Elem())) | |||
} | |||
fieldValue.Interface().(core.Conversion).FromDB(data) | |||
} else { | |||
return nil, err | |||
} | |||
fieldType := fieldValue.Type() | |||
hasAssigned := false | |||
continue | |||
} | |||
if col.SQLType.IsJson() { | |||
var bs []byte | |||
if rawValueType.Kind() == reflect.String { | |||
bs = []byte(vv.String()) | |||
} else if rawValueType.ConvertibleTo(core.BytesType) { | |||
bs = vv.Bytes() | |||
} else { | |||
return nil, fmt.Errorf("unsupported database data type: %s %v", key, rawValueType.Kind()) | |||
} | |||
rawValueType := reflect.TypeOf(rawValue.Interface()) | |||
vv := reflect.ValueOf(rawValue.Interface()) | |||
col := table.GetColumnIdx(key, idx) | |||
if col.IsPrimaryKey { | |||
pk = append(pk, rawValue.Interface()) | |||
} | |||
fieldType := fieldValue.Type() | |||
hasAssigned := false | |||
if col.SQLType.IsJson() { | |||
var bs []byte | |||
if rawValueType.Kind() == reflect.String { | |||
bs = []byte(vv.String()) | |||
} else if rawValueType.ConvertibleTo(core.BytesType) { | |||
bs = vv.Bytes() | |||
} else { | |||
return nil, fmt.Errorf("unsupported database data type: %s %v", key, rawValueType.Kind()) | |||
} | |||
hasAssigned = true | |||
hasAssigned = true | |||
if len(bs) > 0 { | |||
if fieldType.Kind() == reflect.String { | |||
fieldValue.SetString(string(bs)) | |||
continue | |||
if len(bs) > 0 { | |||
if fieldType.Kind() == reflect.String { | |||
fieldValue.SetString(string(bs)) | |||
continue | |||
} | |||
if fieldValue.CanAddr() { | |||
err := json.Unmarshal(bs, fieldValue.Addr().Interface()) | |||
if err != nil { | |||
return nil, err | |||
} | |||
if fieldValue.CanAddr() { | |||
err := json.Unmarshal(bs, fieldValue.Addr().Interface()) | |||
if err != nil { | |||
return nil, err | |||
} | |||
} else { | |||
x := reflect.New(fieldType) | |||
err := json.Unmarshal(bs, x.Interface()) | |||
if err != nil { | |||
return nil, err | |||
} | |||
fieldValue.Set(x.Elem()) | |||
} else { | |||
x := reflect.New(fieldType) | |||
err := json.Unmarshal(bs, x.Interface()) | |||
if err != nil { | |||
return nil, err | |||
} | |||
fieldValue.Set(x.Elem()) | |||
} | |||
continue | |||
} | |||
switch fieldType.Kind() { | |||
case reflect.Complex64, reflect.Complex128: | |||
// TODO: reimplement this | |||
var bs []byte | |||
if rawValueType.Kind() == reflect.String { | |||
bs = []byte(vv.String()) | |||
} else if rawValueType.ConvertibleTo(core.BytesType) { | |||
bs = vv.Bytes() | |||
} | |||
continue | |||
} | |||
hasAssigned = true | |||
if len(bs) > 0 { | |||
if fieldValue.CanAddr() { | |||
err := json.Unmarshal(bs, fieldValue.Addr().Interface()) | |||
if err != nil { | |||
return nil, err | |||
} | |||
} else { | |||
x := reflect.New(fieldType) | |||
err := json.Unmarshal(bs, x.Interface()) | |||
if err != nil { | |||
return nil, err | |||
} | |||
fieldValue.Set(x.Elem()) | |||
switch fieldType.Kind() { | |||
case reflect.Complex64, reflect.Complex128: | |||
// TODO: reimplement this | |||
var bs []byte | |||
if rawValueType.Kind() == reflect.String { | |||
bs = []byte(vv.String()) | |||
} else if rawValueType.ConvertibleTo(core.BytesType) { | |||
bs = vv.Bytes() | |||
} | |||
hasAssigned = true | |||
if len(bs) > 0 { | |||
if fieldValue.CanAddr() { | |||
err := json.Unmarshal(bs, fieldValue.Addr().Interface()) | |||
if err != nil { | |||
return nil, err | |||
} | |||
} else { | |||
x := reflect.New(fieldType) | |||
err := json.Unmarshal(bs, x.Interface()) | |||
if err != nil { | |||
return nil, err | |||
} | |||
fieldValue.Set(x.Elem()) | |||
} | |||
} | |||
case reflect.Slice, reflect.Array: | |||
switch rawValueType.Kind() { | |||
case reflect.Slice, reflect.Array: | |||
switch rawValueType.Kind() { | |||
case reflect.Slice, reflect.Array: | |||
switch rawValueType.Elem().Kind() { | |||
case reflect.Uint8: | |||
if fieldType.Elem().Kind() == reflect.Uint8 { | |||
hasAssigned = true | |||
if col.SQLType.IsText() { | |||
x := reflect.New(fieldType) | |||
err := json.Unmarshal(vv.Bytes(), x.Interface()) | |||
if err != nil { | |||
return nil, err | |||
switch rawValueType.Elem().Kind() { | |||
case reflect.Uint8: | |||
if fieldType.Elem().Kind() == reflect.Uint8 { | |||
hasAssigned = true | |||
if col.SQLType.IsText() { | |||
x := reflect.New(fieldType) | |||
err := json.Unmarshal(vv.Bytes(), x.Interface()) | |||
if err != nil { | |||
return nil, err | |||
} | |||
fieldValue.Set(x.Elem()) | |||
} else { | |||
if fieldValue.Len() > 0 { | |||
for i := 0; i < fieldValue.Len(); i++ { | |||
if i < vv.Len() { | |||
fieldValue.Index(i).Set(vv.Index(i)) | |||
} | |||
} | |||
fieldValue.Set(x.Elem()) | |||
} else { | |||
if fieldValue.Len() > 0 { | |||
for i := 0; i < fieldValue.Len(); i++ { | |||
if i < vv.Len() { | |||
fieldValue.Index(i).Set(vv.Index(i)) | |||
} | |||
} | |||
} else { | |||
for i := 0; i < vv.Len(); i++ { | |||
fieldValue.Set(reflect.Append(*fieldValue, vv.Index(i))) | |||
} | |||
for i := 0; i < vv.Len(); i++ { | |||
fieldValue.Set(reflect.Append(*fieldValue, vv.Index(i))) | |||
} | |||
} | |||
} | |||
} | |||
} | |||
case reflect.String: | |||
if rawValueType.Kind() == reflect.String { | |||
hasAssigned = true | |||
fieldValue.SetString(vv.String()) | |||
} | |||
case reflect.Bool: | |||
if rawValueType.Kind() == reflect.Bool { | |||
hasAssigned = true | |||
fieldValue.SetBool(vv.Bool()) | |||
} | |||
} | |||
case reflect.String: | |||
if rawValueType.Kind() == reflect.String { | |||
hasAssigned = true | |||
fieldValue.SetString(vv.String()) | |||
} | |||
case reflect.Bool: | |||
if rawValueType.Kind() == reflect.Bool { | |||
hasAssigned = true | |||
fieldValue.SetBool(vv.Bool()) | |||
} | |||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | |||
switch rawValueType.Kind() { | |||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | |||
switch rawValueType.Kind() { | |||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | |||
hasAssigned = true | |||
fieldValue.SetInt(vv.Int()) | |||
} | |||
hasAssigned = true | |||
fieldValue.SetInt(vv.Int()) | |||
} | |||
case reflect.Float32, reflect.Float64: | |||
switch rawValueType.Kind() { | |||
case reflect.Float32, reflect.Float64: | |||
switch rawValueType.Kind() { | |||
case reflect.Float32, reflect.Float64: | |||
hasAssigned = true | |||
fieldValue.SetFloat(vv.Float()) | |||
} | |||
hasAssigned = true | |||
fieldValue.SetFloat(vv.Float()) | |||
} | |||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: | |||
switch rawValueType.Kind() { | |||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: | |||
switch rawValueType.Kind() { | |||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: | |||
hasAssigned = true | |||
fieldValue.SetUint(vv.Uint()) | |||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | |||
hasAssigned = true | |||
fieldValue.SetUint(uint64(vv.Int())) | |||
hasAssigned = true | |||
fieldValue.SetUint(vv.Uint()) | |||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | |||
hasAssigned = true | |||
fieldValue.SetUint(uint64(vv.Int())) | |||
} | |||
case reflect.Struct: | |||
if fieldType.ConvertibleTo(core.TimeType) { | |||
dbTZ := session.engine.DatabaseTZ | |||
if col.TimeZone != nil { | |||
dbTZ = col.TimeZone | |||
} | |||
case reflect.Struct: | |||
if fieldType.ConvertibleTo(core.TimeType) { | |||
dbTZ := session.engine.DatabaseTZ | |||
if col.TimeZone != nil { | |||
dbTZ = col.TimeZone | |||
} | |||
if rawValueType == core.TimeType { | |||
hasAssigned = true | |||
t := vv.Convert(core.TimeType).Interface().(time.Time) | |||
z, _ := t.Zone() | |||
// set new location if database don't save timezone or give an incorrect timezone | |||
if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbTZ.String() { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location | |||
session.engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location()) | |||
t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), | |||
t.Minute(), t.Second(), t.Nanosecond(), dbTZ) | |||
} | |||
if rawValueType == core.TimeType { | |||
hasAssigned = true | |||
t = t.In(session.engine.TZLocation) | |||
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) | |||
} else if rawValueType == core.IntType || rawValueType == core.Int64Type || | |||
rawValueType == core.Int32Type { | |||
hasAssigned = true | |||
t := vv.Convert(core.TimeType).Interface().(time.Time) | |||
t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation) | |||
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) | |||
} else { | |||
if d, ok := vv.Interface().([]uint8); ok { | |||
hasAssigned = true | |||
t, err := session.byte2Time(col, d) | |||
if err != nil { | |||
session.engine.logger.Error("byte2Time error:", err.Error()) | |||
hasAssigned = false | |||
} else { | |||
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) | |||
} | |||
} else if d, ok := vv.Interface().(string); ok { | |||
hasAssigned = true | |||
t, err := session.str2Time(col, d) | |||
if err != nil { | |||
session.engine.logger.Error("byte2Time error:", err.Error()) | |||
hasAssigned = false | |||
} else { | |||
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) | |||
} | |||
} else { | |||
return nil, fmt.Errorf("rawValueType is %v, value is %v", rawValueType, vv.Interface()) | |||
} | |||
} | |||
} else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { | |||
// !<winxxp>! 增加支持sql.Scanner接口的结构,如sql.NullString | |||
hasAssigned = true | |||
if err := nulVal.Scan(vv.Interface()); err != nil { | |||
session.engine.logger.Error("sql.Sanner error:", err.Error()) | |||
hasAssigned = false | |||
} | |||
} else if col.SQLType.IsJson() { | |||
if rawValueType.Kind() == reflect.String { | |||
hasAssigned = true | |||
x := reflect.New(fieldType) | |||
if len([]byte(vv.String())) > 0 { | |||
err := json.Unmarshal([]byte(vv.String()), x.Interface()) | |||
if err != nil { | |||
return nil, err | |||
} | |||
fieldValue.Set(x.Elem()) | |||
} | |||
} else if rawValueType.Kind() == reflect.Slice { | |||
hasAssigned = true | |||
x := reflect.New(fieldType) | |||
if len(vv.Bytes()) > 0 { | |||
err := json.Unmarshal(vv.Bytes(), x.Interface()) | |||
if err != nil { | |||
return nil, err | |||
} | |||
fieldValue.Set(x.Elem()) | |||
} | |||
} | |||
} else if session.statement.UseCascade { | |||
table, err := session.engine.autoMapType(*fieldValue) | |||
if err != nil { | |||
return nil, err | |||
z, _ := t.Zone() | |||
// set new location if database don't save timezone or give an incorrect timezone | |||
if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbTZ.String() { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location | |||
session.engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location()) | |||
t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), | |||
t.Minute(), t.Second(), t.Nanosecond(), dbTZ) | |||
} | |||
t = t.In(session.engine.TZLocation) | |||
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) | |||
} else if rawValueType == core.IntType || rawValueType == core.Int64Type || | |||
rawValueType == core.Int32Type { | |||
hasAssigned = true | |||
if len(table.PrimaryKeys) != 1 { | |||
return nil, errors.New("unsupported non or composited primary key cascade") | |||
} | |||
var pk = make(core.PK, len(table.PrimaryKeys)) | |||
pk[0], err = asKind(vv, rawValueType) | |||
if err != nil { | |||
return nil, err | |||
} | |||
if !isPKZero(pk) { | |||
// !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch | |||
// however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne | |||
// property to be fetched lazily | |||
structInter := reflect.New(fieldValue.Type()) | |||
has, err := session.ID(pk).NoCascade().get(structInter.Interface()) | |||
t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation) | |||
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) | |||
} else { | |||
if d, ok := vv.Interface().([]uint8); ok { | |||
hasAssigned = true | |||
t, err := session.byte2Time(col, d) | |||
if err != nil { | |||
return nil, err | |||
session.engine.logger.Error("byte2Time error:", err.Error()) | |||
hasAssigned = false | |||
} else { | |||
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) | |||
} | |||
if has { | |||
fieldValue.Set(structInter.Elem()) | |||
} else if d, ok := vv.Interface().(string); ok { | |||
hasAssigned = true | |||
t, err := session.str2Time(col, d) | |||
if err != nil { | |||
session.engine.logger.Error("byte2Time error:", err.Error()) | |||
hasAssigned = false | |||
} else { | |||
return nil, errors.New("cascade obj is not exist") | |||
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) | |||
} | |||
} else { | |||
return nil, fmt.Errorf("rawValueType is %v, value is %v", rawValueType, vv.Interface()) | |||
} | |||
} | |||
case reflect.Ptr: | |||
// !nashtsai! TODO merge duplicated codes above | |||
switch fieldType { | |||
// following types case matching ptr's native type, therefore assign ptr directly | |||
case core.PtrStringType: | |||
if rawValueType.Kind() == reflect.String { | |||
x := vv.String() | |||
hasAssigned = true | |||
fieldValue.Set(reflect.ValueOf(&x)) | |||
} | |||
case core.PtrBoolType: | |||
if rawValueType.Kind() == reflect.Bool { | |||
x := vv.Bool() | |||
hasAssigned = true | |||
fieldValue.Set(reflect.ValueOf(&x)) | |||
} | |||
case core.PtrTimeType: | |||
if rawValueType == core.PtrTimeType { | |||
hasAssigned = true | |||
var x = rawValue.Interface().(time.Time) | |||
fieldValue.Set(reflect.ValueOf(&x)) | |||
} | |||
case core.PtrFloat64Type: | |||
if rawValueType.Kind() == reflect.Float64 { | |||
x := vv.Float() | |||
hasAssigned = true | |||
fieldValue.Set(reflect.ValueOf(&x)) | |||
} | |||
case core.PtrUint64Type: | |||
if rawValueType.Kind() == reflect.Int64 { | |||
var x = uint64(vv.Int()) | |||
hasAssigned = true | |||
fieldValue.Set(reflect.ValueOf(&x)) | |||
} | |||
case core.PtrInt64Type: | |||
if rawValueType.Kind() == reflect.Int64 { | |||
x := vv.Int() | |||
hasAssigned = true | |||
fieldValue.Set(reflect.ValueOf(&x)) | |||
} | |||
case core.PtrFloat32Type: | |||
if rawValueType.Kind() == reflect.Float64 { | |||
var x = float32(vv.Float()) | |||
hasAssigned = true | |||
fieldValue.Set(reflect.ValueOf(&x)) | |||
} | |||
case core.PtrIntType: | |||
if rawValueType.Kind() == reflect.Int64 { | |||
var x = int(vv.Int()) | |||
hasAssigned = true | |||
fieldValue.Set(reflect.ValueOf(&x)) | |||
} | |||
case core.PtrInt32Type: | |||
if rawValueType.Kind() == reflect.Int64 { | |||
var x = int32(vv.Int()) | |||
hasAssigned = true | |||
fieldValue.Set(reflect.ValueOf(&x)) | |||
} | |||
case core.PtrInt8Type: | |||
if rawValueType.Kind() == reflect.Int64 { | |||
var x = int8(vv.Int()) | |||
hasAssigned = true | |||
fieldValue.Set(reflect.ValueOf(&x)) | |||
} | |||
case core.PtrInt16Type: | |||
if rawValueType.Kind() == reflect.Int64 { | |||
var x = int16(vv.Int()) | |||
hasAssigned = true | |||
fieldValue.Set(reflect.ValueOf(&x)) | |||
} | |||
case core.PtrUintType: | |||
if rawValueType.Kind() == reflect.Int64 { | |||
var x = uint(vv.Int()) | |||
hasAssigned = true | |||
fieldValue.Set(reflect.ValueOf(&x)) | |||
} | |||
case core.PtrUint32Type: | |||
if rawValueType.Kind() == reflect.Int64 { | |||
var x = uint32(vv.Int()) | |||
hasAssigned = true | |||
fieldValue.Set(reflect.ValueOf(&x)) | |||
} | |||
case core.Uint8Type: | |||
if rawValueType.Kind() == reflect.Int64 { | |||
var x = uint8(vv.Int()) | |||
hasAssigned = true | |||
fieldValue.Set(reflect.ValueOf(&x)) | |||
} | |||
case core.Uint16Type: | |||
if rawValueType.Kind() == reflect.Int64 { | |||
var x = uint16(vv.Int()) | |||
hasAssigned = true | |||
fieldValue.Set(reflect.ValueOf(&x)) | |||
} | |||
case core.Complex64Type: | |||
var x complex64 | |||
} else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { | |||
// !<winxxp>! 增加支持sql.Scanner接口的结构,如sql.NullString | |||
hasAssigned = true | |||
if err := nulVal.Scan(vv.Interface()); err != nil { | |||
session.engine.logger.Error("sql.Sanner error:", err.Error()) | |||
hasAssigned = false | |||
} | |||
} else if col.SQLType.IsJson() { | |||
if rawValueType.Kind() == reflect.String { | |||
hasAssigned = true | |||
x := reflect.New(fieldType) | |||
if len([]byte(vv.String())) > 0 { | |||
err := json.Unmarshal([]byte(vv.String()), &x) | |||
err := json.Unmarshal([]byte(vv.String()), x.Interface()) | |||
if err != nil { | |||
return nil, err | |||
} | |||
fieldValue.Set(reflect.ValueOf(&x)) | |||
fieldValue.Set(x.Elem()) | |||
} | |||
} else if rawValueType.Kind() == reflect.Slice { | |||
hasAssigned = true | |||
case core.Complex128Type: | |||
var x complex128 | |||
if len([]byte(vv.String())) > 0 { | |||
err := json.Unmarshal([]byte(vv.String()), &x) | |||
x := reflect.New(fieldType) | |||
if len(vv.Bytes()) > 0 { | |||
err := json.Unmarshal(vv.Bytes(), x.Interface()) | |||
if err != nil { | |||
return nil, err | |||
} | |||
fieldValue.Set(reflect.ValueOf(&x)) | |||
fieldValue.Set(x.Elem()) | |||
} | |||
hasAssigned = true | |||
} // switch fieldType | |||
} // switch fieldType.Kind() | |||
// !nashtsai! for value can't be assigned directly fallback to convert to []byte then back to value | |||
if !hasAssigned { | |||
data, err := value2Bytes(&rawValue) | |||
} | |||
} else if session.statement.UseCascade { | |||
table, err := session.engine.autoMapType(*fieldValue) | |||
if err != nil { | |||
return nil, err | |||
} | |||
if err = session.bytes2Value(col, fieldValue, data); err != nil { | |||
hasAssigned = true | |||
if len(table.PrimaryKeys) != 1 { | |||
return nil, errors.New("unsupported non or composited primary key cascade") | |||
} | |||
var pk = make(core.PK, len(table.PrimaryKeys)) | |||
pk[0], err = asKind(vv, rawValueType) | |||
if err != nil { | |||
return nil, err | |||
} | |||
if !isPKZero(pk) { | |||
// !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch | |||
// however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne | |||
// property to be fetched lazily | |||
structInter := reflect.New(fieldValue.Type()) | |||
has, err := session.ID(pk).NoCascade().get(structInter.Interface()) | |||
if err != nil { | |||
return nil, err | |||
} | |||
if has { | |||
fieldValue.Set(structInter.Elem()) | |||
} else { | |||
return nil, errors.New("cascade obj is not exist") | |||
} | |||
} | |||
} | |||
case reflect.Ptr: | |||
// !nashtsai! TODO merge duplicated codes above | |||
switch fieldType { | |||
// following types case matching ptr's native type, therefore assign ptr directly | |||
case core.PtrStringType: | |||
if rawValueType.Kind() == reflect.String { | |||
x := vv.String() | |||
hasAssigned = true | |||
fieldValue.Set(reflect.ValueOf(&x)) | |||
} | |||
case core.PtrBoolType: | |||
if rawValueType.Kind() == reflect.Bool { | |||
x := vv.Bool() | |||
hasAssigned = true | |||
fieldValue.Set(reflect.ValueOf(&x)) | |||
} | |||
case core.PtrTimeType: | |||
if rawValueType == core.PtrTimeType { | |||
hasAssigned = true | |||
var x = rawValue.Interface().(time.Time) | |||
fieldValue.Set(reflect.ValueOf(&x)) | |||
} | |||
case core.PtrFloat64Type: | |||
if rawValueType.Kind() == reflect.Float64 { | |||
x := vv.Float() | |||
hasAssigned = true | |||
fieldValue.Set(reflect.ValueOf(&x)) | |||
} | |||
case core.PtrUint64Type: | |||
if rawValueType.Kind() == reflect.Int64 { | |||
var x = uint64(vv.Int()) | |||
hasAssigned = true | |||
fieldValue.Set(reflect.ValueOf(&x)) | |||
} | |||
case core.PtrInt64Type: | |||
if rawValueType.Kind() == reflect.Int64 { | |||
x := vv.Int() | |||
hasAssigned = true | |||
fieldValue.Set(reflect.ValueOf(&x)) | |||
} | |||
case core.PtrFloat32Type: | |||
if rawValueType.Kind() == reflect.Float64 { | |||
var x = float32(vv.Float()) | |||
hasAssigned = true | |||
fieldValue.Set(reflect.ValueOf(&x)) | |||
} | |||
case core.PtrIntType: | |||
if rawValueType.Kind() == reflect.Int64 { | |||
var x = int(vv.Int()) | |||
hasAssigned = true | |||
fieldValue.Set(reflect.ValueOf(&x)) | |||
} | |||
case core.PtrInt32Type: | |||
if rawValueType.Kind() == reflect.Int64 { | |||
var x = int32(vv.Int()) | |||
hasAssigned = true | |||
fieldValue.Set(reflect.ValueOf(&x)) | |||
} | |||
case core.PtrInt8Type: | |||
if rawValueType.Kind() == reflect.Int64 { | |||
var x = int8(vv.Int()) | |||
hasAssigned = true | |||
fieldValue.Set(reflect.ValueOf(&x)) | |||
} | |||
case core.PtrInt16Type: | |||
if rawValueType.Kind() == reflect.Int64 { | |||
var x = int16(vv.Int()) | |||
hasAssigned = true | |||
fieldValue.Set(reflect.ValueOf(&x)) | |||
} | |||
case core.PtrUintType: | |||
if rawValueType.Kind() == reflect.Int64 { | |||
var x = uint(vv.Int()) | |||
hasAssigned = true | |||
fieldValue.Set(reflect.ValueOf(&x)) | |||
} | |||
case core.PtrUint32Type: | |||
if rawValueType.Kind() == reflect.Int64 { | |||
var x = uint32(vv.Int()) | |||
hasAssigned = true | |||
fieldValue.Set(reflect.ValueOf(&x)) | |||
} | |||
case core.Uint8Type: | |||
if rawValueType.Kind() == reflect.Int64 { | |||
var x = uint8(vv.Int()) | |||
hasAssigned = true | |||
fieldValue.Set(reflect.ValueOf(&x)) | |||
} | |||
case core.Uint16Type: | |||
if rawValueType.Kind() == reflect.Int64 { | |||
var x = uint16(vv.Int()) | |||
hasAssigned = true | |||
fieldValue.Set(reflect.ValueOf(&x)) | |||
} | |||
case core.Complex64Type: | |||
var x complex64 | |||
if len([]byte(vv.String())) > 0 { | |||
err := json.Unmarshal([]byte(vv.String()), &x) | |||
if err != nil { | |||
return nil, err | |||
} | |||
fieldValue.Set(reflect.ValueOf(&x)) | |||
} | |||
hasAssigned = true | |||
case core.Complex128Type: | |||
var x complex128 | |||
if len([]byte(vv.String())) > 0 { | |||
err := json.Unmarshal([]byte(vv.String()), &x) | |||
if err != nil { | |||
return nil, err | |||
} | |||
fieldValue.Set(reflect.ValueOf(&x)) | |||
} | |||
hasAssigned = true | |||
} // switch fieldType | |||
} // switch fieldType.Kind() | |||
// !nashtsai! for value can't be assigned directly fallback to convert to []byte then back to value | |||
if !hasAssigned { | |||
data, err := value2Bytes(&rawValue) | |||
if err != nil { | |||
return nil, err | |||
} | |||
if err = session.bytes2Value(col, fieldValue, data); err != nil { | |||
return nil, err | |||
} | |||
} | |||
} | |||
@@ -828,15 +834,6 @@ func (session *Session) LastSQL() (string, []interface{}) { | |||
return session.lastSQL, session.lastSQLArgs | |||
} | |||
// tbName get some table's table name | |||
func (session *Session) tbNameNoSchema(table *core.Table) string { | |||
if len(session.statement.AltTableName) > 0 { | |||
return session.statement.AltTableName | |||
} | |||
return table.Name | |||
} | |||
// Unscoped always disable struct tag "deleted" | |||
func (session *Session) Unscoped() *Session { | |||
session.statement.Unscoped() |
@@ -4,6 +4,121 @@ | |||
package xorm | |||
import ( | |||
"reflect" | |||
"strings" | |||
"time" | |||
"github.com/go-xorm/core" | |||
) | |||
type incrParam struct { | |||
colName string | |||
arg interface{} | |||
} | |||
type decrParam struct { | |||
colName string | |||
arg interface{} | |||
} | |||
type exprParam struct { | |||
colName string | |||
expr string | |||
} | |||
type columnMap []string | |||
func (m columnMap) contain(colName string) bool { | |||
if len(m) == 0 { | |||
return false | |||
} | |||
n := len(colName) | |||
for _, mk := range m { | |||
if len(mk) != n { | |||
continue | |||
} | |||
if strings.EqualFold(mk, colName) { | |||
return true | |||
} | |||
} | |||
return false | |||
} | |||
func (m *columnMap) add(colName string) bool { | |||
if m.contain(colName) { | |||
return false | |||
} | |||
*m = append(*m, colName) | |||
return true | |||
} | |||
func setColumnInt(bean interface{}, col *core.Column, t int64) { | |||
v, err := col.ValueOf(bean) | |||
if err != nil { | |||
return | |||
} | |||
if v.CanSet() { | |||
switch v.Type().Kind() { | |||
case reflect.Int, reflect.Int64, reflect.Int32: | |||
v.SetInt(t) | |||
case reflect.Uint, reflect.Uint64, reflect.Uint32: | |||
v.SetUint(uint64(t)) | |||
} | |||
} | |||
} | |||
func setColumnTime(bean interface{}, col *core.Column, t time.Time) { | |||
v, err := col.ValueOf(bean) | |||
if err != nil { | |||
return | |||
} | |||
if v.CanSet() { | |||
switch v.Type().Kind() { | |||
case reflect.Struct: | |||
v.Set(reflect.ValueOf(t).Convert(v.Type())) | |||
case reflect.Int, reflect.Int64, reflect.Int32: | |||
v.SetInt(t.Unix()) | |||
case reflect.Uint, reflect.Uint64, reflect.Uint32: | |||
v.SetUint(uint64(t.Unix())) | |||
} | |||
} | |||
} | |||
func getFlagForColumn(m map[string]bool, col *core.Column) (val bool, has bool) { | |||
if len(m) == 0 { | |||
return false, false | |||
} | |||
n := len(col.Name) | |||
for mk := range m { | |||
if len(mk) != n { | |||
continue | |||
} | |||
if strings.EqualFold(mk, col.Name) { | |||
return m[mk], true | |||
} | |||
} | |||
return false, false | |||
} | |||
func col2NewCols(columns ...string) []string { | |||
newColumns := make([]string, 0, len(columns)) | |||
for _, col := range columns { | |||
col = strings.Replace(col, "`", "", -1) | |||
col = strings.Replace(col, `"`, "", -1) | |||
ccols := strings.Split(col, ",") | |||
for _, c := range ccols { | |||
newColumns = append(newColumns, strings.TrimSpace(c)) | |||
} | |||
} | |||
return newColumns | |||
} | |||
// Incr provides a query string like "count = count + 1" | |||
func (session *Session) Incr(column string, arg ...interface{}) *Session { | |||
session.statement.Incr(column, arg...) |
@@ -27,7 +27,7 @@ func (session *Session) cacheDelete(table *core.Table, tableName, sqlStr string, | |||
return ErrCacheFailed | |||
} | |||
cacher := session.engine.getCacher2(table) | |||
cacher := session.engine.getCacher(tableName) | |||
pkColumns := table.PKColumns() | |||
ids, err := core.GetCacheSql(cacher, tableName, newsql, args) | |||
if err != nil { | |||
@@ -79,7 +79,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) { | |||
defer session.Close() | |||
} | |||
if err := session.statement.setRefValue(rValue(bean)); err != nil { | |||
if err := session.statement.setRefBean(bean); err != nil { | |||
return 0, err | |||
} | |||
@@ -199,7 +199,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) { | |||
}) | |||
} | |||
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { | |||
if cacher := session.engine.getCacher(tableName); cacher != nil && session.statement.UseCache { | |||
session.cacheDelete(table, tableNameNoQuote, deleteSQL, argsForCache...) | |||
} | |||
@@ -57,7 +57,7 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) { | |||
} | |||
if beanValue.Elem().Kind() == reflect.Struct { | |||
if err := session.statement.setRefValue(beanValue.Elem()); err != nil { | |||
if err := session.statement.setRefBean(bean[0]); err != nil { | |||
return false, err | |||
} | |||
} |
@@ -29,6 +29,39 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) | |||
return session.find(rowsSlicePtr, condiBean...) | |||
} | |||
// FindAndCount find the results and also return the counts | |||
func (session *Session) FindAndCount(rowsSlicePtr interface{}, condiBean ...interface{}) (int64, error) { | |||
if session.isAutoClose { | |||
defer session.Close() | |||
} | |||
session.autoResetStatement = false | |||
err := session.find(rowsSlicePtr, condiBean...) | |||
if err != nil { | |||
return 0, err | |||
} | |||
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) | |||
if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map { | |||
return 0, errors.New("needs a pointer to a slice or a map") | |||
} | |||
sliceElementType := sliceValue.Type().Elem() | |||
if sliceElementType.Kind() == reflect.Ptr { | |||
sliceElementType = sliceElementType.Elem() | |||
} | |||
session.autoResetStatement = true | |||
if session.statement.selectStr != "" { | |||
session.statement.selectStr = "" | |||
} | |||
if session.statement.OrderStr != "" { | |||
session.statement.OrderStr = "" | |||
} | |||
return session.Count(reflect.New(sliceElementType).Interface()) | |||
} | |||
func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) error { | |||
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) | |||
if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map { | |||
@@ -42,7 +75,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) | |||
if sliceElementType.Kind() == reflect.Ptr { | |||
if sliceElementType.Elem().Kind() == reflect.Struct { | |||
pv := reflect.New(sliceElementType.Elem()) | |||
if err := session.statement.setRefValue(pv.Elem()); err != nil { | |||
if err := session.statement.setRefValue(pv); err != nil { | |||
return err | |||
} | |||
} else { | |||
@@ -50,7 +83,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) | |||
} | |||
} else if sliceElementType.Kind() == reflect.Struct { | |||
pv := reflect.New(sliceElementType) | |||
if err := session.statement.setRefValue(pv.Elem()); err != nil { | |||
if err := session.statement.setRefValue(pv); err != nil { | |||
return err | |||
} | |||
} else { | |||
@@ -128,7 +161,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) | |||
} | |||
args = append(session.statement.joinArgs, condArgs...) | |||
sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL) | |||
sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL, true, true) | |||
if err != nil { | |||
return err | |||
} | |||
@@ -143,7 +176,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) | |||
} | |||
if session.canCache() { | |||
if cacher := session.engine.getCacher2(table); cacher != nil && | |||
if cacher := session.engine.getCacher(table.Name); cacher != nil && | |||
!session.statement.IsDistinct && | |||
!session.statement.unscoped { | |||
err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...) | |||
@@ -288,6 +321,12 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in | |||
return ErrCacheFailed | |||
} | |||
tableName := session.statement.TableName() | |||
cacher := session.engine.getCacher(tableName) | |||
if cacher == nil { | |||
return nil | |||
} | |||
for _, filter := range session.engine.dialect.Filters() { | |||
sqlStr = filter.Do(sqlStr, session.engine.dialect, session.statement.RefTable) | |||
} | |||
@@ -297,9 +336,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in | |||
return ErrCacheFailed | |||
} | |||
tableName := session.statement.TableName() | |||
table := session.statement.RefTable | |||
cacher := session.engine.getCacher2(table) | |||
ids, err := core.GetCacheSql(cacher, tableName, newsql, args) | |||
if err != nil { | |||
rows, err := session.queryRows(newsql, args...) |
@@ -31,7 +31,7 @@ func (session *Session) get(bean interface{}) (bool, error) { | |||
} | |||
if beanValue.Elem().Kind() == reflect.Struct { | |||
if err := session.statement.setRefValue(beanValue.Elem()); err != nil { | |||
if err := session.statement.setRefBean(bean); err != nil { | |||
return false, err | |||
} | |||
} | |||
@@ -57,7 +57,7 @@ func (session *Session) get(bean interface{}) (bool, error) { | |||
table := session.statement.RefTable | |||
if session.canCache() && beanValue.Elem().Kind() == reflect.Struct { | |||
if cacher := session.engine.getCacher2(table); cacher != nil && | |||
if cacher := session.engine.getCacher(table.Name); cacher != nil && | |||
!session.statement.unscoped { | |||
has, err := session.cacheGet(bean, sqlStr, args...) | |||
if err != ErrCacheFailed { | |||
@@ -134,8 +134,9 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf | |||
return false, ErrCacheFailed | |||
} | |||
cacher := session.engine.getCacher2(session.statement.RefTable) | |||
tableName := session.statement.TableName() | |||
cacher := session.engine.getCacher(tableName) | |||
session.engine.logger.Debug("[cacheGet] find sql:", newsql, args) | |||
table := session.statement.RefTable | |||
ids, err := core.GetCacheSql(cacher, tableName, newsql, args) |
@@ -66,11 +66,12 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error | |||
return 0, errors.New("could not insert a empty slice") | |||
} | |||
if err := session.statement.setRefValue(reflect.ValueOf(sliceValue.Index(0).Interface())); err != nil { | |||
if err := session.statement.setRefBean(sliceValue.Index(0).Interface()); err != nil { | |||
return 0, err | |||
} | |||
if len(session.statement.TableName()) <= 0 { | |||
tableName := session.statement.TableName() | |||
if len(tableName) <= 0 { | |||
return 0, ErrTableNotFound | |||
} | |||
@@ -115,15 +116,11 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error | |||
if col.IsDeleted { | |||
continue | |||
} | |||
if session.statement.ColumnStr != "" { | |||
if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok { | |||
continue | |||
} | |||
if session.statement.omitColumnMap.contain(col.Name) { | |||
continue | |||
} | |||
if session.statement.OmitStr != "" { | |||
if _, ok := getFlagForColumn(session.statement.columnMap, col); ok { | |||
continue | |||
} | |||
if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { | |||
continue | |||
} | |||
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime { | |||
val, t := session.engine.nowTime(col) | |||
@@ -170,15 +167,11 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error | |||
if col.IsDeleted { | |||
continue | |||
} | |||
if session.statement.ColumnStr != "" { | |||
if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok { | |||
continue | |||
} | |||
if session.statement.omitColumnMap.contain(col.Name) { | |||
continue | |||
} | |||
if session.statement.OmitStr != "" { | |||
if _, ok := getFlagForColumn(session.statement.columnMap, col); ok { | |||
continue | |||
} | |||
if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { | |||
continue | |||
} | |||
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime { | |||
val, t := session.engine.nowTime(col) | |||
@@ -211,38 +204,33 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error | |||
} | |||
cleanupProcessorsClosures(&session.beforeClosures) | |||
var sql = "INSERT INTO %s (%v%v%v) VALUES (%v)" | |||
var statement string | |||
var tableName = session.statement.TableName() | |||
var sql string | |||
if session.engine.dialect.DBType() == core.ORACLE { | |||
sql = "INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL" | |||
temp := fmt.Sprintf(") INTO %s (%v%v%v) VALUES (", | |||
session.engine.Quote(tableName), | |||
session.engine.QuoteStr(), | |||
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()), | |||
session.engine.QuoteStr()) | |||
statement = fmt.Sprintf(sql, | |||
sql = fmt.Sprintf("INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL", | |||
session.engine.Quote(tableName), | |||
session.engine.QuoteStr(), | |||
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()), | |||
session.engine.QuoteStr(), | |||
strings.Join(colMultiPlaces, temp)) | |||
} else { | |||
statement = fmt.Sprintf(sql, | |||
sql = fmt.Sprintf("INSERT INTO %s (%v%v%v) VALUES (%v)", | |||
session.engine.Quote(tableName), | |||
session.engine.QuoteStr(), | |||
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()), | |||
session.engine.QuoteStr(), | |||
strings.Join(colMultiPlaces, "),(")) | |||
} | |||
res, err := session.exec(statement, args...) | |||
res, err := session.exec(sql, args...) | |||
if err != nil { | |||
return 0, err | |||
} | |||
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { | |||
session.cacheInsert(table, tableName) | |||
} | |||
session.cacheInsert(tableName) | |||
lenAfterClosures := len(session.afterClosures) | |||
for i := 0; i < size; i++ { | |||
@@ -298,7 +286,7 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) { | |||
} | |||
func (session *Session) innerInsert(bean interface{}) (int64, error) { | |||
if err := session.statement.setRefValue(rValue(bean)); err != nil { | |||
if err := session.statement.setRefBean(bean); err != nil { | |||
return 0, err | |||
} | |||
if len(session.statement.TableName()) <= 0 { | |||
@@ -316,8 +304,8 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { | |||
if processor, ok := interface{}(bean).(BeforeInsertProcessor); ok { | |||
processor.BeforeInsert() | |||
} | |||
// -- | |||
colNames, args, err := genCols(session.statement.RefTable, session, bean, false, false) | |||
colNames, args, err := session.genInsertColumns(bean) | |||
if err != nil { | |||
return 0, err | |||
} | |||
@@ -402,9 +390,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { | |||
defer handleAfterInsertProcessorFunc(bean) | |||
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { | |||
session.cacheInsert(table, tableName) | |||
} | |||
session.cacheInsert(tableName) | |||
if table.Version != "" && session.statement.checkVersion { | |||
verValue, err := table.VersionColumn().ValueOf(bean) | |||
@@ -447,9 +433,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { | |||
} | |||
defer handleAfterInsertProcessorFunc(bean) | |||
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { | |||
session.cacheInsert(table, tableName) | |||
} | |||
session.cacheInsert(tableName) | |||
if table.Version != "" && session.statement.checkVersion { | |||
verValue, err := table.VersionColumn().ValueOf(bean) | |||
@@ -490,9 +474,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { | |||
defer handleAfterInsertProcessorFunc(bean) | |||
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { | |||
session.cacheInsert(table, tableName) | |||
} | |||
session.cacheInsert(tableName) | |||
if table.Version != "" && session.statement.checkVersion { | |||
verValue, err := table.VersionColumn().ValueOf(bean) | |||
@@ -539,16 +521,104 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) { | |||
return session.innerInsert(bean) | |||
} | |||
func (session *Session) cacheInsert(table *core.Table, tables ...string) error { | |||
if table == nil { | |||
return ErrCacheFailed | |||
func (session *Session) cacheInsert(table string) error { | |||
if !session.statement.UseCache { | |||
return nil | |||
} | |||
cacher := session.engine.getCacher2(table) | |||
for _, t := range tables { | |||
session.engine.logger.Debug("[cache] clear sql:", t) | |||
cacher.ClearIds(t) | |||
cacher := session.engine.getCacher(table) | |||
if cacher == nil { | |||
return nil | |||
} | |||
session.engine.logger.Debug("[cache] clear sql:", table) | |||
cacher.ClearIds(table) | |||
return nil | |||
} | |||
// genInsertColumns generates insert needed columns | |||
func (session *Session) genInsertColumns(bean interface{}) ([]string, []interface{}, error) { | |||
table := session.statement.RefTable | |||
colNames := make([]string, 0, len(table.ColumnsSeq())) | |||
args := make([]interface{}, 0, len(table.ColumnsSeq())) | |||
for _, col := range table.Columns() { | |||
if col.MapType == core.ONLYFROMDB { | |||
continue | |||
} | |||
if col.IsDeleted { | |||
continue | |||
} | |||
if session.statement.omitColumnMap.contain(col.Name) { | |||
continue | |||
} | |||
if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { | |||
continue | |||
} | |||
if _, ok := session.statement.incrColumns[col.Name]; ok { | |||
continue | |||
} else if _, ok := session.statement.decrColumns[col.Name]; ok { | |||
continue | |||
} | |||
fieldValuePtr, err := col.ValueOf(bean) | |||
if err != nil { | |||
return nil, nil, err | |||
} | |||
fieldValue := *fieldValuePtr | |||
if col.IsAutoIncrement { | |||
switch fieldValue.Type().Kind() { | |||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64: | |||
if fieldValue.Int() == 0 { | |||
continue | |||
} | |||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64: | |||
if fieldValue.Uint() == 0 { | |||
continue | |||
} | |||
case reflect.String: | |||
if len(fieldValue.String()) == 0 { | |||
continue | |||
} | |||
case reflect.Ptr: | |||
if fieldValue.Pointer() == 0 { | |||
continue | |||
} | |||
} | |||
} | |||
// !evalphobia! set fieldValue as nil when column is nullable and zero-value | |||
if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok { | |||
if col.Nullable && isZero(fieldValue.Interface()) { | |||
var nilValue *int | |||
fieldValue = reflect.ValueOf(nilValue) | |||
} | |||
} | |||
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ { | |||
// if time is non-empty, then set to auto time | |||
val, t := session.engine.nowTime(col) | |||
args = append(args, val) | |||
var colName = col.Name | |||
session.afterClosures = append(session.afterClosures, func(bean interface{}) { | |||
col := table.GetColumn(colName) | |||
setColumnTime(bean, col, t) | |||
}) | |||
} else if col.IsVersion && session.statement.checkVersion { | |||
args = append(args, 1) | |||
} else { | |||
arg, err := session.value2Interface(col, fieldValue) | |||
if err != nil { | |||
return colNames, args, err | |||
} | |||
args = append(args, arg) | |||
} | |||
colNames = append(colNames, col.Name) | |||
} | |||
return colNames, args, nil | |||
} |
@@ -64,13 +64,17 @@ func (session *Session) genQuerySQL(sqlorArgs ...interface{}) (string, []interfa | |||
} | |||
} | |||
if err := session.statement.processIDParam(); err != nil { | |||
return "", nil, err | |||
} | |||
condSQL, condArgs, err := builder.ToSQL(session.statement.cond) | |||
if err != nil { | |||
return "", nil, err | |||
} | |||
args := append(session.statement.joinArgs, condArgs...) | |||
sqlStr, err := session.statement.genSelectSQL(columnStr, condSQL) | |||
sqlStr, err := session.statement.genSelectSQL(columnStr, condSQL, true, true) | |||
if err != nil { | |||
return "", nil, err | |||
} |
@@ -6,9 +6,7 @@ package xorm | |||
import ( | |||
"database/sql" | |||
"errors" | |||
"fmt" | |||
"reflect" | |||
"strings" | |||
"github.com/go-xorm/core" | |||
@@ -34,8 +32,7 @@ func (session *Session) CreateTable(bean interface{}) error { | |||
} | |||
func (session *Session) createTable(bean interface{}) error { | |||
v := rValue(bean) | |||
if err := session.statement.setRefValue(v); err != nil { | |||
if err := session.statement.setRefBean(bean); err != nil { | |||
return err | |||
} | |||
@@ -54,8 +51,7 @@ func (session *Session) CreateIndexes(bean interface{}) error { | |||
} | |||
func (session *Session) createIndexes(bean interface{}) error { | |||
v := rValue(bean) | |||
if err := session.statement.setRefValue(v); err != nil { | |||
if err := session.statement.setRefBean(bean); err != nil { | |||
return err | |||
} | |||
@@ -78,8 +74,7 @@ func (session *Session) CreateUniques(bean interface{}) error { | |||
} | |||
func (session *Session) createUniques(bean interface{}) error { | |||
v := rValue(bean) | |||
if err := session.statement.setRefValue(v); err != nil { | |||
if err := session.statement.setRefBean(bean); err != nil { | |||
return err | |||
} | |||
@@ -103,8 +98,7 @@ func (session *Session) DropIndexes(bean interface{}) error { | |||
} | |||
func (session *Session) dropIndexes(bean interface{}) error { | |||
v := rValue(bean) | |||
if err := session.statement.setRefValue(v); err != nil { | |||
if err := session.statement.setRefBean(bean); err != nil { | |||
return err | |||
} | |||
@@ -128,11 +122,7 @@ func (session *Session) DropTable(beanOrTableName interface{}) error { | |||
} | |||
func (session *Session) dropTable(beanOrTableName interface{}) error { | |||
tableName, err := session.engine.tableName(beanOrTableName) | |||
if err != nil { | |||
return err | |||
} | |||
tableName := session.engine.TableName(beanOrTableName) | |||
var needDrop = true | |||
if !session.engine.dialect.SupportDropIfExists() { | |||
sqlStr, args := session.engine.dialect.TableCheckSql(tableName) | |||
@@ -144,8 +134,8 @@ func (session *Session) dropTable(beanOrTableName interface{}) error { | |||
} | |||
if needDrop { | |||
sqlStr := session.engine.Dialect().DropTableSql(tableName) | |||
_, err = session.exec(sqlStr) | |||
sqlStr := session.engine.Dialect().DropTableSql(session.engine.TableName(tableName, true)) | |||
_, err := session.exec(sqlStr) | |||
return err | |||
} | |||
return nil | |||
@@ -157,10 +147,7 @@ func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error) | |||
defer session.Close() | |||
} | |||
tableName, err := session.engine.tableName(beanOrTableName) | |||
if err != nil { | |||
return false, err | |||
} | |||
tableName := session.engine.TableName(beanOrTableName) | |||
return session.isTableExist(tableName) | |||
} | |||
@@ -173,24 +160,15 @@ func (session *Session) isTableExist(tableName string) (bool, error) { | |||
// IsTableEmpty if table have any records | |||
func (session *Session) IsTableEmpty(bean interface{}) (bool, error) { | |||
v := rValue(bean) | |||
t := v.Type() | |||
if t.Kind() == reflect.String { | |||
if session.isAutoClose { | |||
defer session.Close() | |||
} | |||
return session.isTableEmpty(bean.(string)) | |||
} else if t.Kind() == reflect.Struct { | |||
rows, err := session.Count(bean) | |||
return rows == 0, err | |||
if session.isAutoClose { | |||
defer session.Close() | |||
} | |||
return false, errors.New("bean should be a struct or struct's point") | |||
return session.isTableEmpty(session.engine.TableName(bean)) | |||
} | |||
func (session *Session) isTableEmpty(tableName string) (bool, error) { | |||
var total int64 | |||
sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(tableName)) | |||
sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(session.engine.TableName(tableName, true))) | |||
err := session.queryRow(sqlStr).Scan(&total) | |||
if err != nil { | |||
if err == sql.ErrNoRows { | |||
@@ -255,6 +233,12 @@ func (session *Session) Sync2(beans ...interface{}) error { | |||
return err | |||
} | |||
session.autoResetStatement = false | |||
defer func() { | |||
session.autoResetStatement = true | |||
session.resetStatement() | |||
}() | |||
var structTables []*core.Table | |||
for _, bean := range beans { | |||
@@ -264,7 +248,8 @@ func (session *Session) Sync2(beans ...interface{}) error { | |||
return err | |||
} | |||
structTables = append(structTables, table) | |||
var tbName = session.tbNameNoSchema(table) | |||
tbName := engine.TableName(bean) | |||
tbNameWithSchema := engine.TableName(tbName, true) | |||
var oriTable *core.Table | |||
for _, tb := range tables { | |||
@@ -309,32 +294,32 @@ func (session *Session) Sync2(beans ...interface{}) error { | |||
if engine.dialect.DBType() == core.MYSQL || | |||
engine.dialect.DBType() == core.POSTGRES { | |||
engine.logger.Infof("Table %s column %s change type from %s to %s\n", | |||
tbName, col.Name, curType, expectedType) | |||
_, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col)) | |||
tbNameWithSchema, col.Name, curType, expectedType) | |||
_, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col)) | |||
} else { | |||
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n", | |||
tbName, col.Name, curType, expectedType) | |||
tbNameWithSchema, col.Name, curType, expectedType) | |||
} | |||
} else if strings.HasPrefix(curType, core.Varchar) && strings.HasPrefix(expectedType, core.Varchar) { | |||
if engine.dialect.DBType() == core.MYSQL { | |||
if oriCol.Length < col.Length { | |||
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", | |||
tbName, col.Name, oriCol.Length, col.Length) | |||
_, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col)) | |||
tbNameWithSchema, col.Name, oriCol.Length, col.Length) | |||
_, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col)) | |||
} | |||
} | |||
} else { | |||
if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') { | |||
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s", | |||
tbName, col.Name, curType, expectedType) | |||
tbNameWithSchema, col.Name, curType, expectedType) | |||
} | |||
} | |||
} else if expectedType == core.Varchar { | |||
if engine.dialect.DBType() == core.MYSQL { | |||
if oriCol.Length < col.Length { | |||
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", | |||
tbName, col.Name, oriCol.Length, col.Length) | |||
_, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col)) | |||
tbNameWithSchema, col.Name, oriCol.Length, col.Length) | |||
_, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col)) | |||
} | |||
} | |||
} | |||
@@ -348,7 +333,7 @@ func (session *Session) Sync2(beans ...interface{}) error { | |||
} | |||
} else { | |||
session.statement.RefTable = table | |||
session.statement.tableName = tbName | |||
session.statement.tableName = tbNameWithSchema | |||
err = session.addColumn(col.Name) | |||
} | |||
if err != nil { | |||
@@ -371,7 +356,7 @@ func (session *Session) Sync2(beans ...interface{}) error { | |||
if oriIndex != nil { | |||
if oriIndex.Type != index.Type { | |||
sql := engine.dialect.DropIndexSql(tbName, oriIndex) | |||
sql := engine.dialect.DropIndexSql(tbNameWithSchema, oriIndex) | |||
_, err = session.exec(sql) | |||
if err != nil { | |||
return err | |||
@@ -387,7 +372,7 @@ func (session *Session) Sync2(beans ...interface{}) error { | |||
for name2, index2 := range oriTable.Indexes { | |||
if _, ok := foundIndexNames[name2]; !ok { | |||
sql := engine.dialect.DropIndexSql(tbName, index2) | |||
sql := engine.dialect.DropIndexSql(tbNameWithSchema, index2) | |||
_, err = session.exec(sql) | |||
if err != nil { | |||
return err | |||
@@ -398,12 +383,12 @@ func (session *Session) Sync2(beans ...interface{}) error { | |||
for name, index := range addedNames { | |||
if index.Type == core.UniqueType { | |||
session.statement.RefTable = table | |||
session.statement.tableName = tbName | |||
err = session.addUnique(tbName, name) | |||
session.statement.tableName = tbNameWithSchema | |||
err = session.addUnique(tbNameWithSchema, name) | |||
} else if index.Type == core.IndexType { | |||
session.statement.RefTable = table | |||
session.statement.tableName = tbName | |||
err = session.addIndex(tbName, name) | |||
session.statement.tableName = tbNameWithSchema | |||
err = session.addIndex(tbNameWithSchema, name) | |||
} | |||
if err != nil { | |||
return err | |||
@@ -428,7 +413,7 @@ func (session *Session) Sync2(beans ...interface{}) error { | |||
for _, colName := range table.ColumnsSeq() { | |||
if oriTable.GetColumn(colName) == nil { | |||
engine.logger.Warnf("Table %s has column %s but struct has not related field", table.Name, colName) | |||
engine.logger.Warnf("Table %s has column %s but struct has not related field", engine.TableName(table.Name, true), colName) | |||
} | |||
} | |||
} |
@@ -24,6 +24,7 @@ func (session *Session) Rollback() error { | |||
if !session.isAutoCommit && !session.isCommitedOrRollbacked { | |||
session.saveLastSQL(session.engine.dialect.RollBackStr()) | |||
session.isCommitedOrRollbacked = true | |||
session.isAutoCommit = true | |||
return session.tx.Rollback() | |||
} | |||
return nil | |||
@@ -34,6 +35,7 @@ func (session *Session) Commit() error { | |||
if !session.isAutoCommit && !session.isCommitedOrRollbacked { | |||
session.saveLastSQL("COMMIT") | |||
session.isCommitedOrRollbacked = true | |||
session.isAutoCommit = true | |||
var err error | |||
if err = session.tx.Commit(); err == nil { | |||
// handle processors after tx committed |
@@ -40,7 +40,7 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string, | |||
} | |||
} | |||
cacher := session.engine.getCacher2(table) | |||
cacher := session.engine.getCacher(tableName) | |||
session.engine.logger.Debug("[cacheUpdate] get cache sql", newsql, args[nStart:]) | |||
ids, err := core.GetCacheSql(cacher, tableName, newsql, args[nStart:]) | |||
if err != nil { | |||
@@ -167,7 +167,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 | |||
var isMap = t.Kind() == reflect.Map | |||
var isStruct = t.Kind() == reflect.Struct | |||
if isStruct { | |||
if err := session.statement.setRefValue(v); err != nil { | |||
if err := session.statement.setRefBean(bean); err != nil { | |||
return 0, err | |||
} | |||
@@ -176,12 +176,10 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 | |||
} | |||
if session.statement.ColumnStr == "" { | |||
colNames, args = buildUpdates(session.engine, session.statement.RefTable, bean, false, false, | |||
false, false, session.statement.allUseBool, session.statement.useAllCols, | |||
session.statement.mustColumnMap, session.statement.nullableMap, | |||
session.statement.columnMap, true, session.statement.unscoped) | |||
colNames, args = session.statement.buildUpdates(bean, false, false, | |||
false, false, true) | |||
} else { | |||
colNames, args, err = genCols(session.statement.RefTable, session, bean, true, true) | |||
colNames, args, err = session.genUpdateColumns(bean) | |||
if err != nil { | |||
return 0, err | |||
} | |||
@@ -202,7 +200,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 | |||
table := session.statement.RefTable | |||
if session.statement.UseAutoTime && table != nil && table.Updated != "" { | |||
if _, ok := session.statement.columnMap[strings.ToLower(table.Updated)]; !ok { | |||
if !session.statement.columnMap.contain(table.Updated) && | |||
!session.statement.omitColumnMap.contain(table.Updated) { | |||
colNames = append(colNames, session.engine.Quote(table.Updated)+" = ?") | |||
col := table.UpdatedColumn() | |||
val, t := session.engine.nowTime(col) | |||
@@ -362,12 +361,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 | |||
} | |||
} | |||
if table != nil { | |||
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { | |||
//session.cacheUpdate(table, tableName, sqlStr, args...) | |||
cacher.ClearIds(tableName) | |||
cacher.ClearBeans(tableName) | |||
} | |||
if cacher := session.engine.getCacher(tableName); cacher != nil && session.statement.UseCache { | |||
//session.cacheUpdate(table, tableName, sqlStr, args...) | |||
session.engine.logger.Debug("[cacheUpdate] clear table ", tableName) | |||
cacher.ClearIds(tableName) | |||
cacher.ClearBeans(tableName) | |||
} | |||
// handle after update processors | |||
@@ -402,3 +400,92 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 | |||
return res.RowsAffected() | |||
} | |||
func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interface{}, error) { | |||
table := session.statement.RefTable | |||
colNames := make([]string, 0, len(table.ColumnsSeq())) | |||
args := make([]interface{}, 0, len(table.ColumnsSeq())) | |||
for _, col := range table.Columns() { | |||
if !col.IsVersion && !col.IsCreated && !col.IsUpdated { | |||
if session.statement.omitColumnMap.contain(col.Name) { | |||
continue | |||
} | |||
} | |||
if col.MapType == core.ONLYFROMDB { | |||
continue | |||
} | |||
fieldValuePtr, err := col.ValueOf(bean) | |||
if err != nil { | |||
return nil, nil, err | |||
} | |||
fieldValue := *fieldValuePtr | |||
if col.IsAutoIncrement { | |||
switch fieldValue.Type().Kind() { | |||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64: | |||
if fieldValue.Int() == 0 { | |||
continue | |||
} | |||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64: | |||
if fieldValue.Uint() == 0 { | |||
continue | |||
} | |||
case reflect.String: | |||
if len(fieldValue.String()) == 0 { | |||
continue | |||
} | |||
case reflect.Ptr: | |||
if fieldValue.Pointer() == 0 { | |||
continue | |||
} | |||
} | |||
} | |||
if col.IsDeleted || col.IsCreated { | |||
continue | |||
} | |||
if len(session.statement.columnMap) > 0 { | |||
if !session.statement.columnMap.contain(col.Name) { | |||
continue | |||
} else if _, ok := session.statement.incrColumns[col.Name]; ok { | |||
continue | |||
} else if _, ok := session.statement.decrColumns[col.Name]; ok { | |||
continue | |||
} | |||
} | |||
// !evalphobia! set fieldValue as nil when column is nullable and zero-value | |||
if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok { | |||
if col.Nullable && isZero(fieldValue.Interface()) { | |||
var nilValue *int | |||
fieldValue = reflect.ValueOf(nilValue) | |||
} | |||
} | |||
if col.IsUpdated && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ { | |||
// if time is non-empty, then set to auto time | |||
val, t := session.engine.nowTime(col) | |||
args = append(args, val) | |||
var colName = col.Name | |||
session.afterClosures = append(session.afterClosures, func(bean interface{}) { | |||
col := table.GetColumn(colName) | |||
setColumnTime(bean, col, t) | |||
}) | |||
} else if col.IsVersion && session.statement.checkVersion { | |||
args = append(args, 1) | |||
} else { | |||
arg, err := session.value2Interface(col, fieldValue) | |||
if err != nil { | |||
return colNames, args, err | |||
} | |||
args = append(args, arg) | |||
} | |||
colNames = append(colNames, session.engine.Quote(col.Name)+" = ?") | |||
} | |||
return colNames, args, nil | |||
} |
@@ -5,7 +5,6 @@ | |||
package xorm | |||
import ( | |||
"bytes" | |||
"database/sql/driver" | |||
"encoding/json" | |||
"errors" | |||
@@ -18,21 +17,6 @@ import ( | |||
"github.com/go-xorm/core" | |||
) | |||
type incrParam struct { | |||
colName string | |||
arg interface{} | |||
} | |||
type decrParam struct { | |||
colName string | |||
arg interface{} | |||
} | |||
type exprParam struct { | |||
colName string | |||
expr string | |||
} | |||
// Statement save all the sql info for executing SQL | |||
type Statement struct { | |||
RefTable *core.Table | |||
@@ -47,7 +31,6 @@ type Statement struct { | |||
HavingStr string | |||
ColumnStr string | |||
selectStr string | |||
columnMap map[string]bool | |||
useAllCols bool | |||
OmitStr string | |||
AltTableName string | |||
@@ -67,6 +50,8 @@ type Statement struct { | |||
allUseBool bool | |||
checkVersion bool | |||
unscoped bool | |||
columnMap columnMap | |||
omitColumnMap columnMap | |||
mustColumnMap map[string]bool | |||
nullableMap map[string]bool | |||
incrColumns map[string]incrParam | |||
@@ -89,7 +74,8 @@ func (statement *Statement) Init() { | |||
statement.HavingStr = "" | |||
statement.ColumnStr = "" | |||
statement.OmitStr = "" | |||
statement.columnMap = make(map[string]bool) | |||
statement.columnMap = columnMap{} | |||
statement.omitColumnMap = columnMap{} | |||
statement.AltTableName = "" | |||
statement.tableName = "" | |||
statement.idParam = nil | |||
@@ -221,34 +207,33 @@ func (statement *Statement) setRefValue(v reflect.Value) error { | |||
if err != nil { | |||
return err | |||
} | |||
statement.tableName = statement.Engine.tbName(v) | |||
statement.tableName = statement.Engine.TableName(v, true) | |||
return nil | |||
} | |||
// Table tempororily set table name, the parameter could be a string or a pointer of struct | |||
func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { | |||
v := rValue(tableNameOrBean) | |||
t := v.Type() | |||
if t.Kind() == reflect.String { | |||
statement.AltTableName = tableNameOrBean.(string) | |||
} else if t.Kind() == reflect.Struct { | |||
var err error | |||
statement.RefTable, err = statement.Engine.autoMapType(v) | |||
if err != nil { | |||
statement.Engine.logger.Error(err) | |||
return statement | |||
} | |||
statement.AltTableName = statement.Engine.tbName(v) | |||
func (statement *Statement) setRefBean(bean interface{}) error { | |||
var err error | |||
statement.RefTable, err = statement.Engine.autoMapType(rValue(bean)) | |||
if err != nil { | |||
return err | |||
} | |||
return statement | |||
statement.tableName = statement.Engine.TableName(bean, true) | |||
return nil | |||
} | |||
// Auto generating update columnes and values according a struct | |||
func buildUpdates(engine *Engine, table *core.Table, bean interface{}, | |||
includeVersion bool, includeUpdated bool, includeNil bool, | |||
includeAutoIncr bool, allUseBool bool, useAllCols bool, | |||
mustColumnMap map[string]bool, nullableMap map[string]bool, | |||
columnMap map[string]bool, update, unscoped bool) ([]string, []interface{}) { | |||
func (statement *Statement) buildUpdates(bean interface{}, | |||
includeVersion, includeUpdated, includeNil, | |||
includeAutoIncr, update bool) ([]string, []interface{}) { | |||
engine := statement.Engine | |||
table := statement.RefTable | |||
allUseBool := statement.allUseBool | |||
useAllCols := statement.useAllCols | |||
mustColumnMap := statement.mustColumnMap | |||
nullableMap := statement.nullableMap | |||
columnMap := statement.columnMap | |||
omitColumnMap := statement.omitColumnMap | |||
unscoped := statement.unscoped | |||
var colNames = make([]string, 0) | |||
var args = make([]interface{}, 0) | |||
@@ -268,7 +253,14 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, | |||
if col.IsDeleted && !unscoped { | |||
continue | |||
} | |||
if use, ok := columnMap[strings.ToLower(col.Name)]; ok && !use { | |||
if omitColumnMap.contain(col.Name) { | |||
continue | |||
} | |||
if len(columnMap) > 0 && !columnMap.contain(col.Name) { | |||
continue | |||
} | |||
if col.MapType == core.ONLYFROMDB { | |||
continue | |||
} | |||
@@ -604,17 +596,10 @@ func (statement *Statement) col2NewColsWithQuote(columns ...string) []string { | |||
} | |||
func (statement *Statement) colmap2NewColsWithQuote() []string { | |||
newColumns := make([]string, 0, len(statement.columnMap)) | |||
for col := range statement.columnMap { | |||
fields := strings.Split(strings.TrimSpace(col), ".") | |||
if len(fields) == 1 { | |||
newColumns = append(newColumns, statement.Engine.quote(fields[0])) | |||
} else if len(fields) == 2 { | |||
newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+ | |||
statement.Engine.quote(fields[1])) | |||
} else { | |||
panic(errors.New("unwanted colnames")) | |||
} | |||
newColumns := make([]string, len(statement.columnMap), len(statement.columnMap)) | |||
copy(newColumns, statement.columnMap) | |||
for i := 0; i < len(statement.columnMap); i++ { | |||
newColumns[i] = statement.Engine.Quote(newColumns[i]) | |||
} | |||
return newColumns | |||
} | |||
@@ -642,10 +627,11 @@ func (statement *Statement) Select(str string) *Statement { | |||
func (statement *Statement) Cols(columns ...string) *Statement { | |||
cols := col2NewCols(columns...) | |||
for _, nc := range cols { | |||
statement.columnMap[strings.ToLower(nc)] = true | |||
statement.columnMap.add(nc) | |||
} | |||
newColumns := statement.colmap2NewColsWithQuote() | |||
statement.ColumnStr = strings.Join(newColumns, ", ") | |||
statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1) | |||
return statement | |||
@@ -680,7 +666,7 @@ func (statement *Statement) UseBool(columns ...string) *Statement { | |||
func (statement *Statement) Omit(columns ...string) { | |||
newColumns := col2NewCols(columns...) | |||
for _, nc := range newColumns { | |||
statement.columnMap[strings.ToLower(nc)] = false | |||
statement.omitColumnMap = append(statement.omitColumnMap, nc) | |||
} | |||
statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", "))) | |||
} | |||
@@ -719,10 +705,9 @@ func (statement *Statement) OrderBy(order string) *Statement { | |||
// Desc generate `ORDER BY xx DESC` | |||
func (statement *Statement) Desc(colNames ...string) *Statement { | |||
var buf bytes.Buffer | |||
fmt.Fprintf(&buf, statement.OrderStr) | |||
var buf builder.StringBuilder | |||
if len(statement.OrderStr) > 0 { | |||
fmt.Fprint(&buf, ", ") | |||
fmt.Fprint(&buf, statement.OrderStr, ", ") | |||
} | |||
newColNames := statement.col2NewColsWithQuote(colNames...) | |||
fmt.Fprintf(&buf, "%v DESC", strings.Join(newColNames, " DESC, ")) | |||
@@ -732,10 +717,9 @@ func (statement *Statement) Desc(colNames ...string) *Statement { | |||
// Asc provide asc order by query condition, the input parameters are columns. | |||
func (statement *Statement) Asc(colNames ...string) *Statement { | |||
var buf bytes.Buffer | |||
fmt.Fprintf(&buf, statement.OrderStr) | |||
var buf builder.StringBuilder | |||
if len(statement.OrderStr) > 0 { | |||
fmt.Fprint(&buf, ", ") | |||
fmt.Fprint(&buf, statement.OrderStr, ", ") | |||
} | |||
newColNames := statement.col2NewColsWithQuote(colNames...) | |||
fmt.Fprintf(&buf, "%v ASC", strings.Join(newColNames, " ASC, ")) | |||
@@ -743,48 +727,35 @@ func (statement *Statement) Asc(colNames ...string) *Statement { | |||
return statement | |||
} | |||
// Table tempororily set table name, the parameter could be a string or a pointer of struct | |||
func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { | |||
v := rValue(tableNameOrBean) | |||
t := v.Type() | |||
if t.Kind() == reflect.Struct { | |||
var err error | |||
statement.RefTable, err = statement.Engine.autoMapType(v) | |||
if err != nil { | |||
statement.Engine.logger.Error(err) | |||
return statement | |||
} | |||
} | |||
statement.AltTableName = statement.Engine.TableName(tableNameOrBean, true) | |||
return statement | |||
} | |||
// Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN | |||
func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement { | |||
var buf bytes.Buffer | |||
var buf builder.StringBuilder | |||
if len(statement.JoinStr) > 0 { | |||
fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP) | |||
} else { | |||
fmt.Fprintf(&buf, "%v JOIN ", joinOP) | |||
} | |||
switch tablename.(type) { | |||
case []string: | |||
t := tablename.([]string) | |||
if len(t) > 1 { | |||
fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(t[0]), statement.Engine.Quote(t[1])) | |||
} else if len(t) == 1 { | |||
fmt.Fprintf(&buf, statement.Engine.Quote(t[0])) | |||
} | |||
case []interface{}: | |||
t := tablename.([]interface{}) | |||
l := len(t) | |||
var table string | |||
if l > 0 { | |||
f := t[0] | |||
v := rValue(f) | |||
t := v.Type() | |||
if t.Kind() == reflect.String { | |||
table = f.(string) | |||
} else if t.Kind() == reflect.Struct { | |||
table = statement.Engine.tbName(v) | |||
} | |||
} | |||
if l > 1 { | |||
fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(table), | |||
statement.Engine.Quote(fmt.Sprintf("%v", t[1]))) | |||
} else if l == 1 { | |||
fmt.Fprintf(&buf, statement.Engine.Quote(table)) | |||
} | |||
default: | |||
fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", tablename))) | |||
} | |||
tbName := statement.Engine.TableName(tablename, true) | |||
fmt.Fprintf(&buf, " ON %v", condition) | |||
fmt.Fprintf(&buf, "%s ON %v", tbName, condition) | |||
statement.JoinStr = buf.String() | |||
statement.joinArgs = append(statement.joinArgs, args...) | |||
return statement | |||
@@ -809,18 +780,20 @@ func (statement *Statement) Unscoped() *Statement { | |||
} | |||
func (statement *Statement) genColumnStr() string { | |||
var buf bytes.Buffer | |||
if statement.RefTable == nil { | |||
return "" | |||
} | |||
var buf builder.StringBuilder | |||
columns := statement.RefTable.Columns() | |||
for _, col := range columns { | |||
if statement.OmitStr != "" { | |||
if _, ok := getFlagForColumn(statement.columnMap, col); ok { | |||
continue | |||
} | |||
if statement.omitColumnMap.contain(col.Name) { | |||
continue | |||
} | |||
if len(statement.columnMap) > 0 && !statement.columnMap.contain(col.Name) { | |||
continue | |||
} | |||
if col.MapType == core.ONLYTODB { | |||
@@ -831,10 +804,6 @@ func (statement *Statement) genColumnStr() string { | |||
buf.WriteString(", ") | |||
} | |||
if col.IsPrimaryKey && statement.Engine.Dialect().DBType() == "ql" { | |||
buf.WriteString("id() AS ") | |||
} | |||
if statement.JoinStr != "" { | |||
if statement.TableAlias != "" { | |||
buf.WriteString(statement.TableAlias) | |||
@@ -859,11 +828,13 @@ func (statement *Statement) genCreateTableSQL() string { | |||
func (statement *Statement) genIndexSQL() []string { | |||
var sqls []string | |||
tbName := statement.TableName() | |||
quote := statement.Engine.Quote | |||
for idxName, index := range statement.RefTable.Indexes { | |||
for _, index := range statement.RefTable.Indexes { | |||
if index.Type == core.IndexType { | |||
sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(tbName, idxName)), | |||
quote(tbName), quote(strings.Join(index.Cols, quote(",")))) | |||
sql := statement.Engine.dialect.CreateIndexSql(tbName, index) | |||
/*idxTBName := strings.Replace(tbName, ".", "_", -1) | |||
idxTBName = strings.Replace(idxTBName, `"`, "", -1) | |||
sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(idxTBName, idxName)), | |||
quote(tbName), quote(strings.Join(index.Cols, quote(","))))*/ | |||
sqls = append(sqls, sql) | |||
} | |||
} | |||
@@ -889,16 +860,18 @@ func (statement *Statement) genUniqueSQL() []string { | |||
func (statement *Statement) genDelIndexSQL() []string { | |||
var sqls []string | |||
tbName := statement.TableName() | |||
idxPrefixName := strings.Replace(tbName, `"`, "", -1) | |||
idxPrefixName = strings.Replace(idxPrefixName, `.`, "_", -1) | |||
for idxName, index := range statement.RefTable.Indexes { | |||
var rIdxName string | |||
if index.Type == core.UniqueType { | |||
rIdxName = uniqueName(tbName, idxName) | |||
rIdxName = uniqueName(idxPrefixName, idxName) | |||
} else if index.Type == core.IndexType { | |||
rIdxName = indexName(tbName, idxName) | |||
rIdxName = indexName(idxPrefixName, idxName) | |||
} | |||
sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(rIdxName)) | |||
sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(statement.Engine.TableName(rIdxName, true))) | |||
if statement.Engine.dialect.IndexOnTable() { | |||
sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(statement.TableName())) | |||
sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName)) | |||
} | |||
sqls = append(sqls, sql) | |||
} | |||
@@ -949,7 +922,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, | |||
v := rValue(bean) | |||
isStruct := v.Kind() == reflect.Struct | |||
if isStruct { | |||
statement.setRefValue(v) | |||
statement.setRefBean(bean) | |||
} | |||
var columnStr = statement.ColumnStr | |||
@@ -982,13 +955,17 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, | |||
if err := statement.mergeConds(bean); err != nil { | |||
return "", nil, err | |||
} | |||
} else { | |||
if err := statement.processIDParam(); err != nil { | |||
return "", nil, err | |||
} | |||
} | |||
condSQL, condArgs, err := builder.ToSQL(statement.cond) | |||
if err != nil { | |||
return "", nil, err | |||
} | |||
sqlStr, err := statement.genSelectSQL(columnStr, condSQL) | |||
sqlStr, err := statement.genSelectSQL(columnStr, condSQL, true, true) | |||
if err != nil { | |||
return "", nil, err | |||
} | |||
@@ -1001,7 +978,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa | |||
var condArgs []interface{} | |||
var err error | |||
if len(beans) > 0 { | |||
statement.setRefValue(rValue(beans[0])) | |||
statement.setRefBean(beans[0]) | |||
condSQL, condArgs, err = statement.genConds(beans[0]) | |||
} else { | |||
condSQL, condArgs, err = builder.ToSQL(statement.cond) | |||
@@ -1018,7 +995,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa | |||
selectSQL = "count(*)" | |||
} | |||
} | |||
sqlStr, err := statement.genSelectSQL(selectSQL, condSQL) | |||
sqlStr, err := statement.genSelectSQL(selectSQL, condSQL, false, false) | |||
if err != nil { | |||
return "", nil, err | |||
} | |||
@@ -1027,7 +1004,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa | |||
} | |||
func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) { | |||
statement.setRefValue(rValue(bean)) | |||
statement.setRefBean(bean) | |||
var sumStrs = make([]string, 0, len(columns)) | |||
for _, colName := range columns { | |||
@@ -1043,7 +1020,7 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri | |||
return "", nil, err | |||
} | |||
sqlStr, err := statement.genSelectSQL(sumSelect, condSQL) | |||
sqlStr, err := statement.genSelectSQL(sumSelect, condSQL, true, true) | |||
if err != nil { | |||
return "", nil, err | |||
} | |||
@@ -1051,27 +1028,20 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri | |||
return sqlStr, append(statement.joinArgs, condArgs...), nil | |||
} | |||
func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, err error) { | |||
var distinct string | |||
func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (string, error) { | |||
var ( | |||
distinct string | |||
dialect = statement.Engine.Dialect() | |||
quote = statement.Engine.Quote | |||
fromStr = " FROM " | |||
top, mssqlCondi, whereStr string | |||
) | |||
if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") { | |||
distinct = "DISTINCT " | |||
} | |||
var dialect = statement.Engine.Dialect() | |||
var quote = statement.Engine.Quote | |||
var top string | |||
var mssqlCondi string | |||
if err := statement.processIDParam(); err != nil { | |||
return "", err | |||
} | |||
var buf bytes.Buffer | |||
if len(condSQL) > 0 { | |||
fmt.Fprintf(&buf, " WHERE %v", condSQL) | |||
whereStr = " WHERE " + condSQL | |||
} | |||
var whereStr = buf.String() | |||
var fromStr = " FROM " | |||
if dialect.DBType() == core.MSSQL && strings.Contains(statement.TableName(), "..") { | |||
fromStr += statement.TableName() | |||
@@ -1118,9 +1088,10 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, e | |||
} | |||
var orderStr string | |||
if len(statement.OrderStr) > 0 { | |||
if needOrderBy && len(statement.OrderStr) > 0 { | |||
orderStr = " ORDER BY " + statement.OrderStr | |||
} | |||
var groupStr string | |||
if len(statement.GroupByStr) > 0 { | |||
groupStr = " GROUP BY " + statement.GroupByStr | |||
@@ -1130,45 +1101,50 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, e | |||
} | |||
} | |||
// !nashtsai! REVIEW Sprintf is considered slowest mean of string concatnation, better to work with builder pattern | |||
a = fmt.Sprintf("SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr) | |||
var buf builder.StringBuilder | |||
fmt.Fprintf(&buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr) | |||
if len(mssqlCondi) > 0 { | |||
if len(whereStr) > 0 { | |||
a += " AND " + mssqlCondi | |||
fmt.Fprint(&buf, " AND ", mssqlCondi) | |||
} else { | |||
a += " WHERE " + mssqlCondi | |||
fmt.Fprint(&buf, " WHERE ", mssqlCondi) | |||
} | |||
} | |||
if statement.GroupByStr != "" { | |||
a = fmt.Sprintf("%v GROUP BY %v", a, statement.GroupByStr) | |||
fmt.Fprint(&buf, " GROUP BY ", statement.GroupByStr) | |||
} | |||
if statement.HavingStr != "" { | |||
a = fmt.Sprintf("%v %v", a, statement.HavingStr) | |||
fmt.Fprint(&buf, " ", statement.HavingStr) | |||
} | |||
if statement.OrderStr != "" { | |||
a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr) | |||
if needOrderBy && statement.OrderStr != "" { | |||
fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr) | |||
} | |||
if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE { | |||
if statement.Start > 0 { | |||
a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start) | |||
} else if statement.LimitN > 0 { | |||
a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN) | |||
} | |||
} else if dialect.DBType() == core.ORACLE { | |||
if statement.Start != 0 || statement.LimitN != 0 { | |||
a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start) | |||
if needLimit { | |||
if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE { | |||
if statement.Start > 0 { | |||
fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", statement.LimitN, statement.Start) | |||
} else if statement.LimitN > 0 { | |||
fmt.Fprint(&buf, " LIMIT ", statement.LimitN) | |||
} | |||
} else if dialect.DBType() == core.ORACLE { | |||
if statement.Start != 0 || statement.LimitN != 0 { | |||
oldString := buf.String() | |||
buf.Reset() | |||
fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", | |||
columnStr, columnStr, oldString, statement.Start+statement.LimitN, statement.Start) | |||
} | |||
} | |||
} | |||
if statement.IsForUpdate { | |||
a = dialect.ForUpdateSql(a) | |||
return dialect.ForUpdateSql(buf.String()), nil | |||
} | |||
return | |||
return buf.String(), nil | |||
} | |||
func (statement *Statement) processIDParam() error { | |||
if statement.idParam == nil { | |||
if statement.idParam == nil || statement.RefTable == nil { | |||
return nil | |||
} | |||
@@ -17,7 +17,7 @@ import ( | |||
const ( | |||
// Version show the xorm's version | |||
Version string = "0.6.4.0910" | |||
Version string = "0.7.0.0504" | |||
) | |||
func regDrvsNDialects() bool { | |||
@@ -31,7 +31,7 @@ func regDrvsNDialects() bool { | |||
"mysql": {"mysql", func() core.Driver { return &mysqlDriver{} }, func() core.Dialect { return &mysql{} }}, | |||
"mymysql": {"mysql", func() core.Driver { return &mymysqlDriver{} }, func() core.Dialect { return &mysql{} }}, | |||
"postgres": {"postgres", func() core.Driver { return &pqDriver{} }, func() core.Dialect { return &postgres{} }}, | |||
"pgx": {"postgres", func() core.Driver { return &pqDriver{} }, func() core.Dialect { return &postgres{} }}, | |||
"pgx": {"postgres", func() core.Driver { return &pqDriverPgx{} }, func() core.Dialect { return &postgres{} }}, | |||
"sqlite3": {"sqlite3", func() core.Driver { return &sqlite3Driver{} }, func() core.Dialect { return &sqlite3{} }}, | |||
"oci8": {"oracle", func() core.Driver { return &oci8Driver{} }, func() core.Dialect { return &oracle{} }}, | |||
"goracle": {"oracle", func() core.Driver { return &goracleDriver{} }, func() core.Dialect { return &oracle{} }}, | |||
@@ -90,6 +90,7 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { | |||
TagIdentifier: "xorm", | |||
TZLocation: time.Local, | |||
tagHandlers: defaultTagHandlers, | |||
cachers: make(map[string]core.Cacher), | |||
} | |||
if uri.DbType == core.SQLITE { | |||
@@ -108,6 +109,13 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { | |||
return engine, nil | |||
} | |||
// NewEngineWithParams new a db manager with params. The params will be passed to dialect. | |||
func NewEngineWithParams(driverName string, dataSourceName string, params map[string]string) (*Engine, error) { | |||
engine, err := NewEngine(driverName, dataSourceName) | |||
engine.dialect.SetParams(params) | |||
return engine, err | |||
} | |||
// Clone clone an engine | |||
func (engine *Engine) Clone() (*Engine, error) { | |||
return NewEngine(engine.DriverName(), engine.DataSourceName()) |