zai_rs/file/
upload.rs

1use std::path::PathBuf;
2
3use super::request::FilePurpose;
4use crate::client::http::HttpClient;
5
6/// File upload request (multipart/form-data)
7///
8/// Sends a multipart request with fields:
9/// - purpose: `FilePurpose`
10/// - file: file content
11pub struct FileUploadRequest {
12    pub key: String,
13    purpose: FilePurpose,
14    file_path: PathBuf,
15    file_name: Option<String>,
16    content_type: Option<String>,
17}
18
19impl FileUploadRequest {
20    pub fn new(key: String, purpose: FilePurpose, file_path: impl Into<PathBuf>) -> Self {
21        Self {
22            key,
23            purpose,
24            file_path: file_path.into(),
25            file_name: None,
26            content_type: None,
27        }
28    }
29
30    pub fn with_file_name(mut self, name: impl Into<String>) -> Self {
31        self.file_name = Some(name.into());
32        self
33    }
34
35    pub fn with_content_type(mut self, ct: impl Into<String>) -> Self {
36        self.content_type = Some(ct.into());
37        self
38    }
39}
40
41impl FileUploadRequest {
42    /// Send the upload request and parse typed response (`FileObject`).
43    pub async fn send(&self) -> anyhow::Result<super::response::FileObject> {
44        let resp: reqwest::Response = self.post().await?;
45        let parsed = resp.json::<super::response::FileObject>().await?;
46        Ok(parsed)
47    }
48}
49
50impl HttpClient for FileUploadRequest {
51    type Body = (); // unused
52    type ApiUrl = &'static str;
53    type ApiKey = String;
54
55    fn api_url(&self) -> &Self::ApiUrl {
56        &"https://open.bigmodel.cn/api/paas/v4/files"
57    }
58    fn api_key(&self) -> &Self::ApiKey {
59        &self.key
60    }
61    fn body(&self) -> &Self::Body {
62        &()
63    }
64
65    // Override POST to send multipart/form-data
66    fn post(&self) -> impl std::future::Future<Output = anyhow::Result<reqwest::Response>> + Send {
67        let url: String = "https://open.bigmodel.cn/api/paas/v4/files".to_string();
68        let key: String = self.key.clone();
69        let purpose = self.purpose.clone();
70        let path = self.file_path.clone();
71        let file_name = self.file_name.clone();
72        let content_type = self.content_type.clone();
73        async move {
74            let mut form =
75                reqwest::multipart::Form::new().text("purpose", purpose.as_str().to_string());
76
77            let fname = file_name
78                .or_else(|| {
79                    path.file_name()
80                        .and_then(|s| s.to_str())
81                        .map(|s| s.to_string())
82                })
83                .unwrap_or_else(|| "upload.bin".to_string());
84
85            let mut part = reqwest::multipart::Part::bytes(std::fs::read(&path)?).file_name(fname);
86            if let Some(ct) = content_type {
87                part = part
88                    .mime_str(&ct)
89                    .map_err(|e| anyhow::anyhow!("invalid content-type: {}", e))?;
90            }
91            form = form.part("file", part);
92
93            let resp = reqwest::Client::new()
94                .post(url)
95                .bearer_auth(key)
96                .multipart(form)
97                .send()
98                .await?;
99
100            let status = resp.status();
101            if status.is_success() {
102                return Ok(resp);
103            }
104
105            // Parse standard error envelope {"error": { code, message }}
106            let text = resp.text().await.unwrap_or_default();
107            #[derive(serde::Deserialize)]
108            struct ErrEnv {
109                error: ErrObj,
110            }
111            #[derive(serde::Deserialize)]
112            struct ErrObj {
113                code: serde_json::Value,
114                message: String,
115            }
116            if let Ok(parsed) = serde_json::from_str::<ErrEnv>(&text) {
117                return Err(anyhow::anyhow!(
118                    "HTTP {} {} | code={} | message={}",
119                    status.as_u16(),
120                    status.canonical_reason().unwrap_or(""),
121                    parsed.error.code,
122                    parsed.error.message
123                ));
124            } else {
125                return Err(anyhow::anyhow!(
126                    "HTTP {} {} | body={}",
127                    status.as_u16(),
128                    status.canonical_reason().unwrap_or(""),
129                    text
130                ));
131            }
132        }
133    }
134}