strands_agents/models/
ollama.rs

1//! Ollama local model provider.
2
3use async_trait::async_trait;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6
7use super::{Model, ModelConfig, StreamEventStream};
8use crate::types::{
9    content::{Message, Role, SystemContentBlock},
10    errors::StrandsError,
11    streaming::{
12        ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockDeltaToolUse, ContentBlockStart,
13        ContentBlockStartEvent, ContentBlockStartToolUse, ContentBlockStopEvent, MessageStartEvent,
14        MessageStopEvent, MetadataEvent, Metrics, StopReason, StreamEvent, Usage,
15    },
16    tools::{ToolChoice, ToolSpec},
17};
18
19const DEFAULT_MODEL_ID: &str = "llama3";
20const DEFAULT_HOST: &str = "http://localhost:11434";
21
22/// Ollama local model provider.
23#[derive(Clone)]
24pub struct OllamaModel {
25    config: ModelConfig,
26    host: String,
27    client: Client,
28}
29
30impl std::fmt::Debug for OllamaModel {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        f.debug_struct("OllamaModel")
33            .field("config", &self.config)
34            .field("host", &self.host)
35            .finish()
36    }
37}
38
39#[derive(Debug, Serialize)]
40struct OllamaRequest {
41    model: String,
42    messages: Vec<OllamaMessage>,
43    stream: bool,
44    #[serde(skip_serializing_if = "Option::is_none")]
45    options: Option<OllamaOptions>,
46    #[serde(skip_serializing_if = "Vec::is_empty")]
47    tools: Vec<OllamaTool>,
48}
49
50#[derive(Debug, Serialize)]
51struct OllamaOptions {
52    #[serde(skip_serializing_if = "Option::is_none")]
53    num_predict: Option<u32>,
54    #[serde(skip_serializing_if = "Option::is_none")]
55    temperature: Option<f32>,
56    #[serde(skip_serializing_if = "Option::is_none")]
57    top_p: Option<f32>,
58}
59
60#[derive(Debug, Serialize)]
61struct OllamaMessage {
62    role: String,
63    content: String,
64    #[serde(skip_serializing_if = "Option::is_none")]
65    images: Option<Vec<String>>,
66    #[serde(skip_serializing_if = "Option::is_none")]
67    tool_calls: Option<Vec<OllamaToolCall>>,
68}
69
70#[derive(Debug, Serialize, Deserialize, Clone)]
71struct OllamaToolCall {
72    function: OllamaFunctionCall,
73}
74
75#[derive(Debug, Serialize, Deserialize, Clone)]
76struct OllamaFunctionCall {
77    name: String,
78    arguments: serde_json::Value,
79}
80
81#[derive(Debug, Serialize)]
82struct OllamaTool {
83    #[serde(rename = "type")]
84    tool_type: String,
85    function: OllamaFunctionDef,
86}
87
88#[derive(Debug, Serialize)]
89struct OllamaFunctionDef {
90    name: String,
91    description: String,
92    parameters: serde_json::Value,
93}
94
95#[derive(Debug, Deserialize)]
96struct OllamaStreamResponse {
97    message: OllamaResponseMessage,
98    done: bool,
99    #[serde(default)]
100    done_reason: Option<String>,
101    #[serde(default)]
102    eval_count: Option<u32>,
103    #[serde(default)]
104    prompt_eval_count: Option<u32>,
105    #[serde(default)]
106    total_duration: Option<u64>,
107}
108
109#[derive(Debug, Deserialize)]
110struct OllamaResponseMessage {
111    #[serde(default)]
112    content: String,
113    #[serde(default)]
114    tool_calls: Option<Vec<OllamaToolCall>>,
115}
116
117impl OllamaModel {
118    pub fn new(model_id: impl Into<String>) -> Self {
119        Self {
120            config: ModelConfig::new(model_id),
121            host: DEFAULT_HOST.to_string(),
122            client: Client::new(),
123        }
124    }
125
126    pub fn with_host(mut self, host: impl Into<String>) -> Self {
127        self.host = host.into();
128        self
129    }
130
131    pub fn with_config(mut self, config: ModelConfig) -> Self {
132        self.config = config;
133        self
134    }
135
136    fn format_messages(&self, messages: &[Message], system_prompt: Option<&str>) -> Vec<OllamaMessage> {
137        let mut formatted = Vec::new();
138
139        if let Some(prompt) = system_prompt {
140            formatted.push(OllamaMessage {
141                role: "system".to_string(),
142                content: prompt.to_string(),
143                images: None,
144                tool_calls: None,
145            });
146        }
147
148        for msg in messages {
149            let role = match msg.role {
150                Role::User => "user",
151                Role::Assistant => "assistant",
152            };
153
154            let mut text_content = String::new();
155            let mut tool_calls = Vec::new();
156
157            for block in &msg.content {
158                if let Some(ref text) = block.text {
159                    text_content.push_str(text);
160                }
161
162                if let Some(ref tu) = block.tool_use {
163                    tool_calls.push(OllamaToolCall {
164                        function: OllamaFunctionCall {
165                            name: tu.name.clone(),
166                            arguments: tu.input.clone(),
167                        },
168                    });
169                }
170
171                if let Some(ref tr) = block.tool_result {
172                    let content = tr
173                        .content
174                        .iter()
175                        .filter_map(|c| c.text.clone())
176                        .collect::<Vec<_>>()
177                        .join("\n");
178
179                    formatted.push(OllamaMessage {
180                        role: "tool".to_string(),
181                        content,
182                        images: None,
183                        tool_calls: None,
184                    });
185                }
186            }
187
188            if !text_content.is_empty() || !tool_calls.is_empty() {
189                formatted.push(OllamaMessage {
190                    role: role.to_string(),
191                    content: text_content,
192                    images: None,
193                    tool_calls: if tool_calls.is_empty() { None } else { Some(tool_calls) },
194                });
195            }
196        }
197
198        formatted
199    }
200
201    fn format_tools(&self, tool_specs: &[ToolSpec]) -> Vec<OllamaTool> {
202        tool_specs
203            .iter()
204            .map(|spec| OllamaTool {
205                tool_type: "function".to_string(),
206                function: OllamaFunctionDef {
207                    name: spec.name.clone(),
208                    description: spec.description.clone(),
209                    parameters: spec.input_schema.json.clone(),
210                },
211            })
212            .collect()
213    }
214}
215
216impl Default for OllamaModel {
217    fn default() -> Self {
218        Self::new(DEFAULT_MODEL_ID)
219    }
220}
221
222#[async_trait]
223impl Model for OllamaModel {
224    fn config(&self) -> &ModelConfig {
225        &self.config
226    }
227
228    fn update_config(&mut self, config: ModelConfig) {
229        self.config = config;
230    }
231
232    fn stream<'a>(
233        &'a self,
234        messages: &'a [Message],
235        tool_specs: Option<&'a [ToolSpec]>,
236        system_prompt: Option<&'a str>,
237        _tool_choice: Option<ToolChoice>,
238        _system_prompt_content: Option<&'a [SystemContentBlock]>,
239    ) -> StreamEventStream<'a> {
240        let url = format!("{}/api/chat", self.host);
241        let client = self.client.clone();
242
243        let options = OllamaOptions {
244            num_predict: self.config.max_tokens,
245            temperature: self.config.temperature,
246            top_p: self.config.top_p,
247        };
248
249        let request = OllamaRequest {
250            model: self.config.model_id.clone(),
251            messages: self.format_messages(messages, system_prompt),
252            stream: true,
253            options: Some(options),
254            tools: tool_specs.map(|s| self.format_tools(s)).unwrap_or_default(),
255        };
256
257        Box::pin(async_stream::stream! {
258            let response = match client
259                .post(&url)
260                .header("Content-Type", "application/json")
261                .json(&request)
262                .send()
263                .await
264            {
265                Ok(resp) => resp,
266                Err(e) => {
267                    yield Err(StrandsError::NetworkError(e.to_string()));
268                    return;
269                }
270            };
271
272            if !response.status().is_success() {
273                let status = response.status();
274                let body = response.text().await.unwrap_or_default();
275                yield Err(StrandsError::model_error(format!("HTTP {status}: {body}")));
276                return;
277            }
278
279            yield Ok(StreamEvent {
280                message_start: Some(MessageStartEvent { role: Role::Assistant }),
281                ..Default::default()
282            });
283
284            yield Ok(StreamEvent {
285                content_block_start: Some(ContentBlockStartEvent {
286                    content_block_index: Some(0),
287                    start: None,
288                }),
289                ..Default::default()
290            });
291
292            use futures::StreamExt;
293            let mut byte_stream = response.bytes_stream();
294            let mut tool_calls_found: Vec<OllamaToolCall> = Vec::new();
295            let mut final_response: Option<OllamaStreamResponse> = None;
296
297            while let Some(chunk_result) = byte_stream.next().await {
298                let chunk = match chunk_result {
299                    Ok(bytes) => String::from_utf8_lossy(&bytes).to_string(),
300                    Err(e) => {
301                        yield Err(StrandsError::NetworkError(e.to_string()));
302                        return;
303                    }
304                };
305
306                for line in chunk.lines() {
307                    let line = line.trim();
308                    if line.is_empty() {
309                        continue;
310                    }
311
312                    if let Ok(resp) = serde_json::from_str::<OllamaStreamResponse>(line) {
313                        if !resp.message.content.is_empty() {
314                            yield Ok(StreamEvent {
315                                content_block_delta: Some(ContentBlockDeltaEvent {
316                                    content_block_index: Some(0),
317                                    delta: Some(ContentBlockDelta {
318                                        text: Some(resp.message.content.clone()),
319                                        ..Default::default()
320                                    }),
321                                }),
322                                ..Default::default()
323                            });
324                        }
325
326                        if let Some(ref tcs) = resp.message.tool_calls {
327                            tool_calls_found.extend(tcs.clone());
328                        }
329
330                        if resp.done {
331                            final_response = Some(resp);
332                            break;
333                        }
334                    }
335                }
336            }
337
338            yield Ok(StreamEvent {
339                content_block_stop: Some(ContentBlockStopEvent {
340                    content_block_index: Some(0),
341                }),
342                ..Default::default()
343            });
344
345            let mut tool_index = 1u32;
346            for tc in &tool_calls_found {
347                yield Ok(StreamEvent {
348                    content_block_start: Some(ContentBlockStartEvent {
349                        content_block_index: Some(tool_index),
350                        start: Some(ContentBlockStart {
351                            tool_use: Some(ContentBlockStartToolUse {
352                                name: tc.function.name.clone(),
353                                tool_use_id: tc.function.name.clone(),
354                            }),
355                        }),
356                    }),
357                    ..Default::default()
358                });
359
360                yield Ok(StreamEvent {
361                    content_block_delta: Some(ContentBlockDeltaEvent {
362                        content_block_index: Some(tool_index),
363                        delta: Some(ContentBlockDelta {
364                            tool_use: Some(ContentBlockDeltaToolUse {
365                                input: serde_json::to_string(&tc.function.arguments).unwrap_or_default(),
366                            }),
367                            ..Default::default()
368                        }),
369                    }),
370                    ..Default::default()
371                });
372
373                yield Ok(StreamEvent {
374                    content_block_stop: Some(ContentBlockStopEvent {
375                        content_block_index: Some(tool_index),
376                    }),
377                    ..Default::default()
378                });
379
380                tool_index += 1;
381            }
382
383            let stop_reason = if !tool_calls_found.is_empty() {
384                StopReason::ToolUse
385            } else if final_response.as_ref().and_then(|r| r.done_reason.as_ref()).map(|s| s == "length").unwrap_or(false) {
386                StopReason::MaxTokens
387            } else {
388                StopReason::EndTurn
389            };
390
391            yield Ok(StreamEvent {
392                message_stop: Some(MessageStopEvent {
393                    stop_reason: Some(stop_reason),
394                    additional_model_response_fields: None,
395                }),
396                ..Default::default()
397            });
398
399            if let Some(ref resp) = final_response {
400                let input_tokens = resp.prompt_eval_count.unwrap_or(0);
401                let output_tokens = resp.eval_count.unwrap_or(0);
402                let latency_ms = resp.total_duration.map(|d| d / 1_000_000).unwrap_or(0);
403
404                yield Ok(StreamEvent {
405                    metadata: Some(MetadataEvent {
406                        usage: Some(Usage {
407                            input_tokens,
408                            output_tokens,
409                            total_tokens: input_tokens + output_tokens,
410                            cache_read_input_tokens: 0,
411                            cache_write_input_tokens: 0,
412                        }),
413                        metrics: Some(Metrics {
414                            latency_ms,
415                            time_to_first_byte_ms: 0,
416                        }),
417                        trace: None,
418                    }),
419                    ..Default::default()
420                });
421            }
422        })
423    }
424}
425
426#[cfg(test)]
427mod tests {
428    use super::*;
429
430    #[test]
431    fn test_ollama_model_creation() {
432        let model = OllamaModel::new("llama3.2");
433        assert_eq!(model.config().model_id, "llama3.2");
434    }
435
436    #[test]
437    fn test_ollama_with_host() {
438        let model = OllamaModel::new("llama3").with_host("http://192.168.1.100:11434");
439        assert_eq!(model.host, "http://192.168.1.100:11434");
440    }
441
442    #[test]
443    fn test_ollama_default() {
444        let model = OllamaModel::default();
445        assert_eq!(model.config().model_id, "llama3");
446        assert_eq!(model.host, "http://localhost:11434");
447    }
448}
449