diff options
author | Lunny Xiao <xiaolunwen@gmail.com> | 2021-01-26 23:36:53 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-01-26 16:36:53 +0100 |
commit | 6433ba0ec3dfde67f45267aa12bd713c4a44c740 (patch) | |
tree | 8813388f7e58ff23ad24af9ccbdb5f0350cb3a09 /modules/context | |
parent | 3adbbb4255c42cde04d59b6ebf5ead7e3edda3e7 (diff) | |
download | gitea-6433ba0ec3dfde67f45267aa12bd713c4a44c740.tar.gz gitea-6433ba0ec3dfde67f45267aa12bd713c4a44c740.zip |
Move macaron to chi (#14293)
Use [chi](https://github.com/go-chi/chi) instead of the forked [macaron](https://gitea.com/macaron/macaron). Since macaron and chi have conflicts with session share, this big PR becomes a have-to thing. According my previous idea, we can replace macaron step by step but I'm wrong. :( Below is a list of big changes on this PR.
- [x] Define `context.ResponseWriter` interface with an implementation `context.Response`.
- [x] Use chi instead of macaron, and also a customize `Route` to wrap chi so that the router usage is similar as before.
- [x] Create different routers for `web`, `api`, `internal` and `install` so that the codes will be more clear and no magic .
- [x] Use https://github.com/unrolled/render instead of macaron's internal render
- [x] Use https://github.com/NYTimes/gziphandler instead of https://gitea.com/macaron/gzip
- [x] Use https://gitea.com/go-chi/session which is a modified version of https://gitea.com/macaron/session and removed `nodb` support since it will not be maintained. **BREAK**
- [x] Use https://gitea.com/go-chi/captcha which is a modified version of https://gitea.com/macaron/captcha
- [x] Use https://gitea.com/go-chi/cache which is a modified version of https://gitea.com/macaron/cache
- [x] Use https://gitea.com/go-chi/binding which is a modified version of https://gitea.com/macaron/binding
- [x] Use https://github.com/go-chi/cors instead of https://gitea.com/macaron/cors
- [x] Dropped https://gitea.com/macaron/i18n and make a new one in `code.gitea.io/gitea/modules/translation`
- [x] Move validation form structs from `code.gitea.io/gitea/modules/auth` to `code.gitea.io/gitea/modules/forms` to avoid dependency cycle.
- [x] Removed macaron log service because it's not need any more. **BREAK**
- [x] All form structs have to be get by `web.GetForm(ctx)` in the route function but not as a function parameter on routes definition.
- [x] Move Git HTTP protocol implementation to use routers directly.
- [x] Fix the problem that chi routes don't support trailing slash but macaron did.
- [x] `/api/v1/swagger` now will be redirect to `/api/swagger` but not render directly so that `APIContext` will not create a html render.
Notices:
- Chi router don't support request with trailing slash
- Integration test `TestUserHeatmap` maybe mysql version related. It's failed on my macOS(mysql 5.7.29 installed via brew) but succeed on CI.
Co-authored-by: 6543 <6543@obermui.de>
Diffstat (limited to 'modules/context')
-rw-r--r-- | modules/context/api.go | 141 | ||||
-rw-r--r-- | modules/context/auth.go | 129 | ||||
-rw-r--r-- | modules/context/captcha.go | 26 | ||||
-rw-r--r-- | modules/context/context.go | 559 | ||||
-rw-r--r-- | modules/context/csrf.go | 265 | ||||
-rw-r--r-- | modules/context/form.go | 227 | ||||
-rw-r--r-- | modules/context/org.go | 4 | ||||
-rw-r--r-- | modules/context/permission.go | 12 | ||||
-rw-r--r-- | modules/context/private.go | 45 | ||||
-rw-r--r-- | modules/context/repo.go | 570 | ||||
-rw-r--r-- | modules/context/response.go | 29 | ||||
-rw-r--r-- | modules/context/secret.go | 100 | ||||
-rw-r--r-- | modules/context/xsrf.go | 98 | ||||
-rw-r--r-- | modules/context/xsrf_test.go | 90 |
14 files changed, 1800 insertions, 495 deletions
diff --git a/modules/context/api.go b/modules/context/api.go index 7771ec1b14..cf6dc265cd 100644 --- a/modules/context/api.go +++ b/modules/context/api.go @@ -6,18 +6,21 @@ package context import ( + "context" "fmt" + "html" "net/http" "net/url" "strings" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/modules/auth/sso" "code.gitea.io/gitea/modules/git" "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/middlewares" "code.gitea.io/gitea/modules/setting" - "gitea.com/macaron/csrf" - "gitea.com/macaron/macaron" + "gitea.com/go-chi/session" ) // APIContext is a specific macaron context for API service @@ -91,7 +94,7 @@ func (ctx *APIContext) Error(status int, title string, obj interface{}) { if status == http.StatusInternalServerError { log.ErrorWithSkip(1, "%s: %s", title, message) - if macaron.Env == macaron.PROD && !(ctx.User != nil && ctx.User.IsAdmin) { + if setting.IsProd() && !(ctx.User != nil && ctx.User.IsAdmin) { message = "" } } @@ -108,7 +111,7 @@ func (ctx *APIContext) InternalServerError(err error) { log.ErrorWithSkip(1, "InternalServerError: %v", err) var message string - if macaron.Env != macaron.PROD || (ctx.User != nil && ctx.User.IsAdmin) { + if !setting.IsProd() || (ctx.User != nil && ctx.User.IsAdmin) { message = err.Error() } @@ -118,6 +121,20 @@ func (ctx *APIContext) InternalServerError(err error) { }) } +var ( + apiContextKey interface{} = "default_api_context" +) + +// WithAPIContext set up api context in request +func WithAPIContext(req *http.Request, ctx *APIContext) *http.Request { + return req.WithContext(context.WithValue(req.Context(), apiContextKey, ctx)) +} + +// GetAPIContext returns a context for API routes +func GetAPIContext(req *http.Request) *APIContext { + return req.Context().Value(apiContextKey).(*APIContext) +} + func genAPILinks(curURL *url.URL, total, pageSize, curPage int) []string { page := NewPagination(total, pageSize, curPage, 0) paginater := page.Paginater @@ -172,7 +189,7 @@ func (ctx *APIContext) RequireCSRF() { headerToken := ctx.Req.Header.Get(ctx.csrf.GetHeaderName()) formValueToken := ctx.Req.FormValue(ctx.csrf.GetFormName()) if len(headerToken) > 0 || len(formValueToken) > 0 { - csrf.Validate(ctx.Context.Context, ctx.csrf) + Validate(ctx.Context, ctx.csrf) } else { ctx.Context.Error(401, "Missing CSRF token.") } @@ -201,42 +218,91 @@ func (ctx *APIContext) CheckForOTP() { } // APIContexter returns apicontext as macaron middleware -func APIContexter() macaron.Handler { - return func(c *Context) { - ctx := &APIContext{ - Context: c, - } - c.Map(ctx) +func APIContexter() func(http.Handler) http.Handler { + var csrfOpts = getCsrfOpts() + + return func(next http.Handler) http.Handler { + + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + var locale = middlewares.Locale(w, req) + var ctx = APIContext{ + Context: &Context{ + Resp: NewResponse(w), + Data: map[string]interface{}{}, + Locale: locale, + Session: session.GetSession(req), + Repo: &Repository{ + PullRequest: &PullRequest{}, + }, + Org: &Organization{}, + }, + Org: &APIOrganization{}, + } + + ctx.Req = WithAPIContext(WithContext(req, ctx.Context), &ctx) + ctx.csrf = Csrfer(csrfOpts, ctx.Context) + + // If request sends files, parse them here otherwise the Query() can't be parsed and the CsrfToken will be invalid. + if ctx.Req.Method == "POST" && strings.Contains(ctx.Req.Header.Get("Content-Type"), "multipart/form-data") { + if err := ctx.Req.ParseMultipartForm(setting.Attachment.MaxSize << 20); err != nil && !strings.Contains(err.Error(), "EOF") { // 32MB max size + ctx.InternalServerError(err) + return + } + } + + // Get user from session if logged in. + ctx.User, ctx.IsBasicAuth = sso.SignedInUser(ctx.Req, ctx.Resp, &ctx, ctx.Session) + if ctx.User != nil { + ctx.IsSigned = true + ctx.Data["IsSigned"] = ctx.IsSigned + ctx.Data["SignedUser"] = ctx.User + ctx.Data["SignedUserID"] = ctx.User.ID + ctx.Data["SignedUserName"] = ctx.User.Name + ctx.Data["IsAdmin"] = ctx.User.IsAdmin + } else { + ctx.Data["SignedUserID"] = int64(0) + ctx.Data["SignedUserName"] = "" + } + + ctx.Resp.Header().Set(`X-Frame-Options`, `SAMEORIGIN`) + + ctx.Data["CsrfToken"] = html.EscapeString(ctx.csrf.GetToken()) + + next.ServeHTTP(ctx.Resp, ctx.Req) + }) } } // ReferencesGitRepo injects the GitRepo into the Context -func ReferencesGitRepo(allowEmpty bool) macaron.Handler { - return func(ctx *APIContext) { - // Empty repository does not have reference information. - if !allowEmpty && ctx.Repo.Repository.IsEmpty { - return - } - - // For API calls. - if ctx.Repo.GitRepo == nil { - repoPath := models.RepoPath(ctx.Repo.Owner.Name, ctx.Repo.Repository.Name) - gitRepo, err := git.OpenRepository(repoPath) - if err != nil { - ctx.Error(500, "RepoRef Invalid repo "+repoPath, err) +func ReferencesGitRepo(allowEmpty bool) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + ctx := GetAPIContext(req) + // Empty repository does not have reference information. + if !allowEmpty && ctx.Repo.Repository.IsEmpty { return } - ctx.Repo.GitRepo = gitRepo - // We opened it, we should close it - defer func() { - // If it's been set to nil then assume someone else has closed it. - if ctx.Repo.GitRepo != nil { - ctx.Repo.GitRepo.Close() + + // For API calls. + if ctx.Repo.GitRepo == nil { + repoPath := models.RepoPath(ctx.Repo.Owner.Name, ctx.Repo.Repository.Name) + gitRepo, err := git.OpenRepository(repoPath) + if err != nil { + ctx.Error(500, "RepoRef Invalid repo "+repoPath, err) + return } - }() - } + ctx.Repo.GitRepo = gitRepo + // We opened it, we should close it + defer func() { + // If it's been set to nil then assume someone else has closed it. + if ctx.Repo.GitRepo != nil { + ctx.Repo.GitRepo.Close() + } + }() + } - ctx.Next() + next.ServeHTTP(w, req) + }) } } @@ -266,8 +332,9 @@ func (ctx *APIContext) NotFound(objs ...interface{}) { } // RepoRefForAPI handles repository reference names when the ref name is not explicitly given -func RepoRefForAPI() macaron.Handler { - return func(ctx *APIContext) { +func RepoRefForAPI(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + ctx := GetAPIContext(req) // Empty repository does not have reference information. if ctx.Repo.Repository.IsEmpty { return @@ -319,6 +386,6 @@ func RepoRefForAPI() macaron.Handler { return } - ctx.Next() - } + next.ServeHTTP(w, req) + }) } diff --git a/modules/context/auth.go b/modules/context/auth.go index 02248384e1..8be6ed1907 100644 --- a/modules/context/auth.go +++ b/modules/context/auth.go @@ -7,12 +7,8 @@ package context import ( "code.gitea.io/gitea/models" - "code.gitea.io/gitea/modules/auth" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/setting" - - "gitea.com/macaron/csrf" - "gitea.com/macaron/macaron" ) // ToggleOptions contains required or check options @@ -24,42 +20,23 @@ type ToggleOptions struct { } // Toggle returns toggle options as middleware -func Toggle(options *ToggleOptions) macaron.Handler { +func Toggle(options *ToggleOptions) func(ctx *Context) { return func(ctx *Context) { - isAPIPath := auth.IsAPIPath(ctx.Req.URL.Path) - // Check prohibit login users. if ctx.IsSigned { if !ctx.User.IsActive && setting.Service.RegisterEmailConfirm { ctx.Data["Title"] = ctx.Tr("auth.active_your_account") - if isAPIPath { - ctx.JSON(403, map[string]string{ - "message": "This account is not activated.", - }) - return - } ctx.HTML(200, "user/auth/activate") return - } else if !ctx.User.IsActive || ctx.User.ProhibitLogin { + } + if !ctx.User.IsActive || ctx.User.ProhibitLogin { log.Info("Failed authentication attempt for %s from %s", ctx.User.Name, ctx.RemoteAddr()) ctx.Data["Title"] = ctx.Tr("auth.prohibit_login") - if isAPIPath { - ctx.JSON(403, map[string]string{ - "message": "This account is prohibited from signing in, please contact your site administrator.", - }) - return - } ctx.HTML(200, "user/auth/prohibit_login") return } if ctx.User.MustChangePassword { - if isAPIPath { - ctx.JSON(403, map[string]string{ - "message": "You must change your password. Change it at: " + setting.AppURL + "/user/change_password", - }) - return - } if ctx.Req.URL.Path != "/user/settings/change_password" { ctx.Data["Title"] = ctx.Tr("auth.must_change_password") ctx.Data["ChangePasscodeLink"] = setting.AppSubURL + "/user/change_password" @@ -82,8 +59,8 @@ func Toggle(options *ToggleOptions) macaron.Handler { return } - if !options.SignOutRequired && !options.DisableCSRF && ctx.Req.Method == "POST" && !auth.IsAPIPath(ctx.Req.URL.Path) { - csrf.Validate(ctx.Context, ctx.csrf) + if !options.SignOutRequired && !options.DisableCSRF && ctx.Req.Method == "POST" { + Validate(ctx, ctx.csrf) if ctx.Written() { return } @@ -91,13 +68,6 @@ func Toggle(options *ToggleOptions) macaron.Handler { if options.SignInRequired { if !ctx.IsSigned { - // Restrict API calls with error message. - if isAPIPath { - ctx.JSON(403, map[string]string{ - "message": "Only signed in user is allowed to call APIs.", - }) - return - } if ctx.Req.URL.Path != "/user/events" { ctx.SetCookie("redirect_to", setting.AppSubURL+ctx.Req.URL.RequestURI(), 0, setting.AppSubURL) } @@ -108,19 +78,88 @@ func Toggle(options *ToggleOptions) macaron.Handler { ctx.HTML(200, "user/auth/activate") return } - if ctx.IsSigned && isAPIPath && ctx.IsBasicAuth { + } + + // Redirect to log in page if auto-signin info is provided and has not signed in. + if !options.SignOutRequired && !ctx.IsSigned && + len(ctx.GetCookie(setting.CookieUserName)) > 0 { + if ctx.Req.URL.Path != "/user/events" { + ctx.SetCookie("redirect_to", setting.AppSubURL+ctx.Req.URL.RequestURI(), 0, setting.AppSubURL) + } + ctx.Redirect(setting.AppSubURL + "/user/login") + return + } + + if options.AdminRequired { + if !ctx.User.IsAdmin { + ctx.Error(403) + return + } + ctx.Data["PageIsAdmin"] = true + } + } +} + +// ToggleAPI returns toggle options as middleware +func ToggleAPI(options *ToggleOptions) func(ctx *APIContext) { + return func(ctx *APIContext) { + // Check prohibit login users. + if ctx.IsSigned { + if !ctx.User.IsActive && setting.Service.RegisterEmailConfirm { + ctx.Data["Title"] = ctx.Tr("auth.active_your_account") + ctx.JSON(403, map[string]string{ + "message": "This account is not activated.", + }) + return + } + if !ctx.User.IsActive || ctx.User.ProhibitLogin { + log.Info("Failed authentication attempt for %s from %s", ctx.User.Name, ctx.RemoteAddr()) + ctx.Data["Title"] = ctx.Tr("auth.prohibit_login") + ctx.JSON(403, map[string]string{ + "message": "This account is prohibited from signing in, please contact your site administrator.", + }) + return + } + + if ctx.User.MustChangePassword { + ctx.JSON(403, map[string]string{ + "message": "You must change your password. Change it at: " + setting.AppURL + "/user/change_password", + }) + return + } + } + + // Redirect to dashboard if user tries to visit any non-login page. + if options.SignOutRequired && ctx.IsSigned && ctx.Req.URL.RequestURI() != "/" { + ctx.Redirect(setting.AppSubURL + "/") + return + } + + if options.SignInRequired { + if !ctx.IsSigned { + // Restrict API calls with error message. + ctx.JSON(403, map[string]string{ + "message": "Only signed in user is allowed to call APIs.", + }) + return + } else if !ctx.User.IsActive && setting.Service.RegisterEmailConfirm { + ctx.Data["Title"] = ctx.Tr("auth.active_your_account") + ctx.HTML(200, "user/auth/activate") + return + } + if ctx.IsSigned && ctx.IsBasicAuth { twofa, err := models.GetTwoFactorByUID(ctx.User.ID) if err != nil { if models.IsErrTwoFactorNotEnrolled(err) { return // No 2FA enrollment for this user } - ctx.Error(500) + ctx.InternalServerError(err) return } otpHeader := ctx.Req.Header.Get("X-Gitea-OTP") ok, err := twofa.ValidateTOTP(otpHeader) if err != nil { - ctx.Error(500) + ctx.InternalServerError(err) return } if !ok { @@ -132,19 +171,11 @@ func Toggle(options *ToggleOptions) macaron.Handler { } } - // Redirect to log in page if auto-signin info is provided and has not signed in. - if !options.SignOutRequired && !ctx.IsSigned && !isAPIPath && - len(ctx.GetCookie(setting.CookieUserName)) > 0 { - if ctx.Req.URL.Path != "/user/events" { - ctx.SetCookie("redirect_to", setting.AppSubURL+ctx.Req.URL.RequestURI(), 0, setting.AppSubURL) - } - ctx.Redirect(setting.AppSubURL + "/user/login") - return - } - if options.AdminRequired { if !ctx.User.IsAdmin { - ctx.Error(403) + ctx.JSON(403, map[string]string{ + "message": "You have no permission to request for this.", + }) return } ctx.Data["PageIsAdmin"] = true diff --git a/modules/context/captcha.go b/modules/context/captcha.go new file mode 100644 index 0000000000..956380ed73 --- /dev/null +++ b/modules/context/captcha.go @@ -0,0 +1,26 @@ +// Copyright 2020 The Gitea Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package context + +import ( + "sync" + + "code.gitea.io/gitea/modules/setting" + + "gitea.com/go-chi/captcha" +) + +var imageCaptchaOnce sync.Once +var cpt *captcha.Captcha + +// GetImageCaptcha returns global image captcha +func GetImageCaptcha() *captcha.Captcha { + imageCaptchaOnce.Do(func() { + cpt = captcha.NewCaptcha(captcha.Options{ + SubURL: setting.AppSubURL, + }) + }) + return cpt +} diff --git a/modules/context/context.go b/modules/context/context.go index e4121649ae..630129b8c1 100644 --- a/modules/context/context.go +++ b/modules/context/context.go @@ -6,37 +6,55 @@ package context import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" "html" "html/template" "io" "net/http" "net/url" "path" + "strconv" "strings" "time" "code.gitea.io/gitea/models" - "code.gitea.io/gitea/modules/auth" "code.gitea.io/gitea/modules/auth/sso" "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/middlewares" "code.gitea.io/gitea/modules/setting" + "code.gitea.io/gitea/modules/templates" + "code.gitea.io/gitea/modules/translation" "code.gitea.io/gitea/modules/util" - "gitea.com/macaron/cache" - "gitea.com/macaron/csrf" - "gitea.com/macaron/i18n" - "gitea.com/macaron/macaron" - "gitea.com/macaron/session" + "gitea.com/go-chi/cache" + "gitea.com/go-chi/session" + "github.com/go-chi/chi" "github.com/unknwon/com" + "github.com/unknwon/i18n" + "github.com/unrolled/render" + "golang.org/x/crypto/pbkdf2" ) +// Render represents a template render +type Render interface { + TemplateLookup(tmpl string) *template.Template + HTML(w io.Writer, status int, name string, binding interface{}, htmlOpt ...render.HTMLOptions) error +} + // Context represents context of a request. type Context struct { - *macaron.Context + Resp ResponseWriter + Req *http.Request + Data map[string]interface{} + Render Render + translation.Locale Cache cache.Cache - csrf csrf.CSRF - Flash *session.Flash + csrf CSRF + Flash *middlewares.Flash Session session.Store Link string // current request URL @@ -163,13 +181,22 @@ func (ctx *Context) RedirectToFirst(location ...string) { // HTML calls Context.HTML and converts template name to string. func (ctx *Context) HTML(status int, name base.TplName) { log.Debug("Template: %s", name) - ctx.Context.HTML(status, string(name)) + if err := ctx.Render.HTML(ctx.Resp, status, string(name), ctx.Data); err != nil { + ctx.ServerError("Render failed", err) + } +} + +// HTMLString render content to a string but not http.ResponseWriter +func (ctx *Context) HTMLString(name string, data interface{}) (string, error) { + var buf strings.Builder + err := ctx.Render.HTML(&buf, 200, string(name), data) + return buf.String(), err } // RenderWithErr used for page has form validation but need to prompt error to users. func (ctx *Context) RenderWithErr(msg string, tpl base.TplName, form interface{}) { if form != nil { - auth.AssignForm(form, ctx.Data) + middlewares.AssignForm(form, ctx.Data) } ctx.Flash.ErrorMsg = msg ctx.Data["Flash"] = ctx.Flash @@ -184,7 +211,7 @@ func (ctx *Context) NotFound(title string, err error) { func (ctx *Context) notFoundInternal(title string, err error) { if err != nil { log.ErrorWithSkip(2, "%s: %v", title, err) - if macaron.Env != macaron.PROD { + if !setting.IsProd() { ctx.Data["ErrorMsg"] = err } } @@ -203,7 +230,7 @@ func (ctx *Context) ServerError(title string, err error) { func (ctx *Context) serverErrorInternal(title string, err error) { if err != nil { log.ErrorWithSkip(2, "%s: %v", title, err) - if macaron.Env != macaron.PROD { + if !setting.IsProd() { ctx.Data["ErrorMsg"] = err } } @@ -224,6 +251,44 @@ func (ctx *Context) NotFoundOrServerError(title string, errck func(error) bool, ctx.serverErrorInternal(title, err) } +// Header returns a header +func (ctx *Context) Header() http.Header { + return ctx.Resp.Header() +} + +// FIXME: We should differ Query and Form, currently we just use form as query +// Currently to be compatible with macaron, we keep it. + +// Query returns request form as string with default +func (ctx *Context) Query(key string, defaults ...string) string { + return (*Forms)(ctx.Req).MustString(key, defaults...) +} + +// QueryTrim returns request form as string with default and trimmed spaces +func (ctx *Context) QueryTrim(key string, defaults ...string) string { + return (*Forms)(ctx.Req).MustTrimmed(key, defaults...) +} + +// QueryStrings returns request form as strings with default +func (ctx *Context) QueryStrings(key string, defaults ...[]string) []string { + return (*Forms)(ctx.Req).MustStrings(key, defaults...) +} + +// QueryInt returns request form as int with default +func (ctx *Context) QueryInt(key string, defaults ...int) int { + return (*Forms)(ctx.Req).MustInt(key, defaults...) +} + +// QueryInt64 returns request form as int64 with default +func (ctx *Context) QueryInt64(key string, defaults ...int64) int64 { + return (*Forms)(ctx.Req).MustInt64(key, defaults...) +} + +// QueryBool returns request form as bool with default +func (ctx *Context) QueryBool(key string, defaults ...bool) bool { + return (*Forms)(ctx.Req).MustBool(key, defaults...) +} + // HandleText handles HTTP status code func (ctx *Context) HandleText(status int, title string) { if (status/100 == 4) || (status/100 == 5) { @@ -249,66 +314,324 @@ func (ctx *Context) ServeContent(name string, r io.ReadSeeker, params ...interfa ctx.Resp.Header().Set("Cache-Control", "must-revalidate") ctx.Resp.Header().Set("Pragma", "public") ctx.Resp.Header().Set("Access-Control-Expose-Headers", "Content-Disposition") - http.ServeContent(ctx.Resp, ctx.Req.Request, name, modtime, r) + http.ServeContent(ctx.Resp, ctx.Req, name, modtime, r) +} + +// PlainText render content as plain text +func (ctx *Context) PlainText(status int, bs []byte) { + ctx.Resp.WriteHeader(status) + ctx.Resp.Header().Set("Content-Type", "text/plain;charset=utf8") + if _, err := ctx.Resp.Write(bs); err != nil { + ctx.ServerError("Render JSON failed", err) + } +} + +// ServeFile serves given file to response. +func (ctx *Context) ServeFile(file string, names ...string) { + var name string + if len(names) > 0 { + name = names[0] + } else { + name = path.Base(file) + } + ctx.Resp.Header().Set("Content-Description", "File Transfer") + ctx.Resp.Header().Set("Content-Type", "application/octet-stream") + ctx.Resp.Header().Set("Content-Disposition", "attachment; filename="+name) + ctx.Resp.Header().Set("Content-Transfer-Encoding", "binary") + ctx.Resp.Header().Set("Expires", "0") + ctx.Resp.Header().Set("Cache-Control", "must-revalidate") + ctx.Resp.Header().Set("Pragma", "public") + http.ServeFile(ctx.Resp, ctx.Req, file) +} + +// Error returned an error to web browser +func (ctx *Context) Error(status int, contents ...string) { + var v = http.StatusText(status) + if len(contents) > 0 { + v = contents[0] + } + http.Error(ctx.Resp, v, status) +} + +// JSON render content as JSON +func (ctx *Context) JSON(status int, content interface{}) { + ctx.Resp.WriteHeader(status) + ctx.Resp.Header().Set("Content-Type", "application/json;charset=utf8") + if err := json.NewEncoder(ctx.Resp).Encode(content); err != nil { + ctx.ServerError("Render JSON failed", err) + } +} + +// Redirect redirect the request +func (ctx *Context) Redirect(location string, status ...int) { + code := http.StatusFound + if len(status) == 1 { + code = status[0] + } + + http.Redirect(ctx.Resp, ctx.Req, location, code) +} + +// SetCookie set cookies to web browser +func (ctx *Context) SetCookie(name string, value string, others ...interface{}) { + middlewares.SetCookie(ctx.Resp, name, value, others...) +} + +// GetCookie returns given cookie value from request header. +func (ctx *Context) GetCookie(name string) string { + return middlewares.GetCookie(ctx.Req, name) +} + +// GetSuperSecureCookie returns given cookie value from request header with secret string. +func (ctx *Context) GetSuperSecureCookie(secret, name string) (string, bool) { + val := ctx.GetCookie(name) + if val == "" { + return "", false + } + + text, err := hex.DecodeString(val) + if err != nil { + return "", false + } + + key := pbkdf2.Key([]byte(secret), []byte(secret), 1000, 16, sha256.New) + text, err = com.AESGCMDecrypt(key, text) + return string(text), err == nil +} + +// SetSuperSecureCookie sets given cookie value to response header with secret string. +func (ctx *Context) SetSuperSecureCookie(secret, name, value string, others ...interface{}) { + key := pbkdf2.Key([]byte(secret), []byte(secret), 1000, 16, sha256.New) + text, err := com.AESGCMEncrypt(key, []byte(value)) + if err != nil { + panic("error encrypting cookie: " + err.Error()) + } + + ctx.SetCookie(name, hex.EncodeToString(text), others...) +} + +// GetCookieInt returns cookie result in int type. +func (ctx *Context) GetCookieInt(name string) int { + r, _ := strconv.Atoi(ctx.GetCookie(name)) + return r +} + +// GetCookieInt64 returns cookie result in int64 type. +func (ctx *Context) GetCookieInt64(name string) int64 { + r, _ := strconv.ParseInt(ctx.GetCookie(name), 10, 64) + return r +} + +// GetCookieFloat64 returns cookie result in float64 type. +func (ctx *Context) GetCookieFloat64(name string) float64 { + v, _ := strconv.ParseFloat(ctx.GetCookie(name), 64) + return v +} + +// RemoteAddr returns the client machie ip address +func (ctx *Context) RemoteAddr() string { + return ctx.Req.RemoteAddr +} + +// Params returns the param on route +func (ctx *Context) Params(p string) string { + s, _ := url.PathUnescape(chi.URLParam(ctx.Req, strings.TrimPrefix(p, ":"))) + return s +} + +// ParamsInt64 returns the param on route as int64 +func (ctx *Context) ParamsInt64(p string) int64 { + v, _ := strconv.ParseInt(ctx.Params(p), 10, 64) + return v +} + +// SetParams set params into routes +func (ctx *Context) SetParams(k, v string) { + chiCtx := chi.RouteContext(ctx.Req.Context()) + chiCtx.URLParams.Add(strings.TrimPrefix(k, ":"), url.PathEscape(v)) +} + +// Write writes data to webbrowser +func (ctx *Context) Write(bs []byte) (int, error) { + return ctx.Resp.Write(bs) +} + +// Written returns true if there are something sent to web browser +func (ctx *Context) Written() bool { + return ctx.Resp.Status() > 0 +} + +// Status writes status code +func (ctx *Context) Status(status int) { + ctx.Resp.WriteHeader(status) +} + +// Handler represents a custom handler +type Handler func(*Context) + +// enumerate all content +var ( + contextKey interface{} = "default_context" +) + +// WithContext set up install context in request +func WithContext(req *http.Request, ctx *Context) *http.Request { + return req.WithContext(context.WithValue(req.Context(), contextKey, ctx)) +} + +// GetContext retrieves install context from request +func GetContext(req *http.Request) *Context { + return req.Context().Value(contextKey).(*Context) +} + +func getCsrfOpts() CsrfOptions { + return CsrfOptions{ + Secret: setting.SecretKey, + Cookie: setting.CSRFCookieName, + SetCookie: true, + Secure: setting.SessionConfig.Secure, + CookieHTTPOnly: setting.CSRFCookieHTTPOnly, + Header: "X-Csrf-Token", + CookieDomain: setting.SessionConfig.Domain, + CookiePath: setting.SessionConfig.CookiePath, + } } // Contexter initializes a classic context for a request. -func Contexter() macaron.Handler { - return func(c *macaron.Context, l i18n.Locale, cache cache.Cache, sess session.Store, f *session.Flash, x csrf.CSRF) { - ctx := &Context{ - Context: c, - Cache: cache, - csrf: x, - Flash: f, - Session: sess, - Link: setting.AppSubURL + strings.TrimSuffix(c.Req.URL.EscapedPath(), "/"), - Repo: &Repository{ - PullRequest: &PullRequest{}, - }, - Org: &Organization{}, +func Contexter() func(next http.Handler) http.Handler { + rnd := templates.HTMLRenderer() + + var c cache.Cache + var err error + if setting.CacheService.Enabled { + c, err = cache.NewCacher(cache.Options{ + Adapter: setting.CacheService.Adapter, + AdapterConfig: setting.CacheService.Conn, + Interval: setting.CacheService.Interval, + }) + if err != nil { + panic(err) } - ctx.Data["Language"] = ctx.Locale.Language() - c.Data["Link"] = ctx.Link - ctx.Data["CurrentURL"] = setting.AppSubURL + c.Req.URL.RequestURI() - ctx.Data["PageStartTime"] = time.Now() - // Quick responses appropriate go-get meta with status 200 - // regardless of if user have access to the repository, - // or the repository does not exist at all. - // This is particular a workaround for "go get" command which does not respect - // .netrc file. - if ctx.Query("go-get") == "1" { - ownerName := c.Params(":username") - repoName := c.Params(":reponame") - trimmedRepoName := strings.TrimSuffix(repoName, ".git") - - if ownerName == "" || trimmedRepoName == "" { - _, _ = c.Write([]byte(`<!doctype html> + } + + var csrfOpts = getCsrfOpts() + //var flashEncryptionKey, _ = NewSecret() + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + var locale = middlewares.Locale(resp, req) + var startTime = time.Now() + var link = setting.AppSubURL + strings.TrimSuffix(req.URL.EscapedPath(), "/") + var ctx = Context{ + Resp: NewResponse(resp), + Cache: c, + Locale: locale, + Link: link, + Render: rnd, + Session: session.GetSession(req), + Repo: &Repository{ + PullRequest: &PullRequest{}, + }, + Org: &Organization{}, + Data: map[string]interface{}{ + "CurrentURL": setting.AppSubURL + req.URL.RequestURI(), + "PageStartTime": startTime, + "TmplLoadTimes": func() string { + return time.Since(startTime).String() + }, + "Link": link, + }, + } + + ctx.Req = WithContext(req, &ctx) + ctx.csrf = Csrfer(csrfOpts, &ctx) + + // Get flash. + flashCookie := ctx.GetCookie("macaron_flash") + vals, _ := url.ParseQuery(flashCookie) + if len(vals) > 0 { + f := &middlewares.Flash{ + DataStore: &ctx, + Values: vals, + ErrorMsg: vals.Get("error"), + SuccessMsg: vals.Get("success"), + InfoMsg: vals.Get("info"), + WarningMsg: vals.Get("warning"), + } + ctx.Data["Flash"] = f + } + + f := &middlewares.Flash{ + DataStore: &ctx, + Values: url.Values{}, + ErrorMsg: "", + WarningMsg: "", + InfoMsg: "", + SuccessMsg: "", + } + ctx.Resp.Before(func(resp ResponseWriter) { + if flash := f.Encode(); len(flash) > 0 { + if err == nil { + middlewares.SetCookie(resp, "macaron_flash", flash, 0, + setting.SessionConfig.CookiePath, + middlewares.Domain(setting.SessionConfig.Domain), + middlewares.HTTPOnly(true), + middlewares.Secure(setting.SessionConfig.Secure), + //middlewares.SameSite(opt.SameSite), FIXME: we need a samesite config + ) + return + } + } + + ctx.SetCookie("macaron_flash", "", -1, + setting.SessionConfig.CookiePath, + middlewares.Domain(setting.SessionConfig.Domain), + middlewares.HTTPOnly(true), + middlewares.Secure(setting.SessionConfig.Secure), + //middlewares.SameSite(), FIXME: we need a samesite config + ) + }) + + ctx.Flash = f + + // Quick responses appropriate go-get meta with status 200 + // regardless of if user have access to the repository, + // or the repository does not exist at all. + // This is particular a workaround for "go get" command which does not respect + // .netrc file. + if ctx.Query("go-get") == "1" { + ownerName := ctx.Params(":username") + repoName := ctx.Params(":reponame") + trimmedRepoName := strings.TrimSuffix(repoName, ".git") + + if ownerName == "" || trimmedRepoName == "" { + _, _ = ctx.Write([]byte(`<!doctype html> <html> <body> invalid import path </body> </html> `)) - c.WriteHeader(400) - return - } - branchName := "master" - - repo, err := models.GetRepositoryByOwnerAndName(ownerName, repoName) - if err == nil && len(repo.DefaultBranch) > 0 { - branchName = repo.DefaultBranch - } - prefix := setting.AppURL + path.Join(url.PathEscape(ownerName), url.PathEscape(repoName), "src", "branch", util.PathEscapeSegments(branchName)) - - appURL, _ := url.Parse(setting.AppURL) - - insecure := "" - if appURL.Scheme == string(setting.HTTP) { - insecure = "--insecure " - } - c.Header().Set("Content-Type", "text/html") - c.WriteHeader(http.StatusOK) - _, _ = c.Write([]byte(com.Expand(`<!doctype html> + ctx.Status(400) + return + } + branchName := "master" + + repo, err := models.GetRepositoryByOwnerAndName(ownerName, repoName) + if err == nil && len(repo.DefaultBranch) > 0 { + branchName = repo.DefaultBranch + } + prefix := setting.AppURL + path.Join(url.PathEscape(ownerName), url.PathEscape(repoName), "src", "branch", util.PathEscapeSegments(branchName)) + + appURL, _ := url.Parse(setting.AppURL) + + insecure := "" + if appURL.Scheme == string(setting.HTTP) { + insecure = "--insecure " + } + ctx.Header().Set("Content-Type", "text/html") + ctx.Status(http.StatusOK) + _, _ = ctx.Write([]byte(com.Expand(`<!doctype html> <html> <head> <meta name="go-import" content="{GoGetImport} git {CloneLink}"> @@ -319,60 +642,72 @@ func Contexter() macaron.Handler { </body> </html> `, map[string]string{ - "GoGetImport": ComposeGoGetImport(ownerName, trimmedRepoName), - "CloneLink": models.ComposeHTTPSCloneURL(ownerName, repoName), - "GoDocDirectory": prefix + "{/dir}", - "GoDocFile": prefix + "{/dir}/{file}#L{line}", - "Insecure": insecure, - }))) - return - } - - // Get user from session if logged in. - ctx.User, ctx.IsBasicAuth = sso.SignedInUser(ctx.Req.Request, c.Resp, ctx, ctx.Session) - - if ctx.User != nil { - ctx.IsSigned = true - ctx.Data["IsSigned"] = ctx.IsSigned - ctx.Data["SignedUser"] = ctx.User - ctx.Data["SignedUserID"] = ctx.User.ID - ctx.Data["SignedUserName"] = ctx.User.Name - ctx.Data["IsAdmin"] = ctx.User.IsAdmin - } else { - ctx.Data["SignedUserID"] = int64(0) - ctx.Data["SignedUserName"] = "" - } - - // If request sends files, parse them here otherwise the Query() can't be parsed and the CsrfToken will be invalid. - if ctx.Req.Method == "POST" && strings.Contains(ctx.Req.Header.Get("Content-Type"), "multipart/form-data") { - if err := ctx.Req.ParseMultipartForm(setting.Attachment.MaxSize << 20); err != nil && !strings.Contains(err.Error(), "EOF") { // 32MB max size - ctx.ServerError("ParseMultipartForm", err) + "GoGetImport": ComposeGoGetImport(ownerName, trimmedRepoName), + "CloneLink": models.ComposeHTTPSCloneURL(ownerName, repoName), + "GoDocDirectory": prefix + "{/dir}", + "GoDocFile": prefix + "{/dir}/{file}#L{line}", + "Insecure": insecure, + }))) return } - } - - ctx.Resp.Header().Set(`X-Frame-Options`, `SAMEORIGIN`) - - ctx.Data["CsrfToken"] = html.EscapeString(x.GetToken()) - ctx.Data["CsrfTokenHtml"] = template.HTML(`<input type="hidden" name="_csrf" value="` + ctx.Data["CsrfToken"].(string) + `">`) - log.Debug("Session ID: %s", sess.ID()) - log.Debug("CSRF Token: %v", ctx.Data["CsrfToken"]) - - ctx.Data["IsLandingPageHome"] = setting.LandingPageURL == setting.LandingPageHome - ctx.Data["IsLandingPageExplore"] = setting.LandingPageURL == setting.LandingPageExplore - ctx.Data["IsLandingPageOrganizations"] = setting.LandingPageURL == setting.LandingPageOrganizations - ctx.Data["ShowRegistrationButton"] = setting.Service.ShowRegistrationButton - ctx.Data["ShowMilestonesDashboardPage"] = setting.Service.ShowMilestonesDashboardPage - ctx.Data["ShowFooterBranding"] = setting.ShowFooterBranding - ctx.Data["ShowFooterVersion"] = setting.ShowFooterVersion + // If request sends files, parse them here otherwise the Query() can't be parsed and the CsrfToken will be invalid. + if ctx.Req.Method == "POST" && strings.Contains(ctx.Req.Header.Get("Content-Type"), "multipart/form-data") { + if err := ctx.Req.ParseMultipartForm(setting.Attachment.MaxSize << 20); err != nil && !strings.Contains(err.Error(), "EOF") { // 32MB max size + ctx.ServerError("ParseMultipartForm", err) + return + } + } - ctx.Data["EnableSwagger"] = setting.API.EnableSwagger - ctx.Data["EnableOpenIDSignIn"] = setting.Service.EnableOpenIDSignIn - ctx.Data["DisableMigrations"] = setting.Repository.DisableMigrations + // Get user from session if logged in. + ctx.User, ctx.IsBasicAuth = sso.SignedInUser(ctx.Req, ctx.Resp, &ctx, ctx.Session) + + if ctx.User != nil { + ctx.IsSigned = true + ctx.Data["IsSigned"] = ctx.IsSigned + ctx.Data["SignedUser"] = ctx.User + ctx.Data["SignedUserID"] = ctx.User.ID + ctx.Data["SignedUserName"] = ctx.User.Name + ctx.Data["IsAdmin"] = ctx.User.IsAdmin + } else { + ctx.Data["SignedUserID"] = int64(0) + ctx.Data["SignedUserName"] = "" + } - ctx.Data["ManifestData"] = setting.ManifestData + ctx.Resp.Header().Set(`X-Frame-Options`, `SAMEORIGIN`) + + ctx.Data["CsrfToken"] = html.EscapeString(ctx.csrf.GetToken()) + ctx.Data["CsrfTokenHtml"] = template.HTML(`<input type="hidden" name="_csrf" value="` + ctx.Data["CsrfToken"].(string) + `">`) + log.Debug("Session ID: %s", ctx.Session.ID()) + log.Debug("CSRF Token: %v", ctx.Data["CsrfToken"]) + + ctx.Data["IsLandingPageHome"] = setting.LandingPageURL == setting.LandingPageHome + ctx.Data["IsLandingPageExplore"] = setting.LandingPageURL == setting.LandingPageExplore + ctx.Data["IsLandingPageOrganizations"] = setting.LandingPageURL == setting.LandingPageOrganizations + + ctx.Data["ShowRegistrationButton"] = setting.Service.ShowRegistrationButton + ctx.Data["ShowMilestonesDashboardPage"] = setting.Service.ShowMilestonesDashboardPage + ctx.Data["ShowFooterBranding"] = setting.ShowFooterBranding + ctx.Data["ShowFooterVersion"] = setting.ShowFooterVersion + + ctx.Data["EnableSwagger"] = setting.API.EnableSwagger + ctx.Data["EnableOpenIDSignIn"] = setting.Service.EnableOpenIDSignIn + ctx.Data["DisableMigrations"] = setting.Repository.DisableMigrations + + ctx.Data["ManifestData"] = setting.ManifestData + + ctx.Data["i18n"] = locale + ctx.Data["Tr"] = i18n.Tr + ctx.Data["Lang"] = locale.Language() + ctx.Data["AllLangs"] = translation.AllLangs() + for _, lang := range translation.AllLangs() { + if lang.Lang == locale.Language() { + ctx.Data["LangName"] = lang.Name + break + } + } - c.Map(ctx) + next.ServeHTTP(ctx.Resp, ctx.Req) + }) } } diff --git a/modules/context/csrf.go b/modules/context/csrf.go new file mode 100644 index 0000000000..4a26664bf3 --- /dev/null +++ b/modules/context/csrf.go @@ -0,0 +1,265 @@ +// Copyright 2013 Martini Authors +// Copyright 2014 The Macaron Authors +// Copyright 2021 The Gitea Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +// a middleware that generates and validates CSRF tokens. + +package context + +import ( + "net/http" + "time" + + "github.com/unknwon/com" +) + +// CSRF represents a CSRF service and is used to get the current token and validate a suspect token. +type CSRF interface { + // Return HTTP header to search for token. + GetHeaderName() string + // Return form value to search for token. + GetFormName() string + // Return cookie name to search for token. + GetCookieName() string + // Return cookie path + GetCookiePath() string + // Return the flag value used for the csrf token. + GetCookieHTTPOnly() bool + // Return the token. + GetToken() string + // Validate by token. + ValidToken(t string) bool + // Error replies to the request with a custom function when ValidToken fails. + Error(w http.ResponseWriter) +} + +type csrf struct { + // Header name value for setting and getting csrf token. + Header string + // Form name value for setting and getting csrf token. + Form string + // Cookie name value for setting and getting csrf token. + Cookie string + //Cookie domain + CookieDomain string + //Cookie path + CookiePath string + // Cookie HttpOnly flag value used for the csrf token. + CookieHTTPOnly bool + // Token generated to pass via header, cookie, or hidden form value. + Token string + // This value must be unique per user. + ID string + // Secret used along with the unique id above to generate the Token. + Secret string + // ErrorFunc is the custom function that replies to the request when ValidToken fails. + ErrorFunc func(w http.ResponseWriter) +} + +// GetHeaderName returns the name of the HTTP header for csrf token. +func (c *csrf) GetHeaderName() string { + return c.Header +} + +// GetFormName returns the name of the form value for csrf token. +func (c *csrf) GetFormName() string { + return c.Form +} + +// GetCookieName returns the name of the cookie for csrf token. +func (c *csrf) GetCookieName() string { + return c.Cookie +} + +// GetCookiePath returns the path of the cookie for csrf token. +func (c *csrf) GetCookiePath() string { + return c.CookiePath +} + +// GetCookieHTTPOnly returns the flag value used for the csrf token. +func (c *csrf) GetCookieHTTPOnly() bool { + return c.CookieHTTPOnly +} + +// GetToken returns the current token. This is typically used +// to populate a hidden form in an HTML template. +func (c *csrf) GetToken() string { + return c.Token +} + +// ValidToken validates the passed token against the existing Secret and ID. +func (c *csrf) ValidToken(t string) bool { + return ValidToken(t, c.Secret, c.ID, "POST") +} + +// Error replies to the request when ValidToken fails. +func (c *csrf) Error(w http.ResponseWriter) { + c.ErrorFunc(w) +} + +// CsrfOptions maintains options to manage behavior of Generate. +type CsrfOptions struct { + // The global secret value used to generate Tokens. + Secret string + // HTTP header used to set and get token. + Header string + // Form value used to set and get token. + Form string + // Cookie value used to set and get token. + Cookie string + // Cookie domain. + CookieDomain string + // Cookie path. + CookiePath string + CookieHTTPOnly bool + // SameSite set the cookie SameSite type + SameSite http.SameSite + // Key used for getting the unique ID per user. + SessionKey string + // oldSessionKey saves old value corresponding to SessionKey. + oldSessionKey string + // If true, send token via X-CSRFToken header. + SetHeader bool + // If true, send token via _csrf cookie. + SetCookie bool + // Set the Secure flag to true on the cookie. + Secure bool + // Disallow Origin appear in request header. + Origin bool + // The function called when Validate fails. + ErrorFunc func(w http.ResponseWriter) + // Cookie life time. Default is 0 + CookieLifeTime int +} + +func prepareOptions(options []CsrfOptions) CsrfOptions { + var opt CsrfOptions + if len(options) > 0 { + opt = options[0] + } + + // Defaults. + if len(opt.Secret) == 0 { + opt.Secret = string(com.RandomCreateBytes(10)) + } + if len(opt.Header) == 0 { + opt.Header = "X-CSRFToken" + } + if len(opt.Form) == 0 { + opt.Form = "_csrf" + } + if len(opt.Cookie) == 0 { + opt.Cookie = "_csrf" + } + if len(opt.CookiePath) == 0 { + opt.CookiePath = "/" + } + if len(opt.SessionKey) == 0 { + opt.SessionKey = "uid" + } + opt.oldSessionKey = "_old_" + opt.SessionKey + if opt.ErrorFunc == nil { + opt.ErrorFunc = func(w http.ResponseWriter) { + http.Error(w, "Invalid csrf token.", http.StatusBadRequest) + } + } + + return opt +} + +// Csrfer maps CSRF to each request. If this request is a Get request, it will generate a new token. +// Additionally, depending on options set, generated tokens will be sent via Header and/or Cookie. +func Csrfer(opt CsrfOptions, ctx *Context) CSRF { + opt = prepareOptions([]CsrfOptions{opt}) + x := &csrf{ + Secret: opt.Secret, + Header: opt.Header, + Form: opt.Form, + Cookie: opt.Cookie, + CookieDomain: opt.CookieDomain, + CookiePath: opt.CookiePath, + CookieHTTPOnly: opt.CookieHTTPOnly, + ErrorFunc: opt.ErrorFunc, + } + + if opt.Origin && len(ctx.Req.Header.Get("Origin")) > 0 { + return x + } + + x.ID = "0" + uid := ctx.Session.Get(opt.SessionKey) + if uid != nil { + x.ID = com.ToStr(uid) + } + + needsNew := false + oldUID := ctx.Session.Get(opt.oldSessionKey) + if oldUID == nil || oldUID.(string) != x.ID { + needsNew = true + _ = ctx.Session.Set(opt.oldSessionKey, x.ID) + } else { + // If cookie present, map existing token, else generate a new one. + if val := ctx.GetCookie(opt.Cookie); len(val) > 0 { + // FIXME: test coverage. + x.Token = val + } else { + needsNew = true + } + } + + if needsNew { + // FIXME: actionId. + x.Token = GenerateToken(x.Secret, x.ID, "POST") + if opt.SetCookie { + var expires interface{} + if opt.CookieLifeTime == 0 { + expires = time.Now().AddDate(0, 0, 1) + } + ctx.SetCookie(opt.Cookie, x.Token, opt.CookieLifeTime, opt.CookiePath, opt.CookieDomain, opt.Secure, opt.CookieHTTPOnly, expires, + func(c *http.Cookie) { + c.SameSite = opt.SameSite + }, + ) + } + } + + if opt.SetHeader { + ctx.Resp.Header().Add(opt.Header, x.Token) + } + return x +} + +// Validate should be used as a per route middleware. It attempts to get a token from a "X-CSRFToken" +// HTTP header and then a "_csrf" form value. If one of these is found, the token will be validated +// using ValidToken. If this validation fails, custom Error is sent in the reply. +// If neither a header or form value is found, http.StatusBadRequest is sent. +func Validate(ctx *Context, x CSRF) { + if token := ctx.Req.Header.Get(x.GetHeaderName()); len(token) > 0 { + if !x.ValidToken(token) { + ctx.SetCookie(x.GetCookieName(), "", -1, x.GetCookiePath()) + x.Error(ctx.Resp) + } + return + } + if token := ctx.Req.FormValue(x.GetFormName()); len(token) > 0 { + if !x.ValidToken(token) { + ctx.SetCookie(x.GetCookieName(), "", -1, x.GetCookiePath()) + x.Error(ctx.Resp) + } + return + } + + http.Error(ctx.Resp, "Bad Request: no CSRF token present", http.StatusBadRequest) +} diff --git a/modules/context/form.go b/modules/context/form.go new file mode 100644 index 0000000000..c7b76c614c --- /dev/null +++ b/modules/context/form.go @@ -0,0 +1,227 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package context + +import ( + "errors" + "net/http" + "net/url" + "strconv" + "strings" + "text/template" + + "code.gitea.io/gitea/modules/log" +) + +// Forms a new enhancement of http.Request +type Forms http.Request + +// Values returns http.Request values +func (f *Forms) Values() url.Values { + return (*http.Request)(f).Form +} + +// String returns request form as string +func (f *Forms) String(key string) (string, error) { + return (*http.Request)(f).FormValue(key), nil +} + +// Trimmed returns request form as string with trimed spaces left and right +func (f *Forms) Trimmed(key string) (string, error) { + return strings.TrimSpace((*http.Request)(f).FormValue(key)), nil +} + +// Strings returns request form as strings +func (f *Forms) Strings(key string) ([]string, error) { + if (*http.Request)(f).Form == nil { + if err := (*http.Request)(f).ParseMultipartForm(32 << 20); err != nil { + return nil, err + } + } + if v, ok := (*http.Request)(f).Form[key]; ok { + return v, nil + } + return nil, errors.New("not exist") +} + +// Escape returns request form as escaped string +func (f *Forms) Escape(key string) (string, error) { + return template.HTMLEscapeString((*http.Request)(f).FormValue(key)), nil +} + +// Int returns request form as int +func (f *Forms) Int(key string) (int, error) { + return strconv.Atoi((*http.Request)(f).FormValue(key)) +} + +// Int32 returns request form as int32 +func (f *Forms) Int32(key string) (int32, error) { + v, err := strconv.ParseInt((*http.Request)(f).FormValue(key), 10, 32) + return int32(v), err +} + +// Int64 returns request form as int64 +func (f *Forms) Int64(key string) (int64, error) { + return strconv.ParseInt((*http.Request)(f).FormValue(key), 10, 64) +} + +// Uint returns request form as uint +func (f *Forms) Uint(key string) (uint, error) { + v, err := strconv.ParseUint((*http.Request)(f).FormValue(key), 10, 64) + return uint(v), err +} + +// Uint32 returns request form as uint32 +func (f *Forms) Uint32(key string) (uint32, error) { + v, err := strconv.ParseUint((*http.Request)(f).FormValue(key), 10, 32) + return uint32(v), err +} + +// Uint64 returns request form as uint64 +func (f *Forms) Uint64(key string) (uint64, error) { + return strconv.ParseUint((*http.Request)(f).FormValue(key), 10, 64) +} + +// Bool returns request form as bool +func (f *Forms) Bool(key string) (bool, error) { + return strconv.ParseBool((*http.Request)(f).FormValue(key)) +} + +// Float32 returns request form as float32 +func (f *Forms) Float32(key string) (float32, error) { + v, err := strconv.ParseFloat((*http.Request)(f).FormValue(key), 64) + return float32(v), err +} + +// Float64 returns request form as float64 +func (f *Forms) Float64(key string) (float64, error) { + return strconv.ParseFloat((*http.Request)(f).FormValue(key), 64) +} + +// MustString returns request form as string with default +func (f *Forms) MustString(key string, defaults ...string) string { + if v := (*http.Request)(f).FormValue(key); len(v) > 0 { + return v + } + if len(defaults) > 0 { + return defaults[0] + } + return "" +} + +// MustTrimmed returns request form as string with default +func (f *Forms) MustTrimmed(key string, defaults ...string) string { + return strings.TrimSpace(f.MustString(key, defaults...)) +} + +// MustStrings returns request form as strings with default +func (f *Forms) MustStrings(key string, defaults ...[]string) []string { + if (*http.Request)(f).Form == nil { + if err := (*http.Request)(f).ParseMultipartForm(32 << 20); err != nil { + log.Error("ParseMultipartForm: %v", err) + return []string{} + } + } + + if v, ok := (*http.Request)(f).Form[key]; ok { + return v + } + if len(defaults) > 0 { + return defaults[0] + } + return []string{} +} + +// MustEscape returns request form as escaped string with default +func (f *Forms) MustEscape(key string, defaults ...string) string { + if v := (*http.Request)(f).FormValue(key); len(v) > 0 { + return template.HTMLEscapeString(v) + } + if len(defaults) > 0 { + return defaults[0] + } + return "" +} + +// MustInt returns request form as int with default +func (f *Forms) MustInt(key string, defaults ...int) int { + v, err := strconv.Atoi((*http.Request)(f).FormValue(key)) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return v +} + +// MustInt32 returns request form as int32 with default +func (f *Forms) MustInt32(key string, defaults ...int32) int32 { + v, err := strconv.ParseInt((*http.Request)(f).FormValue(key), 10, 32) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return int32(v) +} + +// MustInt64 returns request form as int64 with default +func (f *Forms) MustInt64(key string, defaults ...int64) int64 { + v, err := strconv.ParseInt((*http.Request)(f).FormValue(key), 10, 64) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return v +} + +// MustUint returns request form as uint with default +func (f *Forms) MustUint(key string, defaults ...uint) uint { + v, err := strconv.ParseUint((*http.Request)(f).FormValue(key), 10, 64) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return uint(v) +} + +// MustUint32 returns request form as uint32 with default +func (f *Forms) MustUint32(key string, defaults ...uint32) uint32 { + v, err := strconv.ParseUint((*http.Request)(f).FormValue(key), 10, 32) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return uint32(v) +} + +// MustUint64 returns request form as uint64 with default +func (f *Forms) MustUint64(key string, defaults ...uint64) uint64 { + v, err := strconv.ParseUint((*http.Request)(f).FormValue(key), 10, 64) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return v +} + +// MustFloat32 returns request form as float32 with default +func (f *Forms) MustFloat32(key string, defaults ...float32) float32 { + v, err := strconv.ParseFloat((*http.Request)(f).FormValue(key), 32) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return float32(v) +} + +// MustFloat64 returns request form as float64 with default +func (f *Forms) MustFloat64(key string, defaults ...float64) float64 { + v, err := strconv.ParseFloat((*http.Request)(f).FormValue(key), 64) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return v +} + +// MustBool returns request form as bool with default +func (f *Forms) MustBool(key string, defaults ...bool) bool { + v, err := strconv.ParseBool((*http.Request)(f).FormValue(key)) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return v +} diff --git a/modules/context/org.go b/modules/context/org.go index f61a39c666..83d385a1e9 100644 --- a/modules/context/org.go +++ b/modules/context/org.go @@ -10,8 +10,6 @@ import ( "code.gitea.io/gitea/models" "code.gitea.io/gitea/modules/setting" - - "gitea.com/macaron/macaron" ) // Organization contains organization context @@ -173,7 +171,7 @@ func HandleOrgAssignment(ctx *Context, args ...bool) { } // OrgAssignment returns a macaron middleware to handle organization assignment -func OrgAssignment(args ...bool) macaron.Handler { +func OrgAssignment(args ...bool) func(ctx *Context) { return func(ctx *Context) { HandleOrgAssignment(ctx, args...) } diff --git a/modules/context/permission.go b/modules/context/permission.go index 151be9f832..6fb8237e22 100644 --- a/modules/context/permission.go +++ b/modules/context/permission.go @@ -7,12 +7,10 @@ package context import ( "code.gitea.io/gitea/models" "code.gitea.io/gitea/modules/log" - - "gitea.com/macaron/macaron" ) // RequireRepoAdmin returns a macaron middleware for requiring repository admin permission -func RequireRepoAdmin() macaron.Handler { +func RequireRepoAdmin() func(ctx *Context) { return func(ctx *Context) { if !ctx.IsSigned || !ctx.Repo.IsAdmin() { ctx.NotFound(ctx.Req.URL.RequestURI(), nil) @@ -22,7 +20,7 @@ func RequireRepoAdmin() macaron.Handler { } // RequireRepoWriter returns a macaron middleware for requiring repository write to the specify unitType -func RequireRepoWriter(unitType models.UnitType) macaron.Handler { +func RequireRepoWriter(unitType models.UnitType) func(ctx *Context) { return func(ctx *Context) { if !ctx.Repo.CanWrite(unitType) { ctx.NotFound(ctx.Req.URL.RequestURI(), nil) @@ -32,7 +30,7 @@ func RequireRepoWriter(unitType models.UnitType) macaron.Handler { } // RequireRepoWriterOr returns a macaron middleware for requiring repository write to one of the unit permission -func RequireRepoWriterOr(unitTypes ...models.UnitType) macaron.Handler { +func RequireRepoWriterOr(unitTypes ...models.UnitType) func(ctx *Context) { return func(ctx *Context) { for _, unitType := range unitTypes { if ctx.Repo.CanWrite(unitType) { @@ -44,7 +42,7 @@ func RequireRepoWriterOr(unitTypes ...models.UnitType) macaron.Handler { } // RequireRepoReader returns a macaron middleware for requiring repository read to the specify unitType -func RequireRepoReader(unitType models.UnitType) macaron.Handler { +func RequireRepoReader(unitType models.UnitType) func(ctx *Context) { return func(ctx *Context) { if !ctx.Repo.CanRead(unitType) { if log.IsTrace() { @@ -70,7 +68,7 @@ func RequireRepoReader(unitType models.UnitType) macaron.Handler { } // RequireRepoReaderOr returns a macaron middleware for requiring repository write to one of the unit permission -func RequireRepoReaderOr(unitTypes ...models.UnitType) macaron.Handler { +func RequireRepoReaderOr(unitTypes ...models.UnitType) func(ctx *Context) { return func(ctx *Context) { for _, unitType := range unitTypes { if ctx.Repo.CanRead(unitType) { diff --git a/modules/context/private.go b/modules/context/private.go new file mode 100644 index 0000000000..a246100050 --- /dev/null +++ b/modules/context/private.go @@ -0,0 +1,45 @@ +// Copyright 2020 The Gitea Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package context + +import ( + "context" + "net/http" +) + +// PrivateContext represents a context for private routes +type PrivateContext struct { + *Context +} + +var ( + privateContextKey interface{} = "default_private_context" +) + +// WithPrivateContext set up private context in request +func WithPrivateContext(req *http.Request, ctx *PrivateContext) *http.Request { + return req.WithContext(context.WithValue(req.Context(), privateContextKey, ctx)) +} + +// GetPrivateContext returns a context for Private routes +func GetPrivateContext(req *http.Request) *PrivateContext { + return req.Context().Value(privateContextKey).(*PrivateContext) +} + +// PrivateContexter returns apicontext as macaron middleware +func PrivateContexter() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + ctx := &PrivateContext{ + Context: &Context{ + Resp: NewResponse(w), + Data: map[string]interface{}{}, + }, + } + ctx.Req = WithPrivateContext(req, ctx) + next.ServeHTTP(ctx.Resp, ctx.Req) + }) + } +} diff --git a/modules/context/repo.go b/modules/context/repo.go index 63cb02dc06..79192267fb 100644 --- a/modules/context/repo.go +++ b/modules/context/repo.go @@ -8,6 +8,7 @@ package context import ( "fmt" "io/ioutil" + "net/http" "net/url" "path" "strings" @@ -21,7 +22,6 @@ import ( api "code.gitea.io/gitea/modules/structs" "code.gitea.io/gitea/modules/util" - "gitea.com/macaron/macaron" "github.com/editorconfig/editorconfig-core-go/v2" "github.com/unknwon/com" ) @@ -81,7 +81,7 @@ func (r *Repository) CanCreateBranch() bool { } // RepoMustNotBeArchived checks if a repo is archived -func RepoMustNotBeArchived() macaron.Handler { +func RepoMustNotBeArchived() func(ctx *Context) { return func(ctx *Context) { if ctx.Repo.Repository.IsArchived { ctx.NotFound("IsArchived", fmt.Errorf(ctx.Tr("repo.archive.title"))) @@ -374,7 +374,7 @@ func repoAssignment(ctx *Context, repo *models.Repository) { } // RepoIDAssignment returns a macaron handler which assigns the repo to the context. -func RepoIDAssignment() macaron.Handler { +func RepoIDAssignment() func(ctx *Context) { return func(ctx *Context) { repoID := ctx.ParamsInt64(":repoid") @@ -394,223 +394,220 @@ func RepoIDAssignment() macaron.Handler { } // RepoAssignment returns a macaron to handle repository assignment -func RepoAssignment() macaron.Handler { - return func(ctx *Context) { - var ( - owner *models.User - err error - ) - - userName := ctx.Params(":username") - repoName := ctx.Params(":reponame") +func RepoAssignment() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + var ( + owner *models.User + err error + ctx = GetContext(req) + ) + + userName := ctx.Params(":username") + repoName := ctx.Params(":reponame") + repoName = strings.TrimSuffix(repoName, ".git") + + // Check if the user is the same as the repository owner + if ctx.IsSigned && ctx.User.LowerName == strings.ToLower(userName) { + owner = ctx.User + } else { + owner, err = models.GetUserByName(userName) + if err != nil { + if models.IsErrUserNotExist(err) { + if ctx.Query("go-get") == "1" { + EarlyResponseForGoGetMeta(ctx) + return + } + ctx.NotFound("GetUserByName", nil) + } else { + ctx.ServerError("GetUserByName", err) + } + return + } + } + ctx.Repo.Owner = owner + ctx.Data["Username"] = ctx.Repo.Owner.Name - // Check if the user is the same as the repository owner - if ctx.IsSigned && ctx.User.LowerName == strings.ToLower(userName) { - owner = ctx.User - } else { - owner, err = models.GetUserByName(userName) + // Get repository. + repo, err := models.GetRepositoryByName(owner.ID, repoName) if err != nil { - if models.IsErrUserNotExist(err) { - redirectUserID, err := models.LookupUserRedirect(userName) + if models.IsErrRepoNotExist(err) { + redirectRepoID, err := models.LookupRepoRedirect(owner.ID, repoName) if err == nil { - RedirectToUser(ctx, userName, redirectUserID) - } else if models.IsErrUserRedirectNotExist(err) { + RedirectToRepo(ctx, redirectRepoID) + } else if models.IsErrRepoRedirectNotExist(err) { if ctx.Query("go-get") == "1" { EarlyResponseForGoGetMeta(ctx) return } - ctx.NotFound("GetUserByName", nil) + ctx.NotFound("GetRepositoryByName", nil) } else { - ctx.ServerError("LookupUserRedirect", err) + ctx.ServerError("LookupRepoRedirect", err) } } else { - ctx.ServerError("GetUserByName", err) + ctx.ServerError("GetRepositoryByName", err) } return } - } - ctx.Repo.Owner = owner - ctx.Data["Username"] = ctx.Repo.Owner.Name + repo.Owner = owner - // Get repository. - repo, err := models.GetRepositoryByName(owner.ID, repoName) - if err != nil { - if models.IsErrRepoNotExist(err) { - redirectRepoID, err := models.LookupRepoRedirect(owner.ID, repoName) - if err == nil { - RedirectToRepo(ctx, redirectRepoID) - } else if models.IsErrRepoRedirectNotExist(err) { - if ctx.Query("go-get") == "1" { - EarlyResponseForGoGetMeta(ctx) - return - } - ctx.NotFound("GetRepositoryByName", nil) - } else { - ctx.ServerError("LookupRepoRedirect", err) - } - } else { - ctx.ServerError("GetRepositoryByName", err) + repoAssignment(ctx, repo) + if ctx.Written() { + return } - return - } - repo.Owner = owner - repoAssignment(ctx, repo) - if ctx.Written() { - return - } + ctx.Repo.RepoLink = repo.Link() + ctx.Data["RepoLink"] = ctx.Repo.RepoLink + ctx.Data["RepoRelPath"] = ctx.Repo.Owner.Name + "/" + ctx.Repo.Repository.Name - ctx.Repo.RepoLink = repo.Link() - ctx.Data["RepoLink"] = ctx.Repo.RepoLink - ctx.Data["RepoRelPath"] = ctx.Repo.Owner.Name + "/" + ctx.Repo.Repository.Name + unit, err := ctx.Repo.Repository.GetUnit(models.UnitTypeExternalTracker) + if err == nil { + ctx.Data["RepoExternalIssuesLink"] = unit.ExternalTrackerConfig().ExternalTrackerURL + } - unit, err := ctx.Repo.Repository.GetUnit(models.UnitTypeExternalTracker) - if err == nil { - ctx.Data["RepoExternalIssuesLink"] = unit.ExternalTrackerConfig().ExternalTrackerURL - } + ctx.Data["NumTags"], err = models.GetReleaseCountByRepoID(ctx.Repo.Repository.ID, models.FindReleasesOptions{ + IncludeTags: true, + }) + if err != nil { + ctx.ServerError("GetReleaseCountByRepoID", err) + return + } + ctx.Data["NumReleases"], err = models.GetReleaseCountByRepoID(ctx.Repo.Repository.ID, models.FindReleasesOptions{}) + if err != nil { + ctx.ServerError("GetReleaseCountByRepoID", err) + return + } - ctx.Data["NumTags"], err = models.GetReleaseCountByRepoID(ctx.Repo.Repository.ID, models.FindReleasesOptions{ - IncludeTags: true, - }) - if err != nil { - ctx.ServerError("GetReleaseCountByRepoID", err) - return - } - ctx.Data["NumReleases"], err = models.GetReleaseCountByRepoID(ctx.Repo.Repository.ID, models.FindReleasesOptions{}) - if err != nil { - ctx.ServerError("GetReleaseCountByRepoID", err) - return - } + ctx.Data["Title"] = owner.Name + "/" + repo.Name + ctx.Data["Repository"] = repo + ctx.Data["Owner"] = ctx.Repo.Repository.Owner + ctx.Data["IsRepositoryOwner"] = ctx.Repo.IsOwner() + ctx.Data["IsRepositoryAdmin"] = ctx.Repo.IsAdmin() + ctx.Data["RepoOwnerIsOrganization"] = repo.Owner.IsOrganization() + ctx.Data["CanWriteCode"] = ctx.Repo.CanWrite(models.UnitTypeCode) + ctx.Data["CanWriteIssues"] = ctx.Repo.CanWrite(models.UnitTypeIssues) + ctx.Data["CanWritePulls"] = ctx.Repo.CanWrite(models.UnitTypePullRequests) + + if ctx.Data["CanSignedUserFork"], err = ctx.Repo.Repository.CanUserFork(ctx.User); err != nil { + ctx.ServerError("CanUserFork", err) + return + } - ctx.Data["Title"] = owner.Name + "/" + repo.Name - ctx.Data["Repository"] = repo - ctx.Data["Owner"] = ctx.Repo.Repository.Owner - ctx.Data["IsRepositoryOwner"] = ctx.Repo.IsOwner() - ctx.Data["IsRepositoryAdmin"] = ctx.Repo.IsAdmin() - ctx.Data["RepoOwnerIsOrganization"] = repo.Owner.IsOrganization() - ctx.Data["CanWriteCode"] = ctx.Repo.CanWrite(models.UnitTypeCode) - ctx.Data["CanWriteIssues"] = ctx.Repo.CanWrite(models.UnitTypeIssues) - ctx.Data["CanWritePulls"] = ctx.Repo.CanWrite(models.UnitTypePullRequests) - - if ctx.Data["CanSignedUserFork"], err = ctx.Repo.Repository.CanUserFork(ctx.User); err != nil { - ctx.ServerError("CanUserFork", err) - return - } + ctx.Data["DisableSSH"] = setting.SSH.Disabled + ctx.Data["ExposeAnonSSH"] = setting.SSH.ExposeAnonymous + ctx.Data["DisableHTTP"] = setting.Repository.DisableHTTPGit + ctx.Data["RepoSearchEnabled"] = setting.Indexer.RepoIndexerEnabled + ctx.Data["CloneLink"] = repo.CloneLink() + ctx.Data["WikiCloneLink"] = repo.WikiCloneLink() - ctx.Data["DisableSSH"] = setting.SSH.Disabled - ctx.Data["ExposeAnonSSH"] = setting.SSH.ExposeAnonymous - ctx.Data["DisableHTTP"] = setting.Repository.DisableHTTPGit - ctx.Data["RepoSearchEnabled"] = setting.Indexer.RepoIndexerEnabled - ctx.Data["CloneLink"] = repo.CloneLink() - ctx.Data["WikiCloneLink"] = repo.WikiCloneLink() + if ctx.IsSigned { + ctx.Data["IsWatchingRepo"] = models.IsWatching(ctx.User.ID, repo.ID) + ctx.Data["IsStaringRepo"] = models.IsStaring(ctx.User.ID, repo.ID) + } - if ctx.IsSigned { - ctx.Data["IsWatchingRepo"] = models.IsWatching(ctx.User.ID, repo.ID) - ctx.Data["IsStaringRepo"] = models.IsStaring(ctx.User.ID, repo.ID) - } + if repo.IsFork { + RetrieveBaseRepo(ctx, repo) + if ctx.Written() { + return + } + } - if repo.IsFork { - RetrieveBaseRepo(ctx, repo) - if ctx.Written() { - return + if repo.IsGenerated() { + RetrieveTemplateRepo(ctx, repo) + if ctx.Written() { + return + } } - } - if repo.IsGenerated() { - RetrieveTemplateRepo(ctx, repo) - if ctx.Written() { + // Disable everything when the repo is being created + if ctx.Repo.Repository.IsBeingCreated() { + ctx.Data["BranchName"] = ctx.Repo.Repository.DefaultBranch return } - } - - // Disable everything when the repo is being created - if ctx.Repo.Repository.IsBeingCreated() { - ctx.Data["BranchName"] = ctx.Repo.Repository.DefaultBranch - return - } - gitRepo, err := git.OpenRepository(models.RepoPath(userName, repoName)) - if err != nil { - ctx.ServerError("RepoAssignment Invalid repo "+models.RepoPath(userName, repoName), err) - return - } - ctx.Repo.GitRepo = gitRepo - - // We opened it, we should close it - defer func() { - // If it's been set to nil then assume someone else has closed it. - if ctx.Repo.GitRepo != nil { - ctx.Repo.GitRepo.Close() + gitRepo, err := git.OpenRepository(models.RepoPath(userName, repoName)) + if err != nil { + ctx.ServerError("RepoAssignment Invalid repo "+models.RepoPath(userName, repoName), err) + return } - }() + ctx.Repo.GitRepo = gitRepo - // Stop at this point when the repo is empty. - if ctx.Repo.Repository.IsEmpty { - ctx.Data["BranchName"] = ctx.Repo.Repository.DefaultBranch - ctx.Next() - return - } + // We opened it, we should close it + defer func() { + // If it's been set to nil then assume someone else has closed it. + if ctx.Repo.GitRepo != nil { + ctx.Repo.GitRepo.Close() + } + }() - tags, err := ctx.Repo.GitRepo.GetTags() - if err != nil { - ctx.ServerError("GetTags", err) - return - } - ctx.Data["Tags"] = tags + // Stop at this point when the repo is empty. + if ctx.Repo.Repository.IsEmpty { + ctx.Data["BranchName"] = ctx.Repo.Repository.DefaultBranch + next.ServeHTTP(w, req) + return + } - brs, err := ctx.Repo.GitRepo.GetBranches() - if err != nil { - ctx.ServerError("GetBranches", err) - return - } - ctx.Data["Branches"] = brs - ctx.Data["BranchesCount"] = len(brs) - - ctx.Data["TagName"] = ctx.Repo.TagName - - // If not branch selected, try default one. - // If default branch doesn't exists, fall back to some other branch. - if len(ctx.Repo.BranchName) == 0 { - if len(ctx.Repo.Repository.DefaultBranch) > 0 && gitRepo.IsBranchExist(ctx.Repo.Repository.DefaultBranch) { - ctx.Repo.BranchName = ctx.Repo.Repository.DefaultBranch - } else if len(brs) > 0 { - ctx.Repo.BranchName = brs[0] + tags, err := ctx.Repo.GitRepo.GetTags() + if err != nil { + ctx.ServerError("GetTags", err) + return } - } - ctx.Data["BranchName"] = ctx.Repo.BranchName - ctx.Data["CommitID"] = ctx.Repo.CommitID - - // People who have push access or have forked repository can propose a new pull request. - canPush := ctx.Repo.CanWrite(models.UnitTypeCode) || (ctx.IsSigned && ctx.User.HasForkedRepo(ctx.Repo.Repository.ID)) - canCompare := false - - // Pull request is allowed if this is a fork repository - // and base repository accepts pull requests. - if repo.BaseRepo != nil && repo.BaseRepo.AllowsPulls() { - canCompare = true - ctx.Data["BaseRepo"] = repo.BaseRepo - ctx.Repo.PullRequest.BaseRepo = repo.BaseRepo - ctx.Repo.PullRequest.Allowed = canPush - ctx.Repo.PullRequest.HeadInfo = ctx.Repo.Owner.Name + ":" + ctx.Repo.BranchName - } else if repo.AllowsPulls() { - // Or, this is repository accepts pull requests between branches. - canCompare = true - ctx.Data["BaseRepo"] = repo - ctx.Repo.PullRequest.BaseRepo = repo - ctx.Repo.PullRequest.Allowed = canPush - ctx.Repo.PullRequest.SameRepo = true - ctx.Repo.PullRequest.HeadInfo = ctx.Repo.BranchName - } - ctx.Data["CanCompareOrPull"] = canCompare - ctx.Data["PullRequestCtx"] = ctx.Repo.PullRequest + ctx.Data["Tags"] = tags - if ctx.Query("go-get") == "1" { - ctx.Data["GoGetImport"] = ComposeGoGetImport(owner.Name, repo.Name) - prefix := setting.AppURL + path.Join(owner.Name, repo.Name, "src", "branch", ctx.Repo.BranchName) - ctx.Data["GoDocDirectory"] = prefix + "{/dir}" - ctx.Data["GoDocFile"] = prefix + "{/dir}/{file}#L{line}" - } - ctx.Next() + brs, err := ctx.Repo.GitRepo.GetBranches() + if err != nil { + ctx.ServerError("GetBranches", err) + return + } + ctx.Data["Branches"] = brs + ctx.Data["BranchesCount"] = len(brs) + + ctx.Data["TagName"] = ctx.Repo.TagName + + // If not branch selected, try default one. + // If default branch doesn't exists, fall back to some other branch. + if len(ctx.Repo.BranchName) == 0 { + if len(ctx.Repo.Repository.DefaultBranch) > 0 && gitRepo.IsBranchExist(ctx.Repo.Repository.DefaultBranch) { + ctx.Repo.BranchName = ctx.Repo.Repository.DefaultBranch + } else if len(brs) > 0 { + ctx.Repo.BranchName = brs[0] + } + } + ctx.Data["BranchName"] = ctx.Repo.BranchName + ctx.Data["CommitID"] = ctx.Repo.CommitID + + // People who have push access or have forked repository can propose a new pull request. + canPush := ctx.Repo.CanWrite(models.UnitTypeCode) || (ctx.IsSigned && ctx.User.HasForkedRepo(ctx.Repo.Repository.ID)) + canCompare := false + + // Pull request is allowed if this is a fork repository + // and base repository accepts pull requests. + if repo.BaseRepo != nil && repo.BaseRepo.AllowsPulls() { + canCompare = true + ctx.Data["BaseRepo"] = repo.BaseRepo + ctx.Repo.PullRequest.BaseRepo = repo.BaseRepo + ctx.Repo.PullRequest.Allowed = canPush + ctx.Repo.PullRequest.HeadInfo = ctx.Repo.Owner.Name + ":" + ctx.Repo.BranchName + } else if repo.AllowsPulls() { + // Or, this is repository accepts pull requests between branches. + canCompare = true + ctx.Data["BaseRepo"] = repo + ctx.Repo.PullRequest.BaseRepo = repo + ctx.Repo.PullRequest.Allowed = canPush + ctx.Repo.PullRequest.SameRepo = true + ctx.Repo.PullRequest.HeadInfo = ctx.Repo.BranchName + } + ctx.Data["CanCompareOrPull"] = canCompare + ctx.Data["PullRequestCtx"] = ctx.Repo.PullRequest + + if ctx.Query("go-get") == "1" { + ctx.Data["GoGetImport"] = ComposeGoGetImport(owner.Name, repo.Name) + prefix := setting.AppURL + path.Join(owner.Name, repo.Name, "src", "branch", ctx.Repo.BranchName) + ctx.Data["GoDocDirectory"] = prefix + "{/dir}" + ctx.Data["GoDocFile"] = prefix + "{/dir}/{file}#L{line}" + } + next.ServeHTTP(w, req) + }) } } @@ -636,7 +633,7 @@ const ( // RepoRef handles repository reference names when the ref name is not // explicitly given -func RepoRef() macaron.Handler { +func RepoRef() func(http.Handler) http.Handler { // since no ref name is explicitly specified, ok to just use branch return RepoRefByType(RepoRefBranch) } @@ -715,132 +712,135 @@ func getRefName(ctx *Context, pathType RepoRefType) string { // RepoRefByType handles repository reference name for a specific type // of repository reference -func RepoRefByType(refType RepoRefType) macaron.Handler { - return func(ctx *Context) { - // Empty repository does not have reference information. - if ctx.Repo.Repository.IsEmpty { - return - } - - var ( - refName string - err error - ) - - if ctx.Repo.GitRepo == nil { - repoPath := models.RepoPath(ctx.Repo.Owner.Name, ctx.Repo.Repository.Name) - ctx.Repo.GitRepo, err = git.OpenRepository(repoPath) - if err != nil { - ctx.ServerError("RepoRef Invalid repo "+repoPath, err) +func RepoRefByType(refType RepoRefType) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + ctx := GetContext(req) + // Empty repository does not have reference information. + if ctx.Repo.Repository.IsEmpty { return } - // We opened it, we should close it - defer func() { - // If it's been set to nil then assume someone else has closed it. - if ctx.Repo.GitRepo != nil { - ctx.Repo.GitRepo.Close() - } - }() - } - // Get default branch. - if len(ctx.Params("*")) == 0 { - refName = ctx.Repo.Repository.DefaultBranch - ctx.Repo.BranchName = refName - if !ctx.Repo.GitRepo.IsBranchExist(refName) { - brs, err := ctx.Repo.GitRepo.GetBranches() + var ( + refName string + err error + ) + + if ctx.Repo.GitRepo == nil { + repoPath := models.RepoPath(ctx.Repo.Owner.Name, ctx.Repo.Repository.Name) + ctx.Repo.GitRepo, err = git.OpenRepository(repoPath) if err != nil { - ctx.ServerError("GetBranches", err) - return - } else if len(brs) == 0 { - err = fmt.Errorf("No branches in non-empty repository %s", - ctx.Repo.GitRepo.Path) - ctx.ServerError("GetBranches", err) + ctx.ServerError("RepoRef Invalid repo "+repoPath, err) return } - refName = brs[0] - } - ctx.Repo.Commit, err = ctx.Repo.GitRepo.GetBranchCommit(refName) - if err != nil { - ctx.ServerError("GetBranchCommit", err) - return + // We opened it, we should close it + defer func() { + // If it's been set to nil then assume someone else has closed it. + if ctx.Repo.GitRepo != nil { + ctx.Repo.GitRepo.Close() + } + }() } - ctx.Repo.CommitID = ctx.Repo.Commit.ID.String() - ctx.Repo.IsViewBranch = true - - } else { - refName = getRefName(ctx, refType) - ctx.Repo.BranchName = refName - if refType.RefTypeIncludesBranches() && ctx.Repo.GitRepo.IsBranchExist(refName) { - ctx.Repo.IsViewBranch = true + // Get default branch. + if len(ctx.Params("*")) == 0 { + refName = ctx.Repo.Repository.DefaultBranch + ctx.Repo.BranchName = refName + if !ctx.Repo.GitRepo.IsBranchExist(refName) { + brs, err := ctx.Repo.GitRepo.GetBranches() + if err != nil { + ctx.ServerError("GetBranches", err) + return + } else if len(brs) == 0 { + err = fmt.Errorf("No branches in non-empty repository %s", + ctx.Repo.GitRepo.Path) + ctx.ServerError("GetBranches", err) + return + } + refName = brs[0] + } ctx.Repo.Commit, err = ctx.Repo.GitRepo.GetBranchCommit(refName) if err != nil { ctx.ServerError("GetBranchCommit", err) return } ctx.Repo.CommitID = ctx.Repo.Commit.ID.String() + ctx.Repo.IsViewBranch = true - } else if refType.RefTypeIncludesTags() && ctx.Repo.GitRepo.IsTagExist(refName) { - ctx.Repo.IsViewTag = true - ctx.Repo.Commit, err = ctx.Repo.GitRepo.GetTagCommit(refName) - if err != nil { - ctx.ServerError("GetTagCommit", err) + } else { + refName = getRefName(ctx, refType) + ctx.Repo.BranchName = refName + if refType.RefTypeIncludesBranches() && ctx.Repo.GitRepo.IsBranchExist(refName) { + ctx.Repo.IsViewBranch = true + + ctx.Repo.Commit, err = ctx.Repo.GitRepo.GetBranchCommit(refName) + if err != nil { + ctx.ServerError("GetBranchCommit", err) + return + } + ctx.Repo.CommitID = ctx.Repo.Commit.ID.String() + + } else if refType.RefTypeIncludesTags() && ctx.Repo.GitRepo.IsTagExist(refName) { + ctx.Repo.IsViewTag = true + ctx.Repo.Commit, err = ctx.Repo.GitRepo.GetTagCommit(refName) + if err != nil { + ctx.ServerError("GetTagCommit", err) + return + } + ctx.Repo.CommitID = ctx.Repo.Commit.ID.String() + } else if len(refName) >= 7 && len(refName) <= 40 { + ctx.Repo.IsViewCommit = true + ctx.Repo.CommitID = refName + + ctx.Repo.Commit, err = ctx.Repo.GitRepo.GetCommit(refName) + if err != nil { + ctx.NotFound("GetCommit", err) + return + } + // If short commit ID add canonical link header + if len(refName) < 40 { + ctx.Header().Set("Link", fmt.Sprintf("<%s>; rel=\"canonical\"", + util.URLJoin(setting.AppURL, strings.Replace(ctx.Req.URL.RequestURI(), refName, ctx.Repo.Commit.ID.String(), 1)))) + } + } else { + ctx.NotFound("RepoRef invalid repo", fmt.Errorf("branch or tag not exist: %s", refName)) return } - ctx.Repo.CommitID = ctx.Repo.Commit.ID.String() - } else if len(refName) >= 7 && len(refName) <= 40 { - ctx.Repo.IsViewCommit = true - ctx.Repo.CommitID = refName - ctx.Repo.Commit, err = ctx.Repo.GitRepo.GetCommit(refName) - if err != nil { - ctx.NotFound("GetCommit", err) + if refType == RepoRefLegacy { + // redirect from old URL scheme to new URL scheme + ctx.Redirect(path.Join( + setting.AppSubURL, + strings.TrimSuffix(ctx.Req.URL.Path, ctx.Params("*")), + ctx.Repo.BranchNameSubURL(), + ctx.Repo.TreePath)) return } - // If short commit ID add canonical link header - if len(refName) < 40 { - ctx.Header().Set("Link", fmt.Sprintf("<%s>; rel=\"canonical\"", - util.URLJoin(setting.AppURL, strings.Replace(ctx.Req.URL.RequestURI(), refName, ctx.Repo.Commit.ID.String(), 1)))) - } - } else { - ctx.NotFound("RepoRef invalid repo", fmt.Errorf("branch or tag not exist: %s", refName)) - return } - if refType == RepoRefLegacy { - // redirect from old URL scheme to new URL scheme - ctx.Redirect(path.Join( - setting.AppSubURL, - strings.TrimSuffix(ctx.Req.URL.Path, ctx.Params("*")), - ctx.Repo.BranchNameSubURL(), - ctx.Repo.TreePath)) + ctx.Data["BranchName"] = ctx.Repo.BranchName + ctx.Data["BranchNameSubURL"] = ctx.Repo.BranchNameSubURL() + ctx.Data["CommitID"] = ctx.Repo.CommitID + ctx.Data["TreePath"] = ctx.Repo.TreePath + ctx.Data["IsViewBranch"] = ctx.Repo.IsViewBranch + ctx.Data["IsViewTag"] = ctx.Repo.IsViewTag + ctx.Data["IsViewCommit"] = ctx.Repo.IsViewCommit + ctx.Data["CanCreateBranch"] = ctx.Repo.CanCreateBranch() + + ctx.Repo.CommitsCount, err = ctx.Repo.GetCommitsCount() + if err != nil { + ctx.ServerError("GetCommitsCount", err) return } - } + ctx.Data["CommitsCount"] = ctx.Repo.CommitsCount - ctx.Data["BranchName"] = ctx.Repo.BranchName - ctx.Data["BranchNameSubURL"] = ctx.Repo.BranchNameSubURL() - ctx.Data["CommitID"] = ctx.Repo.CommitID - ctx.Data["TreePath"] = ctx.Repo.TreePath - ctx.Data["IsViewBranch"] = ctx.Repo.IsViewBranch - ctx.Data["IsViewTag"] = ctx.Repo.IsViewTag - ctx.Data["IsViewCommit"] = ctx.Repo.IsViewCommit - ctx.Data["CanCreateBranch"] = ctx.Repo.CanCreateBranch() - - ctx.Repo.CommitsCount, err = ctx.Repo.GetCommitsCount() - if err != nil { - ctx.ServerError("GetCommitsCount", err) - return - } - ctx.Data["CommitsCount"] = ctx.Repo.CommitsCount - - ctx.Next() + next.ServeHTTP(w, req) + }) } } // GitHookService checks if repository Git hooks service has been enabled. -func GitHookService() macaron.Handler { +func GitHookService() func(ctx *Context) { return func(ctx *Context) { if !ctx.User.CanEditGitHook() { ctx.NotFound("GitHookService", nil) @@ -850,7 +850,7 @@ func GitHookService() macaron.Handler { } // UnitTypes returns a macaron middleware to set unit types to context variables. -func UnitTypes() macaron.Handler { +func UnitTypes() func(ctx *Context) { return func(ctx *Context) { ctx.Data["UnitTypeCode"] = models.UnitTypeCode ctx.Data["UnitTypeIssues"] = models.UnitTypeIssues diff --git a/modules/context/response.go b/modules/context/response.go index 549bd30ee0..1881ec7b33 100644 --- a/modules/context/response.go +++ b/modules/context/response.go @@ -11,6 +11,7 @@ type ResponseWriter interface { http.ResponseWriter Flush() Status() int + Before(func(ResponseWriter)) } var ( @@ -20,11 +21,19 @@ var ( // Response represents a response type Response struct { http.ResponseWriter - status int + status int + befores []func(ResponseWriter) + beforeExecuted bool } // Write writes bytes to HTTP endpoint func (r *Response) Write(bs []byte) (int, error) { + if !r.beforeExecuted { + for _, before := range r.befores { + before(r) + } + r.beforeExecuted = true + } size, err := r.ResponseWriter.Write(bs) if err != nil { return 0, err @@ -37,6 +46,12 @@ func (r *Response) Write(bs []byte) (int, error) { // WriteHeader write status code func (r *Response) WriteHeader(statusCode int) { + if !r.beforeExecuted { + for _, before := range r.befores { + before(r) + } + r.beforeExecuted = true + } r.status = statusCode r.ResponseWriter.WriteHeader(statusCode) } @@ -53,10 +68,20 @@ func (r *Response) Status() int { return r.status } +// Before allows for a function to be called before the ResponseWriter has been written to. This is +// useful for setting headers or any other operations that must happen before a response has been written. +func (r *Response) Before(f func(ResponseWriter)) { + r.befores = append(r.befores, f) +} + // NewResponse creates a response func NewResponse(resp http.ResponseWriter) *Response { if v, ok := resp.(*Response); ok { return v } - return &Response{resp, 0} + return &Response{ + ResponseWriter: resp, + status: 0, + befores: make([]func(ResponseWriter), 0), + } } diff --git a/modules/context/secret.go b/modules/context/secret.go new file mode 100644 index 0000000000..fcb488d211 --- /dev/null +++ b/modules/context/secret.go @@ -0,0 +1,100 @@ +// Copyright 2019 The Gitea Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package context + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "errors" + "io" +) + +// NewSecret creates a new secret +func NewSecret() (string, error) { + return NewSecretWithLength(32) +} + +// NewSecretWithLength creates a new secret for a given length +func NewSecretWithLength(length int64) (string, error) { + return randomString(length) +} + +func randomBytes(len int64) ([]byte, error) { + b := make([]byte, len) + if _, err := rand.Read(b); err != nil { + return nil, err + } + return b, nil +} + +func randomString(len int64) (string, error) { + b, err := randomBytes(len) + return base64.URLEncoding.EncodeToString(b), err +} + +// AesEncrypt encrypts text and given key with AES. +func AesEncrypt(key, text []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + b := base64.StdEncoding.EncodeToString(text) + ciphertext := make([]byte, aes.BlockSize+len(b)) + iv := ciphertext[:aes.BlockSize] + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + return nil, err + } + cfb := cipher.NewCFBEncrypter(block, iv) + cfb.XORKeyStream(ciphertext[aes.BlockSize:], []byte(b)) + return ciphertext, nil +} + +// AesDecrypt decrypts text and given key with AES. +func AesDecrypt(key, text []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + if len(text) < aes.BlockSize { + return nil, errors.New("ciphertext too short") + } + iv := text[:aes.BlockSize] + text = text[aes.BlockSize:] + cfb := cipher.NewCFBDecrypter(block, iv) + cfb.XORKeyStream(text, text) + data, err := base64.StdEncoding.DecodeString(string(text)) + if err != nil { + return nil, err + } + return data, nil +} + +// EncryptSecret encrypts a string with given key into a hex string +func EncryptSecret(key string, str string) (string, error) { + keyHash := sha256.Sum256([]byte(key)) + plaintext := []byte(str) + ciphertext, err := AesEncrypt(keyHash[:], plaintext) + if err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(ciphertext), nil +} + +// DecryptSecret decrypts a previously encrypted hex string +func DecryptSecret(key string, cipherhex string) (string, error) { + keyHash := sha256.Sum256([]byte(key)) + ciphertext, err := base64.StdEncoding.DecodeString(cipherhex) + if err != nil { + return "", err + } + plaintext, err := AesDecrypt(keyHash[:], ciphertext) + if err != nil { + return "", err + } + return string(plaintext), nil +} diff --git a/modules/context/xsrf.go b/modules/context/xsrf.go new file mode 100644 index 0000000000..10e63a4180 --- /dev/null +++ b/modules/context/xsrf.go @@ -0,0 +1,98 @@ +// Copyright 2012 Google Inc. All Rights Reserved. +// Copyright 2014 The Macaron Authors +// Copyright 2020 The Gitea Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "bytes" + "crypto/hmac" + "crypto/sha1" + "crypto/subtle" + "encoding/base64" + "fmt" + "strconv" + "strings" + "time" +) + +// Timeout represents the duration that XSRF tokens are valid. +// It is exported so clients may set cookie timeouts that match generated tokens. +const Timeout = 24 * time.Hour + +// clean sanitizes a string for inclusion in a token by replacing all ":"s. +func clean(s string) string { + return strings.ReplaceAll(s, ":", "_") +} + +// GenerateToken returns a URL-safe secure XSRF token that expires in 24 hours. +// +// key is a secret key for your application. +// userID is a unique identifier for the user. +// actionID is the action the user is taking (e.g. POSTing to a particular path). +func GenerateToken(key, userID, actionID string) string { + return generateTokenAtTime(key, userID, actionID, time.Now()) +} + +// generateTokenAtTime is like Generate, but returns a token that expires 24 hours from now. +func generateTokenAtTime(key, userID, actionID string, now time.Time) string { + h := hmac.New(sha1.New, []byte(key)) + fmt.Fprintf(h, "%s:%s:%d", clean(userID), clean(actionID), now.UnixNano()) + tok := fmt.Sprintf("%s:%d", h.Sum(nil), now.UnixNano()) + return base64.RawURLEncoding.EncodeToString([]byte(tok)) +} + +// ValidToken returns true if token is a valid, unexpired token returned by Generate. +func ValidToken(token, key, userID, actionID string) bool { + return validTokenAtTime(token, key, userID, actionID, time.Now()) +} + +// validTokenAtTime is like Valid, but it uses now to check if the token is expired. +func validTokenAtTime(token, key, userID, actionID string, now time.Time) bool { + // Decode the token. + data, err := base64.RawURLEncoding.DecodeString(token) + if err != nil { + return false + } + + // Extract the issue time of the token. + sep := bytes.LastIndex(data, []byte{':'}) + if sep < 0 { + return false + } + nanos, err := strconv.ParseInt(string(data[sep+1:]), 10, 64) + if err != nil { + return false + } + issueTime := time.Unix(0, nanos) + + // Check that the token is not expired. + if now.Sub(issueTime) >= Timeout { + return false + } + + // Check that the token is not from the future. + // Allow 1 minute grace period in case the token is being verified on a + // machine whose clock is behind the machine that issued the token. + if issueTime.After(now.Add(1 * time.Minute)) { + return false + } + + expected := generateTokenAtTime(key, userID, actionID, issueTime) + + // Check that the token matches the expected value. + // Use constant time comparison to avoid timing attacks. + return subtle.ConstantTimeCompare([]byte(token), []byte(expected)) == 1 +} diff --git a/modules/context/xsrf_test.go b/modules/context/xsrf_test.go new file mode 100644 index 0000000000..c0c711bf07 --- /dev/null +++ b/modules/context/xsrf_test.go @@ -0,0 +1,90 @@ +// Copyright 2012 Google Inc. All Rights Reserved. +// Copyright 2014 The Macaron Authors +// Copyright 2020 The Gitea Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "encoding/base64" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +const ( + key = "quay" + userID = "12345678" + actionID = "POST /form" +) + +var ( + now = time.Now() + oneMinuteFromNow = now.Add(1 * time.Minute) +) + +func Test_ValidToken(t *testing.T) { + t.Run("Validate token", func(t *testing.T) { + tok := generateTokenAtTime(key, userID, actionID, now) + assert.True(t, validTokenAtTime(tok, key, userID, actionID, oneMinuteFromNow)) + assert.True(t, validTokenAtTime(tok, key, userID, actionID, now.Add(Timeout-1*time.Nanosecond))) + assert.True(t, validTokenAtTime(tok, key, userID, actionID, now.Add(-1*time.Minute))) + }) +} + +// Test_SeparatorReplacement tests that separators are being correctly substituted +func Test_SeparatorReplacement(t *testing.T) { + t.Run("Test two separator replacements", func(t *testing.T) { + assert.NotEqual(t, generateTokenAtTime("foo:bar", "baz", "wah", now), + generateTokenAtTime("foo", "bar:baz", "wah", now)) + }) +} + +func Test_InvalidToken(t *testing.T) { + t.Run("Test invalid tokens", func(t *testing.T) { + invalidTokenTests := []struct { + name, key, userID, actionID string + t time.Time + }{ + {"Bad key", "foobar", userID, actionID, oneMinuteFromNow}, + {"Bad userID", key, "foobar", actionID, oneMinuteFromNow}, + {"Bad actionID", key, userID, "foobar", oneMinuteFromNow}, + {"Expired", key, userID, actionID, now.Add(Timeout)}, + {"More than 1 minute from the future", key, userID, actionID, now.Add(-1*time.Nanosecond - 1*time.Minute)}, + } + + tok := generateTokenAtTime(key, userID, actionID, now) + for _, itt := range invalidTokenTests { + assert.False(t, validTokenAtTime(tok, itt.key, itt.userID, itt.actionID, itt.t)) + } + }) +} + +// Test_ValidateBadData primarily tests that no unexpected panics are triggered during parsing +func Test_ValidateBadData(t *testing.T) { + t.Run("Validate bad data", func(t *testing.T) { + badDataTests := []struct { + name, tok string + }{ + {"Invalid Base64", "ASDab24(@)$*=="}, + {"No delimiter", base64.URLEncoding.EncodeToString([]byte("foobar12345678"))}, + {"Invalid time", base64.URLEncoding.EncodeToString([]byte("foobar:foobar"))}, + } + + for _, bdt := range badDataTests { + assert.False(t, validTokenAtTime(bdt.tok, key, userID, actionID, oneMinuteFromNow)) + } + }) +} |