vtcode_core/llm/providers/
zai.rs

1use crate::config::constants::{models, urls};
2use crate::config::core::{PromptCachingConfig, ZaiPromptCacheSettings};
3use crate::llm::client::LLMClient;
4use crate::llm::error_display;
5use crate::llm::provider::{
6    FinishReason, LLMError, LLMProvider, LLMRequest, LLMResponse, Message, MessageRole, ToolCall,
7    ToolChoice, ToolDefinition, Usage,
8};
9use crate::llm::types as llm_types;
10use async_trait::async_trait;
11use reqwest::Client as HttpClient;
12use serde_json::{Value, json};
13use std::collections::HashSet;
14
15const PROVIDER_NAME: &str = "Z.AI";
16const PROVIDER_KEY: &str = "zai";
17const CHAT_COMPLETIONS_PATH: &str = "/paas/v4/chat/completions";
18
19pub struct ZAIProvider {
20    api_key: String,
21    http_client: HttpClient,
22    base_url: String,
23    model: String,
24    prompt_cache_settings: ZaiPromptCacheSettings,
25}
26
27impl ZAIProvider {
28    fn serialize_tools(tools: &[ToolDefinition]) -> Option<Value> {
29        if tools.is_empty() {
30            return None;
31        }
32
33        let serialized = tools
34            .iter()
35            .map(|tool| {
36                json!({
37                    "type": tool.tool_type,
38                    "function": {
39                        "name": tool.function.name,
40                        "description": tool.function.description,
41                        "parameters": tool.function.parameters,
42                    }
43                })
44            })
45            .collect::<Vec<Value>>();
46
47        Some(Value::Array(serialized))
48    }
49
50    fn extract_prompt_cache_settings(
51        prompt_cache: Option<PromptCachingConfig>,
52    ) -> ZaiPromptCacheSettings {
53        if let Some(cfg) = prompt_cache {
54            cfg.providers.zai
55        } else {
56            ZaiPromptCacheSettings::default()
57        }
58    }
59
60    fn with_model_internal(
61        api_key: String,
62        model: String,
63        base_url: Option<String>,
64        prompt_cache: Option<PromptCachingConfig>,
65    ) -> Self {
66        Self {
67            api_key,
68            http_client: HttpClient::new(),
69            base_url: base_url.unwrap_or_else(|| urls::Z_AI_API_BASE.to_string()),
70            model,
71            prompt_cache_settings: Self::extract_prompt_cache_settings(prompt_cache),
72        }
73    }
74
75    pub fn new(api_key: String) -> Self {
76        Self::with_model_internal(api_key, models::zai::DEFAULT_MODEL.to_string(), None, None)
77    }
78
79    pub fn with_model(api_key: String, model: String) -> Self {
80        Self::with_model_internal(api_key, model, None, None)
81    }
82
83    pub fn from_config(
84        api_key: Option<String>,
85        model: Option<String>,
86        base_url: Option<String>,
87        prompt_cache: Option<PromptCachingConfig>,
88    ) -> Self {
89        let api_key_value = api_key.unwrap_or_default();
90        let model_value = model.unwrap_or_else(|| models::zai::DEFAULT_MODEL.to_string());
91        Self::with_model_internal(api_key_value, model_value, base_url, prompt_cache)
92    }
93
94    fn default_request(&self, prompt: &str) -> LLMRequest {
95        LLMRequest {
96            messages: vec![Message::user(prompt.to_string())],
97            system_prompt: None,
98            tools: None,
99            model: self.model.clone(),
100            max_tokens: None,
101            temperature: None,
102            stream: false,
103            tool_choice: None,
104            parallel_tool_calls: None,
105            parallel_tool_config: None,
106            reasoning_effort: None,
107        }
108    }
109
110    fn parse_client_prompt(&self, prompt: &str) -> LLMRequest {
111        let trimmed = prompt.trim_start();
112        if trimmed.starts_with('{') {
113            if let Ok(value) = serde_json::from_str::<Value>(trimmed) {
114                if let Some(request) = self.parse_chat_request(&value) {
115                    return request;
116                }
117            }
118        }
119
120        self.default_request(prompt)
121    }
122
123    fn parse_chat_request(&self, value: &Value) -> Option<LLMRequest> {
124        let messages_value = value.get("messages")?.as_array()?;
125        let mut system_prompt = value
126            .get("system")
127            .and_then(|entry| entry.as_str())
128            .map(|text| text.to_string());
129        let mut messages = Vec::new();
130
131        for entry in messages_value {
132            let role = entry
133                .get("role")
134                .and_then(|r| r.as_str())
135                .unwrap_or(crate::config::constants::message_roles::USER);
136            let content = entry
137                .get("content")
138                .map(|c| match c {
139                    Value::String(text) => text.to_string(),
140                    other => other.to_string(),
141                })
142                .unwrap_or_default();
143
144            match role {
145                "system" => {
146                    if system_prompt.is_none() && !content.is_empty() {
147                        system_prompt = Some(content);
148                    }
149                }
150                "assistant" => {
151                    let tool_calls = entry
152                        .get("tool_calls")
153                        .and_then(|tc| tc.as_array())
154                        .map(|calls| {
155                            calls
156                                .iter()
157                                .filter_map(|call| Self::parse_tool_call(call))
158                                .collect::<Vec<_>>()
159                        })
160                        .filter(|calls| !calls.is_empty());
161
162                    messages.push(Message {
163                        role: MessageRole::Assistant,
164                        content,
165                        tool_calls,
166                        tool_call_id: None,
167                    });
168                }
169                "tool" => {
170                    if let Some(tool_call_id) = entry.get("tool_call_id").and_then(|v| v.as_str()) {
171                        messages.push(Message::tool_response(tool_call_id.to_string(), content));
172                    }
173                }
174                _ => {
175                    messages.push(Message::user(content));
176                }
177            }
178        }
179
180        Some(LLMRequest {
181            messages,
182            system_prompt,
183            model: value
184                .get("model")
185                .and_then(|m| m.as_str())
186                .unwrap_or(&self.model)
187                .to_string(),
188            max_tokens: value
189                .get("max_tokens")
190                .and_then(|m| m.as_u64())
191                .map(|m| m as u32),
192            temperature: value
193                .get("temperature")
194                .and_then(|t| t.as_f64())
195                .map(|t| t as f32),
196            stream: value
197                .get("stream")
198                .and_then(|s| s.as_bool())
199                .unwrap_or(false),
200            tools: None,
201            tool_choice: value.get("tool_choice").and_then(|choice| match choice {
202                Value::String(s) => match s.as_str() {
203                    "auto" => Some(ToolChoice::auto()),
204                    "none" => Some(ToolChoice::none()),
205                    "any" | "required" => Some(ToolChoice::any()),
206                    _ => None,
207                },
208                _ => None,
209            }),
210            parallel_tool_calls: None,
211            parallel_tool_config: None,
212            reasoning_effort: None,
213        })
214    }
215
216    fn parse_tool_call(value: &Value) -> Option<ToolCall> {
217        let id = value.get("id").and_then(|v| v.as_str())?;
218        let function = value.get("function")?;
219        let name = function.get("name").and_then(|v| v.as_str())?;
220        let arguments = function.get("arguments");
221        let serialized = arguments.map_or("{}".to_string(), |value| {
222            if value.is_string() {
223                value.as_str().unwrap_or("").to_string()
224            } else {
225                value.to_string()
226            }
227        });
228
229        Some(ToolCall::function(
230            id.to_string(),
231            name.to_string(),
232            serialized,
233        ))
234    }
235
236    fn convert_to_zai_format(&self, request: &LLMRequest) -> Result<Value, LLMError> {
237        let mut messages = Vec::new();
238        let mut active_tool_call_ids: HashSet<String> = HashSet::new();
239
240        if let Some(system_prompt) = &request.system_prompt {
241            messages.push(json!({
242                "role": crate::config::constants::message_roles::SYSTEM,
243                "content": system_prompt
244            }));
245        }
246
247        for msg in &request.messages {
248            let role = msg.role.as_generic_str();
249            let mut message = json!({
250                "role": role,
251                "content": msg.content
252            });
253            let mut skip_message = false;
254
255            if msg.role == MessageRole::Assistant {
256                if let Some(tool_calls) = &msg.tool_calls {
257                    if !tool_calls.is_empty() {
258                        let tool_calls_json: Vec<Value> = tool_calls
259                            .iter()
260                            .map(|tc| {
261                                active_tool_call_ids.insert(tc.id.clone());
262                                json!({
263                                    "id": tc.id,
264                                    "type": "function",
265                                    "function": {
266                                        "name": tc.function.name,
267                                        "arguments": tc.function.arguments,
268                                    }
269                                })
270                            })
271                            .collect();
272                        message["tool_calls"] = Value::Array(tool_calls_json);
273                    }
274                }
275            }
276
277            if msg.role == MessageRole::Tool {
278                match &msg.tool_call_id {
279                    Some(tool_call_id) if active_tool_call_ids.contains(tool_call_id) => {
280                        message["tool_call_id"] = Value::String(tool_call_id.clone());
281                        active_tool_call_ids.remove(tool_call_id);
282                    }
283                    Some(_) | None => {
284                        skip_message = true;
285                    }
286                }
287            }
288
289            if !skip_message {
290                messages.push(message);
291            }
292        }
293
294        if messages.is_empty() {
295            let formatted = error_display::format_llm_error(PROVIDER_NAME, "No messages provided");
296            return Err(LLMError::InvalidRequest(formatted));
297        }
298
299        let mut payload = json!({
300            "model": request.model,
301            "messages": messages,
302            "stream": request.stream,
303        });
304
305        if let Some(max_tokens) = request.max_tokens {
306            payload["max_tokens"] = json!(max_tokens);
307        }
308
309        if let Some(temperature) = request.temperature {
310            payload["temperature"] = json!(temperature);
311        }
312
313        if let Some(tools) = &request.tools {
314            if let Some(serialized) = Self::serialize_tools(tools) {
315                payload["tools"] = serialized;
316            }
317        }
318
319        if let Some(choice) = &request.tool_choice {
320            payload["tool_choice"] = choice.to_provider_format("openai");
321        }
322
323        if self.supports_reasoning(&request.model) || request.reasoning_effort.is_some() {
324            payload["thinking"] = json!({ "type": "enabled" });
325        }
326
327        Ok(payload)
328    }
329
330    fn parse_zai_response(&self, response_json: Value) -> Result<LLMResponse, LLMError> {
331        let choices = response_json
332            .get("choices")
333            .and_then(|c| c.as_array())
334            .ok_or_else(|| {
335                let formatted = error_display::format_llm_error(
336                    PROVIDER_NAME,
337                    "Invalid response format: missing choices",
338                );
339                LLMError::Provider(formatted)
340            })?;
341
342        if choices.is_empty() {
343            let formatted =
344                error_display::format_llm_error(PROVIDER_NAME, "No choices in response");
345            return Err(LLMError::Provider(formatted));
346        }
347
348        let choice = &choices[0];
349        let message = choice.get("message").ok_or_else(|| {
350            let formatted = error_display::format_llm_error(
351                PROVIDER_NAME,
352                "Invalid response format: missing message",
353            );
354            LLMError::Provider(formatted)
355        })?;
356
357        let content = message
358            .get("content")
359            .and_then(|c| c.as_str())
360            .map(|s| s.to_string());
361
362        let reasoning = message
363            .get("reasoning_content")
364            .map(|value| match value {
365                Value::String(text) => Some(text.to_string()),
366                Value::Array(parts) => {
367                    let combined = parts
368                        .iter()
369                        .filter_map(|part| part.as_str())
370                        .collect::<Vec<_>>()
371                        .join("");
372                    if combined.is_empty() {
373                        None
374                    } else {
375                        Some(combined)
376                    }
377                }
378                _ => None,
379            })
380            .flatten();
381
382        let tool_calls = message
383            .get("tool_calls")
384            .and_then(|tc| tc.as_array())
385            .map(|calls| {
386                calls
387                    .iter()
388                    .filter_map(|call| Self::parse_tool_call(call))
389                    .collect::<Vec<_>>()
390            })
391            .filter(|calls| !calls.is_empty());
392
393        let finish_reason = choice
394            .get("finish_reason")
395            .and_then(|fr| fr.as_str())
396            .map(Self::map_finish_reason)
397            .unwrap_or(FinishReason::Stop);
398
399        let usage = response_json.get("usage").map(|usage_value| Usage {
400            prompt_tokens: usage_value
401                .get("prompt_tokens")
402                .and_then(|pt| pt.as_u64())
403                .unwrap_or(0) as u32,
404            completion_tokens: usage_value
405                .get("completion_tokens")
406                .and_then(|ct| ct.as_u64())
407                .unwrap_or(0) as u32,
408            total_tokens: usage_value
409                .get("total_tokens")
410                .and_then(|tt| tt.as_u64())
411                .unwrap_or(0) as u32,
412            cached_prompt_tokens: usage_value
413                .get("prompt_tokens_details")
414                .and_then(|details| details.get("cached_tokens"))
415                .and_then(|value| value.as_u64())
416                .map(|value| value as u32),
417            cache_creation_tokens: None,
418            cache_read_tokens: None,
419        });
420
421        Ok(LLMResponse {
422            content,
423            tool_calls,
424            usage,
425            finish_reason,
426            reasoning,
427        })
428    }
429
430    fn map_finish_reason(reason: &str) -> FinishReason {
431        match reason {
432            "stop" => FinishReason::Stop,
433            "length" => FinishReason::Length,
434            "tool_calls" => FinishReason::ToolCalls,
435            "sensitive" => FinishReason::ContentFilter,
436            other => FinishReason::Error(other.to_string()),
437        }
438    }
439
440    fn available_models() -> Vec<String> {
441        models::zai::SUPPORTED_MODELS
442            .iter()
443            .map(|s| s.to_string())
444            .collect()
445    }
446}
447
448#[async_trait]
449impl LLMProvider for ZAIProvider {
450    fn name(&self) -> &str {
451        PROVIDER_KEY
452    }
453
454    fn supports_streaming(&self) -> bool {
455        false
456    }
457
458    fn supports_reasoning(&self, model: &str) -> bool {
459        matches!(
460            model,
461            models::zai::GLM_4_6
462                | models::zai::GLM_4_5
463                | models::zai::GLM_4_5_AIR
464                | models::zai::GLM_4_5_X
465                | models::zai::GLM_4_5_AIRX
466        )
467    }
468
469    fn supports_reasoning_effort(&self, _model: &str) -> bool {
470        false
471    }
472
473    async fn generate(&self, mut request: LLMRequest) -> Result<LLMResponse, LLMError> {
474        if request.model.trim().is_empty() {
475            request.model = self.model.clone();
476        }
477
478        if !Self::available_models().contains(&request.model) {
479            let formatted = error_display::format_llm_error(
480                PROVIDER_NAME,
481                &format!("Unsupported model: {}", request.model),
482            );
483            return Err(LLMError::InvalidRequest(formatted));
484        }
485
486        for message in &request.messages {
487            if let Err(err) = message.validate_for_provider(PROVIDER_KEY) {
488                let formatted = error_display::format_llm_error(PROVIDER_NAME, &err);
489                return Err(LLMError::InvalidRequest(formatted));
490            }
491        }
492
493        let payload = self.convert_to_zai_format(&request)?;
494        let url = format!("{}{}", self.base_url, CHAT_COMPLETIONS_PATH);
495
496        let response = self
497            .http_client
498            .post(&url)
499            .bearer_auth(&self.api_key)
500            .json(&payload)
501            .send()
502            .await
503            .map_err(|err| {
504                let formatted = error_display::format_llm_error(
505                    PROVIDER_NAME,
506                    &format!("Network error: {}", err),
507                );
508                LLMError::Network(formatted)
509            })?;
510
511        if !response.status().is_success() {
512            let status = response.status();
513            let text = response.text().await.unwrap_or_default();
514
515            if status.as_u16() == 429 || text.to_lowercase().contains("rate") {
516                return Err(LLMError::RateLimit);
517            }
518
519            let message = serde_json::from_str::<Value>(&text)
520                .ok()
521                .and_then(|value| {
522                    value
523                        .get("message")
524                        .and_then(|m| m.as_str())
525                        .map(|s| s.to_string())
526                })
527                .unwrap_or(text);
528
529            let formatted = error_display::format_llm_error(
530                PROVIDER_NAME,
531                &format!("HTTP {}: {}", status, message),
532            );
533            return Err(LLMError::Provider(formatted));
534        }
535
536        let json: Value = response.json().await.map_err(|err| {
537            let formatted = error_display::format_llm_error(
538                PROVIDER_NAME,
539                &format!("Failed to parse response: {}", err),
540            );
541            LLMError::Provider(formatted)
542        })?;
543
544        self.parse_zai_response(json)
545    }
546
547    fn supported_models(&self) -> Vec<String> {
548        Self::available_models()
549    }
550
551    fn validate_request(&self, request: &LLMRequest) -> Result<(), LLMError> {
552        if request.messages.is_empty() {
553            let formatted =
554                error_display::format_llm_error(PROVIDER_NAME, "Messages cannot be empty");
555            return Err(LLMError::InvalidRequest(formatted));
556        }
557
558        if !request.model.is_empty() && !Self::available_models().contains(&request.model) {
559            let formatted = error_display::format_llm_error(
560                PROVIDER_NAME,
561                &format!("Unsupported model: {}", request.model),
562            );
563            return Err(LLMError::InvalidRequest(formatted));
564        }
565
566        for message in &request.messages {
567            if let Err(err) = message.validate_for_provider(PROVIDER_KEY) {
568                let formatted = error_display::format_llm_error(PROVIDER_NAME, &err);
569                return Err(LLMError::InvalidRequest(formatted));
570            }
571        }
572
573        Ok(())
574    }
575}
576
577#[async_trait]
578impl LLMClient for ZAIProvider {
579    async fn generate(&mut self, prompt: &str) -> Result<llm_types::LLMResponse, LLMError> {
580        let request = self.parse_client_prompt(prompt);
581        let request_model = request.model.clone();
582        let response = LLMProvider::generate(self, request).await?;
583
584        Ok(llm_types::LLMResponse {
585            content: response.content.unwrap_or_default(),
586            model: request_model,
587            usage: response.usage.map(|usage| llm_types::Usage {
588                prompt_tokens: usage.prompt_tokens as usize,
589                completion_tokens: usage.completion_tokens as usize,
590                total_tokens: usage.total_tokens as usize,
591                cached_prompt_tokens: usage.cached_prompt_tokens.map(|v| v as usize),
592                cache_creation_tokens: usage.cache_creation_tokens.map(|v| v as usize),
593                cache_read_tokens: usage.cache_read_tokens.map(|v| v as usize),
594            }),
595            reasoning: response.reasoning,
596        })
597    }
598
599    fn backend_kind(&self) -> llm_types::BackendKind {
600        llm_types::BackendKind::ZAI
601    }
602
603    fn model_id(&self) -> &str {
604        &self.model
605    }
606}