Skip to main content

vtcode_core/llm/providers/
deepseek.rs

1#![allow(clippy::collapsible_if)]
2
3use crate::config::TimeoutsConfig;
4use crate::config::constants::{env_vars, models, urls};
5use crate::config::core::{
6    AnthropicConfig, DeepSeekPromptCacheSettings, ModelConfig, PromptCachingConfig,
7};
8use crate::llm::error_display;
9use crate::llm::provider::{
10    LLMError, LLMProvider, LLMRequest, LLMResponse, LLMStream, LLMStreamEvent,
11};
12use async_stream::try_stream;
13use async_trait::async_trait;
14use reqwest::Client as HttpClient;
15use serde_json::{Map, Value};
16
17use super::{
18    common::{
19        ensure_model, extract_prompt_cache_settings, impl_llm_client, map_finish_reason_common,
20        override_base_url, parse_json_response, parse_response_openai_format, resolve_model,
21        serialize_messages_openai_format, serialize_tools_openai_format, validate_supported_models,
22    },
23    error_handling::handle_openai_http_error,
24    extract_reasoning_trace,
25};
26
27const PROVIDER_NAME: &str = "DeepSeek";
28const PROVIDER_KEY: &str = "deepseek";
29
30pub struct DeepSeekProvider {
31    api_key: String,
32    http_client: HttpClient,
33    base_url: String,
34    model: String,
35    prompt_cache_enabled: bool,
36    prompt_cache_settings: DeepSeekPromptCacheSettings,
37    model_behavior: Option<ModelConfig>,
38}
39
40impl DeepSeekProvider {
41    pub fn new(api_key: String) -> Self {
42        Self::with_model_internal(
43            api_key,
44            models::deepseek::DEFAULT_MODEL.to_string(),
45            None,
46            None,
47            TimeoutsConfig::default(),
48            None,
49        )
50    }
51
52    pub fn with_model(api_key: String, model: String) -> Self {
53        Self::with_model_internal(api_key, model, None, None, TimeoutsConfig::default(), None)
54    }
55
56    pub fn new_with_client(
57        api_key: String,
58        model: String,
59        http_client: reqwest::Client,
60        base_url: String,
61        _timeouts: TimeoutsConfig,
62    ) -> Self {
63        Self {
64            api_key,
65            http_client,
66            base_url,
67            model,
68            prompt_cache_enabled: false,
69            prompt_cache_settings: DeepSeekPromptCacheSettings::default(),
70            model_behavior: None,
71        }
72    }
73
74    pub fn from_config(
75        api_key: Option<String>,
76        model: Option<String>,
77        base_url: Option<String>,
78        prompt_cache: Option<PromptCachingConfig>,
79        timeouts: Option<TimeoutsConfig>,
80        _anthropic: Option<AnthropicConfig>,
81        model_behavior: Option<ModelConfig>,
82    ) -> Self {
83        let api_key_value = api_key.unwrap_or_default();
84        let model_value = resolve_model(model, models::deepseek::DEFAULT_MODEL);
85
86        Self::with_model_internal(
87            api_key_value,
88            model_value,
89            prompt_cache,
90            base_url,
91            timeouts.unwrap_or_default(),
92            model_behavior,
93        )
94    }
95
96    fn with_model_internal(
97        api_key: String,
98        model: String,
99        prompt_cache: Option<PromptCachingConfig>,
100        base_url: Option<String>,
101        timeouts: TimeoutsConfig,
102        model_behavior: Option<ModelConfig>,
103    ) -> Self {
104        use crate::llm::http_client::HttpClientFactory;
105
106        let (prompt_cache_enabled, prompt_cache_settings) = extract_prompt_cache_settings(
107            prompt_cache,
108            |providers| &providers.deepseek,
109            |cfg, provider_settings| cfg.enabled && provider_settings.enabled,
110        );
111
112        Self {
113            api_key,
114            http_client: HttpClientFactory::for_llm(&timeouts),
115            base_url: override_base_url(
116                urls::DEEPSEEK_API_BASE,
117                base_url,
118                Some(env_vars::DEEPSEEK_BASE_URL),
119            ),
120            model,
121            prompt_cache_enabled,
122            prompt_cache_settings,
123            model_behavior,
124        }
125    }
126
127    #[must_use]
128    #[inline]
129    fn is_thinking_enabled(request: &LLMRequest) -> bool {
130        request
131            .reasoning_effort
132            .is_some_and(|e| e != crate::config::types::ReasoningEffortLevel::None)
133    }
134
135    fn float_to_json_number(value: f32) -> Result<serde_json::Number, LLMError> {
136        serde_json::Number::from_f64(value as f64).ok_or_else(|| LLMError::InvalidRequest {
137            message: "invalid numeric parameter value (NaN or infinity)".to_string(),
138            metadata: None,
139        })
140    }
141
142    fn convert_to_deepseek_format(&self, request: &LLMRequest) -> Result<Value, LLMError> {
143        // est. 8–12 keys: model, messages, system (cond), max_tokens (cond),
144        // temperature/top_p (cond), stream (cond), tools (cond), tool_choice (cond),
145        // thinking (cond), user_id (cond)
146        let mut payload = Map::with_capacity(12);
147
148        payload.insert("model".to_owned(), Value::String(request.model.clone()));
149        payload.insert(
150            "messages".to_owned(),
151            Value::Array(self.serialize_messages(request)?),
152        );
153
154        if let Some(system_prompt) = &request.system_prompt {
155            payload.insert(
156                "system".to_owned(),
157                Value::String(system_prompt.trim().to_owned()),
158            );
159        }
160
161        if let Some(max_tokens) = request.max_tokens {
162            payload.insert(
163                "max_tokens".to_owned(),
164                Value::Number(serde_json::Number::from(max_tokens as u64)),
165            );
166        }
167
168        let thinking_enabled = Self::is_thinking_enabled(request);
169
170        // Thinking mode does not support temperature, top_p, presence_penalty,
171        // or frequency_penalty. Suppress them to avoid wasted payload bytes.
172        if !thinking_enabled {
173            if let Some(temperature) = request.temperature {
174                payload.insert(
175                    "temperature".to_owned(),
176                    Value::Number(Self::float_to_json_number(temperature)?),
177                );
178            }
179
180            if let Some(top_p) = request.top_p {
181                payload.insert(
182                    "top_p".to_owned(),
183                    Value::Number(Self::float_to_json_number(top_p)?),
184                );
185            }
186        }
187
188        if request.stream {
189            payload.insert("stream".to_string(), Value::Bool(true));
190            // Request usage info in the final streaming chunk.
191            payload.insert(
192                "stream_options".to_string(),
193                serde_json::json!({"include_usage": true}),
194            );
195        }
196
197        if let Some(tools) = &request.tools
198            && let Some(serialized_tools) = serialize_tools_openai_format(tools)
199        {
200            payload.insert("tools".to_string(), Value::Array(serialized_tools));
201        }
202
203        if let Some(choice) = &request.tool_choice {
204            payload.insert(
205                "tool_choice".to_string(),
206                choice.to_provider_format(PROVIDER_KEY),
207            );
208        }
209
210        if let Some(effort) = request.reasoning_effort {
211            use crate::config::models::Provider;
212            use crate::llm::rig_adapter::RigProviderCapabilities;
213            if effort == crate::config::types::ReasoningEffortLevel::None {
214                payload.insert(
215                    "thinking".to_owned(),
216                    serde_json::json!({"type": "disabled"}),
217                );
218            } else if let Some(reasoning_params) =
219                RigProviderCapabilities::new(Provider::DeepSeek, &request.model)
220                    .reasoning_parameters(effort)
221            {
222                if let Some(params_obj) = reasoning_params.as_object() {
223                    for (k, v) in params_obj {
224                        payload[k] = v.clone();
225                    }
226                }
227            }
228        }
229
230        // Pass through user_id for KV cache isolation and traffic management.
231        if let Some(meta) = &request.metadata {
232            if let Some(user_id) = meta.get("user_id").and_then(|v| v.as_str()) {
233                payload.insert("user_id".to_owned(), Value::String(user_id.to_owned()));
234            }
235        }
236
237        Ok(Value::Object(payload))
238    }
239
240    async fn send_request(&self, payload: &Value) -> Result<reqwest::Response, LLMError> {
241        let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
242
243        self.http_client
244            .post(&url)
245            .bearer_auth(&self.api_key)
246            .json(payload)
247            .send()
248            .await
249            .map_err(|e| LLMError::Network {
250                message: error_display::format_llm_error(
251                    PROVIDER_NAME,
252                    &format!("network error: {}", e),
253                ),
254                metadata: None,
255            })
256    }
257
258    fn serialize_messages(&self, request: &LLMRequest) -> Result<Vec<Value>, LLMError> {
259        serialize_messages_openai_format(request, PROVIDER_KEY)
260    }
261
262    fn parse_response(&self, response_json: Value, model: String) -> Result<LLMResponse, LLMError> {
263        let include_cache = self.prompt_cache_enabled && self.prompt_cache_settings.surface_metrics;
264
265        // Custom reasoning extractor for DeepSeek
266        let reasoning_extractor = |message: &Value, choice: &Value| {
267            message
268                .get("reasoning_content")
269                .and_then(extract_reasoning_trace)
270                .or_else(|| message.get("reasoning").and_then(extract_reasoning_trace))
271                .or_else(|| {
272                    choice
273                        .get("reasoning_content")
274                        .and_then(extract_reasoning_trace)
275                })
276        };
277
278        parse_response_openai_format(
279            response_json,
280            PROVIDER_NAME,
281            model,
282            include_cache,
283            Some(reasoning_extractor),
284        )
285    }
286}
287
288#[async_trait]
289impl LLMProvider for DeepSeekProvider {
290    fn name(&self) -> &str {
291        PROVIDER_KEY
292    }
293
294    fn supports_reasoning(&self, model: &str) -> bool {
295        let requested = if model.trim().is_empty() {
296            &self.model
297        } else {
298            model
299        };
300
301        // Codex-inspired robustness: Setting model_supports_reasoning to false
302        // does NOT disable it for known reasoning models.
303        requested == models::deepseek::DEEPSEEK_V4_PRO
304            || self
305                .model_behavior
306                .as_ref()
307                .and_then(|b| b.model_supports_reasoning)
308                .unwrap_or(false)
309    }
310
311    fn supports_reasoning_effort(&self, _model: &str) -> bool {
312        // Same robustness logic for reasoning effort
313        self.model_behavior
314            .as_ref()
315            .and_then(|b| b.model_supports_reasoning_effort)
316            .unwrap_or(false)
317    }
318
319    async fn generate(&self, mut request: LLMRequest) -> Result<LLMResponse, LLMError> {
320        let model = ensure_model(&mut request, &self.model);
321
322        let payload = self.convert_to_deepseek_format(&request)?;
323        let response = self.send_request(&payload).await?;
324        let response =
325            handle_openai_http_error(response, PROVIDER_NAME, "DEEPSEEK_API_KEY").await?;
326
327        let response_json = parse_json_response(response, PROVIDER_NAME).await?;
328        self.parse_response(response_json, model)
329    }
330
331    async fn stream(&self, mut request: LLMRequest) -> Result<LLMStream, LLMError> {
332        ensure_model(&mut request, &self.model);
333        self.validate_request(&request)?;
334        request.stream = true;
335        let model = request.model.clone();
336
337        let payload = self.convert_to_deepseek_format(&request)?;
338        let response = self.send_request(&payload).await?;
339        let response =
340            handle_openai_http_error(response, PROVIDER_NAME, "DEEPSEEK_API_KEY").await?;
341
342        let bytes_stream = response.bytes_stream();
343        let (event_tx, event_rx) =
344            tokio::sync::mpsc::unbounded_channel::<Result<LLMStreamEvent, LLMError>>();
345        let tx = event_tx.clone();
346
347        let model_clone = model.clone();
348        tokio::spawn(async move {
349            let mut aggregator =
350                crate::llm::providers::shared::StreamAggregator::new(model_clone.clone());
351
352            let result = crate::llm::providers::shared::process_openai_stream(
353                bytes_stream,
354                PROVIDER_NAME,
355                model_clone,
356                |value| {
357                    if let Some(choices) = value.get("choices").and_then(|c| c.as_array())
358                        && let Some(choice) = choices.first()
359                    {
360                        if let Some(delta) = choice.get("delta") {
361                            if let Some(reasoning) =
362                                delta.get("reasoning_content").and_then(|r| r.as_str())
363                            {
364                                if let Some(delta) = aggregator.handle_reasoning(reasoning) {
365                                    let _ = tx.send(Ok(LLMStreamEvent::Reasoning { delta }));
366                                }
367                            }
368
369                            if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
370                                for event in aggregator.handle_content(content) {
371                                    let _ = tx.send(Ok(event));
372                                }
373                            }
374
375                            if let Some(tool_calls) =
376                                delta.get("tool_calls").and_then(|tc| tc.as_array())
377                            {
378                                aggregator.handle_tool_calls(tool_calls);
379                            }
380                        }
381
382                        if let Some(reason) = choice.get("finish_reason").and_then(|r| r.as_str()) {
383                            aggregator.set_finish_reason(map_finish_reason_common(reason));
384                        }
385                    }
386
387                    if let Some(_usage_value) = value.get("usage") {
388                        if let Some(usage) =
389                            crate::llm::providers::common::parse_usage_openai_format(&value, true)
390                        {
391                            aggregator.set_usage(usage);
392                        }
393                    }
394                    Ok(())
395                },
396            )
397            .await;
398
399            match result {
400                Ok(_) => {
401                    let response = aggregator.finalize();
402                    let _ = tx.send(Ok(LLMStreamEvent::Completed {
403                        response: Box::new(response),
404                    }));
405                }
406                Err(err) => {
407                    let _ = tx.send(Err(err));
408                }
409            }
410        });
411
412        let stream = try_stream! {
413            let mut receiver = event_rx;
414            while let Some(event) = receiver.recv().await {
415                yield event?;
416            }
417        };
418
419        Ok(Box::pin(stream))
420    }
421
422    fn supported_models(&self) -> Vec<String> {
423        models::deepseek::SUPPORTED_MODELS
424            .iter()
425            .map(|model| model.to_string())
426            .collect()
427    }
428
429    fn validate_request(&self, request: &LLMRequest) -> Result<(), LLMError> {
430        validate_supported_models(
431            request,
432            PROVIDER_NAME,
433            PROVIDER_KEY,
434            models::deepseek::SUPPORTED_MODELS,
435        )
436    }
437
438    async fn get_balance(&self) -> Result<Option<vtcode_commons::llm::BalanceInfo>, LLMError> {
439        // Strip /v1 suffix to get the root API URL for the balance endpoint.
440        let base = self.base_url.trim_end_matches('/');
441        let root = base.strip_suffix("/v1").unwrap_or(base);
442        let url = format!("{}/user/balance", root);
443
444        let response = self
445            .http_client
446            .get(&url)
447            .bearer_auth(&self.api_key)
448            .send()
449            .await
450            .map_err(|e| LLMError::Network {
451                message: error_display::format_llm_error(
452                    PROVIDER_NAME,
453                    &format!("balance request failed: {}", e),
454                ),
455                metadata: None,
456            })?;
457
458        if !response.status().is_success() {
459            let status = response.status();
460            let body = response.text().await.unwrap_or_default();
461            return Err(LLMError::Provider {
462                message: error_display::format_llm_error(
463                    PROVIDER_NAME,
464                    &format!("balance API returned {}: {}", status, body),
465                ),
466                metadata: None,
467            });
468        }
469
470        let balance_resp: vtcode_commons::llm::DeepSeekBalanceResponse =
471            response.json().await.map_err(|e| LLMError::Provider {
472                message: error_display::format_llm_error(
473                    PROVIDER_NAME,
474                    &format!("failed to parse balance response: {}", e),
475                ),
476                metadata: None,
477            })?;
478
479        Ok(Some(balance_resp.into()))
480    }
481}
482
483impl_llm_client!(DeepSeekProvider);