diff options
author | Lunny Xiao <xiaolunwen@gmail.com> | 2020-03-22 23:12:55 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-03-22 11:12:55 -0400 |
commit | c61b902538e8e4b2bb83136f190e044a6bbcdd9b (patch) | |
tree | bc08142f5ced69721ab14793b8dc62aa83314268 /vendor/xorm.io/xorm/session_insert.go | |
parent | dcaa5643d70cd1be72e53e82a30d2bc6e5f8de92 (diff) | |
download | gitea-c61b902538e8e4b2bb83136f190e044a6bbcdd9b.tar.gz gitea-c61b902538e8e4b2bb83136f190e044a6bbcdd9b.zip |
Upgrade xorm to v1.0.0 (#10646)
* Upgrade xorm to v1.0.0
* small nit
* Fix tests
* Update xorm
* Update xorm
* fix go.sum
* fix test
* Fix bug when dump
* Fix bug
* update xorm to latest
* Fix migration test
* update xorm to latest
* Fix import order
* Use xorm tag
Diffstat (limited to 'vendor/xorm.io/xorm/session_insert.go')
-rw-r--r-- | vendor/xorm.io/xorm/session_insert.go | 441 |
1 files changed, 149 insertions, 292 deletions
diff --git a/vendor/xorm.io/xorm/session_insert.go b/vendor/xorm.io/xorm/session_insert.go index 5f8f7e1ee8..3a0a706655 100644 --- a/vendor/xorm.io/xorm/session_insert.go +++ b/vendor/xorm.io/xorm/session_insert.go @@ -13,9 +13,13 @@ import ( "strings" "xorm.io/builder" - "xorm.io/core" + "xorm.io/xorm/internal/utils" + "xorm.io/xorm/schemas" ) +// ErrNoElementsOnSlice represents an error there is no element when insert +var ErrNoElementsOnSlice = errors.New("No element on slice when insert") + // Insert insert one or more beans func (session *Session) Insert(beans ...interface{}) (int64, error) { var affected int64 @@ -67,23 +71,15 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) { sliceValue := reflect.Indirect(reflect.ValueOf(bean)) if sliceValue.Kind() == reflect.Slice { size := sliceValue.Len() - if size > 0 { - if session.engine.SupportInsertMany() { - cnt, err := session.innerInsertMulti(bean) - if err != nil { - return affected, err - } - affected += cnt - } else { - for i := 0; i < size; i++ { - cnt, err := session.innerInsert(sliceValue.Index(i).Interface()) - if err != nil { - return affected, err - } - affected += cnt - } - } + if size <= 0 { + return 0, ErrNoElementsOnSlice + } + + cnt, err := session.innerInsertMulti(bean) + if err != nil { + return affected, err } + affected += cnt } else { cnt, err := session.innerInsert(bean) if err != nil { @@ -107,7 +103,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error return 0, errors.New("could not insert a empty slice") } - if err := session.statement.setRefBean(sliceValue.Index(0).Interface()); err != nil { + if err := session.statement.SetRefBean(sliceValue.Index(0).Interface()); err != nil { return 0, err } @@ -122,11 +118,17 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error var colNames []string var colMultiPlaces []string var args []interface{} - var cols []*core.Column + var cols []*schemas.Column for i := 0; i < size; i++ { v := sliceValue.Index(i) - vv := reflect.Indirect(v) + var vv reflect.Value + switch v.Kind() { + case reflect.Interface: + vv = reflect.Indirect(v.Elem()) + default: + vv = reflect.Indirect(v) + } elemValue := v.Interface() var colPlaces []string @@ -141,123 +143,77 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error } // -- - if i == 0 { - for _, col := range table.Columns() { - ptrFieldValue, err := col.ValueOfV(&vv) + for _, col := range table.Columns() { + ptrFieldValue, err := col.ValueOfV(&vv) + if err != nil { + return 0, err + } + fieldValue := *ptrFieldValue + if col.IsAutoIncrement && utils.IsZero(fieldValue.Interface()) { + continue + } + if col.MapType == schemas.ONLYFROMDB { + continue + } + if col.IsDeleted { + continue + } + if session.statement.OmitColumnMap.Contain(col.Name) { + continue + } + if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(col.Name) { + continue + } + if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime { + val, t := session.engine.nowTime(col) + args = append(args, val) + + var colName = col.Name + session.afterClosures = append(session.afterClosures, func(bean interface{}) { + col := table.GetColumn(colName) + setColumnTime(bean, col, t) + }) + } else if col.IsVersion && session.statement.CheckVersion { + args = append(args, 1) + var colName = col.Name + session.afterClosures = append(session.afterClosures, func(bean interface{}) { + col := table.GetColumn(colName) + setColumnInt(bean, col, 1) + }) + } else { + arg, err := session.statement.Value2Interface(col, fieldValue) if err != nil { return 0, err } - fieldValue := *ptrFieldValue - if col.IsAutoIncrement && isZero(fieldValue.Interface()) { - continue - } - if col.MapType == core.ONLYFROMDB { - continue - } - if col.IsDeleted { - continue - } - if session.statement.omitColumnMap.contain(col.Name) { - continue - } - if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { - continue - } - if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime { - val, t := session.engine.nowTime(col) - args = append(args, val) - - var colName = col.Name - session.afterClosures = append(session.afterClosures, func(bean interface{}) { - col := table.GetColumn(colName) - setColumnTime(bean, col, t) - }) - } else if col.IsVersion && session.statement.checkVersion { - args = append(args, 1) - var colName = col.Name - session.afterClosures = append(session.afterClosures, func(bean interface{}) { - col := table.GetColumn(colName) - setColumnInt(bean, col, 1) - }) - } else { - arg, err := session.value2Interface(col, fieldValue) - if err != nil { - return 0, err - } - args = append(args, arg) - } + args = append(args, arg) + } + if i == 0 { colNames = append(colNames, col.Name) cols = append(cols, col) - colPlaces = append(colPlaces, "?") - } - } else { - for _, col := range cols { - ptrFieldValue, err := col.ValueOfV(&vv) - if err != nil { - return 0, err - } - fieldValue := *ptrFieldValue - - if col.IsAutoIncrement && isZero(fieldValue.Interface()) { - continue - } - if col.MapType == core.ONLYFROMDB { - continue - } - if col.IsDeleted { - continue - } - if session.statement.omitColumnMap.contain(col.Name) { - continue - } - if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { - continue - } - if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime { - val, t := session.engine.nowTime(col) - args = append(args, val) - - var colName = col.Name - session.afterClosures = append(session.afterClosures, func(bean interface{}) { - col := table.GetColumn(colName) - setColumnTime(bean, col, t) - }) - } else if col.IsVersion && session.statement.checkVersion { - args = append(args, 1) - var colName = col.Name - session.afterClosures = append(session.afterClosures, func(bean interface{}) { - col := table.GetColumn(colName) - setColumnInt(bean, col, 1) - }) - } else { - arg, err := session.value2Interface(col, fieldValue) - if err != nil { - return 0, err - } - args = append(args, arg) - } - - colPlaces = append(colPlaces, "?") } + colPlaces = append(colPlaces, "?") } + colMultiPlaces = append(colMultiPlaces, strings.Join(colPlaces, ", ")) } cleanupProcessorsClosures(&session.beforeClosures) + quoter := session.engine.dialect.Quoter() var sql string - if session.engine.dialect.DBType() == core.ORACLE { + colStr := quoter.Join(colNames, ",") + if session.engine.dialect.URI().DBType == schemas.ORACLE { temp := fmt.Sprintf(") INTO %s (%v) VALUES (", - session.engine.Quote(tableName), - quoteColumns(colNames, session.engine.Quote, ",")) + quoter.Quote(tableName), + colStr) sql = fmt.Sprintf("INSERT ALL INTO %s (%v) VALUES (%v) SELECT 1 FROM DUAL", - session.engine.Quote(tableName), - quoteColumns(colNames, session.engine.Quote, ","), + quoter.Quote(tableName), + colStr, strings.Join(colMultiPlaces, temp)) } else { sql = fmt.Sprintf("INSERT INTO %s (%v) VALUES (%v)", - session.engine.Quote(tableName), - quoteColumns(colNames, session.engine.Quote, ","), + quoter.Quote(tableName), + colStr, strings.Join(colMultiPlaces, "),(")) } res, err := session.exec(sql, args...) @@ -321,15 +277,13 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) { } func (session *Session) innerInsert(bean interface{}) (int64, error) { - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return 0, err } if len(session.statement.TableName()) <= 0 { return 0, ErrTableNotFound } - table := session.statement.RefTable - // handle BeforeInsertProcessor for _, closure := range session.beforeClosures { closure(bean) @@ -340,100 +294,19 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { processor.BeforeInsert() } + var tableName = session.statement.TableName() + table := session.statement.RefTable + colNames, args, err := session.genInsertColumns(bean) if err != nil { return 0, err } - exprs := session.statement.exprColumns - colPlaces := strings.Repeat("?, ", len(colNames)) - if exprs.Len() <= 0 && len(colPlaces) > 0 { - colPlaces = colPlaces[0 : len(colPlaces)-2] - } - - var tableName = session.statement.TableName() - var output string - if session.engine.dialect.DBType() == core.MSSQL && len(table.AutoIncrement) > 0 { - output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement) - } - - var buf = builder.NewWriter() - if _, err := buf.WriteString(fmt.Sprintf("INSERT INTO %s", session.engine.Quote(tableName))); err != nil { + sqlStr, args, err := session.statement.GenInsertSQL(colNames, args) + if err != nil { return 0, err } - if len(colPlaces) <= 0 { - if session.engine.dialect.DBType() == core.MYSQL { - if _, err := buf.WriteString(" VALUES ()"); err != nil { - return 0, err - } - } else { - if _, err := buf.WriteString(fmt.Sprintf("%s DEFAULT VALUES", output)); err != nil { - return 0, err - } - } - } else { - if _, err := buf.WriteString(" ("); err != nil { - return 0, err - } - - if err := writeStrings(buf, append(colNames, exprs.colNames...), "`", "`"); err != nil { - return 0, err - } - - if session.statement.cond.IsValid() { - if _, err := buf.WriteString(fmt.Sprintf(")%s SELECT ", output)); err != nil { - return 0, err - } - - if err := session.statement.writeArgs(buf, args); err != nil { - return 0, err - } - - if len(exprs.args) > 0 { - if _, err := buf.WriteString(","); err != nil { - return 0, err - } - } - if err := exprs.writeArgs(buf); err != nil { - return 0, err - } - - if _, err := buf.WriteString(fmt.Sprintf(" FROM %v WHERE ", session.engine.Quote(tableName))); err != nil { - return 0, err - } - - if err := session.statement.cond.WriteTo(buf); err != nil { - return 0, err - } - } else { - buf.Append(args...) - - if _, err := buf.WriteString(fmt.Sprintf(")%s VALUES (%v", - output, - colPlaces)); err != nil { - return 0, err - } - - if err := exprs.writeArgs(buf); err != nil { - return 0, err - } - - if _, err := buf.WriteString(")"); err != nil { - return 0, err - } - } - } - - if len(table.AutoIncrement) > 0 && session.engine.dialect.DBType() == core.POSTGRES { - if _, err := buf.WriteString(" RETURNING " + session.engine.Quote(table.AutoIncrement)); err != nil { - return 0, err - } - } - - sqlStr := buf.String() - args = buf.Args() - handleAfterInsertProcessorFunc := func(bean interface{}) { if session.isAutoCommit { for _, closure := range session.afterClosures { @@ -464,7 +337,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { // for postgres, many of them didn't implement lastInsertId, so we should // implemented it ourself. - if session.engine.dialect.DBType() == core.ORACLE && len(table.AutoIncrement) > 0 { + if session.engine.dialect.URI().DBType == schemas.ORACLE && len(table.AutoIncrement) > 0 { res, err := session.queryBytes("select seq_atable.currval from dual", args...) if err != nil { return 0, err @@ -474,10 +347,10 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { session.cacheInsert(tableName) - if table.Version != "" && session.statement.checkVersion { + if table.Version != "" && session.statement.CheckVersion { verValue, err := table.VersionColumn().ValueOf(bean) if err != nil { - session.engine.logger.Error(err) + session.engine.logger.Errorf("%v", err) } else if verValue.IsValid() && verValue.CanSet() { session.incrVersionFieldValue(verValue) } @@ -495,7 +368,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { aiValue, err := table.AutoIncrColumn().ValueOf(bean) if err != nil { - session.engine.logger.Error(err) + session.engine.logger.Errorf("%v", err) } if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() { @@ -505,7 +378,8 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { aiValue.Set(int64ToIntValue(id, aiValue.Type())) return 1, nil - } else if len(table.AutoIncrement) > 0 && (session.engine.dialect.DBType() == core.POSTGRES || session.engine.dialect.DBType() == core.MSSQL) { + } else if len(table.AutoIncrement) > 0 && (session.engine.dialect.URI().DBType == schemas.POSTGRES || + session.engine.dialect.URI().DBType == schemas.MSSQL) { res, err := session.queryBytes(sqlStr, args...) if err != nil { @@ -515,10 +389,10 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { session.cacheInsert(tableName) - if table.Version != "" && session.statement.checkVersion { + if table.Version != "" && session.statement.CheckVersion { verValue, err := table.VersionColumn().ValueOf(bean) if err != nil { - session.engine.logger.Error(err) + session.engine.logger.Errorf("%v", err) } else if verValue.IsValid() && verValue.CanSet() { session.incrVersionFieldValue(verValue) } @@ -536,7 +410,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { aiValue, err := table.AutoIncrColumn().ValueOf(bean) if err != nil { - session.engine.logger.Error(err) + session.engine.logger.Errorf("%v", err) } if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() { @@ -546,48 +420,48 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { aiValue.Set(int64ToIntValue(id, aiValue.Type())) return 1, nil - } else { - res, err := session.exec(sqlStr, args...) - if err != nil { - return 0, err - } - - defer handleAfterInsertProcessorFunc(bean) - - session.cacheInsert(tableName) + } - if table.Version != "" && session.statement.checkVersion { - verValue, err := table.VersionColumn().ValueOf(bean) - if err != nil { - session.engine.logger.Error(err) - } else if verValue.IsValid() && verValue.CanSet() { - session.incrVersionFieldValue(verValue) - } - } + res, err := session.exec(sqlStr, args...) + if err != nil { + return 0, err + } - if table.AutoIncrement == "" { - return res.RowsAffected() - } + defer handleAfterInsertProcessorFunc(bean) - var id int64 - id, err = res.LastInsertId() - if err != nil || id <= 0 { - return res.RowsAffected() - } + session.cacheInsert(tableName) - aiValue, err := table.AutoIncrColumn().ValueOf(bean) + if table.Version != "" && session.statement.CheckVersion { + verValue, err := table.VersionColumn().ValueOf(bean) if err != nil { - session.engine.logger.Error(err) + session.engine.logger.Errorf("%v", err) + } else if verValue.IsValid() && verValue.CanSet() { + session.incrVersionFieldValue(verValue) } + } - if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() { - return res.RowsAffected() - } + if table.AutoIncrement == "" { + return res.RowsAffected() + } - aiValue.Set(int64ToIntValue(id, aiValue.Type())) + var id int64 + id, err = res.LastInsertId() + if err != nil || id <= 0 { + return res.RowsAffected() + } + aiValue, err := table.AutoIncrColumn().ValueOf(bean) + if err != nil { + session.engine.logger.Errorf("%v", err) + } + + if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() { return res.RowsAffected() } + + aiValue.Set(int64ToIntValue(id, aiValue.Type())) + + return res.RowsAffected() } // InsertOne insert only one struct into database as a record. @@ -605,11 +479,11 @@ func (session *Session) cacheInsert(table string) error { if !session.statement.UseCache { return nil } - cacher := session.engine.getCacher(table) + cacher := session.engine.cacherMgr.GetCacher(table) if cacher == nil { return nil } - session.engine.logger.Debug("[cache] clear sql:", table) + session.engine.logger.Debugf("[cache] clear sql: %v", table) cacher.ClearIds(table) return nil } @@ -621,7 +495,7 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac args := make([]interface{}, 0, len(table.ColumnsSeq())) for _, col := range table.Columns() { - if col.MapType == core.ONLYFROMDB { + if col.MapType == schemas.ONLYFROMDB { continue } @@ -629,19 +503,19 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac continue } - if session.statement.omitColumnMap.contain(col.Name) { + if session.statement.OmitColumnMap.Contain(col.Name) { continue } - if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { + if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(col.Name) { continue } - if session.statement.incrColumns.isColExist(col.Name) { + if session.statement.IncrColumns.IsColExist(col.Name) { continue - } else if session.statement.decrColumns.isColExist(col.Name) { + } else if session.statement.DecrColumns.IsColExist(col.Name) { continue - } else if session.statement.exprColumns.isColExist(col.Name) { + } else if session.statement.ExprColumns.IsColExist(col.Name) { continue } @@ -651,30 +525,13 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac } fieldValue := *fieldValuePtr - if col.IsAutoIncrement { - switch fieldValue.Type().Kind() { - case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64: - if fieldValue.Int() == 0 { - continue - } - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64: - if fieldValue.Uint() == 0 { - continue - } - case reflect.String: - if len(fieldValue.String()) == 0 { - continue - } - case reflect.Ptr: - if fieldValue.Pointer() == 0 { - continue - } - } + if col.IsAutoIncrement && utils.IsValueZero(fieldValue) { + continue } // !evalphobia! set fieldValue as nil when column is nullable and zero-value - if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok { - if col.Nullable && isZero(fieldValue.Interface()) { + if _, ok := getFlagForColumn(session.statement.NullableMap, col); ok { + if col.Nullable && utils.IsValueZero(fieldValue) { var nilValue *int fieldValue = reflect.ValueOf(nilValue) } @@ -690,10 +547,10 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac col := table.GetColumn(colName) setColumnTime(bean, col, t) }) - } else if col.IsVersion && session.statement.checkVersion { + } else if col.IsVersion && session.statement.CheckVersion { args = append(args, 1) } else { - arg, err := session.value2Interface(col, fieldValue) + arg, err := session.statement.Value2Interface(col, fieldValue) if err != nil { return colNames, args, err } @@ -716,9 +573,9 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err } var columns = make([]string, 0, len(m)) - exprs := session.statement.exprColumns + exprs := session.statement.ExprColumns for k := range m { - if !exprs.isColExist(k) { + if !exprs.IsColExist(k) { columns = append(columns, k) } } @@ -743,9 +600,9 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) { } var columns = make([]string, 0, len(m)) - exprs := session.statement.exprColumns + exprs := session.statement.ExprColumns for k := range m { - if !exprs.isColExist(k) { + if !exprs.IsColExist(k) { columns = append(columns, k) } } @@ -766,15 +623,15 @@ func (session *Session) insertMap(columns []string, args []interface{}) (int64, return 0, ErrTableNotFound } - exprs := session.statement.exprColumns + exprs := session.statement.ExprColumns w := builder.NewWriter() // if insert where - if session.statement.cond.IsValid() { + if session.statement.Conds().IsValid() { if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil { return 0, err } - if err := writeStrings(w, append(columns, exprs.colNames...), "`", "`"); err != nil { + if err := session.engine.dialect.Quoter().JoinWrite(w.Builder, append(columns, exprs.ColNames...), ","); err != nil { return 0, err } @@ -782,15 +639,15 @@ func (session *Session) insertMap(columns []string, args []interface{}) (int64, return 0, err } - if err := session.statement.writeArgs(w, args); err != nil { + if err := session.statement.WriteArgs(w, args); err != nil { return 0, err } - if len(exprs.args) > 0 { + if len(exprs.Args) > 0 { if _, err := w.WriteString(","); err != nil { return 0, err } - if err := exprs.writeArgs(w); err != nil { + if err := exprs.WriteArgs(w); err != nil { return 0, err } } @@ -799,7 +656,7 @@ func (session *Session) insertMap(columns []string, args []interface{}) (int64, return 0, err } - if err := session.statement.cond.WriteTo(w); err != nil { + if err := session.statement.Conds().WriteTo(w); err != nil { return 0, err } } else { @@ -810,7 +667,7 @@ func (session *Session) insertMap(columns []string, args []interface{}) (int64, return 0, err } - if err := writeStrings(w, append(columns, exprs.colNames...), "`", "`"); err != nil { + if err := session.engine.dialect.Quoter().JoinWrite(w.Builder, append(columns, exprs.ColNames...), ","); err != nil { return 0, err } if _, err := w.WriteString(fmt.Sprintf(") VALUES (%s", qm)); err != nil { @@ -818,11 +675,11 @@ func (session *Session) insertMap(columns []string, args []interface{}) (int64, } w.Append(args...) - if len(exprs.args) > 0 { + if len(exprs.Args) > 0 { if _, err := w.WriteString(","); err != nil { return 0, err } - if err := exprs.writeArgs(w); err != nil { + if err := exprs.WriteArgs(w); err != nil { return 0, err } } |