diff options
Diffstat (limited to 'models/user')
-rw-r--r-- | models/user/avatar.go | 9 | ||||
-rw-r--r-- | models/user/email_address.go | 27 | ||||
-rw-r--r-- | models/user/email_address_test.go | 4 | ||||
-rw-r--r-- | models/user/list.go | 13 | ||||
-rw-r--r-- | models/user/openid.go | 18 | ||||
-rw-r--r-- | models/user/user.go | 100 | ||||
-rw-r--r-- | models/user/user_test.go | 12 |
7 files changed, 66 insertions, 117 deletions
diff --git a/models/user/avatar.go b/models/user/avatar.go index c881642b56..6a44a3bcb3 100644 --- a/models/user/avatar.go +++ b/models/user/avatar.go @@ -26,12 +26,7 @@ func (u *User) CustomAvatarRelativePath() string { } // GenerateRandomAvatar generates a random avatar for user. -func GenerateRandomAvatar(u *User) error { - return GenerateRandomAvatarCtx(db.DefaultContext, u) -} - -// GenerateRandomAvatarCtx generates a random avatar for user. -func GenerateRandomAvatarCtx(ctx context.Context, u *User) error { +func GenerateRandomAvatar(ctx context.Context, u *User) error { seed := u.Email if len(seed) == 0 { seed = u.Name @@ -82,7 +77,7 @@ func (u *User) AvatarLinkWithSize(size int) string { if useLocalAvatar { if u.Avatar == "" && autoGenerateAvatar { - if err := GenerateRandomAvatar(u); err != nil { + if err := GenerateRandomAvatar(db.DefaultContext, u); err != nil { log.Error("GenerateRandomAvatar: %v", err) } } diff --git a/models/user/email_address.go b/models/user/email_address.go index 564d018dac..c931db9c16 100644 --- a/models/user/email_address.go +++ b/models/user/email_address.go @@ -207,7 +207,8 @@ func IsEmailUsed(ctx context.Context, email string) (bool, error) { return db.GetEngine(ctx).Where("lower_email=?", strings.ToLower(email)).Get(&EmailAddress{}) } -func addEmailAddress(ctx context.Context, email *EmailAddress) error { +// AddEmailAddress adds an email address to given user. +func AddEmailAddress(ctx context.Context, email *EmailAddress) error { email.Email = strings.TrimSpace(email.Email) used, err := IsEmailUsed(ctx, email.Email) if err != nil { @@ -223,11 +224,6 @@ func addEmailAddress(ctx context.Context, email *EmailAddress) error { return db.Insert(ctx, email) } -// AddEmailAddress adds an email address to given user. -func AddEmailAddress(email *EmailAddress) error { - return addEmailAddress(db.DefaultContext, email) -} - // AddEmailAddresses adds an email address to given user. func AddEmailAddresses(emails []*EmailAddress) error { if len(emails) == 0 { @@ -311,14 +307,14 @@ func ActivateEmail(email *EmailAddress) error { return err } defer committer.Close() - if err := updateActivation(db.GetEngine(ctx), email, true); err != nil { + if err := updateActivation(ctx, email, true); err != nil { return err } return committer.Commit() } -func updateActivation(e db.Engine, email *EmailAddress, activate bool) error { - user, err := GetUserByIDEngine(e, email.UID) +func updateActivation(ctx context.Context, email *EmailAddress, activate bool) error { + user, err := GetUserByIDCtx(ctx, email.UID) if err != nil { return err } @@ -326,10 +322,10 @@ func updateActivation(e db.Engine, email *EmailAddress, activate bool) error { return err } email.IsActivated = activate - if _, err := e.ID(email.ID).Cols("is_activated").Update(email); err != nil { + if _, err := db.GetEngine(ctx).ID(email.ID).Cols("is_activated").Update(email); err != nil { return err } - return UpdateUserColsEngine(e, user, "rands") + return UpdateUserCols(ctx, user, "rands") } // MakeEmailPrimary sets primary email address of given user. @@ -500,12 +496,11 @@ func ActivateUserEmail(userID int64, email string, activate bool) (err error) { return err } defer committer.Close() - sess := db.GetEngine(ctx) // Activate/deactivate a user's secondary email address // First check if there's another user active with the same address addr := EmailAddress{UID: userID, LowerEmail: strings.ToLower(email)} - if has, err := sess.Get(&addr); err != nil { + if has, err := db.GetByBean(ctx, &addr); err != nil { return err } else if !has { return fmt.Errorf("no such email: %d (%s)", userID, email) @@ -521,14 +516,14 @@ func ActivateUserEmail(userID int64, email string, activate bool) (err error) { return ErrEmailAlreadyUsed{Email: email} } } - if err = updateActivation(sess, &addr, activate); err != nil { + if err = updateActivation(ctx, &addr, activate); err != nil { return fmt.Errorf("unable to updateActivation() for %d:%s: %w", addr.ID, addr.Email, err) } // Activate/deactivate a user's primary email address and account if addr.IsPrimary { user := User{ID: userID, Email: email} - if has, err := sess.Get(&user); err != nil { + if has, err := db.GetByBean(ctx, &user); err != nil { return err } else if !has { return fmt.Errorf("no user with ID: %d and Email: %s", userID, email) @@ -539,7 +534,7 @@ func ActivateUserEmail(userID int64, email string, activate bool) (err error) { if user.Rands, err = GetUserSalt(); err != nil { return fmt.Errorf("unable to generate salt: %v", err) } - if err = UpdateUserColsEngine(sess, &user, "is_active", "rands"); err != nil { + if err = UpdateUserCols(ctx, &user, "is_active", "rands"); err != nil { return fmt.Errorf("unable to updateUserCols() for user ID: %d: %v", userID, err) } } diff --git a/models/user/email_address_test.go b/models/user/email_address_test.go index 7eeb469b26..79de4c0b48 100644 --- a/models/user/email_address_test.go +++ b/models/user/email_address_test.go @@ -45,7 +45,7 @@ func TestIsEmailUsed(t *testing.T) { func TestAddEmailAddress(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - assert.NoError(t, AddEmailAddress(&EmailAddress{ + assert.NoError(t, AddEmailAddress(db.DefaultContext, &EmailAddress{ Email: "user1234567890@example.com", LowerEmail: "user1234567890@example.com", IsPrimary: true, @@ -53,7 +53,7 @@ func TestAddEmailAddress(t *testing.T) { })) // ErrEmailAlreadyUsed - err := AddEmailAddress(&EmailAddress{ + err := AddEmailAddress(db.DefaultContext, &EmailAddress{ Email: "user1234567890@example.com", LowerEmail: "user1234567890@example.com", }) diff --git a/models/user/list.go b/models/user/list.go index 5cdc92ba4a..68e62ca15d 100644 --- a/models/user/list.go +++ b/models/user/list.go @@ -5,6 +5,7 @@ package user import ( + "context" "fmt" "code.gitea.io/gitea/models/auth" @@ -31,13 +32,13 @@ func (users UserList) GetTwoFaStatus() map[int64]bool { results[user.ID] = false // Set default to false } - if tokenMaps, err := users.loadTwoFactorStatus(db.GetEngine(db.DefaultContext)); err == nil { + if tokenMaps, err := users.loadTwoFactorStatus(db.DefaultContext); err == nil { for _, token := range tokenMaps { results[token.UID] = true } } - if ids, err := users.userIDsWithWebAuthn(db.GetEngine(db.DefaultContext)); err == nil { + if ids, err := users.userIDsWithWebAuthn(db.DefaultContext); err == nil { for _, id := range ids { results[id] = true } @@ -46,25 +47,25 @@ func (users UserList) GetTwoFaStatus() map[int64]bool { return results } -func (users UserList) loadTwoFactorStatus(e db.Engine) (map[int64]*auth.TwoFactor, error) { +func (users UserList) loadTwoFactorStatus(ctx context.Context) (map[int64]*auth.TwoFactor, error) { if len(users) == 0 { return nil, nil } userIDs := users.GetUserIDs() tokenMaps := make(map[int64]*auth.TwoFactor, len(userIDs)) - if err := e.In("uid", userIDs).Find(&tokenMaps); err != nil { + if err := db.GetEngine(ctx).In("uid", userIDs).Find(&tokenMaps); err != nil { return nil, fmt.Errorf("find two factor: %v", err) } return tokenMaps, nil } -func (users UserList) userIDsWithWebAuthn(e db.Engine) ([]int64, error) { +func (users UserList) userIDsWithWebAuthn(ctx context.Context) ([]int64, error) { if len(users) == 0 { return nil, nil } ids := make([]int64, 0, len(users)) - if err := e.Table(new(auth.WebAuthnCredential)).In("user_id", users.GetUserIDs()).Select("user_id").Distinct("user_id").Find(&ids); err != nil { + if err := db.GetEngine(ctx).Table(new(auth.WebAuthnCredential)).In("user_id", users.GetUserIDs()).Select("user_id").Distinct("user_id").Find(&ids); err != nil { return nil, fmt.Errorf("find two factor: %v", err) } return ids, nil diff --git a/models/user/openid.go b/models/user/openid.go index 8ca3c7f2c8..8ef0ce5ed7 100644 --- a/models/user/openid.go +++ b/models/user/openid.go @@ -5,6 +5,7 @@ package user import ( + "context" "errors" "fmt" @@ -41,12 +42,12 @@ func GetUserOpenIDs(uid int64) ([]*UserOpenID, error) { } // isOpenIDUsed returns true if the openid has been used. -func isOpenIDUsed(e db.Engine, uri string) (bool, error) { +func isOpenIDUsed(ctx context.Context, uri string) (bool, error) { if len(uri) == 0 { return true, nil } - return e.Get(&UserOpenID{URI: uri}) + return db.GetEngine(ctx).Get(&UserOpenID{URI: uri}) } // ErrOpenIDAlreadyUsed represents a "OpenIDAlreadyUsed" kind of error. @@ -64,22 +65,17 @@ func (err ErrOpenIDAlreadyUsed) Error() string { return fmt.Sprintf("OpenID already in use [oid: %s]", err.OpenID) } +// AddUserOpenID adds an pre-verified/normalized OpenID URI to given user. // NOTE: make sure openid.URI is normalized already -func addUserOpenID(e db.Engine, openid *UserOpenID) error { - used, err := isOpenIDUsed(e, openid.URI) +func AddUserOpenID(ctx context.Context, openid *UserOpenID) error { + used, err := isOpenIDUsed(ctx, openid.URI) if err != nil { return err } else if used { return ErrOpenIDAlreadyUsed{openid.URI} } - _, err = e.Insert(openid) - return err -} - -// AddUserOpenID adds an pre-verified/normalized OpenID URI to given user. -func AddUserOpenID(openid *UserOpenID) error { - return addUserOpenID(db.GetEngine(db.DefaultContext), openid) + return db.Insert(ctx, openid) } // DeleteUserOpenID deletes an openid address of given user. diff --git a/models/user/user.go b/models/user/user.go index 6aa63a0a56..f7d457b91b 100644 --- a/models/user/user.go +++ b/models/user/user.go @@ -509,23 +509,19 @@ func SetEmailNotifications(u *User, set string) error { return nil } -func isUserExist(e db.Engine, uid int64, name string) (bool, error) { +// IsUserExist checks if given user name exist, +// the user name should be noncased unique. +// If uid is presented, then check will rule out that one, +// it is used when update a user name in settings page. +func IsUserExist(ctx context.Context, uid int64, name string) (bool, error) { if len(name) == 0 { return false, nil } - return e. + return db.GetEngine(ctx). Where("id!=?", uid). Get(&User{LowerName: strings.ToLower(name)}) } -// IsUserExist checks if given user name exist, -// the user name should be noncased unique. -// If uid is presented, then check will rule out that one, -// it is used when update a user name in settings page. -func IsUserExist(uid int64, name string) (bool, error) { - return isUserExist(db.GetEngine(db.DefaultContext), uid, name) -} - // Note: As of the beginning of 2022, it is recommended to use at least // 64 bits of salt, but NIST is already recommending to use to 128 bits. // (16 bytes = 16 * 8 = 128 bits) @@ -691,9 +687,7 @@ func CreateUser(u *User, overwriteDefault ...*CreateUserOverwriteOptions) (err e } defer committer.Close() - sess := db.GetEngine(ctx) - - isExist, err := isUserExist(sess, 0, u.Name) + isExist, err := IsUserExist(ctx, 0, u.Name) if err != nil { return err } else if isExist { @@ -774,7 +768,7 @@ func GetVerifyUser(code string) (user *User) { // use tail hex username query user hexStr := code[base.TimeLimitCodeLength:] if b, err := hex.DecodeString(hexStr); err == nil { - if user, err = GetUserByName(string(b)); user != nil { + if user, err = GetUserByName(db.DefaultContext, string(b)); user != nil { return user } log.Error("user.getVerifyUser: %v", err) @@ -811,16 +805,15 @@ func ChangeUserName(u *User, newUserName string) (err error) { return err } defer committer.Close() - sess := db.GetEngine(ctx) - isExist, err := isUserExist(sess, 0, newUserName) + isExist, err := IsUserExist(ctx, 0, newUserName) if err != nil { return err } else if isExist { return ErrUserAlreadyExist{newUserName} } - if _, err = sess.Exec("UPDATE `repository` SET owner_name=? WHERE owner_name=?", newUserName, oldUserName); err != nil { + if _, err = db.GetEngine(ctx).Exec("UPDATE `repository` SET owner_name=? WHERE owner_name=?", newUserName, oldUserName); err != nil { return fmt.Errorf("Change repo owner name: %v", err) } @@ -845,9 +838,9 @@ func ChangeUserName(u *User, newUserName string) (err error) { } // checkDupEmail checks whether there are the same email with the user -func checkDupEmail(e db.Engine, u *User) error { +func checkDupEmail(ctx context.Context, u *User) error { u.Email = strings.ToLower(u.Email) - has, err := e. + has, err := db.GetEngine(ctx). Where("id!=?", u.ID). And("type=?", u.Type). And("email=?", u.Email). @@ -872,7 +865,8 @@ func validateUser(u *User) error { return ValidateEmail(u.Email) } -func updateUser(ctx context.Context, u *User, changePrimaryEmail bool, cols ...string) error { +// UpdateUser updates user's information. +func UpdateUser(ctx context.Context, u *User, changePrimaryEmail bool, cols ...string) error { err := validateUser(u) if err != nil { return err @@ -932,27 +926,13 @@ func updateUser(ctx context.Context, u *User, changePrimaryEmail bool, cols ...s return err } -// UpdateUser updates user's information. -func UpdateUser(u *User, emailChanged bool, cols ...string) error { - return updateUser(db.DefaultContext, u, emailChanged, cols...) -} - // UpdateUserCols update user according special columns func UpdateUserCols(ctx context.Context, u *User, cols ...string) error { - return updateUserCols(db.GetEngine(ctx), u, cols...) -} - -// UpdateUserColsEngine update user according special columns -func UpdateUserColsEngine(e db.Engine, u *User, cols ...string) error { - return updateUserCols(e, u, cols...) -} - -func updateUserCols(e db.Engine, u *User, cols ...string) error { if err := validateUser(u); err != nil { return err } - _, err := e.ID(u.ID).Cols(cols...).Update(u) + _, err := db.GetEngine(ctx).ID(u.ID).Cols(cols...).Update(u) return err } @@ -965,11 +945,11 @@ func UpdateUserSetting(u *User) (err error) { defer committer.Close() if !u.IsOrganization() { - if err = checkDupEmail(db.GetEngine(ctx), u); err != nil { + if err = checkDupEmail(ctx, u); err != nil { return err } } - if err = updateUser(ctx, u, false); err != nil { + if err = UpdateUser(ctx, u, false); err != nil { return err } return committer.Commit() @@ -994,18 +974,6 @@ func UserPath(userName string) string { //revive:disable-line:exported return filepath.Join(setting.RepoRootPath, strings.ToLower(userName)) } -// GetUserByIDEngine returns the user object by given ID if exists. -func GetUserByIDEngine(e db.Engine, id int64) (*User, error) { - u := new(User) - has, err := e.ID(id).Get(u) - if err != nil { - return nil, err - } else if !has { - return nil, ErrUserNotExist{id, "", 0} - } - return u, nil -} - // GetUserByID returns the user object by given ID if exists. func GetUserByID(id int64) (*User, error) { return GetUserByIDCtx(db.DefaultContext, id) @@ -1013,16 +981,18 @@ func GetUserByID(id int64) (*User, error) { // GetUserByIDCtx returns the user object by given ID if exists. func GetUserByIDCtx(ctx context.Context, id int64) (*User, error) { - return GetUserByIDEngine(db.GetEngine(ctx), id) -} - -// GetUserByName returns user by given name. -func GetUserByName(name string) (*User, error) { - return GetUserByNameCtx(db.DefaultContext, name) + u := new(User) + has, err := db.GetEngine(ctx).ID(id).Get(u) + if err != nil { + return nil, err + } else if !has { + return nil, ErrUserNotExist{id, "", 0} + } + return u, nil } // GetUserByNameCtx returns user by given name. -func GetUserByNameCtx(ctx context.Context, name string) (*User, error) { +func GetUserByName(ctx context.Context, name string) (*User, error) { if len(name) == 0 { return nil, ErrUserNotExist{0, name, 0} } @@ -1038,14 +1008,10 @@ func GetUserByNameCtx(ctx context.Context, name string) (*User, error) { // GetUserEmailsByNames returns a list of e-mails corresponds to names of users // that have their email notifications set to enabled or onmention. -func GetUserEmailsByNames(names []string) []string { - return getUserEmailsByNames(db.DefaultContext, names) -} - -func getUserEmailsByNames(ctx context.Context, names []string) []string { +func GetUserEmailsByNames(ctx context.Context, names []string) []string { mails := make([]string, 0, len(names)) for _, name := range names { - u, err := GetUserByNameCtx(ctx, name) + u, err := GetUserByName(ctx, name) if err != nil { continue } @@ -1108,7 +1074,7 @@ func GetUserNameByID(ctx context.Context, id int64) (string, error) { func GetUserIDsByNames(names []string, ignoreNonExistent bool) ([]int64, error) { ids := make([]int64, 0, len(names)) for _, name := range names { - u, err := GetUserByName(name) + u, err := GetUserByName(db.DefaultContext, name) if err != nil { if ignoreNonExistent { continue @@ -1254,11 +1220,7 @@ func GetAdminUser() (*User, error) { } // IsUserVisibleToViewer check if viewer is able to see user profile -func IsUserVisibleToViewer(u, viewer *User) bool { - return isUserVisibleToViewer(db.GetEngine(db.DefaultContext), u, viewer) -} - -func isUserVisibleToViewer(e db.Engine, u, viewer *User) bool { +func IsUserVisibleToViewer(ctx context.Context, u, viewer *User) bool { if viewer != nil && viewer.IsAdmin { return true } @@ -1283,7 +1245,7 @@ func isUserVisibleToViewer(e db.Engine, u, viewer *User) bool { } // Now we need to check if they in some organization together - count, err := e.Table("team_user"). + count, err := db.GetEngine(ctx).Table("team_user"). Where( builder.And( builder.Eq{"uid": viewer.ID}, diff --git a/models/user/user_test.go b/models/user/user_test.go index 335537aa13..0dbf2fc205 100644 --- a/models/user/user_test.go +++ b/models/user/user_test.go @@ -31,10 +31,10 @@ func TestGetUserEmailsByNames(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) // ignore none active user email - assert.Equal(t, []string{"user8@example.com"}, GetUserEmailsByNames([]string{"user8", "user9"})) - assert.Equal(t, []string{"user8@example.com", "user5@example.com"}, GetUserEmailsByNames([]string{"user8", "user5"})) + assert.Equal(t, []string{"user8@example.com"}, GetUserEmailsByNames(db.DefaultContext, []string{"user8", "user9"})) + assert.Equal(t, []string{"user8@example.com", "user5@example.com"}, GetUserEmailsByNames(db.DefaultContext, []string{"user8", "user5"})) - assert.Equal(t, []string{"user8@example.com"}, GetUserEmailsByNames([]string{"user8", "user7"})) + assert.Equal(t, []string{"user8@example.com"}, GetUserEmailsByNames(db.DefaultContext, []string{"user8", "user7"})) } func TestCanCreateOrganization(t *testing.T) { @@ -287,19 +287,19 @@ func TestUpdateUser(t *testing.T) { user := unittest.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User) user.KeepActivityPrivate = true - assert.NoError(t, UpdateUser(user, false)) + assert.NoError(t, UpdateUser(db.DefaultContext, user, false)) user = unittest.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User) assert.True(t, user.KeepActivityPrivate) setting.Service.AllowedUserVisibilityModesSlice = []bool{true, false, false} user.KeepActivityPrivate = false user.Visibility = structs.VisibleTypePrivate - assert.Error(t, UpdateUser(user, false)) + assert.Error(t, UpdateUser(db.DefaultContext, user, false)) user = unittest.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User) assert.True(t, user.KeepActivityPrivate) user.Email = "no mail@mail.org" - assert.Error(t, UpdateUser(user, true)) + assert.Error(t, UpdateUser(db.DefaultContext, user, true)) } func TestNewUserRedirect(t *testing.T) { |