zai_rs/knowledge/
document_upload_file.rs

1use std::collections::BTreeMap;
2use std::path::PathBuf;
3
4use crate::client::http::HttpClient;
5use validator::Validate;
6
7use super::types::UploadFileResponse;
8
9/// Slice type (knowledge_type)
10#[derive(Debug, Clone, Copy)]
11pub enum DocumentSliceType {
12    /// 1: Title-paragraph slicing (txt, doc, pdf, url, docx, ppt, pptx, md)
13    TitleParagraph = 1,
14    /// 2: Q&A slicing (txt, doc, pdf, url, docx, ppt, pptx, md)
15    QaPair = 2,
16    /// 3: Line slicing (xls, xlsx, csv)
17    Line = 3,
18    /// 5: Custom slicing (txt, doc, pdf, url, docx, ppt, pptx, md)
19    Custom = 5,
20    /// 6: Page slicing (pdf, ppt, pptx)
21    Page = 6,
22    /// 7: Single slice (xls, xlsx, csv)
23    Single = 7,
24}
25impl DocumentSliceType {
26    fn as_i64(self) -> i64 {
27        self as i64
28    }
29}
30
31/// Optional parameters for file upload
32#[derive(Debug, Clone, Default, Validate)]
33pub struct UploadFileOptions {
34    /// Document type; if omitted, the server parses dynamically
35    pub knowledge_type: Option<DocumentSliceType>,
36    /// Custom slicing rules; used when knowledge_type = 5
37    pub custom_separator: Option<Vec<String>>,
38    /// Custom slice size; used when knowledge_type = 5; valid range: 20..=2000
39    #[validate(range(min = 20, max = 2000))]
40    pub sentence_size: Option<u32>,
41    /// Whether to parse images
42    pub parse_image: Option<bool>,
43    /// Callback URL
44    #[validate(url)]
45    pub callback_url: Option<String>,
46    /// Callback headers
47    pub callback_header: Option<BTreeMap<String, String>>,
48    /// Document word number limit (must be numeric string per API)
49    pub word_num_limit: Option<String>,
50    /// Request id
51    #[validate(length(min = 1))]
52    pub req_id: Option<String>,
53}
54
55/// File upload request (multipart/form-data)
56pub struct DocumentUploadFileRequest {
57    /// Bearer API key
58    pub key: String,
59    url: String,
60    files: Vec<PathBuf>,
61    options: UploadFileOptions,
62}
63
64impl DocumentUploadFileRequest {
65    /// Create a new request for a specific knowledge base id
66    pub fn new(key: String, knowledge_id: impl AsRef<str>) -> Self {
67        let url = format!(
68            "https://open.bigmodel.cn/api/llm-application/open/document/upload_document/{}",
69            knowledge_id.as_ref()
70        );
71        Self {
72            key,
73            url,
74            files: Vec::new(),
75            options: UploadFileOptions::default(),
76        }
77    }
78
79    /// Add a local file path to upload
80    pub fn add_file_path(mut self, path: impl Into<PathBuf>) -> Self {
81        self.files.push(path.into());
82        self
83    }
84
85    /// Set optional parameters
86    pub fn with_options(mut self, opts: UploadFileOptions) -> Self {
87        self.options = opts;
88        self
89    }
90
91    /// Mutable access to options for incremental configuration
92    pub fn options_mut(&mut self) -> &mut UploadFileOptions {
93        &mut self.options
94    }
95
96    /// Validate cross-field constraints not expressible via `validator`
97    fn validate_cross(&self) -> anyhow::Result<()> {
98        // When knowledge_type is Custom (5), sentence_size should be within 20..=2000
99        if let Some(DocumentSliceType::Custom) = self.options.knowledge_type {
100            // sentence_size recommended; API shows default 300; we ensure range if provided
101            if let Some(sz) = self.options.sentence_size {
102                if !(20..=2000).contains(&sz) {
103                    anyhow::bail!("sentence_size must be 20..=2000 when knowledge_type=Custom (5)");
104                }
105            }
106        }
107        if let Some(ref w) = self.options.word_num_limit {
108            if !w.chars().all(|c| c.is_ascii_digit()) {
109                anyhow::bail!("word_num_limit must be numeric string");
110            }
111        }
112        if self.files.is_empty() {
113            anyhow::bail!("at least one file path must be provided");
114        }
115        Ok(())
116    }
117
118    /// Send multipart request and parse typed response
119    pub async fn send(&self) -> anyhow::Result<UploadFileResponse> {
120        // Field validations
121        self.options.validate()?;
122        self.validate_cross()?;
123
124        let resp = self.post().await?;
125        let parsed = resp.json::<UploadFileResponse>().await?;
126        Ok(parsed)
127    }
128}
129
130impl HttpClient for DocumentUploadFileRequest {
131    type Body = (); // unused
132    type ApiUrl = String;
133    type ApiKey = String;
134
135    fn api_url(&self) -> &Self::ApiUrl {
136        &self.url
137    }
138    fn api_key(&self) -> &Self::ApiKey {
139        &self.key
140    }
141    fn body(&self) -> &Self::Body {
142        &()
143    }
144
145    // Override POST to send multipart/form-data
146    fn post(&self) -> impl std::future::Future<Output = anyhow::Result<reqwest::Response>> + Send {
147        let url = self.url.clone();
148        let key = self.key.clone();
149        let files = self.files.clone();
150        let opts = self.options.clone();
151        async move {
152            let mut form = reqwest::multipart::Form::new();
153
154            // Optional fields
155            if let Some(t) = opts.knowledge_type {
156                form = form.text("knowledge_type", t.as_i64().to_string());
157            }
158            if let Some(seps) = opts.custom_separator.as_ref() {
159                let s = serde_json::to_string(seps).unwrap_or("[]".to_string());
160                form = form.text("custom_separator", s);
161            }
162            if let Some(sz) = opts.sentence_size {
163                form = form.text("sentence_size", sz.to_string());
164            }
165            if let Some(pi) = opts.parse_image {
166                form = form.text("parse_image", if pi { "true" } else { "false" }.to_string());
167            }
168            if let Some(u) = opts.callback_url.as_ref() {
169                form = form.text("callback_url", u.clone());
170            }
171            if let Some(h) = opts.callback_header.as_ref() {
172                let s = serde_json::to_string(h).unwrap_or("{}".to_string());
173                form = form.text("callback_header", s);
174            }
175            if let Some(w) = opts.word_num_limit.as_ref() {
176                form = form.text("word_num_limit", w.clone());
177            }
178            if let Some(r) = opts.req_id.as_ref() {
179                form = form.text("req_id", r.clone());
180            }
181
182            // Files: use field name "files" per API
183            for path in files {
184                let fname = path
185                    .file_name()
186                    .and_then(|s| s.to_str())
187                    .map(|s| s.to_string())
188                    .unwrap_or_else(|| "upload.bin".to_string());
189                let part = reqwest::multipart::Part::bytes(std::fs::read(&path)?).file_name(fname);
190                form = form.part("files", part);
191            }
192
193            let resp = reqwest::Client::new()
194                .post(url)
195                .bearer_auth(key)
196                .multipart(form)
197                .send()
198                .await?;
199
200            let status = resp.status();
201            if status.is_success() {
202                return Ok(resp);
203            }
204
205            // Standard error envelope {"error": { code, message }}
206            let text = resp.text().await.unwrap_or_default();
207            #[derive(serde::Deserialize)]
208            struct ErrEnv {
209                error: ErrObj,
210            }
211            #[derive(serde::Deserialize)]
212            struct ErrObj {
213                code: serde_json::Value,
214                message: String,
215            }
216            if let Ok(parsed) = serde_json::from_str::<ErrEnv>(&text) {
217                return Err(anyhow::anyhow!(
218                    "HTTP {} {} | code={} | message={}",
219                    status.as_u16(),
220                    status.canonical_reason().unwrap_or(""),
221                    parsed.error.code,
222                    parsed.error.message
223                ));
224            }
225            Err(anyhow::anyhow!(
226                "HTTP {} {} | body={}",
227                status.as_u16(),
228                status.canonical_reason().unwrap_or(""),
229                text
230            ))
231        }
232    }
233}