vtcode_core/llm/providers/
zai.rs1use crate::config::constants::{models, urls};
2use crate::config::core::PromptCachingConfig;
3use crate::llm::client::LLMClient;
4use crate::llm::error_display;
5use crate::llm::provider::{
6 FinishReason, LLMError, LLMProvider, LLMRequest, LLMResponse, Message, MessageRole, ToolCall,
7 ToolChoice, ToolDefinition, Usage,
8};
9use crate::llm::types as llm_types;
10use async_trait::async_trait;
11use reqwest::Client as HttpClient;
12use serde_json::{Value, json};
13use std::collections::HashSet;
14
15const PROVIDER_NAME: &str = "Z.AI";
16const PROVIDER_KEY: &str = "zai";
17const CHAT_COMPLETIONS_PATH: &str = "/paas/v4/chat/completions";
18
19pub struct ZAIProvider {
20 api_key: String,
21 http_client: HttpClient,
22 base_url: String,
23 model: String,
24}
25
26impl ZAIProvider {
27 fn serialize_tools(tools: &[ToolDefinition]) -> Option<Value> {
28 if tools.is_empty() {
29 return None;
30 }
31
32 let serialized = tools
33 .iter()
34 .map(|tool| {
35 json!({
36 "type": tool.tool_type,
37 "function": {
38 "name": tool.function.name,
39 "description": tool.function.description,
40 "parameters": tool.function.parameters,
41 }
42 })
43 })
44 .collect::<Vec<Value>>();
45
46 Some(Value::Array(serialized))
47 }
48
49 fn with_model_internal(
50 api_key: String,
51 model: String,
52 base_url: Option<String>,
53 _prompt_cache: Option<PromptCachingConfig>,
54 ) -> Self {
55 Self {
56 api_key,
57 http_client: HttpClient::new(),
58 base_url: base_url.unwrap_or_else(|| urls::Z_AI_API_BASE.to_string()),
59 model,
60 }
61 }
62
63 pub fn new(api_key: String) -> Self {
64 Self::with_model_internal(api_key, models::zai::DEFAULT_MODEL.to_string(), None, None)
65 }
66
67 pub fn with_model(api_key: String, model: String) -> Self {
68 Self::with_model_internal(api_key, model, None, None)
69 }
70
71 pub fn from_config(
72 api_key: Option<String>,
73 model: Option<String>,
74 base_url: Option<String>,
75 prompt_cache: Option<PromptCachingConfig>,
76 ) -> Self {
77 let api_key_value = api_key.unwrap_or_default();
78 let model_value = model.unwrap_or_else(|| models::zai::DEFAULT_MODEL.to_string());
79 Self::with_model_internal(api_key_value, model_value, base_url, prompt_cache)
80 }
81
82 fn default_request(&self, prompt: &str) -> LLMRequest {
83 LLMRequest {
84 messages: vec![Message::user(prompt.to_string())],
85 system_prompt: None,
86 tools: None,
87 model: self.model.clone(),
88 max_tokens: None,
89 temperature: None,
90 stream: false,
91 tool_choice: None,
92 parallel_tool_calls: None,
93 parallel_tool_config: None,
94 reasoning_effort: None,
95 }
96 }
97
98 fn parse_client_prompt(&self, prompt: &str) -> LLMRequest {
99 let trimmed = prompt.trim_start();
100 if trimmed.starts_with('{') {
101 if let Ok(value) = serde_json::from_str::<Value>(trimmed) {
102 if let Some(request) = self.parse_chat_request(&value) {
103 return request;
104 }
105 }
106 }
107
108 self.default_request(prompt)
109 }
110
111 fn parse_chat_request(&self, value: &Value) -> Option<LLMRequest> {
112 let messages_value = value.get("messages")?.as_array()?;
113 let mut system_prompt = value
114 .get("system")
115 .and_then(|entry| entry.as_str())
116 .map(|text| text.to_string());
117 let mut messages = Vec::new();
118
119 for entry in messages_value {
120 let role = entry
121 .get("role")
122 .and_then(|r| r.as_str())
123 .unwrap_or(crate::config::constants::message_roles::USER);
124 let content = entry
125 .get("content")
126 .map(|c| match c {
127 Value::String(text) => text.to_string(),
128 other => other.to_string(),
129 })
130 .unwrap_or_default();
131
132 match role {
133 "system" => {
134 if system_prompt.is_none() && !content.is_empty() {
135 system_prompt = Some(content);
136 }
137 }
138 "assistant" => {
139 let tool_calls = entry
140 .get("tool_calls")
141 .and_then(|tc| tc.as_array())
142 .map(|calls| {
143 calls
144 .iter()
145 .filter_map(|call| Self::parse_tool_call(call))
146 .collect::<Vec<_>>()
147 })
148 .filter(|calls| !calls.is_empty());
149
150 messages.push(Message {
151 role: MessageRole::Assistant,
152 content,
153 tool_calls,
154 tool_call_id: None,
155 });
156 }
157 "tool" => {
158 if let Some(tool_call_id) = entry.get("tool_call_id").and_then(|v| v.as_str()) {
159 messages.push(Message::tool_response(tool_call_id.to_string(), content));
160 }
161 }
162 _ => {
163 messages.push(Message::user(content));
164 }
165 }
166 }
167
168 Some(LLMRequest {
169 messages,
170 system_prompt,
171 model: value
172 .get("model")
173 .and_then(|m| m.as_str())
174 .unwrap_or(&self.model)
175 .to_string(),
176 max_tokens: value
177 .get("max_tokens")
178 .and_then(|m| m.as_u64())
179 .map(|m| m as u32),
180 temperature: value
181 .get("temperature")
182 .and_then(|t| t.as_f64())
183 .map(|t| t as f32),
184 stream: value
185 .get("stream")
186 .and_then(|s| s.as_bool())
187 .unwrap_or(false),
188 tools: None,
189 tool_choice: value.get("tool_choice").and_then(|choice| match choice {
190 Value::String(s) => match s.as_str() {
191 "auto" => Some(ToolChoice::auto()),
192 "none" => Some(ToolChoice::none()),
193 "any" | "required" => Some(ToolChoice::any()),
194 _ => None,
195 },
196 _ => None,
197 }),
198 parallel_tool_calls: None,
199 parallel_tool_config: None,
200 reasoning_effort: None,
201 })
202 }
203
204 fn parse_tool_call(value: &Value) -> Option<ToolCall> {
205 let id = value.get("id").and_then(|v| v.as_str())?;
206 let function = value.get("function")?;
207 let name = function.get("name").and_then(|v| v.as_str())?;
208 let arguments = function.get("arguments");
209 let serialized = arguments.map_or("{}".to_string(), |value| {
210 if value.is_string() {
211 value.as_str().unwrap_or("").to_string()
212 } else {
213 value.to_string()
214 }
215 });
216
217 Some(ToolCall::function(
218 id.to_string(),
219 name.to_string(),
220 serialized,
221 ))
222 }
223
224 fn convert_to_zai_format(&self, request: &LLMRequest) -> Result<Value, LLMError> {
225 let mut messages = Vec::new();
226 let mut active_tool_call_ids: HashSet<String> = HashSet::new();
227
228 if let Some(system_prompt) = &request.system_prompt {
229 messages.push(json!({
230 "role": crate::config::constants::message_roles::SYSTEM,
231 "content": system_prompt
232 }));
233 }
234
235 for msg in &request.messages {
236 let role = msg.role.as_generic_str();
237 let mut message = json!({
238 "role": role,
239 "content": msg.content
240 });
241 let mut skip_message = false;
242
243 if msg.role == MessageRole::Assistant {
244 if let Some(tool_calls) = &msg.tool_calls {
245 if !tool_calls.is_empty() {
246 let tool_calls_json: Vec<Value> = tool_calls
247 .iter()
248 .map(|tc| {
249 active_tool_call_ids.insert(tc.id.clone());
250 json!({
251 "id": tc.id,
252 "type": "function",
253 "function": {
254 "name": tc.function.name,
255 "arguments": tc.function.arguments,
256 }
257 })
258 })
259 .collect();
260 message["tool_calls"] = Value::Array(tool_calls_json);
261 }
262 }
263 }
264
265 if msg.role == MessageRole::Tool {
266 match &msg.tool_call_id {
267 Some(tool_call_id) if active_tool_call_ids.contains(tool_call_id) => {
268 message["tool_call_id"] = Value::String(tool_call_id.clone());
269 active_tool_call_ids.remove(tool_call_id);
270 }
271 Some(_) | None => {
272 skip_message = true;
273 }
274 }
275 }
276
277 if !skip_message {
278 messages.push(message);
279 }
280 }
281
282 if messages.is_empty() {
283 let formatted = error_display::format_llm_error(PROVIDER_NAME, "No messages provided");
284 return Err(LLMError::InvalidRequest(formatted));
285 }
286
287 let mut payload = json!({
288 "model": request.model,
289 "messages": messages,
290 "stream": request.stream,
291 });
292
293 if let Some(max_tokens) = request.max_tokens {
294 payload["max_tokens"] = json!(max_tokens);
295 }
296
297 if let Some(temperature) = request.temperature {
298 payload["temperature"] = json!(temperature);
299 }
300
301 if let Some(tools) = &request.tools {
302 if let Some(serialized) = Self::serialize_tools(tools) {
303 payload["tools"] = serialized;
304 }
305 }
306
307 if let Some(choice) = &request.tool_choice {
308 payload["tool_choice"] = choice.to_provider_format("openai");
309 }
310
311 if self.supports_reasoning(&request.model) || request.reasoning_effort.is_some() {
312 payload["thinking"] = json!({ "type": "enabled" });
313 }
314
315 Ok(payload)
316 }
317
318 fn parse_zai_response(&self, response_json: Value) -> Result<LLMResponse, LLMError> {
319 let choices = response_json
320 .get("choices")
321 .and_then(|c| c.as_array())
322 .ok_or_else(|| {
323 let formatted = error_display::format_llm_error(
324 PROVIDER_NAME,
325 "Invalid response format: missing choices",
326 );
327 LLMError::Provider(formatted)
328 })?;
329
330 if choices.is_empty() {
331 let formatted =
332 error_display::format_llm_error(PROVIDER_NAME, "No choices in response");
333 return Err(LLMError::Provider(formatted));
334 }
335
336 let choice = &choices[0];
337 let message = choice.get("message").ok_or_else(|| {
338 let formatted = error_display::format_llm_error(
339 PROVIDER_NAME,
340 "Invalid response format: missing message",
341 );
342 LLMError::Provider(formatted)
343 })?;
344
345 let content = message
346 .get("content")
347 .and_then(|c| c.as_str())
348 .map(|s| s.to_string());
349
350 let reasoning = message
351 .get("reasoning_content")
352 .map(|value| match value {
353 Value::String(text) => Some(text.to_string()),
354 Value::Array(parts) => {
355 let combined = parts
356 .iter()
357 .filter_map(|part| part.as_str())
358 .collect::<Vec<_>>()
359 .join("");
360 if combined.is_empty() {
361 None
362 } else {
363 Some(combined)
364 }
365 }
366 _ => None,
367 })
368 .flatten();
369
370 let tool_calls = message
371 .get("tool_calls")
372 .and_then(|tc| tc.as_array())
373 .map(|calls| {
374 calls
375 .iter()
376 .filter_map(|call| Self::parse_tool_call(call))
377 .collect::<Vec<_>>()
378 })
379 .filter(|calls| !calls.is_empty());
380
381 let finish_reason = choice
382 .get("finish_reason")
383 .and_then(|fr| fr.as_str())
384 .map(Self::map_finish_reason)
385 .unwrap_or(FinishReason::Stop);
386
387 let usage = response_json.get("usage").map(|usage_value| Usage {
388 prompt_tokens: usage_value
389 .get("prompt_tokens")
390 .and_then(|pt| pt.as_u64())
391 .unwrap_or(0) as u32,
392 completion_tokens: usage_value
393 .get("completion_tokens")
394 .and_then(|ct| ct.as_u64())
395 .unwrap_or(0) as u32,
396 total_tokens: usage_value
397 .get("total_tokens")
398 .and_then(|tt| tt.as_u64())
399 .unwrap_or(0) as u32,
400 cached_prompt_tokens: usage_value
401 .get("prompt_tokens_details")
402 .and_then(|details| details.get("cached_tokens"))
403 .and_then(|value| value.as_u64())
404 .map(|value| value as u32),
405 cache_creation_tokens: None,
406 cache_read_tokens: None,
407 });
408
409 Ok(LLMResponse {
410 content,
411 tool_calls,
412 usage,
413 finish_reason,
414 reasoning,
415 })
416 }
417
418 fn map_finish_reason(reason: &str) -> FinishReason {
419 match reason {
420 "stop" => FinishReason::Stop,
421 "length" => FinishReason::Length,
422 "tool_calls" => FinishReason::ToolCalls,
423 "sensitive" => FinishReason::ContentFilter,
424 other => FinishReason::Error(other.to_string()),
425 }
426 }
427
428 fn available_models() -> Vec<String> {
429 models::zai::SUPPORTED_MODELS
430 .iter()
431 .map(|s| s.to_string())
432 .collect()
433 }
434}
435
436#[async_trait]
437impl LLMProvider for ZAIProvider {
438 fn name(&self) -> &str {
439 PROVIDER_KEY
440 }
441
442 fn supports_streaming(&self) -> bool {
443 false
444 }
445
446 fn supports_reasoning(&self, model: &str) -> bool {
447 matches!(
448 model,
449 models::zai::GLM_4_6
450 | models::zai::GLM_4_5
451 | models::zai::GLM_4_5_AIR
452 | models::zai::GLM_4_5_X
453 | models::zai::GLM_4_5_AIRX
454 )
455 }
456
457 fn supports_reasoning_effort(&self, _model: &str) -> bool {
458 false
459 }
460
461 async fn generate(&self, mut request: LLMRequest) -> Result<LLMResponse, LLMError> {
462 if request.model.trim().is_empty() {
463 request.model = self.model.clone();
464 }
465
466 if !Self::available_models().contains(&request.model) {
467 let formatted = error_display::format_llm_error(
468 PROVIDER_NAME,
469 &format!("Unsupported model: {}", request.model),
470 );
471 return Err(LLMError::InvalidRequest(formatted));
472 }
473
474 for message in &request.messages {
475 if let Err(err) = message.validate_for_provider(PROVIDER_KEY) {
476 let formatted = error_display::format_llm_error(PROVIDER_NAME, &err);
477 return Err(LLMError::InvalidRequest(formatted));
478 }
479 }
480
481 let payload = self.convert_to_zai_format(&request)?;
482 let url = format!("{}{}", self.base_url, CHAT_COMPLETIONS_PATH);
483
484 let response = self
485 .http_client
486 .post(&url)
487 .bearer_auth(&self.api_key)
488 .json(&payload)
489 .send()
490 .await
491 .map_err(|err| {
492 let formatted = error_display::format_llm_error(
493 PROVIDER_NAME,
494 &format!("Network error: {}", err),
495 );
496 LLMError::Network(formatted)
497 })?;
498
499 if !response.status().is_success() {
500 let status = response.status();
501 let text = response.text().await.unwrap_or_default();
502
503 if status.as_u16() == 429 || text.to_lowercase().contains("rate") {
504 return Err(LLMError::RateLimit);
505 }
506
507 let message = serde_json::from_str::<Value>(&text)
508 .ok()
509 .and_then(|value| {
510 value
511 .get("message")
512 .and_then(|m| m.as_str())
513 .map(|s| s.to_string())
514 })
515 .unwrap_or(text);
516
517 let formatted = error_display::format_llm_error(
518 PROVIDER_NAME,
519 &format!("HTTP {}: {}", status, message),
520 );
521 return Err(LLMError::Provider(formatted));
522 }
523
524 let json: Value = response.json().await.map_err(|err| {
525 let formatted = error_display::format_llm_error(
526 PROVIDER_NAME,
527 &format!("Failed to parse response: {}", err),
528 );
529 LLMError::Provider(formatted)
530 })?;
531
532 self.parse_zai_response(json)
533 }
534
535 fn supported_models(&self) -> Vec<String> {
536 Self::available_models()
537 }
538
539 fn validate_request(&self, request: &LLMRequest) -> Result<(), LLMError> {
540 if request.messages.is_empty() {
541 let formatted =
542 error_display::format_llm_error(PROVIDER_NAME, "Messages cannot be empty");
543 return Err(LLMError::InvalidRequest(formatted));
544 }
545
546 if !request.model.is_empty() && !Self::available_models().contains(&request.model) {
547 let formatted = error_display::format_llm_error(
548 PROVIDER_NAME,
549 &format!("Unsupported model: {}", request.model),
550 );
551 return Err(LLMError::InvalidRequest(formatted));
552 }
553
554 for message in &request.messages {
555 if let Err(err) = message.validate_for_provider(PROVIDER_KEY) {
556 let formatted = error_display::format_llm_error(PROVIDER_NAME, &err);
557 return Err(LLMError::InvalidRequest(formatted));
558 }
559 }
560
561 Ok(())
562 }
563}
564
565#[async_trait]
566impl LLMClient for ZAIProvider {
567 async fn generate(&mut self, prompt: &str) -> Result<llm_types::LLMResponse, LLMError> {
568 let request = self.parse_client_prompt(prompt);
569 let request_model = request.model.clone();
570 let response = LLMProvider::generate(self, request).await?;
571
572 Ok(llm_types::LLMResponse {
573 content: response.content.unwrap_or_default(),
574 model: request_model,
575 usage: response.usage.map(|usage| llm_types::Usage {
576 prompt_tokens: usage.prompt_tokens as usize,
577 completion_tokens: usage.completion_tokens as usize,
578 total_tokens: usage.total_tokens as usize,
579 cached_prompt_tokens: usage.cached_prompt_tokens.map(|v| v as usize),
580 cache_creation_tokens: usage.cache_creation_tokens.map(|v| v as usize),
581 cache_read_tokens: usage.cache_read_tokens.map(|v| v as usize),
582 }),
583 reasoning: response.reasoning,
584 })
585 }
586
587 fn backend_kind(&self) -> llm_types::BackendKind {
588 llm_types::BackendKind::ZAI
589 }
590
591 fn model_id(&self) -> &str {
592 &self.model
593 }
594}