1use async_trait::async_trait;
9use serde_json::json;
10
11use super::{Tool, ToolCtx, ToolResult};
12use crate::event::{Block, RiskLevel};
13
14fn resolve_key(env_names: &[&str]) -> Option<String> {
15 for name in env_names {
16 if let Ok(v) = std::env::var(name) {
17 if !v.trim().is_empty() {
18 return Some(v);
19 }
20 }
21 }
22 None
23}
24
25pub struct ImageGen {
29 base_url: String,
30 model: String,
31}
32
33impl ImageGen {
34 pub fn new() -> Self {
35 Self {
36 base_url: std::env::var("IMAGE_API_BASE")
37 .unwrap_or_else(|_| "https://api.openai.com/v1".into()),
38 model: std::env::var("IMAGE_MODEL").unwrap_or_else(|_| "gpt-image-1".into()),
39 }
40 }
41}
42
43#[async_trait]
44impl Tool for ImageGen {
45 fn name(&self) -> &str {
46 "image_generate"
47 }
48 fn description(&self) -> &str {
49 "Generate an image from a text prompt. Saves a PNG into the workspace and returns its path."
50 }
51 fn schema(&self) -> serde_json::Value {
52 json!({
53 "type": "object",
54 "properties": {
55 "prompt": { "type": "string", "description": "Image description" },
56 "filename": { "type": "string", "description": "Output filename (default: generated.png)" },
57 "size": { "type": "string", "description": "e.g. 1024x1024" }
58 },
59 "required": ["prompt"]
60 })
61 }
62 fn risk(&self) -> RiskLevel {
63 RiskLevel::Network
64 }
65 async fn call(&self, args: serde_json::Value, ctx: &ToolCtx) -> anyhow::Result<ToolResult> {
66 let Some(key) = resolve_key(&["IMAGE_API_KEY", "OPENAI_API_KEY"]) else {
67 return Ok(ToolResult::error(
68 "No image API key. Set IMAGE_API_KEY or OPENAI_API_KEY.",
69 ));
70 };
71 let prompt = args["prompt"].as_str().unwrap_or("");
72 let size = args["size"].as_str().unwrap_or("1024x1024");
73 let filename = args["filename"].as_str().unwrap_or("generated.png");
74
75 let endpoint = format!("{}/images/generations", self.base_url.trim_end_matches('/'));
76 if let Err(why) = crate::tools::search_and_web::validate_public_url(&endpoint) {
77 return Ok(ToolResult::error(format!(
78 "Refused IMAGE_API_BASE ({}): {}",
79 why, endpoint
80 )));
81 }
82 let client = reqwest::Client::new();
83 let resp = client
84 .post(&endpoint)
85 .bearer_auth(&key)
86 .json(&json!({
87 "model": self.model,
88 "prompt": prompt,
89 "size": size,
90 "n": 1,
91 "response_format": "b64_json"
92 }))
93 .send()
94 .await?;
95 if !resp.status().is_success() {
96 let status = resp.status();
97 let body = resp.text().await.unwrap_or_default();
98 return Ok(ToolResult::error(format!(
99 "image API error {}: {}",
100 status, body
101 )));
102 }
103 let value: serde_json::Value = resp.json().await?;
104 let b64 = value["data"][0]["b64_json"].as_str();
105 let url = value["data"][0]["url"].as_str();
106
107 if let Some(b64) = b64 {
108 let bytes = base64_decode::decode(b64)
109 .map_err(|e| anyhow::anyhow!("invalid base64 image: {}", e))?;
110 let path = super::resolve_workspace_path(&ctx.workspace_root, filename)?;
111 std::fs::write(&path, &bytes)?;
112 Ok(ToolResult::ok(vec![Block::Text(format!(
113 "image saved to {} ({} bytes)",
114 path.display(),
115 bytes.len()
116 ))]))
117 } else if let Some(url) = url {
118 Ok(ToolResult::ok(vec![Block::Text(format!(
119 "image generated: {}",
120 url
121 ))]))
122 } else {
123 Ok(ToolResult::error("image API returned no data"))
124 }
125 }
126}
127
128pub struct Tts {
131 base_url: String,
132 model: String,
133}
134
135impl Tts {
136 pub fn new() -> Self {
137 Self {
138 base_url: std::env::var("TTS_API_BASE")
139 .unwrap_or_else(|_| "https://api.openai.com/v1".into()),
140 model: std::env::var("TTS_MODEL").unwrap_or_else(|_| "gpt-4o-mini-tts".into()),
141 }
142 }
143}
144
145#[async_trait]
146impl Tool for Tts {
147 fn name(&self) -> &str {
148 "text_to_speech"
149 }
150 fn description(&self) -> &str {
151 "Synthesize speech from text via an OpenAI-compatible /audio/speech endpoint. Saves an audio file into the workspace."
152 }
153 fn schema(&self) -> serde_json::Value {
154 json!({
155 "type": "object",
156 "properties": {
157 "text": { "type": "string" },
158 "voice": { "type": "string", "description": "e.g. alloy" },
159 "filename": { "type": "string", "description": "default: speech.mp3" }
160 },
161 "required": ["text"]
162 })
163 }
164 fn risk(&self) -> RiskLevel {
165 RiskLevel::Network
166 }
167 async fn call(&self, args: serde_json::Value, ctx: &ToolCtx) -> anyhow::Result<ToolResult> {
168 let Some(key) = resolve_key(&["TTS_API_KEY", "OPENAI_API_KEY"]) else {
169 return Ok(ToolResult::error(
170 "No TTS API key. Set TTS_API_KEY or OPENAI_API_KEY.",
171 ));
172 };
173 let text = args["text"].as_str().unwrap_or("");
174 let voice = args["voice"].as_str().unwrap_or("alloy");
175 let filename = args["filename"].as_str().unwrap_or("speech.mp3");
176
177 let endpoint = format!("{}/audio/speech", self.base_url.trim_end_matches('/'));
178 if let Err(why) = crate::tools::search_and_web::validate_public_url(&endpoint) {
179 return Ok(ToolResult::error(format!(
180 "Refused TTS_API_BASE ({}): {}",
181 why, endpoint
182 )));
183 }
184 let client = reqwest::Client::new();
185 let resp = client
186 .post(&endpoint)
187 .bearer_auth(&key)
188 .json(&json!({ "model": self.model, "input": text, "voice": voice }))
189 .send()
190 .await?;
191 if !resp.status().is_success() {
192 let status = resp.status();
193 let body = resp.text().await.unwrap_or_default();
194 return Ok(ToolResult::error(format!(
195 "tts API error {}: {}",
196 status, body
197 )));
198 }
199 let bytes = resp.bytes().await?;
200 let path = super::resolve_workspace_path(&ctx.workspace_root, filename)?;
201 std::fs::write(&path, &bytes)?;
202 Ok(ToolResult::ok(vec![Block::Text(format!(
203 "audio saved to {} ({} bytes)",
204 path.display(),
205 bytes.len()
206 ))]))
207 }
208}
209
210pub struct Transcribe {
217 base_url: String,
218 model: String,
219}
220
221impl Transcribe {
222 pub fn new() -> Self {
223 Self {
224 base_url: std::env::var("TRANSCRIBE_API_BASE")
225 .unwrap_or_else(|_| "https://api.openai.com/v1".into()),
226 model: std::env::var("TRANSCRIBE_MODEL").unwrap_or_else(|_| "whisper-1".into()),
227 }
228 }
229}
230
231#[async_trait]
232impl Tool for Transcribe {
233 fn name(&self) -> &str {
234 "transcribe"
235 }
236 fn description(&self) -> &str {
237 "Transcribe an audio file in the workspace to text via an OpenAI-compatible /audio/transcriptions endpoint."
238 }
239 fn schema(&self) -> serde_json::Value {
240 json!({
241 "type": "object",
242 "properties": {
243 "path": { "type": "string", "description": "Workspace-relative path to the audio file" },
244 "language": { "type": "string", "description": "Optional ISO-639-1 language hint" }
245 },
246 "required": ["path"]
247 })
248 }
249 fn risk(&self) -> RiskLevel {
250 RiskLevel::Network
251 }
252 async fn call(&self, args: serde_json::Value, ctx: &ToolCtx) -> anyhow::Result<ToolResult> {
253 let Some(key) = resolve_key(&["TRANSCRIBE_API_KEY", "OPENAI_API_KEY"]) else {
254 return Ok(ToolResult::error(
255 "No transcription API key. Set TRANSCRIBE_API_KEY or OPENAI_API_KEY.",
256 ));
257 };
258 let path = args["path"].as_str().unwrap_or("");
259 if path.is_empty() {
260 return Ok(ToolResult::error("transcribe: missing 'path' argument"));
261 }
262 let full = super::resolve_workspace_path(&ctx.workspace_root, path)?;
263 if !full.exists() {
264 return Ok(ToolResult::error(format!("audio file not found: {}", path)));
265 }
266 let bytes = std::fs::read(&full)?;
267 let filename = full
268 .file_name()
269 .map(|s| s.to_string_lossy().to_string())
270 .unwrap_or_else(|| "audio.bin".into());
271 let mime = mime_guess::from_path(&full)
272 .first_or_octet_stream()
273 .to_string();
274
275 let part = reqwest::multipart::Part::bytes(bytes)
276 .file_name(filename)
277 .mime_str(&mime)
278 .unwrap_or_else(|_| reqwest::multipart::Part::text("")); let mut form = reqwest::multipart::Form::new()
280 .text("model", self.model.clone())
281 .part("file", part);
282 if let Some(lang) = args["language"].as_str() {
283 if !lang.is_empty() {
284 form = form.text("language", lang.to_string());
285 }
286 }
287
288 let endpoint = format!(
289 "{}/audio/transcriptions",
290 self.base_url.trim_end_matches('/')
291 );
292 if let Err(why) = crate::tools::search_and_web::validate_public_url(&endpoint) {
293 return Ok(ToolResult::error(format!(
294 "Refused TRANSCRIBE_API_BASE ({}): {}",
295 why, endpoint
296 )));
297 }
298 let client = reqwest::Client::new();
299 let resp = client
300 .post(&endpoint)
301 .bearer_auth(&key)
302 .multipart(form)
303 .send()
304 .await?;
305 if !resp.status().is_success() {
306 let status = resp.status();
307 let body = resp.text().await.unwrap_or_default();
308 return Ok(ToolResult::error(format!(
309 "transcribe API error {}: {}",
310 status, body
311 )));
312 }
313 let value: serde_json::Value = resp.json().await?;
314 let text = value["text"].as_str().unwrap_or("").to_string();
315 Ok(ToolResult::ok(vec![Block::Text(text)]))
316 }
317}
318
319mod base64_decode {
321 pub fn decode(s: &str) -> Result<Vec<u8>, &'static str> {
322 fn val(c: u8) -> Option<u8> {
323 match c {
324 b'A'..=b'Z' => Some(c - b'A'),
325 b'a'..=b'z' => Some(c - b'a' + 26),
326 b'0'..=b'9' => Some(c - b'0' + 52),
327 b'+' => Some(62),
328 b'/' => Some(63),
329 _ => None,
330 }
331 }
332 let mut out = Vec::with_capacity(s.len() / 4 * 3);
333 let mut buf = 0u32;
334 let mut bits = 0u32;
335 for &c in s.as_bytes() {
336 if c == b'=' || c == b'\n' || c == b'\r' {
337 continue;
338 }
339 let v = match val(c) {
340 Some(v) => v as u32,
341 None => return Err("invalid base64 char"),
342 };
343 buf = (buf << 6) | v;
344 bits += 6;
345 if bits >= 8 {
346 bits -= 8;
347 out.push((buf >> bits) as u8);
348 }
349 }
350 Ok(out)
351 }
352}