Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
740 changes: 740 additions & 0 deletions AsyncFunction_Design.md

Large diffs are not rendered by default.

20 changes: 18 additions & 2 deletions packages/ai/src/microsoft/teams/ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,22 @@
Licensed under the MIT License.
"""

from . import plugins, utils
from .agent import Agent
from .ai_model import AIModel
from .chat_prompt import ChatPrompt, ChatSendResult
from .function import Function, FunctionCall, FunctionHandler, FunctionHandlers, FunctionHandlerWithNoParams
from .function import (
DeferredResult,
Function,
FunctionCall,
FunctionHandler,
FunctionHandlers,
FunctionHandlerWithNoParams,
)
from .memory import ListMemory, Memory
from .message import FunctionMessage, Message, ModelMessage, SystemMessage, UserMessage
from .message import DeferredMessage, FunctionMessage, Message, ModelMessage, SystemMessage, UserMessage
from .plugin import AIPluginProtocol, BaseAIPlugin
from .utils import * # noqa: F401, F403

__all__ = [
"ChatSendResult",
Expand All @@ -19,12 +29,18 @@
"ModelMessage",
"SystemMessage",
"FunctionMessage",
"DeferredMessage",
"Function",
"FunctionCall",
"DeferredResult",
"Memory",
"ListMemory",
"AIModel",
"AIPluginProtocol",
"BaseAIPlugin",
"FunctionHandler",
"FunctionHandlerWithNoParams",
"FunctionHandlers",
]
__all__.extend(utils.__all__)
__all__.extend(plugins.__all__)
6 changes: 3 additions & 3 deletions packages/ai/src/microsoft/teams/ai/ai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from .function import Function
from .memory import Memory
from .message import Message, ModelMessage, SystemMessage
from .message import DeferredMessage, Message, ModelMessage, SystemMessage


class AIModel(Protocol):
Expand All @@ -23,13 +23,13 @@ class AIModel(Protocol):

async def generate_text(
self,
input: Message,
input: Message | None,
*,
system: SystemMessage | None = None,
memory: Memory | None = None,
functions: dict[str, Function[BaseModel]] | None = None,
on_chunk: Callable[[str], Awaitable[None]] | None = None,
) -> ModelMessage:
) -> ModelMessage | list[DeferredMessage]:
"""
Generate a text response from the AI model.
Expand Down
162 changes: 156 additions & 6 deletions packages/ai/src/microsoft/teams/ai/chat_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
import inspect
from dataclasses import dataclass
from inspect import isawaitable
from logging import Logger
from typing import Any, Awaitable, Callable, Dict, Optional, Self, TypeVar, Union, cast, overload

from microsoft.teams.common.logging import ConsoleLogger
from pydantic import BaseModel

from .ai_model import AIModel
from .function import Function, FunctionHandler, FunctionHandlers, FunctionHandlerWithNoParams
from .memory import Memory
from .message import Message, ModelMessage, SystemMessage, UserMessage
from .message import DeferredMessage, FunctionMessage, Message, ModelMessage, SystemMessage, UserMessage
from .plugin import AIPluginProtocol

T = TypeVar("T", bound=BaseModel)
Expand All @@ -28,7 +30,8 @@ class ChatSendResult:
calls and plugin processing have been completed.
"""

response: ModelMessage # Final model response after processing
response: ModelMessage | None # Final model response after processing
is_deferred: bool = False


class ChatPrompt:
Expand All @@ -45,6 +48,9 @@ def __init__(
*,
functions: list[Function[Any]] | None = None,
plugins: list[AIPluginProtocol] | None = None,
memory: Memory | None = None,
logger: Logger | None = None,
instructions: str | SystemMessage | None = None,
):
"""
Initialize ChatPrompt with model and optional functions/plugins.
Expand All @@ -53,10 +59,16 @@ def __init__(
model: AI model implementation for text generation
functions: Optional list of functions the model can call
plugins: Optional list of plugins for extending functionality
memory: Optional memory for conversation context and deferred state
logger: Optional logger for debugging and monitoring
instructions: Optional default system instructions for the model
"""
self.model = model
self.functions: dict[str, Function[Any]] = {func.name: func for func in functions} if functions else {}
self.plugins: list[AIPluginProtocol] = plugins or []
self.memory = memory
self.logger = logger or ConsoleLogger().create_logger("@teams/ai/chat_prompt")
self.instructions = instructions

@overload
def with_function(self, function: Function[T]) -> Self: ...
Expand Down Expand Up @@ -134,9 +146,136 @@ def with_plugin(self, plugin: AIPluginProtocol) -> Self:
self.plugins.append(plugin)
return self

async def requires_resuming(self) -> bool:
"""
Check if there are any deferred functions that need resuming.

Returns:
True if there are DeferredMessage objects in memory that need resuming
"""
if not self.memory:
return False

messages = await self.memory.get_all()
return any(isinstance(msg, DeferredMessage) for msg in messages)

async def resolve_deferred(self, activity: Any) -> list[str]:
"""
Resolve deferred functions with the provided activity input.

Only attempts to resolve deferred functions whose resumers can handle
the provided activity type (determined by can_handle method).

Args:
activity: Activity data to use for resolving deferred functions

Returns:
List of resolution results from successfully resolved functions
"""
if not self.memory:
return []

messages = await self.memory.get_all()
deferred_messages = [msg for msg in messages if isinstance(msg, DeferredMessage)]

if not deferred_messages:
return []

results: list[str] = []
updated_messages = messages.copy() # Work with a copy

for i, msg in enumerate(updated_messages):
if not isinstance(msg, DeferredMessage):
continue

# Try plugins first, then fall back to built-in resumer
result = await self._try_resolve_with_plugins(msg, activity)
if result is None:
result = await self._try_resolve_with_builtin_resumer(msg, activity)

if result is not None:
updated_messages[i] = FunctionMessage(content=result, function_id=msg.function_id)
results.append(result)

# Update memory with resolved messages
if results: # Only update if we actually resolved something
await self.memory.set_all(updated_messages)

return results

async def _try_resolve_with_plugins(self, msg: DeferredMessage, activity: Any) -> str | None:
"""
Try to resolve a deferred message using plugins.

Args:
msg: The deferred message to resolve
activity: Activity data for resolution

Returns:
Result string if a plugin handled it, None otherwise
"""
for plugin in self.plugins:
result = await plugin.on_resume(msg.function_name, activity, msg.deferred_result.state)
if result is not None:
return result
return None

async def _try_resolve_with_builtin_resumer(self, msg: DeferredMessage, activity: Any) -> str | None:
"""
Try to resolve a deferred message using the built-in resumer.

Args:
msg: The deferred message to resolve
activity: Activity data for resolution

Returns:
Result string if resolved successfully, None if skipped, raises on error
"""
resumer_name = msg.function_name
associated_func = self.functions.get(resumer_name)

if not associated_func or associated_func.resumer is None:
raise ValueError(f"Expected a resumer for {resumer_name} but chat prompt was not set up with one")

# Check if the resumer can handle this type of activity
if not associated_func.resumer.can_handle(activity):
return None # Skip this deferred function

try:
# Call the resumer with the activity and saved state
result = associated_func.resumer(activity, msg.deferred_result.state)
if isawaitable(result):
result = await result
return result

except Exception as e:
# Return error message instead of raising
return f"Error resolving {resumer_name}: {str(e)}"

async def resume(self, activity: Any) -> ChatSendResult:
"""
Resume deferred functions with the provided activity input.

If all deferred functions are resolved, automatically continues with
normal chat processing using the activity text as input.

Args:
activity: Activity data to use for resolving deferred functions

Returns:
ChatSendResult - either indicating still deferred or containing the chat response
"""
await self.resolve_deferred(activity)

# If there are still deferred functions pending, return early
if await self.requires_resuming():
return ChatSendResult(response=None, is_deferred=True)

return await self.send(input=None)

async def send(
self,
input: str | Message,
input: str | Message | None,
*,
memory: Memory | None = None,
on_chunk: Callable[[str], Awaitable[None]] | Callable[[str], None] | None = None,
Expand All @@ -158,11 +297,18 @@ async def send(
if isinstance(input, str):
input = UserMessage(content=input)

# Use constructor instructions as default if none provided
if instructions is None:
instructions = self.instructions

# Convert string instructions to SystemMessage
if isinstance(instructions, str):
instructions = SystemMessage(content=instructions)

current_input = await self._run_before_send_hooks(input)
if input is not None:
current_input = await self._run_before_send_hooks(input)
else:
current_input = None
current_system_message = await self._run_build_instructions_hooks(instructions)
wrapped_functions = await self._build_wrapped_functions()

Expand All @@ -176,10 +322,12 @@ async def on_chunk_fn(chunk: str):
response = await self.model.generate_text(
current_input,
system=current_system_message,
memory=memory,
memory=memory or self.memory,
functions=wrapped_functions,
on_chunk=on_chunk_fn if on_chunk else None,
)
if isinstance(response, list):
return ChatSendResult(response=None, is_deferred=True)

current_response = await self._run_after_send_hooks(response)

Expand Down Expand Up @@ -283,7 +431,9 @@ async def _build_wrapped_functions(self) -> dict[str, Function[BaseModel]] | Non
name=func.name,
description=func.description,
parameter_schema=func.parameter_schema,
handler=self._wrap_function_handler(func.handler, func.name),
handler=self._wrap_function_handler(cast(FunctionHandler[BaseModel], func.handler), func.name)
if func.resumer is None
else func.handler,
)

return wrapped_functions
Expand Down
Loading
Loading