diff --git a/optimizely/cmab/cmab_client.py b/optimizely/cmab/cmab_client.py index 25c18518..4880b0bb 100644 --- a/optimizely/cmab/cmab_client.py +++ b/optimizely/cmab/cmab_client.py @@ -25,6 +25,7 @@ DEFAULT_MAX_BACKOFF = 10 # in seconds DEFAULT_BACKOFF_MULTIPLIER = 2.0 MAX_WAIT_TIME = 10.0 +DEFAULT_PREDICTION_ENDPOINT = "https://prediction.cmab.optimizely.com/predict/{}" class CmabRetryConfig: @@ -52,17 +53,21 @@ class DefaultCmabClient: """ def __init__(self, http_client: Optional[requests.Session] = None, retry_config: Optional[CmabRetryConfig] = None, - logger: Optional[_logging.Logger] = None): + logger: Optional[_logging.Logger] = None, + prediction_endpoint: Optional[str] = None): """Initialize the CMAB client. Args: http_client (Optional[requests.Session]): HTTP client for making requests. retry_config (Optional[CmabRetryConfig]): Configuration for retry logic. logger (Optional[_logging.Logger]): Logger for logging messages. + prediction_endpoint (Optional[str]): Custom prediction endpoint URL template. + Use {} as placeholder for rule_id. """ self.http_client = http_client or requests.Session() self.retry_config = retry_config self.logger = _logging.adapt_logger(logger or _logging.NoOpLogger()) + self.prediction_endpoint = prediction_endpoint or DEFAULT_PREDICTION_ENDPOINT def fetch_decision( self, @@ -84,7 +89,7 @@ def fetch_decision( Returns: str: The variation ID. """ - url = f"https://prediction.cmab.optimizely.com/predict/{rule_id}" + url = self.prediction_endpoint.format(rule_id) cmab_attributes = [ {"id": key, "value": value, "type": "custom_attribute"} for key, value in attributes.items() diff --git a/optimizely/helpers/sdk_settings.py b/optimizely/helpers/sdk_settings.py index 6b31ee9c..e5e7aeb1 100644 --- a/optimizely/helpers/sdk_settings.py +++ b/optimizely/helpers/sdk_settings.py @@ -33,7 +33,8 @@ def __init__( odp_event_manager: Optional[OdpEventManager] = None, odp_segment_request_timeout: Optional[int] = None, odp_event_request_timeout: Optional[int] = None, - odp_event_flush_interval: Optional[int] = None + odp_event_flush_interval: Optional[int] = None, + cmab_prediction_endpoint: Optional[str] = None ) -> None: """ Args: @@ -52,6 +53,8 @@ def __init__( send successfully (optional). odp_event_request_timeout: Time to wait in seconds for send_odp_events request to send successfully. odp_event_flush_interval: Time to wait for events to accumulate before sending a batch in seconds (optional). + cmab_prediction_endpoint: Custom CMAB prediction endpoint URL template (optional). + Use {} as placeholder for rule_id. Defaults to production endpoint if not provided. """ self.odp_disabled = odp_disabled @@ -63,3 +66,4 @@ def __init__( self.fetch_segments_timeout = odp_segment_request_timeout self.odp_event_timeout = odp_event_request_timeout self.odp_flush_interval = odp_event_flush_interval + self.cmab_prediction_endpoint = cmab_prediction_endpoint diff --git a/optimizely/optimizely.py b/optimizely/optimizely.py index 4a47bbdb..6a1acaf2 100644 --- a/optimizely/optimizely.py +++ b/optimizely/optimizely.py @@ -178,9 +178,15 @@ def __init__( if cmab_service: self.cmab_service = cmab_service else: + # Get custom prediction endpoint from settings if provided + cmab_prediction_endpoint = None + if self.sdk_settings and self.sdk_settings.cmab_prediction_endpoint: + cmab_prediction_endpoint = self.sdk_settings.cmab_prediction_endpoint + self.cmab_client = DefaultCmabClient( retry_config=CmabRetryConfig(), - logger=self.logger + logger=self.logger, + prediction_endpoint=cmab_prediction_endpoint ) self.cmab_cache: LRUCache[str, CmabCacheValue] = LRUCache(DEFAULT_CMAB_CACHE_SIZE, DEFAULT_CMAB_CACHE_TIMEOUT) diff --git a/tests/test_cmab_client.py b/tests/test_cmab_client.py index 3aac5fd9..3613da76 100644 --- a/tests/test_cmab_client.py +++ b/tests/test_cmab_client.py @@ -245,3 +245,81 @@ def test_fetch_decision_exhausts_all_retry_attempts(self, mock_sleep): self.mock_logger.error.assert_called_with( Errors.CMAB_FETCH_FAILED.format('Exhausted all retries for CMAB request.') ) + + def test_custom_prediction_endpoint(self): + """Test that custom prediction endpoint is used correctly.""" + custom_endpoint = "https://custom.endpoint.com/predict/{}" + client = DefaultCmabClient( + http_client=self.mock_http_client, + logger=self.mock_logger, + prediction_endpoint=custom_endpoint + ) + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + 'predictions': [{'variation_id': 'abc123'}] + } + self.mock_http_client.post.return_value = mock_response + + result = client.fetch_decision(self.rule_id, self.user_id, self.attributes, self.cmab_uuid) + + self.assertEqual(result, 'abc123') + expected_custom_url = custom_endpoint.format(self.rule_id) + self.mock_http_client.post.assert_called_once_with( + expected_custom_url, + data=json.dumps(self.expected_body), + headers=self.expected_headers, + timeout=10.0 + ) + + def test_default_prediction_endpoint(self): + """Test that default prediction endpoint is used when none is provided.""" + client = DefaultCmabClient( + http_client=self.mock_http_client, + logger=self.mock_logger + ) + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + 'predictions': [{'variation_id': 'def456'}] + } + self.mock_http_client.post.return_value = mock_response + + result = client.fetch_decision(self.rule_id, self.user_id, self.attributes, self.cmab_uuid) + + self.assertEqual(result, 'def456') + # Should use the default production endpoint + self.mock_http_client.post.assert_called_once_with( + self.expected_url, + data=json.dumps(self.expected_body), + headers=self.expected_headers, + timeout=10.0 + ) + + def test_empty_prediction_endpoint_uses_default(self): + """Test that empty string prediction endpoint falls back to default.""" + client = DefaultCmabClient( + http_client=self.mock_http_client, + logger=self.mock_logger, + prediction_endpoint="" + ) + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + 'predictions': [{'variation_id': 'ghi789'}] + } + self.mock_http_client.post.return_value = mock_response + + result = client.fetch_decision(self.rule_id, self.user_id, self.attributes, self.cmab_uuid) + + self.assertEqual(result, 'ghi789') + # Should use the default production endpoint when empty string is provided + self.mock_http_client.post.assert_called_once_with( + self.expected_url, + data=json.dumps(self.expected_body), + headers=self.expected_headers, + timeout=10.0 + ) diff --git a/tests/test_config_manager.py b/tests/test_config_manager.py index 56674381..1930520e 100644 --- a/tests/test_config_manager.py +++ b/tests/test_config_manager.py @@ -517,8 +517,10 @@ def test_fetch_datafile__exception_polling_thread_failed(self, _): log_messages = [args[0] for args, _ in mock_logger.error.call_args_list] for message in log_messages: print(message) - if "Thread for background datafile polling failed. " \ - "Error: timestamp too large to convert to C PyTime_t" not in message: + # Check for key parts of the error message (version-agnostic for Python 3.11+) + if not ("Thread for background datafile polling failed" in message and + "timestamp too large to convert to C" in message and + "PyTime_t" in message): assert False def test_is_running(self, _):