Skip to main content

torsh_cli/
config.rs

1//! Configuration management for ToRSh CLI
2
3// Framework infrastructure - components designed for future use
4#![allow(dead_code)]
5use anyhow::{Context, Result};
6use serde::{Deserialize, Serialize};
7use std::path::{Path, PathBuf};
8use tracing::{debug, info};
9
10/// CLI configuration
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct Config {
13    /// General settings
14    pub general: GeneralConfig,
15
16    /// Model operations settings
17    pub model: ModelConfig,
18
19    /// Training settings
20    pub training: TrainingConfig,
21
22    /// Hub settings
23    pub hub: HubConfig,
24
25    /// Benchmark settings
26    pub benchmark: BenchmarkConfig,
27
28    /// Development settings
29    pub dev: DevConfig,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct GeneralConfig {
34    /// Default output directory
35    pub output_dir: PathBuf,
36
37    /// Default cache directory
38    pub cache_dir: PathBuf,
39
40    /// Default device (cpu, cuda, cuda:0, etc.)
41    pub default_device: String,
42
43    /// Number of worker threads
44    pub num_workers: usize,
45
46    /// Memory limit in GB
47    pub memory_limit_gb: Option<f64>,
48
49    /// Enable progress bars
50    pub show_progress: bool,
51
52    /// Default data type (f32, f16, bf16)
53    pub default_dtype: String,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct ModelConfig {
58    /// Default model format for conversion
59    pub default_format: String,
60
61    /// Model optimization settings
62    pub optimization: OptimizationConfig,
63
64    /// Model validation settings
65    pub validation: ValidationConfig,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct OptimizationConfig {
70    /// Enable automatic optimization
71    pub auto_optimize: bool,
72
73    /// Quantization settings
74    pub quantization: QuantizationConfig,
75
76    /// Pruning settings
77    pub pruning: PruningConfig,
78
79    /// Fusion settings
80    pub fusion: FusionConfig,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct QuantizationConfig {
85    /// Enable quantization by default
86    pub enabled: bool,
87
88    /// Default quantization method (dynamic, static, qat)
89    pub method: String,
90
91    /// Default precision (int8, int4, mixed)
92    pub precision: String,
93
94    /// Calibration dataset size
95    pub calibration_samples: usize,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct PruningConfig {
100    /// Enable pruning by default
101    pub enabled: bool,
102
103    /// Default sparsity target (0.0-1.0)
104    pub sparsity: f64,
105
106    /// Pruning method (magnitude, gradient, structured)
107    pub method: String,
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct FusionConfig {
112    /// Enable operator fusion
113    pub enabled: bool,
114
115    /// Fusion patterns to apply
116    pub patterns: Vec<String>,
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct ValidationConfig {
121    /// Enable model validation by default
122    pub enabled: bool,
123
124    /// Validation dataset path
125    pub dataset_path: Option<PathBuf>,
126
127    /// Number of validation samples
128    pub num_samples: usize,
129
130    /// Accuracy threshold for validation
131    pub accuracy_threshold: f64,
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct TrainingConfig {
136    /// Default training configuration directory
137    pub config_dir: PathBuf,
138
139    /// Default checkpoint directory
140    pub checkpoint_dir: PathBuf,
141
142    /// Default logging directory
143    pub log_dir: PathBuf,
144
145    /// Auto-resume from latest checkpoint
146    pub auto_resume: bool,
147
148    /// Save checkpoint every N epochs
149    pub checkpoint_frequency: usize,
150
151    /// Early stopping patience
152    pub early_stopping_patience: usize,
153
154    /// Mixed precision training
155    pub mixed_precision: bool,
156
157    /// Distributed training settings
158    pub distributed: DistributedConfig,
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct DistributedConfig {
163    /// Backend for distributed training (nccl, gloo, mpi)
164    pub backend: String,
165
166    /// Master address for distributed training
167    pub master_addr: String,
168
169    /// Master port for distributed training
170    pub master_port: u16,
171
172    /// Auto-detect distributed environment
173    pub auto_detect: bool,
174}
175
176#[derive(Debug, Clone, Serialize, Deserialize)]
177pub struct HubConfig {
178    /// Hub API endpoint
179    pub api_endpoint: String,
180
181    /// Authentication token
182    pub auth_token: Option<String>,
183
184    /// Default organization
185    pub organization: Option<String>,
186
187    /// Model cache directory
188    pub cache_dir: PathBuf,
189
190    /// Enable model signature verification
191    pub verify_signatures: bool,
192
193    /// Connection timeout in seconds
194    pub timeout_seconds: u64,
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct BenchmarkConfig {
199    /// Default number of warmup iterations
200    pub warmup_iterations: usize,
201
202    /// Default number of benchmark iterations
203    pub benchmark_iterations: usize,
204
205    /// Default batch sizes to test
206    pub batch_sizes: Vec<usize>,
207
208    /// Enable memory tracking
209    pub track_memory: bool,
210
211    /// Enable power tracking (if available)
212    pub track_power: bool,
213
214    /// Benchmark output directory
215    pub output_dir: PathBuf,
216
217    /// Enable detailed profiling
218    pub detailed_profiling: bool,
219}
220
221#[derive(Debug, Clone, Serialize, Deserialize)]
222pub struct DevConfig {
223    /// Enable development mode features
224    pub enabled: bool,
225
226    /// Enable debug logging
227    pub debug_logging: bool,
228
229    /// Enable experimental features
230    pub experimental_features: bool,
231
232    /// Code generation settings
233    pub codegen: CodegenConfig,
234
235    /// Testing settings
236    pub testing: TestingConfig,
237}
238
239#[derive(Debug, Clone, Serialize, Deserialize)]
240pub struct CodegenConfig {
241    /// Enable automatic code generation
242    pub enabled: bool,
243
244    /// Output directory for generated code
245    pub output_dir: PathBuf,
246
247    /// Code generation templates directory
248    pub templates_dir: PathBuf,
249}
250
251#[derive(Debug, Clone, Serialize, Deserialize)]
252pub struct TestingConfig {
253    /// Enable automatic testing
254    pub enabled: bool,
255
256    /// Test data directory
257    pub test_data_dir: PathBuf,
258
259    /// Tolerance for numerical tests
260    pub numerical_tolerance: f64,
261}
262
263impl Default for Config {
264    fn default() -> Self {
265        let home_dir = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
266        let torsh_dir = home_dir.join(".torsh");
267
268        Self {
269            general: GeneralConfig {
270                output_dir: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
271                cache_dir: torsh_dir.join("cache"),
272                default_device: "cpu".to_string(),
273                num_workers: std::thread::available_parallelism()
274                    .map(|n| n.get())
275                    .unwrap_or(4),
276                memory_limit_gb: None,
277                show_progress: true,
278                default_dtype: "f32".to_string(),
279            },
280            model: ModelConfig {
281                default_format: "torsh".to_string(),
282                optimization: OptimizationConfig {
283                    auto_optimize: false,
284                    quantization: QuantizationConfig {
285                        enabled: false,
286                        method: "dynamic".to_string(),
287                        precision: "int8".to_string(),
288                        calibration_samples: 1000,
289                    },
290                    pruning: PruningConfig {
291                        enabled: false,
292                        sparsity: 0.5,
293                        method: "magnitude".to_string(),
294                    },
295                    fusion: FusionConfig {
296                        enabled: true,
297                        patterns: vec!["conv_bn_relu".to_string(), "linear_relu".to_string()],
298                    },
299                },
300                validation: ValidationConfig {
301                    enabled: true,
302                    dataset_path: None,
303                    num_samples: 1000,
304                    accuracy_threshold: 0.95,
305                },
306            },
307            training: TrainingConfig {
308                config_dir: torsh_dir.join("configs"),
309                checkpoint_dir: PathBuf::from("./checkpoints"),
310                log_dir: PathBuf::from("./logs"),
311                auto_resume: false,
312                checkpoint_frequency: 1,
313                early_stopping_patience: 10,
314                mixed_precision: true,
315                distributed: DistributedConfig {
316                    backend: "nccl".to_string(),
317                    master_addr: "localhost".to_string(),
318                    master_port: 29500,
319                    auto_detect: true,
320                },
321            },
322            hub: HubConfig {
323                api_endpoint: "https://hub.torsh.dev".to_string(),
324                auth_token: None,
325                organization: None,
326                cache_dir: torsh_dir.join("hub"),
327                verify_signatures: true,
328                timeout_seconds: 300,
329            },
330            benchmark: BenchmarkConfig {
331                warmup_iterations: 10,
332                benchmark_iterations: 100,
333                batch_sizes: vec![1, 4, 8, 16, 32, 64],
334                track_memory: true,
335                track_power: false,
336                output_dir: PathBuf::from("./benchmarks"),
337                detailed_profiling: false,
338            },
339            dev: DevConfig {
340                enabled: false,
341                debug_logging: false,
342                experimental_features: false,
343                codegen: CodegenConfig {
344                    enabled: false,
345                    output_dir: PathBuf::from("./generated"),
346                    templates_dir: torsh_dir.join("templates"),
347                },
348                testing: TestingConfig {
349                    enabled: true,
350                    test_data_dir: PathBuf::from("./test_data"),
351                    numerical_tolerance: 1e-6,
352                },
353            },
354        }
355    }
356}
357
358/// Load configuration from file or create default
359pub async fn load_config(config_path: Option<&Path>) -> Result<Config> {
360    let config_path = if let Some(path) = config_path {
361        path.to_path_buf()
362    } else {
363        get_default_config_path()?
364    };
365
366    if config_path.exists() {
367        debug!("Loading configuration from: {}", config_path.display());
368        load_config_from_file(&config_path).await
369    } else {
370        info!("Configuration file not found, using defaults");
371        let config = Config::default();
372
373        // Create config directory if it doesn't exist
374        if let Some(parent) = config_path.parent() {
375            tokio::fs::create_dir_all(parent).await.with_context(|| {
376                format!("Failed to create config directory: {}", parent.display())
377            })?;
378        }
379
380        // Save default configuration
381        save_config(&config, &config_path)
382            .await
383            .with_context(|| "Failed to save default configuration")?;
384
385        Ok(config)
386    }
387}
388
389/// Load configuration from a specific file
390async fn load_config_from_file(path: &Path) -> Result<Config> {
391    let content = tokio::fs::read_to_string(path)
392        .await
393        .with_context(|| format!("Failed to read config file: {}", path.display()))?;
394
395    let config =
396        match path.extension().and_then(|ext| ext.to_str()) {
397            Some("yaml") | Some("yml") => serde_yaml::from_str(&content)
398                .with_context(|| "Failed to parse YAML configuration")?,
399            Some("json") => serde_json::from_str(&content)
400                .with_context(|| "Failed to parse JSON configuration")?,
401            Some("toml") => {
402                toml::from_str(&content).with_context(|| "Failed to parse TOML configuration")?
403            }
404            _ => {
405                // Try to detect format
406                if content.trim_start().starts_with('{') {
407                    serde_json::from_str(&content)
408                        .with_context(|| "Failed to parse JSON configuration")?
409                } else {
410                    serde_yaml::from_str(&content)
411                        .with_context(|| "Failed to parse YAML configuration")?
412                }
413            }
414        };
415
416    Ok(config)
417}
418
419/// Save configuration to file
420pub async fn save_config(config: &Config, path: &Path) -> Result<()> {
421    let content = match path.extension().and_then(|ext| ext.to_str()) {
422        Some("json") => serde_json::to_string_pretty(config)
423            .with_context(|| "Failed to serialize configuration to JSON")?,
424        Some("toml") => toml::to_string_pretty(config)
425            .with_context(|| "Failed to serialize configuration to TOML")?,
426        _ => {
427            // Default to YAML
428            serde_yaml::to_string(config)
429                .with_context(|| "Failed to serialize configuration to YAML")?
430        }
431    };
432
433    tokio::fs::write(path, content)
434        .await
435        .with_context(|| format!("Failed to write config file: {}", path.display()))?;
436
437    info!("Configuration saved to: {}", path.display());
438    Ok(())
439}
440
441/// Get the default configuration file path
442fn get_default_config_path() -> Result<PathBuf> {
443    let config_dir = dirs::config_dir()
444        .or_else(|| dirs::home_dir().map(|h| h.join(".config")))
445        .unwrap_or_else(|| PathBuf::from("."));
446
447    Ok(config_dir.join("torsh").join("config.yaml"))
448}
449
450/// Initialize configuration directories
451pub async fn init_config_dirs(config: &Config) -> Result<()> {
452    let dirs = [
453        &config.general.cache_dir,
454        &config.training.config_dir,
455        &config.training.checkpoint_dir,
456        &config.training.log_dir,
457        &config.hub.cache_dir,
458        &config.benchmark.output_dir,
459    ];
460
461    for dir in dirs {
462        if !dir.exists() {
463            tokio::fs::create_dir_all(dir)
464                .await
465                .with_context(|| format!("Failed to create directory: {}", dir.display()))?;
466            debug!("Created directory: {}", dir.display());
467        }
468    }
469
470    Ok(())
471}
472
473#[cfg(test)]
474mod tests {
475    use super::*;
476    use tempfile::tempdir;
477
478    #[tokio::test]
479    async fn test_default_config() {
480        let config = Config::default();
481        assert_eq!(config.general.default_device, "cpu");
482        assert_eq!(config.model.default_format, "torsh");
483    }
484
485    #[tokio::test]
486    async fn test_config_serialization() {
487        let config = Config::default();
488
489        // Test YAML serialization
490        let yaml = serde_yaml::to_string(&config).unwrap();
491        let parsed: Config = serde_yaml::from_str(&yaml).unwrap();
492        assert_eq!(config.general.default_device, parsed.general.default_device);
493
494        // Test JSON serialization
495        let json = serde_json::to_string_pretty(&config).unwrap();
496        let parsed: Config = serde_json::from_str(&json).unwrap();
497        assert_eq!(config.general.default_device, parsed.general.default_device);
498    }
499
500    #[tokio::test]
501    async fn test_config_file_operations() {
502        let temp_dir = tempdir().unwrap();
503        let config_path = temp_dir.path().join("test_config.yaml");
504
505        let config = Config::default();
506
507        // Save configuration
508        save_config(&config, &config_path).await.unwrap();
509        assert!(config_path.exists());
510
511        // Load configuration
512        let loaded_config = load_config_from_file(&config_path).await.unwrap();
513        assert_eq!(
514            config.general.default_device,
515            loaded_config.general.default_device
516        );
517    }
518}