diff options
Diffstat (limited to 'routers')
-rw-r--r-- | routers/init.go | 2 | ||||
-rw-r--r-- | routers/web/admin/auths.go | 84 | ||||
-rw-r--r-- | routers/web/auth/auth.go | 35 | ||||
-rw-r--r-- | routers/web/auth/linkaccount.go | 45 | ||||
-rw-r--r-- | routers/web/auth/oauth.go | 19 | ||||
-rw-r--r-- | routers/web/auth/openid.go | 5 | ||||
-rw-r--r-- | routers/web/auth/saml.go | 172 | ||||
-rw-r--r-- | routers/web/web.go | 5 |
8 files changed, 328 insertions, 39 deletions
diff --git a/routers/init.go b/routers/init.go index e0a7150ba3..9ae8c368a2 100644 --- a/routers/init.go +++ b/routers/init.go @@ -35,6 +35,7 @@ import ( actions_service "code.gitea.io/gitea/services/actions" "code.gitea.io/gitea/services/auth" "code.gitea.io/gitea/services/auth/source/oauth2" + "code.gitea.io/gitea/services/auth/source/saml" "code.gitea.io/gitea/services/automerge" "code.gitea.io/gitea/services/cron" feed_service "code.gitea.io/gitea/services/feed" @@ -138,6 +139,7 @@ func InitWebInstalled(ctx context.Context) { log.Info("ORM engine initialization successful!") mustInit(system.Init) mustInitCtx(ctx, oauth2.Init) + mustInitCtx(ctx, saml.Init) mustInit(release_service.Init) diff --git a/routers/web/admin/auths.go b/routers/web/admin/auths.go index 7fdd18dfae..187b569d39 100644 --- a/routers/web/admin/auths.go +++ b/routers/web/admin/auths.go @@ -1,9 +1,12 @@ // Copyright 2014 The Gogs Authors. All rights reserved. +// Copyright 2024 The Gitea Authors. All rights reserved. // SPDX-License-Identifier: MIT package admin import ( + "crypto/tls" + "crypto/x509" "errors" "fmt" "net/http" @@ -25,6 +28,7 @@ import ( "code.gitea.io/gitea/services/auth/source/ldap" "code.gitea.io/gitea/services/auth/source/oauth2" pam_service "code.gitea.io/gitea/services/auth/source/pam" + "code.gitea.io/gitea/services/auth/source/saml" "code.gitea.io/gitea/services/auth/source/smtp" "code.gitea.io/gitea/services/auth/source/sspi" "code.gitea.io/gitea/services/forms" @@ -71,6 +75,7 @@ var ( {auth.SMTP.String(), auth.SMTP}, {auth.OAuth2.String(), auth.OAuth2}, {auth.SSPI.String(), auth.SSPI}, + {auth.SAML.String(), auth.SAML}, } if pam.Supported { items = append(items, dropdownItem{auth.Names[auth.PAM], auth.PAM}) @@ -83,6 +88,16 @@ var ( {ldap.SecurityProtocolNames[ldap.SecurityProtocolLDAPS], ldap.SecurityProtocolLDAPS}, {ldap.SecurityProtocolNames[ldap.SecurityProtocolStartTLS], ldap.SecurityProtocolStartTLS}, } + + nameIDFormats = []dropdownItem{ + {saml.NameIDFormatNames[saml.SAML20Persistent], saml.SAML20Persistent}, // use this as default value + {saml.NameIDFormatNames[saml.SAML11Email], saml.SAML11Email}, + {saml.NameIDFormatNames[saml.SAML11Persistent], saml.SAML11Persistent}, + {saml.NameIDFormatNames[saml.SAML11Unspecified], saml.SAML11Unspecified}, + {saml.NameIDFormatNames[saml.SAML20Email], saml.SAML20Email}, + {saml.NameIDFormatNames[saml.SAML20Transient], saml.SAML20Transient}, + {saml.NameIDFormatNames[saml.SAML20Unspecified], saml.SAML20Unspecified}, + } ) // NewAuthSource render adding a new auth source page @@ -98,6 +113,8 @@ func NewAuthSource(ctx *context.Context) { ctx.Data["is_sync_enabled"] = true ctx.Data["AuthSources"] = authSources ctx.Data["SecurityProtocols"] = securityProtocols + ctx.Data["CurrentNameIDFormat"] = saml.NameIDFormatNames[saml.SAML20Persistent] + ctx.Data["NameIDFormats"] = nameIDFormats ctx.Data["SMTPAuths"] = smtp.Authenticators oauth2providers := oauth2.GetSupportedOAuth2Providers() ctx.Data["OAuth2Providers"] = oauth2providers @@ -231,6 +248,52 @@ func parseSSPIConfig(ctx *context.Context, form forms.AuthenticationForm) (*sspi }, nil } +func parseSAMLConfig(ctx *context.Context, form forms.AuthenticationForm) (*saml.Source, error) { + if util.IsEmptyString(form.IdentityProviderMetadata) && util.IsEmptyString(form.IdentityProviderMetadataURL) { + return nil, fmt.Errorf("%s %s", ctx.Tr("form.SAMLMetadata"), ctx.Tr("form.require_error")) + } + + if !util.IsEmptyString(form.IdentityProviderMetadataURL) { + _, err := url.Parse(form.IdentityProviderMetadataURL) + if err != nil { + return nil, fmt.Errorf("%s", ctx.Tr("form.SAMLMetadataURL")) + } + } + + // check the integrity of the certificate and private key (autogenerated if these form fields are blank) + if !util.IsEmptyString(form.ServiceProviderCertificate) && !util.IsEmptyString(form.ServiceProviderPrivateKey) { + keyPair, err := tls.X509KeyPair([]byte(form.ServiceProviderCertificate), []byte(form.ServiceProviderPrivateKey)) + if err != nil { + return nil, err + } + keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0]) + if err != nil { + return nil, err + } + } else { + privateKey, cert, err := saml.GenerateSAMLSPKeypair() + if err != nil { + return nil, err + } + + form.ServiceProviderPrivateKey = privateKey + form.ServiceProviderCertificate = cert + } + + return &saml.Source{ + IdentityProviderMetadata: form.IdentityProviderMetadata, + IdentityProviderMetadataURL: form.IdentityProviderMetadataURL, + InsecureSkipAssertionSignatureValidation: form.InsecureSkipAssertionSignatureValidation, + NameIDFormat: saml.NameIDFormat(form.NameIDFormat), + ServiceProviderCertificate: form.ServiceProviderCertificate, + ServiceProviderPrivateKey: form.ServiceProviderPrivateKey, + EmailAssertionKey: form.EmailAssertionKey, + NameAssertionKey: form.NameAssertionKey, + UsernameAssertionKey: form.UsernameAssertionKey, + IconURL: form.SAMLIconURL, + }, nil +} + // NewAuthSourcePost response for adding an auth source func NewAuthSourcePost(ctx *context.Context) { form := *web.GetForm(ctx).(*forms.AuthenticationForm) @@ -244,6 +307,8 @@ func NewAuthSourcePost(ctx *context.Context) { ctx.Data["SMTPAuths"] = smtp.Authenticators oauth2providers := oauth2.GetSupportedOAuth2Providers() ctx.Data["OAuth2Providers"] = oauth2providers + ctx.Data["CurrentNameIDFormat"] = saml.NameIDFormatNames[saml.NameIDFormat(form.NameIDFormat)] + ctx.Data["NameIDFormats"] = nameIDFormats ctx.Data["SSPIAutoCreateUsers"] = true ctx.Data["SSPIAutoActivateUsers"] = true @@ -290,6 +355,13 @@ func NewAuthSourcePost(ctx *context.Context) { ctx.RenderWithErr(ctx.Tr("admin.auths.login_source_of_type_exist"), tplAuthNew, form) return } + case auth.SAML: + var err error + config, err = parseSAMLConfig(ctx, form) + if err != nil { + ctx.RenderWithErr(err.Error(), tplAuthNew, form) + return + } default: ctx.Error(http.StatusBadRequest) return @@ -336,6 +408,7 @@ func EditAuthSource(ctx *context.Context) { ctx.Data["SMTPAuths"] = smtp.Authenticators oauth2providers := oauth2.GetSupportedOAuth2Providers() ctx.Data["OAuth2Providers"] = oauth2providers + ctx.Data["NameIDFormats"] = nameIDFormats source, err := auth.GetSourceByID(ctx, ctx.ParamsInt64(":authid")) if err != nil { @@ -344,6 +417,9 @@ func EditAuthSource(ctx *context.Context) { } ctx.Data["Source"] = source ctx.Data["HasTLS"] = source.HasTLS() + if source.IsSAML() { + ctx.Data["CurrentNameIDFormat"] = saml.NameIDFormatNames[source.Cfg.(*saml.Source).NameIDFormat] + } if source.IsOAuth2() { type Named interface { @@ -378,6 +454,8 @@ func EditAuthSourcePost(ctx *context.Context) { } ctx.Data["Source"] = source ctx.Data["HasTLS"] = source.HasTLS() + ctx.Data["CurrentNameIDFormat"] = saml.NameIDFormatNames[saml.SAML20Persistent] + ctx.Data["NameIDFormats"] = nameIDFormats if ctx.HasError() { ctx.HTML(http.StatusOK, tplAuthEdit) @@ -412,6 +490,12 @@ func EditAuthSourcePost(ctx *context.Context) { ctx.RenderWithErr(err.Error(), tplAuthEdit, form) return } + case auth.SAML: + config, err = parseSAMLConfig(ctx, form) + if err != nil { + ctx.RenderWithErr(err.Error(), tplAuthEdit, form) + return + } default: ctx.Error(http.StatusBadRequest) return diff --git a/routers/web/auth/auth.go b/routers/web/auth/auth.go index 3de1f3373d..f5955ec5ff 100644 --- a/routers/web/auth/auth.go +++ b/routers/web/auth/auth.go @@ -28,6 +28,7 @@ import ( "code.gitea.io/gitea/routers/utils" auth_service "code.gitea.io/gitea/services/auth" "code.gitea.io/gitea/services/auth/source/oauth2" + "code.gitea.io/gitea/services/auth/source/saml" "code.gitea.io/gitea/services/externalaccount" "code.gitea.io/gitea/services/forms" "code.gitea.io/gitea/services/mailer" @@ -170,6 +171,14 @@ func SignIn(ctx *context.Context) { return } ctx.Data["OAuth2Providers"] = oauth2Providers + + samlProviders, err := saml.GetSAMLProviders(ctx, util.OptionalBoolTrue) + if err != nil { + ctx.ServerError("UserSignIn", err) + return + } + ctx.Data["SAMLProviders"] = samlProviders + ctx.Data["Title"] = ctx.Tr("sign_in") ctx.Data["SignInLink"] = setting.AppSubURL + "/user/login" ctx.Data["PageIsSignIn"] = true @@ -193,6 +202,14 @@ func SignInPost(ctx *context.Context) { return } ctx.Data["OAuth2Providers"] = oauth2Providers + + samlProviders, err := saml.GetSAMLProviders(ctx, util.OptionalBoolTrue) + if err != nil { + ctx.ServerError("UserSignIn", err) + return + } + ctx.Data["SAMLProviders"] = samlProviders + ctx.Data["Title"] = ctx.Tr("sign_in") ctx.Data["SignInLink"] = setting.AppSubURL + "/user/login" ctx.Data["PageIsSignIn"] = true @@ -504,7 +521,7 @@ func SignUpPost(ctx *context.Context) { Passwd: form.Password, } - if !createAndHandleCreatedUser(ctx, tplSignUp, form, u, nil, nil, false) { + if !createAndHandleCreatedUser(ctx, tplSignUp, form, u, nil, nil, false, auth.NoType) { // error already handled return } @@ -515,16 +532,16 @@ func SignUpPost(ctx *context.Context) { // createAndHandleCreatedUser calls createUserInContext and // then handleUserCreated. -func createAndHandleCreatedUser(ctx *context.Context, tpl base.TplName, form any, u *user_model.User, overwrites *user_model.CreateUserOverwriteOptions, gothUser *goth.User, allowLink bool) bool { - if !createUserInContext(ctx, tpl, form, u, overwrites, gothUser, allowLink) { +func createAndHandleCreatedUser(ctx *context.Context, tpl base.TplName, form any, u *user_model.User, overwrites *user_model.CreateUserOverwriteOptions, gothUser *goth.User, allowLink bool, authType auth.Type) bool { + if !createUserInContext(ctx, tpl, form, u, overwrites, gothUser, allowLink, authType) { return false } - return handleUserCreated(ctx, u, gothUser) + return handleUserCreated(ctx, u, gothUser, authType) } // createUserInContext creates a user and handles errors within a given context. // Optionally a template can be specified. -func createUserInContext(ctx *context.Context, tpl base.TplName, form any, u *user_model.User, overwrites *user_model.CreateUserOverwriteOptions, gothUser *goth.User, allowLink bool) (ok bool) { +func createUserInContext(ctx *context.Context, tpl base.TplName, form any, u *user_model.User, overwrites *user_model.CreateUserOverwriteOptions, gothUser *goth.User, allowLink bool, authType auth.Type) (ok bool) { if err := user_model.CreateUser(ctx, u, overwrites); err != nil { if allowLink && (user_model.IsErrUserAlreadyExist(err) || user_model.IsErrEmailAlreadyUsed(err)) { if setting.OAuth2Client.AccountLinking == setting.OAuth2AccountLinkingAuto { @@ -541,10 +558,10 @@ func createUserInContext(ctx *context.Context, tpl base.TplName, form any, u *us } // TODO: probably we should respect 'remember' user's choice... - linkAccount(ctx, user, *gothUser, true) + linkAccount(ctx, user, *gothUser, true, authType) return false // user is already created here, all redirects are handled } else if setting.OAuth2Client.AccountLinking == setting.OAuth2AccountLinkingLogin { - showLinkingLogin(ctx, *gothUser) + showLinkingLogin(ctx, *gothUser, authType) return false // user will be created only after linking login } } @@ -590,7 +607,7 @@ func createUserInContext(ctx *context.Context, tpl base.TplName, form any, u *us // handleUserCreated does additional steps after a new user is created. // It auto-sets admin for the only user, updates the optional external user and // sends a confirmation email if required. -func handleUserCreated(ctx *context.Context, u *user_model.User, gothUser *goth.User) (ok bool) { +func handleUserCreated(ctx *context.Context, u *user_model.User, gothUser *goth.User, authType auth.Type) (ok bool) { // Auto-set admin for the only user. if user_model.CountUsers(ctx, nil) == 1 { opts := &user_service.UpdateOptions{ @@ -606,7 +623,7 @@ func handleUserCreated(ctx *context.Context, u *user_model.User, gothUser *goth. // update external user information if gothUser != nil { - if err := externalaccount.UpdateExternalUser(ctx, u, *gothUser); err != nil { + if err := externalaccount.UpdateExternalUser(ctx, u, *gothUser, authType); err != nil { if !errors.Is(err, util.ErrNotExist) { log.Error("UpdateExternalUser failed: %v", err) } diff --git a/routers/web/auth/linkaccount.go b/routers/web/auth/linkaccount.go index 1d94e52fe3..c62ae84083 100644 --- a/routers/web/auth/linkaccount.go +++ b/routers/web/auth/linkaccount.go @@ -48,13 +48,13 @@ func LinkAccount(ctx *context.Context) { ctx.Data["SignInLink"] = setting.AppSubURL + "/user/link_account_signin" ctx.Data["SignUpLink"] = setting.AppSubURL + "/user/link_account_signup" - gothUser := ctx.Session.Get("linkAccountGothUser") - if gothUser == nil { + externalLinkUser := ctx.Session.Get("linkAccountUser") + if externalLinkUser == nil { ctx.ServerError("UserSignIn", errors.New("not in LinkAccount session")) return } - gu, _ := gothUser.(goth.User) + gu := externalLinkUser.(auth.LinkAccountUser).GothUser uname, err := getUserName(&gu) if err != nil { ctx.ServerError("UserSignIn", err) @@ -135,12 +135,14 @@ func LinkAccountPostSignIn(ctx *context.Context) { ctx.Data["SignInLink"] = setting.AppSubURL + "/user/link_account_signin" ctx.Data["SignUpLink"] = setting.AppSubURL + "/user/link_account_signup" - gothUser := ctx.Session.Get("linkAccountGothUser") - if gothUser == nil { + externalLinkUserInterface := ctx.Session.Get("linkAccountUser") + if externalLinkUserInterface == nil { ctx.ServerError("UserSignIn", errors.New("not in LinkAccount session")) return } + externalLinkUser := externalLinkUserInterface.(auth.LinkAccountUser) + if ctx.HasError() { ctx.HTML(http.StatusOK, tplLinkAccount) return @@ -152,10 +154,10 @@ func LinkAccountPostSignIn(ctx *context.Context) { return } - linkAccount(ctx, u, gothUser.(goth.User), signInForm.Remember) + linkAccount(ctx, u, externalLinkUser.GothUser, signInForm.Remember, externalLinkUser.Type) } -func linkAccount(ctx *context.Context, u *user_model.User, gothUser goth.User, remember bool) { +func linkAccount(ctx *context.Context, u *user_model.User, gothUser goth.User, remember bool, authType auth.Type) { updateAvatarIfNeed(ctx, gothUser.AvatarURL, u) // If this user is enrolled in 2FA, we can't sign the user in just yet. @@ -168,7 +170,7 @@ func linkAccount(ctx *context.Context, u *user_model.User, gothUser goth.User, r return } - err = externalaccount.LinkAccountToUser(ctx, u, gothUser) + err = externalaccount.LinkAccountToUser(ctx, u, gothUser, authType) if err != nil { ctx.ServerError("UserLinkAccount", err) return @@ -222,14 +224,14 @@ func LinkAccountPostRegister(ctx *context.Context) { ctx.Data["SignInLink"] = setting.AppSubURL + "/user/link_account_signin" ctx.Data["SignUpLink"] = setting.AppSubURL + "/user/link_account_signup" - gothUserInterface := ctx.Session.Get("linkAccountGothUser") - if gothUserInterface == nil { + externalLinkUser := ctx.Session.Get("linkAccountUser") + if externalLinkUser == nil { ctx.ServerError("UserSignUp", errors.New("not in LinkAccount session")) return } - gothUser, ok := gothUserInterface.(goth.User) + linkUser, ok := externalLinkUser.(auth.LinkAccountUser) if !ok { - ctx.ServerError("UserSignUp", fmt.Errorf("session linkAccountGothUser type is %t but not goth.User", gothUserInterface)) + ctx.ServerError("UserSignUp", fmt.Errorf("session linkAccountUser type is %t but not goth.User", externalLinkUser)) return } @@ -275,7 +277,7 @@ func LinkAccountPostRegister(ctx *context.Context) { } } - authSource, err := auth.GetActiveOAuth2SourceByName(ctx, gothUser.Provider) + authSource, err := auth.GetActiveAuthSourceByName(ctx, linkUser.GothUser.Provider, linkUser.Type) if err != nil { ctx.ServerError("CreateUser", err) return @@ -285,21 +287,24 @@ func LinkAccountPostRegister(ctx *context.Context) { Name: form.UserName, Email: form.Email, Passwd: form.Password, - LoginType: auth.OAuth2, + LoginType: authSource.Type, LoginSource: authSource.ID, - LoginName: gothUser.UserID, + LoginName: linkUser.GothUser.UserID, } - if !createAndHandleCreatedUser(ctx, tplLinkAccount, form, u, nil, &gothUser, false) { + if !createAndHandleCreatedUser(ctx, tplLinkAccount, form, u, nil, &linkUser.GothUser, false, linkUser.Type) { // error already handled return } - source := authSource.Cfg.(*oauth2.Source) - if err := syncGroupsToTeams(ctx, source, &gothUser, u); err != nil { - ctx.ServerError("SyncGroupsToTeams", err) - return + if linkUser.Type == auth.OAuth2 { + source := authSource.Cfg.(*oauth2.Source) + if err := syncGroupsToTeams(ctx, source, &linkUser.GothUser, u); err != nil { + ctx.ServerError("SyncGroupsToTeams", err) + return + } } + // TODO we will support some form of group mapping for SAML handleSignIn(ctx, u, false) } diff --git a/routers/web/auth/oauth.go b/routers/web/auth/oauth.go index ee0770ef37..d00644dd5f 100644 --- a/routers/web/auth/oauth.go +++ b/routers/web/auth/oauth.go @@ -841,7 +841,7 @@ func handleAuthorizeError(ctx *context.Context, authErr AuthorizeError, redirect func SignInOAuth(ctx *context.Context) { provider := ctx.Params(":provider") - authSource, err := auth.GetActiveOAuth2SourceByName(ctx, provider) + authSource, err := auth.GetActiveAuthSourceByName(ctx, provider, auth.OAuth2) if err != nil { ctx.ServerError("SignIn", err) return @@ -892,7 +892,7 @@ func SignInOAuthCallback(ctx *context.Context) { } // first look if the provider is still active - authSource, err := auth.GetActiveOAuth2SourceByName(ctx, provider) + authSource, err := auth.GetActiveAuthSourceByName(ctx, provider, auth.OAuth2) if err != nil { ctx.ServerError("SignIn", err) return @@ -935,7 +935,7 @@ func SignInOAuthCallback(ctx *context.Context) { if u == nil { if ctx.Doer != nil { // attach user to already logged in user - err = externalaccount.LinkAccountToUser(ctx, ctx.Doer, gothUser) + err = externalaccount.LinkAccountToUser(ctx, ctx.Doer, gothUser, auth.OAuth2) if err != nil { ctx.ServerError("UserLinkAccount", err) return @@ -988,7 +988,7 @@ func SignInOAuthCallback(ctx *context.Context) { u.IsAdmin = isAdmin.ValueOrDefault(false) u.IsRestricted = isRestricted.ValueOrDefault(false) - if !createAndHandleCreatedUser(ctx, base.TplName(""), nil, u, overwriteDefault, &gothUser, setting.OAuth2Client.AccountLinking != setting.OAuth2AccountLinkingDisabled) { + if !createAndHandleCreatedUser(ctx, base.TplName(""), nil, u, overwriteDefault, &gothUser, setting.OAuth2Client.AccountLinking != setting.OAuth2AccountLinkingDisabled, auth.OAuth2) { // error already handled return } @@ -999,7 +999,7 @@ func SignInOAuthCallback(ctx *context.Context) { } } else { // no existing user is found, request attach or new account - showLinkingLogin(ctx, gothUser) + showLinkingLogin(ctx, gothUser, auth.OAuth2) return } } @@ -1063,9 +1063,12 @@ func getUserAdminAndRestrictedFromGroupClaims(source *oauth2.Source, gothUser *g return isAdmin, isRestricted } -func showLinkingLogin(ctx *context.Context, gothUser goth.User) { +func showLinkingLogin(ctx *context.Context, gothUser goth.User, authType auth.Type) { if err := updateSession(ctx, nil, map[string]any{ - "linkAccountGothUser": gothUser, + "linkAccountUser": auth.LinkAccountUser{ + Type: authType, + GothUser: gothUser, + }, }); err != nil { ctx.ServerError("updateSession", err) return @@ -1144,7 +1147,7 @@ func handleOAuth2SignIn(ctx *context.Context, source *auth.Source, u *user_model } // update external user information - if err := externalaccount.UpdateExternalUser(ctx, u, gothUser); err != nil { + if err := externalaccount.UpdateExternalUser(ctx, u, gothUser, auth.OAuth2); err != nil { if !errors.Is(err, util.ErrNotExist) { log.Error("UpdateExternalUser failed: %v", err) } diff --git a/routers/web/auth/openid.go b/routers/web/auth/openid.go index 29ef772b1c..bf377b4496 100644 --- a/routers/web/auth/openid.go +++ b/routers/web/auth/openid.go @@ -8,6 +8,7 @@ import ( "net/http" "net/url" + auth_model "code.gitea.io/gitea/models/auth" user_model "code.gitea.io/gitea/models/user" "code.gitea.io/gitea/modules/auth/openid" "code.gitea.io/gitea/modules/base" @@ -363,7 +364,7 @@ func RegisterOpenIDPost(ctx *context.Context) { Email: form.Email, Passwd: password, } - if !createUserInContext(ctx, tplSignUpOID, form, u, nil, nil, false) { + if !createUserInContext(ctx, tplSignUpOID, form, u, nil, nil, false, auth_model.NoType) { // error already handled return } @@ -379,7 +380,7 @@ func RegisterOpenIDPost(ctx *context.Context) { return } - if !handleUserCreated(ctx, u, nil) { + if !handleUserCreated(ctx, u, nil, auth_model.NoType) { // error already handled return } diff --git a/routers/web/auth/saml.go b/routers/web/auth/saml.go new file mode 100644 index 0000000000..29d689d2e9 --- /dev/null +++ b/routers/web/auth/saml.go @@ -0,0 +1,172 @@ +// Copyright 2024 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package auth + +import ( + "errors" + "fmt" + "net/http" + "strings" + + "code.gitea.io/gitea/models/auth" + user_model "code.gitea.io/gitea/models/user" + "code.gitea.io/gitea/modules/context" + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/setting" + "code.gitea.io/gitea/modules/util" + "code.gitea.io/gitea/modules/web/middleware" + "code.gitea.io/gitea/services/auth/source/saml" + "code.gitea.io/gitea/services/externalaccount" + + "github.com/markbates/goth" +) + +func SignInSAML(ctx *context.Context) { + provider := ctx.Params(":provider") + + loginSource, err := auth.GetActiveAuthSourceByName(ctx, provider, auth.SAML) + if err != nil || loginSource == nil { + ctx.NotFound("SAMLMetadata", err) + return + } + + if err = loginSource.Cfg.(*saml.Source).Callout(ctx.Req, ctx.Resp); err != nil { + if strings.Contains(err.Error(), "no provider for ") { + ctx.Error(http.StatusNotFound) + return + } + ctx.ServerError("SignIn", err) + } +} + +func SignInSAMLCallback(ctx *context.Context) { + provider := ctx.Params(":provider") + loginSource, err := auth.GetActiveAuthSourceByName(ctx, provider, auth.SAML) + if err != nil || loginSource == nil { + ctx.NotFound("SignInSAMLCallback", err) + return + } + + if loginSource == nil { + ctx.ServerError("SignIn", fmt.Errorf("no valid provider found, check configured callback url in provider")) + return + } + + u, gothUser, err := samlUserLoginCallback(*ctx, loginSource, ctx.Req, ctx.Resp) + if err != nil { + ctx.ServerError("SignInSAMLCallback", err) + return + } + + if u == nil { + if ctx.Doer != nil { + // attach user to already logged in user + err = externalaccount.LinkAccountToUser(ctx, ctx.Doer, gothUser, auth.SAML) + if err != nil { + ctx.ServerError("LinkAccountToUser", err) + return + } + + ctx.Redirect(setting.AppSubURL + "/user/settings/security") + return + } else if !setting.Service.AllowOnlyInternalRegistration && false { + // TODO: allow auto registration from saml users (OAuth2 uses the following setting.OAuth2Client.EnableAutoRegistration) + } else { + // no existing user is found, request attach or new account + showLinkingLogin(ctx, gothUser, auth.SAML) + return + } + } + + handleSamlSignIn(ctx, loginSource, u, gothUser) +} + +func handleSamlSignIn(ctx *context.Context, source *auth.Source, u *user_model.User, gothUser goth.User) { + if err := updateSession(ctx, nil, map[string]any{ + "uid": u.ID, + "uname": u.Name, + }); err != nil { + ctx.ServerError("updateSession", err) + return + } + + // Clear whatever CSRF cookie has right now, force to generate a new one + ctx.Csrf.DeleteCookie(ctx) + + // Register last login + u.SetLastLogin() + + // update external user information + if err := externalaccount.UpdateExternalUser(ctx, u, gothUser, auth.SAML); err != nil { + if !errors.Is(err, util.ErrNotExist) { + log.Error("UpdateExternalUser failed: %v", err) + } + } + + if err := resetLocale(ctx, u); err != nil { + ctx.ServerError("resetLocale", err) + return + } + + if redirectTo := ctx.GetSiteCookie("redirect_to"); len(redirectTo) > 0 { + middleware.DeleteRedirectToCookie(ctx.Resp) + ctx.RedirectToFirst(redirectTo) + return + } + + ctx.Redirect(setting.AppSubURL + "/") +} + +func samlUserLoginCallback(ctx context.Context, authSource *auth.Source, request *http.Request, response http.ResponseWriter) (*user_model.User, goth.User, error) { + samlSource := authSource.Cfg.(*saml.Source) + + gothUser, err := samlSource.Callback(request, response) + if err != nil { + return nil, gothUser, err + } + + user := &user_model.User{ + LoginName: gothUser.UserID, + LoginType: auth.SAML, + LoginSource: authSource.ID, + } + + hasUser, err := user_model.GetUser(ctx, user) + if err != nil { + return nil, goth.User{}, err + } + + if hasUser { + return user, gothUser, nil + } + + // search in external linked users + externalLoginUser := &user_model.ExternalLoginUser{ + ExternalID: gothUser.UserID, + LoginSourceID: authSource.ID, + } + hasUser, err = user_model.GetExternalLogin(ctx, externalLoginUser) + if err != nil { + return nil, goth.User{}, err + } + if hasUser { + user, err = user_model.GetUserByID(request.Context(), externalLoginUser.UserID) + return user, gothUser, err + } + + // no user found to login + return nil, gothUser, nil +} + +func SAMLMetadata(ctx *context.Context) { + provider := ctx.Params(":provider") + loginSource, err := auth.GetActiveAuthSourceByName(ctx, provider, auth.SAML) + if err != nil || loginSource == nil { + ctx.NotFound("SAMLMetadata", err) + return + } + if err = loginSource.Cfg.(*saml.Source).Metadata(ctx.Req, ctx.Resp); err != nil { + ctx.ServerError("SAMLMetadata", err) + } +} diff --git a/routers/web/web.go b/routers/web/web.go index 864164972e..77c8319f06 100644 --- a/routers/web/web.go +++ b/routers/web/web.go @@ -667,6 +667,11 @@ func registerRoutes(m *web.Route) { m.Get("/{provider}", auth.SignInOAuth) m.Get("/{provider}/callback", auth.SignInOAuthCallback) }) + m.Group("/saml", func() { + m.Get("/{provider}", auth.SignInSAML) // redir to SAML IDP + m.Post("/{provider}/acs", auth.SignInSAMLCallback) + m.Get("/{provider}/metadata", auth.SAMLMetadata) + }) }) // ***** END: User ***** |