Skip to main content

ruvector_scipix/
config.rs

1//! Configuration system for Ruvector-Scipix
2//!
3//! Comprehensive configuration with TOML support, environment overrides, and validation.
4
5use crate::error::{Result, ScipixError};
6use serde::{Deserialize, Serialize};
7use std::path::Path;
8
9/// Main configuration structure
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct Config {
12    /// OCR processing configuration
13    pub ocr: OcrConfig,
14
15    /// Model configuration
16    pub model: ModelConfig,
17
18    /// Preprocessing configuration
19    pub preprocess: PreprocessConfig,
20
21    /// Output format configuration
22    pub output: OutputConfig,
23
24    /// Performance tuning
25    pub performance: PerformanceConfig,
26
27    /// Cache configuration
28    pub cache: CacheConfig,
29}
30
31/// OCR engine configuration
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct OcrConfig {
34    /// Confidence threshold (0.0-1.0)
35    pub confidence_threshold: f32,
36
37    /// Maximum processing time in seconds
38    pub timeout: u64,
39
40    /// Enable GPU acceleration
41    pub use_gpu: bool,
42
43    /// Language codes (e.g., ["en", "es"])
44    pub languages: Vec<String>,
45
46    /// Enable equation detection
47    pub detect_equations: bool,
48
49    /// Enable table detection
50    pub detect_tables: bool,
51}
52
53/// Model configuration
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct ModelConfig {
56    /// Path to OCR model
57    pub model_path: String,
58
59    /// Model version
60    pub version: String,
61
62    /// Batch size for processing
63    pub batch_size: usize,
64
65    /// Model precision (fp16, fp32, int8)
66    pub precision: String,
67
68    /// Enable quantization
69    pub quantize: bool,
70}
71
72/// Image preprocessing configuration
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct PreprocessConfig {
75    /// Enable auto-rotation
76    pub auto_rotate: bool,
77
78    /// Enable denoising
79    pub denoise: bool,
80
81    /// Enable contrast enhancement
82    pub enhance_contrast: bool,
83
84    /// Enable binarization
85    pub binarize: bool,
86
87    /// Target DPI for scaling
88    pub target_dpi: u32,
89
90    /// Maximum image dimension
91    pub max_dimension: u32,
92}
93
94/// Output format configuration
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct OutputConfig {
97    /// Output formats (latex, mathml, asciimath)
98    pub formats: Vec<String>,
99
100    /// Include confidence scores
101    pub include_confidence: bool,
102
103    /// Include bounding boxes
104    pub include_bbox: bool,
105
106    /// Pretty print JSON
107    pub pretty_print: bool,
108
109    /// Include metadata
110    pub include_metadata: bool,
111}
112
113/// Performance tuning configuration
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct PerformanceConfig {
116    /// Number of worker threads
117    pub num_threads: usize,
118
119    /// Enable parallel processing
120    pub parallel: bool,
121
122    /// Memory limit in MB
123    pub memory_limit: usize,
124
125    /// Enable profiling
126    pub profile: bool,
127}
128
129/// Cache configuration
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct CacheConfig {
132    /// Enable caching
133    pub enabled: bool,
134
135    /// Cache capacity (number of entries)
136    pub capacity: usize,
137
138    /// Similarity threshold for cache hits (0.0-1.0)
139    pub similarity_threshold: f32,
140
141    /// Cache TTL in seconds
142    pub ttl: u64,
143
144    /// Vector dimension for embeddings
145    pub vector_dimension: usize,
146
147    /// Enable persistent cache
148    pub persistent: bool,
149
150    /// Cache directory path
151    pub cache_dir: String,
152}
153
154impl Default for Config {
155    fn default() -> Self {
156        Self {
157            ocr: OcrConfig {
158                confidence_threshold: 0.7,
159                timeout: 30,
160                use_gpu: false,
161                languages: vec!["en".to_string()],
162                detect_equations: true,
163                detect_tables: true,
164            },
165            model: ModelConfig {
166                model_path: "models/scipix-ocr".to_string(),
167                version: "1.0.0".to_string(),
168                batch_size: 1,
169                precision: "fp32".to_string(),
170                quantize: false,
171            },
172            preprocess: PreprocessConfig {
173                auto_rotate: true,
174                denoise: true,
175                enhance_contrast: true,
176                binarize: false,
177                target_dpi: 300,
178                max_dimension: 4096,
179            },
180            output: OutputConfig {
181                formats: vec!["latex".to_string()],
182                include_confidence: true,
183                include_bbox: false,
184                pretty_print: true,
185                include_metadata: false,
186            },
187            performance: PerformanceConfig {
188                num_threads: num_cpus::get(),
189                parallel: true,
190                memory_limit: 2048,
191                profile: false,
192            },
193            cache: CacheConfig {
194                enabled: true,
195                capacity: 1000,
196                similarity_threshold: 0.95,
197                ttl: 3600,
198                vector_dimension: 512,
199                persistent: false,
200                cache_dir: ".cache/scipix".to_string(),
201            },
202        }
203    }
204}
205
206impl Config {
207    /// Load configuration from TOML file
208    ///
209    /// # Arguments
210    ///
211    /// * `path` - Path to TOML configuration file
212    ///
213    /// # Examples
214    ///
215    /// ```rust,no_run
216    /// use ruvector_scipix::Config;
217    ///
218    /// let config = Config::from_file("scipix.toml").unwrap();
219    /// ```
220    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
221        let content = std::fs::read_to_string(path)?;
222        let config: Config = toml::from_str(&content)?;
223        config.validate()?;
224        Ok(config)
225    }
226
227    /// Save configuration to TOML file
228    ///
229    /// # Arguments
230    ///
231    /// * `path` - Path to save TOML configuration
232    pub fn to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
233        let content = toml::to_string_pretty(self)?;
234        std::fs::write(path, content)?;
235        Ok(())
236    }
237
238    /// Load configuration from environment variables
239    ///
240    /// Environment variables should be prefixed with `MATHPIX_`
241    /// and use double underscores for nested fields.
242    ///
243    /// # Examples
244    ///
245    /// ```bash
246    /// export MATHPIX_OCR__CONFIDENCE_THRESHOLD=0.8
247    /// export MATHPIX_MODEL__BATCH_SIZE=4
248    /// ```
249    pub fn from_env() -> Result<Self> {
250        let mut config = Self::default();
251        config.apply_env_overrides()?;
252        Ok(config)
253    }
254
255    /// Apply environment variable overrides
256    fn apply_env_overrides(&mut self) -> Result<()> {
257        // OCR overrides
258        if let Ok(val) = std::env::var("MATHPIX_OCR__CONFIDENCE_THRESHOLD") {
259            self.ocr.confidence_threshold = val
260                .parse()
261                .map_err(|_| ScipixError::Config("Invalid confidence_threshold".to_string()))?;
262        }
263        if let Ok(val) = std::env::var("MATHPIX_OCR__TIMEOUT") {
264            self.ocr.timeout = val
265                .parse()
266                .map_err(|_| ScipixError::Config("Invalid timeout".to_string()))?;
267        }
268        if let Ok(val) = std::env::var("MATHPIX_OCR__USE_GPU") {
269            self.ocr.use_gpu = val
270                .parse()
271                .map_err(|_| ScipixError::Config("Invalid use_gpu".to_string()))?;
272        }
273
274        // Model overrides
275        if let Ok(val) = std::env::var("MATHPIX_MODEL__PATH") {
276            self.model.model_path = val;
277        }
278        if let Ok(val) = std::env::var("MATHPIX_MODEL__BATCH_SIZE") {
279            self.model.batch_size = val
280                .parse()
281                .map_err(|_| ScipixError::Config("Invalid batch_size".to_string()))?;
282        }
283
284        // Cache overrides
285        if let Ok(val) = std::env::var("MATHPIX_CACHE__ENABLED") {
286            self.cache.enabled = val
287                .parse()
288                .map_err(|_| ScipixError::Config("Invalid cache enabled".to_string()))?;
289        }
290        if let Ok(val) = std::env::var("MATHPIX_CACHE__CAPACITY") {
291            self.cache.capacity = val
292                .parse()
293                .map_err(|_| ScipixError::Config("Invalid cache capacity".to_string()))?;
294        }
295
296        Ok(())
297    }
298
299    /// Validate configuration
300    pub fn validate(&self) -> Result<()> {
301        // Validate confidence threshold
302        if self.ocr.confidence_threshold < 0.0 || self.ocr.confidence_threshold > 1.0 {
303            return Err(ScipixError::Config(
304                "confidence_threshold must be between 0.0 and 1.0".to_string(),
305            ));
306        }
307
308        // Validate similarity threshold
309        if self.cache.similarity_threshold < 0.0 || self.cache.similarity_threshold > 1.0 {
310            return Err(ScipixError::Config(
311                "similarity_threshold must be between 0.0 and 1.0".to_string(),
312            ));
313        }
314
315        // Validate batch size
316        if self.model.batch_size == 0 {
317            return Err(ScipixError::Config(
318                "batch_size must be greater than 0".to_string(),
319            ));
320        }
321
322        // Validate precision
323        let valid_precisions = ["fp16", "fp32", "int8"];
324        if !valid_precisions.contains(&self.model.precision.as_str()) {
325            return Err(ScipixError::Config(format!(
326                "precision must be one of: {:?}",
327                valid_precisions
328            )));
329        }
330
331        // Validate output formats
332        let valid_formats = ["latex", "mathml", "asciimath"];
333        for format in &self.output.formats {
334            if !valid_formats.contains(&format.as_str()) {
335                return Err(ScipixError::Config(format!(
336                    "Invalid output format: {}. Must be one of: {:?}",
337                    format, valid_formats
338                )));
339            }
340        }
341
342        Ok(())
343    }
344
345    /// Create high-accuracy preset configuration
346    pub fn high_accuracy() -> Self {
347        let mut config = Self::default();
348        config.ocr.confidence_threshold = 0.9;
349        config.model.precision = "fp32".to_string();
350        config.model.quantize = false;
351        config.preprocess.denoise = true;
352        config.preprocess.enhance_contrast = true;
353        config.cache.similarity_threshold = 0.98;
354        config
355    }
356
357    /// Create high-speed preset configuration
358    pub fn high_speed() -> Self {
359        let mut config = Self::default();
360        config.ocr.confidence_threshold = 0.6;
361        config.model.precision = "fp16".to_string();
362        config.model.quantize = true;
363        config.model.batch_size = 4;
364        config.preprocess.denoise = false;
365        config.preprocess.enhance_contrast = false;
366        config.performance.parallel = true;
367        config.cache.similarity_threshold = 0.85;
368        config
369    }
370
371    /// Create minimal configuration
372    pub fn minimal() -> Self {
373        let mut config = Self::default();
374        config.cache.enabled = false;
375        config.preprocess.denoise = false;
376        config.preprocess.enhance_contrast = false;
377        config.performance.parallel = false;
378        config
379    }
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385
386    #[test]
387    fn test_default_config() {
388        let config = Config::default();
389        assert!(config.validate().is_ok());
390        assert_eq!(config.ocr.confidence_threshold, 0.7);
391        assert!(config.cache.enabled);
392    }
393
394    #[test]
395    fn test_high_accuracy_config() {
396        let config = Config::high_accuracy();
397        assert!(config.validate().is_ok());
398        assert_eq!(config.ocr.confidence_threshold, 0.9);
399        assert_eq!(config.cache.similarity_threshold, 0.98);
400    }
401
402    #[test]
403    fn test_high_speed_config() {
404        let config = Config::high_speed();
405        assert!(config.validate().is_ok());
406        assert_eq!(config.model.precision, "fp16");
407        assert!(config.model.quantize);
408    }
409
410    #[test]
411    fn test_minimal_config() {
412        let config = Config::minimal();
413        assert!(config.validate().is_ok());
414        assert!(!config.cache.enabled);
415    }
416
417    #[test]
418    fn test_invalid_confidence_threshold() {
419        let mut config = Config::default();
420        config.ocr.confidence_threshold = 1.5;
421        assert!(config.validate().is_err());
422    }
423
424    #[test]
425    fn test_invalid_batch_size() {
426        let mut config = Config::default();
427        config.model.batch_size = 0;
428        assert!(config.validate().is_err());
429    }
430
431    #[test]
432    fn test_invalid_precision() {
433        let mut config = Config::default();
434        config.model.precision = "invalid".to_string();
435        assert!(config.validate().is_err());
436    }
437
438    #[test]
439    fn test_invalid_output_format() {
440        let mut config = Config::default();
441        config.output.formats = vec!["invalid".to_string()];
442        assert!(config.validate().is_err());
443    }
444
445    #[test]
446    fn test_toml_serialization() {
447        let config = Config::default();
448        let toml_str = toml::to_string(&config).unwrap();
449        let deserialized: Config = toml::from_str(&toml_str).unwrap();
450        assert_eq!(
451            config.ocr.confidence_threshold,
452            deserialized.ocr.confidence_threshold
453        );
454    }
455}