diff --git a/hset_benchmark_test.go b/hset_benchmark_test.go index 8d141f4193..649c935241 100644 --- a/hset_benchmark_test.go +++ b/hset_benchmark_test.go @@ -3,6 +3,7 @@ package redis_test import ( "context" "fmt" + "sync" "testing" "time" @@ -100,7 +101,82 @@ func benchmarkHSETOperations(b *testing.B, rdb *redis.Client, ctx context.Contex avgTimePerOp := b.Elapsed().Nanoseconds() / int64(operations*b.N) b.ReportMetric(float64(avgTimePerOp), "ns/op") // report average time in milliseconds from totalTimes - avgTimePerOpMs := totalTimes[0].Milliseconds() / int64(len(totalTimes)) + sumTime := time.Duration(0) + for _, t := range totalTimes { + sumTime += t + } + avgTimePerOpMs := sumTime.Milliseconds() / int64(len(totalTimes)) + b.ReportMetric(float64(avgTimePerOpMs), "ms") +} + +// benchmarkHSETOperationsConcurrent performs the actual HSET benchmark for a given scale +func benchmarkHSETOperationsConcurrent(b *testing.B, rdb *redis.Client, ctx context.Context, operations int) { + hashKey := fmt.Sprintf("benchmark_hash_%d", operations) + + b.ResetTimer() + b.StartTimer() + totalTimes := []time.Duration{} + + for i := 0; i < b.N; i++ { + b.StopTimer() + // Clean up the hash before each iteration + rdb.Del(ctx, hashKey) + b.StartTimer() + + startTime := time.Now() + // Perform the specified number of HSET operations + + wg := sync.WaitGroup{} + timesCh := make(chan time.Duration, operations) + errCh := make(chan error, operations) + + for j := 0; j < operations; j++ { + wg.Add(1) + go func(j int) { + defer wg.Done() + field := fmt.Sprintf("field_%d", j) + value := fmt.Sprintf("value_%d", j) + + err := rdb.HSet(ctx, hashKey, field, value).Err() + if err != nil { + errCh <- err + return + } + timesCh <- time.Since(startTime) + }(j) + } + + wg.Wait() + close(timesCh) + close(errCh) + + // Check for errors + for err := range errCh { + b.Errorf("HSET operation failed: %v", err) + } + + for d := range timesCh { + totalTimes = append(totalTimes, d) + } + } + + // Stop the timer to calculate metrics + b.StopTimer() + + // Report operations per second + opsPerSec := float64(operations*b.N) / b.Elapsed().Seconds() + b.ReportMetric(opsPerSec, "ops/sec") + + // Report average time per operation + avgTimePerOp := b.Elapsed().Nanoseconds() / int64(operations*b.N) + b.ReportMetric(float64(avgTimePerOp), "ns/op") + // report average time in milliseconds from totalTimes + + sumTime := time.Duration(0) + for _, t := range totalTimes { + sumTime += t + } + avgTimePerOpMs := sumTime.Milliseconds() / int64(len(totalTimes)) b.ReportMetric(float64(avgTimePerOpMs), "ms") } @@ -134,6 +210,37 @@ func BenchmarkHSETPipelined(b *testing.B) { } } +func BenchmarkHSET_Concurrent(b *testing.B) { + ctx := context.Background() + + // Setup Redis client + rdb := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + DB: 0, + PoolSize: 100, + }) + defer rdb.Close() + + // Test connection + if err := rdb.Ping(ctx).Err(); err != nil { + b.Skipf("Redis server not available: %v", err) + } + + // Clean up before and after tests + defer func() { + rdb.FlushDB(ctx) + }() + + // Reduced scales to avoid overwhelming the system with too many concurrent goroutines + scales := []int{1, 10, 100, 1000} + + for _, scale := range scales { + b.Run(fmt.Sprintf("HSET_%d_operations_concurrent", scale), func(b *testing.B) { + benchmarkHSETOperationsConcurrent(b, rdb, ctx, scale) + }) + } +} + // benchmarkHSETPipelined performs HSET benchmark using pipelining func benchmarkHSETPipelined(b *testing.B, rdb *redis.Client, ctx context.Context, operations int) { hashKey := fmt.Sprintf("benchmark_hash_pipelined_%d", operations) @@ -177,7 +284,11 @@ func benchmarkHSETPipelined(b *testing.B, rdb *redis.Client, ctx context.Context avgTimePerOp := b.Elapsed().Nanoseconds() / int64(operations*b.N) b.ReportMetric(float64(avgTimePerOp), "ns/op") // report average time in milliseconds from totalTimes - avgTimePerOpMs := totalTimes[0].Milliseconds() / int64(len(totalTimes)) + sumTime := time.Duration(0) + for _, t := range totalTimes { + sumTime += t + } + avgTimePerOpMs := sumTime.Milliseconds() / int64(len(totalTimes)) b.ReportMetric(float64(avgTimePerOpMs), "ms") } diff --git a/internal/auth/streaming/pool_hook.go b/internal/auth/streaming/pool_hook.go index a5647be0a3..f37fe557c0 100644 --- a/internal/auth/streaming/pool_hook.go +++ b/internal/auth/streaming/pool_hook.go @@ -34,9 +34,10 @@ type ReAuthPoolHook struct { shouldReAuth map[uint64]func(error) shouldReAuthLock sync.RWMutex - // workers is a semaphore channel limiting concurrent re-auth operations + // workers is a semaphore limiting concurrent re-auth operations // Initialized with poolSize tokens to prevent pool exhaustion - workers chan struct{} + // Uses FastSemaphore for consistency and better performance + workers *internal.FastSemaphore // reAuthTimeout is the maximum time to wait for acquiring a connection for re-auth reAuthTimeout time.Duration @@ -59,16 +60,10 @@ type ReAuthPoolHook struct { // The poolSize parameter is used to initialize the worker semaphore, ensuring that // re-auth operations don't exhaust the connection pool. func NewReAuthPoolHook(poolSize int, reAuthTimeout time.Duration) *ReAuthPoolHook { - workers := make(chan struct{}, poolSize) - // Initialize the workers channel with tokens (semaphore pattern) - for i := 0; i < poolSize; i++ { - workers <- struct{}{} - } - return &ReAuthPoolHook{ shouldReAuth: make(map[uint64]func(error)), scheduledReAuth: make(map[uint64]bool), - workers: workers, + workers: internal.NewFastSemaphore(int32(poolSize)), reAuthTimeout: reAuthTimeout, } } @@ -162,10 +157,10 @@ func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool, r.scheduledLock.Unlock() r.shouldReAuthLock.Unlock() go func() { - <-r.workers + r.workers.AcquireBlocking() // safety first if conn == nil || (conn != nil && conn.IsClosed()) { - r.workers <- struct{}{} + r.workers.Release() return } defer func() { @@ -176,7 +171,7 @@ func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool, r.scheduledLock.Lock() delete(r.scheduledReAuth, connID) r.scheduledLock.Unlock() - r.workers <- struct{}{} + r.workers.Release() }() // Create timeout context for connection acquisition diff --git a/internal/pool/conn.go b/internal/pool/conn.go index 4d38184a0b..56be70985f 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -1,3 +1,4 @@ +// Package pool implements the pool management package pool import ( @@ -17,6 +18,35 @@ import ( var noDeadline = time.Time{} +// Global time cache updated every 50ms by background goroutine. +// This avoids expensive time.Now() syscalls in hot paths like getEffectiveReadTimeout. +// Max staleness: 50ms, which is acceptable for timeout deadline checks (timeouts are typically 3-30 seconds). +var globalTimeCache struct { + nowNs atomic.Int64 +} + +func init() { + // Initialize immediately + globalTimeCache.nowNs.Store(time.Now().UnixNano()) + + // Start background updater + go func() { + ticker := time.NewTicker(50 * time.Millisecond) + defer ticker.Stop() + + for range ticker.C { + globalTimeCache.nowNs.Store(time.Now().UnixNano()) + } + }() +} + +// getCachedTimeNs returns the current time in nanoseconds from the global cache. +// This is updated every 50ms by a background goroutine, avoiding expensive syscalls. +// Max staleness: 50ms. +func getCachedTimeNs() int64 { + return globalTimeCache.nowNs.Load() +} + // Global atomic counter for connection IDs var connIDCounter uint64 @@ -79,6 +109,7 @@ type Conn struct { expiresAt time.Time // maintenanceNotifications upgrade support: relaxed timeouts during migrations/failovers + // Using atomic operations for lock-free access to avoid mutex contention relaxedReadTimeoutNs atomic.Int64 // time.Duration as nanoseconds relaxedWriteTimeoutNs atomic.Int64 // time.Duration as nanoseconds @@ -260,11 +291,13 @@ func (cn *Conn) CompareAndSwapUsed(old, new bool) bool { if !old && new { // Acquiring: IDLE → IN_USE - _, err := cn.stateMachine.TryTransition([]ConnState{StateIdle}, StateInUse) + // Use predefined slice to avoid allocation + _, err := cn.stateMachine.TryTransition(validFromCreatedOrIdle, StateInUse) return err == nil } else { // Releasing: IN_USE → IDLE - _, err := cn.stateMachine.TryTransition([]ConnState{StateInUse}, StateIdle) + // Use predefined slice to avoid allocation + _, err := cn.stateMachine.TryTransition(validFromInUse, StateIdle) return err == nil } } @@ -454,7 +487,8 @@ func (cn *Conn) getEffectiveReadTimeout(normalTimeout time.Duration) time.Durati return time.Duration(readTimeoutNs) } - nowNs := time.Now().UnixNano() + // Use cached time to avoid expensive syscall (max 50ms staleness is acceptable for timeout checks) + nowNs := getCachedTimeNs() // Check if deadline has passed if nowNs < deadlineNs { // Deadline is in the future, use relaxed timeout @@ -487,7 +521,8 @@ func (cn *Conn) getEffectiveWriteTimeout(normalTimeout time.Duration) time.Durat return time.Duration(writeTimeoutNs) } - nowNs := time.Now().UnixNano() + // Use cached time to avoid expensive syscall (max 50ms staleness is acceptable for timeout checks) + nowNs := getCachedTimeNs() // Check if deadline has passed if nowNs < deadlineNs { // Deadline is in the future, use relaxed timeout @@ -632,7 +667,8 @@ func (cn *Conn) MarkQueuedForHandoff() error { // The connection is typically in IN_USE state when OnPut is called (normal Put flow) // But in some edge cases or tests, it might be in IDLE or CREATED state // The pool will detect this state change and preserve it (not overwrite with IDLE) - finalState, err := cn.stateMachine.TryTransition([]ConnState{StateInUse, StateIdle, StateCreated}, StateUnusable) + // Use predefined slice to avoid allocation + finalState, err := cn.stateMachine.TryTransition(validFromCreatedInUseOrIdle, StateUnusable) if err != nil { // Check if already in UNUSABLE state (race condition or retry) // ShouldHandoff should be false now, but check just in case @@ -658,6 +694,42 @@ func (cn *Conn) GetStateMachine() *ConnStateMachine { return cn.stateMachine } +// TryAcquire attempts to acquire the connection for use. +// This is an optimized inline method for the hot path (Get operation). +// +// It tries to transition from IDLE -> IN_USE or CREATED -> IN_USE. +// Returns true if the connection was successfully acquired, false otherwise. +// +// Performance: This is faster than calling GetStateMachine() + TryTransitionFast() +// +// NOTE: We directly access cn.stateMachine.state here instead of using the state machine's +// methods. This breaks encapsulation but is necessary for performance. +// The IDLE->IN_USE and CREATED->IN_USE transitions don't need +// waiter notification, and benchmarks show 1-3% improvement. If the state machine ever +// needs to notify waiters on these transitions, update this to use TryTransitionFast(). +func (cn *Conn) TryAcquire() bool { + // The || operator short-circuits, so only 1 CAS in the common case + return cn.stateMachine.state.CompareAndSwap(uint32(StateIdle), uint32(StateInUse)) || + cn.stateMachine.state.CompareAndSwap(uint32(StateCreated), uint32(StateInUse)) +} + +// Release releases the connection back to the pool. +// This is an optimized inline method for the hot path (Put operation). +// +// It tries to transition from IN_USE -> IDLE. +// Returns true if the connection was successfully released, false otherwise. +// +// Performance: This is faster than calling GetStateMachine() + TryTransitionFast(). +// +// NOTE: We directly access cn.stateMachine.state here instead of using the state machine's +// methods. This breaks encapsulation but is necessary for performance. +// If the state machine ever needs to notify waiters +// on this transition, update this to use TryTransitionFast(). +func (cn *Conn) Release() bool { + // Inline the hot path - single CAS operation + return cn.stateMachine.state.CompareAndSwap(uint32(StateInUse), uint32(StateIdle)) +} + // ClearHandoffState clears the handoff state after successful handoff. // Makes the connection usable again. func (cn *Conn) ClearHandoffState() { @@ -800,8 +872,12 @@ func (cn *Conn) MaybeHasData() bool { return false } +// deadline computes the effective deadline time based on context and timeout. +// It updates the usedAt timestamp to now. +// Uses cached time to avoid expensive syscall (max 50ms staleness is acceptable for deadline calculation). func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time { - tm := time.Now() + // Use cached time for deadline calculation (called 2x per command: read + write) + tm := time.Unix(0, getCachedTimeNs()) cn.SetUsedAt(tm) if timeout > 0 { diff --git a/internal/pool/conn_state.go b/internal/pool/conn_state.go index 93147d1766..32fc505835 100644 --- a/internal/pool/conn_state.go +++ b/internal/pool/conn_state.go @@ -41,6 +41,13 @@ const ( StateClosed ) +// Predefined state slices to avoid allocations in hot paths +var ( + validFromInUse = []ConnState{StateInUse} + validFromCreatedOrIdle = []ConnState{StateCreated, StateIdle} + validFromCreatedInUseOrIdle = []ConnState{StateCreated, StateInUse, StateIdle} +) + // String returns a human-readable string representation of the state. func (s ConnState) String() string { switch s { @@ -92,8 +99,9 @@ type ConnStateMachine struct { state atomic.Uint32 // FIFO queue for waiters - only locked during waiter add/remove/notify - mu sync.Mutex - waiters *list.List // List of *waiter + mu sync.Mutex + waiters *list.List // List of *waiter + waiterCount atomic.Int32 // Fast lock-free check for waiters (avoids mutex in hot path) } // NewConnStateMachine creates a new connection state machine. @@ -114,6 +122,23 @@ func (sm *ConnStateMachine) GetState() ConnState { return ConnState(sm.state.Load()) } +// TryTransitionFast is an optimized version for the hot path (Get/Put operations). +// It only handles simple state transitions without waiter notification. +// This is safe because: +// 1. Get/Put don't need to wait for state changes +// 2. Background operations (handoff/reauth) use UNUSABLE state, which this won't match +// 3. If a background operation is in progress (state is UNUSABLE), this fails fast +// +// Returns true if transition succeeded, false otherwise. +// Use this for performance-critical paths where you don't need error details. +// +// Performance: Single CAS operation - as fast as the old atomic bool! +// For multiple from states, use: sm.TryTransitionFast(State1, Target) || sm.TryTransitionFast(State2, Target) +// The || operator short-circuits, so only 1 CAS is executed in the common case. +func (sm *ConnStateMachine) TryTransitionFast(fromState, targetState ConnState) bool { + return sm.state.CompareAndSwap(uint32(fromState), uint32(targetState)) +} + // TryTransition attempts an immediate state transition without waiting. // Returns the current state after the transition attempt and an error if the transition failed. // The returned state is the CURRENT state (after the attempt), not the previous state. @@ -126,17 +151,15 @@ func (sm *ConnStateMachine) TryTransition(validFromStates []ConnState, targetSta // Try each valid from state with CAS // This ensures only ONE goroutine can successfully transition at a time for _, fromState := range validFromStates { - // Fast path: check if we're already in target state - if fromState == targetState && sm.GetState() == targetState { - return targetState, nil - } - // Try to atomically swap from fromState to targetState // If successful, we won the race and can proceed if sm.state.CompareAndSwap(uint32(fromState), uint32(targetState)) { // Success! We transitioned atomically - // Notify any waiters - sm.notifyWaiters() + // Hot path optimization: only check for waiters if transition succeeded + // This avoids atomic load on every Get/Put when no waiters exist + if sm.waiterCount.Load() > 0 { + sm.notifyWaiters() + } return targetState, nil } } @@ -213,6 +236,7 @@ func (sm *ConnStateMachine) AwaitAndTransition( // Add to FIFO queue sm.mu.Lock() elem := sm.waiters.PushBack(w) + sm.waiterCount.Add(1) sm.mu.Unlock() // Wait for state change or timeout @@ -221,10 +245,13 @@ func (sm *ConnStateMachine) AwaitAndTransition( // Timeout or cancellation - remove from queue sm.mu.Lock() sm.waiters.Remove(elem) + sm.waiterCount.Add(-1) sm.mu.Unlock() return sm.GetState(), ctx.Err() case err := <-w.done: // Transition completed (or failed) + // Note: waiterCount is decremented either in notifyWaiters (when the waiter is notified and removed) + // or here (on timeout/cancellation). return sm.GetState(), err } } @@ -232,9 +259,16 @@ func (sm *ConnStateMachine) AwaitAndTransition( // notifyWaiters checks if any waiters can proceed and notifies them in FIFO order. // This is called after every state transition. func (sm *ConnStateMachine) notifyWaiters() { + // Fast path: check atomic counter without acquiring lock + // This eliminates mutex overhead in the common case (no waiters) + if sm.waiterCount.Load() == 0 { + return + } + sm.mu.Lock() defer sm.mu.Unlock() + // Double-check after acquiring lock (waiters might have been processed) if sm.waiters.Len() == 0 { return } @@ -255,6 +289,7 @@ func (sm *ConnStateMachine) notifyWaiters() { if _, valid := w.validStates[currentState]; valid { // Remove from queue first sm.waiters.Remove(elem) + sm.waiterCount.Add(-1) // Use CAS to ensure state hasn't changed since we checked // This prevents race condition where another thread changes state @@ -267,6 +302,7 @@ func (sm *ConnStateMachine) notifyWaiters() { } else { // State changed - re-add waiter to front of queue and retry sm.waiters.PushFront(w) + sm.waiterCount.Add(1) // Continue to next iteration to re-read state processed = true break diff --git a/internal/pool/export_test.go b/internal/pool/export_test.go index 20456b8100..2d17803854 100644 --- a/internal/pool/export_test.go +++ b/internal/pool/export_test.go @@ -20,5 +20,5 @@ func (p *ConnPool) CheckMinIdleConns() { } func (p *ConnPool) QueueLen() int { - return len(p.queue) + return int(p.semaphore.Len()) } diff --git a/internal/pool/hooks.go b/internal/pool/hooks.go index bfbd9e14e0..1c365dbabe 100644 --- a/internal/pool/hooks.go +++ b/internal/pool/hooks.go @@ -140,3 +140,16 @@ func (phm *PoolHookManager) GetHooks() []PoolHook { copy(hooks, phm.hooks) return hooks } + +// Clone creates a copy of the hook manager with the same hooks. +// This is used for lock-free atomic updates of the hook manager. +func (phm *PoolHookManager) Clone() *PoolHookManager { + phm.hooksMu.RLock() + defer phm.hooksMu.RUnlock() + + newManager := &PoolHookManager{ + hooks: make([]PoolHook, len(phm.hooks)), + } + copy(newManager.hooks, phm.hooks) + return newManager +} diff --git a/internal/pool/hooks_test.go b/internal/pool/hooks_test.go index ec1d6da351..b8f504dfd2 100644 --- a/internal/pool/hooks_test.go +++ b/internal/pool/hooks_test.go @@ -202,26 +202,29 @@ func TestPoolWithHooks(t *testing.T) { pool.AddPoolHook(testHook) // Verify hooks are initialized - if pool.hookManager == nil { + manager := pool.hookManager.Load() + if manager == nil { t.Error("Expected hookManager to be initialized") } - if pool.hookManager.GetHookCount() != 1 { - t.Errorf("Expected 1 hook in pool, got %d", pool.hookManager.GetHookCount()) + if manager.GetHookCount() != 1 { + t.Errorf("Expected 1 hook in pool, got %d", manager.GetHookCount()) } // Test adding hook to pool additionalHook := &TestHook{ShouldPool: true, ShouldAccept: true} pool.AddPoolHook(additionalHook) - if pool.hookManager.GetHookCount() != 2 { - t.Errorf("Expected 2 hooks after adding, got %d", pool.hookManager.GetHookCount()) + manager = pool.hookManager.Load() + if manager.GetHookCount() != 2 { + t.Errorf("Expected 2 hooks after adding, got %d", manager.GetHookCount()) } // Test removing hook from pool pool.RemovePoolHook(additionalHook) - if pool.hookManager.GetHookCount() != 1 { - t.Errorf("Expected 1 hook after removing, got %d", pool.hookManager.GetHookCount()) + manager = pool.hookManager.Load() + if manager.GetHookCount() != 1 { + t.Errorf("Expected 1 hook after removing, got %d", manager.GetHookCount()) } } diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 59b8e19434..2dedca0591 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -27,6 +27,12 @@ var ( // ErrConnUnusableTimeout is returned when a connection is not usable and we timed out trying to mark it as unusable. ErrConnUnusableTimeout = errors.New("redis: timed out trying to mark connection as unusable") + // errHookRequestedRemoval is returned when a hook requests connection removal. + errHookRequestedRemoval = errors.New("hook requested removal") + + // errConnNotPooled is returned when trying to return a non-pooled connection to the pool. + errConnNotPooled = errors.New("connection not pooled") + // popAttempts is the maximum number of attempts to find a usable connection // when popping from the idle connection pool. This handles cases where connections // are temporarily marked as unusable (e.g., during maintenanceNotifications upgrades or network issues). @@ -45,14 +51,6 @@ var ( noExpiration = maxTime ) -var timers = sync.Pool{ - New: func() interface{} { - t := time.NewTimer(time.Hour) - t.Stop() - return t - }, -} - // Stats contains pool state information and accumulated stats. type Stats struct { Hits uint32 // number of times free connection was found in the pool @@ -132,7 +130,9 @@ type ConnPool struct { dialErrorsNum uint32 // atomic lastDialError atomic.Value - queue chan struct{} + // Fast atomic semaphore for connection limiting + // Replaces the old channel-based queue for better performance + semaphore *internal.FastSemaphore connsMu sync.Mutex conns map[uint64]*Conn @@ -148,8 +148,8 @@ type ConnPool struct { _closed uint32 // atomic // Pool hooks manager for flexible connection processing - hookManagerMu sync.RWMutex - hookManager *PoolHookManager + // Using atomic.Pointer for lock-free reads in hot paths (Get/Put) + hookManager atomic.Pointer[PoolHookManager] } var _ Pooler = (*ConnPool)(nil) @@ -158,7 +158,7 @@ func NewConnPool(opt *Options) *ConnPool { p := &ConnPool{ cfg: opt, - queue: make(chan struct{}, opt.PoolSize), + semaphore: internal.NewFastSemaphore(opt.PoolSize), conns: make(map[uint64]*Conn), idleConns: make([]*Conn, 0, opt.PoolSize), } @@ -176,27 +176,37 @@ func NewConnPool(opt *Options) *ConnPool { // initializeHooks sets up the pool hooks system. func (p *ConnPool) initializeHooks() { - p.hookManager = NewPoolHookManager() + manager := NewPoolHookManager() + p.hookManager.Store(manager) } // AddPoolHook adds a pool hook to the pool. func (p *ConnPool) AddPoolHook(hook PoolHook) { - p.hookManagerMu.Lock() - defer p.hookManagerMu.Unlock() - - if p.hookManager == nil { + // Lock-free read of current manager + manager := p.hookManager.Load() + if manager == nil { p.initializeHooks() + manager = p.hookManager.Load() } - p.hookManager.AddHook(hook) + + // Create new manager with added hook + newManager := manager.Clone() + newManager.AddHook(hook) + + // Atomically swap to new manager + p.hookManager.Store(newManager) } // RemovePoolHook removes a pool hook from the pool. func (p *ConnPool) RemovePoolHook(hook PoolHook) { - p.hookManagerMu.Lock() - defer p.hookManagerMu.Unlock() + manager := p.hookManager.Load() + if manager != nil { + // Create new manager with removed hook + newManager := manager.Clone() + newManager.RemoveHook(hook) - if p.hookManager != nil { - p.hookManager.RemoveHook(hook) + // Atomically swap to new manager + p.hookManager.Store(newManager) } } @@ -213,31 +223,32 @@ func (p *ConnPool) checkMinIdleConns() { // Only create idle connections if we haven't reached the total pool size limit // MinIdleConns should be a subset of PoolSize, not additional connections for p.poolSize.Load() < p.cfg.PoolSize && p.idleConnsLen.Load() < p.cfg.MinIdleConns { - select { - case p.queue <- struct{}{}: - p.poolSize.Add(1) - p.idleConnsLen.Add(1) - go func() { - defer func() { - if err := recover(); err != nil { - p.poolSize.Add(-1) - p.idleConnsLen.Add(-1) - - p.freeTurn() - internal.Logger.Printf(context.Background(), "addIdleConn panic: %+v", err) - } - }() - - err := p.addIdleConn() - if err != nil && err != ErrClosed { + // Try to acquire a semaphore token + if !p.semaphore.TryAcquire() { + // Semaphore is full, can't create more connections + return + } + + p.poolSize.Add(1) + p.idleConnsLen.Add(1) + go func() { + defer func() { + if err := recover(); err != nil { p.poolSize.Add(-1) p.idleConnsLen.Add(-1) + + p.freeTurn() + internal.Logger.Printf(context.Background(), "addIdleConn panic: %+v", err) } - p.freeTurn() }() - default: - return - } + + err := p.addIdleConn() + if err != nil && err != ErrClosed { + p.poolSize.Add(-1) + p.idleConnsLen.Add(-1) + } + p.freeTurn() + }() } } @@ -281,7 +292,7 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) { return nil, ErrClosed } - if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() >= int32(p.cfg.MaxActiveConns) { + if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() >= p.cfg.MaxActiveConns { return nil, ErrPoolExhausted } @@ -296,7 +307,7 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) { // when first used. Do NOT transition to IDLE here - that happens after initialization completes. // The state machine flow is: CREATED → INITIALIZING (in initConn) → IDLE (after init success) - if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() > int32(p.cfg.MaxActiveConns) { + if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() > p.cfg.MaxActiveConns { _ = cn.Close() return nil, ErrPoolExhausted } @@ -441,14 +452,12 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { return nil, err } - now := time.Now() + // Use cached time for health checks (max 50ms staleness is acceptable) + now := time.Unix(0, getCachedTimeNs()) attempts := 0 - // Get hooks manager once for this getConn call for performance. - // Note: Hooks added/removed during this call won't be reflected. - p.hookManagerMu.RLock() - hookManager := p.hookManager - p.hookManagerMu.RUnlock() + // Lock-free atomic read - no mutex overhead! + hookManager := p.hookManager.Load() for { if attempts >= getAttempts { @@ -476,19 +485,20 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { } // Process connection using the hooks system + // Combine error and rejection checks to reduce branches if hookManager != nil { acceptConn, err := hookManager.ProcessOnGet(ctx, cn, false) - if err != nil { - internal.Logger.Printf(ctx, "redis: connection pool: failed to process idle connection by hook: %v", err) - _ = p.CloseConn(cn) - continue - } - if !acceptConn { - internal.Logger.Printf(ctx, "redis: connection pool: conn[%d] rejected by hook, returning to pool", cn.GetID()) - // Return connection to pool without freeing the turn that this Get() call holds. - // We use putConnWithoutTurn() to run all the Put hooks and logic without freeing a turn. - p.putConnWithoutTurn(ctx, cn) - cn = nil + if err != nil || !acceptConn { + if err != nil { + internal.Logger.Printf(ctx, "redis: connection pool: failed to process idle connection by hook: %v", err) + _ = p.CloseConn(cn) + } else { + internal.Logger.Printf(ctx, "redis: connection pool: conn[%d] rejected by hook, returning to pool", cn.GetID()) + // Return connection to pool without freeing the turn that this Get() call holds. + // We use putConnWithoutTurn() to run all the Put hooks and logic without freeing a turn. + p.putConnWithoutTurn(ctx, cn) + cn = nil + } continue } } @@ -521,44 +531,36 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { } func (p *ConnPool) waitTurn(ctx context.Context) error { + // Fast path: check context first select { case <-ctx.Done(): return ctx.Err() default: } - select { - case p.queue <- struct{}{}: + // Fast path: try to acquire without blocking + if p.semaphore.TryAcquire() { return nil - default: } + // Slow path: need to wait start := time.Now() - timer := timers.Get().(*time.Timer) - defer timers.Put(timer) - timer.Reset(p.cfg.PoolTimeout) + err := p.semaphore.Acquire(ctx, p.cfg.PoolTimeout, ErrPoolTimeout) - select { - case <-ctx.Done(): - if !timer.Stop() { - <-timer.C - } - return ctx.Err() - case p.queue <- struct{}{}: + switch err { + case nil: + // Successfully acquired after waiting p.waitDurationNs.Add(time.Now().UnixNano() - start.UnixNano()) atomic.AddUint32(&p.stats.WaitCount, 1) - if !timer.Stop() { - <-timer.C - } - return nil - case <-timer.C: + case ErrPoolTimeout: atomic.AddUint32(&p.stats.Timeouts, 1) - return ErrPoolTimeout } + + return err } func (p *ConnPool) freeTurn() { - <-p.queue + p.semaphore.Release() } func (p *ConnPool) popIdle() (*Conn, error) { @@ -592,15 +594,16 @@ func (p *ConnPool) popIdle() (*Conn, error) { } attempts++ - // Try to atomically transition to IN_USE using state machine - // Accept both CREATED (uninitialized) and IDLE (initialized) states - _, err := cn.GetStateMachine().TryTransition([]ConnState{StateCreated, StateIdle}, StateInUse) - if err == nil { + // Hot path optimization: try IDLE → IN_USE or CREATED → IN_USE transition + // Using inline TryAcquire() method for better performance (avoids pointer dereference) + if cn.TryAcquire() { // Successfully acquired the connection p.idleConnsLen.Add(-1) break } + // Connection is in UNUSABLE, INITIALIZING, or other state - skip it + // Connection is not in a valid state (might be UNUSABLE for handoff/re-auth, INITIALIZING, etc.) // Put it back in the pool and try the next one if p.cfg.PoolFIFO { @@ -651,9 +654,8 @@ func (p *ConnPool) putConn(ctx context.Context, cn *Conn, freeTurn bool) { // It's a push notification, allow pooling (client will handle it) } - p.hookManagerMu.RLock() - hookManager := p.hookManager - p.hookManagerMu.RUnlock() + // Lock-free atomic read - no mutex overhead! + hookManager := p.hookManager.Load() if hookManager != nil { shouldPool, shouldRemove, err = hookManager.ProcessOnPut(ctx, cn) @@ -664,41 +666,35 @@ func (p *ConnPool) putConn(ctx context.Context, cn *Conn, freeTurn bool) { } } - // If hooks say to remove the connection, do so - if shouldRemove { - p.removeConnInternal(ctx, cn, errors.New("hook requested removal"), freeTurn) - return - } - - // If processor says not to pool the connection, remove it - if !shouldPool { - p.removeConnInternal(ctx, cn, errors.New("hook requested no pooling"), freeTurn) + // Combine all removal checks into one - reduces branches + if shouldRemove || !shouldPool { + p.removeConnInternal(ctx, cn, errHookRequestedRemoval, freeTurn) return } if !cn.pooled { - p.removeConnInternal(ctx, cn, errors.New("connection not pooled"), freeTurn) + p.removeConnInternal(ctx, cn, errConnNotPooled, freeTurn) return } var shouldCloseConn bool if p.cfg.MaxIdleConns == 0 || p.idleConnsLen.Load() < p.cfg.MaxIdleConns { - // Try to transition to IDLE state BEFORE adding to pool - // Only transition if connection is still IN_USE (hooks might have changed state) - // This prevents: - // 1. Race condition where another goroutine could acquire a connection that's still in IN_USE state - // 2. Overwriting state changes made by hooks (e.g., IN_USE → UNUSABLE for handoff) - currentState, err := cn.GetStateMachine().TryTransition([]ConnState{StateInUse}, StateIdle) - if err != nil { - // Hook changed the state (e.g., to UNUSABLE for handoff) + // Hot path optimization: try fast IN_USE → IDLE transition + // Using inline Release() method for better performance (avoids pointer dereference) + transitionedToIdle := cn.Release() + + if !transitionedToIdle { + // Fast path failed - hook might have changed state (e.g., to UNUSABLE for handoff) // Keep the state set by the hook and pool the connection anyway + currentState := cn.GetStateMachine().GetState() internal.Logger.Printf(ctx, "Connection state changed by hook to %v, pooling as-is", currentState) } // unusable conns are expected to become usable at some point (background process is reconnecting them) // put them at the opposite end of the queue - if !cn.IsUsable() { + // Optimization: if we just transitioned to IDLE, we know it's usable - skip the check + if !transitionedToIdle && !cn.IsUsable() { if p.cfg.PoolFIFO { p.connsMu.Lock() p.idleConns = append(p.idleConns, cn) @@ -742,9 +738,8 @@ func (p *ConnPool) RemoveWithoutTurn(ctx context.Context, cn *Conn, reason error // removeConnInternal is the internal implementation of Remove that optionally frees a turn. func (p *ConnPool) removeConnInternal(ctx context.Context, cn *Conn, reason error, freeTurn bool) { - p.hookManagerMu.RLock() - hookManager := p.hookManager - p.hookManagerMu.RUnlock() + // Lock-free atomic read - no mutex overhead! + hookManager := p.hookManager.Load() if hookManager != nil { hookManager.ProcessOnRemove(ctx, cn, reason) @@ -877,36 +872,53 @@ func (p *ConnPool) Close() error { } func (p *ConnPool) isHealthyConn(cn *Conn, now time.Time) bool { - // slight optimization, check expiresAt first. - if cn.expiresAt.Before(now) { - return false + // Performance optimization: check conditions from cheapest to most expensive, + // and from most likely to fail to least likely to fail. + + // Only fails if ConnMaxLifetime is set AND connection is old. + // Most pools don't set ConnMaxLifetime, so this rarely fails. + if p.cfg.ConnMaxLifetime > 0 { + if cn.expiresAt.Before(now) { + return false // Connection has exceeded max lifetime + } } - // Check if connection has exceeded idle timeout - if p.cfg.ConnMaxIdleTime > 0 && now.Sub(cn.UsedAt()) >= p.cfg.ConnMaxIdleTime { - return false + // Most pools set ConnMaxIdleTime, and idle connections are common. + // Checking this first allows us to fail fast without expensive syscalls. + if p.cfg.ConnMaxIdleTime > 0 { + if now.Sub(cn.UsedAt()) >= p.cfg.ConnMaxIdleTime { + return false // Connection has been idle too long + } } - cn.SetUsedAt(now) - // Check basic connection health - // Use GetNetConn() to safely access netConn and avoid data races + // Only run this if the cheap checks passed. if err := connCheck(cn.getNetConn()); err != nil { // If there's unexpected data, it might be push notifications (RESP3) - // However, push notification processing is now handled by the client - // before WithReader to ensure proper context is available to handlers if p.cfg.PushNotificationsEnabled && err == errUnexpectedRead { - // we know that there is something in the buffer, so peek at the next reply type without - // the potential to block + // Peek at the reply type to check if it's a push notification if replyType, err := cn.rd.PeekReplyType(); err == nil && replyType == proto.RespPush { // For RESP3 connections with push notifications, we allow some buffered data // The client will process these notifications before using the connection - internal.Logger.Printf(context.Background(), "push: conn[%d] has buffered data, likely push notifications - will be processed by client", cn.GetID()) - return true // Connection is healthy, client will handle notifications + internal.Logger.Printf( + context.Background(), + "push: conn[%d] has buffered data, likely push notifications - will be processed by client", + cn.GetID(), + ) + + // Update timestamp for healthy connection + cn.SetUsedAt(now) + + // Connection is healthy, client will handle notifications + return true } - return false // Unexpected data, not push notifications, connection is unhealthy - } else { + // Not a push notification - treat as unhealthy return false } + // Connection failed health check + return false } + + // Only update UsedAt if connection is healthy (avoids unnecessary atomic store) + cn.SetUsedAt(now) return true } diff --git a/internal/pool/pubsub.go b/internal/pool/pubsub.go index ed87d1bbc7..5b29659eac 100644 --- a/internal/pool/pubsub.go +++ b/internal/pool/pubsub.go @@ -24,7 +24,7 @@ type PubSubPool struct { stats PubSubStats } -// PubSubPool implements a pool for PubSub connections. +// NewPubSubPool implements a pool for PubSub connections. // It intentionally does not implement the Pooler interface func NewPubSubPool(opt *Options, netDialer func(ctx context.Context, network, addr string) (net.Conn, error)) *PubSubPool { return &PubSubPool{ diff --git a/internal/proto/peek_push_notification_test.go b/internal/proto/peek_push_notification_test.go index 58a794b849..491867591d 100644 --- a/internal/proto/peek_push_notification_test.go +++ b/internal/proto/peek_push_notification_test.go @@ -370,9 +370,17 @@ func BenchmarkPeekPushNotificationName(b *testing.B) { buf := createValidPushNotification(tc.notification, "data") data := buf.Bytes() + // Reuse both bytes.Reader and proto.Reader to avoid allocations + bytesReader := bytes.NewReader(data) + reader := NewReader(bytesReader) + b.ResetTimer() + b.ReportAllocs() for i := 0; i < b.N; i++ { - reader := NewReader(bytes.NewReader(data)) + // Reset the bytes.Reader to the beginning without allocating + bytesReader.Reset(data) + // Reset the proto.Reader to reuse the bufio buffer + reader.Reset(bytesReader) _, err := reader.PeekPushNotificationName() if err != nil { b.Errorf("PeekPushNotificationName should not error: %v", err) diff --git a/internal/semaphore.go b/internal/semaphore.go new file mode 100644 index 0000000000..091b663586 --- /dev/null +++ b/internal/semaphore.go @@ -0,0 +1,161 @@ +package internal + +import ( + "context" + "sync" + "sync/atomic" + "time" +) + +var semTimers = sync.Pool{ + New: func() interface{} { + t := time.NewTimer(time.Hour) + t.Stop() + return t + }, +} + +// FastSemaphore is a counting semaphore implementation using atomic operations. +// It's optimized for the fast path (no blocking) while still supporting timeouts and context cancellation. +// +// Performance characteristics: +// - Fast path (no blocking): Single atomic CAS operation +// - Slow path (blocking): Falls back to channel-based waiting +// - Release: Single atomic decrement + optional channel notification +// +// This is significantly faster than a pure channel-based semaphore because: +// 1. The fast path avoids channel operations entirely (no scheduler involvement) +// 2. Atomic operations are much cheaper than channel send/receive +type FastSemaphore struct { + // Current number of acquired tokens (atomic) + count atomic.Int32 + + // Maximum number of tokens (capacity) + max int32 + + // Channel for blocking waiters + // Only used when fast path fails (semaphore is full) + waitCh chan struct{} +} + +// NewFastSemaphore creates a new fast semaphore with the given capacity. +func NewFastSemaphore(capacity int32) *FastSemaphore { + return &FastSemaphore{ + max: capacity, + waitCh: make(chan struct{}, capacity), + } +} + +// TryAcquire attempts to acquire a token without blocking. +// Returns true if successful, false if the semaphore is full. +// +// This is the fast path - just a single CAS operation. +func (s *FastSemaphore) TryAcquire() bool { + for { + current := s.count.Load() + if current >= s.max { + return false // Semaphore is full + } + if s.count.CompareAndSwap(current, current+1) { + return true // Successfully acquired + } + // CAS failed due to concurrent modification, retry + } +} + +// Acquire acquires a token, blocking if necessary until one is available or the context is cancelled. +// Returns an error if the context is cancelled or the timeout expires. +// Returns timeoutErr when the timeout expires. +// +// Performance optimization: +// 1. First try fast path (no blocking) +// 2. If that fails, fall back to channel-based waiting +func (s *FastSemaphore) Acquire(ctx context.Context, timeout time.Duration, timeoutErr error) error { + // Fast path: try to acquire without blocking + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + // Try fast acquire first + if s.TryAcquire() { + return nil + } + + // Fast path failed, need to wait + // Use timer pool to avoid allocation + timer := semTimers.Get().(*time.Timer) + defer semTimers.Put(timer) + timer.Reset(timeout) + + start := time.Now() + + for { + select { + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + return ctx.Err() + + case <-s.waitCh: + // Someone released a token, try to acquire it + if s.TryAcquire() { + if !timer.Stop() { + <-timer.C + } + return nil + } + // Failed to acquire (race with another goroutine), continue waiting + + case <-timer.C: + return timeoutErr + } + + // Periodically check if we can acquire (handles race conditions) + if time.Since(start) > timeout { + return timeoutErr + } + } +} + +// AcquireBlocking acquires a token, blocking indefinitely until one is available. +// This is useful for cases where you don't need timeout or context cancellation. +// Returns immediately if a token is available (fast path). +func (s *FastSemaphore) AcquireBlocking() { + // Try fast path first + if s.TryAcquire() { + return + } + + // Slow path: wait for a token + for { + <-s.waitCh + if s.TryAcquire() { + return + } + // Failed to acquire (race with another goroutine), continue waiting + } +} + +// Release releases a token back to the semaphore. +// This wakes up one waiting goroutine if any are blocked. +func (s *FastSemaphore) Release() { + s.count.Add(-1) + + // Try to wake up a waiter (non-blocking) + // If no one is waiting, this is a no-op + select { + case s.waitCh <- struct{}{}: + // Successfully notified a waiter + default: + // No waiters, that's fine + } +} + +// Len returns the current number of acquired tokens. +// Used by tests to check semaphore state. +func (s *FastSemaphore) Len() int32 { + return s.count.Load() +} diff --git a/redis_test.go b/redis_test.go index 0906d420b1..5cce3f25be 100644 --- a/redis_test.go +++ b/redis_test.go @@ -323,6 +323,7 @@ var _ = Describe("Client", func() { cn, err = client.Pool().Get(context.Background()) Expect(err).NotTo(HaveOccurred()) Expect(cn).NotTo(BeNil()) + Expect(cn.UsedAt().UnixNano()).To(BeNumerically(">", createdAt.UnixNano())) Expect(cn.UsedAt().After(createdAt)).To(BeTrue()) })