Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions optimizely/cmab/cmab_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
DEFAULT_MAX_BACKOFF = 10 # in seconds
DEFAULT_BACKOFF_MULTIPLIER = 2.0
MAX_WAIT_TIME = 10.0
DEFAULT_PREDICTION_ENDPOINT = "https://prediction.cmab.optimizely.com/predict/{}"


class CmabRetryConfig:
Expand Down Expand Up @@ -52,17 +53,21 @@ class DefaultCmabClient:
"""
def __init__(self, http_client: Optional[requests.Session] = None,
retry_config: Optional[CmabRetryConfig] = None,
logger: Optional[_logging.Logger] = None):
logger: Optional[_logging.Logger] = None,
prediction_endpoint: Optional[str] = None):
"""Initialize the CMAB client.

Args:
http_client (Optional[requests.Session]): HTTP client for making requests.
retry_config (Optional[CmabRetryConfig]): Configuration for retry logic.
logger (Optional[_logging.Logger]): Logger for logging messages.
prediction_endpoint (Optional[str]): Custom prediction endpoint URL template.
Use {} as placeholder for rule_id.
"""
self.http_client = http_client or requests.Session()
self.retry_config = retry_config
self.logger = _logging.adapt_logger(logger or _logging.NoOpLogger())
self.prediction_endpoint = prediction_endpoint or DEFAULT_PREDICTION_ENDPOINT

def fetch_decision(
self,
Expand All @@ -84,7 +89,7 @@ def fetch_decision(
Returns:
str: The variation ID.
"""
url = f"https://prediction.cmab.optimizely.com/predict/{rule_id}"
url = self.prediction_endpoint.format(rule_id)
cmab_attributes = [
{"id": key, "value": value, "type": "custom_attribute"}
for key, value in attributes.items()
Expand Down
6 changes: 5 additions & 1 deletion optimizely/helpers/sdk_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def __init__(
odp_event_manager: Optional[OdpEventManager] = None,
odp_segment_request_timeout: Optional[int] = None,
odp_event_request_timeout: Optional[int] = None,
odp_event_flush_interval: Optional[int] = None
odp_event_flush_interval: Optional[int] = None,
cmab_prediction_endpoint: Optional[str] = None
) -> None:
"""
Args:
Expand All @@ -52,6 +53,8 @@ def __init__(
send successfully (optional).
odp_event_request_timeout: Time to wait in seconds for send_odp_events request to send successfully.
odp_event_flush_interval: Time to wait for events to accumulate before sending a batch in seconds (optional).
cmab_prediction_endpoint: Custom CMAB prediction endpoint URL template (optional).
Use {} as placeholder for rule_id. Defaults to production endpoint if not provided.
"""

self.odp_disabled = odp_disabled
Expand All @@ -63,3 +66,4 @@ def __init__(
self.fetch_segments_timeout = odp_segment_request_timeout
self.odp_event_timeout = odp_event_request_timeout
self.odp_flush_interval = odp_event_flush_interval
self.cmab_prediction_endpoint = cmab_prediction_endpoint
8 changes: 7 additions & 1 deletion optimizely/optimizely.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,15 @@ def __init__(
if cmab_service:
self.cmab_service = cmab_service
else:
# Get custom prediction endpoint from settings if provided
cmab_prediction_endpoint = None
if self.sdk_settings and self.sdk_settings.cmab_prediction_endpoint:
cmab_prediction_endpoint = self.sdk_settings.cmab_prediction_endpoint

self.cmab_client = DefaultCmabClient(
retry_config=CmabRetryConfig(),
logger=self.logger
logger=self.logger,
prediction_endpoint=cmab_prediction_endpoint
)
self.cmab_cache: LRUCache[str, CmabCacheValue] = LRUCache(DEFAULT_CMAB_CACHE_SIZE,
DEFAULT_CMAB_CACHE_TIMEOUT)
Expand Down
78 changes: 78 additions & 0 deletions tests/test_cmab_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,81 @@ def test_fetch_decision_exhausts_all_retry_attempts(self, mock_sleep):
self.mock_logger.error.assert_called_with(
Errors.CMAB_FETCH_FAILED.format('Exhausted all retries for CMAB request.')
)

def test_custom_prediction_endpoint(self):
"""Test that custom prediction endpoint is used correctly."""
custom_endpoint = "https://custom.endpoint.com/predict/{}"
client = DefaultCmabClient(
http_client=self.mock_http_client,
logger=self.mock_logger,
prediction_endpoint=custom_endpoint
)

mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
'predictions': [{'variation_id': 'abc123'}]
}
self.mock_http_client.post.return_value = mock_response

result = client.fetch_decision(self.rule_id, self.user_id, self.attributes, self.cmab_uuid)

self.assertEqual(result, 'abc123')
expected_custom_url = custom_endpoint.format(self.rule_id)
self.mock_http_client.post.assert_called_once_with(
expected_custom_url,
data=json.dumps(self.expected_body),
headers=self.expected_headers,
timeout=10.0
)

def test_default_prediction_endpoint(self):
"""Test that default prediction endpoint is used when none is provided."""
client = DefaultCmabClient(
http_client=self.mock_http_client,
logger=self.mock_logger
)

mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
'predictions': [{'variation_id': 'def456'}]
}
self.mock_http_client.post.return_value = mock_response

result = client.fetch_decision(self.rule_id, self.user_id, self.attributes, self.cmab_uuid)

self.assertEqual(result, 'def456')
# Should use the default production endpoint
self.mock_http_client.post.assert_called_once_with(
self.expected_url,
data=json.dumps(self.expected_body),
headers=self.expected_headers,
timeout=10.0
)

def test_empty_prediction_endpoint_uses_default(self):
"""Test that empty string prediction endpoint falls back to default."""
client = DefaultCmabClient(
http_client=self.mock_http_client,
logger=self.mock_logger,
prediction_endpoint=""
)

mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
'predictions': [{'variation_id': 'ghi789'}]
}
self.mock_http_client.post.return_value = mock_response

result = client.fetch_decision(self.rule_id, self.user_id, self.attributes, self.cmab_uuid)

self.assertEqual(result, 'ghi789')
# Should use the default production endpoint when empty string is provided
self.mock_http_client.post.assert_called_once_with(
self.expected_url,
data=json.dumps(self.expected_body),
headers=self.expected_headers,
timeout=10.0
)
6 changes: 4 additions & 2 deletions tests/test_config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,8 +517,10 @@ def test_fetch_datafile__exception_polling_thread_failed(self, _):
log_messages = [args[0] for args, _ in mock_logger.error.call_args_list]
for message in log_messages:
print(message)
if "Thread for background datafile polling failed. " \
"Error: timestamp too large to convert to C PyTime_t" not in message:
# Check for key parts of the error message (version-agnostic for Python 3.11+)
if not ("Thread for background datafile polling failed" in message and
"timestamp too large to convert to C" in message and
"PyTime_t" in message):
assert False

def test_is_running(self, _):
Expand Down
Loading