@@ -245,3 +245,81 @@ def test_fetch_decision_exhausts_all_retry_attempts(self, mock_sleep):
245245 self .mock_logger .error .assert_called_with (
246246 Errors .CMAB_FETCH_FAILED .format ('Exhausted all retries for CMAB request.' )
247247 )
248+
249+ def test_custom_prediction_endpoint (self ):
250+ """Test that custom prediction endpoint is used correctly."""
251+ custom_endpoint = "https://custom.endpoint.com/predict/{}"
252+ client = DefaultCmabClient (
253+ http_client = self .mock_http_client ,
254+ logger = self .mock_logger ,
255+ prediction_endpoint = custom_endpoint
256+ )
257+
258+ mock_response = MagicMock ()
259+ mock_response .status_code = 200
260+ mock_response .json .return_value = {
261+ 'predictions' : [{'variation_id' : 'abc123' }]
262+ }
263+ self .mock_http_client .post .return_value = mock_response
264+
265+ result = client .fetch_decision (self .rule_id , self .user_id , self .attributes , self .cmab_uuid )
266+
267+ self .assertEqual (result , 'abc123' )
268+ expected_custom_url = custom_endpoint .format (self .rule_id )
269+ self .mock_http_client .post .assert_called_once_with (
270+ expected_custom_url ,
271+ data = json .dumps (self .expected_body ),
272+ headers = self .expected_headers ,
273+ timeout = 10.0
274+ )
275+
276+ def test_default_prediction_endpoint (self ):
277+ """Test that default prediction endpoint is used when none is provided."""
278+ client = DefaultCmabClient (
279+ http_client = self .mock_http_client ,
280+ logger = self .mock_logger
281+ )
282+
283+ mock_response = MagicMock ()
284+ mock_response .status_code = 200
285+ mock_response .json .return_value = {
286+ 'predictions' : [{'variation_id' : 'def456' }]
287+ }
288+ self .mock_http_client .post .return_value = mock_response
289+
290+ result = client .fetch_decision (self .rule_id , self .user_id , self .attributes , self .cmab_uuid )
291+
292+ self .assertEqual (result , 'def456' )
293+ # Should use the default production endpoint
294+ self .mock_http_client .post .assert_called_once_with (
295+ self .expected_url ,
296+ data = json .dumps (self .expected_body ),
297+ headers = self .expected_headers ,
298+ timeout = 10.0
299+ )
300+
301+ def test_empty_prediction_endpoint_uses_default (self ):
302+ """Test that empty string prediction endpoint falls back to default."""
303+ client = DefaultCmabClient (
304+ http_client = self .mock_http_client ,
305+ logger = self .mock_logger ,
306+ prediction_endpoint = ""
307+ )
308+
309+ mock_response = MagicMock ()
310+ mock_response .status_code = 200
311+ mock_response .json .return_value = {
312+ 'predictions' : [{'variation_id' : 'ghi789' }]
313+ }
314+ self .mock_http_client .post .return_value = mock_response
315+
316+ result = client .fetch_decision (self .rule_id , self .user_id , self .attributes , self .cmab_uuid )
317+
318+ self .assertEqual (result , 'ghi789' )
319+ # Should use the default production endpoint when empty string is provided
320+ self .mock_http_client .post .assert_called_once_with (
321+ self .expected_url ,
322+ data = json .dumps (self .expected_body ),
323+ headers = self .expected_headers ,
324+ timeout = 10.0
325+ )
0 commit comments