diff options
author | Lunny Xiao <xiaolunwen@gmail.com> | 2022-01-02 21:12:35 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-01-02 21:12:35 +0800 |
commit | de8e3948a5e38f7eaf82d3c0cfd10e995bf68e92 (patch) | |
tree | bbcb011d264e0d614d49c734856b446360c5a4a3 /models/auth/source.go | |
parent | e61b390d545919244141b699b28e3fbc42adc66f (diff) | |
download | gitea-de8e3948a5e38f7eaf82d3c0cfd10e995bf68e92.tar.gz gitea-de8e3948a5e38f7eaf82d3c0cfd10e995bf68e92.zip |
Refactor auth package (#17962)
Diffstat (limited to 'models/auth/source.go')
-rw-r--r-- | models/auth/source.go | 397 |
1 files changed, 397 insertions, 0 deletions
diff --git a/models/auth/source.go b/models/auth/source.go new file mode 100644 index 0000000000..6f4f5addcb --- /dev/null +++ b/models/auth/source.go @@ -0,0 +1,397 @@ +// Copyright 2014 The Gogs Authors. All rights reserved. +// Copyright 2019 The Gitea Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package auth + +import ( + "fmt" + "reflect" + + "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/timeutil" + + "xorm.io/xorm" + "xorm.io/xorm/convert" +) + +// Type represents an login type. +type Type int + +// Note: new type must append to the end of list to maintain compatibility. +const ( + NoType Type = iota + Plain // 1 + LDAP // 2 + SMTP // 3 + PAM // 4 + DLDAP // 5 + OAuth2 // 6 + SSPI // 7 +) + +// String returns the string name of the LoginType +func (typ Type) String() string { + return Names[typ] +} + +// Int returns the int value of the LoginType +func (typ Type) Int() int { + return int(typ) +} + +// Names contains the name of LoginType values. +var Names = map[Type]string{ + LDAP: "LDAP (via BindDN)", + DLDAP: "LDAP (simple auth)", // Via direct bind + SMTP: "SMTP", + PAM: "PAM", + OAuth2: "OAuth2", + SSPI: "SPNEGO with SSPI", +} + +// Config represents login config as far as the db is concerned +type Config interface { + convert.Conversion +} + +// SkipVerifiable configurations provide a IsSkipVerify to check if SkipVerify is set +type SkipVerifiable interface { + IsSkipVerify() bool +} + +// HasTLSer configurations provide a HasTLS to check if TLS can be enabled +type HasTLSer interface { + HasTLS() bool +} + +// UseTLSer configurations provide a HasTLS to check if TLS is enabled +type UseTLSer interface { + UseTLS() bool +} + +// SSHKeyProvider configurations provide ProvidesSSHKeys to check if they provide SSHKeys +type SSHKeyProvider interface { + ProvidesSSHKeys() bool +} + +// RegisterableSource configurations provide RegisterSource which needs to be run on creation +type RegisterableSource interface { + RegisterSource() error + UnregisterSource() error +} + +var registeredConfigs = map[Type]func() Config{} + +// RegisterTypeConfig register a config for a provided type +func RegisterTypeConfig(typ Type, exemplar Config) { + if reflect.TypeOf(exemplar).Kind() == reflect.Ptr { + // Pointer: + registeredConfigs[typ] = func() Config { + return reflect.New(reflect.ValueOf(exemplar).Elem().Type()).Interface().(Config) + } + return + } + + // Not a Pointer + registeredConfigs[typ] = func() Config { + return reflect.New(reflect.TypeOf(exemplar)).Elem().Interface().(Config) + } +} + +// SourceSettable configurations can have their authSource set on them +type SourceSettable interface { + SetAuthSource(*Source) +} + +// Source represents an external way for authorizing users. +type Source struct { + ID int64 `xorm:"pk autoincr"` + Type Type + Name string `xorm:"UNIQUE"` + IsActive bool `xorm:"INDEX NOT NULL DEFAULT false"` + IsSyncEnabled bool `xorm:"INDEX NOT NULL DEFAULT false"` + Cfg convert.Conversion `xorm:"TEXT"` + + CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"` + UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"` +} + +// TableName xorm will read the table name from this method +func (Source) TableName() string { + return "login_source" +} + +func init() { + db.RegisterModel(new(Source)) +} + +// BeforeSet is invoked from XORM before setting the value of a field of this object. +func (source *Source) BeforeSet(colName string, val xorm.Cell) { + if colName == "type" { + typ := Type(db.Cell2Int64(val)) + constructor, ok := registeredConfigs[typ] + if !ok { + return + } + source.Cfg = constructor() + if settable, ok := source.Cfg.(SourceSettable); ok { + settable.SetAuthSource(source) + } + } +} + +// TypeName return name of this login source type. +func (source *Source) TypeName() string { + return Names[source.Type] +} + +// IsLDAP returns true of this source is of the LDAP type. +func (source *Source) IsLDAP() bool { + return source.Type == LDAP +} + +// IsDLDAP returns true of this source is of the DLDAP type. +func (source *Source) IsDLDAP() bool { + return source.Type == DLDAP +} + +// IsSMTP returns true of this source is of the SMTP type. +func (source *Source) IsSMTP() bool { + return source.Type == SMTP +} + +// IsPAM returns true of this source is of the PAM type. +func (source *Source) IsPAM() bool { + return source.Type == PAM +} + +// IsOAuth2 returns true of this source is of the OAuth2 type. +func (source *Source) IsOAuth2() bool { + return source.Type == OAuth2 +} + +// IsSSPI returns true of this source is of the SSPI type. +func (source *Source) IsSSPI() bool { + return source.Type == SSPI +} + +// HasTLS returns true of this source supports TLS. +func (source *Source) HasTLS() bool { + hasTLSer, ok := source.Cfg.(HasTLSer) + return ok && hasTLSer.HasTLS() +} + +// UseTLS returns true of this source is configured to use TLS. +func (source *Source) UseTLS() bool { + useTLSer, ok := source.Cfg.(UseTLSer) + return ok && useTLSer.UseTLS() +} + +// SkipVerify returns true if this source is configured to skip SSL +// verification. +func (source *Source) SkipVerify() bool { + skipVerifiable, ok := source.Cfg.(SkipVerifiable) + return ok && skipVerifiable.IsSkipVerify() +} + +// CreateSource inserts a AuthSource in the DB if not already +// existing with the given name. +func CreateSource(source *Source) error { + has, err := db.GetEngine(db.DefaultContext).Where("name=?", source.Name).Exist(new(Source)) + if err != nil { + return err + } else if has { + return ErrSourceAlreadyExist{source.Name} + } + // Synchronization is only available with LDAP for now + if !source.IsLDAP() { + source.IsSyncEnabled = false + } + + _, err = db.GetEngine(db.DefaultContext).Insert(source) + if err != nil { + return err + } + + if !source.IsActive { + return nil + } + + if settable, ok := source.Cfg.(SourceSettable); ok { + settable.SetAuthSource(source) + } + + registerableSource, ok := source.Cfg.(RegisterableSource) + if !ok { + return nil + } + + err = registerableSource.RegisterSource() + if err != nil { + // remove the AuthSource in case of errors while registering configuration + if _, err := db.GetEngine(db.DefaultContext).Delete(source); err != nil { + log.Error("CreateSource: Error while wrapOpenIDConnectInitializeError: %v", err) + } + } + return err +} + +// Sources returns a slice of all login sources found in DB. +func Sources() ([]*Source, error) { + auths := make([]*Source, 0, 6) + return auths, db.GetEngine(db.DefaultContext).Find(&auths) +} + +// SourcesByType returns all sources of the specified type +func SourcesByType(loginType Type) ([]*Source, error) { + sources := make([]*Source, 0, 1) + if err := db.GetEngine(db.DefaultContext).Where("type = ?", loginType).Find(&sources); err != nil { + return nil, err + } + return sources, nil +} + +// AllActiveSources returns all active sources +func AllActiveSources() ([]*Source, error) { + sources := make([]*Source, 0, 5) + if err := db.GetEngine(db.DefaultContext).Where("is_active = ?", true).Find(&sources); err != nil { + return nil, err + } + return sources, nil +} + +// ActiveSources returns all active sources of the specified type +func ActiveSources(tp Type) ([]*Source, error) { + sources := make([]*Source, 0, 1) + if err := db.GetEngine(db.DefaultContext).Where("is_active = ? and type = ?", true, tp).Find(&sources); err != nil { + return nil, err + } + return sources, nil +} + +// IsSSPIEnabled returns true if there is at least one activated login +// source of type LoginSSPI +func IsSSPIEnabled() bool { + if !db.HasEngine { + return false + } + sources, err := ActiveSources(SSPI) + if err != nil { + log.Error("ActiveSources: %v", err) + return false + } + return len(sources) > 0 +} + +// GetSourceByID returns login source by given ID. +func GetSourceByID(id int64) (*Source, error) { + source := new(Source) + if id == 0 { + source.Cfg = registeredConfigs[NoType]() + // Set this source to active + // FIXME: allow disabling of db based password authentication in future + source.IsActive = true + return source, nil + } + + has, err := db.GetEngine(db.DefaultContext).ID(id).Get(source) + if err != nil { + return nil, err + } else if !has { + return nil, ErrSourceNotExist{id} + } + return source, nil +} + +// UpdateSource updates a Source record in DB. +func UpdateSource(source *Source) error { + var originalSource *Source + if source.IsOAuth2() { + // keep track of the original values so we can restore in case of errors while registering OAuth2 providers + var err error + if originalSource, err = GetSourceByID(source.ID); err != nil { + return err + } + } + + _, err := db.GetEngine(db.DefaultContext).ID(source.ID).AllCols().Update(source) + if err != nil { + return err + } + + if !source.IsActive { + return nil + } + + if settable, ok := source.Cfg.(SourceSettable); ok { + settable.SetAuthSource(source) + } + + registerableSource, ok := source.Cfg.(RegisterableSource) + if !ok { + return nil + } + + err = registerableSource.RegisterSource() + if err != nil { + // restore original values since we cannot update the provider it self + if _, err := db.GetEngine(db.DefaultContext).ID(source.ID).AllCols().Update(originalSource); err != nil { + log.Error("UpdateSource: Error while wrapOpenIDConnectInitializeError: %v", err) + } + } + return err +} + +// CountSources returns number of login sources. +func CountSources() int64 { + count, _ := db.GetEngine(db.DefaultContext).Count(new(Source)) + return count +} + +// ErrSourceNotExist represents a "SourceNotExist" kind of error. +type ErrSourceNotExist struct { + ID int64 +} + +// IsErrSourceNotExist checks if an error is a ErrSourceNotExist. +func IsErrSourceNotExist(err error) bool { + _, ok := err.(ErrSourceNotExist) + return ok +} + +func (err ErrSourceNotExist) Error() string { + return fmt.Sprintf("login source does not exist [id: %d]", err.ID) +} + +// ErrSourceAlreadyExist represents a "SourceAlreadyExist" kind of error. +type ErrSourceAlreadyExist struct { + Name string +} + +// IsErrSourceAlreadyExist checks if an error is a ErrSourceAlreadyExist. +func IsErrSourceAlreadyExist(err error) bool { + _, ok := err.(ErrSourceAlreadyExist) + return ok +} + +func (err ErrSourceAlreadyExist) Error() string { + return fmt.Sprintf("login source already exists [name: %s]", err.Name) +} + +// ErrSourceInUse represents a "SourceInUse" kind of error. +type ErrSourceInUse struct { + ID int64 +} + +// IsErrSourceInUse checks if an error is a ErrSourceInUse. +func IsErrSourceInUse(err error) bool { + _, ok := err.(ErrSourceInUse) + return ok +} + +func (err ErrSourceInUse) Error() string { + return fmt.Sprintf("login source is still used by some users [id: %d]", err.ID) +} |