diff --git a/contributing/samples/gepa/experiment.py b/contributing/samples/gepa/experiment.py index f3751206a8..2710c3894c 100644 --- a/contributing/samples/gepa/experiment.py +++ b/contributing/samples/gepa/experiment.py @@ -43,7 +43,6 @@ from tau_bench.types import EnvRunResult from tau_bench.types import RunConfig import tau_bench_agent as tau_bench_agent_lib - import utils diff --git a/contributing/samples/gepa/run_experiment.py b/contributing/samples/gepa/run_experiment.py index d857da9635..e31db15788 100644 --- a/contributing/samples/gepa/run_experiment.py +++ b/contributing/samples/gepa/run_experiment.py @@ -25,7 +25,6 @@ from absl import flags import experiment from google.genai import types - import utils _OUTPUT_DIR = flags.DEFINE_string( diff --git a/pyproject.toml b/pyproject.toml index f60494f255..ab19ef0810 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -210,6 +210,7 @@ known_third_party = ["google.adk"] [tool.pytest.ini_options] testpaths = ["tests"] +pythonpath = "src" asyncio_default_fixture_loop_scope = "function" asyncio_mode = "auto" diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 287fce6796..78da37f44d 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -14,6 +14,7 @@ from __future__ import annotations +import copy import importlib import json import logging @@ -35,10 +36,14 @@ from starlette.types import Lifespan from watchdog.observers import Observer +from ..artifacts.in_memory_artifact_service import InMemoryArtifactService from ..auth.credential_service.in_memory_credential_service import InMemoryCredentialService from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager +from ..memory.in_memory_memory_service import InMemoryMemoryService from ..runners import Runner +from ..sessions.in_memory_session_service import InMemorySessionService +from ..sessions.vertex_ai_session_service import VertexAiSessionService from .adk_web_server import AdkWebServer from .service_registry import load_services_module from .utils import envs @@ -489,7 +494,22 @@ def create_a2a_runner_loader(captured_app_name: str): """Factory function to create A2A runner with proper closure.""" async def _get_a2a_runner_async() -> Runner: - return await adk_web_server.get_runner_async(captured_app_name) + original_runner = await adk_web_server.get_runner_async( + captured_app_name + ) + # Check if the session service is Agent Engine session Service + if isinstance( + original_runner.session_service, VertexAiSessionService + ): + # VertexAiSessionService is not compliant with A2A (impossible to create session on the fly with contextID) + # So, change it to InMemorySessionService. Put the other service in memory because persistence do not make sense + runner = copy.copy(original_runner) + runner.session_service = InMemorySessionService() + runner.artifact_service = InMemoryArtifactService() + runner.memory_service = InMemoryMemoryService() + runner.credential_service = InMemoryCredentialService() + return runner + return original_runner return _get_a2a_runner_async diff --git a/tests/unittests/agents/test_remote_a2a_agent.py b/tests/unittests/agents/test_remote_a2a_agent.py index 7643125d81..b884bc4048 100644 --- a/tests/unittests/agents/test_remote_a2a_agent.py +++ b/tests/unittests/agents/test_remote_a2a_agent.py @@ -17,6 +17,7 @@ import tempfile from unittest.mock import AsyncMock from unittest.mock import create_autospec +from unittest.mock import MagicMock from unittest.mock import Mock from unittest.mock import patch @@ -1002,7 +1003,7 @@ async def test_handle_a2a_response_with_task_submitted_and_no_update(self): mock_a2a_task, self.agent.name, self.mock_context, - self.mock_a2a_part_converter, + self.agent._a2a_part_converter, ) # Check the parts are updated as Thought assert result.content.parts[0].thought is True @@ -1770,7 +1771,7 @@ async def test_run_async_impl_successful_request(self): ) # Tuple with parts and context_id # Mock A2A client - mock_a2a_client = create_autospec(spec=A2AClient, instance=True) + mock_a2a_client = MagicMock(spec=A2AClient) mock_response = Mock() mock_send_message = AsyncMock() mock_send_message.__aiter__.return_value = [mock_response] @@ -1909,7 +1910,7 @@ async def test_run_async_impl_with_meta_provider(self): ) # Tuple with parts and context_id # Mock A2A client - mock_a2a_client = create_autospec(spec=A2AClient, instance=True) + mock_a2a_client = MagicMock(spec=A2AClient) mock_response = Mock() mock_send_message = AsyncMock() mock_send_message.__aiter__.return_value = [mock_response] @@ -2046,7 +2047,7 @@ async def test_run_async_impl_successful_request(self): ) # Tuple with parts and context_id # Mock A2A client - mock_a2a_client = create_autospec(spec=A2AClient, instance=True) + mock_a2a_client = MagicMock(spec=A2AClient) mock_response = Mock() mock_send_message = AsyncMock() mock_send_message.__aiter__.return_value = [mock_response] diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index 0c69605349..34c71d90bf 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -32,6 +32,8 @@ from google.adk.agents.run_config import RunConfig from google.adk.apps.app import App from google.adk.artifacts.base_artifact_service import ArtifactVersion +from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.auth.credential_service.in_memory_credential_service import InMemoryCredentialService from google.adk.cli import fast_api as fast_api_module from google.adk.cli.fast_api import get_fast_api_app from google.adk.errors.input_validation_error import InputValidationError @@ -42,10 +44,12 @@ from google.adk.evaluation.in_memory_eval_sets_manager import InMemoryEvalSetsManager from google.adk.events.event import Event from google.adk.events.event_actions import EventActions +from google.adk.memory.in_memory_memory_service import InMemoryMemoryService from google.adk.runners import Runner from google.adk.sessions.in_memory_session_service import InMemorySessionService from google.adk.sessions.session import Session from google.adk.sessions.state import State +from google.adk.sessions.vertex_ai_session_service import VertexAiSessionService from google.genai import types from pydantic import BaseModel import pytest @@ -1291,6 +1295,107 @@ def test_a2a_agent_discovery(test_app_with_a2a): logger.info("A2A agent discovery test passed") +@pytest.mark.skipif( + sys.version_info < (3, 10), reason="A2A requires Python 3.10+" +) +def test_a2a_runner_factory_creates_isolated_runner(temp_agents_dir_with_a2a): + """Verify the A2A runner factory creates a copy of the runner with in-memory services.""" + # 1. Setup Mocks for the original runner and its services + original_runner = Runner( + agent=MagicMock(), + app_name="test_app", + session_service=VertexAiSessionService(), + ) + original_runner.memory_service = MagicMock() + original_runner.artifact_service = MagicMock() + original_runner.credential_service = MagicMock() + + # Mock the AdkWebServer to control the runner it returns + mock_web_server_instance = MagicMock() + mock_web_server_instance.get_runner_async = AsyncMock( + return_value=original_runner + ) + # The factory captures the app_name, so we need to mock list_agents + mock_web_server_instance.list_agents.return_value = ["test_a2a_agent"] + + # 2. Patch dependencies in the fast_api module + with ( + patch("google.adk.cli.fast_api.AdkWebServer") as mock_web_server, + patch("a2a.server.apps.A2AStarletteApplication") as mock_a2a_app, + patch("a2a.server.tasks.InMemoryTaskStore") as mock_task_store, + patch( + "google.adk.a2a.executor.a2a_agent_executor.A2aAgentExecutor" + ) as mock_executor, + patch( + "a2a.server.request_handlers.DefaultRequestHandler" + ) as mock_handler, + patch("a2a.types.AgentCard") as mock_agent_card, + patch("a2a.utils.constants.AGENT_CARD_WELL_KNOWN_PATH", "/agent.json"), + ): + mock_web_server.return_value = mock_web_server_instance + mock_task_store.return_value = MagicMock() + mock_executor.return_value = MagicMock() + mock_handler.return_value = MagicMock() + mock_agent_card.return_value = MagicMock() + + # Change to temp directory + original_cwd = os.getcwd() + os.chdir(temp_agents_dir_with_a2a) + try: + # 3. Call get_fast_api_app to trigger the factory creation + get_fast_api_app( + agents_dir=".", + web=False, + session_service_uri="", + artifact_service_uri="", + memory_service_uri="", + allow_origins=[], + a2a=True, # Enable A2A to create the factory + host="127.0.0.1", + port=8000, + ) + finally: + os.chdir(original_cwd) + + # 4. Capture the factory from the mocked A2aAgentExecutor + assert mock_executor.call_args is not None, "A2aAgentExecutor not called" + kwargs = mock_executor.call_args.kwargs + assert "runner" in kwargs + runner_factory = kwargs["runner"] + + # 5. Execute the factory to get the new runner + # Since runner_factory is an async function, we need to run it. + # We run it in a separate thread to avoid event loop conflicts if an event loop is already running. + from concurrent.futures import ThreadPoolExecutor + + with ThreadPoolExecutor(max_workers=1) as executor: + a2a_runner = executor.submit(asyncio.run, runner_factory()).result() + + # 6. Assert that the new runner is a separate, modified copy + assert a2a_runner is not original_runner, "Runner should be a copy" + + # Assert that services have been replaced with InMemory versions + assert isinstance(a2a_runner.memory_service, InMemoryMemoryService) + assert isinstance(a2a_runner.session_service, InMemorySessionService) + assert isinstance(a2a_runner.artifact_service, InMemoryArtifactService) + assert isinstance(a2a_runner.credential_service, InMemoryCredentialService) + + # Assert that the original runner's services are unchanged + assert not isinstance(original_runner.memory_service, InMemoryMemoryService) + assert not isinstance( + original_runner.session_service, InMemorySessionService + ) + assert not isinstance( + original_runner.artifact_service, InMemoryArtifactService + ) + assert not isinstance( + original_runner.credential_service, InMemoryCredentialService + ) + + +@pytest.mark.skipif( + sys.version_info < (3, 10), reason="A2A requires Python 3.10+" +) def test_a2a_disabled_by_default(test_app): """Test that A2A functionality is disabled by default.""" # The regular test_app fixture has a2a=False