diff options
author | zeripath <art27@cantab.net> | 2021-09-23 16:45:36 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-09-23 23:45:36 +0800 |
commit | 9302eba971611601c3ebf6024e22a11c63f4e151 (patch) | |
tree | a3e5583986161ef62e7affc694098279ecf2217d /models/user.go | |
parent | b22be7f594401d7bd81196750456ce52185bd391 (diff) | |
download | gitea-9302eba971611601c3ebf6024e22a11c63f4e151.tar.gz gitea-9302eba971611601c3ebf6024e22a11c63f4e151.zip |
DBContext is just a Context (#17100)
* DBContext is just a Context
This PR removes some of the specialness from the DBContext and makes it context
This allows us to simplify the GetEngine code to wrap around any context in future
and means that we can change our loadRepo(e Engine) functions to simply take contexts.
Signed-off-by: Andrew Thornton <art27@cantab.net>
* fix unit tests
Signed-off-by: Andrew Thornton <art27@cantab.net>
* another place that needs to set the initial context
Signed-off-by: Andrew Thornton <art27@cantab.net>
* avoid race
Signed-off-by: Andrew Thornton <art27@cantab.net>
* change attachment error
Signed-off-by: Andrew Thornton <art27@cantab.net>
Diffstat (limited to 'models/user.go')
-rw-r--r-- | models/user.go | 86 |
1 files changed, 43 insertions, 43 deletions
diff --git a/models/user.go b/models/user.go index a4a3d83166..fc5d417d36 100644 --- a/models/user.go +++ b/models/user.go @@ -236,7 +236,7 @@ func (u *User) GetEmail() string { // GetAllUsers returns a slice of all individual users found in DB. func GetAllUsers() ([]*User, error) { users := make([]*User, 0) - return users, db.DefaultContext().Engine().OrderBy("id").Where("type = ?", UserTypeIndividual).Find(&users) + return users, db.GetEngine(db.DefaultContext).OrderBy("id").Where("type = ?", UserTypeIndividual).Find(&users) } // IsLocal returns true if user login type is LoginPlain. @@ -332,7 +332,7 @@ func (u *User) GenerateEmailActivateCode(email string) string { // GetFollowers returns range of user's followers. func (u *User) GetFollowers(listOptions ListOptions) ([]*User, error) { - sess := db.DefaultContext().Engine(). + sess := db.GetEngine(db.DefaultContext). Where("follow.follow_id=?", u.ID). Join("LEFT", "follow", "`user`.id=follow.user_id") @@ -354,7 +354,7 @@ func (u *User) IsFollowing(followID int64) bool { // GetFollowing returns range of user's following. func (u *User) GetFollowing(listOptions ListOptions) ([]*User, error) { - sess := db.DefaultContext().Engine(). + sess := db.GetEngine(db.DefaultContext). Where("follow.user_id=?", u.ID). Join("LEFT", "follow", "`user`.id=follow.follow_id") @@ -437,7 +437,7 @@ func (u *User) IsPasswordSet() bool { // IsVisibleToUser check if viewer is able to see user profile func (u *User) IsVisibleToUser(viewer *User) bool { - return u.isVisibleToUser(db.DefaultContext().Engine(), viewer) + return u.isVisibleToUser(db.GetEngine(db.DefaultContext), viewer) } func (u *User) isVisibleToUser(e db.Engine, viewer *User) bool { @@ -465,7 +465,7 @@ func (u *User) isVisibleToUser(e db.Engine, viewer *User) bool { } // Now we need to check if they in some organization together - count, err := db.DefaultContext().Engine().Table("team_user"). + count, err := db.GetEngine(db.DefaultContext).Table("team_user"). Where( builder.And( builder.Eq{"uid": viewer.ID}, @@ -508,7 +508,7 @@ func (u *User) IsUserOrgOwner(orgID int64) bool { // HasMemberWithUserID returns true if user with userID is part of the u organisation. func (u *User) HasMemberWithUserID(userID int64) bool { - return u.hasMemberWithUserID(db.DefaultContext().Engine(), userID) + return u.hasMemberWithUserID(db.GetEngine(db.DefaultContext), userID) } func (u *User) hasMemberWithUserID(e db.Engine, userID int64) bool { @@ -538,7 +538,7 @@ func (u *User) getOrganizationCount(e db.Engine) (int64, error) { // GetOrganizationCount returns count of membership of organization of user. func (u *User) GetOrganizationCount() (int64, error) { - return u.getOrganizationCount(db.DefaultContext().Engine()) + return u.getOrganizationCount(db.GetEngine(db.DefaultContext)) } // GetRepositories returns repositories that user owns, including private repositories. @@ -552,7 +552,7 @@ func (u *User) GetRepositories(listOpts ListOptions, names ...string) (err error func (u *User) GetRepositoryIDs(units ...UnitType) ([]int64, error) { var ids []int64 - sess := db.DefaultContext().Engine().Table("repository").Cols("repository.id") + sess := db.GetEngine(db.DefaultContext).Table("repository").Cols("repository.id") if len(units) > 0 { sess = sess.Join("INNER", "repo_unit", "repository.id = repo_unit.repo_id") @@ -567,7 +567,7 @@ func (u *User) GetRepositoryIDs(units ...UnitType) ([]int64, error) { func (u *User) GetActiveRepositoryIDs(units ...UnitType) ([]int64, error) { var ids []int64 - sess := db.DefaultContext().Engine().Table("repository").Cols("repository.id") + sess := db.GetEngine(db.DefaultContext).Table("repository").Cols("repository.id") if len(units) > 0 { sess = sess.Join("INNER", "repo_unit", "repository.id = repo_unit.repo_id") @@ -584,7 +584,7 @@ func (u *User) GetActiveRepositoryIDs(units ...UnitType) ([]int64, error) { func (u *User) GetOrgRepositoryIDs(units ...UnitType) ([]int64, error) { var ids []int64 - if err := db.DefaultContext().Engine().Table("repository"). + if err := db.GetEngine(db.DefaultContext).Table("repository"). Cols("repository.id"). Join("INNER", "team_user", "repository.owner_id = team_user.org_id"). Join("INNER", "team_repo", "(? != ? and repository.is_private != ?) OR (team_user.team_id = team_repo.team_id AND repository.id = team_repo.repo_id)", true, u.IsRestricted, true). @@ -605,7 +605,7 @@ func (u *User) GetOrgRepositoryIDs(units ...UnitType) ([]int64, error) { func (u *User) GetActiveOrgRepositoryIDs(units ...UnitType) ([]int64, error) { var ids []int64 - if err := db.DefaultContext().Engine().Table("repository"). + if err := db.GetEngine(db.DefaultContext).Table("repository"). Cols("repository.id"). Join("INNER", "team_user", "repository.owner_id = team_user.org_id"). Join("INNER", "team_repo", "(? != ? and repository.is_private != ?) OR (team_user.team_id = team_repo.team_id AND repository.id = team_repo.repo_id)", true, u.IsRestricted, true). @@ -743,7 +743,7 @@ func isUserExist(e db.Engine, uid int64, name string) (bool, error) { // If uid is presented, then check will rule out that one, // it is used when update a user name in settings page. func IsUserExist(uid int64, name string) (bool, error) { - return isUserExist(db.DefaultContext().Engine(), uid, name) + return isUserExist(db.GetEngine(db.DefaultContext), uid, name) } // GetUserSalt returns a random user salt token. @@ -879,7 +879,7 @@ func CreateUser(u *User, overwriteDefault ...*CreateUserOverwriteOptions) (err e u.Visibility = overwriteDefault[0].Visibility } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err @@ -949,7 +949,7 @@ func countUsers(e db.Engine) int64 { // CountUsers returns number of users. func CountUsers() int64 { - return countUsers(db.DefaultContext().Engine()) + return countUsers(db.GetEngine(db.DefaultContext)) } // get user by verify code @@ -997,7 +997,7 @@ func VerifyActiveEmailCode(code, email string) *EmailAddress { if base.VerifyTimeLimitCode(data, minutes, prefix) { emailAddress := &EmailAddress{UID: user.ID, Email: email} - if has, _ := db.DefaultContext().Engine().Get(emailAddress); has { + if has, _ := db.GetEngine(db.DefaultContext).Get(emailAddress); has { return emailAddress } } @@ -1012,7 +1012,7 @@ func ChangeUserName(u *User, newUserName string) (err error) { return err } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err @@ -1086,12 +1086,12 @@ func updateUser(e db.Engine, u *User) error { // UpdateUser updates user's information. func UpdateUser(u *User) error { - return updateUser(db.DefaultContext().Engine(), u) + return updateUser(db.GetEngine(db.DefaultContext), u) } // UpdateUserCols update user according special columns func UpdateUserCols(u *User, cols ...string) error { - return updateUserCols(db.DefaultContext().Engine(), u, cols...) + return updateUserCols(db.GetEngine(db.DefaultContext), u, cols...) } func updateUserCols(e db.Engine, u *User, cols ...string) error { @@ -1105,7 +1105,7 @@ func updateUserCols(e db.Engine, u *User, cols ...string) error { // UpdateUserSetting updates user's settings. func UpdateUserSetting(u *User) (err error) { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err @@ -1311,7 +1311,7 @@ func DeleteUser(u *User) (err error) { return fmt.Errorf("%s is an organization not a user", u.Name) } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err @@ -1329,13 +1329,13 @@ func DeleteUser(u *User) (err error) { func DeleteInactiveUsers(ctx context.Context, olderThan time.Duration) (err error) { users := make([]*User, 0, 10) if olderThan > 0 { - if err = db.DefaultContext().Engine(). + if err = db.GetEngine(db.DefaultContext). Where("is_active = ? and created_unix < ?", false, time.Now().Add(-olderThan).Unix()). Find(&users); err != nil { return fmt.Errorf("get all inactive users: %v", err) } } else { - if err = db.DefaultContext().Engine(). + if err = db.GetEngine(db.DefaultContext). Where("is_active = ?", false). Find(&users); err != nil { return fmt.Errorf("get all inactive users: %v", err) @@ -1357,7 +1357,7 @@ func DeleteInactiveUsers(ctx context.Context, olderThan time.Duration) (err erro } } - _, err = db.DefaultContext().Engine(). + _, err = db.GetEngine(db.DefaultContext). Where("is_activated = ?", false). Delete(new(EmailAddress)) return err @@ -1381,12 +1381,12 @@ func getUserByID(e db.Engine, id int64) (*User, error) { // GetUserByID returns the user object by given ID if exists. func GetUserByID(id int64) (*User, error) { - return getUserByID(db.DefaultContext().Engine(), id) + return getUserByID(db.GetEngine(db.DefaultContext), id) } // GetUserByName returns user by given name. func GetUserByName(name string) (*User, error) { - return getUserByName(db.DefaultContext().Engine(), name) + return getUserByName(db.GetEngine(db.DefaultContext), name) } func getUserByName(e db.Engine, name string) (*User, error) { @@ -1406,7 +1406,7 @@ func getUserByName(e db.Engine, name string) (*User, error) { // GetUserEmailsByNames returns a list of e-mails corresponds to names of users // that have their email notifications set to enabled or onmention. func GetUserEmailsByNames(names []string) []string { - return getUserEmailsByNames(db.DefaultContext().Engine(), names) + return getUserEmailsByNames(db.GetEngine(db.DefaultContext), names) } func getUserEmailsByNames(e db.Engine, names []string) []string { @@ -1431,7 +1431,7 @@ func GetMaileableUsersByIDs(ids []int64, isMention bool) ([]*User, error) { ous := make([]*User, 0, len(ids)) if isMention { - return ous, db.DefaultContext().Engine().In("id", ids). + return ous, db.GetEngine(db.DefaultContext).In("id", ids). Where("`type` = ?", UserTypeIndividual). And("`prohibit_login` = ?", false). And("`is_active` = ?", true). @@ -1439,7 +1439,7 @@ func GetMaileableUsersByIDs(ids []int64, isMention bool) ([]*User, error) { Find(&ous) } - return ous, db.DefaultContext().Engine().In("id", ids). + return ous, db.GetEngine(db.DefaultContext).In("id", ids). Where("`type` = ?", UserTypeIndividual). And("`prohibit_login` = ?", false). And("`is_active` = ?", true). @@ -1450,7 +1450,7 @@ func GetMaileableUsersByIDs(ids []int64, isMention bool) ([]*User, error) { // GetUserNamesByIDs returns usernames for all resolved users from a list of Ids. func GetUserNamesByIDs(ids []int64) ([]string, error) { unames := make([]string, 0, len(ids)) - err := db.DefaultContext().Engine().In("id", ids). + err := db.GetEngine(db.DefaultContext).In("id", ids). Table("user"). Asc("name"). Cols("name"). @@ -1464,7 +1464,7 @@ func GetUsersByIDs(ids []int64) (UserList, error) { if len(ids) == 0 { return ous, nil } - err := db.DefaultContext().Engine().In("id", ids). + err := db.GetEngine(db.DefaultContext).In("id", ids). Asc("name"). Find(&ous) return ous, err @@ -1490,7 +1490,7 @@ func GetUserIDsByNames(names []string, ignoreNonExistent bool) ([]int64, error) // GetUsersBySource returns a list of Users for a login source func GetUsersBySource(s *LoginSource) ([]*User, error) { var users []*User - err := db.DefaultContext().Engine().Where("login_type = ? AND login_source = ?", s.Type, s.ID).Find(&users) + err := db.GetEngine(db.DefaultContext).Where("login_type = ? AND login_source = ?", s.Type, s.ID).Find(&users) return users, err } @@ -1539,11 +1539,11 @@ func ValidateCommitsWithEmails(oldCommits []*git.Commit) []*UserCommit { // GetUserByEmail returns the user object by given e-mail if exists. func GetUserByEmail(email string) (*User, error) { - return GetUserByEmailContext(db.DefaultContext(), email) + return GetUserByEmailContext(db.DefaultContext, email) } // GetUserByEmailContext returns the user object by given e-mail if exists with db context -func GetUserByEmailContext(ctx *db.Context, email string) (*User, error) { +func GetUserByEmailContext(ctx context.Context, email string) (*User, error) { if len(email) == 0 { return nil, ErrUserNotExist{0, email, 0} } @@ -1551,7 +1551,7 @@ func GetUserByEmailContext(ctx *db.Context, email string) (*User, error) { email = strings.ToLower(email) // First try to find the user by primary email user := &User{Email: email} - has, err := ctx.Engine().Get(user) + has, err := db.GetEngine(ctx).Get(user) if err != nil { return nil, err } @@ -1561,19 +1561,19 @@ func GetUserByEmailContext(ctx *db.Context, email string) (*User, error) { // Otherwise, check in alternative list for activated email addresses emailAddress := &EmailAddress{Email: email, IsActivated: true} - has, err = ctx.Engine().Get(emailAddress) + has, err = db.GetEngine(ctx).Get(emailAddress) if err != nil { return nil, err } if has { - return getUserByID(ctx.Engine(), emailAddress.UID) + return getUserByID(db.GetEngine(ctx), emailAddress.UID) } // Finally, if email address is the protected email address: if strings.HasSuffix(email, fmt.Sprintf("@%s", setting.Service.NoReplyAddress)) { username := strings.TrimSuffix(email, fmt.Sprintf("@%s", setting.Service.NoReplyAddress)) user := &User{} - has, err := ctx.Engine().Where("lower_name=?", username).Get(user) + has, err := db.GetEngine(ctx).Where("lower_name=?", username).Get(user) if err != nil { return nil, err } @@ -1587,7 +1587,7 @@ func GetUserByEmailContext(ctx *db.Context, email string) (*User, error) { // GetUser checks if a user already exists func GetUser(user *User) (bool, error) { - return db.DefaultContext().Engine().Get(user) + return db.GetEngine(db.DefaultContext).Get(user) } // SearchUserOptions contains the options for searching @@ -1664,7 +1664,7 @@ func (opts *SearchUserOptions) toConds() builder.Cond { // it returns results in given range and number of total results. func SearchUsers(opts *SearchUserOptions) (users []*User, _ int64, _ error) { cond := opts.toConds() - count, err := db.DefaultContext().Engine().Where(cond).Count(new(User)) + count, err := db.GetEngine(db.DefaultContext).Where(cond).Count(new(User)) if err != nil { return nil, 0, fmt.Errorf("Count: %v", err) } @@ -1673,7 +1673,7 @@ func SearchUsers(opts *SearchUserOptions) (users []*User, _ int64, _ error) { opts.OrderBy = SearchOrderByAlphabetically } - sess := db.DefaultContext().Engine().Where(cond).OrderBy(opts.OrderBy.String()) + sess := db.GetEngine(db.DefaultContext).Where(cond).OrderBy(opts.OrderBy.String()) if opts.Page != 0 { sess = setSessionPagination(sess, opts) } @@ -1684,7 +1684,7 @@ func SearchUsers(opts *SearchUserOptions) (users []*User, _ int64, _ error) { // GetStarredRepos returns the repos starred by a particular user func GetStarredRepos(userID int64, private bool, listOptions ListOptions) ([]*Repository, error) { - sess := db.DefaultContext().Engine().Where("star.uid=?", userID). + sess := db.GetEngine(db.DefaultContext).Where("star.uid=?", userID). Join("LEFT", "star", "`repository`.id=`star`.repo_id") if !private { sess = sess.And("is_private=?", false) @@ -1703,7 +1703,7 @@ func GetStarredRepos(userID int64, private bool, listOptions ListOptions) ([]*Re // GetWatchedRepos returns the repos watched by a particular user func GetWatchedRepos(userID int64, private bool, listOptions ListOptions) ([]*Repository, int64, error) { - sess := db.DefaultContext().Engine().Where("watch.user_id=?", userID). + sess := db.GetEngine(db.DefaultContext).Where("watch.user_id=?", userID). And("`watch`.mode<>?", RepoWatchModeDont). Join("LEFT", "watch", "`repository`.id=`watch`.repo_id") if !private { @@ -1729,7 +1729,7 @@ func IterateUser(f func(user *User) error) error { batchSize := setting.Database.IterateBufferSize for { users := make([]*User, 0, batchSize) - if err := db.DefaultContext().Engine().Limit(batchSize, start).Find(&users); err != nil { + if err := db.GetEngine(db.DefaultContext).Limit(batchSize, start).Find(&users); err != nil { return err } if len(users) == 0 { |