aboutsummaryrefslogtreecommitdiffstats
path: root/models/db
diff options
context:
space:
mode:
authorLunny Xiao <xiaolunwen@gmail.com>2023-12-07 15:27:36 +0800
committerGitHub <noreply@github.com>2023-12-07 15:27:36 +0800
commitdd30d9d5c0f577cb6e084aae6de2752ad43474d8 (patch)
tree1e3799a672a23424484b849827ba39eae447856a /models/db
parentbeb71f5ef6e8074dc744ac995c15f7b5947a3f2e (diff)
downloadgitea-dd30d9d5c0f577cb6e084aae6de2752ad43474d8.tar.gz
gitea-dd30d9d5c0f577cb6e084aae6de2752ad43474d8.zip
Remove GetByBean method because sometimes it's danger when query condition parameter is zero and also introduce new generic methods (#28220)
The function `GetByBean` has an obvious defect that when the fields are empty values, it will be ignored. Then users will get a wrong result which is possibly used to make a security problem. To avoid the possibility, this PR removed function `GetByBean` and all references. And some new generic functions have been introduced to be used. The recommand usage like below. ```go // if query an object according id obj, err := db.GetByID[Object](ctx, id) // query with other conditions obj, err := db.Get[Object](ctx, builder.Eq{"a": a, "b":b}) ```
Diffstat (limited to 'models/db')
-rw-r--r--models/db/context.go46
-rw-r--r--models/db/error.go18
-rw-r--r--models/db/iterate_test.go6
3 files changed, 59 insertions, 11 deletions
diff --git a/models/db/context.go b/models/db/context.go
index 45765ef7d3..7b739f7e9f 100644
--- a/models/db/context.go
+++ b/models/db/context.go
@@ -173,9 +173,44 @@ func Exec(ctx context.Context, sqlAndArgs ...any) (sql.Result, error) {
return GetEngine(ctx).Exec(sqlAndArgs...)
}
-// GetByBean filled empty fields of the bean according non-empty fields to query in database.
-func GetByBean(ctx context.Context, bean any) (bool, error) {
- return GetEngine(ctx).Get(bean)
+func Get[T any](ctx context.Context, cond builder.Cond) (object *T, exist bool, err error) {
+ if !cond.IsValid() {
+ return nil, false, ErrConditionRequired{}
+ }
+
+ var bean T
+ has, err := GetEngine(ctx).Where(cond).NoAutoCondition().Get(&bean)
+ if err != nil {
+ return nil, false, err
+ } else if !has {
+ return nil, false, nil
+ }
+ return &bean, true, nil
+}
+
+func GetByID[T any](ctx context.Context, id int64) (object *T, exist bool, err error) {
+ var bean T
+ has, err := GetEngine(ctx).ID(id).NoAutoCondition().Get(&bean)
+ if err != nil {
+ return nil, false, err
+ } else if !has {
+ return nil, false, nil
+ }
+ return &bean, true, nil
+}
+
+func Exist[T any](ctx context.Context, cond builder.Cond) (bool, error) {
+ if !cond.IsValid() {
+ return false, ErrConditionRequired{}
+ }
+
+ var bean T
+ return GetEngine(ctx).Where(cond).NoAutoCondition().Exist(&bean)
+}
+
+func ExistByID[T any](ctx context.Context, id int64) (bool, error) {
+ var bean T
+ return GetEngine(ctx).ID(id).NoAutoCondition().Exist(&bean)
}
// DeleteByBean deletes all records according non-empty fields of the bean as conditions.
@@ -264,8 +299,3 @@ func inTransaction(ctx context.Context) (*xorm.Session, bool) {
return nil, false
}
}
-
-func Exists[T any](ctx context.Context, opts FindOptions) (bool, error) {
- var bean T
- return GetEngine(ctx).Where(opts.ToConds()).Exist(&bean)
-}
diff --git a/models/db/error.go b/models/db/error.go
index 665e970e17..f601a15c01 100644
--- a/models/db/error.go
+++ b/models/db/error.go
@@ -72,3 +72,21 @@ func (err ErrNotExist) Error() string {
func (err ErrNotExist) Unwrap() error {
return util.ErrNotExist
}
+
+// ErrConditionRequired represents an error which require condition.
+type ErrConditionRequired struct{}
+
+// IsErrConditionRequired checks if an error is an ErrConditionRequired
+func IsErrConditionRequired(err error) bool {
+ _, ok := err.(ErrConditionRequired)
+ return ok
+}
+
+func (err ErrConditionRequired) Error() string {
+ return "condition is required"
+}
+
+// Unwrap unwraps this as a ErrNotExist err
+func (err ErrConditionRequired) Unwrap() error {
+ return util.ErrInvalidArgument
+}
diff --git a/models/db/iterate_test.go b/models/db/iterate_test.go
index 5362f34075..0f6ba2cc94 100644
--- a/models/db/iterate_test.go
+++ b/models/db/iterate_test.go
@@ -31,11 +31,11 @@ func TestIterate(t *testing.T) {
assert.EqualValues(t, cnt, repoUnitCnt)
err = db.Iterate(db.DefaultContext, nil, func(ctx context.Context, repoUnit *repo_model.RepoUnit) error {
- reopUnit2 := repo_model.RepoUnit{ID: repoUnit.ID}
- has, err := db.GetByBean(ctx, &reopUnit2)
+ has, err := db.ExistByID[repo_model.RepoUnit](ctx, repoUnit.ID)
if err != nil {
return err
- } else if !has {
+ }
+ if !has {
return db.ErrNotExist{Resource: "repo_unit", ID: repoUnit.ID}
}
assert.EqualValues(t, repoUnit.RepoID, repoUnit.RepoID)