Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/firebolt/async_db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ async def connect(
if not auth:
raise ConfigurationError("auth is required to connect.")

auth.account = account_name

api_endpoint = fix_url_schema(api_endpoint)
# Type checks
assert auth is not None
Expand Down
78 changes: 54 additions & 24 deletions src/firebolt/client/auth/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@
from abc import abstractmethod
from enum import IntEnum
from time import time
from typing import AsyncGenerator, Generator, Optional
from typing import AsyncGenerator, Generator, Optional, Tuple

from anyio import Lock
from httpx import Auth as HttpxAuth
from httpx import Request, Response, codes

from firebolt.utils.token_storage import TokenSecureStorage
from firebolt.utils.util import Timer, cached_property, get_internal_error_code
from firebolt.utils.cache import (
ConnectionInfo,
SecureCacheKey,
_firebolt_cache,
)
from firebolt.utils.util import Timer, get_internal_error_code

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -38,6 +42,7 @@ class Auth(HttpxAuth):

__slots__ = (
"_token",
"_account_name",
"_expires",
"_use_token_cache",
)
Expand All @@ -47,10 +52,22 @@ class Auth(HttpxAuth):

def __init__(self, use_token_cache: bool = True):
self._use_token_cache = use_token_cache
self._token: Optional[str] = self._get_cached_token()
self._account_name: Optional[str] = None
self._token: Optional[str] = None
self._expires: Optional[int] = None
self._lock = Lock()

@property
def account(self) -> Optional[str]:
return self._account_name

@account.setter
def account(self, value: Optional[str]) -> None:
self._account_name = value
# Now we have all the elements to fetch the cached token
if not self._token:
self._token, self._expires = self._get_cached_token()

def copy(self) -> "Auth":
"""Make another auth object with same credentials.

Expand Down Expand Up @@ -103,36 +120,49 @@ def expired(self) -> bool:
"""
return self._expires is not None and self._expires <= int(time())

@cached_property
def _token_storage(self) -> Optional[TokenSecureStorage]:
"""Token filesystem cache storage.
def _get_cached_token(self) -> Tuple[Optional[str], Optional[int]]:
"""If caching is enabled, get token from cache.

This is evaluated lazily, only if caching is enabled.
If caching is disabled, None is returned.

Returns:
Optional[TokenSecureStorage]: Token filesystem cache storage if any
Optional[str]: Token if any, and if caching is enabled; None otherwise
"""
return None
if not self._use_token_cache:
return (None, None)

def _get_cached_token(self) -> Optional[str]:
"""If caching is enabled, get token from filesystem cache.
cache_key = SecureCacheKey(
[self.principal, self.secret, self._account_name], self.secret
)
connection_info = _firebolt_cache.get(cache_key)

If caching is disabled, None is returned.
if connection_info and connection_info.token:
return (connection_info.token, connection_info.expiry_time)

Returns:
Optional[str]: Token if any, and if caching is enabled; None otherwise
"""
if not self._use_token_cache or not self._token_storage:
return None
return self._token_storage.get_cached_token()
return (None, None)

def _cache_token(self) -> None:
"""If caching isenabled, cache token to filesystem."""
if not self._use_token_cache or not self._token_storage:
"""If caching is enabled, cache token."""
if not self._use_token_cache:
return
# Only cache if token and expiration are retrieved
if self._token and self._expires:
self._token_storage.cache_token(self._token, self._expires)
# Only cache if token is retrieved
if self._token:
cache_key = SecureCacheKey(
[self.principal, self.secret, self._account_name], self.secret
)

# Get existing connection info or create new one
connection_info = _firebolt_cache.get(cache_key)
if connection_info is None:
connection_info = ConnectionInfo(
id="NONE"
) # This is triggered first so there will be no id

# Update token information
connection_info.token = self._token

# Cache it
_firebolt_cache.set(cache_key, connection_info)

@abstractmethod
def get_new_token_generator(self) -> Generator[Request, Response, None]:
Expand Down
15 changes: 0 additions & 15 deletions src/firebolt/client/auth/client_credentials.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
from typing import Optional

from firebolt.client.auth.base import AuthRequest, FireboltAuthVersion
from firebolt.client.auth.request_auth_base import _RequestBasedAuth
from firebolt.utils.token_storage import TokenSecureStorage
from firebolt.utils.urls import AUTH_SERVICE_ACCOUNT_URL
from firebolt.utils.util import cached_property


class ClientCredentials(_RequestBasedAuth):
Expand Down Expand Up @@ -79,17 +75,6 @@ def get_firebolt_version(self) -> FireboltAuthVersion:
"""
return FireboltAuthVersion.V2

@cached_property
def _token_storage(self) -> Optional[TokenSecureStorage]:
"""Token filesystem cache storage.

This is evaluated lazily, only if caching is enabled

Returns:
TokenSecureStorage: Token filesystem cache storage
"""
return TokenSecureStorage(username=self.client_id, password=self.client_secret)

def _make_auth_request(self) -> AuthRequest:
"""Get new token using username and password.

Expand Down
15 changes: 0 additions & 15 deletions src/firebolt/client/auth/service_account.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
from typing import Optional

from firebolt.client.auth.base import AuthRequest, FireboltAuthVersion
from firebolt.client.auth.request_auth_base import _RequestBasedAuth
from firebolt.utils.token_storage import TokenSecureStorage
from firebolt.utils.urls import AUTH_SERVICE_ACCOUNT_URL
from firebolt.utils.util import cached_property


class ServiceAccount(_RequestBasedAuth):
Expand Down Expand Up @@ -77,17 +73,6 @@ def copy(self) -> "ServiceAccount":
"""
return ServiceAccount(self.client_id, self.client_secret, self._use_token_cache)

@cached_property
def _token_storage(self) -> Optional[TokenSecureStorage]:
"""Token filesystem cache storage.

This is evaluated lazily, only if caching is enabled

Returns:
TokenSecureStorage: Token filesystem cache storage
"""
return TokenSecureStorage(username=self.client_id, password=self.client_secret)

def _make_auth_request(self) -> AuthRequest:
"""Get new token using username and password.

Expand Down
15 changes: 0 additions & 15 deletions src/firebolt/client/auth/username_password.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
from typing import Optional

from firebolt.client.auth.base import AuthRequest, FireboltAuthVersion
from firebolt.client.auth.request_auth_base import _RequestBasedAuth
from firebolt.utils.token_storage import TokenSecureStorage
from firebolt.utils.urls import AUTH_URL
from firebolt.utils.util import cached_property


class UsernamePassword(_RequestBasedAuth):
Expand Down Expand Up @@ -77,17 +73,6 @@ def copy(self) -> "UsernamePassword":
"""
return UsernamePassword(self.username, self.password, self._use_token_cache)

@cached_property
def _token_storage(self) -> Optional[TokenSecureStorage]:
"""Token filesystem cache storage.

This is evaluated lazily, only if caching is enabled

Returns:
TokenSecureStorage: Token filesystem cache storage
"""
return TokenSecureStorage(username=self.username, password=self.password)

def _make_auth_request(self) -> AuthRequest:
"""Get new token using username and password.

Expand Down
2 changes: 0 additions & 2 deletions src/firebolt/common/token_storage.py

This file was deleted.

2 changes: 2 additions & 0 deletions src/firebolt/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def connect(
if not auth:
raise ConfigurationError("auth is required to connect.")

auth.account = account_name

api_endpoint = fix_url_schema(api_endpoint)
# Type checks
assert auth is not None
Expand Down
Loading