diff --git a/README.md b/README.md index c1db495..8bd45b2 100644 --- a/README.md +++ b/README.md @@ -71,6 +71,7 @@ Decompress(dst, src []byte) ([]byte, error) NewWriter(w io.Writer) *Writer NewWriterLevel(w io.Writer, level int) *Writer NewWriterLevelDict(w io.Writer, level int, dict []byte) *Writer +NewWriterLevelDictWindowSize(w io.Writer, level int, dict []byte, windowSize int) *Writer // Write compresses the input data and write it to the underlying writer (w *Writer) Write(p []byte) (int, error) @@ -89,6 +90,7 @@ NewWriterLevelDict(w io.Writer, level int, dict []byte) *Writer // to call Close, which frees up C objects. NewReader(r io.Reader) io.ReadCloser NewReaderDict(r io.Reader, dict []byte) io.ReadCloser +NewReaderDictMaxWindowSize(r io.Reader, dict []byte, maxWindowSize int) io.ReadCloser ``` ### Benchmarks (benchmarked with v0.5.0) diff --git a/zstd.go b/zstd.go index 8499bf1..f74ed59 100644 --- a/zstd.go +++ b/zstd.go @@ -37,6 +37,8 @@ const ( decompressSizeBufferLimit = 1000 * 1000 zstdFrameHeaderSizeMin = 2 // From zstd.h. Since it's experimental API, hardcoding it + + zstdVersion = C.ZSTD_VERSION_NUMBER ) // CompressBound returns the worst case size needed for a destination buffer, diff --git a/zstd_stream.go b/zstd_stream.go index 714ecfe..b67ddf5 100644 --- a/zstd_stream.go +++ b/zstd_stream.go @@ -1,6 +1,7 @@ package zstd /* +#define ZSTD_STATIC_LINKING_ONLY 1 #include "zstd.h" typedef struct compressStream2_result_s { @@ -65,6 +66,7 @@ import ( "errors" "fmt" "io" + "math/bits" "runtime" "sync" "unsafe" @@ -117,11 +119,21 @@ func NewWriterLevel(w io.Writer, level int) *Writer { // compress with. If the dictionary is empty or nil it is ignored. The dictionary // should not be modified until the writer is closed. func NewWriterLevelDict(w io.Writer, level int, dict []byte) *Writer { + return NewWriterLevelDictWindowSize(w, level, dict, 0) +} + +// NewWriterLevelDictWindowSize is like NewWriterLevelDict but allows configuring +// the window size, specificed in bytes. If windowSize is 0, the default window size is used. +// The windowSize must be a power of 2 between 1KB and 1GB on 32-bit platforms, +// or 1KB and 2GB on 64-bit platforms. +// A larger window size allows for better compression ratios for repetitive data +// but requires more memory during compression and decompression. +func NewWriterLevelDictWindowSize(w io.Writer, level int, dict []byte, windowSize int) *Writer { var err error ctx := C.ZSTD_createCStream() // Load dictionnary if any - if dict != nil { + if len(dict) > 0 { err = getError(int(C.ZSTD_CCtx_loadDictionary(ctx, unsafe.Pointer(&dict[0]), C.size_t(len(dict)), @@ -133,6 +145,17 @@ func NewWriterLevelDict(w io.Writer, level int, dict []byte) *Writer { err = getError(int(C.ZSTD_CCtx_setParameter(ctx, C.ZSTD_c_compressionLevel, C.int(level)))) } + if err == nil && windowSize > 0 { + // Only set windowLog if windowSize is a power of 2 + if windowSize&(windowSize-1) == 0 { + // Convert windowSize to windowLog using bits.TrailingZeros + windowLog := bits.TrailingZeros(uint(windowSize)) + err = getError(int(C.ZSTD_CCtx_setParameter(ctx, C.ZSTD_c_windowLog, C.int(windowLog)))) + } else { + err = fmt.Errorf("window size must be a power of 2") + } + } + return &Writer{ CompressionLevel: level, ctx: ctx, @@ -373,6 +396,15 @@ func NewReader(r io.Reader) io.ReadCloser { // NewReaderDict is like NewReader but uses a preset dictionary. NewReaderDict // ignores the dictionary if it is nil. func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser { + return NewReaderDictMaxWindowSize(r, dict, 0) +} + +// NewReaderDictMaxWindowSize is like NewReaderDict but allows configuring the maximum +// window size for decompression, specified in bytes. +// If maxWindowSize is 0, the default window size limit is used. +// Setting a maximum window size protects against allocating too much memory for +// decompression (potential attack scenario) when processing untrusted inputs. +func NewReaderDictMaxWindowSize(r io.Reader, dict []byte, maxWindowSize int) io.ReadCloser { var err error ctx := C.ZSTD_createDStream() if len(dict) == 0 { @@ -387,6 +419,11 @@ func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser { C.size_t(len(dict))))) } } + + if err == nil && maxWindowSize > 0 { + err = getError(int(C.ZSTD_DCtx_setMaxWindowSize(ctx, C.size_t(maxWindowSize)))) + } + compressionBufferP := cPool.Get().(*[]byte) decompressionBufferP := dPool.Get().(*[]byte) return &reader{ diff --git a/zstd_stream_test.go b/zstd_stream_test.go index d0960f4..d37bcf9 100644 --- a/zstd_stream_test.go +++ b/zstd_stream_test.go @@ -414,6 +414,123 @@ func TestStreamSetNbWorkers(t *testing.T) { testCompressionDecompression(t, nil, []byte(s), nbWorkers) } +func TestStreamWindowSize(t *testing.T) { + dict := []byte(strings.Repeat("dictdata_for_compression_test", 1000)) + data := []byte(strings.Repeat("abcdefghijklmnopqrstuvwxyz", 10000)) + testCases := []struct { + name string + dict []byte + }{ + {"NilDict", nil}, + {"ValidDict", dict}, + {"EmptyDict", []byte{}}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Test with valid window size (power of 2) + t.Run("ValidWindowSize", func(t *testing.T) { + var buf bytes.Buffer + w := NewWriterLevelDictWindowSize(&buf, DefaultCompression, tc.dict, 1<<17) // 128 KB + + _, err := w.Write(data) + failOnError(t, "Write error", err) + failOnError(t, "Close error", w.Close()) + + // Test decoding + r := NewReader(&buf) + decompressed, err := ioutil.ReadAll(r) + failOnError(t, "ReadAll error", err) + if !bytes.Equal(decompressed, data) { + t.Fatalf("got %q; want %q", decompressed, data) + } + failOnError(t, "Reader close error", r.Close()) + }) + + // Test with invalid window size (not a power of 2) + t.Run("InvalidWindowSize", func(t *testing.T) { + var buf bytes.Buffer + w := NewWriterLevelDictWindowSize(&buf, DefaultCompression, tc.dict, 123456) + _, err := w.Write(data) + if err == nil { + t.Fatal("Expected error for invalid window size, got nil") + } + if !strings.Contains(err.Error(), "window size must be a power of 2") { + t.Fatalf("Unexpected error message: %v", err) + } + }) + }) + } +} + +func TestStreamMaxWindowSize(t *testing.T) { + dict := []byte(strings.Repeat("dictdata_for_compression_test", 1000)) + testCases := []struct { + name string + dict []byte + }{ + {"NilDict", nil}, + {"ValidDict", dict}, + {"EmptyDict", []byte{}}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create compressed data with a 128KB window size + data := []byte(strings.Repeat("abcdefghijklmnopqrstuvwxyz", 10000)) + var buf bytes.Buffer + w := NewWriterLevelDictWindowSize(&buf, DefaultCompression, tc.dict, 1<<17) // 128 KB + + _, err := w.Write(data) + failOnError(t, "Write error", err) + failOnError(t, "Flush error", w.Flush()) + failOnError(t, "Close error", w.Close()) + compressedData := buf.Bytes() + + // Normal decompression should work + t.Run("NormalDecompression", func(t *testing.T) { + r1 := NewReader(bytes.NewReader(compressedData)) + decompressed1, err := ioutil.ReadAll(r1) + failOnError(t, "ReadAll error (normal)", err) + if !bytes.Equal(decompressed1, data) { + t.Fatal("Regular decompression failed to match original data") + } + failOnError(t, "Reader close error", r1.Close()) + }) + + // Decompression with max window size > original window should work + t.Run("LargerMaxWindowSize", func(t *testing.T) { + r2 := NewReaderDictMaxWindowSize(bytes.NewReader(compressedData), tc.dict, 1<<18) + decompressed2, err := ioutil.ReadAll(r2) + failOnError(t, "ReadAll error (large max window)", err) + if !bytes.Equal(decompressed2, data) { + t.Fatalf("Decompression with larger max window failed to match original data - got len=%d, want len=%d", + len(decompressed2), len(data)) + } + failOnError(t, "Reader close error", r2.Close()) + }) + + // Decompression with max window size < original window should fail + t.Run("SmallerMaxWindowSize", func(t *testing.T) { + // workaround for regression when setting window size & using dictionary (facebook/zstd#2442) + if zstdVersion < 10409 && zstdVersion > 10405 && len(tc.dict) > 0 { + t.Skip("Skipping: Zstd v1.4.5 - v1.4.9 won't set window size when streaming with dictionary") + } + // We set it to 64KB, less than the 128KB used for compression + r3 := NewReaderDictMaxWindowSize(bytes.NewReader(compressedData), tc.dict, 1<<16) + _, err = ioutil.ReadAll(r3) + if err == nil { + t.Fatal("Expected error when max window size is too small, got nil") + } + if !strings.Contains(err.Error(), "Frame requires too much memory") { + t.Fatalf("Unexpected error message: %v", err) + } + r3.Close() + }) + }) + } +} + func BenchmarkStreamCompression(b *testing.B) { if raw == nil { b.Fatal(ErrNoPayloadEnv)