diff options
author | Lunny Xiao <xiaolunwen@gmail.com> | 2022-01-02 21:12:35 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-01-02 21:12:35 +0800 |
commit | de8e3948a5e38f7eaf82d3c0cfd10e995bf68e92 (patch) | |
tree | bbcb011d264e0d614d49c734856b446360c5a4a3 /models/auth | |
parent | e61b390d545919244141b699b28e3fbc42adc66f (diff) | |
download | gitea-de8e3948a5e38f7eaf82d3c0cfd10e995bf68e92.tar.gz gitea-de8e3948a5e38f7eaf82d3c0cfd10e995bf68e92.zip |
Refactor auth package (#17962)
Diffstat (limited to 'models/auth')
-rw-r--r-- | models/auth/main_test.go | 22 | ||||
-rw-r--r-- | models/auth/oauth2.go | 564 | ||||
-rw-r--r-- | models/auth/oauth2_test.go | 233 | ||||
-rw-r--r-- | models/auth/session.go | 126 | ||||
-rw-r--r-- | models/auth/source.go | 397 | ||||
-rw-r--r-- | models/auth/source_test.go | 60 | ||||
-rw-r--r-- | models/auth/twofactor.go | 156 | ||||
-rw-r--r-- | models/auth/u2f.go | 154 | ||||
-rw-r--r-- | models/auth/u2f_test.go | 100 |
9 files changed, 1812 insertions, 0 deletions
diff --git a/models/auth/main_test.go b/models/auth/main_test.go new file mode 100644 index 0000000000..94a1f405d9 --- /dev/null +++ b/models/auth/main_test.go @@ -0,0 +1,22 @@ +// Copyright 2020 The Gitea Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package auth + +import ( + "path/filepath" + "testing" + + "code.gitea.io/gitea/models/unittest" +) + +func TestMain(m *testing.M) { + unittest.MainTest(m, filepath.Join("..", ".."), + "login_source.yml", + "oauth2_application.yml", + "oauth2_authorization_code.yml", + "oauth2_grant.yml", + "u2f_registration.yml", + ) +} diff --git a/models/auth/oauth2.go b/models/auth/oauth2.go new file mode 100644 index 0000000000..e7030fce28 --- /dev/null +++ b/models/auth/oauth2.go @@ -0,0 +1,564 @@ +// Copyright 2019 The Gitea Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package auth + +import ( + "crypto/sha256" + "encoding/base64" + "fmt" + "net/url" + "strings" + + "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/modules/secret" + "code.gitea.io/gitea/modules/timeutil" + "code.gitea.io/gitea/modules/util" + + uuid "github.com/google/uuid" + "golang.org/x/crypto/bcrypt" + "xorm.io/xorm" +) + +// OAuth2Application represents an OAuth2 client (RFC 6749) +type OAuth2Application struct { + ID int64 `xorm:"pk autoincr"` + UID int64 `xorm:"INDEX"` + Name string + ClientID string `xorm:"unique"` + ClientSecret string + RedirectURIs []string `xorm:"redirect_uris JSON TEXT"` + CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"` + UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"` +} + +func init() { + db.RegisterModel(new(OAuth2Application)) + db.RegisterModel(new(OAuth2AuthorizationCode)) + db.RegisterModel(new(OAuth2Grant)) +} + +// TableName sets the table name to `oauth2_application` +func (app *OAuth2Application) TableName() string { + return "oauth2_application" +} + +// PrimaryRedirectURI returns the first redirect uri or an empty string if empty +func (app *OAuth2Application) PrimaryRedirectURI() string { + if len(app.RedirectURIs) == 0 { + return "" + } + return app.RedirectURIs[0] +} + +// ContainsRedirectURI checks if redirectURI is allowed for app +func (app *OAuth2Application) ContainsRedirectURI(redirectURI string) bool { + return util.IsStringInSlice(redirectURI, app.RedirectURIs, true) +} + +// GenerateClientSecret will generate the client secret and returns the plaintext and saves the hash at the database +func (app *OAuth2Application) GenerateClientSecret() (string, error) { + clientSecret, err := secret.New() + if err != nil { + return "", err + } + hashedSecret, err := bcrypt.GenerateFromPassword([]byte(clientSecret), bcrypt.DefaultCost) + if err != nil { + return "", err + } + app.ClientSecret = string(hashedSecret) + if _, err := db.GetEngine(db.DefaultContext).ID(app.ID).Cols("client_secret").Update(app); err != nil { + return "", err + } + return clientSecret, nil +} + +// ValidateClientSecret validates the given secret by the hash saved in database +func (app *OAuth2Application) ValidateClientSecret(secret []byte) bool { + return bcrypt.CompareHashAndPassword([]byte(app.ClientSecret), secret) == nil +} + +// GetGrantByUserID returns a OAuth2Grant by its user and application ID +func (app *OAuth2Application) GetGrantByUserID(userID int64) (*OAuth2Grant, error) { + return app.getGrantByUserID(db.GetEngine(db.DefaultContext), userID) +} + +func (app *OAuth2Application) getGrantByUserID(e db.Engine, userID int64) (grant *OAuth2Grant, err error) { + grant = new(OAuth2Grant) + if has, err := e.Where("user_id = ? AND application_id = ?", userID, app.ID).Get(grant); err != nil { + return nil, err + } else if !has { + return nil, nil + } + return grant, nil +} + +// CreateGrant generates a grant for an user +func (app *OAuth2Application) CreateGrant(userID int64, scope string) (*OAuth2Grant, error) { + return app.createGrant(db.GetEngine(db.DefaultContext), userID, scope) +} + +func (app *OAuth2Application) createGrant(e db.Engine, userID int64, scope string) (*OAuth2Grant, error) { + grant := &OAuth2Grant{ + ApplicationID: app.ID, + UserID: userID, + Scope: scope, + } + _, err := e.Insert(grant) + if err != nil { + return nil, err + } + return grant, nil +} + +// GetOAuth2ApplicationByClientID returns the oauth2 application with the given client_id. Returns an error if not found. +func GetOAuth2ApplicationByClientID(clientID string) (app *OAuth2Application, err error) { + return getOAuth2ApplicationByClientID(db.GetEngine(db.DefaultContext), clientID) +} + +func getOAuth2ApplicationByClientID(e db.Engine, clientID string) (app *OAuth2Application, err error) { + app = new(OAuth2Application) + has, err := e.Where("client_id = ?", clientID).Get(app) + if !has { + return nil, ErrOAuthClientIDInvalid{ClientID: clientID} + } + return +} + +// GetOAuth2ApplicationByID returns the oauth2 application with the given id. Returns an error if not found. +func GetOAuth2ApplicationByID(id int64) (app *OAuth2Application, err error) { + return getOAuth2ApplicationByID(db.GetEngine(db.DefaultContext), id) +} + +func getOAuth2ApplicationByID(e db.Engine, id int64) (app *OAuth2Application, err error) { + app = new(OAuth2Application) + has, err := e.ID(id).Get(app) + if err != nil { + return nil, err + } + if !has { + return nil, ErrOAuthApplicationNotFound{ID: id} + } + return app, nil +} + +// GetOAuth2ApplicationsByUserID returns all oauth2 applications owned by the user +func GetOAuth2ApplicationsByUserID(userID int64) (apps []*OAuth2Application, err error) { + return getOAuth2ApplicationsByUserID(db.GetEngine(db.DefaultContext), userID) +} + +func getOAuth2ApplicationsByUserID(e db.Engine, userID int64) (apps []*OAuth2Application, err error) { + apps = make([]*OAuth2Application, 0) + err = e.Where("uid = ?", userID).Find(&apps) + return +} + +// CreateOAuth2ApplicationOptions holds options to create an oauth2 application +type CreateOAuth2ApplicationOptions struct { + Name string + UserID int64 + RedirectURIs []string +} + +// CreateOAuth2Application inserts a new oauth2 application +func CreateOAuth2Application(opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) { + return createOAuth2Application(db.GetEngine(db.DefaultContext), opts) +} + +func createOAuth2Application(e db.Engine, opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) { + clientID := uuid.New().String() + app := &OAuth2Application{ + UID: opts.UserID, + Name: opts.Name, + ClientID: clientID, + RedirectURIs: opts.RedirectURIs, + } + if _, err := e.Insert(app); err != nil { + return nil, err + } + return app, nil +} + +// UpdateOAuth2ApplicationOptions holds options to update an oauth2 application +type UpdateOAuth2ApplicationOptions struct { + ID int64 + Name string + UserID int64 + RedirectURIs []string +} + +// UpdateOAuth2Application updates an oauth2 application +func UpdateOAuth2Application(opts UpdateOAuth2ApplicationOptions) (*OAuth2Application, error) { + ctx, committer, err := db.TxContext() + if err != nil { + return nil, err + } + defer committer.Close() + sess := db.GetEngine(ctx) + + app, err := getOAuth2ApplicationByID(sess, opts.ID) + if err != nil { + return nil, err + } + if app.UID != opts.UserID { + return nil, fmt.Errorf("UID mismatch") + } + + app.Name = opts.Name + app.RedirectURIs = opts.RedirectURIs + + if err = updateOAuth2Application(sess, app); err != nil { + return nil, err + } + app.ClientSecret = "" + + return app, committer.Commit() +} + +func updateOAuth2Application(e db.Engine, app *OAuth2Application) error { + if _, err := e.ID(app.ID).Update(app); err != nil { + return err + } + return nil +} + +func deleteOAuth2Application(sess db.Engine, id, userid int64) error { + if deleted, err := sess.Delete(&OAuth2Application{ID: id, UID: userid}); err != nil { + return err + } else if deleted == 0 { + return ErrOAuthApplicationNotFound{ID: id} + } + codes := make([]*OAuth2AuthorizationCode, 0) + // delete correlating auth codes + if err := sess.Join("INNER", "oauth2_grant", + "oauth2_authorization_code.grant_id = oauth2_grant.id AND oauth2_grant.application_id = ?", id).Find(&codes); err != nil { + return err + } + codeIDs := make([]int64, 0) + for _, grant := range codes { + codeIDs = append(codeIDs, grant.ID) + } + + if _, err := sess.In("id", codeIDs).Delete(new(OAuth2AuthorizationCode)); err != nil { + return err + } + + if _, err := sess.Where("application_id = ?", id).Delete(new(OAuth2Grant)); err != nil { + return err + } + return nil +} + +// DeleteOAuth2Application deletes the application with the given id and the grants and auth codes related to it. It checks if the userid was the creator of the app. +func DeleteOAuth2Application(id, userid int64) error { + ctx, committer, err := db.TxContext() + if err != nil { + return err + } + defer committer.Close() + if err := deleteOAuth2Application(db.GetEngine(ctx), id, userid); err != nil { + return err + } + return committer.Commit() +} + +// ListOAuth2Applications returns a list of oauth2 applications belongs to given user. +func ListOAuth2Applications(uid int64, listOptions db.ListOptions) ([]*OAuth2Application, int64, error) { + sess := db.GetEngine(db.DefaultContext). + Where("uid=?", uid). + Desc("id") + + if listOptions.Page != 0 { + sess = db.SetSessionPagination(sess, &listOptions) + + apps := make([]*OAuth2Application, 0, listOptions.PageSize) + total, err := sess.FindAndCount(&apps) + return apps, total, err + } + + apps := make([]*OAuth2Application, 0, 5) + total, err := sess.FindAndCount(&apps) + return apps, total, err +} + +////////////////////////////////////////////////////// + +// OAuth2AuthorizationCode is a code to obtain an access token in combination with the client secret once. It has a limited lifetime. +type OAuth2AuthorizationCode struct { + ID int64 `xorm:"pk autoincr"` + Grant *OAuth2Grant `xorm:"-"` + GrantID int64 + Code string `xorm:"INDEX unique"` + CodeChallenge string + CodeChallengeMethod string + RedirectURI string + ValidUntil timeutil.TimeStamp `xorm:"index"` +} + +// TableName sets the table name to `oauth2_authorization_code` +func (code *OAuth2AuthorizationCode) TableName() string { + return "oauth2_authorization_code" +} + +// GenerateRedirectURI generates a redirect URI for a successful authorization request. State will be used if not empty. +func (code *OAuth2AuthorizationCode) GenerateRedirectURI(state string) (redirect *url.URL, err error) { + if redirect, err = url.Parse(code.RedirectURI); err != nil { + return + } + q := redirect.Query() + if state != "" { + q.Set("state", state) + } + q.Set("code", code.Code) + redirect.RawQuery = q.Encode() + return +} + +// Invalidate deletes the auth code from the database to invalidate this code +func (code *OAuth2AuthorizationCode) Invalidate() error { + return code.invalidate(db.GetEngine(db.DefaultContext)) +} + +func (code *OAuth2AuthorizationCode) invalidate(e db.Engine) error { + _, err := e.Delete(code) + return err +} + +// ValidateCodeChallenge validates the given verifier against the saved code challenge. This is part of the PKCE implementation. +func (code *OAuth2AuthorizationCode) ValidateCodeChallenge(verifier string) bool { + return code.validateCodeChallenge(verifier) +} + +func (code *OAuth2AuthorizationCode) validateCodeChallenge(verifier string) bool { + switch code.CodeChallengeMethod { + case "S256": + // base64url(SHA256(verifier)) see https://tools.ietf.org/html/rfc7636#section-4.6 + h := sha256.Sum256([]byte(verifier)) + hashedVerifier := base64.RawURLEncoding.EncodeToString(h[:]) + return hashedVerifier == code.CodeChallenge + case "plain": + return verifier == code.CodeChallenge + case "": + return true + default: + // unsupported method -> return false + return false + } +} + +// GetOAuth2AuthorizationByCode returns an authorization by its code +func GetOAuth2AuthorizationByCode(code string) (*OAuth2AuthorizationCode, error) { + return getOAuth2AuthorizationByCode(db.GetEngine(db.DefaultContext), code) +} + +func getOAuth2AuthorizationByCode(e db.Engine, code string) (auth *OAuth2AuthorizationCode, err error) { + auth = new(OAuth2AuthorizationCode) + if has, err := e.Where("code = ?", code).Get(auth); err != nil { + return nil, err + } else if !has { + return nil, nil + } + auth.Grant = new(OAuth2Grant) + if has, err := e.ID(auth.GrantID).Get(auth.Grant); err != nil { + return nil, err + } else if !has { + return nil, nil + } + return auth, nil +} + +////////////////////////////////////////////////////// + +// OAuth2Grant represents the permission of an user for a specific application to access resources +type OAuth2Grant struct { + ID int64 `xorm:"pk autoincr"` + UserID int64 `xorm:"INDEX unique(user_application)"` + Application *OAuth2Application `xorm:"-"` + ApplicationID int64 `xorm:"INDEX unique(user_application)"` + Counter int64 `xorm:"NOT NULL DEFAULT 1"` + Scope string `xorm:"TEXT"` + Nonce string `xorm:"TEXT"` + CreatedUnix timeutil.TimeStamp `xorm:"created"` + UpdatedUnix timeutil.TimeStamp `xorm:"updated"` +} + +// TableName sets the table name to `oauth2_grant` +func (grant *OAuth2Grant) TableName() string { + return "oauth2_grant" +} + +// GenerateNewAuthorizationCode generates a new authorization code for a grant and saves it to the database +func (grant *OAuth2Grant) GenerateNewAuthorizationCode(redirectURI, codeChallenge, codeChallengeMethod string) (*OAuth2AuthorizationCode, error) { + return grant.generateNewAuthorizationCode(db.GetEngine(db.DefaultContext), redirectURI, codeChallenge, codeChallengeMethod) +} + +func (grant *OAuth2Grant) generateNewAuthorizationCode(e db.Engine, redirectURI, codeChallenge, codeChallengeMethod string) (code *OAuth2AuthorizationCode, err error) { + var codeSecret string + if codeSecret, err = secret.New(); err != nil { + return &OAuth2AuthorizationCode{}, err + } + code = &OAuth2AuthorizationCode{ + Grant: grant, + GrantID: grant.ID, + RedirectURI: redirectURI, + Code: codeSecret, + CodeChallenge: codeChallenge, + CodeChallengeMethod: codeChallengeMethod, + } + if _, err := e.Insert(code); err != nil { + return nil, err + } + return code, nil +} + +// IncreaseCounter increases the counter and updates the grant +func (grant *OAuth2Grant) IncreaseCounter() error { + return grant.increaseCount(db.GetEngine(db.DefaultContext)) +} + +func (grant *OAuth2Grant) increaseCount(e db.Engine) error { + _, err := e.ID(grant.ID).Incr("counter").Update(new(OAuth2Grant)) + if err != nil { + return err + } + updatedGrant, err := getOAuth2GrantByID(e, grant.ID) + if err != nil { + return err + } + grant.Counter = updatedGrant.Counter + return nil +} + +// ScopeContains returns true if the grant scope contains the specified scope +func (grant *OAuth2Grant) ScopeContains(scope string) bool { + for _, currentScope := range strings.Split(grant.Scope, " ") { + if scope == currentScope { + return true + } + } + return false +} + +// SetNonce updates the current nonce value of a grant +func (grant *OAuth2Grant) SetNonce(nonce string) error { + return grant.setNonce(db.GetEngine(db.DefaultContext), nonce) +} + +func (grant *OAuth2Grant) setNonce(e db.Engine, nonce string) error { + grant.Nonce = nonce + _, err := e.ID(grant.ID).Cols("nonce").Update(grant) + if err != nil { + return err + } + return nil +} + +// GetOAuth2GrantByID returns the grant with the given ID +func GetOAuth2GrantByID(id int64) (*OAuth2Grant, error) { + return getOAuth2GrantByID(db.GetEngine(db.DefaultContext), id) +} + +func getOAuth2GrantByID(e db.Engine, id int64) (grant *OAuth2Grant, err error) { + grant = new(OAuth2Grant) + if has, err := e.ID(id).Get(grant); err != nil { + return nil, err + } else if !has { + return nil, nil + } + return +} + +// GetOAuth2GrantsByUserID lists all grants of a certain user +func GetOAuth2GrantsByUserID(uid int64) ([]*OAuth2Grant, error) { + return getOAuth2GrantsByUserID(db.GetEngine(db.DefaultContext), uid) +} + +func getOAuth2GrantsByUserID(e db.Engine, uid int64) ([]*OAuth2Grant, error) { + type joinedOAuth2Grant struct { + Grant *OAuth2Grant `xorm:"extends"` + Application *OAuth2Application `xorm:"extends"` + } + var results *xorm.Rows + var err error + if results, err = e. + Table("oauth2_grant"). + Where("user_id = ?", uid). + Join("INNER", "oauth2_application", "application_id = oauth2_application.id"). + Rows(new(joinedOAuth2Grant)); err != nil { + return nil, err + } + defer results.Close() + grants := make([]*OAuth2Grant, 0) + for results.Next() { + joinedGrant := new(joinedOAuth2Grant) + if err := results.Scan(joinedGrant); err != nil { + return nil, err + } + joinedGrant.Grant.Application = joinedGrant.Application + grants = append(grants, joinedGrant.Grant) + } + return grants, nil +} + +// RevokeOAuth2Grant deletes the grant with grantID and userID +func RevokeOAuth2Grant(grantID, userID int64) error { + return revokeOAuth2Grant(db.GetEngine(db.DefaultContext), grantID, userID) +} + +func revokeOAuth2Grant(e db.Engine, grantID, userID int64) error { + _, err := e.Delete(&OAuth2Grant{ID: grantID, UserID: userID}) + return err +} + +// ErrOAuthClientIDInvalid will be thrown if client id cannot be found +type ErrOAuthClientIDInvalid struct { + ClientID string +} + +// IsErrOauthClientIDInvalid checks if an error is a ErrReviewNotExist. +func IsErrOauthClientIDInvalid(err error) bool { + _, ok := err.(ErrOAuthClientIDInvalid) + return ok +} + +// Error returns the error message +func (err ErrOAuthClientIDInvalid) Error() string { + return fmt.Sprintf("Client ID invalid [Client ID: %s]", err.ClientID) +} + +// ErrOAuthApplicationNotFound will be thrown if id cannot be found +type ErrOAuthApplicationNotFound struct { + ID int64 +} + +// IsErrOAuthApplicationNotFound checks if an error is a ErrReviewNotExist. +func IsErrOAuthApplicationNotFound(err error) bool { + _, ok := err.(ErrOAuthApplicationNotFound) + return ok +} + +// Error returns the error message +func (err ErrOAuthApplicationNotFound) Error() string { + return fmt.Sprintf("OAuth application not found [ID: %d]", err.ID) +} + +// GetActiveOAuth2ProviderSources returns all actived LoginOAuth2 sources +func GetActiveOAuth2ProviderSources() ([]*Source, error) { + sources := make([]*Source, 0, 1) + if err := db.GetEngine(db.DefaultContext).Where("is_active = ? and type = ?", true, OAuth2).Find(&sources); err != nil { + return nil, err + } + return sources, nil +} + +// GetActiveOAuth2SourceByName returns a OAuth2 AuthSource based on the given name +func GetActiveOAuth2SourceByName(name string) (*Source, error) { + authSource := new(Source) + has, err := db.GetEngine(db.DefaultContext).Where("name = ? and type = ? and is_active = ?", name, OAuth2, true).Get(authSource) + if !has || err != nil { + return nil, err + } + + return authSource, nil +} diff --git a/models/auth/oauth2_test.go b/models/auth/oauth2_test.go new file mode 100644 index 0000000000..b712fc285f --- /dev/null +++ b/models/auth/oauth2_test.go @@ -0,0 +1,233 @@ +// Copyright 2019 The Gitea Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package auth + +import ( + "testing" + + "code.gitea.io/gitea/models/unittest" + + "github.com/stretchr/testify/assert" +) + +//////////////////// Application + +func TestOAuth2Application_GenerateClientSecret(t *testing.T) { + assert.NoError(t, unittest.PrepareTestDatabase()) + app := unittest.AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1}).(*OAuth2Application) + secret, err := app.GenerateClientSecret() + assert.NoError(t, err) + assert.True(t, len(secret) > 0) + unittest.AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1, ClientSecret: app.ClientSecret}) +} + +func BenchmarkOAuth2Application_GenerateClientSecret(b *testing.B) { + assert.NoError(b, unittest.PrepareTestDatabase()) + app := unittest.AssertExistsAndLoadBean(b, &OAuth2Application{ID: 1}).(*OAuth2Application) + for i := 0; i < b.N; i++ { + _, _ = app.GenerateClientSecret() + } +} + +func TestOAuth2Application_ContainsRedirectURI(t *testing.T) { + app := &OAuth2Application{ + RedirectURIs: []string{"a", "b", "c"}, + } + assert.True(t, app.ContainsRedirectURI("a")) + assert.True(t, app.ContainsRedirectURI("b")) + assert.True(t, app.ContainsRedirectURI("c")) + assert.False(t, app.ContainsRedirectURI("d")) +} + +func TestOAuth2Application_ValidateClientSecret(t *testing.T) { + assert.NoError(t, unittest.PrepareTestDatabase()) + app := unittest.AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1}).(*OAuth2Application) + secret, err := app.GenerateClientSecret() + assert.NoError(t, err) + assert.True(t, app.ValidateClientSecret([]byte(secret))) + assert.False(t, app.ValidateClientSecret([]byte("fewijfowejgfiowjeoifew"))) +} + +func TestGetOAuth2ApplicationByClientID(t *testing.T) { + assert.NoError(t, unittest.PrepareTestDatabase()) + app, err := GetOAuth2ApplicationByClientID("da7da3ba-9a13-4167-856f-3899de0b0138") + assert.NoError(t, err) + assert.Equal(t, "da7da3ba-9a13-4167-856f-3899de0b0138", app.ClientID) + + app, err = GetOAuth2ApplicationByClientID("invalid client id") + assert.Error(t, err) + assert.Nil(t, app) +} + +func TestCreateOAuth2Application(t *testing.T) { + assert.NoError(t, unittest.PrepareTestDatabase()) + app, err := CreateOAuth2Application(CreateOAuth2ApplicationOptions{Name: "newapp", UserID: 1}) + assert.NoError(t, err) + assert.Equal(t, "newapp", app.Name) + assert.Len(t, app.ClientID, 36) + unittest.AssertExistsAndLoadBean(t, &OAuth2Application{Name: "newapp"}) +} + +func TestOAuth2Application_TableName(t *testing.T) { + assert.Equal(t, "oauth2_application", new(OAuth2Application).TableName()) +} + +func TestOAuth2Application_GetGrantByUserID(t *testing.T) { + assert.NoError(t, unittest.PrepareTestDatabase()) + app := unittest.AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1}).(*OAuth2Application) + grant, err := app.GetGrantByUserID(1) + assert.NoError(t, err) + assert.Equal(t, int64(1), grant.UserID) + + grant, err = app.GetGrantByUserID(34923458) + assert.NoError(t, err) + assert.Nil(t, grant) +} + +func TestOAuth2Application_CreateGrant(t *testing.T) { + assert.NoError(t, unittest.PrepareTestDatabase()) + app := unittest.AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1}).(*OAuth2Application) + grant, err := app.CreateGrant(2, "") + assert.NoError(t, err) + assert.NotNil(t, grant) + assert.Equal(t, int64(2), grant.UserID) + assert.Equal(t, int64(1), grant.ApplicationID) + assert.Equal(t, "", grant.Scope) +} + +//////////////////// Grant + +func TestGetOAuth2GrantByID(t *testing.T) { + assert.NoError(t, unittest.PrepareTestDatabase()) + grant, err := GetOAuth2GrantByID(1) + assert.NoError(t, err) + assert.Equal(t, int64(1), grant.ID) + + grant, err = GetOAuth2GrantByID(34923458) + assert.NoError(t, err) + assert.Nil(t, grant) +} + +func TestOAuth2Grant_IncreaseCounter(t *testing.T) { + assert.NoError(t, unittest.PrepareTestDatabase()) + grant := unittest.AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1, Counter: 1}).(*OAuth2Grant) + assert.NoError(t, grant.IncreaseCounter()) + assert.Equal(t, int64(2), grant.Counter) + unittest.AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1, Counter: 2}) +} + +func TestOAuth2Grant_ScopeContains(t *testing.T) { + assert.NoError(t, unittest.PrepareTestDatabase()) + grant := unittest.AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1, Scope: "openid profile"}).(*OAuth2Grant) + assert.True(t, grant.ScopeContains("openid")) + assert.True(t, grant.ScopeContains("profile")) + assert.False(t, grant.ScopeContains("profil")) + assert.False(t, grant.ScopeContains("profile2")) +} + +func TestOAuth2Grant_GenerateNewAuthorizationCode(t *testing.T) { + assert.NoError(t, unittest.PrepareTestDatabase()) + grant := unittest.AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1}).(*OAuth2Grant) + code, err := grant.GenerateNewAuthorizationCode("https://example2.com/callback", "CjvyTLSdR47G5zYenDA-eDWW4lRrO8yvjcWwbD_deOg", "S256") + assert.NoError(t, err) + assert.NotNil(t, code) + assert.True(t, len(code.Code) > 32) // secret length > 32 +} + +func TestOAuth2Grant_TableName(t *testing.T) { + assert.Equal(t, "oauth2_grant", new(OAuth2Grant).TableName()) +} + +func TestGetOAuth2GrantsByUserID(t *testing.T) { + assert.NoError(t, unittest.PrepareTestDatabase()) + result, err := GetOAuth2GrantsByUserID(1) + assert.NoError(t, err) + assert.Len(t, result, 1) + assert.Equal(t, int64(1), result[0].ID) + assert.Equal(t, result[0].ApplicationID, result[0].Application.ID) + + result, err = GetOAuth2GrantsByUserID(34134) + assert.NoError(t, err) + assert.Empty(t, result) +} + +func TestRevokeOAuth2Grant(t *testing.T) { + assert.NoError(t, unittest.PrepareTestDatabase()) + assert.NoError(t, RevokeOAuth2Grant(1, 1)) + unittest.AssertNotExistsBean(t, &OAuth2Grant{ID: 1, UserID: 1}) +} + +//////////////////// Authorization Code + +func TestGetOAuth2AuthorizationByCode(t *testing.T) { + assert.NoError(t, unittest.PrepareTestDatabase()) + code, err := GetOAuth2AuthorizationByCode("authcode") + assert.NoError(t, err) + assert.NotNil(t, code) + assert.Equal(t, "authcode", code.Code) + assert.Equal(t, int64(1), code.ID) + + code, err = GetOAuth2AuthorizationByCode("does not exist") + assert.NoError(t, err) + assert.Nil(t, code) +} + +func TestOAuth2AuthorizationCode_ValidateCodeChallenge(t *testing.T) { + // test plain + code := &OAuth2AuthorizationCode{ + CodeChallengeMethod: "plain", + CodeChallenge: "test123", + } + assert.True(t, code.ValidateCodeChallenge("test123")) + assert.False(t, code.ValidateCodeChallenge("ierwgjoergjio")) + + // test S256 + code = &OAuth2AuthorizationCode{ + CodeChallengeMethod: "S256", + CodeChallenge: "CjvyTLSdR47G5zYenDA-eDWW4lRrO8yvjcWwbD_deOg", + } + assert.True(t, code.ValidateCodeChallenge("N1Zo9-8Rfwhkt68r1r29ty8YwIraXR8eh_1Qwxg7yQXsonBt")) + assert.False(t, code.ValidateCodeChallenge("wiogjerogorewngoenrgoiuenorg")) + + // test unknown + code = &OAuth2AuthorizationCode{ + CodeChallengeMethod: "monkey", + CodeChallenge: "foiwgjioriogeiogjerger", + } + assert.False(t, code.ValidateCodeChallenge("foiwgjioriogeiogjerger")) + + // test no code challenge + code = &OAuth2AuthorizationCode{ + CodeChallengeMethod: "", + CodeChallenge: "foierjiogerogerg", + } + assert.True(t, code.ValidateCodeChallenge("")) +} + +func TestOAuth2AuthorizationCode_GenerateRedirectURI(t *testing.T) { + code := &OAuth2AuthorizationCode{ + RedirectURI: "https://example.com/callback", + Code: "thecode", + } + + redirect, err := code.GenerateRedirectURI("thestate") + assert.NoError(t, err) + assert.Equal(t, "https://example.com/callback?code=thecode&state=thestate", redirect.String()) + + redirect, err = code.GenerateRedirectURI("") + assert.NoError(t, err) + assert.Equal(t, "https://example.com/callback?code=thecode", redirect.String()) +} + +func TestOAuth2AuthorizationCode_Invalidate(t *testing.T) { + assert.NoError(t, unittest.PrepareTestDatabase()) + code := unittest.AssertExistsAndLoadBean(t, &OAuth2AuthorizationCode{Code: "authcode"}).(*OAuth2AuthorizationCode) + assert.NoError(t, code.Invalidate()) + unittest.AssertNotExistsBean(t, &OAuth2AuthorizationCode{Code: "authcode"}) +} + +func TestOAuth2AuthorizationCode_TableName(t *testing.T) { + assert.Equal(t, "oauth2_authorization_code", new(OAuth2AuthorizationCode).TableName()) +} diff --git a/models/auth/session.go b/models/auth/session.go new file mode 100644 index 0000000000..5b130c64b6 --- /dev/null +++ b/models/auth/session.go @@ -0,0 +1,126 @@ +// Copyright 2020 The Gitea Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package auth + +import ( + "fmt" + + "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/modules/timeutil" +) + +// Session represents a session compatible for go-chi session +type Session struct { + Key string `xorm:"pk CHAR(16)"` // has to be Key to match with go-chi/session + Data []byte `xorm:"BLOB"` // on MySQL this has a maximum size of 64Kb - this may need to be increased + Expiry timeutil.TimeStamp // has to be Expiry to match with go-chi/session +} + +func init() { + db.RegisterModel(new(Session)) +} + +// UpdateSession updates the session with provided id +func UpdateSession(key string, data []byte) error { + _, err := db.GetEngine(db.DefaultContext).ID(key).Update(&Session{ + Data: data, + Expiry: timeutil.TimeStampNow(), + }) + return err +} + +// ReadSession reads the data for the provided session +func ReadSession(key string) (*Session, error) { + session := Session{ + Key: key, + } + + ctx, committer, err := db.TxContext() + if err != nil { + return nil, err + } + defer committer.Close() + + if has, err := db.GetByBean(ctx, &session); err != nil { + return nil, err + } else if !has { + session.Expiry = timeutil.TimeStampNow() + if err := db.Insert(ctx, &session); err != nil { + return nil, err + } + } + + return &session, committer.Commit() +} + +// ExistSession checks if a session exists +func ExistSession(key string) (bool, error) { + session := Session{ + Key: key, + } + return db.GetEngine(db.DefaultContext).Get(&session) +} + +// DestroySession destroys a session +func DestroySession(key string) error { + _, err := db.GetEngine(db.DefaultContext).Delete(&Session{ + Key: key, + }) + return err +} + +// RegenerateSession regenerates a session from the old id +func RegenerateSession(oldKey, newKey string) (*Session, error) { + ctx, committer, err := db.TxContext() + if err != nil { + return nil, err + } + defer committer.Close() + + if has, err := db.GetByBean(ctx, &Session{ + Key: newKey, + }); err != nil { + return nil, err + } else if has { + return nil, fmt.Errorf("session Key: %s already exists", newKey) + } + + if has, err := db.GetByBean(ctx, &Session{ + Key: oldKey, + }); err != nil { + return nil, err + } else if !has { + if err := db.Insert(ctx, &Session{ + Key: oldKey, + Expiry: timeutil.TimeStampNow(), + }); err != nil { + return nil, err + } + } + + if _, err := db.Exec(ctx, "UPDATE "+db.TableName(&Session{})+" SET `key` = ? WHERE `key`=?", newKey, oldKey); err != nil { + return nil, err + } + + s := Session{ + Key: newKey, + } + if _, err := db.GetByBean(ctx, &s); err != nil { + return nil, err + } + + return &s, committer.Commit() +} + +// CountSessions returns the number of sessions +func CountSessions() (int64, error) { + return db.GetEngine(db.DefaultContext).Count(&Session{}) +} + +// CleanupSessions cleans up expired sessions +func CleanupSessions(maxLifetime int64) error { + _, err := db.GetEngine(db.DefaultContext).Where("expiry <= ?", timeutil.TimeStampNow().Add(-maxLifetime)).Delete(&Session{}) + return err +} diff --git a/models/auth/source.go b/models/auth/source.go new file mode 100644 index 0000000000..6f4f5addcb --- /dev/null +++ b/models/auth/source.go @@ -0,0 +1,397 @@ +// Copyright 2014 The Gogs Authors. All rights reserved. +// Copyright 2019 The Gitea Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package auth + +import ( + "fmt" + "reflect" + + "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/timeutil" + + "xorm.io/xorm" + "xorm.io/xorm/convert" +) + +// Type represents an login type. +type Type int + +// Note: new type must append to the end of list to maintain compatibility. +const ( + NoType Type = iota + Plain // 1 + LDAP // 2 + SMTP // 3 + PAM // 4 + DLDAP // 5 + OAuth2 // 6 + SSPI // 7 +) + +// String returns the string name of the LoginType +func (typ Type) String() string { + return Names[typ] +} + +// Int returns the int value of the LoginType +func (typ Type) Int() int { + return int(typ) +} + +// Names contains the name of LoginType values. +var Names = map[Type]string{ + LDAP: "LDAP (via BindDN)", + DLDAP: "LDAP (simple auth)", // Via direct bind + SMTP: "SMTP", + PAM: "PAM", + OAuth2: "OAuth2", + SSPI: "SPNEGO with SSPI", +} + +// Config represents login config as far as the db is concerned +type Config interface { + convert.Conversion +} + +// SkipVerifiable configurations provide a IsSkipVerify to check if SkipVerify is set +type SkipVerifiable interface { + IsSkipVerify() bool +} + +// HasTLSer configurations provide a HasTLS to check if TLS can be enabled +type HasTLSer interface { + HasTLS() bool +} + +// UseTLSer configurations provide a HasTLS to check if TLS is enabled +type UseTLSer interface { + UseTLS() bool +} + +// SSHKeyProvider configurations provide ProvidesSSHKeys to check if they provide SSHKeys +type SSHKeyProvider interface { + ProvidesSSHKeys() bool +} + +// RegisterableSource configurations provide RegisterSource which needs to be run on creation +type RegisterableSource interface { + RegisterSource() error + UnregisterSource() error +} + +var registeredConfigs = map[Type]func() Config{} + +// RegisterTypeConfig register a config for a provided type +func RegisterTypeConfig(typ Type, exemplar Config) { + if reflect.TypeOf(exemplar).Kind() == reflect.Ptr { + // Pointer: + registeredConfigs[typ] = func() Config { + return reflect.New(reflect.ValueOf(exemplar).Elem().Type()).Interface().(Config) + } + return + } + + // Not a Pointer + registeredConfigs[typ] = func() Config { + return reflect.New(reflect.TypeOf(exemplar)).Elem().Interface().(Config) + } +} + +// SourceSettable configurations can have their authSource set on them +type SourceSettable interface { + SetAuthSource(*Source) +} + +// Source represents an external way for authorizing users. +type Source struct { + ID int64 `xorm:"pk autoincr"` + Type Type + Name string `xorm:"UNIQUE"` + IsActive bool `xorm:"INDEX NOT NULL DEFAULT false"` + IsSyncEnabled bool `xorm:"INDEX NOT NULL DEFAULT false"` + Cfg convert.Conversion `xorm:"TEXT"` + + CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"` + UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"` +} + +// TableName xorm will read the table name from this method +func (Source) TableName() string { + return "login_source" +} + +func init() { + db.RegisterModel(new(Source)) +} + +// BeforeSet is invoked from XORM before setting the value of a field of this object. +func (source *Source) BeforeSet(colName string, val xorm.Cell) { + if colName == "type" { + typ := Type(db.Cell2Int64(val)) + constructor, ok := registeredConfigs[typ] + if !ok { + return + } + source.Cfg = constructor() + if settable, ok := source.Cfg.(SourceSettable); ok { + settable.SetAuthSource(source) + } + } +} + +// TypeName return name of this login source type. +func (source *Source) TypeName() string { + return Names[source.Type] +} + +// IsLDAP returns true of this source is of the LDAP type. +func (source *Source) IsLDAP() bool { + return source.Type == LDAP +} + +// IsDLDAP returns true of this source is of the DLDAP type. +func (source *Source) IsDLDAP() bool { + return source.Type == DLDAP +} + +// IsSMTP returns true of this source is of the SMTP type. +func (source *Source) IsSMTP() bool { + return source.Type == SMTP +} + +// IsPAM returns true of this source is of the PAM type. +func (source *Source) IsPAM() bool { + return source.Type == PAM +} + +// IsOAuth2 returns true of this source is of the OAuth2 type. +func (source *Source) IsOAuth2() bool { + return source.Type == OAuth2 +} + +// IsSSPI returns true of this source is of the SSPI type. +func (source *Source) IsSSPI() bool { + return source.Type == SSPI +} + +// HasTLS returns true of this source supports TLS. +func (source *Source) HasTLS() bool { + hasTLSer, ok := source.Cfg.(HasTLSer) + return ok && hasTLSer.HasTLS() +} + +// UseTLS returns true of this source is configured to use TLS. +func (source *Source) UseTLS() bool { + useTLSer, ok := source.Cfg.(UseTLSer) + return ok && useTLSer.UseTLS() +} + +// SkipVerify returns true if this source is configured to skip SSL +// verification. +func (source *Source) SkipVerify() bool { + skipVerifiable, ok := source.Cfg.(SkipVerifiable) + return ok && skipVerifiable.IsSkipVerify() +} + +// CreateSource inserts a AuthSource in the DB if not already +// existing with the given name. +func CreateSource(source *Source) error { + has, err := db.GetEngine(db.DefaultContext).Where("name=?", source.Name).Exist(new(Source)) + if err != nil { + return err + } else if has { + return ErrSourceAlreadyExist{source.Name} + } + // Synchronization is only available with LDAP for now + if !source.IsLDAP() { + source.IsSyncEnabled = false + } + + _, err = db.GetEngine(db.DefaultContext).Insert(source) + if err != nil { + return err + } + + if !source.IsActive { + return nil + } + + if settable, ok := source.Cfg.(SourceSettable); ok { + settable.SetAuthSource(source) + } + + registerableSource, ok := source.Cfg.(RegisterableSource) + if !ok { + return nil + } + + err = registerableSource.RegisterSource() + if err != nil { + // remove the AuthSource in case of errors while registering configuration + if _, err := db.GetEngine(db.DefaultContext).Delete(source); err != nil { + log.Error("CreateSource: Error while wrapOpenIDConnectInitializeError: %v", err) + } + } + return err +} + +// Sources returns a slice of all login sources found in DB. +func Sources() ([]*Source, error) { + auths := make([]*Source, 0, 6) + return auths, db.GetEngine(db.DefaultContext).Find(&auths) +} + +// SourcesByType returns all sources of the specified type +func SourcesByType(loginType Type) ([]*Source, error) { + sources := make([]*Source, 0, 1) + if err := db.GetEngine(db.DefaultContext).Where("type = ?", loginType).Find(&sources); err != nil { + return nil, err + } + return sources, nil +} + +// AllActiveSources returns all active sources +func AllActiveSources() ([]*Source, error) { + sources := make([]*Source, 0, 5) + if err := db.GetEngine(db.DefaultContext).Where("is_active = ?", true).Find(&sources); err != nil { + return nil, err + } + return sources, nil +} + +// ActiveSources returns all active sources of the specified type +func ActiveSources(tp Type) ([]*Source, error) { + sources := make([]*Source, 0, 1) + if err := db.GetEngine(db.DefaultContext).Where("is_active = ? and type = ?", true, tp).Find(&sources); err != nil { + return nil, err + } + return sources, nil +} + +// IsSSPIEnabled returns true if there is at least one activated login +// source of type LoginSSPI +func IsSSPIEnabled() bool { + if !db.HasEngine { + return false + } + sources, err := ActiveSources(SSPI) + if err != nil { + log.Error("ActiveSources: %v", err) + return false + } + return len(sources) > 0 +} + +// GetSourceByID returns login source by given ID. +func GetSourceByID(id int64) (*Source, error) { + source := new(Source) + if id == 0 { + source.Cfg = registeredConfigs[NoType]() + // Set this source to active + // FIXME: allow disabling of db based password authentication in future + source.IsActive = true + return source, nil + } + + has, err := db.GetEngine(db.DefaultContext).ID(id).Get(source) + if err != nil { + return nil, err + } else if !has { + return nil, ErrSourceNotExist{id} + } + return source, nil +} + +// UpdateSource updates a Source record in DB. +func UpdateSource(source *Source) error { + var originalSource *Source + if source.IsOAuth2() { + // keep track of the original values so we can restore in case of errors while registering OAuth2 providers + var err error + if originalSource, err = GetSourceByID(source.ID); err != nil { + return err + } + } + + _, err := db.GetEngine(db.DefaultContext).ID(source.ID).AllCols().Update(source) + if err != nil { + return err + } + + if !source.IsActive { + return nil + } + + if settable, ok := source.Cfg.(SourceSettable); ok { + settable.SetAuthSource(source) + } + + registerableSource, ok := source.Cfg.(RegisterableSource) + if !ok { + return nil + } + + err = registerableSource.RegisterSource() + if err != nil { + // restore original values since we cannot update the provider it self + if _, err := db.GetEngine(db.DefaultContext).ID(source.ID).AllCols().Update(originalSource); err != nil { + log.Error("UpdateSource: Error while wrapOpenIDConnectInitializeError: %v", err) + } + } + return err +} + +// CountSources returns number of login sources. +func CountSources() int64 { + count, _ := db.GetEngine(db.DefaultContext).Count(new(Source)) + return count +} + +// ErrSourceNotExist represents a "SourceNotExist" kind of error. +type ErrSourceNotExist struct { + ID int64 +} + +// IsErrSourceNotExist checks if an error is a ErrSourceNotExist. +func IsErrSourceNotExist(err error) bool { + _, ok := err.(ErrSourceNotExist) + return ok +} + +func (err ErrSourceNotExist) Error() string { + return fmt.Sprintf("login source does not exist [id: %d]", err.ID) +} + +// ErrSourceAlreadyExist represents a "SourceAlreadyExist" kind of error. +type ErrSourceAlreadyExist struct { + Name string +} + +// IsErrSourceAlreadyExist checks if an error is a ErrSourceAlreadyExist. +func IsErrSourceAlreadyExist(err error) bool { + _, ok := err.(ErrSourceAlreadyExist) + return ok +} + +func (err ErrSourceAlreadyExist) Error() string { + return fmt.Sprintf("login source already exists [name: %s]", err.Name) +} + +// ErrSourceInUse represents a "SourceInUse" kind of error. +type ErrSourceInUse struct { + ID int64 +} + +// IsErrSourceInUse checks if an error is a ErrSourceInUse. +func IsErrSourceInUse(err error) bool { + _, ok := err.(ErrSourceInUse) + return ok +} + +func (err ErrSourceInUse) Error() string { + return fmt.Sprintf("login source is still used by some users [id: %d]", err.ID) +} diff --git a/models/auth/source_test.go b/models/auth/source_test.go new file mode 100644 index 0000000000..6a8e286910 --- /dev/null +++ b/models/auth/source_test.go @@ -0,0 +1,60 @@ +// Copyright 2019 The Gitea Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package auth + +import ( + "strings" + "testing" + + "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/models/unittest" + "code.gitea.io/gitea/modules/json" + + "github.com/stretchr/testify/assert" + "xorm.io/xorm/schemas" +) + +type TestSource struct { + Provider string + ClientID string + ClientSecret string + OpenIDConnectAutoDiscoveryURL string + IconURL string +} + +// FromDB fills up a LDAPConfig from serialized format. +func (source *TestSource) FromDB(bs []byte) error { + return json.Unmarshal(bs, &source) +} + +// ToDB exports a LDAPConfig to a serialized format. +func (source *TestSource) ToDB() ([]byte, error) { + return json.Marshal(source) +} + +func TestDumpAuthSource(t *testing.T) { + assert.NoError(t, unittest.PrepareTestDatabase()) + + authSourceSchema, err := db.TableInfo(new(Source)) + assert.NoError(t, err) + + RegisterTypeConfig(OAuth2, new(TestSource)) + + CreateSource(&Source{ + Type: OAuth2, + Name: "TestSource", + IsActive: false, + Cfg: &TestSource{ + Provider: "ConvertibleSourceName", + ClientID: "42", + }, + }) + + sb := new(strings.Builder) + + db.DumpTables([]*schemas.Table{authSourceSchema}, sb) + + assert.Contains(t, sb.String(), `"Provider":"ConvertibleSourceName"`) +} diff --git a/models/auth/twofactor.go b/models/auth/twofactor.go new file mode 100644 index 0000000000..883e6ce01c --- /dev/null +++ b/models/auth/twofactor.go @@ -0,0 +1,156 @@ +// Copyright 2017 The Gitea Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package auth + +import ( + "crypto/md5" + "crypto/sha256" + "crypto/subtle" + "encoding/base64" + "fmt" + + "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/modules/secret" + "code.gitea.io/gitea/modules/setting" + "code.gitea.io/gitea/modules/timeutil" + "code.gitea.io/gitea/modules/util" + + "github.com/pquerna/otp/totp" + "golang.org/x/crypto/pbkdf2" +) + +// +// Two-factor authentication +// + +// ErrTwoFactorNotEnrolled indicates that a user is not enrolled in two-factor authentication. +type ErrTwoFactorNotEnrolled struct { + UID int64 +} + +// IsErrTwoFactorNotEnrolled checks if an error is a ErrTwoFactorNotEnrolled. +func IsErrTwoFactorNotEnrolled(err error) bool { + _, ok := err.(ErrTwoFactorNotEnrolled) + return ok +} + +func (err ErrTwoFactorNotEnrolled) Error() string { + return fmt.Sprintf("user not enrolled in 2FA [uid: %d]", err.UID) +} + +// TwoFactor represents a two-factor authentication token. +type TwoFactor struct { + ID int64 `xorm:"pk autoincr"` + UID int64 `xorm:"UNIQUE"` + Secret string + ScratchSalt string + ScratchHash string + LastUsedPasscode string `xorm:"VARCHAR(10)"` + CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"` + UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"` +} + +func init() { + db.RegisterModel(new(TwoFactor)) +} + +// GenerateScratchToken recreates the scratch token the user is using. +func (t *TwoFactor) GenerateScratchToken() (string, error) { + token, err := util.RandomString(8) + if err != nil { + return "", err + } + t.ScratchSalt, _ = util.RandomString(10) + t.ScratchHash = HashToken(token, t.ScratchSalt) + return token, nil +} + +// HashToken return the hashable salt +func HashToken(token, salt string) string { + tempHash := pbkdf2.Key([]byte(token), []byte(salt), 10000, 50, sha256.New) + return fmt.Sprintf("%x", tempHash) +} + +// VerifyScratchToken verifies if the specified scratch token is valid. +func (t *TwoFactor) VerifyScratchToken(token string) bool { + if len(token) == 0 { + return false + } + tempHash := HashToken(token, t.ScratchSalt) + return subtle.ConstantTimeCompare([]byte(t.ScratchHash), []byte(tempHash)) == 1 +} + +func (t *TwoFactor) getEncryptionKey() []byte { + k := md5.Sum([]byte(setting.SecretKey)) + return k[:] +} + +// SetSecret sets the 2FA secret. +func (t *TwoFactor) SetSecret(secretString string) error { + secretBytes, err := secret.AesEncrypt(t.getEncryptionKey(), []byte(secretString)) + if err != nil { + return err + } + t.Secret = base64.StdEncoding.EncodeToString(secretBytes) + return nil +} + +// ValidateTOTP validates the provided passcode. +func (t *TwoFactor) ValidateTOTP(passcode string) (bool, error) { + decodedStoredSecret, err := base64.StdEncoding.DecodeString(t.Secret) + if err != nil { + return false, err + } + secretBytes, err := secret.AesDecrypt(t.getEncryptionKey(), decodedStoredSecret) + if err != nil { + return false, err + } + secretStr := string(secretBytes) + return totp.Validate(passcode, secretStr), nil +} + +// NewTwoFactor creates a new two-factor authentication token. +func NewTwoFactor(t *TwoFactor) error { + _, err := db.GetEngine(db.DefaultContext).Insert(t) + return err +} + +// UpdateTwoFactor updates a two-factor authentication token. +func UpdateTwoFactor(t *TwoFactor) error { + _, err := db.GetEngine(db.DefaultContext).ID(t.ID).AllCols().Update(t) + return err +} + +// GetTwoFactorByUID returns the two-factor authentication token associated with +// the user, if any. +func GetTwoFactorByUID(uid int64) (*TwoFactor, error) { + twofa := &TwoFactor{} + has, err := db.GetEngine(db.DefaultContext).Where("uid=?", uid).Get(twofa) + if err != nil { + return nil, err + } else if !has { + return nil, ErrTwoFactorNotEnrolled{uid} + } + return twofa, nil +} + +// HasTwoFactorByUID returns the two-factor authentication token associated with +// the user, if any. +func HasTwoFactorByUID(uid int64) (bool, error) { + return db.GetEngine(db.DefaultContext).Where("uid=?", uid).Exist(&TwoFactor{}) +} + +// DeleteTwoFactorByID deletes two-factor authentication token by given ID. +func DeleteTwoFactorByID(id, userID int64) error { + cnt, err := db.GetEngine(db.DefaultContext).ID(id).Delete(&TwoFactor{ + UID: userID, + }) + if err != nil { + return err + } else if cnt != 1 { + return ErrTwoFactorNotEnrolled{userID} + } + return nil +} diff --git a/models/auth/u2f.go b/models/auth/u2f.go new file mode 100644 index 0000000000..71943b237c --- /dev/null +++ b/models/auth/u2f.go @@ -0,0 +1,154 @@ +// Copyright 2018 The Gitea Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package auth + +import ( + "fmt" + + "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/timeutil" + + "github.com/tstranex/u2f" +) + +// ____ ________________________________ .__ __ __ .__ +// | | \_____ \_ _____/\______ \ ____ ____ |__| _______/ |_____________ _/ |_|__| ____ ____ +// | | // ____/| __) | _// __ \ / ___\| |/ ___/\ __\_ __ \__ \\ __\ |/ _ \ / \ +// | | // \| \ | | \ ___// /_/ > |\___ \ | | | | \// __ \| | | ( <_> ) | \ +// |______/ \_______ \___ / |____|_ /\___ >___ /|__/____ > |__| |__| (____ /__| |__|\____/|___| / +// \/ \/ \/ \/_____/ \/ \/ \/ + +// ErrU2FRegistrationNotExist represents a "ErrU2FRegistrationNotExist" kind of error. +type ErrU2FRegistrationNotExist struct { + ID int64 +} + +func (err ErrU2FRegistrationNotExist) Error() string { + return fmt.Sprintf("U2F registration does not exist [id: %d]", err.ID) +} + +// IsErrU2FRegistrationNotExist checks if an error is a ErrU2FRegistrationNotExist. +func IsErrU2FRegistrationNotExist(err error) bool { + _, ok := err.(ErrU2FRegistrationNotExist) + return ok +} + +// U2FRegistration represents the registration data and counter of a security key +type U2FRegistration struct { + ID int64 `xorm:"pk autoincr"` + Name string + UserID int64 `xorm:"INDEX"` + Raw []byte + Counter uint32 `xorm:"BIGINT"` + CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"` + UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"` +} + +func init() { + db.RegisterModel(new(U2FRegistration)) +} + +// TableName returns a better table name for U2FRegistration +func (reg U2FRegistration) TableName() string { + return "u2f_registration" +} + +// Parse will convert the db entry U2FRegistration to an u2f.Registration struct +func (reg *U2FRegistration) Parse() (*u2f.Registration, error) { + r := new(u2f.Registration) + return r, r.UnmarshalBinary(reg.Raw) +} + +func (reg *U2FRegistration) updateCounter(e db.Engine) error { + _, err := e.ID(reg.ID).Cols("counter").Update(reg) + return err +} + +// UpdateCounter will update the database value of counter +func (reg *U2FRegistration) UpdateCounter() error { + return reg.updateCounter(db.GetEngine(db.DefaultContext)) +} + +// U2FRegistrationList is a list of *U2FRegistration +type U2FRegistrationList []*U2FRegistration + +// ToRegistrations will convert all U2FRegistrations to u2f.Registrations +func (list U2FRegistrationList) ToRegistrations() []u2f.Registration { + regs := make([]u2f.Registration, 0, len(list)) + for _, reg := range list { + r, err := reg.Parse() + if err != nil { + log.Error("parsing u2f registration: %v", err) + continue + } + regs = append(regs, *r) + } + + return regs +} + +func getU2FRegistrationsByUID(e db.Engine, uid int64) (U2FRegistrationList, error) { + regs := make(U2FRegistrationList, 0) + return regs, e.Where("user_id = ?", uid).Find(®s) +} + +// GetU2FRegistrationByID returns U2F registration by id +func GetU2FRegistrationByID(id int64) (*U2FRegistration, error) { + return getU2FRegistrationByID(db.GetEngine(db.DefaultContext), id) +} + +func getU2FRegistrationByID(e db.Engine, id int64) (*U2FRegistration, error) { + reg := new(U2FRegistration) + if found, err := e.ID(id).Get(reg); err != nil { + return nil, err + } else if !found { + return nil, ErrU2FRegistrationNotExist{ID: id} + } + return reg, nil +} + +// GetU2FRegistrationsByUID returns all U2F registrations of the given user +func GetU2FRegistrationsByUID(uid int64) (U2FRegistrationList, error) { + return getU2FRegistrationsByUID(db.GetEngine(db.DefaultContext), uid) +} + +// HasU2FRegistrationsByUID returns whether a given user has U2F registrations +func HasU2FRegistrationsByUID(uid int64) (bool, error) { + return db.GetEngine(db.DefaultContext).Where("user_id = ?", uid).Exist(&U2FRegistration{}) +} + +func createRegistration(e db.Engine, userID int64, name string, reg *u2f.Registration) (*U2FRegistration, error) { + raw, err := reg.MarshalBinary() + if err != nil { + return nil, err + } + r := &U2FRegistration{ + UserID: userID, + Name: name, + Counter: 0, + Raw: raw, + } + _, err = e.InsertOne(r) + if err != nil { + return nil, err + } + return r, nil +} + +// CreateRegistration will create a new U2FRegistration from the given Registration +func CreateRegistration(userID int64, name string, reg *u2f.Registration) (*U2FRegistration, error) { + return createRegistration(db.GetEngine(db.DefaultContext), userID, name, reg) +} + +// DeleteRegistration will delete U2FRegistration +func DeleteRegistration(reg *U2FRegistration) error { + return deleteRegistration(db.GetEngine(db.DefaultContext), reg) +} + +func deleteRegistration(e db.Engine, reg *U2FRegistration) error { + _, err := e.Delete(reg) + return err +} diff --git a/models/auth/u2f_test.go b/models/auth/u2f_test.go new file mode 100644 index 0000000000..32ad17839c --- /dev/null +++ b/models/auth/u2f_test.go @@ -0,0 +1,100 @@ +// Copyright 2020 The Gitea Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package auth + +import ( + "encoding/hex" + "testing" + + "code.gitea.io/gitea/models/unittest" + + "github.com/stretchr/testify/assert" + "github.com/tstranex/u2f" +) + +func TestGetU2FRegistrationByID(t *testing.T) { + assert.NoError(t, unittest.PrepareTestDatabase()) + + res, err := GetU2FRegistrationByID(1) + assert.NoError(t, err) + assert.Equal(t, "U2F Key", res.Name) + + _, err = GetU2FRegistrationByID(342432) + assert.Error(t, err) + assert.True(t, IsErrU2FRegistrationNotExist(err)) +} + +func TestGetU2FRegistrationsByUID(t *testing.T) { + assert.NoError(t, unittest.PrepareTestDatabase()) + + res, err := GetU2FRegistrationsByUID(32) + + assert.NoError(t, err) + assert.Len(t, res, 1) + assert.Equal(t, "U2F Key", res[0].Name) +} + +func TestU2FRegistration_TableName(t *testing.T) { + assert.Equal(t, "u2f_registration", U2FRegistration{}.TableName()) +} + +func TestU2FRegistration_UpdateCounter(t *testing.T) { + assert.NoError(t, unittest.PrepareTestDatabase()) + reg := unittest.AssertExistsAndLoadBean(t, &U2FRegistration{ID: 1}).(*U2FRegistration) + reg.Counter = 1 + assert.NoError(t, reg.UpdateCounter()) + unittest.AssertExistsIf(t, true, &U2FRegistration{ID: 1, Counter: 1}) +} + +func TestU2FRegistration_UpdateLargeCounter(t *testing.T) { + assert.NoError(t, unittest.PrepareTestDatabase()) + reg := unittest.AssertExistsAndLoadBean(t, &U2FRegistration{ID: 1}).(*U2FRegistration) + reg.Counter = 0xffffffff + assert.NoError(t, reg.UpdateCounter()) + unittest.AssertExistsIf(t, true, &U2FRegistration{ID: 1, Counter: 0xffffffff}) +} + +func TestCreateRegistration(t *testing.T) { + assert.NoError(t, unittest.PrepareTestDatabase()) + + res, err := CreateRegistration(1, "U2F Created Key", &u2f.Registration{Raw: []byte("Test")}) + assert.NoError(t, err) + assert.Equal(t, "U2F Created Key", res.Name) + assert.Equal(t, []byte("Test"), res.Raw) + + unittest.AssertExistsIf(t, true, &U2FRegistration{Name: "U2F Created Key", UserID: 1}) +} + +func TestDeleteRegistration(t *testing.T) { + assert.NoError(t, unittest.PrepareTestDatabase()) + reg := unittest.AssertExistsAndLoadBean(t, &U2FRegistration{ID: 1}).(*U2FRegistration) + + assert.NoError(t, DeleteRegistration(reg)) + unittest.AssertNotExistsBean(t, &U2FRegistration{ID: 1}) +} + +const validU2FRegistrationResponseHex = "0504b174bc49c7ca254b70d2e5c207cee9cf174820ebd77ea3c65508c26da51b657c1cc6b952f8621697936482da0a6d3d3826a59095daf6cd7c03e2e60385d2f6d9402a552dfdb7477ed65fd84133f86196010b2215b57da75d315b7b9e8fe2e3925a6019551bab61d16591659cbaf00b4950f7abfe6660e2e006f76868b772d70c253082013c3081e4a003020102020a47901280001155957352300a06082a8648ce3d0403023017311530130603550403130c476e756262792050696c6f74301e170d3132303831343138323933325a170d3133303831343138323933325a3031312f302d0603550403132650696c6f74476e756262792d302e342e312d34373930313238303030313135353935373335323059301306072a8648ce3d020106082a8648ce3d030107034200048d617e65c9508e64bcc5673ac82a6799da3c1446682c258c463fffdf58dfd2fa3e6c378b53d795c4a4dffb4199edd7862f23abaf0203b4b8911ba0569994e101300a06082a8648ce3d0403020347003044022060cdb6061e9c22262d1aac1d96d8c70829b2366531dda268832cb836bcd30dfa0220631b1459f09e6330055722c8d89b7f48883b9089b88d60d1d9795902b30410df304502201471899bcc3987e62e8202c9b39c33c19033f7340352dba80fcab017db9230e402210082677d673d891933ade6f617e5dbde2e247e70423fd5ad7804a6d3d3961ef871" + +func TestToRegistrations_SkipInvalidItemsWithoutCrashing(t *testing.T) { + regKeyRaw, _ := hex.DecodeString(validU2FRegistrationResponseHex) + regs := U2FRegistrationList{ + &U2FRegistration{ID: 1}, + &U2FRegistration{ID: 2, Name: "U2F Key", UserID: 2, Counter: 0, Raw: regKeyRaw, CreatedUnix: 946684800, UpdatedUnix: 946684800}, + } + + actual := regs.ToRegistrations() + assert.Len(t, actual, 1) +} + +func TestToRegistrations(t *testing.T) { + regKeyRaw, _ := hex.DecodeString(validU2FRegistrationResponseHex) + regs := U2FRegistrationList{ + &U2FRegistration{ID: 1, Name: "U2F Key", UserID: 1, Counter: 0, Raw: regKeyRaw, CreatedUnix: 946684800, UpdatedUnix: 946684800}, + &U2FRegistration{ID: 2, Name: "U2F Key", UserID: 2, Counter: 0, Raw: regKeyRaw, CreatedUnix: 946684800, UpdatedUnix: 946684800}, + } + + actual := regs.ToRegistrations() + assert.Len(t, actual, 2) +} |