diff options
Diffstat (limited to 'models')
-rw-r--r-- | models/activities/statistic.go | 10 | ||||
-rw-r--r-- | models/asymkey/gpg_key.go | 10 | ||||
-rw-r--r-- | models/asymkey/gpg_key_add.go | 2 | ||||
-rw-r--r-- | models/asymkey/gpg_key_commit_verification.go | 21 | ||||
-rw-r--r-- | models/asymkey/ssh_key_commit_verification.go | 5 | ||||
-rw-r--r-- | models/asymkey/ssh_key_principals.go | 5 | ||||
-rw-r--r-- | models/org_team.go | 44 | ||||
-rw-r--r-- | models/org_team_test.go | 22 | ||||
-rw-r--r-- | models/organization/org.go | 14 | ||||
-rw-r--r-- | models/organization/org_test.go | 2 | ||||
-rw-r--r-- | models/organization/org_user_test.go | 4 | ||||
-rw-r--r-- | models/repo/fork.go | 12 | ||||
-rw-r--r-- | models/repo/git.go | 6 | ||||
-rw-r--r-- | models/user/email_address.go | 52 | ||||
-rw-r--r-- | models/user/email_address_test.go | 42 | ||||
-rw-r--r-- | models/user/list.go | 10 | ||||
-rw-r--r-- | models/user/search.go | 11 | ||||
-rw-r--r-- | models/user/user.go | 62 | ||||
-rw-r--r-- | models/user/user_test.go | 24 |
19 files changed, 181 insertions, 177 deletions
diff --git a/models/activities/statistic.go b/models/activities/statistic.go index 9d379cd0c4..5479db20eb 100644 --- a/models/activities/statistic.go +++ b/models/activities/statistic.go @@ -4,6 +4,8 @@ package activities import ( + "context" + asymkey_model "code.gitea.io/gitea/models/asymkey" "code.gitea.io/gitea/models/auth" "code.gitea.io/gitea/models/db" @@ -47,12 +49,12 @@ type IssueByRepositoryCount struct { } // GetStatistic returns the database statistics -func GetStatistic() (stats Statistic) { - e := db.GetEngine(db.DefaultContext) - stats.Counter.User = user_model.CountUsers(nil) +func GetStatistic(ctx context.Context) (stats Statistic) { + e := db.GetEngine(ctx) + stats.Counter.User = user_model.CountUsers(ctx, nil) stats.Counter.Org, _ = organization.CountOrgs(organization.FindOrgOptions{IncludePrivate: true}) stats.Counter.PublicKey, _ = e.Count(new(asymkey_model.PublicKey)) - stats.Counter.Repo, _ = repo_model.CountRepositories(db.DefaultContext, repo_model.CountRepositoryOptions{}) + stats.Counter.Repo, _ = repo_model.CountRepositories(ctx, repo_model.CountRepositoryOptions{}) stats.Counter.Watch, _ = e.Count(new(repo_model.Watch)) stats.Counter.Star, _ = e.Count(new(repo_model.Star)) stats.Counter.Access, _ = e.Count(new(access_model.Access)) diff --git a/models/asymkey/gpg_key.go b/models/asymkey/gpg_key.go index be019184eb..e5e5fdb2f3 100644 --- a/models/asymkey/gpg_key.go +++ b/models/asymkey/gpg_key.go @@ -144,7 +144,7 @@ func parseSubGPGKey(ownerID int64, primaryID string, pubkey *packet.PublicKey, e } // parseGPGKey parse a PrimaryKey entity (primary key + subs keys + self-signature) -func parseGPGKey(ownerID int64, e *openpgp.Entity, verified bool) (*GPGKey, error) { +func parseGPGKey(ctx context.Context, ownerID int64, e *openpgp.Entity, verified bool) (*GPGKey, error) { pubkey := e.PrimaryKey expiry := getExpiryTime(e) @@ -159,7 +159,7 @@ func parseGPGKey(ownerID int64, e *openpgp.Entity, verified bool) (*GPGKey, erro } // Check emails - userEmails, err := user_model.GetEmailAddresses(ownerID) + userEmails, err := user_model.GetEmailAddresses(ctx, ownerID) if err != nil { return nil, err } @@ -251,7 +251,7 @@ func DeleteGPGKey(doer *user_model.User, id int64) (err error) { return committer.Commit() } -func checkKeyEmails(email string, keys ...*GPGKey) (bool, string) { +func checkKeyEmails(ctx context.Context, email string, keys ...*GPGKey) (bool, string) { uid := int64(0) var userEmails []*user_model.EmailAddress var user *user_model.User @@ -263,10 +263,10 @@ func checkKeyEmails(email string, keys ...*GPGKey) (bool, string) { } if key.Verified && key.OwnerID != 0 { if uid != key.OwnerID { - userEmails, _ = user_model.GetEmailAddresses(key.OwnerID) + userEmails, _ = user_model.GetEmailAddresses(ctx, key.OwnerID) uid = key.OwnerID user = &user_model.User{ID: uid} - _, _ = user_model.GetUser(user) + _, _ = user_model.GetUser(ctx, user) } for _, e := range userEmails { if e.IsActivated && (email == "" || strings.EqualFold(e.Email, email)) { diff --git a/models/asymkey/gpg_key_add.go b/models/asymkey/gpg_key_add.go index eb4027b3a4..6926fd2143 100644 --- a/models/asymkey/gpg_key_add.go +++ b/models/asymkey/gpg_key_add.go @@ -153,7 +153,7 @@ func AddGPGKey(ownerID int64, content, token, signature string) ([]*GPGKey, erro // Get DB session - key, err := parseGPGKey(ownerID, ekey, verified) + key, err := parseGPGKey(ctx, ownerID, ekey, verified) if err != nil { return nil, err } diff --git a/models/asymkey/gpg_key_commit_verification.go b/models/asymkey/gpg_key_commit_verification.go index 65af0bc945..bf0fdd9a9a 100644 --- a/models/asymkey/gpg_key_commit_verification.go +++ b/models/asymkey/gpg_key_commit_verification.go @@ -125,7 +125,7 @@ func ParseCommitWithSignature(ctx context.Context, c *git.Commit) *CommitVerific // If this a SSH signature handle it differently if strings.HasPrefix(c.Signature.Signature, "-----BEGIN SSH SIGNATURE-----") { - return ParseCommitWithSSHSignature(c, committer) + return ParseCommitWithSSHSignature(ctx, c, committer) } // Parsing signature @@ -150,6 +150,7 @@ func ParseCommitWithSignature(ctx context.Context, c *git.Commit) *CommitVerific // First check if the sig has a keyID and if so just look at that if commitVerification := hashAndVerifyForKeyID( + ctx, sig, c.Signature.Payload, committer, @@ -165,7 +166,7 @@ func ParseCommitWithSignature(ctx context.Context, c *git.Commit) *CommitVerific // Now try to associate the signature with the committer, if present if committer.ID != 0 { - keys, err := ListGPGKeys(db.DefaultContext, committer.ID, db.ListOptions{}) + keys, err := ListGPGKeys(ctx, committer.ID, db.ListOptions{}) if err != nil { // Skipping failed to get gpg keys of user log.Error("ListGPGKeys: %v", err) return &CommitVerification{ @@ -175,7 +176,7 @@ func ParseCommitWithSignature(ctx context.Context, c *git.Commit) *CommitVerific } } - committerEmailAddresses, _ := user_model.GetEmailAddresses(committer.ID) + committerEmailAddresses, _ := user_model.GetEmailAddresses(ctx, committer.ID) activated := false for _, e := range committerEmailAddresses { if e.IsActivated && strings.EqualFold(e.Email, c.Committer.Email) { @@ -222,7 +223,7 @@ func ParseCommitWithSignature(ctx context.Context, c *git.Commit) *CommitVerific } if err := gpgSettings.LoadPublicKeyContent(); err != nil { log.Error("Error getting default signing key: %s %v", gpgSettings.KeyID, err) - } else if commitVerification := verifyWithGPGSettings(&gpgSettings, sig, c.Signature.Payload, committer, keyID); commitVerification != nil { + } else if commitVerification := verifyWithGPGSettings(ctx, &gpgSettings, sig, c.Signature.Payload, committer, keyID); commitVerification != nil { if commitVerification.Reason == BadSignature { defaultReason = BadSignature } else { @@ -237,7 +238,7 @@ func ParseCommitWithSignature(ctx context.Context, c *git.Commit) *CommitVerific } else if defaultGPGSettings == nil { log.Warn("Unable to get defaultGPGSettings for unattached commit: %s", c.ID.String()) } else if defaultGPGSettings.Sign { - if commitVerification := verifyWithGPGSettings(defaultGPGSettings, sig, c.Signature.Payload, committer, keyID); commitVerification != nil { + if commitVerification := verifyWithGPGSettings(ctx, defaultGPGSettings, sig, c.Signature.Payload, committer, keyID); commitVerification != nil { if commitVerification.Reason == BadSignature { defaultReason = BadSignature } else { @@ -257,9 +258,9 @@ func ParseCommitWithSignature(ctx context.Context, c *git.Commit) *CommitVerific } } -func verifyWithGPGSettings(gpgSettings *git.GPGSettings, sig *packet.Signature, payload string, committer *user_model.User, keyID string) *CommitVerification { +func verifyWithGPGSettings(ctx context.Context, gpgSettings *git.GPGSettings, sig *packet.Signature, payload string, committer *user_model.User, keyID string) *CommitVerification { // First try to find the key in the db - if commitVerification := hashAndVerifyForKeyID(sig, payload, committer, gpgSettings.KeyID, gpgSettings.Name, gpgSettings.Email); commitVerification != nil { + if commitVerification := hashAndVerifyForKeyID(ctx, sig, payload, committer, gpgSettings.KeyID, gpgSettings.Name, gpgSettings.Email); commitVerification != nil { return commitVerification } @@ -387,7 +388,7 @@ func hashAndVerifyWithSubKeysCommitVerification(sig *packet.Signature, payload s return nil } -func hashAndVerifyForKeyID(sig *packet.Signature, payload string, committer *user_model.User, keyID, name, email string) *CommitVerification { +func hashAndVerifyForKeyID(ctx context.Context, sig *packet.Signature, payload string, committer *user_model.User, keyID, name, email string) *CommitVerification { if keyID == "" { return nil } @@ -417,7 +418,7 @@ func hashAndVerifyForKeyID(sig *packet.Signature, payload string, committer *use } } - activated, email := checkKeyEmails(email, append([]*GPGKey{key}, primaryKeys...)...) + activated, email := checkKeyEmails(ctx, email, append([]*GPGKey{key}, primaryKeys...)...) if !activated { continue } @@ -427,7 +428,7 @@ func hashAndVerifyForKeyID(sig *packet.Signature, payload string, committer *use Email: email, } if key.OwnerID != 0 { - owner, err := user_model.GetUserByID(db.DefaultContext, key.OwnerID) + owner, err := user_model.GetUserByID(ctx, key.OwnerID) if err == nil { signer = owner } else if !user_model.IsErrUserNotExist(err) { diff --git a/models/asymkey/ssh_key_commit_verification.go b/models/asymkey/ssh_key_commit_verification.go index af73637c4a..80931c9af4 100644 --- a/models/asymkey/ssh_key_commit_verification.go +++ b/models/asymkey/ssh_key_commit_verification.go @@ -5,6 +5,7 @@ package asymkey import ( "bytes" + "context" "fmt" "strings" @@ -17,7 +18,7 @@ import ( ) // ParseCommitWithSSHSignature check if signature is good against keystore. -func ParseCommitWithSSHSignature(c *git.Commit, committer *user_model.User) *CommitVerification { +func ParseCommitWithSSHSignature(ctx context.Context, c *git.Commit, committer *user_model.User) *CommitVerification { // Now try to associate the signature with the committer, if present if committer.ID != 0 { keys, err := ListPublicKeys(committer.ID, db.ListOptions{}) @@ -30,7 +31,7 @@ func ParseCommitWithSSHSignature(c *git.Commit, committer *user_model.User) *Com } } - committerEmailAddresses, err := user_model.GetEmailAddresses(committer.ID) + committerEmailAddresses, err := user_model.GetEmailAddresses(ctx, committer.ID) if err != nil { log.Error("GetEmailAddresses: %v", err) } diff --git a/models/asymkey/ssh_key_principals.go b/models/asymkey/ssh_key_principals.go index 6d43437ec1..150b77c7b2 100644 --- a/models/asymkey/ssh_key_principals.go +++ b/models/asymkey/ssh_key_principals.go @@ -4,6 +4,7 @@ package asymkey import ( + "context" "fmt" "strings" @@ -63,7 +64,7 @@ func AddPrincipalKey(ownerID int64, content string, authSourceID int64) (*Public } // CheckPrincipalKeyString strips spaces and returns an error if the given principal contains newlines -func CheckPrincipalKeyString(user *user_model.User, content string) (_ string, err error) { +func CheckPrincipalKeyString(ctx context.Context, user *user_model.User, content string) (_ string, err error) { if setting.SSH.Disabled { return "", db.ErrSSHDisabled{} } @@ -80,7 +81,7 @@ func CheckPrincipalKeyString(user *user_model.User, content string) (_ string, e case "anything": return content, nil case "email": - emails, err := user_model.GetEmailAddresses(user.ID) + emails, err := user_model.GetEmailAddresses(ctx, user.ID) if err != nil { return "", err } diff --git a/models/org_team.go b/models/org_team.go index cf3680990d..7ddf986ce9 100644 --- a/models/org_team.go +++ b/models/org_team.go @@ -73,8 +73,8 @@ func addAllRepositories(ctx context.Context, t *organization.Team) error { } // AddAllRepositories adds all repositories to the team -func AddAllRepositories(t *organization.Team) (err error) { - ctx, committer, err := db.TxContext(db.DefaultContext) +func AddAllRepositories(ctx context.Context, t *organization.Team) (err error) { + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } @@ -88,12 +88,12 @@ func AddAllRepositories(t *organization.Team) (err error) { } // RemoveAllRepositories removes all repositories from team and recalculates access -func RemoveAllRepositories(t *organization.Team) (err error) { +func RemoveAllRepositories(ctx context.Context, t *organization.Team) (err error) { if t.IncludesAllRepositories { return nil } - ctx, committer, err := db.TxContext(db.DefaultContext) + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } @@ -153,7 +153,7 @@ func removeAllRepositories(ctx context.Context, t *organization.Team) (err error // NewTeam creates a record of new team. // It's caller's responsibility to assign organization ID. -func NewTeam(t *organization.Team) (err error) { +func NewTeam(ctx context.Context, t *organization.Team) (err error) { if len(t.Name) == 0 { return util.NewInvalidArgumentErrorf("empty team name") } @@ -162,7 +162,7 @@ func NewTeam(t *organization.Team) (err error) { return err } - has, err := db.GetEngine(db.DefaultContext).ID(t.OrgID).Get(new(user_model.User)) + has, err := db.GetEngine(ctx).ID(t.OrgID).Get(new(user_model.User)) if err != nil { return err } @@ -171,7 +171,7 @@ func NewTeam(t *organization.Team) (err error) { } t.LowerName = strings.ToLower(t.Name) - has, err = db.GetEngine(db.DefaultContext). + has, err = db.GetEngine(ctx). Where("org_id=?", t.OrgID). And("lower_name=?", t.LowerName). Get(new(organization.Team)) @@ -182,7 +182,7 @@ func NewTeam(t *organization.Team) (err error) { return organization.ErrTeamAlreadyExist{OrgID: t.OrgID, Name: t.LowerName} } - ctx, committer, err := db.TxContext(db.DefaultContext) + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } @@ -218,7 +218,7 @@ func NewTeam(t *organization.Team) (err error) { } // UpdateTeam updates information of team. -func UpdateTeam(t *organization.Team, authChanged, includeAllChanged bool) (err error) { +func UpdateTeam(ctx context.Context, t *organization.Team, authChanged, includeAllChanged bool) (err error) { if len(t.Name) == 0 { return util.NewInvalidArgumentErrorf("empty team name") } @@ -227,7 +227,7 @@ func UpdateTeam(t *organization.Team, authChanged, includeAllChanged bool) (err t.Description = t.Description[:255] } - ctx, committer, err := db.TxContext(db.DefaultContext) + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } @@ -293,8 +293,8 @@ func UpdateTeam(t *organization.Team, authChanged, includeAllChanged bool) (err // DeleteTeam deletes given team. // It's caller's responsibility to assign organization ID. -func DeleteTeam(t *organization.Team) error { - ctx, committer, err := db.TxContext(db.DefaultContext) +func DeleteTeam(ctx context.Context, t *organization.Team) error { + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } @@ -356,8 +356,8 @@ func DeleteTeam(t *organization.Team) error { // AddTeamMember adds new membership of given team to given organization, // the user will have membership to given organization automatically when needed. -func AddTeamMember(team *organization.Team, userID int64) error { - isAlreadyMember, err := organization.IsTeamMember(db.DefaultContext, team.OrgID, team.ID, userID) +func AddTeamMember(ctx context.Context, team *organization.Team, userID int64) error { + isAlreadyMember, err := organization.IsTeamMember(ctx, team.OrgID, team.ID, userID) if err != nil || isAlreadyMember { return err } @@ -366,7 +366,7 @@ func AddTeamMember(team *organization.Team, userID int64) error { return err } - ctx, committer, err := db.TxContext(db.DefaultContext) + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } @@ -423,18 +423,14 @@ func AddTeamMember(team *organization.Team, userID int64) error { } } - if err := committer.Commit(); err != nil { - return err - } - committer.Close() - // this behaviour may spend much time so run it in a goroutine // FIXME: Update watch repos batchly if setting.Service.AutoWatchNewRepos { // Get team and its repositories. - if err := team.LoadRepositories(db.DefaultContext); err != nil { + if err := team.LoadRepositories(ctx); err != nil { log.Error("getRepositories failed: %v", err) } + // FIXME: in the goroutine, it can't access the "ctx", it could only use db.DefaultContext at the moment go func(repos []*repo_model.Repository) { for _, repo := range repos { if err = repo_model.WatchRepo(db.DefaultContext, userID, repo.ID, true); err != nil { @@ -444,7 +440,7 @@ func AddTeamMember(team *organization.Team, userID int64) error { }(team.Repos) } - return nil + return committer.Commit() } func removeTeamMember(ctx context.Context, team *organization.Team, userID int64) error { @@ -512,8 +508,8 @@ func removeInvalidOrgUser(ctx context.Context, userID, orgID int64) error { } // RemoveTeamMember removes member from given team of given organization. -func RemoveTeamMember(team *organization.Team, userID int64) error { - ctx, committer, err := db.TxContext(db.DefaultContext) +func RemoveTeamMember(ctx context.Context, team *organization.Team, userID int64) error { + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } diff --git a/models/org_team_test.go b/models/org_team_test.go index 4978f8ef99..e4b7b917e8 100644 --- a/models/org_team_test.go +++ b/models/org_team_test.go @@ -23,7 +23,7 @@ func TestTeam_AddMember(t *testing.T) { test := func(teamID, userID int64) { team := unittest.AssertExistsAndLoadBean(t, &organization.Team{ID: teamID}) - assert.NoError(t, AddTeamMember(team, userID)) + assert.NoError(t, AddTeamMember(db.DefaultContext, team, userID)) unittest.AssertExistsAndLoadBean(t, &organization.TeamUser{UID: userID, TeamID: teamID}) unittest.CheckConsistencyFor(t, &organization.Team{ID: teamID}, &user_model.User{ID: team.OrgID}) } @@ -37,7 +37,7 @@ func TestTeam_RemoveMember(t *testing.T) { testSuccess := func(teamID, userID int64) { team := unittest.AssertExistsAndLoadBean(t, &organization.Team{ID: teamID}) - assert.NoError(t, RemoveTeamMember(team, userID)) + assert.NoError(t, RemoveTeamMember(db.DefaultContext, team, userID)) unittest.AssertNotExistsBean(t, &organization.TeamUser{UID: userID, TeamID: teamID}) unittest.CheckConsistencyFor(t, &organization.Team{ID: teamID}) } @@ -47,7 +47,7 @@ func TestTeam_RemoveMember(t *testing.T) { testSuccess(3, unittest.NonexistentID) team := unittest.AssertExistsAndLoadBean(t, &organization.Team{ID: 1}) - err := RemoveTeamMember(team, 2) + err := RemoveTeamMember(db.DefaultContext, team, 2) assert.True(t, organization.IsErrLastOrgOwner(err)) } @@ -61,7 +61,7 @@ func TestNewTeam(t *testing.T) { const teamName = "newTeamName" team := &organization.Team{Name: teamName, OrgID: 3} - assert.NoError(t, NewTeam(team)) + assert.NoError(t, NewTeam(db.DefaultContext, team)) unittest.AssertExistsAndLoadBean(t, &organization.Team{Name: teamName}) unittest.CheckConsistencyFor(t, &organization.Team{}, &user_model.User{ID: team.OrgID}) } @@ -75,7 +75,7 @@ func TestUpdateTeam(t *testing.T) { team.Name = "newName" team.Description = strings.Repeat("A long description!", 100) team.AccessMode = perm.AccessModeAdmin - assert.NoError(t, UpdateTeam(team, true, false)) + assert.NoError(t, UpdateTeam(db.DefaultContext, team, true, false)) team = unittest.AssertExistsAndLoadBean(t, &organization.Team{Name: "newName"}) assert.True(t, strings.HasPrefix(team.Description, "A long description!")) @@ -94,7 +94,7 @@ func TestUpdateTeam2(t *testing.T) { team.LowerName = "owners" team.Name = "Owners" team.Description = strings.Repeat("A long description!", 100) - err := UpdateTeam(team, true, false) + err := UpdateTeam(db.DefaultContext, team, true, false) assert.True(t, organization.IsErrTeamAlreadyExist(err)) unittest.CheckConsistencyFor(t, &organization.Team{ID: team.ID}) @@ -104,7 +104,7 @@ func TestDeleteTeam(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) team := unittest.AssertExistsAndLoadBean(t, &organization.Team{ID: 2}) - assert.NoError(t, DeleteTeam(team)) + assert.NoError(t, DeleteTeam(db.DefaultContext, team)) unittest.AssertNotExistsBean(t, &organization.Team{ID: team.ID}) unittest.AssertNotExistsBean(t, &organization.TeamRepo{TeamID: team.ID}) unittest.AssertNotExistsBean(t, &organization.TeamUser{TeamID: team.ID}) @@ -122,7 +122,7 @@ func TestAddTeamMember(t *testing.T) { test := func(teamID, userID int64) { team := unittest.AssertExistsAndLoadBean(t, &organization.Team{ID: teamID}) - assert.NoError(t, AddTeamMember(team, userID)) + assert.NoError(t, AddTeamMember(db.DefaultContext, team, userID)) unittest.AssertExistsAndLoadBean(t, &organization.TeamUser{UID: userID, TeamID: teamID}) unittest.CheckConsistencyFor(t, &organization.Team{ID: teamID}, &user_model.User{ID: team.OrgID}) } @@ -136,7 +136,7 @@ func TestRemoveTeamMember(t *testing.T) { testSuccess := func(teamID, userID int64) { team := unittest.AssertExistsAndLoadBean(t, &organization.Team{ID: teamID}) - assert.NoError(t, RemoveTeamMember(team, userID)) + assert.NoError(t, RemoveTeamMember(db.DefaultContext, team, userID)) unittest.AssertNotExistsBean(t, &organization.TeamUser{UID: userID, TeamID: teamID}) unittest.CheckConsistencyFor(t, &organization.Team{ID: teamID}) } @@ -146,7 +146,7 @@ func TestRemoveTeamMember(t *testing.T) { testSuccess(3, unittest.NonexistentID) team := unittest.AssertExistsAndLoadBean(t, &organization.Team{ID: 1}) - err := RemoveTeamMember(team, 2) + err := RemoveTeamMember(db.DefaultContext, team, 2) assert.True(t, organization.IsErrLastOrgOwner(err)) } @@ -161,7 +161,7 @@ func TestRepository_RecalculateAccesses3(t *testing.T) { // adding user29 to team5 should add an explicit access row for repo 23 // even though repo 23 is public - assert.NoError(t, AddTeamMember(team5, user29.ID)) + assert.NoError(t, AddTeamMember(db.DefaultContext, team5, user29.ID)) has, err = db.GetEngine(db.DefaultContext).Get(&access_model.Access{UserID: 29, RepoID: 23}) assert.NoError(t, err) diff --git a/models/organization/org.go b/models/organization/org.go index 8fd4ad076b..260571b4b2 100644 --- a/models/organization/org.go +++ b/models/organization/org.go @@ -140,8 +140,8 @@ func (org *Organization) LoadTeams() ([]*Team, error) { } // GetMembers returns all members of organization. -func (org *Organization) GetMembers() (user_model.UserList, map[int64]bool, error) { - return FindOrgMembers(&FindOrgMembersOpts{ +func (org *Organization) GetMembers(ctx context.Context) (user_model.UserList, map[int64]bool, error) { + return FindOrgMembers(ctx, &FindOrgMembersOpts{ OrgID: org.ID, }) } @@ -208,8 +208,8 @@ func CountOrgMembers(opts *FindOrgMembersOpts) (int64, error) { } // FindOrgMembers loads organization members according conditions -func FindOrgMembers(opts *FindOrgMembersOpts) (user_model.UserList, map[int64]bool, error) { - ous, err := GetOrgUsersByOrgID(db.DefaultContext, opts) +func FindOrgMembers(ctx context.Context, opts *FindOrgMembersOpts) (user_model.UserList, map[int64]bool, error) { + ous, err := GetOrgUsersByOrgID(ctx, opts) if err != nil { return nil, nil, err } @@ -221,7 +221,7 @@ func FindOrgMembers(opts *FindOrgMembersOpts) (user_model.UserList, map[int64]bo idsIsPublic[ou.UID] = ou.IsPublic } - users, err := user_model.GetUsersByIDs(ids) + users, err := user_model.GetUsersByIDs(ctx, ids) if err != nil { return nil, nil, err } @@ -520,10 +520,10 @@ func HasOrgsVisible(orgs []*Organization, user *user_model.User) bool { // GetOrgsCanCreateRepoByUserID returns a list of organizations where given user ID // are allowed to create repos. -func GetOrgsCanCreateRepoByUserID(userID int64) ([]*Organization, error) { +func GetOrgsCanCreateRepoByUserID(ctx context.Context, userID int64) ([]*Organization, error) { orgs := make([]*Organization, 0, 10) - return orgs, db.GetEngine(db.DefaultContext).Where(builder.In("id", builder.Select("`user`.id").From("`user`"). + return orgs, db.GetEngine(ctx).Where(builder.In("id", builder.Select("`user`.id").From("`user`"). Join("INNER", "`team_user`", "`team_user`.org_id = `user`.id"). Join("INNER", "`team`", "`team`.id = `team_user`.team_id"). Where(builder.Eq{"`team_user`.uid": userID}). diff --git a/models/organization/org_test.go b/models/organization/org_test.go index d36736b5c2..b5dff9ec01 100644 --- a/models/organization/org_test.go +++ b/models/organization/org_test.go @@ -103,7 +103,7 @@ func TestUser_GetTeams(t *testing.T) { func TestUser_GetMembers(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) org := unittest.AssertExistsAndLoadBean(t, &organization.Organization{ID: 3}) - members, _, err := org.GetMembers() + members, _, err := org.GetMembers(db.DefaultContext) assert.NoError(t, err) if assert.Len(t, members, 3) { assert.Equal(t, int64(2), members[0].ID) diff --git a/models/organization/org_user_test.go b/models/organization/org_user_test.go index b6477f859c..1c3cf2798d 100644 --- a/models/organization/org_user_test.go +++ b/models/organization/org_user_test.go @@ -94,7 +94,7 @@ func TestUserListIsPublicMember(t *testing.T) { func testUserListIsPublicMember(t *testing.T, orgID int64, expected map[int64]bool) { org, err := organization.GetOrgByID(db.DefaultContext, orgID) assert.NoError(t, err) - _, membersIsPublic, err := org.GetMembers() + _, membersIsPublic, err := org.GetMembers(db.DefaultContext) assert.NoError(t, err) assert.Equal(t, expected, membersIsPublic) } @@ -121,7 +121,7 @@ func TestUserListIsUserOrgOwner(t *testing.T) { func testUserListIsUserOrgOwner(t *testing.T, orgID int64, expected map[int64]bool) { org, err := organization.GetOrgByID(db.DefaultContext, orgID) assert.NoError(t, err) - members, _, err := org.GetMembers() + members, _, err := org.GetMembers(db.DefaultContext) assert.NoError(t, err) assert.Equal(t, expected, organization.IsUserOrgOwner(members, orgID)) } diff --git a/models/repo/fork.go b/models/repo/fork.go index eafbab0fb1..6be6ebc3f5 100644 --- a/models/repo/fork.go +++ b/models/repo/fork.go @@ -21,9 +21,9 @@ func GetRepositoriesByForkID(ctx context.Context, forkID int64) ([]*Repository, } // GetForkedRepo checks if given user has already forked a repository with given ID. -func GetForkedRepo(ownerID, repoID int64) *Repository { +func GetForkedRepo(ctx context.Context, ownerID, repoID int64) *Repository { repo := new(Repository) - has, _ := db.GetEngine(db.DefaultContext). + has, _ := db.GetEngine(ctx). Where("owner_id=? AND fork_id=?", ownerID, repoID). Get(repo) if has { @@ -33,8 +33,8 @@ func GetForkedRepo(ownerID, repoID int64) *Repository { } // HasForkedRepo checks if given user has already forked a repository with given ID. -func HasForkedRepo(ownerID, repoID int64) bool { - has, _ := db.GetEngine(db.DefaultContext). +func HasForkedRepo(ctx context.Context, ownerID, repoID int64) bool { + has, _ := db.GetEngine(ctx). Table("repository"). Where("owner_id=? AND fork_id=?", ownerID, repoID). Exist() @@ -55,10 +55,10 @@ func GetUserFork(ctx context.Context, repoID, userID int64) (*Repository, error) } // GetForks returns all the forks of the repository -func GetForks(repo *Repository, listOptions db.ListOptions) ([]*Repository, error) { +func GetForks(ctx context.Context, repo *Repository, listOptions db.ListOptions) ([]*Repository, error) { if listOptions.Page == 0 { forks := make([]*Repository, 0, repo.NumForks) - return forks, db.GetEngine(db.DefaultContext).Find(&forks, &Repository{ForkID: repo.ID}) + return forks, db.GetEngine(ctx).Find(&forks, &Repository{ForkID: repo.ID}) } sess := db.GetPaginatedSession(&listOptions) diff --git a/models/repo/git.go b/models/repo/git.go index 2f71128b5a..610c554296 100644 --- a/models/repo/git.go +++ b/models/repo/git.go @@ -4,6 +4,8 @@ package repo import ( + "context" + "code.gitea.io/gitea/models/db" ) @@ -26,7 +28,7 @@ const ( ) // UpdateDefaultBranch updates the default branch -func UpdateDefaultBranch(repo *Repository) error { - _, err := db.GetEngine(db.DefaultContext).ID(repo.ID).Cols("default_branch").Update(repo) +func UpdateDefaultBranch(ctx context.Context, repo *Repository) error { + _, err := db.GetEngine(ctx).ID(repo.ID).Cols("default_branch").Update(repo) return err } diff --git a/models/user/email_address.go b/models/user/email_address.go index e916249e30..f1ed6692cf 100644 --- a/models/user/email_address.go +++ b/models/user/email_address.go @@ -178,9 +178,9 @@ func ValidateEmail(email string) error { } // GetEmailAddresses returns all email addresses belongs to given user. -func GetEmailAddresses(uid int64) ([]*EmailAddress, error) { +func GetEmailAddresses(ctx context.Context, uid int64) ([]*EmailAddress, error) { emails := make([]*EmailAddress, 0, 5) - if err := db.GetEngine(db.DefaultContext). + if err := db.GetEngine(ctx). Where("uid=?", uid). Asc("id"). Find(&emails); err != nil { @@ -190,10 +190,10 @@ func GetEmailAddresses(uid int64) ([]*EmailAddress, error) { } // GetEmailAddressByID gets a user's email address by ID -func GetEmailAddressByID(uid, id int64) (*EmailAddress, error) { +func GetEmailAddressByID(ctx context.Context, uid, id int64) (*EmailAddress, error) { // User ID is required for security reasons email := &EmailAddress{UID: uid} - if has, err := db.GetEngine(db.DefaultContext).ID(id).Get(email); err != nil { + if has, err := db.GetEngine(ctx).ID(id).Get(email); err != nil { return nil, err } else if !has { return nil, nil @@ -253,7 +253,7 @@ func AddEmailAddress(ctx context.Context, email *EmailAddress) error { } // AddEmailAddresses adds an email address to given user. -func AddEmailAddresses(emails []*EmailAddress) error { +func AddEmailAddresses(ctx context.Context, emails []*EmailAddress) error { if len(emails) == 0 { return nil } @@ -261,7 +261,7 @@ func AddEmailAddresses(emails []*EmailAddress) error { // Check if any of them has been used for i := range emails { emails[i].Email = strings.TrimSpace(emails[i].Email) - used, err := IsEmailUsed(db.DefaultContext, emails[i].Email) + used, err := IsEmailUsed(ctx, emails[i].Email) if err != nil { return err } else if used { @@ -272,7 +272,7 @@ func AddEmailAddresses(emails []*EmailAddress) error { } } - if err := db.Insert(db.DefaultContext, emails); err != nil { + if err := db.Insert(ctx, emails); err != nil { return fmt.Errorf("Insert: %w", err) } @@ -280,7 +280,7 @@ func AddEmailAddresses(emails []*EmailAddress) error { } // DeleteEmailAddress deletes an email address of given user. -func DeleteEmailAddress(email *EmailAddress) (err error) { +func DeleteEmailAddress(ctx context.Context, email *EmailAddress) (err error) { if email.IsPrimary { return ErrPrimaryEmailCannotDelete{Email: email.Email} } @@ -291,12 +291,12 @@ func DeleteEmailAddress(email *EmailAddress) (err error) { UID: email.UID, } if email.ID > 0 { - deleted, err = db.GetEngine(db.DefaultContext).ID(email.ID).Delete(&address) + deleted, err = db.GetEngine(ctx).ID(email.ID).Delete(&address) } else { if email.Email != "" && email.LowerEmail == "" { email.LowerEmail = strings.ToLower(email.Email) } - deleted, err = db.GetEngine(db.DefaultContext). + deleted, err = db.GetEngine(ctx). Where("lower_email=?", email.LowerEmail). Delete(&address) } @@ -310,9 +310,9 @@ func DeleteEmailAddress(email *EmailAddress) (err error) { } // DeleteEmailAddresses deletes multiple email addresses -func DeleteEmailAddresses(emails []*EmailAddress) (err error) { +func DeleteEmailAddresses(ctx context.Context, emails []*EmailAddress) (err error) { for i := range emails { - if err = DeleteEmailAddress(emails[i]); err != nil { + if err = DeleteEmailAddress(ctx, emails[i]); err != nil { return err } } @@ -329,8 +329,8 @@ func DeleteInactiveEmailAddresses(ctx context.Context) error { } // ActivateEmail activates the email address to given user. -func ActivateEmail(email *EmailAddress) error { - ctx, committer, err := db.TxContext(db.DefaultContext) +func ActivateEmail(ctx context.Context, email *EmailAddress) error { + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } @@ -357,8 +357,8 @@ func updateActivation(ctx context.Context, email *EmailAddress, activate bool) e } // MakeEmailPrimary sets primary email address of given user. -func MakeEmailPrimary(email *EmailAddress) error { - has, err := db.GetEngine(db.DefaultContext).Get(email) +func MakeEmailPrimary(ctx context.Context, email *EmailAddress) error { + has, err := db.GetEngine(ctx).Get(email) if err != nil { return err } else if !has { @@ -370,7 +370,7 @@ func MakeEmailPrimary(email *EmailAddress) error { } user := &User{} - has, err = db.GetEngine(db.DefaultContext).ID(email.UID).Get(user) + has, err = db.GetEngine(ctx).ID(email.UID).Get(user) if err != nil { return err } else if !has { @@ -381,7 +381,7 @@ func MakeEmailPrimary(email *EmailAddress) error { } } - ctx, committer, err := db.TxContext(db.DefaultContext) + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } @@ -411,17 +411,17 @@ func MakeEmailPrimary(email *EmailAddress) error { } // VerifyActiveEmailCode verifies active email code when active account -func VerifyActiveEmailCode(code, email string) *EmailAddress { +func VerifyActiveEmailCode(ctx context.Context, code, email string) *EmailAddress { minutes := setting.Service.ActiveCodeLives - if user := GetVerifyUser(code); user != nil { + if user := GetVerifyUser(ctx, code); user != nil { // time limit code prefix := code[:base.TimeLimitCodeLength] data := fmt.Sprintf("%d%s%s%s%s", user.ID, email, user.LowerName, user.Passwd, user.Rands) if base.VerifyTimeLimitCode(data, minutes, prefix) { emailAddress := &EmailAddress{UID: user.ID, Email: email} - if has, _ := db.GetEngine(db.DefaultContext).Get(emailAddress); has { + if has, _ := db.GetEngine(ctx).Get(emailAddress); has { return emailAddress } } @@ -466,7 +466,7 @@ type SearchEmailResult struct { // SearchEmails takes options i.e. keyword and part of email name to search, // it returns results in given range and number of total results. -func SearchEmails(opts *SearchEmailOptions) ([]*SearchEmailResult, int64, error) { +func SearchEmails(ctx context.Context, opts *SearchEmailOptions) ([]*SearchEmailResult, int64, error) { var cond builder.Cond = builder.Eq{"`user`.`type`": UserTypeIndividual} if len(opts.Keyword) > 0 { likeStr := "%" + strings.ToLower(opts.Keyword) + "%" @@ -491,7 +491,7 @@ func SearchEmails(opts *SearchEmailOptions) ([]*SearchEmailResult, int64, error) cond = cond.And(builder.Eq{"email_address.is_activated": false}) } - count, err := db.GetEngine(db.DefaultContext).Join("INNER", "`user`", "`user`.ID = email_address.uid"). + count, err := db.GetEngine(ctx).Join("INNER", "`user`", "`user`.ID = email_address.uid"). Where(cond).Count(new(EmailAddress)) if err != nil { return nil, 0, fmt.Errorf("Count: %w", err) @@ -505,7 +505,7 @@ func SearchEmails(opts *SearchEmailOptions) ([]*SearchEmailResult, int64, error) opts.SetDefaultValues() emails := make([]*SearchEmailResult, 0, opts.PageSize) - err = db.GetEngine(db.DefaultContext).Table("email_address"). + err = db.GetEngine(ctx).Table("email_address"). Select("email_address.*, `user`.name, `user`.full_name"). Join("INNER", "`user`", "`user`.ID = email_address.uid"). Where(cond). @@ -518,8 +518,8 @@ func SearchEmails(opts *SearchEmailOptions) ([]*SearchEmailResult, int64, error) // ActivateUserEmail will change the activated state of an email address, // either primary or secondary (all in the email_address table) -func ActivateUserEmail(userID int64, email string, activate bool) (err error) { - ctx, committer, err := db.TxContext(db.DefaultContext) +func ActivateUserEmail(ctx context.Context, userID int64, email string, activate bool) (err error) { + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } diff --git a/models/user/email_address_test.go b/models/user/email_address_test.go index f2b383fe4b..7f3ca75cfd 100644 --- a/models/user/email_address_test.go +++ b/models/user/email_address_test.go @@ -17,14 +17,14 @@ import ( func TestGetEmailAddresses(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - emails, _ := user_model.GetEmailAddresses(int64(1)) + emails, _ := user_model.GetEmailAddresses(db.DefaultContext, int64(1)) if assert.Len(t, emails, 3) { assert.True(t, emails[0].IsPrimary) assert.True(t, emails[2].IsActivated) assert.False(t, emails[2].IsPrimary) } - emails, _ = user_model.GetEmailAddresses(int64(2)) + emails, _ = user_model.GetEmailAddresses(db.DefaultContext, int64(2)) if assert.Len(t, emails, 2) { assert.True(t, emails[0].IsPrimary) assert.True(t, emails[0].IsActivated) @@ -76,10 +76,10 @@ func TestAddEmailAddresses(t *testing.T) { LowerEmail: "user5678@example.com", IsActivated: true, } - assert.NoError(t, user_model.AddEmailAddresses(emails)) + assert.NoError(t, user_model.AddEmailAddresses(db.DefaultContext, emails)) // ErrEmailAlreadyUsed - err := user_model.AddEmailAddresses(emails) + err := user_model.AddEmailAddresses(db.DefaultContext, emails) assert.Error(t, err) assert.True(t, user_model.IsErrEmailAlreadyUsed(err)) } @@ -87,21 +87,21 @@ func TestAddEmailAddresses(t *testing.T) { func TestDeleteEmailAddress(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - assert.NoError(t, user_model.DeleteEmailAddress(&user_model.EmailAddress{ + assert.NoError(t, user_model.DeleteEmailAddress(db.DefaultContext, &user_model.EmailAddress{ UID: int64(1), ID: int64(33), Email: "user1-2@example.com", LowerEmail: "user1-2@example.com", })) - assert.NoError(t, user_model.DeleteEmailAddress(&user_model.EmailAddress{ + assert.NoError(t, user_model.DeleteEmailAddress(db.DefaultContext, &user_model.EmailAddress{ UID: int64(1), Email: "user1-3@example.com", LowerEmail: "user1-3@example.com", })) // Email address does not exist - err := user_model.DeleteEmailAddress(&user_model.EmailAddress{ + err := user_model.DeleteEmailAddress(db.DefaultContext, &user_model.EmailAddress{ UID: int64(1), Email: "user1234567890@example.com", LowerEmail: "user1234567890@example.com", @@ -125,10 +125,10 @@ func TestDeleteEmailAddresses(t *testing.T) { Email: "user2-2@example.com", LowerEmail: "user2-2@example.com", } - assert.NoError(t, user_model.DeleteEmailAddresses(emails)) + assert.NoError(t, user_model.DeleteEmailAddresses(db.DefaultContext, emails)) // ErrEmailAlreadyUsed - err := user_model.DeleteEmailAddresses(emails) + err := user_model.DeleteEmailAddresses(db.DefaultContext, emails) assert.Error(t, err) } @@ -138,28 +138,28 @@ func TestMakeEmailPrimary(t *testing.T) { email := &user_model.EmailAddress{ Email: "user567890@example.com", } - err := user_model.MakeEmailPrimary(email) + err := user_model.MakeEmailPrimary(db.DefaultContext, email) assert.Error(t, err) assert.EqualError(t, err, user_model.ErrEmailAddressNotExist{Email: email.Email}.Error()) email = &user_model.EmailAddress{ Email: "user11@example.com", } - err = user_model.MakeEmailPrimary(email) + err = user_model.MakeEmailPrimary(db.DefaultContext, email) assert.Error(t, err) assert.EqualError(t, err, user_model.ErrEmailNotActivated.Error()) email = &user_model.EmailAddress{ Email: "user9999999@example.com", } - err = user_model.MakeEmailPrimary(email) + err = user_model.MakeEmailPrimary(db.DefaultContext, email) assert.Error(t, err) assert.True(t, user_model.IsErrUserNotExist(err)) email = &user_model.EmailAddress{ Email: "user101@example.com", } - err = user_model.MakeEmailPrimary(email) + err = user_model.MakeEmailPrimary(db.DefaultContext, email) assert.NoError(t, err) user, _ := user_model.GetUserByID(db.DefaultContext, int64(10)) @@ -174,9 +174,9 @@ func TestActivate(t *testing.T) { UID: int64(1), Email: "user11@example.com", } - assert.NoError(t, user_model.ActivateEmail(email)) + assert.NoError(t, user_model.ActivateEmail(db.DefaultContext, email)) - emails, _ := user_model.GetEmailAddresses(int64(1)) + emails, _ := user_model.GetEmailAddresses(db.DefaultContext, int64(1)) assert.Len(t, emails, 3) assert.True(t, emails[0].IsActivated) assert.True(t, emails[0].IsPrimary) @@ -194,7 +194,7 @@ func TestListEmails(t *testing.T) { PageSize: 10000, }, } - emails, count, err := user_model.SearchEmails(opts) + emails, count, err := user_model.SearchEmails(db.DefaultContext, opts) assert.NoError(t, err) assert.NotEqual(t, int64(0), count) assert.True(t, count > 5) @@ -214,13 +214,13 @@ func TestListEmails(t *testing.T) { // Must find no records opts = &user_model.SearchEmailOptions{Keyword: "NOTFOUND"} - emails, count, err = user_model.SearchEmails(opts) + emails, count, err = user_model.SearchEmails(db.DefaultContext, opts) assert.NoError(t, err) assert.Equal(t, int64(0), count) // Must find users 'user2', 'user28', etc. opts = &user_model.SearchEmailOptions{Keyword: "user2"} - emails, count, err = user_model.SearchEmails(opts) + emails, count, err = user_model.SearchEmails(db.DefaultContext, opts) assert.NoError(t, err) assert.NotEqual(t, int64(0), count) assert.True(t, contains(func(s *user_model.SearchEmailResult) bool { return s.UID == 2 })) @@ -228,14 +228,14 @@ func TestListEmails(t *testing.T) { // Must find only primary addresses (i.e. from the `user` table) opts = &user_model.SearchEmailOptions{IsPrimary: util.OptionalBoolTrue} - emails, _, err = user_model.SearchEmails(opts) + emails, _, err = user_model.SearchEmails(db.DefaultContext, opts) assert.NoError(t, err) assert.True(t, contains(func(s *user_model.SearchEmailResult) bool { return s.IsPrimary })) assert.False(t, contains(func(s *user_model.SearchEmailResult) bool { return !s.IsPrimary })) // Must find only inactive addresses (i.e. not validated) opts = &user_model.SearchEmailOptions{IsActivated: util.OptionalBoolFalse} - emails, _, err = user_model.SearchEmails(opts) + emails, _, err = user_model.SearchEmails(db.DefaultContext, opts) assert.NoError(t, err) assert.True(t, contains(func(s *user_model.SearchEmailResult) bool { return !s.IsActivated })) assert.False(t, contains(func(s *user_model.SearchEmailResult) bool { return s.IsActivated })) @@ -247,7 +247,7 @@ func TestListEmails(t *testing.T) { Page: 1, }, } - emails, count, err = user_model.SearchEmails(opts) + emails, count, err = user_model.SearchEmails(db.DefaultContext, opts) assert.NoError(t, err) assert.Len(t, emails, 5) assert.Greater(t, count, int64(len(emails))) diff --git a/models/user/list.go b/models/user/list.go index 6b3b7bea9a..ca589d1e02 100644 --- a/models/user/list.go +++ b/models/user/list.go @@ -25,19 +25,19 @@ func (users UserList) GetUserIDs() []int64 { } // GetTwoFaStatus return state of 2FA enrollement -func (users UserList) GetTwoFaStatus() map[int64]bool { +func (users UserList) GetTwoFaStatus(ctx context.Context) map[int64]bool { results := make(map[int64]bool, len(users)) for _, user := range users { results[user.ID] = false // Set default to false } - if tokenMaps, err := users.loadTwoFactorStatus(db.DefaultContext); err == nil { + if tokenMaps, err := users.loadTwoFactorStatus(ctx); err == nil { for _, token := range tokenMaps { results[token.UID] = true } } - if ids, err := users.userIDsWithWebAuthn(db.DefaultContext); err == nil { + if ids, err := users.userIDsWithWebAuthn(ctx); err == nil { for _, id := range ids { results[id] = true } @@ -71,12 +71,12 @@ func (users UserList) userIDsWithWebAuthn(ctx context.Context) ([]int64, error) } // GetUsersByIDs returns all resolved users from a list of Ids. -func GetUsersByIDs(ids []int64) (UserList, error) { +func GetUsersByIDs(ctx context.Context, ids []int64) (UserList, error) { ous := make([]*User, 0, len(ids)) if len(ids) == 0 { return ous, nil } - err := db.GetEngine(db.DefaultContext).In("id", ids). + err := db.GetEngine(ctx).In("id", ids). Asc("name"). Find(&ous) return ous, err diff --git a/models/user/search.go b/models/user/search.go index 446556f89b..0fa278c257 100644 --- a/models/user/search.go +++ b/models/user/search.go @@ -4,6 +4,7 @@ package user import ( + "context" "fmt" "strings" @@ -39,7 +40,7 @@ type SearchUserOptions struct { ExtraParamStrings map[string]string } -func (opts *SearchUserOptions) toSearchQueryBase() *xorm.Session { +func (opts *SearchUserOptions) toSearchQueryBase(ctx context.Context) *xorm.Session { var cond builder.Cond cond = builder.Eq{"type": opts.Type} if opts.IncludeReserved { @@ -101,7 +102,7 @@ func (opts *SearchUserOptions) toSearchQueryBase() *xorm.Session { cond = cond.And(builder.Eq{"prohibit_login": opts.IsProhibitLogin.IsTrue()}) } - e := db.GetEngine(db.DefaultContext) + e := db.GetEngine(ctx) if opts.IsTwoFactorEnabled.IsNone() { return e.Where(cond) } @@ -122,8 +123,8 @@ func (opts *SearchUserOptions) toSearchQueryBase() *xorm.Session { // SearchUsers takes options i.e. keyword and part of user name to search, // it returns results in given range and number of total results. -func SearchUsers(opts *SearchUserOptions) (users []*User, _ int64, _ error) { - sessCount := opts.toSearchQueryBase() +func SearchUsers(ctx context.Context, opts *SearchUserOptions) (users []*User, _ int64, _ error) { + sessCount := opts.toSearchQueryBase(ctx) defer sessCount.Close() count, err := sessCount.Count(new(User)) if err != nil { @@ -134,7 +135,7 @@ func SearchUsers(opts *SearchUserOptions) (users []*User, _ int64, _ error) { opts.OrderBy = db.SearchOrderByAlphabetically } - sessQuery := opts.toSearchQueryBase().OrderBy(opts.OrderBy.String()) + sessQuery := opts.toSearchQueryBase(ctx).OrderBy(opts.OrderBy.String()) defer sessQuery.Close() if opts.Page != 0 { sessQuery = db.SetSessionPagination(sessQuery, opts) diff --git a/models/user/user.go b/models/user/user.go index 86cf2ad280..b3956da1cb 100644 --- a/models/user/user.go +++ b/models/user/user.go @@ -192,15 +192,15 @@ func (u *User) SetLastLogin() { } // UpdateUserDiffViewStyle updates the users diff view style -func UpdateUserDiffViewStyle(u *User, style string) error { +func UpdateUserDiffViewStyle(ctx context.Context, u *User, style string) error { u.DiffViewStyle = style - return UpdateUserCols(db.DefaultContext, u, "diff_view_style") + return UpdateUserCols(ctx, u, "diff_view_style") } // UpdateUserTheme updates a users' theme irrespective of the site wide theme -func UpdateUserTheme(u *User, themeName string) error { +func UpdateUserTheme(ctx context.Context, u *User, themeName string) error { u.Theme = themeName - return UpdateUserCols(db.DefaultContext, u, "theme") + return UpdateUserCols(ctx, u, "theme") } // GetPlaceholderEmail returns an noreply email @@ -218,9 +218,9 @@ func (u *User) GetEmail() string { } // GetAllUsers returns a slice of all individual users found in DB. -func GetAllUsers() ([]*User, error) { +func GetAllUsers(ctx context.Context) ([]*User, error) { users := make([]*User, 0) - return users, db.GetEngine(db.DefaultContext).OrderBy("id").Where("type = ?", UserTypeIndividual).Find(&users) + return users, db.GetEngine(ctx).OrderBy("id").Where("type = ?", UserTypeIndividual).Find(&users) } // IsLocal returns true if user login type is LoginPlain. @@ -478,9 +478,9 @@ func (u *User) EmailNotifications() string { } // SetEmailNotifications sets the user's email notification preference -func SetEmailNotifications(u *User, set string) error { +func SetEmailNotifications(ctx context.Context, u *User, set string) error { u.EmailNotificationsPreference = set - if err := UpdateUserCols(db.DefaultContext, u, "email_notifications_preference"); err != nil { + if err := UpdateUserCols(ctx, u, "email_notifications_preference"); err != nil { log.Error("SetEmailNotifications: %v", err) return err } @@ -582,7 +582,7 @@ type CreateUserOverwriteOptions struct { } // CreateUser creates record of a new user. -func CreateUser(u *User, overwriteDefault ...*CreateUserOverwriteOptions) (err error) { +func CreateUser(ctx context.Context, u *User, overwriteDefault ...*CreateUserOverwriteOptions) (err error) { if err = IsUsableUsername(u.Name); err != nil { return err } @@ -640,7 +640,7 @@ func CreateUser(u *User, overwriteDefault ...*CreateUserOverwriteOptions) (err e return err } - ctx, committer, err := db.TxContext(db.DefaultContext) + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } @@ -711,8 +711,8 @@ type CountUserFilter struct { } // CountUsers returns number of users. -func CountUsers(opts *CountUserFilter) int64 { - return countUsers(db.DefaultContext, opts) +func CountUsers(ctx context.Context, opts *CountUserFilter) int64 { + return countUsers(ctx, opts) } func countUsers(ctx context.Context, opts *CountUserFilter) int64 { @@ -727,7 +727,7 @@ func countUsers(ctx context.Context, opts *CountUserFilter) int64 { } // GetVerifyUser get user by verify code -func GetVerifyUser(code string) (user *User) { +func GetVerifyUser(ctx context.Context, code string) (user *User) { if len(code) <= base.TimeLimitCodeLength { return nil } @@ -735,7 +735,7 @@ func GetVerifyUser(code string) (user *User) { // use tail hex username query user hexStr := code[base.TimeLimitCodeLength:] if b, err := hex.DecodeString(hexStr); err == nil { - if user, err = GetUserByName(db.DefaultContext, string(b)); user != nil { + if user, err = GetUserByName(ctx, string(b)); user != nil { return user } log.Error("user.getVerifyUser: %v", err) @@ -745,10 +745,10 @@ func GetVerifyUser(code string) (user *User) { } // VerifyUserActiveCode verifies active code when active account -func VerifyUserActiveCode(code string) (user *User) { +func VerifyUserActiveCode(ctx context.Context, code string) (user *User) { minutes := setting.Service.ActiveCodeLives - if user = GetVerifyUser(code); user != nil { + if user = GetVerifyUser(ctx, code); user != nil { // time limit code prefix := code[:base.TimeLimitCodeLength] data := fmt.Sprintf("%d%s%s%s%s", user.ID, user.Email, user.LowerName, user.Passwd, user.Rands) @@ -872,8 +872,8 @@ func UpdateUserCols(ctx context.Context, u *User, cols ...string) error { } // UpdateUserSetting updates user's settings. -func UpdateUserSetting(u *User) (err error) { - ctx, committer, err := db.TxContext(db.DefaultContext) +func UpdateUserSetting(ctx context.Context, u *User) (err error) { + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } @@ -1021,9 +1021,9 @@ func GetMaileableUsersByIDs(ctx context.Context, ids []int64, isMention bool) ([ } // GetUserNamesByIDs returns usernames for all resolved users from a list of Ids. -func GetUserNamesByIDs(ids []int64) ([]string, error) { +func GetUserNamesByIDs(ctx context.Context, ids []int64) ([]string, error) { unames := make([]string, 0, len(ids)) - err := db.GetEngine(db.DefaultContext).In("id", ids). + err := db.GetEngine(ctx).In("id", ids). Table("user"). Asc("name"). Cols("name"). @@ -1062,9 +1062,9 @@ func GetUserIDsByNames(ctx context.Context, names []string, ignoreNonExistent bo } // GetUsersBySource returns a list of Users for a login source -func GetUsersBySource(s *auth.Source) ([]*User, error) { +func GetUsersBySource(ctx context.Context, s *auth.Source) ([]*User, error) { var users []*User - err := db.GetEngine(db.DefaultContext).Where("login_type = ? AND login_source = ?", s.Type, s.ID).Find(&users) + err := db.GetEngine(ctx).Where("login_type = ? AND login_source = ?", s.Type, s.ID).Find(&users) return users, err } @@ -1145,12 +1145,12 @@ func GetUserByEmail(ctx context.Context, email string) (*User, error) { } // GetUser checks if a user already exists -func GetUser(user *User) (bool, error) { - return db.GetEngine(db.DefaultContext).Get(user) +func GetUser(ctx context.Context, user *User) (bool, error) { + return db.GetEngine(ctx).Get(user) } // GetUserByOpenID returns the user object by given OpenID if exists. -func GetUserByOpenID(uri string) (*User, error) { +func GetUserByOpenID(ctx context.Context, uri string) (*User, error) { if len(uri) == 0 { return nil, ErrUserNotExist{0, uri, 0} } @@ -1164,12 +1164,12 @@ func GetUserByOpenID(uri string) (*User, error) { // Otherwise, check in openid table oid := &UserOpenID{} - has, err := db.GetEngine(db.DefaultContext).Where("uri=?", uri).Get(oid) + has, err := db.GetEngine(ctx).Where("uri=?", uri).Get(oid) if err != nil { return nil, err } if has { - return GetUserByID(db.DefaultContext, oid.UID) + return GetUserByID(ctx, oid.UID) } return nil, ErrUserNotExist{0, uri, 0} @@ -1279,13 +1279,13 @@ func IsUserVisibleToViewer(ctx context.Context, u, viewer *User) bool { } // CountWrongUserType count OrgUser who have wrong type -func CountWrongUserType() (int64, error) { - return db.GetEngine(db.DefaultContext).Where(builder.Eq{"type": 0}.And(builder.Neq{"num_teams": 0})).Count(new(User)) +func CountWrongUserType(ctx context.Context) (int64, error) { + return db.GetEngine(ctx).Where(builder.Eq{"type": 0}.And(builder.Neq{"num_teams": 0})).Count(new(User)) } // FixWrongUserType fix OrgUser who have wrong type -func FixWrongUserType() (int64, error) { - return db.GetEngine(db.DefaultContext).Where(builder.Eq{"type": 0}.And(builder.Neq{"num_teams": 0})).Cols("type").NoAutoTime().Update(&User{Type: 1}) +func FixWrongUserType(ctx context.Context) (int64, error) { + return db.GetEngine(ctx).Where(builder.Eq{"type": 0}.And(builder.Neq{"num_teams": 0})).Cols("type").NoAutoTime().Update(&User{Type: 1}) } func GetOrderByName() string { diff --git a/models/user/user_test.go b/models/user/user_test.go index 032dcba676..b15f0cbc59 100644 --- a/models/user/user_test.go +++ b/models/user/user_test.go @@ -63,7 +63,7 @@ func TestCanCreateOrganization(t *testing.T) { func TestSearchUsers(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) testSuccess := func(opts *user_model.SearchUserOptions, expectedUserOrOrgIDs []int64) { - users, _, err := user_model.SearchUsers(opts) + users, _, err := user_model.SearchUsers(db.DefaultContext, opts) assert.NoError(t, err) cassText := fmt.Sprintf("ids: %v, opts: %v", expectedUserOrOrgIDs, opts) if assert.Len(t, users, len(expectedUserOrOrgIDs), "case: %s", cassText) { @@ -150,16 +150,16 @@ func TestEmailNotificationPreferences(t *testing.T) { assert.Equal(t, test.expected, user.EmailNotifications()) // Try all possible settings - assert.NoError(t, user_model.SetEmailNotifications(user, user_model.EmailNotificationsEnabled)) + assert.NoError(t, user_model.SetEmailNotifications(db.DefaultContext, user, user_model.EmailNotificationsEnabled)) assert.Equal(t, user_model.EmailNotificationsEnabled, user.EmailNotifications()) - assert.NoError(t, user_model.SetEmailNotifications(user, user_model.EmailNotificationsOnMention)) + assert.NoError(t, user_model.SetEmailNotifications(db.DefaultContext, user, user_model.EmailNotificationsOnMention)) assert.Equal(t, user_model.EmailNotificationsOnMention, user.EmailNotifications()) - assert.NoError(t, user_model.SetEmailNotifications(user, user_model.EmailNotificationsDisabled)) + assert.NoError(t, user_model.SetEmailNotifications(db.DefaultContext, user, user_model.EmailNotificationsDisabled)) assert.Equal(t, user_model.EmailNotificationsDisabled, user.EmailNotifications()) - assert.NoError(t, user_model.SetEmailNotifications(user, user_model.EmailNotificationsAndYourOwn)) + assert.NoError(t, user_model.SetEmailNotifications(db.DefaultContext, user, user_model.EmailNotificationsAndYourOwn)) assert.Equal(t, user_model.EmailNotificationsAndYourOwn, user.EmailNotifications()) } } @@ -239,7 +239,7 @@ func TestCreateUserInvalidEmail(t *testing.T) { MustChangePassword: false, } - err := user_model.CreateUser(user) + err := user_model.CreateUser(db.DefaultContext, user) assert.Error(t, err) assert.True(t, user_model.IsErrEmailCharIsNotSupported(err)) } @@ -253,7 +253,7 @@ func TestCreateUserEmailAlreadyUsed(t *testing.T) { user.Name = "testuser" user.LowerName = strings.ToLower(user.Name) user.ID = 0 - err := user_model.CreateUser(user) + err := user_model.CreateUser(db.DefaultContext, user) assert.Error(t, err) assert.True(t, user_model.IsErrEmailAlreadyUsed(err)) } @@ -270,7 +270,7 @@ func TestCreateUserCustomTimestamps(t *testing.T) { user.ID = 0 user.Email = "unique@example.com" user.CreatedUnix = creationTimestamp - err := user_model.CreateUser(user) + err := user_model.CreateUser(db.DefaultContext, user) assert.NoError(t, err) fetched, err := user_model.GetUserByID(context.Background(), user.ID) @@ -295,7 +295,7 @@ func TestCreateUserWithoutCustomTimestamps(t *testing.T) { user.Email = "unique@example.com" user.CreatedUnix = 0 user.UpdatedUnix = 0 - err := user_model.CreateUser(user) + err := user_model.CreateUser(db.DefaultContext, user) assert.NoError(t, err) timestampEnd := time.Now().Unix() @@ -429,17 +429,17 @@ func TestNewUserRedirect3(t *testing.T) { func TestGetUserByOpenID(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - _, err := user_model.GetUserByOpenID("https://unknown") + _, err := user_model.GetUserByOpenID(db.DefaultContext, "https://unknown") if assert.Error(t, err) { assert.True(t, user_model.IsErrUserNotExist(err)) } - user, err := user_model.GetUserByOpenID("https://user1.domain1.tld") + user, err := user_model.GetUserByOpenID(db.DefaultContext, "https://user1.domain1.tld") if assert.NoError(t, err) { assert.Equal(t, int64(1), user.ID) } - user, err = user_model.GetUserByOpenID("https://domain1.tld/user2/") + user, err = user_model.GetUserByOpenID(db.DefaultContext, "https://domain1.tld/user2/") if assert.NoError(t, err) { assert.Equal(t, int64(2), user.ID) } |