You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

route_headers.go 4.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. package middleware
  2. import (
  3. "net/http"
  4. "strings"
  5. )
  6. // RouteHeaders is a neat little header-based router that allows you to direct
  7. // the flow of a request through a middleware stack based on a request header.
  8. //
  9. // For example, lets say you'd like to setup multiple routers depending on the
  10. // request Host header, you could then do something as so:
  11. //
  12. // r := chi.NewRouter()
  13. // rSubdomain := chi.NewRouter()
  14. //
  15. // r.Use(middleware.RouteHeaders().
  16. // Route("Host", "example.com", middleware.New(r)).
  17. // Route("Host", "*.example.com", middleware.New(rSubdomain)).
  18. // Handler)
  19. //
  20. // r.Get("/", h)
  21. // rSubdomain.Get("/", h2)
  22. //
  23. //
  24. // Another example, imagine you want to setup multiple CORS handlers, where for
  25. // your origin servers you allow authorized requests, but for third-party public
  26. // requests, authorization is disabled.
  27. //
  28. // r := chi.NewRouter()
  29. //
  30. // r.Use(middleware.RouteHeaders().
  31. // Route("Origin", "https://app.skyweaver.net", cors.Handler(cors.Options{
  32. // AllowedOrigins: []string{"https://api.skyweaver.net"},
  33. // AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
  34. // AllowedHeaders: []string{"Accept", "Authorization", "Content-Type"},
  35. // AllowCredentials: true, // <----------<<< allow credentials
  36. // })).
  37. // Route("Origin", "*", cors.Handler(cors.Options{
  38. // AllowedOrigins: []string{"*"},
  39. // AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
  40. // AllowedHeaders: []string{"Accept", "Content-Type"},
  41. // AllowCredentials: false, // <----------<<< do not allow credentials
  42. // })).
  43. // Handler)
  44. //
  45. func RouteHeaders() HeaderRouter {
  46. return HeaderRouter{}
  47. }
  48. type HeaderRouter map[string][]HeaderRoute
  49. func (hr HeaderRouter) Route(header, match string, middlewareHandler func(next http.Handler) http.Handler) HeaderRouter {
  50. header = strings.ToLower(header)
  51. k := hr[header]
  52. if k == nil {
  53. hr[header] = []HeaderRoute{}
  54. }
  55. hr[header] = append(hr[header], HeaderRoute{MatchOne: NewPattern(match), Middleware: middlewareHandler})
  56. return hr
  57. }
  58. func (hr HeaderRouter) RouteAny(header string, match []string, middlewareHandler func(next http.Handler) http.Handler) HeaderRouter {
  59. header = strings.ToLower(header)
  60. k := hr[header]
  61. if k == nil {
  62. hr[header] = []HeaderRoute{}
  63. }
  64. patterns := []Pattern{}
  65. for _, m := range match {
  66. patterns = append(patterns, NewPattern(m))
  67. }
  68. hr[header] = append(hr[header], HeaderRoute{MatchAny: patterns, Middleware: middlewareHandler})
  69. return hr
  70. }
  71. func (hr HeaderRouter) RouteDefault(handler func(next http.Handler) http.Handler) HeaderRouter {
  72. hr["*"] = []HeaderRoute{{Middleware: handler}}
  73. return hr
  74. }
  75. func (hr HeaderRouter) Handler(next http.Handler) http.Handler {
  76. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  77. if len(hr) == 0 {
  78. // skip if no routes set
  79. next.ServeHTTP(w, r)
  80. }
  81. // find first matching header route, and continue
  82. for header, matchers := range hr {
  83. headerValue := r.Header.Get(header)
  84. if headerValue == "" {
  85. continue
  86. }
  87. headerValue = strings.ToLower(headerValue)
  88. for _, matcher := range matchers {
  89. if matcher.IsMatch(headerValue) {
  90. matcher.Middleware(next).ServeHTTP(w, r)
  91. return
  92. }
  93. }
  94. }
  95. // if no match, check for "*" default route
  96. matcher, ok := hr["*"]
  97. if !ok || matcher[0].Middleware == nil {
  98. next.ServeHTTP(w, r)
  99. return
  100. }
  101. matcher[0].Middleware(next).ServeHTTP(w, r)
  102. })
  103. }
  104. type HeaderRoute struct {
  105. Middleware func(next http.Handler) http.Handler
  106. MatchOne Pattern
  107. MatchAny []Pattern
  108. }
  109. func (r HeaderRoute) IsMatch(value string) bool {
  110. if len(r.MatchAny) > 0 {
  111. for _, m := range r.MatchAny {
  112. if m.Match(value) {
  113. return true
  114. }
  115. }
  116. } else if r.MatchOne.Match(value) {
  117. return true
  118. }
  119. return false
  120. }
  121. type Pattern struct {
  122. prefix string
  123. suffix string
  124. wildcard bool
  125. }
  126. func NewPattern(value string) Pattern {
  127. p := Pattern{}
  128. if i := strings.IndexByte(value, '*'); i >= 0 {
  129. p.wildcard = true
  130. p.prefix = value[0:i]
  131. p.suffix = value[i+1:]
  132. } else {
  133. p.prefix = value
  134. }
  135. return p
  136. }
  137. func (p Pattern) Match(v string) bool {
  138. if !p.wildcard {
  139. if p.prefix == v {
  140. return true
  141. } else {
  142. return false
  143. }
  144. }
  145. return len(v) >= len(p.prefix+p.suffix) && strings.HasPrefix(v, p.prefix) && strings.HasSuffix(v, p.suffix)
  146. }