From c09412eabd683672a5dca9a0b91aa582d2af9a1d Mon Sep 17 00:00:00 2001 From: ptiurin Date: Thu, 21 Aug 2025 17:04:28 +0100 Subject: [PATCH 01/11] working file caching --- src/firebolt/async_db/connection.py | 3 + src/firebolt/client/auth/base.py | 63 +++++--- .../client/auth/client_credentials.py | 15 -- src/firebolt/client/auth/service_account.py | 15 -- src/firebolt/client/auth/username_password.py | 15 -- src/firebolt/db/connection.py | 3 + src/firebolt/utils/cache.py | 150 +++++++++++++++++- src/firebolt/utils/file_operations.py | 114 +++++++++++++ 8 files changed, 309 insertions(+), 69 deletions(-) create mode 100644 src/firebolt/utils/file_operations.py diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index b63270ef7c1..5ad6938413d 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -237,6 +237,9 @@ async def connect( if not auth: raise ConfigurationError("auth is required to connect.") + if account_name: + auth._account_name = account_name + api_endpoint = fix_url_schema(api_endpoint) # Type checks assert auth is not None diff --git a/src/firebolt/client/auth/base.py b/src/firebolt/client/auth/base.py index 4c4132e3c2c..067836efe3e 100644 --- a/src/firebolt/client/auth/base.py +++ b/src/firebolt/client/auth/base.py @@ -8,8 +8,12 @@ from httpx import Auth as HttpxAuth from httpx import Request, Response, codes -from firebolt.utils.token_storage import TokenSecureStorage -from firebolt.utils.util import Timer, cached_property, get_internal_error_code +from firebolt.utils.cache import ( + ConnectionInfo, + SecureCacheKey, + _firebolt_cache, +) +from firebolt.utils.util import Timer, get_internal_error_code logger = logging.getLogger(__name__) @@ -38,6 +42,7 @@ class Auth(HttpxAuth): __slots__ = ( "_token", + "_account_name", "_expires", "_use_token_cache", ) @@ -47,7 +52,8 @@ class Auth(HttpxAuth): def __init__(self, use_token_cache: bool = True): self._use_token_cache = use_token_cache - self._token: Optional[str] = self._get_cached_token() + self._account_name: Optional[str] = None + self._token: Optional[str] = None self._expires: Optional[int] = None self._lock = Lock() @@ -103,36 +109,49 @@ def expired(self) -> bool: """ return self._expires is not None and self._expires <= int(time()) - @cached_property - def _token_storage(self) -> Optional[TokenSecureStorage]: - """Token filesystem cache storage. - - This is evaluated lazily, only if caching is enabled. - - Returns: - Optional[TokenSecureStorage]: Token filesystem cache storage if any - """ - return None - def _get_cached_token(self) -> Optional[str]: - """If caching is enabled, get token from filesystem cache. + """If caching is enabled, get token from cache. If caching is disabled, None is returned. Returns: Optional[str]: Token if any, and if caching is enabled; None otherwise """ - if not self._use_token_cache or not self._token_storage: + if not self._use_token_cache: return None - return self._token_storage.get_cached_token() + + cache_key = SecureCacheKey( + [self.principal, self.secret, self._account_name], self.secret + ) + connection_info = _firebolt_cache.get(cache_key) + + if connection_info and connection_info.token: + return connection_info.token + + return None def _cache_token(self) -> None: - """If caching isenabled, cache token to filesystem.""" - if not self._use_token_cache or not self._token_storage: + """If caching is enabled, cache token.""" + if not self._use_token_cache: return - # Only cache if token and expiration are retrieved - if self._token and self._expires: - self._token_storage.cache_token(self._token, self._expires) + # Only cache if token is retrieved + if self._token: + cache_key = SecureCacheKey( + [self.principal, self.secret, self._account_name], self.secret + ) + + # Get existing connection info or create new one + connection_info = _firebolt_cache.get(cache_key) + if connection_info is None: + connection_info = ConnectionInfo( + id="NONE" + ) # This is triggered first so there will be no id + + # Update token information + connection_info.token = self._token + + # Cache it + _firebolt_cache.set(cache_key, connection_info) @abstractmethod def get_new_token_generator(self) -> Generator[Request, Response, None]: diff --git a/src/firebolt/client/auth/client_credentials.py b/src/firebolt/client/auth/client_credentials.py index d729a581fb3..d08eae2798d 100644 --- a/src/firebolt/client/auth/client_credentials.py +++ b/src/firebolt/client/auth/client_credentials.py @@ -1,10 +1,6 @@ -from typing import Optional - from firebolt.client.auth.base import AuthRequest, FireboltAuthVersion from firebolt.client.auth.request_auth_base import _RequestBasedAuth -from firebolt.utils.token_storage import TokenSecureStorage from firebolt.utils.urls import AUTH_SERVICE_ACCOUNT_URL -from firebolt.utils.util import cached_property class ClientCredentials(_RequestBasedAuth): @@ -79,17 +75,6 @@ def get_firebolt_version(self) -> FireboltAuthVersion: """ return FireboltAuthVersion.V2 - @cached_property - def _token_storage(self) -> Optional[TokenSecureStorage]: - """Token filesystem cache storage. - - This is evaluated lazily, only if caching is enabled - - Returns: - TokenSecureStorage: Token filesystem cache storage - """ - return TokenSecureStorage(username=self.client_id, password=self.client_secret) - def _make_auth_request(self) -> AuthRequest: """Get new token using username and password. diff --git a/src/firebolt/client/auth/service_account.py b/src/firebolt/client/auth/service_account.py index 7e5b91c4a4b..529b6f7b0ab 100644 --- a/src/firebolt/client/auth/service_account.py +++ b/src/firebolt/client/auth/service_account.py @@ -1,10 +1,6 @@ -from typing import Optional - from firebolt.client.auth.base import AuthRequest, FireboltAuthVersion from firebolt.client.auth.request_auth_base import _RequestBasedAuth -from firebolt.utils.token_storage import TokenSecureStorage from firebolt.utils.urls import AUTH_SERVICE_ACCOUNT_URL -from firebolt.utils.util import cached_property class ServiceAccount(_RequestBasedAuth): @@ -77,17 +73,6 @@ def copy(self) -> "ServiceAccount": """ return ServiceAccount(self.client_id, self.client_secret, self._use_token_cache) - @cached_property - def _token_storage(self) -> Optional[TokenSecureStorage]: - """Token filesystem cache storage. - - This is evaluated lazily, only if caching is enabled - - Returns: - TokenSecureStorage: Token filesystem cache storage - """ - return TokenSecureStorage(username=self.client_id, password=self.client_secret) - def _make_auth_request(self) -> AuthRequest: """Get new token using username and password. diff --git a/src/firebolt/client/auth/username_password.py b/src/firebolt/client/auth/username_password.py index 29050641d16..672ee23470f 100644 --- a/src/firebolt/client/auth/username_password.py +++ b/src/firebolt/client/auth/username_password.py @@ -1,10 +1,6 @@ -from typing import Optional - from firebolt.client.auth.base import AuthRequest, FireboltAuthVersion from firebolt.client.auth.request_auth_base import _RequestBasedAuth -from firebolt.utils.token_storage import TokenSecureStorage from firebolt.utils.urls import AUTH_URL -from firebolt.utils.util import cached_property class UsernamePassword(_RequestBasedAuth): @@ -77,17 +73,6 @@ def copy(self) -> "UsernamePassword": """ return UsernamePassword(self.username, self.password, self._use_token_cache) - @cached_property - def _token_storage(self) -> Optional[TokenSecureStorage]: - """Token filesystem cache storage. - - This is evaluated lazily, only if caching is enabled - - Returns: - TokenSecureStorage: Token filesystem cache storage - """ - return TokenSecureStorage(username=self.username, password=self.password) - def _make_auth_request(self) -> AuthRequest: """Get new token using username and password. diff --git a/src/firebolt/db/connection.py b/src/firebolt/db/connection.py index 72021ce07ff..14eada03a14 100644 --- a/src/firebolt/db/connection.py +++ b/src/firebolt/db/connection.py @@ -66,6 +66,9 @@ def connect( if not auth: raise ConfigurationError("auth is required to connect.") + if account_name: + auth._account_name = account_name + api_endpoint = fix_url_schema(api_endpoint) # Type checks assert auth is not None diff --git a/src/firebolt/utils/cache.py b/src/firebolt/utils/cache.py index e8f0d8bd9d1..43299009ab5 100644 --- a/src/firebolt/utils/cache.py +++ b/src/firebolt/utils/cache.py @@ -1,6 +1,11 @@ +import logging import os import time -from dataclasses import dataclass, field +from dataclasses import asdict, dataclass, field +from json import JSONDecodeError +from json import dumps as json_dumps +from json import loads as json_loads +from os import makedirs, path from typing import ( Any, Callable, @@ -12,10 +17,21 @@ TypeVar, ) +from appdirs import user_data_dir + +from firebolt.utils.file_operations import ( + FernetEncrypter, + generate_encrypted_file_name, + generate_salt, +) + T = TypeVar("T") # Cache expiry configuration CACHE_EXPIRY_SECONDS = 3600 # 1 hour +APPNAME = "firebolt" + +logger = logging.getLogger(__name__) class ReprCacheable(Protocol): @@ -47,6 +63,22 @@ class ConnectionInfo: system_engine: Optional[EngineInfo] = None databases: Dict[str, DatabaseInfo] = field(default_factory=dict) engines: Dict[str, EngineInfo] = field(default_factory=dict) + token: Optional[str] = None + + def __post_init__(self) -> None: + """ + Post-initialization processing to convert dicts to dataclasses. + """ + if self.system_engine and isinstance(self.system_engine, dict): + self.system_engine = EngineInfo(**self.system_engine) + self.databases = { + k: DatabaseInfo(**v) + for k, v in self.databases.items() + if isinstance(v, dict) + } + self.engines = { + k: EngineInfo(**v) for k, v in self.engines.items() if isinstance(v, dict) + } def noop_if_disabled(func: Callable) -> Callable: @@ -150,4 +182,118 @@ def __hash__(self) -> int: return hash(self.key) -_firebolt_cache = UtilCache[ConnectionInfo](cache_name="connection_info") +class FileBasedCache(UtilCache[ConnectionInfo]): + """ + File-based cache that persists to disk with encryption. + Extends UtilCache to provide persistent storage using encrypted files. + """ + + def __init__(self, cache_name: str = ""): + super().__init__(cache_name) + self._data_dir = user_data_dir(appname=APPNAME) # TODO: change to new dir + makedirs(self._data_dir, exist_ok=True) + + def _get_file_path(self, key: SecureCacheKey) -> str: + """Get the file path for a cache key.""" + cache_key = self.create_key(key) + encrypted_filename = generate_encrypted_file_name(cache_key, key.encryption_key) + return path.join(self._data_dir, encrypted_filename) + + def _read_data_json(self, file_path: str, encrypter: FernetEncrypter) -> dict: + """Read and decrypt JSON data from file.""" + if not path.exists(file_path): + return {} + + try: + with open(file_path, "r") as f: + encrypted_data = f.read() + + decrypted_data = encrypter.decrypt(encrypted_data) + if decrypted_data is None: + logger.debug("Decryption failed for %s", file_path) + return {} + + return json_loads(decrypted_data) if decrypted_data else {} + except (JSONDecodeError, IOError) as e: + logger.debug( + "Failed to read or decode data from %s error: %s", file_path, e + ) + return {} + + def _write_data_json( + self, file_path: str, data: dict, encrypter: FernetEncrypter + ) -> None: + """Encrypt and write JSON data to file.""" + try: + json_str = json_dumps(data) + logger.debug("Writing data to %s", file_path) + encrypted_data = encrypter.encrypt(json_str) + with open(file_path, "w") as f: + f.write(encrypted_data) + except (IOError, OSError) as e: + # Silently proceed if we can't write to disk + logger.debug("Failed to write data to %s error: %s", file_path, e) + + def get(self, key: SecureCacheKey) -> Optional[ConnectionInfo]: + """Get value from cache, checking both memory and disk.""" + if self.disabled: + return None + + # First try memory cache + memory_result = super().get(key) + if memory_result is not None: + logger.debug("Cache hit in memory") + return memory_result + + # If not in memory, try to load from disk + file_path = self._get_file_path(key) + encrypter = FernetEncrypter(generate_salt(), key.encryption_key) + raw_data = self._read_data_json(file_path, encrypter) + if not raw_data: + return None + logger.debug("Cache hit on disk") + data = ConnectionInfo(**raw_data) + + # Add to memory cache and return + super().set(key, data) + return data + + def set(self, key: SecureCacheKey, value: ConnectionInfo) -> None: + """Set value in both memory and disk cache.""" + if self.disabled: + return + + logger.debug("Setting value in cache") + # First set in memory + super().set(key, value) + + file_path = self._get_file_path(key) + encrypter = FernetEncrypter(generate_salt(), key.encryption_key) + data = asdict(value) + + self._write_data_json(file_path, data, encrypter) + + def delete(self, key: SecureCacheKey) -> None: + """Delete value from both memory and disk cache.""" + if self.disabled: + return + + # Delete from memory + super().delete(key) + + # Delete from disk + file_path = self._get_file_path(key) + try: + if path.exists(file_path): + os.remove(file_path) + except OSError: + logger.debug("Failed to delete file %s", file_path) + # Silently proceed if we can't delete the file + + def clear(self) -> None: + # Clear memory only, as deleting every file is not safe + logger.debug("Clearing memory cache") + super().clear() + + +_firebolt_cache = FileBasedCache(cache_name="connection_info") diff --git a/src/firebolt/utils/file_operations.py b/src/firebolt/utils/file_operations.py new file mode 100644 index 00000000000..be08dcbddcb --- /dev/null +++ b/src/firebolt/utils/file_operations.py @@ -0,0 +1,114 @@ +from base64 import b64decode, urlsafe_b64encode +from hashlib import sha256 +from typing import Optional + +from cryptography.fernet import Fernet, InvalidToken +from cryptography.hazmat.backends import default_backend # type: ignore +from cryptography.hazmat.primitives import hashes, padding # type: ignore +from cryptography.hazmat.primitives.ciphers import ( # type: ignore + Cipher, + algorithms, + modes, +) +from cryptography.hazmat.primitives.kdf.pbkdf2 import ( + PBKDF2HMAC, # type: ignore +) + + +class FernetEncrypter: + """PBKDF2HMAC based encrypter. + + Username and password combination is used as a key. + + Args: + salt (str): Salt value for encryption + username: Username for key + password: Password for key + """ + + def __init__(self, salt: str, encryption_key: str): + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + salt=b64decode(salt), + length=32, + iterations=39000, + backend=default_backend(), + ) + self.fernet = Fernet( + urlsafe_b64encode(kdf.derive(bytes(encryption_key, encoding="utf-8"))) + ) + + def encrypt(self, data: str) -> str: + """Encrypt data string. + + Args: + data (str): Data for encryption + + Returns: + str: Encrypted data + + """ + return self.fernet.encrypt(bytes(data, encoding="utf-8")).decode("utf-8") + + def decrypt(self, data: str) -> Optional[str]: + """Decrypt encrypted data. + + Args: + data (str): Encrypted data + + Returns: + Optional[str]: Decrypted data + + """ + try: + return self.fernet.decrypt(bytes(data, encoding="utf-8")).decode("utf-8") + except InvalidToken: + return None + + +def generate_salt() -> str: + """Generate salt for FernetEncrypter. + + Returns: + str: Generated salt + """ + return "salt" + + +def generate_encrypted_file_name(cache_key: str, encryption_key: str) -> str: + """Generate encrypted file name from cache key using AES encryption. + + Args: + cache_key (str): The cache key to encrypt + encryption_key (str): The encryption key + + Returns: + str: Base64URL encoded AES encrypted filename ending in .txt + """ + # Derive a 256-bit key from the encryption_key using PBKDF2 + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + salt=b"firebolt_cache_salt", # Fixed salt for deterministic key derivation + length=32, # 256 bits + iterations=10000, + backend=default_backend(), + ) + aes_key = kdf.derive(encryption_key.encode("utf-8")) + + # Pad the cache_key to be a multiple of 16 bytes (AES block size) + padder = padding.PKCS7(128).padder() + padded_data = padder.update(cache_key.encode("utf-8")) + padded_data += padder.finalize() + + # Use a fixed IV for deterministic encryption + # (same input always produces same output) + # This is acceptable for cache file names where we need deterministic results + iv = sha256(cache_key.encode("utf-8")).digest()[:16] + + # Encrypt the padded cache_key + cipher = Cipher(algorithms.AES(aes_key), modes.CBC(iv), backend=default_backend()) + encryptor = cipher.encryptor() + encrypted_data = encryptor.update(padded_data) + encryptor.finalize() + + # Base64URL encode the encrypted data and add .txt extension + return urlsafe_b64encode(encrypted_data).decode("ascii").rstrip("=") + ".txt" From eed8a071b4ddca54e0c48d274df79e3a2b6e2f4c Mon Sep 17 00:00:00 2001 From: ptiurin Date: Fri, 22 Aug 2025 13:56:27 +0100 Subject: [PATCH 02/11] Changes after testing --- src/firebolt/async_db/connection.py | 3 +-- src/firebolt/client/auth/base.py | 21 ++++++++++++++++----- src/firebolt/db/connection.py | 3 +-- src/firebolt/utils/cache.py | 26 ++++++++++++++++++-------- 4 files changed, 36 insertions(+), 17 deletions(-) diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index 5ad6938413d..15fe127fc9d 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -237,8 +237,7 @@ async def connect( if not auth: raise ConfigurationError("auth is required to connect.") - if account_name: - auth._account_name = account_name + auth.account = account_name api_endpoint = fix_url_schema(api_endpoint) # Type checks diff --git a/src/firebolt/client/auth/base.py b/src/firebolt/client/auth/base.py index 067836efe3e..d17e4746b48 100644 --- a/src/firebolt/client/auth/base.py +++ b/src/firebolt/client/auth/base.py @@ -2,7 +2,7 @@ from abc import abstractmethod from enum import IntEnum from time import time -from typing import AsyncGenerator, Generator, Optional +from typing import AsyncGenerator, Generator, Optional, Tuple from anyio import Lock from httpx import Auth as HttpxAuth @@ -57,6 +57,17 @@ def __init__(self, use_token_cache: bool = True): self._expires: Optional[int] = None self._lock = Lock() + @property + def account(self) -> Optional[str]: + return self._account_name + + @account.setter + def account(self, value: str) -> None: + self._account_name = value + # Now we have all the elements to fetch the cached token + if not self._token: + self._token, self._expires = self._get_cached_token() + def copy(self) -> "Auth": """Make another auth object with same credentials. @@ -109,7 +120,7 @@ def expired(self) -> bool: """ return self._expires is not None and self._expires <= int(time()) - def _get_cached_token(self) -> Optional[str]: + def _get_cached_token(self) -> Tuple[Optional[str], Optional[int]]: """If caching is enabled, get token from cache. If caching is disabled, None is returned. @@ -118,7 +129,7 @@ def _get_cached_token(self) -> Optional[str]: Optional[str]: Token if any, and if caching is enabled; None otherwise """ if not self._use_token_cache: - return None + return (None, None) cache_key = SecureCacheKey( [self.principal, self.secret, self._account_name], self.secret @@ -126,9 +137,9 @@ def _get_cached_token(self) -> Optional[str]: connection_info = _firebolt_cache.get(cache_key) if connection_info and connection_info.token: - return connection_info.token + return (connection_info.token, connection_info.expiry_time) - return None + return (None, None) def _cache_token(self) -> None: """If caching is enabled, cache token.""" diff --git a/src/firebolt/db/connection.py b/src/firebolt/db/connection.py index 14eada03a14..09cc40bc868 100644 --- a/src/firebolt/db/connection.py +++ b/src/firebolt/db/connection.py @@ -66,8 +66,7 @@ def connect( if not auth: raise ConfigurationError("auth is required to connect.") - if account_name: - auth._account_name = account_name + auth.account = account_name api_endpoint = fix_url_schema(api_endpoint) # Type checks diff --git a/src/firebolt/utils/cache.py b/src/firebolt/utils/cache.py index 43299009ab5..2988e76981a 100644 --- a/src/firebolt/utils/cache.py +++ b/src/firebolt/utils/cache.py @@ -71,14 +71,24 @@ def __post_init__(self) -> None: """ if self.system_engine and isinstance(self.system_engine, dict): self.system_engine = EngineInfo(**self.system_engine) - self.databases = { - k: DatabaseInfo(**v) - for k, v in self.databases.items() - if isinstance(v, dict) - } - self.engines = { - k: EngineInfo(**v) for k, v in self.engines.items() if isinstance(v, dict) - } + + # Convert dict values to dataclasses, keep existing dataclass objects + new_databases = {} + for k, db in self.databases.items(): + if isinstance(db, dict): + new_databases[k] = DatabaseInfo(**db) + else: + new_databases[k] = db + self.databases = new_databases + + # Convert dict values to dataclasses, keep existing dataclass objects + new_engines = {} + for k, engine in self.engines.items(): + if isinstance(engine, dict): + new_engines[k] = EngineInfo(**engine) + else: + new_engines[k] = engine + self.engines = new_engines def noop_if_disabled(func: Callable) -> Callable: From 60403ddd50c7dd7c017fb4c515011eb927572942 Mon Sep 17 00:00:00 2001 From: ptiurin Date: Fri, 22 Aug 2025 15:38:42 +0100 Subject: [PATCH 03/11] unit tests --- tests/unit/V1/async_db/test_connection.py | 17 ++-- tests/unit/V1/client/test_client.py | 11 ++- tests/unit/V1/db/test_connection.py | 17 ++-- .../unit/V1/service/test_resource_manager.py | 17 ++-- tests/unit/async_db/test_connection.py | 24 +++-- tests/unit/client/auth/test_auth.py | 43 +++++---- tests/unit/client/auth/test_auth_async.py | 45 ++++----- tests/unit/client/test_client.py | 11 ++- tests/unit/db/test_connection.py | 25 +++-- tests/unit/service/test_resource_manager.py | 22 +++-- tests/unit/test_cache_helpers.py | 63 ++++++++++++ tests/unit/utils/test_cache.py | 63 ++++++++++++ tests/unit/utils/test_file_operations.py | 96 +++++++++++++++++++ 13 files changed, 357 insertions(+), 97 deletions(-) create mode 100644 tests/unit/test_cache_helpers.py create mode 100644 tests/unit/utils/test_file_operations.py diff --git a/tests/unit/V1/async_db/test_connection.py b/tests/unit/V1/async_db/test_connection.py index 908d665e4da..51791367127 100644 --- a/tests/unit/V1/async_db/test_connection.py +++ b/tests/unit/V1/async_db/test_connection.py @@ -12,14 +12,15 @@ from firebolt.async_db.connection import Connection, connect from firebolt.client.auth import Auth, Token, UsernamePassword from firebolt.common._types import ColType +from firebolt.utils.cache import _firebolt_cache from firebolt.utils.exception import ( AccountNotFoundError, ConfigurationError, ConnectionClosedError, FireboltEngineError, ) -from firebolt.utils.token_storage import TokenSecureStorage from firebolt.utils.urls import ACCOUNT_ENGINE_ID_BY_NAME_URL +from tests.unit.test_cache_helpers import get_cached_token async def test_closed_connection(connection: Connection) -> None: @@ -303,6 +304,7 @@ async def test_connection_token_caching( access_token: str, account_id_callback: Callable, account_id_url: str, + enable_cache: Callable, ) -> None: httpx_mock.add_callback(check_credentials_callback, url=auth_url, is_reusable=True) httpx_mock.add_callback( @@ -331,9 +333,11 @@ async def test_connection_token_caching( assert await connection.cursor().execute("select*") == len( python_query_data ) - ts = TokenSecureStorage(username=user, password=password) - assert ts.get_cached_token() == access_token, "Invalid token value cached" + # Verify token was cached using the new cache system + cached_token = get_cached_token(user, password, account_name) + assert cached_token == access_token, "Invalid token value cached" + _firebolt_cache.clear() # Do the same, but with use_token_cache=False with Patcher(): async with await connect( @@ -350,10 +354,9 @@ async def test_connection_token_caching( assert await connection.cursor().execute("select*") == len( python_query_data ) - ts = TokenSecureStorage(username=user, password=password) - assert ( - ts.get_cached_token() is None - ), "Token is cached even though caching is disabled" + # Verify token was not cached when caching is disabled + cached_token = get_cached_token(user, password, account_name) + assert cached_token is None, "Token is cached even though caching is disabled" async def test_connect_with_auth( diff --git a/tests/unit/V1/client/test_client.py b/tests/unit/V1/client/test_client.py index 763f7a4a5bf..7f4d70e03b9 100644 --- a/tests/unit/V1/client/test_client.py +++ b/tests/unit/V1/client/test_client.py @@ -10,9 +10,9 @@ from firebolt.client import ClientV1 as Client from firebolt.client.auth import Token, UsernamePassword from firebolt.client.resource_manager_hooks import raise_on_4xx_5xx -from firebolt.utils.token_storage import TokenSecureStorage from firebolt.utils.urls import AUTH_URL from firebolt.utils.util import fix_url_schema +from tests.unit.test_cache_helpers import cache_token def test_client_retry( @@ -130,16 +130,19 @@ def test_refresh_with_hooks( test_username: str, test_password: str, test_token: str, + enable_cache: Callable, ) -> None: """ When hooks are used, the invalid token, fetched from cache, is refreshed """ - tss = TokenSecureStorage(test_username, test_password) - tss.cache_token(test_token, 2**32) + cache_token(test_username, test_password, test_token, 2**32) + auth = UsernamePassword(test_username, test_password) + # Simulate what connect() would do + auth.account = None client = Client( - auth=UsernamePassword(test_username, test_password), + auth=auth, event_hooks={ "response": [raise_on_4xx_5xx], }, diff --git a/tests/unit/V1/db/test_connection.py b/tests/unit/V1/db/test_connection.py index 1b47a3e8043..89621c5ed0c 100644 --- a/tests/unit/V1/db/test_connection.py +++ b/tests/unit/V1/db/test_connection.py @@ -13,14 +13,15 @@ from firebolt.common._types import ColType from firebolt.db import Connection, connect from firebolt.db.cursor import CursorV1 as Cursor +from firebolt.utils.cache import _firebolt_cache from firebolt.utils.exception import ( AccountNotFoundError, ConfigurationError, ConnectionClosedError, FireboltEngineError, ) -from firebolt.utils.token_storage import TokenSecureStorage from firebolt.utils.urls import ACCOUNT_ENGINE_ID_BY_NAME_URL +from tests.unit.test_cache_helpers import get_cached_token def test_closed_connection(connection: Connection) -> None: @@ -272,6 +273,7 @@ def test_connection_token_caching( access_token: str, account_id_callback: Callable, account_id_url: str, + enable_cache: Callable, ) -> None: httpx_mock.add_callback(check_credentials_callback, url=auth_url, is_reusable=True) httpx_mock.add_callback( @@ -294,9 +296,11 @@ def test_connection_token_caching( api_endpoint=api_endpoint, ) as connection: assert connection.cursor().execute("select*") == len(python_query_data) - ts = TokenSecureStorage(username=user, password=password) - assert ts.get_cached_token() == access_token, "Invalid token value cached" + # Verify token was cached using the new cache system + cached_token = get_cached_token(user, password, account_name) + assert cached_token == access_token, "Invalid token value cached" + _firebolt_cache.clear() # Do the same, but with use_token_cache=False with Patcher(): with connect( @@ -307,10 +311,9 @@ def test_connection_token_caching( api_endpoint=api_endpoint, ) as connection: assert connection.cursor().execute("select*") == len(python_query_data) - ts = TokenSecureStorage(username=user, password=password) - assert ( - ts.get_cached_token() is None - ), "Token is cached even though caching is disabled" + # Verify token was not cached when caching is disabled + cached_token = get_cached_token(user, password, account_name) + assert cached_token is None, "Token is cached even though caching is disabled" def test_connect_with_auth( diff --git a/tests/unit/V1/service/test_resource_manager.py b/tests/unit/V1/service/test_resource_manager.py index 0c1d6994a01..32d0d119eeb 100644 --- a/tests/unit/V1/service/test_resource_manager.py +++ b/tests/unit/V1/service/test_resource_manager.py @@ -8,8 +8,9 @@ from firebolt.client.auth import Auth, Token, UsernamePassword from firebolt.common.settings import Settings from firebolt.service.manager import ResourceManager +from firebolt.utils.cache import _firebolt_cache from firebolt.utils.exception import AccountNotFoundError -from firebolt.utils.token_storage import TokenSecureStorage +from tests.unit.test_cache_helpers import get_cached_token def test_rm_credentials( @@ -82,6 +83,7 @@ def test_rm_token_cache( account_id_url: Pattern, account_id_callback: Callable, access_token: str, + enable_cache: Callable, ) -> None: """Credentials, that are passed to rm are processed properly.""" url = "https://url" @@ -112,9 +114,11 @@ def test_rm_token_cache( rm = ResourceManager(local_settings) rm._client.get(url) - ts = TokenSecureStorage(user, password) - assert ts.get_cached_token() == access_token, "Invalid token value cached" + # Verify token was cached using the new cache system + cached_token = get_cached_token(user, password) + assert cached_token == access_token, "Invalid token value cached" + _firebolt_cache.clear() # Do the same, but with use_token_cache=False with Patcher(): local_settings = Settings( @@ -125,10 +129,9 @@ def test_rm_token_cache( rm = ResourceManager(local_settings) rm._client.get(url) - ts = TokenSecureStorage(user, password) - assert ( - ts.get_cached_token() is None - ), "Token is cached even though caching is disabled" + # Verify token was not cached when caching is disabled + cached_token = get_cached_token(user, password) + assert cached_token is None, "Token is cached even though caching is disabled" def test_rm_invalid_account_name( diff --git a/tests/unit/async_db/test_connection.py b/tests/unit/async_db/test_connection.py index 416c65f93f3..89fb5bf4551 100644 --- a/tests/unit/async_db/test_connection.py +++ b/tests/unit/async_db/test_connection.py @@ -16,7 +16,7 @@ ConnectionClosedError, FireboltError, ) -from firebolt.utils.token_storage import TokenSecureStorage +from tests.unit.test_cache_helpers import get_cached_token @mark.skip("__slots__ is broken on Connection class") @@ -213,6 +213,8 @@ def system_engine_callback_counter(request, **kwargs): else: assert system_engine_call_counter != 1, "System engine URL was cached" + _firebolt_cache.clear() + async def test_connect_engine_failed( db_name: str, @@ -307,6 +309,7 @@ async def test_connection_token_caching( python_query_data: List[List[ColType]], mock_connection_flow: Callable, mock_query: Callable, + enable_cache: Callable, ) -> None: mock_connection_flow() mock_query() @@ -323,9 +326,11 @@ async def test_connection_token_caching( assert await connection.cursor().execute("select*") == len( python_query_data ) - ts = TokenSecureStorage(username=client_id, password=client_secret) - assert ts.get_cached_token() == access_token, "Invalid token value cached" + # Verify token was cached using the new cache system + cached_token = get_cached_token(client_id, client_secret, account_name) + assert cached_token == access_token, "Invalid token value cached" + _firebolt_cache.clear() with Patcher(): async with await connect( database=db_name, @@ -337,9 +342,11 @@ async def test_connection_token_caching( assert await connection.cursor().execute("select*") == len( python_query_data ) - ts = TokenSecureStorage(username=client_id, password=client_secret) - assert ts.get_cached_token() == access_token, "Invalid token value cached" + # Verify token was cached using the new cache system (second check) + cached_token = get_cached_token(client_id, client_secret, account_name) + assert cached_token == access_token, "Invalid token value cached" + _firebolt_cache.clear() # Do the same, but with use_token_cache=False with Patcher(): async with await connect( @@ -352,10 +359,9 @@ async def test_connection_token_caching( assert await connection.cursor().execute("select*") == len( python_query_data ) - ts = TokenSecureStorage(username=client_id, password=client_secret) - assert ( - ts.get_cached_token() is None - ), "Token is cached even though caching is disabled" + # Verify token was not cached when caching is disabled + cached_token = get_cached_token(client_id, client_secret, account_name) + assert cached_token is None, "Token is cached even though caching is disabled" async def test_connect_with_user_agent( diff --git a/tests/unit/client/auth/test_auth.py b/tests/unit/client/auth/test_auth.py index 4a372b0e084..0f63a0039f7 100644 --- a/tests/unit/client/auth/test_auth.py +++ b/tests/unit/client/auth/test_auth.py @@ -1,13 +1,14 @@ from types import MethodType -from unittest.mock import PropertyMock, patch +from typing import Generator from httpx import Request, codes from pyfakefs.fake_filesystem_unittest import Patcher from pytest import mark from pytest_httpx import HTTPXMock -from firebolt.client.auth import Auth -from firebolt.utils.token_storage import TokenSecureStorage +from firebolt.client.auth import Auth, ClientCredentials +from firebolt.utils.cache import _firebolt_cache +from tests.unit.test_cache_helpers import get_cached_token from tests.unit.util import execute_generator_requests @@ -89,6 +90,7 @@ def test_auth_token_storage( client_id: str, client_secret: str, access_token: str, + enable_cache: Generator, ) -> None: # Mock auth flow def set_token(token: str) -> callable: @@ -101,29 +103,28 @@ def inner(self): url = "https://host" httpx_mock.add_response(status_code=codes.OK, url=url, is_reusable=True) - with Patcher(), patch( - "firebolt.client.auth.base.Auth._token_storage", - new_callable=PropertyMock, - return_value=TokenSecureStorage(client_id, client_secret), - ): - auth = Auth(use_token_cache=True) + + # Test with caching enabled + with Patcher(): + auth = ClientCredentials(client_id, client_secret, use_token_cache=True) # Get token auth.get_new_token_generator = MethodType(set_token(access_token), auth) execute_generator_requests(auth.auth_flow(Request("GET", url))) - st = TokenSecureStorage(client_id, client_secret) - assert st.get_cached_token() == access_token, "Invalid token value cached" + # Verify token was cached using the new cache system + cached_token = get_cached_token(client_id, client_secret, None) + assert cached_token == access_token, "Invalid token value cached" + + # Clear cache before second test + _firebolt_cache.clear() - with Patcher(), patch( - "firebolt.client.auth.base.Auth._token_storage", - new_callable=PropertyMock, - return_value=TokenSecureStorage(client_id, client_secret), - ): - auth = Auth(use_token_cache=False) + # Test with caching disabled + with Patcher(): + auth = ClientCredentials(client_id, client_secret, use_token_cache=False) # Get token auth.get_new_token_generator = MethodType(set_token(access_token), auth) execute_generator_requests(auth.auth_flow(Request("GET", url))) - st = TokenSecureStorage(client_id, client_secret) - assert ( - st.get_cached_token() is None - ), "Token cached even though caching is disabled" + + # Verify token was not cached + cached_token = get_cached_token(client_id, client_secret, None) + assert cached_token is None, "Token cached even though caching is disabled" diff --git a/tests/unit/client/auth/test_auth_async.py b/tests/unit/client/auth/test_auth_async.py index f8ca484ccbe..ac766dde926 100644 --- a/tests/unit/client/auth/test_auth_async.py +++ b/tests/unit/client/auth/test_auth_async.py @@ -1,13 +1,14 @@ from types import MethodType -from unittest.mock import PropertyMock, patch +from typing import Generator from httpx import Request, codes from pyfakefs.fake_filesystem_unittest import Patcher from pytest import mark from pytest_httpx import HTTPXMock -from firebolt.client import Auth -from firebolt.utils.token_storage import TokenSecureStorage +from firebolt.client.auth import Auth, ClientCredentials +from firebolt.utils.cache import _firebolt_cache +from tests.unit.test_cache_helpers import get_cached_token from tests.unit.util import async_execute_generator_requests @@ -91,7 +92,8 @@ async def test_auth_token_storage( httpx_mock: HTTPXMock, client_id: str, client_secret: str, - access_token, + access_token: str, + enable_cache: Generator, ) -> None: # Mock auth flow def set_token(token: str) -> callable: @@ -104,33 +106,32 @@ def inner(self): url = "https://host" httpx_mock.add_response(status_code=codes.OK, url=url, is_reusable=True) - with Patcher(), patch( - "firebolt.client.auth.base.Auth._token_storage", - new_callable=PropertyMock, - return_value=TokenSecureStorage(client_id, client_secret), - ): - auth = Auth(use_token_cache=True) + + # Test with caching enabled + with Patcher(): + auth = ClientCredentials(client_id, client_secret, use_token_cache=True) # Get token auth.get_new_token_generator = MethodType(set_token(access_token), auth) await async_execute_generator_requests( auth.async_auth_flow(Request("GET", url)) ) - st = TokenSecureStorage(client_id, client_secret) - assert st.get_cached_token() == access_token, "Invalid token value cached" + # Verify token was cached using the new cache system + cached_token = get_cached_token(client_id, client_secret, None) + assert cached_token == access_token, "Invalid token value cached" + + # Clear cache before second test + _firebolt_cache.clear() - with Patcher(), patch( - "firebolt.client.auth.base.Auth._token_storage", - new_callable=PropertyMock, - return_value=TokenSecureStorage(client_id, client_secret), - ): - auth = Auth(use_token_cache=False) + # Test with caching disabled + with Patcher(): + auth = ClientCredentials(client_id, client_secret, use_token_cache=False) # Get token auth.get_new_token_generator = MethodType(set_token(access_token), auth) await async_execute_generator_requests( auth.async_auth_flow(Request("GET", url)) ) - st = TokenSecureStorage(client_id, client_secret) - assert ( - st.get_cached_token() is None - ), "Token cached even though caching is disabled" + + # Verify token was not cached + cached_token = get_cached_token(client_id, client_secret, None) + assert cached_token is None, "Token cached even though caching is disabled" diff --git a/tests/unit/client/test_client.py b/tests/unit/client/test_client.py index 1900151c67d..33b9b87f479 100644 --- a/tests/unit/client/test_client.py +++ b/tests/unit/client/test_client.py @@ -8,9 +8,9 @@ from firebolt.client import ClientV2 as Client from firebolt.client.auth import Auth, ClientCredentials from firebolt.client.resource_manager_hooks import raise_on_4xx_5xx -from firebolt.utils.token_storage import TokenSecureStorage from firebolt.utils.urls import AUTH_SERVICE_ACCOUNT_URL from tests.unit.conftest import Response +from tests.unit.test_cache_helpers import cache_token def test_client_retry( @@ -99,17 +99,20 @@ def test_refresh_with_hooks( client_id: str, client_secret: str, access_token: str, + enable_cache: Callable, ) -> None: """ When hooks are used, the invalid token, fetched from cache, is refreshed """ - tss = TokenSecureStorage(client_id, client_secret) - tss.cache_token(access_token, 2**32) + cache_token(client_id, client_secret, access_token, 2**32, account_name) + + auth = ClientCredentials(client_id, client_secret) + auth.account = account_name client = Client( account_name=account_name, - auth=ClientCredentials(client_id, client_secret), + auth=auth, event_hooks={ "response": [raise_on_4xx_5xx], }, diff --git a/tests/unit/db/test_connection.py b/tests/unit/db/test_connection.py index 148589c40cf..57677ce6cb4 100644 --- a/tests/unit/db/test_connection.py +++ b/tests/unit/db/test_connection.py @@ -20,7 +20,7 @@ ConnectionClosedError, FireboltError, ) -from firebolt.utils.token_storage import TokenSecureStorage +from tests.unit.test_cache_helpers import get_cached_token def test_connection_attributes(connection: Connection) -> None: @@ -262,6 +262,8 @@ def system_engine_callback_counter(request, **kwargs): else: assert system_engine_call_counter != 1, "System engine URL was cached" + _firebolt_cache.clear() + def test_connect_system_engine_404( db_name: str, @@ -331,6 +333,7 @@ def test_connection_token_caching( python_query_data: List[List[ColType]], mock_connection_flow: Callable, mock_query: Callable, + enable_cache: Callable, ) -> None: mock_connection_flow() mock_query() @@ -345,8 +348,11 @@ def test_connection_token_caching( api_endpoint=api_endpoint, ) as connection: assert connection.cursor().execute("select*") == len(python_query_data) - ts = TokenSecureStorage(username=client_id, password=client_secret) - assert ts.get_cached_token() == access_token, "Invalid token value cached" + # Verify token was cached using the new cache system + cached_token = get_cached_token(client_id, client_secret, account_name) + assert cached_token == access_token, "Invalid token value cached" + + _firebolt_cache.clear() with Patcher(): with connect( @@ -357,9 +363,11 @@ def test_connection_token_caching( api_endpoint=api_endpoint, ) as connection: assert connection.cursor().execute("select*") == len(python_query_data) - ts = TokenSecureStorage(username=client_id, password=client_secret) - assert ts.get_cached_token() == access_token, "Invalid token value cached" + # Verify token was cached using the new cache system (second test) + cached_token = get_cached_token(client_id, client_secret, account_name) + assert cached_token == access_token, "Invalid token value cached" + _firebolt_cache.clear() # Do the same, but with use_token_cache=False with Patcher(): with connect( @@ -370,10 +378,9 @@ def test_connection_token_caching( api_endpoint=api_endpoint, ) as connection: assert connection.cursor().execute("select*") == len(python_query_data) - ts = TokenSecureStorage(username=client_id, password=client_secret) - assert ( - ts.get_cached_token() is None - ), "Token is cached even though caching is disabled" + # Verify token was not cached when caching is disabled + cached_token = get_cached_token(client_id, client_secret, account_name) + assert cached_token is None, "Token is cached even though caching is disabled" def test_connect_with_user_agent( diff --git a/tests/unit/service/test_resource_manager.py b/tests/unit/service/test_resource_manager.py index 93febfdb2f3..c41ef6a0b1d 100644 --- a/tests/unit/service/test_resource_manager.py +++ b/tests/unit/service/test_resource_manager.py @@ -6,7 +6,8 @@ from firebolt.client.auth import Auth, ClientCredentials from firebolt.service.manager import ResourceManager -from firebolt.utils.token_storage import TokenSecureStorage +from firebolt.utils.cache import _firebolt_cache +from tests.unit.test_cache_helpers import get_cached_token def test_rm_credentials( @@ -38,6 +39,7 @@ def test_rm_token_cache( account_name: str, access_token: str, mock_system_engine_connection_flow: Callable, + enable_cache: Callable, ) -> None: """Credentials, that are passed to rm are cached properly.""" url = "https://url" @@ -55,10 +57,15 @@ def test_rm_token_cache( ) rm._client.get(url) - ts = TokenSecureStorage(auth.client_id, auth.client_secret) - assert ts.get_cached_token() == access_token, "Invalid token value cached" + # Verify token was cached using the new cache system + cached_token = get_cached_token( + auth.client_id, auth.client_secret, account_name + ) + assert cached_token == access_token, "Invalid token value cached" # Do the same, but with use_token_cache=False + _firebolt_cache.clear() # Clear cache before testing disabled cache + with Patcher(): rm = ResourceManager( auth=ClientCredentials( @@ -69,7 +76,8 @@ def test_rm_token_cache( ) rm._client.get(url) - ts = TokenSecureStorage(auth.client_id, auth.client_secret) - assert ( - ts.get_cached_token() is None - ), "Token is cached even though caching is disabled" + # Verify token was not cached when caching is disabled + cached_token = get_cached_token( + auth.client_id, auth.client_secret, account_name + ) + assert cached_token is None, "Token is cached even though caching is disabled" diff --git a/tests/unit/test_cache_helpers.py b/tests/unit/test_cache_helpers.py new file mode 100644 index 00000000000..8ac8f593884 --- /dev/null +++ b/tests/unit/test_cache_helpers.py @@ -0,0 +1,63 @@ +"""Helper functions for cache-related tests.""" +from typing import Optional + +from firebolt.utils.cache import ( + ConnectionInfo, + SecureCacheKey, + _firebolt_cache, +) + + +def get_cached_token( + principal: str, secret: str, account_name: Optional[str] = None +) -> Optional[str]: + """Get cached token for the given credentials. + + This is a test helper function for backward compatibility. + + Args: + principal: Username or client ID + secret: Password or client secret + account_name: Account name (optional) + + Returns: + Cached token if available, None otherwise + """ + cache_key = SecureCacheKey([principal, secret, account_name], secret) + connection_info = _firebolt_cache.get(cache_key) + + if connection_info and connection_info.token: + return connection_info.token + return None + + +def cache_token( + principal: str, + secret: str, + token: str, + expiry: Optional[int] = None, + account_name: Optional[str] = None, +) -> None: + """Cache token for the given credentials. + + This is a test helper function for backward compatibility. + + Args: + principal: Username or client ID + secret: Password or client secret + token: Token to cache + expiry: Token expiry time (ignored, we use our own expiry) + account_name: Account name (optional) + """ + cache_key = SecureCacheKey([principal, secret, account_name], secret) + + # Get existing connection info or create new one + connection_info = _firebolt_cache.get(cache_key) + if connection_info is None: + connection_info = ConnectionInfo(id="NONE") + + # Update token information + connection_info.token = token + + # Cache it + _firebolt_cache.set(cache_key, connection_info) diff --git a/tests/unit/utils/test_cache.py b/tests/unit/utils/test_cache.py index 6fce309fbe3..0328f5d092b 100644 --- a/tests/unit/utils/test_cache.py +++ b/tests/unit/utils/test_cache.py @@ -351,3 +351,66 @@ def test_cache_disable_enable_behavior( else: # Keep cache enabled - should continue working assert cache.get(sample_cache_key) is not None + + +def test_helper_functions(): + """Test the backward compatibility helper functions.""" + from tests.unit.test_cache_helpers import cache_token, get_cached_token + from firebolt.utils.cache import _firebolt_cache + + _firebolt_cache.enable() + _firebolt_cache.clear() + + # Test caching and retrieving tokens + principal = "test_user" + secret = "test_secret" + token = "test_token" + account_name = "test_account" + + # Cache token + cache_token(principal, secret, token, 9999, account_name) + + # Retrieve token + cached_token = get_cached_token(principal, secret, account_name) + assert cached_token == token + + # Test with None account name + cache_token(principal, secret, token, 9999, None) + cached_token_none = get_cached_token(principal, secret, None) + assert cached_token_none == token + + +def test_connection_info_post_init(): + """Test ConnectionInfo.__post_init__ method.""" + # Test with dictionary inputs that should be converted to dataclasses + engine_dict = {"url": "http://test.com", "params": {"key": "value"}} + db_dict = {"name": "test_db"} + + connection_info = ConnectionInfo( + id="test", + system_engine=engine_dict, + databases={"db1": db_dict}, + engines={"engine1": engine_dict} + ) + + # Should convert dicts to dataclasses + from firebolt.utils.cache import EngineInfo, DatabaseInfo + assert isinstance(connection_info.system_engine, EngineInfo) + assert isinstance(connection_info.databases["db1"], DatabaseInfo) + assert isinstance(connection_info.engines["engine1"], EngineInfo) + + # Test with already converted dataclass objects + engine_obj = EngineInfo(url="http://test.com", params={"key": "value"}) + db_obj = DatabaseInfo(name="test_db") + + connection_info2 = ConnectionInfo( + id="test2", + system_engine=engine_obj, + databases={"db1": db_obj}, + engines={"engine1": engine_obj} + ) + + # Should remain as dataclasses + assert connection_info2.system_engine is engine_obj + assert connection_info2.databases["db1"] is db_obj + assert connection_info2.engines["engine1"] is engine_obj diff --git a/tests/unit/utils/test_file_operations.py b/tests/unit/utils/test_file_operations.py new file mode 100644 index 00000000000..c0752f5c39d --- /dev/null +++ b/tests/unit/utils/test_file_operations.py @@ -0,0 +1,96 @@ +from firebolt.utils.file_operations import ( + FernetEncrypter, + generate_encrypted_file_name, + generate_salt, +) + + +def test_generate_encrypted_file_name_returns_same_value(): + # Test the function with sample inputs + assert generate_encrypted_file_name( + "test_key", "test_encryption_key" + ) == generate_encrypted_file_name("test_key", "test_encryption_key") + + +def test_generate_encrypted_file_name_different_inputs_different_outputs(): + # Test that different inputs produce different outputs + result1 = generate_encrypted_file_name("test_key1", "test_encryption_key") + result2 = generate_encrypted_file_name("test_key2", "test_encryption_key") + assert result1 != result2 + + +def test_generate_encrypted_file_name_different_keys_different_outputs(): + # Test that different encryption keys produce different outputs + result1 = generate_encrypted_file_name("test_key", "test_encryption_key1") + result2 = generate_encrypted_file_name("test_key", "test_encryption_key2") + assert result1 != result2 + + +def test_generate_encrypted_file_name_format(): + # Test that the output has the correct format + result = generate_encrypted_file_name("test_key", "test_encryption_key") + assert result.endswith(".txt") + assert len(result) > 10 # Should be a reasonable length with .txt extension + + +def test_generate_salt(): + """Test salt generation.""" + salt = generate_salt() + assert salt == "salt" + + +def test_fernet_encrypter(): + """Test FernetEncrypter encryption and decryption.""" + salt = generate_salt() + encryption_key = "test_encryption_key" + + encrypter = FernetEncrypter(salt, encryption_key) + + test_data = "Hello, World! This is test data." + + # Encrypt data + encrypted = encrypter.encrypt(test_data) + assert encrypted != test_data + assert len(encrypted) > len(test_data) + + # Decrypt data + decrypted = encrypter.decrypt(encrypted) + assert decrypted == test_data + + +def test_fernet_encrypter_invalid_data(): + """Test FernetEncrypter with invalid encrypted data.""" + salt = generate_salt() + encryption_key = "test_encryption_key" + + encrypter = FernetEncrypter(salt, encryption_key) + + # Try to decrypt invalid data + result = encrypter.decrypt("invalid_encrypted_data") + assert result is None + + +def test_fernet_encrypter_different_keys(): + """Test that different keys produce different encrypted data.""" + salt = generate_salt() + encryption_key1 = "test_key_1" + encryption_key2 = "test_key_2" + + encrypter1 = FernetEncrypter(salt, encryption_key1) + encrypter2 = FernetEncrypter(salt, encryption_key2) + + test_data = "Hello, World!" + + encrypted1 = encrypter1.encrypt(test_data) + encrypted2 = encrypter2.encrypt(test_data) + + # Different keys should produce different encrypted data + assert encrypted1 != encrypted2 + + # Each encrypter should decrypt its own data correctly + assert encrypter1.decrypt(encrypted1) == test_data + assert encrypter2.decrypt(encrypted2) == test_data + + # But shouldn't be able to decrypt the other's data + assert encrypter1.decrypt(encrypted2) is None + assert encrypter2.decrypt(encrypted1) is None From c237956041d6cc8e8367ad04e83c4eb896b1d87a Mon Sep 17 00:00:00 2001 From: ptiurin Date: Fri, 22 Aug 2025 17:01:36 +0100 Subject: [PATCH 04/11] fix pre commit --- src/firebolt/utils/cache.py | 38 ++++++++++++++++++++--------- src/firebolt/utils/usage_tracker.py | 8 ++++-- tests/unit/utils/test_cache.py | 29 +++++++++++----------- 3 files changed, 48 insertions(+), 27 deletions(-) diff --git a/src/firebolt/utils/cache.py b/src/firebolt/utils/cache.py index 2988e76981a..d5fc25a933f 100644 --- a/src/firebolt/utils/cache.py +++ b/src/firebolt/utils/cache.py @@ -192,20 +192,34 @@ def __hash__(self) -> int: return hash(self.key) -class FileBasedCache(UtilCache[ConnectionInfo]): +class FileBasedCache: """ File-based cache that persists to disk with encryption. - Extends UtilCache to provide persistent storage using encrypted files. + Uses composition to combine in-memory caching with persistent storage + using encrypted files. """ - def __init__(self, cache_name: str = ""): - super().__init__(cache_name) + def __init__(self, memory_cache: UtilCache[ConnectionInfo], cache_name: str = ""): + self.memory_cache = memory_cache self._data_dir = user_data_dir(appname=APPNAME) # TODO: change to new dir makedirs(self._data_dir, exist_ok=True) + # FileBasedCache has its own disabled state, independent of memory cache + cache_env_var = f"FIREBOLT_SDK_DISABLE_CACHE_${cache_name}" + self.disabled = os.getenv("FIREBOLT_SDK_DISABLE_CACHE", False) or os.getenv( + cache_env_var, False + ) + + def disable(self) -> None: + """Disable the file-based cache.""" + self.disabled = True + + def enable(self) -> None: + """Enable the file-based cache.""" + self.disabled = False def _get_file_path(self, key: SecureCacheKey) -> str: """Get the file path for a cache key.""" - cache_key = self.create_key(key) + cache_key = self.memory_cache.create_key(key) encrypted_filename = generate_encrypted_file_name(cache_key, key.encryption_key) return path.join(self._data_dir, encrypted_filename) @@ -250,7 +264,7 @@ def get(self, key: SecureCacheKey) -> Optional[ConnectionInfo]: return None # First try memory cache - memory_result = super().get(key) + memory_result = self.memory_cache.get(key) if memory_result is not None: logger.debug("Cache hit in memory") return memory_result @@ -265,7 +279,7 @@ def get(self, key: SecureCacheKey) -> Optional[ConnectionInfo]: data = ConnectionInfo(**raw_data) # Add to memory cache and return - super().set(key, data) + self.memory_cache.set(key, data) return data def set(self, key: SecureCacheKey, value: ConnectionInfo) -> None: @@ -275,7 +289,7 @@ def set(self, key: SecureCacheKey, value: ConnectionInfo) -> None: logger.debug("Setting value in cache") # First set in memory - super().set(key, value) + self.memory_cache.set(key, value) file_path = self._get_file_path(key) encrypter = FernetEncrypter(generate_salt(), key.encryption_key) @@ -289,7 +303,7 @@ def delete(self, key: SecureCacheKey) -> None: return # Delete from memory - super().delete(key) + self.memory_cache.delete(key) # Delete from disk file_path = self._get_file_path(key) @@ -303,7 +317,9 @@ def delete(self, key: SecureCacheKey) -> None: def clear(self) -> None: # Clear memory only, as deleting every file is not safe logger.debug("Clearing memory cache") - super().clear() + self.memory_cache.clear() -_firebolt_cache = FileBasedCache(cache_name="connection_info") +_firebolt_cache = FileBasedCache( + UtilCache[ConnectionInfo](cache_name="memory_cache"), cache_name="file_cache" +) diff --git a/src/firebolt/utils/usage_tracker.py b/src/firebolt/utils/usage_tracker.py index 83b30b6e9cf..c30ca362624 100644 --- a/src/firebolt/utils/usage_tracker.py +++ b/src/firebolt/utils/usage_tracker.py @@ -8,7 +8,11 @@ from typing import Dict, List, Optional, Tuple from firebolt import __version__ -from firebolt.utils.cache import ConnectionInfo, ReprCacheable, _firebolt_cache +from firebolt.utils.cache import ( + ConnectionInfo, + SecureCacheKey, + _firebolt_cache, +) @dataclass @@ -228,7 +232,7 @@ def get_user_agent_header( def get_cache_tracking_params( - cache_key: ReprCacheable, conn_id: str + cache_key: SecureCacheKey, conn_id: str ) -> List[Tuple[str, str]]: ua_parameters = [] ua_parameters.append(("connId", conn_id)) diff --git a/tests/unit/utils/test_cache.py b/tests/unit/utils/test_cache.py index 0328f5d092b..8c70856a00e 100644 --- a/tests/unit/utils/test_cache.py +++ b/tests/unit/utils/test_cache.py @@ -355,25 +355,25 @@ def test_cache_disable_enable_behavior( def test_helper_functions(): """Test the backward compatibility helper functions.""" - from tests.unit.test_cache_helpers import cache_token, get_cached_token from firebolt.utils.cache import _firebolt_cache - + from tests.unit.test_cache_helpers import cache_token, get_cached_token + _firebolt_cache.enable() _firebolt_cache.clear() - + # Test caching and retrieving tokens principal = "test_user" secret = "test_secret" token = "test_token" account_name = "test_account" - + # Cache token cache_token(principal, secret, token, 9999, account_name) - + # Retrieve token cached_token = get_cached_token(principal, secret, account_name) assert cached_token == token - + # Test with None account name cache_token(principal, secret, token, 9999, None) cached_token_none = get_cached_token(principal, secret, None) @@ -385,31 +385,32 @@ def test_connection_info_post_init(): # Test with dictionary inputs that should be converted to dataclasses engine_dict = {"url": "http://test.com", "params": {"key": "value"}} db_dict = {"name": "test_db"} - + connection_info = ConnectionInfo( id="test", system_engine=engine_dict, databases={"db1": db_dict}, - engines={"engine1": engine_dict} + engines={"engine1": engine_dict}, ) - + # Should convert dicts to dataclasses - from firebolt.utils.cache import EngineInfo, DatabaseInfo + from firebolt.utils.cache import DatabaseInfo, EngineInfo + assert isinstance(connection_info.system_engine, EngineInfo) assert isinstance(connection_info.databases["db1"], DatabaseInfo) assert isinstance(connection_info.engines["engine1"], EngineInfo) - + # Test with already converted dataclass objects engine_obj = EngineInfo(url="http://test.com", params={"key": "value"}) db_obj = DatabaseInfo(name="test_db") - + connection_info2 = ConnectionInfo( id="test2", system_engine=engine_obj, databases={"db1": db_obj}, - engines={"engine1": engine_obj} + engines={"engine1": engine_obj}, ) - + # Should remain as dataclasses assert connection_info2.system_engine is engine_obj assert connection_info2.databases["db1"] is db_obj From 0ab388b52dacb25fedfdd33e01e3fce0ab9259cd Mon Sep 17 00:00:00 2001 From: ptiurin Date: Fri, 22 Aug 2025 17:37:26 +0100 Subject: [PATCH 05/11] file name generation similar to java --- src/firebolt/utils/file_operations.py | 58 ++++++++++----------------- 1 file changed, 22 insertions(+), 36 deletions(-) diff --git a/src/firebolt/utils/file_operations.py b/src/firebolt/utils/file_operations.py index be08dcbddcb..85b86e8baa6 100644 --- a/src/firebolt/utils/file_operations.py +++ b/src/firebolt/utils/file_operations.py @@ -1,15 +1,11 @@ -from base64 import b64decode, urlsafe_b64encode +from base64 import b64decode, b64encode, urlsafe_b64encode from hashlib import sha256 from typing import Optional from cryptography.fernet import Fernet, InvalidToken from cryptography.hazmat.backends import default_backend # type: ignore -from cryptography.hazmat.primitives import hashes, padding # type: ignore -from cryptography.hazmat.primitives.ciphers import ( # type: ignore - Cipher, - algorithms, - modes, -) +from cryptography.hazmat.primitives import hashes # type: ignore +from cryptography.hazmat.primitives.ciphers.aead import AESGCM # type: ignore from cryptography.hazmat.primitives.kdf.pbkdf2 import ( PBKDF2HMAC, # type: ignore ) @@ -76,39 +72,29 @@ def generate_salt() -> str: def generate_encrypted_file_name(cache_key: str, encryption_key: str) -> str: - """Generate encrypted file name from cache key using AES encryption. + """Generate encrypted file name from cache key using AES-GCM encryption. + + This implementation matches the Java EncryptionService to ensure compatibility. Args: cache_key (str): The cache key to encrypt encryption_key (str): The encryption key Returns: - str: Base64URL encoded AES encrypted filename ending in .txt + str: Base64 encoded AES-GCM encrypted filename """ - # Derive a 256-bit key from the encryption_key using PBKDF2 - kdf = PBKDF2HMAC( - algorithm=hashes.SHA256(), - salt=b"firebolt_cache_salt", # Fixed salt for deterministic key derivation - length=32, # 256 bits - iterations=10000, - backend=default_backend(), - ) - aes_key = kdf.derive(encryption_key.encode("utf-8")) - - # Pad the cache_key to be a multiple of 16 bytes (AES block size) - padder = padding.PKCS7(128).padder() - padded_data = padder.update(cache_key.encode("utf-8")) - padded_data += padder.finalize() - - # Use a fixed IV for deterministic encryption - # (same input always produces same output) - # This is acceptable for cache file names where we need deterministic results - iv = sha256(cache_key.encode("utf-8")).digest()[:16] - - # Encrypt the padded cache_key - cipher = Cipher(algorithms.AES(aes_key), modes.CBC(iv), backend=default_backend()) - encryptor = cipher.encryptor() - encrypted_data = encryptor.update(padded_data) + encryptor.finalize() - - # Base64URL encode the encrypted data and add .txt extension - return urlsafe_b64encode(encrypted_data).decode("ascii").rstrip("=") + ".txt" + # Derive AES key using SHA-256 + key_hash = sha256(encryption_key.encode("utf-8")).digest() + aes_key = key_hash[:32] # Use first 32 bytes for AES-256 + + # Generate deterministic nonce + nonce_input = (encryption_key + encryption_key).encode("utf-8") + nonce_hash = sha256(nonce_input).digest() + nonce = nonce_hash[:12] # AES-GCM nonce should be 12 bytes + + # Encrypt using AES-GCM + aesgcm = AESGCM(aes_key) + encrypted_data = aesgcm.encrypt(nonce, cache_key.encode("utf-8"), None) + + # Base64 encode + return b64encode(encrypted_data).decode("ascii") From e0e0b0fd2af52ef3dec6b9217ad1c52f3dfb52c5 Mon Sep 17 00:00:00 2001 From: ptiurin Date: Fri, 22 Aug 2025 18:01:16 +0100 Subject: [PATCH 06/11] better java match --- src/firebolt/common/token_storage.py | 2 - src/firebolt/utils/file_operations.py | 14 +- src/firebolt/utils/token_storage.py | 181 ---------------------- tests/unit/client/auth/test_auth.py | 2 +- tests/unit/client/auth/test_auth_async.py | 2 +- tests/unit/common/test_token_storage.py | 146 ----------------- 6 files changed, 12 insertions(+), 335 deletions(-) delete mode 100644 src/firebolt/common/token_storage.py delete mode 100644 src/firebolt/utils/token_storage.py delete mode 100644 tests/unit/common/test_token_storage.py diff --git a/src/firebolt/common/token_storage.py b/src/firebolt/common/token_storage.py deleted file mode 100644 index d070223dfe1..00000000000 --- a/src/firebolt/common/token_storage.py +++ /dev/null @@ -1,2 +0,0 @@ -# Prevent backward compatibility errors related to new module structure -from firebolt.utils.token_storage import * # NOQA diff --git a/src/firebolt/utils/file_operations.py b/src/firebolt/utils/file_operations.py index 85b86e8baa6..c3bc4389b56 100644 --- a/src/firebolt/utils/file_operations.py +++ b/src/firebolt/utils/file_operations.py @@ -74,14 +74,14 @@ def generate_salt() -> str: def generate_encrypted_file_name(cache_key: str, encryption_key: str) -> str: """Generate encrypted file name from cache key using AES-GCM encryption. - This implementation matches the Java EncryptionService to ensure compatibility. + This implementation matches the Java FilenameGenerator to ensure compatibility. Args: cache_key (str): The cache key to encrypt encryption_key (str): The encryption key Returns: - str: Base64 encoded AES-GCM encrypted filename + str: Double base64 encoded AES-GCM encrypted filename ending in .txt """ # Derive AES key using SHA-256 key_hash = sha256(encryption_key.encode("utf-8")).digest() @@ -96,5 +96,11 @@ def generate_encrypted_file_name(cache_key: str, encryption_key: str) -> str: aesgcm = AESGCM(aes_key) encrypted_data = aesgcm.encrypt(nonce, cache_key.encode("utf-8"), None) - # Base64 encode - return b64encode(encrypted_data).decode("ascii") + first_base64 = b64encode(encrypted_data).decode("ascii") + + # URL-safe base64 encode without padding (matches Java FilenameGenerator) + second_base64 = ( + urlsafe_b64encode(first_base64.encode("ascii")).decode("ascii").rstrip("=") + ) + + return second_base64 + ".txt" diff --git a/src/firebolt/utils/token_storage.py b/src/firebolt/utils/token_storage.py deleted file mode 100644 index 131c76f8e43..00000000000 --- a/src/firebolt/utils/token_storage.py +++ /dev/null @@ -1,181 +0,0 @@ -from base64 import b64decode, b64encode, urlsafe_b64encode -from hashlib import sha256 -from json import JSONDecodeError -from json import dump as json_dump -from json import load as json_load -from os import makedirs, path, urandom -from time import time -from typing import Optional - -from appdirs import user_data_dir -from cryptography.fernet import Fernet, InvalidToken -from cryptography.hazmat.backends import default_backend # type: ignore -from cryptography.hazmat.primitives import hashes # type: ignore -from cryptography.hazmat.primitives.kdf.pbkdf2 import ( - PBKDF2HMAC, # type: ignore -) - -APPNAME = "firebolt" - - -def generate_salt() -> str: - """Generate salt for FernetExcrypter. - - Returns: - str: Generated salt - """ - return b64encode(urandom(16)).decode("ascii") - - -def generate_file_name(username: str, password: str) -> str: - """Generate unique file name based on username and password. - - Username and password values are not exposed. - - Args: - username (str): Username - password (str): Password - - Returns: - str: File name 64 characters long - - """ - username_hash = sha256(username.encode("utf-8")).hexdigest()[:32] - password_hash = sha256(password.encode("utf-8")).hexdigest()[:32] - - return f"{username_hash}{password_hash}.json" - - -class TokenSecureStorage: - """File system storage for token. - - Token is encrypted using username and password. - - Args: - username (str): Username - password (str): Password - """ - - def __init__(self, username: str, password: str): - self._data_dir = user_data_dir(appname=APPNAME) - makedirs(self._data_dir, exist_ok=True) - - self._token_file = path.join( - self._data_dir, generate_file_name(username, password) - ) - - self.salt = self._get_salt() - self.encrypter = FernetEncrypter(self.salt, username, password) - - def _get_salt(self) -> str: - """Get salt from the file if exists, or generate a new one. - - Returns: - str: Salt - """ - res = self._read_data_json() - return res.get("salt", generate_salt()) - - def _read_data_json(self) -> dict: - """Read json token file. - - Returns: - dict: JSON object - """ - if not path.exists(self._token_file): - return {} - - with open(self._token_file) as f: - try: - return json_load(f) - except JSONDecodeError: - return {} - - def get_cached_token(self) -> Optional[str]: - """Get decrypted token. - - If token is not found, cannot be decrypted with username and password, - or is expired - None will be returned. - - Returns: - Optional[str]: Decrypted token or None - """ - res = self._read_data_json() - if "token" not in res: - return None - - # Ignore expired tokens - if "expiration" in res and res["expiration"] <= int(time()): - return None - - return self.encrypter.decrypt(res["token"]) - - def cache_token(self, token: str, expiration_ts: int) -> None: - """Encrypt and store token in file system. - - Expiration timestamp is also stored with token in order to later - be able to check if it's expired. - - Args: - token (str): Token to store - expiration_ts (int): Token expiration timestamp - """ - token = self.encrypter.encrypt(token) - - with open(self._token_file, "w") as f: - json_dump( - {"token": token, "salt": self.salt, "expiration": expiration_ts}, f - ) - - -class FernetEncrypter: - """PBDKF2HMAC based encrypter. - - Username and password combination is used as a key. - - Args: - salt (str): Salt value for encryption - username: Username for key - password: Password for key - """ - - def __init__(self, salt: str, username: str, password: str): - kdf = PBKDF2HMAC( - algorithm=hashes.SHA256(), - salt=b64decode(salt), - length=32, - iterations=39000, - backend=default_backend(), - ) - self.fernet = Fernet( - urlsafe_b64encode( - kdf.derive(bytes(f"{username}{password}", encoding="utf-8")) - ) - ) - - def encrypt(self, data: str) -> str: - """Encrypt data string. - - Args: - data (str): Data for encryption - - Returns: - str: Encrypted data - - """ - return self.fernet.encrypt(bytes(data, encoding="utf-8")).decode("utf-8") - - def decrypt(self, data: str) -> Optional[str]: - """Decrypt encrypted data. - - Args: - data (str): Encrypted data - - Returns: - Optional[str]: Decrypted data - - """ - try: - return self.fernet.decrypt(bytes(data, encoding="utf-8")).decode("utf-8") - except InvalidToken: - return None diff --git a/tests/unit/client/auth/test_auth.py b/tests/unit/client/auth/test_auth.py index 0f63a0039f7..533d55e44f6 100644 --- a/tests/unit/client/auth/test_auth.py +++ b/tests/unit/client/auth/test_auth.py @@ -85,7 +85,7 @@ def test_auth_adds_header(access_token: str) -> None: @mark.nofakefs -def test_auth_token_storage( +def test_auth_token_caching( httpx_mock: HTTPXMock, client_id: str, client_secret: str, diff --git a/tests/unit/client/auth/test_auth_async.py b/tests/unit/client/auth/test_auth_async.py index ac766dde926..92d1ae73edd 100644 --- a/tests/unit/client/auth/test_auth_async.py +++ b/tests/unit/client/auth/test_auth_async.py @@ -88,7 +88,7 @@ async def test_auth_adds_header(access_token: str) -> None: @mark.nofakefs -async def test_auth_token_storage( +async def test_auth_token_caching( httpx_mock: HTTPXMock, client_id: str, client_secret: str, diff --git a/tests/unit/common/test_token_storage.py b/tests/unit/common/test_token_storage.py deleted file mode 100644 index 4e441e877fa..00000000000 --- a/tests/unit/common/test_token_storage.py +++ /dev/null @@ -1,146 +0,0 @@ -import os -from unittest.mock import patch - -from appdirs import user_config_dir -from pyfakefs.fake_filesystem import FakeFilesystem - -from firebolt.utils.token_storage import ( - FernetEncrypter, - TokenSecureStorage, - generate_salt, -) - - -def test_encrypter_happy_path(): - """ - Simple encrypt/decrypt using FernetEncrypter. - """ - salt = generate_salt() - encrypter1 = FernetEncrypter(salt, username="username", password="password") - encrypter2 = FernetEncrypter(salt, username="username", password="password") - - token = "some string to encrypt" - encrypted_token = encrypter1.encrypt(token) - - assert token == encrypter2.decrypt(encrypted_token) - - -def test_encrypter_wrong_parameter(): - """ - Test that decryption only works if the correct salt - username and password is provided; otherwise None is returned. - """ - salt1 = generate_salt() - salt2 = generate_salt() - - encrypter1 = FernetEncrypter(salt1, username="username", password="password") - - token = "some string to encrypt" - encrypted_token = encrypter1.encrypt(token) - - encrypter2 = FernetEncrypter(salt2, username="username", password="password") - assert encrypter2.decrypt(encrypted_token) is None - - encrypter2 = FernetEncrypter(salt1, username="username1", password="password") - assert encrypter2.decrypt(encrypted_token) is None - - encrypter2 = FernetEncrypter(salt1, username="username", password="password1") - assert encrypter2.decrypt(encrypted_token) is None - - encrypter2 = FernetEncrypter(salt1, username="username", password="password") - assert encrypter2.decrypt(encrypted_token) == token - - -@patch("firebolt.utils.token_storage.time", return_value=0) -def test_token_storage_happy_path(fs: FakeFilesystem): - """ - Test storage happy path cache token and get token - """ - settings = {"username": "username", "password": "password"} - assert TokenSecureStorage(**settings).get_cached_token() is None - - token = "some string to encrypt" - TokenSecureStorage(**settings).cache_token(token, 1) - - assert token == TokenSecureStorage(**settings).get_cached_token() - token = "some new string to encrypt" - - TokenSecureStorage(**settings).cache_token(token, 1) - assert token == TokenSecureStorage(**settings).get_cached_token() - - -@patch("firebolt.utils.token_storage.time", return_value=0) -def test_token_storage_wrong_parameter(fs: FakeFilesystem): - """ - Test getting token with different username or password. - """ - settings = {"username": "username", "password": "password"} - token = "some string to encrypt" - TokenSecureStorage(**settings).cache_token(token, 1) - - assert ( - TokenSecureStorage( - username="username", password="wrong_password" - ).get_cached_token() - is None - ) - assert ( - TokenSecureStorage( - username="wrong_username", password="password" - ).get_cached_token() - is None - ) - assert TokenSecureStorage(**settings).get_cached_token() == token - - -def test_token_storage_json_broken(fs: FakeFilesystem): - """ - Check that the TokenSecureStorage properly handles broken json. - """ - settings = {"username": "username", "password": "password"} - - data_dir = os.path.join(user_config_dir(), "firebolt") - fs.create_dir(data_dir) - fs.create_file(os.path.join(data_dir, "token.json"), contents="{Not a valid json") - - assert TokenSecureStorage(**settings).get_cached_token() is None - - -@patch("firebolt.utils.token_storage.time", return_value=0) -def test_multiple_tokens(fs: FakeFilesystem) -> None: - """ - Check that the TokenSecureStorage properly handles multiple tokens hashed. - """ - settings1 = {"username": "username1", "password": "password1"} - settings2 = {"username": "username2", "password": "password2"} - token1 = "token1" - token2 = "token2" - token3 = "token3" - - st1 = TokenSecureStorage(**settings1) - st2 = TokenSecureStorage(**settings2) - - st1.cache_token(token1, 1) - - assert st1.get_cached_token() == token1 - assert st2.get_cached_token() is None - - st2.cache_token(token2, 1) - - assert st1.get_cached_token() == token1 - assert st2.get_cached_token() == token2 - - st1.cache_token(token3, 1) - assert st1.get_cached_token() == token3 - assert st2.get_cached_token() == token2 - - -@patch("firebolt.utils.token_storage.time", return_value=0) -def test_expired_token(fs: FakeFilesystem) -> None: - """ - Check that TokenSecureStorage ignores expired tokens. - """ - tss = TokenSecureStorage(username="username", password="password") - tss.cache_token("token", 0) - - assert tss.get_cached_token() is None From 88b4b03606e506d837cda2f9f7d0a5a6a694be3f Mon Sep 17 00:00:00 2001 From: ptiurin Date: Fri, 22 Aug 2025 19:13:55 +0100 Subject: [PATCH 07/11] improve coverage --- tests/unit/utils/test_cache.py | 147 ++++++++++++++++++++++++++++++++- 1 file changed, 146 insertions(+), 1 deletion(-) diff --git a/tests/unit/utils/test_cache.py b/tests/unit/utils/test_cache.py index 8c70856a00e..6db44e8ecf5 100644 --- a/tests/unit/utils/test_cache.py +++ b/tests/unit/utils/test_cache.py @@ -1,15 +1,19 @@ +import json +import os import time from typing import Generator -from unittest.mock import patch +from unittest.mock import mock_open, patch from pytest import fixture, mark from firebolt.utils.cache import ( CACHE_EXPIRY_SECONDS, ConnectionInfo, + FileBasedCache, SecureCacheKey, UtilCache, ) +from firebolt.utils.file_operations import FernetEncrypter, generate_salt @fixture @@ -78,6 +82,26 @@ def test_string(): return "test_value" +@fixture +def file_based_cache() -> Generator[FileBasedCache, None, None]: + """Create a fresh FileBasedCache instance for testing.""" + memory_cache = UtilCache[ConnectionInfo](cache_name="test_memory_cache") + memory_cache.enable() + cache = FileBasedCache(memory_cache, cache_name="test_file_cache") + cache.enable() + yield cache + cache.clear() + + +@fixture +def encrypter_with_key(): + """Create a FernetEncrypter instance for testing.""" + from firebolt.utils.file_operations import FernetEncrypter, generate_salt + + salt = generate_salt() + return FernetEncrypter(salt, "test_encryption_key") + + def test_cache_set_and_get(cache, sample_cache_key, sample_connection_info): """Test basic cache set and get operations.""" # Test cache miss initially @@ -415,3 +439,124 @@ def test_connection_info_post_init(): assert connection_info2.system_engine is engine_obj assert connection_info2.databases["db1"] is db_obj assert connection_info2.engines["engine1"] is engine_obj + + +@mark.nofakefs +def test_file_based_cache_read_data_json_file_not_exists( + file_based_cache, encrypter_with_key +): + """Test _read_data_json returns empty dict when file doesn't exist.""" + # Test with a non-existent file path + result = file_based_cache._read_data_json( + "/path/to/nonexistent/file.txt", encrypter_with_key + ) + assert result == {} + + +def test_file_based_cache_read_data_json_valid_data( + file_based_cache, encrypter_with_key +): + """Test _read_data_json successfully reads and decrypts valid JSON data.""" + # Create test data + test_data = {"id": "test_connection", "token": "test_token"} + test_file_path = "/test_cache/valid_data.txt" + + # Create directory and file with encrypted JSON data + os.makedirs(os.path.dirname(test_file_path), exist_ok=True) + json_str = json.dumps(test_data) + encrypted_data = encrypter_with_key.encrypt(json_str) + + with open(test_file_path, "w") as f: + f.write(encrypted_data) + + # Test reading the valid encrypted data + result = file_based_cache._read_data_json(test_file_path, encrypter_with_key) + assert result == test_data + assert result["id"] == "test_connection" + assert result["token"] == "test_token" + + +def test_file_based_cache_read_data_json_decryption_failure(file_based_cache): + """Test _read_data_json returns empty dict when decryption fails.""" + # Create encrypters with different keys + salt = generate_salt() + encrypter1 = FernetEncrypter(salt, "test_key_1") + encrypter2 = FernetEncrypter(salt, "test_key_2") # Different key + + test_file_path = "/test_cache/decryption_test.txt" + + # Create directory and file with data encrypted by encrypter1 + os.makedirs(os.path.dirname(test_file_path), exist_ok=True) + encrypted_data = encrypter1.encrypt('{"test": "data"}') + + with open(test_file_path, "w") as f: + f.write(encrypted_data) + + # Try to decrypt with encrypter2 (should fail) + result = file_based_cache._read_data_json(test_file_path, encrypter2) + assert result == {} + + +def test_file_based_cache_read_data_json_invalid_json( + file_based_cache, encrypter_with_key +): + """Test _read_data_json returns empty dict when JSON is invalid.""" + test_file_path = "/test_cache/invalid_json.txt" + + # Create directory and file with encrypted invalid JSON + os.makedirs(os.path.dirname(test_file_path), exist_ok=True) + invalid_json = "invalid json data {{" + encrypted_data = encrypter_with_key.encrypt(invalid_json) + + with open(test_file_path, "w") as f: + f.write(encrypted_data) + + # Test reading the invalid JSON + result = file_based_cache._read_data_json(test_file_path, encrypter_with_key) + assert result == {} + + +@mark.nofakefs +def test_file_based_cache_read_data_json_io_error(file_based_cache, encrypter_with_key): + """Test _read_data_json returns empty dict when IOError occurs.""" + # Mock open to raise IOError + with patch("builtins.open", mock_open()) as mock_file: + mock_file.side_effect = IOError("File read error") + + result = file_based_cache._read_data_json("test_file.txt", encrypter_with_key) + assert result == {} + + +def test_file_based_cache_read_data_json_empty_encrypted_data( + file_based_cache, encrypter_with_key +): + """Test _read_data_json handles empty encrypted data.""" + test_file_path = "/test_cache/empty_data.txt" + + # Create directory and file with empty encrypted data + os.makedirs(os.path.dirname(test_file_path), exist_ok=True) + encrypted_empty = encrypter_with_key.encrypt("") + + with open(test_file_path, "w") as f: + f.write(encrypted_empty) + + # Test reading empty decrypted data + result = file_based_cache._read_data_json(test_file_path, encrypter_with_key) + assert result == {} + + +def test_file_based_cache_read_data_json_invalid_encrypted_format( + file_based_cache, encrypter_with_key +): + """Test _read_data_json handles invalid encrypted data format.""" + test_file_path = "/test_cache/invalid_encrypted.txt" + + # Create directory and file with invalid encrypted data format + os.makedirs(os.path.dirname(test_file_path), exist_ok=True) + + with open(test_file_path, "w") as f: + f.write("not_encrypted_data_at_all") + + # Test reading invalid encrypted format + result = file_based_cache._read_data_json(test_file_path, encrypter_with_key) + assert result == {} From c6700973105f81837482f044b0aa4042fbe972da Mon Sep 17 00:00:00 2001 From: ptiurin Date: Fri, 22 Aug 2025 19:51:31 +0100 Subject: [PATCH 08/11] fix bug when expired cache from file would be loaded into memory --- src/firebolt/utils/cache.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/src/firebolt/utils/cache.py b/src/firebolt/utils/cache.py index d5fc25a933f..6a19dee7ddd 100644 --- a/src/firebolt/utils/cache.py +++ b/src/firebolt/utils/cache.py @@ -143,12 +143,13 @@ def _is_expired(self, value: T) -> bool: return False @noop_if_disabled - def set(self, key: ReprCacheable, value: T) -> None: + def set(self, key: ReprCacheable, value: T, preserve_expiry: bool = False) -> None: if not self.disabled: # Set expiry_time for ConnectionInfo objects if hasattr(value, "expiry_time"): - current_time = int(time.time()) - value.expiry_time = current_time + CACHE_EXPIRY_SECONDS + if not preserve_expiry or value.expiry_time is None: + current_time = int(time.time()) + value.expiry_time = current_time + CACHE_EXPIRY_SECONDS s_key = self.create_key(key) self._cache[s_key] = value @@ -275,11 +276,23 @@ def get(self, key: SecureCacheKey) -> Optional[ConnectionInfo]: raw_data = self._read_data_json(file_path, encrypter) if not raw_data: return None + logger.debug("Cache hit on disk") data = ConnectionInfo(**raw_data) - # Add to memory cache and return - self.memory_cache.set(key, data) + # Check if the loaded data is expired + if self.memory_cache._is_expired(data): + # Data is expired, delete the file and return None + try: + if path.exists(file_path): + os.remove(file_path) + logger.debug("Deleted expired file %s", file_path) + except OSError: + logger.debug("Failed to delete expired file %s", file_path) + return None + + # Data is not expired, add to memory cache preserving original expiry time + self.memory_cache.set(key, data, preserve_expiry=True) return data def set(self, key: SecureCacheKey, value: ConnectionInfo) -> None: From 90a5b4df822ce071baad3e5f8c645a734678e0c5 Mon Sep 17 00:00:00 2001 From: ptiurin Date: Fri, 22 Aug 2025 19:58:54 +0100 Subject: [PATCH 09/11] additional tests --- tests/unit/utils/test_cache.py | 367 +++++++++++++++++++++++++++++++++ 1 file changed, 367 insertions(+) diff --git a/tests/unit/utils/test_cache.py b/tests/unit/utils/test_cache.py index 6db44e8ecf5..9218bb620aa 100644 --- a/tests/unit/utils/test_cache.py +++ b/tests/unit/utils/test_cache.py @@ -560,3 +560,370 @@ def test_file_based_cache_read_data_json_invalid_encrypted_format( # Test reading invalid encrypted format result = file_based_cache._read_data_json(test_file_path, encrypter_with_key) assert result == {} + + +def test_file_based_cache_delete_method(file_based_cache, encrypter_with_key): + """Test FileBasedCache delete method removes data from both memory and file.""" + # Create test data + sample_key = SecureCacheKey(["delete", "test"], "test_secret") + sample_data = ConnectionInfo(id="test_delete_connection", token="test_token") + + # Set data in cache (both memory and file) + file_based_cache.set(sample_key, sample_data) + + # Verify data exists in memory cache + memory_result = file_based_cache.memory_cache.get(sample_key) + assert memory_result is not None + assert memory_result.id == "test_delete_connection" + + # Verify file exists on disk + file_path = file_based_cache._get_file_path(sample_key) + assert os.path.exists(file_path) + + # Delete the data + file_based_cache.delete(sample_key) + + # Verify data is removed from memory cache + memory_result_after_delete = file_based_cache.memory_cache.get(sample_key) + assert memory_result_after_delete is None + + # Verify file is removed from disk + assert not os.path.exists(file_path) + + # Verify get returns None + cache_result = file_based_cache.get(sample_key) + assert cache_result is None + + +@mark.nofakefs +def test_file_based_cache_delete_method_file_removal_failure( + file_based_cache, encrypter_with_key +): + """Test FileBasedCache delete method handles file removal failures gracefully.""" + sample_key = SecureCacheKey(["delete", "failure"], "test_secret") + sample_data = ConnectionInfo(id="test_connection", token="test_token") + + # Set data in memory cache only (no file operations due to @mark.nofakefs) + file_based_cache.memory_cache.set(sample_key, sample_data) + + # Mock path.exists to return True and os.remove to raise OSError + with patch("firebolt.utils.cache.path.exists", return_value=True), patch( + "firebolt.utils.cache.os.remove" + ) as mock_remove: + mock_remove.side_effect = OSError("Permission denied") + + # Delete should not raise an exception despite file removal failure + file_based_cache.delete(sample_key) + + # Verify data is still removed from memory cache + memory_result = file_based_cache.memory_cache.get(sample_key) + assert memory_result is None + + +def test_file_based_cache_get_from_file_when_not_in_memory( + file_based_cache, encrypter_with_key +): + """Test FileBasedCache get method retrieves data from file when not in memory.""" + # Create test data + sample_key = SecureCacheKey(["file", "only"], "test_secret") + sample_data = ConnectionInfo( + id="test_file_connection", + token="test_file_token", + expiry_time=int(time.time()) + 3600, # Valid for 1 hour + ) + + # First set data in cache (both memory and file) + file_based_cache.set(sample_key, sample_data) + + # Verify data exists + initial_result = file_based_cache.get(sample_key) + assert initial_result is not None + assert initial_result.id == "test_file_connection" + + # Clear memory cache but keep file + file_based_cache.memory_cache.clear() + + # Verify memory cache is empty + memory_result = file_based_cache.memory_cache.get(sample_key) + assert memory_result is None + + # Verify file still exists + file_path = file_based_cache._get_file_path(sample_key) + assert os.path.exists(file_path) + + # Get should retrieve from file and reload into memory + file_result = file_based_cache.get(sample_key) + assert file_result is not None + assert file_result.id == "test_file_connection" + assert file_result.token == "test_file_token" + + # Verify data is now back in memory cache + memory_result_after_load = file_based_cache.memory_cache.get(sample_key) + assert memory_result_after_load is not None + assert memory_result_after_load.id == "test_file_connection" + + +def test_file_based_cache_get_from_corrupted_file(file_based_cache, encrypter_with_key): + """Test FileBasedCache get method handles corrupted file gracefully.""" + sample_key = SecureCacheKey(["corrupted", "file"], "test_secret") + + # Create corrupted file manually + file_path = file_based_cache._get_file_path(sample_key) + os.makedirs(os.path.dirname(file_path), exist_ok=True) + + with open(file_path, "w") as f: + f.write("corrupted_data_that_cannot_be_decrypted") + + # Verify file exists + assert os.path.exists(file_path) + + # Get should return None due to decryption failure + result = file_based_cache.get(sample_key) + assert result is None + + # Verify nothing is loaded into memory cache + memory_result = file_based_cache.memory_cache.get(sample_key) + assert memory_result is None + + +def test_file_based_cache_disabled_behavior(file_based_cache, encrypter_with_key): + """Test FileBasedCache methods when cache is disabled.""" + sample_key = SecureCacheKey(["disabled", "test"], "test_secret") + sample_data = ConnectionInfo(id="test_connection", token="test_token") + + # Disable the cache + file_based_cache.disable() + + # Set should do nothing when disabled + file_based_cache.set(sample_key, sample_data) + + # Get should return None when disabled + result = file_based_cache.get(sample_key) + assert result is None + + # Enable cache, set data, then disable again + file_based_cache.enable() + file_based_cache.set(sample_key, sample_data) + + # Verify data is set + enabled_result = file_based_cache.get(sample_key) + assert enabled_result is not None + + # Disable and verify get returns None + file_based_cache.disable() + disabled_result = file_based_cache.get(sample_key) + assert disabled_result is None + + # Delete should do nothing when disabled + file_based_cache.delete(sample_key) # Should not raise exception + + +def test_file_based_cache_preserves_expiry_from_file( + file_based_cache, encrypter_with_key, fixed_time +): + """Test that FileBasedCache preserves original expiry time when loading from file.""" + sample_key = SecureCacheKey(["preserve", "expiry"], "test_secret") + + # Create data and set it at an earlier time + sample_data = ConnectionInfo(id="test_connection") + + # Set data at fixed_time - this will give it expiry of fixed_time + CACHE_EXPIRY_SECONDS + with patch("time.time", return_value=fixed_time): + file_based_cache.set(sample_key, sample_data) + + # Verify the expiry time that was set + memory_result = file_based_cache.memory_cache.get(sample_key) + expected_expiry = fixed_time + CACHE_EXPIRY_SECONDS + assert memory_result.expiry_time == expected_expiry + + # Clear memory cache to force file load on next get + file_based_cache.memory_cache.clear() + + # Get data from file (should preserve the original expiry time from file) + result = file_based_cache.get(sample_key) + + assert result is not None + assert ( + result.expiry_time == expected_expiry + ) # Should preserve original expiry from file + assert result.id == "test_connection" + + # Verify it's also in memory cache with preserved expiry + memory_result_after_load = file_based_cache.memory_cache.get(sample_key) + assert memory_result_after_load is not None + assert memory_result_after_load.expiry_time == expected_expiry + + +def test_file_based_cache_deletes_expired_file_on_get( + file_based_cache, encrypter_with_key, fixed_time +): + """Test that FileBasedCache deletes expired files on get and returns cache miss.""" + sample_key = SecureCacheKey(["expired", "file"], "test_secret") + sample_data = ConnectionInfo(id="test_connection") + + # Set data at an early time so it gets an early expiry + early_time = fixed_time - 7200 # 2 hours before + with patch("time.time", return_value=early_time): + file_based_cache.set(sample_key, sample_data) + + # Verify the expiry time that was set (should be early_time + CACHE_EXPIRY_SECONDS) + memory_result = file_based_cache.memory_cache.get(sample_key) + expected_expiry = early_time + CACHE_EXPIRY_SECONDS + assert memory_result.expiry_time == expected_expiry + + # Verify file was created + file_path = file_based_cache._get_file_path(sample_key) + assert os.path.exists(file_path) + + # Clear memory cache to force file load + file_based_cache.memory_cache.clear() + + # Now try to get at a time when the data should be expired + # The data expires at early_time + CACHE_EXPIRY_SECONDS + # Let's try to get it after that expiry time + expired_check_time = early_time + CACHE_EXPIRY_SECONDS + 1 + with patch("time.time", return_value=expired_check_time): + result = file_based_cache.get(sample_key) + + # Should return None due to expiry + assert result is None + + # File should be deleted + assert not os.path.exists(file_path) + + # Memory cache should not contain the data + memory_result = file_based_cache.memory_cache.get(sample_key) + assert memory_result is None + + +def test_file_based_cache_expiry_edge_case_exactly_expired( + file_based_cache, encrypter_with_key, fixed_time +): + """Test behavior when data expires exactly at the current time.""" + sample_key = SecureCacheKey(["edge", "case"], "test_secret") + sample_data = ConnectionInfo(id="test_connection") + + # Set data such that it will expire exactly at fixed_time + set_time = fixed_time - CACHE_EXPIRY_SECONDS + with patch("time.time", return_value=set_time): + file_based_cache.set(sample_key, sample_data) + + # Verify the expiry time that was set + memory_result = file_based_cache.memory_cache.get(sample_key) + expected_expiry = set_time + CACHE_EXPIRY_SECONDS # This equals fixed_time + assert memory_result.expiry_time == expected_expiry == fixed_time + + file_path = file_based_cache._get_file_path(sample_key) + assert os.path.exists(file_path) + + # Clear memory cache + file_based_cache.memory_cache.clear() + + # Try to get exactly at expiry time (should be considered expired) + with patch("time.time", return_value=fixed_time): + result = file_based_cache.get(sample_key) + + # Should return None as data is expired (>= check in _is_expired) + assert result is None + + # File should be deleted + assert not os.path.exists(file_path) + + +def test_file_based_cache_non_expired_file_loads_correctly( + file_based_cache, encrypter_with_key, fixed_time +): + """Test that non-expired data from file loads correctly with preserved expiry.""" + sample_key = SecureCacheKey(["non", "expired"], "test_secret") + + sample_data = ConnectionInfo(id="test_connection", token="test_token") + + # Set data at an earlier time so it's not expired yet + set_time = fixed_time - 900 # 15 minutes before + with patch("time.time", return_value=set_time): + file_based_cache.set(sample_key, sample_data) + + # Verify expiry time + memory_result = file_based_cache.memory_cache.get(sample_key) + expected_expiry = set_time + CACHE_EXPIRY_SECONDS + assert memory_result.expiry_time == expected_expiry + + # Clear memory cache to force file load + file_based_cache.memory_cache.clear() + + # Get data at fixed_time (data should not be expired since expected_expiry > fixed_time) + with patch("time.time", return_value=fixed_time): + # Ensure the data is not expired + assert expected_expiry > fixed_time, "Data should not be expired for this test" + + result = file_based_cache.get(sample_key) + + # Should successfully load data + assert result is not None + assert result.id == "test_connection" + assert result.token == "test_token" + assert result.expiry_time == expected_expiry # Preserved original expiry + + # Verify file still exists (not deleted) + file_path = file_based_cache._get_file_path(sample_key) + assert os.path.exists(file_path) + + # Verify it's in memory cache with preserved expiry + memory_result = file_based_cache.memory_cache.get(sample_key) + assert memory_result is not None + assert memory_result.expiry_time == expected_expiry + + +def test_memory_cache_set_preserve_expiry_parameter( + cache, sample_cache_key, fixed_time +): + """Test UtilCache.set preserve_expiry parameter functionality.""" + # Create connection info with specific expiry time + original_expiry = fixed_time + 1800 + sample_data = ConnectionInfo(id="test_connection", expiry_time=original_expiry) + + with patch("time.time", return_value=fixed_time): + # Test preserve_expiry=True + cache.set(sample_cache_key, sample_data, preserve_expiry=True) + + result = cache.get(sample_cache_key) + assert result is not None + assert result.expiry_time == original_expiry # Should preserve original + + cache.clear() + + # Test preserve_expiry=False (default behavior) + cache.set(sample_cache_key, sample_data, preserve_expiry=False) + + result = cache.get(sample_cache_key) + assert result is not None + expected_new_expiry = fixed_time + CACHE_EXPIRY_SECONDS + assert result.expiry_time == expected_new_expiry # Should get new expiry + + cache.clear() + + # Test default behavior (preserve_expiry not specified) + cache.set(sample_cache_key, sample_data) + + result = cache.get(sample_cache_key) + assert result is not None + assert result.expiry_time == expected_new_expiry # Should get new expiry + + +def test_memory_cache_set_preserve_expiry_with_none_expiry( + cache, sample_cache_key, fixed_time +): + """Test UtilCache.set preserve_expiry when original expiry_time is None.""" + # Create connection info with None expiry time + sample_data = ConnectionInfo(id="test_connection", expiry_time=None) + + with patch("time.time", return_value=fixed_time): + # Even with preserve_expiry=True, None expiry should get new expiry + cache.set(sample_cache_key, sample_data, preserve_expiry=True) + + result = cache.get(sample_cache_key) + assert result is not None + expected_expiry = fixed_time + CACHE_EXPIRY_SECONDS + assert ( + result.expiry_time == expected_expiry + ) # Should get new expiry despite preserve=True From aeebeabe57501f8e77dc8a078830514c2e4f4b9b Mon Sep 17 00:00:00 2001 From: ptiurin Date: Mon, 6 Oct 2025 14:01:22 +0100 Subject: [PATCH 10/11] fix mypy issue --- src/firebolt/client/auth/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/firebolt/client/auth/base.py b/src/firebolt/client/auth/base.py index d17e4746b48..2a7ddf0fd97 100644 --- a/src/firebolt/client/auth/base.py +++ b/src/firebolt/client/auth/base.py @@ -62,7 +62,7 @@ def account(self) -> Optional[str]: return self._account_name @account.setter - def account(self, value: str) -> None: + def account(self, value: Optional[str]) -> None: self._account_name = value # Now we have all the elements to fetch the cached token if not self._token: From 7f55541d4b38b8e6d5e4c997d391e164986f68c1 Mon Sep 17 00:00:00 2001 From: ptiurin Date: Mon, 6 Oct 2025 14:20:20 +0100 Subject: [PATCH 11/11] safer test --- .../resource_manager/V2/test_engine.py | 29 ++++++++++--------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/tests/integration/resource_manager/V2/test_engine.py b/tests/integration/resource_manager/V2/test_engine.py index 96f43bd2139..7c122b62a9d 100644 --- a/tests/integration/resource_manager/V2/test_engine.py +++ b/tests/integration/resource_manager/V2/test_engine.py @@ -27,29 +27,30 @@ def test_create_start_stop_engine( auto_stop=120, ) assert engine.name == name + database = None try: database = rm.databases.create(name=name) assert database.name == name - try: - engine.attach_to_database(database) - assert engine.database == database + engine.attach_to_database(database) + assert engine.database == database - engine.start() - assert engine.current_status == EngineStatus.RUNNING + engine.start() + assert engine.current_status == EngineStatus.RUNNING - engine.stop() - assert engine.current_status in { - EngineStatus.STOPPING, - EngineStatus.STOPPED, - } - finally: - database.delete() + engine.stop() + assert engine.current_status in { + EngineStatus.STOPPING, + EngineStatus.STOPPED, + } finally: - engine.stop() - engine.delete() + if engine: + engine.stop() + engine.delete() + if database: + database.delete() ParamValue = namedtuple("ParamValue", "set expected")