You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

pull_list.go 5.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. // Copyright 2019 The Gitea Authors. All rights reserved.
  2. // Use of this source code is governed by a MIT-style
  3. // license that can be found in the LICENSE file.
  4. package models
  5. import (
  6. "fmt"
  7. "code.gitea.io/gitea/models/db"
  8. "code.gitea.io/gitea/modules/base"
  9. "code.gitea.io/gitea/modules/git"
  10. "code.gitea.io/gitea/modules/log"
  11. "xorm.io/xorm"
  12. )
  13. // PullRequestsOptions holds the options for PRs
  14. type PullRequestsOptions struct {
  15. db.ListOptions
  16. State string
  17. SortType string
  18. Labels []string
  19. MilestoneID int64
  20. }
  21. func listPullRequestStatement(baseRepoID int64, opts *PullRequestsOptions) (*xorm.Session, error) {
  22. sess := db.GetEngine(db.DefaultContext).Where("pull_request.base_repo_id=?", baseRepoID)
  23. sess.Join("INNER", "issue", "pull_request.issue_id = issue.id")
  24. switch opts.State {
  25. case "closed", "open":
  26. sess.And("issue.is_closed=?", opts.State == "closed")
  27. }
  28. if labelIDs, err := base.StringsToInt64s(opts.Labels); err != nil {
  29. return nil, err
  30. } else if len(labelIDs) > 0 {
  31. sess.Join("INNER", "issue_label", "issue.id = issue_label.issue_id").
  32. In("issue_label.label_id", labelIDs)
  33. }
  34. if opts.MilestoneID > 0 {
  35. sess.And("issue.milestone_id=?", opts.MilestoneID)
  36. }
  37. return sess, nil
  38. }
  39. // GetUnmergedPullRequestsByHeadInfo returns all pull requests that are open and has not been merged
  40. // by given head information (repo and branch).
  41. func GetUnmergedPullRequestsByHeadInfo(repoID int64, branch string) ([]*PullRequest, error) {
  42. prs := make([]*PullRequest, 0, 2)
  43. return prs, db.GetEngine(db.DefaultContext).
  44. Where("head_repo_id = ? AND head_branch = ? AND has_merged = ? AND issue.is_closed = ? AND flow = ?",
  45. repoID, branch, false, false, PullRequestFlowGithub).
  46. Join("INNER", "issue", "issue.id = pull_request.issue_id").
  47. Find(&prs)
  48. }
  49. // GetUnmergedPullRequestsByBaseInfo returns all pull requests that are open and has not been merged
  50. // by given base information (repo and branch).
  51. func GetUnmergedPullRequestsByBaseInfo(repoID int64, branch string) ([]*PullRequest, error) {
  52. prs := make([]*PullRequest, 0, 2)
  53. return prs, db.GetEngine(db.DefaultContext).
  54. Where("base_repo_id=? AND base_branch=? AND has_merged=? AND issue.is_closed=?",
  55. repoID, branch, false, false).
  56. Join("INNER", "issue", "issue.id=pull_request.issue_id").
  57. Find(&prs)
  58. }
  59. // GetPullRequestIDsByCheckStatus returns all pull requests according the special checking status.
  60. func GetPullRequestIDsByCheckStatus(status PullRequestStatus) ([]int64, error) {
  61. prs := make([]int64, 0, 10)
  62. return prs, db.GetEngine(db.DefaultContext).Table("pull_request").
  63. Where("status=?", status).
  64. Cols("pull_request.id").
  65. Find(&prs)
  66. }
  67. // PullRequests returns all pull requests for a base Repo by the given conditions
  68. func PullRequests(baseRepoID int64, opts *PullRequestsOptions) ([]*PullRequest, int64, error) {
  69. if opts.Page <= 0 {
  70. opts.Page = 1
  71. }
  72. countSession, err := listPullRequestStatement(baseRepoID, opts)
  73. if err != nil {
  74. log.Error("listPullRequestStatement: %v", err)
  75. return nil, 0, err
  76. }
  77. maxResults, err := countSession.Count(new(PullRequest))
  78. if err != nil {
  79. log.Error("Count PRs: %v", err)
  80. return nil, maxResults, err
  81. }
  82. findSession, err := listPullRequestStatement(baseRepoID, opts)
  83. sortIssuesSession(findSession, opts.SortType, 0)
  84. if err != nil {
  85. log.Error("listPullRequestStatement: %v", err)
  86. return nil, maxResults, err
  87. }
  88. findSession = db.SetSessionPagination(findSession, opts)
  89. prs := make([]*PullRequest, 0, opts.PageSize)
  90. return prs, maxResults, findSession.Find(&prs)
  91. }
  92. // PullRequestList defines a list of pull requests
  93. type PullRequestList []*PullRequest
  94. func (prs PullRequestList) loadAttributes(e db.Engine) error {
  95. if len(prs) == 0 {
  96. return nil
  97. }
  98. // Load issues.
  99. issueIDs := prs.getIssueIDs()
  100. issues := make([]*Issue, 0, len(issueIDs))
  101. if err := e.
  102. Where("id > 0").
  103. In("id", issueIDs).
  104. Find(&issues); err != nil {
  105. return fmt.Errorf("find issues: %v", err)
  106. }
  107. set := make(map[int64]*Issue)
  108. for i := range issues {
  109. set[issues[i].ID] = issues[i]
  110. }
  111. for i := range prs {
  112. prs[i].Issue = set[prs[i].IssueID]
  113. }
  114. return nil
  115. }
  116. func (prs PullRequestList) getIssueIDs() []int64 {
  117. issueIDs := make([]int64, 0, len(prs))
  118. for i := range prs {
  119. issueIDs = append(issueIDs, prs[i].IssueID)
  120. }
  121. return issueIDs
  122. }
  123. // LoadAttributes load all the prs attributes
  124. func (prs PullRequestList) LoadAttributes() error {
  125. return prs.loadAttributes(db.GetEngine(db.DefaultContext))
  126. }
  127. func (prs PullRequestList) invalidateCodeComments(e db.Engine, doer *User, repo *git.Repository, branch string) error {
  128. if len(prs) == 0 {
  129. return nil
  130. }
  131. issueIDs := prs.getIssueIDs()
  132. var codeComments []*Comment
  133. if err := e.
  134. Where("type = ? and invalidated = ?", CommentTypeCode, false).
  135. In("issue_id", issueIDs).
  136. Find(&codeComments); err != nil {
  137. return fmt.Errorf("find code comments: %v", err)
  138. }
  139. for _, comment := range codeComments {
  140. if err := comment.CheckInvalidation(repo, doer, branch); err != nil {
  141. return err
  142. }
  143. }
  144. return nil
  145. }
  146. // InvalidateCodeComments will lookup the prs for code comments which got invalidated by change
  147. func (prs PullRequestList) InvalidateCodeComments(doer *User, repo *git.Repository, branch string) error {
  148. return prs.invalidateCodeComments(db.GetEngine(db.DefaultContext), doer, repo, branch)
  149. }