diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1be1508..fd87867 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -49,7 +49,7 @@ jobs: with: fetch-depth: 1 - - uses: actions/cache@v2 + - uses: actions/cache@v4 with: path: | ~/go/pkg/mod diff --git a/sqlhooks.go b/sqlhooks.go index 3d52576..7b93bde 100644 --- a/sqlhooks.go +++ b/sqlhooks.go @@ -81,6 +81,10 @@ type Conn struct { } func (conn *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + return conn.prepareContext(ctx, query) +} + +func (conn *Conn) prepareContext(ctx context.Context, query string) (*Stmt, error) { var ( stmt driver.Stmt err error @@ -93,7 +97,7 @@ func (conn *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt } if err != nil { - return stmt, err + return nil, err } return &Stmt{stmt, conn.hooks, query}, nil @@ -139,21 +143,39 @@ func (conn *ExecerContext) execContext(ctx context.Context, query string, args [ } func (conn *ExecerContext) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + return execWithHooks(ctx, query, args, conn.hooks, func(ctx context.Context) (driver.Result, error) { + results, err := conn.execContext(ctx, query, args) + if err == nil || !errors.Is(err, driver.ErrSkip) { + return results, err + } + // If driver.ErrSkip is returned, we fall back to using Prepare + Statement to handle the query. + // We need to avoid executing the hooks twice since they were already run in ExecContext. + // This matches the behavior in database/sql when ExecContext returns ErrSkip. + stmt, err := conn.prepareContext(ctx, query) + if err != nil { + return nil, err + } + defer stmt.Close() + return stmt.execContext(ctx, args) + }) +} + +func execWithHooks(ctx context.Context, query string, args []driver.NamedValue, hooks Hooks, execer func(context.Context) (driver.Result, error)) (driver.Result, error) { var err error list := namedToInterface(args) // Exec `Before` Hooks - if ctx, err = conn.hooks.Before(ctx, query, list...); err != nil { + if ctx, err = hooks.Before(ctx, query, list...); err != nil { return nil, err } - results, err := conn.execContext(ctx, query, args) + results, err := execer(ctx) if err != nil { - return results, handlerErr(ctx, conn.hooks, err, query, list...) + return results, handlerErr(ctx, hooks, err, query, list...) } - if _, err := conn.hooks.After(ctx, query, list...); err != nil { + if _, err := hooks.After(ctx, query, list...); err != nil { return nil, err } @@ -201,21 +223,43 @@ func (conn *QueryerContext) queryContext(ctx context.Context, query string, args } func (conn *QueryerContext) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + return queryWithHooks(ctx, query, args, conn.hooks, func(ctx context.Context) (driver.Rows, error) { + rows, err := conn.queryContext(ctx, query, args) + if err == nil || !errors.Is(err, driver.ErrSkip) { + return rows, err + } + // If driver.ErrSkip is returned, we fall back to using Prepare + Statement to handle the query. + // We need to avoid executing the hooks twice since they were already run in QueryContext. + // This matches the behavior in database/sql when QueryContext returns ErrSkip. + stmt, err := conn.prepareContext(ctx, query) + if err != nil { + return nil, err + } + rows, err = stmt.queryContext(ctx, args) + if err != nil { + _ = stmt.Close() + return nil, err + } + return &rowsWrapper{rows: rows, closeStmt: stmt}, nil + }) +} + +func queryWithHooks(ctx context.Context, query string, args []driver.NamedValue, hooks Hooks, queryer func(context.Context) (driver.Rows, error)) (driver.Rows, error) { var err error list := namedToInterface(args) // Query `Before` Hooks - if ctx, err = conn.hooks.Before(ctx, query, list...); err != nil { + if ctx, err = hooks.Before(ctx, query, list...); err != nil { return nil, err } - results, err := conn.queryContext(ctx, query, args) + results, err := queryer(ctx) if err != nil { - return results, handlerErr(ctx, conn.hooks, err, query, list...) + return results, handlerErr(ctx, hooks, err, query, list...) } - if _, err := conn.hooks.After(ctx, query, list...); err != nil { + if _, err := hooks.After(ctx, query, list...); err != nil { return nil, err } @@ -264,25 +308,9 @@ func (stmt *Stmt) execContext(ctx context.Context, args []driver.NamedValue) (dr } func (stmt *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { - var err error - - list := namedToInterface(args) - - // Exec `Before` Hooks - if ctx, err = stmt.hooks.Before(ctx, stmt.query, list...); err != nil { - return nil, err - } - - results, err := stmt.execContext(ctx, args) - if err != nil { - return results, handlerErr(ctx, stmt.hooks, err, stmt.query, list...) - } - - if _, err := stmt.hooks.After(ctx, stmt.query, list...); err != nil { - return nil, err - } - - return results, err + return execWithHooks(ctx, stmt.query, args, stmt.hooks, func(ctx context.Context) (driver.Result, error) { + return stmt.execContext(ctx, args) + }) } func (stmt *Stmt) queryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { @@ -298,25 +326,9 @@ func (stmt *Stmt) queryContext(ctx context.Context, args []driver.NamedValue) (d } func (stmt *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { - var err error - - list := namedToInterface(args) - - // Exec Before Hooks - if ctx, err = stmt.hooks.Before(ctx, stmt.query, list...); err != nil { - return nil, err - } - - rows, err := stmt.queryContext(ctx, args) - if err != nil { - return rows, handlerErr(ctx, stmt.hooks, err, stmt.query, list...) - } - - if _, err := stmt.hooks.After(ctx, stmt.query, list...); err != nil { - return nil, err - } - - return rows, err + return queryWithHooks(ctx, stmt.query, args, stmt.hooks, func(ctx context.Context) (driver.Rows, error) { + return stmt.queryContext(ctx, args) + }) } func (stmt *Stmt) Close() error { return stmt.Stmt.Close() } @@ -350,6 +362,27 @@ func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { return dargs, nil } +type rowsWrapper struct { + rows driver.Rows + closeStmt driver.Stmt // if non-nil, statement to Close on close +} + +func (r *rowsWrapper) Close() error { + err := r.rows.Close() + if r.closeStmt != nil { + _ = r.closeStmt.Close() + } + return err +} + +func (r *rowsWrapper) Columns() []string { + return r.rows.Columns() +} + +func (r *rowsWrapper) Next(dest []driver.Value) error { + return r.rows.Next(dest) +} + /* type hooks struct { } diff --git a/sqlhooks_test.go b/sqlhooks_test.go index 26b8b87..1948afe 100644 --- a/sqlhooks_test.go +++ b/sqlhooks_test.go @@ -68,62 +68,62 @@ func newSuite(t *testing.T, driver driver.Driver, dsn string) *suite { } func (s *suite) TestHooksExecution(t *testing.T, query string, args ...interface{}) { - var before, after bool + var beforeCount, afterCount int s.hooks.before = func(ctx context.Context, q string, a ...interface{}) (context.Context, error) { - before = true + beforeCount++ return ctx, nil } s.hooks.after = func(ctx context.Context, q string, a ...interface{}) (context.Context, error) { - after = true + afterCount++ return ctx, nil } t.Run("Query", func(t *testing.T) { - before, after = false, false + beforeCount, afterCount = 0, 0 _, err := s.db.Query(query, args...) require.NoError(t, err) - assert.True(t, before, "Before Hook did not run for query: "+query) - assert.True(t, after, "After Hook did not run for query: "+query) + assert.Equal(t, 1, beforeCount, "Before Hook didn't execute only once: "+query) + assert.Equal(t, 1, afterCount, "After Hook didn't execute only once: "+query) }) t.Run("QueryContext", func(t *testing.T) { - before, after = false, false + beforeCount, afterCount = 0, 0 _, err := s.db.QueryContext(context.Background(), query, args...) require.NoError(t, err) - assert.True(t, before, "Before Hook did not run for query: "+query) - assert.True(t, after, "After Hook did not run for query: "+query) + assert.Equal(t, 1, beforeCount, "Before Hook didn't execute only once: "+query) + assert.Equal(t, 1, afterCount, "After Hook didn't execute only once: "+query) }) t.Run("Exec", func(t *testing.T) { - before, after = false, false + beforeCount, afterCount = 0, 0 _, err := s.db.Exec(query, args...) require.NoError(t, err) - assert.True(t, before, "Before Hook did not run for query: "+query) - assert.True(t, after, "After Hook did not run for query: "+query) + assert.Equal(t, 1, beforeCount, "Before Hook didn't execute only once: "+query) + assert.Equal(t, 1, afterCount, "After Hook didn't execute only once: "+query) }) t.Run("ExecContext", func(t *testing.T) { - before, after = false, false + beforeCount, afterCount = 0, 0 _, err := s.db.ExecContext(context.Background(), query, args...) require.NoError(t, err) - assert.True(t, before, "Before Hook did not run for query: "+query) - assert.True(t, after, "After Hook did not run for query: "+query) + assert.Equal(t, 1, beforeCount, "Before Hook didn't execute only once: "+query) + assert.Equal(t, 1, afterCount, "After Hook didn't execute only once: "+query) }) t.Run("Statements", func(t *testing.T) { - before, after = false, false + beforeCount, afterCount = 0, 0 stmt, err := s.db.Prepare(query) require.NoError(t, err) // Hooks just run when the stmt is executed (Query or Exec) - assert.False(t, before, "Before Hook run before execution: "+query) - assert.False(t, after, "After Hook run before execution: "+query) + assert.Equal(t, 0, beforeCount, "Before Hook run before execution: "+query) + assert.Equal(t, 0, afterCount, "After Hook run before execution: "+query) _, err = stmt.Query(args...) require.NoError(t, err) - assert.True(t, before, "Before Hook did not run for query: "+query) - assert.True(t, after, "After Hook did not run for query: "+query) + assert.Equal(t, 1, beforeCount, "Before Hook didn't execute only once: "+query) + assert.Equal(t, 1, afterCount, "After Hook didn't execute only once: "+query) }) }