diff options
author | JakobDev <jakobdev@gmx.de> | 2023-09-16 16:39:12 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-16 14:39:12 +0000 |
commit | f91dbbba98c841f11d99be998ed5dd98122a457c (patch) | |
tree | 9c6c935ccf745c5a1716f1330922354809cd39e0 /models | |
parent | a1b2a118123e0abd1d4737f4a6c0cf56d15eff57 (diff) | |
download | gitea-f91dbbba98c841f11d99be998ed5dd98122a457c.tar.gz gitea-f91dbbba98c841f11d99be998ed5dd98122a457c.zip |
Next round of `db.DefaultContext` refactor (#27089)
Part of #27065
Diffstat (limited to 'models')
27 files changed, 236 insertions, 273 deletions
diff --git a/models/actions/schedule.go b/models/actions/schedule.go index b0bc40dadc..34d23f1c01 100644 --- a/models/actions/schedule.go +++ b/models/actions/schedule.go @@ -41,15 +41,15 @@ func init() { } // GetSchedulesMapByIDs returns the schedules by given id slice. -func GetSchedulesMapByIDs(ids []int64) (map[int64]*ActionSchedule, error) { +func GetSchedulesMapByIDs(ctx context.Context, ids []int64) (map[int64]*ActionSchedule, error) { schedules := make(map[int64]*ActionSchedule, len(ids)) - return schedules, db.GetEngine(db.DefaultContext).In("id", ids).Find(&schedules) + return schedules, db.GetEngine(ctx).In("id", ids).Find(&schedules) } // GetReposMapByIDs returns the repos by given id slice. -func GetReposMapByIDs(ids []int64) (map[int64]*repo_model.Repository, error) { +func GetReposMapByIDs(ctx context.Context, ids []int64) (map[int64]*repo_model.Repository, error) { repos := make(map[int64]*repo_model.Repository, len(ids)) - return repos, db.GetEngine(db.DefaultContext).In("id", ids).Find(&repos) + return repos, db.GetEngine(ctx).In("id", ids).Find(&repos) } var cronParser = cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor) diff --git a/models/actions/schedule_spec_list.go b/models/actions/schedule_spec_list.go index d379490b4e..2c017fdabc 100644 --- a/models/actions/schedule_spec_list.go +++ b/models/actions/schedule_spec_list.go @@ -23,9 +23,9 @@ func (specs SpecList) GetScheduleIDs() []int64 { return ids.Values() } -func (specs SpecList) LoadSchedules() error { +func (specs SpecList) LoadSchedules(ctx context.Context) error { scheduleIDs := specs.GetScheduleIDs() - schedules, err := GetSchedulesMapByIDs(scheduleIDs) + schedules, err := GetSchedulesMapByIDs(ctx, scheduleIDs) if err != nil { return err } @@ -34,7 +34,7 @@ func (specs SpecList) LoadSchedules() error { } repoIDs := specs.GetRepoIDs() - repos, err := GetReposMapByIDs(repoIDs) + repos, err := GetReposMapByIDs(ctx, repoIDs) if err != nil { return err } @@ -95,7 +95,7 @@ func FindSpecs(ctx context.Context, opts FindSpecOptions) (SpecList, int64, erro return nil, 0, err } - if err := specs.LoadSchedules(); err != nil { + if err := specs.LoadSchedules(ctx); err != nil { return nil, 0, err } return specs, total, nil diff --git a/models/admin/task.go b/models/admin/task.go index 8aa397ad35..c8bc95f981 100644 --- a/models/admin/task.go +++ b/models/admin/task.go @@ -48,11 +48,7 @@ type TranslatableMessage struct { } // LoadRepo loads repository of the task -func (task *Task) LoadRepo() error { - return task.loadRepo(db.DefaultContext) -} - -func (task *Task) loadRepo(ctx context.Context) error { +func (task *Task) LoadRepo(ctx context.Context) error { if task.Repo != nil { return nil } @@ -70,13 +66,13 @@ func (task *Task) loadRepo(ctx context.Context) error { } // LoadDoer loads do user -func (task *Task) LoadDoer() error { +func (task *Task) LoadDoer(ctx context.Context) error { if task.Doer != nil { return nil } var doer user_model.User - has, err := db.GetEngine(db.DefaultContext).ID(task.DoerID).Get(&doer) + has, err := db.GetEngine(ctx).ID(task.DoerID).Get(&doer) if err != nil { return err } else if !has { @@ -90,13 +86,13 @@ func (task *Task) LoadDoer() error { } // LoadOwner loads owner user -func (task *Task) LoadOwner() error { +func (task *Task) LoadOwner(ctx context.Context) error { if task.Owner != nil { return nil } var owner user_model.User - has, err := db.GetEngine(db.DefaultContext).ID(task.OwnerID).Get(&owner) + has, err := db.GetEngine(ctx).ID(task.OwnerID).Get(&owner) if err != nil { return err } else if !has { @@ -110,8 +106,8 @@ func (task *Task) LoadOwner() error { } // UpdateCols updates some columns -func (task *Task) UpdateCols(cols ...string) error { - _, err := db.GetEngine(db.DefaultContext).ID(task.ID).Cols(cols...).Update(task) +func (task *Task) UpdateCols(ctx context.Context, cols ...string) error { + _, err := db.GetEngine(ctx).ID(task.ID).Cols(cols...).Update(task) return err } @@ -169,12 +165,12 @@ func (err ErrTaskDoesNotExist) Unwrap() error { } // GetMigratingTask returns the migrating task by repo's id -func GetMigratingTask(repoID int64) (*Task, error) { +func GetMigratingTask(ctx context.Context, repoID int64) (*Task, error) { task := Task{ RepoID: repoID, Type: structs.TaskTypeMigrateRepo, } - has, err := db.GetEngine(db.DefaultContext).Get(&task) + has, err := db.GetEngine(ctx).Get(&task) if err != nil { return nil, err } else if !has { @@ -184,13 +180,13 @@ func GetMigratingTask(repoID int64) (*Task, error) { } // GetMigratingTaskByID returns the migrating task by repo's id -func GetMigratingTaskByID(id, doerID int64) (*Task, *migration.MigrateOptions, error) { +func GetMigratingTaskByID(ctx context.Context, id, doerID int64) (*Task, *migration.MigrateOptions, error) { task := Task{ ID: id, DoerID: doerID, Type: structs.TaskTypeMigrateRepo, } - has, err := db.GetEngine(db.DefaultContext).Get(&task) + has, err := db.GetEngine(ctx).Get(&task) if err != nil { return nil, nil, err } else if !has { @@ -205,12 +201,12 @@ func GetMigratingTaskByID(id, doerID int64) (*Task, *migration.MigrateOptions, e } // CreateTask creates a task on database -func CreateTask(task *Task) error { - return db.Insert(db.DefaultContext, task) +func CreateTask(ctx context.Context, task *Task) error { + return db.Insert(ctx, task) } // FinishMigrateTask updates database when migrate task finished -func FinishMigrateTask(task *Task) error { +func FinishMigrateTask(ctx context.Context, task *Task) error { task.Status = structs.TaskStatusFinished task.EndTime = timeutil.TimeStampNow() @@ -231,6 +227,6 @@ func FinishMigrateTask(task *Task) error { } task.PayloadContent = string(confBytes) - _, err = db.GetEngine(db.DefaultContext).ID(task.ID).Cols("status", "end_time", "payload_content").Update(task) + _, err = db.GetEngine(ctx).ID(task.ID).Cols("status", "end_time", "payload_content").Update(task) return err } diff --git a/models/auth/session.go b/models/auth/session.go index b60e6a903b..28f25170ee 100644 --- a/models/auth/session.go +++ b/models/auth/session.go @@ -4,6 +4,7 @@ package auth import ( + "context" "fmt" "code.gitea.io/gitea/models/db" @@ -22,8 +23,8 @@ func init() { } // UpdateSession updates the session with provided id -func UpdateSession(key string, data []byte) error { - _, err := db.GetEngine(db.DefaultContext).ID(key).Update(&Session{ +func UpdateSession(ctx context.Context, key string, data []byte) error { + _, err := db.GetEngine(ctx).ID(key).Update(&Session{ Data: data, Expiry: timeutil.TimeStampNow(), }) @@ -31,12 +32,12 @@ func UpdateSession(key string, data []byte) error { } // ReadSession reads the data for the provided session -func ReadSession(key string) (*Session, error) { +func ReadSession(ctx context.Context, key string) (*Session, error) { session := Session{ Key: key, } - ctx, committer, err := db.TxContext(db.DefaultContext) + ctx, committer, err := db.TxContext(ctx) if err != nil { return nil, err } @@ -55,24 +56,24 @@ func ReadSession(key string) (*Session, error) { } // ExistSession checks if a session exists -func ExistSession(key string) (bool, error) { +func ExistSession(ctx context.Context, key string) (bool, error) { session := Session{ Key: key, } - return db.GetEngine(db.DefaultContext).Get(&session) + return db.GetEngine(ctx).Get(&session) } // DestroySession destroys a session -func DestroySession(key string) error { - _, err := db.GetEngine(db.DefaultContext).Delete(&Session{ +func DestroySession(ctx context.Context, key string) error { + _, err := db.GetEngine(ctx).Delete(&Session{ Key: key, }) return err } // RegenerateSession regenerates a session from the old id -func RegenerateSession(oldKey, newKey string) (*Session, error) { - ctx, committer, err := db.TxContext(db.DefaultContext) +func RegenerateSession(ctx context.Context, oldKey, newKey string) (*Session, error) { + ctx, committer, err := db.TxContext(ctx) if err != nil { return nil, err } @@ -114,12 +115,12 @@ func RegenerateSession(oldKey, newKey string) (*Session, error) { } // CountSessions returns the number of sessions -func CountSessions() (int64, error) { - return db.GetEngine(db.DefaultContext).Count(&Session{}) +func CountSessions(ctx context.Context) (int64, error) { + return db.GetEngine(ctx).Count(&Session{}) } // CleanupSessions cleans up expired sessions -func CleanupSessions(maxLifetime int64) error { - _, err := db.GetEngine(db.DefaultContext).Where("expiry <= ?", timeutil.TimeStampNow().Add(-maxLifetime)).Delete(&Session{}) +func CleanupSessions(ctx context.Context, maxLifetime int64) error { + _, err := db.GetEngine(ctx).Where("expiry <= ?", timeutil.TimeStampNow().Add(-maxLifetime)).Delete(&Session{}) return err } diff --git a/models/auth/webauthn.go b/models/auth/webauthn.go index db5dd7eea5..d12713bd37 100644 --- a/models/auth/webauthn.go +++ b/models/auth/webauthn.go @@ -67,11 +67,7 @@ func (cred WebAuthnCredential) TableName() string { } // UpdateSignCount will update the database value of SignCount -func (cred *WebAuthnCredential) UpdateSignCount() error { - return cred.updateSignCount(db.DefaultContext) -} - -func (cred *WebAuthnCredential) updateSignCount(ctx context.Context) error { +func (cred *WebAuthnCredential) UpdateSignCount(ctx context.Context) error { _, err := db.GetEngine(ctx).ID(cred.ID).Cols("sign_count").Update(cred) return err } @@ -113,30 +109,18 @@ func (list WebAuthnCredentialList) ToCredentials() []webauthn.Credential { } // GetWebAuthnCredentialsByUID returns all WebAuthn credentials of the given user -func GetWebAuthnCredentialsByUID(uid int64) (WebAuthnCredentialList, error) { - return getWebAuthnCredentialsByUID(db.DefaultContext, uid) -} - -func getWebAuthnCredentialsByUID(ctx context.Context, uid int64) (WebAuthnCredentialList, error) { +func GetWebAuthnCredentialsByUID(ctx context.Context, uid int64) (WebAuthnCredentialList, error) { creds := make(WebAuthnCredentialList, 0) return creds, db.GetEngine(ctx).Where("user_id = ?", uid).Find(&creds) } // ExistsWebAuthnCredentialsForUID returns if the given user has credentials -func ExistsWebAuthnCredentialsForUID(uid int64) (bool, error) { - return existsWebAuthnCredentialsByUID(db.DefaultContext, uid) -} - -func existsWebAuthnCredentialsByUID(ctx context.Context, uid int64) (bool, error) { +func ExistsWebAuthnCredentialsForUID(ctx context.Context, uid int64) (bool, error) { return db.GetEngine(ctx).Where("user_id = ?", uid).Exist(&WebAuthnCredential{}) } // GetWebAuthnCredentialByName returns WebAuthn credential by id -func GetWebAuthnCredentialByName(uid int64, name string) (*WebAuthnCredential, error) { - return getWebAuthnCredentialByName(db.DefaultContext, uid, name) -} - -func getWebAuthnCredentialByName(ctx context.Context, uid int64, name string) (*WebAuthnCredential, error) { +func GetWebAuthnCredentialByName(ctx context.Context, uid int64, name string) (*WebAuthnCredential, error) { cred := new(WebAuthnCredential) if found, err := db.GetEngine(ctx).Where("user_id = ? AND lower_name = ?", uid, strings.ToLower(name)).Get(cred); err != nil { return nil, err @@ -147,11 +131,7 @@ func getWebAuthnCredentialByName(ctx context.Context, uid int64, name string) (* } // GetWebAuthnCredentialByID returns WebAuthn credential by id -func GetWebAuthnCredentialByID(id int64) (*WebAuthnCredential, error) { - return getWebAuthnCredentialByID(db.DefaultContext, id) -} - -func getWebAuthnCredentialByID(ctx context.Context, id int64) (*WebAuthnCredential, error) { +func GetWebAuthnCredentialByID(ctx context.Context, id int64) (*WebAuthnCredential, error) { cred := new(WebAuthnCredential) if found, err := db.GetEngine(ctx).ID(id).Get(cred); err != nil { return nil, err @@ -162,16 +142,12 @@ func getWebAuthnCredentialByID(ctx context.Context, id int64) (*WebAuthnCredenti } // HasWebAuthnRegistrationsByUID returns whether a given user has WebAuthn registrations -func HasWebAuthnRegistrationsByUID(uid int64) (bool, error) { - return db.GetEngine(db.DefaultContext).Where("user_id = ?", uid).Exist(&WebAuthnCredential{}) +func HasWebAuthnRegistrationsByUID(ctx context.Context, uid int64) (bool, error) { + return db.GetEngine(ctx).Where("user_id = ?", uid).Exist(&WebAuthnCredential{}) } // GetWebAuthnCredentialByCredID returns WebAuthn credential by credential ID -func GetWebAuthnCredentialByCredID(userID int64, credID []byte) (*WebAuthnCredential, error) { - return getWebAuthnCredentialByCredID(db.DefaultContext, userID, credID) -} - -func getWebAuthnCredentialByCredID(ctx context.Context, userID int64, credID []byte) (*WebAuthnCredential, error) { +func GetWebAuthnCredentialByCredID(ctx context.Context, userID int64, credID []byte) (*WebAuthnCredential, error) { cred := new(WebAuthnCredential) if found, err := db.GetEngine(ctx).Where("user_id = ? AND credential_id = ?", userID, credID).Get(cred); err != nil { return nil, err @@ -182,11 +158,7 @@ func getWebAuthnCredentialByCredID(ctx context.Context, userID int64, credID []b } // CreateCredential will create a new WebAuthnCredential from the given Credential -func CreateCredential(userID int64, name string, cred *webauthn.Credential) (*WebAuthnCredential, error) { - return createCredential(db.DefaultContext, userID, name, cred) -} - -func createCredential(ctx context.Context, userID int64, name string, cred *webauthn.Credential) (*WebAuthnCredential, error) { +func CreateCredential(ctx context.Context, userID int64, name string, cred *webauthn.Credential) (*WebAuthnCredential, error) { c := &WebAuthnCredential{ UserID: userID, Name: name, @@ -205,18 +177,14 @@ func createCredential(ctx context.Context, userID int64, name string, cred *weba } // DeleteCredential will delete WebAuthnCredential -func DeleteCredential(id, userID int64) (bool, error) { - return deleteCredential(db.DefaultContext, id, userID) -} - -func deleteCredential(ctx context.Context, id, userID int64) (bool, error) { +func DeleteCredential(ctx context.Context, id, userID int64) (bool, error) { had, err := db.GetEngine(ctx).ID(id).Where("user_id = ?", userID).Delete(&WebAuthnCredential{}) return had > 0, err } // WebAuthnCredentials implementns the webauthn.User interface -func WebAuthnCredentials(userID int64) ([]webauthn.Credential, error) { - dbCreds, err := GetWebAuthnCredentialsByUID(userID) +func WebAuthnCredentials(ctx context.Context, userID int64) ([]webauthn.Credential, error) { + dbCreds, err := GetWebAuthnCredentialsByUID(ctx, userID) if err != nil { return nil, err } diff --git a/models/auth/webauthn_test.go b/models/auth/webauthn_test.go index 6f2ec087c7..f1cf398adf 100644 --- a/models/auth/webauthn_test.go +++ b/models/auth/webauthn_test.go @@ -7,6 +7,7 @@ import ( "testing" auth_model "code.gitea.io/gitea/models/auth" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/models/unittest" "github.com/go-webauthn/webauthn/webauthn" @@ -16,11 +17,11 @@ import ( func TestGetWebAuthnCredentialByID(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - res, err := auth_model.GetWebAuthnCredentialByID(1) + res, err := auth_model.GetWebAuthnCredentialByID(db.DefaultContext, 1) assert.NoError(t, err) assert.Equal(t, "WebAuthn credential", res.Name) - _, err = auth_model.GetWebAuthnCredentialByID(342432) + _, err = auth_model.GetWebAuthnCredentialByID(db.DefaultContext, 342432) assert.Error(t, err) assert.True(t, auth_model.IsErrWebAuthnCredentialNotExist(err)) } @@ -28,7 +29,7 @@ func TestGetWebAuthnCredentialByID(t *testing.T) { func TestGetWebAuthnCredentialsByUID(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - res, err := auth_model.GetWebAuthnCredentialsByUID(32) + res, err := auth_model.GetWebAuthnCredentialsByUID(db.DefaultContext, 32) assert.NoError(t, err) assert.Len(t, res, 1) assert.Equal(t, "WebAuthn credential", res[0].Name) @@ -42,7 +43,7 @@ func TestWebAuthnCredential_UpdateSignCount(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) cred := unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1}) cred.SignCount = 1 - assert.NoError(t, cred.UpdateSignCount()) + assert.NoError(t, cred.UpdateSignCount(db.DefaultContext)) unittest.AssertExistsIf(t, true, &auth_model.WebAuthnCredential{ID: 1, SignCount: 1}) } @@ -50,14 +51,14 @@ func TestWebAuthnCredential_UpdateLargeCounter(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) cred := unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1}) cred.SignCount = 0xffffffff - assert.NoError(t, cred.UpdateSignCount()) + assert.NoError(t, cred.UpdateSignCount(db.DefaultContext)) unittest.AssertExistsIf(t, true, &auth_model.WebAuthnCredential{ID: 1, SignCount: 0xffffffff}) } func TestCreateCredential(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - res, err := auth_model.CreateCredential(1, "WebAuthn Created Credential", &webauthn.Credential{ID: []byte("Test")}) + res, err := auth_model.CreateCredential(db.DefaultContext, 1, "WebAuthn Created Credential", &webauthn.Credential{ID: []byte("Test")}) assert.NoError(t, err) assert.Equal(t, "WebAuthn Created Credential", res.Name) assert.Equal(t, []byte("Test"), res.CredentialID) diff --git a/models/issues/issue_test.go b/models/issues/issue_test.go index 747fbbc78c..b7fa7eff1c 100644 --- a/models/issues/issue_test.go +++ b/models/issues/issue_test.go @@ -385,7 +385,7 @@ func TestMilestoneList_LoadTotalTrackedTimes(t *testing.T) { unittest.AssertExistsAndLoadBean(t, &issues_model.Milestone{ID: 1}), } - assert.NoError(t, miles.LoadTotalTrackedTimes()) + assert.NoError(t, miles.LoadTotalTrackedTimes(db.DefaultContext)) assert.Equal(t, int64(3682), miles[0].TotalTrackedTime) } @@ -394,7 +394,7 @@ func TestLoadTotalTrackedTime(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) milestone := unittest.AssertExistsAndLoadBean(t, &issues_model.Milestone{ID: 1}) - assert.NoError(t, milestone.LoadTotalTrackedTime()) + assert.NoError(t, milestone.LoadTotalTrackedTime(db.DefaultContext)) assert.Equal(t, int64(3682), milestone.TotalTrackedTime) } diff --git a/models/issues/issue_watch.go b/models/issues/issue_watch.go index 1efc0ea687..b7e9504c67 100644 --- a/models/issues/issue_watch.go +++ b/models/issues/issue_watch.go @@ -30,8 +30,8 @@ func init() { type IssueWatchList []*IssueWatch // CreateOrUpdateIssueWatch set watching for a user and issue -func CreateOrUpdateIssueWatch(userID, issueID int64, isWatching bool) error { - iw, exists, err := GetIssueWatch(db.DefaultContext, userID, issueID) +func CreateOrUpdateIssueWatch(ctx context.Context, userID, issueID int64, isWatching bool) error { + iw, exists, err := GetIssueWatch(ctx, userID, issueID) if err != nil { return err } @@ -43,13 +43,13 @@ func CreateOrUpdateIssueWatch(userID, issueID int64, isWatching bool) error { IsWatching: isWatching, } - if _, err := db.GetEngine(db.DefaultContext).Insert(iw); err != nil { + if _, err := db.GetEngine(ctx).Insert(iw); err != nil { return err } } else { iw.IsWatching = isWatching - if _, err := db.GetEngine(db.DefaultContext).ID(iw.ID).Cols("is_watching", "updated_unix").Update(iw); err != nil { + if _, err := db.GetEngine(ctx).ID(iw.ID).Cols("is_watching", "updated_unix").Update(iw); err != nil { return err } } @@ -69,15 +69,15 @@ func GetIssueWatch(ctx context.Context, userID, issueID int64) (iw *IssueWatch, // CheckIssueWatch check if an user is watching an issue // it takes participants and repo watch into account -func CheckIssueWatch(user *user_model.User, issue *Issue) (bool, error) { - iw, exist, err := GetIssueWatch(db.DefaultContext, user.ID, issue.ID) +func CheckIssueWatch(ctx context.Context, user *user_model.User, issue *Issue) (bool, error) { + iw, exist, err := GetIssueWatch(ctx, user.ID, issue.ID) if err != nil { return false, err } if exist { return iw.IsWatching, nil } - w, err := repo_model.GetWatch(db.DefaultContext, user.ID, issue.RepoID) + w, err := repo_model.GetWatch(ctx, user.ID, issue.RepoID) if err != nil { return false, err } diff --git a/models/issues/issue_watch_test.go b/models/issues/issue_watch_test.go index 4f44487f56..d4ce8d8d3d 100644 --- a/models/issues/issue_watch_test.go +++ b/models/issues/issue_watch_test.go @@ -16,11 +16,11 @@ import ( func TestCreateOrUpdateIssueWatch(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - assert.NoError(t, issues_model.CreateOrUpdateIssueWatch(3, 1, true)) + assert.NoError(t, issues_model.CreateOrUpdateIssueWatch(db.DefaultContext, 3, 1, true)) iw := unittest.AssertExistsAndLoadBean(t, &issues_model.IssueWatch{UserID: 3, IssueID: 1}) assert.True(t, iw.IsWatching) - assert.NoError(t, issues_model.CreateOrUpdateIssueWatch(1, 1, false)) + assert.NoError(t, issues_model.CreateOrUpdateIssueWatch(db.DefaultContext, 1, 1, false)) iw = unittest.AssertExistsAndLoadBean(t, &issues_model.IssueWatch{UserID: 1, IssueID: 1}) assert.False(t, iw.IsWatching) } diff --git a/models/issues/label.go b/models/issues/label.go index 0087c933a6..f8dbb9e39c 100644 --- a/models/issues/label.go +++ b/models/issues/label.go @@ -199,8 +199,8 @@ func NewLabel(ctx context.Context, l *Label) error { } // NewLabels creates new labels -func NewLabels(labels ...*Label) error { - ctx, committer, err := db.TxContext(db.DefaultContext) +func NewLabels(ctx context.Context, labels ...*Label) error { + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } @@ -221,19 +221,19 @@ func NewLabels(labels ...*Label) error { } // UpdateLabel updates label information. -func UpdateLabel(l *Label) error { +func UpdateLabel(ctx context.Context, l *Label) error { color, err := label.NormalizeColor(l.Color) if err != nil { return err } l.Color = color - return updateLabelCols(db.DefaultContext, l, "name", "description", "color", "exclusive", "archived_unix") + return updateLabelCols(ctx, l, "name", "description", "color", "exclusive", "archived_unix") } // DeleteLabel delete a label -func DeleteLabel(id, labelID int64) error { - l, err := GetLabelByID(db.DefaultContext, labelID) +func DeleteLabel(ctx context.Context, id, labelID int64) error { + l, err := GetLabelByID(ctx, labelID) if err != nil { if IsErrLabelNotExist(err) { return nil @@ -241,7 +241,7 @@ func DeleteLabel(id, labelID int64) error { return err } - ctx, committer, err := db.TxContext(db.DefaultContext) + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } @@ -289,9 +289,9 @@ func GetLabelByID(ctx context.Context, labelID int64) (*Label, error) { } // GetLabelsByIDs returns a list of labels by IDs -func GetLabelsByIDs(labelIDs []int64, cols ...string) ([]*Label, error) { +func GetLabelsByIDs(ctx context.Context, labelIDs []int64, cols ...string) ([]*Label, error) { labels := make([]*Label, 0, len(labelIDs)) - return labels, db.GetEngine(db.DefaultContext).Table("label"). + return labels, db.GetEngine(ctx).Table("label"). In("id", labelIDs). Asc("name"). Cols(cols...). @@ -339,9 +339,9 @@ func GetLabelInRepoByID(ctx context.Context, repoID, labelID int64) (*Label, err // GetLabelIDsInRepoByNames returns a list of labelIDs by names in a given // repository. // it silently ignores label names that do not belong to the repository. -func GetLabelIDsInRepoByNames(repoID int64, labelNames []string) ([]int64, error) { +func GetLabelIDsInRepoByNames(ctx context.Context, repoID int64, labelNames []string) ([]int64, error) { labelIDs := make([]int64, 0, len(labelNames)) - return labelIDs, db.GetEngine(db.DefaultContext).Table("label"). + return labelIDs, db.GetEngine(ctx).Table("label"). Where("repo_id = ?", repoID). In("name", labelNames). Asc("name"). @@ -398,8 +398,8 @@ func GetLabelsByRepoID(ctx context.Context, repoID int64, sortType string, listO } // CountLabelsByRepoID count number of all labels that belong to given repository by ID. -func CountLabelsByRepoID(repoID int64) (int64, error) { - return db.GetEngine(db.DefaultContext).Where("repo_id = ?", repoID).Count(&Label{}) +func CountLabelsByRepoID(ctx context.Context, repoID int64) (int64, error) { + return db.GetEngine(ctx).Where("repo_id = ?", repoID).Count(&Label{}) } // GetLabelInOrgByName returns a label by name in given organization. @@ -442,13 +442,13 @@ func GetLabelInOrgByID(ctx context.Context, orgID, labelID int64) (*Label, error // GetLabelIDsInOrgByNames returns a list of labelIDs by names in a given // organization. -func GetLabelIDsInOrgByNames(orgID int64, labelNames []string) ([]int64, error) { +func GetLabelIDsInOrgByNames(ctx context.Context, orgID int64, labelNames []string) ([]int64, error) { if orgID <= 0 { return nil, ErrOrgLabelNotExist{0, orgID} } labelIDs := make([]int64, 0, len(labelNames)) - return labelIDs, db.GetEngine(db.DefaultContext).Table("label"). + return labelIDs, db.GetEngine(ctx).Table("label"). Where("org_id = ?", orgID). In("name", labelNames). Asc("name"). @@ -506,8 +506,8 @@ func GetLabelIDsByNames(ctx context.Context, labelNames []string) ([]int64, erro } // CountLabelsByOrgID count all labels that belong to given organization by ID. -func CountLabelsByOrgID(orgID int64) (int64, error) { - return db.GetEngine(db.DefaultContext).Where("org_id = ?", orgID).Count(&Label{}) +func CountLabelsByOrgID(ctx context.Context, orgID int64) (int64, error) { + return db.GetEngine(ctx).Where("org_id = ?", orgID).Count(&Label{}) } func updateLabelCols(ctx context.Context, l *Label, cols ...string) error { diff --git a/models/issues/label_test.go b/models/issues/label_test.go index 3f0e980b31..9f44cd3e03 100644 --- a/models/issues/label_test.go +++ b/models/issues/label_test.go @@ -48,7 +48,7 @@ func TestNewLabels(t *testing.T) { for _, label := range labels { unittest.AssertNotExistsBean(t, label) } - assert.NoError(t, issues_model.NewLabels(labels...)) + assert.NoError(t, issues_model.NewLabels(db.DefaultContext, labels...)) for _, label := range labels { unittest.AssertExistsAndLoadBean(t, label, unittest.Cond("id = ?", label.ID)) } @@ -81,7 +81,7 @@ func TestGetLabelInRepoByName(t *testing.T) { func TestGetLabelInRepoByNames(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - labelIDs, err := issues_model.GetLabelIDsInRepoByNames(1, []string{"label1", "label2"}) + labelIDs, err := issues_model.GetLabelIDsInRepoByNames(db.DefaultContext, 1, []string{"label1", "label2"}) assert.NoError(t, err) assert.Len(t, labelIDs, 2) @@ -93,7 +93,7 @@ func TestGetLabelInRepoByNames(t *testing.T) { func TestGetLabelInRepoByNamesDiscardsNonExistentLabels(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) // label3 doesn't exists.. See labels.yml - labelIDs, err := issues_model.GetLabelIDsInRepoByNames(1, []string{"label1", "label2", "label3"}) + labelIDs, err := issues_model.GetLabelIDsInRepoByNames(db.DefaultContext, 1, []string{"label1", "label2", "label3"}) assert.NoError(t, err) assert.Len(t, labelIDs, 2) @@ -166,7 +166,7 @@ func TestGetLabelInOrgByName(t *testing.T) { func TestGetLabelInOrgByNames(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - labelIDs, err := issues_model.GetLabelIDsInOrgByNames(3, []string{"orglabel3", "orglabel4"}) + labelIDs, err := issues_model.GetLabelIDsInOrgByNames(db.DefaultContext, 3, []string{"orglabel3", "orglabel4"}) assert.NoError(t, err) assert.Len(t, labelIDs, 2) @@ -178,7 +178,7 @@ func TestGetLabelInOrgByNames(t *testing.T) { func TestGetLabelInOrgByNamesDiscardsNonExistentLabels(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) // orglabel99 doesn't exists.. See labels.yml - labelIDs, err := issues_model.GetLabelIDsInOrgByNames(3, []string{"orglabel3", "orglabel4", "orglabel99"}) + labelIDs, err := issues_model.GetLabelIDsInOrgByNames(db.DefaultContext, 3, []string{"orglabel3", "orglabel4", "orglabel99"}) assert.NoError(t, err) assert.Len(t, labelIDs, 2) @@ -269,7 +269,7 @@ func TestUpdateLabel(t *testing.T) { } label.Color = update.Color label.Name = update.Name - assert.NoError(t, issues_model.UpdateLabel(update)) + assert.NoError(t, issues_model.UpdateLabel(db.DefaultContext, update)) newLabel := unittest.AssertExistsAndLoadBean(t, &issues_model.Label{ID: 1}) assert.EqualValues(t, label.ID, newLabel.ID) assert.EqualValues(t, label.Color, newLabel.Color) @@ -282,13 +282,13 @@ func TestUpdateLabel(t *testing.T) { func TestDeleteLabel(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) label := unittest.AssertExistsAndLoadBean(t, &issues_model.Label{ID: 1}) - assert.NoError(t, issues_model.DeleteLabel(label.RepoID, label.ID)) + assert.NoError(t, issues_model.DeleteLabel(db.DefaultContext, label.RepoID, label.ID)) unittest.AssertNotExistsBean(t, &issues_model.Label{ID: label.ID, RepoID: label.RepoID}) - assert.NoError(t, issues_model.DeleteLabel(label.RepoID, label.ID)) + assert.NoError(t, issues_model.DeleteLabel(db.DefaultContext, label.RepoID, label.ID)) unittest.AssertNotExistsBean(t, &issues_model.Label{ID: label.ID}) - assert.NoError(t, issues_model.DeleteLabel(unittest.NonexistentID, unittest.NonexistentID)) + assert.NoError(t, issues_model.DeleteLabel(db.DefaultContext, unittest.NonexistentID, unittest.NonexistentID)) unittest.CheckConsistencyFor(t, &issues_model.Label{}, &repo_model.Repository{}) } diff --git a/models/issues/milestone.go b/models/issues/milestone.go index c15b2a41fe..ad1d5d0453 100644 --- a/models/issues/milestone.go +++ b/models/issues/milestone.go @@ -103,8 +103,8 @@ func (m *Milestone) State() api.StateType { } // NewMilestone creates new milestone of repository. -func NewMilestone(m *Milestone) (err error) { - ctx, committer, err := db.TxContext(db.DefaultContext) +func NewMilestone(ctx context.Context, m *Milestone) (err error) { + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } @@ -140,9 +140,9 @@ func GetMilestoneByRepoID(ctx context.Context, repoID, id int64) (*Milestone, er } // GetMilestoneByRepoIDANDName return a milestone if one exist by name and repo -func GetMilestoneByRepoIDANDName(repoID int64, name string) (*Milestone, error) { +func GetMilestoneByRepoIDANDName(ctx context.Context, repoID int64, name string) (*Milestone, error) { var mile Milestone - has, err := db.GetEngine(db.DefaultContext).Where("repo_id=? AND name=?", repoID, name).Get(&mile) + has, err := db.GetEngine(ctx).Where("repo_id=? AND name=?", repoID, name).Get(&mile) if err != nil { return nil, err } @@ -153,8 +153,8 @@ func GetMilestoneByRepoIDANDName(repoID int64, name string) (*Milestone, error) } // UpdateMilestone updates information of given milestone. -func UpdateMilestone(m *Milestone, oldIsClosed bool) error { - ctx, committer, err := db.TxContext(db.DefaultContext) +func UpdateMilestone(ctx context.Context, m *Milestone, oldIsClosed bool) error { + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } @@ -211,8 +211,8 @@ func UpdateMilestoneCounters(ctx context.Context, id int64) error { } // ChangeMilestoneStatusByRepoIDAndID changes a milestone open/closed status if the milestone ID is in the repo. -func ChangeMilestoneStatusByRepoIDAndID(repoID, milestoneID int64, isClosed bool) error { - ctx, committer, err := db.TxContext(db.DefaultContext) +func ChangeMilestoneStatusByRepoIDAndID(ctx context.Context, repoID, milestoneID int64, isClosed bool) error { + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } @@ -238,8 +238,8 @@ func ChangeMilestoneStatusByRepoIDAndID(repoID, milestoneID int64, isClosed bool } // ChangeMilestoneStatus changes the milestone open/closed status. -func ChangeMilestoneStatus(m *Milestone, isClosed bool) (err error) { - ctx, committer, err := db.TxContext(db.DefaultContext) +func ChangeMilestoneStatus(ctx context.Context, m *Milestone, isClosed bool) (err error) { + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } @@ -269,8 +269,8 @@ func changeMilestoneStatus(ctx context.Context, m *Milestone, isClosed bool) err } // DeleteMilestoneByRepoID deletes a milestone from a repository. -func DeleteMilestoneByRepoID(repoID, id int64) error { - m, err := GetMilestoneByRepoID(db.DefaultContext, repoID, id) +func DeleteMilestoneByRepoID(ctx context.Context, repoID, id int64) error { + m, err := GetMilestoneByRepoID(ctx, repoID, id) if err != nil { if IsErrMilestoneNotExist(err) { return nil @@ -278,12 +278,12 @@ func DeleteMilestoneByRepoID(repoID, id int64) error { return err } - repo, err := repo_model.GetRepositoryByID(db.DefaultContext, m.RepoID) + repo, err := repo_model.GetRepositoryByID(ctx, m.RepoID) if err != nil { return err } - ctx, committer, err := db.TxContext(db.DefaultContext) + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } @@ -332,7 +332,8 @@ func updateRepoMilestoneNum(ctx context.Context, repoID int64) error { return err } -func (m *Milestone) loadTotalTrackedTime(ctx context.Context) error { +// LoadTotalTrackedTime loads the tracked time for the milestone +func (m *Milestone) LoadTotalTrackedTime(ctx context.Context) error { type totalTimesByMilestone struct { MilestoneID int64 Time int64 @@ -355,18 +356,13 @@ func (m *Milestone) loadTotalTrackedTime(ctx context.Context) error { return nil } -// LoadTotalTrackedTime loads the tracked time for the milestone -func (m *Milestone) LoadTotalTrackedTime() error { - return m.loadTotalTrackedTime(db.DefaultContext) -} - // InsertMilestones creates milestones of repository. -func InsertMilestones(ms ...*Milestone) (err error) { +func InsertMilestones(ctx context.Context, ms ...*Milestone) (err error) { if len(ms) == 0 { return nil } - ctx, committer, err := db.TxContext(db.DefaultContext) + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } diff --git a/models/issues/milestone_list.go b/models/issues/milestone_list.go index b0c29106a0..d5c9b1358c 100644 --- a/models/issues/milestone_list.go +++ b/models/issues/milestone_list.go @@ -100,9 +100,9 @@ func GetMilestoneIDsByNames(ctx context.Context, names []string) ([]int64, error } // SearchMilestones search milestones -func SearchMilestones(repoCond builder.Cond, page int, isClosed bool, sortType, keyword string) (MilestoneList, error) { +func SearchMilestones(ctx context.Context, repoCond builder.Cond, page int, isClosed bool, sortType, keyword string) (MilestoneList, error) { miles := make([]*Milestone, 0, setting.UI.IssuePagingNum) - sess := db.GetEngine(db.DefaultContext).Where("is_closed = ?", isClosed) + sess := db.GetEngine(ctx).Where("is_closed = ?", isClosed) if len(keyword) > 0 { sess = sess.And(builder.Like{"UPPER(name)", strings.ToUpper(keyword)}) } @@ -131,8 +131,9 @@ func SearchMilestones(repoCond builder.Cond, page int, isClosed bool, sortType, } // GetMilestonesByRepoIDs returns a list of milestones of given repositories and status. -func GetMilestonesByRepoIDs(repoIDs []int64, page int, isClosed bool, sortType string) (MilestoneList, error) { +func GetMilestonesByRepoIDs(ctx context.Context, repoIDs []int64, page int, isClosed bool, sortType string) (MilestoneList, error) { return SearchMilestones( + ctx, builder.In("repo_id", repoIDs), page, isClosed, @@ -141,7 +142,8 @@ func GetMilestonesByRepoIDs(repoIDs []int64, page int, isClosed bool, sortType s ) } -func (milestones MilestoneList) loadTotalTrackedTimes(ctx context.Context) error { +// LoadTotalTrackedTimes loads for every milestone in the list the TotalTrackedTime by a batch request +func (milestones MilestoneList) LoadTotalTrackedTimes(ctx context.Context) error { type totalTimesByMilestone struct { MilestoneID int64 Time int64 @@ -181,11 +183,6 @@ func (milestones MilestoneList) loadTotalTrackedTimes(ctx context.Context) error return nil } -// LoadTotalTrackedTimes loads for every milestone in the list the TotalTrackedTime by a batch request -func (milestones MilestoneList) LoadTotalTrackedTimes() error { - return milestones.loadTotalTrackedTimes(db.DefaultContext) -} - // CountMilestones returns number of milestones in given repository with other options func CountMilestones(ctx context.Context, opts GetMilestonesOption) (int64, error) { return db.GetEngine(ctx). @@ -194,8 +191,8 @@ func CountMilestones(ctx context.Context, opts GetMilestonesOption) (int64, erro } // CountMilestonesByRepoCond map from repo conditions to number of milestones matching the options` -func CountMilestonesByRepoCond(repoCond builder.Cond, isClosed bool) (map[int64]int64, error) { - sess := db.GetEngine(db.DefaultContext).Where("is_closed = ?", isClosed) +func CountMilestonesByRepoCond(ctx context.Context, repoCond builder.Cond, isClosed bool) (map[int64]int64, error) { + sess := db.GetEngine(ctx).Where("is_closed = ?", isClosed) if repoCond.IsValid() { sess.In("repo_id", builder.Select("id").From("repository").Where(repoCond)) } @@ -219,8 +216,8 @@ func CountMilestonesByRepoCond(repoCond builder.Cond, isClosed bool) (map[int64] } // CountMilestonesByRepoCondAndKw map from repo conditions and the keyword of milestones' name to number of milestones matching the options` -func CountMilestonesByRepoCondAndKw(repoCond builder.Cond, keyword string, isClosed bool) (map[int64]int64, error) { - sess := db.GetEngine(db.DefaultContext).Where("is_closed = ?", isClosed) +func CountMilestonesByRepoCondAndKw(ctx context.Context, repoCond builder.Cond, keyword string, isClosed bool) (map[int64]int64, error) { + sess := db.GetEngine(ctx).Where("is_closed = ?", isClosed) if len(keyword) > 0 { sess = sess.And(builder.Like{"UPPER(name)", strings.ToUpper(keyword)}) } @@ -257,11 +254,11 @@ func (m MilestonesStats) Total() int64 { } // GetMilestonesStatsByRepoCond returns milestone statistic information for dashboard by given conditions. -func GetMilestonesStatsByRepoCond(repoCond builder.Cond) (*MilestonesStats, error) { +func GetMilestonesStatsByRepoCond(ctx context.Context, repoCond builder.Cond) (*MilestonesStats, error) { var err error stats := &MilestonesStats{} - sess := db.GetEngine(db.DefaultContext).Where("is_closed = ?", false) + sess := db.GetEngine(ctx).Where("is_closed = ?", false) if repoCond.IsValid() { sess.And(builder.In("repo_id", builder.Select("id").From("repository").Where(repoCond))) } @@ -270,7 +267,7 @@ func GetMilestonesStatsByRepoCond(repoCond builder.Cond) (*MilestonesStats, erro return nil, err } - sess = db.GetEngine(db.DefaultContext).Where("is_closed = ?", true) + sess = db.GetEngine(ctx).Where("is_closed = ?", true) if repoCond.IsValid() { sess.And(builder.In("repo_id", builder.Select("id").From("repository").Where(repoCond))) } @@ -283,11 +280,11 @@ func GetMilestonesStatsByRepoCond(repoCond builder.Cond) (*MilestonesStats, erro } // GetMilestonesStatsByRepoCondAndKw returns milestone statistic information for dashboard by given repo conditions and name keyword. -func GetMilestonesStatsByRepoCondAndKw(repoCond builder.Cond, keyword string) (*MilestonesStats, error) { +func GetMilestonesStatsByRepoCondAndKw(ctx context.Context, repoCond builder.Cond, keyword string) (*MilestonesStats, error) { var err error stats := &MilestonesStats{} - sess := db.GetEngine(db.DefaultContext).Where("is_closed = ?", false) + sess := db.GetEngine(ctx).Where("is_closed = ?", false) if len(keyword) > 0 { sess = sess.And(builder.Like{"UPPER(name)", strings.ToUpper(keyword)}) } @@ -299,7 +296,7 @@ func GetMilestonesStatsByRepoCondAndKw(repoCond builder.Cond, keyword string) (* return nil, err } - sess = db.GetEngine(db.DefaultContext).Where("is_closed = ?", true) + sess = db.GetEngine(ctx).Where("is_closed = ?", true) if len(keyword) > 0 { sess = sess.And(builder.Like{"UPPER(name)", strings.ToUpper(keyword)}) } diff --git a/models/issues/milestone_test.go b/models/issues/milestone_test.go index e85d77ebc8..403eeaadb3 100644 --- a/models/issues/milestone_test.go +++ b/models/issues/milestone_test.go @@ -201,12 +201,12 @@ func TestCountMilestonesByRepoIDs(t *testing.T) { repo1OpenCount, repo1ClosedCount := milestonesCount(1) repo2OpenCount, repo2ClosedCount := milestonesCount(2) - openCounts, err := issues_model.CountMilestonesByRepoCond(builder.In("repo_id", []int64{1, 2}), false) + openCounts, err := issues_model.CountMilestonesByRepoCond(db.DefaultContext, builder.In("repo_id", []int64{1, 2}), false) assert.NoError(t, err) assert.EqualValues(t, repo1OpenCount, openCounts[1]) assert.EqualValues(t, repo2OpenCount, openCounts[2]) - closedCounts, err := issues_model.CountMilestonesByRepoCond(builder.In("repo_id", []int64{1, 2}), true) + closedCounts, err := issues_model.CountMilestonesByRepoCond(db.DefaultContext, builder.In("repo_id", []int64{1, 2}), true) assert.NoError(t, err) assert.EqualValues(t, repo1ClosedCount, closedCounts[1]) assert.EqualValues(t, repo2ClosedCount, closedCounts[2]) @@ -218,7 +218,7 @@ func TestGetMilestonesByRepoIDs(t *testing.T) { repo2 := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 2}) test := func(sortType string, sortCond func(*issues_model.Milestone) int) { for _, page := range []int{0, 1} { - openMilestones, err := issues_model.GetMilestonesByRepoIDs([]int64{repo1.ID, repo2.ID}, page, false, sortType) + openMilestones, err := issues_model.GetMilestonesByRepoIDs(db.DefaultContext, []int64{repo1.ID, repo2.ID}, page, false, sortType) assert.NoError(t, err) assert.Len(t, openMilestones, repo1.NumOpenMilestones+repo2.NumOpenMilestones) values := make([]int, len(openMilestones)) @@ -227,7 +227,7 @@ func TestGetMilestonesByRepoIDs(t *testing.T) { } assert.True(t, sort.IntsAreSorted(values)) - closedMilestones, err := issues_model.GetMilestonesByRepoIDs([]int64{repo1.ID, repo2.ID}, page, true, sortType) + closedMilestones, err := issues_model.GetMilestonesByRepoIDs(db.DefaultContext, []int64{repo1.ID, repo2.ID}, page, true, sortType) assert.NoError(t, err) assert.Len(t, closedMilestones, repo1.NumClosedMilestones+repo2.NumClosedMilestones) values = make([]int, len(closedMilestones)) @@ -262,7 +262,7 @@ func TestGetMilestonesStats(t *testing.T) { test := func(repoID int64) { repo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: repoID}) - stats, err := issues_model.GetMilestonesStatsByRepoCond(builder.And(builder.Eq{"repo_id": repoID})) + stats, err := issues_model.GetMilestonesStatsByRepoCond(db.DefaultContext, builder.And(builder.Eq{"repo_id": repoID})) assert.NoError(t, err) assert.EqualValues(t, repo.NumMilestones-repo.NumClosedMilestones, stats.OpenCount) assert.EqualValues(t, repo.NumClosedMilestones, stats.ClosedCount) @@ -271,7 +271,7 @@ func TestGetMilestonesStats(t *testing.T) { test(2) test(3) - stats, err := issues_model.GetMilestonesStatsByRepoCond(builder.And(builder.Eq{"repo_id": unittest.NonexistentID})) + stats, err := issues_model.GetMilestonesStatsByRepoCond(db.DefaultContext, builder.And(builder.Eq{"repo_id": unittest.NonexistentID})) assert.NoError(t, err) assert.EqualValues(t, 0, stats.OpenCount) assert.EqualValues(t, 0, stats.ClosedCount) @@ -279,7 +279,7 @@ func TestGetMilestonesStats(t *testing.T) { repo1 := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 1}) repo2 := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 2}) - milestoneStats, err := issues_model.GetMilestonesStatsByRepoCond(builder.In("repo_id", []int64{repo1.ID, repo2.ID})) + milestoneStats, err := issues_model.GetMilestonesStatsByRepoCond(db.DefaultContext, builder.In("repo_id", []int64{repo1.ID, repo2.ID})) assert.NoError(t, err) assert.EqualValues(t, repo1.NumOpenMilestones+repo2.NumOpenMilestones, milestoneStats.OpenCount) assert.EqualValues(t, repo1.NumClosedMilestones+repo2.NumClosedMilestones, milestoneStats.ClosedCount) @@ -293,7 +293,7 @@ func TestNewMilestone(t *testing.T) { Content: "milestoneContent", } - assert.NoError(t, issues_model.NewMilestone(milestone)) + assert.NoError(t, issues_model.NewMilestone(db.DefaultContext, milestone)) unittest.AssertExistsAndLoadBean(t, milestone) unittest.CheckConsistencyFor(t, &repo_model.Repository{ID: milestone.RepoID}, &issues_model.Milestone{}) } @@ -302,22 +302,22 @@ func TestChangeMilestoneStatus(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) milestone := unittest.AssertExistsAndLoadBean(t, &issues_model.Milestone{ID: 1}) - assert.NoError(t, issues_model.ChangeMilestoneStatus(milestone, true)) + assert.NoError(t, issues_model.ChangeMilestoneStatus(db.DefaultContext, milestone, true)) unittest.AssertExistsAndLoadBean(t, &issues_model.Milestone{ID: 1}, "is_closed=1") unittest.CheckConsistencyFor(t, &repo_model.Repository{ID: milestone.RepoID}, &issues_model.Milestone{}) - assert.NoError(t, issues_model.ChangeMilestoneStatus(milestone, false)) + assert.NoError(t, issues_model.ChangeMilestoneStatus(db.DefaultContext, milestone, false)) unittest.AssertExistsAndLoadBean(t, &issues_model.Milestone{ID: 1}, "is_closed=0") unittest.CheckConsistencyFor(t, &repo_model.Repository{ID: milestone.RepoID}, &issues_model.Milestone{}) } func TestDeleteMilestoneByRepoID(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - assert.NoError(t, issues_model.DeleteMilestoneByRepoID(1, 1)) + assert.NoError(t, issues_model.DeleteMilestoneByRepoID(db.DefaultContext, 1, 1)) unittest.AssertNotExistsBean(t, &issues_model.Milestone{ID: 1}) unittest.CheckConsistencyFor(t, &repo_model.Repository{ID: 1}) - assert.NoError(t, issues_model.DeleteMilestoneByRepoID(unittest.NonexistentID, unittest.NonexistentID)) + assert.NoError(t, issues_model.DeleteMilestoneByRepoID(db.DefaultContext, unittest.NonexistentID, unittest.NonexistentID)) } func TestUpdateMilestone(t *testing.T) { @@ -326,7 +326,7 @@ func TestUpdateMilestone(t *testing.T) { milestone := unittest.AssertExistsAndLoadBean(t, &issues_model.Milestone{ID: 1}) milestone.Name = " newMilestoneName " milestone.Content = "newMilestoneContent" - assert.NoError(t, issues_model.UpdateMilestone(milestone, milestone.IsClosed)) + assert.NoError(t, issues_model.UpdateMilestone(db.DefaultContext, milestone, milestone.IsClosed)) milestone = unittest.AssertExistsAndLoadBean(t, &issues_model.Milestone{ID: 1}) assert.EqualValues(t, "newMilestoneName", milestone.Name) unittest.CheckConsistencyFor(t, &issues_model.Milestone{}) @@ -361,7 +361,7 @@ func TestMigrate_InsertMilestones(t *testing.T) { RepoID: repo.ID, Name: name, } - err := issues_model.InsertMilestones(ms) + err := issues_model.InsertMilestones(db.DefaultContext, ms) assert.NoError(t, err) unittest.AssertExistsAndLoadBean(t, ms) repoModified := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: repo.ID}) diff --git a/models/issues/stopwatch.go b/models/issues/stopwatch.go index c8cd5ad33f..2c662bdb06 100644 --- a/models/issues/stopwatch.go +++ b/models/issues/stopwatch.go @@ -81,9 +81,9 @@ type UserStopwatch struct { } // GetUIDsAndNotificationCounts between the two provided times -func GetUIDsAndStopwatch() ([]*UserStopwatch, error) { +func GetUIDsAndStopwatch(ctx context.Context) ([]*UserStopwatch, error) { sws := []*Stopwatch{} - if err := db.GetEngine(db.DefaultContext).Where("issue_id != 0").Find(&sws); err != nil { + if err := db.GetEngine(ctx).Where("issue_id != 0").Find(&sws); err != nil { return nil, err } if len(sws) == 0 { @@ -107,9 +107,9 @@ func GetUIDsAndStopwatch() ([]*UserStopwatch, error) { } // GetUserStopwatches return list of all stopwatches of a user -func GetUserStopwatches(userID int64, listOptions db.ListOptions) ([]*Stopwatch, error) { +func GetUserStopwatches(ctx context.Context, userID int64, listOptions db.ListOptions) ([]*Stopwatch, error) { sws := make([]*Stopwatch, 0, 8) - sess := db.GetEngine(db.DefaultContext).Where("stopwatch.user_id = ?", userID) + sess := db.GetEngine(ctx).Where("stopwatch.user_id = ?", userID) if listOptions.Page != 0 { sess = db.SetSessionPagination(sess, &listOptions) } @@ -122,13 +122,13 @@ func GetUserStopwatches(userID int64, listOptions db.ListOptions) ([]*Stopwatch, } // CountUserStopwatches return count of all stopwatches of a user -func CountUserStopwatches(userID int64) (int64, error) { - return db.GetEngine(db.DefaultContext).Where("user_id = ?", userID).Count(&Stopwatch{}) +func CountUserStopwatches(ctx context.Context, userID int64) (int64, error) { + return db.GetEngine(ctx).Where("user_id = ?", userID).Count(&Stopwatch{}) } // StopwatchExists returns true if the stopwatch exists -func StopwatchExists(userID, issueID int64) bool { - _, exists, _ := getStopwatch(db.DefaultContext, userID, issueID) +func StopwatchExists(ctx context.Context, userID, issueID int64) bool { + _, exists, _ := getStopwatch(ctx, userID, issueID) return exists } @@ -168,15 +168,15 @@ func FinishIssueStopwatchIfPossible(ctx context.Context, user *user_model.User, } // CreateOrStopIssueStopwatch create an issue stopwatch if it's not exist, otherwise finish it -func CreateOrStopIssueStopwatch(user *user_model.User, issue *Issue) error { - _, exists, err := getStopwatch(db.DefaultContext, user.ID, issue.ID) +func CreateOrStopIssueStopwatch(ctx context.Context, user *user_model.User, issue *Issue) error { + _, exists, err := getStopwatch(ctx, user.ID, issue.ID) if err != nil { return err } if exists { - return FinishIssueStopwatch(db.DefaultContext, user, issue) + return FinishIssueStopwatch(ctx, user, issue) } - return CreateIssueStopwatch(db.DefaultContext, user, issue) + return CreateIssueStopwatch(ctx, user, issue) } // FinishIssueStopwatch if stopwatch exist then finish it otherwise return an error @@ -269,8 +269,8 @@ func CreateIssueStopwatch(ctx context.Context, user *user_model.User, issue *Iss } // CancelStopwatch removes the given stopwatch and logs it into issue's timeline. -func CancelStopwatch(user *user_model.User, issue *Issue) error { - ctx, committer, err := db.TxContext(db.DefaultContext) +func CancelStopwatch(ctx context.Context, user *user_model.User, issue *Issue) error { + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } diff --git a/models/issues/stopwatch_test.go b/models/issues/stopwatch_test.go index fa937ecbed..39958a7f36 100644 --- a/models/issues/stopwatch_test.go +++ b/models/issues/stopwatch_test.go @@ -26,20 +26,20 @@ func TestCancelStopwatch(t *testing.T) { issue2, err := issues_model.GetIssueByID(db.DefaultContext, 2) assert.NoError(t, err) - err = issues_model.CancelStopwatch(user1, issue1) + err = issues_model.CancelStopwatch(db.DefaultContext, user1, issue1) assert.NoError(t, err) unittest.AssertNotExistsBean(t, &issues_model.Stopwatch{UserID: user1.ID, IssueID: issue1.ID}) _ = unittest.AssertExistsAndLoadBean(t, &issues_model.Comment{Type: issues_model.CommentTypeCancelTracking, PosterID: user1.ID, IssueID: issue1.ID}) - assert.Nil(t, issues_model.CancelStopwatch(user1, issue2)) + assert.Nil(t, issues_model.CancelStopwatch(db.DefaultContext, user1, issue2)) } func TestStopwatchExists(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - assert.True(t, issues_model.StopwatchExists(1, 1)) - assert.False(t, issues_model.StopwatchExists(1, 2)) + assert.True(t, issues_model.StopwatchExists(db.DefaultContext, 1, 1)) + assert.False(t, issues_model.StopwatchExists(db.DefaultContext, 1, 2)) } func TestHasUserStopwatch(t *testing.T) { @@ -68,11 +68,11 @@ func TestCreateOrStopIssueStopwatch(t *testing.T) { issue2, err := issues_model.GetIssueByID(db.DefaultContext, 2) assert.NoError(t, err) - assert.NoError(t, issues_model.CreateOrStopIssueStopwatch(org3, issue1)) + assert.NoError(t, issues_model.CreateOrStopIssueStopwatch(db.DefaultContext, org3, issue1)) sw := unittest.AssertExistsAndLoadBean(t, &issues_model.Stopwatch{UserID: 3, IssueID: 1}) assert.LessOrEqual(t, sw.CreatedUnix, timeutil.TimeStampNow()) - assert.NoError(t, issues_model.CreateOrStopIssueStopwatch(user2, issue2)) + assert.NoError(t, issues_model.CreateOrStopIssueStopwatch(db.DefaultContext, user2, issue2)) unittest.AssertNotExistsBean(t, &issues_model.Stopwatch{UserID: 2, IssueID: 2}) unittest.AssertExistsAndLoadBean(t, &issues_model.TrackedTime{UserID: 2, IssueID: 2}) } diff --git a/models/organization/mini_org.go b/models/organization/mini_org.go index b1627b5e6c..b1b24624c5 100644 --- a/models/organization/mini_org.go +++ b/models/organization/mini_org.go @@ -4,6 +4,7 @@ package organization import ( + "context" "fmt" "strings" @@ -19,7 +20,7 @@ import ( type MinimalOrg = Organization // GetUserOrgsList returns all organizations the given user has access to -func GetUserOrgsList(user *user_model.User) ([]*MinimalOrg, error) { +func GetUserOrgsList(ctx context.Context, user *user_model.User) ([]*MinimalOrg, error) { schema, err := db.TableInfo(new(user_model.User)) if err != nil { return nil, err @@ -42,7 +43,7 @@ func GetUserOrgsList(user *user_model.User) ([]*MinimalOrg, error) { groupByStr := groupByCols.String() groupByStr = groupByStr[0 : len(groupByStr)-1] - sess := db.GetEngine(db.DefaultContext) + sess := db.GetEngine(ctx) sess = sess.Select(groupByStr+", count(distinct repo_id) as org_count"). Table("user"). Join("INNER", "team", "`team`.org_id = `user`.id"). diff --git a/models/repo/archiver.go b/models/repo/archiver.go index 70f53cfe15..6d0ed42877 100644 --- a/models/repo/archiver.go +++ b/models/repo/archiver.go @@ -72,7 +72,7 @@ var delRepoArchiver = new(RepoArchiver) // DeleteRepoArchiver delete archiver func DeleteRepoArchiver(ctx context.Context, archiver *RepoArchiver) error { - _, err := db.GetEngine(db.DefaultContext).ID(archiver.ID).Delete(delRepoArchiver) + _, err := db.GetEngine(ctx).ID(archiver.ID).Delete(delRepoArchiver) return err } @@ -113,8 +113,8 @@ func UpdateRepoArchiverStatus(ctx context.Context, archiver *RepoArchiver) error } // DeleteAllRepoArchives deletes all repo archives records -func DeleteAllRepoArchives() error { - _, err := db.GetEngine(db.DefaultContext).Where("1=1").Delete(new(RepoArchiver)) +func DeleteAllRepoArchives(ctx context.Context) error { + _, err := db.GetEngine(ctx).Where("1=1").Delete(new(RepoArchiver)) return err } @@ -133,10 +133,10 @@ func (opts FindRepoArchiversOption) toConds() builder.Cond { } // FindRepoArchives find repo archivers -func FindRepoArchives(opts FindRepoArchiversOption) ([]*RepoArchiver, error) { +func FindRepoArchives(ctx context.Context, opts FindRepoArchiversOption) ([]*RepoArchiver, error) { archivers := make([]*RepoArchiver, 0, opts.PageSize) start, limit := opts.GetSkipTake() - err := db.GetEngine(db.DefaultContext).Where(opts.toConds()). + err := db.GetEngine(ctx).Where(opts.toConds()). Asc("created_unix"). Limit(limit, start). Find(&archivers) @@ -144,7 +144,7 @@ func FindRepoArchives(opts FindRepoArchiversOption) ([]*RepoArchiver, error) { } // SetArchiveRepoState sets if a repo is archived -func SetArchiveRepoState(repo *Repository, isArchived bool) (err error) { +func SetArchiveRepoState(ctx context.Context, repo *Repository, isArchived bool) (err error) { repo.IsArchived = isArchived if isArchived { @@ -153,6 +153,6 @@ func SetArchiveRepoState(repo *Repository, isArchived bool) (err error) { repo.ArchivedUnix = timeutil.TimeStamp(0) } - _, err = db.GetEngine(db.DefaultContext).ID(repo.ID).Cols("is_archived", "archived_unix").NoAutoTime().Update(repo) + _, err = db.GetEngine(ctx).ID(repo.ID).Cols("is_archived", "archived_unix").NoAutoTime().Update(repo) return err } diff --git a/models/repo/topic.go b/models/repo/topic.go index 71302388b9..ca533fc1e0 100644 --- a/models/repo/topic.go +++ b/models/repo/topic.go @@ -92,9 +92,9 @@ func SanitizeAndValidateTopics(topics []string) (validTopics, invalidTopics []st } // GetTopicByName retrieves topic by name -func GetTopicByName(name string) (*Topic, error) { +func GetTopicByName(ctx context.Context, name string) (*Topic, error) { var topic Topic - if has, err := db.GetEngine(db.DefaultContext).Where("name = ?", name).Get(&topic); err != nil { + if has, err := db.GetEngine(ctx).Where("name = ?", name).Get(&topic); err != nil { return nil, err } else if !has { return nil, ErrTopicNotExist{name} @@ -192,8 +192,8 @@ func (opts *FindTopicOptions) toConds() builder.Cond { } // FindTopics retrieves the topics via FindTopicOptions -func FindTopics(opts *FindTopicOptions) ([]*Topic, int64, error) { - sess := db.GetEngine(db.DefaultContext).Select("topic.*").Where(opts.toConds()) +func FindTopics(ctx context.Context, opts *FindTopicOptions) ([]*Topic, int64, error) { + sess := db.GetEngine(ctx).Select("topic.*").Where(opts.toConds()) orderBy := "topic.repo_count DESC" if opts.RepoID > 0 { sess.Join("INNER", "repo_topic", "repo_topic.topic_id = topic.id") @@ -208,8 +208,8 @@ func FindTopics(opts *FindTopicOptions) ([]*Topic, int64, error) { } // CountTopics counts the number of topics matching the FindTopicOptions -func CountTopics(opts *FindTopicOptions) (int64, error) { - sess := db.GetEngine(db.DefaultContext).Where(opts.toConds()) +func CountTopics(ctx context.Context, opts *FindTopicOptions) (int64, error) { + sess := db.GetEngine(ctx).Where(opts.toConds()) if opts.RepoID > 0 { sess.Join("INNER", "repo_topic", "repo_topic.topic_id = topic.id") } @@ -231,8 +231,8 @@ func GetRepoTopicByName(ctx context.Context, repoID int64, topicName string) (*T } // AddTopic adds a topic name to a repository (if it does not already have it) -func AddTopic(repoID int64, topicName string) (*Topic, error) { - ctx, committer, err := db.TxContext(db.DefaultContext) +func AddTopic(ctx context.Context, repoID int64, topicName string) (*Topic, error) { + ctx, committer, err := db.TxContext(ctx) if err != nil { return nil, err } @@ -261,8 +261,8 @@ func AddTopic(repoID int64, topicName string) (*Topic, error) { } // DeleteTopic removes a topic name from a repository (if it has it) -func DeleteTopic(repoID int64, topicName string) (*Topic, error) { - topic, err := GetRepoTopicByName(db.DefaultContext, repoID, topicName) +func DeleteTopic(ctx context.Context, repoID int64, topicName string) (*Topic, error) { + topic, err := GetRepoTopicByName(ctx, repoID, topicName) if err != nil { return nil, err } @@ -271,26 +271,26 @@ func DeleteTopic(repoID int64, topicName string) (*Topic, error) { return nil, nil } - err = removeTopicFromRepo(db.DefaultContext, repoID, topic) + err = removeTopicFromRepo(ctx, repoID, topic) if err != nil { return nil, err } - err = syncTopicsInRepository(db.GetEngine(db.DefaultContext), repoID) + err = syncTopicsInRepository(db.GetEngine(ctx), repoID) return topic, err } // SaveTopics save topics to a repository -func SaveTopics(repoID int64, topicNames ...string) error { - topics, _, err := FindTopics(&FindTopicOptions{ +func SaveTopics(ctx context.Context, repoID int64, topicNames ...string) error { + topics, _, err := FindTopics(ctx, &FindTopicOptions{ RepoID: repoID, }) if err != nil { return err } - ctx, committer, err := db.TxContext(db.DefaultContext) + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } diff --git a/models/repo/topic_test.go b/models/repo/topic_test.go index aaed91bdd3..2b609e6d66 100644 --- a/models/repo/topic_test.go +++ b/models/repo/topic_test.go @@ -19,47 +19,47 @@ func TestAddTopic(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - topics, _, err := repo_model.FindTopics(&repo_model.FindTopicOptions{}) + topics, _, err := repo_model.FindTopics(db.DefaultContext, &repo_model.FindTopicOptions{}) assert.NoError(t, err) assert.Len(t, topics, totalNrOfTopics) - topics, total, err := repo_model.FindTopics(&repo_model.FindTopicOptions{ + topics, total, err := repo_model.FindTopics(db.DefaultContext, &repo_model.FindTopicOptions{ ListOptions: db.ListOptions{Page: 1, PageSize: 2}, }) assert.NoError(t, err) assert.Len(t, topics, 2) assert.EqualValues(t, 6, total) - topics, _, err = repo_model.FindTopics(&repo_model.FindTopicOptions{ + topics, _, err = repo_model.FindTopics(db.DefaultContext, &repo_model.FindTopicOptions{ RepoID: 1, }) assert.NoError(t, err) assert.Len(t, topics, repo1NrOfTopics) - assert.NoError(t, repo_model.SaveTopics(2, "golang")) + assert.NoError(t, repo_model.SaveTopics(db.DefaultContext, 2, "golang")) repo2NrOfTopics := 1 - topics, _, err = repo_model.FindTopics(&repo_model.FindTopicOptions{}) + topics, _, err = repo_model.FindTopics(db.DefaultContext, &repo_model.FindTopicOptions{}) assert.NoError(t, err) assert.Len(t, topics, totalNrOfTopics) - topics, _, err = repo_model.FindTopics(&repo_model.FindTopicOptions{ + topics, _, err = repo_model.FindTopics(db.DefaultContext, &repo_model.FindTopicOptions{ RepoID: 2, }) assert.NoError(t, err) assert.Len(t, topics, repo2NrOfTopics) - assert.NoError(t, repo_model.SaveTopics(2, "golang", "gitea")) + assert.NoError(t, repo_model.SaveTopics(db.DefaultContext, 2, "golang", "gitea")) repo2NrOfTopics = 2 totalNrOfTopics++ - topic, err := repo_model.GetTopicByName("gitea") + topic, err := repo_model.GetTopicByName(db.DefaultContext, "gitea") assert.NoError(t, err) assert.EqualValues(t, 1, topic.RepoCount) - topics, _, err = repo_model.FindTopics(&repo_model.FindTopicOptions{}) + topics, _, err = repo_model.FindTopics(db.DefaultContext, &repo_model.FindTopicOptions{}) assert.NoError(t, err) assert.Len(t, topics, totalNrOfTopics) - topics, _, err = repo_model.FindTopics(&repo_model.FindTopicOptions{ + topics, _, err = repo_model.FindTopics(db.DefaultContext, &repo_model.FindTopicOptions{ RepoID: 2, }) assert.NoError(t, err) diff --git a/models/repo/update.go b/models/repo/update.go index c4fba32ad2..6ddf1a8905 100644 --- a/models/repo/update.go +++ b/models/repo/update.go @@ -16,11 +16,11 @@ import ( ) // UpdateRepositoryOwnerNames updates repository owner_names (this should only be used when the ownerName has changed case) -func UpdateRepositoryOwnerNames(ownerID int64, ownerName string) error { +func UpdateRepositoryOwnerNames(ctx context.Context, ownerID int64, ownerName string) error { if ownerID == 0 { return nil } - ctx, committer, err := db.TxContext(db.DefaultContext) + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } @@ -36,8 +36,8 @@ func UpdateRepositoryOwnerNames(ownerID int64, ownerName string) error { } // UpdateRepositoryUpdatedTime updates a repository's updated time -func UpdateRepositoryUpdatedTime(repoID int64, updateTime time.Time) error { - _, err := db.GetEngine(db.DefaultContext).Exec("UPDATE repository SET updated_unix = ? WHERE id = ?", updateTime.Unix(), repoID) +func UpdateRepositoryUpdatedTime(ctx context.Context, repoID int64, updateTime time.Time) error { + _, err := db.GetEngine(ctx).Exec("UPDATE repository SET updated_unix = ? WHERE id = ?", updateTime.Unix(), repoID) return err } @@ -107,7 +107,7 @@ func (err ErrRepoFilesAlreadyExist) Unwrap() error { } // CheckCreateRepository check if could created a repository -func CheckCreateRepository(doer, u *user_model.User, name string, overwriteOrAdopt bool) error { +func CheckCreateRepository(ctx context.Context, doer, u *user_model.User, name string, overwriteOrAdopt bool) error { if !doer.CanCreateRepo() { return ErrReachLimitOfRepo{u.MaxRepoCreation} } @@ -116,7 +116,7 @@ func CheckCreateRepository(doer, u *user_model.User, name string, overwriteOrAdo return err } - has, err := IsRepositoryModelOrDirExist(db.DefaultContext, u, name) + has, err := IsRepositoryModelOrDirExist(ctx, u, name) if err != nil { return fmt.Errorf("IsRepositoryExist: %w", err) } else if has { @@ -136,18 +136,18 @@ func CheckCreateRepository(doer, u *user_model.User, name string, overwriteOrAdo } // ChangeRepositoryName changes all corresponding setting from old repository name to new one. -func ChangeRepositoryName(doer *user_model.User, repo *Repository, newRepoName string) (err error) { +func ChangeRepositoryName(ctx context.Context, doer *user_model.User, repo *Repository, newRepoName string) (err error) { oldRepoName := repo.Name newRepoName = strings.ToLower(newRepoName) if err = IsUsableRepoName(newRepoName); err != nil { return err } - if err := repo.LoadOwner(db.DefaultContext); err != nil { + if err := repo.LoadOwner(ctx); err != nil { return err } - has, err := IsRepositoryModelOrDirExist(db.DefaultContext, repo.Owner, newRepoName) + has, err := IsRepositoryModelOrDirExist(ctx, repo.Owner, newRepoName) if err != nil { return fmt.Errorf("IsRepositoryExist: %w", err) } else if has { @@ -171,7 +171,7 @@ func ChangeRepositoryName(doer *user_model.User, repo *Repository, newRepoName s } } - ctx, committer, err := db.TxContext(db.DefaultContext) + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } diff --git a/models/repo_transfer.go b/models/repo_transfer.go index 1c873cec57..630c243c8e 100644 --- a/models/repo_transfer.go +++ b/models/repo_transfer.go @@ -79,8 +79,8 @@ func (r *RepoTransfer) LoadAttributes(ctx context.Context) error { // CanUserAcceptTransfer checks if the user has the rights to accept/decline a repo transfer. // For user, it checks if it's himself // For organizations, it checks if the user is able to create repos -func (r *RepoTransfer) CanUserAcceptTransfer(u *user_model.User) bool { - if err := r.LoadAttributes(db.DefaultContext); err != nil { +func (r *RepoTransfer) CanUserAcceptTransfer(ctx context.Context, u *user_model.User) bool { + if err := r.LoadAttributes(ctx); err != nil { log.Error("LoadAttributes: %v", err) return false } @@ -89,7 +89,7 @@ func (r *RepoTransfer) CanUserAcceptTransfer(u *user_model.User) bool { return r.RecipientID == u.ID } - allowed, err := organization.CanCreateOrgRepo(db.DefaultContext, r.RecipientID, u.ID) + allowed, err := organization.CanCreateOrgRepo(ctx, r.RecipientID, u.ID) if err != nil { log.Error("CanCreateOrgRepo: %v", err) return false @@ -122,8 +122,8 @@ func deleteRepositoryTransfer(ctx context.Context, repoID int64) error { // CancelRepositoryTransfer marks the repository as ready and remove pending transfer entry, // thus cancel the transfer process. -func CancelRepositoryTransfer(repo *repo_model.Repository) error { - ctx, committer, err := db.TxContext(db.DefaultContext) +func CancelRepositoryTransfer(ctx context.Context, repo *repo_model.Repository) error { + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } @@ -199,7 +199,7 @@ func CreatePendingRepositoryTransfer(ctx context.Context, doer, newOwner *user_m } // TransferOwnership transfers all corresponding repository items from old user to new one. -func TransferOwnership(doer *user_model.User, newOwnerName string, repo *repo_model.Repository) (err error) { +func TransferOwnership(ctx context.Context, doer *user_model.User, newOwnerName string, repo *repo_model.Repository) (err error) { repoRenamed := false wikiRenamed := false oldOwnerName := doer.Name @@ -234,7 +234,7 @@ func TransferOwnership(doer *user_model.User, newOwnerName string, repo *repo_mo } }() - ctx, committer, err := db.TxContext(db.DefaultContext) + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } diff --git a/models/repo_transfer_test.go b/models/repo_transfer_test.go index 7364d4d02c..b55cef9473 100644 --- a/models/repo_transfer_test.go +++ b/models/repo_transfer_test.go @@ -25,7 +25,7 @@ func TestRepositoryTransfer(t *testing.T) { assert.NotNil(t, transfer) // Cancel transfer - assert.NoError(t, CancelRepositoryTransfer(repo)) + assert.NoError(t, CancelRepositoryTransfer(db.DefaultContext, repo)) transfer, err = GetPendingRepositoryTransfer(db.DefaultContext, repo) assert.Error(t, err) @@ -53,5 +53,5 @@ func TestRepositoryTransfer(t *testing.T) { assert.Error(t, err) // Cancel transfer - assert.NoError(t, CancelRepositoryTransfer(repo)) + assert.NoError(t, CancelRepositoryTransfer(db.DefaultContext, repo)) } diff --git a/models/user/follow.go b/models/user/follow.go index 7efecc26a7..f4dd2891ff 100644 --- a/models/user/follow.go +++ b/models/user/follow.go @@ -4,6 +4,8 @@ package user import ( + "context" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/timeutil" ) @@ -21,18 +23,18 @@ func init() { } // IsFollowing returns true if user is following followID. -func IsFollowing(userID, followID int64) bool { - has, _ := db.GetEngine(db.DefaultContext).Get(&Follow{UserID: userID, FollowID: followID}) +func IsFollowing(ctx context.Context, userID, followID int64) bool { + has, _ := db.GetEngine(ctx).Get(&Follow{UserID: userID, FollowID: followID}) return has } // FollowUser marks someone be another's follower. -func FollowUser(userID, followID int64) (err error) { - if userID == followID || IsFollowing(userID, followID) { +func FollowUser(ctx context.Context, userID, followID int64) (err error) { + if userID == followID || IsFollowing(ctx, userID, followID) { return nil } - ctx, committer, err := db.TxContext(db.DefaultContext) + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } @@ -53,12 +55,12 @@ func FollowUser(userID, followID int64) (err error) { } // UnfollowUser unmarks someone as another's follower. -func UnfollowUser(userID, followID int64) (err error) { - if userID == followID || !IsFollowing(userID, followID) { +func UnfollowUser(ctx context.Context, userID, followID int64) (err error) { + if userID == followID || !IsFollowing(ctx, userID, followID) { return nil } - ctx, committer, err := db.TxContext(db.DefaultContext) + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } diff --git a/models/user/follow_test.go b/models/user/follow_test.go index fc408d5257..c327d935ae 100644 --- a/models/user/follow_test.go +++ b/models/user/follow_test.go @@ -6,6 +6,7 @@ package user_test import ( "testing" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/models/unittest" user_model "code.gitea.io/gitea/models/user" @@ -14,9 +15,9 @@ import ( func TestIsFollowing(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - assert.True(t, user_model.IsFollowing(4, 2)) - assert.False(t, user_model.IsFollowing(2, 4)) - assert.False(t, user_model.IsFollowing(5, unittest.NonexistentID)) - assert.False(t, user_model.IsFollowing(unittest.NonexistentID, 5)) - assert.False(t, user_model.IsFollowing(unittest.NonexistentID, unittest.NonexistentID)) + assert.True(t, user_model.IsFollowing(db.DefaultContext, 4, 2)) + assert.False(t, user_model.IsFollowing(db.DefaultContext, 2, 4)) + assert.False(t, user_model.IsFollowing(db.DefaultContext, 5, unittest.NonexistentID)) + assert.False(t, user_model.IsFollowing(db.DefaultContext, unittest.NonexistentID, 5)) + assert.False(t, user_model.IsFollowing(db.DefaultContext, unittest.NonexistentID, unittest.NonexistentID)) } diff --git a/models/user/user.go b/models/user/user.go index b3956da1cb..63b95816ce 100644 --- a/models/user/user.go +++ b/models/user/user.go @@ -1246,7 +1246,7 @@ func IsUserVisibleToViewer(ctx context.Context, u, viewer *User) bool { } // If they follow - they see each over - follower := IsFollowing(u.ID, viewer.ID) + follower := IsFollowing(ctx, u.ID, viewer.ID) if follower { return true } diff --git a/models/user/user_test.go b/models/user/user_test.go index b15f0cbc59..971117482c 100644 --- a/models/user/user_test.go +++ b/models/user/user_test.go @@ -449,13 +449,13 @@ func TestFollowUser(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) testSuccess := func(followerID, followedID int64) { - assert.NoError(t, user_model.FollowUser(followerID, followedID)) + assert.NoError(t, user_model.FollowUser(db.DefaultContext, followerID, followedID)) unittest.AssertExistsAndLoadBean(t, &user_model.Follow{UserID: followerID, FollowID: followedID}) } testSuccess(4, 2) testSuccess(5, 2) - assert.NoError(t, user_model.FollowUser(2, 2)) + assert.NoError(t, user_model.FollowUser(db.DefaultContext, 2, 2)) unittest.CheckConsistencyFor(t, &user_model.User{}) } @@ -464,7 +464,7 @@ func TestUnfollowUser(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) testSuccess := func(followerID, followedID int64) { - assert.NoError(t, user_model.UnfollowUser(followerID, followedID)) + assert.NoError(t, user_model.UnfollowUser(db.DefaultContext, followerID, followedID)) unittest.AssertNotExistsBean(t, &user_model.Follow{UserID: followerID, FollowID: followedID}) } testSuccess(4, 2) |