Skip to main content

sparrow_providers/
responses.rs

1use async_trait::async_trait;
2use futures::StreamExt;
3use reqwest::Client;
4use serde_json::json;
5
6use super::{Brain, BrainEvent, BrainRequest, BrainStream, LatencyClass, ModelCaps};
7
8// ─── OpenAI Responses API adapter ───────────────────────────────────────────────
9
10/// OpenAI's newer Responses API (vs Chat Completions).
11/// Uses the /v1/responses endpoint with its own message format.
12pub struct OpenAIResponsesAdapter {
13    model: String,
14    api_key: String,
15    base_url: String,
16    client: Client,
17    caps: ModelCaps,
18}
19
20impl OpenAIResponsesAdapter {
21    pub fn new(model: &str, api_key: impl Into<String>, base_url: Option<&str>) -> Self {
22        let model = model.to_string();
23        Self {
24            model,
25            api_key: api_key.into(),
26            base_url: base_url.unwrap_or("https://api.openai.com/v1").to_string(),
27            client: Client::new(),
28            caps: ModelCaps {
29                context_window: 128_000,
30                max_output: 16_000,
31                tools: true,
32                vision: true,
33                cost_input_per_mtok: 2.5,
34                cost_output_per_mtok: 10.0,
35                latency: LatencyClass::Medium,
36            },
37        }
38    }
39
40    /// Override capabilities with those from the static registry (cost, context, etc.)
41    pub fn with_caps(mut self, caps: ModelCaps) -> Self {
42        self.caps = caps;
43        self
44    }
45}
46
47fn build_responses_body(model: &str, req: &BrainRequest) -> serde_json::Value {
48    // Convert to Responses API format
49    let mut input: Vec<serde_json::Value> = Vec::new();
50
51    if let Some(sys) = &req.system {
52        input.push(json!({
53            "role": "system",
54            "content": sys,
55        }));
56    }
57
58    for msg in &req.messages {
59        let mut reasoning_buf = String::new();
60        let mut content_blocks: Vec<serde_json::Value> = Vec::new();
61        for block in &msg.content {
62            match block {
63                super::ContentBlock::Text { text } => {
64                    if msg.role == "assistant" {
65                        content_blocks.push(json!({
66                            "type": "output_text",
67                            "text": text,
68                        }));
69                    } else {
70                        content_blocks.push(json!({
71                            "type": "input_text",
72                            "text": text,
73                        }));
74                    }
75                }
76                super::ContentBlock::Image { source } => {
77                    content_blocks.push(json!({
78                        "type": "input_image",
79                        "image_url": image_source_url(source),
80                    }));
81                }
82                super::ContentBlock::Reasoning { text } => {
83                    if !reasoning_buf.is_empty() {
84                        reasoning_buf.push('\n');
85                    }
86                    reasoning_buf.push_str(text);
87                }
88                _ => {}
89            }
90        }
91        let content = if content_blocks.len() == 1
92            && content_blocks[0]["type"].as_str() == Some("input_text")
93        {
94            content_blocks[0]["text"].clone()
95        } else if content_blocks.len() == 1
96            && content_blocks[0]["type"].as_str() == Some("output_text")
97        {
98            content_blocks[0]["text"].clone()
99        } else {
100            json!(content_blocks)
101        };
102
103        let mut item = json!({
104            "role": msg.role,
105            "content": content,
106        });
107        if msg.role == "assistant" && !reasoning_buf.is_empty() {
108            item["reasoning_content"] = json!(reasoning_buf);
109        }
110        input.push(item);
111    }
112
113    let mut body = json!({
114        "model": model,
115        "input": input,
116        "stream": true,
117        "temperature": req.temperature,
118        "max_output_tokens": req.max_tokens,
119    });
120
121    if req.cache.enabled {
122        if let Some(key) = &req.cache.key {
123            body["prompt_cache_key"] = json!(key);
124        }
125        body["prompt_cache_retention"] = json!(req.cache.ttl.openai_retention());
126    }
127
128    body
129}
130
131fn image_source_url(source: &super::ImageSource) -> String {
132    match source {
133        super::ImageSource::Base64 { media_type, data } => {
134            format!("data:{};base64,{}", media_type, data)
135        }
136        super::ImageSource::Url { url } => url.clone(),
137    }
138}
139
140fn push_responses_events(val: &serde_json::Value, events: &mut Vec<BrainEvent>) {
141    let event_type = val["type"].as_str().unwrap_or("");
142    if let Some(delta) = val["delta"].as_str() {
143        if event_type.contains("reasoning") || event_type.contains("thinking") {
144            events.push(BrainEvent::ReasoningDelta(delta.to_string()));
145        } else {
146            events.push(BrainEvent::TextDelta(delta.to_string()));
147        }
148    }
149
150    for key in [
151        "reasoning_content",
152        "reasoning",
153        "thinking",
154        "reasoning_summary_text",
155    ] {
156        if let Some(text) = val.get(key).and_then(|v| v.as_str()) {
157            if !text.is_empty() {
158                events.push(BrainEvent::ReasoningDelta(text.to_string()));
159            }
160        }
161    }
162
163    if let Some(response) = val.get("response") {
164        collect_nested_reasoning(response, events);
165    }
166    if event_type == "response.completed" {
167        events.push(BrainEvent::Done(sparrow_core::event::StopReason::EndTurn));
168    }
169}
170
171fn collect_nested_reasoning(value: &serde_json::Value, events: &mut Vec<BrainEvent>) {
172    match value {
173        serde_json::Value::Array(items) => {
174            for item in items {
175                collect_nested_reasoning(item, events);
176            }
177        }
178        serde_json::Value::Object(map) => {
179            for (key, value) in map {
180                if matches!(
181                    key.as_str(),
182                    "reasoning_content" | "reasoning" | "thinking" | "reasoning_summary_text"
183                ) {
184                    if let Some(text) = value.as_str() {
185                        if !text.is_empty() {
186                            events.push(BrainEvent::ReasoningDelta(text.to_string()));
187                        }
188                    }
189                }
190                collect_nested_reasoning(value, events);
191            }
192        }
193        _ => {}
194    }
195}
196
197#[async_trait]
198impl Brain for OpenAIResponsesAdapter {
199    fn id(&self) -> &str {
200        &self.model
201    }
202
203    fn caps(&self) -> ModelCaps {
204        self.caps.clone()
205    }
206
207    async fn complete(&self, req: BrainRequest) -> anyhow::Result<BrainStream> {
208        let body = build_responses_body(&self.model, &req);
209
210        let response = self
211            .client
212            .post(format!("{}/responses", self.base_url))
213            .header("Authorization", format!("Bearer {}", self.api_key))
214            .json(&body)
215            .send()
216            .await?;
217
218        if !response.status().is_success() {
219            let status = response.status().as_u16();
220            let body = response.text().await.unwrap_or_default();
221            return Err(anyhow::anyhow!(
222                "OpenAI Responses API error {}: {}",
223                status,
224                body
225            ));
226        }
227
228        let stream = response.bytes_stream();
229        // SSE frames split across TCP chunks must be reassembled — see
230        // provider/sse_buffer.rs. Without this, words/characters mid-stream
231        // silently disappear (the original symptom: streamed text mangling).
232        let event_stream = futures::stream::unfold(
233            (stream, false, super::sse_buffer::LineBuffer::new()),
234            |(mut stream, done, mut buf)| async move {
235                if done {
236                    return None;
237                }
238                match futures::StreamExt::next(&mut stream).await {
239                    Some(Ok(bytes)) => {
240                        let mut events = Vec::new();
241                        for line in buf.push(&bytes) {
242                            let line = line.trim();
243                            if line.is_empty() || !line.starts_with("data: ") {
244                                continue;
245                            }
246                            let data = &line[6..];
247                            if data == "[DONE]" {
248                                events.push(BrainEvent::Done(
249                                    sparrow_core::event::StopReason::EndTurn,
250                                ));
251                                continue;
252                            }
253                            if let Ok(val) = serde_json::from_str::<serde_json::Value>(data) {
254                                push_responses_events(&val, &mut events);
255                            }
256                        }
257                        Some((futures::stream::iter(events), (stream, false, buf)))
258                    }
259                    Some(Err(e)) => Some((
260                        futures::stream::iter(vec![BrainEvent::Error(format!(
261                            "stream error: {}",
262                            e
263                        ))]),
264                        (stream, true, buf),
265                    )),
266                    None => None,
267                }
268            },
269        )
270        .flatten();
271
272        Ok(Box::pin(event_stream))
273    }
274}
275
276// ─── AWS Bedrock adapter (Converse API) ────────────────────────────────────────
277
278pub struct BedrockAdapter {
279    model_id: String,
280    // Kept for future SigV4 implementation; suppressed from dead-code warnings.
281    #[allow(dead_code)]
282    region: String,
283    #[allow(dead_code)]
284    access_key: String,
285    #[allow(dead_code)]
286    secret_key: String,
287    #[allow(dead_code)]
288    client: Client,
289    caps: ModelCaps,
290}
291
292impl BedrockAdapter {
293    pub fn new(
294        model_id: &str,
295        region: &str,
296        access_key: impl Into<String>,
297        secret_key: impl Into<String>,
298    ) -> Self {
299        let model_id = model_id.to_string();
300        Self {
301            model_id,
302            region: region.to_string(),
303            access_key: access_key.into(),
304            secret_key: secret_key.into(),
305            client: Client::new(),
306            caps: ModelCaps {
307                context_window: 200_000,
308                max_output: 8_000,
309                tools: true,
310                vision: true,
311                cost_input_per_mtok: 3.0,
312                cost_output_per_mtok: 15.0,
313                latency: LatencyClass::Medium,
314            },
315        }
316    }
317}
318
319#[async_trait]
320impl Brain for BedrockAdapter {
321    fn id(&self) -> &str {
322        &self.model_id
323    }
324
325    fn caps(&self) -> ModelCaps {
326        self.caps.clone()
327    }
328
329    async fn complete(&self, _req: BrainRequest) -> anyhow::Result<BrainStream> {
330        // Bedrock Converse requires AWS SigV4 request signing and parses an
331        // EventStream binary frame format on the response. Neither is implemented
332        // here — the previous code sent fake `X-Amz-Access-Key` / `X-Amz-Secret-Key`
333        // headers (which Bedrock ignores) and wrapped raw response bytes in a
334        // single `TextDelta`, producing garbled output even when authentication
335        // somehow succeeded. Rather than ship a stub that silently fails, we
336        // surface a clear error so callers can route around it.
337        anyhow::bail!(
338            "Bedrock provider is not implemented (model={}). \
339             AWS SigV4 signing + Bedrock EventStream parsing are missing. \
340             Use anthropic:* or openai:* directly, or pin a different provider in your config.",
341            self.model_id
342        )
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349    use crate::{ContentBlock, Msg, PromptCacheConfig, PromptCacheTtl};
350
351    #[test]
352    fn responses_body_adds_prompt_cache_controls() {
353        let req = BrainRequest {
354            system: Some("stable sparrow system".into()),
355            messages: vec![Msg {
356                role: "user".into(),
357                content: vec![ContentBlock::Text {
358                    text: "dynamic task".into(),
359                }],
360            }],
361            cache: PromptCacheConfig {
362                enabled: true,
363                ttl: PromptCacheTtl::OneHour,
364                key: Some("sparrow-repo-abc".into()),
365            },
366            ..BrainRequest::default()
367        };
368
369        let body = build_responses_body("gpt-test", &req);
370        assert_eq!(body["prompt_cache_key"], "sparrow-repo-abc");
371        assert_eq!(body["prompt_cache_retention"], "in_memory");
372    }
373
374    #[test]
375    fn responses_body_reinjects_assistant_reasoning_content() {
376        let req = BrainRequest {
377            messages: vec![Msg {
378                role: "assistant".into(),
379                content: vec![
380                    ContentBlock::Reasoning {
381                        text: "private reasoning state".into(),
382                    },
383                    ContentBlock::Text {
384                        text: "visible answer".into(),
385                    },
386                ],
387            }],
388            ..BrainRequest::default()
389        };
390
391        let body = build_responses_body("gpt-test", &req);
392        assert_eq!(body["input"][0]["content"], "visible answer");
393        assert_eq!(
394            body["input"][0]["reasoning_content"],
395            "private reasoning state"
396        );
397    }
398
399    #[test]
400    fn responses_body_serializes_image_blocks() {
401        let req = BrainRequest {
402            messages: vec![Msg {
403                role: "user".into(),
404                content: vec![
405                    ContentBlock::Text {
406                        text: "describe this".into(),
407                    },
408                    ContentBlock::Image {
409                        source: crate::ImageSource::Base64 {
410                            media_type: "image/png".into(),
411                            data: "iVBORw0KGgo=".into(),
412                        },
413                    },
414                ],
415            }],
416            ..BrainRequest::default()
417        };
418
419        let body = build_responses_body("gpt-test", &req);
420        assert_eq!(body["input"][0]["content"][0]["type"], "input_text");
421        assert_eq!(body["input"][0]["content"][1]["type"], "input_image");
422        assert_eq!(
423            body["input"][0]["content"][1]["image_url"],
424            "data:image/png;base64,iVBORw0KGgo="
425        );
426    }
427
428    #[test]
429    fn responses_events_capture_reasoning_delta_without_visible_text() {
430        let event = json!({
431            "type": "response.reasoning_summary_text.delta",
432            "delta": "reasoning chunk"
433        });
434        let mut events = Vec::new();
435        push_responses_events(&event, &mut events);
436
437        assert!(matches!(
438            events.as_slice(),
439            [BrainEvent::ReasoningDelta(text)] if text == "reasoning chunk"
440        ));
441    }
442}