ruvector_scipix/
config.rs

1//! Configuration system for Ruvector-Scipix
2//!
3//! Comprehensive configuration with TOML support, environment overrides, and validation.
4
5use serde::{Deserialize, Serialize};
6use std::path::Path;
7use crate::error::{ScipixError, Result};
8
9/// Main configuration structure
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct Config {
12    /// OCR processing configuration
13    pub ocr: OcrConfig,
14
15    /// Model configuration
16    pub model: ModelConfig,
17
18    /// Preprocessing configuration
19    pub preprocess: PreprocessConfig,
20
21    /// Output format configuration
22    pub output: OutputConfig,
23
24    /// Performance tuning
25    pub performance: PerformanceConfig,
26
27    /// Cache configuration
28    pub cache: CacheConfig,
29}
30
31/// OCR engine configuration
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct OcrConfig {
34    /// Confidence threshold (0.0-1.0)
35    pub confidence_threshold: f32,
36
37    /// Maximum processing time in seconds
38    pub timeout: u64,
39
40    /// Enable GPU acceleration
41    pub use_gpu: bool,
42
43    /// Language codes (e.g., ["en", "es"])
44    pub languages: Vec<String>,
45
46    /// Enable equation detection
47    pub detect_equations: bool,
48
49    /// Enable table detection
50    pub detect_tables: bool,
51}
52
53/// Model configuration
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct ModelConfig {
56    /// Path to OCR model
57    pub model_path: String,
58
59    /// Model version
60    pub version: String,
61
62    /// Batch size for processing
63    pub batch_size: usize,
64
65    /// Model precision (fp16, fp32, int8)
66    pub precision: String,
67
68    /// Enable quantization
69    pub quantize: bool,
70}
71
72/// Image preprocessing configuration
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct PreprocessConfig {
75    /// Enable auto-rotation
76    pub auto_rotate: bool,
77
78    /// Enable denoising
79    pub denoise: bool,
80
81    /// Enable contrast enhancement
82    pub enhance_contrast: bool,
83
84    /// Enable binarization
85    pub binarize: bool,
86
87    /// Target DPI for scaling
88    pub target_dpi: u32,
89
90    /// Maximum image dimension
91    pub max_dimension: u32,
92}
93
94/// Output format configuration
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct OutputConfig {
97    /// Output formats (latex, mathml, asciimath)
98    pub formats: Vec<String>,
99
100    /// Include confidence scores
101    pub include_confidence: bool,
102
103    /// Include bounding boxes
104    pub include_bbox: bool,
105
106    /// Pretty print JSON
107    pub pretty_print: bool,
108
109    /// Include metadata
110    pub include_metadata: bool,
111}
112
113/// Performance tuning configuration
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct PerformanceConfig {
116    /// Number of worker threads
117    pub num_threads: usize,
118
119    /// Enable parallel processing
120    pub parallel: bool,
121
122    /// Memory limit in MB
123    pub memory_limit: usize,
124
125    /// Enable profiling
126    pub profile: bool,
127}
128
129/// Cache configuration
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct CacheConfig {
132    /// Enable caching
133    pub enabled: bool,
134
135    /// Cache capacity (number of entries)
136    pub capacity: usize,
137
138    /// Similarity threshold for cache hits (0.0-1.0)
139    pub similarity_threshold: f32,
140
141    /// Cache TTL in seconds
142    pub ttl: u64,
143
144    /// Vector dimension for embeddings
145    pub vector_dimension: usize,
146
147    /// Enable persistent cache
148    pub persistent: bool,
149
150    /// Cache directory path
151    pub cache_dir: String,
152}
153
154impl Default for Config {
155    fn default() -> Self {
156        Self {
157            ocr: OcrConfig {
158                confidence_threshold: 0.7,
159                timeout: 30,
160                use_gpu: false,
161                languages: vec!["en".to_string()],
162                detect_equations: true,
163                detect_tables: true,
164            },
165            model: ModelConfig {
166                model_path: "models/scipix-ocr".to_string(),
167                version: "1.0.0".to_string(),
168                batch_size: 1,
169                precision: "fp32".to_string(),
170                quantize: false,
171            },
172            preprocess: PreprocessConfig {
173                auto_rotate: true,
174                denoise: true,
175                enhance_contrast: true,
176                binarize: false,
177                target_dpi: 300,
178                max_dimension: 4096,
179            },
180            output: OutputConfig {
181                formats: vec!["latex".to_string()],
182                include_confidence: true,
183                include_bbox: false,
184                pretty_print: true,
185                include_metadata: false,
186            },
187            performance: PerformanceConfig {
188                num_threads: num_cpus::get(),
189                parallel: true,
190                memory_limit: 2048,
191                profile: false,
192            },
193            cache: CacheConfig {
194                enabled: true,
195                capacity: 1000,
196                similarity_threshold: 0.95,
197                ttl: 3600,
198                vector_dimension: 512,
199                persistent: false,
200                cache_dir: ".cache/scipix".to_string(),
201            },
202        }
203    }
204}
205
206impl Config {
207    /// Load configuration from TOML file
208    ///
209    /// # Arguments
210    ///
211    /// * `path` - Path to TOML configuration file
212    ///
213    /// # Examples
214    ///
215    /// ```rust,no_run
216    /// use ruvector_scipix::Config;
217    ///
218    /// let config = Config::from_file("scipix.toml").unwrap();
219    /// ```
220    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
221        let content = std::fs::read_to_string(path)?;
222        let config: Config = toml::from_str(&content)?;
223        config.validate()?;
224        Ok(config)
225    }
226
227    /// Save configuration to TOML file
228    ///
229    /// # Arguments
230    ///
231    /// * `path` - Path to save TOML configuration
232    pub fn to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
233        let content = toml::to_string_pretty(self)?;
234        std::fs::write(path, content)?;
235        Ok(())
236    }
237
238    /// Load configuration from environment variables
239    ///
240    /// Environment variables should be prefixed with `MATHPIX_`
241    /// and use double underscores for nested fields.
242    ///
243    /// # Examples
244    ///
245    /// ```bash
246    /// export MATHPIX_OCR__CONFIDENCE_THRESHOLD=0.8
247    /// export MATHPIX_MODEL__BATCH_SIZE=4
248    /// ```
249    pub fn from_env() -> Result<Self> {
250        let mut config = Self::default();
251        config.apply_env_overrides()?;
252        Ok(config)
253    }
254
255    /// Apply environment variable overrides
256    fn apply_env_overrides(&mut self) -> Result<()> {
257        // OCR overrides
258        if let Ok(val) = std::env::var("MATHPIX_OCR__CONFIDENCE_THRESHOLD") {
259            self.ocr.confidence_threshold = val.parse()
260                .map_err(|_| ScipixError::Config("Invalid confidence_threshold".to_string()))?;
261        }
262        if let Ok(val) = std::env::var("MATHPIX_OCR__TIMEOUT") {
263            self.ocr.timeout = val.parse()
264                .map_err(|_| ScipixError::Config("Invalid timeout".to_string()))?;
265        }
266        if let Ok(val) = std::env::var("MATHPIX_OCR__USE_GPU") {
267            self.ocr.use_gpu = val.parse()
268                .map_err(|_| ScipixError::Config("Invalid use_gpu".to_string()))?;
269        }
270
271        // Model overrides
272        if let Ok(val) = std::env::var("MATHPIX_MODEL__PATH") {
273            self.model.model_path = val;
274        }
275        if let Ok(val) = std::env::var("MATHPIX_MODEL__BATCH_SIZE") {
276            self.model.batch_size = val.parse()
277                .map_err(|_| ScipixError::Config("Invalid batch_size".to_string()))?;
278        }
279
280        // Cache overrides
281        if let Ok(val) = std::env::var("MATHPIX_CACHE__ENABLED") {
282            self.cache.enabled = val.parse()
283                .map_err(|_| ScipixError::Config("Invalid cache enabled".to_string()))?;
284        }
285        if let Ok(val) = std::env::var("MATHPIX_CACHE__CAPACITY") {
286            self.cache.capacity = val.parse()
287                .map_err(|_| ScipixError::Config("Invalid cache capacity".to_string()))?;
288        }
289
290        Ok(())
291    }
292
293    /// Validate configuration
294    pub fn validate(&self) -> Result<()> {
295        // Validate confidence threshold
296        if self.ocr.confidence_threshold < 0.0 || self.ocr.confidence_threshold > 1.0 {
297            return Err(ScipixError::Config(
298                "confidence_threshold must be between 0.0 and 1.0".to_string()
299            ));
300        }
301
302        // Validate similarity threshold
303        if self.cache.similarity_threshold < 0.0 || self.cache.similarity_threshold > 1.0 {
304            return Err(ScipixError::Config(
305                "similarity_threshold must be between 0.0 and 1.0".to_string()
306            ));
307        }
308
309        // Validate batch size
310        if self.model.batch_size == 0 {
311            return Err(ScipixError::Config(
312                "batch_size must be greater than 0".to_string()
313            ));
314        }
315
316        // Validate precision
317        let valid_precisions = ["fp16", "fp32", "int8"];
318        if !valid_precisions.contains(&self.model.precision.as_str()) {
319            return Err(ScipixError::Config(
320                format!("precision must be one of: {:?}", valid_precisions)
321            ));
322        }
323
324        // Validate output formats
325        let valid_formats = ["latex", "mathml", "asciimath"];
326        for format in &self.output.formats {
327            if !valid_formats.contains(&format.as_str()) {
328                return Err(ScipixError::Config(
329                    format!("Invalid output format: {}. Must be one of: {:?}", format, valid_formats)
330                ));
331            }
332        }
333
334        Ok(())
335    }
336
337    /// Create high-accuracy preset configuration
338    pub fn high_accuracy() -> Self {
339        let mut config = Self::default();
340        config.ocr.confidence_threshold = 0.9;
341        config.model.precision = "fp32".to_string();
342        config.model.quantize = false;
343        config.preprocess.denoise = true;
344        config.preprocess.enhance_contrast = true;
345        config.cache.similarity_threshold = 0.98;
346        config
347    }
348
349    /// Create high-speed preset configuration
350    pub fn high_speed() -> Self {
351        let mut config = Self::default();
352        config.ocr.confidence_threshold = 0.6;
353        config.model.precision = "fp16".to_string();
354        config.model.quantize = true;
355        config.model.batch_size = 4;
356        config.preprocess.denoise = false;
357        config.preprocess.enhance_contrast = false;
358        config.performance.parallel = true;
359        config.cache.similarity_threshold = 0.85;
360        config
361    }
362
363    /// Create minimal configuration
364    pub fn minimal() -> Self {
365        let mut config = Self::default();
366        config.cache.enabled = false;
367        config.preprocess.denoise = false;
368        config.preprocess.enhance_contrast = false;
369        config.performance.parallel = false;
370        config
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377
378    #[test]
379    fn test_default_config() {
380        let config = Config::default();
381        assert!(config.validate().is_ok());
382        assert_eq!(config.ocr.confidence_threshold, 0.7);
383        assert!(config.cache.enabled);
384    }
385
386    #[test]
387    fn test_high_accuracy_config() {
388        let config = Config::high_accuracy();
389        assert!(config.validate().is_ok());
390        assert_eq!(config.ocr.confidence_threshold, 0.9);
391        assert_eq!(config.cache.similarity_threshold, 0.98);
392    }
393
394    #[test]
395    fn test_high_speed_config() {
396        let config = Config::high_speed();
397        assert!(config.validate().is_ok());
398        assert_eq!(config.model.precision, "fp16");
399        assert!(config.model.quantize);
400    }
401
402    #[test]
403    fn test_minimal_config() {
404        let config = Config::minimal();
405        assert!(config.validate().is_ok());
406        assert!(!config.cache.enabled);
407    }
408
409    #[test]
410    fn test_invalid_confidence_threshold() {
411        let mut config = Config::default();
412        config.ocr.confidence_threshold = 1.5;
413        assert!(config.validate().is_err());
414    }
415
416    #[test]
417    fn test_invalid_batch_size() {
418        let mut config = Config::default();
419        config.model.batch_size = 0;
420        assert!(config.validate().is_err());
421    }
422
423    #[test]
424    fn test_invalid_precision() {
425        let mut config = Config::default();
426        config.model.precision = "invalid".to_string();
427        assert!(config.validate().is_err());
428    }
429
430    #[test]
431    fn test_invalid_output_format() {
432        let mut config = Config::default();
433        config.output.formats = vec!["invalid".to_string()];
434        assert!(config.validate().is_err());
435    }
436
437    #[test]
438    fn test_toml_serialization() {
439        let config = Config::default();
440        let toml_str = toml::to_string(&config).unwrap();
441        let deserialized: Config = toml::from_str(&toml_str).unwrap();
442        assert_eq!(config.ocr.confidence_threshold, deserialized.ocr.confidence_threshold);
443    }
444}