aboutsummaryrefslogtreecommitdiffstats
path: root/modules/web/router_path.go
blob: b59948581a94eb2b6726bc2e8b921704b37b23a9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
// Copyright 2024 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT

package web

import (
	"fmt"
	"net/http"
	"regexp"
	"strings"

	"code.gitea.io/gitea/modules/container"
	"code.gitea.io/gitea/modules/util"

	"github.com/go-chi/chi/v5"
)

type RouterPathGroup struct {
	r         *Router
	pathParam string
	matchers  []*routerPathMatcher
}

func (g *RouterPathGroup) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
	chiCtx := chi.RouteContext(req.Context())
	path := chiCtx.URLParam(g.pathParam)
	for _, m := range g.matchers {
		if m.matchPath(chiCtx, path) {
			handler := m.handlerFunc
			for i := len(m.middlewares) - 1; i >= 0; i-- {
				handler = m.middlewares[i](handler).ServeHTTP
			}
			handler(resp, req)
			return
		}
	}
	g.r.chiRouter.NotFoundHandler().ServeHTTP(resp, req)
}

// MatchPath matches the request method, and uses regexp to match the path.
// The pattern uses "<...>" to define path parameters, for example: "/<name>" (different from chi router)
// It is only designed to resolve some special cases which chi router can't handle.
// For most cases, it shouldn't be used because it needs to iterate all rules to find the matched one (inefficient).
func (g *RouterPathGroup) MatchPath(methods, pattern string, h ...any) {
	g.matchers = append(g.matchers, newRouterPathMatcher(methods, pattern, h...))
}

type routerPathParam struct {
	name         string
	captureGroup int
}

type routerPathMatcher struct {
	methods     container.Set[string]
	re          *regexp.Regexp
	params      []routerPathParam
	middlewares []func(http.Handler) http.Handler
	handlerFunc http.HandlerFunc
}

func (p *routerPathMatcher) matchPath(chiCtx *chi.Context, path string) bool {
	if !p.methods.Contains(chiCtx.RouteMethod) {
		return false
	}
	if !strings.HasPrefix(path, "/") {
		path = "/" + path
	}
	pathMatches := p.re.FindStringSubmatchIndex(path) // Golang regexp match pairs [start, end, start, end, ...]
	if pathMatches == nil {
		return false
	}
	var paramMatches [][]int
	for i := 2; i < len(pathMatches); {
		paramMatches = append(paramMatches, []int{pathMatches[i], pathMatches[i+1]})
		pmIdx := len(paramMatches) - 1
		end := pathMatches[i+1]
		i += 2
		for ; i < len(pathMatches); i += 2 {
			if pathMatches[i] >= end {
				break
			}
			paramMatches[pmIdx] = append(paramMatches[pmIdx], pathMatches[i], pathMatches[i+1])
		}
	}
	for i, pm := range paramMatches {
		groupIdx := p.params[i].captureGroup * 2
		chiCtx.URLParams.Add(p.params[i].name, path[pm[groupIdx]:pm[groupIdx+1]])
	}
	return true
}

func isValidMethod(name string) bool {
	switch name {
	case http.MethodGet, http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete, http.MethodHead, http.MethodOptions, http.MethodConnect, http.MethodTrace:
		return true
	}
	return false
}

func newRouterPathMatcher(methods, pattern string, h ...any) *routerPathMatcher {
	middlewares, handlerFunc := wrapMiddlewareAndHandler(nil, h)
	p := &routerPathMatcher{methods: make(container.Set[string]), middlewares: middlewares, handlerFunc: handlerFunc}
	for _, method := range strings.Split(methods, ",") {
		method = strings.TrimSpace(method)
		if !isValidMethod(method) {
			panic(fmt.Sprintf("invalid HTTP method: %s", method))
		}
		p.methods.Add(method)
	}
	re := []byte{'^'}
	lastEnd := 0
	for lastEnd < len(pattern) {
		start := strings.IndexByte(pattern[lastEnd:], '<')
		if start == -1 {
			re = append(re, pattern[lastEnd:]...)
			break
		}
		end := strings.IndexByte(pattern[lastEnd+start:], '>')
		if end == -1 {
			panic(fmt.Sprintf("invalid pattern: %s", pattern))
		}
		re = append(re, pattern[lastEnd:lastEnd+start]...)
		partName, partExp, _ := strings.Cut(pattern[lastEnd+start+1:lastEnd+start+end], ":")
		lastEnd += start + end + 1

		// TODO: it could support to specify a "capture group" for the name, for example: "/<name[2]:(\d)-(\d)>"
		// it is not used so no need to implement it now
		param := routerPathParam{}
		if partExp == "*" {
			re = append(re, "(.*?)/?"...)
			if lastEnd < len(pattern) && pattern[lastEnd] == '/' {
				lastEnd++ // the "*" pattern is able to handle the last slash, so skip it
			}
		} else {
			partExp = util.IfZero(partExp, "[^/]+")
			re = append(re, '(')
			re = append(re, partExp...)
			re = append(re, ')')
		}
		param.name = partName
		p.params = append(p.params, param)
	}
	re = append(re, '$')
	reStr := string(re)
	p.re = regexp.MustCompile(reStr)
	return p
}