Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ Decompress(dst, src []byte) ([]byte, error)
NewWriter(w io.Writer) *Writer
NewWriterLevel(w io.Writer, level int) *Writer
NewWriterLevelDict(w io.Writer, level int, dict []byte) *Writer
NewWriterLevelDictWindowSize(w io.Writer, level int, dict []byte, windowSize int) *Writer

// Write compresses the input data and write it to the underlying writer
(w *Writer) Write(p []byte) (int, error)
Expand All @@ -89,6 +90,7 @@ NewWriterLevelDict(w io.Writer, level int, dict []byte) *Writer
// to call Close, which frees up C objects.
NewReader(r io.Reader) io.ReadCloser
NewReaderDict(r io.Reader, dict []byte) io.ReadCloser
NewReaderDictMaxWindowSize(r io.Reader, dict []byte, maxWindowSize int) io.ReadCloser
```

### Benchmarks (benchmarked with v0.5.0)
Expand Down
2 changes: 2 additions & 0 deletions zstd.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ const (
decompressSizeBufferLimit = 1000 * 1000

zstdFrameHeaderSizeMin = 2 // From zstd.h. Since it's experimental API, hardcoding it

zstdVersion = C.ZSTD_VERSION_NUMBER
)

// CompressBound returns the worst case size needed for a destination buffer,
Expand Down
39 changes: 38 additions & 1 deletion zstd_stream.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package zstd

/*
#define ZSTD_STATIC_LINKING_ONLY 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🫤 Unfortunately this means only static linking is allowed which is a deal breaker for many users of this library that are using dynamic linking. As this would break compatibility, we can't merge it to 1.x.

#include "zstd.h"

typedef struct compressStream2_result_s {
Expand Down Expand Up @@ -65,6 +66,7 @@ import (
"errors"
"fmt"
"io"
"math/bits"
"runtime"
"sync"
"unsafe"
Expand Down Expand Up @@ -117,11 +119,21 @@ func NewWriterLevel(w io.Writer, level int) *Writer {
// compress with. If the dictionary is empty or nil it is ignored. The dictionary
// should not be modified until the writer is closed.
func NewWriterLevelDict(w io.Writer, level int, dict []byte) *Writer {
return NewWriterLevelDictWindowSize(w, level, dict, 0)
}

// NewWriterLevelDictWindowSize is like NewWriterLevelDict but allows configuring
// the window size, specificed in bytes. If windowSize is 0, the default window size is used.
// The windowSize must be a power of 2 between 1KB and 1GB on 32-bit platforms,
// or 1KB and 2GB on 64-bit platforms.
// A larger window size allows for better compression ratios for repetitive data
// but requires more memory during compression and decompression.
func NewWriterLevelDictWindowSize(w io.Writer, level int, dict []byte, windowSize int) *Writer {
var err error
ctx := C.ZSTD_createCStream()

// Load dictionnary if any
if dict != nil {
if len(dict) > 0 {
err = getError(int(C.ZSTD_CCtx_loadDictionary(ctx,
unsafe.Pointer(&dict[0]),
C.size_t(len(dict)),
Expand All @@ -133,6 +145,17 @@ func NewWriterLevelDict(w io.Writer, level int, dict []byte) *Writer {
err = getError(int(C.ZSTD_CCtx_setParameter(ctx, C.ZSTD_c_compressionLevel, C.int(level))))
}

if err == nil && windowSize > 0 {
// Only set windowLog if windowSize is a power of 2
if windowSize&(windowSize-1) == 0 {
// Convert windowSize to windowLog using bits.TrailingZeros
windowLog := bits.TrailingZeros(uint(windowSize))
err = getError(int(C.ZSTD_CCtx_setParameter(ctx, C.ZSTD_c_windowLog, C.int(windowLog))))
} else {
err = fmt.Errorf("window size must be a power of 2")
}
}

return &Writer{
CompressionLevel: level,
ctx: ctx,
Expand Down Expand Up @@ -373,6 +396,15 @@ func NewReader(r io.Reader) io.ReadCloser {
// NewReaderDict is like NewReader but uses a preset dictionary. NewReaderDict
// ignores the dictionary if it is nil.
func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser {
return NewReaderDictMaxWindowSize(r, dict, 0)
}

// NewReaderDictMaxWindowSize is like NewReaderDict but allows configuring the maximum
// window size for decompression, specified in bytes.
// If maxWindowSize is 0, the default window size limit is used.
// Setting a maximum window size protects against allocating too much memory for
// decompression (potential attack scenario) when processing untrusted inputs.
func NewReaderDictMaxWindowSize(r io.Reader, dict []byte, maxWindowSize int) io.ReadCloser {
var err error
ctx := C.ZSTD_createDStream()
if len(dict) == 0 {
Expand All @@ -387,6 +419,11 @@ func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser {
C.size_t(len(dict)))))
}
}

if err == nil && maxWindowSize > 0 {
err = getError(int(C.ZSTD_DCtx_setMaxWindowSize(ctx, C.size_t(maxWindowSize))))
}

compressionBufferP := cPool.Get().(*[]byte)
decompressionBufferP := dPool.Get().(*[]byte)
return &reader{
Expand Down
117 changes: 117 additions & 0 deletions zstd_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,123 @@ func TestStreamSetNbWorkers(t *testing.T) {
testCompressionDecompression(t, nil, []byte(s), nbWorkers)
}

func TestStreamWindowSize(t *testing.T) {
dict := []byte(strings.Repeat("dictdata_for_compression_test", 1000))
data := []byte(strings.Repeat("abcdefghijklmnopqrstuvwxyz", 10000))
testCases := []struct {
name string
dict []byte
}{
{"NilDict", nil},
{"ValidDict", dict},
{"EmptyDict", []byte{}},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Test with valid window size (power of 2)
t.Run("ValidWindowSize", func(t *testing.T) {
var buf bytes.Buffer
w := NewWriterLevelDictWindowSize(&buf, DefaultCompression, tc.dict, 1<<17) // 128 KB

_, err := w.Write(data)
failOnError(t, "Write error", err)
failOnError(t, "Close error", w.Close())

// Test decoding
r := NewReader(&buf)
decompressed, err := ioutil.ReadAll(r)
failOnError(t, "ReadAll error", err)
if !bytes.Equal(decompressed, data) {
t.Fatalf("got %q; want %q", decompressed, data)
}
failOnError(t, "Reader close error", r.Close())
})

// Test with invalid window size (not a power of 2)
t.Run("InvalidWindowSize", func(t *testing.T) {
var buf bytes.Buffer
w := NewWriterLevelDictWindowSize(&buf, DefaultCompression, tc.dict, 123456)
_, err := w.Write(data)
if err == nil {
t.Fatal("Expected error for invalid window size, got nil")
}
if !strings.Contains(err.Error(), "window size must be a power of 2") {
t.Fatalf("Unexpected error message: %v", err)
}
})
})
}
}

func TestStreamMaxWindowSize(t *testing.T) {
dict := []byte(strings.Repeat("dictdata_for_compression_test", 1000))
testCases := []struct {
name string
dict []byte
}{
{"NilDict", nil},
{"ValidDict", dict},
{"EmptyDict", []byte{}},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create compressed data with a 128KB window size
data := []byte(strings.Repeat("abcdefghijklmnopqrstuvwxyz", 10000))
var buf bytes.Buffer
w := NewWriterLevelDictWindowSize(&buf, DefaultCompression, tc.dict, 1<<17) // 128 KB

_, err := w.Write(data)
failOnError(t, "Write error", err)
failOnError(t, "Flush error", w.Flush())
failOnError(t, "Close error", w.Close())
compressedData := buf.Bytes()

// Normal decompression should work
t.Run("NormalDecompression", func(t *testing.T) {
r1 := NewReader(bytes.NewReader(compressedData))
decompressed1, err := ioutil.ReadAll(r1)
failOnError(t, "ReadAll error (normal)", err)
if !bytes.Equal(decompressed1, data) {
t.Fatal("Regular decompression failed to match original data")
}
failOnError(t, "Reader close error", r1.Close())
})

// Decompression with max window size > original window should work
t.Run("LargerMaxWindowSize", func(t *testing.T) {
r2 := NewReaderDictMaxWindowSize(bytes.NewReader(compressedData), tc.dict, 1<<18)
decompressed2, err := ioutil.ReadAll(r2)
failOnError(t, "ReadAll error (large max window)", err)
if !bytes.Equal(decompressed2, data) {
t.Fatalf("Decompression with larger max window failed to match original data - got len=%d, want len=%d",
len(decompressed2), len(data))
}
failOnError(t, "Reader close error", r2.Close())
})

// Decompression with max window size < original window should fail
t.Run("SmallerMaxWindowSize", func(t *testing.T) {
// workaround for regression when setting window size & using dictionary (facebook/zstd#2442)
if zstdVersion < 10409 && zstdVersion > 10405 && len(tc.dict) > 0 {
t.Skip("Skipping: Zstd v1.4.5 - v1.4.9 won't set window size when streaming with dictionary")
}
// We set it to 64KB, less than the 128KB used for compression
r3 := NewReaderDictMaxWindowSize(bytes.NewReader(compressedData), tc.dict, 1<<16)
_, err = ioutil.ReadAll(r3)
if err == nil {
t.Fatal("Expected error when max window size is too small, got nil")
}
if !strings.Contains(err.Error(), "Frame requires too much memory") {
t.Fatalf("Unexpected error message: %v", err)
}
r3.Close()
})
})
}
}

func BenchmarkStreamCompression(b *testing.B) {
if raw == nil {
b.Fatal(ErrNoPayloadEnv)
Expand Down