1use crate::ElementList;
2use reqwest::multipart::Form;
3use serde::{Deserialize, Serialize};
4
5#[derive(Serialize, Deserialize, Debug, PartialEq)]
7#[serde(rename_all = "snake_case")]
8pub enum ChunkingStrategy {
9 Basic,
10 ByPage,
11 BySimilarity,
12 ByTitle,
13}
14
15#[derive(Serialize, Deserialize, Debug, PartialEq)]
17#[serde(rename_all = "snake_case")]
18pub enum Strategy {
19 Fast,
20 HiRes,
21 Auto,
22 OcrOnly,
23}
24
25#[derive(Serialize, Deserialize, Debug, PartialEq)]
27pub enum OutputFormat {
28 #[serde(rename = "application/json")]
29 ApplicationJson,
30
31 #[serde(rename = "text/csv")]
32 TextCsv,
33}
34
35#[derive(Debug, Serialize, Deserialize)]
36pub struct PartitionParameters {
37 pub coordinates: bool,
39
40 pub encoding: Option<String>,
42
43 pub extract_image_block_types: Vec<String>,
45
46 pub gz_uncompressed_content_type: Option<String>,
48
49 pub hi_res_model_name: Option<String>,
51
52 pub include_page_breaks: bool,
54
55 pub languages: Option<Vec<String>>,
57
58 pub output_format: String,
60
61 pub skip_infer_table_types: Vec<String>,
63
64 pub starting_page_number: Option<i32>,
66
67 pub strategy: Strategy,
69
70 pub unique_element_ids: bool,
72
73 pub xml_keep_tags: bool,
75
76 pub chunking_strategy: Option<ChunkingStrategy>,
78
79 pub combine_under_n_chars: Option<i32>,
81
82 pub include_orig_elements: bool,
84
85 pub max_characters: Option<i32>,
87
88 pub multipage_sections: bool,
90
91 pub new_after_n_chars: Option<i32>,
93
94 pub overlap: i32,
96
97 pub overlap_all: bool,
99
100 pub similarity_threshold: Option<f64>,
102}
103
104impl Default for PartitionParameters {
105 fn default() -> Self {
106 PartitionParameters {
107 coordinates: false,
108 encoding: Some("utf-8".to_string()),
109 extract_image_block_types: vec![],
110 gz_uncompressed_content_type: None,
111 hi_res_model_name: None,
112 include_page_breaks: false,
113 languages: None,
114 output_format: "application/json".to_string(),
115 skip_infer_table_types: vec![],
116 starting_page_number: None,
117 strategy: Strategy::Auto,
118 unique_element_ids: false,
119 xml_keep_tags: false,
120 chunking_strategy: None,
121 combine_under_n_chars: None,
122 include_orig_elements: true,
123 max_characters: None,
124 multipage_sections: true,
125 new_after_n_chars: None,
126 overlap: 0,
127 overlap_all: false,
128 similarity_threshold: None,
129 }
130 }
131}
132
133impl From<PartitionParameters> for Form {
134 fn from(value: PartitionParameters) -> Self {
135 let mut form = Form::new();
136 form = form.text("coordinates", value.coordinates.to_string());
137 if let Some(encoding) = value.encoding.clone() {
138 form = form.text("encoding", encoding);
139 }
140 form = form.text(
141 "extract_image_block_types",
142 serde_json::to_string(&value.extract_image_block_types).unwrap(),
143 );
144 if let Some(gz_uncompressed_content_type) = value.gz_uncompressed_content_type.clone() {
145 form = form.text("gz_uncompressed_content_type", gz_uncompressed_content_type);
146 }
147 if let Some(hi_res_model_name) = value.hi_res_model_name.clone() {
148 form = form.text("hi_res_model_name", hi_res_model_name);
149 }
150 form = form.text("include_page_breaks", value.include_page_breaks.to_string());
151 if let Some(languages) = value.languages.clone() {
152 form = form.text("languages", serde_json::to_string(&languages).unwrap());
153 }
154 form = form.text("output_format", value.output_format.clone());
155 form = form.text(
156 "skip_infer_table_types",
157 serde_json::to_string(&value.skip_infer_table_types).unwrap(),
158 );
159 if let Some(starting_page_number) = value.starting_page_number {
160 form = form.text("starting_page_number", starting_page_number.to_string());
161 }
162 form = form.text("strategy", {
163 let s = String::from(
164 serde_json::to_string(&value.strategy)
165 .expect("Could not convert Strategy enum to string.")
166 .trim_matches('"'),
167 );
168 s
169 });
170 form = form.text("unique_element_ids", value.unique_element_ids.to_string());
171 form = form.text("xml_keep_tags", value.xml_keep_tags.to_string());
172 if let Some(chunking_strategy) = value
173 .chunking_strategy
174 .as_ref()
175 .map(serde_json::to_string)
176 .transpose()
177 .expect("Could not convert Chunking Strategy enum to string.")
178 {
179 form = form.text(
180 "chunking_strategy",
181 chunking_strategy.trim_matches('"').to_string(),
182 );
183 }
184 if let Some(combine_under_n_chars) = value.combine_under_n_chars {
185 form = form.text("combine_under_n_chars", combine_under_n_chars.to_string());
186 }
187 form = form.text(
188 "include_orig_elements",
189 value.include_orig_elements.to_string(),
190 );
191 if let Some(max_characters) = value.max_characters {
192 form = form.text("max_characters", max_characters.to_string());
193 }
194 form = form.text("multipage_sections", value.multipage_sections.to_string());
195 if let Some(new_after_n_chars) = value.new_after_n_chars {
196 form = form.text("new_after_n_chars", new_after_n_chars.to_string());
197 }
198 form = form.text("overlap", value.overlap.to_string());
199 form = form.text("overlap_all", value.overlap_all.to_string());
200 form
201 }
202}
203
204#[derive(Serialize, Deserialize, Debug)]
205#[serde(untagged)]
206pub enum LocElement {
207 Str(String),
208 Int(i64),
209}
210
211#[derive(Serialize, Deserialize, Debug)]
212pub struct ValidationError {
213 pub loc: Vec<LocElement>,
214 pub msg: String,
215 pub r#type: String,
216}
217
218#[derive(Debug, Deserialize, Serialize)]
219#[serde(untagged)]
220pub enum PartitionResponse {
221 Success(ElementList),
223
224 ValidationFailure(ValidationError),
226
227 UnknownFailure(serde_json::Value),
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234
235 #[test]
236 fn test_default_partition_params() {
237 let params = PartitionParameters::default();
238 println!("{:?}", params)
239 }
240
241 #[test]
242 fn test_deserialize_chunking_strategy() {
243 let json = r#""basic""#;
244 let strategy: ChunkingStrategy = serde_json::from_str(json).unwrap();
245 assert_eq!(strategy, ChunkingStrategy::Basic);
246 }
247
248 #[test]
249 fn test_deserialize_strategy() {
250 let json = r#""auto""#;
251 let strategy: Strategy = serde_json::from_str(json).unwrap();
252 assert_eq!(strategy, Strategy::Auto);
253 }
254
255 #[test]
256 fn test_deserialize_output_format() {
257 let json = r#""application/json""#;
258 let format: OutputFormat = serde_json::from_str(json).unwrap();
259 assert_eq!(format, OutputFormat::ApplicationJson);
260 }
261
262 #[test]
263 fn test_deserialize_partition_parameters() {
264 let json = r#"{
265 "coordinates": true,
266 "encoding": "utf-8",
267 "extract_image_block_types": [],
268 "gz_uncompressed_content_type": null,
269 "hi_res_model_name": null,
270 "include_page_breaks": true,
271 "languages": null,
272 "output_format": "application/json",
273 "skip_infer_table_types": [],
274 "starting_page_number": null,
275 "strategy": "auto",
276 "unique_element_ids": false,
277 "xml_keep_tags": false,
278 "chunking_strategy": null,
279 "combine_under_n_chars": null,
280 "include_orig_elements": true,
281 "max_characters": null,
282 "multipage_sections": true,
283 "new_after_n_chars": null,
284 "overlap": 0,
285 "overlap_all": false,
286 "similarity_threshold": null
287 }"#;
288 let params: PartitionParameters = serde_json::from_str(json).unwrap();
289 assert_eq!(params.coordinates, true);
290 assert_eq!(params.encoding.unwrap(), "utf-8");
291 assert_eq!(params.include_page_breaks, true);
292 assert_eq!(params.output_format, "application/json".to_string());
293 assert_eq!(params.include_orig_elements, true);
294 assert_eq!(params.multipage_sections, true);
295 assert_eq!(params.overlap, 0);
296 assert_eq!(params.overlap_all, false);
297 }
298}