aboutsummaryrefslogtreecommitdiffstats
path: root/models/user
diff options
context:
space:
mode:
authorJakobDev <jakobdev@gmx.de>2023-09-16 16:39:12 +0200
committerGitHub <noreply@github.com>2023-09-16 14:39:12 +0000
commitf91dbbba98c841f11d99be998ed5dd98122a457c (patch)
tree9c6c935ccf745c5a1716f1330922354809cd39e0 /models/user
parenta1b2a118123e0abd1d4737f4a6c0cf56d15eff57 (diff)
downloadgitea-f91dbbba98c841f11d99be998ed5dd98122a457c.tar.gz
gitea-f91dbbba98c841f11d99be998ed5dd98122a457c.zip
Next round of `db.DefaultContext` refactor (#27089)
Part of #27065
Diffstat (limited to 'models/user')
-rw-r--r--models/user/follow.go18
-rw-r--r--models/user/follow_test.go11
-rw-r--r--models/user/user.go2
-rw-r--r--models/user/user_test.go6
4 files changed, 20 insertions, 17 deletions
diff --git a/models/user/follow.go b/models/user/follow.go
index 7efecc26a7..f4dd2891ff 100644
--- a/models/user/follow.go
+++ b/models/user/follow.go
@@ -4,6 +4,8 @@
package user
import (
+ "context"
+
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/timeutil"
)
@@ -21,18 +23,18 @@ func init() {
}
// IsFollowing returns true if user is following followID.
-func IsFollowing(userID, followID int64) bool {
- has, _ := db.GetEngine(db.DefaultContext).Get(&Follow{UserID: userID, FollowID: followID})
+func IsFollowing(ctx context.Context, userID, followID int64) bool {
+ has, _ := db.GetEngine(ctx).Get(&Follow{UserID: userID, FollowID: followID})
return has
}
// FollowUser marks someone be another's follower.
-func FollowUser(userID, followID int64) (err error) {
- if userID == followID || IsFollowing(userID, followID) {
+func FollowUser(ctx context.Context, userID, followID int64) (err error) {
+ if userID == followID || IsFollowing(ctx, userID, followID) {
return nil
}
- ctx, committer, err := db.TxContext(db.DefaultContext)
+ ctx, committer, err := db.TxContext(ctx)
if err != nil {
return err
}
@@ -53,12 +55,12 @@ func FollowUser(userID, followID int64) (err error) {
}
// UnfollowUser unmarks someone as another's follower.
-func UnfollowUser(userID, followID int64) (err error) {
- if userID == followID || !IsFollowing(userID, followID) {
+func UnfollowUser(ctx context.Context, userID, followID int64) (err error) {
+ if userID == followID || !IsFollowing(ctx, userID, followID) {
return nil
}
- ctx, committer, err := db.TxContext(db.DefaultContext)
+ ctx, committer, err := db.TxContext(ctx)
if err != nil {
return err
}
diff --git a/models/user/follow_test.go b/models/user/follow_test.go
index fc408d5257..c327d935ae 100644
--- a/models/user/follow_test.go
+++ b/models/user/follow_test.go
@@ -6,6 +6,7 @@ package user_test
import (
"testing"
+ "code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/unittest"
user_model "code.gitea.io/gitea/models/user"
@@ -14,9 +15,9 @@ import (
func TestIsFollowing(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
- assert.True(t, user_model.IsFollowing(4, 2))
- assert.False(t, user_model.IsFollowing(2, 4))
- assert.False(t, user_model.IsFollowing(5, unittest.NonexistentID))
- assert.False(t, user_model.IsFollowing(unittest.NonexistentID, 5))
- assert.False(t, user_model.IsFollowing(unittest.NonexistentID, unittest.NonexistentID))
+ assert.True(t, user_model.IsFollowing(db.DefaultContext, 4, 2))
+ assert.False(t, user_model.IsFollowing(db.DefaultContext, 2, 4))
+ assert.False(t, user_model.IsFollowing(db.DefaultContext, 5, unittest.NonexistentID))
+ assert.False(t, user_model.IsFollowing(db.DefaultContext, unittest.NonexistentID, 5))
+ assert.False(t, user_model.IsFollowing(db.DefaultContext, unittest.NonexistentID, unittest.NonexistentID))
}
diff --git a/models/user/user.go b/models/user/user.go
index b3956da1cb..63b95816ce 100644
--- a/models/user/user.go
+++ b/models/user/user.go
@@ -1246,7 +1246,7 @@ func IsUserVisibleToViewer(ctx context.Context, u, viewer *User) bool {
}
// If they follow - they see each over
- follower := IsFollowing(u.ID, viewer.ID)
+ follower := IsFollowing(ctx, u.ID, viewer.ID)
if follower {
return true
}
diff --git a/models/user/user_test.go b/models/user/user_test.go
index b15f0cbc59..971117482c 100644
--- a/models/user/user_test.go
+++ b/models/user/user_test.go
@@ -449,13 +449,13 @@ func TestFollowUser(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
testSuccess := func(followerID, followedID int64) {
- assert.NoError(t, user_model.FollowUser(followerID, followedID))
+ assert.NoError(t, user_model.FollowUser(db.DefaultContext, followerID, followedID))
unittest.AssertExistsAndLoadBean(t, &user_model.Follow{UserID: followerID, FollowID: followedID})
}
testSuccess(4, 2)
testSuccess(5, 2)
- assert.NoError(t, user_model.FollowUser(2, 2))
+ assert.NoError(t, user_model.FollowUser(db.DefaultContext, 2, 2))
unittest.CheckConsistencyFor(t, &user_model.User{})
}
@@ -464,7 +464,7 @@ func TestUnfollowUser(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
testSuccess := func(followerID, followedID int64) {
- assert.NoError(t, user_model.UnfollowUser(followerID, followedID))
+ assert.NoError(t, user_model.UnfollowUser(db.DefaultContext, followerID, followedID))
unittest.AssertNotExistsBean(t, &user_model.Follow{UserID: followerID, FollowID: followedID})
}
testSuccess(4, 2)