Skip to main content

zai_rs/knowledge/
document_upload_file.rs

1use std::{collections::BTreeMap, path::PathBuf};
2
3use validator::Validate;
4
5use super::types::UploadFileResponse;
6use crate::client::http::HttpClient;
7
8/// Slice type (knowledge_type)
9#[derive(Debug, Clone, Copy)]
10pub enum DocumentSliceType {
11    /// 1: Title-paragraph slicing (txt, doc, pdf, url, docx, ppt, pptx, md)
12    TitleParagraph = 1,
13    /// 2: Q&A slicing (txt, doc, pdf, url, docx, ppt, pptx, md)
14    QaPair = 2,
15    /// 3: Line slicing (xls, xlsx, csv)
16    Line = 3,
17    /// 5: Custom slicing (txt, doc, pdf, url, docx, ppt, pptx, md)
18    Custom = 5,
19    /// 6: Page slicing (pdf, ppt, pptx)
20    Page = 6,
21    /// 7: Single slice (xls, xlsx, csv)
22    Single = 7,
23}
24impl DocumentSliceType {
25    fn as_i64(self) -> i64 {
26        self as i64
27    }
28}
29
30/// Optional parameters for file upload
31#[derive(Debug, Clone, Default, Validate)]
32pub struct UploadFileOptions {
33    /// Document type; if omitted, the server parses dynamically
34    pub knowledge_type: Option<DocumentSliceType>,
35    /// Custom slicing rules; used when knowledge_type = 5
36    pub custom_separator: Option<Vec<String>>,
37    /// Custom slice size; used when knowledge_type = 5; valid range: 20..=2000
38    #[validate(range(min = 20, max = 2000))]
39    pub sentence_size: Option<u32>,
40    /// Whether to parse images
41    pub parse_image: Option<bool>,
42    /// Callback URL
43    #[validate(url)]
44    pub callback_url: Option<String>,
45    /// Callback headers
46    pub callback_header: Option<BTreeMap<String, String>>,
47    /// Document word number limit (must be numeric string per API)
48    pub word_num_limit: Option<String>,
49    /// Request id
50    #[validate(length(min = 1))]
51    pub req_id: Option<String>,
52}
53
54/// File upload request (multipart/form-data)
55pub struct DocumentUploadFileRequest {
56    /// Bearer API key
57    pub key: String,
58    url: String,
59    files: Vec<PathBuf>,
60    options: UploadFileOptions,
61}
62
63impl DocumentUploadFileRequest {
64    /// Create a new request for a specific knowledge base id
65    pub fn new(key: String, knowledge_id: impl AsRef<str>) -> Self {
66        let url = format!(
67            "https://open.bigmodel.cn/api/llm-application/open/document/upload_document/{}",
68            knowledge_id.as_ref()
69        );
70        Self {
71            key,
72            url,
73            files: Vec::new(),
74            options: UploadFileOptions::default(),
75        }
76    }
77
78    /// Add a local file path to upload
79    pub fn add_file_path(mut self, path: impl Into<PathBuf>) -> Self {
80        self.files.push(path.into());
81        self
82    }
83
84    /// Set optional parameters
85    pub fn with_options(mut self, opts: UploadFileOptions) -> Self {
86        self.options = opts;
87        self
88    }
89
90    /// Mutable access to options for incremental configuration
91    pub fn options_mut(&mut self) -> &mut UploadFileOptions {
92        &mut self.options
93    }
94
95    /// Validate cross-field constraints not expressible via `validator`
96    fn validate_cross(&self) -> crate::ZaiResult<()> {
97        // When knowledge_type is Custom (5), sentence_size should be within 20..=2000
98        if let Some(DocumentSliceType::Custom) = self.options.knowledge_type {
99            // sentence_size recommended; API shows default 300; we ensure range if provided
100            if let Some(sz) = self.options.sentence_size
101                && !(20..=2000).contains(&sz)
102            {
103                return Err(crate::client::error::ZaiError::ApiError {
104                    code: 1200,
105                    message: "sentence_size must be 20..=2000 when knowledge_type=Custom (5)"
106                        .to_string(),
107                });
108            }
109        }
110        if let Some(ref w) = self.options.word_num_limit
111            && !w.chars().all(|c| c.is_ascii_digit())
112        {
113            return Err(crate::client::error::ZaiError::ApiError {
114                code: 1200,
115                message: "word_num_limit must be numeric string".to_string(),
116            });
117        }
118        if self.files.is_empty() {
119            return Err(crate::client::error::ZaiError::ApiError {
120                code: 1200,
121                message: "at least one file path must be provided".to_string(),
122            });
123        }
124        Ok(())
125    }
126
127    /// Send multipart request and parse typed response
128    pub async fn send(&self) -> crate::ZaiResult<UploadFileResponse> {
129        // Field validations
130        self.options.validate()?;
131        self.validate_cross()?;
132
133        let resp = self.post().await?;
134        let parsed = resp.json::<UploadFileResponse>().await?;
135        Ok(parsed)
136    }
137}
138
139impl HttpClient for DocumentUploadFileRequest {
140    type Body = (); // unused
141    type ApiUrl = String;
142    type ApiKey = String;
143
144    fn api_url(&self) -> &Self::ApiUrl {
145        &self.url
146    }
147    fn api_key(&self) -> &Self::ApiKey {
148        &self.key
149    }
150    fn body(&self) -> &Self::Body {
151        &()
152    }
153
154    // Override POST to send multipart/form-data
155
156    fn post(
157        &self,
158    ) -> impl std::future::Future<Output = crate::ZaiResult<reqwest::Response>> + Send {
159        let url = self.url.clone();
160        let key = self.key.clone();
161        let files = self.files.clone();
162        let opts = self.options.clone();
163        async move {
164            let mut form = reqwest::multipart::Form::new();
165
166            // Optional fields
167            if let Some(t) = opts.knowledge_type {
168                form = form.text("knowledge_type", t.as_i64().to_string());
169            }
170            if let Some(seps) = opts.custom_separator.as_ref() {
171                let s = serde_json::to_string(seps).unwrap_or("[]".to_string());
172                form = form.text("custom_separator", s);
173            }
174            if let Some(sz) = opts.sentence_size {
175                form = form.text("sentence_size", sz.to_string());
176            }
177            if let Some(pi) = opts.parse_image {
178                form = form.text("parse_image", if pi { "true" } else { "false" }.to_string());
179            }
180            if let Some(u) = opts.callback_url.as_ref() {
181                form = form.text("callback_url", u.clone());
182            }
183            if let Some(h) = opts.callback_header.as_ref() {
184                let s = serde_json::to_string(h).unwrap_or("{}".to_string());
185                form = form.text("callback_header", s);
186            }
187            if let Some(w) = opts.word_num_limit.as_ref() {
188                form = form.text("word_num_limit", w.clone());
189            }
190            if let Some(r) = opts.req_id.as_ref() {
191                form = form.text("req_id", r.clone());
192            }
193
194            // Files: use field name "files" per API
195            for path in files {
196                let fname = path
197                    .file_name()
198                    .and_then(|s| s.to_str())
199                    .map(|s| s.to_string())
200                    .unwrap_or_else(|| "upload.bin".to_string());
201                let part = reqwest::multipart::Part::bytes(std::fs::read(&path)?).file_name(fname);
202                form = form.part("files", part);
203            }
204
205            let resp = reqwest::Client::new()
206                .post(url)
207                .bearer_auth(key)
208                .multipart(form)
209                .send()
210                .await?;
211
212            let status = resp.status();
213            if status.is_success() {
214                return Ok(resp);
215            }
216
217            // Standard error envelope {"error": { code, message }}
218            let text = resp.text().await.unwrap_or_default();
219            #[derive(serde::Deserialize)]
220            struct ErrEnv {
221                error: ErrObj,
222            }
223            #[derive(serde::Deserialize)]
224            struct ErrObj {
225                _code: serde_json::Value,
226                message: String,
227            }
228
229            if let Ok(parsed) = serde_json::from_str::<ErrEnv>(&text) {
230                Err(crate::client::error::ZaiError::from_api_response(
231                    status.as_u16(),
232                    0,
233                    parsed.error.message,
234                ))
235            } else {
236                Err(crate::client::error::ZaiError::from_api_response(
237                    status.as_u16(),
238                    0,
239                    text,
240                ))
241            }
242        }
243    }
244}