Skip to main content

sparrow/provider/
ollama.rs

1use async_trait::async_trait;
2use futures::stream::{self, StreamExt};
3use reqwest::Client;
4use serde_json::json;
5
6use super::{Brain, BrainEvent, BrainRequest, BrainStream, ContentBlock, LatencyClass, ModelCaps};
7
8/// Native Ollama adapter using `/api/chat` with NDJSON streaming.
9/// Ollama does not use OpenAI-compatible tool format; it has its own.
10pub struct OllamaAdapter {
11    model: String,
12    base_url: String,
13    client: Client,
14    caps: ModelCaps,
15}
16
17impl OllamaAdapter {
18    pub fn new(model: &str, base_url: &str) -> Self {
19        Self {
20            model: model.to_string(),
21            base_url: base_url
22                .trim_end_matches("/v1")
23                .trim_end_matches('/')
24                .to_string(),
25            client: Client::new(),
26            caps: ModelCaps {
27                context_window: 32_768,
28                max_output: 8_000,
29                tools: true,
30                vision: false,
31                cost_input_per_mtok: 0.0,
32                cost_output_per_mtok: 0.0,
33                latency: LatencyClass::Medium,
34            },
35        }
36    }
37
38    pub fn with_caps(mut self, caps: ModelCaps) -> Self {
39        self.caps = caps;
40        self
41    }
42
43    async fn live_caps(&self) -> ModelCaps {
44        let mut caps = self.caps.clone();
45        let url = format!("{}/api/show", self.base_url);
46        let Ok(response) = self
47            .client
48            .post(&url)
49            .json(&json!({ "model": self.model }))
50            .send()
51            .await
52        else {
53            return caps;
54        };
55        if !response.status().is_success() {
56            return caps;
57        }
58        let Ok(payload) = response.json::<serde_json::Value>().await else {
59            return caps;
60        };
61
62        if let Some(capabilities) = payload.get("capabilities").and_then(|v| v.as_array()) {
63            caps.tools = capabilities.iter().any(|cap| cap.as_str() == Some("tools"));
64        }
65        if let Some(ctx) = find_context_window(&payload) {
66            caps.context_window = ctx;
67            caps.max_output = (ctx / 8).clamp(4_096, 32_000);
68        }
69        caps
70    }
71
72    /// Convert Sparrow Msg into Ollama's native format
73    fn build_ollama_messages(req: &BrainRequest) -> Vec<serde_json::Value> {
74        let mut messages: Vec<serde_json::Value> = Vec::new();
75
76        if let Some(sys) = &req.system {
77            messages.push(json!({"role": "system", "content": sys}));
78        }
79
80        for msg in &req.messages {
81            let role = match msg.role.as_str() {
82                "assistant" => "assistant",
83                _ => "user",
84            };
85
86            let mut content = String::new();
87            let mut tool_calls: Vec<serde_json::Value> = Vec::new();
88
89            for block in &msg.content {
90                match block {
91                    ContentBlock::Text { text } => {
92                        content.push_str(text);
93                    }
94                    ContentBlock::ToolUse { id: _, name, input } => {
95                        tool_calls.push(json!({
96                            "function": {
97                                "name": name,
98                                "arguments": input,
99                            }
100                        }));
101                    }
102                    ContentBlock::ToolResult {
103                        tool_use_id,
104                        content: blocks,
105                        is_error: _,
106                    } => {
107                        let text: String = blocks
108                            .iter()
109                            .filter_map(|b| match b {
110                                ContentBlock::Text { text } => Some(text.as_str()),
111                                _ => None,
112                            })
113                            .collect::<Vec<_>>()
114                            .join("\n");
115                        // Ollama native: tool results are user messages with tool_call_id
116                        messages.push(json!({
117                            "role": "tool",
118                            "content": text,
119                            "tool_call_id": tool_use_id,
120                        }));
121                    }
122                    _ => {}
123                }
124            }
125
126            if !content.is_empty() || tool_calls.is_empty() {
127                let mut msg_json = json!({"role": role, "content": content});
128                if !tool_calls.is_empty() {
129                    msg_json["tool_calls"] = json!(tool_calls);
130                }
131                messages.push(msg_json);
132            }
133        }
134
135        messages
136    }
137
138    /// Convert Sparrow ToolSpec to Ollama tool format
139    fn build_ollama_tools(tools: &[super::ToolSpec]) -> Vec<serde_json::Value> {
140        if tools.is_empty() {
141            return vec![];
142        }
143        tools
144            .iter()
145            .map(|t| {
146                json!({
147                    "type": "function",
148                    "function": {
149                        "name": t.name,
150                        "description": t.description,
151                        "parameters": t.input_schema,
152                    }
153                })
154            })
155            .collect()
156    }
157}
158
159#[async_trait]
160impl Brain for OllamaAdapter {
161    fn id(&self) -> &str {
162        &self.model
163    }
164
165    fn caps(&self) -> ModelCaps {
166        self.caps.clone()
167    }
168
169    async fn complete(&self, req: BrainRequest) -> anyhow::Result<BrainStream> {
170        let caps = self.live_caps().await;
171        let messages = Self::build_ollama_messages(&req);
172        let tools = if caps.tools {
173            Self::build_ollama_tools(&req.tools)
174        } else {
175            Vec::new()
176        };
177
178        let mut body = json!({
179            "model": self.model,
180            "messages": messages,
181            "stream": true,
182            "options": {
183                "temperature": req.temperature as f64,
184            }
185        });
186
187        if req.max_tokens > 0 {
188            body["options"]["num_predict"] = json!(req.max_tokens);
189        }
190        if caps.context_window > 0 {
191            body["options"]["num_ctx"] = json!(caps.context_window);
192        }
193        if !tools.is_empty() {
194            body["tools"] = json!(tools);
195        }
196
197        let url = format!("{}/api/chat", self.base_url);
198
199        let response = self.client.post(&url).json(&body).send().await?;
200
201        if !response.status().is_success() {
202            let status = response.status().as_u16();
203            let body = response.text().await.unwrap_or_default();
204            return Err(anyhow::anyhow!("Ollama API error {}: {}", status, body));
205        }
206
207        let stream = response.bytes_stream();
208
209        // NDJSON across chunk boundaries needs the same line buffer the SSE
210        // providers use (see provider/sse_buffer.rs) — without it a JSON object
211        // split between two TCP chunks gets dropped silently.
212        let event_stream = stream
213            .scan(super::sse_buffer::LineBuffer::new(), |line_buf, chunk| {
214                let events: Vec<BrainEvent> = match chunk {
215                    Ok(bytes) => {
216                        let lines = line_buf.push(&bytes);
217                        let mut parsed = Vec::new();
218                        for line in lines {
219                            let line = line.trim();
220                            if line.is_empty() {
221                                continue;
222                            }
223                            let event: serde_json::Value = match serde_json::from_str(line) {
224                                Ok(v) => v,
225                                Err(_) => continue,
226                            };
227
228                            // Ollama NDJSON: {"message":{"content":"..."}} or {"message":{"tool_calls":[...]}}
229                            if let Some(msg) = event.get("message") {
230                                // Text delta (Ollama streams full message each line, not deltas)
231                                if let Some(content) = msg.get("content").and_then(|v| v.as_str()) {
232                                    if !content.is_empty() {
233                                        parsed.push(BrainEvent::TextDelta(content.to_string()));
234                                    }
235                                }
236                                // Tool calls
237                                if let Some(tc_array) =
238                                    msg.get("tool_calls").and_then(|v| v.as_array())
239                                {
240                                    for tc in tc_array {
241                                        if let Some(func) = tc.get("function") {
242                                            let name = func
243                                                .get("name")
244                                                .and_then(|v| v.as_str())
245                                                .unwrap_or("");
246                                            let args = func.get("arguments");
247                                            // Ollama sends tool_calls as objects; we emit start+end
248                                            let id = format!("tc_{}", name);
249                                            parsed.push(BrainEvent::ToolUseStart {
250                                                id: id.clone(),
251                                                name: name.to_string(),
252                                            });
253                                            if let Some(args) = args {
254                                                parsed.push(BrainEvent::ToolUseDelta {
255                                                    id: id.clone(),
256                                                    json: args.to_string(),
257                                                });
258                                            }
259                                            parsed.push(BrainEvent::ToolUseEnd { id });
260                                        }
261                                    }
262                                }
263                            }
264
265                            // Usage
266                            if let (Some(prompt), Some(completion)) = (
267                                event.get("prompt_eval_count").and_then(|v| v.as_u64()),
268                                event.get("eval_count").and_then(|v| v.as_u64()),
269                            ) {
270                                parsed.push(BrainEvent::Usage(crate::event::TokenUsage {
271                                    input: prompt,
272                                    output: completion,
273                                }));
274                            }
275
276                            // Done
277                            if event.get("done").and_then(|v| v.as_bool()).unwrap_or(false) {
278                                let reason = event
279                                    .get("done_reason")
280                                    .and_then(|v| v.as_str())
281                                    .unwrap_or("stop");
282                                let stop = match reason {
283                                    "stop" => crate::event::StopReason::EndTurn,
284                                    "length" => crate::event::StopReason::MaxTokens,
285                                    "tool_calls" => crate::event::StopReason::ToolUse,
286                                    s => crate::event::StopReason::StopSequence(s.to_string()),
287                                };
288                                parsed.push(BrainEvent::Done(stop));
289                            }
290                        }
291                        parsed
292                    }
293                    Err(e) => vec![BrainEvent::Error(format!("Ollama stream error: {}", e))],
294                };
295                async move { Some(stream::iter(events)) }
296            })
297            .flatten();
298
299        Ok(Box::pin(event_stream))
300    }
301}
302
303fn find_context_window(value: &serde_json::Value) -> Option<u64> {
304    fn visit(value: &serde_json::Value, best: &mut Option<u64>) {
305        match value {
306            serde_json::Value::Object(map) => {
307                for (key, child) in map {
308                    let key = key.to_ascii_lowercase();
309                    if (key.ends_with("context_length")
310                        || key == "num_ctx"
311                        || key == "context_window")
312                        && child.as_u64().is_some()
313                    {
314                        *best = (*best).max(child.as_u64());
315                    }
316                    visit(child, best);
317                }
318            }
319            serde_json::Value::Array(items) => {
320                for item in items {
321                    visit(item, best);
322                }
323            }
324            _ => {}
325        }
326    }
327    let mut best = None;
328    visit(value, &mut best);
329    best
330}