diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst index 995e7676afbca..295e59273413b 100644 --- a/doc/source/whatsnew/v3.0.0.rst +++ b/doc/source/whatsnew/v3.0.0.rst @@ -156,6 +156,7 @@ All warnings for upcoming changes in pandas will have the base class :class:`pan Other enhancements ^^^^^^^^^^^^^^^^^^ +- :func:`DataFrame.to_sql` now accepts a ``hints`` parameter to pass database-specific query hints for optimizing insert performance. The hints are specified as a dictionary mapping dialect names to hint strings (e.g., ``{'oracle': '/*+ APPEND PARALLEL(4) */', 'mysql': 'DELAYED'}``). Users are responsible for providing correctly formatted hint strings for their target database (:issue:`61370`) - :func:`pandas.merge` propagates the ``attrs`` attribute to the result if all inputs have identical ``attrs``, as has so far already been the case for :func:`pandas.concat`. diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 43078ef3a263c..f7edd418f4f14 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -2798,6 +2798,7 @@ def to_sql( chunksize: int | None = None, dtype: DtypeArg | None = None, method: Literal["multi"] | Callable | None = None, + hints: dict[str, str] | None = None, ) -> int | None: """ Write records stored in a DataFrame to a SQL database. @@ -2861,6 +2862,21 @@ def to_sql( Details and a sample callable implementation can be found in the section :ref:`insert method `. + hints : dict[str, str], optional + Dictionary of SQL hints to optimize insertion performance, keyed by + database dialect name (e.g., 'oracle', 'mysql', 'postgresql', 'mssql'). + Each value should be a complete hint string formatted exactly as required + by the target database. The user is responsible for providing correctly + formatted hint strings. + + Examples: ``{'oracle': '/*+ APPEND PARALLEL(4) */', 'mysql': 'DELAYED'}`` + + .. note:: + - Hints are database-specific and ignored for unsupported dialects. + - SQLite raises a ``UserWarning`` (hints not supported). + - ADBC connections raise ``NotImplementedError``. + + .. versionadded:: 3.0.0 Returns ------- @@ -3044,6 +3060,7 @@ def to_sql( chunksize=chunksize, dtype=dtype, method=method, + hints=hints, ) @final diff --git a/pandas/io/sql.py b/pandas/io/sql.py index 0247a4b1da8dd..c5b45ee2a0724 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -18,7 +18,6 @@ datetime, time, ) -from functools import partial import re from typing import ( TYPE_CHECKING, @@ -235,6 +234,18 @@ def _wrap_result_adbc( return df +def _process_sql_hints(hints: dict[str, str] | None, dialect_name: str) -> str | None: + if hints is None: + return None + + dialect_name = dialect_name.lower() + for key, value in hints.items(): + if key.lower() == dialect_name: + return value + + return None + + # ----------------------------------------------------------------------------- # -- Read and write to DataFrames @@ -753,6 +764,7 @@ def to_sql( dtype: DtypeArg | None = None, method: Literal["multi"] | Callable | None = None, engine: str = "auto", + hints: dict[str, str] | None = None, **engine_kwargs, ) -> int | None: """ @@ -813,6 +825,23 @@ def to_sql( .. versionadded:: 1.3.0 + hints : dict[str, str], optional + SQL hints to optimize insertion performance, keyed by database dialect name. + Each hint should be a complete string formatted exactly as required by the + target database. The user is responsible for constructing dialect-specific + syntax. + + Examples: ``{'oracle': '/*+ APPEND PARALLEL(4) */'}`` + ``{'mysql': 'DELAYED'}`` + ``{'mssql': 'WITH (TABLOCK)'}`` + + .. note:: + - Hints are database-specific and will be ignored for unsupported dialects + - SQLite will raise a UserWarning (hints not supported) + - ADBC connections will raise NotImplementedError + + .. versionadded:: 3.0.0 + **engine_kwargs Any additional kwargs are passed to the engine. @@ -855,6 +884,7 @@ def to_sql( dtype=dtype, method=method, engine=engine, + hints=hints, **engine_kwargs, ) @@ -1004,7 +1034,13 @@ def create(self) -> None: else: self._execute_create() - def _execute_insert(self, conn, keys: list[str], data_iter) -> int: + def _execute_insert( + self, + conn, + keys: list[str], + data_iter, + hint_str: str | None = None, + ) -> int: """ Execute SQL statement inserting data @@ -1016,11 +1052,23 @@ def _execute_insert(self, conn, keys: list[str], data_iter) -> int: data_iter : generator of list Each item contains a list of values to be inserted """ - data = [dict(zip(keys, row, strict=True)) for row in data_iter] - result = self.pd_sql.execute(self.table.insert(), data) + data = [dict(zip(keys, row, strict=False)) for row in data_iter] + + if hint_str: + stmt = self.table.insert().prefix_with(hint_str) + else: + stmt = self.table.insert() + + result = self.pd_sql.execute(stmt, data) return result.rowcount - def _execute_insert_multi(self, conn, keys: list[str], data_iter) -> int: + def _execute_insert_multi( + self, + conn, + keys: list[str], + data_iter, + hint_str: str | None = None, + ) -> int: """ Alternative to _execute_insert for DBs support multi-value INSERT. @@ -1029,11 +1077,15 @@ def _execute_insert_multi(self, conn, keys: list[str], data_iter) -> int: but performance degrades quickly with increase of columns. """ - from sqlalchemy import insert - data = [dict(zip(keys, row, strict=True)) for row in data_iter] - stmt = insert(self.table).values(data) + data = [dict(zip(keys, row, strict=False)) for row in data_iter] + + if hint_str: + stmt = insert(self.table).values(data).prefix_with(hint_str) + else: + stmt = insert(self.table).values(data) + result = self.pd_sql.execute(stmt) return result.rowcount @@ -1090,6 +1142,8 @@ def insert( self, chunksize: int | None = None, method: Literal["multi"] | Callable | None = None, + hints: dict[str, str] | None = None, + dialect_name: str | None = None, ) -> int | None: # set insert method if method is None: @@ -1097,7 +1151,11 @@ def insert( elif method == "multi": exec_insert = self._execute_insert_multi elif callable(method): - exec_insert = partial(method, self) + + def callable_wrapper(conn, keys, data_iter, hint_str=None): + return method(self, conn, keys, data_iter) + + exec_insert = callable_wrapper else: raise ValueError(f"Invalid parameter `method`: {method}") @@ -1114,6 +1172,9 @@ def insert( raise ValueError("chunksize argument should be non-zero") chunks = (nrows // chunksize) + 1 + + hint_str = _process_sql_hints(hints, dialect_name) if dialect_name else None + total_inserted = None with self.pd_sql.run_transaction() as conn: for i in range(chunks): @@ -1125,7 +1186,7 @@ def insert( chunk_iter = zip( *(arr[start_i:end_i] for arr in data_list), strict=True ) - num_inserted = exec_insert(conn, keys, chunk_iter) + num_inserted = exec_insert(conn, keys, chunk_iter, hint_str) # GH 46891 if num_inserted is not None: if total_inserted is None: @@ -1509,6 +1570,7 @@ def to_sql( chunksize: int | None = None, dtype: DtypeArg | None = None, method: Literal["multi"] | Callable | None = None, + hints: dict[str, str] | None = None, engine: str = "auto", **engine_kwargs, ) -> int | None: @@ -1545,6 +1607,8 @@ def insert_records( schema=None, chunksize: int | None = None, method=None, + hints: dict[str, str] | None = None, + dialect_name: str | None = None, **engine_kwargs, ) -> int | None: """ @@ -1569,6 +1633,8 @@ def insert_records( schema=None, chunksize: int | None = None, method=None, + hints: dict[str, str] | None = None, + dialect_name: str | None = None, **engine_kwargs, ) -> int | None: from sqlalchemy import exc @@ -1980,6 +2046,7 @@ def to_sql( chunksize: int | None = None, dtype: DtypeArg | None = None, method: Literal["multi"] | Callable | None = None, + hints: dict[str, str] | None = None, engine: str = "auto", **engine_kwargs, ) -> int | None: @@ -2053,6 +2120,8 @@ def to_sql( schema=schema, chunksize=chunksize, method=method, + hints=hints, + dialect_name=self.con.dialect.name, **engine_kwargs, ) @@ -2344,6 +2413,7 @@ def to_sql( chunksize: int | None = None, dtype: DtypeArg | None = None, method: Literal["multi"] | Callable | None = None, + hints: dict[str, str] | None = None, engine: str = "auto", **engine_kwargs, ) -> int | None: @@ -2394,6 +2464,8 @@ def to_sql( raise NotImplementedError( "engine != 'auto' not implemented for ADBC drivers" ) + if hints: + raise NotImplementedError("'hints' is not implemented for ADBC drivers") if schema: table_name = f"{schema}.{name}" @@ -2575,7 +2647,9 @@ def insert_statement(self, *, num_rows: int) -> str: ) return insert_statement - def _execute_insert(self, conn, keys, data_iter) -> int: + def _execute_insert( + self, conn, keys: list[str], data_iter, hint_str: str | None = None + ) -> int: from sqlite3 import Error data_list = list(data_iter) @@ -2585,7 +2659,9 @@ def _execute_insert(self, conn, keys, data_iter) -> int: raise DatabaseError("Execution failed") from exc return conn.rowcount - def _execute_insert_multi(self, conn, keys, data_iter) -> int: + def _execute_insert_multi( + self, conn, keys: list[str], data_iter, hint_str: str | None = None + ) -> int: data_list = list(data_iter) flattened_data = [x for row in data_list for x in row] conn.execute(self.insert_statement(num_rows=len(data_list)), flattened_data) @@ -2821,6 +2897,7 @@ def to_sql( chunksize: int | None = None, dtype: DtypeArg | None = None, method: Literal["multi"] | Callable | None = None, + hints: dict[str, str] | None = None, engine: str = "auto", **engine_kwargs, ) -> int | None: @@ -2863,6 +2940,13 @@ def to_sql( Details and a sample callable implementation can be found in the section :ref:`insert method `. """ + if hints: + warnings.warn( + "SQL hints are not supported for SQLite and will be ignored.", + UserWarning, + stacklevel=find_stack_level(), + ) + if dtype: if not is_dict_like(dtype): # error: Value expression in dictionary comprehension has incompatible diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index 5865c46b4031e..21ac722246836 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -4398,3 +4398,193 @@ def test_xsqlite_if_exists(sqlite_buildin): (5, "E"), ] drop_table(table_name, sqlite_buildin) + + +# ----------------------------------------------------------------------------- +# -- Testing SQL Hints + + +class TestProcessSQLHints: + """Tests for _process_sql_hints helper function.""" + + def test_process_sql_hints_oracle_string(self): + """Test hint processing with Oracle dialect - user provides complete string.""" + hints = {"oracle": "/*+ APPEND PARALLEL */"} + result = sql._process_sql_hints(hints, "oracle") + assert result == "/*+ APPEND PARALLEL */" + + def test_process_sql_hints_oracle_simple(self): + """Test hint processing with simple Oracle hint string.""" + hints = {"oracle": "/*+ PARALLEL */"} + result = sql._process_sql_hints(hints, "oracle") + assert result == "/*+ PARALLEL */" + + def test_process_sql_hints_case_insensitive(self): + """Test that dialect names are case-insensitive.""" + hints = {"ORACLE": "/*+ APPEND */"} + result = sql._process_sql_hints(hints, "oracle") + assert result == "/*+ APPEND */" + + hints = {"oracle": "/*+ APPEND */"} + result = sql._process_sql_hints(hints, "ORACLE") + assert result == "/*+ APPEND */" + + def test_process_sql_hints_no_match(self): + """Test that None is returned when dialect doesn't match.""" + hints = {"mysql": "HIGH_PRIORITY"} + result = sql._process_sql_hints(hints, "oracle") + assert result is None + + def test_process_sql_hints_none(self): + """Test that None input returns None.""" + result = sql._process_sql_hints(None, "oracle") + assert result is None + + def test_process_sql_hints_empty_dict(self): + """Test that empty dict returns None.""" + result = sql._process_sql_hints({}, "oracle") + assert result is None + + def test_process_sql_hints_mysql(self): + """Test hint processing for MySQL dialect.""" + hints = {"mysql": "HIGH_PRIORITY"} + result = sql._process_sql_hints(hints, "mysql") + assert result == "HIGH_PRIORITY" + + def test_process_sql_hints_mssql(self): + """Test hint processing for SQL Server dialect.""" + hints = {"mssql": "WITH (TABLOCK)"} + result = sql._process_sql_hints(hints, "mssql") + assert result == "WITH (TABLOCK)" + + def test_process_sql_hints_multiple_dialects(self): + """Test extraction from dict with multiple dialects.""" + hints = { + "oracle": "/*+ PARALLEL */", + "mysql": "DELAYED", + "postgresql": "/* comment */", + } + assert sql._process_sql_hints(hints, "oracle") == "/*+ PARALLEL */" + assert sql._process_sql_hints(hints, "mysql") == "DELAYED" + assert sql._process_sql_hints(hints, "postgresql") == "/* comment */" + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_to_sql_with_hints_parameter(conn, test_frame1, request): + """Test that to_sql accepts hints parameter without error.""" + conn = request.getfixturevalue(conn) + + with pandasSQL_builder(conn, need_transaction=True) as pandasSQL: + pandasSQL.to_sql( + test_frame1, + "test_hints", + hints={"oracle": "/*+ APPEND */"}, + if_exists="replace", + ) + assert pandasSQL.has_table("test_hints") + assert count_rows(conn, "test_hints") == len(test_frame1) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_to_sql_hints_none_default(conn, test_frame1, request): + """Test that hints defaults to None and doesn't break existing code.""" + conn = request.getfixturevalue(conn) + + with pandasSQL_builder(conn, need_transaction=True) as pandasSQL: + pandasSQL.to_sql(test_frame1, "test_no_hints") + assert pandasSQL.has_table("test_no_hints") + assert count_rows(conn, "test_no_hints") == len(test_frame1) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_to_sql_hints_with_method(conn, test_frame1, request): + """Test that hints work alongside custom method parameter.""" + conn = request.getfixturevalue(conn) + + check = [] + + def sample(pd_table, conn, keys, data_iter): + check.append(1) + data = [dict(zip(keys, row)) for row in data_iter] + conn.execute(pd_table.table.insert(), data) + + with pandasSQL_builder(conn, need_transaction=True) as pandasSQL: + pandasSQL.to_sql( + test_frame1, + "test_hints_method", + method=sample, + hints={"oracle": "/*+ APPEND */"}, + ) + assert pandasSQL.has_table("test_hints_method") + + assert check == [1] + assert count_rows(conn, "test_hints_method") == len(test_frame1) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +@pytest.mark.parametrize("method", [None, "multi"]) +def test_to_sql_hints_with_different_methods(conn, method, test_frame1, request): + """Test hints work with different insertion methods.""" + conn = request.getfixturevalue(conn) + + with pandasSQL_builder(conn, need_transaction=True) as pandasSQL: + pandasSQL.to_sql( + test_frame1, + "test_hints_methods", + method=method, + hints={"oracle": "/*+ APPEND PARALLEL */"}, + if_exists="replace", + ) + assert pandasSQL.has_table("test_hints_methods") + + assert count_rows(conn, "test_hints_methods") == len(test_frame1) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_to_sql_hints_multidb_dict(conn, test_frame1, request): + """Test that multi-database hints dict works (only matching dialect used).""" + conn = request.getfixturevalue(conn) + + hints = { + "oracle": "/*+ APPEND PARALLEL */", + "mysql": "HIGH_PRIORITY", + "postgresql": "/* pg hint */", + "sqlite": "IGNORED", + } + + with pandasSQL_builder(conn, need_transaction=True) as pandasSQL: + pandasSQL.to_sql( + test_frame1, "test_multidb_hints", hints=hints, if_exists="replace" + ) + assert pandasSQL.has_table("test_multidb_hints") + + assert count_rows(conn, "test_multidb_hints") == len(test_frame1) + + +def test_to_sql_hints_adbc_not_supported(sqlite_adbc_conn, test_frame1): + """Test that ADBC connections raise NotImplementedError for hints.""" + pytest.importorskip("adbc_driver_manager.dbapi") + + df = test_frame1.copy() + msg = "'hints' is not implemented for ADBC drivers" + + with pytest.raises(NotImplementedError, match=msg): + df.to_sql("test", sqlite_adbc_conn, hints={"mysql": "SOME_HINT"}) + + +def test_to_sql_hints_sqlite_builtin(sqlite_buildin, test_frame1): + """Test that sqlite builtin connection handles hints gracefully.""" + df = test_frame1.copy() + + msg = "SQL hints are not supported for SQLite and will be ignored." + with tm.assert_produces_warning(UserWarning, match=msg): + result = df.to_sql( + "test_sqlite_hints", + sqlite_buildin, + if_exists="replace", + hints={"sqlite": "IGNORED"}, + ) + + assert result == len(test_frame1) + result_df = pd.read_sql("SELECT * FROM test_sqlite_hints", sqlite_buildin) + assert len(result_df) == len(test_frame1)