aboutsummaryrefslogtreecommitdiffstats
path: root/models/issues
diff options
context:
space:
mode:
authorLunny Xiao <xiaolunwen@gmail.com>2023-07-22 22:14:27 +0800
committerGitHub <noreply@github.com>2023-07-22 22:14:27 +0800
commitb167f35113e643ccdb17a2dde55bdec5960b284b (patch)
tree99a6e53bf1a9d4c9199c19113650cc48a8c1fd0e /models/issues
parentc42b71877edb4830b9573101d20853222d66fb3c (diff)
downloadgitea-b167f35113e643ccdb17a2dde55bdec5960b284b.tar.gz
gitea-b167f35113e643ccdb17a2dde55bdec5960b284b.zip
Add context parameter to some database functions (#26055)
To avoid deadlock problem, almost database related functions should be have ctx as the first parameter. This PR do a refactor for some of these functions.
Diffstat (limited to 'models/issues')
-rw-r--r--models/issues/comment_list.go11
-rw-r--r--models/issues/issue.go14
-rw-r--r--models/issues/issue_list.go8
-rw-r--r--models/issues/issue_list_test.go2
-rw-r--r--models/issues/issue_search.go2
-rw-r--r--models/issues/pull_list.go23
-rw-r--r--models/issues/pull_test.go8
-rw-r--r--models/issues/tracked_time.go20
-rw-r--r--models/issues/tracked_time_test.go2
9 files changed, 35 insertions, 55 deletions
diff --git a/models/issues/comment_list.go b/models/issues/comment_list.go
index e9c8406c3a..6f1d350eb4 100644
--- a/models/issues/comment_list.go
+++ b/models/issues/comment_list.go
@@ -465,8 +465,9 @@ func (comments CommentList) loadReviews(ctx context.Context) error {
return nil
}
-// loadAttributes loads all attributes
-func (comments CommentList) loadAttributes(ctx context.Context) (err error) {
+// LoadAttributes loads attributes of the comments, except for attachments and
+// comments
+func (comments CommentList) LoadAttributes(ctx context.Context) (err error) {
if err = comments.LoadPosters(ctx); err != nil {
return err
}
@@ -501,9 +502,3 @@ func (comments CommentList) loadAttributes(ctx context.Context) (err error) {
return comments.loadDependentIssues(ctx)
}
-
-// LoadAttributes loads attributes of the comments, except for attachments and
-// comments
-func (comments CommentList) LoadAttributes() error {
- return comments.loadAttributes(db.DefaultContext)
-}
diff --git a/models/issues/issue.go b/models/issues/issue.go
index c1a802c792..9d60d011ed 100644
--- a/models/issues/issue.go
+++ b/models/issues/issue.go
@@ -354,7 +354,7 @@ func (issue *Issue) LoadAttributes(ctx context.Context) (err error) {
return err
}
- if err = issue.Comments.loadAttributes(ctx); err != nil {
+ if err = issue.Comments.LoadAttributes(ctx); err != nil {
return err
}
if issue.IsTimetrackerEnabled(ctx) {
@@ -502,7 +502,7 @@ func (issue *Issue) GetLastEventLabelFake() string {
}
// GetIssueByIndex returns raw issue without loading attributes by index in a repository.
-func GetIssueByIndex(repoID, index int64) (*Issue, error) {
+func GetIssueByIndex(ctx context.Context, repoID, index int64) (*Issue, error) {
if index < 1 {
return nil, ErrIssueNotExist{}
}
@@ -510,7 +510,7 @@ func GetIssueByIndex(repoID, index int64) (*Issue, error) {
RepoID: repoID,
Index: index,
}
- has, err := db.GetEngine(db.DefaultContext).Get(issue)
+ has, err := db.GetEngine(ctx).Get(issue)
if err != nil {
return nil, err
} else if !has {
@@ -520,12 +520,12 @@ func GetIssueByIndex(repoID, index int64) (*Issue, error) {
}
// GetIssueWithAttrsByIndex returns issue by index in a repository.
-func GetIssueWithAttrsByIndex(repoID, index int64) (*Issue, error) {
- issue, err := GetIssueByIndex(repoID, index)
+func GetIssueWithAttrsByIndex(ctx context.Context, repoID, index int64) (*Issue, error) {
+ issue, err := GetIssueByIndex(ctx, repoID, index)
if err != nil {
return nil, err
}
- return issue, issue.LoadAttributes(db.DefaultContext)
+ return issue, issue.LoadAttributes(ctx)
}
// GetIssueByID returns an issue by given ID.
@@ -846,7 +846,7 @@ func GetPinnedIssues(ctx context.Context, repoID int64, isPull bool) ([]*Issue,
return nil, err
}
- err = IssueList(issues).LoadAttributes()
+ err = IssueList(issues).LoadAttributes(ctx)
if err != nil {
return nil, err
}
diff --git a/models/issues/issue_list.go b/models/issues/issue_list.go
index 9cc41ec6ab..a932ac2554 100644
--- a/models/issues/issue_list.go
+++ b/models/issues/issue_list.go
@@ -526,7 +526,7 @@ func (issues IssueList) loadTotalTrackedTimes(ctx context.Context) (err error) {
}
// loadAttributes loads all attributes, expect for attachments and comments
-func (issues IssueList) loadAttributes(ctx context.Context) error {
+func (issues IssueList) LoadAttributes(ctx context.Context) error {
if _, err := issues.LoadRepositories(ctx); err != nil {
return fmt.Errorf("issue.loadAttributes: LoadRepositories: %w", err)
}
@@ -562,12 +562,6 @@ func (issues IssueList) loadAttributes(ctx context.Context) error {
return nil
}
-// LoadAttributes loads attributes of the issues, except for attachments and
-// comments
-func (issues IssueList) LoadAttributes() error {
- return issues.loadAttributes(db.DefaultContext)
-}
-
// LoadComments loads comments
func (issues IssueList) LoadComments(ctx context.Context) error {
return issues.loadComments(ctx, builder.NewCond())
diff --git a/models/issues/issue_list_test.go b/models/issues/issue_list_test.go
index 696c3b765d..9069e1012d 100644
--- a/models/issues/issue_list_test.go
+++ b/models/issues/issue_list_test.go
@@ -39,7 +39,7 @@ func TestIssueList_LoadAttributes(t *testing.T) {
unittest.AssertExistsAndLoadBean(t, &issues_model.Issue{ID: 4}),
}
- assert.NoError(t, issueList.LoadAttributes())
+ assert.NoError(t, issueList.LoadAttributes(db.DefaultContext))
for _, issue := range issueList {
assert.EqualValues(t, issue.RepoID, issue.Repo.ID)
for _, label := range issue.Labels {
diff --git a/models/issues/issue_search.go b/models/issues/issue_search.go
index 9fd13f0995..6540ce02c0 100644
--- a/models/issues/issue_search.go
+++ b/models/issues/issue_search.go
@@ -440,7 +440,7 @@ func Issues(ctx context.Context, opts *IssuesOptions) ([]*Issue, error) {
return nil, fmt.Errorf("unable to query Issues: %w", err)
}
- if err := issues.LoadAttributes(); err != nil {
+ if err := issues.LoadAttributes(ctx); err != nil {
return nil, fmt.Errorf("unable to LoadAttributes for Issues: %w", err)
}
diff --git a/models/issues/pull_list.go b/models/issues/pull_list.go
index c443928344..3b2416900b 100644
--- a/models/issues/pull_list.go
+++ b/models/issues/pull_list.go
@@ -51,16 +51,16 @@ func listPullRequestStatement(baseRepoID int64, opts *PullRequestsOptions) (*xor
}
// GetUnmergedPullRequestsByHeadInfo returns all pull requests that are open and has not been merged
-func GetUnmergedPullRequestsByHeadInfo(repoID int64, branch string) ([]*PullRequest, error) {
+func GetUnmergedPullRequestsByHeadInfo(ctx context.Context, repoID int64, branch string) ([]*PullRequest, error) {
prs := make([]*PullRequest, 0, 2)
- sess := db.GetEngine(db.DefaultContext).
+ sess := db.GetEngine(ctx).
Join("INNER", "issue", "issue.id = pull_request.issue_id").
Where("head_repo_id = ? AND head_branch = ? AND has_merged = ? AND issue.is_closed = ? AND flow = ?", repoID, branch, false, false, PullRequestFlowGithub)
return prs, sess.Find(&prs)
}
// CanMaintainerWriteToBranch check whether user is a maintainer and could write to the branch
-func CanMaintainerWriteToBranch(p access_model.Permission, branch string, user *user_model.User) bool {
+func CanMaintainerWriteToBranch(ctx context.Context, p access_model.Permission, branch string, user *user_model.User) bool {
if p.CanWrite(unit.TypeCode) {
return true
}
@@ -69,18 +69,18 @@ func CanMaintainerWriteToBranch(p access_model.Permission, branch string, user *
return false
}
- prs, err := GetUnmergedPullRequestsByHeadInfo(p.Units[0].RepoID, branch)
+ prs, err := GetUnmergedPullRequestsByHeadInfo(ctx, p.Units[0].RepoID, branch)
if err != nil {
return false
}
for _, pr := range prs {
if pr.AllowMaintainerEdit {
- err = pr.LoadBaseRepo(db.DefaultContext)
+ err = pr.LoadBaseRepo(ctx)
if err != nil {
continue
}
- prPerm, err := access_model.GetUserRepoPermission(db.DefaultContext, pr.BaseRepo, user)
+ prPerm, err := access_model.GetUserRepoPermission(ctx, pr.BaseRepo, user)
if err != nil {
continue
}
@@ -104,9 +104,9 @@ func HasUnmergedPullRequestsByHeadInfo(ctx context.Context, repoID int64, branch
// GetUnmergedPullRequestsByBaseInfo returns all pull requests that are open and has not been merged
// by given base information (repo and branch).
-func GetUnmergedPullRequestsByBaseInfo(repoID int64, branch string) ([]*PullRequest, error) {
+func GetUnmergedPullRequestsByBaseInfo(ctx context.Context, repoID int64, branch string) ([]*PullRequest, error) {
prs := make([]*PullRequest, 0, 2)
- return prs, db.GetEngine(db.DefaultContext).
+ return prs, db.GetEngine(ctx).
Where("base_repo_id=? AND base_branch=? AND has_merged=? AND issue.is_closed=?",
repoID, branch, false, false).
OrderBy("issue.updated_unix DESC").
@@ -154,7 +154,7 @@ func PullRequests(baseRepoID int64, opts *PullRequestsOptions) ([]*PullRequest,
// PullRequestList defines a list of pull requests
type PullRequestList []*PullRequest
-func (prs PullRequestList) loadAttributes(ctx context.Context) error {
+func (prs PullRequestList) LoadAttributes(ctx context.Context) error {
if len(prs) == 0 {
return nil
}
@@ -199,8 +199,3 @@ func (prs PullRequestList) GetIssueIDs() []int64 {
}
return issueIDs
}
-
-// LoadAttributes load all the prs attributes
-func (prs PullRequestList) LoadAttributes() error {
- return prs.loadAttributes(db.DefaultContext)
-}
diff --git a/models/issues/pull_test.go b/models/issues/pull_test.go
index 5856b5dc58..0990a3b870 100644
--- a/models/issues/pull_test.go
+++ b/models/issues/pull_test.go
@@ -148,7 +148,7 @@ func TestHasUnmergedPullRequestsByHeadInfo(t *testing.T) {
func TestGetUnmergedPullRequestsByHeadInfo(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
- prs, err := issues_model.GetUnmergedPullRequestsByHeadInfo(1, "branch2")
+ prs, err := issues_model.GetUnmergedPullRequestsByHeadInfo(db.DefaultContext, 1, "branch2")
assert.NoError(t, err)
assert.Len(t, prs, 1)
for _, pr := range prs {
@@ -159,7 +159,7 @@ func TestGetUnmergedPullRequestsByHeadInfo(t *testing.T) {
func TestGetUnmergedPullRequestsByBaseInfo(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
- prs, err := issues_model.GetUnmergedPullRequestsByBaseInfo(1, "master")
+ prs, err := issues_model.GetUnmergedPullRequestsByBaseInfo(db.DefaultContext, 1, "master")
assert.NoError(t, err)
assert.Len(t, prs, 1)
pr := prs[0]
@@ -242,13 +242,13 @@ func TestPullRequestList_LoadAttributes(t *testing.T) {
unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ID: 1}),
unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ID: 2}),
}
- assert.NoError(t, issues_model.PullRequestList(prs).LoadAttributes())
+ assert.NoError(t, issues_model.PullRequestList(prs).LoadAttributes(db.DefaultContext))
for _, pr := range prs {
assert.NotNil(t, pr.Issue)
assert.Equal(t, pr.IssueID, pr.Issue.ID)
}
- assert.NoError(t, issues_model.PullRequestList([]*issues_model.PullRequest{}).LoadAttributes())
+ assert.NoError(t, issues_model.PullRequestList([]*issues_model.PullRequest{}).LoadAttributes(db.DefaultContext))
}
// TODO TestAddTestPullRequestTask
diff --git a/models/issues/tracked_time.go b/models/issues/tracked_time.go
index d117b74bc0..58c6b775f0 100644
--- a/models/issues/tracked_time.go
+++ b/models/issues/tracked_time.go
@@ -43,11 +43,7 @@ func (t *TrackedTime) AfterLoad() {
}
// LoadAttributes load Issue, User
-func (t *TrackedTime) LoadAttributes() (err error) {
- return t.loadAttributes(db.DefaultContext)
-}
-
-func (t *TrackedTime) loadAttributes(ctx context.Context) (err error) {
+func (t *TrackedTime) LoadAttributes(ctx context.Context) (err error) {
// Load the issue
if t.Issue == nil {
t.Issue, err = GetIssueByID(ctx, t.IssueID)
@@ -76,9 +72,9 @@ func (t *TrackedTime) loadAttributes(ctx context.Context) (err error) {
}
// LoadAttributes load Issue, User
-func (tl TrackedTimeList) LoadAttributes() error {
+func (tl TrackedTimeList) LoadAttributes(ctx context.Context) error {
for _, t := range tl {
- if err := t.LoadAttributes(); err != nil {
+ if err := t.LoadAttributes(ctx); err != nil {
return err
}
}
@@ -143,8 +139,8 @@ func GetTrackedTimes(ctx context.Context, options *FindTrackedTimesOptions) (tra
}
// CountTrackedTimes returns count of tracked times that fit to the given options.
-func CountTrackedTimes(opts *FindTrackedTimesOptions) (int64, error) {
- sess := db.GetEngine(db.DefaultContext).Where(opts.toCond())
+func CountTrackedTimes(ctx context.Context, opts *FindTrackedTimesOptions) (int64, error) {
+ sess := db.GetEngine(ctx).Where(opts.toCond())
if opts.RepositoryID > 0 || opts.MilestoneID > 0 {
sess = sess.Join("INNER", "issue", "issue.id = tracked_time.issue_id")
}
@@ -157,8 +153,8 @@ func GetTrackedSeconds(ctx context.Context, opts FindTrackedTimesOptions) (track
}
// AddTime will add the given time (in seconds) to the issue
-func AddTime(user *user_model.User, issue *Issue, amount int64, created time.Time) (*TrackedTime, error) {
- ctx, committer, err := db.TxContext(db.DefaultContext)
+func AddTime(ctx context.Context, user *user_model.User, issue *Issue, amount int64, created time.Time) (*TrackedTime, error) {
+ ctx, committer, err := db.TxContext(ctx)
if err != nil {
return nil, err
}
@@ -276,7 +272,7 @@ func DeleteTime(t *TrackedTime) error {
}
defer committer.Close()
- if err := t.loadAttributes(ctx); err != nil {
+ if err := t.LoadAttributes(ctx); err != nil {
return err
}
diff --git a/models/issues/tracked_time_test.go b/models/issues/tracked_time_test.go
index 37ba1cfdc4..1d88109183 100644
--- a/models/issues/tracked_time_test.go
+++ b/models/issues/tracked_time_test.go
@@ -25,7 +25,7 @@ func TestAddTime(t *testing.T) {
assert.NoError(t, err)
// 3661 = 1h 1min 1s
- trackedTime, err := issues_model.AddTime(user3, issue1, 3661, time.Now())
+ trackedTime, err := issues_model.AddTime(db.DefaultContext, user3, issue1, 3661, time.Now())
assert.NoError(t, err)
assert.Equal(t, int64(3), trackedTime.UserID)
assert.Equal(t, int64(1), trackedTime.IssueID)