diff --git a/src/agents/extensions/memory/sqlalchemy_session.py b/src/agents/extensions/memory/sqlalchemy_session.py index e1fc885bb..b2589a3fe 100644 --- a/src/agents/extensions/memory/sqlalchemy_session.py +++ b/src/agents/extensions/memory/sqlalchemy_session.py @@ -19,6 +19,19 @@ ) await Runner.run(agent, "Hello", session=session) + + # Set session metadata + await session.set_metadata({ + "owner_id": "user_456", + "title": "Customer Support Chat", + "tags": ["support", "billing"] + }) + + # Get metadata + metadata = await session.get_metadata(keys=["owner_id", "title"]) + + # Find sessions by metadata + user_sessions = await session.find_sessions_by_metadata("owner_id", "user_456") """ from __future__ import annotations @@ -34,6 +47,7 @@ Index, Integer, MetaData, + PrimaryKeyConstraint, String, Table, Text, @@ -55,6 +69,7 @@ class SQLAlchemySession(SessionABC): _metadata: MetaData _sessions: Table _messages: Table + _session_metadata: Table def __init__( self, @@ -64,6 +79,7 @@ def __init__( create_tables: bool = False, sessions_table: str = "agent_sessions", messages_table: str = "agent_messages", + session_metadata_table: str = "agent_sessions_metadata", ): """Initializes a new SQLAlchemySession. @@ -77,6 +93,8 @@ def __init__( development and testing when migrations aren't used. sessions_table (str, optional): Override the default table name for sessions if needed. messages_table (str, optional): Override the default table name for messages if needed. + session_metadata_table (str, optional): Override the default table name for session + metadata if needed. """ self.session_id = session_id self._engine = engine @@ -127,6 +145,34 @@ def __init__( sqlite_autoincrement=True, ) + self._session_metadata = Table( + session_metadata_table, + self._metadata, + Column( + "session_id", + String, + ForeignKey(f"{sessions_table}.session_id", ondelete="CASCADE"), + nullable=False, + ), + Column("key", String(255), nullable=False), + Column("value", Text, nullable=True), + Column( + "created_at", + TIMESTAMP(timezone=False), + server_default=sql_text("CURRENT_TIMESTAMP"), + nullable=False, + ), + Column( + "updated_at", + TIMESTAMP(timezone=False), + server_default=sql_text("CURRENT_TIMESTAMP"), + onupdate=sql_text("CURRENT_TIMESTAMP"), + nullable=False, + ), + PrimaryKeyConstraint("session_id", "key", name="pk_session_metadata"), + Index("idx_session_metadata_key_value", "key", "value"), + ) + # Async session factory self._session_factory = async_sessionmaker(self._engine, expire_on_commit=False) @@ -169,6 +215,21 @@ async def _deserialize_item(self, item: str) -> TResponseInputItem: """Deserialize a JSON string to an item. Can be overridden by subclasses.""" return json.loads(item) # type: ignore[no-any-return] + def _serialize_metadata_value(self, value: Any) -> str: + """Serialize metadata value to string (JSON for dicts/lists).""" + if isinstance(value, (dict, list)): + return json.dumps(value, separators=(",", ":")) + return str(value) + + def _deserialize_metadata_value(self, value_str: str | None) -> Any: + """Deserialize metadata value (auto-parse JSON).""" + if value_str is None: + return None + try: + return json.loads(value_str) + except (json.JSONDecodeError, TypeError): + return value_str + # ------------------------------------------------------------------ # Session protocol implementation # ------------------------------------------------------------------ @@ -309,13 +370,224 @@ async def pop_item(self) -> TResponseInputItem | None: return None async def clear_session(self) -> None: - """Clear all items for this session.""" + """Clear all items and metadata for this session.""" await self._ensure_tables() async with self._session_factory() as sess: async with sess.begin(): + # Delete metadata + await sess.execute( + delete(self._session_metadata).where( + self._session_metadata.c.session_id == self.session_id + ) + ) + # Delete messages await sess.execute( delete(self._messages).where(self._messages.c.session_id == self.session_id) ) + # Delete session await sess.execute( delete(self._sessions).where(self._sessions.c.session_id == self.session_id) ) + + # ------------------------------------------------------------------ + # Session metadata operations + # ------------------------------------------------------------------ + async def set_metadata(self, metadata: dict[str, Any]) -> None: + """Set metadata key-value pairs for this session (performs UPSERT). + + Args: + metadata: Dictionary of key-value pairs to set. Values can be strings, + numbers, booleans, or JSON-serializable dicts/lists. + + Example: + await session.set_metadata({ + "owner_id": "user_123", + "title": "My Chat", + "tags": ["work", "important"] + }) + """ + if not metadata: + return + + await self._ensure_tables() + + # Detect dialect and import correct insert function + dialect_name = self._engine.dialect.name + + if dialect_name == "postgresql": + from sqlalchemy.dialects.postgresql import insert + elif dialect_name == "sqlite": + from sqlalchemy.dialects.sqlite import insert # type: ignore[assignment] + elif dialect_name == "mysql": + from sqlalchemy.dialects.mysql import insert # type: ignore[assignment] + else: + raise ValueError(f"Unsupported dialect: {dialect_name}") + + async with self._session_factory() as sess: + async with sess.begin(): + # Auto-create session row if it doesn't exist + existing = await sess.execute( + select(self._sessions.c.session_id).where( + self._sessions.c.session_id == self.session_id + ) + ) + if not existing.scalar_one_or_none(): + await sess.execute( + insert(self._sessions).values({"session_id": self.session_id}) + ) + + # UPSERT each metadata key-value pair + for key, value in metadata.items(): + value_str = self._serialize_metadata_value(value) + + stmt = insert(self._session_metadata).values( + { + "session_id": self.session_id, + "key": key, + "value": value_str, + } + ) + + # Use dialect-specific UPSERT + if dialect_name == "mysql": + # MySQL uses ON DUPLICATE KEY UPDATE + stmt = stmt.on_duplicate_key_update( + value=stmt.inserted.value, + updated_at=sql_text("CURRENT_TIMESTAMP"), + ) + else: + # PostgreSQL and SQLite use ON CONFLICT DO UPDATE + stmt = stmt.on_conflict_do_update( + index_elements=["session_id", "key"], + set_={ + "value": stmt.excluded.value, + "updated_at": sql_text("CURRENT_TIMESTAMP"), + }, + ) + + await sess.execute(stmt) + + async def get_metadata(self, keys: list[str] | None = None) -> dict[str, Any]: + """Get metadata for this session. + + Args: + keys: Optional list of specific keys to retrieve. If None, returns all metadata. + Missing keys will have None as their value. + + Returns: + Dictionary of metadata key-value pairs. Values are auto-deserialized from JSON. + + Example: + # Get all metadata + all_meta = await session.get_metadata() + + # Get specific keys (missing keys return None) + meta = await session.get_metadata(keys=["owner_id", "title"]) + # Returns: {"owner_id": "user_123", "title": "My Chat"} + """ + await self._ensure_tables() + + async with self._session_factory() as sess: + if keys is None: + # Get all metadata for this session + stmt = select(self._session_metadata.c.key, self._session_metadata.c.value).where( + self._session_metadata.c.session_id == self.session_id + ) + + result = await sess.execute(stmt) + rows = result.all() + + return {key: self._deserialize_metadata_value(value) for key, value in rows} + else: + # Get specific keys + stmt = select(self._session_metadata.c.key, self._session_metadata.c.value).where( + self._session_metadata.c.session_id == self.session_id, + self._session_metadata.c.key.in_(keys), + ) + + result = await sess.execute(stmt) + rows = result.all() + + # Build dict with None for missing keys + found_keys = {key: self._deserialize_metadata_value(value) for key, value in rows} + + return {key: found_keys.get(key, None) for key in keys} + + async def delete_metadata(self, keys: list[str] | None = None) -> None: + """Delete metadata for this session. + + Args: + keys: Optional list of specific keys to delete. If None, deletes all metadata + for this session. + + Example: + # Delete specific keys + await session.delete_metadata(keys=["title", "tags"]) + + # Delete all metadata + await session.delete_metadata() + """ + await self._ensure_tables() + + async with self._session_factory() as sess: + async with sess.begin(): + if keys is None: + # Delete all metadata for this session + await sess.execute( + delete(self._session_metadata).where( + self._session_metadata.c.session_id == self.session_id + ) + ) + else: + # Delete specific keys + await sess.execute( + delete(self._session_metadata).where( + self._session_metadata.c.session_id == self.session_id, + self._session_metadata.c.key.in_(keys), + ) + ) + + async def find_sessions_by_metadata( + self, key: str, value: Any, limit: int | None = 100 + ) -> list[str]: + """Find session IDs that have matching metadata (cross-session query). + + This is an instance method that queries across ALL sessions in the database, + not just the current session_id. + + Args: + key: Metadata key to search for + value: Metadata value to match (supports simple types: str, int, bool) + limit: Maximum number of session IDs to return. Pass None for unlimited results. + + Returns: + List of session IDs matching the criteria + + Example: + # Find all sessions for a specific user (limited to 100) + session_ids = await session.find_sessions_by_metadata("owner_id", "user_123") + # Returns: ["chat_1", "chat_2", "chat_3"] + + # Find all sessions without limit + all_sessions = await session.find_sessions_by_metadata( + "owner_id", "user_123", limit=None + ) + """ + await self._ensure_tables() + + # Serialize value for comparison + value_str = self._serialize_metadata_value(value) + + async with self._session_factory() as sess: + stmt = ( + select(self._session_metadata.c.session_id) + .where( + self._session_metadata.c.key == key, self._session_metadata.c.value == value_str + ) + .distinct() + ) + if limit is not None: + stmt = stmt.limit(limit) + + result = await sess.execute(stmt) + return [row[0] for row in result.all()] diff --git a/tests/extensions/memory/test_sqlalchemy_session.py b/tests/extensions/memory/test_sqlalchemy_session.py index 496d0b027..ef44ed3be 100644 --- a/tests/extensions/memory/test_sqlalchemy_session.py +++ b/tests/extensions/memory/test_sqlalchemy_session.py @@ -130,11 +130,16 @@ async def test_runner_integration(agent: Agent): async def test_session_isolation(agent: Agent): """Test that different session IDs result in isolated conversation histories.""" + from sqlalchemy.ext.asyncio import create_async_engine + + # Create ONE shared engine + engine = create_async_engine(DB_URL) + session_id_1 = "session_1" - session1 = SQLAlchemySession.from_url(session_id_1, url=DB_URL, create_tables=True) + session1 = SQLAlchemySession(session_id_1, engine=engine, create_tables=True) session_id_2 = "session_2" - session2 = SQLAlchemySession.from_url(session_id_2, url=DB_URL, create_tables=True) + session2 = SQLAlchemySession(session_id_2, engine=engine, create_tables=True) # Interact with session 1 assert isinstance(agent.model, FakeModel) @@ -219,19 +224,20 @@ async def test_get_items_same_timestamp_consistent_order(): ) ) id_map = { - json.loads(message_json)["id"]: row_id - for row_id, message_json in rows.fetchall() + json.loads(message_json)["id"]: row_id for row_id, message_json in rows.fetchall() } shared = datetime(2025, 10, 15, 17, 26, 39, 132483) older = shared - timedelta(milliseconds=1) await sess.execute( update(session._messages) - .where(session._messages.c.id.in_( - [ - id_map["rs_same_ts"], - id_map["msg_same_ts"], - ] - )) + .where( + session._messages.c.id.in_( + [ + id_map["rs_same_ts"], + id_map["msg_same_ts"], + ] + ) + ) .values(created_at=shared) ) await sess.execute( @@ -320,9 +326,7 @@ async def test_pop_item_same_timestamp_returns_latest(): async with session._session_factory() as sess: await sess.execute( text( - "UPDATE agent_messages " - "SET created_at = :created_at " - "WHERE session_id = :session_id" + "UPDATE agent_messages SET created_at = :created_at WHERE session_id = :session_id" ), { "created_at": "2025-10-15 17:26:39.132483", @@ -391,3 +395,222 @@ async def recording_execute(statement: Any, *args: Any, **kwargs: Any) -> Any: assert _item_ids(retrieved_full) == ["rs_first", "msg_second"] assert _item_ids(retrieved_limited) == ["rs_first", "msg_second"] + + +# ------------------------------------------------------------------ +# Session metadata tests +# ------------------------------------------------------------------ +async def test_set_and_get_metadata(): + """Test basic metadata set and get operations.""" + session = SQLAlchemySession.from_url("metadata_test", url=DB_URL, create_tables=True) + + # Set metadata + await session.set_metadata({"owner_id": "user_123", "title": "Test Chat"}) + + # Get all metadata + metadata = await session.get_metadata() + assert metadata["owner_id"] == "user_123" + assert metadata["title"] == "Test Chat" + + +async def test_get_metadata_with_keys(): + """Test getting specific metadata keys including missing keys.""" + session = SQLAlchemySession.from_url("metadata_keys_test", url=DB_URL, create_tables=True) + + await session.set_metadata({"owner_id": "user_123", "title": "Chat"}) + + # Get specific keys + metadata = await session.get_metadata(keys=["owner_id", "nonexistent"]) + assert metadata["owner_id"] == "user_123" + assert metadata["nonexistent"] is None + + +async def test_get_metadata_all(): + """Test getting all metadata for a session.""" + session = SQLAlchemySession.from_url("metadata_all_test", url=DB_URL, create_tables=True) + + await session.set_metadata({"owner_id": "user_123", "title": "Chat", "tags": ["work"]}) + + all_meta = await session.get_metadata() + assert len(all_meta) == 3 + assert all_meta["owner_id"] == "user_123" + assert all_meta["title"] == "Chat" + assert all_meta["tags"] == ["work"] + + +async def test_metadata_json_serialization(): + """Test that dicts and lists are automatically serialized to JSON.""" + session = SQLAlchemySession.from_url("json_test", url=DB_URL, create_tables=True) + + await session.set_metadata( + { + "tags": ["work", "important"], + "config": {"theme": "dark", "lang": "en"}, + "nested": { + "level1": { + "level2": {"items": ["a", "b", "c"], "count": 3}, + "tags": ["nested", "deep"], + } + }, + } + ) + + metadata = await session.get_metadata() + assert metadata["tags"] == ["work", "important"] + assert metadata["config"] == {"theme": "dark", "lang": "en"} + assert metadata["nested"]["level1"]["level2"]["items"] == ["a", "b", "c"] + assert metadata["nested"]["level1"]["level2"]["count"] == 3 + assert metadata["nested"]["level1"]["tags"] == ["nested", "deep"] + + +async def test_metadata_auto_create_session(): + """Test that setting metadata auto-creates session row.""" + session = SQLAlchemySession.from_url("auto_create_test", url=DB_URL, create_tables=True) + + # Set metadata before adding any messages + await session.set_metadata({"owner_id": "user_123"}) + + # Verify metadata was set + metadata = await session.get_metadata() + assert metadata["owner_id"] == "user_123" + + +async def test_metadata_upsert(): + """Test that setting metadata updates existing keys.""" + session = SQLAlchemySession.from_url("upsert_test", url=DB_URL, create_tables=True) + + await session.set_metadata({"title": "Original Title"}) + await session.set_metadata({"title": "Updated Title", "new_key": "value"}) + + metadata = await session.get_metadata() + assert metadata["title"] == "Updated Title" + assert metadata["new_key"] == "value" + + +async def test_delete_metadata_specific_keys(): + """Test deleting specific metadata keys.""" + session = SQLAlchemySession.from_url("delete_keys_test", url=DB_URL, create_tables=True) + + await session.set_metadata({"owner_id": "user_123", "title": "Chat", "tags": ["work"]}) + await session.delete_metadata(keys=["title", "tags"]) + + metadata = await session.get_metadata() + assert "owner_id" in metadata + assert "title" not in metadata + assert "tags" not in metadata + + +async def test_delete_metadata_all(): + """Test deleting all metadata for a session.""" + session = SQLAlchemySession.from_url("delete_all_test", url=DB_URL, create_tables=True) + + await session.set_metadata({"owner_id": "user_123", "title": "Chat"}) + await session.delete_metadata() + + metadata = await session.get_metadata() + assert len(metadata) == 0 + + +async def test_find_sessions_by_metadata(): + """Test finding sessions by metadata (cross-session query).""" + from sqlalchemy.ext.asyncio import create_async_engine + + # Create ONE shared engine for all sessions + engine = create_async_engine(DB_URL) + + session1 = SQLAlchemySession("find_test_1", engine=engine, create_tables=True) + session2 = SQLAlchemySession("find_test_2", engine=engine, create_tables=True) + session3 = SQLAlchemySession("find_test_3", engine=engine, create_tables=True) + session4 = SQLAlchemySession("find_test_4", engine=engine, create_tables=True) + + await session1.set_metadata({"owner_id": "user_123"}) + await session2.set_metadata({"owner_id": "user_123"}) + await session3.set_metadata({"owner_id": "user_456"}) + await session4.set_metadata({"owner_id": "user_123"}) + + # Find sessions for user_123 + matching = await session1.find_sessions_by_metadata("owner_id", "user_123") + assert "find_test_1" in matching + assert "find_test_2" in matching + assert "find_test_3" not in matching + assert "find_test_4" in matching + + +async def test_find_sessions_by_metadata_no_matches(): + """Test finding sessions when there are no matches.""" + session = SQLAlchemySession.from_url("no_match_test", url=DB_URL, create_tables=True) + + await session.set_metadata({"owner_id": "user_123"}) + + matching = await session.find_sessions_by_metadata("owner_id", "nonexistent_user") + assert len(matching) == 0 + + +async def test_find_sessions_by_metadata_unlimited(): + """Test finding sessions without limit.""" + from sqlalchemy.ext.asyncio import create_async_engine + + # Create ONE shared engine + engine = create_async_engine(DB_URL) + + # Create 105 sessions with same metadata - all using shared engine + for i in range(105): + session = SQLAlchemySession(f"session_{i}", engine=engine, create_tables=True) + await session.set_metadata({"owner_id": "user_123"}) + + # Query with default limit (100) - using same shared engine + query_session = SQLAlchemySession("query_session", engine=engine, create_tables=False) + limited = await query_session.find_sessions_by_metadata("owner_id", "user_123") + assert len(limited) == 100 + + # Query with no limit + unlimited = await query_session.find_sessions_by_metadata("owner_id", "user_123", limit=None) + assert len(unlimited) == 105 + + +async def test_clear_session_deletes_metadata(): + """Test that clear_session() removes metadata.""" + session = SQLAlchemySession.from_url("clear_meta_test", url=DB_URL, create_tables=True) + + await session.set_metadata({"owner_id": "user_123"}) + await session.add_items([{"role": "user", "content": "Hello"}]) + + await session.clear_session() + + metadata = await session.get_metadata() + assert len(metadata) == 0 + + items = await session.get_items() + assert len(items) == 0 + + +async def test_metadata_isolation_between_sessions(): + """Test that metadata is isolated between different sessions.""" + from sqlalchemy.ext.asyncio import create_async_engine + + # Create ONE shared engine (same database) + engine = create_async_engine(DB_URL) + + # Two sessions with different session_ids, but same database + session1 = SQLAlchemySession("iso_test_1", engine=engine, create_tables=True) + session2 = SQLAlchemySession("iso_test_2", engine=engine, create_tables=True) + + await session1.set_metadata({"owner_id": "user_123"}) + await session2.set_metadata({"owner_id": "user_456"}) + + meta1 = await session1.get_metadata() + meta2 = await session2.get_metadata() + + # Verify isolation within same database + assert meta1["owner_id"] == "user_123" + assert meta2["owner_id"] == "user_456" + + +async def test_metadata_empty_dict(): + """Test that setting empty metadata dict is a no-op.""" + session = SQLAlchemySession.from_url("empty_dict_test", url=DB_URL, create_tables=True) + + await session.set_metadata({}) + + metadata = await session.get_metadata() + assert len(metadata) == 0 diff --git a/tests/test_openai_chatcompletions_converter.py b/tests/test_openai_chatcompletions_converter.py index 4c5674388..18dfdf045 100644 --- a/tests/test_openai_chatcompletions_converter.py +++ b/tests/test_openai_chatcompletions_converter.py @@ -303,10 +303,13 @@ def test_extract_all_content_rejects_invalid_input_audio(): """ input_audio requires both data and format fields to be present. """ - audio_missing_data = cast(ResponseInputAudioParam, { - "type": "input_audio", - "input_audio": {"format": "wav"}, - }) + audio_missing_data = cast( + ResponseInputAudioParam, + { + "type": "input_audio", + "input_audio": {"format": "wav"}, + }, + ) with pytest.raises(UserError): Converter.extract_all_content([audio_missing_data]) diff --git a/tests/test_strict_schema_oneof.py b/tests/test_strict_schema_oneof.py index d6a145b57..7e289e70f 100644 --- a/tests/test_strict_schema_oneof.py +++ b/tests/test_strict_schema_oneof.py @@ -239,9 +239,7 @@ def test_oneof_with_refs(): schema = { "type": "object", "properties": { - "value": { - "oneOf": [{"$ref": "#/$defs/StringType"}, {"$ref": "#/$defs/IntType"}] - } + "value": {"oneOf": [{"$ref": "#/$defs/StringType"}, {"$ref": "#/$defs/IntType"}]} }, "$defs": { "StringType": {"type": "string"}, @@ -254,9 +252,7 @@ def test_oneof_with_refs(): expected = { "type": "object", "properties": { - "value": { - "anyOf": [{"$ref": "#/$defs/StringType"}, {"$ref": "#/$defs/IntType"}] - } + "value": {"anyOf": [{"$ref": "#/$defs/StringType"}, {"$ref": "#/$defs/IntType"}]} }, "$defs": { "StringType": {"type": "string"},