]> source.dussan.org Git - gitea.git/commitdiff
Add container.FilterSlice function (#30339)
authoroliverpool <3864879+oliverpool@users.noreply.github.com>
Tue, 9 Apr 2024 12:27:30 +0000 (14:27 +0200)
committerGitHub <noreply@github.com>
Tue, 9 Apr 2024 12:27:30 +0000 (20:27 +0800)
Many places have the following logic:
```go
func (jobs ActionJobList) GetRunIDs() []int64 {
ids := make(container.Set[int64], len(jobs))
for _, j := range jobs {
if j.RunID == 0 {
continue
}
ids.Add(j.RunID)
}
return ids.Values()
}
```

this introduces a `container.FilterMapUnique` function, which reduces
the code above to:
```go
func (jobs ActionJobList) GetRunIDs() []int64 {
return container.FilterMapUnique(jobs, func(j *ActionRunJob) (int64, bool) {
return j.RunID, j.RunID != 0
})
}
```

16 files changed:
models/actions/run_job_list.go
models/actions/run_list.go
models/actions/runner_list.go
models/actions/schedule_list.go
models/actions/schedule_spec_list.go
models/actions/task_list.go
models/activities/action_list.go
models/git/branch_list.go
models/issues/comment.go
models/issues/comment_list.go
models/issues/issue_list.go
models/issues/reaction.go
models/issues/review_list.go
models/repo/repo_list.go
modules/container/filter.go [new file with mode: 0644]
modules/container/filter_test.go [new file with mode: 0644]

index 6ea6cb9d3b30d9f459646ffe88234915aef67b31..6c5d3b3252ebfb513f3762c19c51e2b89f810c37 100644 (file)
@@ -16,14 +16,9 @@ import (
 type ActionJobList []*ActionRunJob
 
 func (jobs ActionJobList) GetRunIDs() []int64 {
-       ids := make(container.Set[int64], len(jobs))
-       for _, j := range jobs {
-               if j.RunID == 0 {
-                       continue
-               }
-               ids.Add(j.RunID)
-       }
-       return ids.Values()
+       return container.FilterSlice(jobs, func(j *ActionRunJob) (int64, bool) {
+               return j.RunID, j.RunID != 0
+       })
 }
 
 func (jobs ActionJobList) LoadRuns(ctx context.Context, withRepo bool) error {
index 388bfc4f86f94f3953140d4125db2cc0385a2311..4046c7d3694365e8987e4a08d6c155ec33e005bf 100644 (file)
@@ -19,19 +19,15 @@ type RunList []*ActionRun
 
 // GetUserIDs returns a slice of user's id
 func (runs RunList) GetUserIDs() []int64 {
-       ids := make(container.Set[int64], len(runs))
-       for _, run := range runs {
-               ids.Add(run.TriggerUserID)
-       }
-       return ids.Values()
+       return container.FilterSlice(runs, func(run *ActionRun) (int64, bool) {
+               return run.TriggerUserID, true
+       })
 }
 
 func (runs RunList) GetRepoIDs() []int64 {
-       ids := make(container.Set[int64], len(runs))
-       for _, run := range runs {
-               ids.Add(run.RepoID)
-       }
-       return ids.Values()
+       return container.FilterSlice(runs, func(run *ActionRun) (int64, bool) {
+               return run.RepoID, true
+       })
 }
 
 func (runs RunList) LoadTriggerUser(ctx context.Context) error {
index 87f0886b470d5fb7671ec02dafc3da7b53ede763..5ce69e07acfa6dbb8151fb624e5356138980d29c 100644 (file)
@@ -16,14 +16,9 @@ type RunnerList []*ActionRunner
 
 // GetUserIDs returns a slice of user's id
 func (runners RunnerList) GetUserIDs() []int64 {
-       ids := make(container.Set[int64], len(runners))
-       for _, runner := range runners {
-               if runner.OwnerID == 0 {
-                       continue
-               }
-               ids.Add(runner.OwnerID)
-       }
-       return ids.Values()
+       return container.FilterSlice(runners, func(runner *ActionRunner) (int64, bool) {
+               return runner.OwnerID, runner.OwnerID != 0
+       })
 }
 
 func (runners RunnerList) LoadOwners(ctx context.Context) error {
index b806550b87d81387d196ee7f839829b61be1ab15..1d35adc420822157e42d80023a5308fb767f0e2c 100644 (file)
@@ -18,19 +18,15 @@ type ScheduleList []*ActionSchedule
 
 // GetUserIDs returns a slice of user's id
 func (schedules ScheduleList) GetUserIDs() []int64 {
-       ids := make(container.Set[int64], len(schedules))
-       for _, schedule := range schedules {
-               ids.Add(schedule.TriggerUserID)
-       }
-       return ids.Values()
+       return container.FilterSlice(schedules, func(schedule *ActionSchedule) (int64, bool) {
+               return schedule.TriggerUserID, true
+       })
 }
 
 func (schedules ScheduleList) GetRepoIDs() []int64 {
-       ids := make(container.Set[int64], len(schedules))
-       for _, schedule := range schedules {
-               ids.Add(schedule.RepoID)
-       }
-       return ids.Values()
+       return container.FilterSlice(schedules, func(schedule *ActionSchedule) (int64, bool) {
+               return schedule.RepoID, true
+       })
 }
 
 func (schedules ScheduleList) LoadTriggerUser(ctx context.Context) error {
index e9ae268a6e4ad7260e213b2dd3405a4a89fd0cd4..f7dac72f8b38e801f5564b72d826069cfccc2e97 100644 (file)
@@ -16,11 +16,9 @@ import (
 type SpecList []*ActionScheduleSpec
 
 func (specs SpecList) GetScheduleIDs() []int64 {
-       ids := make(container.Set[int64], len(specs))
-       for _, spec := range specs {
-               ids.Add(spec.ScheduleID)
-       }
-       return ids.Values()
+       return container.FilterSlice(specs, func(spec *ActionScheduleSpec) (int64, bool) {
+               return spec.ScheduleID, true
+       })
 }
 
 func (specs SpecList) LoadSchedules(ctx context.Context) error {
@@ -46,11 +44,9 @@ func (specs SpecList) LoadSchedules(ctx context.Context) error {
 }
 
 func (specs SpecList) GetRepoIDs() []int64 {
-       ids := make(container.Set[int64], len(specs))
-       for _, spec := range specs {
-               ids.Add(spec.RepoID)
-       }
-       return ids.Values()
+       return container.FilterSlice(specs, func(spec *ActionScheduleSpec) (int64, bool) {
+               return spec.RepoID, true
+       })
 }
 
 func (specs SpecList) LoadRepos(ctx context.Context) error {
index b07d00b8dbd6339906d943cd72867ec12816d5e4..5e17f914417e7dc03ece6a4eeba34b57878b855b 100644 (file)
@@ -16,14 +16,9 @@ import (
 type TaskList []*ActionTask
 
 func (tasks TaskList) GetJobIDs() []int64 {
-       ids := make(container.Set[int64], len(tasks))
-       for _, t := range tasks {
-               if t.JobID == 0 {
-                       continue
-               }
-               ids.Add(t.JobID)
-       }
-       return ids.Values()
+       return container.FilterSlice(tasks, func(t *ActionTask) (int64, bool) {
+               return t.JobID, t.JobID != 0
+       })
 }
 
 func (tasks TaskList) LoadJobs(ctx context.Context) error {
index fdf0f35d4f457393bbc5d455c7700af7e672c8b0..6e23b173b5abcddb259a52d3e9eae12d4dc38500 100644 (file)
@@ -22,11 +22,9 @@ import (
 type ActionList []*Action
 
 func (actions ActionList) getUserIDs() []int64 {
-       userIDs := make(container.Set[int64], len(actions))
-       for _, action := range actions {
-               userIDs.Add(action.ActUserID)
-       }
-       return userIDs.Values()
+       return container.FilterSlice(actions, func(action *Action) (int64, bool) {
+               return action.ActUserID, true
+       })
 }
 
 func (actions ActionList) LoadActUsers(ctx context.Context) (map[int64]*user_model.User, error) {
@@ -50,11 +48,9 @@ func (actions ActionList) LoadActUsers(ctx context.Context) (map[int64]*user_mod
 }
 
 func (actions ActionList) getRepoIDs() []int64 {
-       repoIDs := make(container.Set[int64], len(actions))
-       for _, action := range actions {
-               repoIDs.Add(action.RepoID)
-       }
-       return repoIDs.Values()
+       return container.FilterSlice(actions, func(action *Action) (int64, bool) {
+               return action.RepoID, true
+       })
 }
 
 func (actions ActionList) LoadRepositories(ctx context.Context) error {
@@ -80,18 +76,16 @@ func (actions ActionList) loadRepoOwner(ctx context.Context, userMap map[int64]*
                userMap = make(map[int64]*user_model.User)
        }
 
-       userSet := make(container.Set[int64], len(actions))
-       for _, action := range actions {
+       missingUserIDs := container.FilterSlice(actions, func(action *Action) (int64, bool) {
                if action.Repo == nil {
-                       continue
+                       return 0, false
                }
-               if _, ok := userMap[action.Repo.OwnerID]; !ok {
-                       userSet.Add(action.Repo.OwnerID)
-               }
-       }
+               _, alreadyLoaded := userMap[action.Repo.OwnerID]
+               return action.Repo.OwnerID, !alreadyLoaded
+       })
 
        if err := db.GetEngine(ctx).
-               In("id", userSet.Values()).
+               In("id", missingUserIDs).
                Find(&userMap); err != nil {
                return fmt.Errorf("find user: %w", err)
        }
index 8319e5ecd08321bf8f66463e735e38d4ab303a50..980bd7b4c9df859a350d2c5efb97e59be096fba7 100644 (file)
@@ -17,15 +17,12 @@ import (
 type BranchList []*Branch
 
 func (branches BranchList) LoadDeletedBy(ctx context.Context) error {
-       ids := container.Set[int64]{}
-       for _, branch := range branches {
-               if !branch.IsDeleted {
-                       continue
-               }
-               ids.Add(branch.DeletedByID)
-       }
+       ids := container.FilterSlice(branches, func(branch *Branch) (int64, bool) {
+               return branch.DeletedByID, branch.IsDeleted
+       })
+
        usersMap := make(map[int64]*user_model.User, len(ids))
-       if err := db.GetEngine(ctx).In("id", ids.Values()).Find(&usersMap); err != nil {
+       if err := db.GetEngine(ctx).In("id", ids).Find(&usersMap); err != nil {
                return err
        }
        for _, branch := range branches {
@@ -41,14 +38,13 @@ func (branches BranchList) LoadDeletedBy(ctx context.Context) error {
 }
 
 func (branches BranchList) LoadPusher(ctx context.Context) error {
-       ids := container.Set[int64]{}
-       for _, branch := range branches {
-               if branch.PusherID > 0 { // pusher_id maybe zero because some branches are sync by backend with no pusher
-                       ids.Add(branch.PusherID)
-               }
-       }
+       ids := container.FilterSlice(branches, func(branch *Branch) (int64, bool) {
+               // pusher_id maybe zero because some branches are sync by backend with no pusher
+               return branch.PusherID, branch.PusherID > 0
+       })
+
        usersMap := make(map[int64]*user_model.User, len(ids))
-       if err := db.GetEngine(ctx).In("id", ids.Values()).Find(&usersMap); err != nil {
+       if err := db.GetEngine(ctx).In("id", ids).Find(&usersMap); err != nil {
                return err
        }
        for _, branch := range branches {
index 6f65a5dbbc61625de3342065469f6e092f5de640..353163ebd6f99228f405a7c5d33b918f6e0e8789 100644 (file)
@@ -1272,10 +1272,9 @@ func InsertIssueComments(ctx context.Context, comments []*Comment) error {
                return nil
        }
 
-       issueIDs := make(container.Set[int64])
-       for _, comment := range comments {
-               issueIDs.Add(comment.IssueID)
-       }
+       issueIDs := container.FilterSlice(comments, func(comment *Comment) (int64, bool) {
+               return comment.IssueID, true
+       })
 
        ctx, committer, err := db.TxContext(ctx)
        if err != nil {
@@ -1298,7 +1297,7 @@ func InsertIssueComments(ctx context.Context, comments []*Comment) error {
                }
        }
 
-       for issueID := range issueIDs {
+       for _, issueID := range issueIDs {
                if _, err := db.Exec(ctx, "UPDATE issue set num_comments = (SELECT count(*) FROM comment WHERE issue_id = ? AND `type`=?) WHERE id = ?",
                        issueID, CommentTypeComment, issueID); err != nil {
                        return err
index 0047b054bac1f5eeed4b9c3f44c09d599a893ab3..370b5396e09040b945af9c46a1eb61546a68bac9 100644 (file)
@@ -17,13 +17,9 @@ import (
 type CommentList []*Comment
 
 func (comments CommentList) getPosterIDs() []int64 {
-       posterIDs := make(container.Set[int64], len(comments))
-       for _, comment := range comments {
-               if comment.PosterID > 0 {
-                       posterIDs.Add(comment.PosterID)
-               }
-       }
-       return posterIDs.Values()
+       return container.FilterSlice(comments, func(c *Comment) (int64, bool) {
+               return c.PosterID, c.PosterID > 0
+       })
 }
 
 // LoadPosters loads posters
@@ -44,13 +40,9 @@ func (comments CommentList) LoadPosters(ctx context.Context) error {
 }
 
 func (comments CommentList) getLabelIDs() []int64 {
-       ids := make(container.Set[int64], len(comments))
-       for _, comment := range comments {
-               if comment.LabelID > 0 {
-                       ids.Add(comment.LabelID)
-               }
-       }
-       return ids.Values()
+       return container.FilterSlice(comments, func(comment *Comment) (int64, bool) {
+               return comment.LabelID, comment.LabelID > 0
+       })
 }
 
 func (comments CommentList) loadLabels(ctx context.Context) error {
@@ -94,13 +86,9 @@ func (comments CommentList) loadLabels(ctx context.Context) error {
 }
 
 func (comments CommentList) getMilestoneIDs() []int64 {
-       ids := make(container.Set[int64], len(comments))
-       for _, comment := range comments {
-               if comment.MilestoneID > 0 {
-                       ids.Add(comment.MilestoneID)
-               }
-       }
-       return ids.Values()
+       return container.FilterSlice(comments, func(comment *Comment) (int64, bool) {
+               return comment.MilestoneID, comment.MilestoneID > 0
+       })
 }
 
 func (comments CommentList) loadMilestones(ctx context.Context) error {
@@ -137,13 +125,9 @@ func (comments CommentList) loadMilestones(ctx context.Context) error {
 }
 
 func (comments CommentList) getOldMilestoneIDs() []int64 {
-       ids := make(container.Set[int64], len(comments))
-       for _, comment := range comments {
-               if comment.OldMilestoneID > 0 {
-                       ids.Add(comment.OldMilestoneID)
-               }
-       }
-       return ids.Values()
+       return container.FilterSlice(comments, func(comment *Comment) (int64, bool) {
+               return comment.OldMilestoneID, comment.OldMilestoneID > 0
+       })
 }
 
 func (comments CommentList) loadOldMilestones(ctx context.Context) error {
@@ -180,13 +164,9 @@ func (comments CommentList) loadOldMilestones(ctx context.Context) error {
 }
 
 func (comments CommentList) getAssigneeIDs() []int64 {
-       ids := make(container.Set[int64], len(comments))
-       for _, comment := range comments {
-               if comment.AssigneeID > 0 {
-                       ids.Add(comment.AssigneeID)
-               }
-       }
-       return ids.Values()
+       return container.FilterSlice(comments, func(comment *Comment) (int64, bool) {
+               return comment.AssigneeID, comment.AssigneeID > 0
+       })
 }
 
 func (comments CommentList) loadAssignees(ctx context.Context) error {
@@ -237,14 +217,9 @@ func (comments CommentList) loadAssignees(ctx context.Context) error {
 
 // getIssueIDs returns all the issue ids on this comment list which issue hasn't been loaded
 func (comments CommentList) getIssueIDs() []int64 {
-       ids := make(container.Set[int64], len(comments))
-       for _, comment := range comments {
-               if comment.Issue != nil {
-                       continue
-               }
-               ids.Add(comment.IssueID)
-       }
-       return ids.Values()
+       return container.FilterSlice(comments, func(comment *Comment) (int64, bool) {
+               return comment.IssueID, comment.Issue == nil
+       })
 }
 
 // Issues returns all the issues of comments
@@ -311,16 +286,12 @@ func (comments CommentList) LoadIssues(ctx context.Context) error {
 }
 
 func (comments CommentList) getDependentIssueIDs() []int64 {
-       ids := make(container.Set[int64], len(comments))
-       for _, comment := range comments {
+       return container.FilterSlice(comments, func(comment *Comment) (int64, bool) {
                if comment.DependentIssue != nil {
-                       continue
-               }
-               if comment.DependentIssueID > 0 {
-                       ids.Add(comment.DependentIssueID)
+                       return 0, false
                }
-       }
-       return ids.Values()
+               return comment.DependentIssueID, comment.DependentIssueID > 0
+       })
 }
 
 func (comments CommentList) loadDependentIssues(ctx context.Context) error {
@@ -375,15 +346,9 @@ func (comments CommentList) loadDependentIssues(ctx context.Context) error {
 
 // getAttachmentCommentIDs only return the comment ids which possibly has attachments
 func (comments CommentList) getAttachmentCommentIDs() []int64 {
-       ids := make(container.Set[int64], len(comments))
-       for _, comment := range comments {
-               if comment.Type == CommentTypeComment ||
-                       comment.Type == CommentTypeReview ||
-                       comment.Type == CommentTypeCode {
-                       ids.Add(comment.ID)
-               }
-       }
-       return ids.Values()
+       return container.FilterSlice(comments, func(comment *Comment) (int64, bool) {
+               return comment.ID, comment.Type.HasAttachmentSupport()
+       })
 }
 
 // LoadAttachmentsByIssue loads attachments by issue id
@@ -451,13 +416,9 @@ func (comments CommentList) LoadAttachments(ctx context.Context) (err error) {
 }
 
 func (comments CommentList) getReviewIDs() []int64 {
-       ids := make(container.Set[int64], len(comments))
-       for _, comment := range comments {
-               if comment.ReviewID > 0 {
-                       ids.Add(comment.ReviewID)
-               }
-       }
-       return ids.Values()
+       return container.FilterSlice(comments, func(comment *Comment) (int64, bool) {
+               return comment.ReviewID, comment.ReviewID > 0
+       })
 }
 
 func (comments CommentList) loadReviews(ctx context.Context) error {
index 218891ad35771ed85f7a6e1512b4f080ad5ea1a9..1b05f0aa35a52689c1d53af93a37ae8bd61fc3b0 100644 (file)
@@ -74,11 +74,9 @@ func (issues IssueList) LoadRepositories(ctx context.Context) (repo_model.Reposi
 }
 
 func (issues IssueList) getPosterIDs() []int64 {
-       posterIDs := make(container.Set[int64], len(issues))
-       for _, issue := range issues {
-               posterIDs.Add(issue.PosterID)
-       }
-       return posterIDs.Values()
+       return container.FilterSlice(issues, func(issue *Issue) (int64, bool) {
+               return issue.PosterID, true
+       })
 }
 
 func (issues IssueList) loadPosters(ctx context.Context) error {
@@ -193,11 +191,9 @@ func (issues IssueList) loadLabels(ctx context.Context) error {
 }
 
 func (issues IssueList) getMilestoneIDs() []int64 {
-       ids := make(container.Set[int64], len(issues))
-       for _, issue := range issues {
-               ids.Add(issue.MilestoneID)
-       }
-       return ids.Values()
+       return container.FilterSlice(issues, func(issue *Issue) (int64, bool) {
+               return issue.MilestoneID, true
+       })
 }
 
 func (issues IssueList) loadMilestones(ctx context.Context) error {
index d5448636fe8284d98dc5bca77947b47bdd4e6432..eb7faefc796b94c403b309586b9993173c7dadc2 100644 (file)
@@ -305,14 +305,12 @@ func (list ReactionList) GroupByType() map[string]ReactionList {
 }
 
 func (list ReactionList) getUserIDs() []int64 {
-       userIDs := make(container.Set[int64], len(list))
-       for _, reaction := range list {
+       return container.FilterSlice(list, func(reaction *Reaction) (int64, bool) {
                if reaction.OriginalAuthor != "" {
-                       continue
+                       return 0, false
                }
-               userIDs.Add(reaction.UserID)
-       }
-       return userIDs.Values()
+               return reaction.UserID, true
+       })
 }
 
 func valuesUser(m map[int64]*user_model.User) []*user_model.User {
index ec6cb079886580a8f092b30503747c02039baf76..7b8c3d319c35a0ab8c2f42d6af1feade2d6480e6 100644 (file)
@@ -38,12 +38,11 @@ func (reviews ReviewList) LoadReviewers(ctx context.Context) error {
 }
 
 func (reviews ReviewList) LoadIssues(ctx context.Context) error {
-       issueIDs := container.Set[int64]{}
-       for i := 0; i < len(reviews); i++ {
-               issueIDs.Add(reviews[i].IssueID)
-       }
+       issueIDs := container.FilterSlice(reviews, func(review *Review) (int64, bool) {
+               return review.IssueID, true
+       })
 
-       issues, err := GetIssuesByIDs(ctx, issueIDs.Values())
+       issues, err := GetIssuesByIDs(ctx, issueIDs)
        if err != nil {
                return err
        }
index cb7cd47a8ded5ed6994a731e606b53f5517a1df9..987c7df9b0eb0fb956b1973ac05069b646eb518d 100644 (file)
@@ -104,18 +104,19 @@ func (repos RepositoryList) LoadAttributes(ctx context.Context) error {
                return nil
        }
 
-       set := make(container.Set[int64])
+       userIDs := container.FilterSlice(repos, func(repo *Repository) (int64, bool) {
+               return repo.OwnerID, true
+       })
        repoIDs := make([]int64, len(repos))
        for i := range repos {
-               set.Add(repos[i].OwnerID)
                repoIDs[i] = repos[i].ID
        }
 
        // Load owners.
-       users := make(map[int64]*user_model.User, len(set))
+       users := make(map[int64]*user_model.User, len(userIDs))
        if err := db.GetEngine(ctx).
                Where("id > 0").
-               In("id", set.Values()).
+               In("id", userIDs).
                Find(&users); err != nil {
                return fmt.Errorf("find users: %w", err)
        }
diff --git a/modules/container/filter.go b/modules/container/filter.go
new file mode 100644 (file)
index 0000000..37ec7c3
--- /dev/null
@@ -0,0 +1,21 @@
+// Copyright 2024 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package container
+
+import "slices"
+
+// FilterSlice ranges over the slice and calls include() for each element.
+// If the second returned value is true, the first returned value will be included in the resulting
+// slice (after deduplication).
+func FilterSlice[E any, T comparable](s []E, include func(E) (T, bool)) []T {
+       filtered := make([]T, 0, len(s)) // slice will be clipped before returning
+       seen := make(map[T]bool, len(s))
+       for i := range s {
+               if v, ok := include(s[i]); ok && !seen[v] {
+                       filtered = append(filtered, v)
+                       seen[v] = true
+               }
+       }
+       return slices.Clip(filtered)
+}
diff --git a/modules/container/filter_test.go b/modules/container/filter_test.go
new file mode 100644 (file)
index 0000000..ad304e5
--- /dev/null
@@ -0,0 +1,28 @@
+// Copyright 2024 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package container
+
+import (
+       "testing"
+
+       "github.com/stretchr/testify/assert"
+)
+
+func TestFilterMapUnique(t *testing.T) {
+       result := FilterSlice([]int{
+               0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
+       }, func(i int) (int, bool) {
+               switch i {
+               case 0:
+                       return 0, true // included later
+               case 1:
+                       return 0, true // duplicate of previous (should be ignored)
+               case 2:
+                       return 2, false // not included
+               default:
+                       return i, true
+               }
+       })
+       assert.Equal(t, []int{0, 3, 4, 5, 6, 7, 8, 9}, result)
+}