diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index 3355c07eff012..f597332563a14 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -14,8 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - - +import os from typing import Any, Optional from selenium.webdriver.common.bidi.common import command_builder @@ -236,3 +235,46 @@ def get_client_windows(self) -> list[ClientWindowInfo]: """ result = self.conn.execute(command_builder("browser.getClientWindows", {})) return [ClientWindowInfo.from_dict(window) for window in result["clientWindows"]] + + def set_download_behavior( + self, + *, + allowed: Optional[bool] = None, + destination_folder: Optional[str | os.PathLike] = None, + user_contexts: Optional[list[str]] = None, + ) -> None: + """Set the download behavior for the browser or specific user contexts. + + Args: + allowed: True to allow downloads, False to deny downloads, or None to + clear download behavior (revert to default). + destination_folder: Required when allowed is True. Specifies the folder + to store downloads in. + user_contexts: Optional list of user context IDs to apply this + behavior to. If omitted, updates the default behavior. + + Raises: + ValueError: If allowed=True and destination_folder is missing, or if + allowed=False and destination_folder is provided. + """ + params: dict[str, Any] = {} + + if allowed is None: + params["downloadBehavior"] = None + else: + if allowed: + if not destination_folder: + raise ValueError("destination_folder is required when allowed=True.") + params["downloadBehavior"] = { + "type": "allowed", + "destinationFolder": os.fspath(destination_folder), + } + else: + if destination_folder: + raise ValueError("destination_folder should not be provided when allowed=False.") + params["downloadBehavior"] = {"type": "denied"} + + if user_contexts is not None: + params["userContexts"] = user_contexts + + self.conn.execute(command_builder("browser.setDownloadBehavior", params)) diff --git a/py/test/selenium/webdriver/common/bidi_browser_tests.py b/py/test/selenium/webdriver/common/bidi_browser_tests.py index 74b406b54c22e..e89ccd3f0ff59 100644 --- a/py/test/selenium/webdriver/common/bidi_browser_tests.py +++ b/py/test/selenium/webdriver/common/bidi_browser_tests.py @@ -16,17 +16,21 @@ # under the License. import http.server +import os import socketserver import threading import pytest +from selenium.common.exceptions import TimeoutException from selenium.webdriver.common.bidi.browser import ClientWindowInfo, ClientWindowState +from selenium.webdriver.common.bidi.browsing_context import ReadinessState from selenium.webdriver.common.bidi.session import UserPromptHandler, UserPromptHandlerType from selenium.webdriver.common.by import By from selenium.webdriver.common.proxy import Proxy, ProxyType from selenium.webdriver.common.utils import free_port from selenium.webdriver.common.window import WindowTypes +from selenium.webdriver.support.ui import WebDriverWait class FakeProxyHandler(http.server.SimpleHTTPRequestHandler): @@ -262,3 +266,96 @@ def test_create_user_context_with_unhandled_prompt_behavior(driver, pages): # Clean up driver.browser.remove_user_context(user_context) + + +@pytest.mark.xfail_firefox +def test_set_download_behavior_allowed(driver, pages, tmp_path): + print(f"Driver info: {driver.capabilities}") + try: + driver.browser.set_download_behavior(allowed=True, destination_folder=tmp_path) + + context_id = driver.current_window_handle + url = pages.url("downloads/download.html") + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + + driver.find_element(By.ID, "file-1").click() + + WebDriverWait(driver, 5).until(lambda d: "file_1.txt" in os.listdir(tmp_path)) + + files = os.listdir(tmp_path) + assert "file_1.txt" in files, f"Expected file_1.txt in {tmp_path}, but found: {files}" + finally: + driver.browser.set_download_behavior(allowed=None) + + +@pytest.mark.xfail_firefox +def test_set_download_behavior_denied(driver, pages, tmp_path): + try: + driver.browser.set_download_behavior(allowed=False) + + context_id = driver.current_window_handle + url = pages.url("downloads/download.html") + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + + driver.find_element(By.ID, "file-1").click() + + try: + WebDriverWait(driver, 3, poll_frequency=0.2).until(lambda _: len(os.listdir(tmp_path)) > 0) + files = os.listdir(tmp_path) + pytest.fail(f"A file was downloaded unexpectedly: {files}") + except TimeoutException: + pass # Expected, no file downloaded + finally: + driver.browser.set_download_behavior(allowed=None) + + +@pytest.mark.xfail_firefox +def test_set_download_behavior_user_context(driver, pages, tmp_path): + user_context = driver.browser.create_user_context() + + try: + bc = driver.browsing_context.create(type=WindowTypes.WINDOW, user_context=user_context) + driver.switch_to.window(bc) + + try: + driver.browser.set_download_behavior( + allowed=True, destination_folder=tmp_path, user_contexts=[user_context] + ) + + url = pages.url("downloads/download.html") + driver.browsing_context.navigate(context=bc, url=url, wait=ReadinessState.COMPLETE) + + driver.find_element(By.ID, "file-1").click() + + WebDriverWait(driver, 5).until(lambda d: "file_1.txt" in os.listdir(tmp_path)) + + files = os.listdir(tmp_path) + assert "file_1.txt" in files, f"Expected file_1.txt in {tmp_path}, but found: {files}" + + initial_file_count = len(files) + + driver.browser.set_download_behavior(allowed=False, user_contexts=[user_context]) + + driver.find_element(By.ID, "file-2").click() + + try: + WebDriverWait(driver, 3, poll_frequency=0.2).until( + lambda _: len(os.listdir(tmp_path)) > initial_file_count + ) + files_after = os.listdir(tmp_path) + pytest.fail(f"A file was downloaded unexpectedly: {files_after}") + except TimeoutException: + pass # Expected, no file downloaded + finally: + driver.browser.set_download_behavior(allowed=None, user_contexts=[user_context]) + finally: + driver.browser.remove_user_context(user_context) + + +@pytest.mark.xfail_firefox +def test_set_download_behavior_validation(driver): + with pytest.raises(ValueError, match="destination_folder is required when allowed=True"): + driver.browser.set_download_behavior(allowed=True) + + with pytest.raises(ValueError, match="destination_folder should not be provided when allowed=False"): + driver.browser.set_download_behavior(allowed=False, destination_folder="/tmp")