strands_agents/models/
mistral.rs

1//! Mistral AI model provider.
2//!
3//! Docs: https://docs.mistral.ai/
4
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7
8use crate::models::{Model, ModelConfig, StreamEventStream};
9use crate::types::content::{Message, Role, SystemContentBlock};
10use crate::types::errors::StrandsError;
11use crate::types::streaming::{StopReason, StreamEvent};
12use crate::types::tools::{ToolChoice, ToolSpec};
13
14/// Configuration for Mistral models.
15#[derive(Debug, Clone, Default)]
16pub struct MistralConfig {
17    /// Mistral model ID (e.g., "mistral-large-latest").
18    pub model_id: String,
19    /// Maximum number of tokens to generate.
20    pub max_tokens: Option<u32>,
21    /// Controls randomness (0.0 to 1.0).
22    pub temperature: Option<f32>,
23    /// Controls diversity via nucleus sampling.
24    pub top_p: Option<f32>,
25    /// API key for authentication.
26    pub api_key: Option<String>,
27}
28
29impl MistralConfig {
30    pub fn new(model_id: impl Into<String>) -> Self {
31        Self {
32            model_id: model_id.into(),
33            ..Default::default()
34        }
35    }
36
37    pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
38        self.api_key = Some(api_key.into());
39        self
40    }
41
42    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
43        self.max_tokens = Some(max_tokens);
44        self
45    }
46
47    pub fn with_temperature(mut self, temperature: f32) -> Self {
48        self.temperature = Some(temperature);
49        self
50    }
51}
52
53/// Mistral API request format.
54#[derive(Debug, Serialize)]
55struct MistralRequest {
56    model: String,
57    messages: Vec<MistralMessage>,
58    #[serde(skip_serializing_if = "Option::is_none")]
59    max_tokens: Option<u32>,
60    #[serde(skip_serializing_if = "Option::is_none")]
61    temperature: Option<f32>,
62    #[serde(skip_serializing_if = "Option::is_none")]
63    top_p: Option<f32>,
64    #[serde(skip_serializing_if = "Option::is_none")]
65    tools: Option<Vec<MistralTool>>,
66    stream: bool,
67}
68
69#[derive(Debug, Serialize, Deserialize)]
70struct MistralMessage {
71    role: String,
72    content: String,
73    #[serde(skip_serializing_if = "Option::is_none")]
74    tool_calls: Option<Vec<MistralToolCall>>,
75    #[serde(skip_serializing_if = "Option::is_none")]
76    tool_call_id: Option<String>,
77}
78
79#[derive(Debug, Serialize, Deserialize)]
80struct MistralToolCall {
81    id: String,
82    #[serde(rename = "type")]
83    call_type: String,
84    function: MistralFunction,
85}
86
87#[derive(Debug, Serialize, Deserialize)]
88struct MistralFunction {
89    name: String,
90    arguments: String,
91}
92
93#[derive(Debug, Serialize)]
94struct MistralTool {
95    #[serde(rename = "type")]
96    tool_type: String,
97    function: MistralFunctionDef,
98}
99
100#[derive(Debug, Serialize)]
101struct MistralFunctionDef {
102    name: String,
103    description: String,
104    parameters: serde_json::Value,
105}
106
107/// Mistral AI model provider.
108pub struct MistralModel {
109    config: ModelConfig,
110    mistral_config: MistralConfig,
111    client: reqwest::Client,
112}
113
114impl MistralModel {
115    const BASE_URL: &'static str = "https://api.mistral.ai/v1";
116
117    pub fn new(config: MistralConfig) -> Self {
118        let model_config = ModelConfig {
119            model_id: config.model_id.clone(),
120            max_tokens: config.max_tokens,
121            temperature: config.temperature,
122            top_p: config.top_p,
123            ..Default::default()
124        };
125
126        Self {
127            config: model_config,
128            mistral_config: config,
129            client: reqwest::Client::new(),
130        }
131    }
132
133    fn api_key(&self) -> Result<String, StrandsError> {
134        self.mistral_config
135            .api_key
136            .clone()
137            .or_else(|| std::env::var("MISTRAL_API_KEY").ok())
138            .ok_or_else(|| StrandsError::ConfigurationError {
139                message: "Mistral API key not configured. Set MISTRAL_API_KEY or provide api_key".into(),
140            })
141    }
142
143    fn convert_messages(&self, messages: &[Message], system_prompt: Option<&str>) -> Vec<MistralMessage> {
144        let mut result = Vec::new();
145
146        if let Some(prompt) = system_prompt {
147            result.push(MistralMessage {
148                role: "system".to_string(),
149                content: prompt.to_string(),
150                tool_calls: None,
151                tool_call_id: None,
152            });
153        }
154
155        for msg in messages {
156            let role = match msg.role {
157                Role::User => "user",
158                Role::Assistant => "assistant",
159            };
160
161            let content = msg.text_content();
162
163            let tool_calls: Option<Vec<MistralToolCall>> = {
164                let calls: Vec<_> = msg
165                    .content
166                    .iter()
167                    .filter_map(|b| b.tool_use.as_ref())
168                    .map(|tu| MistralToolCall {
169                        id: tu.tool_use_id.clone(),
170                        call_type: "function".to_string(),
171                        function: MistralFunction {
172                            name: tu.name.clone(),
173                            arguments: serde_json::to_string(&tu.input).unwrap_or_default(),
174                        },
175                    })
176                    .collect();
177
178                if calls.is_empty() {
179                    None
180                } else {
181                    Some(calls)
182                }
183            };
184
185            if tool_calls.is_some() {
186                result.push(MistralMessage {
187                    role: role.to_string(),
188                    content,
189                    tool_calls,
190                    tool_call_id: None,
191                });
192            } else if msg.has_tool_result() {
193                for block in &msg.content {
194                    if let Some(tr) = &block.tool_result {
195                        let content_text = tr
196                            .content
197                            .iter()
198                            .filter_map(|c| c.text.as_ref())
199                            .cloned()
200                            .collect::<Vec<_>>()
201                            .join("");
202
203                        result.push(MistralMessage {
204                            role: "tool".to_string(),
205                            content: content_text,
206                            tool_calls: None,
207                            tool_call_id: Some(tr.tool_use_id.clone()),
208                        });
209                    }
210                }
211            } else {
212                result.push(MistralMessage {
213                    role: role.to_string(),
214                    content,
215                    tool_calls: None,
216                    tool_call_id: None,
217                });
218            }
219        }
220
221        result
222    }
223
224    fn convert_tools(&self, tool_specs: &[ToolSpec]) -> Vec<MistralTool> {
225        tool_specs
226            .iter()
227            .map(|spec| MistralTool {
228                tool_type: "function".to_string(),
229                function: MistralFunctionDef {
230                    name: spec.name.clone(),
231                    description: spec.description.clone(),
232                    parameters: serde_json::to_value(&spec.input_schema).unwrap_or_default(),
233                },
234            })
235            .collect()
236    }
237}
238
239#[async_trait]
240impl Model for MistralModel {
241    fn config(&self) -> &ModelConfig {
242        &self.config
243    }
244
245    fn update_config(&mut self, config: ModelConfig) {
246        self.config = config;
247    }
248
249    fn stream<'a>(
250        &'a self,
251        messages: &'a [Message],
252        tool_specs: Option<&'a [ToolSpec]>,
253        system_prompt: Option<&'a str>,
254        _tool_choice: Option<ToolChoice>,
255        _system_prompt_content: Option<&'a [SystemContentBlock]>,
256    ) -> StreamEventStream<'a> {
257        let messages = messages.to_vec();
258        let tool_specs = tool_specs.map(|t| t.to_vec());
259        let system_prompt = system_prompt.map(|s| s.to_string());
260
261        Box::pin(async_stream::stream! {
262            let api_key = match self.api_key() {
263                Ok(key) => key,
264                Err(e) => {
265                    yield Err(e);
266                    return;
267                }
268            };
269
270            let mistral_messages = self.convert_messages(&messages, system_prompt.as_deref());
271            let tools = tool_specs.as_ref().map(|specs| self.convert_tools(specs));
272
273            let request = MistralRequest {
274                model: self.config.model_id.clone(),
275                messages: mistral_messages,
276                max_tokens: self.config.max_tokens,
277                temperature: self.config.temperature,
278                top_p: self.config.top_p,
279                tools,
280                stream: true,
281            };
282
283            let url = format!("{}/chat/completions", Self::BASE_URL);
284
285            let response = match self.client
286                .post(&url)
287                .header("Authorization", format!("Bearer {}", api_key))
288                .header("Content-Type", "application/json")
289                .json(&request)
290                .send()
291                .await
292            {
293                Ok(resp) => resp,
294                Err(e) => {
295                    yield Err(StrandsError::NetworkError(e.to_string()));
296                    return;
297                }
298            };
299
300            if !response.status().is_success() {
301                let status = response.status();
302                let body = response.text().await.unwrap_or_default();
303
304                if status.as_u16() == 429 {
305                    yield Err(StrandsError::ModelThrottled {
306                        message: "Mistral rate limit exceeded".into(),
307                    });
308                } else {
309                    yield Err(StrandsError::ModelError {
310                        message: format!("Mistral API error {}: {}", status, body),
311                        source: None,
312                    });
313                }
314                return;
315            }
316
317            yield Ok(StreamEvent::message_start(Role::Assistant));
318
319            let body = match response.text().await {
320                Ok(b) => b,
321                Err(e) => {
322                    yield Err(StrandsError::NetworkError(e.to_string()));
323                    return;
324                }
325            };
326
327            for line in body.lines() {
328                if line.starts_with("data: ") {
329                    let data = &line[6..];
330                    if data == "[DONE]" {
331                        break;
332                    }
333
334                    if let Ok(chunk) = serde_json::from_str::<serde_json::Value>(data) {
335                        if let Some(choices) = chunk.get("choices").and_then(|c| c.as_array()) {
336                            for choice in choices {
337                                if let Some(delta) = choice.get("delta") {
338                                    if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
339                                        yield Ok(StreamEvent::text_delta(0, content));
340                                    }
341                                }
342                            }
343                        }
344                    }
345                }
346            }
347
348            yield Ok(StreamEvent::message_stop(StopReason::EndTurn));
349        })
350    }
351}
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356
357    #[test]
358    fn test_mistral_config() {
359        let config = MistralConfig::new("mistral-large-latest")
360            .with_api_key("test-key")
361            .with_temperature(0.7);
362
363        assert_eq!(config.model_id, "mistral-large-latest");
364        assert_eq!(config.api_key, Some("test-key".to_string()));
365        assert_eq!(config.temperature, Some(0.7));
366    }
367}
368