Skip to main content

st/proxy/
zai.rs

1//! Z.AI Provider Implementation (Zhipu / GLM family)
2//!
3//! Z.AI exposes an OpenAI-compatible endpoint at https://open.bigmodel.cn/api/paas/v4
4//! Models: glm-4-plus, glm-4.7, glm-4.6, glm-4-air, glm-4-flash, etc.
5
6use crate::proxy::{LlmMessage, LlmProvider, LlmRequest, LlmResponse, LlmRole, LlmUsage};
7use anyhow::{Context, Result};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde::{Deserialize, Serialize};
11
12pub struct ZaiProvider {
13    client: Client,
14    api_key: String,
15    base_url: String,
16}
17
18impl ZaiProvider {
19    pub fn new(api_key: String) -> Self {
20        Self {
21            client: Client::new(),
22            api_key,
23            base_url: "https://open.bigmodel.cn/api/paas/v4".to_string(),
24        }
25    }
26
27    pub fn with_base_url(api_key: String, base_url: String) -> Self {
28        Self {
29            client: Client::new(),
30            api_key,
31            base_url,
32        }
33    }
34}
35
36impl Default for ZaiProvider {
37    fn default() -> Self {
38        let api_key = std::env::var("ZAI_API_KEY")
39            .or_else(|_| std::env::var("ZHIPU_API_KEY"))
40            .unwrap_or_default();
41        Self::new(api_key)
42    }
43}
44
45#[async_trait]
46impl LlmProvider for ZaiProvider {
47    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
48        let url = format!("{}/chat/completions", self.base_url);
49
50        let model = if request.model.is_empty() || request.model == "default" {
51            "glm-4-plus".to_string()
52        } else {
53            request.model.clone()
54        };
55
56        let zai_request = ZaiChatRequest {
57            model,
58            messages: request.messages.into_iter().map(Into::into).collect(),
59            temperature: request.temperature,
60            max_tokens: request.max_tokens,
61            stream: request.stream,
62        };
63
64        let response = self
65            .client
66            .post(&url)
67            .header("Authorization", format!("Bearer {}", self.api_key))
68            .header("Content-Type", "application/json")
69            .json(&zai_request)
70            .send()
71            .await
72            .context("Failed to send request to Z.AI")?;
73
74        if !response.status().is_success() {
75            let error_text = response.text().await?;
76            return Err(anyhow::anyhow!("Z.AI API error: {}", error_text));
77        }
78
79        let zai_response: ZaiChatResponse = response.json().await?;
80
81        let content = zai_response
82            .choices
83            .first()
84            .map(|c| c.message.content.clone())
85            .unwrap_or_default();
86
87        Ok(LlmResponse {
88            content,
89            model: zai_response.model.unwrap_or_else(|| "glm".to_string()),
90            usage: zai_response.usage.map(Into::into),
91        })
92    }
93
94    fn name(&self) -> &'static str {
95        "ZAI"
96    }
97}
98
99#[derive(Debug, Serialize)]
100struct ZaiChatRequest {
101    model: String,
102    messages: Vec<ZaiMessage>,
103    #[serde(skip_serializing_if = "Option::is_none")]
104    temperature: Option<f32>,
105    #[serde(skip_serializing_if = "Option::is_none")]
106    max_tokens: Option<usize>,
107    stream: bool,
108}
109
110#[derive(Debug, Serialize, Deserialize)]
111struct ZaiMessage {
112    role: String,
113    content: String,
114}
115
116impl From<LlmMessage> for ZaiMessage {
117    fn from(msg: LlmMessage) -> Self {
118        Self {
119            role: match msg.role {
120                LlmRole::System => "system".to_string(),
121                LlmRole::User => "user".to_string(),
122                LlmRole::Assistant => "assistant".to_string(),
123            },
124            content: msg.content,
125        }
126    }
127}
128
129#[derive(Debug, Deserialize)]
130struct ZaiChatResponse {
131    model: Option<String>,
132    choices: Vec<ZaiChoice>,
133    usage: Option<ZaiUsage>,
134}
135
136#[derive(Debug, Deserialize)]
137struct ZaiChoice {
138    message: ZaiMessage,
139}
140
141#[derive(Debug, Deserialize)]
142struct ZaiUsage {
143    prompt_tokens: usize,
144    completion_tokens: usize,
145    total_tokens: usize,
146}
147
148impl From<ZaiUsage> for LlmUsage {
149    fn from(u: ZaiUsage) -> Self {
150        Self {
151            prompt_tokens: u.prompt_tokens,
152            completion_tokens: u.completion_tokens,
153            total_tokens: u.total_tokens,
154        }
155    }
156}
157
158pub mod models {
159    pub const GLM_4_PLUS: &str = "glm-4-plus";
160    pub const GLM_4_7: &str = "glm-4.7";
161    pub const GLM_4_6: &str = "glm-4.6";
162    pub const GLM_4_AIR: &str = "glm-4-air";
163    pub const GLM_4_FLASH: &str = "glm-4-flash";
164}