diff options
author | Lunny Xiao <xiaolunwen@gmail.com> | 2017-02-20 19:33:10 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-02-20 19:33:10 +0800 |
commit | c5f8b96ddaefe83eb87a71a44c46cc13cc7130fa (patch) | |
tree | aa7872c353fa6b6e1b5b18191f929fede553b5d6 /vendor/github.com | |
parent | 04fdeb9d8d4dc3cf296d8354ee29f1d053154a54 (diff) | |
download | gitea-c5f8b96ddaefe83eb87a71a44c46cc13cc7130fa.tar.gz gitea-c5f8b96ddaefe83eb87a71a44c46cc13cc7130fa.zip |
update xorm for fixing bug on processor BeforeSet and AfterSet when Find a map (#987)
Diffstat (limited to 'vendor/github.com')
-rw-r--r-- | vendor/github.com/go-xorm/xorm/convert.go | 249 | ||||
-rw-r--r-- | vendor/github.com/go-xorm/xorm/helpers.go | 37 | ||||
-rw-r--r-- | vendor/github.com/go-xorm/xorm/rows.go | 3 | ||||
-rw-r--r-- | vendor/github.com/go-xorm/xorm/session.go | 91 | ||||
-rw-r--r-- | vendor/github.com/go-xorm/xorm/session_find.go | 128 | ||||
-rw-r--r-- | vendor/github.com/go-xorm/xorm/session_get.go | 2 |
6 files changed, 360 insertions, 150 deletions
diff --git a/vendor/github.com/go-xorm/xorm/convert.go b/vendor/github.com/go-xorm/xorm/convert.go new file mode 100644 index 0000000000..87f0d3f1ec --- /dev/null +++ b/vendor/github.com/go-xorm/xorm/convert.go @@ -0,0 +1,249 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "database/sql/driver" + "errors" + "fmt" + "reflect" + "strconv" + "time" +) + +var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error + +func strconvErr(err error) error { + if ne, ok := err.(*strconv.NumError); ok { + return ne.Err + } + return err +} + +func cloneBytes(b []byte) []byte { + if b == nil { + return nil + } else { + c := make([]byte, len(b)) + copy(c, b) + return c + } +} + +func asString(src interface{}) string { + switch v := src.(type) { + case string: + return v + case []byte: + return string(v) + } + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.FormatInt(rv.Int(), 10) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return strconv.FormatUint(rv.Uint(), 10) + case reflect.Float64: + return strconv.FormatFloat(rv.Float(), 'g', -1, 64) + case reflect.Float32: + return strconv.FormatFloat(rv.Float(), 'g', -1, 32) + case reflect.Bool: + return strconv.FormatBool(rv.Bool()) + } + return fmt.Sprintf("%v", src) +} + +func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.AppendInt(buf, rv.Int(), 10), true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return strconv.AppendUint(buf, rv.Uint(), 10), true + case reflect.Float32: + return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true + case reflect.Float64: + return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true + case reflect.Bool: + return strconv.AppendBool(buf, rv.Bool()), true + case reflect.String: + s := rv.String() + return append(buf, s...), true + } + return +} + +// convertAssign copies to dest the value in src, converting it if possible. +// An error is returned if the copy would result in loss of information. +// dest should be a pointer type. +func convertAssign(dest, src interface{}) error { + // Common cases, without reflect. + switch s := src.(type) { + case string: + switch d := dest.(type) { + case *string: + if d == nil { + return errNilPtr + } + *d = s + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = []byte(s) + return nil + } + case []byte: + switch d := dest.(type) { + case *string: + if d == nil { + return errNilPtr + } + *d = string(s) + return nil + case *interface{}: + if d == nil { + return errNilPtr + } + *d = cloneBytes(s) + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = cloneBytes(s) + return nil + } + + case time.Time: + switch d := dest.(type) { + case *string: + *d = s.Format(time.RFC3339Nano) + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = []byte(s.Format(time.RFC3339Nano)) + return nil + } + case nil: + switch d := dest.(type) { + case *interface{}: + if d == nil { + return errNilPtr + } + *d = nil + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = nil + return nil + } + } + + var sv reflect.Value + + switch d := dest.(type) { + case *string: + sv = reflect.ValueOf(src) + switch sv.Kind() { + case reflect.Bool, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + *d = asString(src) + return nil + } + case *[]byte: + sv = reflect.ValueOf(src) + if b, ok := asBytes(nil, sv); ok { + *d = b + return nil + } + case *bool: + bv, err := driver.Bool.ConvertValue(src) + if err == nil { + *d = bv.(bool) + } + return err + case *interface{}: + *d = src + return nil + } + + dpv := reflect.ValueOf(dest) + if dpv.Kind() != reflect.Ptr { + return errors.New("destination not a pointer") + } + if dpv.IsNil() { + return errNilPtr + } + + if !sv.IsValid() { + sv = reflect.ValueOf(src) + } + + dv := reflect.Indirect(dpv) + if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) { + switch b := src.(type) { + case []byte: + dv.Set(reflect.ValueOf(cloneBytes(b))) + default: + dv.Set(sv) + } + return nil + } + + if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) { + dv.Set(sv.Convert(dv.Type())) + return nil + } + + switch dv.Kind() { + case reflect.Ptr: + if src == nil { + dv.Set(reflect.Zero(dv.Type())) + return nil + } else { + dv.Set(reflect.New(dv.Type().Elem())) + return convertAssign(dv.Interface(), src) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + s := asString(src) + i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetInt(i64) + return nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + s := asString(src) + u64, err := strconv.ParseUint(s, 10, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetUint(u64) + return nil + case reflect.Float32, reflect.Float64: + s := asString(src) + f64, err := strconv.ParseFloat(s, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetFloat(f64) + return nil + case reflect.String: + dv.SetString(asString(src)) + return nil + } + + return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest) +} diff --git a/vendor/github.com/go-xorm/xorm/helpers.go b/vendor/github.com/go-xorm/xorm/helpers.go index a015ca50b9..4e26e84143 100644 --- a/vendor/github.com/go-xorm/xorm/helpers.go +++ b/vendor/github.com/go-xorm/xorm/helpers.go @@ -17,74 +17,83 @@ import ( ) // str2PK convert string value to primary key value according to tp -func str2PK(s string, tp reflect.Type) (interface{}, error) { +func str2PKValue(s string, tp reflect.Type) (reflect.Value, error) { var err error var result interface{} + var defReturn = reflect.Zero(tp) + switch tp.Kind() { case reflect.Int: result, err = strconv.Atoi(s) if err != nil { - return nil, errors.New("convert " + s + " as int: " + err.Error()) + return defReturn, fmt.Errorf("convert %s as int: %s", s, err.Error()) } case reflect.Int8: x, err := strconv.Atoi(s) if err != nil { - return nil, errors.New("convert " + s + " as int16: " + err.Error()) + return defReturn, fmt.Errorf("convert %s as int8: %s", s, err.Error()) } result = int8(x) case reflect.Int16: x, err := strconv.Atoi(s) if err != nil { - return nil, errors.New("convert " + s + " as int16: " + err.Error()) + return defReturn, fmt.Errorf("convert %s as int16: %s", s, err.Error()) } result = int16(x) case reflect.Int32: x, err := strconv.Atoi(s) if err != nil { - return nil, errors.New("convert " + s + " as int32: " + err.Error()) + return defReturn, fmt.Errorf("convert %s as int32: %s", s, err.Error()) } result = int32(x) case reflect.Int64: result, err = strconv.ParseInt(s, 10, 64) if err != nil { - return nil, errors.New("convert " + s + " as int64: " + err.Error()) + return defReturn, fmt.Errorf("convert %s as int64: %s", s, err.Error()) } case reflect.Uint: x, err := strconv.ParseUint(s, 10, 64) if err != nil { - return nil, errors.New("convert " + s + " as uint: " + err.Error()) + return defReturn, fmt.Errorf("convert %s as uint: %s", s, err.Error()) } result = uint(x) case reflect.Uint8: x, err := strconv.ParseUint(s, 10, 64) if err != nil { - return nil, errors.New("convert " + s + " as uint8: " + err.Error()) + return defReturn, fmt.Errorf("convert %s as uint8: %s", s, err.Error()) } result = uint8(x) case reflect.Uint16: x, err := strconv.ParseUint(s, 10, 64) if err != nil { - return nil, errors.New("convert " + s + " as uint16: " + err.Error()) + return defReturn, fmt.Errorf("convert %s as uint16: %s", s, err.Error()) } result = uint16(x) case reflect.Uint32: x, err := strconv.ParseUint(s, 10, 64) if err != nil { - return nil, errors.New("convert " + s + " as uint32: " + err.Error()) + return defReturn, fmt.Errorf("convert %s as uint32: %s", s, err.Error()) } result = uint32(x) case reflect.Uint64: result, err = strconv.ParseUint(s, 10, 64) if err != nil { - return nil, errors.New("convert " + s + " as uint64: " + err.Error()) + return defReturn, fmt.Errorf("convert %s as uint64: %s", s, err.Error()) } case reflect.String: result = s default: - panic("unsupported convert type") + return defReturn, errors.New("unsupported convert type") } - result = reflect.ValueOf(result).Convert(tp).Interface() - return result, nil + return reflect.ValueOf(result).Convert(tp), nil +} + +func str2PK(s string, tp reflect.Type) (interface{}, error) { + v, err := str2PKValue(s, tp) + if err != nil { + return nil, err + } + return v.Interface(), nil } func splitTag(tag string) (tags []string) { diff --git a/vendor/github.com/go-xorm/xorm/rows.go b/vendor/github.com/go-xorm/xorm/rows.go index d35040cdf2..e9cf8597cc 100644 --- a/vendor/github.com/go-xorm/xorm/rows.go +++ b/vendor/github.com/go-xorm/xorm/rows.go @@ -114,7 +114,8 @@ func (rows *Rows) Scan(bean interface{}) error { return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType) } - return rows.session.row2Bean(rows.rows, rows.fields, rows.fieldsCount, bean) + _, err := rows.session.row2Bean(rows.rows, rows.fields, rows.fieldsCount, bean) + return err } // Close session if session.IsAutoClose is true, and claimed any opened resources diff --git a/vendor/github.com/go-xorm/xorm/session.go b/vendor/github.com/go-xorm/xorm/session.go index 6e1b02afb0..2efc74b285 100644 --- a/vendor/github.com/go-xorm/xorm/session.go +++ b/vendor/github.com/go-xorm/xorm/session.go @@ -386,52 +386,6 @@ func cleanupProcessorsClosures(slices *[]func(interface{})) { } } -func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]byte) error { - dataStruct := rValue(obj) - if dataStruct.Kind() != reflect.Struct { - return errors.New("Expected a pointer to a struct") - } - - var col *core.Column - session.Statement.setRefValue(dataStruct) - table := session.Statement.RefTable - tableName := session.Statement.tableName - - for key, data := range objMap { - if col = table.GetColumn(key); col == nil { - session.Engine.logger.Warnf("struct %v's has not field %v. %v", - table.Type.Name(), key, table.ColumnsSeq()) - continue - } - - fieldName := col.FieldName - fieldPath := strings.Split(fieldName, ".") - var fieldValue reflect.Value - if len(fieldPath) > 2 { - session.Engine.logger.Error("Unsupported mutliderive", fieldName) - continue - } else if len(fieldPath) == 2 { - parentField := dataStruct.FieldByName(fieldPath[0]) - if parentField.IsValid() { - fieldValue = parentField.FieldByName(fieldPath[1]) - } - } else { - fieldValue = dataStruct.FieldByName(fieldName) - } - if !fieldValue.IsValid() || !fieldValue.CanSet() { - session.Engine.logger.Warnf("table %v's column %v is not valid or cannot set", tableName, key) - continue - } - - err := session.bytes2Value(col, &fieldValue, data) - if err != nil { - return err - } - } - - return nil -} - func (session *Session) canCache() bool { if session.Statement.RefTable == nil || session.Statement.JoinStr != "" || @@ -485,24 +439,28 @@ type Cell *interface{} func (session *Session) rows2Beans(rows *core.Rows, fields []string, fieldsCount int, table *core.Table, newElemFunc func() reflect.Value, - sliceValueSetFunc func(*reflect.Value)) error { + sliceValueSetFunc func(*reflect.Value, core.PK) error) error { for rows.Next() { var newValue = newElemFunc() bean := newValue.Interface() dataStruct := rValue(bean) - err := session._row2Bean(rows, fields, fieldsCount, bean, &dataStruct, table) + pk, err := session._row2Bean(rows, fields, fieldsCount, bean, &dataStruct, table) + if err != nil { + return err + } + + err = sliceValueSetFunc(&newValue, pk) if err != nil { return err } - sliceValueSetFunc(&newValue) } return nil } -func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount int, bean interface{}) error { +func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount int, bean interface{}) (core.PK, error) { dataStruct := rValue(bean) if dataStruct.Kind() != reflect.Struct { - return errors.New("Expected a pointer to a struct") + return nil, errors.New("Expected a pointer to a struct") } session.Statement.setRefValue(dataStruct) @@ -510,14 +468,14 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i return session._row2Bean(rows, fields, fieldsCount, bean, &dataStruct, session.Statement.RefTable) } -func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount int, bean interface{}, dataStruct *reflect.Value, table *core.Table) error { +func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount int, bean interface{}, dataStruct *reflect.Value, table *core.Table) (core.PK, error) { scanResults := make([]interface{}, fieldsCount) for i := 0; i < len(fields); i++ { var cell interface{} scanResults[i] = &cell } if err := rows.Scan(scanResults...); err != nil { - return err + return nil, err } if b, hasBeforeSet := bean.(BeforeSetProcessor); hasBeforeSet { @@ -535,6 +493,7 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount }() var tempMap = make(map[string]int) + var pk core.PK for ii, key := range fields { var idx int var ok bool @@ -579,10 +538,12 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount rawValueType := reflect.TypeOf(rawValue.Interface()) vv := reflect.ValueOf(rawValue.Interface()) - + col := table.GetColumnIdx(key, idx) + if col.IsPrimaryKey { + pk = append(pk, rawValue.Interface()) + } fieldType := fieldValue.Type() hasAssigned := false - col := table.GetColumnIdx(key, idx) if col.SQLType.IsJson() { var bs []byte @@ -591,7 +552,7 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount } else if rawValueType.ConvertibleTo(core.BytesType) { bs = vv.Bytes() } else { - return fmt.Errorf("unsupported database data type: %s %v", key, rawValueType.Kind()) + return nil, fmt.Errorf("unsupported database data type: %s %v", key, rawValueType.Kind()) } hasAssigned = true @@ -601,14 +562,14 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount err := json.Unmarshal(bs, fieldValue.Addr().Interface()) if err != nil { session.Engine.logger.Error(key, err) - return err + return nil, err } } else { x := reflect.New(fieldType) err := json.Unmarshal(bs, x.Interface()) if err != nil { session.Engine.logger.Error(key, err) - return err + return nil, err } fieldValue.Set(x.Elem()) } @@ -633,14 +594,14 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount err := json.Unmarshal(bs, fieldValue.Addr().Interface()) if err != nil { session.Engine.logger.Error(err) - return err + return nil, err } } else { x := reflect.New(fieldType) err := json.Unmarshal(bs, x.Interface()) if err != nil { session.Engine.logger.Error(err) - return err + return nil, err } fieldValue.Set(x.Elem()) } @@ -772,7 +733,7 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount err := json.Unmarshal([]byte(vv.String()), x.Interface()) if err != nil { session.Engine.logger.Error(err) - return err + return nil, err } fieldValue.Set(x.Elem()) } @@ -783,7 +744,7 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount err := json.Unmarshal(vv.Bytes(), x.Interface()) if err != nil { session.Engine.logger.Error(err) - return err + return nil, err } fieldValue.Set(x.Elem()) } @@ -835,14 +796,14 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount defer newsession.Close() has, err := newsession.Id(pk).NoCascade().Get(structInter.Interface()) if err != nil { - return err + return nil, err } if has { //v := structInter.Elem().Interface() //fieldValue.Set(reflect.ValueOf(v)) fieldValue.Set(structInter.Elem()) } else { - return errors.New("cascade obj is not exist") + return nil, errors.New("cascade obj is not exist") } } } else { @@ -982,7 +943,7 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount } } } - return nil + return pk, nil } func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) { diff --git a/vendor/github.com/go-xorm/xorm/session_find.go b/vendor/github.com/go-xorm/xorm/session_find.go index 2e52fff3c7..ff79033b91 100644 --- a/vendor/github.com/go-xorm/xorm/session_find.go +++ b/vendor/github.com/go-xorm/xorm/session_find.go @@ -43,14 +43,12 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) pv := reflect.New(sliceElementType.Elem()) session.Statement.setRefValue(pv.Elem()) } else { - //return errors.New("slice type") tp = tpNonStruct } } else if sliceElementType.Kind() == reflect.Struct { pv := reflect.New(sliceElementType) session.Statement.setRefValue(pv.Elem()) } else { - //return errors.New("slice type") tp = tpNonStruct } } @@ -148,62 +146,10 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) } } - if sliceValue.Kind() != reflect.Map { - return session.noCacheFind(sliceValue, sqlStr, args...) - } - - resultsSlice, err := session.query(sqlStr, args...) - if err != nil { - return err - } - - keyType := sliceValue.Type().Key() - - for _, results := range resultsSlice { - var newValue reflect.Value - if sliceElementType.Kind() == reflect.Ptr { - newValue = reflect.New(sliceElementType.Elem()) - } else { - newValue = reflect.New(sliceElementType) - } - err := session.scanMapIntoStruct(newValue.Interface(), results) - if err != nil { - return err - } - var key interface{} - // if there is only one pk, we can put the id as map key. - if len(table.PrimaryKeys) == 1 { - key, err = str2PK(string(results[table.PrimaryKeys[0]]), keyType) - if err != nil { - return err - } - } else { - if keyType.Kind() != reflect.Slice { - panic("don't support multiple primary key's map has non-slice key type") - } else { - var keys core.PK = make([]interface{}, 0, len(table.PrimaryKeys)) - for _, pk := range table.PrimaryKeys { - skey, err := str2PK(string(results[pk]), keyType) - if err != nil { - return err - } - keys = append(keys, skey) - } - key = keys - } - } - - if sliceElementType.Kind() == reflect.Ptr { - sliceValue.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(newValue.Interface())) - } else { - sliceValue.SetMapIndex(reflect.ValueOf(key), reflect.Indirect(reflect.ValueOf(newValue.Interface()))) - } - } - - return nil + return session.noCacheFind(table, sliceValue, sqlStr, args...) } -func (session *Session) noCacheFind(sliceValue reflect.Value, sqlStr string, args ...interface{}) error { +func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Value, sqlStr string, args ...interface{}) error { var rawRows *core.Rows var err error @@ -224,27 +170,59 @@ func (session *Session) noCacheFind(sliceValue reflect.Value, sqlStr string, arg } var newElemFunc func() reflect.Value - sliceElementType := sliceValue.Type().Elem() - if sliceElementType.Kind() == reflect.Ptr { + elemType := containerValue.Type().Elem() + if elemType.Kind() == reflect.Ptr { newElemFunc = func() reflect.Value { - return reflect.New(sliceElementType.Elem()) + return reflect.New(elemType.Elem()) } } else { newElemFunc = func() reflect.Value { - return reflect.New(sliceElementType) + return reflect.New(elemType) } } - var sliceValueSetFunc func(*reflect.Value) + var containerValueSetFunc func(*reflect.Value, core.PK) error - if sliceValue.Kind() == reflect.Slice { - if sliceElementType.Kind() == reflect.Ptr { - sliceValueSetFunc = func(newValue *reflect.Value) { - sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(newValue.Interface()))) + if containerValue.Kind() == reflect.Slice { + if elemType.Kind() == reflect.Ptr { + containerValueSetFunc = func(newValue *reflect.Value, pk core.PK) error { + containerValue.Set(reflect.Append(containerValue, reflect.ValueOf(newValue.Interface()))) + return nil } } else { - sliceValueSetFunc = func(newValue *reflect.Value) { - sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(newValue.Interface())))) + containerValueSetFunc = func(newValue *reflect.Value, pk core.PK) error { + containerValue.Set(reflect.Append(containerValue, reflect.Indirect(reflect.ValueOf(newValue.Interface())))) + return nil + } + } + } else { + keyType := containerValue.Type().Key() + if len(table.PrimaryKeys) == 0 { + return errors.New("don't support multiple primary key's map has non-slice key type") + } + if len(table.PrimaryKeys) > 1 && keyType.Kind() != reflect.Slice { + return errors.New("don't support multiple primary key's map has non-slice key type") + } + + if elemType.Kind() == reflect.Ptr { + containerValueSetFunc = func(newValue *reflect.Value, pk core.PK) error { + keyValue := reflect.New(keyType) + err := convertPKToValue(table, keyValue.Interface(), pk) + if err != nil { + return err + } + containerValue.SetMapIndex(keyValue.Elem(), reflect.ValueOf(newValue.Interface())) + return nil + } + } else { + containerValueSetFunc = func(newValue *reflect.Value, pk core.PK) error { + keyValue := reflect.New(keyType) + err := convertPKToValue(table, keyValue.Interface(), pk) + if err != nil { + return err + } + containerValue.SetMapIndex(keyValue.Elem(), reflect.Indirect(reflect.ValueOf(newValue.Interface()))) + return nil } } } @@ -252,7 +230,7 @@ func (session *Session) noCacheFind(sliceValue reflect.Value, sqlStr string, arg var newValue = newElemFunc() dataStruct := rValue(newValue.Interface()) if dataStruct.Kind() == reflect.Struct { - return session.rows2Beans(rawRows, fields, len(fields), session.Engine.autoMapType(dataStruct), newElemFunc, sliceValueSetFunc) + return session.rows2Beans(rawRows, fields, len(fields), session.Engine.autoMapType(dataStruct), newElemFunc, containerValueSetFunc) } for rawRows.Next() { @@ -263,8 +241,20 @@ func (session *Session) noCacheFind(sliceValue reflect.Value, sqlStr string, arg return err } - sliceValueSetFunc(&newValue) + if err := containerValueSetFunc(&newValue, nil); err != nil { + return err + } + } + return nil +} + +func convertPKToValue(table *core.Table, dst interface{}, pk core.PK) error { + cols := table.PKColumns() + if len(cols) == 1 { + return convertAssign(dst, pk[0]) } + + dst = pk return nil } diff --git a/vendor/github.com/go-xorm/xorm/session_get.go b/vendor/github.com/go-xorm/xorm/session_get.go index f32bf4810f..ac0c5ebbf7 100644 --- a/vendor/github.com/go-xorm/xorm/session_get.go +++ b/vendor/github.com/go-xorm/xorm/session_get.go @@ -67,7 +67,7 @@ func (session *Session) nocacheGet(bean interface{}, sqlStr string, args ...inte if rawRows.Next() { fields, err := rawRows.Columns() if err == nil { - err = session.row2Bean(rawRows, fields, len(fields), bean) + _, err = session.row2Bean(rawRows, fields, len(fields), bean) } return true, err } |