torsh_cli/
config.rs

1//! Configuration management for ToRSh CLI
2
3use anyhow::{Context, Result};
4use serde::{Deserialize, Serialize};
5use std::path::{Path, PathBuf};
6use tracing::{debug, info};
7
8/// CLI configuration
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct Config {
11    /// General settings
12    pub general: GeneralConfig,
13
14    /// Model operations settings
15    pub model: ModelConfig,
16
17    /// Training settings
18    pub training: TrainingConfig,
19
20    /// Hub settings
21    pub hub: HubConfig,
22
23    /// Benchmark settings
24    pub benchmark: BenchmarkConfig,
25
26    /// Development settings
27    pub dev: DevConfig,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct GeneralConfig {
32    /// Default output directory
33    pub output_dir: PathBuf,
34
35    /// Default cache directory
36    pub cache_dir: PathBuf,
37
38    /// Default device (cpu, cuda, cuda:0, etc.)
39    pub default_device: String,
40
41    /// Number of worker threads
42    pub num_workers: usize,
43
44    /// Memory limit in GB
45    pub memory_limit_gb: Option<f64>,
46
47    /// Enable progress bars
48    pub show_progress: bool,
49
50    /// Default data type (f32, f16, bf16)
51    pub default_dtype: String,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct ModelConfig {
56    /// Default model format for conversion
57    pub default_format: String,
58
59    /// Model optimization settings
60    pub optimization: OptimizationConfig,
61
62    /// Model validation settings
63    pub validation: ValidationConfig,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct OptimizationConfig {
68    /// Enable automatic optimization
69    pub auto_optimize: bool,
70
71    /// Quantization settings
72    pub quantization: QuantizationConfig,
73
74    /// Pruning settings
75    pub pruning: PruningConfig,
76
77    /// Fusion settings
78    pub fusion: FusionConfig,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct QuantizationConfig {
83    /// Enable quantization by default
84    pub enabled: bool,
85
86    /// Default quantization method (dynamic, static, qat)
87    pub method: String,
88
89    /// Default precision (int8, int4, mixed)
90    pub precision: String,
91
92    /// Calibration dataset size
93    pub calibration_samples: usize,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct PruningConfig {
98    /// Enable pruning by default
99    pub enabled: bool,
100
101    /// Default sparsity target (0.0-1.0)
102    pub sparsity: f64,
103
104    /// Pruning method (magnitude, gradient, structured)
105    pub method: String,
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct FusionConfig {
110    /// Enable operator fusion
111    pub enabled: bool,
112
113    /// Fusion patterns to apply
114    pub patterns: Vec<String>,
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct ValidationConfig {
119    /// Enable model validation by default
120    pub enabled: bool,
121
122    /// Validation dataset path
123    pub dataset_path: Option<PathBuf>,
124
125    /// Number of validation samples
126    pub num_samples: usize,
127
128    /// Accuracy threshold for validation
129    pub accuracy_threshold: f64,
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct TrainingConfig {
134    /// Default training configuration directory
135    pub config_dir: PathBuf,
136
137    /// Default checkpoint directory
138    pub checkpoint_dir: PathBuf,
139
140    /// Default logging directory
141    pub log_dir: PathBuf,
142
143    /// Auto-resume from latest checkpoint
144    pub auto_resume: bool,
145
146    /// Save checkpoint every N epochs
147    pub checkpoint_frequency: usize,
148
149    /// Early stopping patience
150    pub early_stopping_patience: usize,
151
152    /// Mixed precision training
153    pub mixed_precision: bool,
154
155    /// Distributed training settings
156    pub distributed: DistributedConfig,
157}
158
159#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct DistributedConfig {
161    /// Backend for distributed training (nccl, gloo, mpi)
162    pub backend: String,
163
164    /// Master address for distributed training
165    pub master_addr: String,
166
167    /// Master port for distributed training
168    pub master_port: u16,
169
170    /// Auto-detect distributed environment
171    pub auto_detect: bool,
172}
173
174#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct HubConfig {
176    /// Hub API endpoint
177    pub api_endpoint: String,
178
179    /// Authentication token
180    pub auth_token: Option<String>,
181
182    /// Default organization
183    pub organization: Option<String>,
184
185    /// Model cache directory
186    pub cache_dir: PathBuf,
187
188    /// Enable model signature verification
189    pub verify_signatures: bool,
190
191    /// Connection timeout in seconds
192    pub timeout_seconds: u64,
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct BenchmarkConfig {
197    /// Default number of warmup iterations
198    pub warmup_iterations: usize,
199
200    /// Default number of benchmark iterations
201    pub benchmark_iterations: usize,
202
203    /// Default batch sizes to test
204    pub batch_sizes: Vec<usize>,
205
206    /// Enable memory tracking
207    pub track_memory: bool,
208
209    /// Enable power tracking (if available)
210    pub track_power: bool,
211
212    /// Benchmark output directory
213    pub output_dir: PathBuf,
214
215    /// Enable detailed profiling
216    pub detailed_profiling: bool,
217}
218
219#[derive(Debug, Clone, Serialize, Deserialize)]
220pub struct DevConfig {
221    /// Enable development mode features
222    pub enabled: bool,
223
224    /// Enable debug logging
225    pub debug_logging: bool,
226
227    /// Enable experimental features
228    pub experimental_features: bool,
229
230    /// Code generation settings
231    pub codegen: CodegenConfig,
232
233    /// Testing settings
234    pub testing: TestingConfig,
235}
236
237#[derive(Debug, Clone, Serialize, Deserialize)]
238pub struct CodegenConfig {
239    /// Enable automatic code generation
240    pub enabled: bool,
241
242    /// Output directory for generated code
243    pub output_dir: PathBuf,
244
245    /// Code generation templates directory
246    pub templates_dir: PathBuf,
247}
248
249#[derive(Debug, Clone, Serialize, Deserialize)]
250pub struct TestingConfig {
251    /// Enable automatic testing
252    pub enabled: bool,
253
254    /// Test data directory
255    pub test_data_dir: PathBuf,
256
257    /// Tolerance for numerical tests
258    pub numerical_tolerance: f64,
259}
260
261impl Default for Config {
262    fn default() -> Self {
263        let home_dir = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
264        let torsh_dir = home_dir.join(".torsh");
265
266        Self {
267            general: GeneralConfig {
268                output_dir: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
269                cache_dir: torsh_dir.join("cache"),
270                default_device: "cpu".to_string(),
271                num_workers: std::thread::available_parallelism()
272                    .map(|n| n.get())
273                    .unwrap_or(4),
274                memory_limit_gb: None,
275                show_progress: true,
276                default_dtype: "f32".to_string(),
277            },
278            model: ModelConfig {
279                default_format: "torsh".to_string(),
280                optimization: OptimizationConfig {
281                    auto_optimize: false,
282                    quantization: QuantizationConfig {
283                        enabled: false,
284                        method: "dynamic".to_string(),
285                        precision: "int8".to_string(),
286                        calibration_samples: 1000,
287                    },
288                    pruning: PruningConfig {
289                        enabled: false,
290                        sparsity: 0.5,
291                        method: "magnitude".to_string(),
292                    },
293                    fusion: FusionConfig {
294                        enabled: true,
295                        patterns: vec!["conv_bn_relu".to_string(), "linear_relu".to_string()],
296                    },
297                },
298                validation: ValidationConfig {
299                    enabled: true,
300                    dataset_path: None,
301                    num_samples: 1000,
302                    accuracy_threshold: 0.95,
303                },
304            },
305            training: TrainingConfig {
306                config_dir: torsh_dir.join("configs"),
307                checkpoint_dir: PathBuf::from("./checkpoints"),
308                log_dir: PathBuf::from("./logs"),
309                auto_resume: false,
310                checkpoint_frequency: 1,
311                early_stopping_patience: 10,
312                mixed_precision: true,
313                distributed: DistributedConfig {
314                    backend: "nccl".to_string(),
315                    master_addr: "localhost".to_string(),
316                    master_port: 29500,
317                    auto_detect: true,
318                },
319            },
320            hub: HubConfig {
321                api_endpoint: "https://hub.torsh.dev".to_string(),
322                auth_token: None,
323                organization: None,
324                cache_dir: torsh_dir.join("hub"),
325                verify_signatures: true,
326                timeout_seconds: 300,
327            },
328            benchmark: BenchmarkConfig {
329                warmup_iterations: 10,
330                benchmark_iterations: 100,
331                batch_sizes: vec![1, 4, 8, 16, 32, 64],
332                track_memory: true,
333                track_power: false,
334                output_dir: PathBuf::from("./benchmarks"),
335                detailed_profiling: false,
336            },
337            dev: DevConfig {
338                enabled: false,
339                debug_logging: false,
340                experimental_features: false,
341                codegen: CodegenConfig {
342                    enabled: false,
343                    output_dir: PathBuf::from("./generated"),
344                    templates_dir: torsh_dir.join("templates"),
345                },
346                testing: TestingConfig {
347                    enabled: true,
348                    test_data_dir: PathBuf::from("./test_data"),
349                    numerical_tolerance: 1e-6,
350                },
351            },
352        }
353    }
354}
355
356/// Load configuration from file or create default
357pub async fn load_config(config_path: Option<&Path>) -> Result<Config> {
358    let config_path = if let Some(path) = config_path {
359        path.to_path_buf()
360    } else {
361        get_default_config_path()?
362    };
363
364    if config_path.exists() {
365        debug!("Loading configuration from: {}", config_path.display());
366        load_config_from_file(&config_path).await
367    } else {
368        info!("Configuration file not found, using defaults");
369        let config = Config::default();
370
371        // Create config directory if it doesn't exist
372        if let Some(parent) = config_path.parent() {
373            tokio::fs::create_dir_all(parent).await.with_context(|| {
374                format!("Failed to create config directory: {}", parent.display())
375            })?;
376        }
377
378        // Save default configuration
379        save_config(&config, &config_path)
380            .await
381            .with_context(|| "Failed to save default configuration")?;
382
383        Ok(config)
384    }
385}
386
387/// Load configuration from a specific file
388async fn load_config_from_file(path: &Path) -> Result<Config> {
389    let content = tokio::fs::read_to_string(path)
390        .await
391        .with_context(|| format!("Failed to read config file: {}", path.display()))?;
392
393    let config =
394        match path.extension().and_then(|ext| ext.to_str()) {
395            Some("yaml") | Some("yml") => serde_yaml::from_str(&content)
396                .with_context(|| "Failed to parse YAML configuration")?,
397            Some("json") => serde_json::from_str(&content)
398                .with_context(|| "Failed to parse JSON configuration")?,
399            Some("toml") => {
400                toml::from_str(&content).with_context(|| "Failed to parse TOML configuration")?
401            }
402            _ => {
403                // Try to detect format
404                if content.trim_start().starts_with('{') {
405                    serde_json::from_str(&content)
406                        .with_context(|| "Failed to parse JSON configuration")?
407                } else {
408                    serde_yaml::from_str(&content)
409                        .with_context(|| "Failed to parse YAML configuration")?
410                }
411            }
412        };
413
414    Ok(config)
415}
416
417/// Save configuration to file
418pub async fn save_config(config: &Config, path: &Path) -> Result<()> {
419    let content = match path.extension().and_then(|ext| ext.to_str()) {
420        Some("json") => serde_json::to_string_pretty(config)
421            .with_context(|| "Failed to serialize configuration to JSON")?,
422        Some("toml") => toml::to_string_pretty(config)
423            .with_context(|| "Failed to serialize configuration to TOML")?,
424        _ => {
425            // Default to YAML
426            serde_yaml::to_string(config)
427                .with_context(|| "Failed to serialize configuration to YAML")?
428        }
429    };
430
431    tokio::fs::write(path, content)
432        .await
433        .with_context(|| format!("Failed to write config file: {}", path.display()))?;
434
435    info!("Configuration saved to: {}", path.display());
436    Ok(())
437}
438
439/// Get the default configuration file path
440fn get_default_config_path() -> Result<PathBuf> {
441    let config_dir = dirs::config_dir()
442        .or_else(|| dirs::home_dir().map(|h| h.join(".config")))
443        .unwrap_or_else(|| PathBuf::from("."));
444
445    Ok(config_dir.join("torsh").join("config.yaml"))
446}
447
448/// Initialize configuration directories
449pub async fn init_config_dirs(config: &Config) -> Result<()> {
450    let dirs = [
451        &config.general.cache_dir,
452        &config.training.config_dir,
453        &config.training.checkpoint_dir,
454        &config.training.log_dir,
455        &config.hub.cache_dir,
456        &config.benchmark.output_dir,
457    ];
458
459    for dir in dirs {
460        if !dir.exists() {
461            tokio::fs::create_dir_all(dir)
462                .await
463                .with_context(|| format!("Failed to create directory: {}", dir.display()))?;
464            debug!("Created directory: {}", dir.display());
465        }
466    }
467
468    Ok(())
469}
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474    use tempfile::tempdir;
475
476    #[tokio::test]
477    async fn test_default_config() {
478        let config = Config::default();
479        assert_eq!(config.general.default_device, "cpu");
480        assert_eq!(config.model.default_format, "torsh");
481    }
482
483    #[tokio::test]
484    async fn test_config_serialization() {
485        let config = Config::default();
486
487        // Test YAML serialization
488        let yaml = serde_yaml::to_string(&config).unwrap();
489        let parsed: Config = serde_yaml::from_str(&yaml).unwrap();
490        assert_eq!(config.general.default_device, parsed.general.default_device);
491
492        // Test JSON serialization
493        let json = serde_json::to_string_pretty(&config).unwrap();
494        let parsed: Config = serde_json::from_str(&json).unwrap();
495        assert_eq!(config.general.default_device, parsed.general.default_device);
496    }
497
498    #[tokio::test]
499    async fn test_config_file_operations() {
500        let temp_dir = tempdir().unwrap();
501        let config_path = temp_dir.path().join("test_config.yaml");
502
503        let config = Config::default();
504
505        // Save configuration
506        save_config(&config, &config_path).await.unwrap();
507        assert!(config_path.exists());
508
509        // Load configuration
510        let loaded_config = load_config_from_file(&config_path).await.unwrap();
511        assert_eq!(
512            config.general.default_device,
513            loaded_config.general.default_device
514        );
515    }
516}