]> source.dussan.org Git - gitea.git/commitdiff
Fix milestones too many SQL variables bug (#10880) (#10904)
authorLunny Xiao <xiaolunwen@gmail.com>
Tue, 31 Mar 2020 13:40:37 +0000 (21:40 +0800)
committerGitHub <noreply@github.com>
Tue, 31 Mar 2020 13:40:37 +0000 (08:40 -0500)
* Fix milestones too many SQL variables bug

* Fix test

* Don't display repositories with no milestone and fix tests

* Remove unused code and add some comments

models/issue_milestone.go
models/issue_milestone_test.go
models/repo_list.go
routers/user/home.go
routers/user/home_test.go

index 4ceaaa9e764c6f047643db232837ca7f06770222..29e32e7d96d500e09b4ddeaa0579e52f8d3bf7b2 100644 (file)
@@ -521,10 +521,12 @@ func DeleteMilestoneByRepoID(repoID, id int64) error {
        return sess.Commit()
 }
 
-// CountMilestonesByRepoIDs map from repoIDs to number of milestones matching the options`
-func CountMilestonesByRepoIDs(repoIDs []int64, isClosed bool) (map[int64]int64, error) {
+// CountMilestones map from repo conditions to number of milestones matching the options`
+func CountMilestones(repoCond builder.Cond, isClosed bool) (map[int64]int64, error) {
        sess := x.Where("is_closed = ?", isClosed)
-       sess.In("repo_id", repoIDs)
+       if repoCond.IsValid() {
+               sess.In("repo_id", builder.Select("id").From("repository").Where(repoCond))
+       }
 
        countsSlice := make([]*struct {
                RepoID int64
@@ -544,11 +546,21 @@ func CountMilestonesByRepoIDs(repoIDs []int64, isClosed bool) (map[int64]int64,
        return countMap, nil
 }
 
-// GetMilestonesByRepoIDs returns a list of milestones of given repositories and status.
-func GetMilestonesByRepoIDs(repoIDs []int64, page int, isClosed bool, sortType string) (MilestoneList, error) {
+// CountMilestonesByRepoIDs map from repoIDs to number of milestones matching the options`
+func CountMilestonesByRepoIDs(repoIDs []int64, isClosed bool) (map[int64]int64, error) {
+       return CountMilestones(
+               builder.In("repo_id", repoIDs),
+               isClosed,
+       )
+}
+
+// SearchMilestones search milestones
+func SearchMilestones(repoCond builder.Cond, page int, isClosed bool, sortType string) (MilestoneList, error) {
        miles := make([]*Milestone, 0, setting.UI.IssuePagingNum)
        sess := x.Where("is_closed = ?", isClosed)
-       sess.In("repo_id", repoIDs)
+       if repoCond.IsValid() {
+               sess.In("repo_id", builder.Select("id").From("repository").Where(repoCond))
+       }
        if page > 0 {
                sess = sess.Limit(setting.UI.IssuePagingNum, (page-1)*setting.UI.IssuePagingNum)
        }
@@ -570,25 +582,45 @@ func GetMilestonesByRepoIDs(repoIDs []int64, page int, isClosed bool, sortType s
        return miles, sess.Find(&miles)
 }
 
+// GetMilestonesByRepoIDs returns a list of milestones of given repositories and status.
+func GetMilestonesByRepoIDs(repoIDs []int64, page int, isClosed bool, sortType string) (MilestoneList, error) {
+       return SearchMilestones(
+               builder.In("repo_id", repoIDs),
+               page,
+               isClosed,
+               sortType,
+       )
+}
+
 // MilestonesStats represents milestone statistic information.
 type MilestonesStats struct {
        OpenCount, ClosedCount int64
 }
 
+// Total returns the total counts of milestones
+func (m MilestonesStats) Total() int64 {
+       return m.OpenCount + m.ClosedCount
+}
+
 // GetMilestonesStats returns milestone statistic information for dashboard by given conditions.
-func GetMilestonesStats(userRepoIDs []int64) (*MilestonesStats, error) {
+func GetMilestonesStats(repoCond builder.Cond) (*MilestonesStats, error) {
        var err error
        stats := &MilestonesStats{}
 
-       stats.OpenCount, err = x.Where("is_closed = ?", false).
-               And(builder.In("repo_id", userRepoIDs)).
-               Count(new(Milestone))
+       sess := x.Where("is_closed = ?", false)
+       if repoCond.IsValid() {
+               sess.And(builder.In("repo_id", builder.Select("id").From("repository").Where(repoCond)))
+       }
+       stats.OpenCount, err = sess.Count(new(Milestone))
        if err != nil {
                return nil, err
        }
-       stats.ClosedCount, err = x.Where("is_closed = ?", true).
-               And(builder.In("repo_id", userRepoIDs)).
-               Count(new(Milestone))
+
+       sess = x.Where("is_closed = ?", true)
+       if repoCond.IsValid() {
+               sess.And(builder.In("repo_id", builder.Select("id").From("repository").Where(repoCond)))
+       }
+       stats.ClosedCount, err = sess.Count(new(Milestone))
        if err != nil {
                return nil, err
        }
index cea2c4ea142d818c7e219db93bcad955ccbcf392..9adfd15f1486c344b1df29476443d475fc40f351 100644 (file)
@@ -11,6 +11,7 @@ import (
 
        api "code.gitea.io/gitea/modules/structs"
        "code.gitea.io/gitea/modules/timeutil"
+       "xorm.io/builder"
 
        "github.com/stretchr/testify/assert"
 )
@@ -370,7 +371,7 @@ func TestGetMilestonesStats(t *testing.T) {
        repo1 := AssertExistsAndLoadBean(t, &Repository{ID: 1}).(*Repository)
        repo2 := AssertExistsAndLoadBean(t, &Repository{ID: 2}).(*Repository)
 
-       milestoneStats, err := GetMilestonesStats([]int64{repo1.ID, repo2.ID})
+       milestoneStats, err := GetMilestonesStats(builder.In("repo_id", []int64{repo1.ID, repo2.ID}))
        assert.NoError(t, err)
        assert.EqualValues(t, repo1.NumOpenMilestones+repo2.NumOpenMilestones, milestoneStats.OpenCount)
        assert.EqualValues(t, repo1.NumClosedMilestones+repo2.NumClosedMilestones, milestoneStats.ClosedCount)
index 1dd5cf2f063c1023551f7b3e296dd7b1d09e1358..bc5aac524afcdca35dbe40c5d9f1995ca6205e8e 100644 (file)
@@ -144,6 +144,10 @@ type SearchRepoOptions struct {
        TopicOnly bool
        // include description in keyword search
        IncludeDescription bool
+       // None -> include has milestones AND has no milestone
+       // True -> include just has milestones
+       // False -> include just has no milestone
+       HasMilestones util.OptionalBool
 }
 
 //SearchOrderBy is used to sort the result
@@ -171,12 +175,9 @@ const (
        SearchOrderByForksReverse          SearchOrderBy = "num_forks DESC"
 )
 
-// SearchRepository returns repositories based on search options,
+// SearchRepositoryCondition returns repositories based on search options,
 // it returns results in given range and number of total results.
-func SearchRepository(opts *SearchRepoOptions) (RepositoryList, int64, error) {
-       if opts.Page <= 0 {
-               opts.Page = 1
-       }
+func SearchRepositoryCondition(opts *SearchRepoOptions) builder.Cond {
        var cond = builder.NewCond()
 
        if opts.Private {
@@ -276,6 +277,29 @@ func SearchRepository(opts *SearchRepoOptions) (RepositoryList, int64, error) {
                cond = cond.And(builder.Eq{"is_mirror": opts.Mirror == util.OptionalBoolTrue})
        }
 
+       switch opts.HasMilestones {
+       case util.OptionalBoolTrue:
+               cond = cond.And(builder.Gt{"num_milestones": 0})
+       case util.OptionalBoolFalse:
+               cond = cond.And(builder.Eq{"num_milestones": 0}.Or(builder.IsNull{"num_milestones"}))
+       }
+
+       return cond
+}
+
+// SearchRepository returns repositories based on search options,
+// it returns results in given range and number of total results.
+func SearchRepository(opts *SearchRepoOptions) (RepositoryList, int64, error) {
+       cond := SearchRepositoryCondition(opts)
+       return SearchRepositoryByCondition(opts, cond)
+}
+
+// SearchRepositoryByCondition search repositories by condition
+func SearchRepositoryByCondition(opts *SearchRepoOptions, cond builder.Cond) (RepositoryList, int64, error) {
+       if opts.Page <= 0 {
+               opts.Page = 1
+       }
+
        if len(opts.OrderBy) == 0 {
                opts.OrderBy = SearchOrderByAlphabetically
        }
@@ -296,11 +320,11 @@ func SearchRepository(opts *SearchRepoOptions) (RepositoryList, int64, error) {
        }
 
        repos := make(RepositoryList, 0, opts.PageSize)
-       if err = sess.
-               Where(cond).
-               OrderBy(opts.OrderBy.String()).
-               Limit(opts.PageSize, (opts.Page-1)*opts.PageSize).
-               Find(&repos); err != nil {
+       sess.Where(cond).OrderBy(opts.OrderBy.String())
+       if opts.PageSize > 0 {
+               sess.Limit(opts.PageSize, (opts.Page-1)*opts.PageSize)
+       }
+       if err = sess.Find(&repos); err != nil {
                return nil, 0, fmt.Errorf("Repo: %v", err)
        }
 
index 819853ac3841ba2a7035caf40a6434ae553030ea..3b357ffdc1611376e2fa018e92d997a60f8f5a90 100644 (file)
@@ -24,7 +24,7 @@ import (
 
        "github.com/keybase/go-crypto/openpgp"
        "github.com/keybase/go-crypto/openpgp/armor"
-       "github.com/unknwon/com"
+       "xorm.io/builder"
 )
 
 const (
@@ -171,135 +171,114 @@ func Milestones(ctx *context.Context) {
                return
        }
 
-       sortType := ctx.Query("sort")
-       page := ctx.QueryInt("page")
-       if page <= 1 {
-               page = 1
-       }
+       var (
+               repoOpts = models.SearchRepoOptions{
+                       OwnerID:       ctxUser.ID,
+                       Private:       true,
+                       AllPublic:     false,                 // Include also all public repositories of users and public organisations
+                       AllLimited:    false,                 // Include also all public repositories of limited organisations
+                       HasMilestones: util.OptionalBoolTrue, // Just needs display repos has milestones
+                       IsProfile:     false,
+               }
 
-       reposQuery := ctx.Query("repos")
-       isShowClosed := ctx.Query("state") == "closed"
+               userRepoCond = models.SearchRepositoryCondition(&repoOpts) // all repo condition user could visit
+               repoCond     = userRepoCond
+               repoIDs      []int64
 
-       // Get repositories.
-       var err error
-       var userRepoIDs []int64
-       if ctxUser.IsOrganization() {
-               env, err := ctxUser.AccessibleReposEnv(ctx.User.ID)
-               if err != nil {
-                       ctx.ServerError("AccessibleReposEnv", err)
-                       return
-               }
-               userRepoIDs, err = env.RepoIDs(1, ctxUser.NumRepos)
-               if err != nil {
-                       ctx.ServerError("env.RepoIDs", err)
-                       return
-               }
-               userRepoIDs, err = models.FilterOutRepoIdsWithoutUnitAccess(ctx.User, userRepoIDs, models.UnitTypeIssues, models.UnitTypePullRequests)
-               if err != nil {
-                       ctx.ServerError("FilterOutRepoIdsWithoutUnitAccess", err)
-                       return
-               }
-       } else {
-               userRepoIDs, err = ctxUser.GetAccessRepoIDs(models.UnitTypeIssues, models.UnitTypePullRequests)
-               if err != nil {
-                       ctx.ServerError("ctxUser.GetAccessRepoIDs", err)
-                       return
-               }
-       }
-       if len(userRepoIDs) == 0 {
-               userRepoIDs = []int64{-1}
+               reposQuery   = ctx.Query("repos")
+               isShowClosed = ctx.Query("state") == "closed"
+               sortType     = ctx.Query("sort")
+               page         = ctx.QueryInt("page")
+       )
+
+       if page <= 1 {
+               page = 1
        }
 
-       var repoIDs []int64
        if len(reposQuery) != 0 {
                if issueReposQueryPattern.MatchString(reposQuery) {
                        // remove "[" and "]" from string
                        reposQuery = reposQuery[1 : len(reposQuery)-1]
                        //for each ID (delimiter ",") add to int to repoIDs
-                       reposSet := false
+
                        for _, rID := range strings.Split(reposQuery, ",") {
                                // Ensure nonempty string entries
                                if rID != "" && rID != "0" {
-                                       reposSet = true
                                        rIDint64, err := strconv.ParseInt(rID, 10, 64)
                                        // If the repo id specified by query is not parseable or not accessible by user, just ignore it.
-                                       if err == nil && com.IsSliceContainsInt64(userRepoIDs, rIDint64) {
+                                       if err == nil {
                                                repoIDs = append(repoIDs, rIDint64)
                                        }
                                }
                        }
-                       if reposSet && len(repoIDs) == 0 {
-                               // force an empty result
-                               repoIDs = []int64{-1}
+                       if len(repoIDs) > 0 {
+                               // Don't just let repoCond = builder.In("id", repoIDs) because user may has no permission on repoIDs
+                               // But the original repoCond has a limitation
+                               repoCond = repoCond.And(builder.In("id", repoIDs))
                        }
                } else {
                        log.Warn("issueReposQueryPattern not match with query")
                }
        }
 
-       if len(repoIDs) == 0 {
-               repoIDs = userRepoIDs
-       }
-
-       counts, err := models.CountMilestonesByRepoIDs(userRepoIDs, isShowClosed)
+       counts, err := models.CountMilestones(userRepoCond, isShowClosed)
        if err != nil {
                ctx.ServerError("CountMilestonesByRepoIDs", err)
                return
        }
 
-       milestones, err := models.GetMilestonesByRepoIDs(repoIDs, page, isShowClosed, sortType)
+       milestones, err := models.SearchMilestones(repoCond, page, isShowClosed, sortType)
        if err != nil {
                ctx.ServerError("GetMilestonesByRepoIDs", err)
                return
        }
 
-       showReposMap := make(map[int64]*models.Repository, len(counts))
-       for rID := range counts {
-               if rID == -1 {
-                       break
-               }
-               repo, err := models.GetRepositoryByID(rID)
-               if err != nil {
-                       if models.IsErrRepoNotExist(err) {
-                               ctx.NotFound("GetRepositoryByID", err)
-                               return
-                       } else if err != nil {
-                               ctx.ServerError("GetRepositoryByID", fmt.Errorf("[%d]%v", rID, err))
-                               return
-                       }
-               }
-               showReposMap[rID] = repo
-       }
-
-       showRepos := models.RepositoryListOfMap(showReposMap)
-       sort.Sort(showRepos)
-       if err = showRepos.LoadAttributes(); err != nil {
-               ctx.ServerError("LoadAttributes", err)
+       showRepos, _, err := models.SearchRepositoryByCondition(&repoOpts, userRepoCond)
+       if err != nil {
+               ctx.ServerError("SearchRepositoryByCondition", err)
                return
        }
+       sort.Sort(showRepos)
+
+       for i := 0; i < len(milestones); {
+               for _, repo := range showRepos {
+                       if milestones[i].RepoID == repo.ID {
+                               milestones[i].Repo = repo
+                               break
+                       }
+               }
+               if milestones[i].Repo == nil {
+                       log.Warn("Cannot find milestone %d 's repository %d", milestones[i].ID, milestones[i].RepoID)
+                       milestones = append(milestones[:i], milestones[i+1:]...)
+                       continue
+               }
 
-       for _, m := range milestones {
-               m.Repo = showReposMap[m.RepoID]
-               m.RenderedContent = string(markdown.Render([]byte(m.Content), m.Repo.Link(), m.Repo.ComposeMetas()))
-               if m.Repo.IsTimetrackerEnabled() {
-                       err := m.LoadTotalTrackedTime()
+               milestones[i].RenderedContent = string(markdown.Render([]byte(milestones[i].Content), milestones[i].Repo.Link(), milestones[i].Repo.ComposeMetas()))
+               if milestones[i].Repo.IsTimetrackerEnabled() {
+                       err := milestones[i].LoadTotalTrackedTime()
                        if err != nil {
                                ctx.ServerError("LoadTotalTrackedTime", err)
                                return
                        }
                }
+               i++
        }
 
-       milestoneStats, err := models.GetMilestonesStats(repoIDs)
+       milestoneStats, err := models.GetMilestonesStats(repoCond)
        if err != nil {
                ctx.ServerError("GetMilestoneStats", err)
                return
        }
 
-       totalMilestoneStats, err := models.GetMilestonesStats(userRepoIDs)
-       if err != nil {
-               ctx.ServerError("GetMilestoneStats", err)
-               return
+       var totalMilestoneStats *models.MilestonesStats
+       if len(repoIDs) == 0 {
+               totalMilestoneStats = milestoneStats
+       } else {
+               totalMilestoneStats, err = models.GetMilestonesStats(userRepoCond)
+               if err != nil {
+                       ctx.ServerError("GetMilestoneStats", err)
+                       return
+               }
        }
 
        var pagerCount int
@@ -318,7 +297,7 @@ func Milestones(ctx *context.Context) {
        ctx.Data["Counts"] = counts
        ctx.Data["MilestoneStats"] = milestoneStats
        ctx.Data["SortType"] = sortType
-       if len(repoIDs) != len(userRepoIDs) {
+       if milestoneStats.Total() != totalMilestoneStats.Total() {
                ctx.Data["RepoIDs"] = repoIDs
        }
        ctx.Data["IsShowClosed"] = isShowClosed
index 39186d93eeaa9f1b85742f2e599ea3024035f308..ff48953d440238efcbd1f110b6518971acfdbdd6 100644 (file)
@@ -48,7 +48,7 @@ func TestMilestones(t *testing.T) {
        assert.EqualValues(t, "furthestduedate", ctx.Data["SortType"])
        assert.EqualValues(t, 1, ctx.Data["Total"])
        assert.Len(t, ctx.Data["Milestones"], 1)
-       assert.Len(t, ctx.Data["Repos"], 1)
+       assert.Len(t, ctx.Data["Repos"], 2) // both repo 42 and 1 have milestones and both are owned by user 2
 }
 
 func TestMilestonesForSpecificRepo(t *testing.T) {
@@ -68,5 +68,5 @@ func TestMilestonesForSpecificRepo(t *testing.T) {
        assert.EqualValues(t, "furthestduedate", ctx.Data["SortType"])
        assert.EqualValues(t, 1, ctx.Data["Total"])
        assert.Len(t, ctx.Data["Milestones"], 1)
-       assert.Len(t, ctx.Data["Repos"], 1)
+       assert.Len(t, ctx.Data["Repos"], 2) // both repo 42 and 1 have milestones and both are owned by user 2
 }