瀏覽代碼

Protect against NPEs in notifications list (#10879) (#10883)

* Protect against NPEs in notifications list (#10879)

Unfortunately there appears to be potential race with notifications
being set before the associated issue has been committed.

This PR adds protection in to the notifications list to log any failures
and remove these notifications from the display.

References #10815 - and prevents the panic but does not completely fix
this.

Signed-off-by: Andrew Thornton <art27@cantab.net>

* add log import

* Update models/notification.go

Co-Authored-By: Lauris BH <lauris@nix.lv>

Co-authored-by: Lauris BH <lauris@nix.lv>
tags/v1.11.4
zeripath 4 年之前
父節點
當前提交
596eebb2b6
沒有連結到貢獻者的電子郵件帳戶。
共有 2 個檔案被更改,包括 76 行新增22 行删除
  1. 58
    19
      models/notification.go
  2. 18
    3
      routers/user/notification.go

+ 58
- 19
models/notification.go 查看文件

@@ -7,6 +7,7 @@ package models
import (
"fmt"

"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/timeutil"
)

@@ -281,9 +282,9 @@ func (nl NotificationList) getPendingRepoIDs() []int64 {
}

// LoadRepos loads repositories from database
func (nl NotificationList) LoadRepos() (RepositoryList, error) {
func (nl NotificationList) LoadRepos() (RepositoryList, []int, error) {
if len(nl) == 0 {
return RepositoryList{}, nil
return RepositoryList{}, []int{}, nil
}

var repoIDs = nl.getPendingRepoIDs()
@@ -298,7 +299,7 @@ func (nl NotificationList) LoadRepos() (RepositoryList, error) {
In("id", repoIDs[:limit]).
Rows(new(Repository))
if err != nil {
return nil, err
return nil, nil, err
}

for rows.Next() {
@@ -306,7 +307,7 @@ func (nl NotificationList) LoadRepos() (RepositoryList, error) {
err = rows.Scan(&repo)
if err != nil {
rows.Close()
return nil, err
return nil, nil, err
}

repos[repo.ID] = &repo
@@ -317,14 +318,21 @@ func (nl NotificationList) LoadRepos() (RepositoryList, error) {
repoIDs = repoIDs[limit:]
}

failed := []int{}

var reposList = make(RepositoryList, 0, len(repoIDs))
for _, notification := range nl {
for i, notification := range nl {
if notification.Repository == nil {
notification.Repository = repos[notification.RepoID]
}
if notification.Repository == nil {
log.Error("Notification[%d]: RepoID: %d not found", notification.ID, notification.RepoID)
failed = append(failed, i)
continue
}
var found bool
for _, r := range reposList {
if r.ID == notification.Repository.ID {
if r.ID == notification.RepoID {
found = true
break
}
@@ -333,7 +341,7 @@ func (nl NotificationList) LoadRepos() (RepositoryList, error) {
reposList = append(reposList, notification.Repository)
}
}
return reposList, nil
return reposList, failed, nil
}

func (nl NotificationList) getPendingIssueIDs() []int64 {
@@ -350,9 +358,9 @@ func (nl NotificationList) getPendingIssueIDs() []int64 {
}

// LoadIssues loads issues from database
func (nl NotificationList) LoadIssues() error {
func (nl NotificationList) LoadIssues() ([]int, error) {
if len(nl) == 0 {
return nil
return []int{}, nil
}

var issueIDs = nl.getPendingIssueIDs()
@@ -367,7 +375,7 @@ func (nl NotificationList) LoadIssues() error {
In("id", issueIDs[:limit]).
Rows(new(Issue))
if err != nil {
return err
return nil, err
}

for rows.Next() {
@@ -375,7 +383,7 @@ func (nl NotificationList) LoadIssues() error {
err = rows.Scan(&issue)
if err != nil {
rows.Close()
return err
return nil, err
}

issues[issue.ID] = &issue
@@ -386,13 +394,38 @@ func (nl NotificationList) LoadIssues() error {
issueIDs = issueIDs[limit:]
}

for _, notification := range nl {
failures := []int{}

for i, notification := range nl {
if notification.Issue == nil {
notification.Issue = issues[notification.IssueID]
if notification.Issue == nil {
log.Error("Notification[%d]: IssueID: %d Not Found", notification.ID, notification.IssueID)
failures = append(failures, i)
continue
}
notification.Issue.Repo = notification.Repository
}
}
return nil
return failures, nil
}

// Without returns the notification list without the failures
func (nl NotificationList) Without(failures []int) NotificationList {
if len(failures) == 0 {
return nl
}
remaining := make([]*Notification, 0, len(nl))
last := -1
var i int
for _, i = range failures {
remaining = append(remaining, nl[last+1:i]...)
last = i
}
if len(nl) > i {
remaining = append(remaining, nl[i+1:]...)
}
return remaining
}

func (nl NotificationList) getPendingCommentIDs() []int64 {
@@ -409,9 +442,9 @@ func (nl NotificationList) getPendingCommentIDs() []int64 {
}

// LoadComments loads comments from database
func (nl NotificationList) LoadComments() error {
func (nl NotificationList) LoadComments() ([]int, error) {
if len(nl) == 0 {
return nil
return []int{}, nil
}

var commentIDs = nl.getPendingCommentIDs()
@@ -426,7 +459,7 @@ func (nl NotificationList) LoadComments() error {
In("id", commentIDs[:limit]).
Rows(new(Comment))
if err != nil {
return err
return nil, err
}

for rows.Next() {
@@ -434,7 +467,7 @@ func (nl NotificationList) LoadComments() error {
err = rows.Scan(&comment)
if err != nil {
rows.Close()
return err
return nil, err
}

comments[comment.ID] = &comment
@@ -445,13 +478,19 @@ func (nl NotificationList) LoadComments() error {
commentIDs = commentIDs[limit:]
}

for _, notification := range nl {
failures := []int{}
for i, notification := range nl {
if notification.CommentID > 0 && notification.Comment == nil && comments[notification.CommentID] != nil {
notification.Comment = comments[notification.CommentID]
if notification.Comment == nil {
log.Error("Notification[%d]: CommentID[%d] failed to load", notification.ID, notification.CommentID)
failures = append(failures, i)
continue
}
notification.Comment.Issue = notification.Issue
}
}
return nil
return failures, nil
}

// GetNotificationCount returns the notification count for user

+ 18
- 3
routers/user/notification.go 查看文件

@@ -68,24 +68,39 @@ func Notifications(c *context.Context) {
return
}

repos, err := notifications.LoadRepos()
failCount := 0

repos, failures, err := notifications.LoadRepos()
if err != nil {
c.ServerError("LoadRepos", err)
return
}
notifications = notifications.Without(failures)
if err := repos.LoadAttributes(); err != nil {
c.ServerError("LoadAttributes", err)
return
}
failCount += len(failures)

if err := notifications.LoadIssues(); err != nil {
failures, err = notifications.LoadIssues()
if err != nil {
c.ServerError("LoadIssues", err)
return
}
if err := notifications.LoadComments(); err != nil {
notifications = notifications.Without(failures)
failCount += len(failures)

failures, err = notifications.LoadComments()
if err != nil {
c.ServerError("LoadComments", err)
return
}
notifications = notifications.Without(failures)
failCount += len(failures)

if failCount > 0 {
c.Flash.Error(fmt.Sprintf("ERROR: %d notifications were removed due to missing parts - check the logs", failCount))
}

total, err := models.GetNotificationCount(c.User, status)
if err != nil {

Loading…
取消
儲存