1use async_trait::async_trait;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6
7use super::{Model, ModelConfig, StreamEventStream};
8use crate::types::{
9 content::{Message, Role, SystemContentBlock},
10 errors::StrandsError,
11 streaming::{
12 ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockDeltaToolUse, ContentBlockStart,
13 ContentBlockStartEvent, ContentBlockStartToolUse, ContentBlockStopEvent, MessageStartEvent,
14 MessageStopEvent, MetadataEvent, Metrics, StopReason, StreamEvent, Usage,
15 },
16 tools::{ToolChoice, ToolSpec},
17};
18
19const DEFAULT_MODEL_ID: &str = "llama3";
20const DEFAULT_HOST: &str = "http://localhost:11434";
21
22#[derive(Clone)]
24pub struct OllamaModel {
25 config: ModelConfig,
26 host: String,
27 client: Client,
28}
29
30impl std::fmt::Debug for OllamaModel {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 f.debug_struct("OllamaModel")
33 .field("config", &self.config)
34 .field("host", &self.host)
35 .finish()
36 }
37}
38
39#[derive(Debug, Serialize)]
40struct OllamaRequest {
41 model: String,
42 messages: Vec<OllamaMessage>,
43 stream: bool,
44 #[serde(skip_serializing_if = "Option::is_none")]
45 options: Option<OllamaOptions>,
46 #[serde(skip_serializing_if = "Vec::is_empty")]
47 tools: Vec<OllamaTool>,
48}
49
50#[derive(Debug, Serialize)]
51struct OllamaOptions {
52 #[serde(skip_serializing_if = "Option::is_none")]
53 num_predict: Option<u32>,
54 #[serde(skip_serializing_if = "Option::is_none")]
55 temperature: Option<f32>,
56 #[serde(skip_serializing_if = "Option::is_none")]
57 top_p: Option<f32>,
58}
59
60#[derive(Debug, Serialize)]
61struct OllamaMessage {
62 role: String,
63 content: String,
64 #[serde(skip_serializing_if = "Option::is_none")]
65 images: Option<Vec<String>>,
66 #[serde(skip_serializing_if = "Option::is_none")]
67 tool_calls: Option<Vec<OllamaToolCall>>,
68}
69
70#[derive(Debug, Serialize, Deserialize, Clone)]
71struct OllamaToolCall {
72 function: OllamaFunctionCall,
73}
74
75#[derive(Debug, Serialize, Deserialize, Clone)]
76struct OllamaFunctionCall {
77 name: String,
78 arguments: serde_json::Value,
79}
80
81#[derive(Debug, Serialize)]
82struct OllamaTool {
83 #[serde(rename = "type")]
84 tool_type: String,
85 function: OllamaFunctionDef,
86}
87
88#[derive(Debug, Serialize)]
89struct OllamaFunctionDef {
90 name: String,
91 description: String,
92 parameters: serde_json::Value,
93}
94
95#[derive(Debug, Deserialize)]
96struct OllamaStreamResponse {
97 message: OllamaResponseMessage,
98 done: bool,
99 #[serde(default)]
100 done_reason: Option<String>,
101 #[serde(default)]
102 eval_count: Option<u32>,
103 #[serde(default)]
104 prompt_eval_count: Option<u32>,
105 #[serde(default)]
106 total_duration: Option<u64>,
107}
108
109#[derive(Debug, Deserialize)]
110struct OllamaResponseMessage {
111 #[serde(default)]
112 content: String,
113 #[serde(default)]
114 tool_calls: Option<Vec<OllamaToolCall>>,
115}
116
117impl OllamaModel {
118 pub fn new(model_id: impl Into<String>) -> Self {
119 Self {
120 config: ModelConfig::new(model_id),
121 host: DEFAULT_HOST.to_string(),
122 client: Client::new(),
123 }
124 }
125
126 pub fn with_host(mut self, host: impl Into<String>) -> Self {
127 self.host = host.into();
128 self
129 }
130
131 pub fn with_config(mut self, config: ModelConfig) -> Self {
132 self.config = config;
133 self
134 }
135
136 fn format_messages(&self, messages: &[Message], system_prompt: Option<&str>) -> Vec<OllamaMessage> {
137 let mut formatted = Vec::new();
138
139 if let Some(prompt) = system_prompt {
140 formatted.push(OllamaMessage {
141 role: "system".to_string(),
142 content: prompt.to_string(),
143 images: None,
144 tool_calls: None,
145 });
146 }
147
148 for msg in messages {
149 let role = match msg.role {
150 Role::User => "user",
151 Role::Assistant => "assistant",
152 };
153
154 let mut text_content = String::new();
155 let mut tool_calls = Vec::new();
156
157 for block in &msg.content {
158 if let Some(ref text) = block.text {
159 text_content.push_str(text);
160 }
161
162 if let Some(ref tu) = block.tool_use {
163 tool_calls.push(OllamaToolCall {
164 function: OllamaFunctionCall {
165 name: tu.name.clone(),
166 arguments: tu.input.clone(),
167 },
168 });
169 }
170
171 if let Some(ref tr) = block.tool_result {
172 let content = tr
173 .content
174 .iter()
175 .filter_map(|c| c.text.clone())
176 .collect::<Vec<_>>()
177 .join("\n");
178
179 formatted.push(OllamaMessage {
180 role: "tool".to_string(),
181 content,
182 images: None,
183 tool_calls: None,
184 });
185 }
186 }
187
188 if !text_content.is_empty() || !tool_calls.is_empty() {
189 formatted.push(OllamaMessage {
190 role: role.to_string(),
191 content: text_content,
192 images: None,
193 tool_calls: if tool_calls.is_empty() { None } else { Some(tool_calls) },
194 });
195 }
196 }
197
198 formatted
199 }
200
201 fn format_tools(&self, tool_specs: &[ToolSpec]) -> Vec<OllamaTool> {
202 tool_specs
203 .iter()
204 .map(|spec| OllamaTool {
205 tool_type: "function".to_string(),
206 function: OllamaFunctionDef {
207 name: spec.name.clone(),
208 description: spec.description.clone(),
209 parameters: spec.input_schema.json.clone(),
210 },
211 })
212 .collect()
213 }
214}
215
216impl Default for OllamaModel {
217 fn default() -> Self {
218 Self::new(DEFAULT_MODEL_ID)
219 }
220}
221
222#[async_trait]
223impl Model for OllamaModel {
224 fn config(&self) -> &ModelConfig {
225 &self.config
226 }
227
228 fn update_config(&mut self, config: ModelConfig) {
229 self.config = config;
230 }
231
232 fn stream<'a>(
233 &'a self,
234 messages: &'a [Message],
235 tool_specs: Option<&'a [ToolSpec]>,
236 system_prompt: Option<&'a str>,
237 _tool_choice: Option<ToolChoice>,
238 _system_prompt_content: Option<&'a [SystemContentBlock]>,
239 ) -> StreamEventStream<'a> {
240 let url = format!("{}/api/chat", self.host);
241 let client = self.client.clone();
242
243 let options = OllamaOptions {
244 num_predict: self.config.max_tokens,
245 temperature: self.config.temperature,
246 top_p: self.config.top_p,
247 };
248
249 let request = OllamaRequest {
250 model: self.config.model_id.clone(),
251 messages: self.format_messages(messages, system_prompt),
252 stream: true,
253 options: Some(options),
254 tools: tool_specs.map(|s| self.format_tools(s)).unwrap_or_default(),
255 };
256
257 Box::pin(async_stream::stream! {
258 let response = match client
259 .post(&url)
260 .header("Content-Type", "application/json")
261 .json(&request)
262 .send()
263 .await
264 {
265 Ok(resp) => resp,
266 Err(e) => {
267 yield Err(StrandsError::NetworkError(e.to_string()));
268 return;
269 }
270 };
271
272 if !response.status().is_success() {
273 let status = response.status();
274 let body = response.text().await.unwrap_or_default();
275 yield Err(StrandsError::model_error(format!("HTTP {status}: {body}")));
276 return;
277 }
278
279 yield Ok(StreamEvent {
280 message_start: Some(MessageStartEvent { role: Role::Assistant }),
281 ..Default::default()
282 });
283
284 yield Ok(StreamEvent {
285 content_block_start: Some(ContentBlockStartEvent {
286 content_block_index: Some(0),
287 start: None,
288 }),
289 ..Default::default()
290 });
291
292 use futures::StreamExt;
293 let mut byte_stream = response.bytes_stream();
294 let mut tool_calls_found: Vec<OllamaToolCall> = Vec::new();
295 let mut final_response: Option<OllamaStreamResponse> = None;
296
297 while let Some(chunk_result) = byte_stream.next().await {
298 let chunk = match chunk_result {
299 Ok(bytes) => String::from_utf8_lossy(&bytes).to_string(),
300 Err(e) => {
301 yield Err(StrandsError::NetworkError(e.to_string()));
302 return;
303 }
304 };
305
306 for line in chunk.lines() {
307 let line = line.trim();
308 if line.is_empty() {
309 continue;
310 }
311
312 if let Ok(resp) = serde_json::from_str::<OllamaStreamResponse>(line) {
313 if !resp.message.content.is_empty() {
314 yield Ok(StreamEvent {
315 content_block_delta: Some(ContentBlockDeltaEvent {
316 content_block_index: Some(0),
317 delta: Some(ContentBlockDelta {
318 text: Some(resp.message.content.clone()),
319 ..Default::default()
320 }),
321 }),
322 ..Default::default()
323 });
324 }
325
326 if let Some(ref tcs) = resp.message.tool_calls {
327 tool_calls_found.extend(tcs.clone());
328 }
329
330 if resp.done {
331 final_response = Some(resp);
332 break;
333 }
334 }
335 }
336 }
337
338 yield Ok(StreamEvent {
339 content_block_stop: Some(ContentBlockStopEvent {
340 content_block_index: Some(0),
341 }),
342 ..Default::default()
343 });
344
345 let mut tool_index = 1u32;
346 for tc in &tool_calls_found {
347 yield Ok(StreamEvent {
348 content_block_start: Some(ContentBlockStartEvent {
349 content_block_index: Some(tool_index),
350 start: Some(ContentBlockStart {
351 tool_use: Some(ContentBlockStartToolUse {
352 name: tc.function.name.clone(),
353 tool_use_id: tc.function.name.clone(),
354 }),
355 }),
356 }),
357 ..Default::default()
358 });
359
360 yield Ok(StreamEvent {
361 content_block_delta: Some(ContentBlockDeltaEvent {
362 content_block_index: Some(tool_index),
363 delta: Some(ContentBlockDelta {
364 tool_use: Some(ContentBlockDeltaToolUse {
365 input: serde_json::to_string(&tc.function.arguments).unwrap_or_default(),
366 }),
367 ..Default::default()
368 }),
369 }),
370 ..Default::default()
371 });
372
373 yield Ok(StreamEvent {
374 content_block_stop: Some(ContentBlockStopEvent {
375 content_block_index: Some(tool_index),
376 }),
377 ..Default::default()
378 });
379
380 tool_index += 1;
381 }
382
383 let stop_reason = if !tool_calls_found.is_empty() {
384 StopReason::ToolUse
385 } else if final_response.as_ref().and_then(|r| r.done_reason.as_ref()).map(|s| s == "length").unwrap_or(false) {
386 StopReason::MaxTokens
387 } else {
388 StopReason::EndTurn
389 };
390
391 yield Ok(StreamEvent {
392 message_stop: Some(MessageStopEvent {
393 stop_reason: Some(stop_reason),
394 additional_model_response_fields: None,
395 }),
396 ..Default::default()
397 });
398
399 if let Some(ref resp) = final_response {
400 let input_tokens = resp.prompt_eval_count.unwrap_or(0);
401 let output_tokens = resp.eval_count.unwrap_or(0);
402 let latency_ms = resp.total_duration.map(|d| d / 1_000_000).unwrap_or(0);
403
404 yield Ok(StreamEvent {
405 metadata: Some(MetadataEvent {
406 usage: Some(Usage {
407 input_tokens,
408 output_tokens,
409 total_tokens: input_tokens + output_tokens,
410 cache_read_input_tokens: 0,
411 cache_write_input_tokens: 0,
412 }),
413 metrics: Some(Metrics {
414 latency_ms,
415 time_to_first_byte_ms: 0,
416 }),
417 trace: None,
418 }),
419 ..Default::default()
420 });
421 }
422 })
423 }
424}
425
426#[cfg(test)]
427mod tests {
428 use super::*;
429
430 #[test]
431 fn test_ollama_model_creation() {
432 let model = OllamaModel::new("llama3.2");
433 assert_eq!(model.config().model_id, "llama3.2");
434 }
435
436 #[test]
437 fn test_ollama_with_host() {
438 let model = OllamaModel::new("llama3").with_host("http://192.168.1.100:11434");
439 assert_eq!(model.host, "http://192.168.1.100:11434");
440 }
441
442 #[test]
443 fn test_ollama_default() {
444 let model = OllamaModel::default();
445 assert_eq!(model.config().model_id, "llama3");
446 assert_eq!(model.host, "http://localhost:11434");
447 }
448}
449