vtcode_core/llm/providers/
zai.rs

1use crate::config::constants::{models, urls};
2use crate::config::core::PromptCachingConfig;
3use crate::llm::client::LLMClient;
4use crate::llm::error_display;
5use crate::llm::provider::{
6    FinishReason, LLMError, LLMProvider, LLMRequest, LLMResponse, Message, MessageRole, ToolCall,
7    ToolChoice, ToolDefinition, Usage,
8};
9use crate::llm::types as llm_types;
10use async_trait::async_trait;
11use reqwest::Client as HttpClient;
12use serde_json::{Value, json};
13use std::collections::HashSet;
14
15const PROVIDER_NAME: &str = "Z.AI";
16const PROVIDER_KEY: &str = "zai";
17const CHAT_COMPLETIONS_PATH: &str = "/paas/v4/chat/completions";
18
19pub struct ZAIProvider {
20    api_key: String,
21    http_client: HttpClient,
22    base_url: String,
23    model: String,
24}
25
26impl ZAIProvider {
27    fn serialize_tools(tools: &[ToolDefinition]) -> Option<Value> {
28        if tools.is_empty() {
29            return None;
30        }
31
32        let serialized = tools
33            .iter()
34            .map(|tool| {
35                json!({
36                    "type": tool.tool_type,
37                    "function": {
38                        "name": tool.function.name,
39                        "description": tool.function.description,
40                        "parameters": tool.function.parameters,
41                    }
42                })
43            })
44            .collect::<Vec<Value>>();
45
46        Some(Value::Array(serialized))
47    }
48
49    fn with_model_internal(
50        api_key: String,
51        model: String,
52        base_url: Option<String>,
53        _prompt_cache: Option<PromptCachingConfig>,
54    ) -> Self {
55        Self {
56            api_key,
57            http_client: HttpClient::new(),
58            base_url: base_url.unwrap_or_else(|| urls::Z_AI_API_BASE.to_string()),
59            model,
60        }
61    }
62
63    pub fn new(api_key: String) -> Self {
64        Self::with_model_internal(api_key, models::zai::DEFAULT_MODEL.to_string(), None, None)
65    }
66
67    pub fn with_model(api_key: String, model: String) -> Self {
68        Self::with_model_internal(api_key, model, None, None)
69    }
70
71    pub fn from_config(
72        api_key: Option<String>,
73        model: Option<String>,
74        base_url: Option<String>,
75        prompt_cache: Option<PromptCachingConfig>,
76    ) -> Self {
77        let api_key_value = api_key.unwrap_or_default();
78        let model_value = model.unwrap_or_else(|| models::zai::DEFAULT_MODEL.to_string());
79        Self::with_model_internal(api_key_value, model_value, base_url, prompt_cache)
80    }
81
82    fn default_request(&self, prompt: &str) -> LLMRequest {
83        LLMRequest {
84            messages: vec![Message::user(prompt.to_string())],
85            system_prompt: None,
86            tools: None,
87            model: self.model.clone(),
88            max_tokens: None,
89            temperature: None,
90            stream: false,
91            tool_choice: None,
92            parallel_tool_calls: None,
93            parallel_tool_config: None,
94            reasoning_effort: None,
95        }
96    }
97
98    fn parse_client_prompt(&self, prompt: &str) -> LLMRequest {
99        let trimmed = prompt.trim_start();
100        if trimmed.starts_with('{') {
101            if let Ok(value) = serde_json::from_str::<Value>(trimmed) {
102                if let Some(request) = self.parse_chat_request(&value) {
103                    return request;
104                }
105            }
106        }
107
108        self.default_request(prompt)
109    }
110
111    fn parse_chat_request(&self, value: &Value) -> Option<LLMRequest> {
112        let messages_value = value.get("messages")?.as_array()?;
113        let mut system_prompt = value
114            .get("system")
115            .and_then(|entry| entry.as_str())
116            .map(|text| text.to_string());
117        let mut messages = Vec::new();
118
119        for entry in messages_value {
120            let role = entry
121                .get("role")
122                .and_then(|r| r.as_str())
123                .unwrap_or(crate::config::constants::message_roles::USER);
124            let content = entry
125                .get("content")
126                .map(|c| match c {
127                    Value::String(text) => text.to_string(),
128                    other => other.to_string(),
129                })
130                .unwrap_or_default();
131
132            match role {
133                "system" => {
134                    if system_prompt.is_none() && !content.is_empty() {
135                        system_prompt = Some(content);
136                    }
137                }
138                "assistant" => {
139                    let tool_calls = entry
140                        .get("tool_calls")
141                        .and_then(|tc| tc.as_array())
142                        .map(|calls| {
143                            calls
144                                .iter()
145                                .filter_map(|call| Self::parse_tool_call(call))
146                                .collect::<Vec<_>>()
147                        })
148                        .filter(|calls| !calls.is_empty());
149
150                    messages.push(Message {
151                        role: MessageRole::Assistant,
152                        content,
153                        tool_calls,
154                        tool_call_id: None,
155                    });
156                }
157                "tool" => {
158                    if let Some(tool_call_id) = entry.get("tool_call_id").and_then(|v| v.as_str()) {
159                        messages.push(Message::tool_response(tool_call_id.to_string(), content));
160                    }
161                }
162                _ => {
163                    messages.push(Message::user(content));
164                }
165            }
166        }
167
168        Some(LLMRequest {
169            messages,
170            system_prompt,
171            model: value
172                .get("model")
173                .and_then(|m| m.as_str())
174                .unwrap_or(&self.model)
175                .to_string(),
176            max_tokens: value
177                .get("max_tokens")
178                .and_then(|m| m.as_u64())
179                .map(|m| m as u32),
180            temperature: value
181                .get("temperature")
182                .and_then(|t| t.as_f64())
183                .map(|t| t as f32),
184            stream: value
185                .get("stream")
186                .and_then(|s| s.as_bool())
187                .unwrap_or(false),
188            tools: None,
189            tool_choice: value.get("tool_choice").and_then(|choice| match choice {
190                Value::String(s) => match s.as_str() {
191                    "auto" => Some(ToolChoice::auto()),
192                    "none" => Some(ToolChoice::none()),
193                    "any" | "required" => Some(ToolChoice::any()),
194                    _ => None,
195                },
196                _ => None,
197            }),
198            parallel_tool_calls: None,
199            parallel_tool_config: None,
200            reasoning_effort: None,
201        })
202    }
203
204    fn parse_tool_call(value: &Value) -> Option<ToolCall> {
205        let id = value.get("id").and_then(|v| v.as_str())?;
206        let function = value.get("function")?;
207        let name = function.get("name").and_then(|v| v.as_str())?;
208        let arguments = function.get("arguments");
209        let serialized = arguments.map_or("{}".to_string(), |value| {
210            if value.is_string() {
211                value.as_str().unwrap_or("").to_string()
212            } else {
213                value.to_string()
214            }
215        });
216
217        Some(ToolCall::function(
218            id.to_string(),
219            name.to_string(),
220            serialized,
221        ))
222    }
223
224    fn convert_to_zai_format(&self, request: &LLMRequest) -> Result<Value, LLMError> {
225        let mut messages = Vec::new();
226        let mut active_tool_call_ids: HashSet<String> = HashSet::new();
227
228        if let Some(system_prompt) = &request.system_prompt {
229            messages.push(json!({
230                "role": crate::config::constants::message_roles::SYSTEM,
231                "content": system_prompt
232            }));
233        }
234
235        for msg in &request.messages {
236            let role = msg.role.as_generic_str();
237            let mut message = json!({
238                "role": role,
239                "content": msg.content
240            });
241            let mut skip_message = false;
242
243            if msg.role == MessageRole::Assistant {
244                if let Some(tool_calls) = &msg.tool_calls {
245                    if !tool_calls.is_empty() {
246                        let tool_calls_json: Vec<Value> = tool_calls
247                            .iter()
248                            .map(|tc| {
249                                active_tool_call_ids.insert(tc.id.clone());
250                                json!({
251                                    "id": tc.id,
252                                    "type": "function",
253                                    "function": {
254                                        "name": tc.function.name,
255                                        "arguments": tc.function.arguments,
256                                    }
257                                })
258                            })
259                            .collect();
260                        message["tool_calls"] = Value::Array(tool_calls_json);
261                    }
262                }
263            }
264
265            if msg.role == MessageRole::Tool {
266                match &msg.tool_call_id {
267                    Some(tool_call_id) if active_tool_call_ids.contains(tool_call_id) => {
268                        message["tool_call_id"] = Value::String(tool_call_id.clone());
269                        active_tool_call_ids.remove(tool_call_id);
270                    }
271                    Some(_) | None => {
272                        skip_message = true;
273                    }
274                }
275            }
276
277            if !skip_message {
278                messages.push(message);
279            }
280        }
281
282        if messages.is_empty() {
283            let formatted = error_display::format_llm_error(PROVIDER_NAME, "No messages provided");
284            return Err(LLMError::InvalidRequest(formatted));
285        }
286
287        let mut payload = json!({
288            "model": request.model,
289            "messages": messages,
290            "stream": request.stream,
291        });
292
293        if let Some(max_tokens) = request.max_tokens {
294            payload["max_tokens"] = json!(max_tokens);
295        }
296
297        if let Some(temperature) = request.temperature {
298            payload["temperature"] = json!(temperature);
299        }
300
301        if let Some(tools) = &request.tools {
302            if let Some(serialized) = Self::serialize_tools(tools) {
303                payload["tools"] = serialized;
304            }
305        }
306
307        if let Some(choice) = &request.tool_choice {
308            payload["tool_choice"] = choice.to_provider_format("openai");
309        }
310
311        if self.supports_reasoning(&request.model) || request.reasoning_effort.is_some() {
312            payload["thinking"] = json!({ "type": "enabled" });
313        }
314
315        Ok(payload)
316    }
317
318    fn parse_zai_response(&self, response_json: Value) -> Result<LLMResponse, LLMError> {
319        let choices = response_json
320            .get("choices")
321            .and_then(|c| c.as_array())
322            .ok_or_else(|| {
323                let formatted = error_display::format_llm_error(
324                    PROVIDER_NAME,
325                    "Invalid response format: missing choices",
326                );
327                LLMError::Provider(formatted)
328            })?;
329
330        if choices.is_empty() {
331            let formatted =
332                error_display::format_llm_error(PROVIDER_NAME, "No choices in response");
333            return Err(LLMError::Provider(formatted));
334        }
335
336        let choice = &choices[0];
337        let message = choice.get("message").ok_or_else(|| {
338            let formatted = error_display::format_llm_error(
339                PROVIDER_NAME,
340                "Invalid response format: missing message",
341            );
342            LLMError::Provider(formatted)
343        })?;
344
345        let content = message
346            .get("content")
347            .and_then(|c| c.as_str())
348            .map(|s| s.to_string());
349
350        let reasoning = message
351            .get("reasoning_content")
352            .map(|value| match value {
353                Value::String(text) => Some(text.to_string()),
354                Value::Array(parts) => {
355                    let combined = parts
356                        .iter()
357                        .filter_map(|part| part.as_str())
358                        .collect::<Vec<_>>()
359                        .join("");
360                    if combined.is_empty() {
361                        None
362                    } else {
363                        Some(combined)
364                    }
365                }
366                _ => None,
367            })
368            .flatten();
369
370        let tool_calls = message
371            .get("tool_calls")
372            .and_then(|tc| tc.as_array())
373            .map(|calls| {
374                calls
375                    .iter()
376                    .filter_map(|call| Self::parse_tool_call(call))
377                    .collect::<Vec<_>>()
378            })
379            .filter(|calls| !calls.is_empty());
380
381        let finish_reason = choice
382            .get("finish_reason")
383            .and_then(|fr| fr.as_str())
384            .map(Self::map_finish_reason)
385            .unwrap_or(FinishReason::Stop);
386
387        let usage = response_json.get("usage").map(|usage_value| Usage {
388            prompt_tokens: usage_value
389                .get("prompt_tokens")
390                .and_then(|pt| pt.as_u64())
391                .unwrap_or(0) as u32,
392            completion_tokens: usage_value
393                .get("completion_tokens")
394                .and_then(|ct| ct.as_u64())
395                .unwrap_or(0) as u32,
396            total_tokens: usage_value
397                .get("total_tokens")
398                .and_then(|tt| tt.as_u64())
399                .unwrap_or(0) as u32,
400            cached_prompt_tokens: usage_value
401                .get("prompt_tokens_details")
402                .and_then(|details| details.get("cached_tokens"))
403                .and_then(|value| value.as_u64())
404                .map(|value| value as u32),
405            cache_creation_tokens: None,
406            cache_read_tokens: None,
407        });
408
409        Ok(LLMResponse {
410            content,
411            tool_calls,
412            usage,
413            finish_reason,
414            reasoning,
415        })
416    }
417
418    fn map_finish_reason(reason: &str) -> FinishReason {
419        match reason {
420            "stop" => FinishReason::Stop,
421            "length" => FinishReason::Length,
422            "tool_calls" => FinishReason::ToolCalls,
423            "sensitive" => FinishReason::ContentFilter,
424            other => FinishReason::Error(other.to_string()),
425        }
426    }
427
428    fn available_models() -> Vec<String> {
429        models::zai::SUPPORTED_MODELS
430            .iter()
431            .map(|s| s.to_string())
432            .collect()
433    }
434}
435
436#[async_trait]
437impl LLMProvider for ZAIProvider {
438    fn name(&self) -> &str {
439        PROVIDER_KEY
440    }
441
442    fn supports_streaming(&self) -> bool {
443        false
444    }
445
446    fn supports_reasoning(&self, model: &str) -> bool {
447        matches!(
448            model,
449            models::zai::GLM_4_6
450                | models::zai::GLM_4_5
451                | models::zai::GLM_4_5_AIR
452                | models::zai::GLM_4_5_X
453                | models::zai::GLM_4_5_AIRX
454        )
455    }
456
457    fn supports_reasoning_effort(&self, _model: &str) -> bool {
458        false
459    }
460
461    async fn generate(&self, mut request: LLMRequest) -> Result<LLMResponse, LLMError> {
462        if request.model.trim().is_empty() {
463            request.model = self.model.clone();
464        }
465
466        if !Self::available_models().contains(&request.model) {
467            let formatted = error_display::format_llm_error(
468                PROVIDER_NAME,
469                &format!("Unsupported model: {}", request.model),
470            );
471            return Err(LLMError::InvalidRequest(formatted));
472        }
473
474        for message in &request.messages {
475            if let Err(err) = message.validate_for_provider(PROVIDER_KEY) {
476                let formatted = error_display::format_llm_error(PROVIDER_NAME, &err);
477                return Err(LLMError::InvalidRequest(formatted));
478            }
479        }
480
481        let payload = self.convert_to_zai_format(&request)?;
482        let url = format!("{}{}", self.base_url, CHAT_COMPLETIONS_PATH);
483
484        let response = self
485            .http_client
486            .post(&url)
487            .bearer_auth(&self.api_key)
488            .json(&payload)
489            .send()
490            .await
491            .map_err(|err| {
492                let formatted = error_display::format_llm_error(
493                    PROVIDER_NAME,
494                    &format!("Network error: {}", err),
495                );
496                LLMError::Network(formatted)
497            })?;
498
499        if !response.status().is_success() {
500            let status = response.status();
501            let text = response.text().await.unwrap_or_default();
502
503            if status.as_u16() == 429 || text.to_lowercase().contains("rate") {
504                return Err(LLMError::RateLimit);
505            }
506
507            let message = serde_json::from_str::<Value>(&text)
508                .ok()
509                .and_then(|value| {
510                    value
511                        .get("message")
512                        .and_then(|m| m.as_str())
513                        .map(|s| s.to_string())
514                })
515                .unwrap_or(text);
516
517            let formatted = error_display::format_llm_error(
518                PROVIDER_NAME,
519                &format!("HTTP {}: {}", status, message),
520            );
521            return Err(LLMError::Provider(formatted));
522        }
523
524        let json: Value = response.json().await.map_err(|err| {
525            let formatted = error_display::format_llm_error(
526                PROVIDER_NAME,
527                &format!("Failed to parse response: {}", err),
528            );
529            LLMError::Provider(formatted)
530        })?;
531
532        self.parse_zai_response(json)
533    }
534
535    fn supported_models(&self) -> Vec<String> {
536        Self::available_models()
537    }
538
539    fn validate_request(&self, request: &LLMRequest) -> Result<(), LLMError> {
540        if request.messages.is_empty() {
541            let formatted =
542                error_display::format_llm_error(PROVIDER_NAME, "Messages cannot be empty");
543            return Err(LLMError::InvalidRequest(formatted));
544        }
545
546        if !request.model.is_empty() && !Self::available_models().contains(&request.model) {
547            let formatted = error_display::format_llm_error(
548                PROVIDER_NAME,
549                &format!("Unsupported model: {}", request.model),
550            );
551            return Err(LLMError::InvalidRequest(formatted));
552        }
553
554        for message in &request.messages {
555            if let Err(err) = message.validate_for_provider(PROVIDER_KEY) {
556                let formatted = error_display::format_llm_error(PROVIDER_NAME, &err);
557                return Err(LLMError::InvalidRequest(formatted));
558            }
559        }
560
561        Ok(())
562    }
563}
564
565#[async_trait]
566impl LLMClient for ZAIProvider {
567    async fn generate(&mut self, prompt: &str) -> Result<llm_types::LLMResponse, LLMError> {
568        let request = self.parse_client_prompt(prompt);
569        let request_model = request.model.clone();
570        let response = LLMProvider::generate(self, request).await?;
571
572        Ok(llm_types::LLMResponse {
573            content: response.content.unwrap_or_default(),
574            model: request_model,
575            usage: response.usage.map(|usage| llm_types::Usage {
576                prompt_tokens: usage.prompt_tokens as usize,
577                completion_tokens: usage.completion_tokens as usize,
578                total_tokens: usage.total_tokens as usize,
579                cached_prompt_tokens: usage.cached_prompt_tokens.map(|v| v as usize),
580                cache_creation_tokens: usage.cache_creation_tokens.map(|v| v as usize),
581                cache_read_tokens: usage.cache_read_tokens.map(|v| v as usize),
582            }),
583            reasoning: response.reasoning,
584        })
585    }
586
587    fn backend_kind(&self) -> llm_types::BackendKind {
588        llm_types::BackendKind::ZAI
589    }
590
591    fn model_id(&self) -> &str {
592        &self.model
593    }
594}