Skip to main content

vtcode_core/llm/providers/
moonshot.rs

1use crate::config::TimeoutsConfig;
2use crate::config::constants::{env_vars, models, urls};
3use crate::config::core::{AnthropicConfig, ModelConfig, PromptCachingConfig};
4use crate::llm::error_display;
5use crate::llm::provider::{
6    LLMError, LLMProvider, LLMRequest, LLMResponse, LLMStream, LLMStreamEvent,
7};
8use async_stream::try_stream;
9use async_trait::async_trait;
10use reqwest::Client as HttpClient;
11use serde_json::{Map, Value};
12
13use super::common::{
14    ensure_model, impl_llm_client, map_finish_reason_common, override_base_url,
15    parse_json_response, parse_response_openai_format, resolve_model,
16    serialize_messages_openai_format,
17};
18use super::error_handling::handle_openai_http_error;
19
20const PROVIDER_NAME: &str = "Moonshot";
21const PROVIDER_KEY: &str = "moonshot";
22
23pub struct MoonshotProvider {
24    api_key: String,
25    http_client: HttpClient,
26    base_url: String,
27    model: String,
28}
29
30impl MoonshotProvider {
31    pub fn new(api_key: String) -> Self {
32        Self::with_model_internal(
33            api_key,
34            models::moonshot::DEFAULT_MODEL.to_string(),
35            None,
36            None,
37            None,
38        )
39    }
40
41    pub fn with_model(api_key: String, model: String) -> Self {
42        Self::with_model_internal(api_key, model, None, None, None)
43    }
44
45    pub fn new_with_client(
46        api_key: String,
47        model: String,
48        http_client: reqwest::Client,
49        base_url: String,
50        _timeouts: TimeoutsConfig,
51    ) -> Self {
52        Self {
53            api_key,
54            http_client,
55            base_url,
56            model: model.trim().to_string(),
57        }
58    }
59
60    pub fn from_config(
61        api_key: Option<String>,
62        model: Option<String>,
63        base_url: Option<String>,
64        _prompt_cache: Option<PromptCachingConfig>,
65        timeouts: Option<TimeoutsConfig>,
66        _anthropic: Option<AnthropicConfig>,
67        _model_behavior: Option<ModelConfig>,
68    ) -> Self {
69        let api_key_value = api_key.unwrap_or_default();
70        let model_value = resolve_model(model, models::moonshot::DEFAULT_MODEL);
71
72        Self::with_model_internal(
73            api_key_value,
74            model_value,
75            base_url,
76            timeouts,
77            _model_behavior,
78        )
79    }
80
81    fn with_model_internal(
82        api_key: String,
83        model: String,
84        base_url: Option<String>,
85        timeouts: Option<TimeoutsConfig>,
86        _model_behavior: Option<ModelConfig>,
87    ) -> Self {
88        use crate::llm::http_client::HttpClientFactory;
89
90        let timeouts = timeouts.unwrap_or_default();
91
92        Self {
93            api_key,
94            http_client: HttpClientFactory::for_llm(&timeouts),
95            base_url: override_base_url(
96                urls::MOONSHOT_API_BASE,
97                base_url,
98                Some(env_vars::MOONSHOT_BASE_URL),
99            ),
100            model: model.trim().to_string(),
101        }
102    }
103
104    fn convert_to_moonshot_format(&self, request: &LLMRequest) -> Result<Value, LLMError> {
105        let mut payload = Map::new();
106
107        payload.insert("model".to_owned(), Value::String(request.model.clone()));
108        payload.insert(
109            "messages".to_owned(),
110            Value::Array(serialize_messages_openai_format(request, PROVIDER_KEY)?),
111        );
112
113        if let Some(max_tokens) = request.max_tokens {
114            payload.insert(
115                "max_tokens".to_owned(),
116                Value::Number(serde_json::Number::from(max_tokens as u64)),
117            );
118        }
119
120        if let Some(temperature) = request.temperature {
121            payload.insert(
122                "temperature".to_owned(),
123                Value::Number(serde_json::Number::from_f64(temperature as f64).ok_or_else(
124                    || LLMError::InvalidRequest {
125                        message: "Invalid temperature value".to_string(),
126                        metadata: None,
127                    },
128                )?),
129            );
130        }
131
132        // Add reasoning_effort for Kimi K2 Thinking model
133        if let Some(effort) = request.reasoning_effort
134            && self.supports_reasoning_effort(&request.model)
135        {
136            payload.insert(
137                "reasoning_effort".to_string(),
138                Value::String(effort.as_str().to_string()),
139            );
140        }
141
142        if request.stream {
143            payload.insert("stream".to_string(), Value::Bool(true));
144        }
145
146        // Add tools if present (Moonshot supports function calling)
147        if let Some(tools) = &request.tools
148            && let Some(serialized_tools) = super::common::serialize_tools_openai_format(tools)
149        {
150            payload.insert("tools".to_string(), Value::Array(serialized_tools));
151        }
152
153        if let Some(choice) = &request.tool_choice {
154            payload.insert(
155                "tool_choice".to_string(),
156                choice.to_provider_format(PROVIDER_KEY),
157            );
158        }
159
160        Ok(Value::Object(payload))
161    }
162}
163
164#[async_trait]
165impl LLMProvider for MoonshotProvider {
166    fn name(&self) -> &str {
167        "moonshot"
168    }
169
170    fn supports_reasoning(&self, model: &str) -> bool {
171        model.contains("k2-thinking") || model.contains("kimi-k2-thinking")
172    }
173
174    fn supports_reasoning_effort(&self, model: &str) -> bool {
175        model.contains("k2-thinking") || model.contains("kimi-k2-thinking")
176    }
177
178    async fn generate(&self, mut request: LLMRequest) -> Result<LLMResponse, LLMError> {
179        ensure_model(&mut request, &self.model);
180        request.model = request.model.trim().to_string();
181        let model = request.model.clone();
182
183        let payload = self.convert_to_moonshot_format(&request)?;
184        let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
185
186        let response = self
187            .http_client
188            .post(&url)
189            .bearer_auth(&self.api_key)
190            .json(&payload)
191            .send()
192            .await
193            .map_err(|e| {
194                let formatted_error = error_display::format_llm_error(
195                    PROVIDER_NAME,
196                    &format!("Network error: {}", e),
197                );
198                LLMError::Network {
199                    message: formatted_error,
200                    metadata: None,
201                }
202            })?;
203
204        let response =
205            handle_openai_http_error(response, PROVIDER_NAME, "MOONSHOT_API_KEY").await?;
206        let response_json = parse_json_response(response, PROVIDER_NAME).await?;
207
208        parse_response_openai_format::<fn(&Value, &Value) -> Option<String>>(
209            response_json,
210            PROVIDER_NAME,
211            model,
212            false,
213            None,
214        )
215    }
216
217    async fn stream(&self, mut request: LLMRequest) -> Result<LLMStream, LLMError> {
218        ensure_model(&mut request, &self.model);
219        request.model = request.model.trim().to_string();
220        let model = request.model.clone();
221
222        self.validate_request(&request)?;
223        request.stream = true;
224
225        let payload = self.convert_to_moonshot_format(&request)?;
226        let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
227
228        let response = self
229            .http_client
230            .post(&url)
231            .bearer_auth(&self.api_key)
232            .json(&payload)
233            .send()
234            .await
235            .map_err(|e| {
236                let formatted_error = error_display::format_llm_error(
237                    PROVIDER_NAME,
238                    &format!("Network error: {}", e),
239                );
240                LLMError::Network {
241                    message: formatted_error,
242                    metadata: None,
243                }
244            })?;
245
246        let response =
247            handle_openai_http_error(response, PROVIDER_NAME, "MOONSHOT_API_KEY").await?;
248
249        let bytes_stream = response.bytes_stream();
250        let (event_tx, event_rx) =
251            tokio::sync::mpsc::unbounded_channel::<Result<LLMStreamEvent, LLMError>>();
252        let tx = event_tx.clone();
253
254        let model_clone = model.clone();
255        tokio::spawn(async move {
256            let mut aggregator =
257                crate::llm::providers::shared::StreamAggregator::new(model_clone.clone());
258
259            let result = crate::llm::providers::shared::process_openai_stream(
260                bytes_stream,
261                PROVIDER_NAME,
262                model_clone,
263                |value| {
264                    if let Some(choices) = value.get("choices").and_then(|c| c.as_array())
265                        && let Some(choice) = choices.first()
266                    {
267                        if let Some(delta) = choice.get("delta") {
268                            // Handle reasoning_content field (Kimi K2 Thinking models)
269                            if let Some(reasoning) =
270                                delta.get("reasoning_content").and_then(|c| c.as_str())
271                                && let Some(d) = aggregator.handle_reasoning(reasoning)
272                            {
273                                let _ = tx.send(Ok(LLMStreamEvent::Reasoning { delta: d }));
274                            }
275
276                            // Handle regular content
277                            if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
278                                for event in aggregator.handle_content(content) {
279                                    let _ = tx.send(Ok(event));
280                                }
281                            }
282                        }
283
284                        if let Some(reason) = choice.get("finish_reason").and_then(|r| r.as_str()) {
285                            aggregator.set_finish_reason(map_finish_reason_common(reason));
286                        }
287                    }
288
289                    if let Some(_usage_value) = value.get("usage")
290                        && let Some(usage) =
291                            crate::llm::providers::common::parse_usage_openai_format(&value, false)
292                    {
293                        aggregator.set_usage(usage);
294                    }
295                    Ok(())
296                },
297            )
298            .await;
299
300            match result {
301                Ok(_) => {
302                    let response = aggregator.finalize();
303                    let _ = tx.send(Ok(LLMStreamEvent::Completed {
304                        response: Box::new(response),
305                    }));
306                }
307                Err(err) => {
308                    let _ = tx.send(Err(err));
309                }
310            }
311        });
312
313        let stream = try_stream! {
314            let mut receiver = event_rx;
315            while let Some(event) = receiver.recv().await {
316                yield event?;
317            }
318        };
319
320        Ok(Box::pin(stream))
321    }
322
323    fn supported_models(&self) -> Vec<String> {
324        models::moonshot::SUPPORTED_MODELS
325            .iter()
326            .map(|model| model.to_string())
327            .collect()
328    }
329
330    fn validate_request(&self, request: &LLMRequest) -> Result<(), LLMError> {
331        // Moonshot publishes new official aliases and preview slugs faster than VT Code's
332        // curated picker list is refreshed, so let the upstream API be the source of truth
333        // for model identifiers and keep local validation focused on request shape.
334        super::common::validate_request_common(request, PROVIDER_NAME, PROVIDER_KEY, None)
335    }
336}
337
338impl_llm_client!(MoonshotProvider);