Skip to main content

sparrow/provider/
ollama.rs

1use async_trait::async_trait;
2use futures::stream::{self, StreamExt};
3use reqwest::Client;
4use serde_json::json;
5
6use super::{Brain, BrainEvent, BrainRequest, BrainStream, ContentBlock, LatencyClass, ModelCaps};
7
8/// Native Ollama adapter using `/api/chat` with NDJSON streaming.
9/// Ollama does not use OpenAI-compatible tool format; it has its own.
10pub struct OllamaAdapter {
11    model: String,
12    base_url: String,
13    client: Client,
14    caps: ModelCaps,
15}
16
17impl OllamaAdapter {
18    pub fn new(model: &str, base_url: &str) -> Self {
19        Self {
20            model: model.to_string(),
21            base_url: base_url
22                .trim_end_matches("/v1")
23                .trim_end_matches('/')
24                .to_string(),
25            client: Client::new(),
26            caps: ModelCaps {
27                context_window: 32_768,
28                max_output: 8_000,
29                tools: true,
30                vision: false,
31                cost_input_per_mtok: 0.0,
32                cost_output_per_mtok: 0.0,
33                latency: LatencyClass::Medium,
34            },
35        }
36    }
37
38    pub fn with_caps(mut self, caps: ModelCaps) -> Self {
39        self.caps = caps;
40        self
41    }
42
43    /// Convert Sparrow Msg into Ollama's native format
44    fn build_ollama_messages(req: &BrainRequest) -> Vec<serde_json::Value> {
45        let mut messages: Vec<serde_json::Value> = Vec::new();
46
47        if let Some(sys) = &req.system {
48            messages.push(json!({"role": "system", "content": sys}));
49        }
50
51        for msg in &req.messages {
52            let role = match msg.role.as_str() {
53                "assistant" => "assistant",
54                _ => "user",
55            };
56
57            let mut content = String::new();
58            let mut tool_calls: Vec<serde_json::Value> = Vec::new();
59
60            for block in &msg.content {
61                match block {
62                    ContentBlock::Text { text } => {
63                        content.push_str(text);
64                    }
65                    ContentBlock::ToolUse { id: _, name, input } => {
66                        tool_calls.push(json!({
67                            "function": {
68                                "name": name,
69                                "arguments": input,
70                            }
71                        }));
72                    }
73                    ContentBlock::ToolResult {
74                        tool_use_id,
75                        content: blocks,
76                        is_error: _,
77                    } => {
78                        let text: String = blocks
79                            .iter()
80                            .filter_map(|b| match b {
81                                ContentBlock::Text { text } => Some(text.as_str()),
82                                _ => None,
83                            })
84                            .collect::<Vec<_>>()
85                            .join("\n");
86                        // Ollama native: tool results are user messages with tool_call_id
87                        messages.push(json!({
88                            "role": "tool",
89                            "content": text,
90                            "tool_call_id": tool_use_id,
91                        }));
92                    }
93                    _ => {}
94                }
95            }
96
97            if !content.is_empty() || tool_calls.is_empty() {
98                let mut msg_json = json!({"role": role, "content": content});
99                if !tool_calls.is_empty() {
100                    msg_json["tool_calls"] = json!(tool_calls);
101                }
102                messages.push(msg_json);
103            }
104        }
105
106        messages
107    }
108
109    /// Convert Sparrow ToolSpec to Ollama tool format
110    fn build_ollama_tools(tools: &[super::ToolSpec]) -> Vec<serde_json::Value> {
111        if tools.is_empty() {
112            return vec![];
113        }
114        tools
115            .iter()
116            .map(|t| {
117                json!({
118                    "type": "function",
119                    "function": {
120                        "name": t.name,
121                        "description": t.description,
122                        "parameters": t.input_schema,
123                    }
124                })
125            })
126            .collect()
127    }
128}
129
130#[async_trait]
131impl Brain for OllamaAdapter {
132    fn id(&self) -> &str {
133        &self.model
134    }
135
136    fn caps(&self) -> ModelCaps {
137        self.caps.clone()
138    }
139
140    async fn complete(&self, req: BrainRequest) -> anyhow::Result<BrainStream> {
141        let messages = Self::build_ollama_messages(&req);
142        let tools = Self::build_ollama_tools(&req.tools);
143
144        let mut body = json!({
145            "model": self.model,
146            "messages": messages,
147            "stream": true,
148            "options": {
149                "temperature": req.temperature as f64,
150            }
151        });
152
153        if req.max_tokens > 0 {
154            body["options"]["num_predict"] = json!(req.max_tokens);
155        }
156        if !tools.is_empty() {
157            body["tools"] = json!(tools);
158        }
159
160        let url = format!("{}/api/chat", self.base_url);
161
162        let response = self.client.post(&url).json(&body).send().await?;
163
164        if !response.status().is_success() {
165            let status = response.status().as_u16();
166            let body = response.text().await.unwrap_or_default();
167            return Err(anyhow::anyhow!("Ollama API error {}: {}", status, body));
168        }
169
170        let stream = response.bytes_stream();
171
172        // NDJSON across chunk boundaries needs the same line buffer the SSE
173        // providers use (see provider/sse_buffer.rs) — without it a JSON object
174        // split between two TCP chunks gets dropped silently.
175        let event_stream = stream
176            .scan(super::sse_buffer::LineBuffer::new(), |line_buf, chunk| {
177                let events: Vec<BrainEvent> = match chunk {
178                    Ok(bytes) => {
179                        let lines = line_buf.push(&bytes);
180                        let mut parsed = Vec::new();
181                        for line in lines {
182                            let line = line.trim();
183                            if line.is_empty() {
184                                continue;
185                            }
186                            let event: serde_json::Value = match serde_json::from_str(line) {
187                                Ok(v) => v,
188                                Err(_) => continue,
189                            };
190
191                            // Ollama NDJSON: {"message":{"content":"..."}} or {"message":{"tool_calls":[...]}}
192                            if let Some(msg) = event.get("message") {
193                                // Text delta (Ollama streams full message each line, not deltas)
194                                if let Some(content) = msg.get("content").and_then(|v| v.as_str()) {
195                                    if !content.is_empty() {
196                                        parsed.push(BrainEvent::TextDelta(content.to_string()));
197                                    }
198                                }
199                                // Tool calls
200                                if let Some(tc_array) =
201                                    msg.get("tool_calls").and_then(|v| v.as_array())
202                                {
203                                    for tc in tc_array {
204                                        if let Some(func) = tc.get("function") {
205                                            let name = func
206                                                .get("name")
207                                                .and_then(|v| v.as_str())
208                                                .unwrap_or("");
209                                            let args = func.get("arguments");
210                                            // Ollama sends tool_calls as objects; we emit start+end
211                                            let id = format!("tc_{}", name);
212                                            parsed.push(BrainEvent::ToolUseStart {
213                                                id: id.clone(),
214                                                name: name.to_string(),
215                                            });
216                                            if let Some(args) = args {
217                                                parsed.push(BrainEvent::ToolUseDelta {
218                                                    id: id.clone(),
219                                                    json: args.to_string(),
220                                                });
221                                            }
222                                            parsed.push(BrainEvent::ToolUseEnd { id });
223                                        }
224                                    }
225                                }
226                            }
227
228                            // Usage
229                            if let (Some(prompt), Some(completion)) = (
230                                event.get("prompt_eval_count").and_then(|v| v.as_u64()),
231                                event.get("eval_count").and_then(|v| v.as_u64()),
232                            ) {
233                                parsed.push(BrainEvent::Usage(crate::event::TokenUsage {
234                                    input: prompt,
235                                    output: completion,
236                                }));
237                            }
238
239                            // Done
240                            if event.get("done").and_then(|v| v.as_bool()).unwrap_or(false) {
241                                let reason = event
242                                    .get("done_reason")
243                                    .and_then(|v| v.as_str())
244                                    .unwrap_or("stop");
245                                let stop = match reason {
246                                    "stop" => crate::event::StopReason::EndTurn,
247                                    "length" => crate::event::StopReason::MaxTokens,
248                                    "tool_calls" => crate::event::StopReason::ToolUse,
249                                    s => crate::event::StopReason::StopSequence(s.to_string()),
250                                };
251                                parsed.push(BrainEvent::Done(stop));
252                            }
253                        }
254                        parsed
255                    }
256                    Err(e) => vec![BrainEvent::Error(format!("Ollama stream error: {}", e))],
257                };
258                async move { Some(stream::iter(events)) }
259            })
260            .flatten();
261
262        Ok(Box::pin(event_stream))
263    }
264}