diff options
author | Lunny Xiao <xiaolunwen@gmail.com> | 2023-12-07 15:27:36 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-07 15:27:36 +0800 |
commit | dd30d9d5c0f577cb6e084aae6de2752ad43474d8 (patch) | |
tree | 1e3799a672a23424484b849827ba39eae447856a /models/db | |
parent | beb71f5ef6e8074dc744ac995c15f7b5947a3f2e (diff) | |
download | gitea-dd30d9d5c0f577cb6e084aae6de2752ad43474d8.tar.gz gitea-dd30d9d5c0f577cb6e084aae6de2752ad43474d8.zip |
Remove GetByBean method because sometimes it's danger when query condition parameter is zero and also introduce new generic methods (#28220)
The function `GetByBean` has an obvious defect that when the fields are
empty values, it will be ignored. Then users will get a wrong result
which is possibly used to make a security problem.
To avoid the possibility, this PR removed function `GetByBean` and all
references.
And some new generic functions have been introduced to be used.
The recommand usage like below.
```go
// if query an object according id
obj, err := db.GetByID[Object](ctx, id)
// query with other conditions
obj, err := db.Get[Object](ctx, builder.Eq{"a": a, "b":b})
```
Diffstat (limited to 'models/db')
-rw-r--r-- | models/db/context.go | 46 | ||||
-rw-r--r-- | models/db/error.go | 18 | ||||
-rw-r--r-- | models/db/iterate_test.go | 6 |
3 files changed, 59 insertions, 11 deletions
diff --git a/models/db/context.go b/models/db/context.go index 45765ef7d3..7b739f7e9f 100644 --- a/models/db/context.go +++ b/models/db/context.go @@ -173,9 +173,44 @@ func Exec(ctx context.Context, sqlAndArgs ...any) (sql.Result, error) { return GetEngine(ctx).Exec(sqlAndArgs...) } -// GetByBean filled empty fields of the bean according non-empty fields to query in database. -func GetByBean(ctx context.Context, bean any) (bool, error) { - return GetEngine(ctx).Get(bean) +func Get[T any](ctx context.Context, cond builder.Cond) (object *T, exist bool, err error) { + if !cond.IsValid() { + return nil, false, ErrConditionRequired{} + } + + var bean T + has, err := GetEngine(ctx).Where(cond).NoAutoCondition().Get(&bean) + if err != nil { + return nil, false, err + } else if !has { + return nil, false, nil + } + return &bean, true, nil +} + +func GetByID[T any](ctx context.Context, id int64) (object *T, exist bool, err error) { + var bean T + has, err := GetEngine(ctx).ID(id).NoAutoCondition().Get(&bean) + if err != nil { + return nil, false, err + } else if !has { + return nil, false, nil + } + return &bean, true, nil +} + +func Exist[T any](ctx context.Context, cond builder.Cond) (bool, error) { + if !cond.IsValid() { + return false, ErrConditionRequired{} + } + + var bean T + return GetEngine(ctx).Where(cond).NoAutoCondition().Exist(&bean) +} + +func ExistByID[T any](ctx context.Context, id int64) (bool, error) { + var bean T + return GetEngine(ctx).ID(id).NoAutoCondition().Exist(&bean) } // DeleteByBean deletes all records according non-empty fields of the bean as conditions. @@ -264,8 +299,3 @@ func inTransaction(ctx context.Context) (*xorm.Session, bool) { return nil, false } } - -func Exists[T any](ctx context.Context, opts FindOptions) (bool, error) { - var bean T - return GetEngine(ctx).Where(opts.ToConds()).Exist(&bean) -} diff --git a/models/db/error.go b/models/db/error.go index 665e970e17..f601a15c01 100644 --- a/models/db/error.go +++ b/models/db/error.go @@ -72,3 +72,21 @@ func (err ErrNotExist) Error() string { func (err ErrNotExist) Unwrap() error { return util.ErrNotExist } + +// ErrConditionRequired represents an error which require condition. +type ErrConditionRequired struct{} + +// IsErrConditionRequired checks if an error is an ErrConditionRequired +func IsErrConditionRequired(err error) bool { + _, ok := err.(ErrConditionRequired) + return ok +} + +func (err ErrConditionRequired) Error() string { + return "condition is required" +} + +// Unwrap unwraps this as a ErrNotExist err +func (err ErrConditionRequired) Unwrap() error { + return util.ErrInvalidArgument +} diff --git a/models/db/iterate_test.go b/models/db/iterate_test.go index 5362f34075..0f6ba2cc94 100644 --- a/models/db/iterate_test.go +++ b/models/db/iterate_test.go @@ -31,11 +31,11 @@ func TestIterate(t *testing.T) { assert.EqualValues(t, cnt, repoUnitCnt) err = db.Iterate(db.DefaultContext, nil, func(ctx context.Context, repoUnit *repo_model.RepoUnit) error { - reopUnit2 := repo_model.RepoUnit{ID: repoUnit.ID} - has, err := db.GetByBean(ctx, &reopUnit2) + has, err := db.ExistByID[repo_model.RepoUnit](ctx, repoUnit.ID) if err != nil { return err - } else if !has { + } + if !has { return db.ErrNotExist{Resource: "repo_unit", ID: repoUnit.ID} } assert.EqualValues(t, repoUnit.RepoID, repoUnit.RepoID) |