diff options
author | JakobDev <jakobdev@gmx.de> | 2023-09-14 19:09:32 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-14 17:09:32 +0000 |
commit | 76659b1114153603050de810006e04a938e9dcb7 (patch) | |
tree | ed180213ca459bca51c44dbcc55fd3f9bc2d0460 /models/user | |
parent | 0de09d3afcb5394cbd97e4a1c5609eb8b2acb6cf (diff) | |
download | gitea-76659b1114153603050de810006e04a938e9dcb7.tar.gz gitea-76659b1114153603050de810006e04a938e9dcb7.zip |
Reduce usage of `db.DefaultContext` (#27073)
Part of #27065
This reduces the usage of `db.DefaultContext`. I think I've got enough
files for the first PR. When this is merged, I will continue working on
this.
Considering how many files this PR affect, I hope it won't take to long
to merge, so I don't end up in the merge conflict hell.
---------
Co-authored-by: wxiaoguang <wxiaoguang@gmail.com>
Diffstat (limited to 'models/user')
-rw-r--r-- | models/user/email_address.go | 52 | ||||
-rw-r--r-- | models/user/email_address_test.go | 42 | ||||
-rw-r--r-- | models/user/list.go | 10 | ||||
-rw-r--r-- | models/user/search.go | 11 | ||||
-rw-r--r-- | models/user/user.go | 62 | ||||
-rw-r--r-- | models/user/user_test.go | 24 |
6 files changed, 101 insertions, 100 deletions
diff --git a/models/user/email_address.go b/models/user/email_address.go index e916249e30..f1ed6692cf 100644 --- a/models/user/email_address.go +++ b/models/user/email_address.go @@ -178,9 +178,9 @@ func ValidateEmail(email string) error { } // GetEmailAddresses returns all email addresses belongs to given user. -func GetEmailAddresses(uid int64) ([]*EmailAddress, error) { +func GetEmailAddresses(ctx context.Context, uid int64) ([]*EmailAddress, error) { emails := make([]*EmailAddress, 0, 5) - if err := db.GetEngine(db.DefaultContext). + if err := db.GetEngine(ctx). Where("uid=?", uid). Asc("id"). Find(&emails); err != nil { @@ -190,10 +190,10 @@ func GetEmailAddresses(uid int64) ([]*EmailAddress, error) { } // GetEmailAddressByID gets a user's email address by ID -func GetEmailAddressByID(uid, id int64) (*EmailAddress, error) { +func GetEmailAddressByID(ctx context.Context, uid, id int64) (*EmailAddress, error) { // User ID is required for security reasons email := &EmailAddress{UID: uid} - if has, err := db.GetEngine(db.DefaultContext).ID(id).Get(email); err != nil { + if has, err := db.GetEngine(ctx).ID(id).Get(email); err != nil { return nil, err } else if !has { return nil, nil @@ -253,7 +253,7 @@ func AddEmailAddress(ctx context.Context, email *EmailAddress) error { } // AddEmailAddresses adds an email address to given user. -func AddEmailAddresses(emails []*EmailAddress) error { +func AddEmailAddresses(ctx context.Context, emails []*EmailAddress) error { if len(emails) == 0 { return nil } @@ -261,7 +261,7 @@ func AddEmailAddresses(emails []*EmailAddress) error { // Check if any of them has been used for i := range emails { emails[i].Email = strings.TrimSpace(emails[i].Email) - used, err := IsEmailUsed(db.DefaultContext, emails[i].Email) + used, err := IsEmailUsed(ctx, emails[i].Email) if err != nil { return err } else if used { @@ -272,7 +272,7 @@ func AddEmailAddresses(emails []*EmailAddress) error { } } - if err := db.Insert(db.DefaultContext, emails); err != nil { + if err := db.Insert(ctx, emails); err != nil { return fmt.Errorf("Insert: %w", err) } @@ -280,7 +280,7 @@ func AddEmailAddresses(emails []*EmailAddress) error { } // DeleteEmailAddress deletes an email address of given user. -func DeleteEmailAddress(email *EmailAddress) (err error) { +func DeleteEmailAddress(ctx context.Context, email *EmailAddress) (err error) { if email.IsPrimary { return ErrPrimaryEmailCannotDelete{Email: email.Email} } @@ -291,12 +291,12 @@ func DeleteEmailAddress(email *EmailAddress) (err error) { UID: email.UID, } if email.ID > 0 { - deleted, err = db.GetEngine(db.DefaultContext).ID(email.ID).Delete(&address) + deleted, err = db.GetEngine(ctx).ID(email.ID).Delete(&address) } else { if email.Email != "" && email.LowerEmail == "" { email.LowerEmail = strings.ToLower(email.Email) } - deleted, err = db.GetEngine(db.DefaultContext). + deleted, err = db.GetEngine(ctx). Where("lower_email=?", email.LowerEmail). Delete(&address) } @@ -310,9 +310,9 @@ func DeleteEmailAddress(email *EmailAddress) (err error) { } // DeleteEmailAddresses deletes multiple email addresses -func DeleteEmailAddresses(emails []*EmailAddress) (err error) { +func DeleteEmailAddresses(ctx context.Context, emails []*EmailAddress) (err error) { for i := range emails { - if err = DeleteEmailAddress(emails[i]); err != nil { + if err = DeleteEmailAddress(ctx, emails[i]); err != nil { return err } } @@ -329,8 +329,8 @@ func DeleteInactiveEmailAddresses(ctx context.Context) error { } // ActivateEmail activates the email address to given user. -func ActivateEmail(email *EmailAddress) error { - ctx, committer, err := db.TxContext(db.DefaultContext) +func ActivateEmail(ctx context.Context, email *EmailAddress) error { + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } @@ -357,8 +357,8 @@ func updateActivation(ctx context.Context, email *EmailAddress, activate bool) e } // MakeEmailPrimary sets primary email address of given user. -func MakeEmailPrimary(email *EmailAddress) error { - has, err := db.GetEngine(db.DefaultContext).Get(email) +func MakeEmailPrimary(ctx context.Context, email *EmailAddress) error { + has, err := db.GetEngine(ctx).Get(email) if err != nil { return err } else if !has { @@ -370,7 +370,7 @@ func MakeEmailPrimary(email *EmailAddress) error { } user := &User{} - has, err = db.GetEngine(db.DefaultContext).ID(email.UID).Get(user) + has, err = db.GetEngine(ctx).ID(email.UID).Get(user) if err != nil { return err } else if !has { @@ -381,7 +381,7 @@ func MakeEmailPrimary(email *EmailAddress) error { } } - ctx, committer, err := db.TxContext(db.DefaultContext) + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } @@ -411,17 +411,17 @@ func MakeEmailPrimary(email *EmailAddress) error { } // VerifyActiveEmailCode verifies active email code when active account -func VerifyActiveEmailCode(code, email string) *EmailAddress { +func VerifyActiveEmailCode(ctx context.Context, code, email string) *EmailAddress { minutes := setting.Service.ActiveCodeLives - if user := GetVerifyUser(code); user != nil { + if user := GetVerifyUser(ctx, code); user != nil { // time limit code prefix := code[:base.TimeLimitCodeLength] data := fmt.Sprintf("%d%s%s%s%s", user.ID, email, user.LowerName, user.Passwd, user.Rands) if base.VerifyTimeLimitCode(data, minutes, prefix) { emailAddress := &EmailAddress{UID: user.ID, Email: email} - if has, _ := db.GetEngine(db.DefaultContext).Get(emailAddress); has { + if has, _ := db.GetEngine(ctx).Get(emailAddress); has { return emailAddress } } @@ -466,7 +466,7 @@ type SearchEmailResult struct { // SearchEmails takes options i.e. keyword and part of email name to search, // it returns results in given range and number of total results. -func SearchEmails(opts *SearchEmailOptions) ([]*SearchEmailResult, int64, error) { +func SearchEmails(ctx context.Context, opts *SearchEmailOptions) ([]*SearchEmailResult, int64, error) { var cond builder.Cond = builder.Eq{"`user`.`type`": UserTypeIndividual} if len(opts.Keyword) > 0 { likeStr := "%" + strings.ToLower(opts.Keyword) + "%" @@ -491,7 +491,7 @@ func SearchEmails(opts *SearchEmailOptions) ([]*SearchEmailResult, int64, error) cond = cond.And(builder.Eq{"email_address.is_activated": false}) } - count, err := db.GetEngine(db.DefaultContext).Join("INNER", "`user`", "`user`.ID = email_address.uid"). + count, err := db.GetEngine(ctx).Join("INNER", "`user`", "`user`.ID = email_address.uid"). Where(cond).Count(new(EmailAddress)) if err != nil { return nil, 0, fmt.Errorf("Count: %w", err) @@ -505,7 +505,7 @@ func SearchEmails(opts *SearchEmailOptions) ([]*SearchEmailResult, int64, error) opts.SetDefaultValues() emails := make([]*SearchEmailResult, 0, opts.PageSize) - err = db.GetEngine(db.DefaultContext).Table("email_address"). + err = db.GetEngine(ctx).Table("email_address"). Select("email_address.*, `user`.name, `user`.full_name"). Join("INNER", "`user`", "`user`.ID = email_address.uid"). Where(cond). @@ -518,8 +518,8 @@ func SearchEmails(opts *SearchEmailOptions) ([]*SearchEmailResult, int64, error) // ActivateUserEmail will change the activated state of an email address, // either primary or secondary (all in the email_address table) -func ActivateUserEmail(userID int64, email string, activate bool) (err error) { - ctx, committer, err := db.TxContext(db.DefaultContext) +func ActivateUserEmail(ctx context.Context, userID int64, email string, activate bool) (err error) { + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } diff --git a/models/user/email_address_test.go b/models/user/email_address_test.go index f2b383fe4b..7f3ca75cfd 100644 --- a/models/user/email_address_test.go +++ b/models/user/email_address_test.go @@ -17,14 +17,14 @@ import ( func TestGetEmailAddresses(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - emails, _ := user_model.GetEmailAddresses(int64(1)) + emails, _ := user_model.GetEmailAddresses(db.DefaultContext, int64(1)) if assert.Len(t, emails, 3) { assert.True(t, emails[0].IsPrimary) assert.True(t, emails[2].IsActivated) assert.False(t, emails[2].IsPrimary) } - emails, _ = user_model.GetEmailAddresses(int64(2)) + emails, _ = user_model.GetEmailAddresses(db.DefaultContext, int64(2)) if assert.Len(t, emails, 2) { assert.True(t, emails[0].IsPrimary) assert.True(t, emails[0].IsActivated) @@ -76,10 +76,10 @@ func TestAddEmailAddresses(t *testing.T) { LowerEmail: "user5678@example.com", IsActivated: true, } - assert.NoError(t, user_model.AddEmailAddresses(emails)) + assert.NoError(t, user_model.AddEmailAddresses(db.DefaultContext, emails)) // ErrEmailAlreadyUsed - err := user_model.AddEmailAddresses(emails) + err := user_model.AddEmailAddresses(db.DefaultContext, emails) assert.Error(t, err) assert.True(t, user_model.IsErrEmailAlreadyUsed(err)) } @@ -87,21 +87,21 @@ func TestAddEmailAddresses(t *testing.T) { func TestDeleteEmailAddress(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - assert.NoError(t, user_model.DeleteEmailAddress(&user_model.EmailAddress{ + assert.NoError(t, user_model.DeleteEmailAddress(db.DefaultContext, &user_model.EmailAddress{ UID: int64(1), ID: int64(33), Email: "user1-2@example.com", LowerEmail: "user1-2@example.com", })) - assert.NoError(t, user_model.DeleteEmailAddress(&user_model.EmailAddress{ + assert.NoError(t, user_model.DeleteEmailAddress(db.DefaultContext, &user_model.EmailAddress{ UID: int64(1), Email: "user1-3@example.com", LowerEmail: "user1-3@example.com", })) // Email address does not exist - err := user_model.DeleteEmailAddress(&user_model.EmailAddress{ + err := user_model.DeleteEmailAddress(db.DefaultContext, &user_model.EmailAddress{ UID: int64(1), Email: "user1234567890@example.com", LowerEmail: "user1234567890@example.com", @@ -125,10 +125,10 @@ func TestDeleteEmailAddresses(t *testing.T) { Email: "user2-2@example.com", LowerEmail: "user2-2@example.com", } - assert.NoError(t, user_model.DeleteEmailAddresses(emails)) + assert.NoError(t, user_model.DeleteEmailAddresses(db.DefaultContext, emails)) // ErrEmailAlreadyUsed - err := user_model.DeleteEmailAddresses(emails) + err := user_model.DeleteEmailAddresses(db.DefaultContext, emails) assert.Error(t, err) } @@ -138,28 +138,28 @@ func TestMakeEmailPrimary(t *testing.T) { email := &user_model.EmailAddress{ Email: "user567890@example.com", } - err := user_model.MakeEmailPrimary(email) + err := user_model.MakeEmailPrimary(db.DefaultContext, email) assert.Error(t, err) assert.EqualError(t, err, user_model.ErrEmailAddressNotExist{Email: email.Email}.Error()) email = &user_model.EmailAddress{ Email: "user11@example.com", } - err = user_model.MakeEmailPrimary(email) + err = user_model.MakeEmailPrimary(db.DefaultContext, email) assert.Error(t, err) assert.EqualError(t, err, user_model.ErrEmailNotActivated.Error()) email = &user_model.EmailAddress{ Email: "user9999999@example.com", } - err = user_model.MakeEmailPrimary(email) + err = user_model.MakeEmailPrimary(db.DefaultContext, email) assert.Error(t, err) assert.True(t, user_model.IsErrUserNotExist(err)) email = &user_model.EmailAddress{ Email: "user101@example.com", } - err = user_model.MakeEmailPrimary(email) + err = user_model.MakeEmailPrimary(db.DefaultContext, email) assert.NoError(t, err) user, _ := user_model.GetUserByID(db.DefaultContext, int64(10)) @@ -174,9 +174,9 @@ func TestActivate(t *testing.T) { UID: int64(1), Email: "user11@example.com", } - assert.NoError(t, user_model.ActivateEmail(email)) + assert.NoError(t, user_model.ActivateEmail(db.DefaultContext, email)) - emails, _ := user_model.GetEmailAddresses(int64(1)) + emails, _ := user_model.GetEmailAddresses(db.DefaultContext, int64(1)) assert.Len(t, emails, 3) assert.True(t, emails[0].IsActivated) assert.True(t, emails[0].IsPrimary) @@ -194,7 +194,7 @@ func TestListEmails(t *testing.T) { PageSize: 10000, }, } - emails, count, err := user_model.SearchEmails(opts) + emails, count, err := user_model.SearchEmails(db.DefaultContext, opts) assert.NoError(t, err) assert.NotEqual(t, int64(0), count) assert.True(t, count > 5) @@ -214,13 +214,13 @@ func TestListEmails(t *testing.T) { // Must find no records opts = &user_model.SearchEmailOptions{Keyword: "NOTFOUND"} - emails, count, err = user_model.SearchEmails(opts) + emails, count, err = user_model.SearchEmails(db.DefaultContext, opts) assert.NoError(t, err) assert.Equal(t, int64(0), count) // Must find users 'user2', 'user28', etc. opts = &user_model.SearchEmailOptions{Keyword: "user2"} - emails, count, err = user_model.SearchEmails(opts) + emails, count, err = user_model.SearchEmails(db.DefaultContext, opts) assert.NoError(t, err) assert.NotEqual(t, int64(0), count) assert.True(t, contains(func(s *user_model.SearchEmailResult) bool { return s.UID == 2 })) @@ -228,14 +228,14 @@ func TestListEmails(t *testing.T) { // Must find only primary addresses (i.e. from the `user` table) opts = &user_model.SearchEmailOptions{IsPrimary: util.OptionalBoolTrue} - emails, _, err = user_model.SearchEmails(opts) + emails, _, err = user_model.SearchEmails(db.DefaultContext, opts) assert.NoError(t, err) assert.True(t, contains(func(s *user_model.SearchEmailResult) bool { return s.IsPrimary })) assert.False(t, contains(func(s *user_model.SearchEmailResult) bool { return !s.IsPrimary })) // Must find only inactive addresses (i.e. not validated) opts = &user_model.SearchEmailOptions{IsActivated: util.OptionalBoolFalse} - emails, _, err = user_model.SearchEmails(opts) + emails, _, err = user_model.SearchEmails(db.DefaultContext, opts) assert.NoError(t, err) assert.True(t, contains(func(s *user_model.SearchEmailResult) bool { return !s.IsActivated })) assert.False(t, contains(func(s *user_model.SearchEmailResult) bool { return s.IsActivated })) @@ -247,7 +247,7 @@ func TestListEmails(t *testing.T) { Page: 1, }, } - emails, count, err = user_model.SearchEmails(opts) + emails, count, err = user_model.SearchEmails(db.DefaultContext, opts) assert.NoError(t, err) assert.Len(t, emails, 5) assert.Greater(t, count, int64(len(emails))) diff --git a/models/user/list.go b/models/user/list.go index 6b3b7bea9a..ca589d1e02 100644 --- a/models/user/list.go +++ b/models/user/list.go @@ -25,19 +25,19 @@ func (users UserList) GetUserIDs() []int64 { } // GetTwoFaStatus return state of 2FA enrollement -func (users UserList) GetTwoFaStatus() map[int64]bool { +func (users UserList) GetTwoFaStatus(ctx context.Context) map[int64]bool { results := make(map[int64]bool, len(users)) for _, user := range users { results[user.ID] = false // Set default to false } - if tokenMaps, err := users.loadTwoFactorStatus(db.DefaultContext); err == nil { + if tokenMaps, err := users.loadTwoFactorStatus(ctx); err == nil { for _, token := range tokenMaps { results[token.UID] = true } } - if ids, err := users.userIDsWithWebAuthn(db.DefaultContext); err == nil { + if ids, err := users.userIDsWithWebAuthn(ctx); err == nil { for _, id := range ids { results[id] = true } @@ -71,12 +71,12 @@ func (users UserList) userIDsWithWebAuthn(ctx context.Context) ([]int64, error) } // GetUsersByIDs returns all resolved users from a list of Ids. -func GetUsersByIDs(ids []int64) (UserList, error) { +func GetUsersByIDs(ctx context.Context, ids []int64) (UserList, error) { ous := make([]*User, 0, len(ids)) if len(ids) == 0 { return ous, nil } - err := db.GetEngine(db.DefaultContext).In("id", ids). + err := db.GetEngine(ctx).In("id", ids). Asc("name"). Find(&ous) return ous, err diff --git a/models/user/search.go b/models/user/search.go index 446556f89b..0fa278c257 100644 --- a/models/user/search.go +++ b/models/user/search.go @@ -4,6 +4,7 @@ package user import ( + "context" "fmt" "strings" @@ -39,7 +40,7 @@ type SearchUserOptions struct { ExtraParamStrings map[string]string } -func (opts *SearchUserOptions) toSearchQueryBase() *xorm.Session { +func (opts *SearchUserOptions) toSearchQueryBase(ctx context.Context) *xorm.Session { var cond builder.Cond cond = builder.Eq{"type": opts.Type} if opts.IncludeReserved { @@ -101,7 +102,7 @@ func (opts *SearchUserOptions) toSearchQueryBase() *xorm.Session { cond = cond.And(builder.Eq{"prohibit_login": opts.IsProhibitLogin.IsTrue()}) } - e := db.GetEngine(db.DefaultContext) + e := db.GetEngine(ctx) if opts.IsTwoFactorEnabled.IsNone() { return e.Where(cond) } @@ -122,8 +123,8 @@ func (opts *SearchUserOptions) toSearchQueryBase() *xorm.Session { // SearchUsers takes options i.e. keyword and part of user name to search, // it returns results in given range and number of total results. -func SearchUsers(opts *SearchUserOptions) (users []*User, _ int64, _ error) { - sessCount := opts.toSearchQueryBase() +func SearchUsers(ctx context.Context, opts *SearchUserOptions) (users []*User, _ int64, _ error) { + sessCount := opts.toSearchQueryBase(ctx) defer sessCount.Close() count, err := sessCount.Count(new(User)) if err != nil { @@ -134,7 +135,7 @@ func SearchUsers(opts *SearchUserOptions) (users []*User, _ int64, _ error) { opts.OrderBy = db.SearchOrderByAlphabetically } - sessQuery := opts.toSearchQueryBase().OrderBy(opts.OrderBy.String()) + sessQuery := opts.toSearchQueryBase(ctx).OrderBy(opts.OrderBy.String()) defer sessQuery.Close() if opts.Page != 0 { sessQuery = db.SetSessionPagination(sessQuery, opts) diff --git a/models/user/user.go b/models/user/user.go index 86cf2ad280..b3956da1cb 100644 --- a/models/user/user.go +++ b/models/user/user.go @@ -192,15 +192,15 @@ func (u *User) SetLastLogin() { } // UpdateUserDiffViewStyle updates the users diff view style -func UpdateUserDiffViewStyle(u *User, style string) error { +func UpdateUserDiffViewStyle(ctx context.Context, u *User, style string) error { u.DiffViewStyle = style - return UpdateUserCols(db.DefaultContext, u, "diff_view_style") + return UpdateUserCols(ctx, u, "diff_view_style") } // UpdateUserTheme updates a users' theme irrespective of the site wide theme -func UpdateUserTheme(u *User, themeName string) error { +func UpdateUserTheme(ctx context.Context, u *User, themeName string) error { u.Theme = themeName - return UpdateUserCols(db.DefaultContext, u, "theme") + return UpdateUserCols(ctx, u, "theme") } // GetPlaceholderEmail returns an noreply email @@ -218,9 +218,9 @@ func (u *User) GetEmail() string { } // GetAllUsers returns a slice of all individual users found in DB. -func GetAllUsers() ([]*User, error) { +func GetAllUsers(ctx context.Context) ([]*User, error) { users := make([]*User, 0) - return users, db.GetEngine(db.DefaultContext).OrderBy("id").Where("type = ?", UserTypeIndividual).Find(&users) + return users, db.GetEngine(ctx).OrderBy("id").Where("type = ?", UserTypeIndividual).Find(&users) } // IsLocal returns true if user login type is LoginPlain. @@ -478,9 +478,9 @@ func (u *User) EmailNotifications() string { } // SetEmailNotifications sets the user's email notification preference -func SetEmailNotifications(u *User, set string) error { +func SetEmailNotifications(ctx context.Context, u *User, set string) error { u.EmailNotificationsPreference = set - if err := UpdateUserCols(db.DefaultContext, u, "email_notifications_preference"); err != nil { + if err := UpdateUserCols(ctx, u, "email_notifications_preference"); err != nil { log.Error("SetEmailNotifications: %v", err) return err } @@ -582,7 +582,7 @@ type CreateUserOverwriteOptions struct { } // CreateUser creates record of a new user. -func CreateUser(u *User, overwriteDefault ...*CreateUserOverwriteOptions) (err error) { +func CreateUser(ctx context.Context, u *User, overwriteDefault ...*CreateUserOverwriteOptions) (err error) { if err = IsUsableUsername(u.Name); err != nil { return err } @@ -640,7 +640,7 @@ func CreateUser(u *User, overwriteDefault ...*CreateUserOverwriteOptions) (err e return err } - ctx, committer, err := db.TxContext(db.DefaultContext) + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } @@ -711,8 +711,8 @@ type CountUserFilter struct { } // CountUsers returns number of users. -func CountUsers(opts *CountUserFilter) int64 { - return countUsers(db.DefaultContext, opts) +func CountUsers(ctx context.Context, opts *CountUserFilter) int64 { + return countUsers(ctx, opts) } func countUsers(ctx context.Context, opts *CountUserFilter) int64 { @@ -727,7 +727,7 @@ func countUsers(ctx context.Context, opts *CountUserFilter) int64 { } // GetVerifyUser get user by verify code -func GetVerifyUser(code string) (user *User) { +func GetVerifyUser(ctx context.Context, code string) (user *User) { if len(code) <= base.TimeLimitCodeLength { return nil } @@ -735,7 +735,7 @@ func GetVerifyUser(code string) (user *User) { // use tail hex username query user hexStr := code[base.TimeLimitCodeLength:] if b, err := hex.DecodeString(hexStr); err == nil { - if user, err = GetUserByName(db.DefaultContext, string(b)); user != nil { + if user, err = GetUserByName(ctx, string(b)); user != nil { return user } log.Error("user.getVerifyUser: %v", err) @@ -745,10 +745,10 @@ func GetVerifyUser(code string) (user *User) { } // VerifyUserActiveCode verifies active code when active account -func VerifyUserActiveCode(code string) (user *User) { +func VerifyUserActiveCode(ctx context.Context, code string) (user *User) { minutes := setting.Service.ActiveCodeLives - if user = GetVerifyUser(code); user != nil { + if user = GetVerifyUser(ctx, code); user != nil { // time limit code prefix := code[:base.TimeLimitCodeLength] data := fmt.Sprintf("%d%s%s%s%s", user.ID, user.Email, user.LowerName, user.Passwd, user.Rands) @@ -872,8 +872,8 @@ func UpdateUserCols(ctx context.Context, u *User, cols ...string) error { } // UpdateUserSetting updates user's settings. -func UpdateUserSetting(u *User) (err error) { - ctx, committer, err := db.TxContext(db.DefaultContext) +func UpdateUserSetting(ctx context.Context, u *User) (err error) { + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } @@ -1021,9 +1021,9 @@ func GetMaileableUsersByIDs(ctx context.Context, ids []int64, isMention bool) ([ } // GetUserNamesByIDs returns usernames for all resolved users from a list of Ids. -func GetUserNamesByIDs(ids []int64) ([]string, error) { +func GetUserNamesByIDs(ctx context.Context, ids []int64) ([]string, error) { unames := make([]string, 0, len(ids)) - err := db.GetEngine(db.DefaultContext).In("id", ids). + err := db.GetEngine(ctx).In("id", ids). Table("user"). Asc("name"). Cols("name"). @@ -1062,9 +1062,9 @@ func GetUserIDsByNames(ctx context.Context, names []string, ignoreNonExistent bo } // GetUsersBySource returns a list of Users for a login source -func GetUsersBySource(s *auth.Source) ([]*User, error) { +func GetUsersBySource(ctx context.Context, s *auth.Source) ([]*User, error) { var users []*User - err := db.GetEngine(db.DefaultContext).Where("login_type = ? AND login_source = ?", s.Type, s.ID).Find(&users) + err := db.GetEngine(ctx).Where("login_type = ? AND login_source = ?", s.Type, s.ID).Find(&users) return users, err } @@ -1145,12 +1145,12 @@ func GetUserByEmail(ctx context.Context, email string) (*User, error) { } // GetUser checks if a user already exists -func GetUser(user *User) (bool, error) { - return db.GetEngine(db.DefaultContext).Get(user) +func GetUser(ctx context.Context, user *User) (bool, error) { + return db.GetEngine(ctx).Get(user) } // GetUserByOpenID returns the user object by given OpenID if exists. -func GetUserByOpenID(uri string) (*User, error) { +func GetUserByOpenID(ctx context.Context, uri string) (*User, error) { if len(uri) == 0 { return nil, ErrUserNotExist{0, uri, 0} } @@ -1164,12 +1164,12 @@ func GetUserByOpenID(uri string) (*User, error) { // Otherwise, check in openid table oid := &UserOpenID{} - has, err := db.GetEngine(db.DefaultContext).Where("uri=?", uri).Get(oid) + has, err := db.GetEngine(ctx).Where("uri=?", uri).Get(oid) if err != nil { return nil, err } if has { - return GetUserByID(db.DefaultContext, oid.UID) + return GetUserByID(ctx, oid.UID) } return nil, ErrUserNotExist{0, uri, 0} @@ -1279,13 +1279,13 @@ func IsUserVisibleToViewer(ctx context.Context, u, viewer *User) bool { } // CountWrongUserType count OrgUser who have wrong type -func CountWrongUserType() (int64, error) { - return db.GetEngine(db.DefaultContext).Where(builder.Eq{"type": 0}.And(builder.Neq{"num_teams": 0})).Count(new(User)) +func CountWrongUserType(ctx context.Context) (int64, error) { + return db.GetEngine(ctx).Where(builder.Eq{"type": 0}.And(builder.Neq{"num_teams": 0})).Count(new(User)) } // FixWrongUserType fix OrgUser who have wrong type -func FixWrongUserType() (int64, error) { - return db.GetEngine(db.DefaultContext).Where(builder.Eq{"type": 0}.And(builder.Neq{"num_teams": 0})).Cols("type").NoAutoTime().Update(&User{Type: 1}) +func FixWrongUserType(ctx context.Context) (int64, error) { + return db.GetEngine(ctx).Where(builder.Eq{"type": 0}.And(builder.Neq{"num_teams": 0})).Cols("type").NoAutoTime().Update(&User{Type: 1}) } func GetOrderByName() string { diff --git a/models/user/user_test.go b/models/user/user_test.go index 032dcba676..b15f0cbc59 100644 --- a/models/user/user_test.go +++ b/models/user/user_test.go @@ -63,7 +63,7 @@ func TestCanCreateOrganization(t *testing.T) { func TestSearchUsers(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) testSuccess := func(opts *user_model.SearchUserOptions, expectedUserOrOrgIDs []int64) { - users, _, err := user_model.SearchUsers(opts) + users, _, err := user_model.SearchUsers(db.DefaultContext, opts) assert.NoError(t, err) cassText := fmt.Sprintf("ids: %v, opts: %v", expectedUserOrOrgIDs, opts) if assert.Len(t, users, len(expectedUserOrOrgIDs), "case: %s", cassText) { @@ -150,16 +150,16 @@ func TestEmailNotificationPreferences(t *testing.T) { assert.Equal(t, test.expected, user.EmailNotifications()) // Try all possible settings - assert.NoError(t, user_model.SetEmailNotifications(user, user_model.EmailNotificationsEnabled)) + assert.NoError(t, user_model.SetEmailNotifications(db.DefaultContext, user, user_model.EmailNotificationsEnabled)) assert.Equal(t, user_model.EmailNotificationsEnabled, user.EmailNotifications()) - assert.NoError(t, user_model.SetEmailNotifications(user, user_model.EmailNotificationsOnMention)) + assert.NoError(t, user_model.SetEmailNotifications(db.DefaultContext, user, user_model.EmailNotificationsOnMention)) assert.Equal(t, user_model.EmailNotificationsOnMention, user.EmailNotifications()) - assert.NoError(t, user_model.SetEmailNotifications(user, user_model.EmailNotificationsDisabled)) + assert.NoError(t, user_model.SetEmailNotifications(db.DefaultContext, user, user_model.EmailNotificationsDisabled)) assert.Equal(t, user_model.EmailNotificationsDisabled, user.EmailNotifications()) - assert.NoError(t, user_model.SetEmailNotifications(user, user_model.EmailNotificationsAndYourOwn)) + assert.NoError(t, user_model.SetEmailNotifications(db.DefaultContext, user, user_model.EmailNotificationsAndYourOwn)) assert.Equal(t, user_model.EmailNotificationsAndYourOwn, user.EmailNotifications()) } } @@ -239,7 +239,7 @@ func TestCreateUserInvalidEmail(t *testing.T) { MustChangePassword: false, } - err := user_model.CreateUser(user) + err := user_model.CreateUser(db.DefaultContext, user) assert.Error(t, err) assert.True(t, user_model.IsErrEmailCharIsNotSupported(err)) } @@ -253,7 +253,7 @@ func TestCreateUserEmailAlreadyUsed(t *testing.T) { user.Name = "testuser" user.LowerName = strings.ToLower(user.Name) user.ID = 0 - err := user_model.CreateUser(user) + err := user_model.CreateUser(db.DefaultContext, user) assert.Error(t, err) assert.True(t, user_model.IsErrEmailAlreadyUsed(err)) } @@ -270,7 +270,7 @@ func TestCreateUserCustomTimestamps(t *testing.T) { user.ID = 0 user.Email = "unique@example.com" user.CreatedUnix = creationTimestamp - err := user_model.CreateUser(user) + err := user_model.CreateUser(db.DefaultContext, user) assert.NoError(t, err) fetched, err := user_model.GetUserByID(context.Background(), user.ID) @@ -295,7 +295,7 @@ func TestCreateUserWithoutCustomTimestamps(t *testing.T) { user.Email = "unique@example.com" user.CreatedUnix = 0 user.UpdatedUnix = 0 - err := user_model.CreateUser(user) + err := user_model.CreateUser(db.DefaultContext, user) assert.NoError(t, err) timestampEnd := time.Now().Unix() @@ -429,17 +429,17 @@ func TestNewUserRedirect3(t *testing.T) { func TestGetUserByOpenID(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - _, err := user_model.GetUserByOpenID("https://unknown") + _, err := user_model.GetUserByOpenID(db.DefaultContext, "https://unknown") if assert.Error(t, err) { assert.True(t, user_model.IsErrUserNotExist(err)) } - user, err := user_model.GetUserByOpenID("https://user1.domain1.tld") + user, err := user_model.GetUserByOpenID(db.DefaultContext, "https://user1.domain1.tld") if assert.NoError(t, err) { assert.Equal(t, int64(1), user.ID) } - user, err = user_model.GetUserByOpenID("https://domain1.tld/user2/") + user, err = user_model.GetUserByOpenID(db.DefaultContext, "https://domain1.tld/user2/") if assert.NoError(t, err) { assert.Equal(t, int64(2), user.ID) } |