From 0cbc08b8c4e460ad91caf6c7003535647f3d4c6e Mon Sep 17 00:00:00 2001 From: Jordan Woods <13803242+jorwoods@users.noreply.github.com> Date: Sat, 5 Jul 2025 07:22:48 -0500 Subject: [PATCH] feat: retrieve tableau server product name Closes #1592 Calls a VizPortal API to retrieve the detailed product name, and attaches it to the TSC.Server object which will make determining what requests to build for things like subscriptions easier. Adds this functionality into the pre-existing `use_server_version` function because: 1. Users of the library will already know to call that method. 2. Users already aware that method makes API calls. If the request errors or if the payload doesn't match the expected format, defaults to assuming TSC is talking with an on-prem Server instance. Co-authored-by: emeric-dsj --- .../server/endpoint/endpoint.py | 4 +-- .../server/endpoint/server_info_endpoint.py | 31 +++++++++++++++++-- tableauserverclient/server/server.py | 2 ++ .../getServerSettingsUnauthenticated.json | 1 + test/http/test_http_requests.py | 17 +++++++++- test/test_server_info.py | 25 +++++++++++++++ 6 files changed, 75 insertions(+), 5 deletions(-) create mode 100644 test/assets/getServerSettingsUnauthenticated.json diff --git a/tableauserverclient/server/endpoint/endpoint.py b/tableauserverclient/server/endpoint/endpoint.py index 21462af5f..6a3ea1913 100644 --- a/tableauserverclient/server/endpoint/endpoint.py +++ b/tableauserverclient/server/endpoint/endpoint.py @@ -189,8 +189,8 @@ def log_response_safely(self, server_response: "Response") -> str: loggable_response = helpers.strings.redact_xml(server_response.content.decode(server_response.encoding)) return loggable_response - def get_unauthenticated_request(self, url): - return self._make_request(self.parent_srv.session.get, url) + def get_unauthenticated_request(self, url, parameters=None): + return self._make_request(self.parent_srv.session.get, url, parameters=parameters) def get_request(self, url, request_object=None, parameters=None): if request_object is not None: diff --git a/tableauserverclient/server/endpoint/server_info_endpoint.py b/tableauserverclient/server/endpoint/server_info_endpoint.py index dc934496a..58e687fe7 100644 --- a/tableauserverclient/server/endpoint/server_info_endpoint.py +++ b/tableauserverclient/server/endpoint/server_info_endpoint.py @@ -1,5 +1,5 @@ import logging -from typing import Union +from typing import Literal, Union, TYPE_CHECKING from .endpoint import Endpoint, api from .exceptions import ServerResponseError @@ -9,10 +9,15 @@ ) from tableauserverclient.models import ServerInfoItem +if TYPE_CHECKING: + from tableauserverclient.server import Server + +Products = Literal["TableauServer", "TableauOnline"] + class ServerInfo(Endpoint): def __init__(self, server): - self.parent_srv = server + self.parent_srv: "Server" = server self._info = None @property @@ -80,3 +85,25 @@ def get(self) -> Union[ServerInfoItem, None]: logging.getLogger(self.__class__.__name__).debug(e) logging.getLogger(self.__class__.__name__).debug(server_response.content) return self._info + + def _get_product_info(self) -> Products: + """ + Retrieve the server product information to determine if the server is + Tableau Server or Tableau Online. + """ + method = "getServerSettingsUnauthenticated" + response = self.parent_srv.session.post( + f"{self.parent_srv.server_address}/vizportal/api/web/v1/{method}", + headers={"Content-Type": "application/json"}, + verify=self.parent_srv.http_options.get("verify", True), + json={"method": method, "params": {}}, + ) + if not response.ok: + return "TableauServer" + else: + try: + return response.json().get("result", {}).get("product", "TableauServer") + except Exception as e: + logging.getLogger(self.__class__.__name__).debug(e) + logging.getLogger(self.__class__.__name__).debug("Failed to parse product info response.") + return "TableauServer" diff --git a/tableauserverclient/server/server.py b/tableauserverclient/server/server.py index 9202e3e63..f91142644 100644 --- a/tableauserverclient/server/server.py +++ b/tableauserverclient/server/server.py @@ -145,6 +145,7 @@ def __init__(self, server_address, use_server_version=False, http_options=None, self._site_id = None self._user_id = None self._ssl_context = None + self._product = "TableauServer" # default product type # TODO: this needs to change to default to https, but without breaking existing code if not server_address.startswith("http://") and not server_address.startswith("https://"): @@ -269,6 +270,7 @@ def _determine_highest_version(self): def use_server_version(self): self.version = self._determine_highest_version() + self._product = self.server_info._get_product_info() def use_highest_version(self): self.use_server_version() diff --git a/test/assets/getServerSettingsUnauthenticated.json b/test/assets/getServerSettingsUnauthenticated.json new file mode 100644 index 000000000..9c3464353 --- /dev/null +++ b/test/assets/getServerSettingsUnauthenticated.json @@ -0,0 +1 @@ +{"result": {"product": "TableauOnline"}} diff --git a/test/http/test_http_requests.py b/test/http/test_http_requests.py index ce845502d..d96c4389b 100644 --- a/test/http/test_http_requests.py +++ b/test/http/test_http_requests.py @@ -27,6 +27,19 @@ def __init__(self, status_code): return MockResponse(200) +# This method will be used by the mock to replace requests.get +def mocked_requests_post(*args, **kwargs): + class MockResponse: + def __init__(self, status_code): + self.headers = {} + self.encoding = None + self.content = '{"result": {"product": "TableauOnline"}}' + self.status_code = status_code + self.ok = True + + return MockResponse(200) + + class ServerTests(unittest.TestCase): def test_init_server_model_empty_throws(self): with self.assertRaises(TypeError): @@ -46,7 +59,8 @@ def test_init_server_model_bad_server_name_not_version_check(self): server = TSC.Server("fake-url", use_server_version=False) @mock.patch("requests.sessions.Session.get", side_effect=mocked_requests_get) - def test_init_server_model_bad_server_name_do_version_check(self, mock_get): + @mock.patch("requests.sessions.Session.post", side_effect=mocked_requests_post) + def test_init_server_model_bad_server_name_do_version_check(self, mock_get, mock_post): server = TSC.Server("fake-url", use_server_version=True) def test_init_server_model_bad_server_name_not_version_check_random_options(self): @@ -114,4 +128,5 @@ def test_session_factory_adds_headers(self): test_request_bin = "http://capture-this-with-mock.com" with requests_mock.mock() as m: m.get(url="http://capture-this-with-mock.com/api/2.4/serverInfo", request_headers=SessionTests.test_header) + m.post(f"{test_request_bin}/vizportal/api/web/v1/getServerSettingsUnauthenticated", json={}) server = TSC.Server(test_request_bin, use_server_version=True, session_factory=SessionTests.session_factory) diff --git a/test/test_server_info.py b/test/test_server_info.py index fa1472c9a..eb5809fab 100644 --- a/test/test_server_info.py +++ b/test/test_server_info.py @@ -1,3 +1,4 @@ +import json import os.path import unittest @@ -13,6 +14,7 @@ SERVER_INFO_404 = os.path.join(TEST_ASSET_DIR, "server_info_404.xml") SERVER_INFO_AUTH_INFO_XML = os.path.join(TEST_ASSET_DIR, "server_info_auth_info.xml") SERVER_INFO_WRONG_SITE = os.path.join(TEST_ASSET_DIR, "server_info_wrong_site.html") +SERVER_PRODUCT_INFO = os.path.join(TEST_ASSET_DIR, "getServerSettingsUnauthenticated.json") class ServerInfoTests(unittest.TestCase): @@ -26,6 +28,7 @@ def test_server_info_get(self): response_xml = f.read().decode("utf-8") with requests_mock.mock() as m: m.get(self.server.server_info.baseurl, text=response_xml) + m.post(f"{self.server.server_address}/vizportal/api/web/v1/getServerSettingsUnauthenticated", json={}) actual = self.server.server_info.get() self.assertEqual("10.1.0", actual.product_version) @@ -43,6 +46,8 @@ def test_server_info_use_highest_version_downgrades(self): # Return a 404 for serverInfo so we can pretend this is an old Server m.get(self.server.server_address + "/api/2.4/serverInfo", text=si_response_xml, status_code=404) m.get(self.server.server_address + "/auth?format=xml", text=auth_response_xml) + m.post(f"{self.server.server_address}/vizportal/api/web/v1/getServerSettingsUnauthenticated", json={}) + self.server.use_server_version() # does server-version[9.2] lookup in PRODUCT_TO_REST_VERSION self.assertEqual(self.server.version, "2.2") @@ -52,6 +57,7 @@ def test_server_info_use_highest_version_upgrades(self): si_response_xml = f.read().decode("utf-8") with requests_mock.mock() as m: m.get(self.server.server_address + "/api/2.8/serverInfo", text=si_response_xml) + m.post(f"{self.server.server_address}/vizportal/api/web/v1/getServerSettingsUnauthenticated", json={}) # Pretend we're old self.server.version = "2.8" self.server.use_server_version() @@ -63,6 +69,7 @@ def test_server_use_server_version_flag(self): si_response_xml = f.read().decode("utf-8") with requests_mock.mock() as m: m.get("http://test/api/2.4/serverInfo", text=si_response_xml) + m.post(f"{self.server.server_address}/vizportal/api/web/v1/getServerSettingsUnauthenticated", json={}) server = TSC.Server("http://test", use_server_version=True) self.assertEqual(server.version, "2.5") @@ -73,3 +80,21 @@ def test_server_wrong_site(self): m.get(self.server.server_info.baseurl, text=response, status_code=404) with self.assertRaises(NonXMLResponseError): self.server.server_info.get() + + def test_server_info_product(self): + with open(SERVER_PRODUCT_INFO) as f: + product_info_json = json.load(f) + + with requests_mock.mock() as m: + m.post( + f"{self.server.server_address}/vizportal/api/web/v1/getServerSettingsUnauthenticated", + json=product_info_json, + ) + self.server.use_server_version() + assert self.server._product == "TableauOnline" + + def test_server_info_product_no_response(self): + with requests_mock.mock() as m: + m.post(f"{self.server.server_address}/vizportal/api/web/v1/getServerSettingsUnauthenticated", json={}) + self.server.use_server_version() + assert self.server._product == "TableauServer"