Skip to main content

sparrow/provider/
openai_compat.rs

1use async_trait::async_trait;
2use futures::stream::{self, StreamExt};
3use reqwest::Client;
4use serde_json::json;
5use std::collections::HashMap;
6
7use super::{Brain, BrainEvent, BrainRequest, BrainStream, ContentBlock, LatencyClass, ModelCaps};
8
9/// OpenAI-compatible adapter. Covers OpenAI, Groq, NVIDIA NIM, Together, Cerebras,
10/// OpenRouter, NovitaAI, Nous Portal, HuggingFace, Ollama, and custom endpoints.
11pub struct OpenAICompatAdapter {
12    model: String,
13    api_key: String,
14    base_url: String,
15    client: Client,
16    caps: ModelCaps,
17}
18
19impl OpenAICompatAdapter {
20    pub fn new(model: &str, api_key: impl Into<String>, base_url: &str) -> Self {
21        let model = model.to_string();
22        Self {
23            model,
24            api_key: api_key.into(),
25            base_url: base_url.to_string(),
26            client: Client::new(),
27            caps: ModelCaps::default(),
28        }
29    }
30
31    pub fn with_caps(mut self, caps: ModelCaps) -> Self {
32        self.caps = caps;
33        self
34    }
35
36    /// Create an Ollama adapter (OpenAI-compatible API on localhost)
37    pub fn ollama(model: &str, base_url: &str) -> Self {
38        // Ollama doesn't require an API key
39        Self::new(model, "ollama", base_url).with_caps(ModelCaps {
40            context_window: 32_768,
41            max_output: 8_000,
42            tools: true,
43            vision: false,
44            cost_input_per_mtok: 0.0,
45            cost_output_per_mtok: 0.0,
46            latency: LatencyClass::Medium,
47        })
48    }
49}
50
51fn build_chat_body(model: &str, req: &BrainRequest) -> serde_json::Value {
52    let mut messages: Vec<serde_json::Value> = Vec::new();
53
54    // Add system message
55    if let Some(sys) = &req.system {
56        messages.push(json!({
57            "role": "system",
58            "content": sys,
59        }));
60    }
61
62    // Convert messages
63    for msg in &req.messages {
64        if msg.role == "system" {
65            messages.push(json!({
66                "role": "system",
67                "content": msg.content.iter()
68                    .filter_map(|b| match b {
69                        ContentBlock::Text { text } => Some(text.clone()),
70                        _ => None,
71                    })
72                    .collect::<Vec<_>>()
73                    .join("\n"),
74            }));
75            continue;
76        }
77
78        let mut content: Vec<serde_json::Value> = Vec::new();
79        let mut tool_calls: Vec<serde_json::Value> = Vec::new();
80        let mut reasoning_buf = String::new();
81        let mut emitted_tool_result = false;
82
83        for block in &msg.content {
84            match block {
85                ContentBlock::Text { text } => {
86                    content.push(json!({"type": "text", "text": text}));
87                }
88                ContentBlock::Image { source } => {
89                    content.push(json!({
90                        "type": "image_url",
91                        "image_url": {
92                            "url": image_source_url(source),
93                        }
94                    }));
95                }
96                ContentBlock::Reasoning { text } => {
97                    // DeepSeek / Moonshot / Qwen "thinking mode" require the
98                    // model's previous reasoning_content to be echoed back
99                    // on the next turn or the API rejects with 400. We aggregate
100                    // all reasoning blocks of this message and ship them as a
101                    // single `reasoning_content` field.
102                    if !reasoning_buf.is_empty() {
103                        reasoning_buf.push('\n');
104                    }
105                    reasoning_buf.push_str(text);
106                }
107                ContentBlock::ToolUse { id, name, input } => {
108                    tool_calls.push(json!({
109                        "id": id,
110                        "type": "function",
111                        "function": {
112                            "name": name,
113                            "arguments": serde_json::to_string(input).unwrap_or_default(),
114                        }
115                    }));
116                }
117                ContentBlock::ToolResult {
118                    tool_use_id,
119                    content: tool_content,
120                    ..
121                } => {
122                    let text = tool_content
123                        .iter()
124                        .filter_map(|b| match b {
125                            ContentBlock::Text { text } => Some(text.clone()),
126                            _ => None,
127                        })
128                        .collect::<Vec<_>>()
129                        .join("\n");
130                    messages.push(json!({
131                        "role": "tool",
132                        "tool_call_id": tool_use_id,
133                        "content": text,
134                    }));
135                    emitted_tool_result = true;
136                    continue; // tool results are separate messages
137                }
138            }
139        }
140
141        if emitted_tool_result && content.is_empty() && tool_calls.is_empty() {
142            continue;
143        }
144
145        let mut msg_json = json!({ "role": msg.role });
146
147        if !tool_calls.is_empty() {
148            msg_json["tool_calls"] = json!(tool_calls);
149        }
150        if !content.is_empty() {
151            if content.len() == 1 && content[0]["type"] == "text" {
152                msg_json["content"] = json!(content[0]["text"]);
153            } else {
154                msg_json["content"] = json!(content);
155            }
156        }
157        if !reasoning_buf.is_empty() && msg.role == "assistant" {
158            msg_json["reasoning_content"] = json!(reasoning_buf);
159        }
160
161        messages.push(msg_json);
162    }
163
164    // Build tools
165    let tools: Vec<serde_json::Value> = req
166        .tools
167        .iter()
168        .map(|t| {
169            json!({
170                "type": "function",
171                "function": {
172                    "name": t.name,
173                    "description": t.description,
174                    "parameters": t.input_schema,
175                }
176            })
177        })
178        .collect();
179
180    let mut body = json!({
181        "model": model,
182        "messages": messages,
183        "stream": true,
184        "stream_options": {
185            "include_usage": true
186        },
187        "temperature": req.temperature,
188    });
189
190    if req.max_tokens > 0 {
191        body["max_tokens"] = json!(req.max_tokens);
192    }
193    if !tools.is_empty() {
194        body["tools"] = json!(tools);
195    }
196    if !req.stop.is_empty() {
197        body["stop"] = json!(req.stop);
198    }
199    if req.cache.enabled {
200        if let Some(key) = &req.cache.key {
201            body["prompt_cache_key"] = json!(key);
202        }
203        body["prompt_cache_retention"] = json!(req.cache.ttl.openai_retention());
204    }
205
206    body
207}
208
209fn image_source_url(source: &super::ImageSource) -> String {
210    match source {
211        super::ImageSource::Base64 { media_type, data } => {
212            format!("data:{};base64,{}", media_type, data)
213        }
214        super::ImageSource::Url { url } => url.clone(),
215    }
216}
217
218#[async_trait]
219impl Brain for OpenAICompatAdapter {
220    fn id(&self) -> &str {
221        &self.model
222    }
223
224    fn caps(&self) -> ModelCaps {
225        self.caps.clone()
226    }
227
228    async fn complete(&self, req: BrainRequest) -> anyhow::Result<BrainStream> {
229        let body = build_chat_body(&self.model, &req);
230
231        let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
232
233        let response = self
234            .client
235            .post(&url)
236            .header("Authorization", format!("Bearer {}", self.api_key))
237            .json(&body)
238            .send()
239            .await?;
240
241        if !response.status().is_success() {
242            let status = response.status().as_u16();
243            let body = response.text().await.unwrap_or_default();
244            return Err(anyhow::anyhow!(
245                "OpenAI-compatible API error {}: {}",
246                status,
247                body
248            ));
249        }
250
251        #[derive(Default)]
252        struct ToolCallState {
253            id: String,
254            started: bool,
255        }
256
257        let stream = response.bytes_stream();
258
259        // SSE state: tool-call accumulator + line buffer that survives chunk
260        // boundaries. Without the buffer, a JSON event split across two TCP
261        // chunks was parsed in halves and silently dropped — producing the
262        // "à rebours" → "àours" mangling.
263        struct SseState {
264            tools: HashMap<u64, ToolCallState>,
265            lines: super::sse_buffer::LineBuffer,
266        }
267
268        let event_stream = stream
269            .scan(
270                SseState {
271                    tools: HashMap::new(),
272                    lines: super::sse_buffer::LineBuffer::new(),
273                },
274                |state, chunk| {
275                    let events: Vec<BrainEvent> = match chunk {
276                        Ok(bytes) => {
277                            let lines = state.lines.push(&bytes);
278                            let tool_state = &mut state.tools;
279                            let mut parsed = Vec::new();
280                            for line in lines {
281                                let line = line.trim();
282                                if line.is_empty() || !line.starts_with("data: ") {
283                                    continue;
284                                }
285                                let data = &line[6..];
286                                if data == "[DONE]" {
287                                    continue;
288                                }
289                                let event: serde_json::Value = match serde_json::from_str(data) {
290                                    Ok(v) => v,
291                                    Err(e) => {
292                                        tracing::debug!(
293                                            "JSON parse error: {} — data: {}",
294                                            e,
295                                            &data[..data.len().min(200)]
296                                        );
297                                        continue;
298                                    }
299                                };
300
301                                if let Some(choices) = event["choices"].as_array() {
302                                    for choice in choices {
303                                        if let Some(delta) = choice["delta"].as_object() {
304                                            if let Some(text) =
305                                                delta.get("content").and_then(|v| v.as_str())
306                                            {
307                                                if !text.is_empty() {
308                                                    parsed.push(BrainEvent::TextDelta(
309                                                        text.to_string(),
310                                                    ));
311                                                }
312                                            }
313                                            // DeepSeek / Moonshot thinking-mode emit
314                                            // reasoning trace alongside content. Capture
315                                            // it as a dedicated event so the engine can
316                                            // echo it back on the next turn (required
317                                            // by DeepSeek's contract).
318                                            // Several providers report this under
319                                            // different keys; check the known aliases.
320                                            for key in [
321                                                "reasoning_content",
322                                                "reasoning",
323                                                "thinking",
324                                                "thought",
325                                            ] {
326                                                if let Some(rtext) =
327                                                    delta.get(key).and_then(|v| v.as_str())
328                                                {
329                                                    if !rtext.is_empty() {
330                                                        parsed.push(BrainEvent::ReasoningDelta(
331                                                            rtext.to_string(),
332                                                        ));
333                                                    }
334                                                }
335                                            }
336                                        }
337                                        // Some providers (non-streaming chunk at end of
338                                        // turn) bundle the reasoning under
339                                        // `message.reasoning_content` rather than
340                                        // streaming it through `delta`. Cover that path
341                                        // too — duplicate captures are harmless because
342                                        // the engine joins them.
343                                        if let Some(msg_obj) =
344                                            choice.get("message").and_then(|v| v.as_object())
345                                        {
346                                            for key in
347                                                ["reasoning_content", "reasoning", "thinking"]
348                                            {
349                                                if let Some(rtext) =
350                                                    msg_obj.get(key).and_then(|v| v.as_str())
351                                                {
352                                                    if !rtext.is_empty() {
353                                                        parsed.push(BrainEvent::ReasoningDelta(
354                                                            rtext.to_string(),
355                                                        ));
356                                                    }
357                                                }
358                                            }
359                                        }
360                                        if let Some(delta) = choice["delta"].as_object() {
361                                            // (Re-open the original tool_calls block.)
362                                            let _ = delta; // keep this branch syntactically anchored
363                                            if let Some(tool_calls) =
364                                                delta.get("tool_calls").and_then(|v| v.as_array())
365                                            {
366                                                for tc in tool_calls {
367                                                    let idx = tc
368                                                        .get("index")
369                                                        .and_then(|v| v.as_u64())
370                                                        .unwrap_or(0);
371                                                    let id = tc
372                                                        .get("id")
373                                                        .and_then(|v| v.as_str())
374                                                        .map(|s| s.to_string());
375                                                    let state = tool_state.entry(idx).or_default();
376                                                    if let Some(id) = id {
377                                                        state.id = id;
378                                                    }
379                                                    if let Some(func) = tc
380                                                        .get("function")
381                                                        .and_then(|v| v.as_object())
382                                                    {
383                                                        if let Some(name) = func
384                                                            .get("name")
385                                                            .and_then(|v| v.as_str())
386                                                        {
387                                                            if !state.started {
388                                                                if state.id.is_empty() {
389                                                                    state.id = format!(
390                                                                        "tool-call-{}",
391                                                                        idx
392                                                                    );
393                                                                }
394                                                                state.started = true;
395                                                                parsed.push(
396                                                                    BrainEvent::ToolUseStart {
397                                                                        id: state.id.clone(),
398                                                                        name: name.to_string(),
399                                                                    },
400                                                                );
401                                                            }
402                                                        }
403                                                        if let Some(args) = func
404                                                            .get("arguments")
405                                                            .and_then(|v| v.as_str())
406                                                        {
407                                                            if !state.id.is_empty()
408                                                                && !args.is_empty()
409                                                            {
410                                                                parsed.push(
411                                                                    BrainEvent::ToolUseDelta {
412                                                                        id: state.id.clone(),
413                                                                        json: args.to_string(),
414                                                                    },
415                                                                );
416                                                            }
417                                                        }
418                                                    }
419                                                }
420                                            }
421                                        }
422
423                                        if let Some(reason) =
424                                            choice.get("finish_reason").and_then(|v| v.as_str())
425                                        {
426                                            if !reason.is_empty() && reason != "null" {
427                                                let stop = match reason {
428                                                    "stop" => crate::event::StopReason::EndTurn,
429                                                    "length" => crate::event::StopReason::MaxTokens,
430                                                    "tool_calls" => {
431                                                        for (_, state) in tool_state.drain() {
432                                                            if !state.id.is_empty() {
433                                                                parsed.push(
434                                                                    BrainEvent::ToolUseEnd {
435                                                                        id: state.id,
436                                                                    },
437                                                                );
438                                                            }
439                                                        }
440                                                        crate::event::StopReason::ToolUse
441                                                    }
442                                                    s => crate::event::StopReason::StopSequence(
443                                                        s.to_string(),
444                                                    ),
445                                                };
446                                                parsed.push(BrainEvent::Done(stop));
447                                            }
448                                        }
449                                    }
450                                }
451
452                                if let Some(usage) = event.get("usage").and_then(|u| u.as_object())
453                                {
454                                    // Use .get() — indexing a serde_json::Map with [] panics on a
455                                    // missing key, and some providers (e.g. MiniMax) omit fields.
456                                    parsed.push(BrainEvent::Usage(crate::event::TokenUsage {
457                                        input: usage
458                                            .get("prompt_tokens")
459                                            .and_then(|v| v.as_u64())
460                                            .unwrap_or(0),
461                                        output: usage
462                                            .get("completion_tokens")
463                                            .and_then(|v| v.as_u64())
464                                            .unwrap_or(0),
465                                    }));
466                                }
467                            }
468                            parsed
469                        }
470                        Err(e) => vec![BrainEvent::Error(format!("stream error: {}", e))],
471                    };
472                    futures::future::ready(Some(stream::iter(events)))
473                },
474            )
475            .flatten();
476
477        Ok(Box::pin(event_stream))
478    }
479}
480
481#[cfg(test)]
482mod tests {
483    use super::*;
484    use crate::provider::{Msg, PromptCacheConfig, PromptCacheTtl};
485
486    #[test]
487    fn openai_chat_body_adds_prompt_cache_controls() {
488        let req = BrainRequest {
489            system: Some("stable sparrow system".into()),
490            messages: vec![Msg {
491                role: "user".into(),
492                content: vec![ContentBlock::Text {
493                    text: "dynamic task".into(),
494                }],
495            }],
496            cache: PromptCacheConfig {
497                enabled: true,
498                ttl: PromptCacheTtl::OneHour,
499                key: Some("sparrow-repo-abc".into()),
500            },
501            ..BrainRequest::default()
502        };
503
504        let body = build_chat_body("gpt-test", &req);
505        assert_eq!(body["prompt_cache_key"], "sparrow-repo-abc");
506        assert_eq!(body["prompt_cache_retention"], "in_memory");
507    }
508
509    #[test]
510    fn openai_chat_body_serializes_image_blocks() {
511        let req = BrainRequest {
512            messages: vec![Msg {
513                role: "user".into(),
514                content: vec![
515                    ContentBlock::Text {
516                        text: "what is in this image?".into(),
517                    },
518                    ContentBlock::Image {
519                        source: crate::provider::ImageSource::Base64 {
520                            media_type: "image/png".into(),
521                            data: "iVBORw0KGgo=".into(),
522                        },
523                    },
524                ],
525            }],
526            ..BrainRequest::default()
527        };
528
529        let body = build_chat_body("gpt-test", &req);
530        assert_eq!(body["messages"][0]["content"][0]["type"], "text");
531        assert_eq!(body["messages"][0]["content"][1]["type"], "image_url");
532        assert_eq!(
533            body["messages"][0]["content"][1]["image_url"]["url"],
534            "data:image/png;base64,iVBORw0KGgo="
535        );
536    }
537
538    #[test]
539    fn openai_chat_body_reinjects_assistant_reasoning_content() {
540        let req = BrainRequest {
541            messages: vec![Msg {
542                role: "assistant".into(),
543                content: vec![
544                    ContentBlock::Reasoning {
545                        text: "opaque provider reasoning".into(),
546                    },
547                    ContentBlock::Text {
548                        text: "visible answer".into(),
549                    },
550                ],
551            }],
552            ..BrainRequest::default()
553        };
554
555        let body = build_chat_body("deepseek-test", &req);
556        assert_eq!(body["messages"][0]["content"], "visible answer");
557        assert_eq!(
558            body["messages"][0]["reasoning_content"],
559            "opaque provider reasoning"
560        );
561    }
562}