From c13dc090229d7c4d4bd6932e17b8aa16398e00f8 Mon Sep 17 00:00:00 2001 From: cce <51567+cce@users.noreply.github.com> Date: Thu, 6 Mar 2025 13:17:28 -0500 Subject: [PATCH 1/4] Add NewWriterLevelDictWindowSize and NewReaderDictMaxWindowSize to limit memory usage during compression and decompression. More details available in RFC 9659. - NewWriterLevelDictWindowSize added to set specific window size (in bytes) - NewReaderDictMaxWindowSize added to set a maximum window size (in bytes) that this decompressor will allow. --- README.md | 2 + zstd_stream.go | 41 +++++++++++++++- zstd_stream_test.go | 113 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 155 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index c1db495..8bd45b2 100644 --- a/README.md +++ b/README.md @@ -71,6 +71,7 @@ Decompress(dst, src []byte) ([]byte, error) NewWriter(w io.Writer) *Writer NewWriterLevel(w io.Writer, level int) *Writer NewWriterLevelDict(w io.Writer, level int, dict []byte) *Writer +NewWriterLevelDictWindowSize(w io.Writer, level int, dict []byte, windowSize int) *Writer // Write compresses the input data and write it to the underlying writer (w *Writer) Write(p []byte) (int, error) @@ -89,6 +90,7 @@ NewWriterLevelDict(w io.Writer, level int, dict []byte) *Writer // to call Close, which frees up C objects. NewReader(r io.Reader) io.ReadCloser NewReaderDict(r io.Reader, dict []byte) io.ReadCloser +NewReaderDictMaxWindowSize(r io.Reader, dict []byte, maxWindowSize int) io.ReadCloser ``` ### Benchmarks (benchmarked with v0.5.0) diff --git a/zstd_stream.go b/zstd_stream.go index 714ecfe..042f7c6 100644 --- a/zstd_stream.go +++ b/zstd_stream.go @@ -1,6 +1,7 @@ package zstd /* +#define ZSTD_STATIC_LINKING_ONLY 1 #include "zstd.h" typedef struct compressStream2_result_s { @@ -65,6 +66,7 @@ import ( "errors" "fmt" "io" + "math/bits" "runtime" "sync" "unsafe" @@ -117,11 +119,22 @@ func NewWriterLevel(w io.Writer, level int) *Writer { // compress with. If the dictionary is empty or nil it is ignored. The dictionary // should not be modified until the writer is closed. func NewWriterLevelDict(w io.Writer, level int, dict []byte) *Writer { + return NewWriterLevelDictWindowSize(w, level, dict, 0) +} + +// NewWriterLevelDictWindowSize is like NewWriterLevelDict but allows configuring +// the window size. windowSize is specified in bytes and will be converted to a windowLog +// parameter (log2 of the window size). If windowSize is 0, the default window size is used. +// The windowSize must be a power of 2 between 1KB and 8MB on 32-bit platforms +// or 1KB and 2GB on 64-bit platforms. +// A larger window size allows for better compression ratios for repetitive data +// but requires more memory during compression and decompression. +func NewWriterLevelDictWindowSize(w io.Writer, level int, dict []byte, windowSize int) *Writer { var err error ctx := C.ZSTD_createCStream() // Load dictionnary if any - if dict != nil { + if len(dict) > 0 { err = getError(int(C.ZSTD_CCtx_loadDictionary(ctx, unsafe.Pointer(&dict[0]), C.size_t(len(dict)), @@ -133,6 +146,17 @@ func NewWriterLevelDict(w io.Writer, level int, dict []byte) *Writer { err = getError(int(C.ZSTD_CCtx_setParameter(ctx, C.ZSTD_c_compressionLevel, C.int(level)))) } + if err == nil && windowSize > 0 { + // Only set windowLog if windowSize is a power of 2 + if windowSize&(windowSize-1) == 0 { + // Convert windowSize to windowLog using bits.TrailingZeros + windowLog := bits.TrailingZeros(uint(windowSize)) + err = getError(int(C.ZSTD_CCtx_setParameter(ctx, C.ZSTD_c_windowLog, C.int(windowLog)))) + } else { + err = fmt.Errorf("window size must be a power of 2") + } + } + return &Writer{ CompressionLevel: level, ctx: ctx, @@ -373,8 +397,18 @@ func NewReader(r io.Reader) io.ReadCloser { // NewReaderDict is like NewReader but uses a preset dictionary. NewReaderDict // ignores the dictionary if it is nil. func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser { + return NewReaderDictMaxWindowSize(r, dict, 0) +} + +// NewReaderDictMaxWindowSize is like NewReaderDict but allows configuring the maximum +// window size for decompression. maxWindowSize is specified in bytes, not as a log value. +// If maxWindowSize is 0, the default window size limit is used. +// Setting a maximum window size protects against allocating too much memory for +// decompression (potential attack scenario) when processing untrusted inputs. +func NewReaderDictMaxWindowSize(r io.Reader, dict []byte, maxWindowSize int) io.ReadCloser { var err error ctx := C.ZSTD_createDStream() + if len(dict) == 0 { err = getError(int(C.ZSTD_initDStream(ctx))) } else { @@ -387,6 +421,11 @@ func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser { C.size_t(len(dict))))) } } + + if err == nil && maxWindowSize > 0 { + err = getError(int(C.ZSTD_DCtx_setMaxWindowSize(ctx, C.size_t(maxWindowSize)))) + } + compressionBufferP := cPool.Get().(*[]byte) decompressionBufferP := dPool.Get().(*[]byte) return &reader{ diff --git a/zstd_stream_test.go b/zstd_stream_test.go index d0960f4..3aec74d 100644 --- a/zstd_stream_test.go +++ b/zstd_stream_test.go @@ -414,6 +414,119 @@ func TestStreamSetNbWorkers(t *testing.T) { testCompressionDecompression(t, nil, []byte(s), nbWorkers) } +func TestStreamWindowSize(t *testing.T) { + dict := []byte("dictdata_for_compression_test") + data := []byte("hello world") + testCases := []struct { + name string + dict []byte + }{ + {"NilDict", nil}, + {"ValidDict", dict}, + {"EmptyDict", []byte{}}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Test with valid window size (power of 2) + t.Run("ValidWindowSize", func(t *testing.T) { + var buf bytes.Buffer + w := NewWriterLevelDictWindowSize(&buf, DefaultCompression, tc.dict, 1<<17) // 128 KB + + _, err := w.Write(data) + failOnError(t, "Write error", err) + failOnError(t, "Close error", w.Close()) + + // Test decoding + r := NewReader(&buf) + decompressed, err := ioutil.ReadAll(r) + failOnError(t, "ReadAll error", err) + if !bytes.Equal(decompressed, data) { + t.Fatalf("got %q; want %q", decompressed, data) + } + failOnError(t, "Reader close error", r.Close()) + }) + + // Test with invalid window size (not a power of 2) + t.Run("InvalidWindowSize", func(t *testing.T) { + var buf bytes.Buffer + w := NewWriterLevelDictWindowSize(&buf, DefaultCompression, tc.dict, 123456) + _, err := w.Write(data) + if err == nil { + t.Fatal("Expected error for invalid window size, got nil") + } + if !strings.Contains(err.Error(), "window size must be a power of 2") { + t.Fatalf("Unexpected error message: %v", err) + } + }) + }) + } +} + +func TestStreamMaxWindowSize(t *testing.T) { + dict := []byte("dictdata_for_compression_test") + testCases := []struct { + name string + dict []byte + }{ + {"NilDict", nil}, + {"ValidDict", dict}, + {"EmptyDict", []byte{}}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create compressed data with a 128KB window size + data := strings.Repeat("abcdefghijklmnopqrstuvwxyz", 1000) + var buf bytes.Buffer + w := NewWriterLevelDictWindowSize(&buf, DefaultCompression, tc.dict, 1<<17) // 128 KB + + _, err := w.Write([]byte(data)) + failOnError(t, "Write error", err) + failOnError(t, "Flush error", w.Flush()) + failOnError(t, "Close error", w.Close()) + compressedData := buf.Bytes() + + // Normal decompression should work + t.Run("NormalDecompression", func(t *testing.T) { + r1 := NewReader(bytes.NewReader(compressedData)) + decompressed1, err := io.ReadAll(r1) + failOnError(t, "ReadAll error (normal)", err) + if !bytes.Equal(decompressed1, []byte(data)) { + t.Fatal("Regular decompression failed to match original data") + } + failOnError(t, "Reader close error", r1.Close()) + }) + + // Decompression with max window size > original window should work + t.Run("LargerMaxWindowSize", func(t *testing.T) { + r2 := NewReaderDictMaxWindowSize(bytes.NewReader(compressedData), tc.dict, 1<<18) + decompressed2, err := io.ReadAll(r2) + failOnError(t, "ReadAll error (large max window)", err) + if !bytes.Equal(decompressed2, []byte(data)) { + t.Fatalf("Decompression with larger max window failed to match original data - got len=%d, want len=%d", + len(decompressed2), len(data)) + } + failOnError(t, "Reader close error", r2.Close()) + }) + + // Decompression with max window size < original window should fail + t.Run("SmallerMaxWindowSize", func(t *testing.T) { + // We set it to 64KB, less than the 128KB used for compression + r3 := NewReaderDictMaxWindowSize(bytes.NewReader(compressedData), tc.dict, 1<<16) + _, err = io.ReadAll(r3) + if err == nil { + t.Fatal("Expected error when max window size is too small, got nil") + } + if !strings.Contains(err.Error(), "Frame requires too much memory") { + t.Fatalf("Unexpected error message: %v", err) + } + r3.Close() + }) + }) + } +} + func BenchmarkStreamCompression(b *testing.B) { if raw == nil { b.Fatal(ErrNoPayloadEnv) From 1d7ca9c2abc9814b77a322f80f979084752c8a76 Mon Sep 17 00:00:00 2001 From: cce <51567+cce@users.noreply.github.com> Date: Thu, 6 Mar 2025 14:18:10 -0500 Subject: [PATCH 2/4] fix tests using older versions of go --- zstd_stream_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/zstd_stream_test.go b/zstd_stream_test.go index 3aec74d..79f710f 100644 --- a/zstd_stream_test.go +++ b/zstd_stream_test.go @@ -490,7 +490,7 @@ func TestStreamMaxWindowSize(t *testing.T) { // Normal decompression should work t.Run("NormalDecompression", func(t *testing.T) { r1 := NewReader(bytes.NewReader(compressedData)) - decompressed1, err := io.ReadAll(r1) + decompressed1, err := ioutil.ReadAll(r1) failOnError(t, "ReadAll error (normal)", err) if !bytes.Equal(decompressed1, []byte(data)) { t.Fatal("Regular decompression failed to match original data") @@ -501,7 +501,7 @@ func TestStreamMaxWindowSize(t *testing.T) { // Decompression with max window size > original window should work t.Run("LargerMaxWindowSize", func(t *testing.T) { r2 := NewReaderDictMaxWindowSize(bytes.NewReader(compressedData), tc.dict, 1<<18) - decompressed2, err := io.ReadAll(r2) + decompressed2, err := ioutil.ReadAll(r2) failOnError(t, "ReadAll error (large max window)", err) if !bytes.Equal(decompressed2, []byte(data)) { t.Fatalf("Decompression with larger max window failed to match original data - got len=%d, want len=%d", @@ -514,7 +514,7 @@ func TestStreamMaxWindowSize(t *testing.T) { t.Run("SmallerMaxWindowSize", func(t *testing.T) { // We set it to 64KB, less than the 128KB used for compression r3 := NewReaderDictMaxWindowSize(bytes.NewReader(compressedData), tc.dict, 1<<16) - _, err = io.ReadAll(r3) + _, err = ioutil.ReadAll(r3) if err == nil { t.Fatal("Expected error when max window size is too small, got nil") } From 1bcd920342ce4e049616b7ce27d8650ae89c05ab Mon Sep 17 00:00:00 2001 From: cce <51567+cce@users.noreply.github.com> Date: Thu, 6 Mar 2025 22:32:19 -0500 Subject: [PATCH 3/4] Fix CI external_libzstd job failing on apt package providing zstd v1.4.8 due to facebook/zstd#2442, fixed in v1.4.9 by facebook/zstd#2451 --- zstd.go | 2 ++ zstd_stream_test.go | 18 +++++++++++------- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/zstd.go b/zstd.go index 8499bf1..f74ed59 100644 --- a/zstd.go +++ b/zstd.go @@ -37,6 +37,8 @@ const ( decompressSizeBufferLimit = 1000 * 1000 zstdFrameHeaderSizeMin = 2 // From zstd.h. Since it's experimental API, hardcoding it + + zstdVersion = C.ZSTD_VERSION_NUMBER ) // CompressBound returns the worst case size needed for a destination buffer, diff --git a/zstd_stream_test.go b/zstd_stream_test.go index 79f710f..d37bcf9 100644 --- a/zstd_stream_test.go +++ b/zstd_stream_test.go @@ -415,8 +415,8 @@ func TestStreamSetNbWorkers(t *testing.T) { } func TestStreamWindowSize(t *testing.T) { - dict := []byte("dictdata_for_compression_test") - data := []byte("hello world") + dict := []byte(strings.Repeat("dictdata_for_compression_test", 1000)) + data := []byte(strings.Repeat("abcdefghijklmnopqrstuvwxyz", 10000)) testCases := []struct { name string dict []byte @@ -464,7 +464,7 @@ func TestStreamWindowSize(t *testing.T) { } func TestStreamMaxWindowSize(t *testing.T) { - dict := []byte("dictdata_for_compression_test") + dict := []byte(strings.Repeat("dictdata_for_compression_test", 1000)) testCases := []struct { name string dict []byte @@ -477,11 +477,11 @@ func TestStreamMaxWindowSize(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // Create compressed data with a 128KB window size - data := strings.Repeat("abcdefghijklmnopqrstuvwxyz", 1000) + data := []byte(strings.Repeat("abcdefghijklmnopqrstuvwxyz", 10000)) var buf bytes.Buffer w := NewWriterLevelDictWindowSize(&buf, DefaultCompression, tc.dict, 1<<17) // 128 KB - _, err := w.Write([]byte(data)) + _, err := w.Write(data) failOnError(t, "Write error", err) failOnError(t, "Flush error", w.Flush()) failOnError(t, "Close error", w.Close()) @@ -492,7 +492,7 @@ func TestStreamMaxWindowSize(t *testing.T) { r1 := NewReader(bytes.NewReader(compressedData)) decompressed1, err := ioutil.ReadAll(r1) failOnError(t, "ReadAll error (normal)", err) - if !bytes.Equal(decompressed1, []byte(data)) { + if !bytes.Equal(decompressed1, data) { t.Fatal("Regular decompression failed to match original data") } failOnError(t, "Reader close error", r1.Close()) @@ -503,7 +503,7 @@ func TestStreamMaxWindowSize(t *testing.T) { r2 := NewReaderDictMaxWindowSize(bytes.NewReader(compressedData), tc.dict, 1<<18) decompressed2, err := ioutil.ReadAll(r2) failOnError(t, "ReadAll error (large max window)", err) - if !bytes.Equal(decompressed2, []byte(data)) { + if !bytes.Equal(decompressed2, data) { t.Fatalf("Decompression with larger max window failed to match original data - got len=%d, want len=%d", len(decompressed2), len(data)) } @@ -512,6 +512,10 @@ func TestStreamMaxWindowSize(t *testing.T) { // Decompression with max window size < original window should fail t.Run("SmallerMaxWindowSize", func(t *testing.T) { + // workaround for regression when setting window size & using dictionary (facebook/zstd#2442) + if zstdVersion < 10409 && zstdVersion > 10405 && len(tc.dict) > 0 { + t.Skip("Skipping: Zstd v1.4.5 - v1.4.9 won't set window size when streaming with dictionary") + } // We set it to 64KB, less than the 128KB used for compression r3 := NewReaderDictMaxWindowSize(bytes.NewReader(compressedData), tc.dict, 1<<16) _, err = ioutil.ReadAll(r3) From 1b2e25ff7c6533286b4462d5662916f58cf0c1fa Mon Sep 17 00:00:00 2001 From: cce <51567+cce@users.noreply.github.com> Date: Thu, 6 Mar 2025 22:46:23 -0500 Subject: [PATCH 4/4] update comments --- zstd_stream.go | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/zstd_stream.go b/zstd_stream.go index 042f7c6..b67ddf5 100644 --- a/zstd_stream.go +++ b/zstd_stream.go @@ -123,9 +123,8 @@ func NewWriterLevelDict(w io.Writer, level int, dict []byte) *Writer { } // NewWriterLevelDictWindowSize is like NewWriterLevelDict but allows configuring -// the window size. windowSize is specified in bytes and will be converted to a windowLog -// parameter (log2 of the window size). If windowSize is 0, the default window size is used. -// The windowSize must be a power of 2 between 1KB and 8MB on 32-bit platforms +// the window size, specificed in bytes. If windowSize is 0, the default window size is used. +// The windowSize must be a power of 2 between 1KB and 1GB on 32-bit platforms, // or 1KB and 2GB on 64-bit platforms. // A larger window size allows for better compression ratios for repetitive data // but requires more memory during compression and decompression. @@ -401,14 +400,13 @@ func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser { } // NewReaderDictMaxWindowSize is like NewReaderDict but allows configuring the maximum -// window size for decompression. maxWindowSize is specified in bytes, not as a log value. +// window size for decompression, specified in bytes. // If maxWindowSize is 0, the default window size limit is used. // Setting a maximum window size protects against allocating too much memory for // decompression (potential attack scenario) when processing untrusted inputs. func NewReaderDictMaxWindowSize(r io.Reader, dict []byte, maxWindowSize int) io.ReadCloser { var err error ctx := C.ZSTD_createDStream() - if len(dict) == 0 { err = getError(int(C.ZSTD_initDStream(ctx))) } else {