1use anyhow::{Context, Result};
4use serde::{Deserialize, Serialize};
5use std::path::{Path, PathBuf};
6use tracing::{debug, info};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct Config {
11 pub general: GeneralConfig,
13
14 pub model: ModelConfig,
16
17 pub training: TrainingConfig,
19
20 pub hub: HubConfig,
22
23 pub benchmark: BenchmarkConfig,
25
26 pub dev: DevConfig,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct GeneralConfig {
32 pub output_dir: PathBuf,
34
35 pub cache_dir: PathBuf,
37
38 pub default_device: String,
40
41 pub num_workers: usize,
43
44 pub memory_limit_gb: Option<f64>,
46
47 pub show_progress: bool,
49
50 pub default_dtype: String,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct ModelConfig {
56 pub default_format: String,
58
59 pub optimization: OptimizationConfig,
61
62 pub validation: ValidationConfig,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct OptimizationConfig {
68 pub auto_optimize: bool,
70
71 pub quantization: QuantizationConfig,
73
74 pub pruning: PruningConfig,
76
77 pub fusion: FusionConfig,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct QuantizationConfig {
83 pub enabled: bool,
85
86 pub method: String,
88
89 pub precision: String,
91
92 pub calibration_samples: usize,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct PruningConfig {
98 pub enabled: bool,
100
101 pub sparsity: f64,
103
104 pub method: String,
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct FusionConfig {
110 pub enabled: bool,
112
113 pub patterns: Vec<String>,
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct ValidationConfig {
119 pub enabled: bool,
121
122 pub dataset_path: Option<PathBuf>,
124
125 pub num_samples: usize,
127
128 pub accuracy_threshold: f64,
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct TrainingConfig {
134 pub config_dir: PathBuf,
136
137 pub checkpoint_dir: PathBuf,
139
140 pub log_dir: PathBuf,
142
143 pub auto_resume: bool,
145
146 pub checkpoint_frequency: usize,
148
149 pub early_stopping_patience: usize,
151
152 pub mixed_precision: bool,
154
155 pub distributed: DistributedConfig,
157}
158
159#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct DistributedConfig {
161 pub backend: String,
163
164 pub master_addr: String,
166
167 pub master_port: u16,
169
170 pub auto_detect: bool,
172}
173
174#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct HubConfig {
176 pub api_endpoint: String,
178
179 pub auth_token: Option<String>,
181
182 pub organization: Option<String>,
184
185 pub cache_dir: PathBuf,
187
188 pub verify_signatures: bool,
190
191 pub timeout_seconds: u64,
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct BenchmarkConfig {
197 pub warmup_iterations: usize,
199
200 pub benchmark_iterations: usize,
202
203 pub batch_sizes: Vec<usize>,
205
206 pub track_memory: bool,
208
209 pub track_power: bool,
211
212 pub output_dir: PathBuf,
214
215 pub detailed_profiling: bool,
217}
218
219#[derive(Debug, Clone, Serialize, Deserialize)]
220pub struct DevConfig {
221 pub enabled: bool,
223
224 pub debug_logging: bool,
226
227 pub experimental_features: bool,
229
230 pub codegen: CodegenConfig,
232
233 pub testing: TestingConfig,
235}
236
237#[derive(Debug, Clone, Serialize, Deserialize)]
238pub struct CodegenConfig {
239 pub enabled: bool,
241
242 pub output_dir: PathBuf,
244
245 pub templates_dir: PathBuf,
247}
248
249#[derive(Debug, Clone, Serialize, Deserialize)]
250pub struct TestingConfig {
251 pub enabled: bool,
253
254 pub test_data_dir: PathBuf,
256
257 pub numerical_tolerance: f64,
259}
260
261impl Default for Config {
262 fn default() -> Self {
263 let home_dir = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
264 let torsh_dir = home_dir.join(".torsh");
265
266 Self {
267 general: GeneralConfig {
268 output_dir: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
269 cache_dir: torsh_dir.join("cache"),
270 default_device: "cpu".to_string(),
271 num_workers: std::thread::available_parallelism()
272 .map(|n| n.get())
273 .unwrap_or(4),
274 memory_limit_gb: None,
275 show_progress: true,
276 default_dtype: "f32".to_string(),
277 },
278 model: ModelConfig {
279 default_format: "torsh".to_string(),
280 optimization: OptimizationConfig {
281 auto_optimize: false,
282 quantization: QuantizationConfig {
283 enabled: false,
284 method: "dynamic".to_string(),
285 precision: "int8".to_string(),
286 calibration_samples: 1000,
287 },
288 pruning: PruningConfig {
289 enabled: false,
290 sparsity: 0.5,
291 method: "magnitude".to_string(),
292 },
293 fusion: FusionConfig {
294 enabled: true,
295 patterns: vec!["conv_bn_relu".to_string(), "linear_relu".to_string()],
296 },
297 },
298 validation: ValidationConfig {
299 enabled: true,
300 dataset_path: None,
301 num_samples: 1000,
302 accuracy_threshold: 0.95,
303 },
304 },
305 training: TrainingConfig {
306 config_dir: torsh_dir.join("configs"),
307 checkpoint_dir: PathBuf::from("./checkpoints"),
308 log_dir: PathBuf::from("./logs"),
309 auto_resume: false,
310 checkpoint_frequency: 1,
311 early_stopping_patience: 10,
312 mixed_precision: true,
313 distributed: DistributedConfig {
314 backend: "nccl".to_string(),
315 master_addr: "localhost".to_string(),
316 master_port: 29500,
317 auto_detect: true,
318 },
319 },
320 hub: HubConfig {
321 api_endpoint: "https://hub.torsh.dev".to_string(),
322 auth_token: None,
323 organization: None,
324 cache_dir: torsh_dir.join("hub"),
325 verify_signatures: true,
326 timeout_seconds: 300,
327 },
328 benchmark: BenchmarkConfig {
329 warmup_iterations: 10,
330 benchmark_iterations: 100,
331 batch_sizes: vec![1, 4, 8, 16, 32, 64],
332 track_memory: true,
333 track_power: false,
334 output_dir: PathBuf::from("./benchmarks"),
335 detailed_profiling: false,
336 },
337 dev: DevConfig {
338 enabled: false,
339 debug_logging: false,
340 experimental_features: false,
341 codegen: CodegenConfig {
342 enabled: false,
343 output_dir: PathBuf::from("./generated"),
344 templates_dir: torsh_dir.join("templates"),
345 },
346 testing: TestingConfig {
347 enabled: true,
348 test_data_dir: PathBuf::from("./test_data"),
349 numerical_tolerance: 1e-6,
350 },
351 },
352 }
353 }
354}
355
356pub async fn load_config(config_path: Option<&Path>) -> Result<Config> {
358 let config_path = if let Some(path) = config_path {
359 path.to_path_buf()
360 } else {
361 get_default_config_path()?
362 };
363
364 if config_path.exists() {
365 debug!("Loading configuration from: {}", config_path.display());
366 load_config_from_file(&config_path).await
367 } else {
368 info!("Configuration file not found, using defaults");
369 let config = Config::default();
370
371 if let Some(parent) = config_path.parent() {
373 tokio::fs::create_dir_all(parent).await.with_context(|| {
374 format!("Failed to create config directory: {}", parent.display())
375 })?;
376 }
377
378 save_config(&config, &config_path)
380 .await
381 .with_context(|| "Failed to save default configuration")?;
382
383 Ok(config)
384 }
385}
386
387async fn load_config_from_file(path: &Path) -> Result<Config> {
389 let content = tokio::fs::read_to_string(path)
390 .await
391 .with_context(|| format!("Failed to read config file: {}", path.display()))?;
392
393 let config =
394 match path.extension().and_then(|ext| ext.to_str()) {
395 Some("yaml") | Some("yml") => serde_yaml::from_str(&content)
396 .with_context(|| "Failed to parse YAML configuration")?,
397 Some("json") => serde_json::from_str(&content)
398 .with_context(|| "Failed to parse JSON configuration")?,
399 Some("toml") => {
400 toml::from_str(&content).with_context(|| "Failed to parse TOML configuration")?
401 }
402 _ => {
403 if content.trim_start().starts_with('{') {
405 serde_json::from_str(&content)
406 .with_context(|| "Failed to parse JSON configuration")?
407 } else {
408 serde_yaml::from_str(&content)
409 .with_context(|| "Failed to parse YAML configuration")?
410 }
411 }
412 };
413
414 Ok(config)
415}
416
417pub async fn save_config(config: &Config, path: &Path) -> Result<()> {
419 let content = match path.extension().and_then(|ext| ext.to_str()) {
420 Some("json") => serde_json::to_string_pretty(config)
421 .with_context(|| "Failed to serialize configuration to JSON")?,
422 Some("toml") => toml::to_string_pretty(config)
423 .with_context(|| "Failed to serialize configuration to TOML")?,
424 _ => {
425 serde_yaml::to_string(config)
427 .with_context(|| "Failed to serialize configuration to YAML")?
428 }
429 };
430
431 tokio::fs::write(path, content)
432 .await
433 .with_context(|| format!("Failed to write config file: {}", path.display()))?;
434
435 info!("Configuration saved to: {}", path.display());
436 Ok(())
437}
438
439fn get_default_config_path() -> Result<PathBuf> {
441 let config_dir = dirs::config_dir()
442 .or_else(|| dirs::home_dir().map(|h| h.join(".config")))
443 .unwrap_or_else(|| PathBuf::from("."));
444
445 Ok(config_dir.join("torsh").join("config.yaml"))
446}
447
448pub async fn init_config_dirs(config: &Config) -> Result<()> {
450 let dirs = [
451 &config.general.cache_dir,
452 &config.training.config_dir,
453 &config.training.checkpoint_dir,
454 &config.training.log_dir,
455 &config.hub.cache_dir,
456 &config.benchmark.output_dir,
457 ];
458
459 for dir in dirs {
460 if !dir.exists() {
461 tokio::fs::create_dir_all(dir)
462 .await
463 .with_context(|| format!("Failed to create directory: {}", dir.display()))?;
464 debug!("Created directory: {}", dir.display());
465 }
466 }
467
468 Ok(())
469}
470
471#[cfg(test)]
472mod tests {
473 use super::*;
474 use tempfile::tempdir;
475
476 #[tokio::test]
477 async fn test_default_config() {
478 let config = Config::default();
479 assert_eq!(config.general.default_device, "cpu");
480 assert_eq!(config.model.default_format, "torsh");
481 }
482
483 #[tokio::test]
484 async fn test_config_serialization() {
485 let config = Config::default();
486
487 let yaml = serde_yaml::to_string(&config).unwrap();
489 let parsed: Config = serde_yaml::from_str(&yaml).unwrap();
490 assert_eq!(config.general.default_device, parsed.general.default_device);
491
492 let json = serde_json::to_string_pretty(&config).unwrap();
494 let parsed: Config = serde_json::from_str(&json).unwrap();
495 assert_eq!(config.general.default_device, parsed.general.default_device);
496 }
497
498 #[tokio::test]
499 async fn test_config_file_operations() {
500 let temp_dir = tempdir().unwrap();
501 let config_path = temp_dir.path().join("test_config.yaml");
502
503 let config = Config::default();
504
505 save_config(&config, &config_path).await.unwrap();
507 assert!(config_path.exists());
508
509 let loaded_config = load_config_from_file(&config_path).await.unwrap();
511 assert_eq!(
512 config.general.default_device,
513 loaded_config.general.default_device
514 );
515 }
516}