diff --git a/.github/CODING_AGENT_ISSUE.md b/.github/CODING_AGENT_ISSUE.md new file mode 100644 index 0000000000..db39b23495 --- /dev/null +++ b/.github/CODING_AGENT_ISSUE.md @@ -0,0 +1,94 @@ +# GitHub Issue: CodingAgent Feature Request + +Use this content to create an issue at: +https://github.com/google/adk-python/issues/new?template=feature_request.md + +--- + +## Title + +feat(agents): Add CodingAgent (agents that think in code) + +--- + +## Is your feature request related to a problem? Please describe. + +ADK’s current default agent interaction pattern is “tool selection from a fixed action set”. This is powerful, but it breaks down for two increasingly common workloads: + +1) Long-context work beyond model context windows +- Many real tasks require operating over very large corpora: codebases, logs, datasets, multi-file configs, or long documents. +- If the agent must keep the relevant source text and intermediate results inside the LLM context, it becomes context-window bound and expensive. +- Recent work such as “Recursive Language Models” (arXiv:2512.24601) proposes treating long prompts as an external environment and letting the model programmatically examine/decompose/recursively process snippets. This suggests a practical direction for agents: move heavy inspection, decomposition, and intermediate state out of the prompt and into an execution environment. + - https://arxiv.org/abs/2512.24601 + +2) Expressiveness and composability limits of pure tool-calling +- Tool-calling assumes we can enumerate actions up-front. In open-ended tasks, the agent needs to compose multiple operations, iterate, cache intermediate artifacts, and implement “one-off” transformations without requiring new bespoke tools each time. +- A code-based action space lets the agent compose operations naturally (loops, conditionals, helper functions), which reduces the need for an explosion of tools. + +3) Developer experience gap for building “coding agents” and sub-agent architectures +- Users increasingly want agent systems like Claude Code / OpenCode: multi-step coding workflows with sub-agents (planner, tester, refactorer, etc.) and strong “think in code” execution. +- ADK has strong orchestration primitives; adding a first-class code-executing agent unlocks building these systems within ADK while keeping sandboxing and tool integration. + +Related inspiration: HuggingFace “smolagents” positions CodeAgent as a first-class concept (“agents that think in code”) and supports sandbox backends (Docker, etc.). +- https://github.com/huggingface/smolagents + +--- + +## Describe the solution you’d like + +Add a new experimental agent type: CodingAgent. + +CodingAgent should: +- Generate Python code as the primary action representation (in `tool_code` blocks). +- Execute that code in a sandboxed environment (Docker-based initially). +- Allow generated code to call ADK tools safely via an IPC bridge (e.g., HTTP) rather than exposing the host runtime directly. +- Support iterative execution (ReAct-style loop): generate → run → observe stdout/tool results → refine → final answer. + +Why this solves the problem +- Long-context: aligns with the “external environment” framing in arXiv:2512.24601 by enabling the agent to iteratively inspect, decompose, and process large inputs using code and persisted artifacts, instead of forcing all content into the model context. +- Composability: code enables arbitrary composition (loops, conditionals, helper functions) without requiring every combination to be implemented as a first-class tool. +- Coding-agent architectures: makes it straightforward to build higher-level workflows and multi-agent hierarchies where sub-agents can generate/run code for specialized tasks. + +High-level architecture + +User → CodingAgent (LLM) → sandbox executor (Docker Python) + ↘ tool IPC server on host ↙ + +Proposed execution environments (progressive) +- v1: Docker Python sandbox (existing ContainerCodeExecutor integration) +- future: REPL / Jupyter-kernel style execution modes for interactive, stateful sessions (still sandboxed) + +--- + +## Describe alternatives you’ve considered + +1) “Just add a code-execution tool” to existing agents +- Pros: minimal surface-area change. +- Cons: code execution becomes an occasional tool call rather than the agent’s primary action space; harder to support tight generate→execute→iterate loops and long-context strategies that rely on an external environment. + +2) Require users to write bespoke tools for every operation +- Pros: explicit and controlled. +- Cons: does not scale; real workflows need ad-hoc transformations and composition that explode the tool surface area. + +3) Run code on the host interpreter +- Pros: simplest. +- Cons: unacceptable security risk; sandboxing is required for a general-purpose code agent. + +--- + +## Additional context + +Future directions enabled by CodingAgent +- Long-context scaffolds inspired by arXiv:2512.24601: treat large inputs (files, repo trees, logs) as an “environment” the agent queries/decomposes recursively using code, storing intermediate state outside the LLM context. +- Sub-agent coding workflows (Claude Code / OpenCode style): planner/tester/refactor sub-agents coordinated by ADK, each using code execution. +- Multiple sandbox backends (like smolagents): Docker initially, with optional future support for other sandboxes and interactive execution modes. + +Links +- smolagents (inspiration): https://github.com/huggingface/smolagents +- Recursive Language Models (long-context framing): https://arxiv.org/abs/2512.24601 + +Labels to add +- enhancement +- agents +- new-feature +- experimental diff --git a/.github/CODING_AGENT_PLAN.md b/.github/CODING_AGENT_PLAN.md new file mode 100644 index 0000000000..a20436094f --- /dev/null +++ b/.github/CODING_AGENT_PLAN.md @@ -0,0 +1,148 @@ +# CodingAgent - Implementation Plan & Status + +This document tracks the implementation of CodingAgent, an experimental agent type that generates and executes Python code in sandboxed containers. + +## Overview + +CodingAgent is a ReAct-style agent that: +- Generates Python code to solve tasks using an LLM (Gemini) +- Executes code in sandboxed Docker containers +- Calls ADK tools from generated code via HTTP IPC +- Iterates until a final answer is produced + +## Implementation Status + +### Core Components ✅ Complete + +| Component | File | Status | Lines | +|-----------|------|--------|-------| +| CodingAgent | `src/google/adk/agents/coding_agent.py` | ✅ Complete | ~610 | +| CodingAgentConfig | `src/google/adk/agents/coding_agent_config.py` | ✅ Complete | ~225 | +| CodingAgentCodeExecutor | `src/google/adk/code_executors/coding_agent_code_executor.py` | ✅ Complete | ~505 | +| ToolCodeGenerator | `src/google/adk/code_executors/tool_code_generator.py` | ✅ Complete | ~475 | +| ToolExecutionServer | `src/google/adk/code_executors/tool_execution_server.py` | ✅ Complete | ~365 | +| AllowlistValidator | `src/google/adk/code_executors/allowlist_validator.py` | ✅ Complete | ~355 | + +### Sample Agent ✅ Complete + +| File | Status | Description | +|------|--------|-------------| +| `contributing/samples/coding_agent/agent.py` | ✅ Complete | Data Analysis Agent (~360 lines) | +| `contributing/samples/coding_agent/README.md` | ✅ Complete | Documentation (~290 lines) | +| `contributing/samples/coding_agent/__init__.py` | ✅ Complete | Module init | + +### Unit Tests ✅ Complete + +| Test File | Status | Lines | +|-----------|--------|-------| +| `tests/unittests/agents/test_coding_agent.py` | ✅ Complete | ~310 | +| `tests/unittests/code_executors/test_allowlist_validator.py` | ✅ Complete | ~320 | +| `tests/unittests/code_executors/test_tool_code_generator.py` | ✅ Complete | ~320 | + +### Manual E2E Tests ✅ Passed + +| Test Scenario | Status | Notes | +|--------------|--------|-------| +| Basic math query ("What is 25 * 17?") | ✅ Passed | Returns 425 | +| Data analysis (Titanic survival rate) | ✅ Passed | Returns 38.38% | +| Visualization (bar chart by class) | ✅ Passed | Chart saved to host | +| Multi-step analysis | ✅ Passed | Stats + visualization + insights | +| Tool calling via HTTP IPC | ✅ Passed | fetch_url, save_chart work | +| Error handling (pip warnings) | ✅ Passed | Ignores non-fatal stderr | +| Chart saving to host system | ✅ Passed | Saved to /tmp/adk_charts/ | + +## Architecture + +``` +┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐ +│ User Query │────▶│ CodingAgent │────▶│ Docker Container│ +│ │ │ (Gemini LLM) │ │ (Python 3.11) │ +└─────────────────┘ └──────────────────┘ └─────────────────┘ + │ │ + │ │ Executes + ▼ │ generated code + ┌──────────────┐ │ + │ Tool Server │◀────────────────┘ + │ (HTTP IPC) │ Tool calls via HTTP + └──────────────┘ +``` + +### How Tool IPC Works + +1. CodingAgent starts ToolExecutionServer on host (port 8765) +2. Code is generated with tool stubs that make HTTP POST requests +3. Container reaches host via `host.docker.internal` (macOS/Windows) or bridge gateway (Linux) +4. Tool server executes actual tool functions with proper context +5. Results returned to container via HTTP response + +## Key Design Decisions + +| Decision | Choice | Rationale | +|----------|--------|-----------| +| Container image | `python:3.11-slim` + runtime pip | Simpler for users, no custom Dockerfile | +| Tool communication | HTTP IPC | Works across container boundary, secure | +| Import validation | Allowlist-based | Security without blocking legitimate use | +| Chart saving | `save_chart` tool | Transfers data to host filesystem | +| Error handling | Distinguish warnings from errors | pip warnings shouldn't fail execution | + +## Sample Agent: Data Analyst + +### Tools Available + +| Tool | Description | +|------|-------------| +| `fetch_url(url)` | Fetch CSV/JSON/text from URLs | +| `get_sample_datasets()` | List available datasets (Titanic, Iris, Tips) | +| `get_current_time()` | Get current timestamp | +| `save_chart(image_data, filename)` | Save base64 chart to host | +| `list_saved_charts()` | List saved charts | + +### Example Queries + +1. "What is the survival rate on the Titanic?" +2. "Create a bar chart showing survival rate by passenger class" +3. "Analyze the iris dataset and create a scatter plot colored by species" +4. "Perform comprehensive analysis: stats, survival rates, visualization, insights" + +## Files Changed Summary + +``` + .github/CODING_AGENT_PLAN.md | Plan document + contributing/samples/coding_agent/README.md | 290 lines + contributing/samples/coding_agent/__init__.py | 17 lines + contributing/samples/coding_agent/agent.py | 360 lines + src/google/adk/agents/__init__.py | +2 exports + src/google/adk/agents/coding_agent.py | 610 lines + src/google/adk/agents/coding_agent_config.py | 225 lines + src/google/adk/code_executors/__init__.py | +6 exports + src/google/adk/code_executors/allowlist_validator.py | 355 lines + src/google/adk/code_executors/coding_agent_code_executor.py | 505 lines + src/google/adk/code_executors/tool_code_generator.py | 475 lines + src/google/adk/code_executors/tool_execution_server.py | 365 lines + tests/unittests/agents/test_coding_agent.py | 310 lines + tests/unittests/code_executors/test_allowlist_validator.py | 320 lines + tests/unittests/code_executors/test_tool_code_generator.py | 320 lines +``` + +**Total: ~4,200 lines of new code** + +## PR Checklist + +- [x] Implementation complete +- [x] Unit tests written and passing +- [x] Manual E2E tests passing +- [x] Sample agent created with README +- [x] Code follows ADK style guide (relative imports, `from __future__ import annotations`) +- [x] Marked as `@experimental` +- [ ] Run `./autoformat.sh` before PR +- [ ] Run full test suite: `pytest tests/unittests` +- [ ] Create GitHub issue (see `.github/CODING_AGENT_ISSUE.md`) +- [ ] Submit PR with testing plan + +## Future Enhancements (Out of Scope) + +- Stateful execution (persist variables across turns) +- Custom container images with pre-installed packages +- VertexAI code execution integration +- Support for JavaScript/TypeScript +- Streaming output during execution diff --git a/.github/ISSUE_TEMPLATE/coding_agent_feature.md b/.github/ISSUE_TEMPLATE/coding_agent_feature.md new file mode 100644 index 0000000000..f2b433a399 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/coding_agent_feature.md @@ -0,0 +1,226 @@ +--- +name: "Feature: CodingAgent - Code-generating Agent with Sandboxed Execution" +about: "New experimental agent type that generates and executes Python code" +title: "feat(agents): Add CodingAgent for code generation and sandboxed execution" +labels: "enhancement, agents, new-feature" +assignees: '' +--- + +## Summary + +Add a new experimental agent type called **CodingAgent** that generates Python code to solve tasks, executes it in a sandboxed Docker container, and iterates using a ReAct-style loop. This mirrors the popular "Code Interpreter" pattern seen in other AI platforms. + +## Is your feature request related to a problem? + +Currently, ADK agents can only interact with the world through pre-defined tools. While powerful, this approach has limitations: + +1. **Limited flexibility**: Users must anticipate all possible operations and create tools for each +2. **No computational capability**: Agents cannot perform complex calculations, data analysis, or create visualizations without custom tools +3. **No iteration**: Standard tool-calling doesn't easily support multi-step reasoning with intermediate computations +4. **Competitive gap**: Other platforms (OpenAI Code Interpreter, Anthropic's computer use) offer code execution capabilities + +**User pain points:** +- "I want my agent to analyze a CSV file and create a chart" - requires building custom tools +- "I need multi-step calculations with intermediate results" - awkward with standard tools +- "I want the agent to figure out HOW to solve a problem, not just call predefined functions" + +## Describe the solution + +### CodingAgent Overview + +A new agent type that: +1. Receives a task from the user +2. Generates Python code to accomplish the task +3. Executes the code in a sandboxed Docker container +4. Processes results and either provides an answer or continues iterating +5. Can call ADK tools from within generated code via HTTP IPC + +### Architecture + +``` +┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐ +│ User Query │────▶│ CodingAgent │────▶│ Docker Container│ +│ │ │ (Gemini LLM) │ │ (Python 3.11) │ +└─────────────────┘ └──────────────────┘ └─────────────────┘ + │ │ + │ │ Executes + ▼ │ generated code + ┌──────────────┐ │ + │ Tool Server │◀────────────────┘ + │ (HTTP IPC) │ Tool calls via HTTP + └──────────────┘ +``` + +### Key Components + +| Component | File | Description | +|-----------|------|-------------| +| CodingAgent | `src/google/adk/agents/coding_agent.py` | Main agent class with ReAct loop | +| CodingAgentConfig | `src/google/adk/agents/coding_agent_config.py` | Pydantic configuration | +| CodingAgentCodeExecutor | `src/google/adk/code_executors/coding_agent_code_executor.py` | Wrapper that injects tools | +| ToolCodeGenerator | `src/google/adk/code_executors/tool_code_generator.py` | Generates Python stubs for tools | +| ToolExecutionServer | `src/google/adk/code_executors/tool_execution_server.py` | HTTP server for tool IPC | +| AllowlistValidator | `src/google/adk/code_executors/allowlist_validator.py` | Import security validation | + +### API Design + +```python +from google.adk.agents import CodingAgent +from google.adk.code_executors import ContainerCodeExecutor + +def fetch_data(url: str) -> dict: + """Fetch data from a URL.""" + # Implementation... + +root_agent = CodingAgent( + name="data_analyst", + model="gemini-2.5-flash", + instruction="You are a data analyst. Analyze data and provide insights.", + tools=[fetch_data], # Tools available to generated code + code_executor=ContainerCodeExecutor(image="python:3.11-slim"), + authorized_imports=DEFAULT_SAFE_IMPORTS | {"pandas", "matplotlib"}, + max_iterations=10, + error_retry_attempts=2, +) +``` + +### Security Features + +1. **Sandboxed execution**: All code runs in isolated Docker containers +2. **Import allowlisting**: Only authorized imports are permitted (configurable) +3. **Tool isolation**: Tools execute on host via HTTP, not in container +4. **No filesystem access**: Container has no access to host filesystem +5. **Network isolation**: Container can only reach tool server + +### Sample Agent + +A complete Data Analysis Agent sample is included: +- Fetches datasets from URLs (Titanic, Iris, Tips) +- Analyzes data with pandas +- Creates visualizations with matplotlib +- Saves charts to host system via `save_chart` tool + +## Describe alternatives you've considered + +### Alternative 1: Extend LlmAgent with code execution +- **Pros**: Simpler architecture, reuses existing agent +- **Cons**: Conflates two distinct patterns, harder to configure + +### Alternative 2: Code execution as a tool only +- **Pros**: Minimal changes, fits existing model +- **Cons**: No ReAct loop, no iteration, limited capability + +### Alternative 3: Use external code execution service +- **Pros**: Offloads security concerns +- **Cons**: Adds external dependency, latency, cost + +**Chosen approach**: Dedicated CodingAgent provides cleanest separation of concerns, explicit configuration, and full control over the execution environment. + +## Implementation Details + +### Files Added/Modified + +**New files (agents):** +- `src/google/adk/agents/coding_agent.py` (~550 lines) +- `src/google/adk/agents/coding_agent_config.py` (~225 lines) + +**New files (code_executors):** +- `src/google/adk/code_executors/coding_agent_code_executor.py` (~500 lines) +- `src/google/adk/code_executors/tool_code_generator.py` (~475 lines) +- `src/google/adk/code_executors/tool_execution_server.py` (~365 lines) +- `src/google/adk/code_executors/allowlist_validator.py` (~350 lines) + +**Modified files:** +- `src/google/adk/agents/__init__.py` - Export CodingAgent +- `src/google/adk/code_executors/__init__.py` - Export new components + +**Sample agent:** +- `contributing/samples/coding_agent/agent.py` (~360 lines) +- `contributing/samples/coding_agent/README.md` (~290 lines) + +**Tests:** +- `tests/unittests/agents/test_coding_agent.py` (~310 lines) +- `tests/unittests/code_executors/test_allowlist_validator.py` (~320 lines) +- `tests/unittests/code_executors/test_tool_code_generator.py` (~320 lines) + +### How Tool IPC Works + +1. When CodingAgent starts, it launches a ToolExecutionServer on the host +2. Generated code includes tool stubs that make HTTP POST requests +3. Tool server receives requests, executes actual tool functions +4. Results are returned to container via HTTP response +5. On macOS/Windows: uses `host.docker.internal` +6. On Linux: uses Docker bridge network gateway + +### Experimental Status + +This feature is marked as **@experimental** because: +- API may change based on user feedback +- Security model is being refined +- Performance optimizations are ongoing + +## Testing Plan + +### Unit Tests + +```bash +pytest tests/unittests/agents/test_coding_agent.py -v +pytest tests/unittests/code_executors/test_allowlist_validator.py -v +pytest tests/unittests/code_executors/test_tool_code_generator.py -v +``` + +### Manual E2E Tests + +**Test 1: Basic Query** +``` +Query: "What is 25 times 17?" +Expected: Agent generates code, calculates, returns "425" +``` + +**Test 2: Data Analysis** +``` +Query: "What is the survival rate on the Titanic?" +Expected: Agent fetches data, analyzes with pandas, returns "38.38%" +``` + +**Test 3: Visualization** +``` +Query: "Create a bar chart of Titanic survival by class" +Expected: Agent creates chart, saves to /tmp/adk_charts/, reports filepath +``` + +**Test 4: Multi-step Analysis** +``` +Query: "Analyze Titanic data: show stats, survival by sex/class, and provide 3 insights" +Expected: Agent performs multiple steps, creates visualization, provides comprehensive answer +``` + +## Additional Context + +### Related Work +- OpenAI Code Interpreter +- Anthropic Computer Use +- Google AI Studio code execution + +### Dependencies +- Docker (required for sandboxed execution) +- ContainerCodeExecutor (existing ADK component) + +### Future Enhancements +- [ ] Support for stateful execution (persist variables across turns) +- [ ] Custom container images with pre-installed packages +- [ ] Integration with VertexAI code execution +- [ ] Support for additional languages (JavaScript, etc.) + +### Screenshots + +**Data Analysis Agent in action:** +``` +User: Create a bar chart showing survival rate by passenger class from Titanic + +Agent: [Installs packages, fetches data, creates visualization] +Response: The bar chart has been saved to /tmp/adk_charts/survival_by_class.png +- 1st Class: 63% survival rate +- 2nd Class: 47% survival rate +- 3rd Class: 24% survival rate +``` diff --git a/.github/prompts/plan-codingAgent.prompt.md b/.github/prompts/plan-codingAgent.prompt.md new file mode 100644 index 0000000000..11454da141 --- /dev/null +++ b/.github/prompts/plan-codingAgent.prompt.md @@ -0,0 +1,149 @@ +# Plan: CodingAgent Implementation for ADK-Python + +Create a production-ready `CodingAgent` class that generates Python code to execute tools via ReAct loop, with HTTP-based tool injection, dual-layer security (allowlist + container), configurable statefulness with history re-execution, and full ADK telemetry integration. + +## Steps + +1. **Create `CodingAgentConfig`** in [src/google/adk/agents/coding_agent_config.py](src/google/adk/agents/coding_agent_config.py) - Pydantic config extending `BaseAgentConfig` with fields: `model`, `instruction`, `tools`, `code_executor`, `authorized_imports` (frozenset), `max_iterations` (default 10), `error_retry_attempts` (default 2), `stateful` (default False), `tool_server_host` (default `host.docker.internal`, fallback `172.17.0.1`), `tool_server_port` (default 8765). + +2. **Create `CodingAgentState`** in [src/google/adk/agents/coding_agent.py](src/google/adk/agents/coding_agent.py) - State extending `BaseAgentState` with: `iteration_count`, `error_count`, `execution_history` (list of `ExecutionStep` with `code`, `result`, `tool_traces`, `success` fields for re-execution optimization). + +3. **Create `ToolCodeGenerator`** in [src/google/adk/code_executors/tool_code_generator.py](src/google/adk/code_executors/tool_code_generator.py) - Functions: `generate_runtime_header()` (HTTP client + `_call_adk_tool()` + trace collection), `generate_tool_stubs()` (typed function stubs from `BaseTool._get_declaration()`), `generate_final_answer_stub()`, `generate_system_prompt()` (tool docs + 1-2 few-shot examples showing `tool_code` format + `final_answer()` usage). + +4. **Create `AllowlistValidator`** in [src/google/adk/code_executors/allowlist_validator.py](src/google/adk/code_executors/allowlist_validator.py) - `validate_imports()` using AST extraction, `DEFAULT_SAFE_IMPORTS` frozenset, `ImportValidationError` with violation details, `is_import_allowed()` helper supporting wildcards (e.g., `collections.*`). + +5. **Create `ToolExecutionServer`** in [src/google/adk/code_executors/tool_execution_server.py](src/google/adk/code_executors/tool_execution_server.py) - FastAPI server with `POST /tool_call` routing to `BaseTool.run_async()` with full `ToolContext`, `GET /tool_trace`, lifecycle `start()/stop()`, configurable host detection (`host.docker.internal` → `172.17.0.1` fallback). + +6. **Create `CodingAgentCodeExecutor`** in [src/google/adk/code_executors/coding_agent_code_executor.py](src/google/adk/code_executors/coding_agent_code_executor.py) - Composable wrapper with: tool stub prepending, allowlist pre-validation, server lifecycle, history re-execution (skip successful steps via hash comparison), trace extraction (`__TOOL_TRACE__:`), final answer detection (`__FINAL_ANSWER__:`). + +7. **Create `CodingAgent` class** in [src/google/adk/agents/coding_agent.py](src/google/adk/agents/coding_agent.py) - `_run_async_impl()` ReAct loop: build system prompt with few-shot examples, call `canonical_model`, parse code blocks, validate imports, execute via `CodingAgentCodeExecutor`, detect `final_answer()` OR no-code fallback, yield events with `state_delta`, retry errors with LLM feedback up to `error_retry_attempts`. + +8. **Add telemetry** in [src/google/adk/telemetry/tracing.py](src/google/adk/telemetry/tracing.py) - Add `trace_code_generation()`, `trace_code_execution()`, `trace_import_validation()`, `trace_tool_ipc()` following existing patterns with code content, duration, and error attributes. + +9. **Update exports** in [src/google/adk/agents/__init__.py](src/google/adk/agents/__init__.py) and [src/google/adk/code_executors/__init__.py](src/google/adk/code_executors/__init__.py) - Add all new classes to `__all__` with lazy loading for executor components. + +10. **Create comprehensive tests** in `tests/unittests/agents/test_coding_agent.py` and `tests/unittests/code_executors/test_coding_agent_*.py` - Cover: ReAct loop, final answer detection + fallback, allowlist validation, error retry, stateful history re-execution with skip optimization, tool traces, host fallback logic. + +11. **Create sample agent** in `contributing/samples/coding_agent/` - Example with `web_search`, `calculator`, `read_file` tools demonstrating multi-step code generation with `ContainerCodeExecutor`. + +## Key Implementation Details + +### Few-shot examples in system prompt + +```python +SYSTEM_PROMPT_EXAMPLES = ''' +Example 1 - Using tools: +```tool_code +result = web_search(query="Python async best practices") +print(result["snippets"][0]) +``` + +Example 2 - Final answer: +```tool_code +data = read_file(path="data.csv") +total = sum(float(row["amount"]) for row in data["rows"]) +final_answer(f"The total amount is ${total:.2f}") +``` +''' +``` + +### History re-execution optimization + +```python +def _should_skip_step(self, step: ExecutionStep, code_hash: str) -> bool: + """Skip if code unchanged and previously succeeded.""" + return step.success and step.code_hash == code_hash +``` + +### Host detection with fallback + +```python +def _resolve_tool_server_host(self) -> str: + if self.tool_server_host: + return self.tool_server_host + # Try host.docker.internal first (Docker Desktop) + # Fallback to 172.17.0.1 (Linux bridge network) + return detect_docker_host_address() +``` + +## Architecture Diagram + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ CodingAgent._run_async_impl() │ +│ ┌─────────────┐ ┌──────────────┐ ┌─────────────────────┐ │ +│ │ Build prompt│──▶│ Call LLM │──▶│ Parse code blocks │ │ +│ │ + tool docs │ │ (canonical_ │ │ (delimiters) │ │ +│ └─────────────┘ │ model) │ └─────────┬───────────┘ │ +│ └──────────────┘ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────┐ │ +│ │ CodingAgentCodeExecutor.execute_code() │ │ +│ │ ┌─────────────┐ ┌──────────────┐ ┌───────────────┐ │ │ +│ │ │ Validate │─▶│ Prepend tool │─▶│ Execute in │ │ │ +│ │ │ imports │ │ stubs + │ │ container │ │ │ +│ │ │ (allowlist) │ │ runtime │ │ │ │ │ +│ │ └─────────────┘ └──────────────┘ └───────┬───────┘ │ │ +│ └─────────────────────────────────────────────┼───────────┘ │ +│ │ │ +│ ┌──────────────────────────────────────┘ │ +│ │ HTTP IPC (host.docker.internal) │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────┐ │ +│ │ ToolExecutionServer (FastAPI) │ │ +│ │ POST /tool_call ──▶ BaseTool.run_async(ToolContext) │ │ +│ │ GET /tool_trace ──▶ call_traces[] │ │ +│ └─────────────────────────────────────────────────────────┘ │ +│ │ +│ ┌─────────────────┐ ┌─────────────────────────────────────┐ │ +│ │ Check final_ │◀──│ Extract traces + clean stdout │ │ +│ │ answer() OR │ │ (__TOOL_TRACE__, __FINAL_ANSWER__) │ │ +│ │ fallback │ └─────────────────────────────────────┘ │ +│ └────────┬────────┘ │ +│ │ if done: yield final Event │ +│ │ else: feed result back to LLM (loop) │ +└───────────┴─────────────────────────────────────────────────────┘ +``` + +## File Structure + +``` +src/google/adk/ +├── agents/ +│ ├── coding_agent.py # CodingAgent + CodingAgentState +│ ├── coding_agent_config.py # CodingAgentConfig +│ └── __init__.py # Updated exports +├── code_executors/ +│ ├── tool_code_generator.py # Stub generation + system prompt +│ ├── allowlist_validator.py # Import validation +│ ├── tool_execution_server.py # FastAPI IPC server +│ ├── coding_agent_code_executor.py # Main executor wrapper +│ └── __init__.py # Updated exports +└── telemetry/ + └── tracing.py # New trace functions + +tests/unittests/ +├── agents/ +│ └── test_coding_agent.py +└── code_executors/ + ├── test_tool_code_generator.py + ├── test_allowlist_validator.py + └── test_coding_agent_code_executor.py + +contributing/samples/ +└── coding_agent/ + ├── __init__.py + └── agent.py +``` + +## Design Decisions + +| Decision | Choice | Rationale | +|----------|--------|-----------| +| Tool injection | HTTP IPC (Code Prepending) | Native ADK integration, full `ToolContext` access, async support | +| Security | Allowlist + Container | Defense-in-depth: import validation before container isolation | +| Final answer | Explicit `final_answer()` + fallback | Reliability with graceful degradation | +| Stateful mode | Re-execute history | Safer than pickle, with skip optimization for speed | +| Async tools | Sync wrapper via host server | Host handles async natively, container code stays simple | +| Docker host | Configurable with fallback | `host.docker.internal` → `172.17.0.1` for cross-platform | +| Retries | Default 2 with LLM feedback | Matches `BaseCodeExecutor.error_retry_attempts` pattern | diff --git a/.gitignore b/.gitignore index 47f633c5c5..5f8ad01f2e 100644 --- a/.gitignore +++ b/.gitignore @@ -116,3 +116,5 @@ CLAUDE.md .rooignore .bolt/ .v0/ + +CODING_AGENT_PLAN.md diff --git a/contributing/samples/coding_agent/README.md b/contributing/samples/coding_agent/README.md new file mode 100644 index 0000000000..cb5ce8b68c --- /dev/null +++ b/contributing/samples/coding_agent/README.md @@ -0,0 +1,290 @@ +# Data Analysis Agent + +A CodingAgent sample that demonstrates AI-powered data analysis with Python code execution. + +## Overview + +This sample showcases the CodingAgent's ability to: + +- Fetch datasets from URLs (CSV, JSON, text) +- Analyze data using pandas +- Create visualizations with matplotlib +- Generate statistical summaries and insights +- Execute multi-step reasoning through code + +The agent writes and executes Python code in a sandboxed Docker container, calling tools via HTTP IPC when needed. + +## Architecture + +``` +┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐ +│ User Query │────>│ CodingAgent │────>│ Docker Container│ +│ │ │ (Gemini 2.5) │ │ (Python 3.11) │ +└─────────────────┘ └──────────────────┘ └─────────────────┘ + │ │ + │ │ Executes + v │ pandas/matplotlib + ┌──────────────┐ │ code + │ Tool Server │<────────────────┘ + │ (HTTP IPC) │ Tool calls + └──────────────┘ (fetch_url, etc.) +``` + +### How It Works + +1. User sends a natural language query +2. CodingAgent (powered by Gemini 2.5) generates Python code +3. Code is executed in a sandboxed Docker container +4. Tools (like `fetch_url`) are called via HTTP to the Tool Server on the host +5. Results are returned to the container, LLM iterates if needed +6. Final answer is provided via `final_answer()` function + +## Prerequisites + +- **Docker**: Must be installed and running +- **API Key**: Set `GOOGLE_API_KEY` in `.env` file or environment +- **Python**: 3.10+ (for running ADK CLI) + +## Quick Start + +### 1. Set up your API key + +Create a `.env` file in this directory: + +```bash +echo "GOOGLE_API_KEY=your_api_key_here" > contributing/samples/coding_agent/.env +``` + +### 2. Run the agent + +**Using CLI (interactive):** + +```bash +adk run contributing/samples/coding_agent +``` + +**Using Web UI:** + +```bash +adk web contributing/samples +``` + +Then navigate to `http://localhost:8000` and select `coding_agent`. + +## Available Tools + +### `fetch_url(url: str) -> dict` + +Fetches content from a URL. Supports CSV, JSON, and plain text. + +**Returns:** +- `content`: The fetched content as a string +- `content_type`: MIME type of the content +- `size`: Size in bytes +- `success`: Boolean indicating success +- `error`: Error message (only on failure) + +**Example:** +```python +result = fetch_url("https://example.com/data.csv") +if result["success"]: + csv_content = result["content"] +``` + +### `get_sample_datasets() -> dict` + +Returns available sample datasets with URLs and descriptions. + +**Available datasets:** +- `titanic`: Titanic passenger survival data (891 rows) +- `iris`: Iris flower classification data (150 rows) +- `tips`: Restaurant tipping data (244 rows) + +### `get_current_time() -> dict` + +Returns current date and time information. + +### `save_chart(image_data: str, filename: str) -> dict` + +Saves a chart image to the **host system** (not the Docker container). This is essential for making visualizations accessible outside of Docker. + +**Parameters:** +- `image_data`: Base64-encoded image data (PNG recommended) +- `filename`: Name for the saved file (e.g., "chart.png") + +**Returns:** +- `success`: Boolean indicating success +- `filepath`: Full path where the file was saved (e.g., `/tmp/adk_charts/chart.png`) +- `size`: Size of the saved file in bytes +- `error`: Error message (only on failure) + +**Example usage in generated code:** +```python +import base64 +import io +import matplotlib.pyplot as plt + +# Create your plot +plt.figure(figsize=(10, 6)) +plt.bar(['A', 'B', 'C'], [1, 2, 3]) + +# Save to buffer and encode +buf = io.BytesIO() +plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') +buf.seek(0) +image_data = base64.b64encode(buf.read()).decode('utf-8') +plt.close() + +# Save to host system +result = save_chart(image_data=image_data, filename="my_chart.png") +print(f"Chart saved to: {result['filepath']}") +``` + +### `list_saved_charts() -> dict` + +Lists all charts saved to the host system. + +**Returns:** +- `charts`: List of saved chart filenames +- `directory`: Directory path (`/tmp/adk_charts`) +- `count`: Number of charts found + +## Sample Datasets + +| Dataset | Description | Columns | +|---------|-------------|---------| +| **Titanic** | Passenger survival data from the Titanic disaster | PassengerId, Survived, Pclass, Name, Sex, Age, SibSp, Parch, Ticket, Fare, Cabin, Embarked | +| **Iris** | Classic flower classification dataset | sepal_length, sepal_width, petal_length, petal_width, species | +| **Tips** | Restaurant tipping behavior data | total_bill, tip, sex, smoker, day, time, size | + +## Example Interactions + +### Basic Analysis + +``` +User: What is the survival rate on the Titanic? + +Agent: [Fetches Titanic CSV, calculates survival statistics] +Response: The overall survival rate on the Titanic was 38.4% (342 out of 891 passengers survived). +``` + +### Visualization + +``` +User: Create a bar chart showing survival rate by passenger class on the Titanic + +Agent: [Fetches data, creates matplotlib chart, uses save_chart tool] +Response: I've created a bar chart showing survival rates by class: +- 1st Class: 63% survival rate +- 2nd Class: 47% survival rate +- 3rd Class: 24% survival rate +Chart saved to: /tmp/adk_charts/survival_by_class.png +``` + +Charts are saved to `/tmp/adk_charts/` on your host system, making them accessible outside Docker. + +### Exploratory Analysis + +``` +User: Analyze the iris dataset and give me key insights + +Agent: [Fetches iris data, runs statistical analysis, creates visualizations] +Response: Key insights from the Iris dataset: +1. The dataset contains 150 samples across 3 species (50 each) +2. Setosa is clearly separable from the other two species +3. Petal length is the best single predictor for species classification +... +``` + +### Custom URL + +``` +User: Fetch the tips dataset and tell me which day has the highest average tip + +Agent: [Fetches tips CSV, analyzes by day] +Response: Sunday has the highest average tip at $3.26, followed by Saturday at $2.99. +``` + +## Troubleshooting + +### Docker not running + +``` +Error: Cannot connect to Docker daemon +``` + +**Solution:** Start Docker Desktop or the Docker service. + +### API key issues + +``` +Error: Missing key inputs argument! +``` + +**Solution:** Ensure `GOOGLE_API_KEY` is set in `.env` file or environment. + +### Container networking issues + +``` +Error: Connection refused to tool server +``` + +**Solution:** On macOS/Windows, Docker Desktop should handle `host.docker.internal` automatically. On Linux, you may need to configure the Docker bridge network. + +### Package installation in container + +The agent automatically installs pandas and matplotlib at runtime. If you see import errors, ensure the container has internet access. + +## Extending the Agent + +### Adding Custom Tools + +Add new tool functions in `agent.py`: + +```python +def my_custom_tool(param: str) -> dict: + """Description of what this tool does. + + Args: + param: Description of the parameter. + + Returns: + Dictionary with results. + """ + # Implementation + return {"result": "..."} + +# Add to the tools list +root_agent = CodingAgent( + ... + tools=[fetch_url, get_sample_datasets, get_current_time, my_custom_tool], + ... +) +``` + +### Using a Custom Container Image + +For faster execution with pre-installed packages: + +```python +code_executor=ContainerCodeExecutor( + image="my-custom-image:latest", # Image with pandas/matplotlib pre-installed +) +``` + +### Adjusting Iteration Limits + +```python +root_agent = CodingAgent( + ... + max_iterations=15, # More iterations for complex tasks + error_retry_attempts=3, # More retries on errors + ... +) +``` + +## Related Documentation + +- [CodingAgent Documentation](https://google.github.io/adk-docs/agents/coding-agent) +- [ContainerCodeExecutor](https://google.github.io/adk-docs/code-executors/container) +- [ADK Tools](https://google.github.io/adk-docs/tools) diff --git a/contributing/samples/coding_agent/__init__.py b/contributing/samples/coding_agent/__init__.py new file mode 100644 index 0000000000..373c27ec12 --- /dev/null +++ b/contributing/samples/coding_agent/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sample CodingAgent that demonstrates code generation and tool usage.""" + +from . import agent diff --git a/contributing/samples/coding_agent/agent.py b/contributing/samples/coding_agent/agent.py new file mode 100644 index 0000000000..e67e141806 --- /dev/null +++ b/contributing/samples/coding_agent/agent.py @@ -0,0 +1,404 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Conversational Data Analysis Agent using CodingAgent. + +This sample demonstrates a CodingAgent configured as a conversational data +analyst that can: +- Have multi-turn conversations about data analysis +- Fetch datasets from URLs (CSV, JSON, text) +- Analyze data using pandas +- Create visualizations using matplotlib +- Generate statistical summaries and insights +- Remember context from previous messages + +Prerequisites: +- Docker must be installed and running +- Set GOOGLE_API_KEY or configure Vertex AI credentials + +Usage: + adk run contributing/samples/coding_agent + adk web contributing/samples + +Example conversation: + User: "What datasets do you have available?" + Agent: "I have access to three sample datasets: titanic, iris, and tips..." + + User: "Let's look at the titanic data. What's the survival rate?" + Agent: "I'll analyze the Titanic dataset for you... The overall survival + rate was 38.4%. Would you like me to break this down by gender + or passenger class?" + + User: "Yes, show me by passenger class with a chart" + Agent: "Here's a bar chart showing survival rates by class... First class + passengers had the highest survival rate at 63%..." +""" + +import base64 +import binascii +from datetime import datetime +import os +import socket +import urllib.error +import urllib.request + +from google.adk.agents import CodingAgent +from google.adk.code_executors import ContainerCodeExecutor +from google.adk.code_executors.allowlist_validator import DEFAULT_SAFE_IMPORTS + +# Sample dataset URLs +SAMPLE_DATASETS = { + "titanic": { + "url": ( + "https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv" + ), + "description": ( + "Titanic passenger data with survival information. 891 passengers" + " with features like age, sex, class, fare, and survival status." + ), + "columns": ( + "PassengerId, Survived, Pclass, Name, Sex, Age, SibSp, Parch," + " Ticket, Fare, Cabin, Embarked" + ), + }, + "iris": { + "url": ( + "https://raw.githubusercontent.com/mwaskom/seaborn-data/master/iris.csv" + ), + "description": ( + "Iris flower dataset. 150 samples of 3 species with sepal and petal" + " measurements." + ), + "columns": ("sepal_length, sepal_width, petal_length, petal_width, species"), + }, + "tips": { + "url": ( + "https://raw.githubusercontent.com/mwaskom/seaborn-data/master/tips.csv" + ), + "description": ( + "Restaurant tips dataset. 244 records with bill amount, tip, and" + " customer info." + ), + "columns": "total_bill, tip, sex, smoker, day, time, size", + }, +} + + +def fetch_url(url: str) -> dict: + """Fetch content from a URL. + + Fetches data from the specified URL and returns the content along with + metadata. Supports CSV, JSON, and plain text content. + + Args: + url: The URL to fetch content from. + + Returns: + Dictionary containing: + - content: The fetched content as a string + - content_type: The MIME type of the content + - size: Size of the content in bytes + - url: The original URL + - success: Whether the fetch was successful + - error: Error message if fetch failed (only present on failure) + """ + try: + req = urllib.request.Request( + url, + headers={"User-Agent": "Mozilla/5.0 (compatible; ADK-DataAnalyst/1.0)"}, + ) + with urllib.request.urlopen(req, timeout=30) as response: + content = response.read().decode("utf-8") + content_type = response.headers.get("Content-Type", "text/plain") + return { + "content": content, + "content_type": content_type, + "size": len(content), + "url": url, + "success": True, + } + except (urllib.error.URLError, socket.timeout) as e: + return { + "content": "", + "url": url, + "success": False, + "error": f"Failed to fetch URL: {str(e)}", + } + except OSError as e: + return { + "content": "", + "url": url, + "success": False, + "error": f"Failed to decode response: {str(e)}", + } + + +def get_sample_datasets() -> dict: + """Get available sample datasets with their URLs and descriptions. + + Returns a dictionary of sample datasets that can be used for analysis. + Each dataset includes a URL, description, and column information. + + Returns: + Dictionary with dataset names as keys, each containing: + - url: Direct URL to download the CSV file + - description: Brief description of the dataset + - columns: Comma-separated list of column names + """ + return SAMPLE_DATASETS + + +def get_current_time() -> dict: + """Get the current date and time. + + Returns: + Dictionary containing: + - timestamp: ISO format timestamp + - year, month, day: Date components + - hour, minute, second: Time components + - weekday: Name of the day of the week + """ + now = datetime.now() + return { + "timestamp": now.isoformat(), + "year": now.year, + "month": now.month, + "day": now.day, + "hour": now.hour, + "minute": now.minute, + "second": now.second, + "weekday": now.strftime("%A"), + } + + +# Directory on host system to save charts +HOST_CHARTS_DIR = "/tmp/adk_charts" + + +def save_chart(image_data: str, filename: str) -> dict: + """Save a chart image to the host system. + + This tool saves base64-encoded image data to the host machine's filesystem, + making charts accessible outside the Docker container. + + To use this tool, first save your matplotlib figure to a bytes buffer, + then encode it as base64: + + Example: + import base64 + import io + import matplotlib.pyplot as plt + + # Create your plot + plt.figure() + plt.plot([1, 2, 3], [1, 4, 9]) + + # Save to buffer and encode + buf = io.BytesIO() + plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') + buf.seek(0) + image_data = base64.b64encode(buf.read()).decode('utf-8') + plt.close() + + # Save to host system + result = save_chart(image_data=image_data, filename="my_chart.png") + + Args: + image_data: Base64-encoded image data (PNG format recommended). + filename: Name for the saved file (e.g., "chart.png"). + + Returns: + Dictionary containing: + - success: Whether the save was successful + - filepath: Full path where the file was saved on the host + - size: Size of the saved file in bytes + - error: Error message if save failed (only present on failure) + """ + try: + # Ensure the output directory exists + os.makedirs(HOST_CHARTS_DIR, exist_ok=True) + + # Sanitize filename + safe_filename = os.path.basename(filename) + if not safe_filename: + safe_filename = "chart.png" + + filepath = os.path.join(HOST_CHARTS_DIR, safe_filename) + + # Decode and save + image_bytes = base64.b64decode(image_data) + with open(filepath, "wb") as f: + f.write(image_bytes) + + return { + "success": True, + "filepath": filepath, + "size": len(image_bytes), + "message": f"Chart saved to {filepath}", + } + except binascii.Error as e: + return { + "success": False, + "error": f"Invalid base64 data: {str(e)}", + } + except OSError as e: + return { + "success": False, + "error": f"Failed to save file: {str(e)}", + } + + +def list_saved_charts() -> dict: + """List all charts saved on the host system. + + Returns: + Dictionary containing: + - success: Whether the operation was successful + - charts: List of saved chart filenames + - directory: The directory where charts are saved + - count: Number of charts found + """ + try: + if not os.path.exists(HOST_CHARTS_DIR): + return { + "success": True, + "charts": [], + "directory": HOST_CHARTS_DIR, + "count": 0, + } + + charts = [ + f + for f in os.listdir(HOST_CHARTS_DIR) + if f.lower().endswith((".png", ".jpg", ".jpeg", ".svg", ".pdf")) + ] + return { + "success": True, + "charts": charts, + "directory": HOST_CHARTS_DIR, + "count": len(charts), + } + except OSError as e: + return { + "success": False, + "error": f"Failed to list charts: {str(e)}", + } + + +# Additional imports allowed for data analysis +DATA_ANALYSIS_IMPORTS = frozenset( + { + # Data analysis + "pandas", + "pandas.*", + "numpy", + "numpy.*", + # Visualization + "matplotlib", + "matplotlib.*", + "seaborn", + "seaborn.*", + # Data I/O + "csv", + "io", + "io.*", + # Encoding for chart saving + "base64", + # Subprocess for pip installs + "subprocess", + } +) + + +# Create the Data Analysis Agent +root_agent = CodingAgent( + name="data_analyst", + description=( + "An AI data analyst that analyzes datasets, creates visualizations, " + "and generates insights using Python code execution." + ), + model="gemini-2.5-flash", + instruction="""You are a friendly, conversational data analyst assistant. You help users analyze datasets, create visualizations, and generate insights using Python code execution. + +## Your Personality +- Be conversational and engaging - ask clarifying questions when needed +- Explain your analysis in plain language that anyone can understand +- Offer suggestions for follow-up analyses or visualizations +- Remember context from previous messages in our conversation + +## When the user asks for analysis or visualization: + +1. First, install required packages (only needed once per session): +```tool_code +import subprocess +subprocess.run(["pip", "install", "-q", "pandas", "matplotlib", "seaborn", "numpy"], check=True) +print("Packages installed successfully") +``` + +2. Use the available tools: + - `get_sample_datasets()` - See available sample datasets (titanic, iris, tips) + - `fetch_url(url)` - Fetch data from any URL + - `save_chart(image_data, filename)` - Save visualizations + - `list_saved_charts()` - See previously saved charts + +3. When creating charts, ALWAYS use this pattern to save them: +```tool_code +import base64 +import io +import matplotlib.pyplot as plt + +# Create your plot +plt.figure(figsize=(10, 6)) +# ... your plotting code ... + +# Save to buffer and encode as base64 +buf = io.BytesIO() +plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') +buf.seek(0) +image_data = base64.b64encode(buf.read()).decode('utf-8') +plt.close() + +# Save to host system +result = save_chart(image_data=image_data, filename="descriptive_name.png") +print(f"Chart saved: {result}") +``` + +## Response Style +- After completing an analysis, summarize your findings in a conversational way +- Mention the chart location if you created one (charts are saved to /tmp/adk_charts/) +- Ask if the user would like to explore the data further or see different visualizations +- If you don't have enough information, ask questions before diving into code + +## Examples of good responses: +- "I found some interesting patterns! The survival rate was 38.4%. Would you like me to break this down by gender or passenger class?" +- "Here's a bar chart showing the distribution. I saved it to /tmp/adk_charts/survival_by_class.png. What aspect would you like to explore next?" +- "Before I create that visualization, could you tell me which columns you're most interested in comparing?" + +Remember: You're having a conversation, not just executing tasks. Engage with the user! +""", + tools=[ + fetch_url, + get_sample_datasets, + get_current_time, + save_chart, + list_saved_charts, + ], + code_executor=ContainerCodeExecutor( + image="python:3.11-slim", + ), + authorized_imports=DEFAULT_SAFE_IMPORTS | DATA_ANALYSIS_IMPORTS, + max_iterations=10, + error_retry_attempts=2, + stateful=True, +) diff --git a/src/google/adk/agents/__init__.py b/src/google/adk/agents/__init__.py index 35198179a5..b718513135 100644 --- a/src/google/adk/agents/__init__.py +++ b/src/google/adk/agents/__init__.py @@ -13,6 +13,8 @@ # limitations under the License. from .base_agent import BaseAgent +from .coding_agent import CodingAgent +from .coding_agent import CodingAgentState from .invocation_context import InvocationContext from .live_request_queue import LiveRequest from .live_request_queue import LiveRequestQueue @@ -25,15 +27,17 @@ from .sequential_agent import SequentialAgent __all__ = [ - 'Agent', - 'BaseAgent', - 'LlmAgent', - 'LoopAgent', - 'McpInstructionProvider', - 'ParallelAgent', - 'SequentialAgent', - 'InvocationContext', - 'LiveRequest', - 'LiveRequestQueue', - 'RunConfig', + "Agent", + "BaseAgent", + "CodingAgent", + "CodingAgentState", + "LlmAgent", + "LoopAgent", + "McpInstructionProvider", + "ParallelAgent", + "SequentialAgent", + "InvocationContext", + "LiveRequest", + "LiveRequestQueue", + "RunConfig", ] diff --git a/src/google/adk/agents/coding_agent.py b/src/google/adk/agents/coding_agent.py new file mode 100644 index 0000000000..d540ad4ba6 --- /dev/null +++ b/src/google/adk/agents/coding_agent.py @@ -0,0 +1,660 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CodingAgent - An agent that generates and executes Python code. + +This module provides the CodingAgent class, which implements a ReAct-style +agent that generates Python code to accomplish tasks using available tools. +""" + +from __future__ import annotations + +import logging +import re +import time +import uuid +from typing import Any +from typing import AsyncGenerator +from typing import Callable +from typing import ClassVar +from typing import Dict +from typing import FrozenSet +from typing import List +from typing import Optional +from typing import Type +from typing import Union + +from google.genai import types +from pydantic import Field +from pydantic import model_validator +from typing_extensions import override + +from ..code_executors.allowlist_validator import DEFAULT_SAFE_IMPORTS +from ..code_executors.base_code_executor import BaseCodeExecutor +from ..code_executors.code_execution_utils import CodeExecutionInput +from ..code_executors.code_execution_utils import CodeExecutionResult +from ..code_executors.code_execution_utils import CodeExecutionUtils +from ..code_executors.coding_agent_code_executor import CodingAgentCodeExecutor +from ..code_executors.coding_agent_code_executor import CodingAgentExecutionResult +from ..code_executors.tool_code_generator import generate_system_prompt +from ..events.event import Event +from ..events.event_actions import EventActions +from ..models.base_llm import BaseLlm +from ..models.llm_request import LlmRequest +from ..models.llm_response import LlmResponse +from ..models.registry import LLMRegistry +from ..telemetry.tracing import trace_call_llm +from ..telemetry.tracing import trace_code_execution +from ..telemetry.tracing import trace_code_generation +from ..telemetry.tracing import tracer +from ..tools.base_tool import BaseTool +from ..tools.base_toolset import BaseToolset +from ..tools.function_tool import FunctionTool +from ..tools.tool_context import ToolContext +from ..utils.feature_decorator import experimental +from .base_agent import BaseAgent +from .base_agent import BaseAgentState +from .base_agent_config import BaseAgentConfig +from .coding_agent_config import CodingAgentConfig +from .invocation_context import InvocationContext +from .readonly_context import ReadonlyContext + +logger = logging.getLogger("google_adk." + __name__) + + +@experimental +class CodingAgentState(BaseAgentState): + """State for CodingAgent tracking execution progress. + + Attributes: + iteration_count: Number of ReAct loop iterations completed. + error_count: Number of consecutive errors encountered. + execution_history: List of execution steps with code, results, and traces. + """ + + iteration_count: int = 0 + error_count: int = 0 + execution_history: List[Dict[str, Any]] = Field(default_factory=list) + + +ToolUnion = Union[Callable[..., Any], BaseTool, BaseToolset] + + +async def _convert_tool_union_to_tools( + tool_union: ToolUnion, + ctx: Optional[ReadonlyContext] = None, +) -> List[BaseTool]: + """Convert a tool union to a list of BaseTool instances. + + Args: + tool_union: A callable, BaseTool, or BaseToolset. + ctx: Optional context for toolset resolution. + + Returns: + List of BaseTool instances. + """ + if isinstance(tool_union, BaseTool): + return [tool_union] + if callable(tool_union): + return [FunctionTool(func=tool_union)] + # BaseToolset + if ctx: + return await tool_union.get_tools_with_prefix(ctx) + return await tool_union.get_tools_with_prefix(None) + + +@experimental +class CodingAgent(BaseAgent): + """Agent that generates Python code to solve tasks using available tools. + + CodingAgent implements a ReAct-style loop where it: + 1. Receives a task from the user + 2. Generates Python code that calls available tools + 3. Executes the code in a sandboxed environment + 4. Processes the results and either provides an answer or continues + + Tools are made available as Python functions that the generated code + can call. The code execution happens in a container for security, + with tool calls routed via HTTP to the host. + + Attributes: + model: The LLM model to use for code generation. + instruction: Additional instructions for the agent. + tools: List of tools available to the agent. + code_executor: The underlying code executor (e.g., ContainerCodeExecutor). + authorized_imports: Set of allowed Python imports. + max_iterations: Maximum ReAct loop iterations. + error_retry_attempts: Number of retries on execution errors. + stateful: Whether to maintain state across iterations. + tool_server_host: Host for the tool execution server. + tool_server_port: Port for the tool execution server. + """ + + DEFAULT_MODEL: ClassVar[str] = "gemini-2.5-flash" + + config_type: ClassVar[Type[BaseAgentConfig]] = CodingAgentConfig + + model: Union[str, BaseLlm] = "" + """The model to use for code generation.""" + + instruction: str = "" + """Additional instructions for the agent.""" + + tools: List[ToolUnion] = Field(default_factory=list) + """Tools available to the agent.""" + + code_executor: Optional[BaseCodeExecutor] = None + """The underlying code executor. If not set, uses ContainerCodeExecutor.""" + + authorized_imports: FrozenSet[str] = DEFAULT_SAFE_IMPORTS + """Set of allowed import patterns.""" + + max_iterations: int = 10 + """Maximum number of ReAct loop iterations.""" + + error_retry_attempts: int = 2 + """Number of retries on execution errors.""" + + stateful: bool = False + """Whether to maintain state across iterations.""" + + tool_server_host: Optional[str] = None + """Host for the tool execution server.""" + + tool_server_port: int = 8765 + """Port for the tool execution server.""" + + # Internal state + _coding_executor: Optional[CodingAgentCodeExecutor] = None + _resolved_tools: Optional[List[BaseTool]] = None + + class Config: + """Pydantic config.""" + + arbitrary_types_allowed = True + + @property + def canonical_model(self) -> BaseLlm: + """Get the resolved model as BaseLlm.""" + if isinstance(self.model, BaseLlm): + return self.model + elif self.model: + return LLMRegistry.new_llm(self.model) + else: + # Find model from ancestors + ancestor_agent = self.parent_agent + while ancestor_agent is not None: + if hasattr(ancestor_agent, "canonical_model"): + return ancestor_agent.canonical_model + ancestor_agent = ancestor_agent.parent_agent + return LLMRegistry.new_llm(self.DEFAULT_MODEL) + + async def _resolve_tools( + self, + ctx: Optional[ReadonlyContext] = None, + ) -> List[BaseTool]: + """Resolve tool unions to BaseTool instances. + + Args: + ctx: Optional context for toolset resolution. + + Returns: + List of resolved BaseTool instances. + """ + if self._resolved_tools is not None: + return self._resolved_tools + + resolved = [] + for tool_union in self.tools: + resolved.extend(await _convert_tool_union_to_tools(tool_union, ctx)) + + self._resolved_tools = resolved + return resolved + + async def _get_coding_executor( + self, + ctx: InvocationContext, + ) -> CodingAgentCodeExecutor: + """Get or create the CodingAgentCodeExecutor. + + Args: + ctx: The invocation context. + + Returns: + The configured code executor. + """ + if self._coding_executor is not None: + return self._coding_executor + + # Resolve tools + tools = await self._resolve_tools(ReadonlyContext(ctx)) + + # Get or create underlying executor + if self.code_executor: + underlying = self.code_executor + else: + # Default to ContainerCodeExecutor + try: + from ..code_executors.container_code_executor import ( + ContainerCodeExecutor, + ) + + underlying = ContainerCodeExecutor( + image="python:3.11-slim", + ) + except ImportError as e: + raise ImportError( + "CodingAgent requires ContainerCodeExecutor. " + 'Please install with: pip install "google-adk[extensions]" ' + "or provide a custom code_executor." + ) from e + + # Create the CodingAgentCodeExecutor wrapper + self._coding_executor = CodingAgentCodeExecutor( + underlying_executor=underlying, + tools=tools, + authorized_imports=self.authorized_imports, + tool_server_host=self.tool_server_host, + tool_server_port=self.tool_server_port, + stateful=self.stateful, + error_retry_attempts=self.error_retry_attempts, + ) + + return self._coding_executor + + def _build_system_prompt(self, tools: List[BaseTool]) -> str: + """Build the system prompt with tool documentation. + + Args: + tools: List of available tools. + + Returns: + The complete system prompt. + """ + return generate_system_prompt( + tools=tools, + custom_instruction=self.instruction, + ) + + def _extract_code_block(self, response_text: str) -> Optional[str]: + """Extract code from the model response. + + Args: + response_text: The model's response text. + + Returns: + The extracted code, or None if no code block found. + """ + # Try tool_code blocks first + pattern = r"```tool_code\n(.*?)```" + match = re.search(pattern, response_text, re.DOTALL) + if match: + return match.group(1).strip() + + # Fall back to python blocks + pattern = r"```python\n(.*?)```" + match = re.search(pattern, response_text, re.DOTALL) + if match: + return match.group(1).strip() + + return None + + def _is_real_error(self, stderr: str) -> bool: + """Check if stderr contains a real error vs just warnings. + + Args: + stderr: The stderr output from code execution. + + Returns: + True if stderr contains a real error, False if just warnings. + """ + if not stderr: + return False + + # Patterns that indicate this is just a warning, not an error + warning_patterns = [ + "WARNING: Running pip as the 'root' user", + "[notice] A new release of pip", + "[notice] To update, run:", + "pip install --upgrade pip", + "UserWarning:", + "DeprecationWarning:", + "FutureWarning:", + "RuntimeWarning:", + ] + + # Check if ALL lines are just warnings + lines = stderr.strip().split("\n") + real_error_lines = [] + for line in lines: + line_stripped = line.strip() + if not line_stripped: + continue + is_warning = any( + pattern.lower() in line_stripped.lower() for pattern in warning_patterns + ) + if not is_warning: + real_error_lines.append(line) + + # Also check for actual error keywords + error_keywords = [ + "error:", + "traceback", + "exception", + "syntaxerror", + "nameerror", + "typeerror", + "valueerror", + "importerror", + "modulenotfounderror", + "attributeerror", + "keyerror", + "indexerror", + "zerodivisionerror", + ] + + stderr_lower = stderr.lower() + has_error_keyword = any(keyword in stderr_lower for keyword in error_keywords) + + # Consider it a real error if there are non-warning lines with error keywords + return bool(real_error_lines) and has_error_keyword + + def _build_error_feedback( + self, + error: str, + code: str, + ) -> str: + """Build feedback message for execution errors. + + Args: + error: The error message. + code: The code that caused the error. + + Returns: + Formatted error feedback for the LLM. + """ + return f"""The code execution failed with the following error: + +``` +{error} +``` + +The code that failed was: +```python +{code} +``` + +Please fix the error and try again. Common issues: +- Unauthorized imports (only use allowed imports) +- Tool call errors (check the tool documentation) +- Python syntax errors +""" + + @override + async def _run_async_impl( + self, + ctx: InvocationContext, + ) -> AsyncGenerator[Event, None]: + """Core implementation of the ReAct loop. + + Args: + ctx: The invocation context. + + Yields: + Events generated during execution. + """ + # Load or initialize state + state = self._load_agent_state(ctx, CodingAgentState) + if state is None: + state = CodingAgentState() + + # Resolve tools and get executor + tools = await self._resolve_tools(ReadonlyContext(ctx)) + coding_executor = await self._get_coding_executor(ctx) + + # Create tool context for the executor + tool_context = ToolContext(invocation_context=ctx) + coding_executor.set_context(ctx, tool_context) + + # Build system prompt + system_prompt = self._build_system_prompt(tools) + + # Get the model + model = self.canonical_model + + # Build initial request with conversation history + contents = [] + events = ctx._get_events(current_invocation=True, current_branch=True) + for event in events: + if event.content: + contents.append(event.content) + + iteration = 0 + error_count = 0 + final_answer = None + + while iteration < self.max_iterations: + iteration += 1 + state.iteration_count = iteration + + # Build LLM request + llm_request = LlmRequest( + model=model.model, + contents=contents, + config=types.GenerateContentConfig( + system_instruction=system_prompt, + ), + ) + + # Generate event ID for tracing + event_id = f"coding_agent_llm_{uuid.uuid4().hex[:8]}" + + # Call the model with tracing (generate_content_async returns an async generator) + llm_response = None + generation_start = time.time() + with tracer.start_as_current_span("call_llm"): + async for response in model.generate_content_async( + llm_request, stream=False + ): + llm_response = response + break + + # Record trace for the LLM call + if llm_response: + trace_call_llm( + ctx, + event_id, + llm_request, + llm_response, + ) + + generation_duration_ms = (time.time() - generation_start) * 1000 + + # Extract response text + response_text = "" + if llm_response and llm_response.content and llm_response.content.parts: + response_text = "".join( + part.text for part in llm_response.content.parts if part.text + ) + + # Check for code block + code = self._extract_code_block(response_text) + + if not code: + # No code generated - treat as final response + # Check if the response looks like a final answer + final_answer = response_text + break + + # Trace code generation + with tracer.start_as_current_span("generate_code"): + trace_code_generation( + agent_name=self.name, + code=code, + iteration=iteration, + duration_ms=generation_duration_ms, + ) + + # Execute the code with tracing + code_input = CodeExecutionInput(code=code) + execution_start = time.time() + with tracer.start_as_current_span("execute_code"): + exec_result = coding_executor.execute_code_extended( + invocation_context=ctx, + code_execution_input=code_input, + ) + execution_duration_ms = (time.time() - execution_start) * 1000 + + # Trace code execution + trace_code_execution( + agent_name=self.name, + code=code, + stdout=exec_result.clean_stdout, + stderr=exec_result.code_result.stderr or "", + duration_ms=execution_duration_ms, + success=not self._is_real_error( + exec_result.code_result.stderr or "" + ), + has_final_answer=exec_result.has_final_answer, + ) + + # Record execution in state + state.execution_history.append( + { + "iteration": iteration, + "code": code, + "stdout": exec_result.clean_stdout, + "stderr": exec_result.code_result.stderr, + "tool_traces": exec_result.tool_traces, + "has_final_answer": exec_result.has_final_answer, + } + ) + + # Check for errors - ignore warnings from pip and other non-fatal stderr + stderr = exec_result.code_result.stderr or "" + is_real_error = self._is_real_error(stderr) + + if is_real_error: + error_count += 1 + state.error_count = error_count + + if error_count > self.error_retry_attempts: + # Too many errors - give up + final_answer = ( + "I encountered too many errors while executing code. " + f"Last error: {stderr}" + ) + break + + # Build error feedback and add to conversation + error_feedback = self._build_error_feedback( + stderr, + code, + ) + contents.append( + types.Content( + role="model", + parts=[types.Part(text=response_text)], + ) + ) + contents.append( + types.Content( + role="user", + parts=[types.Part(text=error_feedback)], + ) + ) + continue + + # Reset error count on success + error_count = 0 + state.error_count = 0 + + # Check for final answer + if exec_result.has_final_answer: + final_answer = exec_result.final_answer + break + + # Add execution result to conversation and continue + contents.append( + types.Content( + role="model", + parts=[types.Part(text=response_text)], + ) + ) + + # Add execution output as user message + output_text = f"""Code execution result: +``` +{exec_result.clean_stdout} +``` +""" + contents.append( + types.Content( + role="user", + parts=[types.Part(text=output_text)], + ) + ) + + # Build final event + if final_answer is None: + final_answer = ( + "I was unable to complete the task within the allowed iterations." + ) + + # Convert final_answer to string if needed + if not isinstance(final_answer, str): + import json + + try: + final_answer = json.dumps(final_answer) + except (TypeError, ValueError): + final_answer = str(final_answer) + + # Update state in context + ctx.agent_states[self.name] = state.model_dump() + + # Yield final event + yield Event( + invocation_id=ctx.invocation_id, + author=self.name, + branch=ctx.branch, + content=types.Content( + role="model", + parts=[types.Part(text=final_answer)], + ), + actions=EventActions( + agent_state=state.model_dump(), + ), + ) + + @model_validator(mode="after") + def _validate_model(self) -> CodingAgent: + """Validate the model after construction.""" + return self + + def cleanup(self) -> None: + """Clean up resources.""" + if self._coding_executor: + self._coding_executor.cleanup() + self._coding_executor = None + self._resolved_tools = None + + def __enter__(self) -> "CodingAgent": + """Enter context manager and return self.""" + return self + + def __exit__(self, exc_type, exc, traceback) -> bool: + """Exit context manager and clean up resources.""" + self.cleanup() + return False diff --git a/src/google/adk/agents/coding_agent_config.py b/src/google/adk/agents/coding_agent_config.py new file mode 100644 index 0000000000..89e42f1691 --- /dev/null +++ b/src/google/adk/agents/coding_agent_config.py @@ -0,0 +1,175 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import FrozenSet +from typing import List +from typing import Literal +from typing import Optional +from typing import Union + +from pydantic import Field + +from ..code_executors.allowlist_validator import DEFAULT_SAFE_IMPORTS +from ..tools.tool_configs import ToolConfig +from ..utils.feature_decorator import experimental +from .base_agent_config import BaseAgentConfig +from .common_configs import CodeConfig + +# Additional data science imports commonly used in coding agents +_DATA_SCIENCE_IMPORTS: FrozenSet[str] = frozenset({ + "numpy", + "numpy.*", + "pandas", + "pandas.*", + "scipy", + "scipy.*", + "matplotlib", + "matplotlib.*", +}) + +# Extended safe imports including common data science packages +_EXTENDED_SAFE_IMPORTS: FrozenSet[str] = ( + DEFAULT_SAFE_IMPORTS | _DATA_SCIENCE_IMPORTS +) + + +@experimental +class CodingAgentConfig(BaseAgentConfig): + """Configuration for CodingAgent. + + This config extends BaseAgentConfig with fields specific to agents that + generate and execute Python code to accomplish tasks using tools. + """ + + agent_class: Union[Literal["CodingAgent"], str] = Field( + default="CodingAgent", + description="The class of the agent. Must be CodingAgent.", + ) + + model: str = Field( + default="", + description=( + "The model to use for the agent. When not set, the agent will " + "inherit the model from its ancestor or use the default model." + ), + ) + + model_code: Optional[CodeConfig] = Field( + default=None, + description=( + "Optional. Code reference to a custom model instance. " + "Takes precedence over the model field if both are set." + ), + ) + + instruction: str = Field( + default="", + description=( + "Dynamic instructions for the agent, guiding its behavior. " + "Can contain placeholders like {variable_name} that will be " + "resolved at runtime using session state and context." + ), + ) + + tools: Optional[List[ToolConfig]] = Field( + default=None, + description=( + "Optional. The list of tools available to the agent. " + "Tools are exposed as Python functions that the agent can call " + "in the generated code." + ), + ) + + code_executor: Optional[CodeConfig] = Field( + default=None, + description=( + "Optional. Code reference to a custom code executor instance. " + "If not set, a default ContainerCodeExecutor will be used." + ), + ) + + authorized_imports: FrozenSet[str] = Field( + default=_EXTENDED_SAFE_IMPORTS, + description=( + "Set of allowed import names/patterns. Supports wildcards " + '(e.g., "collections.*" allows all collections submodules). ' + "Any imports not in this set will be rejected before execution." + ), + ) + + max_iterations: int = Field( + default=10, + ge=1, + le=100, + description=( + "Maximum number of ReAct loop iterations. Each iteration " + "involves generating code, executing it, and processing results." + ), + ) + + error_retry_attempts: int = Field( + default=2, + ge=0, + le=10, + description=( + "Number of times to retry code execution on errors. " + "Error messages are fed back to the LLM for correction." + ), + ) + + stateful: bool = Field( + default=False, + description=( + "Whether to maintain state across iterations. If True, " + "execution history is preserved and re-executed to restore state." + ), + ) + + tool_server_host: Optional[str] = Field( + default=None, + description=( + "Host address for the tool execution server. If not set, " + "auto-detection will try host.docker.internal first, " + "then fall back to 172.17.0.1 for Linux." + ), + ) + + tool_server_port: int = Field( + default=8765, + ge=1024, + le=65535, + description="Port for the tool execution server.", + ) + + before_model_callbacks: Optional[List[CodeConfig]] = Field( + default=None, + description="Optional. Callbacks to be called before calling the LLM.", + ) + + after_model_callbacks: Optional[List[CodeConfig]] = Field( + default=None, + description="Optional. Callbacks to be called after calling the LLM.", + ) + + before_tool_callbacks: Optional[List[CodeConfig]] = Field( + default=None, + description="Optional. Callbacks to be called before calling a tool.", + ) + + after_tool_callbacks: Optional[List[CodeConfig]] = Field( + default=None, + description="Optional. Callbacks to be called after calling a tool.", + ) diff --git a/src/google/adk/code_executors/__init__.py b/src/google/adk/code_executors/__init__.py index 1cf04a477d..5e5fa7d1e7 100644 --- a/src/google/adk/code_executors/__init__.py +++ b/src/google/adk/code_executors/__init__.py @@ -21,59 +21,118 @@ from .code_executor_context import CodeExecutorContext from .unsafe_local_code_executor import UnsafeLocalCodeExecutor -logger = logging.getLogger('google_adk.' + __name__) +logger = logging.getLogger("google_adk." + __name__) __all__ = [ - 'BaseCodeExecutor', - 'BuiltInCodeExecutor', - 'CodeExecutorContext', - 'UnsafeLocalCodeExecutor', - 'VertexAiCodeExecutor', - 'ContainerCodeExecutor', - 'GkeCodeExecutor', - 'AgentEngineSandboxCodeExecutor', + "BaseCodeExecutor", + "BuiltInCodeExecutor", + "CodeExecutorContext", + "UnsafeLocalCodeExecutor", + "VertexAiCodeExecutor", + "ContainerCodeExecutor", + "GkeCodeExecutor", + "AgentEngineSandboxCodeExecutor", + # CodingAgent components + "AllowlistValidator", + "CodingAgentCodeExecutor", + "ToolCodeGenerator", + "ToolExecutionServer", ] def __getattr__(name: str): - if name == 'VertexAiCodeExecutor': - try: - from .vertex_ai_code_executor import VertexAiCodeExecutor - - return VertexAiCodeExecutor - except ImportError as e: - raise ImportError( - 'VertexAiCodeExecutor requires additional dependencies. ' - 'Please install with: pip install "google-adk[extensions]"' - ) from e - elif name == 'ContainerCodeExecutor': - try: - from .container_code_executor import ContainerCodeExecutor - - return ContainerCodeExecutor - except ImportError as e: - raise ImportError( - 'ContainerCodeExecutor requires additional dependencies. ' - 'Please install with: pip install "google-adk[extensions]"' - ) from e - elif name == 'GkeCodeExecutor': - try: - from .gke_code_executor import GkeCodeExecutor - - return GkeCodeExecutor - except ImportError as e: - raise ImportError( - 'GkeCodeExecutor requires additional dependencies. ' - 'Please install with: pip install "google-adk[extensions]"' - ) from e - elif name == 'AgentEngineSandboxCodeExecutor': - try: - from .agent_engine_sandbox_code_executor import AgentEngineSandboxCodeExecutor - - return AgentEngineSandboxCodeExecutor - except ImportError as e: - raise ImportError( - 'AgentEngineSandboxCodeExecutor requires additional dependencies. ' - 'Please install with: pip install "google-adk[extensions]"' - ) from e - raise AttributeError(f"module '{__name__}' has no attribute '{name}'") + if name == "VertexAiCodeExecutor": + try: + from .vertex_ai_code_executor import VertexAiCodeExecutor + + return VertexAiCodeExecutor + except ImportError as e: + raise ImportError( + "VertexAiCodeExecutor requires additional dependencies. " + 'Please install with: pip install "google-adk[extensions]"' + ) from e + elif name == "ContainerCodeExecutor": + try: + from .container_code_executor import ContainerCodeExecutor + + return ContainerCodeExecutor + except ImportError as e: + raise ImportError( + "ContainerCodeExecutor requires additional dependencies. " + 'Please install with: pip install "google-adk[extensions]"' + ) from e + elif name == "GkeCodeExecutor": + try: + from .gke_code_executor import GkeCodeExecutor + + return GkeCodeExecutor + except ImportError as e: + raise ImportError( + "GkeCodeExecutor requires additional dependencies. " + 'Please install with: pip install "google-adk[extensions]"' + ) from e + elif name == "AgentEngineSandboxCodeExecutor": + try: + from .agent_engine_sandbox_code_executor import ( + AgentEngineSandboxCodeExecutor, + ) + + return AgentEngineSandboxCodeExecutor + except ImportError as e: + raise ImportError( + "AgentEngineSandboxCodeExecutor requires additional dependencies. " + 'Please install with: pip install "google-adk[extensions]"' + ) from e + elif name == "AllowlistValidator": + try: + from .allowlist_validator import AllowlistValidator + + return AllowlistValidator + except ImportError as e: + raise ImportError( + "AllowlistValidator requires additional dependencies. " + 'Please install with: pip install "google-adk[extensions]"' + ) from e + elif name == "CodingAgentCodeExecutor": + try: + from .coding_agent_code_executor import CodingAgentCodeExecutor + + return CodingAgentCodeExecutor + except ImportError as e: + raise ImportError( + "CodingAgentCodeExecutor requires additional dependencies. " + 'Please install with: pip install "google-adk[extensions]"' + ) from e + elif name == "ToolCodeGenerator": + try: + from .tool_code_generator import generate_full_code_with_stubs + from .tool_code_generator import generate_runtime_header + from .tool_code_generator import generate_system_prompt + from .tool_code_generator import generate_tool_stubs + + # Return module-like object with functions + class ToolCodeGenerator: + generate_full_code_with_stubs = staticmethod( + generate_full_code_with_stubs + ) + generate_runtime_header = staticmethod(generate_runtime_header) + generate_system_prompt = staticmethod(generate_system_prompt) + generate_tool_stubs = staticmethod(generate_tool_stubs) + + return ToolCodeGenerator + except ImportError as e: + raise ImportError( + "ToolCodeGenerator requires additional dependencies. " + 'Please install with: pip install "google-adk[extensions]"' + ) from e + elif name == "ToolExecutionServer": + try: + from .tool_execution_server import ToolExecutionServer + + return ToolExecutionServer + except ImportError as e: + raise ImportError( + "ToolExecutionServer requires additional dependencies. " + 'Please install with: pip install "google-adk[extensions]"' + ) from e + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") diff --git a/src/google/adk/code_executors/allowlist_validator.py b/src/google/adk/code_executors/allowlist_validator.py new file mode 100644 index 0000000000..57afc0e534 --- /dev/null +++ b/src/google/adk/code_executors/allowlist_validator.py @@ -0,0 +1,353 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Import allowlist validation for code execution security.""" + +from __future__ import annotations + +import ast +from dataclasses import dataclass +from dataclasses import field +import fnmatch +import logging +from typing import FrozenSet +from typing import List +from typing import Set + +logger = logging.getLogger("google_adk." + __name__) + + +# Default set of safe imports that are always allowed +DEFAULT_SAFE_IMPORTS: FrozenSet[str] = frozenset({ + # Standard library - safe modules + "json", + "math", + "re", + "datetime", + "collections", + "collections.*", + "itertools", + "functools", + "operator", + "string", + "textwrap", + "unicodedata", + "decimal", + "fractions", + "random", + "statistics", + "typing", + "typing.*", + "dataclasses", + "enum", + "abc", + "copy", + "pprint", + "reprlib", + "numbers", + "cmath", + "time", + "calendar", + "hashlib", + "hmac", + "base64", + "binascii", + "html", + "html.*", + "urllib.parse", + "uuid", + "struct", + "codecs", + "locale", + "gettext", + "bisect", + "heapq", + "array", + "weakref", + "types", + "contextlib", + "warnings", + "traceback", + "linecache", + "difflib", + "graphlib", + "zoneinfo", +}) + + +class ImportValidationError(Exception): + """Exception raised when import validation fails. + + Attributes: + violations: List of import violations found. + code: The code that was validated. + """ + + def __init__( + self, + violations: List[str], + code: str, + ): + self.violations = violations + self.code = code + violation_str = "\n".join(f" - {v}" for v in violations) + super().__init__( + "Import validation failed. Unauthorized imports" + f" found:\n{violation_str}" + ) + + +@dataclass +class ImportInfo: + """Information about an import statement. + + Attributes: + module: The module being imported. + names: Names being imported from the module (for 'from' imports). + alias: Alias for the import (if any). + line_number: Line number in the source code. + is_from_import: Whether this is a 'from X import Y' style import. + """ + + module: str + names: List[str] = field(default_factory=list) + alias: str = "" + line_number: int = 0 + is_from_import: bool = False + + +def extract_imports(code: str) -> List[ImportInfo]: + """Extract all import statements from Python code using AST. + + Args: + code: Python source code to analyze. + + Returns: + List of ImportInfo objects describing each import. + + Raises: + SyntaxError: If the code cannot be parsed. + """ + imports = [] + + try: + tree = ast.parse(code) + except SyntaxError as e: + logger.warning("Failed to parse code for import extraction: %s", e) + raise + + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + imports.append( + ImportInfo( + module=alias.name, + alias=alias.asname or "", + line_number=node.lineno, + is_from_import=False, + ) + ) + elif isinstance(node, ast.ImportFrom): + module = node.module or "" + names = [alias.name for alias in node.names] + for alias in node.names: + imports.append( + ImportInfo( + module=module, + names=[alias.name], + alias=alias.asname or "", + line_number=node.lineno, + is_from_import=True, + ) + ) + + return imports + + +def is_import_allowed( + import_name: str, + allowlist: FrozenSet[str], +) -> bool: + """Check if an import is allowed by the allowlist. + + Supports wildcards: + - 'collections.*' allows 'collections.abc', 'collections.defaultdict', etc. + - 'numpy' allows only 'numpy', not 'numpy.array' + - 'numpy.*' allows 'numpy.array', 'numpy.linalg', etc. + + Args: + import_name: The full import name to check. + allowlist: Set of allowed import patterns. + + Returns: + True if the import is allowed, False otherwise. + """ + # Direct match + if import_name in allowlist: + return True + + # Check wildcard patterns + for pattern in allowlist: + if "*" in pattern: + if fnmatch.fnmatch(import_name, pattern): + return True + # Also check if the import is a submodule of an allowed module + # e.g., 'collections.*' should allow 'collections.abc.Callable' + base_pattern = pattern.rstrip(".*") + if import_name.startswith(base_pattern + "."): + return True + + # Check if parent module is allowed with wildcard + parts = import_name.split(".") + for i in range(len(parts)): + parent = ".".join(parts[: i + 1]) + wildcard_pattern = parent + ".*" + if wildcard_pattern in allowlist: + return True + + return False + + +def validate_imports( + code: str, + allowlist: FrozenSet[str], +) -> List[str]: + """Validate that all imports in code are in the allowlist. + + Args: + code: Python source code to validate. + allowlist: Set of allowed import patterns. + + Returns: + List of violations (empty if all imports are valid). + + Raises: + ImportValidationError: If unauthorized imports are found. + """ + violations = [] + + try: + imports = extract_imports(code) + except SyntaxError as e: + # If we can't parse, we can't validate - return syntax error as violation + violations.append(f"Syntax error in code: {e}") + return violations + + for import_info in imports: + module = import_info.module + + if import_info.is_from_import: + # For 'from X import Y', check both the module and the full path + for name in import_info.names: + full_name = f"{module}.{name}" if module else name + # Check if module is allowed OR full import path is allowed + if not ( + is_import_allowed(module, allowlist) + or is_import_allowed(full_name, allowlist) + ): + violations.append( + f"Line {import_info.line_number}: " + f'Unauthorized import "from {module} import {name}"' + ) + else: + # For 'import X', just check the module + if not is_import_allowed(module, allowlist): + violations.append( + f'Line {import_info.line_number}: Unauthorized import "{module}"' + ) + + return violations + + +def validate_imports_strict( + code: str, + allowlist: FrozenSet[str], +) -> None: + """Validate imports and raise exception if any violations found. + + Args: + code: Python source code to validate. + allowlist: Set of allowed import patterns. + + Raises: + ImportValidationError: If unauthorized imports are found. + """ + violations = validate_imports(code, allowlist) + if violations: + raise ImportValidationError(violations, code) + + +class AllowlistValidator: + """Validator for checking Python code imports against an allowlist. + + This class provides a stateful validator that can be reused for multiple + validations with the same allowlist. + + Attributes: + allowlist: The set of allowed import patterns. + """ + + def __init__( + self, + allowlist: FrozenSet[str] = DEFAULT_SAFE_IMPORTS, + additional_imports: FrozenSet[str] = frozenset(), + ): + """Initialize the validator with an allowlist. + + Args: + allowlist: Base set of allowed import patterns. + additional_imports: Additional imports to allow beyond the base set. + """ + self.allowlist = allowlist | additional_imports + + def validate(self, code: str) -> List[str]: + """Validate imports in code. + + Args: + code: Python source code to validate. + + Returns: + List of violations (empty if all imports are valid). + """ + return validate_imports(code, self.allowlist) + + def validate_strict(self, code: str) -> None: + """Validate imports and raise if any violations found. + + Args: + code: Python source code to validate. + + Raises: + ImportValidationError: If unauthorized imports are found. + """ + validate_imports_strict(code, self.allowlist) + + def is_allowed(self, import_name: str) -> bool: + """Check if a single import is allowed. + + Args: + import_name: The import name to check. + + Returns: + True if allowed, False otherwise. + """ + return is_import_allowed(import_name, self.allowlist) + + def add_allowed_imports(self, imports: Set[str]) -> None: + """Add additional allowed imports. + + Args: + imports: Set of import patterns to allow. + """ + self.allowlist = self.allowlist | frozenset(imports) diff --git a/src/google/adk/code_executors/coding_agent_code_executor.py b/src/google/adk/code_executors/coding_agent_code_executor.py new file mode 100644 index 0000000000..c1394a8ef8 --- /dev/null +++ b/src/google/adk/code_executors/coding_agent_code_executor.py @@ -0,0 +1,515 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Code executor for CodingAgent with tool injection support. + +This module provides a code executor that wraps an underlying executor +(e.g., ContainerCodeExecutor) and adds tool injection via HTTP IPC. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from dataclasses import field +import hashlib +import json +import logging +import re +from typing import Any +from typing import Dict +from typing import FrozenSet +from typing import List +from typing import Optional +from typing import TYPE_CHECKING + +from pydantic import Field +from pydantic import PrivateAttr +from typing_extensions import override + +from ..tools.base_tool import BaseTool +from .allowlist_validator import AllowlistValidator +from .allowlist_validator import DEFAULT_SAFE_IMPORTS +from .allowlist_validator import ImportValidationError +from .base_code_executor import BaseCodeExecutor +from .code_execution_utils import CodeExecutionInput +from .code_execution_utils import CodeExecutionResult +from .tool_code_generator import generate_full_code_with_stubs +from .tool_execution_server import detect_docker_host_address +from .tool_execution_server import ToolExecutionServer +from .tool_execution_server import ToolTrace + +if TYPE_CHECKING: + from ..agents.invocation_context import InvocationContext + from ..tools.tool_context import ToolContext + +logger = logging.getLogger("google_adk." + __name__) + + +# Markers for extracting data from execution output +TOOL_TRACE_MARKER = "__TOOL_TRACE__:" +FINAL_ANSWER_MARKER = "__FINAL_ANSWER__:" + + +@dataclass +class ExecutionStep: + """Record of a single code execution step. + + Attributes: + code: The code that was executed. + code_hash: Hash of the code for comparison. + result: The execution result. + tool_traces: Tool call traces from this step. + success: Whether the execution succeeded. + final_answer: The final answer if one was provided. + """ + + code: str + code_hash: str = "" + result: Optional[CodeExecutionResult] = None + tool_traces: List[Dict[str, Any]] = field(default_factory=list) + success: bool = False + final_answer: Optional[Any] = None + + def __post_init__(self): + if not self.code_hash: + self.code_hash = hashlib.sha256(self.code.encode()).hexdigest()[:16] + + +@dataclass +class CodingAgentExecutionResult: + """Extended execution result with CodingAgent-specific fields. + + Attributes: + code_result: The underlying code execution result. + tool_traces: List of tool call traces. + final_answer: The final answer if one was provided. + has_final_answer: Whether a final answer was extracted. + clean_stdout: Stdout with trace markers removed. + """ + + code_result: CodeExecutionResult + tool_traces: List[Dict[str, Any]] = field(default_factory=list) + final_answer: Optional[Any] = None + has_final_answer: bool = False + clean_stdout: str = "" + + +class CodingAgentCodeExecutor(BaseCodeExecutor): + """Code executor with tool injection for CodingAgent. + + This executor wraps an underlying code executor and adds: + - Tool stub prepending for HTTP-based tool calls + - Import allowlist validation before execution + - Tool execution server lifecycle management + - Trace extraction from execution output + - Final answer detection + - History re-execution for stateful mode + + Attributes: + underlying_executor: The actual code executor to use. + tools: List of tools to make available. + authorized_imports: Set of allowed import patterns. + tool_server_host: Host for the tool server. + tool_server_port: Port for the tool server. + execution_history: List of execution steps for stateful mode. + """ + + underlying_executor: BaseCodeExecutor + tools: List[BaseTool] = Field(default_factory=list) + authorized_imports: FrozenSet[str] = DEFAULT_SAFE_IMPORTS + tool_server_host: Optional[str] = None + tool_server_port: int = 8765 + + # Internal state - use PrivateAttr for Pydantic + _tool_server: Optional[ToolExecutionServer] = PrivateAttr(default=None) + _validator: Optional[AllowlistValidator] = PrivateAttr(default=None) + _invocation_context: Optional[InvocationContext] = PrivateAttr(default=None) + _tool_context: Optional[ToolContext] = PrivateAttr(default=None) + _execution_history: List[ExecutionStep] = PrivateAttr(default_factory=list) + + class Config: + """Pydantic config.""" + + arbitrary_types_allowed = True + + def model_post_init(self, __context): + """Initialize after model construction.""" + self._validator = AllowlistValidator( + allowlist=self.authorized_imports, + ) + self._execution_history = [] + + def set_context( + self, + invocation_context: InvocationContext, + tool_context: Optional[ToolContext] = None, + ) -> None: + """Set the execution context. + + Args: + invocation_context: The invocation context. + tool_context: The tool context. + """ + self._invocation_context = invocation_context + self._tool_context = tool_context + if self._tool_server: + self._tool_server.set_context(invocation_context, tool_context) + + def _start_tool_server(self) -> None: + """Start the tool execution server if not already running.""" + if self._tool_server is not None: + return + + host = self.tool_server_host or "0.0.0.0" + self._tool_server = ToolExecutionServer( + host=host, + port=self.tool_server_port, + tools=self.tools, + invocation_context=self._invocation_context, + ) + self._tool_server.start() + + def _stop_tool_server(self) -> None: + """Stop the tool execution server.""" + if self._tool_server: + self._tool_server.stop() + self._tool_server = None + + def _get_tool_server_url(self) -> str: + """Get the URL for the tool server. + + Returns: + The tool server URL accessible from containers. + """ + if self.tool_server_host: + host = self.tool_server_host + else: + host = detect_docker_host_address() + return f"http://{host}:{self.tool_server_port}" + + def _validate_imports(self, code: str) -> None: + """Validate imports in the code against the allowlist. + + Args: + code: The code to validate. + + Raises: + ImportValidationError: If unauthorized imports are found. + """ + if self._validator: + self._validator.validate_strict(code) + + def _extract_traces_and_answer( + self, + result: CodeExecutionResult, + ) -> CodingAgentExecutionResult: + """Extract tool traces and final answer from execution output. + + Args: + result: The raw execution result. + + Returns: + Extended result with extracted data. + """ + tool_traces = [] + final_answer = None + has_final_answer = False + clean_lines = [] + + for line in result.stdout.split("\n"): + if line.startswith(TOOL_TRACE_MARKER): + try: + trace_json = line[len(TOOL_TRACE_MARKER) :] + traces = json.loads(trace_json) + tool_traces.extend(traces) + except json.JSONDecodeError as e: + logger.warning("Failed to parse tool trace: %s", e) + elif line.startswith(FINAL_ANSWER_MARKER): + answer_str = line[len(FINAL_ANSWER_MARKER) :] + try: + final_answer = json.loads(answer_str) + except json.JSONDecodeError: + # Not JSON, treat as string + final_answer = answer_str + has_final_answer = True + else: + clean_lines.append(line) + + clean_stdout = "\n".join(clean_lines).strip() + + return CodingAgentExecutionResult( + code_result=result, + tool_traces=tool_traces, + final_answer=final_answer, + has_final_answer=has_final_answer, + clean_stdout=clean_stdout, + ) + + def _should_skip_step(self, step: ExecutionStep, code_hash: str) -> bool: + """Check if an execution step can be skipped. + + For stateful mode, we can skip re-executing code if: + - The code hasn't changed (same hash) + - The previous execution succeeded + + Args: + step: The previous execution step. + code_hash: Hash of the current code. + + Returns: + True if the step can be skipped. + """ + return step.success and step.code_hash == code_hash + + def _prepend_tool_stubs(self, code: str) -> str: + """Prepend runtime header and tool stubs to user code. + + Args: + code: The user code to wrap. + + Returns: + Complete code with tool stubs. + """ + return generate_full_code_with_stubs( + user_code=code, + tools=self.tools, + tool_server_url=self._get_tool_server_url(), + ) + + def _replay_history( + self, + invocation_context: InvocationContext, + ) -> Optional[CodeExecutionResult]: + """Replay execution history for stateful mode. + + This re-executes previous successful steps to restore state + before executing new code. + + Args: + invocation_context: The invocation context. + + Returns: + The result of the last replayed step, or None if no replay needed. + """ + if not self.stateful or not self._execution_history: + return None + + last_result = None + for step in self._execution_history: + if step.success: + # Re-execute to restore state + full_code = self._prepend_tool_stubs(step.code) + input_data = CodeExecutionInput(code=full_code) + last_result = self.underlying_executor.execute_code( + invocation_context=invocation_context, + code_execution_input=input_data, + ) + if last_result.stderr: + raise RuntimeError( + f'Failed to replay history step with hash {step.code_hash}. ' + f'Error: {last_result.stderr}' + ) + logger.debug("Replayed history step: %s", step.code_hash) + + return last_result + + @override + def execute_code( + self, + invocation_context: InvocationContext, + code_execution_input: CodeExecutionInput, + ) -> CodeExecutionResult: + """Execute code with tool injection. + + Args: + invocation_context: The invocation context. + code_execution_input: The code to execute. + + Returns: + The execution result. + """ + user_code = code_execution_input.code + + # Validate imports first (security check before execution) + try: + self._validate_imports(user_code) + except ImportValidationError as e: + return CodeExecutionResult( + stdout="", + stderr=str(e), + output_files=[], + ) + + # Start tool server if needed + self._start_tool_server() + + # Set context on tool server + if self._tool_server: + self._tool_server.set_context( + invocation_context, + self._tool_context, + ) + self._tool_server.clear_traces() + + # Replay history for stateful mode + if self.stateful: + self._replay_history(invocation_context) + + # Prepend tool stubs to user code + full_code = self._prepend_tool_stubs(user_code) + + # Execute the code + input_with_stubs = CodeExecutionInput( + code=full_code, + input_files=code_execution_input.input_files, + execution_id=code_execution_input.execution_id, + ) + + result = self.underlying_executor.execute_code( + invocation_context=invocation_context, + code_execution_input=input_with_stubs, + ) + + # Extract traces and final answer + extended_result = self._extract_traces_and_answer(result) + + # Record execution step for stateful mode + step = ExecutionStep( + code=user_code, + result=result, + tool_traces=extended_result.tool_traces, + success=not result.stderr, + final_answer=extended_result.final_answer, + ) + self._execution_history.append(step) + + # Return result with clean stdout (traces stripped) + return CodeExecutionResult( + stdout=extended_result.clean_stdout, + stderr=result.stderr, + output_files=result.output_files, + ) + + def execute_code_extended( + self, + invocation_context: InvocationContext, + code_execution_input: CodeExecutionInput, + ) -> CodingAgentExecutionResult: + """Execute code and return extended result with traces. + + Args: + invocation_context: The invocation context. + code_execution_input: The code to execute. + + Returns: + Extended execution result with tool traces and final answer. + """ + user_code = code_execution_input.code + + # Validate imports first + try: + self._validate_imports(user_code) + except ImportValidationError as e: + return CodingAgentExecutionResult( + code_result=CodeExecutionResult( + stdout="", + stderr=str(e), + output_files=[], + ), + tool_traces=[], + final_answer=None, + has_final_answer=False, + clean_stdout="", + ) + + # Start tool server if needed + self._start_tool_server() + + # Set context on tool server + if self._tool_server: + self._tool_server.set_context( + invocation_context, + self._tool_context, + ) + self._tool_server.clear_traces() + + # Replay history for stateful mode + if self.stateful: + self._replay_history(invocation_context) + + # Prepend tool stubs to user code + full_code = self._prepend_tool_stubs(user_code) + + # Execute the code + input_with_stubs = CodeExecutionInput( + code=full_code, + input_files=code_execution_input.input_files, + execution_id=code_execution_input.execution_id, + ) + + result = self.underlying_executor.execute_code( + invocation_context=invocation_context, + code_execution_input=input_with_stubs, + ) + + # Extract traces and final answer + extended_result = self._extract_traces_and_answer(result) + + # Record execution step for stateful mode + step = ExecutionStep( + code=user_code, + result=result, + tool_traces=extended_result.tool_traces, + success=not result.stderr, + final_answer=extended_result.final_answer, + ) + self._execution_history.append(step) + + return extended_result + + def get_execution_history(self) -> List[ExecutionStep]: + """Get the execution history. + + Returns: + List of execution steps. + """ + return self._execution_history.copy() + + def clear_execution_history(self) -> None: + """Clear the execution history.""" + self._execution_history.clear() + + def get_tool_traces(self) -> List[ToolTrace]: + """Get tool traces from the server. + + Returns: + List of tool traces. + """ + if self._tool_server: + return self._tool_server.get_traces() + return [] + + def cleanup(self) -> None: + """Clean up resources.""" + self._stop_tool_server() + self._execution_history.clear() + + def __enter__(self) -> CodingAgentCodeExecutor: + """Enter context manager and return self.""" + return self + + def __exit__(self, exc_type, exc, traceback) -> bool: + """Exit context manager and clean up resources.""" + self.cleanup() + return False diff --git a/src/google/adk/code_executors/tool_code_generator.py b/src/google/adk/code_executors/tool_code_generator.py new file mode 100644 index 0000000000..e0c0b51440 --- /dev/null +++ b/src/google/adk/code_executors/tool_code_generator.py @@ -0,0 +1,479 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tool code generator for CodingAgent. + +This module provides functions to generate Python code stubs and runtime +headers that allow generated code to call ADK tools via HTTP IPC. +""" + +from __future__ import annotations + +import json +import logging +import textwrap +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ..tools.base_tool import BaseTool + +logger = logging.getLogger("google_adk." + __name__) + + +# Runtime header template that provides HTTP-based tool calling +RUNTIME_HEADER_TEMPLATE = ''' +# ============================================================================ +# ADK CodingAgent Runtime Header - DO NOT MODIFY +# ============================================================================ +import json as __adk_json +import urllib.request as __adk_urllib_request +import urllib.error as __adk_urllib_error + +__ADK_TOOL_SERVER_URL = "{tool_server_url}" +__ADK_TOOL_TRACES = [] + +def _call_adk_tool(__tool_name: str, **kwargs) -> dict: + """Call an ADK tool via HTTP IPC. + + Args: + __tool_name: Name of the tool to call. + **kwargs: Arguments to pass to the tool. + + Returns: + The tool result as a dictionary. + """ + global __ADK_TOOL_TRACES + + request_data = __adk_json.dumps({{ + "tool_name": __tool_name, + "args": kwargs, + }}).encode("utf-8") + + req = __adk_urllib_request.Request( + __ADK_TOOL_SERVER_URL + "/tool_call", + data=request_data, + headers={{"Content-Type": "application/json"}}, + method="POST", + ) + + try: + with __adk_urllib_request.urlopen(req, timeout=300) as response: + result = __adk_json.loads(response.read().decode("utf-8")) + # Record the trace + __ADK_TOOL_TRACES.append({{ + "tool_name": __tool_name, + "args": kwargs, + "result": result, + "success": True, + }}) + return result + except __adk_urllib_error.HTTPError as e: + error_body = e.read().decode("utf-8") if e.fp else str(e) + __ADK_TOOL_TRACES.append({{ + "tool_name": __tool_name, + "args": kwargs, + "error": error_body, + "success": False, + }}) + raise RuntimeError(f"Tool call failed: {{error_body}}") from e + except __adk_urllib_error.URLError as e: + __ADK_TOOL_TRACES.append({{ + "tool_name": __tool_name, + "args": kwargs, + "error": str(e), + "success": False, + }}) + raise RuntimeError(f"Tool server connection failed: {{e}}") from e + +def __get_tool_traces() -> list: + """Get all tool call traces.""" + return __ADK_TOOL_TRACES + +def __clear_tool_traces(): + """Clear all tool call traces.""" + global __ADK_TOOL_TRACES + __ADK_TOOL_TRACES = [] + +# Final answer function for terminating execution +__FINAL_ANSWER_VALUE = None + +def final_answer(result): + """Mark the final answer and terminate the code execution. + + Args: + result: The final result to return to the user. + """ + global __FINAL_ANSWER_VALUE + __FINAL_ANSWER_VALUE = result + print(f"__FINAL_ANSWER__:{{__adk_json.dumps(result) if not isinstance(result, str) else result}}") + +# ============================================================================ +# End of Runtime Header +# ============================================================================ + +''' + + +def generate_runtime_header( + tool_server_url: str, +) -> str: + """Generate the runtime header with HTTP client and helper functions. + + Args: + tool_server_url: URL of the tool execution server. + + Returns: + Python code string containing the runtime header. + """ + return RUNTIME_HEADER_TEMPLATE.format(tool_server_url=tool_server_url) + + +def _get_schema_type(schema: Any) -> str: + """Get the type from a schema (dict or Pydantic Schema object). + + Args: + schema: JSON schema dict or google.genai.types.Schema object. + + Returns: + The type as a lowercase string. + """ + if hasattr(schema, "type"): + # Pydantic Schema object from google.genai.types + schema_type = schema.type + if schema_type is None: + return "any" + # Handle enum (Type.STRING -> "string") + if hasattr(schema_type, "value"): + return schema_type.value.lower() + return str(schema_type).lower() + elif isinstance(schema, dict): + return schema.get("type", "any") + return "any" + + +def _get_schema_attr(schema: Any, attr: str, default: Any = None) -> Any: + """Get an attribute from a schema (dict or Pydantic Schema object). + + Args: + schema: JSON schema dict or google.genai.types.Schema object. + attr: The attribute name to get. + default: Default value if attribute not found. + + Returns: + The attribute value or default. + """ + if hasattr(schema, attr): + return getattr(schema, attr, default) + elif isinstance(schema, dict): + return schema.get(attr, default) + return default + + +def _get_python_type_hint(schema: Any) -> str: + """Convert JSON schema type to Python type hint. + + Args: + schema: JSON schema dict or google.genai.types.Schema object. + + Returns: + Python type hint string. + """ + schema_type = _get_schema_type(schema) + + type_mapping = { + "string": "str", + "integer": "int", + "number": "float", + "boolean": "bool", + "array": "list", + "object": "dict", + } + + if schema_type == "array": + items = _get_schema_attr(schema, "items", {}) + if items: + item_type = _get_python_type_hint(items) + return f"list[{item_type}]" + return "list" + elif schema_type == "object": + return "dict" + + return type_mapping.get(schema_type, "Any") + + +def _generate_tool_stub(tool: BaseTool) -> str: + """Generate a Python function stub for a single tool. + + Args: + tool: The BaseTool to generate a stub for. + + Returns: + Python code string for the tool stub function. + """ + decl = tool._get_declaration() + if not decl: + logger.warning( + "Tool %s has no declaration, skipping stub generation", tool.name + ) + return "" + + # Build parameter list with type hints + params = [] + param_docs = [] + + if decl.parameters and decl.parameters.properties: + required = set(decl.parameters.required or []) + + for param_name, param_schema in decl.parameters.properties.items(): + type_hint = _get_python_type_hint(param_schema) + description = _get_schema_attr(param_schema, "description", "") + + if param_name in required: + params.append(f"{param_name}: {type_hint}") + else: + params.append(f"{param_name}: {type_hint} = None") + + param_docs.append(f" {param_name}: {description}") + + param_str = ", ".join(params) + param_doc_str = "\n".join(param_docs) if param_docs else " None" + + # Build the function stub + stub = f''' +def {tool.name}({param_str}) -> dict: + """{tool.description} + + Args: +{param_doc_str} + + Returns: + Tool execution result as a dictionary. + """ + kwargs = {{k: v for k, v in locals().items() if v is not None}} + response = _call_adk_tool("{tool.name}", **kwargs) + # On success, the response is a dict with a "result" key. + # On failure, _call_adk_tool raises an exception. + return response["result"] + +''' + return stub + + +def generate_tool_stubs(tools: List[BaseTool]) -> str: + """Generate Python function stubs for all tools. + + Args: + tools: List of tools to generate stubs for. + + Returns: + Python code string containing all tool stubs. + """ + stubs = [ + ( + "# ============================================================================" + ), + "# Tool Function Stubs", + ( + "# ============================================================================" + ), + "", + ] + + for tool in tools: + stub = _generate_tool_stub(tool) + if stub: + stubs.append(stub) + + return "\n".join(stubs) + + +def generate_final_answer_stub() -> str: + """Generate the final_answer function documentation. + + This is included in the runtime header, but we generate additional + documentation here for the system prompt. + + Returns: + Documentation string about the final_answer function. + """ + return """ +The `final_answer(result)` function is available to mark your final result. +Call this function when you have completed the task and have a result to return. +Example: `final_answer("The calculation result is 42")` +""" + + +# Few-shot examples for the system prompt +SYSTEM_PROMPT_EXAMPLES = """ +## Examples + +### Example 1 - Using a tool to search for information: +```tool_code +# Search for relevant information +result = web_search(query="Python async best practices") +# Display the findings +for snippet in result.get("snippets", [])[:3]: + print(snippet) +``` + +### Example 2 - Processing data and providing a final answer: +```tool_code +# Read and process data +data = read_file(path="sales_data.csv") +rows = data.get("rows", []) + +# Calculate the total +total = sum(float(row.get("amount", 0)) for row in rows) + +# Provide the final answer +final_answer(f"The total sales amount is ${total:.2f}") +``` + +### Example 3 - Multi-step reasoning with tool calls: +```tool_code +# Step 1: Get the current weather +weather = get_weather(city="San Francisco") +temp = weather.get("temperature", "unknown") +print(f"Current temperature: {temp}") +``` + +Then, after seeing the output: +```tool_code +# Step 2: Based on temperature, provide recommendation +if temp > 70: + recommendation = "It's warm! Consider light clothing." +else: + recommendation = "It might be cool. Bring a jacket." + +final_answer(recommendation) +``` +""" + + +def generate_system_prompt( + tools: List[BaseTool], + custom_instruction: str = "", +) -> str: + """Generate the system prompt for the CodingAgent. + + Args: + tools: List of available tools. + custom_instruction: Additional custom instructions. + + Returns: + Complete system prompt string. + """ + # Build tool documentation + tool_docs = [] + for tool in tools: + decl = tool._get_declaration() + if decl: + params_doc = "" + if decl.parameters and decl.parameters.properties: + param_lines = [] + required = set(decl.parameters.required or []) + for name, schema in decl.parameters.properties.items(): + type_hint = _get_python_type_hint(schema) + req_marker = " (required)" if name in required else " (optional)" + desc = _get_schema_attr(schema, "description", "") + param_lines.append(f" - {name}: {type_hint}{req_marker} - {desc}") + params_doc = "\n".join(param_lines) + + tool_docs.append(f""" +### {tool.name} +{tool.description} + +Parameters: +{params_doc if params_doc else " None"} +""") + + tools_section = "\n".join(tool_docs) if tool_docs else "No tools available." + + system_prompt = f"""You are a coding agent that solves tasks by writing and executing Python code. + +## How to Respond + +1. **Write Python code** in code blocks marked with ```tool_code +2. **Use available tools** by calling them as Python functions +3. **Print intermediate results** to see outputs and make decisions +4. **Call final_answer()** when you have the final result + +## Available Tools + +{tools_section} + +## Special Functions + +- `final_answer(result)`: Call this to provide your final answer and complete the task. +- `print(...)`: Use print statements to see intermediate results. + +{SYSTEM_PROMPT_EXAMPLES} + +## Important Guidelines + +1. **One step at a time**: Write code for one logical step, wait for output, then continue. +2. **Always print results**: Use print() to see what tools return. +3. **Handle errors gracefully**: If a tool fails, try an alternative approach. +4. **Call final_answer()**: When done, call final_answer() with your result. +5. **Install packages if needed**: If you need external libraries (pandas, matplotlib, numpy, etc.), install them first using subprocess: + ```python + import subprocess + subprocess.run(["pip", "install", "-q", "pandas", "matplotlib", "seaborn"], check=True) + ``` + Then import and use them normally. + +{custom_instruction} +""" + + return system_prompt.strip() + + +def generate_full_code_with_stubs( + user_code: str, + tools: List[BaseTool], + tool_server_url: str, +) -> str: + """Generate complete executable code with runtime header and tool stubs. + + Args: + user_code: The user-generated code to execute. + tools: List of available tools. + tool_server_url: URL of the tool execution server. + + Returns: + Complete Python code ready for execution. + """ + runtime_header = generate_runtime_header(tool_server_url) + tool_stubs = generate_tool_stubs(tools) + + full_code = f"""{runtime_header} +{tool_stubs} +# ============================================================================ +# User Code +# ============================================================================ + +{user_code} + +# ============================================================================ +# Output tool traces for extraction +# ============================================================================ +import json as __output_json +print("__TOOL_TRACE__:" + __output_json.dumps(__get_tool_traces())) +""" + + return full_code diff --git a/src/google/adk/code_executors/tool_execution_server.py b/src/google/adk/code_executors/tool_execution_server.py new file mode 100644 index 0000000000..de13b15820 --- /dev/null +++ b/src/google/adk/code_executors/tool_execution_server.py @@ -0,0 +1,366 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tool execution server for CodingAgent. + +This module provides a FastAPI server that handles tool execution requests +from code running in containers. It routes requests to the appropriate +ADK tools with full ToolContext support. +""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +from dataclasses import field +import json +import logging +import os +import socket +import threading +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import TYPE_CHECKING + +from fastapi import FastAPI +from fastapi import HTTPException +from pydantic import BaseModel +import uvicorn + +if TYPE_CHECKING: + from ..agents.invocation_context import InvocationContext + from ..tools.base_tool import BaseTool + from ..tools.tool_context import ToolContext + +logger = logging.getLogger("google_adk." + __name__) + + +class ToolCallRequest(BaseModel): + """Request model for tool calls.""" + + tool_name: str + args: Dict[str, Any] + + +class ToolCallResponse(BaseModel): + """Response model for tool calls.""" + + result: Any + success: bool + error: Optional[str] = None + + +@dataclass +class ToolTrace: + """Record of a tool call for debugging and telemetry.""" + + tool_name: str + args: Dict[str, Any] + result: Any = None + error: Optional[str] = None + success: bool = False + duration_ms: float = 0.0 + + +def detect_docker_host_address() -> str: + """Detect the appropriate host address for Docker containers. + + On macOS and Windows (Docker Desktop), use host.docker.internal. + On Linux, use 172.17.0.1 (default Docker bridge network gateway). + + Note: host.docker.internal only resolves from within containers, + not from the host machine, so we check the platform instead. + + Returns: + The detected host address. + """ + import platform + + system = platform.system().lower() + + # macOS and Windows use Docker Desktop which supports host.docker.internal + if system in ("darwin", "windows"): + return "host.docker.internal" + + # Linux: use Docker bridge network gateway + return "172.17.0.1" + + +class ToolExecutionServer: + """FastAPI server for executing ADK tools via HTTP. + + This server is designed to run on the host machine and receive tool + execution requests from code running in Docker containers. + + Attributes: + host: Host address to bind the server to. + port: Port to bind the server to. + tools: Dictionary mapping tool names to tool instances. + invocation_context: The current invocation context. + tool_context: The current tool context. + traces: List of tool call traces. + """ + + def __init__( + self, + host: str = "0.0.0.0", + port: int = 8765, + tools: Optional[List[BaseTool]] = None, + invocation_context: Optional[InvocationContext] = None, + ): + """Initialize the tool execution server. + + Args: + host: Host address to bind to. + port: Port to bind to. + tools: List of tools to make available. + invocation_context: The invocation context for tool calls. + """ + self.host = host + self.port = port + self.tools: Dict[str, BaseTool] = {} + self.invocation_context = invocation_context + self.tool_context: Optional[ToolContext] = None + self.traces: List[ToolTrace] = [] + self._server: Optional[uvicorn.Server] = None + self._server_thread: Optional[threading.Thread] = None + self._app = self._create_app() + + if tools: + for tool in tools: + self.register_tool(tool) + + def _create_app(self) -> FastAPI: + """Create the FastAPI application with routes.""" + app = FastAPI( + title="ADK Tool Execution Server", + description="HTTP server for executing ADK tools from containers", + version="1.0.0", + ) + + @app.post("/tool_call", response_model=ToolCallResponse) + async def handle_tool_call(request: ToolCallRequest) -> ToolCallResponse: + """Handle a tool call request.""" + return await self._execute_tool(request.tool_name, request.args) + + @app.get("/tool_trace") + async def get_tool_traces() -> List[Dict[str, Any]]: + """Get all tool call traces.""" + return [ + { + "tool_name": t.tool_name, + "args": t.args, + "result": t.result, + "error": t.error, + "success": t.success, + "duration_ms": t.duration_ms, + } + for t in self.traces + ] + + @app.delete("/tool_trace") + async def clear_tool_traces() -> Dict[str, str]: + """Clear all tool call traces.""" + self.traces.clear() + return {"status": "cleared"} + + @app.get("/health") + async def health_check() -> Dict[str, str]: + """Health check endpoint.""" + return {"status": "healthy"} + + @app.get("/tools") + async def list_tools() -> List[str]: + """List available tools.""" + return list(self.tools.keys()) + + return app + + def register_tool(self, tool: BaseTool) -> None: + """Register a tool with the server. + + Args: + tool: The tool to register. + """ + self.tools[tool.name] = tool + logger.debug("Registered tool: %s", tool.name) + + def set_context( + self, + invocation_context: InvocationContext, + tool_context: Optional[ToolContext] = None, + ) -> None: + """Set the context for tool execution. + + Args: + invocation_context: The invocation context. + tool_context: The tool context. + """ + self.invocation_context = invocation_context + self.tool_context = tool_context + + async def _execute_tool( + self, + tool_name: str, + args: Dict[str, Any], + ) -> ToolCallResponse: + """Execute a tool and return the result. + + Args: + tool_name: Name of the tool to execute. + args: Arguments to pass to the tool. + + Returns: + The tool execution response. + """ + import time + + start_time = time.time() + trace = ToolTrace(tool_name=tool_name, args=args) + + if tool_name not in self.tools: + trace.error = f"Tool not found: {tool_name}" + trace.success = False + trace.duration_ms = (time.time() - start_time) * 1000 + self.traces.append(trace) + raise HTTPException(status_code=404, detail=trace.error) + + tool = self.tools[tool_name] + + try: + # Create a tool context if we have an invocation context + if self.invocation_context and not self.tool_context: + from ..tools.tool_context import ToolContext + + self.tool_context = ToolContext( + invocation_context=self.invocation_context, + ) + + if self.tool_context: + result = await tool.run_async(args=args, tool_context=self.tool_context) + else: + # If no context available, create a minimal mock context + # This is a fallback and shouldn't happen in normal operation + logger.warning("Executing tool %s without proper context", tool_name) + result = await tool.run_async(args=args, tool_context=None) + + trace.result = result + trace.success = True + trace.duration_ms = (time.time() - start_time) * 1000 + self.traces.append(trace) + + return ToolCallResponse(result=result, success=True) + + except Exception as e: + trace.error = str(e) + trace.success = False + trace.duration_ms = (time.time() - start_time) * 1000 + self.traces.append(trace) + logger.error("Tool execution failed: %s - %s", tool_name, e) + raise HTTPException(status_code=500, detail=str(e)) from e + + def start(self) -> None: + """Start the server in a background thread.""" + if self._server_thread and self._server_thread.is_alive(): + logger.warning("Server already running") + return + + config = uvicorn.Config( + app=self._app, + host=self.host, + port=self.port, + log_level="warning", + ) + self._server = uvicorn.Server(config) + + def run_server(): + asyncio.run(self._server.serve()) + + self._server_thread = threading.Thread(target=run_server, daemon=True) + self._server_thread.start() + + # Wait for server to be ready + self._wait_for_server() + logger.info("Tool execution server started on %s:%d", self.host, self.port) + + def _wait_for_server(self, timeout: float = 10.0) -> None: + """Wait for the server to be ready. + + Args: + timeout: Maximum time to wait in seconds. + """ + import time + + start = time.time() + while time.time() - start < timeout: + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(1) + result = sock.connect_ex(("127.0.0.1", self.port)) + sock.close() + if result == 0: + return + except Exception: + pass + time.sleep(0.1) + + logger.warning("Server may not be fully ready after %.1f seconds", timeout) + + def stop(self) -> None: + """Stop the server.""" + if self._server: + self._server.should_exit = True + if self._server_thread: + self._server_thread.join(timeout=5.0) + self._server = None + self._server_thread = None + logger.info("Tool execution server stopped") + + def get_url(self, for_container: bool = True) -> str: + """Get the URL for the server. + + Args: + for_container: If True, return URL accessible from Docker containers. + + Returns: + The server URL. + """ + if for_container: + host = detect_docker_host_address() + else: + host = "localhost" if self.host == "0.0.0.0" else self.host + return f"http://{host}:{self.port}" + + def clear_traces(self) -> None: + """Clear all tool call traces.""" + self.traces.clear() + + def get_traces(self) -> List[ToolTrace]: + """Get all tool call traces. + + Returns: + List of tool traces. + """ + return self.traces.copy() + + def __enter__(self) -> ToolExecutionServer: + """Context manager entry.""" + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Context manager exit.""" + self.stop() diff --git a/src/google/adk/telemetry/tracing.py b/src/google/adk/telemetry/tracing.py index bd964a5d07..609d96c14d 100644 --- a/src/google/adk/telemetry/tracing.py +++ b/src/google/adk/telemetry/tracing.py @@ -30,6 +30,7 @@ import logging import os from typing import Any +from typing import Optional from typing import TYPE_CHECKING from google.genai import types @@ -62,14 +63,14 @@ # By default some ADK spans include attributes with potential PII data. # This env, when set to false, allows to disable populating those attributes. -ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS = 'ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS' +ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS = "ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS" # Standard OTEL env variable to enable logging of prompt/response content. OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT = ( - 'OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT' + "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT" ) -USER_CONTENT_ELIDED = '' +USER_CONTENT_ELIDED = "" # Needed to avoid circular imports if TYPE_CHECKING: @@ -81,18 +82,18 @@ from ..tools.base_tool import BaseTool tracer = trace.get_tracer( - instrumenting_module_name='gcp.vertex.agent', + instrumenting_module_name="gcp.vertex.agent", instrumenting_library_version=version.__version__, schema_url=Schemas.V1_36_0.value, ) otel_logger = _logs.get_logger( - instrumenting_module_name='gcp.vertex.agent', + instrumenting_module_name="gcp.vertex.agent", instrumenting_library_version=version.__version__, schema_url=Schemas.V1_36_0.value, ) -logger = logging.getLogger('google_adk.' + __name__) +logger = logging.getLogger("google_adk." + __name__) def _safe_json_serialize(obj) -> str: @@ -108,10 +109,10 @@ def _safe_json_serialize(obj) -> str: try: # Try direct JSON serialization first return json.dumps( - obj, ensure_ascii=False, default=lambda o: '' + obj, ensure_ascii=False, default=lambda o: "" ) except (TypeError, OverflowError): - return '' + return "" def trace_agent_invocation( @@ -138,7 +139,7 @@ def trace_agent_invocation( """ # Required - span.set_attribute(GEN_AI_OPERATION_NAME, 'invoke_agent') + span.set_attribute(GEN_AI_OPERATION_NAME, "invoke_agent") # Conditionally Required span.set_attribute(GEN_AI_AGENT_DESCRIPTION, agent.description) @@ -161,7 +162,7 @@ def trace_tool_call( """ span = trace.get_current_span() - span.set_attribute(GEN_AI_OPERATION_NAME, 'execute_tool') + span.set_attribute(GEN_AI_OPERATION_NAME, "execute_tool") span.set_attribute(GEN_AI_TOOL_DESCRIPTION, tool.description) span.set_attribute(GEN_AI_TOOL_NAME, tool.name) @@ -171,20 +172,20 @@ def trace_tool_call( # Setting empty llm request and response (as UI expect these) while not # applicable for tool_response. - span.set_attribute('gcp.vertex.agent.llm_request', '{}') - span.set_attribute('gcp.vertex.agent.llm_response', '{}') + span.set_attribute("gcp.vertex.agent.llm_request", "{}") + span.set_attribute("gcp.vertex.agent.llm_response", "{}") if _should_add_request_response_to_spans(): span.set_attribute( - 'gcp.vertex.agent.tool_call_args', + "gcp.vertex.agent.tool_call_args", _safe_json_serialize(args), ) else: - span.set_attribute('gcp.vertex.agent.tool_call_args', '{}') + span.set_attribute("gcp.vertex.agent.tool_call_args", "{}") # Tracing tool response - tool_call_id = '' - tool_response = '' + tool_call_id = "" + tool_response = "" if ( function_response_event is not None and function_response_event.content is not None @@ -201,16 +202,16 @@ def trace_tool_call( span.set_attribute(GEN_AI_TOOL_CALL_ID, tool_call_id) if not isinstance(tool_response, dict): - tool_response = {'result': tool_response} + tool_response = {"result": tool_response} if function_response_event is not None: - span.set_attribute('gcp.vertex.agent.event_id', function_response_event.id) + span.set_attribute("gcp.vertex.agent.event_id", function_response_event.id) if _should_add_request_response_to_spans(): span.set_attribute( - 'gcp.vertex.agent.tool_response', + "gcp.vertex.agent.tool_response", _safe_json_serialize(tool_response), ) else: - span.set_attribute('gcp.vertex.agent.tool_response', '{}') + span.set_attribute("gcp.vertex.agent.tool_response", "{}") def trace_merged_tool_calls( @@ -229,34 +230,34 @@ def trace_merged_tool_calls( span = trace.get_current_span() - span.set_attribute(GEN_AI_OPERATION_NAME, 'execute_tool') - span.set_attribute(GEN_AI_TOOL_NAME, '(merged tools)') - span.set_attribute(GEN_AI_TOOL_DESCRIPTION, '(merged tools)') + span.set_attribute(GEN_AI_OPERATION_NAME, "execute_tool") + span.set_attribute(GEN_AI_TOOL_NAME, "(merged tools)") + span.set_attribute(GEN_AI_TOOL_DESCRIPTION, "(merged tools)") span.set_attribute(GEN_AI_TOOL_CALL_ID, response_event_id) # TODO(b/441461932): See if these are still necessary - span.set_attribute('gcp.vertex.agent.tool_call_args', 'N/A') - span.set_attribute('gcp.vertex.agent.event_id', response_event_id) + span.set_attribute("gcp.vertex.agent.tool_call_args", "N/A") + span.set_attribute("gcp.vertex.agent.event_id", response_event_id) try: function_response_event_json = function_response_event.model_dumps_json( exclude_none=True ) except Exception: # pylint: disable=broad-exception-caught - function_response_event_json = '' + function_response_event_json = "" if _should_add_request_response_to_spans(): span.set_attribute( - 'gcp.vertex.agent.tool_response', + "gcp.vertex.agent.tool_response", function_response_event_json, ) else: - span.set_attribute('gcp.vertex.agent.tool_response', '{}') + span.set_attribute("gcp.vertex.agent.tool_response", "{}") # Setting empty llm request and response (as UI expect these) while not # applicable for tool_response. - span.set_attribute('gcp.vertex.agent.llm_request', '{}') + span.set_attribute("gcp.vertex.agent.llm_request", "{}") span.set_attribute( - 'gcp.vertex.agent.llm_response', - '{}', + "gcp.vertex.agent.llm_response", + "{}", ) @@ -281,57 +282,57 @@ def trace_call_llm( span = span or trace.get_current_span() # Special standard Open Telemetry GenaI attributes that indicate # that this is a span related to a Generative AI system. - span.set_attribute('gen_ai.system', 'gcp.vertex.agent') - span.set_attribute('gen_ai.request.model', llm_request.model) + span.set_attribute("gen_ai.system", "gcp.vertex.agent") + span.set_attribute("gen_ai.request.model", llm_request.model) span.set_attribute( - 'gcp.vertex.agent.invocation_id', invocation_context.invocation_id + "gcp.vertex.agent.invocation_id", invocation_context.invocation_id ) span.set_attribute( - 'gcp.vertex.agent.session_id', invocation_context.session.id + "gcp.vertex.agent.session_id", invocation_context.session.id ) - span.set_attribute('gcp.vertex.agent.event_id', event_id) + span.set_attribute("gcp.vertex.agent.event_id", event_id) # Consider removing once GenAI SDK provides a way to record this info. if _should_add_request_response_to_spans(): span.set_attribute( - 'gcp.vertex.agent.llm_request', + "gcp.vertex.agent.llm_request", _safe_json_serialize(_build_llm_request_for_trace(llm_request)), ) else: - span.set_attribute('gcp.vertex.agent.llm_request', '{}') + span.set_attribute("gcp.vertex.agent.llm_request", "{}") # Consider removing once GenAI SDK provides a way to record this info. if llm_request.config: if llm_request.config.top_p: span.set_attribute( - 'gen_ai.request.top_p', + "gen_ai.request.top_p", llm_request.config.top_p, ) if llm_request.config.max_output_tokens: span.set_attribute( - 'gen_ai.request.max_tokens', + "gen_ai.request.max_tokens", llm_request.config.max_output_tokens, ) try: llm_response_json = llm_response.model_dump_json(exclude_none=True) except Exception: # pylint: disable=broad-exception-caught - llm_response_json = '' + llm_response_json = "" if _should_add_request_response_to_spans(): span.set_attribute( - 'gcp.vertex.agent.llm_response', + "gcp.vertex.agent.llm_response", llm_response_json, ) else: - span.set_attribute('gcp.vertex.agent.llm_response', '{}') + span.set_attribute("gcp.vertex.agent.llm_response", "{}") if llm_response.usage_metadata is not None: span.set_attribute( - 'gen_ai.usage.input_tokens', + "gen_ai.usage.input_tokens", llm_response.usage_metadata.prompt_token_count, ) if llm_response.usage_metadata.candidates_token_count is not None: span.set_attribute( - 'gen_ai.usage.output_tokens', + "gen_ai.usage.output_tokens", llm_response.usage_metadata.candidates_token_count, ) if llm_response.finish_reason: @@ -340,7 +341,7 @@ def trace_call_llm( except AttributeError: finish_reason_str = str(llm_response.finish_reason).lower() span.set_attribute( - 'gen_ai.response.finish_reasons', + "gen_ai.response.finish_reasons", [finish_reason_str], ) @@ -362,23 +363,23 @@ def trace_send_data( """ span = trace.get_current_span() span.set_attribute( - 'gcp.vertex.agent.invocation_id', invocation_context.invocation_id + "gcp.vertex.agent.invocation_id", invocation_context.invocation_id ) - span.set_attribute('gcp.vertex.agent.event_id', event_id) + span.set_attribute("gcp.vertex.agent.event_id", event_id) # Once instrumentation is added to the GenAI SDK, consider whether this # information still needs to be recorded by the Agent Development Kit. if _should_add_request_response_to_spans(): span.set_attribute( - 'gcp.vertex.agent.data', + "gcp.vertex.agent.data", _safe_json_serialize([ types.Content(role=content.role, parts=content.parts).model_dump( - exclude_none=True, mode='json' + exclude_none=True, mode="json" ) for content in data ]), ) else: - span.set_attribute('gcp.vertex.agent.data', '{}') + span.set_attribute("gcp.vertex.agent.data", "{}") def _build_llm_request_for_trace(llm_request: LlmRequest) -> dict[str, Any]: @@ -396,18 +397,18 @@ def _build_llm_request_for_trace(llm_request: LlmRequest) -> dict[str, Any]: """ # Some fields in LlmRequest are function pointers and cannot be serialized. result = { - 'model': llm_request.model, - 'config': llm_request.config.model_dump( - exclude_none=True, exclude='response_schema', mode='json' + "model": llm_request.model, + "config": llm_request.config.model_dump( + exclude_none=True, exclude="response_schema", mode="json" ), - 'contents': [], + "contents": [], } # We do not want to send bytes data to the trace. for content in llm_request.contents: parts = [part for part in content.parts if not part.inline_data] - result['contents'].append( + result["contents"].append( types.Content(role=content.role, parts=parts).model_dump( - exclude_none=True, mode='json' + exclude_none=True, mode="json" ) ) return result @@ -419,8 +420,8 @@ def _build_llm_request_for_trace(llm_request: LlmRequest) -> dict[str, Any]: # to false. def _should_add_request_response_to_spans() -> bool: disabled_via_env_var = os.getenv( - ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS, 'true' - ).lower() in ('false', '0') + ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS, "true" + ).lower() in ("false", "0") return not disabled_via_env_var @@ -438,7 +439,7 @@ def use_generate_content_span( common_attributes = { GEN_AI_CONVERSATION_ID: invocation_context.session.id, - 'gcp.vertex.agent.event_id': model_response_event.id, + "gcp.vertex.agent.event_id": model_response_event.id, } if ( _is_gemini_agent(invocation_context.agent) @@ -455,8 +456,8 @@ def use_generate_content_span( def _should_log_prompt_response_content() -> bool: return os.getenv( - OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT, '' - ).lower() in ('1', 'true') + OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT, "" + ).lower() in ("1", "true") def _serialize_content(content: types.ContentUnion) -> AnyValue: @@ -481,9 +482,9 @@ def _serialize_content_with_elision( def _instrumented_with_opentelemetry_instrumentation_google_genai() -> bool: maybe_wrapped_function = Models.generate_content - while wrapped := getattr(maybe_wrapped_function, '__wrapped__', None): + while wrapped := getattr(maybe_wrapped_function, "__wrapped__", None): if ( - 'opentelemetry/instrumentation/google_genai' + "opentelemetry/instrumentation/google_genai" in maybe_wrapped_function.__code__.co_filename ): return True @@ -515,15 +516,15 @@ def _use_native_generate_content_span( f"generate_content {llm_request.model or ''}" ) as span: span.set_attribute(GEN_AI_SYSTEM, _guess_gemini_system_name()) - span.set_attribute(GEN_AI_OPERATION_NAME, 'generate_content') - span.set_attribute(GEN_AI_REQUEST_MODEL, llm_request.model or '') + span.set_attribute(GEN_AI_OPERATION_NAME, "generate_content") + span.set_attribute(GEN_AI_REQUEST_MODEL, llm_request.model or "") span.set_attributes(common_attributes) otel_logger.emit( LogRecord( - event_name='gen_ai.system.message', + event_name="gen_ai.system.message", body={ - 'content': _serialize_content_with_elision( + "content": _serialize_content_with_elision( llm_request.config.system_instruction ) }, @@ -534,8 +535,8 @@ def _use_native_generate_content_span( for content in llm_request.contents: otel_logger.emit( LogRecord( - event_name='gen_ai.user.message', - body={'content': _serialize_content_with_elision(content)}, + event_name="gen_ai.user.message", + body={"content": _serialize_content_with_elision(content)}, attributes={GEN_AI_SYSTEM: _guess_gemini_system_name()}, ) ) @@ -566,12 +567,12 @@ def trace_generate_content_result(span: Span | None, llm_response: LlmResponse): otel_logger.emit( LogRecord( - event_name='gen_ai.choice', + event_name="gen_ai.choice", body={ - 'content': _serialize_content_with_elision(llm_response.content), - 'index': 0, # ADK always returns a single candidate + "content": _serialize_content_with_elision(llm_response.content), + "index": 0, # ADK always returns a single candidate } - | {'finish_reason': llm_response.finish_reason.value} + | {"finish_reason": llm_response.finish_reason.value} if llm_response.finish_reason is not None else {}, attributes={GEN_AI_SYSTEM: _guess_gemini_system_name()}, @@ -582,6 +583,184 @@ def trace_generate_content_result(span: Span | None, llm_response: LlmResponse): def _guess_gemini_system_name() -> str: return ( GenAiSystemValues.VERTEX_AI.name.lower() - if os.getenv('GOOGLE_GENAI_USE_VERTEXAI', '').lower() in ('true', '1') + if os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() in ("true", "1") else GenAiSystemValues.GEMINI.name.lower() ) + + +# ============================================================================= +# CodingAgent-specific tracing functions +# ============================================================================= + + +def trace_code_generation( + agent_name: str, + code: str, + iteration: int, + duration_ms: float, +) -> None: + """Traces code generation by CodingAgent. + + Args: + agent_name: Name of the CodingAgent. + code: The generated code. + iteration: Current iteration number. + duration_ms: Time taken for code generation in milliseconds. + """ + span = trace.get_current_span() + + span.set_attribute(GEN_AI_OPERATION_NAME, "generate_code") + span.set_attribute(GEN_AI_AGENT_NAME, agent_name) + span.set_attribute("gcp.vertex.agent.coding_agent.iteration", iteration) + span.set_attribute( + "gcp.vertex.agent.coding_agent.generation_duration_ms", duration_ms + ) + + if _should_add_request_response_to_spans(): + # Truncate code if too long for span attribute + max_code_length = 10000 + truncated_code = code[:max_code_length] + if len(code) > max_code_length: + truncated_code += "\n... [truncated]" + span.set_attribute( + "gcp.vertex.agent.coding_agent.generated_code", truncated_code + ) + else: + span.set_attribute("gcp.vertex.agent.coding_agent.generated_code", "{}") + + +def trace_code_execution( + agent_name: str, + code: str, + stdout: str, + stderr: str, + duration_ms: float, + success: bool, + has_final_answer: bool, +) -> None: + """Traces code execution by CodingAgent. + + Args: + agent_name: Name of the CodingAgent. + code: The executed code. + stdout: Standard output from execution. + stderr: Standard error from execution. + duration_ms: Time taken for execution in milliseconds. + success: Whether execution was successful. + has_final_answer: Whether execution produced a final answer. + """ + span = trace.get_current_span() + + span.set_attribute(GEN_AI_OPERATION_NAME, "execute_code") + span.set_attribute(GEN_AI_AGENT_NAME, agent_name) + span.set_attribute( + "gcp.vertex.agent.coding_agent.execution_duration_ms", duration_ms + ) + span.set_attribute("gcp.vertex.agent.coding_agent.execution_success", success) + span.set_attribute( + "gcp.vertex.agent.coding_agent.has_final_answer", has_final_answer + ) + + if _should_add_request_response_to_spans(): + # Truncate outputs if too long + max_output_length = 5000 + truncated_stdout = stdout[:max_output_length] + if len(stdout) > max_output_length: + truncated_stdout += "\n... [truncated]" + truncated_stderr = stderr[:max_output_length] + if len(stderr) > max_output_length: + truncated_stderr += "\n... [truncated]" + truncated_code = code[:max_output_length] + if len(code) > max_output_length: + truncated_code += "\n... [truncated]" + + span.set_attribute( + "gcp.vertex.agent.coding_agent.executed_code", truncated_code + ) + span.set_attribute("gcp.vertex.agent.coding_agent.stdout", truncated_stdout) + span.set_attribute("gcp.vertex.agent.coding_agent.stderr", truncated_stderr) + else: + span.set_attribute("gcp.vertex.agent.coding_agent.executed_code", "{}") + span.set_attribute("gcp.vertex.agent.coding_agent.stdout", "{}") + span.set_attribute("gcp.vertex.agent.coding_agent.stderr", "{}") + + +def trace_import_validation( + agent_name: str, + code: str, + violations: list[str], + duration_ms: float, +) -> None: + """Traces import validation by CodingAgent. + + Args: + agent_name: Name of the CodingAgent. + code: The code that was validated. + violations: List of import violations found. + duration_ms: Time taken for validation in milliseconds. + """ + span = trace.get_current_span() + + span.set_attribute(GEN_AI_OPERATION_NAME, "validate_imports") + span.set_attribute(GEN_AI_AGENT_NAME, agent_name) + span.set_attribute( + "gcp.vertex.agent.coding_agent.validation_duration_ms", duration_ms + ) + span.set_attribute( + "gcp.vertex.agent.coding_agent.violation_count", len(violations) + ) + span.set_attribute( + "gcp.vertex.agent.coding_agent.validation_passed", len(violations) == 0 + ) + + if _should_add_request_response_to_spans() and violations: + span.set_attribute( + "gcp.vertex.agent.coding_agent.violations", + _safe_json_serialize(violations), + ) + + +def trace_tool_ipc( + agent_name: str, + tool_name: str, + args: dict[str, Any], + result: Any, + duration_ms: float, + success: bool, + error: Optional[str] = None, +) -> None: + """Traces tool IPC calls from container to host. + + Args: + agent_name: Name of the CodingAgent. + tool_name: Name of the tool called. + args: Arguments passed to the tool. + result: Result returned by the tool. + duration_ms: Time taken for the IPC call in milliseconds. + success: Whether the call was successful. + error: Error message if call failed. + """ + span = trace.get_current_span() + + span.set_attribute(GEN_AI_OPERATION_NAME, "tool_ipc") + span.set_attribute(GEN_AI_AGENT_NAME, agent_name) + span.set_attribute(GEN_AI_TOOL_NAME, tool_name) + span.set_attribute( + "gcp.vertex.agent.coding_agent.ipc_duration_ms", duration_ms + ) + span.set_attribute("gcp.vertex.agent.coding_agent.ipc_success", success) + + if error: + span.set_attribute("gcp.vertex.agent.coding_agent.ipc_error", error) + + if _should_add_request_response_to_spans(): + span.set_attribute( + "gcp.vertex.agent.coding_agent.tool_args", _safe_json_serialize(args) + ) + span.set_attribute( + "gcp.vertex.agent.coding_agent.tool_result", + _safe_json_serialize(result), + ) + else: + span.set_attribute("gcp.vertex.agent.coding_agent.tool_args", "{}") + span.set_attribute("gcp.vertex.agent.coding_agent.tool_result", "{}") diff --git a/tests/unittests/agents/test_coding_agent.py b/tests/unittests/agents/test_coding_agent.py new file mode 100644 index 0000000000..657b1e87da --- /dev/null +++ b/tests/unittests/agents/test_coding_agent.py @@ -0,0 +1,310 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for CodingAgent.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch + +from google.adk.agents.coding_agent import CodingAgent +from google.adk.agents.coding_agent import CodingAgentState +from google.adk.agents.coding_agent_config import _EXTENDED_SAFE_IMPORTS +from google.adk.agents.coding_agent_config import CodingAgentConfig +from google.adk.tools.base_tool import BaseTool +import pytest + + +class TestCodingAgentConfig: + """Tests for CodingAgentConfig.""" + + def test_default_values(self): + """Test that default values are set correctly.""" + config = CodingAgentConfig(name="test_agent") + + assert config.name == "test_agent" + assert config.agent_class == "CodingAgent" + assert config.max_iterations == 10 + assert config.error_retry_attempts == 2 + assert config.stateful is False + assert config.tool_server_port == 8765 + assert config.authorized_imports == _EXTENDED_SAFE_IMPORTS + + def test_custom_values(self): + """Test that custom values can be set.""" + custom_imports = frozenset({"json", "math"}) + config = CodingAgentConfig( + name="custom_agent", + model="gemini-2.0-flash", + max_iterations=20, + error_retry_attempts=5, + stateful=True, + tool_server_port=9999, + authorized_imports=custom_imports, + ) + + assert config.name == "custom_agent" + assert config.model == "gemini-2.0-flash" + assert config.max_iterations == 20 + assert config.error_retry_attempts == 5 + assert config.stateful is True + assert config.tool_server_port == 9999 + assert config.authorized_imports == custom_imports + + def test_max_iterations_bounds(self): + """Test max_iterations validation.""" + # Valid bounds + config = CodingAgentConfig(name="test", max_iterations=1) + assert config.max_iterations == 1 + + config = CodingAgentConfig(name="test", max_iterations=100) + assert config.max_iterations == 100 + + # Invalid bounds + with pytest.raises(ValueError): + CodingAgentConfig(name="test", max_iterations=0) + + with pytest.raises(ValueError): + CodingAgentConfig(name="test", max_iterations=101) + + def test_port_bounds(self): + """Test tool_server_port validation.""" + # Valid bounds + config = CodingAgentConfig(name="test", tool_server_port=1024) + assert config.tool_server_port == 1024 + + config = CodingAgentConfig(name="test", tool_server_port=65535) + assert config.tool_server_port == 65535 + + # Invalid bounds + with pytest.raises(ValueError): + CodingAgentConfig(name="test", tool_server_port=1023) + + with pytest.raises(ValueError): + CodingAgentConfig(name="test", tool_server_port=65536) + + +class TestCodingAgentState: + """Tests for CodingAgentState.""" + + def test_default_state(self): + """Test default state values.""" + state = CodingAgentState() + + assert state.iteration_count == 0 + assert state.error_count == 0 + assert state.execution_history == [] + + def test_state_with_history(self): + """Test state with execution history.""" + history = [ + {"iteration": 1, "code": "print('hello')", "success": True}, + {"iteration": 2, "code": "print('world')", "success": True}, + ] + state = CodingAgentState( + iteration_count=2, + error_count=0, + execution_history=history, + ) + + assert state.iteration_count == 2 + assert len(state.execution_history) == 2 + + def test_state_serialization(self): + """Test state can be serialized and deserialized.""" + state = CodingAgentState( + iteration_count=5, + error_count=1, + execution_history=[{"iteration": 1, "code": "x = 1"}], + ) + + dumped = state.model_dump() + restored = CodingAgentState.model_validate(dumped) + + assert restored.iteration_count == 5 + assert restored.error_count == 1 + assert len(restored.execution_history) == 1 + + +class TestCodingAgent: + """Tests for CodingAgent.""" + + def test_agent_creation(self): + """Test basic agent creation.""" + agent = CodingAgent( + name="test_coding_agent", + description="A test coding agent", + ) + + assert agent.name == "test_coding_agent" + assert agent.description == "A test coding agent" + assert agent.max_iterations == 10 + assert agent.error_retry_attempts == 2 + + def test_agent_with_custom_config(self): + """Test agent with custom configuration.""" + agent = CodingAgent( + name="custom_agent", + model="gemini-2.0-flash", + max_iterations=5, + error_retry_attempts=3, + stateful=True, + ) + + assert agent.name == "custom_agent" + assert agent.model == "gemini-2.0-flash" + assert agent.max_iterations == 5 + assert agent.error_retry_attempts == 3 + assert agent.stateful is True + + def test_extract_code_block_tool_code(self): + """Test code extraction from tool_code blocks.""" + agent = CodingAgent(name="test") + + response = """Here's some code: +```tool_code +result = search(query="test") +print(result) +``` +That should work.""" + + code = agent._extract_code_block(response) + assert code == 'result = search(query="test")\nprint(result)' + + def test_extract_code_block_python(self): + """Test code extraction from python blocks.""" + agent = CodingAgent(name="test") + + response = """Here's some code: +```python +x = 1 + 2 +print(x) +``` +Done.""" + + code = agent._extract_code_block(response) + assert code == "x = 1 + 2\nprint(x)" + + def test_extract_code_block_prefers_tool_code(self): + """Test that tool_code blocks are preferred over python blocks.""" + agent = CodingAgent(name="test") + + response = """Code: +```tool_code +tool_result = tool_call() +``` +Also: +```python +python_code = True +```""" + + code = agent._extract_code_block(response) + assert code == "tool_result = tool_call()" + + def test_extract_code_block_no_code(self): + """Test code extraction when no code block present.""" + agent = CodingAgent(name="test") + + response = "This is just text without any code blocks." + code = agent._extract_code_block(response) + assert code is None + + def test_build_error_feedback(self): + """Test error feedback formatting.""" + agent = CodingAgent(name="test") + + error = "NameError: name 'undefined_var' is not defined" + code = "print(undefined_var)" + + feedback = agent._build_error_feedback(error, code) + + assert "NameError" in feedback + assert "undefined_var" in feedback + assert code in feedback + assert "fix the error" in feedback.lower() + + def test_default_model(self): + """Test that default model is used when not specified.""" + agent = CodingAgent(name="test") + + # canonical_model property should return a BaseLlm + model = agent.canonical_model + assert model is not None + + def test_cleanup(self): + """Test that cleanup releases resources.""" + agent = CodingAgent(name="test") + agent._resolved_tools = [MagicMock()] + agent._coding_executor = MagicMock() + + agent.cleanup() + + assert agent._resolved_tools is None + assert agent._coding_executor is None + + +class TestCodingAgentTools: + """Tests for CodingAgent tool handling.""" + + def test_agent_with_function_tools(self): + """Test agent with function tools.""" + + def my_tool(query: str) -> dict: + """A test tool.""" + return {"result": query} + + agent = CodingAgent( + name="test", + tools=[my_tool], + ) + + assert len(agent.tools) == 1 + + def test_agent_with_base_tool(self): + """Test agent with BaseTool instances.""" + + class MockTool(BaseTool): + + def __init__(self): + super().__init__(name="mock_tool", description="A mock tool") + + async def run_async(self, *, args, tool_context): + return {"result": "mock"} + + tool = MockTool() + agent = CodingAgent( + name="test", + tools=[tool], + ) + + assert len(agent.tools) == 1 + + @pytest.mark.asyncio + async def test_resolve_tools(self): + """Test tool resolution.""" + + def test_func(x: int) -> int: + """Test function.""" + return x * 2 + + agent = CodingAgent( + name="test", + tools=[test_func], + ) + + tools = await agent._resolve_tools() + assert len(tools) == 1 + assert tools[0].name == "test_func" diff --git a/tests/unittests/code_executors/test_allowlist_validator.py b/tests/unittests/code_executors/test_allowlist_validator.py new file mode 100644 index 0000000000..58ae45771a --- /dev/null +++ b/tests/unittests/code_executors/test_allowlist_validator.py @@ -0,0 +1,320 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for AllowlistValidator.""" + +from __future__ import annotations + +from google.adk.code_executors.allowlist_validator import AllowlistValidator +from google.adk.code_executors.allowlist_validator import DEFAULT_SAFE_IMPORTS +from google.adk.code_executors.allowlist_validator import extract_imports +from google.adk.code_executors.allowlist_validator import ImportValidationError +from google.adk.code_executors.allowlist_validator import is_import_allowed +from google.adk.code_executors.allowlist_validator import validate_imports +import pytest + + +class TestExtractImports: + """Tests for extract_imports function.""" + + def test_simple_import(self): + """Test extracting simple import statements.""" + code = "import json" + imports = extract_imports(code) + + assert len(imports) == 1 + assert imports[0].module == "json" + assert imports[0].is_from_import is False + + def test_multiple_imports(self): + """Test extracting multiple imports.""" + code = """ +import json +import math +import re +""" + imports = extract_imports(code) + + assert len(imports) == 3 + modules = {i.module for i in imports} + assert modules == {"json", "math", "re"} + + def test_from_import(self): + """Test extracting from imports.""" + code = "from collections import defaultdict" + imports = extract_imports(code) + + assert len(imports) == 1 + assert imports[0].module == "collections" + assert imports[0].names == ["defaultdict"] + assert imports[0].is_from_import is True + + def test_from_import_multiple(self): + """Test extracting from imports with multiple names.""" + code = "from typing import List, Dict, Optional" + imports = extract_imports(code) + + assert len(imports) == 3 + for imp in imports: + assert imp.module == "typing" + assert imp.is_from_import is True + + def test_import_with_alias(self): + """Test extracting imports with aliases.""" + code = "import numpy as np" + imports = extract_imports(code) + + assert len(imports) == 1 + assert imports[0].module == "numpy" + assert imports[0].alias == "np" + + def test_submodule_import(self): + """Test extracting submodule imports.""" + code = "import os.path" + imports = extract_imports(code) + + assert len(imports) == 1 + assert imports[0].module == "os.path" + + def test_from_submodule_import(self): + """Test extracting from submodule imports.""" + code = "from collections.abc import Mapping" + imports = extract_imports(code) + + assert len(imports) == 1 + assert imports[0].module == "collections.abc" + assert imports[0].names == ["Mapping"] + + def test_syntax_error(self): + """Test handling of syntax errors.""" + code = "import json\nthis is not valid python" + + with pytest.raises(SyntaxError): + extract_imports(code) + + def test_no_imports(self): + """Test code with no imports.""" + code = "x = 1 + 2\nprint(x)" + imports = extract_imports(code) + + assert len(imports) == 0 + + +class TestIsImportAllowed: + """Tests for is_import_allowed function.""" + + def test_direct_match(self): + """Test direct import match.""" + allowlist = frozenset({"json", "math"}) + + assert is_import_allowed("json", allowlist) is True + assert is_import_allowed("math", allowlist) is True + assert is_import_allowed("os", allowlist) is False + + def test_wildcard_match(self): + """Test wildcard pattern matching.""" + allowlist = frozenset({"collections.*"}) + + assert is_import_allowed("collections.abc", allowlist) is True + assert is_import_allowed("collections.defaultdict", allowlist) is True + assert is_import_allowed("itertools", allowlist) is False + + def test_deep_wildcard_match(self): + """Test wildcard matching for deep submodules.""" + allowlist = frozenset({"collections.*"}) + + assert is_import_allowed("collections.abc.Mapping", allowlist) is True + + def test_exact_vs_wildcard(self): + """Test that exact matches work without wildcard.""" + allowlist = frozenset({"numpy"}) + + assert is_import_allowed("numpy", allowlist) is True + # Without wildcard, submodules are not allowed + assert is_import_allowed("numpy.array", allowlist) is False + + def test_multiple_patterns(self): + """Test multiple patterns in allowlist.""" + allowlist = frozenset({"json", "typing.*", "collections"}) + + assert is_import_allowed("json", allowlist) is True + assert is_import_allowed("typing.List", allowlist) is True + assert is_import_allowed("collections", allowlist) is True + assert is_import_allowed("collections.abc", allowlist) is False + + +class TestValidateImports: + """Tests for validate_imports function.""" + + def test_all_allowed(self): + """Test code with all imports allowed.""" + code = """ +import json +import math +from typing import List +""" + allowlist = frozenset({"json", "math", "typing.*"}) + + violations = validate_imports(code, allowlist) + assert len(violations) == 0 + + def test_some_violations(self): + """Test code with some unauthorized imports.""" + code = """ +import json +import os +import subprocess +""" + allowlist = frozenset({"json"}) + + violations = validate_imports(code, allowlist) + assert len(violations) == 2 + assert any("os" in v for v in violations) + assert any("subprocess" in v for v in violations) + + def test_from_import_violations(self): + """Test from import violations.""" + code = "from os import system" + allowlist = frozenset({"json"}) + + violations = validate_imports(code, allowlist) + assert len(violations) == 1 + assert "os" in violations[0] + + def test_syntax_error_violation(self): + """Test that syntax errors are reported as violations.""" + code = "import json\n$$$invalid" + allowlist = frozenset({"json"}) + + violations = validate_imports(code, allowlist) + assert len(violations) == 1 + assert "Syntax error" in violations[0] + + +class TestImportValidationError: + """Tests for ImportValidationError exception.""" + + def test_error_message(self): + """Test error message formatting.""" + violations = ["Unauthorized import: os", "Unauthorized import: subprocess"] + code = "import os\nimport subprocess" + + error = ImportValidationError(violations, code) + + assert "Import validation failed" in str(error) + assert "os" in str(error) + assert "subprocess" in str(error) + + def test_error_attributes(self): + """Test error attributes.""" + violations = ["violation1", "violation2"] + code = "some code" + + error = ImportValidationError(violations, code) + + assert error.violations == violations + assert error.code == code + + +class TestAllowlistValidator: + """Tests for AllowlistValidator class.""" + + def test_default_allowlist(self): + """Test validator with default allowlist.""" + validator = AllowlistValidator() + + # These should be in the default safe imports + assert validator.is_allowed("json") is True + assert validator.is_allowed("math") is True + assert validator.is_allowed("typing") is True + + # These should not be in the default safe imports + assert validator.is_allowed("os") is False + assert validator.is_allowed("subprocess") is False + + def test_custom_allowlist(self): + """Test validator with custom allowlist.""" + custom = frozenset({"custom_module"}) + validator = AllowlistValidator(allowlist=custom) + + assert validator.is_allowed("custom_module") is True + assert validator.is_allowed("json") is False + + def test_additional_imports(self): + """Test adding additional imports to default.""" + additional = frozenset({"custom_module", "another_module"}) + validator = AllowlistValidator(additional_imports=additional) + + # Should have both default and additional + assert validator.is_allowed("json") is True + assert validator.is_allowed("custom_module") is True + assert validator.is_allowed("another_module") is True + + def test_validate_method(self): + """Test validate method returns violations.""" + validator = AllowlistValidator(allowlist=frozenset({"json"})) + + violations = validator.validate("import json\nimport os") + assert len(violations) == 1 + assert "os" in violations[0] + + def test_validate_strict_raises(self): + """Test validate_strict raises on violations.""" + validator = AllowlistValidator(allowlist=frozenset({"json"})) + + with pytest.raises(ImportValidationError): + validator.validate_strict("import os") + + def test_validate_strict_passes(self): + """Test validate_strict passes when no violations.""" + validator = AllowlistValidator(allowlist=frozenset({"json"})) + + # Should not raise + validator.validate_strict("import json") + + def test_add_allowed_imports(self): + """Test adding imports after construction.""" + validator = AllowlistValidator(allowlist=frozenset({"json"})) + + assert validator.is_allowed("os") is False + + validator.add_allowed_imports({"os"}) + + assert validator.is_allowed("os") is True + + +class TestDefaultSafeImports: + """Tests for the default safe imports list.""" + + def test_common_safe_imports_included(self): + """Test that common safe imports are in the default list.""" + assert "json" in DEFAULT_SAFE_IMPORTS + assert "math" in DEFAULT_SAFE_IMPORTS + assert "re" in DEFAULT_SAFE_IMPORTS + assert "datetime" in DEFAULT_SAFE_IMPORTS + assert "typing" in DEFAULT_SAFE_IMPORTS + assert "collections" in DEFAULT_SAFE_IMPORTS + + def test_dangerous_imports_not_included(self): + """Test that dangerous imports are not in the default list.""" + assert "os" not in DEFAULT_SAFE_IMPORTS + assert "subprocess" not in DEFAULT_SAFE_IMPORTS + assert "sys" not in DEFAULT_SAFE_IMPORTS + assert "socket" not in DEFAULT_SAFE_IMPORTS + assert "ctypes" not in DEFAULT_SAFE_IMPORTS + + def test_wildcard_patterns_included(self): + """Test that wildcard patterns are included.""" + assert "collections.*" in DEFAULT_SAFE_IMPORTS + assert "typing.*" in DEFAULT_SAFE_IMPORTS diff --git a/tests/unittests/code_executors/test_tool_code_generator.py b/tests/unittests/code_executors/test_tool_code_generator.py new file mode 100644 index 0000000000..5a067e47e2 --- /dev/null +++ b/tests/unittests/code_executors/test_tool_code_generator.py @@ -0,0 +1,319 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for ToolCodeGenerator.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +from google.adk.code_executors.tool_code_generator import generate_full_code_with_stubs +from google.adk.code_executors.tool_code_generator import generate_runtime_header +from google.adk.code_executors.tool_code_generator import generate_system_prompt +from google.adk.code_executors.tool_code_generator import generate_tool_stubs +from google.adk.tools.base_tool import BaseTool +from google.genai import types +import pytest + + +class MockTool(BaseTool): + """Mock tool for testing.""" + + def __init__( + self, + name: str = "mock_tool", + description: str = "A mock tool for testing", + params: dict = None, + ): + super().__init__(name=name, description=description) + self._params = params or {} + + def _get_declaration(self): + properties = {} + required = [] + + for param_name, param_info in self._params.items(): + properties[param_name] = { + "type": param_info.get("type", "string"), + "description": param_info.get("description", ""), + } + if param_info.get("required", False): + required.append(param_name) + + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type="object", + properties=properties, + required=required if required else None, + ), + ) + + async def run_async(self, *, args, tool_context): + return {"result": "mock"} + + +class TestGenerateRuntimeHeader: + """Tests for generate_runtime_header function.""" + + def test_generates_valid_header(self): + """Test that the header contains required elements.""" + url = "http://localhost:8765" + header = generate_runtime_header(url) + + # Should contain the URL + assert url in header + + # Should contain helper functions + assert "_call_adk_tool" in header + assert "final_answer" in header + assert "__get_tool_traces" in header + + # Should be valid Python syntax + compile(header, "", "exec") + + def test_header_with_different_urls(self): + """Test header generation with different URLs.""" + urls = [ + "http://localhost:8765", + "http://host.docker.internal:9999", + "http://172.17.0.1:8765", + ] + + for url in urls: + header = generate_runtime_header(url) + assert url in header + + def test_header_contains_trace_collection(self): + """Test that header contains trace collection code.""" + header = generate_runtime_header("http://localhost:8765") + + assert "__ADK_TOOL_TRACES" in header + assert "__get_tool_traces" in header + assert "__clear_tool_traces" in header + + def test_header_contains_final_answer_marker(self): + """Test that header contains final answer marker.""" + header = generate_runtime_header("http://localhost:8765") + + assert "__FINAL_ANSWER__" in header + assert "final_answer" in header + + +class TestGenerateToolStubs: + """Tests for generate_tool_stubs function.""" + + def test_generates_stub_for_tool(self): + """Test generating stub for a single tool.""" + tool = MockTool( + name="search", + description="Search for information", + params={ + "query": { + "type": "string", + "description": "The search query", + "required": True, + } + }, + ) + + stubs = generate_tool_stubs([tool]) + + # Should contain function definition + assert "def search(" in stubs + assert "query" in stubs + + # Should be valid Python + compile(stubs, "", "exec") + + def test_generates_stubs_for_multiple_tools(self): + """Test generating stubs for multiple tools.""" + tools = [ + MockTool(name="tool1", description="First tool"), + MockTool(name="tool2", description="Second tool"), + MockTool(name="tool3", description="Third tool"), + ] + + stubs = generate_tool_stubs(tools) + + assert "def tool1(" in stubs + assert "def tool2(" in stubs + assert "def tool3(" in stubs + + def test_stub_includes_docstring(self): + """Test that stubs include docstrings.""" + tool = MockTool( + name="my_tool", + description="A tool that does something useful", + ) + + stubs = generate_tool_stubs([tool]) + + assert '"""' in stubs + assert "A tool that does something useful" in stubs + + def test_stub_includes_type_hints(self): + """Test that stubs include type hints.""" + tool = MockTool( + name="typed_tool", + description="A typed tool", + params={ + "count": {"type": "integer", "description": "A count"}, + "name": {"type": "string", "description": "A name"}, + "enabled": {"type": "boolean", "description": "Is enabled"}, + }, + ) + + stubs = generate_tool_stubs([tool]) + + assert "int" in stubs + assert "str" in stubs + assert "bool" in stubs + + def test_empty_tool_list(self): + """Test generating stubs for empty tool list.""" + stubs = generate_tool_stubs([]) + + # Should still be valid Python + compile(stubs, "", "exec") + + +class TestGenerateSystemPrompt: + """Tests for generate_system_prompt function.""" + + def test_generates_prompt_with_tools(self): + """Test generating system prompt with tools.""" + tools = [ + MockTool( + name="search", + description="Search the web", + params={"query": {"type": "string", "required": True}}, + ), + ] + + prompt = generate_system_prompt(tools) + + # Should contain tool documentation + assert "search" in prompt + assert "Search the web" in prompt + + # Should contain usage instructions + assert "tool_code" in prompt + assert "final_answer" in prompt + + def test_generates_prompt_with_custom_instruction(self): + """Test generating prompt with custom instruction.""" + tools = [] + custom = "Always be polite and helpful." + + prompt = generate_system_prompt(tools, custom_instruction=custom) + + assert custom in prompt + + def test_generates_prompt_with_examples(self): + """Test that prompt contains examples.""" + tools = [] + prompt = generate_system_prompt(tools) + + assert "Example" in prompt + assert "```tool_code" in prompt + + def test_generates_prompt_with_parameter_docs(self): + """Test that prompt includes parameter documentation.""" + tools = [ + MockTool( + name="get_weather", + description="Get weather for a city", + params={ + "city": { + "type": "string", + "description": "The city name", + "required": True, + }, + "units": { + "type": "string", + "description": "Temperature units", + "required": False, + }, + }, + ), + ] + + prompt = generate_system_prompt(tools) + + assert "city" in prompt + assert "units" in prompt + assert "required" in prompt.lower() or "optional" in prompt.lower() + + +class TestGenerateFullCodeWithStubs: + """Tests for generate_full_code_with_stubs function.""" + + def test_generates_complete_code(self): + """Test generating complete executable code.""" + tools = [MockTool(name="my_tool", description="A tool")] + user_code = "result = my_tool()\nprint(result)" + + full_code = generate_full_code_with_stubs( + user_code=user_code, + tools=tools, + tool_server_url="http://localhost:8765", + ) + + # Should contain runtime header + assert "_call_adk_tool" in full_code + + # Should contain tool stub + assert "def my_tool(" in full_code + + # Should contain user code + assert user_code in full_code + + # Should be valid Python + compile(full_code, "", "exec") + + def test_generated_code_outputs_traces(self): + """Test that generated code outputs traces.""" + tools = [] + user_code = "x = 1" + + full_code = generate_full_code_with_stubs( + user_code=user_code, + tools=tools, + tool_server_url="http://localhost:8765", + ) + + assert "__TOOL_TRACE__" in full_code + + def test_generated_code_is_executable(self): + """Test that generated code can be compiled.""" + tools = [ + MockTool(name="tool_a", description="Tool A"), + MockTool(name="tool_b", description="Tool B"), + ] + user_code = """ +result_a = tool_a() +result_b = tool_b() +print(result_a, result_b) +""" + + full_code = generate_full_code_with_stubs( + user_code=user_code, + tools=tools, + tool_server_url="http://localhost:8765", + ) + + # Should compile without errors + compile(full_code, "", "exec")