diff options
Diffstat (limited to 'models/db/engine.go')
-rwxr-xr-x | models/db/engine.go | 297 |
1 files changed, 297 insertions, 0 deletions
diff --git a/models/db/engine.go b/models/db/engine.go new file mode 100755 index 0000000000..e71b6fdc59 --- /dev/null +++ b/models/db/engine.go @@ -0,0 +1,297 @@ +// Copyright 2014 The Gogs Authors. All rights reserved. +// Copyright 2018 The Gitea Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package db + +import ( + "context" + "database/sql" + "errors" + "fmt" + "io" + "reflect" + "strings" + + "code.gitea.io/gitea/modules/setting" + + // Needed for the MySQL driver + _ "github.com/go-sql-driver/mysql" + "xorm.io/xorm" + "xorm.io/xorm/names" + "xorm.io/xorm/schemas" + + // Needed for the Postgresql driver + _ "github.com/lib/pq" + + // Needed for the MSSQL driver + _ "github.com/denisenkom/go-mssqldb" +) + +var ( + x *xorm.Engine + tables []interface{} + initFuncs []func() error + + // HasEngine specifies if we have a xorm.Engine + HasEngine bool +) + +// Engine represents a xorm engine or session. +type Engine interface { + Table(tableNameOrBean interface{}) *xorm.Session + Count(...interface{}) (int64, error) + Decr(column string, arg ...interface{}) *xorm.Session + Delete(...interface{}) (int64, error) + Exec(...interface{}) (sql.Result, error) + Find(interface{}, ...interface{}) error + Get(interface{}) (bool, error) + ID(interface{}) *xorm.Session + In(string, ...interface{}) *xorm.Session + Incr(column string, arg ...interface{}) *xorm.Session + Insert(...interface{}) (int64, error) + InsertOne(interface{}) (int64, error) + Iterate(interface{}, xorm.IterFunc) error + Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *xorm.Session + SQL(interface{}, ...interface{}) *xorm.Session + Where(interface{}, ...interface{}) *xorm.Session + Asc(colNames ...string) *xorm.Session + Desc(colNames ...string) *xorm.Session + Limit(limit int, start ...int) *xorm.Session + SumInt(bean interface{}, columnName string) (res int64, err error) + Sync2(...interface{}) error + Select(string) *xorm.Session + NotIn(string, ...interface{}) *xorm.Session + OrderBy(string) *xorm.Session + Exist(...interface{}) (bool, error) + Distinct(...string) *xorm.Session + Query(...interface{}) ([]map[string][]byte, error) + Cols(...string) *xorm.Session +} + +// TableInfo returns table's information via an object +func TableInfo(v interface{}) (*schemas.Table, error) { + return x.TableInfo(v) +} + +// DumpTables dump tables information +func DumpTables(tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) error { + return x.DumpTables(tables, w, tp...) +} + +// RegisterModel registers model, if initfunc provided, it will be invoked after data model sync +func RegisterModel(bean interface{}, initFunc ...func() error) { + tables = append(tables, bean) + if len(initFuncs) > 0 && initFunc[0] != nil { + initFuncs = append(initFuncs, initFunc[0]) + } +} + +func init() { + gonicNames := []string{"SSL", "UID"} + for _, name := range gonicNames { + names.LintGonicMapper[name] = true + } +} + +// GetNewEngine returns a new xorm engine from the configuration +func GetNewEngine() (*xorm.Engine, error) { + connStr, err := setting.DBConnStr() + if err != nil { + return nil, err + } + + var engine *xorm.Engine + + if setting.Database.UsePostgreSQL && len(setting.Database.Schema) > 0 { + // OK whilst we sort out our schema issues - create a schema aware postgres + registerPostgresSchemaDriver() + engine, err = xorm.NewEngine("postgresschema", connStr) + } else { + engine, err = xorm.NewEngine(setting.Database.Type, connStr) + } + + if err != nil { + return nil, err + } + if setting.Database.Type == "mysql" { + engine.Dialect().SetParams(map[string]string{"rowFormat": "DYNAMIC"}) + } else if setting.Database.Type == "mssql" { + engine.Dialect().SetParams(map[string]string{"DEFAULT_VARCHAR": "nvarchar"}) + } + engine.SetSchema(setting.Database.Schema) + return engine, nil +} + +func syncTables() error { + return x.StoreEngine("InnoDB").Sync2(tables...) +} + +// NewTestEngine sets a new test xorm.Engine +func NewTestEngine() (err error) { + x, err = GetNewEngine() + if err != nil { + return fmt.Errorf("Connect to database: %v", err) + } + + x.SetMapper(names.GonicMapper{}) + x.SetLogger(NewXORMLogger(!setting.IsProd())) + x.ShowSQL(!setting.IsProd()) + return syncTables() +} + +// SetEngine sets the xorm.Engine +func SetEngine() (err error) { + x, err = GetNewEngine() + if err != nil { + return fmt.Errorf("Failed to connect to database: %v", err) + } + + x.SetMapper(names.GonicMapper{}) + // WARNING: for serv command, MUST remove the output to os.stdout, + // so use log file to instead print to stdout. + x.SetLogger(NewXORMLogger(setting.Database.LogSQL)) + x.ShowSQL(setting.Database.LogSQL) + x.SetMaxOpenConns(setting.Database.MaxOpenConns) + x.SetMaxIdleConns(setting.Database.MaxIdleConns) + x.SetConnMaxLifetime(setting.Database.ConnMaxLifetime) + return nil +} + +// NewEngine initializes a new xorm.Engine +// This function must never call .Sync2() if the provided migration function fails. +// When called from the "doctor" command, the migration function is a version check +// that prevents the doctor from fixing anything in the database if the migration level +// is different from the expected value. +func NewEngine(ctx context.Context, migrateFunc func(*xorm.Engine) error) (err error) { + if err = SetEngine(); err != nil { + return err + } + + x.SetDefaultContext(ctx) + + if err = x.Ping(); err != nil { + return err + } + + if err = migrateFunc(x); err != nil { + return fmt.Errorf("migrate: %v", err) + } + + if err = syncTables(); err != nil { + return fmt.Errorf("sync database struct error: %v", err) + } + + for _, initFunc := range initFuncs { + if err := initFunc(); err != nil { + return fmt.Errorf("initFunc failed: %v", err) + } + } + + return nil +} + +// NamesToBean return a list of beans or an error +func NamesToBean(names ...string) ([]interface{}, error) { + beans := []interface{}{} + if len(names) == 0 { + beans = append(beans, tables...) + return beans, nil + } + // Need to map provided names to beans... + beanMap := make(map[string]interface{}) + for _, bean := range tables { + + beanMap[strings.ToLower(reflect.Indirect(reflect.ValueOf(bean)).Type().Name())] = bean + beanMap[strings.ToLower(x.TableName(bean))] = bean + beanMap[strings.ToLower(x.TableName(bean, true))] = bean + } + + gotBean := make(map[interface{}]bool) + for _, name := range names { + bean, ok := beanMap[strings.ToLower(strings.TrimSpace(name))] + if !ok { + return nil, fmt.Errorf("No table found that matches: %s", name) + } + if !gotBean[bean] { + beans = append(beans, bean) + gotBean[bean] = true + } + } + return beans, nil +} + +// Ping tests if database is alive +func Ping() error { + if x != nil { + return x.Ping() + } + return errors.New("database not configured") +} + +// DumpDatabase dumps all data from database according the special database SQL syntax to file system. +func DumpDatabase(filePath, dbType string) error { + var tbs []*schemas.Table + for _, t := range tables { + t, err := x.TableInfo(t) + if err != nil { + return err + } + tbs = append(tbs, t) + } + + type Version struct { + ID int64 `xorm:"pk autoincr"` + Version int64 + } + t, err := x.TableInfo(&Version{}) + if err != nil { + return err + } + tbs = append(tbs, t) + + if len(dbType) > 0 { + return x.DumpTablesToFile(tbs, filePath, schemas.DBType(dbType)) + } + return x.DumpTablesToFile(tbs, filePath) +} + +// MaxBatchInsertSize returns the table's max batch insert size +func MaxBatchInsertSize(bean interface{}) int { + t, err := x.TableInfo(bean) + if err != nil { + return 50 + } + return 999 / len(t.ColumnsSeq()) +} + +// Count returns records number according struct's fields as database query conditions +func Count(bean interface{}) (int64, error) { + return x.Count(bean) +} + +// IsTableNotEmpty returns true if table has at least one record +func IsTableNotEmpty(tableName string) (bool, error) { + return x.Table(tableName).Exist() +} + +// DeleteAllRecords will delete all the records of this table +func DeleteAllRecords(tableName string) error { + _, err := x.Exec(fmt.Sprintf("DELETE FROM %s", tableName)) + return err +} + +// GetMaxID will return max id of the table +func GetMaxID(beanOrTableName interface{}) (maxID int64, err error) { + _, err = x.Select("MAX(id)").Table(beanOrTableName).Get(&maxID) + return +} + +// FindByMaxID filled results as the condition from database +func FindByMaxID(maxID int64, limit int, results interface{}) error { + return x.Where("id <= ?", maxID). + OrderBy("id DESC"). + Limit(limit). + Find(results) +} |