You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

compress.go 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. package middleware
  2. import (
  3. "bufio"
  4. "compress/flate"
  5. "compress/gzip"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "io/ioutil"
  10. "net"
  11. "net/http"
  12. "strings"
  13. "sync"
  14. )
  15. var defaultCompressibleContentTypes = []string{
  16. "text/html",
  17. "text/css",
  18. "text/plain",
  19. "text/javascript",
  20. "application/javascript",
  21. "application/x-javascript",
  22. "application/json",
  23. "application/atom+xml",
  24. "application/rss+xml",
  25. "image/svg+xml",
  26. }
  27. // Compress is a middleware that compresses response
  28. // body of a given content types to a data format based
  29. // on Accept-Encoding request header. It uses a given
  30. // compression level.
  31. //
  32. // NOTE: make sure to set the Content-Type header on your response
  33. // otherwise this middleware will not compress the response body. For ex, in
  34. // your handler you should set w.Header().Set("Content-Type", http.DetectContentType(yourBody))
  35. // or set it manually.
  36. //
  37. // Passing a compression level of 5 is sensible value
  38. func Compress(level int, types ...string) func(next http.Handler) http.Handler {
  39. compressor := NewCompressor(level, types...)
  40. return compressor.Handler
  41. }
  42. // Compressor represents a set of encoding configurations.
  43. type Compressor struct {
  44. // The mapping of encoder names to encoder functions.
  45. encoders map[string]EncoderFunc
  46. // The mapping of pooled encoders to pools.
  47. pooledEncoders map[string]*sync.Pool
  48. // The set of content types allowed to be compressed.
  49. allowedTypes map[string]struct{}
  50. allowedWildcards map[string]struct{}
  51. // The list of encoders in order of decreasing precedence.
  52. encodingPrecedence []string
  53. level int // The compression level.
  54. }
  55. // NewCompressor creates a new Compressor that will handle encoding responses.
  56. //
  57. // The level should be one of the ones defined in the flate package.
  58. // The types are the content types that are allowed to be compressed.
  59. func NewCompressor(level int, types ...string) *Compressor {
  60. // If types are provided, set those as the allowed types. If none are
  61. // provided, use the default list.
  62. allowedTypes := make(map[string]struct{})
  63. allowedWildcards := make(map[string]struct{})
  64. if len(types) > 0 {
  65. for _, t := range types {
  66. if strings.Contains(strings.TrimSuffix(t, "/*"), "*") {
  67. panic(fmt.Sprintf("middleware/compress: Unsupported content-type wildcard pattern '%s'. Only '/*' supported", t))
  68. }
  69. if strings.HasSuffix(t, "/*") {
  70. allowedWildcards[strings.TrimSuffix(t, "/*")] = struct{}{}
  71. } else {
  72. allowedTypes[t] = struct{}{}
  73. }
  74. }
  75. } else {
  76. for _, t := range defaultCompressibleContentTypes {
  77. allowedTypes[t] = struct{}{}
  78. }
  79. }
  80. c := &Compressor{
  81. level: level,
  82. encoders: make(map[string]EncoderFunc),
  83. pooledEncoders: make(map[string]*sync.Pool),
  84. allowedTypes: allowedTypes,
  85. allowedWildcards: allowedWildcards,
  86. }
  87. // Set the default encoders. The precedence order uses the reverse
  88. // ordering that the encoders were added. This means adding new encoders
  89. // will move them to the front of the order.
  90. //
  91. // TODO:
  92. // lzma: Opera.
  93. // sdch: Chrome, Android. Gzip output + dictionary header.
  94. // br: Brotli, see https://github.com/go-chi/chi/pull/326
  95. // HTTP 1.1 "deflate" (RFC 2616) stands for DEFLATE data (RFC 1951)
  96. // wrapped with zlib (RFC 1950). The zlib wrapper uses Adler-32
  97. // checksum compared to CRC-32 used in "gzip" and thus is faster.
  98. //
  99. // But.. some old browsers (MSIE, Safari 5.1) incorrectly expect
  100. // raw DEFLATE data only, without the mentioned zlib wrapper.
  101. // Because of this major confusion, most modern browsers try it
  102. // both ways, first looking for zlib headers.
  103. // Quote by Mark Adler: http://stackoverflow.com/a/9186091/385548
  104. //
  105. // The list of browsers having problems is quite big, see:
  106. // http://zoompf.com/blog/2012/02/lose-the-wait-http-compression
  107. // https://web.archive.org/web/20120321182910/http://www.vervestudios.co/projects/compression-tests/results
  108. //
  109. // That's why we prefer gzip over deflate. It's just more reliable
  110. // and not significantly slower than gzip.
  111. c.SetEncoder("deflate", encoderDeflate)
  112. // TODO: Exception for old MSIE browsers that can't handle non-HTML?
  113. // https://zoompf.com/blog/2012/02/lose-the-wait-http-compression
  114. c.SetEncoder("gzip", encoderGzip)
  115. // NOTE: Not implemented, intentionally:
  116. // case "compress": // LZW. Deprecated.
  117. // case "bzip2": // Too slow on-the-fly.
  118. // case "zopfli": // Too slow on-the-fly.
  119. // case "xz": // Too slow on-the-fly.
  120. return c
  121. }
  122. // SetEncoder can be used to set the implementation of a compression algorithm.
  123. //
  124. // The encoding should be a standardised identifier. See:
  125. // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Encoding
  126. //
  127. // For example, add the Brotli algortithm:
  128. //
  129. // import brotli_enc "gopkg.in/kothar/brotli-go.v0/enc"
  130. //
  131. // compressor := middleware.NewCompressor(5, "text/html")
  132. // compressor.SetEncoder("br", func(w http.ResponseWriter, level int) io.Writer {
  133. // params := brotli_enc.NewBrotliParams()
  134. // params.SetQuality(level)
  135. // return brotli_enc.NewBrotliWriter(params, w)
  136. // })
  137. func (c *Compressor) SetEncoder(encoding string, fn EncoderFunc) {
  138. encoding = strings.ToLower(encoding)
  139. if encoding == "" {
  140. panic("the encoding can not be empty")
  141. }
  142. if fn == nil {
  143. panic("attempted to set a nil encoder function")
  144. }
  145. // If we are adding a new encoder that is already registered, we have to
  146. // clear that one out first.
  147. if _, ok := c.pooledEncoders[encoding]; ok {
  148. delete(c.pooledEncoders, encoding)
  149. }
  150. if _, ok := c.encoders[encoding]; ok {
  151. delete(c.encoders, encoding)
  152. }
  153. // If the encoder supports Resetting (IoReseterWriter), then it can be pooled.
  154. encoder := fn(ioutil.Discard, c.level)
  155. if encoder != nil {
  156. if _, ok := encoder.(ioResetterWriter); ok {
  157. pool := &sync.Pool{
  158. New: func() interface{} {
  159. return fn(ioutil.Discard, c.level)
  160. },
  161. }
  162. c.pooledEncoders[encoding] = pool
  163. }
  164. }
  165. // If the encoder is not in the pooledEncoders, add it to the normal encoders.
  166. if _, ok := c.pooledEncoders[encoding]; !ok {
  167. c.encoders[encoding] = fn
  168. }
  169. for i, v := range c.encodingPrecedence {
  170. if v == encoding {
  171. c.encodingPrecedence = append(c.encodingPrecedence[:i], c.encodingPrecedence[i+1:]...)
  172. }
  173. }
  174. c.encodingPrecedence = append([]string{encoding}, c.encodingPrecedence...)
  175. }
  176. // Handler returns a new middleware that will compress the response based on the
  177. // current Compressor.
  178. func (c *Compressor) Handler(next http.Handler) http.Handler {
  179. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  180. encoder, encoding, cleanup := c.selectEncoder(r.Header, w)
  181. cw := &compressResponseWriter{
  182. ResponseWriter: w,
  183. w: w,
  184. contentTypes: c.allowedTypes,
  185. contentWildcards: c.allowedWildcards,
  186. encoding: encoding,
  187. compressable: false, // determined in post-handler
  188. }
  189. if encoder != nil {
  190. cw.w = encoder
  191. }
  192. // Re-add the encoder to the pool if applicable.
  193. defer cleanup()
  194. defer cw.Close()
  195. next.ServeHTTP(cw, r)
  196. })
  197. }
  198. // selectEncoder returns the encoder, the name of the encoder, and a closer function.
  199. func (c *Compressor) selectEncoder(h http.Header, w io.Writer) (io.Writer, string, func()) {
  200. header := h.Get("Accept-Encoding")
  201. // Parse the names of all accepted algorithms from the header.
  202. accepted := strings.Split(strings.ToLower(header), ",")
  203. // Find supported encoder by accepted list by precedence
  204. for _, name := range c.encodingPrecedence {
  205. if matchAcceptEncoding(accepted, name) {
  206. if pool, ok := c.pooledEncoders[name]; ok {
  207. encoder := pool.Get().(ioResetterWriter)
  208. cleanup := func() {
  209. pool.Put(encoder)
  210. }
  211. encoder.Reset(w)
  212. return encoder, name, cleanup
  213. }
  214. if fn, ok := c.encoders[name]; ok {
  215. return fn(w, c.level), name, func() {}
  216. }
  217. }
  218. }
  219. // No encoder found to match the accepted encoding
  220. return nil, "", func() {}
  221. }
  222. func matchAcceptEncoding(accepted []string, encoding string) bool {
  223. for _, v := range accepted {
  224. if strings.Contains(v, encoding) {
  225. return true
  226. }
  227. }
  228. return false
  229. }
  230. // An EncoderFunc is a function that wraps the provided io.Writer with a
  231. // streaming compression algorithm and returns it.
  232. //
  233. // In case of failure, the function should return nil.
  234. type EncoderFunc func(w io.Writer, level int) io.Writer
  235. // Interface for types that allow resetting io.Writers.
  236. type ioResetterWriter interface {
  237. io.Writer
  238. Reset(w io.Writer)
  239. }
  240. type compressResponseWriter struct {
  241. http.ResponseWriter
  242. // The streaming encoder writer to be used if there is one. Otherwise,
  243. // this is just the normal writer.
  244. w io.Writer
  245. contentTypes map[string]struct{}
  246. contentWildcards map[string]struct{}
  247. encoding string
  248. wroteHeader bool
  249. compressable bool
  250. }
  251. func (cw *compressResponseWriter) isCompressable() bool {
  252. // Parse the first part of the Content-Type response header.
  253. contentType := cw.Header().Get("Content-Type")
  254. if idx := strings.Index(contentType, ";"); idx >= 0 {
  255. contentType = contentType[0:idx]
  256. }
  257. // Is the content type compressable?
  258. if _, ok := cw.contentTypes[contentType]; ok {
  259. return true
  260. }
  261. if idx := strings.Index(contentType, "/"); idx > 0 {
  262. contentType = contentType[0:idx]
  263. _, ok := cw.contentWildcards[contentType]
  264. return ok
  265. }
  266. return false
  267. }
  268. func (cw *compressResponseWriter) WriteHeader(code int) {
  269. if cw.wroteHeader {
  270. cw.ResponseWriter.WriteHeader(code) // Allow multiple calls to propagate.
  271. return
  272. }
  273. cw.wroteHeader = true
  274. defer cw.ResponseWriter.WriteHeader(code)
  275. // Already compressed data?
  276. if cw.Header().Get("Content-Encoding") != "" {
  277. return
  278. }
  279. if !cw.isCompressable() {
  280. cw.compressable = false
  281. return
  282. }
  283. if cw.encoding != "" {
  284. cw.compressable = true
  285. cw.Header().Set("Content-Encoding", cw.encoding)
  286. cw.Header().Set("Vary", "Accept-Encoding")
  287. // The content-length after compression is unknown
  288. cw.Header().Del("Content-Length")
  289. }
  290. }
  291. func (cw *compressResponseWriter) Write(p []byte) (int, error) {
  292. if !cw.wroteHeader {
  293. cw.WriteHeader(http.StatusOK)
  294. }
  295. return cw.writer().Write(p)
  296. }
  297. func (cw *compressResponseWriter) writer() io.Writer {
  298. if cw.compressable {
  299. return cw.w
  300. } else {
  301. return cw.ResponseWriter
  302. }
  303. }
  304. type compressFlusher interface {
  305. Flush() error
  306. }
  307. func (cw *compressResponseWriter) Flush() {
  308. if f, ok := cw.writer().(http.Flusher); ok {
  309. f.Flush()
  310. }
  311. // If the underlying writer has a compression flush signature,
  312. // call this Flush() method instead
  313. if f, ok := cw.writer().(compressFlusher); ok {
  314. f.Flush()
  315. // Also flush the underlying response writer
  316. if f, ok := cw.ResponseWriter.(http.Flusher); ok {
  317. f.Flush()
  318. }
  319. }
  320. }
  321. func (cw *compressResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
  322. if hj, ok := cw.writer().(http.Hijacker); ok {
  323. return hj.Hijack()
  324. }
  325. return nil, nil, errors.New("chi/middleware: http.Hijacker is unavailable on the writer")
  326. }
  327. func (cw *compressResponseWriter) Push(target string, opts *http.PushOptions) error {
  328. if ps, ok := cw.writer().(http.Pusher); ok {
  329. return ps.Push(target, opts)
  330. }
  331. return errors.New("chi/middleware: http.Pusher is unavailable on the writer")
  332. }
  333. func (cw *compressResponseWriter) Close() error {
  334. if c, ok := cw.writer().(io.WriteCloser); ok {
  335. return c.Close()
  336. }
  337. return errors.New("chi/middleware: io.WriteCloser is unavailable on the writer")
  338. }
  339. func encoderGzip(w io.Writer, level int) io.Writer {
  340. gw, err := gzip.NewWriterLevel(w, level)
  341. if err != nil {
  342. return nil
  343. }
  344. return gw
  345. }
  346. func encoderDeflate(w io.Writer, level int) io.Writer {
  347. dw, err := flate.NewWriter(w, level)
  348. if err != nil {
  349. return nil
  350. }
  351. return dw
  352. }