1use crate::config::constants::{models, urls};
2use crate::config::core::{PromptCachingConfig, ZaiPromptCacheSettings};
3use crate::llm::client::LLMClient;
4use crate::llm::error_display;
5use crate::llm::provider::{
6 FinishReason, LLMError, LLMProvider, LLMRequest, LLMResponse, Message, MessageRole, ToolCall,
7 ToolChoice, ToolDefinition, Usage,
8};
9use crate::llm::types as llm_types;
10use async_trait::async_trait;
11use reqwest::Client as HttpClient;
12use serde_json::{Value, json};
13use std::collections::HashSet;
14
15const PROVIDER_NAME: &str = "Z.AI";
16const PROVIDER_KEY: &str = "zai";
17const CHAT_COMPLETIONS_PATH: &str = "/paas/v4/chat/completions";
18
19pub struct ZAIProvider {
20 api_key: String,
21 http_client: HttpClient,
22 base_url: String,
23 model: String,
24 prompt_cache_settings: ZaiPromptCacheSettings,
25}
26
27impl ZAIProvider {
28 fn serialize_tools(tools: &[ToolDefinition]) -> Option<Value> {
29 if tools.is_empty() {
30 return None;
31 }
32
33 let serialized = tools
34 .iter()
35 .map(|tool| {
36 json!({
37 "type": tool.tool_type,
38 "function": {
39 "name": tool.function.name,
40 "description": tool.function.description,
41 "parameters": tool.function.parameters,
42 }
43 })
44 })
45 .collect::<Vec<Value>>();
46
47 Some(Value::Array(serialized))
48 }
49
50 fn extract_prompt_cache_settings(
51 prompt_cache: Option<PromptCachingConfig>,
52 ) -> ZaiPromptCacheSettings {
53 if let Some(cfg) = prompt_cache {
54 cfg.providers.zai
55 } else {
56 ZaiPromptCacheSettings::default()
57 }
58 }
59
60 fn with_model_internal(
61 api_key: String,
62 model: String,
63 base_url: Option<String>,
64 prompt_cache: Option<PromptCachingConfig>,
65 ) -> Self {
66 Self {
67 api_key,
68 http_client: HttpClient::new(),
69 base_url: base_url.unwrap_or_else(|| urls::Z_AI_API_BASE.to_string()),
70 model,
71 prompt_cache_settings: Self::extract_prompt_cache_settings(prompt_cache),
72 }
73 }
74
75 pub fn new(api_key: String) -> Self {
76 Self::with_model_internal(api_key, models::zai::DEFAULT_MODEL.to_string(), None, None)
77 }
78
79 pub fn with_model(api_key: String, model: String) -> Self {
80 Self::with_model_internal(api_key, model, None, None)
81 }
82
83 pub fn from_config(
84 api_key: Option<String>,
85 model: Option<String>,
86 base_url: Option<String>,
87 prompt_cache: Option<PromptCachingConfig>,
88 ) -> Self {
89 let api_key_value = api_key.unwrap_or_default();
90 let model_value = model.unwrap_or_else(|| models::zai::DEFAULT_MODEL.to_string());
91 Self::with_model_internal(api_key_value, model_value, base_url, prompt_cache)
92 }
93
94 fn default_request(&self, prompt: &str) -> LLMRequest {
95 LLMRequest {
96 messages: vec![Message::user(prompt.to_string())],
97 system_prompt: None,
98 tools: None,
99 model: self.model.clone(),
100 max_tokens: None,
101 temperature: None,
102 stream: false,
103 tool_choice: None,
104 parallel_tool_calls: None,
105 parallel_tool_config: None,
106 reasoning_effort: None,
107 }
108 }
109
110 fn parse_client_prompt(&self, prompt: &str) -> LLMRequest {
111 let trimmed = prompt.trim_start();
112 if trimmed.starts_with('{') {
113 if let Ok(value) = serde_json::from_str::<Value>(trimmed) {
114 if let Some(request) = self.parse_chat_request(&value) {
115 return request;
116 }
117 }
118 }
119
120 self.default_request(prompt)
121 }
122
123 fn parse_chat_request(&self, value: &Value) -> Option<LLMRequest> {
124 let messages_value = value.get("messages")?.as_array()?;
125 let mut system_prompt = value
126 .get("system")
127 .and_then(|entry| entry.as_str())
128 .map(|text| text.to_string());
129 let mut messages = Vec::new();
130
131 for entry in messages_value {
132 let role = entry
133 .get("role")
134 .and_then(|r| r.as_str())
135 .unwrap_or(crate::config::constants::message_roles::USER);
136 let content = entry
137 .get("content")
138 .map(|c| match c {
139 Value::String(text) => text.to_string(),
140 other => other.to_string(),
141 })
142 .unwrap_or_default();
143
144 match role {
145 "system" => {
146 if system_prompt.is_none() && !content.is_empty() {
147 system_prompt = Some(content);
148 }
149 }
150 "assistant" => {
151 let tool_calls = entry
152 .get("tool_calls")
153 .and_then(|tc| tc.as_array())
154 .map(|calls| {
155 calls
156 .iter()
157 .filter_map(|call| Self::parse_tool_call(call))
158 .collect::<Vec<_>>()
159 })
160 .filter(|calls| !calls.is_empty());
161
162 messages.push(Message {
163 role: MessageRole::Assistant,
164 content,
165 tool_calls,
166 tool_call_id: None,
167 });
168 }
169 "tool" => {
170 if let Some(tool_call_id) = entry.get("tool_call_id").and_then(|v| v.as_str()) {
171 messages.push(Message::tool_response(tool_call_id.to_string(), content));
172 }
173 }
174 _ => {
175 messages.push(Message::user(content));
176 }
177 }
178 }
179
180 Some(LLMRequest {
181 messages,
182 system_prompt,
183 model: value
184 .get("model")
185 .and_then(|m| m.as_str())
186 .unwrap_or(&self.model)
187 .to_string(),
188 max_tokens: value
189 .get("max_tokens")
190 .and_then(|m| m.as_u64())
191 .map(|m| m as u32),
192 temperature: value
193 .get("temperature")
194 .and_then(|t| t.as_f64())
195 .map(|t| t as f32),
196 stream: value
197 .get("stream")
198 .and_then(|s| s.as_bool())
199 .unwrap_or(false),
200 tools: None,
201 tool_choice: value.get("tool_choice").and_then(|choice| match choice {
202 Value::String(s) => match s.as_str() {
203 "auto" => Some(ToolChoice::auto()),
204 "none" => Some(ToolChoice::none()),
205 "any" | "required" => Some(ToolChoice::any()),
206 _ => None,
207 },
208 _ => None,
209 }),
210 parallel_tool_calls: None,
211 parallel_tool_config: None,
212 reasoning_effort: None,
213 })
214 }
215
216 fn parse_tool_call(value: &Value) -> Option<ToolCall> {
217 let id = value.get("id").and_then(|v| v.as_str())?;
218 let function = value.get("function")?;
219 let name = function.get("name").and_then(|v| v.as_str())?;
220 let arguments = function.get("arguments");
221 let serialized = arguments.map_or("{}".to_string(), |value| {
222 if value.is_string() {
223 value.as_str().unwrap_or("").to_string()
224 } else {
225 value.to_string()
226 }
227 });
228
229 Some(ToolCall::function(
230 id.to_string(),
231 name.to_string(),
232 serialized,
233 ))
234 }
235
236 fn convert_to_zai_format(&self, request: &LLMRequest) -> Result<Value, LLMError> {
237 let mut messages = Vec::new();
238 let mut active_tool_call_ids: HashSet<String> = HashSet::new();
239
240 if let Some(system_prompt) = &request.system_prompt {
241 messages.push(json!({
242 "role": crate::config::constants::message_roles::SYSTEM,
243 "content": system_prompt
244 }));
245 }
246
247 for msg in &request.messages {
248 let role = msg.role.as_generic_str();
249 let mut message = json!({
250 "role": role,
251 "content": msg.content
252 });
253 let mut skip_message = false;
254
255 if msg.role == MessageRole::Assistant {
256 if let Some(tool_calls) = &msg.tool_calls {
257 if !tool_calls.is_empty() {
258 let tool_calls_json: Vec<Value> = tool_calls
259 .iter()
260 .map(|tc| {
261 active_tool_call_ids.insert(tc.id.clone());
262 json!({
263 "id": tc.id,
264 "type": "function",
265 "function": {
266 "name": tc.function.name,
267 "arguments": tc.function.arguments,
268 }
269 })
270 })
271 .collect();
272 message["tool_calls"] = Value::Array(tool_calls_json);
273 }
274 }
275 }
276
277 if msg.role == MessageRole::Tool {
278 match &msg.tool_call_id {
279 Some(tool_call_id) if active_tool_call_ids.contains(tool_call_id) => {
280 message["tool_call_id"] = Value::String(tool_call_id.clone());
281 active_tool_call_ids.remove(tool_call_id);
282 }
283 Some(_) | None => {
284 skip_message = true;
285 }
286 }
287 }
288
289 if !skip_message {
290 messages.push(message);
291 }
292 }
293
294 if messages.is_empty() {
295 let formatted = error_display::format_llm_error(PROVIDER_NAME, "No messages provided");
296 return Err(LLMError::InvalidRequest(formatted));
297 }
298
299 let mut payload = json!({
300 "model": request.model,
301 "messages": messages,
302 "stream": request.stream,
303 });
304
305 if let Some(max_tokens) = request.max_tokens {
306 payload["max_tokens"] = json!(max_tokens);
307 }
308
309 if let Some(temperature) = request.temperature {
310 payload["temperature"] = json!(temperature);
311 }
312
313 if let Some(tools) = &request.tools {
314 if let Some(serialized) = Self::serialize_tools(tools) {
315 payload["tools"] = serialized;
316 }
317 }
318
319 if let Some(choice) = &request.tool_choice {
320 payload["tool_choice"] = choice.to_provider_format("openai");
321 }
322
323 if self.supports_reasoning(&request.model) || request.reasoning_effort.is_some() {
324 payload["thinking"] = json!({ "type": "enabled" });
325 }
326
327 Ok(payload)
328 }
329
330 fn parse_zai_response(&self, response_json: Value) -> Result<LLMResponse, LLMError> {
331 let choices = response_json
332 .get("choices")
333 .and_then(|c| c.as_array())
334 .ok_or_else(|| {
335 let formatted = error_display::format_llm_error(
336 PROVIDER_NAME,
337 "Invalid response format: missing choices",
338 );
339 LLMError::Provider(formatted)
340 })?;
341
342 if choices.is_empty() {
343 let formatted =
344 error_display::format_llm_error(PROVIDER_NAME, "No choices in response");
345 return Err(LLMError::Provider(formatted));
346 }
347
348 let choice = &choices[0];
349 let message = choice.get("message").ok_or_else(|| {
350 let formatted = error_display::format_llm_error(
351 PROVIDER_NAME,
352 "Invalid response format: missing message",
353 );
354 LLMError::Provider(formatted)
355 })?;
356
357 let content = message
358 .get("content")
359 .and_then(|c| c.as_str())
360 .map(|s| s.to_string());
361
362 let reasoning = message
363 .get("reasoning_content")
364 .map(|value| match value {
365 Value::String(text) => Some(text.to_string()),
366 Value::Array(parts) => {
367 let combined = parts
368 .iter()
369 .filter_map(|part| part.as_str())
370 .collect::<Vec<_>>()
371 .join("");
372 if combined.is_empty() {
373 None
374 } else {
375 Some(combined)
376 }
377 }
378 _ => None,
379 })
380 .flatten();
381
382 let tool_calls = message
383 .get("tool_calls")
384 .and_then(|tc| tc.as_array())
385 .map(|calls| {
386 calls
387 .iter()
388 .filter_map(|call| Self::parse_tool_call(call))
389 .collect::<Vec<_>>()
390 })
391 .filter(|calls| !calls.is_empty());
392
393 let finish_reason = choice
394 .get("finish_reason")
395 .and_then(|fr| fr.as_str())
396 .map(Self::map_finish_reason)
397 .unwrap_or(FinishReason::Stop);
398
399 let usage = response_json.get("usage").map(|usage_value| Usage {
400 prompt_tokens: usage_value
401 .get("prompt_tokens")
402 .and_then(|pt| pt.as_u64())
403 .unwrap_or(0) as u32,
404 completion_tokens: usage_value
405 .get("completion_tokens")
406 .and_then(|ct| ct.as_u64())
407 .unwrap_or(0) as u32,
408 total_tokens: usage_value
409 .get("total_tokens")
410 .and_then(|tt| tt.as_u64())
411 .unwrap_or(0) as u32,
412 cached_prompt_tokens: usage_value
413 .get("prompt_tokens_details")
414 .and_then(|details| details.get("cached_tokens"))
415 .and_then(|value| value.as_u64())
416 .map(|value| value as u32),
417 cache_creation_tokens: None,
418 cache_read_tokens: None,
419 });
420
421 Ok(LLMResponse {
422 content,
423 tool_calls,
424 usage,
425 finish_reason,
426 reasoning,
427 })
428 }
429
430 fn map_finish_reason(reason: &str) -> FinishReason {
431 match reason {
432 "stop" => FinishReason::Stop,
433 "length" => FinishReason::Length,
434 "tool_calls" => FinishReason::ToolCalls,
435 "sensitive" => FinishReason::ContentFilter,
436 other => FinishReason::Error(other.to_string()),
437 }
438 }
439
440 fn available_models() -> Vec<String> {
441 models::zai::SUPPORTED_MODELS
442 .iter()
443 .map(|s| s.to_string())
444 .collect()
445 }
446}
447
448#[async_trait]
449impl LLMProvider for ZAIProvider {
450 fn name(&self) -> &str {
451 PROVIDER_KEY
452 }
453
454 fn supports_streaming(&self) -> bool {
455 false
456 }
457
458 fn supports_reasoning(&self, model: &str) -> bool {
459 matches!(
460 model,
461 models::zai::GLM_4_6
462 | models::zai::GLM_4_5
463 | models::zai::GLM_4_5_AIR
464 | models::zai::GLM_4_5_X
465 | models::zai::GLM_4_5_AIRX
466 )
467 }
468
469 fn supports_reasoning_effort(&self, _model: &str) -> bool {
470 false
471 }
472
473 async fn generate(&self, mut request: LLMRequest) -> Result<LLMResponse, LLMError> {
474 if request.model.trim().is_empty() {
475 request.model = self.model.clone();
476 }
477
478 if !Self::available_models().contains(&request.model) {
479 let formatted = error_display::format_llm_error(
480 PROVIDER_NAME,
481 &format!("Unsupported model: {}", request.model),
482 );
483 return Err(LLMError::InvalidRequest(formatted));
484 }
485
486 for message in &request.messages {
487 if let Err(err) = message.validate_for_provider(PROVIDER_KEY) {
488 let formatted = error_display::format_llm_error(PROVIDER_NAME, &err);
489 return Err(LLMError::InvalidRequest(formatted));
490 }
491 }
492
493 let payload = self.convert_to_zai_format(&request)?;
494 let url = format!("{}{}", self.base_url, CHAT_COMPLETIONS_PATH);
495
496 let response = self
497 .http_client
498 .post(&url)
499 .bearer_auth(&self.api_key)
500 .json(&payload)
501 .send()
502 .await
503 .map_err(|err| {
504 let formatted = error_display::format_llm_error(
505 PROVIDER_NAME,
506 &format!("Network error: {}", err),
507 );
508 LLMError::Network(formatted)
509 })?;
510
511 if !response.status().is_success() {
512 let status = response.status();
513 let text = response.text().await.unwrap_or_default();
514
515 if status.as_u16() == 429 || text.to_lowercase().contains("rate") {
516 return Err(LLMError::RateLimit);
517 }
518
519 let message = serde_json::from_str::<Value>(&text)
520 .ok()
521 .and_then(|value| {
522 value
523 .get("message")
524 .and_then(|m| m.as_str())
525 .map(|s| s.to_string())
526 })
527 .unwrap_or(text);
528
529 let formatted = error_display::format_llm_error(
530 PROVIDER_NAME,
531 &format!("HTTP {}: {}", status, message),
532 );
533 return Err(LLMError::Provider(formatted));
534 }
535
536 let json: Value = response.json().await.map_err(|err| {
537 let formatted = error_display::format_llm_error(
538 PROVIDER_NAME,
539 &format!("Failed to parse response: {}", err),
540 );
541 LLMError::Provider(formatted)
542 })?;
543
544 self.parse_zai_response(json)
545 }
546
547 fn supported_models(&self) -> Vec<String> {
548 Self::available_models()
549 }
550
551 fn validate_request(&self, request: &LLMRequest) -> Result<(), LLMError> {
552 if request.messages.is_empty() {
553 let formatted =
554 error_display::format_llm_error(PROVIDER_NAME, "Messages cannot be empty");
555 return Err(LLMError::InvalidRequest(formatted));
556 }
557
558 if !request.model.is_empty() && !Self::available_models().contains(&request.model) {
559 let formatted = error_display::format_llm_error(
560 PROVIDER_NAME,
561 &format!("Unsupported model: {}", request.model),
562 );
563 return Err(LLMError::InvalidRequest(formatted));
564 }
565
566 for message in &request.messages {
567 if let Err(err) = message.validate_for_provider(PROVIDER_KEY) {
568 let formatted = error_display::format_llm_error(PROVIDER_NAME, &err);
569 return Err(LLMError::InvalidRequest(formatted));
570 }
571 }
572
573 Ok(())
574 }
575}
576
577#[async_trait]
578impl LLMClient for ZAIProvider {
579 async fn generate(&mut self, prompt: &str) -> Result<llm_types::LLMResponse, LLMError> {
580 let request = self.parse_client_prompt(prompt);
581 let request_model = request.model.clone();
582 let response = LLMProvider::generate(self, request).await?;
583
584 Ok(llm_types::LLMResponse {
585 content: response.content.unwrap_or_default(),
586 model: request_model,
587 usage: response.usage.map(|usage| llm_types::Usage {
588 prompt_tokens: usage.prompt_tokens as usize,
589 completion_tokens: usage.completion_tokens as usize,
590 total_tokens: usage.total_tokens as usize,
591 cached_prompt_tokens: usage.cached_prompt_tokens.map(|v| v as usize),
592 cache_creation_tokens: usage.cache_creation_tokens.map(|v| v as usize),
593 cache_read_tokens: usage.cache_read_tokens.map(|v| v as usize),
594 }),
595 reasoning: response.reasoning,
596 })
597 }
598
599 fn backend_kind(&self) -> llm_types::BackendKind {
600 llm_types::BackendKind::ZAI
601 }
602
603 fn model_id(&self) -> &str {
604 &self.model
605 }
606}