diff options
Diffstat (limited to 'vendor/github.com/go-xorm/xorm/statement.go')
-rw-r--r-- | vendor/github.com/go-xorm/xorm/statement.go | 343 |
1 files changed, 94 insertions, 249 deletions
diff --git a/vendor/github.com/go-xorm/xorm/statement.go b/vendor/github.com/go-xorm/xorm/statement.go index b6f0baf206..50008d1dd8 100644 --- a/vendor/github.com/go-xorm/xorm/statement.go +++ b/vendor/github.com/go-xorm/xorm/statement.go @@ -272,6 +272,9 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, fieldValue := *fieldValuePtr fieldType := reflect.TypeOf(fieldValue.Interface()) + if fieldType == nil { + continue + } requiredField := useAllCols includeNil := useAllCols @@ -376,7 +379,7 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { continue } - val = engine.FormatTime(col.SQLType.Name, t) + val = engine.formatColTime(col, t) } else if nulType, ok := fieldValue.Interface().(driver.Valuer); ok { val, _ = nulType.Value() } else { @@ -490,224 +493,6 @@ func (statement *Statement) colName(col *core.Column, tableName string) string { return statement.Engine.Quote(col.Name) } -func buildConds(engine *Engine, table *core.Table, bean interface{}, - includeVersion bool, includeUpdated bool, includeNil bool, - includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool, - mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool) (builder.Cond, error) { - var conds []builder.Cond - for _, col := range table.Columns() { - if !includeVersion && col.IsVersion { - continue - } - if !includeUpdated && col.IsUpdated { - continue - } - if !includeAutoIncr && col.IsAutoIncrement { - continue - } - - if engine.dialect.DBType() == core.MSSQL && (col.SQLType.Name == core.Text || col.SQLType.IsBlob() || col.SQLType.Name == core.TimeStampz) { - continue - } - if col.SQLType.IsJson() { - continue - } - - var colName string - if addedTableName { - var nm = tableName - if len(aliasName) > 0 { - nm = aliasName - } - colName = engine.Quote(nm) + "." + engine.Quote(col.Name) - } else { - colName = engine.Quote(col.Name) - } - - fieldValuePtr, err := col.ValueOf(bean) - if err != nil { - engine.logger.Error(err) - continue - } - - if col.IsDeleted && !unscoped { // tag "deleted" is enabled - if engine.dialect.DBType() == core.MSSQL { - conds = append(conds, builder.IsNull{colName}) - } else { - conds = append(conds, builder.IsNull{colName}.Or(builder.Eq{colName: "0001-01-01 00:00:00"})) - } - } - - fieldValue := *fieldValuePtr - if fieldValue.Interface() == nil { - continue - } - - fieldType := reflect.TypeOf(fieldValue.Interface()) - requiredField := useAllCols - - if b, ok := getFlagForColumn(mustColumnMap, col); ok { - if b { - requiredField = true - } else { - continue - } - } - - if fieldType.Kind() == reflect.Ptr { - if fieldValue.IsNil() { - if includeNil { - conds = append(conds, builder.Eq{colName: nil}) - } - continue - } else if !fieldValue.IsValid() { - continue - } else { - // dereference ptr type to instance type - fieldValue = fieldValue.Elem() - fieldType = reflect.TypeOf(fieldValue.Interface()) - requiredField = true - } - } - - var val interface{} - switch fieldType.Kind() { - case reflect.Bool: - if allUseBool || requiredField { - val = fieldValue.Interface() - } else { - // if a bool in a struct, it will not be as a condition because it default is false, - // please use Where() instead - continue - } - case reflect.String: - if !requiredField && fieldValue.String() == "" { - continue - } - // for MyString, should convert to string or panic - if fieldType.String() != reflect.String.String() { - val = fieldValue.String() - } else { - val = fieldValue.Interface() - } - case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: - if !requiredField && fieldValue.Int() == 0 { - continue - } - val = fieldValue.Interface() - case reflect.Float32, reflect.Float64: - if !requiredField && fieldValue.Float() == 0.0 { - continue - } - val = fieldValue.Interface() - case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: - if !requiredField && fieldValue.Uint() == 0 { - continue - } - t := int64(fieldValue.Uint()) - val = reflect.ValueOf(&t).Interface() - case reflect.Struct: - if fieldType.ConvertibleTo(core.TimeType) { - t := fieldValue.Convert(core.TimeType).Interface().(time.Time) - if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { - continue - } - val = engine.FormatTime(col.SQLType.Name, t) - } else if _, ok := reflect.New(fieldType).Interface().(core.Conversion); ok { - continue - } else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok { - val, _ = valNul.Value() - if val == nil { - continue - } - } else { - if col.SQLType.IsJson() { - if col.SQLType.IsText() { - bytes, err := json.Marshal(fieldValue.Interface()) - if err != nil { - engine.logger.Error(err) - continue - } - val = string(bytes) - } else if col.SQLType.IsBlob() { - var bytes []byte - var err error - bytes, err = json.Marshal(fieldValue.Interface()) - if err != nil { - engine.logger.Error(err) - continue - } - val = bytes - } - } else { - engine.autoMapType(fieldValue) - if table, ok := engine.Tables[fieldValue.Type()]; ok { - if len(table.PrimaryKeys) == 1 { - pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) - // fix non-int pk issues - //if pkField.Int() != 0 { - if pkField.IsValid() && !isZero(pkField.Interface()) { - val = pkField.Interface() - } else { - continue - } - } else { - //TODO: how to handler? - panic(fmt.Sprintln("not supported", fieldValue.Interface(), "as", table.PrimaryKeys)) - } - } else { - val = fieldValue.Interface() - } - } - } - case reflect.Array: - continue - case reflect.Slice, reflect.Map: - if fieldValue == reflect.Zero(fieldType) { - continue - } - if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 { - continue - } - - if col.SQLType.IsText() { - bytes, err := json.Marshal(fieldValue.Interface()) - if err != nil { - engine.logger.Error(err) - continue - } - val = string(bytes) - } else if col.SQLType.IsBlob() { - var bytes []byte - var err error - if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) && - fieldType.Elem().Kind() == reflect.Uint8 { - if fieldValue.Len() > 0 { - val = fieldValue.Bytes() - } else { - continue - } - } else { - bytes, err = json.Marshal(fieldValue.Interface()) - if err != nil { - engine.logger.Error(err) - continue - } - val = bytes - } - } else { - continue - } - default: - val = fieldValue.Interface() - } - - conds = append(conds, builder.Eq{colName: val}) - } - - return builder.And(conds...), nil -} - // TableName return current tableName func (statement *Statement) TableName() string { if statement.AltTableName != "" { @@ -810,6 +595,22 @@ func (statement *Statement) col2NewColsWithQuote(columns ...string) []string { return newColumns } +func (statement *Statement) colmap2NewColsWithQuote() []string { + newColumns := make([]string, 0, len(statement.columnMap)) + for col := range statement.columnMap { + fields := strings.Split(strings.TrimSpace(col), ".") + if len(fields) == 1 { + newColumns = append(newColumns, statement.Engine.quote(fields[0])) + } else if len(fields) == 2 { + newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+ + statement.Engine.quote(fields[1])) + } else { + panic(errors.New("unwanted colnames")) + } + } + return newColumns +} + // Distinct generates "DISTINCT col1, col2 " statement func (statement *Statement) Distinct(columns ...string) *Statement { statement.IsDistinct = true @@ -836,7 +637,7 @@ func (statement *Statement) Cols(columns ...string) *Statement { statement.columnMap[strings.ToLower(nc)] = true } - newColumns := statement.col2NewColsWithQuote(columns...) + newColumns := statement.colmap2NewColsWithQuote() statement.ColumnStr = strings.Join(newColumns, ", ") statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1) return statement @@ -1104,26 +905,35 @@ func (statement *Statement) genAddColumnStr(col *core.Column) (string, []interfa } func (statement *Statement) buildConds(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) { - return buildConds(statement.Engine, table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols, + return statement.Engine.buildConds(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols, statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName) } -func (statement *Statement) genConds(bean interface{}) (string, []interface{}, error) { +func (statement *Statement) mergeConds(bean interface{}) error { if !statement.noAutoCondition { var addedTableName = (len(statement.JoinStr) > 0) autoCond, err := statement.buildConds(statement.RefTable, bean, true, true, false, true, addedTableName) if err != nil { - return "", nil, err + return err } statement.cond = statement.cond.And(autoCond) } - statement.processIDParam() + if err := statement.processIDParam(); err != nil { + return err + } + return nil +} + +func (statement *Statement) genConds(bean interface{}) (string, []interface{}, error) { + if err := statement.mergeConds(bean); err != nil { + return "", nil, err + } return builder.ToSQL(statement.cond) } -func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}) { +func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, error) { v := rValue(bean) isStruct := v.Kind() == reflect.Struct if isStruct { @@ -1156,21 +966,37 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}) columnStr = "*" } - var condSQL string - var condArgs []interface{} if isStruct { - condSQL, condArgs, _ = statement.genConds(bean) - } else { - condSQL, condArgs, _ = builder.ToSQL(statement.cond) + if err := statement.mergeConds(bean); err != nil { + return "", nil, err + } + } + condSQL, condArgs, err := builder.ToSQL(statement.cond) + if err != nil { + return "", nil, err } - return statement.genSelectSQL(columnStr, condSQL), append(statement.joinArgs, condArgs...) -} + sqlStr, err := statement.genSelectSQL(columnStr, condSQL) + if err != nil { + return "", nil, err + } -func (statement *Statement) genCountSQL(bean interface{}) (string, []interface{}) { - statement.setRefValue(rValue(bean)) + return sqlStr, append(statement.joinArgs, condArgs...), nil +} - condSQL, condArgs, _ := statement.genConds(bean) +func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interface{}, error) { + var condSQL string + var condArgs []interface{} + var err error + if len(beans) > 0 { + statement.setRefValue(rValue(beans[0])) + condSQL, condArgs, err = statement.genConds(beans[0]) + } else { + condSQL, condArgs, err = builder.ToSQL(statement.cond) + } + if err != nil { + return "", nil, err + } var selectSQL = statement.selectStr if len(selectSQL) <= 0 { @@ -1180,10 +1006,15 @@ func (statement *Statement) genCountSQL(bean interface{}) (string, []interface{} selectSQL = "count(*)" } } - return statement.genSelectSQL(selectSQL, condSQL), append(statement.joinArgs, condArgs...) + sqlStr, err := statement.genSelectSQL(selectSQL, condSQL) + if err != nil { + return "", nil, err + } + + return sqlStr, append(statement.joinArgs, condArgs...), nil } -func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}) { +func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) { statement.setRefValue(rValue(bean)) var sumStrs = make([]string, 0, len(columns)) @@ -1195,12 +1026,20 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri } sumSelect := strings.Join(sumStrs, ", ") - condSQL, condArgs, _ := statement.genConds(bean) + condSQL, condArgs, err := statement.genConds(bean) + if err != nil { + return "", nil, err + } + + sqlStr, err := statement.genSelectSQL(sumSelect, condSQL) + if err != nil { + return "", nil, err + } - return statement.genSelectSQL(sumSelect, condSQL), append(statement.joinArgs, condArgs...) + return sqlStr, append(statement.joinArgs, condArgs...), nil } -func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) { +func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, err error) { var distinct string if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") { distinct = "DISTINCT " @@ -1211,7 +1050,9 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) { var top string var mssqlCondi string - statement.processIDParam() + if err := statement.processIDParam(); err != nil { + return "", err + } var buf bytes.Buffer if len(condSQL) > 0 { @@ -1278,7 +1119,7 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) { } // !nashtsai! REVIEW Sprintf is considered slowest mean of string concatnation, better to work with builder pattern - a = fmt.Sprintf("SELECT %v%v%v%v%v", top, distinct, columnStr, fromStr, whereStr) + a = fmt.Sprintf("SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr) if len(mssqlCondi) > 0 { if len(whereStr) > 0 { a += " AND " + mssqlCondi @@ -1314,19 +1155,23 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) { return } -func (statement *Statement) processIDParam() { +func (statement *Statement) processIDParam() error { if statement.idParam == nil { - return + return nil + } + + if len(statement.RefTable.PrimaryKeys) != len(*statement.idParam) { + return fmt.Errorf("ID condition is error, expect %d primarykeys, there are %d", + len(statement.RefTable.PrimaryKeys), + len(*statement.idParam), + ) } for i, col := range statement.RefTable.PKColumns() { var colName = statement.colName(col, statement.TableName()) - if i < len(*(statement.idParam)) { - statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]}) - } else { - statement.cond = statement.cond.And(builder.Eq{colName: ""}) - } + statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]}) } + return nil } func (statement *Statement) joinColumns(cols []*core.Column, includeTableName bool) string { |