Skip to main content

zai_rs/model/ocr/
data.rs

1use std::path::Path;
2
3use validator::Validate;
4
5use super::request::{OcrBody, OcrLanguageType, OcrToolType};
6use crate::client::http::{HttpClient, HttpClientConfig, http_client_with_config};
7
8/// OCR recognition request (multipart/form-data)
9pub struct OcrRequest {
10    pub key: String,
11    pub body: OcrBody,
12    file_path: Option<String>,
13}
14
15impl OcrRequest {
16    pub fn new(key: String) -> Self {
17        Self {
18            key,
19            body: OcrBody::new(),
20            file_path: None,
21        }
22    }
23
24    pub fn with_file_path(mut self, path: impl Into<String>) -> Self {
25        self.file_path = Some(path.into());
26        self
27    }
28
29    pub fn with_tool_type(mut self, tool_type: OcrToolType) -> Self {
30        self.body = self.body.with_tool_type(tool_type);
31        self
32    }
33
34    pub fn with_language_type(mut self, language_type: OcrLanguageType) -> Self {
35        self.body = self.body.with_language_type(language_type);
36        self
37    }
38
39    pub fn with_probability(mut self, probability: bool) -> Self {
40        self.body = self.body.with_probability(probability);
41        self
42    }
43
44    pub fn with_request_id(mut self, request_id: impl Into<String>) -> Self {
45        self.body = self.body.with_request_id(request_id);
46        self
47    }
48
49    pub fn with_user_id(mut self, user_id: impl Into<String>) -> Self {
50        self.body = self.body.with_user_id(user_id);
51        self
52    }
53
54    pub fn validate(&self) -> crate::ZaiResult<()> {
55        // Check body constraints
56        self.body
57            .validate()
58            .map_err(crate::client::error::ZaiError::from)?;
59
60        // Ensure file path exists
61        let p =
62            self.file_path
63                .as_ref()
64                .ok_or_else(|| crate::client::error::ZaiError::ApiError {
65                    code: 1200,
66                    message: "file_path is required".to_string(),
67                })?;
68
69        if !Path::new(p).exists() {
70            return Err(crate::client::error::ZaiError::FileError {
71                code: 0,
72                message: format!("file_path not found: {}", p),
73            });
74        }
75
76        // Validate file size (max 8MB)
77        let metadata = std::fs::metadata(p)?;
78        let file_size = metadata.len();
79        const MAX_SIZE: u64 = 8 * 1024 * 1024; // 8MB
80        if file_size > MAX_SIZE {
81            return Err(crate::client::error::ZaiError::FileError {
82                code: 0,
83                message: format!("file_size exceeds 8MB limit: {} bytes", file_size),
84            });
85        }
86
87        // Validate file extension
88        let ext = Path::new(p)
89            .extension()
90            .and_then(|s| s.to_str())
91            .map(|s| s.to_ascii_lowercase());
92        let valid_ext = matches!(
93            ext.as_deref(),
94            Some("png") | Some("jpg") | Some("jpeg") | Some("bmp")
95        );
96        if !valid_ext {
97            return Err(crate::client::error::ZaiError::FileError {
98                code: 0,
99                message: format!(
100                    "invalid file format: {:?}. Only PNG, JPG, JPEG, BMP are supported",
101                    ext
102                ),
103            });
104        }
105
106        Ok(())
107    }
108
109    pub async fn send(&self) -> crate::ZaiResult<super::response::OcrResponse> {
110        self.validate()?;
111
112        let resp = self.post().await?;
113
114        let parsed = resp.json::<super::response::OcrResponse>().await?;
115
116        Ok(parsed)
117    }
118}
119
120impl HttpClient for OcrRequest {
121    type Body = OcrBody;
122    type ApiUrl = &'static str;
123    type ApiKey = String;
124
125    fn api_url(&self) -> &Self::ApiUrl {
126        &"https://open.bigmodel.cn/api/paas/v4/files/ocr"
127    }
128
129    fn api_key(&self) -> &Self::ApiKey {
130        &self.key
131    }
132
133    fn body(&self) -> &Self::Body {
134        &self.body
135    }
136
137    fn post(
138        &self,
139    ) -> impl std::future::Future<Output = crate::ZaiResult<reqwest::Response>> + Send {
140        let key = self.key.clone();
141        let url = (*self.api_url()).to_string();
142        let body = self.body.clone();
143        let file_path_opt = self.file_path.clone();
144
145        async move {
146            let file_path =
147                file_path_opt.ok_or_else(|| crate::client::error::ZaiError::ApiError {
148                    code: 1200,
149                    message: "file_path is required".to_string(),
150                })?;
151
152            let mut form = reqwest::multipart::Form::new();
153
154            // file
155            let file_name = Path::new(&file_path)
156                .file_name()
157                .and_then(|s| s.to_str())
158                .unwrap_or("image.png");
159            let file_bytes = tokio::fs::read(&file_path).await?;
160
161            // Determine MIME type by extension
162            let ext = Path::new(&file_path)
163                .extension()
164                .and_then(|s| s.to_str())
165                .map(|s| s.to_ascii_lowercase());
166            let mime = match ext.as_deref() {
167                Some("png") => "image/png",
168                Some("jpg") | Some("jpeg") => "image/jpeg",
169                Some("bmp") => "image/bmp",
170                _ => "image/png",
171            };
172
173            let part = reqwest::multipart::Part::bytes(file_bytes)
174                .file_name(file_name.to_string())
175                .mime_str(mime)?;
176            form = form.part("file", part);
177
178            // tool_type (required, default to hand_write)
179            let tool_type_str = match &body.tool_type {
180                Some(OcrToolType::HandWrite) => "hand_write",
181                None => "hand_write",
182            };
183            form = form.text("tool_type", tool_type_str);
184
185            // language_type (optional)
186            if let Some(lang) = &body.language_type {
187                let lang_str = serde_json::to_string(lang)
188                    .unwrap_or_default()
189                    .trim_matches('"')
190                    .to_string();
191                form = form.text("language_type", lang_str);
192            }
193
194            // probability (optional)
195            if let Some(prob) = body.probability {
196                form = form.text("probability", prob.to_string());
197            }
198
199            // Use shared HTTP client with connection pooling
200            let client = http_client_with_config(&HttpClientConfig::default());
201            let resp = client
202                .post(url)
203                .bearer_auth(key)
204                .multipart(form)
205                .send()
206                .await?;
207
208            Ok(resp)
209        }
210    }
211}