1use std::sync::Arc;
2
3use async_trait::async_trait;
4use serde_json::{json, Value};
5use synaptic_core::{
6 AIMessageChunk, ChatModel, ChatRequest, ChatResponse, ChatStream, Message, SynapseError,
7 TokenUsage, ToolCall, ToolChoice, ToolDefinition,
8};
9
10use crate::backend::{ProviderBackend, ProviderRequest, ProviderResponse};
11
12#[derive(Debug, Clone)]
13pub struct OllamaConfig {
14 pub model: String,
15 pub base_url: String,
16 pub top_p: Option<f64>,
17 pub stop: Option<Vec<String>>,
18 pub seed: Option<u64>,
19}
20
21impl OllamaConfig {
22 pub fn new(model: impl Into<String>) -> Self {
23 Self {
24 model: model.into(),
25 base_url: "http://localhost:11434".to_string(),
26 top_p: None,
27 stop: None,
28 seed: None,
29 }
30 }
31
32 pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
33 self.base_url = url.into();
34 self
35 }
36
37 pub fn with_top_p(mut self, top_p: f64) -> Self {
38 self.top_p = Some(top_p);
39 self
40 }
41
42 pub fn with_stop(mut self, stop: Vec<String>) -> Self {
43 self.stop = Some(stop);
44 self
45 }
46
47 pub fn with_seed(mut self, seed: u64) -> Self {
48 self.seed = Some(seed);
49 self
50 }
51}
52
53pub struct OllamaChatModel {
54 config: OllamaConfig,
55 backend: Arc<dyn ProviderBackend>,
56}
57
58impl OllamaChatModel {
59 pub fn new(config: OllamaConfig, backend: Arc<dyn ProviderBackend>) -> Self {
60 Self { config, backend }
61 }
62
63 fn build_request(&self, request: &ChatRequest, stream: bool) -> ProviderRequest {
64 let messages: Vec<Value> = request.messages.iter().map(message_to_ollama).collect();
65
66 let mut body = json!({
67 "model": self.config.model,
68 "messages": messages,
69 "stream": stream,
70 });
71
72 if !request.tools.is_empty() {
73 body["tools"] = json!(request
74 .tools
75 .iter()
76 .map(tool_def_to_ollama)
77 .collect::<Vec<_>>());
78 }
79 if let Some(ref choice) = request.tool_choice {
80 body["tool_choice"] = match choice {
81 ToolChoice::Auto => json!("auto"),
82 ToolChoice::Required => json!("required"),
83 ToolChoice::None => json!("none"),
84 ToolChoice::Specific(name) => json!({
85 "type": "function",
86 "function": {"name": name}
87 }),
88 };
89 }
90
91 {
92 let mut options = json!({});
93 let mut has_options = false;
94 if let Some(top_p) = self.config.top_p {
95 options["top_p"] = json!(top_p);
96 has_options = true;
97 }
98 if let Some(ref stop) = self.config.stop {
99 options["stop"] = json!(stop);
100 has_options = true;
101 }
102 if let Some(seed) = self.config.seed {
103 options["seed"] = json!(seed);
104 has_options = true;
105 }
106 if has_options {
107 body["options"] = options;
108 }
109 }
110
111 ProviderRequest {
112 url: format!("{}/api/chat", self.config.base_url),
113 headers: vec![("Content-Type".to_string(), "application/json".to_string())],
114 body,
115 }
116 }
117}
118
119fn message_to_ollama(msg: &Message) -> Value {
120 match msg {
121 Message::System { content, .. } => json!({
122 "role": "system",
123 "content": content,
124 }),
125 Message::Human { content, .. } => json!({
126 "role": "user",
127 "content": content,
128 }),
129 Message::AI {
130 content,
131 tool_calls,
132 ..
133 } => {
134 let mut obj = json!({
135 "role": "assistant",
136 "content": content,
137 });
138 if !tool_calls.is_empty() {
139 obj["tool_calls"] = json!(tool_calls
140 .iter()
141 .map(|tc| json!({
142 "function": {
143 "name": tc.name,
144 "arguments": tc.arguments,
145 }
146 }))
147 .collect::<Vec<_>>());
148 }
149 obj
150 }
151 Message::Tool {
152 content,
153 tool_call_id: _,
154 ..
155 } => json!({
156 "role": "tool",
157 "content": content,
158 }),
159 Message::Chat {
160 custom_role,
161 content,
162 ..
163 } => json!({
164 "role": custom_role,
165 "content": content,
166 }),
167 Message::Remove { .. } => json!(null), }
169}
170
171fn tool_def_to_ollama(def: &ToolDefinition) -> Value {
172 json!({
173 "type": "function",
174 "function": {
175 "name": def.name,
176 "description": def.description,
177 "parameters": def.parameters,
178 }
179 })
180}
181
182fn parse_response(resp: &ProviderResponse) -> Result<ChatResponse, SynapseError> {
183 check_error_status(resp)?;
184
185 let message_val = &resp.body["message"];
186 let content = message_val["content"].as_str().unwrap_or("").to_string();
187 let tool_calls = parse_tool_calls(message_val);
188
189 let usage = parse_usage(&resp.body);
190
191 let message = if tool_calls.is_empty() {
192 Message::ai(content)
193 } else {
194 Message::ai_with_tool_calls(content, tool_calls)
195 };
196
197 Ok(ChatResponse { message, usage })
198}
199
200fn check_error_status(resp: &ProviderResponse) -> Result<(), SynapseError> {
201 if resp.status >= 400 {
202 let msg = resp.body["error"]
203 .as_str()
204 .unwrap_or("unknown Ollama error")
205 .to_string();
206 return Err(SynapseError::Model(format!(
207 "Ollama API error ({}): {}",
208 resp.status, msg
209 )));
210 }
211 Ok(())
212}
213
214fn parse_tool_calls(message: &Value) -> Vec<ToolCall> {
215 message["tool_calls"]
216 .as_array()
217 .map(|arr| {
218 arr.iter()
219 .enumerate()
220 .filter_map(|(i, tc)| {
221 let name = tc["function"]["name"].as_str()?.to_string();
222 let arguments = tc["function"]["arguments"].clone();
223 Some(ToolCall {
224 id: format!("ollama-{i}"),
225 name,
226 arguments,
227 })
228 })
229 .collect()
230 })
231 .unwrap_or_default()
232}
233
234fn parse_usage(body: &Value) -> Option<TokenUsage> {
235 let prompt = body["prompt_eval_count"].as_u64();
236 let completion = body["eval_count"].as_u64();
237 match (prompt, completion) {
238 (Some(p), Some(c)) => Some(TokenUsage {
239 input_tokens: p as u32,
240 output_tokens: c as u32,
241 total_tokens: (p + c) as u32,
242 input_details: None,
243 output_details: None,
244 }),
245 _ => None,
246 }
247}
248
249fn parse_ndjson_chunk(line: &str) -> Option<AIMessageChunk> {
250 let v: Value = serde_json::from_str(line).ok()?;
251
252 let content = v["message"]["content"].as_str().unwrap_or("").to_string();
254 let tool_calls = parse_tool_calls(&v["message"]);
255 let done = v["done"].as_bool().unwrap_or(false);
256
257 let usage = if done { parse_usage(&v) } else { None };
258
259 Some(AIMessageChunk {
260 content,
261 tool_calls,
262 usage,
263 ..Default::default()
264 })
265}
266
267#[async_trait]
268impl ChatModel for OllamaChatModel {
269 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapseError> {
270 let provider_req = self.build_request(&request, false);
271 let resp = self.backend.send(provider_req).await?;
272 parse_response(&resp)
273 }
274
275 fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
276 Box::pin(async_stream::stream! {
277 let provider_req = self.build_request(&request, true);
278 let byte_stream = self.backend.send_stream(provider_req).await;
279
280 let byte_stream = match byte_stream {
281 Ok(s) => s,
282 Err(e) => {
283 yield Err(e);
284 return;
285 }
286 };
287
288 use futures::StreamExt;
289
290 let mut buffer = String::new();
292 let mut byte_stream = std::pin::pin!(byte_stream);
293
294 while let Some(result) = byte_stream.next().await {
295 match result {
296 Ok(bytes) => {
297 buffer.push_str(&String::from_utf8_lossy(&bytes));
298 while let Some(pos) = buffer.find('\n') {
299 let line = buffer[..pos].trim().to_string();
300 buffer = buffer[pos + 1..].to_string();
301 if line.is_empty() {
302 continue;
303 }
304 if let Some(chunk) = parse_ndjson_chunk(&line) {
305 yield Ok(chunk);
306 }
307 }
308 }
309 Err(e) => {
310 yield Err(e);
311 break;
312 }
313 }
314 }
315
316 let remaining = buffer.trim().to_string();
318 if !remaining.is_empty() {
319 if let Some(chunk) = parse_ndjson_chunk(&remaining) {
320 yield Ok(chunk);
321 }
322 }
323 })
324 }
325}