diff --git a/veadk/integrations/ve_identity/auth_processor.py b/veadk/integrations/ve_identity/auth_processor.py index 289dc0b..096a10d 100644 --- a/veadk/integrations/ve_identity/auth_processor.py +++ b/veadk/integrations/ve_identity/auth_processor.py @@ -25,6 +25,7 @@ from google.genai import types from google.adk.auth.auth_credential import OAuth2Auth +from veadk.integrations.ve_identity.auth_config import _get_default_region from veadk.processors.base_run_processor import BaseRunProcessor from veadk.integrations.ve_identity.identity_client import IdentityClient from veadk.integrations.ve_identity.models import AuthRequestConfig, OAuth2AuthPoller @@ -178,6 +179,10 @@ def __init__(self, *, config: Optional[AuthRequestConfig] = None): f"Please open this URL in your browser to authorize: {url}" ) ) + # Use provided region or get from config + if self.config.region is None: + self.config.region = _get_default_region() + self._identity_client = self.config.identity_client or IdentityClient( region=self.config.region ) diff --git a/veadk/integrations/ve_identity/token_manager.py b/veadk/integrations/ve_identity/token_manager.py index a101aef..8477a32 100644 --- a/veadk/integrations/ve_identity/token_manager.py +++ b/veadk/integrations/ve_identity/token_manager.py @@ -22,6 +22,7 @@ from google.adk.tools.tool_context import ToolContext from google.adk.agents.readonly_context import ReadonlyContext +from veadk.integrations.ve_identity.auth_config import _get_default_region from veadk.utils.logger import get_logger from veadk.integrations.ve_identity.identity_client import IdentityClient @@ -49,13 +50,16 @@ class WorkloadTokenManager: def __init__( self, identity_client: IdentityClient = None, - region: Optional[str] = "cn-beijing", + region: Optional[str] = None, ): """Initialize the token manager. Args: identity_client: The IdentityClient instance to use for token requests. """ + if region is None: + region = _get_default_region() + self._identity_client = identity_client or IdentityClient(region=region) def _build_cache_key( diff --git a/veadk/runner.py b/veadk/runner.py index c61bb52..ebd357a 100644 --- a/veadk/runner.py +++ b/veadk/runner.py @@ -32,6 +32,7 @@ from veadk.config import getenv from veadk.evaluation import EvalSetRecorder from veadk.memory.short_term_memory import ShortTermMemory +from veadk.processors.base_run_processor import BaseRunProcessor from veadk.types import MediaMessage from veadk.utils.logger import get_logger from veadk.utils.misc import formatted_timestamp, read_file_to_bytes @@ -418,6 +419,7 @@ def __init__( app_name: str = "veadk_default_app", user_id: str = "veadk_default_user", upload_inline_data_to_tos: bool = False, + run_processor: "BaseRunProcessor | None" = None, *args, **kwargs, ) -> None: @@ -438,6 +440,8 @@ def __init__( app_name (str): Application name. Defaults to ``"veadk_default_app"``. user_id (str): Default user ID. Defaults to ``"veadk_default_user"``. upload_inline_data_to_tos (bool): Whether to enable inline media upload. Defaults to ``False``. + run_processor (BaseRunProcessor | None): Optional run processor for intercepting agent execution. + If not provided, will try to get from agent. If agent doesn't have one, uses NoOpRunProcessor. *args: Positional args passed through to ``ADKRunner``. **kwargs: Keyword args passed through to ``ADKRunner``; may include ``session_service`` and ``memory_service`` to override defaults. @@ -456,6 +460,16 @@ def __init__( session_service = kwargs.pop("session_service", None) memory_service = kwargs.pop("memory_service", None) + # Handle run_processor: priority is runner arg > agent.run_processor > NoOpRunProcessor + if run_processor is not None: + self.run_processor = run_processor + elif hasattr(agent, "run_processor") and agent.run_processor is not None: # type: ignore + self.run_processor = agent.run_processor # type: ignore + else: + from veadk.processors import NoOpRunProcessor + + self.run_processor = NoOpRunProcessor() + if session_service: if short_term_memory: logger.warning( @@ -511,6 +525,7 @@ async def run( run_config: RunConfig | None = None, save_tracing_data: bool = False, upload_inline_data_to_tos: bool = False, + run_processor: "BaseRunProcessor | None" = None, ): """Run a conversation with multi-turn text and multimodal inputs. @@ -527,6 +542,8 @@ async def run( config is created using the environment var ``MODEL_AGENT_MAX_LLM_CALLS``. save_tracing_data (bool): Whether to dump tracing data to disk after the run. Defaults to ``False``. upload_inline_data_to_tos (bool): Whether to enable media upload only for this run. Defaults to ``False``. + run_processor (BaseRunProcessor | None): Optional run processor to use for this run. + If not provided, uses the runner's default run_processor. Defaults to None. Returns: str: The textual output from the last event, if present; otherwise an empty string. @@ -567,12 +584,20 @@ async def run( final_output = "" for converted_message in converted_messages: try: - async for event in self.run_async( - user_id=user_id, - session_id=session_id, - new_message=converted_message, - run_config=run_config, - ): + + @(run_processor or self.run_processor).process_run( + runner=self, message=converted_message + ) + async def event_generator(): + async for event in self.run_async( + user_id=user_id, + session_id=session_id, + new_message=converted_message, + run_config=run_config, + ): + yield event + + async for event in event_generator(): if event.get_function_calls(): for function_call in event.get_function_calls(): logger.debug(f"Function call: {function_call}")