summaryrefslogtreecommitdiffstats
path: root/vendor/xorm.io/xorm/session_update.go
diff options
context:
space:
mode:
authorLunny Xiao <xiaolunwen@gmail.com>2019-10-17 17:26:49 +0800
committerAntoine GIRARD <sapk@users.noreply.github.com>2019-10-17 11:26:49 +0200
commitd151503d3428d61b5b3cb27ddbe849d3a6f288eb (patch)
treef5c1346d6ddb4f3584dc089188a557cd75a07dc6 /vendor/xorm.io/xorm/session_update.go
parentae132632a9847c3d304b3bb7b8481a1d0320ab20 (diff)
downloadgitea-d151503d3428d61b5b3cb27ddbe849d3a6f288eb.tar.gz
gitea-d151503d3428d61b5b3cb27ddbe849d3a6f288eb.zip
Upgrade xorm to v0.8.0 (#8536)
Diffstat (limited to 'vendor/xorm.io/xorm/session_update.go')
-rw-r--r--vendor/xorm.io/xorm/session_update.go525
1 files changed, 525 insertions, 0 deletions
diff --git a/vendor/xorm.io/xorm/session_update.go b/vendor/xorm.io/xorm/session_update.go
new file mode 100644
index 0000000000..c5c65a452a
--- /dev/null
+++ b/vendor/xorm.io/xorm/session_update.go
@@ -0,0 +1,525 @@
+// Copyright 2016 The Xorm Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package xorm
+
+import (
+ "errors"
+ "fmt"
+ "reflect"
+ "strconv"
+ "strings"
+
+ "xorm.io/builder"
+ "xorm.io/core"
+)
+
+func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string, args ...interface{}) error {
+ if table == nil ||
+ session.tx != nil {
+ return ErrCacheFailed
+ }
+
+ oldhead, newsql := session.statement.convertUpdateSQL(sqlStr)
+ if newsql == "" {
+ return ErrCacheFailed
+ }
+ for _, filter := range session.engine.dialect.Filters() {
+ newsql = filter.Do(newsql, session.engine.dialect, table)
+ }
+ session.engine.logger.Debug("[cacheUpdate] new sql", oldhead, newsql)
+
+ var nStart int
+ if len(args) > 0 {
+ if strings.Index(sqlStr, "?") > -1 {
+ nStart = strings.Count(oldhead, "?")
+ } else {
+ // only for pq, TODO: if any other databse?
+ nStart = strings.Count(oldhead, "$")
+ }
+ }
+
+ cacher := session.engine.getCacher(tableName)
+ session.engine.logger.Debug("[cacheUpdate] get cache sql", newsql, args[nStart:])
+ ids, err := core.GetCacheSql(cacher, tableName, newsql, args[nStart:])
+ if err != nil {
+ rows, err := session.NoCache().queryRows(newsql, args[nStart:]...)
+ if err != nil {
+ return err
+ }
+ defer rows.Close()
+
+ ids = make([]core.PK, 0)
+ for rows.Next() {
+ var res = make([]string, len(table.PrimaryKeys))
+ err = rows.ScanSlice(&res)
+ if err != nil {
+ return err
+ }
+ var pk core.PK = make([]interface{}, len(table.PrimaryKeys))
+ for i, col := range table.PKColumns() {
+ if col.SQLType.IsNumeric() {
+ n, err := strconv.ParseInt(res[i], 10, 64)
+ if err != nil {
+ return err
+ }
+ pk[i] = n
+ } else if col.SQLType.IsText() {
+ pk[i] = res[i]
+ } else {
+ return errors.New("not supported")
+ }
+ }
+
+ ids = append(ids, pk)
+ }
+ session.engine.logger.Debug("[cacheUpdate] find updated id", ids)
+ } /*else {
+ session.engine.LogDebug("[xorm:cacheUpdate] del cached sql:", tableName, newsql, args)
+ cacher.DelIds(tableName, genSqlKey(newsql, args))
+ }*/
+
+ for _, id := range ids {
+ sid, err := id.ToString()
+ if err != nil {
+ return err
+ }
+ if bean := cacher.GetBean(tableName, sid); bean != nil {
+ sqls := splitNNoCase(sqlStr, "where", 2)
+ if len(sqls) == 0 || len(sqls) > 2 {
+ return ErrCacheFailed
+ }
+
+ sqls = splitNNoCase(sqls[0], "set", 2)
+ if len(sqls) != 2 {
+ return ErrCacheFailed
+ }
+ kvs := strings.Split(strings.TrimSpace(sqls[1]), ",")
+
+ for idx, kv := range kvs {
+ sps := strings.SplitN(kv, "=", 2)
+ sps2 := strings.Split(sps[0], ".")
+ colName := sps2[len(sps2)-1]
+ // treat quote prefix, suffix and '`' as quotes
+ quotes := append(strings.Split(session.engine.Quote(""), ""), "`")
+ if strings.ContainsAny(colName, strings.Join(quotes, "")) {
+ colName = strings.TrimSpace(eraseAny(colName, quotes...))
+ } else {
+ session.engine.logger.Debug("[cacheUpdate] cannot find column", tableName, colName)
+ return ErrCacheFailed
+ }
+
+ if col := table.GetColumn(colName); col != nil {
+ fieldValue, err := col.ValueOf(bean)
+ if err != nil {
+ session.engine.logger.Error(err)
+ } else {
+ session.engine.logger.Debug("[cacheUpdate] set bean field", bean, colName, fieldValue.Interface())
+ if col.IsVersion && session.statement.checkVersion {
+ session.incrVersionFieldValue(fieldValue)
+ } else {
+ fieldValue.Set(reflect.ValueOf(args[idx]))
+ }
+ }
+ } else {
+ session.engine.logger.Errorf("[cacheUpdate] ERROR: column %v is not table %v's",
+ colName, table.Name)
+ }
+ }
+
+ session.engine.logger.Debug("[cacheUpdate] update cache", tableName, id, bean)
+ cacher.PutBean(tableName, sid, bean)
+ }
+ }
+ session.engine.logger.Debug("[cacheUpdate] clear cached table sql:", tableName)
+ cacher.ClearIds(tableName)
+ return nil
+}
+
+// Update records, bean's non-empty fields are updated contents,
+// condiBean' non-empty filds are conditions
+// CAUTION:
+// 1.bool will defaultly be updated content nor conditions
+// You should call UseBool if you have bool to use.
+// 2.float32 & float64 may be not inexact as conditions
+func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int64, error) {
+ if session.isAutoClose {
+ defer session.Close()
+ }
+
+ if session.statement.lastError != nil {
+ return 0, session.statement.lastError
+ }
+
+ v := rValue(bean)
+ t := v.Type()
+
+ var colNames []string
+ var args []interface{}
+
+ // handle before update processors
+ for _, closure := range session.beforeClosures {
+ closure(bean)
+ }
+ cleanupProcessorsClosures(&session.beforeClosures) // cleanup after used
+ if processor, ok := interface{}(bean).(BeforeUpdateProcessor); ok {
+ processor.BeforeUpdate()
+ }
+ // --
+
+ var err error
+ var isMap = t.Kind() == reflect.Map
+ var isStruct = t.Kind() == reflect.Struct
+ if isStruct {
+ if err := session.statement.setRefBean(bean); err != nil {
+ return 0, err
+ }
+
+ if len(session.statement.TableName()) <= 0 {
+ return 0, ErrTableNotFound
+ }
+
+ if session.statement.ColumnStr == "" {
+ colNames, args = session.statement.buildUpdates(bean, false, false,
+ false, false, true)
+ } else {
+ colNames, args, err = session.genUpdateColumns(bean)
+ if err != nil {
+ return 0, err
+ }
+ }
+ } else if isMap {
+ colNames = make([]string, 0)
+ args = make([]interface{}, 0)
+ bValue := reflect.Indirect(reflect.ValueOf(bean))
+
+ for _, v := range bValue.MapKeys() {
+ colNames = append(colNames, session.engine.Quote(v.String())+" = ?")
+ args = append(args, bValue.MapIndex(v).Interface())
+ }
+ } else {
+ return 0, ErrParamsType
+ }
+
+ table := session.statement.RefTable
+
+ if session.statement.UseAutoTime && table != nil && table.Updated != "" {
+ if !session.statement.columnMap.contain(table.Updated) &&
+ !session.statement.omitColumnMap.contain(table.Updated) {
+ colNames = append(colNames, session.engine.Quote(table.Updated)+" = ?")
+ col := table.UpdatedColumn()
+ val, t := session.engine.nowTime(col)
+ args = append(args, val)
+
+ var colName = col.Name
+ if isStruct {
+ session.afterClosures = append(session.afterClosures, func(bean interface{}) {
+ col := table.GetColumn(colName)
+ setColumnTime(bean, col, t)
+ })
+ }
+ }
+ }
+
+ // for update action to like "column = column + ?"
+ incColumns := session.statement.incrColumns
+ for i, colName := range incColumns.colNames {
+ colNames = append(colNames, session.engine.Quote(colName)+" = "+session.engine.Quote(colName)+" + ?")
+ args = append(args, incColumns.args[i])
+ }
+ // for update action to like "column = column - ?"
+ decColumns := session.statement.decrColumns
+ for i, colName := range decColumns.colNames {
+ colNames = append(colNames, session.engine.Quote(colName)+" = "+session.engine.Quote(colName)+" - ?")
+ args = append(args, decColumns.args[i])
+ }
+ // for update action to like "column = expression"
+ exprColumns := session.statement.exprColumns
+ for i, colName := range exprColumns.colNames {
+ switch tp := exprColumns.args[i].(type) {
+ case string:
+ colNames = append(colNames, session.engine.Quote(colName)+" = "+tp)
+ case *builder.Builder:
+ subQuery, subArgs, err := builder.ToSQL(tp)
+ if err != nil {
+ return 0, err
+ }
+ colNames = append(colNames, session.engine.Quote(colName)+" = ("+subQuery+")")
+ args = append(args, subArgs...)
+ }
+ }
+
+ if err = session.statement.processIDParam(); err != nil {
+ return 0, err
+ }
+
+ var autoCond builder.Cond
+ if !session.statement.noAutoCondition {
+ condBeanIsStruct := false
+ if len(condiBean) > 0 {
+ if c, ok := condiBean[0].(map[string]interface{}); ok {
+ autoCond = builder.Eq(c)
+ } else {
+ ct := reflect.TypeOf(condiBean[0])
+ k := ct.Kind()
+ if k == reflect.Ptr {
+ k = ct.Elem().Kind()
+ }
+ if k == reflect.Struct {
+ var err error
+ autoCond, err = session.statement.buildConds(session.statement.RefTable, condiBean[0], true, true, false, true, false)
+ if err != nil {
+ return 0, err
+ }
+ condBeanIsStruct = true
+ } else {
+ return 0, ErrConditionType
+ }
+ }
+ }
+
+ if !condBeanIsStruct && table != nil {
+ if col := table.DeletedColumn(); col != nil && !session.statement.unscoped { // tag "deleted" is enabled
+ autoCond1 := session.engine.CondDeleted(session.engine.Quote(col.Name))
+
+ if autoCond == nil {
+ autoCond = autoCond1
+ } else {
+ autoCond = autoCond.And(autoCond1)
+ }
+ }
+ }
+ }
+
+ st := &session.statement
+
+ var sqlStr string
+ var condArgs []interface{}
+ var condSQL string
+ cond := session.statement.cond.And(autoCond)
+
+ var doIncVer = (table != nil && table.Version != "" && session.statement.checkVersion)
+ var verValue *reflect.Value
+ if doIncVer {
+ verValue, err = table.VersionColumn().ValueOf(bean)
+ if err != nil {
+ return 0, err
+ }
+
+ cond = cond.And(builder.Eq{session.engine.Quote(table.Version): verValue.Interface()})
+ colNames = append(colNames, session.engine.Quote(table.Version)+" = "+session.engine.Quote(table.Version)+" + 1")
+ }
+
+ condSQL, condArgs, err = builder.ToSQL(cond)
+ if err != nil {
+ return 0, err
+ }
+
+ if len(condSQL) > 0 {
+ condSQL = "WHERE " + condSQL
+ }
+
+ if st.OrderStr != "" {
+ condSQL = condSQL + fmt.Sprintf(" ORDER BY %v", st.OrderStr)
+ }
+
+ var tableName = session.statement.TableName()
+ // TODO: Oracle support needed
+ var top string
+ if st.LimitN > 0 {
+ if st.Engine.dialect.DBType() == core.MYSQL {
+ condSQL = condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
+ } else if st.Engine.dialect.DBType() == core.SQLITE {
+ tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
+ cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
+ session.engine.Quote(tableName), tempCondSQL), condArgs...))
+ condSQL, condArgs, err = builder.ToSQL(cond)
+ if err != nil {
+ return 0, err
+ }
+ if len(condSQL) > 0 {
+ condSQL = "WHERE " + condSQL
+ }
+ } else if st.Engine.dialect.DBType() == core.POSTGRES {
+ tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
+ cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
+ session.engine.Quote(tableName), tempCondSQL), condArgs...))
+ condSQL, condArgs, err = builder.ToSQL(cond)
+ if err != nil {
+ return 0, err
+ }
+
+ if len(condSQL) > 0 {
+ condSQL = "WHERE " + condSQL
+ }
+ } else if st.Engine.dialect.DBType() == core.MSSQL {
+ if st.OrderStr != "" && st.Engine.dialect.DBType() == core.MSSQL &&
+ table != nil && len(table.PrimaryKeys) == 1 {
+ cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)",
+ table.PrimaryKeys[0], st.LimitN, table.PrimaryKeys[0],
+ session.engine.Quote(tableName), condSQL), condArgs...)
+
+ condSQL, condArgs, err = builder.ToSQL(cond)
+ if err != nil {
+ return 0, err
+ }
+ if len(condSQL) > 0 {
+ condSQL = "WHERE " + condSQL
+ }
+ } else {
+ top = fmt.Sprintf("TOP (%d) ", st.LimitN)
+ }
+ }
+ }
+
+ if len(colNames) <= 0 {
+ return 0, errors.New("No content found to be updated")
+ }
+
+ sqlStr = fmt.Sprintf("UPDATE %v%v SET %v %v",
+ top,
+ session.engine.Quote(tableName),
+ strings.Join(colNames, ", "),
+ condSQL)
+
+ res, err := session.exec(sqlStr, append(args, condArgs...)...)
+ if err != nil {
+ return 0, err
+ } else if doIncVer {
+ if verValue != nil && verValue.IsValid() && verValue.CanSet() {
+ session.incrVersionFieldValue(verValue)
+ }
+ }
+
+ if cacher := session.engine.getCacher(tableName); cacher != nil && session.statement.UseCache {
+ // session.cacheUpdate(table, tableName, sqlStr, args...)
+ session.engine.logger.Debug("[cacheUpdate] clear table ", tableName)
+ cacher.ClearIds(tableName)
+ cacher.ClearBeans(tableName)
+ }
+
+ // handle after update processors
+ if session.isAutoCommit {
+ for _, closure := range session.afterClosures {
+ closure(bean)
+ }
+ if processor, ok := interface{}(bean).(AfterUpdateProcessor); ok {
+ session.engine.logger.Debug("[event]", tableName, " has after update processor")
+ processor.AfterUpdate()
+ }
+ } else {
+ lenAfterClosures := len(session.afterClosures)
+ if lenAfterClosures > 0 {
+ if value, has := session.afterUpdateBeans[bean]; has && value != nil {
+ *value = append(*value, session.afterClosures...)
+ } else {
+ afterClosures := make([]func(interface{}), lenAfterClosures)
+ copy(afterClosures, session.afterClosures)
+ // FIXME: if bean is a map type, it will panic because map cannot be as map key
+ session.afterUpdateBeans[bean] = &afterClosures
+ }
+
+ } else {
+ if _, ok := interface{}(bean).(AfterUpdateProcessor); ok {
+ session.afterUpdateBeans[bean] = nil
+ }
+ }
+ }
+ cleanupProcessorsClosures(&session.afterClosures) // cleanup after used
+ // --
+
+ return res.RowsAffected()
+}
+
+func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interface{}, error) {
+ table := session.statement.RefTable
+ colNames := make([]string, 0, len(table.ColumnsSeq()))
+ args := make([]interface{}, 0, len(table.ColumnsSeq()))
+
+ for _, col := range table.Columns() {
+ if !col.IsVersion && !col.IsCreated && !col.IsUpdated {
+ if session.statement.omitColumnMap.contain(col.Name) {
+ continue
+ }
+ }
+ if col.MapType == core.ONLYFROMDB {
+ continue
+ }
+
+ fieldValuePtr, err := col.ValueOf(bean)
+ if err != nil {
+ return nil, nil, err
+ }
+ fieldValue := *fieldValuePtr
+
+ if col.IsAutoIncrement {
+ switch fieldValue.Type().Kind() {
+ case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64:
+ if fieldValue.Int() == 0 {
+ continue
+ }
+ case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64:
+ if fieldValue.Uint() == 0 {
+ continue
+ }
+ case reflect.String:
+ if len(fieldValue.String()) == 0 {
+ continue
+ }
+ case reflect.Ptr:
+ if fieldValue.Pointer() == 0 {
+ continue
+ }
+ }
+ }
+
+ if (col.IsDeleted && !session.statement.unscoped) || col.IsCreated {
+ continue
+ }
+
+ // if only update specify columns
+ if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) {
+ continue
+ }
+
+ if session.statement.incrColumns.isColExist(col.Name) {
+ continue
+ } else if session.statement.decrColumns.isColExist(col.Name) {
+ continue
+ } else if session.statement.exprColumns.isColExist(col.Name) {
+ continue
+ }
+
+ // !evalphobia! set fieldValue as nil when column is nullable and zero-value
+ if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok {
+ if col.Nullable && isZero(fieldValue.Interface()) {
+ var nilValue *int
+ fieldValue = reflect.ValueOf(nilValue)
+ }
+ }
+
+ if col.IsUpdated && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ {
+ // if time is non-empty, then set to auto time
+ val, t := session.engine.nowTime(col)
+ args = append(args, val)
+
+ var colName = col.Name
+ session.afterClosures = append(session.afterClosures, func(bean interface{}) {
+ col := table.GetColumn(colName)
+ setColumnTime(bean, col, t)
+ })
+ } else if col.IsVersion && session.statement.checkVersion {
+ args = append(args, 1)
+ } else {
+ arg, err := session.value2Interface(col, fieldValue)
+ if err != nil {
+ return colNames, args, err
+ }
+ args = append(args, arg)
+ }
+
+ colNames = append(colNames, session.engine.Quote(col.Name)+" = ?")
+ }
+ return colNames, args, nil
+}