22
33import json
44import logging
5- from typing import Any , Dict , Iterable , List , Optional , Tuple
5+ from typing import Any , Generator , Iterable , Optional
66
77from ..types .content import ContentBlock , Message , Messages
88from ..types .models import Model
@@ -80,7 +80,7 @@ def handle_message_start(event: MessageStartEvent, message: Message) -> Message:
8080 return message
8181
8282
83- def handle_content_block_start (event : ContentBlockStartEvent ) -> Dict [str , Any ]:
83+ def handle_content_block_start (event : ContentBlockStartEvent ) -> dict [str , Any ]:
8484 """Handles the start of a content block by extracting tool usage information if any.
8585
8686 Args:
@@ -102,61 +102,59 @@ def handle_content_block_start(event: ContentBlockStartEvent) -> Dict[str, Any]:
102102
103103
104104def handle_content_block_delta (
105- event : ContentBlockDeltaEvent , state : Dict [str , Any ], callback_handler : Any , ** kwargs : Any
106- ) -> Dict [ str , Any ]:
105+ event : ContentBlockDeltaEvent , state : dict [str , Any ]
106+ ) -> tuple [ dict [ str , Any ], dict [ str , Any ] ]:
107107 """Handles content block delta updates by appending text, tool input, or reasoning content to the state.
108108
109109 Args:
110110 event: Delta event.
111111 state: The current state of message processing.
112- callback_handler: Callback for processing events as they happen.
113- **kwargs: Additional keyword arguments to pass to the callback handler.
114112
115113 Returns:
116114 Updated state with appended text or tool input.
117115 """
118116 delta_content = event ["delta" ]
119117
118+ callback_event = {}
119+
120120 if "toolUse" in delta_content :
121121 if "input" not in state ["current_tool_use" ]:
122122 state ["current_tool_use" ]["input" ] = ""
123123
124124 state ["current_tool_use" ]["input" ] += delta_content ["toolUse" ]["input" ]
125- callback_handler ( delta = delta_content , current_tool_use = state ["current_tool_use" ], ** kwargs )
125+ callback_event [ "callback" ] = { "delta" : delta_content , " current_tool_use" : state ["current_tool_use" ]}
126126
127127 elif "text" in delta_content :
128128 state ["text" ] += delta_content ["text" ]
129- callback_handler ( data = delta_content ["text" ], delta = delta_content , ** kwargs )
129+ callback_event [ "callback" ] = { "data" : delta_content ["text" ], " delta" : delta_content }
130130
131131 elif "reasoningContent" in delta_content :
132132 if "text" in delta_content ["reasoningContent" ]:
133133 if "reasoningText" not in state :
134134 state ["reasoningText" ] = ""
135135
136136 state ["reasoningText" ] += delta_content ["reasoningContent" ]["text" ]
137- callback_handler (
138- reasoningText = delta_content ["reasoningContent" ]["text" ],
139- delta = delta_content ,
140- reasoning = True ,
141- ** kwargs ,
142- )
137+ callback_event ["callback" ] = {
138+ "reasoningText" : delta_content ["reasoningContent" ]["text" ],
139+ "delta" : delta_content ,
140+ "reasoning" : True ,
141+ }
143142
144143 elif "signature" in delta_content ["reasoningContent" ]:
145144 if "signature" not in state :
146145 state ["signature" ] = ""
147146
148147 state ["signature" ] += delta_content ["reasoningContent" ]["signature" ]
149- callback_handler (
150- reasoning_signature = delta_content ["reasoningContent" ]["signature" ],
151- delta = delta_content ,
152- reasoning = True ,
153- ** kwargs ,
154- )
148+ callback_event ["callback" ] = {
149+ "reasoning_signature" : delta_content ["reasoningContent" ]["signature" ],
150+ "delta" : delta_content ,
151+ "reasoning" : True ,
152+ }
155153
156- return state
154+ return state , callback_event
157155
158156
159- def handle_content_block_stop (state : Dict [str , Any ]) -> Dict [str , Any ]:
157+ def handle_content_block_stop (state : dict [str , Any ]) -> dict [str , Any ]:
160158 """Handles the end of a content block by finalizing tool usage, text content, or reasoning content.
161159
162160 Args:
@@ -165,7 +163,7 @@ def handle_content_block_stop(state: Dict[str, Any]) -> Dict[str, Any]:
165163 Returns:
166164 Updated state with finalized content block.
167165 """
168- content : List [ContentBlock ] = state ["content" ]
166+ content : list [ContentBlock ] = state ["content" ]
169167
170168 current_tool_use = state ["current_tool_use" ]
171169 text = state ["text" ]
@@ -223,7 +221,7 @@ def handle_message_stop(event: MessageStopEvent) -> StopReason:
223221 return event ["stopReason" ]
224222
225223
226- def handle_redact_content (event : RedactContentEvent , messages : Messages , state : Dict [str , Any ]) -> None :
224+ def handle_redact_content (event : RedactContentEvent , messages : Messages , state : dict [str , Any ]) -> None :
227225 """Handles redacting content from the input or output.
228226
229227 Args:
@@ -238,7 +236,7 @@ def handle_redact_content(event: RedactContentEvent, messages: Messages, state:
238236 state ["message" ]["content" ] = [{"text" : event ["redactAssistantContentMessage" ]}]
239237
240238
241- def extract_usage_metrics (event : MetadataEvent ) -> Tuple [Usage , Metrics ]:
239+ def extract_usage_metrics (event : MetadataEvent ) -> tuple [Usage , Metrics ]:
242240 """Extracts usage metrics from the metadata chunk.
243241
244242 Args:
@@ -255,25 +253,20 @@ def extract_usage_metrics(event: MetadataEvent) -> Tuple[Usage, Metrics]:
255253
256254def process_stream (
257255 chunks : Iterable [StreamEvent ],
258- callback_handler : Any ,
259256 messages : Messages ,
260- ** kwargs : Any ,
261- ) -> Tuple [StopReason , Message , Usage , Metrics , Any ]:
257+ ) -> Generator [dict [str , Any ], None , None ]:
262258 """Processes the response stream from the API, constructing the final message and extracting usage metrics.
263259
264260 Args:
265261 chunks: The chunks of the response stream from the model.
266- callback_handler: Callback for processing events as they happen.
267262 messages: The agents messages.
268- **kwargs: Additional keyword arguments that will be passed to the callback handler.
269- And also returned in the request_state.
270263
271264 Returns:
272- The reason for stopping, the constructed message, the usage metrics, and the updated request state .
265+ The reason for stopping, the constructed message, and the usage metrics .
273266 """
274267 stop_reason : StopReason = "end_turn"
275268
276- state : Dict [str , Any ] = {
269+ state : dict [str , Any ] = {
277270 "message" : {"role" : "assistant" , "content" : []},
278271 "text" : "" ,
279272 "current_tool_use" : {},
@@ -285,18 +278,16 @@ def process_stream(
285278 usage : Usage = Usage (inputTokens = 0 , outputTokens = 0 , totalTokens = 0 )
286279 metrics : Metrics = Metrics (latencyMs = 0 )
287280
288- kwargs .setdefault ("request_state" , {})
289-
290281 for chunk in chunks :
291- # Callback handler call here allows each event to be visible to the caller
292- callback_handler (event = chunk )
282+ yield {"callback" : {"event" : chunk }}
293283
294284 if "messageStart" in chunk :
295285 state ["message" ] = handle_message_start (chunk ["messageStart" ], state ["message" ])
296286 elif "contentBlockStart" in chunk :
297287 state ["current_tool_use" ] = handle_content_block_start (chunk ["contentBlockStart" ])
298288 elif "contentBlockDelta" in chunk :
299- state = handle_content_block_delta (chunk ["contentBlockDelta" ], state , callback_handler , ** kwargs )
289+ state , callback_event = handle_content_block_delta (chunk ["contentBlockDelta" ], state )
290+ yield callback_event
300291 elif "contentBlockStop" in chunk :
301292 state = handle_content_block_stop (state )
302293 elif "messageStop" in chunk :
@@ -306,35 +297,30 @@ def process_stream(
306297 elif "redactContent" in chunk :
307298 handle_redact_content (chunk ["redactContent" ], messages , state )
308299
309- return stop_reason , state ["message" ], usage , metrics , kwargs [ "request_state" ]
300+ yield { "stop" : ( stop_reason , state ["message" ], usage , metrics )}
310301
311302
312303def stream_messages (
313304 model : Model ,
314305 system_prompt : Optional [str ],
315306 messages : Messages ,
316307 tool_config : Optional [ToolConfig ],
317- callback_handler : Any ,
318- ** kwargs : Any ,
319- ) -> Tuple [StopReason , Message , Usage , Metrics , Any ]:
308+ ) -> Generator [dict [str , Any ], None , None ]:
320309 """Streams messages to the model and processes the response.
321310
322311 Args:
323312 model: Model provider.
324313 system_prompt: The system prompt to send.
325314 messages: List of messages to send.
326315 tool_config: Configuration for the tools to use.
327- callback_handler: Callback for processing events as they happen.
328- **kwargs: Additional keyword arguments that will be passed to the callback handler.
329- And also returned in the request_state.
330316
331317 Returns:
332- The reason for stopping, the final message, the usage metrics, and updated request state.
318+ The reason for stopping, the final message, and the usage metrics
333319 """
334320 logger .debug ("model=<%s> | streaming messages" , model )
335321
336322 messages = remove_blank_messages_content_text (messages )
337323 tool_specs = [tool ["toolSpec" ] for tool in tool_config .get ("tools" , [])] or None if tool_config else None
338324
339325 chunks = model .converse (messages , tool_specs , system_prompt )
340- return process_stream (chunks , callback_handler , messages , ** kwargs )
326+ yield from process_stream (chunks , messages )
0 commit comments