diff options
Diffstat (limited to 'vendor/xorm.io/xorm/session_update.go')
-rw-r--r-- | vendor/xorm.io/xorm/session_update.go | 220 |
1 files changed, 102 insertions, 118 deletions
diff --git a/vendor/xorm.io/xorm/session_update.go b/vendor/xorm.io/xorm/session_update.go index 47ced66d19..62116c473c 100644 --- a/vendor/xorm.io/xorm/session_update.go +++ b/vendor/xorm.io/xorm/session_update.go @@ -12,23 +12,25 @@ import ( "strings" "xorm.io/builder" - "xorm.io/core" + "xorm.io/xorm/caches" + "xorm.io/xorm/internal/utils" + "xorm.io/xorm/schemas" ) -func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string, args ...interface{}) error { +func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr string, args ...interface{}) error { if table == nil || session.tx != nil { return ErrCacheFailed } - oldhead, newsql := session.statement.convertUpdateSQL(sqlStr) + oldhead, newsql := session.statement.ConvertUpdateSQL(sqlStr) if newsql == "" { return ErrCacheFailed } for _, filter := range session.engine.dialect.Filters() { - newsql = filter.Do(newsql, session.engine.dialect, table) + newsql = filter.Do(newsql) } - session.engine.logger.Debug("[cacheUpdate] new sql", oldhead, newsql) + session.engine.logger.Debugf("[cache] new sql: %v, %v", oldhead, newsql) var nStart int if len(args) > 0 { @@ -40,9 +42,9 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string, } } - cacher := session.engine.getCacher(tableName) - session.engine.logger.Debug("[cacheUpdate] get cache sql", newsql, args[nStart:]) - ids, err := core.GetCacheSql(cacher, tableName, newsql, args[nStart:]) + cacher := session.engine.GetCacher(tableName) + session.engine.logger.Debugf("[cache] get cache sql: %v, %v", newsql, args[nStart:]) + ids, err := caches.GetCacheSql(cacher, tableName, newsql, args[nStart:]) if err != nil { rows, err := session.NoCache().queryRows(newsql, args[nStart:]...) if err != nil { @@ -50,14 +52,14 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string, } defer rows.Close() - ids = make([]core.PK, 0) + ids = make([]schemas.PK, 0) for rows.Next() { var res = make([]string, len(table.PrimaryKeys)) err = rows.ScanSlice(&res) if err != nil { return err } - var pk core.PK = make([]interface{}, len(table.PrimaryKeys)) + var pk schemas.PK = make([]interface{}, len(table.PrimaryKeys)) for i, col := range table.PKColumns() { if col.SQLType.IsNumeric() { n, err := strconv.ParseInt(res[i], 10, 64) @@ -74,7 +76,7 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string, ids = append(ids, pk) } - session.engine.logger.Debug("[cacheUpdate] find updated id", ids) + session.engine.logger.Debugf("[cache] find updated id: %v", ids) } /*else { session.engine.LogDebug("[xorm:cacheUpdate] del cached sql:", tableName, newsql, args) cacher.DelIds(tableName, genSqlKey(newsql, args)) @@ -86,12 +88,12 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string, return err } if bean := cacher.GetBean(tableName, sid); bean != nil { - sqls := splitNNoCase(sqlStr, "where", 2) + sqls := utils.SplitNNoCase(sqlStr, "where", 2) if len(sqls) == 0 || len(sqls) > 2 { return ErrCacheFailed } - sqls = splitNNoCase(sqls[0], "set", 2) + sqls = utils.SplitNNoCase(sqls[0], "set", 2) if len(sqls) != 2 { return ErrCacheFailed } @@ -101,38 +103,32 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string, sps := strings.SplitN(kv, "=", 2) sps2 := strings.Split(sps[0], ".") colName := sps2[len(sps2)-1] - // treat quote prefix, suffix and '`' as quotes - quotes := append(strings.Split(session.engine.Quote(""), ""), "`") - if strings.ContainsAny(colName, strings.Join(quotes, "")) { - colName = strings.TrimSpace(eraseAny(colName, quotes...)) - } else { - session.engine.logger.Debug("[cacheUpdate] cannot find column", tableName, colName) - return ErrCacheFailed - } + colName = session.engine.dialect.Quoter().Trim(colName) + colName = schemas.CommonQuoter.Trim(colName) if col := table.GetColumn(colName); col != nil { fieldValue, err := col.ValueOf(bean) if err != nil { - session.engine.logger.Error(err) + session.engine.logger.Errorf("%v", err) } else { - session.engine.logger.Debug("[cacheUpdate] set bean field", bean, colName, fieldValue.Interface()) - if col.IsVersion && session.statement.checkVersion { + session.engine.logger.Debugf("[cache] set bean field: %v, %v, %v", bean, colName, fieldValue.Interface()) + if col.IsVersion && session.statement.CheckVersion { session.incrVersionFieldValue(fieldValue) } else { fieldValue.Set(reflect.ValueOf(args[idx])) } } } else { - session.engine.logger.Errorf("[cacheUpdate] ERROR: column %v is not table %v's", + session.engine.logger.Errorf("[cache] ERROR: column %v is not table %v's", colName, table.Name) } } - session.engine.logger.Debug("[cacheUpdate] update cache", tableName, id, bean) + session.engine.logger.Debugf("[cache] update cache: %v, %v, %v", tableName, id, bean) cacher.PutBean(tableName, sid, bean) } } - session.engine.logger.Debug("[cacheUpdate] clear cached table sql:", tableName) + session.engine.logger.Debugf("[cache] clear cached table sql: %v", tableName) cacher.ClearIds(tableName) return nil } @@ -148,11 +144,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 defer session.Close() } - if session.statement.lastError != nil { - return 0, session.statement.lastError + if session.statement.LastError != nil { + return 0, session.statement.LastError } - v := rValue(bean) + v := utils.ReflectValue(bean) t := v.Type() var colNames []string @@ -172,7 +168,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 var isMap = t.Kind() == reflect.Map var isStruct = t.Kind() == reflect.Struct if isStruct { - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return 0, err } @@ -180,14 +176,14 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 return 0, ErrTableNotFound } - if session.statement.ColumnStr == "" { - colNames, args = session.statement.buildUpdates(bean, false, false, + if session.statement.ColumnStr() == "" { + colNames, args, err = session.statement.BuildUpdates(v, false, false, false, false, true) } else { colNames, args, err = session.genUpdateColumns(bean) - if err != nil { - return 0, err - } + } + if err != nil { + return 0, err } } else if isMap { colNames = make([]string, 0) @@ -205,8 +201,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 table := session.statement.RefTable if session.statement.UseAutoTime && table != nil && table.Updated != "" { - if !session.statement.columnMap.contain(table.Updated) && - !session.statement.omitColumnMap.contain(table.Updated) { + if !session.statement.ColumnMap.Contain(table.Updated) && + !session.statement.OmitColumnMap.Contain(table.Updated) { colNames = append(colNames, session.engine.Quote(table.Updated)+" = ?") col := table.UpdatedColumn() val, t := session.engine.nowTime(col) @@ -223,28 +219,28 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } // for update action to like "column = column + ?" - incColumns := session.statement.incrColumns - for i, colName := range incColumns.colNames { + incColumns := session.statement.IncrColumns + for i, colName := range incColumns.ColNames { colNames = append(colNames, session.engine.Quote(colName)+" = "+session.engine.Quote(colName)+" + ?") - args = append(args, incColumns.args[i]) + args = append(args, incColumns.Args[i]) } // for update action to like "column = column - ?" - decColumns := session.statement.decrColumns - for i, colName := range decColumns.colNames { + decColumns := session.statement.DecrColumns + for i, colName := range decColumns.ColNames { colNames = append(colNames, session.engine.Quote(colName)+" = "+session.engine.Quote(colName)+" - ?") - args = append(args, decColumns.args[i]) + args = append(args, decColumns.Args[i]) } // for update action to like "column = expression" - exprColumns := session.statement.exprColumns - for i, colName := range exprColumns.colNames { - switch tp := exprColumns.args[i].(type) { + exprColumns := session.statement.ExprColumns + for i, colName := range exprColumns.ColNames { + switch tp := exprColumns.Args[i].(type) { case string: if len(tp) == 0 { tp = "''" } colNames = append(colNames, session.engine.Quote(colName)+"="+tp) case *builder.Builder: - subQuery, subArgs, err := builder.ToSQL(tp) + subQuery, subArgs, err := session.statement.GenCondSQL(tp) if err != nil { return 0, err } @@ -252,16 +248,16 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 args = append(args, subArgs...) default: colNames = append(colNames, session.engine.Quote(colName)+"=?") - args = append(args, exprColumns.args[i]) + args = append(args, exprColumns.Args[i]) } } - if err = session.statement.processIDParam(); err != nil { + if err = session.statement.ProcessIDParam(); err != nil { return 0, err } var autoCond builder.Cond - if !session.statement.noAutoCondition { + if !session.statement.NoAutoCondition { condBeanIsStruct := false if len(condiBean) > 0 { if c, ok := condiBean[0].(map[string]interface{}); ok { @@ -274,7 +270,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } if k == reflect.Struct { var err error - autoCond, err = session.statement.buildConds(session.statement.RefTable, condiBean[0], true, true, false, true, false) + autoCond, err = session.statement.BuildConds(session.statement.RefTable, condiBean[0], true, true, false, true, false) if err != nil { return 0, err } @@ -286,8 +282,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } if !condBeanIsStruct && table != nil { - if col := table.DeletedColumn(); col != nil && !session.statement.unscoped { // tag "deleted" is enabled - autoCond1 := session.engine.CondDeleted(session.engine.Quote(col.Name)) + if col := table.DeletedColumn(); col != nil && !session.statement.GetUnscoped() { // tag "deleted" is enabled + autoCond1 := session.statement.CondDeleted(col) if autoCond == nil { autoCond = autoCond1 @@ -298,26 +294,34 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - st := &session.statement + st := session.statement - var sqlStr string - var condArgs []interface{} - var condSQL string - cond := session.statement.cond.And(autoCond) + var ( + sqlStr string + condArgs []interface{} + condSQL string + cond = session.statement.Conds().And(autoCond) - var doIncVer = (table != nil && table.Version != "" && session.statement.checkVersion) - var verValue *reflect.Value + doIncVer = isStruct && (table != nil && table.Version != "" && session.statement.CheckVersion) + verValue *reflect.Value + ) if doIncVer { verValue, err = table.VersionColumn().ValueOf(bean) if err != nil { return 0, err } - cond = cond.And(builder.Eq{session.engine.Quote(table.Version): verValue.Interface()}) - colNames = append(colNames, session.engine.Quote(table.Version)+" = "+session.engine.Quote(table.Version)+" + 1") + if verValue != nil { + cond = cond.And(builder.Eq{session.engine.Quote(table.Version): verValue.Interface()}) + colNames = append(colNames, session.engine.Quote(table.Version)+" = "+session.engine.Quote(table.Version)+" + 1") + } } - condSQL, condArgs, err = builder.ToSQL(cond) + if len(colNames) <= 0 { + return 0, errors.New("No content found to be updated") + } + + condSQL, condArgs, err = session.statement.GenCondSQL(cond) if err != nil { return 0, err } @@ -333,25 +337,27 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 var tableName = session.statement.TableName() // TODO: Oracle support needed var top string - if st.LimitN > 0 { - if st.Engine.dialect.DBType() == core.MYSQL { - condSQL = condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN) - } else if st.Engine.dialect.DBType() == core.SQLITE { - tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN) + if st.LimitN != nil { + limitValue := *st.LimitN + switch session.engine.dialect.URI().DBType { + case schemas.MYSQL: + condSQL = condSQL + fmt.Sprintf(" LIMIT %d", limitValue) + case schemas.SQLITE: + tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)", session.engine.Quote(tableName), tempCondSQL), condArgs...)) - condSQL, condArgs, err = builder.ToSQL(cond) + condSQL, condArgs, err = session.statement.GenCondSQL(cond) if err != nil { return 0, err } if len(condSQL) > 0 { condSQL = "WHERE " + condSQL } - } else if st.Engine.dialect.DBType() == core.POSTGRES { - tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN) + case schemas.POSTGRES: + tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)", session.engine.Quote(tableName), tempCondSQL), condArgs...)) - condSQL, condArgs, err = builder.ToSQL(cond) + condSQL, condArgs, err = session.statement.GenCondSQL(cond) if err != nil { return 0, err } @@ -359,14 +365,13 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if len(condSQL) > 0 { condSQL = "WHERE " + condSQL } - } else if st.Engine.dialect.DBType() == core.MSSQL { - if st.OrderStr != "" && st.Engine.dialect.DBType() == core.MSSQL && - table != nil && len(table.PrimaryKeys) == 1 { + case schemas.MSSQL: + if st.OrderStr != "" && table != nil && len(table.PrimaryKeys) == 1 { cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)", - table.PrimaryKeys[0], st.LimitN, table.PrimaryKeys[0], + table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0], session.engine.Quote(tableName), condSQL), condArgs...) - condSQL, condArgs, err = builder.ToSQL(cond) + condSQL, condArgs, err = session.statement.GenCondSQL(cond) if err != nil { return 0, err } @@ -374,20 +379,16 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 condSQL = "WHERE " + condSQL } } else { - top = fmt.Sprintf("TOP (%d) ", st.LimitN) + top = fmt.Sprintf("TOP (%d) ", limitValue) } } } - if len(colNames) <= 0 { - return 0, errors.New("No content found to be updated") - } - var tableAlias = session.engine.Quote(tableName) var fromSQL string if session.statement.TableAlias != "" { - switch session.engine.dialect.DBType() { - case core.MSSQL: + switch session.engine.dialect.URI().DBType { + case schemas.MSSQL: fromSQL = fmt.Sprintf("FROM %s %s ", tableAlias, session.statement.TableAlias) tableAlias = session.statement.TableAlias default: @@ -411,9 +412,9 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - if cacher := session.engine.getCacher(tableName); cacher != nil && session.statement.UseCache { + if cacher := session.engine.GetCacher(tableName); cacher != nil && session.statement.UseCache { // session.cacheUpdate(table, tableName, sqlStr, args...) - session.engine.logger.Debug("[cacheUpdate] clear table ", tableName) + session.engine.logger.Debugf("[cache] clear table: %v", tableName) cacher.ClearIds(tableName) cacher.ClearBeans(tableName) } @@ -424,7 +425,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 closure(bean) } if processor, ok := interface{}(bean).(AfterUpdateProcessor); ok { - session.engine.logger.Debug("[event]", tableName, " has after update processor") + session.engine.logger.Debugf("[event] %v has after update processor", tableName) processor.AfterUpdate() } } else { @@ -458,11 +459,11 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac for _, col := range table.Columns() { if !col.IsVersion && !col.IsCreated && !col.IsUpdated { - if session.statement.omitColumnMap.contain(col.Name) { + if session.statement.OmitColumnMap.Contain(col.Name) { continue } } - if col.MapType == core.ONLYFROMDB { + if col.MapType == schemas.ONLYFROMDB { continue } @@ -472,47 +473,30 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac } fieldValue := *fieldValuePtr - if col.IsAutoIncrement { - switch fieldValue.Type().Kind() { - case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64: - if fieldValue.Int() == 0 { - continue - } - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64: - if fieldValue.Uint() == 0 { - continue - } - case reflect.String: - if len(fieldValue.String()) == 0 { - continue - } - case reflect.Ptr: - if fieldValue.Pointer() == 0 { - continue - } - } + if col.IsAutoIncrement && utils.IsValueZero(fieldValue) { + continue } - if (col.IsDeleted && !session.statement.unscoped) || col.IsCreated { + if (col.IsDeleted && !session.statement.GetUnscoped()) || col.IsCreated { continue } // if only update specify columns - if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { + if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(col.Name) { continue } - if session.statement.incrColumns.isColExist(col.Name) { + if session.statement.IncrColumns.IsColExist(col.Name) { continue - } else if session.statement.decrColumns.isColExist(col.Name) { + } else if session.statement.DecrColumns.IsColExist(col.Name) { continue - } else if session.statement.exprColumns.isColExist(col.Name) { + } else if session.statement.ExprColumns.IsColExist(col.Name) { continue } // !evalphobia! set fieldValue as nil when column is nullable and zero-value - if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok { - if col.Nullable && isZero(fieldValue.Interface()) { + if _, ok := getFlagForColumn(session.statement.NullableMap, col); ok { + if col.Nullable && utils.IsValueZero(fieldValue) { var nilValue *int fieldValue = reflect.ValueOf(nilValue) } @@ -528,10 +512,10 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac col := table.GetColumn(colName) setColumnTime(bean, col, t) }) - } else if col.IsVersion && session.statement.checkVersion { + } else if col.IsVersion && session.statement.CheckVersion { args = append(args, 1) } else { - arg, err := session.value2Interface(col, fieldValue) + arg, err := session.statement.Value2Interface(col, fieldValue) if err != nil { return colNames, args, err } |