11import json
22import logging
33from abc import ABC , abstractmethod
4- from typing import Any , List , Optional
4+ from typing import Any , List , Optional , Type
55
6- from sqlalchemy import Column , Integer , Text , create_engine
6+ from sqlalchemy import Column , Integer , Select , Text , create_engine , select
77
88try :
99 from sqlalchemy .orm import declarative_base
2222class BaseMessageConverter (ABC ):
2323 """The class responsible for converting BaseMessage to your SQLAlchemy model."""
2424
25+ @abstractmethod
26+ def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
27+ raise NotImplementedError
28+
2529 @abstractmethod
2630 def from_sql_model (self , sql_message : Any ) -> BaseMessage :
2731 """Convert a SQLAlchemy model to a BaseMessage instance."""
@@ -51,7 +55,7 @@ def create_message_model(table_name, DynamicBase): # type: ignore
5155
5256 """
5357
54- # Model decleared inside a function to have a dynamic table name
58+ # Model declared inside a function to have a dynamic table name
5559 class Message (DynamicBase ):
5660 __tablename__ = table_name
5761 id = Column (Integer , primary_key = True )
@@ -82,6 +86,8 @@ def get_sql_model_class(self) -> Any:
8286class SQLChatMessageHistory (BaseChatMessageHistory ):
8387 """Chat message history stored in an SQL database."""
8488
89+ DEFAULT_MESSAGE_CONVERTER : Type [BaseMessageConverter ] = DefaultMessageConverter
90+
8591 def __init__ (
8692 self ,
8793 session_id : str ,
@@ -93,7 +99,9 @@ def __init__(
9399 self .connection_string = connection_string
94100 self .engine = create_engine (connection_string , echo = False )
95101 self .session_id_field_name = session_id_field_name
96- self .converter = custom_message_converter or DefaultMessageConverter (table_name )
102+ self .converter = custom_message_converter or self .DEFAULT_MESSAGE_CONVERTER (
103+ table_name
104+ )
97105 self .sql_model_class = self .converter .get_sql_model_class ()
98106 if not hasattr (self .sql_model_class , session_id_field_name ):
99107 raise ValueError ("SQL model class must have session_id column" )
@@ -105,21 +113,25 @@ def __init__(
105113 def _create_table_if_not_exists (self ) -> None :
106114 self .sql_model_class .metadata .create_all (self .engine )
107115
116+ def _messages_query (self ) -> Select :
117+ """Construct an SQLAlchemy selectable to query for messages"""
118+ return (
119+ select (self .sql_model_class )
120+ .where (
121+ getattr (self .sql_model_class , self .session_id_field_name )
122+ == self .session_id
123+ )
124+ .order_by (self .sql_model_class .id .asc ())
125+ )
126+
108127 @property
109128 def messages (self ) -> List [BaseMessage ]: # type: ignore
110129 """Retrieve all messages from db"""
111130 with self .Session () as session :
112- result = (
113- session .query (self .sql_model_class )
114- .where (
115- getattr (self .sql_model_class , self .session_id_field_name )
116- == self .session_id
117- )
118- .order_by (self .sql_model_class .id .asc ())
119- )
131+ result = session .execute (self ._messages_query ())
120132 messages = []
121133 for record in result :
122- messages .append (self .converter .from_sql_model (record ))
134+ messages .append (self .converter .from_sql_model (record [ 0 ] ))
123135 return messages
124136
125137 def add_message (self , message : BaseMessage ) -> None :
0 commit comments