Skip to main content

zai_rs/file/
upload.rs

1use std::path::PathBuf;
2
3use super::request::FilePurpose;
4use crate::client::http::HttpClient;
5
6/// File upload request (multipart/form-data)
7///
8/// Sends a multipart request with fields:
9/// - purpose: `FilePurpose`
10/// - file: file content
11pub struct FileUploadRequest {
12    pub key: String,
13    purpose: FilePurpose,
14    file_path: PathBuf,
15    file_name: Option<String>,
16    content_type: Option<String>,
17}
18
19impl FileUploadRequest {
20    pub fn new(key: String, purpose: FilePurpose, file_path: impl Into<PathBuf>) -> Self {
21        Self {
22            key,
23            purpose,
24            file_path: file_path.into(),
25            file_name: None,
26            content_type: None,
27        }
28    }
29
30    pub fn with_file_name(mut self, name: impl Into<String>) -> Self {
31        self.file_name = Some(name.into());
32        self
33    }
34
35    pub fn with_content_type(mut self, ct: impl Into<String>) -> Self {
36        self.content_type = Some(ct.into());
37
38        self
39    }
40
41    /// Send the upload request and parse typed response (`FileObject`)
42    pub async fn send(&self) -> crate::ZaiResult<super::response::FileObject> {
43        let resp: reqwest::Response = self.post().await?;
44        let parsed = resp.json::<super::response::FileObject>().await?;
45        Ok(parsed)
46    }
47}
48
49impl HttpClient for FileUploadRequest {
50    type Body = (); // unused
51    type ApiUrl = &'static str;
52    type ApiKey = String;
53
54    fn api_url(&self) -> &Self::ApiUrl {
55        &"https://open.bigmodel.cn/api/paas/v4/files"
56    }
57    fn api_key(&self) -> &Self::ApiKey {
58        &self.key
59    }
60    fn body(&self) -> &Self::Body {
61        &()
62    }
63
64    // Override POST to send multipart/form-data
65
66    fn post(
67        &self,
68    ) -> impl std::future::Future<Output = crate::ZaiResult<reqwest::Response>> + Send {
69        let url: String = "https://open.bigmodel.cn/api/paas/v4/files".to_string();
70
71        let key: String = self.key.clone();
72
73        let purpose = self.purpose.clone();
74        let path = self.file_path.clone();
75        let file_name = self.file_name.clone();
76        let content_type = self.content_type.clone();
77        async move {
78            let mut form =
79                reqwest::multipart::Form::new().text("purpose", purpose.as_str().to_string());
80
81            let fname = file_name
82                .or_else(|| {
83                    path.file_name()
84                        .and_then(|s| s.to_str())
85                        .map(|s| s.to_string())
86                })
87                .unwrap_or_else(|| "upload.bin".to_string());
88
89            let mut part = reqwest::multipart::Part::bytes(std::fs::read(&path)?).file_name(fname);
90            if let Some(ct) = content_type {
91                part =
92                    part.mime_str(&ct)
93                        .map_err(|e| crate::client::error::ZaiError::ApiError {
94                            code: 1200,
95                            message: format!("invalid content-type: {}", e),
96                        })?;
97            }
98            form = form.part("file", part);
99
100            let resp = reqwest::Client::new()
101                .post(url)
102                .bearer_auth(key)
103                .multipart(form)
104                .send()
105                .await?;
106
107            let status = resp.status();
108            if status.is_success() {
109                return Ok(resp);
110            }
111
112            // Parse standard error envelope {"error": { code, message }}
113            let text = resp.text().await.unwrap_or_default();
114            #[derive(serde::Deserialize)]
115            struct ErrEnv {
116                error: ErrObj,
117            }
118            #[derive(serde::Deserialize)]
119            struct ErrObj {
120                _code: serde_json::Value,
121                message: String,
122            }
123
124            if let Ok(parsed) = serde_json::from_str::<ErrEnv>(&text) {
125                Err(crate::client::error::ZaiError::from_api_response(
126                    status.as_u16(),
127                    0,
128                    parsed.error.message,
129                ))
130            } else {
131                Err(crate::client::error::ZaiError::from_api_response(
132                    status.as_u16(),
133                    0,
134                    text,
135                ))
136            }
137        }
138    }
139}