From 213b8c7a487bb815c72489c1964569b23e049881 Mon Sep 17 00:00:00 2001 From: ccoVeille <3875889+ccoVeille@users.noreply.github.com> Date: Fri, 24 Oct 2025 14:59:40 +0200 Subject: [PATCH] feat: add optional logger wherever possible This commit introduces an optional logger parameter to various structs. This enhancement allows users to provide custom logging implementations. --- internal/pool/pool.go | 33 ++++++---- maintnotifications/circuit_breaker.go | 26 ++++++-- maintnotifications/config.go | 5 ++ maintnotifications/handoff_worker.go | 44 +++++++------ maintnotifications/manager.go | 16 +++-- maintnotifications/pool_hook.go | 12 +++- .../push_notification_handler.go | 62 +++++++++++-------- options.go | 5 ++ osscluster.go | 37 ++++++++--- pubsub.go | 43 +++++++++++-- redis.go | 47 +++++++++----- ring.go | 14 ++++- sentinel.go | 43 +++++++++---- 13 files changed, 277 insertions(+), 110 deletions(-) diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 4915bf623d..d73cd74875 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -114,6 +114,9 @@ type Options struct { // DialerRetryTimeout is the backoff duration between retry attempts. // Default: 100ms DialerRetryTimeout time.Duration + + // Optional logger for connection pool operations. + Logger *internal.Logging } type lastDialErrorWrap struct { @@ -218,7 +221,7 @@ func (p *ConnPool) checkMinIdleConns() { p.idleConnsLen.Add(-1) p.freeTurn() - internal.Logger.Printf(context.Background(), "addIdleConn panic: %+v", err) + p.logf(context.Background(), "addIdleConn panic: %+v", err) } }() @@ -373,7 +376,7 @@ func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) { return cn, nil } - internal.Logger.Printf(ctx, "redis: connection pool: failed to dial after %d attempts: %v", maxRetries, lastErr) + p.logf(ctx, "redis: connection pool: failed to dial after %d attempts: %v", maxRetries, lastErr) // All retries failed - handle error tracking p.setLastDialError(lastErr) if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.cfg.PoolSize) { @@ -446,7 +449,7 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { for { if attempts >= getAttempts { - internal.Logger.Printf(ctx, "redis: connection pool: was not able to get a healthy connection after %d attempts", attempts) + p.logf(ctx, "redis: connection pool: was not able to get a healthy connection after %d attempts", attempts) break } attempts++ @@ -473,12 +476,12 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { if hookManager != nil { acceptConn, err := hookManager.ProcessOnGet(ctx, cn, false) if err != nil { - internal.Logger.Printf(ctx, "redis: connection pool: failed to process idle connection by hook: %v", err) + p.logf(ctx, "redis: connection pool: failed to process idle connection by hook: %v", err) _ = p.CloseConn(cn) continue } if !acceptConn { - internal.Logger.Printf(ctx, "redis: connection pool: conn[%d] rejected by hook, returning to pool", cn.GetID()) + p.logf(ctx, "redis: connection pool: conn[%d] rejected by hook, returning to pool", cn.GetID()) p.Put(ctx, cn) cn = nil continue @@ -504,7 +507,7 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { // this should not happen with a new connection, but we handle it gracefully if err != nil || !acceptConn { // Failed to process connection, discard it - internal.Logger.Printf(ctx, "redis: connection pool: failed to process new connection conn[%d] by hook: accept=%v, err=%v", newcn.GetID(), acceptConn, err) + p.logf(ctx, "redis: connection pool: failed to process new connection conn[%d] by hook: accept=%v, err=%v", newcn.GetID(), acceptConn, err) _ = p.CloseConn(newcn) return nil, err } @@ -605,7 +608,7 @@ func (p *ConnPool) popIdle() (*Conn, error) { // If we exhausted all attempts without finding a usable connection, return nil if attempts > 1 && attempts >= maxAttempts && int32(attempts) >= p.poolSize.Load() { - internal.Logger.Printf(context.Background(), "redis: connection pool: failed to get a usable connection after %d attempts", attempts) + p.logf(context.Background(), "redis: connection pool: failed to get a usable connection after %d attempts", attempts) return nil, nil } @@ -622,7 +625,7 @@ func (p *ConnPool) Put(ctx context.Context, cn *Conn) { // Peek at the reply type to check if it's a push notification if replyType, err := cn.PeekReplyTypeSafe(); err != nil || replyType != proto.RespPush { // Not a push notification or error peeking, remove connection - internal.Logger.Printf(ctx, "Conn has unread data (not push notification), removing it") + p.logf(ctx, "Conn has unread data (not push notification), removing it") p.Remove(ctx, cn, err) } // It's a push notification, allow pooling (client will handle it) @@ -635,7 +638,7 @@ func (p *ConnPool) Put(ctx context.Context, cn *Conn) { if hookManager != nil { shouldPool, shouldRemove, err = hookManager.ProcessOnPut(ctx, cn) if err != nil { - internal.Logger.Printf(ctx, "Connection hook error: %v", err) + p.logf(ctx, "Connection hook error: %v", err) p.Remove(ctx, cn, err) return } @@ -737,7 +740,7 @@ func (p *ConnPool) removeConn(cn *Conn) { // this can be idle conn for idx, ic := range p.idleConns { if ic.GetID() == cid { - internal.Logger.Printf(context.Background(), "redis: connection pool: removing idle conn[%d]", cid) + p.logf(context.Background(), "redis: connection pool: removing idle conn[%d]", cid) p.idleConns = append(p.idleConns[:idx], p.idleConns[idx+1:]...) p.idleConnsLen.Add(-1) break @@ -853,7 +856,7 @@ func (p *ConnPool) isHealthyConn(cn *Conn, now time.Time) bool { if replyType, err := cn.rd.PeekReplyType(); err == nil && replyType == proto.RespPush { // For RESP3 connections with push notifications, we allow some buffered data // The client will process these notifications before using the connection - internal.Logger.Printf(context.Background(), "push: conn[%d] has buffered data, likely push notifications - will be processed by client", cn.GetID()) + p.logf(context.Background(), "push: conn[%d] has buffered data, likely push notifications - will be processed by client", cn.GetID()) return true // Connection is healthy, client will handle notifications } return false // Unexpected data, not push notifications, connection is unhealthy @@ -863,3 +866,11 @@ func (p *ConnPool) isHealthyConn(cn *Conn, now time.Time) bool { } return true } + +func (p *ConnPool) logf(ctx context.Context, format string, args ...any) { + logger := internal.Logger + if p.cfg.Logger != nil { + logger = *p.cfg.Logger + } + logger.Printf(ctx, format, args...) +} diff --git a/maintnotifications/circuit_breaker.go b/maintnotifications/circuit_breaker.go index cb76b6447f..77ce24adb4 100644 --- a/maintnotifications/circuit_breaker.go +++ b/maintnotifications/circuit_breaker.go @@ -103,7 +103,7 @@ func (cb *CircuitBreaker) Execute(fn func() error) error { cb.requests.Store(0) cb.successes.Store(0) if internal.LogLevel.InfoOrAbove() { - internal.Logger.Printf(context.Background(), logs.CircuitBreakerTransitioningToHalfOpen(cb.endpoint)) + cb.logf(context.Background(), logs.CircuitBreakerTransitioningToHalfOpen(cb.endpoint)) } // Fall through to half-open logic } else { @@ -145,7 +145,7 @@ func (cb *CircuitBreaker) recordFailure() { if failures >= int64(cb.failureThreshold) { if cb.state.CompareAndSwap(int32(CircuitBreakerClosed), int32(CircuitBreakerOpen)) { if internal.LogLevel.WarnOrAbove() { - internal.Logger.Printf(context.Background(), logs.CircuitBreakerOpened(cb.endpoint, failures)) + cb.logf(context.Background(), logs.CircuitBreakerOpened(cb.endpoint, failures)) } } } @@ -153,7 +153,7 @@ func (cb *CircuitBreaker) recordFailure() { // Any failure in half-open state immediately opens the circuit if cb.state.CompareAndSwap(int32(CircuitBreakerHalfOpen), int32(CircuitBreakerOpen)) { if internal.LogLevel.WarnOrAbove() { - internal.Logger.Printf(context.Background(), logs.CircuitBreakerReopened(cb.endpoint)) + cb.logf(context.Background(), logs.CircuitBreakerReopened(cb.endpoint)) } } } @@ -177,7 +177,7 @@ func (cb *CircuitBreaker) recordSuccess() { if cb.state.CompareAndSwap(int32(CircuitBreakerHalfOpen), int32(CircuitBreakerClosed)) { cb.failures.Store(0) if internal.LogLevel.InfoOrAbove() { - internal.Logger.Printf(context.Background(), logs.CircuitBreakerClosed(cb.endpoint, successes)) + cb.logf(context.Background(), logs.CircuitBreakerClosed(cb.endpoint, successes)) } } } @@ -202,6 +202,14 @@ func (cb *CircuitBreaker) GetStats() CircuitBreakerStats { } } +func (cb *CircuitBreaker) logf(ctx context.Context, format string, args ...interface{}) { + logger := internal.Logger + if cb.config != nil && cb.config.Logger != nil { + logger = *cb.config.Logger + } + logger.Printf(ctx, format, args...) +} + // CircuitBreakerStats provides statistics about a circuit breaker type CircuitBreakerStats struct { Endpoint string @@ -326,7 +334,7 @@ func (cbm *CircuitBreakerManager) cleanup() { // Log cleanup results if len(toDelete) > 0 && internal.LogLevel.InfoOrAbove() { - internal.Logger.Printf(context.Background(), logs.CircuitBreakerCleanup(len(toDelete), count)) + cbm.logf(context.Background(), logs.CircuitBreakerCleanup(len(toDelete), count)) } cbm.lastCleanup.Store(now.Unix()) @@ -351,3 +359,11 @@ func (cbm *CircuitBreakerManager) Reset() { return true }) } + +func (cbm *CircuitBreakerManager) logf(ctx context.Context, format string, args ...interface{}) { + logger := internal.Logger + if cbm.config != nil && cbm.config.Logger != nil { + logger = *cbm.config.Logger + } + logger.Printf(ctx, format, args...) +} diff --git a/maintnotifications/config.go b/maintnotifications/config.go index cbf4f6b22b..b85e39b46a 100644 --- a/maintnotifications/config.go +++ b/maintnotifications/config.go @@ -128,6 +128,9 @@ type Config struct { // After this many retries, the connection will be removed from the pool. // Default: 3 MaxHandoffRetries int + + // Logger is an optional custom logger for maintenance notifications. + Logger *internal.Logging } func (c *Config) IsEnabled() bool { @@ -341,6 +344,8 @@ func (c *Config) Clone() *Config { // Configuration fields MaxHandoffRetries: c.MaxHandoffRetries, + + Logger: c.Logger, } } diff --git a/maintnotifications/handoff_worker.go b/maintnotifications/handoff_worker.go index 22df2c8008..c5234c403d 100644 --- a/maintnotifications/handoff_worker.go +++ b/maintnotifications/handoff_worker.go @@ -121,7 +121,7 @@ func (hwm *handoffWorkerManager) onDemandWorker() { defer func() { // Handle panics to ensure proper cleanup if r := recover(); r != nil { - internal.Logger.Printf(context.Background(), logs.WorkerPanicRecovered(r)) + hwm.logf(context.Background(), logs.WorkerPanicRecovered(r)) } // Decrement active worker count when exiting @@ -146,13 +146,13 @@ func (hwm *handoffWorkerManager) onDemandWorker() { select { case <-hwm.shutdown: if internal.LogLevel.InfoOrAbove() { - internal.Logger.Printf(context.Background(), logs.WorkerExitingDueToShutdown()) + hwm.logf(context.Background(), logs.WorkerExitingDueToShutdown()) } return case <-timer.C: // Worker has been idle for too long, exit to save resources if internal.LogLevel.InfoOrAbove() { - internal.Logger.Printf(context.Background(), logs.WorkerExitingDueToInactivityTimeout(hwm.workerTimeout)) + hwm.logf(context.Background(), logs.WorkerExitingDueToInactivityTimeout(hwm.workerTimeout)) } return case request := <-hwm.handoffQueue: @@ -160,7 +160,7 @@ func (hwm *handoffWorkerManager) onDemandWorker() { select { case <-hwm.shutdown: if internal.LogLevel.InfoOrAbove() { - internal.Logger.Printf(context.Background(), logs.WorkerExitingDueToShutdownWhileProcessing()) + hwm.logf(context.Background(), logs.WorkerExitingDueToShutdownWhileProcessing()) } // Clean up the request before exiting hwm.pending.Delete(request.ConnID) @@ -178,7 +178,7 @@ func (hwm *handoffWorkerManager) processHandoffRequest(request HandoffRequest) { // Remove from pending map defer hwm.pending.Delete(request.Conn.GetID()) if internal.LogLevel.InfoOrAbove() { - internal.Logger.Printf(context.Background(), logs.HandoffStarted(request.Conn.GetID(), request.Endpoint)) + hwm.logf(context.Background(), logs.HandoffStarted(request.Conn.GetID(), request.Endpoint)) } // Create a context with handoff timeout from config @@ -226,12 +226,12 @@ func (hwm *handoffWorkerManager) processHandoffRequest(request HandoffRequest) { if hwm.config != nil { maxRetries = hwm.config.MaxHandoffRetries } - internal.Logger.Printf(context.Background(), logs.HandoffFailed(request.ConnID, request.Endpoint, currentRetries, maxRetries, err)) + hwm.logf(context.Background(), logs.HandoffFailed(request.ConnID, request.Endpoint, currentRetries, maxRetries, err)) } time.AfterFunc(afterTime, func() { if err := hwm.queueHandoff(request.Conn); err != nil { if internal.LogLevel.WarnOrAbove() { - internal.Logger.Printf(context.Background(), logs.CannotQueueHandoffForRetry(err)) + hwm.logf(context.Background(), logs.CannotQueueHandoffForRetry(err)) } hwm.closeConnFromRequest(context.Background(), request, err) } @@ -259,7 +259,7 @@ func (hwm *handoffWorkerManager) queueHandoff(conn *pool.Conn) error { // if shouldHandoff is false and retries is 0, then we are not retrying and not do a handoff if !shouldHandoff && conn.HandoffRetries() == 0 { if internal.LogLevel.InfoOrAbove() { - internal.Logger.Printf(context.Background(), logs.ConnectionNotMarkedForHandoff(conn.GetID())) + hwm.logf(context.Background(), logs.ConnectionNotMarkedForHandoff(conn.GetID())) } return errors.New(logs.ConnectionNotMarkedForHandoffError(conn.GetID())) } @@ -302,7 +302,7 @@ func (hwm *handoffWorkerManager) queueHandoff(conn *pool.Conn) error { queueLen := len(hwm.handoffQueue) queueCap := cap(hwm.handoffQueue) if internal.LogLevel.WarnOrAbove() { - internal.Logger.Printf(context.Background(), logs.HandoffQueueFull(queueLen, queueCap)) + hwm.logf(context.Background(), logs.HandoffQueueFull(queueLen, queueCap)) } } } @@ -356,7 +356,7 @@ func (hwm *handoffWorkerManager) performConnectionHandoff(ctx context.Context, c // Check if circuit breaker is open before attempting handoff if circuitBreaker.IsOpen() { - internal.Logger.Printf(ctx, logs.CircuitBreakerOpen(connID, newEndpoint)) + hwm.logf(ctx, logs.CircuitBreakerOpen(connID, newEndpoint)) return false, ErrCircuitBreakerOpen // Don't retry when circuit breaker is open } @@ -385,7 +385,7 @@ func (hwm *handoffWorkerManager) performHandoffInternal( connID uint64, ) (shouldRetry bool, err error) { retries := conn.IncrementAndGetHandoffRetries(1) - internal.Logger.Printf(ctx, logs.HandoffRetryAttempt(connID, retries, newEndpoint, conn.RemoteAddr().String())) + hwm.logf(ctx, logs.HandoffRetryAttempt(connID, retries, newEndpoint, conn.RemoteAddr().String())) maxRetries := 3 // Default fallback if hwm.config != nil { maxRetries = hwm.config.MaxHandoffRetries @@ -393,7 +393,7 @@ func (hwm *handoffWorkerManager) performHandoffInternal( if retries > maxRetries { if internal.LogLevel.WarnOrAbove() { - internal.Logger.Printf(ctx, logs.ReachedMaxHandoffRetries(connID, newEndpoint, maxRetries)) + hwm.logf(ctx, logs.ReachedMaxHandoffRetries(connID, newEndpoint, maxRetries)) } // won't retry on ErrMaxHandoffRetriesReached return false, ErrMaxHandoffRetriesReached @@ -405,7 +405,7 @@ func (hwm *handoffWorkerManager) performHandoffInternal( // Create new connection to the new endpoint newNetConn, err := endpointDialer(ctx) if err != nil { - internal.Logger.Printf(ctx, logs.FailedToDialNewEndpoint(connID, newEndpoint, err)) + hwm.logf(ctx, logs.FailedToDialNewEndpoint(connID, newEndpoint, err)) // will retry // Maybe a network error - retry after a delay return true, err @@ -425,7 +425,7 @@ func (hwm *handoffWorkerManager) performHandoffInternal( conn.SetRelaxedTimeoutWithDeadline(relaxedTimeout, relaxedTimeout, deadline) if internal.LogLevel.InfoOrAbove() { - internal.Logger.Printf(context.Background(), logs.ApplyingRelaxedTimeoutDueToPostHandoff(connID, relaxedTimeout, deadline.Format("15:04:05.000"))) + hwm.logf(context.Background(), logs.ApplyingRelaxedTimeoutDueToPostHandoff(connID, relaxedTimeout, deadline.Format("15:04:05.000"))) } } @@ -447,7 +447,7 @@ func (hwm *handoffWorkerManager) performHandoffInternal( // - clear the handoff state (shouldHandoff, endpoint, seqID) // - reset the handoff retries to 0 conn.ClearHandoffState() - internal.Logger.Printf(ctx, logs.HandoffSucceeded(connID, newEndpoint)) + hwm.logf(ctx, logs.HandoffSucceeded(connID, newEndpoint)) // successfully completed the handoff, no retry needed and no error return false, nil @@ -478,15 +478,23 @@ func (hwm *handoffWorkerManager) closeConnFromRequest(ctx context.Context, reque if pooler != nil { pooler.Remove(ctx, conn, err) if internal.LogLevel.WarnOrAbove() { - internal.Logger.Printf(ctx, logs.RemovingConnectionFromPool(conn.GetID(), err)) + hwm.logf(ctx, logs.RemovingConnectionFromPool(conn.GetID(), err)) } } else { err := conn.Close() // Close the connection if no pool provided if err != nil { - internal.Logger.Printf(ctx, "redis: failed to close connection: %v", err) + hwm.logf(ctx, "redis: failed to close connection: %v", err) } if internal.LogLevel.WarnOrAbove() { - internal.Logger.Printf(ctx, logs.NoPoolProvidedCannotRemove(conn.GetID(), err)) + hwm.logf(ctx, logs.NoPoolProvidedCannotRemove(conn.GetID(), err)) } } } + +func (hwm *handoffWorkerManager) logf(ctx context.Context, format string, args ...interface{}) { + logger := internal.Logger + if hwm.config != nil && hwm.config.Logger != nil { + logger = *hwm.config.Logger + } + logger.Printf(ctx, format, args...) +} diff --git a/maintnotifications/manager.go b/maintnotifications/manager.go index 775c163e14..5024d0526f 100644 --- a/maintnotifications/manager.go +++ b/maintnotifications/manager.go @@ -151,12 +151,12 @@ func (hm *Manager) TrackMovingOperationWithConnID(ctx context.Context, newEndpoi if _, loaded := hm.activeMovingOps.LoadOrStore(key, movingOp); loaded { // Duplicate MOVING notification, ignore if internal.LogLevel.DebugOrAbove() { // Debug level - internal.Logger.Printf(context.Background(), logs.DuplicateMovingOperation(connID, newEndpoint, seqID)) + hm.logf(context.Background(), logs.DuplicateMovingOperation(connID, newEndpoint, seqID)) } return nil } if internal.LogLevel.DebugOrAbove() { // Debug level - internal.Logger.Printf(context.Background(), logs.TrackingMovingOperation(connID, newEndpoint, seqID)) + hm.logf(context.Background(), logs.TrackingMovingOperation(connID, newEndpoint, seqID)) } // Increment active operation count atomically @@ -176,13 +176,13 @@ func (hm *Manager) UntrackOperationWithConnID(seqID int64, connID uint64) { // Remove from active operations atomically if _, loaded := hm.activeMovingOps.LoadAndDelete(key); loaded { if internal.LogLevel.DebugOrAbove() { // Debug level - internal.Logger.Printf(context.Background(), logs.UntrackingMovingOperation(connID, seqID)) + hm.logf(context.Background(), logs.UntrackingMovingOperation(connID, seqID)) } // Decrement active operation count only if operation existed hm.activeOperationCount.Add(-1) } else { if internal.LogLevel.DebugOrAbove() { // Debug level - internal.Logger.Printf(context.Background(), logs.OperationNotTracked(connID, seqID)) + hm.logf(context.Background(), logs.OperationNotTracked(connID, seqID)) } } } @@ -318,3 +318,11 @@ func (hm *Manager) AddNotificationHook(notificationHook NotificationHook) { defer hm.hooksMu.Unlock() hm.hooks = append(hm.hooks, notificationHook) } + +func (hm *Manager) logf(ctx context.Context, format string, args ...interface{}) { + logger := internal.Logger + if hm.config != nil && hm.config.Logger != nil { + logger = *hm.config.Logger + } + logger.Printf(ctx, format, args...) +} diff --git a/maintnotifications/pool_hook.go b/maintnotifications/pool_hook.go index 9fd24b4a7b..f6f70cdad1 100644 --- a/maintnotifications/pool_hook.go +++ b/maintnotifications/pool_hook.go @@ -150,7 +150,7 @@ func (ph *PoolHook) OnPut(ctx context.Context, conn *pool.Conn) (shouldPool bool if err := ph.workerManager.queueHandoff(conn); err != nil { // Failed to queue handoff, remove the connection - internal.Logger.Printf(ctx, logs.FailedToQueueHandoff(conn.GetID(), err)) + ph.logf(ctx, logs.FailedToQueueHandoff(conn.GetID(), err)) // Don't pool, remove connection, no error to caller return false, true, nil } @@ -170,7 +170,7 @@ func (ph *PoolHook) OnPut(ctx context.Context, conn *pool.Conn) (shouldPool bool // Other error - remove the connection return false, true, nil } - internal.Logger.Printf(ctx, logs.MarkedForHandoff(conn.GetID())) + ph.logf(ctx, logs.MarkedForHandoff(conn.GetID())) return true, false, nil } @@ -182,3 +182,11 @@ func (ph *PoolHook) OnRemove(_ context.Context, _ *pool.Conn, _ error) { func (ph *PoolHook) Shutdown(ctx context.Context) error { return ph.workerManager.shutdownWorkers(ctx) } + +func (ph *PoolHook) logf(ctx context.Context, format string, args ...interface{}) { + logger := internal.Logger + if ph.config != nil && ph.config.Logger != nil { + logger = *ph.config.Logger + } + logger.Printf(ctx, format, args...) +} diff --git a/maintnotifications/push_notification_handler.go b/maintnotifications/push_notification_handler.go index 937b4ae82e..70f6aa8147 100644 --- a/maintnotifications/push_notification_handler.go +++ b/maintnotifications/push_notification_handler.go @@ -21,13 +21,13 @@ type NotificationHandler struct { // HandlePushNotification processes push notifications with hook support. func (snh *NotificationHandler) HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { if len(notification) == 0 { - internal.Logger.Printf(ctx, logs.InvalidNotificationFormat(notification)) + snh.logf(ctx, logs.InvalidNotificationFormat(notification)) return ErrInvalidNotification } notificationType, ok := notification[0].(string) if !ok { - internal.Logger.Printf(ctx, logs.InvalidNotificationTypeFormat(notification[0])) + snh.logf(ctx, logs.InvalidNotificationTypeFormat(notification[0])) return ErrInvalidNotification } @@ -64,19 +64,19 @@ func (snh *NotificationHandler) HandlePushNotification(ctx context.Context, hand // ["MOVING", seqNum, timeS, endpoint] - per-connection handoff func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { if len(notification) < 3 { - internal.Logger.Printf(ctx, logs.InvalidNotification("MOVING", notification)) + snh.logf(ctx, logs.InvalidNotification("MOVING", notification)) return ErrInvalidNotification } seqID, ok := notification[1].(int64) if !ok { - internal.Logger.Printf(ctx, logs.InvalidSeqIDInMovingNotification(notification[1])) + snh.logf(ctx, logs.InvalidSeqIDInMovingNotification(notification[1])) return ErrInvalidNotification } // Extract timeS timeS, ok := notification[2].(int64) if !ok { - internal.Logger.Printf(ctx, logs.InvalidTimeSInMovingNotification(notification[2])) + snh.logf(ctx, logs.InvalidTimeSInMovingNotification(notification[2])) return ErrInvalidNotification } @@ -90,7 +90,7 @@ func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx pus if notification[3] == nil || stringified == internal.RedisNull { newEndpoint = "" } else { - internal.Logger.Printf(ctx, logs.InvalidNewEndpointInMovingNotification(notification[3])) + snh.logf(ctx, logs.InvalidNewEndpointInMovingNotification(notification[3])) return ErrInvalidNotification } } @@ -99,7 +99,7 @@ func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx pus // Get the connection that received this notification conn := handlerCtx.Conn if conn == nil { - internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("MOVING")) + snh.logf(ctx, logs.NoConnectionInHandlerContext("MOVING")) return ErrInvalidNotification } @@ -108,7 +108,7 @@ func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx pus if pc, ok := conn.(*pool.Conn); ok { poolConn = pc } else { - internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("MOVING", conn, handlerCtx)) + snh.logf(ctx, logs.InvalidConnectionTypeInHandlerContext("MOVING", conn, handlerCtx)) return ErrInvalidNotification } @@ -125,7 +125,7 @@ func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx pus // If newEndpoint is empty, we should schedule a handoff to the current endpoint in timeS/2 seconds if newEndpoint == "" || newEndpoint == internal.RedisNull { if internal.LogLevel.DebugOrAbove() { - internal.Logger.Printf(ctx, logs.SchedulingHandoffToCurrentEndpoint(poolConn.GetID(), float64(timeS)/2)) + snh.logf(ctx, logs.SchedulingHandoffToCurrentEndpoint(poolConn.GetID(), float64(timeS)/2)) } // same as current endpoint newEndpoint = snh.manager.options.GetAddr() @@ -139,7 +139,7 @@ func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx pus } if err := snh.markConnForHandoff(poolConn, newEndpoint, seqID, deadline); err != nil { // Log error but don't fail the goroutine - use background context since original may be cancelled - internal.Logger.Printf(context.Background(), logs.FailedToMarkForHandoff(poolConn.GetID(), err)) + snh.logf(context.Background(), logs.FailedToMarkForHandoff(poolConn.GetID(), err)) } }) return nil @@ -150,7 +150,7 @@ func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx pus func (snh *NotificationHandler) markConnForHandoff(conn *pool.Conn, newEndpoint string, seqID int64, deadline time.Time) error { if err := conn.MarkForHandoff(newEndpoint, seqID); err != nil { - internal.Logger.Printf(context.Background(), logs.FailedToMarkForHandoff(conn.GetID(), err)) + snh.logf(context.Background(), logs.FailedToMarkForHandoff(conn.GetID(), err)) // Connection is already marked for handoff, which is acceptable // This can happen if multiple MOVING notifications are received for the same connection return nil @@ -171,24 +171,24 @@ func (snh *NotificationHandler) handleMigrating(ctx context.Context, handlerCtx // MIGRATING notifications indicate that a connection is about to be migrated // Apply relaxed timeouts to the specific connection that received this notification if len(notification) < 2 { - internal.Logger.Printf(ctx, logs.InvalidNotification("MIGRATING", notification)) + snh.logf(ctx, logs.InvalidNotification("MIGRATING", notification)) return ErrInvalidNotification } if handlerCtx.Conn == nil { - internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("MIGRATING")) + snh.logf(ctx, logs.NoConnectionInHandlerContext("MIGRATING")) return ErrInvalidNotification } conn, ok := handlerCtx.Conn.(*pool.Conn) if !ok { - internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("MIGRATING", handlerCtx.Conn, handlerCtx)) + snh.logf(ctx, logs.InvalidConnectionTypeInHandlerContext("MIGRATING", handlerCtx.Conn, handlerCtx)) return ErrInvalidNotification } // Apply relaxed timeout to this specific connection if internal.LogLevel.InfoOrAbove() { - internal.Logger.Printf(ctx, logs.RelaxedTimeoutDueToNotification(conn.GetID(), "MIGRATING", snh.manager.config.RelaxedTimeout)) + snh.logf(ctx, logs.RelaxedTimeoutDueToNotification(conn.GetID(), "MIGRATING", snh.manager.config.RelaxedTimeout)) } conn.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout) return nil @@ -199,25 +199,25 @@ func (snh *NotificationHandler) handleMigrated(ctx context.Context, handlerCtx p // MIGRATED notifications indicate that a connection migration has completed // Restore normal timeouts for the specific connection that received this notification if len(notification) < 2 { - internal.Logger.Printf(ctx, logs.InvalidNotification("MIGRATED", notification)) + snh.logf(ctx, logs.InvalidNotification("MIGRATED", notification)) return ErrInvalidNotification } if handlerCtx.Conn == nil { - internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("MIGRATED")) + snh.logf(ctx, logs.NoConnectionInHandlerContext("MIGRATED")) return ErrInvalidNotification } conn, ok := handlerCtx.Conn.(*pool.Conn) if !ok { - internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("MIGRATED", handlerCtx.Conn, handlerCtx)) + snh.logf(ctx, logs.InvalidConnectionTypeInHandlerContext("MIGRATED", handlerCtx.Conn, handlerCtx)) return ErrInvalidNotification } // Clear relaxed timeout for this specific connection if internal.LogLevel.InfoOrAbove() { connID := conn.GetID() - internal.Logger.Printf(ctx, logs.UnrelaxedTimeout(connID)) + snh.logf(ctx, logs.UnrelaxedTimeout(connID)) } conn.ClearRelaxedTimeout() return nil @@ -228,25 +228,25 @@ func (snh *NotificationHandler) handleFailingOver(ctx context.Context, handlerCt // FAILING_OVER notifications indicate that a connection is about to failover // Apply relaxed timeouts to the specific connection that received this notification if len(notification) < 2 { - internal.Logger.Printf(ctx, logs.InvalidNotification("FAILING_OVER", notification)) + snh.logf(ctx, logs.InvalidNotification("FAILING_OVER", notification)) return ErrInvalidNotification } if handlerCtx.Conn == nil { - internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("FAILING_OVER")) + snh.logf(ctx, logs.NoConnectionInHandlerContext("FAILING_OVER")) return ErrInvalidNotification } conn, ok := handlerCtx.Conn.(*pool.Conn) if !ok { - internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("FAILING_OVER", handlerCtx.Conn, handlerCtx)) + snh.logf(ctx, logs.InvalidConnectionTypeInHandlerContext("FAILING_OVER", handlerCtx.Conn, handlerCtx)) return ErrInvalidNotification } // Apply relaxed timeout to this specific connection if internal.LogLevel.InfoOrAbove() { connID := conn.GetID() - internal.Logger.Printf(ctx, logs.RelaxedTimeoutDueToNotification(connID, "FAILING_OVER", snh.manager.config.RelaxedTimeout)) + snh.logf(ctx, logs.RelaxedTimeoutDueToNotification(connID, "FAILING_OVER", snh.manager.config.RelaxedTimeout)) } conn.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout) return nil @@ -257,26 +257,34 @@ func (snh *NotificationHandler) handleFailedOver(ctx context.Context, handlerCtx // FAILED_OVER notifications indicate that a connection failover has completed // Restore normal timeouts for the specific connection that received this notification if len(notification) < 2 { - internal.Logger.Printf(ctx, logs.InvalidNotification("FAILED_OVER", notification)) + snh.logf(ctx, logs.InvalidNotification("FAILED_OVER", notification)) return ErrInvalidNotification } if handlerCtx.Conn == nil { - internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("FAILED_OVER")) + snh.logf(ctx, logs.NoConnectionInHandlerContext("FAILED_OVER")) return ErrInvalidNotification } conn, ok := handlerCtx.Conn.(*pool.Conn) if !ok { - internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("FAILED_OVER", handlerCtx.Conn, handlerCtx)) + snh.logf(ctx, logs.InvalidConnectionTypeInHandlerContext("FAILED_OVER", handlerCtx.Conn, handlerCtx)) return ErrInvalidNotification } // Clear relaxed timeout for this specific connection if internal.LogLevel.InfoOrAbove() { connID := conn.GetID() - internal.Logger.Printf(ctx, logs.UnrelaxedTimeout(connID)) + snh.logf(ctx, logs.UnrelaxedTimeout(connID)) } conn.ClearRelaxedTimeout() return nil } + +func (snh *NotificationHandler) logf(ctx context.Context, format string, args ...interface{}) { + logger := internal.Logger + if snh.manager != nil && snh.manager.config != nil && snh.manager.config.Logger != nil { + logger = *snh.manager.config.Logger + } + logger.Printf(ctx, format, args...) +} diff --git a/options.go b/options.go index 79e4b6df7d..41b6e4d844 100644 --- a/options.go +++ b/options.go @@ -14,6 +14,7 @@ import ( "time" "github.com/redis/go-redis/v9/auth" + "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/pool" "github.com/redis/go-redis/v9/internal/proto" "github.com/redis/go-redis/v9/internal/util" @@ -264,6 +265,10 @@ type Options struct { // transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications. // If nil, maintnotifications are in "auto" mode and will be enabled if the server supports it. MaintNotificationsConfig *maintnotifications.Config + + // Logger is the logger used by the client for logging. + // If none is provided, the global logger [internal.Logger] is used. + Logger *internal.Logging } func (opt *Options) init() { diff --git a/osscluster.go b/osscluster.go index 7925d2c603..65872b132f 100644 --- a/osscluster.go +++ b/osscluster.go @@ -148,6 +148,9 @@ type ClusterOptions struct { // If nil, maintnotifications upgrades are in "auto" mode and will be enabled if the server supports it. // The ClusterClient does not directly work with maintnotifications, it is up to the clients in the Nodes map to work with maintnotifications. MaintNotificationsConfig *maintnotifications.Config + + // Logger is an optional logger for logging cluster-related messages. + Logger *internal.Logging } func (opt *ClusterOptions) init() { @@ -390,6 +393,8 @@ func (opt *ClusterOptions) clientOptions() *Options { UnstableResp3: opt.UnstableResp3, MaintNotificationsConfig: maintNotificationsConfig, PushNotificationProcessor: opt.PushNotificationProcessor, + + Logger: opt.Logger, } } @@ -703,6 +708,14 @@ func (c *clusterNodes) Random() (*clusterNode, error) { return c.GetOrCreate(addrs[n]) } +func (c *clusterNodes) logf(ctx context.Context, format string, args ...interface{}) { + logger := internal.Logger + if c.opt.Logger != nil { + logger = *c.opt.Logger + } + logger.Printf(ctx, format, args...) +} + //------------------------------------------------------------------------------ type clusterSlot struct { @@ -900,12 +913,12 @@ func (c *clusterState) slotClosestNode(slot int) (*clusterNode, error) { // if all nodes are failing, we will pick the temporarily failing node with lowest latency if minLatency < maximumNodeLatency && closestNode != nil { - internal.Logger.Printf(context.TODO(), "redis: all nodes are marked as failed, picking the temporarily failing node with lowest latency") + c.nodes.logf(context.TODO(), "redis: all nodes are marked as failed, picking the temporarily failing node with lowest latency") return closestNode, nil } // If all nodes are having the maximum latency(all pings are failing) - return a random node across the cluster - internal.Logger.Printf(context.TODO(), "redis: pings to all nodes are failing, picking a random node across the cluster") + c.nodes.logf(context.TODO(), "redis: pings to all nodes are failing, picking a random node across the cluster") return c.nodes.Random() } @@ -1740,7 +1753,7 @@ func (c *ClusterClient) txPipelineReadQueued( if err := node.Client.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { // Log the error but don't fail the command execution // Push notification processing errors shouldn't break normal Redis operations - internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err) + c.logf(ctx, "push: error processing pending notifications before reading reply: %v", err) } if err := statusCmd.readReply(rd); err != nil { return err @@ -1751,7 +1764,7 @@ func (c *ClusterClient) txPipelineReadQueued( if err := node.Client.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { // Log the error but don't fail the command execution // Push notification processing errors shouldn't break normal Redis operations - internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err) + c.logf(ctx, "push: error processing pending notifications before reading reply: %v", err) } err := statusCmd.readReply(rd) if err != nil { @@ -1770,7 +1783,7 @@ func (c *ClusterClient) txPipelineReadQueued( if err := node.Client.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { // Log the error but don't fail the command execution // Push notification processing errors shouldn't break normal Redis operations - internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err) + c.logf(ctx, "push: error processing pending notifications before reading reply: %v", err) } // Parse number of replies. line, err := rd.ReadLine() @@ -2022,13 +2035,13 @@ func (c *ClusterClient) cmdsInfo(ctx context.Context) (map[string]*CommandInfo, func (c *ClusterClient) cmdInfo(ctx context.Context, name string) *CommandInfo { cmdsInfo, err := c.cmdsInfoCache.Get(ctx) if err != nil { - internal.Logger.Printf(context.TODO(), "getting command info: %s", err) + c.logf(context.TODO(), "getting command info: %s", err) return nil } info := cmdsInfo[name] if info == nil { - internal.Logger.Printf(context.TODO(), "info for cmd=%s not found", name) + c.logf(context.TODO(), "info for cmd=%s not found", name) } return info } @@ -2126,6 +2139,16 @@ func (c *ClusterClient) context(ctx context.Context) context.Context { return context.Background() } +func (c *ClusterClient) logf(ctx context.Context, format string, args ...interface{}) { + logger := internal.Logger + + if c.opt.Logger != nil { + logger = *c.opt.Logger + } + + logger.Printf(ctx, format, args...) +} + func appendIfNotExist[T comparable](vals []T, newVal T) []T { for _, v := range vals { if v == newVal { diff --git a/pubsub.go b/pubsub.go index 959a5c45b1..8fed3a1665 100644 --- a/pubsub.go +++ b/pubsub.go @@ -141,6 +141,17 @@ func mapKeys(m map[string]struct{}) []string { return s } +// logf is a wrapper around the logger to log messages with context. +// +// it uses the client logger if set, otherwise it uses the global logger. +func (c *PubSub) logf(ctx context.Context, format string, args ...interface{}) { + logger := internal.Logger + if c.opt.Logger != nil { + logger = *c.opt.Logger + } + logger.Printf(ctx, format, args...) +} + func (c *PubSub) _subscribe( ctx context.Context, cn *pool.Conn, redisCmd string, channels []string, ) error { @@ -190,7 +201,7 @@ func (c *PubSub) reconnect(ctx context.Context, reason error) { // Update the address in the options oldAddr := c.cn.RemoteAddr().String() c.opt.Addr = newEndpoint - internal.Logger.Printf(ctx, "pubsub: reconnecting to new endpoint %s (was %s)", newEndpoint, oldAddr) + c.logf(ctx, "pubsub: reconnecting to new endpoint %s (was %s)", newEndpoint, oldAddr) } } _ = c.closeTheCn(reason) @@ -475,7 +486,7 @@ func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (int if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { // Log the error but don't fail the command execution // Push notification processing errors shouldn't break normal Redis operations - internal.Logger.Printf(ctx, "push: conn[%d] error processing pending notifications before reading reply: %v", cn.GetID(), err) + c.logf(ctx, "push: conn[%d] error processing pending notifications before reading reply: %v", cn.GetID(), err) } return c.cmd.readReply(rd) }) @@ -631,9 +642,18 @@ func WithChannelSendTimeout(d time.Duration) ChannelOption { } } +func WithLogger(logger *internal.Logging) ChannelOption { + return func(c *channel) { + c.Logger = logger + } +} + type channel struct { pubSub *PubSub + // Optional logger for logging channel-related messages. + Logger *internal.Logging + msgCh chan *Message allCh chan interface{} ping chan struct{} @@ -733,12 +753,12 @@ func (c *channel) initMsgChan() { <-timer.C } case <-timer.C: - internal.Logger.Printf( + c.logf( ctx, "redis: %s channel is full for %s (message is dropped)", c, c.chanSendTimeout) } default: - internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg) + c.logf(ctx, "redis: unknown message type: %T", msg) } } }() @@ -787,13 +807,24 @@ func (c *channel) initAllChan() { <-timer.C } case <-timer.C: - internal.Logger.Printf( + c.logf( ctx, "redis: %s channel is full for %s (message is dropped)", c, c.chanSendTimeout) } default: - internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg) + c.logf(ctx, "redis: unknown message type: %T", msg) } } }() } + +func (c *channel) logf(ctx context.Context, format string, args ...any) { + logger := internal.Logger + switch { + case c.pubSub.opt.Logger != nil: + logger = *c.pubSub.opt.Logger + case c.Logger != nil: + logger = *c.Logger + } + logger.Printf(ctx, format, args...) +} diff --git a/redis.go b/redis.go index dcd7b59a78..d148a05dae 100644 --- a/redis.go +++ b/redis.go @@ -228,6 +228,9 @@ type baseClient struct { // streamingCredentialsManager is used to manage streaming credentials streamingCredentialsManager *streaming.Manager + + // Logger is the logger used by the client for logging. + logger *internal.Logging } func (c *baseClient) clone() *baseClient { @@ -242,6 +245,7 @@ func (c *baseClient) clone() *baseClient { pushProcessor: c.pushProcessor, maintNotificationsManager: maintNotificationsManager, streamingCredentialsManager: c.streamingCredentialsManager, + logger: c.logger, } return clone } @@ -330,16 +334,16 @@ func (c *baseClient) onAuthenticationErr() func(poolCn *pool.Conn, err error) { // Close the connection to force a reconnection. err := c.connPool.CloseConn(poolCn) if err != nil { - internal.Logger.Printf(context.Background(), "redis: failed to close connection: %v", err) + c.logf(context.Background(), "redis: failed to close connection: %v", err) // try to close the network connection directly // so that no resource is leaked err := poolCn.Close() if err != nil { - internal.Logger.Printf(context.Background(), "redis: failed to close network connection: %v", err) + c.logf(context.Background(), "redis: failed to close network connection: %v", err) } } } - internal.Logger.Printf(context.Background(), "redis: re-authentication failed: %v", err) + c.logf(context.Background(), "redis: re-authentication failed: %v", err) } } } @@ -475,13 +479,13 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { c.optLock.Unlock() return fmt.Errorf("failed to enable maintnotifications: %w", maintNotifHandshakeErr) default: // will handle auto and any other - internal.Logger.Printf(ctx, "auto mode fallback: maintnotifications disabled due to handshake error: %v", maintNotifHandshakeErr) + c.logf(ctx, "auto mode fallback: maintnotifications disabled due to handshake error: %v", maintNotifHandshakeErr) c.opt.MaintNotificationsConfig.Mode = maintnotifications.ModeDisabled c.optLock.Unlock() // auto mode, disable maintnotifications and continue if err := c.disableMaintNotificationsUpgrades(); err != nil { // Log error but continue - auto mode should be resilient - internal.Logger.Printf(ctx, "failed to disable maintnotifications in auto mode: %v", err) + c.logf(ctx, "failed to disable maintnotifications in auto mode: %v", err) } } } else { @@ -536,7 +540,7 @@ func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error) } else { // process any pending push notifications before returning the connection to the pool if err := c.processPushNotifications(ctx, cn); err != nil { - internal.Logger.Printf(ctx, "push: error processing pending notifications before releasing connection: %v", err) + c.logf(ctx, "push: error processing pending notifications before releasing connection: %v", err) } c.connPool.Put(ctx, cn) } @@ -603,7 +607,7 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool if err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { // Process any pending push notifications before executing the command if err := c.processPushNotifications(ctx, cn); err != nil { - internal.Logger.Printf(ctx, "push: error processing pending notifications before command: %v", err) + c.logf(ctx, "push: error processing pending notifications before command: %v", err) } if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error { @@ -626,7 +630,7 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool if err := cn.WithReader(c.context(ctx), c.cmdTimeout(cmd), func(rd *proto.Reader) error { // To be sure there are no buffered push notifications, we process them before reading the reply if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { - internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err) + c.logf(ctx, "push: error processing pending notifications before reading reply: %v", err) } return readReplyFunc(rd) }); err != nil { @@ -672,6 +676,17 @@ func (c *baseClient) context(ctx context.Context) context.Context { return context.Background() } +// logf is a wrapper around the logger to log messages with context. +// it uses the client logger if set, otherwise it uses the global logger. +func (c *baseClient) logf(ctx context.Context, format string, args ...interface{}) { + logger := internal.Logger + if c.logger != nil { + logger = *c.logger + } + + logger.Printf(ctx, format, args...) +} + // createInitConnFunc creates a connection initialization function that can be used for reconnections. func (c *baseClient) createInitConnFunc() func(context.Context, *pool.Conn) error { return func(ctx context.Context, cn *pool.Conn) error { @@ -783,7 +798,7 @@ func (c *baseClient) generalProcessPipeline( lastErr = c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { // Process any pending push notifications before executing the pipeline if err := c.processPushNotifications(ctx, cn); err != nil { - internal.Logger.Printf(ctx, "push: error processing pending notifications before processing pipeline: %v", err) + c.logf(ctx, "push: error processing pending notifications before processing pipeline: %v", err) } var err error canRetry, err = p(ctx, cn, cmds) @@ -805,7 +820,7 @@ func (c *baseClient) pipelineProcessCmds( ) (bool, error) { // Process any pending push notifications before executing the pipeline if err := c.processPushNotifications(ctx, cn); err != nil { - internal.Logger.Printf(ctx, "push: error processing pending notifications before writing pipeline: %v", err) + c.logf(ctx, "push: error processing pending notifications before writing pipeline: %v", err) } if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error { @@ -829,7 +844,7 @@ func (c *baseClient) pipelineReadCmds(ctx context.Context, cn *pool.Conn, rd *pr for i, cmd := range cmds { // To be sure there are no buffered push notifications, we process them before reading the reply if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { - internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err) + c.logf(ctx, "push: error processing pending notifications before reading reply: %v", err) } err := cmd.readReply(rd) cmd.SetErr(err) @@ -847,7 +862,7 @@ func (c *baseClient) txPipelineProcessCmds( ) (bool, error) { // Process any pending push notifications before executing the transaction pipeline if err := c.processPushNotifications(ctx, cn); err != nil { - internal.Logger.Printf(ctx, "push: error processing pending notifications before transaction: %v", err) + c.logf(ctx, "push: error processing pending notifications before transaction: %v", err) } if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error { @@ -881,7 +896,7 @@ func (c *baseClient) txPipelineProcessCmds( func (c *baseClient) txPipelineReadQueued(ctx context.Context, cn *pool.Conn, rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder) error { // To be sure there are no buffered push notifications, we process them before reading the reply if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { - internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err) + c.logf(ctx, "push: error processing pending notifications before reading reply: %v", err) } // Parse +OK. if err := statusCmd.readReply(rd); err != nil { @@ -892,7 +907,7 @@ func (c *baseClient) txPipelineReadQueued(ctx context.Context, cn *pool.Conn, rd for _, cmd := range cmds { // To be sure there are no buffered push notifications, we process them before reading the reply if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { - internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err) + c.logf(ctx, "push: error processing pending notifications before reading reply: %v", err) } if err := statusCmd.readReply(rd); err != nil { cmd.SetErr(err) @@ -904,7 +919,7 @@ func (c *baseClient) txPipelineReadQueued(ctx context.Context, cn *pool.Conn, rd // To be sure there are no buffered push notifications, we process them before reading the reply if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { - internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err) + c.logf(ctx, "push: error processing pending notifications before reading reply: %v", err) } // Parse number of replies. line, err := rd.ReadLine() @@ -978,7 +993,7 @@ func NewClient(opt *Options) *Client { if opt.MaintNotificationsConfig != nil && opt.MaintNotificationsConfig.Mode != maintnotifications.ModeDisabled && opt.Protocol == 3 { err := c.enableMaintNotificationsUpgrades() if err != nil { - internal.Logger.Printf(context.Background(), "failed to initialize maintnotifications: %v", err) + c.logf(context.Background(), "failed to initialize maintnotifications: %v", err) if opt.MaintNotificationsConfig.Mode == maintnotifications.ModeEnabled { /* Design decision: panic here to fail fast if maintnotifications cannot be enabled when explicitly requested. diff --git a/ring.go b/ring.go index 3381460abd..a8398c178a 100644 --- a/ring.go +++ b/ring.go @@ -154,6 +154,8 @@ type RingOptions struct { DisableIdentity bool IdentitySuffix string UnstableResp3 bool + + Logger *internal.Logging } func (opt *RingOptions) init() { @@ -345,7 +347,7 @@ func (c *ringSharding) SetAddrs(addrs map[string]string) { cleanup := func(shards map[string]*ringShard) { for addr, shard := range shards { if err := shard.Client.Close(); err != nil { - internal.Logger.Printf(context.Background(), "shard.Close %s failed: %s", addr, err) + c.logf(context.Background(), "shard.Close %s failed: %s", addr, err) } } } @@ -490,7 +492,7 @@ func (c *ringSharding) Heartbeat(ctx context.Context, frequency time.Duration) { for _, shard := range c.List() { isUp := c.opt.HeartbeatFn(ctx, shard.Client) if shard.Vote(isUp) { - internal.Logger.Printf(ctx, "ring shard state changed: %s", shard) + c.logf(ctx, "ring shard state changed: %s", shard) rebalance = true } } @@ -559,6 +561,14 @@ func (c *ringSharding) Close() error { return firstErr } +func (c *ringSharding) logf(ctx context.Context, format string, args ...any) { + logger := internal.Logger + if c.opt.Logger != nil { + logger = *c.opt.Logger + } + logger.Printf(ctx, format, args...) +} + //------------------------------------------------------------------------------ // Ring is a Redis client that uses consistent hashing to distribute diff --git a/sentinel.go b/sentinel.go index f1222a340b..249d0bfe30 100644 --- a/sentinel.go +++ b/sentinel.go @@ -148,6 +148,9 @@ type FailoverOptions struct { // If nil, maintnotifications upgrades are disabled. // (however if Mode is nil, it defaults to "auto" - enable if server supports it) //MaintNotificationsConfig *maintnotifications.Config + + // Optional logger for logging + Logger *internal.Logging } func (opt *FailoverOptions) clientOptions() *Options { @@ -194,6 +197,8 @@ func (opt *FailoverOptions) clientOptions() *Options { IdentitySuffix: opt.IdentitySuffix, UnstableResp3: opt.UnstableResp3, + + Logger: opt.Logger, } } @@ -238,6 +243,8 @@ func (opt *FailoverOptions) sentinelOptions(addr string) *Options { IdentitySuffix: opt.IdentitySuffix, UnstableResp3: opt.UnstableResp3, + + Logger: opt.Logger, } } @@ -287,6 +294,8 @@ func (opt *FailoverOptions) clusterOptions() *ClusterOptions { DisableIndentity: opt.DisableIndentity, IdentitySuffix: opt.IdentitySuffix, FailingTimeoutSeconds: opt.FailingTimeoutSeconds, + + Logger: opt.Logger, } } @@ -818,7 +827,7 @@ func (c *sentinelFailover) MasterAddr(ctx context.Context) (string, error) { return "", err } // Continue on other errors - internal.Logger.Printf(ctx, "sentinel: GetMasterAddrByName name=%q failed: %s", + c.logf(ctx, "sentinel: GetMasterAddrByName name=%q failed: %s", c.opt.MasterName, err) } else { return addr, nil @@ -836,7 +845,7 @@ func (c *sentinelFailover) MasterAddr(ctx context.Context) (string, error) { return "", err } // Continue on other errors - internal.Logger.Printf(ctx, "sentinel: GetMasterAddrByName name=%q failed: %s", + c.logf(ctx, "sentinel: GetMasterAddrByName name=%q failed: %s", c.opt.MasterName, err) } else { return addr, nil @@ -860,7 +869,7 @@ func (c *sentinelFailover) MasterAddr(ctx context.Context) (string, error) { sentinelCli := NewSentinelClient(c.opt.sentinelOptions(addr)) addrVal, err := sentinelCli.GetMasterAddrByName(ctx, c.opt.MasterName).Result() if err != nil { - internal.Logger.Printf(ctx, "sentinel: GetMasterAddrByName addr=%s, master=%q failed: %s", + c.logf(ctx, "sentinel: GetMasterAddrByName addr=%s, master=%q failed: %s", addr, c.opt.MasterName, err) _ = sentinelCli.Close() errCh <- err @@ -871,7 +880,7 @@ func (c *sentinelFailover) MasterAddr(ctx context.Context) (string, error) { // Push working sentinel to the top c.sentinelAddrs[0], c.sentinelAddrs[i] = c.sentinelAddrs[i], c.sentinelAddrs[0] c.setSentinel(ctx, sentinelCli) - internal.Logger.Printf(ctx, "sentinel: selected addr=%s masterAddr=%s", addr, masterAddr) + c.logf(ctx, "sentinel: selected addr=%s masterAddr=%s", addr, masterAddr) cancel() }) }(i, sentinelAddr) @@ -914,7 +923,7 @@ func (c *sentinelFailover) replicaAddrs(ctx context.Context, useDisconnected boo return nil, err } // Continue on other errors - internal.Logger.Printf(ctx, "sentinel: Replicas name=%q failed: %s", + c.logf(ctx, "sentinel: Replicas name=%q failed: %s", c.opt.MasterName, err) } else if len(addrs) > 0 { return addrs, nil @@ -932,7 +941,7 @@ func (c *sentinelFailover) replicaAddrs(ctx context.Context, useDisconnected boo return nil, err } // Continue on other errors - internal.Logger.Printf(ctx, "sentinel: Replicas name=%q failed: %s", + c.logf(ctx, "sentinel: Replicas name=%q failed: %s", c.opt.MasterName, err) } else if len(addrs) > 0 { return addrs, nil @@ -953,7 +962,7 @@ func (c *sentinelFailover) replicaAddrs(ctx context.Context, useDisconnected boo if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return nil, err } - internal.Logger.Printf(ctx, "sentinel: Replicas master=%q failed: %s", + c.logf(ctx, "sentinel: Replicas master=%q failed: %s", c.opt.MasterName, err) continue } @@ -986,7 +995,7 @@ func (c *sentinelFailover) getMasterAddr(ctx context.Context, sentinel *Sentinel func (c *sentinelFailover) getReplicaAddrs(ctx context.Context, sentinel *SentinelClient) ([]string, error) { addrs, err := sentinel.Replicas(ctx, c.opt.MasterName).Result() if err != nil { - internal.Logger.Printf(ctx, "sentinel: Replicas name=%q failed: %s", + c.logf(ctx, "sentinel: Replicas name=%q failed: %s", c.opt.MasterName, err) return nil, err } @@ -1034,7 +1043,7 @@ func (c *sentinelFailover) trySwitchMaster(ctx context.Context, addr string) { } c.masterAddr = addr - internal.Logger.Printf(ctx, "sentinel: new master=%q addr=%q", + c.logf(ctx, "sentinel: new master=%q addr=%q", c.opt.MasterName, addr) if c.onFailover != nil { c.onFailover(ctx, addr) @@ -1055,7 +1064,7 @@ func (c *sentinelFailover) setSentinel(ctx context.Context, sentinel *SentinelCl func (c *sentinelFailover) discoverSentinels(ctx context.Context) { sentinels, err := c.sentinel.Sentinels(ctx, c.opt.MasterName).Result() if err != nil { - internal.Logger.Printf(ctx, "sentinel: Sentinels master=%q failed: %s", c.opt.MasterName, err) + c.logf(ctx, "sentinel: Sentinels master=%q failed: %s", c.opt.MasterName, err) return } for _, sentinel := range sentinels { @@ -1070,7 +1079,7 @@ func (c *sentinelFailover) discoverSentinels(ctx context.Context) { if ip != "" && port != "" { sentinelAddr := net.JoinHostPort(ip, port) if !contains(c.sentinelAddrs, sentinelAddr) { - internal.Logger.Printf(ctx, "sentinel: discovered new sentinel=%q for master=%q", + c.logf(ctx, "sentinel: discovered new sentinel=%q for master=%q", sentinelAddr, c.opt.MasterName) c.sentinelAddrs = append(c.sentinelAddrs, sentinelAddr) } @@ -1090,7 +1099,7 @@ func (c *sentinelFailover) listen(pubsub *PubSub) { if msg.Channel == "+switch-master" { parts := strings.Split(msg.Payload, " ") if parts[0] != c.opt.MasterName { - internal.Logger.Printf(pubsub.getContext(), "sentinel: ignore addr for master=%q", parts[0]) + c.logf(pubsub.getContext(), "sentinel: ignore addr for master=%q", parts[0]) continue } addr := net.JoinHostPort(parts[3], parts[4]) @@ -1103,6 +1112,16 @@ func (c *sentinelFailover) listen(pubsub *PubSub) { } } +func (c *sentinelFailover) logf(ctx context.Context, format string, args ...interface{}) { + logger := internal.Logger + + if c.opt.Logger != nil { + logger = *c.opt.Logger + } + + logger.Printf(ctx, format, args...) +} + func contains(slice []string, str string) bool { for _, s := range slice { if s == str {