diff options
Diffstat (limited to 'vendor/xorm.io/xorm/session.go')
-rw-r--r-- | vendor/xorm.io/xorm/session.go | 191 |
1 files changed, 112 insertions, 79 deletions
diff --git a/vendor/xorm.io/xorm/session.go b/vendor/xorm.io/xorm/session.go index 8307193550..4842883b6b 100644 --- a/vendor/xorm.io/xorm/session.go +++ b/vendor/xorm.io/xorm/session.go @@ -14,9 +14,34 @@ import ( "strings" "time" - "xorm.io/core" + "xorm.io/xorm/contexts" + "xorm.io/xorm/convert" + "xorm.io/xorm/core" + "xorm.io/xorm/internal/json" + "xorm.io/xorm/internal/statements" + "xorm.io/xorm/schemas" ) +// ErrFieldIsNotExist columns does not exist +type ErrFieldIsNotExist struct { + FieldName string + TableName string +} + +func (e ErrFieldIsNotExist) Error() string { + return fmt.Sprintf("field %s is not valid on table %s", e.FieldName, e.TableName) +} + +// ErrFieldIsNotValid is not valid +type ErrFieldIsNotValid struct { + FieldName string + TableName string +} + +func (e ErrFieldIsNotValid) Error() string { + return fmt.Sprintf("field %s is not valid on table %s", e.FieldName, e.TableName) +} + type sessionType int const ( @@ -30,7 +55,7 @@ type Session struct { db *core.DB engine *Engine tx *core.Tx - statement Statement + statement *statements.Statement isAutoCommit bool isCommitedOrRollbacked bool isAutoClose bool @@ -53,8 +78,6 @@ type Session struct { prepareStmt bool stmtCache map[uint32]*core.Stmt //key: hash.Hash32 of (queryStr, len(queryStr)) - // !evalphobia! stored the last executed query on this session - //beforeSQLExec func(string, ...interface{}) lastSQL string lastSQLArgs []interface{} showSQL bool @@ -71,9 +94,12 @@ func (session *Session) Clone() *Session { // Init reset the session as the init status. func (session *Session) Init() { - session.statement.Init() - session.statement.Engine = session.engine - session.showSQL = session.engine.showSQL + session.statement = statements.NewStatement( + session.engine.dialect, + session.engine.tagParser, + session.engine.DatabaseTZ, + ) + session.isAutoCommit = true session.isCommitedOrRollbacked = false session.isAutoClose = false @@ -115,8 +141,8 @@ func (session *Session) Close() { } // ContextCache enable context cache or not -func (session *Session) ContextCache(context ContextCache) *Session { - session.statement.context = context +func (session *Session) ContextCache(context contexts.ContextCache) *Session { + session.statement.SetContextCache(context) return session } @@ -127,7 +153,7 @@ func (session *Session) IsClosed() bool { func (session *Session) resetStatement() { if session.autoResetStatement { - session.statement.Init() + session.statement.Reset() } } @@ -155,7 +181,9 @@ func (session *Session) After(closures func(interface{})) *Session { // Table can input a string or pointer to struct for special a table to operate. func (session *Session) Table(tableNameOrBean interface{}) *Session { - session.statement.Table(tableNameOrBean) + if err := session.statement.SetTable(tableNameOrBean); err != nil { + session.statement.LastError = err + } return session } @@ -179,7 +207,7 @@ func (session *Session) ForUpdate() *Session { // NoAutoCondition disable generate SQL condition from beans func (session *Session) NoAutoCondition(no ...bool) *Session { - session.statement.NoAutoCondition(no...) + session.statement.SetNoAutoCondition(no...) return session } @@ -230,11 +258,11 @@ func (session *Session) Cascade(trueOrFalse ...bool) *Session { // MustLogSQL means record SQL or not and don't follow engine's setting func (session *Session) MustLogSQL(log ...bool) *Session { + var showSQL = true if len(log) > 0 { - session.showSQL = log[0] - } else { - session.showSQL = true + showSQL = log[0] } + session.ctx = context.WithValue(session.ctx, "__xorm_show_sql", showSQL) return session } @@ -266,7 +294,7 @@ func (session *Session) Having(conditions string) *Session { // DB db return the wrapper of sql.DB func (session *Session) DB() *core.DB { if session.db == nil { - session.db = session.engine.db + session.db = session.engine.DB() session.stmtCache = make(map[uint32]*core.Stmt, 0) } return session.db @@ -285,7 +313,7 @@ func (session *Session) canCache() bool { !session.statement.UseCache || session.statement.IsForUpdate || session.tx != nil || - len(session.statement.selectStr) > 0 { + len(session.statement.SelectStr) > 0 { return false } return true @@ -306,8 +334,8 @@ func (session *Session) doPrepare(db *core.DB, sqlStr string) (stmt *core.Stmt, return } -func (session *Session) getField(dataStruct *reflect.Value, key string, table *core.Table, idx int) (*reflect.Value, error) { - var col *core.Column +func (session *Session) getField(dataStruct *reflect.Value, key string, table *schemas.Table, idx int) (*reflect.Value, error) { + var col *schemas.Column if col = table.GetColumnIdx(key, idx); col == nil { return nil, ErrFieldIsNotExist{key, table.Name} } @@ -328,8 +356,8 @@ func (session *Session) getField(dataStruct *reflect.Value, key string, table *c type Cell *interface{} func (session *Session) rows2Beans(rows *core.Rows, fields []string, - table *core.Table, newElemFunc func([]string) reflect.Value, - sliceValueSetFunc func(*reflect.Value, core.PK) error) error { + table *schemas.Table, newElemFunc func([]string) reflect.Value, + sliceValueSetFunc func(*reflect.Value, schemas.PK) error) error { for rows.Next() { var newValue = newElemFunc(fields) bean := newValue.Interface() @@ -377,7 +405,7 @@ func (session *Session) row2Slice(rows *core.Rows, fields []string, bean interfa return scanResults, nil } -func (session *Session) slice2Bean(scanResults []interface{}, fields []string, bean interface{}, dataStruct *reflect.Value, table *core.Table) (core.PK, error) { +func (session *Session) slice2Bean(scanResults []interface{}, fields []string, bean interface{}, dataStruct *reflect.Value, table *schemas.Table) (schemas.PK, error) { defer func() { if b, hasAfterSet := bean.(AfterSetProcessor); hasAfterSet { for ii, key := range fields { @@ -421,7 +449,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b } var tempMap = make(map[string]int) - var pk core.PK + var pk schemas.PK for ii, key := range fields { var idx int var ok bool @@ -436,7 +464,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b fieldValue, err := session.getField(dataStruct, key, table, idx) if err != nil { if !strings.Contains(err.Error(), "is not valid") { - session.engine.logger.Warn(err) + session.engine.logger.Warnf("%v", err) } continue } @@ -451,7 +479,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b } if fieldValue.CanAddr() { - if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok { + if structConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { if data, err := value2Bytes(&rawValue); err == nil { if err := structConvert.FromDB(data); err != nil { return nil, err @@ -463,12 +491,12 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b } } - if _, ok := fieldValue.Interface().(core.Conversion); ok { + if _, ok := fieldValue.Interface().(convert.Conversion); ok { if data, err := value2Bytes(&rawValue); err == nil { if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { fieldValue.Set(reflect.New(fieldValue.Type().Elem())) } - fieldValue.Interface().(core.Conversion).FromDB(data) + fieldValue.Interface().(convert.Conversion).FromDB(data) } else { return nil, err } @@ -488,7 +516,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b var bs []byte if rawValueType.Kind() == reflect.String { bs = []byte(vv.String()) - } else if rawValueType.ConvertibleTo(core.BytesType) { + } else if rawValueType.ConvertibleTo(schemas.BytesType) { bs = vv.Bytes() } else { return nil, fmt.Errorf("unsupported database data type: %s %v", key, rawValueType.Kind()) @@ -502,13 +530,13 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b continue } if fieldValue.CanAddr() { - err := DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface()) + err := json.DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface()) if err != nil { return nil, err } } else { x := reflect.New(fieldType) - err := DefaultJSONHandler.Unmarshal(bs, x.Interface()) + err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface()) if err != nil { return nil, err } @@ -525,20 +553,20 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b var bs []byte if rawValueType.Kind() == reflect.String { bs = []byte(vv.String()) - } else if rawValueType.ConvertibleTo(core.BytesType) { + } else if rawValueType.ConvertibleTo(schemas.BytesType) { bs = vv.Bytes() } hasAssigned = true if len(bs) > 0 { if fieldValue.CanAddr() { - err := DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface()) + err := json.DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface()) if err != nil { return nil, err } } else { x := reflect.New(fieldType) - err := DefaultJSONHandler.Unmarshal(bs, x.Interface()) + err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface()) if err != nil { return nil, err } @@ -554,7 +582,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b hasAssigned = true if col.SQLType.IsText() { x := reflect.New(fieldType) - err := DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) + err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) if err != nil { return nil, err } @@ -607,16 +635,16 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b fieldValue.SetUint(uint64(vv.Int())) } case reflect.Struct: - if fieldType.ConvertibleTo(core.TimeType) { + if fieldType.ConvertibleTo(schemas.TimeType) { dbTZ := session.engine.DatabaseTZ if col.TimeZone != nil { dbTZ = col.TimeZone } - if rawValueType == core.TimeType { + if rawValueType == schemas.TimeType { hasAssigned = true - t := vv.Convert(core.TimeType).Interface().(time.Time) + t := vv.Convert(schemas.TimeType).Interface().(time.Time) z, _ := t.Zone() // set new location if database don't save timezone or give an incorrect timezone @@ -628,8 +656,8 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b t = t.In(session.engine.TZLocation) fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) - } else if rawValueType == core.IntType || rawValueType == core.Int64Type || - rawValueType == core.Int32Type { + } else if rawValueType == schemas.IntType || rawValueType == schemas.Int64Type || + rawValueType == schemas.Int32Type { hasAssigned = true t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation) @@ -639,7 +667,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b hasAssigned = true t, err := session.byte2Time(col, d) if err != nil { - session.engine.logger.Error("byte2Time error:", err.Error()) + session.engine.logger.Errorf("byte2Time error: %v", err) hasAssigned = false } else { fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) @@ -648,7 +676,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b hasAssigned = true t, err := session.str2Time(col, d) if err != nil { - session.engine.logger.Error("byte2Time error:", err.Error()) + session.engine.logger.Errorf("byte2Time error: %v", err) hasAssigned = false } else { fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) @@ -661,7 +689,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b // !<winxxp>! 增加支持sql.Scanner接口的结构,如sql.NullString hasAssigned = true if err := nulVal.Scan(vv.Interface()); err != nil { - session.engine.logger.Error("sql.Sanner error:", err.Error()) + session.engine.logger.Errorf("sql.Sanner error: %v", err) hasAssigned = false } } else if col.SQLType.IsJson() { @@ -669,7 +697,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b hasAssigned = true x := reflect.New(fieldType) if len([]byte(vv.String())) > 0 { - err := DefaultJSONHandler.Unmarshal([]byte(vv.String()), x.Interface()) + err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), x.Interface()) if err != nil { return nil, err } @@ -679,7 +707,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b hasAssigned = true x := reflect.New(fieldType) if len(vv.Bytes()) > 0 { - err := DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) + err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) if err != nil { return nil, err } @@ -687,7 +715,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b } } } else if session.statement.UseCascade { - table, err := session.engine.autoMapType(*fieldValue) + table, err := session.engine.tagParser.ParseWithCache(*fieldValue) if err != nil { return nil, err } @@ -696,13 +724,13 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b if len(table.PrimaryKeys) != 1 { return nil, errors.New("unsupported non or composited primary key cascade") } - var pk = make(core.PK, len(table.PrimaryKeys)) + var pk = make(schemas.PK, len(table.PrimaryKeys)) pk[0], err = asKind(vv, rawValueType) if err != nil { return nil, err } - if !isPKZero(pk) { + if !pk.IsZero() { // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne // property to be fetched lazily @@ -722,110 +750,110 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b // !nashtsai! TODO merge duplicated codes above switch fieldType { // following types case matching ptr's native type, therefore assign ptr directly - case core.PtrStringType: + case schemas.PtrStringType: if rawValueType.Kind() == reflect.String { x := vv.String() hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrBoolType: + case schemas.PtrBoolType: if rawValueType.Kind() == reflect.Bool { x := vv.Bool() hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrTimeType: - if rawValueType == core.PtrTimeType { + case schemas.PtrTimeType: + if rawValueType == schemas.PtrTimeType { hasAssigned = true var x = rawValue.Interface().(time.Time) fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrFloat64Type: + case schemas.PtrFloat64Type: if rawValueType.Kind() == reflect.Float64 { x := vv.Float() hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrUint64Type: + case schemas.PtrUint64Type: if rawValueType.Kind() == reflect.Int64 { var x = uint64(vv.Int()) hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrInt64Type: + case schemas.PtrInt64Type: if rawValueType.Kind() == reflect.Int64 { x := vv.Int() hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrFloat32Type: + case schemas.PtrFloat32Type: if rawValueType.Kind() == reflect.Float64 { var x = float32(vv.Float()) hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrIntType: + case schemas.PtrIntType: if rawValueType.Kind() == reflect.Int64 { var x = int(vv.Int()) hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrInt32Type: + case schemas.PtrInt32Type: if rawValueType.Kind() == reflect.Int64 { var x = int32(vv.Int()) hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrInt8Type: + case schemas.PtrInt8Type: if rawValueType.Kind() == reflect.Int64 { var x = int8(vv.Int()) hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrInt16Type: + case schemas.PtrInt16Type: if rawValueType.Kind() == reflect.Int64 { var x = int16(vv.Int()) hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrUintType: + case schemas.PtrUintType: if rawValueType.Kind() == reflect.Int64 { var x = uint(vv.Int()) hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrUint32Type: + case schemas.PtrUint32Type: if rawValueType.Kind() == reflect.Int64 { var x = uint32(vv.Int()) hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.Uint8Type: + case schemas.Uint8Type: if rawValueType.Kind() == reflect.Int64 { var x = uint8(vv.Int()) hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.Uint16Type: + case schemas.Uint16Type: if rawValueType.Kind() == reflect.Int64 { var x = uint16(vv.Int()) hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.Complex64Type: + case schemas.Complex64Type: var x complex64 if len([]byte(vv.String())) > 0 { - err := DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x) + err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x) if err != nil { return nil, err } fieldValue.Set(reflect.ValueOf(&x)) } hasAssigned = true - case core.Complex128Type: + case schemas.Complex128Type: var x complex128 if len([]byte(vv.String())) > 0 { - err := DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x) + err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x) if err != nil { return nil, err } @@ -854,17 +882,6 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b func (session *Session) saveLastSQL(sql string, args ...interface{}) { session.lastSQL = sql session.lastSQLArgs = args - session.logSQL(sql, args...) -} - -func (session *Session) logSQL(sqlStr string, sqlArgs ...interface{}) { - if session.showSQL && !session.engine.showExecTime { - if len(sqlArgs) > 0 { - session.engine.logger.Infof("[SQL] %v %#v", sqlStr, sqlArgs) - } else { - session.engine.logger.Infof("[SQL] %v", sqlStr) - } - } } // LastSQL returns last query information @@ -874,7 +891,7 @@ func (session *Session) LastSQL() (string, []interface{}) { // Unscoped always disable struct tag "deleted" func (session *Session) Unscoped() *Session { - session.statement.Unscoped() + session.statement.SetUnscoped() return session } @@ -886,3 +903,19 @@ func (session *Session) incrVersionFieldValue(fieldValue *reflect.Value) { fieldValue.SetUint(fieldValue.Uint() + 1) } } + +// Context sets the context on this session +func (session *Session) Context(ctx context.Context) *Session { + session.ctx = ctx + return session +} + +// PingContext test if database is ok +func (session *Session) PingContext(ctx context.Context) error { + if session.isAutoClose { + defer session.Close() + } + + session.engine.logger.Infof("PING DATABASE %v", session.engine.DriverName()) + return session.DB().PingContext(ctx) +} |