diff options
author | JakobDev <jakobdev@gmx.de> | 2023-09-15 08:13:19 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-15 06:13:19 +0000 |
commit | c548dde205244a39a26ba98377c0f5cc11da7041 (patch) | |
tree | f9d9e1185609703e320ed07fd2ff3f95dbdcc230 /models/auth | |
parent | f8a109440655b77e8554e1744e31bf52a7c63df7 (diff) | |
download | gitea-c548dde205244a39a26ba98377c0f5cc11da7041.tar.gz gitea-c548dde205244a39a26ba98377c0f5cc11da7041.zip |
More refactoring of `db.DefaultContext` (#27083)
Next step of #27065
Diffstat (limited to 'models/auth')
-rw-r--r-- | models/auth/token.go | 31 | ||||
-rw-r--r-- | models/auth/token_test.go | 35 | ||||
-rw-r--r-- | models/auth/twofactor.go | 21 |
3 files changed, 45 insertions, 42 deletions
diff --git a/models/auth/token.go b/models/auth/token.go index fed03803d5..8abcc622bc 100644 --- a/models/auth/token.go +++ b/models/auth/token.go @@ -5,6 +5,7 @@ package auth import ( + "context" "crypto/subtle" "encoding/hex" "fmt" @@ -95,7 +96,7 @@ func init() { } // NewAccessToken creates new access token. -func NewAccessToken(t *AccessToken) error { +func NewAccessToken(ctx context.Context, t *AccessToken) error { salt, err := util.CryptoRandomString(10) if err != nil { return err @@ -108,7 +109,7 @@ func NewAccessToken(t *AccessToken) error { t.Token = hex.EncodeToString(token) t.TokenHash = HashToken(t.Token, t.TokenSalt) t.TokenLastEight = t.Token[len(t.Token)-8:] - _, err = db.GetEngine(db.DefaultContext).Insert(t) + _, err = db.GetEngine(ctx).Insert(t) return err } @@ -137,7 +138,7 @@ func getAccessTokenIDFromCache(token string) int64 { } // GetAccessTokenBySHA returns access token by given token value -func GetAccessTokenBySHA(token string) (*AccessToken, error) { +func GetAccessTokenBySHA(ctx context.Context, token string) (*AccessToken, error) { if token == "" { return nil, ErrAccessTokenEmpty{} } @@ -158,7 +159,7 @@ func GetAccessTokenBySHA(token string) (*AccessToken, error) { TokenLastEight: lastEight, } // Re-get the token from the db in case it has been deleted in the intervening period - has, err := db.GetEngine(db.DefaultContext).ID(id).Get(accessToken) + has, err := db.GetEngine(ctx).ID(id).Get(accessToken) if err != nil { return nil, err } @@ -169,7 +170,7 @@ func GetAccessTokenBySHA(token string) (*AccessToken, error) { } var tokens []AccessToken - err := db.GetEngine(db.DefaultContext).Table(&AccessToken{}).Where("token_last_eight = ?", lastEight).Find(&tokens) + err := db.GetEngine(ctx).Table(&AccessToken{}).Where("token_last_eight = ?", lastEight).Find(&tokens) if err != nil { return nil, err } else if len(tokens) == 0 { @@ -189,8 +190,8 @@ func GetAccessTokenBySHA(token string) (*AccessToken, error) { } // AccessTokenByNameExists checks if a token name has been used already by a user. -func AccessTokenByNameExists(token *AccessToken) (bool, error) { - return db.GetEngine(db.DefaultContext).Table("access_token").Where("name = ?", token.Name).And("uid = ?", token.UID).Exist() +func AccessTokenByNameExists(ctx context.Context, token *AccessToken) (bool, error) { + return db.GetEngine(ctx).Table("access_token").Where("name = ?", token.Name).And("uid = ?", token.UID).Exist() } // ListAccessTokensOptions contain filter options @@ -201,8 +202,8 @@ type ListAccessTokensOptions struct { } // ListAccessTokens returns a list of access tokens belongs to given user. -func ListAccessTokens(opts ListAccessTokensOptions) ([]*AccessToken, error) { - sess := db.GetEngine(db.DefaultContext).Where("uid=?", opts.UserID) +func ListAccessTokens(ctx context.Context, opts ListAccessTokensOptions) ([]*AccessToken, error) { + sess := db.GetEngine(ctx).Where("uid=?", opts.UserID) if len(opts.Name) != 0 { sess = sess.Where("name=?", opts.Name) @@ -222,14 +223,14 @@ func ListAccessTokens(opts ListAccessTokensOptions) ([]*AccessToken, error) { } // UpdateAccessToken updates information of access token. -func UpdateAccessToken(t *AccessToken) error { - _, err := db.GetEngine(db.DefaultContext).ID(t.ID).AllCols().Update(t) +func UpdateAccessToken(ctx context.Context, t *AccessToken) error { + _, err := db.GetEngine(ctx).ID(t.ID).AllCols().Update(t) return err } // CountAccessTokens count access tokens belongs to given user by options -func CountAccessTokens(opts ListAccessTokensOptions) (int64, error) { - sess := db.GetEngine(db.DefaultContext).Where("uid=?", opts.UserID) +func CountAccessTokens(ctx context.Context, opts ListAccessTokensOptions) (int64, error) { + sess := db.GetEngine(ctx).Where("uid=?", opts.UserID) if len(opts.Name) != 0 { sess = sess.Where("name=?", opts.Name) } @@ -237,8 +238,8 @@ func CountAccessTokens(opts ListAccessTokensOptions) (int64, error) { } // DeleteAccessTokenByID deletes access token by given ID. -func DeleteAccessTokenByID(id, userID int64) error { - cnt, err := db.GetEngine(db.DefaultContext).ID(id).Delete(&AccessToken{ +func DeleteAccessTokenByID(ctx context.Context, id, userID int64) error { + cnt, err := db.GetEngine(ctx).ID(id).Delete(&AccessToken{ UID: userID, }) if err != nil { diff --git a/models/auth/token_test.go b/models/auth/token_test.go index 8a1e664950..72c937ffd6 100644 --- a/models/auth/token_test.go +++ b/models/auth/token_test.go @@ -7,6 +7,7 @@ import ( "testing" auth_model "code.gitea.io/gitea/models/auth" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/models/unittest" "github.com/stretchr/testify/assert" @@ -18,7 +19,7 @@ func TestNewAccessToken(t *testing.T) { UID: 3, Name: "Token C", } - assert.NoError(t, auth_model.NewAccessToken(token)) + assert.NoError(t, auth_model.NewAccessToken(db.DefaultContext, token)) unittest.AssertExistsAndLoadBean(t, token) invalidToken := &auth_model.AccessToken{ @@ -26,7 +27,7 @@ func TestNewAccessToken(t *testing.T) { UID: 2, Name: "Token F", } - assert.Error(t, auth_model.NewAccessToken(invalidToken)) + assert.Error(t, auth_model.NewAccessToken(db.DefaultContext, invalidToken)) } func TestAccessTokenByNameExists(t *testing.T) { @@ -39,16 +40,16 @@ func TestAccessTokenByNameExists(t *testing.T) { } // Check to make sure it doesn't exists already - exist, err := auth_model.AccessTokenByNameExists(token) + exist, err := auth_model.AccessTokenByNameExists(db.DefaultContext, token) assert.NoError(t, err) assert.False(t, exist) // Save it to the database - assert.NoError(t, auth_model.NewAccessToken(token)) + assert.NoError(t, auth_model.NewAccessToken(db.DefaultContext, token)) unittest.AssertExistsAndLoadBean(t, token) // This token must be found by name in the DB now - exist, err = auth_model.AccessTokenByNameExists(token) + exist, err = auth_model.AccessTokenByNameExists(db.DefaultContext, token) assert.NoError(t, err) assert.True(t, exist) @@ -59,32 +60,32 @@ func TestAccessTokenByNameExists(t *testing.T) { // Name matches but different user ID, this shouldn't exists in the // database - exist, err = auth_model.AccessTokenByNameExists(user4Token) + exist, err = auth_model.AccessTokenByNameExists(db.DefaultContext, user4Token) assert.NoError(t, err) assert.False(t, exist) } func TestGetAccessTokenBySHA(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - token, err := auth_model.GetAccessTokenBySHA("d2c6c1ba3890b309189a8e618c72a162e4efbf36") + token, err := auth_model.GetAccessTokenBySHA(db.DefaultContext, "d2c6c1ba3890b309189a8e618c72a162e4efbf36") assert.NoError(t, err) assert.Equal(t, int64(1), token.UID) assert.Equal(t, "Token A", token.Name) assert.Equal(t, "2b3668e11cb82d3af8c6e4524fc7841297668f5008d1626f0ad3417e9fa39af84c268248b78c481daa7e5dc437784003494f", token.TokenHash) assert.Equal(t, "e4efbf36", token.TokenLastEight) - _, err = auth_model.GetAccessTokenBySHA("notahash") + _, err = auth_model.GetAccessTokenBySHA(db.DefaultContext, "notahash") assert.Error(t, err) assert.True(t, auth_model.IsErrAccessTokenNotExist(err)) - _, err = auth_model.GetAccessTokenBySHA("") + _, err = auth_model.GetAccessTokenBySHA(db.DefaultContext, "") assert.Error(t, err) assert.True(t, auth_model.IsErrAccessTokenEmpty(err)) } func TestListAccessTokens(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - tokens, err := auth_model.ListAccessTokens(auth_model.ListAccessTokensOptions{UserID: 1}) + tokens, err := auth_model.ListAccessTokens(db.DefaultContext, auth_model.ListAccessTokensOptions{UserID: 1}) assert.NoError(t, err) if assert.Len(t, tokens, 2) { assert.Equal(t, int64(1), tokens[0].UID) @@ -93,39 +94,39 @@ func TestListAccessTokens(t *testing.T) { assert.Contains(t, []string{tokens[0].Name, tokens[1].Name}, "Token B") } - tokens, err = auth_model.ListAccessTokens(auth_model.ListAccessTokensOptions{UserID: 2}) + tokens, err = auth_model.ListAccessTokens(db.DefaultContext, auth_model.ListAccessTokensOptions{UserID: 2}) assert.NoError(t, err) if assert.Len(t, tokens, 1) { assert.Equal(t, int64(2), tokens[0].UID) assert.Equal(t, "Token A", tokens[0].Name) } - tokens, err = auth_model.ListAccessTokens(auth_model.ListAccessTokensOptions{UserID: 100}) + tokens, err = auth_model.ListAccessTokens(db.DefaultContext, auth_model.ListAccessTokensOptions{UserID: 100}) assert.NoError(t, err) assert.Empty(t, tokens) } func TestUpdateAccessToken(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - token, err := auth_model.GetAccessTokenBySHA("4c6f36e6cf498e2a448662f915d932c09c5a146c") + token, err := auth_model.GetAccessTokenBySHA(db.DefaultContext, "4c6f36e6cf498e2a448662f915d932c09c5a146c") assert.NoError(t, err) token.Name = "Token Z" - assert.NoError(t, auth_model.UpdateAccessToken(token)) + assert.NoError(t, auth_model.UpdateAccessToken(db.DefaultContext, token)) unittest.AssertExistsAndLoadBean(t, token) } func TestDeleteAccessTokenByID(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - token, err := auth_model.GetAccessTokenBySHA("4c6f36e6cf498e2a448662f915d932c09c5a146c") + token, err := auth_model.GetAccessTokenBySHA(db.DefaultContext, "4c6f36e6cf498e2a448662f915d932c09c5a146c") assert.NoError(t, err) assert.Equal(t, int64(1), token.UID) - assert.NoError(t, auth_model.DeleteAccessTokenByID(token.ID, 1)) + assert.NoError(t, auth_model.DeleteAccessTokenByID(db.DefaultContext, token.ID, 1)) unittest.AssertNotExistsBean(t, token) - err = auth_model.DeleteAccessTokenByID(100, 100) + err = auth_model.DeleteAccessTokenByID(db.DefaultContext, 100, 100) assert.Error(t, err) assert.True(t, auth_model.IsErrAccessTokenNotExist(err)) } diff --git a/models/auth/twofactor.go b/models/auth/twofactor.go index 751a281f7e..51061e5205 100644 --- a/models/auth/twofactor.go +++ b/models/auth/twofactor.go @@ -4,6 +4,7 @@ package auth import ( + "context" "crypto/md5" "crypto/subtle" "encoding/base32" @@ -121,22 +122,22 @@ func (t *TwoFactor) ValidateTOTP(passcode string) (bool, error) { } // NewTwoFactor creates a new two-factor authentication token. -func NewTwoFactor(t *TwoFactor) error { - _, err := db.GetEngine(db.DefaultContext).Insert(t) +func NewTwoFactor(ctx context.Context, t *TwoFactor) error { + _, err := db.GetEngine(ctx).Insert(t) return err } // UpdateTwoFactor updates a two-factor authentication token. -func UpdateTwoFactor(t *TwoFactor) error { - _, err := db.GetEngine(db.DefaultContext).ID(t.ID).AllCols().Update(t) +func UpdateTwoFactor(ctx context.Context, t *TwoFactor) error { + _, err := db.GetEngine(ctx).ID(t.ID).AllCols().Update(t) return err } // GetTwoFactorByUID returns the two-factor authentication token associated with // the user, if any. -func GetTwoFactorByUID(uid int64) (*TwoFactor, error) { +func GetTwoFactorByUID(ctx context.Context, uid int64) (*TwoFactor, error) { twofa := &TwoFactor{} - has, err := db.GetEngine(db.DefaultContext).Where("uid=?", uid).Get(twofa) + has, err := db.GetEngine(ctx).Where("uid=?", uid).Get(twofa) if err != nil { return nil, err } else if !has { @@ -147,13 +148,13 @@ func GetTwoFactorByUID(uid int64) (*TwoFactor, error) { // HasTwoFactorByUID returns the two-factor authentication token associated with // the user, if any. -func HasTwoFactorByUID(uid int64) (bool, error) { - return db.GetEngine(db.DefaultContext).Where("uid=?", uid).Exist(&TwoFactor{}) +func HasTwoFactorByUID(ctx context.Context, uid int64) (bool, error) { + return db.GetEngine(ctx).Where("uid=?", uid).Exist(&TwoFactor{}) } // DeleteTwoFactorByID deletes two-factor authentication token by given ID. -func DeleteTwoFactorByID(id, userID int64) error { - cnt, err := db.GetEngine(db.DefaultContext).ID(id).Delete(&TwoFactor{ +func DeleteTwoFactorByID(ctx context.Context, id, userID int64) error { + cnt, err := db.GetEngine(ctx).ID(id).Delete(&TwoFactor{ UID: userID, }) if err != nil { |