zai_rs/knowledge/
document_upload_file.rs1use std::collections::BTreeMap;
2use std::path::PathBuf;
3
4use crate::client::http::HttpClient;
5use validator::Validate;
6
7use super::types::UploadFileResponse;
8
9#[derive(Debug, Clone, Copy)]
11pub enum DocumentSliceType {
12 TitleParagraph = 1,
14 QaPair = 2,
16 Line = 3,
18 Custom = 5,
20 Page = 6,
22 Single = 7,
24}
25impl DocumentSliceType {
26 fn as_i64(self) -> i64 {
27 self as i64
28 }
29}
30
31#[derive(Debug, Clone, Default, Validate)]
33pub struct UploadFileOptions {
34 pub knowledge_type: Option<DocumentSliceType>,
36 pub custom_separator: Option<Vec<String>>,
38 #[validate(range(min = 20, max = 2000))]
40 pub sentence_size: Option<u32>,
41 pub parse_image: Option<bool>,
43 #[validate(url)]
45 pub callback_url: Option<String>,
46 pub callback_header: Option<BTreeMap<String, String>>,
48 pub word_num_limit: Option<String>,
50 #[validate(length(min = 1))]
52 pub req_id: Option<String>,
53}
54
55pub struct DocumentUploadFileRequest {
57 pub key: String,
59 url: String,
60 files: Vec<PathBuf>,
61 options: UploadFileOptions,
62}
63
64impl DocumentUploadFileRequest {
65 pub fn new(key: String, knowledge_id: impl AsRef<str>) -> Self {
67 let url = format!(
68 "https://open.bigmodel.cn/api/llm-application/open/document/upload_document/{}",
69 knowledge_id.as_ref()
70 );
71 Self {
72 key,
73 url,
74 files: Vec::new(),
75 options: UploadFileOptions::default(),
76 }
77 }
78
79 pub fn add_file_path(mut self, path: impl Into<PathBuf>) -> Self {
81 self.files.push(path.into());
82 self
83 }
84
85 pub fn with_options(mut self, opts: UploadFileOptions) -> Self {
87 self.options = opts;
88 self
89 }
90
91 pub fn options_mut(&mut self) -> &mut UploadFileOptions {
93 &mut self.options
94 }
95
96 fn validate_cross(&self) -> anyhow::Result<()> {
98 if let Some(DocumentSliceType::Custom) = self.options.knowledge_type {
100 if let Some(sz) = self.options.sentence_size {
102 if !(20..=2000).contains(&sz) {
103 anyhow::bail!("sentence_size must be 20..=2000 when knowledge_type=Custom (5)");
104 }
105 }
106 }
107 if let Some(ref w) = self.options.word_num_limit {
108 if !w.chars().all(|c| c.is_ascii_digit()) {
109 anyhow::bail!("word_num_limit must be numeric string");
110 }
111 }
112 if self.files.is_empty() {
113 anyhow::bail!("at least one file path must be provided");
114 }
115 Ok(())
116 }
117
118 pub async fn send(&self) -> anyhow::Result<UploadFileResponse> {
120 self.options.validate()?;
122 self.validate_cross()?;
123
124 let resp = self.post().await?;
125 let parsed = resp.json::<UploadFileResponse>().await?;
126 Ok(parsed)
127 }
128}
129
130impl HttpClient for DocumentUploadFileRequest {
131 type Body = (); type ApiUrl = String;
133 type ApiKey = String;
134
135 fn api_url(&self) -> &Self::ApiUrl {
136 &self.url
137 }
138 fn api_key(&self) -> &Self::ApiKey {
139 &self.key
140 }
141 fn body(&self) -> &Self::Body {
142 &()
143 }
144
145 fn post(&self) -> impl std::future::Future<Output = anyhow::Result<reqwest::Response>> + Send {
147 let url = self.url.clone();
148 let key = self.key.clone();
149 let files = self.files.clone();
150 let opts = self.options.clone();
151 async move {
152 let mut form = reqwest::multipart::Form::new();
153
154 if let Some(t) = opts.knowledge_type {
156 form = form.text("knowledge_type", t.as_i64().to_string());
157 }
158 if let Some(seps) = opts.custom_separator.as_ref() {
159 let s = serde_json::to_string(seps).unwrap_or("[]".to_string());
160 form = form.text("custom_separator", s);
161 }
162 if let Some(sz) = opts.sentence_size {
163 form = form.text("sentence_size", sz.to_string());
164 }
165 if let Some(pi) = opts.parse_image {
166 form = form.text("parse_image", if pi { "true" } else { "false" }.to_string());
167 }
168 if let Some(u) = opts.callback_url.as_ref() {
169 form = form.text("callback_url", u.clone());
170 }
171 if let Some(h) = opts.callback_header.as_ref() {
172 let s = serde_json::to_string(h).unwrap_or("{}".to_string());
173 form = form.text("callback_header", s);
174 }
175 if let Some(w) = opts.word_num_limit.as_ref() {
176 form = form.text("word_num_limit", w.clone());
177 }
178 if let Some(r) = opts.req_id.as_ref() {
179 form = form.text("req_id", r.clone());
180 }
181
182 for path in files {
184 let fname = path
185 .file_name()
186 .and_then(|s| s.to_str())
187 .map(|s| s.to_string())
188 .unwrap_or_else(|| "upload.bin".to_string());
189 let part = reqwest::multipart::Part::bytes(std::fs::read(&path)?).file_name(fname);
190 form = form.part("files", part);
191 }
192
193 let resp = reqwest::Client::new()
194 .post(url)
195 .bearer_auth(key)
196 .multipart(form)
197 .send()
198 .await?;
199
200 let status = resp.status();
201 if status.is_success() {
202 return Ok(resp);
203 }
204
205 let text = resp.text().await.unwrap_or_default();
207 #[derive(serde::Deserialize)]
208 struct ErrEnv {
209 error: ErrObj,
210 }
211 #[derive(serde::Deserialize)]
212 struct ErrObj {
213 code: serde_json::Value,
214 message: String,
215 }
216 if let Ok(parsed) = serde_json::from_str::<ErrEnv>(&text) {
217 return Err(anyhow::anyhow!(
218 "HTTP {} {} | code={} | message={}",
219 status.as_u16(),
220 status.canonical_reason().unwrap_or(""),
221 parsed.error.code,
222 parsed.error.message
223 ));
224 }
225 Err(anyhow::anyhow!(
226 "HTTP {} {} | body={}",
227 status.as_u16(),
228 status.canonical_reason().unwrap_or(""),
229 text
230 ))
231 }
232 }
233}