Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 44 additions & 2 deletions py/selenium/webdriver/common/bidi/browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.


import os
from typing import Any, Optional

from selenium.webdriver.common.bidi.common import command_builder
Expand Down Expand Up @@ -236,3 +235,46 @@ def get_client_windows(self) -> list[ClientWindowInfo]:
"""
result = self.conn.execute(command_builder("browser.getClientWindows", {}))
return [ClientWindowInfo.from_dict(window) for window in result["clientWindows"]]

def set_download_behavior(
self,
*,
allowed: Optional[bool] = None,
destination_folder: Optional[str | os.PathLike] = None,
user_contexts: Optional[list[str]] = None,
) -> None:
"""Set the download behavior for the browser or specific user contexts.

Args:
allowed: True to allow downloads, False to deny downloads, or None to
clear download behavior (revert to default).
destination_folder: Required when allowed is True. Specifies the folder
to store downloads in.
user_contexts: Optional list of user context IDs to apply this
behavior to. If omitted, updates the default behavior.

Raises:
ValueError: If allowed=True and destination_folder is missing, or if
allowed=False and destination_folder is provided.
"""
params: dict[str, Any] = {}

if allowed is None:
params["downloadBehavior"] = None
else:
if allowed:
if not destination_folder:
raise ValueError("destination_folder is required when allowed=True.")
params["downloadBehavior"] = {
"type": "allowed",
"destinationFolder": os.fspath(destination_folder),
}
else:
if destination_folder:
raise ValueError("destination_folder should not be provided when allowed=False.")
params["downloadBehavior"] = {"type": "denied"}

if user_contexts is not None:
params["userContexts"] = user_contexts

self.conn.execute(command_builder("browser.setDownloadBehavior", params))
97 changes: 97 additions & 0 deletions py/test/selenium/webdriver/common/bidi_browser_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,21 @@
# under the License.

import http.server
import os
import socketserver
import threading

import pytest

from selenium.common.exceptions import TimeoutException
from selenium.webdriver.common.bidi.browser import ClientWindowInfo, ClientWindowState
from selenium.webdriver.common.bidi.browsing_context import ReadinessState
from selenium.webdriver.common.bidi.session import UserPromptHandler, UserPromptHandlerType
from selenium.webdriver.common.by import By
from selenium.webdriver.common.proxy import Proxy, ProxyType
from selenium.webdriver.common.utils import free_port
from selenium.webdriver.common.window import WindowTypes
from selenium.webdriver.support.ui import WebDriverWait


class FakeProxyHandler(http.server.SimpleHTTPRequestHandler):
Expand Down Expand Up @@ -262,3 +266,96 @@ def test_create_user_context_with_unhandled_prompt_behavior(driver, pages):

# Clean up
driver.browser.remove_user_context(user_context)


@pytest.mark.xfail_firefox
def test_set_download_behavior_allowed(driver, pages, tmp_path):
print(f"Driver info: {driver.capabilities}")
try:
driver.browser.set_download_behavior(allowed=True, destination_folder=tmp_path)

context_id = driver.current_window_handle
url = pages.url("downloads/download.html")
driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE)

driver.find_element(By.ID, "file-1").click()

WebDriverWait(driver, 5).until(lambda d: "file_1.txt" in os.listdir(tmp_path))

files = os.listdir(tmp_path)
assert "file_1.txt" in files, f"Expected file_1.txt in {tmp_path}, but found: {files}"
finally:
driver.browser.set_download_behavior(allowed=None)


@pytest.mark.xfail_firefox
def test_set_download_behavior_denied(driver, pages, tmp_path):
try:
driver.browser.set_download_behavior(allowed=False)

context_id = driver.current_window_handle
url = pages.url("downloads/download.html")
driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE)

driver.find_element(By.ID, "file-1").click()

try:
WebDriverWait(driver, 3, poll_frequency=0.2).until(lambda _: len(os.listdir(tmp_path)) > 0)
files = os.listdir(tmp_path)
pytest.fail(f"A file was downloaded unexpectedly: {files}")
except TimeoutException:
pass # Expected, no file downloaded
finally:
driver.browser.set_download_behavior(allowed=None)


@pytest.mark.xfail_firefox
def test_set_download_behavior_user_context(driver, pages, tmp_path):
user_context = driver.browser.create_user_context()

try:
bc = driver.browsing_context.create(type=WindowTypes.WINDOW, user_context=user_context)
driver.switch_to.window(bc)

try:
driver.browser.set_download_behavior(
allowed=True, destination_folder=tmp_path, user_contexts=[user_context]
)

url = pages.url("downloads/download.html")
driver.browsing_context.navigate(context=bc, url=url, wait=ReadinessState.COMPLETE)

driver.find_element(By.ID, "file-1").click()

WebDriverWait(driver, 5).until(lambda d: "file_1.txt" in os.listdir(tmp_path))

files = os.listdir(tmp_path)
assert "file_1.txt" in files, f"Expected file_1.txt in {tmp_path}, but found: {files}"

initial_file_count = len(files)

driver.browser.set_download_behavior(allowed=False, user_contexts=[user_context])

driver.find_element(By.ID, "file-2").click()

try:
WebDriverWait(driver, 3, poll_frequency=0.2).until(
lambda _: len(os.listdir(tmp_path)) > initial_file_count
)
files_after = os.listdir(tmp_path)
pytest.fail(f"A file was downloaded unexpectedly: {files_after}")
except TimeoutException:
pass # Expected, no file downloaded
finally:
driver.browser.set_download_behavior(allowed=None, user_contexts=[user_context])
finally:
driver.browser.remove_user_context(user_context)


@pytest.mark.xfail_firefox
def test_set_download_behavior_validation(driver):
with pytest.raises(ValueError, match="destination_folder is required when allowed=True"):
driver.browser.set_download_behavior(allowed=True)

with pytest.raises(ValueError, match="destination_folder should not be provided when allowed=False"):
driver.browser.set_download_behavior(allowed=False, destination_folder="/tmp")
Loading