diff options
Diffstat (limited to 'models/notification.go')
-rw-r--r-- | models/notification.go | 108 |
1 files changed, 45 insertions, 63 deletions
diff --git a/models/notification.go b/models/notification.go index d0b7852cd2..548362d190 100644 --- a/models/notification.go +++ b/models/notification.go @@ -119,22 +119,18 @@ func (opts *FindNotificationOptions) ToCond() builder.Cond { } // ToSession will convert the given options to a xorm Session by using the conditions from ToCond and joining with issue table if required -func (opts *FindNotificationOptions) ToSession(e db.Engine) *xorm.Session { - sess := e.Where(opts.ToCond()) +func (opts *FindNotificationOptions) ToSession(ctx context.Context) *xorm.Session { + sess := db.GetEngine(ctx).Where(opts.ToCond()) if opts.Page != 0 { sess = db.SetSessionPagination(sess, opts) } return sess } -func getNotifications(e db.Engine, options *FindNotificationOptions) (nl NotificationList, err error) { - err = options.ToSession(e).OrderBy("notification.updated_unix DESC").Find(&nl) - return -} - // GetNotifications returns all notifications that fit to the given options. -func GetNotifications(opts *FindNotificationOptions) (NotificationList, error) { - return getNotifications(db.GetEngine(db.DefaultContext), opts) +func GetNotifications(ctx context.Context, options *FindNotificationOptions) (nl NotificationList, err error) { + err = options.ToSession(ctx).OrderBy("notification.updated_unix DESC").Find(&nl) + return } // CountNotifications count all notifications that fit to the given options and ignore pagination. @@ -201,15 +197,14 @@ func CreateOrUpdateIssueNotifications(issueID, commentID, notificationAuthorID, } func createOrUpdateIssueNotifications(ctx context.Context, issueID, commentID, notificationAuthorID, receiverID int64) error { - e := db.GetEngine(ctx) // init var toNotify map[int64]struct{} - notifications, err := getNotificationsByIssueID(e, issueID) + notifications, err := getNotificationsByIssueID(ctx, issueID) if err != nil { return err } - issue, err := getIssueByID(e, issueID) + issue, err := getIssueByID(ctx, issueID) if err != nil { return err } @@ -219,7 +214,7 @@ func createOrUpdateIssueNotifications(ctx context.Context, issueID, commentID, n toNotify[receiverID] = struct{}{} } else { toNotify = make(map[int64]struct{}, 32) - issueWatches, err := getIssueWatchersIDs(e, issueID, true) + issueWatches, err := GetIssueWatchersIDs(ctx, issueID, true) if err != nil { return err } @@ -235,7 +230,7 @@ func createOrUpdateIssueNotifications(ctx context.Context, issueID, commentID, n toNotify[id] = struct{}{} } } - issueParticipants, err := issue.getParticipantIDsByIssue(e) + issueParticipants, err := issue.getParticipantIDsByIssue(ctx) if err != nil { return err } @@ -246,7 +241,7 @@ func createOrUpdateIssueNotifications(ctx context.Context, issueID, commentID, n // dont notify user who cause notification delete(toNotify, notificationAuthorID) // explicit unwatch on issue - issueUnWatches, err := getIssueWatchersIDs(e, issueID, false) + issueUnWatches, err := GetIssueWatchersIDs(ctx, issueID, false) if err != nil { return err } @@ -263,7 +258,7 @@ func createOrUpdateIssueNotifications(ctx context.Context, issueID, commentID, n // notify for userID := range toNotify { issue.Repo.Units = nil - user, err := user_model.GetUserByIDEngine(e, userID) + user, err := user_model.GetUserByIDCtx(ctx, userID) if err != nil { if user_model.IsErrUserNotExist(err) { continue @@ -279,20 +274,20 @@ func createOrUpdateIssueNotifications(ctx context.Context, issueID, commentID, n } if notificationExists(notifications, issue.ID, userID) { - if err = updateIssueNotification(e, userID, issue.ID, commentID, notificationAuthorID); err != nil { + if err = updateIssueNotification(ctx, userID, issue.ID, commentID, notificationAuthorID); err != nil { return err } continue } - if err = createIssueNotification(e, userID, issue, commentID, notificationAuthorID); err != nil { + if err = createIssueNotification(ctx, userID, issue, commentID, notificationAuthorID); err != nil { return err } } return nil } -func getNotificationsByIssueID(e db.Engine, issueID int64) (notifications []*Notification, err error) { - err = e. +func getNotificationsByIssueID(ctx context.Context, issueID int64) (notifications []*Notification, err error) { + err = db.GetEngine(ctx). Where("issue_id = ?", issueID). Find(¬ifications) return @@ -308,7 +303,7 @@ func notificationExists(notifications []*Notification, issueID, userID int64) bo return false } -func createIssueNotification(e db.Engine, userID int64, issue *Issue, commentID, updatedByID int64) error { +func createIssueNotification(ctx context.Context, userID int64, issue *Issue, commentID, updatedByID int64) error { notification := &Notification{ UserID: userID, RepoID: issue.RepoID, @@ -324,12 +319,11 @@ func createIssueNotification(e db.Engine, userID int64, issue *Issue, commentID, notification.Source = NotificationSourceIssue } - _, err := e.Insert(notification) - return err + return db.Insert(ctx, notification) } -func updateIssueNotification(e db.Engine, userID, issueID, commentID, updatedByID int64) error { - notification, err := getIssueNotification(e, userID, issueID) +func updateIssueNotification(ctx context.Context, userID, issueID, commentID, updatedByID int64) error { + notification, err := getIssueNotification(ctx, userID, issueID) if err != nil { return err } @@ -346,13 +340,13 @@ func updateIssueNotification(e db.Engine, userID, issueID, commentID, updatedByI cols = []string{"update_by"} } - _, err = e.ID(notification.ID).Cols(cols...).Update(notification) + _, err = db.GetEngine(ctx).ID(notification.ID).Cols(cols...).Update(notification) return err } -func getIssueNotification(e db.Engine, userID, issueID int64) (*Notification, error) { +func getIssueNotification(ctx context.Context, userID, issueID int64) (*Notification, error) { notification := new(Notification) - _, err := e. + _, err := db.GetEngine(ctx). Where("user_id = ?", userID). And("issue_id = ?", issueID). Get(notification) @@ -360,16 +354,12 @@ func getIssueNotification(e db.Engine, userID, issueID int64) (*Notification, er } // NotificationsForUser returns notifications for a given user and status -func NotificationsForUser(user *user_model.User, statuses []NotificationStatus, page, perPage int) (NotificationList, error) { - return notificationsForUser(db.GetEngine(db.DefaultContext), user, statuses, page, perPage) -} - -func notificationsForUser(e db.Engine, user *user_model.User, statuses []NotificationStatus, page, perPage int) (notifications []*Notification, err error) { +func NotificationsForUser(ctx context.Context, user *user_model.User, statuses []NotificationStatus, page, perPage int) (notifications NotificationList, err error) { if len(statuses) == 0 { return } - sess := e. + sess := db.GetEngine(ctx). Where("user_id = ?", user.ID). In("status", statuses). OrderBy("updated_unix DESC") @@ -383,12 +373,8 @@ func notificationsForUser(e db.Engine, user *user_model.User, statuses []Notific } // CountUnread count unread notifications for a user -func CountUnread(user *user_model.User) int64 { - return countUnread(db.GetEngine(db.DefaultContext), user.ID) -} - -func countUnread(e db.Engine, userID int64) int64 { - exist, err := e.Where("user_id = ?", userID).And("status = ?", NotificationStatusUnread).Count(new(Notification)) +func CountUnread(ctx context.Context, userID int64) int64 { + exist, err := db.GetEngine(ctx).Where("user_id = ?", userID).And("status = ?", NotificationStatusUnread).Count(new(Notification)) if err != nil { log.Error("countUnread", err) return 0 @@ -402,17 +388,16 @@ func (n *Notification) LoadAttributes() (err error) { } func (n *Notification) loadAttributes(ctx context.Context) (err error) { - e := db.GetEngine(ctx) if err = n.loadRepo(ctx); err != nil { return } if err = n.loadIssue(ctx); err != nil { return } - if err = n.loadUser(e); err != nil { + if err = n.loadUser(ctx); err != nil { return } - if err = n.loadComment(e); err != nil { + if err = n.loadComment(ctx); err != nil { return } return @@ -430,7 +415,7 @@ func (n *Notification) loadRepo(ctx context.Context) (err error) { func (n *Notification) loadIssue(ctx context.Context) (err error) { if n.Issue == nil && n.IssueID != 0 { - n.Issue, err = getIssueByID(db.GetEngine(ctx), n.IssueID) + n.Issue, err = getIssueByID(ctx, n.IssueID) if err != nil { return fmt.Errorf("getIssueByID [%d]: %v", n.IssueID, err) } @@ -439,9 +424,9 @@ func (n *Notification) loadIssue(ctx context.Context) (err error) { return nil } -func (n *Notification) loadComment(e db.Engine) (err error) { +func (n *Notification) loadComment(ctx context.Context) (err error) { if n.Comment == nil && n.CommentID != 0 { - n.Comment, err = getCommentByID(e, n.CommentID) + n.Comment, err = GetCommentByID(ctx, n.CommentID) if err != nil { if IsErrCommentNotExist(err) { return ErrCommentNotExist{ @@ -455,9 +440,9 @@ func (n *Notification) loadComment(e db.Engine) (err error) { return nil } -func (n *Notification) loadUser(e db.Engine) (err error) { +func (n *Notification) loadUser(ctx context.Context) (err error) { if n.User == nil { - n.User, err = user_model.GetUserByIDEngine(e, n.UserID) + n.User, err = user_model.GetUserByIDCtx(ctx, n.UserID) if err != nil { return fmt.Errorf("getUserByID [%d]: %v", n.UserID, err) } @@ -739,12 +724,8 @@ func (nl NotificationList) LoadComments() ([]int, error) { } // GetNotificationCount returns the notification count for user -func GetNotificationCount(user *user_model.User, status NotificationStatus) (int64, error) { - return getNotificationCount(db.GetEngine(db.DefaultContext), user, status) -} - -func getNotificationCount(e db.Engine, user *user_model.User, status NotificationStatus) (count int64, err error) { - count, err = e. +func GetNotificationCount(ctx context.Context, user *user_model.User, status NotificationStatus) (count int64, err error) { + count, err = db.GetEngine(ctx). Where("user_id = ?", user.ID). And("status = ?", status). Count(&Notification{}) @@ -766,8 +747,8 @@ func GetUIDsAndNotificationCounts(since, until timeutil.TimeStamp) ([]UserIDCoun return res, db.GetEngine(db.DefaultContext).SQL(sql, since, until, NotificationStatusUnread).Find(&res) } -func setIssueNotificationStatusReadIfUnread(e db.Engine, userID, issueID int64) error { - notification, err := getIssueNotification(e, userID, issueID) +func setIssueNotificationStatusReadIfUnread(ctx context.Context, userID, issueID int64) error { + notification, err := getIssueNotification(ctx, userID, issueID) // ignore if not exists if err != nil { return nil @@ -779,12 +760,13 @@ func setIssueNotificationStatusReadIfUnread(e db.Engine, userID, issueID int64) notification.Status = NotificationStatusRead - _, err = e.ID(notification.ID).Update(notification) + _, err = db.GetEngine(ctx).ID(notification.ID).Update(notification) return err } -func setRepoNotificationStatusReadIfUnread(e db.Engine, userID, repoID int64) error { - _, err := e.Where(builder.Eq{ +// SetRepoReadBy sets repo to be visited by given user. +func SetRepoReadBy(ctx context.Context, userID, repoID int64) error { + _, err := db.GetEngine(ctx).Where(builder.Eq{ "user_id": userID, "status": NotificationStatusUnread, "source": NotificationSourceRepository, @@ -795,7 +777,7 @@ func setRepoNotificationStatusReadIfUnread(e db.Engine, userID, repoID int64) er // SetNotificationStatus change the notification status func SetNotificationStatus(notificationID int64, user *user_model.User, status NotificationStatus) (*Notification, error) { - notification, err := getNotificationByID(db.GetEngine(db.DefaultContext), notificationID) + notification, err := getNotificationByID(db.DefaultContext, notificationID) if err != nil { return notification, err } @@ -812,12 +794,12 @@ func SetNotificationStatus(notificationID int64, user *user_model.User, status N // GetNotificationByID return notification by ID func GetNotificationByID(notificationID int64) (*Notification, error) { - return getNotificationByID(db.GetEngine(db.DefaultContext), notificationID) + return getNotificationByID(db.DefaultContext, notificationID) } -func getNotificationByID(e db.Engine, notificationID int64) (*Notification, error) { +func getNotificationByID(ctx context.Context, notificationID int64) (*Notification, error) { notification := new(Notification) - ok, err := e. + ok, err := db.GetEngine(ctx). Where("id = ?", notificationID). Get(notification) if err != nil { |