Skip to content
3 changes: 1 addition & 2 deletions src/google/adk/plugins/global_instruction_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,8 @@ async def before_model_callback(
return None

# Resolve the global instruction (handle both string and InstructionProvider)
readonly_context = ReadonlyContext(callback_context.invocation_context)
final_global_instruction = await self._resolve_global_instruction(
readonly_context
callback_context
)

if not final_global_instruction:
Expand Down
12 changes: 6 additions & 6 deletions tests/unittests/plugins/test_global_instruction_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ async def test_global_instruction_plugin_with_string():
mock_invocation_context.session = mock_session

mock_callback_context = Mock(spec=CallbackContext)
mock_callback_context.invocation_context = mock_invocation_context
mock_callback_context._invocation_context = mock_invocation_context

llm_request = LlmRequest(
model="gemini-1.5-flash",
Expand Down Expand Up @@ -80,10 +80,10 @@ async def build_global_instruction(readonly_context: ReadonlyContext) -> str:
)

mock_invocation_context = Mock(spec=InvocationContext)
mock_invocation_context.session = mock_session

mock_callback_context = Mock(spec=CallbackContext)
mock_callback_context.invocation_context = mock_invocation_context
mock_callback_context._invocation_context = mock_invocation_context
mock_callback_context.session = mock_session

llm_request = LlmRequest(
model="gemini-1.5-flash",
Expand Down Expand Up @@ -119,7 +119,7 @@ async def test_global_instruction_plugin_empty_instruction():
mock_invocation_context.session = mock_session

mock_callback_context = Mock(spec=CallbackContext)
mock_callback_context.invocation_context = mock_invocation_context
mock_callback_context._invocation_context = mock_invocation_context

llm_request = LlmRequest(
model="gemini-1.5-flash",
Expand Down Expand Up @@ -156,7 +156,7 @@ async def test_global_instruction_plugin_leads_existing():
mock_invocation_context.session = mock_session

mock_callback_context = Mock(spec=CallbackContext)
mock_callback_context.invocation_context = mock_invocation_context
mock_callback_context._invocation_context = mock_invocation_context

llm_request = LlmRequest(
model="gemini-1.5-flash",
Expand Down Expand Up @@ -191,7 +191,7 @@ async def test_global_instruction_plugin_prepends_to_list():
mock_invocation_context.session = mock_session

mock_callback_context = Mock(spec=CallbackContext)
mock_callback_context.invocation_context = mock_invocation_context
mock_callback_context._invocation_context = mock_invocation_context

llm_request = LlmRequest(
model="gemini-1.5-flash",
Expand Down