diff options
Diffstat (limited to 'models')
-rw-r--r-- | models/activities/action.go | 8 | ||||
-rw-r--r-- | models/activities/repo_activity.go | 54 | ||||
-rw-r--r-- | models/issues/comment_list.go | 11 | ||||
-rw-r--r-- | models/issues/issue.go | 14 | ||||
-rw-r--r-- | models/issues/issue_list.go | 8 | ||||
-rw-r--r-- | models/issues/issue_list_test.go | 2 | ||||
-rw-r--r-- | models/issues/issue_search.go | 2 | ||||
-rw-r--r-- | models/issues/pull_list.go | 23 | ||||
-rw-r--r-- | models/issues/pull_test.go | 8 | ||||
-rw-r--r-- | models/issues/tracked_time.go | 20 | ||||
-rw-r--r-- | models/issues/tracked_time_test.go | 2 | ||||
-rw-r--r-- | models/migrate.go | 4 | ||||
-rw-r--r-- | models/migrate_test.go | 2 |
13 files changed, 69 insertions, 89 deletions
diff --git a/models/activities/action.go b/models/activities/action.go index 57f579372f..7f22605d0d 100644 --- a/models/activities/action.go +++ b/models/activities/action.go @@ -391,10 +391,10 @@ func (a *Action) GetIssueInfos() []string { } // GetIssueTitle returns the title of first issue associated -// with the action. +// with the action. This function will be invoked in template so keep db.DefaultContext here func (a *Action) GetIssueTitle() string { index, _ := strconv.ParseInt(a.GetIssueInfos()[0], 10, 64) - issue, err := issues_model.GetIssueByIndex(a.RepoID, index) + issue, err := issues_model.GetIssueByIndex(db.DefaultContext, a.RepoID, index) if err != nil { log.Error("GetIssueByIndex: %v", err) return "500 when get issue" @@ -404,9 +404,9 @@ func (a *Action) GetIssueTitle() string { // GetIssueContent returns the content of first issue associated with // this action. -func (a *Action) GetIssueContent() string { +func (a *Action) GetIssueContent(ctx context.Context) string { index, _ := strconv.ParseInt(a.GetIssueInfos()[0], 10, 64) - issue, err := issues_model.GetIssueByIndex(a.RepoID, index) + issue, err := issues_model.GetIssueByIndex(ctx, a.RepoID, index) if err != nil { log.Error("GetIssueByIndex: %v", err) return "500 when get issue" diff --git a/models/activities/repo_activity.go b/models/activities/repo_activity.go index 72b6be3122..509f9caaf3 100644 --- a/models/activities/repo_activity.go +++ b/models/activities/repo_activity.go @@ -47,21 +47,21 @@ type ActivityStats struct { func GetActivityStats(ctx context.Context, repo *repo_model.Repository, timeFrom time.Time, releases, issues, prs, code bool) (*ActivityStats, error) { stats := &ActivityStats{Code: &git.CodeActivityStats{}} if releases { - if err := stats.FillReleases(repo.ID, timeFrom); err != nil { + if err := stats.FillReleases(ctx, repo.ID, timeFrom); err != nil { return nil, fmt.Errorf("FillReleases: %w", err) } } if prs { - if err := stats.FillPullRequests(repo.ID, timeFrom); err != nil { + if err := stats.FillPullRequests(ctx, repo.ID, timeFrom); err != nil { return nil, fmt.Errorf("FillPullRequests: %w", err) } } if issues { - if err := stats.FillIssues(repo.ID, timeFrom); err != nil { + if err := stats.FillIssues(ctx, repo.ID, timeFrom); err != nil { return nil, fmt.Errorf("FillIssues: %w", err) } } - if err := stats.FillUnresolvedIssues(repo.ID, timeFrom, issues, prs); err != nil { + if err := stats.FillUnresolvedIssues(ctx, repo.ID, timeFrom, issues, prs); err != nil { return nil, fmt.Errorf("FillUnresolvedIssues: %w", err) } if code { @@ -205,41 +205,41 @@ func (stats *ActivityStats) PublishedReleaseCount() int { } // FillPullRequests returns pull request information for activity page -func (stats *ActivityStats) FillPullRequests(repoID int64, fromTime time.Time) error { +func (stats *ActivityStats) FillPullRequests(ctx context.Context, repoID int64, fromTime time.Time) error { var err error var count int64 // Merged pull requests - sess := pullRequestsForActivityStatement(repoID, fromTime, true) + sess := pullRequestsForActivityStatement(ctx, repoID, fromTime, true) sess.OrderBy("pull_request.merged_unix DESC") stats.MergedPRs = make(issues_model.PullRequestList, 0) if err = sess.Find(&stats.MergedPRs); err != nil { return err } - if err = stats.MergedPRs.LoadAttributes(); err != nil { + if err = stats.MergedPRs.LoadAttributes(ctx); err != nil { return err } // Merged pull request authors - sess = pullRequestsForActivityStatement(repoID, fromTime, true) + sess = pullRequestsForActivityStatement(ctx, repoID, fromTime, true) if _, err = sess.Select("count(distinct issue.poster_id) as `count`").Table("pull_request").Get(&count); err != nil { return err } stats.MergedPRAuthorCount = count // Opened pull requests - sess = pullRequestsForActivityStatement(repoID, fromTime, false) + sess = pullRequestsForActivityStatement(ctx, repoID, fromTime, false) sess.OrderBy("issue.created_unix ASC") stats.OpenedPRs = make(issues_model.PullRequestList, 0) if err = sess.Find(&stats.OpenedPRs); err != nil { return err } - if err = stats.OpenedPRs.LoadAttributes(); err != nil { + if err = stats.OpenedPRs.LoadAttributes(ctx); err != nil { return err } // Opened pull request authors - sess = pullRequestsForActivityStatement(repoID, fromTime, false) + sess = pullRequestsForActivityStatement(ctx, repoID, fromTime, false) if _, err = sess.Select("count(distinct issue.poster_id) as `count`").Table("pull_request").Get(&count); err != nil { return err } @@ -248,8 +248,8 @@ func (stats *ActivityStats) FillPullRequests(repoID int64, fromTime time.Time) e return nil } -func pullRequestsForActivityStatement(repoID int64, fromTime time.Time, merged bool) *xorm.Session { - sess := db.GetEngine(db.DefaultContext).Where("pull_request.base_repo_id=?", repoID). +func pullRequestsForActivityStatement(ctx context.Context, repoID int64, fromTime time.Time, merged bool) *xorm.Session { + sess := db.GetEngine(ctx).Where("pull_request.base_repo_id=?", repoID). Join("INNER", "issue", "pull_request.issue_id = issue.id") if merged { @@ -264,12 +264,12 @@ func pullRequestsForActivityStatement(repoID int64, fromTime time.Time, merged b } // FillIssues returns issue information for activity page -func (stats *ActivityStats) FillIssues(repoID int64, fromTime time.Time) error { +func (stats *ActivityStats) FillIssues(ctx context.Context, repoID int64, fromTime time.Time) error { var err error var count int64 // Closed issues - sess := issuesForActivityStatement(repoID, fromTime, true, false) + sess := issuesForActivityStatement(ctx, repoID, fromTime, true, false) sess.OrderBy("issue.closed_unix DESC") stats.ClosedIssues = make(issues_model.IssueList, 0) if err = sess.Find(&stats.ClosedIssues); err != nil { @@ -277,14 +277,14 @@ func (stats *ActivityStats) FillIssues(repoID int64, fromTime time.Time) error { } // Closed issue authors - sess = issuesForActivityStatement(repoID, fromTime, true, false) + sess = issuesForActivityStatement(ctx, repoID, fromTime, true, false) if _, err = sess.Select("count(distinct issue.poster_id) as `count`").Table("issue").Get(&count); err != nil { return err } stats.ClosedIssueAuthorCount = count // New issues - sess = issuesForActivityStatement(repoID, fromTime, false, false) + sess = issuesForActivityStatement(ctx, repoID, fromTime, false, false) sess.OrderBy("issue.created_unix ASC") stats.OpenedIssues = make(issues_model.IssueList, 0) if err = sess.Find(&stats.OpenedIssues); err != nil { @@ -292,7 +292,7 @@ func (stats *ActivityStats) FillIssues(repoID int64, fromTime time.Time) error { } // Opened issue authors - sess = issuesForActivityStatement(repoID, fromTime, false, false) + sess = issuesForActivityStatement(ctx, repoID, fromTime, false, false) if _, err = sess.Select("count(distinct issue.poster_id) as `count`").Table("issue").Get(&count); err != nil { return err } @@ -302,12 +302,12 @@ func (stats *ActivityStats) FillIssues(repoID int64, fromTime time.Time) error { } // FillUnresolvedIssues returns unresolved issue and pull request information for activity page -func (stats *ActivityStats) FillUnresolvedIssues(repoID int64, fromTime time.Time, issues, prs bool) error { +func (stats *ActivityStats) FillUnresolvedIssues(ctx context.Context, repoID int64, fromTime time.Time, issues, prs bool) error { // Check if we need to select anything if !issues && !prs { return nil } - sess := issuesForActivityStatement(repoID, fromTime, false, true) + sess := issuesForActivityStatement(ctx, repoID, fromTime, false, true) if !issues || !prs { sess.And("issue.is_pull = ?", prs) } @@ -316,8 +316,8 @@ func (stats *ActivityStats) FillUnresolvedIssues(repoID int64, fromTime time.Tim return sess.Find(&stats.UnresolvedIssues) } -func issuesForActivityStatement(repoID int64, fromTime time.Time, closed, unresolved bool) *xorm.Session { - sess := db.GetEngine(db.DefaultContext).Where("issue.repo_id = ?", repoID). +func issuesForActivityStatement(ctx context.Context, repoID int64, fromTime time.Time, closed, unresolved bool) *xorm.Session { + sess := db.GetEngine(ctx).Where("issue.repo_id = ?", repoID). And("issue.is_closed = ?", closed) if !unresolved { @@ -336,12 +336,12 @@ func issuesForActivityStatement(repoID int64, fromTime time.Time, closed, unreso } // FillReleases returns release information for activity page -func (stats *ActivityStats) FillReleases(repoID int64, fromTime time.Time) error { +func (stats *ActivityStats) FillReleases(ctx context.Context, repoID int64, fromTime time.Time) error { var err error var count int64 // Published releases list - sess := releasesForActivityStatement(repoID, fromTime) + sess := releasesForActivityStatement(ctx, repoID, fromTime) sess.OrderBy("release.created_unix DESC") stats.PublishedReleases = make([]*repo_model.Release, 0) if err = sess.Find(&stats.PublishedReleases); err != nil { @@ -349,7 +349,7 @@ func (stats *ActivityStats) FillReleases(repoID int64, fromTime time.Time) error } // Published releases authors - sess = releasesForActivityStatement(repoID, fromTime) + sess = releasesForActivityStatement(ctx, repoID, fromTime) if _, err = sess.Select("count(distinct release.publisher_id) as `count`").Table("release").Get(&count); err != nil { return err } @@ -358,8 +358,8 @@ func (stats *ActivityStats) FillReleases(repoID int64, fromTime time.Time) error return nil } -func releasesForActivityStatement(repoID int64, fromTime time.Time) *xorm.Session { - return db.GetEngine(db.DefaultContext).Where("release.repo_id = ?", repoID). +func releasesForActivityStatement(ctx context.Context, repoID int64, fromTime time.Time) *xorm.Session { + return db.GetEngine(ctx).Where("release.repo_id = ?", repoID). And("release.is_draft = ?", false). And("release.created_unix >= ?", fromTime.Unix()) } diff --git a/models/issues/comment_list.go b/models/issues/comment_list.go index e9c8406c3a..6f1d350eb4 100644 --- a/models/issues/comment_list.go +++ b/models/issues/comment_list.go @@ -465,8 +465,9 @@ func (comments CommentList) loadReviews(ctx context.Context) error { return nil } -// loadAttributes loads all attributes -func (comments CommentList) loadAttributes(ctx context.Context) (err error) { +// LoadAttributes loads attributes of the comments, except for attachments and +// comments +func (comments CommentList) LoadAttributes(ctx context.Context) (err error) { if err = comments.LoadPosters(ctx); err != nil { return err } @@ -501,9 +502,3 @@ func (comments CommentList) loadAttributes(ctx context.Context) (err error) { return comments.loadDependentIssues(ctx) } - -// LoadAttributes loads attributes of the comments, except for attachments and -// comments -func (comments CommentList) LoadAttributes() error { - return comments.loadAttributes(db.DefaultContext) -} diff --git a/models/issues/issue.go b/models/issues/issue.go index c1a802c792..9d60d011ed 100644 --- a/models/issues/issue.go +++ b/models/issues/issue.go @@ -354,7 +354,7 @@ func (issue *Issue) LoadAttributes(ctx context.Context) (err error) { return err } - if err = issue.Comments.loadAttributes(ctx); err != nil { + if err = issue.Comments.LoadAttributes(ctx); err != nil { return err } if issue.IsTimetrackerEnabled(ctx) { @@ -502,7 +502,7 @@ func (issue *Issue) GetLastEventLabelFake() string { } // GetIssueByIndex returns raw issue without loading attributes by index in a repository. -func GetIssueByIndex(repoID, index int64) (*Issue, error) { +func GetIssueByIndex(ctx context.Context, repoID, index int64) (*Issue, error) { if index < 1 { return nil, ErrIssueNotExist{} } @@ -510,7 +510,7 @@ func GetIssueByIndex(repoID, index int64) (*Issue, error) { RepoID: repoID, Index: index, } - has, err := db.GetEngine(db.DefaultContext).Get(issue) + has, err := db.GetEngine(ctx).Get(issue) if err != nil { return nil, err } else if !has { @@ -520,12 +520,12 @@ func GetIssueByIndex(repoID, index int64) (*Issue, error) { } // GetIssueWithAttrsByIndex returns issue by index in a repository. -func GetIssueWithAttrsByIndex(repoID, index int64) (*Issue, error) { - issue, err := GetIssueByIndex(repoID, index) +func GetIssueWithAttrsByIndex(ctx context.Context, repoID, index int64) (*Issue, error) { + issue, err := GetIssueByIndex(ctx, repoID, index) if err != nil { return nil, err } - return issue, issue.LoadAttributes(db.DefaultContext) + return issue, issue.LoadAttributes(ctx) } // GetIssueByID returns an issue by given ID. @@ -846,7 +846,7 @@ func GetPinnedIssues(ctx context.Context, repoID int64, isPull bool) ([]*Issue, return nil, err } - err = IssueList(issues).LoadAttributes() + err = IssueList(issues).LoadAttributes(ctx) if err != nil { return nil, err } diff --git a/models/issues/issue_list.go b/models/issues/issue_list.go index 9cc41ec6ab..a932ac2554 100644 --- a/models/issues/issue_list.go +++ b/models/issues/issue_list.go @@ -526,7 +526,7 @@ func (issues IssueList) loadTotalTrackedTimes(ctx context.Context) (err error) { } // loadAttributes loads all attributes, expect for attachments and comments -func (issues IssueList) loadAttributes(ctx context.Context) error { +func (issues IssueList) LoadAttributes(ctx context.Context) error { if _, err := issues.LoadRepositories(ctx); err != nil { return fmt.Errorf("issue.loadAttributes: LoadRepositories: %w", err) } @@ -562,12 +562,6 @@ func (issues IssueList) loadAttributes(ctx context.Context) error { return nil } -// LoadAttributes loads attributes of the issues, except for attachments and -// comments -func (issues IssueList) LoadAttributes() error { - return issues.loadAttributes(db.DefaultContext) -} - // LoadComments loads comments func (issues IssueList) LoadComments(ctx context.Context) error { return issues.loadComments(ctx, builder.NewCond()) diff --git a/models/issues/issue_list_test.go b/models/issues/issue_list_test.go index 696c3b765d..9069e1012d 100644 --- a/models/issues/issue_list_test.go +++ b/models/issues/issue_list_test.go @@ -39,7 +39,7 @@ func TestIssueList_LoadAttributes(t *testing.T) { unittest.AssertExistsAndLoadBean(t, &issues_model.Issue{ID: 4}), } - assert.NoError(t, issueList.LoadAttributes()) + assert.NoError(t, issueList.LoadAttributes(db.DefaultContext)) for _, issue := range issueList { assert.EqualValues(t, issue.RepoID, issue.Repo.ID) for _, label := range issue.Labels { diff --git a/models/issues/issue_search.go b/models/issues/issue_search.go index 9fd13f0995..6540ce02c0 100644 --- a/models/issues/issue_search.go +++ b/models/issues/issue_search.go @@ -440,7 +440,7 @@ func Issues(ctx context.Context, opts *IssuesOptions) ([]*Issue, error) { return nil, fmt.Errorf("unable to query Issues: %w", err) } - if err := issues.LoadAttributes(); err != nil { + if err := issues.LoadAttributes(ctx); err != nil { return nil, fmt.Errorf("unable to LoadAttributes for Issues: %w", err) } diff --git a/models/issues/pull_list.go b/models/issues/pull_list.go index c443928344..3b2416900b 100644 --- a/models/issues/pull_list.go +++ b/models/issues/pull_list.go @@ -51,16 +51,16 @@ func listPullRequestStatement(baseRepoID int64, opts *PullRequestsOptions) (*xor } // GetUnmergedPullRequestsByHeadInfo returns all pull requests that are open and has not been merged -func GetUnmergedPullRequestsByHeadInfo(repoID int64, branch string) ([]*PullRequest, error) { +func GetUnmergedPullRequestsByHeadInfo(ctx context.Context, repoID int64, branch string) ([]*PullRequest, error) { prs := make([]*PullRequest, 0, 2) - sess := db.GetEngine(db.DefaultContext). + sess := db.GetEngine(ctx). Join("INNER", "issue", "issue.id = pull_request.issue_id"). Where("head_repo_id = ? AND head_branch = ? AND has_merged = ? AND issue.is_closed = ? AND flow = ?", repoID, branch, false, false, PullRequestFlowGithub) return prs, sess.Find(&prs) } // CanMaintainerWriteToBranch check whether user is a maintainer and could write to the branch -func CanMaintainerWriteToBranch(p access_model.Permission, branch string, user *user_model.User) bool { +func CanMaintainerWriteToBranch(ctx context.Context, p access_model.Permission, branch string, user *user_model.User) bool { if p.CanWrite(unit.TypeCode) { return true } @@ -69,18 +69,18 @@ func CanMaintainerWriteToBranch(p access_model.Permission, branch string, user * return false } - prs, err := GetUnmergedPullRequestsByHeadInfo(p.Units[0].RepoID, branch) + prs, err := GetUnmergedPullRequestsByHeadInfo(ctx, p.Units[0].RepoID, branch) if err != nil { return false } for _, pr := range prs { if pr.AllowMaintainerEdit { - err = pr.LoadBaseRepo(db.DefaultContext) + err = pr.LoadBaseRepo(ctx) if err != nil { continue } - prPerm, err := access_model.GetUserRepoPermission(db.DefaultContext, pr.BaseRepo, user) + prPerm, err := access_model.GetUserRepoPermission(ctx, pr.BaseRepo, user) if err != nil { continue } @@ -104,9 +104,9 @@ func HasUnmergedPullRequestsByHeadInfo(ctx context.Context, repoID int64, branch // GetUnmergedPullRequestsByBaseInfo returns all pull requests that are open and has not been merged // by given base information (repo and branch). -func GetUnmergedPullRequestsByBaseInfo(repoID int64, branch string) ([]*PullRequest, error) { +func GetUnmergedPullRequestsByBaseInfo(ctx context.Context, repoID int64, branch string) ([]*PullRequest, error) { prs := make([]*PullRequest, 0, 2) - return prs, db.GetEngine(db.DefaultContext). + return prs, db.GetEngine(ctx). Where("base_repo_id=? AND base_branch=? AND has_merged=? AND issue.is_closed=?", repoID, branch, false, false). OrderBy("issue.updated_unix DESC"). @@ -154,7 +154,7 @@ func PullRequests(baseRepoID int64, opts *PullRequestsOptions) ([]*PullRequest, // PullRequestList defines a list of pull requests type PullRequestList []*PullRequest -func (prs PullRequestList) loadAttributes(ctx context.Context) error { +func (prs PullRequestList) LoadAttributes(ctx context.Context) error { if len(prs) == 0 { return nil } @@ -199,8 +199,3 @@ func (prs PullRequestList) GetIssueIDs() []int64 { } return issueIDs } - -// LoadAttributes load all the prs attributes -func (prs PullRequestList) LoadAttributes() error { - return prs.loadAttributes(db.DefaultContext) -} diff --git a/models/issues/pull_test.go b/models/issues/pull_test.go index 5856b5dc58..0990a3b870 100644 --- a/models/issues/pull_test.go +++ b/models/issues/pull_test.go @@ -148,7 +148,7 @@ func TestHasUnmergedPullRequestsByHeadInfo(t *testing.T) { func TestGetUnmergedPullRequestsByHeadInfo(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - prs, err := issues_model.GetUnmergedPullRequestsByHeadInfo(1, "branch2") + prs, err := issues_model.GetUnmergedPullRequestsByHeadInfo(db.DefaultContext, 1, "branch2") assert.NoError(t, err) assert.Len(t, prs, 1) for _, pr := range prs { @@ -159,7 +159,7 @@ func TestGetUnmergedPullRequestsByHeadInfo(t *testing.T) { func TestGetUnmergedPullRequestsByBaseInfo(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - prs, err := issues_model.GetUnmergedPullRequestsByBaseInfo(1, "master") + prs, err := issues_model.GetUnmergedPullRequestsByBaseInfo(db.DefaultContext, 1, "master") assert.NoError(t, err) assert.Len(t, prs, 1) pr := prs[0] @@ -242,13 +242,13 @@ func TestPullRequestList_LoadAttributes(t *testing.T) { unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ID: 1}), unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{ID: 2}), } - assert.NoError(t, issues_model.PullRequestList(prs).LoadAttributes()) + assert.NoError(t, issues_model.PullRequestList(prs).LoadAttributes(db.DefaultContext)) for _, pr := range prs { assert.NotNil(t, pr.Issue) assert.Equal(t, pr.IssueID, pr.Issue.ID) } - assert.NoError(t, issues_model.PullRequestList([]*issues_model.PullRequest{}).LoadAttributes()) + assert.NoError(t, issues_model.PullRequestList([]*issues_model.PullRequest{}).LoadAttributes(db.DefaultContext)) } // TODO TestAddTestPullRequestTask diff --git a/models/issues/tracked_time.go b/models/issues/tracked_time.go index d117b74bc0..58c6b775f0 100644 --- a/models/issues/tracked_time.go +++ b/models/issues/tracked_time.go @@ -43,11 +43,7 @@ func (t *TrackedTime) AfterLoad() { } // LoadAttributes load Issue, User -func (t *TrackedTime) LoadAttributes() (err error) { - return t.loadAttributes(db.DefaultContext) -} - -func (t *TrackedTime) loadAttributes(ctx context.Context) (err error) { +func (t *TrackedTime) LoadAttributes(ctx context.Context) (err error) { // Load the issue if t.Issue == nil { t.Issue, err = GetIssueByID(ctx, t.IssueID) @@ -76,9 +72,9 @@ func (t *TrackedTime) loadAttributes(ctx context.Context) (err error) { } // LoadAttributes load Issue, User -func (tl TrackedTimeList) LoadAttributes() error { +func (tl TrackedTimeList) LoadAttributes(ctx context.Context) error { for _, t := range tl { - if err := t.LoadAttributes(); err != nil { + if err := t.LoadAttributes(ctx); err != nil { return err } } @@ -143,8 +139,8 @@ func GetTrackedTimes(ctx context.Context, options *FindTrackedTimesOptions) (tra } // CountTrackedTimes returns count of tracked times that fit to the given options. -func CountTrackedTimes(opts *FindTrackedTimesOptions) (int64, error) { - sess := db.GetEngine(db.DefaultContext).Where(opts.toCond()) +func CountTrackedTimes(ctx context.Context, opts *FindTrackedTimesOptions) (int64, error) { + sess := db.GetEngine(ctx).Where(opts.toCond()) if opts.RepositoryID > 0 || opts.MilestoneID > 0 { sess = sess.Join("INNER", "issue", "issue.id = tracked_time.issue_id") } @@ -157,8 +153,8 @@ func GetTrackedSeconds(ctx context.Context, opts FindTrackedTimesOptions) (track } // AddTime will add the given time (in seconds) to the issue -func AddTime(user *user_model.User, issue *Issue, amount int64, created time.Time) (*TrackedTime, error) { - ctx, committer, err := db.TxContext(db.DefaultContext) +func AddTime(ctx context.Context, user *user_model.User, issue *Issue, amount int64, created time.Time) (*TrackedTime, error) { + ctx, committer, err := db.TxContext(ctx) if err != nil { return nil, err } @@ -276,7 +272,7 @@ func DeleteTime(t *TrackedTime) error { } defer committer.Close() - if err := t.loadAttributes(ctx); err != nil { + if err := t.LoadAttributes(ctx); err != nil { return err } diff --git a/models/issues/tracked_time_test.go b/models/issues/tracked_time_test.go index 37ba1cfdc4..1d88109183 100644 --- a/models/issues/tracked_time_test.go +++ b/models/issues/tracked_time_test.go @@ -25,7 +25,7 @@ func TestAddTime(t *testing.T) { assert.NoError(t, err) // 3661 = 1h 1min 1s - trackedTime, err := issues_model.AddTime(user3, issue1, 3661, time.Now()) + trackedTime, err := issues_model.AddTime(db.DefaultContext, user3, issue1, 3661, time.Now()) assert.NoError(t, err) assert.Equal(t, int64(3), trackedTime.UserID) assert.Equal(t, int64(1), trackedTime.IssueID) diff --git a/models/migrate.go b/models/migrate.go index 82cacd4a75..9705d0ad04 100644 --- a/models/migrate.go +++ b/models/migrate.go @@ -128,8 +128,8 @@ func InsertIssueComments(comments []*issues_model.Comment) error { } // InsertPullRequests inserted pull requests -func InsertPullRequests(prs ...*issues_model.PullRequest) error { - ctx, committer, err := db.TxContext(db.DefaultContext) +func InsertPullRequests(ctx context.Context, prs ...*issues_model.PullRequest) error { + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } diff --git a/models/migrate_test.go b/models/migrate_test.go index 42102f9a7d..74736a2849 100644 --- a/models/migrate_test.go +++ b/models/migrate_test.go @@ -122,7 +122,7 @@ func TestMigrate_InsertPullRequests(t *testing.T) { Issue: i, } - err := InsertPullRequests(p) + err := InsertPullRequests(db.DefaultContext, p) assert.NoError(t, err) _ = unittest.AssertExistsAndLoadBean(t, &issues_model.PullRequest{IssueID: i.ID}) |