diff options
author | Lunny Xiao <xiaolunwen@gmail.com> | 2023-07-22 22:14:27 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-22 22:14:27 +0800 |
commit | b167f35113e643ccdb17a2dde55bdec5960b284b (patch) | |
tree | 99a6e53bf1a9d4c9199c19113650cc48a8c1fd0e /models/issues | |
parent | c42b71877edb4830b9573101d20853222d66fb3c (diff) | |
download | gitea-b167f35113e643ccdb17a2dde55bdec5960b284b.tar.gz gitea-b167f35113e643ccdb17a2dde55bdec5960b284b.zip |
Add context parameter to some database functions (#26055)
To avoid deadlock problem, almost database related functions should be
have ctx as the first parameter.
This PR do a refactor for some of these functions.
Diffstat (limited to 'models/issues')
-rw-r--r-- | models/issues/comment_list.go | 11 | ||||
-rw-r--r-- | models/issues/issue.go | 14 | ||||
-rw-r--r-- | models/issues/issue_list.go | 8 | ||||
-rw-r--r-- | models/issues/issue_list_test.go | 2 | ||||
-rw-r--r-- | models/issues/issue_search.go | 2 | ||||
-rw-r--r-- | models/issues/pull_list.go | 23 | ||||
-rw-r--r-- | models/issues/pull_test.go | 8 | ||||
-rw-r--r-- | models/issues/tracked_time.go | 20 | ||||
-rw-r--r-- | models/issues/tracked_time_test.go | 2 |
9 files changed, 35 insertions, 55 deletions
diff --git a/models/issues/comment_list.go b/models/issues/comment_list.go index e9c8406c3a..6f1d350eb4 100644 --- a/models/issues/comment_list.go +++ b/models/issues/comment_list.go @@ -465,8 +465,9 @@ func (comments CommentList) loadReviews(ctx context.Context) error { return nil } -// loadAttributes loads all attributes -func (comments CommentList) loadAttributes(ctx context.Context) (err error) { +// LoadAttributes loads attributes of the comments, except for attachments and +// comments +func (comments CommentList) LoadAttributes(ctx context.Context) (err error) { if err = comments.LoadPosters(ctx); err != nil { return err } @@ -501,9 +502,3 @@ func (comments CommentList) loadAttributes(ctx context.Context) (err error) { return comments.loadDependentIssues(ctx) } - -// LoadAttributes loads attributes of the comments, except for attachments and -// comments -func (comments CommentList) LoadAttributes() error { - return comments.loadAttributes(db.DefaultContext) -} diff --git a/models/issues/issue.go b/models/issues/issue.go index c1a802c792..9d60d011ed 100644 --- a/models/issues/issue.go +++ b/models/issues/issue.go @@ -354,7 +354,7 @@ func (issue *Issue) LoadAttributes(ctx context.Context) (err error) { return err } - if err = issue.Comments.loadAttributes(ctx); err != nil { + if err = issue.Comments.LoadAttributes(ctx); err != nil { return err } if issue.IsTimetrackerEnabled(ctx) { @@ -502,7 +502,7 @@ func (issue *Issue) GetLastEventLabelFake() string { } // GetIssueByIndex returns raw issue without loading attributes by index in a repository. -func GetIssueByIndex(repoID, index int64) (*Issue, error) { +func GetIssueByIndex(ctx context.Context, repoID, index int64) (*Issue, error) { if index < 1 { return nil, ErrIssueNotExist{} } @@ -510,7 +510,7 @@ func GetIssueByIndex(repoID, index int64) (*Issue, error) { RepoID: repoID, Index: index, } - has, err := db.GetEngine(db.DefaultContext).Get(issue) + has, err := db.GetEngine(ctx).Get(issue) if err != nil { return nil, err } else if !has { @@ -520,12 +520,12 @@ func GetIssueByIndex(repoID, index int64) (*Issue, error) { } // GetIssueWithAttrsByIndex returns issue by index in a repository. -func GetIssueWithAttrsByIndex(repoID, index int64) (*Issue, error) { - issue, err := GetIssueByIndex(repoID, index) +func GetIssueWithAttrsByIndex(ctx context.Context, repoID, index int64) (*Issue, error) { + issue, err := GetIssueByIndex(ctx, repoID, index) if err != nil { return nil, err } - return issue, issue.LoadAttributes(db.DefaultContext) + return issue, issue.LoadAttributes(ctx) } // GetIssueByID returns an issue by given ID. @@ -846,7 +846,7 @@ func GetPinnedIssues(ctx context.Context, repoID int64, isPull bool) ([]*Issue, return nil, err } - err = IssueList(issues).LoadAttributes() + err = IssueList(issues).LoadAttributes(ctx) if err != nil { return nil, err } diff --git a/models/issues/issue_list.go b/models/issues/issue_list.go index 9cc41ec6ab..a932ac2554 100644 --- a/models/issues/issue_list.go +++ b/models/issues/issue_list.go @@ -526,7 +526,7 @@ func (issues IssueList) loadTotalTrackedTimes(ctx context.Context) (err error) { } // loadAttributes loads all attributes, expect for attachments and comments -func (issues IssueList) loadAttributes(ctx context.Context) error { +func (issues IssueList) LoadAttributes(ctx context.Context) error { if _, err := issues.LoadRepositories(ctx); err != nil { return fmt.Errorf("issue.loadAttributes: LoadRepositories: %w", err) } @@ -562,12 +562,6 @@ func (issues IssueList) loadAttributes(ctx context.Context) error { return nil } -// LoadAttributes loads attributes of the issues, except for attachments and -// comments -func (issues IssueList) LoadAttributes() error { - return issues.loadAttributes(db.DefaultContext) -} - // LoadComments loads comments func (issues IssueList) LoadComments(ctx context.Context) error { return issues.loadComments(ctx, builder.NewCond()) diff --git a/models/issues/issue_list_test.go b/models/issues/issue_list_test.go index 696c3b765d..9069e1012d 100644 --- a/models/issues/issue_list_test.go +++ b/models/issues/issue_list_test.go @@ -39,7 +39,7 @@ func TestIssueList_LoadAttributes(t *testing.T) { unittest.AssertExistsAndLoadBean(t, &issues_model.Issue{ID: 4}), } - assert.NoError(t, issueList.LoadAttributes()) + assert.NoError(t, issueList.LoadAttributes(db.DefaultContext)) for _, issue := range issueList { assert.EqualValues(t, issue.RepoID, issue.Repo.ID) for _, label := range issue.Labels { diff --git a/models/issues/issue_search.go b/models/issues/issue_search.go index 9fd13f0995..6540ce02c0 100644 --- a/models/issues/issue_search.go +++ b/models/issues/issue_search.go @@ -440,7 +440,7 @@ func Issues(ctx context.Context, opts *IssuesOptions) ([]*Issue, error) { return nil, fmt.Errorf("unable to query Issues: %w", err) } - if err := issues.LoadAttributes(); err != nil { + if err := issues.LoadAttributes(ctx); err != nil { return nil, fmt.Errorf("unable to LoadAttributes for Issues: %w", err) } diff --git a/models/issues/pull_list.go b/models/issues/pull_list.go index c443928344..3b2416900b 100644 --- a/models/issues/pull_list.go +++ b/models/issues/pull_list.go @@ -51,16 +51,16 @@ func listPullRequestStatement(baseRepoID int64, opts *PullRequestsOptions) (*xor } // GetUnmergedPullRequestsByHeadInfo returns all pull requests that are open and has not been merged -func GetUnmergedPullRequestsByHeadInfo(repoID int64, branch string) ([]*PullRequest, error) { +func GetUnmergedPullRequestsByHeadInfo(ctx context.Context, repoID int64, branch string) ([]*PullRequest, error) { prs := make([]*PullRequest, 0, 2) - sess := db.GetEngine(db.DefaultContext). + sess := db.GetEngine(ctx). Join("INNER", "issue", "issue.id = pull_request.issue_id"). Where("head_repo_id = ? AND head_branch = ? AND has_merged = ? AND issue.is_closed = ? AND flow = ?", repoID, branch, false, false, PullRequestFlowGithub) return prs, sess.Find(&prs) } // CanMaintainerWriteToBranch check whether user is a maintainer and could write to the branch -func CanMaintainerWriteToBranch(p access_model.Permission, branch string, user *user_model.User) bool { +func CanMaintainerWriteToBranch(ctx context.Context, p access_model.Permission, branch string, user *user_model.User) bool { if p.CanWrite(unit.TypeCode) { return true } @@ -69,18 +69,18 @@ func CanMaintainerWriteToBranch(p access_model.Permission, branch string, user * return false } - prs, err := GetUnmergedPullRequestsByHeadInfo(p.Units[0].RepoID, branch) + prs, err := GetUnmergedPullRequestsByHeadInfo(ctx, p.Units[0].RepoID, branch) if err != nil { return false } for _, pr := range prs { if pr.AllowMaintainerEdit { - err = pr.LoadBaseRepo(db.DefaultContext) + err = pr.LoadBaseRepo(ctx) if err != nil { continue } - prPerm, err := access_model.GetUserRepoPermission(db.DefaultContext, pr.BaseRepo, user) + prPerm, err := access_model.GetUserRepoPermission(ctx, pr.BaseRepo, user) if err != nil { continue } @@ -104,9 +104,9 @@ func HasUnmergedPullRequestsByHeadInfo(ctx context.Context, repoID int64, branch // GetUnmergedPullRequestsByBaseInfo returns all pull requests that are open and has not been merged // by given base information (repo and branch). -func GetUnmergedPullRequestsByBaseInfo(repoID int64, branch string) ([]*PullRequest, error) { +func GetUnmergedPullRequestsByBaseInfo(ctx context.Context, repoID int64, branch string) ([]*PullRequest, error) { prs := make([]*PullRequest, 0, 2) - return prs, db.GetEngine(db.DefaultContext). + return prs, db.GetEngine(ctx). Where("base_repo_id=? AND base_branch=? AND has_merged=? AND issue.is_closed=?", repoID, branch, false, false). OrderBy("issue.updated_unix DESC"). @@ -154,7 +154,7 @@ func PullRequests(baseRepoID int64, opts *PullRequestsOptions) ([]*PullRequest, // PullRequestList defines a list of pull requests type PullRequestList []*PullRequest -func (prs PullRequestList) loadAttributes(ctx context.Context) error { +func (prs PullRequestList) LoadAttributes(ctx context.Context) error { if len(prs) == 0 { return nil } @@ -199,8 +199,3 @@ func (prs PullRequestList) GetIssueIDs() []int64 { } return issueIDs } - -// LoadAttributes load all the prs attributes -func (prs PullRequestList) LoadAttributes() error { - return prs.loadAttributes(db.DefaultContext) -} diff --git a/models/issues/pull_test.go b/models/issues/pull_test.go index 5856b5dc58..0990a3b870 100644 --- a/models/issues/pull_test.go +++ b/models/issues/pull_test.go @@ -148,7 +148,7 @@ func TestHasUnmergedPullRequestsByHeadInfo(t *testing.T) { func TestGetUnmergedPullRequestsByHeadInfo(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - prs, err := issues_model.GetUnmergedPullRequestsByHeadInfo(1, "branch2") + prs, err := issues_model.GetUnmergedPullRequestsByHeadInfo(db.DefaultContext, 1, "branch2") assert.NoError(t, err) assert.Len(t, prs, 1) for _, pr := range prs { @@ -159,7 +159,7 @@ func TestGetUnmergedPullRequestsByHeadInfo(t *testing.T) { func TestGetUnmergedPullRequestsByBaseInfo(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - prs, err := issues_model.GetUnmergedPullRequestsByBaseInfo(1, "master") + prs, err := issues_model.GetUnmergedPullRequestsByBaseInfo(db.DefaultContext, 1, "master") assert.NoError(t, err) assert.Len(t, prs, 1) pr := prs[0] @@ -242,13 +242,13 @@ func TestPullRequestList_LoadAttributes(t *testing.T) { unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ID: 1}), unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ID: 2}), } - assert.NoError(t, issues_model.PullRequestList(prs).LoadAttributes()) + assert.NoError(t, issues_model.PullRequestList(prs).LoadAttributes(db.DefaultContext)) for _, pr := range prs { assert.NotNil(t, pr.Issue) assert.Equal(t, pr.IssueID, pr.Issue.ID) } - assert.NoError(t, issues_model.PullRequestList([]*issues_model.PullRequest{}).LoadAttributes()) + assert.NoError(t, issues_model.PullRequestList([]*issues_model.PullRequest{}).LoadAttributes(db.DefaultContext)) } // TODO TestAddTestPullRequestTask diff --git a/models/issues/tracked_time.go b/models/issues/tracked_time.go index d117b74bc0..58c6b775f0 100644 --- a/models/issues/tracked_time.go +++ b/models/issues/tracked_time.go @@ -43,11 +43,7 @@ func (t *TrackedTime) AfterLoad() { } // LoadAttributes load Issue, User -func (t *TrackedTime) LoadAttributes() (err error) { - return t.loadAttributes(db.DefaultContext) -} - -func (t *TrackedTime) loadAttributes(ctx context.Context) (err error) { +func (t *TrackedTime) LoadAttributes(ctx context.Context) (err error) { // Load the issue if t.Issue == nil { t.Issue, err = GetIssueByID(ctx, t.IssueID) @@ -76,9 +72,9 @@ func (t *TrackedTime) loadAttributes(ctx context.Context) (err error) { } // LoadAttributes load Issue, User -func (tl TrackedTimeList) LoadAttributes() error { +func (tl TrackedTimeList) LoadAttributes(ctx context.Context) error { for _, t := range tl { - if err := t.LoadAttributes(); err != nil { + if err := t.LoadAttributes(ctx); err != nil { return err } } @@ -143,8 +139,8 @@ func GetTrackedTimes(ctx context.Context, options *FindTrackedTimesOptions) (tra } // CountTrackedTimes returns count of tracked times that fit to the given options. -func CountTrackedTimes(opts *FindTrackedTimesOptions) (int64, error) { - sess := db.GetEngine(db.DefaultContext).Where(opts.toCond()) +func CountTrackedTimes(ctx context.Context, opts *FindTrackedTimesOptions) (int64, error) { + sess := db.GetEngine(ctx).Where(opts.toCond()) if opts.RepositoryID > 0 || opts.MilestoneID > 0 { sess = sess.Join("INNER", "issue", "issue.id = tracked_time.issue_id") } @@ -157,8 +153,8 @@ func GetTrackedSeconds(ctx context.Context, opts FindTrackedTimesOptions) (track } // AddTime will add the given time (in seconds) to the issue -func AddTime(user *user_model.User, issue *Issue, amount int64, created time.Time) (*TrackedTime, error) { - ctx, committer, err := db.TxContext(db.DefaultContext) +func AddTime(ctx context.Context, user *user_model.User, issue *Issue, amount int64, created time.Time) (*TrackedTime, error) { + ctx, committer, err := db.TxContext(ctx) if err != nil { return nil, err } @@ -276,7 +272,7 @@ func DeleteTime(t *TrackedTime) error { } defer committer.Close() - if err := t.loadAttributes(ctx); err != nil { + if err := t.LoadAttributes(ctx); err != nil { return err } diff --git a/models/issues/tracked_time_test.go b/models/issues/tracked_time_test.go index 37ba1cfdc4..1d88109183 100644 --- a/models/issues/tracked_time_test.go +++ b/models/issues/tracked_time_test.go @@ -25,7 +25,7 @@ func TestAddTime(t *testing.T) { assert.NoError(t, err) // 3661 = 1h 1min 1s - trackedTime, err := issues_model.AddTime(user3, issue1, 3661, time.Now()) + trackedTime, err := issues_model.AddTime(db.DefaultContext, user3, issue1, 3661, time.Now()) assert.NoError(t, err) assert.Equal(t, int64(3), trackedTime.UserID) assert.Equal(t, int64(1), trackedTime.IssueID) |