* fix forgot removed records when deleting user * fix migration * fix rewritekey lock on sqlite * remove unused codestags/v1.7.0-dev
@@ -390,12 +390,12 @@ | |||
revision = "d523deb1b23d913de5bdada721a6071e71283618" | |||
[[projects]] | |||
digest = "1:1397763fd29d5667bcfbacde8a37542ee145416b3acb7e1b149c98fef2567930" | |||
digest = "1:06d21295033f211588d0ad7ff391cc1b27e72b60cb6d4b7db0d70cffae4cf228" | |||
name = "github.com/go-xorm/builder" | |||
packages = ["."] | |||
pruneopts = "NUT" | |||
revision = "dc8bf48f58fab2b4da338ffd25191905fd741b8f" | |||
version = "v0.3.0" | |||
revision = "03eb88feccce3e477c318ce7f6f1b386544ab20b" | |||
version = "v0.3.3" | |||
[[projects]] | |||
digest = "1:c910feae32bcc3cbf068c7263424d9f198da931c0cad909179621835e6f87cb8" |
@@ -40,6 +40,10 @@ ignored = ["google.golang.org/appengine*"] | |||
name = "github.com/go-xorm/xorm" | |||
revision = "401f4ee8ff8cbc40a4754cb12192fbe4f02f3979" | |||
[[override]] | |||
name = "github.com/go-xorm/builder" | |||
version = "0.3.3" | |||
[[override]] | |||
name = "github.com/go-sql-driver/mysql" | |||
revision = "d523deb1b23d913de5bdada721a6071e71283618" |
@@ -202,6 +202,8 @@ var migrations = []Migration{ | |||
NewMigration("add must_change_password column for users table", addMustChangePassword), | |||
// v74 -> v75 | |||
NewMigration("add approval whitelists to protected branches", addApprovalWhitelistsToProtectedBranches), | |||
// v75 -> v76 | |||
NewMigration("clear nonused data which not deleted when user was deleted", clearNonusedData), | |||
} | |||
// Migrate database to current version |
@@ -0,0 +1,33 @@ | |||
// Copyright 2018 The Gitea Authors. All rights reserved. | |||
// Use of this source code is governed by a MIT-style | |||
// license that can be found in the LICENSE file. | |||
package migrations | |||
import ( | |||
"github.com/go-xorm/builder" | |||
"github.com/go-xorm/xorm" | |||
) | |||
func clearNonusedData(x *xorm.Engine) error { | |||
condDelete := func(colName string) builder.Cond { | |||
return builder.NotIn(colName, builder.Select("id").From("user")) | |||
} | |||
if _, err := x.Exec(builder.Delete(condDelete("uid")).From("team_user")); err != nil { | |||
return err | |||
} | |||
if _, err := x.Exec(builder.Delete(condDelete("user_id")).From("collaboration")); err != nil { | |||
return err | |||
} | |||
if _, err := x.Exec(builder.Delete(condDelete("user_id")).From("stop_watch")); err != nil { | |||
return err | |||
} | |||
if _, err := x.Exec(builder.Delete(condDelete("owner_id")).From("gpg_key")); err != nil { | |||
return err | |||
} | |||
return nil | |||
} |
@@ -549,6 +549,7 @@ func DeletePublicKey(doer *User, id int64) (err error) { | |||
if err = sess.Commit(); err != nil { | |||
return err | |||
} | |||
sess.Close() | |||
return RewriteAllPublicKeys() | |||
} | |||
@@ -557,6 +558,10 @@ func DeletePublicKey(doer *User, id int64) (err error) { | |||
// Note: x.Iterate does not get latest data after insert/delete, so we have to call this function | |||
// outside any session scope independently. | |||
func RewriteAllPublicKeys() error { | |||
return rewriteAllPublicKeys(x) | |||
} | |||
func rewriteAllPublicKeys(e Engine) error { | |||
//Don't rewrite key if internal server | |||
if setting.SSH.StartBuiltinServer || !setting.SSH.CreateAuthorizedKeysFile { | |||
return nil | |||
@@ -583,7 +588,7 @@ func RewriteAllPublicKeys() error { | |||
} | |||
} | |||
err = x.Iterate(new(PublicKey), func(idx int, bean interface{}) (err error) { | |||
err = e.Iterate(new(PublicKey), func(idx int, bean interface{}) (err error) { | |||
_, err = t.WriteString((bean.(*PublicKey)).AuthorizedString()) | |||
return err | |||
}) |
@@ -1015,25 +1015,26 @@ func deleteUser(e *xorm.Session, u *User) error { | |||
&EmailAddress{UID: u.ID}, | |||
&UserOpenID{UID: u.ID}, | |||
&Reaction{UserID: u.ID}, | |||
&TeamUser{UID: u.ID}, | |||
&Collaboration{UserID: u.ID}, | |||
&Stopwatch{UserID: u.ID}, | |||
); err != nil { | |||
return fmt.Errorf("deleteBeans: %v", err) | |||
} | |||
// ***** START: PublicKey ***** | |||
keys := make([]*PublicKey, 0, 10) | |||
if err = e.Find(&keys, &PublicKey{OwnerID: u.ID}); err != nil { | |||
return fmt.Errorf("get all public keys: %v", err) | |||
} | |||
keyIDs := make([]int64, len(keys)) | |||
for i := range keys { | |||
keyIDs[i] = keys[i].ID | |||
} | |||
if err = deletePublicKeys(e, keyIDs...); err != nil { | |||
if _, err = e.Delete(&PublicKey{OwnerID: u.ID}); err != nil { | |||
return fmt.Errorf("deletePublicKeys: %v", err) | |||
} | |||
rewriteAllPublicKeys(e) | |||
// ***** END: PublicKey ***** | |||
// ***** START: GPGPublicKey ***** | |||
if _, err = e.Delete(&GPGKey{OwnerID: u.ID}); err != nil { | |||
return fmt.Errorf("deleteGPGKeys: %v", err) | |||
} | |||
// ***** END: GPGPublicKey ***** | |||
// Clear assignee. | |||
if err = clearAssigneeByUserID(e, u.ID); err != nil { | |||
return fmt.Errorf("clear assignee: %v", err) | |||
@@ -1084,11 +1085,7 @@ func DeleteUser(u *User) (err error) { | |||
return err | |||
} | |||
if err = sess.Commit(); err != nil { | |||
return err | |||
} | |||
return RewriteAllPublicKeys() | |||
return sess.Commit() | |||
} | |||
// DeleteInactivateUsers deletes all inactivate users and email addresses. |
@@ -5,7 +5,9 @@ | |||
package builder | |||
import ( | |||
sql2 "database/sql" | |||
"fmt" | |||
"sort" | |||
) | |||
type optype byte | |||
@@ -16,6 +18,15 @@ const ( | |||
insertType // insert | |||
updateType // update | |||
deleteType // delete | |||
unionType // union | |||
) | |||
const ( | |||
POSTGRES = "postgres" | |||
SQLITE = "sqlite3" | |||
MYSQL = "mysql" | |||
MSSQL = "mssql" | |||
ORACLE = "oracle" | |||
) | |||
type join struct { | |||
@@ -24,68 +35,115 @@ type join struct { | |||
joinCond Cond | |||
} | |||
type union struct { | |||
unionType string | |||
builder *Builder | |||
} | |||
type limit struct { | |||
limitN int | |||
offset int | |||
} | |||
// Builder describes a SQL statement | |||
type Builder struct { | |||
optype | |||
tableName string | |||
cond Cond | |||
selects []string | |||
joins []join | |||
inserts Eq | |||
updates []Eq | |||
orderBy string | |||
groupBy string | |||
having string | |||
dialect string | |||
isNested bool | |||
into string | |||
from string | |||
subQuery *Builder | |||
cond Cond | |||
selects []string | |||
joins []join | |||
unions []union | |||
limitation *limit | |||
insertCols []string | |||
insertVals []interface{} | |||
updates []Eq | |||
orderBy string | |||
groupBy string | |||
having string | |||
} | |||
// Dialect sets the db dialect of Builder. | |||
func Dialect(dialect string) *Builder { | |||
builder := &Builder{cond: NewCond(), dialect: dialect} | |||
return builder | |||
} | |||
// Select creates a select Builder | |||
func Select(cols ...string) *Builder { | |||
builder := &Builder{cond: NewCond()} | |||
return builder.Select(cols...) | |||
// MySQL is shortcut of Dialect(MySQL) | |||
func MySQL() *Builder { | |||
return Dialect(MYSQL) | |||
} | |||
// Insert creates an insert Builder | |||
func Insert(eq Eq) *Builder { | |||
builder := &Builder{cond: NewCond()} | |||
return builder.Insert(eq) | |||
// MsSQL is shortcut of Dialect(MsSQL) | |||
func MsSQL() *Builder { | |||
return Dialect(MSSQL) | |||
} | |||
// Update creates an update Builder | |||
func Update(updates ...Eq) *Builder { | |||
builder := &Builder{cond: NewCond()} | |||
return builder.Update(updates...) | |||
// Oracle is shortcut of Dialect(Oracle) | |||
func Oracle() *Builder { | |||
return Dialect(ORACLE) | |||
} | |||
// Delete creates a delete Builder | |||
func Delete(conds ...Cond) *Builder { | |||
builder := &Builder{cond: NewCond()} | |||
return builder.Delete(conds...) | |||
// Postgres is shortcut of Dialect(Postgres) | |||
func Postgres() *Builder { | |||
return Dialect(POSTGRES) | |||
} | |||
// SQLite is shortcut of Dialect(SQLITE) | |||
func SQLite() *Builder { | |||
return Dialect(SQLITE) | |||
} | |||
// Where sets where SQL | |||
func (b *Builder) Where(cond Cond) *Builder { | |||
b.cond = b.cond.And(cond) | |||
if b.cond.IsValid() { | |||
b.cond = b.cond.And(cond) | |||
} else { | |||
b.cond = cond | |||
} | |||
return b | |||
} | |||
// From sets the table name | |||
func (b *Builder) From(tableName string) *Builder { | |||
b.tableName = tableName | |||
// From sets from subject(can be a table name in string or a builder pointer) and its alias | |||
func (b *Builder) From(subject interface{}, alias ...string) *Builder { | |||
switch subject.(type) { | |||
case *Builder: | |||
b.subQuery = subject.(*Builder) | |||
if len(alias) > 0 { | |||
b.from = alias[0] | |||
} else { | |||
b.isNested = true | |||
} | |||
case string: | |||
b.from = subject.(string) | |||
if len(alias) > 0 { | |||
b.from = b.from + " " + alias[0] | |||
} | |||
} | |||
return b | |||
} | |||
// TableName returns the table name | |||
func (b *Builder) TableName() string { | |||
return b.tableName | |||
if b.optype == insertType { | |||
return b.into | |||
} | |||
return b.from | |||
} | |||
// Into sets insert table name | |||
func (b *Builder) Into(tableName string) *Builder { | |||
b.tableName = tableName | |||
b.into = tableName | |||
return b | |||
} | |||
// Join sets join table and contions | |||
// Join sets join table and conditions | |||
func (b *Builder) Join(joinType, joinTable string, joinCond interface{}) *Builder { | |||
switch joinCond.(type) { | |||
case Cond: | |||
@@ -97,6 +155,50 @@ func (b *Builder) Join(joinType, joinTable string, joinCond interface{}) *Builde | |||
return b | |||
} | |||
// Union sets union conditions | |||
func (b *Builder) Union(unionTp string, unionCond *Builder) *Builder { | |||
var builder *Builder | |||
if b.optype != unionType { | |||
builder = &Builder{cond: NewCond()} | |||
builder.optype = unionType | |||
builder.dialect = b.dialect | |||
builder.selects = b.selects | |||
currentUnions := b.unions | |||
// erase sub unions (actually append to new Builder.unions) | |||
b.unions = nil | |||
for e := range currentUnions { | |||
currentUnions[e].builder.dialect = b.dialect | |||
} | |||
builder.unions = append(append(builder.unions, union{"", b}), currentUnions...) | |||
} else { | |||
builder = b | |||
} | |||
if unionCond != nil { | |||
if unionCond.dialect == "" && builder.dialect != "" { | |||
unionCond.dialect = builder.dialect | |||
} | |||
builder.unions = append(builder.unions, union{unionTp, unionCond}) | |||
} | |||
return builder | |||
} | |||
// Limit sets limitN condition | |||
func (b *Builder) Limit(limitN int, offset ...int) *Builder { | |||
b.limitation = &limit{limitN: limitN} | |||
if len(offset) > 0 { | |||
b.limitation.offset = offset[0] | |||
} | |||
return b | |||
} | |||
// InnerJoin sets inner join | |||
func (b *Builder) InnerJoin(joinTable string, joinCond interface{}) *Builder { | |||
return b.Join("INNER", joinTable, joinCond) | |||
@@ -125,7 +227,9 @@ func (b *Builder) FullJoin(joinTable string, joinCond interface{}) *Builder { | |||
// Select sets select SQL | |||
func (b *Builder) Select(cols ...string) *Builder { | |||
b.selects = cols | |||
b.optype = selectType | |||
if b.optype == condType { | |||
b.optype = selectType | |||
} | |||
return b | |||
} | |||
@@ -141,16 +245,70 @@ func (b *Builder) Or(cond Cond) *Builder { | |||
return b | |||
} | |||
type insertColsSorter struct { | |||
cols []string | |||
vals []interface{} | |||
} | |||
func (s insertColsSorter) Len() int { | |||
return len(s.cols) | |||
} | |||
func (s insertColsSorter) Swap(i, j int) { | |||
s.cols[i], s.cols[j] = s.cols[j], s.cols[i] | |||
s.vals[i], s.vals[j] = s.vals[j], s.vals[i] | |||
} | |||
func (s insertColsSorter) Less(i, j int) bool { | |||
return s.cols[i] < s.cols[j] | |||
} | |||
// Insert sets insert SQL | |||
func (b *Builder) Insert(eq Eq) *Builder { | |||
b.inserts = eq | |||
func (b *Builder) Insert(eq ...interface{}) *Builder { | |||
if len(eq) > 0 { | |||
var paramType = -1 | |||
for _, e := range eq { | |||
switch t := e.(type) { | |||
case Eq: | |||
if paramType == -1 { | |||
paramType = 0 | |||
} | |||
if paramType != 0 { | |||
break | |||
} | |||
for k, v := range t { | |||
b.insertCols = append(b.insertCols, k) | |||
b.insertVals = append(b.insertVals, v) | |||
} | |||
case string: | |||
if paramType == -1 { | |||
paramType = 1 | |||
} | |||
if paramType != 1 { | |||
break | |||
} | |||
b.insertCols = append(b.insertCols, t) | |||
} | |||
} | |||
} | |||
if len(b.insertCols) == len(b.insertVals) { | |||
sort.Sort(insertColsSorter{ | |||
cols: b.insertCols, | |||
vals: b.insertVals, | |||
}) | |||
} | |||
b.optype = insertType | |||
return b | |||
} | |||
// Update sets update SQL | |||
func (b *Builder) Update(updates ...Eq) *Builder { | |||
b.updates = updates | |||
b.updates = make([]Eq, 0, len(updates)) | |||
for _, update := range updates { | |||
if update.IsValid() { | |||
b.updates = append(b.updates, update) | |||
} | |||
} | |||
b.optype = updateType | |||
return b | |||
} | |||
@@ -165,8 +323,8 @@ func (b *Builder) Delete(conds ...Cond) *Builder { | |||
// WriteTo implements Writer interface | |||
func (b *Builder) WriteTo(w Writer) error { | |||
switch b.optype { | |||
case condType: | |||
return b.cond.WriteTo(w) | |||
/*case condType: | |||
return b.cond.WriteTo(w)*/ | |||
case selectType: | |||
return b.selectWriteTo(w) | |||
case insertType: | |||
@@ -175,6 +333,8 @@ func (b *Builder) WriteTo(w Writer) error { | |||
return b.updateWriteTo(w) | |||
case deleteType: | |||
return b.deleteWriteTo(w) | |||
case unionType: | |||
return b.unionWriteTo(w) | |||
} | |||
return ErrNotSupportType | |||
@@ -187,43 +347,48 @@ func (b *Builder) ToSQL() (string, []interface{}, error) { | |||
return "", nil, err | |||
} | |||
return w.writer.String(), w.args, nil | |||
} | |||
// in case of sql.NamedArg in args | |||
for e := range w.args { | |||
if namedArg, ok := w.args[e].(sql2.NamedArg); ok { | |||
w.args[e] = namedArg.Value | |||
} | |||
} | |||
// ConvertPlaceholder replaces ? to $1, $2 ... or :1, :2 ... according prefix | |||
func ConvertPlaceholder(sql, prefix string) (string, error) { | |||
buf := StringBuilder{} | |||
var j, start = 0, 0 | |||
for i := 0; i < len(sql); i++ { | |||
if sql[i] == '?' { | |||
_, err := buf.WriteString(sql[start:i]) | |||
if err != nil { | |||
return "", err | |||
} | |||
start = i + 1 | |||
var sql = w.writer.String() | |||
var err error | |||
_, err = buf.WriteString(prefix) | |||
if err != nil { | |||
return "", err | |||
} | |||
switch b.dialect { | |||
case ORACLE, MSSQL: | |||
// This is for compatibility with different sql drivers | |||
for e := range w.args { | |||
w.args[e] = sql2.Named(fmt.Sprintf("p%d", e+1), w.args[e]) | |||
} | |||
j = j + 1 | |||
_, err = buf.WriteString(fmt.Sprintf("%d", j)) | |||
if err != nil { | |||
return "", err | |||
} | |||
var prefix string | |||
if b.dialect == ORACLE { | |||
prefix = ":p" | |||
} else { | |||
prefix = "@p" | |||
} | |||
if sql, err = ConvertPlaceholder(sql, prefix); err != nil { | |||
return "", nil, err | |||
} | |||
case POSTGRES: | |||
if sql, err = ConvertPlaceholder(sql, "$"); err != nil { | |||
return "", nil, err | |||
} | |||
} | |||
return buf.String(), nil | |||
return sql, w.args, nil | |||
} | |||
// ToSQL convert a builder or condtions to SQL and args | |||
func ToSQL(cond interface{}) (string, []interface{}, error) { | |||
switch cond.(type) { | |||
case Cond: | |||
return condToSQL(cond.(Cond)) | |||
case *Builder: | |||
return cond.(*Builder).ToSQL() | |||
// ToBoundSQL | |||
func (b *Builder) ToBoundSQL() (string, error) { | |||
w := NewWriter() | |||
if err := b.WriteTo(w); err != nil { | |||
return "", err | |||
} | |||
return "", nil, ErrNotSupportType | |||
return ConvertToBoundSQL(w.writer.String(), w.args) | |||
} |
@@ -5,16 +5,21 @@ | |||
package builder | |||
import ( | |||
"errors" | |||
"fmt" | |||
) | |||
// Delete creates a delete Builder | |||
func Delete(conds ...Cond) *Builder { | |||
builder := &Builder{cond: NewCond()} | |||
return builder.Delete(conds...) | |||
} | |||
func (b *Builder) deleteWriteTo(w Writer) error { | |||
if len(b.tableName) <= 0 { | |||
return errors.New("no table indicated") | |||
if len(b.from) <= 0 { | |||
return ErrNoTableName | |||
} | |||
if _, err := fmt.Fprintf(w, "DELETE FROM %s WHERE ", b.tableName); err != nil { | |||
if _, err := fmt.Fprintf(w, "DELETE FROM %s WHERE ", b.from); err != nil { | |||
return err | |||
} | |||
@@ -6,39 +6,63 @@ package builder | |||
import ( | |||
"bytes" | |||
"errors" | |||
"fmt" | |||
) | |||
// Insert creates an insert Builder | |||
func Insert(eq ...interface{}) *Builder { | |||
builder := &Builder{cond: NewCond()} | |||
return builder.Insert(eq...) | |||
} | |||
func (b *Builder) insertSelectWriteTo(w Writer) error { | |||
if _, err := fmt.Fprintf(w, "INSERT INTO %s ", b.into); err != nil { | |||
return err | |||
} | |||
if len(b.insertCols) > 0 { | |||
fmt.Fprintf(w, "(") | |||
for _, col := range b.insertCols { | |||
fmt.Fprintf(w, col) | |||
} | |||
fmt.Fprintf(w, ") ") | |||
} | |||
return b.selectWriteTo(w) | |||
} | |||
func (b *Builder) insertWriteTo(w Writer) error { | |||
if len(b.tableName) <= 0 { | |||
return errors.New("no table indicated") | |||
if len(b.into) <= 0 { | |||
return ErrNoTableName | |||
} | |||
if len(b.insertCols) <= 0 && b.from == "" { | |||
return ErrNoColumnToInsert | |||
} | |||
if len(b.inserts) <= 0 { | |||
return errors.New("no column to be insert") | |||
if b.into != "" && b.from != "" { | |||
return b.insertSelectWriteTo(w) | |||
} | |||
if _, err := fmt.Fprintf(w, "INSERT INTO %s (", b.tableName); err != nil { | |||
if _, err := fmt.Fprintf(w, "INSERT INTO %s (", b.into); err != nil { | |||
return err | |||
} | |||
var args = make([]interface{}, 0) | |||
var bs []byte | |||
var valBuffer = bytes.NewBuffer(bs) | |||
var i = 0 | |||
for _, col := range b.inserts.sortedKeys() { | |||
value := b.inserts[col] | |||
for i, col := range b.insertCols { | |||
value := b.insertVals[i] | |||
fmt.Fprint(w, col) | |||
if e, ok := value.(expr); ok { | |||
fmt.Fprint(valBuffer, e.sql) | |||
fmt.Fprintf(valBuffer, "(%s)", e.sql) | |||
args = append(args, e.args...) | |||
} else { | |||
fmt.Fprint(valBuffer, "?") | |||
args = append(args, value) | |||
} | |||
if i != len(b.inserts)-1 { | |||
if i != len(b.insertCols)-1 { | |||
if _, err := fmt.Fprint(w, ","); err != nil { | |||
return err | |||
} | |||
@@ -46,7 +70,6 @@ func (b *Builder) insertWriteTo(w Writer) error { | |||
return err | |||
} | |||
} | |||
i = i + 1 | |||
} | |||
if _, err := fmt.Fprint(w, ") Values ("); err != nil { |
@@ -0,0 +1,100 @@ | |||
// Copyright 2018 The Xorm Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
package builder | |||
import ( | |||
"fmt" | |||
"strings" | |||
) | |||
func (b *Builder) limitWriteTo(w Writer) error { | |||
if strings.TrimSpace(b.dialect) == "" { | |||
return ErrDialectNotSetUp | |||
} | |||
if b.limitation != nil { | |||
limit := b.limitation | |||
if limit.offset < 0 || limit.limitN <= 0 { | |||
return ErrInvalidLimitation | |||
} | |||
// erase limit condition | |||
b.limitation = nil | |||
ow := w.(*BytesWriter) | |||
switch strings.ToLower(strings.TrimSpace(b.dialect)) { | |||
case ORACLE: | |||
if len(b.selects) == 0 { | |||
b.selects = append(b.selects, "*") | |||
} | |||
var final *Builder | |||
selects := b.selects | |||
b.selects = append(selects, "ROWNUM RN") | |||
var wb *Builder | |||
if b.optype == unionType { | |||
wb = Dialect(b.dialect).Select("at.*", "ROWNUM RN"). | |||
From(b, "at") | |||
} else { | |||
wb = b | |||
} | |||
if limit.offset == 0 { | |||
final = Dialect(b.dialect).Select(selects...).From(wb, "at"). | |||
Where(Lte{"at.RN": limit.limitN}) | |||
} else { | |||
sub := Dialect(b.dialect).Select("*"). | |||
From(b, "at").Where(Lte{"at.RN": limit.offset + limit.limitN}) | |||
final = Dialect(b.dialect).Select(selects...).From(sub, "att"). | |||
Where(Gt{"att.RN": limit.offset}) | |||
} | |||
return final.WriteTo(ow) | |||
case SQLITE, MYSQL, POSTGRES: | |||
// if type UNION, we need to write previous content back to current writer | |||
if b.optype == unionType { | |||
if err := b.WriteTo(ow); err != nil { | |||
return err | |||
} | |||
} | |||
if limit.offset == 0 { | |||
fmt.Fprint(ow, " LIMIT ", limit.limitN) | |||
} else { | |||
fmt.Fprintf(ow, " LIMIT %v OFFSET %v", limit.limitN, limit.offset) | |||
} | |||
case MSSQL: | |||
if len(b.selects) == 0 { | |||
b.selects = append(b.selects, "*") | |||
} | |||
var final *Builder | |||
selects := b.selects | |||
b.selects = append(append([]string{fmt.Sprintf("TOP %d %v", limit.limitN+limit.offset, b.selects[0])}, | |||
b.selects[1:]...), "ROW_NUMBER() OVER (ORDER BY (SELECT 1)) AS RN") | |||
var wb *Builder | |||
if b.optype == unionType { | |||
wb = Dialect(b.dialect).Select("*", "ROW_NUMBER() OVER (ORDER BY (SELECT 1)) AS RN"). | |||
From(b, "at") | |||
} else { | |||
wb = b | |||
} | |||
if limit.offset == 0 { | |||
final = Dialect(b.dialect).Select(selects...).From(wb, "at") | |||
} else { | |||
final = Dialect(b.dialect).Select(selects...).From(wb, "at").Where(Gt{"at.RN": limit.offset}) | |||
} | |||
return final.WriteTo(ow) | |||
default: | |||
return ErrNotSupportType | |||
} | |||
} | |||
return nil | |||
} |
@@ -5,13 +5,24 @@ | |||
package builder | |||
import ( | |||
"errors" | |||
"fmt" | |||
) | |||
// Select creates a select Builder | |||
func Select(cols ...string) *Builder { | |||
builder := &Builder{cond: NewCond()} | |||
return builder.Select(cols...) | |||
} | |||
func (b *Builder) selectWriteTo(w Writer) error { | |||
if len(b.tableName) <= 0 { | |||
return errors.New("no table indicated") | |||
if len(b.from) <= 0 && !b.isNested { | |||
return ErrNoTableName | |||
} | |||
// perform limit before writing to writer when b.dialect between ORACLE and MSSQL | |||
// this avoid a duplicate writing problem in simple limit query | |||
if b.limitation != nil && (b.dialect == ORACLE || b.dialect == MSSQL) { | |||
return b.limitWriteTo(w) | |||
} | |||
if _, err := fmt.Fprint(w, "SELECT "); err != nil { | |||
@@ -34,8 +45,38 @@ func (b *Builder) selectWriteTo(w Writer) error { | |||
} | |||
} | |||
if _, err := fmt.Fprint(w, " FROM ", b.tableName); err != nil { | |||
return err | |||
if b.subQuery == nil { | |||
if _, err := fmt.Fprint(w, " FROM ", b.from); err != nil { | |||
return err | |||
} | |||
} else { | |||
if b.cond.IsValid() && len(b.from) <= 0 { | |||
return ErrUnnamedDerivedTable | |||
} | |||
if b.subQuery.dialect != "" && b.dialect != b.subQuery.dialect { | |||
return ErrInconsistentDialect | |||
} | |||
// dialect of sub-query will inherit from the main one (if not set up) | |||
if b.dialect != "" && b.subQuery.dialect == "" { | |||
b.subQuery.dialect = b.dialect | |||
} | |||
switch b.subQuery.optype { | |||
case selectType, unionType: | |||
fmt.Fprint(w, " FROM (") | |||
if err := b.subQuery.WriteTo(w); err != nil { | |||
return err | |||
} | |||
if len(b.from) == 0 { | |||
fmt.Fprintf(w, ")") | |||
} else { | |||
fmt.Fprintf(w, ") %v", b.from) | |||
} | |||
default: | |||
return ErrUnexpectedSubQuery | |||
} | |||
} | |||
for _, v := range b.joins { | |||
@@ -76,6 +117,12 @@ func (b *Builder) selectWriteTo(w Writer) error { | |||
} | |||
} | |||
if b.limitation != nil { | |||
if err := b.limitWriteTo(w); err != nil { | |||
return err | |||
} | |||
} | |||
return nil | |||
} | |||
@@ -0,0 +1,47 @@ | |||
// Copyright 2018 The Xorm Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
package builder | |||
import ( | |||
"fmt" | |||
"strings" | |||
) | |||
func (b *Builder) unionWriteTo(w Writer) error { | |||
if b.limitation != nil || b.cond.IsValid() || | |||
b.orderBy != "" || b.having != "" || b.groupBy != "" { | |||
return ErrNotUnexpectedUnionConditions | |||
} | |||
for idx, u := range b.unions { | |||
current := u.builder | |||
if current.optype != selectType { | |||
return ErrUnsupportedUnionMembers | |||
} | |||
if len(b.unions) == 1 { | |||
if err := current.selectWriteTo(w); err != nil { | |||
return err | |||
} | |||
} else { | |||
if b.dialect != "" && b.dialect != current.dialect { | |||
return ErrInconsistentDialect | |||
} | |||
if idx != 0 { | |||
fmt.Fprint(w, fmt.Sprintf(" UNION %v ", strings.ToUpper(u.unionType))) | |||
} | |||
fmt.Fprint(w, "(") | |||
if err := current.selectWriteTo(w); err != nil { | |||
return err | |||
} | |||
fmt.Fprint(w, ")") | |||
} | |||
} | |||
return nil | |||
} |
@@ -5,19 +5,24 @@ | |||
package builder | |||
import ( | |||
"errors" | |||
"fmt" | |||
) | |||
// Update creates an update Builder | |||
func Update(updates ...Eq) *Builder { | |||
builder := &Builder{cond: NewCond()} | |||
return builder.Update(updates...) | |||
} | |||
func (b *Builder) updateWriteTo(w Writer) error { | |||
if len(b.tableName) <= 0 { | |||
return errors.New("no table indicated") | |||
if len(b.from) <= 0 { | |||
return ErrNoTableName | |||
} | |||
if len(b.updates) <= 0 { | |||
return errors.New("no column to be update") | |||
return ErrNoColumnToUpdate | |||
} | |||
if _, err := fmt.Fprintf(w, "UPDATE %s SET ", b.tableName); err != nil { | |||
if _, err := fmt.Fprintf(w, "UPDATE %s SET ", b.from); err != nil { | |||
return err | |||
} | |||
@@ -72,15 +72,3 @@ func (condEmpty) Or(conds ...Cond) Cond { | |||
func (condEmpty) IsValid() bool { | |||
return false | |||
} | |||
func condToSQL(cond Cond) (string, []interface{}, error) { | |||
if cond == nil || !cond.IsValid() { | |||
return "", nil, nil | |||
} | |||
w := NewWriter() | |||
if err := cond.WriteTo(w); err != nil { | |||
return "", nil, err | |||
} | |||
return w.writer.String(), w.args, nil | |||
} |
@@ -17,10 +17,35 @@ var _ Cond = Between{} | |||
// WriteTo write data to Writer | |||
func (between Between) WriteTo(w Writer) error { | |||
if _, err := fmt.Fprintf(w, "%s BETWEEN ? AND ?", between.Col); err != nil { | |||
if _, err := fmt.Fprintf(w, "%s BETWEEN ", between.Col); err != nil { | |||
return err | |||
} | |||
w.Append(between.LessVal, between.MoreVal) | |||
if lv, ok := between.LessVal.(expr); ok { | |||
if err := lv.WriteTo(w); err != nil { | |||
return err | |||
} | |||
} else { | |||
if _, err := fmt.Fprint(w, "?"); err != nil { | |||
return err | |||
} | |||
w.Append(between.LessVal) | |||
} | |||
if _, err := fmt.Fprint(w, " AND "); err != nil { | |||
return err | |||
} | |||
if mv, ok := between.MoreVal.(expr); ok { | |||
if err := mv.WriteTo(w); err != nil { | |||
return err | |||
} | |||
} else { | |||
if _, err := fmt.Fprint(w, "?"); err != nil { | |||
return err | |||
} | |||
w.Append(between.MoreVal) | |||
} | |||
return nil | |||
} | |||
@@ -27,10 +27,12 @@ func (o condOr) WriteTo(w Writer) error { | |||
for i, cond := range o { | |||
var needQuote bool | |||
switch cond.(type) { | |||
case condAnd: | |||
case condAnd, expr: | |||
needQuote = true | |||
case Eq: | |||
needQuote = (len(cond.(Eq)) > 1) | |||
case Neq: | |||
needQuote = (len(cond.(Neq)) > 1) | |||
} | |||
if needQuote { |
@@ -8,9 +8,33 @@ import "errors" | |||
var ( | |||
// ErrNotSupportType not supported SQL type error | |||
ErrNotSupportType = errors.New("not supported SQL type") | |||
ErrNotSupportType = errors.New("Not supported SQL type") | |||
// ErrNoNotInConditions no NOT IN params error | |||
ErrNoNotInConditions = errors.New("No NOT IN conditions") | |||
// ErrNoInConditions no IN params error | |||
ErrNoInConditions = errors.New("No IN conditions") | |||
// ErrNeedMoreArguments need more arguments | |||
ErrNeedMoreArguments = errors.New("Need more sql arguments") | |||
// ErrNoTableName no table name | |||
ErrNoTableName = errors.New("No table indicated") | |||
// ErrNoColumnToInsert no column to update | |||
ErrNoColumnToUpdate = errors.New("No column(s) to update") | |||
// ErrNoColumnToInsert no column to update | |||
ErrNoColumnToInsert = errors.New("No column(s) to insert") | |||
// ErrNotSupportDialectType not supported dialect type error | |||
ErrNotSupportDialectType = errors.New("Not supported dialect type") | |||
// ErrNotUnexpectedUnionConditions using union in a wrong way | |||
ErrNotUnexpectedUnionConditions = errors.New("Unexpected conditional fields in UNION query") | |||
// ErrUnsupportedUnionMembers unexpected members in UNION query | |||
ErrUnsupportedUnionMembers = errors.New("Unexpected members in UNION query") | |||
// ErrUnexpectedSubQuery Unexpected sub-query in SELECT query | |||
ErrUnexpectedSubQuery = errors.New("Unexpected sub-query in SELECT query") | |||
// ErrDialectNotSetUp dialect is not setup yet | |||
ErrDialectNotSetUp = errors.New("Dialect is not setup yet, try to use `Dialect(dbType)` at first") | |||
// ErrInvalidLimitation offset or limit is not correct | |||
ErrInvalidLimitation = errors.New("Offset or limit is not correct") | |||
// ErrUnnamedDerivedTable Every derived table must have its own alias | |||
ErrUnnamedDerivedTable = errors.New("Every derived table must have its own alias") | |||
// ErrInconsistentDialect Inconsistent dialect in same builder | |||
ErrInconsistentDialect = errors.New("Inconsistent dialect in same builder") | |||
) |
@@ -0,0 +1,156 @@ | |||
// Copyright 2018 The Xorm Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
package builder | |||
import ( | |||
sql2 "database/sql" | |||
"fmt" | |||
"reflect" | |||
"time" | |||
) | |||
func condToSQL(cond Cond) (string, []interface{}, error) { | |||
if cond == nil || !cond.IsValid() { | |||
return "", nil, nil | |||
} | |||
w := NewWriter() | |||
if err := cond.WriteTo(w); err != nil { | |||
return "", nil, err | |||
} | |||
return w.writer.String(), w.args, nil | |||
} | |||
func condToBoundSQL(cond Cond) (string, error) { | |||
if cond == nil || !cond.IsValid() { | |||
return "", nil | |||
} | |||
w := NewWriter() | |||
if err := cond.WriteTo(w); err != nil { | |||
return "", err | |||
} | |||
return ConvertToBoundSQL(w.writer.String(), w.args) | |||
} | |||
// ToSQL convert a builder or conditions to SQL and args | |||
func ToSQL(cond interface{}) (string, []interface{}, error) { | |||
switch cond.(type) { | |||
case Cond: | |||
return condToSQL(cond.(Cond)) | |||
case *Builder: | |||
return cond.(*Builder).ToSQL() | |||
} | |||
return "", nil, ErrNotSupportType | |||
} | |||
// ToBoundSQL convert a builder or conditions to parameters bound SQL | |||
func ToBoundSQL(cond interface{}) (string, error) { | |||
switch cond.(type) { | |||
case Cond: | |||
return condToBoundSQL(cond.(Cond)) | |||
case *Builder: | |||
return cond.(*Builder).ToBoundSQL() | |||
} | |||
return "", ErrNotSupportType | |||
} | |||
func noSQLQuoteNeeded(a interface{}) bool { | |||
switch a.(type) { | |||
case int, int8, int16, int32, int64: | |||
return true | |||
case uint, uint8, uint16, uint32, uint64: | |||
return true | |||
case float32, float64: | |||
return true | |||
case bool: | |||
return true | |||
case string: | |||
return false | |||
case time.Time, *time.Time: | |||
return false | |||
} | |||
t := reflect.TypeOf(a) | |||
switch t.Kind() { | |||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | |||
return true | |||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: | |||
return true | |||
case reflect.Float32, reflect.Float64: | |||
return true | |||
case reflect.Bool: | |||
return true | |||
case reflect.String: | |||
return false | |||
} | |||
return false | |||
} | |||
// ConvertToBoundSQL will convert SQL and args to a bound SQL | |||
func ConvertToBoundSQL(sql string, args []interface{}) (string, error) { | |||
buf := StringBuilder{} | |||
var i, j, start int | |||
for ; i < len(sql); i++ { | |||
if sql[i] == '?' { | |||
_, err := buf.WriteString(sql[start:i]) | |||
if err != nil { | |||
return "", err | |||
} | |||
start = i + 1 | |||
if len(args) == j { | |||
return "", ErrNeedMoreArguments | |||
} | |||
arg := args[j] | |||
if namedArg, ok := arg.(sql2.NamedArg); ok { | |||
arg = namedArg.Value | |||
} | |||
if noSQLQuoteNeeded(arg) { | |||
_, err = fmt.Fprint(&buf, arg) | |||
} else { | |||
_, err = fmt.Fprintf(&buf, "'%v'", arg) | |||
} | |||
if err != nil { | |||
return "", err | |||
} | |||
j = j + 1 | |||
} | |||
} | |||
_, err := buf.WriteString(sql[start:]) | |||
if err != nil { | |||
return "", err | |||
} | |||
return buf.String(), nil | |||
} | |||
// ConvertPlaceholder replaces ? to $1, $2 ... or :1, :2 ... according prefix | |||
func ConvertPlaceholder(sql, prefix string) (string, error) { | |||
buf := StringBuilder{} | |||
var i, j, start int | |||
for ; i < len(sql); i++ { | |||
if sql[i] == '?' { | |||
if _, err := buf.WriteString(sql[start:i]); err != nil { | |||
return "", err | |||
} | |||
start = i + 1 | |||
j = j + 1 | |||
if _, err := buf.WriteString(fmt.Sprintf("%v%d", prefix, j)); err != nil { | |||
return "", err | |||
} | |||
} | |||
} | |||
if _, err := buf.WriteString(sql[start:]); err != nil { | |||
return "", err | |||
} | |||
return buf.String(), nil | |||
} |