77import json
88import logging
99import mimetypes
10- from typing import Any , Iterable , Optional , TypedDict , cast
10+ from typing import Any , Callable , Iterable , Optional , Type , TypedDict , TypeVar , cast
1111
1212import anthropic
13+ from pydantic import BaseModel
1314from typing_extensions import Required , Unpack , override
1415
16+ from ..event_loop .streaming import process_stream
17+ from ..handlers .callback_handler import PrintingCallbackHandler
18+ from ..tools import convert_pydantic_to_tool_spec
1519from ..types .content import ContentBlock , Messages
1620from ..types .exceptions import ContextWindowOverflowException , ModelThrottledException
1721from ..types .models import Model
2024
2125logger = logging .getLogger (__name__ )
2226
27+ T = TypeVar ("T" , bound = BaseModel )
28+
2329
2430class AnthropicModel (Model ):
2531 """Anthropic model provider implementation."""
@@ -356,10 +362,10 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
356362 with self .client .messages .stream (** request ) as stream :
357363 for event in stream :
358364 if event .type in AnthropicModel .EVENT_TYPES :
359- yield event .dict ()
365+ yield event .model_dump ()
360366
361367 usage = event .message .usage # type: ignore
362- yield {"type" : "metadata" , "usage" : usage .dict ()}
368+ yield {"type" : "metadata" , "usage" : usage .model_dump ()}
363369
364370 except anthropic .RateLimitError as error :
365371 raise ModelThrottledException (str (error )) from error
@@ -369,3 +375,42 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
369375 raise ContextWindowOverflowException (str (error )) from error
370376
371377 raise error
378+
379+ @override
380+ def structured_output (
381+ self , output_model : Type [T ], prompt : Messages , callback_handler : Optional [Callable ] = None
382+ ) -> T :
383+ """Get structured output from the model.
384+
385+ Args:
386+ output_model(Type[BaseModel]): The output model to use for the agent.
387+ prompt(Messages): The prompt messages to use for the agent.
388+ callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None.
389+ """
390+ tool_spec = convert_pydantic_to_tool_spec (output_model )
391+
392+ response = self .converse (messages = prompt , tool_specs = [tool_spec ])
393+ # process the stream and get the tool use input
394+ results = process_stream (
395+ response , callback_handler = callback_handler or PrintingCallbackHandler (), messages = prompt
396+ )
397+
398+ stop_reason , messages , _ , _ , _ = results
399+
400+ if stop_reason != "tool_use" :
401+ raise ValueError ("No valid tool use or tool use input was found in the Anthropic response." )
402+
403+ content = messages ["content" ]
404+ output_response : dict [str , Any ] | None = None
405+ for block in content :
406+ # if the tool use name doesn't match the tool spec name, skip, and if the block is not a tool use, skip.
407+ # if the tool use name never matches, raise an error.
408+ if block .get ("toolUse" ) and block ["toolUse" ]["name" ] == tool_spec ["name" ]:
409+ output_response = block ["toolUse" ]["input" ]
410+ else :
411+ continue
412+
413+ if output_response is None :
414+ raise ValueError ("No valid tool use or tool use input was found in the Anthropic response." )
415+
416+ return output_model (** output_response )
0 commit comments