From 3a9a71a7e9fbc062b7f691b4a34d8587e5dc7ac3 Mon Sep 17 00:00:00 2001 From: fangwentong Date: Fri, 9 May 2025 12:06:40 +0800 Subject: [PATCH 1/2] fix: handle driver.ErrSkip to avoid duplicate hooks execution with MySQL driver When InterpolateParams=false is set in MySQL driver, it returns driver.ErrSkip which causes the SQL package to fall back to prepared statements, resulting in hooks being executed twice. This change handles driver.ErrSkip internally to ensure hooks are only executed once per logical operation. --- sqlhooks.go | 127 +++++++++++++++++++++++++++++------------------ sqlhooks_test.go | 40 +++++++-------- 2 files changed, 100 insertions(+), 67 deletions(-) diff --git a/sqlhooks.go b/sqlhooks.go index 3d52576..7b93bde 100644 --- a/sqlhooks.go +++ b/sqlhooks.go @@ -81,6 +81,10 @@ type Conn struct { } func (conn *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + return conn.prepareContext(ctx, query) +} + +func (conn *Conn) prepareContext(ctx context.Context, query string) (*Stmt, error) { var ( stmt driver.Stmt err error @@ -93,7 +97,7 @@ func (conn *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt } if err != nil { - return stmt, err + return nil, err } return &Stmt{stmt, conn.hooks, query}, nil @@ -139,21 +143,39 @@ func (conn *ExecerContext) execContext(ctx context.Context, query string, args [ } func (conn *ExecerContext) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + return execWithHooks(ctx, query, args, conn.hooks, func(ctx context.Context) (driver.Result, error) { + results, err := conn.execContext(ctx, query, args) + if err == nil || !errors.Is(err, driver.ErrSkip) { + return results, err + } + // If driver.ErrSkip is returned, we fall back to using Prepare + Statement to handle the query. + // We need to avoid executing the hooks twice since they were already run in ExecContext. + // This matches the behavior in database/sql when ExecContext returns ErrSkip. + stmt, err := conn.prepareContext(ctx, query) + if err != nil { + return nil, err + } + defer stmt.Close() + return stmt.execContext(ctx, args) + }) +} + +func execWithHooks(ctx context.Context, query string, args []driver.NamedValue, hooks Hooks, execer func(context.Context) (driver.Result, error)) (driver.Result, error) { var err error list := namedToInterface(args) // Exec `Before` Hooks - if ctx, err = conn.hooks.Before(ctx, query, list...); err != nil { + if ctx, err = hooks.Before(ctx, query, list...); err != nil { return nil, err } - results, err := conn.execContext(ctx, query, args) + results, err := execer(ctx) if err != nil { - return results, handlerErr(ctx, conn.hooks, err, query, list...) + return results, handlerErr(ctx, hooks, err, query, list...) } - if _, err := conn.hooks.After(ctx, query, list...); err != nil { + if _, err := hooks.After(ctx, query, list...); err != nil { return nil, err } @@ -201,21 +223,43 @@ func (conn *QueryerContext) queryContext(ctx context.Context, query string, args } func (conn *QueryerContext) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + return queryWithHooks(ctx, query, args, conn.hooks, func(ctx context.Context) (driver.Rows, error) { + rows, err := conn.queryContext(ctx, query, args) + if err == nil || !errors.Is(err, driver.ErrSkip) { + return rows, err + } + // If driver.ErrSkip is returned, we fall back to using Prepare + Statement to handle the query. + // We need to avoid executing the hooks twice since they were already run in QueryContext. + // This matches the behavior in database/sql when QueryContext returns ErrSkip. + stmt, err := conn.prepareContext(ctx, query) + if err != nil { + return nil, err + } + rows, err = stmt.queryContext(ctx, args) + if err != nil { + _ = stmt.Close() + return nil, err + } + return &rowsWrapper{rows: rows, closeStmt: stmt}, nil + }) +} + +func queryWithHooks(ctx context.Context, query string, args []driver.NamedValue, hooks Hooks, queryer func(context.Context) (driver.Rows, error)) (driver.Rows, error) { var err error list := namedToInterface(args) // Query `Before` Hooks - if ctx, err = conn.hooks.Before(ctx, query, list...); err != nil { + if ctx, err = hooks.Before(ctx, query, list...); err != nil { return nil, err } - results, err := conn.queryContext(ctx, query, args) + results, err := queryer(ctx) if err != nil { - return results, handlerErr(ctx, conn.hooks, err, query, list...) + return results, handlerErr(ctx, hooks, err, query, list...) } - if _, err := conn.hooks.After(ctx, query, list...); err != nil { + if _, err := hooks.After(ctx, query, list...); err != nil { return nil, err } @@ -264,25 +308,9 @@ func (stmt *Stmt) execContext(ctx context.Context, args []driver.NamedValue) (dr } func (stmt *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { - var err error - - list := namedToInterface(args) - - // Exec `Before` Hooks - if ctx, err = stmt.hooks.Before(ctx, stmt.query, list...); err != nil { - return nil, err - } - - results, err := stmt.execContext(ctx, args) - if err != nil { - return results, handlerErr(ctx, stmt.hooks, err, stmt.query, list...) - } - - if _, err := stmt.hooks.After(ctx, stmt.query, list...); err != nil { - return nil, err - } - - return results, err + return execWithHooks(ctx, stmt.query, args, stmt.hooks, func(ctx context.Context) (driver.Result, error) { + return stmt.execContext(ctx, args) + }) } func (stmt *Stmt) queryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { @@ -298,25 +326,9 @@ func (stmt *Stmt) queryContext(ctx context.Context, args []driver.NamedValue) (d } func (stmt *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { - var err error - - list := namedToInterface(args) - - // Exec Before Hooks - if ctx, err = stmt.hooks.Before(ctx, stmt.query, list...); err != nil { - return nil, err - } - - rows, err := stmt.queryContext(ctx, args) - if err != nil { - return rows, handlerErr(ctx, stmt.hooks, err, stmt.query, list...) - } - - if _, err := stmt.hooks.After(ctx, stmt.query, list...); err != nil { - return nil, err - } - - return rows, err + return queryWithHooks(ctx, stmt.query, args, stmt.hooks, func(ctx context.Context) (driver.Rows, error) { + return stmt.queryContext(ctx, args) + }) } func (stmt *Stmt) Close() error { return stmt.Stmt.Close() } @@ -350,6 +362,27 @@ func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { return dargs, nil } +type rowsWrapper struct { + rows driver.Rows + closeStmt driver.Stmt // if non-nil, statement to Close on close +} + +func (r *rowsWrapper) Close() error { + err := r.rows.Close() + if r.closeStmt != nil { + _ = r.closeStmt.Close() + } + return err +} + +func (r *rowsWrapper) Columns() []string { + return r.rows.Columns() +} + +func (r *rowsWrapper) Next(dest []driver.Value) error { + return r.rows.Next(dest) +} + /* type hooks struct { } diff --git a/sqlhooks_test.go b/sqlhooks_test.go index 26b8b87..1948afe 100644 --- a/sqlhooks_test.go +++ b/sqlhooks_test.go @@ -68,62 +68,62 @@ func newSuite(t *testing.T, driver driver.Driver, dsn string) *suite { } func (s *suite) TestHooksExecution(t *testing.T, query string, args ...interface{}) { - var before, after bool + var beforeCount, afterCount int s.hooks.before = func(ctx context.Context, q string, a ...interface{}) (context.Context, error) { - before = true + beforeCount++ return ctx, nil } s.hooks.after = func(ctx context.Context, q string, a ...interface{}) (context.Context, error) { - after = true + afterCount++ return ctx, nil } t.Run("Query", func(t *testing.T) { - before, after = false, false + beforeCount, afterCount = 0, 0 _, err := s.db.Query(query, args...) require.NoError(t, err) - assert.True(t, before, "Before Hook did not run for query: "+query) - assert.True(t, after, "After Hook did not run for query: "+query) + assert.Equal(t, 1, beforeCount, "Before Hook didn't execute only once: "+query) + assert.Equal(t, 1, afterCount, "After Hook didn't execute only once: "+query) }) t.Run("QueryContext", func(t *testing.T) { - before, after = false, false + beforeCount, afterCount = 0, 0 _, err := s.db.QueryContext(context.Background(), query, args...) require.NoError(t, err) - assert.True(t, before, "Before Hook did not run for query: "+query) - assert.True(t, after, "After Hook did not run for query: "+query) + assert.Equal(t, 1, beforeCount, "Before Hook didn't execute only once: "+query) + assert.Equal(t, 1, afterCount, "After Hook didn't execute only once: "+query) }) t.Run("Exec", func(t *testing.T) { - before, after = false, false + beforeCount, afterCount = 0, 0 _, err := s.db.Exec(query, args...) require.NoError(t, err) - assert.True(t, before, "Before Hook did not run for query: "+query) - assert.True(t, after, "After Hook did not run for query: "+query) + assert.Equal(t, 1, beforeCount, "Before Hook didn't execute only once: "+query) + assert.Equal(t, 1, afterCount, "After Hook didn't execute only once: "+query) }) t.Run("ExecContext", func(t *testing.T) { - before, after = false, false + beforeCount, afterCount = 0, 0 _, err := s.db.ExecContext(context.Background(), query, args...) require.NoError(t, err) - assert.True(t, before, "Before Hook did not run for query: "+query) - assert.True(t, after, "After Hook did not run for query: "+query) + assert.Equal(t, 1, beforeCount, "Before Hook didn't execute only once: "+query) + assert.Equal(t, 1, afterCount, "After Hook didn't execute only once: "+query) }) t.Run("Statements", func(t *testing.T) { - before, after = false, false + beforeCount, afterCount = 0, 0 stmt, err := s.db.Prepare(query) require.NoError(t, err) // Hooks just run when the stmt is executed (Query or Exec) - assert.False(t, before, "Before Hook run before execution: "+query) - assert.False(t, after, "After Hook run before execution: "+query) + assert.Equal(t, 0, beforeCount, "Before Hook run before execution: "+query) + assert.Equal(t, 0, afterCount, "After Hook run before execution: "+query) _, err = stmt.Query(args...) require.NoError(t, err) - assert.True(t, before, "Before Hook did not run for query: "+query) - assert.True(t, after, "After Hook did not run for query: "+query) + assert.Equal(t, 1, beforeCount, "Before Hook didn't execute only once: "+query) + assert.Equal(t, 1, afterCount, "After Hook didn't execute only once: "+query) }) } From 8c3d8136efba26f22f1297baaaa19468b3c9b3a0 Mon Sep 17 00:00:00 2001 From: fangwentong Date: Wed, 14 May 2025 10:51:21 +0800 Subject: [PATCH 2/2] ci: bump actions/cache version from v2 to v4, fix Github Action --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1be1508..fd87867 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -49,7 +49,7 @@ jobs: with: fetch-depth: 1 - - uses: actions/cache@v2 + - uses: actions/cache@v4 with: path: | ~/go/pkg/mod