* initial stuff for oauth2 login, fails on: * login button on the signIn page to start the OAuth2 flow and a callback for each provider Only GitHub is implemented for now * show login button only when the OAuth2 consumer is configured (and activated) * create macaron group for oauth2 urls * prevent net/http in modules (other then oauth2) * use a new data sessions oauth2 folder for storing the oauth2 session data * add missing 2FA when this is enabled on the user * add password option for OAuth2 user , for use with git over http and login to the GUI * add tip for registering a GitHub OAuth application * at startup of Gitea register all configured providers and also on adding/deleting of new providers * custom handling of errors in oauth2 request init + show better tip * add ExternalLoginUser model and migration script to add it to database * link a external account to an existing account (still need to handle wrong login and signup) and remove if user is removed * remove the linked external account from the user his settings * if user is unknown we allow him to register a new account or link it to some existing account * sign up with button on signin page (als change OAuth2Provider structure so we can store basic stuff about providers) * from gorilla/sessions docs: "Important Note: If you aren't using gorilla/mux, you need to wrap your handlers with context.ClearHandler as or else you will leak memory!" (we're using gorilla/sessions for storing oauth2 sessions) * use updated goth lib that now supports getting the OAuth2 user if the AccessToken is still valid instead of re-authenticating (prevent flooding the OAuth2 provider)tags/v1.1.0
@@ -41,6 +41,7 @@ import ( | |||
"github.com/go-macaron/toolbox" | |||
"github.com/urfave/cli" | |||
macaron "gopkg.in/macaron.v1" | |||
context2 "github.com/gorilla/context" | |||
) | |||
// CmdWeb represents the available web sub-command. | |||
@@ -210,6 +211,13 @@ func runWeb(ctx *cli.Context) error { | |||
m.Post("/sign_up", bindIgnErr(auth.RegisterForm{}), user.SignUpPost) | |||
m.Get("/reset_password", user.ResetPasswd) | |||
m.Post("/reset_password", user.ResetPasswdPost) | |||
m.Group("/oauth2", func() { | |||
m.Get("/:provider", user.SignInOAuth) | |||
m.Get("/:provider/callback", user.SignInOAuthCallback) | |||
}) | |||
m.Get("/link_account", user.LinkAccount) | |||
m.Post("/link_account_signin", bindIgnErr(auth.SignInForm{}), user.LinkAccountPostSignIn) | |||
m.Post("/link_account_signup", bindIgnErr(auth.RegisterForm{}), user.LinkAccountPostRegister) | |||
m.Group("/two_factor", func() { | |||
m.Get("", user.TwoFactor) | |||
m.Post("", bindIgnErr(auth.TwoFactorAuthForm{}), user.TwoFactorPost) | |||
@@ -236,6 +244,7 @@ func runWeb(ctx *cli.Context) error { | |||
Post(bindIgnErr(auth.NewAccessTokenForm{}), user.SettingsApplicationsPost) | |||
m.Post("/applications/delete", user.SettingsDeleteApplication) | |||
m.Route("/delete", "GET,POST", user.SettingsDelete) | |||
m.Combo("/account_link").Get(user.SettingsAccountLinks).Post(user.SettingsDeleteAccountLink) | |||
m.Group("/two_factor", func() { | |||
m.Get("", user.SettingsTwoFactor) | |||
m.Post("/regenerate_scratch", user.SettingsTwoFactorRegenerateScratch) | |||
@@ -671,11 +680,11 @@ func runWeb(ctx *cli.Context) error { | |||
var err error | |||
switch setting.Protocol { | |||
case setting.HTTP: | |||
err = runHTTP(listenAddr, m) | |||
err = runHTTP(listenAddr, context2.ClearHandler(m)) | |||
case setting.HTTPS: | |||
err = runHTTPS(listenAddr, setting.CertFile, setting.KeyFile, m) | |||
err = runHTTPS(listenAddr, setting.CertFile, setting.KeyFile, context2.ClearHandler(m)) | |||
case setting.FCGI: | |||
err = fcgi.Serve(nil, m) | |||
err = fcgi.Serve(nil, context2.ClearHandler(m)) | |||
case setting.UnixSocket: | |||
if err := os.Remove(listenAddr); err != nil && !os.IsNotExist(err) { | |||
log.Fatal(4, "Failed to remove unix socket directory %s: %v", listenAddr, err) | |||
@@ -691,7 +700,7 @@ func runWeb(ctx *cli.Context) error { | |||
if err = os.Chmod(listenAddr, os.FileMode(setting.UnixSocketPermission)); err != nil { | |||
log.Fatal(4, "Failed to set permission of unix socket: %v", err) | |||
} | |||
err = http.Serve(listener, m) | |||
err = http.Serve(listener, context2.ClearHandler(m)) | |||
default: | |||
log.Fatal(4, "Invalid protocol: %s", setting.Protocol) | |||
} |
@@ -847,3 +847,43 @@ func IsErrUploadNotExist(err error) bool { | |||
func (err ErrUploadNotExist) Error() string { | |||
return fmt.Sprintf("attachment does not exist [id: %d, uuid: %s]", err.ID, err.UUID) | |||
} | |||
// ___________ __ .__ .____ .__ ____ ___ | |||
// \_ _____/__ ____/ |_ ___________ ____ _____ | | | | ____ ____ |__| ____ | | \______ ___________ | |||
// | __)_\ \/ /\ __\/ __ \_ __ \/ \\__ \ | | | | / _ \ / ___\| |/ \ | | / ___// __ \_ __ \ | |||
// | \> < | | \ ___/| | \/ | \/ __ \| |__ | |__( <_> ) /_/ > | | \ | | /\___ \\ ___/| | \/ | |||
// /_______ /__/\_ \ |__| \___ >__| |___| (____ /____/ |_______ \____/\___ /|__|___| / |______//____ >\___ >__| | |||
// \/ \/ \/ \/ \/ \/ /_____/ \/ \/ \/ | |||
// ErrExternalLoginUserAlreadyExist represents a "ExternalLoginUserAlreadyExist" kind of error. | |||
type ErrExternalLoginUserAlreadyExist struct { | |||
ExternalID string | |||
UserID int64 | |||
LoginSourceID int64 | |||
} | |||
// IsErrExternalLoginUserAlreadyExist checks if an error is a ExternalLoginUserAlreadyExist. | |||
func IsErrExternalLoginUserAlreadyExist(err error) bool { | |||
_, ok := err.(ErrExternalLoginUserAlreadyExist) | |||
return ok | |||
} | |||
func (err ErrExternalLoginUserAlreadyExist) Error() string { | |||
return fmt.Sprintf("external login user already exists [externalID: %s, userID: %d, loginSourceID: %d]", err.ExternalID, err.UserID, err.LoginSourceID) | |||
} | |||
// ErrExternalLoginUserNotExist represents a "ExternalLoginUserNotExist" kind of error. | |||
type ErrExternalLoginUserNotExist struct { | |||
UserID int64 | |||
LoginSourceID int64 | |||
} | |||
// IsErrExternalLoginUserNotExist checks if an error is a ExternalLoginUserNotExist. | |||
func IsErrExternalLoginUserNotExist(err error) bool { | |||
_, ok := err.(ErrExternalLoginUserNotExist) | |||
return ok | |||
} | |||
func (err ErrExternalLoginUserNotExist) Error() string { | |||
return fmt.Sprintf("external login user link does not exists [userID: %d, loginSourceID: %d]", err.UserID, err.LoginSourceID) | |||
} |
@@ -0,0 +1,74 @@ | |||
// Copyright 2017 The Gitea Authors. All rights reserved. | |||
// Use of this source code is governed by a MIT-style | |||
// license that can be found in the LICENSE file. | |||
package models | |||
import "github.com/markbates/goth" | |||
// ExternalLoginUser makes the connecting between some existing user and additional external login sources | |||
type ExternalLoginUser struct { | |||
ExternalID string `xorm:"NOT NULL"` | |||
UserID int64 `xorm:"NOT NULL"` | |||
LoginSourceID int64 `xorm:"NOT NULL"` | |||
} | |||
// GetExternalLogin checks if a externalID in loginSourceID scope already exists | |||
func GetExternalLogin(externalLoginUser *ExternalLoginUser) (bool, error) { | |||
return x.Get(externalLoginUser) | |||
} | |||
// ListAccountLinks returns a map with the ExternalLoginUser and its LoginSource | |||
func ListAccountLinks(user *User) ([]*ExternalLoginUser, error) { | |||
externalAccounts := make([]*ExternalLoginUser, 0, 5) | |||
err := x.Where("user_id=?", user.ID). | |||
Desc("login_source_id"). | |||
Find(&externalAccounts) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return externalAccounts, nil | |||
} | |||
// LinkAccountToUser link the gothUser to the user | |||
func LinkAccountToUser(user *User, gothUser goth.User) error { | |||
loginSource, err := GetActiveOAuth2LoginSourceByName(gothUser.Provider) | |||
if err != nil { | |||
return err | |||
} | |||
externalLoginUser := &ExternalLoginUser{ | |||
ExternalID: gothUser.UserID, | |||
UserID: user.ID, | |||
LoginSourceID: loginSource.ID, | |||
} | |||
has, err := x.Get(externalLoginUser) | |||
if err != nil { | |||
return err | |||
} else if has { | |||
return ErrExternalLoginUserAlreadyExist{gothUser.UserID, user.ID, loginSource.ID} | |||
} | |||
_, err = x.Insert(externalLoginUser) | |||
return err | |||
} | |||
// RemoveAccountLink will remove all external login sources for the given user | |||
func RemoveAccountLink(user *User, loginSourceID int64) (int64, error) { | |||
deleted, err := x.Delete(&ExternalLoginUser{UserID: user.ID, LoginSourceID: loginSourceID}) | |||
if err != nil { | |||
return deleted, err | |||
} | |||
if deleted < 1 { | |||
return deleted, ErrExternalLoginUserNotExist{user.ID, loginSourceID} | |||
} | |||
return deleted, err | |||
} | |||
// RemoveAllAccountLinks will remove all external login sources for the given user | |||
func RemoveAllAccountLinks(user *User) error { | |||
_, err := x.Delete(&ExternalLoginUser{UserID: user.ID}) | |||
return err | |||
} |
@@ -22,6 +22,7 @@ import ( | |||
"code.gitea.io/gitea/modules/auth/ldap" | |||
"code.gitea.io/gitea/modules/auth/pam" | |||
"code.gitea.io/gitea/modules/log" | |||
"code.gitea.io/gitea/modules/auth/oauth2" | |||
) | |||
// LoginType represents an login type. | |||
@@ -30,19 +31,21 @@ type LoginType int | |||
// Note: new type must append to the end of list to maintain compatibility. | |||
const ( | |||
LoginNoType LoginType = iota | |||
LoginPlain // 1 | |||
LoginLDAP // 2 | |||
LoginSMTP // 3 | |||
LoginPAM // 4 | |||
LoginDLDAP // 5 | |||
LoginPlain // 1 | |||
LoginLDAP // 2 | |||
LoginSMTP // 3 | |||
LoginPAM // 4 | |||
LoginDLDAP // 5 | |||
LoginOAuth2 // 6 | |||
) | |||
// LoginNames contains the name of LoginType values. | |||
var LoginNames = map[LoginType]string{ | |||
LoginLDAP: "LDAP (via BindDN)", | |||
LoginDLDAP: "LDAP (simple auth)", // Via direct bind | |||
LoginSMTP: "SMTP", | |||
LoginPAM: "PAM", | |||
LoginLDAP: "LDAP (via BindDN)", | |||
LoginDLDAP: "LDAP (simple auth)", // Via direct bind | |||
LoginSMTP: "SMTP", | |||
LoginPAM: "PAM", | |||
LoginOAuth2: "OAuth2", | |||
} | |||
// SecurityProtocolNames contains the name of SecurityProtocol values. | |||
@@ -57,6 +60,7 @@ var ( | |||
_ core.Conversion = &LDAPConfig{} | |||
_ core.Conversion = &SMTPConfig{} | |||
_ core.Conversion = &PAMConfig{} | |||
_ core.Conversion = &OAuth2Config{} | |||
) | |||
// LDAPConfig holds configuration for LDAP login source. | |||
@@ -115,6 +119,23 @@ func (cfg *PAMConfig) ToDB() ([]byte, error) { | |||
return json.Marshal(cfg) | |||
} | |||
// OAuth2Config holds configuration for the OAuth2 login source. | |||
type OAuth2Config struct { | |||
Provider string | |||
ClientID string | |||
ClientSecret string | |||
} | |||
// FromDB fills up an OAuth2Config from serialized format. | |||
func (cfg *OAuth2Config) FromDB(bs []byte) error { | |||
return json.Unmarshal(bs, cfg) | |||
} | |||
// ToDB exports an SMTPConfig to a serialized format. | |||
func (cfg *OAuth2Config) ToDB() ([]byte, error) { | |||
return json.Marshal(cfg) | |||
} | |||
// LoginSource represents an external way for authorizing users. | |||
type LoginSource struct { | |||
ID int64 `xorm:"pk autoincr"` | |||
@@ -162,6 +183,8 @@ func (source *LoginSource) BeforeSet(colName string, val xorm.Cell) { | |||
source.Cfg = new(SMTPConfig) | |||
case LoginPAM: | |||
source.Cfg = new(PAMConfig) | |||
case LoginOAuth2: | |||
source.Cfg = new(OAuth2Config) | |||
default: | |||
panic("unrecognized login source type: " + com.ToStr(*val)) | |||
} | |||
@@ -203,6 +226,11 @@ func (source *LoginSource) IsPAM() bool { | |||
return source.Type == LoginPAM | |||
} | |||
// IsOAuth2 returns true of this source is of the OAuth2 type. | |||
func (source *LoginSource) IsOAuth2() bool { | |||
return source.Type == LoginOAuth2 | |||
} | |||
// HasTLS returns true of this source supports TLS. | |||
func (source *LoginSource) HasTLS() bool { | |||
return ((source.IsLDAP() || source.IsDLDAP()) && | |||
@@ -250,6 +278,11 @@ func (source *LoginSource) PAM() *PAMConfig { | |||
return source.Cfg.(*PAMConfig) | |||
} | |||
// OAuth2 returns OAuth2Config for this source, if of OAuth2 type. | |||
func (source *LoginSource) OAuth2() *OAuth2Config { | |||
return source.Cfg.(*OAuth2Config) | |||
} | |||
// CreateLoginSource inserts a LoginSource in the DB if not already | |||
// existing with the given name. | |||
func CreateLoginSource(source *LoginSource) error { | |||
@@ -261,12 +294,16 @@ func CreateLoginSource(source *LoginSource) error { | |||
} | |||
_, err = x.Insert(source) | |||
if err == nil && source.IsOAuth2() { | |||
oAuth2Config := source.OAuth2() | |||
oauth2.RegisterProvider(source.Name, oAuth2Config.Provider, oAuth2Config.ClientID, oAuth2Config.ClientSecret) | |||
} | |||
return err | |||
} | |||
// LoginSources returns a slice of all login sources found in DB. | |||
func LoginSources() ([]*LoginSource, error) { | |||
auths := make([]*LoginSource, 0, 5) | |||
auths := make([]*LoginSource, 0, 6) | |||
return auths, x.Find(&auths) | |||
} | |||
@@ -285,6 +322,11 @@ func GetLoginSourceByID(id int64) (*LoginSource, error) { | |||
// UpdateSource updates a LoginSource record in DB. | |||
func UpdateSource(source *LoginSource) error { | |||
_, err := x.Id(source.ID).AllCols().Update(source) | |||
if err == nil && source.IsOAuth2() { | |||
oAuth2Config := source.OAuth2() | |||
oauth2.RemoveProvider(source.Name) | |||
oauth2.RegisterProvider(source.Name, oAuth2Config.Provider, oAuth2Config.ClientID, oAuth2Config.ClientSecret) | |||
} | |||
return err | |||
} | |||
@@ -296,6 +338,18 @@ func DeleteSource(source *LoginSource) error { | |||
} else if count > 0 { | |||
return ErrLoginSourceInUse{source.ID} | |||
} | |||
count, err = x.Count(&ExternalLoginUser{LoginSourceID: source.ID}) | |||
if err != nil { | |||
return err | |||
} else if count > 0 { | |||
return ErrLoginSourceInUse{source.ID} | |||
} | |||
if source.IsOAuth2() { | |||
oauth2.RemoveProvider(source.Name) | |||
} | |||
_, err = x.Id(source.ID).Delete(new(LoginSource)) | |||
return err | |||
} | |||
@@ -444,7 +498,7 @@ func LoginViaSMTP(user *User, login, password string, sourceID int64, cfg *SMTPC | |||
idx := strings.Index(login, "@") | |||
if idx == -1 { | |||
return nil, ErrUserNotExist{0, login, 0} | |||
} else if !com.IsSliceContainsStr(strings.Split(cfg.AllowedDomains, ","), login[idx+1:]) { | |||
} else if !com.IsSliceContainsStr(strings.Split(cfg.AllowedDomains, ","), login[idx + 1:]) { | |||
return nil, ErrUserNotExist{0, login, 0} | |||
} | |||
} | |||
@@ -526,6 +580,27 @@ func LoginViaPAM(user *User, login, password string, sourceID int64, cfg *PAMCon | |||
return user, CreateUser(user) | |||
} | |||
// ________ _____ __ .__ ________ | |||
// \_____ \ / _ \ __ ___/ |_| |__ \_____ \ | |||
// / | \ / /_\ \| | \ __\ | \ / ____/ | |||
// / | \/ | \ | /| | | Y \/ \ | |||
// \_______ /\____|__ /____/ |__| |___| /\_______ \ | |||
// \/ \/ \/ \/ | |||
// OAuth2Provider describes the display values of a single OAuth2 provider | |||
type OAuth2Provider struct { | |||
Name string | |||
DisplayName string | |||
Image string | |||
} | |||
// OAuth2Providers contains the map of registered OAuth2 providers in Gitea (based on goth) | |||
// key is used to map the OAuth2Provider with the goth provider type (also in LoginSource.OAuth2Config.Provider) | |||
// value is used to store display data | |||
var OAuth2Providers = map[string]OAuth2Provider{ | |||
"github": {Name: "github", DisplayName:"GitHub", Image: "/img/github.png"}, | |||
} | |||
// ExternalUserLogin attempts a login using external source types. | |||
func ExternalUserLogin(user *User, login, password string, source *LoginSource, autoRegister bool) (*User, error) { | |||
if !source.IsActived { | |||
@@ -560,7 +635,7 @@ func UserSignIn(username, password string) (*User, error) { | |||
if hasUser { | |||
switch user.LoginType { | |||
case LoginNoType, LoginPlain: | |||
case LoginNoType, LoginPlain, LoginOAuth2: | |||
if user.ValidatePassword(password) { | |||
return user, nil | |||
} | |||
@@ -580,12 +655,16 @@ func UserSignIn(username, password string) (*User, error) { | |||
} | |||
} | |||
sources := make([]*LoginSource, 0, 3) | |||
sources := make([]*LoginSource, 0, 5) | |||
if err = x.UseBool().Find(&sources, &LoginSource{IsActived: true}); err != nil { | |||
return nil, err | |||
} | |||
for _, source := range sources { | |||
if source.IsOAuth2() { | |||
// don't try to authenticate against OAuth2 sources | |||
continue | |||
} | |||
authUser, err := ExternalUserLogin(nil, username, password, source, true) | |||
if err == nil { | |||
return authUser, nil | |||
@@ -596,3 +675,58 @@ func UserSignIn(username, password string) (*User, error) { | |||
return nil, ErrUserNotExist{user.ID, user.Name, 0} | |||
} | |||
// GetActiveOAuth2ProviderLoginSources returns all actived LoginOAuth2 sources | |||
func GetActiveOAuth2ProviderLoginSources() ([]*LoginSource, error) { | |||
sources := make([]*LoginSource, 0, 1) | |||
if err := x.UseBool().Find(&sources, &LoginSource{IsActived: true, Type: LoginOAuth2}); err != nil { | |||
return nil, err | |||
} | |||
return sources, nil | |||
} | |||
// GetActiveOAuth2LoginSourceByName returns a OAuth2 LoginSource based on the given name | |||
func GetActiveOAuth2LoginSourceByName(name string) (*LoginSource, error) { | |||
loginSource := &LoginSource{ | |||
Name: name, | |||
Type: LoginOAuth2, | |||
IsActived: true, | |||
} | |||
has, err := x.UseBool().Get(loginSource) | |||
if !has || err != nil { | |||
return nil, err | |||
} | |||
return loginSource, nil | |||
} | |||
// GetActiveOAuth2Providers returns the map of configured active OAuth2 providers | |||
// key is used as technical name (like in the callbackURL) | |||
// values to display | |||
func GetActiveOAuth2Providers() (map[string]OAuth2Provider, error) { | |||
// Maybe also seperate used and unused providers so we can force the registration of only 1 active provider for each type | |||
loginSources, err := GetActiveOAuth2ProviderLoginSources() | |||
if err != nil { | |||
return nil, err | |||
} | |||
providers := make(map[string]OAuth2Provider) | |||
for _, source := range loginSources { | |||
providers[source.Name] = OAuth2Providers[source.OAuth2().Provider] | |||
} | |||
return providers, nil | |||
} | |||
// InitOAuth2 initialize the OAuth2 lib and register all active OAuth2 providers in the library | |||
func InitOAuth2() { | |||
oauth2.Init() | |||
loginSources, _ := GetActiveOAuth2ProviderLoginSources() | |||
for _, source := range loginSources { | |||
oAuth2Config := source.OAuth2() | |||
oauth2.RegisterProvider(source.Name, oAuth2Config.Provider, oAuth2Config.ClientID, oAuth2Config.ClientSecret) | |||
} | |||
} |
@@ -84,6 +84,8 @@ var migrations = []Migration{ | |||
NewMigration("create repo unit table and add units for all repos", addUnitsToTables), | |||
// v17 -> v18 | |||
NewMigration("set protect branches updated with created", setProtectedBranchUpdatedWithCreated), | |||
// v18 -> v19 | |||
NewMigration("add external login user", addExternalLoginUser), | |||
} | |||
// Migrate database to current version |
@@ -0,0 +1,25 @@ | |||
// Copyright 2016 Gitea. All rights reserved. | |||
// Use of this source code is governed by a MIT-style | |||
// license that can be found in the LICENSE file. | |||
package migrations | |||
import ( | |||
"fmt" | |||
"github.com/go-xorm/xorm" | |||
) | |||
// ExternalLoginUser makes the connecting between some existing user and additional external login sources | |||
type ExternalLoginUser struct { | |||
ExternalID string `xorm:"NOT NULL"` | |||
UserID int64 `xorm:"NOT NULL"` | |||
LoginSourceID int64 `xorm:"NOT NULL"` | |||
} | |||
func addExternalLoginUser(x *xorm.Engine) error { | |||
if err := x.Sync2(new(ExternalLoginUser)); err != nil { | |||
return fmt.Errorf("Sync2: %v", err) | |||
} | |||
return nil | |||
} |
@@ -196,6 +196,11 @@ func (u *User) IsLocal() bool { | |||
return u.LoginType <= LoginPlain | |||
} | |||
// IsOAuth2 returns true if user login type is LoginOAuth2. | |||
func (u *User) IsOAuth2() bool { | |||
return u.LoginType == LoginOAuth2 | |||
} | |||
// HasForkedRepo checks if user has already forked a repository with given ID. | |||
func (u *User) HasForkedRepo(repoID int64) bool { | |||
_, has := HasForkedRepo(u.ID, repoID) | |||
@@ -397,6 +402,11 @@ func (u *User) ValidatePassword(passwd string) bool { | |||
return subtle.ConstantTimeCompare([]byte(u.Passwd), []byte(newUser.Passwd)) == 1 | |||
} | |||
// IsPasswordSet checks if the password is set or left empty | |||
func (u *User) IsPasswordSet() bool { | |||
return !u.ValidatePassword("") | |||
} | |||
// UploadAvatar saves custom avatar for user. | |||
// FIXME: split uploads to different subdirs in case we have massive users. | |||
func (u *User) UploadAvatar(data []byte) error { | |||
@@ -947,6 +957,12 @@ func deleteUser(e *xorm.Session, u *User) error { | |||
return fmt.Errorf("clear assignee: %v", err) | |||
} | |||
// ***** START: ExternalLoginUser ***** | |||
if err = RemoveAllAccountLinks(u); err != nil { | |||
return fmt.Errorf("ExternalLoginUser: %v", err) | |||
} | |||
// ***** END: ExternalLoginUser ***** | |||
if _, err = e.Id(u.ID).Delete(new(User)); err != nil { | |||
return fmt.Errorf("Delete: %v", err) | |||
} | |||
@@ -1190,6 +1206,11 @@ func GetUserByEmail(email string) (*User, error) { | |||
return nil, ErrUserNotExist{0, email, 0} | |||
} | |||
// GetUser checks if a user already exists | |||
func GetUser(user *User) (bool, error) { | |||
return x.Get(user) | |||
} | |||
// SearchUserOptions contains the options for searching | |||
type SearchUserOptions struct { | |||
Keyword string |
@@ -179,7 +179,7 @@ func AssignForm(form interface{}, data map[string]interface{}) { | |||
func getRuleBody(field reflect.StructField, prefix string) string { | |||
for _, rule := range strings.Split(field.Tag.Get("binding"), ";") { | |||
if strings.HasPrefix(rule, prefix) { | |||
return rule[len(prefix) : len(rule)-1] | |||
return rule[len(prefix): len(rule) - 1] | |||
} | |||
} | |||
return "" | |||
@@ -237,7 +237,7 @@ func validate(errs binding.Errors, data map[string]interface{}, f Form, l macaro | |||
} | |||
if errs[0].FieldNames[0] == field.Name { | |||
data["Err_"+field.Name] = true | |||
data["Err_" + field.Name] = true | |||
trName := field.Tag.Get("locale") | |||
if len(trName) == 0 { |
@@ -12,7 +12,7 @@ import ( | |||
// AuthenticationForm form for authentication | |||
type AuthenticationForm struct { | |||
ID int64 | |||
Type int `binding:"Range(2,5)"` | |||
Type int `binding:"Range(2,6)"` | |||
Name string `binding:"Required;MaxSize(30)"` | |||
Host string | |||
Port int | |||
@@ -36,6 +36,9 @@ type AuthenticationForm struct { | |||
TLS bool | |||
SkipVerify bool | |||
PAMServiceName string | |||
Oauth2Provider string | |||
Oauth2Key string | |||
Oauth2Secret string | |||
} | |||
// Validate validates fields |
@@ -0,0 +1,105 @@ | |||
// Copyright 2017 The Gitea Authors. All rights reserved. | |||
// Use of this source code is governed by a MIT-style | |||
// license that can be found in the LICENSE file. | |||
package oauth2 | |||
import ( | |||
"code.gitea.io/gitea/modules/setting" | |||
"code.gitea.io/gitea/modules/log" | |||
"github.com/gorilla/sessions" | |||
"github.com/markbates/goth" | |||
"github.com/markbates/goth/gothic" | |||
"net/http" | |||
"os" | |||
"github.com/satori/go.uuid" | |||
"path/filepath" | |||
"github.com/markbates/goth/providers/github" | |||
) | |||
var ( | |||
sessionUsersStoreKey = "gitea-oauth2-sessions" | |||
providerHeaderKey = "gitea-oauth2-provider" | |||
) | |||
// Init initialize the setup of the OAuth2 library | |||
func Init() { | |||
sessionDir := filepath.Join(setting.AppDataPath, "sessions", "oauth2") | |||
if err := os.MkdirAll(sessionDir, 0700); err != nil { | |||
log.Fatal(4, "Fail to create dir %s: %v", sessionDir, err) | |||
} | |||
gothic.Store = sessions.NewFilesystemStore(sessionDir, []byte(sessionUsersStoreKey)) | |||
gothic.SetState = func(req *http.Request) string { | |||
return uuid.NewV4().String() | |||
} | |||
gothic.GetProviderName = func(req *http.Request) (string, error) { | |||
return req.Header.Get(providerHeaderKey), nil | |||
} | |||
} | |||
// Auth OAuth2 auth service | |||
func Auth(provider string, request *http.Request, response http.ResponseWriter) error { | |||
// not sure if goth is thread safe (?) when using multiple providers | |||
request.Header.Set(providerHeaderKey, provider) | |||
// don't use the default gothic begin handler to prevent issues when some error occurs | |||
// normally the gothic library will write some custom stuff to the response instead of our own nice error page | |||
//gothic.BeginAuthHandler(response, request) | |||
url, err := gothic.GetAuthURL(response, request) | |||
if err == nil { | |||
http.Redirect(response, request, url, http.StatusTemporaryRedirect) | |||
} | |||
return err | |||
} | |||
// ProviderCallback handles OAuth callback, resolve to a goth user and send back to original url | |||
// this will trigger a new authentication request, but because we save it in the session we can use that | |||
func ProviderCallback(provider string, request *http.Request, response http.ResponseWriter) (goth.User, error) { | |||
// not sure if goth is thread safe (?) when using multiple providers | |||
request.Header.Set(providerHeaderKey, provider) | |||
user, err := gothic.CompleteUserAuth(response, request) | |||
if err != nil { | |||
return user, err | |||
} | |||
return user, nil | |||
} | |||
// RegisterProvider register a OAuth2 provider in goth lib | |||
func RegisterProvider(providerName, providerType, clientID, clientSecret string) { | |||
provider := createProvider(providerName, providerType, clientID, clientSecret) | |||
if provider != nil { | |||
goth.UseProviders(provider) | |||
} | |||
} | |||
// RemoveProvider removes the given OAuth2 provider from the goth lib | |||
func RemoveProvider(providerName string) { | |||
delete(goth.GetProviders(), providerName) | |||
} | |||
// used to create different types of goth providers | |||
func createProvider(providerName, providerType, clientID, clientSecret string) goth.Provider { | |||
callbackURL := setting.AppURL + "user/oauth2/" + providerName + "/callback" | |||
var provider goth.Provider | |||
switch providerType { | |||
case "github": | |||
provider = github.New(clientID, clientSecret, callbackURL, "user:email") | |||
} | |||
// always set the name if provider is created so we can support multiple setups of 1 provider | |||
if provider != nil { | |||
provider.SetName(providerName) | |||
} | |||
return provider | |||
} |
@@ -143,7 +143,7 @@ func (f *AddEmailForm) Validate(ctx *macaron.Context, errs binding.Errors) bindi | |||
// ChangePasswordForm form for changing password | |||
type ChangePasswordForm struct { | |||
OldPassword string `form:"old_password" binding:"Required;MinSize(1);MaxSize(255)"` | |||
OldPassword string `form:"old_password" binding:"MaxSize(255)"` | |||
Password string `form:"password" binding:"Required;MaxSize(255)"` | |||
Retype string `form:"retype"` | |||
} |
@@ -5,8 +5,11 @@ dashboard = Dashboard | |||
explore = Explore | |||
help = Help | |||
sign_in = Sign In | |||
sign_in_with = Sign in with | |||
sign_out = Sign Out | |||
sign_up = Sign Up | |||
link_account = Link Account | |||
link_account_signin_or_signup = Login with existing credentials to link your existing account to these new account, or sign up for a new account | |||
register = Register | |||
website = Website | |||
version = Version | |||
@@ -277,6 +280,7 @@ applications = Applications | |||
orgs = Organizations | |||
delete = Delete Account | |||
twofa = Two-Factor Authentication | |||
account_link = External Accounts | |||
uid = Uid | |||
public_profile = Public Profile | |||
@@ -379,6 +383,13 @@ then_enter_passcode = Then enter the passcode the application gives you: | |||
passcode_invalid = That passcode is invalid. Try again. | |||
twofa_enrolled = Your account has now been enrolled in two-factor authentication. Make sure to save your scratch token (%s), as it will only be shown once! | |||
manage_account_links = Manage account links | |||
manage_account_links_desc = External accounts linked to this account | |||
account_links_not_available = There are no external accounts linked to this account | |||
remove_account_link = Remove linked account | |||
remove_account_link_desc = Delete this account link will remove all related access for your account. Do you want to continue? | |||
remove_account_link_success = Account link has been removed successfully! | |||
delete_account = Delete Your Account | |||
delete_prompt = The operation will delete your account permanently, and <strong>CANNOT</strong> be undone! | |||
confirm_delete_account = Confirm Deletion | |||
@@ -1106,8 +1117,12 @@ auths.allowed_domains_helper = Leave it empty to not restrict any domains. Multi | |||
auths.enable_tls = Enable TLS Encryption | |||
auths.skip_tls_verify = Skip TLS Verify | |||
auths.pam_service_name = PAM Service Name | |||
auths.oauth2_provider = OAuth2 provider | |||
auths.oauth2_clientID = Client ID (Key) | |||
auths.oauth2_clientSecret = Client Secret | |||
auths.enable_auto_register = Enable Auto Registration | |||
auths.tips = Tips | |||
auths.tip.github = Register a new OAuth application on https://github.com/settings/applications/new and use <host>/user/oauth2/<Authentication Name>/callback as "Authorization callback URL" | |||
auths.edit = Edit Authentication Setting | |||
auths.activated = This authentication is activated | |||
auths.new_success = New authentication '%s' has been added successfully. |
@@ -2983,3 +2983,24 @@ footer .ui.language .menu { | |||
.ui.user.list .item .description a:hover { | |||
text-decoration: underline; | |||
} | |||
.user.link-account:not(.icon) { | |||
padding-top: 15px; | |||
padding-bottom: 5px; | |||
} | |||
.signin .oauth2 div { | |||
display: inline-block; | |||
} | |||
.signin .oauth2 div p { | |||
margin: 10px 5px 0 0; | |||
float: left; | |||
} | |||
.signin .oauth2 a { | |||
margin-right: 5px; | |||
} | |||
.signin .oauth2 a:last-child { | |||
margin-right: 0px; | |||
} | |||
.signin .oauth2 img { | |||
width: 32px; | |||
height: 32px; | |||
} |
@@ -1019,9 +1019,9 @@ function initAdmin() { | |||
// New authentication | |||
if ($('.admin.new.authentication').length > 0) { | |||
$('#auth_type').change(function () { | |||
$('.ldap, .dldap, .smtp, .pam, .has-tls').hide(); | |||
$('.ldap, .dldap, .smtp, .pam, .oauth2, .has-tls').hide(); | |||
$('.ldap input[required], .dldap input[required], .smtp input[required], .pam input[required], .has-tls input[required]').removeAttr('required'); | |||
$('.ldap input[required], .dldap input[required], .smtp input[required], .pam input[required], .oauth2 input[required] .has-tls input[required]').removeAttr('required'); | |||
var authType = $(this).val(); | |||
switch (authType) { | |||
@@ -1042,6 +1042,10 @@ function initAdmin() { | |||
$('.dldap').show(); | |||
$('.dldap div.required:not(.ldap) input').attr('required', 'required'); | |||
break; | |||
case '6': // OAuth2 | |||
$('.oauth2').show(); | |||
$('.oauth2 input').attr('required', 'required'); | |||
break; | |||
} | |||
if (authType == '2' || authType == '5') { |
@@ -53,6 +53,7 @@ var ( | |||
{models.LoginNames[models.LoginDLDAP], models.LoginDLDAP}, | |||
{models.LoginNames[models.LoginSMTP], models.LoginSMTP}, | |||
{models.LoginNames[models.LoginPAM], models.LoginPAM}, | |||
{models.LoginNames[models.LoginOAuth2], models.LoginOAuth2}, | |||
} | |||
securityProtocols = []dropdownItem{ | |||
{models.SecurityProtocolNames[ldap.SecurityProtocolUnencrypted], ldap.SecurityProtocolUnencrypted}, | |||
@@ -75,6 +76,14 @@ func NewAuthSource(ctx *context.Context) { | |||
ctx.Data["AuthSources"] = authSources | |||
ctx.Data["SecurityProtocols"] = securityProtocols | |||
ctx.Data["SMTPAuths"] = models.SMTPAuths | |||
ctx.Data["OAuth2Providers"] = models.OAuth2Providers | |||
// only the first as default | |||
for key := range models.OAuth2Providers { | |||
ctx.Data["oauth2_provider"] = key | |||
break | |||
} | |||
ctx.HTML(200, tplAuthNew) | |||
} | |||
@@ -113,6 +122,14 @@ func parseSMTPConfig(form auth.AuthenticationForm) *models.SMTPConfig { | |||
} | |||
} | |||
func parseOAuth2Config(form auth.AuthenticationForm) *models.OAuth2Config { | |||
return &models.OAuth2Config{ | |||
Provider: form.Oauth2Provider, | |||
ClientID: form.Oauth2Key, | |||
ClientSecret: form.Oauth2Secret, | |||
} | |||
} | |||
// NewAuthSourcePost response for adding an auth source | |||
func NewAuthSourcePost(ctx *context.Context, form auth.AuthenticationForm) { | |||
ctx.Data["Title"] = ctx.Tr("admin.auths.new") | |||
@@ -124,6 +141,7 @@ func NewAuthSourcePost(ctx *context.Context, form auth.AuthenticationForm) { | |||
ctx.Data["AuthSources"] = authSources | |||
ctx.Data["SecurityProtocols"] = securityProtocols | |||
ctx.Data["SMTPAuths"] = models.SMTPAuths | |||
ctx.Data["OAuth2Providers"] = models.OAuth2Providers | |||
hasTLS := false | |||
var config core.Conversion | |||
@@ -138,6 +156,8 @@ func NewAuthSourcePost(ctx *context.Context, form auth.AuthenticationForm) { | |||
config = &models.PAMConfig{ | |||
ServiceName: form.PAMServiceName, | |||
} | |||
case models.LoginOAuth2: | |||
config = parseOAuth2Config(form) | |||
default: | |||
ctx.Error(400) | |||
return | |||
@@ -178,6 +198,7 @@ func EditAuthSource(ctx *context.Context) { | |||
ctx.Data["SecurityProtocols"] = securityProtocols | |||
ctx.Data["SMTPAuths"] = models.SMTPAuths | |||
ctx.Data["OAuth2Providers"] = models.OAuth2Providers | |||
source, err := models.GetLoginSourceByID(ctx.ParamsInt64(":authid")) | |||
if err != nil { | |||
@@ -187,16 +208,20 @@ func EditAuthSource(ctx *context.Context) { | |||
ctx.Data["Source"] = source | |||
ctx.Data["HasTLS"] = source.HasTLS() | |||
if source.IsOAuth2() { | |||
ctx.Data["CurrentOAuth2Provider"] = models.OAuth2Providers[source.OAuth2().Provider] | |||
} | |||
ctx.HTML(200, tplAuthEdit) | |||
} | |||
// EditAuthSourcePost resposne for editing auth source | |||
// EditAuthSourcePost response for editing auth source | |||
func EditAuthSourcePost(ctx *context.Context, form auth.AuthenticationForm) { | |||
ctx.Data["Title"] = ctx.Tr("admin.auths.edit") | |||
ctx.Data["PageIsAdmin"] = true | |||
ctx.Data["PageIsAdminAuthentications"] = true | |||
ctx.Data["SMTPAuths"] = models.SMTPAuths | |||
ctx.Data["OAuth2Providers"] = models.OAuth2Providers | |||
source, err := models.GetLoginSourceByID(ctx.ParamsInt64(":authid")) | |||
if err != nil { | |||
@@ -221,6 +246,8 @@ func EditAuthSourcePost(ctx *context.Context, form auth.AuthenticationForm) { | |||
config = &models.PAMConfig{ | |||
ServiceName: form.PAMServiceName, | |||
} | |||
case models.LoginOAuth2: | |||
config = parseOAuth2Config(form) | |||
default: | |||
ctx.Error(400) | |||
return |
@@ -54,6 +54,7 @@ func GlobalInit() { | |||
log.Fatal(4, "Failed to initialize ORM engine: %v", err) | |||
} | |||
models.HasEngine = true | |||
models.InitOAuth2() | |||
models.LoadRepoConfig() | |||
models.NewRepoContext() |
@@ -59,7 +59,7 @@ func HTTP(ctx *context.Context) { | |||
isWiki := false | |||
if strings.HasSuffix(reponame, ".wiki") { | |||
isWiki = true | |||
reponame = reponame[:len(reponame)-5] | |||
reponame = reponame[:len(reponame) - 5] | |||
} | |||
repoUser, err := models.GetUserByName(username) | |||
@@ -191,9 +191,9 @@ func HTTP(ctx *context.Context) { | |||
var lastLine int64 | |||
for { | |||
head := input[lastLine : lastLine+2] | |||
head := input[lastLine: lastLine + 2] | |||
if head[0] == '0' && head[1] == '0' { | |||
size, err := strconv.ParseInt(string(input[lastLine+2:lastLine+4]), 16, 32) | |||
size, err := strconv.ParseInt(string(input[lastLine + 2:lastLine + 4]), 16, 32) | |||
if err != nil { | |||
log.Error(4, "%v", err) | |||
return | |||
@@ -204,7 +204,7 @@ func HTTP(ctx *context.Context) { | |||
break | |||
} | |||
line := input[lastLine : lastLine+size] | |||
line := input[lastLine: lastLine + size] | |||
idx := bytes.IndexRune(line, '\000') | |||
if idx > -1 { | |||
line = line[:idx] | |||
@@ -370,7 +370,7 @@ func gitCommand(dir string, args ...string) []byte { | |||
func getGitConfig(option, dir string) string { | |||
out := string(gitCommand(dir, "config", option)) | |||
return out[0 : len(out)-1] | |||
return out[0: len(out) - 1] | |||
} | |||
func getConfigSetting(service, dir string) bool { | |||
@@ -501,7 +501,7 @@ func updateServerInfo(dir string) []byte { | |||
} | |||
func packetWrite(str string) []byte { | |||
s := strconv.FormatInt(int64(len(str)+4), 16) | |||
s := strconv.FormatInt(int64(len(str) + 4), 16) | |||
if len(s)%4 != 0 { | |||
s = strings.Repeat("0", 4-len(s)%4) + s | |||
} |
@@ -17,6 +17,10 @@ import ( | |||
"code.gitea.io/gitea/modules/context" | |||
"code.gitea.io/gitea/modules/log" | |||
"code.gitea.io/gitea/modules/setting" | |||
"net/http" | |||
"code.gitea.io/gitea/modules/auth/oauth2" | |||
"github.com/markbates/goth" | |||
"strings" | |||
) | |||
const ( | |||
@@ -30,6 +34,7 @@ const ( | |||
tplResetPassword base.TplName = "user/auth/reset_passwd" | |||
tplTwofa base.TplName = "user/auth/twofa" | |||
tplTwofaScratch base.TplName = "user/auth/twofa_scratch" | |||
tplLinkAccount base.TplName = "user/auth/link_account" | |||
) | |||
// AutoSignIn reads cookie and try to auto-login. | |||
@@ -61,7 +66,7 @@ func AutoSignIn(ctx *context.Context) (bool, error) { | |||
} | |||
if val, _ := ctx.GetSuperSecureCookie( | |||
base.EncodeMD5(u.Rands+u.Passwd), setting.CookieRememberName); val != u.Name { | |||
base.EncodeMD5(u.Rands + u.Passwd), setting.CookieRememberName); val != u.Name { | |||
return false, nil | |||
} | |||
@@ -109,6 +114,13 @@ func SignIn(ctx *context.Context) { | |||
return | |||
} | |||
oauth2Providers, err := models.GetActiveOAuth2Providers() | |||
if err != nil { | |||
ctx.Handle(500, "UserSignIn", err) | |||
return | |||
} | |||
ctx.Data["OAuth2Providers"] = oauth2Providers | |||
ctx.HTML(200, tplSignIn) | |||
} | |||
@@ -116,6 +128,13 @@ func SignIn(ctx *context.Context) { | |||
func SignInPost(ctx *context.Context, form auth.SignInForm) { | |||
ctx.Data["Title"] = ctx.Tr("sign_in") | |||
oauth2Providers, err := models.GetActiveOAuth2Providers() | |||
if err != nil { | |||
ctx.Handle(500, "UserSignIn", err) | |||
return | |||
} | |||
ctx.Data["OAuth2Providers"] = oauth2Providers | |||
if ctx.HasError() { | |||
ctx.HTML(200, tplSignIn) | |||
return | |||
@@ -277,7 +296,7 @@ func handleSignInFull(ctx *context.Context, u *models.User, remember bool, obeyR | |||
if remember { | |||
days := 86400 * setting.LogInRememberDays | |||
ctx.SetCookie(setting.CookieUserName, u.Name, days, setting.AppSubURL) | |||
ctx.SetSuperSecureCookie(base.EncodeMD5(u.Rands+u.Passwd), | |||
ctx.SetSuperSecureCookie(base.EncodeMD5(u.Rands + u.Passwd), | |||
setting.CookieRememberName, u.Name, days, setting.AppSubURL) | |||
} | |||
@@ -309,6 +328,333 @@ func handleSignInFull(ctx *context.Context, u *models.User, remember bool, obeyR | |||
} | |||
} | |||
// SignInOAuth handles the OAuth2 login buttons | |||
func SignInOAuth(ctx *context.Context) { | |||
provider := ctx.Params(":provider") | |||
loginSource, err := models.GetActiveOAuth2LoginSourceByName(provider) | |||
if err != nil { | |||
ctx.Handle(500, "SignIn", err) | |||
return | |||
} | |||
// try to do a direct callback flow, so we don't authenticate the user again but use the valid accesstoken to get the user | |||
user, gothUser, err := oAuth2UserLoginCallback(loginSource, ctx.Req.Request, ctx.Resp) | |||
if err == nil && user != nil { | |||
// we got the user without going through the whole OAuth2 authentication flow again | |||
handleOAuth2SignIn(user, gothUser, ctx, err) | |||
return | |||
} | |||
err = oauth2.Auth(loginSource.Name, ctx.Req.Request, ctx.Resp) | |||
if err != nil { | |||
ctx.Handle(500, "SignIn", err) | |||
} | |||
// redirect is done in oauth2.Auth | |||
} | |||
// SignInOAuthCallback handles the callback from the given provider | |||
func SignInOAuthCallback(ctx *context.Context) { | |||
provider := ctx.Params(":provider") | |||
// first look if the provider is still active | |||
loginSource, err := models.GetActiveOAuth2LoginSourceByName(provider) | |||
if err != nil { | |||
ctx.Handle(500, "SignIn", err) | |||
return | |||
} | |||
if loginSource == nil { | |||
ctx.Handle(500, "SignIn", errors.New("No valid provider found, check configured callback url in provider")) | |||
return | |||
} | |||
u, gothUser, err := oAuth2UserLoginCallback(loginSource, ctx.Req.Request, ctx.Resp) | |||
handleOAuth2SignIn(u, gothUser, ctx, err) | |||
} | |||
func handleOAuth2SignIn(u *models.User, gothUser goth.User, ctx *context.Context, err error) { | |||
if err != nil { | |||
ctx.Handle(500, "UserSignIn", err) | |||
return | |||
} | |||
if u == nil { | |||
// no existing user is found, request attach or new account | |||
ctx.Session.Set("linkAccountGothUser", gothUser) | |||
ctx.Redirect(setting.AppSubURL + "/user/link_account") | |||
return | |||
} | |||
// If this user is enrolled in 2FA, we can't sign the user in just yet. | |||
// Instead, redirect them to the 2FA authentication page. | |||
_, err = models.GetTwoFactorByUID(u.ID) | |||
if err != nil { | |||
if models.IsErrTwoFactorNotEnrolled(err) { | |||
ctx.Session.Set("uid", u.ID) | |||
ctx.Session.Set("uname", u.Name) | |||
// Clear whatever CSRF has right now, force to generate a new one | |||
ctx.SetCookie(setting.CSRFCookieName, "", -1, setting.AppSubURL) | |||
// Register last login | |||
u.SetLastLogin() | |||
if err := models.UpdateUser(u); err != nil { | |||
ctx.Handle(500, "UpdateUser", err) | |||
return | |||
} | |||
if redirectTo, _ := url.QueryUnescape(ctx.GetCookie("redirect_to")); len(redirectTo) > 0 { | |||
ctx.SetCookie("redirect_to", "", -1, setting.AppSubURL) | |||
ctx.Redirect(redirectTo) | |||
return | |||
} | |||
ctx.Redirect(setting.AppSubURL + "/") | |||
} else { | |||
ctx.Handle(500, "UserSignIn", err) | |||
} | |||
return | |||
} | |||
// User needs to use 2FA, save data and redirect to 2FA page. | |||
ctx.Session.Set("twofaUid", u.ID) | |||
ctx.Session.Set("twofaRemember", false) | |||
ctx.Redirect(setting.AppSubURL + "/user/two_factor") | |||
} | |||
// OAuth2UserLoginCallback attempts to handle the callback from the OAuth2 provider and if successful | |||
// login the user | |||
func oAuth2UserLoginCallback(loginSource *models.LoginSource, request *http.Request, response http.ResponseWriter) (*models.User, goth.User, error) { | |||
gothUser, err := oauth2.ProviderCallback(loginSource.Name, request, response) | |||
if err != nil { | |||
return nil, goth.User{}, err | |||
} | |||
user := &models.User{ | |||
LoginName: gothUser.UserID, | |||
LoginType: models.LoginOAuth2, | |||
LoginSource: loginSource.ID, | |||
} | |||
hasUser, err := models.GetUser(user) | |||
if err != nil { | |||
return nil, goth.User{}, err | |||
} | |||
if hasUser { | |||
return user, goth.User{}, nil | |||
} | |||
// search in external linked users | |||
externalLoginUser := &models.ExternalLoginUser{ | |||
ExternalID: gothUser.UserID, | |||
LoginSourceID: loginSource.ID, | |||
} | |||
hasUser, err = models.GetExternalLogin(externalLoginUser) | |||
if err != nil { | |||
return nil, goth.User{}, err | |||
} | |||
if hasUser { | |||
user, err = models.GetUserByID(externalLoginUser.UserID) | |||
return user, goth.User{}, err | |||
} | |||
// no user found to login | |||
return nil, gothUser, nil | |||
} | |||
// LinkAccount shows the page where the user can decide to login or create a new account | |||
func LinkAccount(ctx *context.Context) { | |||
ctx.Data["Title"] = ctx.Tr("link_account") | |||
ctx.Data["LinkAccountMode"] = true | |||
ctx.Data["EnableCaptcha"] = setting.Service.EnableCaptcha | |||
ctx.Data["DisableRegistration"] = setting.Service.DisableRegistration | |||
ctx.Data["ShowRegistrationButton"] = false | |||
// use this to set the right link into the signIn and signUp templates in the link_account template | |||
ctx.Data["SignInLink"] = setting.AppSubURL + "/user/link_account_signin" | |||
ctx.Data["SignUpLink"] = setting.AppSubURL + "/user/link_account_signup" | |||
gothUser := ctx.Session.Get("linkAccountGothUser") | |||
if gothUser == nil { | |||
ctx.Handle(500, "UserSignIn", errors.New("not in LinkAccount session")) | |||
return | |||
} | |||
ctx.Data["user_name"] = gothUser.(goth.User).NickName | |||
ctx.Data["email"] = gothUser.(goth.User).Email | |||
ctx.HTML(200, tplLinkAccount) | |||
} | |||
// LinkAccountPostSignIn handle the coupling of external account with another account using signIn | |||
func LinkAccountPostSignIn(ctx *context.Context, signInForm auth.SignInForm) { | |||
ctx.Data["Title"] = ctx.Tr("link_account") | |||
ctx.Data["LinkAccountMode"] = true | |||
ctx.Data["LinkAccountModeSignIn"] = true | |||
ctx.Data["EnableCaptcha"] = setting.Service.EnableCaptcha | |||
ctx.Data["DisableRegistration"] = setting.Service.DisableRegistration | |||
ctx.Data["ShowRegistrationButton"] = false | |||
// use this to set the right link into the signIn and signUp templates in the link_account template | |||
ctx.Data["SignInLink"] = setting.AppSubURL + "/user/link_account_signin" | |||
ctx.Data["SignUpLink"] = setting.AppSubURL + "/user/link_account_signup" | |||
gothUser := ctx.Session.Get("linkAccountGothUser") | |||
if gothUser == nil { | |||
ctx.Handle(500, "UserSignIn", errors.New("not in LinkAccount session")) | |||
return | |||
} | |||
if ctx.HasError() { | |||
ctx.HTML(200, tplLinkAccount) | |||
return | |||
} | |||
u, err := models.UserSignIn(signInForm.UserName, signInForm.Password) | |||
if err != nil { | |||
if models.IsErrUserNotExist(err) { | |||
ctx.RenderWithErr(ctx.Tr("form.username_password_incorrect"), tplLinkAccount, &signInForm) | |||
} else { | |||
ctx.Handle(500, "UserLinkAccount", err) | |||
} | |||
return | |||
} | |||
// If this user is enrolled in 2FA, we can't sign the user in just yet. | |||
// Instead, redirect them to the 2FA authentication page. | |||
_, err = models.GetTwoFactorByUID(u.ID) | |||
if err != nil { | |||
if models.IsErrTwoFactorNotEnrolled(err) { | |||
models.LinkAccountToUser(u, gothUser.(goth.User)) | |||
handleSignIn(ctx, u, signInForm.Remember) | |||
} else { | |||
ctx.Handle(500, "UserLinkAccount", err) | |||
} | |||
return | |||
} | |||
// User needs to use 2FA, save data and redirect to 2FA page. | |||
ctx.Session.Set("twofaUid", u.ID) | |||
ctx.Session.Set("twofaRemember", signInForm.Remember) | |||
ctx.Session.Set("linkAccount", true) | |||
ctx.Redirect(setting.AppSubURL + "/user/two_factor") | |||
} | |||
// LinkAccountPostRegister handle the creation of a new account for an external account using signUp | |||
func LinkAccountPostRegister(ctx *context.Context, cpt *captcha.Captcha, form auth.RegisterForm) { | |||
ctx.Data["Title"] = ctx.Tr("link_account") | |||
ctx.Data["LinkAccountMode"] = true | |||
ctx.Data["LinkAccountModeRegister"] = true | |||
ctx.Data["EnableCaptcha"] = setting.Service.EnableCaptcha | |||
ctx.Data["DisableRegistration"] = setting.Service.DisableRegistration | |||
ctx.Data["ShowRegistrationButton"] = false | |||
// use this to set the right link into the signIn and signUp templates in the link_account template | |||
ctx.Data["SignInLink"] = setting.AppSubURL + "/user/link_account_signin" | |||
ctx.Data["SignUpLink"] = setting.AppSubURL + "/user/link_account_signup" | |||
gothUser := ctx.Session.Get("linkAccountGothUser") | |||
if gothUser == nil { | |||
ctx.Handle(500, "UserSignUp", errors.New("not in LinkAccount session")) | |||
return | |||
} | |||
if ctx.HasError() { | |||
ctx.HTML(200, tplLinkAccount) | |||
return | |||
} | |||
if setting.Service.DisableRegistration { | |||
ctx.Error(403) | |||
return | |||
} | |||
if setting.Service.EnableCaptcha && !cpt.VerifyReq(ctx.Req) { | |||
ctx.Data["Err_Captcha"] = true | |||
ctx.RenderWithErr(ctx.Tr("form.captcha_incorrect"), tplLinkAccount, &form) | |||
return | |||
} | |||
if (len(strings.TrimSpace(form.Password)) > 0 || len(strings.TrimSpace(form.Retype)) > 0) && form.Password != form.Retype { | |||
ctx.Data["Err_Password"] = true | |||
ctx.RenderWithErr(ctx.Tr("form.password_not_match"), tplLinkAccount, &form) | |||
return | |||
} | |||
if len(strings.TrimSpace(form.Password)) > 0 && len(form.Password) < setting.MinPasswordLength { | |||
ctx.Data["Err_Password"] = true | |||
ctx.RenderWithErr(ctx.Tr("auth.password_too_short", setting.MinPasswordLength), tplLinkAccount, &form) | |||
return | |||
} | |||
loginSource, err := models.GetActiveOAuth2LoginSourceByName(gothUser.(goth.User).Provider) | |||
if err != nil { | |||
ctx.Handle(500, "CreateUser", err) | |||
} | |||
u := &models.User{ | |||
Name: form.UserName, | |||
Email: form.Email, | |||
Passwd: form.Password, | |||
IsActive: !setting.Service.RegisterEmailConfirm, | |||
LoginType: models.LoginOAuth2, | |||
LoginSource: loginSource.ID, | |||
LoginName: gothUser.(goth.User).UserID, | |||
} | |||
if err := models.CreateUser(u); err != nil { | |||
switch { | |||
case models.IsErrUserAlreadyExist(err): | |||
ctx.Data["Err_UserName"] = true | |||
ctx.RenderWithErr(ctx.Tr("form.username_been_taken"), tplLinkAccount, &form) | |||
case models.IsErrEmailAlreadyUsed(err): | |||
ctx.Data["Err_Email"] = true | |||
ctx.RenderWithErr(ctx.Tr("form.email_been_used"), tplLinkAccount, &form) | |||
case models.IsErrNameReserved(err): | |||
ctx.Data["Err_UserName"] = true | |||
ctx.RenderWithErr(ctx.Tr("user.form.name_reserved", err.(models.ErrNameReserved).Name), tplLinkAccount, &form) | |||
case models.IsErrNamePatternNotAllowed(err): | |||
ctx.Data["Err_UserName"] = true | |||
ctx.RenderWithErr(ctx.Tr("user.form.name_pattern_not_allowed", err.(models.ErrNamePatternNotAllowed).Pattern), tplLinkAccount, &form) | |||
default: | |||
ctx.Handle(500, "CreateUser", err) | |||
} | |||
return | |||
} | |||
log.Trace("Account created: %s", u.Name) | |||
// Auto-set admin for the only user. | |||
if models.CountUsers() == 1 { | |||
u.IsAdmin = true | |||
u.IsActive = true | |||
if err := models.UpdateUser(u); err != nil { | |||
ctx.Handle(500, "UpdateUser", err) | |||
return | |||
} | |||
} | |||
// Send confirmation email | |||
if setting.Service.RegisterEmailConfirm && u.ID > 1 { | |||
models.SendActivateAccountMail(ctx.Context, u) | |||
ctx.Data["IsSendRegisterMail"] = true | |||
ctx.Data["Email"] = u.Email | |||
ctx.Data["Hours"] = setting.Service.ActiveCodeLives / 60 | |||
ctx.HTML(200, TplActivate) | |||
if err := ctx.Cache.Put("MailResendLimit_"+u.LowerName, u.LowerName, 180); err != nil { | |||
log.Error(4, "Set cache(MailResendLimit) fail: %v", err) | |||
} | |||
return | |||
} | |||
ctx.Redirect(setting.AppSubURL + "/user/login") | |||
} | |||
// SignOut sign out from login status | |||
func SignOut(ctx *context.Context) { | |||
ctx.Session.Delete("uid") | |||
@@ -328,11 +674,7 @@ func SignUp(ctx *context.Context) { | |||
ctx.Data["EnableCaptcha"] = setting.Service.EnableCaptcha | |||
if setting.Service.DisableRegistration { | |||
ctx.Data["DisableRegistration"] = true | |||
ctx.HTML(200, tplSignUp) | |||
return | |||
} | |||
ctx.Data["DisableRegistration"] = setting.Service.DisableRegistration | |||
ctx.HTML(200, tplSignUp) | |||
} | |||
@@ -540,7 +882,7 @@ func ForgotPasswdPost(ctx *context.Context) { | |||
return | |||
} | |||
if !u.IsLocal() { | |||
if !u.IsLocal() && !u.IsOAuth2() { | |||
ctx.Data["Err_Email"] = true | |||
ctx.RenderWithErr(ctx.Tr("auth.non_local_account"), tplForgotPassword, nil) | |||
return |
@@ -37,6 +37,7 @@ const ( | |||
tplSettingsApplications base.TplName = "user/settings/applications" | |||
tplSettingsTwofa base.TplName = "user/settings/twofa" | |||
tplSettingsTwofaEnroll base.TplName = "user/settings/twofa_enroll" | |||
tplSettingsAccountLink base.TplName = "user/settings/account_link" | |||
tplSettingsDelete base.TplName = "user/settings/delete" | |||
tplSecurity base.TplName = "user/security" | |||
) | |||
@@ -200,7 +201,7 @@ func SettingsPasswordPost(ctx *context.Context, form auth.ChangePasswordForm) { | |||
return | |||
} | |||
if !ctx.User.ValidatePassword(form.OldPassword) { | |||
if ctx.User.IsPasswordSet() && !ctx.User.ValidatePassword(form.OldPassword) { | |||
ctx.Flash.Error(ctx.Tr("settings.password_incorrect")) | |||
} else if form.Password != form.Retype { | |||
ctx.Flash.Error(ctx.Tr("form.password_not_match")) | |||
@@ -631,6 +632,49 @@ func SettingsTwoFactorEnrollPost(ctx *context.Context, form auth.TwoFactorAuthFo | |||
ctx.Redirect(setting.AppSubURL + "/user/settings/two_factor") | |||
} | |||
// SettingsAccountLinks render the account links settings page | |||
func SettingsAccountLinks(ctx *context.Context) { | |||
ctx.Data["Title"] = ctx.Tr("settings") | |||
ctx.Data["PageIsSettingsAccountLink"] = true | |||
accountLinks, err := models.ListAccountLinks(ctx.User) | |||
if err != nil { | |||
ctx.Handle(500, "ListAccountLinks", err) | |||
return | |||
} | |||
// map the provider display name with the LoginSource | |||
sources := make(map[*models.LoginSource]string) | |||
for _, externalAccount := range accountLinks { | |||
if loginSource, err := models.GetLoginSourceByID(externalAccount.LoginSourceID); err == nil { | |||
var providerDisplayName string | |||
if loginSource.IsOAuth2() { | |||
providerTechnicalName := loginSource.OAuth2().Provider | |||
providerDisplayName = models.OAuth2Providers[providerTechnicalName].DisplayName | |||
} else { | |||
providerDisplayName = loginSource.Name | |||
} | |||
sources[loginSource] = providerDisplayName | |||
} | |||
} | |||
ctx.Data["AccountLinks"] = sources | |||
ctx.HTML(200, tplSettingsAccountLink) | |||
} | |||
// SettingsDeleteAccountLink delete a single account link | |||
func SettingsDeleteAccountLink(ctx *context.Context) { | |||
if _, err := models.RemoveAccountLink(ctx.User, ctx.QueryInt64("loginSourceID")); err != nil { | |||
ctx.Flash.Error("RemoveAccountLink: " + err.Error()) | |||
} else { | |||
ctx.Flash.Success(ctx.Tr("settings.remove_account_link_success")) | |||
} | |||
ctx.JSON(200, map[string]interface{}{ | |||
"redirect": setting.AppSubURL + "/user/settings/account_link", | |||
}) | |||
} | |||
// SettingsDelete render user suicide page and response for delete user himself | |||
func SettingsDelete(ctx *context.Context) { | |||
ctx.Data["Title"] = ctx.Tr("settings") |
@@ -142,6 +142,32 @@ | |||
</div> | |||
{{end}} | |||
<!-- OAuth2 --> | |||
{{if .Source.IsOAuth2}} | |||
{{ $cfg:=.Source.OAuth2 }} | |||
<div class="inline required field"> | |||
<label>{{.i18n.Tr "admin.auths.oauth2_provider"}}</label> | |||
<div class="ui selection type dropdown"> | |||
<input type="hidden" id="oauth2_provider" name="oauth2_provider" value="{{$cfg.Provider}}" required> | |||
<div class="text">{{.CurrentOAuth2Provider.DisplayName}}</div> | |||
<i class="dropdown icon"></i> | |||
<div class="menu"> | |||
{{range $key, $value := .OAuth2Providers}} | |||
<div class="item" data-value="{{$key}}">{{$value.DisplayName}}</div> | |||
{{end}} | |||
</div> | |||
</div> | |||
</div> | |||
<div class="required field"> | |||
<label for="oauth2_key">{{.i18n.Tr "admin.auths.oauth2_clientID"}}</label> | |||
<input id="oauth2_key" name="oauth2_key" value="{{$cfg.ClientID}}" required> | |||
</div> | |||
<div class="required field"> | |||
<label for="oauth2_secret">{{.i18n.Tr "admin.auths.oauth2_clientSecret"}}</label> | |||
<input id="oauth2_secret" name="oauth2_secret" value="{{$cfg.ClientSecret}}" required> | |||
</div> | |||
{{end}} | |||
<div class="inline field {{if not .Source.IsSMTP}}hide{{end}}"> | |||
<div class="ui checkbox"> | |||
<label><strong>{{.i18n.Tr "admin.auths.enable_tls"}}</strong></label> |
@@ -133,6 +133,31 @@ | |||
<input id="pam_service_name" name="pam_service_name" value="{{.pam_service_name}}" /> | |||
</div> | |||
<!-- OAuth2 --> | |||
<div class="oauth2 field {{if not (eq .type 6)}}hide{{end}}"> | |||
<div class="inline required field"> | |||
<label>{{.i18n.Tr "admin.auths.oauth2_provider"}}</label> | |||
<div class="ui selection type dropdown"> | |||
<input type="hidden" id="oauth2_provider" name="oauth2_provider" value="{{.oauth2_provider}}"> | |||
<div class="text">{{.oauth2_provider}}</div> | |||
<i class="dropdown icon"></i> | |||
<div class="menu"> | |||
{{range $key, $value := .OAuth2Providers}} | |||
<div class="item" data-value="{{$key}}">{{$value.DisplayName}}</div> | |||
{{end}} | |||
</div> | |||
</div> | |||
</div> | |||
<div class="required field"> | |||
<label for="oauth2_key">{{.i18n.Tr "admin.auths.oauth2_clientID"}}</label> | |||
<input id="oauth2_key" name="oauth2_key" value="{{.oauth2_key}}"> | |||
</div> | |||
<div class="required field"> | |||
<label for="oauth2_secret">{{.i18n.Tr "admin.auths.oauth2_clientSecret"}}</label> | |||
<input id="oauth2_secret" name="oauth2_secret" value="{{.oauth2_secret}}"> | |||
</div> | |||
</div> | |||
<div class="ldap field"> | |||
<div class="ui checkbox"> | |||
<label><strong>{{.i18n.Tr "admin.auths.attributes_in_bind"}}</strong></label> | |||
@@ -170,6 +195,8 @@ | |||
<div class="ui attached segment"> | |||
<h5>GMail Settings:</h5> | |||
<p>Host: smtp.gmail.com, Port: 587, Enable TLS Encryption: true</p> | |||
<h5>OAuth GitHub:</h5> | |||
<p>{{.i18n.Tr "admin.auths.tip.github"}}</p> | |||
</div> | |||
</div> | |||
</div> |
@@ -43,7 +43,7 @@ | |||
<input id="email" name="email" type="email" value="{{.User.Email}}" autofocus required> | |||
</div> | |||
<input class="fake" type="password"> | |||
<div class="local field {{if .Err_Password}}error{{end}} {{if not (eq .User.LoginSource 0)}}hide{{end}}"> | |||
<div class="local field {{if .Err_Password}}error{{end}} {{if not (or (.User.IsLocal) (.User.IsOAuth2))}}hide{{end}}"> | |||
<label for="password">{{.i18n.Tr "password"}}</label> | |||
<input id="password" name="password" type="password"> | |||
<p class="help">{{.i18n.Tr "admin.users.password_helper"}}</p> |
@@ -0,0 +1,13 @@ | |||
{{template "base/head" .}} | |||
<div class="user link-account"> | |||
<div class="ui middle very relaxed page grid"> | |||
<div class="column"> | |||
<p class="large center"> | |||
{{.i18n.Tr "link_account_signin_or_signup"}} | |||
</p> | |||
</div> | |||
</div> | |||
</div> | |||
{{template "user/auth/signin_inner" .}} | |||
{{template "user/auth/signup_inner" .}} | |||
{{template "base/footer" .}} |
@@ -1,44 +1,3 @@ | |||
{{template "base/head" .}} | |||
<div class="user signin"> | |||
<div class="ui middle very relaxed page grid"> | |||
<div class="column"> | |||
<form class="ui form" action="{{.Link}}" method="post"> | |||
{{.CsrfTokenHtml}} | |||
<h3 class="ui top attached header"> | |||
{{.i18n.Tr "sign_in"}} | |||
</h3> | |||
<div class="ui attached segment"> | |||
{{template "base/alert" .}} | |||
<div class="required inline field {{if .Err_UserName}}error{{end}}"> | |||
<label for="user_name">{{.i18n.Tr "home.uname_holder"}}</label> | |||
<input id="user_name" name="user_name" value="{{.user_name}}" autofocus required> | |||
</div> | |||
<div class="required inline field {{if .Err_Password}}error{{end}}"> | |||
<label for="password">{{.i18n.Tr "password"}}</label> | |||
<input id="password" name="password" type="password" value="{{.password}}" autocomplete="off" required> | |||
</div> | |||
<div class="inline field"> | |||
<label></label> | |||
<div class="ui checkbox"> | |||
<label>{{.i18n.Tr "auth.remember_me"}}</label> | |||
<input name="remember" type="checkbox"> | |||
</div> | |||
</div> | |||
<div class="inline field"> | |||
<label></label> | |||
<button class="ui green button">{{.i18n.Tr "sign_in"}}</button> | |||
<a href="{{AppSubUrl}}/user/forget_password">{{.i18n.Tr "auth.forget_password"}}</a> | |||
</div> | |||
{{if .ShowRegistrationButton}} | |||
<div class="inline field"> | |||
<label></label> | |||
<a href="{{AppSubUrl}}/user/sign_up">{{.i18n.Tr "auth.sign_up_now" | Str2html}}</a> | |||
</div> | |||
{{end}} | |||
</div> | |||
</form> | |||
</div> | |||
</div> | |||
</div> | |||
{{template "user/auth/signin_inner" .}} | |||
{{template "base/footer" .}} |
@@ -0,0 +1,57 @@ | |||
<div class="user signin{{if .LinkAccountMode}} icon{{end}}"> | |||
<div class="ui middle very relaxed page grid"> | |||
<div class="column"> | |||
<form class="ui form" action="{{if not .LinkAccountMode}}{{.Link}}{{else}}{{.SignInLink}}{{end}}" method="post"> | |||
{{.CsrfTokenHtml}} | |||
<h3 class="ui top attached header"> | |||
{{.i18n.Tr "sign_in"}} | |||
</h3> | |||
<div class="ui attached segment"> | |||
{{if or (not .LinkAccountMode) (and .LinkAccountMode .LinkAccountModeSignIn)}} | |||
{{template "base/alert" .}} | |||
{{end}} | |||
<div class="required inline field {{if and (.Err_UserName) (or (not .LinkAccountMode) (and .LinkAccountMode .LinkAccountModeSignIn))}}error{{end}}"> | |||
<label for="user_name">{{.i18n.Tr "home.uname_holder"}}</label> | |||
<input id="user_name" name="user_name" value="{{.user_name}}" autofocus required> | |||
</div> | |||
<div class="required inline field {{if and (.Err_Password) (or (not .LinkAccountMode) (and .LinkAccountMode .LinkAccountModeSignIn))}}error{{end}}"> | |||
<label for="password">{{.i18n.Tr "password"}}</label> | |||
<input id="password" name="password" type="password" value="{{.password}}" autocomplete="off" required> | |||
</div> | |||
{{if not .LinkAccountMode}} | |||
<div class="inline field"> | |||
<label></label> | |||
<div class="ui checkbox"> | |||
<label>{{.i18n.Tr "auth.remember_me"}}</label> | |||
<input name="remember" type="checkbox"> | |||
</div> | |||
</div> | |||
{{end}} | |||
<div class="inline field"> | |||
<label></label> | |||
<button class="ui green button">{{.i18n.Tr "sign_in"}}</button> | |||
<a href="{{AppSubUrl}}/user/forget_password">{{.i18n.Tr "auth.forget_password"}}</a> | |||
</div> | |||
{{if .ShowRegistrationButton}} | |||
<div class="inline field"> | |||
<label></label> | |||
<a href="{{AppSubUrl}}/user/sign_up">{{.i18n.Tr "auth.sign_up_now" | Str2html}}</a> | |||
</div> | |||
{{end}} | |||
{{if .OAuth2Providers}} | |||
<div class="ui attached segment"> | |||
<div class="oauth2 center"> | |||
<div> | |||
<p>{{.i18n.Tr "sign_in_with"}}</p>{{range $key, $value := .OAuth2Providers}}<a href="{{AppSubUrl}}/user/oauth2/{{$key}}"><img alt="{{$value.DisplayName}}" title="{{$value.DisplayName}}" src="{{AppSubUrl}}{{$value.Image}}"></a>{{end}} | |||
</div> | |||
</div> | |||
</div> | |||
{{end}} | |||
</div> | |||
</form> | |||
</div> | |||
</div> | |||
</div> |
@@ -1,56 +1,3 @@ | |||
{{template "base/head" .}} | |||
<div class="user signup"> | |||
<div class="ui middle very relaxed page grid"> | |||
<div class="column"> | |||
<form class="ui form" action="{{.Link}}" method="post"> | |||
{{.CsrfTokenHtml}} | |||
<h3 class="ui top attached header"> | |||
{{if .IsSocialLogin}}{{.i18n.Tr "social_sign_in" | Str2html}}{{else}}{{.i18n.Tr "sign_up"}}{{end}} | |||
</h3> | |||
<div class="ui attached segment"> | |||
{{template "base/alert" .}} | |||
{{if .DisableRegistration}} | |||
<p>{{.i18n.Tr "auth.disable_register_prompt"}}</p> | |||
{{else}} | |||
<div class="required inline field {{if .Err_UserName}}error{{end}}"> | |||
<label for="user_name">{{.i18n.Tr "username"}}</label> | |||
<input id="user_name" name="user_name" value="{{.user_name}}" autofocus required> | |||
</div> | |||
<div class="required inline field {{if .Err_Email}}error{{end}}"> | |||
<label for="email">{{.i18n.Tr "email"}}</label> | |||
<input id="email" name="email" type="email" value="{{.email}}" required> | |||
</div> | |||
<div class="required inline field {{if .Err_Password}}error{{end}}"> | |||
<label for="password">{{.i18n.Tr "password"}}</label> | |||
<input id="password" name="password" type="password" value="{{.password}}" autocomplete="off" required> | |||
</div> | |||
<div class="required inline field {{if .Err_Password}}error{{end}}"> | |||
<label for="retype">{{.i18n.Tr "re_type"}}</label> | |||
<input id="retype" name="retype" type="password" value="{{.retype}}" autocomplete="off" required> | |||
</div> | |||
{{if .EnableCaptcha}} | |||
<div class="inline field"> | |||
<label></label> | |||
{{.Captcha.CreateHtml}} | |||
</div> | |||
<div class="required inline field {{if .Err_Captcha}}error{{end}}"> | |||
<label for="captcha">{{.i18n.Tr "captcha"}}</label> | |||
<input id="captcha" name="captcha" value="{{.captcha}}" autocomplete="off"> | |||
</div> | |||
{{end}} | |||
<div class="inline field"> | |||
<label></label> | |||
<button class="ui green button">{{.i18n.Tr "auth.create_new_account"}}</button> | |||
</div> | |||
<div class="inline field"> | |||
<label></label> | |||
<a href="{{AppSubUrl}}/user/login">{{if .IsSocialLogin}}{{.i18n.Tr "auth.social_register_helper_msg"}}{{else}}{{.i18n.Tr "auth.register_helper_msg"}}{{end}}</a> | |||
</div> | |||
{{end}} | |||
</div> | |||
</form> | |||
</div> | |||
</div> | |||
</div> | |||
{{template "user/auth/signup_inner" .}} | |||
{{template "base/footer" .}} |
@@ -0,0 +1,59 @@ | |||
<div class="user signup{{if .LinkAccountMode}} icon{{end}}"> | |||
<div class="ui middle very relaxed page grid"> | |||
<div class="column"> | |||
<form class="ui form" action="{{if not .LinkAccountMode}}{{.Link}}{{else}}{{.SignUpLink}}{{end}}" method="post"> | |||
{{.CsrfTokenHtml}} | |||
<h3 class="ui top attached header"> | |||
{{.i18n.Tr "sign_up"}} | |||
</h3> | |||
<div class="ui attached segment"> | |||
{{if or (not .LinkAccountMode) (and .LinkAccountMode .LinkAccountModeRegister)}} | |||
{{template "base/alert" .}} | |||
{{end}} | |||
{{if .DisableRegistration}} | |||
<p>{{.i18n.Tr "auth.disable_register_prompt"}}</p> | |||
{{else}} | |||
<div class="required inline field {{if and (.Err_UserName) (or (not .LinkAccountMode) (and .LinkAccountMode .LinkAccountModeRegister))}}error{{end}}"> | |||
<label for="user_name">{{.i18n.Tr "username"}}</label> | |||
<input id="user_name" name="user_name" value="{{.user_name}}" autofocus required> | |||
</div> | |||
<div class="required inline field {{if .Err_Email}}error{{end}}"> | |||
<label for="email">{{.i18n.Tr "email"}}</label> | |||
<input id="email" name="email" type="email" value="{{.email}}" required> | |||
</div> | |||
<div class="required inline field {{if and (.Err_Password) (or (not .LinkAccountMode) (and .LinkAccountMode .LinkAccountModeRegister))}}error{{end}}"> | |||
<label for="password">{{.i18n.Tr "password"}}</label> | |||
<input id="password" name="password" type="password" value="{{.password}}" autocomplete="off" required> | |||
</div> | |||
<div class="required inline field {{if and (.Err_Password) (or (not .LinkAccountMode) (and .LinkAccountMode .LinkAccountModeRegister))}}error{{end}}"> | |||
<label for="retype">{{.i18n.Tr "re_type"}}</label> | |||
<input id="retype" name="retype" type="password" value="{{.retype}}" autocomplete="off" required> | |||
</div> | |||
{{if .EnableCaptcha}} | |||
<div class="inline field"> | |||
<label></label> | |||
{{.Captcha.CreateHtml}} | |||
</div> | |||
<div class="required inline field {{if .Err_Captcha}}error{{end}}"> | |||
<label for="captcha">{{.i18n.Tr "captcha"}}</label> | |||
<input id="captcha" name="captcha" value="{{.captcha}}" autocomplete="off"> | |||
</div> | |||
{{end}} | |||
<div class="inline field"> | |||
<label></label> | |||
<button class="ui green button">{{.i18n.Tr "auth.create_new_account"}}</button> | |||
</div> | |||
{{if not .LinkAccountMode}} | |||
<div class="inline field"> | |||
<label></label> | |||
<a href="{{AppSubUrl}}/user/login">{{.i18n.Tr "auth.register_helper_msg"}}</a> | |||
</div> | |||
{{end}} | |||
{{end}} | |||
</div> | |||
</form> | |||
</div> | |||
</div> | |||
</div> |
@@ -0,0 +1,48 @@ | |||
{{template "base/head" .}} | |||
<div class="user settings account_link"> | |||
<div class="ui container"> | |||
<div class="ui grid"> | |||
{{template "user/settings/navbar" .}} | |||
<div class="twelve wide column content"> | |||
{{template "base/alert" .}} | |||
<h4 class="ui top attached header"> | |||
{{.i18n.Tr "settings.manage_account_links"}} | |||
</h4> | |||
<div class="ui attached segment"> | |||
<div class="ui key list"> | |||
<div class="item"> | |||
{{.i18n.Tr "settings.manage_account_links_desc"}} | |||
</div> | |||
{{if .AccountLinks}} | |||
{{range $loginSource, $provider := .AccountLinks}} | |||
<div class="item ui grid"> | |||
<div class="column"> | |||
<strong>{{$provider}}</strong> | |||
{{if $loginSource.IsActived}}<span class="text red">{{$.i18n.Tr "settings.active"}}</span>{{end}} | |||
<div class="ui right"> | |||
<button class="ui red tiny button delete-button" data-url="{{$.Link}}" data-id="{{$loginSource.ID}}"> | |||
{{$.i18n.Tr "settings.delete_key"}} | |||
</button> | |||
</div> | |||
</div> | |||
</div> | |||
{{end}} | |||
{{end}} | |||
</div> | |||
</div> | |||
</div> | |||
</div> | |||
</div> | |||
</div> | |||
<div class="ui small basic delete modal"> | |||
<div class="ui icon header"> | |||
<i class="trash icon"></i> | |||
{{.i18n.Tr "settings.remove_account_link"}} | |||
</div> | |||
<div class="content"> | |||
<p>{{.i18n.Tr "settings.remove_account_link_desc"}}</p> | |||
</div> | |||
{{template "base/delete_modal_actions" .}} | |||
</div> | |||
{{template "base/footer" .}} |
@@ -22,6 +22,9 @@ | |||
<a class="{{if .PageIsSettingsTwofa}}active{{end}} item" href="{{AppSubUrl}}/user/settings/two_factor"> | |||
{{.i18n.Tr "settings.twofa"}} | |||
</a> | |||
<a class="{{if .PageIsSettingsAccountLink}}active{{end}} item" href="{{AppSubUrl}}/user/settings/account_link"> | |||
{{.i18n.Tr "settings.account_link"}} | |||
</a> | |||
<a class="{{if .PageIsSettingsDelete}}active{{end}} item" href="{{AppSubUrl}}/user/settings/delete"> | |||
{{.i18n.Tr "settings.delete"}} | |||
</a> |
@@ -9,13 +9,15 @@ | |||
{{.i18n.Tr "settings.change_password"}} | |||
</h4> | |||
<div class="ui attached segment"> | |||
{{if .SignedUser.IsLocal}} | |||
{{if or (.SignedUser.IsLocal) (.SignedUser.IsOAuth2)}} | |||
<form class="ui form" action="{{.Link}}" method="post"> | |||
{{.CsrfTokenHtml}} | |||
{{if .SignedUser.IsPasswordSet}} | |||
<div class="required field {{if .Err_OldPassword}}error{{end}}"> | |||
<label for="old_password">{{.i18n.Tr "settings.old_password"}}</label> | |||
<input id="old_password" name="old_password" type="password" autocomplete="off" autofocus required> | |||
</div> | |||
{{end}} | |||
<div class="required field {{if .Err_Password}}error{{end}}"> | |||
<label for="password">{{.i18n.Tr "settings.new_password"}}</label> | |||
<input id="password" name="password" type="password" autocomplete="off" required> |
@@ -0,0 +1,27 @@ | |||
Copyright (c) 2012 Rodrigo Moraes. All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are | |||
met: | |||
* Redistributions of source code must retain the above copyright | |||
notice, this list of conditions and the following disclaimer. | |||
* Redistributions in binary form must reproduce the above | |||
copyright notice, this list of conditions and the following disclaimer | |||
in the documentation and/or other materials provided with the | |||
distribution. | |||
* Neither the name of Google Inc. nor the names of its | |||
contributors may be used to endorse or promote products derived from | |||
this software without specific prior written permission. | |||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS | |||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT | |||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR | |||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT | |||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, | |||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT | |||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY | |||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | |||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
@@ -0,0 +1,10 @@ | |||
context | |||
======= | |||
[![Build Status](https://travis-ci.org/gorilla/context.png?branch=master)](https://travis-ci.org/gorilla/context) | |||
gorilla/context is a general purpose registry for global request variables. | |||
> Note: gorilla/context, having been born well before `context.Context` existed, does not play well | |||
> with the shallow copying of the request that [`http.Request.WithContext`](https://golang.org/pkg/net/http/#Request.WithContext) (added to net/http Go 1.7 onwards) performs. You should either use *just* gorilla/context, or moving forward, the new `http.Request.Context()`. | |||
Read the full documentation here: http://www.gorillatoolkit.org/pkg/context |
@@ -0,0 +1,143 @@ | |||
// Copyright 2012 The Gorilla Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
package context | |||
import ( | |||
"net/http" | |||
"sync" | |||
"time" | |||
) | |||
var ( | |||
mutex sync.RWMutex | |||
data = make(map[*http.Request]map[interface{}]interface{}) | |||
datat = make(map[*http.Request]int64) | |||
) | |||
// Set stores a value for a given key in a given request. | |||
func Set(r *http.Request, key, val interface{}) { | |||
mutex.Lock() | |||
if data[r] == nil { | |||
data[r] = make(map[interface{}]interface{}) | |||
datat[r] = time.Now().Unix() | |||
} | |||
data[r][key] = val | |||
mutex.Unlock() | |||
} | |||
// Get returns a value stored for a given key in a given request. | |||
func Get(r *http.Request, key interface{}) interface{} { | |||
mutex.RLock() | |||
if ctx := data[r]; ctx != nil { | |||
value := ctx[key] | |||
mutex.RUnlock() | |||
return value | |||
} | |||
mutex.RUnlock() | |||
return nil | |||
} | |||
// GetOk returns stored value and presence state like multi-value return of map access. | |||
func GetOk(r *http.Request, key interface{}) (interface{}, bool) { | |||
mutex.RLock() | |||
if _, ok := data[r]; ok { | |||
value, ok := data[r][key] | |||
mutex.RUnlock() | |||
return value, ok | |||
} | |||
mutex.RUnlock() | |||
return nil, false | |||
} | |||
// GetAll returns all stored values for the request as a map. Nil is returned for invalid requests. | |||
func GetAll(r *http.Request) map[interface{}]interface{} { | |||
mutex.RLock() | |||
if context, ok := data[r]; ok { | |||
result := make(map[interface{}]interface{}, len(context)) | |||
for k, v := range context { | |||
result[k] = v | |||
} | |||
mutex.RUnlock() | |||
return result | |||
} | |||
mutex.RUnlock() | |||
return nil | |||
} | |||
// GetAllOk returns all stored values for the request as a map and a boolean value that indicates if | |||
// the request was registered. | |||
func GetAllOk(r *http.Request) (map[interface{}]interface{}, bool) { | |||
mutex.RLock() | |||
context, ok := data[r] | |||
result := make(map[interface{}]interface{}, len(context)) | |||
for k, v := range context { | |||
result[k] = v | |||
} | |||
mutex.RUnlock() | |||
return result, ok | |||
} | |||
// Delete removes a value stored for a given key in a given request. | |||
func Delete(r *http.Request, key interface{}) { | |||
mutex.Lock() | |||
if data[r] != nil { | |||
delete(data[r], key) | |||
} | |||
mutex.Unlock() | |||
} | |||
// Clear removes all values stored for a given request. | |||
// | |||
// This is usually called by a handler wrapper to clean up request | |||
// variables at the end of a request lifetime. See ClearHandler(). | |||
func Clear(r *http.Request) { | |||
mutex.Lock() | |||
clear(r) | |||
mutex.Unlock() | |||
} | |||
// clear is Clear without the lock. | |||
func clear(r *http.Request) { | |||
delete(data, r) | |||
delete(datat, r) | |||
} | |||
// Purge removes request data stored for longer than maxAge, in seconds. | |||
// It returns the amount of requests removed. | |||
// | |||
// If maxAge <= 0, all request data is removed. | |||
// | |||
// This is only used for sanity check: in case context cleaning was not | |||
// properly set some request data can be kept forever, consuming an increasing | |||
// amount of memory. In case this is detected, Purge() must be called | |||
// periodically until the problem is fixed. | |||
func Purge(maxAge int) int { | |||
mutex.Lock() | |||
count := 0 | |||
if maxAge <= 0 { | |||
count = len(data) | |||
data = make(map[*http.Request]map[interface{}]interface{}) | |||
datat = make(map[*http.Request]int64) | |||
} else { | |||
min := time.Now().Unix() - int64(maxAge) | |||
for r := range data { | |||
if datat[r] < min { | |||
clear(r) | |||
count++ | |||
} | |||
} | |||
} | |||
mutex.Unlock() | |||
return count | |||
} | |||
// ClearHandler wraps an http.Handler and clears request values at the end | |||
// of a request lifetime. | |||
func ClearHandler(h http.Handler) http.Handler { | |||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |||
defer Clear(r) | |||
h.ServeHTTP(w, r) | |||
}) | |||
} |
@@ -0,0 +1,88 @@ | |||
// Copyright 2012 The Gorilla Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
/* | |||
Package context stores values shared during a request lifetime. | |||
Note: gorilla/context, having been born well before `context.Context` existed, | |||
does not play well > with the shallow copying of the request that | |||
[`http.Request.WithContext`](https://golang.org/pkg/net/http/#Request.WithContext) | |||
(added to net/http Go 1.7 onwards) performs. You should either use *just* | |||
gorilla/context, or moving forward, the new `http.Request.Context()`. | |||
For example, a router can set variables extracted from the URL and later | |||
application handlers can access those values, or it can be used to store | |||
sessions values to be saved at the end of a request. There are several | |||
others common uses. | |||
The idea was posted by Brad Fitzpatrick to the go-nuts mailing list: | |||
http://groups.google.com/group/golang-nuts/msg/e2d679d303aa5d53 | |||
Here's the basic usage: first define the keys that you will need. The key | |||
type is interface{} so a key can be of any type that supports equality. | |||
Here we define a key using a custom int type to avoid name collisions: | |||
package foo | |||
import ( | |||
"github.com/gorilla/context" | |||
) | |||
type key int | |||
const MyKey key = 0 | |||
Then set a variable. Variables are bound to an http.Request object, so you | |||
need a request instance to set a value: | |||
context.Set(r, MyKey, "bar") | |||
The application can later access the variable using the same key you provided: | |||
func MyHandler(w http.ResponseWriter, r *http.Request) { | |||
// val is "bar". | |||
val := context.Get(r, foo.MyKey) | |||
// returns ("bar", true) | |||
val, ok := context.GetOk(r, foo.MyKey) | |||
// ... | |||
} | |||
And that's all about the basic usage. We discuss some other ideas below. | |||
Any type can be stored in the context. To enforce a given type, make the key | |||
private and wrap Get() and Set() to accept and return values of a specific | |||
type: | |||
type key int | |||
const mykey key = 0 | |||
// GetMyKey returns a value for this package from the request values. | |||
func GetMyKey(r *http.Request) SomeType { | |||
if rv := context.Get(r, mykey); rv != nil { | |||
return rv.(SomeType) | |||
} | |||
return nil | |||
} | |||
// SetMyKey sets a value for this package in the request values. | |||
func SetMyKey(r *http.Request, val SomeType) { | |||
context.Set(r, mykey, val) | |||
} | |||
Variables must be cleared at the end of a request, to remove all values | |||
that were stored. This can be done in an http.Handler, after a request was | |||
served. Just call Clear() passing the request: | |||
context.Clear(r) | |||
...or use ClearHandler(), which conveniently wraps an http.Handler to clear | |||
variables at the end of a request lifetime. | |||
The Routers from the packages gorilla/mux and gorilla/pat call Clear() | |||
so if you are using either of them you don't need to clear the context manually. | |||
*/ | |||
package context |
@@ -0,0 +1,27 @@ | |||
Copyright (c) 2012 Rodrigo Moraes. All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are | |||
met: | |||
* Redistributions of source code must retain the above copyright | |||
notice, this list of conditions and the following disclaimer. | |||
* Redistributions in binary form must reproduce the above | |||
copyright notice, this list of conditions and the following disclaimer | |||
in the documentation and/or other materials provided with the | |||
distribution. | |||
* Neither the name of Google Inc. nor the names of its | |||
contributors may be used to endorse or promote products derived from | |||
this software without specific prior written permission. | |||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS | |||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT | |||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR | |||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT | |||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, | |||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT | |||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY | |||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | |||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
@@ -0,0 +1,299 @@ | |||
gorilla/mux | |||
=== | |||
[![GoDoc](https://godoc.org/github.com/gorilla/mux?status.svg)](https://godoc.org/github.com/gorilla/mux) | |||
[![Build Status](https://travis-ci.org/gorilla/mux.svg?branch=master)](https://travis-ci.org/gorilla/mux) | |||
![Gorilla Logo](http://www.gorillatoolkit.org/static/images/gorilla-icon-64.png) | |||
http://www.gorillatoolkit.org/pkg/mux | |||
Package `gorilla/mux` implements a request router and dispatcher for matching incoming requests to | |||
their respective handler. | |||
The name mux stands for "HTTP request multiplexer". Like the standard `http.ServeMux`, `mux.Router` matches incoming requests against a list of registered routes and calls a handler for the route that matches the URL or other conditions. The main features are: | |||
* It implements the `http.Handler` interface so it is compatible with the standard `http.ServeMux`. | |||
* Requests can be matched based on URL host, path, path prefix, schemes, header and query values, HTTP methods or using custom matchers. | |||
* URL hosts and paths can have variables with an optional regular expression. | |||
* Registered URLs can be built, or "reversed", which helps maintaining references to resources. | |||
* Routes can be used as subrouters: nested routes are only tested if the parent route matches. This is useful to define groups of routes that share common conditions like a host, a path prefix or other repeated attributes. As a bonus, this optimizes request matching. | |||
--- | |||
* [Install](#install) | |||
* [Examples](#examples) | |||
* [Matching Routes](#matching-routes) | |||
* [Static Files](#static-files) | |||
* [Registered URLs](#registered-urls) | |||
* [Full Example](#full-example) | |||
--- | |||
## Install | |||
With a [correctly configured](https://golang.org/doc/install#testing) Go toolchain: | |||
```sh | |||
go get -u github.com/gorilla/mux | |||
``` | |||
## Examples | |||
Let's start registering a couple of URL paths and handlers: | |||
```go | |||
func main() { | |||
r := mux.NewRouter() | |||
r.HandleFunc("/", HomeHandler) | |||
r.HandleFunc("/products", ProductsHandler) | |||
r.HandleFunc("/articles", ArticlesHandler) | |||
http.Handle("/", r) | |||
} | |||
``` | |||
Here we register three routes mapping URL paths to handlers. This is equivalent to how `http.HandleFunc()` works: if an incoming request URL matches one of the paths, the corresponding handler is called passing (`http.ResponseWriter`, `*http.Request`) as parameters. | |||
Paths can have variables. They are defined using the format `{name}` or `{name:pattern}`. If a regular expression pattern is not defined, the matched variable will be anything until the next slash. For example: | |||
```go | |||
r := mux.NewRouter() | |||
r.HandleFunc("/products/{key}", ProductHandler) | |||
r.HandleFunc("/articles/{category}/", ArticlesCategoryHandler) | |||
r.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler) | |||
``` | |||
The names are used to create a map of route variables which can be retrieved calling `mux.Vars()`: | |||
```go | |||
vars := mux.Vars(request) | |||
category := vars["category"] | |||
``` | |||
And this is all you need to know about the basic usage. More advanced options are explained below. | |||
### Matching Routes | |||
Routes can also be restricted to a domain or subdomain. Just define a host pattern to be matched. They can also have variables: | |||
```go | |||
r := mux.NewRouter() | |||
// Only matches if domain is "www.example.com". | |||
r.Host("www.example.com") | |||
// Matches a dynamic subdomain. | |||
r.Host("{subdomain:[a-z]+}.domain.com") | |||
``` | |||
There are several other matchers that can be added. To match path prefixes: | |||
```go | |||
r.PathPrefix("/products/") | |||
``` | |||
...or HTTP methods: | |||
```go | |||
r.Methods("GET", "POST") | |||
``` | |||
...or URL schemes: | |||
```go | |||
r.Schemes("https") | |||
``` | |||
...or header values: | |||
```go | |||
r.Headers("X-Requested-With", "XMLHttpRequest") | |||
``` | |||
...or query values: | |||
```go | |||
r.Queries("key", "value") | |||
``` | |||
...or to use a custom matcher function: | |||
```go | |||
r.MatcherFunc(func(r *http.Request, rm *RouteMatch) bool { | |||
return r.ProtoMajor == 0 | |||
}) | |||
``` | |||
...and finally, it is possible to combine several matchers in a single route: | |||
```go | |||
r.HandleFunc("/products", ProductsHandler). | |||
Host("www.example.com"). | |||
Methods("GET"). | |||
Schemes("http") | |||
``` | |||
Setting the same matching conditions again and again can be boring, so we have a way to group several routes that share the same requirements. We call it "subrouting". | |||
For example, let's say we have several URLs that should only match when the host is `www.example.com`. Create a route for that host and get a "subrouter" from it: | |||
```go | |||
r := mux.NewRouter() | |||
s := r.Host("www.example.com").Subrouter() | |||
``` | |||
Then register routes in the subrouter: | |||
```go | |||
s.HandleFunc("/products/", ProductsHandler) | |||
s.HandleFunc("/products/{key}", ProductHandler) | |||
s.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler) | |||
``` | |||
The three URL paths we registered above will only be tested if the domain is `www.example.com`, because the subrouter is tested first. This is not only convenient, but also optimizes request matching. You can create subrouters combining any attribute matchers accepted by a route. | |||
Subrouters can be used to create domain or path "namespaces": you define subrouters in a central place and then parts of the app can register its paths relatively to a given subrouter. | |||
There's one more thing about subroutes. When a subrouter has a path prefix, the inner routes use it as base for their paths: | |||
```go | |||
r := mux.NewRouter() | |||
s := r.PathPrefix("/products").Subrouter() | |||
// "/products/" | |||
s.HandleFunc("/", ProductsHandler) | |||
// "/products/{key}/" | |||
s.HandleFunc("/{key}/", ProductHandler) | |||
// "/products/{key}/details" | |||
s.HandleFunc("/{key}/details", ProductDetailsHandler) | |||
``` | |||
### Static Files | |||
Note that the path provided to `PathPrefix()` represents a "wildcard": calling | |||
`PathPrefix("/static/").Handler(...)` means that the handler will be passed any | |||
request that matches "/static/*". This makes it easy to serve static files with mux: | |||
```go | |||
func main() { | |||
var dir string | |||
flag.StringVar(&dir, "dir", ".", "the directory to serve files from. Defaults to the current dir") | |||
flag.Parse() | |||
r := mux.NewRouter() | |||
// This will serve files under http://localhost:8000/static/<filename> | |||
r.PathPrefix("/static/").Handler(http.StripPrefix("/static/", http.FileServer(http.Dir(dir)))) | |||
srv := &http.Server{ | |||
Handler: r, | |||
Addr: "127.0.0.1:8000", | |||
// Good practice: enforce timeouts for servers you create! | |||
WriteTimeout: 15 * time.Second, | |||
ReadTimeout: 15 * time.Second, | |||
} | |||
log.Fatal(srv.ListenAndServe()) | |||
} | |||
``` | |||
### Registered URLs | |||
Now let's see how to build registered URLs. | |||
Routes can be named. All routes that define a name can have their URLs built, or "reversed". We define a name calling `Name()` on a route. For example: | |||
```go | |||
r := mux.NewRouter() | |||
r.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler). | |||
Name("article") | |||
``` | |||
To build a URL, get the route and call the `URL()` method, passing a sequence of key/value pairs for the route variables. For the previous route, we would do: | |||
```go | |||
url, err := r.Get("article").URL("category", "technology", "id", "42") | |||
``` | |||
...and the result will be a `url.URL` with the following path: | |||
``` | |||
"/articles/technology/42" | |||
``` | |||
This also works for host variables: | |||
```go | |||
r := mux.NewRouter() | |||
r.Host("{subdomain}.domain.com"). | |||
Path("/articles/{category}/{id:[0-9]+}"). | |||
HandlerFunc(ArticleHandler). | |||
Name("article") | |||
// url.String() will be "http://news.domain.com/articles/technology/42" | |||
url, err := r.Get("article").URL("subdomain", "news", | |||
"category", "technology", | |||
"id", "42") | |||
``` | |||
All variables defined in the route are required, and their values must conform to the corresponding patterns. These requirements guarantee that a generated URL will always match a registered route -- the only exception is for explicitly defined "build-only" routes which never match. | |||
Regex support also exists for matching Headers within a route. For example, we could do: | |||
```go | |||
r.HeadersRegexp("Content-Type", "application/(text|json)") | |||
``` | |||
...and the route will match both requests with a Content-Type of `application/json` as well as `application/text` | |||
There's also a way to build only the URL host or path for a route: use the methods `URLHost()` or `URLPath()` instead. For the previous route, we would do: | |||
```go | |||
// "http://news.domain.com/" | |||
host, err := r.Get("article").URLHost("subdomain", "news") | |||
// "/articles/technology/42" | |||
path, err := r.Get("article").URLPath("category", "technology", "id", "42") | |||
``` | |||
And if you use subrouters, host and path defined separately can be built as well: | |||
```go | |||
r := mux.NewRouter() | |||
s := r.Host("{subdomain}.domain.com").Subrouter() | |||
s.Path("/articles/{category}/{id:[0-9]+}"). | |||
HandlerFunc(ArticleHandler). | |||
Name("article") | |||
// "http://news.domain.com/articles/technology/42" | |||
url, err := r.Get("article").URL("subdomain", "news", | |||
"category", "technology", | |||
"id", "42") | |||
``` | |||
## Full Example | |||
Here's a complete, runnable example of a small `mux` based server: | |||
```go | |||
package main | |||
import ( | |||
"net/http" | |||
"log" | |||
"github.com/gorilla/mux" | |||
) | |||
func YourHandler(w http.ResponseWriter, r *http.Request) { | |||
w.Write([]byte("Gorilla!\n")) | |||
} | |||
func main() { | |||
r := mux.NewRouter() | |||
// Routes consist of a path and a handler function. | |||
r.HandleFunc("/", YourHandler) | |||
// Bind to a port and pass our router in | |||
log.Fatal(http.ListenAndServe(":8000", r)) | |||
} | |||
``` | |||
## License | |||
BSD licensed. See the LICENSE file for details. |
@@ -0,0 +1,26 @@ | |||
// +build !go1.7 | |||
package mux | |||
import ( | |||
"net/http" | |||
"github.com/gorilla/context" | |||
) | |||
func contextGet(r *http.Request, key interface{}) interface{} { | |||
return context.Get(r, key) | |||
} | |||
func contextSet(r *http.Request, key, val interface{}) *http.Request { | |||
if val == nil { | |||
return r | |||
} | |||
context.Set(r, key, val) | |||
return r | |||
} | |||
func contextClear(r *http.Request) { | |||
context.Clear(r) | |||
} |
@@ -0,0 +1,24 @@ | |||
// +build go1.7 | |||
package mux | |||
import ( | |||
"context" | |||
"net/http" | |||
) | |||
func contextGet(r *http.Request, key interface{}) interface{} { | |||
return r.Context().Value(key) | |||
} | |||
func contextSet(r *http.Request, key, val interface{}) *http.Request { | |||
if val == nil { | |||
return r | |||
} | |||
return r.WithContext(context.WithValue(r.Context(), key, val)) | |||
} | |||
func contextClear(r *http.Request) { | |||
return | |||
} |
@@ -0,0 +1,235 @@ | |||
// Copyright 2012 The Gorilla Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
/* | |||
Package mux implements a request router and dispatcher. | |||
The name mux stands for "HTTP request multiplexer". Like the standard | |||
http.ServeMux, mux.Router matches incoming requests against a list of | |||
registered routes and calls a handler for the route that matches the URL | |||
or other conditions. The main features are: | |||
* Requests can be matched based on URL host, path, path prefix, schemes, | |||
header and query values, HTTP methods or using custom matchers. | |||
* URL hosts and paths can have variables with an optional regular | |||
expression. | |||
* Registered URLs can be built, or "reversed", which helps maintaining | |||
references to resources. | |||
* Routes can be used as subrouters: nested routes are only tested if the | |||
parent route matches. This is useful to define groups of routes that | |||
share common conditions like a host, a path prefix or other repeated | |||
attributes. As a bonus, this optimizes request matching. | |||
* It implements the http.Handler interface so it is compatible with the | |||
standard http.ServeMux. | |||
Let's start registering a couple of URL paths and handlers: | |||
func main() { | |||
r := mux.NewRouter() | |||
r.HandleFunc("/", HomeHandler) | |||
r.HandleFunc("/products", ProductsHandler) | |||
r.HandleFunc("/articles", ArticlesHandler) | |||
http.Handle("/", r) | |||
} | |||
Here we register three routes mapping URL paths to handlers. This is | |||
equivalent to how http.HandleFunc() works: if an incoming request URL matches | |||
one of the paths, the corresponding handler is called passing | |||
(http.ResponseWriter, *http.Request) as parameters. | |||
Paths can have variables. They are defined using the format {name} or | |||
{name:pattern}. If a regular expression pattern is not defined, the matched | |||
variable will be anything until the next slash. For example: | |||
r := mux.NewRouter() | |||
r.HandleFunc("/products/{key}", ProductHandler) | |||
r.HandleFunc("/articles/{category}/", ArticlesCategoryHandler) | |||
r.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler) | |||
Groups can be used inside patterns, as long as they are non-capturing (?:re). For example: | |||
r.HandleFunc("/articles/{category}/{sort:(?:asc|desc|new)}", ArticlesCategoryHandler) | |||
The names are used to create a map of route variables which can be retrieved | |||
calling mux.Vars(): | |||
vars := mux.Vars(request) | |||
category := vars["category"] | |||
And this is all you need to know about the basic usage. More advanced options | |||
are explained below. | |||
Routes can also be restricted to a domain or subdomain. Just define a host | |||
pattern to be matched. They can also have variables: | |||
r := mux.NewRouter() | |||
// Only matches if domain is "www.example.com". | |||
r.Host("www.example.com") | |||
// Matches a dynamic subdomain. | |||
r.Host("{subdomain:[a-z]+}.domain.com") | |||
There are several other matchers that can be added. To match path prefixes: | |||
r.PathPrefix("/products/") | |||
...or HTTP methods: | |||
r.Methods("GET", "POST") | |||
...or URL schemes: | |||
r.Schemes("https") | |||
...or header values: | |||
r.Headers("X-Requested-With", "XMLHttpRequest") | |||
...or query values: | |||
r.Queries("key", "value") | |||
...or to use a custom matcher function: | |||
r.MatcherFunc(func(r *http.Request, rm *RouteMatch) bool { | |||
return r.ProtoMajor == 0 | |||
}) | |||
...and finally, it is possible to combine several matchers in a single route: | |||
r.HandleFunc("/products", ProductsHandler). | |||
Host("www.example.com"). | |||
Methods("GET"). | |||
Schemes("http") | |||
Setting the same matching conditions again and again can be boring, so we have | |||
a way to group several routes that share the same requirements. | |||
We call it "subrouting". | |||
For example, let's say we have several URLs that should only match when the | |||
host is "www.example.com". Create a route for that host and get a "subrouter" | |||
from it: | |||
r := mux.NewRouter() | |||
s := r.Host("www.example.com").Subrouter() | |||
Then register routes in the subrouter: | |||
s.HandleFunc("/products/", ProductsHandler) | |||
s.HandleFunc("/products/{key}", ProductHandler) | |||
s.HandleFunc("/articles/{category}/{id:[0-9]+}"), ArticleHandler) | |||
The three URL paths we registered above will only be tested if the domain is | |||
"www.example.com", because the subrouter is tested first. This is not | |||
only convenient, but also optimizes request matching. You can create | |||
subrouters combining any attribute matchers accepted by a route. | |||
Subrouters can be used to create domain or path "namespaces": you define | |||
subrouters in a central place and then parts of the app can register its | |||
paths relatively to a given subrouter. | |||
There's one more thing about subroutes. When a subrouter has a path prefix, | |||
the inner routes use it as base for their paths: | |||
r := mux.NewRouter() | |||
s := r.PathPrefix("/products").Subrouter() | |||
// "/products/" | |||
s.HandleFunc("/", ProductsHandler) | |||
// "/products/{key}/" | |||
s.HandleFunc("/{key}/", ProductHandler) | |||
// "/products/{key}/details" | |||
s.HandleFunc("/{key}/details", ProductDetailsHandler) | |||
Note that the path provided to PathPrefix() represents a "wildcard": calling | |||
PathPrefix("/static/").Handler(...) means that the handler will be passed any | |||
request that matches "/static/*". This makes it easy to serve static files with mux: | |||
func main() { | |||
var dir string | |||
flag.StringVar(&dir, "dir", ".", "the directory to serve files from. Defaults to the current dir") | |||
flag.Parse() | |||
r := mux.NewRouter() | |||
// This will serve files under http://localhost:8000/static/<filename> | |||
r.PathPrefix("/static/").Handler(http.StripPrefix("/static/", http.FileServer(http.Dir(dir)))) | |||
srv := &http.Server{ | |||
Handler: r, | |||
Addr: "127.0.0.1:8000", | |||
// Good practice: enforce timeouts for servers you create! | |||
WriteTimeout: 15 * time.Second, | |||
ReadTimeout: 15 * time.Second, | |||
} | |||
log.Fatal(srv.ListenAndServe()) | |||
} | |||
Now let's see how to build registered URLs. | |||
Routes can be named. All routes that define a name can have their URLs built, | |||
or "reversed". We define a name calling Name() on a route. For example: | |||
r := mux.NewRouter() | |||
r.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler). | |||
Name("article") | |||
To build a URL, get the route and call the URL() method, passing a sequence of | |||
key/value pairs for the route variables. For the previous route, we would do: | |||
url, err := r.Get("article").URL("category", "technology", "id", "42") | |||
...and the result will be a url.URL with the following path: | |||
"/articles/technology/42" | |||
This also works for host variables: | |||
r := mux.NewRouter() | |||
r.Host("{subdomain}.domain.com"). | |||
Path("/articles/{category}/{id:[0-9]+}"). | |||
HandlerFunc(ArticleHandler). | |||
Name("article") | |||
// url.String() will be "http://news.domain.com/articles/technology/42" | |||
url, err := r.Get("article").URL("subdomain", "news", | |||
"category", "technology", | |||
"id", "42") | |||
All variables defined in the route are required, and their values must | |||
conform to the corresponding patterns. These requirements guarantee that a | |||
generated URL will always match a registered route -- the only exception is | |||
for explicitly defined "build-only" routes which never match. | |||
Regex support also exists for matching Headers within a route. For example, we could do: | |||
r.HeadersRegexp("Content-Type", "application/(text|json)") | |||
...and the route will match both requests with a Content-Type of `application/json` as well as | |||
`application/text` | |||
There's also a way to build only the URL host or path for a route: | |||
use the methods URLHost() or URLPath() instead. For the previous route, | |||
we would do: | |||
// "http://news.domain.com/" | |||
host, err := r.Get("article").URLHost("subdomain", "news") | |||
// "/articles/technology/42" | |||
path, err := r.Get("article").URLPath("category", "technology", "id", "42") | |||
And if you use subrouters, host and path defined separately can be built | |||
as well: | |||
r := mux.NewRouter() | |||
s := r.Host("{subdomain}.domain.com").Subrouter() | |||
s.Path("/articles/{category}/{id:[0-9]+}"). | |||
HandlerFunc(ArticleHandler). | |||
Name("article") | |||
// "http://news.domain.com/articles/technology/42" | |||
url, err := r.Get("article").URL("subdomain", "news", | |||
"category", "technology", | |||
"id", "42") | |||
*/ | |||
package mux |
@@ -0,0 +1,542 @@ | |||
// Copyright 2012 The Gorilla Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
package mux | |||
import ( | |||
"errors" | |||
"fmt" | |||
"net/http" | |||
"path" | |||
"regexp" | |||
"strings" | |||
) | |||
// NewRouter returns a new router instance. | |||
func NewRouter() *Router { | |||
return &Router{namedRoutes: make(map[string]*Route), KeepContext: false} | |||
} | |||
// Router registers routes to be matched and dispatches a handler. | |||
// | |||
// It implements the http.Handler interface, so it can be registered to serve | |||
// requests: | |||
// | |||
// var router = mux.NewRouter() | |||
// | |||
// func main() { | |||
// http.Handle("/", router) | |||
// } | |||
// | |||
// Or, for Google App Engine, register it in a init() function: | |||
// | |||
// func init() { | |||
// http.Handle("/", router) | |||
// } | |||
// | |||
// This will send all incoming requests to the router. | |||
type Router struct { | |||
// Configurable Handler to be used when no route matches. | |||
NotFoundHandler http.Handler | |||
// Parent route, if this is a subrouter. | |||
parent parentRoute | |||
// Routes to be matched, in order. | |||
routes []*Route | |||
// Routes by name for URL building. | |||
namedRoutes map[string]*Route | |||
// See Router.StrictSlash(). This defines the flag for new routes. | |||
strictSlash bool | |||
// See Router.SkipClean(). This defines the flag for new routes. | |||
skipClean bool | |||
// If true, do not clear the request context after handling the request. | |||
// This has no effect when go1.7+ is used, since the context is stored | |||
// on the request itself. | |||
KeepContext bool | |||
// see Router.UseEncodedPath(). This defines a flag for all routes. | |||
useEncodedPath bool | |||
} | |||
// Match matches registered routes against the request. | |||
func (r *Router) Match(req *http.Request, match *RouteMatch) bool { | |||
for _, route := range r.routes { | |||
if route.Match(req, match) { | |||
return true | |||
} | |||
} | |||
// Closest match for a router (includes sub-routers) | |||
if r.NotFoundHandler != nil { | |||
match.Handler = r.NotFoundHandler | |||
return true | |||
} | |||
return false | |||
} | |||
// ServeHTTP dispatches the handler registered in the matched route. | |||
// | |||
// When there is a match, the route variables can be retrieved calling | |||
// mux.Vars(request). | |||
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { | |||
if !r.skipClean { | |||
path := req.URL.Path | |||
if r.useEncodedPath { | |||
path = getPath(req) | |||
} | |||
// Clean path to canonical form and redirect. | |||
if p := cleanPath(path); p != path { | |||
// Added 3 lines (Philip Schlump) - It was dropping the query string and #whatever from query. | |||
// This matches with fix in go 1.2 r.c. 4 for same problem. Go Issue: | |||
// http://code.google.com/p/go/issues/detail?id=5252 | |||
url := *req.URL | |||
url.Path = p | |||
p = url.String() | |||
w.Header().Set("Location", p) | |||
w.WriteHeader(http.StatusMovedPermanently) | |||
return | |||
} | |||
} | |||
var match RouteMatch | |||
var handler http.Handler | |||
if r.Match(req, &match) { | |||
handler = match.Handler | |||
req = setVars(req, match.Vars) | |||
req = setCurrentRoute(req, match.Route) | |||
} | |||
if handler == nil { | |||
handler = http.NotFoundHandler() | |||
} | |||
if !r.KeepContext { | |||
defer contextClear(req) | |||
} | |||
handler.ServeHTTP(w, req) | |||
} | |||
// Get returns a route registered with the given name. | |||
func (r *Router) Get(name string) *Route { | |||
return r.getNamedRoutes()[name] | |||
} | |||
// GetRoute returns a route registered with the given name. This method | |||
// was renamed to Get() and remains here for backwards compatibility. | |||
func (r *Router) GetRoute(name string) *Route { | |||
return r.getNamedRoutes()[name] | |||
} | |||
// StrictSlash defines the trailing slash behavior for new routes. The initial | |||
// value is false. | |||
// | |||
// When true, if the route path is "/path/", accessing "/path" will redirect | |||
// to the former and vice versa. In other words, your application will always | |||
// see the path as specified in the route. | |||
// | |||
// When false, if the route path is "/path", accessing "/path/" will not match | |||
// this route and vice versa. | |||
// | |||
// Special case: when a route sets a path prefix using the PathPrefix() method, | |||
// strict slash is ignored for that route because the redirect behavior can't | |||
// be determined from a prefix alone. However, any subrouters created from that | |||
// route inherit the original StrictSlash setting. | |||
func (r *Router) StrictSlash(value bool) *Router { | |||
r.strictSlash = value | |||
return r | |||
} | |||
// SkipClean defines the path cleaning behaviour for new routes. The initial | |||
// value is false. Users should be careful about which routes are not cleaned | |||
// | |||
// When true, if the route path is "/path//to", it will remain with the double | |||
// slash. This is helpful if you have a route like: /fetch/http://xkcd.com/534/ | |||
// | |||
// When false, the path will be cleaned, so /fetch/http://xkcd.com/534/ will | |||
// become /fetch/http/xkcd.com/534 | |||
func (r *Router) SkipClean(value bool) *Router { | |||
r.skipClean = value | |||
return r | |||
} | |||
// UseEncodedPath tells the router to match the encoded original path | |||
// to the routes. | |||
// For eg. "/path/foo%2Fbar/to" will match the path "/path/{var}/to". | |||
// This behavior has the drawback of needing to match routes against | |||
// r.RequestURI instead of r.URL.Path. Any modifications (such as http.StripPrefix) | |||
// to r.URL.Path will not affect routing when this flag is on and thus may | |||
// induce unintended behavior. | |||
// | |||
// If not called, the router will match the unencoded path to the routes. | |||
// For eg. "/path/foo%2Fbar/to" will match the path "/path/foo/bar/to" | |||
func (r *Router) UseEncodedPath() *Router { | |||
r.useEncodedPath = true | |||
return r | |||
} | |||
// ---------------------------------------------------------------------------- | |||
// parentRoute | |||
// ---------------------------------------------------------------------------- | |||
// getNamedRoutes returns the map where named routes are registered. | |||
func (r *Router) getNamedRoutes() map[string]*Route { | |||
if r.namedRoutes == nil { | |||
if r.parent != nil { | |||
r.namedRoutes = r.parent.getNamedRoutes() | |||
} else { | |||
r.namedRoutes = make(map[string]*Route) | |||
} | |||
} | |||
return r.namedRoutes | |||
} | |||
// getRegexpGroup returns regexp definitions from the parent route, if any. | |||
func (r *Router) getRegexpGroup() *routeRegexpGroup { | |||
if r.parent != nil { | |||
return r.parent.getRegexpGroup() | |||
} | |||
return nil | |||
} | |||
func (r *Router) buildVars(m map[string]string) map[string]string { | |||
if r.parent != nil { | |||
m = r.parent.buildVars(m) | |||
} | |||
return m | |||
} | |||
// ---------------------------------------------------------------------------- | |||
// Route factories | |||
// ---------------------------------------------------------------------------- | |||
// NewRoute registers an empty route. | |||
func (r *Router) NewRoute() *Route { | |||
route := &Route{parent: r, strictSlash: r.strictSlash, skipClean: r.skipClean, useEncodedPath: r.useEncodedPath} | |||
r.routes = append(r.routes, route) | |||
return route | |||
} | |||
// Handle registers a new route with a matcher for the URL path. | |||
// See Route.Path() and Route.Handler(). | |||
func (r *Router) Handle(path string, handler http.Handler) *Route { | |||
return r.NewRoute().Path(path).Handler(handler) | |||
} | |||
// HandleFunc registers a new route with a matcher for the URL path. | |||
// See Route.Path() and Route.HandlerFunc(). | |||
func (r *Router) HandleFunc(path string, f func(http.ResponseWriter, | |||
*http.Request)) *Route { | |||
return r.NewRoute().Path(path).HandlerFunc(f) | |||
} | |||
// Headers registers a new route with a matcher for request header values. | |||
// See Route.Headers(). | |||
func (r *Router) Headers(pairs ...string) *Route { | |||
return r.NewRoute().Headers(pairs...) | |||
} | |||
// Host registers a new route with a matcher for the URL host. | |||
// See Route.Host(). | |||
func (r *Router) Host(tpl string) *Route { | |||
return r.NewRoute().Host(tpl) | |||
} | |||
// MatcherFunc registers a new route with a custom matcher function. | |||
// See Route.MatcherFunc(). | |||
func (r *Router) MatcherFunc(f MatcherFunc) *Route { | |||
return r.NewRoute().MatcherFunc(f) | |||
} | |||
// Methods registers a new route with a matcher for HTTP methods. | |||
// See Route.Methods(). | |||
func (r *Router) Methods(methods ...string) *Route { | |||
return r.NewRoute().Methods(methods...) | |||
} | |||
// Path registers a new route with a matcher for the URL path. | |||
// See Route.Path(). | |||
func (r *Router) Path(tpl string) *Route { | |||
return r.NewRoute().Path(tpl) | |||
} | |||
// PathPrefix registers a new route with a matcher for the URL path prefix. | |||
// See Route.PathPrefix(). | |||
func (r *Router) PathPrefix(tpl string) *Route { | |||
return r.NewRoute().PathPrefix(tpl) | |||
} | |||
// Queries registers a new route with a matcher for URL query values. | |||
// See Route.Queries(). | |||
func (r *Router) Queries(pairs ...string) *Route { | |||
return r.NewRoute().Queries(pairs...) | |||
} | |||
// Schemes registers a new route with a matcher for URL schemes. | |||
// See Route.Schemes(). | |||
func (r *Router) Schemes(schemes ...string) *Route { | |||
return r.NewRoute().Schemes(schemes...) | |||
} | |||
// BuildVarsFunc registers a new route with a custom function for modifying | |||
// route variables before building a URL. | |||
func (r *Router) BuildVarsFunc(f BuildVarsFunc) *Route { | |||
return r.NewRoute().BuildVarsFunc(f) | |||
} | |||
// Walk walks the router and all its sub-routers, calling walkFn for each route | |||
// in the tree. The routes are walked in the order they were added. Sub-routers | |||
// are explored depth-first. | |||
func (r *Router) Walk(walkFn WalkFunc) error { | |||
return r.walk(walkFn, []*Route{}) | |||
} | |||
// SkipRouter is used as a return value from WalkFuncs to indicate that the | |||
// router that walk is about to descend down to should be skipped. | |||
var SkipRouter = errors.New("skip this router") | |||
// WalkFunc is the type of the function called for each route visited by Walk. | |||
// At every invocation, it is given the current route, and the current router, | |||
// and a list of ancestor routes that lead to the current route. | |||
type WalkFunc func(route *Route, router *Router, ancestors []*Route) error | |||
func (r *Router) walk(walkFn WalkFunc, ancestors []*Route) error { | |||
for _, t := range r.routes { | |||
if t.regexp == nil || t.regexp.path == nil || t.regexp.path.template == "" { | |||
continue | |||
} | |||
err := walkFn(t, r, ancestors) | |||
if err == SkipRouter { | |||
continue | |||
} | |||
if err != nil { | |||
return err | |||
} | |||
for _, sr := range t.matchers { | |||
if h, ok := sr.(*Router); ok { | |||
err := h.walk(walkFn, ancestors) | |||
if err != nil { | |||
return err | |||
} | |||
} | |||
} | |||
if h, ok := t.handler.(*Router); ok { | |||
ancestors = append(ancestors, t) | |||
err := h.walk(walkFn, ancestors) | |||
if err != nil { | |||
return err | |||
} | |||
ancestors = ancestors[:len(ancestors)-1] | |||
} | |||
} | |||
return nil | |||
} | |||
// ---------------------------------------------------------------------------- | |||
// Context | |||
// ---------------------------------------------------------------------------- | |||
// RouteMatch stores information about a matched route. | |||
type RouteMatch struct { | |||
Route *Route | |||
Handler http.Handler | |||
Vars map[string]string | |||
} | |||
type contextKey int | |||
const ( | |||
varsKey contextKey = iota | |||
routeKey | |||
) | |||
// Vars returns the route variables for the current request, if any. | |||
func Vars(r *http.Request) map[string]string { | |||
if rv := contextGet(r, varsKey); rv != nil { | |||
return rv.(map[string]string) | |||
} | |||
return nil | |||
} | |||
// CurrentRoute returns the matched route for the current request, if any. | |||
// This only works when called inside the handler of the matched route | |||
// because the matched route is stored in the request context which is cleared | |||
// after the handler returns, unless the KeepContext option is set on the | |||
// Router. | |||
func CurrentRoute(r *http.Request) *Route { | |||
if rv := contextGet(r, routeKey); rv != nil { | |||
return rv.(*Route) | |||
} | |||
return nil | |||
} | |||
func setVars(r *http.Request, val interface{}) *http.Request { | |||
return contextSet(r, varsKey, val) | |||
} | |||
func setCurrentRoute(r *http.Request, val interface{}) *http.Request { | |||
return contextSet(r, routeKey, val) | |||
} | |||
// ---------------------------------------------------------------------------- | |||
// Helpers | |||
// ---------------------------------------------------------------------------- | |||
// getPath returns the escaped path if possible; doing what URL.EscapedPath() | |||
// which was added in go1.5 does | |||
func getPath(req *http.Request) string { | |||
if req.RequestURI != "" { | |||
// Extract the path from RequestURI (which is escaped unlike URL.Path) | |||
// as detailed here as detailed in https://golang.org/pkg/net/url/#URL | |||
// for < 1.5 server side workaround | |||
// http://localhost/path/here?v=1 -> /path/here | |||
path := req.RequestURI | |||
path = strings.TrimPrefix(path, req.URL.Scheme+`://`) | |||
path = strings.TrimPrefix(path, req.URL.Host) | |||
if i := strings.LastIndex(path, "?"); i > -1 { | |||
path = path[:i] | |||
} | |||
if i := strings.LastIndex(path, "#"); i > -1 { | |||
path = path[:i] | |||
} | |||
return path | |||
} | |||
return req.URL.Path | |||
} | |||
// cleanPath returns the canonical path for p, eliminating . and .. elements. | |||
// Borrowed from the net/http package. | |||
func cleanPath(p string) string { | |||
if p == "" { | |||
return "/" | |||
} | |||
if p[0] != '/' { | |||
p = "/" + p | |||
} | |||
np := path.Clean(p) | |||
// path.Clean removes trailing slash except for root; | |||
// put the trailing slash back if necessary. | |||
if p[len(p)-1] == '/' && np != "/" { | |||
np += "/" | |||
} | |||
return np | |||
} | |||
// uniqueVars returns an error if two slices contain duplicated strings. | |||
func uniqueVars(s1, s2 []string) error { | |||
for _, v1 := range s1 { | |||
for _, v2 := range s2 { | |||
if v1 == v2 { | |||
return fmt.Errorf("mux: duplicated route variable %q", v2) | |||
} | |||
} | |||
} | |||
return nil | |||
} | |||
// checkPairs returns the count of strings passed in, and an error if | |||
// the count is not an even number. | |||
func checkPairs(pairs ...string) (int, error) { | |||
length := len(pairs) | |||
if length%2 != 0 { | |||
return length, fmt.Errorf( | |||
"mux: number of parameters must be multiple of 2, got %v", pairs) | |||
} | |||
return length, nil | |||
} | |||
// mapFromPairsToString converts variadic string parameters to a | |||
// string to string map. | |||
func mapFromPairsToString(pairs ...string) (map[string]string, error) { | |||
length, err := checkPairs(pairs...) | |||
if err != nil { | |||
return nil, err | |||
} | |||
m := make(map[string]string, length/2) | |||
for i := 0; i < length; i += 2 { | |||
m[pairs[i]] = pairs[i+1] | |||
} | |||
return m, nil | |||
} | |||
// mapFromPairsToRegex converts variadic string paramers to a | |||
// string to regex map. | |||
func mapFromPairsToRegex(pairs ...string) (map[string]*regexp.Regexp, error) { | |||
length, err := checkPairs(pairs...) | |||
if err != nil { | |||
return nil, err | |||
} | |||
m := make(map[string]*regexp.Regexp, length/2) | |||
for i := 0; i < length; i += 2 { | |||
regex, err := regexp.Compile(pairs[i+1]) | |||
if err != nil { | |||
return nil, err | |||
} | |||
m[pairs[i]] = regex | |||
} | |||
return m, nil | |||
} | |||
// matchInArray returns true if the given string value is in the array. | |||
func matchInArray(arr []string, value string) bool { | |||
for _, v := range arr { | |||
if v == value { | |||
return true | |||
} | |||
} | |||
return false | |||
} | |||
// matchMapWithString returns true if the given key/value pairs exist in a given map. | |||
func matchMapWithString(toCheck map[string]string, toMatch map[string][]string, canonicalKey bool) bool { | |||
for k, v := range toCheck { | |||
// Check if key exists. | |||
if canonicalKey { | |||
k = http.CanonicalHeaderKey(k) | |||
} | |||
if values := toMatch[k]; values == nil { | |||
return false | |||
} else if v != "" { | |||
// If value was defined as an empty string we only check that the | |||
// key exists. Otherwise we also check for equality. | |||
valueExists := false | |||
for _, value := range values { | |||
if v == value { | |||
valueExists = true | |||
break | |||
} | |||
} | |||
if !valueExists { | |||
return false | |||
} | |||
} | |||
} | |||
return true | |||
} | |||
// matchMapWithRegex returns true if the given key/value pairs exist in a given map compiled against | |||
// the given regex | |||
func matchMapWithRegex(toCheck map[string]*regexp.Regexp, toMatch map[string][]string, canonicalKey bool) bool { | |||
for k, v := range toCheck { | |||
// Check if key exists. | |||
if canonicalKey { | |||
k = http.CanonicalHeaderKey(k) | |||
} | |||
if values := toMatch[k]; values == nil { | |||
return false | |||
} else if v != nil { | |||
// If value was defined as an empty string we only check that the | |||
// key exists. Otherwise we also check for equality. | |||
valueExists := false | |||
for _, value := range values { | |||
if v.MatchString(value) { | |||
valueExists = true | |||
break | |||
} | |||
} | |||
if !valueExists { | |||
return false | |||
} | |||
} | |||
} | |||
return true | |||
} |
@@ -0,0 +1,316 @@ | |||
// Copyright 2012 The Gorilla Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
package mux | |||
import ( | |||
"bytes" | |||
"fmt" | |||
"net/http" | |||
"net/url" | |||
"regexp" | |||
"strconv" | |||
"strings" | |||
) | |||
// newRouteRegexp parses a route template and returns a routeRegexp, | |||
// used to match a host, a path or a query string. | |||
// | |||
// It will extract named variables, assemble a regexp to be matched, create | |||
// a "reverse" template to build URLs and compile regexps to validate variable | |||
// values used in URL building. | |||
// | |||
// Previously we accepted only Python-like identifiers for variable | |||
// names ([a-zA-Z_][a-zA-Z0-9_]*), but currently the only restriction is that | |||
// name and pattern can't be empty, and names can't contain a colon. | |||
func newRouteRegexp(tpl string, matchHost, matchPrefix, matchQuery, strictSlash, useEncodedPath bool) (*routeRegexp, error) { | |||
// Check if it is well-formed. | |||
idxs, errBraces := braceIndices(tpl) | |||
if errBraces != nil { | |||
return nil, errBraces | |||
} | |||
// Backup the original. | |||
template := tpl | |||
// Now let's parse it. | |||
defaultPattern := "[^/]+" | |||
if matchQuery { | |||
defaultPattern = "[^?&]*" | |||
} else if matchHost { | |||
defaultPattern = "[^.]+" | |||
matchPrefix = false | |||
} | |||
// Only match strict slash if not matching | |||
if matchPrefix || matchHost || matchQuery { | |||
strictSlash = false | |||
} | |||
// Set a flag for strictSlash. | |||
endSlash := false | |||
if strictSlash && strings.HasSuffix(tpl, "/") { | |||
tpl = tpl[:len(tpl)-1] | |||
endSlash = true | |||
} | |||
varsN := make([]string, len(idxs)/2) | |||
varsR := make([]*regexp.Regexp, len(idxs)/2) | |||
pattern := bytes.NewBufferString("") | |||
pattern.WriteByte('^') | |||
reverse := bytes.NewBufferString("") | |||
var end int | |||
var err error | |||
for i := 0; i < len(idxs); i += 2 { | |||
// Set all values we are interested in. | |||
raw := tpl[end:idxs[i]] | |||
end = idxs[i+1] | |||
parts := strings.SplitN(tpl[idxs[i]+1:end-1], ":", 2) | |||
name := parts[0] | |||
patt := defaultPattern | |||
if len(parts) == 2 { | |||
patt = parts[1] | |||
} | |||
// Name or pattern can't be empty. | |||
if name == "" || patt == "" { | |||
return nil, fmt.Errorf("mux: missing name or pattern in %q", | |||
tpl[idxs[i]:end]) | |||
} | |||
// Build the regexp pattern. | |||
fmt.Fprintf(pattern, "%s(?P<%s>%s)", regexp.QuoteMeta(raw), varGroupName(i/2), patt) | |||
// Build the reverse template. | |||
fmt.Fprintf(reverse, "%s%%s", raw) | |||
// Append variable name and compiled pattern. | |||
varsN[i/2] = name | |||
varsR[i/2], err = regexp.Compile(fmt.Sprintf("^%s$", patt)) | |||
if err != nil { | |||
return nil, err | |||
} | |||
} | |||
// Add the remaining. | |||
raw := tpl[end:] | |||
pattern.WriteString(regexp.QuoteMeta(raw)) | |||
if strictSlash { | |||
pattern.WriteString("[/]?") | |||
} | |||
if matchQuery { | |||
// Add the default pattern if the query value is empty | |||
if queryVal := strings.SplitN(template, "=", 2)[1]; queryVal == "" { | |||
pattern.WriteString(defaultPattern) | |||
} | |||
} | |||
if !matchPrefix { | |||
pattern.WriteByte('$') | |||
} | |||
reverse.WriteString(raw) | |||
if endSlash { | |||
reverse.WriteByte('/') | |||
} | |||
// Compile full regexp. | |||
reg, errCompile := regexp.Compile(pattern.String()) | |||
if errCompile != nil { | |||
return nil, errCompile | |||
} | |||
// Done! | |||
return &routeRegexp{ | |||
template: template, | |||
matchHost: matchHost, | |||
matchQuery: matchQuery, | |||
strictSlash: strictSlash, | |||
useEncodedPath: useEncodedPath, | |||
regexp: reg, | |||
reverse: reverse.String(), | |||
varsN: varsN, | |||
varsR: varsR, | |||
}, nil | |||
} | |||
// routeRegexp stores a regexp to match a host or path and information to | |||
// collect and validate route variables. | |||
type routeRegexp struct { | |||
// The unmodified template. | |||
template string | |||
// True for host match, false for path or query string match. | |||
matchHost bool | |||
// True for query string match, false for path and host match. | |||
matchQuery bool | |||
// The strictSlash value defined on the route, but disabled if PathPrefix was used. | |||
strictSlash bool | |||
// Determines whether to use encoded path from getPath function or unencoded | |||
// req.URL.Path for path matching | |||
useEncodedPath bool | |||
// Expanded regexp. | |||
regexp *regexp.Regexp | |||
// Reverse template. | |||
reverse string | |||
// Variable names. | |||
varsN []string | |||
// Variable regexps (validators). | |||
varsR []*regexp.Regexp | |||
} | |||
// Match matches the regexp against the URL host or path. | |||
func (r *routeRegexp) Match(req *http.Request, match *RouteMatch) bool { | |||
if !r.matchHost { | |||
if r.matchQuery { | |||
return r.matchQueryString(req) | |||
} | |||
path := req.URL.Path | |||
if r.useEncodedPath { | |||
path = getPath(req) | |||
} | |||
return r.regexp.MatchString(path) | |||
} | |||
return r.regexp.MatchString(getHost(req)) | |||
} | |||
// url builds a URL part using the given values. | |||
func (r *routeRegexp) url(values map[string]string) (string, error) { | |||
urlValues := make([]interface{}, len(r.varsN)) | |||
for k, v := range r.varsN { | |||
value, ok := values[v] | |||
if !ok { | |||
return "", fmt.Errorf("mux: missing route variable %q", v) | |||
} | |||
urlValues[k] = value | |||
} | |||
rv := fmt.Sprintf(r.reverse, urlValues...) | |||
if !r.regexp.MatchString(rv) { | |||
// The URL is checked against the full regexp, instead of checking | |||
// individual variables. This is faster but to provide a good error | |||
// message, we check individual regexps if the URL doesn't match. | |||
for k, v := range r.varsN { | |||
if !r.varsR[k].MatchString(values[v]) { | |||
return "", fmt.Errorf( | |||
"mux: variable %q doesn't match, expected %q", values[v], | |||
r.varsR[k].String()) | |||
} | |||
} | |||
} | |||
return rv, nil | |||
} | |||
// getURLQuery returns a single query parameter from a request URL. | |||
// For a URL with foo=bar&baz=ding, we return only the relevant key | |||
// value pair for the routeRegexp. | |||
func (r *routeRegexp) getURLQuery(req *http.Request) string { | |||
if !r.matchQuery { | |||
return "" | |||
} | |||
templateKey := strings.SplitN(r.template, "=", 2)[0] | |||
for key, vals := range req.URL.Query() { | |||
if key == templateKey && len(vals) > 0 { | |||
return key + "=" + vals[0] | |||
} | |||
} | |||
return "" | |||
} | |||
func (r *routeRegexp) matchQueryString(req *http.Request) bool { | |||
return r.regexp.MatchString(r.getURLQuery(req)) | |||
} | |||
// braceIndices returns the first level curly brace indices from a string. | |||
// It returns an error in case of unbalanced braces. | |||
func braceIndices(s string) ([]int, error) { | |||
var level, idx int | |||
var idxs []int | |||
for i := 0; i < len(s); i++ { | |||
switch s[i] { | |||
case '{': | |||
if level++; level == 1 { | |||
idx = i | |||
} | |||
case '}': | |||
if level--; level == 0 { | |||
idxs = append(idxs, idx, i+1) | |||
} else if level < 0 { | |||
return nil, fmt.Errorf("mux: unbalanced braces in %q", s) | |||
} | |||
} | |||
} | |||
if level != 0 { | |||
return nil, fmt.Errorf("mux: unbalanced braces in %q", s) | |||
} | |||
return idxs, nil | |||
} | |||
// varGroupName builds a capturing group name for the indexed variable. | |||
func varGroupName(idx int) string { | |||
return "v" + strconv.Itoa(idx) | |||
} | |||
// ---------------------------------------------------------------------------- | |||
// routeRegexpGroup | |||
// ---------------------------------------------------------------------------- | |||
// routeRegexpGroup groups the route matchers that carry variables. | |||
type routeRegexpGroup struct { | |||
host *routeRegexp | |||
path *routeRegexp | |||
queries []*routeRegexp | |||
} | |||
// setMatch extracts the variables from the URL once a route matches. | |||
func (v *routeRegexpGroup) setMatch(req *http.Request, m *RouteMatch, r *Route) { | |||
// Store host variables. | |||
if v.host != nil { | |||
host := getHost(req) | |||
matches := v.host.regexp.FindStringSubmatchIndex(host) | |||
if len(matches) > 0 { | |||
extractVars(host, matches, v.host.varsN, m.Vars) | |||
} | |||
} | |||
path := req.URL.Path | |||
if r.useEncodedPath { | |||
path = getPath(req) | |||
} | |||
// Store path variables. | |||
if v.path != nil { | |||
matches := v.path.regexp.FindStringSubmatchIndex(path) | |||
if len(matches) > 0 { | |||
extractVars(path, matches, v.path.varsN, m.Vars) | |||
// Check if we should redirect. | |||
if v.path.strictSlash { | |||
p1 := strings.HasSuffix(path, "/") | |||
p2 := strings.HasSuffix(v.path.template, "/") | |||
if p1 != p2 { | |||
u, _ := url.Parse(req.URL.String()) | |||
if p1 { | |||
u.Path = u.Path[:len(u.Path)-1] | |||
} else { | |||
u.Path += "/" | |||
} | |||
m.Handler = http.RedirectHandler(u.String(), 301) | |||
} | |||
} | |||
} | |||
} | |||
// Store query string variables. | |||
for _, q := range v.queries { | |||
queryURL := q.getURLQuery(req) | |||
matches := q.regexp.FindStringSubmatchIndex(queryURL) | |||
if len(matches) > 0 { | |||
extractVars(queryURL, matches, q.varsN, m.Vars) | |||
} | |||
} | |||
} | |||
// getHost tries its best to return the request host. | |||
func getHost(r *http.Request) string { | |||
if r.URL.IsAbs() { | |||
return r.URL.Host | |||
} | |||
host := r.Host | |||
// Slice off any port information. | |||
if i := strings.Index(host, ":"); i != -1 { | |||
host = host[:i] | |||
} | |||
return host | |||
} | |||
func extractVars(input string, matches []int, names []string, output map[string]string) { | |||
for i, name := range names { | |||
output[name] = input[matches[2*i+2]:matches[2*i+3]] | |||
} | |||
} |
@@ -0,0 +1,636 @@ | |||
// Copyright 2012 The Gorilla Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
package mux | |||
import ( | |||
"errors" | |||
"fmt" | |||
"net/http" | |||
"net/url" | |||
"regexp" | |||
"strings" | |||
) | |||
// Route stores information to match a request and build URLs. | |||
type Route struct { | |||
// Parent where the route was registered (a Router). | |||
parent parentRoute | |||
// Request handler for the route. | |||
handler http.Handler | |||
// List of matchers. | |||
matchers []matcher | |||
// Manager for the variables from host and path. | |||
regexp *routeRegexpGroup | |||
// If true, when the path pattern is "/path/", accessing "/path" will | |||
// redirect to the former and vice versa. | |||
strictSlash bool | |||
// If true, when the path pattern is "/path//to", accessing "/path//to" | |||
// will not redirect | |||
skipClean bool | |||
// If true, "/path/foo%2Fbar/to" will match the path "/path/{var}/to" | |||
useEncodedPath bool | |||
// If true, this route never matches: it is only used to build URLs. | |||
buildOnly bool | |||
// The name used to build URLs. | |||
name string | |||
// Error resulted from building a route. | |||
err error | |||
buildVarsFunc BuildVarsFunc | |||
} | |||
func (r *Route) SkipClean() bool { | |||
return r.skipClean | |||
} | |||
// Match matches the route against the request. | |||
func (r *Route) Match(req *http.Request, match *RouteMatch) bool { | |||
if r.buildOnly || r.err != nil { | |||
return false | |||
} | |||
// Match everything. | |||
for _, m := range r.matchers { | |||
if matched := m.Match(req, match); !matched { | |||
return false | |||
} | |||
} | |||
// Yay, we have a match. Let's collect some info about it. | |||
if match.Route == nil { | |||
match.Route = r | |||
} | |||
if match.Handler == nil { | |||
match.Handler = r.handler | |||
} | |||
if match.Vars == nil { | |||
match.Vars = make(map[string]string) | |||
} | |||
// Set variables. | |||
if r.regexp != nil { | |||
r.regexp.setMatch(req, match, r) | |||
} | |||
return true | |||
} | |||
// ---------------------------------------------------------------------------- | |||
// Route attributes | |||
// ---------------------------------------------------------------------------- | |||
// GetError returns an error resulted from building the route, if any. | |||
func (r *Route) GetError() error { | |||
return r.err | |||
} | |||
// BuildOnly sets the route to never match: it is only used to build URLs. | |||
func (r *Route) BuildOnly() *Route { | |||
r.buildOnly = true | |||
return r | |||
} | |||
// Handler -------------------------------------------------------------------- | |||
// Handler sets a handler for the route. | |||
func (r *Route) Handler(handler http.Handler) *Route { | |||
if r.err == nil { | |||
r.handler = handler | |||
} | |||
return r | |||
} | |||
// HandlerFunc sets a handler function for the route. | |||
func (r *Route) HandlerFunc(f func(http.ResponseWriter, *http.Request)) *Route { | |||
return r.Handler(http.HandlerFunc(f)) | |||
} | |||
// GetHandler returns the handler for the route, if any. | |||
func (r *Route) GetHandler() http.Handler { | |||
return r.handler | |||
} | |||
// Name ----------------------------------------------------------------------- | |||
// Name sets the name for the route, used to build URLs. | |||
// If the name was registered already it will be overwritten. | |||
func (r *Route) Name(name string) *Route { | |||
if r.name != "" { | |||
r.err = fmt.Errorf("mux: route already has name %q, can't set %q", | |||
r.name, name) | |||
} | |||
if r.err == nil { | |||
r.name = name | |||
r.getNamedRoutes()[name] = r | |||
} | |||
return r | |||
} | |||
// GetName returns the name for the route, if any. | |||
func (r *Route) GetName() string { | |||
return r.name | |||
} | |||
// ---------------------------------------------------------------------------- | |||
// Matchers | |||
// ---------------------------------------------------------------------------- | |||
// matcher types try to match a request. | |||
type matcher interface { | |||
Match(*http.Request, *RouteMatch) bool | |||
} | |||
// addMatcher adds a matcher to the route. | |||
func (r *Route) addMatcher(m matcher) *Route { | |||
if r.err == nil { | |||
r.matchers = append(r.matchers, m) | |||
} | |||
return r | |||
} | |||
// addRegexpMatcher adds a host or path matcher and builder to a route. | |||
func (r *Route) addRegexpMatcher(tpl string, matchHost, matchPrefix, matchQuery bool) error { | |||
if r.err != nil { | |||
return r.err | |||
} | |||
r.regexp = r.getRegexpGroup() | |||
if !matchHost && !matchQuery { | |||
if len(tpl) == 0 || tpl[0] != '/' { | |||
return fmt.Errorf("mux: path must start with a slash, got %q", tpl) | |||
} | |||
if r.regexp.path != nil { | |||
tpl = strings.TrimRight(r.regexp.path.template, "/") + tpl | |||
} | |||
} | |||
rr, err := newRouteRegexp(tpl, matchHost, matchPrefix, matchQuery, r.strictSlash, r.useEncodedPath) | |||
if err != nil { | |||
return err | |||
} | |||
for _, q := range r.regexp.queries { | |||
if err = uniqueVars(rr.varsN, q.varsN); err != nil { | |||
return err | |||
} | |||
} | |||
if matchHost { | |||
if r.regexp.path != nil { | |||
if err = uniqueVars(rr.varsN, r.regexp.path.varsN); err != nil { | |||
return err | |||
} | |||
} | |||
r.regexp.host = rr | |||
} else { | |||
if r.regexp.host != nil { | |||
if err = uniqueVars(rr.varsN, r.regexp.host.varsN); err != nil { | |||
return err | |||
} | |||
} | |||
if matchQuery { | |||
r.regexp.queries = append(r.regexp.queries, rr) | |||
} else { | |||
r.regexp.path = rr | |||
} | |||
} | |||
r.addMatcher(rr) | |||
return nil | |||
} | |||
// Headers -------------------------------------------------------------------- | |||
// headerMatcher matches the request against header values. | |||
type headerMatcher map[string]string | |||
func (m headerMatcher) Match(r *http.Request, match *RouteMatch) bool { | |||
return matchMapWithString(m, r.Header, true) | |||
} | |||
// Headers adds a matcher for request header values. | |||
// It accepts a sequence of key/value pairs to be matched. For example: | |||
// | |||
// r := mux.NewRouter() | |||
// r.Headers("Content-Type", "application/json", | |||
// "X-Requested-With", "XMLHttpRequest") | |||
// | |||
// The above route will only match if both request header values match. | |||
// If the value is an empty string, it will match any value if the key is set. | |||
func (r *Route) Headers(pairs ...string) *Route { | |||
if r.err == nil { | |||
var headers map[string]string | |||
headers, r.err = mapFromPairsToString(pairs...) | |||
return r.addMatcher(headerMatcher(headers)) | |||
} | |||
return r | |||
} | |||
// headerRegexMatcher matches the request against the route given a regex for the header | |||
type headerRegexMatcher map[string]*regexp.Regexp | |||
func (m headerRegexMatcher) Match(r *http.Request, match *RouteMatch) bool { | |||
return matchMapWithRegex(m, r.Header, true) | |||
} | |||
// HeadersRegexp accepts a sequence of key/value pairs, where the value has regex | |||
// support. For example: | |||
// | |||
// r := mux.NewRouter() | |||
// r.HeadersRegexp("Content-Type", "application/(text|json)", | |||
// "X-Requested-With", "XMLHttpRequest") | |||
// | |||
// The above route will only match if both the request header matches both regular expressions. | |||
// It the value is an empty string, it will match any value if the key is set. | |||
func (r *Route) HeadersRegexp(pairs ...string) *Route { | |||
if r.err == nil { | |||
var headers map[string]*regexp.Regexp | |||
headers, r.err = mapFromPairsToRegex(pairs...) | |||
return r.addMatcher(headerRegexMatcher(headers)) | |||
} | |||
return r | |||
} | |||
// Host ----------------------------------------------------------------------- | |||
// Host adds a matcher for the URL host. | |||
// It accepts a template with zero or more URL variables enclosed by {}. | |||
// Variables can define an optional regexp pattern to be matched: | |||
// | |||
// - {name} matches anything until the next dot. | |||
// | |||
// - {name:pattern} matches the given regexp pattern. | |||
// | |||
// For example: | |||
// | |||
// r := mux.NewRouter() | |||
// r.Host("www.example.com") | |||
// r.Host("{subdomain}.domain.com") | |||
// r.Host("{subdomain:[a-z]+}.domain.com") | |||
// | |||
// Variable names must be unique in a given route. They can be retrieved | |||
// calling mux.Vars(request). | |||
func (r *Route) Host(tpl string) *Route { | |||
r.err = r.addRegexpMatcher(tpl, true, false, false) | |||
return r | |||
} | |||
// MatcherFunc ---------------------------------------------------------------- | |||
// MatcherFunc is the function signature used by custom matchers. | |||
type MatcherFunc func(*http.Request, *RouteMatch) bool | |||
// Match returns the match for a given request. | |||
func (m MatcherFunc) Match(r *http.Request, match *RouteMatch) bool { | |||
return m(r, match) | |||
} | |||
// MatcherFunc adds a custom function to be used as request matcher. | |||
func (r *Route) MatcherFunc(f MatcherFunc) *Route { | |||
return r.addMatcher(f) | |||
} | |||
// Methods -------------------------------------------------------------------- | |||
// methodMatcher matches the request against HTTP methods. | |||
type methodMatcher []string | |||
func (m methodMatcher) Match(r *http.Request, match *RouteMatch) bool { | |||
return matchInArray(m, r.Method) | |||
} | |||
// Methods adds a matcher for HTTP methods. | |||
// It accepts a sequence of one or more methods to be matched, e.g.: | |||
// "GET", "POST", "PUT". | |||
func (r *Route) Methods(methods ...string) *Route { | |||
for k, v := range methods { | |||
methods[k] = strings.ToUpper(v) | |||
} | |||
return r.addMatcher(methodMatcher(methods)) | |||
} | |||
// Path ----------------------------------------------------------------------- | |||
// Path adds a matcher for the URL path. | |||
// It accepts a template with zero or more URL variables enclosed by {}. The | |||
// template must start with a "/". | |||
// Variables can define an optional regexp pattern to be matched: | |||
// | |||
// - {name} matches anything until the next slash. | |||
// | |||
// - {name:pattern} matches the given regexp pattern. | |||
// | |||
// For example: | |||
// | |||
// r := mux.NewRouter() | |||
// r.Path("/products/").Handler(ProductsHandler) | |||
// r.Path("/products/{key}").Handler(ProductsHandler) | |||
// r.Path("/articles/{category}/{id:[0-9]+}"). | |||
// Handler(ArticleHandler) | |||
// | |||
// Variable names must be unique in a given route. They can be retrieved | |||
// calling mux.Vars(request). | |||
func (r *Route) Path(tpl string) *Route { | |||
r.err = r.addRegexpMatcher(tpl, false, false, false) | |||
return r | |||
} | |||
// PathPrefix ----------------------------------------------------------------- | |||
// PathPrefix adds a matcher for the URL path prefix. This matches if the given | |||
// template is a prefix of the full URL path. See Route.Path() for details on | |||
// the tpl argument. | |||
// | |||
// Note that it does not treat slashes specially ("/foobar/" will be matched by | |||
// the prefix "/foo") so you may want to use a trailing slash here. | |||
// | |||
// Also note that the setting of Router.StrictSlash() has no effect on routes | |||
// with a PathPrefix matcher. | |||
func (r *Route) PathPrefix(tpl string) *Route { | |||
r.err = r.addRegexpMatcher(tpl, false, true, false) | |||
return r | |||
} | |||
// Query ---------------------------------------------------------------------- | |||
// Queries adds a matcher for URL query values. | |||
// It accepts a sequence of key/value pairs. Values may define variables. | |||
// For example: | |||
// | |||
// r := mux.NewRouter() | |||
// r.Queries("foo", "bar", "id", "{id:[0-9]+}") | |||
// | |||
// The above route will only match if the URL contains the defined queries | |||
// values, e.g.: ?foo=bar&id=42. | |||
// | |||
// It the value is an empty string, it will match any value if the key is set. | |||
// | |||
// Variables can define an optional regexp pattern to be matched: | |||
// | |||
// - {name} matches anything until the next slash. | |||
// | |||
// - {name:pattern} matches the given regexp pattern. | |||
func (r *Route) Queries(pairs ...string) *Route { | |||
length := len(pairs) | |||
if length%2 != 0 { | |||
r.err = fmt.Errorf( | |||
"mux: number of parameters must be multiple of 2, got %v", pairs) | |||
return nil | |||
} | |||
for i := 0; i < length; i += 2 { | |||
if r.err = r.addRegexpMatcher(pairs[i]+"="+pairs[i+1], false, false, true); r.err != nil { | |||
return r | |||
} | |||
} | |||
return r | |||
} | |||
// Schemes -------------------------------------------------------------------- | |||
// schemeMatcher matches the request against URL schemes. | |||
type schemeMatcher []string | |||
func (m schemeMatcher) Match(r *http.Request, match *RouteMatch) bool { | |||
return matchInArray(m, r.URL.Scheme) | |||
} | |||
// Schemes adds a matcher for URL schemes. | |||
// It accepts a sequence of schemes to be matched, e.g.: "http", "https". | |||
func (r *Route) Schemes(schemes ...string) *Route { | |||
for k, v := range schemes { | |||
schemes[k] = strings.ToLower(v) | |||
} | |||
return r.addMatcher(schemeMatcher(schemes)) | |||
} | |||
// BuildVarsFunc -------------------------------------------------------------- | |||
// BuildVarsFunc is the function signature used by custom build variable | |||
// functions (which can modify route variables before a route's URL is built). | |||
type BuildVarsFunc func(map[string]string) map[string]string | |||
// BuildVarsFunc adds a custom function to be used to modify build variables | |||
// before a route's URL is built. | |||
func (r *Route) BuildVarsFunc(f BuildVarsFunc) *Route { | |||
r.buildVarsFunc = f | |||
return r | |||
} | |||
// Subrouter ------------------------------------------------------------------ | |||
// Subrouter creates a subrouter for the route. | |||
// | |||
// It will test the inner routes only if the parent route matched. For example: | |||
// | |||
// r := mux.NewRouter() | |||
// s := r.Host("www.example.com").Subrouter() | |||
// s.HandleFunc("/products/", ProductsHandler) | |||
// s.HandleFunc("/products/{key}", ProductHandler) | |||
// s.HandleFunc("/articles/{category}/{id:[0-9]+}"), ArticleHandler) | |||
// | |||
// Here, the routes registered in the subrouter won't be tested if the host | |||
// doesn't match. | |||
func (r *Route) Subrouter() *Router { | |||
router := &Router{parent: r, strictSlash: r.strictSlash} | |||
r.addMatcher(router) | |||
return router | |||
} | |||
// ---------------------------------------------------------------------------- | |||
// URL building | |||
// ---------------------------------------------------------------------------- | |||
// URL builds a URL for the route. | |||
// | |||
// It accepts a sequence of key/value pairs for the route variables. For | |||
// example, given this route: | |||
// | |||
// r := mux.NewRouter() | |||
// r.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler). | |||
// Name("article") | |||
// | |||
// ...a URL for it can be built using: | |||
// | |||
// url, err := r.Get("article").URL("category", "technology", "id", "42") | |||
// | |||
// ...which will return an url.URL with the following path: | |||
// | |||
// "/articles/technology/42" | |||
// | |||
// This also works for host variables: | |||
// | |||
// r := mux.NewRouter() | |||
// r.Host("{subdomain}.domain.com"). | |||
// HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler). | |||
// Name("article") | |||
// | |||
// // url.String() will be "http://news.domain.com/articles/technology/42" | |||
// url, err := r.Get("article").URL("subdomain", "news", | |||
// "category", "technology", | |||
// "id", "42") | |||
// | |||
// All variables defined in the route are required, and their values must | |||
// conform to the corresponding patterns. | |||
func (r *Route) URL(pairs ...string) (*url.URL, error) { | |||
if r.err != nil { | |||
return nil, r.err | |||
} | |||
if r.regexp == nil { | |||
return nil, errors.New("mux: route doesn't have a host or path") | |||
} | |||
values, err := r.prepareVars(pairs...) | |||
if err != nil { | |||
return nil, err | |||
} | |||
var scheme, host, path string | |||
if r.regexp.host != nil { | |||
// Set a default scheme. | |||
scheme = "http" | |||
if host, err = r.regexp.host.url(values); err != nil { | |||
return nil, err | |||
} | |||
} | |||
if r.regexp.path != nil { | |||
if path, err = r.regexp.path.url(values); err != nil { | |||
return nil, err | |||
} | |||
} | |||
return &url.URL{ | |||
Scheme: scheme, | |||
Host: host, | |||
Path: path, | |||
}, nil | |||
} | |||
// URLHost builds the host part of the URL for a route. See Route.URL(). | |||
// | |||
// The route must have a host defined. | |||
func (r *Route) URLHost(pairs ...string) (*url.URL, error) { | |||
if r.err != nil { | |||
return nil, r.err | |||
} | |||
if r.regexp == nil || r.regexp.host == nil { | |||
return nil, errors.New("mux: route doesn't have a host") | |||
} | |||
values, err := r.prepareVars(pairs...) | |||
if err != nil { | |||
return nil, err | |||
} | |||
host, err := r.regexp.host.url(values) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return &url.URL{ | |||
Scheme: "http", | |||
Host: host, | |||
}, nil | |||
} | |||
// URLPath builds the path part of the URL for a route. See Route.URL(). | |||
// | |||
// The route must have a path defined. | |||
func (r *Route) URLPath(pairs ...string) (*url.URL, error) { | |||
if r.err != nil { | |||
return nil, r.err | |||
} | |||
if r.regexp == nil || r.regexp.path == nil { | |||
return nil, errors.New("mux: route doesn't have a path") | |||
} | |||
values, err := r.prepareVars(pairs...) | |||
if err != nil { | |||
return nil, err | |||
} | |||
path, err := r.regexp.path.url(values) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return &url.URL{ | |||
Path: path, | |||
}, nil | |||
} | |||
// GetPathTemplate returns the template used to build the | |||
// route match. | |||
// This is useful for building simple REST API documentation and for instrumentation | |||
// against third-party services. | |||
// An error will be returned if the route does not define a path. | |||
func (r *Route) GetPathTemplate() (string, error) { | |||
if r.err != nil { | |||
return "", r.err | |||
} | |||
if r.regexp == nil || r.regexp.path == nil { | |||
return "", errors.New("mux: route doesn't have a path") | |||
} | |||
return r.regexp.path.template, nil | |||
} | |||
// GetHostTemplate returns the template used to build the | |||
// route match. | |||
// This is useful for building simple REST API documentation and for instrumentation | |||
// against third-party services. | |||
// An error will be returned if the route does not define a host. | |||
func (r *Route) GetHostTemplate() (string, error) { | |||
if r.err != nil { | |||
return "", r.err | |||
} | |||
if r.regexp == nil || r.regexp.host == nil { | |||
return "", errors.New("mux: route doesn't have a host") | |||
} | |||
return r.regexp.host.template, nil | |||
} | |||
// prepareVars converts the route variable pairs into a map. If the route has a | |||
// BuildVarsFunc, it is invoked. | |||
func (r *Route) prepareVars(pairs ...string) (map[string]string, error) { | |||
m, err := mapFromPairsToString(pairs...) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return r.buildVars(m), nil | |||
} | |||
func (r *Route) buildVars(m map[string]string) map[string]string { | |||
if r.parent != nil { | |||
m = r.parent.buildVars(m) | |||
} | |||
if r.buildVarsFunc != nil { | |||
m = r.buildVarsFunc(m) | |||
} | |||
return m | |||
} | |||
// ---------------------------------------------------------------------------- | |||
// parentRoute | |||
// ---------------------------------------------------------------------------- | |||
// parentRoute allows routes to know about parent host and path definitions. | |||
type parentRoute interface { | |||
getNamedRoutes() map[string]*Route | |||
getRegexpGroup() *routeRegexpGroup | |||
buildVars(map[string]string) map[string]string | |||
} | |||
// getNamedRoutes returns the map where named routes are registered. | |||
func (r *Route) getNamedRoutes() map[string]*Route { | |||
if r.parent == nil { | |||
// During tests router is not always set. | |||
r.parent = NewRouter() | |||
} | |||
return r.parent.getNamedRoutes() | |||
} | |||
// getRegexpGroup returns regexp definitions from this route. | |||
func (r *Route) getRegexpGroup() *routeRegexpGroup { | |||
if r.regexp == nil { | |||
if r.parent == nil { | |||
// During tests router is not always set. | |||
r.parent = NewRouter() | |||
} | |||
regexp := r.parent.getRegexpGroup() | |||
if regexp == nil { | |||
r.regexp = new(routeRegexpGroup) | |||
} else { | |||
// Copy. | |||
r.regexp = &routeRegexpGroup{ | |||
host: regexp.host, | |||
path: regexp.path, | |||
queries: regexp.queries, | |||
} | |||
} | |||
} | |||
return r.regexp | |||
} |
@@ -0,0 +1,27 @@ | |||
Copyright (c) 2012 Rodrigo Moraes. All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are | |||
met: | |||
* Redistributions of source code must retain the above copyright | |||
notice, this list of conditions and the following disclaimer. | |||
* Redistributions in binary form must reproduce the above | |||
copyright notice, this list of conditions and the following disclaimer | |||
in the documentation and/or other materials provided with the | |||
distribution. | |||
* Neither the name of Google Inc. nor the names of its | |||
contributors may be used to endorse or promote products derived from | |||
this software without specific prior written permission. | |||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS | |||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT | |||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR | |||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT | |||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, | |||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT | |||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY | |||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | |||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
@@ -0,0 +1,78 @@ | |||
securecookie | |||
============ | |||
[![GoDoc](https://godoc.org/github.com/gorilla/securecookie?status.svg)](https://godoc.org/github.com/gorilla/securecookie) [![Build Status](https://travis-ci.org/gorilla/securecookie.png?branch=master)](https://travis-ci.org/gorilla/securecookie) | |||
securecookie encodes and decodes authenticated and optionally encrypted | |||
cookie values. | |||
Secure cookies can't be forged, because their values are validated using HMAC. | |||
When encrypted, the content is also inaccessible to malicious eyes. It is still | |||
recommended that sensitive data not be stored in cookies, and that HTTPS be used | |||
to prevent cookie [replay attacks](https://en.wikipedia.org/wiki/Replay_attack). | |||
## Examples | |||
To use it, first create a new SecureCookie instance: | |||
```go | |||
// Hash keys should be at least 32 bytes long | |||
var hashKey = []byte("very-secret") | |||
// Block keys should be 16 bytes (AES-128) or 32 bytes (AES-256) long. | |||
// Shorter keys may weaken the encryption used. | |||
var blockKey = []byte("a-lot-secret") | |||
var s = securecookie.New(hashKey, blockKey) | |||
``` | |||
The hashKey is required, used to authenticate the cookie value using HMAC. | |||
It is recommended to use a key with 32 or 64 bytes. | |||
The blockKey is optional, used to encrypt the cookie value -- set it to nil | |||
to not use encryption. If set, the length must correspond to the block size | |||
of the encryption algorithm. For AES, used by default, valid lengths are | |||
16, 24, or 32 bytes to select AES-128, AES-192, or AES-256. | |||
Strong keys can be created using the convenience function GenerateRandomKey(). | |||
Once a SecureCookie instance is set, use it to encode a cookie value: | |||
```go | |||
func SetCookieHandler(w http.ResponseWriter, r *http.Request) { | |||
value := map[string]string{ | |||
"foo": "bar", | |||
} | |||
if encoded, err := s.Encode("cookie-name", value); err == nil { | |||
cookie := &http.Cookie{ | |||
Name: "cookie-name", | |||
Value: encoded, | |||
Path: "/", | |||
Secure: true, | |||
HttpOnly: true, | |||
} | |||
http.SetCookie(w, cookie) | |||
} | |||
} | |||
``` | |||
Later, use the same SecureCookie instance to decode and validate a cookie | |||
value: | |||
```go | |||
func ReadCookieHandler(w http.ResponseWriter, r *http.Request) { | |||
if cookie, err := r.Cookie("cookie-name"); err == nil { | |||
value := make(map[string]string) | |||
if err = s2.Decode("cookie-name", cookie.Value, &value); err == nil { | |||
fmt.Fprintf(w, "The value of foo is %q", value["foo"]) | |||
} | |||
} | |||
} | |||
``` | |||
We stored a map[string]string, but secure cookies can hold any value that | |||
can be encoded using `encoding/gob`. To store custom types, they must be | |||
registered first using gob.Register(). For basic types this is not needed; | |||
it works out of the box. An optional JSON encoder that uses `encoding/json` is | |||
available for types compatible with JSON. | |||
## License | |||
BSD licensed. See the LICENSE file for details. |
@@ -0,0 +1,61 @@ | |||
// Copyright 2012 The Gorilla Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
/* | |||
Package securecookie encodes and decodes authenticated and optionally | |||
encrypted cookie values. | |||
Secure cookies can't be forged, because their values are validated using HMAC. | |||
When encrypted, the content is also inaccessible to malicious eyes. | |||
To use it, first create a new SecureCookie instance: | |||
var hashKey = []byte("very-secret") | |||
var blockKey = []byte("a-lot-secret") | |||
var s = securecookie.New(hashKey, blockKey) | |||
The hashKey is required, used to authenticate the cookie value using HMAC. | |||
It is recommended to use a key with 32 or 64 bytes. | |||
The blockKey is optional, used to encrypt the cookie value -- set it to nil | |||
to not use encryption. If set, the length must correspond to the block size | |||
of the encryption algorithm. For AES, used by default, valid lengths are | |||
16, 24, or 32 bytes to select AES-128, AES-192, or AES-256. | |||
Strong keys can be created using the convenience function GenerateRandomKey(). | |||
Once a SecureCookie instance is set, use it to encode a cookie value: | |||
func SetCookieHandler(w http.ResponseWriter, r *http.Request) { | |||
value := map[string]string{ | |||
"foo": "bar", | |||
} | |||
if encoded, err := s.Encode("cookie-name", value); err == nil { | |||
cookie := &http.Cookie{ | |||
Name: "cookie-name", | |||
Value: encoded, | |||
Path: "/", | |||
} | |||
http.SetCookie(w, cookie) | |||
} | |||
} | |||
Later, use the same SecureCookie instance to decode and validate a cookie | |||
value: | |||
func ReadCookieHandler(w http.ResponseWriter, r *http.Request) { | |||
if cookie, err := r.Cookie("cookie-name"); err == nil { | |||
value := make(map[string]string) | |||
if err = s2.Decode("cookie-name", cookie.Value, &value); err == nil { | |||
fmt.Fprintf(w, "The value of foo is %q", value["foo"]) | |||
} | |||
} | |||
} | |||
We stored a map[string]string, but secure cookies can hold any value that | |||
can be encoded using encoding/gob. To store custom types, they must be | |||
registered first using gob.Register(). For basic types this is not needed; | |||
it works out of the box. | |||
*/ | |||
package securecookie |
@@ -0,0 +1,25 @@ | |||
// +build gofuzz | |||
package securecookie | |||
var hashKey = []byte("very-secret12345") | |||
var blockKey = []byte("a-lot-secret1234") | |||
var s = New(hashKey, blockKey) | |||
type Cookie struct { | |||
B bool | |||
I int | |||
S string | |||
} | |||
func Fuzz(data []byte) int { | |||
datas := string(data) | |||
var c Cookie | |||
if err := s.Decode("fuzz", datas, &c); err != nil { | |||
return 0 | |||
} | |||
if _, err := s.Encode("fuzz", c); err != nil { | |||
panic(err) | |||
} | |||
return 1 | |||
} |
@@ -0,0 +1,646 @@ | |||
// Copyright 2012 The Gorilla Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
package securecookie | |||
import ( | |||
"bytes" | |||
"crypto/aes" | |||
"crypto/cipher" | |||
"crypto/hmac" | |||
"crypto/rand" | |||
"crypto/sha256" | |||
"crypto/subtle" | |||
"encoding/base64" | |||
"encoding/gob" | |||
"encoding/json" | |||
"fmt" | |||
"hash" | |||
"io" | |||
"strconv" | |||
"strings" | |||
"time" | |||
) | |||
// Error is the interface of all errors returned by functions in this library. | |||
type Error interface { | |||
error | |||
// IsUsage returns true for errors indicating the client code probably | |||
// uses this library incorrectly. For example, the client may have | |||
// failed to provide a valid hash key, or may have failed to configure | |||
// the Serializer adequately for encoding value. | |||
IsUsage() bool | |||
// IsDecode returns true for errors indicating that a cookie could not | |||
// be decoded and validated. Since cookies are usually untrusted | |||
// user-provided input, errors of this type should be expected. | |||
// Usually, the proper action is simply to reject the request. | |||
IsDecode() bool | |||
// IsInternal returns true for unexpected errors occurring in the | |||
// securecookie implementation. | |||
IsInternal() bool | |||
// Cause, if it returns a non-nil value, indicates that this error was | |||
// propagated from some underlying library. If this method returns nil, | |||
// this error was raised directly by this library. | |||
// | |||
// Cause is provided principally for debugging/logging purposes; it is | |||
// rare that application logic should perform meaningfully different | |||
// logic based on Cause. See, for example, the caveats described on | |||
// (MultiError).Cause(). | |||
Cause() error | |||
} | |||
// errorType is a bitmask giving the error type(s) of an cookieError value. | |||
type errorType int | |||
const ( | |||
usageError = errorType(1 << iota) | |||
decodeError | |||
internalError | |||
) | |||
type cookieError struct { | |||
typ errorType | |||
msg string | |||
cause error | |||
} | |||
func (e cookieError) IsUsage() bool { return (e.typ & usageError) != 0 } | |||
func (e cookieError) IsDecode() bool { return (e.typ & decodeError) != 0 } | |||
func (e cookieError) IsInternal() bool { return (e.typ & internalError) != 0 } | |||
func (e cookieError) Cause() error { return e.cause } | |||
func (e cookieError) Error() string { | |||
parts := []string{"securecookie: "} | |||
if e.msg == "" { | |||
parts = append(parts, "error") | |||
} else { | |||
parts = append(parts, e.msg) | |||
} | |||
if c := e.Cause(); c != nil { | |||
parts = append(parts, " - caused by: ", c.Error()) | |||
} | |||
return strings.Join(parts, "") | |||
} | |||
var ( | |||
errGeneratingIV = cookieError{typ: internalError, msg: "failed to generate random iv"} | |||
errNoCodecs = cookieError{typ: usageError, msg: "no codecs provided"} | |||
errHashKeyNotSet = cookieError{typ: usageError, msg: "hash key is not set"} | |||
errBlockKeyNotSet = cookieError{typ: usageError, msg: "block key is not set"} | |||
errEncodedValueTooLong = cookieError{typ: usageError, msg: "the value is too long"} | |||
errValueToDecodeTooLong = cookieError{typ: decodeError, msg: "the value is too long"} | |||
errTimestampInvalid = cookieError{typ: decodeError, msg: "invalid timestamp"} | |||
errTimestampTooNew = cookieError{typ: decodeError, msg: "timestamp is too new"} | |||
errTimestampExpired = cookieError{typ: decodeError, msg: "expired timestamp"} | |||
errDecryptionFailed = cookieError{typ: decodeError, msg: "the value could not be decrypted"} | |||
errValueNotByte = cookieError{typ: decodeError, msg: "value not a []byte."} | |||
errValueNotBytePtr = cookieError{typ: decodeError, msg: "value not a pointer to []byte."} | |||
// ErrMacInvalid indicates that cookie decoding failed because the HMAC | |||
// could not be extracted and verified. Direct use of this error | |||
// variable is deprecated; it is public only for legacy compatibility, | |||
// and may be privatized in the future, as it is rarely useful to | |||
// distinguish between this error and other Error implementations. | |||
ErrMacInvalid = cookieError{typ: decodeError, msg: "the value is not valid"} | |||
) | |||
// Codec defines an interface to encode and decode cookie values. | |||
type Codec interface { | |||
Encode(name string, value interface{}) (string, error) | |||
Decode(name, value string, dst interface{}) error | |||
} | |||
// New returns a new SecureCookie. | |||
// | |||
// hashKey is required, used to authenticate values using HMAC. Create it using | |||
// GenerateRandomKey(). It is recommended to use a key with 32 or 64 bytes. | |||
// | |||
// blockKey is optional, used to encrypt values. Create it using | |||
// GenerateRandomKey(). The key length must correspond to the block size | |||
// of the encryption algorithm. For AES, used by default, valid lengths are | |||
// 16, 24, or 32 bytes to select AES-128, AES-192, or AES-256. | |||
// The default encoder used for cookie serialization is encoding/gob. | |||
// | |||
// Note that keys created using GenerateRandomKey() are not automatically | |||
// persisted. New keys will be created when the application is restarted, and | |||
// previously issued cookies will not be able to be decoded. | |||
func New(hashKey, blockKey []byte) *SecureCookie { | |||
s := &SecureCookie{ | |||
hashKey: hashKey, | |||
blockKey: blockKey, | |||
hashFunc: sha256.New, | |||
maxAge: 86400 * 30, | |||
maxLength: 4096, | |||
sz: GobEncoder{}, | |||
} | |||
if hashKey == nil { | |||
s.err = errHashKeyNotSet | |||
} | |||
if blockKey != nil { | |||
s.BlockFunc(aes.NewCipher) | |||
} | |||
return s | |||
} | |||
// SecureCookie encodes and decodes authenticated and optionally encrypted | |||
// cookie values. | |||
type SecureCookie struct { | |||
hashKey []byte | |||
hashFunc func() hash.Hash | |||
blockKey []byte | |||
block cipher.Block | |||
maxLength int | |||
maxAge int64 | |||
minAge int64 | |||
err error | |||
sz Serializer | |||
// For testing purposes, the function that returns the current timestamp. | |||
// If not set, it will use time.Now().UTC().Unix(). | |||
timeFunc func() int64 | |||
} | |||
// Serializer provides an interface for providing custom serializers for cookie | |||
// values. | |||
type Serializer interface { | |||
Serialize(src interface{}) ([]byte, error) | |||
Deserialize(src []byte, dst interface{}) error | |||
} | |||
// GobEncoder encodes cookie values using encoding/gob. This is the simplest | |||
// encoder and can handle complex types via gob.Register. | |||
type GobEncoder struct{} | |||
// JSONEncoder encodes cookie values using encoding/json. Users who wish to | |||
// encode complex types need to satisfy the json.Marshaller and | |||
// json.Unmarshaller interfaces. | |||
type JSONEncoder struct{} | |||
// NopEncoder does not encode cookie values, and instead simply accepts a []byte | |||
// (as an interface{}) and returns a []byte. This is particularly useful when | |||
// you encoding an object upstream and do not wish to re-encode it. | |||
type NopEncoder struct{} | |||
// MaxLength restricts the maximum length, in bytes, for the cookie value. | |||
// | |||
// Default is 4096, which is the maximum value accepted by Internet Explorer. | |||
func (s *SecureCookie) MaxLength(value int) *SecureCookie { | |||
s.maxLength = value | |||
return s | |||
} | |||
// MaxAge restricts the maximum age, in seconds, for the cookie value. | |||
// | |||
// Default is 86400 * 30. Set it to 0 for no restriction. | |||
func (s *SecureCookie) MaxAge(value int) *SecureCookie { | |||
s.maxAge = int64(value) | |||
return s | |||
} | |||
// MinAge restricts the minimum age, in seconds, for the cookie value. | |||
// | |||
// Default is 0 (no restriction). | |||
func (s *SecureCookie) MinAge(value int) *SecureCookie { | |||
s.minAge = int64(value) | |||
return s | |||
} | |||
// HashFunc sets the hash function used to create HMAC. | |||
// | |||
// Default is crypto/sha256.New. | |||
func (s *SecureCookie) HashFunc(f func() hash.Hash) *SecureCookie { | |||
s.hashFunc = f | |||
return s | |||
} | |||
// BlockFunc sets the encryption function used to create a cipher.Block. | |||
// | |||
// Default is crypto/aes.New. | |||
func (s *SecureCookie) BlockFunc(f func([]byte) (cipher.Block, error)) *SecureCookie { | |||
if s.blockKey == nil { | |||
s.err = errBlockKeyNotSet | |||
} else if block, err := f(s.blockKey); err == nil { | |||
s.block = block | |||
} else { | |||
s.err = cookieError{cause: err, typ: usageError} | |||
} | |||
return s | |||
} | |||
// Encoding sets the encoding/serialization method for cookies. | |||
// | |||
// Default is encoding/gob. To encode special structures using encoding/gob, | |||
// they must be registered first using gob.Register(). | |||
func (s *SecureCookie) SetSerializer(sz Serializer) *SecureCookie { | |||
s.sz = sz | |||
return s | |||
} | |||
// Encode encodes a cookie value. | |||
// | |||
// It serializes, optionally encrypts, signs with a message authentication code, | |||
// and finally encodes the value. | |||
// | |||
// The name argument is the cookie name. It is stored with the encoded value. | |||
// The value argument is the value to be encoded. It can be any value that can | |||
// be encoded using the currently selected serializer; see SetSerializer(). | |||
// | |||
// It is the client's responsibility to ensure that value, when encoded using | |||
// the current serialization/encryption settings on s and then base64-encoded, | |||
// is shorter than the maximum permissible length. | |||
func (s *SecureCookie) Encode(name string, value interface{}) (string, error) { | |||
if s.err != nil { | |||
return "", s.err | |||
} | |||
if s.hashKey == nil { | |||
s.err = errHashKeyNotSet | |||
return "", s.err | |||
} | |||
var err error | |||
var b []byte | |||
// 1. Serialize. | |||
if b, err = s.sz.Serialize(value); err != nil { | |||
return "", cookieError{cause: err, typ: usageError} | |||
} | |||
// 2. Encrypt (optional). | |||
if s.block != nil { | |||
if b, err = encrypt(s.block, b); err != nil { | |||
return "", cookieError{cause: err, typ: usageError} | |||
} | |||
} | |||
b = encode(b) | |||
// 3. Create MAC for "name|date|value". Extra pipe to be used later. | |||
b = []byte(fmt.Sprintf("%s|%d|%s|", name, s.timestamp(), b)) | |||
mac := createMac(hmac.New(s.hashFunc, s.hashKey), b[:len(b)-1]) | |||
// Append mac, remove name. | |||
b = append(b, mac...)[len(name)+1:] | |||
// 4. Encode to base64. | |||
b = encode(b) | |||
// 5. Check length. | |||
if s.maxLength != 0 && len(b) > s.maxLength { | |||
return "", errEncodedValueTooLong | |||
} | |||
// Done. | |||
return string(b), nil | |||
} | |||
// Decode decodes a cookie value. | |||
// | |||
// It decodes, verifies a message authentication code, optionally decrypts and | |||
// finally deserializes the value. | |||
// | |||
// The name argument is the cookie name. It must be the same name used when | |||
// it was stored. The value argument is the encoded cookie value. The dst | |||
// argument is where the cookie will be decoded. It must be a pointer. | |||
func (s *SecureCookie) Decode(name, value string, dst interface{}) error { | |||
if s.err != nil { | |||
return s.err | |||
} | |||
if s.hashKey == nil { | |||
s.err = errHashKeyNotSet | |||
return s.err | |||
} | |||
// 1. Check length. | |||
if s.maxLength != 0 && len(value) > s.maxLength { | |||
return errValueToDecodeTooLong | |||
} | |||
// 2. Decode from base64. | |||
b, err := decode([]byte(value)) | |||
if err != nil { | |||
return err | |||
} | |||
// 3. Verify MAC. Value is "date|value|mac". | |||
parts := bytes.SplitN(b, []byte("|"), 3) | |||
if len(parts) != 3 { | |||
return ErrMacInvalid | |||
} | |||
h := hmac.New(s.hashFunc, s.hashKey) | |||
b = append([]byte(name+"|"), b[:len(b)-len(parts[2])-1]...) | |||
if err = verifyMac(h, b, parts[2]); err != nil { | |||
return err | |||
} | |||
// 4. Verify date ranges. | |||
var t1 int64 | |||
if t1, err = strconv.ParseInt(string(parts[0]), 10, 64); err != nil { | |||
return errTimestampInvalid | |||
} | |||
t2 := s.timestamp() | |||
if s.minAge != 0 && t1 > t2-s.minAge { | |||
return errTimestampTooNew | |||
} | |||
if s.maxAge != 0 && t1 < t2-s.maxAge { | |||
return errTimestampExpired | |||
} | |||
// 5. Decrypt (optional). | |||
b, err = decode(parts[1]) | |||
if err != nil { | |||
return err | |||
} | |||
if s.block != nil { | |||
if b, err = decrypt(s.block, b); err != nil { | |||
return err | |||
} | |||
} | |||
// 6. Deserialize. | |||
if err = s.sz.Deserialize(b, dst); err != nil { | |||
return cookieError{cause: err, typ: decodeError} | |||
} | |||
// Done. | |||
return nil | |||
} | |||
// timestamp returns the current timestamp, in seconds. | |||
// | |||
// For testing purposes, the function that generates the timestamp can be | |||
// overridden. If not set, it will return time.Now().UTC().Unix(). | |||
func (s *SecureCookie) timestamp() int64 { | |||
if s.timeFunc == nil { | |||
return time.Now().UTC().Unix() | |||
} | |||
return s.timeFunc() | |||
} | |||
// Authentication ------------------------------------------------------------- | |||
// createMac creates a message authentication code (MAC). | |||
func createMac(h hash.Hash, value []byte) []byte { | |||
h.Write(value) | |||
return h.Sum(nil) | |||
} | |||
// verifyMac verifies that a message authentication code (MAC) is valid. | |||
func verifyMac(h hash.Hash, value []byte, mac []byte) error { | |||
mac2 := createMac(h, value) | |||
// Check that both MACs are of equal length, as subtle.ConstantTimeCompare | |||
// does not do this prior to Go 1.4. | |||
if len(mac) == len(mac2) && subtle.ConstantTimeCompare(mac, mac2) == 1 { | |||
return nil | |||
} | |||
return ErrMacInvalid | |||
} | |||
// Encryption ----------------------------------------------------------------- | |||
// encrypt encrypts a value using the given block in counter mode. | |||
// | |||
// A random initialization vector (http://goo.gl/zF67k) with the length of the | |||
// block size is prepended to the resulting ciphertext. | |||
func encrypt(block cipher.Block, value []byte) ([]byte, error) { | |||
iv := GenerateRandomKey(block.BlockSize()) | |||
if iv == nil { | |||
return nil, errGeneratingIV | |||
} | |||
// Encrypt it. | |||
stream := cipher.NewCTR(block, iv) | |||
stream.XORKeyStream(value, value) | |||
// Return iv + ciphertext. | |||
return append(iv, value...), nil | |||
} | |||
// decrypt decrypts a value using the given block in counter mode. | |||
// | |||
// The value to be decrypted must be prepended by a initialization vector | |||
// (http://goo.gl/zF67k) with the length of the block size. | |||
func decrypt(block cipher.Block, value []byte) ([]byte, error) { | |||
size := block.BlockSize() | |||
if len(value) > size { | |||
// Extract iv. | |||
iv := value[:size] | |||
// Extract ciphertext. | |||
value = value[size:] | |||
// Decrypt it. | |||
stream := cipher.NewCTR(block, iv) | |||
stream.XORKeyStream(value, value) | |||
return value, nil | |||
} | |||
return nil, errDecryptionFailed | |||
} | |||
// Serialization -------------------------------------------------------------- | |||
// Serialize encodes a value using gob. | |||
func (e GobEncoder) Serialize(src interface{}) ([]byte, error) { | |||
buf := new(bytes.Buffer) | |||
enc := gob.NewEncoder(buf) | |||
if err := enc.Encode(src); err != nil { | |||
return nil, cookieError{cause: err, typ: usageError} | |||
} | |||
return buf.Bytes(), nil | |||
} | |||
// Deserialize decodes a value using gob. | |||
func (e GobEncoder) Deserialize(src []byte, dst interface{}) error { | |||
dec := gob.NewDecoder(bytes.NewBuffer(src)) | |||
if err := dec.Decode(dst); err != nil { | |||
return cookieError{cause: err, typ: decodeError} | |||
} | |||
return nil | |||
} | |||
// Serialize encodes a value using encoding/json. | |||
func (e JSONEncoder) Serialize(src interface{}) ([]byte, error) { | |||
buf := new(bytes.Buffer) | |||
enc := json.NewEncoder(buf) | |||
if err := enc.Encode(src); err != nil { | |||
return nil, cookieError{cause: err, typ: usageError} | |||
} | |||
return buf.Bytes(), nil | |||
} | |||
// Deserialize decodes a value using encoding/json. | |||
func (e JSONEncoder) Deserialize(src []byte, dst interface{}) error { | |||
dec := json.NewDecoder(bytes.NewReader(src)) | |||
if err := dec.Decode(dst); err != nil { | |||
return cookieError{cause: err, typ: decodeError} | |||
} | |||
return nil | |||
} | |||
// Serialize passes a []byte through as-is. | |||
func (e NopEncoder) Serialize(src interface{}) ([]byte, error) { | |||
if b, ok := src.([]byte); ok { | |||
return b, nil | |||
} | |||
return nil, errValueNotByte | |||
} | |||
// Deserialize passes a []byte through as-is. | |||
func (e NopEncoder) Deserialize(src []byte, dst interface{}) error { | |||
if dat, ok := dst.(*[]byte); ok { | |||
*dat = src | |||
return nil | |||
} | |||
return errValueNotBytePtr | |||
} | |||
// Encoding ------------------------------------------------------------------- | |||
// encode encodes a value using base64. | |||
func encode(value []byte) []byte { | |||
encoded := make([]byte, base64.URLEncoding.EncodedLen(len(value))) | |||
base64.URLEncoding.Encode(encoded, value) | |||
return encoded | |||
} | |||
// decode decodes a cookie using base64. | |||
func decode(value []byte) ([]byte, error) { | |||
decoded := make([]byte, base64.URLEncoding.DecodedLen(len(value))) | |||
b, err := base64.URLEncoding.Decode(decoded, value) | |||
if err != nil { | |||
return nil, cookieError{cause: err, typ: decodeError, msg: "base64 decode failed"} | |||
} | |||
return decoded[:b], nil | |||
} | |||
// Helpers -------------------------------------------------------------------- | |||
// GenerateRandomKey creates a random key with the given length in bytes. | |||
// On failure, returns nil. | |||
// | |||
// Callers should explicitly check for the possibility of a nil return, treat | |||
// it as a failure of the system random number generator, and not continue. | |||
func GenerateRandomKey(length int) []byte { | |||
k := make([]byte, length) | |||
if _, err := io.ReadFull(rand.Reader, k); err != nil { | |||
return nil | |||
} | |||
return k | |||
} | |||
// CodecsFromPairs returns a slice of SecureCookie instances. | |||
// | |||
// It is a convenience function to create a list of codecs for key rotation. Note | |||
// that the generated Codecs will have the default options applied: callers | |||
// should iterate over each Codec and type-assert the underlying *SecureCookie to | |||
// change these. | |||
// | |||
// Example: | |||
// | |||
// codecs := securecookie.CodecsFromPairs( | |||
// []byte("new-hash-key"), | |||
// []byte("new-block-key"), | |||
// []byte("old-hash-key"), | |||
// []byte("old-block-key"), | |||
// ) | |||
// | |||
// // Modify each instance. | |||
// for _, s := range codecs { | |||
// if cookie, ok := s.(*securecookie.SecureCookie); ok { | |||
// cookie.MaxAge(86400 * 7) | |||
// cookie.SetSerializer(securecookie.JSONEncoder{}) | |||
// cookie.HashFunc(sha512.New512_256) | |||
// } | |||
// } | |||
// | |||
func CodecsFromPairs(keyPairs ...[]byte) []Codec { | |||
codecs := make([]Codec, len(keyPairs)/2+len(keyPairs)%2) | |||
for i := 0; i < len(keyPairs); i += 2 { | |||
var blockKey []byte | |||
if i+1 < len(keyPairs) { | |||
blockKey = keyPairs[i+1] | |||
} | |||
codecs[i/2] = New(keyPairs[i], blockKey) | |||
} | |||
return codecs | |||
} | |||
// EncodeMulti encodes a cookie value using a group of codecs. | |||
// | |||
// The codecs are tried in order. Multiple codecs are accepted to allow | |||
// key rotation. | |||
// | |||
// On error, may return a MultiError. | |||
func EncodeMulti(name string, value interface{}, codecs ...Codec) (string, error) { | |||
if len(codecs) == 0 { | |||
return "", errNoCodecs | |||
} | |||
var errors MultiError | |||
for _, codec := range codecs { | |||
encoded, err := codec.Encode(name, value) | |||
if err == nil { | |||
return encoded, nil | |||
} | |||
errors = append(errors, err) | |||
} | |||
return "", errors | |||
} | |||
// DecodeMulti decodes a cookie value using a group of codecs. | |||
// | |||
// The codecs are tried in order. Multiple codecs are accepted to allow | |||
// key rotation. | |||
// | |||
// On error, may return a MultiError. | |||
func DecodeMulti(name string, value string, dst interface{}, codecs ...Codec) error { | |||
if len(codecs) == 0 { | |||
return errNoCodecs | |||
} | |||
var errors MultiError | |||
for _, codec := range codecs { | |||
err := codec.Decode(name, value, dst) | |||
if err == nil { | |||
return nil | |||
} | |||
errors = append(errors, err) | |||
} | |||
return errors | |||
} | |||
// MultiError groups multiple errors. | |||
type MultiError []error | |||
func (m MultiError) IsUsage() bool { return m.any(func(e Error) bool { return e.IsUsage() }) } | |||
func (m MultiError) IsDecode() bool { return m.any(func(e Error) bool { return e.IsDecode() }) } | |||
func (m MultiError) IsInternal() bool { return m.any(func(e Error) bool { return e.IsInternal() }) } | |||
// Cause returns nil for MultiError; there is no unique underlying cause in the | |||
// general case. | |||
// | |||
// Note: we could conceivably return a non-nil Cause only when there is exactly | |||
// one child error with a Cause. However, it would be brittle for client code | |||
// to rely on the arity of causes inside a MultiError, so we have opted not to | |||
// provide this functionality. Clients which really wish to access the Causes | |||
// of the underlying errors are free to iterate through the errors themselves. | |||
func (m MultiError) Cause() error { return nil } | |||
func (m MultiError) Error() string { | |||
s, n := "", 0 | |||
for _, e := range m { | |||
if e != nil { | |||
if n == 0 { | |||
s = e.Error() | |||
} | |||
n++ | |||
} | |||
} | |||
switch n { | |||
case 0: | |||
return "(0 errors)" | |||
case 1: | |||
return s | |||
case 2: | |||
return s + " (and 1 other error)" | |||
} | |||
return fmt.Sprintf("%s (and %d other errors)", s, n-1) | |||
} | |||
// any returns true if any element of m is an Error for which pred returns true. | |||
func (m MultiError) any(pred func(Error) bool) bool { | |||
for _, e := range m { | |||
if ourErr, ok := e.(Error); ok && pred(ourErr) { | |||
return true | |||
} | |||
} | |||
return false | |||
} |
@@ -0,0 +1,27 @@ | |||
Copyright (c) 2012 Rodrigo Moraes. All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are | |||
met: | |||
* Redistributions of source code must retain the above copyright | |||
notice, this list of conditions and the following disclaimer. | |||
* Redistributions in binary form must reproduce the above | |||
copyright notice, this list of conditions and the following disclaimer | |||
in the documentation and/or other materials provided with the | |||
distribution. | |||
* Neither the name of Google Inc. nor the names of its | |||
contributors may be used to endorse or promote products derived from | |||
this software without specific prior written permission. | |||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS | |||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT | |||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR | |||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT | |||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, | |||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT | |||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY | |||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | |||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
@@ -0,0 +1,81 @@ | |||
sessions | |||
======== | |||
[![GoDoc](https://godoc.org/github.com/gorilla/sessions?status.svg)](https://godoc.org/github.com/gorilla/sessions) [![Build Status](https://travis-ci.org/gorilla/sessions.png?branch=master)](https://travis-ci.org/gorilla/sessions) | |||
gorilla/sessions provides cookie and filesystem sessions and infrastructure for | |||
custom session backends. | |||
The key features are: | |||
* Simple API: use it as an easy way to set signed (and optionally | |||
encrypted) cookies. | |||
* Built-in backends to store sessions in cookies or the filesystem. | |||
* Flash messages: session values that last until read. | |||
* Convenient way to switch session persistency (aka "remember me") and set | |||
other attributes. | |||
* Mechanism to rotate authentication and encryption keys. | |||
* Multiple sessions per request, even using different backends. | |||
* Interfaces and infrastructure for custom session backends: sessions from | |||
different stores can be retrieved and batch-saved using a common API. | |||
Let's start with an example that shows the sessions API in a nutshell: | |||
```go | |||
import ( | |||
"net/http" | |||
"github.com/gorilla/sessions" | |||
) | |||
var store = sessions.NewCookieStore([]byte("something-very-secret")) | |||
func MyHandler(w http.ResponseWriter, r *http.Request) { | |||
// Get a session. We're ignoring the error resulted from decoding an | |||
// existing session: Get() always returns a session, even if empty. | |||
session, _ := store.Get(r, "session-name") | |||
// Set some session values. | |||
session.Values["foo"] = "bar" | |||
session.Values[42] = 43 | |||
// Save it before we write to the response/return from the handler. | |||
session.Save(r, w) | |||
} | |||
``` | |||
First we initialize a session store calling `NewCookieStore()` and passing a | |||
secret key used to authenticate the session. Inside the handler, we call | |||
`store.Get()` to retrieve an existing session or a new one. Then we set some | |||
session values in session.Values, which is a `map[interface{}]interface{}`. | |||
And finally we call `session.Save()` to save the session in the response. | |||
Important Note: If you aren't using gorilla/mux, you need to wrap your handlers | |||
with | |||
[`context.ClearHandler`](http://www.gorillatoolkit.org/pkg/context#ClearHandler) | |||
as or else you will leak memory! An easy way to do this is to wrap the top-level | |||
mux when calling http.ListenAndServe: | |||
More examples are available [on the Gorilla | |||
website](http://www.gorillatoolkit.org/pkg/sessions). | |||
## Store Implementations | |||
Other implementations of the `sessions.Store` interface: | |||
* [github.com/starJammer/gorilla-sessions-arangodb](https://github.com/starJammer/gorilla-sessions-arangodb) - ArangoDB | |||
* [github.com/yosssi/boltstore](https://github.com/yosssi/boltstore) - Bolt | |||
* [github.com/srinathgs/couchbasestore](https://github.com/srinathgs/couchbasestore) - Couchbase | |||
* [github.com/denizeren/dynamostore](https://github.com/denizeren/dynamostore) - Dynamodb on AWS | |||
* [github.com/bradleypeabody/gorilla-sessions-memcache](https://github.com/bradleypeabody/gorilla-sessions-memcache) - Memcache | |||
* [github.com/dsoprea/go-appengine-sessioncascade](https://github.com/dsoprea/go-appengine-sessioncascade) - Memcache/Datastore/Context in AppEngine | |||
* [github.com/kidstuff/mongostore](https://github.com/kidstuff/mongostore) - MongoDB | |||
* [github.com/srinathgs/mysqlstore](https://github.com/srinathgs/mysqlstore) - MySQL | |||
* [github.com/EnumApps/clustersqlstore](https://github.com/EnumApps/clustersqlstore) - MySQL Cluster | |||
* [github.com/antonlindstrom/pgstore](https://github.com/antonlindstrom/pgstore) - PostgreSQL | |||
* [github.com/boj/redistore](https://github.com/boj/redistore) - Redis | |||
* [github.com/boj/rethinkstore](https://github.com/boj/rethinkstore) - RethinkDB | |||
* [github.com/boj/riakstore](https://github.com/boj/riakstore) - Riak | |||
* [github.com/michaeljs1990/sqlitestore](https://github.com/michaeljs1990/sqlitestore) - SQLite | |||
* [github.com/wader/gormstore](https://github.com/wader/gormstore) - GORM (MySQL, PostgreSQL, SQLite) | |||
* [github.com/gernest/qlstore](https://github.com/gernest/qlstore) - ql | |||
## License | |||
BSD licensed. See the LICENSE file for details. |
@@ -0,0 +1,199 @@ | |||
// Copyright 2012 The Gorilla Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
/* | |||
Package sessions provides cookie and filesystem sessions and | |||
infrastructure for custom session backends. | |||
The key features are: | |||
* Simple API: use it as an easy way to set signed (and optionally | |||
encrypted) cookies. | |||
* Built-in backends to store sessions in cookies or the filesystem. | |||
* Flash messages: session values that last until read. | |||
* Convenient way to switch session persistency (aka "remember me") and set | |||
other attributes. | |||
* Mechanism to rotate authentication and encryption keys. | |||
* Multiple sessions per request, even using different backends. | |||
* Interfaces and infrastructure for custom session backends: sessions from | |||
different stores can be retrieved and batch-saved using a common API. | |||
Let's start with an example that shows the sessions API in a nutshell: | |||
import ( | |||
"net/http" | |||
"github.com/gorilla/sessions" | |||
) | |||
var store = sessions.NewCookieStore([]byte("something-very-secret")) | |||
func MyHandler(w http.ResponseWriter, r *http.Request) { | |||
// Get a session. We're ignoring the error resulted from decoding an | |||
// existing session: Get() always returns a session, even if empty. | |||
session, err := store.Get(r, "session-name") | |||
if err != nil { | |||
http.Error(w, err.Error(), http.StatusInternalServerError) | |||
return | |||
} | |||
// Set some session values. | |||
session.Values["foo"] = "bar" | |||
session.Values[42] = 43 | |||
// Save it before we write to the response/return from the handler. | |||
session.Save(r, w) | |||
} | |||
First we initialize a session store calling NewCookieStore() and passing a | |||
secret key used to authenticate the session. Inside the handler, we call | |||
store.Get() to retrieve an existing session or a new one. Then we set some | |||
session values in session.Values, which is a map[interface{}]interface{}. | |||
And finally we call session.Save() to save the session in the response. | |||
Note that in production code, we should check for errors when calling | |||
session.Save(r, w), and either display an error message or otherwise handle it. | |||
Save must be called before writing to the response, otherwise the session | |||
cookie will not be sent to the client. | |||
Important Note: If you aren't using gorilla/mux, you need to wrap your handlers | |||
with context.ClearHandler as or else you will leak memory! An easy way to do this | |||
is to wrap the top-level mux when calling http.ListenAndServe: | |||
http.ListenAndServe(":8080", context.ClearHandler(http.DefaultServeMux)) | |||
The ClearHandler function is provided by the gorilla/context package. | |||
That's all you need to know for the basic usage. Let's take a look at other | |||
options, starting with flash messages. | |||
Flash messages are session values that last until read. The term appeared with | |||
Ruby On Rails a few years back. When we request a flash message, it is removed | |||
from the session. To add a flash, call session.AddFlash(), and to get all | |||
flashes, call session.Flashes(). Here is an example: | |||
func MyHandler(w http.ResponseWriter, r *http.Request) { | |||
// Get a session. | |||
session, err := store.Get(r, "session-name") | |||
if err != nil { | |||
http.Error(w, err.Error(), http.StatusInternalServerError) | |||
return | |||
} | |||
// Get the previously flashes, if any. | |||
if flashes := session.Flashes(); len(flashes) > 0 { | |||
// Use the flash values. | |||
} else { | |||
// Set a new flash. | |||
session.AddFlash("Hello, flash messages world!") | |||
} | |||
session.Save(r, w) | |||
} | |||
Flash messages are useful to set information to be read after a redirection, | |||
like after form submissions. | |||
There may also be cases where you want to store a complex datatype within a | |||
session, such as a struct. Sessions are serialised using the encoding/gob package, | |||
so it is easy to register new datatypes for storage in sessions: | |||
import( | |||
"encoding/gob" | |||
"github.com/gorilla/sessions" | |||
) | |||
type Person struct { | |||
FirstName string | |||
LastName string | |||
Email string | |||
Age int | |||
} | |||
type M map[string]interface{} | |||
func init() { | |||
gob.Register(&Person{}) | |||
gob.Register(&M{}) | |||
} | |||
As it's not possible to pass a raw type as a parameter to a function, gob.Register() | |||
relies on us passing it a value of the desired type. In the example above we've passed | |||
it a pointer to a struct and a pointer to a custom type representing a | |||
map[string]interface. (We could have passed non-pointer values if we wished.) This will | |||
then allow us to serialise/deserialise values of those types to and from our sessions. | |||
Note that because session values are stored in a map[string]interface{}, there's | |||
a need to type-assert data when retrieving it. We'll use the Person struct we registered above: | |||
func MyHandler(w http.ResponseWriter, r *http.Request) { | |||
session, err := store.Get(r, "session-name") | |||
if err != nil { | |||
http.Error(w, err.Error(), http.StatusInternalServerError) | |||
return | |||
} | |||
// Retrieve our struct and type-assert it | |||
val := session.Values["person"] | |||
var person = &Person{} | |||
if person, ok := val.(*Person); !ok { | |||
// Handle the case that it's not an expected type | |||
} | |||
// Now we can use our person object | |||
} | |||
By default, session cookies last for a month. This is probably too long for | |||
some cases, but it is easy to change this and other attributes during | |||
runtime. Sessions can be configured individually or the store can be | |||
configured and then all sessions saved using it will use that configuration. | |||
We access session.Options or store.Options to set a new configuration. The | |||
fields are basically a subset of http.Cookie fields. Let's change the | |||
maximum age of a session to one week: | |||
session.Options = &sessions.Options{ | |||
Path: "/", | |||
MaxAge: 86400 * 7, | |||
HttpOnly: true, | |||
} | |||
Sometimes we may want to change authentication and/or encryption keys without | |||
breaking existing sessions. The CookieStore supports key rotation, and to use | |||
it you just need to set multiple authentication and encryption keys, in pairs, | |||
to be tested in order: | |||
var store = sessions.NewCookieStore( | |||
[]byte("new-authentication-key"), | |||
[]byte("new-encryption-key"), | |||
[]byte("old-authentication-key"), | |||
[]byte("old-encryption-key"), | |||
) | |||
New sessions will be saved using the first pair. Old sessions can still be | |||
read because the first pair will fail, and the second will be tested. This | |||
makes it easy to "rotate" secret keys and still be able to validate existing | |||
sessions. Note: for all pairs the encryption key is optional; set it to nil | |||
or omit it and and encryption won't be used. | |||
Multiple sessions can be used in the same request, even with different | |||
session backends. When this happens, calling Save() on each session | |||
individually would be cumbersome, so we have a way to save all sessions | |||
at once: it's sessions.Save(). Here's an example: | |||
var store = sessions.NewCookieStore([]byte("something-very-secret")) | |||
func MyHandler(w http.ResponseWriter, r *http.Request) { | |||
// Get a session and set a value. | |||
session1, _ := store.Get(r, "session-one") | |||
session1.Values["foo"] = "bar" | |||
// Get another session and set another value. | |||
session2, _ := store.Get(r, "session-two") | |||
session2.Values[42] = 43 | |||
// Save all sessions. | |||
sessions.Save(r, w) | |||
} | |||
This is possible because when we call Get() from a session store, it adds the | |||
session to a common registry. Save() uses it to save all registered sessions. | |||
*/ | |||
package sessions |
@@ -0,0 +1,102 @@ | |||
// This file contains code adapted from the Go standard library | |||
// https://github.com/golang/go/blob/39ad0fd0789872f9469167be7fe9578625ff246e/src/net/http/lex.go | |||
package sessions | |||
import "strings" | |||
var isTokenTable = [127]bool{ | |||
'!': true, | |||
'#': true, | |||
'$': true, | |||
'%': true, | |||
'&': true, | |||
'\'': true, | |||
'*': true, | |||
'+': true, | |||
'-': true, | |||
'.': true, | |||
'0': true, | |||
'1': true, | |||
'2': true, | |||
'3': true, | |||
'4': true, | |||
'5': true, | |||
'6': true, | |||
'7': true, | |||
'8': true, | |||
'9': true, | |||
'A': true, | |||
'B': true, | |||
'C': true, | |||
'D': true, | |||
'E': true, | |||
'F': true, | |||
'G': true, | |||
'H': true, | |||
'I': true, | |||
'J': true, | |||
'K': true, | |||
'L': true, | |||
'M': true, | |||
'N': true, | |||
'O': true, | |||
'P': true, | |||
'Q': true, | |||
'R': true, | |||
'S': true, | |||
'T': true, | |||
'U': true, | |||
'W': true, | |||
'V': true, | |||
'X': true, | |||
'Y': true, | |||
'Z': true, | |||
'^': true, | |||
'_': true, | |||
'`': true, | |||
'a': true, | |||
'b': true, | |||
'c': true, | |||
'd': true, | |||
'e': true, | |||
'f': true, | |||
'g': true, | |||
'h': true, | |||
'i': true, | |||
'j': true, | |||
'k': true, | |||
'l': true, | |||
'm': true, | |||
'n': true, | |||
'o': true, | |||
'p': true, | |||
'q': true, | |||
'r': true, | |||
's': true, | |||
't': true, | |||
'u': true, | |||
'v': true, | |||
'w': true, | |||
'x': true, | |||
'y': true, | |||
'z': true, | |||
'|': true, | |||
'~': true, | |||
} | |||
func isToken(r rune) bool { | |||
i := int(r) | |||
return i < len(isTokenTable) && isTokenTable[i] | |||
} | |||
func isNotToken(r rune) bool { | |||
return !isToken(r) | |||
} | |||
func isCookieNameValid(raw string) bool { | |||
if raw == "" { | |||
return false | |||
} | |||
return strings.IndexFunc(raw, isNotToken) < 0 | |||
} |
@@ -0,0 +1,241 @@ | |||
// Copyright 2012 The Gorilla Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
package sessions | |||
import ( | |||
"encoding/gob" | |||
"fmt" | |||
"net/http" | |||
"time" | |||
"github.com/gorilla/context" | |||
) | |||
// Default flashes key. | |||
const flashesKey = "_flash" | |||
// Options -------------------------------------------------------------------- | |||
// Options stores configuration for a session or session store. | |||
// | |||
// Fields are a subset of http.Cookie fields. | |||
type Options struct { | |||
Path string | |||
Domain string | |||
// MaxAge=0 means no 'Max-Age' attribute specified. | |||
// MaxAge<0 means delete cookie now, equivalently 'Max-Age: 0'. | |||
// MaxAge>0 means Max-Age attribute present and given in seconds. | |||
MaxAge int | |||
Secure bool | |||
HttpOnly bool | |||
} | |||
// Session -------------------------------------------------------------------- | |||
// NewSession is called by session stores to create a new session instance. | |||
func NewSession(store Store, name string) *Session { | |||
return &Session{ | |||
Values: make(map[interface{}]interface{}), | |||
store: store, | |||
name: name, | |||
} | |||
} | |||
// Session stores the values and optional configuration for a session. | |||
type Session struct { | |||
// The ID of the session, generated by stores. It should not be used for | |||
// user data. | |||
ID string | |||
// Values contains the user-data for the session. | |||
Values map[interface{}]interface{} | |||
Options *Options | |||
IsNew bool | |||
store Store | |||
name string | |||
} | |||
// Flashes returns a slice of flash messages from the session. | |||
// | |||
// A single variadic argument is accepted, and it is optional: it defines | |||
// the flash key. If not defined "_flash" is used by default. | |||
func (s *Session) Flashes(vars ...string) []interface{} { | |||
var flashes []interface{} | |||
key := flashesKey | |||
if len(vars) > 0 { | |||
key = vars[0] | |||
} | |||
if v, ok := s.Values[key]; ok { | |||
// Drop the flashes and return it. | |||
delete(s.Values, key) | |||
flashes = v.([]interface{}) | |||
} | |||
return flashes | |||
} | |||
// AddFlash adds a flash message to the session. | |||
// | |||
// A single variadic argument is accepted, and it is optional: it defines | |||
// the flash key. If not defined "_flash" is used by default. | |||
func (s *Session) AddFlash(value interface{}, vars ...string) { | |||
key := flashesKey | |||
if len(vars) > 0 { | |||
key = vars[0] | |||
} | |||
var flashes []interface{} | |||
if v, ok := s.Values[key]; ok { | |||
flashes = v.([]interface{}) | |||
} | |||
s.Values[key] = append(flashes, value) | |||
} | |||
// Save is a convenience method to save this session. It is the same as calling | |||
// store.Save(request, response, session). You should call Save before writing to | |||
// the response or returning from the handler. | |||
func (s *Session) Save(r *http.Request, w http.ResponseWriter) error { | |||
return s.store.Save(r, w, s) | |||
} | |||
// Name returns the name used to register the session. | |||
func (s *Session) Name() string { | |||
return s.name | |||
} | |||
// Store returns the session store used to register the session. | |||
func (s *Session) Store() Store { | |||
return s.store | |||
} | |||
// Registry ------------------------------------------------------------------- | |||
// sessionInfo stores a session tracked by the registry. | |||
type sessionInfo struct { | |||
s *Session | |||
e error | |||
} | |||
// contextKey is the type used to store the registry in the context. | |||
type contextKey int | |||
// registryKey is the key used to store the registry in the context. | |||
const registryKey contextKey = 0 | |||
// GetRegistry returns a registry instance for the current request. | |||
func GetRegistry(r *http.Request) *Registry { | |||
registry := context.Get(r, registryKey) | |||
if registry != nil { | |||
return registry.(*Registry) | |||
} | |||
newRegistry := &Registry{ | |||
request: r, | |||
sessions: make(map[string]sessionInfo), | |||
} | |||
context.Set(r, registryKey, newRegistry) | |||
return newRegistry | |||
} | |||
// Registry stores sessions used during a request. | |||
type Registry struct { | |||
request *http.Request | |||
sessions map[string]sessionInfo | |||
} | |||
// Get registers and returns a session for the given name and session store. | |||
// | |||
// It returns a new session if there are no sessions registered for the name. | |||
func (s *Registry) Get(store Store, name string) (session *Session, err error) { | |||
if !isCookieNameValid(name) { | |||
return nil, fmt.Errorf("sessions: invalid character in cookie name: %s", name) | |||
} | |||
if info, ok := s.sessions[name]; ok { | |||
session, err = info.s, info.e | |||
} else { | |||
session, err = store.New(s.request, name) | |||
session.name = name | |||
s.sessions[name] = sessionInfo{s: session, e: err} | |||
} | |||
session.store = store | |||
return | |||
} | |||
// Save saves all sessions registered for the current request. | |||
func (s *Registry) Save(w http.ResponseWriter) error { | |||
var errMulti MultiError | |||
for name, info := range s.sessions { | |||
session := info.s | |||
if session.store == nil { | |||
errMulti = append(errMulti, fmt.Errorf( | |||
"sessions: missing store for session %q", name)) | |||
} else if err := session.store.Save(s.request, w, session); err != nil { | |||
errMulti = append(errMulti, fmt.Errorf( | |||
"sessions: error saving session %q -- %v", name, err)) | |||
} | |||
} | |||
if errMulti != nil { | |||
return errMulti | |||
} | |||
return nil | |||
} | |||
// Helpers -------------------------------------------------------------------- | |||
func init() { | |||
gob.Register([]interface{}{}) | |||
} | |||
// Save saves all sessions used during the current request. | |||
func Save(r *http.Request, w http.ResponseWriter) error { | |||
return GetRegistry(r).Save(w) | |||
} | |||
// NewCookie returns an http.Cookie with the options set. It also sets | |||
// the Expires field calculated based on the MaxAge value, for Internet | |||
// Explorer compatibility. | |||
func NewCookie(name, value string, options *Options) *http.Cookie { | |||
cookie := &http.Cookie{ | |||
Name: name, | |||
Value: value, | |||
Path: options.Path, | |||
Domain: options.Domain, | |||
MaxAge: options.MaxAge, | |||
Secure: options.Secure, | |||
HttpOnly: options.HttpOnly, | |||
} | |||
if options.MaxAge > 0 { | |||
d := time.Duration(options.MaxAge) * time.Second | |||
cookie.Expires = time.Now().Add(d) | |||
} else if options.MaxAge < 0 { | |||
// Set it to the past to expire now. | |||
cookie.Expires = time.Unix(1, 0) | |||
} | |||
return cookie | |||
} | |||
// Error ---------------------------------------------------------------------- | |||
// MultiError stores multiple errors. | |||
// | |||
// Borrowed from the App Engine SDK. | |||
type MultiError []error | |||
func (m MultiError) Error() string { | |||
s, n := "", 0 | |||
for _, e := range m { | |||
if e != nil { | |||
if n == 0 { | |||
s = e.Error() | |||
} | |||
n++ | |||
} | |||
} | |||
switch n { | |||
case 0: | |||
return "(0 errors)" | |||
case 1: | |||
return s | |||
case 2: | |||
return s + " (and 1 other error)" | |||
} | |||
return fmt.Sprintf("%s (and %d other errors)", s, n-1) | |||
} |
@@ -0,0 +1,295 @@ | |||
// Copyright 2012 The Gorilla Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
package sessions | |||
import ( | |||
"encoding/base32" | |||
"io/ioutil" | |||
"net/http" | |||
"os" | |||
"path/filepath" | |||
"strings" | |||
"sync" | |||
"github.com/gorilla/securecookie" | |||
) | |||
// Store is an interface for custom session stores. | |||
// | |||
// See CookieStore and FilesystemStore for examples. | |||
type Store interface { | |||
// Get should return a cached session. | |||
Get(r *http.Request, name string) (*Session, error) | |||
// New should create and return a new session. | |||
// | |||
// Note that New should never return a nil session, even in the case of | |||
// an error if using the Registry infrastructure to cache the session. | |||
New(r *http.Request, name string) (*Session, error) | |||
// Save should persist session to the underlying store implementation. | |||
Save(r *http.Request, w http.ResponseWriter, s *Session) error | |||
} | |||
// CookieStore ---------------------------------------------------------------- | |||
// NewCookieStore returns a new CookieStore. | |||
// | |||
// Keys are defined in pairs to allow key rotation, but the common case is | |||
// to set a single authentication key and optionally an encryption key. | |||
// | |||
// The first key in a pair is used for authentication and the second for | |||
// encryption. The encryption key can be set to nil or omitted in the last | |||
// pair, but the authentication key is required in all pairs. | |||
// | |||
// It is recommended to use an authentication key with 32 or 64 bytes. | |||
// The encryption key, if set, must be either 16, 24, or 32 bytes to select | |||
// AES-128, AES-192, or AES-256 modes. | |||
// | |||
// Use the convenience function securecookie.GenerateRandomKey() to create | |||
// strong keys. | |||
func NewCookieStore(keyPairs ...[]byte) *CookieStore { | |||
cs := &CookieStore{ | |||
Codecs: securecookie.CodecsFromPairs(keyPairs...), | |||
Options: &Options{ | |||
Path: "/", | |||
MaxAge: 86400 * 30, | |||
}, | |||
} | |||
cs.MaxAge(cs.Options.MaxAge) | |||
return cs | |||
} | |||
// CookieStore stores sessions using secure cookies. | |||
type CookieStore struct { | |||
Codecs []securecookie.Codec | |||
Options *Options // default configuration | |||
} | |||
// Get returns a session for the given name after adding it to the registry. | |||
// | |||
// It returns a new session if the sessions doesn't exist. Access IsNew on | |||
// the session to check if it is an existing session or a new one. | |||
// | |||
// It returns a new session and an error if the session exists but could | |||
// not be decoded. | |||
func (s *CookieStore) Get(r *http.Request, name string) (*Session, error) { | |||
return GetRegistry(r).Get(s, name) | |||
} | |||
// New returns a session for the given name without adding it to the registry. | |||
// | |||
// The difference between New() and Get() is that calling New() twice will | |||
// decode the session data twice, while Get() registers and reuses the same | |||
// decoded session after the first call. | |||
func (s *CookieStore) New(r *http.Request, name string) (*Session, error) { | |||
session := NewSession(s, name) | |||
opts := *s.Options | |||
session.Options = &opts | |||
session.IsNew = true | |||
var err error | |||
if c, errCookie := r.Cookie(name); errCookie == nil { | |||
err = securecookie.DecodeMulti(name, c.Value, &session.Values, | |||
s.Codecs...) | |||
if err == nil { | |||
session.IsNew = false | |||
} | |||
} | |||
return session, err | |||
} | |||
// Save adds a single session to the response. | |||
func (s *CookieStore) Save(r *http.Request, w http.ResponseWriter, | |||
session *Session) error { | |||
encoded, err := securecookie.EncodeMulti(session.Name(), session.Values, | |||
s.Codecs...) | |||
if err != nil { | |||
return err | |||
} | |||
http.SetCookie(w, NewCookie(session.Name(), encoded, session.Options)) | |||
return nil | |||
} | |||
// MaxAge sets the maximum age for the store and the underlying cookie | |||
// implementation. Individual sessions can be deleted by setting Options.MaxAge | |||
// = -1 for that session. | |||
func (s *CookieStore) MaxAge(age int) { | |||
s.Options.MaxAge = age | |||
// Set the maxAge for each securecookie instance. | |||
for _, codec := range s.Codecs { | |||
if sc, ok := codec.(*securecookie.SecureCookie); ok { | |||
sc.MaxAge(age) | |||
} | |||
} | |||
} | |||
// FilesystemStore ------------------------------------------------------------ | |||
var fileMutex sync.RWMutex | |||
// NewFilesystemStore returns a new FilesystemStore. | |||
// | |||
// The path argument is the directory where sessions will be saved. If empty | |||
// it will use os.TempDir(). | |||
// | |||
// See NewCookieStore() for a description of the other parameters. | |||
func NewFilesystemStore(path string, keyPairs ...[]byte) *FilesystemStore { | |||
if path == "" { | |||
path = os.TempDir() | |||
} | |||
fs := &FilesystemStore{ | |||
Codecs: securecookie.CodecsFromPairs(keyPairs...), | |||
Options: &Options{ | |||
Path: "/", | |||
MaxAge: 86400 * 30, | |||
}, | |||
path: path, | |||
} | |||
fs.MaxAge(fs.Options.MaxAge) | |||
return fs | |||
} | |||
// FilesystemStore stores sessions in the filesystem. | |||
// | |||
// It also serves as a reference for custom stores. | |||
// | |||
// This store is still experimental and not well tested. Feedback is welcome. | |||
type FilesystemStore struct { | |||
Codecs []securecookie.Codec | |||
Options *Options // default configuration | |||
path string | |||
} | |||
// MaxLength restricts the maximum length of new sessions to l. | |||
// If l is 0 there is no limit to the size of a session, use with caution. | |||
// The default for a new FilesystemStore is 4096. | |||
func (s *FilesystemStore) MaxLength(l int) { | |||
for _, c := range s.Codecs { | |||
if codec, ok := c.(*securecookie.SecureCookie); ok { | |||
codec.MaxLength(l) | |||
} | |||
} | |||
} | |||
// Get returns a session for the given name after adding it to the registry. | |||
// | |||
// See CookieStore.Get(). | |||
func (s *FilesystemStore) Get(r *http.Request, name string) (*Session, error) { | |||
return GetRegistry(r).Get(s, name) | |||
} | |||
// New returns a session for the given name without adding it to the registry. | |||
// | |||
// See CookieStore.New(). | |||
func (s *FilesystemStore) New(r *http.Request, name string) (*Session, error) { | |||
session := NewSession(s, name) | |||
opts := *s.Options | |||
session.Options = &opts | |||
session.IsNew = true | |||
var err error | |||
if c, errCookie := r.Cookie(name); errCookie == nil { | |||
err = securecookie.DecodeMulti(name, c.Value, &session.ID, s.Codecs...) | |||
if err == nil { | |||
err = s.load(session) | |||
if err == nil { | |||
session.IsNew = false | |||
} | |||
} | |||
} | |||
return session, err | |||
} | |||
// Save adds a single session to the response. | |||
// | |||
// If the Options.MaxAge of the session is <= 0 then the session file will be | |||
// deleted from the store path. With this process it enforces the properly | |||
// session cookie handling so no need to trust in the cookie management in the | |||
// web browser. | |||
func (s *FilesystemStore) Save(r *http.Request, w http.ResponseWriter, | |||
session *Session) error { | |||
// Delete if max-age is <= 0 | |||
if session.Options.MaxAge <= 0 { | |||
if err := s.erase(session); err != nil { | |||
return err | |||
} | |||
http.SetCookie(w, NewCookie(session.Name(), "", session.Options)) | |||
return nil | |||
} | |||
if session.ID == "" { | |||
// Because the ID is used in the filename, encode it to | |||
// use alphanumeric characters only. | |||
session.ID = strings.TrimRight( | |||
base32.StdEncoding.EncodeToString( | |||
securecookie.GenerateRandomKey(32)), "=") | |||
} | |||
if err := s.save(session); err != nil { | |||
return err | |||
} | |||
encoded, err := securecookie.EncodeMulti(session.Name(), session.ID, | |||
s.Codecs...) | |||
if err != nil { | |||
return err | |||
} | |||
http.SetCookie(w, NewCookie(session.Name(), encoded, session.Options)) | |||
return nil | |||
} | |||
// MaxAge sets the maximum age for the store and the underlying cookie | |||
// implementation. Individual sessions can be deleted by setting Options.MaxAge | |||
// = -1 for that session. | |||
func (s *FilesystemStore) MaxAge(age int) { | |||
s.Options.MaxAge = age | |||
// Set the maxAge for each securecookie instance. | |||
for _, codec := range s.Codecs { | |||
if sc, ok := codec.(*securecookie.SecureCookie); ok { | |||
sc.MaxAge(age) | |||
} | |||
} | |||
} | |||
// save writes encoded session.Values to a file. | |||
func (s *FilesystemStore) save(session *Session) error { | |||
encoded, err := securecookie.EncodeMulti(session.Name(), session.Values, | |||
s.Codecs...) | |||
if err != nil { | |||
return err | |||
} | |||
filename := filepath.Join(s.path, "session_"+session.ID) | |||
fileMutex.Lock() | |||
defer fileMutex.Unlock() | |||
return ioutil.WriteFile(filename, []byte(encoded), 0600) | |||
} | |||
// load reads a file and decodes its content into session.Values. | |||
func (s *FilesystemStore) load(session *Session) error { | |||
filename := filepath.Join(s.path, "session_"+session.ID) | |||
fileMutex.RLock() | |||
defer fileMutex.RUnlock() | |||
fdata, err := ioutil.ReadFile(filename) | |||
if err != nil { | |||
return err | |||
} | |||
if err = securecookie.DecodeMulti(session.Name(), string(fdata), | |||
&session.Values, s.Codecs...); err != nil { | |||
return err | |||
} | |||
return nil | |||
} | |||
// delete session file | |||
func (s *FilesystemStore) erase(session *Session) error { | |||
filename := filepath.Join(s.path, "session_"+session.ID) | |||
fileMutex.RLock() | |||
defer fileMutex.RUnlock() | |||
err := os.Remove(filename) | |||
return err | |||
} |
@@ -0,0 +1,22 @@ | |||
Copyright (c) 2014 Mark Bates | |||
MIT License | |||
Permission is hereby granted, free of charge, to any person obtaining | |||
a copy of this software and associated documentation files (the | |||
"Software"), to deal in the Software without restriction, including | |||
without limitation the rights to use, copy, modify, merge, publish, | |||
distribute, sublicense, and/or sell copies of the Software, and to | |||
permit persons to whom the Software is furnished to do so, subject to | |||
the following conditions: | |||
The above copyright notice and this permission notice shall be | |||
included in all copies or substantial portions of the Software. | |||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, | |||
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF | |||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND | |||
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE | |||
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION | |||
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION | |||
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. |
@@ -0,0 +1,143 @@ | |||
# Goth: Multi-Provider Authentication for Go [![GoDoc](https://godoc.org/github.com/markbates/goth?status.svg)](https://godoc.org/github.com/markbates/goth) [![Build Status](https://travis-ci.org/markbates/goth.svg)](https://travis-ci.org/markbates/goth) | |||
Package goth provides a simple, clean, and idiomatic way to write authentication | |||
packages for Go web applications. | |||
Unlike other similar packages, Goth, lets you write OAuth, OAuth2, or any other | |||
protocol providers, as long as they implement the `Provider` and `Session` interfaces. | |||
This package was inspired by [https://github.com/intridea/omniauth](https://github.com/intridea/omniauth). | |||
## Installation | |||
```text | |||
$ go get github.com/markbates/goth | |||
``` | |||
## Supported Providers | |||
* Amazon | |||
* Auth0 | |||
* Bitbucket | |||
* Box | |||
* Cloud Foundry | |||
* Dailymotion | |||
* Deezer | |||
* Digital Ocean | |||
* Discord | |||
* Dropbox | |||
* Fitbit | |||
* GitHub | |||
* Gitlab | |||
* Google+ | |||
* Heroku | |||
* InfluxCloud | |||
* Intercom | |||
* Lastfm | |||
* Meetup | |||
* OneDrive | |||
* OpenID Connect (auto discovery) | |||
* Paypal | |||
* SalesForce | |||
* Slack | |||
* Soundcloud | |||
* Spotify | |||
* Steam | |||
* Stripe | |||
* Twitch | |||
* Uber | |||
* Wepay | |||
* Yahoo | |||
* Yammer | |||
## Examples | |||
See the [examples](examples) folder for a working application that lets users authenticate | |||
through Twitter, Facebook, Google Plus etc. | |||
To run the example either clone the source from GitHub | |||
```text | |||
$ git clone git@github.com:markbates/goth.git | |||
``` | |||
or use | |||
```text | |||
$ go get github.com/markbates/goth | |||
``` | |||
```text | |||
$ cd goth/examples | |||
$ go get -v | |||
$ go build | |||
$ ./examples | |||
``` | |||
Now open up your browser and go to [http://localhost:3000](http://localhost:3000) to see the example. | |||
To actually use the different providers, please make sure you configure them given the system environments as defined in the examples/main.go file | |||
## Issues | |||
Issues always stand a significantly better chance of getting fixed if the are accompanied by a | |||
pull request. | |||
## Contributing | |||
Would I love to see more providers? Certainly! Would you love to contribute one? Hopefully, yes! | |||
1. Fork it | |||
2. Create your feature branch (git checkout -b my-new-feature) | |||
3. Write Tests! | |||
4. Commit your changes (git commit -am 'Add some feature') | |||
5. Push to the branch (git push origin my-new-feature) | |||
6. Create new Pull Request | |||
## Contributors | |||
* Mark Bates | |||
* Tyler Bunnell | |||
* Corey McGrillis | |||
* willemvd | |||
* Rakesh Goyal | |||
* Andy Grunwald | |||
* Glenn Walker | |||
* Kevin Fitzpatrick | |||
* Ben Tranter | |||
* Sharad Ganapathy | |||
* Andrew Chilton | |||
* sharadgana | |||
* Aurorae | |||
* Craig P Jolicoeur | |||
* Zac Bergquist | |||
* Geoff Franks | |||
* Raphael Geronimi | |||
* Noah Shibley | |||
* lumost | |||
* oov | |||
* Felix Lamouroux | |||
* Rafael Quintela | |||
* Tyler | |||
* DenSm | |||
* Samy KACIMI | |||
* dante gray | |||
* Noah | |||
* Jacob Walker | |||
* Marin Martinic | |||
* Roy | |||
* Omni Adams | |||
* Sasa Brankovic | |||
* dkhamsing | |||
* Dante Swift | |||
* Attila Domokos | |||
* Albin Gilles | |||
* Syed Zubairuddin | |||
* Johnny Boursiquot | |||
* Jerome Touffe-Blin | |||
* bryanl | |||
* Masanobu YOSHIOKA | |||
* Jonathan Hall | |||
* HaiMing.Yin | |||
* Sairam Kunala |
@@ -0,0 +1,10 @@ | |||
/* | |||
Package goth provides a simple, clean, and idiomatic way to write authentication | |||
packages for Go web applications. | |||
This package was inspired by https://github.com/intridea/omniauth. | |||
See the examples folder for a working application that lets users authenticate | |||
through Twitter or Facebook. | |||
*/ | |||
package goth |
@@ -0,0 +1,219 @@ | |||
/* | |||
Package gothic wraps common behaviour when using Goth. This makes it quick, and easy, to get up | |||
and running with Goth. Of course, if you want complete control over how things flow, in regards | |||
to the authentication process, feel free and use Goth directly. | |||
See https://github.com/markbates/goth/examples/main.go to see this in action. | |||
*/ | |||
package gothic | |||
import ( | |||
"errors" | |||
"fmt" | |||
"net/http" | |||
"os" | |||
"github.com/gorilla/mux" | |||
"github.com/gorilla/sessions" | |||
"github.com/markbates/goth" | |||
) | |||
// SessionName is the key used to access the session store. | |||
const SessionName = "_gothic_session" | |||
// Store can/should be set by applications using gothic. The default is a cookie store. | |||
var Store sessions.Store | |||
var defaultStore sessions.Store | |||
var keySet = false | |||
func init() { | |||
key := []byte(os.Getenv("SESSION_SECRET")) | |||
keySet = len(key) != 0 | |||
Store = sessions.NewCookieStore([]byte(key)) | |||
defaultStore = Store | |||
} | |||
/* | |||
BeginAuthHandler is a convienence handler for starting the authentication process. | |||
It expects to be able to get the name of the provider from the query parameters | |||
as either "provider" or ":provider". | |||
BeginAuthHandler will redirect the user to the appropriate authentication end-point | |||
for the requested provider. | |||
See https://github.com/markbates/goth/examples/main.go to see this in action. | |||
*/ | |||
func BeginAuthHandler(res http.ResponseWriter, req *http.Request) { | |||
url, err := GetAuthURL(res, req) | |||
if err != nil { | |||
res.WriteHeader(http.StatusBadRequest) | |||
fmt.Fprintln(res, err) | |||
return | |||
} | |||
http.Redirect(res, req, url, http.StatusTemporaryRedirect) | |||
} | |||
// SetState sets the state string associated with the given request. | |||
// If no state string is associated with the request, one will be generated. | |||
// This state is sent to the provider and can be retrieved during the | |||
// callback. | |||
var SetState = func(req *http.Request) string { | |||
state := req.URL.Query().Get("state") | |||
if len(state) > 0 { | |||
return state | |||
} | |||
return "state" | |||
} | |||
// GetState gets the state returned by the provider during the callback. | |||
// This is used to prevent CSRF attacks, see | |||
// http://tools.ietf.org/html/rfc6749#section-10.12 | |||
var GetState = func(req *http.Request) string { | |||
return req.URL.Query().Get("state") | |||
} | |||
/* | |||
GetAuthURL starts the authentication process with the requested provided. | |||
It will return a URL that should be used to send users to. | |||
It expects to be able to get the name of the provider from the query parameters | |||
as either "provider" or ":provider". | |||
I would recommend using the BeginAuthHandler instead of doing all of these steps | |||
yourself, but that's entirely up to you. | |||
*/ | |||
func GetAuthURL(res http.ResponseWriter, req *http.Request) (string, error) { | |||
if !keySet && defaultStore == Store { | |||
fmt.Println("goth/gothic: no SESSION_SECRET environment variable is set. The default cookie store is not available and any calls will fail. Ignore this warning if you are using a different store.") | |||
} | |||
providerName, err := GetProviderName(req) | |||
if err != nil { | |||
return "", err | |||
} | |||
provider, err := goth.GetProvider(providerName) | |||
if err != nil { | |||
return "", err | |||
} | |||
sess, err := provider.BeginAuth(SetState(req)) | |||
if err != nil { | |||
return "", err | |||
} | |||
url, err := sess.GetAuthURL() | |||
if err != nil { | |||
return "", err | |||
} | |||
err = storeInSession(providerName, sess.Marshal(), req, res) | |||
if err != nil { | |||
return "", err | |||
} | |||
return url, err | |||
} | |||
/* | |||
CompleteUserAuth does what it says on the tin. It completes the authentication | |||
process and fetches all of the basic information about the user from the provider. | |||
It expects to be able to get the name of the provider from the query parameters | |||
as either "provider" or ":provider". | |||
See https://github.com/markbates/goth/examples/main.go to see this in action. | |||
*/ | |||
var CompleteUserAuth = func(res http.ResponseWriter, req *http.Request) (goth.User, error) { | |||
if !keySet && defaultStore == Store { | |||
fmt.Println("goth/gothic: no SESSION_SECRET environment variable is set. The default cookie store is not available and any calls will fail. Ignore this warning if you are using a different store.") | |||
} | |||
providerName, err := GetProviderName(req) | |||
if err != nil { | |||
return goth.User{}, err | |||
} | |||
provider, err := goth.GetProvider(providerName) | |||
if err != nil { | |||
return goth.User{}, err | |||
} | |||
value, err := getFromSession(providerName, req) | |||
if err != nil { | |||
return goth.User{}, err | |||
} | |||
sess, err := provider.UnmarshalSession(value) | |||
if err != nil { | |||
return goth.User{}, err | |||
} | |||
user, err := provider.FetchUser(sess) | |||
if err == nil { | |||
// user can be found with existing session data | |||
return user, err | |||
} | |||
// get new token and retry fetch | |||
_, err = sess.Authorize(provider, req.URL.Query()) | |||
if err != nil { | |||
return goth.User{}, err | |||
} | |||
err = storeInSession(providerName, sess.Marshal(), req, res) | |||
if err != nil { | |||
return goth.User{}, err | |||
} | |||
return provider.FetchUser(sess) | |||
} | |||
// GetProviderName is a function used to get the name of a provider | |||
// for a given request. By default, this provider is fetched from | |||
// the URL query string. If you provide it in a different way, | |||
// assign your own function to this variable that returns the provider | |||
// name for your request. | |||
var GetProviderName = getProviderName | |||
func getProviderName(req *http.Request) (string, error) { | |||
provider := req.URL.Query().Get("provider") | |||
if provider == "" { | |||
if p, ok := mux.Vars(req)["provider"]; ok { | |||
return p, nil | |||
} | |||
} | |||
if provider == "" { | |||
provider = req.URL.Query().Get(":provider") | |||
} | |||
if provider == "" { | |||
return provider, errors.New("you must select a provider") | |||
} | |||
return provider, nil | |||
} | |||
func storeInSession(key string, value string, req *http.Request, res http.ResponseWriter) error { | |||
session, _ := Store.Get(req, key + SessionName) | |||
session.Values[key] = value | |||
return session.Save(req, res) | |||
} | |||
func getFromSession(key string, req *http.Request) (string, error) { | |||
session, _ := Store.Get(req, key + SessionName) | |||
value := session.Values[key] | |||
if value == nil { | |||
return "", errors.New("could not find a matching session for this request") | |||
} | |||
return value.(string), nil | |||
} |
@@ -0,0 +1,75 @@ | |||
package goth | |||
import ( | |||
"fmt" | |||
"net/http" | |||
"golang.org/x/net/context" | |||
"golang.org/x/oauth2" | |||
) | |||
// Provider needs to be implemented for each 3rd party authentication provider | |||
// e.g. Facebook, Twitter, etc... | |||
type Provider interface { | |||
Name() string | |||
SetName(name string) | |||
BeginAuth(state string) (Session, error) | |||
UnmarshalSession(string) (Session, error) | |||
FetchUser(Session) (User, error) | |||
Debug(bool) | |||
RefreshToken(refreshToken string) (*oauth2.Token, error) //Get new access token based on the refresh token | |||
RefreshTokenAvailable() bool //Refresh token is provided by auth provider or not | |||
} | |||
const NoAuthUrlErrorMessage = "an AuthURL has not been set" | |||
// Providers is list of known/available providers. | |||
type Providers map[string]Provider | |||
var providers = Providers{} | |||
// UseProviders adds a list of available providers for use with Goth. | |||
// Can be called multiple times. If you pass the same provider more | |||
// than once, the last will be used. | |||
func UseProviders(viders ...Provider) { | |||
for _, provider := range viders { | |||
providers[provider.Name()] = provider | |||
} | |||
} | |||
// GetProviders returns a list of all the providers currently in use. | |||
func GetProviders() Providers { | |||
return providers | |||
} | |||
// GetProvider returns a previously created provider. If Goth has not | |||
// been told to use the named provider it will return an error. | |||
func GetProvider(name string) (Provider, error) { | |||
provider := providers[name] | |||
if provider == nil { | |||
return nil, fmt.Errorf("no provider for %s exists", name) | |||
} | |||
return provider, nil | |||
} | |||
// ClearProviders will remove all providers currently in use. | |||
// This is useful, mostly, for testing purposes. | |||
func ClearProviders() { | |||
providers = Providers{} | |||
} | |||
// ContextForClient provides a context for use with oauth2. | |||
func ContextForClient(h *http.Client) context.Context { | |||
if h == nil { | |||
return oauth2.NoContext | |||
} | |||
return context.WithValue(oauth2.NoContext, oauth2.HTTPClient, h) | |||
} | |||
// HTTPClientWithFallBack to be used in all fetch operations. | |||
func HTTPClientWithFallBack(h *http.Client) *http.Client { | |||
if h != nil { | |||
return h | |||
} | |||
return http.DefaultClient | |||
} |
@@ -0,0 +1,224 @@ | |||
// Package github implements the OAuth2 protocol for authenticating users through Github. | |||
// This package can be used as a reference implementation of an OAuth2 provider for Goth. | |||
package github | |||
import ( | |||
"bytes" | |||
"encoding/json" | |||
"errors" | |||
"fmt" | |||
"io" | |||
"io/ioutil" | |||
"net/http" | |||
"net/url" | |||
"strconv" | |||
"strings" | |||
"github.com/markbates/goth" | |||
"golang.org/x/oauth2" | |||
) | |||
// These vars define the Authentication, Token, and API URLS for GitHub. If | |||
// using GitHub enterprise you should change these values before calling New. | |||
// | |||
// Examples: | |||
// github.AuthURL = "https://github.acme.com/login/oauth/authorize | |||
// github.TokenURL = "https://github.acme.com/login/oauth/access_token | |||
// github.ProfileURL = "https://github.acme.com/api/v3/user | |||
// github.EmailURL = "https://github.acme.com/api/v3/user/emails | |||
var ( | |||
AuthURL = "https://github.com/login/oauth/authorize" | |||
TokenURL = "https://github.com/login/oauth/access_token" | |||
ProfileURL = "https://api.github.com/user" | |||
EmailURL = "https://api.github.com/user/emails" | |||
) | |||
// New creates a new Github provider, and sets up important connection details. | |||
// You should always call `github.New` to get a new Provider. Never try to create | |||
// one manually. | |||
func New(clientKey, secret, callbackURL string, scopes ...string) *Provider { | |||
p := &Provider{ | |||
ClientKey: clientKey, | |||
Secret: secret, | |||
CallbackURL: callbackURL, | |||
providerName: "github", | |||
} | |||
p.config = newConfig(p, scopes) | |||
return p | |||
} | |||
// Provider is the implementation of `goth.Provider` for accessing Github. | |||
type Provider struct { | |||
ClientKey string | |||
Secret string | |||
CallbackURL string | |||
HTTPClient *http.Client | |||
config *oauth2.Config | |||
providerName string | |||
} | |||
// Name is the name used to retrieve this provider later. | |||
func (p *Provider) Name() string { | |||
return p.providerName | |||
} | |||
// SetName is to update the name of the provider (needed in case of multiple providers of 1 type) | |||
func (p *Provider) SetName(name string) { | |||
p.providerName = name | |||
} | |||
func (p *Provider) Client() *http.Client { | |||
return goth.HTTPClientWithFallBack(p.HTTPClient) | |||
} | |||
// Debug is a no-op for the github package. | |||
func (p *Provider) Debug(debug bool) {} | |||
// BeginAuth asks Github for an authentication end-point. | |||
func (p *Provider) BeginAuth(state string) (goth.Session, error) { | |||
url := p.config.AuthCodeURL(state) | |||
session := &Session{ | |||
AuthURL: url, | |||
} | |||
return session, nil | |||
} | |||
// FetchUser will go to Github and access basic information about the user. | |||
func (p *Provider) FetchUser(session goth.Session) (goth.User, error) { | |||
sess := session.(*Session) | |||
user := goth.User{ | |||
AccessToken: sess.AccessToken, | |||
Provider: p.Name(), | |||
} | |||
if user.AccessToken == "" { | |||
// data is not yet retrieved since accessToken is still empty | |||
return user, fmt.Errorf("%s cannot get user information without accessToken", p.providerName) | |||
} | |||
response, err := p.Client().Get(ProfileURL + "?access_token=" + url.QueryEscape(sess.AccessToken)) | |||
if err != nil { | |||
return user, err | |||
} | |||
defer response.Body.Close() | |||
if response.StatusCode != http.StatusOK { | |||
return user, fmt.Errorf("GitHub API responded with a %d trying to fetch user information", response.StatusCode) | |||
} | |||
bits, err := ioutil.ReadAll(response.Body) | |||
if err != nil { | |||
return user, err | |||
} | |||
err = json.NewDecoder(bytes.NewReader(bits)).Decode(&user.RawData) | |||
if err != nil { | |||
return user, err | |||
} | |||
err = userFromReader(bytes.NewReader(bits), &user) | |||
if err != nil { | |||
return user, err | |||
} | |||
if user.Email == "" { | |||
for _, scope := range p.config.Scopes { | |||
if strings.TrimSpace(scope) == "user" || strings.TrimSpace(scope) == "user:email" { | |||
user.Email, err = getPrivateMail(p, sess) | |||
if err != nil { | |||
return user, err | |||
} | |||
break | |||
} | |||
} | |||
} | |||
return user, err | |||
} | |||
func userFromReader(reader io.Reader, user *goth.User) error { | |||
u := struct { | |||
ID int `json:"id"` | |||
Email string `json:"email"` | |||
Bio string `json:"bio"` | |||
Name string `json:"name"` | |||
Login string `json:"login"` | |||
Picture string `json:"avatar_url"` | |||
Location string `json:"location"` | |||
}{} | |||
err := json.NewDecoder(reader).Decode(&u) | |||
if err != nil { | |||
return err | |||
} | |||
user.Name = u.Name | |||
user.NickName = u.Login | |||
user.Email = u.Email | |||
user.Description = u.Bio | |||
user.AvatarURL = u.Picture | |||
user.UserID = strconv.Itoa(u.ID) | |||
user.Location = u.Location | |||
return err | |||
} | |||
func getPrivateMail(p *Provider, sess *Session) (email string, err error) { | |||
response, err := p.Client().Get(EmailURL + "?access_token=" + url.QueryEscape(sess.AccessToken)) | |||
if err != nil { | |||
if response != nil { | |||
response.Body.Close() | |||
} | |||
return email, err | |||
} | |||
defer response.Body.Close() | |||
if response.StatusCode != http.StatusOK { | |||
return email, fmt.Errorf("GitHub API responded with a %d trying to fetch user email", response.StatusCode) | |||
} | |||
var mailList = []struct { | |||
Email string `json:"email"` | |||
Primary bool `json:"primary"` | |||
Verified bool `json:"verified"` | |||
}{} | |||
err = json.NewDecoder(response.Body).Decode(&mailList) | |||
if err != nil { | |||
return email, err | |||
} | |||
for _, v := range mailList { | |||
if v.Primary && v.Verified { | |||
return v.Email, nil | |||
} | |||
} | |||
// can't get primary email - shouldn't be possible | |||
return | |||
} | |||
func newConfig(provider *Provider, scopes []string) *oauth2.Config { | |||
c := &oauth2.Config{ | |||
ClientID: provider.ClientKey, | |||
ClientSecret: provider.Secret, | |||
RedirectURL: provider.CallbackURL, | |||
Endpoint: oauth2.Endpoint{ | |||
AuthURL: AuthURL, | |||
TokenURL: TokenURL, | |||
}, | |||
Scopes: []string{}, | |||
} | |||
for _, scope := range scopes { | |||
c.Scopes = append(c.Scopes, scope) | |||
} | |||
return c | |||
} | |||
//RefreshToken refresh token is not provided by github | |||
func (p *Provider) RefreshToken(refreshToken string) (*oauth2.Token, error) { | |||
return nil, errors.New("Refresh token is not provided by github") | |||
} | |||
//RefreshTokenAvailable refresh token is not provided by github | |||
func (p *Provider) RefreshTokenAvailable() bool { | |||
return false | |||
} |
@@ -0,0 +1,56 @@ | |||
package github | |||
import ( | |||
"encoding/json" | |||
"errors" | |||
"strings" | |||
"github.com/markbates/goth" | |||
) | |||
// Session stores data during the auth process with Github. | |||
type Session struct { | |||
AuthURL string | |||
AccessToken string | |||
} | |||
// GetAuthURL will return the URL set by calling the `BeginAuth` function on the Github provider. | |||
func (s Session) GetAuthURL() (string, error) { | |||
if s.AuthURL == "" { | |||
return "", errors.New(goth.NoAuthUrlErrorMessage) | |||
} | |||
return s.AuthURL, nil | |||
} | |||
// Authorize the session with Github and return the access token to be stored for future use. | |||
func (s *Session) Authorize(provider goth.Provider, params goth.Params) (string, error) { | |||
p := provider.(*Provider) | |||
token, err := p.config.Exchange(goth.ContextForClient(p.Client()), params.Get("code")) | |||
if err != nil { | |||
return "", err | |||
} | |||
if !token.Valid() { | |||
return "", errors.New("Invalid token received from provider") | |||
} | |||
s.AccessToken = token.AccessToken | |||
return token.AccessToken, err | |||
} | |||
// Marshal the session into a string | |||
func (s Session) Marshal() string { | |||
b, _ := json.Marshal(s) | |||
return string(b) | |||
} | |||
func (s Session) String() string { | |||
return s.Marshal() | |||
} | |||
// UnmarshalSession will unmarshal a JSON string into a session. | |||
func (p *Provider) UnmarshalSession(data string) (goth.Session, error) { | |||
sess := &Session{} | |||
err := json.NewDecoder(strings.NewReader(data)).Decode(sess) | |||
return sess, err | |||
} |
@@ -0,0 +1,21 @@ | |||
package goth | |||
// Params is used to pass data to sessions for authorization. An existing | |||
// implementation, and the one most likely to be used, is `url.Values`. | |||
type Params interface { | |||
Get(string) string | |||
} | |||
// Session needs to be implemented as part of the provider package. | |||
// It will be marshaled and persisted between requests to "tie" | |||
// the start and the end of the authorization process with a | |||
// 3rd party provider. | |||
type Session interface { | |||
// GetAuthURL returns the URL for the authentication end-point for the provider. | |||
GetAuthURL() (string, error) | |||
// Marshal generates a string representation of the Session for storing between requests. | |||
Marshal() string | |||
// Authorize should validate the data from the provider and return an access token | |||
// that can be stored for later access to the provider. | |||
Authorize(Provider, Params) (string, error) | |||
} |
@@ -0,0 +1,30 @@ | |||
package goth | |||
import ( | |||
"encoding/gob" | |||
"time" | |||
) | |||
func init() { | |||
gob.Register(User{}) | |||
} | |||
// User contains the information common amongst most OAuth and OAuth2 providers. | |||
// All of the "raw" datafrom the provider can be found in the `RawData` field. | |||
type User struct { | |||
RawData map[string]interface{} | |||
Provider string | |||
Email string | |||
Name string | |||
FirstName string | |||
LastName string | |||
NickName string | |||
Description string | |||
UserID string | |||
AvatarURL string | |||
Location string | |||
AccessToken string | |||
AccessTokenSecret string | |||
RefreshToken string | |||
ExpiresAt time.Time | |||
} |
@@ -0,0 +1,3 @@ | |||
# This source code refers to The Go Authors for copyright purposes. | |||
# The master list of authors is in the main Go distribution, | |||
# visible at http://tip.golang.org/AUTHORS. |
@@ -0,0 +1,31 @@ | |||
# Contributing to Go | |||
Go is an open source project. | |||
It is the work of hundreds of contributors. We appreciate your help! | |||
## Filing issues | |||
When [filing an issue](https://github.com/golang/oauth2/issues), make sure to answer these five questions: | |||
1. What version of Go are you using (`go version`)? | |||
2. What operating system and processor architecture are you using? | |||
3. What did you do? | |||
4. What did you expect to see? | |||
5. What did you see instead? | |||
General questions should go to the [golang-nuts mailing list](https://groups.google.com/group/golang-nuts) instead of the issue tracker. | |||
The gophers there will answer or ask you to file an issue if you've tripped over a bug. | |||
## Contributing code | |||
Please read the [Contribution Guidelines](https://golang.org/doc/contribute.html) | |||
before sending patches. | |||
**We do not accept GitHub pull requests** | |||
(we use [Gerrit](https://code.google.com/p/gerrit/) instead for code review). | |||
Unless otherwise noted, the Go source files are distributed under | |||
the BSD-style license found in the LICENSE file. | |||
@@ -0,0 +1,3 @@ | |||
# This source code was written by the Go contributors. | |||
# The master list of contributors is in the main Go distribution, | |||
# visible at http://tip.golang.org/CONTRIBUTORS. |
@@ -0,0 +1,27 @@ | |||
Copyright (c) 2009 The oauth2 Authors. All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are | |||
met: | |||
* Redistributions of source code must retain the above copyright | |||
notice, this list of conditions and the following disclaimer. | |||
* Redistributions in binary form must reproduce the above | |||
copyright notice, this list of conditions and the following disclaimer | |||
in the documentation and/or other materials provided with the | |||
distribution. | |||
* Neither the name of Google Inc. nor the names of its | |||
contributors may be used to endorse or promote products derived from | |||
this software without specific prior written permission. | |||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS | |||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT | |||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR | |||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT | |||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, | |||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT | |||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY | |||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | |||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
@@ -0,0 +1,65 @@ | |||
# OAuth2 for Go | |||
[![Build Status](https://travis-ci.org/golang/oauth2.svg?branch=master)](https://travis-ci.org/golang/oauth2) | |||
[![GoDoc](https://godoc.org/golang.org/x/oauth2?status.svg)](https://godoc.org/golang.org/x/oauth2) | |||
oauth2 package contains a client implementation for OAuth 2.0 spec. | |||
## Installation | |||
~~~~ | |||
go get golang.org/x/oauth2 | |||
~~~~ | |||
See godoc for further documentation and examples. | |||
* [godoc.org/golang.org/x/oauth2](http://godoc.org/golang.org/x/oauth2) | |||
* [godoc.org/golang.org/x/oauth2/google](http://godoc.org/golang.org/x/oauth2/google) | |||
## App Engine | |||
In change 96e89be (March 2015) we removed the `oauth2.Context2` type in favor | |||
of the [`context.Context`](https://golang.org/x/net/context#Context) type from | |||
the `golang.org/x/net/context` package | |||
This means its no longer possible to use the "Classic App Engine" | |||
`appengine.Context` type with the `oauth2` package. (You're using | |||
Classic App Engine if you import the package `"appengine"`.) | |||
To work around this, you may use the new `"google.golang.org/appengine"` | |||
package. This package has almost the same API as the `"appengine"` package, | |||
but it can be fetched with `go get` and used on "Managed VMs" and well as | |||
Classic App Engine. | |||
See the [new `appengine` package's readme](https://github.com/golang/appengine#updating-a-go-app-engine-app) | |||
for information on updating your app. | |||
If you don't want to update your entire app to use the new App Engine packages, | |||
you may use both sets of packages in parallel, using only the new packages | |||
with the `oauth2` package. | |||
import ( | |||
"golang.org/x/net/context" | |||
"golang.org/x/oauth2" | |||
"golang.org/x/oauth2/google" | |||
newappengine "google.golang.org/appengine" | |||
newurlfetch "google.golang.org/appengine/urlfetch" | |||
"appengine" | |||
) | |||
func handler(w http.ResponseWriter, r *http.Request) { | |||
var c appengine.Context = appengine.NewContext(r) | |||
c.Infof("Logging a message with the old package") | |||
var ctx context.Context = newappengine.NewContext(r) | |||
client := &http.Client{ | |||
Transport: &oauth2.Transport{ | |||
Source: google.AppEngineTokenSource(ctx, "scope"), | |||
Base: &newurlfetch.Transport{Context: ctx}, | |||
}, | |||
} | |||
client.Get("...") | |||
} | |||
@@ -0,0 +1,25 @@ | |||
// Copyright 2014 The Go Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
// +build appengine | |||
// App Engine hooks. | |||
package oauth2 | |||
import ( | |||
"net/http" | |||
"golang.org/x/net/context" | |||
"golang.org/x/oauth2/internal" | |||
"google.golang.org/appengine/urlfetch" | |||
) | |||
func init() { | |||
internal.RegisterContextClientFunc(contextClientAppEngine) | |||
} | |||
func contextClientAppEngine(ctx context.Context) (*http.Client, error) { | |||
return urlfetch.Client(ctx), nil | |||
} |
@@ -0,0 +1,76 @@ | |||
// Copyright 2014 The Go Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
// Package internal contains support packages for oauth2 package. | |||
package internal | |||
import ( | |||
"bufio" | |||
"crypto/rsa" | |||
"crypto/x509" | |||
"encoding/pem" | |||
"errors" | |||
"fmt" | |||
"io" | |||
"strings" | |||
) | |||
// ParseKey converts the binary contents of a private key file | |||
// to an *rsa.PrivateKey. It detects whether the private key is in a | |||
// PEM container or not. If so, it extracts the the private key | |||
// from PEM container before conversion. It only supports PEM | |||
// containers with no passphrase. | |||
func ParseKey(key []byte) (*rsa.PrivateKey, error) { | |||
block, _ := pem.Decode(key) | |||
if block != nil { | |||
key = block.Bytes | |||
} | |||
parsedKey, err := x509.ParsePKCS8PrivateKey(key) | |||
if err != nil { | |||
parsedKey, err = x509.ParsePKCS1PrivateKey(key) | |||
if err != nil { | |||
return nil, fmt.Errorf("private key should be a PEM or plain PKSC1 or PKCS8; parse error: %v", err) | |||
} | |||
} | |||
parsed, ok := parsedKey.(*rsa.PrivateKey) | |||
if !ok { | |||
return nil, errors.New("private key is invalid") | |||
} | |||
return parsed, nil | |||
} | |||
func ParseINI(ini io.Reader) (map[string]map[string]string, error) { | |||
result := map[string]map[string]string{ | |||
"": map[string]string{}, // root section | |||
} | |||
scanner := bufio.NewScanner(ini) | |||
currentSection := "" | |||
for scanner.Scan() { | |||
line := strings.TrimSpace(scanner.Text()) | |||
if strings.HasPrefix(line, ";") { | |||
// comment. | |||
continue | |||
} | |||
if strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]") { | |||
currentSection = strings.TrimSpace(line[1 : len(line)-1]) | |||
result[currentSection] = map[string]string{} | |||
continue | |||
} | |||
parts := strings.SplitN(line, "=", 2) | |||
if len(parts) == 2 && parts[0] != "" { | |||
result[currentSection][strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1]) | |||
} | |||
} | |||
if err := scanner.Err(); err != nil { | |||
return nil, fmt.Errorf("error scanning ini: %v", err) | |||
} | |||
return result, nil | |||
} | |||
func CondVal(v string) []string { | |||
if v == "" { | |||
return nil | |||
} | |||
return []string{v} | |||
} |
@@ -0,0 +1,227 @@ | |||
// Copyright 2014 The Go Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
// Package internal contains support packages for oauth2 package. | |||
package internal | |||
import ( | |||
"encoding/json" | |||
"fmt" | |||
"io" | |||
"io/ioutil" | |||
"mime" | |||
"net/http" | |||
"net/url" | |||
"strconv" | |||
"strings" | |||
"time" | |||
"golang.org/x/net/context" | |||
) | |||
// Token represents the crendentials used to authorize | |||
// the requests to access protected resources on the OAuth 2.0 | |||
// provider's backend. | |||
// | |||
// This type is a mirror of oauth2.Token and exists to break | |||
// an otherwise-circular dependency. Other internal packages | |||
// should convert this Token into an oauth2.Token before use. | |||
type Token struct { | |||
// AccessToken is the token that authorizes and authenticates | |||
// the requests. | |||
AccessToken string | |||
// TokenType is the type of token. | |||
// The Type method returns either this or "Bearer", the default. | |||
TokenType string | |||
// RefreshToken is a token that's used by the application | |||
// (as opposed to the user) to refresh the access token | |||
// if it expires. | |||
RefreshToken string | |||
// Expiry is the optional expiration time of the access token. | |||
// | |||
// If zero, TokenSource implementations will reuse the same | |||
// token forever and RefreshToken or equivalent | |||
// mechanisms for that TokenSource will not be used. | |||
Expiry time.Time | |||
// Raw optionally contains extra metadata from the server | |||
// when updating a token. | |||
Raw interface{} | |||
} | |||
// tokenJSON is the struct representing the HTTP response from OAuth2 | |||
// providers returning a token in JSON form. | |||
type tokenJSON struct { | |||
AccessToken string `json:"access_token"` | |||
TokenType string `json:"token_type"` | |||
RefreshToken string `json:"refresh_token"` | |||
ExpiresIn expirationTime `json:"expires_in"` // at least PayPal returns string, while most return number | |||
Expires expirationTime `json:"expires"` // broken Facebook spelling of expires_in | |||
} | |||
func (e *tokenJSON) expiry() (t time.Time) { | |||
if v := e.ExpiresIn; v != 0 { | |||
return time.Now().Add(time.Duration(v) * time.Second) | |||
} | |||
if v := e.Expires; v != 0 { | |||
return time.Now().Add(time.Duration(v) * time.Second) | |||
} | |||
return | |||
} | |||
type expirationTime int32 | |||
func (e *expirationTime) UnmarshalJSON(b []byte) error { | |||
var n json.Number | |||
err := json.Unmarshal(b, &n) | |||
if err != nil { | |||
return err | |||
} | |||
i, err := n.Int64() | |||
if err != nil { | |||
return err | |||
} | |||
*e = expirationTime(i) | |||
return nil | |||
} | |||
var brokenAuthHeaderProviders = []string{ | |||
"https://accounts.google.com/", | |||
"https://api.dropbox.com/", | |||
"https://api.dropboxapi.com/", | |||
"https://api.instagram.com/", | |||
"https://api.netatmo.net/", | |||
"https://api.odnoklassniki.ru/", | |||
"https://api.pushbullet.com/", | |||
"https://api.soundcloud.com/", | |||
"https://api.twitch.tv/", | |||
"https://app.box.com/", | |||
"https://connect.stripe.com/", | |||
"https://login.microsoftonline.com/", | |||
"https://login.salesforce.com/", | |||
"https://oauth.sandbox.trainingpeaks.com/", | |||
"https://oauth.trainingpeaks.com/", | |||
"https://oauth.vk.com/", | |||
"https://openapi.baidu.com/", | |||
"https://slack.com/", | |||
"https://test-sandbox.auth.corp.google.com", | |||
"https://test.salesforce.com/", | |||
"https://user.gini.net/", | |||
"https://www.douban.com/", | |||
"https://www.googleapis.com/", | |||
"https://www.linkedin.com/", | |||
"https://www.strava.com/oauth/", | |||
"https://www.wunderlist.com/oauth/", | |||
"https://api.patreon.com/", | |||
"https://sandbox.codeswholesale.com/oauth/token", | |||
"https://api.codeswholesale.com/oauth/token", | |||
} | |||
func RegisterBrokenAuthHeaderProvider(tokenURL string) { | |||
brokenAuthHeaderProviders = append(brokenAuthHeaderProviders, tokenURL) | |||
} | |||
// providerAuthHeaderWorks reports whether the OAuth2 server identified by the tokenURL | |||
// implements the OAuth2 spec correctly | |||
// See https://code.google.com/p/goauth2/issues/detail?id=31 for background. | |||
// In summary: | |||
// - Reddit only accepts client secret in the Authorization header | |||
// - Dropbox accepts either it in URL param or Auth header, but not both. | |||
// - Google only accepts URL param (not spec compliant?), not Auth header | |||
// - Stripe only accepts client secret in Auth header with Bearer method, not Basic | |||
func providerAuthHeaderWorks(tokenURL string) bool { | |||
for _, s := range brokenAuthHeaderProviders { | |||
if strings.HasPrefix(tokenURL, s) { | |||
// Some sites fail to implement the OAuth2 spec fully. | |||
return false | |||
} | |||
} | |||
// Assume the provider implements the spec properly | |||
// otherwise. We can add more exceptions as they're | |||
// discovered. We will _not_ be adding configurable hooks | |||
// to this package to let users select server bugs. | |||
return true | |||
} | |||
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values) (*Token, error) { | |||
hc, err := ContextClient(ctx) | |||
if err != nil { | |||
return nil, err | |||
} | |||
v.Set("client_id", clientID) | |||
bustedAuth := !providerAuthHeaderWorks(tokenURL) | |||
if bustedAuth && clientSecret != "" { | |||
v.Set("client_secret", clientSecret) | |||
} | |||
req, err := http.NewRequest("POST", tokenURL, strings.NewReader(v.Encode())) | |||
if err != nil { | |||
return nil, err | |||
} | |||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") | |||
if !bustedAuth { | |||
req.SetBasicAuth(clientID, clientSecret) | |||
} | |||
r, err := hc.Do(req) | |||
if err != nil { | |||
return nil, err | |||
} | |||
defer r.Body.Close() | |||
body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20)) | |||
if err != nil { | |||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) | |||
} | |||
if code := r.StatusCode; code < 200 || code > 299 { | |||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", r.Status, body) | |||
} | |||
var token *Token | |||
content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type")) | |||
switch content { | |||
case "application/x-www-form-urlencoded", "text/plain": | |||
vals, err := url.ParseQuery(string(body)) | |||
if err != nil { | |||
return nil, err | |||
} | |||
token = &Token{ | |||
AccessToken: vals.Get("access_token"), | |||
TokenType: vals.Get("token_type"), | |||
RefreshToken: vals.Get("refresh_token"), | |||
Raw: vals, | |||
} | |||
e := vals.Get("expires_in") | |||
if e == "" { | |||
// TODO(jbd): Facebook's OAuth2 implementation is broken and | |||
// returns expires_in field in expires. Remove the fallback to expires, | |||
// when Facebook fixes their implementation. | |||
e = vals.Get("expires") | |||
} | |||
expires, _ := strconv.Atoi(e) | |||
if expires != 0 { | |||
token.Expiry = time.Now().Add(time.Duration(expires) * time.Second) | |||
} | |||
default: | |||
var tj tokenJSON | |||
if err = json.Unmarshal(body, &tj); err != nil { | |||
return nil, err | |||
} | |||
token = &Token{ | |||
AccessToken: tj.AccessToken, | |||
TokenType: tj.TokenType, | |||
RefreshToken: tj.RefreshToken, | |||
Expiry: tj.expiry(), | |||
Raw: make(map[string]interface{}), | |||
} | |||
json.Unmarshal(body, &token.Raw) // no error checks for optional fields | |||
} | |||
// Don't overwrite `RefreshToken` with an empty value | |||
// if this was a token refreshing request. | |||
if token.RefreshToken == "" { | |||
token.RefreshToken = v.Get("refresh_token") | |||
} | |||
return token, nil | |||
} |
@@ -0,0 +1,69 @@ | |||
// Copyright 2014 The Go Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
// Package internal contains support packages for oauth2 package. | |||
package internal | |||
import ( | |||
"net/http" | |||
"golang.org/x/net/context" | |||
) | |||
// HTTPClient is the context key to use with golang.org/x/net/context's | |||
// WithValue function to associate an *http.Client value with a context. | |||
var HTTPClient ContextKey | |||
// ContextKey is just an empty struct. It exists so HTTPClient can be | |||
// an immutable public variable with a unique type. It's immutable | |||
// because nobody else can create a ContextKey, being unexported. | |||
type ContextKey struct{} | |||
// ContextClientFunc is a func which tries to return an *http.Client | |||
// given a Context value. If it returns an error, the search stops | |||
// with that error. If it returns (nil, nil), the search continues | |||
// down the list of registered funcs. | |||
type ContextClientFunc func(context.Context) (*http.Client, error) | |||
var contextClientFuncs []ContextClientFunc | |||
func RegisterContextClientFunc(fn ContextClientFunc) { | |||
contextClientFuncs = append(contextClientFuncs, fn) | |||
} | |||
func ContextClient(ctx context.Context) (*http.Client, error) { | |||
if ctx != nil { | |||
if hc, ok := ctx.Value(HTTPClient).(*http.Client); ok { | |||
return hc, nil | |||
} | |||
} | |||
for _, fn := range contextClientFuncs { | |||
c, err := fn(ctx) | |||
if err != nil { | |||
return nil, err | |||
} | |||
if c != nil { | |||
return c, nil | |||
} | |||
} | |||
return http.DefaultClient, nil | |||
} | |||
func ContextTransport(ctx context.Context) http.RoundTripper { | |||
hc, err := ContextClient(ctx) | |||
// This is a rare error case (somebody using nil on App Engine). | |||
if err != nil { | |||
return ErrorTransport{err} | |||
} | |||
return hc.Transport | |||
} | |||
// ErrorTransport returns the specified error on RoundTrip. | |||
// This RoundTripper should be used in rare error cases where | |||
// error handling can be postponed to response handling time. | |||
type ErrorTransport struct{ Err error } | |||
func (t ErrorTransport) RoundTrip(*http.Request) (*http.Response, error) { | |||
return nil, t.Err | |||
} |
@@ -0,0 +1,341 @@ | |||
// Copyright 2014 The Go Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
// Package oauth2 provides support for making | |||
// OAuth2 authorized and authenticated HTTP requests. | |||
// It can additionally grant authorization with Bearer JWT. | |||
package oauth2 // import "golang.org/x/oauth2" | |||
import ( | |||
"bytes" | |||
"errors" | |||
"net/http" | |||
"net/url" | |||
"strings" | |||
"sync" | |||
"golang.org/x/net/context" | |||
"golang.org/x/oauth2/internal" | |||
) | |||
// NoContext is the default context you should supply if not using | |||
// your own context.Context (see https://golang.org/x/net/context). | |||
// | |||
// Deprecated: Use context.Background() or context.TODO() instead. | |||
var NoContext = context.TODO() | |||
// RegisterBrokenAuthHeaderProvider registers an OAuth2 server | |||
// identified by the tokenURL prefix as an OAuth2 implementation | |||
// which doesn't support the HTTP Basic authentication | |||
// scheme to authenticate with the authorization server. | |||
// Once a server is registered, credentials (client_id and client_secret) | |||
// will be passed as query parameters rather than being present | |||
// in the Authorization header. | |||
// See https://code.google.com/p/goauth2/issues/detail?id=31 for background. | |||
func RegisterBrokenAuthHeaderProvider(tokenURL string) { | |||
internal.RegisterBrokenAuthHeaderProvider(tokenURL) | |||
} | |||
// Config describes a typical 3-legged OAuth2 flow, with both the | |||
// client application information and the server's endpoint URLs. | |||
// For the client credentials 2-legged OAuth2 flow, see the clientcredentials | |||
// package (https://golang.org/x/oauth2/clientcredentials). | |||
type Config struct { | |||
// ClientID is the application's ID. | |||
ClientID string | |||
// ClientSecret is the application's secret. | |||
ClientSecret string | |||
// Endpoint contains the resource server's token endpoint | |||
// URLs. These are constants specific to each server and are | |||
// often available via site-specific packages, such as | |||
// google.Endpoint or github.Endpoint. | |||
Endpoint Endpoint | |||
// RedirectURL is the URL to redirect users going through | |||
// the OAuth flow, after the resource owner's URLs. | |||
RedirectURL string | |||
// Scope specifies optional requested permissions. | |||
Scopes []string | |||
} | |||
// A TokenSource is anything that can return a token. | |||
type TokenSource interface { | |||
// Token returns a token or an error. | |||
// Token must be safe for concurrent use by multiple goroutines. | |||
// The returned Token must not be modified. | |||
Token() (*Token, error) | |||
} | |||
// Endpoint contains the OAuth 2.0 provider's authorization and token | |||
// endpoint URLs. | |||
type Endpoint struct { | |||
AuthURL string | |||
TokenURL string | |||
} | |||
var ( | |||
// AccessTypeOnline and AccessTypeOffline are options passed | |||
// to the Options.AuthCodeURL method. They modify the | |||
// "access_type" field that gets sent in the URL returned by | |||
// AuthCodeURL. | |||
// | |||
// Online is the default if neither is specified. If your | |||
// application needs to refresh access tokens when the user | |||
// is not present at the browser, then use offline. This will | |||
// result in your application obtaining a refresh token the | |||
// first time your application exchanges an authorization | |||
// code for a user. | |||
AccessTypeOnline AuthCodeOption = SetAuthURLParam("access_type", "online") | |||
AccessTypeOffline AuthCodeOption = SetAuthURLParam("access_type", "offline") | |||
// ApprovalForce forces the users to view the consent dialog | |||
// and confirm the permissions request at the URL returned | |||
// from AuthCodeURL, even if they've already done so. | |||
ApprovalForce AuthCodeOption = SetAuthURLParam("approval_prompt", "force") | |||
) | |||
// An AuthCodeOption is passed to Config.AuthCodeURL. | |||
type AuthCodeOption interface { | |||
setValue(url.Values) | |||
} | |||
type setParam struct{ k, v string } | |||
func (p setParam) setValue(m url.Values) { m.Set(p.k, p.v) } | |||
// SetAuthURLParam builds an AuthCodeOption which passes key/value parameters | |||
// to a provider's authorization endpoint. | |||
func SetAuthURLParam(key, value string) AuthCodeOption { | |||
return setParam{key, value} | |||
} | |||
// AuthCodeURL returns a URL to OAuth 2.0 provider's consent page | |||
// that asks for permissions for the required scopes explicitly. | |||
// | |||
// State is a token to protect the user from CSRF attacks. You must | |||
// always provide a non-zero string and validate that it matches the | |||
// the state query parameter on your redirect callback. | |||
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info. | |||
// | |||
// Opts may include AccessTypeOnline or AccessTypeOffline, as well | |||
// as ApprovalForce. | |||
func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string { | |||
var buf bytes.Buffer | |||
buf.WriteString(c.Endpoint.AuthURL) | |||
v := url.Values{ | |||
"response_type": {"code"}, | |||
"client_id": {c.ClientID}, | |||
"redirect_uri": internal.CondVal(c.RedirectURL), | |||
"scope": internal.CondVal(strings.Join(c.Scopes, " ")), | |||
"state": internal.CondVal(state), | |||
} | |||
for _, opt := range opts { | |||
opt.setValue(v) | |||
} | |||
if strings.Contains(c.Endpoint.AuthURL, "?") { | |||
buf.WriteByte('&') | |||
} else { | |||
buf.WriteByte('?') | |||
} | |||
buf.WriteString(v.Encode()) | |||
return buf.String() | |||
} | |||
// PasswordCredentialsToken converts a resource owner username and password | |||
// pair into a token. | |||
// | |||
// Per the RFC, this grant type should only be used "when there is a high | |||
// degree of trust between the resource owner and the client (e.g., the client | |||
// is part of the device operating system or a highly privileged application), | |||
// and when other authorization grant types are not available." | |||
// See https://tools.ietf.org/html/rfc6749#section-4.3 for more info. | |||
// | |||
// The HTTP client to use is derived from the context. | |||
// If nil, http.DefaultClient is used. | |||
func (c *Config) PasswordCredentialsToken(ctx context.Context, username, password string) (*Token, error) { | |||
return retrieveToken(ctx, c, url.Values{ | |||
"grant_type": {"password"}, | |||
"username": {username}, | |||
"password": {password}, | |||
"scope": internal.CondVal(strings.Join(c.Scopes, " ")), | |||
}) | |||
} | |||
// Exchange converts an authorization code into a token. | |||
// | |||
// It is used after a resource provider redirects the user back | |||
// to the Redirect URI (the URL obtained from AuthCodeURL). | |||
// | |||
// The HTTP client to use is derived from the context. | |||
// If a client is not provided via the context, http.DefaultClient is used. | |||
// | |||
// The code will be in the *http.Request.FormValue("code"). Before | |||
// calling Exchange, be sure to validate FormValue("state"). | |||
func (c *Config) Exchange(ctx context.Context, code string) (*Token, error) { | |||
return retrieveToken(ctx, c, url.Values{ | |||
"grant_type": {"authorization_code"}, | |||
"code": {code}, | |||
"redirect_uri": internal.CondVal(c.RedirectURL), | |||
"scope": internal.CondVal(strings.Join(c.Scopes, " ")), | |||
}) | |||
} | |||
// Client returns an HTTP client using the provided token. | |||
// The token will auto-refresh as necessary. The underlying | |||
// HTTP transport will be obtained using the provided context. | |||
// The returned client and its Transport should not be modified. | |||
func (c *Config) Client(ctx context.Context, t *Token) *http.Client { | |||
return NewClient(ctx, c.TokenSource(ctx, t)) | |||
} | |||
// TokenSource returns a TokenSource that returns t until t expires, | |||
// automatically refreshing it as necessary using the provided context. | |||
// | |||
// Most users will use Config.Client instead. | |||
func (c *Config) TokenSource(ctx context.Context, t *Token) TokenSource { | |||
tkr := &tokenRefresher{ | |||
ctx: ctx, | |||
conf: c, | |||
} | |||
if t != nil { | |||
tkr.refreshToken = t.RefreshToken | |||
} | |||
return &reuseTokenSource{ | |||
t: t, | |||
new: tkr, | |||
} | |||
} | |||
// tokenRefresher is a TokenSource that makes "grant_type"=="refresh_token" | |||
// HTTP requests to renew a token using a RefreshToken. | |||
type tokenRefresher struct { | |||
ctx context.Context // used to get HTTP requests | |||
conf *Config | |||
refreshToken string | |||
} | |||
// WARNING: Token is not safe for concurrent access, as it | |||
// updates the tokenRefresher's refreshToken field. | |||
// Within this package, it is used by reuseTokenSource which | |||
// synchronizes calls to this method with its own mutex. | |||
func (tf *tokenRefresher) Token() (*Token, error) { | |||
if tf.refreshToken == "" { | |||
return nil, errors.New("oauth2: token expired and refresh token is not set") | |||
} | |||
tk, err := retrieveToken(tf.ctx, tf.conf, url.Values{ | |||
"grant_type": {"refresh_token"}, | |||
"refresh_token": {tf.refreshToken}, | |||
}) | |||
if err != nil { | |||
return nil, err | |||
} | |||
if tf.refreshToken != tk.RefreshToken { | |||
tf.refreshToken = tk.RefreshToken | |||
} | |||
return tk, err | |||
} | |||
// reuseTokenSource is a TokenSource that holds a single token in memory | |||
// and validates its expiry before each call to retrieve it with | |||
// Token. If it's expired, it will be auto-refreshed using the | |||
// new TokenSource. | |||
type reuseTokenSource struct { | |||
new TokenSource // called when t is expired. | |||
mu sync.Mutex // guards t | |||
t *Token | |||
} | |||
// Token returns the current token if it's still valid, else will | |||
// refresh the current token (using r.Context for HTTP client | |||
// information) and return the new one. | |||
func (s *reuseTokenSource) Token() (*Token, error) { | |||
s.mu.Lock() | |||
defer s.mu.Unlock() | |||
if s.t.Valid() { | |||
return s.t, nil | |||
} | |||
t, err := s.new.Token() | |||
if err != nil { | |||
return nil, err | |||
} | |||
s.t = t | |||
return t, nil | |||
} | |||
// StaticTokenSource returns a TokenSource that always returns the same token. | |||
// Because the provided token t is never refreshed, StaticTokenSource is only | |||
// useful for tokens that never expire. | |||
func StaticTokenSource(t *Token) TokenSource { | |||
return staticTokenSource{t} | |||
} | |||
// staticTokenSource is a TokenSource that always returns the same Token. | |||
type staticTokenSource struct { | |||
t *Token | |||
} | |||
func (s staticTokenSource) Token() (*Token, error) { | |||
return s.t, nil | |||
} | |||
// HTTPClient is the context key to use with golang.org/x/net/context's | |||
// WithValue function to associate an *http.Client value with a context. | |||
var HTTPClient internal.ContextKey | |||
// NewClient creates an *http.Client from a Context and TokenSource. | |||
// The returned client is not valid beyond the lifetime of the context. | |||
// | |||
// As a special case, if src is nil, a non-OAuth2 client is returned | |||
// using the provided context. This exists to support related OAuth2 | |||
// packages. | |||
func NewClient(ctx context.Context, src TokenSource) *http.Client { | |||
if src == nil { | |||
c, err := internal.ContextClient(ctx) | |||
if err != nil { | |||
return &http.Client{Transport: internal.ErrorTransport{Err: err}} | |||
} | |||
return c | |||
} | |||
return &http.Client{ | |||
Transport: &Transport{ | |||
Base: internal.ContextTransport(ctx), | |||
Source: ReuseTokenSource(nil, src), | |||
}, | |||
} | |||
} | |||
// ReuseTokenSource returns a TokenSource which repeatedly returns the | |||
// same token as long as it's valid, starting with t. | |||
// When its cached token is invalid, a new token is obtained from src. | |||
// | |||
// ReuseTokenSource is typically used to reuse tokens from a cache | |||
// (such as a file on disk) between runs of a program, rather than | |||
// obtaining new tokens unnecessarily. | |||
// | |||
// The initial token t may be nil, in which case the TokenSource is | |||
// wrapped in a caching version if it isn't one already. This also | |||
// means it's always safe to wrap ReuseTokenSource around any other | |||
// TokenSource without adverse effects. | |||
func ReuseTokenSource(t *Token, src TokenSource) TokenSource { | |||
// Don't wrap a reuseTokenSource in itself. That would work, | |||
// but cause an unnecessary number of mutex operations. | |||
// Just build the equivalent one. | |||
if rt, ok := src.(*reuseTokenSource); ok { | |||
if t == nil { | |||
// Just use it directly. | |||
return rt | |||
} | |||
src = rt.new | |||
} | |||
return &reuseTokenSource{ | |||
t: t, | |||
new: src, | |||
} | |||
} |
@@ -0,0 +1,158 @@ | |||
// Copyright 2014 The Go Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
package oauth2 | |||
import ( | |||
"net/http" | |||
"net/url" | |||
"strconv" | |||
"strings" | |||
"time" | |||
"golang.org/x/net/context" | |||
"golang.org/x/oauth2/internal" | |||
) | |||
// expiryDelta determines how earlier a token should be considered | |||
// expired than its actual expiration time. It is used to avoid late | |||
// expirations due to client-server time mismatches. | |||
const expiryDelta = 10 * time.Second | |||
// Token represents the crendentials used to authorize | |||
// the requests to access protected resources on the OAuth 2.0 | |||
// provider's backend. | |||
// | |||
// Most users of this package should not access fields of Token | |||
// directly. They're exported mostly for use by related packages | |||
// implementing derivative OAuth2 flows. | |||
type Token struct { | |||
// AccessToken is the token that authorizes and authenticates | |||
// the requests. | |||
AccessToken string `json:"access_token"` | |||
// TokenType is the type of token. | |||
// The Type method returns either this or "Bearer", the default. | |||
TokenType string `json:"token_type,omitempty"` | |||
// RefreshToken is a token that's used by the application | |||
// (as opposed to the user) to refresh the access token | |||
// if it expires. | |||
RefreshToken string `json:"refresh_token,omitempty"` | |||
// Expiry is the optional expiration time of the access token. | |||
// | |||
// If zero, TokenSource implementations will reuse the same | |||
// token forever and RefreshToken or equivalent | |||
// mechanisms for that TokenSource will not be used. | |||
Expiry time.Time `json:"expiry,omitempty"` | |||
// raw optionally contains extra metadata from the server | |||
// when updating a token. | |||
raw interface{} | |||
} | |||
// Type returns t.TokenType if non-empty, else "Bearer". | |||
func (t *Token) Type() string { | |||
if strings.EqualFold(t.TokenType, "bearer") { | |||
return "Bearer" | |||
} | |||
if strings.EqualFold(t.TokenType, "mac") { | |||
return "MAC" | |||
} | |||
if strings.EqualFold(t.TokenType, "basic") { | |||
return "Basic" | |||
} | |||
if t.TokenType != "" { | |||
return t.TokenType | |||
} | |||
return "Bearer" | |||
} | |||
// SetAuthHeader sets the Authorization header to r using the access | |||
// token in t. | |||
// | |||
// This method is unnecessary when using Transport or an HTTP Client | |||
// returned by this package. | |||
func (t *Token) SetAuthHeader(r *http.Request) { | |||
r.Header.Set("Authorization", t.Type()+" "+t.AccessToken) | |||
} | |||
// WithExtra returns a new Token that's a clone of t, but using the | |||
// provided raw extra map. This is only intended for use by packages | |||
// implementing derivative OAuth2 flows. | |||
func (t *Token) WithExtra(extra interface{}) *Token { | |||
t2 := new(Token) | |||
*t2 = *t | |||
t2.raw = extra | |||
return t2 | |||
} | |||
// Extra returns an extra field. | |||
// Extra fields are key-value pairs returned by the server as a | |||
// part of the token retrieval response. | |||
func (t *Token) Extra(key string) interface{} { | |||
if raw, ok := t.raw.(map[string]interface{}); ok { | |||
return raw[key] | |||
} | |||
vals, ok := t.raw.(url.Values) | |||
if !ok { | |||
return nil | |||
} | |||
v := vals.Get(key) | |||
switch s := strings.TrimSpace(v); strings.Count(s, ".") { | |||
case 0: // Contains no "."; try to parse as int | |||
if i, err := strconv.ParseInt(s, 10, 64); err == nil { | |||
return i | |||
} | |||
case 1: // Contains a single "."; try to parse as float | |||
if f, err := strconv.ParseFloat(s, 64); err == nil { | |||
return f | |||
} | |||
} | |||
return v | |||
} | |||
// expired reports whether the token is expired. | |||
// t must be non-nil. | |||
func (t *Token) expired() bool { | |||
if t.Expiry.IsZero() { | |||
return false | |||
} | |||
return t.Expiry.Add(-expiryDelta).Before(time.Now()) | |||
} | |||
// Valid reports whether t is non-nil, has an AccessToken, and is not expired. | |||
func (t *Token) Valid() bool { | |||
return t != nil && t.AccessToken != "" && !t.expired() | |||
} | |||
// tokenFromInternal maps an *internal.Token struct into | |||
// a *Token struct. | |||
func tokenFromInternal(t *internal.Token) *Token { | |||
if t == nil { | |||
return nil | |||
} | |||
return &Token{ | |||
AccessToken: t.AccessToken, | |||
TokenType: t.TokenType, | |||
RefreshToken: t.RefreshToken, | |||
Expiry: t.Expiry, | |||
raw: t.Raw, | |||
} | |||
} | |||
// retrieveToken takes a *Config and uses that to retrieve an *internal.Token. | |||
// This token is then mapped from *internal.Token into an *oauth2.Token which is returned along | |||
// with an error.. | |||
func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) { | |||
tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return tokenFromInternal(tk), nil | |||
} |
@@ -0,0 +1,132 @@ | |||
// Copyright 2014 The Go Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
package oauth2 | |||
import ( | |||
"errors" | |||
"io" | |||
"net/http" | |||
"sync" | |||
) | |||
// Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests, | |||
// wrapping a base RoundTripper and adding an Authorization header | |||
// with a token from the supplied Sources. | |||
// | |||
// Transport is a low-level mechanism. Most code will use the | |||
// higher-level Config.Client method instead. | |||
type Transport struct { | |||
// Source supplies the token to add to outgoing requests' | |||
// Authorization headers. | |||
Source TokenSource | |||
// Base is the base RoundTripper used to make HTTP requests. | |||
// If nil, http.DefaultTransport is used. | |||
Base http.RoundTripper | |||
mu sync.Mutex // guards modReq | |||
modReq map[*http.Request]*http.Request // original -> modified | |||
} | |||
// RoundTrip authorizes and authenticates the request with an | |||
// access token. If no token exists or token is expired, | |||
// tries to refresh/fetch a new token. | |||
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { | |||
if t.Source == nil { | |||
return nil, errors.New("oauth2: Transport's Source is nil") | |||
} | |||
token, err := t.Source.Token() | |||
if err != nil { | |||
return nil, err | |||
} | |||
req2 := cloneRequest(req) // per RoundTripper contract | |||
token.SetAuthHeader(req2) | |||
t.setModReq(req, req2) | |||
res, err := t.base().RoundTrip(req2) | |||
if err != nil { | |||
t.setModReq(req, nil) | |||
return nil, err | |||
} | |||
res.Body = &onEOFReader{ | |||
rc: res.Body, | |||
fn: func() { t.setModReq(req, nil) }, | |||
} | |||
return res, nil | |||
} | |||
// CancelRequest cancels an in-flight request by closing its connection. | |||
func (t *Transport) CancelRequest(req *http.Request) { | |||
type canceler interface { | |||
CancelRequest(*http.Request) | |||
} | |||
if cr, ok := t.base().(canceler); ok { | |||
t.mu.Lock() | |||
modReq := t.modReq[req] | |||
delete(t.modReq, req) | |||
t.mu.Unlock() | |||
cr.CancelRequest(modReq) | |||
} | |||
} | |||
func (t *Transport) base() http.RoundTripper { | |||
if t.Base != nil { | |||
return t.Base | |||
} | |||
return http.DefaultTransport | |||
} | |||
func (t *Transport) setModReq(orig, mod *http.Request) { | |||
t.mu.Lock() | |||
defer t.mu.Unlock() | |||
if t.modReq == nil { | |||
t.modReq = make(map[*http.Request]*http.Request) | |||
} | |||
if mod == nil { | |||
delete(t.modReq, orig) | |||
} else { | |||
t.modReq[orig] = mod | |||
} | |||
} | |||
// cloneRequest returns a clone of the provided *http.Request. | |||
// The clone is a shallow copy of the struct and its Header map. | |||
func cloneRequest(r *http.Request) *http.Request { | |||
// shallow copy of the struct | |||
r2 := new(http.Request) | |||
*r2 = *r | |||
// deep copy of the Header | |||
r2.Header = make(http.Header, len(r.Header)) | |||
for k, s := range r.Header { | |||
r2.Header[k] = append([]string(nil), s...) | |||
} | |||
return r2 | |||
} | |||
type onEOFReader struct { | |||
rc io.ReadCloser | |||
fn func() | |||
} | |||
func (r *onEOFReader) Read(p []byte) (n int, err error) { | |||
n, err = r.rc.Read(p) | |||
if err == io.EOF { | |||
r.runFunc() | |||
} | |||
return | |||
} | |||
func (r *onEOFReader) Close() error { | |||
err := r.rc.Close() | |||
r.runFunc() | |||
return err | |||
} | |||
func (r *onEOFReader) runFunc() { | |||
if fn := r.fn; fn != nil { | |||
fn() | |||
r.fn = nil | |||
} | |||
} |
@@ -550,6 +550,24 @@ | |||
"revision": "d8eeeb8bae8896dd8e1b7e514ab0d396c4f12a1b", | |||
"revisionTime": "2016-11-03T02:43:54Z" | |||
}, | |||
{ | |||
"checksumSHA1": "O3KUfEXQPfdQ+tCMpP2RAIRJJqY=", | |||
"path": "github.com/markbates/goth", | |||
"revision": "450379d2950a65070b23cc93c53436553add4484", | |||
"revisionTime": "2017-02-06T19:46:32Z" | |||
}, | |||
{ | |||
"checksumSHA1": "MkFKwLV3icyUo4oP0BgEs+7+R1Y=", | |||
"path": "github.com/markbates/goth/gothic", | |||
"revision": "450379d2950a65070b23cc93c53436553add4484", | |||
"revisionTime": "2017-02-06T19:46:32Z" | |||
}, | |||
{ | |||
"checksumSHA1": "ZFqznX3/ZW65I4QeepiHQdE69nA=", | |||
"path": "github.com/markbates/goth/providers/github", | |||
"revision": "450379d2950a65070b23cc93c53436553add4484", | |||
"revisionTime": "2017-02-06T19:46:32Z" | |||
}, | |||
{ | |||
"checksumSHA1": "9FJUwn3EIgASVki+p8IHgWVC5vQ=", | |||
"path": "github.com/mattn/go-sqlite3", |