diff options
author | KN4CK3R <admin@oldschoolhack.me> | 2022-11-19 09:12:33 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-11-19 16:12:33 +0800 |
commit | 044c754ea53f5b81f451451df53aea366f6f700a (patch) | |
tree | 45688c28a84f87f71ec3f99eb0e8456eb7d19c42 /models | |
parent | fefdb7ffd11bbfbff66dae8e88681ec840dedfde (diff) | |
download | gitea-044c754ea53f5b81f451451df53aea366f6f700a.tar.gz gitea-044c754ea53f5b81f451451df53aea366f6f700a.zip |
Add `context.Context` to more methods (#21546)
This PR adds a context parameter to a bunch of methods. Some helper
`xxxCtx()` methods got replaced with the normal name now.
Co-authored-by: delvh <dev.lh@web.de>
Co-authored-by: Lunny Xiao <xiaolunwen@gmail.com>
Diffstat (limited to 'models')
39 files changed, 348 insertions, 442 deletions
diff --git a/models/activities/action.go b/models/activities/action.go index 5c3419c5ec..bbb6073265 100644 --- a/models/activities/action.go +++ b/models/activities/action.go @@ -461,7 +461,8 @@ func DeleteOldActions(olderThan time.Duration) (err error) { return err } -func notifyWatchers(ctx context.Context, actions ...*Action) error { +// NotifyWatchers creates batch of actions for every watcher. +func NotifyWatchers(ctx context.Context, actions ...*Action) error { var watchers []*repo_model.Watch var repo *repo_model.Repository var err error @@ -565,11 +566,6 @@ func notifyWatchers(ctx context.Context, actions ...*Action) error { return nil } -// NotifyWatchers creates batch of actions for every watcher. -func NotifyWatchers(actions ...*Action) error { - return notifyWatchers(db.DefaultContext, actions...) -} - // NotifyWatchersActions creates batch of actions for every watcher. func NotifyWatchersActions(acts []*Action) error { ctx, committer, err := db.TxContext(db.DefaultContext) @@ -578,7 +574,7 @@ func NotifyWatchersActions(acts []*Action) error { } defer committer.Close() for _, act := range acts { - if err := notifyWatchers(ctx, act); err != nil { + if err := NotifyWatchers(ctx, act); err != nil { return err } } @@ -603,17 +599,17 @@ func DeleteIssueActions(ctx context.Context, repoID, issueID int64) error { } // CountActionCreatedUnixString count actions where created_unix is an empty string -func CountActionCreatedUnixString() (int64, error) { +func CountActionCreatedUnixString(ctx context.Context) (int64, error) { if setting.Database.UseSQLite3 { - return db.GetEngine(db.DefaultContext).Where(`created_unix = ""`).Count(new(Action)) + return db.GetEngine(ctx).Where(`created_unix = ""`).Count(new(Action)) } return 0, nil } // FixActionCreatedUnixString set created_unix to zero if it is an empty string -func FixActionCreatedUnixString() (int64, error) { +func FixActionCreatedUnixString(ctx context.Context) (int64, error) { if setting.Database.UseSQLite3 { - res, err := db.GetEngine(db.DefaultContext).Exec(`UPDATE action SET created_unix = 0 WHERE created_unix = ""`) + res, err := db.GetEngine(ctx).Exec(`UPDATE action SET created_unix = 0 WHERE created_unix = ""`) if err != nil { return 0, err } diff --git a/models/activities/action_test.go b/models/activities/action_test.go index ac2a3043a6..b79eb0d08d 100644 --- a/models/activities/action_test.go +++ b/models/activities/action_test.go @@ -188,7 +188,7 @@ func TestNotifyWatchers(t *testing.T) { RepoID: 1, OpType: activities_model.ActionStarRepo, } - assert.NoError(t, activities_model.NotifyWatchers(action)) + assert.NoError(t, activities_model.NotifyWatchers(db.DefaultContext, action)) // One watchers are inactive, thus action is only created for user 8, 1, 4, 11 unittest.AssertExistsAndLoadBean(t, &activities_model.Action{ @@ -256,17 +256,17 @@ func TestConsistencyUpdateAction(t *testing.T) { // // Get rid of incorrectly set created_unix // - count, err := activities_model.CountActionCreatedUnixString() + count, err := activities_model.CountActionCreatedUnixString(db.DefaultContext) assert.NoError(t, err) assert.EqualValues(t, 1, count) - count, err = activities_model.FixActionCreatedUnixString() + count, err = activities_model.FixActionCreatedUnixString(db.DefaultContext) assert.NoError(t, err) assert.EqualValues(t, 1, count) - count, err = activities_model.CountActionCreatedUnixString() + count, err = activities_model.CountActionCreatedUnixString(db.DefaultContext) assert.NoError(t, err) assert.EqualValues(t, 0, count) - count, err = activities_model.FixActionCreatedUnixString() + count, err = activities_model.FixActionCreatedUnixString(db.DefaultContext) assert.NoError(t, err) assert.EqualValues(t, 0, count) diff --git a/models/activities/notification.go b/models/activities/notification.go index 28adc8cc4e..10b3a76713 100644 --- a/models/activities/notification.go +++ b/models/activities/notification.go @@ -136,49 +136,41 @@ func GetNotifications(ctx context.Context, options *FindNotificationOptions) (nl } // CountNotifications count all notifications that fit to the given options and ignore pagination. -func CountNotifications(opts *FindNotificationOptions) (int64, error) { - return db.GetEngine(db.DefaultContext).Where(opts.ToCond()).Count(&Notification{}) +func CountNotifications(ctx context.Context, opts *FindNotificationOptions) (int64, error) { + return db.GetEngine(ctx).Where(opts.ToCond()).Count(&Notification{}) } // CreateRepoTransferNotification creates notification for the user a repository was transferred to -func CreateRepoTransferNotification(doer, newOwner *user_model.User, repo *repo_model.Repository) error { - ctx, committer, err := db.TxContext(db.DefaultContext) - if err != nil { - return err - } - defer committer.Close() - - var notify []*Notification +func CreateRepoTransferNotification(ctx context.Context, doer, newOwner *user_model.User, repo *repo_model.Repository) error { + return db.AutoTx(ctx, func(ctx context.Context) error { + var notify []*Notification - if newOwner.IsOrganization() { - users, err := organization.GetUsersWhoCanCreateOrgRepo(ctx, newOwner.ID) - if err != nil || len(users) == 0 { - return err - } - for i := range users { - notify = append(notify, &Notification{ - UserID: users[i].ID, + if newOwner.IsOrganization() { + users, err := organization.GetUsersWhoCanCreateOrgRepo(ctx, newOwner.ID) + if err != nil || len(users) == 0 { + return err + } + for i := range users { + notify = append(notify, &Notification{ + UserID: users[i].ID, + RepoID: repo.ID, + Status: NotificationStatusUnread, + UpdatedBy: doer.ID, + Source: NotificationSourceRepository, + }) + } + } else { + notify = []*Notification{{ + UserID: newOwner.ID, RepoID: repo.ID, Status: NotificationStatusUnread, UpdatedBy: doer.ID, Source: NotificationSourceRepository, - }) + }} } - } else { - notify = []*Notification{{ - UserID: newOwner.ID, - RepoID: repo.ID, - Status: NotificationStatusUnread, - UpdatedBy: doer.ID, - Source: NotificationSourceRepository, - }} - } - - if err := db.Insert(ctx, notify); err != nil { - return err - } - return committer.Commit() + return db.Insert(ctx, notify) + }) } // CreateOrUpdateIssueNotifications creates an issue notification @@ -379,11 +371,7 @@ func CountUnread(ctx context.Context, userID int64) int64 { } // LoadAttributes load Repo Issue User and Comment if not loaded -func (n *Notification) LoadAttributes() (err error) { - return n.loadAttributes(db.DefaultContext) -} - -func (n *Notification) loadAttributes(ctx context.Context) (err error) { +func (n *Notification) LoadAttributes(ctx context.Context) (err error) { if err = n.loadRepo(ctx); err != nil { return } @@ -481,10 +469,10 @@ func (n *Notification) APIURL() string { type NotificationList []*Notification // LoadAttributes load Repo Issue User and Comment if not loaded -func (nl NotificationList) LoadAttributes() error { +func (nl NotificationList) LoadAttributes(ctx context.Context) error { var err error for i := 0; i < len(nl); i++ { - err = nl[i].LoadAttributes() + err = nl[i].LoadAttributes(ctx) if err != nil && !issues_model.IsErrCommentNotExist(err) { return err } @@ -504,7 +492,7 @@ func (nl NotificationList) getPendingRepoIDs() []int64 { } // LoadRepos loads repositories from database -func (nl NotificationList) LoadRepos() (repo_model.RepositoryList, []int, error) { +func (nl NotificationList) LoadRepos(ctx context.Context) (repo_model.RepositoryList, []int, error) { if len(nl) == 0 { return repo_model.RepositoryList{}, []int{}, nil } @@ -517,7 +505,7 @@ func (nl NotificationList) LoadRepos() (repo_model.RepositoryList, []int, error) if left < limit { limit = left } - rows, err := db.GetEngine(db.DefaultContext). + rows, err := db.GetEngine(ctx). In("id", repoIDs[:limit]). Rows(new(repo_model.Repository)) if err != nil { @@ -578,7 +566,7 @@ func (nl NotificationList) getPendingIssueIDs() []int64 { } // LoadIssues loads issues from database -func (nl NotificationList) LoadIssues() ([]int, error) { +func (nl NotificationList) LoadIssues(ctx context.Context) ([]int, error) { if len(nl) == 0 { return []int{}, nil } @@ -591,7 +579,7 @@ func (nl NotificationList) LoadIssues() ([]int, error) { if left < limit { limit = left } - rows, err := db.GetEngine(db.DefaultContext). + rows, err := db.GetEngine(ctx). In("id", issueIDs[:limit]). Rows(new(issues_model.Issue)) if err != nil { @@ -662,7 +650,7 @@ func (nl NotificationList) getPendingCommentIDs() []int64 { } // LoadComments loads comments from database -func (nl NotificationList) LoadComments() ([]int, error) { +func (nl NotificationList) LoadComments(ctx context.Context) ([]int, error) { if len(nl) == 0 { return []int{}, nil } @@ -675,7 +663,7 @@ func (nl NotificationList) LoadComments() ([]int, error) { if left < limit { limit = left } - rows, err := db.GetEngine(db.DefaultContext). + rows, err := db.GetEngine(ctx). In("id", commentIDs[:limit]). Rows(new(issues_model.Comment)) if err != nil { @@ -775,8 +763,8 @@ func SetRepoReadBy(ctx context.Context, userID, repoID int64) error { } // SetNotificationStatus change the notification status -func SetNotificationStatus(notificationID int64, user *user_model.User, status NotificationStatus) (*Notification, error) { - notification, err := getNotificationByID(db.DefaultContext, notificationID) +func SetNotificationStatus(ctx context.Context, notificationID int64, user *user_model.User, status NotificationStatus) (*Notification, error) { + notification, err := GetNotificationByID(ctx, notificationID) if err != nil { return notification, err } @@ -787,16 +775,12 @@ func SetNotificationStatus(notificationID int64, user *user_model.User, status N notification.Status = status - _, err = db.GetEngine(db.DefaultContext).ID(notificationID).Update(notification) + _, err = db.GetEngine(ctx).ID(notificationID).Update(notification) return notification, err } // GetNotificationByID return notification by ID -func GetNotificationByID(notificationID int64) (*Notification, error) { - return getNotificationByID(db.DefaultContext, notificationID) -} - -func getNotificationByID(ctx context.Context, notificationID int64) (*Notification, error) { +func GetNotificationByID(ctx context.Context, notificationID int64) (*Notification, error) { notification := new(Notification) ok, err := db.GetEngine(ctx). Where("id = ?", notificationID). @@ -813,9 +797,9 @@ func getNotificationByID(ctx context.Context, notificationID int64) (*Notificati } // UpdateNotificationStatuses updates the statuses of all of a user's notifications that are of the currentStatus type to the desiredStatus -func UpdateNotificationStatuses(user *user_model.User, currentStatus, desiredStatus NotificationStatus) error { +func UpdateNotificationStatuses(ctx context.Context, user *user_model.User, currentStatus, desiredStatus NotificationStatus) error { n := &Notification{Status: desiredStatus, UpdatedBy: user.ID} - _, err := db.GetEngine(db.DefaultContext). + _, err := db.GetEngine(ctx). Where("user_id = ? AND status = ?", user.ID, currentStatus). Cols("status", "updated_by", "updated_unix"). Update(n) diff --git a/models/activities/notification_test.go b/models/activities/notification_test.go index 4ee16af076..d871891001 100644 --- a/models/activities/notification_test.go +++ b/models/activities/notification_test.go @@ -82,14 +82,14 @@ func TestSetNotificationStatus(t *testing.T) { user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2}) notf := unittest.AssertExistsAndLoadBean(t, &activities_model.Notification{UserID: user.ID, Status: activities_model.NotificationStatusRead}) - _, err := activities_model.SetNotificationStatus(notf.ID, user, activities_model.NotificationStatusPinned) + _, err := activities_model.SetNotificationStatus(db.DefaultContext, notf.ID, user, activities_model.NotificationStatusPinned) assert.NoError(t, err) unittest.AssertExistsAndLoadBean(t, &activities_model.Notification{ID: notf.ID, Status: activities_model.NotificationStatusPinned}) - _, err = activities_model.SetNotificationStatus(1, user, activities_model.NotificationStatusRead) + _, err = activities_model.SetNotificationStatus(db.DefaultContext, 1, user, activities_model.NotificationStatusRead) assert.Error(t, err) - _, err = activities_model.SetNotificationStatus(unittest.NonexistentID, user, activities_model.NotificationStatusRead) + _, err = activities_model.SetNotificationStatus(db.DefaultContext, unittest.NonexistentID, user, activities_model.NotificationStatusRead) assert.Error(t, err) } @@ -102,7 +102,7 @@ func TestUpdateNotificationStatuses(t *testing.T) { &activities_model.Notification{UserID: user.ID, Status: activities_model.NotificationStatusRead}) notfPinned := unittest.AssertExistsAndLoadBean(t, &activities_model.Notification{UserID: user.ID, Status: activities_model.NotificationStatusPinned}) - assert.NoError(t, activities_model.UpdateNotificationStatuses(user, activities_model.NotificationStatusUnread, activities_model.NotificationStatusRead)) + assert.NoError(t, activities_model.UpdateNotificationStatuses(db.DefaultContext, user, activities_model.NotificationStatusUnread, activities_model.NotificationStatusRead)) unittest.AssertExistsAndLoadBean(t, &activities_model.Notification{ID: notfUnread.ID, Status: activities_model.NotificationStatusRead}) unittest.AssertExistsAndLoadBean(t, diff --git a/models/db/consistency.go b/models/db/consistency.go index 7addb174c4..5a7878c74d 100644 --- a/models/db/consistency.go +++ b/models/db/consistency.go @@ -4,11 +4,16 @@ package db -import "xorm.io/builder" +import ( + "context" + + "xorm.io/builder" +) // CountOrphanedObjects count subjects with have no existing refobject anymore -func CountOrphanedObjects(subject, refobject, joinCond string) (int64, error) { - return GetEngine(DefaultContext).Table("`"+subject+"`"). +func CountOrphanedObjects(ctx context.Context, subject, refobject, joinCond string) (int64, error) { + return GetEngine(ctx). + Table("`"+subject+"`"). Join("LEFT", "`"+refobject+"`", joinCond). Where(builder.IsNull{"`" + refobject + "`.id"}). Select("COUNT(`" + subject + "`.`id`)"). @@ -16,12 +21,12 @@ func CountOrphanedObjects(subject, refobject, joinCond string) (int64, error) { } // DeleteOrphanedObjects delete subjects with have no existing refobject anymore -func DeleteOrphanedObjects(subject, refobject, joinCond string) error { +func DeleteOrphanedObjects(ctx context.Context, subject, refobject, joinCond string) error { subQuery := builder.Select("`"+subject+"`.id"). From("`"+subject+"`"). Join("LEFT", "`"+refobject+"`", joinCond). Where(builder.IsNull{"`" + refobject + "`.id"}) b := builder.Delete(builder.In("id", subQuery)).From("`" + subject + "`") - _, err := GetEngine(DefaultContext).Exec(b) + _, err := GetEngine(ctx).Exec(b) return err } diff --git a/models/db/engine_test.go b/models/db/engine_test.go index c26d94c340..c2ba9614aa 100644 --- a/models/db/engine_test.go +++ b/models/db/engine_test.go @@ -41,11 +41,11 @@ func TestDeleteOrphanedObjects(t *testing.T) { _, err = db.GetEngine(db.DefaultContext).Insert(&issues_model.PullRequest{IssueID: 1000}, &issues_model.PullRequest{IssueID: 1001}, &issues_model.PullRequest{IssueID: 1003}) assert.NoError(t, err) - orphaned, err := db.CountOrphanedObjects("pull_request", "issue", "pull_request.issue_id=issue.id") + orphaned, err := db.CountOrphanedObjects(db.DefaultContext, "pull_request", "issue", "pull_request.issue_id=issue.id") assert.NoError(t, err) assert.EqualValues(t, 3, orphaned) - err = db.DeleteOrphanedObjects("pull_request", "issue", "pull_request.issue_id=issue.id") + err = db.DeleteOrphanedObjects(db.DefaultContext, "pull_request", "issue", "pull_request.issue_id=issue.id") assert.NoError(t, err) countAfter, err := db.GetEngine(db.DefaultContext).Count(&issues_model.PullRequest{}) diff --git a/models/db/sequence.go b/models/db/sequence.go index 48e4a8f1ac..0daacee70c 100644 --- a/models/db/sequence.go +++ b/models/db/sequence.go @@ -5,6 +5,7 @@ package db import ( + "context" "fmt" "regexp" @@ -12,7 +13,7 @@ import ( ) // CountBadSequences looks for broken sequences from recreate-table mistakes -func CountBadSequences() (int64, error) { +func CountBadSequences(_ context.Context) (int64, error) { if !setting.Database.UsePostgreSQL { return 0, nil } @@ -33,7 +34,7 @@ func CountBadSequences() (int64, error) { } // FixBadSequences fixes for broken sequences from recreate-table mistakes -func FixBadSequences() error { +func FixBadSequences(_ context.Context) error { if !setting.Database.UsePostgreSQL { return nil } diff --git a/models/fixture_generation.go b/models/fixture_generation.go index f4644859eb..50b983fa82 100644 --- a/models/fixture_generation.go +++ b/models/fixture_generation.go @@ -22,7 +22,7 @@ func GetYamlFixturesAccess() (string, error) { } for _, repo := range repos { - repo.MustOwner() + repo.MustOwner(db.DefaultContext) if err := access_model.RecalculateAccesses(db.DefaultContext, repo); err != nil { return "", err } diff --git a/models/git/protected_tag.go b/models/git/protected_tag.go index 7c3881643d..4640a77b20 100644 --- a/models/git/protected_tag.go +++ b/models/git/protected_tag.go @@ -5,6 +5,7 @@ package git import ( + "context" "regexp" "strings" @@ -69,13 +70,13 @@ func UpdateProtectedTag(pt *ProtectedTag) error { } // DeleteProtectedTag deletes a protected tag by ID -func DeleteProtectedTag(pt *ProtectedTag) error { - _, err := db.GetEngine(db.DefaultContext).ID(pt.ID).Delete(&ProtectedTag{}) +func DeleteProtectedTag(ctx context.Context, pt *ProtectedTag) error { + _, err := db.GetEngine(ctx).ID(pt.ID).Delete(&ProtectedTag{}) return err } // IsUserAllowedModifyTag returns true if the user is allowed to modify the tag -func IsUserAllowedModifyTag(pt *ProtectedTag, userID int64) (bool, error) { +func IsUserAllowedModifyTag(ctx context.Context, pt *ProtectedTag, userID int64) (bool, error) { if base.Int64sContains(pt.AllowlistUserIDs, userID) { return true, nil } @@ -84,7 +85,7 @@ func IsUserAllowedModifyTag(pt *ProtectedTag, userID int64) (bool, error) { return false, nil } - in, err := organization.IsUserInTeams(db.DefaultContext, userID, pt.AllowlistTeamIDs) + in, err := organization.IsUserInTeams(ctx, userID, pt.AllowlistTeamIDs) if err != nil { return false, err } @@ -92,9 +93,9 @@ func IsUserAllowedModifyTag(pt *ProtectedTag, userID int64) (bool, error) { } // GetProtectedTags gets all protected tags of the repository -func GetProtectedTags(repoID int64) ([]*ProtectedTag, error) { +func GetProtectedTags(ctx context.Context, repoID int64) ([]*ProtectedTag, error) { tags := make([]*ProtectedTag, 0) - return tags, db.GetEngine(db.DefaultContext).Find(&tags, &ProtectedTag{RepoID: repoID}) + return tags, db.GetEngine(ctx).Find(&tags, &ProtectedTag{RepoID: repoID}) } // GetProtectedTagByID gets the protected tag with the specific id @@ -112,7 +113,7 @@ func GetProtectedTagByID(id int64) (*ProtectedTag, error) { // IsUserAllowedToControlTag checks if a user can control the specific tag. // It returns true if the tag name is not protected or the user is allowed to control it. -func IsUserAllowedToControlTag(tags []*ProtectedTag, tagName string, userID int64) (bool, error) { +func IsUserAllowedToControlTag(ctx context.Context, tags []*ProtectedTag, tagName string, userID int64) (bool, error) { isAllowed := true for _, tag := range tags { err := tag.EnsureCompiledPattern() @@ -124,7 +125,7 @@ func IsUserAllowedToControlTag(tags []*ProtectedTag, tagName string, userID int6 continue } - isAllowed, err = IsUserAllowedModifyTag(tag, userID) + isAllowed, err = IsUserAllowedModifyTag(ctx, tag, userID) if err != nil { return false, err } diff --git a/models/git/protected_tag_test.go b/models/git/protected_tag_test.go index b496688b25..352eed0060 100644 --- a/models/git/protected_tag_test.go +++ b/models/git/protected_tag_test.go @@ -7,6 +7,7 @@ package git_test import ( "testing" + "code.gitea.io/gitea/models/db" git_model "code.gitea.io/gitea/models/git" "code.gitea.io/gitea/models/unittest" @@ -17,29 +18,29 @@ func TestIsUserAllowed(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) pt := &git_model.ProtectedTag{} - allowed, err := git_model.IsUserAllowedModifyTag(pt, 1) + allowed, err := git_model.IsUserAllowedModifyTag(db.DefaultContext, pt, 1) assert.NoError(t, err) assert.False(t, allowed) pt = &git_model.ProtectedTag{ AllowlistUserIDs: []int64{1}, } - allowed, err = git_model.IsUserAllowedModifyTag(pt, 1) + allowed, err = git_model.IsUserAllowedModifyTag(db.DefaultContext, pt, 1) assert.NoError(t, err) assert.True(t, allowed) - allowed, err = git_model.IsUserAllowedModifyTag(pt, 2) + allowed, err = git_model.IsUserAllowedModifyTag(db.DefaultContext, pt, 2) assert.NoError(t, err) assert.False(t, allowed) pt = &git_model.ProtectedTag{ AllowlistTeamIDs: []int64{1}, } - allowed, err = git_model.IsUserAllowedModifyTag(pt, 1) + allowed, err = git_model.IsUserAllowedModifyTag(db.DefaultContext, pt, 1) assert.NoError(t, err) assert.False(t, allowed) - allowed, err = git_model.IsUserAllowedModifyTag(pt, 2) + allowed, err = git_model.IsUserAllowedModifyTag(db.DefaultContext, pt, 2) assert.NoError(t, err) assert.True(t, allowed) @@ -47,11 +48,11 @@ func TestIsUserAllowed(t *testing.T) { AllowlistUserIDs: []int64{1}, AllowlistTeamIDs: []int64{1}, } - allowed, err = git_model.IsUserAllowedModifyTag(pt, 1) + allowed, err = git_model.IsUserAllowedModifyTag(db.DefaultContext, pt, 1) assert.NoError(t, err) assert.True(t, allowed) - allowed, err = git_model.IsUserAllowedModifyTag(pt, 2) + allowed, err = git_model.IsUserAllowedModifyTag(db.DefaultContext, pt, 2) assert.NoError(t, err) assert.True(t, allowed) } @@ -135,7 +136,7 @@ func TestIsUserAllowedToControlTag(t *testing.T) { } for n, c := range cases { - isAllowed, err := git_model.IsUserAllowedToControlTag(protectedTags, c.name, c.userid) + isAllowed, err := git_model.IsUserAllowedToControlTag(db.DefaultContext, protectedTags, c.name, c.userid) assert.NoError(t, err) assert.Equal(t, c.allowed, isAllowed, "case %d: error should match", n) } @@ -157,7 +158,7 @@ func TestIsUserAllowedToControlTag(t *testing.T) { } for n, c := range cases { - isAllowed, err := git_model.IsUserAllowedToControlTag(protectedTags, c.name, c.userid) + isAllowed, err := git_model.IsUserAllowedToControlTag(db.DefaultContext, protectedTags, c.name, c.userid) assert.NoError(t, err) assert.Equal(t, c.allowed, isAllowed, "case %d: error should match", n) } diff --git a/models/issues/assignees.go b/models/issues/assignees.go index ce497b116d..19480fa1e1 100644 --- a/models/issues/assignees.go +++ b/models/issues/assignees.go @@ -48,9 +48,10 @@ func (issue *Issue) LoadAssignees(ctx context.Context) (err error) { // GetAssigneeIDsByIssue returns the IDs of users assigned to an issue // but skips joining with `user` for performance reasons. // User permissions must be verified elsewhere if required. -func GetAssigneeIDsByIssue(issueID int64) ([]int64, error) { +func GetAssigneeIDsByIssue(ctx context.Context, issueID int64) ([]int64, error) { userIDs := make([]int64, 0, 5) - return userIDs, db.GetEngine(db.DefaultContext).Table("issue_assignees"). + return userIDs, db.GetEngine(ctx). + Table("issue_assignees"). Cols("assignee_id"). Where("issue_id = ?", issueID). Distinct("assignee_id"). @@ -151,7 +152,7 @@ func toggleUserAssignee(ctx context.Context, issue *Issue, assigneeID int64) (re } // MakeIDsFromAPIAssigneesToAdd returns an array with all assignee IDs -func MakeIDsFromAPIAssigneesToAdd(oneAssignee string, multipleAssignees []string) (assigneeIDs []int64, err error) { +func MakeIDsFromAPIAssigneesToAdd(ctx context.Context, oneAssignee string, multipleAssignees []string) (assigneeIDs []int64, err error) { var requestAssignees []string // Keeping the old assigning method for compatibility reasons @@ -165,7 +166,7 @@ func MakeIDsFromAPIAssigneesToAdd(oneAssignee string, multipleAssignees []string } // Get the IDs of all assignees - assigneeIDs, err = user_model.GetUserIDsByNames(requestAssignees, false) + assigneeIDs, err = user_model.GetUserIDsByNames(ctx, requestAssignees, false) return assigneeIDs, err } diff --git a/models/issues/assignees_test.go b/models/issues/assignees_test.go index 291bb673da..4286bdd7ee 100644 --- a/models/issues/assignees_test.go +++ b/models/issues/assignees_test.go @@ -71,22 +71,22 @@ func TestMakeIDsFromAPIAssigneesToAdd(t *testing.T) { _ = unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 1}) _ = unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2}) - IDs, err := issues_model.MakeIDsFromAPIAssigneesToAdd("", []string{""}) + IDs, err := issues_model.MakeIDsFromAPIAssigneesToAdd(db.DefaultContext, "", []string{""}) assert.NoError(t, err) assert.Equal(t, []int64{}, IDs) - _, err = issues_model.MakeIDsFromAPIAssigneesToAdd("", []string{"none_existing_user"}) + _, err = issues_model.MakeIDsFromAPIAssigneesToAdd(db.DefaultContext, "", []string{"none_existing_user"}) assert.Error(t, err) - IDs, err = issues_model.MakeIDsFromAPIAssigneesToAdd("user1", []string{"user1"}) + IDs, err = issues_model.MakeIDsFromAPIAssigneesToAdd(db.DefaultContext, "user1", []string{"user1"}) assert.NoError(t, err) assert.Equal(t, []int64{1}, IDs) - IDs, err = issues_model.MakeIDsFromAPIAssigneesToAdd("user2", []string{""}) + IDs, err = issues_model.MakeIDsFromAPIAssigneesToAdd(db.DefaultContext, "user2", []string{""}) assert.NoError(t, err) assert.Equal(t, []int64{2}, IDs) - IDs, err = issues_model.MakeIDsFromAPIAssigneesToAdd("", []string{"user1", "user2"}) + IDs, err = issues_model.MakeIDsFromAPIAssigneesToAdd(db.DefaultContext, "", []string{"user1", "user2"}) assert.NoError(t, err) assert.Equal(t, []int64{1, 2}, IDs) } diff --git a/models/issues/comment.go b/models/issues/comment.go index d71c675d23..9483814a19 100644 --- a/models/issues/comment.go +++ b/models/issues/comment.go @@ -309,13 +309,8 @@ type PushActionContent struct { CommitIDs []string `json:"commit_ids"` } -// LoadIssue loads issue from database -func (c *Comment) LoadIssue() (err error) { - return c.LoadIssueCtx(db.DefaultContext) -} - -// LoadIssueCtx loads issue from database -func (c *Comment) LoadIssueCtx(ctx context.Context) (err error) { +// LoadIssue loads the issue reference for the comment +func (c *Comment) LoadIssue(ctx context.Context) (err error) { if c.Issue != nil { return nil } @@ -350,7 +345,8 @@ func (c *Comment) AfterLoad(session *xorm.Session) { } } -func (c *Comment) loadPoster(ctx context.Context) (err error) { +// LoadPoster loads comment poster +func (c *Comment) LoadPoster(ctx context.Context) (err error) { if c.PosterID <= 0 || c.Poster != nil { return nil } @@ -381,7 +377,7 @@ func (c *Comment) AfterDelete() { // HTMLURL formats a URL-string to the issue-comment func (c *Comment) HTMLURL() string { - err := c.LoadIssue() + err := c.LoadIssue(db.DefaultContext) if err != nil { // Silently dropping errors :unamused: log.Error("LoadIssue(%d): %v", c.IssueID, err) return "" @@ -410,7 +406,7 @@ func (c *Comment) HTMLURL() string { // APIURL formats a API-string to the issue-comment func (c *Comment) APIURL() string { - err := c.LoadIssue() + err := c.LoadIssue(db.DefaultContext) if err != nil { // Silently dropping errors :unamused: log.Error("LoadIssue(%d): %v", c.IssueID, err) return "" @@ -426,7 +422,7 @@ func (c *Comment) APIURL() string { // IssueURL formats a URL-string to the issue func (c *Comment) IssueURL() string { - err := c.LoadIssue() + err := c.LoadIssue(db.DefaultContext) if err != nil { // Silently dropping errors :unamused: log.Error("LoadIssue(%d): %v", c.IssueID, err) return "" @@ -446,7 +442,7 @@ func (c *Comment) IssueURL() string { // PRURL formats a URL-string to the pull-request func (c *Comment) PRURL() string { - err := c.LoadIssue() + err := c.LoadIssue(db.DefaultContext) if err != nil { // Silently dropping errors :unamused: log.Error("LoadIssue(%d): %v", c.IssueID, err) return "" @@ -521,10 +517,10 @@ func (c *Comment) LoadProject() error { } // LoadMilestone if comment.Type is CommentTypeMilestone, then load milestone -func (c *Comment) LoadMilestone() error { +func (c *Comment) LoadMilestone(ctx context.Context) error { if c.OldMilestoneID > 0 { var oldMilestone Milestone - has, err := db.GetEngine(db.DefaultContext).ID(c.OldMilestoneID).Get(&oldMilestone) + has, err := db.GetEngine(ctx).ID(c.OldMilestoneID).Get(&oldMilestone) if err != nil { return err } else if has { @@ -534,7 +530,7 @@ func (c *Comment) LoadMilestone() error { if c.MilestoneID > 0 { var milestone Milestone - has, err := db.GetEngine(db.DefaultContext).ID(c.MilestoneID).Get(&milestone) + has, err := db.GetEngine(ctx).ID(c.MilestoneID).Get(&milestone) if err != nil { return err } else if has { @@ -544,19 +540,14 @@ func (c *Comment) LoadMilestone() error { return nil } -// LoadPoster loads comment poster -func (c *Comment) LoadPoster() error { - return c.loadPoster(db.DefaultContext) -} - // LoadAttachments loads attachments (it never returns error, the error during `GetAttachmentsByCommentIDCtx` is ignored) -func (c *Comment) LoadAttachments() error { +func (c *Comment) LoadAttachments(ctx context.Context) error { if len(c.Attachments) > 0 { return nil } var err error - c.Attachments, err = repo_model.GetAttachmentsByCommentID(db.DefaultContext, c.ID) + c.Attachments, err = repo_model.GetAttachmentsByCommentID(ctx, c.ID) if err != nil { log.Error("getAttachmentsByCommentID[%d]: %v", c.ID, err) } @@ -598,7 +589,7 @@ func (c *Comment) LoadAssigneeUserAndTeam() error { c.Assignee = user_model.NewGhostUser() } } else if c.AssigneeTeamID > 0 && c.AssigneeTeam == nil { - if err = c.LoadIssue(); err != nil { + if err = c.LoadIssue(db.DefaultContext); err != nil { return err } @@ -740,7 +731,7 @@ func (c *Comment) UnsignedLine() uint64 { // CodeCommentURL returns the url to a comment in code func (c *Comment) CodeCommentURL() string { - err := c.LoadIssue() + err := c.LoadIssue(db.DefaultContext) if err != nil { // Silently dropping errors :unamused: log.Error("LoadIssue(%d): %v", c.IssueID, err) return "" @@ -1145,7 +1136,7 @@ func UpdateComment(c *Comment, doer *user_model.User) error { if _, err := sess.ID(c.ID).AllCols().Update(c); err != nil { return err } - if err := c.LoadIssueCtx(ctx); err != nil { + if err := c.LoadIssue(ctx); err != nil { return err } if err := c.AddCrossReferences(ctx, doer, true); err != nil { @@ -1245,7 +1236,7 @@ func findCodeComments(ctx context.Context, opts FindCommentsOptions, issue *Issu return nil, err } - if err := CommentList(comments).loadPosters(ctx); err != nil { + if err := CommentList(comments).LoadPosters(ctx); err != nil { return nil, err } @@ -1363,11 +1354,11 @@ func CreateAutoMergeComment(ctx context.Context, typ CommentType, pr *PullReques if typ != CommentTypePRScheduledToAutoMerge && typ != CommentTypePRUnScheduledToAutoMerge { return nil, fmt.Errorf("comment type %d cannot be used to create an auto merge comment", typ) } - if err = pr.LoadIssueCtx(ctx); err != nil { + if err = pr.LoadIssue(ctx); err != nil { return } - if err = pr.LoadBaseRepoCtx(ctx); err != nil { + if err = pr.LoadBaseRepo(ctx); err != nil { return } @@ -1512,18 +1503,18 @@ func (c *Comment) GetExternalName() string { return c.OriginalAuthor } func (c *Comment) GetExternalID() int64 { return c.OriginalAuthorID } // CountCommentTypeLabelWithEmptyLabel count label comments with empty label -func CountCommentTypeLabelWithEmptyLabel() (int64, error) { - return db.GetEngine(db.DefaultContext).Where(builder.Eq{"type": CommentTypeLabel, "label_id": 0}).Count(new(Comment)) +func CountCommentTypeLabelWithEmptyLabel(ctx context.Context) (int64, error) { + return db.GetEngine(ctx).Where(builder.Eq{"type": CommentTypeLabel, "label_id": 0}).Count(new(Comment)) } // FixCommentTypeLabelWithEmptyLabel count label comments with empty label -func FixCommentTypeLabelWithEmptyLabel() (int64, error) { - return db.GetEngine(db.DefaultContext).Where(builder.Eq{"type": CommentTypeLabel, "label_id": 0}).Delete(new(Comment)) +func FixCommentTypeLabelWithEmptyLabel(ctx context.Context) (int64, error) { + return db.GetEngine(ctx).Where(builder.Eq{"type": CommentTypeLabel, "label_id": 0}).Delete(new(Comment)) } // CountCommentTypeLabelWithOutsideLabels count label comments with outside label -func CountCommentTypeLabelWithOutsideLabels() (int64, error) { - return db.GetEngine(db.DefaultContext).Where("comment.type = ? AND ((label.org_id = 0 AND issue.repo_id != label.repo_id) OR (label.repo_id = 0 AND label.org_id != repository.owner_id))", CommentTypeLabel). +func CountCommentTypeLabelWithOutsideLabels(ctx context.Context) (int64, error) { + return db.GetEngine(ctx).Where("comment.type = ? AND ((label.org_id = 0 AND issue.repo_id != label.repo_id) OR (label.repo_id = 0 AND label.org_id != repository.owner_id))", CommentTypeLabel). Table("comment"). Join("inner", "label", "label.id = comment.label_id"). Join("inner", "issue", "issue.id = comment.issue_id "). @@ -1532,8 +1523,8 @@ func CountCommentTypeLabelWithOutsideLabels() (int64, error) { } // FixCommentTypeLabelWithOutsideLabels count label comments with outside label -func FixCommentTypeLabelWithOutsideLabels() (int64, error) { - res, err := db.GetEngine(db.DefaultContext).Exec(`DELETE FROM comment WHERE comment.id IN ( +func FixCommentTypeLabelWithOutsideLabels(ctx context.Context) (int64, error) { + res, err := db.GetEngine(ctx).Exec(`DELETE FROM comment WHERE comment.id IN ( SELECT il_too.id FROM ( SELECT com.id FROM comment AS com diff --git a/models/issues/comment_list.go b/models/issues/comment_list.go index 70105d7ff0..e42b8605f9 100644 --- a/models/issues/comment_list.go +++ b/models/issues/comment_list.go @@ -24,7 +24,8 @@ func (comments CommentList) getPosterIDs() []int64 { return posterIDs.Values() } -func (comments CommentList) loadPosters(ctx context.Context) error { +// LoadPosters loads posters +func (comments CommentList) LoadPosters(ctx context.Context) error { if len(comments) == 0 { return nil } @@ -277,7 +278,8 @@ func (comments CommentList) Issues() IssueList { return issueList } -func (comments CommentList) loadIssues(ctx context.Context) error { +// LoadIssues loads issues of comments +func (comments CommentList) LoadIssues(ctx context.Context) error { if len(comments) == 0 { return nil } @@ -382,7 +384,8 @@ func (comments CommentList) loadDependentIssues(ctx context.Context) error { return nil } -func (comments CommentList) loadAttachments(ctx context.Context) (err error) { +// LoadAttachments loads attachments +func (comments CommentList) LoadAttachments(ctx context.Context) (err error) { if len(comments) == 0 { return nil } @@ -476,7 +479,7 @@ func (comments CommentList) loadReviews(ctx context.Context) error { //nolint // loadAttributes loads all attributes func (comments CommentList) loadAttributes(ctx context.Context) (err error) { - if err = comments.loadPosters(ctx); err != nil { + if err = comments.LoadPosters(ctx); err != nil { return } @@ -496,7 +499,7 @@ func (comments CommentList) loadAttributes(ctx context.Context) (err error) { return } - if err = comments.loadAttachments(ctx); err != nil { + if err = comments.LoadAttachments(ctx); err != nil { return } @@ -504,7 +507,7 @@ func (comments CommentList) loadAttributes(ctx context.Context) (err error) { return } - if err = comments.loadIssues(ctx); err != nil { + if err = comments.LoadIssues(ctx); err != nil { return } @@ -520,18 +523,3 @@ func (comments CommentList) loadAttributes(ctx context.Context) (err error) { func (comments CommentList) LoadAttributes() error { return comments.loadAttributes(db.DefaultContext) } - -// LoadAttachments loads attachments -func (comments CommentList) LoadAttachments() error { - return comments.loadAttachments(db.DefaultContext) -} - -// LoadPosters loads posters -func (comments CommentList) LoadPosters() error { - return comments.loadPosters(db.DefaultContext) -} - -// LoadIssues loads issues of comments -func (comments CommentList) LoadIssues() error { - return comments.loadIssues(db.DefaultContext) -} diff --git a/models/issues/issue.go b/models/issues/issue.go index c2f7cb6578..69d6657d46 100644 --- a/models/issues/issue.go +++ b/models/issues/issue.go @@ -241,11 +241,7 @@ func (issue *Issue) LoadLabels(ctx context.Context) (err error) { } // LoadPoster loads poster -func (issue *Issue) LoadPoster() error { - return issue.loadPoster(db.DefaultContext) -} - -func (issue *Issue) loadPoster(ctx context.Context) (err error) { +func (issue *Issue) LoadPoster(ctx context.Context) (err error) { if issue.Poster == nil { issue.Poster, err = user_model.GetUserByIDCtx(ctx, issue.PosterID) if err != nil { @@ -261,7 +257,8 @@ func (issue *Issue) loadPoster(ctx context.Context) (err error) { return err } -func (issue *Issue) loadPullRequest(ctx context.Context) (err error) { +// LoadPullRequest loads pull request info +func (issue *Issue) LoadPullRequest(ctx context.Context) (err error) { if issue.IsPull && issue.PullRequest == nil { issue.PullRequest, err = GetPullRequestByIssueID(ctx, issue.ID) if err != nil { @@ -275,18 +272,13 @@ func (issue *Issue) loadPullRequest(ctx context.Context) (err error) { return nil } -// LoadPullRequest loads pull request info -func (issue *Issue) LoadPullRequest() error { - return issue.loadPullRequest(db.DefaultContext) -} - func (issue *Issue) loadComments(ctx context.Context) (err error) { return issue.loadCommentsByType(ctx, CommentTypeUnknown) } // LoadDiscussComments loads discuss comments -func (issue *Issue) LoadDiscussComments() error { - return issue.loadCommentsByType(db.DefaultContext, CommentTypeComment) +func (issue *Issue) LoadDiscussComments(ctx context.Context) error { + return issue.loadCommentsByType(ctx, CommentTypeComment) } func (issue *Issue) loadCommentsByType(ctx context.Context, tp CommentType) (err error) { @@ -357,7 +349,8 @@ func (issue *Issue) loadForeignReference(ctx context.Context) (err error) { return nil } -func (issue *Issue) loadMilestone(ctx context.Context) (err error) { +// LoadMilestone load milestone of this issue. +func (issue *Issue) LoadMilestone(ctx context.Context) (err error) { if (issue.Milestone == nil || issue.Milestone.ID != issue.MilestoneID) && issue.MilestoneID > 0 { issue.Milestone, err = GetMilestoneByRepoID(ctx, issue.RepoID, issue.MilestoneID) if err != nil && !IsErrMilestoneNotExist(err) { @@ -373,7 +366,7 @@ func (issue *Issue) LoadAttributes(ctx context.Context) (err error) { return } - if err = issue.loadPoster(ctx); err != nil { + if err = issue.LoadPoster(ctx); err != nil { return } @@ -381,7 +374,7 @@ func (issue *Issue) LoadAttributes(ctx context.Context) (err error) { return } - if err = issue.loadMilestone(ctx); err != nil { + if err = issue.LoadMilestone(ctx); err != nil { return } @@ -393,7 +386,7 @@ func (issue *Issue) LoadAttributes(ctx context.Context) (err error) { return } - if err = issue.loadPullRequest(ctx); err != nil && !IsErrPullRequestNotExist(err) { + if err = issue.LoadPullRequest(ctx); err != nil && !IsErrPullRequestNotExist(err) { // It is possible pull request is not yet created. return err } @@ -425,11 +418,6 @@ func (issue *Issue) LoadAttributes(ctx context.Context) (err error) { return issue.loadReactions(ctx) } -// LoadMilestone load milestone of this issue. -func (issue *Issue) LoadMilestone() error { - return issue.loadMilestone(db.DefaultContext) -} - // GetIsRead load the `IsRead` field of the issue func (issue *Issue) GetIsRead(userID int64) error { issueUser := &IssueUser{IssueID: issue.ID, UID: userID} @@ -548,7 +536,7 @@ func ClearIssueLabels(issue *Issue, doer *user_model.User) (err error) { if err := issue.LoadRepo(ctx); err != nil { return err - } else if err = issue.loadPullRequest(ctx); err != nil { + } else if err = issue.LoadPullRequest(ctx); err != nil { return err } @@ -751,7 +739,7 @@ func ChangeIssueStatus(ctx context.Context, issue *Issue, doer *user_model.User, if err := issue.LoadRepo(ctx); err != nil { return nil, err } - if err := issue.loadPoster(ctx); err != nil { + if err := issue.LoadPoster(ctx); err != nil { return nil, err } @@ -1027,7 +1015,7 @@ func NewIssueWithIndex(ctx context.Context, doer *user_model.User, opts NewIssue return fmt.Errorf("find all labels [label_ids: %v]: %w", opts.LabelIDs, err) } - if err = opts.Issue.loadPoster(ctx); err != nil { + if err = opts.Issue.LoadPoster(ctx); err != nil { return err } @@ -1505,10 +1493,9 @@ func applySubscribedCondition(sess *xorm.Session, subscriberID int64) *xorm.Sess } // CountIssuesByRepo map from repoID to number of issues matching the options -func CountIssuesByRepo(opts *IssuesOptions) (map[int64]int64, error) { - e := db.GetEngine(db.DefaultContext) - - sess := e.Join("INNER", "repository", "`issue`.repo_id = `repository`.id") +func CountIssuesByRepo(ctx context.Context, opts *IssuesOptions) (map[int64]int64, error) { + sess := db.GetEngine(ctx). + Join("INNER", "repository", "`issue`.repo_id = `repository`.id") opts.setupSessionNoLimit(sess) @@ -1551,10 +1538,9 @@ func GetRepoIDsForIssuesOptions(opts *IssuesOptions, user *user_model.User) ([]i } // Issues returns a list of issues by given conditions. -func Issues(opts *IssuesOptions) ([]*Issue, error) { - e := db.GetEngine(db.DefaultContext) - - sess := e.Join("INNER", "repository", "`issue`.repo_id = `repository`.id") +func Issues(ctx context.Context, opts *IssuesOptions) ([]*Issue, error) { + sess := db.GetEngine(ctx). + Join("INNER", "repository", "`issue`.repo_id = `repository`.id") opts.setupSessionWithLimit(sess) sortIssuesSession(sess, opts.SortType, opts.PriorityRepoID) @@ -1572,11 +1558,11 @@ func Issues(opts *IssuesOptions) ([]*Issue, error) { } // CountIssues number return of issues by given conditions. -func CountIssues(opts *IssuesOptions) (int64, error) { - e := db.GetEngine(db.DefaultContext) - - sess := e.Select("COUNT(issue.id) AS count").Table("issue") - sess.Join("INNER", "repository", "`issue`.repo_id = `repository`.id") +func CountIssues(ctx context.Context, opts *IssuesOptions) (int64, error) { + sess := db.GetEngine(ctx). + Select("COUNT(issue.id) AS count"). + Table("issue"). + Join("INNER", "repository", "`issue`.repo_id = `repository`.id") opts.setupSessionNoLimit(sess) return sess.Count() @@ -1585,9 +1571,10 @@ func CountIssues(opts *IssuesOptions) (int64, error) { // GetParticipantsIDsByIssueID returns the IDs of all users who participated in comments of an issue, // but skips joining with `user` for performance reasons. // User permissions must be verified elsewhere if required. -func GetParticipantsIDsByIssueID(issueID int64) ([]int64, error) { +func GetParticipantsIDsByIssueID(ctx context.Context, issueID int64) ([]int64, error) { userIDs := make([]int64, 0, 5) - return userIDs, db.GetEngine(db.DefaultContext).Table("comment"). + return userIDs, db.GetEngine(ctx). + Table("comment"). Cols("poster_id"). Where("issue_id = ?", issueID). And("type in (?,?,?)", CommentTypeComment, CommentTypeCode, CommentTypeReview). @@ -2426,8 +2413,9 @@ func (issue *Issue) GetExternalName() string { return issue.OriginalAuthor } func (issue *Issue) GetExternalID() int64 { return issue.OriginalAuthorID } // CountOrphanedIssues count issues without a repo -func CountOrphanedIssues() (int64, error) { - return db.GetEngine(db.DefaultContext).Table("issue"). +func CountOrphanedIssues(ctx context.Context) (int64, error) { + return db.GetEngine(ctx). + Table("issue"). Join("LEFT", "repository", "issue.repo_id=repository.id"). Where(builder.IsNull{"repository.id"}). Select("COUNT(`issue`.`id`)"). @@ -2435,35 +2423,31 @@ func CountOrphanedIssues() (int64, error) { } // DeleteOrphanedIssues delete issues without a repo -func DeleteOrphanedIssues() error { - ctx, committer, err := db.TxContext(db.DefaultContext) - if err != nil { - return err - } - defer committer.Close() - - var ids []int64 - - if err := db.GetEngine(ctx).Table("issue").Distinct("issue.repo_id"). - Join("LEFT", "repository", "issue.repo_id=repository.id"). - Where(builder.IsNull{"repository.id"}).GroupBy("issue.repo_id"). - Find(&ids); err != nil { - return err - } - +func DeleteOrphanedIssues(ctx context.Context) error { var attachmentPaths []string - for i := range ids { - paths, err := DeleteIssuesByRepoID(ctx, ids[i]) - if err != nil { + err := db.AutoTx(ctx, func(ctx context.Context) error { + var ids []int64 + + if err := db.GetEngine(ctx).Table("issue").Distinct("issue.repo_id"). + Join("LEFT", "repository", "issue.repo_id=repository.id"). + Where(builder.IsNull{"repository.id"}).GroupBy("issue.repo_id"). + Find(&ids); err != nil { return err } - attachmentPaths = append(attachmentPaths, paths...) - } - if err := committer.Commit(); err != nil { + for i := range ids { + paths, err := DeleteIssuesByRepoID(ctx, ids[i]) + if err != nil { + return err + } + attachmentPaths = append(attachmentPaths, paths...) + } + + return nil + }) + if err != nil { return err } - committer.Close() // Remove issue attachment files. for i := range attachmentPaths { diff --git a/models/issues/issue_list.go b/models/issues/issue_list.go index bbe2292dd1..d9dff4cb4d 100644 --- a/models/issues/issue_list.go +++ b/models/issues/issue_list.go @@ -34,7 +34,8 @@ func (issues IssueList) getRepoIDs() []int64 { return repoIDs.Values() } -func (issues IssueList) loadRepositories(ctx context.Context) ([]*repo_model.Repository, error) { +// LoadRepositories loads issues' all repositories +func (issues IssueList) LoadRepositories(ctx context.Context) ([]*repo_model.Repository, error) { if len(issues) == 0 { return nil, nil } @@ -73,11 +74,6 @@ func (issues IssueList) loadRepositories(ctx context.Context) ([]*repo_model.Rep return repo_model.ValuesRepository(repoMaps), nil } -// LoadRepositories loads issues' all repositories -func (issues IssueList) LoadRepositories() ([]*repo_model.Repository, error) { - return issues.loadRepositories(db.DefaultContext) -} - func (issues IssueList) getPosterIDs() []int64 { posterIDs := make(container.Set[int64], len(issues)) for _, issue := range issues { @@ -317,7 +313,8 @@ func (issues IssueList) getPullIssueIDs() []int64 { return ids } -func (issues IssueList) loadPullRequests(ctx context.Context) error { +// LoadPullRequests loads pull requests +func (issues IssueList) LoadPullRequests(ctx context.Context) error { issuesIDs := issues.getPullIssueIDs() if len(issuesIDs) == 0 { return nil @@ -361,7 +358,8 @@ func (issues IssueList) loadPullRequests(ctx context.Context) error { return nil } -func (issues IssueList) loadAttachments(ctx context.Context) (err error) { +// LoadAttachments loads attachments +func (issues IssueList) LoadAttachments(ctx context.Context) (err error) { if len(issues) == 0 { return nil } @@ -513,8 +511,8 @@ func (issues IssueList) loadTotalTrackedTimes(ctx context.Context) (err error) { // loadAttributes loads all attributes, expect for attachments and comments func (issues IssueList) loadAttributes(ctx context.Context) error { - if _, err := issues.loadRepositories(ctx); err != nil { - return fmt.Errorf("issue.loadAttributes: loadRepositories: %w", err) + if _, err := issues.LoadRepositories(ctx); err != nil { + return fmt.Errorf("issue.loadAttributes: LoadRepositories: %w", err) } if err := issues.loadPosters(ctx); err != nil { @@ -537,7 +535,7 @@ func (issues IssueList) loadAttributes(ctx context.Context) error { return fmt.Errorf("issue.loadAttributes: loadAssignees: %w", err) } - if err := issues.loadPullRequests(ctx); err != nil { + if err := issues.LoadPullRequests(ctx); err != nil { return fmt.Errorf("issue.loadAttributes: loadPullRequests: %w", err) } @@ -554,24 +552,14 @@ func (issues IssueList) LoadAttributes() error { return issues.loadAttributes(db.DefaultContext) } -// LoadAttachments loads attachments -func (issues IssueList) LoadAttachments() error { - return issues.loadAttachments(db.DefaultContext) -} - // LoadComments loads comments -func (issues IssueList) LoadComments() error { - return issues.loadComments(db.DefaultContext, builder.NewCond()) +func (issues IssueList) LoadComments(ctx context.Context) error { + return issues.loadComments(ctx, builder.NewCond()) } // LoadDiscussComments loads discuss comments -func (issues IssueList) LoadDiscussComments() error { - return issues.loadComments(db.DefaultContext, builder.Eq{"comment.type": CommentTypeComment}) -} - -// LoadPullRequests loads pull requests -func (issues IssueList) LoadPullRequests() error { - return issues.loadPullRequests(db.DefaultContext) +func (issues IssueList) LoadDiscussComments(ctx context.Context) error { + return issues.loadComments(ctx, builder.Eq{"comment.type": CommentTypeComment}) } // GetApprovalCounts returns a map of issue ID to slice of approval counts diff --git a/models/issues/issue_list_test.go b/models/issues/issue_list_test.go index f2cfca9bc0..c38a405e02 100644 --- a/models/issues/issue_list_test.go +++ b/models/issues/issue_list_test.go @@ -7,6 +7,7 @@ package issues_test import ( "testing" + "code.gitea.io/gitea/models/db" issues_model "code.gitea.io/gitea/models/issues" "code.gitea.io/gitea/models/unittest" "code.gitea.io/gitea/modules/setting" @@ -23,7 +24,7 @@ func TestIssueList_LoadRepositories(t *testing.T) { unittest.AssertExistsAndLoadBean(t, &issues_model.Issue{ID: 4}), } - repos, err := issueList.LoadRepositories() + repos, err := issueList.LoadRepositories(db.DefaultContext) assert.NoError(t, err) assert.Len(t, repos, 2) for _, issue := range issueList { diff --git a/models/issues/issue_project.go b/models/issues/issue_project.go index 0c59f4e82b..39a27abd3e 100644 --- a/models/issues/issue_project.go +++ b/models/issues/issue_project.go @@ -61,11 +61,11 @@ func (issue *Issue) projectBoardID(ctx context.Context) int64 { } // LoadIssuesFromBoard load issues assigned to this board -func LoadIssuesFromBoard(b *project_model.Board) (IssueList, error) { +func LoadIssuesFromBoard(ctx context.Context, b *project_model.Board) (IssueList, error) { issueList := make([]*Issue, 0, 10) if b.ID != 0 { - issues, err := Issues(&IssuesOptions{ + issues, err := Issues(ctx, &IssuesOptions{ ProjectBoardID: b.ID, ProjectID: b.ProjectID, SortType: "project-column-sorting", @@ -77,7 +77,7 @@ func LoadIssuesFromBoard(b *project_model.Board) (IssueList, error) { } if b.Default { - issues, err := Issues(&IssuesOptions{ + issues, err := Issues(ctx, &IssuesOptions{ ProjectBoardID: -1, // Issues without ProjectBoardID ProjectID: b.ProjectID, SortType: "project-column-sorting", @@ -88,7 +88,7 @@ func LoadIssuesFromBoard(b *project_model.Board) (IssueList, error) { issueList = append(issueList, issues...) } - if err := IssueList(issueList).LoadComments(); err != nil { + if err := IssueList(issueList).LoadComments(ctx); err != nil { return nil, err } @@ -96,10 +96,10 @@ func LoadIssuesFromBoard(b *project_model.Board) (IssueList, error) { } // LoadIssuesFromBoardList load issues assigned to the boards -func LoadIssuesFromBoardList(bs project_model.BoardList) (map[int64]IssueList, error) { +func LoadIssuesFromBoardList(ctx context.Context, bs project_model.BoardList) (map[int64]IssueList, error) { issuesMap := make(map[int64]IssueList, len(bs)) for i := range bs { - il, err := LoadIssuesFromBoard(bs[i]) + il, err := LoadIssuesFromBoard(ctx, bs[i]) if err != nil { return nil, err } diff --git a/models/issues/issue_test.go b/models/issues/issue_test.go index bef5d03e8a..2c8728e71a 100644 --- a/models/issues/issue_test.go +++ b/models/issues/issue_test.go @@ -189,7 +189,7 @@ func TestIssues(t *testing.T) { []int64{}, // issues with **both** label 1 and 2, none of these issues matches, TODO: add more tests }, } { - issues, err := issues_model.Issues(&test.Opts) + issues, err := issues_model.Issues(db.DefaultContext, &test.Opts) assert.NoError(t, err) if assert.Len(t, issues, len(test.ExpectedIssueIDs)) { for i, issue := range issues { @@ -556,7 +556,7 @@ func TestLoadTotalTrackedTime(t *testing.T) { func TestCountIssues(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - count, err := issues_model.CountIssues(&issues_model.IssuesOptions{}) + count, err := issues_model.CountIssues(db.DefaultContext, &issues_model.IssuesOptions{}) assert.NoError(t, err) assert.EqualValues(t, 17, count) } diff --git a/models/issues/issue_xref.go b/models/issues/issue_xref.go index e389f63d72..4c6601a0a2 100644 --- a/models/issues/issue_xref.go +++ b/models/issues/issue_xref.go @@ -235,7 +235,7 @@ func (c *Comment) AddCrossReferences(stdCtx context.Context, doer *user_model.Us if c.Type != CommentTypeCode && c.Type != CommentTypeComment { return nil } - if err := c.LoadIssueCtx(stdCtx); err != nil { + if err := c.LoadIssue(stdCtx); err != nil { return err } ctx := &crossReferencesContext{ diff --git a/models/issues/label.go b/models/issues/label.go index 0b0d1419b1..dc7058d643 100644 --- a/models/issues/label.go +++ b/models/issues/label.go @@ -116,8 +116,8 @@ func (label *Label) CalOpenIssues() { } // CalOpenOrgIssues calculates the open issues of a label for a specific repo -func (label *Label) CalOpenOrgIssues(repoID, labelID int64) { - counts, _ := CountIssuesByRepo(&IssuesOptions{ +func (label *Label) CalOpenOrgIssues(ctx context.Context, repoID, labelID int64) { + counts, _ := CountIssuesByRepo(ctx, &IssuesOptions{ RepoID: repoID, LabelIDs: []int64{labelID}, IsClosed: util.OptionalBoolFalse, @@ -395,9 +395,9 @@ func BuildLabelNamesIssueIDsCondition(labelNames []string) *builder.Builder { // GetLabelsInRepoByIDs returns a list of labels by IDs in given repository, // it silently ignores label IDs that do not belong to the repository. -func GetLabelsInRepoByIDs(repoID int64, labelIDs []int64) ([]*Label, error) { +func GetLabelsInRepoByIDs(ctx context.Context, repoID int64, labelIDs []int64) ([]*Label, error) { labels := make([]*Label, 0, len(labelIDs)) - return labels, db.GetEngine(db.DefaultContext). + return labels, db.GetEngine(ctx). Where("repo_id = ?", repoID). In("id", labelIDs). Asc("name"). @@ -498,9 +498,9 @@ func GetLabelIDsInOrgByNames(orgID int64, labelNames []string) ([]int64, error) // GetLabelsInOrgByIDs returns a list of labels by IDs in given organization, // it silently ignores label IDs that do not belong to the organization. -func GetLabelsInOrgByIDs(orgID int64, labelIDs []int64) ([]*Label, error) { +func GetLabelsInOrgByIDs(ctx context.Context, orgID int64, labelIDs []int64) ([]*Label, error) { labels := make([]*Label, 0, len(labelIDs)) - return labels, db.GetEngine(db.DefaultContext). + return labels, db.GetEngine(ctx). Where("org_id = ?", orgID). In("id", labelIDs). Asc("name"). @@ -746,13 +746,13 @@ func DeleteLabelsByRepoID(ctx context.Context, repoID int64) error { } // CountOrphanedLabels return count of labels witch are broken and not accessible via ui anymore -func CountOrphanedLabels() (int64, error) { - noref, err := db.GetEngine(db.DefaultContext).Table("label").Where("repo_id=? AND org_id=?", 0, 0).Count() +func CountOrphanedLabels(ctx context.Context) (int64, error) { + noref, err := db.GetEngine(ctx).Table("label").Where("repo_id=? AND org_id=?", 0, 0).Count() if err != nil { return 0, err } - norepo, err := db.GetEngine(db.DefaultContext).Table("label"). + norepo, err := db.GetEngine(ctx).Table("label"). Where(builder.And( builder.Gt{"repo_id": 0}, builder.NotIn("repo_id", builder.Select("id").From("repository")), @@ -762,7 +762,7 @@ func CountOrphanedLabels() (int64, error) { return 0, err } - noorg, err := db.GetEngine(db.DefaultContext).Table("label"). + noorg, err := db.GetEngine(ctx).Table("label"). Where(builder.And( builder.Gt{"org_id": 0}, builder.NotIn("org_id", builder.Select("id").From("user")), @@ -776,14 +776,14 @@ func CountOrphanedLabels() (int64, error) { } // DeleteOrphanedLabels delete labels witch are broken and not accessible via ui anymore -func DeleteOrphanedLabels() error { +func DeleteOrphanedLabels(ctx context.Context) error { // delete labels with no reference - if _, err := db.GetEngine(db.DefaultContext).Table("label").Where("repo_id=? AND org_id=?", 0, 0).Delete(new(Label)); err != nil { + if _, err := db.GetEngine(ctx).Table("label").Where("repo_id=? AND org_id=?", 0, 0).Delete(new(Label)); err != nil { return err } // delete labels with none existing repos - if _, err := db.GetEngine(db.DefaultContext). + if _, err := db.GetEngine(ctx). Where(builder.And( builder.Gt{"repo_id": 0}, builder.NotIn("repo_id", builder.Select("id").From("repository")), @@ -793,7 +793,7 @@ func DeleteOrphanedLabels() error { } // delete labels with none existing orgs - if _, err := db.GetEngine(db.DefaultContext). + if _, err := db.GetEngine(ctx). Where(builder.And( builder.Gt{"org_id": 0}, builder.NotIn("org_id", builder.Select("id").From("user")), @@ -806,23 +806,23 @@ func DeleteOrphanedLabels() error { } // CountOrphanedIssueLabels return count of IssueLabels witch have no label behind anymore -func CountOrphanedIssueLabels() (int64, error) { - return db.GetEngine(db.DefaultContext).Table("issue_label"). +func CountOrphanedIssueLabels(ctx context.Context) (int64, error) { + return db.GetEngine(ctx).Table("issue_label"). NotIn("label_id", builder.Select("id").From("label")). Count() } // DeleteOrphanedIssueLabels delete IssueLabels witch have no label behind anymore -func DeleteOrphanedIssueLabels() error { - _, err := db.GetEngine(db.DefaultContext). +func DeleteOrphanedIssueLabels(ctx context.Context) error { + _, err := db.GetEngine(ctx). NotIn("label_id", builder.Select("id").From("label")). Delete(IssueLabel{}) return err } // CountIssueLabelWithOutsideLabels count label comments with outside label -func CountIssueLabelWithOutsideLabels() (int64, error) { - return db.GetEngine(db.DefaultContext).Where(builder.Expr("(label.org_id = 0 AND issue.repo_id != label.repo_id) OR (label.repo_id = 0 AND label.org_id != repository.owner_id)")). +func CountIssueLabelWithOutsideLabels(ctx context.Context) (int64, error) { + return db.GetEngine(ctx).Where(builder.Expr("(label.org_id = 0 AND issue.repo_id != label.repo_id) OR (label.repo_id = 0 AND label.org_id != repository.owner_id)")). Table("issue_label"). Join("inner", "label", "issue_label.label_id = label.id "). Join("inner", "issue", "issue.id = issue_label.issue_id "). @@ -831,8 +831,8 @@ func CountIssueLabelWithOutsideLabels() (int64, error) { } // FixIssueLabelWithOutsideLabels fix label comments with outside label -func FixIssueLabelWithOutsideLabels() (int64, error) { - res, err := db.GetEngine(db.DefaultContext).Exec(`DELETE FROM issue_label WHERE issue_label.id IN ( +func FixIssueLabelWithOutsideLabels(ctx context.Context) (int64, error) { + res, err := db.GetEngine(ctx).Exec(`DELETE FROM issue_label WHERE issue_label.id IN ( SELECT il_too.id FROM ( SELECT il_too_too.id FROM issue_label AS il_too_too diff --git a/models/issues/label_test.go b/models/issues/label_test.go index 5e6cc9a2a0..077e0eeb67 100644 --- a/models/issues/label_test.go +++ b/models/issues/label_test.go @@ -121,7 +121,7 @@ func TestGetLabelInRepoByID(t *testing.T) { func TestGetLabelsInRepoByIDs(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - labels, err := issues_model.GetLabelsInRepoByIDs(1, []int64{1, 2, unittest.NonexistentID}) + labels, err := issues_model.GetLabelsInRepoByIDs(db.DefaultContext, 1, []int64{1, 2, unittest.NonexistentID}) assert.NoError(t, err) if assert.Len(t, labels, 2) { assert.EqualValues(t, 1, labels[0].ID) @@ -212,7 +212,7 @@ func TestGetLabelInOrgByID(t *testing.T) { func TestGetLabelsInOrgByIDs(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - labels, err := issues_model.GetLabelsInOrgByIDs(3, []int64{3, 4, unittest.NonexistentID}) + labels, err := issues_model.GetLabelsInOrgByIDs(db.DefaultContext, 3, []int64{3, 4, unittest.NonexistentID}) assert.NoError(t, err) if assert.Len(t, labels, 2) { assert.EqualValues(t, 3, labels[0].ID) diff --git a/models/issues/pull.go b/models/issues/pull.go index e906407d31..993a1ba8bd 100644 --- a/models/issues/pull.go +++ b/models/issues/pull.go @@ -205,8 +205,8 @@ func DeletePullsByBaseRepoID(ctx context.Context, repoID int64) error { } // MustHeadUserName returns the HeadRepo's username if failed return blank -func (pr *PullRequest) MustHeadUserName() string { - if err := pr.LoadHeadRepo(); err != nil { +func (pr *PullRequest) MustHeadUserName(ctx context.Context) string { + if err := pr.LoadHeadRepo(ctx); err != nil { if !repo_model.IsErrRepoNotExist(err) { log.Error("LoadHeadRepo: %v", err) } else { @@ -220,8 +220,9 @@ func (pr *PullRequest) MustHeadUserName() string { return pr.HeadRepo.OwnerName } +// LoadAttributes loads pull request attributes from database // Note: don't try to get Issue because will end up recursive querying. -func (pr *PullRequest) loadAttributes(ctx context.Context) (err error) { +func (pr *PullRequest) LoadAttributes(ctx context.Context) (err error) { if pr.HasMerged && pr.Merger == nil { pr.Merger, err = user_model.GetUserByIDCtx(ctx, pr.MergerID) if user_model.IsErrUserNotExist(err) { @@ -235,13 +236,8 @@ func (pr *PullRequest) loadAttributes(ctx context.Context) (err error) { return nil } -// LoadAttributes loads pull request attributes from database -func (pr *PullRequest) LoadAttributes() error { - return pr.loadAttributes(db.DefaultContext) -} - -// LoadHeadRepoCtx loads the head repository -func (pr *PullRequest) LoadHeadRepoCtx(ctx context.Context) (err error) { +// LoadHeadRepo loads the head repository +func (pr *PullRequest) LoadHeadRepo(ctx context.Context) (err error) { if !pr.isHeadRepoLoaded && pr.HeadRepo == nil && pr.HeadRepoID > 0 { if pr.HeadRepoID == pr.BaseRepoID { if pr.BaseRepo != nil { @@ -262,18 +258,8 @@ func (pr *PullRequest) LoadHeadRepoCtx(ctx context.Context) (err error) { return nil } -// LoadHeadRepo loads the head repository -func (pr *PullRequest) LoadHeadRepo() error { - return pr.LoadHeadRepoCtx(db.DefaultContext) -} - // LoadBaseRepo loads the target repository -func (pr *PullRequest) LoadBaseRepo() error { - return pr.LoadBaseRepoCtx(db.DefaultContext) -} - -// LoadBaseRepoCtx loads the target repository -func (pr *PullRequest) LoadBaseRepoCtx(ctx context.Context) (err error) { +func (pr *PullRequest) LoadBaseRepo(ctx context.Context) (err error) { if pr.BaseRepo != nil { return nil } @@ -296,12 +282,7 @@ func (pr *PullRequest) LoadBaseRepoCtx(ctx context.Context) (err error) { } // LoadIssue loads issue information from database -func (pr *PullRequest) LoadIssue() (err error) { - return pr.LoadIssueCtx(db.DefaultContext) -} - -// LoadIssueCtx loads issue information from database -func (pr *PullRequest) LoadIssueCtx(ctx context.Context) (err error) { +func (pr *PullRequest) LoadIssue(ctx context.Context) (err error) { if pr.Issue != nil { return nil } @@ -392,7 +373,7 @@ func (pr *PullRequest) getReviewedByLines(writer io.Writer) error { break } - if err := review.loadReviewer(ctx); err != nil && !user_model.IsErrUserNotExist(err) { + if err := review.LoadReviewer(ctx); err != nil && !user_model.IsErrUserNotExist(err) { log.Error("Unable to LoadReviewer[%d] for PR ID %d : %v", review.ReviewerID, pr.ID, err) return err } else if review.Reviewer == nil { @@ -458,7 +439,7 @@ func (pr *PullRequest) SetMerged(ctx context.Context) (bool, error) { } pr.Issue = nil - if err := pr.LoadIssueCtx(ctx); err != nil { + if err := pr.LoadIssue(ctx); err != nil { return false, err } @@ -541,9 +522,9 @@ func NewPullRequest(outerCtx context.Context, repo *repo_model.Repository, issue // GetUnmergedPullRequest returns a pull request that is open and has not been merged // by given head/base and repo/branch. -func GetUnmergedPullRequest(headRepoID, baseRepoID int64, headBranch, baseBranch string, flow PullRequestFlow) (*PullRequest, error) { +func GetUnmergedPullRequest(ctx context.Context, headRepoID, baseRepoID int64, headBranch, baseBranch string, flow PullRequestFlow) (*PullRequest, error) { pr := new(PullRequest) - has, err := db.GetEngine(db.DefaultContext). + has, err := db.GetEngine(ctx). Where("head_repo_id=? AND head_branch=? AND base_repo_id=? AND base_branch=? AND has_merged=? AND flow = ? AND issue.is_closed=?", headRepoID, headBranch, baseRepoID, baseBranch, false, flow, false). Join("INNER", "issue", "issue.id=pull_request.issue_id"). @@ -588,10 +569,10 @@ func GetPullRequestByIndex(ctx context.Context, repoID, index int64) (*PullReque return nil, ErrPullRequestNotExist{0, 0, 0, repoID, "", ""} } - if err = pr.loadAttributes(ctx); err != nil { + if err = pr.LoadAttributes(ctx); err != nil { return nil, err } - if err = pr.LoadIssueCtx(ctx); err != nil { + if err = pr.LoadIssue(ctx); err != nil { return nil, err } @@ -607,7 +588,7 @@ func GetPullRequestByID(ctx context.Context, id int64) (*PullRequest, error) { } else if !has { return nil, ErrPullRequestNotExist{id, 0, 0, 0, "", ""} } - return pr, pr.loadAttributes(ctx) + return pr, pr.LoadAttributes(ctx) } // GetPullRequestByIssueIDWithNoAttributes returns pull request with no attributes loaded by given issue ID. @@ -634,7 +615,7 @@ func GetPullRequestByIssueID(ctx context.Context, issueID int64) (*PullRequest, } else if !has { return nil, ErrPullRequestNotExist{0, issueID, 0, 0, "", ""} } - return pr, pr.loadAttributes(ctx) + return pr, pr.LoadAttributes(ctx) } // GetAllUnmergedAgitPullRequestByPoster get all unmerged agit flow pull request @@ -664,14 +645,15 @@ func (pr *PullRequest) UpdateCols(cols ...string) error { } // UpdateColsIfNotMerged updates specific fields of a pull request if it has not been merged -func (pr *PullRequest) UpdateColsIfNotMerged(cols ...string) error { - _, err := db.GetEngine(db.DefaultContext).Where("id = ? AND has_merged = ?", pr.ID, false).Cols(cols...).Update(pr) +func (pr *PullRequest) UpdateColsIfNotMerged(ctx context.Context, cols ...string) error { + _, err := db.GetEngine(ctx).Where("id = ? AND has_merged = ?", pr.ID, false).Cols(cols...).Update(pr) return err } // IsWorkInProgress determine if the Pull Request is a Work In Progress by its title +// Issue must be set before this method can be called. func (pr *PullRequest) IsWorkInProgress() bool { - if err := pr.LoadIssue(); err != nil { + if err := pr.LoadIssue(db.DefaultContext); err != nil { log.Error("LoadIssue: %v", err) return false } @@ -695,8 +677,8 @@ func (pr *PullRequest) IsFilesConflicted() bool { // GetWorkInProgressPrefix returns the prefix used to mark the pull request as a work in progress. // It returns an empty string when none were found -func (pr *PullRequest) GetWorkInProgressPrefix() string { - if err := pr.LoadIssue(); err != nil { +func (pr *PullRequest) GetWorkInProgressPrefix(ctx context.Context) string { + if err := pr.LoadIssue(ctx); err != nil { log.Error("LoadIssue: %v", err) return "" } @@ -739,7 +721,7 @@ func GetPullRequestsByHeadBranch(ctx context.Context, headBranch string, headRep // GetBaseBranchHTMLURL returns the HTML URL of the base branch func (pr *PullRequest) GetBaseBranchHTMLURL() string { - if err := pr.LoadBaseRepo(); err != nil { + if err := pr.LoadBaseRepo(db.DefaultContext); err != nil { log.Error("LoadBaseRepo: %v", err) return "" } @@ -755,7 +737,7 @@ func (pr *PullRequest) GetHeadBranchHTMLURL() string { return "" } - if err := pr.LoadHeadRepo(); err != nil { + if err := pr.LoadHeadRepo(db.DefaultContext); err != nil { log.Error("LoadHeadRepo: %v", err) return "" } diff --git a/models/issues/pull_list.go b/models/issues/pull_list.go index c69f18492b..6110ba77fa 100644 --- a/models/issues/pull_list.go +++ b/models/issues/pull_list.go @@ -79,7 +79,7 @@ func CanMaintainerWriteToBranch(p access_model.Permission, branch string, user * for _, pr := range prs { if pr.AllowMaintainerEdit { - err = pr.LoadBaseRepo() + err = pr.LoadBaseRepo(db.DefaultContext) if err != nil { continue } diff --git a/models/issues/pull_test.go b/models/issues/pull_test.go index fb46e3071e..d88f9d4f54 100644 --- a/models/issues/pull_test.go +++ b/models/issues/pull_test.go @@ -17,7 +17,7 @@ import ( func TestPullRequest_LoadAttributes(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) pr := unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ID: 1}) - assert.NoError(t, pr.LoadAttributes()) + assert.NoError(t, pr.LoadAttributes(db.DefaultContext)) assert.NotNil(t, pr.Merger) assert.Equal(t, pr.MergerID, pr.Merger.ID) } @@ -25,10 +25,10 @@ func TestPullRequest_LoadAttributes(t *testing.T) { func TestPullRequest_LoadIssue(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) pr := unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ID: 1}) - assert.NoError(t, pr.LoadIssue()) + assert.NoError(t, pr.LoadIssue(db.DefaultContext)) assert.NotNil(t, pr.Issue) assert.Equal(t, int64(2), pr.Issue.ID) - assert.NoError(t, pr.LoadIssue()) + assert.NoError(t, pr.LoadIssue(db.DefaultContext)) assert.NotNil(t, pr.Issue) assert.Equal(t, int64(2), pr.Issue.ID) } @@ -36,10 +36,10 @@ func TestPullRequest_LoadIssue(t *testing.T) { func TestPullRequest_LoadBaseRepo(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) pr := unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ID: 1}) - assert.NoError(t, pr.LoadBaseRepo()) + assert.NoError(t, pr.LoadBaseRepo(db.DefaultContext)) assert.NotNil(t, pr.BaseRepo) assert.Equal(t, pr.BaseRepoID, pr.BaseRepo.ID) - assert.NoError(t, pr.LoadBaseRepo()) + assert.NoError(t, pr.LoadBaseRepo(db.DefaultContext)) assert.NotNil(t, pr.BaseRepo) assert.Equal(t, pr.BaseRepoID, pr.BaseRepo.ID) } @@ -47,7 +47,7 @@ func TestPullRequest_LoadBaseRepo(t *testing.T) { func TestPullRequest_LoadHeadRepo(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) pr := unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ID: 1}) - assert.NoError(t, pr.LoadHeadRepo()) + assert.NoError(t, pr.LoadHeadRepo(db.DefaultContext)) assert.NotNil(t, pr.HeadRepo) assert.Equal(t, pr.HeadRepoID, pr.HeadRepo.ID) } @@ -96,11 +96,11 @@ func TestPullRequestsOldest(t *testing.T) { func TestGetUnmergedPullRequest(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - pr, err := issues_model.GetUnmergedPullRequest(1, 1, "branch2", "master", issues_model.PullRequestFlowGithub) + pr, err := issues_model.GetUnmergedPullRequest(db.DefaultContext, 1, 1, "branch2", "master", issues_model.PullRequestFlowGithub) assert.NoError(t, err) assert.Equal(t, int64(2), pr.ID) - _, err = issues_model.GetUnmergedPullRequest(1, 9223372036854775807, "branch1", "master", issues_model.PullRequestFlowGithub) + _, err = issues_model.GetUnmergedPullRequest(db.DefaultContext, 1, 9223372036854775807, "branch1", "master", issues_model.PullRequestFlowGithub) assert.Error(t, err) assert.True(t, issues_model.IsErrPullRequestNotExist(err)) } @@ -228,7 +228,7 @@ func TestPullRequest_IsWorkInProgress(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) pr := unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ID: 2}) - pr.LoadIssue() + pr.LoadIssue(db.DefaultContext) assert.False(t, pr.IsWorkInProgress()) @@ -243,16 +243,16 @@ func TestPullRequest_GetWorkInProgressPrefixWorkInProgress(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) pr := unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ID: 2}) - pr.LoadIssue() + pr.LoadIssue(db.DefaultContext) - assert.Empty(t, pr.GetWorkInProgressPrefix()) + assert.Empty(t, pr.GetWorkInProgressPrefix(db.DefaultContext)) original := pr.Issue.Title pr.Issue.Title = "WIP: " + original - assert.Equal(t, "WIP:", pr.GetWorkInProgressPrefix()) + assert.Equal(t, "WIP:", pr.GetWorkInProgressPrefix(db.DefaultContext)) pr.Issue.Title = "[wip] " + original - assert.Equal(t, "[wip]", pr.GetWorkInProgressPrefix()) + assert.Equal(t, "[wip]", pr.GetWorkInProgressPrefix(db.DefaultContext)) } func TestDeleteOrphanedObjects(t *testing.T) { @@ -264,11 +264,11 @@ func TestDeleteOrphanedObjects(t *testing.T) { _, err = db.GetEngine(db.DefaultContext).Insert(&issues_model.PullRequest{IssueID: 1000}, &issues_model.PullRequest{IssueID: 1001}, &issues_model.PullRequest{IssueID: 1003}) assert.NoError(t, err) - orphaned, err := db.CountOrphanedObjects("pull_request", "issue", "pull_request.issue_id=issue.id") + orphaned, err := db.CountOrphanedObjects(db.DefaultContext, "pull_request", "issue", "pull_request.issue_id=issue.id") assert.NoError(t, err) assert.EqualValues(t, 3, orphaned) - err = db.DeleteOrphanedObjects("pull_request", "issue", "pull_request.issue_id=issue.id") + err = db.DeleteOrphanedObjects(db.DefaultContext, "pull_request", "issue", "pull_request.issue_id=issue.id") assert.NoError(t, err) countAfter, err := db.GetEngine(db.DefaultContext).Count(&issues_model.PullRequest{}) diff --git a/models/issues/review.go b/models/issues/review.go index f66c70c1fc..5cf7d4c3da 100644 --- a/models/issues/review.go +++ b/models/issues/review.go @@ -154,7 +154,8 @@ func (r *Review) loadIssue(ctx context.Context) (err error) { return err } -func (r *Review) loadReviewer(ctx context.Context) (err error) { +// LoadReviewer loads reviewer +func (r *Review) LoadReviewer(ctx context.Context) (err error) { if r.ReviewerID == 0 || r.Reviewer != nil { return } @@ -162,7 +163,8 @@ func (r *Review) loadReviewer(ctx context.Context) (err error) { return err } -func (r *Review) loadReviewerTeam(ctx context.Context) (err error) { +// LoadReviewerTeam loads reviewer team +func (r *Review) LoadReviewerTeam(ctx context.Context) (err error) { if r.ReviewerTeamID == 0 || r.ReviewerTeam != nil { return } @@ -171,16 +173,6 @@ func (r *Review) loadReviewerTeam(ctx context.Context) (err error) { return err } -// LoadReviewer loads reviewer -func (r *Review) LoadReviewer() error { - return r.loadReviewer(db.DefaultContext) -} - -// LoadReviewerTeam loads reviewer team -func (r *Review) LoadReviewerTeam() error { - return r.loadReviewerTeam(db.DefaultContext) -} - // LoadAttributes loads all attributes except CodeComments func (r *Review) LoadAttributes(ctx context.Context) (err error) { if err = r.loadIssue(ctx); err != nil { @@ -189,10 +181,10 @@ func (r *Review) LoadAttributes(ctx context.Context) (err error) { if err = r.LoadCodeComments(ctx); err != nil { return } - if err = r.loadReviewer(ctx); err != nil { + if err = r.LoadReviewer(ctx); err != nil { return } - if err = r.loadReviewerTeam(ctx); err != nil { + if err = r.LoadReviewerTeam(ctx); err != nil { return } return err diff --git a/models/issues/review_test.go b/models/issues/review_test.go index 46d1cc777b..39ad14c65f 100644 --- a/models/issues/review_test.go +++ b/models/issues/review_test.go @@ -135,7 +135,7 @@ func TestGetReviewersByIssueID(t *testing.T) { allReviews, err := issues_model.GetReviewersByIssueID(issue.ID) for _, reviewer := range allReviews { - assert.NoError(t, reviewer.LoadReviewer()) + assert.NoError(t, reviewer.LoadReviewer(db.DefaultContext)) } assert.NoError(t, err) if assert.Len(t, allReviews, 3) { diff --git a/models/org_team_test.go b/models/org_team_test.go index a600d07c0c..3b1fabf1c3 100644 --- a/models/org_team_test.go +++ b/models/org_team_test.go @@ -143,7 +143,7 @@ func TestDeleteTeam(t *testing.T) { // check that team members don't have "leftover" access to repos user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 4}) repo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 3}) - accessMode, err := access_model.AccessLevel(user, repo) + accessMode, err := access_model.AccessLevel(db.DefaultContext, user, repo) assert.NoError(t, err) assert.True(t, accessMode < perm.AccessModeWrite) } diff --git a/models/perm/access/access_test.go b/models/perm/access/access_test.go index 7f58be4f39..dc707b971b 100644 --- a/models/perm/access/access_test.go +++ b/models/perm/access/access_test.go @@ -36,34 +36,34 @@ func TestAccessLevel(t *testing.T) { // org. owned private repo repo24 := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 24}) - level, err := access_model.AccessLevel(user2, repo1) + level, err := access_model.AccessLevel(db.DefaultContext, user2, repo1) assert.NoError(t, err) assert.Equal(t, perm_model.AccessModeOwner, level) - level, err = access_model.AccessLevel(user2, repo3) + level, err = access_model.AccessLevel(db.DefaultContext, user2, repo3) assert.NoError(t, err) assert.Equal(t, perm_model.AccessModeOwner, level) - level, err = access_model.AccessLevel(user5, repo1) + level, err = access_model.AccessLevel(db.DefaultContext, user5, repo1) assert.NoError(t, err) assert.Equal(t, perm_model.AccessModeRead, level) - level, err = access_model.AccessLevel(user5, repo3) + level, err = access_model.AccessLevel(db.DefaultContext, user5, repo3) assert.NoError(t, err) assert.Equal(t, perm_model.AccessModeNone, level) // restricted user has no access to a public repo - level, err = access_model.AccessLevel(user29, repo1) + level, err = access_model.AccessLevel(db.DefaultContext, user29, repo1) assert.NoError(t, err) assert.Equal(t, perm_model.AccessModeNone, level) // ... unless he's a collaborator - level, err = access_model.AccessLevel(user29, repo4) + level, err = access_model.AccessLevel(db.DefaultContext, user29, repo4) assert.NoError(t, err) assert.Equal(t, perm_model.AccessModeWrite, level) // ... or a team member - level, err = access_model.AccessLevel(user29, repo24) + level, err = access_model.AccessLevel(db.DefaultContext, user29, repo24) assert.NoError(t, err) assert.Equal(t, perm_model.AccessModeRead, level) } diff --git a/models/perm/access/repo_permission.go b/models/perm/access/repo_permission.go index 93e3bdd6d8..3b709a3e85 100644 --- a/models/perm/access/repo_permission.go +++ b/models/perm/access/repo_permission.go @@ -326,17 +326,13 @@ func IsUserRepoAdmin(ctx context.Context, repo *repo_model.Repository, user *use // AccessLevel returns the Access a user has to a repository. Will return NoneAccess if the // user does not have access. -func AccessLevel(user *user_model.User, repo *repo_model.Repository) (perm_model.AccessMode, error) { //nolint - return AccessLevelUnit(user, repo, unit.TypeCode) +func AccessLevel(ctx context.Context, user *user_model.User, repo *repo_model.Repository) (perm_model.AccessMode, error) { //nolint + return AccessLevelUnit(ctx, user, repo, unit.TypeCode) } // AccessLevelUnit returns the Access a user has to a repository's. Will return NoneAccess if the // user does not have access. -func AccessLevelUnit(user *user_model.User, repo *repo_model.Repository, unitType unit.Type) (perm_model.AccessMode, error) { //nolint - return accessLevelUnit(db.DefaultContext, user, repo, unitType) -} - -func accessLevelUnit(ctx context.Context, user *user_model.User, repo *repo_model.Repository, unitType unit.Type) (perm_model.AccessMode, error) { +func AccessLevelUnit(ctx context.Context, user *user_model.User, repo *repo_model.Repository, unitType unit.Type) (perm_model.AccessMode, error) { //nolint perm, err := GetUserRepoPermission(ctx, repo, user) if err != nil { return perm_model.AccessModeNone, err @@ -346,7 +342,7 @@ func accessLevelUnit(ctx context.Context, user *user_model.User, repo *repo_mode // HasAccessUnit returns true if user has testMode to the unit of the repository func HasAccessUnit(ctx context.Context, user *user_model.User, repo *repo_model.Repository, unitType unit.Type, testMode perm_model.AccessMode) (bool, error) { - mode, err := accessLevelUnit(ctx, user, repo, unitType) + mode, err := AccessLevelUnit(ctx, user, repo, unitType) return testMode <= mode, err } diff --git a/models/repo/attachment.go b/models/repo/attachment.go index df7528df09..428f370a0b 100644 --- a/models/repo/attachment.go +++ b/models/repo/attachment.go @@ -226,20 +226,20 @@ func UpdateAttachment(ctx context.Context, atta *Attachment) error { } // DeleteAttachmentsByRelease deletes all attachments associated with the given release. -func DeleteAttachmentsByRelease(releaseID int64) error { - _, err := db.GetEngine(db.DefaultContext).Where("release_id = ?", releaseID).Delete(&Attachment{}) +func DeleteAttachmentsByRelease(ctx context.Context, releaseID int64) error { + _, err := db.GetEngine(ctx).Where("release_id = ?", releaseID).Delete(&Attachment{}) return err } // CountOrphanedAttachments returns the number of bad attachments -func CountOrphanedAttachments() (int64, error) { - return db.GetEngine(db.DefaultContext).Where("(issue_id > 0 and issue_id not in (select id from issue)) or (release_id > 0 and release_id not in (select id from `release`))"). +func CountOrphanedAttachments(ctx context.Context) (int64, error) { + return db.GetEngine(ctx).Where("(issue_id > 0 and issue_id not in (select id from issue)) or (release_id > 0 and release_id not in (select id from `release`))"). Count(new(Attachment)) } // DeleteOrphanedAttachments delete all bad attachments -func DeleteOrphanedAttachments() error { - _, err := db.GetEngine(db.DefaultContext).Where("(issue_id > 0 and issue_id not in (select id from issue)) or (release_id > 0 and release_id not in (select id from `release`))"). +func DeleteOrphanedAttachments(ctx context.Context) error { + _, err := db.GetEngine(ctx).Where("(issue_id > 0 and issue_id not in (select id from issue)) or (release_id > 0 and release_id not in (select id from `release`))"). Delete(new(Attachment)) return err } diff --git a/models/repo/pushmirror.go b/models/repo/pushmirror.go index 38d6e72019..fa876ee560 100644 --- a/models/repo/pushmirror.go +++ b/models/repo/pushmirror.go @@ -120,9 +120,9 @@ func GetPushMirrorsByRepoID(ctx context.Context, repoID int64, listOptions db.Li } // GetPushMirrorsSyncedOnCommit returns push-mirrors for this repo that should be updated by new commits -func GetPushMirrorsSyncedOnCommit(repoID int64) ([]*PushMirror, error) { +func GetPushMirrorsSyncedOnCommit(ctx context.Context, repoID int64) ([]*PushMirror, error) { mirrors := make([]*PushMirror, 0, 10) - return mirrors, db.GetEngine(db.DefaultContext). + return mirrors, db.GetEngine(ctx). Where("repo_id=? AND sync_on_commit=?", repoID, true). Find(&mirrors) } diff --git a/models/repo/release.go b/models/repo/release.go index 14428f15f7..a92e4bb6e5 100644 --- a/models/repo/release.go +++ b/models/repo/release.go @@ -90,7 +90,8 @@ func init() { db.RegisterModel(new(Release)) } -func (r *Release) loadAttributes(ctx context.Context) error { +// LoadAttributes load repo and publisher attributes for a release +func (r *Release) LoadAttributes(ctx context.Context) error { var err error if r.Repo == nil { r.Repo, err = GetRepositoryByIDCtx(ctx, r.RepoID) @@ -111,11 +112,6 @@ func (r *Release) loadAttributes(ctx context.Context) error { return GetReleaseAttachments(ctx, r) } -// LoadAttributes load repo and publisher attributes for a release -func (r *Release) LoadAttributes() error { - return r.loadAttributes(db.DefaultContext) -} - // APIURL the api url for a release. release must have attributes loaded func (r *Release) APIURL() string { return r.Repo.APIURL() + "/releases/" + strconv.FormatInt(r.ID, 10) @@ -241,8 +237,8 @@ func (opts *FindReleasesOptions) toConds(repoID int64) builder.Cond { } // GetReleasesByRepoID returns a list of releases of repository. -func GetReleasesByRepoID(repoID int64, opts FindReleasesOptions) ([]*Release, error) { - sess := db.GetEngine(db.DefaultContext). +func GetReleasesByRepoID(ctx context.Context, repoID int64, opts FindReleasesOptions) ([]*Release, error) { + sess := db.GetEngine(ctx). Desc("created_unix", "id"). Where(opts.toConds(repoID)) @@ -381,8 +377,8 @@ func SortReleases(rels []*Release) { } // DeleteReleaseByID deletes a release from database by given ID. -func DeleteReleaseByID(id int64) error { - _, err := db.GetEngine(db.DefaultContext).ID(id).Delete(new(Release)) +func DeleteReleaseByID(ctx context.Context, id int64) error { + _, err := db.GetEngine(ctx).ID(id).Delete(new(Release)) return err } diff --git a/models/repo/repo.go b/models/repo/repo.go index 77e0367a5a..a3dac8383f 100644 --- a/models/repo/repo.go +++ b/models/repo/repo.go @@ -236,14 +236,6 @@ func (repo *Repository) AfterLoad() { repo.NumOpenProjects = repo.NumProjects - repo.NumClosedProjects } -// MustOwner always returns a valid *user_model.User object to avoid -// conceptually impossible error handling. -// It creates a fake object that contains error details -// when error occurs. -func (repo *Repository) MustOwner() *user_model.User { - return repo.mustOwner(db.DefaultContext) -} - // LoadAttributes loads attributes of the repository. func (repo *Repository) LoadAttributes(ctx context.Context) error { // Load owner @@ -403,7 +395,11 @@ func (repo *Repository) GetOwner(ctx context.Context) (err error) { return err } -func (repo *Repository) mustOwner(ctx context.Context) *user_model.User { +// MustOwner always returns a valid *user_model.User object to avoid +// conceptually impossible error handling. +// It creates a fake object that contains error details +// when error occurs. +func (repo *Repository) MustOwner(ctx context.Context) *user_model.User { if err := repo.GetOwner(ctx); err != nil { return &user_model.User{ Name: "error", @@ -438,7 +434,7 @@ func (repo *Repository) ComposeMetas() map[string]string { } } - repo.MustOwner() + repo.MustOwner(db.DefaultContext) if repo.Owner.IsOrganization() { teams := make([]string, 0, 5) _ = db.GetEngine(db.DefaultContext).Table("team_repo"). @@ -792,13 +788,13 @@ func UpdateRepoIssueNumbers(ctx context.Context, repoID int64, isPull, isClosed } // CountNullArchivedRepository counts the number of repositories with is_archived is null -func CountNullArchivedRepository() (int64, error) { - return db.GetEngine(db.DefaultContext).Where(builder.IsNull{"is_archived"}).Count(new(Repository)) +func CountNullArchivedRepository(ctx context.Context) (int64, error) { + return db.GetEngine(ctx).Where(builder.IsNull{"is_archived"}).Count(new(Repository)) } // FixNullArchivedRepository sets is_archived to false where it is null -func FixNullArchivedRepository() (int64, error) { - return db.GetEngine(db.DefaultContext).Where(builder.IsNull{"is_archived"}).Cols("is_archived").NoAutoTime().Update(&Repository{ +func FixNullArchivedRepository(ctx context.Context) (int64, error) { + return db.GetEngine(ctx).Where(builder.IsNull{"is_archived"}).Cols("is_archived").NoAutoTime().Update(&Repository{ IsArchived: false, }) } diff --git a/models/repo/repo_list.go b/models/repo/repo_list.go index 191970d275..abfa73abb9 100644 --- a/models/repo/repo_list.go +++ b/models/repo/repo_list.go @@ -518,14 +518,13 @@ func SearchRepositoryCondition(opts *SearchRepoOptions) builder.Cond { // SearchRepository returns repositories based on search options, // it returns results in given range and number of total results. -func SearchRepository(opts *SearchRepoOptions) (RepositoryList, int64, error) { +func SearchRepository(ctx context.Context, opts *SearchRepoOptions) (RepositoryList, int64, error) { cond := SearchRepositoryCondition(opts) - return SearchRepositoryByCondition(opts, cond, true) + return SearchRepositoryByCondition(ctx, opts, cond, true) } // SearchRepositoryByCondition search repositories by condition -func SearchRepositoryByCondition(opts *SearchRepoOptions, cond builder.Cond, loadAttributes bool) (RepositoryList, int64, error) { - ctx := db.DefaultContext +func SearchRepositoryByCondition(ctx context.Context, opts *SearchRepoOptions, cond builder.Cond, loadAttributes bool) (RepositoryList, int64, error) { sess, count, err := searchRepositoryByCondition(ctx, opts, cond) if err != nil { return nil, 0, err @@ -652,9 +651,9 @@ func AccessibleRepositoryCondition(user *user_model.User, unitType unit.Type) bu // SearchRepositoryByName takes keyword and part of repository name to search, // it returns results in given range and number of total results. -func SearchRepositoryByName(opts *SearchRepoOptions) (RepositoryList, int64, error) { +func SearchRepositoryByName(ctx context.Context, opts *SearchRepoOptions) (RepositoryList, int64, error) { opts.IncludeDescription = false - return SearchRepository(opts) + return SearchRepository(ctx, opts) } // SearchRepositoryIDs takes keyword and part of repository name to search, diff --git a/models/repo/repo_list_test.go b/models/repo/repo_list_test.go index f9c84a0f3f..926ed07e9e 100644 --- a/models/repo/repo_list_test.go +++ b/models/repo/repo_list_test.go @@ -20,7 +20,7 @@ func TestSearchRepository(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) // test search public repository on explore page - repos, count, err := repo_model.SearchRepositoryByName(&repo_model.SearchRepoOptions{ + repos, count, err := repo_model.SearchRepositoryByName(db.DefaultContext, &repo_model.SearchRepoOptions{ ListOptions: db.ListOptions{ Page: 1, PageSize: 10, @@ -35,7 +35,7 @@ func TestSearchRepository(t *testing.T) { } assert.Equal(t, int64(1), count) - repos, count, err = repo_model.SearchRepositoryByName(&repo_model.SearchRepoOptions{ + repos, count, err = repo_model.SearchRepositoryByName(db.DefaultContext, &repo_model.SearchRepoOptions{ ListOptions: db.ListOptions{ Page: 1, PageSize: 10, @@ -49,7 +49,7 @@ func TestSearchRepository(t *testing.T) { assert.Len(t, repos, 2) // test search private repository on explore page - repos, count, err = repo_model.SearchRepositoryByName(&repo_model.SearchRepoOptions{ + repos, count, err = repo_model.SearchRepositoryByName(db.DefaultContext, &repo_model.SearchRepoOptions{ ListOptions: db.ListOptions{ Page: 1, PageSize: 10, @@ -65,7 +65,7 @@ func TestSearchRepository(t *testing.T) { } assert.Equal(t, int64(1), count) - repos, count, err = repo_model.SearchRepositoryByName(&repo_model.SearchRepoOptions{ + repos, count, err = repo_model.SearchRepositoryByName(db.DefaultContext, &repo_model.SearchRepoOptions{ ListOptions: db.ListOptions{ Page: 1, PageSize: 10, @@ -80,14 +80,14 @@ func TestSearchRepository(t *testing.T) { assert.Len(t, repos, 3) // Test non existing owner - repos, count, err = repo_model.SearchRepositoryByName(&repo_model.SearchRepoOptions{OwnerID: unittest.NonexistentID}) + repos, count, err = repo_model.SearchRepositoryByName(db.DefaultContext, &repo_model.SearchRepoOptions{OwnerID: unittest.NonexistentID}) assert.NoError(t, err) assert.Empty(t, repos) assert.Equal(t, int64(0), count) // Test search within description - repos, count, err = repo_model.SearchRepository(&repo_model.SearchRepoOptions{ + repos, count, err = repo_model.SearchRepository(db.DefaultContext, &repo_model.SearchRepoOptions{ ListOptions: db.ListOptions{ Page: 1, PageSize: 10, @@ -104,7 +104,7 @@ func TestSearchRepository(t *testing.T) { assert.Equal(t, int64(1), count) // Test NOT search within description - repos, count, err = repo_model.SearchRepository(&repo_model.SearchRepoOptions{ + repos, count, err = repo_model.SearchRepository(db.DefaultContext, &repo_model.SearchRepoOptions{ ListOptions: db.ListOptions{ Page: 1, PageSize: 10, @@ -277,7 +277,7 @@ func TestSearchRepository(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - repos, count, err := repo_model.SearchRepositoryByName(testCase.opts) + repos, count, err := repo_model.SearchRepositoryByName(db.DefaultContext, testCase.opts) assert.NoError(t, err) assert.Equal(t, int64(testCase.count), count) @@ -377,7 +377,7 @@ func TestSearchRepositoryByTopicName(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - _, count, err := repo_model.SearchRepositoryByName(testCase.opts) + _, count, err := repo_model.SearchRepositoryByName(db.DefaultContext, testCase.opts) assert.NoError(t, err) assert.Equal(t, int64(testCase.count), count) }) diff --git a/models/repo/user_repo.go b/models/repo/user_repo.go index e7125f70f8..9ca367f556 100644 --- a/models/repo/user_repo.go +++ b/models/repo/user_repo.go @@ -17,8 +17,9 @@ import ( ) // GetStarredRepos returns the repos starred by a particular user -func GetStarredRepos(userID int64, private bool, listOptions db.ListOptions) ([]*Repository, error) { - sess := db.GetEngine(db.DefaultContext).Where("star.uid=?", userID). +func GetStarredRepos(ctx context.Context, userID int64, private bool, listOptions db.ListOptions) ([]*Repository, error) { + sess := db.GetEngine(ctx). + Where("star.uid=?", userID). Join("LEFT", "star", "`repository`.id=`star`.repo_id") if !private { sess = sess.And("is_private=?", false) @@ -36,8 +37,9 @@ func GetStarredRepos(userID int64, private bool, listOptions db.ListOptions) ([] } // GetWatchedRepos returns the repos watched by a particular user -func GetWatchedRepos(userID int64, private bool, listOptions db.ListOptions) ([]*Repository, int64, error) { - sess := db.GetEngine(db.DefaultContext).Where("watch.user_id=?", userID). +func GetWatchedRepos(ctx context.Context, userID int64, private bool, listOptions db.ListOptions) ([]*Repository, int64, error) { + sess := db.GetEngine(ctx). + Where("watch.user_id=?", userID). And("`watch`.mode<>?", WatchModeDont). Join("LEFT", "watch", "`repository`.id=`watch`.repo_id") if !private { diff --git a/models/user/user.go b/models/user/user.go index c36fc21c77..1a71acb0b7 100644 --- a/models/user/user.go +++ b/models/user/user.go @@ -1042,14 +1042,15 @@ func GetUserEmailsByNames(ctx context.Context, names []string) []string { } // GetMaileableUsersByIDs gets users from ids, but only if they can receive mails -func GetMaileableUsersByIDs(ids []int64, isMention bool) ([]*User, error) { +func GetMaileableUsersByIDs(ctx context.Context, ids []int64, isMention bool) ([]*User, error) { if len(ids) == 0 { return nil, nil } ous := make([]*User, 0, len(ids)) if isMention { - return ous, db.GetEngine(db.DefaultContext).In("id", ids). + return ous, db.GetEngine(ctx). + In("id", ids). Where("`type` = ?", UserTypeIndividual). And("`prohibit_login` = ?", false). And("`is_active` = ?", true). @@ -1057,7 +1058,8 @@ func GetMaileableUsersByIDs(ids []int64, isMention bool) ([]*User, error) { Find(&ous) } - return ous, db.GetEngine(db.DefaultContext).In("id", ids). + return ous, db.GetEngine(ctx). + In("id", ids). Where("`type` = ?", UserTypeIndividual). And("`prohibit_login` = ?", false). And("`is_active` = ?", true). @@ -1090,10 +1092,10 @@ func GetUserNameByID(ctx context.Context, id int64) (string, error) { } // GetUserIDsByNames returns a slice of ids corresponds to names. -func GetUserIDsByNames(names []string, ignoreNonExistent bool) ([]int64, error) { +func GetUserIDsByNames(ctx context.Context, names []string, ignoreNonExistent bool) ([]int64, error) { ids := make([]int64, 0, len(names)) for _, name := range names { - u, err := GetUserByName(db.DefaultContext, name) + u, err := GetUserByName(ctx, name) if err != nil { if ignoreNonExistent { continue diff --git a/models/user/user_test.go b/models/user/user_test.go index 5f2ac0a60c..1cdfb5978c 100644 --- a/models/user/user_test.go +++ b/models/user/user_test.go @@ -257,12 +257,12 @@ func TestGetUserIDsByNames(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) // ignore non existing - IDs, err := user_model.GetUserIDsByNames([]string{"user1", "user2", "none_existing_user"}, true) + IDs, err := user_model.GetUserIDsByNames(db.DefaultContext, []string{"user1", "user2", "none_existing_user"}, true) assert.NoError(t, err) assert.Equal(t, []int64{1, 2}, IDs) // ignore non existing - IDs, err = user_model.GetUserIDsByNames([]string{"user1", "do_not_exist"}, false) + IDs, err = user_model.GetUserIDsByNames(db.DefaultContext, []string{"user1", "do_not_exist"}, false) assert.Error(t, err) assert.Equal(t, []int64(nil), IDs) } @@ -270,14 +270,14 @@ func TestGetUserIDsByNames(t *testing.T) { func TestGetMaileableUsersByIDs(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - results, err := user_model.GetMaileableUsersByIDs([]int64{1, 4}, false) + results, err := user_model.GetMaileableUsersByIDs(db.DefaultContext, []int64{1, 4}, false) assert.NoError(t, err) assert.Len(t, results, 1) if len(results) > 1 { assert.Equal(t, results[0].ID, 1) } - results, err = user_model.GetMaileableUsersByIDs([]int64{1, 4}, true) + results, err = user_model.GetMaileableUsersByIDs(db.DefaultContext, []int64{1, 4}, true) assert.NoError(t, err) assert.Len(t, results, 2) if len(results) > 2 { |