diff options
author | JakobDev <jakobdev@gmx.de> | 2023-09-16 16:39:12 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-16 14:39:12 +0000 |
commit | f91dbbba98c841f11d99be998ed5dd98122a457c (patch) | |
tree | 9c6c935ccf745c5a1716f1330922354809cd39e0 /models/user | |
parent | a1b2a118123e0abd1d4737f4a6c0cf56d15eff57 (diff) | |
download | gitea-f91dbbba98c841f11d99be998ed5dd98122a457c.tar.gz gitea-f91dbbba98c841f11d99be998ed5dd98122a457c.zip |
Next round of `db.DefaultContext` refactor (#27089)
Part of #27065
Diffstat (limited to 'models/user')
-rw-r--r-- | models/user/follow.go | 18 | ||||
-rw-r--r-- | models/user/follow_test.go | 11 | ||||
-rw-r--r-- | models/user/user.go | 2 | ||||
-rw-r--r-- | models/user/user_test.go | 6 |
4 files changed, 20 insertions, 17 deletions
diff --git a/models/user/follow.go b/models/user/follow.go index 7efecc26a7..f4dd2891ff 100644 --- a/models/user/follow.go +++ b/models/user/follow.go @@ -4,6 +4,8 @@ package user import ( + "context" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/timeutil" ) @@ -21,18 +23,18 @@ func init() { } // IsFollowing returns true if user is following followID. -func IsFollowing(userID, followID int64) bool { - has, _ := db.GetEngine(db.DefaultContext).Get(&Follow{UserID: userID, FollowID: followID}) +func IsFollowing(ctx context.Context, userID, followID int64) bool { + has, _ := db.GetEngine(ctx).Get(&Follow{UserID: userID, FollowID: followID}) return has } // FollowUser marks someone be another's follower. -func FollowUser(userID, followID int64) (err error) { - if userID == followID || IsFollowing(userID, followID) { +func FollowUser(ctx context.Context, userID, followID int64) (err error) { + if userID == followID || IsFollowing(ctx, userID, followID) { return nil } - ctx, committer, err := db.TxContext(db.DefaultContext) + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } @@ -53,12 +55,12 @@ func FollowUser(userID, followID int64) (err error) { } // UnfollowUser unmarks someone as another's follower. -func UnfollowUser(userID, followID int64) (err error) { - if userID == followID || !IsFollowing(userID, followID) { +func UnfollowUser(ctx context.Context, userID, followID int64) (err error) { + if userID == followID || !IsFollowing(ctx, userID, followID) { return nil } - ctx, committer, err := db.TxContext(db.DefaultContext) + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } diff --git a/models/user/follow_test.go b/models/user/follow_test.go index fc408d5257..c327d935ae 100644 --- a/models/user/follow_test.go +++ b/models/user/follow_test.go @@ -6,6 +6,7 @@ package user_test import ( "testing" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/models/unittest" user_model "code.gitea.io/gitea/models/user" @@ -14,9 +15,9 @@ import ( func TestIsFollowing(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - assert.True(t, user_model.IsFollowing(4, 2)) - assert.False(t, user_model.IsFollowing(2, 4)) - assert.False(t, user_model.IsFollowing(5, unittest.NonexistentID)) - assert.False(t, user_model.IsFollowing(unittest.NonexistentID, 5)) - assert.False(t, user_model.IsFollowing(unittest.NonexistentID, unittest.NonexistentID)) + assert.True(t, user_model.IsFollowing(db.DefaultContext, 4, 2)) + assert.False(t, user_model.IsFollowing(db.DefaultContext, 2, 4)) + assert.False(t, user_model.IsFollowing(db.DefaultContext, 5, unittest.NonexistentID)) + assert.False(t, user_model.IsFollowing(db.DefaultContext, unittest.NonexistentID, 5)) + assert.False(t, user_model.IsFollowing(db.DefaultContext, unittest.NonexistentID, unittest.NonexistentID)) } diff --git a/models/user/user.go b/models/user/user.go index b3956da1cb..63b95816ce 100644 --- a/models/user/user.go +++ b/models/user/user.go @@ -1246,7 +1246,7 @@ func IsUserVisibleToViewer(ctx context.Context, u, viewer *User) bool { } // If they follow - they see each over - follower := IsFollowing(u.ID, viewer.ID) + follower := IsFollowing(ctx, u.ID, viewer.ID) if follower { return true } diff --git a/models/user/user_test.go b/models/user/user_test.go index b15f0cbc59..971117482c 100644 --- a/models/user/user_test.go +++ b/models/user/user_test.go @@ -449,13 +449,13 @@ func TestFollowUser(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) testSuccess := func(followerID, followedID int64) { - assert.NoError(t, user_model.FollowUser(followerID, followedID)) + assert.NoError(t, user_model.FollowUser(db.DefaultContext, followerID, followedID)) unittest.AssertExistsAndLoadBean(t, &user_model.Follow{UserID: followerID, FollowID: followedID}) } testSuccess(4, 2) testSuccess(5, 2) - assert.NoError(t, user_model.FollowUser(2, 2)) + assert.NoError(t, user_model.FollowUser(db.DefaultContext, 2, 2)) unittest.CheckConsistencyFor(t, &user_model.User{}) } @@ -464,7 +464,7 @@ func TestUnfollowUser(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) testSuccess := func(followerID, followedID int64) { - assert.NoError(t, user_model.UnfollowUser(followerID, followedID)) + assert.NoError(t, user_model.UnfollowUser(db.DefaultContext, followerID, followedID)) unittest.AssertNotExistsBean(t, &user_model.Follow{UserID: followerID, FollowID: followedID}) } testSuccess(4, 2) |