Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
6c00bbe
feat(multiagent): Add stream async
Oct 2, 2025
b09b539
Merge branch 'main' into multiagent-streaming
Oct 2, 2025
08141a0
fix(graph): improve parallel node calling
Oct 2, 2025
d4f5571
fix: Fix double execution
Oct 2, 2025
fc0a272
fix: improve graph timeout
Oct 3, 2025
ca59221
Merge branch 'main' into multiagent-streaming
Oct 3, 2025
60f16b9
fix: Add integ tests
Oct 3, 2025
a307f37
refactor(multiagent): improve streaming event handling and documentation
Oct 10, 2025
24502fc
fix(multiagent): remove no-op asyncio.gather in parallel execution
Oct 10, 2025
dd5445a
refactor: Fix streaming timeout logic
Oct 13, 2025
050c369
refactor: rename result to multiagent_result
Oct 13, 2025
defb5e5
refactor: simplify timeout logic
Oct 13, 2025
0b49c15
refactor: exception handling in graphs
Oct 14, 2025
6b64254
refactor: use alist in tests
Oct 14, 2025
d035654
Merge branch 'main' into multiagent-streaming
Oct 14, 2025
f018ea0
feat(multiagent): add type details to result events
Oct 14, 2025
3df5ee3
refactor: include node result in node complete event
Oct 14, 2025
19c93cc
refactor: change node complete to node stop
Oct 14, 2025
cd583ad
fix: fix failing integ tests
Oct 14, 2025
d97e5f4
refactor: address pr comments
Oct 17, 2025
45a1ee1
refactor: update multiagent types to use type key and update handoff …
Oct 17, 2025
01cb874
refactor: address comments
Oct 17, 2025
68d0b96
refactor: simplify integ tests
Oct 17, 2025
fb670ba
refactor: revert agent result changes
Oct 17, 2025
7f34e2d
refactor: update handoff event to use ids
Oct 17, 2025
7bede48
fix: remove comment
Oct 17, 2025
ff5bec8
chore: Merge main
Oct 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion src/strands/multiagent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Union
from typing import Any, AsyncIterator, Union

from ..agent import AgentResult
from ..types.content import ContentBlock
Expand Down Expand Up @@ -98,6 +98,31 @@ async def invoke_async(
"""
raise NotImplementedError("invoke_async not implemented")

async def stream_async(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> AsyncIterator[dict[str, Any]]:
"""Stream events during multi-agent execution.

Default implementation executes invoke_async and yields the result as a single event.
Subclasses can override this method to provide true streaming capabilities.

Args:
task: The task to execute
invocation_state: Additional state/context passed to underlying agents.
Defaults to None to avoid mutable default argument issues.
**kwargs: Additional keyword arguments passed to underlying agents.

Yields:
Dictionary events containing multi-agent execution information including:
- Multi-agent coordination events (node start/complete, handoffs)
- Forwarded single-agent events with node context
- Final result event
"""
# Default implementation for backward compatibility
# Execute invoke_async and yield the result as a single event
result = await self.invoke_async(task, invocation_state, **kwargs)
yield {"result": result}

def __call__(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> MultiAgentResult:
Expand Down
367 changes: 282 additions & 85 deletions src/strands/multiagent/graph.py

Large diffs are not rendered by default.

213 changes: 166 additions & 47 deletions src/strands/multiagent/swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,21 @@
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from typing import Any, Callable, Tuple
from typing import Any, AsyncIterator, Callable, Tuple, cast

from opentelemetry import trace as trace_api

from ..agent import Agent, AgentResult
from ..agent import Agent
from ..agent.state import AgentState
from ..telemetry import get_tracer
from ..tools.decorator import tool
from ..types._events import (
MultiAgentHandoffEvent,
MultiAgentNodeStartEvent,
MultiAgentNodeStopEvent,
MultiAgentNodeStreamEvent,
MultiAgentResultEvent,
)
from ..types.content import ContentBlock, Messages
from ..types.event_loop import Metrics, Usage
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status
Expand Down Expand Up @@ -266,12 +273,43 @@ async def invoke_async(
) -> SwarmResult:
"""Invoke the swarm asynchronously.

This method uses stream_async internally and consumes all events until completion,
following the same pattern as the Agent class.

Args:
task: The task to execute
invocation_state: Additional state/context passed to underlying agents.
Defaults to None to avoid mutable default argument issues - a new empty dict
is created if None is provided.
Defaults to None to avoid mutable default argument issues.
**kwargs: Keyword arguments allowing backward compatible future changes.
"""
events = self.stream_async(task, invocation_state, **kwargs)
final_event = None
async for event in events:
final_event = event

if final_event is None or "result" not in final_event:
raise ValueError("Swarm streaming completed without producing a result event")

return cast(SwarmResult, final_event["result"])

async def stream_async(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> AsyncIterator[dict[str, Any]]:
"""Stream events during swarm execution.

Args:
task: The task to execute
invocation_state: Additional state/context passed to underlying agents.
Defaults to None to avoid mutable default argument issues.
**kwargs: Keyword arguments allowing backward compatible future changes.

Yields:
Dictionary events during swarm execution, such as:
- multi_agent_node_start: When a node begins execution
- multi_agent_node_stream: Forwarded agent events with node context
- multi_agent_handoff: When control is handed off between agents
- multi_agent_node_stop: When a node stops execution
- result: Final swarm result
"""
if invocation_state is None:
invocation_state = {}
Expand All @@ -282,7 +320,7 @@ async def invoke_async(
if self.entry_point:
initial_node = self.nodes[str(self.entry_point.name)]
else:
initial_node = next(iter(self.nodes.values())) # First SwarmNode
initial_node = next(iter(self.nodes.values()))

self.state = SwarmState(
current_node=initial_node,
Expand All @@ -303,15 +341,62 @@ async def invoke_async(
self.execution_timeout,
)

await self._execute_swarm(invocation_state)
async for event in self._execute_swarm(invocation_state):
yield event.as_dict()

except Exception:
logger.exception("swarm execution failed")
self.state.completion_status = Status.FAILED
raise
finally:
self.state.execution_time = round((time.time() - start_time) * 1000)

return self._build_result()
# Yield final result after execution_time is set
result = self._build_result()
yield MultiAgentResultEvent(result=result).as_dict()

async def _stream_with_timeout(
self, async_generator: AsyncIterator[Any], timeout: float | None, timeout_message: str
) -> AsyncIterator[Any]:
"""Wrap an async generator with timeout for total execution time.

Tracks elapsed time from start and enforces timeout across all events.
Each event wait uses remaining time from the total timeout budget.

Args:
async_generator: The generator to wrap
timeout: Total timeout in seconds for entire stream, or None for no timeout
timeout_message: Message to include in timeout exception

Yields:
Events from the wrapped generator as they arrive

Raises:
Exception: If total execution time exceeds timeout
"""
if timeout is None:
# No timeout - just pass through
async for event in async_generator:
yield event
else:
# Track start time for total timeout
start_time = asyncio.get_event_loop().time()

while True:
# Calculate remaining time from total timeout budget
elapsed = asyncio.get_event_loop().time() - start_time
remaining = timeout - elapsed

if remaining <= 0:
raise Exception(timeout_message)

try:
event = await asyncio.wait_for(async_generator.__anext__(), timeout=remaining)
yield event
except StopAsyncIteration:
break
except asyncio.TimeoutError as err:
raise Exception(timeout_message) from err

def _setup_swarm(self, nodes: list[Agent]) -> None:
"""Initialize swarm configuration."""
Expand Down Expand Up @@ -533,14 +618,14 @@ def _build_node_input(self, target_node: SwarmNode) -> str:

return context_text

async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None:
"""Shared execution logic used by execute_async."""
async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterator[Any]:
"""Execute swarm and yield TypedEvent objects."""
try:
# Main execution loop
while True:
if self.state.completion_status != Status.EXECUTING:
reason = f"Completion status is: {self.state.completion_status}"
logger.debug("reason=<%s> | stopping execution", reason)
logger.debug("reason=<%s> | stopping streaming execution", reason)
break

should_continue, reason = self.state.should_continue(
Expand Down Expand Up @@ -568,34 +653,45 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None:
len(self.state.node_history) + 1,
)

# Store the current node before execution to detect handoffs
previous_node = current_node

# Execute node with timeout protection
# TODO: Implement cancellation token to stop _execute_node from continuing
try:
await asyncio.wait_for(
# Execute with timeout wrapper for async generator streaming
node_stream = self._stream_with_timeout(
self._execute_node(current_node, self.state.task, invocation_state),
timeout=self.node_timeout,
self.node_timeout,
f"Node '{current_node.node_id}' execution timed out after {self.node_timeout}s",
)
async for event in node_stream:
yield event

self.state.node_history.append(current_node)

logger.debug("node=<%s> | node execution completed", current_node.node_id)

# Check if the current node is still the same after execution
# If it is, then no handoff occurred and we consider the swarm complete
if self.state.current_node == current_node:
# Check if handoff occurred during execution
if self.state.current_node != previous_node:
# Emit handoff event (single node transition in Swarm)
handoff_event = MultiAgentHandoffEvent(
from_node_ids=[previous_node.node_id],
to_node_ids=[self.state.current_node.node_id],
message=self.state.handoff_message or "Agent handoff occurred",
)
yield handoff_event
logger.debug(
"from_node=<%s>, to_node=<%s> | handoff detected",
previous_node.node_id,
self.state.current_node.node_id,
)
else:
# No handoff occurred, mark swarm as complete
logger.debug("node=<%s> | no handoff occurred, marking swarm as complete", current_node.node_id)
self.state.completion_status = Status.COMPLETED
break

except asyncio.TimeoutError:
logger.exception(
"node=<%s>, timeout=<%s>s | node execution timed out after timeout",
current_node.node_id,
self.node_timeout,
)
self.state.completion_status = Status.FAILED
break

except Exception:
logger.exception("node=<%s> | node execution failed", current_node.node_id)
self.state.completion_status = Status.FAILED
Expand All @@ -604,22 +700,26 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None:
except Exception:
logger.exception("swarm execution failed")
self.state.completion_status = Status.FAILED

elapsed_time = time.time() - self.state.start_time
logger.debug("status=<%s> | swarm execution completed", self.state.completion_status)
logger.debug(
"node_history_length=<%d>, time=<%s>s | metrics",
len(self.state.node_history),
f"{elapsed_time:.2f}",
)
finally:
elapsed_time = time.time() - self.state.start_time
logger.debug("status=<%s> | swarm execution completed", self.state.completion_status)
logger.debug(
"node_history_length=<%d>, time=<%s>s | metrics",
len(self.state.node_history),
f"{elapsed_time:.2f}",
)

async def _execute_node(
self, node: SwarmNode, task: str | list[ContentBlock], invocation_state: dict[str, Any]
) -> AgentResult:
"""Execute swarm node."""
) -> AsyncIterator[Any]:
"""Execute swarm node and yield TypedEvent objects."""
start_time = time.time()
node_name = node.node_id

# Emit node start event
start_event = MultiAgentNodeStartEvent(node_id=node_name, node_type="agent")
yield start_event

try:
# Prepare context for node
context_text = self._build_node_input(node)
Expand All @@ -632,10 +732,21 @@ async def _execute_node(
# Include additional ContentBlocks in node input
node_input = node_input + task

# Execute node
result = None
# Execute node with streaming
node.reset_executor_state()
result = await node.executor.invoke_async(node_input, invocation_state=invocation_state)

# Stream agent events with node context and capture final result
result = None
async for event in node.executor.stream_async(node_input, invocation_state=invocation_state):
# Forward agent events with node context
wrapped_event = MultiAgentNodeStreamEvent(node_name, event)
yield wrapped_event
# Capture the final result event
if "result" in event:
result = event["result"]

if result is None:
raise ValueError(f"Node '{node_name}' did not produce a result event")

if result.stop_reason == "interrupt":
node.executor.messages.pop() # remove interrupted tool use message
Expand All @@ -645,14 +756,10 @@ async def _execute_node(

execution_time = round((time.time() - start_time) * 1000)

# Create NodeResult
usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0)
metrics = Metrics(latencyMs=execution_time)
if hasattr(result, "metrics") and result.metrics:
if hasattr(result.metrics, "accumulated_usage"):
usage = result.metrics.accumulated_usage
if hasattr(result.metrics, "accumulated_metrics"):
metrics = result.metrics.accumulated_metrics
# Create NodeResult with extracted metrics
result_metrics = getattr(result, "metrics", None)
usage = getattr(result_metrics, "accumulated_usage", Usage(inputTokens=0, outputTokens=0, totalTokens=0))
metrics = getattr(result_metrics, "accumulated_metrics", Metrics(latencyMs=execution_time))

node_result = NodeResult(
result=result,
Expand All @@ -669,15 +776,20 @@ async def _execute_node(
# Accumulate metrics
self._accumulate_metrics(node_result)

return result
# Emit node stop event with full NodeResult
complete_event = MultiAgentNodeStopEvent(
node_id=node_name,
node_result=node_result,
)
yield complete_event

except Exception as e:
execution_time = round((time.time() - start_time) * 1000)
logger.exception("node=<%s> | node execution failed", node_name)

# Create a NodeResult for the failed node
node_result = NodeResult(
result=e, # Store exception as result
result=e,
execution_time=execution_time,
status=Status.FAILED,
accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0),
Expand All @@ -688,6 +800,13 @@ async def _execute_node(
# Store result in state
self.state.results[node_name] = node_result

# Emit node stop event even for failures
complete_event = MultiAgentNodeStopEvent(
node_id=node_name,
node_result=node_result,
)
yield complete_event

raise

def _accumulate_metrics(self, node_result: NodeResult) -> None:
Expand Down
Loading
Loading