From 4312a9274683c513ff5be74a960dfaf7d783be52 Mon Sep 17 00:00:00 2001 From: Sudhendra Date: Sat, 17 Jan 2026 12:33:03 -0600 Subject: [PATCH 01/10] Initial implementation of CodingAgent for adk contribution --- .github/prompts/plan-codingAgent.prompt.md | 149 +++++ contributing/samples/coding_agent/__init__.py | 17 + contributing/samples/coding_agent/agent.py | 179 ++++++ src/google/adk/agents/__init__.py | 26 +- src/google/adk/agents/coding_agent.py | 548 ++++++++++++++++++ src/google/adk/agents/coding_agent_config.py | 226 ++++++++ src/google/adk/code_executors/__init__.py | 159 +++-- .../adk/code_executors/allowlist_validator.py | 354 +++++++++++ .../coding_agent_code_executor.py | 505 ++++++++++++++++ .../adk/code_executors/tool_code_generator.py | 469 +++++++++++++++ .../code_executors/tool_execution_server.py | 366 ++++++++++++ src/google/adk/telemetry/tracing.py | 329 ++++++++--- tests/unittests/agents/test_coding_agent.py | 309 ++++++++++ .../test_allowlist_validator.py | 321 ++++++++++ .../test_tool_code_generator.py | 320 ++++++++++ 15 files changed, 4141 insertions(+), 136 deletions(-) create mode 100644 .github/prompts/plan-codingAgent.prompt.md create mode 100644 contributing/samples/coding_agent/__init__.py create mode 100644 contributing/samples/coding_agent/agent.py create mode 100644 src/google/adk/agents/coding_agent.py create mode 100644 src/google/adk/agents/coding_agent_config.py create mode 100644 src/google/adk/code_executors/allowlist_validator.py create mode 100644 src/google/adk/code_executors/coding_agent_code_executor.py create mode 100644 src/google/adk/code_executors/tool_code_generator.py create mode 100644 src/google/adk/code_executors/tool_execution_server.py create mode 100644 tests/unittests/agents/test_coding_agent.py create mode 100644 tests/unittests/code_executors/test_allowlist_validator.py create mode 100644 tests/unittests/code_executors/test_tool_code_generator.py diff --git a/.github/prompts/plan-codingAgent.prompt.md b/.github/prompts/plan-codingAgent.prompt.md new file mode 100644 index 0000000000..11454da141 --- /dev/null +++ b/.github/prompts/plan-codingAgent.prompt.md @@ -0,0 +1,149 @@ +# Plan: CodingAgent Implementation for ADK-Python + +Create a production-ready `CodingAgent` class that generates Python code to execute tools via ReAct loop, with HTTP-based tool injection, dual-layer security (allowlist + container), configurable statefulness with history re-execution, and full ADK telemetry integration. + +## Steps + +1. **Create `CodingAgentConfig`** in [src/google/adk/agents/coding_agent_config.py](src/google/adk/agents/coding_agent_config.py) - Pydantic config extending `BaseAgentConfig` with fields: `model`, `instruction`, `tools`, `code_executor`, `authorized_imports` (frozenset), `max_iterations` (default 10), `error_retry_attempts` (default 2), `stateful` (default False), `tool_server_host` (default `host.docker.internal`, fallback `172.17.0.1`), `tool_server_port` (default 8765). + +2. **Create `CodingAgentState`** in [src/google/adk/agents/coding_agent.py](src/google/adk/agents/coding_agent.py) - State extending `BaseAgentState` with: `iteration_count`, `error_count`, `execution_history` (list of `ExecutionStep` with `code`, `result`, `tool_traces`, `success` fields for re-execution optimization). + +3. **Create `ToolCodeGenerator`** in [src/google/adk/code_executors/tool_code_generator.py](src/google/adk/code_executors/tool_code_generator.py) - Functions: `generate_runtime_header()` (HTTP client + `_call_adk_tool()` + trace collection), `generate_tool_stubs()` (typed function stubs from `BaseTool._get_declaration()`), `generate_final_answer_stub()`, `generate_system_prompt()` (tool docs + 1-2 few-shot examples showing `tool_code` format + `final_answer()` usage). + +4. **Create `AllowlistValidator`** in [src/google/adk/code_executors/allowlist_validator.py](src/google/adk/code_executors/allowlist_validator.py) - `validate_imports()` using AST extraction, `DEFAULT_SAFE_IMPORTS` frozenset, `ImportValidationError` with violation details, `is_import_allowed()` helper supporting wildcards (e.g., `collections.*`). + +5. **Create `ToolExecutionServer`** in [src/google/adk/code_executors/tool_execution_server.py](src/google/adk/code_executors/tool_execution_server.py) - FastAPI server with `POST /tool_call` routing to `BaseTool.run_async()` with full `ToolContext`, `GET /tool_trace`, lifecycle `start()/stop()`, configurable host detection (`host.docker.internal` → `172.17.0.1` fallback). + +6. **Create `CodingAgentCodeExecutor`** in [src/google/adk/code_executors/coding_agent_code_executor.py](src/google/adk/code_executors/coding_agent_code_executor.py) - Composable wrapper with: tool stub prepending, allowlist pre-validation, server lifecycle, history re-execution (skip successful steps via hash comparison), trace extraction (`__TOOL_TRACE__:`), final answer detection (`__FINAL_ANSWER__:`). + +7. **Create `CodingAgent` class** in [src/google/adk/agents/coding_agent.py](src/google/adk/agents/coding_agent.py) - `_run_async_impl()` ReAct loop: build system prompt with few-shot examples, call `canonical_model`, parse code blocks, validate imports, execute via `CodingAgentCodeExecutor`, detect `final_answer()` OR no-code fallback, yield events with `state_delta`, retry errors with LLM feedback up to `error_retry_attempts`. + +8. **Add telemetry** in [src/google/adk/telemetry/tracing.py](src/google/adk/telemetry/tracing.py) - Add `trace_code_generation()`, `trace_code_execution()`, `trace_import_validation()`, `trace_tool_ipc()` following existing patterns with code content, duration, and error attributes. + +9. **Update exports** in [src/google/adk/agents/__init__.py](src/google/adk/agents/__init__.py) and [src/google/adk/code_executors/__init__.py](src/google/adk/code_executors/__init__.py) - Add all new classes to `__all__` with lazy loading for executor components. + +10. **Create comprehensive tests** in `tests/unittests/agents/test_coding_agent.py` and `tests/unittests/code_executors/test_coding_agent_*.py` - Cover: ReAct loop, final answer detection + fallback, allowlist validation, error retry, stateful history re-execution with skip optimization, tool traces, host fallback logic. + +11. **Create sample agent** in `contributing/samples/coding_agent/` - Example with `web_search`, `calculator`, `read_file` tools demonstrating multi-step code generation with `ContainerCodeExecutor`. + +## Key Implementation Details + +### Few-shot examples in system prompt + +```python +SYSTEM_PROMPT_EXAMPLES = ''' +Example 1 - Using tools: +```tool_code +result = web_search(query="Python async best practices") +print(result["snippets"][0]) +``` + +Example 2 - Final answer: +```tool_code +data = read_file(path="data.csv") +total = sum(float(row["amount"]) for row in data["rows"]) +final_answer(f"The total amount is ${total:.2f}") +``` +''' +``` + +### History re-execution optimization + +```python +def _should_skip_step(self, step: ExecutionStep, code_hash: str) -> bool: + """Skip if code unchanged and previously succeeded.""" + return step.success and step.code_hash == code_hash +``` + +### Host detection with fallback + +```python +def _resolve_tool_server_host(self) -> str: + if self.tool_server_host: + return self.tool_server_host + # Try host.docker.internal first (Docker Desktop) + # Fallback to 172.17.0.1 (Linux bridge network) + return detect_docker_host_address() +``` + +## Architecture Diagram + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ CodingAgent._run_async_impl() │ +│ ┌─────────────┐ ┌──────────────┐ ┌─────────────────────┐ │ +│ │ Build prompt│──▶│ Call LLM │──▶│ Parse code blocks │ │ +│ │ + tool docs │ │ (canonical_ │ │ (delimiters) │ │ +│ └─────────────┘ │ model) │ └─────────┬───────────┘ │ +│ └──────────────┘ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────┐ │ +│ │ CodingAgentCodeExecutor.execute_code() │ │ +│ │ ┌─────────────┐ ┌──────────────┐ ┌───────────────┐ │ │ +│ │ │ Validate │─▶│ Prepend tool │─▶│ Execute in │ │ │ +│ │ │ imports │ │ stubs + │ │ container │ │ │ +│ │ │ (allowlist) │ │ runtime │ │ │ │ │ +│ │ └─────────────┘ └──────────────┘ └───────┬───────┘ │ │ +│ └─────────────────────────────────────────────┼───────────┘ │ +│ │ │ +│ ┌──────────────────────────────────────┘ │ +│ │ HTTP IPC (host.docker.internal) │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────┐ │ +│ │ ToolExecutionServer (FastAPI) │ │ +│ │ POST /tool_call ──▶ BaseTool.run_async(ToolContext) │ │ +│ │ GET /tool_trace ──▶ call_traces[] │ │ +│ └─────────────────────────────────────────────────────────┘ │ +│ │ +│ ┌─────────────────┐ ┌─────────────────────────────────────┐ │ +│ │ Check final_ │◀──│ Extract traces + clean stdout │ │ +│ │ answer() OR │ │ (__TOOL_TRACE__, __FINAL_ANSWER__) │ │ +│ │ fallback │ └─────────────────────────────────────┘ │ +│ └────────┬────────┘ │ +│ │ if done: yield final Event │ +│ │ else: feed result back to LLM (loop) │ +└───────────┴─────────────────────────────────────────────────────┘ +``` + +## File Structure + +``` +src/google/adk/ +├── agents/ +│ ├── coding_agent.py # CodingAgent + CodingAgentState +│ ├── coding_agent_config.py # CodingAgentConfig +│ └── __init__.py # Updated exports +├── code_executors/ +│ ├── tool_code_generator.py # Stub generation + system prompt +│ ├── allowlist_validator.py # Import validation +│ ├── tool_execution_server.py # FastAPI IPC server +│ ├── coding_agent_code_executor.py # Main executor wrapper +│ └── __init__.py # Updated exports +└── telemetry/ + └── tracing.py # New trace functions + +tests/unittests/ +├── agents/ +│ └── test_coding_agent.py +└── code_executors/ + ├── test_tool_code_generator.py + ├── test_allowlist_validator.py + └── test_coding_agent_code_executor.py + +contributing/samples/ +└── coding_agent/ + ├── __init__.py + └── agent.py +``` + +## Design Decisions + +| Decision | Choice | Rationale | +|----------|--------|-----------| +| Tool injection | HTTP IPC (Code Prepending) | Native ADK integration, full `ToolContext` access, async support | +| Security | Allowlist + Container | Defense-in-depth: import validation before container isolation | +| Final answer | Explicit `final_answer()` + fallback | Reliability with graceful degradation | +| Stateful mode | Re-execute history | Safer than pickle, with skip optimization for speed | +| Async tools | Sync wrapper via host server | Host handles async natively, container code stays simple | +| Docker host | Configurable with fallback | `host.docker.internal` → `172.17.0.1` for cross-platform | +| Retries | Default 2 with LLM feedback | Matches `BaseCodeExecutor.error_retry_attempts` pattern | diff --git a/contributing/samples/coding_agent/__init__.py b/contributing/samples/coding_agent/__init__.py new file mode 100644 index 0000000000..373c27ec12 --- /dev/null +++ b/contributing/samples/coding_agent/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sample CodingAgent that demonstrates code generation and tool usage.""" + +from . import agent diff --git a/contributing/samples/coding_agent/agent.py b/contributing/samples/coding_agent/agent.py new file mode 100644 index 0000000000..d0f6bae10c --- /dev/null +++ b/contributing/samples/coding_agent/agent.py @@ -0,0 +1,179 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sample CodingAgent demonstrating code generation with tool usage. + +This sample shows how to create a CodingAgent that can: +- Generate Python code to solve tasks +- Call tools as Python functions from within the generated code +- Execute code in a sandboxed container environment +- Provide final answers after multi-step reasoning + +Prerequisites: +- Docker must be installed and running +- Set GOOGLE_API_KEY or configure Vertex AI credentials + +Usage: + adk run contributing/samples/coding_agent + adk web contributing/samples + +Example queries: +- "What is 15% of 847?" +- "Calculate the compound interest on $10,000 at 5% annual rate for 3 years" +- "Search for the latest Python release and summarize the key features" +""" + +from google.adk.agents import CodingAgent +from google.adk.code_executors import ContainerCodeExecutor + + +# Define sample tools that the CodingAgent can use +def calculator(expression: str) -> dict: + """Evaluate a mathematical expression. + + Args: + expression: A mathematical expression to evaluate (e.g., "2 + 2 * 3"). + + Returns: + Dictionary with the result or error message. + """ + try: + # Safe evaluation of mathematical expressions + allowed_names = { + "abs": abs, + "round": round, + "min": min, + "max": max, + "sum": sum, + "pow": pow, + } + result = eval(expression, {"__builtins__": {}}, allowed_names) + return {"result": result, "expression": expression} + except Exception as e: + return {"error": str(e), "expression": expression} + + +def web_search(query: str, max_results: int = 5) -> dict: + """Search the web for information. + + Args: + query: The search query. + max_results: Maximum number of results to return. + + Returns: + Dictionary with search results. + """ + # This is a mock implementation for demonstration + # In production, you would integrate with a real search API + return { + "query": query, + "results": [ + { + "title": f"Result {i + 1} for: {query}", + "snippet": f"This is a sample result snippet for '{query}'...", + "url": f"https://example.com/result{i + 1}", + } + for i in range(min(max_results, 3)) + ], + "total_results": max_results, + } + + +def read_file(path: str) -> dict: + """Read contents of a file. + + Args: + path: Path to the file to read. + + Returns: + Dictionary with file contents or error. + """ + # This is a mock implementation for demonstration + # In production, you would implement actual file reading with proper security + mock_files = { + "data.csv": { + "content": "name,amount\nAlice,100\nBob,200\nCharlie,150", + "rows": [ + {"name": "Alice", "amount": "100"}, + {"name": "Bob", "amount": "200"}, + {"name": "Charlie", "amount": "150"}, + ], + }, + "config.json": { + "content": '{"setting": "value"}', + "data": {"setting": "value"}, + }, + } + + if path in mock_files: + return {"path": path, **mock_files[path]} + return {"error": f"File not found: {path}", "path": path} + + +def get_current_time() -> dict: + """Get the current date and time. + + Returns: + Dictionary with current timestamp information. + """ + from datetime import datetime + + now = datetime.now() + return { + "timestamp": now.isoformat(), + "year": now.year, + "month": now.month, + "day": now.day, + "hour": now.hour, + "minute": now.minute, + "weekday": now.strftime("%A"), + } + + +# Create the CodingAgent with tools +root_agent = CodingAgent( + name="code_assistant", + description=( + "An AI assistant that solves tasks by writing and executing Python code. " + "It can perform calculations, search for information, read files, and more." + ), + model="gemini-2.5-flash", + instruction=""" +You are a helpful coding assistant that solves problems by writing Python code. + +When given a task: +1. Think about what tools and computations you need +2. Write clear, well-commented Python code +3. Use the available tools as needed +4. Print intermediate results to verify your work +5. Call final_answer() with your result + +Always show your reasoning through code comments and print statements. +If a task cannot be completed with the available tools, explain why. +""", + tools=[ + calculator, + web_search, + read_file, + get_current_time, + ], + # Use ContainerCodeExecutor for sandboxed execution + # Note: Docker must be installed and running + code_executor=ContainerCodeExecutor( + image="python:3.11-slim", + ), + max_iterations=10, + error_retry_attempts=2, + stateful=False, +) diff --git a/src/google/adk/agents/__init__.py b/src/google/adk/agents/__init__.py index 35198179a5..b718513135 100644 --- a/src/google/adk/agents/__init__.py +++ b/src/google/adk/agents/__init__.py @@ -13,6 +13,8 @@ # limitations under the License. from .base_agent import BaseAgent +from .coding_agent import CodingAgent +from .coding_agent import CodingAgentState from .invocation_context import InvocationContext from .live_request_queue import LiveRequest from .live_request_queue import LiveRequestQueue @@ -25,15 +27,17 @@ from .sequential_agent import SequentialAgent __all__ = [ - 'Agent', - 'BaseAgent', - 'LlmAgent', - 'LoopAgent', - 'McpInstructionProvider', - 'ParallelAgent', - 'SequentialAgent', - 'InvocationContext', - 'LiveRequest', - 'LiveRequestQueue', - 'RunConfig', + "Agent", + "BaseAgent", + "CodingAgent", + "CodingAgentState", + "LlmAgent", + "LoopAgent", + "McpInstructionProvider", + "ParallelAgent", + "SequentialAgent", + "InvocationContext", + "LiveRequest", + "LiveRequestQueue", + "RunConfig", ] diff --git a/src/google/adk/agents/coding_agent.py b/src/google/adk/agents/coding_agent.py new file mode 100644 index 0000000000..46ccfcb8ed --- /dev/null +++ b/src/google/adk/agents/coding_agent.py @@ -0,0 +1,548 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CodingAgent - An agent that generates and executes Python code. + +This module provides the CodingAgent class, which implements a ReAct-style +agent that generates Python code to accomplish tasks using available tools. +""" + +from __future__ import annotations + +import logging +import re +from typing import Any +from typing import AsyncGenerator +from typing import Callable +from typing import ClassVar +from typing import Dict +from typing import FrozenSet +from typing import List +from typing import Optional +from typing import Type +from typing import Union + +from google.genai import types +from pydantic import Field +from pydantic import model_validator +from typing_extensions import override + +from ..code_executors.allowlist_validator import DEFAULT_SAFE_IMPORTS +from ..code_executors.base_code_executor import BaseCodeExecutor +from ..code_executors.code_execution_utils import CodeExecutionInput +from ..code_executors.code_execution_utils import CodeExecutionResult +from ..code_executors.code_execution_utils import CodeExecutionUtils +from ..code_executors.coding_agent_code_executor import CodingAgentCodeExecutor +from ..code_executors.coding_agent_code_executor import CodingAgentExecutionResult +from ..code_executors.tool_code_generator import generate_system_prompt +from ..events.event import Event +from ..events.event_actions import EventActions +from ..models.base_llm import BaseLlm +from ..models.llm_request import LlmRequest +from ..models.llm_response import LlmResponse +from ..models.registry import LLMRegistry +from ..tools.base_tool import BaseTool +from ..tools.base_toolset import BaseToolset +from ..tools.function_tool import FunctionTool +from ..tools.tool_context import ToolContext +from ..utils.feature_decorator import experimental +from .base_agent import BaseAgent +from .base_agent import BaseAgentState +from .base_agent_config import BaseAgentConfig +from .coding_agent_config import CodingAgentConfig +from .invocation_context import InvocationContext +from .readonly_context import ReadonlyContext + +logger = logging.getLogger("google_adk." + __name__) + + +@experimental +class CodingAgentState(BaseAgentState): + """State for CodingAgent tracking execution progress. + + Attributes: + iteration_count: Number of ReAct loop iterations completed. + error_count: Number of consecutive errors encountered. + execution_history: List of execution steps with code, results, and traces. + """ + + iteration_count: int = 0 + error_count: int = 0 + execution_history: List[Dict[str, Any]] = Field(default_factory=list) + + +ToolUnion = Union[Callable[..., Any], BaseTool, BaseToolset] + + +async def _convert_tool_union_to_tools( + tool_union: ToolUnion, + ctx: Optional[ReadonlyContext] = None, +) -> List[BaseTool]: + """Convert a tool union to a list of BaseTool instances. + + Args: + tool_union: A callable, BaseTool, or BaseToolset. + ctx: Optional context for toolset resolution. + + Returns: + List of BaseTool instances. + """ + if isinstance(tool_union, BaseTool): + return [tool_union] + if callable(tool_union): + return [FunctionTool(func=tool_union)] + # BaseToolset + if ctx: + return await tool_union.get_tools_with_prefix(ctx) + return await tool_union.get_tools_with_prefix(None) + + +@experimental +class CodingAgent(BaseAgent): + """Agent that generates Python code to solve tasks using available tools. + + CodingAgent implements a ReAct-style loop where it: + 1. Receives a task from the user + 2. Generates Python code that calls available tools + 3. Executes the code in a sandboxed environment + 4. Processes the results and either provides an answer or continues + + Tools are made available as Python functions that the generated code + can call. The code execution happens in a container for security, + with tool calls routed via HTTP to the host. + + Attributes: + model: The LLM model to use for code generation. + instruction: Additional instructions for the agent. + tools: List of tools available to the agent. + code_executor: The underlying code executor (e.g., ContainerCodeExecutor). + authorized_imports: Set of allowed Python imports. + max_iterations: Maximum ReAct loop iterations. + error_retry_attempts: Number of retries on execution errors. + stateful: Whether to maintain state across iterations. + tool_server_host: Host for the tool execution server. + tool_server_port: Port for the tool execution server. + """ + + DEFAULT_MODEL: ClassVar[str] = "gemini-2.5-flash" + + config_type: ClassVar[Type[BaseAgentConfig]] = CodingAgentConfig + + model: Union[str, BaseLlm] = "" + """The model to use for code generation.""" + + instruction: str = "" + """Additional instructions for the agent.""" + + tools: List[ToolUnion] = Field(default_factory=list) + """Tools available to the agent.""" + + code_executor: Optional[BaseCodeExecutor] = None + """The underlying code executor. If not set, uses ContainerCodeExecutor.""" + + authorized_imports: FrozenSet[str] = DEFAULT_SAFE_IMPORTS + """Set of allowed import patterns.""" + + max_iterations: int = 10 + """Maximum number of ReAct loop iterations.""" + + error_retry_attempts: int = 2 + """Number of retries on execution errors.""" + + stateful: bool = False + """Whether to maintain state across iterations.""" + + tool_server_host: Optional[str] = None + """Host for the tool execution server.""" + + tool_server_port: int = 8765 + """Port for the tool execution server.""" + + # Internal state + _coding_executor: Optional[CodingAgentCodeExecutor] = None + _resolved_tools: Optional[List[BaseTool]] = None + + class Config: + """Pydantic config.""" + + arbitrary_types_allowed = True + + @property + def canonical_model(self) -> BaseLlm: + """Get the resolved model as BaseLlm.""" + if isinstance(self.model, BaseLlm): + return self.model + elif self.model: + return LLMRegistry.new_llm(self.model) + else: + # Find model from ancestors + ancestor_agent = self.parent_agent + while ancestor_agent is not None: + if hasattr(ancestor_agent, "canonical_model"): + return ancestor_agent.canonical_model + ancestor_agent = ancestor_agent.parent_agent + return LLMRegistry.new_llm(self.DEFAULT_MODEL) + + async def _resolve_tools( + self, + ctx: Optional[ReadonlyContext] = None, + ) -> List[BaseTool]: + """Resolve tool unions to BaseTool instances. + + Args: + ctx: Optional context for toolset resolution. + + Returns: + List of resolved BaseTool instances. + """ + if self._resolved_tools is not None: + return self._resolved_tools + + resolved = [] + for tool_union in self.tools: + resolved.extend(await _convert_tool_union_to_tools(tool_union, ctx)) + + self._resolved_tools = resolved + return resolved + + async def _get_coding_executor( + self, + ctx: InvocationContext, + ) -> CodingAgentCodeExecutor: + """Get or create the CodingAgentCodeExecutor. + + Args: + ctx: The invocation context. + + Returns: + The configured code executor. + """ + if self._coding_executor is not None: + return self._coding_executor + + # Resolve tools + tools = await self._resolve_tools(ReadonlyContext(ctx)) + + # Get or create underlying executor + if self.code_executor: + underlying = self.code_executor + else: + # Default to ContainerCodeExecutor + try: + from ..code_executors.container_code_executor import ( + ContainerCodeExecutor, + ) + + underlying = ContainerCodeExecutor( + image="python:3.11-slim", + ) + except ImportError as e: + raise ImportError( + "CodingAgent requires ContainerCodeExecutor. " + 'Please install with: pip install "google-adk[extensions]" ' + "or provide a custom code_executor." + ) from e + + # Create the CodingAgentCodeExecutor wrapper + self._coding_executor = CodingAgentCodeExecutor( + underlying_executor=underlying, + tools=tools, + authorized_imports=self.authorized_imports, + tool_server_host=self.tool_server_host, + tool_server_port=self.tool_server_port, + stateful=self.stateful, + error_retry_attempts=self.error_retry_attempts, + ) + + return self._coding_executor + + def _build_system_prompt(self, tools: List[BaseTool]) -> str: + """Build the system prompt with tool documentation. + + Args: + tools: List of available tools. + + Returns: + The complete system prompt. + """ + return generate_system_prompt( + tools=tools, + custom_instruction=self.instruction, + ) + + def _extract_code_block(self, response_text: str) -> Optional[str]: + """Extract code from the model response. + + Args: + response_text: The model's response text. + + Returns: + The extracted code, or None if no code block found. + """ + # Try tool_code blocks first + pattern = r"```tool_code\n(.*?)```" + match = re.search(pattern, response_text, re.DOTALL) + if match: + return match.group(1).strip() + + # Fall back to python blocks + pattern = r"```python\n(.*?)```" + match = re.search(pattern, response_text, re.DOTALL) + if match: + return match.group(1).strip() + + return None + + def _build_error_feedback( + self, + error: str, + code: str, + ) -> str: + """Build feedback message for execution errors. + + Args: + error: The error message. + code: The code that caused the error. + + Returns: + Formatted error feedback for the LLM. + """ + return f"""The code execution failed with the following error: + +``` +{error} +``` + +The code that failed was: +```python +{code} +``` + +Please fix the error and try again. Common issues: +- Unauthorized imports (only use allowed imports) +- Tool call errors (check the tool documentation) +- Python syntax errors +""" + + @override + async def _run_async_impl( + self, + ctx: InvocationContext, + ) -> AsyncGenerator[Event, None]: + """Core implementation of the ReAct loop. + + Args: + ctx: The invocation context. + + Yields: + Events generated during execution. + """ + # Load or initialize state + state = self._load_agent_state(ctx, CodingAgentState) + if state is None: + state = CodingAgentState() + + # Resolve tools and get executor + tools = await self._resolve_tools(ReadonlyContext(ctx)) + coding_executor = await self._get_coding_executor(ctx) + + # Create tool context for the executor + tool_context = ToolContext(invocation_context=ctx) + coding_executor.set_context(ctx, tool_context) + + # Build system prompt + system_prompt = self._build_system_prompt(tools) + + # Get the model + model = self.canonical_model + + # Build initial request with conversation history + contents = [] + events = ctx._get_events(current_invocation=True, current_branch=True) + for event in events: + if event.content: + contents.append(event.content) + + iteration = 0 + error_count = 0 + final_answer = None + + while iteration < self.max_iterations: + iteration += 1 + state.iteration_count = iteration + + # Build LLM request + llm_request = LlmRequest( + model=model.model, + contents=contents, + config=types.GenerateContentConfig( + system_instruction=system_prompt, + ), + ) + + # Call the model (generate_content_async returns an async generator) + llm_response = None + async for response in model.generate_content_async( + llm_request, stream=False + ): + llm_response = response + break + + # Extract response text + response_text = "" + if llm_response and llm_response.content and llm_response.content.parts: + response_text = "".join( + part.text for part in llm_response.content.parts if part.text + ) + + # Check for code block + code = self._extract_code_block(response_text) + + if not code: + # No code generated - treat as final response + # Check if the response looks like a final answer + final_answer = response_text + break + + # Execute the code + code_input = CodeExecutionInput(code=code) + exec_result = coding_executor.execute_code_extended( + invocation_context=ctx, + code_execution_input=code_input, + ) + + # Record execution in state + state.execution_history.append( + { + "iteration": iteration, + "code": code, + "stdout": exec_result.clean_stdout, + "stderr": exec_result.code_result.stderr, + "tool_traces": exec_result.tool_traces, + "has_final_answer": exec_result.has_final_answer, + } + ) + + # Check for errors + if exec_result.code_result.stderr: + error_count += 1 + state.error_count = error_count + + if error_count > self.error_retry_attempts: + # Too many errors - give up + final_answer = ( + f"I encountered too many errors while executing code. " + f"Last error: {exec_result.code_result.stderr}" + ) + break + + # Build error feedback and add to conversation + error_feedback = self._build_error_feedback( + exec_result.code_result.stderr, + code, + ) + contents.append( + types.Content( + role="model", + parts=[types.Part(text=response_text)], + ) + ) + contents.append( + types.Content( + role="user", + parts=[types.Part(text=error_feedback)], + ) + ) + continue + + # Reset error count on success + error_count = 0 + state.error_count = 0 + + # Check for final answer + if exec_result.has_final_answer: + final_answer = exec_result.final_answer + break + + # Add execution result to conversation and continue + contents.append( + types.Content( + role="model", + parts=[types.Part(text=response_text)], + ) + ) + + # Add execution output as user message + output_text = f"""Code execution result: +``` +{exec_result.clean_stdout} +``` +""" + contents.append( + types.Content( + role="user", + parts=[types.Part(text=output_text)], + ) + ) + + # Build final event + if final_answer is None: + final_answer = ( + "I was unable to complete the task within the allowed iterations." + ) + + # Convert final_answer to string if needed + if not isinstance(final_answer, str): + import json + + try: + final_answer = json.dumps(final_answer) + except (TypeError, ValueError): + final_answer = str(final_answer) + + # Update state in context + ctx.agent_states[self.name] = state.model_dump() + + # Yield final event + yield Event( + invocation_id=ctx.invocation_id, + author=self.name, + branch=ctx.branch, + content=types.Content( + role="model", + parts=[types.Part(text=final_answer)], + ), + actions=EventActions( + agent_state=state.model_dump(), + ), + ) + + @model_validator(mode="after") + def _validate_model(self) -> CodingAgent: + """Validate the model after construction.""" + return self + + def cleanup(self) -> None: + """Clean up resources.""" + if self._coding_executor: + self._coding_executor.cleanup() + self._coding_executor = None + self._resolved_tools = None + + def __del__(self): + """Destructor to clean up resources.""" + try: + self.cleanup() + except Exception: + pass diff --git a/src/google/adk/agents/coding_agent_config.py b/src/google/adk/agents/coding_agent_config.py new file mode 100644 index 0000000000..ae78c9ade0 --- /dev/null +++ b/src/google/adk/agents/coding_agent_config.py @@ -0,0 +1,226 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import FrozenSet +from typing import List +from typing import Literal +from typing import Optional +from typing import Union + +from pydantic import Field + +from ..utils.feature_decorator import experimental +from .base_agent_config import BaseAgentConfig +from .common_configs import CodeConfig +from ..tools.tool_configs import ToolConfig + + +# Default set of safe imports for Python code execution +DEFAULT_SAFE_IMPORTS: FrozenSet[str] = frozenset( + { + # Standard library - safe modules + "json", + "math", + "re", + "datetime", + "collections", + "collections.*", + "itertools", + "functools", + "operator", + "string", + "textwrap", + "unicodedata", + "decimal", + "fractions", + "random", + "statistics", + "typing", + "typing.*", + "dataclasses", + "enum", + "abc", + "copy", + "pprint", + "reprlib", + "numbers", + "cmath", + "time", + "calendar", + "hashlib", + "hmac", + "base64", + "binascii", + "html", + "html.*", + "urllib.parse", + "uuid", + "struct", + "codecs", + "locale", + "gettext", + "bisect", + "heapq", + "array", + "weakref", + "types", + "contextlib", + "warnings", + "traceback", + "linecache", + "difflib", + "graphlib", + "zoneinfo", + # Common data science (can be enabled explicitly) + "numpy", + "numpy.*", + "pandas", + "pandas.*", + "scipy", + "scipy.*", + "matplotlib", + "matplotlib.*", + } +) + + +@experimental +class CodingAgentConfig(BaseAgentConfig): + """Configuration for CodingAgent. + + This config extends BaseAgentConfig with fields specific to agents that + generate and execute Python code to accomplish tasks using tools. + """ + + agent_class: Union[Literal["CodingAgent"], str] = Field( + default="CodingAgent", + description="The class of the agent. Must be CodingAgent.", + ) + + model: str = Field( + default="", + description=( + "The model to use for the agent. When not set, the agent will " + "inherit the model from its ancestor or use the default model." + ), + ) + + model_code: Optional[CodeConfig] = Field( + default=None, + description=( + "Optional. Code reference to a custom model instance. " + "Takes precedence over the model field if both are set." + ), + ) + + instruction: str = Field( + default="", + description=( + "Dynamic instructions for the agent, guiding its behavior. " + "Can contain placeholders like {variable_name} that will be " + "resolved at runtime using session state and context." + ), + ) + + tools: Optional[List[ToolConfig]] = Field( + default=None, + description=( + "Optional. The list of tools available to the agent. " + "Tools are exposed as Python functions that the agent can call " + "in the generated code." + ), + ) + + code_executor: Optional[CodeConfig] = Field( + default=None, + description=( + "Optional. Code reference to a custom code executor instance. " + "If not set, a default ContainerCodeExecutor will be used." + ), + ) + + authorized_imports: FrozenSet[str] = Field( + default=DEFAULT_SAFE_IMPORTS, + description=( + "Set of allowed import names/patterns. Supports wildcards " + '(e.g., "collections.*" allows all collections submodules). ' + "Any imports not in this set will be rejected before execution." + ), + ) + + max_iterations: int = Field( + default=10, + ge=1, + le=100, + description=( + "Maximum number of ReAct loop iterations. Each iteration " + "involves generating code, executing it, and processing results." + ), + ) + + error_retry_attempts: int = Field( + default=2, + ge=0, + le=10, + description=( + "Number of times to retry code execution on errors. " + "Error messages are fed back to the LLM for correction." + ), + ) + + stateful: bool = Field( + default=False, + description=( + "Whether to maintain state across iterations. If True, " + "execution history is preserved and re-executed to restore state." + ), + ) + + tool_server_host: Optional[str] = Field( + default=None, + description=( + "Host address for the tool execution server. If not set, " + "auto-detection will try host.docker.internal first, " + "then fall back to 172.17.0.1 for Linux." + ), + ) + + tool_server_port: int = Field( + default=8765, + ge=1024, + le=65535, + description="Port for the tool execution server.", + ) + + before_model_callbacks: Optional[List[CodeConfig]] = Field( + default=None, + description="Optional. Callbacks to be called before calling the LLM.", + ) + + after_model_callbacks: Optional[List[CodeConfig]] = Field( + default=None, + description="Optional. Callbacks to be called after calling the LLM.", + ) + + before_tool_callbacks: Optional[List[CodeConfig]] = Field( + default=None, + description="Optional. Callbacks to be called before calling a tool.", + ) + + after_tool_callbacks: Optional[List[CodeConfig]] = Field( + default=None, + description="Optional. Callbacks to be called after calling a tool.", + ) diff --git a/src/google/adk/code_executors/__init__.py b/src/google/adk/code_executors/__init__.py index 1cf04a477d..5e5fa7d1e7 100644 --- a/src/google/adk/code_executors/__init__.py +++ b/src/google/adk/code_executors/__init__.py @@ -21,59 +21,118 @@ from .code_executor_context import CodeExecutorContext from .unsafe_local_code_executor import UnsafeLocalCodeExecutor -logger = logging.getLogger('google_adk.' + __name__) +logger = logging.getLogger("google_adk." + __name__) __all__ = [ - 'BaseCodeExecutor', - 'BuiltInCodeExecutor', - 'CodeExecutorContext', - 'UnsafeLocalCodeExecutor', - 'VertexAiCodeExecutor', - 'ContainerCodeExecutor', - 'GkeCodeExecutor', - 'AgentEngineSandboxCodeExecutor', + "BaseCodeExecutor", + "BuiltInCodeExecutor", + "CodeExecutorContext", + "UnsafeLocalCodeExecutor", + "VertexAiCodeExecutor", + "ContainerCodeExecutor", + "GkeCodeExecutor", + "AgentEngineSandboxCodeExecutor", + # CodingAgent components + "AllowlistValidator", + "CodingAgentCodeExecutor", + "ToolCodeGenerator", + "ToolExecutionServer", ] def __getattr__(name: str): - if name == 'VertexAiCodeExecutor': - try: - from .vertex_ai_code_executor import VertexAiCodeExecutor - - return VertexAiCodeExecutor - except ImportError as e: - raise ImportError( - 'VertexAiCodeExecutor requires additional dependencies. ' - 'Please install with: pip install "google-adk[extensions]"' - ) from e - elif name == 'ContainerCodeExecutor': - try: - from .container_code_executor import ContainerCodeExecutor - - return ContainerCodeExecutor - except ImportError as e: - raise ImportError( - 'ContainerCodeExecutor requires additional dependencies. ' - 'Please install with: pip install "google-adk[extensions]"' - ) from e - elif name == 'GkeCodeExecutor': - try: - from .gke_code_executor import GkeCodeExecutor - - return GkeCodeExecutor - except ImportError as e: - raise ImportError( - 'GkeCodeExecutor requires additional dependencies. ' - 'Please install with: pip install "google-adk[extensions]"' - ) from e - elif name == 'AgentEngineSandboxCodeExecutor': - try: - from .agent_engine_sandbox_code_executor import AgentEngineSandboxCodeExecutor - - return AgentEngineSandboxCodeExecutor - except ImportError as e: - raise ImportError( - 'AgentEngineSandboxCodeExecutor requires additional dependencies. ' - 'Please install with: pip install "google-adk[extensions]"' - ) from e - raise AttributeError(f"module '{__name__}' has no attribute '{name}'") + if name == "VertexAiCodeExecutor": + try: + from .vertex_ai_code_executor import VertexAiCodeExecutor + + return VertexAiCodeExecutor + except ImportError as e: + raise ImportError( + "VertexAiCodeExecutor requires additional dependencies. " + 'Please install with: pip install "google-adk[extensions]"' + ) from e + elif name == "ContainerCodeExecutor": + try: + from .container_code_executor import ContainerCodeExecutor + + return ContainerCodeExecutor + except ImportError as e: + raise ImportError( + "ContainerCodeExecutor requires additional dependencies. " + 'Please install with: pip install "google-adk[extensions]"' + ) from e + elif name == "GkeCodeExecutor": + try: + from .gke_code_executor import GkeCodeExecutor + + return GkeCodeExecutor + except ImportError as e: + raise ImportError( + "GkeCodeExecutor requires additional dependencies. " + 'Please install with: pip install "google-adk[extensions]"' + ) from e + elif name == "AgentEngineSandboxCodeExecutor": + try: + from .agent_engine_sandbox_code_executor import ( + AgentEngineSandboxCodeExecutor, + ) + + return AgentEngineSandboxCodeExecutor + except ImportError as e: + raise ImportError( + "AgentEngineSandboxCodeExecutor requires additional dependencies. " + 'Please install with: pip install "google-adk[extensions]"' + ) from e + elif name == "AllowlistValidator": + try: + from .allowlist_validator import AllowlistValidator + + return AllowlistValidator + except ImportError as e: + raise ImportError( + "AllowlistValidator requires additional dependencies. " + 'Please install with: pip install "google-adk[extensions]"' + ) from e + elif name == "CodingAgentCodeExecutor": + try: + from .coding_agent_code_executor import CodingAgentCodeExecutor + + return CodingAgentCodeExecutor + except ImportError as e: + raise ImportError( + "CodingAgentCodeExecutor requires additional dependencies. " + 'Please install with: pip install "google-adk[extensions]"' + ) from e + elif name == "ToolCodeGenerator": + try: + from .tool_code_generator import generate_full_code_with_stubs + from .tool_code_generator import generate_runtime_header + from .tool_code_generator import generate_system_prompt + from .tool_code_generator import generate_tool_stubs + + # Return module-like object with functions + class ToolCodeGenerator: + generate_full_code_with_stubs = staticmethod( + generate_full_code_with_stubs + ) + generate_runtime_header = staticmethod(generate_runtime_header) + generate_system_prompt = staticmethod(generate_system_prompt) + generate_tool_stubs = staticmethod(generate_tool_stubs) + + return ToolCodeGenerator + except ImportError as e: + raise ImportError( + "ToolCodeGenerator requires additional dependencies. " + 'Please install with: pip install "google-adk[extensions]"' + ) from e + elif name == "ToolExecutionServer": + try: + from .tool_execution_server import ToolExecutionServer + + return ToolExecutionServer + except ImportError as e: + raise ImportError( + "ToolExecutionServer requires additional dependencies. " + 'Please install with: pip install "google-adk[extensions]"' + ) from e + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") diff --git a/src/google/adk/code_executors/allowlist_validator.py b/src/google/adk/code_executors/allowlist_validator.py new file mode 100644 index 0000000000..68319bf16f --- /dev/null +++ b/src/google/adk/code_executors/allowlist_validator.py @@ -0,0 +1,354 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Import allowlist validation for code execution security.""" + +from __future__ import annotations + +import ast +import fnmatch +import logging +from dataclasses import dataclass +from dataclasses import field +from typing import FrozenSet +from typing import List +from typing import Set + +logger = logging.getLogger("google_adk." + __name__) + + +# Default set of safe imports that are always allowed +DEFAULT_SAFE_IMPORTS: FrozenSet[str] = frozenset( + { + # Standard library - safe modules + "json", + "math", + "re", + "datetime", + "collections", + "collections.*", + "itertools", + "functools", + "operator", + "string", + "textwrap", + "unicodedata", + "decimal", + "fractions", + "random", + "statistics", + "typing", + "typing.*", + "dataclasses", + "enum", + "abc", + "copy", + "pprint", + "reprlib", + "numbers", + "cmath", + "time", + "calendar", + "hashlib", + "hmac", + "base64", + "binascii", + "html", + "html.*", + "urllib.parse", + "uuid", + "struct", + "codecs", + "locale", + "gettext", + "bisect", + "heapq", + "array", + "weakref", + "types", + "contextlib", + "warnings", + "traceback", + "linecache", + "difflib", + "graphlib", + "zoneinfo", + } +) + + +class ImportValidationError(Exception): + """Exception raised when import validation fails. + + Attributes: + violations: List of import violations found. + code: The code that was validated. + """ + + def __init__( + self, + violations: List[str], + code: str, + ): + self.violations = violations + self.code = code + violation_str = "\n".join(f" - {v}" for v in violations) + super().__init__( + f"Import validation failed. Unauthorized imports found:\n{violation_str}" + ) + + +@dataclass +class ImportInfo: + """Information about an import statement. + + Attributes: + module: The module being imported. + names: Names being imported from the module (for 'from' imports). + alias: Alias for the import (if any). + line_number: Line number in the source code. + is_from_import: Whether this is a 'from X import Y' style import. + """ + + module: str + names: List[str] = field(default_factory=list) + alias: str = "" + line_number: int = 0 + is_from_import: bool = False + + +def extract_imports(code: str) -> List[ImportInfo]: + """Extract all import statements from Python code using AST. + + Args: + code: Python source code to analyze. + + Returns: + List of ImportInfo objects describing each import. + + Raises: + SyntaxError: If the code cannot be parsed. + """ + imports = [] + + try: + tree = ast.parse(code) + except SyntaxError as e: + logger.warning("Failed to parse code for import extraction: %s", e) + raise + + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + imports.append( + ImportInfo( + module=alias.name, + alias=alias.asname or "", + line_number=node.lineno, + is_from_import=False, + ) + ) + elif isinstance(node, ast.ImportFrom): + module = node.module or "" + names = [alias.name for alias in node.names] + for alias in node.names: + imports.append( + ImportInfo( + module=module, + names=[alias.name], + alias=alias.asname or "", + line_number=node.lineno, + is_from_import=True, + ) + ) + + return imports + + +def is_import_allowed( + import_name: str, + allowlist: FrozenSet[str], +) -> bool: + """Check if an import is allowed by the allowlist. + + Supports wildcards: + - 'collections.*' allows 'collections.abc', 'collections.defaultdict', etc. + - 'numpy' allows only 'numpy', not 'numpy.array' + - 'numpy.*' allows 'numpy.array', 'numpy.linalg', etc. + + Args: + import_name: The full import name to check. + allowlist: Set of allowed import patterns. + + Returns: + True if the import is allowed, False otherwise. + """ + # Direct match + if import_name in allowlist: + return True + + # Check wildcard patterns + for pattern in allowlist: + if "*" in pattern: + if fnmatch.fnmatch(import_name, pattern): + return True + # Also check if the import is a submodule of an allowed module + # e.g., 'collections.*' should allow 'collections.abc.Callable' + base_pattern = pattern.rstrip(".*") + if import_name.startswith(base_pattern + "."): + return True + + # Check if parent module is allowed with wildcard + parts = import_name.split(".") + for i in range(len(parts)): + parent = ".".join(parts[: i + 1]) + wildcard_pattern = parent + ".*" + if wildcard_pattern in allowlist: + return True + + return False + + +def validate_imports( + code: str, + allowlist: FrozenSet[str], +) -> List[str]: + """Validate that all imports in code are in the allowlist. + + Args: + code: Python source code to validate. + allowlist: Set of allowed import patterns. + + Returns: + List of violations (empty if all imports are valid). + + Raises: + ImportValidationError: If unauthorized imports are found. + """ + violations = [] + + try: + imports = extract_imports(code) + except SyntaxError as e: + # If we can't parse, we can't validate - return syntax error as violation + violations.append(f"Syntax error in code: {e}") + return violations + + for import_info in imports: + module = import_info.module + + if import_info.is_from_import: + # For 'from X import Y', check both the module and the full path + for name in import_info.names: + full_name = f"{module}.{name}" if module else name + # Check if module is allowed OR full import path is allowed + if not ( + is_import_allowed(module, allowlist) + or is_import_allowed(full_name, allowlist) + ): + violations.append( + f"Line {import_info.line_number}: " + f'Unauthorized import "from {module} import {name}"' + ) + else: + # For 'import X', just check the module + if not is_import_allowed(module, allowlist): + violations.append( + f'Line {import_info.line_number}: Unauthorized import "{module}"' + ) + + return violations + + +def validate_imports_strict( + code: str, + allowlist: FrozenSet[str], +) -> None: + """Validate imports and raise exception if any violations found. + + Args: + code: Python source code to validate. + allowlist: Set of allowed import patterns. + + Raises: + ImportValidationError: If unauthorized imports are found. + """ + violations = validate_imports(code, allowlist) + if violations: + raise ImportValidationError(violations, code) + + +class AllowlistValidator: + """Validator for checking Python code imports against an allowlist. + + This class provides a stateful validator that can be reused for multiple + validations with the same allowlist. + + Attributes: + allowlist: The set of allowed import patterns. + """ + + def __init__( + self, + allowlist: FrozenSet[str] = DEFAULT_SAFE_IMPORTS, + additional_imports: FrozenSet[str] = frozenset(), + ): + """Initialize the validator with an allowlist. + + Args: + allowlist: Base set of allowed import patterns. + additional_imports: Additional imports to allow beyond the base set. + """ + self.allowlist = allowlist | additional_imports + + def validate(self, code: str) -> List[str]: + """Validate imports in code. + + Args: + code: Python source code to validate. + + Returns: + List of violations (empty if all imports are valid). + """ + return validate_imports(code, self.allowlist) + + def validate_strict(self, code: str) -> None: + """Validate imports and raise if any violations found. + + Args: + code: Python source code to validate. + + Raises: + ImportValidationError: If unauthorized imports are found. + """ + validate_imports_strict(code, self.allowlist) + + def is_allowed(self, import_name: str) -> bool: + """Check if a single import is allowed. + + Args: + import_name: The import name to check. + + Returns: + True if allowed, False otherwise. + """ + return is_import_allowed(import_name, self.allowlist) + + def add_allowed_imports(self, imports: Set[str]) -> None: + """Add additional allowed imports. + + Args: + imports: Set of import patterns to allow. + """ + self.allowlist = self.allowlist | frozenset(imports) diff --git a/src/google/adk/code_executors/coding_agent_code_executor.py b/src/google/adk/code_executors/coding_agent_code_executor.py new file mode 100644 index 0000000000..22099a53e4 --- /dev/null +++ b/src/google/adk/code_executors/coding_agent_code_executor.py @@ -0,0 +1,505 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Code executor for CodingAgent with tool injection support. + +This module provides a code executor that wraps an underlying executor +(e.g., ContainerCodeExecutor) and adds tool injection via HTTP IPC. +""" + +from __future__ import annotations + +import hashlib +import json +import logging +import re +from dataclasses import dataclass +from dataclasses import field +from typing import Any +from typing import Dict +from typing import FrozenSet +from typing import List +from typing import Optional +from typing import TYPE_CHECKING + +from pydantic import Field +from pydantic import PrivateAttr +from typing_extensions import override + +from .allowlist_validator import AllowlistValidator +from .allowlist_validator import DEFAULT_SAFE_IMPORTS +from .allowlist_validator import ImportValidationError +from .base_code_executor import BaseCodeExecutor +from .code_execution_utils import CodeExecutionInput +from .code_execution_utils import CodeExecutionResult +from .tool_code_generator import generate_full_code_with_stubs +from .tool_execution_server import detect_docker_host_address +from .tool_execution_server import ToolExecutionServer +from .tool_execution_server import ToolTrace +from ..tools.base_tool import BaseTool + +if TYPE_CHECKING: + from ..agents.invocation_context import InvocationContext + from ..tools.tool_context import ToolContext + +logger = logging.getLogger("google_adk." + __name__) + + +# Markers for extracting data from execution output +TOOL_TRACE_MARKER = "__TOOL_TRACE__:" +FINAL_ANSWER_MARKER = "__FINAL_ANSWER__:" + + +@dataclass +class ExecutionStep: + """Record of a single code execution step. + + Attributes: + code: The code that was executed. + code_hash: Hash of the code for comparison. + result: The execution result. + tool_traces: Tool call traces from this step. + success: Whether the execution succeeded. + final_answer: The final answer if one was provided. + """ + + code: str + code_hash: str = "" + result: Optional[CodeExecutionResult] = None + tool_traces: List[Dict[str, Any]] = field(default_factory=list) + success: bool = False + final_answer: Optional[Any] = None + + def __post_init__(self): + if not self.code_hash: + self.code_hash = hashlib.sha256(self.code.encode()).hexdigest()[:16] + + +@dataclass +class CodingAgentExecutionResult: + """Extended execution result with CodingAgent-specific fields. + + Attributes: + code_result: The underlying code execution result. + tool_traces: List of tool call traces. + final_answer: The final answer if one was provided. + has_final_answer: Whether a final answer was extracted. + clean_stdout: Stdout with trace markers removed. + """ + + code_result: CodeExecutionResult + tool_traces: List[Dict[str, Any]] = field(default_factory=list) + final_answer: Optional[Any] = None + has_final_answer: bool = False + clean_stdout: str = "" + + +class CodingAgentCodeExecutor(BaseCodeExecutor): + """Code executor with tool injection for CodingAgent. + + This executor wraps an underlying code executor and adds: + - Tool stub prepending for HTTP-based tool calls + - Import allowlist validation before execution + - Tool execution server lifecycle management + - Trace extraction from execution output + - Final answer detection + - History re-execution for stateful mode + + Attributes: + underlying_executor: The actual code executor to use. + tools: List of tools to make available. + authorized_imports: Set of allowed import patterns. + tool_server_host: Host for the tool server. + tool_server_port: Port for the tool server. + execution_history: List of execution steps for stateful mode. + """ + + underlying_executor: BaseCodeExecutor + tools: List[BaseTool] = Field(default_factory=list) + authorized_imports: FrozenSet[str] = DEFAULT_SAFE_IMPORTS + tool_server_host: Optional[str] = None + tool_server_port: int = 8765 + + # Internal state - use PrivateAttr for Pydantic + _tool_server: Optional[ToolExecutionServer] = PrivateAttr(default=None) + _validator: Optional[AllowlistValidator] = PrivateAttr(default=None) + _invocation_context: Optional[InvocationContext] = PrivateAttr(default=None) + _tool_context: Optional[ToolContext] = PrivateAttr(default=None) + _execution_history: List[ExecutionStep] = PrivateAttr(default_factory=list) + + class Config: + """Pydantic config.""" + + arbitrary_types_allowed = True + + def model_post_init(self, __context): + """Initialize after model construction.""" + self._validator = AllowlistValidator( + allowlist=self.authorized_imports, + ) + self._execution_history = [] + + def set_context( + self, + invocation_context: InvocationContext, + tool_context: Optional[ToolContext] = None, + ) -> None: + """Set the execution context. + + Args: + invocation_context: The invocation context. + tool_context: The tool context. + """ + self._invocation_context = invocation_context + self._tool_context = tool_context + if self._tool_server: + self._tool_server.set_context(invocation_context, tool_context) + + def _start_tool_server(self) -> None: + """Start the tool execution server if not already running.""" + if self._tool_server is not None: + return + + host = self.tool_server_host or "0.0.0.0" + self._tool_server = ToolExecutionServer( + host=host, + port=self.tool_server_port, + tools=self.tools, + invocation_context=self._invocation_context, + ) + self._tool_server.start() + + def _stop_tool_server(self) -> None: + """Stop the tool execution server.""" + if self._tool_server: + self._tool_server.stop() + self._tool_server = None + + def _get_tool_server_url(self) -> str: + """Get the URL for the tool server. + + Returns: + The tool server URL accessible from containers. + """ + if self.tool_server_host: + host = self.tool_server_host + else: + host = detect_docker_host_address() + return f"http://{host}:{self.tool_server_port}" + + def _validate_imports(self, code: str) -> None: + """Validate imports in the code against the allowlist. + + Args: + code: The code to validate. + + Raises: + ImportValidationError: If unauthorized imports are found. + """ + if self._validator: + self._validator.validate_strict(code) + + def _extract_traces_and_answer( + self, + result: CodeExecutionResult, + ) -> CodingAgentExecutionResult: + """Extract tool traces and final answer from execution output. + + Args: + result: The raw execution result. + + Returns: + Extended result with extracted data. + """ + tool_traces = [] + final_answer = None + has_final_answer = False + clean_lines = [] + + for line in result.stdout.split("\n"): + if line.startswith(TOOL_TRACE_MARKER): + try: + trace_json = line[len(TOOL_TRACE_MARKER) :] + traces = json.loads(trace_json) + tool_traces.extend(traces) + except json.JSONDecodeError as e: + logger.warning("Failed to parse tool trace: %s", e) + elif line.startswith(FINAL_ANSWER_MARKER): + answer_str = line[len(FINAL_ANSWER_MARKER) :] + try: + final_answer = json.loads(answer_str) + except json.JSONDecodeError: + # Not JSON, treat as string + final_answer = answer_str + has_final_answer = True + else: + clean_lines.append(line) + + clean_stdout = "\n".join(clean_lines).strip() + + return CodingAgentExecutionResult( + code_result=result, + tool_traces=tool_traces, + final_answer=final_answer, + has_final_answer=has_final_answer, + clean_stdout=clean_stdout, + ) + + def _should_skip_step(self, step: ExecutionStep, code_hash: str) -> bool: + """Check if an execution step can be skipped. + + For stateful mode, we can skip re-executing code if: + - The code hasn't changed (same hash) + - The previous execution succeeded + + Args: + step: The previous execution step. + code_hash: Hash of the current code. + + Returns: + True if the step can be skipped. + """ + return step.success and step.code_hash == code_hash + + def _prepend_tool_stubs(self, code: str) -> str: + """Prepend runtime header and tool stubs to user code. + + Args: + code: The user code to wrap. + + Returns: + Complete code with tool stubs. + """ + return generate_full_code_with_stubs( + user_code=code, + tools=self.tools, + tool_server_url=self._get_tool_server_url(), + ) + + def _replay_history( + self, + invocation_context: InvocationContext, + ) -> Optional[CodeExecutionResult]: + """Replay execution history for stateful mode. + + This re-executes previous successful steps to restore state + before executing new code. + + Args: + invocation_context: The invocation context. + + Returns: + The result of the last replayed step, or None if no replay needed. + """ + if not self.stateful or not self._execution_history: + return None + + last_result = None + for step in self._execution_history: + if step.success: + # Re-execute to restore state + full_code = self._prepend_tool_stubs(step.code) + input_data = CodeExecutionInput(code=full_code) + last_result = self.underlying_executor.execute_code( + invocation_context=invocation_context, + code_execution_input=input_data, + ) + logger.debug("Replayed history step: %s", step.code_hash) + + return last_result + + @override + def execute_code( + self, + invocation_context: InvocationContext, + code_execution_input: CodeExecutionInput, + ) -> CodeExecutionResult: + """Execute code with tool injection. + + Args: + invocation_context: The invocation context. + code_execution_input: The code to execute. + + Returns: + The execution result. + """ + user_code = code_execution_input.code + + # Validate imports first (security check before execution) + try: + self._validate_imports(user_code) + except ImportValidationError as e: + return CodeExecutionResult( + stdout="", + stderr=str(e), + output_files=[], + ) + + # Start tool server if needed + self._start_tool_server() + + # Set context on tool server + if self._tool_server: + self._tool_server.set_context( + invocation_context, + self._tool_context, + ) + self._tool_server.clear_traces() + + # Replay history for stateful mode + if self.stateful: + self._replay_history(invocation_context) + + # Prepend tool stubs to user code + full_code = self._prepend_tool_stubs(user_code) + + # Execute the code + input_with_stubs = CodeExecutionInput( + code=full_code, + input_files=code_execution_input.input_files, + execution_id=code_execution_input.execution_id, + ) + + result = self.underlying_executor.execute_code( + invocation_context=invocation_context, + code_execution_input=input_with_stubs, + ) + + # Extract traces and final answer + extended_result = self._extract_traces_and_answer(result) + + # Record execution step for stateful mode + step = ExecutionStep( + code=user_code, + result=result, + tool_traces=extended_result.tool_traces, + success=not result.stderr, + final_answer=extended_result.final_answer, + ) + self._execution_history.append(step) + + # Return result with clean stdout (traces stripped) + return CodeExecutionResult( + stdout=extended_result.clean_stdout, + stderr=result.stderr, + output_files=result.output_files, + ) + + def execute_code_extended( + self, + invocation_context: InvocationContext, + code_execution_input: CodeExecutionInput, + ) -> CodingAgentExecutionResult: + """Execute code and return extended result with traces. + + Args: + invocation_context: The invocation context. + code_execution_input: The code to execute. + + Returns: + Extended execution result with tool traces and final answer. + """ + user_code = code_execution_input.code + + # Validate imports first + try: + self._validate_imports(user_code) + except ImportValidationError as e: + return CodingAgentExecutionResult( + code_result=CodeExecutionResult( + stdout="", + stderr=str(e), + output_files=[], + ), + tool_traces=[], + final_answer=None, + has_final_answer=False, + clean_stdout="", + ) + + # Start tool server if needed + self._start_tool_server() + + # Set context on tool server + if self._tool_server: + self._tool_server.set_context( + invocation_context, + self._tool_context, + ) + self._tool_server.clear_traces() + + # Replay history for stateful mode + if self.stateful: + self._replay_history(invocation_context) + + # Prepend tool stubs to user code + full_code = self._prepend_tool_stubs(user_code) + + # Execute the code + input_with_stubs = CodeExecutionInput( + code=full_code, + input_files=code_execution_input.input_files, + execution_id=code_execution_input.execution_id, + ) + + result = self.underlying_executor.execute_code( + invocation_context=invocation_context, + code_execution_input=input_with_stubs, + ) + + # Extract traces and final answer + extended_result = self._extract_traces_and_answer(result) + + # Record execution step for stateful mode + step = ExecutionStep( + code=user_code, + result=result, + tool_traces=extended_result.tool_traces, + success=not result.stderr, + final_answer=extended_result.final_answer, + ) + self._execution_history.append(step) + + return extended_result + + def get_execution_history(self) -> List[ExecutionStep]: + """Get the execution history. + + Returns: + List of execution steps. + """ + return self._execution_history.copy() + + def clear_execution_history(self) -> None: + """Clear the execution history.""" + self._execution_history.clear() + + def get_tool_traces(self) -> List[ToolTrace]: + """Get tool traces from the server. + + Returns: + List of tool traces. + """ + if self._tool_server: + return self._tool_server.get_traces() + return [] + + def cleanup(self) -> None: + """Clean up resources.""" + self._stop_tool_server() + self._execution_history.clear() + + def __del__(self): + """Destructor to clean up resources.""" + self.cleanup() diff --git a/src/google/adk/code_executors/tool_code_generator.py b/src/google/adk/code_executors/tool_code_generator.py new file mode 100644 index 0000000000..e2e18271dd --- /dev/null +++ b/src/google/adk/code_executors/tool_code_generator.py @@ -0,0 +1,469 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tool code generator for CodingAgent. + +This module provides functions to generate Python code stubs and runtime +headers that allow generated code to call ADK tools via HTTP IPC. +""" + +from __future__ import annotations + +import json +import logging +import textwrap +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ..tools.base_tool import BaseTool + +logger = logging.getLogger("google_adk." + __name__) + + +# Runtime header template that provides HTTP-based tool calling +RUNTIME_HEADER_TEMPLATE = ''' +# ============================================================================ +# ADK CodingAgent Runtime Header - DO NOT MODIFY +# ============================================================================ +import json as __adk_json +import urllib.request as __adk_urllib_request +import urllib.error as __adk_urllib_error + +__ADK_TOOL_SERVER_URL = "{tool_server_url}" +__ADK_TOOL_TRACES = [] + +def _call_adk_tool(__tool_name: str, **kwargs) -> dict: + """Call an ADK tool via HTTP IPC. + + Args: + __tool_name: Name of the tool to call. + **kwargs: Arguments to pass to the tool. + + Returns: + The tool result as a dictionary. + """ + global __ADK_TOOL_TRACES + + request_data = __adk_json.dumps({{ + "tool_name": __tool_name, + "args": kwargs, + }}).encode("utf-8") + + req = __adk_urllib_request.Request( + __ADK_TOOL_SERVER_URL + "/tool_call", + data=request_data, + headers={{"Content-Type": "application/json"}}, + method="POST", + ) + + try: + with __adk_urllib_request.urlopen(req, timeout=300) as response: + result = __adk_json.loads(response.read().decode("utf-8")) + # Record the trace + __ADK_TOOL_TRACES.append({{ + "tool_name": __tool_name, + "args": kwargs, + "result": result, + "success": True, + }}) + return result + except __adk_urllib_error.HTTPError as e: + error_body = e.read().decode("utf-8") if e.fp else str(e) + __ADK_TOOL_TRACES.append({{ + "tool_name": __tool_name, + "args": kwargs, + "error": error_body, + "success": False, + }}) + raise RuntimeError(f"Tool call failed: {{error_body}}") from e + except __adk_urllib_error.URLError as e: + __ADK_TOOL_TRACES.append({{ + "tool_name": __tool_name, + "args": kwargs, + "error": str(e), + "success": False, + }}) + raise RuntimeError(f"Tool server connection failed: {{e}}") from e + +def __get_tool_traces() -> list: + """Get all tool call traces.""" + return __ADK_TOOL_TRACES + +def __clear_tool_traces(): + """Clear all tool call traces.""" + global __ADK_TOOL_TRACES + __ADK_TOOL_TRACES = [] + +# Final answer function for terminating execution +__FINAL_ANSWER_VALUE = None + +def final_answer(result): + """Mark the final answer and terminate the code execution. + + Args: + result: The final result to return to the user. + """ + global __FINAL_ANSWER_VALUE + __FINAL_ANSWER_VALUE = result + print(f"__FINAL_ANSWER__:{{__adk_json.dumps(result) if not isinstance(result, str) else result}}") + +# ============================================================================ +# End of Runtime Header +# ============================================================================ + +''' + + +def generate_runtime_header( + tool_server_url: str, +) -> str: + """Generate the runtime header with HTTP client and helper functions. + + Args: + tool_server_url: URL of the tool execution server. + + Returns: + Python code string containing the runtime header. + """ + return RUNTIME_HEADER_TEMPLATE.format(tool_server_url=tool_server_url) + + +def _get_schema_type(schema: Any) -> str: + """Get the type from a schema (dict or Pydantic Schema object). + + Args: + schema: JSON schema dict or google.genai.types.Schema object. + + Returns: + The type as a lowercase string. + """ + if hasattr(schema, "type"): + # Pydantic Schema object from google.genai.types + schema_type = schema.type + if schema_type is None: + return "any" + # Handle enum (Type.STRING -> "string") + if hasattr(schema_type, "value"): + return schema_type.value.lower() + return str(schema_type).lower() + elif isinstance(schema, dict): + return schema.get("type", "any") + return "any" + + +def _get_schema_attr(schema: Any, attr: str, default: Any = None) -> Any: + """Get an attribute from a schema (dict or Pydantic Schema object). + + Args: + schema: JSON schema dict or google.genai.types.Schema object. + attr: The attribute name to get. + default: Default value if attribute not found. + + Returns: + The attribute value or default. + """ + if hasattr(schema, attr): + return getattr(schema, attr, default) + elif isinstance(schema, dict): + return schema.get(attr, default) + return default + + +def _get_python_type_hint(schema: Any) -> str: + """Convert JSON schema type to Python type hint. + + Args: + schema: JSON schema dict or google.genai.types.Schema object. + + Returns: + Python type hint string. + """ + schema_type = _get_schema_type(schema) + + type_mapping = { + "string": "str", + "integer": "int", + "number": "float", + "boolean": "bool", + "array": "list", + "object": "dict", + } + + if schema_type == "array": + items = _get_schema_attr(schema, "items", {}) + if items: + item_type = _get_python_type_hint(items) + return f"list[{item_type}]" + return "list" + elif schema_type == "object": + return "dict" + + return type_mapping.get(schema_type, "Any") + + +def _generate_tool_stub(tool: BaseTool) -> str: + """Generate a Python function stub for a single tool. + + Args: + tool: The BaseTool to generate a stub for. + + Returns: + Python code string for the tool stub function. + """ + decl = tool._get_declaration() + if not decl: + logger.warning( + "Tool %s has no declaration, skipping stub generation", tool.name + ) + return "" + + # Build parameter list with type hints + params = [] + param_docs = [] + + if decl.parameters and decl.parameters.properties: + required = set(decl.parameters.required or []) + + for param_name, param_schema in decl.parameters.properties.items(): + type_hint = _get_python_type_hint(param_schema) + description = _get_schema_attr(param_schema, "description", "") + + if param_name in required: + params.append(f"{param_name}: {type_hint}") + else: + params.append(f"{param_name}: {type_hint} = None") + + param_docs.append(f" {param_name}: {description}") + + param_str = ", ".join(params) + param_doc_str = "\n".join(param_docs) if param_docs else " None" + + # Build the function stub + stub = f''' +def {tool.name}({param_str}) -> dict: + """{tool.description} + + Args: +{param_doc_str} + + Returns: + Tool execution result as a dictionary. + """ + kwargs = {{k: v for k, v in locals().items() if v is not None}} + return _call_adk_tool("{tool.name}", **kwargs) + +''' + return stub + + +def generate_tool_stubs(tools: List[BaseTool]) -> str: + """Generate Python function stubs for all tools. + + Args: + tools: List of tools to generate stubs for. + + Returns: + Python code string containing all tool stubs. + """ + stubs = [ + "# ============================================================================", + "# Tool Function Stubs", + "# ============================================================================", + "", + ] + + for tool in tools: + stub = _generate_tool_stub(tool) + if stub: + stubs.append(stub) + + return "\n".join(stubs) + + +def generate_final_answer_stub() -> str: + """Generate the final_answer function documentation. + + This is included in the runtime header, but we generate additional + documentation here for the system prompt. + + Returns: + Documentation string about the final_answer function. + """ + return """ +The `final_answer(result)` function is available to mark your final result. +Call this function when you have completed the task and have a result to return. +Example: `final_answer("The calculation result is 42")` +""" + + +# Few-shot examples for the system prompt +SYSTEM_PROMPT_EXAMPLES = """ +## Examples + +### Example 1 - Using a tool to search for information: +```tool_code +# Search for relevant information +result = web_search(query="Python async best practices") +# Display the findings +for snippet in result.get("snippets", [])[:3]: + print(snippet) +``` + +### Example 2 - Processing data and providing a final answer: +```tool_code +# Read and process data +data = read_file(path="sales_data.csv") +rows = data.get("rows", []) + +# Calculate the total +total = sum(float(row.get("amount", 0)) for row in rows) + +# Provide the final answer +final_answer(f"The total sales amount is ${total:.2f}") +``` + +### Example 3 - Multi-step reasoning with tool calls: +```tool_code +# Step 1: Get the current weather +weather = get_weather(city="San Francisco") +temp = weather.get("temperature", "unknown") +print(f"Current temperature: {temp}") +``` + +Then, after seeing the output: +```tool_code +# Step 2: Based on temperature, provide recommendation +if temp > 70: + recommendation = "It's warm! Consider light clothing." +else: + recommendation = "It might be cool. Bring a jacket." + +final_answer(recommendation) +``` +""" + + +def generate_system_prompt( + tools: List[BaseTool], + custom_instruction: str = "", +) -> str: + """Generate the system prompt for the CodingAgent. + + Args: + tools: List of available tools. + custom_instruction: Additional custom instructions. + + Returns: + Complete system prompt string. + """ + # Build tool documentation + tool_docs = [] + for tool in tools: + decl = tool._get_declaration() + if decl: + params_doc = "" + if decl.parameters and decl.parameters.properties: + param_lines = [] + required = set(decl.parameters.required or []) + for name, schema in decl.parameters.properties.items(): + type_hint = _get_python_type_hint(schema) + req_marker = " (required)" if name in required else " (optional)" + desc = _get_schema_attr(schema, "description", "") + param_lines.append( + f" - {name}: {type_hint}{req_marker} - {desc}" + ) + params_doc = "\n".join(param_lines) + + tool_docs.append(f""" +### {tool.name} +{tool.description} + +Parameters: +{params_doc if params_doc else " None"} +""") + + tools_section = "\n".join(tool_docs) if tool_docs else "No tools available." + + system_prompt = f"""You are a coding agent that solves tasks by writing and executing Python code. + +## How to Respond + +1. **Write Python code** in code blocks marked with ```tool_code +2. **Use available tools** by calling them as Python functions +3. **Print intermediate results** to see outputs and make decisions +4. **Call final_answer()** when you have the final result + +## Available Tools + +{tools_section} + +## Special Functions + +- `final_answer(result)`: Call this to provide your final answer and complete the task. +- `print(...)`: Use print statements to see intermediate results. + +{SYSTEM_PROMPT_EXAMPLES} + +## Important Guidelines + +1. **One step at a time**: Write code for one logical step, wait for output, then continue. +2. **Always print results**: Use print() to see what tools return. +3. **Handle errors gracefully**: If a tool fails, try an alternative approach. +4. **Call final_answer()**: When done, call final_answer() with your result. +5. **No external imports**: Only use the provided tools and standard library. + +{custom_instruction} +""" + + return system_prompt.strip() + + +def generate_full_code_with_stubs( + user_code: str, + tools: List[BaseTool], + tool_server_url: str, +) -> str: + """Generate complete executable code with runtime header and tool stubs. + + Args: + user_code: The user-generated code to execute. + tools: List of available tools. + tool_server_url: URL of the tool execution server. + + Returns: + Complete Python code ready for execution. + """ + runtime_header = generate_runtime_header(tool_server_url) + tool_stubs = generate_tool_stubs(tools) + + full_code = f"""{runtime_header} +{tool_stubs} +# ============================================================================ +# User Code +# ============================================================================ + +{user_code} + +# ============================================================================ +# Output tool traces for extraction +# ============================================================================ +import json as __output_json +print("__TOOL_TRACE__:" + __output_json.dumps(__get_tool_traces())) +""" + + return full_code diff --git a/src/google/adk/code_executors/tool_execution_server.py b/src/google/adk/code_executors/tool_execution_server.py new file mode 100644 index 0000000000..fa1c2efda1 --- /dev/null +++ b/src/google/adk/code_executors/tool_execution_server.py @@ -0,0 +1,366 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tool execution server for CodingAgent. + +This module provides a FastAPI server that handles tool execution requests +from code running in containers. It routes requests to the appropriate +ADK tools with full ToolContext support. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import os +import socket +import threading +from dataclasses import dataclass +from dataclasses import field +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import TYPE_CHECKING + +import uvicorn +from fastapi import FastAPI +from fastapi import HTTPException +from pydantic import BaseModel + +if TYPE_CHECKING: + from ..agents.invocation_context import InvocationContext + from ..tools.base_tool import BaseTool + from ..tools.tool_context import ToolContext + +logger = logging.getLogger("google_adk." + __name__) + + +class ToolCallRequest(BaseModel): + """Request model for tool calls.""" + + tool_name: str + args: Dict[str, Any] + + +class ToolCallResponse(BaseModel): + """Response model for tool calls.""" + + result: Any + success: bool + error: Optional[str] = None + + +@dataclass +class ToolTrace: + """Record of a tool call for debugging and telemetry.""" + + tool_name: str + args: Dict[str, Any] + result: Any = None + error: Optional[str] = None + success: bool = False + duration_ms: float = 0.0 + + +def detect_docker_host_address() -> str: + """Detect the appropriate host address for Docker containers. + + On macOS and Windows (Docker Desktop), use host.docker.internal. + On Linux, use 172.17.0.1 (default Docker bridge network gateway). + + Note: host.docker.internal only resolves from within containers, + not from the host machine, so we check the platform instead. + + Returns: + The detected host address. + """ + import platform + + system = platform.system().lower() + + # macOS and Windows use Docker Desktop which supports host.docker.internal + if system in ("darwin", "windows"): + return "host.docker.internal" + + # Linux: use Docker bridge network gateway + return "172.17.0.1" + + +class ToolExecutionServer: + """FastAPI server for executing ADK tools via HTTP. + + This server is designed to run on the host machine and receive tool + execution requests from code running in Docker containers. + + Attributes: + host: Host address to bind the server to. + port: Port to bind the server to. + tools: Dictionary mapping tool names to tool instances. + invocation_context: The current invocation context. + tool_context: The current tool context. + traces: List of tool call traces. + """ + + def __init__( + self, + host: str = "0.0.0.0", + port: int = 8765, + tools: Optional[List[BaseTool]] = None, + invocation_context: Optional[InvocationContext] = None, + ): + """Initialize the tool execution server. + + Args: + host: Host address to bind to. + port: Port to bind to. + tools: List of tools to make available. + invocation_context: The invocation context for tool calls. + """ + self.host = host + self.port = port + self.tools: Dict[str, BaseTool] = {} + self.invocation_context = invocation_context + self.tool_context: Optional[ToolContext] = None + self.traces: List[ToolTrace] = [] + self._server: Optional[uvicorn.Server] = None + self._server_thread: Optional[threading.Thread] = None + self._app = self._create_app() + + if tools: + for tool in tools: + self.register_tool(tool) + + def _create_app(self) -> FastAPI: + """Create the FastAPI application with routes.""" + app = FastAPI( + title="ADK Tool Execution Server", + description="HTTP server for executing ADK tools from containers", + version="1.0.0", + ) + + @app.post("/tool_call", response_model=ToolCallResponse) + async def handle_tool_call(request: ToolCallRequest) -> ToolCallResponse: + """Handle a tool call request.""" + return await self._execute_tool(request.tool_name, request.args) + + @app.get("/tool_trace") + async def get_tool_traces() -> List[Dict[str, Any]]: + """Get all tool call traces.""" + return [ + { + "tool_name": t.tool_name, + "args": t.args, + "result": t.result, + "error": t.error, + "success": t.success, + "duration_ms": t.duration_ms, + } + for t in self.traces + ] + + @app.delete("/tool_trace") + async def clear_tool_traces() -> Dict[str, str]: + """Clear all tool call traces.""" + self.traces.clear() + return {"status": "cleared"} + + @app.get("/health") + async def health_check() -> Dict[str, str]: + """Health check endpoint.""" + return {"status": "healthy"} + + @app.get("/tools") + async def list_tools() -> List[str]: + """List available tools.""" + return list(self.tools.keys()) + + return app + + def register_tool(self, tool: BaseTool) -> None: + """Register a tool with the server. + + Args: + tool: The tool to register. + """ + self.tools[tool.name] = tool + logger.debug("Registered tool: %s", tool.name) + + def set_context( + self, + invocation_context: InvocationContext, + tool_context: Optional[ToolContext] = None, + ) -> None: + """Set the context for tool execution. + + Args: + invocation_context: The invocation context. + tool_context: The tool context. + """ + self.invocation_context = invocation_context + self.tool_context = tool_context + + async def _execute_tool( + self, + tool_name: str, + args: Dict[str, Any], + ) -> ToolCallResponse: + """Execute a tool and return the result. + + Args: + tool_name: Name of the tool to execute. + args: Arguments to pass to the tool. + + Returns: + The tool execution response. + """ + import time + + start_time = time.time() + trace = ToolTrace(tool_name=tool_name, args=args) + + if tool_name not in self.tools: + trace.error = f"Tool not found: {tool_name}" + trace.success = False + trace.duration_ms = (time.time() - start_time) * 1000 + self.traces.append(trace) + raise HTTPException(status_code=404, detail=trace.error) + + tool = self.tools[tool_name] + + try: + # Create a tool context if we have an invocation context + if self.invocation_context and not self.tool_context: + from ..tools.tool_context import ToolContext + + self.tool_context = ToolContext( + invocation_context=self.invocation_context, + ) + + if self.tool_context: + result = await tool.run_async(args=args, tool_context=self.tool_context) + else: + # If no context available, create a minimal mock context + # This is a fallback and shouldn't happen in normal operation + logger.warning("Executing tool %s without proper context", tool_name) + result = await tool.run_async(args=args, tool_context=None) + + trace.result = result + trace.success = True + trace.duration_ms = (time.time() - start_time) * 1000 + self.traces.append(trace) + + return ToolCallResponse(result=result, success=True) + + except Exception as e: + trace.error = str(e) + trace.success = False + trace.duration_ms = (time.time() - start_time) * 1000 + self.traces.append(trace) + logger.error("Tool execution failed: %s - %s", tool_name, e) + raise HTTPException(status_code=500, detail=str(e)) from e + + def start(self) -> None: + """Start the server in a background thread.""" + if self._server_thread and self._server_thread.is_alive(): + logger.warning("Server already running") + return + + config = uvicorn.Config( + app=self._app, + host=self.host, + port=self.port, + log_level="warning", + ) + self._server = uvicorn.Server(config) + + def run_server(): + asyncio.run(self._server.serve()) + + self._server_thread = threading.Thread(target=run_server, daemon=True) + self._server_thread.start() + + # Wait for server to be ready + self._wait_for_server() + logger.info("Tool execution server started on %s:%d", self.host, self.port) + + def _wait_for_server(self, timeout: float = 10.0) -> None: + """Wait for the server to be ready. + + Args: + timeout: Maximum time to wait in seconds. + """ + import time + + start = time.time() + while time.time() - start < timeout: + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(1) + result = sock.connect_ex(("127.0.0.1", self.port)) + sock.close() + if result == 0: + return + except Exception: + pass + time.sleep(0.1) + + logger.warning("Server may not be fully ready after %.1f seconds", timeout) + + def stop(self) -> None: + """Stop the server.""" + if self._server: + self._server.should_exit = True + if self._server_thread: + self._server_thread.join(timeout=5.0) + self._server = None + self._server_thread = None + logger.info("Tool execution server stopped") + + def get_url(self, for_container: bool = True) -> str: + """Get the URL for the server. + + Args: + for_container: If True, return URL accessible from Docker containers. + + Returns: + The server URL. + """ + if for_container: + host = detect_docker_host_address() + else: + host = "localhost" if self.host == "0.0.0.0" else self.host + return f"http://{host}:{self.port}" + + def clear_traces(self) -> None: + """Clear all tool call traces.""" + self.traces.clear() + + def get_traces(self) -> List[ToolTrace]: + """Get all tool call traces. + + Returns: + List of tool traces. + """ + return self.traces.copy() + + def __enter__(self) -> ToolExecutionServer: + """Context manager entry.""" + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Context manager exit.""" + self.stop() diff --git a/src/google/adk/telemetry/tracing.py b/src/google/adk/telemetry/tracing.py index bd964a5d07..609d96c14d 100644 --- a/src/google/adk/telemetry/tracing.py +++ b/src/google/adk/telemetry/tracing.py @@ -30,6 +30,7 @@ import logging import os from typing import Any +from typing import Optional from typing import TYPE_CHECKING from google.genai import types @@ -62,14 +63,14 @@ # By default some ADK spans include attributes with potential PII data. # This env, when set to false, allows to disable populating those attributes. -ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS = 'ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS' +ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS = "ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS" # Standard OTEL env variable to enable logging of prompt/response content. OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT = ( - 'OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT' + "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT" ) -USER_CONTENT_ELIDED = '' +USER_CONTENT_ELIDED = "" # Needed to avoid circular imports if TYPE_CHECKING: @@ -81,18 +82,18 @@ from ..tools.base_tool import BaseTool tracer = trace.get_tracer( - instrumenting_module_name='gcp.vertex.agent', + instrumenting_module_name="gcp.vertex.agent", instrumenting_library_version=version.__version__, schema_url=Schemas.V1_36_0.value, ) otel_logger = _logs.get_logger( - instrumenting_module_name='gcp.vertex.agent', + instrumenting_module_name="gcp.vertex.agent", instrumenting_library_version=version.__version__, schema_url=Schemas.V1_36_0.value, ) -logger = logging.getLogger('google_adk.' + __name__) +logger = logging.getLogger("google_adk." + __name__) def _safe_json_serialize(obj) -> str: @@ -108,10 +109,10 @@ def _safe_json_serialize(obj) -> str: try: # Try direct JSON serialization first return json.dumps( - obj, ensure_ascii=False, default=lambda o: '' + obj, ensure_ascii=False, default=lambda o: "" ) except (TypeError, OverflowError): - return '' + return "" def trace_agent_invocation( @@ -138,7 +139,7 @@ def trace_agent_invocation( """ # Required - span.set_attribute(GEN_AI_OPERATION_NAME, 'invoke_agent') + span.set_attribute(GEN_AI_OPERATION_NAME, "invoke_agent") # Conditionally Required span.set_attribute(GEN_AI_AGENT_DESCRIPTION, agent.description) @@ -161,7 +162,7 @@ def trace_tool_call( """ span = trace.get_current_span() - span.set_attribute(GEN_AI_OPERATION_NAME, 'execute_tool') + span.set_attribute(GEN_AI_OPERATION_NAME, "execute_tool") span.set_attribute(GEN_AI_TOOL_DESCRIPTION, tool.description) span.set_attribute(GEN_AI_TOOL_NAME, tool.name) @@ -171,20 +172,20 @@ def trace_tool_call( # Setting empty llm request and response (as UI expect these) while not # applicable for tool_response. - span.set_attribute('gcp.vertex.agent.llm_request', '{}') - span.set_attribute('gcp.vertex.agent.llm_response', '{}') + span.set_attribute("gcp.vertex.agent.llm_request", "{}") + span.set_attribute("gcp.vertex.agent.llm_response", "{}") if _should_add_request_response_to_spans(): span.set_attribute( - 'gcp.vertex.agent.tool_call_args', + "gcp.vertex.agent.tool_call_args", _safe_json_serialize(args), ) else: - span.set_attribute('gcp.vertex.agent.tool_call_args', '{}') + span.set_attribute("gcp.vertex.agent.tool_call_args", "{}") # Tracing tool response - tool_call_id = '' - tool_response = '' + tool_call_id = "" + tool_response = "" if ( function_response_event is not None and function_response_event.content is not None @@ -201,16 +202,16 @@ def trace_tool_call( span.set_attribute(GEN_AI_TOOL_CALL_ID, tool_call_id) if not isinstance(tool_response, dict): - tool_response = {'result': tool_response} + tool_response = {"result": tool_response} if function_response_event is not None: - span.set_attribute('gcp.vertex.agent.event_id', function_response_event.id) + span.set_attribute("gcp.vertex.agent.event_id", function_response_event.id) if _should_add_request_response_to_spans(): span.set_attribute( - 'gcp.vertex.agent.tool_response', + "gcp.vertex.agent.tool_response", _safe_json_serialize(tool_response), ) else: - span.set_attribute('gcp.vertex.agent.tool_response', '{}') + span.set_attribute("gcp.vertex.agent.tool_response", "{}") def trace_merged_tool_calls( @@ -229,34 +230,34 @@ def trace_merged_tool_calls( span = trace.get_current_span() - span.set_attribute(GEN_AI_OPERATION_NAME, 'execute_tool') - span.set_attribute(GEN_AI_TOOL_NAME, '(merged tools)') - span.set_attribute(GEN_AI_TOOL_DESCRIPTION, '(merged tools)') + span.set_attribute(GEN_AI_OPERATION_NAME, "execute_tool") + span.set_attribute(GEN_AI_TOOL_NAME, "(merged tools)") + span.set_attribute(GEN_AI_TOOL_DESCRIPTION, "(merged tools)") span.set_attribute(GEN_AI_TOOL_CALL_ID, response_event_id) # TODO(b/441461932): See if these are still necessary - span.set_attribute('gcp.vertex.agent.tool_call_args', 'N/A') - span.set_attribute('gcp.vertex.agent.event_id', response_event_id) + span.set_attribute("gcp.vertex.agent.tool_call_args", "N/A") + span.set_attribute("gcp.vertex.agent.event_id", response_event_id) try: function_response_event_json = function_response_event.model_dumps_json( exclude_none=True ) except Exception: # pylint: disable=broad-exception-caught - function_response_event_json = '' + function_response_event_json = "" if _should_add_request_response_to_spans(): span.set_attribute( - 'gcp.vertex.agent.tool_response', + "gcp.vertex.agent.tool_response", function_response_event_json, ) else: - span.set_attribute('gcp.vertex.agent.tool_response', '{}') + span.set_attribute("gcp.vertex.agent.tool_response", "{}") # Setting empty llm request and response (as UI expect these) while not # applicable for tool_response. - span.set_attribute('gcp.vertex.agent.llm_request', '{}') + span.set_attribute("gcp.vertex.agent.llm_request", "{}") span.set_attribute( - 'gcp.vertex.agent.llm_response', - '{}', + "gcp.vertex.agent.llm_response", + "{}", ) @@ -281,57 +282,57 @@ def trace_call_llm( span = span or trace.get_current_span() # Special standard Open Telemetry GenaI attributes that indicate # that this is a span related to a Generative AI system. - span.set_attribute('gen_ai.system', 'gcp.vertex.agent') - span.set_attribute('gen_ai.request.model', llm_request.model) + span.set_attribute("gen_ai.system", "gcp.vertex.agent") + span.set_attribute("gen_ai.request.model", llm_request.model) span.set_attribute( - 'gcp.vertex.agent.invocation_id', invocation_context.invocation_id + "gcp.vertex.agent.invocation_id", invocation_context.invocation_id ) span.set_attribute( - 'gcp.vertex.agent.session_id', invocation_context.session.id + "gcp.vertex.agent.session_id", invocation_context.session.id ) - span.set_attribute('gcp.vertex.agent.event_id', event_id) + span.set_attribute("gcp.vertex.agent.event_id", event_id) # Consider removing once GenAI SDK provides a way to record this info. if _should_add_request_response_to_spans(): span.set_attribute( - 'gcp.vertex.agent.llm_request', + "gcp.vertex.agent.llm_request", _safe_json_serialize(_build_llm_request_for_trace(llm_request)), ) else: - span.set_attribute('gcp.vertex.agent.llm_request', '{}') + span.set_attribute("gcp.vertex.agent.llm_request", "{}") # Consider removing once GenAI SDK provides a way to record this info. if llm_request.config: if llm_request.config.top_p: span.set_attribute( - 'gen_ai.request.top_p', + "gen_ai.request.top_p", llm_request.config.top_p, ) if llm_request.config.max_output_tokens: span.set_attribute( - 'gen_ai.request.max_tokens', + "gen_ai.request.max_tokens", llm_request.config.max_output_tokens, ) try: llm_response_json = llm_response.model_dump_json(exclude_none=True) except Exception: # pylint: disable=broad-exception-caught - llm_response_json = '' + llm_response_json = "" if _should_add_request_response_to_spans(): span.set_attribute( - 'gcp.vertex.agent.llm_response', + "gcp.vertex.agent.llm_response", llm_response_json, ) else: - span.set_attribute('gcp.vertex.agent.llm_response', '{}') + span.set_attribute("gcp.vertex.agent.llm_response", "{}") if llm_response.usage_metadata is not None: span.set_attribute( - 'gen_ai.usage.input_tokens', + "gen_ai.usage.input_tokens", llm_response.usage_metadata.prompt_token_count, ) if llm_response.usage_metadata.candidates_token_count is not None: span.set_attribute( - 'gen_ai.usage.output_tokens', + "gen_ai.usage.output_tokens", llm_response.usage_metadata.candidates_token_count, ) if llm_response.finish_reason: @@ -340,7 +341,7 @@ def trace_call_llm( except AttributeError: finish_reason_str = str(llm_response.finish_reason).lower() span.set_attribute( - 'gen_ai.response.finish_reasons', + "gen_ai.response.finish_reasons", [finish_reason_str], ) @@ -362,23 +363,23 @@ def trace_send_data( """ span = trace.get_current_span() span.set_attribute( - 'gcp.vertex.agent.invocation_id', invocation_context.invocation_id + "gcp.vertex.agent.invocation_id", invocation_context.invocation_id ) - span.set_attribute('gcp.vertex.agent.event_id', event_id) + span.set_attribute("gcp.vertex.agent.event_id", event_id) # Once instrumentation is added to the GenAI SDK, consider whether this # information still needs to be recorded by the Agent Development Kit. if _should_add_request_response_to_spans(): span.set_attribute( - 'gcp.vertex.agent.data', + "gcp.vertex.agent.data", _safe_json_serialize([ types.Content(role=content.role, parts=content.parts).model_dump( - exclude_none=True, mode='json' + exclude_none=True, mode="json" ) for content in data ]), ) else: - span.set_attribute('gcp.vertex.agent.data', '{}') + span.set_attribute("gcp.vertex.agent.data", "{}") def _build_llm_request_for_trace(llm_request: LlmRequest) -> dict[str, Any]: @@ -396,18 +397,18 @@ def _build_llm_request_for_trace(llm_request: LlmRequest) -> dict[str, Any]: """ # Some fields in LlmRequest are function pointers and cannot be serialized. result = { - 'model': llm_request.model, - 'config': llm_request.config.model_dump( - exclude_none=True, exclude='response_schema', mode='json' + "model": llm_request.model, + "config": llm_request.config.model_dump( + exclude_none=True, exclude="response_schema", mode="json" ), - 'contents': [], + "contents": [], } # We do not want to send bytes data to the trace. for content in llm_request.contents: parts = [part for part in content.parts if not part.inline_data] - result['contents'].append( + result["contents"].append( types.Content(role=content.role, parts=parts).model_dump( - exclude_none=True, mode='json' + exclude_none=True, mode="json" ) ) return result @@ -419,8 +420,8 @@ def _build_llm_request_for_trace(llm_request: LlmRequest) -> dict[str, Any]: # to false. def _should_add_request_response_to_spans() -> bool: disabled_via_env_var = os.getenv( - ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS, 'true' - ).lower() in ('false', '0') + ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS, "true" + ).lower() in ("false", "0") return not disabled_via_env_var @@ -438,7 +439,7 @@ def use_generate_content_span( common_attributes = { GEN_AI_CONVERSATION_ID: invocation_context.session.id, - 'gcp.vertex.agent.event_id': model_response_event.id, + "gcp.vertex.agent.event_id": model_response_event.id, } if ( _is_gemini_agent(invocation_context.agent) @@ -455,8 +456,8 @@ def use_generate_content_span( def _should_log_prompt_response_content() -> bool: return os.getenv( - OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT, '' - ).lower() in ('1', 'true') + OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT, "" + ).lower() in ("1", "true") def _serialize_content(content: types.ContentUnion) -> AnyValue: @@ -481,9 +482,9 @@ def _serialize_content_with_elision( def _instrumented_with_opentelemetry_instrumentation_google_genai() -> bool: maybe_wrapped_function = Models.generate_content - while wrapped := getattr(maybe_wrapped_function, '__wrapped__', None): + while wrapped := getattr(maybe_wrapped_function, "__wrapped__", None): if ( - 'opentelemetry/instrumentation/google_genai' + "opentelemetry/instrumentation/google_genai" in maybe_wrapped_function.__code__.co_filename ): return True @@ -515,15 +516,15 @@ def _use_native_generate_content_span( f"generate_content {llm_request.model or ''}" ) as span: span.set_attribute(GEN_AI_SYSTEM, _guess_gemini_system_name()) - span.set_attribute(GEN_AI_OPERATION_NAME, 'generate_content') - span.set_attribute(GEN_AI_REQUEST_MODEL, llm_request.model or '') + span.set_attribute(GEN_AI_OPERATION_NAME, "generate_content") + span.set_attribute(GEN_AI_REQUEST_MODEL, llm_request.model or "") span.set_attributes(common_attributes) otel_logger.emit( LogRecord( - event_name='gen_ai.system.message', + event_name="gen_ai.system.message", body={ - 'content': _serialize_content_with_elision( + "content": _serialize_content_with_elision( llm_request.config.system_instruction ) }, @@ -534,8 +535,8 @@ def _use_native_generate_content_span( for content in llm_request.contents: otel_logger.emit( LogRecord( - event_name='gen_ai.user.message', - body={'content': _serialize_content_with_elision(content)}, + event_name="gen_ai.user.message", + body={"content": _serialize_content_with_elision(content)}, attributes={GEN_AI_SYSTEM: _guess_gemini_system_name()}, ) ) @@ -566,12 +567,12 @@ def trace_generate_content_result(span: Span | None, llm_response: LlmResponse): otel_logger.emit( LogRecord( - event_name='gen_ai.choice', + event_name="gen_ai.choice", body={ - 'content': _serialize_content_with_elision(llm_response.content), - 'index': 0, # ADK always returns a single candidate + "content": _serialize_content_with_elision(llm_response.content), + "index": 0, # ADK always returns a single candidate } - | {'finish_reason': llm_response.finish_reason.value} + | {"finish_reason": llm_response.finish_reason.value} if llm_response.finish_reason is not None else {}, attributes={GEN_AI_SYSTEM: _guess_gemini_system_name()}, @@ -582,6 +583,184 @@ def trace_generate_content_result(span: Span | None, llm_response: LlmResponse): def _guess_gemini_system_name() -> str: return ( GenAiSystemValues.VERTEX_AI.name.lower() - if os.getenv('GOOGLE_GENAI_USE_VERTEXAI', '').lower() in ('true', '1') + if os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() in ("true", "1") else GenAiSystemValues.GEMINI.name.lower() ) + + +# ============================================================================= +# CodingAgent-specific tracing functions +# ============================================================================= + + +def trace_code_generation( + agent_name: str, + code: str, + iteration: int, + duration_ms: float, +) -> None: + """Traces code generation by CodingAgent. + + Args: + agent_name: Name of the CodingAgent. + code: The generated code. + iteration: Current iteration number. + duration_ms: Time taken for code generation in milliseconds. + """ + span = trace.get_current_span() + + span.set_attribute(GEN_AI_OPERATION_NAME, "generate_code") + span.set_attribute(GEN_AI_AGENT_NAME, agent_name) + span.set_attribute("gcp.vertex.agent.coding_agent.iteration", iteration) + span.set_attribute( + "gcp.vertex.agent.coding_agent.generation_duration_ms", duration_ms + ) + + if _should_add_request_response_to_spans(): + # Truncate code if too long for span attribute + max_code_length = 10000 + truncated_code = code[:max_code_length] + if len(code) > max_code_length: + truncated_code += "\n... [truncated]" + span.set_attribute( + "gcp.vertex.agent.coding_agent.generated_code", truncated_code + ) + else: + span.set_attribute("gcp.vertex.agent.coding_agent.generated_code", "{}") + + +def trace_code_execution( + agent_name: str, + code: str, + stdout: str, + stderr: str, + duration_ms: float, + success: bool, + has_final_answer: bool, +) -> None: + """Traces code execution by CodingAgent. + + Args: + agent_name: Name of the CodingAgent. + code: The executed code. + stdout: Standard output from execution. + stderr: Standard error from execution. + duration_ms: Time taken for execution in milliseconds. + success: Whether execution was successful. + has_final_answer: Whether execution produced a final answer. + """ + span = trace.get_current_span() + + span.set_attribute(GEN_AI_OPERATION_NAME, "execute_code") + span.set_attribute(GEN_AI_AGENT_NAME, agent_name) + span.set_attribute( + "gcp.vertex.agent.coding_agent.execution_duration_ms", duration_ms + ) + span.set_attribute("gcp.vertex.agent.coding_agent.execution_success", success) + span.set_attribute( + "gcp.vertex.agent.coding_agent.has_final_answer", has_final_answer + ) + + if _should_add_request_response_to_spans(): + # Truncate outputs if too long + max_output_length = 5000 + truncated_stdout = stdout[:max_output_length] + if len(stdout) > max_output_length: + truncated_stdout += "\n... [truncated]" + truncated_stderr = stderr[:max_output_length] + if len(stderr) > max_output_length: + truncated_stderr += "\n... [truncated]" + truncated_code = code[:max_output_length] + if len(code) > max_output_length: + truncated_code += "\n... [truncated]" + + span.set_attribute( + "gcp.vertex.agent.coding_agent.executed_code", truncated_code + ) + span.set_attribute("gcp.vertex.agent.coding_agent.stdout", truncated_stdout) + span.set_attribute("gcp.vertex.agent.coding_agent.stderr", truncated_stderr) + else: + span.set_attribute("gcp.vertex.agent.coding_agent.executed_code", "{}") + span.set_attribute("gcp.vertex.agent.coding_agent.stdout", "{}") + span.set_attribute("gcp.vertex.agent.coding_agent.stderr", "{}") + + +def trace_import_validation( + agent_name: str, + code: str, + violations: list[str], + duration_ms: float, +) -> None: + """Traces import validation by CodingAgent. + + Args: + agent_name: Name of the CodingAgent. + code: The code that was validated. + violations: List of import violations found. + duration_ms: Time taken for validation in milliseconds. + """ + span = trace.get_current_span() + + span.set_attribute(GEN_AI_OPERATION_NAME, "validate_imports") + span.set_attribute(GEN_AI_AGENT_NAME, agent_name) + span.set_attribute( + "gcp.vertex.agent.coding_agent.validation_duration_ms", duration_ms + ) + span.set_attribute( + "gcp.vertex.agent.coding_agent.violation_count", len(violations) + ) + span.set_attribute( + "gcp.vertex.agent.coding_agent.validation_passed", len(violations) == 0 + ) + + if _should_add_request_response_to_spans() and violations: + span.set_attribute( + "gcp.vertex.agent.coding_agent.violations", + _safe_json_serialize(violations), + ) + + +def trace_tool_ipc( + agent_name: str, + tool_name: str, + args: dict[str, Any], + result: Any, + duration_ms: float, + success: bool, + error: Optional[str] = None, +) -> None: + """Traces tool IPC calls from container to host. + + Args: + agent_name: Name of the CodingAgent. + tool_name: Name of the tool called. + args: Arguments passed to the tool. + result: Result returned by the tool. + duration_ms: Time taken for the IPC call in milliseconds. + success: Whether the call was successful. + error: Error message if call failed. + """ + span = trace.get_current_span() + + span.set_attribute(GEN_AI_OPERATION_NAME, "tool_ipc") + span.set_attribute(GEN_AI_AGENT_NAME, agent_name) + span.set_attribute(GEN_AI_TOOL_NAME, tool_name) + span.set_attribute( + "gcp.vertex.agent.coding_agent.ipc_duration_ms", duration_ms + ) + span.set_attribute("gcp.vertex.agent.coding_agent.ipc_success", success) + + if error: + span.set_attribute("gcp.vertex.agent.coding_agent.ipc_error", error) + + if _should_add_request_response_to_spans(): + span.set_attribute( + "gcp.vertex.agent.coding_agent.tool_args", _safe_json_serialize(args) + ) + span.set_attribute( + "gcp.vertex.agent.coding_agent.tool_result", + _safe_json_serialize(result), + ) + else: + span.set_attribute("gcp.vertex.agent.coding_agent.tool_args", "{}") + span.set_attribute("gcp.vertex.agent.coding_agent.tool_result", "{}") diff --git a/tests/unittests/agents/test_coding_agent.py b/tests/unittests/agents/test_coding_agent.py new file mode 100644 index 0000000000..2516507a3f --- /dev/null +++ b/tests/unittests/agents/test_coding_agent.py @@ -0,0 +1,309 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for CodingAgent.""" + +from __future__ import annotations + +import pytest +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch + +from google.adk.agents.coding_agent import CodingAgent +from google.adk.agents.coding_agent import CodingAgentState +from google.adk.agents.coding_agent_config import CodingAgentConfig +from google.adk.agents.coding_agent_config import DEFAULT_SAFE_IMPORTS +from google.adk.tools.base_tool import BaseTool + + +class TestCodingAgentConfig: + """Tests for CodingAgentConfig.""" + + def test_default_values(self): + """Test that default values are set correctly.""" + config = CodingAgentConfig(name="test_agent") + + assert config.name == "test_agent" + assert config.agent_class == "CodingAgent" + assert config.max_iterations == 10 + assert config.error_retry_attempts == 2 + assert config.stateful is False + assert config.tool_server_port == 8765 + assert config.authorized_imports == DEFAULT_SAFE_IMPORTS + + def test_custom_values(self): + """Test that custom values can be set.""" + custom_imports = frozenset({"json", "math"}) + config = CodingAgentConfig( + name="custom_agent", + model="gemini-2.0-flash", + max_iterations=20, + error_retry_attempts=5, + stateful=True, + tool_server_port=9999, + authorized_imports=custom_imports, + ) + + assert config.name == "custom_agent" + assert config.model == "gemini-2.0-flash" + assert config.max_iterations == 20 + assert config.error_retry_attempts == 5 + assert config.stateful is True + assert config.tool_server_port == 9999 + assert config.authorized_imports == custom_imports + + def test_max_iterations_bounds(self): + """Test max_iterations validation.""" + # Valid bounds + config = CodingAgentConfig(name="test", max_iterations=1) + assert config.max_iterations == 1 + + config = CodingAgentConfig(name="test", max_iterations=100) + assert config.max_iterations == 100 + + # Invalid bounds + with pytest.raises(ValueError): + CodingAgentConfig(name="test", max_iterations=0) + + with pytest.raises(ValueError): + CodingAgentConfig(name="test", max_iterations=101) + + def test_port_bounds(self): + """Test tool_server_port validation.""" + # Valid bounds + config = CodingAgentConfig(name="test", tool_server_port=1024) + assert config.tool_server_port == 1024 + + config = CodingAgentConfig(name="test", tool_server_port=65535) + assert config.tool_server_port == 65535 + + # Invalid bounds + with pytest.raises(ValueError): + CodingAgentConfig(name="test", tool_server_port=1023) + + with pytest.raises(ValueError): + CodingAgentConfig(name="test", tool_server_port=65536) + + +class TestCodingAgentState: + """Tests for CodingAgentState.""" + + def test_default_state(self): + """Test default state values.""" + state = CodingAgentState() + + assert state.iteration_count == 0 + assert state.error_count == 0 + assert state.execution_history == [] + + def test_state_with_history(self): + """Test state with execution history.""" + history = [ + {"iteration": 1, "code": "print('hello')", "success": True}, + {"iteration": 2, "code": "print('world')", "success": True}, + ] + state = CodingAgentState( + iteration_count=2, + error_count=0, + execution_history=history, + ) + + assert state.iteration_count == 2 + assert len(state.execution_history) == 2 + + def test_state_serialization(self): + """Test state can be serialized and deserialized.""" + state = CodingAgentState( + iteration_count=5, + error_count=1, + execution_history=[{"iteration": 1, "code": "x = 1"}], + ) + + dumped = state.model_dump() + restored = CodingAgentState.model_validate(dumped) + + assert restored.iteration_count == 5 + assert restored.error_count == 1 + assert len(restored.execution_history) == 1 + + +class TestCodingAgent: + """Tests for CodingAgent.""" + + def test_agent_creation(self): + """Test basic agent creation.""" + agent = CodingAgent( + name="test_coding_agent", + description="A test coding agent", + ) + + assert agent.name == "test_coding_agent" + assert agent.description == "A test coding agent" + assert agent.max_iterations == 10 + assert agent.error_retry_attempts == 2 + + def test_agent_with_custom_config(self): + """Test agent with custom configuration.""" + agent = CodingAgent( + name="custom_agent", + model="gemini-2.0-flash", + max_iterations=5, + error_retry_attempts=3, + stateful=True, + ) + + assert agent.name == "custom_agent" + assert agent.model == "gemini-2.0-flash" + assert agent.max_iterations == 5 + assert agent.error_retry_attempts == 3 + assert agent.stateful is True + + def test_extract_code_block_tool_code(self): + """Test code extraction from tool_code blocks.""" + agent = CodingAgent(name="test") + + response = """Here's some code: +```tool_code +result = search(query="test") +print(result) +``` +That should work.""" + + code = agent._extract_code_block(response) + assert code == 'result = search(query="test")\nprint(result)' + + def test_extract_code_block_python(self): + """Test code extraction from python blocks.""" + agent = CodingAgent(name="test") + + response = """Here's some code: +```python +x = 1 + 2 +print(x) +``` +Done.""" + + code = agent._extract_code_block(response) + assert code == "x = 1 + 2\nprint(x)" + + def test_extract_code_block_prefers_tool_code(self): + """Test that tool_code blocks are preferred over python blocks.""" + agent = CodingAgent(name="test") + + response = """Code: +```tool_code +tool_result = tool_call() +``` +Also: +```python +python_code = True +```""" + + code = agent._extract_code_block(response) + assert code == "tool_result = tool_call()" + + def test_extract_code_block_no_code(self): + """Test code extraction when no code block present.""" + agent = CodingAgent(name="test") + + response = "This is just text without any code blocks." + code = agent._extract_code_block(response) + assert code is None + + def test_build_error_feedback(self): + """Test error feedback formatting.""" + agent = CodingAgent(name="test") + + error = "NameError: name 'undefined_var' is not defined" + code = "print(undefined_var)" + + feedback = agent._build_error_feedback(error, code) + + assert "NameError" in feedback + assert "undefined_var" in feedback + assert code in feedback + assert "fix the error" in feedback.lower() + + def test_default_model(self): + """Test that default model is used when not specified.""" + agent = CodingAgent(name="test") + + # canonical_model property should return a BaseLlm + model = agent.canonical_model + assert model is not None + + def test_cleanup(self): + """Test that cleanup releases resources.""" + agent = CodingAgent(name="test") + agent._resolved_tools = [MagicMock()] + agent._coding_executor = MagicMock() + + agent.cleanup() + + assert agent._resolved_tools is None + assert agent._coding_executor is None + + +class TestCodingAgentTools: + """Tests for CodingAgent tool handling.""" + + def test_agent_with_function_tools(self): + """Test agent with function tools.""" + + def my_tool(query: str) -> dict: + """A test tool.""" + return {"result": query} + + agent = CodingAgent( + name="test", + tools=[my_tool], + ) + + assert len(agent.tools) == 1 + + def test_agent_with_base_tool(self): + """Test agent with BaseTool instances.""" + + class MockTool(BaseTool): + def __init__(self): + super().__init__(name="mock_tool", description="A mock tool") + + async def run_async(self, *, args, tool_context): + return {"result": "mock"} + + tool = MockTool() + agent = CodingAgent( + name="test", + tools=[tool], + ) + + assert len(agent.tools) == 1 + + @pytest.mark.asyncio + async def test_resolve_tools(self): + """Test tool resolution.""" + + def test_func(x: int) -> int: + """Test function.""" + return x * 2 + + agent = CodingAgent( + name="test", + tools=[test_func], + ) + + tools = await agent._resolve_tools() + assert len(tools) == 1 + assert tools[0].name == "test_func" diff --git a/tests/unittests/code_executors/test_allowlist_validator.py b/tests/unittests/code_executors/test_allowlist_validator.py new file mode 100644 index 0000000000..41d12c714f --- /dev/null +++ b/tests/unittests/code_executors/test_allowlist_validator.py @@ -0,0 +1,321 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for AllowlistValidator.""" + +from __future__ import annotations + +import pytest + +from google.adk.code_executors.allowlist_validator import AllowlistValidator +from google.adk.code_executors.allowlist_validator import DEFAULT_SAFE_IMPORTS +from google.adk.code_executors.allowlist_validator import extract_imports +from google.adk.code_executors.allowlist_validator import ImportValidationError +from google.adk.code_executors.allowlist_validator import is_import_allowed +from google.adk.code_executors.allowlist_validator import validate_imports + + +class TestExtractImports: + """Tests for extract_imports function.""" + + def test_simple_import(self): + """Test extracting simple import statements.""" + code = "import json" + imports = extract_imports(code) + + assert len(imports) == 1 + assert imports[0].module == "json" + assert imports[0].is_from_import is False + + def test_multiple_imports(self): + """Test extracting multiple imports.""" + code = """ +import json +import math +import re +""" + imports = extract_imports(code) + + assert len(imports) == 3 + modules = {i.module for i in imports} + assert modules == {"json", "math", "re"} + + def test_from_import(self): + """Test extracting from imports.""" + code = "from collections import defaultdict" + imports = extract_imports(code) + + assert len(imports) == 1 + assert imports[0].module == "collections" + assert imports[0].names == ["defaultdict"] + assert imports[0].is_from_import is True + + def test_from_import_multiple(self): + """Test extracting from imports with multiple names.""" + code = "from typing import List, Dict, Optional" + imports = extract_imports(code) + + assert len(imports) == 3 + for imp in imports: + assert imp.module == "typing" + assert imp.is_from_import is True + + def test_import_with_alias(self): + """Test extracting imports with aliases.""" + code = "import numpy as np" + imports = extract_imports(code) + + assert len(imports) == 1 + assert imports[0].module == "numpy" + assert imports[0].alias == "np" + + def test_submodule_import(self): + """Test extracting submodule imports.""" + code = "import os.path" + imports = extract_imports(code) + + assert len(imports) == 1 + assert imports[0].module == "os.path" + + def test_from_submodule_import(self): + """Test extracting from submodule imports.""" + code = "from collections.abc import Mapping" + imports = extract_imports(code) + + assert len(imports) == 1 + assert imports[0].module == "collections.abc" + assert imports[0].names == ["Mapping"] + + def test_syntax_error(self): + """Test handling of syntax errors.""" + code = "import json\nthis is not valid python" + + with pytest.raises(SyntaxError): + extract_imports(code) + + def test_no_imports(self): + """Test code with no imports.""" + code = "x = 1 + 2\nprint(x)" + imports = extract_imports(code) + + assert len(imports) == 0 + + +class TestIsImportAllowed: + """Tests for is_import_allowed function.""" + + def test_direct_match(self): + """Test direct import match.""" + allowlist = frozenset({"json", "math"}) + + assert is_import_allowed("json", allowlist) is True + assert is_import_allowed("math", allowlist) is True + assert is_import_allowed("os", allowlist) is False + + def test_wildcard_match(self): + """Test wildcard pattern matching.""" + allowlist = frozenset({"collections.*"}) + + assert is_import_allowed("collections.abc", allowlist) is True + assert is_import_allowed("collections.defaultdict", allowlist) is True + assert is_import_allowed("itertools", allowlist) is False + + def test_deep_wildcard_match(self): + """Test wildcard matching for deep submodules.""" + allowlist = frozenset({"collections.*"}) + + assert is_import_allowed("collections.abc.Mapping", allowlist) is True + + def test_exact_vs_wildcard(self): + """Test that exact matches work without wildcard.""" + allowlist = frozenset({"numpy"}) + + assert is_import_allowed("numpy", allowlist) is True + # Without wildcard, submodules are not allowed + assert is_import_allowed("numpy.array", allowlist) is False + + def test_multiple_patterns(self): + """Test multiple patterns in allowlist.""" + allowlist = frozenset({"json", "typing.*", "collections"}) + + assert is_import_allowed("json", allowlist) is True + assert is_import_allowed("typing.List", allowlist) is True + assert is_import_allowed("collections", allowlist) is True + assert is_import_allowed("collections.abc", allowlist) is False + + +class TestValidateImports: + """Tests for validate_imports function.""" + + def test_all_allowed(self): + """Test code with all imports allowed.""" + code = """ +import json +import math +from typing import List +""" + allowlist = frozenset({"json", "math", "typing.*"}) + + violations = validate_imports(code, allowlist) + assert len(violations) == 0 + + def test_some_violations(self): + """Test code with some unauthorized imports.""" + code = """ +import json +import os +import subprocess +""" + allowlist = frozenset({"json"}) + + violations = validate_imports(code, allowlist) + assert len(violations) == 2 + assert any("os" in v for v in violations) + assert any("subprocess" in v for v in violations) + + def test_from_import_violations(self): + """Test from import violations.""" + code = "from os import system" + allowlist = frozenset({"json"}) + + violations = validate_imports(code, allowlist) + assert len(violations) == 1 + assert "os" in violations[0] + + def test_syntax_error_violation(self): + """Test that syntax errors are reported as violations.""" + code = "import json\n$$$invalid" + allowlist = frozenset({"json"}) + + violations = validate_imports(code, allowlist) + assert len(violations) == 1 + assert "Syntax error" in violations[0] + + +class TestImportValidationError: + """Tests for ImportValidationError exception.""" + + def test_error_message(self): + """Test error message formatting.""" + violations = ["Unauthorized import: os", "Unauthorized import: subprocess"] + code = "import os\nimport subprocess" + + error = ImportValidationError(violations, code) + + assert "Import validation failed" in str(error) + assert "os" in str(error) + assert "subprocess" in str(error) + + def test_error_attributes(self): + """Test error attributes.""" + violations = ["violation1", "violation2"] + code = "some code" + + error = ImportValidationError(violations, code) + + assert error.violations == violations + assert error.code == code + + +class TestAllowlistValidator: + """Tests for AllowlistValidator class.""" + + def test_default_allowlist(self): + """Test validator with default allowlist.""" + validator = AllowlistValidator() + + # These should be in the default safe imports + assert validator.is_allowed("json") is True + assert validator.is_allowed("math") is True + assert validator.is_allowed("typing") is True + + # These should not be in the default safe imports + assert validator.is_allowed("os") is False + assert validator.is_allowed("subprocess") is False + + def test_custom_allowlist(self): + """Test validator with custom allowlist.""" + custom = frozenset({"custom_module"}) + validator = AllowlistValidator(allowlist=custom) + + assert validator.is_allowed("custom_module") is True + assert validator.is_allowed("json") is False + + def test_additional_imports(self): + """Test adding additional imports to default.""" + additional = frozenset({"custom_module", "another_module"}) + validator = AllowlistValidator(additional_imports=additional) + + # Should have both default and additional + assert validator.is_allowed("json") is True + assert validator.is_allowed("custom_module") is True + assert validator.is_allowed("another_module") is True + + def test_validate_method(self): + """Test validate method returns violations.""" + validator = AllowlistValidator(allowlist=frozenset({"json"})) + + violations = validator.validate("import json\nimport os") + assert len(violations) == 1 + assert "os" in violations[0] + + def test_validate_strict_raises(self): + """Test validate_strict raises on violations.""" + validator = AllowlistValidator(allowlist=frozenset({"json"})) + + with pytest.raises(ImportValidationError): + validator.validate_strict("import os") + + def test_validate_strict_passes(self): + """Test validate_strict passes when no violations.""" + validator = AllowlistValidator(allowlist=frozenset({"json"})) + + # Should not raise + validator.validate_strict("import json") + + def test_add_allowed_imports(self): + """Test adding imports after construction.""" + validator = AllowlistValidator(allowlist=frozenset({"json"})) + + assert validator.is_allowed("os") is False + + validator.add_allowed_imports({"os"}) + + assert validator.is_allowed("os") is True + + +class TestDefaultSafeImports: + """Tests for the default safe imports list.""" + + def test_common_safe_imports_included(self): + """Test that common safe imports are in the default list.""" + assert "json" in DEFAULT_SAFE_IMPORTS + assert "math" in DEFAULT_SAFE_IMPORTS + assert "re" in DEFAULT_SAFE_IMPORTS + assert "datetime" in DEFAULT_SAFE_IMPORTS + assert "typing" in DEFAULT_SAFE_IMPORTS + assert "collections" in DEFAULT_SAFE_IMPORTS + + def test_dangerous_imports_not_included(self): + """Test that dangerous imports are not in the default list.""" + assert "os" not in DEFAULT_SAFE_IMPORTS + assert "subprocess" not in DEFAULT_SAFE_IMPORTS + assert "sys" not in DEFAULT_SAFE_IMPORTS + assert "socket" not in DEFAULT_SAFE_IMPORTS + assert "ctypes" not in DEFAULT_SAFE_IMPORTS + + def test_wildcard_patterns_included(self): + """Test that wildcard patterns are included.""" + assert "collections.*" in DEFAULT_SAFE_IMPORTS + assert "typing.*" in DEFAULT_SAFE_IMPORTS diff --git a/tests/unittests/code_executors/test_tool_code_generator.py b/tests/unittests/code_executors/test_tool_code_generator.py new file mode 100644 index 0000000000..39a676da42 --- /dev/null +++ b/tests/unittests/code_executors/test_tool_code_generator.py @@ -0,0 +1,320 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for ToolCodeGenerator.""" + +from __future__ import annotations + +import pytest +from unittest.mock import MagicMock + +from google.genai import types + +from google.adk.code_executors.tool_code_generator import generate_full_code_with_stubs +from google.adk.code_executors.tool_code_generator import generate_runtime_header +from google.adk.code_executors.tool_code_generator import generate_system_prompt +from google.adk.code_executors.tool_code_generator import generate_tool_stubs +from google.adk.tools.base_tool import BaseTool + + +class MockTool(BaseTool): + """Mock tool for testing.""" + + def __init__( + self, + name: str = "mock_tool", + description: str = "A mock tool for testing", + params: dict = None, + ): + super().__init__(name=name, description=description) + self._params = params or {} + + def _get_declaration(self): + properties = {} + required = [] + + for param_name, param_info in self._params.items(): + properties[param_name] = { + "type": param_info.get("type", "string"), + "description": param_info.get("description", ""), + } + if param_info.get("required", False): + required.append(param_name) + + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type="object", + properties=properties, + required=required if required else None, + ), + ) + + async def run_async(self, *, args, tool_context): + return {"result": "mock"} + + +class TestGenerateRuntimeHeader: + """Tests for generate_runtime_header function.""" + + def test_generates_valid_header(self): + """Test that the header contains required elements.""" + url = "http://localhost:8765" + header = generate_runtime_header(url) + + # Should contain the URL + assert url in header + + # Should contain helper functions + assert "_call_adk_tool" in header + assert "final_answer" in header + assert "__get_tool_traces" in header + + # Should be valid Python syntax + compile(header, "", "exec") + + def test_header_with_different_urls(self): + """Test header generation with different URLs.""" + urls = [ + "http://localhost:8765", + "http://host.docker.internal:9999", + "http://172.17.0.1:8765", + ] + + for url in urls: + header = generate_runtime_header(url) + assert url in header + + def test_header_contains_trace_collection(self): + """Test that header contains trace collection code.""" + header = generate_runtime_header("http://localhost:8765") + + assert "__ADK_TOOL_TRACES" in header + assert "__get_tool_traces" in header + assert "__clear_tool_traces" in header + + def test_header_contains_final_answer_marker(self): + """Test that header contains final answer marker.""" + header = generate_runtime_header("http://localhost:8765") + + assert "__FINAL_ANSWER__" in header + assert "final_answer" in header + + +class TestGenerateToolStubs: + """Tests for generate_tool_stubs function.""" + + def test_generates_stub_for_tool(self): + """Test generating stub for a single tool.""" + tool = MockTool( + name="search", + description="Search for information", + params={ + "query": { + "type": "string", + "description": "The search query", + "required": True, + } + }, + ) + + stubs = generate_tool_stubs([tool]) + + # Should contain function definition + assert "def search(" in stubs + assert "query" in stubs + + # Should be valid Python + compile(stubs, "", "exec") + + def test_generates_stubs_for_multiple_tools(self): + """Test generating stubs for multiple tools.""" + tools = [ + MockTool(name="tool1", description="First tool"), + MockTool(name="tool2", description="Second tool"), + MockTool(name="tool3", description="Third tool"), + ] + + stubs = generate_tool_stubs(tools) + + assert "def tool1(" in stubs + assert "def tool2(" in stubs + assert "def tool3(" in stubs + + def test_stub_includes_docstring(self): + """Test that stubs include docstrings.""" + tool = MockTool( + name="my_tool", + description="A tool that does something useful", + ) + + stubs = generate_tool_stubs([tool]) + + assert '"""' in stubs + assert "A tool that does something useful" in stubs + + def test_stub_includes_type_hints(self): + """Test that stubs include type hints.""" + tool = MockTool( + name="typed_tool", + description="A typed tool", + params={ + "count": {"type": "integer", "description": "A count"}, + "name": {"type": "string", "description": "A name"}, + "enabled": {"type": "boolean", "description": "Is enabled"}, + }, + ) + + stubs = generate_tool_stubs([tool]) + + assert "int" in stubs + assert "str" in stubs + assert "bool" in stubs + + def test_empty_tool_list(self): + """Test generating stubs for empty tool list.""" + stubs = generate_tool_stubs([]) + + # Should still be valid Python + compile(stubs, "", "exec") + + +class TestGenerateSystemPrompt: + """Tests for generate_system_prompt function.""" + + def test_generates_prompt_with_tools(self): + """Test generating system prompt with tools.""" + tools = [ + MockTool( + name="search", + description="Search the web", + params={"query": {"type": "string", "required": True}}, + ), + ] + + prompt = generate_system_prompt(tools) + + # Should contain tool documentation + assert "search" in prompt + assert "Search the web" in prompt + + # Should contain usage instructions + assert "tool_code" in prompt + assert "final_answer" in prompt + + def test_generates_prompt_with_custom_instruction(self): + """Test generating prompt with custom instruction.""" + tools = [] + custom = "Always be polite and helpful." + + prompt = generate_system_prompt(tools, custom_instruction=custom) + + assert custom in prompt + + def test_generates_prompt_with_examples(self): + """Test that prompt contains examples.""" + tools = [] + prompt = generate_system_prompt(tools) + + assert "Example" in prompt + assert "```tool_code" in prompt + + def test_generates_prompt_with_parameter_docs(self): + """Test that prompt includes parameter documentation.""" + tools = [ + MockTool( + name="get_weather", + description="Get weather for a city", + params={ + "city": { + "type": "string", + "description": "The city name", + "required": True, + }, + "units": { + "type": "string", + "description": "Temperature units", + "required": False, + }, + }, + ), + ] + + prompt = generate_system_prompt(tools) + + assert "city" in prompt + assert "units" in prompt + assert "required" in prompt.lower() or "optional" in prompt.lower() + + +class TestGenerateFullCodeWithStubs: + """Tests for generate_full_code_with_stubs function.""" + + def test_generates_complete_code(self): + """Test generating complete executable code.""" + tools = [MockTool(name="my_tool", description="A tool")] + user_code = "result = my_tool()\nprint(result)" + + full_code = generate_full_code_with_stubs( + user_code=user_code, + tools=tools, + tool_server_url="http://localhost:8765", + ) + + # Should contain runtime header + assert "_call_adk_tool" in full_code + + # Should contain tool stub + assert "def my_tool(" in full_code + + # Should contain user code + assert user_code in full_code + + # Should be valid Python + compile(full_code, "", "exec") + + def test_generated_code_outputs_traces(self): + """Test that generated code outputs traces.""" + tools = [] + user_code = "x = 1" + + full_code = generate_full_code_with_stubs( + user_code=user_code, + tools=tools, + tool_server_url="http://localhost:8765", + ) + + assert "__TOOL_TRACE__" in full_code + + def test_generated_code_is_executable(self): + """Test that generated code can be compiled.""" + tools = [ + MockTool(name="tool_a", description="Tool A"), + MockTool(name="tool_b", description="Tool B"), + ] + user_code = """ +result_a = tool_a() +result_b = tool_b() +print(result_a, result_b) +""" + + full_code = generate_full_code_with_stubs( + user_code=user_code, + tools=tools, + tool_server_url="http://localhost:8765", + ) + + # Should compile without errors + compile(full_code, "", "exec") From 5bc161ea4a63742bf20d2c1523779b58fbc615f1 Mon Sep 17 00:00:00 2001 From: Sudhendra Date: Sat, 17 Jan 2026 19:25:30 -0600 Subject: [PATCH 02/10] successful data analyst code_agent test --- .gitignore | 2 + contributing/samples/coding_agent/README.md | 290 ++++++++++++++ contributing/samples/coding_agent/agent.py | 378 +++++++++++++----- src/google/adk/agents/coding_agent.py | 71 +++- .../adk/code_executors/tool_code_generator.py | 13 +- 5 files changed, 649 insertions(+), 105 deletions(-) create mode 100644 contributing/samples/coding_agent/README.md diff --git a/.gitignore b/.gitignore index 47f633c5c5..5f8ad01f2e 100644 --- a/.gitignore +++ b/.gitignore @@ -116,3 +116,5 @@ CLAUDE.md .rooignore .bolt/ .v0/ + +CODING_AGENT_PLAN.md diff --git a/contributing/samples/coding_agent/README.md b/contributing/samples/coding_agent/README.md new file mode 100644 index 0000000000..cb5ce8b68c --- /dev/null +++ b/contributing/samples/coding_agent/README.md @@ -0,0 +1,290 @@ +# Data Analysis Agent + +A CodingAgent sample that demonstrates AI-powered data analysis with Python code execution. + +## Overview + +This sample showcases the CodingAgent's ability to: + +- Fetch datasets from URLs (CSV, JSON, text) +- Analyze data using pandas +- Create visualizations with matplotlib +- Generate statistical summaries and insights +- Execute multi-step reasoning through code + +The agent writes and executes Python code in a sandboxed Docker container, calling tools via HTTP IPC when needed. + +## Architecture + +``` +┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐ +│ User Query │────>│ CodingAgent │────>│ Docker Container│ +│ │ │ (Gemini 2.5) │ │ (Python 3.11) │ +└─────────────────┘ └──────────────────┘ └─────────────────┘ + │ │ + │ │ Executes + v │ pandas/matplotlib + ┌──────────────┐ │ code + │ Tool Server │<────────────────┘ + │ (HTTP IPC) │ Tool calls + └──────────────┘ (fetch_url, etc.) +``` + +### How It Works + +1. User sends a natural language query +2. CodingAgent (powered by Gemini 2.5) generates Python code +3. Code is executed in a sandboxed Docker container +4. Tools (like `fetch_url`) are called via HTTP to the Tool Server on the host +5. Results are returned to the container, LLM iterates if needed +6. Final answer is provided via `final_answer()` function + +## Prerequisites + +- **Docker**: Must be installed and running +- **API Key**: Set `GOOGLE_API_KEY` in `.env` file or environment +- **Python**: 3.10+ (for running ADK CLI) + +## Quick Start + +### 1. Set up your API key + +Create a `.env` file in this directory: + +```bash +echo "GOOGLE_API_KEY=your_api_key_here" > contributing/samples/coding_agent/.env +``` + +### 2. Run the agent + +**Using CLI (interactive):** + +```bash +adk run contributing/samples/coding_agent +``` + +**Using Web UI:** + +```bash +adk web contributing/samples +``` + +Then navigate to `http://localhost:8000` and select `coding_agent`. + +## Available Tools + +### `fetch_url(url: str) -> dict` + +Fetches content from a URL. Supports CSV, JSON, and plain text. + +**Returns:** +- `content`: The fetched content as a string +- `content_type`: MIME type of the content +- `size`: Size in bytes +- `success`: Boolean indicating success +- `error`: Error message (only on failure) + +**Example:** +```python +result = fetch_url("https://example.com/data.csv") +if result["success"]: + csv_content = result["content"] +``` + +### `get_sample_datasets() -> dict` + +Returns available sample datasets with URLs and descriptions. + +**Available datasets:** +- `titanic`: Titanic passenger survival data (891 rows) +- `iris`: Iris flower classification data (150 rows) +- `tips`: Restaurant tipping data (244 rows) + +### `get_current_time() -> dict` + +Returns current date and time information. + +### `save_chart(image_data: str, filename: str) -> dict` + +Saves a chart image to the **host system** (not the Docker container). This is essential for making visualizations accessible outside of Docker. + +**Parameters:** +- `image_data`: Base64-encoded image data (PNG recommended) +- `filename`: Name for the saved file (e.g., "chart.png") + +**Returns:** +- `success`: Boolean indicating success +- `filepath`: Full path where the file was saved (e.g., `/tmp/adk_charts/chart.png`) +- `size`: Size of the saved file in bytes +- `error`: Error message (only on failure) + +**Example usage in generated code:** +```python +import base64 +import io +import matplotlib.pyplot as plt + +# Create your plot +plt.figure(figsize=(10, 6)) +plt.bar(['A', 'B', 'C'], [1, 2, 3]) + +# Save to buffer and encode +buf = io.BytesIO() +plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') +buf.seek(0) +image_data = base64.b64encode(buf.read()).decode('utf-8') +plt.close() + +# Save to host system +result = save_chart(image_data=image_data, filename="my_chart.png") +print(f"Chart saved to: {result['filepath']}") +``` + +### `list_saved_charts() -> dict` + +Lists all charts saved to the host system. + +**Returns:** +- `charts`: List of saved chart filenames +- `directory`: Directory path (`/tmp/adk_charts`) +- `count`: Number of charts found + +## Sample Datasets + +| Dataset | Description | Columns | +|---------|-------------|---------| +| **Titanic** | Passenger survival data from the Titanic disaster | PassengerId, Survived, Pclass, Name, Sex, Age, SibSp, Parch, Ticket, Fare, Cabin, Embarked | +| **Iris** | Classic flower classification dataset | sepal_length, sepal_width, petal_length, petal_width, species | +| **Tips** | Restaurant tipping behavior data | total_bill, tip, sex, smoker, day, time, size | + +## Example Interactions + +### Basic Analysis + +``` +User: What is the survival rate on the Titanic? + +Agent: [Fetches Titanic CSV, calculates survival statistics] +Response: The overall survival rate on the Titanic was 38.4% (342 out of 891 passengers survived). +``` + +### Visualization + +``` +User: Create a bar chart showing survival rate by passenger class on the Titanic + +Agent: [Fetches data, creates matplotlib chart, uses save_chart tool] +Response: I've created a bar chart showing survival rates by class: +- 1st Class: 63% survival rate +- 2nd Class: 47% survival rate +- 3rd Class: 24% survival rate +Chart saved to: /tmp/adk_charts/survival_by_class.png +``` + +Charts are saved to `/tmp/adk_charts/` on your host system, making them accessible outside Docker. + +### Exploratory Analysis + +``` +User: Analyze the iris dataset and give me key insights + +Agent: [Fetches iris data, runs statistical analysis, creates visualizations] +Response: Key insights from the Iris dataset: +1. The dataset contains 150 samples across 3 species (50 each) +2. Setosa is clearly separable from the other two species +3. Petal length is the best single predictor for species classification +... +``` + +### Custom URL + +``` +User: Fetch the tips dataset and tell me which day has the highest average tip + +Agent: [Fetches tips CSV, analyzes by day] +Response: Sunday has the highest average tip at $3.26, followed by Saturday at $2.99. +``` + +## Troubleshooting + +### Docker not running + +``` +Error: Cannot connect to Docker daemon +``` + +**Solution:** Start Docker Desktop or the Docker service. + +### API key issues + +``` +Error: Missing key inputs argument! +``` + +**Solution:** Ensure `GOOGLE_API_KEY` is set in `.env` file or environment. + +### Container networking issues + +``` +Error: Connection refused to tool server +``` + +**Solution:** On macOS/Windows, Docker Desktop should handle `host.docker.internal` automatically. On Linux, you may need to configure the Docker bridge network. + +### Package installation in container + +The agent automatically installs pandas and matplotlib at runtime. If you see import errors, ensure the container has internet access. + +## Extending the Agent + +### Adding Custom Tools + +Add new tool functions in `agent.py`: + +```python +def my_custom_tool(param: str) -> dict: + """Description of what this tool does. + + Args: + param: Description of the parameter. + + Returns: + Dictionary with results. + """ + # Implementation + return {"result": "..."} + +# Add to the tools list +root_agent = CodingAgent( + ... + tools=[fetch_url, get_sample_datasets, get_current_time, my_custom_tool], + ... +) +``` + +### Using a Custom Container Image + +For faster execution with pre-installed packages: + +```python +code_executor=ContainerCodeExecutor( + image="my-custom-image:latest", # Image with pandas/matplotlib pre-installed +) +``` + +### Adjusting Iteration Limits + +```python +root_agent = CodingAgent( + ... + max_iterations=15, # More iterations for complex tasks + error_retry_attempts=3, # More retries on errors + ... +) +``` + +## Related Documentation + +- [CodingAgent Documentation](https://google.github.io/adk-docs/agents/coding-agent) +- [ContainerCodeExecutor](https://google.github.io/adk-docs/code-executors/container) +- [ADK Tools](https://google.github.io/adk-docs/tools) diff --git a/contributing/samples/coding_agent/agent.py b/contributing/samples/coding_agent/agent.py index d0f6bae10c..1c234864b9 100644 --- a/contributing/samples/coding_agent/agent.py +++ b/contributing/samples/coding_agent/agent.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Sample CodingAgent demonstrating code generation with tool usage. +"""Data Analysis Agent using CodingAgent. -This sample shows how to create a CodingAgent that can: -- Generate Python code to solve tasks -- Call tools as Python functions from within the generated code -- Execute code in a sandboxed container environment -- Provide final answers after multi-step reasoning +This sample demonstrates a CodingAgent configured as a data analyst that can: +- Fetch datasets from URLs (CSV, JSON, text) +- Analyze data using pandas +- Create visualizations using matplotlib +- Generate statistical summaries and insights Prerequisites: - Docker must be installed and running @@ -29,106 +29,117 @@ adk web contributing/samples Example queries: -- "What is 15% of 847?" -- "Calculate the compound interest on $10,000 at 5% annual rate for 3 years" -- "Search for the latest Python release and summarize the key features" +- "What is the survival rate on the Titanic?" +- "Create a bar chart showing survival rate by passenger class" +- "Analyze the iris dataset and create a scatter plot" """ +import base64 +import binascii +import os +import urllib.error +import urllib.request +from datetime import datetime + from google.adk.agents import CodingAgent from google.adk.code_executors import ContainerCodeExecutor +from google.adk.code_executors.allowlist_validator import DEFAULT_SAFE_IMPORTS -# Define sample tools that the CodingAgent can use -def calculator(expression: str) -> dict: - """Evaluate a mathematical expression. +# Sample dataset URLs +SAMPLE_DATASETS = { + "titanic": { + "url": "https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv", + "description": "Titanic passenger data with survival information. 891 passengers with features like age, sex, class, fare, and survival status.", + "columns": "PassengerId, Survived, Pclass, Name, Sex, Age, SibSp, Parch, Ticket, Fare, Cabin, Embarked", + }, + "iris": { + "url": "https://raw.githubusercontent.com/mwaskom/seaborn-data/master/iris.csv", + "description": "Iris flower dataset. 150 samples of 3 species with sepal and petal measurements.", + "columns": "sepal_length, sepal_width, petal_length, petal_width, species", + }, + "tips": { + "url": "https://raw.githubusercontent.com/mwaskom/seaborn-data/master/tips.csv", + "description": "Restaurant tips dataset. 244 records with bill amount, tip, and customer info.", + "columns": "total_bill, tip, sex, smoker, day, time, size", + }, +} - Args: - expression: A mathematical expression to evaluate (e.g., "2 + 2 * 3"). - Returns: - Dictionary with the result or error message. - """ - try: - # Safe evaluation of mathematical expressions - allowed_names = { - "abs": abs, - "round": round, - "min": min, - "max": max, - "sum": sum, - "pow": pow, - } - result = eval(expression, {"__builtins__": {}}, allowed_names) - return {"result": result, "expression": expression} - except Exception as e: - return {"error": str(e), "expression": expression} +def fetch_url(url: str) -> dict: + """Fetch content from a URL. - -def web_search(query: str, max_results: int = 5) -> dict: - """Search the web for information. + Fetches data from the specified URL and returns the content along with + metadata. Supports CSV, JSON, and plain text content. Args: - query: The search query. - max_results: Maximum number of results to return. + url: The URL to fetch content from. Returns: - Dictionary with search results. + Dictionary containing: + - content: The fetched content as a string + - content_type: The MIME type of the content + - size: Size of the content in bytes + - url: The original URL + - success: Whether the fetch was successful + - error: Error message if fetch failed (only present on failure) """ - # This is a mock implementation for demonstration - # In production, you would integrate with a real search API - return { - "query": query, - "results": [ - { - "title": f"Result {i + 1} for: {query}", - "snippet": f"This is a sample result snippet for '{query}'...", - "url": f"https://example.com/result{i + 1}", + try: + req = urllib.request.Request( + url, + headers={"User-Agent": "Mozilla/5.0 (compatible; ADK-DataAnalyst/1.0)"}, + ) + with urllib.request.urlopen(req, timeout=30) as response: + content = response.read().decode("utf-8") + content_type = response.headers.get("Content-Type", "text/plain") + return { + "content": content, + "content_type": content_type, + "size": len(content), + "url": url, + "success": True, } - for i in range(min(max_results, 3)) - ], - "total_results": max_results, - } + except urllib.error.URLError as e: + return { + "content": "", + "url": url, + "success": False, + "error": f"Failed to fetch URL: {str(e)}", + } + except Exception as e: + return { + "content": "", + "url": url, + "success": False, + "error": f"Unexpected error: {str(e)}", + } -def read_file(path: str) -> dict: - """Read contents of a file. +def get_sample_datasets() -> dict: + """Get available sample datasets with their URLs and descriptions. - Args: - path: Path to the file to read. + Returns a dictionary of sample datasets that can be used for analysis. + Each dataset includes a URL, description, and column information. Returns: - Dictionary with file contents or error. + Dictionary with dataset names as keys, each containing: + - url: Direct URL to download the CSV file + - description: Brief description of the dataset + - columns: Comma-separated list of column names """ - # This is a mock implementation for demonstration - # In production, you would implement actual file reading with proper security - mock_files = { - "data.csv": { - "content": "name,amount\nAlice,100\nBob,200\nCharlie,150", - "rows": [ - {"name": "Alice", "amount": "100"}, - {"name": "Bob", "amount": "200"}, - {"name": "Charlie", "amount": "150"}, - ], - }, - "config.json": { - "content": '{"setting": "value"}', - "data": {"setting": "value"}, - }, - } - - if path in mock_files: - return {"path": path, **mock_files[path]} - return {"error": f"File not found: {path}", "path": path} + return SAMPLE_DATASETS def get_current_time() -> dict: """Get the current date and time. Returns: - Dictionary with current timestamp information. + Dictionary containing: + - timestamp: ISO format timestamp + - year, month, day: Date components + - hour, minute, second: Time components + - weekday: Name of the day of the week """ - from datetime import datetime - now = datetime.now() return { "timestamp": now.isoformat(), @@ -137,42 +148,211 @@ def get_current_time() -> dict: "day": now.day, "hour": now.hour, "minute": now.minute, + "second": now.second, "weekday": now.strftime("%A"), } -# Create the CodingAgent with tools +# Directory on host system to save charts +HOST_CHARTS_DIR = "/tmp/adk_charts" + + +def save_chart(image_data: str, filename: str) -> dict: + """Save a chart image to the host system. + + This tool saves base64-encoded image data to the host machine's filesystem, + making charts accessible outside the Docker container. + + To use this tool, first save your matplotlib figure to a bytes buffer, + then encode it as base64: + + Example: + import base64 + import io + import matplotlib.pyplot as plt + + # Create your plot + plt.figure() + plt.plot([1, 2, 3], [1, 4, 9]) + + # Save to buffer and encode + buf = io.BytesIO() + plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') + buf.seek(0) + image_data = base64.b64encode(buf.read()).decode('utf-8') + plt.close() + + # Save to host system + result = save_chart(image_data=image_data, filename="my_chart.png") + + Args: + image_data: Base64-encoded image data (PNG format recommended). + filename: Name for the saved file (e.g., "chart.png"). + + Returns: + Dictionary containing: + - success: Whether the save was successful + - filepath: Full path where the file was saved on the host + - size: Size of the saved file in bytes + - error: Error message if save failed (only present on failure) + """ + try: + # Ensure the output directory exists + os.makedirs(HOST_CHARTS_DIR, exist_ok=True) + + # Sanitize filename + safe_filename = os.path.basename(filename) + if not safe_filename: + safe_filename = "chart.png" + + filepath = os.path.join(HOST_CHARTS_DIR, safe_filename) + + # Decode and save + image_bytes = base64.b64decode(image_data) + with open(filepath, "wb") as f: + f.write(image_bytes) + + return { + "success": True, + "filepath": filepath, + "size": len(image_bytes), + "message": f"Chart saved to {filepath}", + } + except binascii.Error as e: + return { + "success": False, + "error": f"Invalid base64 data: {str(e)}", + } + except OSError as e: + return { + "success": False, + "error": f"Failed to save file: {str(e)}", + } + except Exception as e: + return { + "success": False, + "error": f"Unexpected error: {str(e)}", + } + + +def list_saved_charts() -> dict: + """List all charts saved on the host system. + + Returns: + Dictionary containing: + - success: Whether the operation was successful + - charts: List of saved chart filenames + - directory: The directory where charts are saved + - count: Number of charts found + """ + try: + if not os.path.exists(HOST_CHARTS_DIR): + return { + "success": True, + "charts": [], + "directory": HOST_CHARTS_DIR, + "count": 0, + } + + charts = [ + f + for f in os.listdir(HOST_CHARTS_DIR) + if f.lower().endswith((".png", ".jpg", ".jpeg", ".svg", ".pdf")) + ] + return { + "success": True, + "charts": charts, + "directory": HOST_CHARTS_DIR, + "count": len(charts), + } + except Exception as e: + return { + "success": False, + "error": f"Failed to list charts: {str(e)}", + } + + +# Additional imports allowed for data analysis +DATA_ANALYSIS_IMPORTS = frozenset( + { + # Data analysis + "pandas", + "pandas.*", + "numpy", + "numpy.*", + # Visualization + "matplotlib", + "matplotlib.*", + "seaborn", + "seaborn.*", + # Data I/O + "csv", + "io", + "io.*", + # Encoding for chart saving + "base64", + # Subprocess for pip installs + "subprocess", + } +) + + +# Create the Data Analysis Agent root_agent = CodingAgent( - name="code_assistant", + name="data_analyst", description=( - "An AI assistant that solves tasks by writing and executing Python code. " - "It can perform calculations, search for information, read files, and more." + "An AI data analyst that analyzes datasets, creates visualizations, " + "and generates insights using Python code execution." ), model="gemini-2.5-flash", - instruction=""" -You are a helpful coding assistant that solves problems by writing Python code. - -When given a task: -1. Think about what tools and computations you need -2. Write clear, well-commented Python code -3. Use the available tools as needed -4. Print intermediate results to verify your work -5. Call final_answer() with your result - -Always show your reasoning through code comments and print statements. -If a task cannot be completed with the available tools, explain why. + instruction="""You are a data analyst. Analyze data, create visualizations, and provide insights. + +IMPORTANT: First install required packages before using them: +```tool_code +import subprocess +subprocess.run(["pip", "install", "-q", "pandas", "matplotlib", "seaborn", "numpy"], check=True) +print("Packages installed successfully") +``` + +Then use the available tools to fetch datasets. Write Python code to analyze data using pandas and create charts with matplotlib. + +CRITICAL: You MUST use the save_chart() tool to save charts - do NOT use plt.savefig() to a file path directly. The save_chart() tool transfers the chart to the host system. Here is the REQUIRED pattern: +```tool_code +import base64 +import io +import matplotlib.pyplot as plt + +# Create your plot +plt.figure(figsize=(10, 6)) +# ... your plotting code ... + +# Save to buffer and encode as base64 +buf = io.BytesIO() +plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') +buf.seek(0) +image_data = base64.b64encode(buf.read()).decode('utf-8') +plt.close() + +# Use the save_chart tool to save to host system +result = save_chart(image_data=image_data, filename="my_chart.png") +print(f"Chart saved: {result}") +``` + +The chart will be saved to the host system at /tmp/adk_charts/. Always report this filepath in your final answer. + +Call final_answer() with your findings when done. """, tools=[ - calculator, - web_search, - read_file, + fetch_url, + get_sample_datasets, get_current_time, + save_chart, + list_saved_charts, ], - # Use ContainerCodeExecutor for sandboxed execution - # Note: Docker must be installed and running code_executor=ContainerCodeExecutor( image="python:3.11-slim", ), + authorized_imports=DEFAULT_SAFE_IMPORTS | DATA_ANALYSIS_IMPORTS, max_iterations=10, error_retry_attempts=2, stateful=False, diff --git a/src/google/adk/agents/coding_agent.py b/src/google/adk/agents/coding_agent.py index 46ccfcb8ed..42128dacd8 100644 --- a/src/google/adk/agents/coding_agent.py +++ b/src/google/adk/agents/coding_agent.py @@ -304,6 +304,66 @@ def _extract_code_block(self, response_text: str) -> Optional[str]: return None + def _is_real_error(self, stderr: str) -> bool: + """Check if stderr contains a real error vs just warnings. + + Args: + stderr: The stderr output from code execution. + + Returns: + True if stderr contains a real error, False if just warnings. + """ + if not stderr: + return False + + # Patterns that indicate this is just a warning, not an error + warning_patterns = [ + "WARNING: Running pip as the 'root' user", + "[notice] A new release of pip", + "[notice] To update, run:", + "pip install --upgrade pip", + "UserWarning:", + "DeprecationWarning:", + "FutureWarning:", + "RuntimeWarning:", + ] + + # Check if ALL lines are just warnings + lines = stderr.strip().split("\n") + real_error_lines = [] + for line in lines: + line_stripped = line.strip() + if not line_stripped: + continue + is_warning = any( + pattern.lower() in line_stripped.lower() for pattern in warning_patterns + ) + if not is_warning: + real_error_lines.append(line) + + # Also check for actual error keywords + error_keywords = [ + "error:", + "traceback", + "exception", + "syntaxerror", + "nameerror", + "typeerror", + "valueerror", + "importerror", + "modulenotfounderror", + "attributeerror", + "keyerror", + "indexerror", + "zerodivisionerror", + ] + + stderr_lower = stderr.lower() + has_error_keyword = any(keyword in stderr_lower for keyword in error_keywords) + + # Consider it a real error if there are non-warning lines with error keywords + return bool(real_error_lines) and has_error_keyword + def _build_error_feedback( self, error: str, @@ -434,8 +494,11 @@ async def _run_async_impl( } ) - # Check for errors - if exec_result.code_result.stderr: + # Check for errors - ignore warnings from pip and other non-fatal stderr + stderr = exec_result.code_result.stderr or "" + is_real_error = self._is_real_error(stderr) + + if is_real_error: error_count += 1 state.error_count = error_count @@ -443,13 +506,13 @@ async def _run_async_impl( # Too many errors - give up final_answer = ( f"I encountered too many errors while executing code. " - f"Last error: {exec_result.code_result.stderr}" + f"Last error: {stderr}" ) break # Build error feedback and add to conversation error_feedback = self._build_error_feedback( - exec_result.code_result.stderr, + stderr, code, ) contents.append( diff --git a/src/google/adk/code_executors/tool_code_generator.py b/src/google/adk/code_executors/tool_code_generator.py index e2e18271dd..ef73c61376 100644 --- a/src/google/adk/code_executors/tool_code_generator.py +++ b/src/google/adk/code_executors/tool_code_generator.py @@ -265,7 +265,11 @@ def {tool.name}({param_str}) -> dict: Tool execution result as a dictionary. """ kwargs = {{k: v for k, v in locals().items() if v is not None}} - return _call_adk_tool("{tool.name}", **kwargs) + response = _call_adk_tool("{tool.name}", **kwargs) + # Extract the result from the tool server response + if isinstance(response, dict) and "result" in response: + return response["result"] + return response ''' return stub @@ -425,7 +429,12 @@ def generate_system_prompt( 2. **Always print results**: Use print() to see what tools return. 3. **Handle errors gracefully**: If a tool fails, try an alternative approach. 4. **Call final_answer()**: When done, call final_answer() with your result. -5. **No external imports**: Only use the provided tools and standard library. +5. **Install packages if needed**: If you need external libraries (pandas, matplotlib, numpy, etc.), install them first using subprocess: + ```python + import subprocess + subprocess.run(["pip", "install", "-q", "pandas", "matplotlib", "seaborn"], check=True) + ``` + Then import and use them normally. {custom_instruction} """ From f93db6678902e8b3ac29e5dd3a5839f176f855f5 Mon Sep 17 00:00:00 2001 From: Sudhendra Date: Sat, 17 Jan 2026 19:30:15 -0600 Subject: [PATCH 03/10] feat(agents): Add save_chart tool and improve error handling for CodingAgent - Add save_chart tool to save visualizations to host filesystem - Add list_saved_charts tool to list saved charts - Add _is_real_error method to distinguish between warnings and errors - Fix pip warnings being treated as execution errors - Update system prompt with package installation instructions - Add base64 to authorized imports for chart encoding - Update README with new tool documentation - Create GitHub issue template for CodingAgent feature --- .github/CODING_AGENT_ISSUE.md | 165 +++ .github/CODING_AGENT_PLAN.md | 148 +++ .../ISSUE_TEMPLATE/coding_agent_feature.md | 226 ++++ contributing/samples/coding_agent/agent.py | 451 ++++---- src/google/adk/agents/coding_agent.py | 973 +++++++++--------- .../adk/code_executors/allowlist_validator.py | 533 +++++----- .../coding_agent_code_executor.py | 828 +++++++-------- .../adk/code_executors/tool_code_generator.py | 330 +++--- .../code_executors/tool_execution_server.py | 590 +++++------ tests/unittests/agents/test_coding_agent.py | 445 ++++---- .../test_allowlist_validator.py | 415 ++++---- .../test_tool_code_generator.py | 493 +++++---- 12 files changed, 3076 insertions(+), 2521 deletions(-) create mode 100644 .github/CODING_AGENT_ISSUE.md create mode 100644 .github/CODING_AGENT_PLAN.md create mode 100644 .github/ISSUE_TEMPLATE/coding_agent_feature.md diff --git a/.github/CODING_AGENT_ISSUE.md b/.github/CODING_AGENT_ISSUE.md new file mode 100644 index 0000000000..5f6372beca --- /dev/null +++ b/.github/CODING_AGENT_ISSUE.md @@ -0,0 +1,165 @@ +# GitHub Issue: CodingAgent Feature Request + +**Use this content to create an issue at: https://github.com/google/adk-python/issues/new?template=feature_request.md** + +--- + +## Title + +`feat(agents): Add CodingAgent for code generation and sandboxed execution` + +--- + +## Is your feature request related to a problem? Please describe. + +Currently, ADK agents can only interact with the world through pre-defined tools. While powerful, this approach has limitations: + +1. **Limited flexibility**: Users must anticipate all possible operations and create tools for each +2. **No computational capability**: Agents cannot perform complex calculations, data analysis, or create visualizations without custom tools +3. **No iteration**: Standard tool-calling doesn't easily support multi-step reasoning with intermediate computations +4. **Competitive gap**: Other platforms (OpenAI Code Interpreter, Anthropic's computer use) offer code execution capabilities + +**User pain points:** +- "I want my agent to analyze a CSV file and create a chart" - requires building custom tools +- "I need multi-step calculations with intermediate results" - awkward with standard tools +- "I want the agent to figure out HOW to solve a problem, not just call predefined functions" + +--- + +## Describe the solution you'd like + +A new experimental agent type called **CodingAgent** that: + +1. Receives a task from the user +2. Generates Python code to accomplish the task (using `tool_code` blocks) +3. Executes the code in a sandboxed Docker container +4. Processes results and either provides an answer or continues iterating (ReAct loop) +5. Can call ADK tools from within generated code via HTTP IPC + +### Architecture + +``` +┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐ +│ User Query │────▶│ CodingAgent │────▶│ Docker Container│ +│ │ │ (Gemini LLM) │ │ (Python 3.11) │ +└─────────────────┘ └──────────────────┘ └─────────────────┘ + │ │ + │ │ Executes + ▼ │ generated code + ┌──────────────┐ │ + │ Tool Server │◀────────────────┘ + │ (HTTP IPC) │ Tool calls via HTTP + └──────────────┘ +``` + +### API Design + +```python +from google.adk.agents import CodingAgent +from google.adk.code_executors import ContainerCodeExecutor + +def fetch_data(url: str) -> dict: + """Fetch data from a URL.""" + # Implementation... + +root_agent = CodingAgent( + name="data_analyst", + model="gemini-2.5-flash", + instruction="You are a data analyst. Analyze data and provide insights.", + tools=[fetch_data], # Tools available to generated code + code_executor=ContainerCodeExecutor(image="python:3.11-slim"), + authorized_imports=DEFAULT_SAFE_IMPORTS | {"pandas", "matplotlib"}, + max_iterations=10, + error_retry_attempts=2, +) +``` + +### Key Components + +| Component | Description | +|-----------|-------------| +| CodingAgent | Main agent class with ReAct loop | +| CodingAgentCodeExecutor | Wrapper that injects tool stubs into code | +| ToolCodeGenerator | Generates Python function stubs for tools | +| ToolExecutionServer | HTTP server for tool IPC from container | +| AllowlistValidator | Import security validation | + +### Security Features + +1. **Sandboxed execution**: All code runs in isolated Docker containers +2. **Import allowlisting**: Only authorized imports are permitted (configurable) +3. **Tool isolation**: Tools execute on host via HTTP, not in container +4. **No filesystem access**: Container has no access to host filesystem + +--- + +## Describe alternatives you've considered + +### Alternative 1: Extend LlmAgent with code execution +- **Pros**: Simpler architecture, reuses existing agent +- **Cons**: Conflates two distinct patterns, harder to configure + +### Alternative 2: Code execution as a tool only +- **Pros**: Minimal changes, fits existing model +- **Cons**: No ReAct loop, no iteration, limited capability + +### Alternative 3: Use external code execution service +- **Pros**: Offloads security concerns +- **Cons**: Adds external dependency, latency, cost + +**Chosen approach**: Dedicated CodingAgent provides cleanest separation of concerns, explicit configuration, and full control over the execution environment. + +--- + +## Additional context + +### Implementation Status + +I have a working implementation ready for PR submission: + +**New files (~2,500 lines of production code):** +- `src/google/adk/agents/coding_agent.py` - Main agent class +- `src/google/adk/agents/coding_agent_config.py` - Configuration +- `src/google/adk/code_executors/coding_agent_code_executor.py` - Executor wrapper +- `src/google/adk/code_executors/tool_code_generator.py` - Code generation +- `src/google/adk/code_executors/tool_execution_server.py` - HTTP IPC server +- `src/google/adk/code_executors/allowlist_validator.py` - Security validation + +**Sample agent:** +- `contributing/samples/coding_agent/` - Data Analysis Agent demo + +**Unit tests (~950 lines):** +- `tests/unittests/agents/test_coding_agent.py` +- `tests/unittests/code_executors/test_allowlist_validator.py` +- `tests/unittests/code_executors/test_tool_code_generator.py` + +### Tested Scenarios + +| Test | Status | +|------|--------| +| Basic math queries | ✅ Passed | +| Data analysis with pandas | ✅ Passed | +| Visualization with matplotlib | ✅ Passed | +| Multi-step analysis | ✅ Passed | +| Tool calling via HTTP IPC | ✅ Passed | +| Chart saving to host system | ✅ Passed | +| Error handling and retries | ✅ Passed | + +### Related Work +- OpenAI Code Interpreter +- Anthropic Computer Use +- Google AI Studio code execution + +### Future Enhancements (out of scope for initial PR) +- Stateful execution (persist variables across turns) +- Custom container images with pre-installed packages +- Integration with VertexAI code execution +- Support for additional languages + +--- + +## Labels to add + +- `enhancement` +- `agents` +- `new-feature` diff --git a/.github/CODING_AGENT_PLAN.md b/.github/CODING_AGENT_PLAN.md new file mode 100644 index 0000000000..a20436094f --- /dev/null +++ b/.github/CODING_AGENT_PLAN.md @@ -0,0 +1,148 @@ +# CodingAgent - Implementation Plan & Status + +This document tracks the implementation of CodingAgent, an experimental agent type that generates and executes Python code in sandboxed containers. + +## Overview + +CodingAgent is a ReAct-style agent that: +- Generates Python code to solve tasks using an LLM (Gemini) +- Executes code in sandboxed Docker containers +- Calls ADK tools from generated code via HTTP IPC +- Iterates until a final answer is produced + +## Implementation Status + +### Core Components ✅ Complete + +| Component | File | Status | Lines | +|-----------|------|--------|-------| +| CodingAgent | `src/google/adk/agents/coding_agent.py` | ✅ Complete | ~610 | +| CodingAgentConfig | `src/google/adk/agents/coding_agent_config.py` | ✅ Complete | ~225 | +| CodingAgentCodeExecutor | `src/google/adk/code_executors/coding_agent_code_executor.py` | ✅ Complete | ~505 | +| ToolCodeGenerator | `src/google/adk/code_executors/tool_code_generator.py` | ✅ Complete | ~475 | +| ToolExecutionServer | `src/google/adk/code_executors/tool_execution_server.py` | ✅ Complete | ~365 | +| AllowlistValidator | `src/google/adk/code_executors/allowlist_validator.py` | ✅ Complete | ~355 | + +### Sample Agent ✅ Complete + +| File | Status | Description | +|------|--------|-------------| +| `contributing/samples/coding_agent/agent.py` | ✅ Complete | Data Analysis Agent (~360 lines) | +| `contributing/samples/coding_agent/README.md` | ✅ Complete | Documentation (~290 lines) | +| `contributing/samples/coding_agent/__init__.py` | ✅ Complete | Module init | + +### Unit Tests ✅ Complete + +| Test File | Status | Lines | +|-----------|--------|-------| +| `tests/unittests/agents/test_coding_agent.py` | ✅ Complete | ~310 | +| `tests/unittests/code_executors/test_allowlist_validator.py` | ✅ Complete | ~320 | +| `tests/unittests/code_executors/test_tool_code_generator.py` | ✅ Complete | ~320 | + +### Manual E2E Tests ✅ Passed + +| Test Scenario | Status | Notes | +|--------------|--------|-------| +| Basic math query ("What is 25 * 17?") | ✅ Passed | Returns 425 | +| Data analysis (Titanic survival rate) | ✅ Passed | Returns 38.38% | +| Visualization (bar chart by class) | ✅ Passed | Chart saved to host | +| Multi-step analysis | ✅ Passed | Stats + visualization + insights | +| Tool calling via HTTP IPC | ✅ Passed | fetch_url, save_chart work | +| Error handling (pip warnings) | ✅ Passed | Ignores non-fatal stderr | +| Chart saving to host system | ✅ Passed | Saved to /tmp/adk_charts/ | + +## Architecture + +``` +┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐ +│ User Query │────▶│ CodingAgent │────▶│ Docker Container│ +│ │ │ (Gemini LLM) │ │ (Python 3.11) │ +└─────────────────┘ └──────────────────┘ └─────────────────┘ + │ │ + │ │ Executes + ▼ │ generated code + ┌──────────────┐ │ + │ Tool Server │◀────────────────┘ + │ (HTTP IPC) │ Tool calls via HTTP + └──────────────┘ +``` + +### How Tool IPC Works + +1. CodingAgent starts ToolExecutionServer on host (port 8765) +2. Code is generated with tool stubs that make HTTP POST requests +3. Container reaches host via `host.docker.internal` (macOS/Windows) or bridge gateway (Linux) +4. Tool server executes actual tool functions with proper context +5. Results returned to container via HTTP response + +## Key Design Decisions + +| Decision | Choice | Rationale | +|----------|--------|-----------| +| Container image | `python:3.11-slim` + runtime pip | Simpler for users, no custom Dockerfile | +| Tool communication | HTTP IPC | Works across container boundary, secure | +| Import validation | Allowlist-based | Security without blocking legitimate use | +| Chart saving | `save_chart` tool | Transfers data to host filesystem | +| Error handling | Distinguish warnings from errors | pip warnings shouldn't fail execution | + +## Sample Agent: Data Analyst + +### Tools Available + +| Tool | Description | +|------|-------------| +| `fetch_url(url)` | Fetch CSV/JSON/text from URLs | +| `get_sample_datasets()` | List available datasets (Titanic, Iris, Tips) | +| `get_current_time()` | Get current timestamp | +| `save_chart(image_data, filename)` | Save base64 chart to host | +| `list_saved_charts()` | List saved charts | + +### Example Queries + +1. "What is the survival rate on the Titanic?" +2. "Create a bar chart showing survival rate by passenger class" +3. "Analyze the iris dataset and create a scatter plot colored by species" +4. "Perform comprehensive analysis: stats, survival rates, visualization, insights" + +## Files Changed Summary + +``` + .github/CODING_AGENT_PLAN.md | Plan document + contributing/samples/coding_agent/README.md | 290 lines + contributing/samples/coding_agent/__init__.py | 17 lines + contributing/samples/coding_agent/agent.py | 360 lines + src/google/adk/agents/__init__.py | +2 exports + src/google/adk/agents/coding_agent.py | 610 lines + src/google/adk/agents/coding_agent_config.py | 225 lines + src/google/adk/code_executors/__init__.py | +6 exports + src/google/adk/code_executors/allowlist_validator.py | 355 lines + src/google/adk/code_executors/coding_agent_code_executor.py | 505 lines + src/google/adk/code_executors/tool_code_generator.py | 475 lines + src/google/adk/code_executors/tool_execution_server.py | 365 lines + tests/unittests/agents/test_coding_agent.py | 310 lines + tests/unittests/code_executors/test_allowlist_validator.py | 320 lines + tests/unittests/code_executors/test_tool_code_generator.py | 320 lines +``` + +**Total: ~4,200 lines of new code** + +## PR Checklist + +- [x] Implementation complete +- [x] Unit tests written and passing +- [x] Manual E2E tests passing +- [x] Sample agent created with README +- [x] Code follows ADK style guide (relative imports, `from __future__ import annotations`) +- [x] Marked as `@experimental` +- [ ] Run `./autoformat.sh` before PR +- [ ] Run full test suite: `pytest tests/unittests` +- [ ] Create GitHub issue (see `.github/CODING_AGENT_ISSUE.md`) +- [ ] Submit PR with testing plan + +## Future Enhancements (Out of Scope) + +- Stateful execution (persist variables across turns) +- Custom container images with pre-installed packages +- VertexAI code execution integration +- Support for JavaScript/TypeScript +- Streaming output during execution diff --git a/.github/ISSUE_TEMPLATE/coding_agent_feature.md b/.github/ISSUE_TEMPLATE/coding_agent_feature.md new file mode 100644 index 0000000000..f2b433a399 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/coding_agent_feature.md @@ -0,0 +1,226 @@ +--- +name: "Feature: CodingAgent - Code-generating Agent with Sandboxed Execution" +about: "New experimental agent type that generates and executes Python code" +title: "feat(agents): Add CodingAgent for code generation and sandboxed execution" +labels: "enhancement, agents, new-feature" +assignees: '' +--- + +## Summary + +Add a new experimental agent type called **CodingAgent** that generates Python code to solve tasks, executes it in a sandboxed Docker container, and iterates using a ReAct-style loop. This mirrors the popular "Code Interpreter" pattern seen in other AI platforms. + +## Is your feature request related to a problem? + +Currently, ADK agents can only interact with the world through pre-defined tools. While powerful, this approach has limitations: + +1. **Limited flexibility**: Users must anticipate all possible operations and create tools for each +2. **No computational capability**: Agents cannot perform complex calculations, data analysis, or create visualizations without custom tools +3. **No iteration**: Standard tool-calling doesn't easily support multi-step reasoning with intermediate computations +4. **Competitive gap**: Other platforms (OpenAI Code Interpreter, Anthropic's computer use) offer code execution capabilities + +**User pain points:** +- "I want my agent to analyze a CSV file and create a chart" - requires building custom tools +- "I need multi-step calculations with intermediate results" - awkward with standard tools +- "I want the agent to figure out HOW to solve a problem, not just call predefined functions" + +## Describe the solution + +### CodingAgent Overview + +A new agent type that: +1. Receives a task from the user +2. Generates Python code to accomplish the task +3. Executes the code in a sandboxed Docker container +4. Processes results and either provides an answer or continues iterating +5. Can call ADK tools from within generated code via HTTP IPC + +### Architecture + +``` +┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐ +│ User Query │────▶│ CodingAgent │────▶│ Docker Container│ +│ │ │ (Gemini LLM) │ │ (Python 3.11) │ +└─────────────────┘ └──────────────────┘ └─────────────────┘ + │ │ + │ │ Executes + ▼ │ generated code + ┌──────────────┐ │ + │ Tool Server │◀────────────────┘ + │ (HTTP IPC) │ Tool calls via HTTP + └──────────────┘ +``` + +### Key Components + +| Component | File | Description | +|-----------|------|-------------| +| CodingAgent | `src/google/adk/agents/coding_agent.py` | Main agent class with ReAct loop | +| CodingAgentConfig | `src/google/adk/agents/coding_agent_config.py` | Pydantic configuration | +| CodingAgentCodeExecutor | `src/google/adk/code_executors/coding_agent_code_executor.py` | Wrapper that injects tools | +| ToolCodeGenerator | `src/google/adk/code_executors/tool_code_generator.py` | Generates Python stubs for tools | +| ToolExecutionServer | `src/google/adk/code_executors/tool_execution_server.py` | HTTP server for tool IPC | +| AllowlistValidator | `src/google/adk/code_executors/allowlist_validator.py` | Import security validation | + +### API Design + +```python +from google.adk.agents import CodingAgent +from google.adk.code_executors import ContainerCodeExecutor + +def fetch_data(url: str) -> dict: + """Fetch data from a URL.""" + # Implementation... + +root_agent = CodingAgent( + name="data_analyst", + model="gemini-2.5-flash", + instruction="You are a data analyst. Analyze data and provide insights.", + tools=[fetch_data], # Tools available to generated code + code_executor=ContainerCodeExecutor(image="python:3.11-slim"), + authorized_imports=DEFAULT_SAFE_IMPORTS | {"pandas", "matplotlib"}, + max_iterations=10, + error_retry_attempts=2, +) +``` + +### Security Features + +1. **Sandboxed execution**: All code runs in isolated Docker containers +2. **Import allowlisting**: Only authorized imports are permitted (configurable) +3. **Tool isolation**: Tools execute on host via HTTP, not in container +4. **No filesystem access**: Container has no access to host filesystem +5. **Network isolation**: Container can only reach tool server + +### Sample Agent + +A complete Data Analysis Agent sample is included: +- Fetches datasets from URLs (Titanic, Iris, Tips) +- Analyzes data with pandas +- Creates visualizations with matplotlib +- Saves charts to host system via `save_chart` tool + +## Describe alternatives you've considered + +### Alternative 1: Extend LlmAgent with code execution +- **Pros**: Simpler architecture, reuses existing agent +- **Cons**: Conflates two distinct patterns, harder to configure + +### Alternative 2: Code execution as a tool only +- **Pros**: Minimal changes, fits existing model +- **Cons**: No ReAct loop, no iteration, limited capability + +### Alternative 3: Use external code execution service +- **Pros**: Offloads security concerns +- **Cons**: Adds external dependency, latency, cost + +**Chosen approach**: Dedicated CodingAgent provides cleanest separation of concerns, explicit configuration, and full control over the execution environment. + +## Implementation Details + +### Files Added/Modified + +**New files (agents):** +- `src/google/adk/agents/coding_agent.py` (~550 lines) +- `src/google/adk/agents/coding_agent_config.py` (~225 lines) + +**New files (code_executors):** +- `src/google/adk/code_executors/coding_agent_code_executor.py` (~500 lines) +- `src/google/adk/code_executors/tool_code_generator.py` (~475 lines) +- `src/google/adk/code_executors/tool_execution_server.py` (~365 lines) +- `src/google/adk/code_executors/allowlist_validator.py` (~350 lines) + +**Modified files:** +- `src/google/adk/agents/__init__.py` - Export CodingAgent +- `src/google/adk/code_executors/__init__.py` - Export new components + +**Sample agent:** +- `contributing/samples/coding_agent/agent.py` (~360 lines) +- `contributing/samples/coding_agent/README.md` (~290 lines) + +**Tests:** +- `tests/unittests/agents/test_coding_agent.py` (~310 lines) +- `tests/unittests/code_executors/test_allowlist_validator.py` (~320 lines) +- `tests/unittests/code_executors/test_tool_code_generator.py` (~320 lines) + +### How Tool IPC Works + +1. When CodingAgent starts, it launches a ToolExecutionServer on the host +2. Generated code includes tool stubs that make HTTP POST requests +3. Tool server receives requests, executes actual tool functions +4. Results are returned to container via HTTP response +5. On macOS/Windows: uses `host.docker.internal` +6. On Linux: uses Docker bridge network gateway + +### Experimental Status + +This feature is marked as **@experimental** because: +- API may change based on user feedback +- Security model is being refined +- Performance optimizations are ongoing + +## Testing Plan + +### Unit Tests + +```bash +pytest tests/unittests/agents/test_coding_agent.py -v +pytest tests/unittests/code_executors/test_allowlist_validator.py -v +pytest tests/unittests/code_executors/test_tool_code_generator.py -v +``` + +### Manual E2E Tests + +**Test 1: Basic Query** +``` +Query: "What is 25 times 17?" +Expected: Agent generates code, calculates, returns "425" +``` + +**Test 2: Data Analysis** +``` +Query: "What is the survival rate on the Titanic?" +Expected: Agent fetches data, analyzes with pandas, returns "38.38%" +``` + +**Test 3: Visualization** +``` +Query: "Create a bar chart of Titanic survival by class" +Expected: Agent creates chart, saves to /tmp/adk_charts/, reports filepath +``` + +**Test 4: Multi-step Analysis** +``` +Query: "Analyze Titanic data: show stats, survival by sex/class, and provide 3 insights" +Expected: Agent performs multiple steps, creates visualization, provides comprehensive answer +``` + +## Additional Context + +### Related Work +- OpenAI Code Interpreter +- Anthropic Computer Use +- Google AI Studio code execution + +### Dependencies +- Docker (required for sandboxed execution) +- ContainerCodeExecutor (existing ADK component) + +### Future Enhancements +- [ ] Support for stateful execution (persist variables across turns) +- [ ] Custom container images with pre-installed packages +- [ ] Integration with VertexAI code execution +- [ ] Support for additional languages (JavaScript, etc.) + +### Screenshots + +**Data Analysis Agent in action:** +``` +User: Create a bar chart showing survival rate by passenger class from Titanic + +Agent: [Installs packages, fetches data, creates visualization] +Response: The bar chart has been saved to /tmp/adk_charts/survival_by_class.png +- 1st Class: 63% survival rate +- 2nd Class: 47% survival rate +- 3rd Class: 24% survival rate +``` diff --git a/contributing/samples/coding_agent/agent.py b/contributing/samples/coding_agent/agent.py index 1c234864b9..4596d95318 100644 --- a/contributing/samples/coding_agent/agent.py +++ b/contributing/samples/coding_agent/agent.py @@ -36,121 +36,140 @@ import base64 import binascii +from datetime import datetime import os import urllib.error import urllib.request -from datetime import datetime from google.adk.agents import CodingAgent from google.adk.code_executors import ContainerCodeExecutor from google.adk.code_executors.allowlist_validator import DEFAULT_SAFE_IMPORTS - # Sample dataset URLs SAMPLE_DATASETS = { "titanic": { - "url": "https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv", - "description": "Titanic passenger data with survival information. 891 passengers with features like age, sex, class, fare, and survival status.", - "columns": "PassengerId, Survived, Pclass, Name, Sex, Age, SibSp, Parch, Ticket, Fare, Cabin, Embarked", + "url": ( + "https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv" + ), + "description": ( + "Titanic passenger data with survival information. 891 passengers" + " with features like age, sex, class, fare, and survival status." + ), + "columns": ( + "PassengerId, Survived, Pclass, Name, Sex, Age, SibSp, Parch," + " Ticket, Fare, Cabin, Embarked" + ), }, "iris": { - "url": "https://raw.githubusercontent.com/mwaskom/seaborn-data/master/iris.csv", - "description": "Iris flower dataset. 150 samples of 3 species with sepal and petal measurements.", - "columns": "sepal_length, sepal_width, petal_length, petal_width, species", + "url": ( + "https://raw.githubusercontent.com/mwaskom/seaborn-data/master/iris.csv" + ), + "description": ( + "Iris flower dataset. 150 samples of 3 species with sepal and petal" + " measurements." + ), + "columns": ( + "sepal_length, sepal_width, petal_length, petal_width, species" + ), }, "tips": { - "url": "https://raw.githubusercontent.com/mwaskom/seaborn-data/master/tips.csv", - "description": "Restaurant tips dataset. 244 records with bill amount, tip, and customer info.", + "url": ( + "https://raw.githubusercontent.com/mwaskom/seaborn-data/master/tips.csv" + ), + "description": ( + "Restaurant tips dataset. 244 records with bill amount, tip, and" + " customer info." + ), "columns": "total_bill, tip, sex, smoker, day, time, size", }, } def fetch_url(url: str) -> dict: - """Fetch content from a URL. - - Fetches data from the specified URL and returns the content along with - metadata. Supports CSV, JSON, and plain text content. - - Args: - url: The URL to fetch content from. - - Returns: - Dictionary containing: - - content: The fetched content as a string - - content_type: The MIME type of the content - - size: Size of the content in bytes - - url: The original URL - - success: Whether the fetch was successful - - error: Error message if fetch failed (only present on failure) - """ - try: - req = urllib.request.Request( - url, - headers={"User-Agent": "Mozilla/5.0 (compatible; ADK-DataAnalyst/1.0)"}, - ) - with urllib.request.urlopen(req, timeout=30) as response: - content = response.read().decode("utf-8") - content_type = response.headers.get("Content-Type", "text/plain") - return { - "content": content, - "content_type": content_type, - "size": len(content), - "url": url, - "success": True, - } - except urllib.error.URLError as e: - return { - "content": "", - "url": url, - "success": False, - "error": f"Failed to fetch URL: {str(e)}", - } - except Exception as e: - return { - "content": "", - "url": url, - "success": False, - "error": f"Unexpected error: {str(e)}", - } + """Fetch content from a URL. + + Fetches data from the specified URL and returns the content along with + metadata. Supports CSV, JSON, and plain text content. + + Args: + url: The URL to fetch content from. + + Returns: + Dictionary containing: + - content: The fetched content as a string + - content_type: The MIME type of the content + - size: Size of the content in bytes + - url: The original URL + - success: Whether the fetch was successful + - error: Error message if fetch failed (only present on failure) + """ + try: + req = urllib.request.Request( + url, + headers={"User-Agent": "Mozilla/5.0 (compatible; ADK-DataAnalyst/1.0)"}, + ) + with urllib.request.urlopen(req, timeout=30) as response: + content = response.read().decode("utf-8") + content_type = response.headers.get("Content-Type", "text/plain") + return { + "content": content, + "content_type": content_type, + "size": len(content), + "url": url, + "success": True, + } + except urllib.error.URLError as e: + return { + "content": "", + "url": url, + "success": False, + "error": f"Failed to fetch URL: {str(e)}", + } + except Exception as e: + return { + "content": "", + "url": url, + "success": False, + "error": f"Unexpected error: {str(e)}", + } def get_sample_datasets() -> dict: - """Get available sample datasets with their URLs and descriptions. + """Get available sample datasets with their URLs and descriptions. - Returns a dictionary of sample datasets that can be used for analysis. - Each dataset includes a URL, description, and column information. + Returns a dictionary of sample datasets that can be used for analysis. + Each dataset includes a URL, description, and column information. - Returns: - Dictionary with dataset names as keys, each containing: - - url: Direct URL to download the CSV file - - description: Brief description of the dataset - - columns: Comma-separated list of column names - """ - return SAMPLE_DATASETS + Returns: + Dictionary with dataset names as keys, each containing: + - url: Direct URL to download the CSV file + - description: Brief description of the dataset + - columns: Comma-separated list of column names + """ + return SAMPLE_DATASETS def get_current_time() -> dict: - """Get the current date and time. - - Returns: - Dictionary containing: - - timestamp: ISO format timestamp - - year, month, day: Date components - - hour, minute, second: Time components - - weekday: Name of the day of the week - """ - now = datetime.now() - return { - "timestamp": now.isoformat(), - "year": now.year, - "month": now.month, - "day": now.day, - "hour": now.hour, - "minute": now.minute, - "second": now.second, - "weekday": now.strftime("%A"), - } + """Get the current date and time. + + Returns: + Dictionary containing: + - timestamp: ISO format timestamp + - year, month, day: Date components + - hour, minute, second: Time components + - weekday: Name of the day of the week + """ + now = datetime.now() + return { + "timestamp": now.isoformat(), + "year": now.year, + "month": now.month, + "day": now.day, + "hour": now.hour, + "minute": now.minute, + "second": now.second, + "weekday": now.strftime("%A"), + } # Directory on host system to save charts @@ -158,143 +177,141 @@ def get_current_time() -> dict: def save_chart(image_data: str, filename: str) -> dict: - """Save a chart image to the host system. - - This tool saves base64-encoded image data to the host machine's filesystem, - making charts accessible outside the Docker container. - - To use this tool, first save your matplotlib figure to a bytes buffer, - then encode it as base64: - - Example: - import base64 - import io - import matplotlib.pyplot as plt - - # Create your plot - plt.figure() - plt.plot([1, 2, 3], [1, 4, 9]) - - # Save to buffer and encode - buf = io.BytesIO() - plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') - buf.seek(0) - image_data = base64.b64encode(buf.read()).decode('utf-8') - plt.close() - - # Save to host system - result = save_chart(image_data=image_data, filename="my_chart.png") - - Args: - image_data: Base64-encoded image data (PNG format recommended). - filename: Name for the saved file (e.g., "chart.png"). - - Returns: - Dictionary containing: - - success: Whether the save was successful - - filepath: Full path where the file was saved on the host - - size: Size of the saved file in bytes - - error: Error message if save failed (only present on failure) - """ - try: - # Ensure the output directory exists - os.makedirs(HOST_CHARTS_DIR, exist_ok=True) - - # Sanitize filename - safe_filename = os.path.basename(filename) - if not safe_filename: - safe_filename = "chart.png" - - filepath = os.path.join(HOST_CHARTS_DIR, safe_filename) - - # Decode and save - image_bytes = base64.b64decode(image_data) - with open(filepath, "wb") as f: - f.write(image_bytes) - - return { - "success": True, - "filepath": filepath, - "size": len(image_bytes), - "message": f"Chart saved to {filepath}", - } - except binascii.Error as e: - return { - "success": False, - "error": f"Invalid base64 data: {str(e)}", - } - except OSError as e: - return { - "success": False, - "error": f"Failed to save file: {str(e)}", - } - except Exception as e: - return { - "success": False, - "error": f"Unexpected error: {str(e)}", - } + """Save a chart image to the host system. + + This tool saves base64-encoded image data to the host machine's filesystem, + making charts accessible outside the Docker container. + + To use this tool, first save your matplotlib figure to a bytes buffer, + then encode it as base64: + + Example: + import base64 + import io + import matplotlib.pyplot as plt + + # Create your plot + plt.figure() + plt.plot([1, 2, 3], [1, 4, 9]) + + # Save to buffer and encode + buf = io.BytesIO() + plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') + buf.seek(0) + image_data = base64.b64encode(buf.read()).decode('utf-8') + plt.close() + + # Save to host system + result = save_chart(image_data=image_data, filename="my_chart.png") + + Args: + image_data: Base64-encoded image data (PNG format recommended). + filename: Name for the saved file (e.g., "chart.png"). + + Returns: + Dictionary containing: + - success: Whether the save was successful + - filepath: Full path where the file was saved on the host + - size: Size of the saved file in bytes + - error: Error message if save failed (only present on failure) + """ + try: + # Ensure the output directory exists + os.makedirs(HOST_CHARTS_DIR, exist_ok=True) + + # Sanitize filename + safe_filename = os.path.basename(filename) + if not safe_filename: + safe_filename = "chart.png" + + filepath = os.path.join(HOST_CHARTS_DIR, safe_filename) + + # Decode and save + image_bytes = base64.b64decode(image_data) + with open(filepath, "wb") as f: + f.write(image_bytes) + + return { + "success": True, + "filepath": filepath, + "size": len(image_bytes), + "message": f"Chart saved to {filepath}", + } + except binascii.Error as e: + return { + "success": False, + "error": f"Invalid base64 data: {str(e)}", + } + except OSError as e: + return { + "success": False, + "error": f"Failed to save file: {str(e)}", + } + except Exception as e: + return { + "success": False, + "error": f"Unexpected error: {str(e)}", + } def list_saved_charts() -> dict: - """List all charts saved on the host system. - - Returns: - Dictionary containing: - - success: Whether the operation was successful - - charts: List of saved chart filenames - - directory: The directory where charts are saved - - count: Number of charts found - """ - try: - if not os.path.exists(HOST_CHARTS_DIR): - return { - "success": True, - "charts": [], - "directory": HOST_CHARTS_DIR, - "count": 0, - } - - charts = [ - f - for f in os.listdir(HOST_CHARTS_DIR) - if f.lower().endswith((".png", ".jpg", ".jpeg", ".svg", ".pdf")) - ] - return { - "success": True, - "charts": charts, - "directory": HOST_CHARTS_DIR, - "count": len(charts), - } - except Exception as e: - return { - "success": False, - "error": f"Failed to list charts: {str(e)}", - } + """List all charts saved on the host system. + + Returns: + Dictionary containing: + - success: Whether the operation was successful + - charts: List of saved chart filenames + - directory: The directory where charts are saved + - count: Number of charts found + """ + try: + if not os.path.exists(HOST_CHARTS_DIR): + return { + "success": True, + "charts": [], + "directory": HOST_CHARTS_DIR, + "count": 0, + } + + charts = [ + f + for f in os.listdir(HOST_CHARTS_DIR) + if f.lower().endswith((".png", ".jpg", ".jpeg", ".svg", ".pdf")) + ] + return { + "success": True, + "charts": charts, + "directory": HOST_CHARTS_DIR, + "count": len(charts), + } + except Exception as e: + return { + "success": False, + "error": f"Failed to list charts: {str(e)}", + } # Additional imports allowed for data analysis -DATA_ANALYSIS_IMPORTS = frozenset( - { - # Data analysis - "pandas", - "pandas.*", - "numpy", - "numpy.*", - # Visualization - "matplotlib", - "matplotlib.*", - "seaborn", - "seaborn.*", - # Data I/O - "csv", - "io", - "io.*", - # Encoding for chart saving - "base64", - # Subprocess for pip installs - "subprocess", - } -) +DATA_ANALYSIS_IMPORTS = frozenset({ + # Data analysis + "pandas", + "pandas.*", + "numpy", + "numpy.*", + # Visualization + "matplotlib", + "matplotlib.*", + "seaborn", + "seaborn.*", + # Data I/O + "csv", + "io", + "io.*", + # Encoding for chart saving + "base64", + # Subprocess for pip installs + "subprocess", +}) # Create the Data Analysis Agent diff --git a/src/google/adk/agents/coding_agent.py b/src/google/adk/agents/coding_agent.py index 42128dacd8..850fbe505b 100644 --- a/src/google/adk/agents/coding_agent.py +++ b/src/google/adk/agents/coding_agent.py @@ -69,17 +69,17 @@ @experimental class CodingAgentState(BaseAgentState): - """State for CodingAgent tracking execution progress. + """State for CodingAgent tracking execution progress. - Attributes: - iteration_count: Number of ReAct loop iterations completed. - error_count: Number of consecutive errors encountered. - execution_history: List of execution steps with code, results, and traces. - """ + Attributes: + iteration_count: Number of ReAct loop iterations completed. + error_count: Number of consecutive errors encountered. + execution_history: List of execution steps with code, results, and traces. + """ - iteration_count: int = 0 - error_count: int = 0 - execution_history: List[Dict[str, Any]] = Field(default_factory=list) + iteration_count: int = 0 + error_count: int = 0 + execution_history: List[Dict[str, Any]] = Field(default_factory=list) ToolUnion = Union[Callable[..., Any], BaseTool, BaseToolset] @@ -89,296 +89,297 @@ async def _convert_tool_union_to_tools( tool_union: ToolUnion, ctx: Optional[ReadonlyContext] = None, ) -> List[BaseTool]: - """Convert a tool union to a list of BaseTool instances. + """Convert a tool union to a list of BaseTool instances. + + Args: + tool_union: A callable, BaseTool, or BaseToolset. + ctx: Optional context for toolset resolution. + + Returns: + List of BaseTool instances. + """ + if isinstance(tool_union, BaseTool): + return [tool_union] + if callable(tool_union): + return [FunctionTool(func=tool_union)] + # BaseToolset + if ctx: + return await tool_union.get_tools_with_prefix(ctx) + return await tool_union.get_tools_with_prefix(None) + + +@experimental +class CodingAgent(BaseAgent): + """Agent that generates Python code to solve tasks using available tools. + + CodingAgent implements a ReAct-style loop where it: + 1. Receives a task from the user + 2. Generates Python code that calls available tools + 3. Executes the code in a sandboxed environment + 4. Processes the results and either provides an answer or continues + + Tools are made available as Python functions that the generated code + can call. The code execution happens in a container for security, + with tool calls routed via HTTP to the host. + + Attributes: + model: The LLM model to use for code generation. + instruction: Additional instructions for the agent. + tools: List of tools available to the agent. + code_executor: The underlying code executor (e.g., ContainerCodeExecutor). + authorized_imports: Set of allowed Python imports. + max_iterations: Maximum ReAct loop iterations. + error_retry_attempts: Number of retries on execution errors. + stateful: Whether to maintain state across iterations. + tool_server_host: Host for the tool execution server. + tool_server_port: Port for the tool execution server. + """ + + DEFAULT_MODEL: ClassVar[str] = "gemini-2.5-flash" + + config_type: ClassVar[Type[BaseAgentConfig]] = CodingAgentConfig + + model: Union[str, BaseLlm] = "" + """The model to use for code generation.""" + + instruction: str = "" + """Additional instructions for the agent.""" + + tools: List[ToolUnion] = Field(default_factory=list) + """Tools available to the agent.""" + + code_executor: Optional[BaseCodeExecutor] = None + """The underlying code executor. If not set, uses ContainerCodeExecutor.""" + + authorized_imports: FrozenSet[str] = DEFAULT_SAFE_IMPORTS + """Set of allowed import patterns.""" + + max_iterations: int = 10 + """Maximum number of ReAct loop iterations.""" + + error_retry_attempts: int = 2 + """Number of retries on execution errors.""" + + stateful: bool = False + """Whether to maintain state across iterations.""" + + tool_server_host: Optional[str] = None + """Host for the tool execution server.""" + + tool_server_port: int = 8765 + """Port for the tool execution server.""" + + # Internal state + _coding_executor: Optional[CodingAgentCodeExecutor] = None + _resolved_tools: Optional[List[BaseTool]] = None + + class Config: + """Pydantic config.""" + + arbitrary_types_allowed = True + + @property + def canonical_model(self) -> BaseLlm: + """Get the resolved model as BaseLlm.""" + if isinstance(self.model, BaseLlm): + return self.model + elif self.model: + return LLMRegistry.new_llm(self.model) + else: + # Find model from ancestors + ancestor_agent = self.parent_agent + while ancestor_agent is not None: + if hasattr(ancestor_agent, "canonical_model"): + return ancestor_agent.canonical_model + ancestor_agent = ancestor_agent.parent_agent + return LLMRegistry.new_llm(self.DEFAULT_MODEL) + + async def _resolve_tools( + self, + ctx: Optional[ReadonlyContext] = None, + ) -> List[BaseTool]: + """Resolve tool unions to BaseTool instances. Args: - tool_union: A callable, BaseTool, or BaseToolset. ctx: Optional context for toolset resolution. Returns: - List of BaseTool instances. + List of resolved BaseTool instances. """ - if isinstance(tool_union, BaseTool): - return [tool_union] - if callable(tool_union): - return [FunctionTool(func=tool_union)] - # BaseToolset - if ctx: - return await tool_union.get_tools_with_prefix(ctx) - return await tool_union.get_tools_with_prefix(None) + if self._resolved_tools is not None: + return self._resolved_tools + resolved = [] + for tool_union in self.tools: + resolved.extend(await _convert_tool_union_to_tools(tool_union, ctx)) -@experimental -class CodingAgent(BaseAgent): - """Agent that generates Python code to solve tasks using available tools. - - CodingAgent implements a ReAct-style loop where it: - 1. Receives a task from the user - 2. Generates Python code that calls available tools - 3. Executes the code in a sandboxed environment - 4. Processes the results and either provides an answer or continues - - Tools are made available as Python functions that the generated code - can call. The code execution happens in a container for security, - with tool calls routed via HTTP to the host. - - Attributes: - model: The LLM model to use for code generation. - instruction: Additional instructions for the agent. - tools: List of tools available to the agent. - code_executor: The underlying code executor (e.g., ContainerCodeExecutor). - authorized_imports: Set of allowed Python imports. - max_iterations: Maximum ReAct loop iterations. - error_retry_attempts: Number of retries on execution errors. - stateful: Whether to maintain state across iterations. - tool_server_host: Host for the tool execution server. - tool_server_port: Port for the tool execution server. + self._resolved_tools = resolved + return resolved + + async def _get_coding_executor( + self, + ctx: InvocationContext, + ) -> CodingAgentCodeExecutor: + """Get or create the CodingAgentCodeExecutor. + + Args: + ctx: The invocation context. + + Returns: + The configured code executor. """ + if self._coding_executor is not None: + return self._coding_executor + + # Resolve tools + tools = await self._resolve_tools(ReadonlyContext(ctx)) + + # Get or create underlying executor + if self.code_executor: + underlying = self.code_executor + else: + # Default to ContainerCodeExecutor + try: + from ..code_executors.container_code_executor import ContainerCodeExecutor + + underlying = ContainerCodeExecutor( + image="python:3.11-slim", + ) + except ImportError as e: + raise ImportError( + "CodingAgent requires ContainerCodeExecutor. " + 'Please install with: pip install "google-adk[extensions]" ' + "or provide a custom code_executor." + ) from e + + # Create the CodingAgentCodeExecutor wrapper + self._coding_executor = CodingAgentCodeExecutor( + underlying_executor=underlying, + tools=tools, + authorized_imports=self.authorized_imports, + tool_server_host=self.tool_server_host, + tool_server_port=self.tool_server_port, + stateful=self.stateful, + error_retry_attempts=self.error_retry_attempts, + ) + + return self._coding_executor + + def _build_system_prompt(self, tools: List[BaseTool]) -> str: + """Build the system prompt with tool documentation. - DEFAULT_MODEL: ClassVar[str] = "gemini-2.5-flash" + Args: + tools: List of available tools. - config_type: ClassVar[Type[BaseAgentConfig]] = CodingAgentConfig + Returns: + The complete system prompt. + """ + return generate_system_prompt( + tools=tools, + custom_instruction=self.instruction, + ) - model: Union[str, BaseLlm] = "" - """The model to use for code generation.""" - - instruction: str = "" - """Additional instructions for the agent.""" + def _extract_code_block(self, response_text: str) -> Optional[str]: + """Extract code from the model response. - tools: List[ToolUnion] = Field(default_factory=list) - """Tools available to the agent.""" - - code_executor: Optional[BaseCodeExecutor] = None - """The underlying code executor. If not set, uses ContainerCodeExecutor.""" - - authorized_imports: FrozenSet[str] = DEFAULT_SAFE_IMPORTS - """Set of allowed import patterns.""" - - max_iterations: int = 10 - """Maximum number of ReAct loop iterations.""" - - error_retry_attempts: int = 2 - """Number of retries on execution errors.""" - - stateful: bool = False - """Whether to maintain state across iterations.""" - - tool_server_host: Optional[str] = None - """Host for the tool execution server.""" - - tool_server_port: int = 8765 - """Port for the tool execution server.""" - - # Internal state - _coding_executor: Optional[CodingAgentCodeExecutor] = None - _resolved_tools: Optional[List[BaseTool]] = None - - class Config: - """Pydantic config.""" - - arbitrary_types_allowed = True - - @property - def canonical_model(self) -> BaseLlm: - """Get the resolved model as BaseLlm.""" - if isinstance(self.model, BaseLlm): - return self.model - elif self.model: - return LLMRegistry.new_llm(self.model) - else: - # Find model from ancestors - ancestor_agent = self.parent_agent - while ancestor_agent is not None: - if hasattr(ancestor_agent, "canonical_model"): - return ancestor_agent.canonical_model - ancestor_agent = ancestor_agent.parent_agent - return LLMRegistry.new_llm(self.DEFAULT_MODEL) - - async def _resolve_tools( - self, - ctx: Optional[ReadonlyContext] = None, - ) -> List[BaseTool]: - """Resolve tool unions to BaseTool instances. - - Args: - ctx: Optional context for toolset resolution. - - Returns: - List of resolved BaseTool instances. - """ - if self._resolved_tools is not None: - return self._resolved_tools - - resolved = [] - for tool_union in self.tools: - resolved.extend(await _convert_tool_union_to_tools(tool_union, ctx)) - - self._resolved_tools = resolved - return resolved - - async def _get_coding_executor( - self, - ctx: InvocationContext, - ) -> CodingAgentCodeExecutor: - """Get or create the CodingAgentCodeExecutor. - - Args: - ctx: The invocation context. - - Returns: - The configured code executor. - """ - if self._coding_executor is not None: - return self._coding_executor - - # Resolve tools - tools = await self._resolve_tools(ReadonlyContext(ctx)) - - # Get or create underlying executor - if self.code_executor: - underlying = self.code_executor - else: - # Default to ContainerCodeExecutor - try: - from ..code_executors.container_code_executor import ( - ContainerCodeExecutor, - ) - - underlying = ContainerCodeExecutor( - image="python:3.11-slim", - ) - except ImportError as e: - raise ImportError( - "CodingAgent requires ContainerCodeExecutor. " - 'Please install with: pip install "google-adk[extensions]" ' - "or provide a custom code_executor." - ) from e - - # Create the CodingAgentCodeExecutor wrapper - self._coding_executor = CodingAgentCodeExecutor( - underlying_executor=underlying, - tools=tools, - authorized_imports=self.authorized_imports, - tool_server_host=self.tool_server_host, - tool_server_port=self.tool_server_port, - stateful=self.stateful, - error_retry_attempts=self.error_retry_attempts, - ) + Args: + response_text: The model's response text. + + Returns: + The extracted code, or None if no code block found. + """ + # Try tool_code blocks first + pattern = r"```tool_code\n(.*?)```" + match = re.search(pattern, response_text, re.DOTALL) + if match: + return match.group(1).strip() - return self._coding_executor + # Fall back to python blocks + pattern = r"```python\n(.*?)```" + match = re.search(pattern, response_text, re.DOTALL) + if match: + return match.group(1).strip() - def _build_system_prompt(self, tools: List[BaseTool]) -> str: - """Build the system prompt with tool documentation. + return None - Args: - tools: List of available tools. + def _is_real_error(self, stderr: str) -> bool: + """Check if stderr contains a real error vs just warnings. - Returns: - The complete system prompt. - """ - return generate_system_prompt( - tools=tools, - custom_instruction=self.instruction, - ) + Args: + stderr: The stderr output from code execution. - def _extract_code_block(self, response_text: str) -> Optional[str]: - """Extract code from the model response. - - Args: - response_text: The model's response text. - - Returns: - The extracted code, or None if no code block found. - """ - # Try tool_code blocks first - pattern = r"```tool_code\n(.*?)```" - match = re.search(pattern, response_text, re.DOTALL) - if match: - return match.group(1).strip() - - # Fall back to python blocks - pattern = r"```python\n(.*?)```" - match = re.search(pattern, response_text, re.DOTALL) - if match: - return match.group(1).strip() - - return None - - def _is_real_error(self, stderr: str) -> bool: - """Check if stderr contains a real error vs just warnings. - - Args: - stderr: The stderr output from code execution. - - Returns: - True if stderr contains a real error, False if just warnings. - """ - if not stderr: - return False - - # Patterns that indicate this is just a warning, not an error - warning_patterns = [ - "WARNING: Running pip as the 'root' user", - "[notice] A new release of pip", - "[notice] To update, run:", - "pip install --upgrade pip", - "UserWarning:", - "DeprecationWarning:", - "FutureWarning:", - "RuntimeWarning:", - ] - - # Check if ALL lines are just warnings - lines = stderr.strip().split("\n") - real_error_lines = [] - for line in lines: - line_stripped = line.strip() - if not line_stripped: - continue - is_warning = any( - pattern.lower() in line_stripped.lower() for pattern in warning_patterns - ) - if not is_warning: - real_error_lines.append(line) - - # Also check for actual error keywords - error_keywords = [ - "error:", - "traceback", - "exception", - "syntaxerror", - "nameerror", - "typeerror", - "valueerror", - "importerror", - "modulenotfounderror", - "attributeerror", - "keyerror", - "indexerror", - "zerodivisionerror", - ] - - stderr_lower = stderr.lower() - has_error_keyword = any(keyword in stderr_lower for keyword in error_keywords) - - # Consider it a real error if there are non-warning lines with error keywords - return bool(real_error_lines) and has_error_keyword - - def _build_error_feedback( - self, - error: str, - code: str, - ) -> str: - """Build feedback message for execution errors. - - Args: - error: The error message. - code: The code that caused the error. - - Returns: - Formatted error feedback for the LLM. - """ - return f"""The code execution failed with the following error: + Returns: + True if stderr contains a real error, False if just warnings. + """ + if not stderr: + return False + + # Patterns that indicate this is just a warning, not an error + warning_patterns = [ + "WARNING: Running pip as the 'root' user", + "[notice] A new release of pip", + "[notice] To update, run:", + "pip install --upgrade pip", + "UserWarning:", + "DeprecationWarning:", + "FutureWarning:", + "RuntimeWarning:", + ] + + # Check if ALL lines are just warnings + lines = stderr.strip().split("\n") + real_error_lines = [] + for line in lines: + line_stripped = line.strip() + if not line_stripped: + continue + is_warning = any( + pattern.lower() in line_stripped.lower() + for pattern in warning_patterns + ) + if not is_warning: + real_error_lines.append(line) + + # Also check for actual error keywords + error_keywords = [ + "error:", + "traceback", + "exception", + "syntaxerror", + "nameerror", + "typeerror", + "valueerror", + "importerror", + "modulenotfounderror", + "attributeerror", + "keyerror", + "indexerror", + "zerodivisionerror", + ] + + stderr_lower = stderr.lower() + has_error_keyword = any( + keyword in stderr_lower for keyword in error_keywords + ) + + # Consider it a real error if there are non-warning lines with error keywords + return bool(real_error_lines) and has_error_keyword + + def _build_error_feedback( + self, + error: str, + code: str, + ) -> str: + """Build feedback message for execution errors. + + Args: + error: The error message. + code: The code that caused the error. + + Returns: + Formatted error feedback for the LLM. + """ + return f"""The code execution failed with the following error: ``` {error} @@ -395,217 +396,215 @@ def _build_error_feedback( - Python syntax errors """ - @override - async def _run_async_impl( - self, - ctx: InvocationContext, - ) -> AsyncGenerator[Event, None]: - """Core implementation of the ReAct loop. - - Args: - ctx: The invocation context. - - Yields: - Events generated during execution. - """ - # Load or initialize state - state = self._load_agent_state(ctx, CodingAgentState) - if state is None: - state = CodingAgentState() - - # Resolve tools and get executor - tools = await self._resolve_tools(ReadonlyContext(ctx)) - coding_executor = await self._get_coding_executor(ctx) - - # Create tool context for the executor - tool_context = ToolContext(invocation_context=ctx) - coding_executor.set_context(ctx, tool_context) - - # Build system prompt - system_prompt = self._build_system_prompt(tools) - - # Get the model - model = self.canonical_model - - # Build initial request with conversation history - contents = [] - events = ctx._get_events(current_invocation=True, current_branch=True) - for event in events: - if event.content: - contents.append(event.content) - - iteration = 0 - error_count = 0 - final_answer = None - - while iteration < self.max_iterations: - iteration += 1 - state.iteration_count = iteration - - # Build LLM request - llm_request = LlmRequest( - model=model.model, - contents=contents, - config=types.GenerateContentConfig( - system_instruction=system_prompt, - ), - ) + @override + async def _run_async_impl( + self, + ctx: InvocationContext, + ) -> AsyncGenerator[Event, None]: + """Core implementation of the ReAct loop. - # Call the model (generate_content_async returns an async generator) - llm_response = None - async for response in model.generate_content_async( - llm_request, stream=False - ): - llm_response = response - break - - # Extract response text - response_text = "" - if llm_response and llm_response.content and llm_response.content.parts: - response_text = "".join( - part.text for part in llm_response.content.parts if part.text - ) - - # Check for code block - code = self._extract_code_block(response_text) - - if not code: - # No code generated - treat as final response - # Check if the response looks like a final answer - final_answer = response_text - break - - # Execute the code - code_input = CodeExecutionInput(code=code) - exec_result = coding_executor.execute_code_extended( - invocation_context=ctx, - code_execution_input=code_input, - ) + Args: + ctx: The invocation context. - # Record execution in state - state.execution_history.append( - { - "iteration": iteration, - "code": code, - "stdout": exec_result.clean_stdout, - "stderr": exec_result.code_result.stderr, - "tool_traces": exec_result.tool_traces, - "has_final_answer": exec_result.has_final_answer, - } - ) + Yields: + Events generated during execution. + """ + # Load or initialize state + state = self._load_agent_state(ctx, CodingAgentState) + if state is None: + state = CodingAgentState() + + # Resolve tools and get executor + tools = await self._resolve_tools(ReadonlyContext(ctx)) + coding_executor = await self._get_coding_executor(ctx) + + # Create tool context for the executor + tool_context = ToolContext(invocation_context=ctx) + coding_executor.set_context(ctx, tool_context) + + # Build system prompt + system_prompt = self._build_system_prompt(tools) + + # Get the model + model = self.canonical_model + + # Build initial request with conversation history + contents = [] + events = ctx._get_events(current_invocation=True, current_branch=True) + for event in events: + if event.content: + contents.append(event.content) + + iteration = 0 + error_count = 0 + final_answer = None + + while iteration < self.max_iterations: + iteration += 1 + state.iteration_count = iteration + + # Build LLM request + llm_request = LlmRequest( + model=model.model, + contents=contents, + config=types.GenerateContentConfig( + system_instruction=system_prompt, + ), + ) + + # Call the model (generate_content_async returns an async generator) + llm_response = None + async for response in model.generate_content_async( + llm_request, stream=False + ): + llm_response = response + break + + # Extract response text + response_text = "" + if llm_response and llm_response.content and llm_response.content.parts: + response_text = "".join( + part.text for part in llm_response.content.parts if part.text + ) - # Check for errors - ignore warnings from pip and other non-fatal stderr - stderr = exec_result.code_result.stderr or "" - is_real_error = self._is_real_error(stderr) - - if is_real_error: - error_count += 1 - state.error_count = error_count - - if error_count > self.error_retry_attempts: - # Too many errors - give up - final_answer = ( - f"I encountered too many errors while executing code. " - f"Last error: {stderr}" - ) - break - - # Build error feedback and add to conversation - error_feedback = self._build_error_feedback( - stderr, - code, - ) - contents.append( - types.Content( - role="model", - parts=[types.Part(text=response_text)], - ) - ) - contents.append( - types.Content( - role="user", - parts=[types.Part(text=error_feedback)], - ) - ) - continue - - # Reset error count on success - error_count = 0 - state.error_count = 0 - - # Check for final answer - if exec_result.has_final_answer: - final_answer = exec_result.final_answer - break - - # Add execution result to conversation and continue - contents.append( - types.Content( - role="model", - parts=[types.Part(text=response_text)], - ) + # Check for code block + code = self._extract_code_block(response_text) + + if not code: + # No code generated - treat as final response + # Check if the response looks like a final answer + final_answer = response_text + break + + # Execute the code + code_input = CodeExecutionInput(code=code) + exec_result = coding_executor.execute_code_extended( + invocation_context=ctx, + code_execution_input=code_input, + ) + + # Record execution in state + state.execution_history.append({ + "iteration": iteration, + "code": code, + "stdout": exec_result.clean_stdout, + "stderr": exec_result.code_result.stderr, + "tool_traces": exec_result.tool_traces, + "has_final_answer": exec_result.has_final_answer, + }) + + # Check for errors - ignore warnings from pip and other non-fatal stderr + stderr = exec_result.code_result.stderr or "" + is_real_error = self._is_real_error(stderr) + + if is_real_error: + error_count += 1 + state.error_count = error_count + + if error_count > self.error_retry_attempts: + # Too many errors - give up + final_answer = ( + "I encountered too many errors while executing code. " + f"Last error: {stderr}" + ) + break + + # Build error feedback and add to conversation + error_feedback = self._build_error_feedback( + stderr, + code, + ) + contents.append( + types.Content( + role="model", + parts=[types.Part(text=response_text)], ) - - # Add execution output as user message - output_text = f"""Code execution result: + ) + contents.append( + types.Content( + role="user", + parts=[types.Part(text=error_feedback)], + ) + ) + continue + + # Reset error count on success + error_count = 0 + state.error_count = 0 + + # Check for final answer + if exec_result.has_final_answer: + final_answer = exec_result.final_answer + break + + # Add execution result to conversation and continue + contents.append( + types.Content( + role="model", + parts=[types.Part(text=response_text)], + ) + ) + + # Add execution output as user message + output_text = f"""Code execution result: ``` {exec_result.clean_stdout} ``` """ - contents.append( - types.Content( - role="user", - parts=[types.Part(text=output_text)], - ) - ) - - # Build final event - if final_answer is None: - final_answer = ( - "I was unable to complete the task within the allowed iterations." - ) - - # Convert final_answer to string if needed - if not isinstance(final_answer, str): - import json - - try: - final_answer = json.dumps(final_answer) - except (TypeError, ValueError): - final_answer = str(final_answer) - - # Update state in context - ctx.agent_states[self.name] = state.model_dump() - - # Yield final event - yield Event( - invocation_id=ctx.invocation_id, - author=self.name, - branch=ctx.branch, - content=types.Content( - role="model", - parts=[types.Part(text=final_answer)], - ), - actions=EventActions( - agent_state=state.model_dump(), - ), - ) - - @model_validator(mode="after") - def _validate_model(self) -> CodingAgent: - """Validate the model after construction.""" - return self - - def cleanup(self) -> None: - """Clean up resources.""" - if self._coding_executor: - self._coding_executor.cleanup() - self._coding_executor = None - self._resolved_tools = None - - def __del__(self): - """Destructor to clean up resources.""" - try: - self.cleanup() - except Exception: - pass + contents.append( + types.Content( + role="user", + parts=[types.Part(text=output_text)], + ) + ) + + # Build final event + if final_answer is None: + final_answer = ( + "I was unable to complete the task within the allowed iterations." + ) + + # Convert final_answer to string if needed + if not isinstance(final_answer, str): + import json + + try: + final_answer = json.dumps(final_answer) + except (TypeError, ValueError): + final_answer = str(final_answer) + + # Update state in context + ctx.agent_states[self.name] = state.model_dump() + + # Yield final event + yield Event( + invocation_id=ctx.invocation_id, + author=self.name, + branch=ctx.branch, + content=types.Content( + role="model", + parts=[types.Part(text=final_answer)], + ), + actions=EventActions( + agent_state=state.model_dump(), + ), + ) + + @model_validator(mode="after") + def _validate_model(self) -> CodingAgent: + """Validate the model after construction.""" + return self + + def cleanup(self) -> None: + """Clean up resources.""" + if self._coding_executor: + self._coding_executor.cleanup() + self._coding_executor = None + self._resolved_tools = None + + def __del__(self): + """Destructor to clean up resources.""" + try: + self.cleanup() + except Exception: + pass diff --git a/src/google/adk/code_executors/allowlist_validator.py b/src/google/adk/code_executors/allowlist_validator.py index 68319bf16f..57afc0e534 100644 --- a/src/google/adk/code_executors/allowlist_validator.py +++ b/src/google/adk/code_executors/allowlist_validator.py @@ -17,10 +17,10 @@ from __future__ import annotations import ast -import fnmatch -import logging from dataclasses import dataclass from dataclasses import field +import fnmatch +import logging from typing import FrozenSet from typing import List from typing import Set @@ -29,326 +29,325 @@ # Default set of safe imports that are always allowed -DEFAULT_SAFE_IMPORTS: FrozenSet[str] = frozenset( - { - # Standard library - safe modules - "json", - "math", - "re", - "datetime", - "collections", - "collections.*", - "itertools", - "functools", - "operator", - "string", - "textwrap", - "unicodedata", - "decimal", - "fractions", - "random", - "statistics", - "typing", - "typing.*", - "dataclasses", - "enum", - "abc", - "copy", - "pprint", - "reprlib", - "numbers", - "cmath", - "time", - "calendar", - "hashlib", - "hmac", - "base64", - "binascii", - "html", - "html.*", - "urllib.parse", - "uuid", - "struct", - "codecs", - "locale", - "gettext", - "bisect", - "heapq", - "array", - "weakref", - "types", - "contextlib", - "warnings", - "traceback", - "linecache", - "difflib", - "graphlib", - "zoneinfo", - } -) +DEFAULT_SAFE_IMPORTS: FrozenSet[str] = frozenset({ + # Standard library - safe modules + "json", + "math", + "re", + "datetime", + "collections", + "collections.*", + "itertools", + "functools", + "operator", + "string", + "textwrap", + "unicodedata", + "decimal", + "fractions", + "random", + "statistics", + "typing", + "typing.*", + "dataclasses", + "enum", + "abc", + "copy", + "pprint", + "reprlib", + "numbers", + "cmath", + "time", + "calendar", + "hashlib", + "hmac", + "base64", + "binascii", + "html", + "html.*", + "urllib.parse", + "uuid", + "struct", + "codecs", + "locale", + "gettext", + "bisect", + "heapq", + "array", + "weakref", + "types", + "contextlib", + "warnings", + "traceback", + "linecache", + "difflib", + "graphlib", + "zoneinfo", +}) class ImportValidationError(Exception): - """Exception raised when import validation fails. - - Attributes: - violations: List of import violations found. - code: The code that was validated. - """ - - def __init__( - self, - violations: List[str], - code: str, - ): - self.violations = violations - self.code = code - violation_str = "\n".join(f" - {v}" for v in violations) - super().__init__( - f"Import validation failed. Unauthorized imports found:\n{violation_str}" - ) + """Exception raised when import validation fails. + + Attributes: + violations: List of import violations found. + code: The code that was validated. + """ + + def __init__( + self, + violations: List[str], + code: str, + ): + self.violations = violations + self.code = code + violation_str = "\n".join(f" - {v}" for v in violations) + super().__init__( + "Import validation failed. Unauthorized imports" + f" found:\n{violation_str}" + ) @dataclass class ImportInfo: - """Information about an import statement. - - Attributes: - module: The module being imported. - names: Names being imported from the module (for 'from' imports). - alias: Alias for the import (if any). - line_number: Line number in the source code. - is_from_import: Whether this is a 'from X import Y' style import. - """ - - module: str - names: List[str] = field(default_factory=list) - alias: str = "" - line_number: int = 0 - is_from_import: bool = False + """Information about an import statement. + Attributes: + module: The module being imported. + names: Names being imported from the module (for 'from' imports). + alias: Alias for the import (if any). + line_number: Line number in the source code. + is_from_import: Whether this is a 'from X import Y' style import. + """ -def extract_imports(code: str) -> List[ImportInfo]: - """Extract all import statements from Python code using AST. + module: str + names: List[str] = field(default_factory=list) + alias: str = "" + line_number: int = 0 + is_from_import: bool = False - Args: - code: Python source code to analyze. - Returns: - List of ImportInfo objects describing each import. +def extract_imports(code: str) -> List[ImportInfo]: + """Extract all import statements from Python code using AST. + + Args: + code: Python source code to analyze. + + Returns: + List of ImportInfo objects describing each import. + + Raises: + SyntaxError: If the code cannot be parsed. + """ + imports = [] + + try: + tree = ast.parse(code) + except SyntaxError as e: + logger.warning("Failed to parse code for import extraction: %s", e) + raise + + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + imports.append( + ImportInfo( + module=alias.name, + alias=alias.asname or "", + line_number=node.lineno, + is_from_import=False, + ) + ) + elif isinstance(node, ast.ImportFrom): + module = node.module or "" + names = [alias.name for alias in node.names] + for alias in node.names: + imports.append( + ImportInfo( + module=module, + names=[alias.name], + alias=alias.asname or "", + line_number=node.lineno, + is_from_import=True, + ) + ) - Raises: - SyntaxError: If the code cannot be parsed. - """ - imports = [] - - try: - tree = ast.parse(code) - except SyntaxError as e: - logger.warning("Failed to parse code for import extraction: %s", e) - raise - - for node in ast.walk(tree): - if isinstance(node, ast.Import): - for alias in node.names: - imports.append( - ImportInfo( - module=alias.name, - alias=alias.asname or "", - line_number=node.lineno, - is_from_import=False, - ) - ) - elif isinstance(node, ast.ImportFrom): - module = node.module or "" - names = [alias.name for alias in node.names] - for alias in node.names: - imports.append( - ImportInfo( - module=module, - names=[alias.name], - alias=alias.asname or "", - line_number=node.lineno, - is_from_import=True, - ) - ) - - return imports + return imports def is_import_allowed( import_name: str, allowlist: FrozenSet[str], ) -> bool: - """Check if an import is allowed by the allowlist. - - Supports wildcards: - - 'collections.*' allows 'collections.abc', 'collections.defaultdict', etc. - - 'numpy' allows only 'numpy', not 'numpy.array' - - 'numpy.*' allows 'numpy.array', 'numpy.linalg', etc. - - Args: - import_name: The full import name to check. - allowlist: Set of allowed import patterns. - - Returns: - True if the import is allowed, False otherwise. - """ - # Direct match - if import_name in allowlist: + """Check if an import is allowed by the allowlist. + + Supports wildcards: + - 'collections.*' allows 'collections.abc', 'collections.defaultdict', etc. + - 'numpy' allows only 'numpy', not 'numpy.array' + - 'numpy.*' allows 'numpy.array', 'numpy.linalg', etc. + + Args: + import_name: The full import name to check. + allowlist: Set of allowed import patterns. + + Returns: + True if the import is allowed, False otherwise. + """ + # Direct match + if import_name in allowlist: + return True + + # Check wildcard patterns + for pattern in allowlist: + if "*" in pattern: + if fnmatch.fnmatch(import_name, pattern): + return True + # Also check if the import is a submodule of an allowed module + # e.g., 'collections.*' should allow 'collections.abc.Callable' + base_pattern = pattern.rstrip(".*") + if import_name.startswith(base_pattern + "."): return True - # Check wildcard patterns - for pattern in allowlist: - if "*" in pattern: - if fnmatch.fnmatch(import_name, pattern): - return True - # Also check if the import is a submodule of an allowed module - # e.g., 'collections.*' should allow 'collections.abc.Callable' - base_pattern = pattern.rstrip(".*") - if import_name.startswith(base_pattern + "."): - return True - - # Check if parent module is allowed with wildcard - parts = import_name.split(".") - for i in range(len(parts)): - parent = ".".join(parts[: i + 1]) - wildcard_pattern = parent + ".*" - if wildcard_pattern in allowlist: - return True + # Check if parent module is allowed with wildcard + parts = import_name.split(".") + for i in range(len(parts)): + parent = ".".join(parts[: i + 1]) + wildcard_pattern = parent + ".*" + if wildcard_pattern in allowlist: + return True - return False + return False def validate_imports( code: str, allowlist: FrozenSet[str], ) -> List[str]: - """Validate that all imports in code are in the allowlist. + """Validate that all imports in code are in the allowlist. - Args: - code: Python source code to validate. - allowlist: Set of allowed import patterns. + Args: + code: Python source code to validate. + allowlist: Set of allowed import patterns. - Returns: - List of violations (empty if all imports are valid). + Returns: + List of violations (empty if all imports are valid). - Raises: - ImportValidationError: If unauthorized imports are found. - """ - violations = [] - - try: - imports = extract_imports(code) - except SyntaxError as e: - # If we can't parse, we can't validate - return syntax error as violation - violations.append(f"Syntax error in code: {e}") - return violations - - for import_info in imports: - module = import_info.module - - if import_info.is_from_import: - # For 'from X import Y', check both the module and the full path - for name in import_info.names: - full_name = f"{module}.{name}" if module else name - # Check if module is allowed OR full import path is allowed - if not ( - is_import_allowed(module, allowlist) - or is_import_allowed(full_name, allowlist) - ): - violations.append( - f"Line {import_info.line_number}: " - f'Unauthorized import "from {module} import {name}"' - ) - else: - # For 'import X', just check the module - if not is_import_allowed(module, allowlist): - violations.append( - f'Line {import_info.line_number}: Unauthorized import "{module}"' - ) + Raises: + ImportValidationError: If unauthorized imports are found. + """ + violations = [] + try: + imports = extract_imports(code) + except SyntaxError as e: + # If we can't parse, we can't validate - return syntax error as violation + violations.append(f"Syntax error in code: {e}") return violations + for import_info in imports: + module = import_info.module + + if import_info.is_from_import: + # For 'from X import Y', check both the module and the full path + for name in import_info.names: + full_name = f"{module}.{name}" if module else name + # Check if module is allowed OR full import path is allowed + if not ( + is_import_allowed(module, allowlist) + or is_import_allowed(full_name, allowlist) + ): + violations.append( + f"Line {import_info.line_number}: " + f'Unauthorized import "from {module} import {name}"' + ) + else: + # For 'import X', just check the module + if not is_import_allowed(module, allowlist): + violations.append( + f'Line {import_info.line_number}: Unauthorized import "{module}"' + ) + + return violations + def validate_imports_strict( code: str, allowlist: FrozenSet[str], ) -> None: - """Validate imports and raise exception if any violations found. + """Validate imports and raise exception if any violations found. - Args: - code: Python source code to validate. - allowlist: Set of allowed import patterns. + Args: + code: Python source code to validate. + allowlist: Set of allowed import patterns. - Raises: - ImportValidationError: If unauthorized imports are found. - """ - violations = validate_imports(code, allowlist) - if violations: - raise ImportValidationError(violations, code) + Raises: + ImportValidationError: If unauthorized imports are found. + """ + violations = validate_imports(code, allowlist) + if violations: + raise ImportValidationError(violations, code) class AllowlistValidator: - """Validator for checking Python code imports against an allowlist. + """Validator for checking Python code imports against an allowlist. - This class provides a stateful validator that can be reused for multiple - validations with the same allowlist. + This class provides a stateful validator that can be reused for multiple + validations with the same allowlist. - Attributes: - allowlist: The set of allowed import patterns. - """ + Attributes: + allowlist: The set of allowed import patterns. + """ - def __init__( - self, - allowlist: FrozenSet[str] = DEFAULT_SAFE_IMPORTS, - additional_imports: FrozenSet[str] = frozenset(), - ): - """Initialize the validator with an allowlist. + def __init__( + self, + allowlist: FrozenSet[str] = DEFAULT_SAFE_IMPORTS, + additional_imports: FrozenSet[str] = frozenset(), + ): + """Initialize the validator with an allowlist. - Args: - allowlist: Base set of allowed import patterns. - additional_imports: Additional imports to allow beyond the base set. - """ - self.allowlist = allowlist | additional_imports + Args: + allowlist: Base set of allowed import patterns. + additional_imports: Additional imports to allow beyond the base set. + """ + self.allowlist = allowlist | additional_imports - def validate(self, code: str) -> List[str]: - """Validate imports in code. + def validate(self, code: str) -> List[str]: + """Validate imports in code. - Args: - code: Python source code to validate. + Args: + code: Python source code to validate. - Returns: - List of violations (empty if all imports are valid). - """ - return validate_imports(code, self.allowlist) + Returns: + List of violations (empty if all imports are valid). + """ + return validate_imports(code, self.allowlist) - def validate_strict(self, code: str) -> None: - """Validate imports and raise if any violations found. + def validate_strict(self, code: str) -> None: + """Validate imports and raise if any violations found. - Args: - code: Python source code to validate. + Args: + code: Python source code to validate. - Raises: - ImportValidationError: If unauthorized imports are found. - """ - validate_imports_strict(code, self.allowlist) + Raises: + ImportValidationError: If unauthorized imports are found. + """ + validate_imports_strict(code, self.allowlist) - def is_allowed(self, import_name: str) -> bool: - """Check if a single import is allowed. + def is_allowed(self, import_name: str) -> bool: + """Check if a single import is allowed. - Args: - import_name: The import name to check. + Args: + import_name: The import name to check. - Returns: - True if allowed, False otherwise. - """ - return is_import_allowed(import_name, self.allowlist) + Returns: + True if allowed, False otherwise. + """ + return is_import_allowed(import_name, self.allowlist) - def add_allowed_imports(self, imports: Set[str]) -> None: - """Add additional allowed imports. + def add_allowed_imports(self, imports: Set[str]) -> None: + """Add additional allowed imports. - Args: - imports: Set of import patterns to allow. - """ - self.allowlist = self.allowlist | frozenset(imports) + Args: + imports: Set of import patterns to allow. + """ + self.allowlist = self.allowlist | frozenset(imports) diff --git a/src/google/adk/code_executors/coding_agent_code_executor.py b/src/google/adk/code_executors/coding_agent_code_executor.py index 22099a53e4..2f5e72532a 100644 --- a/src/google/adk/code_executors/coding_agent_code_executor.py +++ b/src/google/adk/code_executors/coding_agent_code_executor.py @@ -20,12 +20,12 @@ from __future__ import annotations +from dataclasses import dataclass +from dataclasses import field import hashlib import json import logging import re -from dataclasses import dataclass -from dataclasses import field from typing import Any from typing import Dict from typing import FrozenSet @@ -37,6 +37,7 @@ from pydantic import PrivateAttr from typing_extensions import override +from ..tools.base_tool import BaseTool from .allowlist_validator import AllowlistValidator from .allowlist_validator import DEFAULT_SAFE_IMPORTS from .allowlist_validator import ImportValidationError @@ -47,11 +48,10 @@ from .tool_execution_server import detect_docker_host_address from .tool_execution_server import ToolExecutionServer from .tool_execution_server import ToolTrace -from ..tools.base_tool import BaseTool if TYPE_CHECKING: - from ..agents.invocation_context import InvocationContext - from ..tools.tool_context import ToolContext + from ..agents.invocation_context import InvocationContext + from ..tools.tool_context import ToolContext logger = logging.getLogger("google_adk." + __name__) @@ -63,443 +63,443 @@ @dataclass class ExecutionStep: - """Record of a single code execution step. - - Attributes: - code: The code that was executed. - code_hash: Hash of the code for comparison. - result: The execution result. - tool_traces: Tool call traces from this step. - success: Whether the execution succeeded. - final_answer: The final answer if one was provided. - """ + """Record of a single code execution step. + + Attributes: + code: The code that was executed. + code_hash: Hash of the code for comparison. + result: The execution result. + tool_traces: Tool call traces from this step. + success: Whether the execution succeeded. + final_answer: The final answer if one was provided. + """ - code: str - code_hash: str = "" - result: Optional[CodeExecutionResult] = None - tool_traces: List[Dict[str, Any]] = field(default_factory=list) - success: bool = False - final_answer: Optional[Any] = None + code: str + code_hash: str = "" + result: Optional[CodeExecutionResult] = None + tool_traces: List[Dict[str, Any]] = field(default_factory=list) + success: bool = False + final_answer: Optional[Any] = None - def __post_init__(self): - if not self.code_hash: - self.code_hash = hashlib.sha256(self.code.encode()).hexdigest()[:16] + def __post_init__(self): + if not self.code_hash: + self.code_hash = hashlib.sha256(self.code.encode()).hexdigest()[:16] @dataclass class CodingAgentExecutionResult: - """Extended execution result with CodingAgent-specific fields. - - Attributes: - code_result: The underlying code execution result. - tool_traces: List of tool call traces. - final_answer: The final answer if one was provided. - has_final_answer: Whether a final answer was extracted. - clean_stdout: Stdout with trace markers removed. - """ + """Extended execution result with CodingAgent-specific fields. - code_result: CodeExecutionResult - tool_traces: List[Dict[str, Any]] = field(default_factory=list) - final_answer: Optional[Any] = None - has_final_answer: bool = False - clean_stdout: str = "" + Attributes: + code_result: The underlying code execution result. + tool_traces: List of tool call traces. + final_answer: The final answer if one was provided. + has_final_answer: Whether a final answer was extracted. + clean_stdout: Stdout with trace markers removed. + """ + + code_result: CodeExecutionResult + tool_traces: List[Dict[str, Any]] = field(default_factory=list) + final_answer: Optional[Any] = None + has_final_answer: bool = False + clean_stdout: str = "" class CodingAgentCodeExecutor(BaseCodeExecutor): - """Code executor with tool injection for CodingAgent. - - This executor wraps an underlying code executor and adds: - - Tool stub prepending for HTTP-based tool calls - - Import allowlist validation before execution - - Tool execution server lifecycle management - - Trace extraction from execution output - - Final answer detection - - History re-execution for stateful mode - - Attributes: - underlying_executor: The actual code executor to use. - tools: List of tools to make available. - authorized_imports: Set of allowed import patterns. - tool_server_host: Host for the tool server. - tool_server_port: Port for the tool server. - execution_history: List of execution steps for stateful mode. + """Code executor with tool injection for CodingAgent. + + This executor wraps an underlying code executor and adds: + - Tool stub prepending for HTTP-based tool calls + - Import allowlist validation before execution + - Tool execution server lifecycle management + - Trace extraction from execution output + - Final answer detection + - History re-execution for stateful mode + + Attributes: + underlying_executor: The actual code executor to use. + tools: List of tools to make available. + authorized_imports: Set of allowed import patterns. + tool_server_host: Host for the tool server. + tool_server_port: Port for the tool server. + execution_history: List of execution steps for stateful mode. + """ + + underlying_executor: BaseCodeExecutor + tools: List[BaseTool] = Field(default_factory=list) + authorized_imports: FrozenSet[str] = DEFAULT_SAFE_IMPORTS + tool_server_host: Optional[str] = None + tool_server_port: int = 8765 + + # Internal state - use PrivateAttr for Pydantic + _tool_server: Optional[ToolExecutionServer] = PrivateAttr(default=None) + _validator: Optional[AllowlistValidator] = PrivateAttr(default=None) + _invocation_context: Optional[InvocationContext] = PrivateAttr(default=None) + _tool_context: Optional[ToolContext] = PrivateAttr(default=None) + _execution_history: List[ExecutionStep] = PrivateAttr(default_factory=list) + + class Config: + """Pydantic config.""" + + arbitrary_types_allowed = True + + def model_post_init(self, __context): + """Initialize after model construction.""" + self._validator = AllowlistValidator( + allowlist=self.authorized_imports, + ) + self._execution_history = [] + + def set_context( + self, + invocation_context: InvocationContext, + tool_context: Optional[ToolContext] = None, + ) -> None: + """Set the execution context. + + Args: + invocation_context: The invocation context. + tool_context: The tool context. """ + self._invocation_context = invocation_context + self._tool_context = tool_context + if self._tool_server: + self._tool_server.set_context(invocation_context, tool_context) + + def _start_tool_server(self) -> None: + """Start the tool execution server if not already running.""" + if self._tool_server is not None: + return + + host = self.tool_server_host or "0.0.0.0" + self._tool_server = ToolExecutionServer( + host=host, + port=self.tool_server_port, + tools=self.tools, + invocation_context=self._invocation_context, + ) + self._tool_server.start() + + def _stop_tool_server(self) -> None: + """Stop the tool execution server.""" + if self._tool_server: + self._tool_server.stop() + self._tool_server = None + + def _get_tool_server_url(self) -> str: + """Get the URL for the tool server. + + Returns: + The tool server URL accessible from containers. + """ + if self.tool_server_host: + host = self.tool_server_host + else: + host = detect_docker_host_address() + return f"http://{host}:{self.tool_server_port}" - underlying_executor: BaseCodeExecutor - tools: List[BaseTool] = Field(default_factory=list) - authorized_imports: FrozenSet[str] = DEFAULT_SAFE_IMPORTS - tool_server_host: Optional[str] = None - tool_server_port: int = 8765 - - # Internal state - use PrivateAttr for Pydantic - _tool_server: Optional[ToolExecutionServer] = PrivateAttr(default=None) - _validator: Optional[AllowlistValidator] = PrivateAttr(default=None) - _invocation_context: Optional[InvocationContext] = PrivateAttr(default=None) - _tool_context: Optional[ToolContext] = PrivateAttr(default=None) - _execution_history: List[ExecutionStep] = PrivateAttr(default_factory=list) - - class Config: - """Pydantic config.""" - - arbitrary_types_allowed = True - - def model_post_init(self, __context): - """Initialize after model construction.""" - self._validator = AllowlistValidator( - allowlist=self.authorized_imports, - ) - self._execution_history = [] - - def set_context( - self, - invocation_context: InvocationContext, - tool_context: Optional[ToolContext] = None, - ) -> None: - """Set the execution context. - - Args: - invocation_context: The invocation context. - tool_context: The tool context. - """ - self._invocation_context = invocation_context - self._tool_context = tool_context - if self._tool_server: - self._tool_server.set_context(invocation_context, tool_context) - - def _start_tool_server(self) -> None: - """Start the tool execution server if not already running.""" - if self._tool_server is not None: - return - - host = self.tool_server_host or "0.0.0.0" - self._tool_server = ToolExecutionServer( - host=host, - port=self.tool_server_port, - tools=self.tools, - invocation_context=self._invocation_context, - ) - self._tool_server.start() - - def _stop_tool_server(self) -> None: - """Stop the tool execution server.""" - if self._tool_server: - self._tool_server.stop() - self._tool_server = None - - def _get_tool_server_url(self) -> str: - """Get the URL for the tool server. - - Returns: - The tool server URL accessible from containers. - """ - if self.tool_server_host: - host = self.tool_server_host - else: - host = detect_docker_host_address() - return f"http://{host}:{self.tool_server_port}" - - def _validate_imports(self, code: str) -> None: - """Validate imports in the code against the allowlist. - - Args: - code: The code to validate. - - Raises: - ImportValidationError: If unauthorized imports are found. - """ - if self._validator: - self._validator.validate_strict(code) - - def _extract_traces_and_answer( - self, - result: CodeExecutionResult, - ) -> CodingAgentExecutionResult: - """Extract tool traces and final answer from execution output. - - Args: - result: The raw execution result. - - Returns: - Extended result with extracted data. - """ - tool_traces = [] - final_answer = None - has_final_answer = False - clean_lines = [] - - for line in result.stdout.split("\n"): - if line.startswith(TOOL_TRACE_MARKER): - try: - trace_json = line[len(TOOL_TRACE_MARKER) :] - traces = json.loads(trace_json) - tool_traces.extend(traces) - except json.JSONDecodeError as e: - logger.warning("Failed to parse tool trace: %s", e) - elif line.startswith(FINAL_ANSWER_MARKER): - answer_str = line[len(FINAL_ANSWER_MARKER) :] - try: - final_answer = json.loads(answer_str) - except json.JSONDecodeError: - # Not JSON, treat as string - final_answer = answer_str - has_final_answer = True - else: - clean_lines.append(line) - - clean_stdout = "\n".join(clean_lines).strip() - - return CodingAgentExecutionResult( - code_result=result, - tool_traces=tool_traces, - final_answer=final_answer, - has_final_answer=has_final_answer, - clean_stdout=clean_stdout, - ) - - def _should_skip_step(self, step: ExecutionStep, code_hash: str) -> bool: - """Check if an execution step can be skipped. - - For stateful mode, we can skip re-executing code if: - - The code hasn't changed (same hash) - - The previous execution succeeded + def _validate_imports(self, code: str) -> None: + """Validate imports in the code against the allowlist. - Args: - step: The previous execution step. - code_hash: Hash of the current code. + Args: + code: The code to validate. - Returns: - True if the step can be skipped. - """ - return step.success and step.code_hash == code_hash + Raises: + ImportValidationError: If unauthorized imports are found. + """ + if self._validator: + self._validator.validate_strict(code) - def _prepend_tool_stubs(self, code: str) -> str: - """Prepend runtime header and tool stubs to user code. + def _extract_traces_and_answer( + self, + result: CodeExecutionResult, + ) -> CodingAgentExecutionResult: + """Extract tool traces and final answer from execution output. - Args: - code: The user code to wrap. + Args: + result: The raw execution result. - Returns: - Complete code with tool stubs. - """ - return generate_full_code_with_stubs( - user_code=code, - tools=self.tools, - tool_server_url=self._get_tool_server_url(), - ) + Returns: + Extended result with extracted data. + """ + tool_traces = [] + final_answer = None + has_final_answer = False + clean_lines = [] - def _replay_history( - self, - invocation_context: InvocationContext, - ) -> Optional[CodeExecutionResult]: - """Replay execution history for stateful mode. - - This re-executes previous successful steps to restore state - before executing new code. - - Args: - invocation_context: The invocation context. - - Returns: - The result of the last replayed step, or None if no replay needed. - """ - if not self.stateful or not self._execution_history: - return None - - last_result = None - for step in self._execution_history: - if step.success: - # Re-execute to restore state - full_code = self._prepend_tool_stubs(step.code) - input_data = CodeExecutionInput(code=full_code) - last_result = self.underlying_executor.execute_code( - invocation_context=invocation_context, - code_execution_input=input_data, - ) - logger.debug("Replayed history step: %s", step.code_hash) - - return last_result - - @override - def execute_code( - self, - invocation_context: InvocationContext, - code_execution_input: CodeExecutionInput, - ) -> CodeExecutionResult: - """Execute code with tool injection. - - Args: - invocation_context: The invocation context. - code_execution_input: The code to execute. - - Returns: - The execution result. - """ - user_code = code_execution_input.code - - # Validate imports first (security check before execution) + for line in result.stdout.split("\n"): + if line.startswith(TOOL_TRACE_MARKER): try: - self._validate_imports(user_code) - except ImportValidationError as e: - return CodeExecutionResult( - stdout="", - stderr=str(e), - output_files=[], - ) - - # Start tool server if needed - self._start_tool_server() - - # Set context on tool server - if self._tool_server: - self._tool_server.set_context( - invocation_context, - self._tool_context, - ) - self._tool_server.clear_traces() - - # Replay history for stateful mode - if self.stateful: - self._replay_history(invocation_context) - - # Prepend tool stubs to user code - full_code = self._prepend_tool_stubs(user_code) - - # Execute the code - input_with_stubs = CodeExecutionInput( - code=full_code, - input_files=code_execution_input.input_files, - execution_id=code_execution_input.execution_id, - ) - - result = self.underlying_executor.execute_code( - invocation_context=invocation_context, - code_execution_input=input_with_stubs, - ) - - # Extract traces and final answer - extended_result = self._extract_traces_and_answer(result) - - # Record execution step for stateful mode - step = ExecutionStep( - code=user_code, - result=result, - tool_traces=extended_result.tool_traces, - success=not result.stderr, - final_answer=extended_result.final_answer, - ) - self._execution_history.append(step) - - # Return result with clean stdout (traces stripped) - return CodeExecutionResult( - stdout=extended_result.clean_stdout, - stderr=result.stderr, - output_files=result.output_files, - ) - - def execute_code_extended( - self, - invocation_context: InvocationContext, - code_execution_input: CodeExecutionInput, - ) -> CodingAgentExecutionResult: - """Execute code and return extended result with traces. - - Args: - invocation_context: The invocation context. - code_execution_input: The code to execute. - - Returns: - Extended execution result with tool traces and final answer. - """ - user_code = code_execution_input.code - - # Validate imports first + trace_json = line[len(TOOL_TRACE_MARKER) :] + traces = json.loads(trace_json) + tool_traces.extend(traces) + except json.JSONDecodeError as e: + logger.warning("Failed to parse tool trace: %s", e) + elif line.startswith(FINAL_ANSWER_MARKER): + answer_str = line[len(FINAL_ANSWER_MARKER) :] try: - self._validate_imports(user_code) - except ImportValidationError as e: - return CodingAgentExecutionResult( - code_result=CodeExecutionResult( - stdout="", - stderr=str(e), - output_files=[], - ), - tool_traces=[], - final_answer=None, - has_final_answer=False, - clean_stdout="", - ) - - # Start tool server if needed - self._start_tool_server() - - # Set context on tool server - if self._tool_server: - self._tool_server.set_context( - invocation_context, - self._tool_context, - ) - self._tool_server.clear_traces() - - # Replay history for stateful mode - if self.stateful: - self._replay_history(invocation_context) - - # Prepend tool stubs to user code - full_code = self._prepend_tool_stubs(user_code) - - # Execute the code - input_with_stubs = CodeExecutionInput( - code=full_code, - input_files=code_execution_input.input_files, - execution_id=code_execution_input.execution_id, - ) + final_answer = json.loads(answer_str) + except json.JSONDecodeError: + # Not JSON, treat as string + final_answer = answer_str + has_final_answer = True + else: + clean_lines.append(line) + + clean_stdout = "\n".join(clean_lines).strip() + + return CodingAgentExecutionResult( + code_result=result, + tool_traces=tool_traces, + final_answer=final_answer, + has_final_answer=has_final_answer, + clean_stdout=clean_stdout, + ) + + def _should_skip_step(self, step: ExecutionStep, code_hash: str) -> bool: + """Check if an execution step can be skipped. + + For stateful mode, we can skip re-executing code if: + - The code hasn't changed (same hash) + - The previous execution succeeded + + Args: + step: The previous execution step. + code_hash: Hash of the current code. + + Returns: + True if the step can be skipped. + """ + return step.success and step.code_hash == code_hash - result = self.underlying_executor.execute_code( - invocation_context=invocation_context, - code_execution_input=input_with_stubs, - ) + def _prepend_tool_stubs(self, code: str) -> str: + """Prepend runtime header and tool stubs to user code. - # Extract traces and final answer - extended_result = self._extract_traces_and_answer(result) + Args: + code: The user code to wrap. - # Record execution step for stateful mode - step = ExecutionStep( - code=user_code, - result=result, - tool_traces=extended_result.tool_traces, - success=not result.stderr, - final_answer=extended_result.final_answer, + Returns: + Complete code with tool stubs. + """ + return generate_full_code_with_stubs( + user_code=code, + tools=self.tools, + tool_server_url=self._get_tool_server_url(), + ) + + def _replay_history( + self, + invocation_context: InvocationContext, + ) -> Optional[CodeExecutionResult]: + """Replay execution history for stateful mode. + + This re-executes previous successful steps to restore state + before executing new code. + + Args: + invocation_context: The invocation context. + + Returns: + The result of the last replayed step, or None if no replay needed. + """ + if not self.stateful or not self._execution_history: + return None + + last_result = None + for step in self._execution_history: + if step.success: + # Re-execute to restore state + full_code = self._prepend_tool_stubs(step.code) + input_data = CodeExecutionInput(code=full_code) + last_result = self.underlying_executor.execute_code( + invocation_context=invocation_context, + code_execution_input=input_data, ) - self._execution_history.append(step) - - return extended_result + logger.debug("Replayed history step: %s", step.code_hash) - def get_execution_history(self) -> List[ExecutionStep]: - """Get the execution history. + return last_result - Returns: - List of execution steps. - """ - return self._execution_history.copy() + @override + def execute_code( + self, + invocation_context: InvocationContext, + code_execution_input: CodeExecutionInput, + ) -> CodeExecutionResult: + """Execute code with tool injection. - def clear_execution_history(self) -> None: - """Clear the execution history.""" - self._execution_history.clear() + Args: + invocation_context: The invocation context. + code_execution_input: The code to execute. - def get_tool_traces(self) -> List[ToolTrace]: - """Get tool traces from the server. + Returns: + The execution result. + """ + user_code = code_execution_input.code + + # Validate imports first (security check before execution) + try: + self._validate_imports(user_code) + except ImportValidationError as e: + return CodeExecutionResult( + stdout="", + stderr=str(e), + output_files=[], + ) + + # Start tool server if needed + self._start_tool_server() + + # Set context on tool server + if self._tool_server: + self._tool_server.set_context( + invocation_context, + self._tool_context, + ) + self._tool_server.clear_traces() + + # Replay history for stateful mode + if self.stateful: + self._replay_history(invocation_context) + + # Prepend tool stubs to user code + full_code = self._prepend_tool_stubs(user_code) + + # Execute the code + input_with_stubs = CodeExecutionInput( + code=full_code, + input_files=code_execution_input.input_files, + execution_id=code_execution_input.execution_id, + ) + + result = self.underlying_executor.execute_code( + invocation_context=invocation_context, + code_execution_input=input_with_stubs, + ) + + # Extract traces and final answer + extended_result = self._extract_traces_and_answer(result) + + # Record execution step for stateful mode + step = ExecutionStep( + code=user_code, + result=result, + tool_traces=extended_result.tool_traces, + success=not result.stderr, + final_answer=extended_result.final_answer, + ) + self._execution_history.append(step) + + # Return result with clean stdout (traces stripped) + return CodeExecutionResult( + stdout=extended_result.clean_stdout, + stderr=result.stderr, + output_files=result.output_files, + ) + + def execute_code_extended( + self, + invocation_context: InvocationContext, + code_execution_input: CodeExecutionInput, + ) -> CodingAgentExecutionResult: + """Execute code and return extended result with traces. + + Args: + invocation_context: The invocation context. + code_execution_input: The code to execute. + + Returns: + Extended execution result with tool traces and final answer. + """ + user_code = code_execution_input.code + + # Validate imports first + try: + self._validate_imports(user_code) + except ImportValidationError as e: + return CodingAgentExecutionResult( + code_result=CodeExecutionResult( + stdout="", + stderr=str(e), + output_files=[], + ), + tool_traces=[], + final_answer=None, + has_final_answer=False, + clean_stdout="", + ) + + # Start tool server if needed + self._start_tool_server() + + # Set context on tool server + if self._tool_server: + self._tool_server.set_context( + invocation_context, + self._tool_context, + ) + self._tool_server.clear_traces() + + # Replay history for stateful mode + if self.stateful: + self._replay_history(invocation_context) + + # Prepend tool stubs to user code + full_code = self._prepend_tool_stubs(user_code) + + # Execute the code + input_with_stubs = CodeExecutionInput( + code=full_code, + input_files=code_execution_input.input_files, + execution_id=code_execution_input.execution_id, + ) + + result = self.underlying_executor.execute_code( + invocation_context=invocation_context, + code_execution_input=input_with_stubs, + ) + + # Extract traces and final answer + extended_result = self._extract_traces_and_answer(result) + + # Record execution step for stateful mode + step = ExecutionStep( + code=user_code, + result=result, + tool_traces=extended_result.tool_traces, + success=not result.stderr, + final_answer=extended_result.final_answer, + ) + self._execution_history.append(step) + + return extended_result + + def get_execution_history(self) -> List[ExecutionStep]: + """Get the execution history. + + Returns: + List of execution steps. + """ + return self._execution_history.copy() - Returns: - List of tool traces. - """ - if self._tool_server: - return self._tool_server.get_traces() - return [] + def clear_execution_history(self) -> None: + """Clear the execution history.""" + self._execution_history.clear() - def cleanup(self) -> None: - """Clean up resources.""" - self._stop_tool_server() - self._execution_history.clear() + def get_tool_traces(self) -> List[ToolTrace]: + """Get tool traces from the server. - def __del__(self): - """Destructor to clean up resources.""" - self.cleanup() + Returns: + List of tool traces. + """ + if self._tool_server: + return self._tool_server.get_traces() + return [] + + def cleanup(self) -> None: + """Clean up resources.""" + self._stop_tool_server() + self._execution_history.clear() + + def __del__(self): + """Destructor to clean up resources.""" + self.cleanup() diff --git a/src/google/adk/code_executors/tool_code_generator.py b/src/google/adk/code_executors/tool_code_generator.py index ef73c61376..1c55c36830 100644 --- a/src/google/adk/code_executors/tool_code_generator.py +++ b/src/google/adk/code_executors/tool_code_generator.py @@ -30,7 +30,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from ..tools.base_tool import BaseTool + from ..tools.base_tool import BaseTool logger = logging.getLogger("google_adk." + __name__) @@ -132,129 +132,129 @@ def final_answer(result): def generate_runtime_header( tool_server_url: str, ) -> str: - """Generate the runtime header with HTTP client and helper functions. + """Generate the runtime header with HTTP client and helper functions. - Args: - tool_server_url: URL of the tool execution server. + Args: + tool_server_url: URL of the tool execution server. - Returns: - Python code string containing the runtime header. - """ - return RUNTIME_HEADER_TEMPLATE.format(tool_server_url=tool_server_url) + Returns: + Python code string containing the runtime header. + """ + return RUNTIME_HEADER_TEMPLATE.format(tool_server_url=tool_server_url) def _get_schema_type(schema: Any) -> str: - """Get the type from a schema (dict or Pydantic Schema object). - - Args: - schema: JSON schema dict or google.genai.types.Schema object. - - Returns: - The type as a lowercase string. - """ - if hasattr(schema, "type"): - # Pydantic Schema object from google.genai.types - schema_type = schema.type - if schema_type is None: - return "any" - # Handle enum (Type.STRING -> "string") - if hasattr(schema_type, "value"): - return schema_type.value.lower() - return str(schema_type).lower() - elif isinstance(schema, dict): - return schema.get("type", "any") - return "any" + """Get the type from a schema (dict or Pydantic Schema object). + + Args: + schema: JSON schema dict or google.genai.types.Schema object. + + Returns: + The type as a lowercase string. + """ + if hasattr(schema, "type"): + # Pydantic Schema object from google.genai.types + schema_type = schema.type + if schema_type is None: + return "any" + # Handle enum (Type.STRING -> "string") + if hasattr(schema_type, "value"): + return schema_type.value.lower() + return str(schema_type).lower() + elif isinstance(schema, dict): + return schema.get("type", "any") + return "any" def _get_schema_attr(schema: Any, attr: str, default: Any = None) -> Any: - """Get an attribute from a schema (dict or Pydantic Schema object). + """Get an attribute from a schema (dict or Pydantic Schema object). - Args: - schema: JSON schema dict or google.genai.types.Schema object. - attr: The attribute name to get. - default: Default value if attribute not found. + Args: + schema: JSON schema dict or google.genai.types.Schema object. + attr: The attribute name to get. + default: Default value if attribute not found. - Returns: - The attribute value or default. - """ - if hasattr(schema, attr): - return getattr(schema, attr, default) - elif isinstance(schema, dict): - return schema.get(attr, default) - return default + Returns: + The attribute value or default. + """ + if hasattr(schema, attr): + return getattr(schema, attr, default) + elif isinstance(schema, dict): + return schema.get(attr, default) + return default def _get_python_type_hint(schema: Any) -> str: - """Convert JSON schema type to Python type hint. + """Convert JSON schema type to Python type hint. - Args: - schema: JSON schema dict or google.genai.types.Schema object. + Args: + schema: JSON schema dict or google.genai.types.Schema object. - Returns: - Python type hint string. - """ - schema_type = _get_schema_type(schema) + Returns: + Python type hint string. + """ + schema_type = _get_schema_type(schema) - type_mapping = { - "string": "str", - "integer": "int", - "number": "float", - "boolean": "bool", - "array": "list", - "object": "dict", - } + type_mapping = { + "string": "str", + "integer": "int", + "number": "float", + "boolean": "bool", + "array": "list", + "object": "dict", + } - if schema_type == "array": - items = _get_schema_attr(schema, "items", {}) - if items: - item_type = _get_python_type_hint(items) - return f"list[{item_type}]" - return "list" - elif schema_type == "object": - return "dict" + if schema_type == "array": + items = _get_schema_attr(schema, "items", {}) + if items: + item_type = _get_python_type_hint(items) + return f"list[{item_type}]" + return "list" + elif schema_type == "object": + return "dict" - return type_mapping.get(schema_type, "Any") + return type_mapping.get(schema_type, "Any") def _generate_tool_stub(tool: BaseTool) -> str: - """Generate a Python function stub for a single tool. - - Args: - tool: The BaseTool to generate a stub for. - - Returns: - Python code string for the tool stub function. - """ - decl = tool._get_declaration() - if not decl: - logger.warning( - "Tool %s has no declaration, skipping stub generation", tool.name - ) - return "" + """Generate a Python function stub for a single tool. + + Args: + tool: The BaseTool to generate a stub for. + + Returns: + Python code string for the tool stub function. + """ + decl = tool._get_declaration() + if not decl: + logger.warning( + "Tool %s has no declaration, skipping stub generation", tool.name + ) + return "" - # Build parameter list with type hints - params = [] - param_docs = [] + # Build parameter list with type hints + params = [] + param_docs = [] - if decl.parameters and decl.parameters.properties: - required = set(decl.parameters.required or []) + if decl.parameters and decl.parameters.properties: + required = set(decl.parameters.required or []) - for param_name, param_schema in decl.parameters.properties.items(): - type_hint = _get_python_type_hint(param_schema) - description = _get_schema_attr(param_schema, "description", "") + for param_name, param_schema in decl.parameters.properties.items(): + type_hint = _get_python_type_hint(param_schema) + description = _get_schema_attr(param_schema, "description", "") - if param_name in required: - params.append(f"{param_name}: {type_hint}") - else: - params.append(f"{param_name}: {type_hint} = None") + if param_name in required: + params.append(f"{param_name}: {type_hint}") + else: + params.append(f"{param_name}: {type_hint} = None") - param_docs.append(f" {param_name}: {description}") + param_docs.append(f" {param_name}: {description}") - param_str = ", ".join(params) - param_doc_str = "\n".join(param_docs) if param_docs else " None" + param_str = ", ".join(params) + param_doc_str = "\n".join(param_docs) if param_docs else " None" - # Build the function stub - stub = f''' + # Build the function stub + stub = f''' def {tool.name}({param_str}) -> dict: """{tool.description} @@ -272,43 +272,47 @@ def {tool.name}({param_str}) -> dict: return response ''' - return stub + return stub def generate_tool_stubs(tools: List[BaseTool]) -> str: - """Generate Python function stubs for all tools. + """Generate Python function stubs for all tools. - Args: - tools: List of tools to generate stubs for. + Args: + tools: List of tools to generate stubs for. - Returns: - Python code string containing all tool stubs. - """ - stubs = [ - "# ============================================================================", - "# Tool Function Stubs", - "# ============================================================================", - "", - ] + Returns: + Python code string containing all tool stubs. + """ + stubs = [ + ( + "# ============================================================================" + ), + "# Tool Function Stubs", + ( + "# ============================================================================" + ), + "", + ] - for tool in tools: - stub = _generate_tool_stub(tool) - if stub: - stubs.append(stub) + for tool in tools: + stub = _generate_tool_stub(tool) + if stub: + stubs.append(stub) - return "\n".join(stubs) + return "\n".join(stubs) def generate_final_answer_stub() -> str: - """Generate the final_answer function documentation. + """Generate the final_answer function documentation. - This is included in the runtime header, but we generate additional - documentation here for the system prompt. + This is included in the runtime header, but we generate additional + documentation here for the system prompt. - Returns: - Documentation string about the final_answer function. - """ - return """ + Returns: + Documentation string about the final_answer function. + """ + return """ The `final_answer(result)` function is available to mark your final result. Call this function when you have completed the task and have a result to return. Example: `final_answer("The calculation result is 42")` @@ -366,34 +370,32 @@ def generate_system_prompt( tools: List[BaseTool], custom_instruction: str = "", ) -> str: - """Generate the system prompt for the CodingAgent. - - Args: - tools: List of available tools. - custom_instruction: Additional custom instructions. - - Returns: - Complete system prompt string. - """ - # Build tool documentation - tool_docs = [] - for tool in tools: - decl = tool._get_declaration() - if decl: - params_doc = "" - if decl.parameters and decl.parameters.properties: - param_lines = [] - required = set(decl.parameters.required or []) - for name, schema in decl.parameters.properties.items(): - type_hint = _get_python_type_hint(schema) - req_marker = " (required)" if name in required else " (optional)" - desc = _get_schema_attr(schema, "description", "") - param_lines.append( - f" - {name}: {type_hint}{req_marker} - {desc}" - ) - params_doc = "\n".join(param_lines) - - tool_docs.append(f""" + """Generate the system prompt for the CodingAgent. + + Args: + tools: List of available tools. + custom_instruction: Additional custom instructions. + + Returns: + Complete system prompt string. + """ + # Build tool documentation + tool_docs = [] + for tool in tools: + decl = tool._get_declaration() + if decl: + params_doc = "" + if decl.parameters and decl.parameters.properties: + param_lines = [] + required = set(decl.parameters.required or []) + for name, schema in decl.parameters.properties.items(): + type_hint = _get_python_type_hint(schema) + req_marker = " (required)" if name in required else " (optional)" + desc = _get_schema_attr(schema, "description", "") + param_lines.append(f" - {name}: {type_hint}{req_marker} - {desc}") + params_doc = "\n".join(param_lines) + + tool_docs.append(f""" ### {tool.name} {tool.description} @@ -401,9 +403,9 @@ def generate_system_prompt( {params_doc if params_doc else " None"} """) - tools_section = "\n".join(tool_docs) if tool_docs else "No tools available." + tools_section = "\n".join(tool_docs) if tool_docs else "No tools available." - system_prompt = f"""You are a coding agent that solves tasks by writing and executing Python code. + system_prompt = f"""You are a coding agent that solves tasks by writing and executing Python code. ## How to Respond @@ -439,7 +441,7 @@ def generate_system_prompt( {custom_instruction} """ - return system_prompt.strip() + return system_prompt.strip() def generate_full_code_with_stubs( @@ -447,20 +449,20 @@ def generate_full_code_with_stubs( tools: List[BaseTool], tool_server_url: str, ) -> str: - """Generate complete executable code with runtime header and tool stubs. + """Generate complete executable code with runtime header and tool stubs. - Args: - user_code: The user-generated code to execute. - tools: List of available tools. - tool_server_url: URL of the tool execution server. + Args: + user_code: The user-generated code to execute. + tools: List of available tools. + tool_server_url: URL of the tool execution server. - Returns: - Complete Python code ready for execution. - """ - runtime_header = generate_runtime_header(tool_server_url) - tool_stubs = generate_tool_stubs(tools) + Returns: + Complete Python code ready for execution. + """ + runtime_header = generate_runtime_header(tool_server_url) + tool_stubs = generate_tool_stubs(tools) - full_code = f"""{runtime_header} + full_code = f"""{runtime_header} {tool_stubs} # ============================================================================ # User Code @@ -475,4 +477,4 @@ def generate_full_code_with_stubs( print("__TOOL_TRACE__:" + __output_json.dumps(__get_tool_traces())) """ - return full_code + return full_code diff --git a/src/google/adk/code_executors/tool_execution_server.py b/src/google/adk/code_executors/tool_execution_server.py index fa1c2efda1..de13b15820 100644 --- a/src/google/adk/code_executors/tool_execution_server.py +++ b/src/google/adk/code_executors/tool_execution_server.py @@ -22,345 +22,345 @@ from __future__ import annotations import asyncio +from dataclasses import dataclass +from dataclasses import field import json import logging import os import socket import threading -from dataclasses import dataclass -from dataclasses import field from typing import Any from typing import Dict from typing import List from typing import Optional from typing import TYPE_CHECKING -import uvicorn from fastapi import FastAPI from fastapi import HTTPException from pydantic import BaseModel +import uvicorn if TYPE_CHECKING: - from ..agents.invocation_context import InvocationContext - from ..tools.base_tool import BaseTool - from ..tools.tool_context import ToolContext + from ..agents.invocation_context import InvocationContext + from ..tools.base_tool import BaseTool + from ..tools.tool_context import ToolContext logger = logging.getLogger("google_adk." + __name__) class ToolCallRequest(BaseModel): - """Request model for tool calls.""" + """Request model for tool calls.""" - tool_name: str - args: Dict[str, Any] + tool_name: str + args: Dict[str, Any] class ToolCallResponse(BaseModel): - """Response model for tool calls.""" + """Response model for tool calls.""" - result: Any - success: bool - error: Optional[str] = None + result: Any + success: bool + error: Optional[str] = None @dataclass class ToolTrace: - """Record of a tool call for debugging and telemetry.""" + """Record of a tool call for debugging and telemetry.""" - tool_name: str - args: Dict[str, Any] - result: Any = None - error: Optional[str] = None - success: bool = False - duration_ms: float = 0.0 + tool_name: str + args: Dict[str, Any] + result: Any = None + error: Optional[str] = None + success: bool = False + duration_ms: float = 0.0 def detect_docker_host_address() -> str: - """Detect the appropriate host address for Docker containers. + """Detect the appropriate host address for Docker containers. - On macOS and Windows (Docker Desktop), use host.docker.internal. - On Linux, use 172.17.0.1 (default Docker bridge network gateway). + On macOS and Windows (Docker Desktop), use host.docker.internal. + On Linux, use 172.17.0.1 (default Docker bridge network gateway). - Note: host.docker.internal only resolves from within containers, - not from the host machine, so we check the platform instead. + Note: host.docker.internal only resolves from within containers, + not from the host machine, so we check the platform instead. - Returns: - The detected host address. - """ - import platform + Returns: + The detected host address. + """ + import platform - system = platform.system().lower() + system = platform.system().lower() - # macOS and Windows use Docker Desktop which supports host.docker.internal - if system in ("darwin", "windows"): - return "host.docker.internal" + # macOS and Windows use Docker Desktop which supports host.docker.internal + if system in ("darwin", "windows"): + return "host.docker.internal" - # Linux: use Docker bridge network gateway - return "172.17.0.1" + # Linux: use Docker bridge network gateway + return "172.17.0.1" class ToolExecutionServer: - """FastAPI server for executing ADK tools via HTTP. - - This server is designed to run on the host machine and receive tool - execution requests from code running in Docker containers. - - Attributes: - host: Host address to bind the server to. - port: Port to bind the server to. - tools: Dictionary mapping tool names to tool instances. - invocation_context: The current invocation context. - tool_context: The current tool context. - traces: List of tool call traces. + """FastAPI server for executing ADK tools via HTTP. + + This server is designed to run on the host machine and receive tool + execution requests from code running in Docker containers. + + Attributes: + host: Host address to bind the server to. + port: Port to bind the server to. + tools: Dictionary mapping tool names to tool instances. + invocation_context: The current invocation context. + tool_context: The current tool context. + traces: List of tool call traces. + """ + + def __init__( + self, + host: str = "0.0.0.0", + port: int = 8765, + tools: Optional[List[BaseTool]] = None, + invocation_context: Optional[InvocationContext] = None, + ): + """Initialize the tool execution server. + + Args: + host: Host address to bind to. + port: Port to bind to. + tools: List of tools to make available. + invocation_context: The invocation context for tool calls. """ + self.host = host + self.port = port + self.tools: Dict[str, BaseTool] = {} + self.invocation_context = invocation_context + self.tool_context: Optional[ToolContext] = None + self.traces: List[ToolTrace] = [] + self._server: Optional[uvicorn.Server] = None + self._server_thread: Optional[threading.Thread] = None + self._app = self._create_app() + + if tools: + for tool in tools: + self.register_tool(tool) + + def _create_app(self) -> FastAPI: + """Create the FastAPI application with routes.""" + app = FastAPI( + title="ADK Tool Execution Server", + description="HTTP server for executing ADK tools from containers", + version="1.0.0", + ) + + @app.post("/tool_call", response_model=ToolCallResponse) + async def handle_tool_call(request: ToolCallRequest) -> ToolCallResponse: + """Handle a tool call request.""" + return await self._execute_tool(request.tool_name, request.args) + + @app.get("/tool_trace") + async def get_tool_traces() -> List[Dict[str, Any]]: + """Get all tool call traces.""" + return [ + { + "tool_name": t.tool_name, + "args": t.args, + "result": t.result, + "error": t.error, + "success": t.success, + "duration_ms": t.duration_ms, + } + for t in self.traces + ] + + @app.delete("/tool_trace") + async def clear_tool_traces() -> Dict[str, str]: + """Clear all tool call traces.""" + self.traces.clear() + return {"status": "cleared"} + + @app.get("/health") + async def health_check() -> Dict[str, str]: + """Health check endpoint.""" + return {"status": "healthy"} + + @app.get("/tools") + async def list_tools() -> List[str]: + """List available tools.""" + return list(self.tools.keys()) + + return app + + def register_tool(self, tool: BaseTool) -> None: + """Register a tool with the server. + + Args: + tool: The tool to register. + """ + self.tools[tool.name] = tool + logger.debug("Registered tool: %s", tool.name) + + def set_context( + self, + invocation_context: InvocationContext, + tool_context: Optional[ToolContext] = None, + ) -> None: + """Set the context for tool execution. + + Args: + invocation_context: The invocation context. + tool_context: The tool context. + """ + self.invocation_context = invocation_context + self.tool_context = tool_context - def __init__( - self, - host: str = "0.0.0.0", - port: int = 8765, - tools: Optional[List[BaseTool]] = None, - invocation_context: Optional[InvocationContext] = None, - ): - """Initialize the tool execution server. - - Args: - host: Host address to bind to. - port: Port to bind to. - tools: List of tools to make available. - invocation_context: The invocation context for tool calls. - """ - self.host = host - self.port = port - self.tools: Dict[str, BaseTool] = {} - self.invocation_context = invocation_context - self.tool_context: Optional[ToolContext] = None - self.traces: List[ToolTrace] = [] - self._server: Optional[uvicorn.Server] = None - self._server_thread: Optional[threading.Thread] = None - self._app = self._create_app() - - if tools: - for tool in tools: - self.register_tool(tool) - - def _create_app(self) -> FastAPI: - """Create the FastAPI application with routes.""" - app = FastAPI( - title="ADK Tool Execution Server", - description="HTTP server for executing ADK tools from containers", - version="1.0.0", - ) + async def _execute_tool( + self, + tool_name: str, + args: Dict[str, Any], + ) -> ToolCallResponse: + """Execute a tool and return the result. + + Args: + tool_name: Name of the tool to execute. + args: Arguments to pass to the tool. + + Returns: + The tool execution response. + """ + import time + + start_time = time.time() + trace = ToolTrace(tool_name=tool_name, args=args) - @app.post("/tool_call", response_model=ToolCallResponse) - async def handle_tool_call(request: ToolCallRequest) -> ToolCallResponse: - """Handle a tool call request.""" - return await self._execute_tool(request.tool_name, request.args) - - @app.get("/tool_trace") - async def get_tool_traces() -> List[Dict[str, Any]]: - """Get all tool call traces.""" - return [ - { - "tool_name": t.tool_name, - "args": t.args, - "result": t.result, - "error": t.error, - "success": t.success, - "duration_ms": t.duration_ms, - } - for t in self.traces - ] - - @app.delete("/tool_trace") - async def clear_tool_traces() -> Dict[str, str]: - """Clear all tool call traces.""" - self.traces.clear() - return {"status": "cleared"} - - @app.get("/health") - async def health_check() -> Dict[str, str]: - """Health check endpoint.""" - return {"status": "healthy"} - - @app.get("/tools") - async def list_tools() -> List[str]: - """List available tools.""" - return list(self.tools.keys()) - - return app - - def register_tool(self, tool: BaseTool) -> None: - """Register a tool with the server. - - Args: - tool: The tool to register. - """ - self.tools[tool.name] = tool - logger.debug("Registered tool: %s", tool.name) - - def set_context( - self, - invocation_context: InvocationContext, - tool_context: Optional[ToolContext] = None, - ) -> None: - """Set the context for tool execution. - - Args: - invocation_context: The invocation context. - tool_context: The tool context. - """ - self.invocation_context = invocation_context - self.tool_context = tool_context - - async def _execute_tool( - self, - tool_name: str, - args: Dict[str, Any], - ) -> ToolCallResponse: - """Execute a tool and return the result. - - Args: - tool_name: Name of the tool to execute. - args: Arguments to pass to the tool. - - Returns: - The tool execution response. - """ - import time - - start_time = time.time() - trace = ToolTrace(tool_name=tool_name, args=args) - - if tool_name not in self.tools: - trace.error = f"Tool not found: {tool_name}" - trace.success = False - trace.duration_ms = (time.time() - start_time) * 1000 - self.traces.append(trace) - raise HTTPException(status_code=404, detail=trace.error) - - tool = self.tools[tool_name] - - try: - # Create a tool context if we have an invocation context - if self.invocation_context and not self.tool_context: - from ..tools.tool_context import ToolContext - - self.tool_context = ToolContext( - invocation_context=self.invocation_context, - ) - - if self.tool_context: - result = await tool.run_async(args=args, tool_context=self.tool_context) - else: - # If no context available, create a minimal mock context - # This is a fallback and shouldn't happen in normal operation - logger.warning("Executing tool %s without proper context", tool_name) - result = await tool.run_async(args=args, tool_context=None) - - trace.result = result - trace.success = True - trace.duration_ms = (time.time() - start_time) * 1000 - self.traces.append(trace) - - return ToolCallResponse(result=result, success=True) - - except Exception as e: - trace.error = str(e) - trace.success = False - trace.duration_ms = (time.time() - start_time) * 1000 - self.traces.append(trace) - logger.error("Tool execution failed: %s - %s", tool_name, e) - raise HTTPException(status_code=500, detail=str(e)) from e - - def start(self) -> None: - """Start the server in a background thread.""" - if self._server_thread and self._server_thread.is_alive(): - logger.warning("Server already running") - return - - config = uvicorn.Config( - app=self._app, - host=self.host, - port=self.port, - log_level="warning", + if tool_name not in self.tools: + trace.error = f"Tool not found: {tool_name}" + trace.success = False + trace.duration_ms = (time.time() - start_time) * 1000 + self.traces.append(trace) + raise HTTPException(status_code=404, detail=trace.error) + + tool = self.tools[tool_name] + + try: + # Create a tool context if we have an invocation context + if self.invocation_context and not self.tool_context: + from ..tools.tool_context import ToolContext + + self.tool_context = ToolContext( + invocation_context=self.invocation_context, ) - self._server = uvicorn.Server(config) - - def run_server(): - asyncio.run(self._server.serve()) - - self._server_thread = threading.Thread(target=run_server, daemon=True) - self._server_thread.start() - - # Wait for server to be ready - self._wait_for_server() - logger.info("Tool execution server started on %s:%d", self.host, self.port) - - def _wait_for_server(self, timeout: float = 10.0) -> None: - """Wait for the server to be ready. - - Args: - timeout: Maximum time to wait in seconds. - """ - import time - - start = time.time() - while time.time() - start < timeout: - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(1) - result = sock.connect_ex(("127.0.0.1", self.port)) - sock.close() - if result == 0: - return - except Exception: - pass - time.sleep(0.1) - - logger.warning("Server may not be fully ready after %.1f seconds", timeout) - - def stop(self) -> None: - """Stop the server.""" - if self._server: - self._server.should_exit = True - if self._server_thread: - self._server_thread.join(timeout=5.0) - self._server = None - self._server_thread = None - logger.info("Tool execution server stopped") - - def get_url(self, for_container: bool = True) -> str: - """Get the URL for the server. - - Args: - for_container: If True, return URL accessible from Docker containers. - - Returns: - The server URL. - """ - if for_container: - host = detect_docker_host_address() - else: - host = "localhost" if self.host == "0.0.0.0" else self.host - return f"http://{host}:{self.port}" - - def clear_traces(self) -> None: - """Clear all tool call traces.""" - self.traces.clear() - - def get_traces(self) -> List[ToolTrace]: - """Get all tool call traces. - - Returns: - List of tool traces. - """ - return self.traces.copy() - - def __enter__(self) -> ToolExecutionServer: - """Context manager entry.""" - self.start() - return self - - def __exit__(self, exc_type, exc_val, exc_tb) -> None: - """Context manager exit.""" - self.stop() + + if self.tool_context: + result = await tool.run_async(args=args, tool_context=self.tool_context) + else: + # If no context available, create a minimal mock context + # This is a fallback and shouldn't happen in normal operation + logger.warning("Executing tool %s without proper context", tool_name) + result = await tool.run_async(args=args, tool_context=None) + + trace.result = result + trace.success = True + trace.duration_ms = (time.time() - start_time) * 1000 + self.traces.append(trace) + + return ToolCallResponse(result=result, success=True) + + except Exception as e: + trace.error = str(e) + trace.success = False + trace.duration_ms = (time.time() - start_time) * 1000 + self.traces.append(trace) + logger.error("Tool execution failed: %s - %s", tool_name, e) + raise HTTPException(status_code=500, detail=str(e)) from e + + def start(self) -> None: + """Start the server in a background thread.""" + if self._server_thread and self._server_thread.is_alive(): + logger.warning("Server already running") + return + + config = uvicorn.Config( + app=self._app, + host=self.host, + port=self.port, + log_level="warning", + ) + self._server = uvicorn.Server(config) + + def run_server(): + asyncio.run(self._server.serve()) + + self._server_thread = threading.Thread(target=run_server, daemon=True) + self._server_thread.start() + + # Wait for server to be ready + self._wait_for_server() + logger.info("Tool execution server started on %s:%d", self.host, self.port) + + def _wait_for_server(self, timeout: float = 10.0) -> None: + """Wait for the server to be ready. + + Args: + timeout: Maximum time to wait in seconds. + """ + import time + + start = time.time() + while time.time() - start < timeout: + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(1) + result = sock.connect_ex(("127.0.0.1", self.port)) + sock.close() + if result == 0: + return + except Exception: + pass + time.sleep(0.1) + + logger.warning("Server may not be fully ready after %.1f seconds", timeout) + + def stop(self) -> None: + """Stop the server.""" + if self._server: + self._server.should_exit = True + if self._server_thread: + self._server_thread.join(timeout=5.0) + self._server = None + self._server_thread = None + logger.info("Tool execution server stopped") + + def get_url(self, for_container: bool = True) -> str: + """Get the URL for the server. + + Args: + for_container: If True, return URL accessible from Docker containers. + + Returns: + The server URL. + """ + if for_container: + host = detect_docker_host_address() + else: + host = "localhost" if self.host == "0.0.0.0" else self.host + return f"http://{host}:{self.port}" + + def clear_traces(self) -> None: + """Clear all tool call traces.""" + self.traces.clear() + + def get_traces(self) -> List[ToolTrace]: + """Get all tool call traces. + + Returns: + List of tool traces. + """ + return self.traces.copy() + + def __enter__(self) -> ToolExecutionServer: + """Context manager entry.""" + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Context manager exit.""" + self.stop() diff --git a/tests/unittests/agents/test_coding_agent.py b/tests/unittests/agents/test_coding_agent.py index 2516507a3f..a54ec1cb23 100644 --- a/tests/unittests/agents/test_coding_agent.py +++ b/tests/unittests/agents/test_coding_agent.py @@ -16,7 +16,6 @@ from __future__ import annotations -import pytest from unittest.mock import AsyncMock from unittest.mock import MagicMock from unittest.mock import patch @@ -26,183 +25,184 @@ from google.adk.agents.coding_agent_config import CodingAgentConfig from google.adk.agents.coding_agent_config import DEFAULT_SAFE_IMPORTS from google.adk.tools.base_tool import BaseTool +import pytest class TestCodingAgentConfig: - """Tests for CodingAgentConfig.""" - - def test_default_values(self): - """Test that default values are set correctly.""" - config = CodingAgentConfig(name="test_agent") - - assert config.name == "test_agent" - assert config.agent_class == "CodingAgent" - assert config.max_iterations == 10 - assert config.error_retry_attempts == 2 - assert config.stateful is False - assert config.tool_server_port == 8765 - assert config.authorized_imports == DEFAULT_SAFE_IMPORTS - - def test_custom_values(self): - """Test that custom values can be set.""" - custom_imports = frozenset({"json", "math"}) - config = CodingAgentConfig( - name="custom_agent", - model="gemini-2.0-flash", - max_iterations=20, - error_retry_attempts=5, - stateful=True, - tool_server_port=9999, - authorized_imports=custom_imports, - ) - - assert config.name == "custom_agent" - assert config.model == "gemini-2.0-flash" - assert config.max_iterations == 20 - assert config.error_retry_attempts == 5 - assert config.stateful is True - assert config.tool_server_port == 9999 - assert config.authorized_imports == custom_imports - - def test_max_iterations_bounds(self): - """Test max_iterations validation.""" - # Valid bounds - config = CodingAgentConfig(name="test", max_iterations=1) - assert config.max_iterations == 1 - - config = CodingAgentConfig(name="test", max_iterations=100) - assert config.max_iterations == 100 - - # Invalid bounds - with pytest.raises(ValueError): - CodingAgentConfig(name="test", max_iterations=0) - - with pytest.raises(ValueError): - CodingAgentConfig(name="test", max_iterations=101) - - def test_port_bounds(self): - """Test tool_server_port validation.""" - # Valid bounds - config = CodingAgentConfig(name="test", tool_server_port=1024) - assert config.tool_server_port == 1024 - - config = CodingAgentConfig(name="test", tool_server_port=65535) - assert config.tool_server_port == 65535 - - # Invalid bounds - with pytest.raises(ValueError): - CodingAgentConfig(name="test", tool_server_port=1023) - - with pytest.raises(ValueError): - CodingAgentConfig(name="test", tool_server_port=65536) + """Tests for CodingAgentConfig.""" + + def test_default_values(self): + """Test that default values are set correctly.""" + config = CodingAgentConfig(name="test_agent") + + assert config.name == "test_agent" + assert config.agent_class == "CodingAgent" + assert config.max_iterations == 10 + assert config.error_retry_attempts == 2 + assert config.stateful is False + assert config.tool_server_port == 8765 + assert config.authorized_imports == DEFAULT_SAFE_IMPORTS + + def test_custom_values(self): + """Test that custom values can be set.""" + custom_imports = frozenset({"json", "math"}) + config = CodingAgentConfig( + name="custom_agent", + model="gemini-2.0-flash", + max_iterations=20, + error_retry_attempts=5, + stateful=True, + tool_server_port=9999, + authorized_imports=custom_imports, + ) + + assert config.name == "custom_agent" + assert config.model == "gemini-2.0-flash" + assert config.max_iterations == 20 + assert config.error_retry_attempts == 5 + assert config.stateful is True + assert config.tool_server_port == 9999 + assert config.authorized_imports == custom_imports + + def test_max_iterations_bounds(self): + """Test max_iterations validation.""" + # Valid bounds + config = CodingAgentConfig(name="test", max_iterations=1) + assert config.max_iterations == 1 + + config = CodingAgentConfig(name="test", max_iterations=100) + assert config.max_iterations == 100 + + # Invalid bounds + with pytest.raises(ValueError): + CodingAgentConfig(name="test", max_iterations=0) + + with pytest.raises(ValueError): + CodingAgentConfig(name="test", max_iterations=101) + + def test_port_bounds(self): + """Test tool_server_port validation.""" + # Valid bounds + config = CodingAgentConfig(name="test", tool_server_port=1024) + assert config.tool_server_port == 1024 + + config = CodingAgentConfig(name="test", tool_server_port=65535) + assert config.tool_server_port == 65535 + + # Invalid bounds + with pytest.raises(ValueError): + CodingAgentConfig(name="test", tool_server_port=1023) + + with pytest.raises(ValueError): + CodingAgentConfig(name="test", tool_server_port=65536) class TestCodingAgentState: - """Tests for CodingAgentState.""" - - def test_default_state(self): - """Test default state values.""" - state = CodingAgentState() - - assert state.iteration_count == 0 - assert state.error_count == 0 - assert state.execution_history == [] - - def test_state_with_history(self): - """Test state with execution history.""" - history = [ - {"iteration": 1, "code": "print('hello')", "success": True}, - {"iteration": 2, "code": "print('world')", "success": True}, - ] - state = CodingAgentState( - iteration_count=2, - error_count=0, - execution_history=history, - ) - - assert state.iteration_count == 2 - assert len(state.execution_history) == 2 - - def test_state_serialization(self): - """Test state can be serialized and deserialized.""" - state = CodingAgentState( - iteration_count=5, - error_count=1, - execution_history=[{"iteration": 1, "code": "x = 1"}], - ) - - dumped = state.model_dump() - restored = CodingAgentState.model_validate(dumped) - - assert restored.iteration_count == 5 - assert restored.error_count == 1 - assert len(restored.execution_history) == 1 + """Tests for CodingAgentState.""" + + def test_default_state(self): + """Test default state values.""" + state = CodingAgentState() + + assert state.iteration_count == 0 + assert state.error_count == 0 + assert state.execution_history == [] + + def test_state_with_history(self): + """Test state with execution history.""" + history = [ + {"iteration": 1, "code": "print('hello')", "success": True}, + {"iteration": 2, "code": "print('world')", "success": True}, + ] + state = CodingAgentState( + iteration_count=2, + error_count=0, + execution_history=history, + ) + + assert state.iteration_count == 2 + assert len(state.execution_history) == 2 + + def test_state_serialization(self): + """Test state can be serialized and deserialized.""" + state = CodingAgentState( + iteration_count=5, + error_count=1, + execution_history=[{"iteration": 1, "code": "x = 1"}], + ) + + dumped = state.model_dump() + restored = CodingAgentState.model_validate(dumped) + + assert restored.iteration_count == 5 + assert restored.error_count == 1 + assert len(restored.execution_history) == 1 class TestCodingAgent: - """Tests for CodingAgent.""" - - def test_agent_creation(self): - """Test basic agent creation.""" - agent = CodingAgent( - name="test_coding_agent", - description="A test coding agent", - ) - - assert agent.name == "test_coding_agent" - assert agent.description == "A test coding agent" - assert agent.max_iterations == 10 - assert agent.error_retry_attempts == 2 - - def test_agent_with_custom_config(self): - """Test agent with custom configuration.""" - agent = CodingAgent( - name="custom_agent", - model="gemini-2.0-flash", - max_iterations=5, - error_retry_attempts=3, - stateful=True, - ) - - assert agent.name == "custom_agent" - assert agent.model == "gemini-2.0-flash" - assert agent.max_iterations == 5 - assert agent.error_retry_attempts == 3 - assert agent.stateful is True - - def test_extract_code_block_tool_code(self): - """Test code extraction from tool_code blocks.""" - agent = CodingAgent(name="test") - - response = """Here's some code: + """Tests for CodingAgent.""" + + def test_agent_creation(self): + """Test basic agent creation.""" + agent = CodingAgent( + name="test_coding_agent", + description="A test coding agent", + ) + + assert agent.name == "test_coding_agent" + assert agent.description == "A test coding agent" + assert agent.max_iterations == 10 + assert agent.error_retry_attempts == 2 + + def test_agent_with_custom_config(self): + """Test agent with custom configuration.""" + agent = CodingAgent( + name="custom_agent", + model="gemini-2.0-flash", + max_iterations=5, + error_retry_attempts=3, + stateful=True, + ) + + assert agent.name == "custom_agent" + assert agent.model == "gemini-2.0-flash" + assert agent.max_iterations == 5 + assert agent.error_retry_attempts == 3 + assert agent.stateful is True + + def test_extract_code_block_tool_code(self): + """Test code extraction from tool_code blocks.""" + agent = CodingAgent(name="test") + + response = """Here's some code: ```tool_code result = search(query="test") print(result) ``` That should work.""" - code = agent._extract_code_block(response) - assert code == 'result = search(query="test")\nprint(result)' + code = agent._extract_code_block(response) + assert code == 'result = search(query="test")\nprint(result)' - def test_extract_code_block_python(self): - """Test code extraction from python blocks.""" - agent = CodingAgent(name="test") + def test_extract_code_block_python(self): + """Test code extraction from python blocks.""" + agent = CodingAgent(name="test") - response = """Here's some code: + response = """Here's some code: ```python x = 1 + 2 print(x) ``` Done.""" - code = agent._extract_code_block(response) - assert code == "x = 1 + 2\nprint(x)" + code = agent._extract_code_block(response) + assert code == "x = 1 + 2\nprint(x)" - def test_extract_code_block_prefers_tool_code(self): - """Test that tool_code blocks are preferred over python blocks.""" - agent = CodingAgent(name="test") + def test_extract_code_block_prefers_tool_code(self): + """Test that tool_code blocks are preferred over python blocks.""" + agent = CodingAgent(name="test") - response = """Code: + response = """Code: ```tool_code tool_result = tool_call() ``` @@ -211,99 +211,100 @@ def test_extract_code_block_prefers_tool_code(self): python_code = True ```""" - code = agent._extract_code_block(response) - assert code == "tool_result = tool_call()" + code = agent._extract_code_block(response) + assert code == "tool_result = tool_call()" - def test_extract_code_block_no_code(self): - """Test code extraction when no code block present.""" - agent = CodingAgent(name="test") + def test_extract_code_block_no_code(self): + """Test code extraction when no code block present.""" + agent = CodingAgent(name="test") - response = "This is just text without any code blocks." - code = agent._extract_code_block(response) - assert code is None + response = "This is just text without any code blocks." + code = agent._extract_code_block(response) + assert code is None - def test_build_error_feedback(self): - """Test error feedback formatting.""" - agent = CodingAgent(name="test") + def test_build_error_feedback(self): + """Test error feedback formatting.""" + agent = CodingAgent(name="test") - error = "NameError: name 'undefined_var' is not defined" - code = "print(undefined_var)" + error = "NameError: name 'undefined_var' is not defined" + code = "print(undefined_var)" - feedback = agent._build_error_feedback(error, code) + feedback = agent._build_error_feedback(error, code) - assert "NameError" in feedback - assert "undefined_var" in feedback - assert code in feedback - assert "fix the error" in feedback.lower() + assert "NameError" in feedback + assert "undefined_var" in feedback + assert code in feedback + assert "fix the error" in feedback.lower() - def test_default_model(self): - """Test that default model is used when not specified.""" - agent = CodingAgent(name="test") + def test_default_model(self): + """Test that default model is used when not specified.""" + agent = CodingAgent(name="test") - # canonical_model property should return a BaseLlm - model = agent.canonical_model - assert model is not None + # canonical_model property should return a BaseLlm + model = agent.canonical_model + assert model is not None - def test_cleanup(self): - """Test that cleanup releases resources.""" - agent = CodingAgent(name="test") - agent._resolved_tools = [MagicMock()] - agent._coding_executor = MagicMock() + def test_cleanup(self): + """Test that cleanup releases resources.""" + agent = CodingAgent(name="test") + agent._resolved_tools = [MagicMock()] + agent._coding_executor = MagicMock() - agent.cleanup() + agent.cleanup() - assert agent._resolved_tools is None - assert agent._coding_executor is None + assert agent._resolved_tools is None + assert agent._coding_executor is None class TestCodingAgentTools: - """Tests for CodingAgent tool handling.""" + """Tests for CodingAgent tool handling.""" + + def test_agent_with_function_tools(self): + """Test agent with function tools.""" - def test_agent_with_function_tools(self): - """Test agent with function tools.""" + def my_tool(query: str) -> dict: + """A test tool.""" + return {"result": query} - def my_tool(query: str) -> dict: - """A test tool.""" - return {"result": query} + agent = CodingAgent( + name="test", + tools=[my_tool], + ) - agent = CodingAgent( - name="test", - tools=[my_tool], - ) + assert len(agent.tools) == 1 - assert len(agent.tools) == 1 + def test_agent_with_base_tool(self): + """Test agent with BaseTool instances.""" - def test_agent_with_base_tool(self): - """Test agent with BaseTool instances.""" + class MockTool(BaseTool): - class MockTool(BaseTool): - def __init__(self): - super().__init__(name="mock_tool", description="A mock tool") + def __init__(self): + super().__init__(name="mock_tool", description="A mock tool") - async def run_async(self, *, args, tool_context): - return {"result": "mock"} + async def run_async(self, *, args, tool_context): + return {"result": "mock"} - tool = MockTool() - agent = CodingAgent( - name="test", - tools=[tool], - ) + tool = MockTool() + agent = CodingAgent( + name="test", + tools=[tool], + ) - assert len(agent.tools) == 1 + assert len(agent.tools) == 1 - @pytest.mark.asyncio - async def test_resolve_tools(self): - """Test tool resolution.""" + @pytest.mark.asyncio + async def test_resolve_tools(self): + """Test tool resolution.""" - def test_func(x: int) -> int: - """Test function.""" - return x * 2 + def test_func(x: int) -> int: + """Test function.""" + return x * 2 - agent = CodingAgent( - name="test", - tools=[test_func], - ) + agent = CodingAgent( + name="test", + tools=[test_func], + ) - tools = await agent._resolve_tools() - assert len(tools) == 1 - assert tools[0].name == "test_func" + tools = await agent._resolve_tools() + assert len(tools) == 1 + assert tools[0].name == "test_func" diff --git a/tests/unittests/code_executors/test_allowlist_validator.py b/tests/unittests/code_executors/test_allowlist_validator.py index 41d12c714f..58ae45771a 100644 --- a/tests/unittests/code_executors/test_allowlist_validator.py +++ b/tests/unittests/code_executors/test_allowlist_validator.py @@ -16,306 +16,305 @@ from __future__ import annotations -import pytest - from google.adk.code_executors.allowlist_validator import AllowlistValidator from google.adk.code_executors.allowlist_validator import DEFAULT_SAFE_IMPORTS from google.adk.code_executors.allowlist_validator import extract_imports from google.adk.code_executors.allowlist_validator import ImportValidationError from google.adk.code_executors.allowlist_validator import is_import_allowed from google.adk.code_executors.allowlist_validator import validate_imports +import pytest class TestExtractImports: - """Tests for extract_imports function.""" + """Tests for extract_imports function.""" - def test_simple_import(self): - """Test extracting simple import statements.""" - code = "import json" - imports = extract_imports(code) + def test_simple_import(self): + """Test extracting simple import statements.""" + code = "import json" + imports = extract_imports(code) - assert len(imports) == 1 - assert imports[0].module == "json" - assert imports[0].is_from_import is False + assert len(imports) == 1 + assert imports[0].module == "json" + assert imports[0].is_from_import is False - def test_multiple_imports(self): - """Test extracting multiple imports.""" - code = """ + def test_multiple_imports(self): + """Test extracting multiple imports.""" + code = """ import json import math import re """ - imports = extract_imports(code) + imports = extract_imports(code) - assert len(imports) == 3 - modules = {i.module for i in imports} - assert modules == {"json", "math", "re"} + assert len(imports) == 3 + modules = {i.module for i in imports} + assert modules == {"json", "math", "re"} - def test_from_import(self): - """Test extracting from imports.""" - code = "from collections import defaultdict" - imports = extract_imports(code) + def test_from_import(self): + """Test extracting from imports.""" + code = "from collections import defaultdict" + imports = extract_imports(code) - assert len(imports) == 1 - assert imports[0].module == "collections" - assert imports[0].names == ["defaultdict"] - assert imports[0].is_from_import is True + assert len(imports) == 1 + assert imports[0].module == "collections" + assert imports[0].names == ["defaultdict"] + assert imports[0].is_from_import is True - def test_from_import_multiple(self): - """Test extracting from imports with multiple names.""" - code = "from typing import List, Dict, Optional" - imports = extract_imports(code) + def test_from_import_multiple(self): + """Test extracting from imports with multiple names.""" + code = "from typing import List, Dict, Optional" + imports = extract_imports(code) - assert len(imports) == 3 - for imp in imports: - assert imp.module == "typing" - assert imp.is_from_import is True + assert len(imports) == 3 + for imp in imports: + assert imp.module == "typing" + assert imp.is_from_import is True - def test_import_with_alias(self): - """Test extracting imports with aliases.""" - code = "import numpy as np" - imports = extract_imports(code) + def test_import_with_alias(self): + """Test extracting imports with aliases.""" + code = "import numpy as np" + imports = extract_imports(code) - assert len(imports) == 1 - assert imports[0].module == "numpy" - assert imports[0].alias == "np" + assert len(imports) == 1 + assert imports[0].module == "numpy" + assert imports[0].alias == "np" - def test_submodule_import(self): - """Test extracting submodule imports.""" - code = "import os.path" - imports = extract_imports(code) + def test_submodule_import(self): + """Test extracting submodule imports.""" + code = "import os.path" + imports = extract_imports(code) - assert len(imports) == 1 - assert imports[0].module == "os.path" + assert len(imports) == 1 + assert imports[0].module == "os.path" - def test_from_submodule_import(self): - """Test extracting from submodule imports.""" - code = "from collections.abc import Mapping" - imports = extract_imports(code) + def test_from_submodule_import(self): + """Test extracting from submodule imports.""" + code = "from collections.abc import Mapping" + imports = extract_imports(code) - assert len(imports) == 1 - assert imports[0].module == "collections.abc" - assert imports[0].names == ["Mapping"] + assert len(imports) == 1 + assert imports[0].module == "collections.abc" + assert imports[0].names == ["Mapping"] - def test_syntax_error(self): - """Test handling of syntax errors.""" - code = "import json\nthis is not valid python" + def test_syntax_error(self): + """Test handling of syntax errors.""" + code = "import json\nthis is not valid python" - with pytest.raises(SyntaxError): - extract_imports(code) + with pytest.raises(SyntaxError): + extract_imports(code) - def test_no_imports(self): - """Test code with no imports.""" - code = "x = 1 + 2\nprint(x)" - imports = extract_imports(code) + def test_no_imports(self): + """Test code with no imports.""" + code = "x = 1 + 2\nprint(x)" + imports = extract_imports(code) - assert len(imports) == 0 + assert len(imports) == 0 class TestIsImportAllowed: - """Tests for is_import_allowed function.""" + """Tests for is_import_allowed function.""" - def test_direct_match(self): - """Test direct import match.""" - allowlist = frozenset({"json", "math"}) + def test_direct_match(self): + """Test direct import match.""" + allowlist = frozenset({"json", "math"}) - assert is_import_allowed("json", allowlist) is True - assert is_import_allowed("math", allowlist) is True - assert is_import_allowed("os", allowlist) is False + assert is_import_allowed("json", allowlist) is True + assert is_import_allowed("math", allowlist) is True + assert is_import_allowed("os", allowlist) is False - def test_wildcard_match(self): - """Test wildcard pattern matching.""" - allowlist = frozenset({"collections.*"}) + def test_wildcard_match(self): + """Test wildcard pattern matching.""" + allowlist = frozenset({"collections.*"}) - assert is_import_allowed("collections.abc", allowlist) is True - assert is_import_allowed("collections.defaultdict", allowlist) is True - assert is_import_allowed("itertools", allowlist) is False + assert is_import_allowed("collections.abc", allowlist) is True + assert is_import_allowed("collections.defaultdict", allowlist) is True + assert is_import_allowed("itertools", allowlist) is False - def test_deep_wildcard_match(self): - """Test wildcard matching for deep submodules.""" - allowlist = frozenset({"collections.*"}) + def test_deep_wildcard_match(self): + """Test wildcard matching for deep submodules.""" + allowlist = frozenset({"collections.*"}) - assert is_import_allowed("collections.abc.Mapping", allowlist) is True + assert is_import_allowed("collections.abc.Mapping", allowlist) is True - def test_exact_vs_wildcard(self): - """Test that exact matches work without wildcard.""" - allowlist = frozenset({"numpy"}) + def test_exact_vs_wildcard(self): + """Test that exact matches work without wildcard.""" + allowlist = frozenset({"numpy"}) - assert is_import_allowed("numpy", allowlist) is True - # Without wildcard, submodules are not allowed - assert is_import_allowed("numpy.array", allowlist) is False + assert is_import_allowed("numpy", allowlist) is True + # Without wildcard, submodules are not allowed + assert is_import_allowed("numpy.array", allowlist) is False - def test_multiple_patterns(self): - """Test multiple patterns in allowlist.""" - allowlist = frozenset({"json", "typing.*", "collections"}) + def test_multiple_patterns(self): + """Test multiple patterns in allowlist.""" + allowlist = frozenset({"json", "typing.*", "collections"}) - assert is_import_allowed("json", allowlist) is True - assert is_import_allowed("typing.List", allowlist) is True - assert is_import_allowed("collections", allowlist) is True - assert is_import_allowed("collections.abc", allowlist) is False + assert is_import_allowed("json", allowlist) is True + assert is_import_allowed("typing.List", allowlist) is True + assert is_import_allowed("collections", allowlist) is True + assert is_import_allowed("collections.abc", allowlist) is False class TestValidateImports: - """Tests for validate_imports function.""" + """Tests for validate_imports function.""" - def test_all_allowed(self): - """Test code with all imports allowed.""" - code = """ + def test_all_allowed(self): + """Test code with all imports allowed.""" + code = """ import json import math from typing import List """ - allowlist = frozenset({"json", "math", "typing.*"}) + allowlist = frozenset({"json", "math", "typing.*"}) - violations = validate_imports(code, allowlist) - assert len(violations) == 0 + violations = validate_imports(code, allowlist) + assert len(violations) == 0 - def test_some_violations(self): - """Test code with some unauthorized imports.""" - code = """ + def test_some_violations(self): + """Test code with some unauthorized imports.""" + code = """ import json import os import subprocess """ - allowlist = frozenset({"json"}) + allowlist = frozenset({"json"}) - violations = validate_imports(code, allowlist) - assert len(violations) == 2 - assert any("os" in v for v in violations) - assert any("subprocess" in v for v in violations) + violations = validate_imports(code, allowlist) + assert len(violations) == 2 + assert any("os" in v for v in violations) + assert any("subprocess" in v for v in violations) - def test_from_import_violations(self): - """Test from import violations.""" - code = "from os import system" - allowlist = frozenset({"json"}) + def test_from_import_violations(self): + """Test from import violations.""" + code = "from os import system" + allowlist = frozenset({"json"}) - violations = validate_imports(code, allowlist) - assert len(violations) == 1 - assert "os" in violations[0] + violations = validate_imports(code, allowlist) + assert len(violations) == 1 + assert "os" in violations[0] - def test_syntax_error_violation(self): - """Test that syntax errors are reported as violations.""" - code = "import json\n$$$invalid" - allowlist = frozenset({"json"}) + def test_syntax_error_violation(self): + """Test that syntax errors are reported as violations.""" + code = "import json\n$$$invalid" + allowlist = frozenset({"json"}) - violations = validate_imports(code, allowlist) - assert len(violations) == 1 - assert "Syntax error" in violations[0] + violations = validate_imports(code, allowlist) + assert len(violations) == 1 + assert "Syntax error" in violations[0] class TestImportValidationError: - """Tests for ImportValidationError exception.""" + """Tests for ImportValidationError exception.""" - def test_error_message(self): - """Test error message formatting.""" - violations = ["Unauthorized import: os", "Unauthorized import: subprocess"] - code = "import os\nimport subprocess" + def test_error_message(self): + """Test error message formatting.""" + violations = ["Unauthorized import: os", "Unauthorized import: subprocess"] + code = "import os\nimport subprocess" - error = ImportValidationError(violations, code) + error = ImportValidationError(violations, code) - assert "Import validation failed" in str(error) - assert "os" in str(error) - assert "subprocess" in str(error) + assert "Import validation failed" in str(error) + assert "os" in str(error) + assert "subprocess" in str(error) - def test_error_attributes(self): - """Test error attributes.""" - violations = ["violation1", "violation2"] - code = "some code" + def test_error_attributes(self): + """Test error attributes.""" + violations = ["violation1", "violation2"] + code = "some code" - error = ImportValidationError(violations, code) + error = ImportValidationError(violations, code) - assert error.violations == violations - assert error.code == code + assert error.violations == violations + assert error.code == code class TestAllowlistValidator: - """Tests for AllowlistValidator class.""" + """Tests for AllowlistValidator class.""" - def test_default_allowlist(self): - """Test validator with default allowlist.""" - validator = AllowlistValidator() + def test_default_allowlist(self): + """Test validator with default allowlist.""" + validator = AllowlistValidator() - # These should be in the default safe imports - assert validator.is_allowed("json") is True - assert validator.is_allowed("math") is True - assert validator.is_allowed("typing") is True + # These should be in the default safe imports + assert validator.is_allowed("json") is True + assert validator.is_allowed("math") is True + assert validator.is_allowed("typing") is True - # These should not be in the default safe imports - assert validator.is_allowed("os") is False - assert validator.is_allowed("subprocess") is False + # These should not be in the default safe imports + assert validator.is_allowed("os") is False + assert validator.is_allowed("subprocess") is False - def test_custom_allowlist(self): - """Test validator with custom allowlist.""" - custom = frozenset({"custom_module"}) - validator = AllowlistValidator(allowlist=custom) + def test_custom_allowlist(self): + """Test validator with custom allowlist.""" + custom = frozenset({"custom_module"}) + validator = AllowlistValidator(allowlist=custom) - assert validator.is_allowed("custom_module") is True - assert validator.is_allowed("json") is False + assert validator.is_allowed("custom_module") is True + assert validator.is_allowed("json") is False - def test_additional_imports(self): - """Test adding additional imports to default.""" - additional = frozenset({"custom_module", "another_module"}) - validator = AllowlistValidator(additional_imports=additional) + def test_additional_imports(self): + """Test adding additional imports to default.""" + additional = frozenset({"custom_module", "another_module"}) + validator = AllowlistValidator(additional_imports=additional) - # Should have both default and additional - assert validator.is_allowed("json") is True - assert validator.is_allowed("custom_module") is True - assert validator.is_allowed("another_module") is True + # Should have both default and additional + assert validator.is_allowed("json") is True + assert validator.is_allowed("custom_module") is True + assert validator.is_allowed("another_module") is True - def test_validate_method(self): - """Test validate method returns violations.""" - validator = AllowlistValidator(allowlist=frozenset({"json"})) + def test_validate_method(self): + """Test validate method returns violations.""" + validator = AllowlistValidator(allowlist=frozenset({"json"})) - violations = validator.validate("import json\nimport os") - assert len(violations) == 1 - assert "os" in violations[0] + violations = validator.validate("import json\nimport os") + assert len(violations) == 1 + assert "os" in violations[0] - def test_validate_strict_raises(self): - """Test validate_strict raises on violations.""" - validator = AllowlistValidator(allowlist=frozenset({"json"})) + def test_validate_strict_raises(self): + """Test validate_strict raises on violations.""" + validator = AllowlistValidator(allowlist=frozenset({"json"})) - with pytest.raises(ImportValidationError): - validator.validate_strict("import os") + with pytest.raises(ImportValidationError): + validator.validate_strict("import os") - def test_validate_strict_passes(self): - """Test validate_strict passes when no violations.""" - validator = AllowlistValidator(allowlist=frozenset({"json"})) + def test_validate_strict_passes(self): + """Test validate_strict passes when no violations.""" + validator = AllowlistValidator(allowlist=frozenset({"json"})) - # Should not raise - validator.validate_strict("import json") + # Should not raise + validator.validate_strict("import json") - def test_add_allowed_imports(self): - """Test adding imports after construction.""" - validator = AllowlistValidator(allowlist=frozenset({"json"})) + def test_add_allowed_imports(self): + """Test adding imports after construction.""" + validator = AllowlistValidator(allowlist=frozenset({"json"})) - assert validator.is_allowed("os") is False + assert validator.is_allowed("os") is False - validator.add_allowed_imports({"os"}) + validator.add_allowed_imports({"os"}) - assert validator.is_allowed("os") is True + assert validator.is_allowed("os") is True class TestDefaultSafeImports: - """Tests for the default safe imports list.""" - - def test_common_safe_imports_included(self): - """Test that common safe imports are in the default list.""" - assert "json" in DEFAULT_SAFE_IMPORTS - assert "math" in DEFAULT_SAFE_IMPORTS - assert "re" in DEFAULT_SAFE_IMPORTS - assert "datetime" in DEFAULT_SAFE_IMPORTS - assert "typing" in DEFAULT_SAFE_IMPORTS - assert "collections" in DEFAULT_SAFE_IMPORTS - - def test_dangerous_imports_not_included(self): - """Test that dangerous imports are not in the default list.""" - assert "os" not in DEFAULT_SAFE_IMPORTS - assert "subprocess" not in DEFAULT_SAFE_IMPORTS - assert "sys" not in DEFAULT_SAFE_IMPORTS - assert "socket" not in DEFAULT_SAFE_IMPORTS - assert "ctypes" not in DEFAULT_SAFE_IMPORTS - - def test_wildcard_patterns_included(self): - """Test that wildcard patterns are included.""" - assert "collections.*" in DEFAULT_SAFE_IMPORTS - assert "typing.*" in DEFAULT_SAFE_IMPORTS + """Tests for the default safe imports list.""" + + def test_common_safe_imports_included(self): + """Test that common safe imports are in the default list.""" + assert "json" in DEFAULT_SAFE_IMPORTS + assert "math" in DEFAULT_SAFE_IMPORTS + assert "re" in DEFAULT_SAFE_IMPORTS + assert "datetime" in DEFAULT_SAFE_IMPORTS + assert "typing" in DEFAULT_SAFE_IMPORTS + assert "collections" in DEFAULT_SAFE_IMPORTS + + def test_dangerous_imports_not_included(self): + """Test that dangerous imports are not in the default list.""" + assert "os" not in DEFAULT_SAFE_IMPORTS + assert "subprocess" not in DEFAULT_SAFE_IMPORTS + assert "sys" not in DEFAULT_SAFE_IMPORTS + assert "socket" not in DEFAULT_SAFE_IMPORTS + assert "ctypes" not in DEFAULT_SAFE_IMPORTS + + def test_wildcard_patterns_included(self): + """Test that wildcard patterns are included.""" + assert "collections.*" in DEFAULT_SAFE_IMPORTS + assert "typing.*" in DEFAULT_SAFE_IMPORTS diff --git a/tests/unittests/code_executors/test_tool_code_generator.py b/tests/unittests/code_executors/test_tool_code_generator.py index 39a676da42..5a067e47e2 100644 --- a/tests/unittests/code_executors/test_tool_code_generator.py +++ b/tests/unittests/code_executors/test_tool_code_generator.py @@ -16,305 +16,304 @@ from __future__ import annotations -import pytest from unittest.mock import MagicMock -from google.genai import types - from google.adk.code_executors.tool_code_generator import generate_full_code_with_stubs from google.adk.code_executors.tool_code_generator import generate_runtime_header from google.adk.code_executors.tool_code_generator import generate_system_prompt from google.adk.code_executors.tool_code_generator import generate_tool_stubs from google.adk.tools.base_tool import BaseTool +from google.genai import types +import pytest class MockTool(BaseTool): - """Mock tool for testing.""" - - def __init__( - self, - name: str = "mock_tool", - description: str = "A mock tool for testing", - params: dict = None, - ): - super().__init__(name=name, description=description) - self._params = params or {} - - def _get_declaration(self): - properties = {} - required = [] - - for param_name, param_info in self._params.items(): - properties[param_name] = { - "type": param_info.get("type", "string"), - "description": param_info.get("description", ""), - } - if param_info.get("required", False): - required.append(param_name) - - return types.FunctionDeclaration( - name=self.name, - description=self.description, - parameters=types.Schema( - type="object", - properties=properties, - required=required if required else None, - ), - ) - - async def run_async(self, *, args, tool_context): - return {"result": "mock"} + """Mock tool for testing.""" + + def __init__( + self, + name: str = "mock_tool", + description: str = "A mock tool for testing", + params: dict = None, + ): + super().__init__(name=name, description=description) + self._params = params or {} + + def _get_declaration(self): + properties = {} + required = [] + + for param_name, param_info in self._params.items(): + properties[param_name] = { + "type": param_info.get("type", "string"), + "description": param_info.get("description", ""), + } + if param_info.get("required", False): + required.append(param_name) + + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type="object", + properties=properties, + required=required if required else None, + ), + ) + + async def run_async(self, *, args, tool_context): + return {"result": "mock"} class TestGenerateRuntimeHeader: - """Tests for generate_runtime_header function.""" + """Tests for generate_runtime_header function.""" - def test_generates_valid_header(self): - """Test that the header contains required elements.""" - url = "http://localhost:8765" - header = generate_runtime_header(url) + def test_generates_valid_header(self): + """Test that the header contains required elements.""" + url = "http://localhost:8765" + header = generate_runtime_header(url) - # Should contain the URL - assert url in header + # Should contain the URL + assert url in header - # Should contain helper functions - assert "_call_adk_tool" in header - assert "final_answer" in header - assert "__get_tool_traces" in header + # Should contain helper functions + assert "_call_adk_tool" in header + assert "final_answer" in header + assert "__get_tool_traces" in header - # Should be valid Python syntax - compile(header, "", "exec") + # Should be valid Python syntax + compile(header, "", "exec") - def test_header_with_different_urls(self): - """Test header generation with different URLs.""" - urls = [ - "http://localhost:8765", - "http://host.docker.internal:9999", - "http://172.17.0.1:8765", - ] + def test_header_with_different_urls(self): + """Test header generation with different URLs.""" + urls = [ + "http://localhost:8765", + "http://host.docker.internal:9999", + "http://172.17.0.1:8765", + ] - for url in urls: - header = generate_runtime_header(url) - assert url in header + for url in urls: + header = generate_runtime_header(url) + assert url in header - def test_header_contains_trace_collection(self): - """Test that header contains trace collection code.""" - header = generate_runtime_header("http://localhost:8765") + def test_header_contains_trace_collection(self): + """Test that header contains trace collection code.""" + header = generate_runtime_header("http://localhost:8765") - assert "__ADK_TOOL_TRACES" in header - assert "__get_tool_traces" in header - assert "__clear_tool_traces" in header + assert "__ADK_TOOL_TRACES" in header + assert "__get_tool_traces" in header + assert "__clear_tool_traces" in header - def test_header_contains_final_answer_marker(self): - """Test that header contains final answer marker.""" - header = generate_runtime_header("http://localhost:8765") + def test_header_contains_final_answer_marker(self): + """Test that header contains final answer marker.""" + header = generate_runtime_header("http://localhost:8765") - assert "__FINAL_ANSWER__" in header - assert "final_answer" in header + assert "__FINAL_ANSWER__" in header + assert "final_answer" in header class TestGenerateToolStubs: - """Tests for generate_tool_stubs function.""" - - def test_generates_stub_for_tool(self): - """Test generating stub for a single tool.""" - tool = MockTool( - name="search", - description="Search for information", - params={ - "query": { - "type": "string", - "description": "The search query", - "required": True, - } - }, - ) + """Tests for generate_tool_stubs function.""" + + def test_generates_stub_for_tool(self): + """Test generating stub for a single tool.""" + tool = MockTool( + name="search", + description="Search for information", + params={ + "query": { + "type": "string", + "description": "The search query", + "required": True, + } + }, + ) - stubs = generate_tool_stubs([tool]) + stubs = generate_tool_stubs([tool]) - # Should contain function definition - assert "def search(" in stubs - assert "query" in stubs + # Should contain function definition + assert "def search(" in stubs + assert "query" in stubs - # Should be valid Python - compile(stubs, "", "exec") + # Should be valid Python + compile(stubs, "", "exec") - def test_generates_stubs_for_multiple_tools(self): - """Test generating stubs for multiple tools.""" - tools = [ - MockTool(name="tool1", description="First tool"), - MockTool(name="tool2", description="Second tool"), - MockTool(name="tool3", description="Third tool"), - ] + def test_generates_stubs_for_multiple_tools(self): + """Test generating stubs for multiple tools.""" + tools = [ + MockTool(name="tool1", description="First tool"), + MockTool(name="tool2", description="Second tool"), + MockTool(name="tool3", description="Third tool"), + ] - stubs = generate_tool_stubs(tools) + stubs = generate_tool_stubs(tools) - assert "def tool1(" in stubs - assert "def tool2(" in stubs - assert "def tool3(" in stubs + assert "def tool1(" in stubs + assert "def tool2(" in stubs + assert "def tool3(" in stubs - def test_stub_includes_docstring(self): - """Test that stubs include docstrings.""" - tool = MockTool( - name="my_tool", - description="A tool that does something useful", - ) + def test_stub_includes_docstring(self): + """Test that stubs include docstrings.""" + tool = MockTool( + name="my_tool", + description="A tool that does something useful", + ) - stubs = generate_tool_stubs([tool]) + stubs = generate_tool_stubs([tool]) - assert '"""' in stubs - assert "A tool that does something useful" in stubs + assert '"""' in stubs + assert "A tool that does something useful" in stubs - def test_stub_includes_type_hints(self): - """Test that stubs include type hints.""" - tool = MockTool( - name="typed_tool", - description="A typed tool", - params={ - "count": {"type": "integer", "description": "A count"}, - "name": {"type": "string", "description": "A name"}, - "enabled": {"type": "boolean", "description": "Is enabled"}, - }, - ) + def test_stub_includes_type_hints(self): + """Test that stubs include type hints.""" + tool = MockTool( + name="typed_tool", + description="A typed tool", + params={ + "count": {"type": "integer", "description": "A count"}, + "name": {"type": "string", "description": "A name"}, + "enabled": {"type": "boolean", "description": "Is enabled"}, + }, + ) - stubs = generate_tool_stubs([tool]) + stubs = generate_tool_stubs([tool]) - assert "int" in stubs - assert "str" in stubs - assert "bool" in stubs + assert "int" in stubs + assert "str" in stubs + assert "bool" in stubs - def test_empty_tool_list(self): - """Test generating stubs for empty tool list.""" - stubs = generate_tool_stubs([]) + def test_empty_tool_list(self): + """Test generating stubs for empty tool list.""" + stubs = generate_tool_stubs([]) - # Should still be valid Python - compile(stubs, "", "exec") + # Should still be valid Python + compile(stubs, "", "exec") class TestGenerateSystemPrompt: - """Tests for generate_system_prompt function.""" - - def test_generates_prompt_with_tools(self): - """Test generating system prompt with tools.""" - tools = [ - MockTool( - name="search", - description="Search the web", - params={"query": {"type": "string", "required": True}}, - ), - ] - - prompt = generate_system_prompt(tools) - - # Should contain tool documentation - assert "search" in prompt - assert "Search the web" in prompt - - # Should contain usage instructions - assert "tool_code" in prompt - assert "final_answer" in prompt - - def test_generates_prompt_with_custom_instruction(self): - """Test generating prompt with custom instruction.""" - tools = [] - custom = "Always be polite and helpful." - - prompt = generate_system_prompt(tools, custom_instruction=custom) - - assert custom in prompt - - def test_generates_prompt_with_examples(self): - """Test that prompt contains examples.""" - tools = [] - prompt = generate_system_prompt(tools) - - assert "Example" in prompt - assert "```tool_code" in prompt - - def test_generates_prompt_with_parameter_docs(self): - """Test that prompt includes parameter documentation.""" - tools = [ - MockTool( - name="get_weather", - description="Get weather for a city", - params={ - "city": { - "type": "string", - "description": "The city name", - "required": True, - }, - "units": { - "type": "string", - "description": "Temperature units", - "required": False, - }, + """Tests for generate_system_prompt function.""" + + def test_generates_prompt_with_tools(self): + """Test generating system prompt with tools.""" + tools = [ + MockTool( + name="search", + description="Search the web", + params={"query": {"type": "string", "required": True}}, + ), + ] + + prompt = generate_system_prompt(tools) + + # Should contain tool documentation + assert "search" in prompt + assert "Search the web" in prompt + + # Should contain usage instructions + assert "tool_code" in prompt + assert "final_answer" in prompt + + def test_generates_prompt_with_custom_instruction(self): + """Test generating prompt with custom instruction.""" + tools = [] + custom = "Always be polite and helpful." + + prompt = generate_system_prompt(tools, custom_instruction=custom) + + assert custom in prompt + + def test_generates_prompt_with_examples(self): + """Test that prompt contains examples.""" + tools = [] + prompt = generate_system_prompt(tools) + + assert "Example" in prompt + assert "```tool_code" in prompt + + def test_generates_prompt_with_parameter_docs(self): + """Test that prompt includes parameter documentation.""" + tools = [ + MockTool( + name="get_weather", + description="Get weather for a city", + params={ + "city": { + "type": "string", + "description": "The city name", + "required": True, + }, + "units": { + "type": "string", + "description": "Temperature units", + "required": False, }, - ), - ] + }, + ), + ] - prompt = generate_system_prompt(tools) + prompt = generate_system_prompt(tools) - assert "city" in prompt - assert "units" in prompt - assert "required" in prompt.lower() or "optional" in prompt.lower() + assert "city" in prompt + assert "units" in prompt + assert "required" in prompt.lower() or "optional" in prompt.lower() class TestGenerateFullCodeWithStubs: - """Tests for generate_full_code_with_stubs function.""" - - def test_generates_complete_code(self): - """Test generating complete executable code.""" - tools = [MockTool(name="my_tool", description="A tool")] - user_code = "result = my_tool()\nprint(result)" - - full_code = generate_full_code_with_stubs( - user_code=user_code, - tools=tools, - tool_server_url="http://localhost:8765", - ) - - # Should contain runtime header - assert "_call_adk_tool" in full_code - - # Should contain tool stub - assert "def my_tool(" in full_code - - # Should contain user code - assert user_code in full_code - - # Should be valid Python - compile(full_code, "", "exec") - - def test_generated_code_outputs_traces(self): - """Test that generated code outputs traces.""" - tools = [] - user_code = "x = 1" - - full_code = generate_full_code_with_stubs( - user_code=user_code, - tools=tools, - tool_server_url="http://localhost:8765", - ) - - assert "__TOOL_TRACE__" in full_code - - def test_generated_code_is_executable(self): - """Test that generated code can be compiled.""" - tools = [ - MockTool(name="tool_a", description="Tool A"), - MockTool(name="tool_b", description="Tool B"), - ] - user_code = """ + """Tests for generate_full_code_with_stubs function.""" + + def test_generates_complete_code(self): + """Test generating complete executable code.""" + tools = [MockTool(name="my_tool", description="A tool")] + user_code = "result = my_tool()\nprint(result)" + + full_code = generate_full_code_with_stubs( + user_code=user_code, + tools=tools, + tool_server_url="http://localhost:8765", + ) + + # Should contain runtime header + assert "_call_adk_tool" in full_code + + # Should contain tool stub + assert "def my_tool(" in full_code + + # Should contain user code + assert user_code in full_code + + # Should be valid Python + compile(full_code, "", "exec") + + def test_generated_code_outputs_traces(self): + """Test that generated code outputs traces.""" + tools = [] + user_code = "x = 1" + + full_code = generate_full_code_with_stubs( + user_code=user_code, + tools=tools, + tool_server_url="http://localhost:8765", + ) + + assert "__TOOL_TRACE__" in full_code + + def test_generated_code_is_executable(self): + """Test that generated code can be compiled.""" + tools = [ + MockTool(name="tool_a", description="Tool A"), + MockTool(name="tool_b", description="Tool B"), + ] + user_code = """ result_a = tool_a() result_b = tool_b() print(result_a, result_b) """ - full_code = generate_full_code_with_stubs( - user_code=user_code, - tools=tools, - tool_server_url="http://localhost:8765", - ) + full_code = generate_full_code_with_stubs( + user_code=user_code, + tools=tools, + tool_server_url="http://localhost:8765", + ) - # Should compile without errors - compile(full_code, "", "exec") + # Should compile without errors + compile(full_code, "", "exec") From c012e21eccdaa9a555fb1a18419a1b08e44dc82f Mon Sep 17 00:00:00 2001 From: Sudhendra Date: Sat, 17 Jan 2026 19:41:00 -0600 Subject: [PATCH 04/10] docs: Enhance CodingAgent issue with research context and technical depth - Add research foundation from CodeAct (ICML 2024) and DynaSaur (COLM 2025) - Reference HuggingFace smolagents as inspiration (25k+ GitHub stars) - Expand problem statement with context window bottleneck analysis - Add detailed alternatives considered section with rationale - Include future roadmap for stateful execution and alternative sandboxes - Add concrete user pain points and how CodingAgent solves them --- .github/CODING_AGENT_ISSUE.md | 332 ++++++++++++++++++++++++---------- 1 file changed, 238 insertions(+), 94 deletions(-) diff --git a/.github/CODING_AGENT_ISSUE.md b/.github/CODING_AGENT_ISSUE.md index 5f6372beca..05affec6dd 100644 --- a/.github/CODING_AGENT_ISSUE.md +++ b/.github/CODING_AGENT_ISSUE.md @@ -6,160 +6,304 @@ ## Title -`feat(agents): Add CodingAgent for code generation and sandboxed execution` +`feat(agents): Add CodingAgent - Agents that Think in Code with Sandboxed Execution` --- ## Is your feature request related to a problem? Please describe. -Currently, ADK agents can only interact with the world through pre-defined tools. While powerful, this approach has limitations: +### The Fundamental Limitation of Tool-Calling Agents -1. **Limited flexibility**: Users must anticipate all possible operations and create tools for each -2. **No computational capability**: Agents cannot perform complex calculations, data analysis, or create visualizations without custom tools -3. **No iteration**: Standard tool-calling doesn't easily support multi-step reasoning with intermediate computations -4. **Competitive gap**: Other platforms (OpenAI Code Interpreter, Anthropic's computer use) offer code execution capabilities +Current ADK agents operate through a **predefined tool paradigm**: the agent receives a task, selects from a fixed set of tools, and chains tool calls to accomplish goals. While effective for well-scoped problems, this architecture imposes fundamental constraints that limit agent capabilities in real-world scenarios: -**User pain points:** -- "I want my agent to analyze a CSV file and create a chart" - requires building custom tools -- "I need multi-step calculations with intermediate results" - awkward with standard tools -- "I want the agent to figure out HOW to solve a problem, not just call predefined functions" +#### 1. **Constrained Action Space** + +Tool-calling agents are restricted to a finite, pre-enumerated set of actions. As demonstrated in the [CodeAct paper (ICML 2024)](https://arxiv.org/abs/2402.01030), this creates a **combinatorial explosion problem**: complex tasks requiring composition of multiple operations become intractable when each combination must be explicitly defined as a tool. The paper shows that code-based actions achieve **up to 20% higher success rates** by allowing arbitrary composition. + +#### 2. **Context Window Bottleneck** + +Modern LLMs have context windows ranging from 8K to 200K tokens, but complex reasoning tasks can easily exceed these limits. Tool-calling agents must maintain entire conversation histories, tool schemas, and intermediate results in context. Code agents solve this by **offloading computation to the execution environment**—variables persist in the sandbox, not in the context window. This insight is central to systems like Claude Code and OpenCode that can work on entire codebases. + +#### 3. **Inability to Dynamically Create Actions** + +The [DynaSaur paper (COLM 2025)](https://arxiv.org/abs/2411.01747) identifies a critical flaw: "Existing LLM agent systems typically select actions from a fixed and predefined set at every step... this requires substantial human effort to enumerate and implement all possible actions, which is impractical in complex environments." Code agents can **generate novel actions on-the-fly**, adapting to unforeseen scenarios. + +#### 4. **Competitive Gap** + +The industry has converged on code-generating agents as the next evolution: +- **OpenAI Code Interpreter**: Code execution in sandbox +- **Anthropic Claude's Computer Use**: Code-based computer control +- **HuggingFace smolagents**: "Agents that think in code" ([25k+ GitHub stars](https://github.com/huggingface/smolagents)) +- **OpenCode, Claude Code, Cursor**: Production coding agents using this pattern + +ADK currently lacks this capability, forcing users to build complex workarounds or use competing frameworks. + +### Concrete User Pain Points + +``` +"I want my agent to analyze a 50MB CSV and create visualizations" +→ Current: Build custom tools for every possible analysis operation +→ With CodingAgent: Agent writes pandas/matplotlib code dynamically + +"I need multi-step calculations with intermediate results" +→ Current: Chain multiple tool calls, losing state between each +→ With CodingAgent: Variables persist in sandbox across iterations + +"I want the agent to figure out HOW to solve a problem" +→ Current: Limited to predefined solution paths +→ With CodingAgent: Agent generates arbitrary solution code + +"I need to work with long documents without hitting context limits" +→ Current: Complex chunking strategies, lost coherence +→ With CodingAgent: Load documents into sandbox, process incrementally +``` --- ## Describe the solution you'd like -A new experimental agent type called **CodingAgent** that: +### CodingAgent: A ReAct Agent that Thinks in Code -1. Receives a task from the user -2. Generates Python code to accomplish the task (using `tool_code` blocks) -3. Executes the code in a sandboxed Docker container -4. Processes results and either provides an answer or continues iterating (ReAct loop) -5. Can call ADK tools from within generated code via HTTP IPC +Inspired by [HuggingFace's smolagents](https://github.com/huggingface/smolagents) and grounded in recent research ([CodeAct](https://arxiv.org/abs/2402.01030), [DynaSaur](https://arxiv.org/abs/2411.01747)), CodingAgent is a new agent type that: + +1. **Generates Python code** as its action representation (in `tool_code` blocks) +2. **Executes code in sandboxed Docker containers** for security +3. **Calls ADK tools from generated code** via HTTP IPC +4. **Iterates using a ReAct loop** until producing a final answer +5. **Maintains state across iterations** in the execution environment ### Architecture ``` -┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐ -│ User Query │────▶│ CodingAgent │────▶│ Docker Container│ -│ │ │ (Gemini LLM) │ │ (Python 3.11) │ -└─────────────────┘ └──────────────────┘ └─────────────────┘ - │ │ - │ │ Executes - ▼ │ generated code - ┌──────────────┐ │ - │ Tool Server │◀────────────────┘ - │ (HTTP IPC) │ Tool calls via HTTP +┌─────────────────┐ ┌──────────────────┐ ┌─────────────────────────┐ +│ User Query │────▶│ CodingAgent │────▶│ Docker Container │ +│ │ │ (Gemini LLM) │ │ (Python 3.11) │ +└─────────────────┘ └──────────────────┘ │ │ + │ │ • pandas, numpy │ + │ │ • matplotlib, seaborn │ + │ ReAct Loop │ • Any pip package │ + │ │ • Persistent state │ + ▼ └───────────┬─────────────┘ + ┌──────────────┐ │ + │ Tool Server │◀───────────────────┘ + │ (HTTP IPC) │ Tool calls via HTTP POST + │ Port 8765 │ (fetch_url, save_chart, etc.) └──────────────┘ ``` +### Why Code Actions Are Superior + +From the CodeAct paper: +> "CodeAct can execute code actions and dynamically revise prior actions or emit new actions upon new observations through multi-turn interactions... CodeAct outperforms widely used alternatives (up to 20% higher success rate)." + +From DynaSaur: +> "The agent interacts with its environment by generating and executing programs written in a general-purpose programming language. Moreover, generated actions are accumulated over time for future reuse." + +Code provides: +- **Composability**: Combine operations arbitrarily (`for url in urls: data.append(fetch(url))`) +- **State persistence**: Variables survive across iterations +- **Dynamic tool creation**: Write new functions as needed +- **Error handling**: Try/except, retries, fallbacks in code +- **Computational offloading**: Process data in sandbox, not context + ### API Design ```python from google.adk.agents import CodingAgent from google.adk.code_executors import ContainerCodeExecutor +from google.adk.code_executors.allowlist_validator import DEFAULT_SAFE_IMPORTS def fetch_data(url: str) -> dict: - """Fetch data from a URL.""" + """Fetch data from a URL - available to generated code.""" + # Implementation... + +def save_chart(image_data: str, filename: str) -> dict: + """Save chart to host filesystem - bridges container to host.""" # Implementation... root_agent = CodingAgent( name="data_analyst", model="gemini-2.5-flash", instruction="You are a data analyst. Analyze data and provide insights.", - tools=[fetch_data], # Tools available to generated code + tools=[fetch_data, save_chart], # Tools callable from generated code code_executor=ContainerCodeExecutor(image="python:3.11-slim"), - authorized_imports=DEFAULT_SAFE_IMPORTS | {"pandas", "matplotlib"}, + authorized_imports=DEFAULT_SAFE_IMPORTS | {"pandas", "matplotlib", "numpy"}, max_iterations=10, error_retry_attempts=2, + stateful=False, # Future: True for persistent state across turns ) ``` ### Key Components -| Component | Description | -|-----------|-------------| -| CodingAgent | Main agent class with ReAct loop | -| CodingAgentCodeExecutor | Wrapper that injects tool stubs into code | -| ToolCodeGenerator | Generates Python function stubs for tools | -| ToolExecutionServer | HTTP server for tool IPC from container | -| AllowlistValidator | Import security validation | +| Component | Purpose | +|-----------|---------| +| **CodingAgent** | ReAct loop orchestrator, code extraction, LLM interaction | +| **CodingAgentCodeExecutor** | Wraps underlying executor, injects tool stubs | +| **ToolCodeGenerator** | Generates Python function stubs for ADK tools | +| **ToolExecutionServer** | HTTP server enabling tool calls from container | +| **AllowlistValidator** | Security: validates imports against allowlist | -### Security Features +### Security Model 1. **Sandboxed execution**: All code runs in isolated Docker containers -2. **Import allowlisting**: Only authorized imports are permitted (configurable) -3. **Tool isolation**: Tools execute on host via HTTP, not in container -4. **No filesystem access**: Container has no access to host filesystem +2. **Import allowlisting**: Only explicitly authorized imports permitted +3. **Tool isolation**: Tools execute on host, not in container +4. **No filesystem access**: Container cannot access host filesystem +5. **Network isolation**: Container only reaches tool server +6. **Configurable**: Users control exactly what's permitted --- ## Describe alternatives you've considered -### Alternative 1: Extend LlmAgent with code execution -- **Pros**: Simpler architecture, reuses existing agent -- **Cons**: Conflates two distinct patterns, harder to configure +### Alternative 1: Extend LlmAgent with Code Execution Tool -### Alternative 2: Code execution as a tool only -- **Pros**: Minimal changes, fits existing model -- **Cons**: No ReAct loop, no iteration, limited capability +Add code execution as just another tool that LlmAgent can call. -### Alternative 3: Use external code execution service -- **Pros**: Offloads security concerns -- **Cons**: Adds external dependency, latency, cost +**Pros:** +- Minimal API changes +- Reuses existing agent infrastructure -**Chosen approach**: Dedicated CodingAgent provides cleanest separation of concerns, explicit configuration, and full control over the execution environment. +**Cons:** +- No ReAct loop for code iteration +- Tool-calling overhead for every code snippet +- Mixes paradigms (tool-calling agent calling code execution tool) +- No state persistence between code executions +- Doesn't capture the "thinking in code" pattern ---- +**Why rejected:** This approach treats code execution as an afterthought rather than a first-class paradigm. The power of code agents comes from the tight integration of code generation, execution, and iteration—not from occasionally executing snippets. -## Additional context +### Alternative 2: External Code Execution Service -### Implementation Status +Integrate with external services like E2B, Modal, or Blaxel for code execution. + +**Pros:** +- Offloads security concerns to specialized providers +- Potentially more scalable +- No Docker dependency for users + +**Cons:** +- External dependency and potential costs +- Latency for remote execution +- Less control over execution environment +- Requires API keys and network access +- Not self-contained + +**Why rejected:** While external services are valuable for production deployments, ADK should provide a self-contained solution that works out-of-the-box. Users can later swap ContainerCodeExecutor for cloud-based alternatives. + +### Alternative 3: Unsafe Local Execution + +Execute code directly in the host Python process. -I have a working implementation ready for PR submission: - -**New files (~2,500 lines of production code):** -- `src/google/adk/agents/coding_agent.py` - Main agent class -- `src/google/adk/agents/coding_agent_config.py` - Configuration -- `src/google/adk/code_executors/coding_agent_code_executor.py` - Executor wrapper -- `src/google/adk/code_executors/tool_code_generator.py` - Code generation -- `src/google/adk/code_executors/tool_execution_server.py` - HTTP IPC server -- `src/google/adk/code_executors/allowlist_validator.py` - Security validation - -**Sample agent:** -- `contributing/samples/coding_agent/` - Data Analysis Agent demo - -**Unit tests (~950 lines):** -- `tests/unittests/agents/test_coding_agent.py` -- `tests/unittests/code_executors/test_allowlist_validator.py` -- `tests/unittests/code_executors/test_tool_code_generator.py` - -### Tested Scenarios - -| Test | Status | -|------|--------| -| Basic math queries | ✅ Passed | -| Data analysis with pandas | ✅ Passed | -| Visualization with matplotlib | ✅ Passed | -| Multi-step analysis | ✅ Passed | -| Tool calling via HTTP IPC | ✅ Passed | -| Chart saving to host system | ✅ Passed | -| Error handling and retries | ✅ Passed | - -### Related Work -- OpenAI Code Interpreter -- Anthropic Computer Use -- Google AI Studio code execution - -### Future Enhancements (out of scope for initial PR) -- Stateful execution (persist variables across turns) -- Custom container images with pre-installed packages -- Integration with VertexAI code execution -- Support for additional languages +**Pros:** +- Simplest implementation +- Fastest execution +- No Docker dependency + +**Cons:** +- **Critical security risk**: Arbitrary code execution +- Cannot safely use in production +- No isolation between agent code and host system + +**Why rejected:** Security is non-negotiable for an agent framework. Even with import restrictions, local execution opens attack vectors through creative code generation. + +### Chosen Approach: Dedicated CodingAgent with Docker Sandboxing + +A purpose-built agent class that: +- Makes code generation a first-class paradigm +- Provides secure sandboxed execution via Docker +- Enables tool access through HTTP IPC +- Supports future extensions (stateful execution, alternative sandboxes) + +This approach aligns with the architecture of smolagents while integrating cleanly with ADK's existing infrastructure. --- -## Labels to add +## Additional context + +### Implementation Status + +I have a **complete, tested implementation** ready for PR submission. + +**Production Code (~2,500 lines):** +| File | Lines | Description | +|------|-------|-------------| +| `src/google/adk/agents/coding_agent.py` | ~610 | Main agent class with ReAct loop | +| `src/google/adk/agents/coding_agent_config.py` | ~225 | Pydantic configuration | +| `src/google/adk/code_executors/coding_agent_code_executor.py` | ~505 | Executor wrapper with tool injection | +| `src/google/adk/code_executors/tool_code_generator.py` | ~475 | Python stub generation for tools | +| `src/google/adk/code_executors/tool_execution_server.py` | ~365 | HTTP IPC server for tool calls | +| `src/google/adk/code_executors/allowlist_validator.py` | ~355 | Import security validation | + +**Sample Agent:** +| File | Description | +|------|-------------| +| `contributing/samples/coding_agent/agent.py` | Data Analysis Agent with 5 tools | +| `contributing/samples/coding_agent/README.md` | Comprehensive documentation | + +**Unit Tests (~950 lines):** +| File | Tests | +|------|-------| +| `tests/unittests/agents/test_coding_agent.py` | Agent creation, code extraction, error handling | +| `tests/unittests/code_executors/test_allowlist_validator.py` | Import validation, patterns | +| `tests/unittests/code_executors/test_tool_code_generator.py` | Stub generation, system prompts | + +### E2E Test Results + +| Test Scenario | Status | Notes | +|--------------|--------|-------| +| Basic math ("What is 25 * 17?") | ✅ Passed | Generates code, executes, returns 425 | +| Data analysis (Titanic survival rate) | ✅ Passed | Fetches CSV, uses pandas, returns 38.38% | +| Visualization (bar chart by class) | ✅ Passed | Creates matplotlib chart, saves to host | +| Multi-step analysis | ✅ Passed | Stats → visualization → insights in one query | +| Tool calling via HTTP IPC | ✅ Passed | fetch_url, save_chart work correctly | +| Error handling (pip warnings) | ✅ Passed | Distinguishes warnings from real errors | + +### Research Foundation + +This implementation is grounded in peer-reviewed research: + +1. **CodeAct (ICML 2024)**: [arXiv:2402.01030](https://arxiv.org/abs/2402.01030) + - "Executable Code Actions Elicit Better LLM Agents" + - Shows 20% improvement over JSON/text tool calling + - Introduces the `tool_code` action format we adopt + +2. **DynaSaur (COLM 2025)**: [arXiv:2411.01747](https://arxiv.org/abs/2411.01747) + - "Large Language Agents Beyond Predefined Actions" + - Demonstrates value of dynamically generated actions + - Shows agents can create and accumulate actions over time + +3. **smolagents (HuggingFace)**: [github.com/huggingface/smolagents](https://github.com/huggingface/smolagents) + - "Agents that think in code" - 25k+ stars + - Production-proven architecture + - Supports multiple sandbox backends (E2B, Modal, Docker) + +### Future Roadmap (Out of Scope for Initial PR) + +| Feature | Description | Priority | +|---------|-------------|----------| +| **Stateful execution** | Persist variables across conversation turns | High | +| **Alternative sandboxes** | E2B, Modal, Pyodide+Deno WebAssembly | High | +| **Custom container images** | Pre-installed packages for faster execution | Medium | +| **Jupyter integration** | Execute in Jupyter kernels | Medium | +| **Multi-agent orchestration** | CodingAgent as sub-agent in hierarchies | Medium | +| **Streaming output** | Stream stdout/stderr during execution | Low | +| **Additional languages** | JavaScript/TypeScript support | Low | + +### Enabling New Use Cases + +CodingAgent unlocks capabilities previously impossible or impractical in ADK: + +1. **AI Data Scientists**: Analyze datasets, create visualizations, generate reports +2. **Code Review Agents**: Read codebases, run analysis, suggest improvements +3. **Automation Agents**: Generate scripts to accomplish arbitrary tasks +4. **Research Assistants**: Process papers, extract data, create summaries +5. **Sub-agent Architecture**: Build systems like Claude Code with specialized sub-agents + +### Labels to Add - `enhancement` - `agents` - `new-feature` +- `experimental` From 233fbc4ce909b2c79d46195d44122fe603532262 Mon Sep 17 00:00:00 2001 From: Sudhendra Date: Sat, 17 Jan 2026 19:47:17 -0600 Subject: [PATCH 05/10] docs: Condense CodingAgent issue and align with RLM paper context - Refocus motivation on arXiv:2512.24601 long-context external environment framing - Keep smolagents as primary inspiration for code-thinking agents - Tighten solution, alternatives, and future directions per ADK template --- .github/CODING_AGENT_ISSUE.md | 327 ++++++---------------------------- 1 file changed, 56 insertions(+), 271 deletions(-) diff --git a/.github/CODING_AGENT_ISSUE.md b/.github/CODING_AGENT_ISSUE.md index 05affec6dd..db39b23495 100644 --- a/.github/CODING_AGENT_ISSUE.md +++ b/.github/CODING_AGENT_ISSUE.md @@ -1,309 +1,94 @@ # GitHub Issue: CodingAgent Feature Request -**Use this content to create an issue at: https://github.com/google/adk-python/issues/new?template=feature_request.md** +Use this content to create an issue at: +https://github.com/google/adk-python/issues/new?template=feature_request.md --- ## Title -`feat(agents): Add CodingAgent - Agents that Think in Code with Sandboxed Execution` +feat(agents): Add CodingAgent (agents that think in code) --- ## Is your feature request related to a problem? Please describe. -### The Fundamental Limitation of Tool-Calling Agents +ADK’s current default agent interaction pattern is “tool selection from a fixed action set”. This is powerful, but it breaks down for two increasingly common workloads: -Current ADK agents operate through a **predefined tool paradigm**: the agent receives a task, selects from a fixed set of tools, and chains tool calls to accomplish goals. While effective for well-scoped problems, this architecture imposes fundamental constraints that limit agent capabilities in real-world scenarios: +1) Long-context work beyond model context windows +- Many real tasks require operating over very large corpora: codebases, logs, datasets, multi-file configs, or long documents. +- If the agent must keep the relevant source text and intermediate results inside the LLM context, it becomes context-window bound and expensive. +- Recent work such as “Recursive Language Models” (arXiv:2512.24601) proposes treating long prompts as an external environment and letting the model programmatically examine/decompose/recursively process snippets. This suggests a practical direction for agents: move heavy inspection, decomposition, and intermediate state out of the prompt and into an execution environment. + - https://arxiv.org/abs/2512.24601 -#### 1. **Constrained Action Space** +2) Expressiveness and composability limits of pure tool-calling +- Tool-calling assumes we can enumerate actions up-front. In open-ended tasks, the agent needs to compose multiple operations, iterate, cache intermediate artifacts, and implement “one-off” transformations without requiring new bespoke tools each time. +- A code-based action space lets the agent compose operations naturally (loops, conditionals, helper functions), which reduces the need for an explosion of tools. -Tool-calling agents are restricted to a finite, pre-enumerated set of actions. As demonstrated in the [CodeAct paper (ICML 2024)](https://arxiv.org/abs/2402.01030), this creates a **combinatorial explosion problem**: complex tasks requiring composition of multiple operations become intractable when each combination must be explicitly defined as a tool. The paper shows that code-based actions achieve **up to 20% higher success rates** by allowing arbitrary composition. +3) Developer experience gap for building “coding agents” and sub-agent architectures +- Users increasingly want agent systems like Claude Code / OpenCode: multi-step coding workflows with sub-agents (planner, tester, refactorer, etc.) and strong “think in code” execution. +- ADK has strong orchestration primitives; adding a first-class code-executing agent unlocks building these systems within ADK while keeping sandboxing and tool integration. -#### 2. **Context Window Bottleneck** - -Modern LLMs have context windows ranging from 8K to 200K tokens, but complex reasoning tasks can easily exceed these limits. Tool-calling agents must maintain entire conversation histories, tool schemas, and intermediate results in context. Code agents solve this by **offloading computation to the execution environment**—variables persist in the sandbox, not in the context window. This insight is central to systems like Claude Code and OpenCode that can work on entire codebases. - -#### 3. **Inability to Dynamically Create Actions** - -The [DynaSaur paper (COLM 2025)](https://arxiv.org/abs/2411.01747) identifies a critical flaw: "Existing LLM agent systems typically select actions from a fixed and predefined set at every step... this requires substantial human effort to enumerate and implement all possible actions, which is impractical in complex environments." Code agents can **generate novel actions on-the-fly**, adapting to unforeseen scenarios. - -#### 4. **Competitive Gap** - -The industry has converged on code-generating agents as the next evolution: -- **OpenAI Code Interpreter**: Code execution in sandbox -- **Anthropic Claude's Computer Use**: Code-based computer control -- **HuggingFace smolagents**: "Agents that think in code" ([25k+ GitHub stars](https://github.com/huggingface/smolagents)) -- **OpenCode, Claude Code, Cursor**: Production coding agents using this pattern - -ADK currently lacks this capability, forcing users to build complex workarounds or use competing frameworks. - -### Concrete User Pain Points - -``` -"I want my agent to analyze a 50MB CSV and create visualizations" -→ Current: Build custom tools for every possible analysis operation -→ With CodingAgent: Agent writes pandas/matplotlib code dynamically - -"I need multi-step calculations with intermediate results" -→ Current: Chain multiple tool calls, losing state between each -→ With CodingAgent: Variables persist in sandbox across iterations - -"I want the agent to figure out HOW to solve a problem" -→ Current: Limited to predefined solution paths -→ With CodingAgent: Agent generates arbitrary solution code - -"I need to work with long documents without hitting context limits" -→ Current: Complex chunking strategies, lost coherence -→ With CodingAgent: Load documents into sandbox, process incrementally -``` +Related inspiration: HuggingFace “smolagents” positions CodeAgent as a first-class concept (“agents that think in code”) and supports sandbox backends (Docker, etc.). +- https://github.com/huggingface/smolagents --- -## Describe the solution you'd like - -### CodingAgent: A ReAct Agent that Thinks in Code - -Inspired by [HuggingFace's smolagents](https://github.com/huggingface/smolagents) and grounded in recent research ([CodeAct](https://arxiv.org/abs/2402.01030), [DynaSaur](https://arxiv.org/abs/2411.01747)), CodingAgent is a new agent type that: - -1. **Generates Python code** as its action representation (in `tool_code` blocks) -2. **Executes code in sandboxed Docker containers** for security -3. **Calls ADK tools from generated code** via HTTP IPC -4. **Iterates using a ReAct loop** until producing a final answer -5. **Maintains state across iterations** in the execution environment +## Describe the solution you’d like -### Architecture +Add a new experimental agent type: CodingAgent. -``` -┌─────────────────┐ ┌──────────────────┐ ┌─────────────────────────┐ -│ User Query │────▶│ CodingAgent │────▶│ Docker Container │ -│ │ │ (Gemini LLM) │ │ (Python 3.11) │ -└─────────────────┘ └──────────────────┘ │ │ - │ │ • pandas, numpy │ - │ │ • matplotlib, seaborn │ - │ ReAct Loop │ • Any pip package │ - │ │ • Persistent state │ - ▼ └───────────┬─────────────┘ - ┌──────────────┐ │ - │ Tool Server │◀───────────────────┘ - │ (HTTP IPC) │ Tool calls via HTTP POST - │ Port 8765 │ (fetch_url, save_chart, etc.) - └──────────────┘ -``` +CodingAgent should: +- Generate Python code as the primary action representation (in `tool_code` blocks). +- Execute that code in a sandboxed environment (Docker-based initially). +- Allow generated code to call ADK tools safely via an IPC bridge (e.g., HTTP) rather than exposing the host runtime directly. +- Support iterative execution (ReAct-style loop): generate → run → observe stdout/tool results → refine → final answer. -### Why Code Actions Are Superior +Why this solves the problem +- Long-context: aligns with the “external environment” framing in arXiv:2512.24601 by enabling the agent to iteratively inspect, decompose, and process large inputs using code and persisted artifacts, instead of forcing all content into the model context. +- Composability: code enables arbitrary composition (loops, conditionals, helper functions) without requiring every combination to be implemented as a first-class tool. +- Coding-agent architectures: makes it straightforward to build higher-level workflows and multi-agent hierarchies where sub-agents can generate/run code for specialized tasks. -From the CodeAct paper: -> "CodeAct can execute code actions and dynamically revise prior actions or emit new actions upon new observations through multi-turn interactions... CodeAct outperforms widely used alternatives (up to 20% higher success rate)." +High-level architecture -From DynaSaur: -> "The agent interacts with its environment by generating and executing programs written in a general-purpose programming language. Moreover, generated actions are accumulated over time for future reuse." +User → CodingAgent (LLM) → sandbox executor (Docker Python) + ↘ tool IPC server on host ↙ -Code provides: -- **Composability**: Combine operations arbitrarily (`for url in urls: data.append(fetch(url))`) -- **State persistence**: Variables survive across iterations -- **Dynamic tool creation**: Write new functions as needed -- **Error handling**: Try/except, retries, fallbacks in code -- **Computational offloading**: Process data in sandbox, not context - -### API Design - -```python -from google.adk.agents import CodingAgent -from google.adk.code_executors import ContainerCodeExecutor -from google.adk.code_executors.allowlist_validator import DEFAULT_SAFE_IMPORTS - -def fetch_data(url: str) -> dict: - """Fetch data from a URL - available to generated code.""" - # Implementation... - -def save_chart(image_data: str, filename: str) -> dict: - """Save chart to host filesystem - bridges container to host.""" - # Implementation... - -root_agent = CodingAgent( - name="data_analyst", - model="gemini-2.5-flash", - instruction="You are a data analyst. Analyze data and provide insights.", - tools=[fetch_data, save_chart], # Tools callable from generated code - code_executor=ContainerCodeExecutor(image="python:3.11-slim"), - authorized_imports=DEFAULT_SAFE_IMPORTS | {"pandas", "matplotlib", "numpy"}, - max_iterations=10, - error_retry_attempts=2, - stateful=False, # Future: True for persistent state across turns -) -``` - -### Key Components - -| Component | Purpose | -|-----------|---------| -| **CodingAgent** | ReAct loop orchestrator, code extraction, LLM interaction | -| **CodingAgentCodeExecutor** | Wraps underlying executor, injects tool stubs | -| **ToolCodeGenerator** | Generates Python function stubs for ADK tools | -| **ToolExecutionServer** | HTTP server enabling tool calls from container | -| **AllowlistValidator** | Security: validates imports against allowlist | - -### Security Model - -1. **Sandboxed execution**: All code runs in isolated Docker containers -2. **Import allowlisting**: Only explicitly authorized imports permitted -3. **Tool isolation**: Tools execute on host, not in container -4. **No filesystem access**: Container cannot access host filesystem -5. **Network isolation**: Container only reaches tool server -6. **Configurable**: Users control exactly what's permitted +Proposed execution environments (progressive) +- v1: Docker Python sandbox (existing ContainerCodeExecutor integration) +- future: REPL / Jupyter-kernel style execution modes for interactive, stateful sessions (still sandboxed) --- -## Describe alternatives you've considered - -### Alternative 1: Extend LlmAgent with Code Execution Tool - -Add code execution as just another tool that LlmAgent can call. - -**Pros:** -- Minimal API changes -- Reuses existing agent infrastructure - -**Cons:** -- No ReAct loop for code iteration -- Tool-calling overhead for every code snippet -- Mixes paradigms (tool-calling agent calling code execution tool) -- No state persistence between code executions -- Doesn't capture the "thinking in code" pattern +## Describe alternatives you’ve considered -**Why rejected:** This approach treats code execution as an afterthought rather than a first-class paradigm. The power of code agents comes from the tight integration of code generation, execution, and iteration—not from occasionally executing snippets. +1) “Just add a code-execution tool” to existing agents +- Pros: minimal surface-area change. +- Cons: code execution becomes an occasional tool call rather than the agent’s primary action space; harder to support tight generate→execute→iterate loops and long-context strategies that rely on an external environment. -### Alternative 2: External Code Execution Service +2) Require users to write bespoke tools for every operation +- Pros: explicit and controlled. +- Cons: does not scale; real workflows need ad-hoc transformations and composition that explode the tool surface area. -Integrate with external services like E2B, Modal, or Blaxel for code execution. - -**Pros:** -- Offloads security concerns to specialized providers -- Potentially more scalable -- No Docker dependency for users - -**Cons:** -- External dependency and potential costs -- Latency for remote execution -- Less control over execution environment -- Requires API keys and network access -- Not self-contained - -**Why rejected:** While external services are valuable for production deployments, ADK should provide a self-contained solution that works out-of-the-box. Users can later swap ContainerCodeExecutor for cloud-based alternatives. - -### Alternative 3: Unsafe Local Execution - -Execute code directly in the host Python process. - -**Pros:** -- Simplest implementation -- Fastest execution -- No Docker dependency - -**Cons:** -- **Critical security risk**: Arbitrary code execution -- Cannot safely use in production -- No isolation between agent code and host system - -**Why rejected:** Security is non-negotiable for an agent framework. Even with import restrictions, local execution opens attack vectors through creative code generation. - -### Chosen Approach: Dedicated CodingAgent with Docker Sandboxing - -A purpose-built agent class that: -- Makes code generation a first-class paradigm -- Provides secure sandboxed execution via Docker -- Enables tool access through HTTP IPC -- Supports future extensions (stateful execution, alternative sandboxes) - -This approach aligns with the architecture of smolagents while integrating cleanly with ADK's existing infrastructure. +3) Run code on the host interpreter +- Pros: simplest. +- Cons: unacceptable security risk; sandboxing is required for a general-purpose code agent. --- ## Additional context -### Implementation Status - -I have a **complete, tested implementation** ready for PR submission. - -**Production Code (~2,500 lines):** -| File | Lines | Description | -|------|-------|-------------| -| `src/google/adk/agents/coding_agent.py` | ~610 | Main agent class with ReAct loop | -| `src/google/adk/agents/coding_agent_config.py` | ~225 | Pydantic configuration | -| `src/google/adk/code_executors/coding_agent_code_executor.py` | ~505 | Executor wrapper with tool injection | -| `src/google/adk/code_executors/tool_code_generator.py` | ~475 | Python stub generation for tools | -| `src/google/adk/code_executors/tool_execution_server.py` | ~365 | HTTP IPC server for tool calls | -| `src/google/adk/code_executors/allowlist_validator.py` | ~355 | Import security validation | - -**Sample Agent:** -| File | Description | -|------|-------------| -| `contributing/samples/coding_agent/agent.py` | Data Analysis Agent with 5 tools | -| `contributing/samples/coding_agent/README.md` | Comprehensive documentation | - -**Unit Tests (~950 lines):** -| File | Tests | -|------|-------| -| `tests/unittests/agents/test_coding_agent.py` | Agent creation, code extraction, error handling | -| `tests/unittests/code_executors/test_allowlist_validator.py` | Import validation, patterns | -| `tests/unittests/code_executors/test_tool_code_generator.py` | Stub generation, system prompts | - -### E2E Test Results - -| Test Scenario | Status | Notes | -|--------------|--------|-------| -| Basic math ("What is 25 * 17?") | ✅ Passed | Generates code, executes, returns 425 | -| Data analysis (Titanic survival rate) | ✅ Passed | Fetches CSV, uses pandas, returns 38.38% | -| Visualization (bar chart by class) | ✅ Passed | Creates matplotlib chart, saves to host | -| Multi-step analysis | ✅ Passed | Stats → visualization → insights in one query | -| Tool calling via HTTP IPC | ✅ Passed | fetch_url, save_chart work correctly | -| Error handling (pip warnings) | ✅ Passed | Distinguishes warnings from real errors | - -### Research Foundation - -This implementation is grounded in peer-reviewed research: - -1. **CodeAct (ICML 2024)**: [arXiv:2402.01030](https://arxiv.org/abs/2402.01030) - - "Executable Code Actions Elicit Better LLM Agents" - - Shows 20% improvement over JSON/text tool calling - - Introduces the `tool_code` action format we adopt - -2. **DynaSaur (COLM 2025)**: [arXiv:2411.01747](https://arxiv.org/abs/2411.01747) - - "Large Language Agents Beyond Predefined Actions" - - Demonstrates value of dynamically generated actions - - Shows agents can create and accumulate actions over time - -3. **smolagents (HuggingFace)**: [github.com/huggingface/smolagents](https://github.com/huggingface/smolagents) - - "Agents that think in code" - 25k+ stars - - Production-proven architecture - - Supports multiple sandbox backends (E2B, Modal, Docker) - -### Future Roadmap (Out of Scope for Initial PR) - -| Feature | Description | Priority | -|---------|-------------|----------| -| **Stateful execution** | Persist variables across conversation turns | High | -| **Alternative sandboxes** | E2B, Modal, Pyodide+Deno WebAssembly | High | -| **Custom container images** | Pre-installed packages for faster execution | Medium | -| **Jupyter integration** | Execute in Jupyter kernels | Medium | -| **Multi-agent orchestration** | CodingAgent as sub-agent in hierarchies | Medium | -| **Streaming output** | Stream stdout/stderr during execution | Low | -| **Additional languages** | JavaScript/TypeScript support | Low | - -### Enabling New Use Cases - -CodingAgent unlocks capabilities previously impossible or impractical in ADK: - -1. **AI Data Scientists**: Analyze datasets, create visualizations, generate reports -2. **Code Review Agents**: Read codebases, run analysis, suggest improvements -3. **Automation Agents**: Generate scripts to accomplish arbitrary tasks -4. **Research Assistants**: Process papers, extract data, create summaries -5. **Sub-agent Architecture**: Build systems like Claude Code with specialized sub-agents +Future directions enabled by CodingAgent +- Long-context scaffolds inspired by arXiv:2512.24601: treat large inputs (files, repo trees, logs) as an “environment” the agent queries/decomposes recursively using code, storing intermediate state outside the LLM context. +- Sub-agent coding workflows (Claude Code / OpenCode style): planner/tester/refactor sub-agents coordinated by ADK, each using code execution. +- Multiple sandbox backends (like smolagents): Docker initially, with optional future support for other sandboxes and interactive execution modes. -### Labels to Add +Links +- smolagents (inspiration): https://github.com/huggingface/smolagents +- Recursive Language Models (long-context framing): https://arxiv.org/abs/2512.24601 -- `enhancement` -- `agents` -- `new-feature` -- `experimental` +Labels to add +- enhancement +- agents +- new-feature +- experimental From 7d08efce312e0d4e20ddbafb8e1082d5c6a7a683 Mon Sep 17 00:00:00 2001 From: Sudhendra Date: Sat, 17 Jan 2026 23:41:29 -0600 Subject: [PATCH 06/10] tracing --- contributing/samples/coding_agent/agent.py | 478 ++++----- src/google/adk/agents/coding_agent.py | 1022 ++++++++++---------- 2 files changed, 790 insertions(+), 710 deletions(-) diff --git a/contributing/samples/coding_agent/agent.py b/contributing/samples/coding_agent/agent.py index 4596d95318..2691e73420 100644 --- a/contributing/samples/coding_agent/agent.py +++ b/contributing/samples/coding_agent/agent.py @@ -12,13 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Data Analysis Agent using CodingAgent. +"""Conversational Data Analysis Agent using CodingAgent. -This sample demonstrates a CodingAgent configured as a data analyst that can: +This sample demonstrates a CodingAgent configured as a conversational data +analyst that can: +- Have multi-turn conversations about data analysis - Fetch datasets from URLs (CSV, JSON, text) - Analyze data using pandas - Create visualizations using matplotlib - Generate statistical summaries and insights +- Remember context from previous messages Prerequisites: - Docker must be installed and running @@ -28,10 +31,18 @@ adk run contributing/samples/coding_agent adk web contributing/samples -Example queries: -- "What is the survival rate on the Titanic?" -- "Create a bar chart showing survival rate by passenger class" -- "Analyze the iris dataset and create a scatter plot" +Example conversation: + User: "What datasets do you have available?" + Agent: "I have access to three sample datasets: titanic, iris, and tips..." + + User: "Let's look at the titanic data. What's the survival rate?" + Agent: "I'll analyze the Titanic dataset for you... The overall survival + rate was 38.4%. Would you like me to break this down by gender + or passenger class?" + + User: "Yes, show me by passenger class with a chart" + Agent: "Here's a bar chart showing survival rates by class... First class + passengers had the highest survival rate at 63%..." """ import base64 @@ -68,9 +79,7 @@ "Iris flower dataset. 150 samples of 3 species with sepal and petal" " measurements." ), - "columns": ( - "sepal_length, sepal_width, petal_length, petal_width, species" - ), + "columns": ("sepal_length, sepal_width, petal_length, petal_width, species"), }, "tips": { "url": ( @@ -86,90 +95,90 @@ def fetch_url(url: str) -> dict: - """Fetch content from a URL. - - Fetches data from the specified URL and returns the content along with - metadata. Supports CSV, JSON, and plain text content. - - Args: - url: The URL to fetch content from. - - Returns: - Dictionary containing: - - content: The fetched content as a string - - content_type: The MIME type of the content - - size: Size of the content in bytes - - url: The original URL - - success: Whether the fetch was successful - - error: Error message if fetch failed (only present on failure) - """ - try: - req = urllib.request.Request( - url, - headers={"User-Agent": "Mozilla/5.0 (compatible; ADK-DataAnalyst/1.0)"}, - ) - with urllib.request.urlopen(req, timeout=30) as response: - content = response.read().decode("utf-8") - content_type = response.headers.get("Content-Type", "text/plain") - return { - "content": content, - "content_type": content_type, - "size": len(content), - "url": url, - "success": True, - } - except urllib.error.URLError as e: - return { - "content": "", - "url": url, - "success": False, - "error": f"Failed to fetch URL: {str(e)}", - } - except Exception as e: - return { - "content": "", - "url": url, - "success": False, - "error": f"Unexpected error: {str(e)}", - } + """Fetch content from a URL. + + Fetches data from the specified URL and returns the content along with + metadata. Supports CSV, JSON, and plain text content. + + Args: + url: The URL to fetch content from. + + Returns: + Dictionary containing: + - content: The fetched content as a string + - content_type: The MIME type of the content + - size: Size of the content in bytes + - url: The original URL + - success: Whether the fetch was successful + - error: Error message if fetch failed (only present on failure) + """ + try: + req = urllib.request.Request( + url, + headers={"User-Agent": "Mozilla/5.0 (compatible; ADK-DataAnalyst/1.0)"}, + ) + with urllib.request.urlopen(req, timeout=30) as response: + content = response.read().decode("utf-8") + content_type = response.headers.get("Content-Type", "text/plain") + return { + "content": content, + "content_type": content_type, + "size": len(content), + "url": url, + "success": True, + } + except urllib.error.URLError as e: + return { + "content": "", + "url": url, + "success": False, + "error": f"Failed to fetch URL: {str(e)}", + } + except Exception as e: + return { + "content": "", + "url": url, + "success": False, + "error": f"Unexpected error: {str(e)}", + } def get_sample_datasets() -> dict: - """Get available sample datasets with their URLs and descriptions. + """Get available sample datasets with their URLs and descriptions. - Returns a dictionary of sample datasets that can be used for analysis. - Each dataset includes a URL, description, and column information. + Returns a dictionary of sample datasets that can be used for analysis. + Each dataset includes a URL, description, and column information. - Returns: - Dictionary with dataset names as keys, each containing: - - url: Direct URL to download the CSV file - - description: Brief description of the dataset - - columns: Comma-separated list of column names - """ - return SAMPLE_DATASETS + Returns: + Dictionary with dataset names as keys, each containing: + - url: Direct URL to download the CSV file + - description: Brief description of the dataset + - columns: Comma-separated list of column names + """ + return SAMPLE_DATASETS def get_current_time() -> dict: - """Get the current date and time. - - Returns: - Dictionary containing: - - timestamp: ISO format timestamp - - year, month, day: Date components - - hour, minute, second: Time components - - weekday: Name of the day of the week - """ - now = datetime.now() - return { - "timestamp": now.isoformat(), - "year": now.year, - "month": now.month, - "day": now.day, - "hour": now.hour, - "minute": now.minute, - "second": now.second, - "weekday": now.strftime("%A"), - } + """Get the current date and time. + + Returns: + Dictionary containing: + - timestamp: ISO format timestamp + - year, month, day: Date components + - hour, minute, second: Time components + - weekday: Name of the day of the week + """ + now = datetime.now() + return { + "timestamp": now.isoformat(), + "year": now.year, + "month": now.month, + "day": now.day, + "hour": now.hour, + "minute": now.minute, + "second": now.second, + "weekday": now.strftime("%A"), + } # Directory on host system to save charts @@ -177,141 +186,143 @@ def get_current_time() -> dict: def save_chart(image_data: str, filename: str) -> dict: - """Save a chart image to the host system. - - This tool saves base64-encoded image data to the host machine's filesystem, - making charts accessible outside the Docker container. - - To use this tool, first save your matplotlib figure to a bytes buffer, - then encode it as base64: - - Example: - import base64 - import io - import matplotlib.pyplot as plt - - # Create your plot - plt.figure() - plt.plot([1, 2, 3], [1, 4, 9]) - - # Save to buffer and encode - buf = io.BytesIO() - plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') - buf.seek(0) - image_data = base64.b64encode(buf.read()).decode('utf-8') - plt.close() - - # Save to host system - result = save_chart(image_data=image_data, filename="my_chart.png") - - Args: - image_data: Base64-encoded image data (PNG format recommended). - filename: Name for the saved file (e.g., "chart.png"). - - Returns: - Dictionary containing: - - success: Whether the save was successful - - filepath: Full path where the file was saved on the host - - size: Size of the saved file in bytes - - error: Error message if save failed (only present on failure) - """ - try: - # Ensure the output directory exists - os.makedirs(HOST_CHARTS_DIR, exist_ok=True) - - # Sanitize filename - safe_filename = os.path.basename(filename) - if not safe_filename: - safe_filename = "chart.png" - - filepath = os.path.join(HOST_CHARTS_DIR, safe_filename) - - # Decode and save - image_bytes = base64.b64decode(image_data) - with open(filepath, "wb") as f: - f.write(image_bytes) - - return { - "success": True, - "filepath": filepath, - "size": len(image_bytes), - "message": f"Chart saved to {filepath}", - } - except binascii.Error as e: - return { - "success": False, - "error": f"Invalid base64 data: {str(e)}", - } - except OSError as e: - return { - "success": False, - "error": f"Failed to save file: {str(e)}", - } - except Exception as e: - return { - "success": False, - "error": f"Unexpected error: {str(e)}", - } + """Save a chart image to the host system. + + This tool saves base64-encoded image data to the host machine's filesystem, + making charts accessible outside the Docker container. + + To use this tool, first save your matplotlib figure to a bytes buffer, + then encode it as base64: + + Example: + import base64 + import io + import matplotlib.pyplot as plt + + # Create your plot + plt.figure() + plt.plot([1, 2, 3], [1, 4, 9]) + + # Save to buffer and encode + buf = io.BytesIO() + plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') + buf.seek(0) + image_data = base64.b64encode(buf.read()).decode('utf-8') + plt.close() + + # Save to host system + result = save_chart(image_data=image_data, filename="my_chart.png") + + Args: + image_data: Base64-encoded image data (PNG format recommended). + filename: Name for the saved file (e.g., "chart.png"). + + Returns: + Dictionary containing: + - success: Whether the save was successful + - filepath: Full path where the file was saved on the host + - size: Size of the saved file in bytes + - error: Error message if save failed (only present on failure) + """ + try: + # Ensure the output directory exists + os.makedirs(HOST_CHARTS_DIR, exist_ok=True) + + # Sanitize filename + safe_filename = os.path.basename(filename) + if not safe_filename: + safe_filename = "chart.png" + + filepath = os.path.join(HOST_CHARTS_DIR, safe_filename) + + # Decode and save + image_bytes = base64.b64decode(image_data) + with open(filepath, "wb") as f: + f.write(image_bytes) + + return { + "success": True, + "filepath": filepath, + "size": len(image_bytes), + "message": f"Chart saved to {filepath}", + } + except binascii.Error as e: + return { + "success": False, + "error": f"Invalid base64 data: {str(e)}", + } + except OSError as e: + return { + "success": False, + "error": f"Failed to save file: {str(e)}", + } + except Exception as e: + return { + "success": False, + "error": f"Unexpected error: {str(e)}", + } def list_saved_charts() -> dict: - """List all charts saved on the host system. - - Returns: - Dictionary containing: - - success: Whether the operation was successful - - charts: List of saved chart filenames - - directory: The directory where charts are saved - - count: Number of charts found - """ - try: - if not os.path.exists(HOST_CHARTS_DIR): - return { - "success": True, - "charts": [], - "directory": HOST_CHARTS_DIR, - "count": 0, - } - - charts = [ - f - for f in os.listdir(HOST_CHARTS_DIR) - if f.lower().endswith((".png", ".jpg", ".jpeg", ".svg", ".pdf")) - ] - return { - "success": True, - "charts": charts, - "directory": HOST_CHARTS_DIR, - "count": len(charts), - } - except Exception as e: - return { - "success": False, - "error": f"Failed to list charts: {str(e)}", - } + """List all charts saved on the host system. + + Returns: + Dictionary containing: + - success: Whether the operation was successful + - charts: List of saved chart filenames + - directory: The directory where charts are saved + - count: Number of charts found + """ + try: + if not os.path.exists(HOST_CHARTS_DIR): + return { + "success": True, + "charts": [], + "directory": HOST_CHARTS_DIR, + "count": 0, + } + + charts = [ + f + for f in os.listdir(HOST_CHARTS_DIR) + if f.lower().endswith((".png", ".jpg", ".jpeg", ".svg", ".pdf")) + ] + return { + "success": True, + "charts": charts, + "directory": HOST_CHARTS_DIR, + "count": len(charts), + } + except Exception as e: + return { + "success": False, + "error": f"Failed to list charts: {str(e)}", + } # Additional imports allowed for data analysis -DATA_ANALYSIS_IMPORTS = frozenset({ - # Data analysis - "pandas", - "pandas.*", - "numpy", - "numpy.*", - # Visualization - "matplotlib", - "matplotlib.*", - "seaborn", - "seaborn.*", - # Data I/O - "csv", - "io", - "io.*", - # Encoding for chart saving - "base64", - # Subprocess for pip installs - "subprocess", -}) +DATA_ANALYSIS_IMPORTS = frozenset( + { + # Data analysis + "pandas", + "pandas.*", + "numpy", + "numpy.*", + # Visualization + "matplotlib", + "matplotlib.*", + "seaborn", + "seaborn.*", + # Data I/O + "csv", + "io", + "io.*", + # Encoding for chart saving + "base64", + # Subprocess for pip installs + "subprocess", + } +) # Create the Data Analysis Agent @@ -322,18 +333,30 @@ def list_saved_charts() -> dict: "and generates insights using Python code execution." ), model="gemini-2.5-flash", - instruction="""You are a data analyst. Analyze data, create visualizations, and provide insights. + instruction="""You are a friendly, conversational data analyst assistant. You help users analyze datasets, create visualizations, and generate insights using Python code execution. -IMPORTANT: First install required packages before using them: +## Your Personality +- Be conversational and engaging - ask clarifying questions when needed +- Explain your analysis in plain language that anyone can understand +- Offer suggestions for follow-up analyses or visualizations +- Remember context from previous messages in our conversation + +## When the user asks for analysis or visualization: + +1. First, install required packages (only needed once per session): ```tool_code import subprocess subprocess.run(["pip", "install", "-q", "pandas", "matplotlib", "seaborn", "numpy"], check=True) print("Packages installed successfully") ``` -Then use the available tools to fetch datasets. Write Python code to analyze data using pandas and create charts with matplotlib. +2. Use the available tools: + - `get_sample_datasets()` - See available sample datasets (titanic, iris, tips) + - `fetch_url(url)` - Fetch data from any URL + - `save_chart(image_data, filename)` - Save visualizations + - `list_saved_charts()` - See previously saved charts -CRITICAL: You MUST use the save_chart() tool to save charts - do NOT use plt.savefig() to a file path directly. The save_chart() tool transfers the chart to the host system. Here is the REQUIRED pattern: +3. When creating charts, ALWAYS use this pattern to save them: ```tool_code import base64 import io @@ -350,14 +373,23 @@ def list_saved_charts() -> dict: image_data = base64.b64encode(buf.read()).decode('utf-8') plt.close() -# Use the save_chart tool to save to host system -result = save_chart(image_data=image_data, filename="my_chart.png") +# Save to host system +result = save_chart(image_data=image_data, filename="descriptive_name.png") print(f"Chart saved: {result}") ``` -The chart will be saved to the host system at /tmp/adk_charts/. Always report this filepath in your final answer. +## Response Style +- After completing an analysis, summarize your findings in a conversational way +- Mention the chart location if you created one (charts are saved to /tmp/adk_charts/) +- Ask if the user would like to explore the data further or see different visualizations +- If you don't have enough information, ask questions before diving into code + +## Examples of good responses: +- "I found some interesting patterns! The survival rate was 38.4%. Would you like me to break this down by gender or passenger class?" +- "Here's a bar chart showing the distribution. I saved it to /tmp/adk_charts/survival_by_class.png. What aspect would you like to explore next?" +- "Before I create that visualization, could you tell me which columns you're most interested in comparing?" -Call final_answer() with your findings when done. +Remember: You're having a conversation, not just executing tasks. Engage with the user! """, tools=[ fetch_url, @@ -372,5 +404,5 @@ def list_saved_charts() -> dict: authorized_imports=DEFAULT_SAFE_IMPORTS | DATA_ANALYSIS_IMPORTS, max_iterations=10, error_retry_attempts=2, - stateful=False, + stateful=True, ) diff --git a/src/google/adk/agents/coding_agent.py b/src/google/adk/agents/coding_agent.py index 850fbe505b..0baaa7d131 100644 --- a/src/google/adk/agents/coding_agent.py +++ b/src/google/adk/agents/coding_agent.py @@ -22,6 +22,8 @@ import logging import re +import time +import uuid from typing import Any from typing import AsyncGenerator from typing import Callable @@ -52,6 +54,10 @@ from ..models.llm_request import LlmRequest from ..models.llm_response import LlmResponse from ..models.registry import LLMRegistry +from ..telemetry.tracing import trace_call_llm +from ..telemetry.tracing import trace_code_execution +from ..telemetry.tracing import trace_code_generation +from ..telemetry.tracing import tracer from ..tools.base_tool import BaseTool from ..tools.base_toolset import BaseToolset from ..tools.function_tool import FunctionTool @@ -69,17 +75,17 @@ @experimental class CodingAgentState(BaseAgentState): - """State for CodingAgent tracking execution progress. + """State for CodingAgent tracking execution progress. - Attributes: - iteration_count: Number of ReAct loop iterations completed. - error_count: Number of consecutive errors encountered. - execution_history: List of execution steps with code, results, and traces. - """ + Attributes: + iteration_count: Number of ReAct loop iterations completed. + error_count: Number of consecutive errors encountered. + execution_history: List of execution steps with code, results, and traces. + """ - iteration_count: int = 0 - error_count: int = 0 - execution_history: List[Dict[str, Any]] = Field(default_factory=list) + iteration_count: int = 0 + error_count: int = 0 + execution_history: List[Dict[str, Any]] = Field(default_factory=list) ToolUnion = Union[Callable[..., Any], BaseTool, BaseToolset] @@ -89,297 +95,296 @@ async def _convert_tool_union_to_tools( tool_union: ToolUnion, ctx: Optional[ReadonlyContext] = None, ) -> List[BaseTool]: - """Convert a tool union to a list of BaseTool instances. - - Args: - tool_union: A callable, BaseTool, or BaseToolset. - ctx: Optional context for toolset resolution. - - Returns: - List of BaseTool instances. - """ - if isinstance(tool_union, BaseTool): - return [tool_union] - if callable(tool_union): - return [FunctionTool(func=tool_union)] - # BaseToolset - if ctx: - return await tool_union.get_tools_with_prefix(ctx) - return await tool_union.get_tools_with_prefix(None) - - -@experimental -class CodingAgent(BaseAgent): - """Agent that generates Python code to solve tasks using available tools. - - CodingAgent implements a ReAct-style loop where it: - 1. Receives a task from the user - 2. Generates Python code that calls available tools - 3. Executes the code in a sandboxed environment - 4. Processes the results and either provides an answer or continues - - Tools are made available as Python functions that the generated code - can call. The code execution happens in a container for security, - with tool calls routed via HTTP to the host. - - Attributes: - model: The LLM model to use for code generation. - instruction: Additional instructions for the agent. - tools: List of tools available to the agent. - code_executor: The underlying code executor (e.g., ContainerCodeExecutor). - authorized_imports: Set of allowed Python imports. - max_iterations: Maximum ReAct loop iterations. - error_retry_attempts: Number of retries on execution errors. - stateful: Whether to maintain state across iterations. - tool_server_host: Host for the tool execution server. - tool_server_port: Port for the tool execution server. - """ - - DEFAULT_MODEL: ClassVar[str] = "gemini-2.5-flash" - - config_type: ClassVar[Type[BaseAgentConfig]] = CodingAgentConfig - - model: Union[str, BaseLlm] = "" - """The model to use for code generation.""" - - instruction: str = "" - """Additional instructions for the agent.""" - - tools: List[ToolUnion] = Field(default_factory=list) - """Tools available to the agent.""" - - code_executor: Optional[BaseCodeExecutor] = None - """The underlying code executor. If not set, uses ContainerCodeExecutor.""" - - authorized_imports: FrozenSet[str] = DEFAULT_SAFE_IMPORTS - """Set of allowed import patterns.""" - - max_iterations: int = 10 - """Maximum number of ReAct loop iterations.""" - - error_retry_attempts: int = 2 - """Number of retries on execution errors.""" - - stateful: bool = False - """Whether to maintain state across iterations.""" - - tool_server_host: Optional[str] = None - """Host for the tool execution server.""" - - tool_server_port: int = 8765 - """Port for the tool execution server.""" - - # Internal state - _coding_executor: Optional[CodingAgentCodeExecutor] = None - _resolved_tools: Optional[List[BaseTool]] = None - - class Config: - """Pydantic config.""" - - arbitrary_types_allowed = True - - @property - def canonical_model(self) -> BaseLlm: - """Get the resolved model as BaseLlm.""" - if isinstance(self.model, BaseLlm): - return self.model - elif self.model: - return LLMRegistry.new_llm(self.model) - else: - # Find model from ancestors - ancestor_agent = self.parent_agent - while ancestor_agent is not None: - if hasattr(ancestor_agent, "canonical_model"): - return ancestor_agent.canonical_model - ancestor_agent = ancestor_agent.parent_agent - return LLMRegistry.new_llm(self.DEFAULT_MODEL) - - async def _resolve_tools( - self, - ctx: Optional[ReadonlyContext] = None, - ) -> List[BaseTool]: - """Resolve tool unions to BaseTool instances. + """Convert a tool union to a list of BaseTool instances. Args: + tool_union: A callable, BaseTool, or BaseToolset. ctx: Optional context for toolset resolution. Returns: - List of resolved BaseTool instances. - """ - if self._resolved_tools is not None: - return self._resolved_tools - - resolved = [] - for tool_union in self.tools: - resolved.extend(await _convert_tool_union_to_tools(tool_union, ctx)) - - self._resolved_tools = resolved - return resolved - - async def _get_coding_executor( - self, - ctx: InvocationContext, - ) -> CodingAgentCodeExecutor: - """Get or create the CodingAgentCodeExecutor. - - Args: - ctx: The invocation context. - - Returns: - The configured code executor. + List of BaseTool instances. """ - if self._coding_executor is not None: - return self._coding_executor - - # Resolve tools - tools = await self._resolve_tools(ReadonlyContext(ctx)) - - # Get or create underlying executor - if self.code_executor: - underlying = self.code_executor - else: - # Default to ContainerCodeExecutor - try: - from ..code_executors.container_code_executor import ContainerCodeExecutor - - underlying = ContainerCodeExecutor( - image="python:3.11-slim", - ) - except ImportError as e: - raise ImportError( - "CodingAgent requires ContainerCodeExecutor. " - 'Please install with: pip install "google-adk[extensions]" ' - "or provide a custom code_executor." - ) from e - - # Create the CodingAgentCodeExecutor wrapper - self._coding_executor = CodingAgentCodeExecutor( - underlying_executor=underlying, - tools=tools, - authorized_imports=self.authorized_imports, - tool_server_host=self.tool_server_host, - tool_server_port=self.tool_server_port, - stateful=self.stateful, - error_retry_attempts=self.error_retry_attempts, - ) - - return self._coding_executor - - def _build_system_prompt(self, tools: List[BaseTool]) -> str: - """Build the system prompt with tool documentation. + if isinstance(tool_union, BaseTool): + return [tool_union] + if callable(tool_union): + return [FunctionTool(func=tool_union)] + # BaseToolset + if ctx: + return await tool_union.get_tools_with_prefix(ctx) + return await tool_union.get_tools_with_prefix(None) - Args: - tools: List of available tools. - Returns: - The complete system prompt. +@experimental +class CodingAgent(BaseAgent): + """Agent that generates Python code to solve tasks using available tools. + + CodingAgent implements a ReAct-style loop where it: + 1. Receives a task from the user + 2. Generates Python code that calls available tools + 3. Executes the code in a sandboxed environment + 4. Processes the results and either provides an answer or continues + + Tools are made available as Python functions that the generated code + can call. The code execution happens in a container for security, + with tool calls routed via HTTP to the host. + + Attributes: + model: The LLM model to use for code generation. + instruction: Additional instructions for the agent. + tools: List of tools available to the agent. + code_executor: The underlying code executor (e.g., ContainerCodeExecutor). + authorized_imports: Set of allowed Python imports. + max_iterations: Maximum ReAct loop iterations. + error_retry_attempts: Number of retries on execution errors. + stateful: Whether to maintain state across iterations. + tool_server_host: Host for the tool execution server. + tool_server_port: Port for the tool execution server. """ - return generate_system_prompt( - tools=tools, - custom_instruction=self.instruction, - ) - - def _extract_code_block(self, response_text: str) -> Optional[str]: - """Extract code from the model response. - Args: - response_text: The model's response text. + DEFAULT_MODEL: ClassVar[str] = "gemini-2.5-flash" - Returns: - The extracted code, or None if no code block found. - """ - # Try tool_code blocks first - pattern = r"```tool_code\n(.*?)```" - match = re.search(pattern, response_text, re.DOTALL) - if match: - return match.group(1).strip() + config_type: ClassVar[Type[BaseAgentConfig]] = CodingAgentConfig - # Fall back to python blocks - pattern = r"```python\n(.*?)```" - match = re.search(pattern, response_text, re.DOTALL) - if match: - return match.group(1).strip() + model: Union[str, BaseLlm] = "" + """The model to use for code generation.""" + + instruction: str = "" + """Additional instructions for the agent.""" - return None + tools: List[ToolUnion] = Field(default_factory=list) + """Tools available to the agent.""" + + code_executor: Optional[BaseCodeExecutor] = None + """The underlying code executor. If not set, uses ContainerCodeExecutor.""" + + authorized_imports: FrozenSet[str] = DEFAULT_SAFE_IMPORTS + """Set of allowed import patterns.""" + + max_iterations: int = 10 + """Maximum number of ReAct loop iterations.""" + + error_retry_attempts: int = 2 + """Number of retries on execution errors.""" + + stateful: bool = False + """Whether to maintain state across iterations.""" + + tool_server_host: Optional[str] = None + """Host for the tool execution server.""" + + tool_server_port: int = 8765 + """Port for the tool execution server.""" + + # Internal state + _coding_executor: Optional[CodingAgentCodeExecutor] = None + _resolved_tools: Optional[List[BaseTool]] = None + + class Config: + """Pydantic config.""" + + arbitrary_types_allowed = True + + @property + def canonical_model(self) -> BaseLlm: + """Get the resolved model as BaseLlm.""" + if isinstance(self.model, BaseLlm): + return self.model + elif self.model: + return LLMRegistry.new_llm(self.model) + else: + # Find model from ancestors + ancestor_agent = self.parent_agent + while ancestor_agent is not None: + if hasattr(ancestor_agent, "canonical_model"): + return ancestor_agent.canonical_model + ancestor_agent = ancestor_agent.parent_agent + return LLMRegistry.new_llm(self.DEFAULT_MODEL) + + async def _resolve_tools( + self, + ctx: Optional[ReadonlyContext] = None, + ) -> List[BaseTool]: + """Resolve tool unions to BaseTool instances. + + Args: + ctx: Optional context for toolset resolution. + + Returns: + List of resolved BaseTool instances. + """ + if self._resolved_tools is not None: + return self._resolved_tools + + resolved = [] + for tool_union in self.tools: + resolved.extend(await _convert_tool_union_to_tools(tool_union, ctx)) + + self._resolved_tools = resolved + return resolved + + async def _get_coding_executor( + self, + ctx: InvocationContext, + ) -> CodingAgentCodeExecutor: + """Get or create the CodingAgentCodeExecutor. + + Args: + ctx: The invocation context. + + Returns: + The configured code executor. + """ + if self._coding_executor is not None: + return self._coding_executor + + # Resolve tools + tools = await self._resolve_tools(ReadonlyContext(ctx)) + + # Get or create underlying executor + if self.code_executor: + underlying = self.code_executor + else: + # Default to ContainerCodeExecutor + try: + from ..code_executors.container_code_executor import ( + ContainerCodeExecutor, + ) + + underlying = ContainerCodeExecutor( + image="python:3.11-slim", + ) + except ImportError as e: + raise ImportError( + "CodingAgent requires ContainerCodeExecutor. " + 'Please install with: pip install "google-adk[extensions]" ' + "or provide a custom code_executor." + ) from e + + # Create the CodingAgentCodeExecutor wrapper + self._coding_executor = CodingAgentCodeExecutor( + underlying_executor=underlying, + tools=tools, + authorized_imports=self.authorized_imports, + tool_server_host=self.tool_server_host, + tool_server_port=self.tool_server_port, + stateful=self.stateful, + error_retry_attempts=self.error_retry_attempts, + ) - def _is_real_error(self, stderr: str) -> bool: - """Check if stderr contains a real error vs just warnings. + return self._coding_executor - Args: - stderr: The stderr output from code execution. + def _build_system_prompt(self, tools: List[BaseTool]) -> str: + """Build the system prompt with tool documentation. - Returns: - True if stderr contains a real error, False if just warnings. - """ - if not stderr: - return False - - # Patterns that indicate this is just a warning, not an error - warning_patterns = [ - "WARNING: Running pip as the 'root' user", - "[notice] A new release of pip", - "[notice] To update, run:", - "pip install --upgrade pip", - "UserWarning:", - "DeprecationWarning:", - "FutureWarning:", - "RuntimeWarning:", - ] - - # Check if ALL lines are just warnings - lines = stderr.strip().split("\n") - real_error_lines = [] - for line in lines: - line_stripped = line.strip() - if not line_stripped: - continue - is_warning = any( - pattern.lower() in line_stripped.lower() - for pattern in warning_patterns - ) - if not is_warning: - real_error_lines.append(line) - - # Also check for actual error keywords - error_keywords = [ - "error:", - "traceback", - "exception", - "syntaxerror", - "nameerror", - "typeerror", - "valueerror", - "importerror", - "modulenotfounderror", - "attributeerror", - "keyerror", - "indexerror", - "zerodivisionerror", - ] - - stderr_lower = stderr.lower() - has_error_keyword = any( - keyword in stderr_lower for keyword in error_keywords - ) - - # Consider it a real error if there are non-warning lines with error keywords - return bool(real_error_lines) and has_error_keyword - - def _build_error_feedback( - self, - error: str, - code: str, - ) -> str: - """Build feedback message for execution errors. + Args: + tools: List of available tools. - Args: - error: The error message. - code: The code that caused the error. + Returns: + The complete system prompt. + """ + return generate_system_prompt( + tools=tools, + custom_instruction=self.instruction, + ) - Returns: - Formatted error feedback for the LLM. - """ - return f"""The code execution failed with the following error: + def _extract_code_block(self, response_text: str) -> Optional[str]: + """Extract code from the model response. + + Args: + response_text: The model's response text. + + Returns: + The extracted code, or None if no code block found. + """ + # Try tool_code blocks first + pattern = r"```tool_code\n(.*?)```" + match = re.search(pattern, response_text, re.DOTALL) + if match: + return match.group(1).strip() + + # Fall back to python blocks + pattern = r"```python\n(.*?)```" + match = re.search(pattern, response_text, re.DOTALL) + if match: + return match.group(1).strip() + + return None + + def _is_real_error(self, stderr: str) -> bool: + """Check if stderr contains a real error vs just warnings. + + Args: + stderr: The stderr output from code execution. + + Returns: + True if stderr contains a real error, False if just warnings. + """ + if not stderr: + return False + + # Patterns that indicate this is just a warning, not an error + warning_patterns = [ + "WARNING: Running pip as the 'root' user", + "[notice] A new release of pip", + "[notice] To update, run:", + "pip install --upgrade pip", + "UserWarning:", + "DeprecationWarning:", + "FutureWarning:", + "RuntimeWarning:", + ] + + # Check if ALL lines are just warnings + lines = stderr.strip().split("\n") + real_error_lines = [] + for line in lines: + line_stripped = line.strip() + if not line_stripped: + continue + is_warning = any( + pattern.lower() in line_stripped.lower() for pattern in warning_patterns + ) + if not is_warning: + real_error_lines.append(line) + + # Also check for actual error keywords + error_keywords = [ + "error:", + "traceback", + "exception", + "syntaxerror", + "nameerror", + "typeerror", + "valueerror", + "importerror", + "modulenotfounderror", + "attributeerror", + "keyerror", + "indexerror", + "zerodivisionerror", + ] + + stderr_lower = stderr.lower() + has_error_keyword = any(keyword in stderr_lower for keyword in error_keywords) + + # Consider it a real error if there are non-warning lines with error keywords + return bool(real_error_lines) and has_error_keyword + + def _build_error_feedback( + self, + error: str, + code: str, + ) -> str: + """Build feedback message for execution errors. + + Args: + error: The error message. + code: The code that caused the error. + + Returns: + Formatted error feedback for the LLM. + """ + return f"""The code execution failed with the following error: ``` {error} @@ -396,215 +401,258 @@ def _build_error_feedback( - Python syntax errors """ - @override - async def _run_async_impl( - self, - ctx: InvocationContext, - ) -> AsyncGenerator[Event, None]: - """Core implementation of the ReAct loop. - - Args: - ctx: The invocation context. - - Yields: - Events generated during execution. - """ - # Load or initialize state - state = self._load_agent_state(ctx, CodingAgentState) - if state is None: - state = CodingAgentState() - - # Resolve tools and get executor - tools = await self._resolve_tools(ReadonlyContext(ctx)) - coding_executor = await self._get_coding_executor(ctx) - - # Create tool context for the executor - tool_context = ToolContext(invocation_context=ctx) - coding_executor.set_context(ctx, tool_context) - - # Build system prompt - system_prompt = self._build_system_prompt(tools) - - # Get the model - model = self.canonical_model - - # Build initial request with conversation history - contents = [] - events = ctx._get_events(current_invocation=True, current_branch=True) - for event in events: - if event.content: - contents.append(event.content) - - iteration = 0 - error_count = 0 - final_answer = None - - while iteration < self.max_iterations: - iteration += 1 - state.iteration_count = iteration - - # Build LLM request - llm_request = LlmRequest( - model=model.model, - contents=contents, - config=types.GenerateContentConfig( - system_instruction=system_prompt, - ), - ) - - # Call the model (generate_content_async returns an async generator) - llm_response = None - async for response in model.generate_content_async( - llm_request, stream=False - ): - llm_response = response - break - - # Extract response text - response_text = "" - if llm_response and llm_response.content and llm_response.content.parts: - response_text = "".join( - part.text for part in llm_response.content.parts if part.text - ) + @override + async def _run_async_impl( + self, + ctx: InvocationContext, + ) -> AsyncGenerator[Event, None]: + """Core implementation of the ReAct loop. + + Args: + ctx: The invocation context. + + Yields: + Events generated during execution. + """ + # Load or initialize state + state = self._load_agent_state(ctx, CodingAgentState) + if state is None: + state = CodingAgentState() + + # Resolve tools and get executor + tools = await self._resolve_tools(ReadonlyContext(ctx)) + coding_executor = await self._get_coding_executor(ctx) + + # Create tool context for the executor + tool_context = ToolContext(invocation_context=ctx) + coding_executor.set_context(ctx, tool_context) + + # Build system prompt + system_prompt = self._build_system_prompt(tools) + + # Get the model + model = self.canonical_model + + # Build initial request with conversation history + contents = [] + events = ctx._get_events(current_invocation=True, current_branch=True) + for event in events: + if event.content: + contents.append(event.content) + + iteration = 0 + error_count = 0 + final_answer = None + + while iteration < self.max_iterations: + iteration += 1 + state.iteration_count = iteration + + # Build LLM request + llm_request = LlmRequest( + model=model.model, + contents=contents, + config=types.GenerateContentConfig( + system_instruction=system_prompt, + ), + ) - # Check for code block - code = self._extract_code_block(response_text) - - if not code: - # No code generated - treat as final response - # Check if the response looks like a final answer - final_answer = response_text - break - - # Execute the code - code_input = CodeExecutionInput(code=code) - exec_result = coding_executor.execute_code_extended( - invocation_context=ctx, - code_execution_input=code_input, - ) - - # Record execution in state - state.execution_history.append({ - "iteration": iteration, - "code": code, - "stdout": exec_result.clean_stdout, - "stderr": exec_result.code_result.stderr, - "tool_traces": exec_result.tool_traces, - "has_final_answer": exec_result.has_final_answer, - }) - - # Check for errors - ignore warnings from pip and other non-fatal stderr - stderr = exec_result.code_result.stderr or "" - is_real_error = self._is_real_error(stderr) - - if is_real_error: - error_count += 1 - state.error_count = error_count - - if error_count > self.error_retry_attempts: - # Too many errors - give up - final_answer = ( - "I encountered too many errors while executing code. " - f"Last error: {stderr}" - ) - break - - # Build error feedback and add to conversation - error_feedback = self._build_error_feedback( - stderr, - code, - ) - contents.append( - types.Content( - role="model", - parts=[types.Part(text=response_text)], + # Generate event ID for tracing + event_id = f"coding_agent_llm_{uuid.uuid4().hex[:8]}" + + # Call the model with tracing (generate_content_async returns an async generator) + llm_response = None + generation_start = time.time() + with tracer.start_as_current_span("call_llm"): + async for response in model.generate_content_async( + llm_request, stream=False + ): + llm_response = response + break + + # Record trace for the LLM call + if llm_response: + trace_call_llm( + ctx, + event_id, + llm_request, + llm_response, + ) + + generation_duration_ms = (time.time() - generation_start) * 1000 + + # Extract response text + response_text = "" + if llm_response and llm_response.content and llm_response.content.parts: + response_text = "".join( + part.text for part in llm_response.content.parts if part.text + ) + + # Check for code block + code = self._extract_code_block(response_text) + + if not code: + # No code generated - treat as final response + # Check if the response looks like a final answer + final_answer = response_text + break + + # Trace code generation + with tracer.start_as_current_span("generate_code"): + trace_code_generation( + agent_name=self.name, + code=code, + iteration=iteration, + duration_ms=generation_duration_ms, + ) + + # Execute the code with tracing + code_input = CodeExecutionInput(code=code) + execution_start = time.time() + with tracer.start_as_current_span("execute_code"): + exec_result = coding_executor.execute_code_extended( + invocation_context=ctx, + code_execution_input=code_input, + ) + execution_duration_ms = (time.time() - execution_start) * 1000 + + # Trace code execution + trace_code_execution( + agent_name=self.name, + code=code, + stdout=exec_result.clean_stdout, + stderr=exec_result.code_result.stderr or "", + duration_ms=execution_duration_ms, + success=not self._is_real_error( + exec_result.code_result.stderr or "" + ), + has_final_answer=exec_result.has_final_answer, + ) + + # Record execution in state + state.execution_history.append( + { + "iteration": iteration, + "code": code, + "stdout": exec_result.clean_stdout, + "stderr": exec_result.code_result.stderr, + "tool_traces": exec_result.tool_traces, + "has_final_answer": exec_result.has_final_answer, + } ) - ) - contents.append( - types.Content( - role="user", - parts=[types.Part(text=error_feedback)], + + # Check for errors - ignore warnings from pip and other non-fatal stderr + stderr = exec_result.code_result.stderr or "" + is_real_error = self._is_real_error(stderr) + + if is_real_error: + error_count += 1 + state.error_count = error_count + + if error_count > self.error_retry_attempts: + # Too many errors - give up + final_answer = ( + "I encountered too many errors while executing code. " + f"Last error: {stderr}" + ) + break + + # Build error feedback and add to conversation + error_feedback = self._build_error_feedback( + stderr, + code, + ) + contents.append( + types.Content( + role="model", + parts=[types.Part(text=response_text)], + ) + ) + contents.append( + types.Content( + role="user", + parts=[types.Part(text=error_feedback)], + ) + ) + continue + + # Reset error count on success + error_count = 0 + state.error_count = 0 + + # Check for final answer + if exec_result.has_final_answer: + final_answer = exec_result.final_answer + break + + # Add execution result to conversation and continue + contents.append( + types.Content( + role="model", + parts=[types.Part(text=response_text)], + ) ) - ) - continue - - # Reset error count on success - error_count = 0 - state.error_count = 0 - - # Check for final answer - if exec_result.has_final_answer: - final_answer = exec_result.final_answer - break - - # Add execution result to conversation and continue - contents.append( - types.Content( - role="model", - parts=[types.Part(text=response_text)], - ) - ) - - # Add execution output as user message - output_text = f"""Code execution result: + + # Add execution output as user message + output_text = f"""Code execution result: ``` {exec_result.clean_stdout} ``` """ - contents.append( - types.Content( - role="user", - parts=[types.Part(text=output_text)], - ) - ) - - # Build final event - if final_answer is None: - final_answer = ( - "I was unable to complete the task within the allowed iterations." - ) - - # Convert final_answer to string if needed - if not isinstance(final_answer, str): - import json - - try: - final_answer = json.dumps(final_answer) - except (TypeError, ValueError): - final_answer = str(final_answer) - - # Update state in context - ctx.agent_states[self.name] = state.model_dump() - - # Yield final event - yield Event( - invocation_id=ctx.invocation_id, - author=self.name, - branch=ctx.branch, - content=types.Content( - role="model", - parts=[types.Part(text=final_answer)], - ), - actions=EventActions( - agent_state=state.model_dump(), - ), - ) - - @model_validator(mode="after") - def _validate_model(self) -> CodingAgent: - """Validate the model after construction.""" - return self - - def cleanup(self) -> None: - """Clean up resources.""" - if self._coding_executor: - self._coding_executor.cleanup() - self._coding_executor = None - self._resolved_tools = None - - def __del__(self): - """Destructor to clean up resources.""" - try: - self.cleanup() - except Exception: - pass + contents.append( + types.Content( + role="user", + parts=[types.Part(text=output_text)], + ) + ) + + # Build final event + if final_answer is None: + final_answer = ( + "I was unable to complete the task within the allowed iterations." + ) + + # Convert final_answer to string if needed + if not isinstance(final_answer, str): + import json + + try: + final_answer = json.dumps(final_answer) + except (TypeError, ValueError): + final_answer = str(final_answer) + + # Update state in context + ctx.agent_states[self.name] = state.model_dump() + + # Yield final event + yield Event( + invocation_id=ctx.invocation_id, + author=self.name, + branch=ctx.branch, + content=types.Content( + role="model", + parts=[types.Part(text=final_answer)], + ), + actions=EventActions( + agent_state=state.model_dump(), + ), + ) + + @model_validator(mode="after") + def _validate_model(self) -> CodingAgent: + """Validate the model after construction.""" + return self + + def cleanup(self) -> None: + """Clean up resources.""" + if self._coding_executor: + self._coding_executor.cleanup() + self._coding_executor = None + self._resolved_tools = None + + def __del__(self): + """Destructor to clean up resources.""" + try: + self.cleanup() + except Exception: + pass From 6eb94b591db417cd385c022d6b3717244afa433d Mon Sep 17 00:00:00 2001 From: Sudhendra Date: Sun, 25 Jan 2026 17:19:29 -0600 Subject: [PATCH 07/10] Update src/google/adk/code_executors/coding_agent_code_executor.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/google/adk/code_executors/coding_agent_code_executor.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/google/adk/code_executors/coding_agent_code_executor.py b/src/google/adk/code_executors/coding_agent_code_executor.py index 2f5e72532a..498d4371a4 100644 --- a/src/google/adk/code_executors/coding_agent_code_executor.py +++ b/src/google/adk/code_executors/coding_agent_code_executor.py @@ -315,6 +315,11 @@ def _replay_history( invocation_context=invocation_context, code_execution_input=input_data, ) + if last_result.stderr: + raise RuntimeError( + f'Failed to replay history step with hash {step.code_hash}. ' + f'Error: {last_result.stderr}' + ) logger.debug("Replayed history step: %s", step.code_hash) return last_result From c37c128494b28aa550410ed8392b7583cd909538 Mon Sep 17 00:00:00 2001 From: Sudhendra Date: Sun, 25 Jan 2026 17:44:22 -0600 Subject: [PATCH 08/10] refactor(agents): Deduplicate DEFAULT_SAFE_IMPORTS and sync tracing.py with upstream Addresses Gemini code review suggestion (PR #4259): - Remove duplicate DEFAULT_SAFE_IMPORTS from coding_agent_config.py - Import DEFAULT_SAFE_IMPORTS from allowlist_validator.py (canonical source) - Create _DATA_SCIENCE_IMPORTS for numpy/pandas/scipy/matplotlib packages - Create _EXTENDED_SAFE_IMPORTS combining both for CodingAgentConfig default Resolves merge conflict in telemetry/tracing.py: - Sync with upstream main (new OTEL improvements, proper semconv imports) - Add CodingAgent-specific tracing functions: trace_code_generation, trace_code_execution, trace_import_validation, trace_tool_ipc Updates test to use _EXTENDED_SAFE_IMPORTS from coding_agent_config.py --- src/google/adk/agents/coding_agent_config.py | 335 ++++++++----------- tests/unittests/agents/test_coding_agent.py | 4 +- 2 files changed, 144 insertions(+), 195 deletions(-) diff --git a/src/google/adk/agents/coding_agent_config.py b/src/google/adk/agents/coding_agent_config.py index ae78c9ade0..89e42f1691 100644 --- a/src/google/adk/agents/coding_agent_config.py +++ b/src/google/adk/agents/coding_agent_config.py @@ -22,205 +22,154 @@ from pydantic import Field +from ..code_executors.allowlist_validator import DEFAULT_SAFE_IMPORTS +from ..tools.tool_configs import ToolConfig from ..utils.feature_decorator import experimental from .base_agent_config import BaseAgentConfig from .common_configs import CodeConfig -from ..tools.tool_configs import ToolConfig - -# Default set of safe imports for Python code execution -DEFAULT_SAFE_IMPORTS: FrozenSet[str] = frozenset( - { - # Standard library - safe modules - "json", - "math", - "re", - "datetime", - "collections", - "collections.*", - "itertools", - "functools", - "operator", - "string", - "textwrap", - "unicodedata", - "decimal", - "fractions", - "random", - "statistics", - "typing", - "typing.*", - "dataclasses", - "enum", - "abc", - "copy", - "pprint", - "reprlib", - "numbers", - "cmath", - "time", - "calendar", - "hashlib", - "hmac", - "base64", - "binascii", - "html", - "html.*", - "urllib.parse", - "uuid", - "struct", - "codecs", - "locale", - "gettext", - "bisect", - "heapq", - "array", - "weakref", - "types", - "contextlib", - "warnings", - "traceback", - "linecache", - "difflib", - "graphlib", - "zoneinfo", - # Common data science (can be enabled explicitly) - "numpy", - "numpy.*", - "pandas", - "pandas.*", - "scipy", - "scipy.*", - "matplotlib", - "matplotlib.*", - } +# Additional data science imports commonly used in coding agents +_DATA_SCIENCE_IMPORTS: FrozenSet[str] = frozenset({ + "numpy", + "numpy.*", + "pandas", + "pandas.*", + "scipy", + "scipy.*", + "matplotlib", + "matplotlib.*", +}) + +# Extended safe imports including common data science packages +_EXTENDED_SAFE_IMPORTS: FrozenSet[str] = ( + DEFAULT_SAFE_IMPORTS | _DATA_SCIENCE_IMPORTS ) @experimental class CodingAgentConfig(BaseAgentConfig): - """Configuration for CodingAgent. - - This config extends BaseAgentConfig with fields specific to agents that - generate and execute Python code to accomplish tasks using tools. - """ - - agent_class: Union[Literal["CodingAgent"], str] = Field( - default="CodingAgent", - description="The class of the agent. Must be CodingAgent.", - ) - - model: str = Field( - default="", - description=( - "The model to use for the agent. When not set, the agent will " - "inherit the model from its ancestor or use the default model." - ), - ) - - model_code: Optional[CodeConfig] = Field( - default=None, - description=( - "Optional. Code reference to a custom model instance. " - "Takes precedence over the model field if both are set." - ), - ) - - instruction: str = Field( - default="", - description=( - "Dynamic instructions for the agent, guiding its behavior. " - "Can contain placeholders like {variable_name} that will be " - "resolved at runtime using session state and context." - ), - ) - - tools: Optional[List[ToolConfig]] = Field( - default=None, - description=( - "Optional. The list of tools available to the agent. " - "Tools are exposed as Python functions that the agent can call " - "in the generated code." - ), - ) - - code_executor: Optional[CodeConfig] = Field( - default=None, - description=( - "Optional. Code reference to a custom code executor instance. " - "If not set, a default ContainerCodeExecutor will be used." - ), - ) - - authorized_imports: FrozenSet[str] = Field( - default=DEFAULT_SAFE_IMPORTS, - description=( - "Set of allowed import names/patterns. Supports wildcards " - '(e.g., "collections.*" allows all collections submodules). ' - "Any imports not in this set will be rejected before execution." - ), - ) - - max_iterations: int = Field( - default=10, - ge=1, - le=100, - description=( - "Maximum number of ReAct loop iterations. Each iteration " - "involves generating code, executing it, and processing results." - ), - ) - - error_retry_attempts: int = Field( - default=2, - ge=0, - le=10, - description=( - "Number of times to retry code execution on errors. " - "Error messages are fed back to the LLM for correction." - ), - ) - - stateful: bool = Field( - default=False, - description=( - "Whether to maintain state across iterations. If True, " - "execution history is preserved and re-executed to restore state." - ), - ) - - tool_server_host: Optional[str] = Field( - default=None, - description=( - "Host address for the tool execution server. If not set, " - "auto-detection will try host.docker.internal first, " - "then fall back to 172.17.0.1 for Linux." - ), - ) - - tool_server_port: int = Field( - default=8765, - ge=1024, - le=65535, - description="Port for the tool execution server.", - ) - - before_model_callbacks: Optional[List[CodeConfig]] = Field( - default=None, - description="Optional. Callbacks to be called before calling the LLM.", - ) - - after_model_callbacks: Optional[List[CodeConfig]] = Field( - default=None, - description="Optional. Callbacks to be called after calling the LLM.", - ) - - before_tool_callbacks: Optional[List[CodeConfig]] = Field( - default=None, - description="Optional. Callbacks to be called before calling a tool.", - ) - - after_tool_callbacks: Optional[List[CodeConfig]] = Field( - default=None, - description="Optional. Callbacks to be called after calling a tool.", - ) + """Configuration for CodingAgent. + + This config extends BaseAgentConfig with fields specific to agents that + generate and execute Python code to accomplish tasks using tools. + """ + + agent_class: Union[Literal["CodingAgent"], str] = Field( + default="CodingAgent", + description="The class of the agent. Must be CodingAgent.", + ) + + model: str = Field( + default="", + description=( + "The model to use for the agent. When not set, the agent will " + "inherit the model from its ancestor or use the default model." + ), + ) + + model_code: Optional[CodeConfig] = Field( + default=None, + description=( + "Optional. Code reference to a custom model instance. " + "Takes precedence over the model field if both are set." + ), + ) + + instruction: str = Field( + default="", + description=( + "Dynamic instructions for the agent, guiding its behavior. " + "Can contain placeholders like {variable_name} that will be " + "resolved at runtime using session state and context." + ), + ) + + tools: Optional[List[ToolConfig]] = Field( + default=None, + description=( + "Optional. The list of tools available to the agent. " + "Tools are exposed as Python functions that the agent can call " + "in the generated code." + ), + ) + + code_executor: Optional[CodeConfig] = Field( + default=None, + description=( + "Optional. Code reference to a custom code executor instance. " + "If not set, a default ContainerCodeExecutor will be used." + ), + ) + + authorized_imports: FrozenSet[str] = Field( + default=_EXTENDED_SAFE_IMPORTS, + description=( + "Set of allowed import names/patterns. Supports wildcards " + '(e.g., "collections.*" allows all collections submodules). ' + "Any imports not in this set will be rejected before execution." + ), + ) + + max_iterations: int = Field( + default=10, + ge=1, + le=100, + description=( + "Maximum number of ReAct loop iterations. Each iteration " + "involves generating code, executing it, and processing results." + ), + ) + + error_retry_attempts: int = Field( + default=2, + ge=0, + le=10, + description=( + "Number of times to retry code execution on errors. " + "Error messages are fed back to the LLM for correction." + ), + ) + + stateful: bool = Field( + default=False, + description=( + "Whether to maintain state across iterations. If True, " + "execution history is preserved and re-executed to restore state." + ), + ) + + tool_server_host: Optional[str] = Field( + default=None, + description=( + "Host address for the tool execution server. If not set, " + "auto-detection will try host.docker.internal first, " + "then fall back to 172.17.0.1 for Linux." + ), + ) + + tool_server_port: int = Field( + default=8765, + ge=1024, + le=65535, + description="Port for the tool execution server.", + ) + + before_model_callbacks: Optional[List[CodeConfig]] = Field( + default=None, + description="Optional. Callbacks to be called before calling the LLM.", + ) + + after_model_callbacks: Optional[List[CodeConfig]] = Field( + default=None, + description="Optional. Callbacks to be called after calling the LLM.", + ) + + before_tool_callbacks: Optional[List[CodeConfig]] = Field( + default=None, + description="Optional. Callbacks to be called before calling a tool.", + ) + + after_tool_callbacks: Optional[List[CodeConfig]] = Field( + default=None, + description="Optional. Callbacks to be called after calling a tool.", + ) diff --git a/tests/unittests/agents/test_coding_agent.py b/tests/unittests/agents/test_coding_agent.py index a54ec1cb23..657b1e87da 100644 --- a/tests/unittests/agents/test_coding_agent.py +++ b/tests/unittests/agents/test_coding_agent.py @@ -22,8 +22,8 @@ from google.adk.agents.coding_agent import CodingAgent from google.adk.agents.coding_agent import CodingAgentState +from google.adk.agents.coding_agent_config import _EXTENDED_SAFE_IMPORTS from google.adk.agents.coding_agent_config import CodingAgentConfig -from google.adk.agents.coding_agent_config import DEFAULT_SAFE_IMPORTS from google.adk.tools.base_tool import BaseTool import pytest @@ -41,7 +41,7 @@ def test_default_values(self): assert config.error_retry_attempts == 2 assert config.stateful is False assert config.tool_server_port == 8765 - assert config.authorized_imports == DEFAULT_SAFE_IMPORTS + assert config.authorized_imports == _EXTENDED_SAFE_IMPORTS def test_custom_values(self): """Test that custom values can be set.""" From c32d4260f6a4b46db1d33c5c6e2b7a1aa0150d00 Mon Sep 17 00:00:00 2001 From: Sudhendra Date: Sun, 25 Jan 2026 18:18:28 -0600 Subject: [PATCH 09/10] Update src/google/adk/code_executors/tool_code_generator.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/google/adk/code_executors/tool_code_generator.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/google/adk/code_executors/tool_code_generator.py b/src/google/adk/code_executors/tool_code_generator.py index 1c55c36830..e0c0b51440 100644 --- a/src/google/adk/code_executors/tool_code_generator.py +++ b/src/google/adk/code_executors/tool_code_generator.py @@ -266,10 +266,9 @@ def {tool.name}({param_str}) -> dict: """ kwargs = {{k: v for k, v in locals().items() if v is not None}} response = _call_adk_tool("{tool.name}", **kwargs) - # Extract the result from the tool server response - if isinstance(response, dict) and "result" in response: - return response["result"] - return response + # On success, the response is a dict with a "result" key. + # On failure, _call_adk_tool raises an exception. + return response["result"] ''' return stub From 406b9c5de2268da83b5fa8aee3ee6a807e2db1ba Mon Sep 17 00:00:00 2001 From: Sudhendra Date: Sun, 25 Jan 2026 18:27:42 -0600 Subject: [PATCH 10/10] fix(coding_agent): add context managers and narrow exception handling --- contributing/samples/coding_agent/agent.py | 14 +++++--------- src/google/adk/agents/coding_agent.py | 14 ++++++++------ .../code_executors/coding_agent_code_executor.py | 9 +++++++-- 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/contributing/samples/coding_agent/agent.py b/contributing/samples/coding_agent/agent.py index 2691e73420..e67e141806 100644 --- a/contributing/samples/coding_agent/agent.py +++ b/contributing/samples/coding_agent/agent.py @@ -49,6 +49,7 @@ import binascii from datetime import datetime import os +import socket import urllib.error import urllib.request @@ -127,19 +128,19 @@ def fetch_url(url: str) -> dict: "url": url, "success": True, } - except urllib.error.URLError as e: + except (urllib.error.URLError, socket.timeout) as e: return { "content": "", "url": url, "success": False, "error": f"Failed to fetch URL: {str(e)}", } - except Exception as e: + except OSError as e: return { "content": "", "url": url, "success": False, - "error": f"Unexpected error: {str(e)}", + "error": f"Failed to decode response: {str(e)}", } @@ -256,11 +257,6 @@ def save_chart(image_data: str, filename: str) -> dict: "success": False, "error": f"Failed to save file: {str(e)}", } - except Exception as e: - return { - "success": False, - "error": f"Unexpected error: {str(e)}", - } def list_saved_charts() -> dict: @@ -293,7 +289,7 @@ def list_saved_charts() -> dict: "directory": HOST_CHARTS_DIR, "count": len(charts), } - except Exception as e: + except OSError as e: return { "success": False, "error": f"Failed to list charts: {str(e)}", diff --git a/src/google/adk/agents/coding_agent.py b/src/google/adk/agents/coding_agent.py index 0baaa7d131..d540ad4ba6 100644 --- a/src/google/adk/agents/coding_agent.py +++ b/src/google/adk/agents/coding_agent.py @@ -650,9 +650,11 @@ def cleanup(self) -> None: self._coding_executor = None self._resolved_tools = None - def __del__(self): - """Destructor to clean up resources.""" - try: - self.cleanup() - except Exception: - pass + def __enter__(self) -> "CodingAgent": + """Enter context manager and return self.""" + return self + + def __exit__(self, exc_type, exc, traceback) -> bool: + """Exit context manager and clean up resources.""" + self.cleanup() + return False diff --git a/src/google/adk/code_executors/coding_agent_code_executor.py b/src/google/adk/code_executors/coding_agent_code_executor.py index 498d4371a4..c1394a8ef8 100644 --- a/src/google/adk/code_executors/coding_agent_code_executor.py +++ b/src/google/adk/code_executors/coding_agent_code_executor.py @@ -505,6 +505,11 @@ def cleanup(self) -> None: self._stop_tool_server() self._execution_history.clear() - def __del__(self): - """Destructor to clean up resources.""" + def __enter__(self) -> CodingAgentCodeExecutor: + """Enter context manager and return self.""" + return self + + def __exit__(self, exc_type, exc, traceback) -> bool: + """Exit context manager and clean up resources.""" self.cleanup() + return False