1use std::path::Path;
2
3use validator::Validate;
4
5use super::request::{OcrBody, OcrLanguageType, OcrToolType};
6use crate::client::http::{HttpClient, HttpClientConfig, http_client_with_config};
7
8pub struct OcrRequest {
10 pub key: String,
11 pub body: OcrBody,
12 file_path: Option<String>,
13}
14
15impl OcrRequest {
16 pub fn new(key: String) -> Self {
17 Self {
18 key,
19 body: OcrBody::new(),
20 file_path: None,
21 }
22 }
23
24 pub fn with_file_path(mut self, path: impl Into<String>) -> Self {
25 self.file_path = Some(path.into());
26 self
27 }
28
29 pub fn with_tool_type(mut self, tool_type: OcrToolType) -> Self {
30 self.body = self.body.with_tool_type(tool_type);
31 self
32 }
33
34 pub fn with_language_type(mut self, language_type: OcrLanguageType) -> Self {
35 self.body = self.body.with_language_type(language_type);
36 self
37 }
38
39 pub fn with_probability(mut self, probability: bool) -> Self {
40 self.body = self.body.with_probability(probability);
41 self
42 }
43
44 pub fn with_request_id(mut self, request_id: impl Into<String>) -> Self {
45 self.body = self.body.with_request_id(request_id);
46 self
47 }
48
49 pub fn with_user_id(mut self, user_id: impl Into<String>) -> Self {
50 self.body = self.body.with_user_id(user_id);
51 self
52 }
53
54 pub fn validate(&self) -> crate::ZaiResult<()> {
55 self.body
57 .validate()
58 .map_err(crate::client::error::ZaiError::from)?;
59
60 let p =
62 self.file_path
63 .as_ref()
64 .ok_or_else(|| crate::client::error::ZaiError::ApiError {
65 code: 1200,
66 message: "file_path is required".to_string(),
67 })?;
68
69 if !Path::new(p).exists() {
70 return Err(crate::client::error::ZaiError::FileError {
71 code: 0,
72 message: format!("file_path not found: {}", p),
73 });
74 }
75
76 let metadata = std::fs::metadata(p)?;
78 let file_size = metadata.len();
79 const MAX_SIZE: u64 = 8 * 1024 * 1024; if file_size > MAX_SIZE {
81 return Err(crate::client::error::ZaiError::FileError {
82 code: 0,
83 message: format!("file_size exceeds 8MB limit: {} bytes", file_size),
84 });
85 }
86
87 let ext = Path::new(p)
89 .extension()
90 .and_then(|s| s.to_str())
91 .map(|s| s.to_ascii_lowercase());
92 let valid_ext = matches!(
93 ext.as_deref(),
94 Some("png") | Some("jpg") | Some("jpeg") | Some("bmp")
95 );
96 if !valid_ext {
97 return Err(crate::client::error::ZaiError::FileError {
98 code: 0,
99 message: format!(
100 "invalid file format: {:?}. Only PNG, JPG, JPEG, BMP are supported",
101 ext
102 ),
103 });
104 }
105
106 Ok(())
107 }
108
109 pub async fn send(&self) -> crate::ZaiResult<super::response::OcrResponse> {
110 self.validate()?;
111
112 let resp = self.post().await?;
113
114 let parsed = resp.json::<super::response::OcrResponse>().await?;
115
116 Ok(parsed)
117 }
118}
119
120impl HttpClient for OcrRequest {
121 type Body = OcrBody;
122 type ApiUrl = &'static str;
123 type ApiKey = String;
124
125 fn api_url(&self) -> &Self::ApiUrl {
126 &"https://open.bigmodel.cn/api/paas/v4/files/ocr"
127 }
128
129 fn api_key(&self) -> &Self::ApiKey {
130 &self.key
131 }
132
133 fn body(&self) -> &Self::Body {
134 &self.body
135 }
136
137 fn post(
138 &self,
139 ) -> impl std::future::Future<Output = crate::ZaiResult<reqwest::Response>> + Send {
140 let key = self.key.clone();
141 let url = (*self.api_url()).to_string();
142 let body = self.body.clone();
143 let file_path_opt = self.file_path.clone();
144
145 async move {
146 let file_path =
147 file_path_opt.ok_or_else(|| crate::client::error::ZaiError::ApiError {
148 code: 1200,
149 message: "file_path is required".to_string(),
150 })?;
151
152 let mut form = reqwest::multipart::Form::new();
153
154 let file_name = Path::new(&file_path)
156 .file_name()
157 .and_then(|s| s.to_str())
158 .unwrap_or("image.png");
159 let file_bytes = tokio::fs::read(&file_path).await?;
160
161 let ext = Path::new(&file_path)
163 .extension()
164 .and_then(|s| s.to_str())
165 .map(|s| s.to_ascii_lowercase());
166 let mime = match ext.as_deref() {
167 Some("png") => "image/png",
168 Some("jpg") | Some("jpeg") => "image/jpeg",
169 Some("bmp") => "image/bmp",
170 _ => "image/png",
171 };
172
173 let part = reqwest::multipart::Part::bytes(file_bytes)
174 .file_name(file_name.to_string())
175 .mime_str(mime)?;
176 form = form.part("file", part);
177
178 let tool_type_str = match &body.tool_type {
180 Some(OcrToolType::HandWrite) => "hand_write",
181 None => "hand_write",
182 };
183 form = form.text("tool_type", tool_type_str);
184
185 if let Some(lang) = &body.language_type {
187 let lang_str = serde_json::to_string(lang)
188 .unwrap_or_default()
189 .trim_matches('"')
190 .to_string();
191 form = form.text("language_type", lang_str);
192 }
193
194 if let Some(prob) = body.probability {
196 form = form.text("probability", prob.to_string());
197 }
198
199 let client = http_client_with_config(&HttpClientConfig::default());
201 let resp = client
202 .post(url)
203 .bearer_auth(key)
204 .multipart(form)
205 .send()
206 .await?;
207
208 Ok(resp)
209 }
210 }
211}