summaryrefslogtreecommitdiffstats
path: root/vendor/xorm.io/xorm/dialects/postgres.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/xorm.io/xorm/dialects/postgres.go')
-rw-r--r--vendor/xorm.io/xorm/dialects/postgres.go194
1 files changed, 159 insertions, 35 deletions
diff --git a/vendor/xorm.io/xorm/dialects/postgres.go b/vendor/xorm.io/xorm/dialects/postgres.go
index 9acf763ab4..96ebfc850d 100644
--- a/vendor/xorm.io/xorm/dialects/postgres.go
+++ b/vendor/xorm.io/xorm/dialects/postgres.go
@@ -6,6 +6,7 @@ package dialects
import (
"context"
+ "database/sql"
"errors"
"fmt"
"net/url"
@@ -777,12 +778,24 @@ var (
var (
// DefaultPostgresSchema default postgres schema
DefaultPostgresSchema = "public"
+ postgresColAliases = map[string]string{
+ "numeric": "decimal",
+ }
)
type postgres struct {
Base
}
+// Alias returns a alias of column
+func (db *postgres) Alias(col string) string {
+ v, ok := postgresColAliases[strings.ToLower(col)]
+ if ok {
+ return v
+ }
+ return col
+}
+
func (db *postgres) Init(uri *URI) error {
db.quoter = postgresQuoter
return db.Base.Init(db, uri)
@@ -797,7 +810,10 @@ func (db *postgres) Version(ctx context.Context, queryer core.Queryer) (*schemas
var version string
if !rows.Next() {
- return nil, errors.New("Unknow version")
+ if rows.Err() != nil {
+ return nil, rows.Err()
+ }
+ return nil, errors.New("unknow version")
}
if err := rows.Scan(&version); err != nil {
@@ -860,21 +876,16 @@ func (db *postgres) SetQuotePolicy(quotePolicy QuotePolicy) {
}
}
-// FormatBytes formats bytes
-func (db *postgres) FormatBytes(bs []byte) string {
- return fmt.Sprintf("E'\\x%x'", bs)
-}
-
func (db *postgres) SQLType(c *schemas.Column) string {
var res string
switch t := c.SQLType.Name; t {
- case schemas.TinyInt:
+ case schemas.TinyInt, schemas.UnsignedTinyInt:
res = schemas.SmallInt
return res
case schemas.Bit:
res = schemas.Boolean
return res
- case schemas.MediumInt, schemas.Int, schemas.Integer:
+ case schemas.MediumInt, schemas.Int, schemas.Integer, schemas.UnsignedMediumInt, schemas.UnsignedSmallInt:
if c.IsAutoIncrement {
return schemas.Serial
}
@@ -930,6 +941,21 @@ func (db *postgres) SQLType(c *schemas.Column) string {
return res
}
+func (db *postgres) ColumnTypeKind(t string) int {
+ switch strings.ToUpper(t) {
+ case "DATETIME", "TIMESTAMP":
+ return schemas.TIME_TYPE
+ case "VARCHAR", "TEXT":
+ return schemas.TEXT_TYPE
+ case "BIGINT", "BIGSERIAL", "SMALLINT", "INT", "INT8", "INT4", "INTEGER", "SERIAL", "FLOAT", "FLOAT4", "REAL", "DOUBLE PRECISION":
+ return schemas.NUMERIC_TYPE
+ case "BOOL":
+ return schemas.BOOL_TYPE
+ default:
+ return schemas.UNKNOW_TYPE
+ }
+}
+
func (db *postgres) IsReserved(name string) bool {
_, ok := postgresReservedWords[strings.ToUpper(name)]
return ok
@@ -1039,7 +1065,10 @@ func (db *postgres) IsColumnExist(queryer core.Queryer, ctx context.Context, tab
}
defer rows.Close()
- return rows.Next(), nil
+ if rows.Next() {
+ return true, nil
+ }
+ return false, rows.Err()
}
func (db *postgres) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
@@ -1169,7 +1198,7 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A
}
}
if _, ok := schemas.SqlTypes[col.SQLType.Name]; !ok {
- return nil, nil, fmt.Errorf("Unknown colType: %s - %s", dataType, col.SQLType.Name)
+ return nil, nil, fmt.Errorf("unknown colType: %s - %s", dataType, col.SQLType.Name)
}
col.Length = maxLen
@@ -1177,19 +1206,22 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A
if !col.DefaultIsEmpty {
if col.SQLType.IsText() {
if strings.HasSuffix(col.Default, "::character varying") {
- col.Default = strings.TrimRight(col.Default, "::character varying")
+ col.Default = strings.TrimSuffix(col.Default, "::character varying")
} else if !strings.HasPrefix(col.Default, "'") {
col.Default = "'" + col.Default + "'"
}
} else if col.SQLType.IsTime() {
if strings.HasSuffix(col.Default, "::timestamp without time zone") {
- col.Default = strings.TrimRight(col.Default, "::timestamp without time zone")
+ col.Default = strings.TrimSuffix(col.Default, "::timestamp without time zone")
}
}
}
cols[col.Name] = col
colSeq = append(colSeq, col.Name)
}
+ if rows.Err() != nil {
+ return nil, nil, rows.Err()
+ }
return colSeq, cols, nil
}
@@ -1220,6 +1252,9 @@ func (db *postgres) GetTables(queryer core.Queryer, ctx context.Context) ([]*sch
table.Name = name
tables = append(tables, table)
}
+ if rows.Err() != nil {
+ return nil, rows.Err()
+ }
return tables, nil
}
@@ -1236,7 +1271,7 @@ func getIndexColName(indexdef string) []string {
func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) {
args := []interface{}{tableName}
- s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1")
+ s := "SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1"
if len(db.getSchema()) != 0 {
args = append(args, db.getSchema())
s = s + " AND schemaname=$2"
@@ -1248,7 +1283,7 @@ func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableN
}
defer rows.Close()
- indexes := make(map[string]*schemas.Index, 0)
+ indexes := make(map[string]*schemas.Index)
for rows.Next() {
var indexType int
var indexName, indexdef string
@@ -1290,6 +1325,9 @@ func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableN
index.IsRegular = isRegular
indexes[index.Name] = index
}
+ if rows.Err() != nil {
+ return nil, rows.Err()
+ }
return indexes, nil
}
@@ -1298,18 +1336,11 @@ func (db *postgres) Filters() []Filter {
}
type pqDriver struct {
+ baseDriver
}
type values map[string]string
-func (vs values) Set(k, v string) {
- vs[k] = v
-}
-
-func (vs values) Get(k string) (v string) {
- return vs[k]
-}
-
func parseURL(connstr string) (string, error) {
u, err := url.Parse(connstr)
if err != nil {
@@ -1329,30 +1360,94 @@ func parseURL(connstr string) (string, error) {
return "", nil
}
-func parseOpts(name string, o values) error {
- if len(name) == 0 {
- return fmt.Errorf("invalid options: %s", name)
+func parseOpts(urlStr string, o values) error {
+ if len(urlStr) == 0 {
+ return fmt.Errorf("invalid options: %s", urlStr)
}
- name = strings.TrimSpace(name)
+ urlStr = strings.TrimSpace(urlStr)
+
+ var (
+ inQuote bool
+ state int // 0 key, 1 space, 2 value, 3 equal
+ start int
+ key string
+ )
+ for i, c := range urlStr {
+ switch c {
+ case ' ':
+ if !inQuote {
+ if state == 2 {
+ state = 1
+ v := urlStr[start:i]
+ if strings.HasPrefix(v, "'") && strings.HasSuffix(v, "'") {
+ v = v[1 : len(v)-1]
+ } else if strings.HasPrefix(v, "'") || strings.HasSuffix(v, "'") {
+ return fmt.Errorf("wrong single quote in %d of %s", i, urlStr)
+ }
+ o[key] = v
+ } else if state != 1 {
+ return fmt.Errorf("wrong format: %v", urlStr)
+ }
+ }
+ case '\'':
+ if state == 3 {
+ state = 2
+ start = i
+ } else if state != 2 {
+ return fmt.Errorf("wrong format: %v", urlStr)
+ }
+ inQuote = !inQuote
+ case '=':
+ if !inQuote {
+ if state != 0 {
+ return fmt.Errorf("wrong format: %v", urlStr)
+ }
+ key = urlStr[start:i]
+ state = 3
+ }
+ default:
+ if state == 3 {
+ state = 2
+ start = i
+ } else if state == 1 {
+ state = 0
+ start = i
+ }
+ }
- ps := strings.Split(name, " ")
- for _, p := range ps {
- kv := strings.Split(p, "=")
- if len(kv) < 2 {
- return fmt.Errorf("invalid option: %q", p)
+ if i == len(urlStr)-1 {
+ if state != 2 {
+ return errors.New("no value matched key")
+ }
+ v := urlStr[start : i+1]
+ if strings.HasPrefix(v, "'") && strings.HasSuffix(v, "'") {
+ v = v[1 : len(v)-1]
+ } else if strings.HasPrefix(v, "'") || strings.HasSuffix(v, "'") {
+ return fmt.Errorf("wrong single quote in %d of %s", i, urlStr)
+ }
+ o[key] = v
}
- o.Set(kv[0], kv[1])
}
return nil
}
+func (p *pqDriver) Features() *DriverFeatures {
+ return &DriverFeatures{
+ SupportReturnInsertedID: false,
+ }
+}
+
func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) {
db := &URI{DBType: schemas.POSTGRES}
+
var err error
+ if strings.Contains(dataSourceName, "://") {
+ if !strings.HasPrefix(dataSourceName, "postgresql://") && !strings.HasPrefix(dataSourceName, "postgres://") {
+ return nil, fmt.Errorf("unsupported protocol %v", dataSourceName)
+ }
- if strings.HasPrefix(dataSourceName, "postgresql://") || strings.HasPrefix(dataSourceName, "postgres://") {
db.DBName, err = parseURL(dataSourceName)
if err != nil {
return nil, err
@@ -1364,7 +1459,7 @@ func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) {
return nil, err
}
- db.DBName = o.Get("dbname")
+ db.DBName = o["dbname"]
}
if db.DBName == "" {
@@ -1374,6 +1469,32 @@ func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) {
return db, nil
}
+func (p *pqDriver) GenScanResult(colType string) (interface{}, error) {
+ switch colType {
+ case "VARCHAR", "TEXT":
+ var s sql.NullString
+ return &s, nil
+ case "BIGINT", "BIGSERIAL":
+ var s sql.NullInt64
+ return &s, nil
+ case "SMALLINT", "INT", "INT8", "INT4", "INTEGER", "SERIAL":
+ var s sql.NullInt32
+ return &s, nil
+ case "FLOAT", "FLOAT4", "REAL", "DOUBLE PRECISION":
+ var s sql.NullFloat64
+ return &s, nil
+ case "DATETIME", "TIMESTAMP":
+ var s sql.NullTime
+ return &s, nil
+ case "BOOL":
+ var s sql.NullBool
+ return &s, nil
+ default:
+ var r sql.RawBytes
+ return &r, nil
+ }
+}
+
type pqDriverPgx struct {
pqDriver
}
@@ -1401,6 +1522,9 @@ func QueryDefaultPostgresSchema(ctx context.Context, queryer core.Queryer) (stri
parts := strings.Split(defaultSchema, ",")
return strings.TrimSpace(parts[len(parts)-1]), nil
}
+ if rows.Err() != nil {
+ return "", rows.Err()
+ }
- return "", errors.New("No default schema")
+ return "", errors.New("no default schema")
}