From 6e6ae1eefd1564c50f68f0727ba307a784f23a89 Mon Sep 17 00:00:00 2001 From: Yuanji Date: Fri, 4 Feb 2022 11:04:26 +0900 Subject: [PATCH] feat(middleware): add Deflate middleware --- middleware/compress.go | 97 ++++++++++++++---- middleware/compress_test.go | 197 ++++++++++++++++++++++++++++++++++++ 2 files changed, 272 insertions(+), 22 deletions(-) diff --git a/middleware/compress.go b/middleware/compress.go index ac6672e9d..ad7ec9dfd 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -3,6 +3,7 @@ package middleware import ( "bufio" "compress/gzip" + "compress/zlib" "io" "io/ioutil" "net" @@ -14,33 +15,50 @@ import ( ) type ( - // GzipConfig defines the config for Gzip middleware. - GzipConfig struct { + compressConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper - // Gzip compression level. + // Compression level. // Optional. Default value -1. Level int `yaml:"level"` } - gzipResponseWriter struct { + // GzipConfig defines the config for Gzip middleware. + GzipConfig compressConfig + // DeflateConfig defines the config for Deflate middleware. + DeflateConfig compressConfig + + compressResponseWriter struct { io.Writer http.ResponseWriter wroteBody bool } + + resetWriteCloser interface { + Reset(w io.Writer) + io.WriteCloser + } + + flusher interface { + Flush() error + } ) const ( - gzipScheme = "gzip" + gzipScheme = "gzip" + deflateScheme = "deflate" ) var ( - // DefaultGzipConfig is the default Gzip middleware config. - DefaultGzipConfig = GzipConfig{ + defaultConfig = compressConfig{ Skipper: DefaultSkipper, Level: -1, } + // DefaultGzipConfig is the default Gzip middleware config. + DefaultGzipConfig = GzipConfig(defaultConfig) + // DefaultDeflateConfig is the default Deflate middleware config. + DefaultDeflateConfig = DeflateConfig(defaultConfig) ) // Gzip returns a middleware which compresses HTTP response using gzip compression @@ -49,18 +67,41 @@ func Gzip() echo.MiddlewareFunc { return GzipWithConfig(DefaultGzipConfig) } +// Deflate returns a middleware which compresses HTTP response using deflate(zlib) compression +func Deflate() echo.MiddlewareFunc { + return DeflateWithConfig(DefaultDeflateConfig) +} + // GzipWithConfig return Gzip middleware with config. // See: `Gzip()`. func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { + return compressWithConfig(compressConfig(config), gzipScheme) +} + +// DeflateWithConfig return Deflate middleware with config. +// See: `Deflate()`. +func DeflateWithConfig(config DeflateConfig) echo.MiddlewareFunc { + return compressWithConfig(compressConfig(config), deflateScheme) +} + +func compressWithConfig(config compressConfig, encoding string) echo.MiddlewareFunc { // Defaults if config.Skipper == nil { - config.Skipper = DefaultGzipConfig.Skipper + config.Skipper = defaultConfig.Skipper } if config.Level == 0 { - config.Level = DefaultGzipConfig.Level + config.Level = defaultConfig.Level } - pool := gzipCompressPool(config) + var pool sync.Pool + switch encoding { + case gzipScheme: + pool = gzipCompressPool(config) + case deflateScheme: + pool = deflateCompressPool(config) + default: + panic("echo: either gzip or deflate is currently supported") + } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { @@ -70,19 +111,19 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { res := c.Response() res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding) - if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), gzipScheme) { - res.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806 + if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), encoding) { + res.Header().Set(echo.HeaderContentEncoding, encoding) // Issue #806 i := pool.Get() - w, ok := i.(*gzip.Writer) + w, ok := i.(resetWriteCloser) if !ok { return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error()) } rw := res.Writer w.Reset(rw) - grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw} + grw := &compressResponseWriter{Writer: w, ResponseWriter: rw} defer func() { if !grw.wroteBody { - if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme { + if res.Header().Get(echo.HeaderContentEncoding) == encoding { res.Header().Del(echo.HeaderContentEncoding) } // We have to reset response to it's pristine state when @@ -101,12 +142,12 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { } } -func (w *gzipResponseWriter) WriteHeader(code int) { +func (w *compressResponseWriter) WriteHeader(code int) { w.Header().Del(echo.HeaderContentLength) // Issue #444 w.ResponseWriter.WriteHeader(code) } -func (w *gzipResponseWriter) Write(b []byte) (int, error) { +func (w *compressResponseWriter) Write(b []byte) (int, error) { if w.Header().Get(echo.HeaderContentType) == "" { w.Header().Set(echo.HeaderContentType, http.DetectContentType(b)) } @@ -114,25 +155,25 @@ func (w *gzipResponseWriter) Write(b []byte) (int, error) { return w.Writer.Write(b) } -func (w *gzipResponseWriter) Flush() { - w.Writer.(*gzip.Writer).Flush() +func (w *compressResponseWriter) Flush() { + w.Writer.(flusher).Flush() if flusher, ok := w.ResponseWriter.(http.Flusher); ok { flusher.Flush() } } -func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { +func (w *compressResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { return w.ResponseWriter.(http.Hijacker).Hijack() } -func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error { +func (w *compressResponseWriter) Push(target string, opts *http.PushOptions) error { if p, ok := w.ResponseWriter.(http.Pusher); ok { return p.Push(target, opts) } return http.ErrNotSupported } -func gzipCompressPool(config GzipConfig) sync.Pool { +func gzipCompressPool(config compressConfig) sync.Pool { return sync.Pool{ New: func() interface{} { w, err := gzip.NewWriterLevel(ioutil.Discard, config.Level) @@ -143,3 +184,15 @@ func gzipCompressPool(config GzipConfig) sync.Pool { }, } } + +func deflateCompressPool(config compressConfig) sync.Pool { + return sync.Pool{ + New: func() interface{} { + w, err := zlib.NewWriterLevel(ioutil.Discard, config.Level) + if err != nil { + return err + } + return w + }, + } +} diff --git a/middleware/compress_test.go b/middleware/compress_test.go index b62bffef5..6f34b3162 100644 --- a/middleware/compress_test.go +++ b/middleware/compress_test.go @@ -3,6 +3,7 @@ package middleware import ( "bytes" "compress/gzip" + "compress/zlib" "io" "io/ioutil" "net/http" @@ -205,3 +206,199 @@ func BenchmarkGzip(b *testing.B) { h(c) } } + +func TestDeflate(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + // Skip if no Accept-Encoding header + h := Deflate()(func(c echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + }) + h(c) + + assert := assert.New(t) + + assert.Equal("test", rec.Body.String()) + + // Deflate + req = httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, deflateScheme) + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + h(c) + assert.Equal(deflateScheme, rec.Header().Get(echo.HeaderContentEncoding)) + assert.Contains(rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain) + r, err := zlib.NewReader(rec.Body) + if assert.NoError(err) { + buf := new(bytes.Buffer) + defer r.Close() + buf.ReadFrom(r) + assert.Equal("test", buf.String()) + } + + chunkBuf := make([]byte, 5) + + // Deflate chunked + req = httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, deflateScheme) + rec = httptest.NewRecorder() + + c = e.NewContext(req, rec) + Deflate()(func(c echo.Context) error { + c.Response().Header().Set("Content-Type", "text/event-stream") + c.Response().Header().Set("Transfer-Encoding", "chunked") + + // Write and flush the first part of the data + c.Response().Write([]byte("test\n")) + c.Response().Flush() + + // Read the first part of the data + assert.True(rec.Flushed) + assert.Equal(deflateScheme, rec.Header().Get(echo.HeaderContentEncoding)) + + // See also https://github.com/golang/go/issues/26535#issuecomment-759649380 + rr, _ := r.(zlib.Resetter) + rr.Reset(rec.Body, nil) + + _, err = io.ReadFull(r, chunkBuf) + assert.NoError(err) + assert.Equal("test\n", string(chunkBuf)) + + // Write and flush the second part of the data + c.Response().Write([]byte("test\n")) + c.Response().Flush() + + _, err = io.ReadFull(r, chunkBuf) + assert.NoError(err) + assert.Equal("test\n", string(chunkBuf)) + + // Write the final part of the data and return + c.Response().Write([]byte("test")) + return nil + })(c) + + buf := new(bytes.Buffer) + defer r.Close() + buf.ReadFrom(r) + assert.Equal("test", buf.String()) +} + +func TestDeflateNoContent(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, deflateScheme) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := Deflate()(func(c echo.Context) error { + return c.NoContent(http.StatusNoContent) + }) + if assert.NoError(t, h(c)) { + assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) + assert.Empty(t, rec.Header().Get(echo.HeaderContentType)) + assert.Equal(t, 0, len(rec.Body.Bytes())) + } +} + +func TestDeflateEmpty(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, deflateScheme) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := Deflate()(func(c echo.Context) error { + return c.String(http.StatusOK, "") + }) + if assert.NoError(t, h(c)) { + assert.Equal(t, deflateScheme, rec.Header().Get(echo.HeaderContentEncoding)) + assert.Equal(t, "text/plain; charset=UTF-8", rec.Header().Get(echo.HeaderContentType)) + r, err := zlib.NewReader(rec.Body) + if assert.NoError(t, err) { + var buf bytes.Buffer + buf.ReadFrom(r) + assert.Equal(t, "", buf.String()) + } + } +} + +func TestDeflateErrorReturned(t *testing.T) { + e := echo.New() + e.Use(Deflate()) + e.GET("/", func(c echo.Context) error { + return echo.ErrNotFound + }) + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, deflateScheme) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusNotFound, rec.Code) + assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) +} + +func TestDeflateErrorReturnedInvalidConfig(t *testing.T) { + e := echo.New() + // Invalid level + e.Use(DeflateWithConfig(DeflateConfig{Level: 12})) + e.GET("/", func(c echo.Context) error { + c.Response().Write([]byte("test")) + return nil + }) + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, deflateScheme) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Contains(t, rec.Body.String(), "zlib") +} + +// Issue #806 +func TestDeflateWithStatic(t *testing.T) { + e := echo.New() + e.Use(Deflate()) + e.Static("/test", "../_fixture/images") + req := httptest.NewRequest(http.MethodGet, "/test/walle.png", nil) + req.Header.Set(echo.HeaderAcceptEncoding, deflateScheme) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code) + // Data is written out in chunks when Content-Length == "", so only + // validate the content length if it's not set. + if cl := rec.Header().Get("Content-Length"); cl != "" { + assert.Equal(t, cl, rec.Body.Len()) + } + r, err := zlib.NewReader(rec.Body) + if assert.NoError(t, err) { + defer r.Close() + want, err := ioutil.ReadFile("../_fixture/images/walle.png") + if assert.NoError(t, err) { + buf := new(bytes.Buffer) + buf.ReadFrom(r) + assert.Equal(t, want, buf.Bytes()) + } + } +} + +func BenchmarkDeflate(b *testing.B) { + e := echo.New() + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, deflateScheme) + + h := Deflate()(func(c echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + }) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + // Deflate + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h(c) + } +}