vtcode_core/llm/providers/
mistral.rs1use async_trait::async_trait;
2use reqwest::Client as HttpClient;
3use serde_json::{Map, Value};
4
5use crate::config::TimeoutsConfig;
6use crate::config::constants::{env_vars, models, urls};
7use crate::config::core::{AnthropicConfig, ModelConfig, PromptCachingConfig};
8use crate::llm::error_display;
9use crate::llm::provider::{LLMError, LLMProvider, LLMRequest, LLMResponse, LLMStream};
10
11use super::{
12 common::{
13 ensure_model, extract_prompt_cache_settings_default, impl_llm_client, override_base_url,
14 parse_json_response, parse_response_openai_format, resolve_model,
15 serialize_messages_openai_format, serialize_tools_openai_format,
16 spawn_openai_compatible_stream, validate_supported_models,
17 },
18 error_handling::handle_openai_http_error,
19};
20
21const PROVIDER_NAME: &str = "Mistral";
22const PROVIDER_KEY: &str = "mistral";
23
24pub struct MistralProvider {
25 api_key: String,
26 http_client: HttpClient,
27 base_url: String,
28 model: String,
29 prompt_cache_enabled: bool,
30 model_behavior: Option<ModelConfig>,
31}
32
33impl MistralProvider {
34 pub fn new(api_key: String) -> Self {
35 Self::with_model_internal(
36 api_key,
37 models::mistral::DEFAULT_MODEL.to_string(),
38 None,
39 None,
40 TimeoutsConfig::default(),
41 None,
42 )
43 }
44
45 pub fn with_model(api_key: String, model: String) -> Self {
46 Self::with_model_internal(api_key, model, None, None, TimeoutsConfig::default(), None)
47 }
48
49 pub fn new_with_client(
50 api_key: String,
51 model: String,
52 http_client: reqwest::Client,
53 base_url: String,
54 _timeouts: TimeoutsConfig,
55 ) -> Self {
56 Self {
57 api_key,
58 http_client,
59 base_url,
60 model,
61 prompt_cache_enabled: false,
62 model_behavior: None,
63 }
64 }
65
66 pub fn from_config(
67 api_key: Option<String>,
68 model: Option<String>,
69 base_url: Option<String>,
70 prompt_cache: Option<PromptCachingConfig>,
71 timeouts: Option<TimeoutsConfig>,
72 _anthropic: Option<AnthropicConfig>,
73 model_behavior: Option<ModelConfig>,
74 ) -> Self {
75 let api_key_value = api_key.unwrap_or_default();
76 let model_value = resolve_model(model, models::mistral::DEFAULT_MODEL);
77
78 Self::with_model_internal(
79 api_key_value,
80 model_value,
81 prompt_cache,
82 base_url,
83 timeouts.unwrap_or_default(),
84 model_behavior,
85 )
86 }
87
88 fn with_model_internal(
89 api_key: String,
90 model: String,
91 prompt_cache: Option<PromptCachingConfig>,
92 base_url: Option<String>,
93 timeouts: TimeoutsConfig,
94 model_behavior: Option<ModelConfig>,
95 ) -> Self {
96 use crate::llm::http_client::HttpClientFactory;
97
98 let (prompt_cache_enabled, _) =
99 extract_prompt_cache_settings_default(prompt_cache, "mistral");
100
101 Self {
102 api_key,
103 http_client: HttpClientFactory::for_llm(&timeouts),
104 base_url: override_base_url(
105 urls::MISTRAL_API_BASE,
106 base_url,
107 Some(env_vars::MISTRAL_BASE_URL),
108 ),
109 model,
110 prompt_cache_enabled,
111 model_behavior,
112 }
113 }
114
115 fn convert_to_mistral_format(&self, request: &LLMRequest) -> Result<Value, LLMError> {
116 let mut payload = Map::with_capacity(12);
117
118 let mut messages = self.serialize_messages(request)?;
119
120 if let Some(system_prompt) = &request.system_prompt {
123 let trimmed = system_prompt.trim();
124 if !trimmed.is_empty() {
125 messages.insert(0, serde_json::json!({"role": "system", "content": trimmed}));
126 }
127 }
128
129 payload.insert("model".to_owned(), Value::String(request.model.clone()));
130 payload.insert("messages".to_owned(), Value::Array(messages));
131
132 if let Some(max_tokens) = request.max_tokens {
133 payload.insert(
134 "max_tokens".to_owned(),
135 Value::Number(serde_json::Number::from(max_tokens as u64)),
136 );
137 }
138
139 if let Some(temperature) = request.temperature {
140 payload.insert(
141 "temperature".to_owned(),
142 Value::Number(Self::float_to_json_number(temperature)?),
143 );
144 }
145
146 if let Some(top_p) = request.top_p {
147 payload.insert(
148 "top_p".to_owned(),
149 Value::Number(Self::float_to_json_number(top_p)?),
150 );
151 }
152
153 if request.stream {
154 payload.insert("stream".to_string(), Value::Bool(true));
155 payload.insert(
156 "stream_options".to_string(),
157 serde_json::json!({"include_usage": true}),
158 );
159 }
160
161 if let Some(tools) = &request.tools
162 && let Some(serialized_tools) = serialize_tools_openai_format(tools)
163 {
164 payload.insert("tools".to_string(), Value::Array(serialized_tools));
165 payload.insert("parallel_tool_calls".to_string(), Value::Bool(false));
166 }
167
168 let has_explicit_choice = request.tool_choice.is_some();
169 if let Some(choice) = &request.tool_choice {
170 payload.insert(
171 "tool_choice".to_string(),
172 choice.to_provider_format(PROVIDER_KEY),
173 );
174 }
175 if !has_explicit_choice && request.tools.as_ref().is_some_and(|t| !t.is_empty()) {
179 payload.insert("tool_choice".to_string(), Value::String("auto".to_owned()));
180 }
181
182 if let Some(effort) = request.reasoning_effort
183 && effort != crate::config::types::ReasoningEffortLevel::None
184 {
185 payload.insert(
186 "reasoning_effort".to_owned(),
187 Value::String("high".to_owned()),
188 );
189 }
190
191 if let Some(meta) = &request.metadata
192 && let Some(user_id) = meta.get("user_id").and_then(|v| v.as_str())
193 {
194 payload.insert("user_id".to_owned(), Value::String(user_id.to_owned()));
195 }
196
197 Ok(Value::Object(payload))
198 }
199
200 fn float_to_json_number(value: f32) -> Result<serde_json::Number, LLMError> {
201 serde_json::Number::from_f64(value as f64).ok_or_else(|| LLMError::InvalidRequest {
202 message: "invalid numeric parameter value (NaN or infinity)".to_string(),
203 metadata: None,
204 })
205 }
206
207 async fn send_request(&self, payload: &Value) -> Result<reqwest::Response, LLMError> {
208 let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
209
210 self.http_client
211 .post(&url)
212 .bearer_auth(&self.api_key)
213 .json(payload)
214 .send()
215 .await
216 .map_err(|e| LLMError::Network {
217 message: error_display::format_llm_error(
218 PROVIDER_NAME,
219 &format!("network error: {}", e),
220 ),
221 metadata: None,
222 })
223 }
224
225 fn serialize_messages(&self, request: &LLMRequest) -> Result<Vec<Value>, LLMError> {
226 serialize_messages_openai_format(request, PROVIDER_KEY)
227 }
228
229 fn parse_response(&self, response_json: Value, model: String) -> Result<LLMResponse, LLMError> {
230 parse_response_openai_format(
231 response_json,
232 PROVIDER_NAME,
233 model,
234 self.prompt_cache_enabled,
235 None as Option<fn(&Value, &Value) -> Option<String>>,
236 )
237 }
238}
239
240#[async_trait]
241impl LLMProvider for MistralProvider {
242 fn name(&self) -> &str {
243 PROVIDER_KEY
244 }
245
246 fn supports_streaming(&self) -> bool {
247 true
248 }
249
250 fn supports_tools(&self, _model: &str) -> bool {
251 true
252 }
253
254 fn supports_structured_output(&self, _model: &str) -> bool {
255 true
256 }
257
258 fn supports_vision(&self, _model: &str) -> bool {
259 true
260 }
261
262 fn supports_reasoning(&self, model: &str) -> bool {
263 let requested = if model.trim().is_empty() {
264 &self.model
265 } else {
266 model
267 };
268
269 self.model_behavior
270 .as_ref()
271 .and_then(|b| b.model_supports_reasoning)
272 .unwrap_or(false)
273 || requested == models::mistral::MISTRAL_LARGE_3
274 }
275
276 fn supports_reasoning_effort(&self, _model: &str) -> bool {
277 self.model_behavior
278 .as_ref()
279 .and_then(|b| b.model_supports_reasoning_effort)
280 .unwrap_or(false)
281 }
282
283 fn effective_context_size(&self, _model: &str) -> usize {
284 256_000
285 }
286
287 async fn generate(&self, mut request: LLMRequest) -> Result<LLMResponse, LLMError> {
288 let model = ensure_model(&mut request, &self.model);
289
290 let payload = self.convert_to_mistral_format(&request)?;
291 let response = self.send_request(&payload).await?;
292 let response = handle_openai_http_error(response, PROVIDER_NAME, "MISTRAL_API_KEY").await?;
293
294 let response_json = parse_json_response(response, PROVIDER_NAME).await?;
295 self.parse_response(response_json, model)
296 }
297
298 async fn stream(&self, mut request: LLMRequest) -> Result<LLMStream, LLMError> {
299 ensure_model(&mut request, &self.model);
300 self.validate_request(&request)?;
301 request.stream = true;
302 let model = request.model.clone();
303
304 let payload = self.convert_to_mistral_format(&request)?;
305 let response = self.send_request(&payload).await?;
306 let response = handle_openai_http_error(response, PROVIDER_NAME, "MISTRAL_API_KEY").await?;
307
308 Ok(spawn_openai_compatible_stream(
309 response,
310 PROVIDER_NAME,
311 model,
312 Some("reasoning_content"),
313 super::shared::OpenAiDeltaOrder::ContentFirst,
314 ))
315 }
316
317 fn supported_models(&self) -> Vec<String> {
318 models::mistral::SUPPORTED_MODELS
319 .iter()
320 .map(|model| model.to_string())
321 .collect()
322 }
323
324 fn validate_request(&self, request: &LLMRequest) -> Result<(), LLMError> {
325 validate_supported_models(
326 request,
327 PROVIDER_NAME,
328 PROVIDER_KEY,
329 models::mistral::SUPPORTED_MODELS,
330 )
331 }
332}
333
334impl_llm_client!(MistralProvider);