Skip to main content

vtcode_core/llm/providers/
mistral.rs

1use async_trait::async_trait;
2use reqwest::Client as HttpClient;
3use serde_json::{Map, Value};
4
5use crate::config::TimeoutsConfig;
6use crate::config::constants::{env_vars, models, urls};
7use crate::config::core::{AnthropicConfig, ModelConfig, PromptCachingConfig};
8use crate::llm::error_display;
9use crate::llm::provider::{LLMError, LLMProvider, LLMRequest, LLMResponse, LLMStream};
10
11use super::{
12    common::{
13        ensure_model, extract_prompt_cache_settings_default, impl_llm_client, override_base_url,
14        parse_json_response, parse_response_openai_format, resolve_model,
15        serialize_messages_openai_format, serialize_tools_openai_format,
16        spawn_openai_compatible_stream, validate_supported_models,
17    },
18    error_handling::handle_openai_http_error,
19};
20
21const PROVIDER_NAME: &str = "Mistral";
22const PROVIDER_KEY: &str = "mistral";
23
24pub struct MistralProvider {
25    api_key: String,
26    http_client: HttpClient,
27    base_url: String,
28    model: String,
29    prompt_cache_enabled: bool,
30    model_behavior: Option<ModelConfig>,
31}
32
33impl MistralProvider {
34    pub fn new(api_key: String) -> Self {
35        Self::with_model_internal(
36            api_key,
37            models::mistral::DEFAULT_MODEL.to_string(),
38            None,
39            None,
40            TimeoutsConfig::default(),
41            None,
42        )
43    }
44
45    pub fn with_model(api_key: String, model: String) -> Self {
46        Self::with_model_internal(api_key, model, None, None, TimeoutsConfig::default(), None)
47    }
48
49    pub fn new_with_client(
50        api_key: String,
51        model: String,
52        http_client: reqwest::Client,
53        base_url: String,
54        _timeouts: TimeoutsConfig,
55    ) -> Self {
56        Self {
57            api_key,
58            http_client,
59            base_url,
60            model,
61            prompt_cache_enabled: false,
62            model_behavior: None,
63        }
64    }
65
66    pub fn from_config(
67        api_key: Option<String>,
68        model: Option<String>,
69        base_url: Option<String>,
70        prompt_cache: Option<PromptCachingConfig>,
71        timeouts: Option<TimeoutsConfig>,
72        _anthropic: Option<AnthropicConfig>,
73        model_behavior: Option<ModelConfig>,
74    ) -> Self {
75        let api_key_value = api_key.unwrap_or_default();
76        let model_value = resolve_model(model, models::mistral::DEFAULT_MODEL);
77
78        Self::with_model_internal(
79            api_key_value,
80            model_value,
81            prompt_cache,
82            base_url,
83            timeouts.unwrap_or_default(),
84            model_behavior,
85        )
86    }
87
88    fn with_model_internal(
89        api_key: String,
90        model: String,
91        prompt_cache: Option<PromptCachingConfig>,
92        base_url: Option<String>,
93        timeouts: TimeoutsConfig,
94        model_behavior: Option<ModelConfig>,
95    ) -> Self {
96        use crate::llm::http_client::HttpClientFactory;
97
98        let (prompt_cache_enabled, _) =
99            extract_prompt_cache_settings_default(prompt_cache, "mistral");
100
101        Self {
102            api_key,
103            http_client: HttpClientFactory::for_llm(&timeouts),
104            base_url: override_base_url(
105                urls::MISTRAL_API_BASE,
106                base_url,
107                Some(env_vars::MISTRAL_BASE_URL),
108            ),
109            model,
110            prompt_cache_enabled,
111            model_behavior,
112        }
113    }
114
115    fn convert_to_mistral_format(&self, request: &LLMRequest) -> Result<Value, LLMError> {
116        let mut payload = Map::with_capacity(12);
117
118        let mut messages = self.serialize_messages(request)?;
119
120        // Mistral API does not support a top-level "system" field.
121        // Inject the system prompt as a system-role message at the start.
122        if let Some(system_prompt) = &request.system_prompt {
123            let trimmed = system_prompt.trim();
124            if !trimmed.is_empty() {
125                messages.insert(0, serde_json::json!({"role": "system", "content": trimmed}));
126            }
127        }
128
129        payload.insert("model".to_owned(), Value::String(request.model.clone()));
130        payload.insert("messages".to_owned(), Value::Array(messages));
131
132        if let Some(max_tokens) = request.max_tokens {
133            payload.insert(
134                "max_tokens".to_owned(),
135                Value::Number(serde_json::Number::from(max_tokens as u64)),
136            );
137        }
138
139        if let Some(temperature) = request.temperature {
140            payload.insert(
141                "temperature".to_owned(),
142                Value::Number(Self::float_to_json_number(temperature)?),
143            );
144        }
145
146        if let Some(top_p) = request.top_p {
147            payload.insert(
148                "top_p".to_owned(),
149                Value::Number(Self::float_to_json_number(top_p)?),
150            );
151        }
152
153        if request.stream {
154            payload.insert("stream".to_string(), Value::Bool(true));
155            payload.insert(
156                "stream_options".to_string(),
157                serde_json::json!({"include_usage": true}),
158            );
159        }
160
161        if let Some(tools) = &request.tools
162            && let Some(serialized_tools) = serialize_tools_openai_format(tools)
163        {
164            payload.insert("tools".to_string(), Value::Array(serialized_tools));
165            payload.insert("parallel_tool_calls".to_string(), Value::Bool(false));
166        }
167
168        let has_explicit_choice = request.tool_choice.is_some();
169        if let Some(choice) = &request.tool_choice {
170            payload.insert(
171                "tool_choice".to_string(),
172                choice.to_provider_format(PROVIDER_KEY),
173            );
174        }
175        // Mistral's default "auto" tool_choice sometimes causes the model to
176        // emit tool call arguments as plain text content. Setting it explicitly
177        // when tools are present helps the model use structured tool_calls.
178        if !has_explicit_choice && request.tools.as_ref().is_some_and(|t| !t.is_empty()) {
179            payload.insert("tool_choice".to_string(), Value::String("auto".to_owned()));
180        }
181
182        if let Some(effort) = request.reasoning_effort
183            && effort != crate::config::types::ReasoningEffortLevel::None
184        {
185            payload.insert(
186                "reasoning_effort".to_owned(),
187                Value::String("high".to_owned()),
188            );
189        }
190
191        if let Some(meta) = &request.metadata
192            && let Some(user_id) = meta.get("user_id").and_then(|v| v.as_str())
193        {
194            payload.insert("user_id".to_owned(), Value::String(user_id.to_owned()));
195        }
196
197        Ok(Value::Object(payload))
198    }
199
200    fn float_to_json_number(value: f32) -> Result<serde_json::Number, LLMError> {
201        serde_json::Number::from_f64(value as f64).ok_or_else(|| LLMError::InvalidRequest {
202            message: "invalid numeric parameter value (NaN or infinity)".to_string(),
203            metadata: None,
204        })
205    }
206
207    async fn send_request(&self, payload: &Value) -> Result<reqwest::Response, LLMError> {
208        let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
209
210        self.http_client
211            .post(&url)
212            .bearer_auth(&self.api_key)
213            .json(payload)
214            .send()
215            .await
216            .map_err(|e| LLMError::Network {
217                message: error_display::format_llm_error(
218                    PROVIDER_NAME,
219                    &format!("network error: {}", e),
220                ),
221                metadata: None,
222            })
223    }
224
225    fn serialize_messages(&self, request: &LLMRequest) -> Result<Vec<Value>, LLMError> {
226        serialize_messages_openai_format(request, PROVIDER_KEY)
227    }
228
229    fn parse_response(&self, response_json: Value, model: String) -> Result<LLMResponse, LLMError> {
230        parse_response_openai_format(
231            response_json,
232            PROVIDER_NAME,
233            model,
234            self.prompt_cache_enabled,
235            None as Option<fn(&Value, &Value) -> Option<String>>,
236        )
237    }
238}
239
240#[async_trait]
241impl LLMProvider for MistralProvider {
242    fn name(&self) -> &str {
243        PROVIDER_KEY
244    }
245
246    fn supports_streaming(&self) -> bool {
247        true
248    }
249
250    fn supports_tools(&self, _model: &str) -> bool {
251        true
252    }
253
254    fn supports_structured_output(&self, _model: &str) -> bool {
255        true
256    }
257
258    fn supports_vision(&self, _model: &str) -> bool {
259        true
260    }
261
262    fn supports_reasoning(&self, model: &str) -> bool {
263        let requested = if model.trim().is_empty() {
264            &self.model
265        } else {
266            model
267        };
268
269        self.model_behavior
270            .as_ref()
271            .and_then(|b| b.model_supports_reasoning)
272            .unwrap_or(false)
273            || requested == models::mistral::MISTRAL_LARGE_3
274    }
275
276    fn supports_reasoning_effort(&self, _model: &str) -> bool {
277        self.model_behavior
278            .as_ref()
279            .and_then(|b| b.model_supports_reasoning_effort)
280            .unwrap_or(false)
281    }
282
283    fn effective_context_size(&self, _model: &str) -> usize {
284        256_000
285    }
286
287    async fn generate(&self, mut request: LLMRequest) -> Result<LLMResponse, LLMError> {
288        let model = ensure_model(&mut request, &self.model);
289
290        let payload = self.convert_to_mistral_format(&request)?;
291        let response = self.send_request(&payload).await?;
292        let response = handle_openai_http_error(response, PROVIDER_NAME, "MISTRAL_API_KEY").await?;
293
294        let response_json = parse_json_response(response, PROVIDER_NAME).await?;
295        self.parse_response(response_json, model)
296    }
297
298    async fn stream(&self, mut request: LLMRequest) -> Result<LLMStream, LLMError> {
299        ensure_model(&mut request, &self.model);
300        self.validate_request(&request)?;
301        request.stream = true;
302        let model = request.model.clone();
303
304        let payload = self.convert_to_mistral_format(&request)?;
305        let response = self.send_request(&payload).await?;
306        let response = handle_openai_http_error(response, PROVIDER_NAME, "MISTRAL_API_KEY").await?;
307
308        Ok(spawn_openai_compatible_stream(
309            response,
310            PROVIDER_NAME,
311            model,
312            Some("reasoning_content"),
313            super::shared::OpenAiDeltaOrder::ContentFirst,
314        ))
315    }
316
317    fn supported_models(&self) -> Vec<String> {
318        models::mistral::SUPPORTED_MODELS
319            .iter()
320            .map(|model| model.to_string())
321            .collect()
322    }
323
324    fn validate_request(&self, request: &LLMRequest) -> Result<(), LLMError> {
325        validate_supported_models(
326            request,
327            PROVIDER_NAME,
328            PROVIDER_KEY,
329            models::mistral::SUPPORTED_MODELS,
330        )
331    }
332}
333
334impl_llm_client!(MistralProvider);