diff options
Diffstat (limited to 'vendor/xorm.io/xorm/dialects/postgres.go')
-rw-r--r-- | vendor/xorm.io/xorm/dialects/postgres.go | 194 |
1 files changed, 159 insertions, 35 deletions
diff --git a/vendor/xorm.io/xorm/dialects/postgres.go b/vendor/xorm.io/xorm/dialects/postgres.go index 9acf763ab4..96ebfc850d 100644 --- a/vendor/xorm.io/xorm/dialects/postgres.go +++ b/vendor/xorm.io/xorm/dialects/postgres.go @@ -6,6 +6,7 @@ package dialects import ( "context" + "database/sql" "errors" "fmt" "net/url" @@ -777,12 +778,24 @@ var ( var ( // DefaultPostgresSchema default postgres schema DefaultPostgresSchema = "public" + postgresColAliases = map[string]string{ + "numeric": "decimal", + } ) type postgres struct { Base } +// Alias returns a alias of column +func (db *postgres) Alias(col string) string { + v, ok := postgresColAliases[strings.ToLower(col)] + if ok { + return v + } + return col +} + func (db *postgres) Init(uri *URI) error { db.quoter = postgresQuoter return db.Base.Init(db, uri) @@ -797,7 +810,10 @@ func (db *postgres) Version(ctx context.Context, queryer core.Queryer) (*schemas var version string if !rows.Next() { - return nil, errors.New("Unknow version") + if rows.Err() != nil { + return nil, rows.Err() + } + return nil, errors.New("unknow version") } if err := rows.Scan(&version); err != nil { @@ -860,21 +876,16 @@ func (db *postgres) SetQuotePolicy(quotePolicy QuotePolicy) { } } -// FormatBytes formats bytes -func (db *postgres) FormatBytes(bs []byte) string { - return fmt.Sprintf("E'\\x%x'", bs) -} - func (db *postgres) SQLType(c *schemas.Column) string { var res string switch t := c.SQLType.Name; t { - case schemas.TinyInt: + case schemas.TinyInt, schemas.UnsignedTinyInt: res = schemas.SmallInt return res case schemas.Bit: res = schemas.Boolean return res - case schemas.MediumInt, schemas.Int, schemas.Integer: + case schemas.MediumInt, schemas.Int, schemas.Integer, schemas.UnsignedMediumInt, schemas.UnsignedSmallInt: if c.IsAutoIncrement { return schemas.Serial } @@ -930,6 +941,21 @@ func (db *postgres) SQLType(c *schemas.Column) string { return res } +func (db *postgres) ColumnTypeKind(t string) int { + switch strings.ToUpper(t) { + case "DATETIME", "TIMESTAMP": + return schemas.TIME_TYPE + case "VARCHAR", "TEXT": + return schemas.TEXT_TYPE + case "BIGINT", "BIGSERIAL", "SMALLINT", "INT", "INT8", "INT4", "INTEGER", "SERIAL", "FLOAT", "FLOAT4", "REAL", "DOUBLE PRECISION": + return schemas.NUMERIC_TYPE + case "BOOL": + return schemas.BOOL_TYPE + default: + return schemas.UNKNOW_TYPE + } +} + func (db *postgres) IsReserved(name string) bool { _, ok := postgresReservedWords[strings.ToUpper(name)] return ok @@ -1039,7 +1065,10 @@ func (db *postgres) IsColumnExist(queryer core.Queryer, ctx context.Context, tab } defer rows.Close() - return rows.Next(), nil + if rows.Next() { + return true, nil + } + return false, rows.Err() } func (db *postgres) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { @@ -1169,7 +1198,7 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A } } if _, ok := schemas.SqlTypes[col.SQLType.Name]; !ok { - return nil, nil, fmt.Errorf("Unknown colType: %s - %s", dataType, col.SQLType.Name) + return nil, nil, fmt.Errorf("unknown colType: %s - %s", dataType, col.SQLType.Name) } col.Length = maxLen @@ -1177,19 +1206,22 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A if !col.DefaultIsEmpty { if col.SQLType.IsText() { if strings.HasSuffix(col.Default, "::character varying") { - col.Default = strings.TrimRight(col.Default, "::character varying") + col.Default = strings.TrimSuffix(col.Default, "::character varying") } else if !strings.HasPrefix(col.Default, "'") { col.Default = "'" + col.Default + "'" } } else if col.SQLType.IsTime() { if strings.HasSuffix(col.Default, "::timestamp without time zone") { - col.Default = strings.TrimRight(col.Default, "::timestamp without time zone") + col.Default = strings.TrimSuffix(col.Default, "::timestamp without time zone") } } } cols[col.Name] = col colSeq = append(colSeq, col.Name) } + if rows.Err() != nil { + return nil, nil, rows.Err() + } return colSeq, cols, nil } @@ -1220,6 +1252,9 @@ func (db *postgres) GetTables(queryer core.Queryer, ctx context.Context) ([]*sch table.Name = name tables = append(tables, table) } + if rows.Err() != nil { + return nil, rows.Err() + } return tables, nil } @@ -1236,7 +1271,7 @@ func getIndexColName(indexdef string) []string { func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) { args := []interface{}{tableName} - s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1") + s := "SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1" if len(db.getSchema()) != 0 { args = append(args, db.getSchema()) s = s + " AND schemaname=$2" @@ -1248,7 +1283,7 @@ func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableN } defer rows.Close() - indexes := make(map[string]*schemas.Index, 0) + indexes := make(map[string]*schemas.Index) for rows.Next() { var indexType int var indexName, indexdef string @@ -1290,6 +1325,9 @@ func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableN index.IsRegular = isRegular indexes[index.Name] = index } + if rows.Err() != nil { + return nil, rows.Err() + } return indexes, nil } @@ -1298,18 +1336,11 @@ func (db *postgres) Filters() []Filter { } type pqDriver struct { + baseDriver } type values map[string]string -func (vs values) Set(k, v string) { - vs[k] = v -} - -func (vs values) Get(k string) (v string) { - return vs[k] -} - func parseURL(connstr string) (string, error) { u, err := url.Parse(connstr) if err != nil { @@ -1329,30 +1360,94 @@ func parseURL(connstr string) (string, error) { return "", nil } -func parseOpts(name string, o values) error { - if len(name) == 0 { - return fmt.Errorf("invalid options: %s", name) +func parseOpts(urlStr string, o values) error { + if len(urlStr) == 0 { + return fmt.Errorf("invalid options: %s", urlStr) } - name = strings.TrimSpace(name) + urlStr = strings.TrimSpace(urlStr) + + var ( + inQuote bool + state int // 0 key, 1 space, 2 value, 3 equal + start int + key string + ) + for i, c := range urlStr { + switch c { + case ' ': + if !inQuote { + if state == 2 { + state = 1 + v := urlStr[start:i] + if strings.HasPrefix(v, "'") && strings.HasSuffix(v, "'") { + v = v[1 : len(v)-1] + } else if strings.HasPrefix(v, "'") || strings.HasSuffix(v, "'") { + return fmt.Errorf("wrong single quote in %d of %s", i, urlStr) + } + o[key] = v + } else if state != 1 { + return fmt.Errorf("wrong format: %v", urlStr) + } + } + case '\'': + if state == 3 { + state = 2 + start = i + } else if state != 2 { + return fmt.Errorf("wrong format: %v", urlStr) + } + inQuote = !inQuote + case '=': + if !inQuote { + if state != 0 { + return fmt.Errorf("wrong format: %v", urlStr) + } + key = urlStr[start:i] + state = 3 + } + default: + if state == 3 { + state = 2 + start = i + } else if state == 1 { + state = 0 + start = i + } + } - ps := strings.Split(name, " ") - for _, p := range ps { - kv := strings.Split(p, "=") - if len(kv) < 2 { - return fmt.Errorf("invalid option: %q", p) + if i == len(urlStr)-1 { + if state != 2 { + return errors.New("no value matched key") + } + v := urlStr[start : i+1] + if strings.HasPrefix(v, "'") && strings.HasSuffix(v, "'") { + v = v[1 : len(v)-1] + } else if strings.HasPrefix(v, "'") || strings.HasSuffix(v, "'") { + return fmt.Errorf("wrong single quote in %d of %s", i, urlStr) + } + o[key] = v } - o.Set(kv[0], kv[1]) } return nil } +func (p *pqDriver) Features() *DriverFeatures { + return &DriverFeatures{ + SupportReturnInsertedID: false, + } +} + func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) { db := &URI{DBType: schemas.POSTGRES} + var err error + if strings.Contains(dataSourceName, "://") { + if !strings.HasPrefix(dataSourceName, "postgresql://") && !strings.HasPrefix(dataSourceName, "postgres://") { + return nil, fmt.Errorf("unsupported protocol %v", dataSourceName) + } - if strings.HasPrefix(dataSourceName, "postgresql://") || strings.HasPrefix(dataSourceName, "postgres://") { db.DBName, err = parseURL(dataSourceName) if err != nil { return nil, err @@ -1364,7 +1459,7 @@ func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) { return nil, err } - db.DBName = o.Get("dbname") + db.DBName = o["dbname"] } if db.DBName == "" { @@ -1374,6 +1469,32 @@ func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) { return db, nil } +func (p *pqDriver) GenScanResult(colType string) (interface{}, error) { + switch colType { + case "VARCHAR", "TEXT": + var s sql.NullString + return &s, nil + case "BIGINT", "BIGSERIAL": + var s sql.NullInt64 + return &s, nil + case "SMALLINT", "INT", "INT8", "INT4", "INTEGER", "SERIAL": + var s sql.NullInt32 + return &s, nil + case "FLOAT", "FLOAT4", "REAL", "DOUBLE PRECISION": + var s sql.NullFloat64 + return &s, nil + case "DATETIME", "TIMESTAMP": + var s sql.NullTime + return &s, nil + case "BOOL": + var s sql.NullBool + return &s, nil + default: + var r sql.RawBytes + return &r, nil + } +} + type pqDriverPgx struct { pqDriver } @@ -1401,6 +1522,9 @@ func QueryDefaultPostgresSchema(ctx context.Context, queryer core.Queryer) (stri parts := strings.Split(defaultSchema, ",") return strings.TrimSpace(parts[len(parts)-1]), nil } + if rows.Err() != nil { + return "", rows.Err() + } - return "", errors.New("No default schema") + return "", errors.New("no default schema") } |