1use serde::{Deserialize, Serialize};
6use std::path::Path;
7use crate::error::{ScipixError, Result};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct Config {
12 pub ocr: OcrConfig,
14
15 pub model: ModelConfig,
17
18 pub preprocess: PreprocessConfig,
20
21 pub output: OutputConfig,
23
24 pub performance: PerformanceConfig,
26
27 pub cache: CacheConfig,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct OcrConfig {
34 pub confidence_threshold: f32,
36
37 pub timeout: u64,
39
40 pub use_gpu: bool,
42
43 pub languages: Vec<String>,
45
46 pub detect_equations: bool,
48
49 pub detect_tables: bool,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct ModelConfig {
56 pub model_path: String,
58
59 pub version: String,
61
62 pub batch_size: usize,
64
65 pub precision: String,
67
68 pub quantize: bool,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct PreprocessConfig {
75 pub auto_rotate: bool,
77
78 pub denoise: bool,
80
81 pub enhance_contrast: bool,
83
84 pub binarize: bool,
86
87 pub target_dpi: u32,
89
90 pub max_dimension: u32,
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct OutputConfig {
97 pub formats: Vec<String>,
99
100 pub include_confidence: bool,
102
103 pub include_bbox: bool,
105
106 pub pretty_print: bool,
108
109 pub include_metadata: bool,
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct PerformanceConfig {
116 pub num_threads: usize,
118
119 pub parallel: bool,
121
122 pub memory_limit: usize,
124
125 pub profile: bool,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct CacheConfig {
132 pub enabled: bool,
134
135 pub capacity: usize,
137
138 pub similarity_threshold: f32,
140
141 pub ttl: u64,
143
144 pub vector_dimension: usize,
146
147 pub persistent: bool,
149
150 pub cache_dir: String,
152}
153
154impl Default for Config {
155 fn default() -> Self {
156 Self {
157 ocr: OcrConfig {
158 confidence_threshold: 0.7,
159 timeout: 30,
160 use_gpu: false,
161 languages: vec!["en".to_string()],
162 detect_equations: true,
163 detect_tables: true,
164 },
165 model: ModelConfig {
166 model_path: "models/scipix-ocr".to_string(),
167 version: "1.0.0".to_string(),
168 batch_size: 1,
169 precision: "fp32".to_string(),
170 quantize: false,
171 },
172 preprocess: PreprocessConfig {
173 auto_rotate: true,
174 denoise: true,
175 enhance_contrast: true,
176 binarize: false,
177 target_dpi: 300,
178 max_dimension: 4096,
179 },
180 output: OutputConfig {
181 formats: vec!["latex".to_string()],
182 include_confidence: true,
183 include_bbox: false,
184 pretty_print: true,
185 include_metadata: false,
186 },
187 performance: PerformanceConfig {
188 num_threads: num_cpus::get(),
189 parallel: true,
190 memory_limit: 2048,
191 profile: false,
192 },
193 cache: CacheConfig {
194 enabled: true,
195 capacity: 1000,
196 similarity_threshold: 0.95,
197 ttl: 3600,
198 vector_dimension: 512,
199 persistent: false,
200 cache_dir: ".cache/scipix".to_string(),
201 },
202 }
203 }
204}
205
206impl Config {
207 pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
221 let content = std::fs::read_to_string(path)?;
222 let config: Config = toml::from_str(&content)?;
223 config.validate()?;
224 Ok(config)
225 }
226
227 pub fn to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
233 let content = toml::to_string_pretty(self)?;
234 std::fs::write(path, content)?;
235 Ok(())
236 }
237
238 pub fn from_env() -> Result<Self> {
250 let mut config = Self::default();
251 config.apply_env_overrides()?;
252 Ok(config)
253 }
254
255 fn apply_env_overrides(&mut self) -> Result<()> {
257 if let Ok(val) = std::env::var("MATHPIX_OCR__CONFIDENCE_THRESHOLD") {
259 self.ocr.confidence_threshold = val.parse()
260 .map_err(|_| ScipixError::Config("Invalid confidence_threshold".to_string()))?;
261 }
262 if let Ok(val) = std::env::var("MATHPIX_OCR__TIMEOUT") {
263 self.ocr.timeout = val.parse()
264 .map_err(|_| ScipixError::Config("Invalid timeout".to_string()))?;
265 }
266 if let Ok(val) = std::env::var("MATHPIX_OCR__USE_GPU") {
267 self.ocr.use_gpu = val.parse()
268 .map_err(|_| ScipixError::Config("Invalid use_gpu".to_string()))?;
269 }
270
271 if let Ok(val) = std::env::var("MATHPIX_MODEL__PATH") {
273 self.model.model_path = val;
274 }
275 if let Ok(val) = std::env::var("MATHPIX_MODEL__BATCH_SIZE") {
276 self.model.batch_size = val.parse()
277 .map_err(|_| ScipixError::Config("Invalid batch_size".to_string()))?;
278 }
279
280 if let Ok(val) = std::env::var("MATHPIX_CACHE__ENABLED") {
282 self.cache.enabled = val.parse()
283 .map_err(|_| ScipixError::Config("Invalid cache enabled".to_string()))?;
284 }
285 if let Ok(val) = std::env::var("MATHPIX_CACHE__CAPACITY") {
286 self.cache.capacity = val.parse()
287 .map_err(|_| ScipixError::Config("Invalid cache capacity".to_string()))?;
288 }
289
290 Ok(())
291 }
292
293 pub fn validate(&self) -> Result<()> {
295 if self.ocr.confidence_threshold < 0.0 || self.ocr.confidence_threshold > 1.0 {
297 return Err(ScipixError::Config(
298 "confidence_threshold must be between 0.0 and 1.0".to_string()
299 ));
300 }
301
302 if self.cache.similarity_threshold < 0.0 || self.cache.similarity_threshold > 1.0 {
304 return Err(ScipixError::Config(
305 "similarity_threshold must be between 0.0 and 1.0".to_string()
306 ));
307 }
308
309 if self.model.batch_size == 0 {
311 return Err(ScipixError::Config(
312 "batch_size must be greater than 0".to_string()
313 ));
314 }
315
316 let valid_precisions = ["fp16", "fp32", "int8"];
318 if !valid_precisions.contains(&self.model.precision.as_str()) {
319 return Err(ScipixError::Config(
320 format!("precision must be one of: {:?}", valid_precisions)
321 ));
322 }
323
324 let valid_formats = ["latex", "mathml", "asciimath"];
326 for format in &self.output.formats {
327 if !valid_formats.contains(&format.as_str()) {
328 return Err(ScipixError::Config(
329 format!("Invalid output format: {}. Must be one of: {:?}", format, valid_formats)
330 ));
331 }
332 }
333
334 Ok(())
335 }
336
337 pub fn high_accuracy() -> Self {
339 let mut config = Self::default();
340 config.ocr.confidence_threshold = 0.9;
341 config.model.precision = "fp32".to_string();
342 config.model.quantize = false;
343 config.preprocess.denoise = true;
344 config.preprocess.enhance_contrast = true;
345 config.cache.similarity_threshold = 0.98;
346 config
347 }
348
349 pub fn high_speed() -> Self {
351 let mut config = Self::default();
352 config.ocr.confidence_threshold = 0.6;
353 config.model.precision = "fp16".to_string();
354 config.model.quantize = true;
355 config.model.batch_size = 4;
356 config.preprocess.denoise = false;
357 config.preprocess.enhance_contrast = false;
358 config.performance.parallel = true;
359 config.cache.similarity_threshold = 0.85;
360 config
361 }
362
363 pub fn minimal() -> Self {
365 let mut config = Self::default();
366 config.cache.enabled = false;
367 config.preprocess.denoise = false;
368 config.preprocess.enhance_contrast = false;
369 config.performance.parallel = false;
370 config
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377
378 #[test]
379 fn test_default_config() {
380 let config = Config::default();
381 assert!(config.validate().is_ok());
382 assert_eq!(config.ocr.confidence_threshold, 0.7);
383 assert!(config.cache.enabled);
384 }
385
386 #[test]
387 fn test_high_accuracy_config() {
388 let config = Config::high_accuracy();
389 assert!(config.validate().is_ok());
390 assert_eq!(config.ocr.confidence_threshold, 0.9);
391 assert_eq!(config.cache.similarity_threshold, 0.98);
392 }
393
394 #[test]
395 fn test_high_speed_config() {
396 let config = Config::high_speed();
397 assert!(config.validate().is_ok());
398 assert_eq!(config.model.precision, "fp16");
399 assert!(config.model.quantize);
400 }
401
402 #[test]
403 fn test_minimal_config() {
404 let config = Config::minimal();
405 assert!(config.validate().is_ok());
406 assert!(!config.cache.enabled);
407 }
408
409 #[test]
410 fn test_invalid_confidence_threshold() {
411 let mut config = Config::default();
412 config.ocr.confidence_threshold = 1.5;
413 assert!(config.validate().is_err());
414 }
415
416 #[test]
417 fn test_invalid_batch_size() {
418 let mut config = Config::default();
419 config.model.batch_size = 0;
420 assert!(config.validate().is_err());
421 }
422
423 #[test]
424 fn test_invalid_precision() {
425 let mut config = Config::default();
426 config.model.precision = "invalid".to_string();
427 assert!(config.validate().is_err());
428 }
429
430 #[test]
431 fn test_invalid_output_format() {
432 let mut config = Config::default();
433 config.output.formats = vec!["invalid".to_string()];
434 assert!(config.validate().is_err());
435 }
436
437 #[test]
438 fn test_toml_serialization() {
439 let config = Config::default();
440 let toml_str = toml::to_string(&config).unwrap();
441 let deserialized: Config = toml::from_str(&toml_str).unwrap();
442 assert_eq!(config.ocr.confidence_threshold, deserialized.ocr.confidence_threshold);
443 }
444}