strands_agents/models/
llamacpp.rs

1//! llama.cpp model provider.
2//!
3//! Provides integration with llama.cpp servers running in OpenAI-compatible mode.
4//! Docs: https://github.com/ggml-org/llama.cpp
5
6use std::collections::HashMap;
7
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10
11use crate::models::{Model, ModelConfig, StreamEventStream};
12use crate::types::content::{Message, Role, SystemContentBlock};
13use crate::types::errors::StrandsError;
14use crate::types::streaming::{StopReason, StreamEvent};
15use crate::types::tools::{ToolChoice, ToolSpec};
16
17/// Configuration for llama.cpp models.
18#[derive(Debug, Clone)]
19pub struct LlamaCppConfig {
20    /// Model identifier (default: "default").
21    pub model_id: String,
22    /// Base URL for the llama.cpp server.
23    pub base_url: String,
24    /// Additional model parameters.
25    pub params: HashMap<String, serde_json::Value>,
26}
27
28impl Default for LlamaCppConfig {
29    fn default() -> Self {
30        Self {
31            model_id: "default".to_string(),
32            base_url: "http://localhost:8080".to_string(),
33            params: HashMap::new(),
34        }
35    }
36}
37
38impl LlamaCppConfig {
39    pub fn new(base_url: impl Into<String>) -> Self {
40        Self {
41            base_url: base_url.into(),
42            ..Default::default()
43        }
44    }
45
46    pub fn with_model_id(mut self, model_id: impl Into<String>) -> Self {
47        self.model_id = model_id.into();
48        self
49    }
50
51    pub fn with_param(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
52        self.params.insert(key.into(), value);
53        self
54    }
55
56    pub fn with_temperature(mut self, temperature: f32) -> Self {
57        self.params.insert("temperature".to_string(), serde_json::json!(temperature));
58        self
59    }
60
61    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
62        self.params.insert("max_tokens".to_string(), serde_json::json!(max_tokens));
63        self
64    }
65}
66
67/// OpenAI-compatible request format for llama.cpp.
68#[derive(Debug, Serialize)]
69struct LlamaCppRequest {
70    model: String,
71    messages: Vec<LlamaCppMessage>,
72    #[serde(skip_serializing_if = "Option::is_none")]
73    max_tokens: Option<u32>,
74    #[serde(skip_serializing_if = "Option::is_none")]
75    temperature: Option<f32>,
76    #[serde(skip_serializing_if = "Option::is_none")]
77    tools: Option<Vec<LlamaCppTool>>,
78    stream: bool,
79    #[serde(flatten)]
80    extra: HashMap<String, serde_json::Value>,
81}
82
83#[derive(Debug, Serialize, Deserialize)]
84struct LlamaCppMessage {
85    role: String,
86    content: serde_json::Value,
87}
88
89#[derive(Debug, Serialize)]
90struct LlamaCppTool {
91    #[serde(rename = "type")]
92    tool_type: String,
93    function: LlamaCppFunction,
94}
95
96#[derive(Debug, Serialize)]
97struct LlamaCppFunction {
98    name: String,
99    description: String,
100    parameters: serde_json::Value,
101}
102
103/// llama.cpp model provider.
104pub struct LlamaCppModel {
105    config: ModelConfig,
106    llamacpp_config: LlamaCppConfig,
107    client: reqwest::Client,
108}
109
110impl LlamaCppModel {
111    pub fn new(config: LlamaCppConfig) -> Self {
112        let model_config = ModelConfig::new(&config.model_id);
113
114        Self {
115            config: model_config,
116            llamacpp_config: config,
117            client: reqwest::Client::new(),
118        }
119    }
120
121    fn convert_messages(&self, messages: &[Message], system_prompt: Option<&str>) -> Vec<LlamaCppMessage> {
122        let mut result = Vec::new();
123
124        if let Some(prompt) = system_prompt {
125            result.push(LlamaCppMessage {
126                role: "system".to_string(),
127                content: serde_json::json!(prompt),
128            });
129        }
130
131        for msg in messages {
132            let role = match msg.role {
133                Role::User => "user",
134                Role::Assistant => "assistant",
135            };
136
137            let content = msg.text_content();
138
139            result.push(LlamaCppMessage {
140                role: role.to_string(),
141                content: serde_json::json!(content),
142            });
143        }
144
145        result
146    }
147
148    fn convert_tools(&self, tool_specs: &[ToolSpec]) -> Vec<LlamaCppTool> {
149        tool_specs
150            .iter()
151            .map(|spec| LlamaCppTool {
152                tool_type: "function".to_string(),
153                function: LlamaCppFunction {
154                    name: spec.name.clone(),
155                    description: spec.description.clone(),
156                    parameters: serde_json::to_value(&spec.input_schema).unwrap_or_default(),
157                },
158            })
159            .collect()
160    }
161}
162
163#[async_trait]
164impl Model for LlamaCppModel {
165    fn config(&self) -> &ModelConfig {
166        &self.config
167    }
168
169    fn update_config(&mut self, config: ModelConfig) {
170        self.config = config;
171    }
172
173    fn stream<'a>(
174        &'a self,
175        messages: &'a [Message],
176        tool_specs: Option<&'a [ToolSpec]>,
177        system_prompt: Option<&'a str>,
178        _tool_choice: Option<ToolChoice>,
179        _system_prompt_content: Option<&'a [SystemContentBlock]>,
180    ) -> StreamEventStream<'a> {
181        let messages = messages.to_vec();
182        let tool_specs = tool_specs.map(|t| t.to_vec());
183        let system_prompt = system_prompt.map(|s| s.to_string());
184
185        Box::pin(async_stream::stream! {
186            let llamacpp_messages = self.convert_messages(&messages, system_prompt.as_deref());
187            let tools = tool_specs.as_ref().map(|specs| self.convert_tools(specs));
188
189            let max_tokens = self.llamacpp_config.params
190                .get("max_tokens")
191                .and_then(|v| v.as_u64())
192                .map(|v| v as u32);
193
194            let temperature = self.llamacpp_config.params
195                .get("temperature")
196                .and_then(|v| v.as_f64())
197                .map(|v| v as f32);
198
199            let request = LlamaCppRequest {
200                model: self.config.model_id.clone(),
201                messages: llamacpp_messages,
202                max_tokens,
203                temperature,
204                tools,
205                stream: true,
206                extra: self.llamacpp_config.params.clone(),
207            };
208
209            let url = format!("{}/v1/chat/completions", self.llamacpp_config.base_url);
210
211            let response = match self.client
212                .post(&url)
213                .header("Content-Type", "application/json")
214                .json(&request)
215                .send()
216                .await
217            {
218                Ok(resp) => resp,
219                Err(e) => {
220                    yield Err(StrandsError::NetworkError(e.to_string()));
221                    return;
222                }
223            };
224
225            if !response.status().is_success() {
226                let status = response.status();
227                let body = response.text().await.unwrap_or_default();
228
229                if status.as_u16() == 429 {
230                    yield Err(StrandsError::ModelThrottled {
231                        message: "llama.cpp rate limit exceeded".into(),
232                    });
233                } else {
234                    yield Err(StrandsError::ModelError {
235                        message: format!("llama.cpp API error {}: {}", status, body),
236                        source: None,
237                    });
238                }
239                return;
240            }
241
242            yield Ok(StreamEvent::message_start(crate::types::content::Role::Assistant));
243
244            let body = match response.text().await {
245                Ok(b) => b,
246                Err(e) => {
247                    yield Err(StrandsError::NetworkError(e.to_string()));
248                    return;
249                }
250            };
251
252            for line in body.lines() {
253                if line.starts_with("data: ") {
254                    let data = &line[6..];
255                    if data == "[DONE]" {
256                        break;
257                    }
258
259                    if let Ok(chunk) = serde_json::from_str::<serde_json::Value>(data) {
260                        if let Some(choices) = chunk.get("choices").and_then(|c| c.as_array()) {
261                            for choice in choices {
262                                if let Some(delta) = choice.get("delta") {
263                                    if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
264                                        yield Ok(StreamEvent::text_delta(0, content));
265                                    }
266                                }
267                            }
268                        }
269                    }
270                }
271            }
272
273            yield Ok(StreamEvent::message_stop(StopReason::EndTurn));
274        })
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281
282    #[test]
283    fn test_llamacpp_config() {
284        let config = LlamaCppConfig::new("http://localhost:8080")
285            .with_model_id("my-model")
286            .with_temperature(0.7);
287
288        assert_eq!(config.base_url, "http://localhost:8080");
289        assert_eq!(config.model_id, "my-model");
290        assert!(config.params.contains_key("temperature"));
291    }
292}
293