diff --git a/examples/mem_scheduler/orm_examples.py b/examples/mem_scheduler/orm_examples.py new file mode 100644 index 000000000..bbb57b4ab --- /dev/null +++ b/examples/mem_scheduler/orm_examples.py @@ -0,0 +1,374 @@ +#!/usr/bin/env python3 +""" +ORM Examples for MemScheduler + +This script demonstrates how to use the BaseDBManager's new environment variable loading methods +for MySQL and Redis connections. +""" + +import multiprocessing +import os +import sys + +from pathlib import Path + + +# Add the src directory to the Python path +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) + +from memos.log import get_logger +from memos.mem_scheduler.orm_modules.base_model import BaseDBManager, DatabaseError +from memos.mem_scheduler.orm_modules.redis_model import RedisDBManager, SimpleListManager + + +logger = get_logger(__name__) + + +def test_mysql_engine_from_env(): + """Test loading MySQL engine from environment variables""" + print("\n" + "=" * 60) + print("Testing MySQL Engine from Environment Variables") + print("=" * 60) + + try: + # Test loading MySQL engine from current environment variables + mysql_engine = BaseDBManager.load_mysql_engine_from_env() + if mysql_engine is None: + print("❌ Failed to create MySQL engine - check environment variables") + return + + print(f"✅ Successfully created MySQL engine: {mysql_engine}") + print(f" Engine URL: {mysql_engine.url}") + + # Test connection + with mysql_engine.connect() as conn: + from sqlalchemy import text + + result = conn.execute(text("SELECT 'MySQL connection test successful' as message")) + message = result.fetchone()[0] + print(f" Connection test: {message}") + + mysql_engine.dispose() + print(" MySQL engine disposed successfully") + + except DatabaseError as e: + print(f"❌ DatabaseError: {e}") + except Exception as e: + print(f"❌ Unexpected error: {e}") + + +def test_redis_connection_from_env(): + """Test loading Redis connection from environment variables""" + print("\n" + "=" * 60) + print("Testing Redis Connection from Environment Variables") + print("=" * 60) + + try: + # Test loading Redis connection from current environment variables + redis_client = BaseDBManager.load_redis_engine_from_env() + if redis_client is None: + print("❌ Failed to create Redis connection - check environment variables") + return + + print(f"✅ Successfully created Redis connection: {redis_client}") + + # Test basic Redis operations + redis_client.set("test_key", "Hello from ORM Examples!") + value = redis_client.get("test_key") + print(f" Redis test - Set/Get: {value}") + + # Test Redis info + info = redis_client.info("server") + redis_version = info.get("redis_version", "unknown") + print(f" Redis server version: {redis_version}") + + # Clean up test key + redis_client.delete("test_key") + print(" Test key cleaned up") + + redis_client.close() + print(" Redis connection closed successfully") + + except DatabaseError as e: + print(f"❌ DatabaseError: {e}") + except Exception as e: + print(f"❌ Unexpected error: {e}") + + +def test_environment_variables(): + """Test and display current environment variables""" + print("\n" + "=" * 60) + print("Current Environment Variables") + print("=" * 60) + + # MySQL environment variables + mysql_vars = [ + "MYSQL_HOST", + "MYSQL_PORT", + "MYSQL_USERNAME", + "MYSQL_PASSWORD", + "MYSQL_DATABASE", + "MYSQL_CHARSET", + ] + + print("\nMySQL Environment Variables:") + for var in mysql_vars: + value = os.getenv(var, "Not set") + # Mask password for security + if "PASSWORD" in var and value != "Not set": + value = "*" * len(value) + print(f" {var}: {value}") + + # Redis environment variables + redis_vars = [ + "REDIS_HOST", + "REDIS_PORT", + "REDIS_DB", + "REDIS_PASSWORD", + "MEMSCHEDULER_REDIS_HOST", + "MEMSCHEDULER_REDIS_PORT", + "MEMSCHEDULER_REDIS_DB", + "MEMSCHEDULER_REDIS_PASSWORD", + ] + + print("\nRedis Environment Variables:") + for var in redis_vars: + value = os.getenv(var, "Not set") + # Mask password for security + if "PASSWORD" in var and value != "Not set": + value = "*" * len(value) + print(f" {var}: {value}") + + +def test_manual_env_loading(): + """Test loading environment variables manually from .env file""" + print("\n" + "=" * 60) + print("Testing Manual Environment Loading") + print("=" * 60) + + env_file_path = "/Users/travistang/Documents/codes/memos/.env" + + if not os.path.exists(env_file_path): + print(f"❌ Environment file not found: {env_file_path}") + return + + try: + from dotenv import load_dotenv + + # Load environment variables + load_dotenv(env_file_path) + print(f"✅ Successfully loaded environment variables from {env_file_path}") + + # Test some key variables + test_vars = ["OPENAI_API_KEY", "MOS_CHAT_MODEL", "TZ"] + for var in test_vars: + value = os.getenv(var, "Not set") + if "KEY" in var and value != "Not set": + value = f"{value[:10]}..." if len(value) > 10 else value + print(f" {var}: {value}") + + except ImportError: + print("❌ python-dotenv not installed. Install with: pip install python-dotenv") + except Exception as e: + print(f"❌ Error loading environment file: {e}") + + +def test_redis_lockable_orm_with_list(): + """Test RedisDBManager with list[str] type synchronization""" + print("\n" + "=" * 60) + print("Testing RedisDBManager with list[str]") + print("=" * 60) + + try: + from memos.mem_scheduler.orm_modules.redis_model import RedisDBManager + + # Create a simple list manager instance + list_manager = SimpleListManager(["apple", "banana", "cherry"]) + print(f"Original list manager: {list_manager}") + + # Create RedisDBManager instance + redis_client = BaseDBManager.load_redis_engine_from_env() + if redis_client is None: + print("❌ Failed to create Redis connection - check environment variables") + return + + db_manager = RedisDBManager( + redis_client=redis_client, + user_id="test_user", + mem_cube_id="test_list_cube", + obj=list_manager, + ) + + # Save to Redis + db_manager.save_to_db(list_manager) + print("✅ List manager saved to Redis") + + # Load from Redis + loaded_manager = db_manager.load_from_db() + if loaded_manager: + print(f"Loaded list manager: {loaded_manager}") + print(f"Items match: {list_manager.items == loaded_manager.items}") + else: + print("❌ Failed to load list manager from Redis") + + # Clean up + redis_client.delete("lockable_orm:test_user:test_list_cube:data") + redis_client.delete("lockable_orm:test_user:test_list_cube:lock") + redis_client.delete("lockable_orm:test_user:test_list_cube:version") + redis_client.close() + + except Exception as e: + print(f"❌ Error in RedisDBManager test: {e}") + + +def modify_list_process(process_id: int, items_to_add: list[str]): + """Function to be run in separate processes to modify the list using merge_items""" + try: + from memos.mem_scheduler.orm_modules.redis_model import RedisDBManager + + # Create Redis connection + redis_client = BaseDBManager.load_redis_engine_from_env() + if redis_client is None: + print(f"Process {process_id}: Failed to create Redis connection") + return + + # Create a temporary list manager for this process with items to add + temp_manager = SimpleListManager() + + db_manager = RedisDBManager( + redis_client=redis_client, + user_id="test_user", + mem_cube_id="multiprocess_list", + obj=temp_manager, + ) + + print(f"Process {process_id}: Starting modification with items: {items_to_add}") + for item in items_to_add: + db_manager.obj.add_item(item) + # Use sync_with_orm which internally uses merge_items + db_manager.sync_with_orm(size_limit=None) + + print(f"Process {process_id}: Successfully synchronized with Redis") + + redis_client.close() + + except Exception as e: + print(f"Process {process_id}: Error - {e}") + import traceback + + traceback.print_exc() + + +def test_multiprocess_synchronization(): + """Test multiprocess synchronization with RedisDBManager""" + print("\n" + "=" * 60) + print("Testing Multiprocess Synchronization") + print("=" * 60) + + try: + # Initialize Redis with empty list + redis_client = BaseDBManager.load_redis_engine_from_env() + if redis_client is None: + print("❌ Failed to create Redis connection") + return + + # Initialize with empty list + initial_manager = SimpleListManager([]) + db_manager = RedisDBManager( + redis_client=redis_client, + user_id="test_user", + mem_cube_id="multiprocess_list", + obj=initial_manager, + ) + db_manager.save_to_db(initial_manager) + print("✅ Initialized empty list manager in Redis") + + # Define items for each process to add + process_items = [ + ["item1", "item2"], + ["item3", "item4"], + ["item5", "item6"], + ["item1", "item7"], # item1 is duplicate, should not be added twice + ] + + # Create and start processes + processes = [] + for i, items in enumerate(process_items): + p = multiprocessing.Process(target=modify_list_process, args=(i + 1, items)) + processes.append(p) + p.start() + + # Wait for all processes to complete + for p in processes: + p.join() + + print("\n" + "-" * 40) + print("All processes completed. Checking final result...") + + # Load final result + final_db_manager = RedisDBManager( + redis_client=redis_client, + user_id="test_user", + mem_cube_id="multiprocess_list", + obj=SimpleListManager([]), + ) + final_manager = final_db_manager.load_from_db() + + if final_manager: + print(f"Final synchronized list manager: {final_manager}") + print(f"Final list length: {len(final_manager)}") + print("Expected items: {'item1', 'item2', 'item3', 'item4', 'item5', 'item6', 'item7'}") + print(f"Actual items: {set(final_manager.items)}") + + # Check if all unique items are present + expected_items = {"item1", "item2", "item3", "item4", "item5", "item6", "item7"} + actual_items = set(final_manager.items) + + if expected_items == actual_items: + print("✅ All processes contributed correctly - synchronization successful!") + else: + print(f"❌ Expected items: {expected_items}") + print(f" Actual items: {actual_items}") + else: + print("❌ Failed to load final result") + + # Clean up + redis_client.delete("lockable_orm:test_user:multiprocess_list:data") + redis_client.delete("lockable_orm:test_user:multiprocess_list:lock") + redis_client.delete("lockable_orm:test_user:multiprocess_list:version") + redis_client.close() + + except Exception as e: + print(f"❌ Error in multiprocess synchronization test: {e}") + + +def main(): + """Main function to run all tests""" + print("ORM Examples - Environment Variable Loading Tests") + print("=" * 80) + + # Test environment variables display + test_environment_variables() + + # Test manual environment loading + test_manual_env_loading() + + # Test MySQL engine loading + test_mysql_engine_from_env() + + # Test Redis connection loading + test_redis_connection_from_env() + + # Test RedisLockableORM with list[str] + test_redis_lockable_orm_with_list() + + # Test multiprocess synchronization + test_multiprocess_synchronization() + + print("\n" + "=" * 80) + print("All tests completed!") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 86751b008..d14c05993 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, Field # Import message types from core types module +from memos.mem_scheduler.schemas.general_schemas import SearchMode from memos.types import MessageDict, PermissionDict @@ -170,7 +171,7 @@ class APISearchRequest(BaseRequest): query: str = Field(..., description="Search query") user_id: str = Field(None, description="User ID") mem_cube_id: str | None = Field(None, description="Cube ID to search in") - mode: str = Field("fast", description="search mode fast or fine") + mode: SearchMode = Field(SearchMode.FINE, description="search mode: fast, fine, or mixture") internet_search: bool = Field(False, description="Whether to use internet search") moscube: bool = Field(False, description="Whether to use MemOSCube") top_k: int = Field(10, description="Number of results to return") diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index a332de583..9f982ddd3 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -1,3 +1,4 @@ +import json import os import traceback @@ -18,6 +19,7 @@ from memos.configs.internet_retriever import InternetRetrieverConfigFactory from memos.configs.llm import LLMConfigFactory from memos.configs.mem_reader import MemReaderConfigFactory +from memos.configs.mem_scheduler import SchedulerConfigFactory from memos.configs.reranker import RerankerConfigFactory from memos.embedders.factory import EmbedderFactory from memos.graph_dbs.factory import GraphStoreFactory @@ -26,6 +28,14 @@ from memos.mem_cube.navie import NaiveMemCube from memos.mem_os.product_server import MOSServer from memos.mem_reader.factory import MemReaderFactory +from memos.mem_scheduler.orm_modules.base_model import BaseDBManager +from memos.mem_scheduler.scheduler_factory import SchedulerFactory +from memos.mem_scheduler.schemas.general_schemas import ( + API_MIX_SEARCH_LABEL, + SearchMode, +) +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.utils.db_utils import get_utc_now from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( InternetRetrieverFactory, @@ -134,6 +144,34 @@ def init_server(): llm=llm, online_bot=False, ) + + # Initialize Scheduler + scheduler_config_dict = APIConfig.get_scheduler_config() + scheduler_config = SchedulerConfigFactory( + backend="optimized_scheduler", config=scheduler_config_dict + ) + mem_scheduler = SchedulerFactory.from_config(scheduler_config) + mem_scheduler.initialize_modules( + chat_llm=llm, + process_llm=mem_reader.llm, + db_engine=BaseDBManager.create_default_sqlite_engine(), + ) + mem_scheduler.start() + + # Initialize SchedulerAPIModule + api_module = mem_scheduler.api_module + + naive_mem_cube = NaiveMemCube( + llm=llm, + embedder=embedder, + mem_reader=mem_reader, + graph_db=graph_db, + reranker=reranker, + internet_retriever=internet_retriever, + memory_manager=memory_manager, + default_cube_config=default_cube_config, + ) + return ( graph_db, mem_reader, @@ -144,6 +182,9 @@ def init_server(): memory_manager, default_cube_config, mos_server, + mem_scheduler, + naive_mem_cube, + api_module, ) @@ -158,24 +199,12 @@ def init_server(): memory_manager, default_cube_config, mos_server, + mem_scheduler, + naive_mem_cube, + api_module, ) = init_server() -def _create_naive_mem_cube() -> NaiveMemCube: - """Create a NaiveMemCube instance with initialized components.""" - naive_mem_cube = NaiveMemCube( - llm=llm, - embedder=embedder, - mem_reader=mem_reader, - graph_db=graph_db, - reranker=reranker, - internet_retriever=internet_retriever, - memory_manager=memory_manager, - default_cube_config=default_cube_config, - ) - return naive_mem_cube - - def _format_memory_item(memory_data: Any) -> dict[str, Any]: """Format a single memory item for API response.""" memory = memory_data.model_dump() @@ -207,18 +236,146 @@ def search_memories(search_req: APISearchRequest): "act_mem": [], "para_mem": [], } + + search_mode = search_req.mode + + if search_mode == SearchMode.FAST: + formatted_memories = fast_search_memories(search_req=search_req, user_context=user_context) + elif search_mode == SearchMode.FINE: + formatted_memories = fine_search_memories(search_req=search_req, user_context=user_context) + elif search_mode == SearchMode.MIXTURE: + formatted_memories = mix_search_memories(search_req=search_req, user_context=user_context) + else: + logger.error(f"Unsupported search mode: {search_mode}") + raise HTTPException(status_code=400, detail=f"Unsupported search mode: {search_mode}") + + memories_result["text_mem"].append( + { + "cube_id": search_req.mem_cube_id, + "memories": formatted_memories, + } + ) + + return SearchResponse( + message="Search completed successfully", + data=memories_result, + ) + + +def mix_search_memories( + search_req: APISearchRequest, + user_context: UserContext, +): + """ + Mix search memories: fast search + async fine search + """ + # Get fast memories first + fast_memories = fast_search_memories(search_req, user_context) + + # Check if scheduler and dispatcher are available for async execution + if mem_scheduler and hasattr(mem_scheduler, "dispatcher") and mem_scheduler.dispatcher: + try: + # Create message for async fine search + message_content = { + "search_req": { + "query": search_req.query, + "user_id": search_req.user_id, + "session_id": search_req.session_id, + "top_k": search_req.top_k, + "internet_search": search_req.internet_search, + "moscube": search_req.moscube, + "chat_history": search_req.chat_history, + }, + "user_context": {"mem_cube_id": user_context.mem_cube_id}, + } + + message = ScheduleMessageItem( + item_id=f"mix_search_{search_req.user_id}_{get_utc_now().timestamp()}", + user_id=search_req.user_id, + mem_cube_id=user_context.mem_cube_id, + label=API_MIX_SEARCH_LABEL, + mem_cube=naive_mem_cube, + content=json.dumps(message_content), + timestamp=get_utc_now(), + ) + + # Submit async task + mem_scheduler.dispatcher.submit_message(message) + logger.info(f"Submitted async fine search task for user {search_req.user_id}") + + # Try to get pre-computed fine memories if available + try: + pre_fine_memories = api_module.get_pre_fine_memories( + user_id=search_req.user_id, mem_cube_id=user_context.mem_cube_id + ) + if pre_fine_memories: + # Merge fast and pre-computed fine memories + all_memories = fast_memories + pre_fine_memories + # Remove duplicates based on content + seen_contents = set() + unique_memories = [] + for memory in all_memories: + content_key = memory.get("content", "") + if content_key not in seen_contents: + seen_contents.add(content_key) + unique_memories.append(memory) + return unique_memories + except Exception as e: + logger.warning(f"Failed to get pre-computed fine memories: {e}") + + except Exception as e: + logger.error(f"Failed to submit async fine search task: {e}") + # Fall back to synchronous execution + + # Fallback: synchronous fine search + try: + fine_memories = fine_search_memories(search_req, user_context) + + # Merge fast and fine memories + all_memories = fast_memories + fine_memories + + # Remove duplicates based on content + seen_contents = set() + unique_memories = [] + for memory in all_memories: + content_key = memory.get("content", "") + if content_key not in seen_contents: + seen_contents.add(content_key) + unique_memories.append(memory) + + # Sync search data to Redis + try: + api_module.sync_search_data( + user_id=search_req.user_id, + mem_cube_id=user_context.mem_cube_id, + query=search_req.query, + formatted_memories=unique_memories, + ) + except Exception as e: + logger.error(f"Failed to sync search data: {e}") + + return unique_memories + + except Exception as e: + logger.error(f"Fine search failed: {e}") + return fast_memories + + +def fine_search_memories( + search_req: APISearchRequest, + user_context: UserContext, +): target_session_id = search_req.session_id if not target_session_id: target_session_id = "default_session" search_filter = {"session_id": search_req.session_id} if search_req.session_id else None # Create MemCube and perform search - naive_mem_cube = _create_naive_mem_cube() search_results = naive_mem_cube.text_mem.search( query=search_req.query, user_name=user_context.mem_cube_id, top_k=search_req.top_k, - mode=search_req.mode, + mode=SearchMode.FINE, manual_close_internet=not search_req.internet_search, moscube=search_req.moscube, search_filter=search_filter, @@ -230,17 +387,36 @@ def search_memories(search_req: APISearchRequest): ) formatted_memories = [_format_memory_item(data) for data in search_results] - memories_result["text_mem"].append( - { - "cube_id": search_req.mem_cube_id, - "memories": formatted_memories, - } - ) + return formatted_memories - return SearchResponse( - message="Search completed successfully", - data=memories_result, + +def fast_search_memories( + search_req: APISearchRequest, + user_context: UserContext, +): + target_session_id = search_req.session_id + if not target_session_id: + target_session_id = "default_session" + search_filter = {"session_id": search_req.session_id} if search_req.session_id else None + + # Create MemCube and perform search + search_results = naive_mem_cube.text_mem.search( + query=search_req.query, + user_name=user_context.mem_cube_id, + top_k=search_req.top_k, + mode=SearchMode.FAST, + manual_close_internet=not search_req.internet_search, + moscube=search_req.moscube, + search_filter=search_filter, + info={ + "user_id": search_req.user_id, + "session_id": target_session_id, + "chat_history": search_req.chat_history, + }, ) + formatted_memories = [_format_memory_item(data) for data in search_results] + + return formatted_memories @router.post("/add", summary="Add memories", response_model=MemoryResponse) @@ -252,7 +428,6 @@ def add_memories(add_req: APIADDRequest): mem_cube_id=add_req.mem_cube_id, session_id=add_req.session_id or "default_session", ) - naive_mem_cube = _create_naive_mem_cube() target_session_id = add_req.session_id if not target_session_id: target_session_id = "default_session" @@ -296,7 +471,6 @@ def chat_complete(chat_req: APIChatCompleteRequest): """Chat with MemOS for a specific user. Returns complete response (non-streaming).""" try: # Collect all responses from the generator - naive_mem_cube = _create_naive_mem_cube() content, references = mos_server.chat( query=chat_req.query, user_id=chat_req.user_id, diff --git a/src/memos/configs/mem_scheduler.py b/src/memos/configs/mem_scheduler.py index 2d6155ec2..bc22cfb63 100644 --- a/src/memos/configs/mem_scheduler.py +++ b/src/memos/configs/mem_scheduler.py @@ -11,8 +11,14 @@ from memos.mem_scheduler.schemas.general_schemas import ( BASE_DIR, DEFAULT_ACT_MEM_DUMP_PATH, + DEFAULT_ACTIVATION_MEM_MONITOR_SIZE_LIMIT, DEFAULT_CONSUME_INTERVAL_SECONDS, + DEFAULT_CONTEXT_WINDOW_SIZE, + DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE, DEFAULT_THREAD_POOL_MAX_WORKERS, + DEFAULT_TOP_K, + DEFAULT_USE_REDIS_QUEUE, + DEFAULT_WORKING_MEM_MONITOR_SIZE_LIMIT, ) @@ -20,7 +26,8 @@ class BaseSchedulerConfig(BaseConfig): """Base configuration class for mem_scheduler.""" top_k: int = Field( - default=10, description="Number of top candidates to consider in initial retrieval" + default=DEFAULT_TOP_K, + description="Number of top candidates to consider in initial retrieval", ) enable_parallel_dispatch: bool = Field( default=True, description="Whether to enable parallel message processing using thread pool" @@ -39,6 +46,19 @@ class BaseSchedulerConfig(BaseConfig): default=None, description="Path to the authentication configuration file containing private credentials", ) + # Redis queue configuration + use_redis_queue: bool = Field( + default=DEFAULT_USE_REDIS_QUEUE, + description="Whether to use Redis queue instead of local memory queue", + ) + redis_config: dict[str, Any] = Field( + default_factory=lambda: {"host": "localhost", "port": 6379, "db": 0}, + description="Redis connection configuration", + ) + max_internal_message_queue_size: int = Field( + default=DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE, + description="Maximum size of internal message queue when not using Redis", + ) class GeneralSchedulerConfig(BaseSchedulerConfig): @@ -47,7 +67,8 @@ class GeneralSchedulerConfig(BaseSchedulerConfig): default=300, description="Interval in seconds for updating activation memory" ) context_window_size: int | None = Field( - default=10, description="Size of the context window for conversation history" + default=DEFAULT_CONTEXT_WINDOW_SIZE, + description="Size of the context window for conversation history", ) act_mem_dump_path: str | None = Field( default=DEFAULT_ACT_MEM_DUMP_PATH, # Replace with DEFAULT_ACT_MEM_DUMP_PATH @@ -57,10 +78,12 @@ class GeneralSchedulerConfig(BaseSchedulerConfig): default=False, description="Whether to enable automatic activation memory updates" ) working_mem_monitor_capacity: int = Field( - default=30, description="Capacity of the working memory monitor" + default=DEFAULT_WORKING_MEM_MONITOR_SIZE_LIMIT, + description="Capacity of the working memory monitor", ) activation_mem_monitor_capacity: int = Field( - default=20, description="Capacity of the activation memory monitor" + default=DEFAULT_ACTIVATION_MEM_MONITOR_SIZE_LIMIT, + description="Capacity of the activation memory monitor", ) # Database configuration for ORM persistence @@ -77,6 +100,14 @@ class GeneralSchedulerConfig(BaseSchedulerConfig): ) +class OptimizedSchedulerConfig(GeneralSchedulerConfig): + """Configuration for the optimized scheduler. + + This class inherits all fields from `GeneralSchedulerConfig` + and is used to distinguish optimized scheduling logic via type. + """ + + class SchedulerConfigFactory(BaseConfig): """Factory class for creating scheduler configurations.""" @@ -86,7 +117,7 @@ class SchedulerConfigFactory(BaseConfig): model_config = ConfigDict(extra="forbid", strict=True) backend_to_class: ClassVar[dict[str, Any]] = { "general_scheduler": GeneralSchedulerConfig, - "optimized_scheduler": GeneralSchedulerConfig, # optimized_scheduler uses same config as general_scheduler + "optimized_scheduler": OptimizedSchedulerConfig, # optimized_scheduler uses same config as general_scheduler } @field_validator("backend") diff --git a/src/memos/llms/hf.py b/src/memos/llms/hf.py index 00081b581..be0d1d95f 100644 --- a/src/memos/llms/hf.py +++ b/src/memos/llms/hf.py @@ -379,10 +379,52 @@ def build_kv_cache(self, messages) -> DynamicCache: raise ValueError( "Prompt after chat template is empty, cannot build KV cache. Check your messages input." ) - kv = DynamicCache() + # Create cache and perform forward pass without pre-existing cache with torch.no_grad(): - self.model(**inputs, use_cache=True, past_key_values=kv) - for i, (k, v) in enumerate(zip(kv.key_cache, kv.value_cache, strict=False)): - kv.key_cache[i] = k[:, :, :seq_len, :] - kv.value_cache[i] = v[:, :, :seq_len, :] - return kv + outputs = self.model(**inputs, use_cache=True) + + # Get the cache from model outputs + if hasattr(outputs, "past_key_values") and outputs.past_key_values is not None: + kv = outputs.past_key_values + + # Convert from legacy tuple format to DynamicCache if needed + if isinstance(kv, tuple): + kv = DynamicCache.from_legacy_cache(kv) + + # Handle compatibility between old and new transformers versions + # In newer versions, DynamicCache uses 'layers' attribute + # In older versions, it uses 'key_cache' and 'value_cache' attributes + if hasattr(kv, "layers"): + # New version: trim cache using layers attribute + for layer in kv.layers: + if hasattr(layer, "key_cache") and hasattr(layer, "value_cache"): + # Trim each layer's cache to the sequence length + if layer.key_cache is not None: + layer.key_cache = layer.key_cache[:, :, :seq_len, :] + if layer.value_cache is not None: + layer.value_cache = layer.value_cache[:, :, :seq_len, :] + elif hasattr(layer, "keys") and hasattr(layer, "values"): + # Alternative attribute names in some versions + if layer.keys is not None: + layer.keys = layer.keys[:, :, :seq_len, :] + if layer.values is not None: + layer.values = layer.values[:, :, :seq_len, :] + elif hasattr(kv, "key_cache") and hasattr(kv, "value_cache"): + # Old version: trim cache using key_cache and value_cache attributes + for i in range(len(kv.key_cache)): + if kv.key_cache[i] is not None: + kv.key_cache[i] = kv.key_cache[i][:, :, :seq_len, :] + if kv.value_cache[i] is not None: + kv.value_cache[i] = kv.value_cache[i][:, :, :seq_len, :] + else: + # Fallback: log warning but continue without trimming + logger.warning( + f"DynamicCache object of type {type(kv)} has unexpected structure. " + f"Cache trimming skipped. Available attributes: {dir(kv)}" + ) + + return kv + else: + raise RuntimeError( + "Failed to build KV cache: no cache data available from model outputs" + ) diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index 0010897c0..cedffd6fb 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -310,18 +310,20 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None = past_key_values = None if self.config.enable_activation_memory: - assert self.config.chat_model.backend == "huggingface", ( - "Activation memory only used for huggingface backend." - ) - # TODO this only one cubes - for mem_cube_id, mem_cube in self.mem_cubes.items(): - if mem_cube_id not in user_cube_ids: - continue - if mem_cube.act_mem: - kv_cache = next(iter(mem_cube.act_mem.get_all()), None) - past_key_values = ( - kv_cache.memory if (kv_cache and hasattr(kv_cache, "memory")) else None - ) + if self.config.chat_model.backend != "huggingface": + logger.error( + "Activation memory only used for huggingface backend. Skipping activation memory." + ) + else: + # TODO this only one cubes + for mem_cube_id, mem_cube in self.mem_cubes.items(): + if mem_cube_id not in user_cube_ids: + continue + if mem_cube.act_mem: + kv_cache = next(iter(mem_cube.act_mem.get_all()), None) + past_key_values = ( + kv_cache.memory if (kv_cache and hasattr(kv_cache, "memory")) else None + ) break # Generate response response = self.chat_llm.generate(current_messages, past_key_values=past_key_values) diff --git a/src/memos/mem_os/main.py b/src/memos/mem_os/main.py index 2e5b32548..6fc64c5e3 100644 --- a/src/memos/mem_os/main.py +++ b/src/memos/mem_os/main.py @@ -312,23 +312,25 @@ def _generate_enhanced_response_with_context( # Handle activation memory if enabled (same as core method) past_key_values = None if self.config.enable_activation_memory: - assert self.config.chat_model.backend == "huggingface", ( - "Activation memory only used for huggingface backend." - ) - # Get accessible cubes for the user - target_user_id = user_id if user_id is not None else self.user_id - accessible_cubes = self.user_manager.get_user_cubes(target_user_id) - user_cube_ids = [cube.cube_id for cube in accessible_cubes] - - for mem_cube_id, mem_cube in self.mem_cubes.items(): - if mem_cube_id not in user_cube_ids: - continue - if mem_cube.act_mem: - kv_cache = next(iter(mem_cube.act_mem.get_all()), None) - past_key_values = ( - kv_cache.memory if (kv_cache and hasattr(kv_cache, "memory")) else None - ) - break + if self.config.chat_model.backend != "huggingface": + logger.error( + "Activation memory only used for huggingface backend. Skipping activation memory." + ) + else: + # Get accessible cubes for the user + target_user_id = user_id if user_id is not None else self.user_id + accessible_cubes = self.user_manager.get_user_cubes(target_user_id) + user_cube_ids = [cube.cube_id for cube in accessible_cubes] + + for mem_cube_id, mem_cube in self.mem_cubes.items(): + if mem_cube_id not in user_cube_ids: + continue + if mem_cube.act_mem: + kv_cache = next(iter(mem_cube.act_mem.get_all()), None) + past_key_values = ( + kv_cache.memory if (kv_cache and hasattr(kv_cache, "memory")) else None + ) + break try: # Generate the enhanced response using the chat LLM with same parameters as core diff --git a/src/memos/mem_scheduler/analyzer/api_analyzer.py b/src/memos/mem_scheduler/analyzer/api_analyzer.py index e69de29bb..45a39e0de 100644 --- a/src/memos/mem_scheduler/analyzer/api_analyzer.py +++ b/src/memos/mem_scheduler/analyzer/api_analyzer.py @@ -0,0 +1,619 @@ +""" +API Analyzer for Scheduler + +This module provides the APIAnalyzerForScheduler class that handles API requests +for search and add operations with reusable instance variables. +""" + +import http.client +import json + +from typing import Any +from urllib.parse import urlparse + +import requests + +from memos.log import get_logger + + +logger = get_logger(__name__) + + +class APIAnalyzerForScheduler: + """ + API Analyzer class for scheduler operations. + + This class provides methods to interact with APIs for search and add operations, + with reusable instance variables for better performance and configuration management. + """ + + def __init__( + self, + base_url: str = "http://127.0.0.1:8002", + default_headers: dict[str, str] | None = None, + timeout: int = 30, + ): + """ + Initialize the APIAnalyzerForScheduler. + + Args: + base_url: Base URL for API requests + default_headers: Default headers to use for all requests + timeout: Request timeout in seconds + """ + self.base_url = base_url.rstrip("/") + self.timeout = timeout + + # Default headers + self.default_headers = default_headers or {"Content-Type": "application/json"} + + # Parse URL for http.client usage + parsed_url = urlparse(self.base_url) + self.host = parsed_url.hostname + self.port = parsed_url.port or 8002 + self.is_https = parsed_url.scheme == "https" + + # Reusable connection for http.client + self._connection = None + + # Attributes + self.user_id = "test_user_id" + self.mem_cube_id = "test_mem_cube_id" + + logger.info(f"APIAnalyzerForScheduler initialized with base_url: {self.base_url}") + + def _get_connection(self) -> http.client.HTTPConnection | http.client.HTTPSConnection: + """ + Get or create a reusable HTTP connection. + + Returns: + HTTP connection object + """ + if self._connection is None: + if self.is_https: + self._connection = http.client.HTTPSConnection(self.host, self.port) + else: + self._connection = http.client.HTTPConnection(self.host, self.port) + return self._connection + + def _close_connection(self): + """Close the HTTP connection if it exists.""" + if self._connection: + self._connection.close() + self._connection = None + + def search( + self, user_id: str, mem_cube_id: str, query: str, top: int = 50, use_requests: bool = True + ) -> dict[str, Any]: + """ + Search for memories using the product/search API endpoint. + + Args: + user_id: User identifier + mem_cube_id: Memory cube identifier + query: Search query string + top: Number of top results to return + use_requests: Whether to use requests library (True) or http.client (False) + + Returns: + Dictionary containing the API response + """ + payload = {"user_id": user_id, "mem_cube_id": mem_cube_id, "query": query, "top": top} + + try: + if use_requests: + return self._search_with_requests(payload) + else: + return self._search_with_http_client(payload) + except Exception as e: + logger.error(f"Error in search operation: {e}") + return {"error": str(e), "success": False} + + def _search_with_requests(self, payload: dict[str, Any]) -> dict[str, Any]: + """ + Perform search using requests library. + + Args: + payload: Request payload + + Returns: + Dictionary containing the API response + """ + url = f"{self.base_url}/product/search" + + response = requests.post( + url, headers=self.default_headers, data=json.dumps(payload), timeout=self.timeout + ) + + logger.info(f"Search request to {url} completed with status: {response.status_code}") + + try: + return { + "success": True, + "status_code": response.status_code, + "data": response.json() if response.content else {}, + "text": response.text, + } + except json.JSONDecodeError: + return { + "success": True, + "status_code": response.status_code, + "data": {}, + "text": response.text, + } + + def _search_with_http_client(self, payload: dict[str, Any]) -> dict[str, Any]: + """ + Perform search using http.client. + + Args: + payload: Request payload + + Returns: + Dictionary containing the API response + """ + conn = self._get_connection() + + try: + conn.request("POST", "/product/search", json.dumps(payload), self.default_headers) + + response = conn.getresponse() + data = response.read() + response_text = data.decode("utf-8") + + logger.info(f"Search request completed with status: {response.status}") + + try: + response_data = json.loads(response_text) if response_text else {} + except json.JSONDecodeError: + response_data = {} + + return { + "success": True, + "status_code": response.status, + "data": response_data, + "text": response_text, + } + except Exception as e: + logger.error(f"Error in http.client search: {e}") + return {"error": str(e), "success": False} + + def add( + self, messages: list, user_id: str, mem_cube_id: str, use_requests: bool = True + ) -> dict[str, Any]: + """ + Add memories using the product/add API endpoint. + + Args: + messages: List of message objects with role and content + user_id: User identifier + mem_cube_id: Memory cube identifier + use_requests: Whether to use requests library (True) or http.client (False) + + Returns: + Dictionary containing the API response + """ + payload = {"messages": messages, "user_id": user_id, "mem_cube_id": mem_cube_id} + + try: + if use_requests: + return self._add_with_requests(payload) + else: + return self._add_with_http_client(payload) + except Exception as e: + logger.error(f"Error in add operation: {e}") + return {"error": str(e), "success": False} + + def _add_with_requests(self, payload: dict[str, Any]) -> dict[str, Any]: + """ + Perform add using requests library. + + Args: + payload: Request payload + + Returns: + Dictionary containing the API response + """ + url = f"{self.base_url}/product/add" + + response = requests.post( + url, headers=self.default_headers, data=json.dumps(payload), timeout=self.timeout + ) + + logger.info(f"Add request to {url} completed with status: {response.status_code}") + + try: + return { + "success": True, + "status_code": response.status_code, + "data": response.json() if response.content else {}, + "text": response.text, + } + except json.JSONDecodeError: + return { + "success": True, + "status_code": response.status_code, + "data": {}, + "text": response.text, + } + + def _add_with_http_client(self, payload: dict[str, Any]) -> dict[str, Any]: + """ + Perform add using http.client. + + Args: + payload: Request payload + + Returns: + Dictionary containing the API response + """ + conn = self._get_connection() + + try: + conn.request("POST", "/product/add", json.dumps(payload), self.default_headers) + + response = conn.getresponse() + data = response.read() + response_text = data.decode("utf-8") + + logger.info(f"Add request completed with status: {response.status}") + + try: + response_data = json.loads(response_text) if response_text else {} + except json.JSONDecodeError: + response_data = {} + + return { + "success": True, + "status_code": response.status, + "data": response_data, + "text": response_text, + } + except Exception as e: + logger.error(f"Error in http.client add: {e}") + return {"error": str(e), "success": False} + + def update_base_url(self, new_base_url: str): + """ + Update the base URL and reinitialize connection parameters. + + Args: + new_base_url: New base URL for API requests + """ + self._close_connection() + self.base_url = new_base_url.rstrip("/") + + # Re-parse URL + parsed_url = urlparse(self.base_url) + self.host = parsed_url.hostname + self.port = parsed_url.port or (443 if parsed_url.scheme == "https" else 80) + self.is_https = parsed_url.scheme == "https" + + logger.info(f"Base URL updated to: {self.base_url}") + + def update_headers(self, headers: dict[str, str]): + """ + Update default headers. + + Args: + headers: New headers to merge with existing ones + """ + self.default_headers.update(headers) + logger.info("Headers updated") + + def __del__(self): + """Cleanup method to close connection when object is destroyed.""" + self._close_connection() + + def analyze_service(self): + # Example add operation + messages = [ + {"role": "user", "content": "Where should I go for New Year's Eve in Shanghai?"}, + { + "role": "assistant", + "content": "You could head to the Bund for the countdown, attend a rooftop party, or enjoy the fireworks at Disneyland Shanghai.", + }, + ] + + add_result = self.add( + messages=messages, user_id="test_user_id", mem_cube_id="test_mem_cube_id" + ) + print("Add result:", add_result) + + # Example search operation + search_result = self.search( + user_id="test_user_id", + mem_cube_id="test_mem_cube_id", + query="What are some good places to celebrate New Year's Eve in Shanghai?", + top=50, + ) + print("Search result:", search_result) + + def analyze_features(self): + try: + # Test basic search functionality + search_result = self.search( + user_id="test_user_id", + mem_cube_id="test_mem_cube_id", + query="What are some good places to celebrate New Year's Eve in Shanghai?", + top=50, + ) + print("Search result:", search_result) + except Exception as e: + logger.error(f"Feature analysis failed: {e}") + + +class DirectSearchMemoriesAnalyzer: + """ + Direct analyzer for testing search_memories function + Used for debugging and analyzing search_memories function behavior without starting a full API server + """ + + def __init__(self): + """Initialize the analyzer""" + # Import necessary modules + try: + from memos.api.product_models import APIADDRequest, APISearchRequest + from memos.api.routers.server_router import add_memories, search_memories + from memos.types import MessageDict, UserContext + + self.APISearchRequest = APISearchRequest + self.APIADDRequest = APIADDRequest + self.search_memories = search_memories + self.add_memories = add_memories + self.UserContext = UserContext + self.MessageDict = MessageDict + + logger.info("DirectSearchMemoriesAnalyzer initialized successfully") + except ImportError as e: + logger.error(f"Failed to import modules: {e}") + raise + + def create_test_search_request( + self, + query="test query", + user_id="test_user", + mem_cube_id="test_cube", + mode="fast", + top_k=10, + chat_history=None, + session_id=None, + ): + """ + Create a test APISearchRequest object with the given parameters. + + Args: + query: Search query string + user_id: User ID for the request + mem_cube_id: Memory cube ID for the request + mode: Search mode ("fast" or "fine") + top_k: Number of results to return + chat_history: Chat history for context (optional) + session_id: Session ID for the request (optional) + + Returns: + APISearchRequest: A configured request object + """ + return self.APISearchRequest( + query=query, + user_id=user_id, + mem_cube_id=mem_cube_id, + mode=mode, + top_k=top_k, + chat_history=chat_history, + session_id=session_id, + ) + + def create_test_add_request( + self, + user_id="test_user", + mem_cube_id="test_cube", + messages=None, + memory_content=None, + session_id=None, + ): + """ + Create a test APIADDRequest object with the given parameters. + + Args: + user_id: User ID for the request + mem_cube_id: Memory cube ID for the request + messages: List of messages to add (optional) + memory_content: Direct memory content to add (optional) + session_id: Session ID for the request (optional) + + Returns: + APIADDRequest: A configured request object + """ + if messages is None and memory_content is None: + # Default test messages + messages = [ + {"role": "user", "content": "What's the weather like today?"}, + { + "role": "assistant", + "content": "I don't have access to real-time weather data, but you can check a weather app or website for current conditions.", + }, + ] + + # Ensure we have a valid session_id + if session_id is None: + session_id = "test_session_" + str(hash(user_id + mem_cube_id))[:8] + + return self.APIADDRequest( + user_id=user_id, + mem_cube_id=mem_cube_id, + messages=messages, + memory_content=memory_content, + session_id=session_id, + doc_path=None, + source="api_analyzer_test", + chat_history=None, + operation=None, + ) + + def test_add_memories_basic(self, user_id="test_user_add", mem_cube_id="test_cube_add"): + """Basic add_memories test""" + print("=" * 60) + print("Starting basic add_memories test") + print("=" * 60) + + try: + # Create test request with default messages + add_req = self.create_test_add_request(user_id=user_id, mem_cube_id=mem_cube_id) + + print("Test request created:") + print(f" User ID: {add_req.user_id}") + print(f" Mem Cube ID: {add_req.mem_cube_id}") + print(f" Messages: {add_req.messages}") + print(f" Session ID: {add_req.session_id}") + + # Call add_memories function + print("\nCalling add_memories function...") + result = self.add_memories(add_req) + + print(f"Add result: {result}") + print("Basic add_memories test completed successfully") + return result + + except Exception as e: + print(f"Basic add_memories test failed: {e}") + import traceback + + traceback.print_exc() + return None + + def test_search_memories_basic(self, query: str, mode: str, topk: int): + """Basic search_memories test""" + print("=" * 60) + print("Starting basic search_memories test") + print("=" * 60) + + try: + # Create test request + search_req = self.create_test_search_request( + query=query, + user_id="test_user_id", + mem_cube_id="test_mem_cube_id", + mode=mode, + top_k=topk, + ) + + print("Test request parameters:") + print(f" - query: {search_req.query}") + print(f" - user_id: {search_req.user_id}") + print(f" - mem_cube_id: {search_req.mem_cube_id}") + print(f" - mode: {search_req.mode}") + print(f" - top_k: {search_req.top_k}") + print(f" - internet_search: {search_req.internet_search}") + print(f" - moscube: {search_req.moscube}") + print() + + # Call search_memories function + print("Calling search_memories function...") + result = self.search_memories(search_req) + + print("✅ Function call successful!") + print(f"Return result type: {type(result)}") + print(f"Return result: {result}") + + # Analyze return result + if hasattr(result, "message"): + print(f"Message: {result.message}") + if hasattr(result, "data"): + print(f"Data type: {type(result.data)}") + if result.data and isinstance(result.data, dict): + for key, value in result.data.items(): + print(f" {key}: {len(value) if isinstance(value, list) else value}") + + return result + + except Exception as e: + print(f"❌ Test failed: {e}") + import traceback + + print("Detailed error information:") + traceback.print_exc() + return None + + def run_all_tests(self): + """Run all available tests""" + print("🚀 Starting comprehensive test suite") + print("=" * 80) + + # Test add_memories functions (more likely to have dependency issues) + print("\n\n📝 Testing ADD_MEMORIES functions:") + try: + print("\n" + "-" * 40) + self.test_add_memories_basic() + print("✅ Basic add memories test completed") + except Exception as e: + print(f"❌ Basic add memories test failed: {e}") + + # Test search_memories functions first (less likely to fail) + print("\n🔍 Testing SEARCH_MEMORIES functions:") + try: + self.test_search_memories_basic( + query="What are some good places to celebrate New Year's Eve in Shanghai?", + mode="fast", + topk=3, + ) + print("✅ Search memories test completed successfully") + except Exception as e: + print(f"❌ Search memories test failed: {e}") + + print("\n" + "=" * 80) + print("✅ All tests completed!") + + +# Example usage +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="API Analyzer for Memory Scheduler") + parser.add_argument( + "--mode", + choices=["direct", "api"], + default="direct", + help="Test mode: 'direct' for direct function testing, 'api' for API testing (default: direct)", + ) + + args = parser.parse_args() + + if args.mode == "direct": + # Direct test mode for search_memories and add_memories functions + print("Using direct test mode") + try: + direct_analyzer = DirectSearchMemoriesAnalyzer() + direct_analyzer.run_all_tests() + except Exception as e: + print(f"Direct test mode failed: {e}") + import traceback + + traceback.print_exc() + else: + # Original API test mode + print("Using API test mode") + analyzer = APIAnalyzerForScheduler() + + # Test add operation + messages = [ + {"role": "user", "content": "Where should I go for New Year's Eve in Shanghai?"}, + { + "role": "assistant", + "content": "You could head to the Bund for the countdown, attend a rooftop party, or enjoy the fireworks at Disneyland Shanghai.", + }, + ] + + add_result = analyzer.add( + messages=messages, user_id="test_user_id", mem_cube_id="test_mem_cube_id" + ) + print("Add result:", add_result) + + # Test search operation + search_result = analyzer.search( + user_id="test_user_id", + mem_cube_id="test_mem_cube_id", + query="What are some good places to celebrate New Year's Eve in Shanghai?", + top=50, + ) + print("Search result:", search_result) diff --git a/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py b/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py index 7cd085ada..ace67eff6 100644 --- a/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py +++ b/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py @@ -485,18 +485,20 @@ def chat(self, query: str, user_id: str | None = None) -> str: past_key_values = None if self.config.enable_activation_memory: - assert self.config.chat_model.backend == "huggingface", ( - "Activation memory only used for huggingface backend." - ) - # TODO this only one cubes - for mem_cube_id, mem_cube in self.mem_cubes.items(): - if mem_cube_id not in user_cube_ids: - continue - if mem_cube.act_mem: - kv_cache = next(iter(mem_cube.act_mem.get_all()), None) - past_key_values = ( - kv_cache.memory if (kv_cache and hasattr(kv_cache, "memory")) else None - ) + if self.config.chat_model.backend != "huggingface": + logger.error( + "Activation memory only used for huggingface backend. Skipping activation memory." + ) + else: + # TODO this only one cubes + for mem_cube_id, mem_cube in self.mem_cubes.items(): + if mem_cube_id not in user_cube_ids: + continue + if mem_cube.act_mem: + kv_cache = next(iter(mem_cube.act_mem.get_all()), None) + past_key_values = ( + kv_cache.memory if (kv_cache and hasattr(kv_cache, "memory")) else None + ) break # Generate response response = self.chat_llm.generate(current_messages, past_key_values=past_key_values) diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 1e8b042b1..e475ea225 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -22,8 +22,12 @@ from memos.mem_scheduler.schemas.general_schemas import ( DEFAULT_ACT_MEM_DUMP_PATH, DEFAULT_CONSUME_INTERVAL_SECONDS, + DEFAULT_CONTEXT_WINDOW_SIZE, + DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE, DEFAULT_STARTUP_MODE, DEFAULT_THREAD_POOL_MAX_WORKERS, + DEFAULT_TOP_K, + DEFAULT_USE_REDIS_QUEUE, STARTUP_BY_PROCESS, MemCubeID, TreeTextMemory_SEARCH_METHOD, @@ -34,6 +38,7 @@ ScheduleMessageItem, ) from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem +from memos.mem_scheduler.utils.db_utils import get_utc_now from memos.mem_scheduler.utils.filter_utils import ( transform_name_to_key, ) @@ -57,11 +62,13 @@ def __init__(self, config: BaseSchedulerConfig): self.config = config # hyper-parameters - self.top_k = self.config.get("top_k", 10) - self.context_window_size = self.config.get("context_window_size", 5) + self.top_k = self.config.get("top_k", DEFAULT_TOP_K) + self.context_window_size = self.config.get( + "context_window_size", DEFAULT_CONTEXT_WINDOW_SIZE + ) self.enable_activation_memory = self.config.get("enable_activation_memory", False) self.act_mem_dump_path = self.config.get("act_mem_dump_path", DEFAULT_ACT_MEM_DUMP_PATH) - self.search_method = TreeTextMemory_SEARCH_METHOD + self.search_method = self.config.get("search_method", TreeTextMemory_SEARCH_METHOD) self.enable_parallel_dispatch = self.config.get("enable_parallel_dispatch", True) self.thread_pool_max_workers = self.config.get( "thread_pool_max_workers", DEFAULT_THREAD_POOL_MAX_WORKERS @@ -86,13 +93,22 @@ def __init__(self, config: BaseSchedulerConfig): # optional configs self.disable_handlers: list | None = self.config.get("disable_handlers", None) - # internal message queue + # message queue configuration + self.use_redis_queue = self.config.get("use_redis_queue", DEFAULT_USE_REDIS_QUEUE) self.max_internal_message_queue_size = self.config.get( - "max_internal_message_queue_size", 10000 - ) - self.memos_message_queue: Queue[ScheduleMessageItem] = Queue( - maxsize=self.max_internal_message_queue_size + "max_internal_message_queue_size", DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE ) + + # Initialize message queue based on configuration + if self.use_redis_queue: + self.memos_message_queue = None # Will use Redis instead + # Initialize Redis if using Redis queue with auto-initialization + self.auto_initialize_redis() + else: + self.memos_message_queue: Queue[ScheduleMessageItem] = Queue( + maxsize=self.max_internal_message_queue_size + ) + self.max_web_log_queue_size = self.config.get("max_web_log_queue_size", 50) self._web_log_message_queue: Queue[ScheduleLogForWebItem] = Queue( maxsize=self.max_web_log_queue_size @@ -390,7 +406,7 @@ def update_activation_memory( cache_item = act_mem.extract(new_text_memory) cache_item.records.text_memories = new_text_memories - cache_item.records.timestamp = datetime.utcnow() + cache_item.records.timestamp = get_utc_now() act_mem.add([cache_item]) act_mem.dump(self.act_mem_dump_path) @@ -471,7 +487,7 @@ def update_activation_memory_periodically( mem_cube=mem_cube, ) - self.monitor.last_activation_mem_update_time = datetime.utcnow() + self.monitor.last_activation_mem_update_time = get_utc_now() logger.debug( f"Activation memory update completed at {self.monitor.last_activation_mem_update_time}" @@ -480,14 +496,14 @@ def update_activation_memory_periodically( else: logger.info( f"Skipping update - {interval_seconds} second interval not yet reached. " - f"Last update time is {self.monitor.last_activation_mem_update_time} and now is" - f"{datetime.utcnow()}" + f"Last update time is {self.monitor.last_activation_mem_update_time} and now is " + f"{get_utc_now()}" ) except Exception as e: logger.error(f"Error in update_activation_memory_periodically: {e}", exc_info=True) - def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageItem]): - """Submit multiple messages to the message queue.""" + async def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageItem]): + """Submit messages to the message queue (either local queue or Redis).""" if isinstance(messages, ScheduleMessageItem): messages = [messages] # transform single message to list @@ -497,13 +513,20 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt logger.error(error_msg) raise TypeError(error_msg) - # Check if this handler is disabled if self.disable_handlers and message.label in self.disable_handlers: logger.info(f"Skipping disabled handler: {message.label} - {message.content}") continue - self.memos_message_queue.put(message) - logger.info(f"Submitted message: {message.label} - {message.content}") + if self.use_redis_queue: + # Use Redis stream for message queue + await self.redis_add_message_stream(message.to_dict()) + logger.info(f"Submitted message to Redis: {message.label} - {message.content}") + else: + # Use local queue + self.memos_message_queue.put(message) + logger.info( + f"Submitted message to local queue: {message.label} - {message.content}" + ) def _submit_web_logs( self, messages: ScheduleLogForWebItem | list[ScheduleLogForWebItem] @@ -556,36 +579,64 @@ def _message_consumer(self) -> None: Continuously checks the queue for messages and dispatches them. Runs in a dedicated thread to process messages at regular intervals. + For Redis queue, this method starts the Redis listener. """ - while self._running: # Use a running flag for graceful shutdown - try: - # Get all available messages at once (thread-safe approach) - messages = [] - while True: - try: - # Use get_nowait() directly without empty() check to avoid race conditions - message = self.memos_message_queue.get_nowait() - messages.append(message) - except queue.Empty: - # No more messages available - break - - if messages: - try: - self.dispatcher.dispatch(messages) - except Exception as e: - logger.error(f"Error dispatching messages: {e!s}") - finally: - # Mark all messages as processed - for _ in messages: - self.memos_message_queue.task_done() - - # Sleep briefly to prevent busy waiting - time.sleep(self._consume_interval) # Adjust interval as needed - - except Exception as e: - logger.error(f"Unexpected error in message consumer: {e!s}") - time.sleep(self._consume_interval) # Prevent tight error loops + if self.use_redis_queue: + # For Redis queue, start the Redis listener + def redis_message_handler(message_data): + """Handler for Redis messages""" + try: + # Redis message data needs to be decoded from bytes to string + decoded_data = {} + for key, value in message_data.items(): + if isinstance(key, bytes): + key = key.decode("utf-8") + if isinstance(value, bytes): + value = value.decode("utf-8") + decoded_data[key] = value + + message = ScheduleMessageItem.from_dict(decoded_data) + self.dispatcher.dispatch([message]) + except Exception as e: + logger.error(f"Error processing Redis message: {e}") + logger.error(f"Message data: {message_data}") + + self.redis_start_listening(handler=redis_message_handler) + + # Keep the thread alive while Redis listener is running + while self._running: + time.sleep(self._consume_interval) + else: + # Original local queue logic + while self._running: # Use a running flag for graceful shutdown + try: + # Get all available messages at once (thread-safe approach) + messages = [] + while True: + try: + # Use get_nowait() directly without empty() check to avoid race conditions + message = self.memos_message_queue.get_nowait() + messages.append(message) + except queue.Empty: + # No more messages available + break + + if messages: + try: + self.dispatcher.dispatch(messages) + except Exception as e: + logger.error(f"Error dispatching messages: {e!s}") + finally: + # Mark all messages as processed + for _ in messages: + self.memos_message_queue.task_done() + + # Sleep briefly to prevent busy waiting + time.sleep(self._consume_interval) # Adjust interval as needed + + except Exception as e: + logger.error(f"Unexpected error in message consumer: {e!s}") + time.sleep(self._consume_interval) # Prevent tight error loops def start(self) -> None: """ @@ -722,14 +773,77 @@ def unregister_handlers(self, labels: list[str]) -> dict[str, bool]: return self.dispatcher.unregister_handlers(labels) + def get_running_tasks(self, filter_func: Callable | None = None) -> dict[str, dict]: + """ + Get currently running tasks, optionally filtered by a custom function. + + This method delegates to the dispatcher's get_running_tasks method. + + Args: + filter_func: Optional function to filter tasks. Should accept a RunningTaskItem + and return True if the task should be included in results. + + Returns: + dict[str, dict]: Dictionary mapping task IDs to task information dictionaries. + Each task dict contains: item_id, user_id, mem_cube_id, task_info, + task_name, start_time, end_time, status, result, error_message, messages + + Examples: + # Get all running tasks + all_tasks = scheduler.get_running_tasks() + + # Get tasks for specific user + user_tasks = scheduler.get_running_tasks( + filter_func=lambda task: task.user_id == "user123" + ) + + # Get tasks with specific status + active_tasks = scheduler.get_running_tasks( + filter_func=lambda task: task.status == "running" + ) + """ + if not self.dispatcher: + logger.warning("Dispatcher is not initialized, returning empty tasks dict") + return {} + + running_tasks = self.dispatcher.get_running_tasks(filter_func=filter_func) + + # Convert RunningTaskItem objects to dictionaries for easier consumption + result = {} + for task_id, task_item in running_tasks.items(): + result[task_id] = { + "item_id": task_item.item_id, + "user_id": task_item.user_id, + "mem_cube_id": task_item.mem_cube_id, + "task_info": task_item.task_info, + "task_name": task_item.task_name, + "start_time": task_item.start_time, + "end_time": task_item.end_time, + "status": task_item.status, + "result": task_item.result, + "error_message": task_item.error_message, + "messages": task_item.messages, + } + + return result + def _cleanup_queues(self) -> None: """Ensure all queues are emptied and marked as closed.""" - try: - while not self.memos_message_queue.empty(): - self.memos_message_queue.get_nowait() - self.memos_message_queue.task_done() - except queue.Empty: - pass + if self.use_redis_queue: + # For Redis queue, stop the listener and close connection + try: + self.redis_stop_listening() + self.redis_close() + except Exception as e: + logger.error(f"Error cleaning up Redis connection: {e}") + else: + # Original local queue cleanup + try: + while not self.memos_message_queue.empty(): + self.memos_message_queue.get_nowait() + self.memos_message_queue.task_done() + except queue.Empty: + pass try: while not self._web_log_message_queue.empty(): diff --git a/src/memos/mem_scheduler/general_modules/api_misc.py b/src/memos/mem_scheduler/general_modules/api_misc.py new file mode 100644 index 000000000..6139a895a --- /dev/null +++ b/src/memos/mem_scheduler/general_modules/api_misc.py @@ -0,0 +1,115 @@ +import threading + +from typing import Any + +from memos.log import get_logger +from memos.mem_scheduler.general_modules.base import BaseSchedulerModule +from memos.mem_scheduler.orm_modules.redis_model import RedisDBManager, SimpleListManager + + +logger = get_logger(__name__) + + +class SchedulerAPIModule(BaseSchedulerModule): + def __init__(self): + super().__init__() + + self.search_history_managers: dict[str, RedisDBManager] = {} + + def get_search_history_manager(self, user_id: str, mem_cube_id: str) -> RedisDBManager: + """Get or create a Redis manager for search history.""" + key = f"search_history:{user_id}:{mem_cube_id}" + if key not in self.search_history_managers: + self.search_history_managers[key] = RedisDBManager( + user_id=user_id, mem_cube_id=mem_cube_id + ) + return self.search_history_managers[key] + + def sync_search_data( + self, user_id: str, mem_cube_id: str, query: str, formatted_memories: Any + ) -> None: + """ + Sync search data to Redis, maintaining a list of size 5. + + Args: + user_id: User identifier + mem_cube_id: Memory cube identifier + query: Search query string + formatted_memories: Formatted search results + """ + try: + # Get the search history manager + manager = self.get_search_history_manager(user_id, mem_cube_id) + + # Create search data entry + search_entry = { + "query": query, + "formatted_memories": formatted_memories, + "timestamp": threading.current_thread().ident, # Use thread ID as simple timestamp + } + + # Load existing search history + existing_data = manager.load_from_db() + + if existing_data is None: + search_history = SimpleListManager([]) + else: + # If existing data is a SimpleListManager, use it; otherwise create new one + if isinstance(existing_data, SimpleListManager): + search_history = existing_data + else: + search_history = SimpleListManager([]) + + # Add new entry and keep only latest 5 + search_history.add_item(str(search_entry)) + if len(search_history) > 5: + # Keep only the latest 5 items + search_history.items = search_history.items[-5:] + + # Save back to Redis + manager.save_to_db(search_history) + + logger.info( + f"Synced search data for user {user_id}, mem_cube {mem_cube_id}. History size: {len(search_history)}" + ) + + except Exception as e: + logger.error(f"Failed to sync search data: {e}", exc_info=True) + + def get_pre_fine_memories(self, user_id: str, mem_cube_id: str) -> list: + """ + Get the most recent pre-computed fine memories from search history. + + Args: + user_id: User identifier + mem_cube_id: Memory cube identifier + + Returns: + List of formatted memories from the most recent search, or empty list if none found + """ + try: + manager = self.get_search_history_manager(user_id, mem_cube_id) + search_history_key = "search_history_list" + existing_data = manager.load_from_db(search_history_key) + + if existing_data is None: + return [] + + search_history = ( + existing_data.obj_instance + if hasattr(existing_data, "obj_instance") + else existing_data + ) + + if not search_history or len(search_history) == 0: + return [] + + # Return the formatted_memories from the most recent search + latest_entry = search_history[-1] + return ( + latest_entry.get("formatted_memories", []) if isinstance(latest_entry, dict) else [] + ) + + except Exception as e: + logger.error(f"Failed to get pre-computed fine memories: {e}", exc_info=True) + return [] diff --git a/src/memos/mem_scheduler/general_modules/dispatcher.py b/src/memos/mem_scheduler/general_modules/dispatcher.py index 4584beb96..c357e31b5 100644 --- a/src/memos/mem_scheduler/general_modules/dispatcher.py +++ b/src/memos/mem_scheduler/general_modules/dispatcher.py @@ -101,15 +101,43 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): return wrapped_handler - def get_running_tasks(self) -> dict[str, RunningTaskItem]: + def get_running_tasks( + self, filter_func: Callable[[RunningTaskItem], bool] | None = None + ) -> dict[str, RunningTaskItem]: """ - Get a copy of currently running tasks. + Get a copy of currently running tasks, optionally filtered by a custom function. + + Args: + filter_func: Optional function that takes a RunningTaskItem and returns True if it should be included. + Common filters can be created using helper methods like filter_by_user_id, filter_by_task_name, etc. Returns: Dictionary of running tasks keyed by task ID + + Examples: + # Get all running tasks + all_tasks = dispatcher.get_running_tasks() + + # Get tasks for specific user + user_tasks = dispatcher.get_running_tasks(lambda task: task.user_id == "user123") + + # Get tasks for specific task name + handler_tasks = dispatcher.get_running_tasks(lambda task: task.task_name == "test_handler") + + # Get tasks with multiple conditions + filtered_tasks = dispatcher.get_running_tasks( + lambda task: task.user_id == "user123" and task.status == "running" + ) """ with self._task_lock: - return self._running_tasks.copy() + if filter_func is None: + return self._running_tasks.copy() + + return { + task_id: task_item + for task_id, task_item in self._running_tasks.items() + if filter_func(task_item) + } def get_running_task_count(self) -> int: """ diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index f47cc0cc5..31bb9b3da 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -148,7 +148,7 @@ def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: logger.info(f"Messages {messages} assigned to {QUERY_LABEL} handler.") # Process the query in a session turn - grouped_messages = self.dispatcher.group_messages_by_user_and_cube(messages=messages) + grouped_messages = self.dispatcher._group_messages_by_user_and_mem_cube(messages=messages) self.validate_schedule_messages(messages=messages, label=QUERY_LABEL) @@ -170,7 +170,7 @@ def _answer_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: """ logger.info(f"Messages {messages} assigned to {ANSWER_LABEL} handler.") # Process the query in a session turn - grouped_messages = self.dispatcher.group_messages_by_user_and_cube(messages=messages) + grouped_messages = self.dispatcher._group_messages_by_user_and_mem_cube(messages=messages) self.validate_schedule_messages(messages=messages, label=ANSWER_LABEL) diff --git a/src/memos/mem_scheduler/monitors/dispatcher_monitor.py b/src/memos/mem_scheduler/monitors/dispatcher_monitor.py index 13fe07354..0ebb7da4f 100644 --- a/src/memos/mem_scheduler/monitors/dispatcher_monitor.py +++ b/src/memos/mem_scheduler/monitors/dispatcher_monitor.py @@ -1,7 +1,6 @@ import threading import time -from datetime import datetime from time import perf_counter from memos.configs.mem_scheduler import BaseSchedulerConfig @@ -14,6 +13,7 @@ DEFAULT_DISPATCHER_MONITOR_MAX_FAILURES, DEFAULT_STUCK_THREAD_TOLERANCE, ) +from memos.mem_scheduler.utils.db_utils import get_utc_now logger = get_logger(__name__) @@ -84,7 +84,7 @@ def register_pool( "max_workers": max_workers, "restart": restart_on_failure, "failure_count": 0, - "last_active": datetime.utcnow(), + "last_active": get_utc_now(), "healthy": True, } logger.info(f"Registered thread pool '{name}' for monitoring") @@ -122,54 +122,6 @@ def _monitor_loop(self) -> None: logger.debug("Monitor loop exiting") - def start(self) -> bool: - """ - Start the monitoring thread. - - Returns: - bool: True if monitor started successfully, False if already running - """ - if self._running: - logger.warning("Dispatcher Monitor is already running") - return False - - self._running = True - self._monitor_thread = threading.Thread( - target=self._monitor_loop, name="threadpool_monitor", daemon=True - ) - self._monitor_thread.start() - logger.info("Dispatcher Monitor monitor started") - return True - - def stop(self) -> None: - """ - Stop the monitoring thread and clean up all managed thread pools. - Ensures proper shutdown of all monitored executors. - """ - if not self._running: - return - - # Stop the monitoring loop - self._running = False - if self._monitor_thread and self._monitor_thread.is_alive(): - self._monitor_thread.join(timeout=5) - - # Shutdown all registered pools - with self._pool_lock: - for name, pool_info in self._pools.items(): - executor = pool_info["executor"] - if not executor._shutdown: # pylint: disable=protected-access - try: - logger.info(f"Shutting down thread pool '{name}'") - executor.shutdown(wait=True, cancel_futures=True) - logger.info(f"Successfully shut down thread pool '{name}'") - except Exception as e: - logger.error(f"Error shutting down pool '{name}': {e!s}", exc_info=True) - - # Clear the pool registry - self._pools.clear() - logger.info("Thread pool monitor and all pools stopped") - def _check_pools_health(self) -> None: """Check health of all registered thread pools.""" for name, pool_info in list(self._pools.items()): @@ -182,7 +134,6 @@ def _check_pools_health(self) -> None: if is_healthy: pool_info["failure_count"] = 0 pool_info["healthy"] = True - return else: pool_info["failure_count"] += 1 pool_info["healthy"] = False @@ -269,27 +220,24 @@ def _check_pool_health( f"Found {len(stuck_tasks)} stuck tasks (tolerance: {effective_tolerance})", ) - # Check thread activity - active_threads = sum( - 1 - for t in threading.enumerate() - if t.name.startswith(executor._thread_name_prefix) # pylint: disable=protected-access - ) - - # Check if no threads are active but should be - if active_threads == 0 and pool_info["max_workers"] > 0: - return False, "No active worker threads" - + # Only check for stuck threads, not inactive threads # Check if threads are stuck (no activity for specified intervals) - time_delta = (datetime.utcnow() - pool_info["last_active"]).total_seconds() + time_delta = (get_utc_now() - pool_info["last_active"]).total_seconds() if time_delta >= self.check_interval * stuck_max_interval: return False, f"No recent activity for {time_delta:.1f} seconds" # If we got here, pool appears healthy - pool_info["last_active"] = datetime.utcnow() + pool_info["last_active"] = get_utc_now() # Log health status with comprehensive information if self.dispatcher: + # Check thread activity + active_threads = sum( + 1 + for t in threading.enumerate() + if t.name.startswith(executor._thread_name_prefix) # pylint: disable=protected-access + ) + task_count = self.dispatcher.get_running_task_count() max_workers = pool_info.get("max_workers", 0) stuck_count = len(stuck_tasks) @@ -338,7 +286,7 @@ def _restart_pool(self, name: str, pool_info: dict) -> None: pool_info["executor"] = new_executor pool_info["failure_count"] = 0 pool_info["healthy"] = True - pool_info["last_active"] = datetime.utcnow() + pool_info["last_active"] = get_utc_now() elapsed_time = perf_counter() - start_time if elapsed_time > 1: @@ -379,3 +327,52 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): """Context manager exit point.""" self.stop() + + def start(self) -> bool: + """ + Start the monitoring thread. + + Returns: + bool: True if monitor started successfully, False if already running + """ + if self._running: + logger.warning("Dispatcher Monitor is already running") + return False + + self._running = True + self._monitor_thread = threading.Thread( + target=self._monitor_loop, name="threadpool_monitor", daemon=True + ) + self._monitor_thread.start() + logger.info("Dispatcher Monitor monitor started") + return True + + def stop(self) -> None: + """ + Stop the monitoring thread and clean up all managed thread pools. + Ensures proper shutdown of all monitored executors. + """ + if not self._running: + return + + # Stop the monitoring loop + self._running = False + if self._monitor_thread and self._monitor_thread.is_alive(): + self._monitor_thread.join(timeout=5) + + # Shutdown all registered pools + with self._pool_lock: + for name, pool_info in self._pools.items(): + executor = pool_info["executor"] + if not executor._shutdown: # pylint: disable=protected-access + try: + logger.info(f"Shutting down thread pool '{name}'") + executor.shutdown(wait=True, cancel_futures=True) + logger.info(f"Successfully shut down thread pool '{name}'") + except Exception as e: + logger.error(f"Error shutting down pool '{name}': {e!s}", exc_info=True) + + # Clear the pool registry + self._pools.clear() + + logger.info("Thread pool monitor and all pools stopped") diff --git a/src/memos/mem_scheduler/monitors/general_monitor.py b/src/memos/mem_scheduler/monitors/general_monitor.py index 87d996549..22fb78445 100644 --- a/src/memos/mem_scheduler/monitors/general_monitor.py +++ b/src/memos/mem_scheduler/monitors/general_monitor.py @@ -28,6 +28,7 @@ MemoryMonitorManager, QueryMonitorQueue, ) +from memos.mem_scheduler.utils.db_utils import get_utc_now from memos.mem_scheduler.utils.misc_utils import extract_json_dict from memos.memories.textual.tree import TreeTextMemory @@ -64,7 +65,7 @@ def __init__( "No database engine provided; falling back to default temporary SQLite engine. " "This is intended for testing only. Consider providing a configured engine for production use." ) - self.db_engine = BaseDBManager.create_default_engine() + self.db_engine = BaseDBManager.create_default_sqlite_engine() self.query_monitors: dict[UserID, dict[MemCubeID, DBManagerForQueryMonitorQueue]] = {} self.working_memory_monitors: dict[ @@ -256,7 +257,7 @@ def update_activation_memory_monitors( activation_db_manager.sync_with_orm(size_limit=self.activation_mem_monitor_capacity) def timed_trigger(self, last_time: datetime, interval_seconds: float) -> bool: - now = datetime.utcnow() + now = get_utc_now() elapsed = (now - last_time).total_seconds() if elapsed >= interval_seconds: return True diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py index dd08954a9..fb5f4ce7c 100644 --- a/src/memos/mem_scheduler/optimized_scheduler.py +++ b/src/memos/mem_scheduler/optimized_scheduler.py @@ -1,14 +1,21 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any +from memos.api.product_models import APISearchRequest from memos.configs.mem_scheduler import GeneralSchedulerConfig from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube +from memos.mem_scheduler.general_modules.api_misc import SchedulerAPIModule from memos.mem_scheduler.general_scheduler import GeneralScheduler from memos.mem_scheduler.schemas.general_schemas import ( + API_MIX_SEARCH_LABEL, + QUERY_LABEL, MemCubeID, + SearchMode, UserID, ) +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory +from memos.types import UserContext if TYPE_CHECKING: @@ -19,10 +26,116 @@ class OptimizedScheduler(GeneralScheduler): - """Optimized scheduler with improved working memory management""" + """Optimized scheduler with improved working memory management and support for api""" def __init__(self, config: GeneralSchedulerConfig): super().__init__(config) + self.api_module = SchedulerAPIModule() + self.message_consumers = { + API_MIX_SEARCH_LABEL: self._api_mix_search_message_consumer, + } + + def _format_memory_item(self, memory_data: Any) -> dict[str, Any]: + """Format a single memory item for API response.""" + memory = memory_data.model_dump() + memory_id = memory["id"] + ref_id = f"[{memory_id.split('-')[0]}]" + + memory["ref_id"] = ref_id + memory["metadata"]["embedding"] = [] + memory["metadata"]["sources"] = [] + memory["metadata"]["ref_id"] = ref_id + memory["metadata"]["id"] = memory_id + memory["metadata"]["memory"] = memory["memory"] + + return memory + + def fine_search_memories( + self, + search_req: APISearchRequest, + user_context: UserContext, + mem_cube: GeneralMemCube, + ): + """Fine search memories function copied from server_router to avoid circular import""" + target_session_id = search_req.session_id + if not target_session_id: + target_session_id = "default_session" + search_filter = {"session_id": search_req.session_id} if search_req.session_id else None + + # Create MemCube and perform search + search_results = mem_cube.text_mem.search( + query=search_req.query, + user_name=user_context.mem_cube_id, + top_k=search_req.top_k, + mode=SearchMode.FINE, + manual_close_internet=not search_req.internet_search, + moscube=search_req.moscube, + search_filter=search_filter, + info={ + "user_id": search_req.user_id, + "session_id": target_session_id, + "chat_history": search_req.chat_history, + }, + ) + formatted_memories = [self._format_memory_item(data) for data in search_results] + + return formatted_memories + + def update_search_memories_to_redis( + self, user_id: str, mem_cube_id: str, messages: list[ScheduleMessageItem] + ): + mem_cube = messages[0].mem_cube + + # for status update + self._set_current_context_from_message(msg=messages[0]) + + # update query monitors + for msg in messages: + self.monitor.register_query_monitor_if_not_exists( + user_id=user_id, mem_cube_id=mem_cube_id + ) + + content_dict = msg.content + search_req = content_dict["search_req"] + user_context = content_dict["user_context"] + + formatted_memories = self.fine_search_memories( + search_req=search_req, user_context=user_context, mem_cube=mem_cube + ) + + # Sync search data to Redis + try: + self.api_module.sync_search_data( + user_id=search_req.user_id, + mem_cube_id=user_context.mem_cube_id, + query=search_req.query, + formatted_memories=formatted_memories, + ) + except Exception as e: + logger.error(f"Failed to sync search data: {e}") + + def _api_mix_search_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: + """ + Process and handle query trigger messages from the queue. + + Args: + messages: List of query messages to process + """ + logger.info(f"Messages {messages} assigned to {QUERY_LABEL} handler.") + + # Process the query in a session turn + grouped_messages = self.dispatcher._group_messages_by_user_and_mem_cube(messages=messages) + + self.validate_schedule_messages(messages=messages, label=QUERY_LABEL) + + for user_id in grouped_messages: + for mem_cube_id in grouped_messages[user_id]: + messages = grouped_messages[user_id][mem_cube_id] + if len(messages) == 0: + return + self.update_search_memories_to_redis( + user_id=user_id, mem_cube_id=mem_cube_id, messages=messages + ) def replace_working_memory( self, diff --git a/src/memos/mem_scheduler/orm_modules/base_model.py b/src/memos/mem_scheduler/orm_modules/base_model.py index 9d75a12bd..cf3fc904c 100644 --- a/src/memos/mem_scheduler/orm_modules/base_model.py +++ b/src/memos/mem_scheduler/orm_modules/base_model.py @@ -10,13 +10,16 @@ from sqlalchemy import Boolean, Column, DateTime, String, Text, and_, create_engine from sqlalchemy.engine import Engine -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.orm import Session, declarative_base, sessionmaker from memos.log import get_logger from memos.mem_user.user_manager import UserManager +class DatabaseError(Exception): + """Exception raised for database-related errors""" + + T = TypeVar("T") # The model type (MemoryMonitorManager, QueryMonitorManager, etc.) ORM = TypeVar("ORM") # The ORM model type @@ -561,7 +564,7 @@ def close(self): logger.error(f"Error during close operation: {e}") @staticmethod - def create_default_engine() -> Engine: + def create_default_sqlite_engine() -> Engine: """Create SQLAlchemy engine with default database path Returns: @@ -633,3 +636,211 @@ def create_mysql_db_path( else: db_path = f"mysql+pymysql://{username}@{host}:{port}/{database}?charset={charset}" return db_path + + @staticmethod + def load_mysql_engine_from_env(env_file_path: str | None = None) -> Engine | None: + """Load MySQL engine from environment variables + + Args: + env_file_path: Path to .env file (optional, defaults to loading from current environment) + + Returns: + SQLAlchemy Engine instance configured for MySQL + + Raises: + DatabaseError: If required environment variables are missing or connection fails + """ + # Load environment variables from file if provided + if env_file_path: + if os.path.exists(env_file_path): + from dotenv import load_dotenv + + load_dotenv(env_file_path) + logger.info(f"Loaded environment variables from {env_file_path}") + else: + logger.warning( + f"Environment file not found: {env_file_path}, using current environment variables" + ) + else: + logger.info("Using current environment variables (no env_file_path provided)") + + # Get MySQL configuration from environment variables + mysql_host = os.getenv("MYSQL_HOST") + mysql_port_str = os.getenv("MYSQL_PORT") + mysql_username = os.getenv("MYSQL_USERNAME") + mysql_password = os.getenv("MYSQL_PASSWORD") + mysql_database = os.getenv("MYSQL_DATABASE") + mysql_charset = os.getenv("MYSQL_CHARSET") + + # Check required environment variables + required_vars = { + "MYSQL_HOST": mysql_host, + "MYSQL_USERNAME": mysql_username, + "MYSQL_PASSWORD": mysql_password, + "MYSQL_DATABASE": mysql_database, + } + + missing_vars = [var for var, value in required_vars.items() if not value] + if missing_vars: + error_msg = f"Missing required MySQL environment variables: {', '.join(missing_vars)}" + logger.error(error_msg) + return None + + # Parse port with validation + try: + mysql_port = int(mysql_port_str) if mysql_port_str else 3306 + except ValueError: + error_msg = f"Invalid MYSQL_PORT value: {mysql_port_str}. Must be a valid integer." + logger.error(error_msg) + return None + + # Set default charset if not provided + if not mysql_charset: + mysql_charset = "utf8mb4" + + # Create MySQL connection URL + db_url = BaseDBManager.create_mysql_db_path( + host=mysql_host, + port=mysql_port, + username=mysql_username, + password=mysql_password, + database=mysql_database, + charset=mysql_charset, + ) + + try: + # Create and test the engine + engine = create_engine(db_url, echo=False) + + # Test connection + with engine.connect() as conn: + from sqlalchemy import text + + conn.execute(text("SELECT 1")) + + logger.info( + f"Successfully created MySQL engine: {mysql_host}:{mysql_port}/{mysql_database}" + ) + return engine + + except Exception as e: + error_msg = f"Failed to create MySQL engine from environment variables: {e}" + logger.error(error_msg) + raise DatabaseError(error_msg) from e + + @staticmethod + def load_redis_engine_from_env(env_file_path: str | None = None) -> Any: + """Load Redis connection from environment variables + + Args: + env_file_path: Path to .env file (optional, defaults to loading from current environment) + + Returns: + Redis connection instance + + Raises: + DatabaseError: If required environment variables are missing or connection fails + """ + try: + import redis + except ImportError as e: + error_msg = "Redis package not installed. Install with: pip install redis" + logger.error(error_msg) + raise DatabaseError(error_msg) from e + + # Load environment variables from file if provided + if env_file_path: + if os.path.exists(env_file_path): + from dotenv import load_dotenv + + load_dotenv(env_file_path) + logger.info(f"Loaded environment variables from {env_file_path}") + else: + logger.warning( + f"Environment file not found: {env_file_path}, using current environment variables" + ) + else: + logger.info("Using current environment variables (no env_file_path provided)") + + # Get Redis configuration from environment variables + redis_host = os.getenv("REDIS_HOST") or os.getenv("MEMSCHEDULER_REDIS_HOST") + redis_port_str = os.getenv("REDIS_PORT") or os.getenv("MEMSCHEDULER_REDIS_PORT") + redis_db_str = os.getenv("REDIS_DB") or os.getenv("MEMSCHEDULER_REDIS_DB") + redis_password = os.getenv("REDIS_PASSWORD") or os.getenv("MEMSCHEDULER_REDIS_PASSWORD") + + # Check required environment variables + if not redis_host: + error_msg = ( + "Missing required Redis environment variable: REDIS_HOST or MEMSCHEDULER_REDIS_HOST" + ) + logger.error(error_msg) + return None + + # Parse port with validation + try: + redis_port = int(redis_port_str) if redis_port_str else 6379 + except ValueError: + error_msg = f"Invalid REDIS_PORT value: {redis_port_str}. Must be a valid integer." + logger.error(error_msg) + return None + + # Parse database with validation + try: + redis_db = int(redis_db_str) if redis_db_str else 0 + except ValueError: + error_msg = f"Invalid REDIS_DB value: {redis_db_str}. Must be a valid integer." + logger.error(error_msg) + return None + + # Optional timeout settings + socket_timeout = os.getenv( + "REDIS_SOCKET_TIMEOUT", os.getenv("MEMSCHEDULER_REDIS_TIMEOUT", None) + ) + socket_connect_timeout = os.getenv( + "REDIS_SOCKET_CONNECT_TIMEOUT", os.getenv("MEMSCHEDULER_REDIS_CONNECT_TIMEOUT", None) + ) + + try: + # Build Redis connection parameters + redis_kwargs = { + "host": redis_host, + "port": redis_port, + "db": redis_db, + "decode_responses": True, + } + + if redis_password: + redis_kwargs["password"] = redis_password + + if socket_timeout: + try: + redis_kwargs["socket_timeout"] = float(socket_timeout) + except ValueError: + logger.warning( + f"Invalid REDIS_SOCKET_TIMEOUT value: {socket_timeout}, ignoring" + ) + + if socket_connect_timeout: + try: + redis_kwargs["socket_connect_timeout"] = float(socket_connect_timeout) + except ValueError: + logger.warning( + f"Invalid REDIS_SOCKET_CONNECT_TIMEOUT value: {socket_connect_timeout}, ignoring" + ) + + # Create Redis connection + redis_client = redis.Redis(**redis_kwargs) + + # Test connection + if not redis_client.ping(): + raise ConnectionError("Redis ping failed") + + logger.info( + f"Successfully created Redis connection: {redis_host}:{redis_port}/{redis_db}" + ) + return redis_client + + except Exception as e: + error_msg = f"Failed to create Redis connection from environment variables: {e}" + logger.error(error_msg) + raise DatabaseError(error_msg) from e diff --git a/src/memos/mem_scheduler/orm_modules/redis_model.py b/src/memos/mem_scheduler/orm_modules/redis_model.py new file mode 100644 index 000000000..ccfe1b1c8 --- /dev/null +++ b/src/memos/mem_scheduler/orm_modules/redis_model.py @@ -0,0 +1,699 @@ +import json +import time + +from typing import Any, TypeVar + +from sqlalchemy.engine import Engine +from sqlalchemy.orm import declarative_base + +from memos.log import get_logger +from memos.mem_scheduler.orm_modules.base_model import BaseDBManager +from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorManager +from memos.mem_scheduler.utils.db_utils import get_utc_now + + +T = TypeVar("T") # The model type (MemoryMonitorManager, QueryMonitorManager, etc.) +ORM = TypeVar("ORM") # The ORM model type + +logger = get_logger(__name__) + +Base = declarative_base() + + +class SimpleListManager: + """Simple wrapper class for list[str] to work with RedisDBManager""" + + def __init__(self, items: list[str] | None = None): + self.items = items or [] + + def to_json(self) -> str: + """Serialize to JSON string""" + return json.dumps({"items": self.items}) + + @classmethod + def from_json(cls, json_str: str) -> "SimpleListManager": + """Deserialize from JSON string""" + data = json.loads(json_str) + return cls(items=data.get("items", [])) + + def add_item(self, item: str): + """Add an item to the list""" + self.items.append(item) + + def __len__(self): + return len(self.items) + + def __str__(self): + return f"SimpleListManager(items={self.items})" + + +class RedisLockableORM: + """Redis-based implementation of LockableORM interface + + This class provides Redis-based storage for lockable ORM objects, + mimicking the SQLAlchemy LockableORM interface but using Redis as the backend. + """ + + def __init__(self, redis_client, user_id: str, mem_cube_id: str): + self.redis_client = redis_client + self.user_id = user_id + self.mem_cube_id = mem_cube_id + self.serialized_data = None + self.lock_acquired = False + self.lock_expiry = None + self.version_control = "0" + + def _get_key_prefix(self) -> str: + """Generate Redis key prefix for this ORM instance""" + return f"lockable_orm:{self.user_id}:{self.mem_cube_id}" + + def _get_data_key(self) -> str: + """Get Redis key for serialized data""" + return f"{self._get_key_prefix()}:data" + + def _get_lock_key(self) -> str: + """Get Redis key for lock information""" + return f"{self._get_key_prefix()}:lock" + + def _get_version_key(self) -> str: + """Get Redis key for version control""" + return f"{self._get_key_prefix()}:version" + + def save(self): + """Save this ORM instance to Redis""" + try: + # Save serialized data + if self.serialized_data: + self.redis_client.set(self._get_data_key(), self.serialized_data) + + # Note: Lock information is now managed by acquire_lock/release_locks methods + # We don't save lock info here to avoid conflicts with atomic lock operations + + # Save version control + self.redis_client.set(self._get_version_key(), self.version_control) + + logger.debug(f"Saved RedisLockableORM to Redis: {self._get_key_prefix()}") + + except Exception as e: + logger.error(f"Failed to save RedisLockableORM to Redis: {e}") + raise + + def load(self): + """Load this ORM instance from Redis""" + try: + # Load serialized data + data = self.redis_client.get(self._get_data_key()) + if data: + self.serialized_data = data.decode() if isinstance(data, bytes) else data + else: + self.serialized_data = None + + # Note: Lock information is now managed by acquire_lock/release_locks methods + # We don't load lock info here to avoid conflicts with atomic lock operations + self.lock_acquired = False + self.lock_expiry = None + + # Load version control + version = self.redis_client.get(self._get_version_key()) + if version: + self.version_control = version.decode() if isinstance(version, bytes) else version + else: + self.version_control = "0" + + logger.debug(f"Loaded RedisLockableORM from Redis: {self._get_key_prefix()}") + # Return True if we found any data, False otherwise + return self.serialized_data is not None + + except Exception as e: + logger.error(f"Failed to load RedisLockableORM from Redis: {e}") + return False + + def delete(self): + """Delete this ORM instance from Redis""" + try: + keys_to_delete = [self._get_data_key(), self._get_lock_key(), self._get_version_key()] + self.redis_client.delete(*keys_to_delete) + logger.debug(f"Deleted RedisLockableORM from Redis: {self._get_key_prefix()}") + except Exception as e: + logger.error(f"Failed to delete RedisLockableORM from Redis: {e}") + raise + + +class RedisDBManager(BaseDBManager): + """Redis-based database manager for any serializable object + + This class handles persistence, synchronization, and locking + for any object that implements to_json/from_json methods using Redis as the backend storage. + """ + + def __init__( + self, + engine: Engine | None = None, + user_id: str | None = None, + mem_cube_id: str | None = None, + obj: Any | None = None, + lock_timeout: int = 10, + redis_client=None, + redis_config: dict | None = None, + ): + """Initialize the Redis database manager + + Args: + engine: SQLAlchemy engine (not used for Redis, kept for compatibility) + user_id: Unique identifier for the user + mem_cube_id: Unique identifier for the memory cube + obj: Optional object instance to manage (must have to_json/from_json methods) + lock_timeout: Timeout in seconds for lock acquisition + redis_client: Redis client instance (optional) + redis_config: Redis configuration dictionary (optional) + """ + # Initialize Redis client + self.redis_client = redis_client + self.redis_config = redis_config or {} + + if self.redis_client is None: + self._init_redis_client() + + # Initialize base attributes without calling parent's init_manager + self.user_id = user_id + self.mem_cube_id = mem_cube_id + self.obj = obj + self.obj_type = type(obj) if obj is not None else None # Store the actual object type + self.lock_timeout = lock_timeout + self.engine = engine # Keep for compatibility but not used + self.SessionLocal = None # Not used for Redis + self.last_version_control = None + + logger.info( + f"RedisDBManager initialized for user_id: {user_id}, mem_cube_id: {mem_cube_id}" + ) + logger.info(f"Redis client: {type(self.redis_client).__name__}") + + # Test Redis connection + try: + self.redis_client.ping() + logger.info("Redis connection successful") + except Exception as e: + logger.warning(f"Redis ping failed: {e}") + # Don't raise error here as it might be a mock client in tests + + def _init_redis_client(self): + """Initialize Redis client from config or environment""" + try: + import redis + + # Try to get Redis client from environment first + if not self.redis_client: + self.redis_client = self.load_redis_engine_from_env() + + # If still no client, try from config + if not self.redis_client and self.redis_config: + redis_kwargs = { + "host": self.redis_config.get("host", "localhost"), + "port": self.redis_config.get("port", 6379), + "db": self.redis_config.get("db", 0), + "decode_responses": True, + } + + if self.redis_config.get("password"): + redis_kwargs["password"] = self.redis_config["password"] + + self.redis_client = redis.Redis(**redis_kwargs) + + # Final fallback to localhost + if not self.redis_client: + logger.warning("No Redis configuration found, using localhost defaults") + self.redis_client = redis.Redis( + host="localhost", port=6379, db=0, decode_responses=True + ) + + # Test connection + if not self.redis_client.ping(): + raise ConnectionError("Redis ping failed") + + logger.info("Redis client initialized successfully") + + except ImportError: + logger.error("Redis package not installed. Install with: pip install redis") + raise + except Exception as e: + logger.error(f"Failed to initialize Redis client: {e}") + raise + + @property + def orm_class(self) -> type[RedisLockableORM]: + """Return the Redis-based ORM class""" + return RedisLockableORM + + @property + def obj_class(self) -> type: + """Return the actual object class""" + return self.obj_type if self.obj_type is not None else MemoryMonitorManager + + def merge_items( + self, + orm_instance: RedisLockableORM, + obj_instance: Any, + size_limit: int, + ): + """Merge items from Redis with current object instance + + This method provides a generic way to merge data from Redis with the current + object instance. It handles different object types and their specific merge logic. + + Args: + orm_instance: Redis ORM instance from database + obj_instance: Current object instance (any type with to_json/from_json methods) + size_limit: Maximum number of items to keep after merge + """ + logger.debug(f"Starting merge_items with size_limit={size_limit}") + + try: + if not orm_instance.serialized_data: + logger.warning("No serialized data in Redis ORM instance to merge") + return obj_instance + + # Deserialize the database object using the actual object type + if self.obj_type is not None: + db_obj = self.obj_type.from_json(orm_instance.serialized_data) + else: + db_obj = MemoryMonitorManager.from_json(orm_instance.serialized_data) + + # Handle different object types with specific merge logic based on type + obj_type = type(obj_instance) + if obj_type.__name__ == "MemoryMonitorManager" or hasattr(obj_instance, "memories"): + # MemoryMonitorManager-like objects + return self._merge_memory_monitor_items(obj_instance, db_obj, size_limit) + elif obj_type.__name__ == "SimpleListManager" or hasattr(obj_instance, "items"): + # SimpleListManager-like objects + return self._merge_list_items(obj_instance, db_obj, size_limit) + else: + # Generic objects - just return the current instance + logger.info( + f"No specific merge logic for object type {obj_type.__name__}, returning current instance" + ) + return obj_instance + + except Exception as e: + logger.error(f"Failed to deserialize database instance: {e}", exc_info=True) + logger.warning("Skipping merge due to deserialization error, using current object only") + return obj_instance + + def _merge_memory_monitor_items(self, obj_instance, db_obj, size_limit: int): + """Merge MemoryMonitorManager items""" + # Create a mapping of existing memories by their mapping key + current_memories_dict = obj_instance.memories_mapping_dict + + # Add memories from database that don't exist in current object + for db_memory in db_obj.memories: + if db_memory.tree_memory_item_mapping_key not in current_memories_dict: + obj_instance.memories.append(db_memory) + + # Apply size limit if specified + if size_limit and len(obj_instance.memories) > size_limit: + # Sort by recording_count and keep the most recorded ones + obj_instance.memories.sort(key=lambda x: x.recording_count, reverse=True) + obj_instance.memories = obj_instance.memories[:size_limit] + logger.info( + f"Applied size limit {size_limit}, kept {len(obj_instance.memories)} memories" + ) + + logger.info(f"Merged {len(obj_instance.memories)} memory items") + return obj_instance + + def _merge_list_items(self, obj_instance, db_obj, size_limit: int): + """Merge SimpleListManager-like items""" + merged_items = [] + seen_items = set() + + # First, add all items from current object (higher priority) + for item in obj_instance.items: + if item not in seen_items: + merged_items.append(item) + seen_items.add(item) + + # Then, add items from database that aren't in current object + for item in db_obj.items: + if item not in seen_items: + merged_items.append(item) + seen_items.add(item) + + # Apply size limit if specified (keep most recent items) + if size_limit is not None and size_limit > 0 and len(merged_items) > size_limit: + merged_items = merged_items[:size_limit] + logger.debug(f"Applied size limit of {size_limit}, kept {len(merged_items)} items") + + # Update the object with merged items + obj_instance.items = merged_items + + logger.info(f"Merged {len(merged_items)} list items (size_limit: {size_limit})") + return obj_instance + + def _get_redis_orm_instance(self) -> RedisLockableORM: + """Get or create a Redis ORM instance""" + orm_instance = RedisLockableORM( + redis_client=self.redis_client, user_id=self.user_id, mem_cube_id=self.mem_cube_id + ) + return orm_instance + + def _get_key_prefix(self) -> str: + """Generate Redis key prefix for this ORM instance""" + return f"lockable_orm:{self.user_id}:{self.mem_cube_id}" + + def acquire_lock(self, block: bool = True, **kwargs) -> bool: + """Acquire a distributed lock using Redis with atomic operations + + Args: + block: Whether to block until lock is acquired + **kwargs: Additional filter criteria (ignored for Redis) + + Returns: + True if lock was acquired, False otherwise + """ + try: + lock_key = f"{self._get_key_prefix()}:lock" + now = get_utc_now() + + # Use Redis SET with NX (only if not exists) and EX (expiry) for atomic lock acquisition + lock_value = f"{self.user_id}:{self.mem_cube_id}:{now.timestamp()}" + + while True: + # Try to acquire lock atomically + result = self.redis_client.set( + lock_key, + lock_value, + nx=True, # Only set if key doesn't exist + ex=self.lock_timeout, # Set expiry in seconds + ) + + if result: + # Successfully acquired lock + logger.info(f"Redis lock acquired for {self.user_id}/{self.mem_cube_id}") + return True + + if not block: + logger.warning( + f"Redis lock is held for {self.user_id}/{self.mem_cube_id}, cannot acquire" + ) + return False + + # Wait a bit before retrying + logger.info( + f"Waiting for Redis lock to be released for {self.user_id}/{self.mem_cube_id}" + ) + time.sleep(0.1) + + except Exception as e: + logger.error(f"Failed to acquire Redis lock for {self.user_id}/{self.mem_cube_id}: {e}") + return False + + def release_locks(self, user_id: str, mem_cube_id: str, **kwargs): + """Release Redis locks for the specified user and memory cube + + Args: + user_id: User identifier + mem_cube_id: Memory cube identifier + **kwargs: Additional filter criteria (ignored for Redis) + """ + try: + lock_key = f"lockable_orm:{user_id}:{mem_cube_id}:lock" + + # Delete the lock key to release the lock + result = self.redis_client.delete(lock_key) + + if result: + logger.info(f"Redis lock released for {user_id}/{mem_cube_id}") + else: + logger.warning(f"No Redis lock found to release for {user_id}/{mem_cube_id}") + + except Exception as e: + logger.error(f"Failed to release Redis lock for {user_id}/{mem_cube_id}: {e}") + + def sync_with_orm(self, size_limit: int | None = None) -> None: + """Synchronize data between Redis and the business object + + Args: + size_limit: Optional maximum number of items to keep after synchronization + """ + logger.info( + f"Starting Redis sync_with_orm for {self.user_id}/{self.mem_cube_id} with size_limit={size_limit}" + ) + + try: + # Acquire lock before any operations + lock_status = self.acquire_lock(block=True) + if not lock_status: + logger.error("Failed to acquire Redis lock for synchronization") + return + + # Get existing data from Redis + orm_instance = self._get_redis_orm_instance() + exists = orm_instance.load() + + # If no existing record, create a new one + if not exists: + if self.obj is None: + logger.warning("No object to synchronize and no existing Redis record") + return + + orm_instance.serialized_data = self.obj.to_json() + orm_instance.version_control = "0" + orm_instance.save() + + logger.info("No existing Redis record found. Created a new one.") + self.last_version_control = "0" + return + + # Check version control and merge data + if self.obj is not None: + current_redis_tag = orm_instance.version_control + new_tag = self._increment_version_control(current_redis_tag) + + # Check if this is the first sync or if we need to merge + if self.last_version_control is None: + logger.info("First Redis sync, merging data from Redis") + # Always merge on first sync to load data from Redis + try: + self.merge_items( + orm_instance=orm_instance, obj_instance=self.obj, size_limit=size_limit + ) + except Exception as merge_error: + logger.error( + f"Error during Redis merge_items: {merge_error}", exc_info=True + ) + logger.warning("Continuing with current object data without merge") + elif current_redis_tag == self.last_version_control: + logger.info( + f"Redis version control unchanged ({current_redis_tag}), directly update" + ) + else: + logger.info( + f"Redis version control changed from {self.last_version_control} to {current_redis_tag}, merging data" + ) + try: + self.merge_items( + orm_instance=orm_instance, obj_instance=self.obj, size_limit=size_limit + ) + except Exception as merge_error: + logger.error( + f"Error during Redis merge_items: {merge_error}", exc_info=True + ) + logger.warning("Continuing with current object data without merge") + + # Write merged data back to Redis + orm_instance.serialized_data = self.obj.to_json() + orm_instance.version_control = new_tag + orm_instance.save() + + logger.info(f"Updated Redis serialized_data for {self.user_id}/{self.mem_cube_id}") + self.last_version_control = orm_instance.version_control + else: + logger.warning("No current object to merge with Redis data") + + logger.info(f"Redis synchronization completed for {self.user_id}/{self.mem_cube_id}") + + except Exception as e: + logger.error( + f"Error during Redis synchronization for {self.user_id}/{self.mem_cube_id}: {e}", + exc_info=True, + ) + finally: + # Always release locks + self.release_locks(user_id=self.user_id, mem_cube_id=self.mem_cube_id) + + def save_to_db(self, obj_instance: Any) -> None: + """Save the current state of the business object to Redis + + Args: + obj_instance: The object instance to save (must have to_json method) + """ + try: + # Acquire lock before operations + lock_status = self.acquire_lock(block=True) + if not lock_status: + logger.error("Failed to acquire Redis lock for saving") + return + + # Get or create Redis ORM instance + orm_instance = self._get_redis_orm_instance() + exists = orm_instance.load() + + if not exists: + # Create new record + orm_instance.serialized_data = obj_instance.to_json() + orm_instance.version_control = "0" + orm_instance.save() + + logger.info(f"Created new Redis record for {self.user_id}/{self.mem_cube_id}") + self.last_version_control = "0" + else: + # Update existing record with version control + current_version = orm_instance.version_control + new_version = self._increment_version_control(current_version) + + orm_instance.serialized_data = obj_instance.to_json() + orm_instance.version_control = new_version + orm_instance.save() + + logger.info( + f"Updated existing Redis record for {self.user_id}/{self.mem_cube_id} with version {new_version}" + ) + self.last_version_control = new_version + + except Exception as e: + logger.error(f"Error saving to Redis for {self.user_id}/{self.mem_cube_id}: {e}") + finally: + # Always release locks + self.release_locks(user_id=self.user_id, mem_cube_id=self.mem_cube_id) + + def load_from_db(self, acquire_lock: bool = False) -> Any | None: + """Load the business object from Redis + + Args: + acquire_lock: Whether to acquire a lock during the load operation + + Returns: + The deserialized object instance, or None if not found + """ + try: + if acquire_lock: + lock_status = self.acquire_lock(block=True) + if not lock_status: + logger.error("Failed to acquire Redis lock for loading") + return None + + # Load from Redis + orm_instance = self._get_redis_orm_instance() + exists = orm_instance.load() + + if not exists or not orm_instance.serialized_data: + logger.info(f"No Redis record found for {self.user_id}/{self.mem_cube_id}") + return None + + # Deserialize the business object using the actual object type + if self.obj_type is not None: + db_instance = self.obj_type.from_json(orm_instance.serialized_data) + else: + db_instance = MemoryMonitorManager.from_json(orm_instance.serialized_data) + self.last_version_control = orm_instance.version_control + + logger.info( + f"Successfully loaded object from Redis for {self.user_id}/{self.mem_cube_id} with version {orm_instance.version_control}" + ) + return db_instance + + except Exception as e: + logger.error(f"Error loading from Redis for {self.user_id}/{self.mem_cube_id}: {e}") + return None + finally: + if acquire_lock: + self.release_locks(user_id=self.user_id, mem_cube_id=self.mem_cube_id) + + def close(self): + """Close the Redis manager and clean up resources""" + try: + # Release any locks held by this manager instance + if self.user_id and self.mem_cube_id: + self.release_locks(user_id=self.user_id, mem_cube_id=self.mem_cube_id) + logger.info(f"Released Redis locks for {self.user_id}/{self.mem_cube_id}") + + # Close Redis connection + if self.redis_client: + self.redis_client.close() + logger.info("Redis connection closed") + + # Call parent close method for any additional cleanup + super().close() + + except Exception as e: + logger.error(f"Error during Redis close operation: {e}") + + @classmethod + def from_env( + cls, + user_id: str, + mem_cube_id: str, + obj: Any | None = None, + lock_timeout: int = 10, + env_file_path: str | None = None, + ) -> "RedisDBManager": + """Create RedisDBManager from environment variables + + Args: + user_id: User identifier + mem_cube_id: Memory cube identifier + obj: Optional MemoryMonitorManager instance + lock_timeout: Lock timeout in seconds + env_file_path: Optional path to .env file + + Returns: + RedisDBManager instance + """ + try: + redis_client = cls.load_redis_engine_from_env(env_file_path) + return cls( + user_id=user_id, + mem_cube_id=mem_cube_id, + obj=obj, + lock_timeout=lock_timeout, + redis_client=redis_client, + ) + except Exception as e: + logger.error(f"Failed to create RedisDBManager from environment: {e}") + raise + + def list_keys(self, pattern: str | None = None) -> list[str]: + """List all Redis keys for this manager's data + + Args: + pattern: Optional pattern to filter keys + + Returns: + List of Redis keys + """ + try: + if pattern is None: + pattern = f"lockable_orm:{self.user_id}:{self.mem_cube_id}:*" + + keys = self.redis_client.keys(pattern) + return [key.decode() if isinstance(key, bytes) else key for key in keys] + + except Exception as e: + logger.error(f"Error listing Redis keys: {e}") + return [] + + def health_check(self) -> dict[str, bool]: + """Check the health of Redis connection + + Returns: + Dictionary with health status + """ + try: + redis_healthy = self.redis_client.ping() + return { + "redis": redis_healthy, + "mysql": False, # Not applicable for Redis manager + } + except Exception as e: + logger.error(f"Redis health check failed: {e}") + return {"redis": False, "mysql": False} diff --git a/src/memos/mem_scheduler/schemas/general_schemas.py b/src/memos/mem_scheduler/schemas/general_schemas.py index 248c42e80..f0868e8df 100644 --- a/src/memos/mem_scheduler/schemas/general_schemas.py +++ b/src/memos/mem_scheduler/schemas/general_schemas.py @@ -1,7 +1,16 @@ +from enum import Enum from pathlib import Path from typing import NewType +class SearchMode(str, Enum): + """Enumeration for search modes.""" + + FAST = "fast" + FINE = "fine" + MIXTURE = "mixture" + + FILE_PATH = Path(__file__).absolute() BASE_DIR = FILE_PATH.parent.parent.parent.parent.parent @@ -10,6 +19,8 @@ ADD_LABEL = "add" MEM_READ_LABEL = "mem_read" MEM_ORGANIZE_LABEL = "mem_organize" +API_MIX_SEARCH_LABEL = "api_mix_search" + TreeTextMemory_SEARCH_METHOD = "tree_text_memory_search" TreeTextMemory_FINE_SEARCH_METHOD = "tree_text_memory_fine_search" @@ -24,6 +35,10 @@ DEFAULT_DISPATCHER_MONITOR_CHECK_INTERVAL = 300 DEFAULT_DISPATCHER_MONITOR_MAX_FAILURES = 2 DEFAULT_STUCK_THREAD_TOLERANCE = 10 +DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE = 100000 +DEFAULT_TOP_K = 10 +DEFAULT_CONTEXT_WINDOW_SIZE = 5 +DEFAULT_USE_REDIS_QUEUE = False # startup mode configuration STARTUP_BY_THREAD = "thread" diff --git a/src/memos/mem_scheduler/schemas/message_schemas.py b/src/memos/mem_scheduler/schemas/message_schemas.py index 9b5bd5d81..efdaa44ef 100644 --- a/src/memos/mem_scheduler/schemas/message_schemas.py +++ b/src/memos/mem_scheduler/schemas/message_schemas.py @@ -8,6 +8,7 @@ from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube from memos.mem_scheduler.general_modules.misc import DictConversionMixin +from memos.mem_scheduler.utils.db_utils import get_utc_now from .general_schemas import NOT_INITIALIZED @@ -39,7 +40,7 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin): mem_cube: GeneralMemCube | str = Field(..., description="memcube for schedule") content: str = Field(..., description="Content of the schedule message") timestamp: datetime = Field( - default_factory=lambda: datetime.utcnow(), description="submit time for schedule_messages" + default_factory=get_utc_now, description="submit time for schedule_messages" ) # Pydantic V2 model configuration @@ -88,9 +89,9 @@ def from_dict(cls, data: dict) -> "ScheduleMessageItem": return cls( item_id=data.get("item_id", str(uuid4())), user_id=data["user_id"], - cube_id=data["cube_id"], + mem_cube_id=data["cube_id"], label=data["label"], - cube="Not Applicable", # Custom cube deserialization + mem_cube="Not Applicable", # Custom cube deserialization content=data["content"], timestamp=datetime.fromisoformat(data["timestamp"]), ) @@ -131,7 +132,7 @@ class ScheduleLogForWebItem(BaseModel, DictConversionMixin): description="Maximum capacities of memory partitions", ) timestamp: datetime = Field( - default_factory=lambda: datetime.utcnow(), + default_factory=get_utc_now, description="Timestamp indicating when the log entry was created", ) diff --git a/src/memos/mem_scheduler/schemas/task_schemas.py b/src/memos/mem_scheduler/schemas/task_schemas.py index d189797ae..168a25b5d 100644 --- a/src/memos/mem_scheduler/schemas/task_schemas.py +++ b/src/memos/mem_scheduler/schemas/task_schemas.py @@ -7,6 +7,7 @@ from memos.log import get_logger from memos.mem_scheduler.general_modules.misc import DictConversionMixin +from memos.mem_scheduler.utils.db_utils import get_utc_now logger = get_logger(__name__) @@ -26,7 +27,7 @@ class RunningTaskItem(BaseModel, DictConversionMixin): mem_cube_id: str = Field(..., description="Required memory cube identifier", min_length=1) task_info: str = Field(..., description="Information about the task being executed") task_name: str = Field(..., description="Name/type of the task handler") - start_time: datetime = Field(description="Task start time", default_factory=datetime.utcnow) + start_time: datetime = Field(description="Task start time", default_factory=get_utc_now) end_time: datetime | None = Field(default=None, description="Task completion time") status: str = Field(default="running", description="Task status: running, completed, failed") result: Any | None = Field(default=None, description="Task execution result") @@ -37,13 +38,13 @@ class RunningTaskItem(BaseModel, DictConversionMixin): def mark_completed(self, result: Any | None = None) -> None: """Mark task as completed with optional result.""" - self.end_time = datetime.utcnow() + self.end_time = get_utc_now() self.status = "completed" self.result = result def mark_failed(self, error_message: str) -> None: """Mark task as failed with error message.""" - self.end_time = datetime.utcnow() + self.end_time = get_utc_now() self.status = "failed" self.error_message = error_message diff --git a/src/memos/mem_scheduler/utils/db_utils.py b/src/memos/mem_scheduler/utils/db_utils.py index 5d7cc52c3..4c7402a9d 100644 --- a/src/memos/mem_scheduler/utils/db_utils.py +++ b/src/memos/mem_scheduler/utils/db_utils.py @@ -1,5 +1,22 @@ import os import sqlite3 +import sys + +from datetime import datetime, timezone + + +# Compatibility handling: Python 3.11+ supports UTC, earlier versions use timezone.utc +if sys.version_info >= (3, 11): + from datetime import UTC + + def get_utc_now(): + """Get current UTC datetime with compatibility for different Python versions""" + return datetime.now(UTC) +else: + + def get_utc_now(): + """Get current UTC datetime with compatibility for different Python versions""" + return datetime.now(timezone.utc) def print_db_tables(db_path: str): diff --git a/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py index 8865c2232..b240f4369 100644 --- a/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py +++ b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py @@ -67,39 +67,42 @@ def initialize_rabbitmq( """ Establish connection to RabbitMQ using pika. """ - from pika.adapters.select_connection import SelectConnection - - if config is None: - if config_path is None and AuthConfig.default_config_exists(): - auth_config = AuthConfig.from_local_config() - elif Path(config_path).exists(): - auth_config = AuthConfig.from_local_config(config_path=config_path) + try: + from pika.adapters.select_connection import SelectConnection + + if config is None: + if config_path is None and AuthConfig.default_config_exists(): + auth_config = AuthConfig.from_local_config() + elif Path(config_path).exists(): + auth_config = AuthConfig.from_local_config(config_path=config_path) + else: + logger.error("Fail to initialize auth_config") + return + self.rabbitmq_config = auth_config.rabbitmq + elif isinstance(config, RabbitMQConfig): + self.rabbitmq_config = config + elif isinstance(config, dict): + self.rabbitmq_config = AuthConfig.from_dict(config).rabbitmq else: - logger.error("Fail to initialize auth_config") - return - self.rabbitmq_config = auth_config.rabbitmq - elif isinstance(config, RabbitMQConfig): - self.rabbitmq_config = config - elif isinstance(config, dict): - self.rabbitmq_config = AuthConfig.from_dict(config).rabbitmq - else: - logger.error("Not implemented") - - # Start connection process - parameters = self.get_rabbitmq_connection_param() - self.rabbitmq_connection = SelectConnection( - parameters, - on_open_callback=self.on_rabbitmq_connection_open, - on_open_error_callback=self.on_rabbitmq_connection_error, - on_close_callback=self.on_rabbitmq_connection_closed, - ) + logger.error("Not implemented") + + # Start connection process + parameters = self.get_rabbitmq_connection_param() + self.rabbitmq_connection = SelectConnection( + parameters, + on_open_callback=self.on_rabbitmq_connection_open, + on_open_error_callback=self.on_rabbitmq_connection_error, + on_close_callback=self.on_rabbitmq_connection_closed, + ) - # Start IOLoop in dedicated thread - self._io_loop_thread = threading.Thread( - target=self.rabbitmq_connection.ioloop.start, daemon=True - ) - self._io_loop_thread.start() - logger.info("RabbitMQ connection process started") + # Start IOLoop in dedicated thread + self._io_loop_thread = threading.Thread( + target=self.rabbitmq_connection.ioloop.start, daemon=True + ) + self._io_loop_thread.start() + logger.info("RabbitMQ connection process started") + except Exception: + logger.error("Fail to initialize auth_config", exc_info=True) def get_rabbitmq_queue_size(self) -> int: """Get the current number of messages in the queue. diff --git a/src/memos/mem_scheduler/webservice_modules/redis_service.py b/src/memos/mem_scheduler/webservice_modules/redis_service.py index 5b04ec280..239557bc9 100644 --- a/src/memos/mem_scheduler/webservice_modules/redis_service.py +++ b/src/memos/mem_scheduler/webservice_modules/redis_service.py @@ -1,5 +1,8 @@ import asyncio +import os +import subprocess import threading +import time from collections.abc import Callable from typing import Any @@ -27,10 +30,14 @@ def __init__(self): super().__init__() # settings for redis - self.redis_host: str = None - self.redis_port: int = None - self.redis_db: int = None + self.redis_host: str | None = None + self.redis_port: int | None = None + self.redis_db: int | None = None + self.redis_password: str | None = None + self.socket_timeout: float | None = None + self.socket_connect_timeout: float | None = None self._redis_conn = None + self._local_redis_process = None self.query_list_capacity = 1000 self._redis_listener_running = False @@ -46,19 +53,40 @@ def redis(self, value: Any) -> None: self._redis_conn = value def initialize_redis( - self, redis_host: str = "localhost", redis_port: int = 6379, redis_db: int = 0 + self, + redis_host: str = "localhost", + redis_port: int = 6379, + redis_db: int = 0, + redis_password: str | None = None, + socket_timeout: float | None = None, + socket_connect_timeout: float | None = None, ): import redis self.redis_host = redis_host self.redis_port = redis_port self.redis_db = redis_db + self.redis_password = redis_password + self.socket_timeout = socket_timeout + self.socket_connect_timeout = socket_connect_timeout try: logger.debug(f"Connecting to Redis at {redis_host}:{redis_port}/{redis_db}") - self._redis_conn = redis.Redis( - host=self.redis_host, port=self.redis_port, db=self.redis_db, decode_responses=True - ) + redis_kwargs = { + "host": self.redis_host, + "port": self.redis_port, + "db": self.redis_db, + "password": redis_password, + "decode_responses": True, + } + + # Add timeout parameters if provided + if socket_timeout is not None: + redis_kwargs["socket_timeout"] = socket_timeout + if socket_connect_timeout is not None: + redis_kwargs["socket_connect_timeout"] = socket_connect_timeout + + self._redis_conn = redis.Redis(**redis_kwargs) # test conn if not self._redis_conn.ping(): logger.error("Redis connection failed") @@ -68,6 +96,183 @@ def initialize_redis( self._redis_conn.xtrim("user:queries:stream", self.query_list_capacity) return self._redis_conn + @require_python_package( + import_name="redis", + install_command="pip install redis", + install_link="https://redis.readthedocs.io/en/stable/", + ) + def auto_initialize_redis(self) -> bool: + """ + Auto-initialize Redis with fallback strategies: + 1. Try to initialize from config + 2. Try to initialize from environment variables + 3. Try to start local Redis server as fallback + + Returns: + bool: True if Redis connection is successfully established, False otherwise + """ + import redis + + # Strategy 1: Try to initialize from config + if hasattr(self, "config") and hasattr(self.config, "redis_config"): + try: + redis_config = self.config.redis_config + logger.info("Attempting to initialize Redis from config") + + self._redis_conn = redis.Redis( + host=redis_config.get("host", "localhost"), + port=redis_config.get("port", 6379), + db=redis_config.get("db", 0), + password=redis_config.get("password", None), + decode_responses=True, + ) + + # Test connection + if self._redis_conn.ping(): + logger.info("Redis initialized successfully from config") + self.redis_host = redis_config.get("host", "localhost") + self.redis_port = redis_config.get("port", 6379) + self.redis_db = redis_config.get("db", 0) + self.redis_password = redis_config.get("password", None) + self.socket_timeout = redis_config.get("socket_timeout", None) + self.socket_connect_timeout = redis_config.get("socket_connect_timeout", None) + return True + else: + logger.warning("Redis config connection test failed") + self._redis_conn = None + except Exception as e: + logger.warning(f"Failed to initialize Redis from config: {e}") + self._redis_conn = None + + # Strategy 2: Try to initialize from environment variables + try: + redis_host = os.getenv("MEMSCHEDULER_REDIS_HOST", "localhost") + redis_port = int(os.getenv("MEMSCHEDULER_REDIS_PORT", "6379")) + redis_db = int(os.getenv("MEMSCHEDULER_REDIS_DB", "0")) + redis_password = os.getenv("MEMSCHEDULER_REDIS_PASSWORD", None) + socket_timeout = os.getenv("MEMSCHEDULER_REDIS_TIMEOUT", None) + socket_connect_timeout = os.getenv("MEMSCHEDULER_REDIS_CONNECT_TIMEOUT", None) + + logger.info( + f"Attempting to initialize Redis from environment variables: {redis_host}:{redis_port}" + ) + + redis_kwargs = { + "host": redis_host, + "port": redis_port, + "db": redis_db, + "password": redis_password, + "decode_responses": True, + } + + # Add timeout parameters if provided + if socket_timeout is not None: + try: + redis_kwargs["socket_timeout"] = float(socket_timeout) + except ValueError: + logger.warning( + f"Invalid MEMSCHEDULER_REDIS_TIMEOUT value: {socket_timeout}, ignoring" + ) + + if socket_connect_timeout is not None: + try: + redis_kwargs["socket_connect_timeout"] = float(socket_connect_timeout) + except ValueError: + logger.warning( + f"Invalid MEMSCHEDULER_REDIS_CONNECT_TIMEOUT value: {socket_connect_timeout}, ignoring" + ) + + self._redis_conn = redis.Redis(**redis_kwargs) + + # Test connection + if self._redis_conn.ping(): + logger.info("Redis initialized successfully from environment variables") + self.redis_host = redis_host + self.redis_port = redis_port + self.redis_db = redis_db + self.redis_password = redis_password + self.socket_timeout = float(socket_timeout) if socket_timeout is not None else None + self.socket_connect_timeout = ( + float(socket_connect_timeout) if socket_connect_timeout is not None else None + ) + return True + else: + logger.warning("Redis environment connection test failed") + self._redis_conn = None + except Exception as e: + logger.warning(f"Failed to initialize Redis from environment variables: {e}") + self._redis_conn = None + + # Strategy 3: Try to start local Redis server as fallback + try: + logger.warning( + "Attempting to start local Redis server as fallback (not recommended for production)" + ) + + # Try to start Redis server locally + self._local_redis_process = subprocess.Popen( + ["redis-server", "--port", "6379", "--daemonize", "no"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + preexec_fn=os.setsid if hasattr(os, "setsid") else None, + ) + + # Wait a moment for Redis to start + time.sleep(0.5) + + # Try to connect to local Redis + self._redis_conn = redis.Redis(host="localhost", port=6379, db=0, decode_responses=True) + + # Test connection + if self._redis_conn.ping(): + logger.warning("Local Redis server started and connected successfully") + logger.warning("WARNING: Using local Redis server - not suitable for production!") + self.redis_host = "localhost" + self.redis_port = 6379 + self.redis_db = 0 + self.redis_password = None + self.socket_timeout = None + self.socket_connect_timeout = None + return True + else: + logger.error("Local Redis server connection test failed") + self._cleanup_local_redis() + return False + + except Exception as e: + logger.error(f"Failed to start local Redis server: {e}") + self._cleanup_local_redis() + return False + + def _cleanup_local_redis(self): + """Clean up local Redis process if it exists""" + if self._local_redis_process: + try: + self._local_redis_process.terminate() + self._local_redis_process.wait(timeout=5) + logger.info("Local Redis process terminated") + except subprocess.TimeoutExpired: + logger.warning("Local Redis process did not terminate gracefully, killing it") + self._local_redis_process.kill() + self._local_redis_process.wait() + except Exception as e: + logger.error(f"Error cleaning up local Redis process: {e}") + finally: + self._local_redis_process = None + + def _cleanup_redis_resources(self): + """Clean up Redis connection and local process""" + if self._redis_conn: + try: + self._redis_conn.close() + logger.info("Redis connection closed") + except Exception as e: + logger.error(f"Error closing Redis connection: {e}") + finally: + self._redis_conn = None + + self._cleanup_local_redis() + async def redis_add_message_stream(self, message: dict): logger.debug(f"add_message_stream: {message}") return self._redis_conn.xadd("user:queries:stream", message) @@ -150,7 +355,5 @@ def redis_stop_listening(self): logger.info("Redis stream listener stopped") def redis_close(self): - """Close Redis connection""" - if self._redis_conn is not None: - self._redis_conn.close() - self._redis_conn = None + """Close Redis connection and clean up resources""" + self._cleanup_redis_resources() diff --git a/src/memos/memories/activation/kv.py b/src/memos/memories/activation/kv.py index 2fa08590f..98d611dbf 100644 --- a/src/memos/memories/activation/kv.py +++ b/src/memos/memories/activation/kv.py @@ -237,16 +237,36 @@ def _concat_caches(self, caches: list[DynamicCache]) -> DynamicCache: def move_dynamic_cache_htod(dynamic_cache: DynamicCache, device: str) -> DynamicCache: """ + Move DynamicCache from CPU to GPU device. + Compatible with both old and new transformers versions. + In SimpleMemChat.run(), if self.config.enable_activation_memory is enabled, we load serialized kv cache from a [class KVCacheMemory] object, which has a kv_cache_memories on CPU. So before inferring with DynamicCache, we should move it to GPU in-place first. """ - # Currently, we put this function outside [class KVCacheMemory] - for i in range(len(dynamic_cache.key_cache)): - if dynamic_cache.key_cache[i] is not None: - dynamic_cache.key_cache[i] = dynamic_cache.key_cache[i].to(device, non_blocking=True) - if dynamic_cache.value_cache[i] is not None: - dynamic_cache.value_cache[i] = dynamic_cache.value_cache[i].to( - device, non_blocking=True - ) + # Handle compatibility between old and new transformers versions + if hasattr(dynamic_cache, "layers"): + # New version: use layers attribute + for layer in dynamic_cache.layers: + if hasattr(layer, "key_cache") and layer.key_cache is not None: + layer.key_cache = layer.key_cache.to(device, non_blocking=True) + if hasattr(layer, "value_cache") and layer.value_cache is not None: + layer.value_cache = layer.value_cache.to(device, non_blocking=True) + elif hasattr(layer, "keys") and hasattr(layer, "values"): + # Alternative attribute names in some versions + if layer.keys is not None: + layer.keys = layer.keys.to(device, non_blocking=True) + if layer.values is not None: + layer.values = layer.values.to(device, non_blocking=True) + elif hasattr(dynamic_cache, "key_cache") and hasattr(dynamic_cache, "value_cache"): + # Old version: use key_cache and value_cache attributes + for i in range(len(dynamic_cache.key_cache)): + if dynamic_cache.key_cache[i] is not None: + dynamic_cache.key_cache[i] = dynamic_cache.key_cache[i].to( + device, non_blocking=True + ) + if dynamic_cache.value_cache[i] is not None: + dynamic_cache.value_cache[i] = dynamic_cache.value_cache[i].to( + device, non_blocking=True + ) return dynamic_cache diff --git a/tests/llms/test_hf.py b/tests/llms/test_hf.py index 8a266e58d..595995ad1 100644 --- a/tests/llms/test_hf.py +++ b/tests/llms/test_hf.py @@ -93,15 +93,50 @@ def test_build_kv_cache_and_generation(self): add_generation_prompt=True, ) llm = self._create_llm(config) + + # Ensure the mock model returns an object with past_key_values attribute + forward_output = MagicMock() + forward_output.logits = torch.ones(1, 1, 100) + + # Create a DynamicCache that's compatible with both old and new transformers versions + kv_cache = DynamicCache() + + # Mock the DynamicCache to have both old and new version attributes for compatibility + # New version uses 'layers' attribute + mock_layer = MagicMock() + mock_layer.key_cache = torch.tensor([[[[1.0, 2.0]]]]) + mock_layer.value_cache = torch.tensor([[[[3.0, 4.0]]]]) + kv_cache.layers = [mock_layer] + + # Old version uses 'key_cache' and 'value_cache' lists + kv_cache.key_cache = [torch.tensor([[[[1.0, 2.0]]]])] + kv_cache.value_cache = [torch.tensor([[[[3.0, 4.0]]]])] + + forward_output.past_key_values = kv_cache + # Make sure the mock model call returns the forward_output when called with **kwargs + self.mock_model.return_value = forward_output + kv_cache = llm.build_kv_cache("The capital of France is Paris.") self.assertIsInstance(kv_cache, DynamicCache) resp = llm.generate( [{"role": "user", "content": "What's its population?"}], past_key_values=kv_cache ) self.assertEqual(resp, self.standard_response) - first_kwargs = self.mock_model.call_args_list[0][1] - self.assertIs(first_kwargs["past_key_values"], kv_cache) - self.assertTrue(first_kwargs["use_cache"]) + # Check that the model was called with past_key_values during _prefill + # The model should be called multiple times during generation with cache + found_past_key_values = False + for call_args in self.mock_model.call_args_list: + if len(call_args) > 1 and "past_key_values" in call_args[1]: + found_past_key_values = True + break + self.assertTrue(found_past_key_values, "Model should be called with past_key_values") + # Check that use_cache was used + found_use_cache = False + for call_args in self.mock_model.call_args_list: + if len(call_args) > 1 and call_args[1].get("use_cache"): + found_use_cache = True + break + self.assertTrue(found_use_cache, "Model should be called with use_cache=True") def test_think_prefix_removal(self): config = HFLLMConfig( diff --git a/tests/mem_scheduler/test_dispatcher.py b/tests/mem_scheduler/test_dispatcher.py index ed2093dea..e3064660b 100644 --- a/tests/mem_scheduler/test_dispatcher.py +++ b/tests/mem_scheduler/test_dispatcher.py @@ -233,7 +233,7 @@ def test_dispatch_parallel(self): self.assertEqual(len(label2_messages), 1) self.assertEqual(label2_messages[0].item_id, "msg2") - def test_group_messages_by_user_and_cube(self): + def test_group_messages_by_user_and_mem_cube(self): """Test grouping messages by user and cube.""" # Check actual grouping logic with patch("memos.mem_scheduler.general_modules.dispatcher.logger.debug"): @@ -261,47 +261,6 @@ def test_group_messages_by_user_and_cube(self): for msg in expected[user_id][cube_id]: self.assertIn(msg.item_id, [m.item_id for m in result[user_id][cube_id]]) - def test_thread_race(self): - """Test the ThreadRace integration.""" - - # Define test tasks - def task1(stop_flag): - time.sleep(0.1) - return "result1" - - def task2(stop_flag): - time.sleep(0.2) - return "result2" - - # Run competitive tasks - tasks = { - "task1": task1, - "task2": task2, - } - - result = self.dispatcher.run_competitive_tasks(tasks, timeout=1.0) - - # Verify the result - self.assertIsNotNone(result) - self.assertEqual(result[0], "task1") # task1 should win - self.assertEqual(result[1], "result1") - - def test_thread_race_timeout(self): - """Test ThreadRace with timeout.""" - - # Define a task that takes longer than the timeout - def slow_task(stop_flag): - time.sleep(0.5) - return "slow_result" - - tasks = {"slow": slow_task} - - # Run with a short timeout - result = self.dispatcher.run_competitive_tasks(tasks, timeout=0.1) - - # Verify no result was returned due to timeout - self.assertIsNone(result) - def test_thread_race_cooperative_termination(self): """Test that ThreadRace properly terminates slower threads when one completes.""" @@ -459,3 +418,190 @@ def test_dispatcher_monitor_logs_stuck_task_messages(self): self.assertIn("Messages: 2 items", expected_log) self.assertIn("Stuck message 1", expected_log) self.assertIn("Stuck message 2", expected_log) + + def test_get_running_tasks_no_filter(self): + """Test get_running_tasks without filter returns all running tasks.""" + # Create test tasks manually + task1 = RunningTaskItem( + user_id="user1", + mem_cube_id="cube1", + task_info="Test task 1", + task_name="handler1", + ) + task2 = RunningTaskItem( + user_id="user2", + mem_cube_id="cube2", + task_info="Test task 2", + task_name="handler2", + ) + + # Add tasks to dispatcher's running tasks + with self.dispatcher._task_lock: + self.dispatcher._running_tasks[task1.item_id] = task1 + self.dispatcher._running_tasks[task2.item_id] = task2 + + # Get all running tasks + running_tasks = self.dispatcher.get_running_tasks() + + # Verify all tasks are returned + self.assertEqual(len(running_tasks), 2) + self.assertIn(task1.item_id, running_tasks) + self.assertIn(task2.item_id, running_tasks) + self.assertEqual(running_tasks[task1.item_id], task1) + self.assertEqual(running_tasks[task2.item_id], task2) + + # Clean up + with self.dispatcher._task_lock: + self.dispatcher._running_tasks.clear() + + def test_get_running_tasks_filter_by_user_id(self): + """Test get_running_tasks with user_id filter.""" + # Create test tasks with different user_ids + task1 = RunningTaskItem( + user_id="user1", + mem_cube_id="cube1", + task_info="Test task 1", + task_name="handler1", + ) + task2 = RunningTaskItem( + user_id="user2", + mem_cube_id="cube2", + task_info="Test task 2", + task_name="handler2", + ) + task3 = RunningTaskItem( + user_id="user1", + mem_cube_id="cube3", + task_info="Test task 3", + task_name="handler3", + ) + + # Add tasks to dispatcher's running tasks + with self.dispatcher._task_lock: + self.dispatcher._running_tasks[task1.item_id] = task1 + self.dispatcher._running_tasks[task2.item_id] = task2 + self.dispatcher._running_tasks[task3.item_id] = task3 + + # Filter by user_id + user1_tasks = self.dispatcher.get_running_tasks(lambda task: task.user_id == "user1") + + # Verify only user1 tasks are returned + self.assertEqual(len(user1_tasks), 2) + self.assertIn(task1.item_id, user1_tasks) + self.assertIn(task3.item_id, user1_tasks) + self.assertNotIn(task2.item_id, user1_tasks) + + # Clean up + with self.dispatcher._task_lock: + self.dispatcher._running_tasks.clear() + + def test_get_running_tasks_filter_by_multiple_conditions(self): + """Test get_running_tasks with multiple filter conditions.""" + # Create test tasks with different attributes + task1 = RunningTaskItem( + user_id="user1", + mem_cube_id="cube1", + task_info="Test task 1", + task_name="test_handler", + ) + task2 = RunningTaskItem( + user_id="user1", + mem_cube_id="cube2", + task_info="Test task 2", + task_name="other_handler", + ) + task3 = RunningTaskItem( + user_id="user2", + mem_cube_id="cube1", + task_info="Test task 3", + task_name="test_handler", + ) + + # Add tasks to dispatcher's running tasks + with self.dispatcher._task_lock: + self.dispatcher._running_tasks[task1.item_id] = task1 + self.dispatcher._running_tasks[task2.item_id] = task2 + self.dispatcher._running_tasks[task3.item_id] = task3 + + # Filter by multiple conditions: user_id == "user1" AND task_name == "test_handler" + filtered_tasks = self.dispatcher.get_running_tasks( + lambda task: task.user_id == "user1" and task.task_name == "test_handler" + ) + + # Verify only task1 matches both conditions + self.assertEqual(len(filtered_tasks), 1) + self.assertIn(task1.item_id, filtered_tasks) + self.assertNotIn(task2.item_id, filtered_tasks) + self.assertNotIn(task3.item_id, filtered_tasks) + + # Clean up + with self.dispatcher._task_lock: + self.dispatcher._running_tasks.clear() + + def test_get_running_tasks_filter_by_status(self): + """Test get_running_tasks with status filter.""" + # Create test tasks with different statuses + task1 = RunningTaskItem( + user_id="user1", + mem_cube_id="cube1", + task_info="Test task 1", + task_name="handler1", + ) + task2 = RunningTaskItem( + user_id="user2", + mem_cube_id="cube2", + task_info="Test task 2", + task_name="handler2", + ) + + # Manually set different statuses + task1.status = "running" + task2.status = "completed" + + # Add tasks to dispatcher's running tasks + with self.dispatcher._task_lock: + self.dispatcher._running_tasks[task1.item_id] = task1 + self.dispatcher._running_tasks[task2.item_id] = task2 + + # Filter by status + running_status_tasks = self.dispatcher.get_running_tasks( + lambda task: task.status == "running" + ) + + # Verify only running tasks are returned + self.assertEqual(len(running_status_tasks), 1) + self.assertIn(task1.item_id, running_status_tasks) + self.assertNotIn(task2.item_id, running_status_tasks) + + # Clean up + with self.dispatcher._task_lock: + self.dispatcher._running_tasks.clear() + + def test_get_running_tasks_thread_safety(self): + """Test get_running_tasks is thread-safe.""" + # Create test task + task1 = RunningTaskItem( + user_id="user1", + mem_cube_id="cube1", + task_info="Test task 1", + task_name="handler1", + ) + + # Add task to dispatcher's running tasks + with self.dispatcher._task_lock: + self.dispatcher._running_tasks[task1.item_id] = task1 + + # Get running tasks (should work without deadlock) + running_tasks = self.dispatcher.get_running_tasks() + + # Verify task is returned + self.assertEqual(len(running_tasks), 1) + self.assertIn(task1.item_id, running_tasks) + + # Test with filter (should also work without deadlock) + filtered_tasks = self.dispatcher.get_running_tasks(lambda task: task.user_id == "user1") + self.assertEqual(len(filtered_tasks), 1) + + # Clean up + with self.dispatcher._task_lock: + self.dispatcher._running_tasks.clear() diff --git a/tests/mem_scheduler/test_orm.py b/tests/mem_scheduler/test_orm.py index ddf4fea8b..a43231e4a 100644 --- a/tests/mem_scheduler/test_orm.py +++ b/tests/mem_scheduler/test_orm.py @@ -13,6 +13,7 @@ DBManagerForMemoryMonitorManager, DBManagerForQueryMonitorQueue, ) +from memos.mem_scheduler.orm_modules.redis_model import RedisDBManager from memos.mem_scheduler.schemas.monitor_schemas import ( MemoryMonitorItem, MemoryMonitorManager, @@ -297,3 +298,150 @@ def test_concurrent_access(self, temp_db, query_queue_obj): manager1.close() manager2.close() + + +class TestRedisDBManager: + """Test class for RedisDBManager functionality""" + + @pytest.fixture + def memory_manager_obj(self): + """Create a MemoryMonitorManager object for testing""" + return MemoryMonitorManager( + user_id=TEST_USER_ID, + mem_cube_id=TEST_MEM_CUBE_ID, + memories=[ + MemoryMonitorItem( + item_id="redis-test-123", + memory_text="Redis test memory", + tree_memory_item=None, + tree_memory_item_mapping_key="redis_test_key", + keywords_score=0.8, + sorting_score=0.9, + importance_score=0.7, + recording_count=3, + ) + ], + ) + + @pytest.fixture + def mock_redis_client(self): + """Create a mock Redis client for testing""" + try: + from unittest.mock import MagicMock + + # Create a mock Redis client + mock_client = MagicMock() + + # Mock Redis data storage + mock_data = {} + + def mock_set(key, value, nx=False, ex=None, **kwargs): + if nx and key in mock_data: + # NX means "only set if not exists" + return False # Redis returns False when NX fails + mock_data[key] = value + return True + + def mock_get(key): + return mock_data.get(key) + + def mock_hset(key, mapping=None, **kwargs): + if key not in mock_data: + mock_data[key] = {} + if mapping: + mock_data[key].update(mapping) + if kwargs: + mock_data[key].update(kwargs) + return len(mapping) if mapping else len(kwargs) + + def mock_hgetall(key): + return mock_data.get(key, {}) + + def mock_delete(*keys): + deleted = 0 + for key in keys: + if key in mock_data: + del mock_data[key] + deleted += 1 + return deleted + + def mock_keys(pattern): + import fnmatch + + return [key for key in mock_data if fnmatch.fnmatch(key, pattern)] + + def mock_ping(): + return True + + def mock_close(): + pass + + # Configure mock methods + mock_client.set = mock_set + mock_client.get = mock_get + mock_client.hset = mock_hset + mock_client.hgetall = mock_hgetall + mock_client.delete = mock_delete + mock_client.keys = mock_keys + mock_client.ping = mock_ping + mock_client.close = mock_close + + return mock_client + + except ImportError: + pytest.skip("Redis package not available for testing") + + @pytest.fixture + def redis_manager(self, mock_redis_client, memory_manager_obj): + """Create RedisDBManager instance with mock Redis client""" + manager = RedisDBManager( + user_id=TEST_USER_ID, + mem_cube_id=TEST_MEM_CUBE_ID, + obj=memory_manager_obj, + lock_timeout=10, + redis_client=mock_redis_client, + ) + yield manager + manager.close() + + def test_redis_manager_initialization(self, mock_redis_client): + """Test RedisDBManager initialization""" + manager = RedisDBManager( + user_id=TEST_USER_ID, mem_cube_id=TEST_MEM_CUBE_ID, redis_client=mock_redis_client + ) + + assert manager.user_id == TEST_USER_ID + assert manager.mem_cube_id == TEST_MEM_CUBE_ID + assert manager.redis_client is mock_redis_client + assert manager.orm_class.__name__ == "RedisLockableORM" + assert manager.obj_class == MemoryMonitorManager + + manager.close() + + def test_redis_lockable_orm_save_load(self, mock_redis_client): + """Test RedisLockableORM save and load operations""" + from memos.mem_scheduler.orm_modules.redis_model import RedisLockableORM + + orm = RedisLockableORM( + redis_client=mock_redis_client, user_id=TEST_USER_ID, mem_cube_id=TEST_MEM_CUBE_ID + ) + + # Test save + orm.serialized_data = '{"test": "data"}' + orm.version_control = "1" + orm.lock_acquired = True + orm.lock_expiry = datetime.now() + + orm.save() + + # Test load + new_orm = RedisLockableORM( + redis_client=mock_redis_client, user_id=TEST_USER_ID, mem_cube_id=TEST_MEM_CUBE_ID + ) + + exists = new_orm.load() + assert exists + assert new_orm.serialized_data == '{"test": "data"}' + assert new_orm.version_control == "1" + # Note: lock_acquired is False after load by design - locks are managed separately + assert not new_orm.lock_acquired diff --git a/tests/mem_scheduler/test_scheduler.py b/tests/mem_scheduler/test_scheduler.py index 15338006d..369b4a6f1 100644 --- a/tests/mem_scheduler/test_scheduler.py +++ b/tests/mem_scheduler/test_scheduler.py @@ -26,6 +26,7 @@ ) from memos.mem_scheduler.schemas.message_schemas import ( ScheduleLogForWebItem, + ScheduleMessageItem, ) from memos.memories.textual.tree import TreeTextMemory @@ -36,6 +37,9 @@ class TestGeneralScheduler(unittest.TestCase): + # Control whether to run activation memory tests that require GPU, default is False + RUN_ACTIVATION_MEMORY_TESTS = True + def _create_mock_auth_config(self): """Create a mock AuthConfig for testing purposes.""" # Create mock configs with valid test values @@ -68,6 +72,19 @@ def setUp(self): self.llm = MagicMock(spec=BaseLLM) self.mem_cube = MagicMock(spec=GeneralMemCube) self.tree_text_memory = MagicMock(spec=TreeTextMemory) + # Add memory_manager mock to prevent AttributeError in scheduler_logger + self.tree_text_memory.memory_manager = MagicMock() + self.tree_text_memory.memory_manager.memory_size = { + "LongTermMemory": 10000, + "UserMemory": 10000, + "WorkingMemory": 20, + } + # Mock get_current_memory_size method + self.tree_text_memory.get_current_memory_size.return_value = { + "LongTermMemory": 100, + "UserMemory": 50, + "WorkingMemory": 10, + } self.mem_cube.text_mem = self.tree_text_memory self.mem_cube.act_mem = MagicMock() @@ -185,8 +202,72 @@ def test_scheduler_startup_mode_thread(self): # Stop the scheduler self.scheduler.stop() - # Verify cleanup - self.assertFalse(self.scheduler._running) + def test_redis_message_queue(self): + """Test Redis message queue functionality for sending and receiving messages.""" + import asyncio + import time + + from unittest.mock import MagicMock, patch + + # Mock Redis connection and operations + mock_redis = MagicMock() + mock_redis.xadd = MagicMock(return_value=b"1234567890-0") + + # Track received messages + received_messages = [] + + def redis_handler(messages: list[ScheduleMessageItem]) -> None: + """Handler for Redis messages.""" + received_messages.extend(messages) + + # Register Redis handler + redis_label = "test_redis" + handlers = {redis_label: redis_handler} + self.scheduler.register_handlers(handlers) + + # Enable Redis queue for this test + with ( + patch.object(self.scheduler, "use_redis_queue", True), + patch.object(self.scheduler, "_redis_conn", mock_redis), + ): + # Start scheduler + self.scheduler.start() + + # Create test message for Redis + redis_message = ScheduleMessageItem( + label=redis_label, + content="Redis test message", + user_id="redis_user", + mem_cube_id="redis_cube", + mem_cube="redis_mem_cube_obj", + timestamp=datetime.now(), + ) + + # Submit message to Redis queue + asyncio.run(self.scheduler.submit_messages(redis_message)) + + # Verify Redis xadd was called + mock_redis.xadd.assert_called_once() + call_args = mock_redis.xadd.call_args + self.assertEqual(call_args[0][0], "user:queries:stream") + + # Verify message data was serialized correctly + message_data = call_args[0][1] + self.assertEqual(message_data["label"], redis_label) + self.assertEqual(message_data["content"], "Redis test message") + self.assertEqual(message_data["user_id"], "redis_user") + self.assertEqual(message_data["cube_id"], "redis_cube") # Note: to_dict uses cube_id + + # Simulate Redis message consumption + # This would normally be handled by the Redis consumer in the scheduler + time.sleep(0.1) # Brief wait for async operations + + # Stop scheduler + self.scheduler.stop() + + print("Redis message queue test completed successfully!") + + # Removed test_robustness method - was too time-consuming for CI/CD pipeline def test_scheduler_startup_mode_process(self): """Test scheduler with process startup mode.""" @@ -219,3 +300,284 @@ def test_scheduler_startup_mode_constants(self): """Test that startup mode constants are properly defined.""" self.assertEqual(STARTUP_BY_THREAD, "thread") self.assertEqual(STARTUP_BY_PROCESS, "process") + + def test_activation_memory_update(self): + """Test activation memory update functionality with DynamicCache handling.""" + if not self.RUN_ACTIVATION_MEMORY_TESTS: + self.skipTest( + "Skipping activation memory test. Set RUN_ACTIVATION_MEMORY_TESTS=True to enable." + ) + + from unittest.mock import Mock + + from transformers import DynamicCache + + from memos.memories.activation.kv import KVCacheMemory + + # Mock the mem_cube with activation memory + mock_kv_cache_memory = Mock(spec=KVCacheMemory) + self.mem_cube.act_mem = mock_kv_cache_memory + + # Mock get_all to return empty list (no existing cache items) + mock_kv_cache_memory.get_all.return_value = [] + + # Create a mock DynamicCache with layers attribute + mock_cache = Mock(spec=DynamicCache) + mock_cache.layers = [] + + # Create mock layers with key_cache and value_cache + for _ in range(2): # Simulate 2 layers + mock_layer = Mock() + mock_layer.key_cache = Mock() + mock_layer.value_cache = Mock() + mock_cache.layers.append(mock_layer) + + # Mock the extract method to return a KVCacheItem + mock_cache_item = Mock() + mock_cache_item.records = Mock() + mock_cache_item.records.text_memories = [] + mock_cache_item.records.timestamp = None + mock_kv_cache_memory.extract.return_value = mock_cache_item + + # Test data + test_memories = ["Test memory 1", "Test memory 2"] + user_id = "test_user" + mem_cube_id = "test_cube" + + # Call the method under test + try: + self.scheduler.update_activation_memory( + new_memories=test_memories, + label=QUERY_LABEL, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=self.mem_cube, + ) + + # Verify that extract was called + mock_kv_cache_memory.extract.assert_called_once() + + # Verify that add was called with the extracted cache item + mock_kv_cache_memory.add.assert_called_once() + + # Verify that dump was called + mock_kv_cache_memory.dump.assert_called_once() + + print("✅ Activation memory update test passed - DynamicCache layers handled correctly") + + except Exception as e: + self.fail(f"Activation memory update failed: {e}") + + def test_dynamic_cache_layers_access(self): + """Test DynamicCache layers attribute access for compatibility.""" + if not self.RUN_ACTIVATION_MEMORY_TESTS: + self.skipTest( + "Skipping activation memory test. Set RUN_ACTIVATION_MEMORY_TESTS=True to enable." + ) + + from unittest.mock import Mock + + from transformers import DynamicCache + + # Create a real DynamicCache instance + cache = DynamicCache() + + # Check if it has layers attribute (may vary by transformers version) + if hasattr(cache, "layers"): + self.assertIsInstance(cache.layers, list, "DynamicCache.layers should be a list") + + # Test with mock layers + mock_layer = Mock() + mock_layer.key_cache = Mock() + mock_layer.value_cache = Mock() + cache.layers.append(mock_layer) + + # Verify we can access layer attributes + self.assertEqual(len(cache.layers), 1) + self.assertTrue(hasattr(cache.layers[0], "key_cache")) + self.assertTrue(hasattr(cache.layers[0], "value_cache")) + + print("✅ DynamicCache layers access test passed") + else: + # If layers attribute doesn't exist, verify our fix handles this case + print("⚠️ DynamicCache doesn't have 'layers' attribute in this transformers version") + print("✅ Test passed - our code should handle this gracefully") + + def test_get_running_tasks_with_filter(self): + """Test get_running_tasks method with filter function.""" + # Mock dispatcher and its get_running_tasks method + mock_task_item1 = MagicMock() + mock_task_item1.item_id = "task_1" + mock_task_item1.user_id = "user_1" + mock_task_item1.mem_cube_id = "cube_1" + mock_task_item1.task_info = {"type": "query"} + mock_task_item1.task_name = "test_task_1" + mock_task_item1.start_time = datetime.now() + mock_task_item1.end_time = None + mock_task_item1.status = "running" + mock_task_item1.result = None + mock_task_item1.error_message = None + mock_task_item1.messages = [] + + # Define a filter function + def user_filter(task): + return task.user_id == "user_1" + + # Mock the filtered result (only task_1 matches the filter) + with patch.object( + self.scheduler.dispatcher, "get_running_tasks", return_value={"task_1": mock_task_item1} + ) as mock_get_running_tasks: + # Call get_running_tasks with filter + result = self.scheduler.get_running_tasks(filter_func=user_filter) + + # Verify result + self.assertIsInstance(result, dict) + self.assertIn("task_1", result) + self.assertEqual(len(result), 1) + + # Verify dispatcher method was called with filter + mock_get_running_tasks.assert_called_once_with(filter_func=user_filter) + + def test_get_running_tasks_empty_result(self): + """Test get_running_tasks method when no tasks are running.""" + # Mock dispatcher to return empty dict + with patch.object( + self.scheduler.dispatcher, "get_running_tasks", return_value={} + ) as mock_get_running_tasks: + # Call get_running_tasks + result = self.scheduler.get_running_tasks() + + # Verify empty result + self.assertIsInstance(result, dict) + self.assertEqual(len(result), 0) + + # Verify dispatcher method was called + mock_get_running_tasks.assert_called_once_with(filter_func=None) + + def test_get_running_tasks_no_dispatcher(self): + """Test get_running_tasks method when dispatcher is None.""" + # Temporarily set dispatcher to None + original_dispatcher = self.scheduler.dispatcher + self.scheduler.dispatcher = None + + # Call get_running_tasks + result = self.scheduler.get_running_tasks() + + # Verify empty result and warning behavior + self.assertIsInstance(result, dict) + self.assertEqual(len(result), 0) + + # Restore dispatcher + self.scheduler.dispatcher = original_dispatcher + + def test_get_running_tasks_multiple_tasks(self): + """Test get_running_tasks method with multiple tasks.""" + # Mock multiple task items + mock_task_item1 = MagicMock() + mock_task_item1.item_id = "task_1" + mock_task_item1.user_id = "user_1" + mock_task_item1.mem_cube_id = "cube_1" + mock_task_item1.task_info = {"type": "query"} + mock_task_item1.task_name = "test_task_1" + mock_task_item1.start_time = datetime.now() + mock_task_item1.end_time = None + mock_task_item1.status = "running" + mock_task_item1.result = None + mock_task_item1.error_message = None + mock_task_item1.messages = [] + + mock_task_item2 = MagicMock() + mock_task_item2.item_id = "task_2" + mock_task_item2.user_id = "user_2" + mock_task_item2.mem_cube_id = "cube_2" + mock_task_item2.task_info = {"type": "answer"} + mock_task_item2.task_name = "test_task_2" + mock_task_item2.start_time = datetime.now() + mock_task_item2.end_time = None + mock_task_item2.status = "completed" + mock_task_item2.result = "success" + mock_task_item2.error_message = None + mock_task_item2.messages = ["message1", "message2"] + + with patch.object( + self.scheduler.dispatcher, + "get_running_tasks", + return_value={"task_1": mock_task_item1, "task_2": mock_task_item2}, + ) as mock_get_running_tasks: + # Call get_running_tasks + result = self.scheduler.get_running_tasks() + + # Verify result structure + self.assertIsInstance(result, dict) + self.assertEqual(len(result), 2) + self.assertIn("task_1", result) + self.assertIn("task_2", result) + + # Verify task_1 details + task1_dict = result["task_1"] + self.assertEqual(task1_dict["item_id"], "task_1") + self.assertEqual(task1_dict["user_id"], "user_1") + self.assertEqual(task1_dict["status"], "running") + + # Verify task_2 details + task2_dict = result["task_2"] + self.assertEqual(task2_dict["item_id"], "task_2") + self.assertEqual(task2_dict["user_id"], "user_2") + self.assertEqual(task2_dict["status"], "completed") + self.assertEqual(task2_dict["result"], "success") + self.assertEqual(task2_dict["messages"], ["message1", "message2"]) + + # Verify dispatcher method was called + mock_get_running_tasks.assert_called_once_with(filter_func=None) + + def test_message_handler_receives_submitted_message(self): + """Test that handlers receive messages after scheduler startup and message submission.""" + # Create a mock handler that tracks received messages + received_messages = [] + + def mock_handler(messages: list[ScheduleMessageItem]) -> None: + """Mock handler that records received messages.""" + received_messages.extend(messages) + + # Register the mock handler + test_label = "test_handler" + handlers = {test_label: mock_handler} + self.scheduler.register_handlers(handlers) + + # Verify handler is registered + self.assertIn(test_label, self.scheduler.handlers) + self.assertEqual(self.scheduler.handlers[test_label], mock_handler) + + # Start the scheduler + self.scheduler.start() + + # Create and submit a test message + test_message = ScheduleMessageItem( + label=test_label, + content="Test message content", + user_id="test_user", + mem_cube_id="test_mem_cube", + mem_cube="test_mem_cube_obj", # Required field - can be string or GeneralMemCube + timestamp=datetime.now(), + ) + + import asyncio + + asyncio.run(self.scheduler.submit_messages(test_message)) + + # Wait for message processing to complete + import time + + time.sleep(2.0) # Allow sufficient time for message processing + + # Verify the handler received the message + self.assertEqual( + len(received_messages), 1, f"Expected 1 message, got {len(received_messages)}" + ) + self.assertEqual(received_messages[0].label, test_label) + self.assertEqual(received_messages[0].content, "Test message content") + self.assertEqual(received_messages[0].user_id, "test_user") + self.assertEqual(received_messages[0].mem_cube_id, "test_mem_cube") + + # Stop the scheduler + self.scheduler.stop() diff --git a/tests/test_hello_world.py b/tests/test_hello_world.py index 986839bc9..e9c81c7f0 100644 --- a/tests/test_hello_world.py +++ b/tests/test_hello_world.py @@ -118,6 +118,8 @@ def test_memos_yuqingchen_hello_world_logger_called(): def test_memos_chen_tang_hello_world(): + import warnings + from memos.memories.textual.general import GeneralTextMemory # Define return values for os.getenv @@ -130,7 +132,10 @@ def mock_getenv(key, default=None): } return mock_values.get(key, default) - # Use patch to mock os.getenv - with patch("os.getenv", side_effect=mock_getenv): - memory = memos_chentang_hello_world() - assert isinstance(memory, GeneralTextMemory) + # Filter Pydantic serialization warnings + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning, module="pydantic") + # Use patch to mock os.getenv + with patch("os.getenv", side_effect=mock_getenv): + memory = memos_chentang_hello_world() + assert isinstance(memory, GeneralTextMemory)