Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
274 changes: 273 additions & 1 deletion src/agents/extensions/memory/sqlalchemy_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,19 @@
)

await Runner.run(agent, "Hello", session=session)

# Set session metadata
await session.set_metadata({
"owner_id": "user_456",
"title": "Customer Support Chat",
"tags": ["support", "billing"]
})

# Get metadata
metadata = await session.get_metadata(keys=["owner_id", "title"])

# Find sessions by metadata
user_sessions = await session.find_sessions_by_metadata("owner_id", "user_456")
"""

from __future__ import annotations
Expand All @@ -34,6 +47,7 @@
Index,
Integer,
MetaData,
PrimaryKeyConstraint,
String,
Table,
Text,
Expand All @@ -55,6 +69,7 @@ class SQLAlchemySession(SessionABC):
_metadata: MetaData
_sessions: Table
_messages: Table
_session_metadata: Table

def __init__(
self,
Expand All @@ -64,6 +79,7 @@ def __init__(
create_tables: bool = False,
sessions_table: str = "agent_sessions",
messages_table: str = "agent_messages",
session_metadata_table: str = "agent_sessions_metadata",
):
"""Initializes a new SQLAlchemySession.

Expand All @@ -77,6 +93,8 @@ def __init__(
development and testing when migrations aren't used.
sessions_table (str, optional): Override the default table name for sessions if needed.
messages_table (str, optional): Override the default table name for messages if needed.
session_metadata_table (str, optional): Override the default table name for session
metadata if needed.
"""
self.session_id = session_id
self._engine = engine
Expand Down Expand Up @@ -127,6 +145,34 @@ def __init__(
sqlite_autoincrement=True,
)

self._session_metadata = Table(
session_metadata_table,
self._metadata,
Column(
"session_id",
String,
ForeignKey(f"{sessions_table}.session_id", ondelete="CASCADE"),
nullable=False,
),
Column("key", String(255), nullable=False),
Column("value", Text, nullable=True),
Column(
"created_at",
TIMESTAMP(timezone=False),
server_default=sql_text("CURRENT_TIMESTAMP"),
nullable=False,
),
Column(
"updated_at",
TIMESTAMP(timezone=False),
server_default=sql_text("CURRENT_TIMESTAMP"),
onupdate=sql_text("CURRENT_TIMESTAMP"),
nullable=False,
),
PrimaryKeyConstraint("session_id", "key", name="pk_session_metadata"),
Index("idx_session_metadata_key_value", "key", "value"),
)

# Async session factory
self._session_factory = async_sessionmaker(self._engine, expire_on_commit=False)

Expand Down Expand Up @@ -169,6 +215,21 @@ async def _deserialize_item(self, item: str) -> TResponseInputItem:
"""Deserialize a JSON string to an item. Can be overridden by subclasses."""
return json.loads(item) # type: ignore[no-any-return]

def _serialize_metadata_value(self, value: Any) -> str:
"""Serialize metadata value to string (JSON for dicts/lists)."""
if isinstance(value, (dict, list)):
return json.dumps(value, separators=(",", ":"))
return str(value)

def _deserialize_metadata_value(self, value_str: str | None) -> Any:
"""Deserialize metadata value (auto-parse JSON)."""
if value_str is None:
return None
try:
return json.loads(value_str)
except (json.JSONDecodeError, TypeError):
return value_str

# ------------------------------------------------------------------
# Session protocol implementation
# ------------------------------------------------------------------
Expand Down Expand Up @@ -309,13 +370,224 @@ async def pop_item(self) -> TResponseInputItem | None:
return None

async def clear_session(self) -> None:
"""Clear all items for this session."""
"""Clear all items and metadata for this session."""
await self._ensure_tables()
async with self._session_factory() as sess:
async with sess.begin():
# Delete metadata
await sess.execute(
delete(self._session_metadata).where(
self._session_metadata.c.session_id == self.session_id
)
)
# Delete messages
await sess.execute(
delete(self._messages).where(self._messages.c.session_id == self.session_id)
)
# Delete session
await sess.execute(
delete(self._sessions).where(self._sessions.c.session_id == self.session_id)
)

# ------------------------------------------------------------------
# Session metadata operations
# ------------------------------------------------------------------
async def set_metadata(self, metadata: dict[str, Any]) -> None:
"""Set metadata key-value pairs for this session (performs UPSERT).

Args:
metadata: Dictionary of key-value pairs to set. Values can be strings,
numbers, booleans, or JSON-serializable dicts/lists.

Example:
await session.set_metadata({
"owner_id": "user_123",
"title": "My Chat",
"tags": ["work", "important"]
})
"""
if not metadata:
return

await self._ensure_tables()

# Detect dialect and import correct insert function
dialect_name = self._engine.dialect.name

if dialect_name == "postgresql":
from sqlalchemy.dialects.postgresql import insert
elif dialect_name == "sqlite":
from sqlalchemy.dialects.sqlite import insert # type: ignore[assignment]
elif dialect_name == "mysql":
from sqlalchemy.dialects.mysql import insert # type: ignore[assignment]
else:
raise ValueError(f"Unsupported dialect: {dialect_name}")

async with self._session_factory() as sess:
async with sess.begin():
# Auto-create session row if it doesn't exist
existing = await sess.execute(
select(self._sessions.c.session_id).where(
self._sessions.c.session_id == self.session_id
)
)
if not existing.scalar_one_or_none():
await sess.execute(
insert(self._sessions).values({"session_id": self.session_id})
)

# UPSERT each metadata key-value pair
for key, value in metadata.items():
value_str = self._serialize_metadata_value(value)

stmt = insert(self._session_metadata).values(
{
"session_id": self.session_id,
"key": key,
"value": value_str,
}
)

# Use dialect-specific UPSERT
if dialect_name == "mysql":
# MySQL uses ON DUPLICATE KEY UPDATE
stmt = stmt.on_duplicate_key_update(
value=stmt.inserted.value,
updated_at=sql_text("CURRENT_TIMESTAMP"),
)
else:
# PostgreSQL and SQLite use ON CONFLICT DO UPDATE
stmt = stmt.on_conflict_do_update(
index_elements=["session_id", "key"],
set_={
"value": stmt.excluded.value,
"updated_at": sql_text("CURRENT_TIMESTAMP"),
},
)

await sess.execute(stmt)

async def get_metadata(self, keys: list[str] | None = None) -> dict[str, Any]:
"""Get metadata for this session.

Args:
keys: Optional list of specific keys to retrieve. If None, returns all metadata.
Missing keys will have None as their value.

Returns:
Dictionary of metadata key-value pairs. Values are auto-deserialized from JSON.

Example:
# Get all metadata
all_meta = await session.get_metadata()

# Get specific keys (missing keys return None)
meta = await session.get_metadata(keys=["owner_id", "title"])
# Returns: {"owner_id": "user_123", "title": "My Chat"}
"""
await self._ensure_tables()

async with self._session_factory() as sess:
if keys is None:
# Get all metadata for this session
stmt = select(self._session_metadata.c.key, self._session_metadata.c.value).where(
self._session_metadata.c.session_id == self.session_id
)

result = await sess.execute(stmt)
rows = result.all()

return {key: self._deserialize_metadata_value(value) for key, value in rows}
else:
# Get specific keys
stmt = select(self._session_metadata.c.key, self._session_metadata.c.value).where(
self._session_metadata.c.session_id == self.session_id,
self._session_metadata.c.key.in_(keys),
)

result = await sess.execute(stmt)
rows = result.all()

# Build dict with None for missing keys
found_keys = {key: self._deserialize_metadata_value(value) for key, value in rows}

return {key: found_keys.get(key, None) for key in keys}

async def delete_metadata(self, keys: list[str] | None = None) -> None:
"""Delete metadata for this session.

Args:
keys: Optional list of specific keys to delete. If None, deletes all metadata
for this session.

Example:
# Delete specific keys
await session.delete_metadata(keys=["title", "tags"])

# Delete all metadata
await session.delete_metadata()
"""
await self._ensure_tables()

async with self._session_factory() as sess:
async with sess.begin():
if keys is None:
# Delete all metadata for this session
await sess.execute(
delete(self._session_metadata).where(
self._session_metadata.c.session_id == self.session_id
)
)
else:
# Delete specific keys
await sess.execute(
delete(self._session_metadata).where(
self._session_metadata.c.session_id == self.session_id,
self._session_metadata.c.key.in_(keys),
)
)

async def find_sessions_by_metadata(
self, key: str, value: Any, limit: int | None = 100
) -> list[str]:
"""Find session IDs that have matching metadata (cross-session query).

This is an instance method that queries across ALL sessions in the database,
not just the current session_id.

Args:
key: Metadata key to search for
value: Metadata value to match (supports simple types: str, int, bool)
limit: Maximum number of session IDs to return. Pass None for unlimited results.

Returns:
List of session IDs matching the criteria

Example:
# Find all sessions for a specific user (limited to 100)
session_ids = await session.find_sessions_by_metadata("owner_id", "user_123")
# Returns: ["chat_1", "chat_2", "chat_3"]

# Find all sessions without limit
all_sessions = await session.find_sessions_by_metadata(
"owner_id", "user_123", limit=None
)
"""
await self._ensure_tables()

# Serialize value for comparison
value_str = self._serialize_metadata_value(value)

async with self._session_factory() as sess:
stmt = (
select(self._session_metadata.c.session_id)
.where(
self._session_metadata.c.key == key, self._session_metadata.c.value == value_str
)
.distinct()
)
if limit is not None:
stmt = stmt.limit(limit)

result = await sess.execute(stmt)
return [row[0] for row in result.all()]
Loading