summaryrefslogtreecommitdiffstats
path: root/vendor/xorm.io/xorm/session_update.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/xorm.io/xorm/session_update.go')
-rw-r--r--vendor/xorm.io/xorm/session_update.go220
1 files changed, 102 insertions, 118 deletions
diff --git a/vendor/xorm.io/xorm/session_update.go b/vendor/xorm.io/xorm/session_update.go
index 47ced66d19..62116c473c 100644
--- a/vendor/xorm.io/xorm/session_update.go
+++ b/vendor/xorm.io/xorm/session_update.go
@@ -12,23 +12,25 @@ import (
"strings"
"xorm.io/builder"
- "xorm.io/core"
+ "xorm.io/xorm/caches"
+ "xorm.io/xorm/internal/utils"
+ "xorm.io/xorm/schemas"
)
-func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string, args ...interface{}) error {
+func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr string, args ...interface{}) error {
if table == nil ||
session.tx != nil {
return ErrCacheFailed
}
- oldhead, newsql := session.statement.convertUpdateSQL(sqlStr)
+ oldhead, newsql := session.statement.ConvertUpdateSQL(sqlStr)
if newsql == "" {
return ErrCacheFailed
}
for _, filter := range session.engine.dialect.Filters() {
- newsql = filter.Do(newsql, session.engine.dialect, table)
+ newsql = filter.Do(newsql)
}
- session.engine.logger.Debug("[cacheUpdate] new sql", oldhead, newsql)
+ session.engine.logger.Debugf("[cache] new sql: %v, %v", oldhead, newsql)
var nStart int
if len(args) > 0 {
@@ -40,9 +42,9 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string,
}
}
- cacher := session.engine.getCacher(tableName)
- session.engine.logger.Debug("[cacheUpdate] get cache sql", newsql, args[nStart:])
- ids, err := core.GetCacheSql(cacher, tableName, newsql, args[nStart:])
+ cacher := session.engine.GetCacher(tableName)
+ session.engine.logger.Debugf("[cache] get cache sql: %v, %v", newsql, args[nStart:])
+ ids, err := caches.GetCacheSql(cacher, tableName, newsql, args[nStart:])
if err != nil {
rows, err := session.NoCache().queryRows(newsql, args[nStart:]...)
if err != nil {
@@ -50,14 +52,14 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string,
}
defer rows.Close()
- ids = make([]core.PK, 0)
+ ids = make([]schemas.PK, 0)
for rows.Next() {
var res = make([]string, len(table.PrimaryKeys))
err = rows.ScanSlice(&res)
if err != nil {
return err
}
- var pk core.PK = make([]interface{}, len(table.PrimaryKeys))
+ var pk schemas.PK = make([]interface{}, len(table.PrimaryKeys))
for i, col := range table.PKColumns() {
if col.SQLType.IsNumeric() {
n, err := strconv.ParseInt(res[i], 10, 64)
@@ -74,7 +76,7 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string,
ids = append(ids, pk)
}
- session.engine.logger.Debug("[cacheUpdate] find updated id", ids)
+ session.engine.logger.Debugf("[cache] find updated id: %v", ids)
} /*else {
session.engine.LogDebug("[xorm:cacheUpdate] del cached sql:", tableName, newsql, args)
cacher.DelIds(tableName, genSqlKey(newsql, args))
@@ -86,12 +88,12 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string,
return err
}
if bean := cacher.GetBean(tableName, sid); bean != nil {
- sqls := splitNNoCase(sqlStr, "where", 2)
+ sqls := utils.SplitNNoCase(sqlStr, "where", 2)
if len(sqls) == 0 || len(sqls) > 2 {
return ErrCacheFailed
}
- sqls = splitNNoCase(sqls[0], "set", 2)
+ sqls = utils.SplitNNoCase(sqls[0], "set", 2)
if len(sqls) != 2 {
return ErrCacheFailed
}
@@ -101,38 +103,32 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string,
sps := strings.SplitN(kv, "=", 2)
sps2 := strings.Split(sps[0], ".")
colName := sps2[len(sps2)-1]
- // treat quote prefix, suffix and '`' as quotes
- quotes := append(strings.Split(session.engine.Quote(""), ""), "`")
- if strings.ContainsAny(colName, strings.Join(quotes, "")) {
- colName = strings.TrimSpace(eraseAny(colName, quotes...))
- } else {
- session.engine.logger.Debug("[cacheUpdate] cannot find column", tableName, colName)
- return ErrCacheFailed
- }
+ colName = session.engine.dialect.Quoter().Trim(colName)
+ colName = schemas.CommonQuoter.Trim(colName)
if col := table.GetColumn(colName); col != nil {
fieldValue, err := col.ValueOf(bean)
if err != nil {
- session.engine.logger.Error(err)
+ session.engine.logger.Errorf("%v", err)
} else {
- session.engine.logger.Debug("[cacheUpdate] set bean field", bean, colName, fieldValue.Interface())
- if col.IsVersion && session.statement.checkVersion {
+ session.engine.logger.Debugf("[cache] set bean field: %v, %v, %v", bean, colName, fieldValue.Interface())
+ if col.IsVersion && session.statement.CheckVersion {
session.incrVersionFieldValue(fieldValue)
} else {
fieldValue.Set(reflect.ValueOf(args[idx]))
}
}
} else {
- session.engine.logger.Errorf("[cacheUpdate] ERROR: column %v is not table %v's",
+ session.engine.logger.Errorf("[cache] ERROR: column %v is not table %v's",
colName, table.Name)
}
}
- session.engine.logger.Debug("[cacheUpdate] update cache", tableName, id, bean)
+ session.engine.logger.Debugf("[cache] update cache: %v, %v, %v", tableName, id, bean)
cacher.PutBean(tableName, sid, bean)
}
}
- session.engine.logger.Debug("[cacheUpdate] clear cached table sql:", tableName)
+ session.engine.logger.Debugf("[cache] clear cached table sql: %v", tableName)
cacher.ClearIds(tableName)
return nil
}
@@ -148,11 +144,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
defer session.Close()
}
- if session.statement.lastError != nil {
- return 0, session.statement.lastError
+ if session.statement.LastError != nil {
+ return 0, session.statement.LastError
}
- v := rValue(bean)
+ v := utils.ReflectValue(bean)
t := v.Type()
var colNames []string
@@ -172,7 +168,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
var isMap = t.Kind() == reflect.Map
var isStruct = t.Kind() == reflect.Struct
if isStruct {
- if err := session.statement.setRefBean(bean); err != nil {
+ if err := session.statement.SetRefBean(bean); err != nil {
return 0, err
}
@@ -180,14 +176,14 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
return 0, ErrTableNotFound
}
- if session.statement.ColumnStr == "" {
- colNames, args = session.statement.buildUpdates(bean, false, false,
+ if session.statement.ColumnStr() == "" {
+ colNames, args, err = session.statement.BuildUpdates(v, false, false,
false, false, true)
} else {
colNames, args, err = session.genUpdateColumns(bean)
- if err != nil {
- return 0, err
- }
+ }
+ if err != nil {
+ return 0, err
}
} else if isMap {
colNames = make([]string, 0)
@@ -205,8 +201,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
table := session.statement.RefTable
if session.statement.UseAutoTime && table != nil && table.Updated != "" {
- if !session.statement.columnMap.contain(table.Updated) &&
- !session.statement.omitColumnMap.contain(table.Updated) {
+ if !session.statement.ColumnMap.Contain(table.Updated) &&
+ !session.statement.OmitColumnMap.Contain(table.Updated) {
colNames = append(colNames, session.engine.Quote(table.Updated)+" = ?")
col := table.UpdatedColumn()
val, t := session.engine.nowTime(col)
@@ -223,28 +219,28 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
}
// for update action to like "column = column + ?"
- incColumns := session.statement.incrColumns
- for i, colName := range incColumns.colNames {
+ incColumns := session.statement.IncrColumns
+ for i, colName := range incColumns.ColNames {
colNames = append(colNames, session.engine.Quote(colName)+" = "+session.engine.Quote(colName)+" + ?")
- args = append(args, incColumns.args[i])
+ args = append(args, incColumns.Args[i])
}
// for update action to like "column = column - ?"
- decColumns := session.statement.decrColumns
- for i, colName := range decColumns.colNames {
+ decColumns := session.statement.DecrColumns
+ for i, colName := range decColumns.ColNames {
colNames = append(colNames, session.engine.Quote(colName)+" = "+session.engine.Quote(colName)+" - ?")
- args = append(args, decColumns.args[i])
+ args = append(args, decColumns.Args[i])
}
// for update action to like "column = expression"
- exprColumns := session.statement.exprColumns
- for i, colName := range exprColumns.colNames {
- switch tp := exprColumns.args[i].(type) {
+ exprColumns := session.statement.ExprColumns
+ for i, colName := range exprColumns.ColNames {
+ switch tp := exprColumns.Args[i].(type) {
case string:
if len(tp) == 0 {
tp = "''"
}
colNames = append(colNames, session.engine.Quote(colName)+"="+tp)
case *builder.Builder:
- subQuery, subArgs, err := builder.ToSQL(tp)
+ subQuery, subArgs, err := session.statement.GenCondSQL(tp)
if err != nil {
return 0, err
}
@@ -252,16 +248,16 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
args = append(args, subArgs...)
default:
colNames = append(colNames, session.engine.Quote(colName)+"=?")
- args = append(args, exprColumns.args[i])
+ args = append(args, exprColumns.Args[i])
}
}
- if err = session.statement.processIDParam(); err != nil {
+ if err = session.statement.ProcessIDParam(); err != nil {
return 0, err
}
var autoCond builder.Cond
- if !session.statement.noAutoCondition {
+ if !session.statement.NoAutoCondition {
condBeanIsStruct := false
if len(condiBean) > 0 {
if c, ok := condiBean[0].(map[string]interface{}); ok {
@@ -274,7 +270,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
}
if k == reflect.Struct {
var err error
- autoCond, err = session.statement.buildConds(session.statement.RefTable, condiBean[0], true, true, false, true, false)
+ autoCond, err = session.statement.BuildConds(session.statement.RefTable, condiBean[0], true, true, false, true, false)
if err != nil {
return 0, err
}
@@ -286,8 +282,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
}
if !condBeanIsStruct && table != nil {
- if col := table.DeletedColumn(); col != nil && !session.statement.unscoped { // tag "deleted" is enabled
- autoCond1 := session.engine.CondDeleted(session.engine.Quote(col.Name))
+ if col := table.DeletedColumn(); col != nil && !session.statement.GetUnscoped() { // tag "deleted" is enabled
+ autoCond1 := session.statement.CondDeleted(col)
if autoCond == nil {
autoCond = autoCond1
@@ -298,26 +294,34 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
}
}
- st := &session.statement
+ st := session.statement
- var sqlStr string
- var condArgs []interface{}
- var condSQL string
- cond := session.statement.cond.And(autoCond)
+ var (
+ sqlStr string
+ condArgs []interface{}
+ condSQL string
+ cond = session.statement.Conds().And(autoCond)
- var doIncVer = (table != nil && table.Version != "" && session.statement.checkVersion)
- var verValue *reflect.Value
+ doIncVer = isStruct && (table != nil && table.Version != "" && session.statement.CheckVersion)
+ verValue *reflect.Value
+ )
if doIncVer {
verValue, err = table.VersionColumn().ValueOf(bean)
if err != nil {
return 0, err
}
- cond = cond.And(builder.Eq{session.engine.Quote(table.Version): verValue.Interface()})
- colNames = append(colNames, session.engine.Quote(table.Version)+" = "+session.engine.Quote(table.Version)+" + 1")
+ if verValue != nil {
+ cond = cond.And(builder.Eq{session.engine.Quote(table.Version): verValue.Interface()})
+ colNames = append(colNames, session.engine.Quote(table.Version)+" = "+session.engine.Quote(table.Version)+" + 1")
+ }
}
- condSQL, condArgs, err = builder.ToSQL(cond)
+ if len(colNames) <= 0 {
+ return 0, errors.New("No content found to be updated")
+ }
+
+ condSQL, condArgs, err = session.statement.GenCondSQL(cond)
if err != nil {
return 0, err
}
@@ -333,25 +337,27 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
var tableName = session.statement.TableName()
// TODO: Oracle support needed
var top string
- if st.LimitN > 0 {
- if st.Engine.dialect.DBType() == core.MYSQL {
- condSQL = condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
- } else if st.Engine.dialect.DBType() == core.SQLITE {
- tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
+ if st.LimitN != nil {
+ limitValue := *st.LimitN
+ switch session.engine.dialect.URI().DBType {
+ case schemas.MYSQL:
+ condSQL = condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
+ case schemas.SQLITE:
+ tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
session.engine.Quote(tableName), tempCondSQL), condArgs...))
- condSQL, condArgs, err = builder.ToSQL(cond)
+ condSQL, condArgs, err = session.statement.GenCondSQL(cond)
if err != nil {
return 0, err
}
if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL
}
- } else if st.Engine.dialect.DBType() == core.POSTGRES {
- tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
+ case schemas.POSTGRES:
+ tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
session.engine.Quote(tableName), tempCondSQL), condArgs...))
- condSQL, condArgs, err = builder.ToSQL(cond)
+ condSQL, condArgs, err = session.statement.GenCondSQL(cond)
if err != nil {
return 0, err
}
@@ -359,14 +365,13 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL
}
- } else if st.Engine.dialect.DBType() == core.MSSQL {
- if st.OrderStr != "" && st.Engine.dialect.DBType() == core.MSSQL &&
- table != nil && len(table.PrimaryKeys) == 1 {
+ case schemas.MSSQL:
+ if st.OrderStr != "" && table != nil && len(table.PrimaryKeys) == 1 {
cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)",
- table.PrimaryKeys[0], st.LimitN, table.PrimaryKeys[0],
+ table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0],
session.engine.Quote(tableName), condSQL), condArgs...)
- condSQL, condArgs, err = builder.ToSQL(cond)
+ condSQL, condArgs, err = session.statement.GenCondSQL(cond)
if err != nil {
return 0, err
}
@@ -374,20 +379,16 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
condSQL = "WHERE " + condSQL
}
} else {
- top = fmt.Sprintf("TOP (%d) ", st.LimitN)
+ top = fmt.Sprintf("TOP (%d) ", limitValue)
}
}
}
- if len(colNames) <= 0 {
- return 0, errors.New("No content found to be updated")
- }
-
var tableAlias = session.engine.Quote(tableName)
var fromSQL string
if session.statement.TableAlias != "" {
- switch session.engine.dialect.DBType() {
- case core.MSSQL:
+ switch session.engine.dialect.URI().DBType {
+ case schemas.MSSQL:
fromSQL = fmt.Sprintf("FROM %s %s ", tableAlias, session.statement.TableAlias)
tableAlias = session.statement.TableAlias
default:
@@ -411,9 +412,9 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
}
}
- if cacher := session.engine.getCacher(tableName); cacher != nil && session.statement.UseCache {
+ if cacher := session.engine.GetCacher(tableName); cacher != nil && session.statement.UseCache {
// session.cacheUpdate(table, tableName, sqlStr, args...)
- session.engine.logger.Debug("[cacheUpdate] clear table ", tableName)
+ session.engine.logger.Debugf("[cache] clear table: %v", tableName)
cacher.ClearIds(tableName)
cacher.ClearBeans(tableName)
}
@@ -424,7 +425,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
closure(bean)
}
if processor, ok := interface{}(bean).(AfterUpdateProcessor); ok {
- session.engine.logger.Debug("[event]", tableName, " has after update processor")
+ session.engine.logger.Debugf("[event] %v has after update processor", tableName)
processor.AfterUpdate()
}
} else {
@@ -458,11 +459,11 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac
for _, col := range table.Columns() {
if !col.IsVersion && !col.IsCreated && !col.IsUpdated {
- if session.statement.omitColumnMap.contain(col.Name) {
+ if session.statement.OmitColumnMap.Contain(col.Name) {
continue
}
}
- if col.MapType == core.ONLYFROMDB {
+ if col.MapType == schemas.ONLYFROMDB {
continue
}
@@ -472,47 +473,30 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac
}
fieldValue := *fieldValuePtr
- if col.IsAutoIncrement {
- switch fieldValue.Type().Kind() {
- case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64:
- if fieldValue.Int() == 0 {
- continue
- }
- case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64:
- if fieldValue.Uint() == 0 {
- continue
- }
- case reflect.String:
- if len(fieldValue.String()) == 0 {
- continue
- }
- case reflect.Ptr:
- if fieldValue.Pointer() == 0 {
- continue
- }
- }
+ if col.IsAutoIncrement && utils.IsValueZero(fieldValue) {
+ continue
}
- if (col.IsDeleted && !session.statement.unscoped) || col.IsCreated {
+ if (col.IsDeleted && !session.statement.GetUnscoped()) || col.IsCreated {
continue
}
// if only update specify columns
- if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) {
+ if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(col.Name) {
continue
}
- if session.statement.incrColumns.isColExist(col.Name) {
+ if session.statement.IncrColumns.IsColExist(col.Name) {
continue
- } else if session.statement.decrColumns.isColExist(col.Name) {
+ } else if session.statement.DecrColumns.IsColExist(col.Name) {
continue
- } else if session.statement.exprColumns.isColExist(col.Name) {
+ } else if session.statement.ExprColumns.IsColExist(col.Name) {
continue
}
// !evalphobia! set fieldValue as nil when column is nullable and zero-value
- if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok {
- if col.Nullable && isZero(fieldValue.Interface()) {
+ if _, ok := getFlagForColumn(session.statement.NullableMap, col); ok {
+ if col.Nullable && utils.IsValueZero(fieldValue) {
var nilValue *int
fieldValue = reflect.ValueOf(nilValue)
}
@@ -528,10 +512,10 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac
col := table.GetColumn(colName)
setColumnTime(bean, col, t)
})
- } else if col.IsVersion && session.statement.checkVersion {
+ } else if col.IsVersion && session.statement.CheckVersion {
args = append(args, 1)
} else {
- arg, err := session.value2Interface(col, fieldValue)
+ arg, err := session.statement.Value2Interface(col, fieldValue)
if err != nil {
return colNames, args, err
}