zai_rs/knowledge/
document_upload_file.rs1use std::{collections::BTreeMap, path::PathBuf};
2
3use validator::Validate;
4
5use super::types::UploadFileResponse;
6use crate::client::http::HttpClient;
7
8#[derive(Debug, Clone, Copy)]
10pub enum DocumentSliceType {
11 TitleParagraph = 1,
13 QaPair = 2,
15 Line = 3,
17 Custom = 5,
19 Page = 6,
21 Single = 7,
23}
24impl DocumentSliceType {
25 fn as_i64(self) -> i64 {
26 self as i64
27 }
28}
29
30#[derive(Debug, Clone, Default, Validate)]
32pub struct UploadFileOptions {
33 pub knowledge_type: Option<DocumentSliceType>,
35 pub custom_separator: Option<Vec<String>>,
37 #[validate(range(min = 20, max = 2000))]
39 pub sentence_size: Option<u32>,
40 pub parse_image: Option<bool>,
42 #[validate(url)]
44 pub callback_url: Option<String>,
45 pub callback_header: Option<BTreeMap<String, String>>,
47 pub word_num_limit: Option<String>,
49 #[validate(length(min = 1))]
51 pub req_id: Option<String>,
52}
53
54pub struct DocumentUploadFileRequest {
56 pub key: String,
58 url: String,
59 files: Vec<PathBuf>,
60 options: UploadFileOptions,
61}
62
63impl DocumentUploadFileRequest {
64 pub fn new(key: String, knowledge_id: impl AsRef<str>) -> Self {
66 let url = format!(
67 "https://open.bigmodel.cn/api/llm-application/open/document/upload_document/{}",
68 knowledge_id.as_ref()
69 );
70 Self {
71 key,
72 url,
73 files: Vec::new(),
74 options: UploadFileOptions::default(),
75 }
76 }
77
78 pub fn add_file_path(mut self, path: impl Into<PathBuf>) -> Self {
80 self.files.push(path.into());
81 self
82 }
83
84 pub fn with_options(mut self, opts: UploadFileOptions) -> Self {
86 self.options = opts;
87 self
88 }
89
90 pub fn options_mut(&mut self) -> &mut UploadFileOptions {
92 &mut self.options
93 }
94
95 fn validate_cross(&self) -> crate::ZaiResult<()> {
97 if let Some(DocumentSliceType::Custom) = self.options.knowledge_type {
99 if let Some(sz) = self.options.sentence_size
101 && !(20..=2000).contains(&sz)
102 {
103 return Err(crate::client::error::ZaiError::ApiError {
104 code: 1200,
105 message: "sentence_size must be 20..=2000 when knowledge_type=Custom (5)"
106 .to_string(),
107 });
108 }
109 }
110 if let Some(ref w) = self.options.word_num_limit
111 && !w.chars().all(|c| c.is_ascii_digit())
112 {
113 return Err(crate::client::error::ZaiError::ApiError {
114 code: 1200,
115 message: "word_num_limit must be numeric string".to_string(),
116 });
117 }
118 if self.files.is_empty() {
119 return Err(crate::client::error::ZaiError::ApiError {
120 code: 1200,
121 message: "at least one file path must be provided".to_string(),
122 });
123 }
124 Ok(())
125 }
126
127 pub async fn send(&self) -> crate::ZaiResult<UploadFileResponse> {
129 self.options.validate()?;
131 self.validate_cross()?;
132
133 let resp = self.post().await?;
134 let parsed = resp.json::<UploadFileResponse>().await?;
135 Ok(parsed)
136 }
137}
138
139impl HttpClient for DocumentUploadFileRequest {
140 type Body = (); type ApiUrl = String;
142 type ApiKey = String;
143
144 fn api_url(&self) -> &Self::ApiUrl {
145 &self.url
146 }
147 fn api_key(&self) -> &Self::ApiKey {
148 &self.key
149 }
150 fn body(&self) -> &Self::Body {
151 &()
152 }
153
154 fn post(
157 &self,
158 ) -> impl std::future::Future<Output = crate::ZaiResult<reqwest::Response>> + Send {
159 let url = self.url.clone();
160 let key = self.key.clone();
161 let files = self.files.clone();
162 let opts = self.options.clone();
163 async move {
164 let mut form = reqwest::multipart::Form::new();
165
166 if let Some(t) = opts.knowledge_type {
168 form = form.text("knowledge_type", t.as_i64().to_string());
169 }
170 if let Some(seps) = opts.custom_separator.as_ref() {
171 let s = serde_json::to_string(seps).unwrap_or("[]".to_string());
172 form = form.text("custom_separator", s);
173 }
174 if let Some(sz) = opts.sentence_size {
175 form = form.text("sentence_size", sz.to_string());
176 }
177 if let Some(pi) = opts.parse_image {
178 form = form.text("parse_image", if pi { "true" } else { "false" }.to_string());
179 }
180 if let Some(u) = opts.callback_url.as_ref() {
181 form = form.text("callback_url", u.clone());
182 }
183 if let Some(h) = opts.callback_header.as_ref() {
184 let s = serde_json::to_string(h).unwrap_or("{}".to_string());
185 form = form.text("callback_header", s);
186 }
187 if let Some(w) = opts.word_num_limit.as_ref() {
188 form = form.text("word_num_limit", w.clone());
189 }
190 if let Some(r) = opts.req_id.as_ref() {
191 form = form.text("req_id", r.clone());
192 }
193
194 for path in files {
196 let fname = path
197 .file_name()
198 .and_then(|s| s.to_str())
199 .map(|s| s.to_string())
200 .unwrap_or_else(|| "upload.bin".to_string());
201 let part = reqwest::multipart::Part::bytes(std::fs::read(&path)?).file_name(fname);
202 form = form.part("files", part);
203 }
204
205 let resp = reqwest::Client::new()
206 .post(url)
207 .bearer_auth(key)
208 .multipart(form)
209 .send()
210 .await?;
211
212 let status = resp.status();
213 if status.is_success() {
214 return Ok(resp);
215 }
216
217 let text = resp.text().await.unwrap_or_default();
219 #[derive(serde::Deserialize)]
220 struct ErrEnv {
221 error: ErrObj,
222 }
223 #[derive(serde::Deserialize)]
224 struct ErrObj {
225 _code: serde_json::Value,
226 message: String,
227 }
228
229 if let Ok(parsed) = serde_json::from_str::<ErrEnv>(&text) {
230 Err(crate::client::error::ZaiError::from_api_response(
231 status.as_u16(),
232 0,
233 parsed.error.message,
234 ))
235 } else {
236 Err(crate::client::error::ZaiError::from_api_response(
237 status.as_u16(),
238 0,
239 text,
240 ))
241 }
242 }
243 }
244}