spec_ai_core/tools/builtin/
generate_code.rs

1use crate::agent::model::{GenerationConfig, ModelProvider};
2use crate::tools::{Tool, ToolResult};
3use anyhow::{anyhow, Context, Result};
4use async_trait::async_trait;
5use serde::Deserialize;
6use serde_json::Value;
7use std::sync::Arc;
8
9/// Generate code using a dedicated code model (configured via `model.code_model`).
10pub struct GenerateCodeTool {
11    provider: Arc<dyn ModelProvider>,
12}
13
14#[cfg(test)]
15mod tests {
16    use super::*;
17    use crate::agent::providers::MockProvider;
18    use serde_json::json;
19
20    #[tokio::test]
21    async fn generate_code_returns_content() {
22        let provider = Arc::new(MockProvider::new("fn main() {}"));
23        let tool = GenerateCodeTool::new(provider);
24
25        let args = json!({
26            "prompt": "write rust main"
27        });
28
29        let result = tool.execute(args).await.unwrap();
30        assert!(result.success);
31
32        let payload: serde_json::Value = serde_json::from_str(&result.output).unwrap();
33        assert_eq!(payload["content"], "fn main() {}");
34        assert_eq!(payload["model"], "mock-model");
35    }
36}
37
38impl GenerateCodeTool {
39    pub fn new(provider: Arc<dyn ModelProvider>) -> Self {
40        Self { provider }
41    }
42}
43
44#[derive(Debug, Deserialize)]
45struct GenerateCodeArgs {
46    prompt: String,
47    #[serde(default)]
48    max_tokens: Option<u32>,
49    #[serde(default)]
50    temperature: Option<f32>,
51}
52
53#[async_trait]
54impl Tool for GenerateCodeTool {
55    fn name(&self) -> &str {
56        "generate_code"
57    }
58
59    fn description(&self) -> &str {
60        "Generate code or code reviews using the configured code model."
61    }
62
63    fn parameters(&self) -> Value {
64        serde_json::json!({
65            "type": "object",
66            "properties": {
67                "prompt": {
68                    "type": "string",
69                    "description": "Instruction or request for the code model"
70                },
71                "max_tokens": {
72                    "type": "integer",
73                    "description": "Optional max tokens to generate"
74                },
75                "temperature": {
76                    "type": "number",
77                    "description": "Optional temperature override (0.0 - 2.0)"
78                }
79            },
80            "required": ["prompt"]
81        })
82    }
83
84    async fn execute(&self, args: Value) -> Result<ToolResult> {
85        let args: GenerateCodeArgs =
86            serde_json::from_value(args).context("parsing generate_code arguments")?;
87
88        let prompt = args.prompt.trim();
89        if prompt.is_empty() {
90            return Err(anyhow!("prompt cannot be empty"));
91        }
92
93        let generation_config = GenerationConfig {
94            temperature: args.temperature.map(|t| t.clamp(0.0, 2.0)),
95            max_tokens: args.max_tokens,
96            stop_sequences: None,
97            top_p: None,
98            frequency_penalty: None,
99            presence_penalty: None,
100        };
101
102        let response = self
103            .provider
104            .generate(prompt, &generation_config)
105            .await
106            .context("calling code model")?;
107
108        let output = serde_json::json!({
109            "model": response.model,
110            "content": response.content,
111            "usage": response.usage,
112            "finish_reason": response.finish_reason
113        });
114
115        Ok(ToolResult::success(
116            serde_json::to_string(&output).context("serializing code model response")?,
117        ))
118    }
119}