summaryrefslogtreecommitdiffstats
path: root/models/db/context.go
blob: 670f6272aa9a923b24584d4953e28b906f083138 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
// Copyright 2019 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT

package db

import (
	"context"
	"database/sql"

	"xorm.io/builder"
	"xorm.io/xorm"
	"xorm.io/xorm/schemas"
)

// DefaultContext is the default context to run xorm queries in
// will be overwritten by Init with HammerContext
var DefaultContext context.Context

// contextKey is a value for use with context.WithValue.
type contextKey struct {
	name string
}

// enginedContextKey is a context key. It is used with context.Value() to get the current Engined for the context
var (
	enginedContextKey         = &contextKey{"engined"}
	_                 Engined = &Context{}
)

// Context represents a db context
type Context struct {
	context.Context
	e           Engine
	transaction bool
}

func newContext(ctx context.Context, e Engine, transaction bool) *Context {
	return &Context{
		Context:     ctx,
		e:           e,
		transaction: transaction,
	}
}

// InTransaction if context is in a transaction
func (ctx *Context) InTransaction() bool {
	return ctx.transaction
}

// Engine returns db engine
func (ctx *Context) Engine() Engine {
	return ctx.e
}

// Value shadows Value for context.Context but allows us to get ourselves and an Engined object
func (ctx *Context) Value(key interface{}) interface{} {
	if key == enginedContextKey {
		return ctx
	}
	return ctx.Context.Value(key)
}

// WithContext returns this engine tied to this context
func (ctx *Context) WithContext(other context.Context) *Context {
	return newContext(ctx, ctx.e.Context(other), ctx.transaction)
}

// Engined structs provide an Engine
type Engined interface {
	Engine() Engine
}

// GetEngine will get a db Engine from this context or return an Engine restricted to this context
func GetEngine(ctx context.Context) Engine {
	if e := getEngine(ctx); e != nil {
		return e
	}
	return x.Context(ctx)
}

// getEngine will get a db Engine from this context or return nil
func getEngine(ctx context.Context) Engine {
	if engined, ok := ctx.(Engined); ok {
		return engined.Engine()
	}
	enginedInterface := ctx.Value(enginedContextKey)
	if enginedInterface != nil {
		return enginedInterface.(Engined).Engine()
	}
	return nil
}

// Committer represents an interface to Commit or Close the Context
type Committer interface {
	Commit() error
	Close() error
}

// halfCommitter is a wrapper of Committer.
// It can be closed early, but can't be committed early, it is useful for reusing a transaction.
type halfCommitter struct {
	committer Committer
	committed bool
}

func (c *halfCommitter) Commit() error {
	c.committed = true
	// should do nothing, and the parent committer will commit later
	return nil
}

func (c *halfCommitter) Close() error {
	if c.committed {
		// it's "commit and close", should do nothing, and the parent committer will commit later
		return nil
	}

	// it's "rollback and close", let the parent committer rollback right now
	return c.committer.Close()
}

// TxContext represents a transaction Context,
// it will reuse the existing transaction in the parent context or create a new one.
func TxContext(parentCtx context.Context) (*Context, Committer, error) {
	if sess, ok := inTransaction(parentCtx); ok {
		return newContext(parentCtx, sess, true), &halfCommitter{committer: sess}, nil
	}

	sess := x.NewSession()
	if err := sess.Begin(); err != nil {
		sess.Close()
		return nil, nil, err
	}

	return newContext(DefaultContext, sess, true), sess, nil
}

// WithTx represents executing database operations on a transaction, if the transaction exist,
// this function will reuse it otherwise will create a new one and close it when finished.
func WithTx(parentCtx context.Context, f func(ctx context.Context) error) error {
	if sess, ok := inTransaction(parentCtx); ok {
		err := f(newContext(parentCtx, sess, true))
		if err != nil {
			// rollback immediately, in case the caller ignores returned error and tries to commit the transaction.
			_ = sess.Close()
		}
		return err
	}
	return txWithNoCheck(parentCtx, f)
}

func txWithNoCheck(parentCtx context.Context, f func(ctx context.Context) error) error {
	sess := x.NewSession()
	defer sess.Close()
	if err := sess.Begin(); err != nil {
		return err
	}

	if err := f(newContext(parentCtx, sess, true)); err != nil {
		return err
	}

	return sess.Commit()
}

// Insert inserts records into database
func Insert(ctx context.Context, beans ...interface{}) error {
	_, err := GetEngine(ctx).Insert(beans...)
	return err
}

// Exec executes a sql with args
func Exec(ctx context.Context, sqlAndArgs ...interface{}) (sql.Result, error) {
	return GetEngine(ctx).Exec(sqlAndArgs...)
}

// GetByBean filled empty fields of the bean according non-empty fields to query in database.
func GetByBean(ctx context.Context, bean interface{}) (bool, error) {
	return GetEngine(ctx).Get(bean)
}

// DeleteByBean deletes all records according non-empty fields of the bean as conditions.
func DeleteByBean(ctx context.Context, bean interface{}) (int64, error) {
	return GetEngine(ctx).Delete(bean)
}

// DeleteByID deletes the given bean with the given ID
func DeleteByID(ctx context.Context, id int64, bean interface{}) (int64, error) {
	return GetEngine(ctx).ID(id).NoAutoTime().Delete(bean)
}

// FindIDs finds the IDs for the given table name satisfying the given condition
// By passing a different value than "id" for "idCol", you can query for foreign IDs, i.e. the repo IDs which satisfy the condition
func FindIDs(ctx context.Context, tableName, idCol string, cond builder.Cond) ([]int64, error) {
	ids := make([]int64, 0, 10)
	if err := GetEngine(ctx).Table(tableName).
		Cols(idCol).
		Where(cond).
		Find(&ids); err != nil {
		return nil, err
	}
	return ids, nil
}

// DecrByIDs decreases the given column for entities of the "bean" type with one of the given ids by one
// Timestamps of the entities won't be updated
func DecrByIDs(ctx context.Context, ids []int64, decrCol string, bean interface{}) error {
	_, err := GetEngine(ctx).Decr(decrCol).In("id", ids).NoAutoCondition().NoAutoTime().Update(bean)
	return err
}

// DeleteBeans deletes all given beans, beans must contain delete conditions.
func DeleteBeans(ctx context.Context, beans ...interface{}) (err error) {
	e := GetEngine(ctx)
	for i := range beans {
		if _, err = e.Delete(beans[i]); err != nil {
			return err
		}
	}
	return nil
}

// TruncateBeans deletes all given beans, beans may contain delete conditions.
func TruncateBeans(ctx context.Context, beans ...interface{}) (err error) {
	e := GetEngine(ctx)
	for i := range beans {
		if _, err = e.Truncate(beans[i]); err != nil {
			return err
		}
	}
	return nil
}

// CountByBean counts the number of database records according non-empty fields of the bean as conditions.
func CountByBean(ctx context.Context, bean interface{}) (int64, error) {
	return GetEngine(ctx).Count(bean)
}

// TableName returns the table name according a bean object
func TableName(bean interface{}) string {
	return x.TableName(bean)
}

// EstimateCount returns an estimate of total number of rows in table
func EstimateCount(ctx context.Context, bean interface{}) (int64, error) {
	e := GetEngine(ctx)
	e.Context(ctx)

	var rows int64
	var err error
	tablename := TableName(bean)
	switch x.Dialect().URI().DBType {
	case schemas.MYSQL:
		_, err = e.Context(ctx).SQL("SELECT table_rows FROM information_schema.tables WHERE tables.table_name = ? AND tables.table_schema = ?;", tablename, x.Dialect().URI().DBName).Get(&rows)
	case schemas.POSTGRES:
		// the table can live in multiple schemas of a postgres database
		// See https://wiki.postgresql.org/wiki/Count_estimate
		tablename = x.TableName(bean, true)
		_, err = e.Context(ctx).SQL("SELECT reltuples::bigint AS estimate FROM pg_class WHERE oid = ?::regclass;", tablename).Get(&rows)
	case schemas.MSSQL:
		_, err = e.Context(ctx).SQL("sp_spaceused ?;", tablename).Get(&rows)
	default:
		return e.Context(ctx).Count(tablename)
	}
	return rows, err
}

// InTransaction returns true if the engine is in a transaction otherwise return false
func InTransaction(ctx context.Context) bool {
	_, ok := inTransaction(ctx)
	return ok
}

func inTransaction(ctx context.Context) (*xorm.Session, bool) {
	e := getEngine(ctx)
	if e == nil {
		return nil, false
	}

	switch t := e.(type) {
	case *xorm.Engine:
		return nil, false
	case *xorm.Session:
		if t.IsInTx() {
			return t, true
		}
		return nil, false
	default:
		return nil, false
	}
}