Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions veadk/integrations/ve_identity/auth_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from google.genai import types
from google.adk.auth.auth_credential import OAuth2Auth

from veadk.integrations.ve_identity.auth_config import _get_default_region
from veadk.processors.base_run_processor import BaseRunProcessor
from veadk.integrations.ve_identity.identity_client import IdentityClient
from veadk.integrations.ve_identity.models import AuthRequestConfig, OAuth2AuthPoller
Expand Down Expand Up @@ -178,6 +179,10 @@ def __init__(self, *, config: Optional[AuthRequestConfig] = None):
f"Please open this URL in your browser to authorize: {url}"
)
)
# Use provided region or get from config
if self.config.region is None:
self.config.region = _get_default_region()

self._identity_client = self.config.identity_client or IdentityClient(
region=self.config.region
)
Expand Down
6 changes: 5 additions & 1 deletion veadk/integrations/ve_identity/token_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from google.adk.tools.tool_context import ToolContext
from google.adk.agents.readonly_context import ReadonlyContext

from veadk.integrations.ve_identity.auth_config import _get_default_region
from veadk.utils.logger import get_logger

from veadk.integrations.ve_identity.identity_client import IdentityClient
Expand Down Expand Up @@ -49,13 +50,16 @@ class WorkloadTokenManager:
def __init__(
self,
identity_client: IdentityClient = None,
region: Optional[str] = "cn-beijing",
region: Optional[str] = None,
):
"""Initialize the token manager.

Args:
identity_client: The IdentityClient instance to use for token requests.
"""
if region is None:
region = _get_default_region()

self._identity_client = identity_client or IdentityClient(region=region)

def _build_cache_key(
Expand Down
37 changes: 31 additions & 6 deletions veadk/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from veadk.config import getenv
from veadk.evaluation import EvalSetRecorder
from veadk.memory.short_term_memory import ShortTermMemory
from veadk.processors.base_run_processor import BaseRunProcessor
from veadk.types import MediaMessage
from veadk.utils.logger import get_logger
from veadk.utils.misc import formatted_timestamp, read_file_to_bytes
Expand Down Expand Up @@ -418,6 +419,7 @@ def __init__(
app_name: str = "veadk_default_app",
user_id: str = "veadk_default_user",
upload_inline_data_to_tos: bool = False,
run_processor: "BaseRunProcessor | None" = None,
*args,
**kwargs,
) -> None:
Expand All @@ -438,6 +440,8 @@ def __init__(
app_name (str): Application name. Defaults to ``"veadk_default_app"``.
user_id (str): Default user ID. Defaults to ``"veadk_default_user"``.
upload_inline_data_to_tos (bool): Whether to enable inline media upload. Defaults to ``False``.
run_processor (BaseRunProcessor | None): Optional run processor for intercepting agent execution.
If not provided, will try to get from agent. If agent doesn't have one, uses NoOpRunProcessor.
*args: Positional args passed through to ``ADKRunner``.
**kwargs: Keyword args passed through to ``ADKRunner``; may include
``session_service`` and ``memory_service`` to override defaults.
Expand All @@ -456,6 +460,16 @@ def __init__(
session_service = kwargs.pop("session_service", None)
memory_service = kwargs.pop("memory_service", None)

# Handle run_processor: priority is runner arg > agent.run_processor > NoOpRunProcessor
if run_processor is not None:
self.run_processor = run_processor
elif hasattr(agent, "run_processor") and agent.run_processor is not None: # type: ignore
self.run_processor = agent.run_processor # type: ignore
else:
from veadk.processors import NoOpRunProcessor

self.run_processor = NoOpRunProcessor()

if session_service:
if short_term_memory:
logger.warning(
Expand Down Expand Up @@ -511,6 +525,7 @@ async def run(
run_config: RunConfig | None = None,
save_tracing_data: bool = False,
upload_inline_data_to_tos: bool = False,
run_processor: "BaseRunProcessor | None" = None,
):
"""Run a conversation with multi-turn text and multimodal inputs.

Expand All @@ -527,6 +542,8 @@ async def run(
config is created using the environment var ``MODEL_AGENT_MAX_LLM_CALLS``.
save_tracing_data (bool): Whether to dump tracing data to disk after the run. Defaults to ``False``.
upload_inline_data_to_tos (bool): Whether to enable media upload only for this run. Defaults to ``False``.
run_processor (BaseRunProcessor | None): Optional run processor to use for this run.
If not provided, uses the runner's default run_processor. Defaults to None.

Returns:
str: The textual output from the last event, if present; otherwise an empty string.
Expand Down Expand Up @@ -567,12 +584,20 @@ async def run(
final_output = ""
for converted_message in converted_messages:
try:
async for event in self.run_async(
user_id=user_id,
session_id=session_id,
new_message=converted_message,
run_config=run_config,
):

@(run_processor or self.run_processor).process_run(
runner=self, message=converted_message
)
async def event_generator():
async for event in self.run_async(
user_id=user_id,
session_id=session_id,
new_message=converted_message,
run_config=run_config,
):
yield event

async for event in event_generator():
if event.get_function_calls():
for function_call in event.get_function_calls():
logger.debug(f"Function call: {function_call}")
Expand Down