diff options
Diffstat (limited to 'models')
74 files changed, 1334 insertions, 935 deletions
diff --git a/models/actions/run.go b/models/actions/run.go index f40bc1eb3d..60fbbcd323 100644 --- a/models/actions/run.go +++ b/models/actions/run.go @@ -88,7 +88,7 @@ func (run *ActionRun) RefLink() string { if refName.IsPull() { return run.Repo.Link() + "/pulls/" + refName.ShortName() } - return git.RefURL(run.Repo.Link(), run.Ref) + return run.Repo.Link() + "/src/" + refName.RefWebLinkPath() } // PrettyRef return #id for pull ref or ShortName for others @@ -154,7 +154,7 @@ func (run *ActionRun) GetPushEventPayload() (*api.PushPayload, error) { } func (run *ActionRun) GetPullRequestEventPayload() (*api.PullRequestPayload, error) { - if run.Event == webhook_module.HookEventPullRequest || run.Event == webhook_module.HookEventPullRequestSync { + if run.Event.IsPullRequest() { var payload api.PullRequestPayload if err := json.Unmarshal([]byte(run.EventPayload), &payload); err != nil { return nil, err @@ -275,7 +275,7 @@ func InsertRun(ctx context.Context, run *ActionRun, jobs []*jobparser.SingleWork return err } run.Index = index - run.Title, _ = util.SplitStringAtByteN(run.Title, 255) + run.Title = util.EllipsisDisplayString(run.Title, 255) if err := db.Insert(ctx, run); err != nil { return err @@ -308,7 +308,7 @@ func InsertRun(ctx context.Context, run *ActionRun, jobs []*jobparser.SingleWork } else { hasWaiting = true } - job.Name, _ = util.SplitStringAtByteN(job.Name, 255) + job.Name = util.EllipsisDisplayString(job.Name, 255) runJobs = append(runJobs, &ActionRunJob{ RunID: run.ID, RepoID: run.RepoID, @@ -402,7 +402,7 @@ func UpdateRun(ctx context.Context, run *ActionRun, cols ...string) error { if len(cols) > 0 { sess.Cols(cols...) } - run.Title, _ = util.SplitStringAtByteN(run.Title, 255) + run.Title = util.EllipsisDisplayString(run.Title, 255) affected, err := sess.Update(run) if err != nil { return err diff --git a/models/actions/runner.go b/models/actions/runner.go index b35a76680c..0d5464a5be 100644 --- a/models/actions/runner.go +++ b/models/actions/runner.go @@ -252,7 +252,7 @@ func GetRunnerByID(ctx context.Context, id int64) (*ActionRunner, error) { // UpdateRunner updates runner's information. func UpdateRunner(ctx context.Context, r *ActionRunner, cols ...string) error { e := db.GetEngine(ctx) - r.Name, _ = util.SplitStringAtByteN(r.Name, 255) + r.Name = util.EllipsisDisplayString(r.Name, 255) var err error if len(cols) == 0 { _, err = e.ID(r.ID).AllCols().Update(r) @@ -279,7 +279,7 @@ func CreateRunner(ctx context.Context, t *ActionRunner) error { // Remove OwnerID to avoid confusion; it's not worth returning an error here. t.OwnerID = 0 } - t.Name, _ = util.SplitStringAtByteN(t.Name, 255) + t.Name = util.EllipsisDisplayString(t.Name, 255) return db.Insert(ctx, t) } diff --git a/models/actions/runner_token.go b/models/actions/runner_token.go index fd6ba7ecad..bbd2af73b6 100644 --- a/models/actions/runner_token.go +++ b/models/actions/runner_token.go @@ -51,7 +51,7 @@ func GetRunnerToken(ctx context.Context, token string) (*ActionRunnerToken, erro if err != nil { return nil, err } else if !has { - return nil, fmt.Errorf("runner token %q: %w", token, util.ErrNotExist) + return nil, fmt.Errorf(`runner token "%s...": %w`, util.TruncateRunes(token, 3), util.ErrNotExist) } return &runnerToken, nil } @@ -68,19 +68,15 @@ func UpdateRunnerToken(ctx context.Context, r *ActionRunnerToken, cols ...string return err } -// NewRunnerToken creates a new active runner token and invalidate all old tokens +// NewRunnerTokenWithValue creates a new active runner token and invalidate all old tokens // ownerID will be ignored and treated as 0 if repoID is non-zero. -func NewRunnerToken(ctx context.Context, ownerID, repoID int64) (*ActionRunnerToken, error) { +func NewRunnerTokenWithValue(ctx context.Context, ownerID, repoID int64, token string) (*ActionRunnerToken, error) { if ownerID != 0 && repoID != 0 { // It's trying to create a runner token that belongs to a repository, but OwnerID has been set accidentally. // Remove OwnerID to avoid confusion; it's not worth returning an error here. ownerID = 0 } - token, err := util.CryptoRandomString(40) - if err != nil { - return nil, err - } runnerToken := &ActionRunnerToken{ OwnerID: ownerID, RepoID: repoID, @@ -95,11 +91,19 @@ func NewRunnerToken(ctx context.Context, ownerID, repoID int64) (*ActionRunnerTo return err } - _, err = db.GetEngine(ctx).Insert(runnerToken) + _, err := db.GetEngine(ctx).Insert(runnerToken) return err }) } +func NewRunnerToken(ctx context.Context, ownerID, repoID int64) (*ActionRunnerToken, error) { + token, err := util.CryptoRandomString(40) + if err != nil { + return nil, err + } + return NewRunnerTokenWithValue(ctx, ownerID, repoID, token) +} + // GetLatestRunnerToken returns the latest runner token func GetLatestRunnerToken(ctx context.Context, ownerID, repoID int64) (*ActionRunnerToken, error) { if ownerID != 0 && repoID != 0 { diff --git a/models/actions/schedule.go b/models/actions/schedule.go index 961ffd0851..e2cc32eedc 100644 --- a/models/actions/schedule.go +++ b/models/actions/schedule.go @@ -68,7 +68,7 @@ func CreateScheduleTask(ctx context.Context, rows []*ActionSchedule) error { // Loop through each schedule row for _, row := range rows { - row.Title, _ = util.SplitStringAtByteN(row.Title, 255) + row.Title = util.EllipsisDisplayString(row.Title, 255) // Create new schedule row if err = db.Insert(ctx, row); err != nil { return err diff --git a/models/actions/task.go b/models/actions/task.go index af74faf937..9f13ff94c9 100644 --- a/models/actions/task.go +++ b/models/actions/task.go @@ -298,7 +298,7 @@ func CreateTaskForRunner(ctx context.Context, runner *ActionRunner) (*ActionTask if len(workflowJob.Steps) > 0 { steps := make([]*ActionTaskStep, len(workflowJob.Steps)) for i, v := range workflowJob.Steps { - name, _ := util.SplitStringAtByteN(v.String(), 255) + name := util.EllipsisDisplayString(v.String(), 255) steps[i] = &ActionTaskStep{ Name: name, TaskID: task.ID, diff --git a/models/activities/action.go b/models/activities/action.go index ff7fdb2f10..f96621b7d5 100644 --- a/models/activities/action.go +++ b/models/activities/action.go @@ -20,12 +20,12 @@ import ( repo_model "code.gitea.io/gitea/models/repo" "code.gitea.io/gitea/models/unit" user_model "code.gitea.io/gitea/models/user" - "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/git" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/structs" "code.gitea.io/gitea/modules/timeutil" + "code.gitea.io/gitea/modules/util" "xorm.io/builder" "xorm.io/xorm/schemas" @@ -226,7 +226,7 @@ func (a *Action) GetActUserName(ctx context.Context) string { // ShortActUserName gets the action's user name trimmed to max 20 // chars. func (a *Action) ShortActUserName(ctx context.Context) string { - return base.EllipsisString(a.GetActUserName(ctx), 20) + return util.EllipsisDisplayString(a.GetActUserName(ctx), 20) } // GetActDisplayName gets the action's display name based on DEFAULT_SHOW_FULL_NAME, or falls back to the username if it is blank. @@ -260,7 +260,7 @@ func (a *Action) GetRepoUserName(ctx context.Context) string { // ShortRepoUserName returns the name of the action repository owner // trimmed to max 20 chars. func (a *Action) ShortRepoUserName(ctx context.Context) string { - return base.EllipsisString(a.GetRepoUserName(ctx), 20) + return util.EllipsisDisplayString(a.GetRepoUserName(ctx), 20) } // GetRepoName returns the name of the action repository. @@ -275,7 +275,7 @@ func (a *Action) GetRepoName(ctx context.Context) string { // ShortRepoName returns the name of the action repository // trimmed to max 33 chars. func (a *Action) ShortRepoName(ctx context.Context) string { - return base.EllipsisString(a.GetRepoName(ctx), 33) + return util.EllipsisDisplayString(a.GetRepoName(ctx), 33) } // GetRepoPath returns the virtual path to the action repository. @@ -355,7 +355,7 @@ func (a *Action) GetBranch() string { // GetRefLink returns the action's ref link. func (a *Action) GetRefLink(ctx context.Context) string { - return git.RefURL(a.GetRepoLink(ctx), a.RefName) + return a.GetRepoLink(ctx) + "/src/" + git.RefName(a.RefName).RefWebLinkPath() } // GetTag returns the action's repository tag. diff --git a/models/activities/repo_activity.go b/models/activities/repo_activity.go index 3ffad035b7..3ccdbd47d3 100644 --- a/models/activities/repo_activity.go +++ b/models/activities/repo_activity.go @@ -16,6 +16,7 @@ import ( "code.gitea.io/gitea/modules/git" "code.gitea.io/gitea/modules/gitrepo" + "xorm.io/builder" "xorm.io/xorm" ) @@ -337,8 +338,10 @@ func newlyCreatedIssues(ctx context.Context, repoID int64, fromTime time.Time) * func activeIssues(ctx context.Context, repoID int64, fromTime time.Time) *xorm.Session { sess := db.GetEngine(ctx).Where("issue.repo_id = ?", repoID). And("issue.is_pull = ?", false). - And("issue.created_unix >= ?", fromTime.Unix()). - Or("issue.closed_unix >= ?", fromTime.Unix()) + And(builder.Or( + builder.Gte{"issue.created_unix": fromTime.Unix()}, + builder.Gte{"issue.closed_unix": fromTime.Unix()}, + )) return sess } diff --git a/models/activities/user_heatmap_test.go b/models/activities/user_heatmap_test.go index a039fd3613..380045d3c5 100644 --- a/models/activities/user_heatmap_test.go +++ b/models/activities/user_heatmap_test.go @@ -64,11 +64,9 @@ func TestGetUserHeatmapDataByUser(t *testing.T) { for _, tc := range testCases { user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: tc.userID}) - doer := &user_model.User{ID: tc.doerID} - _, err := unittest.LoadBeanIfExists(doer) - assert.NoError(t, err) - if tc.doerID == 0 { - doer = nil + var doer *user_model.User + if tc.doerID != 0 { + doer = unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: tc.doerID}) } // get the action for comparison diff --git a/models/asymkey/gpg_key_test.go b/models/asymkey/gpg_key_test.go index d3fbb01d82..0bccbb51b5 100644 --- a/models/asymkey/gpg_key_test.go +++ b/models/asymkey/gpg_key_test.go @@ -15,6 +15,7 @@ import ( "github.com/keybase/go-crypto/openpgp/packet" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestCheckArmoredGPGKeyString(t *testing.T) { @@ -107,9 +108,8 @@ MkM/fdpyc2hY7Dl/+qFmN5MG5yGmMpQcX+RNNR222ibNC1D3wg== =i9b7 -----END PGP PUBLIC KEY BLOCK-----` keys, err := checkArmoredGPGKeyString(testGPGArmor) - if !assert.NotEmpty(t, keys) { - return - } + require.NotEmpty(t, keys) + ekey := keys[0] assert.NoError(t, err, "Could not parse a valid GPG armored key", ekey) diff --git a/models/auth/webauthn_test.go b/models/auth/webauthn_test.go index f1cf398adf..654427e974 100644 --- a/models/auth/webauthn_test.go +++ b/models/auth/webauthn_test.go @@ -44,7 +44,7 @@ func TestWebAuthnCredential_UpdateSignCount(t *testing.T) { cred := unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1}) cred.SignCount = 1 assert.NoError(t, cred.UpdateSignCount(db.DefaultContext)) - unittest.AssertExistsIf(t, true, &auth_model.WebAuthnCredential{ID: 1, SignCount: 1}) + unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1, SignCount: 1}) } func TestWebAuthnCredential_UpdateLargeCounter(t *testing.T) { @@ -52,7 +52,7 @@ func TestWebAuthnCredential_UpdateLargeCounter(t *testing.T) { cred := unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1}) cred.SignCount = 0xffffffff assert.NoError(t, cred.UpdateSignCount(db.DefaultContext)) - unittest.AssertExistsIf(t, true, &auth_model.WebAuthnCredential{ID: 1, SignCount: 0xffffffff}) + unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1, SignCount: 0xffffffff}) } func TestCreateCredential(t *testing.T) { @@ -63,5 +63,5 @@ func TestCreateCredential(t *testing.T) { assert.Equal(t, "WebAuthn Created Credential", res.Name) assert.Equal(t, []byte("Test"), res.CredentialID) - unittest.AssertExistsIf(t, true, &auth_model.WebAuthnCredential{Name: "WebAuthn Created Credential", UserID: 1}) + unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{Name: "WebAuthn Created Credential", UserID: 1}) } diff --git a/models/db/engine_hook.go b/models/db/engine_hook.go index b4c543c3dd..2c9fc09c99 100644 --- a/models/db/engine_hook.go +++ b/models/db/engine_hook.go @@ -7,23 +7,36 @@ import ( "context" "time" + "code.gitea.io/gitea/modules/gtprof" "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/setting" "xorm.io/xorm/contexts" ) -type SlowQueryHook struct { +type EngineHook struct { Threshold time.Duration Logger log.Logger } -var _ contexts.Hook = (*SlowQueryHook)(nil) +var _ contexts.Hook = (*EngineHook)(nil) -func (*SlowQueryHook) BeforeProcess(c *contexts.ContextHook) (context.Context, error) { - return c.Ctx, nil +func (*EngineHook) BeforeProcess(c *contexts.ContextHook) (context.Context, error) { + ctx, _ := gtprof.GetTracer().Start(c.Ctx, gtprof.TraceSpanDatabase) + return ctx, nil } -func (h *SlowQueryHook) AfterProcess(c *contexts.ContextHook) error { +func (h *EngineHook) AfterProcess(c *contexts.ContextHook) error { + span := gtprof.GetContextSpan(c.Ctx) + if span != nil { + // Do not record SQL parameters here: + // * It shouldn't expose the parameters because they contain sensitive information, end users need to report the trace details safely. + // * Some parameters contain quite long texts, waste memory and are difficult to display. + span.SetAttributeString(gtprof.TraceAttrDbSQL, c.SQL) + span.End() + } else { + setting.PanicInDevOrTesting("span in database engine hook is nil") + } if c.ExecuteTime >= h.Threshold { // 8 is the amount of skips passed to runtime.Caller, so that in the log the correct function // is being displayed (the function that ultimately wants to execute the query in the code) diff --git a/models/db/engine_init.go b/models/db/engine_init.go index da85018957..edca697934 100644 --- a/models/db/engine_init.go +++ b/models/db/engine_init.go @@ -72,7 +72,7 @@ func InitEngine(ctx context.Context) error { xe.SetDefaultContext(ctx) if setting.Database.SlowQueryThreshold > 0 { - xe.AddHook(&SlowQueryHook{ + xe.AddHook(&EngineHook{ Threshold: setting.Database.SlowQueryThreshold, Logger: log.GetLogger("xorm"), }) diff --git a/models/db/engine_test.go b/models/db/engine_test.go index e3dbfbdc24..10a1a33ff0 100644 --- a/models/db/engine_test.go +++ b/models/db/engine_test.go @@ -15,6 +15,7 @@ import ( _ "code.gitea.io/gitea/cmd" // for TestPrimaryKeys "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestDumpDatabase(t *testing.T) { @@ -62,9 +63,7 @@ func TestPrimaryKeys(t *testing.T) { // Import "code.gitea.io/gitea/cmd" to make sure each db.RegisterModel in init functions has been called. beans, err := db.NamesToBean() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) whitelist := map[string]string{ "the_table_name_to_skip_checking": "Write a note here to explain why", @@ -79,8 +78,6 @@ func TestPrimaryKeys(t *testing.T) { t.Logf("ignore %q because %q", table.Name, why) continue } - if len(table.PrimaryKeys) == 0 { - t.Errorf("table %q has no primary key", table.Name) - } + assert.NotEmpty(t, table.PrimaryKeys, "table %q has no primary key", table.Name) } } diff --git a/models/db/name.go b/models/db/name.go index 55c9dffb6a..5f98edbb28 100644 --- a/models/db/name.go +++ b/models/db/name.go @@ -5,20 +5,13 @@ package db import ( "fmt" - "regexp" "strings" "unicode/utf8" "code.gitea.io/gitea/modules/util" ) -var ( - // ErrNameEmpty name is empty error - ErrNameEmpty = util.SilentWrap{Message: "name is empty", Err: util.ErrInvalidArgument} - - // AlphaDashDotPattern characters prohibited in a username (anything except A-Za-z0-9_.-) - AlphaDashDotPattern = regexp.MustCompile(`[^\w-\.]`) -) +var ErrNameEmpty = util.SilentWrap{Message: "name is empty", Err: util.ErrInvalidArgument} // ErrNameReserved represents a "reserved name" error. type ErrNameReserved struct { @@ -82,20 +75,20 @@ func (err ErrNameCharsNotAllowed) Unwrap() error { // IsUsableName checks if name is reserved or pattern of name is not allowed // based on given reserved names and patterns. -// Names are exact match, patterns can be prefix or suffix match with placeholder '*'. -func IsUsableName(names, patterns []string, name string) error { +// Names are exact match, patterns can be a prefix or suffix match with placeholder '*'. +func IsUsableName(reservedNames, reservedPatterns []string, name string) error { name = strings.TrimSpace(strings.ToLower(name)) if utf8.RuneCountInString(name) == 0 { return ErrNameEmpty } - for i := range names { - if name == names[i] { + for i := range reservedNames { + if name == reservedNames[i] { return ErrNameReserved{name} } } - for _, pat := range patterns { + for _, pat := range reservedPatterns { if pat[0] == '*' && strings.HasSuffix(name, pat[1:]) || (pat[len(pat)-1] == '*' && strings.HasPrefix(name, pat[:len(pat)-1])) { return ErrNamePatternNotAllowed{pat} diff --git a/models/fixtures/access.yml b/models/fixtures/access.yml index 4171e31fef..596046e950 100644 --- a/models/fixtures/access.yml +++ b/models/fixtures/access.yml @@ -171,3 +171,9 @@ user_id: 40 repo_id: 61 mode: 4 + +- + id: 30 + user_id: 40 + repo_id: 1 + mode: 2 diff --git a/models/fixtures/action_run_job.yml b/models/fixtures/action_run_job.yml index 9b6f5b9a88..8837e6ec2d 100644 --- a/models/fixtures/action_run_job.yml +++ b/models/fixtures/action_run_job.yml @@ -64,7 +64,7 @@ name: job2 attempt: 1 job_id: job2 - needs: [job1] + needs: '["job1"]' task_id: 51 status: 5 started: 1683636528 diff --git a/models/fixtures/label.yml b/models/fixtures/label.yml index 2242b90dcd..acfac74968 100644 --- a/models/fixtures/label.yml +++ b/models/fixtures/label.yml @@ -96,3 +96,14 @@ num_issues: 0 num_closed_issues: 0 archived_unix: 0 + +- + id: 10 + repo_id: 3 + org_id: 0 + name: repo3label1 + color: '#112233' + exclusive: false + num_issues: 0 + num_closed_issues: 0 + archived_unix: 0 diff --git a/models/fixtures/protected_tag.yml b/models/fixtures/protected_tag.yml index dbec52c0c2..1944e7bd84 100644 --- a/models/fixtures/protected_tag.yml +++ b/models/fixtures/protected_tag.yml @@ -2,23 +2,23 @@ id: 1 repo_id: 4 name_pattern: /v.+/ - allowlist_user_i_ds: [] - allowlist_team_i_ds: [] + allowlist_user_i_ds: "[]" + allowlist_team_i_ds: "[]" created_unix: 1715596037 updated_unix: 1715596037 - id: 2 repo_id: 1 name_pattern: v-* - allowlist_user_i_ds: [] - allowlist_team_i_ds: [] + allowlist_user_i_ds: "[]" + allowlist_team_i_ds: "[]" created_unix: 1715596037 updated_unix: 1715596037 - id: 3 repo_id: 1 name_pattern: v-1.1 - allowlist_user_i_ds: [2] - allowlist_team_i_ds: [] + allowlist_user_i_ds: "[2]" + allowlist_team_i_ds: "[]" created_unix: 1715596037 updated_unix: 1715596037 diff --git a/models/fixtures/webhook.yml b/models/fixtures/webhook.yml index f62bae1f31..ebc4062b60 100644 --- a/models/fixtures/webhook.yml +++ b/models/fixtures/webhook.yml @@ -22,6 +22,7 @@ content_type: 1 # json events: '{"push_only":false,"send_everything":false,"choose_events":false,"events":{"create":false,"push":true,"pull_request":true}}' is_active: true + - id: 4 repo_id: 2 @@ -29,3 +30,23 @@ content_type: 1 # json events: '{"push_only":true,"branch_filter":"{master,feature*}"}' is_active: true + +- + id: 5 + repo_id: 0 + owner_id: 0 + url: www.example.com/url5 + content_type: 1 # json + events: '{"push_only":true,"branch_filter":"{master,feature*}"}' + is_active: true + is_system_webhook: true + +- + id: 6 + repo_id: 0 + owner_id: 0 + url: www.example.com/url6 + content_type: 1 # json + events: '{"push_only":true,"branch_filter":"{master,feature*}"}' + is_active: true + is_system_webhook: false diff --git a/models/git/branch.go b/models/git/branch.go index e683ce47e6..d1caa35947 100644 --- a/models/git/branch.go +++ b/models/git/branch.go @@ -167,6 +167,9 @@ func GetBranch(ctx context.Context, repoID int64, branchName string) (*Branch, e BranchName: branchName, } } + // FIXME: this design is not right: it doesn't check `branch.IsDeleted`, it doesn't make sense to make callers to check IsDeleted again and again. + // It causes inconsistency with `GetBranches` and `git.GetBranch`, and will lead to strange bugs + // In the future, there should be 2 functions: `GetBranchExisting` and `GetBranchWithDeleted` return &branch, nil } @@ -440,6 +443,8 @@ type FindRecentlyPushedNewBranchesOptions struct { } type RecentlyPushedNewBranch struct { + BranchRepo *repo_model.Repository + BranchName string BranchDisplayName string BranchLink string BranchCompareURL string @@ -540,7 +545,9 @@ func FindRecentlyPushedNewBranches(ctx context.Context, doer *user_model.User, o branchDisplayName = fmt.Sprintf("%s:%s", branch.Repo.FullName(), branchDisplayName) } newBranches = append(newBranches, &RecentlyPushedNewBranch{ + BranchRepo: branch.Repo, BranchDisplayName: branchDisplayName, + BranchName: branch.Name, BranchLink: fmt.Sprintf("%s/src/branch/%s", branch.Repo.Link(), util.PathEscapeSegments(branch.Name)), BranchCompareURL: branch.Repo.ComposeBranchCompareURL(opts.BaseRepo, branch.Name), CommitTime: branch.CommitTime, diff --git a/models/issues/comment.go b/models/issues/comment.go index e4537aa872..a248708820 100644 --- a/models/issues/comment.go +++ b/models/issues/comment.go @@ -112,8 +112,8 @@ const ( CommentTypePRScheduledToAutoMerge // 34 pr was scheduled to auto merge when checks succeed CommentTypePRUnScheduledToAutoMerge // 35 pr was un scheduled to auto merge when checks succeed - CommentTypePin // 36 pin Issue - CommentTypeUnpin // 37 unpin Issue + CommentTypePin // 36 pin Issue/PullRequest + CommentTypeUnpin // 37 unpin Issue/PullRequest CommentTypeChangeTimeEstimate // 38 Change time estimate ) @@ -197,6 +197,20 @@ func (t CommentType) HasMailReplySupport() bool { return false } +func (t CommentType) CountedAsConversation() bool { + for _, ct := range ConversationCountedCommentType() { + if t == ct { + return true + } + } + return false +} + +// ConversationCountedCommentType returns the comment types that are counted as a conversation +func ConversationCountedCommentType() []CommentType { + return []CommentType{CommentTypeComment, CommentTypeReview} +} + // RoleInRepo presents the user's participation in the repo type RoleInRepo string @@ -592,26 +606,26 @@ func (c *Comment) LoadAttachments(ctx context.Context) error { return nil } -// UpdateAttachments update attachments by UUIDs for the comment -func (c *Comment) UpdateAttachments(ctx context.Context, uuids []string) error { - ctx, committer, err := db.TxContext(ctx) - if err != nil { - return err - } - defer committer.Close() - - attachments, err := repo_model.GetAttachmentsByUUIDs(ctx, uuids) - if err != nil { - return fmt.Errorf("getAttachmentsByUUIDs [uuids: %v]: %w", uuids, err) +// UpdateCommentAttachments update attachments by UUIDs for the comment +func UpdateCommentAttachments(ctx context.Context, c *Comment, uuids []string) error { + if len(uuids) == 0 { + return nil } - for i := 0; i < len(attachments); i++ { - attachments[i].IssueID = c.IssueID - attachments[i].CommentID = c.ID - if err := repo_model.UpdateAttachment(ctx, attachments[i]); err != nil { - return fmt.Errorf("update attachment [id: %d]: %w", attachments[i].ID, err) + return db.WithTx(ctx, func(ctx context.Context) error { + attachments, err := repo_model.GetAttachmentsByUUIDs(ctx, uuids) + if err != nil { + return fmt.Errorf("getAttachmentsByUUIDs [uuids: %v]: %w", uuids, err) } - } - return committer.Commit() + for i := 0; i < len(attachments); i++ { + attachments[i].IssueID = c.IssueID + attachments[i].CommentID = c.ID + if err := repo_model.UpdateAttachment(ctx, attachments[i]); err != nil { + return fmt.Errorf("update attachment [id: %d]: %w", attachments[i].ID, err) + } + } + c.Attachments = attachments + return nil + }) } // LoadAssigneeUserAndTeam if comment.Type is CommentTypeAssignees, then load assignees @@ -878,7 +892,7 @@ func updateCommentInfos(ctx context.Context, opts *CreateCommentOptions, comment // Check comment type. switch opts.Type { case CommentTypeCode: - if err = updateAttachments(ctx, opts, comment); err != nil { + if err = UpdateCommentAttachments(ctx, comment, opts.Attachments); err != nil { return err } if comment.ReviewID != 0 { @@ -893,12 +907,12 @@ func updateCommentInfos(ctx context.Context, opts *CreateCommentOptions, comment } fallthrough case CommentTypeComment: - if _, err = db.Exec(ctx, "UPDATE `issue` SET num_comments=num_comments+1 WHERE id=?", opts.Issue.ID); err != nil { + if err := UpdateIssueNumComments(ctx, opts.Issue.ID); err != nil { return err } fallthrough case CommentTypeReview: - if err = updateAttachments(ctx, opts, comment); err != nil { + if err = UpdateCommentAttachments(ctx, comment, opts.Attachments); err != nil { return err } case CommentTypeReopen, CommentTypeClose: @@ -910,23 +924,6 @@ func updateCommentInfos(ctx context.Context, opts *CreateCommentOptions, comment return UpdateIssueCols(ctx, opts.Issue, "updated_unix") } -func updateAttachments(ctx context.Context, opts *CreateCommentOptions, comment *Comment) error { - attachments, err := repo_model.GetAttachmentsByUUIDs(ctx, opts.Attachments) - if err != nil { - return fmt.Errorf("getAttachmentsByUUIDs [uuids: %v]: %w", opts.Attachments, err) - } - for i := range attachments { - attachments[i].IssueID = opts.Issue.ID - attachments[i].CommentID = comment.ID - // No assign value could be 0, so ignore AllCols(). - if _, err = db.GetEngine(ctx).ID(attachments[i].ID).Update(attachments[i]); err != nil { - return fmt.Errorf("update attachment [%d]: %w", attachments[i].ID, err) - } - } - comment.Attachments = attachments - return nil -} - func createDeadlineComment(ctx context.Context, doer *user_model.User, issue *Issue, newDeadlineUnix timeutil.TimeStamp) (*Comment, error) { var content string var commentType CommentType @@ -1182,8 +1179,8 @@ func DeleteComment(ctx context.Context, comment *Comment) error { return err } - if comment.Type == CommentTypeComment { - if _, err := e.ID(comment.IssueID).Decr("num_comments").Update(new(Issue)); err != nil { + if comment.Type.CountedAsConversation() { + if err := UpdateIssueNumComments(ctx, comment.IssueID); err != nil { return err } } @@ -1300,6 +1297,21 @@ func (c *Comment) HasOriginalAuthor() bool { return c.OriginalAuthor != "" && c.OriginalAuthorID != 0 } +func UpdateIssueNumCommentsBuilder(issueID int64) *builder.Builder { + subQuery := builder.Select("COUNT(*)").From("`comment`").Where( + builder.Eq{"issue_id": issueID}.And( + builder.In("`type`", ConversationCountedCommentType()), + )) + + return builder.Update(builder.Eq{"num_comments": subQuery}). + From("`issue`").Where(builder.Eq{"id": issueID}) +} + +func UpdateIssueNumComments(ctx context.Context, issueID int64) error { + _, err := db.GetEngine(ctx).Exec(UpdateIssueNumCommentsBuilder(issueID)) + return err +} + // InsertIssueComments inserts many comments of issues. func InsertIssueComments(ctx context.Context, comments []*Comment) error { if len(comments) == 0 { @@ -1332,8 +1344,7 @@ func InsertIssueComments(ctx context.Context, comments []*Comment) error { } for _, issueID := range issueIDs { - if _, err := db.Exec(ctx, "UPDATE issue set num_comments = (SELECT count(*) FROM comment WHERE issue_id = ? AND `type`=?) WHERE id = ?", - issueID, CommentTypeComment, issueID); err != nil { + if err := UpdateIssueNumComments(ctx, issueID); err != nil { return err } } diff --git a/models/issues/comment_list.go b/models/issues/comment_list.go index 61ac1c8f56..c483ada75a 100644 --- a/models/issues/comment_list.go +++ b/models/issues/comment_list.go @@ -26,14 +26,14 @@ func (comments CommentList) LoadPosters(ctx context.Context) error { return c.PosterID, c.Poster == nil && c.PosterID > 0 }) - posterMaps, err := getPostersByIDs(ctx, posterIDs) + posterMaps, err := user_model.GetUsersMapByIDs(ctx, posterIDs) if err != nil { return err } for _, comment := range comments { if comment.Poster == nil { - comment.Poster = getPoster(comment.PosterID, posterMaps) + comment.Poster = user_model.GetPossibleUserFromMap(comment.PosterID, posterMaps) } } return nil @@ -41,7 +41,7 @@ func (comments CommentList) LoadPosters(ctx context.Context) error { func (comments CommentList) getLabelIDs() []int64 { return container.FilterSlice(comments, func(comment *Comment) (int64, bool) { - return comment.LabelID, comment.LabelID > 0 + return comment.LabelID, comment.LabelID > 0 && comment.Label == nil }) } @@ -51,6 +51,9 @@ func (comments CommentList) loadLabels(ctx context.Context) error { } labelIDs := comments.getLabelIDs() + if len(labelIDs) == 0 { + return nil + } commentLabels := make(map[int64]*Label, len(labelIDs)) left := len(labelIDs) for left > 0 { @@ -118,8 +121,8 @@ func (comments CommentList) loadMilestones(ctx context.Context) error { milestoneIDs = milestoneIDs[limit:] } - for _, issue := range comments { - issue.Milestone = milestoneMaps[issue.MilestoneID] + for _, comment := range comments { + comment.Milestone = milestoneMaps[comment.MilestoneID] } return nil } @@ -175,6 +178,9 @@ func (comments CommentList) loadAssignees(ctx context.Context) error { } assigneeIDs := comments.getAssigneeIDs() + if len(assigneeIDs) == 0 { + return nil + } assignees := make(map[int64]*user_model.User, len(assigneeIDs)) left := len(assigneeIDs) for left > 0 { @@ -301,6 +307,9 @@ func (comments CommentList) loadDependentIssues(ctx context.Context) error { e := db.GetEngine(ctx) issueIDs := comments.getDependentIssueIDs() + if len(issueIDs) == 0 { + return nil + } issues := make(map[int64]*Issue, len(issueIDs)) left := len(issueIDs) for left > 0 { @@ -427,6 +436,9 @@ func (comments CommentList) loadReviews(ctx context.Context) error { } reviewIDs := comments.getReviewIDs() + if len(reviewIDs) == 0 { + return nil + } reviews := make(map[int64]*Review, len(reviewIDs)) if err := db.GetEngine(ctx).In("id", reviewIDs).Find(&reviews); err != nil { return err diff --git a/models/issues/comment_test.go b/models/issues/comment_test.go index d81f33f953..ae0bc3ce17 100644 --- a/models/issues/comment_test.go +++ b/models/issues/comment_test.go @@ -45,6 +45,24 @@ func TestCreateComment(t *testing.T) { unittest.AssertInt64InRange(t, now, then, int64(updatedIssue.UpdatedUnix)) } +func Test_UpdateCommentAttachment(t *testing.T) { + assert.NoError(t, unittest.PrepareTestDatabase()) + + comment := unittest.AssertExistsAndLoadBean(t, &issues_model.Comment{ID: 1}) + attachment := repo_model.Attachment{ + Name: "test.txt", + } + assert.NoError(t, db.Insert(db.DefaultContext, &attachment)) + + err := issues_model.UpdateCommentAttachments(db.DefaultContext, comment, []string{attachment.UUID}) + assert.NoError(t, err) + + attachment2 := unittest.AssertExistsAndLoadBean(t, &repo_model.Attachment{ID: attachment.ID}) + assert.EqualValues(t, attachment.Name, attachment2.Name) + assert.EqualValues(t, comment.ID, attachment2.CommentID) + assert.EqualValues(t, comment.IssueID, attachment2.IssueID) +} + func TestFetchCodeComments(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) @@ -97,3 +115,12 @@ func TestMigrate_InsertIssueComments(t *testing.T) { unittest.CheckConsistencyFor(t, &issues_model.Issue{}) } + +func Test_UpdateIssueNumComments(t *testing.T) { + assert.NoError(t, unittest.PrepareTestDatabase()) + issue2 := unittest.AssertExistsAndLoadBean(t, &issues_model.Issue{ID: 2}) + + assert.NoError(t, issues_model.UpdateIssueNumComments(db.DefaultContext, issue2.ID)) + issue2 = unittest.AssertExistsAndLoadBean(t, &issues_model.Issue{ID: 2}) + assert.EqualValues(t, 1, issue2.NumComments) +} diff --git a/models/issues/dependency_test.go b/models/issues/dependency_test.go index 6eed483cc9..67418039de 100644 --- a/models/issues/dependency_test.go +++ b/models/issues/dependency_test.go @@ -49,9 +49,13 @@ func TestCreateIssueDependency(t *testing.T) { assert.False(t, left) // Close #2 and check again - _, err = issues_model.ChangeIssueStatus(db.DefaultContext, issue2, user1, true) + _, err = issues_model.CloseIssue(db.DefaultContext, issue2, user1) assert.NoError(t, err) + issue2Closed, err := issues_model.GetIssueByID(db.DefaultContext, 2) + assert.NoError(t, err) + assert.True(t, issue2Closed.IsClosed) + left, err = issues_model.IssueNoDependenciesLeft(db.DefaultContext, issue1) assert.NoError(t, err) assert.True(t, left) @@ -59,4 +63,11 @@ func TestCreateIssueDependency(t *testing.T) { // Test removing the dependency err = issues_model.RemoveIssueDependency(db.DefaultContext, user1, issue1, issue2, issues_model.DependencyTypeBlockedBy) assert.NoError(t, err) + + _, err = issues_model.ReopenIssue(db.DefaultContext, issue2, user1) + assert.NoError(t, err) + + issue2Reopened, err := issues_model.GetIssueByID(db.DefaultContext, 2) + assert.NoError(t, err) + assert.False(t, issue2Reopened.IsClosed) } diff --git a/models/issues/issue.go b/models/issues/issue.go index fe347c2715..564a9fb835 100644 --- a/models/issues/issue.go +++ b/models/issues/issue.go @@ -46,23 +46,6 @@ func (err ErrIssueNotExist) Unwrap() error { return util.ErrNotExist } -// ErrIssueIsClosed represents a "IssueIsClosed" kind of error. -type ErrIssueIsClosed struct { - ID int64 - RepoID int64 - Index int64 -} - -// IsErrIssueIsClosed checks if an error is a ErrIssueNotExist. -func IsErrIssueIsClosed(err error) bool { - _, ok := err.(ErrIssueIsClosed) - return ok -} - -func (err ErrIssueIsClosed) Error() string { - return fmt.Sprintf("issue is closed [id: %d, repo_id: %d, index: %d]", err.ID, err.RepoID, err.Index) -} - // ErrNewIssueInsert is used when the INSERT statement in newIssue fails type ErrNewIssueInsert struct { OriginalError error @@ -78,22 +61,6 @@ func (err ErrNewIssueInsert) Error() string { return err.OriginalError.Error() } -// ErrIssueWasClosed is used when close a closed issue -type ErrIssueWasClosed struct { - ID int64 - Index int64 -} - -// IsErrIssueWasClosed checks if an error is a ErrIssueWasClosed. -func IsErrIssueWasClosed(err error) bool { - _, ok := err.(ErrIssueWasClosed) - return ok -} - -func (err ErrIssueWasClosed) Error() string { - return fmt.Sprintf("Issue [%d] %d was already closed", err.ID, err.Index) -} - var ErrIssueAlreadyChanged = util.NewInvalidArgumentErrorf("the issue is already changed") // Issue represents an issue or pull request of repository. @@ -271,6 +238,9 @@ func (issue *Issue) loadCommentsByType(ctx context.Context, tp CommentType) (err IssueID: issue.ID, Type: tp, }) + for _, comment := range issue.Comments { + comment.Issue = issue + } return err } diff --git a/models/issues/issue_list.go b/models/issues/issue_list.go index 22a4548adc..02fd330f0a 100644 --- a/models/issues/issue_list.go +++ b/models/issues/issue_list.go @@ -81,53 +81,19 @@ func (issues IssueList) LoadPosters(ctx context.Context) error { return issue.PosterID, issue.Poster == nil && issue.PosterID > 0 }) - posterMaps, err := getPostersByIDs(ctx, posterIDs) + posterMaps, err := user_model.GetUsersMapByIDs(ctx, posterIDs) if err != nil { return err } for _, issue := range issues { if issue.Poster == nil { - issue.Poster = getPoster(issue.PosterID, posterMaps) + issue.Poster = user_model.GetPossibleUserFromMap(issue.PosterID, posterMaps) } } return nil } -func getPostersByIDs(ctx context.Context, posterIDs []int64) (map[int64]*user_model.User, error) { - posterMaps := make(map[int64]*user_model.User, len(posterIDs)) - left := len(posterIDs) - for left > 0 { - limit := db.DefaultMaxInSize - if left < limit { - limit = left - } - err := db.GetEngine(ctx). - In("id", posterIDs[:limit]). - Find(&posterMaps) - if err != nil { - return nil, err - } - left -= limit - posterIDs = posterIDs[limit:] - } - return posterMaps, nil -} - -func getPoster(posterID int64, posterMaps map[int64]*user_model.User) *user_model.User { - if posterID == user_model.ActionsUserID { - return user_model.NewActionsUser() - } - if posterID <= 0 { - return nil - } - poster, ok := posterMaps[posterID] - if !ok { - return user_model.NewGhostUser() - } - return poster -} - func (issues IssueList) getIssueIDs() []int64 { ids := make([]int64, 0, len(issues)) for _, issue := range issues { diff --git a/models/issues/issue_update.go b/models/issues/issue_update.go index 5b929c9115..7b3fe04aa5 100644 --- a/models/issues/issue_update.go +++ b/models/issues/issue_update.go @@ -28,38 +28,40 @@ import ( // UpdateIssueCols updates cols of issue func UpdateIssueCols(ctx context.Context, issue *Issue, cols ...string) error { - if _, err := db.GetEngine(ctx).ID(issue.ID).Cols(cols...).Update(issue); err != nil { - return err - } - return nil + _, err := db.GetEngine(ctx).ID(issue.ID).Cols(cols...).Update(issue) + return err } -func changeIssueStatus(ctx context.Context, issue *Issue, doer *user_model.User, isClosed, isMergePull bool) (*Comment, error) { - // Reload the issue - currentIssue, err := GetIssueByID(ctx, issue.ID) - if err != nil { - return nil, err - } +// ErrIssueIsClosed is used when close a closed issue +type ErrIssueIsClosed struct { + ID int64 + RepoID int64 + Index int64 + IsPull bool +} - // Nothing should be performed if current status is same as target status - if currentIssue.IsClosed == isClosed { - if !issue.IsPull { - return nil, ErrIssueWasClosed{ - ID: issue.ID, - } - } - return nil, ErrPullWasClosed{ - ID: issue.ID, - } - } +// IsErrIssueIsClosed checks if an error is a ErrIssueIsClosed. +func IsErrIssueIsClosed(err error) bool { + _, ok := err.(ErrIssueIsClosed) + return ok +} - issue.IsClosed = isClosed - return doChangeIssueStatus(ctx, issue, doer, isMergePull) +func (err ErrIssueIsClosed) Error() string { + return fmt.Sprintf("%s [id: %d, repo_id: %d, index: %d] is already closed", util.Iif(err.IsPull, "Pull Request", "Issue"), err.ID, err.RepoID, err.Index) } -func doChangeIssueStatus(ctx context.Context, issue *Issue, doer *user_model.User, isMergePull bool) (*Comment, error) { +func SetIssueAsClosed(ctx context.Context, issue *Issue, doer *user_model.User, isMergePull bool) (*Comment, error) { + if issue.IsClosed { + return nil, ErrIssueIsClosed{ + ID: issue.ID, + RepoID: issue.RepoID, + Index: issue.Index, + IsPull: issue.IsPull, + } + } + // Check for open dependencies - if issue.IsClosed && issue.Repo.IsDependenciesEnabled(ctx) { + if issue.Repo.IsDependenciesEnabled(ctx) { // only check if dependencies are enabled and we're about to close an issue, otherwise reopening an issue would fail when there are unsatisfied dependencies noDeps, err := IssueNoDependenciesLeft(ctx, issue) if err != nil { @@ -71,16 +73,63 @@ func doChangeIssueStatus(ctx context.Context, issue *Issue, doer *user_model.Use } } - if issue.IsClosed { - issue.ClosedUnix = timeutil.TimeStampNow() - } else { - issue.ClosedUnix = 0 + issue.IsClosed = true + issue.ClosedUnix = timeutil.TimeStampNow() + + if cnt, err := db.GetEngine(ctx).ID(issue.ID).Cols("is_closed", "closed_unix"). + Where("is_closed = ?", false). + Update(issue); err != nil { + return nil, err + } else if cnt != 1 { + return nil, ErrIssueAlreadyChanged } - if err := UpdateIssueCols(ctx, issue, "is_closed", "closed_unix"); err != nil { + return updateIssueNumbers(ctx, issue, doer, util.Iif(isMergePull, CommentTypeMergePull, CommentTypeClose)) +} + +// ErrIssueIsOpen is used when reopen an opened issue +type ErrIssueIsOpen struct { + ID int64 + RepoID int64 + IsPull bool + Index int64 +} + +// IsErrIssueIsOpen checks if an error is a ErrIssueIsOpen. +func IsErrIssueIsOpen(err error) bool { + _, ok := err.(ErrIssueIsOpen) + return ok +} + +func (err ErrIssueIsOpen) Error() string { + return fmt.Sprintf("%s [id: %d, repo_id: %d, index: %d] is already open", util.Iif(err.IsPull, "Pull Request", "Issue"), err.ID, err.RepoID, err.Index) +} + +func setIssueAsReopen(ctx context.Context, issue *Issue, doer *user_model.User) (*Comment, error) { + if !issue.IsClosed { + return nil, ErrIssueIsOpen{ + ID: issue.ID, + RepoID: issue.RepoID, + Index: issue.Index, + IsPull: issue.IsPull, + } + } + + issue.IsClosed = false + issue.ClosedUnix = 0 + + if cnt, err := db.GetEngine(ctx).ID(issue.ID).Cols("is_closed", "closed_unix"). + Where("is_closed = ?", true). + Update(issue); err != nil { return nil, err + } else if cnt != 1 { + return nil, ErrIssueAlreadyChanged } + return updateIssueNumbers(ctx, issue, doer, CommentTypeReopen) +} + +func updateIssueNumbers(ctx context.Context, issue *Issue, doer *user_model.User, cmtType CommentType) (*Comment, error) { // Update issue count of labels if err := issue.LoadLabels(ctx); err != nil { return nil, err @@ -103,14 +152,6 @@ func doChangeIssueStatus(ctx context.Context, issue *Issue, doer *user_model.Use return nil, err } - // New action comment - cmtType := CommentTypeClose - if !issue.IsClosed { - cmtType = CommentTypeReopen - } else if isMergePull { - cmtType = CommentTypeMergePull - } - return CreateComment(ctx, &CreateCommentOptions{ Type: cmtType, Doer: doer, @@ -119,8 +160,8 @@ func doChangeIssueStatus(ctx context.Context, issue *Issue, doer *user_model.Use }) } -// ChangeIssueStatus changes issue status to open or closed. -func ChangeIssueStatus(ctx context.Context, issue *Issue, doer *user_model.User, isClosed bool) (*Comment, error) { +// CloseIssue changes issue status to closed. +func CloseIssue(ctx context.Context, issue *Issue, doer *user_model.User) (*Comment, error) { if err := issue.LoadRepo(ctx); err != nil { return nil, err } @@ -128,7 +169,45 @@ func ChangeIssueStatus(ctx context.Context, issue *Issue, doer *user_model.User, return nil, err } - return changeIssueStatus(ctx, issue, doer, isClosed, false) + ctx, committer, err := db.TxContext(ctx) + if err != nil { + return nil, err + } + defer committer.Close() + + comment, err := SetIssueAsClosed(ctx, issue, doer, false) + if err != nil { + return nil, err + } + if err := committer.Commit(); err != nil { + return nil, err + } + return comment, nil +} + +// ReopenIssue changes issue status to open. +func ReopenIssue(ctx context.Context, issue *Issue, doer *user_model.User) (*Comment, error) { + if err := issue.LoadRepo(ctx); err != nil { + return nil, err + } + if err := issue.LoadPoster(ctx); err != nil { + return nil, err + } + + ctx, committer, err := db.TxContext(ctx) + if err != nil { + return nil, err + } + defer committer.Close() + + comment, err := setIssueAsReopen(ctx, issue, doer) + if err != nil { + return nil, err + } + if err := committer.Commit(); err != nil { + return nil, err + } + return comment, nil } // ChangeIssueTitle changes the title of this issue, as the given user. @@ -139,7 +218,7 @@ func ChangeIssueTitle(ctx context.Context, issue *Issue, doer *user_model.User, } defer committer.Close() - issue.Title, _ = util.SplitStringAtByteN(issue.Title, 255) + issue.Title = util.EllipsisDisplayString(issue.Title, 255) if err = UpdateIssueCols(ctx, issue, "name"); err != nil { return fmt.Errorf("updateIssueCols: %w", err) } @@ -367,19 +446,10 @@ func NewIssueWithIndex(ctx context.Context, doer *user_model.User, opts NewIssue return err } - if len(opts.Attachments) > 0 { - attachments, err := repo_model.GetAttachmentsByUUIDs(ctx, opts.Attachments) - if err != nil { - return fmt.Errorf("getAttachmentsByUUIDs [uuids: %v]: %w", opts.Attachments, err) - } - - for i := 0; i < len(attachments); i++ { - attachments[i].IssueID = opts.Issue.ID - if _, err = e.ID(attachments[i].ID).Update(attachments[i]); err != nil { - return fmt.Errorf("update attachment [id: %d]: %w", attachments[i].ID, err) - } - } + if err := UpdateIssueAttachments(ctx, opts.Issue.ID, opts.Attachments); err != nil { + return err } + if err = opts.Issue.LoadAttributes(ctx); err != nil { return err } @@ -402,7 +472,7 @@ func NewIssue(ctx context.Context, repo *repo_model.Repository, issue *Issue, la } issue.Index = idx - issue.Title, _ = util.SplitStringAtByteN(issue.Title, 255) + issue.Title = util.EllipsisDisplayString(issue.Title, 255) if err = NewIssueWithIndex(ctx, issue.Poster, NewIssueOptions{ Repo: repo, diff --git a/models/issues/issue_user_test.go b/models/issues/issue_user_test.go index ce47adb53a..7c21aa15ee 100644 --- a/models/issues/issue_user_test.go +++ b/models/issues/issue_user_test.go @@ -12,6 +12,7 @@ import ( "code.gitea.io/gitea/models/unittest" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test_NewIssueUsers(t *testing.T) { @@ -27,9 +28,8 @@ func Test_NewIssueUsers(t *testing.T) { } // artificially insert new issue - unittest.AssertSuccessfulInsert(t, newIssue) - - assert.NoError(t, issues_model.NewIssueUsers(db.DefaultContext, repo, newIssue)) + require.NoError(t, db.Insert(db.DefaultContext, newIssue)) + require.NoError(t, issues_model.NewIssueUsers(db.DefaultContext, repo, newIssue)) // issue_user table should now have entries for new issue unittest.AssertExistsAndLoadBean(t, &issues_model.IssueUser{IssueID: newIssue.ID, UID: newIssue.PosterID}) diff --git a/models/issues/issue_xref_test.go b/models/issues/issue_xref_test.go index f1b1bb2a6b..7f257330b7 100644 --- a/models/issues/issue_xref_test.go +++ b/models/issues/issue_xref_test.go @@ -98,7 +98,7 @@ func TestXRef_ResolveCrossReferences(t *testing.T) { i1 := testCreateIssue(t, 1, 2, "title1", "content1", false) i2 := testCreateIssue(t, 1, 2, "title2", "content2", false) i3 := testCreateIssue(t, 1, 2, "title3", "content3", false) - _, err := issues_model.ChangeIssueStatus(db.DefaultContext, i3, d, true) + _, err := issues_model.CloseIssue(db.DefaultContext, i3, d) assert.NoError(t, err) pr := testCreatePR(t, 1, 2, "titlepr", fmt.Sprintf("closes #%d", i1.Index)) diff --git a/models/issues/label.go b/models/issues/label.go index d80578193e..b9d24bbe99 100644 --- a/models/issues/label.go +++ b/models/issues/label.go @@ -349,6 +349,17 @@ func GetLabelIDsInRepoByNames(ctx context.Context, repoID int64, labelNames []st Find(&labelIDs) } +// GetLabelIDsInOrgByNames returns a list of labelIDs by names in a given org. +func GetLabelIDsInOrgByNames(ctx context.Context, orgID int64, labelNames []string) ([]int64, error) { + labelIDs := make([]int64, 0, len(labelNames)) + return labelIDs, db.GetEngine(ctx).Table("label"). + Where("org_id = ?", orgID). + In("name", labelNames). + Asc("name"). + Cols("id"). + Find(&labelIDs) +} + // BuildLabelNamesIssueIDsCondition returns a builder where get issue ids match label names func BuildLabelNamesIssueIDsCondition(labelNames []string) *builder.Builder { return builder.Select("issue_label.issue_id"). diff --git a/models/issues/label_test.go b/models/issues/label_test.go index 1d4b6f4684..185fa11bbc 100644 --- a/models/issues/label_test.go +++ b/models/issues/label_test.go @@ -387,7 +387,7 @@ func TestDeleteIssueLabel(t *testing.T) { expectedNumIssues := label.NumIssues expectedNumClosedIssues := label.NumClosedIssues - if unittest.BeanExists(t, &issues_model.IssueLabel{IssueID: issueID, LabelID: labelID}) { + if unittest.GetBean(t, &issues_model.IssueLabel{IssueID: issueID, LabelID: labelID}) != nil { expectedNumIssues-- if issue.IsClosed { expectedNumClosedIssues-- diff --git a/models/issues/pull.go b/models/issues/pull.go index 853e2a69e6..e3af00224d 100644 --- a/models/issues/pull.go +++ b/models/issues/pull.go @@ -80,22 +80,6 @@ func (err ErrPullRequestAlreadyExists) Unwrap() error { return util.ErrAlreadyExist } -// ErrPullWasClosed is used close a closed pull request -type ErrPullWasClosed struct { - ID int64 - Index int64 -} - -// IsErrPullWasClosed checks if an error is a ErrErrPullWasClosed. -func IsErrPullWasClosed(err error) bool { - _, ok := err.(ErrPullWasClosed) - return ok -} - -func (err ErrPullWasClosed) Error() string { - return fmt.Sprintf("Pull request [%d] %d was already closed", err.ID, err.Index) -} - // PullRequestType defines pull request type type PullRequestType int @@ -301,7 +285,7 @@ func (pr *PullRequest) LoadRequestedReviewers(ctx context.Context) error { return nil } - reviews, err := GetReviewsByIssueID(ctx, pr.Issue.ID) + reviews, _, err := GetReviewsByIssueID(ctx, pr.Issue.ID) if err != nil { return err } @@ -320,7 +304,7 @@ func (pr *PullRequest) LoadRequestedReviewers(ctx context.Context) error { // LoadRequestedReviewersTeams loads the requested reviewers teams. func (pr *PullRequest) LoadRequestedReviewersTeams(ctx context.Context) error { - reviews, err := GetReviewsByIssueID(ctx, pr.Issue.ID) + reviews, _, err := GetReviewsByIssueID(ctx, pr.Issue.ID) if err != nil { return err } @@ -499,65 +483,6 @@ func (pr *PullRequest) IsFromFork() bool { return pr.HeadRepoID != pr.BaseRepoID } -// SetMerged sets a pull request to merged and closes the corresponding issue -func (pr *PullRequest) SetMerged(ctx context.Context) (bool, error) { - if pr.HasMerged { - return false, fmt.Errorf("PullRequest[%d] already merged", pr.Index) - } - if pr.MergedCommitID == "" || pr.MergedUnix == 0 || pr.Merger == nil { - return false, fmt.Errorf("Unable to merge PullRequest[%d], some required fields are empty", pr.Index) - } - - pr.HasMerged = true - sess := db.GetEngine(ctx) - - if _, err := sess.Exec("UPDATE `issue` SET `repo_id` = `repo_id` WHERE `id` = ?", pr.IssueID); err != nil { - return false, err - } - - if _, err := sess.Exec("UPDATE `pull_request` SET `issue_id` = `issue_id` WHERE `id` = ?", pr.ID); err != nil { - return false, err - } - - pr.Issue = nil - if err := pr.LoadIssue(ctx); err != nil { - return false, err - } - - if tmpPr, err := GetPullRequestByID(ctx, pr.ID); err != nil { - return false, err - } else if tmpPr.HasMerged { - if pr.Issue.IsClosed { - return false, nil - } - return false, fmt.Errorf("PullRequest[%d] already merged but it's associated issue [%d] is not closed", pr.Index, pr.IssueID) - } else if pr.Issue.IsClosed { - return false, fmt.Errorf("PullRequest[%d] already closed", pr.Index) - } - - if err := pr.Issue.LoadRepo(ctx); err != nil { - return false, err - } - - if err := pr.Issue.Repo.LoadOwner(ctx); err != nil { - return false, err - } - - if _, err := changeIssueStatus(ctx, pr.Issue, pr.Merger, true, true); err != nil { - return false, fmt.Errorf("Issue.changeStatus: %w", err) - } - - // reset the conflicted files as there cannot be any if we're merged - pr.ConflictedFiles = []string{} - - // We need to save all of the data used to compute this merge as it may have already been changed by TestPatch. FIXME: need to set some state to prevent TestPatch from running whilst we are merging. - if _, err := sess.Where("id = ?", pr.ID).Cols("has_merged, status, merge_base, merged_commit_id, merger_id, merged_unix, conflicted_files").Update(pr); err != nil { - return false, fmt.Errorf("Failed to update pr[%d]: %w", pr.ID, err) - } - - return true, nil -} - // NewPullRequest creates new pull request with labels for repository. func NewPullRequest(ctx context.Context, repo *repo_model.Repository, issue *Issue, labelIDs []int64, uuids []string, pr *PullRequest) (err error) { ctx, committer, err := db.TxContext(ctx) @@ -572,7 +497,7 @@ func NewPullRequest(ctx context.Context, repo *repo_model.Repository, issue *Iss } issue.Index = idx - issue.Title, _ = util.SplitStringAtByteN(issue.Title, 255) + issue.Title = util.EllipsisDisplayString(issue.Title, 255) if err = NewIssueWithIndex(ctx, issue.Poster, NewIssueOptions{ Repo: repo, diff --git a/models/issues/pull_list.go b/models/issues/pull_list.go index 59010aa9d0..1ddb94e566 100644 --- a/models/issues/pull_list.go +++ b/models/issues/pull_list.go @@ -166,6 +166,23 @@ func (prs PullRequestList) getRepositoryIDs() []int64 { return repoIDs.Values() } +func (prs PullRequestList) SetBaseRepo(baseRepo *repo_model.Repository) { + for _, pr := range prs { + if pr.BaseRepo == nil { + pr.BaseRepo = baseRepo + } + } +} + +func (prs PullRequestList) SetHeadRepo(headRepo *repo_model.Repository) { + for _, pr := range prs { + if pr.HeadRepo == nil { + pr.HeadRepo = headRepo + pr.isHeadRepoLoaded = true + } + } +} + func (prs PullRequestList) LoadRepositories(ctx context.Context) error { repoIDs := prs.getRepositoryIDs() reposMap := make(map[int64]*repo_model.Repository, len(repoIDs)) diff --git a/models/issues/review.go b/models/issues/review.go index 8b345e5fd8..3e787273be 100644 --- a/models/issues/review.go +++ b/models/issues/review.go @@ -639,6 +639,10 @@ func InsertReviews(ctx context.Context, reviews []*Review) error { return err } } + + if err := UpdateIssueNumComments(ctx, review.IssueID); err != nil { + return err + } } return committer.Commit() diff --git a/models/issues/review_list.go b/models/issues/review_list.go index bc7d7ec0f0..928f24fb2d 100644 --- a/models/issues/review_list.go +++ b/models/issues/review_list.go @@ -5,6 +5,8 @@ package issues import ( "context" + "slices" + "sort" "code.gitea.io/gitea/models/db" organization_model "code.gitea.io/gitea/models/organization" @@ -153,43 +155,60 @@ func CountReviews(ctx context.Context, opts FindReviewOptions) (int64, error) { return db.GetEngine(ctx).Where(opts.toCond()).Count(&Review{}) } -// GetReviewersFromOriginalAuthorsByIssueID gets the latest review of each original authors for a pull request -func GetReviewersFromOriginalAuthorsByIssueID(ctx context.Context, issueID int64) (ReviewList, error) { +// GetReviewsByIssueID gets the latest review of each reviewer for a pull request +// The first returned parameter is the latest review of each individual reviewer or team +// The second returned parameter is the latest review of each original author which is migrated from other systems +// The reviews are sorted by updated time +func GetReviewsByIssueID(ctx context.Context, issueID int64) (latestReviews, migratedOriginalReviews ReviewList, err error) { reviews := make([]*Review, 0, 10) - // Get latest review of each reviewer, sorted in order they were made - if err := db.GetEngine(ctx).SQL("SELECT * FROM review WHERE id IN (SELECT max(id) as id FROM review WHERE issue_id = ? AND reviewer_team_id = 0 AND type in (?, ?, ?) AND original_author_id <> 0 GROUP BY issue_id, original_author_id) ORDER BY review.updated_unix ASC", - issueID, ReviewTypeApprove, ReviewTypeReject, ReviewTypeRequest). - Find(&reviews); err != nil { - return nil, err + // Get all reviews for the issue id + if err := db.GetEngine(ctx).Where("issue_id=?", issueID).OrderBy("updated_unix ASC").Find(&reviews); err != nil { + return nil, nil, err } - return reviews, nil -} - -// GetReviewsByIssueID gets the latest review of each reviewer for a pull request -func GetReviewsByIssueID(ctx context.Context, issueID int64) (ReviewList, error) { - reviews := make([]*Review, 0, 10) - - sess := db.GetEngine(ctx) + // filter them in memory to get the latest review of each reviewer + // Since the reviews should not be too many for one issue, less than 100 commonly, it's acceptable to do this in memory + // And since there are too less indexes in review table, it will be very slow to filter in the database + reviewersMap := make(map[int64][]*Review) // key is reviewer id + originalReviewersMap := make(map[int64][]*Review) // key is original author id + reviewTeamsMap := make(map[int64][]*Review) // key is reviewer team id + countedReivewTypes := []ReviewType{ReviewTypeApprove, ReviewTypeReject, ReviewTypeRequest} + for _, review := range reviews { + if review.ReviewerTeamID == 0 && slices.Contains(countedReivewTypes, review.Type) && !review.Dismissed { + if review.OriginalAuthorID != 0 { + originalReviewersMap[review.OriginalAuthorID] = append(originalReviewersMap[review.OriginalAuthorID], review) + } else { + reviewersMap[review.ReviewerID] = append(reviewersMap[review.ReviewerID], review) + } + } else if review.ReviewerTeamID != 0 && review.OriginalAuthorID == 0 { + reviewTeamsMap[review.ReviewerTeamID] = append(reviewTeamsMap[review.ReviewerTeamID], review) + } + } - // Get latest review of each reviewer, sorted in order they were made - if err := sess.SQL("SELECT * FROM review WHERE id IN (SELECT max(id) as id FROM review WHERE issue_id = ? AND reviewer_team_id = 0 AND type in (?, ?, ?) AND dismissed = ? AND original_author_id = 0 GROUP BY issue_id, reviewer_id) ORDER BY review.updated_unix ASC", - issueID, ReviewTypeApprove, ReviewTypeReject, ReviewTypeRequest, false). - Find(&reviews); err != nil { - return nil, err + individualReviews := make([]*Review, 0, 10) + for _, reviews := range reviewersMap { + individualReviews = append(individualReviews, reviews[len(reviews)-1]) } + sort.Slice(individualReviews, func(i, j int) bool { + return individualReviews[i].UpdatedUnix < individualReviews[j].UpdatedUnix + }) - teamReviewRequests := make([]*Review, 0, 5) - if err := sess.SQL("SELECT * FROM review WHERE id IN (SELECT max(id) as id FROM review WHERE issue_id = ? AND reviewer_team_id <> 0 AND original_author_id = 0 GROUP BY issue_id, reviewer_team_id) ORDER BY review.updated_unix ASC", - issueID). - Find(&teamReviewRequests); err != nil { - return nil, err + originalReviews := make([]*Review, 0, 10) + for _, reviews := range originalReviewersMap { + originalReviews = append(originalReviews, reviews[len(reviews)-1]) } + sort.Slice(originalReviews, func(i, j int) bool { + return originalReviews[i].UpdatedUnix < originalReviews[j].UpdatedUnix + }) - if len(teamReviewRequests) > 0 { - reviews = append(reviews, teamReviewRequests...) + teamReviewRequests := make([]*Review, 0, 5) + for _, reviews := range reviewTeamsMap { + teamReviewRequests = append(teamReviewRequests, reviews[len(reviews)-1]) } + sort.Slice(teamReviewRequests, func(i, j int) bool { + return teamReviewRequests[i].UpdatedUnix < teamReviewRequests[j].UpdatedUnix + }) - return reviews, nil + return append(individualReviews, teamReviewRequests...), originalReviews, nil } diff --git a/models/issues/review_test.go b/models/issues/review_test.go index 50330e3ff2..2588b8ba41 100644 --- a/models/issues/review_test.go +++ b/models/issues/review_test.go @@ -162,8 +162,9 @@ func TestGetReviewersByIssueID(t *testing.T) { }, ) - allReviews, err := issues_model.GetReviewsByIssueID(db.DefaultContext, issue.ID) + allReviews, migratedReviews, err := issues_model.GetReviewsByIssueID(db.DefaultContext, issue.ID) assert.NoError(t, err) + assert.Empty(t, migratedReviews) for _, review := range allReviews { assert.NoError(t, review.LoadReviewer(db.DefaultContext)) } diff --git a/models/issues/stopwatch.go b/models/issues/stopwatch.go index 629af95b57..7c05a3a883 100644 --- a/models/issues/stopwatch.go +++ b/models/issues/stopwatch.go @@ -46,11 +46,6 @@ func (s Stopwatch) Seconds() int64 { return int64(timeutil.TimeStampNow() - s.CreatedUnix) } -// Duration returns a human-readable duration string based on local server time -func (s Stopwatch) Duration() string { - return util.SecToTime(s.Seconds()) -} - func getStopwatch(ctx context.Context, userID, issueID int64) (sw *Stopwatch, exists bool, err error) { sw = new(Stopwatch) exists, err = db.GetEngine(ctx). @@ -201,7 +196,7 @@ func FinishIssueStopwatch(ctx context.Context, user *user_model.User, issue *Iss Doer: user, Issue: issue, Repo: issue.Repo, - Content: util.SecToTime(timediff), + Content: util.SecToHours(timediff), Type: CommentTypeStopTracking, TimeID: tt.ID, }); err != nil { diff --git a/models/migrations/base/tests.go b/models/migrations/base/tests.go index 2eb85cd8a7..fe6de9c517 100644 --- a/models/migrations/base/tests.go +++ b/models/migrations/base/tests.go @@ -76,7 +76,7 @@ func PrepareTestEnv(t *testing.T, skip int, syncModels ...any) (*xorm.Engine, fu t.Errorf("error whilst initializing fixtures from %s: %v", fixturesDir, err) return x, deferFn } - if err := unittest.LoadFixtures(x); err != nil { + if err := unittest.LoadFixtures(); err != nil { t.Errorf("error whilst loading fixtures from %s: %v", fixturesDir, err) return x, deferFn } diff --git a/models/migrations/migrations.go b/models/migrations/migrations.go index 52d10c4fe8..95364ab705 100644 --- a/models/migrations/migrations.go +++ b/models/migrations/migrations.go @@ -22,6 +22,7 @@ import ( "code.gitea.io/gitea/models/migrations/v1_21" "code.gitea.io/gitea/models/migrations/v1_22" "code.gitea.io/gitea/models/migrations/v1_23" + "code.gitea.io/gitea/models/migrations/v1_24" "code.gitea.io/gitea/models/migrations/v1_6" "code.gitea.io/gitea/models/migrations/v1_7" "code.gitea.io/gitea/models/migrations/v1_8" @@ -369,6 +370,9 @@ func prepareMigrationTasks() []*migration { newMigration(309, "Improve Notification table indices", v1_23.ImproveNotificationTableIndices), newMigration(310, "Add Priority to ProtectedBranch", v1_23.AddPriorityToProtectedBranch), newMigration(311, "Add TimeEstimate to Issue table", v1_23.AddTimeEstimateColumnToIssueTable), + + // Gitea 1.23.0-rc0 ends at migration ID number 311 (database version 312) + newMigration(312, "Add DeleteBranchAfterMerge to AutoMerge", v1_24.AddDeleteBranchAfterMergeForAutoMerge), } return preparedMigrations } diff --git a/models/migrations/v1_21/v276.go b/models/migrations/v1_21/v276.go index 15177bf040..9d22c9052e 100644 --- a/models/migrations/v1_21/v276.go +++ b/models/migrations/v1_21/v276.go @@ -172,7 +172,7 @@ func getRemoteAddress(ownerName, repoName, remoteName string) (string, error) { return "", fmt.Errorf("get remote %s's address of %s/%s failed: %v", remoteName, ownerName, repoName, err) } - u, err := giturl.Parse(remoteURL) + u, err := giturl.ParseGitURL(remoteURL) if err != nil { return "", err } diff --git a/models/migrations/v1_24/v312.go b/models/migrations/v1_24/v312.go new file mode 100644 index 0000000000..9766dc1ccf --- /dev/null +++ b/models/migrations/v1_24/v312.go @@ -0,0 +1,21 @@ +// Copyright 2024 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package v1_24 //nolint + +import ( + "xorm.io/xorm" +) + +type pullAutoMerge struct { + DeleteBranchAfterMerge bool +} + +// TableName return database table name for xorm +func (pullAutoMerge) TableName() string { + return "pull_auto_merge" +} + +func AddDeleteBranchAfterMergeForAutoMerge(x *xorm.Engine) error { + return x.Sync(new(pullAutoMerge)) +} diff --git a/models/organization/org_test.go b/models/organization/org_test.go index 2c5b4090df..b882a25be3 100644 --- a/models/organization/org_test.go +++ b/models/organization/org_test.go @@ -16,6 +16,7 @@ import ( "code.gitea.io/gitea/modules/structs" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestUser_IsOwnedBy(t *testing.T) { @@ -180,9 +181,8 @@ func TestRestrictedUserOrgMembers(t *testing.T) { ID: 29, IsRestricted: true, }) - if !assert.True(t, restrictedUser.IsRestricted) { - return // ensure fixtures return restricted user - } + // ensure fixtures return restricted user + require.True(t, restrictedUser.IsRestricted) testCases := []struct { name string diff --git a/models/organization/org_user_test.go b/models/organization/org_user_test.go index 55abb63203..c5110b2a34 100644 --- a/models/organization/org_user_test.go +++ b/models/organization/org_user_test.go @@ -131,7 +131,7 @@ func TestAddOrgUser(t *testing.T) { testSuccess := func(orgID, userID int64, isPublic bool) { org := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: orgID}) expectedNumMembers := org.NumMembers - if !unittest.BeanExists(t, &organization.OrgUser{OrgID: orgID, UID: userID}) { + if unittest.GetBean(t, &organization.OrgUser{OrgID: orgID, UID: userID}) == nil { expectedNumMembers++ } assert.NoError(t, organization.AddOrgUser(db.DefaultContext, orgID, userID)) diff --git a/models/packages/package.go b/models/packages/package.go index c12f345f0e..31e1277a6e 100644 --- a/models/packages/package.go +++ b/models/packages/package.go @@ -248,6 +248,18 @@ func GetPackageByID(ctx context.Context, packageID int64) (*Package, error) { return p, nil } +// UpdatePackageNameByID updates the package's name, it is only for internal usage, for example: rename some legacy packages +func UpdatePackageNameByID(ctx context.Context, ownerID int64, packageType Type, packageID int64, name string) error { + var cond builder.Cond = builder.Eq{ + "package.id": packageID, + "package.owner_id": ownerID, + "package.type": packageType, + "package.is_internal": false, + } + _, err := db.GetEngine(ctx).Where(cond).Update(&Package{Name: name, LowerName: strings.ToLower(name)}) + return err +} + // GetPackageByName gets a package by name func GetPackageByName(ctx context.Context, ownerID int64, packageType Type, name string) (*Package, error) { var cond builder.Cond = builder.Eq{ diff --git a/models/perm/access/repo_permission.go b/models/perm/access/repo_permission.go index 0ed116a132..e00b7c5320 100644 --- a/models/perm/access/repo_permission.go +++ b/models/perm/access/repo_permission.go @@ -175,10 +175,14 @@ func (p *Permission) LogString() string { return fmt.Sprintf(format, args...) } -func applyEveryoneRepoPermission(user *user_model.User, perm *Permission) { +func finalProcessRepoUnitPermission(user *user_model.User, perm *Permission) { if user == nil || user.ID <= 0 { + // for anonymous access, it could be: + // AccessMode is None or Read, units has repo units, unitModes is nil return } + + // apply everyone access permissions for _, u := range perm.units { if u.EveryoneAccessMode >= perm_model.AccessModeRead && u.EveryoneAccessMode > perm.everyoneAccessMode[u.Type] { if perm.everyoneAccessMode == nil { @@ -187,17 +191,40 @@ func applyEveryoneRepoPermission(user *user_model.User, perm *Permission) { perm.everyoneAccessMode[u.Type] = u.EveryoneAccessMode } } + + if perm.unitsMode == nil { + // if unitsMode is not set, then it means that the default p.AccessMode applies to all units + return + } + + // remove no permission units + origPermUnits := perm.units + perm.units = make([]*repo_model.RepoUnit, 0, len(perm.units)) + for _, u := range origPermUnits { + shouldKeep := false + for t := range perm.unitsMode { + if shouldKeep = u.Type == t; shouldKeep { + break + } + } + for t := range perm.everyoneAccessMode { + if shouldKeep = shouldKeep || u.Type == t; shouldKeep { + break + } + } + if shouldKeep { + perm.units = append(perm.units, u) + } + } } // GetUserRepoPermission returns the user permissions to the repository func GetUserRepoPermission(ctx context.Context, repo *repo_model.Repository, user *user_model.User) (perm Permission, err error) { defer func() { if err == nil { - applyEveryoneRepoPermission(user, &perm) - } - if log.IsTrace() { - log.Trace("Permission Loaded for user %-v in repo %-v, permissions: %-+v", user, repo, perm) + finalProcessRepoUnitPermission(user, &perm) } + log.Trace("Permission Loaded for user %-v in repo %-v, permissions: %-+v", user, repo, perm) }() if err = repo.LoadUnits(ctx); err != nil { @@ -294,16 +321,6 @@ func GetUserRepoPermission(ctx context.Context, repo *repo_model.Repository, use } } - // remove no permission units - perm.units = make([]*repo_model.RepoUnit, 0, len(repo.Units)) - for t := range perm.unitsMode { - for _, u := range repo.Units { - if u.Type == t { - perm.units = append(perm.units, u) - } - } - } - return perm, err } diff --git a/models/perm/access/repo_permission_test.go b/models/perm/access/repo_permission_test.go index 50070c4368..9862da0673 100644 --- a/models/perm/access/repo_permission_test.go +++ b/models/perm/access/repo_permission_test.go @@ -50,7 +50,7 @@ func TestApplyEveryoneRepoPermission(t *testing.T) { {Type: unit.TypeWiki, EveryoneAccessMode: perm_model.AccessModeRead}, }, } - applyEveryoneRepoPermission(nil, &perm) + finalProcessRepoUnitPermission(nil, &perm) assert.False(t, perm.CanRead(unit.TypeWiki)) perm = Permission{ @@ -59,7 +59,7 @@ func TestApplyEveryoneRepoPermission(t *testing.T) { {Type: unit.TypeWiki, EveryoneAccessMode: perm_model.AccessModeRead}, }, } - applyEveryoneRepoPermission(&user_model.User{ID: 0}, &perm) + finalProcessRepoUnitPermission(&user_model.User{ID: 0}, &perm) assert.False(t, perm.CanRead(unit.TypeWiki)) perm = Permission{ @@ -68,7 +68,7 @@ func TestApplyEveryoneRepoPermission(t *testing.T) { {Type: unit.TypeWiki, EveryoneAccessMode: perm_model.AccessModeRead}, }, } - applyEveryoneRepoPermission(&user_model.User{ID: 1}, &perm) + finalProcessRepoUnitPermission(&user_model.User{ID: 1}, &perm) assert.True(t, perm.CanRead(unit.TypeWiki)) perm = Permission{ @@ -77,20 +77,22 @@ func TestApplyEveryoneRepoPermission(t *testing.T) { {Type: unit.TypeWiki, EveryoneAccessMode: perm_model.AccessModeRead}, }, } - applyEveryoneRepoPermission(&user_model.User{ID: 1}, &perm) + finalProcessRepoUnitPermission(&user_model.User{ID: 1}, &perm) // it should work the same as "EveryoneAccessMode: none" because the default AccessMode should be applied to units assert.True(t, perm.CanWrite(unit.TypeWiki)) perm = Permission{ units: []*repo_model.RepoUnit{ + {Type: unit.TypeCode}, // will be removed {Type: unit.TypeWiki, EveryoneAccessMode: perm_model.AccessModeRead}, }, unitsMode: map[unit.Type]perm_model.AccessMode{ unit.TypeWiki: perm_model.AccessModeWrite, }, } - applyEveryoneRepoPermission(&user_model.User{ID: 1}, &perm) + finalProcessRepoUnitPermission(&user_model.User{ID: 1}, &perm) assert.True(t, perm.CanWrite(unit.TypeWiki)) + assert.Len(t, perm.units, 1) } func TestUnitAccessMode(t *testing.T) { diff --git a/models/project/project.go b/models/project/project.go index 9779908b9d..edeb0b4742 100644 --- a/models/project/project.go +++ b/models/project/project.go @@ -126,6 +126,14 @@ func (p *Project) LoadRepo(ctx context.Context) (err error) { return err } +func ProjectLinkForOrg(org *user_model.User, projectID int64) string { //nolint + return fmt.Sprintf("%s/-/projects/%d", org.HomeLink(), projectID) +} + +func ProjectLinkForRepo(repo *repo_model.Repository, projectID int64) string { //nolint + return fmt.Sprintf("%s/projects/%d", repo.Link(), projectID) +} + // Link returns the project's relative URL. func (p *Project) Link(ctx context.Context) string { if p.OwnerID > 0 { @@ -134,7 +142,7 @@ func (p *Project) Link(ctx context.Context) string { log.Error("LoadOwner: %v", err) return "" } - return fmt.Sprintf("%s/-/projects/%d", p.Owner.HomeLink(), p.ID) + return ProjectLinkForOrg(p.Owner, p.ID) } if p.RepoID > 0 { err := p.LoadRepo(ctx) @@ -142,7 +150,7 @@ func (p *Project) Link(ctx context.Context) string { log.Error("LoadRepo: %v", err) return "" } - return fmt.Sprintf("%s/projects/%d", p.Repo.Link(), p.ID) + return ProjectLinkForRepo(p.Repo, p.ID) } return "" } @@ -256,7 +264,7 @@ func NewProject(ctx context.Context, p *Project) error { return util.NewInvalidArgumentErrorf("project type is not valid") } - p.Title, _ = util.SplitStringAtByteN(p.Title, 255) + p.Title = util.EllipsisDisplayString(p.Title, 255) return db.WithTx(ctx, func(ctx context.Context) error { if err := db.Insert(ctx, p); err != nil { @@ -311,7 +319,7 @@ func UpdateProject(ctx context.Context, p *Project) error { p.CardType = CardTypeTextOnly } - p.Title, _ = util.SplitStringAtByteN(p.Title, 255) + p.Title = util.EllipsisDisplayString(p.Title, 255) _, err := db.GetEngine(ctx).ID(p.ID).Cols( "title", "description", diff --git a/models/pull/automerge.go b/models/pull/automerge.go index f69fcb60d1..3cafacc3a4 100644 --- a/models/pull/automerge.go +++ b/models/pull/automerge.go @@ -15,13 +15,14 @@ import ( // AutoMerge represents a pull request scheduled for merging when checks succeed type AutoMerge struct { - ID int64 `xorm:"pk autoincr"` - PullID int64 `xorm:"UNIQUE"` - DoerID int64 `xorm:"INDEX NOT NULL"` - Doer *user_model.User `xorm:"-"` - MergeStyle repo_model.MergeStyle `xorm:"varchar(30)"` - Message string `xorm:"LONGTEXT"` - CreatedUnix timeutil.TimeStamp `xorm:"created"` + ID int64 `xorm:"pk autoincr"` + PullID int64 `xorm:"UNIQUE"` + DoerID int64 `xorm:"INDEX NOT NULL"` + Doer *user_model.User `xorm:"-"` + MergeStyle repo_model.MergeStyle `xorm:"varchar(30)"` + Message string `xorm:"LONGTEXT"` + DeleteBranchAfterMerge bool + CreatedUnix timeutil.TimeStamp `xorm:"created"` } // TableName return database table name for xorm @@ -49,7 +50,7 @@ func IsErrAlreadyScheduledToAutoMerge(err error) bool { } // ScheduleAutoMerge schedules a pull request to be merged when all checks succeed -func ScheduleAutoMerge(ctx context.Context, doer *user_model.User, pullID int64, style repo_model.MergeStyle, message string) error { +func ScheduleAutoMerge(ctx context.Context, doer *user_model.User, pullID int64, style repo_model.MergeStyle, message string, deleteBranchAfterMerge bool) error { // Check if we already have a merge scheduled for that pull request if exists, _, err := GetScheduledMergeByPullID(ctx, pullID); err != nil { return err @@ -58,10 +59,11 @@ func ScheduleAutoMerge(ctx context.Context, doer *user_model.User, pullID int64, } _, err := db.GetEngine(ctx).Insert(&AutoMerge{ - DoerID: doer.ID, - PullID: pullID, - MergeStyle: style, - Message: message, + DoerID: doer.ID, + PullID: pullID, + MergeStyle: style, + Message: message, + DeleteBranchAfterMerge: deleteBranchAfterMerge, }) return err } diff --git a/models/repo.go b/models/repo.go index 3e9c52fdd9..9bc67079a9 100644 --- a/models/repo.go +++ b/models/repo.go @@ -16,6 +16,8 @@ import ( "code.gitea.io/gitea/models/unit" user_model "code.gitea.io/gitea/models/user" "code.gitea.io/gitea/modules/log" + + "xorm.io/builder" ) // Init initialize model @@ -24,7 +26,7 @@ func Init(ctx context.Context) error { } type repoChecker struct { - querySQL func(ctx context.Context) ([]map[string][]byte, error) + querySQL func(ctx context.Context) ([]int64, error) correctSQL func(ctx context.Context, id int64) error desc string } @@ -35,8 +37,7 @@ func repoStatsCheck(ctx context.Context, checker *repoChecker) { log.Error("Select %s: %v", checker.desc, err) return } - for _, result := range results { - id, _ := strconv.ParseInt(string(result["id"]), 10, 64) + for _, id := range results { select { case <-ctx.Done(): log.Warn("CheckRepoStats: Cancelled before checking %s for with id=%d", checker.desc, id) @@ -51,21 +52,23 @@ func repoStatsCheck(ctx context.Context, checker *repoChecker) { } } -func StatsCorrectSQL(ctx context.Context, sql string, id int64) error { - _, err := db.GetEngine(ctx).Exec(sql, id, id) +func StatsCorrectSQL(ctx context.Context, sql any, ids ...any) error { + args := []any{sql} + args = append(args, ids...) + _, err := db.GetEngine(ctx).Exec(args...) return err } func repoStatsCorrectNumWatches(ctx context.Context, id int64) error { - return StatsCorrectSQL(ctx, "UPDATE `repository` SET num_watches=(SELECT COUNT(*) FROM `watch` WHERE repo_id=? AND mode<>2) WHERE id=?", id) + return StatsCorrectSQL(ctx, "UPDATE `repository` SET num_watches=(SELECT COUNT(*) FROM `watch` WHERE repo_id=? AND mode<>2) WHERE id=?", id, id) } func repoStatsCorrectNumStars(ctx context.Context, id int64) error { - return StatsCorrectSQL(ctx, "UPDATE `repository` SET num_stars=(SELECT COUNT(*) FROM `star` WHERE repo_id=?) WHERE id=?", id) + return StatsCorrectSQL(ctx, "UPDATE `repository` SET num_stars=(SELECT COUNT(*) FROM `star` WHERE repo_id=?) WHERE id=?", id, id) } func labelStatsCorrectNumIssues(ctx context.Context, id int64) error { - return StatsCorrectSQL(ctx, "UPDATE `label` SET num_issues=(SELECT COUNT(*) FROM `issue_label` WHERE label_id=?) WHERE id=?", id) + return StatsCorrectSQL(ctx, "UPDATE `label` SET num_issues=(SELECT COUNT(*) FROM `issue_label` WHERE label_id=?) WHERE id=?", id, id) } func labelStatsCorrectNumIssuesRepo(ctx context.Context, id int64) error { @@ -102,11 +105,11 @@ func milestoneStatsCorrectNumIssuesRepo(ctx context.Context, id int64) error { } func userStatsCorrectNumRepos(ctx context.Context, id int64) error { - return StatsCorrectSQL(ctx, "UPDATE `user` SET num_repos=(SELECT COUNT(*) FROM `repository` WHERE owner_id=?) WHERE id=?", id) + return StatsCorrectSQL(ctx, "UPDATE `user` SET num_repos=(SELECT COUNT(*) FROM `repository` WHERE owner_id=?) WHERE id=?", id, id) } func repoStatsCorrectIssueNumComments(ctx context.Context, id int64) error { - return StatsCorrectSQL(ctx, "UPDATE `issue` SET num_comments=(SELECT COUNT(*) FROM `comment` WHERE issue_id=? AND type=0) WHERE id=?", id) + return StatsCorrectSQL(ctx, issues_model.UpdateIssueNumCommentsBuilder(id)) } func repoStatsCorrectNumIssues(ctx context.Context, id int64) error { @@ -125,9 +128,12 @@ func repoStatsCorrectNumClosedPulls(ctx context.Context, id int64) error { return repo_model.UpdateRepoIssueNumbers(ctx, id, true, true) } -func statsQuery(args ...any) func(context.Context) ([]map[string][]byte, error) { - return func(ctx context.Context) ([]map[string][]byte, error) { - return db.GetEngine(ctx).Query(args...) +// statsQuery returns a function that queries the database for a list of IDs +// sql could be a string or a *builder.Builder +func statsQuery(sql any, args ...any) func(context.Context) ([]int64, error) { + return func(ctx context.Context) ([]int64, error) { + var ids []int64 + return ids, db.GetEngine(ctx).SQL(sql, args...).Find(&ids) } } @@ -198,7 +204,16 @@ func CheckRepoStats(ctx context.Context) error { }, // Issue.NumComments { - statsQuery("SELECT `issue`.id FROM `issue` WHERE `issue`.num_comments!=(SELECT COUNT(*) FROM `comment` WHERE issue_id=`issue`.id AND type=0)"), + statsQuery(builder.Select("`issue`.id").From("`issue`").Where( + builder.Neq{ + "`issue`.num_comments": builder.Select("COUNT(*)").From("`comment`").Where( + builder.Expr("issue_id = `issue`.id").And( + builder.In("type", issues_model.ConversationCountedCommentType()), + ), + ), + }, + ), + ), repoStatsCorrectIssueNumComments, "issue count 'num_comments'", }, diff --git a/models/repo/archiver.go b/models/repo/archiver.go index 14ffa1d89b..5a3eac9f14 100644 --- a/models/repo/archiver.go +++ b/models/repo/archiver.go @@ -56,16 +56,11 @@ func repoArchiverForRelativePath(relativePath string) (*RepoArchiver, error) { if err != nil { return nil, util.SilentWrap{Message: fmt.Sprintf("invalid storage path: %s", relativePath), Err: util.ErrInvalidArgument} } - nameExts := strings.SplitN(parts[2], ".", 2) - if len(nameExts) != 2 { + commitID, archiveType := git.SplitArchiveNameType(parts[2]) + if archiveType == git.ArchiveUnknown { return nil, util.SilentWrap{Message: fmt.Sprintf("invalid storage path: %s", relativePath), Err: util.ErrInvalidArgument} } - - return &RepoArchiver{ - RepoID: repoID, - CommitID: parts[1] + nameExts[0], - Type: git.ToArchiveType(nameExts[1]), - }, nil + return &RepoArchiver{RepoID: repoID, CommitID: commitID, Type: archiveType}, nil } // GetRepoArchiver get an archiver diff --git a/models/repo/license.go b/models/repo/license.go index 366b4598cc..9bcf0f7bc9 100644 --- a/models/repo/license.go +++ b/models/repo/license.go @@ -54,6 +54,7 @@ func UpdateRepoLicenses(ctx context.Context, repo *Repository, commitID string, for _, o := range oldLicenses { // Update already existing license if o.License == license { + o.CommitID = commitID if _, err := db.GetEngine(ctx).ID(o.ID).Cols("`commit_id`").Update(o); err != nil { return err } diff --git a/models/repo/release.go b/models/repo/release.go index ba7a3b3159..1c2e4a48e3 100644 --- a/models/repo/release.go +++ b/models/repo/release.go @@ -156,7 +156,7 @@ func IsReleaseExist(ctx context.Context, repoID int64, tagName string) (bool, er // UpdateRelease updates all columns of a release func UpdateRelease(ctx context.Context, rel *Release) error { - rel.Title, _ = util.SplitStringAtByteN(rel.Title, 255) + rel.Title = util.EllipsisDisplayString(rel.Title, 255) _, err := db.GetEngine(ctx).ID(rel.ID).AllCols().Update(rel) return err } diff --git a/models/repo/repo.go b/models/repo/repo.go index 2d9b9de88d..fb8a6642f5 100644 --- a/models/repo/repo.go +++ b/models/repo/repo.go @@ -11,6 +11,7 @@ import ( "net" "net/url" "path/filepath" + "regexp" "strconv" "strings" @@ -19,6 +20,7 @@ import ( user_model "code.gitea.io/gitea/models/user" "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/git" + giturl "code.gitea.io/gitea/modules/git/url" "code.gitea.io/gitea/modules/httplib" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/markup" @@ -60,13 +62,15 @@ func (err ErrRepoIsArchived) Error() string { } var ( - reservedRepoNames = []string{".", "..", "-"} - reservedRepoPatterns = []string{"*.git", "*.wiki", "*.rss", "*.atom"} + validRepoNamePattern = regexp.MustCompile(`[-.\w]+`) + invalidRepoNamePattern = regexp.MustCompile(`[.]{2,}`) + reservedRepoNames = []string{".", "..", "-"} + reservedRepoPatterns = []string{"*.git", "*.wiki", "*.rss", "*.atom"} ) -// IsUsableRepoName returns true when repository is usable +// IsUsableRepoName returns true when name is usable func IsUsableRepoName(name string) error { - if db.AlphaDashDotPattern.MatchString(name) { + if !validRepoNamePattern.MatchString(name) || invalidRepoNamePattern.MatchString(name) { // Note: usually this error is normally caught up earlier in the UI return db.ErrNameCharsNotAllowed{Name: name} } @@ -276,6 +280,8 @@ func (repo *Repository) IsBroken() bool { } // MarkAsBrokenEmpty marks the repo as broken and empty +// FIXME: the status "broken" and "is_empty" were abused, +// The code always set them together, no way to distinguish whether a repo is really "empty" or "broken" func (repo *Repository) MarkAsBrokenEmpty() { repo.Status = RepositoryBroken repo.IsEmpty = true @@ -632,14 +638,26 @@ type CloneLink struct { } // ComposeHTTPSCloneURL returns HTTPS clone URL based on given owner and repository name. -func ComposeHTTPSCloneURL(owner, repo string) string { - return fmt.Sprintf("%s%s/%s.git", setting.AppURL, url.PathEscape(owner), url.PathEscape(repo)) +func ComposeHTTPSCloneURL(ctx context.Context, owner, repo string) string { + return fmt.Sprintf("%s%s/%s.git", httplib.GuessCurrentAppURL(ctx), url.PathEscape(owner), url.PathEscape(repo)) } -func ComposeSSHCloneURL(ownerName, repoName string) string { +func ComposeSSHCloneURL(doer *user_model.User, ownerName, repoName string) string { sshUser := setting.SSH.User sshDomain := setting.SSH.Domain + if sshUser == "(DOER_USERNAME)" { + // Some users use SSH reverse-proxy and need to use the current signed-in username as the SSH user + // to make the SSH reverse-proxy could prepare the user's public keys ahead. + // For most cases we have the correct "doer", then use it as the SSH user. + // If we can't get the doer, then use the built-in SSH user. + if doer != nil { + sshUser = doer.Name + } else { + sshUser = setting.SSH.BuiltinServerUser + } + } + // non-standard port, it must use full URI if setting.SSH.Port != 22 { sshHost := net.JoinHostPort(sshDomain, strconv.Itoa(setting.SSH.Port)) @@ -657,21 +675,20 @@ func ComposeSSHCloneURL(ownerName, repoName string) string { return fmt.Sprintf("%s@%s:%s/%s.git", sshUser, sshHost, url.PathEscape(ownerName), url.PathEscape(repoName)) } -func (repo *Repository) cloneLink(isWiki bool) *CloneLink { - repoName := repo.Name - if isWiki { - repoName += ".wiki" - } - +func (repo *Repository) cloneLink(ctx context.Context, doer *user_model.User, repoPathName string) *CloneLink { cl := new(CloneLink) - cl.SSH = ComposeSSHCloneURL(repo.OwnerName, repoName) - cl.HTTPS = ComposeHTTPSCloneURL(repo.OwnerName, repoName) + cl.SSH = ComposeSSHCloneURL(doer, repo.OwnerName, repoPathName) + cl.HTTPS = ComposeHTTPSCloneURL(ctx, repo.OwnerName, repoPathName) return cl } // CloneLink returns clone URLs of repository. -func (repo *Repository) CloneLink() (cl *CloneLink) { - return repo.cloneLink(false) +func (repo *Repository) CloneLink(ctx context.Context, doer *user_model.User) (cl *CloneLink) { + return repo.cloneLink(ctx, doer, repo.Name) +} + +func (repo *Repository) CloneLinkGeneral(ctx context.Context) (cl *CloneLink) { + return repo.cloneLink(ctx, nil /* no doer, use a general git user */, repo.Name) } // GetOriginalURLHostname returns the hostname of a URL or the URL @@ -767,47 +784,25 @@ func GetRepositoryByName(ctx context.Context, ownerID int64, name string) (*Repo return &repo, err } -// getRepositoryURLPathSegments returns segments (owner, reponame) extracted from a url -func getRepositoryURLPathSegments(repoURL string) []string { - if strings.HasPrefix(repoURL, setting.AppURL) { - return strings.Split(strings.TrimPrefix(repoURL, setting.AppURL), "/") - } - - sshURLVariants := [4]string{ - setting.SSH.Domain + ":", - setting.SSH.User + "@" + setting.SSH.Domain + ":", - "git+ssh://" + setting.SSH.Domain + "/", - "git+ssh://" + setting.SSH.User + "@" + setting.SSH.Domain + "/", - } - - for _, sshURL := range sshURLVariants { - if strings.HasPrefix(repoURL, sshURL) { - return strings.Split(strings.TrimPrefix(repoURL, sshURL), "/") - } - } - - return nil -} - // GetRepositoryByURL returns the repository by given url func GetRepositoryByURL(ctx context.Context, repoURL string) (*Repository, error) { - // possible urls for git: - // https://my.domain/sub-path/<owner>/<repo>.git - // https://my.domain/sub-path/<owner>/<repo> - // git+ssh://user@my.domain/<owner>/<repo>.git - // git+ssh://user@my.domain/<owner>/<repo> - // user@my.domain:<owner>/<repo>.git - // user@my.domain:<owner>/<repo> - - pathSegments := getRepositoryURLPathSegments(repoURL) - - if len(pathSegments) != 2 { + ret, err := giturl.ParseRepositoryURL(ctx, repoURL) + if err != nil || ret.OwnerName == "" { return nil, fmt.Errorf("unknown or malformed repository URL") } + return GetRepositoryByOwnerAndName(ctx, ret.OwnerName, ret.RepoName) +} - ownerName := pathSegments[0] - repoName := strings.TrimSuffix(pathSegments[1], ".git") - return GetRepositoryByOwnerAndName(ctx, ownerName, repoName) +// GetRepositoryByURLRelax also accepts an SSH clone URL without user part +func GetRepositoryByURLRelax(ctx context.Context, repoURL string) (*Repository, error) { + if !strings.Contains(repoURL, "://") && !strings.Contains(repoURL, "@") { + // convert "example.com:owner/repo" to "@example.com:owner/repo" + p1, p2, p3 := strings.Index(repoURL, "."), strings.Index(repoURL, ":"), strings.Index(repoURL, "/") + if 0 < p1 && p1 < p2 && p2 < p3 { + repoURL = "@" + repoURL + } + } + return GetRepositoryByURL(ctx, repoURL) } // GetRepositoryByID returns the repository by given id if exists. diff --git a/models/repo/repo_test.go b/models/repo/repo_test.go index 6d88d170da..a9e2cdfb75 100644 --- a/models/repo/repo_test.go +++ b/models/repo/repo_test.go @@ -16,6 +16,7 @@ import ( "code.gitea.io/gitea/modules/test" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var ( @@ -132,60 +133,43 @@ func TestGetRepositoryByURL(t *testing.T) { t.Run("InvalidPath", func(t *testing.T) { repo, err := GetRepositoryByURL(db.DefaultContext, "something") - assert.Nil(t, repo) assert.Error(t, err) }) - t.Run("ValidHttpURL", func(t *testing.T) { - test := func(t *testing.T, url string) { - repo, err := GetRepositoryByURL(db.DefaultContext, url) - - assert.NotNil(t, repo) - assert.NoError(t, err) - - assert.Equal(t, int64(2), repo.ID) - assert.Equal(t, int64(2), repo.OwnerID) - } + testRepo2 := func(t *testing.T, url string) { + repo, err := GetRepositoryByURL(db.DefaultContext, url) + require.NoError(t, err) + assert.EqualValues(t, 2, repo.ID) + assert.EqualValues(t, 2, repo.OwnerID) + } - test(t, "https://try.gitea.io/user2/repo2") - test(t, "https://try.gitea.io/user2/repo2.git") + t.Run("ValidHttpURL", func(t *testing.T) { + testRepo2(t, "https://try.gitea.io/user2/repo2") + testRepo2(t, "https://try.gitea.io/user2/repo2.git") }) t.Run("ValidGitSshURL", func(t *testing.T) { - test := func(t *testing.T, url string) { - repo, err := GetRepositoryByURL(db.DefaultContext, url) - - assert.NotNil(t, repo) - assert.NoError(t, err) + testRepo2(t, "git+ssh://sshuser@try.gitea.io/user2/repo2") + testRepo2(t, "git+ssh://sshuser@try.gitea.io/user2/repo2.git") - assert.Equal(t, int64(2), repo.ID) - assert.Equal(t, int64(2), repo.OwnerID) - } - - test(t, "git+ssh://sshuser@try.gitea.io/user2/repo2") - test(t, "git+ssh://sshuser@try.gitea.io/user2/repo2.git") - - test(t, "git+ssh://try.gitea.io/user2/repo2") - test(t, "git+ssh://try.gitea.io/user2/repo2.git") + testRepo2(t, "git+ssh://try.gitea.io/user2/repo2") + testRepo2(t, "git+ssh://try.gitea.io/user2/repo2.git") }) t.Run("ValidImplicitSshURL", func(t *testing.T) { - test := func(t *testing.T, url string) { - repo, err := GetRepositoryByURL(db.DefaultContext, url) - - assert.NotNil(t, repo) - assert.NoError(t, err) + testRepo2(t, "sshuser@try.gitea.io:user2/repo2") + testRepo2(t, "sshuser@try.gitea.io:user2/repo2.git") + testRelax := func(t *testing.T, url string) { + repo, err := GetRepositoryByURLRelax(db.DefaultContext, url) + require.NoError(t, err) assert.Equal(t, int64(2), repo.ID) assert.Equal(t, int64(2), repo.OwnerID) } - - test(t, "sshuser@try.gitea.io:user2/repo2") - test(t, "sshuser@try.gitea.io:user2/repo2.git") - - test(t, "try.gitea.io:user2/repo2") - test(t, "try.gitea.io:user2/repo2.git") + // TODO: it doesn't seem to be common git ssh URL, should we really support this? + testRelax(t, "try.gitea.io:user2/repo2") + testRelax(t, "try.gitea.io:user2/repo2.git") }) } @@ -199,21 +183,40 @@ func TestComposeSSHCloneURL(t *testing.T) { setting.SSH.Domain = "domain" setting.SSH.Port = 22 setting.Repository.UseCompatSSHURI = false - assert.Equal(t, "git@domain:user/repo.git", ComposeSSHCloneURL("user", "repo")) + assert.Equal(t, "git@domain:user/repo.git", ComposeSSHCloneURL(&user_model.User{Name: "doer"}, "user", "repo")) setting.Repository.UseCompatSSHURI = true - assert.Equal(t, "ssh://git@domain/user/repo.git", ComposeSSHCloneURL("user", "repo")) + assert.Equal(t, "ssh://git@domain/user/repo.git", ComposeSSHCloneURL(&user_model.User{Name: "doer"}, "user", "repo")) // test SSH_DOMAIN while use non-standard SSH port setting.SSH.Port = 123 setting.Repository.UseCompatSSHURI = false - assert.Equal(t, "ssh://git@domain:123/user/repo.git", ComposeSSHCloneURL("user", "repo")) + assert.Equal(t, "ssh://git@domain:123/user/repo.git", ComposeSSHCloneURL(nil, "user", "repo")) setting.Repository.UseCompatSSHURI = true - assert.Equal(t, "ssh://git@domain:123/user/repo.git", ComposeSSHCloneURL("user", "repo")) + assert.Equal(t, "ssh://git@domain:123/user/repo.git", ComposeSSHCloneURL(nil, "user", "repo")) // test IPv6 SSH_DOMAIN setting.Repository.UseCompatSSHURI = false setting.SSH.Domain = "::1" setting.SSH.Port = 22 - assert.Equal(t, "git@[::1]:user/repo.git", ComposeSSHCloneURL("user", "repo")) + assert.Equal(t, "git@[::1]:user/repo.git", ComposeSSHCloneURL(nil, "user", "repo")) + setting.SSH.Port = 123 + assert.Equal(t, "ssh://git@[::1]:123/user/repo.git", ComposeSSHCloneURL(nil, "user", "repo")) + + setting.SSH.User = "(DOER_USERNAME)" + setting.SSH.Domain = "domain" + setting.SSH.Port = 22 + assert.Equal(t, "doer@domain:user/repo.git", ComposeSSHCloneURL(&user_model.User{Name: "doer"}, "user", "repo")) setting.SSH.Port = 123 - assert.Equal(t, "ssh://git@[::1]:123/user/repo.git", ComposeSSHCloneURL("user", "repo")) + assert.Equal(t, "ssh://doer@domain:123/user/repo.git", ComposeSSHCloneURL(&user_model.User{Name: "doer"}, "user", "repo")) +} + +func TestIsUsableRepoName(t *testing.T) { + assert.NoError(t, IsUsableRepoName("a")) + assert.NoError(t, IsUsableRepoName("-1_.")) + assert.NoError(t, IsUsableRepoName(".profile")) + + assert.Error(t, IsUsableRepoName("-")) + assert.Error(t, IsUsableRepoName("🌞")) + assert.Error(t, IsUsableRepoName("the..repo")) + assert.Error(t, IsUsableRepoName("foo.wiki")) + assert.Error(t, IsUsableRepoName("foo.git")) } diff --git a/models/repo/update.go b/models/repo/update.go index e7ca224028..fce357a1ac 100644 --- a/models/repo/update.go +++ b/models/repo/update.go @@ -46,6 +46,12 @@ func UpdateRepositoryCols(ctx context.Context, repo *Repository, cols ...string) return err } +// UpdateRepositoryColsNoAutoTime updates repository's columns and but applies time change automatically +func UpdateRepositoryColsNoAutoTime(ctx context.Context, repo *Repository, cols ...string) error { + _, err := db.GetEngine(ctx).ID(repo.ID).Cols(cols...).NoAutoTime().Update(repo) + return err +} + // ErrReachLimitOfRepo represents a "ReachLimitOfRepo" kind of error. type ErrReachLimitOfRepo struct { Limit int diff --git a/models/repo/wiki.go b/models/repo/wiki.go index b378666a20..4239a815b2 100644 --- a/models/repo/wiki.go +++ b/models/repo/wiki.go @@ -5,6 +5,7 @@ package repo import ( + "context" "fmt" "path/filepath" "strings" @@ -72,8 +73,8 @@ func (err ErrWikiInvalidFileName) Unwrap() error { } // WikiCloneLink returns clone URLs of repository wiki. -func (repo *Repository) WikiCloneLink() *CloneLink { - return repo.cloneLink(true) +func (repo *Repository) WikiCloneLink(ctx context.Context, doer *user_model.User) *CloneLink { + return repo.cloneLink(ctx, doer, repo.Name+".wiki") } // WikiPath returns wiki data path by given user and repository name. diff --git a/models/repo/wiki_test.go b/models/repo/wiki_test.go index 629986f741..0157b7735d 100644 --- a/models/repo/wiki_test.go +++ b/models/repo/wiki_test.go @@ -4,6 +4,7 @@ package repo_test import ( + "context" "path/filepath" "testing" @@ -18,7 +19,7 @@ func TestRepository_WikiCloneLink(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) repo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 1}) - cloneLink := repo.WikiCloneLink() + cloneLink := repo.WikiCloneLink(context.Background(), nil) assert.Equal(t, "ssh://sshuser@try.gitea.io:3000/user2/repo1.wiki.git", cloneLink.SSH) assert.Equal(t, "https://try.gitea.io/user2/repo1.wiki.git", cloneLink.HTTPS) } diff --git a/models/repo_test.go b/models/repo_test.go index 2a8a4a743e..bcf62237f0 100644 --- a/models/repo_test.go +++ b/models/repo_test.go @@ -7,6 +7,7 @@ import ( "testing" "code.gitea.io/gitea/models/db" + issues_model "code.gitea.io/gitea/models/issues" "code.gitea.io/gitea/models/unittest" "github.com/stretchr/testify/assert" @@ -22,3 +23,16 @@ func TestDoctorUserStarNum(t *testing.T) { assert.NoError(t, DoctorUserStarNum(db.DefaultContext)) } + +func Test_repoStatsCorrectIssueNumComments(t *testing.T) { + assert.NoError(t, unittest.PrepareTestDatabase()) + + issue2 := unittest.AssertExistsAndLoadBean(t, &issues_model.Issue{ID: 2}) + assert.NotNil(t, issue2) + assert.EqualValues(t, 0, issue2.NumComments) // the fixture data is wrong, but we don't fix it here + + assert.NoError(t, repoStatsCorrectIssueNumComments(db.DefaultContext, 2)) + // reload the issue + issue2 = unittest.AssertExistsAndLoadBean(t, &issues_model.Issue{ID: 2}) + assert.EqualValues(t, 1, issue2.NumComments) +} diff --git a/models/unittest/fixtures.go b/models/unittest/fixtures.go index 4dde5410d6..fb2d2d0085 100644 --- a/models/unittest/fixtures.go +++ b/models/unittest/fixtures.go @@ -1,97 +1,33 @@ // Copyright 2021 The Gitea Authors. All rights reserved. // SPDX-License-Identifier: MIT -//nolint:forbidigo package unittest import ( "fmt" - "os" - "time" "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/auth/password/hash" "code.gitea.io/gitea/modules/setting" + "code.gitea.io/gitea/modules/util" - "github.com/go-testfixtures/testfixtures/v3" "xorm.io/xorm" "xorm.io/xorm/schemas" ) -var fixturesLoader *testfixtures.Loader - -// GetXORMEngine gets the XORM engine -func GetXORMEngine(engine ...*xorm.Engine) (x *xorm.Engine) { - if len(engine) == 1 { - return engine[0] - } - return db.GetEngine(db.DefaultContext).(*xorm.Engine) +type FixturesLoader interface { + Load() error } -// InitFixtures initialize test fixtures for a test database -func InitFixtures(opts FixturesOptions, engine ...*xorm.Engine) (err error) { - e := GetXORMEngine(engine...) - var fixtureOptionFiles func(*testfixtures.Loader) error - if opts.Dir != "" { - fixtureOptionFiles = testfixtures.Directory(opts.Dir) - } else { - fixtureOptionFiles = testfixtures.Files(opts.Files...) - } - dialect := "unknown" - switch e.Dialect().URI().DBType { - case schemas.POSTGRES: - dialect = "postgres" - case schemas.MYSQL: - dialect = "mysql" - case schemas.MSSQL: - dialect = "mssql" - case schemas.SQLITE: - dialect = "sqlite3" - default: - fmt.Println("Unsupported RDBMS for integration tests") - os.Exit(1) - } - loaderOptions := []func(loader *testfixtures.Loader) error{ - testfixtures.Database(e.DB().DB), - testfixtures.Dialect(dialect), - testfixtures.DangerousSkipTestDatabaseCheck(), - fixtureOptionFiles, - } - - if e.Dialect().URI().DBType == schemas.POSTGRES { - loaderOptions = append(loaderOptions, testfixtures.SkipResetSequences()) - } +var fixturesLoader FixturesLoader - fixturesLoader, err = testfixtures.New(loaderOptions...) - if err != nil { - return err - } - - // register the dummy hash algorithm function used in the test fixtures - _ = hash.Register("dummy", hash.NewDummyHasher) - - setting.PasswordHashAlgo, _ = hash.SetDefaultPasswordHashAlgorithm("dummy") - - return err +// GetXORMEngine gets the XORM engine +func GetXORMEngine() (x *xorm.Engine) { + return db.GetEngine(db.DefaultContext).(*xorm.Engine) } -// LoadFixtures load fixtures for a test database -func LoadFixtures(engine ...*xorm.Engine) error { - e := GetXORMEngine(engine...) - var err error - // (doubt) database transaction conflicts could occur and result in ROLLBACK? just try for a few times. - for i := 0; i < 5; i++ { - if err = fixturesLoader.Load(); err == nil { - break - } - time.Sleep(200 * time.Millisecond) - } - if err != nil { - fmt.Printf("LoadFixtures failed after retries: %v\n", err) - } - // Now if we're running postgres we need to tell it to update the sequences - if e.Dialect().URI().DBType == schemas.POSTGRES { - results, err := e.QueryString(`SELECT 'SELECT SETVAL(' || +func loadFixtureResetSeqPgsql(e *xorm.Engine) error { + results, err := e.QueryString(`SELECT 'SELECT SETVAL(' || quote_literal(quote_ident(PGT.schemaname) || '.' || quote_ident(S.relname)) || ', COALESCE(MAX(' ||quote_ident(C.attname)|| '), 1) ) FROM ' || quote_ident(PGT.schemaname)|| '.'||quote_ident(T.relname)|| ';' @@ -107,22 +43,42 @@ func LoadFixtures(engine ...*xorm.Engine) error { AND D.refobjsubid = C.attnum AND T.relname = PGT.tablename ORDER BY S.relname;`) - if err != nil { - fmt.Printf("Failed to generate sequence update: %v\n", err) - return err - } - for _, r := range results { - for _, value := range r { - _, err = e.Exec(value) - if err != nil { - fmt.Printf("Failed to update sequence: %s Error: %v\n", value, err) - return err - } + if err != nil { + return fmt.Errorf("failed to generate sequence update: %w", err) + } + for _, r := range results { + for _, value := range r { + _, err = e.Exec(value) + if err != nil { + return fmt.Errorf("failed to update sequence: %s, error: %w", value, err) } } } + return nil +} + +// InitFixtures initialize test fixtures for a test database +func InitFixtures(opts FixturesOptions, engine ...*xorm.Engine) (err error) { + xormEngine := util.IfZero(util.OptionalArg(engine), GetXORMEngine()) + fixturesLoader, err = NewFixturesLoader(xormEngine, opts) + // fixturesLoader = NewFixturesLoaderVendor(xormEngine, opts) + + // register the dummy hash algorithm function used in the test fixtures _ = hash.Register("dummy", hash.NewDummyHasher) setting.PasswordHashAlgo, _ = hash.SetDefaultPasswordHashAlgorithm("dummy") - return err } + +// LoadFixtures load fixtures for a test database +func LoadFixtures() error { + if err := fixturesLoader.Load(); err != nil { + return err + } + // Now if we're running postgres we need to tell it to update the sequences + if GetXORMEngine().Dialect().URI().DBType == schemas.POSTGRES { + if err := loadFixtureResetSeqPgsql(GetXORMEngine()); err != nil { + return err + } + } + return nil +} diff --git a/models/unittest/fixtures_loader.go b/models/unittest/fixtures_loader.go new file mode 100644 index 0000000000..14686caf63 --- /dev/null +++ b/models/unittest/fixtures_loader.go @@ -0,0 +1,201 @@ +// Copyright 2024 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package unittest + +import ( + "database/sql" + "encoding/hex" + "fmt" + "os" + "path/filepath" + "slices" + "strings" + + "gopkg.in/yaml.v3" + "xorm.io/xorm" + "xorm.io/xorm/schemas" +) + +type fixtureItem struct { + tableName string + tableNameQuoted string + sqlInserts []string + sqlInsertArgs [][]any + + mssqlHasIdentityColumn bool +} + +type fixturesLoaderInternal struct { + db *sql.DB + dbType schemas.DBType + files []string + fixtures map[string]*fixtureItem + quoteObject func(string) string + paramPlaceholder func(idx int) string +} + +func (f *fixturesLoaderInternal) mssqlTableHasIdentityColumn(db *sql.DB, tableName string) (bool, error) { + row := db.QueryRow(`SELECT COUNT(*) FROM sys.identity_columns WHERE OBJECT_ID = OBJECT_ID(?)`, tableName) + var count int + if err := row.Scan(&count); err != nil { + return false, err + } + return count > 0, nil +} + +func (f *fixturesLoaderInternal) preprocessFixtureRow(row []map[string]any) (err error) { + for _, m := range row { + for k, v := range m { + if s, ok := v.(string); ok { + if strings.HasPrefix(s, "0x") { + if m[k], err = hex.DecodeString(s[2:]); err != nil { + return err + } + } + } + } + } + return nil +} + +func (f *fixturesLoaderInternal) prepareFixtureItem(file string) (_ *fixtureItem, err error) { + fixture := &fixtureItem{} + fixture.tableName, _, _ = strings.Cut(filepath.Base(file), ".") + fixture.tableNameQuoted = f.quoteObject(fixture.tableName) + + if f.dbType == schemas.MSSQL { + fixture.mssqlHasIdentityColumn, err = f.mssqlTableHasIdentityColumn(f.db, fixture.tableName) + if err != nil { + return nil, err + } + } + + data, err := os.ReadFile(file) + if err != nil { + return nil, fmt.Errorf("failed to read file %q: %w", file, err) + } + + var rows []map[string]any + if err = yaml.Unmarshal(data, &rows); err != nil { + return nil, fmt.Errorf("failed to unmarshal yaml data from %q: %w", file, err) + } + if err = f.preprocessFixtureRow(rows); err != nil { + return nil, fmt.Errorf("failed to preprocess fixture rows from %q: %w", file, err) + } + + var sqlBuf []byte + var sqlArguments []any + for _, row := range rows { + sqlBuf = append(sqlBuf, fmt.Sprintf("INSERT INTO %s (", fixture.tableNameQuoted)...) + for k, v := range row { + sqlBuf = append(sqlBuf, f.quoteObject(k)...) + sqlBuf = append(sqlBuf, ","...) + sqlArguments = append(sqlArguments, v) + } + sqlBuf = sqlBuf[:len(sqlBuf)-1] + sqlBuf = append(sqlBuf, ") VALUES ("...) + paramIdx := 1 + for range row { + sqlBuf = append(sqlBuf, f.paramPlaceholder(paramIdx)...) + sqlBuf = append(sqlBuf, ',') + paramIdx++ + } + sqlBuf[len(sqlBuf)-1] = ')' + fixture.sqlInserts = append(fixture.sqlInserts, string(sqlBuf)) + fixture.sqlInsertArgs = append(fixture.sqlInsertArgs, slices.Clone(sqlArguments)) + sqlBuf = sqlBuf[:0] + sqlArguments = sqlArguments[:0] + } + return fixture, nil +} + +func (f *fixturesLoaderInternal) loadFixtures(tx *sql.Tx, file string) (err error) { + fixture := f.fixtures[file] + if fixture == nil { + if fixture, err = f.prepareFixtureItem(file); err != nil { + return err + } + f.fixtures[file] = fixture + } + + _, err = tx.Exec(fmt.Sprintf("DELETE FROM %s", fixture.tableNameQuoted)) // sqlite3 doesn't support truncate + if err != nil { + return err + } + + if fixture.mssqlHasIdentityColumn { + _, err = tx.Exec(fmt.Sprintf("SET IDENTITY_INSERT %s ON", fixture.tableNameQuoted)) + if err != nil { + return err + } + defer func() { _, err = tx.Exec(fmt.Sprintf("SET IDENTITY_INSERT %s OFF", fixture.tableNameQuoted)) }() + } + for i := range fixture.sqlInserts { + _, err = tx.Exec(fixture.sqlInserts[i], fixture.sqlInsertArgs[i]...) + } + if err != nil { + return err + } + return nil +} + +func (f *fixturesLoaderInternal) Load() error { + tx, err := f.db.Begin() + if err != nil { + return err + } + defer func() { _ = tx.Rollback() }() + + for _, file := range f.files { + if err := f.loadFixtures(tx, file); err != nil { + return fmt.Errorf("failed to load fixtures from %s: %w", file, err) + } + } + return tx.Commit() +} + +func FixturesFileFullPaths(dir string, files []string) ([]string, error) { + if files != nil && len(files) == 0 { + return nil, nil // load nothing + } + files = slices.Clone(files) + if len(files) == 0 { + entries, err := os.ReadDir(dir) + if err != nil { + return nil, err + } + for _, e := range entries { + files = append(files, e.Name()) + } + } + for i, file := range files { + if !filepath.IsAbs(file) { + files[i] = filepath.Join(dir, file) + } + } + return files, nil +} + +func NewFixturesLoader(x *xorm.Engine, opts FixturesOptions) (FixturesLoader, error) { + files, err := FixturesFileFullPaths(opts.Dir, opts.Files) + if err != nil { + return nil, fmt.Errorf("failed to get fixtures files: %w", err) + } + f := &fixturesLoaderInternal{db: x.DB().DB, dbType: x.Dialect().URI().DBType, files: files, fixtures: map[string]*fixtureItem{}} + switch f.dbType { + case schemas.SQLITE: + f.quoteObject = func(s string) string { return fmt.Sprintf(`"%s"`, s) } + f.paramPlaceholder = func(idx int) string { return "?" } + case schemas.POSTGRES: + f.quoteObject = func(s string) string { return fmt.Sprintf(`"%s"`, s) } + f.paramPlaceholder = func(idx int) string { return fmt.Sprintf(`$%d`, idx) } + case schemas.MYSQL: + f.quoteObject = func(s string) string { return fmt.Sprintf("`%s`", s) } + f.paramPlaceholder = func(idx int) string { return "?" } + case schemas.MSSQL: + f.quoteObject = func(s string) string { return fmt.Sprintf("[%s]", s) } + f.paramPlaceholder = func(idx int) string { return "?" } + } + return f, nil +} diff --git a/models/unittest/fixtures_test.go b/models/unittest/fixtures_test.go new file mode 100644 index 0000000000..a4c55f4e55 --- /dev/null +++ b/models/unittest/fixtures_test.go @@ -0,0 +1,114 @@ +// Copyright 2024 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package unittest_test + +import ( + "path/filepath" + "testing" + + "code.gitea.io/gitea/models/unittest" + user_model "code.gitea.io/gitea/models/user" + "code.gitea.io/gitea/modules/test" + + "github.com/stretchr/testify/require" + "xorm.io/xorm" +) + +var NewFixturesLoaderVendor = func(e *xorm.Engine, opts unittest.FixturesOptions) (unittest.FixturesLoader, error) { + return nil, nil +} + +/* +// the old code is kept here in case we are still interested in benchmarking the two implementations +func init() { + NewFixturesLoaderVendor = func(e *xorm.Engine, opts unittest.FixturesOptions) (unittest.FixturesLoader, error) { + return NewFixturesLoaderVendorGoTestfixtures(e, opts) + } +} + +func NewFixturesLoaderVendorGoTestfixtures(e *xorm.Engine, opts unittest.FixturesOptions) (*testfixtures.Loader, error) { + files, err := unittest.FixturesFileFullPaths(opts.Dir, opts.Files) + if err != nil { + return nil, fmt.Errorf("failed to get fixtures files: %w", err) + } + var dialect string + switch e.Dialect().URI().DBType { + case schemas.POSTGRES: + dialect = "postgres" + case schemas.MYSQL: + dialect = "mysql" + case schemas.MSSQL: + dialect = "mssql" + case schemas.SQLITE: + dialect = "sqlite3" + default: + return nil, fmt.Errorf("unsupported RDBMS for integration tests: %q", e.Dialect().URI().DBType) + } + loaderOptions := []func(loader *testfixtures.Loader) error{ + testfixtures.Database(e.DB().DB), + testfixtures.Dialect(dialect), + testfixtures.DangerousSkipTestDatabaseCheck(), + testfixtures.Files(files...), + } + if e.Dialect().URI().DBType == schemas.POSTGRES { + loaderOptions = append(loaderOptions, testfixtures.SkipResetSequences()) + } + return testfixtures.New(loaderOptions...) +} +*/ + +func prepareTestFixturesLoaders(t testing.TB) unittest.FixturesOptions { + _ = user_model.User{} + opts := unittest.FixturesOptions{Dir: filepath.Join(test.SetupGiteaRoot(), "models", "fixtures"), Files: []string{ + "user.yml", + }} + require.NoError(t, unittest.CreateTestEngine(opts)) + return opts +} + +func TestFixturesLoader(t *testing.T) { + opts := prepareTestFixturesLoaders(t) + loaderInternal, err := unittest.NewFixturesLoader(unittest.GetXORMEngine(), opts) + require.NoError(t, err) + loaderVendor, err := NewFixturesLoaderVendor(unittest.GetXORMEngine(), opts) + require.NoError(t, err) + t.Run("Internal", func(t *testing.T) { + require.NoError(t, loaderInternal.Load()) + require.NoError(t, loaderInternal.Load()) + }) + t.Run("Vendor", func(t *testing.T) { + if loaderVendor == nil { + t.Skip() + } + require.NoError(t, loaderVendor.Load()) + require.NoError(t, loaderVendor.Load()) + }) +} + +func BenchmarkFixturesLoader(b *testing.B) { + opts := prepareTestFixturesLoaders(b) + require.NoError(b, unittest.CreateTestEngine(opts)) + loaderInternal, err := unittest.NewFixturesLoader(unittest.GetXORMEngine(), opts) + require.NoError(b, err) + loaderVendor, err := NewFixturesLoaderVendor(unittest.GetXORMEngine(), opts) + require.NoError(b, err) + + // BenchmarkFixturesLoader/Vendor + // BenchmarkFixturesLoader/Vendor-12 1696 719416 ns/op + // BenchmarkFixturesLoader/Internal + // BenchmarkFixturesLoader/Internal-12 1746 670457 ns/op + b.Run("Internal", func(b *testing.B) { + for i := 0; i < b.N; i++ { + require.NoError(b, loaderInternal.Load()) + } + }) + b.Run("Vendor", func(b *testing.B) { + if loaderVendor == nil { + b.Skip() + } + for i := 0; i < b.N; i++ { + require.NoError(b, loaderVendor.Load()) + } + }) +} diff --git a/models/unittest/fscopy.go b/models/unittest/fscopy.go index 4d7ee2151d..b7ba6b7ef5 100644 --- a/models/unittest/fscopy.go +++ b/models/unittest/fscopy.go @@ -11,35 +11,13 @@ import ( "code.gitea.io/gitea/modules/util" ) -// Copy copies file from source to target path. -func Copy(src, dest string) error { - // Gather file information to set back later. - si, err := os.Lstat(src) - if err != nil { - return err - } - - // Handle symbolic link. - if si.Mode()&os.ModeSymlink != 0 { - target, err := os.Readlink(src) - if err != nil { - return err - } - // NOTE: os.Chmod and os.Chtimes don't recognize symbolic link, - // which will lead "no such file or directory" error. - return os.Symlink(target, dest) - } - - return util.CopyFile(src, dest) -} - -// Sync synchronizes the two files. This is skipped if both files +// SyncFile synchronizes the two files. This is skipped if both files // exist and the size, modtime, and mode match. -func Sync(srcPath, destPath string) error { +func SyncFile(srcPath, destPath string) error { dest, err := os.Stat(destPath) if err != nil { if os.IsNotExist(err) { - return Copy(srcPath, destPath) + return util.CopyFile(srcPath, destPath) } return err } @@ -55,7 +33,7 @@ func Sync(srcPath, destPath string) error { return nil } - return Copy(srcPath, destPath) + return util.CopyFile(srcPath, destPath) } // SyncDirs synchronizes files recursively from source to target directory. @@ -66,37 +44,45 @@ func SyncDirs(srcPath, destPath string) error { return err } + // the keep file is used to keep the directory in a git repository, it doesn't need to be synced + // and go-git doesn't work with the ".keep" file (it would report errors like "ref is empty") + const keepFile = ".keep" + // find and delete all untracked files - destFiles, err := util.StatDir(destPath, true) + destFiles, err := util.ListDirRecursively(destPath, &util.ListDirOptions{IncludeDir: true}) if err != nil { return err } for _, destFile := range destFiles { destFilePath := filepath.Join(destPath, destFile) + shouldRemove := filepath.Base(destFilePath) == keepFile if _, err = os.Stat(filepath.Join(srcPath, destFile)); err != nil { if os.IsNotExist(err) { - // if src file does not exist, remove dest file - if err = os.RemoveAll(destFilePath); err != nil { - return err - } + shouldRemove = true } else { return err } } + // if src file does not exist, remove dest file + if shouldRemove { + if err = os.RemoveAll(destFilePath); err != nil { + return err + } + } } // sync src files to dest - srcFiles, err := util.StatDir(srcPath, true) + srcFiles, err := util.ListDirRecursively(srcPath, &util.ListDirOptions{IncludeDir: true}) if err != nil { return err } for _, srcFile := range srcFiles { destFilePath := filepath.Join(destPath, srcFile) - // util.StatDir appends a slash to the directory name + // util.ListDirRecursively appends a slash to the directory name if strings.HasSuffix(srcFile, "/") { err = os.MkdirAll(destFilePath, os.ModePerm) - } else { - err = Sync(filepath.Join(srcPath, srcFile), destFilePath) + } else if filepath.Base(destFilePath) != keepFile { + err = SyncFile(filepath.Join(srcPath, srcFile), destFilePath) } if err != nil { return err diff --git a/models/unittest/reflection.go b/models/unittest/reflection.go index 141fc66b99..bc96a05973 100644 --- a/models/unittest/reflection.go +++ b/models/unittest/reflection.go @@ -4,7 +4,7 @@ package unittest import ( - "log" + "fmt" "reflect" ) @@ -14,7 +14,7 @@ func fieldByName(v reflect.Value, field string) reflect.Value { } f := v.FieldByName(field) if !f.IsValid() { - log.Panicf("can not read %s for %v", field, v) + panic(fmt.Errorf("can not read %s for %v", field, v)) } return f } diff --git a/models/unittest/testdb.go b/models/unittest/testdb.go index 5794d5109e..7a9ca9698d 100644 --- a/models/unittest/testdb.go +++ b/models/unittest/testdb.go @@ -28,16 +28,7 @@ import ( "xorm.io/xorm/names" ) -// giteaRoot a path to the gitea root -var ( - giteaRoot string - fixturesDir string -) - -// FixturesDir returns the fixture directory -func FixturesDir() string { - return fixturesDir -} +var giteaRoot string func fatalTestError(fmtStr string, args ...any) { _, _ = fmt.Fprintf(os.Stderr, fmtStr, args...) @@ -68,6 +59,7 @@ func InitSettings() { _ = hash.Register("dummy", hash.NewDummyHasher) setting.PasswordHashAlgo, _ = hash.SetDefaultPasswordHashAlgorithm("dummy") + setting.InitGiteaEnvVarsForTesting() } // TestOptions represents test options @@ -79,44 +71,20 @@ type TestOptions struct { // MainTest a reusable TestMain(..) function for unit tests that need to use a // test database. Creates the test database, and sets necessary settings. -func MainTest(m *testing.M, testOpts ...*TestOptions) { - searchDir, _ := os.Getwd() - for searchDir != "" { - if _, err := os.Stat(filepath.Join(searchDir, "go.mod")); err == nil { - break // The "go.mod" should be the one for Gitea repository - } - if dir := filepath.Dir(searchDir); dir == searchDir { - searchDir = "" // reaches the root of filesystem - } else { - searchDir = dir - } - } - if searchDir == "" { - panic("The tests should run in a Gitea repository, there should be a 'go.mod' in the root") - } - - giteaRoot = searchDir +func MainTest(m *testing.M, testOptsArg ...*TestOptions) { + testOpts := util.OptionalArg(testOptsArg, &TestOptions{}) + giteaRoot = test.SetupGiteaRoot() setting.CustomPath = filepath.Join(giteaRoot, "custom") InitSettings() - fixturesDir = filepath.Join(giteaRoot, "models", "fixtures") - var opts FixturesOptions - if len(testOpts) == 0 || len(testOpts[0].FixtureFiles) == 0 { - opts.Dir = fixturesDir - } else { - for _, f := range testOpts[0].FixtureFiles { - if len(f) != 0 { - opts.Files = append(opts.Files, filepath.Join(fixturesDir, f)) - } - } - } - - if err := CreateTestEngine(opts); err != nil { + fixturesOpts := FixturesOptions{Dir: filepath.Join(giteaRoot, "models", "fixtures"), Files: testOpts.FixtureFiles} + if err := CreateTestEngine(fixturesOpts); err != nil { fatalTestError("Error creating test engine: %v\n", err) } setting.IsInTesting = true setting.AppURL = "https://try.gitea.io/" + setting.Domain = "try.gitea.io" setting.RunUser = "runuser" setting.SSH.User = "sshuser" setting.SSH.BuiltinServerUser = "builtinuser" @@ -172,16 +140,16 @@ func MainTest(m *testing.M, testOpts ...*TestOptions) { fatalTestError("git.Init: %v\n", err) } - if len(testOpts) > 0 && testOpts[0].SetUp != nil { - if err := testOpts[0].SetUp(); err != nil { + if testOpts.SetUp != nil { + if err := testOpts.SetUp(); err != nil { fatalTestError("set up failed: %v\n", err) } } exitStatus := m.Run() - if len(testOpts) > 0 && testOpts[0].TearDown != nil { - if err := testOpts[0].TearDown(); err != nil { + if testOpts.TearDown != nil { + if err := testOpts.TearDown(); err != nil { fatalTestError("tear down failed: %v\n", err) } } @@ -206,7 +174,7 @@ func CreateTestEngine(opts FixturesOptions) error { x, err := xorm.NewEngine("sqlite3", "file::memory:?cache=shared&_txlock=immediate") if err != nil { if strings.Contains(err.Error(), "unknown driver") { - return fmt.Errorf(`sqlite3 requires: import _ "github.com/mattn/go-sqlite3" or -tags sqlite,sqlite_unlock_notify%s%w`, "\n", err) + return fmt.Errorf(`sqlite3 requires: -tags sqlite,sqlite_unlock_notify%s%w`, "\n", err) } return err } diff --git a/models/unittest/unit_tests.go b/models/unittest/unit_tests.go index 4ac858e04e..1c5595aef8 100644 --- a/models/unittest/unit_tests.go +++ b/models/unittest/unit_tests.go @@ -4,13 +4,17 @@ package unittest import ( + "fmt" "math" + "os" + "strings" "code.gitea.io/gitea/models/db" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "xorm.io/builder" + "xorm.io/xorm" ) // Code in this file is mainly used by unittest.CheckConsistencyFor, which is not in the unit test for various reasons. @@ -51,22 +55,23 @@ func whereOrderConditions(e db.Engine, conditions []any) db.Engine { return e.OrderBy(orderBy) } -// LoadBeanIfExists loads beans from fixture database if exist -func LoadBeanIfExists(bean any, conditions ...any) (bool, error) { +func getBeanIfExists(bean any, conditions ...any) (bool, error) { e := db.GetEngine(db.DefaultContext) return whereOrderConditions(e, conditions).Get(bean) } -// BeanExists for testing, check if a bean exists -func BeanExists(t assert.TestingT, bean any, conditions ...any) bool { - exists, err := LoadBeanIfExists(bean, conditions...) - assert.NoError(t, err) - return exists +func GetBean[T any](t require.TestingT, bean T, conditions ...any) (ret T) { + exists, err := getBeanIfExists(bean, conditions...) + require.NoError(t, err) + if exists { + return bean + } + return ret } // AssertExistsAndLoadBean assert that a bean exists and load it from the test database func AssertExistsAndLoadBean[T any](t require.TestingT, bean T, conditions ...any) T { - exists, err := LoadBeanIfExists(bean, conditions...) + exists, err := getBeanIfExists(bean, conditions...) require.NoError(t, err) require.True(t, exists, "Expected to find %+v (of type %T, with conditions %+v), but did not", @@ -112,25 +117,11 @@ func GetCount(t assert.TestingT, bean any, conditions ...any) int { // AssertNotExistsBean assert that a bean does not exist in the test database func AssertNotExistsBean(t assert.TestingT, bean any, conditions ...any) { - exists, err := LoadBeanIfExists(bean, conditions...) + exists, err := getBeanIfExists(bean, conditions...) assert.NoError(t, err) assert.False(t, exists) } -// AssertExistsIf asserts that a bean exists or does not exist, depending on -// what is expected. -func AssertExistsIf(t assert.TestingT, expected bool, bean any, conditions ...any) { - exists, err := LoadBeanIfExists(bean, conditions...) - assert.NoError(t, err) - assert.Equal(t, expected, exists) -} - -// AssertSuccessfulInsert assert that beans is successfully inserted -func AssertSuccessfulInsert(t assert.TestingT, beans ...any) { - err := db.Insert(db.DefaultContext, beans...) - assert.NoError(t, err) -} - // AssertCount assert the count of a bean func AssertCount(t assert.TestingT, bean, expected any) bool { return assert.EqualValues(t, expected, GetCount(t, bean)) @@ -155,3 +146,39 @@ func AssertCountByCond(t assert.TestingT, tableName string, cond builder.Cond, e return assert.EqualValues(t, expected, GetCountByCond(t, tableName, cond), "Failed consistency test, the counted bean (of table %s) was %+v", tableName, cond) } + +// DumpQueryResult dumps the result of a query for debugging purpose +func DumpQueryResult(t require.TestingT, sqlOrBean any, sqlArgs ...any) { + x := db.GetEngine(db.DefaultContext).(*xorm.Engine) + goDB := x.DB().DB + sql, ok := sqlOrBean.(string) + if !ok { + sql = fmt.Sprintf("SELECT * FROM %s", db.TableName(sqlOrBean)) + } else if !strings.Contains(sql, " ") { + sql = fmt.Sprintf("SELECT * FROM %s", sql) + } + rows, err := goDB.Query(sql, sqlArgs...) + require.NoError(t, err) + defer rows.Close() + columns, err := rows.Columns() + require.NoError(t, err) + + _, _ = fmt.Fprintf(os.Stdout, "====== DumpQueryResult: %s ======\n", sql) + idx := 0 + for rows.Next() { + row := make([]any, len(columns)) + rowPointers := make([]any, len(columns)) + for i := range row { + rowPointers[i] = &row[i] + } + require.NoError(t, rows.Scan(rowPointers...)) + _, _ = fmt.Fprintf(os.Stdout, "- # row[%d]\n", idx) + for i, col := range columns { + _, _ = fmt.Fprintf(os.Stdout, " %s: %v\n", col, row[i]) + } + idx++ + } + if idx == 0 { + _, _ = fmt.Fprintf(os.Stdout, "(no result, columns: %s)\n", strings.Join(columns, ", ")) + } +} diff --git a/models/user/email_address.go b/models/user/email_address.go index 5c04909ed7..74ba5f617a 100644 --- a/models/user/email_address.go +++ b/models/user/email_address.go @@ -357,8 +357,8 @@ func VerifyActiveEmailCode(ctx context.Context, code, email string) *EmailAddres if user := GetVerifyUser(ctx, code); user != nil { // time limit code prefix := code[:base.TimeLimitCodeLength] - data := fmt.Sprintf("%d%s%s%s%s", user.ID, email, user.LowerName, user.Passwd, user.Rands) - + opts := &TimeLimitCodeOptions{Purpose: TimeLimitCodeActivateEmail, NewEmail: email} + data := makeTimeLimitCodeHashData(opts, user) if base.VerifyTimeLimitCode(time.Now(), data, setting.Service.ActiveCodeLives, prefix) { emailAddress := &EmailAddress{UID: user.ID, Email: email} if has, _ := db.GetEngine(ctx).Get(emailAddress); has { @@ -486,10 +486,10 @@ func ActivateUserEmail(ctx context.Context, userID int64, email string, activate // Activate/deactivate a user's primary email address and account if addr.IsPrimary { - user, exist, err := db.Get[User](ctx, builder.Eq{"id": userID, "email": email}) + user, exist, err := db.Get[User](ctx, builder.Eq{"id": userID}) if err != nil { return err - } else if !exist { + } else if !exist || !strings.EqualFold(user.Email, email) { return fmt.Errorf("no user with ID: %d and Email: %s", userID, email) } diff --git a/models/user/openid_test.go b/models/user/openid_test.go index 27e6edd1e0..708af9e653 100644 --- a/models/user/openid_test.go +++ b/models/user/openid_test.go @@ -11,6 +11,7 @@ import ( user_model "code.gitea.io/gitea/models/user" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestGetUserOpenIDs(t *testing.T) { @@ -34,30 +35,23 @@ func TestGetUserOpenIDs(t *testing.T) { func TestToggleUserOpenIDVisibility(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) oids, err := user_model.GetUserOpenIDs(db.DefaultContext, int64(2)) - if !assert.NoError(t, err) || !assert.Len(t, oids, 1) { - return - } + require.NoError(t, err) + require.Len(t, oids, 1) assert.True(t, oids[0].Show) err = user_model.ToggleUserOpenIDVisibility(db.DefaultContext, oids[0].ID) - if !assert.NoError(t, err) { - return - } + require.NoError(t, err) oids, err = user_model.GetUserOpenIDs(db.DefaultContext, int64(2)) - if !assert.NoError(t, err) || !assert.Len(t, oids, 1) { - return - } + require.NoError(t, err) + require.Len(t, oids, 1) + assert.False(t, oids[0].Show) err = user_model.ToggleUserOpenIDVisibility(db.DefaultContext, oids[0].ID) - if !assert.NoError(t, err) { - return - } + require.NoError(t, err) oids, err = user_model.GetUserOpenIDs(db.DefaultContext, int64(2)) - if !assert.NoError(t, err) { - return - } + require.NoError(t, err) if assert.Len(t, oids, 1) { assert.True(t, oids[0].Show) } diff --git a/models/user/search.go b/models/user/search.go index 6af3389237..85915f4020 100644 --- a/models/user/search.go +++ b/models/user/search.go @@ -39,8 +39,6 @@ type SearchUserOptions struct { IsTwoFactorEnabled optional.Option[bool] IsProhibitLogin optional.Option[bool] IncludeReserved bool - - ExtraParamStrings map[string]string } func (opts *SearchUserOptions) toSearchQueryBase(ctx context.Context) *xorm.Session { diff --git a/models/user/user.go b/models/user/user.go index 72caafc3ba..19879fbcc7 100644 --- a/models/user/user.go +++ b/models/user/user.go @@ -181,7 +181,8 @@ func (u *User) BeforeUpdate() { u.MaxRepoCreation = -1 } - // Organization does not need email + // FIXME: this email doesn't need to be in lowercase, because the emails are mainly managed by the email table with lower_email field + // This trick could be removed in new releases to display the user inputed email as-is. u.Email = strings.ToLower(u.Email) if !u.IsOrganization() { if len(u.AvatarEmail) == 0 { @@ -190,9 +191,9 @@ func (u *User) BeforeUpdate() { } u.LowerName = strings.ToLower(u.Name) - u.Location = base.TruncateString(u.Location, 255) - u.Website = base.TruncateString(u.Website, 255) - u.Description = base.TruncateString(u.Description, 255) + u.Location = util.TruncateRunes(u.Location, 255) + u.Website = util.TruncateRunes(u.Website, 255) + u.Description = util.TruncateRunes(u.Description, 255) } // AfterLoad is invoked from XORM after filling all the fields of this object. @@ -310,17 +311,6 @@ func (u *User) OrganisationLink() string { return setting.AppSubURL + "/org/" + url.PathEscape(u.Name) } -// GenerateEmailActivateCode generates an activate code based on user information and given e-mail. -func (u *User) GenerateEmailActivateCode(email string) string { - code := base.CreateTimeLimitCode( - fmt.Sprintf("%d%s%s%s%s", u.ID, email, u.LowerName, u.Passwd, u.Rands), - setting.Service.ActiveCodeLives, time.Now(), nil) - - // Add tail hex username - code += hex.EncodeToString([]byte(u.LowerName)) - return code -} - // GetUserFollowers returns range of user's followers. func GetUserFollowers(ctx context.Context, u, viewer *User, listOptions db.ListOptions) ([]*User, int64, error) { sess := db.GetEngine(ctx). @@ -501,9 +491,9 @@ func (u *User) GitName() string { // ShortName ellipses username to length func (u *User) ShortName(length int) string { if setting.UI.DefaultShowFullName && len(u.FullName) > 0 { - return base.EllipsisString(u.FullName, length) + return util.EllipsisDisplayString(u.FullName, length) } - return base.EllipsisString(u.Name, length) + return util.EllipsisDisplayString(u.Name, length) } // IsMailable checks if a user is eligible @@ -863,12 +853,38 @@ func GetVerifyUser(ctx context.Context, code string) (user *User) { return nil } -// VerifyUserActiveCode verifies active code when active account -func VerifyUserActiveCode(ctx context.Context, code string) (user *User) { +type TimeLimitCodePurpose string + +const ( + TimeLimitCodeActivateAccount TimeLimitCodePurpose = "activate_account" + TimeLimitCodeActivateEmail TimeLimitCodePurpose = "activate_email" + TimeLimitCodeResetPassword TimeLimitCodePurpose = "reset_password" +) + +type TimeLimitCodeOptions struct { + Purpose TimeLimitCodePurpose + NewEmail string +} + +func makeTimeLimitCodeHashData(opts *TimeLimitCodeOptions, u *User) string { + return fmt.Sprintf("%s|%d|%s|%s|%s|%s", opts.Purpose, u.ID, strings.ToLower(util.IfZero(opts.NewEmail, u.Email)), u.LowerName, u.Passwd, u.Rands) +} + +// GenerateUserTimeLimitCode generates a time-limit code based on user information and given e-mail. +// TODO: need to use cache or db to store it to make sure a code can only be consumed once +func GenerateUserTimeLimitCode(opts *TimeLimitCodeOptions, u *User) string { + data := makeTimeLimitCodeHashData(opts, u) + code := base.CreateTimeLimitCode(data, setting.Service.ActiveCodeLives, time.Now(), nil) + code += hex.EncodeToString([]byte(u.LowerName)) // Add tail hex username + return code +} + +// VerifyUserTimeLimitCode verifies the time-limit code +func VerifyUserTimeLimitCode(ctx context.Context, opts *TimeLimitCodeOptions, code string) (user *User) { if user = GetVerifyUser(ctx, code); user != nil { // time limit code prefix := code[:base.TimeLimitCodeLength] - data := fmt.Sprintf("%d%s%s%s%s", user.ID, user.Email, user.LowerName, user.Passwd, user.Rands) + data := makeTimeLimitCodeHashData(opts, user) if base.VerifyTimeLimitCode(time.Now(), data, setting.Service.ActiveCodeLives, prefix) { return user } diff --git a/models/user/user_list.go b/models/user/user_list.go new file mode 100644 index 0000000000..c66d59f0d9 --- /dev/null +++ b/models/user/user_list.go @@ -0,0 +1,47 @@ +// Copyright 2025 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package user + +import ( + "context" + + "code.gitea.io/gitea/models/db" +) + +func GetUsersMapByIDs(ctx context.Context, userIDs []int64) (map[int64]*User, error) { + userMaps := make(map[int64]*User, len(userIDs)) + left := len(userIDs) + for left > 0 { + limit := db.DefaultMaxInSize + if left < limit { + limit = left + } + err := db.GetEngine(ctx). + In("id", userIDs[:limit]). + Find(&userMaps) + if err != nil { + return nil, err + } + left -= limit + userIDs = userIDs[limit:] + } + return userMaps, nil +} + +func GetPossibleUserFromMap(userID int64, usererMaps map[int64]*User) *User { + switch userID { + case GhostUserID: + return NewGhostUser() + case ActionsUserID: + return NewActionsUser() + case 0: + return nil + default: + user, ok := usererMaps[userID] + if !ok { + return NewGhostUser() + } + return user + } +} diff --git a/models/webhook/webhook.go b/models/webhook/webhook.go index 894357e36a..97ad373027 100644 --- a/models/webhook/webhook.go +++ b/models/webhook/webhook.go @@ -167,186 +167,39 @@ func (w *Webhook) UpdateEvent() error { return err } -// HasCreateEvent returns true if hook enabled create event. -func (w *Webhook) HasCreateEvent() bool { - return w.SendEverything || - (w.ChooseEvents && w.HookEvents.Create) -} - -// HasDeleteEvent returns true if hook enabled delete event. -func (w *Webhook) HasDeleteEvent() bool { - return w.SendEverything || - (w.ChooseEvents && w.HookEvents.Delete) -} - -// HasForkEvent returns true if hook enabled fork event. -func (w *Webhook) HasForkEvent() bool { - return w.SendEverything || - (w.ChooseEvents && w.HookEvents.Fork) -} - -// HasIssuesEvent returns true if hook enabled issues event. -func (w *Webhook) HasIssuesEvent() bool { - return w.SendEverything || - (w.ChooseEvents && w.HookEvents.Issues) -} - -// HasIssuesAssignEvent returns true if hook enabled issues assign event. -func (w *Webhook) HasIssuesAssignEvent() bool { - return w.SendEverything || - (w.ChooseEvents && w.HookEvents.IssueAssign) -} - -// HasIssuesLabelEvent returns true if hook enabled issues label event. -func (w *Webhook) HasIssuesLabelEvent() bool { - return w.SendEverything || - (w.ChooseEvents && w.HookEvents.IssueLabel) -} - -// HasIssuesMilestoneEvent returns true if hook enabled issues milestone event. -func (w *Webhook) HasIssuesMilestoneEvent() bool { - return w.SendEverything || - (w.ChooseEvents && w.HookEvents.IssueMilestone) -} - -// HasIssueCommentEvent returns true if hook enabled issue_comment event. -func (w *Webhook) HasIssueCommentEvent() bool { - return w.SendEverything || - (w.ChooseEvents && w.HookEvents.IssueComment) -} - -// HasPushEvent returns true if hook enabled push event. -func (w *Webhook) HasPushEvent() bool { - return w.PushOnly || w.SendEverything || - (w.ChooseEvents && w.HookEvents.Push) -} - -// HasPullRequestEvent returns true if hook enabled pull request event. -func (w *Webhook) HasPullRequestEvent() bool { - return w.SendEverything || - (w.ChooseEvents && w.HookEvents.PullRequest) -} - -// HasPullRequestAssignEvent returns true if hook enabled pull request assign event. -func (w *Webhook) HasPullRequestAssignEvent() bool { - return w.SendEverything || - (w.ChooseEvents && w.HookEvents.PullRequestAssign) -} - -// HasPullRequestLabelEvent returns true if hook enabled pull request label event. -func (w *Webhook) HasPullRequestLabelEvent() bool { - return w.SendEverything || - (w.ChooseEvents && w.HookEvents.PullRequestLabel) -} - -// HasPullRequestMilestoneEvent returns true if hook enabled pull request milestone event. -func (w *Webhook) HasPullRequestMilestoneEvent() bool { - return w.SendEverything || - (w.ChooseEvents && w.HookEvents.PullRequestMilestone) -} - -// HasPullRequestCommentEvent returns true if hook enabled pull_request_comment event. -func (w *Webhook) HasPullRequestCommentEvent() bool { - return w.SendEverything || - (w.ChooseEvents && w.HookEvents.PullRequestComment) -} - -// HasPullRequestApprovedEvent returns true if hook enabled pull request review event. -func (w *Webhook) HasPullRequestApprovedEvent() bool { - return w.SendEverything || - (w.ChooseEvents && w.HookEvents.PullRequestReview) -} - -// HasPullRequestRejectedEvent returns true if hook enabled pull request review event. -func (w *Webhook) HasPullRequestRejectedEvent() bool { - return w.SendEverything || - (w.ChooseEvents && w.HookEvents.PullRequestReview) -} - -// HasPullRequestReviewCommentEvent returns true if hook enabled pull request review event. -func (w *Webhook) HasPullRequestReviewCommentEvent() bool { - return w.SendEverything || - (w.ChooseEvents && w.HookEvents.PullRequestReview) -} - -// HasPullRequestSyncEvent returns true if hook enabled pull request sync event. -func (w *Webhook) HasPullRequestSyncEvent() bool { - return w.SendEverything || - (w.ChooseEvents && w.HookEvents.PullRequestSync) -} - -// HasWikiEvent returns true if hook enabled wiki event. -func (w *Webhook) HasWikiEvent() bool { - return w.SendEverything || - (w.ChooseEvents && w.HookEvent.Wiki) -} - -// HasReleaseEvent returns if hook enabled release event. -func (w *Webhook) HasReleaseEvent() bool { - return w.SendEverything || - (w.ChooseEvents && w.HookEvents.Release) -} - -// HasRepositoryEvent returns if hook enabled repository event. -func (w *Webhook) HasRepositoryEvent() bool { - return w.SendEverything || - (w.ChooseEvents && w.HookEvents.Repository) -} - -// HasPackageEvent returns if hook enabled package event. -func (w *Webhook) HasPackageEvent() bool { - return w.SendEverything || - (w.ChooseEvents && w.HookEvents.Package) -} - -// HasPullRequestReviewRequestEvent returns true if hook enabled pull request review request event. -func (w *Webhook) HasPullRequestReviewRequestEvent() bool { - return w.SendEverything || - (w.ChooseEvents && w.HookEvents.PullRequestReviewRequest) -} - -// EventCheckers returns event checkers -func (w *Webhook) EventCheckers() []struct { - Has func() bool - Type webhook_module.HookEventType -} { - return []struct { - Has func() bool - Type webhook_module.HookEventType - }{ - {w.HasCreateEvent, webhook_module.HookEventCreate}, - {w.HasDeleteEvent, webhook_module.HookEventDelete}, - {w.HasForkEvent, webhook_module.HookEventFork}, - {w.HasPushEvent, webhook_module.HookEventPush}, - {w.HasIssuesEvent, webhook_module.HookEventIssues}, - {w.HasIssuesAssignEvent, webhook_module.HookEventIssueAssign}, - {w.HasIssuesLabelEvent, webhook_module.HookEventIssueLabel}, - {w.HasIssuesMilestoneEvent, webhook_module.HookEventIssueMilestone}, - {w.HasIssueCommentEvent, webhook_module.HookEventIssueComment}, - {w.HasPullRequestEvent, webhook_module.HookEventPullRequest}, - {w.HasPullRequestAssignEvent, webhook_module.HookEventPullRequestAssign}, - {w.HasPullRequestLabelEvent, webhook_module.HookEventPullRequestLabel}, - {w.HasPullRequestMilestoneEvent, webhook_module.HookEventPullRequestMilestone}, - {w.HasPullRequestCommentEvent, webhook_module.HookEventPullRequestComment}, - {w.HasPullRequestApprovedEvent, webhook_module.HookEventPullRequestReviewApproved}, - {w.HasPullRequestRejectedEvent, webhook_module.HookEventPullRequestReviewRejected}, - {w.HasPullRequestCommentEvent, webhook_module.HookEventPullRequestReviewComment}, - {w.HasPullRequestSyncEvent, webhook_module.HookEventPullRequestSync}, - {w.HasWikiEvent, webhook_module.HookEventWiki}, - {w.HasRepositoryEvent, webhook_module.HookEventRepository}, - {w.HasReleaseEvent, webhook_module.HookEventRelease}, - {w.HasPackageEvent, webhook_module.HookEventPackage}, - {w.HasPullRequestReviewRequestEvent, webhook_module.HookEventPullRequestReviewRequest}, +func (w *Webhook) HasEvent(evt webhook_module.HookEventType) bool { + if w.SendEverything { + return true } + if w.PushOnly { + return evt == webhook_module.HookEventPush + } + checkEvt := evt + switch evt { + case webhook_module.HookEventPullRequestReviewApproved, webhook_module.HookEventPullRequestReviewRejected, webhook_module.HookEventPullRequestReviewComment: + checkEvt = webhook_module.HookEventPullRequestReview + } + return w.HookEvents[checkEvt] } // EventsArray returns an array of hook events func (w *Webhook) EventsArray() []string { - events := make([]string, 0, 7) + if w.SendEverything { + events := make([]string, 0, len(webhook_module.AllEvents())) + for _, evt := range webhook_module.AllEvents() { + events = append(events, string(evt)) + } + return events + } + + if w.PushOnly { + return []string{string(webhook_module.HookEventPush)} + } - for _, c := range w.EventCheckers() { - if c.Has() { - events = append(events, string(c.Type)) + events := make([]string, 0, len(w.HookEvents)) + for event, enabled := range w.HookEvents { + if enabled { + events = append(events, string(event)) } } return events diff --git a/models/webhook/webhook_system.go b/models/webhook/webhook_system.go index a2a9ee321a..58d9d4a5c1 100644 --- a/models/webhook/webhook_system.go +++ b/models/webhook/webhook_system.go @@ -11,6 +11,19 @@ import ( "code.gitea.io/gitea/modules/optional" ) +// GetSystemOrDefaultWebhooks returns webhooks by given argument or all if argument is missing. +func GetSystemOrDefaultWebhooks(ctx context.Context, isSystemWebhook optional.Option[bool]) ([]*Webhook, error) { + webhooks := make([]*Webhook, 0, 5) + if !isSystemWebhook.Has() { + return webhooks, db.GetEngine(ctx).Where("repo_id=? AND owner_id=?", 0, 0). + Find(&webhooks) + } + + return webhooks, db.GetEngine(ctx). + Where("repo_id=? AND owner_id=? AND is_system_webhook=?", 0, 0, isSystemWebhook.Value()). + Find(&webhooks) +} + // GetDefaultWebhooks returns all admin-default webhooks. func GetDefaultWebhooks(ctx context.Context) ([]*Webhook, error) { webhooks := make([]*Webhook, 0, 5) diff --git a/models/webhook/webhook_system_test.go b/models/webhook/webhook_system_test.go new file mode 100644 index 0000000000..96157ed9c9 --- /dev/null +++ b/models/webhook/webhook_system_test.go @@ -0,0 +1,37 @@ +// Copyright 2017 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package webhook + +import ( + "testing" + + "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/models/unittest" + "code.gitea.io/gitea/modules/optional" + + "github.com/stretchr/testify/assert" +) + +func TestGetSystemOrDefaultWebhooks(t *testing.T) { + assert.NoError(t, unittest.PrepareTestDatabase()) + + hooks, err := GetSystemOrDefaultWebhooks(db.DefaultContext, optional.None[bool]()) + assert.NoError(t, err) + if assert.Len(t, hooks, 2) { + assert.Equal(t, int64(5), hooks[0].ID) + assert.Equal(t, int64(6), hooks[1].ID) + } + + hooks, err = GetSystemOrDefaultWebhooks(db.DefaultContext, optional.Some(true)) + assert.NoError(t, err) + if assert.Len(t, hooks, 1) { + assert.Equal(t, int64(5), hooks[0].ID) + } + + hooks, err = GetSystemOrDefaultWebhooks(db.DefaultContext, optional.Some(false)) + assert.NoError(t, err) + if assert.Len(t, hooks, 1) { + assert.Equal(t, int64(6), hooks[0].ID) + } +} diff --git a/models/webhook/webhook_test.go b/models/webhook/webhook_test.go index c6c3f40d46..ee53d6da92 100644 --- a/models/webhook/webhook_test.go +++ b/models/webhook/webhook_test.go @@ -54,9 +54,9 @@ func TestWebhook_UpdateEvent(t *testing.T) { SendEverything: false, ChooseEvents: false, HookEvents: webhook_module.HookEvents{ - Create: false, - Push: true, - PullRequest: false, + webhook_module.HookEventCreate: false, + webhook_module.HookEventPush: true, + webhook_module.HookEventPullRequest: false, }, } webhook.HookEvent = hookEvent @@ -68,13 +68,13 @@ func TestWebhook_UpdateEvent(t *testing.T) { } func TestWebhook_EventsArray(t *testing.T) { - assert.Equal(t, []string{ + assert.EqualValues(t, []string{ "create", "delete", "fork", "push", "issues", "issue_assign", "issue_label", "issue_milestone", "issue_comment", "pull_request", "pull_request_assign", "pull_request_label", "pull_request_milestone", "pull_request_comment", "pull_request_review_approved", "pull_request_review_rejected", - "pull_request_review_comment", "pull_request_sync", "wiki", "repository", "release", - "package", "pull_request_review_request", + "pull_request_review_comment", "pull_request_sync", "pull_request_review_request", "wiki", "repository", "release", + "package", "status", }, (&Webhook{ HookEvent: &webhook_module.HookEvent{SendEverything: true}, |