diff options
author | wxiaoguang <wxiaoguang@gmail.com> | 2023-05-21 09:50:53 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-21 09:50:53 +0800 |
commit | 6b33152b7dc81b38e5832a30c52cfad1902e86d0 (patch) | |
tree | 020272cc3b2d0566d286ed01f85ae74a9f48c177 | |
parent | 6ba4f897231229c06ac98bf2e067665e3ef0bf23 (diff) | |
download | gitea-6b33152b7dc81b38e5832a30c52cfad1902e86d0.tar.gz gitea-6b33152b7dc81b38e5832a30c52cfad1902e86d0.zip |
Decouple the different contexts from each other (#24786)
Replace #16455
Close #21803
Mixing different Gitea contexts together causes some problems:
1. Unable to respond proper content when error occurs, eg: Web should
respond HTML while API should respond JSON
2. Unclear dependency, eg: it's unclear when Context is used in
APIContext, which fields should be initialized, which methods are
necessary.
To make things clear, this PR introduces a Base context, it only
provides basic Req/Resp/Data features.
This PR mainly moves code. There are still many legacy problems and
TODOs in code, leave unrelated changes to future PRs.
57 files changed, 882 insertions, 778 deletions
diff --git a/modules/context/api.go b/modules/context/api.go index e263dcbe8d..092ad73f31 100644 --- a/modules/context/api.go +++ b/modules/context/api.go @@ -13,18 +13,32 @@ import ( "code.gitea.io/gitea/models/auth" repo_model "code.gitea.io/gitea/models/repo" - "code.gitea.io/gitea/modules/cache" + "code.gitea.io/gitea/models/unit" + user_model "code.gitea.io/gitea/models/user" + mc "code.gitea.io/gitea/modules/cache" "code.gitea.io/gitea/modules/git" "code.gitea.io/gitea/modules/httpcache" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/setting" - "code.gitea.io/gitea/modules/web/middleware" + + "gitea.com/go-chi/cache" ) // APIContext is a specific context for API service type APIContext struct { - *Context - Org *APIOrganization + *Base + + Cache cache.Cache + + Doer *user_model.User // current signed-in user + IsSigned bool + IsBasicAuth bool + + ContextUser *user_model.User // the user which is being visited, in most cases it differs from Doer + + Repo *Repository + Org *APIOrganization + Package *Package } // Currently, we have the following common fields in error response: @@ -128,11 +142,6 @@ type apiContextKeyType struct{} var apiContextKey = apiContextKeyType{} -// WithAPIContext set up api context in request -func WithAPIContext(req *http.Request, ctx *APIContext) *http.Request { - return req.WithContext(context.WithValue(req.Context(), apiContextKey, ctx)) -} - // GetAPIContext returns a context for API routes func GetAPIContext(req *http.Request) *APIContext { return req.Context().Value(apiContextKey).(*APIContext) @@ -195,21 +204,21 @@ func (ctx *APIContext) CheckForOTP() { } otpHeader := ctx.Req.Header.Get("X-Gitea-OTP") - twofa, err := auth.GetTwoFactorByUID(ctx.Context.Doer.ID) + twofa, err := auth.GetTwoFactorByUID(ctx.Doer.ID) if err != nil { if auth.IsErrTwoFactorNotEnrolled(err) { return // No 2FA enrollment for this user } - ctx.Context.Error(http.StatusInternalServerError) + ctx.Error(http.StatusInternalServerError, "GetTwoFactorByUID", err) return } ok, err := twofa.ValidateTOTP(otpHeader) if err != nil { - ctx.Context.Error(http.StatusInternalServerError) + ctx.Error(http.StatusInternalServerError, "ValidateTOTP", err) return } if !ok { - ctx.Context.Error(http.StatusUnauthorized) + ctx.Error(http.StatusUnauthorized, "", nil) return } } @@ -218,23 +227,17 @@ func (ctx *APIContext) CheckForOTP() { func APIContexter() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - locale := middleware.Locale(w, req) - ctx := APIContext{ - Context: &Context{ - Resp: NewResponse(w), - Data: middleware.GetContextData(req.Context()), - Locale: locale, - Cache: cache.GetCache(), - Repo: &Repository{ - PullRequest: &PullRequest{}, - }, - Org: &Organization{}, - }, - Org: &APIOrganization{}, + base, baseCleanUp := NewBaseContext(w, req) + ctx := &APIContext{ + Base: base, + Cache: mc.GetCache(), + Repo: &Repository{PullRequest: &PullRequest{}}, + Org: &APIOrganization{}, } - defer ctx.Close() + defer baseCleanUp() - ctx.Req = WithAPIContext(WithContext(req, ctx.Context), &ctx) + ctx.Base.AppendContextValue(apiContextKey, ctx) + ctx.Base.AppendContextValueFunc(git.RepositoryContextKey, func() any { return ctx.Repo.GitRepo }) // If request sends files, parse them here otherwise the Query() can't be parsed and the CsrfToken will be invalid. if ctx.Req.Method == "POST" && strings.Contains(ctx.Req.Header.Get("Content-Type"), "multipart/form-data") { @@ -247,8 +250,6 @@ func APIContexter() func(http.Handler) http.Handler { httpcache.SetCacheControlInHeader(ctx.Resp.Header(), 0, "no-transform") ctx.Resp.Header().Set(`X-Frame-Options`, setting.CORSConfig.XFrameOptions) - ctx.Data["Context"] = &ctx - next.ServeHTTP(ctx.Resp, ctx.Req) }) } @@ -301,7 +302,7 @@ func ReferencesGitRepo(allowEmpty ...bool) func(ctx *APIContext) (cancel context return func() { // If it's been set to nil then assume someone else has closed it. if ctx.Repo.GitRepo != nil { - ctx.Repo.GitRepo.Close() + _ = ctx.Repo.GitRepo.Close() } } } @@ -337,7 +338,7 @@ func RepoRefForAPI(next http.Handler) http.Handler { } var err error - refName := getRefName(ctx.Context, RepoRefAny) + refName := getRefName(ctx.Base, ctx.Repo, RepoRefAny) if ctx.Repo.GitRepo.IsBranchExist(refName) { ctx.Repo.Commit, err = ctx.Repo.GitRepo.GetBranchCommit(refName) @@ -368,3 +369,53 @@ func RepoRefForAPI(next http.Handler) http.Handler { next.ServeHTTP(w, req) }) } + +// HasAPIError returns true if error occurs in form validation. +func (ctx *APIContext) HasAPIError() bool { + hasErr, ok := ctx.Data["HasError"] + if !ok { + return false + } + return hasErr.(bool) +} + +// GetErrMsg returns error message in form validation. +func (ctx *APIContext) GetErrMsg() string { + msg, _ := ctx.Data["ErrorMsg"].(string) + if msg == "" { + msg = "invalid form data" + } + return msg +} + +// NotFoundOrServerError use error check function to determine if the error +// is about not found. It responds with 404 status code for not found error, +// or error context description for logging purpose of 500 server error. +func (ctx *APIContext) NotFoundOrServerError(logMsg string, errCheck func(error) bool, logErr error) { + if errCheck(logErr) { + ctx.JSON(http.StatusNotFound, nil) + return + } + ctx.Error(http.StatusInternalServerError, "NotFoundOrServerError", logMsg) +} + +// IsUserSiteAdmin returns true if current user is a site admin +func (ctx *APIContext) IsUserSiteAdmin() bool { + return ctx.IsSigned && ctx.Doer.IsAdmin +} + +// IsUserRepoAdmin returns true if current user is admin in current repo +func (ctx *APIContext) IsUserRepoAdmin() bool { + return ctx.Repo.IsAdmin() +} + +// IsUserRepoWriter returns true if current user has write privilege in current repo +func (ctx *APIContext) IsUserRepoWriter(unitTypes []unit.Type) bool { + for _, unitType := range unitTypes { + if ctx.Repo.CanWrite(unitType) { + return true + } + } + + return false +} diff --git a/modules/context/base.go b/modules/context/base.go new file mode 100644 index 0000000000..ac9b52d51c --- /dev/null +++ b/modules/context/base.go @@ -0,0 +1,300 @@ +// Copyright 2020 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package context + +import ( + "context" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "code.gitea.io/gitea/modules/httplib" + "code.gitea.io/gitea/modules/json" + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/translation" + "code.gitea.io/gitea/modules/util" + "code.gitea.io/gitea/modules/web/middleware" + + "github.com/go-chi/chi/v5" +) + +type contextValuePair struct { + key any + valueFn func() any +} + +type Base struct { + originCtx context.Context + contextValues []contextValuePair + + Resp ResponseWriter + Req *http.Request + + // Data is prepared by ContextDataStore middleware, this field only refers to the pre-created/prepared ContextData. + // Although it's mainly used for MVC templates, sometimes it's also used to pass data between middlewares/handler + Data middleware.ContextData + + // Locale is mainly for Web context, although the API context also uses it in some cases: message response, form validation + Locale translation.Locale +} + +func (b *Base) Deadline() (deadline time.Time, ok bool) { + return b.originCtx.Deadline() +} + +func (b *Base) Done() <-chan struct{} { + return b.originCtx.Done() +} + +func (b *Base) Err() error { + return b.originCtx.Err() +} + +func (b *Base) Value(key any) any { + for _, pair := range b.contextValues { + if pair.key == key { + return pair.valueFn() + } + } + return b.originCtx.Value(key) +} + +func (b *Base) AppendContextValueFunc(key any, valueFn func() any) any { + b.contextValues = append(b.contextValues, contextValuePair{key, valueFn}) + return b +} + +func (b *Base) AppendContextValue(key, value any) any { + b.contextValues = append(b.contextValues, contextValuePair{key, func() any { return value }}) + return b +} + +func (b *Base) GetData() middleware.ContextData { + return b.Data +} + +// AppendAccessControlExposeHeaders append headers by name to "Access-Control-Expose-Headers" header +func (b *Base) AppendAccessControlExposeHeaders(names ...string) { + val := b.RespHeader().Get("Access-Control-Expose-Headers") + if len(val) != 0 { + b.RespHeader().Set("Access-Control-Expose-Headers", fmt.Sprintf("%s, %s", val, strings.Join(names, ", "))) + } else { + b.RespHeader().Set("Access-Control-Expose-Headers", strings.Join(names, ", ")) + } +} + +// SetTotalCountHeader set "X-Total-Count" header +func (b *Base) SetTotalCountHeader(total int64) { + b.RespHeader().Set("X-Total-Count", fmt.Sprint(total)) + b.AppendAccessControlExposeHeaders("X-Total-Count") +} + +// Written returns true if there are something sent to web browser +func (b *Base) Written() bool { + return b.Resp.Status() > 0 +} + +// Status writes status code +func (b *Base) Status(status int) { + b.Resp.WriteHeader(status) +} + +// Write writes data to web browser +func (b *Base) Write(bs []byte) (int, error) { + return b.Resp.Write(bs) +} + +// RespHeader returns the response header +func (b *Base) RespHeader() http.Header { + return b.Resp.Header() +} + +// Error returned an error to web browser +func (b *Base) Error(status int, contents ...string) { + v := http.StatusText(status) + if len(contents) > 0 { + v = contents[0] + } + http.Error(b.Resp, v, status) +} + +// JSON render content as JSON +func (b *Base) JSON(status int, content interface{}) { + b.Resp.Header().Set("Content-Type", "application/json;charset=utf-8") + b.Resp.WriteHeader(status) + if err := json.NewEncoder(b.Resp).Encode(content); err != nil { + log.Error("Render JSON failed: %v", err) + } +} + +// RemoteAddr returns the client machine ip address +func (b *Base) RemoteAddr() string { + return b.Req.RemoteAddr +} + +// Params returns the param on route +func (b *Base) Params(p string) string { + s, _ := url.PathUnescape(chi.URLParam(b.Req, strings.TrimPrefix(p, ":"))) + return s +} + +// ParamsInt64 returns the param on route as int64 +func (b *Base) ParamsInt64(p string) int64 { + v, _ := strconv.ParseInt(b.Params(p), 10, 64) + return v +} + +// SetParams set params into routes +func (b *Base) SetParams(k, v string) { + chiCtx := chi.RouteContext(b) + chiCtx.URLParams.Add(strings.TrimPrefix(k, ":"), url.PathEscape(v)) +} + +// FormString returns the first value matching the provided key in the form as a string +func (b *Base) FormString(key string) string { + return b.Req.FormValue(key) +} + +// FormStrings returns a string slice for the provided key from the form +func (b *Base) FormStrings(key string) []string { + if b.Req.Form == nil { + if err := b.Req.ParseMultipartForm(32 << 20); err != nil { + return nil + } + } + if v, ok := b.Req.Form[key]; ok { + return v + } + return nil +} + +// FormTrim returns the first value for the provided key in the form as a space trimmed string +func (b *Base) FormTrim(key string) string { + return strings.TrimSpace(b.Req.FormValue(key)) +} + +// FormInt returns the first value for the provided key in the form as an int +func (b *Base) FormInt(key string) int { + v, _ := strconv.Atoi(b.Req.FormValue(key)) + return v +} + +// FormInt64 returns the first value for the provided key in the form as an int64 +func (b *Base) FormInt64(key string) int64 { + v, _ := strconv.ParseInt(b.Req.FormValue(key), 10, 64) + return v +} + +// FormBool returns true if the value for the provided key in the form is "1", "true" or "on" +func (b *Base) FormBool(key string) bool { + s := b.Req.FormValue(key) + v, _ := strconv.ParseBool(s) + v = v || strings.EqualFold(s, "on") + return v +} + +// FormOptionalBool returns an OptionalBoolTrue or OptionalBoolFalse if the value +// for the provided key exists in the form else it returns OptionalBoolNone +func (b *Base) FormOptionalBool(key string) util.OptionalBool { + value := b.Req.FormValue(key) + if len(value) == 0 { + return util.OptionalBoolNone + } + s := b.Req.FormValue(key) + v, _ := strconv.ParseBool(s) + v = v || strings.EqualFold(s, "on") + return util.OptionalBoolOf(v) +} + +func (b *Base) SetFormString(key, value string) { + _ = b.Req.FormValue(key) // force parse form + b.Req.Form.Set(key, value) +} + +// PlainTextBytes renders bytes as plain text +func (b *Base) plainTextInternal(skip, status int, bs []byte) { + statusPrefix := status / 100 + if statusPrefix == 4 || statusPrefix == 5 { + log.Log(skip, log.TRACE, "plainTextInternal (status=%d): %s", status, string(bs)) + } + b.Resp.Header().Set("Content-Type", "text/plain;charset=utf-8") + b.Resp.Header().Set("X-Content-Type-Options", "nosniff") + b.Resp.WriteHeader(status) + if _, err := b.Resp.Write(bs); err != nil { + log.ErrorWithSkip(skip, "plainTextInternal (status=%d): write bytes failed: %v", status, err) + } +} + +// PlainTextBytes renders bytes as plain text +func (b *Base) PlainTextBytes(status int, bs []byte) { + b.plainTextInternal(2, status, bs) +} + +// PlainText renders content as plain text +func (b *Base) PlainText(status int, text string) { + b.plainTextInternal(2, status, []byte(text)) +} + +// Redirect redirects the request +func (b *Base) Redirect(location string, status ...int) { + code := http.StatusSeeOther + if len(status) == 1 { + code = status[0] + } + + if strings.Contains(location, "://") || strings.HasPrefix(location, "//") { + // Some browsers (Safari) have buggy behavior for Cookie + Cache + External Redirection, eg: /my-path => https://other/path + // 1. the first request to "/my-path" contains cookie + // 2. some time later, the request to "/my-path" doesn't contain cookie (caused by Prevent web tracking) + // 3. Gitea's Sessioner doesn't see the session cookie, so it generates a new session id, and returns it to browser + // 4. then the browser accepts the empty session, then the user is logged out + // So in this case, we should remove the session cookie from the response header + removeSessionCookieHeader(b.Resp) + } + http.Redirect(b.Resp, b.Req, location, code) +} + +type ServeHeaderOptions httplib.ServeHeaderOptions + +func (b *Base) SetServeHeaders(opt *ServeHeaderOptions) { + httplib.ServeSetHeaders(b.Resp, (*httplib.ServeHeaderOptions)(opt)) +} + +// ServeContent serves content to http request +func (b *Base) ServeContent(r io.ReadSeeker, opts *ServeHeaderOptions) { + httplib.ServeSetHeaders(b.Resp, (*httplib.ServeHeaderOptions)(opts)) + http.ServeContent(b.Resp, b.Req, opts.Filename, opts.LastModified, r) +} + +// Close frees all resources hold by Context +func (b *Base) cleanUp() { + if b.Req != nil && b.Req.MultipartForm != nil { + _ = b.Req.MultipartForm.RemoveAll() // remove the temp files buffered to tmp directory + } +} + +func (b *Base) Tr(msg string, args ...any) string { + return b.Locale.Tr(msg, args...) +} + +func (b *Base) TrN(cnt any, key1, keyN string, args ...any) string { + return b.Locale.TrN(cnt, key1, keyN, args...) +} + +func NewBaseContext(resp http.ResponseWriter, req *http.Request) (b *Base, closeFunc func()) { + b = &Base{ + originCtx: req.Context(), + Req: req, + Resp: WrapResponseWriter(resp), + Locale: middleware.Locale(resp, req), + Data: middleware.GetContextData(req.Context()), + } + b.AppendContextValue(translation.ContextKey, b.Locale) + b.Req = b.Req.WithContext(b) + return b, b.cleanUp +} diff --git a/modules/context/context.go b/modules/context/context.go index 9ba1985f36..1e15081479 100644 --- a/modules/context/context.go +++ b/modules/context/context.go @@ -5,7 +5,6 @@ package context import ( - "context" "html" "html/template" "io" @@ -36,38 +35,27 @@ type Render interface { // Context represents context of a request. type Context struct { - Resp ResponseWriter - Req *http.Request - Render Render + *Base - Data middleware.ContextData // data used by MVC templates - PageData map[string]any // data used by JavaScript modules in one page, it's `window.config.pageData` + Render Render + PageData map[string]any // data used by JavaScript modules in one page, it's `window.config.pageData` - Locale translation.Locale Cache cache.Cache Csrf CSRFProtector Flash *middleware.Flash Session session.Store - Link string // current request URL (without query string) - Doer *user_model.User + Link string // current request URL (without query string) + + Doer *user_model.User // current signed-in user IsSigned bool IsBasicAuth bool - ContextUser *user_model.User - Repo *Repository - Org *Organization - Package *Package -} + ContextUser *user_model.User // the user which is being visited, in most cases it differs from Doer -// Close frees all resources hold by Context -func (ctx *Context) Close() error { - var err error - if ctx.Req != nil && ctx.Req.MultipartForm != nil { - err = ctx.Req.MultipartForm.RemoveAll() // remove the temp files buffered to tmp directory - } - // TODO: close opened repo, and more - return err + Repo *Repository + Org *Organization + Package *Package } // TrHTMLEscapeArgs runs ".Locale.Tr()" but pre-escapes all arguments with html.EscapeString. @@ -80,55 +68,30 @@ func (ctx *Context) TrHTMLEscapeArgs(msg string, args ...string) string { return ctx.Locale.Tr(msg, trArgs...) } -func (ctx *Context) Tr(msg string, args ...any) string { - return ctx.Locale.Tr(msg, args...) -} - -func (ctx *Context) TrN(cnt any, key1, keyN string, args ...any) string { - return ctx.Locale.TrN(cnt, key1, keyN, args...) -} - -// Deadline is part of the interface for context.Context and we pass this to the request context -func (ctx *Context) Deadline() (deadline time.Time, ok bool) { - return ctx.Req.Context().Deadline() -} - -// Done is part of the interface for context.Context and we pass this to the request context -func (ctx *Context) Done() <-chan struct{} { - return ctx.Req.Context().Done() -} - -// Err is part of the interface for context.Context and we pass this to the request context -func (ctx *Context) Err() error { - return ctx.Req.Context().Err() -} - -// Value is part of the interface for context.Context and we pass this to the request context -func (ctx *Context) Value(key interface{}) interface{} { - if key == git.RepositoryContextKey && ctx.Repo != nil { - return ctx.Repo.GitRepo - } - if key == translation.ContextKey && ctx.Locale != nil { - return ctx.Locale - } - return ctx.Req.Context().Value(key) -} - type contextKeyType struct{} var contextKey interface{} = contextKeyType{} -// WithContext set up install context in request -func WithContext(req *http.Request, ctx *Context) *http.Request { - return req.WithContext(context.WithValue(req.Context(), contextKey, ctx)) +func GetContext(req *http.Request) *Context { + ctx, _ := req.Context().Value(contextKey).(*Context) + return ctx } -// GetContext retrieves install context from request -func GetContext(req *http.Request) *Context { - if ctx, ok := req.Context().Value(contextKey).(*Context); ok { - return ctx +// ValidateContext is a special context for form validation middleware. It may be different from other contexts. +type ValidateContext struct { + *Base +} + +// GetValidateContext gets a context for middleware form validation +func GetValidateContext(req *http.Request) (ctx *ValidateContext) { + if ctxAPI, ok := req.Context().Value(apiContextKey).(*APIContext); ok { + ctx = &ValidateContext{Base: ctxAPI.Base} + } else if ctxWeb, ok := req.Context().Value(contextKey).(*Context); ok { + ctx = &ValidateContext{Base: ctxWeb.Base} + } else { + panic("invalid context, expect either APIContext or Context") } - return nil + return ctx } // Contexter initializes a classic context for a request. @@ -150,20 +113,17 @@ func Contexter() func(next http.Handler) http.Handler { } return func(next http.Handler) http.Handler { return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { - ctx := Context{ - Resp: NewResponse(resp), + base, baseCleanUp := NewBaseContext(resp, req) + ctx := &Context{ + Base: base, Cache: mc.GetCache(), - Locale: middleware.Locale(resp, req), Link: setting.AppSubURL + strings.TrimSuffix(req.URL.EscapedPath(), "/"), Render: rnd, Session: session.GetSession(req), - Repo: &Repository{ - PullRequest: &PullRequest{}, - }, - Org: &Organization{}, - Data: middleware.GetContextData(req.Context()), + Repo: &Repository{PullRequest: &PullRequest{}}, + Org: &Organization{}, } - defer ctx.Close() + defer baseCleanUp() ctx.Data.MergeFrom(middleware.CommonTemplateContextData()) ctx.Data["Context"] = &ctx @@ -175,15 +135,17 @@ func Contexter() func(next http.Handler) http.Handler { ctx.PageData = map[string]any{} ctx.Data["PageData"] = ctx.PageData - ctx.Req = WithContext(req, &ctx) - ctx.Csrf = PrepareCSRFProtector(csrfOpts, &ctx) + ctx.Base.AppendContextValue(contextKey, ctx) + ctx.Base.AppendContextValueFunc(git.RepositoryContextKey, func() any { return ctx.Repo.GitRepo }) + + ctx.Csrf = PrepareCSRFProtector(csrfOpts, ctx) // Get the last flash message from cookie lastFlashCookie := middleware.GetSiteCookie(ctx.Req, CookieNameFlash) if vals, _ := url.ParseQuery(lastFlashCookie); len(vals) > 0 { // store last Flash message into the template data, to render it ctx.Data["Flash"] = &middleware.Flash{ - DataStore: &ctx, + DataStore: ctx, Values: vals, ErrorMsg: vals.Get("error"), SuccessMsg: vals.Get("success"), @@ -193,7 +155,7 @@ func Contexter() func(next http.Handler) http.Handler { } // prepare an empty Flash message for current request - ctx.Flash = &middleware.Flash{DataStore: &ctx, Values: url.Values{}} + ctx.Flash = &middleware.Flash{DataStore: ctx, Values: url.Values{}} ctx.Resp.Before(func(resp ResponseWriter) { if val := ctx.Flash.Encode(); val != "" { middleware.SetSiteCookie(ctx.Resp, CookieNameFlash, val, 0) @@ -235,3 +197,24 @@ func Contexter() func(next http.Handler) http.Handler { }) } } + +// HasError returns true if error occurs in form validation. +// Attention: this function changes ctx.Data and ctx.Flash +func (ctx *Context) HasError() bool { + hasErr, ok := ctx.Data["HasError"] + if !ok { + return false + } + ctx.Flash.ErrorMsg = ctx.GetErrMsg() + ctx.Data["Flash"] = ctx.Flash + return hasErr.(bool) +} + +// GetErrMsg returns error message in form validation. +func (ctx *Context) GetErrMsg() string { + msg, _ := ctx.Data["ErrorMsg"].(string) + if msg == "" { + msg = "invalid form data" + } + return msg +} diff --git a/modules/context/context_data.go b/modules/context/context_data.go deleted file mode 100644 index cdf4ff9afe..0000000000 --- a/modules/context/context_data.go +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright 2023 The Gitea Authors. All rights reserved. -// SPDX-License-Identifier: MIT - -package context - -import "code.gitea.io/gitea/modules/web/middleware" - -// GetData returns the data -func (ctx *Context) GetData() middleware.ContextData { - return ctx.Data -} - -// HasAPIError returns true if error occurs in form validation. -func (ctx *Context) HasAPIError() bool { - hasErr, ok := ctx.Data["HasError"] - if !ok { - return false - } - return hasErr.(bool) -} - -// GetErrMsg returns error message -func (ctx *Context) GetErrMsg() string { - return ctx.Data["ErrorMsg"].(string) -} - -// HasError returns true if error occurs in form validation. -// Attention: this function changes ctx.Data and ctx.Flash -func (ctx *Context) HasError() bool { - hasErr, ok := ctx.Data["HasError"] - if !ok { - return false - } - ctx.Flash.ErrorMsg = ctx.Data["ErrorMsg"].(string) - ctx.Data["Flash"] = ctx.Flash - return hasErr.(bool) -} - -// HasValue returns true if value of given name exists. -func (ctx *Context) HasValue(name string) bool { - _, ok := ctx.Data[name] - return ok -} diff --git a/modules/context/context_form.go b/modules/context/context_form.go deleted file mode 100644 index 5c02152582..0000000000 --- a/modules/context/context_form.go +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2021 The Gitea Authors. All rights reserved. -// SPDX-License-Identifier: MIT - -package context - -import ( - "strconv" - "strings" - - "code.gitea.io/gitea/modules/util" -) - -// FormString returns the first value matching the provided key in the form as a string -func (ctx *Context) FormString(key string) string { - return ctx.Req.FormValue(key) -} - -// FormStrings returns a string slice for the provided key from the form -func (ctx *Context) FormStrings(key string) []string { - if ctx.Req.Form == nil { - if err := ctx.Req.ParseMultipartForm(32 << 20); err != nil { - return nil - } - } - if v, ok := ctx.Req.Form[key]; ok { - return v - } - return nil -} - -// FormTrim returns the first value for the provided key in the form as a space trimmed string -func (ctx *Context) FormTrim(key string) string { - return strings.TrimSpace(ctx.Req.FormValue(key)) -} - -// FormInt returns the first value for the provided key in the form as an int -func (ctx *Context) FormInt(key string) int { - v, _ := strconv.Atoi(ctx.Req.FormValue(key)) - return v -} - -// FormInt64 returns the first value for the provided key in the form as an int64 -func (ctx *Context) FormInt64(key string) int64 { - v, _ := strconv.ParseInt(ctx.Req.FormValue(key), 10, 64) - return v -} - -// FormBool returns true if the value for the provided key in the form is "1", "true" or "on" -func (ctx *Context) FormBool(key string) bool { - s := ctx.Req.FormValue(key) - v, _ := strconv.ParseBool(s) - v = v || strings.EqualFold(s, "on") - return v -} - -// FormOptionalBool returns an OptionalBoolTrue or OptionalBoolFalse if the value -// for the provided key exists in the form else it returns OptionalBoolNone -func (ctx *Context) FormOptionalBool(key string) util.OptionalBool { - value := ctx.Req.FormValue(key) - if len(value) == 0 { - return util.OptionalBoolNone - } - s := ctx.Req.FormValue(key) - v, _ := strconv.ParseBool(s) - v = v || strings.EqualFold(s, "on") - return util.OptionalBoolOf(v) -} - -func (ctx *Context) SetFormString(key, value string) { - _ = ctx.Req.FormValue(key) // force parse form - ctx.Req.Form.Set(key, value) -} diff --git a/modules/context/context_request.go b/modules/context/context_request.go index 0b87552c08..984b9ac793 100644 --- a/modules/context/context_request.go +++ b/modules/context/context_request.go @@ -6,36 +6,9 @@ package context import ( "io" "net/http" - "net/url" - "strconv" "strings" - - "github.com/go-chi/chi/v5" ) -// RemoteAddr returns the client machine ip address -func (ctx *Context) RemoteAddr() string { - return ctx.Req.RemoteAddr -} - -// Params returns the param on route -func (ctx *Context) Params(p string) string { - s, _ := url.PathUnescape(chi.URLParam(ctx.Req, strings.TrimPrefix(p, ":"))) - return s -} - -// ParamsInt64 returns the param on route as int64 -func (ctx *Context) ParamsInt64(p string) int64 { - v, _ := strconv.ParseInt(ctx.Params(p), 10, 64) - return v -} - -// SetParams set params into routes -func (ctx *Context) SetParams(k, v string) { - chiCtx := chi.RouteContext(ctx) - chiCtx.URLParams.Add(strings.TrimPrefix(k, ":"), url.PathEscape(v)) -} - // UploadStream returns the request body or the first form file // Only form files need to get closed. func (ctx *Context) UploadStream() (rd io.ReadCloser, needToClose bool, err error) { diff --git a/modules/context/context_response.go b/modules/context/context_response.go index 8adff96994..aeeb51ba37 100644 --- a/modules/context/context_response.go +++ b/modules/context/context_response.go @@ -16,49 +16,17 @@ import ( user_model "code.gitea.io/gitea/models/user" "code.gitea.io/gitea/modules/base" - "code.gitea.io/gitea/modules/json" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/templates" "code.gitea.io/gitea/modules/web/middleware" ) -// SetTotalCountHeader set "X-Total-Count" header -func (ctx *Context) SetTotalCountHeader(total int64) { - ctx.RespHeader().Set("X-Total-Count", fmt.Sprint(total)) - ctx.AppendAccessControlExposeHeaders("X-Total-Count") -} - -// AppendAccessControlExposeHeaders append headers by name to "Access-Control-Expose-Headers" header -func (ctx *Context) AppendAccessControlExposeHeaders(names ...string) { - val := ctx.RespHeader().Get("Access-Control-Expose-Headers") - if len(val) != 0 { - ctx.RespHeader().Set("Access-Control-Expose-Headers", fmt.Sprintf("%s, %s", val, strings.Join(names, ", "))) - } else { - ctx.RespHeader().Set("Access-Control-Expose-Headers", strings.Join(names, ", ")) - } -} - -// Written returns true if there are something sent to web browser -func (ctx *Context) Written() bool { - return ctx.Resp.Status() > 0 -} - -// Status writes status code -func (ctx *Context) Status(status int) { - ctx.Resp.WriteHeader(status) -} - -// Write writes data to web browser -func (ctx *Context) Write(bs []byte) (int, error) { - return ctx.Resp.Write(bs) -} - // RedirectToUser redirect to a differently-named user -func RedirectToUser(ctx *Context, userName string, redirectUserID int64) { +func RedirectToUser(ctx *Base, userName string, redirectUserID int64) { user, err := user_model.GetUserByID(ctx, redirectUserID) if err != nil { - ctx.ServerError("GetUserByID", err) + ctx.Error(http.StatusInternalServerError, "unable to get user") return } @@ -211,69 +179,3 @@ func (ctx *Context) NotFoundOrServerError(logMsg string, errCheck func(error) bo } ctx.serverErrorInternal(logMsg, logErr) } - -// PlainTextBytes renders bytes as plain text -func (ctx *Context) plainTextInternal(skip, status int, bs []byte) { - statusPrefix := status / 100 - if statusPrefix == 4 || statusPrefix == 5 { - log.Log(skip, log.TRACE, "plainTextInternal (status=%d): %s", status, string(bs)) - } - ctx.Resp.Header().Set("Content-Type", "text/plain;charset=utf-8") - ctx.Resp.Header().Set("X-Content-Type-Options", "nosniff") - ctx.Resp.WriteHeader(status) - if _, err := ctx.Resp.Write(bs); err != nil { - log.ErrorWithSkip(skip, "plainTextInternal (status=%d): write bytes failed: %v", status, err) - } -} - -// PlainTextBytes renders bytes as plain text -func (ctx *Context) PlainTextBytes(status int, bs []byte) { - ctx.plainTextInternal(2, status, bs) -} - -// PlainText renders content as plain text -func (ctx *Context) PlainText(status int, text string) { - ctx.plainTextInternal(2, status, []byte(text)) -} - -// RespHeader returns the response header -func (ctx *Context) RespHeader() http.Header { - return ctx.Resp.Header() -} - -// Error returned an error to web browser -func (ctx *Context) Error(status int, contents ...string) { - v := http.StatusText(status) - if len(contents) > 0 { - v = contents[0] - } - http.Error(ctx.Resp, v, status) -} - -// JSON render content as JSON -func (ctx *Context) JSON(status int, content interface{}) { - ctx.Resp.Header().Set("Content-Type", "application/json;charset=utf-8") - ctx.Resp.WriteHeader(status) - if err := json.NewEncoder(ctx.Resp).Encode(content); err != nil { - ctx.ServerError("Render JSON failed", err) - } -} - -// Redirect redirects the request -func (ctx *Context) Redirect(location string, status ...int) { - code := http.StatusSeeOther - if len(status) == 1 { - code = status[0] - } - - if strings.Contains(location, "://") || strings.HasPrefix(location, "//") { - // Some browsers (Safari) have buggy behavior for Cookie + Cache + External Redirection, eg: /my-path => https://other/path - // 1. the first request to "/my-path" contains cookie - // 2. some time later, the request to "/my-path" doesn't contain cookie (caused by Prevent web tracking) - // 3. Gitea's Sessioner doesn't see the session cookie, so it generates a new session id, and returns it to browser - // 4. then the browser accepts the empty session, then the user is logged out - // So in this case, we should remove the session cookie from the response header - removeSessionCookieHeader(ctx.Resp) - } - http.Redirect(ctx.Resp, ctx.Req, location, code) -} diff --git a/modules/context/context_serve.go b/modules/context/context_serve.go deleted file mode 100644 index 5569efbc7e..0000000000 --- a/modules/context/context_serve.go +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright 2023 The Gitea Authors. All rights reserved. -// SPDX-License-Identifier: MIT - -package context - -import ( - "io" - "net/http" - - "code.gitea.io/gitea/modules/httplib" -) - -type ServeHeaderOptions httplib.ServeHeaderOptions - -func (ctx *Context) SetServeHeaders(opt *ServeHeaderOptions) { - httplib.ServeSetHeaders(ctx.Resp, (*httplib.ServeHeaderOptions)(opt)) -} - -// ServeContent serves content to http request -func (ctx *Context) ServeContent(r io.ReadSeeker, opts *ServeHeaderOptions) { - httplib.ServeSetHeaders(ctx.Resp, (*httplib.ServeHeaderOptions)(opts)) - http.ServeContent(ctx.Resp, ctx.Req, opts.Filename, opts.LastModified, r) -} diff --git a/modules/context/org.go b/modules/context/org.go index 39a3038f91..355ba0ebd0 100644 --- a/modules/context/org.go +++ b/modules/context/org.go @@ -47,7 +47,7 @@ func GetOrganizationByParams(ctx *Context) { if organization.IsErrOrgNotExist(err) { redirectUserID, err := user_model.LookupUserRedirect(orgName) if err == nil { - RedirectToUser(ctx, orgName, redirectUserID) + RedirectToUser(ctx.Base, orgName, redirectUserID) } else if user_model.IsErrUserRedirectNotExist(err) { ctx.NotFound("GetUserByName", err) } else { diff --git a/modules/context/package.go b/modules/context/package.go index fe5bdac19d..b1fd7088dd 100644 --- a/modules/context/package.go +++ b/modules/context/package.go @@ -4,7 +4,6 @@ package context import ( - gocontext "context" "fmt" "net/http" @@ -16,7 +15,6 @@ import ( "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/structs" "code.gitea.io/gitea/modules/templates" - "code.gitea.io/gitea/modules/web/middleware" ) // Package contains owner, access mode and optional the package descriptor @@ -26,10 +24,16 @@ type Package struct { Descriptor *packages_model.PackageDescriptor } +type packageAssignmentCtx struct { + *Base + Doer *user_model.User + ContextUser *user_model.User +} + // PackageAssignment returns a middleware to handle Context.Package assignment func PackageAssignment() func(ctx *Context) { return func(ctx *Context) { - packageAssignment(ctx, func(status int, title string, obj interface{}) { + errorFn := func(status int, title string, obj interface{}) { err, ok := obj.(error) if !ok { err = fmt.Errorf("%s", obj) @@ -39,68 +43,72 @@ func PackageAssignment() func(ctx *Context) { } else { ctx.ServerError(title, err) } - }) + } + paCtx := &packageAssignmentCtx{Base: ctx.Base, Doer: ctx.Doer, ContextUser: ctx.ContextUser} + ctx.Package = packageAssignment(paCtx, errorFn) } } // PackageAssignmentAPI returns a middleware to handle Context.Package assignment func PackageAssignmentAPI() func(ctx *APIContext) { return func(ctx *APIContext) { - packageAssignment(ctx.Context, ctx.Error) + paCtx := &packageAssignmentCtx{Base: ctx.Base, Doer: ctx.Doer, ContextUser: ctx.ContextUser} + ctx.Package = packageAssignment(paCtx, ctx.Error) } } -func packageAssignment(ctx *Context, errCb func(int, string, interface{})) { - ctx.Package = &Package{ +func packageAssignment(ctx *packageAssignmentCtx, errCb func(int, string, interface{})) *Package { + pkg := &Package{ Owner: ctx.ContextUser, } - var err error - ctx.Package.AccessMode, err = determineAccessMode(ctx) + pkg.AccessMode, err = determineAccessMode(ctx.Base, pkg, ctx.Doer) if err != nil { errCb(http.StatusInternalServerError, "determineAccessMode", err) - return + return pkg } packageType := ctx.Params("type") name := ctx.Params("name") version := ctx.Params("version") if packageType != "" && name != "" && version != "" { - pv, err := packages_model.GetVersionByNameAndVersion(ctx, ctx.Package.Owner.ID, packages_model.Type(packageType), name, version) + pv, err := packages_model.GetVersionByNameAndVersion(ctx, pkg.Owner.ID, packages_model.Type(packageType), name, version) if err != nil { if err == packages_model.ErrPackageNotExist { errCb(http.StatusNotFound, "GetVersionByNameAndVersion", err) } else { errCb(http.StatusInternalServerError, "GetVersionByNameAndVersion", err) } - return + return pkg } - ctx.Package.Descriptor, err = packages_model.GetPackageDescriptor(ctx, pv) + pkg.Descriptor, err = packages_model.GetPackageDescriptor(ctx, pv) if err != nil { errCb(http.StatusInternalServerError, "GetPackageDescriptor", err) - return + return pkg } } + + return pkg } -func determineAccessMode(ctx *Context) (perm.AccessMode, error) { - if setting.Service.RequireSignInView && ctx.Doer == nil { +func determineAccessMode(ctx *Base, pkg *Package, doer *user_model.User) (perm.AccessMode, error) { + if setting.Service.RequireSignInView && doer == nil { return perm.AccessModeNone, nil } - if ctx.Doer != nil && !ctx.Doer.IsGhost() && (!ctx.Doer.IsActive || ctx.Doer.ProhibitLogin) { + if doer != nil && !doer.IsGhost() && (!doer.IsActive || doer.ProhibitLogin) { return perm.AccessModeNone, nil } // TODO: ActionUser permission check accessMode := perm.AccessModeNone - if ctx.Package.Owner.IsOrganization() { - org := organization.OrgFromUser(ctx.Package.Owner) + if pkg.Owner.IsOrganization() { + org := organization.OrgFromUser(pkg.Owner) - if ctx.Doer != nil && !ctx.Doer.IsGhost() { + if doer != nil && !doer.IsGhost() { // 1. If user is logged in, check all team packages permissions - teams, err := organization.GetUserOrgTeams(ctx, org.ID, ctx.Doer.ID) + teams, err := organization.GetUserOrgTeams(ctx, org.ID, doer.ID) if err != nil { return accessMode, err } @@ -110,19 +118,19 @@ func determineAccessMode(ctx *Context) (perm.AccessMode, error) { accessMode = perm } } - } else if organization.HasOrgOrUserVisible(ctx, ctx.Package.Owner, ctx.Doer) { + } else if organization.HasOrgOrUserVisible(ctx, pkg.Owner, doer) { // 2. If user is non-login, check if org is visible to non-login user accessMode = perm.AccessModeRead } } else { - if ctx.Doer != nil && !ctx.Doer.IsGhost() { + if doer != nil && !doer.IsGhost() { // 1. Check if user is package owner - if ctx.Doer.ID == ctx.Package.Owner.ID { + if doer.ID == pkg.Owner.ID { accessMode = perm.AccessModeOwner - } else if ctx.Package.Owner.Visibility == structs.VisibleTypePublic || ctx.Package.Owner.Visibility == structs.VisibleTypeLimited { // 2. Check if package owner is public or limited + } else if pkg.Owner.Visibility == structs.VisibleTypePublic || pkg.Owner.Visibility == structs.VisibleTypeLimited { // 2. Check if package owner is public or limited accessMode = perm.AccessModeRead } - } else if ctx.Package.Owner.Visibility == structs.VisibleTypePublic { // 3. Check if package owner is public + } else if pkg.Owner.Visibility == structs.VisibleTypePublic { // 3. Check if package owner is public accessMode = perm.AccessModeRead } } @@ -131,19 +139,18 @@ func determineAccessMode(ctx *Context) (perm.AccessMode, error) { } // PackageContexter initializes a package context for a request. -func PackageContexter(ctx gocontext.Context) func(next http.Handler) http.Handler { - rnd := templates.HTMLRenderer() +func PackageContexter() func(next http.Handler) http.Handler { + renderer := templates.HTMLRenderer() return func(next http.Handler) http.Handler { return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { - ctx := Context{ - Resp: NewResponse(resp), - Data: middleware.GetContextData(req.Context()), - Render: rnd, + base, baseCleanUp := NewBaseContext(resp, req) + ctx := &Context{ + Base: base, + Render: renderer, // it is still needed when rendering 500 page in a package handler } - defer ctx.Close() - - ctx.Req = WithContext(req, &ctx) + defer baseCleanUp() + ctx.Base.AppendContextValue(contextKey, ctx) next.ServeHTTP(ctx.Resp, ctx.Req) }) } diff --git a/modules/context/private.go b/modules/context/private.go index f621dd6839..41ca8a4709 100644 --- a/modules/context/private.go +++ b/modules/context/private.go @@ -11,13 +11,14 @@ import ( "code.gitea.io/gitea/modules/graceful" "code.gitea.io/gitea/modules/process" - "code.gitea.io/gitea/modules/web/middleware" ) // PrivateContext represents a context for private routes type PrivateContext struct { - *Context + *Base Override context.Context + + Repo *Repository } // Deadline is part of the interface for context.Context and we pass this to the request context @@ -25,7 +26,7 @@ func (ctx *PrivateContext) Deadline() (deadline time.Time, ok bool) { if ctx.Override != nil { return ctx.Override.Deadline() } - return ctx.Req.Context().Deadline() + return ctx.Base.Deadline() } // Done is part of the interface for context.Context and we pass this to the request context @@ -33,7 +34,7 @@ func (ctx *PrivateContext) Done() <-chan struct{} { if ctx.Override != nil { return ctx.Override.Done() } - return ctx.Req.Context().Done() + return ctx.Base.Done() } // Err is part of the interface for context.Context and we pass this to the request context @@ -41,16 +42,11 @@ func (ctx *PrivateContext) Err() error { if ctx.Override != nil { return ctx.Override.Err() } - return ctx.Req.Context().Err() + return ctx.Base.Err() } var privateContextKey interface{} = "default_private_context" -// WithPrivateContext set up private context in request -func WithPrivateContext(req *http.Request, ctx *PrivateContext) *http.Request { - return req.WithContext(context.WithValue(req.Context(), privateContextKey, ctx)) -} - // GetPrivateContext returns a context for Private routes func GetPrivateContext(req *http.Request) *PrivateContext { return req.Context().Value(privateContextKey).(*PrivateContext) @@ -60,16 +56,11 @@ func GetPrivateContext(req *http.Request) *PrivateContext { func PrivateContexter() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - ctx := &PrivateContext{ - Context: &Context{ - Resp: NewResponse(w), - Data: middleware.GetContextData(req.Context()), - }, - } - defer ctx.Close() + base, baseCleanUp := NewBaseContext(w, req) + ctx := &PrivateContext{Base: base} + defer baseCleanUp() + ctx.Base.AppendContextValue(privateContextKey, ctx) - ctx.Req = WithPrivateContext(req, ctx) - ctx.Data["Context"] = ctx next.ServeHTTP(ctx.Resp, ctx.Req) }) } diff --git a/modules/context/repo.go b/modules/context/repo.go index 5e90e8aec0..fd5f208576 100644 --- a/modules/context/repo.go +++ b/modules/context/repo.go @@ -331,13 +331,14 @@ func EarlyResponseForGoGetMeta(ctx *Context) { } // RedirectToRepo redirect to a differently-named repository -func RedirectToRepo(ctx *Context, redirectRepoID int64) { +func RedirectToRepo(ctx *Base, redirectRepoID int64) { ownerName := ctx.Params(":username") previousRepoName := ctx.Params(":reponame") repo, err := repo_model.GetRepositoryByID(ctx, redirectRepoID) if err != nil { - ctx.ServerError("GetRepositoryByID", err) + log.Error("GetRepositoryByID: %v", err) + ctx.Error(http.StatusInternalServerError, "GetRepositoryByID") return } @@ -456,7 +457,7 @@ func RepoAssignment(ctx *Context) (cancel context.CancelFunc) { } if redirectUserID, err := user_model.LookupUserRedirect(userName); err == nil { - RedirectToUser(ctx, userName, redirectUserID) + RedirectToUser(ctx.Base, userName, redirectUserID) } else if user_model.IsErrUserRedirectNotExist(err) { ctx.NotFound("GetUserByName", nil) } else { @@ -498,7 +499,7 @@ func RepoAssignment(ctx *Context) (cancel context.CancelFunc) { if repo_model.IsErrRepoNotExist(err) { redirectRepoID, err := repo_model.LookupRedirect(owner.ID, repoName) if err == nil { - RedirectToRepo(ctx, redirectRepoID) + RedirectToRepo(ctx.Base, redirectRepoID) } else if repo_model.IsErrRedirectNotExist(err) { if ctx.FormString("go-get") == "1" { EarlyResponseForGoGetMeta(ctx) @@ -781,46 +782,46 @@ func (rt RepoRefType) RefTypeIncludesTags() bool { return false } -func getRefNameFromPath(ctx *Context, path string, isExist func(string) bool) string { +func getRefNameFromPath(ctx *Base, repo *Repository, path string, isExist func(string) bool) string { refName := "" parts := strings.Split(path, "/") for i, part := range parts { refName = strings.TrimPrefix(refName+"/"+part, "/") if isExist(refName) { - ctx.Repo.TreePath = strings.Join(parts[i+1:], "/") + repo.TreePath = strings.Join(parts[i+1:], "/") return refName } } return "" } -func getRefName(ctx *Context, pathType RepoRefType) string { +func getRefName(ctx *Base, repo *Repository, pathType RepoRefType) string { path := ctx.Params("*") switch pathType { case RepoRefLegacy, RepoRefAny: - if refName := getRefName(ctx, RepoRefBranch); len(refName) > 0 { + if refName := getRefName(ctx, repo, RepoRefBranch); len(refName) > 0 { return refName } - if refName := getRefName(ctx, RepoRefTag); len(refName) > 0 { + if refName := getRefName(ctx, repo, RepoRefTag); len(refName) > 0 { return refName } // For legacy and API support only full commit sha parts := strings.Split(path, "/") if len(parts) > 0 && len(parts[0]) == git.SHAFullLength { - ctx.Repo.TreePath = strings.Join(parts[1:], "/") + repo.TreePath = strings.Join(parts[1:], "/") return parts[0] } - if refName := getRefName(ctx, RepoRefBlob); len(refName) > 0 { + if refName := getRefName(ctx, repo, RepoRefBlob); len(refName) > 0 { return refName } - ctx.Repo.TreePath = path - return ctx.Repo.Repository.DefaultBranch + repo.TreePath = path + return repo.Repository.DefaultBranch case RepoRefBranch: - ref := getRefNameFromPath(ctx, path, ctx.Repo.GitRepo.IsBranchExist) + ref := getRefNameFromPath(ctx, repo, path, repo.GitRepo.IsBranchExist) if len(ref) == 0 { // maybe it's a renamed branch - return getRefNameFromPath(ctx, path, func(s string) bool { - b, exist, err := git_model.FindRenamedBranch(ctx, ctx.Repo.Repository.ID, s) + return getRefNameFromPath(ctx, repo, path, func(s string) bool { + b, exist, err := git_model.FindRenamedBranch(ctx, repo.Repository.ID, s) if err != nil { log.Error("FindRenamedBranch", err) return false @@ -839,15 +840,15 @@ func getRefName(ctx *Context, pathType RepoRefType) string { return ref case RepoRefTag: - return getRefNameFromPath(ctx, path, ctx.Repo.GitRepo.IsTagExist) + return getRefNameFromPath(ctx, repo, path, repo.GitRepo.IsTagExist) case RepoRefCommit: parts := strings.Split(path, "/") if len(parts) > 0 && len(parts[0]) >= 7 && len(parts[0]) <= git.SHAFullLength { - ctx.Repo.TreePath = strings.Join(parts[1:], "/") + repo.TreePath = strings.Join(parts[1:], "/") return parts[0] } case RepoRefBlob: - _, err := ctx.Repo.GitRepo.GetBlob(path) + _, err := repo.GitRepo.GetBlob(path) if err != nil { return "" } @@ -922,7 +923,7 @@ func RepoRefByType(refType RepoRefType, ignoreNotExistErr ...bool) func(*Context } ctx.Repo.IsViewBranch = true } else { - refName = getRefName(ctx, refType) + refName = getRefName(ctx.Base, ctx.Repo, refType) ctx.Repo.RefName = refName isRenamedBranch, has := ctx.Data["IsRenamedBranch"].(bool) if isRenamedBranch && has { diff --git a/modules/context/response.go b/modules/context/response.go index 40eb5c0d35..ca52ea137d 100644 --- a/modules/context/response.go +++ b/modules/context/response.go @@ -10,10 +10,9 @@ import ( // ResponseWriter represents a response writer for HTTP type ResponseWriter interface { http.ResponseWriter - Flush() + http.Flusher Status() int Before(func(ResponseWriter)) - Size() int } var _ ResponseWriter = &Response{} @@ -27,11 +26,6 @@ type Response struct { beforeExecuted bool } -// Size return written size -func (r *Response) Size() int { - return r.written -} - // Write writes bytes to HTTP endpoint func (r *Response) Write(bs []byte) (int, error) { if !r.beforeExecuted { @@ -65,7 +59,7 @@ func (r *Response) WriteHeader(statusCode int) { } } -// Flush flush cached data +// Flush flushes cached data func (r *Response) Flush() { if f, ok := r.ResponseWriter.(http.Flusher); ok { f.Flush() @@ -83,8 +77,7 @@ func (r *Response) Before(f func(ResponseWriter)) { r.befores = append(r.befores, f) } -// NewResponse creates a response -func NewResponse(resp http.ResponseWriter) *Response { +func WrapResponseWriter(resp http.ResponseWriter) *Response { if v, ok := resp.(*Response); ok { return v } diff --git a/modules/context/utils.go b/modules/context/utils.go index 1fa99953a2..c0f619aa23 100644 --- a/modules/context/utils.go +++ b/modules/context/utils.go @@ -10,7 +10,7 @@ import ( ) // GetQueryBeforeSince return parsed time (unix format) from URL query's before and since -func GetQueryBeforeSince(ctx *Context) (before, since int64, err error) { +func GetQueryBeforeSince(ctx *Base) (before, since int64, err error) { qCreatedBefore, err := prepareQueryArg(ctx, "before") if err != nil { return 0, 0, err @@ -48,7 +48,7 @@ func parseTime(value string) (int64, error) { } // prepareQueryArg unescape and trim a query arg -func prepareQueryArg(ctx *Context, name string) (value string, err error) { +func prepareQueryArg(ctx *Base, name string) (value string, err error) { value, err = url.PathUnescape(ctx.FormString(name)) value = strings.TrimSpace(value) return value, err diff --git a/modules/test/context_tests.go b/modules/test/context_tests.go index 5ba2126126..349c7e3e80 100644 --- a/modules/test/context_tests.go +++ b/modules/test/context_tests.go @@ -4,7 +4,7 @@ package test import ( - scontext "context" + gocontext "context" "io" "net/http" "net/http/httptest" @@ -28,18 +28,33 @@ import ( // MockContext mock context for unit tests // TODO: move this function to other packages, because it depends on "models" package func MockContext(t *testing.T, path string) *context.Context { - resp := &mockResponseWriter{} - ctx := context.Context{ + resp := httptest.NewRecorder() + requestURL, err := url.Parse(path) + assert.NoError(t, err) + req := &http.Request{ + URL: requestURL, + Form: url.Values{}, + } + + base, baseCleanUp := context.NewBaseContext(resp, req) + base.Data = middleware.ContextData{} + base.Locale = &translation.MockLocale{} + ctx := &context.Context{ + Base: base, Render: &mockRender{}, - Data: make(middleware.ContextData), - Flash: &middleware.Flash{ - Values: make(url.Values), - }, - Resp: context.NewResponse(resp), - Locale: &translation.MockLocale{}, + Flash: &middleware.Flash{Values: url.Values{}}, } - defer ctx.Close() + _ = baseCleanUp // during test, it doesn't need to do clean up. TODO: this can be improved later + chiCtx := chi.NewRouteContext() + ctx.Base.AppendContextValue(chi.RouteCtxKey, chiCtx) + return ctx +} + +// MockAPIContext mock context for unit tests +// TODO: move this function to other packages, because it depends on "models" package +func MockAPIContext(t *testing.T, path string) *context.APIContext { + resp := httptest.NewRecorder() requestURL, err := url.Parse(path) assert.NoError(t, err) req := &http.Request{ @@ -47,41 +62,79 @@ func MockContext(t *testing.T, path string) *context.Context { Form: url.Values{}, } + base, baseCleanUp := context.NewBaseContext(resp, req) + base.Data = middleware.ContextData{} + base.Locale = &translation.MockLocale{} + ctx := &context.APIContext{Base: base} + _ = baseCleanUp // during test, it doesn't need to do clean up. TODO: this can be improved later + chiCtx := chi.NewRouteContext() - req = req.WithContext(scontext.WithValue(req.Context(), chi.RouteCtxKey, chiCtx)) - ctx.Req = context.WithContext(req, &ctx) - return &ctx + ctx.Base.AppendContextValue(chi.RouteCtxKey, chiCtx) + return ctx } // LoadRepo load a repo into a test context. -func LoadRepo(t *testing.T, ctx *context.Context, repoID int64) { - ctx.Repo = &context.Repository{} - ctx.Repo.Repository = unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: repoID}) +func LoadRepo(t *testing.T, ctx gocontext.Context, repoID int64) { + var doer *user_model.User + repo := &context.Repository{} + switch ctx := ctx.(type) { + case *context.Context: + ctx.Repo = repo + doer = ctx.Doer + case *context.APIContext: + ctx.Repo = repo + doer = ctx.Doer + default: + assert.Fail(t, "context is not *context.Context or *context.APIContext") + return + } + + repo.Repository = unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: repoID}) var err error - ctx.Repo.Owner, err = user_model.GetUserByID(ctx, ctx.Repo.Repository.OwnerID) + repo.Owner, err = user_model.GetUserByID(ctx, repo.Repository.OwnerID) assert.NoError(t, err) - ctx.Repo.RepoLink = ctx.Repo.Repository.Link() - ctx.Repo.Permission, err = access_model.GetUserRepoPermission(ctx, ctx.Repo.Repository, ctx.Doer) + repo.RepoLink = repo.Repository.Link() + repo.Permission, err = access_model.GetUserRepoPermission(ctx, repo.Repository, doer) assert.NoError(t, err) } // LoadRepoCommit loads a repo's commit into a test context. -func LoadRepoCommit(t *testing.T, ctx *context.Context) { - gitRepo, err := git.OpenRepository(ctx, ctx.Repo.Repository.RepoPath()) +func LoadRepoCommit(t *testing.T, ctx gocontext.Context) { + var repo *context.Repository + switch ctx := ctx.(type) { + case *context.Context: + repo = ctx.Repo + case *context.APIContext: + repo = ctx.Repo + default: + assert.Fail(t, "context is not *context.Context or *context.APIContext") + return + } + + gitRepo, err := git.OpenRepository(ctx, repo.Repository.RepoPath()) assert.NoError(t, err) defer gitRepo.Close() branch, err := gitRepo.GetHEADBranch() assert.NoError(t, err) assert.NotNil(t, branch) if branch != nil { - ctx.Repo.Commit, err = gitRepo.GetBranchCommit(branch.Name) + repo.Commit, err = gitRepo.GetBranchCommit(branch.Name) assert.NoError(t, err) } } // LoadUser load a user into a test context. -func LoadUser(t *testing.T, ctx *context.Context, userID int64) { - ctx.Doer = unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: userID}) +func LoadUser(t *testing.T, ctx gocontext.Context, userID int64) { + doer := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: userID}) + switch ctx := ctx.(type) { + case *context.Context: + ctx.Doer = doer + case *context.APIContext: + ctx.Doer = doer + default: + assert.Fail(t, "context is not *context.Context or *context.APIContext") + return + } } // LoadGitRepo load a git repo into a test context. Requires that ctx.Repo has @@ -93,32 +146,6 @@ func LoadGitRepo(t *testing.T, ctx *context.Context) { assert.NoError(t, err) } -type mockResponseWriter struct { - httptest.ResponseRecorder - size int -} - -func (rw *mockResponseWriter) Write(b []byte) (int, error) { - rw.size += len(b) - return rw.ResponseRecorder.Write(b) -} - -func (rw *mockResponseWriter) Status() int { - return rw.ResponseRecorder.Code -} - -func (rw *mockResponseWriter) Written() bool { - return rw.ResponseRecorder.Code > 0 -} - -func (rw *mockResponseWriter) Size() int { - return rw.size -} - -func (rw *mockResponseWriter) Push(target string, opts *http.PushOptions) error { - return nil -} - type mockRender struct{} func (tr *mockRender) TemplateLookup(tmpl string) (templates.TemplateExecutor, error) { diff --git a/modules/translation/translation.go b/modules/translation/translation.go index 49dfa84d1b..dba4de6607 100644 --- a/modules/translation/translation.go +++ b/modules/translation/translation.go @@ -38,10 +38,12 @@ type LangType struct { } var ( - lock *sync.RWMutex + lock *sync.RWMutex + + allLangs []*LangType + allLangMap map[string]*LangType + matcher language.Matcher - allLangs []*LangType - allLangMap map[string]*LangType supportedTags []language.Tag ) @@ -251,3 +253,9 @@ func (l *locale) PrettyNumber(v any) string { } return l.msgPrinter.Sprintf("%v", number.Decimal(v)) } + +func init() { + // prepare a default matcher, especially for tests + supportedTags = []language.Tag{language.English} + matcher = language.NewMatcher(supportedTags) +} diff --git a/modules/web/handler.go b/modules/web/handler.go index bfb83820c8..5013bac93f 100644 --- a/modules/web/handler.go +++ b/modules/web/handler.go @@ -10,6 +10,7 @@ import ( "reflect" "code.gitea.io/gitea/modules/context" + "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/web/routing" ) @@ -25,6 +26,10 @@ var argTypeProvider = map[reflect.Type]func(req *http.Request) ResponseStatusPro reflect.TypeOf(&context.PrivateContext{}): func(req *http.Request) ResponseStatusProvider { return context.GetPrivateContext(req) }, } +func RegisterHandleTypeProvider[T any](fn func(req *http.Request) ResponseStatusProvider) { + argTypeProvider[reflect.TypeOf((*T)(nil)).Elem()] = fn +} + // responseWriter is a wrapper of http.ResponseWriter, to check whether the response has been written type responseWriter struct { respWriter http.ResponseWriter @@ -78,7 +83,13 @@ func preCheckHandler(fn reflect.Value, argsIn []reflect.Value) { } } -func prepareHandleArgsIn(resp http.ResponseWriter, req *http.Request, fn reflect.Value) []reflect.Value { +func prepareHandleArgsIn(resp http.ResponseWriter, req *http.Request, fn reflect.Value, fnInfo *routing.FuncInfo) []reflect.Value { + defer func() { + if err := recover(); err != nil { + log.Error("unable to prepare handler arguments for %s: %v", fnInfo.String(), err) + panic(err) + } + }() isPreCheck := req == nil argsIn := make([]reflect.Value, fn.Type().NumIn()) @@ -155,7 +166,7 @@ func toHandlerProvider(handler any) func(next http.Handler) http.Handler { } // prepare the arguments for the handler and do pre-check - argsIn := prepareHandleArgsIn(resp, req, fn) + argsIn := prepareHandleArgsIn(resp, req, fn, funcInfo) if req == nil { preCheckHandler(fn, argsIn) return // it's doing pre-check, just return diff --git a/routers/api/actions/artifacts.go b/routers/api/actions/artifacts.go index 61d432c862..4b10cd7ad1 100644 --- a/routers/api/actions/artifacts.go +++ b/routers/api/actions/artifacts.go @@ -3,7 +3,7 @@ package actions -// Github Actions Artifacts API Simple Description +// GitHub Actions Artifacts API Simple Description // // 1. Upload artifact // 1.1. Post upload url @@ -63,7 +63,6 @@ package actions import ( "compress/gzip" - gocontext "context" "crypto/md5" "encoding/base64" "errors" @@ -92,9 +91,25 @@ const ( const artifactRouteBase = "/_apis/pipelines/workflows/{run_id}/artifacts" -func ArtifactsRoutes(goctx gocontext.Context, prefix string) *web.Route { +type artifactContextKeyType struct{} + +var artifactContextKey = artifactContextKeyType{} + +type ArtifactContext struct { + *context.Base + + ActionTask *actions.ActionTask +} + +func init() { + web.RegisterHandleTypeProvider[*ArtifactContext](func(req *http.Request) web.ResponseStatusProvider { + return req.Context().Value(artifactContextKey).(*ArtifactContext) + }) +} + +func ArtifactsRoutes(prefix string) *web.Route { m := web.NewRoute() - m.Use(withContexter(goctx)) + m.Use(ArtifactContexter()) r := artifactRoutes{ prefix: prefix, @@ -115,15 +130,14 @@ func ArtifactsRoutes(goctx gocontext.Context, prefix string) *web.Route { return m } -// withContexter initializes a package context for a request. -func withContexter(goctx gocontext.Context) func(next http.Handler) http.Handler { +func ArtifactContexter() func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { - ctx := context.Context{ - Resp: context.NewResponse(resp), - Data: map[string]interface{}{}, - } - defer ctx.Close() + base, baseCleanUp := context.NewBaseContext(resp, req) + defer baseCleanUp() + + ctx := &ArtifactContext{Base: base} + ctx.AppendContextValue(artifactContextKey, ctx) // action task call server api with Bearer ACTIONS_RUNTIME_TOKEN // we should verify the ACTIONS_RUNTIME_TOKEN @@ -132,6 +146,7 @@ func withContexter(goctx gocontext.Context) func(next http.Handler) http.Handler ctx.Error(http.StatusUnauthorized, "Bad authorization header") return } + authToken := strings.TrimPrefix(authHeader, "Bearer ") task, err := actions.GetRunningTaskByToken(req.Context(), authToken) if err != nil { @@ -139,16 +154,14 @@ func withContexter(goctx gocontext.Context) func(next http.Handler) http.Handler ctx.Error(http.StatusInternalServerError, "Error runner api getting task") return } - ctx.Data["task"] = task - if err := task.LoadJob(goctx); err != nil { + if err := task.LoadJob(req.Context()); err != nil { log.Error("Error runner api getting job: %v", err) ctx.Error(http.StatusInternalServerError, "Error runner api getting job") return } - ctx.Req = context.WithContext(req, &ctx) - + ctx.ActionTask = task next.ServeHTTP(ctx.Resp, ctx.Req) }) } @@ -175,13 +188,8 @@ type getUploadArtifactResponse struct { FileContainerResourceURL string `json:"fileContainerResourceUrl"` } -func (ar artifactRoutes) validateRunID(ctx *context.Context) (*actions.ActionTask, int64, bool) { - task, ok := ctx.Data["task"].(*actions.ActionTask) - if !ok { - log.Error("Error getting task in context") - ctx.Error(http.StatusInternalServerError, "Error getting task in context") - return nil, 0, false - } +func (ar artifactRoutes) validateRunID(ctx *ArtifactContext) (*actions.ActionTask, int64, bool) { + task := ctx.ActionTask runID := ctx.ParamsInt64("run_id") if task.Job.RunID != runID { log.Error("Error runID not match") @@ -192,7 +200,7 @@ func (ar artifactRoutes) validateRunID(ctx *context.Context) (*actions.ActionTas } // getUploadArtifactURL generates a URL for uploading an artifact -func (ar artifactRoutes) getUploadArtifactURL(ctx *context.Context) { +func (ar artifactRoutes) getUploadArtifactURL(ctx *ArtifactContext) { task, runID, ok := ar.validateRunID(ctx) if !ok { return @@ -220,7 +228,7 @@ func (ar artifactRoutes) getUploadArtifactURL(ctx *context.Context) { // getUploadFileSize returns the size of the file to be uploaded. // The raw size is the size of the file as reported by the header X-TFS-FileLength. -func (ar artifactRoutes) getUploadFileSize(ctx *context.Context) (int64, int64, error) { +func (ar artifactRoutes) getUploadFileSize(ctx *ArtifactContext) (int64, int64, error) { contentLength := ctx.Req.ContentLength xTfsLength, _ := strconv.ParseInt(ctx.Req.Header.Get(artifactXTfsFileLengthHeader), 10, 64) if xTfsLength > 0 { @@ -229,7 +237,7 @@ func (ar artifactRoutes) getUploadFileSize(ctx *context.Context) (int64, int64, return contentLength, contentLength, nil } -func (ar artifactRoutes) saveUploadChunk(ctx *context.Context, +func (ar artifactRoutes) saveUploadChunk(ctx *ArtifactContext, artifact *actions.ActionArtifact, contentSize, runID int64, ) (int64, error) { @@ -273,7 +281,7 @@ func (ar artifactRoutes) saveUploadChunk(ctx *context.Context, // The rules are from https://github.com/actions/toolkit/blob/main/packages/artifact/src/internal/path-and-artifact-name-validation.ts#L32 var invalidArtifactNameChars = strings.Join([]string{"\\", "/", "\"", ":", "<", ">", "|", "*", "?", "\r", "\n"}, "") -func (ar artifactRoutes) uploadArtifact(ctx *context.Context) { +func (ar artifactRoutes) uploadArtifact(ctx *ArtifactContext) { _, runID, ok := ar.validateRunID(ctx) if !ok { return @@ -341,7 +349,7 @@ func (ar artifactRoutes) uploadArtifact(ctx *context.Context) { // comfirmUploadArtifact comfirm upload artifact. // if all chunks are uploaded, merge them to one file. -func (ar artifactRoutes) comfirmUploadArtifact(ctx *context.Context) { +func (ar artifactRoutes) comfirmUploadArtifact(ctx *ArtifactContext) { _, runID, ok := ar.validateRunID(ctx) if !ok { return @@ -364,7 +372,7 @@ type chunkItem struct { Path string } -func (ar artifactRoutes) mergeArtifactChunks(ctx *context.Context, runID int64) error { +func (ar artifactRoutes) mergeArtifactChunks(ctx *ArtifactContext, runID int64) error { storageDir := fmt.Sprintf("tmp%d", runID) var chunks []*chunkItem if err := ar.fs.IterateObjects(storageDir, func(path string, obj storage.Object) error { @@ -415,14 +423,20 @@ func (ar artifactRoutes) mergeArtifactChunks(ctx *context.Context, runID int64) // use multiReader readers := make([]io.Reader, 0, len(allChunks)) - readerClosers := make([]io.Closer, 0, len(allChunks)) + closeReaders := func() { + for _, r := range readers { + _ = r.(io.Closer).Close() // it guarantees to be io.Closer by the following loop's Open function + } + readers = nil + } + defer closeReaders() + for _, c := range allChunks { - reader, err := ar.fs.Open(c.Path) - if err != nil { + var readCloser io.ReadCloser + if readCloser, err = ar.fs.Open(c.Path); err != nil { return fmt.Errorf("open chunk error: %v, %s", err, c.Path) } - readers = append(readers, reader) - readerClosers = append(readerClosers, reader) + readers = append(readers, readCloser) } mergedReader := io.MultiReader(readers...) @@ -445,11 +459,6 @@ func (ar artifactRoutes) mergeArtifactChunks(ctx *context.Context, runID int64) return fmt.Errorf("merged file size is not equal to chunk length") } - // close readers - for _, r := range readerClosers { - r.Close() - } - // save storage path to artifact log.Debug("[artifact] merge chunks to artifact: %d, %s", artifact.ID, storagePath) artifact.StoragePath = storagePath @@ -458,6 +467,8 @@ func (ar artifactRoutes) mergeArtifactChunks(ctx *context.Context, runID int64) return fmt.Errorf("update artifact error: %v", err) } + closeReaders() // close before delete + // drop chunks for _, c := range cs { if err := ar.fs.Delete(c.Path); err != nil { @@ -479,21 +490,21 @@ type ( } ) -func (ar artifactRoutes) listArtifacts(ctx *context.Context) { +func (ar artifactRoutes) listArtifacts(ctx *ArtifactContext) { _, runID, ok := ar.validateRunID(ctx) if !ok { return } - artficats, err := actions.ListArtifactsByRunID(ctx, runID) + artifacts, err := actions.ListArtifactsByRunID(ctx, runID) if err != nil { log.Error("Error getting artifacts: %v", err) ctx.Error(http.StatusInternalServerError, err.Error()) return } - artficatsData := make([]listArtifactsResponseItem, 0, len(artficats)) - for _, a := range artficats { + artficatsData := make([]listArtifactsResponseItem, 0, len(artifacts)) + for _, a := range artifacts { artficatsData = append(artficatsData, listArtifactsResponseItem{ Name: a.ArtifactName, FileContainerResourceURL: ar.buildArtifactURL(runID, a.ID, "path"), @@ -517,7 +528,7 @@ type ( } ) -func (ar artifactRoutes) getDownloadArtifactURL(ctx *context.Context) { +func (ar artifactRoutes) getDownloadArtifactURL(ctx *ArtifactContext) { _, runID, ok := ar.validateRunID(ctx) if !ok { return @@ -546,7 +557,7 @@ func (ar artifactRoutes) getDownloadArtifactURL(ctx *context.Context) { ctx.JSON(http.StatusOK, respData) } -func (ar artifactRoutes) downloadArtifact(ctx *context.Context) { +func (ar artifactRoutes) downloadArtifact(ctx *ArtifactContext) { _, runID, ok := ar.validateRunID(ctx) if !ok { return diff --git a/routers/api/packages/api.go b/routers/api/packages/api.go index aaceb8a92b..e715997e82 100644 --- a/routers/api/packages/api.go +++ b/routers/api/packages/api.go @@ -98,7 +98,7 @@ func verifyAuth(r *web.Route, authMethods []auth.Method) { func CommonRoutes(ctx gocontext.Context) *web.Route { r := web.NewRoute() - r.Use(context.PackageContexter(ctx)) + r.Use(context.PackageContexter()) verifyAuth(r, []auth.Method{ &auth.OAuth2{}, @@ -574,7 +574,7 @@ func CommonRoutes(ctx gocontext.Context) *web.Route { func ContainerRoutes(ctx gocontext.Context) *web.Route { r := web.NewRoute() - r.Use(context.PackageContexter(ctx)) + r.Use(context.PackageContexter()) verifyAuth(r, []auth.Method{ &auth.Basic{}, diff --git a/routers/api/v1/api.go b/routers/api/v1/api.go index a67a5420ac..f1e1cf946a 100644 --- a/routers/api/v1/api.go +++ b/routers/api/v1/api.go @@ -149,7 +149,7 @@ func repoAssignment() func(ctx *context.APIContext) { if err != nil { if user_model.IsErrUserNotExist(err) { if redirectUserID, err := user_model.LookupUserRedirect(userName); err == nil { - context.RedirectToUser(ctx.Context, userName, redirectUserID) + context.RedirectToUser(ctx.Base, userName, redirectUserID) } else if user_model.IsErrUserRedirectNotExist(err) { ctx.NotFound("GetUserByName", err) } else { @@ -170,7 +170,7 @@ func repoAssignment() func(ctx *context.APIContext) { if repo_model.IsErrRepoNotExist(err) { redirectRepoID, err := repo_model.LookupRedirect(owner.ID, repoName) if err == nil { - context.RedirectToRepo(ctx.Context, redirectRepoID) + context.RedirectToRepo(ctx.Base, redirectRepoID) } else if repo_model.IsErrRedirectNotExist(err) { ctx.NotFound() } else { @@ -274,7 +274,7 @@ func reqToken(requiredScope auth_model.AccessTokenScope) func(ctx *context.APICo ctx.Error(http.StatusForbidden, "reqToken", "token does not have required scope: "+requiredScope) return } - if ctx.Context.IsBasicAuth { + if ctx.IsBasicAuth { ctx.CheckForOTP() return } @@ -295,7 +295,7 @@ func reqExploreSignIn() func(ctx *context.APIContext) { func reqBasicAuth() func(ctx *context.APIContext) { return func(ctx *context.APIContext) { - if !ctx.Context.IsBasicAuth { + if !ctx.IsBasicAuth { ctx.Error(http.StatusUnauthorized, "reqBasicAuth", "auth required") return } @@ -375,7 +375,7 @@ func reqAnyRepoReader() func(ctx *context.APIContext) { // reqOrgOwnership user should be an organization owner, or a site admin func reqOrgOwnership() func(ctx *context.APIContext) { return func(ctx *context.APIContext) { - if ctx.Context.IsUserSiteAdmin() { + if ctx.IsUserSiteAdmin() { return } @@ -407,7 +407,7 @@ func reqOrgOwnership() func(ctx *context.APIContext) { // reqTeamMembership user should be an team member, or a site admin func reqTeamMembership() func(ctx *context.APIContext) { return func(ctx *context.APIContext) { - if ctx.Context.IsUserSiteAdmin() { + if ctx.IsUserSiteAdmin() { return } if ctx.Org.Team == nil { @@ -444,7 +444,7 @@ func reqTeamMembership() func(ctx *context.APIContext) { // reqOrgMembership user should be an organization member, or a site admin func reqOrgMembership() func(ctx *context.APIContext) { return func(ctx *context.APIContext) { - if ctx.Context.IsUserSiteAdmin() { + if ctx.IsUserSiteAdmin() { return } @@ -512,7 +512,7 @@ func orgAssignment(args ...bool) func(ctx *context.APIContext) { if organization.IsErrOrgNotExist(err) { redirectUserID, err := user_model.LookupUserRedirect(ctx.Params(":org")) if err == nil { - context.RedirectToUser(ctx.Context, ctx.Params(":org"), redirectUserID) + context.RedirectToUser(ctx.Base, ctx.Params(":org"), redirectUserID) } else if user_model.IsErrUserRedirectNotExist(err) { ctx.NotFound("GetOrgByName", err) } else { diff --git a/routers/api/v1/misc/markup.go b/routers/api/v1/misc/markup.go index 93d5754444..7b24b353b6 100644 --- a/routers/api/v1/misc/markup.go +++ b/routers/api/v1/misc/markup.go @@ -41,7 +41,7 @@ func Markup(ctx *context.APIContext) { return } - common.RenderMarkup(ctx.Context, form.Mode, form.Text, form.Context, form.FilePath, form.Wiki) + common.RenderMarkup(ctx.Base, ctx.Repo, form.Mode, form.Text, form.Context, form.FilePath, form.Wiki) } // Markdown render markdown document to HTML @@ -76,7 +76,7 @@ func Markdown(ctx *context.APIContext) { mode = form.Mode } - common.RenderMarkup(ctx.Context, mode, form.Text, form.Context, "", form.Wiki) + common.RenderMarkup(ctx.Base, ctx.Repo, mode, form.Text, form.Context, "", form.Wiki) } // MarkdownRaw render raw markdown HTML diff --git a/routers/api/v1/misc/markup_test.go b/routers/api/v1/misc/markup_test.go index 68776613b2..fdf540fd65 100644 --- a/routers/api/v1/misc/markup_test.go +++ b/routers/api/v1/misc/markup_test.go @@ -16,7 +16,6 @@ import ( "code.gitea.io/gitea/modules/markup" "code.gitea.io/gitea/modules/setting" api "code.gitea.io/gitea/modules/structs" - "code.gitea.io/gitea/modules/templates" "code.gitea.io/gitea/modules/util" "code.gitea.io/gitea/modules/web" "code.gitea.io/gitea/modules/web/middleware" @@ -30,26 +29,16 @@ const ( AppSubURL = AppURL + Repo + "/" ) -func createContext(req *http.Request) (*context.Context, *httptest.ResponseRecorder) { - rnd := templates.HTMLRenderer() +func createAPIContext(req *http.Request) (*context.APIContext, *httptest.ResponseRecorder) { resp := httptest.NewRecorder() - c := &context.Context{ - Req: req, - Resp: context.NewResponse(resp), - Render: rnd, - Data: make(middleware.ContextData), - } - defer c.Close() + base, baseCleanUp := context.NewBaseContext(resp, req) + base.Data = middleware.ContextData{} + c := &context.APIContext{Base: base} + _ = baseCleanUp // during test, it doesn't need to do clean up. TODO: this can be improved later return c, resp } -func wrap(ctx *context.Context) *context.APIContext { - return &context.APIContext{ - Context: ctx, - } -} - func testRenderMarkup(t *testing.T, mode, filePath, text, responseBody string, responseCode int) { setting.AppURL = AppURL @@ -65,8 +54,7 @@ func testRenderMarkup(t *testing.T, mode, filePath, text, responseBody string, r Method: "POST", URL: requrl, } - m, resp := createContext(req) - ctx := wrap(m) + ctx, resp := createAPIContext(req) options.Text = text web.SetForm(ctx, &options) @@ -90,8 +78,7 @@ func testRenderMarkdown(t *testing.T, mode, text, responseBody string, responseC Method: "POST", URL: requrl, } - m, resp := createContext(req) - ctx := wrap(m) + ctx, resp := createAPIContext(req) options.Text = text web.SetForm(ctx, &options) @@ -211,8 +198,7 @@ func TestAPI_RenderSimple(t *testing.T) { Method: "POST", URL: requrl, } - m, resp := createContext(req) - ctx := wrap(m) + ctx, resp := createAPIContext(req) for i := 0; i < len(simpleCases); i += 2 { options.Text = simpleCases[i] @@ -231,8 +217,7 @@ func TestAPI_RenderRaw(t *testing.T) { Method: "POST", URL: requrl, } - m, resp := createContext(req) - ctx := wrap(m) + ctx, resp := createAPIContext(req) for i := 0; i < len(simpleCases); i += 2 { ctx.Req.Body = io.NopCloser(strings.NewReader(simpleCases[i])) diff --git a/routers/api/v1/notify/notifications.go b/routers/api/v1/notify/notifications.go index 3b6a9bfdc2..b22ea8a771 100644 --- a/routers/api/v1/notify/notifications.go +++ b/routers/api/v1/notify/notifications.go @@ -25,7 +25,7 @@ func NewAvailable(ctx *context.APIContext) { } func getFindNotificationOptions(ctx *context.APIContext) *activities_model.FindNotificationOptions { - before, since, err := context.GetQueryBeforeSince(ctx.Context) + before, since, err := context.GetQueryBeforeSince(ctx.Base) if err != nil { ctx.Error(http.StatusUnprocessableEntity, "GetQueryBeforeSince", err) return nil diff --git a/routers/api/v1/repo/file.go b/routers/api/v1/repo/file.go index eb63dda590..786407827c 100644 --- a/routers/api/v1/repo/file.go +++ b/routers/api/v1/repo/file.go @@ -80,7 +80,7 @@ func GetRawFile(ctx *context.APIContext) { ctx.RespHeader().Set(giteaObjectTypeHeader, string(files_service.GetObjectTypeFromTreeEntry(entry))) - if err := common.ServeBlob(ctx.Context, blob, lastModified); err != nil { + if err := common.ServeBlob(ctx.Base, ctx.Repo.TreePath, blob, lastModified); err != nil { ctx.Error(http.StatusInternalServerError, "ServeBlob", err) } } @@ -137,7 +137,7 @@ func GetRawFileOrLFS(ctx *context.APIContext) { } // OK not cached - serve! - if err := common.ServeBlob(ctx.Context, blob, lastModified); err != nil { + if err := common.ServeBlob(ctx.Base, ctx.Repo.TreePath, blob, lastModified); err != nil { ctx.ServerError("ServeBlob", err) } return @@ -159,7 +159,7 @@ func GetRawFileOrLFS(ctx *context.APIContext) { } if err := dataRc.Close(); err != nil { - log.Error("Error whilst closing blob %s reader in %-v. Error: %v", blob.ID, ctx.Context.Repo.Repository, err) + log.Error("Error whilst closing blob %s reader in %-v. Error: %v", blob.ID, ctx.Repo.Repository, err) } // Check if the blob represents a pointer @@ -173,7 +173,7 @@ func GetRawFileOrLFS(ctx *context.APIContext) { } // OK not cached - serve! - common.ServeContentByReader(ctx.Context, ctx.Repo.TreePath, blob.Size(), bytes.NewReader(buf)) + common.ServeContentByReader(ctx.Base, ctx.Repo.TreePath, blob.Size(), bytes.NewReader(buf)) return } @@ -187,7 +187,7 @@ func GetRawFileOrLFS(ctx *context.APIContext) { return } - common.ServeContentByReader(ctx.Context, ctx.Repo.TreePath, blob.Size(), bytes.NewReader(buf)) + common.ServeContentByReader(ctx.Base, ctx.Repo.TreePath, blob.Size(), bytes.NewReader(buf)) return } else if err != nil { ctx.ServerError("GetLFSMetaObjectByOid", err) @@ -215,7 +215,7 @@ func GetRawFileOrLFS(ctx *context.APIContext) { } defer lfsDataRc.Close() - common.ServeContentByReadSeeker(ctx.Context, ctx.Repo.TreePath, lastModified, lfsDataRc) + common.ServeContentByReadSeeker(ctx.Base, ctx.Repo.TreePath, lastModified, lfsDataRc) } func getBlobForEntry(ctx *context.APIContext) (blob *git.Blob, entry *git.TreeEntry, lastModified time.Time) { diff --git a/routers/api/v1/repo/hook_test.go b/routers/api/v1/repo/hook_test.go index 34dc990c3d..56658b45d5 100644 --- a/routers/api/v1/repo/hook_test.go +++ b/routers/api/v1/repo/hook_test.go @@ -9,7 +9,6 @@ import ( "code.gitea.io/gitea/models/unittest" "code.gitea.io/gitea/models/webhook" - "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/test" "github.com/stretchr/testify/assert" @@ -18,12 +17,12 @@ import ( func TestTestHook(t *testing.T) { unittest.PrepareTestEnv(t) - ctx := test.MockContext(t, "user2/repo1/wiki/_pages") + ctx := test.MockAPIContext(t, "user2/repo1/wiki/_pages") ctx.SetParams(":id", "1") test.LoadRepo(t, ctx, 1) test.LoadRepoCommit(t, ctx) test.LoadUser(t, ctx, 2) - TestHook(&context.APIContext{Context: ctx, Org: nil}) + TestHook(ctx) assert.EqualValues(t, http.StatusNoContent, ctx.Resp.Status()) unittest.AssertExistsAndLoadBean(t, &webhook.HookTask{ diff --git a/routers/api/v1/repo/issue.go b/routers/api/v1/repo/issue.go index 95528d664d..49252f7a4b 100644 --- a/routers/api/v1/repo/issue.go +++ b/routers/api/v1/repo/issue.go @@ -116,7 +116,7 @@ func SearchIssues(ctx *context.APIContext) { // "200": // "$ref": "#/responses/IssueList" - before, since, err := context.GetQueryBeforeSince(ctx.Context) + before, since, err := context.GetQueryBeforeSince(ctx.Base) if err != nil { ctx.Error(http.StatusUnprocessableEntity, "GetQueryBeforeSince", err) return @@ -368,7 +368,7 @@ func ListIssues(ctx *context.APIContext) { // responses: // "200": // "$ref": "#/responses/IssueList" - before, since, err := context.GetQueryBeforeSince(ctx.Context) + before, since, err := context.GetQueryBeforeSince(ctx.Base) if err != nil { ctx.Error(http.StatusUnprocessableEntity, "GetQueryBeforeSince", err) return diff --git a/routers/api/v1/repo/issue_comment.go b/routers/api/v1/repo/issue_comment.go index 6ae6063303..7c8f30f116 100644 --- a/routers/api/v1/repo/issue_comment.go +++ b/routers/api/v1/repo/issue_comment.go @@ -59,7 +59,7 @@ func ListIssueComments(ctx *context.APIContext) { // "200": // "$ref": "#/responses/CommentList" - before, since, err := context.GetQueryBeforeSince(ctx.Context) + before, since, err := context.GetQueryBeforeSince(ctx.Base) if err != nil { ctx.Error(http.StatusUnprocessableEntity, "GetQueryBeforeSince", err) return @@ -156,7 +156,7 @@ func ListIssueCommentsAndTimeline(ctx *context.APIContext) { // "200": // "$ref": "#/responses/TimelineList" - before, since, err := context.GetQueryBeforeSince(ctx.Context) + before, since, err := context.GetQueryBeforeSince(ctx.Base) if err != nil { ctx.Error(http.StatusUnprocessableEntity, "GetQueryBeforeSince", err) return @@ -259,7 +259,7 @@ func ListRepoIssueComments(ctx *context.APIContext) { // "200": // "$ref": "#/responses/CommentList" - before, since, err := context.GetQueryBeforeSince(ctx.Context) + before, since, err := context.GetQueryBeforeSince(ctx.Base) if err != nil { ctx.Error(http.StatusUnprocessableEntity, "GetQueryBeforeSince", err) return diff --git a/routers/api/v1/repo/issue_tracked_time.go b/routers/api/v1/repo/issue_tracked_time.go index 16bb8cb73d..1ff934950c 100644 --- a/routers/api/v1/repo/issue_tracked_time.go +++ b/routers/api/v1/repo/issue_tracked_time.go @@ -103,7 +103,7 @@ func ListTrackedTimes(ctx *context.APIContext) { opts.UserID = user.ID } - if opts.CreatedBeforeUnix, opts.CreatedAfterUnix, err = context.GetQueryBeforeSince(ctx.Context); err != nil { + if opts.CreatedBeforeUnix, opts.CreatedAfterUnix, err = context.GetQueryBeforeSince(ctx.Base); err != nil { ctx.Error(http.StatusUnprocessableEntity, "GetQueryBeforeSince", err) return } @@ -522,7 +522,7 @@ func ListTrackedTimesByRepository(ctx *context.APIContext) { } var err error - if opts.CreatedBeforeUnix, opts.CreatedAfterUnix, err = context.GetQueryBeforeSince(ctx.Context); err != nil { + if opts.CreatedBeforeUnix, opts.CreatedAfterUnix, err = context.GetQueryBeforeSince(ctx.Base); err != nil { ctx.Error(http.StatusUnprocessableEntity, "GetQueryBeforeSince", err) return } @@ -596,7 +596,7 @@ func ListMyTrackedTimes(ctx *context.APIContext) { } var err error - if opts.CreatedBeforeUnix, opts.CreatedAfterUnix, err = context.GetQueryBeforeSince(ctx.Context); err != nil { + if opts.CreatedBeforeUnix, opts.CreatedAfterUnix, err = context.GetQueryBeforeSince(ctx.Base); err != nil { ctx.Error(http.StatusUnprocessableEntity, "GetQueryBeforeSince", err) return } diff --git a/routers/api/v1/repo/migrate.go b/routers/api/v1/repo/migrate.go index efce39e520..b458cd122b 100644 --- a/routers/api/v1/repo/migrate.go +++ b/routers/api/v1/repo/migrate.go @@ -79,7 +79,7 @@ func Migrate(ctx *context.APIContext) { return } - if ctx.HasError() { + if ctx.HasAPIError() { ctx.Error(http.StatusUnprocessableEntity, "", ctx.GetErrMsg()) return } diff --git a/routers/api/v1/repo/repo_test.go b/routers/api/v1/repo/repo_test.go index 59c3bde819..e1bdea5c82 100644 --- a/routers/api/v1/repo/repo_test.go +++ b/routers/api/v1/repo/repo_test.go @@ -9,7 +9,6 @@ import ( repo_model "code.gitea.io/gitea/models/repo" "code.gitea.io/gitea/models/unittest" - "code.gitea.io/gitea/modules/context" api "code.gitea.io/gitea/modules/structs" "code.gitea.io/gitea/modules/test" "code.gitea.io/gitea/modules/web" @@ -20,7 +19,7 @@ import ( func TestRepoEdit(t *testing.T) { unittest.PrepareTestEnv(t) - ctx := test.MockContext(t, "user2/repo1") + ctx := test.MockAPIContext(t, "user2/repo1") test.LoadRepo(t, ctx, 1) test.LoadUser(t, ctx, 2) ctx.Repo.Owner = ctx.Doer @@ -54,9 +53,8 @@ func TestRepoEdit(t *testing.T) { Archived: &archived, } - apiCtx := &context.APIContext{Context: ctx, Org: nil} - web.SetForm(apiCtx, &opts) - Edit(apiCtx) + web.SetForm(ctx, &opts) + Edit(ctx) assert.EqualValues(t, http.StatusOK, ctx.Resp.Status()) unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ @@ -67,7 +65,7 @@ func TestRepoEdit(t *testing.T) { func TestRepoEditNameChange(t *testing.T) { unittest.PrepareTestEnv(t) - ctx := test.MockContext(t, "user2/repo1") + ctx := test.MockAPIContext(t, "user2/repo1") test.LoadRepo(t, ctx, 1) test.LoadUser(t, ctx, 2) ctx.Repo.Owner = ctx.Doer @@ -76,9 +74,8 @@ func TestRepoEditNameChange(t *testing.T) { Name: &name, } - apiCtx := &context.APIContext{Context: ctx, Org: nil} - web.SetForm(apiCtx, &opts) - Edit(apiCtx) + web.SetForm(ctx, &opts) + Edit(ctx) assert.EqualValues(t, http.StatusOK, ctx.Resp.Status()) unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ diff --git a/routers/api/v1/repo/status.go b/routers/api/v1/repo/status.go index 5158f38e14..c1110ebce5 100644 --- a/routers/api/v1/repo/status.go +++ b/routers/api/v1/repo/status.go @@ -183,7 +183,7 @@ func getCommitStatuses(ctx *context.APIContext, sha string) { ctx.Error(http.StatusBadRequest, "ref/sha not given", nil) return } - sha = utils.MustConvertToSHA1(ctx.Context, sha) + sha = utils.MustConvertToSHA1(ctx.Base, ctx.Repo, sha) repo := ctx.Repo.Repository listOptions := utils.GetListOptions(ctx) diff --git a/routers/api/v1/user/helper.go b/routers/api/v1/user/helper.go index 28f600ad92..4b642910b1 100644 --- a/routers/api/v1/user/helper.go +++ b/routers/api/v1/user/helper.go @@ -17,7 +17,7 @@ func GetUserByParamsName(ctx *context.APIContext, name string) *user_model.User if err != nil { if user_model.IsErrUserNotExist(err) { if redirectUserID, err2 := user_model.LookupUserRedirect(username); err2 == nil { - context.RedirectToUser(ctx.Context, username, redirectUserID) + context.RedirectToUser(ctx.Base, username, redirectUserID) } else { ctx.NotFound("GetUserByName", err) } diff --git a/routers/api/v1/utils/git.go b/routers/api/v1/utils/git.go index eaf0f5fd37..32f5c85319 100644 --- a/routers/api/v1/utils/git.go +++ b/routers/api/v1/utils/git.go @@ -4,6 +4,7 @@ package utils import ( + gocontext "context" "fmt" "net/http" @@ -33,7 +34,7 @@ func ResolveRefOrSha(ctx *context.APIContext, ref string) string { } } - sha = MustConvertToSHA1(ctx.Context, sha) + sha = MustConvertToSHA1(ctx, ctx.Repo, sha) if ctx.Repo.GitRepo != nil { err := ctx.Repo.GitRepo.AddLastCommitCache(ctx.Repo.Repository.GetCommitsCountCacheKey(ref, ref != sha), ctx.Repo.Repository.FullName(), sha) @@ -69,7 +70,7 @@ func searchRefCommitByType(ctx *context.APIContext, refType, filter string) (str } // ConvertToSHA1 returns a full-length SHA1 from a potential ID string -func ConvertToSHA1(ctx *context.Context, commitID string) (git.SHA1, error) { +func ConvertToSHA1(ctx gocontext.Context, repo *context.Repository, commitID string) (git.SHA1, error) { if len(commitID) == git.SHAFullLength && git.IsValidSHAPattern(commitID) { sha1, err := git.NewIDFromString(commitID) if err == nil { @@ -77,7 +78,7 @@ func ConvertToSHA1(ctx *context.Context, commitID string) (git.SHA1, error) { } } - gitRepo, closer, err := git.RepositoryFromContextOrOpen(ctx, ctx.Repo.Repository.RepoPath()) + gitRepo, closer, err := git.RepositoryFromContextOrOpen(ctx, repo.Repository.RepoPath()) if err != nil { return git.SHA1{}, fmt.Errorf("RepositoryFromContextOrOpen: %w", err) } @@ -87,8 +88,8 @@ func ConvertToSHA1(ctx *context.Context, commitID string) (git.SHA1, error) { } // MustConvertToSHA1 returns a full-length SHA1 string from a potential ID string, or returns origin input if it can't convert to SHA1 -func MustConvertToSHA1(ctx *context.Context, commitID string) string { - sha, err := ConvertToSHA1(ctx, commitID) +func MustConvertToSHA1(ctx gocontext.Context, repo *context.Repository, commitID string) string { + sha, err := ConvertToSHA1(ctx, repo, commitID) if err != nil { return commitID } diff --git a/routers/common/markup.go b/routers/common/markup.go index 3acd12721e..5f412014d7 100644 --- a/routers/common/markup.go +++ b/routers/common/markup.go @@ -19,7 +19,7 @@ import ( ) // RenderMarkup renders markup text for the /markup and /markdown endpoints -func RenderMarkup(ctx *context.Context, mode, text, urlPrefix, filePath string, wiki bool) { +func RenderMarkup(ctx *context.Base, repo *context.Repository, mode, text, urlPrefix, filePath string, wiki bool) { var markupType string relativePath := "" @@ -63,11 +63,11 @@ func RenderMarkup(ctx *context.Context, mode, text, urlPrefix, filePath string, } meta := map[string]string{} - if ctx.Repo != nil && ctx.Repo.Repository != nil { + if repo != nil && repo.Repository != nil { if mode == "comment" { - meta = ctx.Repo.Repository.ComposeMetas() + meta = repo.Repository.ComposeMetas() } else { - meta = ctx.Repo.Repository.ComposeDocumentMetas() + meta = repo.Repository.ComposeDocumentMetas() } } if mode != "comment" { diff --git a/routers/common/middleware.go b/routers/common/middleware.go index c1ee9dd765..a25ff1ee00 100644 --- a/routers/common/middleware.go +++ b/routers/common/middleware.go @@ -42,7 +42,7 @@ func ProtocolMiddlewares() (handlers []any) { return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { ctx, _, finished := process.GetManager().AddTypedContext(req.Context(), fmt.Sprintf("%s: %s", req.Method, req.RequestURI), process.RequestProcessType, true) defer finished() - next.ServeHTTP(context.NewResponse(resp), req.WithContext(cache.WithCacheContext(ctx))) + next.ServeHTTP(context.WrapResponseWriter(resp), req.WithContext(cache.WithCacheContext(ctx))) }) }) diff --git a/routers/common/serve.go b/routers/common/serve.go index 59b993328e..3094ee6a6e 100644 --- a/routers/common/serve.go +++ b/routers/common/serve.go @@ -15,7 +15,7 @@ import ( ) // ServeBlob download a git.Blob -func ServeBlob(ctx *context.Context, blob *git.Blob, lastModified time.Time) error { +func ServeBlob(ctx *context.Base, filePath string, blob *git.Blob, lastModified time.Time) error { if httpcache.HandleGenericETagTimeCache(ctx.Req, ctx.Resp, `"`+blob.ID.String()+`"`, lastModified) { return nil } @@ -30,14 +30,14 @@ func ServeBlob(ctx *context.Context, blob *git.Blob, lastModified time.Time) err } }() - httplib.ServeContentByReader(ctx.Req, ctx.Resp, ctx.Repo.TreePath, blob.Size(), dataRc) + httplib.ServeContentByReader(ctx.Req, ctx.Resp, filePath, blob.Size(), dataRc) return nil } -func ServeContentByReader(ctx *context.Context, filePath string, size int64, reader io.Reader) { +func ServeContentByReader(ctx *context.Base, filePath string, size int64, reader io.Reader) { httplib.ServeContentByReader(ctx.Req, ctx.Resp, filePath, size, reader) } -func ServeContentByReadSeeker(ctx *context.Context, filePath string, modTime time.Time, reader io.ReadSeeker) { +func ServeContentByReadSeeker(ctx *context.Base, filePath string, modTime time.Time, reader io.ReadSeeker) { httplib.ServeContentByReadSeeker(ctx.Req, ctx.Resp, filePath, modTime, reader) } diff --git a/routers/init.go b/routers/init.go index 087d8c2915..5737ef3dc0 100644 --- a/routers/init.go +++ b/routers/init.go @@ -198,7 +198,7 @@ func NormalRoutes(ctx context.Context) *web.Route { // In Github, it uses ACTIONS_RUNTIME_URL=https://pipelines.actions.githubusercontent.com/fLgcSHkPGySXeIFrg8W8OBSfeg3b5Fls1A1CwX566g8PayEGlg/ // TODO: this prefix should be generated with a token string with runner ? prefix = "/api/actions_pipeline" - r.Mount(prefix, actions_router.ArtifactsRoutes(ctx, prefix)) + r.Mount(prefix, actions_router.ArtifactsRoutes(prefix)) } return r diff --git a/routers/install/install.go b/routers/install/install.go index 714ddd5548..89b91a5a48 100644 --- a/routers/install/install.go +++ b/routers/install/install.go @@ -58,15 +58,14 @@ func Contexter() func(next http.Handler) http.Handler { dbTypeNames := getSupportedDbTypeNames() return func(next http.Handler) http.Handler { return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + base, baseCleanUp := context.NewBaseContext(resp, req) ctx := context.Context{ - Resp: context.NewResponse(resp), + Base: base, Flash: &middleware.Flash{}, - Locale: middleware.Locale(resp, req), Render: rnd, - Data: middleware.GetContextData(req.Context()), Session: session.GetSession(req), } - defer ctx.Close() + defer baseCleanUp() ctx.Data.MergeFrom(middleware.CommonTemplateContextData()) ctx.Data.MergeFrom(middleware.ContextData{ @@ -78,7 +77,6 @@ func Contexter() func(next http.Handler) http.Handler { "PasswordHashAlgorithms": hash.RecommendedHashAlgorithms, }) - ctx.Req = context.WithContext(req, &ctx) next.ServeHTTP(resp, ctx.Req) }) } @@ -249,15 +247,8 @@ func SubmitInstall(ctx *context.Context) { ctx.Data["CurDbType"] = form.DbType if ctx.HasError() { - if ctx.HasValue("Err_SMTPUser") { - ctx.Data["Err_SMTP"] = true - } - if ctx.HasValue("Err_AdminName") || - ctx.HasValue("Err_AdminPasswd") || - ctx.HasValue("Err_AdminEmail") { - ctx.Data["Err_Admin"] = true - } - + ctx.Data["Err_SMTP"] = ctx.Data["Err_SMTPUser"] != nil + ctx.Data["Err_Admin"] = ctx.Data["Err_AdminName"] != nil || ctx.Data["Err_AdminPasswd"] != nil || ctx.Data["Err_AdminEmail"] != nil ctx.HTML(http.StatusOK, tplInstall) return } diff --git a/routers/web/misc/markup.go b/routers/web/misc/markup.go index 1690378945..c91da9a7f1 100644 --- a/routers/web/misc/markup.go +++ b/routers/web/misc/markup.go @@ -5,8 +5,6 @@ package misc import ( - "net/http" - "code.gitea.io/gitea/modules/context" api "code.gitea.io/gitea/modules/structs" "code.gitea.io/gitea/modules/web" @@ -16,11 +14,5 @@ import ( // Markup render markup document to HTML func Markup(ctx *context.Context) { form := web.GetForm(ctx).(*api.MarkupOption) - - if ctx.HasAPIError() { - ctx.Error(http.StatusUnprocessableEntity, "", ctx.GetErrMsg()) - return - } - - common.RenderMarkup(ctx, form.Mode, form.Text, form.Context, form.FilePath, form.Wiki) + common.RenderMarkup(ctx.Base, ctx.Repo, form.Mode, form.Text, form.Context, form.FilePath, form.Wiki) } diff --git a/routers/web/repo/attachment.go b/routers/web/repo/attachment.go index c46ec29841..fb95e63ecf 100644 --- a/routers/web/repo/attachment.go +++ b/routers/web/repo/attachment.go @@ -153,7 +153,7 @@ func ServeAttachment(ctx *context.Context, uuid string) { } defer fr.Close() - common.ServeContentByReadSeeker(ctx, attach.Name, attach.CreatedUnix.AsTime(), fr) + common.ServeContentByReadSeeker(ctx.Base, attach.Name, attach.CreatedUnix.AsTime(), fr) } // GetAttachment serve attachments diff --git a/routers/web/repo/download.go b/routers/web/repo/download.go index 1c87f9bed7..a498180f35 100644 --- a/routers/web/repo/download.go +++ b/routers/web/repo/download.go @@ -47,7 +47,7 @@ func ServeBlobOrLFS(ctx *context.Context, blob *git.Blob, lastModified time.Time log.Error("ServeBlobOrLFS: Close: %v", err) } closed = true - return common.ServeBlob(ctx, blob, lastModified) + return common.ServeBlob(ctx.Base, ctx.Repo.TreePath, blob, lastModified) } if httpcache.HandleGenericETagCache(ctx.Req, ctx.Resp, `"`+pointer.Oid+`"`) { return nil @@ -71,7 +71,7 @@ func ServeBlobOrLFS(ctx *context.Context, blob *git.Blob, lastModified time.Time log.Error("ServeBlobOrLFS: Close: %v", err) } }() - common.ServeContentByReadSeeker(ctx, ctx.Repo.TreePath, lastModified, lfsDataRc) + common.ServeContentByReadSeeker(ctx.Base, ctx.Repo.TreePath, lastModified, lfsDataRc) return nil } if err = dataRc.Close(); err != nil { @@ -79,7 +79,7 @@ func ServeBlobOrLFS(ctx *context.Context, blob *git.Blob, lastModified time.Time } closed = true - return common.ServeBlob(ctx, blob, lastModified) + return common.ServeBlob(ctx.Base, ctx.Repo.TreePath, blob, lastModified) } func getBlobForEntry(ctx *context.Context) (blob *git.Blob, lastModified time.Time) { @@ -120,7 +120,7 @@ func SingleDownload(ctx *context.Context) { return } - if err := common.ServeBlob(ctx, blob, lastModified); err != nil { + if err := common.ServeBlob(ctx.Base, ctx.Repo.TreePath, blob, lastModified); err != nil { ctx.ServerError("ServeBlob", err) } } @@ -148,7 +148,7 @@ func DownloadByID(ctx *context.Context) { } return } - if err = common.ServeBlob(ctx, blob, time.Time{}); err != nil { + if err = common.ServeBlob(ctx.Base, ctx.Repo.TreePath, blob, time.Time{}); err != nil { ctx.ServerError("ServeBlob", err) } } diff --git a/routers/web/repo/http.go b/routers/web/repo/http.go index 4e45a9b6e2..b6ebd25915 100644 --- a/routers/web/repo/http.go +++ b/routers/web/repo/http.go @@ -109,7 +109,7 @@ func httpBase(ctx *context.Context) (h *serviceHandler) { if err != nil { if repo_model.IsErrRepoNotExist(err) { if redirectRepoID, err := repo_model.LookupRedirect(owner.ID, reponame); err == nil { - context.RedirectToRepo(ctx, redirectRepoID) + context.RedirectToRepo(ctx.Base, redirectRepoID) return } repoExist = false diff --git a/routers/web/repo/issue.go b/routers/web/repo/issue.go index 88d2a97a7a..7a0dc9940b 100644 --- a/routers/web/repo/issue.go +++ b/routers/web/repo/issue.go @@ -2344,7 +2344,7 @@ func UpdatePullReviewRequest(ctx *context.Context) { // SearchIssues searches for issues across the repositories that the user has access to func SearchIssues(ctx *context.Context) { - before, since, err := context.GetQueryBeforeSince(ctx) + before, since, err := context.GetQueryBeforeSince(ctx.Base) if err != nil { ctx.Error(http.StatusUnprocessableEntity, err.Error()) return @@ -2545,7 +2545,7 @@ func getUserIDForFilter(ctx *context.Context, queryName string) int64 { // ListIssues list the issues of a repository func ListIssues(ctx *context.Context) { - before, since, err := context.GetQueryBeforeSince(ctx) + before, since, err := context.GetQueryBeforeSince(ctx.Base) if err != nil { ctx.Error(http.StatusUnprocessableEntity, err.Error()) return diff --git a/routers/web/repo/wiki.go b/routers/web/repo/wiki.go index a335c114be..115418887d 100644 --- a/routers/web/repo/wiki.go +++ b/routers/web/repo/wiki.go @@ -671,7 +671,7 @@ func WikiRaw(ctx *context.Context) { } if entry != nil { - if err = common.ServeBlob(ctx, entry.Blob(), time.Time{}); err != nil { + if err = common.ServeBlob(ctx.Base, ctx.Repo.TreePath, entry.Blob(), time.Time{}); err != nil { ctx.ServerError("ServeBlob", err) } return diff --git a/services/auth/middleware.go b/services/auth/middleware.go index 3b2f883d00..d1955a4c90 100644 --- a/services/auth/middleware.go +++ b/services/auth/middleware.go @@ -8,6 +8,7 @@ import ( "strings" "code.gitea.io/gitea/models/auth" + user_model "code.gitea.io/gitea/models/user" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/setting" @@ -17,11 +18,15 @@ import ( // Auth is a middleware to authenticate a web user func Auth(authMethod Method) func(*context.Context) { return func(ctx *context.Context) { - if err := authShared(ctx, authMethod); err != nil { + ar, err := authShared(ctx.Base, ctx.Session, authMethod) + if err != nil { log.Error("Failed to verify user: %v", err) ctx.Error(http.StatusUnauthorized, "Verify") return } + ctx.Doer = ar.Doer + ctx.IsSigned = ar.Doer != nil + ctx.IsBasicAuth = ar.IsBasicAuth if ctx.Doer == nil { // ensure the session uid is deleted _ = ctx.Session.Delete("uid") @@ -32,32 +37,41 @@ func Auth(authMethod Method) func(*context.Context) { // APIAuth is a middleware to authenticate an api user func APIAuth(authMethod Method) func(*context.APIContext) { return func(ctx *context.APIContext) { - if err := authShared(ctx.Context, authMethod); err != nil { + ar, err := authShared(ctx.Base, nil, authMethod) + if err != nil { ctx.Error(http.StatusUnauthorized, "APIAuth", err) + return } + ctx.Doer = ar.Doer + ctx.IsSigned = ar.Doer != nil + ctx.IsBasicAuth = ar.IsBasicAuth } } -func authShared(ctx *context.Context, authMethod Method) error { - var err error - ctx.Doer, err = authMethod.Verify(ctx.Req, ctx.Resp, ctx, ctx.Session) +type authResult struct { + Doer *user_model.User + IsBasicAuth bool +} + +func authShared(ctx *context.Base, sessionStore SessionStore, authMethod Method) (ar authResult, err error) { + ar.Doer, err = authMethod.Verify(ctx.Req, ctx.Resp, ctx, sessionStore) if err != nil { - return err + return ar, err } - if ctx.Doer != nil { - if ctx.Locale.Language() != ctx.Doer.Language { + if ar.Doer != nil { + if ctx.Locale.Language() != ar.Doer.Language { ctx.Locale = middleware.Locale(ctx.Resp, ctx.Req) } - ctx.IsBasicAuth = ctx.Data["AuthedMethod"].(string) == BasicMethodName - ctx.IsSigned = true - ctx.Data["IsSigned"] = ctx.IsSigned - ctx.Data[middleware.ContextDataKeySignedUser] = ctx.Doer - ctx.Data["SignedUserID"] = ctx.Doer.ID - ctx.Data["IsAdmin"] = ctx.Doer.IsAdmin + ar.IsBasicAuth = ctx.Data["AuthedMethod"].(string) == BasicMethodName + + ctx.Data["IsSigned"] = true + ctx.Data[middleware.ContextDataKeySignedUser] = ar.Doer + ctx.Data["SignedUserID"] = ar.Doer.ID + ctx.Data["IsAdmin"] = ar.Doer.IsAdmin } else { ctx.Data["SignedUserID"] = int64(0) } - return nil + return ar, nil } // VerifyOptions contains required or check options @@ -68,7 +82,7 @@ type VerifyOptions struct { DisableCSRF bool } -// Checks authentication according to options +// VerifyAuthWithOptions checks authentication according to options func VerifyAuthWithOptions(options *VerifyOptions) func(ctx *context.Context) { return func(ctx *context.Context) { // Check prohibit login users. @@ -153,7 +167,7 @@ func VerifyAuthWithOptions(options *VerifyOptions) func(ctx *context.Context) { } } -// Checks authentication according to options +// VerifyAuthWithOptionsAPI checks authentication according to options func VerifyAuthWithOptionsAPI(options *VerifyOptions) func(ctx *context.APIContext) { return func(ctx *context.APIContext) { // Check prohibit login users. @@ -197,7 +211,9 @@ func VerifyAuthWithOptionsAPI(options *VerifyOptions) func(ctx *context.APIConte return } else if !ctx.Doer.IsActive && setting.Service.RegisterEmailConfirm { ctx.Data["Title"] = ctx.Tr("auth.active_your_account") - ctx.HTML(http.StatusOK, "user/auth/activate") + ctx.JSON(http.StatusForbidden, map[string]string{ + "message": "This account is not activated.", + }) return } if ctx.IsSigned && ctx.IsBasicAuth { diff --git a/services/context/user.go b/services/context/user.go index c713667bca..4e74aa50bd 100644 --- a/services/context/user.go +++ b/services/context/user.go @@ -15,7 +15,7 @@ import ( // UserAssignmentWeb returns a middleware to handle context-user assignment for web routes func UserAssignmentWeb() func(ctx *context.Context) { return func(ctx *context.Context) { - userAssignment(ctx, func(status int, title string, obj interface{}) { + errorFn := func(status int, title string, obj interface{}) { err, ok := obj.(error) if !ok { err = fmt.Errorf("%s", obj) @@ -25,7 +25,8 @@ func UserAssignmentWeb() func(ctx *context.Context) { } else { ctx.ServerError(title, err) } - }) + } + ctx.ContextUser = userAssignment(ctx.Base, ctx.Doer, errorFn) } } @@ -53,18 +54,18 @@ func UserIDAssignmentAPI() func(ctx *context.APIContext) { // UserAssignmentAPI returns a middleware to handle context-user assignment for api routes func UserAssignmentAPI() func(ctx *context.APIContext) { return func(ctx *context.APIContext) { - userAssignment(ctx.Context, ctx.Error) + ctx.ContextUser = userAssignment(ctx.Base, ctx.Doer, ctx.Error) } } -func userAssignment(ctx *context.Context, errCb func(int, string, interface{})) { +func userAssignment(ctx *context.Base, doer *user_model.User, errCb func(int, string, interface{})) (contextUser *user_model.User) { username := ctx.Params(":username") - if ctx.IsSigned && ctx.Doer.LowerName == strings.ToLower(username) { - ctx.ContextUser = ctx.Doer + if doer != nil && doer.LowerName == strings.ToLower(username) { + contextUser = doer } else { var err error - ctx.ContextUser, err = user_model.GetUserByName(ctx, username) + contextUser, err = user_model.GetUserByName(ctx, username) if err != nil { if user_model.IsErrUserNotExist(err) { if redirectUserID, err := user_model.LookupUserRedirect(username); err == nil { @@ -79,4 +80,5 @@ func userAssignment(ctx *context.Context, errCb func(int, string, interface{})) } } } + return contextUser } diff --git a/services/forms/admin.go b/services/forms/admin.go index a749f863f3..4b3cacc606 100644 --- a/services/forms/admin.go +++ b/services/forms/admin.go @@ -27,7 +27,7 @@ type AdminCreateUserForm struct { // Validate validates form fields func (f *AdminCreateUserForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -55,7 +55,7 @@ type AdminEditUserForm struct { // Validate validates form fields func (f *AdminEditUserForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -67,6 +67,6 @@ type AdminDashboardForm struct { // Validate validates form fields func (f *AdminDashboardForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } diff --git a/services/forms/auth_form.go b/services/forms/auth_form.go index 5625aa1e2e..25acbbb99e 100644 --- a/services/forms/auth_form.go +++ b/services/forms/auth_form.go @@ -86,6 +86,6 @@ type AuthenticationForm struct { // Validate validates fields func (f *AuthenticationForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } diff --git a/services/forms/org.go b/services/forms/org.go index d753531371..c333bead31 100644 --- a/services/forms/org.go +++ b/services/forms/org.go @@ -30,7 +30,7 @@ type CreateOrgForm struct { // Validate validates the fields func (f *CreateOrgForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -48,7 +48,7 @@ type UpdateOrgSettingForm struct { // Validate validates the fields func (f *UpdateOrgSettingForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -70,6 +70,6 @@ type CreateTeamForm struct { // Validate validates the fields func (f *CreateTeamForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } diff --git a/services/forms/package_form.go b/services/forms/package_form.go index dfec98fff4..cf8abfb8fb 100644 --- a/services/forms/package_form.go +++ b/services/forms/package_form.go @@ -25,6 +25,6 @@ type PackageCleanupRuleForm struct { } func (f *PackageCleanupRuleForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } diff --git a/services/forms/repo_branch_form.go b/services/forms/repo_branch_form.go index bf1183fc43..5deb0ae463 100644 --- a/services/forms/repo_branch_form.go +++ b/services/forms/repo_branch_form.go @@ -21,7 +21,7 @@ type NewBranchForm struct { // Validate validates the fields func (f *NewBranchForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -33,6 +33,6 @@ type RenameBranchForm struct { // Validate validates the fields func (f *RenameBranchForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } diff --git a/services/forms/repo_form.go b/services/forms/repo_form.go index d705ecad3f..cacfb64b17 100644 --- a/services/forms/repo_form.go +++ b/services/forms/repo_form.go @@ -54,7 +54,7 @@ type CreateRepoForm struct { // Validate validates the fields func (f *CreateRepoForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -87,7 +87,7 @@ type MigrateRepoForm struct { // Validate validates the fields func (f *MigrateRepoForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -176,7 +176,7 @@ type RepoSettingForm struct { // Validate validates the fields func (f *RepoSettingForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -215,7 +215,7 @@ type ProtectBranchForm struct { // Validate validates the fields func (f *ProtectBranchForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -280,7 +280,7 @@ type NewWebhookForm struct { // Validate validates the fields func (f *NewWebhookForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -294,7 +294,7 @@ type NewGogshookForm struct { // Validate validates the fields func (f *NewGogshookForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -310,7 +310,7 @@ type NewSlackHookForm struct { // Validate validates the fields func (f *NewSlackHookForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) if !webhook.IsValidSlackChannel(strings.TrimSpace(f.Channel)) { errs = append(errs, binding.Error{ FieldNames: []string{"Channel"}, @@ -331,7 +331,7 @@ type NewDiscordHookForm struct { // Validate validates the fields func (f *NewDiscordHookForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -343,7 +343,7 @@ type NewDingtalkHookForm struct { // Validate validates the fields func (f *NewDingtalkHookForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -356,7 +356,7 @@ type NewTelegramHookForm struct { // Validate validates the fields func (f *NewTelegramHookForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -370,7 +370,7 @@ type NewMatrixHookForm struct { // Validate validates the fields func (f *NewMatrixHookForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -382,7 +382,7 @@ type NewMSTeamsHookForm struct { // Validate validates the fields func (f *NewMSTeamsHookForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -394,7 +394,7 @@ type NewFeishuHookForm struct { // Validate validates the fields func (f *NewFeishuHookForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -406,7 +406,7 @@ type NewWechatWorkHookForm struct { // Validate validates the fields func (f *NewWechatWorkHookForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -420,7 +420,7 @@ type NewPackagistHookForm struct { // Validate validates the fields func (f *NewPackagistHookForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -447,7 +447,7 @@ type CreateIssueForm struct { // Validate validates the fields func (f *CreateIssueForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -460,7 +460,7 @@ type CreateCommentForm struct { // Validate validates the fields func (f *CreateCommentForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -471,7 +471,7 @@ type ReactionForm struct { // Validate validates the fields func (f *ReactionForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -482,7 +482,7 @@ type IssueLockForm struct { // Validate validates the fields func (i *IssueLockForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, i, ctx.Locale) } @@ -550,7 +550,7 @@ type CreateMilestoneForm struct { // Validate validates the fields func (f *CreateMilestoneForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -572,7 +572,7 @@ type CreateLabelForm struct { // Validate validates the fields func (f *CreateLabelForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -583,7 +583,7 @@ type InitializeLabelsForm struct { // Validate validates the fields func (f *InitializeLabelsForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -611,7 +611,7 @@ type MergePullRequestForm struct { // Validate validates the fields func (f *MergePullRequestForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -629,7 +629,7 @@ type CodeCommentForm struct { // Validate validates the fields func (f *CodeCommentForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -643,7 +643,7 @@ type SubmitReviewForm struct { // Validate validates the fields func (f *SubmitReviewForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -704,7 +704,7 @@ type NewReleaseForm struct { // Validate validates the fields func (f *NewReleaseForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -719,7 +719,7 @@ type EditReleaseForm struct { // Validate validates the fields func (f *EditReleaseForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -740,7 +740,7 @@ type NewWikiForm struct { // Validate validates the fields // FIXME: use code generation to generate this method. func (f *NewWikiForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -765,7 +765,7 @@ type EditRepoFileForm struct { // Validate validates the fields func (f *EditRepoFileForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -776,7 +776,7 @@ type EditPreviewDiffForm struct { // Validate validates the fields func (f *EditPreviewDiffForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -800,7 +800,7 @@ type CherryPickForm struct { // Validate validates the fields func (f *CherryPickForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -825,7 +825,7 @@ type UploadRepoFileForm struct { // Validate validates the fields func (f *UploadRepoFileForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -836,7 +836,7 @@ type RemoveUploadFileForm struct { // Validate validates the fields func (f *RemoveUploadFileForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -859,7 +859,7 @@ type DeleteRepoFileForm struct { // Validate validates the fields func (f *DeleteRepoFileForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -878,7 +878,7 @@ type AddTimeManuallyForm struct { // Validate validates the fields func (f *AddTimeManuallyForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -894,6 +894,6 @@ type DeadlineForm struct { // Validate validates the fields func (f *DeadlineForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } diff --git a/services/forms/repo_tag_form.go b/services/forms/repo_tag_form.go index 1209d2346f..4dd99f9e32 100644 --- a/services/forms/repo_tag_form.go +++ b/services/forms/repo_tag_form.go @@ -21,6 +21,6 @@ type ProtectTagForm struct { // Validate validates the fields func (f *ProtectTagForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } diff --git a/services/forms/runner.go b/services/forms/runner.go index 9063060346..22dea49e31 100644 --- a/services/forms/runner.go +++ b/services/forms/runner.go @@ -20,6 +20,6 @@ type EditRunnerForm struct { // Validate validates form fields func (f *EditRunnerForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } diff --git a/services/forms/user_form.go b/services/forms/user_form.go index 285bc398b2..fa8129bf85 100644 --- a/services/forms/user_form.go +++ b/services/forms/user_form.go @@ -78,7 +78,7 @@ type InstallForm struct { // Validate validates the fields func (f *InstallForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -99,7 +99,7 @@ type RegisterForm struct { // Validate validates the fields func (f *RegisterForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -148,7 +148,7 @@ type MustChangePasswordForm struct { // Validate validates the fields func (f *MustChangePasswordForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -162,7 +162,7 @@ type SignInForm struct { // Validate validates the fields func (f *SignInForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -182,7 +182,7 @@ type AuthorizationForm struct { // Validate validates the fields func (f *AuthorizationForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -197,7 +197,7 @@ type GrantApplicationForm struct { // Validate validates the fields func (f *GrantApplicationForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -216,7 +216,7 @@ type AccessTokenForm struct { // Validate validates the fields func (f *AccessTokenForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -227,7 +227,7 @@ type IntrospectTokenForm struct { // Validate validates the fields func (f *IntrospectTokenForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -252,7 +252,7 @@ type UpdateProfileForm struct { // Validate validates the fields func (f *UpdateProfileForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -263,7 +263,7 @@ type UpdateLanguageForm struct { // Validate validates the fields func (f *UpdateLanguageForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -283,7 +283,7 @@ type AvatarForm struct { // Validate validates the fields func (f *AvatarForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -294,7 +294,7 @@ type AddEmailForm struct { // Validate validates the fields func (f *AddEmailForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -305,7 +305,7 @@ type UpdateThemeForm struct { // Validate validates the field func (f *UpdateThemeForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -332,7 +332,7 @@ type ChangePasswordForm struct { // Validate validates the fields func (f *ChangePasswordForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -343,7 +343,7 @@ type AddOpenIDForm struct { // Validate validates the fields func (f *AddOpenIDForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -360,7 +360,7 @@ type AddKeyForm struct { // Validate validates the fields func (f *AddKeyForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -372,7 +372,7 @@ type AddSecretForm struct { // Validate validates the fields func (f *AddSecretForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -384,7 +384,7 @@ type NewAccessTokenForm struct { // Validate validates the fields func (f *NewAccessTokenForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -403,7 +403,7 @@ type EditOAuth2ApplicationForm struct { // Validate validates the fields func (f *EditOAuth2ApplicationForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -414,7 +414,7 @@ type TwoFactorAuthForm struct { // Validate validates the fields func (f *TwoFactorAuthForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -425,7 +425,7 @@ type TwoFactorScratchAuthForm struct { // Validate validates the fields func (f *TwoFactorScratchAuthForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -436,7 +436,7 @@ type WebauthnRegistrationForm struct { // Validate validates the fields func (f *WebauthnRegistrationForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -447,7 +447,7 @@ type WebauthnDeleteForm struct { // Validate validates the fields func (f *WebauthnDeleteForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -459,6 +459,6 @@ type PackageSettingForm struct { // Validate validates the fields func (f *PackageSettingForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } diff --git a/services/forms/user_form_auth_openid.go b/services/forms/user_form_auth_openid.go index f95eb98405..d8137a8d13 100644 --- a/services/forms/user_form_auth_openid.go +++ b/services/forms/user_form_auth_openid.go @@ -20,7 +20,7 @@ type SignInOpenIDForm struct { // Validate validates the fields func (f *SignInOpenIDForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -32,7 +32,7 @@ type SignUpOpenIDForm struct { // Validate validates the fields func (f *SignUpOpenIDForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } @@ -44,6 +44,6 @@ type ConnectOpenIDForm struct { // Validate validates the fields func (f *ConnectOpenIDForm) Validate(req *http.Request, errs binding.Errors) binding.Errors { - ctx := context.GetContext(req) + ctx := context.GetValidateContext(req) return middleware.Validate(errs, ctx.Data, f, ctx.Locale) } diff --git a/services/markup/processorhelper_test.go b/services/markup/processorhelper_test.go index 6c9c1c27e7..2f48e03b22 100644 --- a/services/markup/processorhelper_test.go +++ b/services/markup/processorhelper_test.go @@ -6,6 +6,7 @@ package markup import ( "context" "net/http" + "net/http/httptest" "testing" "code.gitea.io/gitea/models/db" @@ -36,12 +37,12 @@ func TestProcessorHelper(t *testing.T) { assert.False(t, ProcessorHelper().IsUsernameMentionable(context.Background(), userNoSuch)) // when using web context, use user.IsUserVisibleToViewer to check - var err error - giteaCtx := &gitea_context.Context{} - giteaCtx.Req, err = http.NewRequest("GET", "/", nil) + req, err := http.NewRequest("GET", "/", nil) assert.NoError(t, err) + base, baseCleanUp := gitea_context.NewBaseContext(httptest.NewRecorder(), req) + defer baseCleanUp() + giteaCtx := &gitea_context.Context{Base: base} - giteaCtx.Doer = nil assert.True(t, ProcessorHelper().IsUsernameMentionable(giteaCtx, userPublic)) assert.False(t, ProcessorHelper().IsUsernameMentionable(giteaCtx, userPrivate)) |