diff --git a/tableauserverclient/__init__.py b/tableauserverclient/__init__.py index b041fcda..cd0ec3e0 100644 --- a/tableauserverclient/__init__.py +++ b/tableauserverclient/__init__.py @@ -13,6 +13,8 @@ DatabaseItem, DataFreshnessPolicyItem, DatasourceItem, + ExtensionsServer, + ExtensionsSiteSettings, FavoriteItem, FlowItem, FlowRunItem, @@ -36,6 +38,7 @@ ProjectItem, Resource, RevisionItem, + SafeExtension, ScheduleItem, SiteAuthConfiguration, SiteOIDCConfiguration, @@ -88,6 +91,8 @@ "DEFAULT_NAMESPACE", "DQWItem", "ExcelRequestOptions", + "ExtensionsServer", + "ExtensionsSiteSettings", "FailedSignInError", "FavoriteItem", "FileuploadItem", @@ -121,6 +126,7 @@ "RequestOptions", "Resource", "RevisionItem", + "SafeExtension", "ScheduleItem", "Server", "ServerInfoItem", diff --git a/tableauserverclient/models/__init__.py b/tableauserverclient/models/__init__.py index 67f6553f..aa28e0db 100644 --- a/tableauserverclient/models/__init__.py +++ b/tableauserverclient/models/__init__.py @@ -10,6 +10,7 @@ from tableauserverclient.models.datasource_item import DatasourceItem from tableauserverclient.models.dqw_item import DQWItem from tableauserverclient.models.exceptions import UnpopulatedPropertyError +from tableauserverclient.models.extensions_item import ExtensionsServer, ExtensionsSiteSettings, SafeExtension from tableauserverclient.models.favorites_item import FavoriteItem from tableauserverclient.models.fileupload_item import FileuploadItem from tableauserverclient.models.flow_item import FlowItem @@ -113,4 +114,7 @@ "LinkedTaskStepItem", "LinkedTaskFlowRunItem", "ExtractItem", + "ExtensionsServer", + "ExtensionsSiteSettings", + "SafeExtension", ] diff --git a/tableauserverclient/models/extensions_item.py b/tableauserverclient/models/extensions_item.py new file mode 100644 index 00000000..9b6e1089 --- /dev/null +++ b/tableauserverclient/models/extensions_item.py @@ -0,0 +1,186 @@ +from typing import overload +from typing_extensions import Self + +from defusedxml.ElementTree import fromstring + +from tableauserverclient.models.property_decorators import property_is_boolean + + +class ExtensionsServer: + def __init__(self) -> None: + self._enabled: bool | None = None + self._block_list: list[str] | None = None + + @property + def enabled(self) -> bool | None: + """Indicates whether the extensions server is enabled.""" + return self._enabled + + @enabled.setter + @property_is_boolean + def enabled(self, value: bool | None) -> None: + self._enabled = value + + @property + def block_list(self) -> list[str] | None: + """List of blocked extensions.""" + return self._block_list + + @block_list.setter + def block_list(self, value: list[str] | None) -> None: + self._block_list = value + + @classmethod + def from_response(cls: type[Self], response, ns) -> Self: + xml = fromstring(response) + obj = cls() + element = xml.find(".//t:extensionsServerSettings", namespaces=ns) + if element is None: + raise ValueError("Missing extensionsServerSettings element in response") + + if (enabled_element := element.find("./t:extensionsGloballyEnabled", namespaces=ns)) is not None: + obj.enabled = string_to_bool(enabled_element.text) + obj.block_list = [e.text for e in element.findall("./t:blockList", namespaces=ns)] + + return obj + + +class SafeExtension: + def __init__( + self, url: str | None = None, full_data_allowed: bool | None = None, prompt_needed: bool | None = None + ) -> None: + self.url = url + self._full_data_allowed = full_data_allowed + self._prompt_needed = prompt_needed + + @property + def full_data_allowed(self) -> bool | None: + return self._full_data_allowed + + @full_data_allowed.setter + @property_is_boolean + def full_data_allowed(self, value: bool | None) -> None: + self._full_data_allowed = value + + @property + def prompt_needed(self) -> bool | None: + return self._prompt_needed + + @prompt_needed.setter + @property_is_boolean + def prompt_needed(self, value: bool | None) -> None: + self._prompt_needed = value + + +class ExtensionsSiteSettings: + def __init__(self) -> None: + self._enabled: bool | None = None + self._use_default_setting: bool | None = None + self.safe_list: list[SafeExtension] | None = None + self._allow_trusted: bool | None = None + self._include_tableau_built: bool | None = None + self._include_partner_built: bool | None = None + self._include_sandboxed: bool | None = None + + @property + def enabled(self) -> bool | None: + return self._enabled + + @enabled.setter + @property_is_boolean + def enabled(self, value: bool | None) -> None: + self._enabled = value + + @property + def use_default_setting(self) -> bool | None: + return self._use_default_setting + + @use_default_setting.setter + @property_is_boolean + def use_default_setting(self, value: bool | None) -> None: + self._use_default_setting = value + + @property + def allow_trusted(self) -> bool | None: + return self._allow_trusted + + @allow_trusted.setter + @property_is_boolean + def allow_trusted(self, value: bool | None) -> None: + self._allow_trusted = value + + @property + def include_tableau_built(self) -> bool | None: + return self._include_tableau_built + + @include_tableau_built.setter + @property_is_boolean + def include_tableau_built(self, value: bool | None) -> None: + self._include_tableau_built = value + + @property + def include_partner_built(self) -> bool | None: + return self._include_partner_built + + @include_partner_built.setter + @property_is_boolean + def include_partner_built(self, value: bool | None) -> None: + self._include_partner_built = value + + @property + def include_sandboxed(self) -> bool | None: + return self._include_sandboxed + + @include_sandboxed.setter + @property_is_boolean + def include_sandboxed(self, value: bool | None) -> None: + self._include_sandboxed = value + + @classmethod + def from_response(cls: type[Self], response, ns) -> Self: + xml = fromstring(response) + element = xml.find(".//t:extensionsSiteSettings", namespaces=ns) + obj = cls() + if element is None: + raise ValueError("Missing extensionsSiteSettings element in response") + + if (enabled_element := element.find("./t:extensionsEnabled", namespaces=ns)) is not None: + obj.enabled = string_to_bool(enabled_element.text) + if (default_settings_element := element.find("./t:useDefaultSetting", namespaces=ns)) is not None: + obj.use_default_setting = string_to_bool(default_settings_element.text) + if (allow_trusted_element := element.find("./t:allowTrusted", namespaces=ns)) is not None: + obj.allow_trusted = string_to_bool(allow_trusted_element.text) + if (include_tableau_built_element := element.find("./t:includeTableauBuilt", namespaces=ns)) is not None: + obj.include_tableau_built = string_to_bool(include_tableau_built_element.text) + if (include_partner_built_element := element.find("./t:includePartnerBuilt", namespaces=ns)) is not None: + obj.include_partner_built = string_to_bool(include_partner_built_element.text) + if (include_sandboxed_element := element.find("./t:includeSandboxed", namespaces=ns)) is not None: + obj.include_sandboxed = string_to_bool(include_sandboxed_element.text) + + safe_list = [] + for safe_extension_element in element.findall("./t:safeList", namespaces=ns): + url = safe_extension_element.find("./t:url", namespaces=ns) + full_data_allowed = safe_extension_element.find("./t:fullDataAllowed", namespaces=ns) + prompt_needed = safe_extension_element.find("./t:promptNeeded", namespaces=ns) + + safe_extension = SafeExtension( + url=url.text if url is not None else None, + full_data_allowed=string_to_bool(full_data_allowed.text) if full_data_allowed is not None else None, + prompt_needed=string_to_bool(prompt_needed.text) if prompt_needed is not None else None, + ) + safe_list.append(safe_extension) + + obj.safe_list = safe_list + return obj + + +@overload +def string_to_bool(s: str) -> bool: ... + + +@overload +def string_to_bool(s: None) -> None: ... + + +def string_to_bool(s): + return s.lower() == "true" if s is not None else None diff --git a/tableauserverclient/models/property_decorators.py b/tableauserverclient/models/property_decorators.py index 5048b349..0fcc9745 100644 --- a/tableauserverclient/models/property_decorators.py +++ b/tableauserverclient/models/property_decorators.py @@ -1,7 +1,7 @@ import datetime import re from functools import wraps -from typing import Any, Optional +from typing import Any, Optional, Tuple from collections.abc import Container from tableauserverclient.datetime_helpers import parse_datetime diff --git a/tableauserverclient/server/endpoint/__init__.py b/tableauserverclient/server/endpoint/__init__.py index 3c1266f9..d944bc42 100644 --- a/tableauserverclient/server/endpoint/__init__.py +++ b/tableauserverclient/server/endpoint/__init__.py @@ -6,6 +6,7 @@ from tableauserverclient.server.endpoint.datasources_endpoint import Datasources from tableauserverclient.server.endpoint.endpoint import Endpoint, QuerysetEndpoint from tableauserverclient.server.endpoint.exceptions import ServerResponseError, MissingRequiredFieldError +from tableauserverclient.server.endpoint.extensions_endpoint import Extensions from tableauserverclient.server.endpoint.favorites_endpoint import Favorites from tableauserverclient.server.endpoint.fileuploads_endpoint import Fileuploads from tableauserverclient.server.endpoint.flow_runs_endpoint import FlowRuns @@ -42,6 +43,7 @@ "QuerysetEndpoint", "MissingRequiredFieldError", "Endpoint", + "Extensions", "Favorites", "Fileuploads", "FlowRuns", diff --git a/tableauserverclient/server/endpoint/extensions_endpoint.py b/tableauserverclient/server/endpoint/extensions_endpoint.py new file mode 100644 index 00000000..ccef53de --- /dev/null +++ b/tableauserverclient/server/endpoint/extensions_endpoint.py @@ -0,0 +1,79 @@ +from tableauserverclient.models.extensions_item import ExtensionsServer, ExtensionsSiteSettings +from tableauserverclient.server.endpoint.endpoint import Endpoint +from tableauserverclient.server.endpoint.endpoint import api +from tableauserverclient.server.request_factory import RequestFactory + + +class Extensions(Endpoint): + def __init__(self, parent_srv): + super().__init__(parent_srv) + + @property + def _server_baseurl(self) -> str: + return f"{self.parent_srv.baseurl}/settings/extensions" + + @property + def baseurl(self) -> str: + return f"{self.parent_srv.baseurl}/sites/{self.parent_srv.site_id}/settings/extensions" + + @api(version="3.21") + def get_server_settings(self) -> ExtensionsServer: + """Lists the settings for extensions of a server + + Returns + ------- + ExtensionsServer + The server extensions settings + """ + response = self.get_request(self._server_baseurl) + return ExtensionsServer.from_response(response.content, self.parent_srv.namespace) + + @api(version="3.21") + def update_server_settings(self, extensions_server: ExtensionsServer) -> ExtensionsServer: + """Updates the settings for extensions of a server. Overwrites all existing settings. Any + sites omitted from the block list will be unblocked. + + Parameters + ---------- + extensions_server : ExtensionsServer + The server extensions settings to update + + Returns + ------- + ExtensionsServer + The updated server extensions settings + """ + req = RequestFactory.Extensions.update_server_extensions(extensions_server) + response = self.put_request(self._server_baseurl, req) + return ExtensionsServer.from_response(response.content, self.parent_srv.namespace) + + @api(version="3.21") + def get(self) -> ExtensionsSiteSettings: + """Lists the extensions settings for the site + + Returns + ------- + list[ExtensionsSiteSettings] + The site extensions settings + """ + response = self.get_request(self.baseurl) + return ExtensionsSiteSettings.from_response(response.content, self.parent_srv.namespace) + + @api(version="3.21") + def update(self, extensions_site_settings: ExtensionsSiteSettings) -> ExtensionsSiteSettings: + """Updates the extensions settings for the site. Overwrites all existing settings. + Any extensions omitted from the safe extensions list will be removed. + + Parameters + ---------- + extensions_site_settings : ExtensionsSiteSettings + The site extensions settings to update + + Returns + ------- + ExtensionsSiteSettings + The updated site extensions settings + """ + req = RequestFactory.Extensions.update_site_extensions(extensions_site_settings) + response = self.put_request(self.baseurl, req) + return ExtensionsSiteSettings.from_response(response.content, self.parent_srv.namespace) diff --git a/tableauserverclient/server/request_factory.py b/tableauserverclient/server/request_factory.py index 877a18c3..b22fc606 100644 --- a/tableauserverclient/server/request_factory.py +++ b/tableauserverclient/server/request_factory.py @@ -1621,6 +1621,61 @@ def update_req(self, xml_request: ET.Element, oidc_item: SiteOIDCConfiguration) return ET.tostring(xml_request) +class ExtensionsRequest: + @_tsrequest_wrapped + def update_server_extensions(self, xml_request: ET.Element, extensions_server: "ExtensionsServer") -> None: + extensions_element = ET.SubElement(xml_request, "extensionsServerSettings") + if not isinstance(extensions_server.enabled, bool): + raise ValueError(f"Extensions Server missing enabled: {extensions_server}") + enabled_element = ET.SubElement(extensions_element, "extensionsGloballyEnabled") + enabled_element.text = str(extensions_server.enabled).lower() + + if extensions_server.block_list is None: + return + for blocked in extensions_server.block_list: + blocked_element = ET.SubElement(extensions_element, "blockList") + blocked_element.text = blocked + return + + @_tsrequest_wrapped + def update_site_extensions(self, xml_request: ET.Element, extensions_site_settings: ExtensionsSiteSettings) -> None: + ext_element = ET.SubElement(xml_request, "extensionsSiteSettings") + if not isinstance(extensions_site_settings.enabled, bool): + raise ValueError(f"Extensions Site Settings missing enabled: {extensions_site_settings}") + enabled_element = ET.SubElement(ext_element, "extensionsEnabled") + enabled_element.text = str(extensions_site_settings.enabled).lower() + if not isinstance(extensions_site_settings.use_default_setting, bool): + raise ValueError( + f"Extensions Site Settings missing use_default_setting: {extensions_site_settings.use_default_setting}" + ) + default_element = ET.SubElement(ext_element, "useDefaultSetting") + default_element.text = str(extensions_site_settings.use_default_setting).lower() + if extensions_site_settings.allow_trusted is not None: + allow_trusted_element = ET.SubElement(ext_element, "allowTrusted") + allow_trusted_element.text = str(extensions_site_settings.allow_trusted).lower() + if extensions_site_settings.include_sandboxed is not None: + include_sandboxed_element = ET.SubElement(ext_element, "includeSandboxed") + include_sandboxed_element.text = str(extensions_site_settings.include_sandboxed).lower() + if extensions_site_settings.include_tableau_built is not None: + include_tableau_built_element = ET.SubElement(ext_element, "includeTableauBuilt") + include_tableau_built_element.text = str(extensions_site_settings.include_tableau_built).lower() + if extensions_site_settings.include_partner_built is not None: + include_partner_built_element = ET.SubElement(ext_element, "includePartnerBuilt") + include_partner_built_element.text = str(extensions_site_settings.include_partner_built).lower() + + for safe in extensions_site_settings.safe_list or []: + safe_element = ET.SubElement(ext_element, "safeList") + if safe.url is not None: + url_element = ET.SubElement(safe_element, "url") + url_element.text = safe.url + if safe.full_data_allowed is not None: + full_data_element = ET.SubElement(safe_element, "fullDataAllowed") + full_data_element.text = str(safe.full_data_allowed).lower() + if safe.prompt_needed is not None: + prompt_element = ET.SubElement(safe_element, "promptNeeded") + prompt_element.text = str(safe.prompt_needed).lower() + + class RequestFactory: Auth = AuthRequest() Connection = Connection() @@ -1631,6 +1686,7 @@ class RequestFactory: Database = DatabaseRequest() DQW = DQWRequest() Empty = EmptyRequest() + Extensions = ExtensionsRequest() Favorite = FavoriteRequest() Fileupload = FileuploadRequest() Flow = FlowRequest() diff --git a/tableauserverclient/server/server.py b/tableauserverclient/server/server.py index 9202e3e6..b497e908 100644 --- a/tableauserverclient/server/server.py +++ b/tableauserverclient/server/server.py @@ -39,6 +39,7 @@ Tags, VirtualConnections, OIDC, + Extensions, ) from tableauserverclient.server.exceptions import ( ServerInfoEndpointNotFoundError, @@ -185,6 +186,7 @@ def __init__(self, server_address, use_server_version=False, http_options=None, self.tags = Tags(self) self.virtual_connections = VirtualConnections(self) self.oidc = OIDC(self) + self.extensions = Extensions(self) self._session = self._session_factory() self._http_options = dict() # must set this before making a server call diff --git a/test/assets/extensions_server_settings_false.xml b/test/assets/extensions_server_settings_false.xml new file mode 100644 index 00000000..16fd3e85 --- /dev/null +++ b/test/assets/extensions_server_settings_false.xml @@ -0,0 +1,6 @@ + + + + false + + diff --git a/test/assets/extensions_server_settings_true.xml b/test/assets/extensions_server_settings_true.xml new file mode 100644 index 00000000..c562d471 --- /dev/null +++ b/test/assets/extensions_server_settings_true.xml @@ -0,0 +1,8 @@ + + + + true + https://test.com + https://example.com + + diff --git a/test/assets/extensions_site_settings.xml b/test/assets/extensions_site_settings.xml new file mode 100644 index 00000000..2a62d299 --- /dev/null +++ b/test/assets/extensions_site_settings.xml @@ -0,0 +1,16 @@ + + + + true + false + true + false + >false + false + + http://localhost:9123/Dynamic.html + true + true + + + diff --git a/test/test_extensions.py b/test/test_extensions.py new file mode 100644 index 00000000..0b5a85ec --- /dev/null +++ b/test/test_extensions.py @@ -0,0 +1,134 @@ +from pathlib import Path +from xml.etree.ElementTree import Element + +from defusedxml.ElementTree import fromstring +import requests_mock +import pytest + +import tableauserverclient as TSC + + +TEST_ASSET_DIR = Path(__file__).parent / "assets" + +GET_SERVER_EXT_SETTINGS = TEST_ASSET_DIR / "extensions_server_settings_true.xml" +GET_SERVER_EXT_SETTINGS_FALSE = TEST_ASSET_DIR / "extensions_server_settings_false.xml" +GET_SITE_SETTINGS = TEST_ASSET_DIR / "extensions_site_settings.xml" + + +@pytest.fixture(scope="function") +def server() -> TSC.Server: + server = TSC.Server("http://test", False) + + # Fake sign in + server._site_id = "dad65087-b08b-4603-af4e-2887b8aafc67" + server._auth_token = "j80k54ll2lfMZ0tv97mlPvvSCRyD0DOM" + server.version = "3.21" + + return server + + +def test_get_server_extensions_settings(server: TSC.Server) -> None: + with requests_mock.mock() as m: + m.get(server.extensions._server_baseurl, text=GET_SERVER_EXT_SETTINGS.read_text()) + ext_settings = server.extensions.get_server_settings() + + assert ext_settings.enabled is True + assert ext_settings.block_list is not None + assert set(ext_settings.block_list) == {"https://test.com", "https://example.com"} + + +def test_get_server_extensions_settings_false(server: TSC.Server) -> None: + with requests_mock.mock() as m: + m.get(server.extensions._server_baseurl, text=GET_SERVER_EXT_SETTINGS_FALSE.read_text()) + ext_settings = server.extensions.get_server_settings() + + assert ext_settings.enabled is False + assert ext_settings.block_list is not None + assert len(ext_settings.block_list) == 0 + + +def test_update_server_extensions_settings(server: TSC.Server) -> None: + with requests_mock.mock() as m: + m.put(server.extensions._server_baseurl, text=GET_SERVER_EXT_SETTINGS_FALSE.read_text()) + + ext_settings = TSC.ExtensionsServer() + ext_settings.enabled = False + ext_settings.block_list = [] + + updated_settings = server.extensions.update_server_settings(ext_settings) + + assert updated_settings.enabled is False + assert updated_settings.block_list is not None + assert len(updated_settings.block_list) == 0 + + +def test_get_site_settings(server: TSC.Server) -> None: + with requests_mock.mock() as m: + m.get(server.extensions.baseurl, text=GET_SITE_SETTINGS.read_text()) + site_settings = server.extensions.get() + + assert isinstance(site_settings, TSC.ExtensionsSiteSettings) + assert site_settings.enabled is True + assert site_settings.use_default_setting is False + assert site_settings.safe_list is not None + assert site_settings.allow_trusted is True + assert site_settings.include_partner_built is False + assert site_settings.include_sandboxed is False + assert site_settings.include_tableau_built is False + assert len(site_settings.safe_list) == 1 + first_safe = site_settings.safe_list[0] + assert first_safe.url == "http://localhost:9123/Dynamic.html" + assert first_safe.full_data_allowed is True + assert first_safe.prompt_needed is True + + +def test_update_site_settings(server: TSC.Server) -> None: + with requests_mock.mock() as m: + m.put(server.extensions.baseurl, text=GET_SITE_SETTINGS.read_text()) + + site_settings = TSC.ExtensionsSiteSettings() + site_settings.enabled = True + site_settings.use_default_setting = False + safe_extension = TSC.SafeExtension( + url="http://localhost:9123/Dynamic.html", + full_data_allowed=True, + prompt_needed=True, + ) + site_settings.safe_list = [safe_extension] + + updated_settings = server.extensions.update(site_settings) + history = m.request_history + + assert isinstance(updated_settings, TSC.ExtensionsSiteSettings) + assert updated_settings.enabled is True + assert updated_settings.use_default_setting is False + assert updated_settings.safe_list is not None + assert len(updated_settings.safe_list) == 1 + first_safe = updated_settings.safe_list[0] + assert first_safe.url == "http://localhost:9123/Dynamic.html" + assert first_safe.full_data_allowed is True + assert first_safe.prompt_needed is True + + # Verify that the request body was as expected + assert len(history) == 1 + xml_payload = fromstring(history[0].body) + extensions_site_settings_elem = xml_payload.find(".//extensionsSiteSettings") + assert extensions_site_settings_elem is not None + enabled_elem = extensions_site_settings_elem.find("extensionsEnabled") + assert enabled_elem is not None + assert enabled_elem.text == "true" + use_default_elem = extensions_site_settings_elem.find("useDefaultSetting") + assert use_default_elem is not None + assert use_default_elem.text == "false" + safe_list_elements = list(extensions_site_settings_elem.findall("safeList")) + assert len(safe_list_elements) == 1 + safe_extension_elem = safe_list_elements[0] + url_elem = safe_extension_elem.find("url") + assert url_elem is not None + assert url_elem.text == "http://localhost:9123/Dynamic.html" + full_data_allowed_elem = safe_extension_elem.find("fullDataAllowed") + assert full_data_allowed_elem is not None + assert full_data_allowed_elem.text == "true" + prompt_needed_elem = safe_extension_elem.find("promptNeeded") + assert prompt_needed_elem is not None + assert prompt_needed_elem.text == "true"