summaryrefslogtreecommitdiffstats
path: root/vendor/xorm.io/xorm/engine.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/xorm.io/xorm/engine.go')
-rw-r--r--vendor/xorm.io/xorm/engine.go161
1 files changed, 129 insertions, 32 deletions
diff --git a/vendor/xorm.io/xorm/engine.go b/vendor/xorm.io/xorm/engine.go
index 873fcdc1b9..d49eea9adc 100644
--- a/vendor/xorm.io/xorm/engine.go
+++ b/vendor/xorm.io/xorm/engine.go
@@ -21,6 +21,7 @@ import (
"xorm.io/xorm/contexts"
"xorm.io/xorm/core"
"xorm.io/xorm/dialects"
+ "xorm.io/xorm/internal/json"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/log"
"xorm.io/xorm/names"
@@ -105,6 +106,15 @@ func NewEngineWithParams(driverName string, dataSourceName string, params map[st
return engine, err
}
+// NewEngineWithDB new a db manager with db. The params will be passed to db.
+func NewEngineWithDB(driverName string, dataSourceName string, db *core.DB) (*Engine, error) {
+ dialect, err := dialects.OpenDialect(driverName, dataSourceName)
+ if err != nil {
+ return nil, err
+ }
+ return newEngine(driverName, dataSourceName, dialect, db)
+}
+
// NewEngineWithDialectAndDB new a db manager according to the parameter.
// If you do not want to use your own dialect or db, please use NewEngine.
// For creating dialect, you can call dialects.OpenDialect. And, for creating db,
@@ -159,6 +169,8 @@ func (engine *Engine) SetLogger(logger interface{}) {
realLogger = t
case log.Logger:
realLogger = log.NewLoggerAdapter(t)
+ default:
+ panic("logger should implement either log.ContextLogger or log.Logger")
}
engine.logger = realLogger
engine.DB().Logger = realLogger
@@ -200,6 +212,11 @@ func (engine *Engine) SetColumnMapper(mapper names.Mapper) {
engine.tagParser.SetColumnMapper(mapper)
}
+// SetTagIdentifier set the tag identifier
+func (engine *Engine) SetTagIdentifier(tagIdentifier string) {
+ engine.tagParser.SetIdentifier(tagIdentifier)
+}
+
// Quote Use QuoteStr quote the string sql
func (engine *Engine) Quote(value string) string {
value = strings.TrimSpace(value)
@@ -441,9 +458,26 @@ func formatColumnValue(dstDialect dialects.Dialect, d interface{}, col *schemas.
}
if col.SQLType.IsText() {
- var v = fmt.Sprintf("%s", d)
+ var v string
+ switch reflect.TypeOf(d).Kind() {
+ case reflect.Struct, reflect.Array, reflect.Slice, reflect.Map:
+ bytes, err := json.DefaultJSONHandler.Marshal(d)
+ if err != nil {
+ v = fmt.Sprintf("%s", d)
+ } else {
+ v = string(bytes)
+ }
+ default:
+ v = fmt.Sprintf("%s", d)
+ }
+
return "'" + strings.Replace(v, "'", "''", -1) + "'"
} else if col.SQLType.IsTime() {
+ if dstDialect.URI().DBType == schemas.MSSQL && col.SQLType.Name == schemas.DateTime {
+ if t, ok := d.(time.Time); ok {
+ return "'" + t.UTC().Format("2006-01-02 15:04:05") + "'"
+ }
+ }
var v = fmt.Sprintf("%s", d)
if strings.HasSuffix(v, " +0000 UTC") {
return fmt.Sprintf("'%s'", v[0:len(v)-len(" +0000 UTC")])
@@ -475,7 +509,7 @@ func formatColumnValue(dstDialect dialects.Dialect, d interface{}, col *schemas.
}
return fmt.Sprintf("%v", strconv.FormatBool(v))
}
- return fmt.Sprintf("%v", d)
+ return fmt.Sprintf("%d", d)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if col.SQLType.Name == schemas.Bool {
v := reflect.ValueOf(d).Uint() > 0
@@ -487,7 +521,7 @@ func formatColumnValue(dstDialect dialects.Dialect, d interface{}, col *schemas.
}
return fmt.Sprintf("%v", strconv.FormatBool(v))
}
- return fmt.Sprintf("%v", d)
+ return fmt.Sprintf("%d", d)
default:
return fmt.Sprintf("%v", d)
}
@@ -521,6 +555,8 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
}
dstDialect.Init(&destURI)
}
+ cacherMgr := caches.NewManager()
+ dstTableCache := tags.NewParser("xorm", dstDialect, engine.GetTableMapper(), engine.GetColumnMapper(), cacherMgr)
_, err := io.WriteString(w, fmt.Sprintf("/*Generated by xorm %s, from %s to %s*/\n\n",
time.Now().In(engine.TZLocation).Format("2006-01-02 15:04:05"), engine.dialect.URI().DBType, dstDialect.URI().DBType))
@@ -529,9 +565,18 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
}
for i, table := range tables {
- tableName := table.Name
+ dstTable := table
+ if table.Type != nil {
+ dstTable, err = dstTableCache.Parse(reflect.New(table.Type).Elem())
+ if err != nil {
+ engine.logger.Errorf("Unable to infer table for %s in new dialect. Error: %v", table.Name)
+ dstTable = table
+ }
+ }
+
+ dstTableName := dstTable.Name
if dstDialect.URI().Schema != "" {
- tableName = fmt.Sprintf("%s.%s", dstDialect.URI().Schema, table.Name)
+ dstTableName = fmt.Sprintf("%s.%s", dstDialect.URI().Schema, dstTable.Name)
}
originalTableName := table.Name
if engine.dialect.URI().Schema != "" {
@@ -543,27 +588,30 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
return err
}
}
- sqls, _ := dstDialect.CreateTableSQL(table, tableName)
+
+ sqls, _ := dstDialect.CreateTableSQL(dstTable, dstTableName)
for _, s := range sqls {
_, err = io.WriteString(w, s+";\n")
if err != nil {
return err
}
}
- if len(table.PKColumns()) > 0 && dstDialect.URI().DBType == schemas.MSSQL {
- fmt.Fprintf(w, "SET IDENTITY_INSERT [%s] ON;\n", table.Name)
+ if len(dstTable.PKColumns()) > 0 && dstDialect.URI().DBType == schemas.MSSQL {
+ fmt.Fprintf(w, "SET IDENTITY_INSERT [%s] ON;\n", dstTable.Name)
}
- for _, index := range table.Indexes {
- _, err = io.WriteString(w, dstDialect.CreateIndexSQL(table.Name, index)+";\n")
+ for _, index := range dstTable.Indexes {
+ _, err = io.WriteString(w, dstDialect.CreateIndexSQL(dstTable.Name, index)+";\n")
if err != nil {
return err
}
}
cols := table.ColumnsSeq()
+ dstCols := dstTable.ColumnsSeq()
+
colNames := engine.dialect.Quoter().Join(cols, ", ")
- destColNames := dstDialect.Quoter().Join(cols, ", ")
+ destColNames := dstDialect.Quoter().Join(dstCols, ", ")
rows, err := engine.DB().QueryContext(engine.defaultContext, "SELECT "+colNames+" FROM "+engine.Quote(originalTableName))
if err != nil {
@@ -571,35 +619,83 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
}
defer rows.Close()
- for rows.Next() {
- dest := make([]interface{}, len(cols))
- err = rows.ScanSlice(&dest)
- if err != nil {
- return err
- }
+ if table.Type != nil {
+ sess := engine.NewSession()
+ defer sess.Close()
+ for rows.Next() {
+ beanValue := reflect.New(table.Type)
+ bean := beanValue.Interface()
+ fields, err := rows.Columns()
+ if err != nil {
+ return err
+ }
+ scanResults, err := sess.row2Slice(rows, fields, bean)
+ if err != nil {
+ return err
+ }
- _, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(tableName)+" ("+destColNames+") VALUES (")
- if err != nil {
- return err
- }
+ dataStruct := utils.ReflectValue(bean)
+ _, err = sess.slice2Bean(scanResults, fields, bean, &dataStruct, table)
+ if err != nil {
+ return err
+ }
+
+ _, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (")
+ if err != nil {
+ return err
+ }
+
+ var temp string
+ for _, d := range dstCols {
+ col := table.GetColumn(d)
+ if col == nil {
+ return errors.New("unknown column error")
+ }
- var temp string
- for i, d := range dest {
- col := table.GetColumn(cols[i])
- if col == nil {
- return errors.New("unknow column error")
+ fields := strings.Split(col.FieldName, ".")
+ field := dataStruct
+ for _, fieldName := range fields {
+ field = field.FieldByName(fieldName)
+ }
+ temp += "," + formatColumnValue(dstDialect, field.Interface(), col)
+ }
+ _, err = io.WriteString(w, temp[1:]+");\n")
+ if err != nil {
+ return err
}
- temp += "," + formatColumnValue(dstDialect, d, col)
}
- _, err = io.WriteString(w, temp[1:]+");\n")
- if err != nil {
- return err
+ } else {
+ for rows.Next() {
+ dest := make([]interface{}, len(cols))
+ err = rows.ScanSlice(&dest)
+ if err != nil {
+ return err
+ }
+
+ _, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (")
+ if err != nil {
+ return err
+ }
+
+ var temp string
+ for i, d := range dest {
+ col := table.GetColumn(cols[i])
+ if col == nil {
+ return errors.New("unknow column error")
+ }
+
+ temp += "," + formatColumnValue(dstDialect, d, col)
+ }
+ _, err = io.WriteString(w, temp[1:]+");\n")
+ if err != nil {
+ return err
+ }
}
}
// FIXME: Hack for postgres
if dstDialect.URI().DBType == schemas.POSTGRES && table.AutoIncrColumn() != nil {
- _, err = io.WriteString(w, "SELECT setval('"+tableName+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dstDialect.Quoter().Quote(tableName)+"), 1), false);\n")
+ _, err = io.WriteString(w, "SELECT setval('"+dstTableName+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dstDialect.Quoter().Quote(dstTableName)+"), 1), false);\n")
if err != nil {
return err
}
@@ -1262,6 +1358,7 @@ func (engine *Engine) SetSchema(schema string) {
engine.dialect.URI().SetSchema(schema)
}
+// AddHook adds a context Hook
func (engine *Engine) AddHook(hook contexts.Hook) {
engine.db.AddHook(hook)
}
@@ -1277,7 +1374,7 @@ func (engine *Engine) tbNameWithSchema(v string) string {
return dialects.TableNameWithSchema(engine.dialect, v)
}
-// ContextHook creates a session with the context
+// Context creates a session with the context
func (engine *Engine) Context(ctx context.Context) *Session {
session := engine.NewSession()
session.isAutoClose = true