1use crate::error::{Result, ScipixError};
6use serde::{Deserialize, Serialize};
7use std::path::Path;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct Config {
12 pub ocr: OcrConfig,
14
15 pub model: ModelConfig,
17
18 pub preprocess: PreprocessConfig,
20
21 pub output: OutputConfig,
23
24 pub performance: PerformanceConfig,
26
27 pub cache: CacheConfig,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct OcrConfig {
34 pub confidence_threshold: f32,
36
37 pub timeout: u64,
39
40 pub use_gpu: bool,
42
43 pub languages: Vec<String>,
45
46 pub detect_equations: bool,
48
49 pub detect_tables: bool,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct ModelConfig {
56 pub model_path: String,
58
59 pub version: String,
61
62 pub batch_size: usize,
64
65 pub precision: String,
67
68 pub quantize: bool,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct PreprocessConfig {
75 pub auto_rotate: bool,
77
78 pub denoise: bool,
80
81 pub enhance_contrast: bool,
83
84 pub binarize: bool,
86
87 pub target_dpi: u32,
89
90 pub max_dimension: u32,
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct OutputConfig {
97 pub formats: Vec<String>,
99
100 pub include_confidence: bool,
102
103 pub include_bbox: bool,
105
106 pub pretty_print: bool,
108
109 pub include_metadata: bool,
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct PerformanceConfig {
116 pub num_threads: usize,
118
119 pub parallel: bool,
121
122 pub memory_limit: usize,
124
125 pub profile: bool,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct CacheConfig {
132 pub enabled: bool,
134
135 pub capacity: usize,
137
138 pub similarity_threshold: f32,
140
141 pub ttl: u64,
143
144 pub vector_dimension: usize,
146
147 pub persistent: bool,
149
150 pub cache_dir: String,
152}
153
154impl Default for Config {
155 fn default() -> Self {
156 Self {
157 ocr: OcrConfig {
158 confidence_threshold: 0.7,
159 timeout: 30,
160 use_gpu: false,
161 languages: vec!["en".to_string()],
162 detect_equations: true,
163 detect_tables: true,
164 },
165 model: ModelConfig {
166 model_path: "models/scipix-ocr".to_string(),
167 version: "1.0.0".to_string(),
168 batch_size: 1,
169 precision: "fp32".to_string(),
170 quantize: false,
171 },
172 preprocess: PreprocessConfig {
173 auto_rotate: true,
174 denoise: true,
175 enhance_contrast: true,
176 binarize: false,
177 target_dpi: 300,
178 max_dimension: 4096,
179 },
180 output: OutputConfig {
181 formats: vec!["latex".to_string()],
182 include_confidence: true,
183 include_bbox: false,
184 pretty_print: true,
185 include_metadata: false,
186 },
187 performance: PerformanceConfig {
188 num_threads: num_cpus::get(),
189 parallel: true,
190 memory_limit: 2048,
191 profile: false,
192 },
193 cache: CacheConfig {
194 enabled: true,
195 capacity: 1000,
196 similarity_threshold: 0.95,
197 ttl: 3600,
198 vector_dimension: 512,
199 persistent: false,
200 cache_dir: ".cache/scipix".to_string(),
201 },
202 }
203 }
204}
205
206impl Config {
207 pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
221 let content = std::fs::read_to_string(path)?;
222 let config: Config = toml::from_str(&content)?;
223 config.validate()?;
224 Ok(config)
225 }
226
227 pub fn to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
233 let content = toml::to_string_pretty(self)?;
234 std::fs::write(path, content)?;
235 Ok(())
236 }
237
238 pub fn from_env() -> Result<Self> {
250 let mut config = Self::default();
251 config.apply_env_overrides()?;
252 Ok(config)
253 }
254
255 fn apply_env_overrides(&mut self) -> Result<()> {
257 if let Ok(val) = std::env::var("MATHPIX_OCR__CONFIDENCE_THRESHOLD") {
259 self.ocr.confidence_threshold = val
260 .parse()
261 .map_err(|_| ScipixError::Config("Invalid confidence_threshold".to_string()))?;
262 }
263 if let Ok(val) = std::env::var("MATHPIX_OCR__TIMEOUT") {
264 self.ocr.timeout = val
265 .parse()
266 .map_err(|_| ScipixError::Config("Invalid timeout".to_string()))?;
267 }
268 if let Ok(val) = std::env::var("MATHPIX_OCR__USE_GPU") {
269 self.ocr.use_gpu = val
270 .parse()
271 .map_err(|_| ScipixError::Config("Invalid use_gpu".to_string()))?;
272 }
273
274 if let Ok(val) = std::env::var("MATHPIX_MODEL__PATH") {
276 self.model.model_path = val;
277 }
278 if let Ok(val) = std::env::var("MATHPIX_MODEL__BATCH_SIZE") {
279 self.model.batch_size = val
280 .parse()
281 .map_err(|_| ScipixError::Config("Invalid batch_size".to_string()))?;
282 }
283
284 if let Ok(val) = std::env::var("MATHPIX_CACHE__ENABLED") {
286 self.cache.enabled = val
287 .parse()
288 .map_err(|_| ScipixError::Config("Invalid cache enabled".to_string()))?;
289 }
290 if let Ok(val) = std::env::var("MATHPIX_CACHE__CAPACITY") {
291 self.cache.capacity = val
292 .parse()
293 .map_err(|_| ScipixError::Config("Invalid cache capacity".to_string()))?;
294 }
295
296 Ok(())
297 }
298
299 pub fn validate(&self) -> Result<()> {
301 if self.ocr.confidence_threshold < 0.0 || self.ocr.confidence_threshold > 1.0 {
303 return Err(ScipixError::Config(
304 "confidence_threshold must be between 0.0 and 1.0".to_string(),
305 ));
306 }
307
308 if self.cache.similarity_threshold < 0.0 || self.cache.similarity_threshold > 1.0 {
310 return Err(ScipixError::Config(
311 "similarity_threshold must be between 0.0 and 1.0".to_string(),
312 ));
313 }
314
315 if self.model.batch_size == 0 {
317 return Err(ScipixError::Config(
318 "batch_size must be greater than 0".to_string(),
319 ));
320 }
321
322 let valid_precisions = ["fp16", "fp32", "int8"];
324 if !valid_precisions.contains(&self.model.precision.as_str()) {
325 return Err(ScipixError::Config(format!(
326 "precision must be one of: {:?}",
327 valid_precisions
328 )));
329 }
330
331 let valid_formats = ["latex", "mathml", "asciimath"];
333 for format in &self.output.formats {
334 if !valid_formats.contains(&format.as_str()) {
335 return Err(ScipixError::Config(format!(
336 "Invalid output format: {}. Must be one of: {:?}",
337 format, valid_formats
338 )));
339 }
340 }
341
342 Ok(())
343 }
344
345 pub fn high_accuracy() -> Self {
347 let mut config = Self::default();
348 config.ocr.confidence_threshold = 0.9;
349 config.model.precision = "fp32".to_string();
350 config.model.quantize = false;
351 config.preprocess.denoise = true;
352 config.preprocess.enhance_contrast = true;
353 config.cache.similarity_threshold = 0.98;
354 config
355 }
356
357 pub fn high_speed() -> Self {
359 let mut config = Self::default();
360 config.ocr.confidence_threshold = 0.6;
361 config.model.precision = "fp16".to_string();
362 config.model.quantize = true;
363 config.model.batch_size = 4;
364 config.preprocess.denoise = false;
365 config.preprocess.enhance_contrast = false;
366 config.performance.parallel = true;
367 config.cache.similarity_threshold = 0.85;
368 config
369 }
370
371 pub fn minimal() -> Self {
373 let mut config = Self::default();
374 config.cache.enabled = false;
375 config.preprocess.denoise = false;
376 config.preprocess.enhance_contrast = false;
377 config.performance.parallel = false;
378 config
379 }
380}
381
382#[cfg(test)]
383mod tests {
384 use super::*;
385
386 #[test]
387 fn test_default_config() {
388 let config = Config::default();
389 assert!(config.validate().is_ok());
390 assert_eq!(config.ocr.confidence_threshold, 0.7);
391 assert!(config.cache.enabled);
392 }
393
394 #[test]
395 fn test_high_accuracy_config() {
396 let config = Config::high_accuracy();
397 assert!(config.validate().is_ok());
398 assert_eq!(config.ocr.confidence_threshold, 0.9);
399 assert_eq!(config.cache.similarity_threshold, 0.98);
400 }
401
402 #[test]
403 fn test_high_speed_config() {
404 let config = Config::high_speed();
405 assert!(config.validate().is_ok());
406 assert_eq!(config.model.precision, "fp16");
407 assert!(config.model.quantize);
408 }
409
410 #[test]
411 fn test_minimal_config() {
412 let config = Config::minimal();
413 assert!(config.validate().is_ok());
414 assert!(!config.cache.enabled);
415 }
416
417 #[test]
418 fn test_invalid_confidence_threshold() {
419 let mut config = Config::default();
420 config.ocr.confidence_threshold = 1.5;
421 assert!(config.validate().is_err());
422 }
423
424 #[test]
425 fn test_invalid_batch_size() {
426 let mut config = Config::default();
427 config.model.batch_size = 0;
428 assert!(config.validate().is_err());
429 }
430
431 #[test]
432 fn test_invalid_precision() {
433 let mut config = Config::default();
434 config.model.precision = "invalid".to_string();
435 assert!(config.validate().is_err());
436 }
437
438 #[test]
439 fn test_invalid_output_format() {
440 let mut config = Config::default();
441 config.output.formats = vec!["invalid".to_string()];
442 assert!(config.validate().is_err());
443 }
444
445 #[test]
446 fn test_toml_serialization() {
447 let config = Config::default();
448 let toml_str = toml::to_string(&config).unwrap();
449 let deserialized: Config = toml::from_str(&toml_str).unwrap();
450 assert_eq!(
451 config.ocr.confidence_threshold,
452 deserialized.ocr.confidence_threshold
453 );
454 }
455}