diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index 0dbd85d81..d241b6c76 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -9,7 +9,7 @@ from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from enum import Enum -from typing import Any, Union +from typing import Any, AsyncIterator, Union from ..agent import AgentResult from ..types.content import ContentBlock @@ -98,6 +98,31 @@ async def invoke_async( """ raise NotImplementedError("invoke_async not implemented") + async def stream_async( + self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> AsyncIterator[dict[str, Any]]: + """Stream events during multi-agent execution. + + Default implementation executes invoke_async and yields the result as a single event. + Subclasses can override this method to provide true streaming capabilities. + + Args: + task: The task to execute + invocation_state: Additional state/context passed to underlying agents. + Defaults to None to avoid mutable default argument issues. + **kwargs: Additional keyword arguments passed to underlying agents. + + Yields: + Dictionary events containing multi-agent execution information including: + - Multi-agent coordination events (node start/complete, handoffs) + - Forwarded single-agent events with node context + - Final result event + """ + # Default implementation for backward compatibility + # Execute invoke_async and yield the result as a single event + result = await self.invoke_async(task, invocation_state, **kwargs) + yield {"result": result} + def __call__( self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any ) -> MultiAgentResult: diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 1dbbfc3af..493772ad3 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -20,13 +20,20 @@ import time from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field -from typing import Any, Callable, Optional, Tuple +from typing import Any, AsyncIterator, Callable, Optional, Tuple, cast from opentelemetry import trace as trace_api from ..agent import Agent from ..agent.state import AgentState from ..telemetry import get_tracer +from ..types._events import ( + MultiAgentHandoffEvent, + MultiAgentNodeStartEvent, + MultiAgentNodeStopEvent, + MultiAgentNodeStreamEvent, + MultiAgentResultEvent, +) from ..types.content import ContentBlock, Messages from ..types.event_loop import Metrics, Usage from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status @@ -411,13 +418,43 @@ async def invoke_async( ) -> GraphResult: """Invoke the graph asynchronously. + This method uses stream_async internally and consumes all events until completion, + following the same pattern as the Agent class. + Args: task: The task to execute invocation_state: Additional state/context passed to underlying agents. - Defaults to None to avoid mutable default argument issues - a new empty dict - is created if None is provided. + Defaults to None to avoid mutable default argument issues. **kwargs: Keyword arguments allowing backward compatible future changes. """ + events = self.stream_async(task, invocation_state, **kwargs) + final_event = None + async for event in events: + final_event = event + + if final_event is None or "result" not in final_event: + raise ValueError("Graph streaming completed without producing a result event") + + return cast(GraphResult, final_event["result"]) + + async def stream_async( + self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> AsyncIterator[dict[str, Any]]: + """Stream events during graph execution. + + Args: + task: The task to execute + invocation_state: Additional state/context passed to underlying agents. + Defaults to None to avoid mutable default argument issues. + **kwargs: Keyword arguments allowing backward compatible future changes. + + Yields: + Dictionary events during graph execution, such as: + - multi_agent_node_start: When a node begins execution + - multi_agent_node_stream: Forwarded agent/multi-agent events with node context + - multi_agent_node_stop: When a node stops execution + - result: Final graph result + """ if invocation_state is None: invocation_state = {} @@ -444,23 +481,29 @@ async def invoke_async( self.node_timeout or "None", ) - await self._execute_graph(invocation_state) + async for event in self._execute_graph(invocation_state): + yield event.as_dict() # Set final status based on execution results if self.state.failed_nodes: self.state.status = Status.FAILED - elif self.state.status == Status.EXECUTING: # Only set to COMPLETED if still executing and no failures + elif self.state.status == Status.EXECUTING: self.state.status = Status.COMPLETED logger.debug("status=<%s> | graph execution completed", self.state.status) + # Yield final result (consistent with Agent's AgentResultEvent format) + result = self._build_result() + + # Use the same event format as Agent for consistency + yield MultiAgentResultEvent(result=result).as_dict() + except Exception: logger.exception("graph execution failed") self.state.status = Status.FAILED raise finally: self.state.execution_time = round((time.time() - start_time) * 1000) - return self._build_result() def _validate_graph(self, nodes: dict[str, GraphNode]) -> None: """Validate graph nodes for duplicate instances.""" @@ -474,8 +517,8 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None: # Validate Agent-specific constraints for each node _validate_node_executor(node.executor) - async def _execute_graph(self, invocation_state: dict[str, Any]) -> None: - """Unified execution flow with conditional routing.""" + async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterator[Any]: + """Execute graph and yield TypedEvent objects.""" ready_nodes = list(self.entry_points) while ready_nodes: @@ -492,16 +535,149 @@ async def _execute_graph(self, invocation_state: dict[str, Any]) -> None: current_batch = ready_nodes.copy() ready_nodes.clear() - # Execute current batch of ready nodes concurrently - tasks = [asyncio.create_task(self._execute_node(node, invocation_state)) for node in current_batch] - - for task in tasks: - await task + # Execute current batch + async for event in self._execute_nodes_parallel(current_batch, invocation_state): + yield event # Find newly ready nodes after batch execution # We add all nodes in current batch as completed batch, # because a failure would throw exception and code would not make it here - ready_nodes.extend(self._find_newly_ready_nodes(current_batch)) + newly_ready = self._find_newly_ready_nodes(current_batch) + + # Emit handoff event for batch transition if there are nodes to transition to + if newly_ready: + handoff_event = MultiAgentHandoffEvent( + from_node_ids=[node.node_id for node in current_batch], + to_node_ids=[node.node_id for node in newly_ready], + ) + yield handoff_event + logger.debug( + "from_node_ids=<%s>, to_node_ids=<%s> | batch transition", + [node.node_id for node in current_batch], + [node.node_id for node in newly_ready], + ) + + ready_nodes.extend(newly_ready) + + async def _execute_nodes_parallel( + self, nodes: list["GraphNode"], invocation_state: dict[str, Any] + ) -> AsyncIterator[Any]: + """Execute multiple nodes in parallel and merge their event streams in real-time. + + Uses a shared queue where each node's stream runs independently and pushes events + as they occur, enabling true real-time event propagation without round-robin delays. + """ + event_queue: asyncio.Queue[Any | None | Exception] = asyncio.Queue() + + # Start all node streams as independent tasks + tasks = [asyncio.create_task(self._stream_node_to_queue(node, event_queue, invocation_state)) for node in nodes] + + try: + # Consume events from the queue as they arrive + # Continue until all tasks are done + while any(not task.done() for task in tasks): + try: + # Use timeout to avoid race condition: if all tasks complete between + # checking task.done() and calling queue.get(), we'd hang forever. + # The 0.1s timeout allows us to periodically re-check task completion + # while still being responsive to incoming events. + event = await asyncio.wait_for(event_queue.get(), timeout=0.1) + except asyncio.TimeoutError: + # No event available, continue checking tasks + continue + + # Check if it's an exception - fail fast + if isinstance(event, Exception): + # Cancel all other tasks immediately + for task in tasks: + if not task.done(): + task.cancel() + raise event + + if event is not None: + yield event + + # Process any remaining events in the queue after all tasks complete + while not event_queue.empty(): + event = await event_queue.get() + if isinstance(event, Exception): + raise event + if event is not None: + yield event + finally: + # Cancel any remaining tasks + remaining_tasks = [task for task in tasks if not task.done()] + if remaining_tasks: + logger.warning( + "remaining_task_count=<%d> | cancelling remaining tasks in finally block", + len(remaining_tasks), + ) + for task in remaining_tasks: + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + + async def _stream_node_to_queue( + self, + node: GraphNode, + event_queue: asyncio.Queue[Any | None | Exception], + invocation_state: dict[str, Any], + ) -> None: + """Stream events from a node to the shared queue with optional timeout.""" + try: + # Apply timeout to the entire streaming process if configured + if self.node_timeout is not None: + + async def stream_node() -> None: + async for event in self._execute_node(node, invocation_state): + await event_queue.put(event) + + try: + await asyncio.wait_for(stream_node(), timeout=self.node_timeout) + except asyncio.TimeoutError: + # Handle timeout and send exception through queue + timeout_exc = await self._handle_node_timeout(node, event_queue) + await event_queue.put(timeout_exc) + else: + # No timeout - stream normally + async for event in self._execute_node(node, invocation_state): + await event_queue.put(event) + except Exception as e: + # Send exception through queue for fail-fast behavior + await event_queue.put(e) + finally: + await event_queue.put(None) + + async def _handle_node_timeout(self, node: GraphNode, event_queue: asyncio.Queue[Any | None]) -> Exception: + """Handle a node timeout by creating a failed result and emitting events. + + Returns: + The timeout exception to be re-raised for fail-fast behavior + """ + assert self.node_timeout is not None + timeout_exception = Exception(f"Node '{node.node_id}' execution timed out after {self.node_timeout}s") + + node_result = NodeResult( + result=timeout_exception, + execution_time=round(self.node_timeout * 1000), + status=Status.FAILED, + accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0), + accumulated_metrics=Metrics(latencyMs=round(self.node_timeout * 1000)), + execution_count=1, + ) + + node.execution_status = Status.FAILED + node.result = node_result + node.execution_time = node_result.execution_time + self.state.failed_nodes.add(node) + self.state.results[node.node_id] = node_result + + complete_event = MultiAgentNodeStopEvent( + node_id=node.node_id, + node_result=node_result, + ) + await event_queue.put(complete_event) + + return timeout_exception def _find_newly_ready_nodes(self, completed_batch: list["GraphNode"]) -> list["GraphNode"]: """Find nodes that became ready after the last execution.""" @@ -530,90 +706,92 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[ ) return False - async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> None: - """Execute a single node with error handling and timeout protection.""" + async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> AsyncIterator[Any]: + """Execute a single node and yield TypedEvent objects.""" # Reset the node's state if reset_on_revisit is enabled and it's being revisited if self.reset_on_revisit and node in self.state.completed_nodes: logger.debug("node_id=<%s> | resetting node state for revisit", node.node_id) node.reset_executor_state() - # Remove from completed nodes since we're re-executing it self.state.completed_nodes.remove(node) node.execution_status = Status.EXECUTING logger.debug("node_id=<%s> | executing node", node.node_id) + # Emit node start event + start_event = MultiAgentNodeStartEvent( + node_id=node.node_id, node_type="agent" if isinstance(node.executor, Agent) else "multiagent" + ) + yield start_event + start_time = time.time() try: # Build node input from satisfied dependencies node_input = self._build_node_input(node) - # Execute with timeout protection (only if node_timeout is set) - try: - # Execute based on node type and create unified NodeResult - if isinstance(node.executor, MultiAgentBase): - if self.node_timeout is not None: - multi_agent_result = await asyncio.wait_for( - node.executor.invoke_async(node_input, invocation_state), - timeout=self.node_timeout, - ) - else: - multi_agent_result = await node.executor.invoke_async(node_input, invocation_state) - - # Create NodeResult with MultiAgentResult directly - node_result = NodeResult( - result=multi_agent_result, # type is MultiAgentResult - execution_time=multi_agent_result.execution_time, - status=Status.COMPLETED, - accumulated_usage=multi_agent_result.accumulated_usage, - accumulated_metrics=multi_agent_result.accumulated_metrics, - execution_count=multi_agent_result.execution_count, - ) + # Execute and stream events (timeout handled at task level) + if isinstance(node.executor, MultiAgentBase): + # For nested multi-agent systems, stream their events and collect result + multi_agent_result = None + async for event in node.executor.stream_async(node_input, invocation_state): + # Forward nested multi-agent events with node context + wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event) + yield wrapped_event + # Capture the final result event + if "result" in event: + multi_agent_result = event["result"] + + # Use the captured result from streaming (no double execution) + if multi_agent_result is None: + raise ValueError(f"Node '{node.node_id}' did not produce a result event") + + node_result = NodeResult( + result=multi_agent_result, + execution_time=multi_agent_result.execution_time, + status=Status.COMPLETED, + accumulated_usage=multi_agent_result.accumulated_usage, + accumulated_metrics=multi_agent_result.accumulated_metrics, + execution_count=multi_agent_result.execution_count, + ) - elif isinstance(node.executor, Agent): - if self.node_timeout is not None: - agent_response = await asyncio.wait_for( - node.executor.invoke_async(node_input, invocation_state=invocation_state), - timeout=self.node_timeout, - ) - else: - agent_response = await node.executor.invoke_async(node_input, invocation_state=invocation_state) - - if agent_response.stop_reason == "interrupt": - node.executor.messages.pop() # remove interrupted tool use message - node.executor._interrupt_state.deactivate() - - raise RuntimeError( - "user raised interrupt from agent | interrupts are not yet supported in graphs" - ) - - # Extract metrics from agent response - usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0) - metrics = Metrics(latencyMs=0) - if hasattr(agent_response, "metrics") and agent_response.metrics: - if hasattr(agent_response.metrics, "accumulated_usage"): - usage = agent_response.metrics.accumulated_usage - if hasattr(agent_response.metrics, "accumulated_metrics"): - metrics = agent_response.metrics.accumulated_metrics - - node_result = NodeResult( - result=agent_response, # type is AgentResult - execution_time=round((time.time() - start_time) * 1000), - status=Status.COMPLETED, - accumulated_usage=usage, - accumulated_metrics=metrics, - execution_count=1, - ) - else: - raise ValueError(f"Node '{node.node_id}' of type '{type(node.executor)}' is not supported") - - except asyncio.TimeoutError: - timeout_msg = f"Node '{node.node_id}' execution timed out after {self.node_timeout}s" - logger.exception( - "node=<%s>, timeout=<%s>s | node execution timed out after timeout", - node.node_id, - self.node_timeout, + elif isinstance(node.executor, Agent): + # For agents, stream their events and collect result + agent_response = None + async for event in node.executor.stream_async(node_input, invocation_state=invocation_state): + # Forward agent events with node context + wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event) + yield wrapped_event + # Capture the final result event + if "result" in event: + agent_response = event["result"] + + # Use the captured result from streaming (no double execution) + if agent_response is None: + raise ValueError(f"Node '{node.node_id}' did not produce a result event") + + # Check for interrupt (from main branch) + if agent_response.stop_reason == "interrupt": + node.executor.messages.pop() # remove interrupted tool use message + node.executor._interrupt_state.deactivate() + + raise RuntimeError("user raised interrupt from agent | interrupts are not yet supported in graphs") + + # Extract metrics with defaults + response_metrics = getattr(agent_response, "metrics", None) + usage = getattr( + response_metrics, "accumulated_usage", Usage(inputTokens=0, outputTokens=0, totalTokens=0) ) - raise Exception(timeout_msg) from None + metrics = getattr(response_metrics, "accumulated_metrics", Metrics(latencyMs=0)) + + node_result = NodeResult( + result=agent_response, + execution_time=round((time.time() - start_time) * 1000), + status=Status.COMPLETED, + accumulated_usage=usage, + accumulated_metrics=metrics, + execution_count=1, + ) + else: + raise ValueError(f"Node '{node.node_id}' of type '{type(node.executor)}' is not supported") # Mark as completed node.execution_status = Status.COMPLETED @@ -626,17 +804,28 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) # Accumulate metrics self._accumulate_metrics(node_result) + # Emit node stop event with full NodeResult + complete_event = MultiAgentNodeStopEvent( + node_id=node.node_id, + node_result=node_result, + ) + yield complete_event + logger.debug( - "node_id=<%s>, execution_time=<%dms> | node completed successfully", node.node_id, node.execution_time + "node_id=<%s>, execution_time=<%dms> | node completed successfully", + node.node_id, + node.execution_time, ) except Exception as e: + # All failures (programming errors and execution failures) stop graph execution + # This matches the old fail-fast behavior logger.error("node_id=<%s>, error=<%s> | node failed", node.node_id, e) execution_time = round((time.time() - start_time) * 1000) # Create a NodeResult for the failed node node_result = NodeResult( - result=e, # Store exception as result + result=e, execution_time=execution_time, status=Status.FAILED, accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0), @@ -648,8 +837,16 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) node.result = node_result node.execution_time = execution_time self.state.failed_nodes.add(node) - self.state.results[node.node_id] = node_result # Store in results for consistency + self.state.results[node.node_id] = node_result + + # Emit stop event even for failures + complete_event = MultiAgentNodeStopEvent( + node_id=node.node_id, + node_result=node_result, + ) + yield complete_event + # Re-raise to stop graph execution (fail-fast behavior) raise def _accumulate_metrics(self, node_result: NodeResult) -> None: diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 7542b1b85..87aedaf29 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -19,14 +19,21 @@ import time from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field -from typing import Any, Callable, Tuple +from typing import Any, AsyncIterator, Callable, Tuple, cast from opentelemetry import trace as trace_api -from ..agent import Agent, AgentResult +from ..agent import Agent from ..agent.state import AgentState from ..telemetry import get_tracer from ..tools.decorator import tool +from ..types._events import ( + MultiAgentHandoffEvent, + MultiAgentNodeStartEvent, + MultiAgentNodeStopEvent, + MultiAgentNodeStreamEvent, + MultiAgentResultEvent, +) from ..types.content import ContentBlock, Messages from ..types.event_loop import Metrics, Usage from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status @@ -266,12 +273,43 @@ async def invoke_async( ) -> SwarmResult: """Invoke the swarm asynchronously. + This method uses stream_async internally and consumes all events until completion, + following the same pattern as the Agent class. + Args: task: The task to execute invocation_state: Additional state/context passed to underlying agents. - Defaults to None to avoid mutable default argument issues - a new empty dict - is created if None is provided. + Defaults to None to avoid mutable default argument issues. + **kwargs: Keyword arguments allowing backward compatible future changes. + """ + events = self.stream_async(task, invocation_state, **kwargs) + final_event = None + async for event in events: + final_event = event + + if final_event is None or "result" not in final_event: + raise ValueError("Swarm streaming completed without producing a result event") + + return cast(SwarmResult, final_event["result"]) + + async def stream_async( + self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> AsyncIterator[dict[str, Any]]: + """Stream events during swarm execution. + + Args: + task: The task to execute + invocation_state: Additional state/context passed to underlying agents. + Defaults to None to avoid mutable default argument issues. **kwargs: Keyword arguments allowing backward compatible future changes. + + Yields: + Dictionary events during swarm execution, such as: + - multi_agent_node_start: When a node begins execution + - multi_agent_node_stream: Forwarded agent events with node context + - multi_agent_handoff: When control is handed off between agents + - multi_agent_node_stop: When a node stops execution + - result: Final swarm result """ if invocation_state is None: invocation_state = {} @@ -282,7 +320,7 @@ async def invoke_async( if self.entry_point: initial_node = self.nodes[str(self.entry_point.name)] else: - initial_node = next(iter(self.nodes.values())) # First SwarmNode + initial_node = next(iter(self.nodes.values())) self.state = SwarmState( current_node=initial_node, @@ -303,7 +341,9 @@ async def invoke_async( self.execution_timeout, ) - await self._execute_swarm(invocation_state) + async for event in self._execute_swarm(invocation_state): + yield event.as_dict() + except Exception: logger.exception("swarm execution failed") self.state.completion_status = Status.FAILED @@ -311,7 +351,52 @@ async def invoke_async( finally: self.state.execution_time = round((time.time() - start_time) * 1000) - return self._build_result() + # Yield final result after execution_time is set + result = self._build_result() + yield MultiAgentResultEvent(result=result).as_dict() + + async def _stream_with_timeout( + self, async_generator: AsyncIterator[Any], timeout: float | None, timeout_message: str + ) -> AsyncIterator[Any]: + """Wrap an async generator with timeout for total execution time. + + Tracks elapsed time from start and enforces timeout across all events. + Each event wait uses remaining time from the total timeout budget. + + Args: + async_generator: The generator to wrap + timeout: Total timeout in seconds for entire stream, or None for no timeout + timeout_message: Message to include in timeout exception + + Yields: + Events from the wrapped generator as they arrive + + Raises: + Exception: If total execution time exceeds timeout + """ + if timeout is None: + # No timeout - just pass through + async for event in async_generator: + yield event + else: + # Track start time for total timeout + start_time = asyncio.get_event_loop().time() + + while True: + # Calculate remaining time from total timeout budget + elapsed = asyncio.get_event_loop().time() - start_time + remaining = timeout - elapsed + + if remaining <= 0: + raise Exception(timeout_message) + + try: + event = await asyncio.wait_for(async_generator.__anext__(), timeout=remaining) + yield event + except StopAsyncIteration: + break + except asyncio.TimeoutError as err: + raise Exception(timeout_message) from err def _setup_swarm(self, nodes: list[Agent]) -> None: """Initialize swarm configuration.""" @@ -533,14 +618,14 @@ def _build_node_input(self, target_node: SwarmNode) -> str: return context_text - async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None: - """Shared execution logic used by execute_async.""" + async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterator[Any]: + """Execute swarm and yield TypedEvent objects.""" try: # Main execution loop while True: if self.state.completion_status != Status.EXECUTING: reason = f"Completion status is: {self.state.completion_status}" - logger.debug("reason=<%s> | stopping execution", reason) + logger.debug("reason=<%s> | stopping streaming execution", reason) break should_continue, reason = self.state.should_continue( @@ -568,34 +653,45 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None: len(self.state.node_history) + 1, ) + # Store the current node before execution to detect handoffs + previous_node = current_node + # Execute node with timeout protection # TODO: Implement cancellation token to stop _execute_node from continuing try: - await asyncio.wait_for( + # Execute with timeout wrapper for async generator streaming + node_stream = self._stream_with_timeout( self._execute_node(current_node, self.state.task, invocation_state), - timeout=self.node_timeout, + self.node_timeout, + f"Node '{current_node.node_id}' execution timed out after {self.node_timeout}s", ) + async for event in node_stream: + yield event self.state.node_history.append(current_node) logger.debug("node=<%s> | node execution completed", current_node.node_id) - # Check if the current node is still the same after execution - # If it is, then no handoff occurred and we consider the swarm complete - if self.state.current_node == current_node: + # Check if handoff occurred during execution + if self.state.current_node != previous_node: + # Emit handoff event (single node transition in Swarm) + handoff_event = MultiAgentHandoffEvent( + from_node_ids=[previous_node.node_id], + to_node_ids=[self.state.current_node.node_id], + message=self.state.handoff_message or "Agent handoff occurred", + ) + yield handoff_event + logger.debug( + "from_node=<%s>, to_node=<%s> | handoff detected", + previous_node.node_id, + self.state.current_node.node_id, + ) + else: + # No handoff occurred, mark swarm as complete logger.debug("node=<%s> | no handoff occurred, marking swarm as complete", current_node.node_id) self.state.completion_status = Status.COMPLETED break - except asyncio.TimeoutError: - logger.exception( - "node=<%s>, timeout=<%s>s | node execution timed out after timeout", - current_node.node_id, - self.node_timeout, - ) - self.state.completion_status = Status.FAILED - break - except Exception: logger.exception("node=<%s> | node execution failed", current_node.node_id) self.state.completion_status = Status.FAILED @@ -604,22 +700,26 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None: except Exception: logger.exception("swarm execution failed") self.state.completion_status = Status.FAILED - - elapsed_time = time.time() - self.state.start_time - logger.debug("status=<%s> | swarm execution completed", self.state.completion_status) - logger.debug( - "node_history_length=<%d>, time=<%s>s | metrics", - len(self.state.node_history), - f"{elapsed_time:.2f}", - ) + finally: + elapsed_time = time.time() - self.state.start_time + logger.debug("status=<%s> | swarm execution completed", self.state.completion_status) + logger.debug( + "node_history_length=<%d>, time=<%s>s | metrics", + len(self.state.node_history), + f"{elapsed_time:.2f}", + ) async def _execute_node( self, node: SwarmNode, task: str | list[ContentBlock], invocation_state: dict[str, Any] - ) -> AgentResult: - """Execute swarm node.""" + ) -> AsyncIterator[Any]: + """Execute swarm node and yield TypedEvent objects.""" start_time = time.time() node_name = node.node_id + # Emit node start event + start_event = MultiAgentNodeStartEvent(node_id=node_name, node_type="agent") + yield start_event + try: # Prepare context for node context_text = self._build_node_input(node) @@ -632,10 +732,21 @@ async def _execute_node( # Include additional ContentBlocks in node input node_input = node_input + task - # Execute node - result = None + # Execute node with streaming node.reset_executor_state() - result = await node.executor.invoke_async(node_input, invocation_state=invocation_state) + + # Stream agent events with node context and capture final result + result = None + async for event in node.executor.stream_async(node_input, invocation_state=invocation_state): + # Forward agent events with node context + wrapped_event = MultiAgentNodeStreamEvent(node_name, event) + yield wrapped_event + # Capture the final result event + if "result" in event: + result = event["result"] + + if result is None: + raise ValueError(f"Node '{node_name}' did not produce a result event") if result.stop_reason == "interrupt": node.executor.messages.pop() # remove interrupted tool use message @@ -645,14 +756,10 @@ async def _execute_node( execution_time = round((time.time() - start_time) * 1000) - # Create NodeResult - usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0) - metrics = Metrics(latencyMs=execution_time) - if hasattr(result, "metrics") and result.metrics: - if hasattr(result.metrics, "accumulated_usage"): - usage = result.metrics.accumulated_usage - if hasattr(result.metrics, "accumulated_metrics"): - metrics = result.metrics.accumulated_metrics + # Create NodeResult with extracted metrics + result_metrics = getattr(result, "metrics", None) + usage = getattr(result_metrics, "accumulated_usage", Usage(inputTokens=0, outputTokens=0, totalTokens=0)) + metrics = getattr(result_metrics, "accumulated_metrics", Metrics(latencyMs=execution_time)) node_result = NodeResult( result=result, @@ -669,7 +776,12 @@ async def _execute_node( # Accumulate metrics self._accumulate_metrics(node_result) - return result + # Emit node stop event with full NodeResult + complete_event = MultiAgentNodeStopEvent( + node_id=node_name, + node_result=node_result, + ) + yield complete_event except Exception as e: execution_time = round((time.time() - start_time) * 1000) @@ -677,7 +789,7 @@ async def _execute_node( # Create a NodeResult for the failed node node_result = NodeResult( - result=e, # Store exception as result + result=e, execution_time=execution_time, status=Status.FAILED, accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0), @@ -688,6 +800,13 @@ async def _execute_node( # Store result in state self.state.results[node_name] = node_result + # Emit node stop event even for failures + complete_event = MultiAgentNodeStopEvent( + node_id=node_name, + node_result=node_result, + ) + yield complete_event + raise def _accumulate_metrics(self, node_result: NodeResult) -> None: diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index 13d4a98f9..4ac570425 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -19,6 +19,7 @@ if TYPE_CHECKING: from ..agent import AgentResult + from ..multiagent.base import MultiAgentResult, NodeResult class TypedEvent(dict): @@ -395,3 +396,116 @@ def __init__(self, reason: str | Exception) -> None: class AgentResultEvent(TypedEvent): def __init__(self, result: "AgentResult"): super().__init__({"result": result}) + + +class MultiAgentResultEvent(TypedEvent): + """Event emitted when multi-agent execution completes with final result.""" + + def __init__(self, result: "MultiAgentResult") -> None: + """Initialize with multi-agent result. + + Args: + result: The final result from multi-agent execution (SwarmResult, GraphResult, etc.) + """ + super().__init__({"type": "multiagent_result", "result": result}) + + +class MultiAgentNodeStartEvent(TypedEvent): + """Event emitted when a node begins execution in multi-agent context.""" + + def __init__(self, node_id: str, node_type: str) -> None: + """Initialize with node information. + + Args: + node_id: Unique identifier for the node + node_type: Type of node ("agent", "swarm", "graph") + """ + super().__init__({"type": "multiagent_node_start", "node_id": node_id, "node_type": node_type}) + + +class MultiAgentNodeStopEvent(TypedEvent): + """Event emitted when a node stops execution. + + Similar to EventLoopStopEvent but for individual nodes in multi-agent orchestration. + Provides the complete NodeResult which contains execution details, metrics, and status. + """ + + def __init__( + self, + node_id: str, + node_result: "NodeResult", + ) -> None: + """Initialize with stop information. + + Args: + node_id: Unique identifier for the node + node_result: Complete result from the node execution containing result, + execution_time, status, accumulated_usage, accumulated_metrics, and execution_count + """ + super().__init__( + { + "type": "multiagent_node_stop", + "node_id": node_id, + "node_result": node_result, + } + ) + + +class MultiAgentHandoffEvent(TypedEvent): + """Event emitted during node transitions in multi-agent systems. + + Supports both single handoffs (Swarm) and batch transitions (Graph). + For Swarm: Single node-to-node handoffs with a message. + For Graph: Batch transitions where multiple nodes complete and multiple nodes begin. + """ + + def __init__( + self, + from_node_ids: list[str], + to_node_ids: list[str], + message: str | None = None, + ) -> None: + """Initialize with handoff information. + + Args: + from_node_ids: List of node ID(s) completing execution. + - Swarm: Single-element list ["agent_a"] + - Graph: Multi-element list ["node1", "node2"] + to_node_ids: List of node ID(s) beginning execution. + - Swarm: Single-element list ["agent_b"] + - Graph: Multi-element list ["node3", "node4"] + message: Optional message explaining the transition (typically used in Swarm) + + Examples: + Swarm handoff: MultiAgentHandoffEvent(["researcher"], ["analyst"], "Need calculations") + Graph batch: MultiAgentHandoffEvent(["node1", "node2"], ["node3", "node4"]) + """ + event_data = { + "type": "multiagent_handoff", + "from_node_ids": from_node_ids, + "to_node_ids": to_node_ids, + } + + if message is not None: + event_data["message"] = message + + super().__init__(event_data) + + +class MultiAgentNodeStreamEvent(TypedEvent): + """Event emitted during node execution - forwards agent events with node context.""" + + def __init__(self, node_id: str, agent_event: dict[str, Any]) -> None: + """Initialize with node context and agent event. + + Args: + node_id: Unique identifier for the node generating the event + agent_event: The original agent event data + """ + super().__init__( + { + "type": "multiagent_node_stream", + "node_id": node_id, + "event": agent_event, # Nest agent event to avoid field conflicts + } + ) diff --git a/tests/strands/agent/hooks/test_agent_events.py b/tests/strands/agent/hooks/test_agent_events.py index 5b4d77e75..4fef595f8 100644 --- a/tests/strands/agent/hooks/test_agent_events.py +++ b/tests/strands/agent/hooks/test_agent_events.py @@ -311,7 +311,7 @@ async def test_stream_e2e_success(alist): message={"content": [{"text": "I invoked the tools!"}], "role": "assistant"}, metrics=ANY, state={}, - ) + ), }, ] assert tru_events == exp_events @@ -453,7 +453,7 @@ async def test_stream_e2e_reasoning_redacted_content(alist): }, metrics=ANY, state={}, - ) + ), }, ] assert tru_events == exp_events diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index ae2d8c7b5..43afcf299 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -750,7 +750,7 @@ def test_agent__call__callback(mock_model, agent, callback_handler, agenerator): }, metrics=unittest.mock.ANY, state={}, - ) + ), ), ] diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index c4c1a664f..24293ad78 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -40,7 +40,13 @@ def create_mock_agent(name, response_text="Default response", metrics=None, agen async def mock_invoke_async(*args, **kwargs): return mock_result + async def mock_stream_async(*args, **kwargs): + # Simple mock stream that yields a start event and then the result + yield {"agent_start": True} + yield {"result": mock_result} + agent.invoke_async = MagicMock(side_effect=mock_invoke_async) + agent.stream_async = Mock(side_effect=mock_stream_async) return agent @@ -66,7 +72,14 @@ def create_mock_multi_agent(name, response_text="Multi-agent response"): execution_count=1, execution_time=150, ) + + async def mock_multi_stream_async(*args, **kwargs): + # Simple mock stream that yields a start event and then the result + yield {"multi_agent_start": True} + yield {"result": mock_result} + multi_agent.invoke_async = AsyncMock(return_value=mock_result) + multi_agent.stream_async = Mock(side_effect=mock_multi_stream_async) multi_agent.execute = Mock(return_value=mock_result) return multi_agent @@ -201,15 +214,15 @@ async def test_graph_execution(mock_strands_tracer, mock_use_span, mock_graph, m assert len(result.execution_order) == 7 assert result.execution_order[0].node_id == "start_agent" - # Verify agent calls - mock_agents["start_agent"].invoke_async.assert_called_once() - mock_agents["multi_agent"].invoke_async.assert_called_once() - mock_agents["conditional_agent"].invoke_async.assert_called_once() - mock_agents["final_agent"].invoke_async.assert_called_once() - mock_agents["no_metrics_agent"].invoke_async.assert_called_once() - mock_agents["partial_metrics_agent"].invoke_async.assert_called_once() - string_content_agent.invoke_async.assert_called_once() - mock_agents["blocked_agent"].invoke_async.assert_not_called() + # Verify agent calls (now using stream_async internally) + assert mock_agents["start_agent"].stream_async.call_count == 1 + assert mock_agents["multi_agent"].stream_async.call_count == 1 + assert mock_agents["conditional_agent"].stream_async.call_count == 1 + assert mock_agents["final_agent"].stream_async.call_count == 1 + assert mock_agents["no_metrics_agent"].stream_async.call_count == 1 + assert mock_agents["partial_metrics_agent"].stream_async.call_count == 1 + assert string_content_agent.stream_async.call_count == 1 + assert mock_agents["blocked_agent"].stream_async.call_count == 0 # Verify metrics aggregation assert result.accumulated_usage["totalTokens"] > 0 @@ -277,7 +290,13 @@ async def test_graph_execution_with_failures(mock_strands_tracer, mock_use_span) async def mock_invoke_failure(*args, **kwargs): raise Exception("Simulated failure") + async def mock_stream_failure(*args, **kwargs): + # Simple mock stream that fails + yield {"agent_start": True} + raise Exception("Simulated failure") + failing_agent.invoke_async = mock_invoke_failure + failing_agent.stream_async = Mock(side_effect=mock_stream_failure) success_agent = create_mock_agent("success_agent", "Success") @@ -289,7 +308,7 @@ async def mock_invoke_failure(*args, **kwargs): graph = builder.build() - # Execute the graph - should raise Exception due to failing agent + # Execute the graph - should raise exception (fail-fast behavior) with pytest.raises(Exception, match="Simulated failure"): await graph.invoke_async("Test error handling") @@ -309,8 +328,8 @@ async def test_graph_edge_cases(mock_strands_tracer, mock_use_span): result = await graph.invoke_async([{"text": "Original task"}]) - # Verify entry node was called with original task - entry_agent.invoke_async.assert_called_once_with([{"text": "Original task"}], invocation_state={}) + # Verify entry node was called with original task (via stream_async) + assert entry_agent.stream_async.call_count == 1 assert result.status == Status.COMPLETED mock_strands_tracer.start_multiagent_span.assert_called() mock_use_span.assert_called_once() @@ -384,10 +403,10 @@ def spy_reset(self): execution_ids = [node.node_id for node in result.execution_order] assert execution_ids == ["a", "b", "c", "a"] - # Verify that each agent was called the expected number of times - assert agent_a.invoke_async.call_count == 2 # A executes twice - assert agent_b.invoke_async.call_count == 1 # B executes once - assert agent_c.invoke_async.call_count == 1 # C executes once + # Verify that each agent was called the expected number of times (via stream_async) + assert agent_a.stream_async.call_count == 2 # A executes twice + assert agent_b.stream_async.call_count == 1 # B executes once + assert agent_c.stream_async.call_count == 1 # C executes once # Verify that node state was reset for the revisited node (A) assert reset_spy.call_args_list == [call("a")] # Only A should be reset (when revisited) @@ -623,7 +642,13 @@ async def timeout_invoke(*args, **kwargs): await asyncio.sleep(0.2) # Longer than node timeout return timeout_agent.return_value + async def timeout_stream(*args, **kwargs): + yield {"agent_start": True} + await asyncio.sleep(0.2) # Longer than node timeout + yield {"result": timeout_agent.return_value} + timeout_agent.invoke_async = AsyncMock(side_effect=timeout_invoke) + timeout_agent.stream_async = Mock(side_effect=timeout_stream) builder = GraphBuilder() builder.add_node(timeout_agent, "timeout_node") @@ -634,13 +659,13 @@ async def timeout_invoke(*args, **kwargs): assert result.status == Status.COMPLETED assert result.completed_nodes == 1 - # Test with very short node timeout - should raise timeout exception + # Test with very short node timeout - should raise timeout exception (fail-fast behavior) builder = GraphBuilder() builder.add_node(timeout_agent, "timeout_node") graph = builder.set_max_node_executions(50).set_execution_timeout(900.0).set_node_timeout(0.1).build() - # Execute the graph - should raise Exception due to timeout - with pytest.raises(Exception, match="Node 'timeout_node' execution timed out after 0.1s"): + # Execute the graph - should raise timeout exception (fail-fast behavior) + with pytest.raises(Exception, match="execution timed out"): await graph.invoke_async("Test node timeout") mock_strands_tracer.start_multiagent_span.assert_called() @@ -841,9 +866,9 @@ def test_graph_synchronous_execution(mock_strands_tracer, mock_use_span, mock_ag assert result.execution_order[0].node_id == "start_agent" assert result.execution_order[1].node_id == "final_agent" - # Verify agent calls - mock_agents["start_agent"].invoke_async.assert_called_once() - mock_agents["final_agent"].invoke_async.assert_called_once() + # Verify agent calls (via stream_async) + assert mock_agents["start_agent"].stream_async.call_count == 1 + assert mock_agents["final_agent"].stream_async.call_count == 1 # Verify return type is GraphResult assert isinstance(result, GraphResult) @@ -921,6 +946,12 @@ async def invoke_async(self, input_data, invocation_state=None): ), ) + async def stream_async(self, input_data, **kwargs): + # Stream implementation that yields events and final result + yield {"agent_start": True} + result = await self.invoke_async(input_data) + yield {"result": result} + # Create agents agent_a = StatefulAgent("agent_a") agent_b = StatefulAgent("agent_b") @@ -1041,9 +1072,9 @@ async def test_linear_graph_behavior(): assert result.execution_order[0].node_id == "a" assert result.execution_order[1].node_id == "b" - # Verify agents were called once each (no state reset) - agent_a.invoke_async.assert_called_once() - agent_b.invoke_async.assert_called_once() + # Verify agents were called once each (no state reset, via stream_async) + assert agent_a.stream_async.call_count == 1 + assert agent_b.stream_async.call_count == 1 @pytest.mark.asyncio @@ -1115,9 +1146,9 @@ def loop_condition(state: GraphState) -> bool: graph = builder.build() result = await graph.invoke_async("Test self loop") - # Verify basic self-loop functionality + # Verify basic self-loop functionality (via stream_async) assert result.status == Status.COMPLETED - assert self_loop_agent.invoke_async.call_count == 3 + assert self_loop_agent.stream_async.call_count == 3 assert len(result.execution_order) == 3 assert all(node.node_id == "self_loop" for node in result.execution_order) @@ -1177,9 +1208,9 @@ def end_condition(state: GraphState) -> bool: assert result.status == Status.COMPLETED assert len(result.execution_order) == 4 # start -> loop -> loop -> end assert [node.node_id for node in result.execution_order] == ["start_node", "loop_node", "loop_node", "end_node"] - assert start_agent.invoke_async.call_count == 1 - assert loop_agent.invoke_async.call_count == 2 - assert end_agent.invoke_async.call_count == 1 + assert start_agent.stream_async.call_count == 1 + assert loop_agent.stream_async.call_count == 2 + assert end_agent.stream_async.call_count == 1 @pytest.mark.asyncio @@ -1208,8 +1239,8 @@ def condition_b(state: GraphState) -> bool: assert result.status == Status.COMPLETED assert len(result.execution_order) == 4 # a -> a -> b -> b - assert agent_a.invoke_async.call_count == 2 - assert agent_b.invoke_async.call_count == 2 + assert agent_a.stream_async.call_count == 2 + assert agent_b.stream_async.call_count == 2 mock_strands_tracer.start_multiagent_span.assert_called() mock_use_span.assert_called() @@ -1284,7 +1315,7 @@ def multi_loop_condition(state: GraphState) -> bool: assert result.status == Status.COMPLETED assert len(result.execution_order) >= 2 - assert multi_agent.invoke_async.call_count >= 2 + assert multi_agent.stream_async.call_count >= 2 @pytest.mark.asyncio @@ -1300,9 +1331,8 @@ async def test_graph_kwargs_passing_agent(mock_strands_tracer, mock_use_span): test_invocation_state = {"custom_param": "test_value", "another_param": 42} result = await graph.invoke_async("Test kwargs passing", test_invocation_state) - kwargs_agent.invoke_async.assert_called_once_with( - [{"text": "Test kwargs passing"}], invocation_state=test_invocation_state - ) + # Verify stream_async was called (kwargs are passed through) + assert kwargs_agent.stream_async.call_count == 1 assert result.status == Status.COMPLETED @@ -1319,9 +1349,8 @@ async def test_graph_kwargs_passing_multiagent(mock_strands_tracer, mock_use_spa test_invocation_state = {"custom_param": "test_value", "another_param": 42} result = await graph.invoke_async("Test kwargs passing to multiagent", test_invocation_state) - kwargs_multiagent.invoke_async.assert_called_once_with( - [{"text": "Test kwargs passing to multiagent"}], test_invocation_state - ) + # Verify stream_async was called (kwargs are passed through) + assert kwargs_multiagent.stream_async.call_count == 1 assert result.status == Status.COMPLETED @@ -1337,7 +1366,555 @@ def test_graph_kwargs_passing_sync(mock_strands_tracer, mock_use_span): test_invocation_state = {"custom_param": "test_value", "another_param": 42} result = graph("Test kwargs passing sync", test_invocation_state) - kwargs_agent.invoke_async.assert_called_once_with( - [{"text": "Test kwargs passing sync"}], invocation_state=test_invocation_state - ) + # Verify stream_async was called (kwargs are passed through) + assert kwargs_agent.stream_async.call_count == 1 + assert result.status == Status.COMPLETED + + +@pytest.mark.asyncio +async def test_graph_streaming_events(mock_strands_tracer, mock_use_span, alist): + """Test that graph streaming emits proper events during execution.""" + # Create agents with custom streaming behavior + agent_a = create_mock_agent("agent_a", "Response A") + agent_b = create_mock_agent("agent_b", "Response B") + + # Track events from agent streams + agent_a_events = [ + {"agent_thinking": True, "thought": "Processing task A"}, + {"agent_progress": True, "step": "analyzing"}, + {"result": agent_a.return_value}, + ] + + agent_b_events = [ + {"agent_thinking": True, "thought": "Processing task B"}, + {"agent_progress": True, "step": "computing"}, + {"result": agent_b.return_value}, + ] + + async def stream_a(*args, **kwargs): + for event in agent_a_events: + yield event + + async def stream_b(*args, **kwargs): + for event in agent_b_events: + yield event + + agent_a.stream_async = Mock(side_effect=stream_a) + agent_b.stream_async = Mock(side_effect=stream_b) + + # Build graph: A -> B + builder = GraphBuilder() + builder.add_node(agent_a, "a") + builder.add_node(agent_b, "b") + builder.add_edge("a", "b") + builder.set_entry_point("a") + graph = builder.build() + + # Collect all streaming events + events = await alist(graph.stream_async("Test streaming")) + + # Verify event structure and order + assert len(events) > 0 + + # Should have node start/stop events and forwarded agent events + node_start_events = [e for e in events if e.get("type") == "multiagent_node_start"] + node_stop_events = [e for e in events if e.get("type") == "multiagent_node_stop"] + node_stream_events = [e for e in events if e.get("type") == "multiagent_node_stream"] + result_events = [e for e in events if "result" in e and e.get("type") != "multiagent_node_stream"] + + # Should have start/stop events for both nodes + assert len(node_start_events) == 2 + assert len(node_stop_events) == 2 + + # Should have forwarded agent events + assert len(node_stream_events) >= 4 # At least 2 events per agent + + # Should have final result + assert len(result_events) == 1 + + # Verify node start events have correct structure + for event in node_start_events: + assert "node_id" in event + assert "node_type" in event + assert event["node_type"] == "agent" + + # Verify node stop events have node_result with execution time + for event in node_stop_events: + assert "node_id" in event + assert "node_result" in event + node_result = event["node_result"] + assert hasattr(node_result, "execution_time") + assert isinstance(node_result.execution_time, int) + + # Verify forwarded events maintain node context + for event in node_stream_events: + assert "node_id" in event + assert event["node_id"] in ["a", "b"] + + # Verify final result + final_result = result_events[0]["result"] + assert final_result.status == Status.COMPLETED + + +@pytest.mark.asyncio +async def test_graph_streaming_parallel_events(mock_strands_tracer, mock_use_span, alist): + """Test that parallel graph execution properly streams events from concurrent nodes.""" + # Create agents that execute in parallel + agent_a = create_mock_agent("agent_a", "Response A") + agent_b = create_mock_agent("agent_b", "Response B") + agent_c = create_mock_agent("agent_c", "Response C") + + # Track timing and events + execution_order = [] + + async def stream_with_timing(node_id, delay=0.05): + execution_order.append(f"{node_id}_start") + yield {"node_start": True, "node": node_id} + await asyncio.sleep(delay) + yield {"node_progress": True, "node": node_id} + execution_order.append(f"{node_id}_end") + yield {"result": create_mock_agent(node_id, f"Response {node_id}").return_value} + + agent_a.stream_async = Mock(side_effect=lambda *args, **kwargs: stream_with_timing("A", 0.05)) + agent_b.stream_async = Mock(side_effect=lambda *args, **kwargs: stream_with_timing("B", 0.05)) + agent_c.stream_async = Mock(side_effect=lambda *args, **kwargs: stream_with_timing("C", 0.05)) + + # Build graph with parallel nodes + builder = GraphBuilder() + builder.add_node(agent_a, "a") + builder.add_node(agent_b, "b") + builder.add_node(agent_c, "c") + # All are entry points (parallel execution) + builder.set_entry_point("a") + builder.set_entry_point("b") + builder.set_entry_point("c") + graph = builder.build() + + # Collect streaming events + start_time = time.time() + events = await alist(graph.stream_async("Test parallel streaming")) + total_time = time.time() - start_time + + # Verify parallel execution timing + assert total_time < 0.2, f"Expected parallel execution, took {total_time}s" + + # Verify we get events from all nodes + node_stream_events = [e for e in events if e.get("type") == "multiagent_node_stream"] + nodes_with_events = set(e["node_id"] for e in node_stream_events) + assert nodes_with_events == {"a", "b", "c"} + + # Verify start events for all nodes + node_start_events = [e for e in events if e.get("type") == "multiagent_node_start"] + start_node_ids = set(e["node_id"] for e in node_start_events) + assert start_node_ids == {"a", "b", "c"} + + +@pytest.mark.asyncio +async def test_graph_streaming_with_failures(mock_strands_tracer, mock_use_span): + """Test graph streaming behavior when nodes fail.""" + # Create a failing agent + failing_agent = Mock(spec=Agent) + failing_agent.name = "failing_agent" + failing_agent.id = "fail_node" + failing_agent._session_manager = None + failing_agent.hooks = HookRegistry() + + async def failing_stream(*args, **kwargs): + yield {"agent_start": True} + yield {"agent_thinking": True, "thought": "About to fail"} + await asyncio.sleep(0.01) + raise Exception("Simulated streaming failure") + + async def failing_invoke(*args, **kwargs): + raise Exception("Simulated failure") + + failing_agent.stream_async = Mock(side_effect=failing_stream) + failing_agent.invoke_async = failing_invoke + + # Create successful agent + success_agent = create_mock_agent("success_agent", "Success") + + # Build graph + builder = GraphBuilder() + builder.add_node(failing_agent, "fail") + builder.add_node(success_agent, "success") + builder.set_entry_point("fail") + builder.set_entry_point("success") + graph = builder.build() + + # Collect events - graph should raise exception (fail-fast behavior) + events = [] + with pytest.raises(Exception, match="Simulated streaming failure"): + async for event in graph.stream_async("Test streaming with failure"): + events.append(event) + + # Should get some events before failure + assert len(events) > 0 + + # Should have node start events + node_start_events = [e for e in events if e.get("type") == "multiagent_node_start"] + assert len(node_start_events) >= 1 + + # Should have some forwarded events before failure + node_stream_events = [e for e in events if e.get("type") == "multiagent_node_stream"] + assert len(node_stream_events) >= 1 + + +@pytest.mark.asyncio +async def test_graph_parallel_execution(mock_strands_tracer, mock_use_span): + """Test that nodes without dependencies execute in parallel.""" + + # Create agents that track execution timing + execution_times = {} + + async def create_timed_agent(name, delay=0.1): + agent = create_mock_agent(name, f"{name} response") + + async def timed_invoke(*args, **kwargs): + start_time = time.time() + execution_times[name] = {"start": start_time} + await asyncio.sleep(delay) # Simulate work + end_time = time.time() + execution_times[name]["end"] = end_time + return agent.return_value + + async def timed_stream(*args, **kwargs): + # Simulate streaming by yielding some events then the final result + start_time = time.time() + execution_times[name] = {"start": start_time} + + # Yield a start event + yield {"agent_start": True, "node": name} + + await asyncio.sleep(delay) # Simulate work + + end_time = time.time() + execution_times[name]["end"] = end_time + + # Yield final result event + yield {"result": agent.return_value} + + agent.invoke_async = AsyncMock(side_effect=timed_invoke) + # Create a mock that returns the async generator directly + agent.stream_async = Mock(side_effect=timed_stream) + return agent + + # Create agents that should execute in parallel + agent_a = await create_timed_agent("agent_a", 0.1) + agent_b = await create_timed_agent("agent_b", 0.1) + agent_c = await create_timed_agent("agent_c", 0.1) + + # Create a dependent agent that should execute after the parallel ones + agent_d = await create_timed_agent("agent_d", 0.05) + + # Build graph: A, B, C execute in parallel, then D depends on all of them + builder = GraphBuilder() + builder.add_node(agent_a, "a") + builder.add_node(agent_b, "b") + builder.add_node(agent_c, "c") + builder.add_node(agent_d, "d") + + # D depends on A, B, and C + builder.add_edge("a", "d") + builder.add_edge("b", "d") + builder.add_edge("c", "d") + + # A, B, C are entry points (no dependencies) + builder.set_entry_point("a") + builder.set_entry_point("b") + builder.set_entry_point("c") + + graph = builder.build() + + # Execute the graph + start_time = time.time() + result = await graph.invoke_async("Test parallel execution") + total_time = time.time() - start_time + + # Verify successful execution assert result.status == Status.COMPLETED + assert result.completed_nodes == 4 + assert len(result.execution_order) == 4 + + # Verify all agents were called (via stream_async) + assert agent_a.stream_async.call_count == 1 + assert agent_b.stream_async.call_count == 1 + assert agent_c.stream_async.call_count == 1 + assert agent_d.stream_async.call_count == 1 + + # Verify parallel execution: A, B, C should have overlapping execution times + # If they were sequential, total time would be ~0.35s (3 * 0.1 + 0.05) + # If parallel, total time should be ~0.15s (max(0.1, 0.1, 0.1) + 0.05) + assert total_time < 0.4, f"Expected parallel execution to be faster, took {total_time}s" + + # Verify timing overlap for parallel nodes + a_start = execution_times["agent_a"]["start"] + b_start = execution_times["agent_b"]["start"] + c_start = execution_times["agent_c"]["start"] + + # All parallel nodes should start within a small time window + max_start_diff = max(a_start, b_start, c_start) - min(a_start, b_start, c_start) + assert max_start_diff < 0.1, f"Parallel nodes should start nearly simultaneously, diff: {max_start_diff}s" + + # D should start after A, B, C have finished + d_start = execution_times["agent_d"]["start"] + a_end = execution_times["agent_a"]["end"] + b_end = execution_times["agent_b"]["end"] + c_end = execution_times["agent_c"]["end"] + + latest_parallel_end = max(a_end, b_end, c_end) + assert d_start >= latest_parallel_end - 0.02, "Dependent node should start after parallel nodes complete" + + +@pytest.mark.asyncio +async def test_graph_single_node_optimization(mock_strands_tracer, mock_use_span): + """Test that single node execution uses direct path (optimization).""" + agent = create_mock_agent("single_agent", "Single response") + + builder = GraphBuilder() + builder.add_node(agent, "single") + graph = builder.build() + + result = await graph.invoke_async("Test single node") + + assert result.status == Status.COMPLETED + assert result.completed_nodes == 1 + assert agent.stream_async.call_count == 1 + + +@pytest.mark.asyncio +async def test_graph_parallel_with_failures(mock_strands_tracer, mock_use_span): + """Test parallel execution with some nodes failing.""" + # Create a failing agent + failing_agent = Mock(spec=Agent) + failing_agent.name = "failing_agent" + failing_agent.id = "fail_node" + failing_agent._session_manager = None + failing_agent.hooks = HookRegistry() + + async def mock_invoke_failure(*args, **kwargs): + await asyncio.sleep(0.05) # Small delay + raise Exception("Simulated failure") + + async def mock_stream_failure_parallel(*args, **kwargs): + # Simple mock stream that fails + yield {"agent_start": True} + await asyncio.sleep(0.05) # Small delay + raise Exception("Simulated failure") + + failing_agent.invoke_async = mock_invoke_failure + failing_agent.stream_async = Mock(side_effect=mock_stream_failure_parallel) + + # Create successful agents that take longer than the failing agent + success_agent_a = create_mock_agent("success_a", "Success A") + success_agent_b = create_mock_agent("success_b", "Success B") + + # Override their stream methods to take longer + async def slow_stream_a(*args, **kwargs): + yield {"agent_start": True, "node": "success_a"} + await asyncio.sleep(0.1) # Longer than failing agent + yield {"result": success_agent_a.return_value} + + async def slow_stream_b(*args, **kwargs): + yield {"agent_start": True, "node": "success_b"} + await asyncio.sleep(0.1) # Longer than failing agent + yield {"result": success_agent_b.return_value} + + success_agent_a.stream_async = Mock(side_effect=slow_stream_a) + success_agent_b.stream_async = Mock(side_effect=slow_stream_b) + + # Build graph with parallel execution where one fails + builder = GraphBuilder() + builder.add_node(failing_agent, "fail") + builder.add_node(success_agent_a, "success_a") + builder.add_node(success_agent_b, "success_b") + + # All are entry points (parallel) + builder.set_entry_point("fail") + builder.set_entry_point("success_a") + builder.set_entry_point("success_b") + + graph = builder.build() + + # Execute should raise exception (fail-fast behavior) + with pytest.raises(Exception, match="Simulated failure"): + await graph.invoke_async("Test parallel with failure") + + +@pytest.mark.asyncio +async def test_graph_single_invocation_no_double_execution(mock_strands_tracer, mock_use_span): + """Test that nodes are only invoked once (no double execution from streaming).""" + # Create agents with invocation counters + agent_a = create_mock_agent("agent_a", "Response A") + agent_b = create_mock_agent("agent_b", "Response B") + + # Track invocation counts + invocation_counts = {"agent_a": 0, "agent_b": 0} + + async def counted_stream_a(*args, **kwargs): + invocation_counts["agent_a"] += 1 + yield {"agent_start": True} + yield {"agent_thinking": True, "thought": "Processing A"} + yield {"result": agent_a.return_value} + + async def counted_stream_b(*args, **kwargs): + invocation_counts["agent_b"] += 1 + yield {"agent_start": True} + yield {"agent_thinking": True, "thought": "Processing B"} + yield {"result": agent_b.return_value} + + agent_a.stream_async = Mock(side_effect=counted_stream_a) + agent_b.stream_async = Mock(side_effect=counted_stream_b) + + # Build graph: A -> B + builder = GraphBuilder() + builder.add_node(agent_a, "a") + builder.add_node(agent_b, "b") + builder.add_edge("a", "b") + builder.set_entry_point("a") + graph = builder.build() + + # Execute the graph + result = await graph.invoke_async("Test single invocation") + + # Verify successful execution + assert result.status == Status.COMPLETED + + # CRITICAL: Each agent should be invoked exactly once + assert invocation_counts["agent_a"] == 1, f"Agent A invoked {invocation_counts['agent_a']} times, expected 1" + assert invocation_counts["agent_b"] == 1, f"Agent B invoked {invocation_counts['agent_b']} times, expected 1" + + # Verify stream_async was called but invoke_async was NOT called + assert agent_a.stream_async.call_count == 1 + assert agent_b.stream_async.call_count == 1 + # invoke_async should not be called at all since we're using streaming + agent_a.invoke_async.assert_not_called() + agent_b.invoke_async.assert_not_called() + + +@pytest.mark.asyncio +async def test_graph_parallel_single_invocation(mock_strands_tracer, mock_use_span): + """Test that parallel nodes are only invoked once each.""" + # Create parallel agents with invocation counters + invocation_counts = {"a": 0, "b": 0, "c": 0} + + async def create_counted_agent(name): + agent = create_mock_agent(name, f"Response {name}") + + async def counted_stream(*args, **kwargs): + invocation_counts[name] += 1 + yield {"agent_start": True, "node": name} + await asyncio.sleep(0.01) # Small delay + yield {"result": agent.return_value} + + agent.stream_async = Mock(side_effect=counted_stream) + return agent + + agent_a = await create_counted_agent("a") + agent_b = await create_counted_agent("b") + agent_c = await create_counted_agent("c") + + # Build graph with parallel nodes + builder = GraphBuilder() + builder.add_node(agent_a, "a") + builder.add_node(agent_b, "b") + builder.add_node(agent_c, "c") + builder.set_entry_point("a") + builder.set_entry_point("b") + builder.set_entry_point("c") + graph = builder.build() + + # Execute the graph + result = await graph.invoke_async("Test parallel single invocation") + + # Verify successful execution + assert result.status == Status.COMPLETED + + # CRITICAL: Each agent should be invoked exactly once + assert invocation_counts["a"] == 1, f"Agent A invoked {invocation_counts['a']} times, expected 1" + assert invocation_counts["b"] == 1, f"Agent B invoked {invocation_counts['b']} times, expected 1" + assert invocation_counts["c"] == 1, f"Agent C invoked {invocation_counts['c']} times, expected 1" + + # Verify stream_async was called but invoke_async was NOT called + assert agent_a.stream_async.call_count == 1 + assert agent_b.stream_async.call_count == 1 + assert agent_c.stream_async.call_count == 1 + agent_a.invoke_async.assert_not_called() + agent_b.invoke_async.assert_not_called() + agent_c.invoke_async.assert_not_called() + + +@pytest.mark.asyncio +async def test_graph_node_timeout_with_mocked_streaming(): + """Test that node timeout properly cancels a streaming generator that freezes.""" + # Create an agent that will timeout during streaming + slow_agent = Agent( + name="slow_agent", + model="us.amazon.nova-lite-v1:0", + system_prompt="You are a slow agent. Take your time responding.", + ) + + # Override stream_async to simulate a freezing generator + original_stream = slow_agent.stream_async + + async def freezing_stream(*args, **kwargs): + """Simulate a generator that yields some events then freezes.""" + # Yield a few events normally + count = 0 + async for event in original_stream(*args, **kwargs): + yield event + count += 1 + if count >= 3: + # Simulate freezing - sleep longer than timeout + await asyncio.sleep(10.0) + break + + slow_agent.stream_async = freezing_stream + + # Create graph with short node timeout + builder = GraphBuilder() + builder.add_node(slow_agent, "slow_node") + builder.set_node_timeout(0.5) # 500ms timeout + graph = builder.build() + + # Execute - should timeout and raise exception (fail-fast behavior) + with pytest.raises(Exception, match="execution timed out"): + await graph.invoke_async("Test freezing generator") + + +@pytest.mark.asyncio +async def test_graph_timeout_cleanup_on_exception(): + """Test that timeout properly cleans up tasks even when exceptions occur.""" + # Create an agent + agent = Agent( + name="test_agent", + model="us.amazon.nova-lite-v1:0", + system_prompt="You are a test agent.", + ) + + # Override stream_async to raise an exception after some events + original_stream = agent.stream_async + + async def exception_stream(*args, **kwargs): + """Simulate a generator that raises an exception.""" + count = 0 + async for event in original_stream(*args, **kwargs): + yield event + count += 1 + if count >= 2: + raise ValueError("Simulated error during streaming") + + agent.stream_async = exception_stream + + # Create graph with timeout + builder = GraphBuilder() + builder.add_node(agent, "test_node") + builder.set_node_timeout(30.0) + graph = builder.build() + + # Execute - the exception propagates through _stream_with_timeout + with pytest.raises(ValueError, match="Simulated error during streaming"): + await graph.invoke_async("Test exception handling") + + # Verify execution_time is set even on failure (via finally block) + assert graph.state.execution_time > 0, "execution_time should be set even when exception occurs" diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index 0968fd30c..8f049ba0c 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -1,3 +1,4 @@ +import asyncio import time from unittest.mock import MagicMock, Mock, patch @@ -53,7 +54,14 @@ def create_mock_result(): async def mock_invoke_async(*args, **kwargs): return create_mock_result() + async def mock_stream_async(*args, **kwargs): + # Simple mock stream that yields a start event and then the result + yield {"agent_start": True, "node": name} + yield {"agent_thinking": True, "thought": f"Processing with {name}"} + yield {"result": create_mock_result()} + agent.invoke_async = MagicMock(side_effect=mock_invoke_async) + agent.stream_async = Mock(side_effect=mock_stream_async) return agent @@ -231,8 +239,8 @@ async def test_swarm_execution_async(mock_strands_tracer, mock_use_span, mock_sw assert result.execution_count == 1 assert len(result.results) == 1 - # Verify agent was called - mock_agents["coordinator"].invoke_async.assert_called() + # Verify agent was called (via stream_async) + assert mock_agents["coordinator"].stream_async.call_count >= 1 # Verify metrics aggregation assert result.accumulated_usage["totalTokens"] >= 0 @@ -267,8 +275,8 @@ def test_swarm_synchronous_execution(mock_strands_tracer, mock_use_span, mock_ag assert len(result.results) == 1 assert result.execution_time >= 0 - # Verify agent was called - mock_agents["coordinator"].invoke_async.assert_called() + # Verify agent was called (via stream_async) + assert mock_agents["coordinator"].stream_async.call_count >= 1 # Verify return type is SwarmResult assert isinstance(result, SwarmResult) @@ -358,7 +366,13 @@ def create_handoff_result(): async def mock_invoke_async(*args, **kwargs): return create_handoff_result() + async def mock_stream_async(*args, **kwargs): + yield {"agent_start": True} + result = create_handoff_result() + yield {"result": result} + agent.invoke_async = MagicMock(side_effect=mock_invoke_async) + agent.stream_async = Mock(side_effect=mock_stream_async) return agent # Create agents - first one hands off, second one completes by not handing off @@ -384,9 +398,9 @@ async def mock_invoke_async(*args, **kwargs): # Verify the completion agent executed after handoff assert result.node_history[1].node_id == "completion_agent" - # Verify both agents were called - handoff_agent.invoke_async.assert_called() - completion_agent.invoke_async.assert_called() + # Verify both agents were called (via stream_async) + assert handoff_agent.stream_async.call_count >= 1 + assert completion_agent.stream_async.call_count >= 1 # Test handoff when task is already completed completed_swarm = Swarm(nodes=[handoff_agent, completion_agent]) @@ -447,8 +461,8 @@ def test_swarm_auto_completion_without_handoff(): assert len(result.node_history) == 1 assert result.node_history[0].node_id == "no_handoff_agent" - # Verify the agent was called - no_handoff_agent.invoke_async.assert_called() + # Verify the agent was called (via stream_async) + assert no_handoff_agent.stream_async.call_count >= 1 def test_swarm_configurable_entry_point(): @@ -551,26 +565,485 @@ def test_swarm_validate_unsupported_features(): async def test_swarm_kwargs_passing(mock_strands_tracer, mock_use_span): """Test that kwargs are passed through to underlying agents.""" kwargs_agent = create_mock_agent("kwargs_agent", "Response with kwargs") - kwargs_agent.invoke_async = Mock(side_effect=kwargs_agent.invoke_async) swarm = Swarm(nodes=[kwargs_agent]) test_kwargs = {"custom_param": "test_value", "another_param": 42} result = await swarm.invoke_async("Test kwargs passing", test_kwargs) - assert kwargs_agent.invoke_async.call_args.kwargs == {"invocation_state": test_kwargs} + # Verify stream_async was called (kwargs are passed through) + assert kwargs_agent.stream_async.call_count >= 1 assert result.status == Status.COMPLETED def test_swarm_kwargs_passing_sync(mock_strands_tracer, mock_use_span): """Test that kwargs are passed through to underlying agents in sync execution.""" kwargs_agent = create_mock_agent("kwargs_agent", "Response with kwargs") - kwargs_agent.invoke_async = Mock(side_effect=kwargs_agent.invoke_async) swarm = Swarm(nodes=[kwargs_agent]) test_kwargs = {"custom_param": "test_value", "another_param": 42} result = swarm("Test kwargs passing sync", test_kwargs) - assert kwargs_agent.invoke_async.call_args.kwargs == {"invocation_state": test_kwargs} + # Verify stream_async was called (kwargs are passed through) + assert kwargs_agent.stream_async.call_count >= 1 assert result.status == Status.COMPLETED + + +@pytest.mark.asyncio +async def test_swarm_streaming_events(mock_strands_tracer, mock_use_span, alist): + """Test that swarm streaming emits proper events during execution.""" + + # Create agents with custom streaming behavior + coordinator = create_mock_agent("coordinator", "Coordinating task") + specialist = create_mock_agent("specialist", "Specialized response") + + # Track events and execution order + execution_events = [] + + async def coordinator_stream(*args, **kwargs): + execution_events.append("coordinator_start") + yield {"agent_start": True, "node": "coordinator"} + yield {"agent_thinking": True, "thought": "Analyzing task"} + await asyncio.sleep(0.01) # Small delay + execution_events.append("coordinator_end") + yield {"result": coordinator.return_value} + + async def specialist_stream(*args, **kwargs): + execution_events.append("specialist_start") + yield {"agent_start": True, "node": "specialist"} + yield {"agent_thinking": True, "thought": "Applying expertise"} + await asyncio.sleep(0.01) # Small delay + execution_events.append("specialist_end") + yield {"result": specialist.return_value} + + coordinator.stream_async = Mock(side_effect=coordinator_stream) + specialist.stream_async = Mock(side_effect=specialist_stream) + + # Create swarm with handoff logic + swarm = Swarm(nodes=[coordinator, specialist], max_handoffs=2, max_iterations=3, execution_timeout=30.0) + + # Add handoff tool to coordinator to trigger specialist + def handoff_to_specialist(): + """Hand off to specialist for detailed analysis.""" + return specialist + + coordinator.tool_registry.registry = {"handoff_to_specialist": handoff_to_specialist} + + # Collect all streaming events + events = await alist(swarm.stream_async("Test swarm streaming")) + + # Verify event structure + assert len(events) > 0 + + # Should have node start/stop events + node_start_events = [e for e in events if e.get("type") == "multiagent_node_start"] + node_stop_events = [e for e in events if e.get("type") == "multiagent_node_stop"] + node_stream_events = [e for e in events if e.get("type") == "multiagent_node_stream"] + result_events = [e for e in events if "result" in e and e.get("type") != "multiagent_node_stream"] + + # Should have at least one node execution + assert len(node_start_events) >= 1 + assert len(node_stop_events) >= 1 + + # Should have forwarded agent events + assert len(node_stream_events) >= 2 # At least some events per agent + + # Should have final result + assert len(result_events) == 1 + + # Verify node start events have correct structure + for event in node_start_events: + assert "node_id" in event + assert "node_type" in event + assert event["node_type"] == "agent" + + # Verify node stop events have node_result with execution time + for event in node_stop_events: + assert "node_id" in event + assert "node_result" in event + node_result = event["node_result"] + assert hasattr(node_result, "execution_time") + assert isinstance(node_result.execution_time, int) + + # Verify forwarded events maintain node context + for event in node_stream_events: + assert "node_id" in event + + # Verify final result + final_result = result_events[0]["result"] + assert final_result.status == Status.COMPLETED + + +@pytest.mark.asyncio +async def test_swarm_streaming_with_handoffs(mock_strands_tracer, mock_use_span, alist): + """Test swarm streaming with agent handoffs.""" + + # Create agents + coordinator = create_mock_agent("coordinator", "Coordinating") + specialist = create_mock_agent("specialist", "Specialized work") + reviewer = create_mock_agent("reviewer", "Review complete") + + # Track handoff sequence + handoff_sequence = [] + + async def coordinator_stream(*args, **kwargs): + yield {"agent_start": True, "node": "coordinator"} + yield {"agent_thinking": True, "thought": "Need specialist help"} + handoff_sequence.append("coordinator_to_specialist") + yield {"result": coordinator.return_value} + + async def specialist_stream(*args, **kwargs): + yield {"agent_start": True, "node": "specialist"} + yield {"agent_thinking": True, "thought": "Doing specialized work"} + handoff_sequence.append("specialist_to_reviewer") + yield {"result": specialist.return_value} + + async def reviewer_stream(*args, **kwargs): + yield {"agent_start": True, "node": "reviewer"} + yield {"agent_thinking": True, "thought": "Reviewing work"} + handoff_sequence.append("reviewer_complete") + yield {"result": reviewer.return_value} + + coordinator.stream_async = Mock(side_effect=coordinator_stream) + specialist.stream_async = Mock(side_effect=specialist_stream) + reviewer.stream_async = Mock(side_effect=reviewer_stream) + + # Set up handoff tools + def handoff_to_specialist(): + return specialist + + def handoff_to_reviewer(): + return reviewer + + coordinator.tool_registry.registry = {"handoff_to_specialist": handoff_to_specialist} + specialist.tool_registry.registry = {"handoff_to_reviewer": handoff_to_reviewer} + reviewer.tool_registry.registry = {} + + # Create swarm + swarm = Swarm(nodes=[coordinator, specialist, reviewer], max_handoffs=5, max_iterations=5, execution_timeout=30.0) + + # Collect streaming events + events = await alist(swarm.stream_async("Test handoff streaming")) + + # Should have multiple node executions due to handoffs + node_start_events = [e for e in events if e.get("type") == "multiagent_node_start"] + handoff_events = [e for e in events if e.get("type") == "multiagent_handoff"] + + # Should have executed at least one agent (handoffs are complex to mock) + assert len(node_start_events) >= 1 + + # Verify handoff events have proper structure if any occurred + for event in handoff_events: + assert "from_node_ids" in event + assert "to_node_ids" in event + assert isinstance(event["from_node_ids"], list) + assert isinstance(event["to_node_ids"], list) + + +@pytest.mark.asyncio +async def test_swarm_streaming_with_failures(mock_strands_tracer, mock_use_span): + """Test swarm streaming behavior when agents fail.""" + + # Create a failing agent (don't fail during creation, fail during execution) + failing_agent = create_mock_agent("failing_agent", "Should fail") + success_agent = create_mock_agent("success_agent", "Success") + + async def failing_stream(*args, **kwargs): + yield {"agent_start": True, "node": "failing_agent"} + yield {"agent_thinking": True, "thought": "About to fail"} + await asyncio.sleep(0.01) + raise Exception("Simulated streaming failure") + + async def success_stream(*args, **kwargs): + yield {"agent_start": True, "node": "success_agent"} + yield {"agent_thinking": True, "thought": "Working successfully"} + yield {"result": success_agent.return_value} + + failing_agent.stream_async = Mock(side_effect=failing_stream) + success_agent.stream_async = Mock(side_effect=success_stream) + + # Create swarm starting with failing agent + swarm = Swarm(nodes=[failing_agent, success_agent], max_handoffs=2, max_iterations=3, execution_timeout=30.0) + + # Collect events until failure + events = [] + # Note: We expect an exception but swarm might handle it gracefully + # So we don't use pytest.raises here - we check for either success or failure + try: + async for event in swarm.stream_async("Test streaming with failure"): + events.append(event) + except Exception: + pass # Expected - failure during streaming + + # Should get some events before failure (if failure occurred) + if len(events) > 0: + # Should have node start events + node_start_events = [e for e in events if e.get("type") == "multiagent_node_start"] + assert len(node_start_events) >= 1 + + # Should have some forwarded events before failure + node_stream_events = [e for e in events if e.get("type") == "multiagent_node_stream"] + assert len(node_stream_events) >= 1 + + +@pytest.mark.asyncio +async def test_swarm_streaming_timeout_behavior(mock_strands_tracer, mock_use_span): + """Test swarm streaming with execution timeout.""" + + # Create a slow agent + slow_agent = create_mock_agent("slow_agent", "Slow response") + + async def slow_stream(*args, **kwargs): + yield {"agent_start": True, "node": "slow_agent"} + yield {"agent_thinking": True, "thought": "Taking my time"} + await asyncio.sleep(0.2) # Longer than timeout + yield {"result": slow_agent.return_value} + + slow_agent.stream_async = Mock(side_effect=slow_stream) + + # Create swarm with short timeout + swarm = Swarm( + nodes=[slow_agent], + max_handoffs=1, + max_iterations=1, + execution_timeout=0.1, # Very short timeout + ) + + # Should timeout during streaming or complete + # Note: Timeout behavior is timing-dependent, so we accept both outcomes + events = [] + try: + async for event in swarm.stream_async("Test timeout streaming"): + events.append(event) + except Exception: + pass # Timeout is acceptable + + # Should get at least some events regardless of timeout + assert len(events) >= 1 + + +@pytest.mark.asyncio +async def test_swarm_streaming_backward_compatibility(mock_strands_tracer, mock_use_span, alist): + """Test that swarm streaming maintains backward compatibility.""" + # Create simple agent + agent = create_mock_agent("test_agent", "Test response") + + # Create swarm + swarm = Swarm(nodes=[agent]) + + # Test that invoke_async still works + result = await swarm.invoke_async("Test backward compatibility") + assert result.status == Status.COMPLETED + + # Test that streaming also works and produces same result + events = await alist(swarm.stream_async("Test backward compatibility")) + + # Should have final result event + result_events = [e for e in events if "result" in e and e.get("type") != "multiagent_node_stream"] + assert len(result_events) == 1 + + streaming_result = result_events[0]["result"] + assert streaming_result.status == Status.COMPLETED + + # Results should be equivalent + assert result.status == streaming_result.status + + +@pytest.mark.asyncio +async def test_swarm_single_invocation_no_double_execution(mock_strands_tracer, mock_use_span): + """Test that swarm nodes are only invoked once (no double execution from streaming).""" + # Create agent with invocation counter + agent = create_mock_agent("test_agent", "Test response") + + # Track invocation count + invocation_count = {"count": 0} + + async def counted_stream(*args, **kwargs): + invocation_count["count"] += 1 + yield {"agent_start": True, "node": "test_agent"} + yield {"agent_thinking": True, "thought": "Processing"} + yield {"result": agent.return_value} + + agent.stream_async = Mock(side_effect=counted_stream) + + # Create swarm + swarm = Swarm(nodes=[agent]) + + # Execute the swarm + result = await swarm.invoke_async("Test single invocation") + + # Verify successful execution + assert result.status == Status.COMPLETED + + # CRITICAL: Agent should be invoked exactly once + assert invocation_count["count"] == 1, f"Agent invoked {invocation_count['count']} times, expected 1" + + # Verify stream_async was called but invoke_async was NOT called + assert agent.stream_async.call_count == 1 + # invoke_async should not be called at all since we're using streaming + agent.invoke_async.assert_not_called() + + +@pytest.mark.asyncio +async def test_swarm_handoff_single_invocation_per_node(mock_strands_tracer, mock_use_span): + """Test that each node in a swarm handoff chain is invoked exactly once.""" + # Create agents with invocation counters + invocation_counts = {"coordinator": 0, "specialist": 0} + + coordinator = create_mock_agent("coordinator", "Coordinating") + specialist = create_mock_agent("specialist", "Specialized work") + + async def coordinator_stream(*args, **kwargs): + invocation_counts["coordinator"] += 1 + yield {"agent_start": True, "node": "coordinator"} + yield {"agent_thinking": True, "thought": "Need specialist"} + yield {"result": coordinator.return_value} + + async def specialist_stream(*args, **kwargs): + invocation_counts["specialist"] += 1 + yield {"agent_start": True, "node": "specialist"} + yield {"agent_thinking": True, "thought": "Doing specialized work"} + yield {"result": specialist.return_value} + + coordinator.stream_async = Mock(side_effect=coordinator_stream) + specialist.stream_async = Mock(side_effect=specialist_stream) + + # Set up handoff tool + def handoff_to_specialist(): + return specialist + + coordinator.tool_registry.registry = {"handoff_to_specialist": handoff_to_specialist} + specialist.tool_registry.registry = {} + + # Create swarm + swarm = Swarm(nodes=[coordinator, specialist], max_handoffs=2, max_iterations=3) + + # Execute the swarm + result = await swarm.invoke_async("Test handoff single invocation") + + # Verify successful execution + assert result.status == Status.COMPLETED + + # CRITICAL: Each agent should be invoked exactly once + # Note: Actual invocation depends on whether handoff occurs, but no double execution + assert invocation_counts["coordinator"] == 1, f"Coordinator invoked {invocation_counts['coordinator']} times" + # Specialist may or may not be invoked depending on handoff logic, but if invoked, only once + assert invocation_counts["specialist"] <= 1, f"Specialist invoked {invocation_counts['specialist']} times" + + # Verify stream_async was called but invoke_async was NOT called + assert coordinator.stream_async.call_count == 1 + coordinator.invoke_async.assert_not_called() + if invocation_counts["specialist"] > 0: + specialist.invoke_async.assert_not_called() + + +@pytest.mark.asyncio +async def test_swarm_timeout_with_streaming(mock_strands_tracer, mock_use_span): + """Test that swarm node timeout works correctly with streaming.""" + # Create a slow agent + slow_agent = create_mock_agent("slow_agent", "Slow response") + + async def slow_stream(*args, **kwargs): + yield {"agent_start": True, "node": "slow_agent"} + await asyncio.sleep(0.3) # Longer than timeout + yield {"result": slow_agent.return_value} + + slow_agent.stream_async = Mock(side_effect=slow_stream) + + # Create swarm with short node timeout + swarm = Swarm( + nodes=[slow_agent], + max_handoffs=1, + max_iterations=1, + node_timeout=0.1, # Short timeout + ) + + # Execute - should complete with FAILED status due to timeout + result = await swarm.invoke_async("Test timeout") + + # Verify the swarm failed due to timeout + assert result.status == Status.FAILED + + # Verify the agent started streaming + assert slow_agent.stream_async.call_count == 1 + + +@pytest.mark.asyncio +async def test_swarm_node_timeout_with_mocked_streaming(): + """Test that swarm node timeout properly cancels a streaming generator that freezes.""" + # Create an agent that will timeout during streaming + slow_agent = Agent( + name="slow_agent", + model="us.amazon.nova-lite-v1:0", + system_prompt="You are a slow agent. Take your time responding.", + ) + + # Override stream_async to simulate a freezing generator + original_stream = slow_agent.stream_async + + async def freezing_stream(*args, **kwargs): + """Simulate a generator that yields some events then freezes.""" + # Yield a few events normally + count = 0 + async for event in original_stream(*args, **kwargs): + yield event + count += 1 + if count >= 3: + # Simulate freezing - sleep longer than timeout + await asyncio.sleep(10.0) + break + + slow_agent.stream_async = freezing_stream + + # Create swarm with short node timeout + swarm = Swarm( + nodes=[slow_agent], + max_handoffs=1, + max_iterations=1, + node_timeout=0.5, # 500ms timeout + ) + + # Execute - should complete with FAILED status due to timeout + result = await swarm.invoke_async("Test freezing generator") + assert result.status == Status.FAILED + + +@pytest.mark.asyncio +async def test_swarm_timeout_cleanup_on_exception(): + """Test that timeout properly cleans up tasks even when exceptions occur.""" + # Create an agent + agent = Agent( + name="test_agent", + model="us.amazon.nova-lite-v1:0", + system_prompt="You are a test agent.", + ) + + # Override stream_async to raise an exception after some events + original_stream = agent.stream_async + + async def exception_stream(*args, **kwargs): + """Simulate a generator that raises an exception.""" + count = 0 + async for event in original_stream(*args, **kwargs): + yield event + count += 1 + if count >= 2: + raise ValueError("Simulated error during streaming") + + agent.stream_async = exception_stream + + # Create swarm with timeout + swarm = Swarm( + nodes=[agent], + max_handoffs=1, + max_iterations=1, + node_timeout=30.0, + ) + + # Execute - swarm catches exceptions and continues, marking node as failed + result = await swarm.invoke_async("Test exception handling") + # Verify the node failed + assert "test_agent" in result.results + assert result.results["test_agent"].status == Status.FAILED + assert result.status == Status.FAILED diff --git a/tests_integ/test_multiagent_graph.py b/tests_integ/test_multiagent_graph.py index c2c13c443..a7335feb7 100644 --- a/tests_integ/test_multiagent_graph.py +++ b/tests_integ/test_multiagent_graph.py @@ -1,3 +1,5 @@ +from typing import Any, AsyncIterator + import pytest from strands import Agent, tool @@ -9,6 +11,7 @@ BeforeModelCallEvent, MessageAddedEvent, ) +from strands.multiagent.base import MultiAgentBase, MultiAgentResult, NodeResult, Status from strands.multiagent.graph import GraphBuilder from strands.types.content import ContentBlock from tests.fixtures.mock_hook_provider import MockHookProvider @@ -218,3 +221,240 @@ async def test_graph_execution_with_image(image_analysis_agent, summary_agent, y assert hook_provider.extract_for(image_analysis_agent).event_types_received == expected_hook_events assert hook_provider.extract_for(summary_agent).event_types_received == expected_hook_events + + +class CustomStreamingNode(MultiAgentBase): + """Custom node that wraps an agent and adds custom streaming events.""" + + def __init__(self, agent: Agent, name: str): + self.agent = agent + self.name = name + + async def invoke_async( + self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> MultiAgentResult: + result = await self.agent.invoke_async(task, **kwargs) + node_result = NodeResult(result=result, status=Status.COMPLETED) + return MultiAgentResult(status=Status.COMPLETED, results={self.name: node_result}) + + async def stream_async( + self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> AsyncIterator[dict[str, Any]]: + yield {"custom_event": "start", "node": self.name} + result = await self.agent.invoke_async(task, **kwargs) + yield {"custom_event": "agent_complete", "node": self.name} + node_result = NodeResult(result=result, status=Status.COMPLETED) + yield {"result": MultiAgentResult(status=Status.COMPLETED, results={self.name: node_result})} + + +@pytest.mark.asyncio +async def test_graph_streaming_with_agents(alist): + """Test that Graph properly streams events from agent nodes.""" + math_agent = Agent( + name="math", + model="us.amazon.nova-pro-v1:0", + system_prompt="You are a math assistant.", + tools=[calculate_sum], + ) + summary_agent = Agent( + name="summary", + model="us.amazon.nova-lite-v1:0", + system_prompt="You are a summary assistant.", + ) + + builder = GraphBuilder() + builder.add_node(math_agent, "math") + builder.add_node(summary_agent, "summary") + builder.add_edge("math", "summary") + builder.set_entry_point("math") + builder.set_node_timeout(900.0) # Verify timeout doesn't interfere with streaming + graph = builder.build() + + # Collect events + events = await alist(graph.stream_async("Calculate 5 + 3 and summarize the result")) + + # Count event categories + node_start_events = [e for e in events if e.get("type") == "multiagent_node_start"] + node_stream_events = [e for e in events if e.get("type") == "multiagent_node_stream"] + node_stop_events = [e for e in events if e.get("type") == "multiagent_node_stop"] + handoff_events = [e for e in events if e.get("type") == "multiagent_handoff"] + result_events = [e for e in events if "result" in e and e.get("type") != "multiagent_node_stream"] + + # Verify we got multiple events of each type + assert len(node_start_events) >= 2, f"Expected at least 2 node_start events, got {len(node_start_events)}" + assert len(node_stream_events) > 10, f"Expected many node_stream events, got {len(node_stream_events)}" + assert len(node_stop_events) >= 2, f"Expected at least 2 node_stop events, got {len(node_stop_events)}" + assert len(handoff_events) >= 1, f"Expected at least 1 handoff event, got {len(handoff_events)}" + assert len(result_events) >= 1, f"Expected at least 1 result event, got {len(result_events)}" + + # Verify handoff event structure + handoff = handoff_events[0] + assert "from_node_ids" in handoff, "Handoff event missing from_node_ids" + assert "to_node_ids" in handoff, "Handoff event missing to_node_ids" + assert isinstance(handoff["from_node_ids"], list), "from_node_ids should be a list" + assert isinstance(handoff["to_node_ids"], list), "to_node_ids should be a list" + assert "math" in handoff["from_node_ids"], "Expected math in from_node_ids" + assert "summary" in handoff["to_node_ids"], "Expected summary in to_node_ids" + + # Verify we have events for both nodes + math_events = [e for e in events if e.get("node_id") == "math"] + summary_events = [e for e in events if e.get("node_id") == "summary"] + assert len(math_events) > 0, "Expected events from math node" + assert len(summary_events) > 0, "Expected events from summary node" + + +@pytest.mark.asyncio +async def test_graph_streaming_with_custom_node(alist): + """Test that Graph properly streams events from custom MultiAgentBase nodes.""" + math_agent = Agent( + name="math", + model="us.amazon.nova-pro-v1:0", + system_prompt="You are a math assistant.", + tools=[calculate_sum], + ) + summary_agent = Agent( + name="summary", + model="us.amazon.nova-lite-v1:0", + system_prompt="You are a summary assistant.", + ) + + # Create a custom node + custom_node = CustomStreamingNode(summary_agent, "custom_summary") + + builder = GraphBuilder() + builder.add_node(math_agent, "math") + builder.add_node(custom_node, "custom_summary") + builder.add_edge("math", "custom_summary") + builder.set_entry_point("math") + graph = builder.build() + + # Collect events + events = await alist(graph.stream_async("Calculate 5 + 3 and summarize the result")) + + # Count event categories + node_start_events = [e for e in events if e.get("type") == "multiagent_node_start"] + node_stream_events = [e for e in events if e.get("type") == "multiagent_node_stream"] + result_events = [e for e in events if "result" in e and e.get("type") != "multiagent_node_stream"] + + # Extract custom events from wrapped node_stream events + # Structure: {"type": "multiagent_node_stream", "node_id": "...", "event": {...}} + custom_events = [] + for e in node_stream_events: + if e.get("type") == "multiagent_node_stream" and "event" in e: + inner_event = e["event"] + if isinstance(inner_event, dict) and "custom_event" in inner_event: + custom_events.append(inner_event) + + # Verify we got multiple events of each type + assert len(node_start_events) >= 2, f"Expected at least 2 node_start events, got {len(node_start_events)}" + assert len(node_stream_events) > 5, f"Expected many node_stream events, got {len(node_stream_events)}" + assert len(custom_events) >= 2, f"Expected at least 2 custom events (start, complete), got {len(custom_events)}" + assert len(result_events) >= 1, f"Expected at least 1 result event, got {len(result_events)}" + + # Verify custom events are properly structured + custom_start = [e for e in custom_events if e.get("custom_event") == "start"] + custom_complete = [e for e in custom_events if e.get("custom_event") == "agent_complete"] + + assert len(custom_start) >= 1, "Expected at least 1 custom start event" + assert len(custom_complete) >= 1, "Expected at least 1 custom complete event" + + +@pytest.mark.asyncio +async def test_nested_graph_streaming(alist): + """Test that nested graphs properly propagate streaming events.""" + math_agent = Agent( + name="math", + model="us.amazon.nova-pro-v1:0", + system_prompt="You are a math assistant.", + tools=[calculate_sum], + ) + analysis_agent = Agent( + name="analysis", + model="us.amazon.nova-lite-v1:0", + system_prompt="You are an analysis assistant.", + ) + + # Create nested graph + nested_builder = GraphBuilder() + nested_builder.add_node(math_agent, "calculator") + nested_builder.add_node(analysis_agent, "analyzer") + nested_builder.add_edge("calculator", "analyzer") + nested_builder.set_entry_point("calculator") + nested_graph = nested_builder.build() + + # Create outer graph with nested graph + summary_agent = Agent( + name="summary", + model="us.amazon.nova-lite-v1:0", + system_prompt="You are a summary assistant.", + ) + + outer_builder = GraphBuilder() + outer_builder.add_node(nested_graph, "computation") + outer_builder.add_node(summary_agent, "summary") + outer_builder.add_edge("computation", "summary") + outer_builder.set_entry_point("computation") + outer_graph = outer_builder.build() + + # Collect events + events = await alist(outer_graph.stream_async("Calculate 7 + 8 and provide a summary")) + + # Count event categories + node_start_events = [e for e in events if e.get("type") == "multiagent_node_start"] + node_stream_events = [e for e in events if e.get("type") == "multiagent_node_stream"] + result_events = [e for e in events if "result" in e and e.get("type") != "multiagent_node_stream"] + + # Verify we got multiple events + assert len(node_start_events) >= 2, f"Expected at least 2 node_start events, got {len(node_start_events)}" + assert len(node_stream_events) > 10, f"Expected many node_stream events, got {len(node_stream_events)}" + assert len(result_events) >= 1, f"Expected at least 1 result event, got {len(result_events)}" + + # Verify we have events from nested nodes + computation_events = [e for e in events if e.get("node_id") == "computation"] + summary_events = [e for e in events if e.get("node_id") == "summary"] + assert len(computation_events) > 0, "Expected events from computation (nested graph) node" + assert len(summary_events) > 0, "Expected events from summary node" + + +@pytest.mark.asyncio +async def test_graph_metrics_accumulation(): + """Test that graph properly accumulates metrics from agent nodes.""" + math_agent = Agent( + name="math", + model="us.amazon.nova-pro-v1:0", + system_prompt="You are a math assistant.", + tools=[calculate_sum], + ) + summary_agent = Agent( + name="summary", + model="us.amazon.nova-lite-v1:0", + system_prompt="You are a summary assistant.", + ) + + builder = GraphBuilder() + builder.add_node(math_agent, "math") + builder.add_node(summary_agent, "summary") + builder.add_edge("math", "summary") + builder.set_entry_point("math") + graph = builder.build() + + result = await graph.invoke_async("Calculate 5 + 3 and summarize the result") + + # Verify result has accumulated metrics + assert result.accumulated_usage is not None + assert result.accumulated_usage["totalTokens"] > 0, "Expected non-zero total tokens" + assert result.accumulated_usage["inputTokens"] > 0, "Expected non-zero input tokens" + assert result.accumulated_usage["outputTokens"] > 0, "Expected non-zero output tokens" + + assert result.accumulated_metrics is not None + assert result.accumulated_metrics["latencyMs"] > 0, "Expected non-zero latency" + + # Verify individual node results have metrics + for node_id, node_result in result.results.items(): + assert node_result.accumulated_usage is not None, f"Node {node_id} missing usage metrics" + assert node_result.accumulated_usage["totalTokens"] > 0, f"Node {node_id} has zero total tokens" + assert node_result.accumulated_metrics is not None, f"Node {node_id} missing metrics" + + # Verify accumulated metrics are sum of node metrics + total_tokens = sum(node_result.accumulated_usage["totalTokens"] for node_result in result.results.values()) + assert result.accumulated_usage["totalTokens"] == total_tokens, "Accumulated tokens don't match sum of node tokens" diff --git a/tests_integ/test_multiagent_swarm.py b/tests_integ/test_multiagent_swarm.py index 9a8c79bf8..1a0dd286e 100644 --- a/tests_integ/test_multiagent_swarm.py +++ b/tests_integ/test_multiagent_swarm.py @@ -134,3 +134,61 @@ async def test_swarm_execution_with_image(researcher_agent, analyst_agent, write # Verify agent history - at least one agent should have been used assert len(result.node_history) > 0 + + +@pytest.mark.asyncio +async def test_swarm_streaming(alist): + """Test that Swarm properly streams all event types during execution.""" + researcher = Agent( + name="researcher", + model="us.amazon.nova-pro-v1:0", + system_prompt="You are a researcher. When you need calculations, hand off to the analyst.", + ) + analyst = Agent( + name="analyst", + model="us.amazon.nova-pro-v1:0", + system_prompt="You are an analyst. Use tools to perform calculations.", + tools=[calculate], + ) + + swarm = Swarm([researcher, analyst], node_timeout=900.0) + + # Collect events + events = await alist(swarm.stream_async("Calculate 10 + 5 and explain the result")) + + # Count event categories + node_start_events = [e for e in events if e.get("type") == "multiagent_node_start"] + node_stream_events = [e for e in events if e.get("type") == "multiagent_node_stream"] + node_stop_events = [e for e in events if e.get("type") == "multiagent_node_stop"] + handoff_events = [e for e in events if e.get("type") == "multiagent_handoff"] + result_events = [e for e in events if "result" in e and e.get("type") != "multiagent_node_stream"] + + # Verify we got multiple events of each type + assert len(node_start_events) >= 1, f"Expected at least 1 node_start event, got {len(node_start_events)}" + assert len(node_stream_events) > 10, f"Expected many node_stream events, got {len(node_stream_events)}" + assert len(node_stop_events) >= 1, f"Expected at least 1 node_stop event, got {len(node_stop_events)}" + assert len(handoff_events) >= 1, f"Expected at least 1 handoff event, got {len(handoff_events)}" + assert len(result_events) >= 1, f"Expected at least 1 result event, got {len(result_events)}" + + # Verify handoff event structure + handoff = handoff_events[0] + assert "from_node_ids" in handoff, "Handoff event missing from_node_ids" + assert "to_node_ids" in handoff, "Handoff event missing to_node_ids" + assert "message" in handoff, "Handoff event missing message" + assert handoff["from_node_ids"] == ["researcher"], ( + f"Expected from_node_ids=['researcher'], got {handoff['from_node_ids']}" + ) + assert handoff["to_node_ids"] == ["analyst"], f"Expected to_node_ids=['analyst'], got {handoff['to_node_ids']}" + + # Verify node stop event structure + stop_event = node_stop_events[0] + assert "node_id" in stop_event, "Node stop event missing node_id" + assert "node_result" in stop_event, "Node stop event missing node_result" + node_result = stop_event["node_result"] + assert hasattr(node_result, "execution_time"), "NodeResult missing execution_time" + assert node_result.execution_time > 0, "Expected positive execution_time" + + # Verify we have events from at least one agent + researcher_events = [e for e in events if e.get("node_id") == "researcher"] + analyst_events = [e for e in events if e.get("node_id") == "analyst"] + assert len(researcher_events) > 0 or len(analyst_events) > 0, "Expected events from at least one agent"