summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--integrations/repofiles_update_test.go13
-rw-r--r--models/issue_watch.go40
-rw-r--r--models/notification.go66
-rw-r--r--models/repo_watch.go6
-rw-r--r--models/user.go2
-rw-r--r--modules/git/repo_branch.go3
-rw-r--r--modules/test/context_tests.go7
-rw-r--r--routers/api/v1/repo/issue_subscription.go9
8 files changed, 66 insertions, 80 deletions
diff --git a/integrations/repofiles_update_test.go b/integrations/repofiles_update_test.go
index a7beec4955..c422483bf8 100644
--- a/integrations/repofiles_update_test.go
+++ b/integrations/repofiles_update_test.go
@@ -207,11 +207,14 @@ func TestCreateOrUpdateRepoFileForCreate(t *testing.T) {
commitID, _ := gitRepo.GetBranchCommitID(opts.NewBranch)
expectedFileResponse := getExpectedFileResponseForRepofilesCreate(commitID)
- assert.EqualValues(t, expectedFileResponse.Content, fileResponse.Content)
- assert.EqualValues(t, expectedFileResponse.Commit.SHA, fileResponse.Commit.SHA)
- assert.EqualValues(t, expectedFileResponse.Commit.HTMLURL, fileResponse.Commit.HTMLURL)
- assert.EqualValues(t, expectedFileResponse.Commit.Author.Email, fileResponse.Commit.Author.Email)
- assert.EqualValues(t, expectedFileResponse.Commit.Author.Name, fileResponse.Commit.Author.Name)
+ assert.NotNil(t, expectedFileResponse)
+ if expectedFileResponse != nil {
+ assert.EqualValues(t, expectedFileResponse.Content, fileResponse.Content)
+ assert.EqualValues(t, expectedFileResponse.Commit.SHA, fileResponse.Commit.SHA)
+ assert.EqualValues(t, expectedFileResponse.Commit.HTMLURL, fileResponse.Commit.HTMLURL)
+ assert.EqualValues(t, expectedFileResponse.Commit.Author.Email, fileResponse.Commit.Author.Email)
+ assert.EqualValues(t, expectedFileResponse.Commit.Author.Name, fileResponse.Commit.Author.Name)
+ }
})
}
diff --git a/models/issue_watch.go b/models/issue_watch.go
index c4732d784e..9046e4d2f7 100644
--- a/models/issue_watch.go
+++ b/models/issue_watch.go
@@ -68,10 +68,14 @@ func getIssueWatch(e Engine, userID, issueID int64) (iw *IssueWatch, exists bool
// but avoids joining with `user` for performance reasons
// User permissions must be verified elsewhere if required
func GetIssueWatchersIDs(issueID int64) ([]int64, error) {
+ return getIssueWatchersIDs(x, issueID, true)
+}
+
+func getIssueWatchersIDs(e Engine, issueID int64, watching bool) ([]int64, error) {
ids := make([]int64, 0, 64)
- return ids, x.Table("issue_watch").
+ return ids, e.Table("issue_watch").
Where("issue_id=?", issueID).
- And("is_watching = ?", true).
+ And("is_watching = ?", watching).
Select("user_id").
Find(&ids)
}
@@ -99,39 +103,9 @@ func getIssueWatchers(e Engine, issueID int64, listOptions ListOptions) (IssueWa
}
func removeIssueWatchersByRepoID(e Engine, userID int64, repoID int64) error {
- iw := &IssueWatch{
- IsWatching: false,
- }
_, err := e.
Join("INNER", "issue", "`issue`.id = `issue_watch`.issue_id AND `issue`.repo_id = ?", repoID).
- Cols("is_watching", "updated_unix").
Where("`issue_watch`.user_id = ?", userID).
- Update(iw)
+ Delete(new(IssueWatch))
return err
}
-
-// LoadWatchUsers return watching users
-func (iwl IssueWatchList) LoadWatchUsers() (users UserList, err error) {
- return iwl.loadWatchUsers(x)
-}
-
-func (iwl IssueWatchList) loadWatchUsers(e Engine) (users UserList, err error) {
- if len(iwl) == 0 {
- return []*User{}, nil
- }
-
- var userIDs = make([]int64, 0, len(iwl))
- for _, iw := range iwl {
- if iw.IsWatching {
- userIDs = append(userIDs, iw.UserID)
- }
- }
-
- if len(userIDs) == 0 {
- return []*User{}, nil
- }
-
- err = e.In("id", userIDs).Find(&users)
-
- return
-}
diff --git a/models/notification.go b/models/notification.go
index e7217a6e04..c52d6c557a 100644
--- a/models/notification.go
+++ b/models/notification.go
@@ -133,55 +133,42 @@ func CreateOrUpdateIssueNotifications(issueID, commentID int64, notificationAuth
}
func createOrUpdateIssueNotifications(e Engine, issueID, commentID int64, notificationAuthorID int64) error {
- issueWatches, err := getIssueWatchers(e, issueID, ListOptions{})
+ // init
+ toNotify := make(map[int64]struct{}, 32)
+ notifications, err := getNotificationsByIssueID(e, issueID)
if err != nil {
return err
}
-
issue, err := getIssueByID(e, issueID)
if err != nil {
return err
}
- watches, err := getWatchers(e, issue.RepoID)
+ issueWatches, err := getIssueWatchersIDs(e, issueID, true)
if err != nil {
return err
}
+ for _, id := range issueWatches {
+ toNotify[id] = struct{}{}
+ }
- notifications, err := getNotificationsByIssueID(e, issueID)
+ repoWatches, err := getRepoWatchersIDs(e, issue.RepoID)
if err != nil {
return err
}
-
- alreadyNotified := make(map[int64]struct{}, len(issueWatches)+len(watches))
-
- notifyUser := func(userID int64) error {
- // do not send notification for the own issuer/commenter
- if userID == notificationAuthorID {
- return nil
- }
-
- if _, ok := alreadyNotified[userID]; ok {
- return nil
- }
- alreadyNotified[userID] = struct{}{}
-
- if notificationExists(notifications, issue.ID, userID) {
- return updateIssueNotification(e, userID, issue.ID, commentID, notificationAuthorID)
- }
- return createIssueNotification(e, userID, issue, commentID, notificationAuthorID)
+ for _, id := range repoWatches {
+ toNotify[id] = struct{}{}
}
- for _, issueWatch := range issueWatches {
- // ignore if user unwatched the issue
- if !issueWatch.IsWatching {
- alreadyNotified[issueWatch.UserID] = struct{}{}
- continue
- }
-
- if err := notifyUser(issueWatch.UserID); err != nil {
- return err
- }
+ // dont notify user who cause notification
+ delete(toNotify, notificationAuthorID)
+ // explicit unwatch on issue
+ issueUnWatches, err := getIssueWatchersIDs(e, issueID, false)
+ if err != nil {
+ return err
+ }
+ for _, id := range issueUnWatches {
+ delete(toNotify, id)
}
err = issue.loadRepo(e)
@@ -189,16 +176,23 @@ func createOrUpdateIssueNotifications(e Engine, issueID, commentID int64, notifi
return err
}
- for _, watch := range watches {
+ // notify
+ for userID := range toNotify {
issue.Repo.Units = nil
- if issue.IsPull && !issue.Repo.checkUnitUser(e, watch.UserID, false, UnitTypePullRequests) {
+ if issue.IsPull && !issue.Repo.checkUnitUser(e, userID, false, UnitTypePullRequests) {
continue
}
- if !issue.IsPull && !issue.Repo.checkUnitUser(e, watch.UserID, false, UnitTypeIssues) {
+ if !issue.IsPull && !issue.Repo.checkUnitUser(e, userID, false, UnitTypeIssues) {
continue
}
- if err := notifyUser(watch.UserID); err != nil {
+ if notificationExists(notifications, issue.ID, userID) {
+ if err = updateIssueNotification(e, userID, issue.ID, commentID, notificationAuthorID); err != nil {
+ return err
+ }
+ continue
+ }
+ if err = createIssueNotification(e, userID, issue, commentID, notificationAuthorID); err != nil {
return err
}
}
diff --git a/models/repo_watch.go b/models/repo_watch.go
index a9d56eff03..11cfa88918 100644
--- a/models/repo_watch.go
+++ b/models/repo_watch.go
@@ -144,8 +144,12 @@ func GetWatchers(repoID int64) ([]*Watch, error) {
// but avoids joining with `user` for performance reasons
// User permissions must be verified elsewhere if required
func GetRepoWatchersIDs(repoID int64) ([]int64, error) {
+ return getRepoWatchersIDs(x, repoID)
+}
+
+func getRepoWatchersIDs(e Engine, repoID int64) ([]int64, error) {
ids := make([]int64, 0, 64)
- return ids, x.Table("watch").
+ return ids, e.Table("watch").
Where("watch.repo_id=?", repoID).
And("watch.mode<>?", RepoWatchModeDont).
Select("user_id").
diff --git a/models/user.go b/models/user.go
index bf59c1240b..8be15ba6df 100644
--- a/models/user.go
+++ b/models/user.go
@@ -1409,7 +1409,7 @@ func GetUserNamesByIDs(ids []int64) ([]string, error) {
}
// GetUsersByIDs returns all resolved users from a list of Ids.
-func GetUsersByIDs(ids []int64) ([]*User, error) {
+func GetUsersByIDs(ids []int64) (UserList, error) {
ous := make([]*User, 0, len(ids))
if len(ids) == 0 {
return ous, nil
diff --git a/modules/git/repo_branch.go b/modules/git/repo_branch.go
index e79bab76a6..3d0e6497ed 100644
--- a/modules/git/repo_branch.go
+++ b/modules/git/repo_branch.go
@@ -48,6 +48,9 @@ type Branch struct {
// GetHEADBranch returns corresponding branch of HEAD.
func (repo *Repository) GetHEADBranch() (*Branch, error) {
+ if repo == nil {
+ return nil, fmt.Errorf("nil repo")
+ }
stdout, err := NewCommand("symbolic-ref", "HEAD").RunInDir(repo.Path)
if err != nil {
return nil, err
diff --git a/modules/test/context_tests.go b/modules/test/context_tests.go
index cf9c5fbc54..f9f0ec5d42 100644
--- a/modules/test/context_tests.go
+++ b/modules/test/context_tests.go
@@ -58,8 +58,11 @@ func LoadRepoCommit(t *testing.T, ctx *context.Context) {
defer gitRepo.Close()
branch, err := gitRepo.GetHEADBranch()
assert.NoError(t, err)
- ctx.Repo.Commit, err = gitRepo.GetBranchCommit(branch.Name)
- assert.NoError(t, err)
+ assert.NotNil(t, branch)
+ if branch != nil {
+ ctx.Repo.Commit, err = gitRepo.GetBranchCommit(branch.Name)
+ assert.NoError(t, err)
+ }
}
// LoadUser load a user into a test context.
diff --git a/routers/api/v1/repo/issue_subscription.go b/routers/api/v1/repo/issue_subscription.go
index 274da966fd..0406edd207 100644
--- a/routers/api/v1/repo/issue_subscription.go
+++ b/routers/api/v1/repo/issue_subscription.go
@@ -190,9 +190,14 @@ func GetIssueSubscribers(ctx *context.APIContext) {
return
}
- users, err := iwl.LoadWatchUsers()
+ var userIDs = make([]int64, 0, len(iwl))
+ for _, iw := range iwl {
+ userIDs = append(userIDs, iw.UserID)
+ }
+
+ users, err := models.GetUsersByIDs(userIDs)
if err != nil {
- ctx.Error(http.StatusInternalServerError, "LoadWatchUsers", err)
+ ctx.Error(http.StatusInternalServerError, "GetUsersByIDs", err)
return
}