summaryrefslogtreecommitdiffstats
path: root/models/auth
diff options
context:
space:
mode:
authorLunny Xiao <xiaolunwen@gmail.com>2022-01-02 21:12:35 +0800
committerGitHub <noreply@github.com>2022-01-02 21:12:35 +0800
commitde8e3948a5e38f7eaf82d3c0cfd10e995bf68e92 (patch)
treebbcb011d264e0d614d49c734856b446360c5a4a3 /models/auth
parente61b390d545919244141b699b28e3fbc42adc66f (diff)
downloadgitea-de8e3948a5e38f7eaf82d3c0cfd10e995bf68e92.tar.gz
gitea-de8e3948a5e38f7eaf82d3c0cfd10e995bf68e92.zip
Refactor auth package (#17962)
Diffstat (limited to 'models/auth')
-rw-r--r--models/auth/main_test.go22
-rw-r--r--models/auth/oauth2.go564
-rw-r--r--models/auth/oauth2_test.go233
-rw-r--r--models/auth/session.go126
-rw-r--r--models/auth/source.go397
-rw-r--r--models/auth/source_test.go60
-rw-r--r--models/auth/twofactor.go156
-rw-r--r--models/auth/u2f.go154
-rw-r--r--models/auth/u2f_test.go100
9 files changed, 1812 insertions, 0 deletions
diff --git a/models/auth/main_test.go b/models/auth/main_test.go
new file mode 100644
index 0000000000..94a1f405d9
--- /dev/null
+++ b/models/auth/main_test.go
@@ -0,0 +1,22 @@
+// Copyright 2020 The Gitea Authors. All rights reserved.
+// Use of this source code is governed by a MIT-style
+// license that can be found in the LICENSE file.
+
+package auth
+
+import (
+ "path/filepath"
+ "testing"
+
+ "code.gitea.io/gitea/models/unittest"
+)
+
+func TestMain(m *testing.M) {
+ unittest.MainTest(m, filepath.Join("..", ".."),
+ "login_source.yml",
+ "oauth2_application.yml",
+ "oauth2_authorization_code.yml",
+ "oauth2_grant.yml",
+ "u2f_registration.yml",
+ )
+}
diff --git a/models/auth/oauth2.go b/models/auth/oauth2.go
new file mode 100644
index 0000000000..e7030fce28
--- /dev/null
+++ b/models/auth/oauth2.go
@@ -0,0 +1,564 @@
+// Copyright 2019 The Gitea Authors. All rights reserved.
+// Use of this source code is governed by a MIT-style
+// license that can be found in the LICENSE file.
+
+package auth
+
+import (
+ "crypto/sha256"
+ "encoding/base64"
+ "fmt"
+ "net/url"
+ "strings"
+
+ "code.gitea.io/gitea/models/db"
+ "code.gitea.io/gitea/modules/secret"
+ "code.gitea.io/gitea/modules/timeutil"
+ "code.gitea.io/gitea/modules/util"
+
+ uuid "github.com/google/uuid"
+ "golang.org/x/crypto/bcrypt"
+ "xorm.io/xorm"
+)
+
+// OAuth2Application represents an OAuth2 client (RFC 6749)
+type OAuth2Application struct {
+ ID int64 `xorm:"pk autoincr"`
+ UID int64 `xorm:"INDEX"`
+ Name string
+ ClientID string `xorm:"unique"`
+ ClientSecret string
+ RedirectURIs []string `xorm:"redirect_uris JSON TEXT"`
+ CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
+ UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
+}
+
+func init() {
+ db.RegisterModel(new(OAuth2Application))
+ db.RegisterModel(new(OAuth2AuthorizationCode))
+ db.RegisterModel(new(OAuth2Grant))
+}
+
+// TableName sets the table name to `oauth2_application`
+func (app *OAuth2Application) TableName() string {
+ return "oauth2_application"
+}
+
+// PrimaryRedirectURI returns the first redirect uri or an empty string if empty
+func (app *OAuth2Application) PrimaryRedirectURI() string {
+ if len(app.RedirectURIs) == 0 {
+ return ""
+ }
+ return app.RedirectURIs[0]
+}
+
+// ContainsRedirectURI checks if redirectURI is allowed for app
+func (app *OAuth2Application) ContainsRedirectURI(redirectURI string) bool {
+ return util.IsStringInSlice(redirectURI, app.RedirectURIs, true)
+}
+
+// GenerateClientSecret will generate the client secret and returns the plaintext and saves the hash at the database
+func (app *OAuth2Application) GenerateClientSecret() (string, error) {
+ clientSecret, err := secret.New()
+ if err != nil {
+ return "", err
+ }
+ hashedSecret, err := bcrypt.GenerateFromPassword([]byte(clientSecret), bcrypt.DefaultCost)
+ if err != nil {
+ return "", err
+ }
+ app.ClientSecret = string(hashedSecret)
+ if _, err := db.GetEngine(db.DefaultContext).ID(app.ID).Cols("client_secret").Update(app); err != nil {
+ return "", err
+ }
+ return clientSecret, nil
+}
+
+// ValidateClientSecret validates the given secret by the hash saved in database
+func (app *OAuth2Application) ValidateClientSecret(secret []byte) bool {
+ return bcrypt.CompareHashAndPassword([]byte(app.ClientSecret), secret) == nil
+}
+
+// GetGrantByUserID returns a OAuth2Grant by its user and application ID
+func (app *OAuth2Application) GetGrantByUserID(userID int64) (*OAuth2Grant, error) {
+ return app.getGrantByUserID(db.GetEngine(db.DefaultContext), userID)
+}
+
+func (app *OAuth2Application) getGrantByUserID(e db.Engine, userID int64) (grant *OAuth2Grant, err error) {
+ grant = new(OAuth2Grant)
+ if has, err := e.Where("user_id = ? AND application_id = ?", userID, app.ID).Get(grant); err != nil {
+ return nil, err
+ } else if !has {
+ return nil, nil
+ }
+ return grant, nil
+}
+
+// CreateGrant generates a grant for an user
+func (app *OAuth2Application) CreateGrant(userID int64, scope string) (*OAuth2Grant, error) {
+ return app.createGrant(db.GetEngine(db.DefaultContext), userID, scope)
+}
+
+func (app *OAuth2Application) createGrant(e db.Engine, userID int64, scope string) (*OAuth2Grant, error) {
+ grant := &OAuth2Grant{
+ ApplicationID: app.ID,
+ UserID: userID,
+ Scope: scope,
+ }
+ _, err := e.Insert(grant)
+ if err != nil {
+ return nil, err
+ }
+ return grant, nil
+}
+
+// GetOAuth2ApplicationByClientID returns the oauth2 application with the given client_id. Returns an error if not found.
+func GetOAuth2ApplicationByClientID(clientID string) (app *OAuth2Application, err error) {
+ return getOAuth2ApplicationByClientID(db.GetEngine(db.DefaultContext), clientID)
+}
+
+func getOAuth2ApplicationByClientID(e db.Engine, clientID string) (app *OAuth2Application, err error) {
+ app = new(OAuth2Application)
+ has, err := e.Where("client_id = ?", clientID).Get(app)
+ if !has {
+ return nil, ErrOAuthClientIDInvalid{ClientID: clientID}
+ }
+ return
+}
+
+// GetOAuth2ApplicationByID returns the oauth2 application with the given id. Returns an error if not found.
+func GetOAuth2ApplicationByID(id int64) (app *OAuth2Application, err error) {
+ return getOAuth2ApplicationByID(db.GetEngine(db.DefaultContext), id)
+}
+
+func getOAuth2ApplicationByID(e db.Engine, id int64) (app *OAuth2Application, err error) {
+ app = new(OAuth2Application)
+ has, err := e.ID(id).Get(app)
+ if err != nil {
+ return nil, err
+ }
+ if !has {
+ return nil, ErrOAuthApplicationNotFound{ID: id}
+ }
+ return app, nil
+}
+
+// GetOAuth2ApplicationsByUserID returns all oauth2 applications owned by the user
+func GetOAuth2ApplicationsByUserID(userID int64) (apps []*OAuth2Application, err error) {
+ return getOAuth2ApplicationsByUserID(db.GetEngine(db.DefaultContext), userID)
+}
+
+func getOAuth2ApplicationsByUserID(e db.Engine, userID int64) (apps []*OAuth2Application, err error) {
+ apps = make([]*OAuth2Application, 0)
+ err = e.Where("uid = ?", userID).Find(&apps)
+ return
+}
+
+// CreateOAuth2ApplicationOptions holds options to create an oauth2 application
+type CreateOAuth2ApplicationOptions struct {
+ Name string
+ UserID int64
+ RedirectURIs []string
+}
+
+// CreateOAuth2Application inserts a new oauth2 application
+func CreateOAuth2Application(opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) {
+ return createOAuth2Application(db.GetEngine(db.DefaultContext), opts)
+}
+
+func createOAuth2Application(e db.Engine, opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) {
+ clientID := uuid.New().String()
+ app := &OAuth2Application{
+ UID: opts.UserID,
+ Name: opts.Name,
+ ClientID: clientID,
+ RedirectURIs: opts.RedirectURIs,
+ }
+ if _, err := e.Insert(app); err != nil {
+ return nil, err
+ }
+ return app, nil
+}
+
+// UpdateOAuth2ApplicationOptions holds options to update an oauth2 application
+type UpdateOAuth2ApplicationOptions struct {
+ ID int64
+ Name string
+ UserID int64
+ RedirectURIs []string
+}
+
+// UpdateOAuth2Application updates an oauth2 application
+func UpdateOAuth2Application(opts UpdateOAuth2ApplicationOptions) (*OAuth2Application, error) {
+ ctx, committer, err := db.TxContext()
+ if err != nil {
+ return nil, err
+ }
+ defer committer.Close()
+ sess := db.GetEngine(ctx)
+
+ app, err := getOAuth2ApplicationByID(sess, opts.ID)
+ if err != nil {
+ return nil, err
+ }
+ if app.UID != opts.UserID {
+ return nil, fmt.Errorf("UID mismatch")
+ }
+
+ app.Name = opts.Name
+ app.RedirectURIs = opts.RedirectURIs
+
+ if err = updateOAuth2Application(sess, app); err != nil {
+ return nil, err
+ }
+ app.ClientSecret = ""
+
+ return app, committer.Commit()
+}
+
+func updateOAuth2Application(e db.Engine, app *OAuth2Application) error {
+ if _, err := e.ID(app.ID).Update(app); err != nil {
+ return err
+ }
+ return nil
+}
+
+func deleteOAuth2Application(sess db.Engine, id, userid int64) error {
+ if deleted, err := sess.Delete(&OAuth2Application{ID: id, UID: userid}); err != nil {
+ return err
+ } else if deleted == 0 {
+ return ErrOAuthApplicationNotFound{ID: id}
+ }
+ codes := make([]*OAuth2AuthorizationCode, 0)
+ // delete correlating auth codes
+ if err := sess.Join("INNER", "oauth2_grant",
+ "oauth2_authorization_code.grant_id = oauth2_grant.id AND oauth2_grant.application_id = ?", id).Find(&codes); err != nil {
+ return err
+ }
+ codeIDs := make([]int64, 0)
+ for _, grant := range codes {
+ codeIDs = append(codeIDs, grant.ID)
+ }
+
+ if _, err := sess.In("id", codeIDs).Delete(new(OAuth2AuthorizationCode)); err != nil {
+ return err
+ }
+
+ if _, err := sess.Where("application_id = ?", id).Delete(new(OAuth2Grant)); err != nil {
+ return err
+ }
+ return nil
+}
+
+// DeleteOAuth2Application deletes the application with the given id and the grants and auth codes related to it. It checks if the userid was the creator of the app.
+func DeleteOAuth2Application(id, userid int64) error {
+ ctx, committer, err := db.TxContext()
+ if err != nil {
+ return err
+ }
+ defer committer.Close()
+ if err := deleteOAuth2Application(db.GetEngine(ctx), id, userid); err != nil {
+ return err
+ }
+ return committer.Commit()
+}
+
+// ListOAuth2Applications returns a list of oauth2 applications belongs to given user.
+func ListOAuth2Applications(uid int64, listOptions db.ListOptions) ([]*OAuth2Application, int64, error) {
+ sess := db.GetEngine(db.DefaultContext).
+ Where("uid=?", uid).
+ Desc("id")
+
+ if listOptions.Page != 0 {
+ sess = db.SetSessionPagination(sess, &listOptions)
+
+ apps := make([]*OAuth2Application, 0, listOptions.PageSize)
+ total, err := sess.FindAndCount(&apps)
+ return apps, total, err
+ }
+
+ apps := make([]*OAuth2Application, 0, 5)
+ total, err := sess.FindAndCount(&apps)
+ return apps, total, err
+}
+
+//////////////////////////////////////////////////////
+
+// OAuth2AuthorizationCode is a code to obtain an access token in combination with the client secret once. It has a limited lifetime.
+type OAuth2AuthorizationCode struct {
+ ID int64 `xorm:"pk autoincr"`
+ Grant *OAuth2Grant `xorm:"-"`
+ GrantID int64
+ Code string `xorm:"INDEX unique"`
+ CodeChallenge string
+ CodeChallengeMethod string
+ RedirectURI string
+ ValidUntil timeutil.TimeStamp `xorm:"index"`
+}
+
+// TableName sets the table name to `oauth2_authorization_code`
+func (code *OAuth2AuthorizationCode) TableName() string {
+ return "oauth2_authorization_code"
+}
+
+// GenerateRedirectURI generates a redirect URI for a successful authorization request. State will be used if not empty.
+func (code *OAuth2AuthorizationCode) GenerateRedirectURI(state string) (redirect *url.URL, err error) {
+ if redirect, err = url.Parse(code.RedirectURI); err != nil {
+ return
+ }
+ q := redirect.Query()
+ if state != "" {
+ q.Set("state", state)
+ }
+ q.Set("code", code.Code)
+ redirect.RawQuery = q.Encode()
+ return
+}
+
+// Invalidate deletes the auth code from the database to invalidate this code
+func (code *OAuth2AuthorizationCode) Invalidate() error {
+ return code.invalidate(db.GetEngine(db.DefaultContext))
+}
+
+func (code *OAuth2AuthorizationCode) invalidate(e db.Engine) error {
+ _, err := e.Delete(code)
+ return err
+}
+
+// ValidateCodeChallenge validates the given verifier against the saved code challenge. This is part of the PKCE implementation.
+func (code *OAuth2AuthorizationCode) ValidateCodeChallenge(verifier string) bool {
+ return code.validateCodeChallenge(verifier)
+}
+
+func (code *OAuth2AuthorizationCode) validateCodeChallenge(verifier string) bool {
+ switch code.CodeChallengeMethod {
+ case "S256":
+ // base64url(SHA256(verifier)) see https://tools.ietf.org/html/rfc7636#section-4.6
+ h := sha256.Sum256([]byte(verifier))
+ hashedVerifier := base64.RawURLEncoding.EncodeToString(h[:])
+ return hashedVerifier == code.CodeChallenge
+ case "plain":
+ return verifier == code.CodeChallenge
+ case "":
+ return true
+ default:
+ // unsupported method -> return false
+ return false
+ }
+}
+
+// GetOAuth2AuthorizationByCode returns an authorization by its code
+func GetOAuth2AuthorizationByCode(code string) (*OAuth2AuthorizationCode, error) {
+ return getOAuth2AuthorizationByCode(db.GetEngine(db.DefaultContext), code)
+}
+
+func getOAuth2AuthorizationByCode(e db.Engine, code string) (auth *OAuth2AuthorizationCode, err error) {
+ auth = new(OAuth2AuthorizationCode)
+ if has, err := e.Where("code = ?", code).Get(auth); err != nil {
+ return nil, err
+ } else if !has {
+ return nil, nil
+ }
+ auth.Grant = new(OAuth2Grant)
+ if has, err := e.ID(auth.GrantID).Get(auth.Grant); err != nil {
+ return nil, err
+ } else if !has {
+ return nil, nil
+ }
+ return auth, nil
+}
+
+//////////////////////////////////////////////////////
+
+// OAuth2Grant represents the permission of an user for a specific application to access resources
+type OAuth2Grant struct {
+ ID int64 `xorm:"pk autoincr"`
+ UserID int64 `xorm:"INDEX unique(user_application)"`
+ Application *OAuth2Application `xorm:"-"`
+ ApplicationID int64 `xorm:"INDEX unique(user_application)"`
+ Counter int64 `xorm:"NOT NULL DEFAULT 1"`
+ Scope string `xorm:"TEXT"`
+ Nonce string `xorm:"TEXT"`
+ CreatedUnix timeutil.TimeStamp `xorm:"created"`
+ UpdatedUnix timeutil.TimeStamp `xorm:"updated"`
+}
+
+// TableName sets the table name to `oauth2_grant`
+func (grant *OAuth2Grant) TableName() string {
+ return "oauth2_grant"
+}
+
+// GenerateNewAuthorizationCode generates a new authorization code for a grant and saves it to the database
+func (grant *OAuth2Grant) GenerateNewAuthorizationCode(redirectURI, codeChallenge, codeChallengeMethod string) (*OAuth2AuthorizationCode, error) {
+ return grant.generateNewAuthorizationCode(db.GetEngine(db.DefaultContext), redirectURI, codeChallenge, codeChallengeMethod)
+}
+
+func (grant *OAuth2Grant) generateNewAuthorizationCode(e db.Engine, redirectURI, codeChallenge, codeChallengeMethod string) (code *OAuth2AuthorizationCode, err error) {
+ var codeSecret string
+ if codeSecret, err = secret.New(); err != nil {
+ return &OAuth2AuthorizationCode{}, err
+ }
+ code = &OAuth2AuthorizationCode{
+ Grant: grant,
+ GrantID: grant.ID,
+ RedirectURI: redirectURI,
+ Code: codeSecret,
+ CodeChallenge: codeChallenge,
+ CodeChallengeMethod: codeChallengeMethod,
+ }
+ if _, err := e.Insert(code); err != nil {
+ return nil, err
+ }
+ return code, nil
+}
+
+// IncreaseCounter increases the counter and updates the grant
+func (grant *OAuth2Grant) IncreaseCounter() error {
+ return grant.increaseCount(db.GetEngine(db.DefaultContext))
+}
+
+func (grant *OAuth2Grant) increaseCount(e db.Engine) error {
+ _, err := e.ID(grant.ID).Incr("counter").Update(new(OAuth2Grant))
+ if err != nil {
+ return err
+ }
+ updatedGrant, err := getOAuth2GrantByID(e, grant.ID)
+ if err != nil {
+ return err
+ }
+ grant.Counter = updatedGrant.Counter
+ return nil
+}
+
+// ScopeContains returns true if the grant scope contains the specified scope
+func (grant *OAuth2Grant) ScopeContains(scope string) bool {
+ for _, currentScope := range strings.Split(grant.Scope, " ") {
+ if scope == currentScope {
+ return true
+ }
+ }
+ return false
+}
+
+// SetNonce updates the current nonce value of a grant
+func (grant *OAuth2Grant) SetNonce(nonce string) error {
+ return grant.setNonce(db.GetEngine(db.DefaultContext), nonce)
+}
+
+func (grant *OAuth2Grant) setNonce(e db.Engine, nonce string) error {
+ grant.Nonce = nonce
+ _, err := e.ID(grant.ID).Cols("nonce").Update(grant)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+// GetOAuth2GrantByID returns the grant with the given ID
+func GetOAuth2GrantByID(id int64) (*OAuth2Grant, error) {
+ return getOAuth2GrantByID(db.GetEngine(db.DefaultContext), id)
+}
+
+func getOAuth2GrantByID(e db.Engine, id int64) (grant *OAuth2Grant, err error) {
+ grant = new(OAuth2Grant)
+ if has, err := e.ID(id).Get(grant); err != nil {
+ return nil, err
+ } else if !has {
+ return nil, nil
+ }
+ return
+}
+
+// GetOAuth2GrantsByUserID lists all grants of a certain user
+func GetOAuth2GrantsByUserID(uid int64) ([]*OAuth2Grant, error) {
+ return getOAuth2GrantsByUserID(db.GetEngine(db.DefaultContext), uid)
+}
+
+func getOAuth2GrantsByUserID(e db.Engine, uid int64) ([]*OAuth2Grant, error) {
+ type joinedOAuth2Grant struct {
+ Grant *OAuth2Grant `xorm:"extends"`
+ Application *OAuth2Application `xorm:"extends"`
+ }
+ var results *xorm.Rows
+ var err error
+ if results, err = e.
+ Table("oauth2_grant").
+ Where("user_id = ?", uid).
+ Join("INNER", "oauth2_application", "application_id = oauth2_application.id").
+ Rows(new(joinedOAuth2Grant)); err != nil {
+ return nil, err
+ }
+ defer results.Close()
+ grants := make([]*OAuth2Grant, 0)
+ for results.Next() {
+ joinedGrant := new(joinedOAuth2Grant)
+ if err := results.Scan(joinedGrant); err != nil {
+ return nil, err
+ }
+ joinedGrant.Grant.Application = joinedGrant.Application
+ grants = append(grants, joinedGrant.Grant)
+ }
+ return grants, nil
+}
+
+// RevokeOAuth2Grant deletes the grant with grantID and userID
+func RevokeOAuth2Grant(grantID, userID int64) error {
+ return revokeOAuth2Grant(db.GetEngine(db.DefaultContext), grantID, userID)
+}
+
+func revokeOAuth2Grant(e db.Engine, grantID, userID int64) error {
+ _, err := e.Delete(&OAuth2Grant{ID: grantID, UserID: userID})
+ return err
+}
+
+// ErrOAuthClientIDInvalid will be thrown if client id cannot be found
+type ErrOAuthClientIDInvalid struct {
+ ClientID string
+}
+
+// IsErrOauthClientIDInvalid checks if an error is a ErrReviewNotExist.
+func IsErrOauthClientIDInvalid(err error) bool {
+ _, ok := err.(ErrOAuthClientIDInvalid)
+ return ok
+}
+
+// Error returns the error message
+func (err ErrOAuthClientIDInvalid) Error() string {
+ return fmt.Sprintf("Client ID invalid [Client ID: %s]", err.ClientID)
+}
+
+// ErrOAuthApplicationNotFound will be thrown if id cannot be found
+type ErrOAuthApplicationNotFound struct {
+ ID int64
+}
+
+// IsErrOAuthApplicationNotFound checks if an error is a ErrReviewNotExist.
+func IsErrOAuthApplicationNotFound(err error) bool {
+ _, ok := err.(ErrOAuthApplicationNotFound)
+ return ok
+}
+
+// Error returns the error message
+func (err ErrOAuthApplicationNotFound) Error() string {
+ return fmt.Sprintf("OAuth application not found [ID: %d]", err.ID)
+}
+
+// GetActiveOAuth2ProviderSources returns all actived LoginOAuth2 sources
+func GetActiveOAuth2ProviderSources() ([]*Source, error) {
+ sources := make([]*Source, 0, 1)
+ if err := db.GetEngine(db.DefaultContext).Where("is_active = ? and type = ?", true, OAuth2).Find(&sources); err != nil {
+ return nil, err
+ }
+ return sources, nil
+}
+
+// GetActiveOAuth2SourceByName returns a OAuth2 AuthSource based on the given name
+func GetActiveOAuth2SourceByName(name string) (*Source, error) {
+ authSource := new(Source)
+ has, err := db.GetEngine(db.DefaultContext).Where("name = ? and type = ? and is_active = ?", name, OAuth2, true).Get(authSource)
+ if !has || err != nil {
+ return nil, err
+ }
+
+ return authSource, nil
+}
diff --git a/models/auth/oauth2_test.go b/models/auth/oauth2_test.go
new file mode 100644
index 0000000000..b712fc285f
--- /dev/null
+++ b/models/auth/oauth2_test.go
@@ -0,0 +1,233 @@
+// Copyright 2019 The Gitea Authors. All rights reserved.
+// Use of this source code is governed by a MIT-style
+// license that can be found in the LICENSE file.
+
+package auth
+
+import (
+ "testing"
+
+ "code.gitea.io/gitea/models/unittest"
+
+ "github.com/stretchr/testify/assert"
+)
+
+//////////////////// Application
+
+func TestOAuth2Application_GenerateClientSecret(t *testing.T) {
+ assert.NoError(t, unittest.PrepareTestDatabase())
+ app := unittest.AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1}).(*OAuth2Application)
+ secret, err := app.GenerateClientSecret()
+ assert.NoError(t, err)
+ assert.True(t, len(secret) > 0)
+ unittest.AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1, ClientSecret: app.ClientSecret})
+}
+
+func BenchmarkOAuth2Application_GenerateClientSecret(b *testing.B) {
+ assert.NoError(b, unittest.PrepareTestDatabase())
+ app := unittest.AssertExistsAndLoadBean(b, &OAuth2Application{ID: 1}).(*OAuth2Application)
+ for i := 0; i < b.N; i++ {
+ _, _ = app.GenerateClientSecret()
+ }
+}
+
+func TestOAuth2Application_ContainsRedirectURI(t *testing.T) {
+ app := &OAuth2Application{
+ RedirectURIs: []string{"a", "b", "c"},
+ }
+ assert.True(t, app.ContainsRedirectURI("a"))
+ assert.True(t, app.ContainsRedirectURI("b"))
+ assert.True(t, app.ContainsRedirectURI("c"))
+ assert.False(t, app.ContainsRedirectURI("d"))
+}
+
+func TestOAuth2Application_ValidateClientSecret(t *testing.T) {
+ assert.NoError(t, unittest.PrepareTestDatabase())
+ app := unittest.AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1}).(*OAuth2Application)
+ secret, err := app.GenerateClientSecret()
+ assert.NoError(t, err)
+ assert.True(t, app.ValidateClientSecret([]byte(secret)))
+ assert.False(t, app.ValidateClientSecret([]byte("fewijfowejgfiowjeoifew")))
+}
+
+func TestGetOAuth2ApplicationByClientID(t *testing.T) {
+ assert.NoError(t, unittest.PrepareTestDatabase())
+ app, err := GetOAuth2ApplicationByClientID("da7da3ba-9a13-4167-856f-3899de0b0138")
+ assert.NoError(t, err)
+ assert.Equal(t, "da7da3ba-9a13-4167-856f-3899de0b0138", app.ClientID)
+
+ app, err = GetOAuth2ApplicationByClientID("invalid client id")
+ assert.Error(t, err)
+ assert.Nil(t, app)
+}
+
+func TestCreateOAuth2Application(t *testing.T) {
+ assert.NoError(t, unittest.PrepareTestDatabase())
+ app, err := CreateOAuth2Application(CreateOAuth2ApplicationOptions{Name: "newapp", UserID: 1})
+ assert.NoError(t, err)
+ assert.Equal(t, "newapp", app.Name)
+ assert.Len(t, app.ClientID, 36)
+ unittest.AssertExistsAndLoadBean(t, &OAuth2Application{Name: "newapp"})
+}
+
+func TestOAuth2Application_TableName(t *testing.T) {
+ assert.Equal(t, "oauth2_application", new(OAuth2Application).TableName())
+}
+
+func TestOAuth2Application_GetGrantByUserID(t *testing.T) {
+ assert.NoError(t, unittest.PrepareTestDatabase())
+ app := unittest.AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1}).(*OAuth2Application)
+ grant, err := app.GetGrantByUserID(1)
+ assert.NoError(t, err)
+ assert.Equal(t, int64(1), grant.UserID)
+
+ grant, err = app.GetGrantByUserID(34923458)
+ assert.NoError(t, err)
+ assert.Nil(t, grant)
+}
+
+func TestOAuth2Application_CreateGrant(t *testing.T) {
+ assert.NoError(t, unittest.PrepareTestDatabase())
+ app := unittest.AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1}).(*OAuth2Application)
+ grant, err := app.CreateGrant(2, "")
+ assert.NoError(t, err)
+ assert.NotNil(t, grant)
+ assert.Equal(t, int64(2), grant.UserID)
+ assert.Equal(t, int64(1), grant.ApplicationID)
+ assert.Equal(t, "", grant.Scope)
+}
+
+//////////////////// Grant
+
+func TestGetOAuth2GrantByID(t *testing.T) {
+ assert.NoError(t, unittest.PrepareTestDatabase())
+ grant, err := GetOAuth2GrantByID(1)
+ assert.NoError(t, err)
+ assert.Equal(t, int64(1), grant.ID)
+
+ grant, err = GetOAuth2GrantByID(34923458)
+ assert.NoError(t, err)
+ assert.Nil(t, grant)
+}
+
+func TestOAuth2Grant_IncreaseCounter(t *testing.T) {
+ assert.NoError(t, unittest.PrepareTestDatabase())
+ grant := unittest.AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1, Counter: 1}).(*OAuth2Grant)
+ assert.NoError(t, grant.IncreaseCounter())
+ assert.Equal(t, int64(2), grant.Counter)
+ unittest.AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1, Counter: 2})
+}
+
+func TestOAuth2Grant_ScopeContains(t *testing.T) {
+ assert.NoError(t, unittest.PrepareTestDatabase())
+ grant := unittest.AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1, Scope: "openid profile"}).(*OAuth2Grant)
+ assert.True(t, grant.ScopeContains("openid"))
+ assert.True(t, grant.ScopeContains("profile"))
+ assert.False(t, grant.ScopeContains("profil"))
+ assert.False(t, grant.ScopeContains("profile2"))
+}
+
+func TestOAuth2Grant_GenerateNewAuthorizationCode(t *testing.T) {
+ assert.NoError(t, unittest.PrepareTestDatabase())
+ grant := unittest.AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1}).(*OAuth2Grant)
+ code, err := grant.GenerateNewAuthorizationCode("https://example2.com/callback", "CjvyTLSdR47G5zYenDA-eDWW4lRrO8yvjcWwbD_deOg", "S256")
+ assert.NoError(t, err)
+ assert.NotNil(t, code)
+ assert.True(t, len(code.Code) > 32) // secret length > 32
+}
+
+func TestOAuth2Grant_TableName(t *testing.T) {
+ assert.Equal(t, "oauth2_grant", new(OAuth2Grant).TableName())
+}
+
+func TestGetOAuth2GrantsByUserID(t *testing.T) {
+ assert.NoError(t, unittest.PrepareTestDatabase())
+ result, err := GetOAuth2GrantsByUserID(1)
+ assert.NoError(t, err)
+ assert.Len(t, result, 1)
+ assert.Equal(t, int64(1), result[0].ID)
+ assert.Equal(t, result[0].ApplicationID, result[0].Application.ID)
+
+ result, err = GetOAuth2GrantsByUserID(34134)
+ assert.NoError(t, err)
+ assert.Empty(t, result)
+}
+
+func TestRevokeOAuth2Grant(t *testing.T) {
+ assert.NoError(t, unittest.PrepareTestDatabase())
+ assert.NoError(t, RevokeOAuth2Grant(1, 1))
+ unittest.AssertNotExistsBean(t, &OAuth2Grant{ID: 1, UserID: 1})
+}
+
+//////////////////// Authorization Code
+
+func TestGetOAuth2AuthorizationByCode(t *testing.T) {
+ assert.NoError(t, unittest.PrepareTestDatabase())
+ code, err := GetOAuth2AuthorizationByCode("authcode")
+ assert.NoError(t, err)
+ assert.NotNil(t, code)
+ assert.Equal(t, "authcode", code.Code)
+ assert.Equal(t, int64(1), code.ID)
+
+ code, err = GetOAuth2AuthorizationByCode("does not exist")
+ assert.NoError(t, err)
+ assert.Nil(t, code)
+}
+
+func TestOAuth2AuthorizationCode_ValidateCodeChallenge(t *testing.T) {
+ // test plain
+ code := &OAuth2AuthorizationCode{
+ CodeChallengeMethod: "plain",
+ CodeChallenge: "test123",
+ }
+ assert.True(t, code.ValidateCodeChallenge("test123"))
+ assert.False(t, code.ValidateCodeChallenge("ierwgjoergjio"))
+
+ // test S256
+ code = &OAuth2AuthorizationCode{
+ CodeChallengeMethod: "S256",
+ CodeChallenge: "CjvyTLSdR47G5zYenDA-eDWW4lRrO8yvjcWwbD_deOg",
+ }
+ assert.True(t, code.ValidateCodeChallenge("N1Zo9-8Rfwhkt68r1r29ty8YwIraXR8eh_1Qwxg7yQXsonBt"))
+ assert.False(t, code.ValidateCodeChallenge("wiogjerogorewngoenrgoiuenorg"))
+
+ // test unknown
+ code = &OAuth2AuthorizationCode{
+ CodeChallengeMethod: "monkey",
+ CodeChallenge: "foiwgjioriogeiogjerger",
+ }
+ assert.False(t, code.ValidateCodeChallenge("foiwgjioriogeiogjerger"))
+
+ // test no code challenge
+ code = &OAuth2AuthorizationCode{
+ CodeChallengeMethod: "",
+ CodeChallenge: "foierjiogerogerg",
+ }
+ assert.True(t, code.ValidateCodeChallenge(""))
+}
+
+func TestOAuth2AuthorizationCode_GenerateRedirectURI(t *testing.T) {
+ code := &OAuth2AuthorizationCode{
+ RedirectURI: "https://example.com/callback",
+ Code: "thecode",
+ }
+
+ redirect, err := code.GenerateRedirectURI("thestate")
+ assert.NoError(t, err)
+ assert.Equal(t, "https://example.com/callback?code=thecode&state=thestate", redirect.String())
+
+ redirect, err = code.GenerateRedirectURI("")
+ assert.NoError(t, err)
+ assert.Equal(t, "https://example.com/callback?code=thecode", redirect.String())
+}
+
+func TestOAuth2AuthorizationCode_Invalidate(t *testing.T) {
+ assert.NoError(t, unittest.PrepareTestDatabase())
+ code := unittest.AssertExistsAndLoadBean(t, &OAuth2AuthorizationCode{Code: "authcode"}).(*OAuth2AuthorizationCode)
+ assert.NoError(t, code.Invalidate())
+ unittest.AssertNotExistsBean(t, &OAuth2AuthorizationCode{Code: "authcode"})
+}
+
+func TestOAuth2AuthorizationCode_TableName(t *testing.T) {
+ assert.Equal(t, "oauth2_authorization_code", new(OAuth2AuthorizationCode).TableName())
+}
diff --git a/models/auth/session.go b/models/auth/session.go
new file mode 100644
index 0000000000..5b130c64b6
--- /dev/null
+++ b/models/auth/session.go
@@ -0,0 +1,126 @@
+// Copyright 2020 The Gitea Authors. All rights reserved.
+// Use of this source code is governed by a MIT-style
+// license that can be found in the LICENSE file.
+
+package auth
+
+import (
+ "fmt"
+
+ "code.gitea.io/gitea/models/db"
+ "code.gitea.io/gitea/modules/timeutil"
+)
+
+// Session represents a session compatible for go-chi session
+type Session struct {
+ Key string `xorm:"pk CHAR(16)"` // has to be Key to match with go-chi/session
+ Data []byte `xorm:"BLOB"` // on MySQL this has a maximum size of 64Kb - this may need to be increased
+ Expiry timeutil.TimeStamp // has to be Expiry to match with go-chi/session
+}
+
+func init() {
+ db.RegisterModel(new(Session))
+}
+
+// UpdateSession updates the session with provided id
+func UpdateSession(key string, data []byte) error {
+ _, err := db.GetEngine(db.DefaultContext).ID(key).Update(&Session{
+ Data: data,
+ Expiry: timeutil.TimeStampNow(),
+ })
+ return err
+}
+
+// ReadSession reads the data for the provided session
+func ReadSession(key string) (*Session, error) {
+ session := Session{
+ Key: key,
+ }
+
+ ctx, committer, err := db.TxContext()
+ if err != nil {
+ return nil, err
+ }
+ defer committer.Close()
+
+ if has, err := db.GetByBean(ctx, &session); err != nil {
+ return nil, err
+ } else if !has {
+ session.Expiry = timeutil.TimeStampNow()
+ if err := db.Insert(ctx, &session); err != nil {
+ return nil, err
+ }
+ }
+
+ return &session, committer.Commit()
+}
+
+// ExistSession checks if a session exists
+func ExistSession(key string) (bool, error) {
+ session := Session{
+ Key: key,
+ }
+ return db.GetEngine(db.DefaultContext).Get(&session)
+}
+
+// DestroySession destroys a session
+func DestroySession(key string) error {
+ _, err := db.GetEngine(db.DefaultContext).Delete(&Session{
+ Key: key,
+ })
+ return err
+}
+
+// RegenerateSession regenerates a session from the old id
+func RegenerateSession(oldKey, newKey string) (*Session, error) {
+ ctx, committer, err := db.TxContext()
+ if err != nil {
+ return nil, err
+ }
+ defer committer.Close()
+
+ if has, err := db.GetByBean(ctx, &Session{
+ Key: newKey,
+ }); err != nil {
+ return nil, err
+ } else if has {
+ return nil, fmt.Errorf("session Key: %s already exists", newKey)
+ }
+
+ if has, err := db.GetByBean(ctx, &Session{
+ Key: oldKey,
+ }); err != nil {
+ return nil, err
+ } else if !has {
+ if err := db.Insert(ctx, &Session{
+ Key: oldKey,
+ Expiry: timeutil.TimeStampNow(),
+ }); err != nil {
+ return nil, err
+ }
+ }
+
+ if _, err := db.Exec(ctx, "UPDATE "+db.TableName(&Session{})+" SET `key` = ? WHERE `key`=?", newKey, oldKey); err != nil {
+ return nil, err
+ }
+
+ s := Session{
+ Key: newKey,
+ }
+ if _, err := db.GetByBean(ctx, &s); err != nil {
+ return nil, err
+ }
+
+ return &s, committer.Commit()
+}
+
+// CountSessions returns the number of sessions
+func CountSessions() (int64, error) {
+ return db.GetEngine(db.DefaultContext).Count(&Session{})
+}
+
+// CleanupSessions cleans up expired sessions
+func CleanupSessions(maxLifetime int64) error {
+ _, err := db.GetEngine(db.DefaultContext).Where("expiry <= ?", timeutil.TimeStampNow().Add(-maxLifetime)).Delete(&Session{})
+ return err
+}
diff --git a/models/auth/source.go b/models/auth/source.go
new file mode 100644
index 0000000000..6f4f5addcb
--- /dev/null
+++ b/models/auth/source.go
@@ -0,0 +1,397 @@
+// Copyright 2014 The Gogs Authors. All rights reserved.
+// Copyright 2019 The Gitea Authors. All rights reserved.
+// Use of this source code is governed by a MIT-style
+// license that can be found in the LICENSE file.
+
+package auth
+
+import (
+ "fmt"
+ "reflect"
+
+ "code.gitea.io/gitea/models/db"
+ "code.gitea.io/gitea/modules/log"
+ "code.gitea.io/gitea/modules/timeutil"
+
+ "xorm.io/xorm"
+ "xorm.io/xorm/convert"
+)
+
+// Type represents an login type.
+type Type int
+
+// Note: new type must append to the end of list to maintain compatibility.
+const (
+ NoType Type = iota
+ Plain // 1
+ LDAP // 2
+ SMTP // 3
+ PAM // 4
+ DLDAP // 5
+ OAuth2 // 6
+ SSPI // 7
+)
+
+// String returns the string name of the LoginType
+func (typ Type) String() string {
+ return Names[typ]
+}
+
+// Int returns the int value of the LoginType
+func (typ Type) Int() int {
+ return int(typ)
+}
+
+// Names contains the name of LoginType values.
+var Names = map[Type]string{
+ LDAP: "LDAP (via BindDN)",
+ DLDAP: "LDAP (simple auth)", // Via direct bind
+ SMTP: "SMTP",
+ PAM: "PAM",
+ OAuth2: "OAuth2",
+ SSPI: "SPNEGO with SSPI",
+}
+
+// Config represents login config as far as the db is concerned
+type Config interface {
+ convert.Conversion
+}
+
+// SkipVerifiable configurations provide a IsSkipVerify to check if SkipVerify is set
+type SkipVerifiable interface {
+ IsSkipVerify() bool
+}
+
+// HasTLSer configurations provide a HasTLS to check if TLS can be enabled
+type HasTLSer interface {
+ HasTLS() bool
+}
+
+// UseTLSer configurations provide a HasTLS to check if TLS is enabled
+type UseTLSer interface {
+ UseTLS() bool
+}
+
+// SSHKeyProvider configurations provide ProvidesSSHKeys to check if they provide SSHKeys
+type SSHKeyProvider interface {
+ ProvidesSSHKeys() bool
+}
+
+// RegisterableSource configurations provide RegisterSource which needs to be run on creation
+type RegisterableSource interface {
+ RegisterSource() error
+ UnregisterSource() error
+}
+
+var registeredConfigs = map[Type]func() Config{}
+
+// RegisterTypeConfig register a config for a provided type
+func RegisterTypeConfig(typ Type, exemplar Config) {
+ if reflect.TypeOf(exemplar).Kind() == reflect.Ptr {
+ // Pointer:
+ registeredConfigs[typ] = func() Config {
+ return reflect.New(reflect.ValueOf(exemplar).Elem().Type()).Interface().(Config)
+ }
+ return
+ }
+
+ // Not a Pointer
+ registeredConfigs[typ] = func() Config {
+ return reflect.New(reflect.TypeOf(exemplar)).Elem().Interface().(Config)
+ }
+}
+
+// SourceSettable configurations can have their authSource set on them
+type SourceSettable interface {
+ SetAuthSource(*Source)
+}
+
+// Source represents an external way for authorizing users.
+type Source struct {
+ ID int64 `xorm:"pk autoincr"`
+ Type Type
+ Name string `xorm:"UNIQUE"`
+ IsActive bool `xorm:"INDEX NOT NULL DEFAULT false"`
+ IsSyncEnabled bool `xorm:"INDEX NOT NULL DEFAULT false"`
+ Cfg convert.Conversion `xorm:"TEXT"`
+
+ CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
+ UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
+}
+
+// TableName xorm will read the table name from this method
+func (Source) TableName() string {
+ return "login_source"
+}
+
+func init() {
+ db.RegisterModel(new(Source))
+}
+
+// BeforeSet is invoked from XORM before setting the value of a field of this object.
+func (source *Source) BeforeSet(colName string, val xorm.Cell) {
+ if colName == "type" {
+ typ := Type(db.Cell2Int64(val))
+ constructor, ok := registeredConfigs[typ]
+ if !ok {
+ return
+ }
+ source.Cfg = constructor()
+ if settable, ok := source.Cfg.(SourceSettable); ok {
+ settable.SetAuthSource(source)
+ }
+ }
+}
+
+// TypeName return name of this login source type.
+func (source *Source) TypeName() string {
+ return Names[source.Type]
+}
+
+// IsLDAP returns true of this source is of the LDAP type.
+func (source *Source) IsLDAP() bool {
+ return source.Type == LDAP
+}
+
+// IsDLDAP returns true of this source is of the DLDAP type.
+func (source *Source) IsDLDAP() bool {
+ return source.Type == DLDAP
+}
+
+// IsSMTP returns true of this source is of the SMTP type.
+func (source *Source) IsSMTP() bool {
+ return source.Type == SMTP
+}
+
+// IsPAM returns true of this source is of the PAM type.
+func (source *Source) IsPAM() bool {
+ return source.Type == PAM
+}
+
+// IsOAuth2 returns true of this source is of the OAuth2 type.
+func (source *Source) IsOAuth2() bool {
+ return source.Type == OAuth2
+}
+
+// IsSSPI returns true of this source is of the SSPI type.
+func (source *Source) IsSSPI() bool {
+ return source.Type == SSPI
+}
+
+// HasTLS returns true of this source supports TLS.
+func (source *Source) HasTLS() bool {
+ hasTLSer, ok := source.Cfg.(HasTLSer)
+ return ok && hasTLSer.HasTLS()
+}
+
+// UseTLS returns true of this source is configured to use TLS.
+func (source *Source) UseTLS() bool {
+ useTLSer, ok := source.Cfg.(UseTLSer)
+ return ok && useTLSer.UseTLS()
+}
+
+// SkipVerify returns true if this source is configured to skip SSL
+// verification.
+func (source *Source) SkipVerify() bool {
+ skipVerifiable, ok := source.Cfg.(SkipVerifiable)
+ return ok && skipVerifiable.IsSkipVerify()
+}
+
+// CreateSource inserts a AuthSource in the DB if not already
+// existing with the given name.
+func CreateSource(source *Source) error {
+ has, err := db.GetEngine(db.DefaultContext).Where("name=?", source.Name).Exist(new(Source))
+ if err != nil {
+ return err
+ } else if has {
+ return ErrSourceAlreadyExist{source.Name}
+ }
+ // Synchronization is only available with LDAP for now
+ if !source.IsLDAP() {
+ source.IsSyncEnabled = false
+ }
+
+ _, err = db.GetEngine(db.DefaultContext).Insert(source)
+ if err != nil {
+ return err
+ }
+
+ if !source.IsActive {
+ return nil
+ }
+
+ if settable, ok := source.Cfg.(SourceSettable); ok {
+ settable.SetAuthSource(source)
+ }
+
+ registerableSource, ok := source.Cfg.(RegisterableSource)
+ if !ok {
+ return nil
+ }
+
+ err = registerableSource.RegisterSource()
+ if err != nil {
+ // remove the AuthSource in case of errors while registering configuration
+ if _, err := db.GetEngine(db.DefaultContext).Delete(source); err != nil {
+ log.Error("CreateSource: Error while wrapOpenIDConnectInitializeError: %v", err)
+ }
+ }
+ return err
+}
+
+// Sources returns a slice of all login sources found in DB.
+func Sources() ([]*Source, error) {
+ auths := make([]*Source, 0, 6)
+ return auths, db.GetEngine(db.DefaultContext).Find(&auths)
+}
+
+// SourcesByType returns all sources of the specified type
+func SourcesByType(loginType Type) ([]*Source, error) {
+ sources := make([]*Source, 0, 1)
+ if err := db.GetEngine(db.DefaultContext).Where("type = ?", loginType).Find(&sources); err != nil {
+ return nil, err
+ }
+ return sources, nil
+}
+
+// AllActiveSources returns all active sources
+func AllActiveSources() ([]*Source, error) {
+ sources := make([]*Source, 0, 5)
+ if err := db.GetEngine(db.DefaultContext).Where("is_active = ?", true).Find(&sources); err != nil {
+ return nil, err
+ }
+ return sources, nil
+}
+
+// ActiveSources returns all active sources of the specified type
+func ActiveSources(tp Type) ([]*Source, error) {
+ sources := make([]*Source, 0, 1)
+ if err := db.GetEngine(db.DefaultContext).Where("is_active = ? and type = ?", true, tp).Find(&sources); err != nil {
+ return nil, err
+ }
+ return sources, nil
+}
+
+// IsSSPIEnabled returns true if there is at least one activated login
+// source of type LoginSSPI
+func IsSSPIEnabled() bool {
+ if !db.HasEngine {
+ return false
+ }
+ sources, err := ActiveSources(SSPI)
+ if err != nil {
+ log.Error("ActiveSources: %v", err)
+ return false
+ }
+ return len(sources) > 0
+}
+
+// GetSourceByID returns login source by given ID.
+func GetSourceByID(id int64) (*Source, error) {
+ source := new(Source)
+ if id == 0 {
+ source.Cfg = registeredConfigs[NoType]()
+ // Set this source to active
+ // FIXME: allow disabling of db based password authentication in future
+ source.IsActive = true
+ return source, nil
+ }
+
+ has, err := db.GetEngine(db.DefaultContext).ID(id).Get(source)
+ if err != nil {
+ return nil, err
+ } else if !has {
+ return nil, ErrSourceNotExist{id}
+ }
+ return source, nil
+}
+
+// UpdateSource updates a Source record in DB.
+func UpdateSource(source *Source) error {
+ var originalSource *Source
+ if source.IsOAuth2() {
+ // keep track of the original values so we can restore in case of errors while registering OAuth2 providers
+ var err error
+ if originalSource, err = GetSourceByID(source.ID); err != nil {
+ return err
+ }
+ }
+
+ _, err := db.GetEngine(db.DefaultContext).ID(source.ID).AllCols().Update(source)
+ if err != nil {
+ return err
+ }
+
+ if !source.IsActive {
+ return nil
+ }
+
+ if settable, ok := source.Cfg.(SourceSettable); ok {
+ settable.SetAuthSource(source)
+ }
+
+ registerableSource, ok := source.Cfg.(RegisterableSource)
+ if !ok {
+ return nil
+ }
+
+ err = registerableSource.RegisterSource()
+ if err != nil {
+ // restore original values since we cannot update the provider it self
+ if _, err := db.GetEngine(db.DefaultContext).ID(source.ID).AllCols().Update(originalSource); err != nil {
+ log.Error("UpdateSource: Error while wrapOpenIDConnectInitializeError: %v", err)
+ }
+ }
+ return err
+}
+
+// CountSources returns number of login sources.
+func CountSources() int64 {
+ count, _ := db.GetEngine(db.DefaultContext).Count(new(Source))
+ return count
+}
+
+// ErrSourceNotExist represents a "SourceNotExist" kind of error.
+type ErrSourceNotExist struct {
+ ID int64
+}
+
+// IsErrSourceNotExist checks if an error is a ErrSourceNotExist.
+func IsErrSourceNotExist(err error) bool {
+ _, ok := err.(ErrSourceNotExist)
+ return ok
+}
+
+func (err ErrSourceNotExist) Error() string {
+ return fmt.Sprintf("login source does not exist [id: %d]", err.ID)
+}
+
+// ErrSourceAlreadyExist represents a "SourceAlreadyExist" kind of error.
+type ErrSourceAlreadyExist struct {
+ Name string
+}
+
+// IsErrSourceAlreadyExist checks if an error is a ErrSourceAlreadyExist.
+func IsErrSourceAlreadyExist(err error) bool {
+ _, ok := err.(ErrSourceAlreadyExist)
+ return ok
+}
+
+func (err ErrSourceAlreadyExist) Error() string {
+ return fmt.Sprintf("login source already exists [name: %s]", err.Name)
+}
+
+// ErrSourceInUse represents a "SourceInUse" kind of error.
+type ErrSourceInUse struct {
+ ID int64
+}
+
+// IsErrSourceInUse checks if an error is a ErrSourceInUse.
+func IsErrSourceInUse(err error) bool {
+ _, ok := err.(ErrSourceInUse)
+ return ok
+}
+
+func (err ErrSourceInUse) Error() string {
+ return fmt.Sprintf("login source is still used by some users [id: %d]", err.ID)
+}
diff --git a/models/auth/source_test.go b/models/auth/source_test.go
new file mode 100644
index 0000000000..6a8e286910
--- /dev/null
+++ b/models/auth/source_test.go
@@ -0,0 +1,60 @@
+// Copyright 2019 The Gitea Authors. All rights reserved.
+// Use of this source code is governed by a MIT-style
+// license that can be found in the LICENSE file.
+
+package auth
+
+import (
+ "strings"
+ "testing"
+
+ "code.gitea.io/gitea/models/db"
+ "code.gitea.io/gitea/models/unittest"
+ "code.gitea.io/gitea/modules/json"
+
+ "github.com/stretchr/testify/assert"
+ "xorm.io/xorm/schemas"
+)
+
+type TestSource struct {
+ Provider string
+ ClientID string
+ ClientSecret string
+ OpenIDConnectAutoDiscoveryURL string
+ IconURL string
+}
+
+// FromDB fills up a LDAPConfig from serialized format.
+func (source *TestSource) FromDB(bs []byte) error {
+ return json.Unmarshal(bs, &source)
+}
+
+// ToDB exports a LDAPConfig to a serialized format.
+func (source *TestSource) ToDB() ([]byte, error) {
+ return json.Marshal(source)
+}
+
+func TestDumpAuthSource(t *testing.T) {
+ assert.NoError(t, unittest.PrepareTestDatabase())
+
+ authSourceSchema, err := db.TableInfo(new(Source))
+ assert.NoError(t, err)
+
+ RegisterTypeConfig(OAuth2, new(TestSource))
+
+ CreateSource(&Source{
+ Type: OAuth2,
+ Name: "TestSource",
+ IsActive: false,
+ Cfg: &TestSource{
+ Provider: "ConvertibleSourceName",
+ ClientID: "42",
+ },
+ })
+
+ sb := new(strings.Builder)
+
+ db.DumpTables([]*schemas.Table{authSourceSchema}, sb)
+
+ assert.Contains(t, sb.String(), `"Provider":"ConvertibleSourceName"`)
+}
diff --git a/models/auth/twofactor.go b/models/auth/twofactor.go
new file mode 100644
index 0000000000..883e6ce01c
--- /dev/null
+++ b/models/auth/twofactor.go
@@ -0,0 +1,156 @@
+// Copyright 2017 The Gitea Authors. All rights reserved.
+// Use of this source code is governed by a MIT-style
+// license that can be found in the LICENSE file.
+
+package auth
+
+import (
+ "crypto/md5"
+ "crypto/sha256"
+ "crypto/subtle"
+ "encoding/base64"
+ "fmt"
+
+ "code.gitea.io/gitea/models/db"
+ "code.gitea.io/gitea/modules/secret"
+ "code.gitea.io/gitea/modules/setting"
+ "code.gitea.io/gitea/modules/timeutil"
+ "code.gitea.io/gitea/modules/util"
+
+ "github.com/pquerna/otp/totp"
+ "golang.org/x/crypto/pbkdf2"
+)
+
+//
+// Two-factor authentication
+//
+
+// ErrTwoFactorNotEnrolled indicates that a user is not enrolled in two-factor authentication.
+type ErrTwoFactorNotEnrolled struct {
+ UID int64
+}
+
+// IsErrTwoFactorNotEnrolled checks if an error is a ErrTwoFactorNotEnrolled.
+func IsErrTwoFactorNotEnrolled(err error) bool {
+ _, ok := err.(ErrTwoFactorNotEnrolled)
+ return ok
+}
+
+func (err ErrTwoFactorNotEnrolled) Error() string {
+ return fmt.Sprintf("user not enrolled in 2FA [uid: %d]", err.UID)
+}
+
+// TwoFactor represents a two-factor authentication token.
+type TwoFactor struct {
+ ID int64 `xorm:"pk autoincr"`
+ UID int64 `xorm:"UNIQUE"`
+ Secret string
+ ScratchSalt string
+ ScratchHash string
+ LastUsedPasscode string `xorm:"VARCHAR(10)"`
+ CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
+ UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
+}
+
+func init() {
+ db.RegisterModel(new(TwoFactor))
+}
+
+// GenerateScratchToken recreates the scratch token the user is using.
+func (t *TwoFactor) GenerateScratchToken() (string, error) {
+ token, err := util.RandomString(8)
+ if err != nil {
+ return "", err
+ }
+ t.ScratchSalt, _ = util.RandomString(10)
+ t.ScratchHash = HashToken(token, t.ScratchSalt)
+ return token, nil
+}
+
+// HashToken return the hashable salt
+func HashToken(token, salt string) string {
+ tempHash := pbkdf2.Key([]byte(token), []byte(salt), 10000, 50, sha256.New)
+ return fmt.Sprintf("%x", tempHash)
+}
+
+// VerifyScratchToken verifies if the specified scratch token is valid.
+func (t *TwoFactor) VerifyScratchToken(token string) bool {
+ if len(token) == 0 {
+ return false
+ }
+ tempHash := HashToken(token, t.ScratchSalt)
+ return subtle.ConstantTimeCompare([]byte(t.ScratchHash), []byte(tempHash)) == 1
+}
+
+func (t *TwoFactor) getEncryptionKey() []byte {
+ k := md5.Sum([]byte(setting.SecretKey))
+ return k[:]
+}
+
+// SetSecret sets the 2FA secret.
+func (t *TwoFactor) SetSecret(secretString string) error {
+ secretBytes, err := secret.AesEncrypt(t.getEncryptionKey(), []byte(secretString))
+ if err != nil {
+ return err
+ }
+ t.Secret = base64.StdEncoding.EncodeToString(secretBytes)
+ return nil
+}
+
+// ValidateTOTP validates the provided passcode.
+func (t *TwoFactor) ValidateTOTP(passcode string) (bool, error) {
+ decodedStoredSecret, err := base64.StdEncoding.DecodeString(t.Secret)
+ if err != nil {
+ return false, err
+ }
+ secretBytes, err := secret.AesDecrypt(t.getEncryptionKey(), decodedStoredSecret)
+ if err != nil {
+ return false, err
+ }
+ secretStr := string(secretBytes)
+ return totp.Validate(passcode, secretStr), nil
+}
+
+// NewTwoFactor creates a new two-factor authentication token.
+func NewTwoFactor(t *TwoFactor) error {
+ _, err := db.GetEngine(db.DefaultContext).Insert(t)
+ return err
+}
+
+// UpdateTwoFactor updates a two-factor authentication token.
+func UpdateTwoFactor(t *TwoFactor) error {
+ _, err := db.GetEngine(db.DefaultContext).ID(t.ID).AllCols().Update(t)
+ return err
+}
+
+// GetTwoFactorByUID returns the two-factor authentication token associated with
+// the user, if any.
+func GetTwoFactorByUID(uid int64) (*TwoFactor, error) {
+ twofa := &TwoFactor{}
+ has, err := db.GetEngine(db.DefaultContext).Where("uid=?", uid).Get(twofa)
+ if err != nil {
+ return nil, err
+ } else if !has {
+ return nil, ErrTwoFactorNotEnrolled{uid}
+ }
+ return twofa, nil
+}
+
+// HasTwoFactorByUID returns the two-factor authentication token associated with
+// the user, if any.
+func HasTwoFactorByUID(uid int64) (bool, error) {
+ return db.GetEngine(db.DefaultContext).Where("uid=?", uid).Exist(&TwoFactor{})
+}
+
+// DeleteTwoFactorByID deletes two-factor authentication token by given ID.
+func DeleteTwoFactorByID(id, userID int64) error {
+ cnt, err := db.GetEngine(db.DefaultContext).ID(id).Delete(&TwoFactor{
+ UID: userID,
+ })
+ if err != nil {
+ return err
+ } else if cnt != 1 {
+ return ErrTwoFactorNotEnrolled{userID}
+ }
+ return nil
+}
diff --git a/models/auth/u2f.go b/models/auth/u2f.go
new file mode 100644
index 0000000000..71943b237c
--- /dev/null
+++ b/models/auth/u2f.go
@@ -0,0 +1,154 @@
+// Copyright 2018 The Gitea Authors. All rights reserved.
+// Use of this source code is governed by a MIT-style
+// license that can be found in the LICENSE file.
+
+package auth
+
+import (
+ "fmt"
+
+ "code.gitea.io/gitea/models/db"
+ "code.gitea.io/gitea/modules/log"
+ "code.gitea.io/gitea/modules/timeutil"
+
+ "github.com/tstranex/u2f"
+)
+
+// ____ ________________________________ .__ __ __ .__
+// | | \_____ \_ _____/\______ \ ____ ____ |__| _______/ |_____________ _/ |_|__| ____ ____
+// | | // ____/| __) | _// __ \ / ___\| |/ ___/\ __\_ __ \__ \\ __\ |/ _ \ / \
+// | | // \| \ | | \ ___// /_/ > |\___ \ | | | | \// __ \| | | ( <_> ) | \
+// |______/ \_______ \___ / |____|_ /\___ >___ /|__/____ > |__| |__| (____ /__| |__|\____/|___| /
+// \/ \/ \/ \/_____/ \/ \/ \/
+
+// ErrU2FRegistrationNotExist represents a "ErrU2FRegistrationNotExist" kind of error.
+type ErrU2FRegistrationNotExist struct {
+ ID int64
+}
+
+func (err ErrU2FRegistrationNotExist) Error() string {
+ return fmt.Sprintf("U2F registration does not exist [id: %d]", err.ID)
+}
+
+// IsErrU2FRegistrationNotExist checks if an error is a ErrU2FRegistrationNotExist.
+func IsErrU2FRegistrationNotExist(err error) bool {
+ _, ok := err.(ErrU2FRegistrationNotExist)
+ return ok
+}
+
+// U2FRegistration represents the registration data and counter of a security key
+type U2FRegistration struct {
+ ID int64 `xorm:"pk autoincr"`
+ Name string
+ UserID int64 `xorm:"INDEX"`
+ Raw []byte
+ Counter uint32 `xorm:"BIGINT"`
+ CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
+ UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
+}
+
+func init() {
+ db.RegisterModel(new(U2FRegistration))
+}
+
+// TableName returns a better table name for U2FRegistration
+func (reg U2FRegistration) TableName() string {
+ return "u2f_registration"
+}
+
+// Parse will convert the db entry U2FRegistration to an u2f.Registration struct
+func (reg *U2FRegistration) Parse() (*u2f.Registration, error) {
+ r := new(u2f.Registration)
+ return r, r.UnmarshalBinary(reg.Raw)
+}
+
+func (reg *U2FRegistration) updateCounter(e db.Engine) error {
+ _, err := e.ID(reg.ID).Cols("counter").Update(reg)
+ return err
+}
+
+// UpdateCounter will update the database value of counter
+func (reg *U2FRegistration) UpdateCounter() error {
+ return reg.updateCounter(db.GetEngine(db.DefaultContext))
+}
+
+// U2FRegistrationList is a list of *U2FRegistration
+type U2FRegistrationList []*U2FRegistration
+
+// ToRegistrations will convert all U2FRegistrations to u2f.Registrations
+func (list U2FRegistrationList) ToRegistrations() []u2f.Registration {
+ regs := make([]u2f.Registration, 0, len(list))
+ for _, reg := range list {
+ r, err := reg.Parse()
+ if err != nil {
+ log.Error("parsing u2f registration: %v", err)
+ continue
+ }
+ regs = append(regs, *r)
+ }
+
+ return regs
+}
+
+func getU2FRegistrationsByUID(e db.Engine, uid int64) (U2FRegistrationList, error) {
+ regs := make(U2FRegistrationList, 0)
+ return regs, e.Where("user_id = ?", uid).Find(&regs)
+}
+
+// GetU2FRegistrationByID returns U2F registration by id
+func GetU2FRegistrationByID(id int64) (*U2FRegistration, error) {
+ return getU2FRegistrationByID(db.GetEngine(db.DefaultContext), id)
+}
+
+func getU2FRegistrationByID(e db.Engine, id int64) (*U2FRegistration, error) {
+ reg := new(U2FRegistration)
+ if found, err := e.ID(id).Get(reg); err != nil {
+ return nil, err
+ } else if !found {
+ return nil, ErrU2FRegistrationNotExist{ID: id}
+ }
+ return reg, nil
+}
+
+// GetU2FRegistrationsByUID returns all U2F registrations of the given user
+func GetU2FRegistrationsByUID(uid int64) (U2FRegistrationList, error) {
+ return getU2FRegistrationsByUID(db.GetEngine(db.DefaultContext), uid)
+}
+
+// HasU2FRegistrationsByUID returns whether a given user has U2F registrations
+func HasU2FRegistrationsByUID(uid int64) (bool, error) {
+ return db.GetEngine(db.DefaultContext).Where("user_id = ?", uid).Exist(&U2FRegistration{})
+}
+
+func createRegistration(e db.Engine, userID int64, name string, reg *u2f.Registration) (*U2FRegistration, error) {
+ raw, err := reg.MarshalBinary()
+ if err != nil {
+ return nil, err
+ }
+ r := &U2FRegistration{
+ UserID: userID,
+ Name: name,
+ Counter: 0,
+ Raw: raw,
+ }
+ _, err = e.InsertOne(r)
+ if err != nil {
+ return nil, err
+ }
+ return r, nil
+}
+
+// CreateRegistration will create a new U2FRegistration from the given Registration
+func CreateRegistration(userID int64, name string, reg *u2f.Registration) (*U2FRegistration, error) {
+ return createRegistration(db.GetEngine(db.DefaultContext), userID, name, reg)
+}
+
+// DeleteRegistration will delete U2FRegistration
+func DeleteRegistration(reg *U2FRegistration) error {
+ return deleteRegistration(db.GetEngine(db.DefaultContext), reg)
+}
+
+func deleteRegistration(e db.Engine, reg *U2FRegistration) error {
+ _, err := e.Delete(reg)
+ return err
+}
diff --git a/models/auth/u2f_test.go b/models/auth/u2f_test.go
new file mode 100644
index 0000000000..32ad17839c
--- /dev/null
+++ b/models/auth/u2f_test.go
@@ -0,0 +1,100 @@
+// Copyright 2020 The Gitea Authors. All rights reserved.
+// Use of this source code is governed by a MIT-style
+// license that can be found in the LICENSE file.
+
+package auth
+
+import (
+ "encoding/hex"
+ "testing"
+
+ "code.gitea.io/gitea/models/unittest"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/tstranex/u2f"
+)
+
+func TestGetU2FRegistrationByID(t *testing.T) {
+ assert.NoError(t, unittest.PrepareTestDatabase())
+
+ res, err := GetU2FRegistrationByID(1)
+ assert.NoError(t, err)
+ assert.Equal(t, "U2F Key", res.Name)
+
+ _, err = GetU2FRegistrationByID(342432)
+ assert.Error(t, err)
+ assert.True(t, IsErrU2FRegistrationNotExist(err))
+}
+
+func TestGetU2FRegistrationsByUID(t *testing.T) {
+ assert.NoError(t, unittest.PrepareTestDatabase())
+
+ res, err := GetU2FRegistrationsByUID(32)
+
+ assert.NoError(t, err)
+ assert.Len(t, res, 1)
+ assert.Equal(t, "U2F Key", res[0].Name)
+}
+
+func TestU2FRegistration_TableName(t *testing.T) {
+ assert.Equal(t, "u2f_registration", U2FRegistration{}.TableName())
+}
+
+func TestU2FRegistration_UpdateCounter(t *testing.T) {
+ assert.NoError(t, unittest.PrepareTestDatabase())
+ reg := unittest.AssertExistsAndLoadBean(t, &U2FRegistration{ID: 1}).(*U2FRegistration)
+ reg.Counter = 1
+ assert.NoError(t, reg.UpdateCounter())
+ unittest.AssertExistsIf(t, true, &U2FRegistration{ID: 1, Counter: 1})
+}
+
+func TestU2FRegistration_UpdateLargeCounter(t *testing.T) {
+ assert.NoError(t, unittest.PrepareTestDatabase())
+ reg := unittest.AssertExistsAndLoadBean(t, &U2FRegistration{ID: 1}).(*U2FRegistration)
+ reg.Counter = 0xffffffff
+ assert.NoError(t, reg.UpdateCounter())
+ unittest.AssertExistsIf(t, true, &U2FRegistration{ID: 1, Counter: 0xffffffff})
+}
+
+func TestCreateRegistration(t *testing.T) {
+ assert.NoError(t, unittest.PrepareTestDatabase())
+
+ res, err := CreateRegistration(1, "U2F Created Key", &u2f.Registration{Raw: []byte("Test")})
+ assert.NoError(t, err)
+ assert.Equal(t, "U2F Created Key", res.Name)
+ assert.Equal(t, []byte("Test"), res.Raw)
+
+ unittest.AssertExistsIf(t, true, &U2FRegistration{Name: "U2F Created Key", UserID: 1})
+}
+
+func TestDeleteRegistration(t *testing.T) {
+ assert.NoError(t, unittest.PrepareTestDatabase())
+ reg := unittest.AssertExistsAndLoadBean(t, &U2FRegistration{ID: 1}).(*U2FRegistration)
+
+ assert.NoError(t, DeleteRegistration(reg))
+ unittest.AssertNotExistsBean(t, &U2FRegistration{ID: 1})
+}
+
+const validU2FRegistrationResponseHex = "0504b174bc49c7ca254b70d2e5c207cee9cf174820ebd77ea3c65508c26da51b657c1cc6b952f8621697936482da0a6d3d3826a59095daf6cd7c03e2e60385d2f6d9402a552dfdb7477ed65fd84133f86196010b2215b57da75d315b7b9e8fe2e3925a6019551bab61d16591659cbaf00b4950f7abfe6660e2e006f76868b772d70c253082013c3081e4a003020102020a47901280001155957352300a06082a8648ce3d0403023017311530130603550403130c476e756262792050696c6f74301e170d3132303831343138323933325a170d3133303831343138323933325a3031312f302d0603550403132650696c6f74476e756262792d302e342e312d34373930313238303030313135353935373335323059301306072a8648ce3d020106082a8648ce3d030107034200048d617e65c9508e64bcc5673ac82a6799da3c1446682c258c463fffdf58dfd2fa3e6c378b53d795c4a4dffb4199edd7862f23abaf0203b4b8911ba0569994e101300a06082a8648ce3d0403020347003044022060cdb6061e9c22262d1aac1d96d8c70829b2366531dda268832cb836bcd30dfa0220631b1459f09e6330055722c8d89b7f48883b9089b88d60d1d9795902b30410df304502201471899bcc3987e62e8202c9b39c33c19033f7340352dba80fcab017db9230e402210082677d673d891933ade6f617e5dbde2e247e70423fd5ad7804a6d3d3961ef871"
+
+func TestToRegistrations_SkipInvalidItemsWithoutCrashing(t *testing.T) {
+ regKeyRaw, _ := hex.DecodeString(validU2FRegistrationResponseHex)
+ regs := U2FRegistrationList{
+ &U2FRegistration{ID: 1},
+ &U2FRegistration{ID: 2, Name: "U2F Key", UserID: 2, Counter: 0, Raw: regKeyRaw, CreatedUnix: 946684800, UpdatedUnix: 946684800},
+ }
+
+ actual := regs.ToRegistrations()
+ assert.Len(t, actual, 1)
+}
+
+func TestToRegistrations(t *testing.T) {
+ regKeyRaw, _ := hex.DecodeString(validU2FRegistrationResponseHex)
+ regs := U2FRegistrationList{
+ &U2FRegistration{ID: 1, Name: "U2F Key", UserID: 1, Counter: 0, Raw: regKeyRaw, CreatedUnix: 946684800, UpdatedUnix: 946684800},
+ &U2FRegistration{ID: 2, Name: "U2F Key", UserID: 2, Counter: 0, Raw: regKeyRaw, CreatedUnix: 946684800, UpdatedUnix: 946684800},
+ }
+
+ actual := regs.ToRegistrations()
+ assert.Len(t, actual, 2)
+}