aboutsummaryrefslogtreecommitdiffstats
path: root/models/auth
diff options
context:
space:
mode:
Diffstat (limited to 'models/auth')
-rw-r--r--models/auth/token.go31
-rw-r--r--models/auth/token_test.go35
-rw-r--r--models/auth/twofactor.go21
3 files changed, 45 insertions, 42 deletions
diff --git a/models/auth/token.go b/models/auth/token.go
index fed03803d5..8abcc622bc 100644
--- a/models/auth/token.go
+++ b/models/auth/token.go
@@ -5,6 +5,7 @@
package auth
import (
+ "context"
"crypto/subtle"
"encoding/hex"
"fmt"
@@ -95,7 +96,7 @@ func init() {
}
// NewAccessToken creates new access token.
-func NewAccessToken(t *AccessToken) error {
+func NewAccessToken(ctx context.Context, t *AccessToken) error {
salt, err := util.CryptoRandomString(10)
if err != nil {
return err
@@ -108,7 +109,7 @@ func NewAccessToken(t *AccessToken) error {
t.Token = hex.EncodeToString(token)
t.TokenHash = HashToken(t.Token, t.TokenSalt)
t.TokenLastEight = t.Token[len(t.Token)-8:]
- _, err = db.GetEngine(db.DefaultContext).Insert(t)
+ _, err = db.GetEngine(ctx).Insert(t)
return err
}
@@ -137,7 +138,7 @@ func getAccessTokenIDFromCache(token string) int64 {
}
// GetAccessTokenBySHA returns access token by given token value
-func GetAccessTokenBySHA(token string) (*AccessToken, error) {
+func GetAccessTokenBySHA(ctx context.Context, token string) (*AccessToken, error) {
if token == "" {
return nil, ErrAccessTokenEmpty{}
}
@@ -158,7 +159,7 @@ func GetAccessTokenBySHA(token string) (*AccessToken, error) {
TokenLastEight: lastEight,
}
// Re-get the token from the db in case it has been deleted in the intervening period
- has, err := db.GetEngine(db.DefaultContext).ID(id).Get(accessToken)
+ has, err := db.GetEngine(ctx).ID(id).Get(accessToken)
if err != nil {
return nil, err
}
@@ -169,7 +170,7 @@ func GetAccessTokenBySHA(token string) (*AccessToken, error) {
}
var tokens []AccessToken
- err := db.GetEngine(db.DefaultContext).Table(&AccessToken{}).Where("token_last_eight = ?", lastEight).Find(&tokens)
+ err := db.GetEngine(ctx).Table(&AccessToken{}).Where("token_last_eight = ?", lastEight).Find(&tokens)
if err != nil {
return nil, err
} else if len(tokens) == 0 {
@@ -189,8 +190,8 @@ func GetAccessTokenBySHA(token string) (*AccessToken, error) {
}
// AccessTokenByNameExists checks if a token name has been used already by a user.
-func AccessTokenByNameExists(token *AccessToken) (bool, error) {
- return db.GetEngine(db.DefaultContext).Table("access_token").Where("name = ?", token.Name).And("uid = ?", token.UID).Exist()
+func AccessTokenByNameExists(ctx context.Context, token *AccessToken) (bool, error) {
+ return db.GetEngine(ctx).Table("access_token").Where("name = ?", token.Name).And("uid = ?", token.UID).Exist()
}
// ListAccessTokensOptions contain filter options
@@ -201,8 +202,8 @@ type ListAccessTokensOptions struct {
}
// ListAccessTokens returns a list of access tokens belongs to given user.
-func ListAccessTokens(opts ListAccessTokensOptions) ([]*AccessToken, error) {
- sess := db.GetEngine(db.DefaultContext).Where("uid=?", opts.UserID)
+func ListAccessTokens(ctx context.Context, opts ListAccessTokensOptions) ([]*AccessToken, error) {
+ sess := db.GetEngine(ctx).Where("uid=?", opts.UserID)
if len(opts.Name) != 0 {
sess = sess.Where("name=?", opts.Name)
@@ -222,14 +223,14 @@ func ListAccessTokens(opts ListAccessTokensOptions) ([]*AccessToken, error) {
}
// UpdateAccessToken updates information of access token.
-func UpdateAccessToken(t *AccessToken) error {
- _, err := db.GetEngine(db.DefaultContext).ID(t.ID).AllCols().Update(t)
+func UpdateAccessToken(ctx context.Context, t *AccessToken) error {
+ _, err := db.GetEngine(ctx).ID(t.ID).AllCols().Update(t)
return err
}
// CountAccessTokens count access tokens belongs to given user by options
-func CountAccessTokens(opts ListAccessTokensOptions) (int64, error) {
- sess := db.GetEngine(db.DefaultContext).Where("uid=?", opts.UserID)
+func CountAccessTokens(ctx context.Context, opts ListAccessTokensOptions) (int64, error) {
+ sess := db.GetEngine(ctx).Where("uid=?", opts.UserID)
if len(opts.Name) != 0 {
sess = sess.Where("name=?", opts.Name)
}
@@ -237,8 +238,8 @@ func CountAccessTokens(opts ListAccessTokensOptions) (int64, error) {
}
// DeleteAccessTokenByID deletes access token by given ID.
-func DeleteAccessTokenByID(id, userID int64) error {
- cnt, err := db.GetEngine(db.DefaultContext).ID(id).Delete(&AccessToken{
+func DeleteAccessTokenByID(ctx context.Context, id, userID int64) error {
+ cnt, err := db.GetEngine(ctx).ID(id).Delete(&AccessToken{
UID: userID,
})
if err != nil {
diff --git a/models/auth/token_test.go b/models/auth/token_test.go
index 8a1e664950..72c937ffd6 100644
--- a/models/auth/token_test.go
+++ b/models/auth/token_test.go
@@ -7,6 +7,7 @@ import (
"testing"
auth_model "code.gitea.io/gitea/models/auth"
+ "code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/unittest"
"github.com/stretchr/testify/assert"
@@ -18,7 +19,7 @@ func TestNewAccessToken(t *testing.T) {
UID: 3,
Name: "Token C",
}
- assert.NoError(t, auth_model.NewAccessToken(token))
+ assert.NoError(t, auth_model.NewAccessToken(db.DefaultContext, token))
unittest.AssertExistsAndLoadBean(t, token)
invalidToken := &auth_model.AccessToken{
@@ -26,7 +27,7 @@ func TestNewAccessToken(t *testing.T) {
UID: 2,
Name: "Token F",
}
- assert.Error(t, auth_model.NewAccessToken(invalidToken))
+ assert.Error(t, auth_model.NewAccessToken(db.DefaultContext, invalidToken))
}
func TestAccessTokenByNameExists(t *testing.T) {
@@ -39,16 +40,16 @@ func TestAccessTokenByNameExists(t *testing.T) {
}
// Check to make sure it doesn't exists already
- exist, err := auth_model.AccessTokenByNameExists(token)
+ exist, err := auth_model.AccessTokenByNameExists(db.DefaultContext, token)
assert.NoError(t, err)
assert.False(t, exist)
// Save it to the database
- assert.NoError(t, auth_model.NewAccessToken(token))
+ assert.NoError(t, auth_model.NewAccessToken(db.DefaultContext, token))
unittest.AssertExistsAndLoadBean(t, token)
// This token must be found by name in the DB now
- exist, err = auth_model.AccessTokenByNameExists(token)
+ exist, err = auth_model.AccessTokenByNameExists(db.DefaultContext, token)
assert.NoError(t, err)
assert.True(t, exist)
@@ -59,32 +60,32 @@ func TestAccessTokenByNameExists(t *testing.T) {
// Name matches but different user ID, this shouldn't exists in the
// database
- exist, err = auth_model.AccessTokenByNameExists(user4Token)
+ exist, err = auth_model.AccessTokenByNameExists(db.DefaultContext, user4Token)
assert.NoError(t, err)
assert.False(t, exist)
}
func TestGetAccessTokenBySHA(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
- token, err := auth_model.GetAccessTokenBySHA("d2c6c1ba3890b309189a8e618c72a162e4efbf36")
+ token, err := auth_model.GetAccessTokenBySHA(db.DefaultContext, "d2c6c1ba3890b309189a8e618c72a162e4efbf36")
assert.NoError(t, err)
assert.Equal(t, int64(1), token.UID)
assert.Equal(t, "Token A", token.Name)
assert.Equal(t, "2b3668e11cb82d3af8c6e4524fc7841297668f5008d1626f0ad3417e9fa39af84c268248b78c481daa7e5dc437784003494f", token.TokenHash)
assert.Equal(t, "e4efbf36", token.TokenLastEight)
- _, err = auth_model.GetAccessTokenBySHA("notahash")
+ _, err = auth_model.GetAccessTokenBySHA(db.DefaultContext, "notahash")
assert.Error(t, err)
assert.True(t, auth_model.IsErrAccessTokenNotExist(err))
- _, err = auth_model.GetAccessTokenBySHA("")
+ _, err = auth_model.GetAccessTokenBySHA(db.DefaultContext, "")
assert.Error(t, err)
assert.True(t, auth_model.IsErrAccessTokenEmpty(err))
}
func TestListAccessTokens(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
- tokens, err := auth_model.ListAccessTokens(auth_model.ListAccessTokensOptions{UserID: 1})
+ tokens, err := auth_model.ListAccessTokens(db.DefaultContext, auth_model.ListAccessTokensOptions{UserID: 1})
assert.NoError(t, err)
if assert.Len(t, tokens, 2) {
assert.Equal(t, int64(1), tokens[0].UID)
@@ -93,39 +94,39 @@ func TestListAccessTokens(t *testing.T) {
assert.Contains(t, []string{tokens[0].Name, tokens[1].Name}, "Token B")
}
- tokens, err = auth_model.ListAccessTokens(auth_model.ListAccessTokensOptions{UserID: 2})
+ tokens, err = auth_model.ListAccessTokens(db.DefaultContext, auth_model.ListAccessTokensOptions{UserID: 2})
assert.NoError(t, err)
if assert.Len(t, tokens, 1) {
assert.Equal(t, int64(2), tokens[0].UID)
assert.Equal(t, "Token A", tokens[0].Name)
}
- tokens, err = auth_model.ListAccessTokens(auth_model.ListAccessTokensOptions{UserID: 100})
+ tokens, err = auth_model.ListAccessTokens(db.DefaultContext, auth_model.ListAccessTokensOptions{UserID: 100})
assert.NoError(t, err)
assert.Empty(t, tokens)
}
func TestUpdateAccessToken(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
- token, err := auth_model.GetAccessTokenBySHA("4c6f36e6cf498e2a448662f915d932c09c5a146c")
+ token, err := auth_model.GetAccessTokenBySHA(db.DefaultContext, "4c6f36e6cf498e2a448662f915d932c09c5a146c")
assert.NoError(t, err)
token.Name = "Token Z"
- assert.NoError(t, auth_model.UpdateAccessToken(token))
+ assert.NoError(t, auth_model.UpdateAccessToken(db.DefaultContext, token))
unittest.AssertExistsAndLoadBean(t, token)
}
func TestDeleteAccessTokenByID(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
- token, err := auth_model.GetAccessTokenBySHA("4c6f36e6cf498e2a448662f915d932c09c5a146c")
+ token, err := auth_model.GetAccessTokenBySHA(db.DefaultContext, "4c6f36e6cf498e2a448662f915d932c09c5a146c")
assert.NoError(t, err)
assert.Equal(t, int64(1), token.UID)
- assert.NoError(t, auth_model.DeleteAccessTokenByID(token.ID, 1))
+ assert.NoError(t, auth_model.DeleteAccessTokenByID(db.DefaultContext, token.ID, 1))
unittest.AssertNotExistsBean(t, token)
- err = auth_model.DeleteAccessTokenByID(100, 100)
+ err = auth_model.DeleteAccessTokenByID(db.DefaultContext, 100, 100)
assert.Error(t, err)
assert.True(t, auth_model.IsErrAccessTokenNotExist(err))
}
diff --git a/models/auth/twofactor.go b/models/auth/twofactor.go
index 751a281f7e..51061e5205 100644
--- a/models/auth/twofactor.go
+++ b/models/auth/twofactor.go
@@ -4,6 +4,7 @@
package auth
import (
+ "context"
"crypto/md5"
"crypto/subtle"
"encoding/base32"
@@ -121,22 +122,22 @@ func (t *TwoFactor) ValidateTOTP(passcode string) (bool, error) {
}
// NewTwoFactor creates a new two-factor authentication token.
-func NewTwoFactor(t *TwoFactor) error {
- _, err := db.GetEngine(db.DefaultContext).Insert(t)
+func NewTwoFactor(ctx context.Context, t *TwoFactor) error {
+ _, err := db.GetEngine(ctx).Insert(t)
return err
}
// UpdateTwoFactor updates a two-factor authentication token.
-func UpdateTwoFactor(t *TwoFactor) error {
- _, err := db.GetEngine(db.DefaultContext).ID(t.ID).AllCols().Update(t)
+func UpdateTwoFactor(ctx context.Context, t *TwoFactor) error {
+ _, err := db.GetEngine(ctx).ID(t.ID).AllCols().Update(t)
return err
}
// GetTwoFactorByUID returns the two-factor authentication token associated with
// the user, if any.
-func GetTwoFactorByUID(uid int64) (*TwoFactor, error) {
+func GetTwoFactorByUID(ctx context.Context, uid int64) (*TwoFactor, error) {
twofa := &TwoFactor{}
- has, err := db.GetEngine(db.DefaultContext).Where("uid=?", uid).Get(twofa)
+ has, err := db.GetEngine(ctx).Where("uid=?", uid).Get(twofa)
if err != nil {
return nil, err
} else if !has {
@@ -147,13 +148,13 @@ func GetTwoFactorByUID(uid int64) (*TwoFactor, error) {
// HasTwoFactorByUID returns the two-factor authentication token associated with
// the user, if any.
-func HasTwoFactorByUID(uid int64) (bool, error) {
- return db.GetEngine(db.DefaultContext).Where("uid=?", uid).Exist(&TwoFactor{})
+func HasTwoFactorByUID(ctx context.Context, uid int64) (bool, error) {
+ return db.GetEngine(ctx).Where("uid=?", uid).Exist(&TwoFactor{})
}
// DeleteTwoFactorByID deletes two-factor authentication token by given ID.
-func DeleteTwoFactorByID(id, userID int64) error {
- cnt, err := db.GetEngine(db.DefaultContext).ID(id).Delete(&TwoFactor{
+func DeleteTwoFactorByID(ctx context.Context, id, userID int64) error {
+ cnt, err := db.GetEngine(ctx).ID(id).Delete(&TwoFactor{
UID: userID,
})
if err != nil {