Skip to main content

sparrow/tools/
media.rs

1//! Multimodal tools: image generation + text-to-speech (§15).
2//!
3//! Both target OpenAI-compatible endpoints so any provider that exposes
4//! `/images/generations` or `/audio/speech` works. The user supplies the key
5//! (env var, resolved at call time). No fake success: a missing key or a
6//! non-2xx response returns a real error.
7
8use async_trait::async_trait;
9use serde_json::json;
10
11use super::{Tool, ToolCtx, ToolResult};
12use crate::event::{Block, RiskLevel};
13
14fn resolve_key(env_names: &[&str]) -> Option<String> {
15    for name in env_names {
16        if let Ok(v) = std::env::var(name) {
17            if !v.trim().is_empty() {
18                return Some(v);
19            }
20        }
21    }
22    None
23}
24
25// ─── Image generation ─────────────────────────────────────────────────────────
26
27/// Generate an image from a prompt via an OpenAI-compatible images endpoint.
28pub struct ImageGen {
29    base_url: String,
30    model: String,
31}
32
33impl ImageGen {
34    pub fn new() -> Self {
35        Self {
36            base_url: std::env::var("IMAGE_API_BASE")
37                .unwrap_or_else(|_| "https://api.openai.com/v1".into()),
38            model: std::env::var("IMAGE_MODEL").unwrap_or_else(|_| "gpt-image-1".into()),
39        }
40    }
41}
42
43#[async_trait]
44impl Tool for ImageGen {
45    fn name(&self) -> &str {
46        "image_generate"
47    }
48    fn description(&self) -> &str {
49        "Generate an image from a text prompt. Saves a PNG into the workspace and returns its path."
50    }
51    fn schema(&self) -> serde_json::Value {
52        json!({
53            "type": "object",
54            "properties": {
55                "prompt": { "type": "string", "description": "Image description" },
56                "filename": { "type": "string", "description": "Output filename (default: generated.png)" },
57                "size": { "type": "string", "description": "e.g. 1024x1024" }
58            },
59            "required": ["prompt"]
60        })
61    }
62    fn risk(&self) -> RiskLevel {
63        RiskLevel::Network
64    }
65    async fn call(&self, args: serde_json::Value, ctx: &ToolCtx) -> anyhow::Result<ToolResult> {
66        let Some(key) = resolve_key(&["IMAGE_API_KEY", "OPENAI_API_KEY"]) else {
67            return Ok(ToolResult::error(
68                "No image API key. Set IMAGE_API_KEY or OPENAI_API_KEY.",
69            ));
70        };
71        let prompt = args["prompt"].as_str().unwrap_or("");
72        let size = args["size"].as_str().unwrap_or("1024x1024");
73        let filename = args["filename"].as_str().unwrap_or("generated.png");
74
75        let endpoint = format!("{}/images/generations", self.base_url.trim_end_matches('/'));
76        if let Err(why) = crate::tools::search_and_web::validate_public_url(&endpoint) {
77            return Ok(ToolResult::error(format!(
78                "Refused IMAGE_API_BASE ({}): {}",
79                why, endpoint
80            )));
81        }
82        let client = reqwest::Client::new();
83        let resp = client
84            .post(&endpoint)
85            .bearer_auth(&key)
86            .json(&json!({
87                "model": self.model,
88                "prompt": prompt,
89                "size": size,
90                "n": 1,
91                "response_format": "b64_json"
92            }))
93            .send()
94            .await?;
95        if !resp.status().is_success() {
96            let status = resp.status();
97            let body = resp.text().await.unwrap_or_default();
98            return Ok(ToolResult::error(format!(
99                "image API error {}: {}",
100                status, body
101            )));
102        }
103        let value: serde_json::Value = resp.json().await?;
104        let b64 = value["data"][0]["b64_json"].as_str();
105        let url = value["data"][0]["url"].as_str();
106
107        if let Some(b64) = b64 {
108            let bytes = base64_decode::decode(b64)
109                .map_err(|e| anyhow::anyhow!("invalid base64 image: {}", e))?;
110            let path = super::resolve_workspace_path(&ctx.workspace_root, filename)?;
111            std::fs::write(&path, &bytes)?;
112            Ok(ToolResult::ok(vec![Block::Text(format!(
113                "image saved to {} ({} bytes)",
114                path.display(),
115                bytes.len()
116            ))]))
117        } else if let Some(url) = url {
118            Ok(ToolResult::ok(vec![Block::Text(format!(
119                "image generated: {}",
120                url
121            ))]))
122        } else {
123            Ok(ToolResult::error("image API returned no data"))
124        }
125    }
126}
127
128// ─── Text to speech ─────────────────────────────────────────────────────────
129
130pub struct Tts {
131    base_url: String,
132    model: String,
133}
134
135impl Tts {
136    pub fn new() -> Self {
137        Self {
138            base_url: std::env::var("TTS_API_BASE")
139                .unwrap_or_else(|_| "https://api.openai.com/v1".into()),
140            model: std::env::var("TTS_MODEL").unwrap_or_else(|_| "gpt-4o-mini-tts".into()),
141        }
142    }
143}
144
145#[async_trait]
146impl Tool for Tts {
147    fn name(&self) -> &str {
148        "text_to_speech"
149    }
150    fn description(&self) -> &str {
151        "Synthesize speech from text via an OpenAI-compatible /audio/speech endpoint. Saves an audio file into the workspace."
152    }
153    fn schema(&self) -> serde_json::Value {
154        json!({
155            "type": "object",
156            "properties": {
157                "text": { "type": "string" },
158                "voice": { "type": "string", "description": "e.g. alloy" },
159                "filename": { "type": "string", "description": "default: speech.mp3" }
160            },
161            "required": ["text"]
162        })
163    }
164    fn risk(&self) -> RiskLevel {
165        RiskLevel::Network
166    }
167    async fn call(&self, args: serde_json::Value, ctx: &ToolCtx) -> anyhow::Result<ToolResult> {
168        let Some(key) = resolve_key(&["TTS_API_KEY", "OPENAI_API_KEY"]) else {
169            return Ok(ToolResult::error(
170                "No TTS API key. Set TTS_API_KEY or OPENAI_API_KEY.",
171            ));
172        };
173        let text = args["text"].as_str().unwrap_or("");
174        let voice = args["voice"].as_str().unwrap_or("alloy");
175        let filename = args["filename"].as_str().unwrap_or("speech.mp3");
176
177        let endpoint = format!("{}/audio/speech", self.base_url.trim_end_matches('/'));
178        if let Err(why) = crate::tools::search_and_web::validate_public_url(&endpoint) {
179            return Ok(ToolResult::error(format!(
180                "Refused TTS_API_BASE ({}): {}",
181                why, endpoint
182            )));
183        }
184        let client = reqwest::Client::new();
185        let resp = client
186            .post(&endpoint)
187            .bearer_auth(&key)
188            .json(&json!({ "model": self.model, "input": text, "voice": voice }))
189            .send()
190            .await?;
191        if !resp.status().is_success() {
192            let status = resp.status();
193            let body = resp.text().await.unwrap_or_default();
194            return Ok(ToolResult::error(format!(
195                "tts API error {}: {}",
196                status, body
197            )));
198        }
199        let bytes = resp.bytes().await?;
200        let path = super::resolve_workspace_path(&ctx.workspace_root, filename)?;
201        std::fs::write(&path, &bytes)?;
202        Ok(ToolResult::ok(vec![Block::Text(format!(
203            "audio saved to {} ({} bytes)",
204            path.display(),
205            bytes.len()
206        ))]))
207    }
208}
209
210// ─── Speech to text (Transcribe) ────────────────────────────────────────────────
211//
212// Voice-mode building block: posts a workspace audio file to an OpenAI-compatible
213// `/audio/transcriptions` endpoint and returns the transcript as text. Missing
214// key or non-2xx response is an HONEST error — never a fake success.
215
216pub struct Transcribe {
217    base_url: String,
218    model: String,
219}
220
221impl Transcribe {
222    pub fn new() -> Self {
223        Self {
224            base_url: std::env::var("TRANSCRIBE_API_BASE")
225                .unwrap_or_else(|_| "https://api.openai.com/v1".into()),
226            model: std::env::var("TRANSCRIBE_MODEL").unwrap_or_else(|_| "whisper-1".into()),
227        }
228    }
229}
230
231#[async_trait]
232impl Tool for Transcribe {
233    fn name(&self) -> &str {
234        "transcribe"
235    }
236    fn description(&self) -> &str {
237        "Transcribe an audio file in the workspace to text via an OpenAI-compatible /audio/transcriptions endpoint."
238    }
239    fn schema(&self) -> serde_json::Value {
240        json!({
241            "type": "object",
242            "properties": {
243                "path": { "type": "string", "description": "Workspace-relative path to the audio file" },
244                "language": { "type": "string", "description": "Optional ISO-639-1 language hint" }
245            },
246            "required": ["path"]
247        })
248    }
249    fn risk(&self) -> RiskLevel {
250        RiskLevel::Network
251    }
252    async fn call(&self, args: serde_json::Value, ctx: &ToolCtx) -> anyhow::Result<ToolResult> {
253        let Some(key) = resolve_key(&["TRANSCRIBE_API_KEY", "OPENAI_API_KEY"]) else {
254            return Ok(ToolResult::error(
255                "No transcription API key. Set TRANSCRIBE_API_KEY or OPENAI_API_KEY.",
256            ));
257        };
258        let path = args["path"].as_str().unwrap_or("");
259        if path.is_empty() {
260            return Ok(ToolResult::error("transcribe: missing 'path' argument"));
261        }
262        let full = super::resolve_workspace_path(&ctx.workspace_root, path)?;
263        if !full.exists() {
264            return Ok(ToolResult::error(format!("audio file not found: {}", path)));
265        }
266        let bytes = std::fs::read(&full)?;
267        let filename = full
268            .file_name()
269            .map(|s| s.to_string_lossy().to_string())
270            .unwrap_or_else(|| "audio.bin".into());
271        let mime = mime_guess::from_path(&full)
272            .first_or_octet_stream()
273            .to_string();
274
275        let part = reqwest::multipart::Part::bytes(bytes)
276            .file_name(filename)
277            .mime_str(&mime)
278            .unwrap_or_else(|_| reqwest::multipart::Part::text("")); // mime parse rarely fails
279        let mut form = reqwest::multipart::Form::new()
280            .text("model", self.model.clone())
281            .part("file", part);
282        if let Some(lang) = args["language"].as_str() {
283            if !lang.is_empty() {
284                form = form.text("language", lang.to_string());
285            }
286        }
287
288        let endpoint = format!(
289            "{}/audio/transcriptions",
290            self.base_url.trim_end_matches('/')
291        );
292        if let Err(why) = crate::tools::search_and_web::validate_public_url(&endpoint) {
293            return Ok(ToolResult::error(format!(
294                "Refused TRANSCRIBE_API_BASE ({}): {}",
295                why, endpoint
296            )));
297        }
298        let client = reqwest::Client::new();
299        let resp = client
300            .post(&endpoint)
301            .bearer_auth(&key)
302            .multipart(form)
303            .send()
304            .await?;
305        if !resp.status().is_success() {
306            let status = resp.status();
307            let body = resp.text().await.unwrap_or_default();
308            return Ok(ToolResult::error(format!(
309                "transcribe API error {}: {}",
310                status, body
311            )));
312        }
313        let value: serde_json::Value = resp.json().await?;
314        let text = value["text"].as_str().unwrap_or("").to_string();
315        Ok(ToolResult::ok(vec![Block::Text(text)]))
316    }
317}
318
319// Minimal base64 decoder (avoid adding a crate). Standard alphabet, no padding strictness.
320mod base64_decode {
321    pub fn decode(s: &str) -> Result<Vec<u8>, &'static str> {
322        fn val(c: u8) -> Option<u8> {
323            match c {
324                b'A'..=b'Z' => Some(c - b'A'),
325                b'a'..=b'z' => Some(c - b'a' + 26),
326                b'0'..=b'9' => Some(c - b'0' + 52),
327                b'+' => Some(62),
328                b'/' => Some(63),
329                _ => None,
330            }
331        }
332        let mut out = Vec::with_capacity(s.len() / 4 * 3);
333        let mut buf = 0u32;
334        let mut bits = 0u32;
335        for &c in s.as_bytes() {
336            if c == b'=' || c == b'\n' || c == b'\r' {
337                continue;
338            }
339            let v = match val(c) {
340                Some(v) => v as u32,
341                None => return Err("invalid base64 char"),
342            };
343            buf = (buf << 6) | v;
344            bits += 6;
345            if bits >= 8 {
346                bits -= 8;
347                out.push((buf >> bits) as u8);
348            }
349        }
350        Ok(out)
351    }
352}