]> source.dussan.org Git - gitea.git/commitdiff
Protect against NPEs in notifications list (#10879) (#10883)
authorzeripath <art27@cantab.net>
Mon, 30 Mar 2020 07:23:02 +0000 (08:23 +0100)
committerGitHub <noreply@github.com>
Mon, 30 Mar 2020 07:23:02 +0000 (15:23 +0800)
* Protect against NPEs in notifications list (#10879)

Unfortunately there appears to be potential race with notifications
being set before the associated issue has been committed.

This PR adds protection in to the notifications list to log any failures
and remove these notifications from the display.

References #10815 - and prevents the panic but does not completely fix
this.

Signed-off-by: Andrew Thornton <art27@cantab.net>
* add log import

* Update models/notification.go

Co-Authored-By: Lauris BH <lauris@nix.lv>
Co-authored-by: Lauris BH <lauris@nix.lv>
models/notification.go
routers/user/notification.go

index 5c03b492574fddf77680d9acbbbf88182c518571..6c8129f88b50ae0a5db356755fe9578a180728f5 100644 (file)
@@ -7,6 +7,7 @@ package models
 import (
        "fmt"
 
+       "code.gitea.io/gitea/modules/log"
        "code.gitea.io/gitea/modules/timeutil"
 )
 
@@ -281,9 +282,9 @@ func (nl NotificationList) getPendingRepoIDs() []int64 {
 }
 
 // LoadRepos loads repositories from database
-func (nl NotificationList) LoadRepos() (RepositoryList, error) {
+func (nl NotificationList) LoadRepos() (RepositoryList, []int, error) {
        if len(nl) == 0 {
-               return RepositoryList{}, nil
+               return RepositoryList{}, []int{}, nil
        }
 
        var repoIDs = nl.getPendingRepoIDs()
@@ -298,7 +299,7 @@ func (nl NotificationList) LoadRepos() (RepositoryList, error) {
                        In("id", repoIDs[:limit]).
                        Rows(new(Repository))
                if err != nil {
-                       return nil, err
+                       return nil, nil, err
                }
 
                for rows.Next() {
@@ -306,7 +307,7 @@ func (nl NotificationList) LoadRepos() (RepositoryList, error) {
                        err = rows.Scan(&repo)
                        if err != nil {
                                rows.Close()
-                               return nil, err
+                               return nil, nil, err
                        }
 
                        repos[repo.ID] = &repo
@@ -317,14 +318,21 @@ func (nl NotificationList) LoadRepos() (RepositoryList, error) {
                repoIDs = repoIDs[limit:]
        }
 
+       failed := []int{}
+
        var reposList = make(RepositoryList, 0, len(repoIDs))
-       for _, notification := range nl {
+       for i, notification := range nl {
                if notification.Repository == nil {
                        notification.Repository = repos[notification.RepoID]
                }
+               if notification.Repository == nil {
+                       log.Error("Notification[%d]: RepoID: %d not found", notification.ID, notification.RepoID)
+                       failed = append(failed, i)
+                       continue
+               }
                var found bool
                for _, r := range reposList {
-                       if r.ID == notification.Repository.ID {
+                       if r.ID == notification.RepoID {
                                found = true
                                break
                        }
@@ -333,7 +341,7 @@ func (nl NotificationList) LoadRepos() (RepositoryList, error) {
                        reposList = append(reposList, notification.Repository)
                }
        }
-       return reposList, nil
+       return reposList, failed, nil
 }
 
 func (nl NotificationList) getPendingIssueIDs() []int64 {
@@ -350,9 +358,9 @@ func (nl NotificationList) getPendingIssueIDs() []int64 {
 }
 
 // LoadIssues loads issues from database
-func (nl NotificationList) LoadIssues() error {
+func (nl NotificationList) LoadIssues() ([]int, error) {
        if len(nl) == 0 {
-               return nil
+               return []int{}, nil
        }
 
        var issueIDs = nl.getPendingIssueIDs()
@@ -367,7 +375,7 @@ func (nl NotificationList) LoadIssues() error {
                        In("id", issueIDs[:limit]).
                        Rows(new(Issue))
                if err != nil {
-                       return err
+                       return nil, err
                }
 
                for rows.Next() {
@@ -375,7 +383,7 @@ func (nl NotificationList) LoadIssues() error {
                        err = rows.Scan(&issue)
                        if err != nil {
                                rows.Close()
-                               return err
+                               return nil, err
                        }
 
                        issues[issue.ID] = &issue
@@ -386,13 +394,38 @@ func (nl NotificationList) LoadIssues() error {
                issueIDs = issueIDs[limit:]
        }
 
-       for _, notification := range nl {
+       failures := []int{}
+
+       for i, notification := range nl {
                if notification.Issue == nil {
                        notification.Issue = issues[notification.IssueID]
+                       if notification.Issue == nil {
+                               log.Error("Notification[%d]: IssueID: %d Not Found", notification.ID, notification.IssueID)
+                               failures = append(failures, i)
+                               continue
+                       }
                        notification.Issue.Repo = notification.Repository
                }
        }
-       return nil
+       return failures, nil
+}
+
+// Without returns the notification list without the failures
+func (nl NotificationList) Without(failures []int) NotificationList {
+       if len(failures) == 0 {
+               return nl
+       }
+       remaining := make([]*Notification, 0, len(nl))
+       last := -1
+       var i int
+       for _, i = range failures {
+               remaining = append(remaining, nl[last+1:i]...)
+               last = i
+       }
+       if len(nl) > i {
+               remaining = append(remaining, nl[i+1:]...)
+       }
+       return remaining
 }
 
 func (nl NotificationList) getPendingCommentIDs() []int64 {
@@ -409,9 +442,9 @@ func (nl NotificationList) getPendingCommentIDs() []int64 {
 }
 
 // LoadComments loads comments from database
-func (nl NotificationList) LoadComments() error {
+func (nl NotificationList) LoadComments() ([]int, error) {
        if len(nl) == 0 {
-               return nil
+               return []int{}, nil
        }
 
        var commentIDs = nl.getPendingCommentIDs()
@@ -426,7 +459,7 @@ func (nl NotificationList) LoadComments() error {
                        In("id", commentIDs[:limit]).
                        Rows(new(Comment))
                if err != nil {
-                       return err
+                       return nil, err
                }
 
                for rows.Next() {
@@ -434,7 +467,7 @@ func (nl NotificationList) LoadComments() error {
                        err = rows.Scan(&comment)
                        if err != nil {
                                rows.Close()
-                               return err
+                               return nil, err
                        }
 
                        comments[comment.ID] = &comment
@@ -445,13 +478,19 @@ func (nl NotificationList) LoadComments() error {
                commentIDs = commentIDs[limit:]
        }
 
-       for _, notification := range nl {
+       failures := []int{}
+       for i, notification := range nl {
                if notification.CommentID > 0 && notification.Comment == nil && comments[notification.CommentID] != nil {
                        notification.Comment = comments[notification.CommentID]
+                       if notification.Comment == nil {
+                               log.Error("Notification[%d]: CommentID[%d] failed to load", notification.ID, notification.CommentID)
+                               failures = append(failures, i)
+                               continue
+                       }
                        notification.Comment.Issue = notification.Issue
                }
        }
-       return nil
+       return failures, nil
 }
 
 // GetNotificationCount returns the notification count for user
index cd6617a23321455a8aa71cd51dbf39f5d6f06f5b..2065057ed182250075de34b0cefdf7208281fd39 100644 (file)
@@ -68,24 +68,39 @@ func Notifications(c *context.Context) {
                return
        }
 
-       repos, err := notifications.LoadRepos()
+       failCount := 0
+
+       repos, failures, err := notifications.LoadRepos()
        if err != nil {
                c.ServerError("LoadRepos", err)
                return
        }
+       notifications = notifications.Without(failures)
        if err := repos.LoadAttributes(); err != nil {
                c.ServerError("LoadAttributes", err)
                return
        }
+       failCount += len(failures)
 
-       if err := notifications.LoadIssues(); err != nil {
+       failures, err = notifications.LoadIssues()
+       if err != nil {
                c.ServerError("LoadIssues", err)
                return
        }
-       if err := notifications.LoadComments(); err != nil {
+       notifications = notifications.Without(failures)
+       failCount += len(failures)
+
+       failures, err = notifications.LoadComments()
+       if err != nil {
                c.ServerError("LoadComments", err)
                return
        }
+       notifications = notifications.Without(failures)
+       failCount += len(failures)
+
+       if failCount > 0 {
+               c.Flash.Error(fmt.Sprintf("ERROR: %d notifications were removed due to missing parts - check the logs", failCount))
+       }
 
        total, err := models.GetNotificationCount(c.User, status)
        if err != nil {