From fa4267c96b5a264ea7dacecfc2bd93778454f132 Mon Sep 17 00:00:00 2001 From: Adil Zouhal Date: Sat, 25 Oct 2025 18:26:45 +0200 Subject: [PATCH] feat(otel): add unified filtering API with O(1) command exclusion Add comprehensive filtering capabilities with performance-optimized command exclusion and unified filtering API. - WithExcludedCommands(): O(1) command exclusion using map lookup - WithProcessFilter(): Custom filtering for individual commands - WithProcessPipelineFilter(): Custom filtering for pipeline operations - WithDialFilterFunc(): Custom filtering for dial operations - Backward compatible with existing filtering APIs - Comprehensive test coverage and updated examples Fixes #3479 --- example/otel/advanced_filtering.go | 113 ++++++ example/otel/client.go | 7 +- extra/redisotel/README.md | 92 +++++ extra/redisotel/config.go | 60 ++- extra/redisotel/tracing.go | 93 +++-- extra/redisotel/unified_filtering_test.go | 423 ++++++++++++++++++++++ 6 files changed, 757 insertions(+), 31 deletions(-) create mode 100644 example/otel/advanced_filtering.go create mode 100644 extra/redisotel/unified_filtering_test.go diff --git a/example/otel/advanced_filtering.go b/example/otel/advanced_filtering.go new file mode 100644 index 0000000000..daddd8d2b3 --- /dev/null +++ b/example/otel/advanced_filtering.go @@ -0,0 +1,113 @@ +package main + +import ( + "context" + "fmt" + "strings" + + "github.com/redis/go-redis/extra/redisotel/v9" + "github.com/redis/go-redis/v9" +) + +func runAdvancedFilteringExamples() { + ctx := context.Background() + + // Example 1: High-performance O(1) command exclusion + fmt.Println("=== Example 1: O(1) Command Exclusion ===") + rdb1 := redis.NewClient(&redis.Options{Addr: ":6379"}) + + if err := redisotel.InstrumentTracing(rdb1, + redisotel.WithExcludedCommands("PING", "INFO", "SELECT")); err != nil { + panic(err) + } + + // These commands won't be traced + rdb1.Ping(ctx) + rdb1.Info(ctx) + + // This command will be traced + rdb1.Set(ctx, "key1", "value1", 0) + fmt.Println("✓ O(1) exclusion configured") + + // Example 2: Custom filtering logic + fmt.Println("\n=== Example 2: Custom Process Filter ===") + rdb2 := redis.NewClient(&redis.Options{Addr: ":6379"}) + + if err := redisotel.InstrumentTracing(rdb2, + redisotel.WithProcessFilter(func(cmd redis.Cmder) bool { + // Exclude commands that start with "DEBUG_" or "INTERNAL_" + name := strings.ToUpper(cmd.Name()) + return strings.HasPrefix(name, "DEBUG_") || + strings.HasPrefix(name, "INTERNAL_") + })); err != nil { + panic(err) + } + fmt.Println("✓ Custom process filter configured") + + // Example 3: Pipeline filtering + fmt.Println("\n=== Example 3: Pipeline Filter ===") + rdb3 := redis.NewClient(&redis.Options{Addr: ":6379"}) + + if err := redisotel.InstrumentTracing(rdb3, + redisotel.WithProcessPipelineFilter(func(cmds []redis.Cmder) bool { + // Exclude large pipelines (>10 commands) from tracing + return len(cmds) > 10 + })); err != nil { + panic(err) + } + + // Small pipeline - will be traced + pipe := rdb3.Pipeline() + pipe.Set(ctx, "key1", "value1", 0) + pipe.Set(ctx, "key2", "value2", 0) + pipe.Exec(ctx) + fmt.Println("✓ Pipeline filter configured") + + // Example 4: Dial filtering + fmt.Println("\n=== Example 4: Dial Filter ===") + rdb4 := redis.NewClient(&redis.Options{Addr: ":6379"}) + + if err := redisotel.InstrumentTracing(rdb4, + redisotel.WithDialFilterFunc(func(network, addr string) bool { + // Don't trace connections to localhost for development + return strings.Contains(addr, "localhost") || + strings.Contains(addr, "127.0.0.1") + })); err != nil { + panic(err) + } + fmt.Println("✓ Dial filter configured") + + // Example 5: Combined approach for optimal performance + fmt.Println("\n=== Example 5: Combined Filtering Approach ===") + rdb5 := redis.NewClient(&redis.Options{Addr: ":6379"}) + + if err := redisotel.InstrumentTracing(rdb5, + // Fast O(1) exclusion for common commands + redisotel.WithExcludedCommands("PING", "INFO", "SELECT"), + // Custom logic for additional filtering + redisotel.WithProcessFilter(func(cmd redis.Cmder) bool { + return strings.HasPrefix(strings.ToUpper(cmd.Name()), "DEBUG_") + })); err != nil { + panic(err) + } + + // Test the combined approach + rdb5.Ping(ctx) // Excluded by O(1) set + rdb5.Set(ctx, "key", "value", 0) // Will be traced + fmt.Println("✓ Combined filtering approach configured") + + // Example 6: Backward compatibility with legacy APIs + fmt.Println("\n=== Example 6: Legacy API Compatibility ===") + rdb6 := redis.NewClient(&redis.Options{Addr: ":6379"}) + + if err := redisotel.InstrumentTracing(rdb6, + // Legacy command filter still works + redisotel.WithCommandFilter(func(cmd redis.Cmder) bool { + return strings.ToUpper(cmd.Name()) == "AUTH" + })); err != nil { + panic(err) + } + fmt.Println("✓ Legacy API compatibility maintained") + + fmt.Println("\n=== All filtering examples completed successfully! ===") +} diff --git a/example/otel/client.go b/example/otel/client.go index 165a9234a7..5a724fd059 100644 --- a/example/otel/client.go +++ b/example/otel/client.go @@ -33,13 +33,18 @@ func main() { rdb := redis.NewClient(&redis.Options{ Addr: ":6379", }) - if err := redisotel.InstrumentTracing(rdb); err != nil { + // Basic tracing with performance-optimized command exclusion + if err := redisotel.InstrumentTracing(rdb, + redisotel.WithExcludedCommands("PING", "INFO")); err != nil { panic(err) } if err := redisotel.InstrumentMetrics(rdb); err != nil { panic(err) } + // Run advanced filtering examples + runAdvancedFilteringExamples() + for i := 0; i < 1e6; i++ { ctx, rootSpan := tracer.Start(ctx, "handleRequest") diff --git a/extra/redisotel/README.md b/extra/redisotel/README.md index 997c17d1c5..f92fffcc43 100644 --- a/extra/redisotel/README.md +++ b/extra/redisotel/README.md @@ -29,6 +29,98 @@ if err := redisotel.InstrumentMetrics(rdb); err != nil { } ``` +## Advanced Tracing Options + +### High-Performance Command Filtering + +For production systems, use O(1) command exclusion for optimal performance: + +```go +// Recommended: O(1) command exclusion +err := redisotel.InstrumentTracing(rdb, + redisotel.WithExcludedCommands("PING", "INFO", "SELECT")) +if err != nil { + panic(err) +} +``` + +### Custom Filtering Logic + +For complex filtering requirements, use custom filter functions: + +```go +// Filter individual commands +err := redisotel.InstrumentTracing(rdb, + redisotel.WithProcessFilter(func(cmd redis.Cmder) bool { + // Return true to exclude the command from tracing + return strings.HasPrefix(cmd.Name(), "INTERNAL_") + })) +if err != nil { + panic(err) +} + +// Filter pipeline commands +err := redisotel.InstrumentTracing(rdb, + redisotel.WithProcessPipelineFilter(func(cmds []redis.Cmder) bool { + // Return true to exclude pipelines with more than 10 commands + return len(cmds) > 10 + })) +if err != nil { + panic(err) +} + +// Filter dial operations +err := redisotel.InstrumentTracing(rdb, + redisotel.WithDialFilterFunc(func(network, addr string) bool { + // Return true to exclude connections to localhost + return strings.Contains(addr, "localhost") + })) +if err != nil { + panic(err) +} +``` + +### Combining Filtering Approaches + +Exclusion sets are checked first for optimal performance: + +```go +err := redisotel.InstrumentTracing(rdb, + // Fast O(1) exclusion for common commands + redisotel.WithExcludedCommands("PING", "INFO"), + // Custom logic for additional cases + redisotel.WithProcessFilter(func(cmd redis.Cmder) bool { + return strings.HasPrefix(cmd.Name(), "DEBUG_") + })) +if err != nil { + panic(err) +} +``` + +### Legacy API Compatibility + +Original filtering APIs remain supported: + +```go +// Legacy command filter +redisotel.WithCommandFilter(func(cmd redis.Cmder) bool { + return cmd.Name() == "AUTH" // Exclude AUTH commands +}) + +// Legacy pipeline filter +redisotel.WithCommandsFilter(func(cmds []redis.Cmder) bool { + for _, cmd := range cmds { + if cmd.Name() == "AUTH" { + return true // Exclude pipelines with AUTH commands + } + } + return false +}) + +// Legacy dial filter +redisotel.WithDialFilter(true) // Enable dial filtering +``` + See [example](../../example/otel) and [Monitoring Go Redis Performance and Errors](https://redis.uptrace.dev/guide/go-redis-monitoring.html) for details. diff --git a/extra/redisotel/config.go b/extra/redisotel/config.go index b9311beafa..280c707af9 100644 --- a/extra/redisotel/config.go +++ b/extra/redisotel/config.go @@ -19,14 +19,21 @@ type config struct { // Tracing options. - tp trace.TracerProvider - tracer trace.Tracer + tp trace.TracerProvider + tracer trace.Tracer + dbStmtEnabled bool + callerEnabled bool + excludedCommands map[string]struct{} - dbStmtEnabled bool - callerEnabled bool + // Legacy filters filterDial bool - filterProcessPipeline func(cmds []redis.Cmder) bool filterProcess func(cmd redis.Cmder) bool + filterProcessPipeline func(cmds []redis.Cmder) bool + + // Unified filters + unifiedProcessFilter func(cmd redis.Cmder) bool + unifiedProcessPipelineFilter func(cmds []redis.Cmder) bool + unifiedDialFilter func(network, addr string) bool // Metrics options. @@ -76,6 +83,7 @@ func newConfig(opts ...baseOption) *config { } return false }, + excludedCommands: make(map[string]struct{}), } for _, opt := range opts { @@ -163,6 +171,48 @@ func WithDialFilter(on bool) TracingOption { }) } +// WithExcludedCommands provides O(1) command exclusion for performance. +// Command names are normalized to uppercase for case-insensitive matching. +func WithExcludedCommands(commands ...string) TracingOption { + return tracingOption(func(conf *config) { + if len(commands) == 0 { + return + } + // Merge into existing map to accumulate exclusions across multiple calls + for _, cmd := range commands { + if cmd != "" { + conf.excludedCommands[strings.ToUpper(cmd)] = struct{}{} + } + } + }) +} + +// WithProcessFilter allows filtering of individual commands with custom logic. +// The filter function should return true to exclude the command from tracing. +func WithProcessFilter(filter func(cmd redis.Cmder) bool) TracingOption { + return tracingOption(func(conf *config) { + conf.unifiedProcessFilter = filter + conf.filterProcess = nil + }) +} + +// WithProcessPipelineFilter allows filtering of pipeline commands with custom logic. +// The filter function should return true to exclude the entire pipeline from tracing. +func WithProcessPipelineFilter(filter func(cmds []redis.Cmder) bool) TracingOption { + return tracingOption(func(conf *config) { + conf.unifiedProcessPipelineFilter = filter + conf.filterProcessPipeline = nil + }) +} + +// WithDialFilterFunc allows filtering of dial operations with custom logic. +// The filter function should return true to exclude the dial operation from tracing. +func WithDialFilterFunc(filter func(network, addr string) bool) TracingOption { + return tracingOption(func(conf *config) { + conf.unifiedDialFilter = filter + }) +} + // DefaultCommandFilter filters out AUTH commands from tracing. func DefaultCommandFilter(cmd redis.Cmder) bool { if strings.ToLower(cmd.Name()) == "auth" { diff --git a/extra/redisotel/tracing.go b/extra/redisotel/tracing.go index a6f361b06e..11984ba440 100644 --- a/extra/redisotel/tracing.go +++ b/extra/redisotel/tracing.go @@ -85,10 +85,52 @@ func newTracingHook(connString string, opts ...TracingOption) *tracingHook { } } +func (th *tracingHook) shouldTrace(cmd redis.Cmder) bool { + if _, excluded := th.conf.excludedCommands[strings.ToUpper(cmd.Name())]; excluded { + return false + } + + if th.conf.unifiedProcessFilter != nil { + return !th.conf.unifiedProcessFilter(cmd) + } + + if th.conf.filterProcess != nil { + return !th.conf.filterProcess(cmd) + } + + return true +} + +func (th *tracingHook) shouldTracePipeline(cmds []redis.Cmder) bool { + for _, cmd := range cmds { + if _, excluded := th.conf.excludedCommands[strings.ToUpper(cmd.Name())]; excluded { + return false + } + } + + if th.conf.unifiedProcessPipelineFilter != nil { + return !th.conf.unifiedProcessPipelineFilter(cmds) + } + + if th.conf.filterProcessPipeline != nil { + return !th.conf.filterProcessPipeline(cmds) + } + + return true +} + +func (th *tracingHook) shouldTraceDial(network, addr string) bool { + if th.conf.unifiedDialFilter != nil { + return !th.conf.unifiedDialFilter(network, addr) + } + + return !th.conf.filterDial +} + func (th *tracingHook) DialHook(hook redis.DialHook) redis.DialHook { return func(ctx context.Context, network, addr string) (net.Conn, error) { - if th.conf.filterDial { + if !th.shouldTraceDial(network, addr) { return hook(ctx, network, addr) } @@ -107,25 +149,25 @@ func (th *tracingHook) DialHook(hook redis.DialHook) redis.DialHook { func (th *tracingHook) ProcessHook(hook redis.ProcessHook) redis.ProcessHook { return func(ctx context.Context, cmd redis.Cmder) error { - // Check if the command should be filtered out - if th.conf.filterProcess != nil && th.conf.filterProcess(cmd) { - // If so, just call the next hook + if !th.shouldTrace(cmd) { return hook(ctx, cmd) } attrs := make([]attribute.KeyValue, 0, 8) if th.conf.callerEnabled { - fn, file, line := funcFileLine("github.com/redis/go-redis") - attrs = append(attrs, - semconv.CodeFunction(fn), - semconv.CodeFilepath(file), - semconv.CodeLineNumber(line), - ) + if fn, file, line := funcFileLine("github.com/redis/go-redis"); fn != "" { + attrs = append(attrs, + semconv.CodeFunction(fn), + semconv.CodeFilepath(file), + semconv.CodeLineNumber(line), + ) + } } if th.conf.dbStmtEnabled { - cmdString := rediscmd.CmdString(cmd) - attrs = append(attrs, semconv.DBStatement(cmdString)) + if cmdString := rediscmd.CmdString(cmd); cmdString != "" { + attrs = append(attrs, semconv.DBStatement(cmdString)) + } } opts := th.spanOpts @@ -147,7 +189,7 @@ func (th *tracingHook) ProcessPipelineHook( ) redis.ProcessPipelineHook { return func(ctx context.Context, cmds []redis.Cmder) error { - if th.conf.filterProcessPipeline != nil && th.conf.filterProcessPipeline(cmds) { + if !th.shouldTracePipeline(cmds) { return hook(ctx, cmds) } @@ -224,7 +266,8 @@ func funcFileLine(pkg string) (string, string, int) { return fn, file, line } -// Database span attributes semantic conventions recommended server address and port +// addServerAttributes adds database span attributes following semantic conventions +// for server address and port as recommended by OpenTelemetry. // https://opentelemetry.io/docs/specs/semconv/database/database-spans/#connection-level-attributes func addServerAttributes(opts []TracingOption, addr string) []TracingOption { host, portString, err := net.SplitHostPort(addr) @@ -232,19 +275,19 @@ func addServerAttributes(opts []TracingOption, addr string) []TracingOption { return opts } - opts = append(opts, WithAttributes( - semconv.ServerAddress(host), - )) - - // Parse the port string to an integer - port, err := strconv.Atoi(portString) - if err != nil { - return opts + if host != "" { + opts = append(opts, WithAttributes( + semconv.ServerAddress(host), + )) } - opts = append(opts, WithAttributes( - semconv.ServerPort(port), - )) + if portString != "" { + if port, err := strconv.Atoi(portString); err == nil && port > 0 { + opts = append(opts, WithAttributes( + semconv.ServerPort(port), + )) + } + } return opts } diff --git a/extra/redisotel/unified_filtering_test.go b/extra/redisotel/unified_filtering_test.go new file mode 100644 index 0000000000..83f35d2710 --- /dev/null +++ b/extra/redisotel/unified_filtering_test.go @@ -0,0 +1,423 @@ +package redisotel + +import ( + "context" + "fmt" + "net" + "strings" + "testing" + + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/trace" + + "github.com/redis/go-redis/v9" +) + +// TestWithExcludedCommands tests the new O(1) command exclusion feature. +func TestWithExcludedCommands(t *testing.T) { + t.Run("exclude multiple commands with O(1) lookup", func(t *testing.T) { + provider := sdktrace.NewTracerProvider() + hook := newTracingHook( + "", + WithTracerProvider(provider), + WithExcludedCommands("PING", "INFO", "SELECT"), + ) + ctx, span := provider.Tracer("redis-test").Start(context.Background(), "redis-test") + defer span.End() + + tests := []struct { + name string + cmdName string + shouldFilter bool + }{ + // Commands that should be excluded + {"ping lowercase should be excluded", "ping", true}, + {"PING uppercase should be excluded", "PING", true}, + {"info lowercase should be excluded", "info", true}, + {"INFO uppercase should be excluded", "INFO", true}, + {"select lowercase should be excluded", "select", true}, + {"SELECT uppercase should be excluded", "SELECT", true}, + // Commands that should not be excluded + {"get command should not be excluded", "get", false}, + {"set command should not be excluded", "set", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := redis.NewCmd(ctx, tt.cmdName) + processHook := hook.ProcessHook(func(ctx context.Context, cmd redis.Cmder) error { + innerSpan := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan) + if tt.shouldFilter { + // Filtered commands should use parent span, not create new span + if innerSpan.Name() != "redis-test" { + t.Errorf("command %q should be filtered out, got span name %q", tt.cmdName, innerSpan.Name()) + } + } else { + // Should create new span with command name + if innerSpan.Name() != tt.cmdName { + t.Fatalf("%s command should not be filtered", tt.cmdName) + } + } + return nil + }) + err := processHook(ctx, cmd) + if err != nil { + t.Fatal(err) + } + }) + } + }) + + t.Run("multiple calls to WithExcludedCommands should accumulate", func(t *testing.T) { + provider := sdktrace.NewTracerProvider() + hook := newTracingHook( + "", + WithTracerProvider(provider), + WithExcludedCommands("PING", "INFO"), + WithExcludedCommands("SELECT", "CONFIG"), + ) + ctx, span := provider.Tracer("redis-test").Start(context.Background(), "redis-test") + defer span.End() + + tests := []struct { + cmdName string + shouldFilter bool + }{ + {"PING", true}, // From first call + {"INFO", true}, // From first call + {"SELECT", true}, // From second call + {"CONFIG", true}, // From second call + {"GET", false}, // Not excluded + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("command_%s", tt.cmdName), func(t *testing.T) { + cmd := redis.NewCmd(ctx, tt.cmdName) + processHook := hook.ProcessHook(func(ctx context.Context, cmd redis.Cmder) error { + innerSpan := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan) + if tt.shouldFilter { + // Filtered commands should use parent span + if innerSpan.Name() != "redis-test" { + t.Errorf("command %q should be filtered out, got span name %q", tt.cmdName, innerSpan.Name()) + } + } else { + // Should create new span with command name (normalized to lowercase) + expectedSpanName := strings.ToLower(tt.cmdName) + if innerSpan.Name() != expectedSpanName { + t.Errorf("command %q should not be filtered, got span name %q, expected %q", tt.cmdName, innerSpan.Name(), expectedSpanName) + } + } + return nil + }) + err := processHook(ctx, cmd) + if err != nil { + t.Fatal(err) + } + }) + } + }) +} + +// TestUnifiedFilteringAPI tests the new unified filtering API. +func TestUnifiedFilteringAPI(t *testing.T) { + t.Run("process filter with custom logic", func(t *testing.T) { + provider := sdktrace.NewTracerProvider() + hook := newTracingHook( + "", + WithTracerProvider(provider), + WithProcessFilter(func(cmd redis.Cmder) bool { + // Filter out commands that start with "dangerous" + return len(cmd.Name()) > 9 && cmd.Name()[:9] == "dangerous" + }), + ) + ctx, span := provider.Tracer("redis-test").Start(context.Background(), "redis-test") + defer span.End() + + tests := []struct { + name string + cmdName string + shouldFilter bool + }{ + {"dangerous command should be filtered", "dangerous_cmd", true}, + {"safe command should not be filtered", "safe_cmd", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := redis.NewCmd(ctx, tt.cmdName) + processHook := hook.ProcessHook(func(ctx context.Context, cmd redis.Cmder) error { + innerSpan := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan) + if tt.shouldFilter { + if innerSpan.Name() != "redis-test" { + t.Fatalf("%s command should be filtered out", tt.cmdName) + } + } else { + if innerSpan.Name() != tt.cmdName { + t.Fatalf("%s command should not be filtered", tt.cmdName) + } + } + return nil + }) + err := processHook(ctx, cmd) + if err != nil { + t.Fatal(err) + } + }) + } + }) + + t.Run("process pipeline filter with custom logic", func(t *testing.T) { + provider := sdktrace.NewTracerProvider() + hook := newTracingHook( + "", + WithTracerProvider(provider), + WithProcessPipelineFilter(func(cmds []redis.Cmder) bool { + // Filter pipelines that contain more than 5 commands + return len(cmds) > 5 + }), + ) + ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test") + defer span.End() + + tests := []struct { + name string + cmdCount int + shouldFilter bool + }{ + {"small pipeline should not be filtered", 3, false}, + {"large pipeline should be filtered", 7, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmds := make([]redis.Cmder, tt.cmdCount) + for i := 0; i < tt.cmdCount; i++ { + cmds[i] = redis.NewCmd(ctx, "ping") + } + + processPipelineHook := hook.ProcessPipelineHook(func(ctx context.Context, cmds []redis.Cmder) error { + innerSpan := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan) + if tt.shouldFilter { + if innerSpan.Name() != "redis-test" { + t.Fatalf("large pipeline should be filtered out") + } + } else { + if innerSpan.Name() != "redis.pipeline ping" { + t.Fatalf("small pipeline should not be filtered") + } + } + return nil + }) + err := processPipelineHook(ctx, cmds) + if err != nil { + t.Fatal(err) + } + }) + } + }) + + t.Run("dial filter with custom logic", func(t *testing.T) { + provider := sdktrace.NewTracerProvider() + hook := newTracingHook( + "", + WithTracerProvider(provider), + WithDialFilterFunc(func(network, addr string) bool { + // Filter connections to localhost + return addr == "localhost:6379" + }), + ) + ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test") + defer span.End() + + tests := []struct { + name string + addr string + shouldFilter bool + }{ + {"localhost should be filtered", "localhost:6379", true}, + {"remote should not be filtered", "remote:6379", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dialHook := hook.DialHook(func(ctx context.Context, network, addr string) (conn net.Conn, err error) { + innerSpan := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan) + if tt.shouldFilter { + if innerSpan.Name() != "redis-test" { + t.Fatalf("localhost dial should be filtered out") + } + } else { + if innerSpan.Name() != "redis.dial" { + t.Fatalf("remote dial should not be filtered") + } + } + return nil, nil + }) + _, err := dialHook(ctx, "tcp", tt.addr) + if err != nil { + t.Fatal(err) + } + }) + } + }) +} + +// TestCombinedApproach tests that ExcludedCommands fast path works with custom filters +func TestCombinedApproach(t *testing.T) { + t.Run("excluded commands take precedence over custom filters", func(t *testing.T) { + provider := sdktrace.NewTracerProvider() + hook := newTracingHook( + "", + WithTracerProvider(provider), + WithExcludedCommands("PING"), + WithProcessFilter(func(cmd redis.Cmder) bool { + // This filter would normally allow PING, but excluded commands take precedence + return false // never filter + }), + ) + ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test") + cmd := redis.NewCmd(ctx, "ping") + defer span.End() + + processHook := hook.ProcessHook(func(ctx context.Context, cmd redis.Cmder) error { + innerSpan := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan) + if innerSpan.Name() != "redis-test" { + t.Fatalf("PING should be excluded despite ProcessFilter saying otherwise") + } + return nil + }) + err := processHook(ctx, cmd) + if err != nil { + t.Fatal(err) + } + }) + + t.Run("custom filter applied when command not in exclusion set", func(t *testing.T) { + provider := sdktrace.NewTracerProvider() + hook := newTracingHook( + "", + WithTracerProvider(provider), + WithExcludedCommands("PING"), + WithProcessFilter(func(cmd redis.Cmder) bool { + // Filter get commands + return cmd.Name() == "get" + }), + ) + ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test") + defer span.End() + + tests := []struct { + name string + cmdName string + shouldFilter bool + }{ + {"ping excluded by set", "ping", true}, + {"get filtered by custom filter", "get", true}, + {"set not filtered", "set", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := redis.NewCmd(ctx, tt.cmdName) + processHook := hook.ProcessHook(func(ctx context.Context, cmd redis.Cmder) error { + innerSpan := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan) + if tt.shouldFilter { + if innerSpan.Name() != "redis-test" { + t.Fatalf("%s command should be filtered out", tt.cmdName) + } + } else { + if innerSpan.Name() != tt.cmdName { + t.Fatalf("%s command should not be filtered", tt.cmdName) + } + } + return nil + }) + err := processHook(ctx, cmd) + if err != nil { + t.Fatal(err) + } + }) + } + }) +} + +// TestBackwardCompatibility ensures existing APIs still work +func TestBackwardCompatibility(t *testing.T) { + t.Run("legacy WithCommandFilter still works", func(t *testing.T) { + provider := sdktrace.NewTracerProvider() + hook := newTracingHook( + "", + WithTracerProvider(provider), + WithCommandFilter(func(cmd redis.Cmder) bool { + return cmd.Name() == "ping" + }), + ) + ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test") + cmd := redis.NewCmd(ctx, "ping") + defer span.End() + + processHook := hook.ProcessHook(func(ctx context.Context, cmd redis.Cmder) error { + innerSpan := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan) + if innerSpan.Name() != "redis-test" { + t.Fatalf("ping should be filtered by legacy filter") + } + return nil + }) + err := processHook(ctx, cmd) + if err != nil { + t.Fatal(err) + } + }) + + t.Run("legacy WithCommandsFilter still works", func(t *testing.T) { + provider := sdktrace.NewTracerProvider() + hook := newTracingHook( + "", + WithTracerProvider(provider), + WithCommandsFilter(func(cmds []redis.Cmder) bool { + return len(cmds) > 1 + }), + ) + ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test") + cmds := []redis.Cmder{ + redis.NewCmd(ctx, "ping"), + redis.NewCmd(ctx, "ping"), + } + defer span.End() + + processPipelineHook := hook.ProcessPipelineHook(func(ctx context.Context, cmds []redis.Cmder) error { + innerSpan := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan) + if innerSpan.Name() != "redis-test" { + t.Fatalf("multi-command pipeline should be filtered") + } + return nil + }) + err := processPipelineHook(ctx, cmds) + if err != nil { + t.Fatal(err) + } + }) + + t.Run("legacy WithDialFilter still works", func(t *testing.T) { + provider := sdktrace.NewTracerProvider() + hook := newTracingHook( + "", + WithTracerProvider(provider), + WithDialFilter(true), // enable filtering + ) + ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test") + defer span.End() + + dialHook := hook.DialHook(func(ctx context.Context, network, addr string) (conn net.Conn, err error) { + innerSpan := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan) + if innerSpan.Name() != "redis-test" { + t.Fatalf("dial should be filtered when filterDial is true") + } + return nil, nil + }) + _, err := dialHook(ctx, "tcp", "localhost:6379") + if err != nil { + t.Fatal(err) + } + }) +}