spec_ai_core/tools/builtin/
generate_code.rs1use crate::agent::model::{GenerationConfig, ModelProvider};
2use crate::tools::{Tool, ToolResult};
3use anyhow::{anyhow, Context, Result};
4use async_trait::async_trait;
5use serde::Deserialize;
6use serde_json::Value;
7use std::sync::Arc;
8
9pub struct GenerateCodeTool {
11 provider: Arc<dyn ModelProvider>,
12}
13
14#[cfg(test)]
15mod tests {
16 use super::*;
17 use crate::agent::providers::MockProvider;
18 use serde_json::json;
19
20 #[tokio::test]
21 async fn generate_code_returns_content() {
22 let provider = Arc::new(MockProvider::new("fn main() {}"));
23 let tool = GenerateCodeTool::new(provider);
24
25 let args = json!({
26 "prompt": "write rust main"
27 });
28
29 let result = tool.execute(args).await.unwrap();
30 assert!(result.success);
31
32 let payload: serde_json::Value = serde_json::from_str(&result.output).unwrap();
33 assert_eq!(payload["content"], "fn main() {}");
34 assert_eq!(payload["model"], "mock-model");
35 }
36}
37
38impl GenerateCodeTool {
39 pub fn new(provider: Arc<dyn ModelProvider>) -> Self {
40 Self { provider }
41 }
42}
43
44#[derive(Debug, Deserialize)]
45struct GenerateCodeArgs {
46 prompt: String,
47 #[serde(default)]
48 max_tokens: Option<u32>,
49 #[serde(default)]
50 temperature: Option<f32>,
51}
52
53#[async_trait]
54impl Tool for GenerateCodeTool {
55 fn name(&self) -> &str {
56 "generate_code"
57 }
58
59 fn description(&self) -> &str {
60 "Generate code or code reviews using the configured code model."
61 }
62
63 fn parameters(&self) -> Value {
64 serde_json::json!({
65 "type": "object",
66 "properties": {
67 "prompt": {
68 "type": "string",
69 "description": "Instruction or request for the code model"
70 },
71 "max_tokens": {
72 "type": "integer",
73 "description": "Optional max tokens to generate"
74 },
75 "temperature": {
76 "type": "number",
77 "description": "Optional temperature override (0.0 - 2.0)"
78 }
79 },
80 "required": ["prompt"]
81 })
82 }
83
84 async fn execute(&self, args: Value) -> Result<ToolResult> {
85 let args: GenerateCodeArgs =
86 serde_json::from_value(args).context("parsing generate_code arguments")?;
87
88 let prompt = args.prompt.trim();
89 if prompt.is_empty() {
90 return Err(anyhow!("prompt cannot be empty"));
91 }
92
93 let generation_config = GenerationConfig {
94 temperature: args.temperature.map(|t| t.clamp(0.0, 2.0)),
95 max_tokens: args.max_tokens,
96 stop_sequences: None,
97 top_p: None,
98 frequency_penalty: None,
99 presence_penalty: None,
100 };
101
102 let response = self
103 .provider
104 .generate(prompt, &generation_config)
105 .await
106 .context("calling code model")?;
107
108 let output = serde_json::json!({
109 "model": response.model,
110 "content": response.content,
111 "usage": response.usage,
112 "finish_reason": response.finish_reason
113 });
114
115 Ok(ToolResult::success(
116 serde_json::to_string(&output).context("serializing code model response")?,
117 ))
118 }
119}