1#![allow(dead_code)]
5use anyhow::{Context, Result};
6use serde::{Deserialize, Serialize};
7use std::path::{Path, PathBuf};
8use tracing::{debug, info};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct Config {
13 pub general: GeneralConfig,
15
16 pub model: ModelConfig,
18
19 pub training: TrainingConfig,
21
22 pub hub: HubConfig,
24
25 pub benchmark: BenchmarkConfig,
27
28 pub dev: DevConfig,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct GeneralConfig {
34 pub output_dir: PathBuf,
36
37 pub cache_dir: PathBuf,
39
40 pub default_device: String,
42
43 pub num_workers: usize,
45
46 pub memory_limit_gb: Option<f64>,
48
49 pub show_progress: bool,
51
52 pub default_dtype: String,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct ModelConfig {
58 pub default_format: String,
60
61 pub optimization: OptimizationConfig,
63
64 pub validation: ValidationConfig,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct OptimizationConfig {
70 pub auto_optimize: bool,
72
73 pub quantization: QuantizationConfig,
75
76 pub pruning: PruningConfig,
78
79 pub fusion: FusionConfig,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct QuantizationConfig {
85 pub enabled: bool,
87
88 pub method: String,
90
91 pub precision: String,
93
94 pub calibration_samples: usize,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct PruningConfig {
100 pub enabled: bool,
102
103 pub sparsity: f64,
105
106 pub method: String,
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct FusionConfig {
112 pub enabled: bool,
114
115 pub patterns: Vec<String>,
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct ValidationConfig {
121 pub enabled: bool,
123
124 pub dataset_path: Option<PathBuf>,
126
127 pub num_samples: usize,
129
130 pub accuracy_threshold: f64,
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct TrainingConfig {
136 pub config_dir: PathBuf,
138
139 pub checkpoint_dir: PathBuf,
141
142 pub log_dir: PathBuf,
144
145 pub auto_resume: bool,
147
148 pub checkpoint_frequency: usize,
150
151 pub early_stopping_patience: usize,
153
154 pub mixed_precision: bool,
156
157 pub distributed: DistributedConfig,
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct DistributedConfig {
163 pub backend: String,
165
166 pub master_addr: String,
168
169 pub master_port: u16,
171
172 pub auto_detect: bool,
174}
175
176#[derive(Debug, Clone, Serialize, Deserialize)]
177pub struct HubConfig {
178 pub api_endpoint: String,
180
181 pub auth_token: Option<String>,
183
184 pub organization: Option<String>,
186
187 pub cache_dir: PathBuf,
189
190 pub verify_signatures: bool,
192
193 pub timeout_seconds: u64,
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct BenchmarkConfig {
199 pub warmup_iterations: usize,
201
202 pub benchmark_iterations: usize,
204
205 pub batch_sizes: Vec<usize>,
207
208 pub track_memory: bool,
210
211 pub track_power: bool,
213
214 pub output_dir: PathBuf,
216
217 pub detailed_profiling: bool,
219}
220
221#[derive(Debug, Clone, Serialize, Deserialize)]
222pub struct DevConfig {
223 pub enabled: bool,
225
226 pub debug_logging: bool,
228
229 pub experimental_features: bool,
231
232 pub codegen: CodegenConfig,
234
235 pub testing: TestingConfig,
237}
238
239#[derive(Debug, Clone, Serialize, Deserialize)]
240pub struct CodegenConfig {
241 pub enabled: bool,
243
244 pub output_dir: PathBuf,
246
247 pub templates_dir: PathBuf,
249}
250
251#[derive(Debug, Clone, Serialize, Deserialize)]
252pub struct TestingConfig {
253 pub enabled: bool,
255
256 pub test_data_dir: PathBuf,
258
259 pub numerical_tolerance: f64,
261}
262
263impl Default for Config {
264 fn default() -> Self {
265 let home_dir = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
266 let torsh_dir = home_dir.join(".torsh");
267
268 Self {
269 general: GeneralConfig {
270 output_dir: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
271 cache_dir: torsh_dir.join("cache"),
272 default_device: "cpu".to_string(),
273 num_workers: std::thread::available_parallelism()
274 .map(|n| n.get())
275 .unwrap_or(4),
276 memory_limit_gb: None,
277 show_progress: true,
278 default_dtype: "f32".to_string(),
279 },
280 model: ModelConfig {
281 default_format: "torsh".to_string(),
282 optimization: OptimizationConfig {
283 auto_optimize: false,
284 quantization: QuantizationConfig {
285 enabled: false,
286 method: "dynamic".to_string(),
287 precision: "int8".to_string(),
288 calibration_samples: 1000,
289 },
290 pruning: PruningConfig {
291 enabled: false,
292 sparsity: 0.5,
293 method: "magnitude".to_string(),
294 },
295 fusion: FusionConfig {
296 enabled: true,
297 patterns: vec!["conv_bn_relu".to_string(), "linear_relu".to_string()],
298 },
299 },
300 validation: ValidationConfig {
301 enabled: true,
302 dataset_path: None,
303 num_samples: 1000,
304 accuracy_threshold: 0.95,
305 },
306 },
307 training: TrainingConfig {
308 config_dir: torsh_dir.join("configs"),
309 checkpoint_dir: PathBuf::from("./checkpoints"),
310 log_dir: PathBuf::from("./logs"),
311 auto_resume: false,
312 checkpoint_frequency: 1,
313 early_stopping_patience: 10,
314 mixed_precision: true,
315 distributed: DistributedConfig {
316 backend: "nccl".to_string(),
317 master_addr: "localhost".to_string(),
318 master_port: 29500,
319 auto_detect: true,
320 },
321 },
322 hub: HubConfig {
323 api_endpoint: "https://hub.torsh.dev".to_string(),
324 auth_token: None,
325 organization: None,
326 cache_dir: torsh_dir.join("hub"),
327 verify_signatures: true,
328 timeout_seconds: 300,
329 },
330 benchmark: BenchmarkConfig {
331 warmup_iterations: 10,
332 benchmark_iterations: 100,
333 batch_sizes: vec![1, 4, 8, 16, 32, 64],
334 track_memory: true,
335 track_power: false,
336 output_dir: PathBuf::from("./benchmarks"),
337 detailed_profiling: false,
338 },
339 dev: DevConfig {
340 enabled: false,
341 debug_logging: false,
342 experimental_features: false,
343 codegen: CodegenConfig {
344 enabled: false,
345 output_dir: PathBuf::from("./generated"),
346 templates_dir: torsh_dir.join("templates"),
347 },
348 testing: TestingConfig {
349 enabled: true,
350 test_data_dir: PathBuf::from("./test_data"),
351 numerical_tolerance: 1e-6,
352 },
353 },
354 }
355 }
356}
357
358pub async fn load_config(config_path: Option<&Path>) -> Result<Config> {
360 let config_path = if let Some(path) = config_path {
361 path.to_path_buf()
362 } else {
363 get_default_config_path()?
364 };
365
366 if config_path.exists() {
367 debug!("Loading configuration from: {}", config_path.display());
368 load_config_from_file(&config_path).await
369 } else {
370 info!("Configuration file not found, using defaults");
371 let config = Config::default();
372
373 if let Some(parent) = config_path.parent() {
375 tokio::fs::create_dir_all(parent).await.with_context(|| {
376 format!("Failed to create config directory: {}", parent.display())
377 })?;
378 }
379
380 save_config(&config, &config_path)
382 .await
383 .with_context(|| "Failed to save default configuration")?;
384
385 Ok(config)
386 }
387}
388
389async fn load_config_from_file(path: &Path) -> Result<Config> {
391 let content = tokio::fs::read_to_string(path)
392 .await
393 .with_context(|| format!("Failed to read config file: {}", path.display()))?;
394
395 let config =
396 match path.extension().and_then(|ext| ext.to_str()) {
397 Some("yaml") | Some("yml") => serde_yaml::from_str(&content)
398 .with_context(|| "Failed to parse YAML configuration")?,
399 Some("json") => serde_json::from_str(&content)
400 .with_context(|| "Failed to parse JSON configuration")?,
401 Some("toml") => {
402 toml::from_str(&content).with_context(|| "Failed to parse TOML configuration")?
403 }
404 _ => {
405 if content.trim_start().starts_with('{') {
407 serde_json::from_str(&content)
408 .with_context(|| "Failed to parse JSON configuration")?
409 } else {
410 serde_yaml::from_str(&content)
411 .with_context(|| "Failed to parse YAML configuration")?
412 }
413 }
414 };
415
416 Ok(config)
417}
418
419pub async fn save_config(config: &Config, path: &Path) -> Result<()> {
421 let content = match path.extension().and_then(|ext| ext.to_str()) {
422 Some("json") => serde_json::to_string_pretty(config)
423 .with_context(|| "Failed to serialize configuration to JSON")?,
424 Some("toml") => toml::to_string_pretty(config)
425 .with_context(|| "Failed to serialize configuration to TOML")?,
426 _ => {
427 serde_yaml::to_string(config)
429 .with_context(|| "Failed to serialize configuration to YAML")?
430 }
431 };
432
433 tokio::fs::write(path, content)
434 .await
435 .with_context(|| format!("Failed to write config file: {}", path.display()))?;
436
437 info!("Configuration saved to: {}", path.display());
438 Ok(())
439}
440
441fn get_default_config_path() -> Result<PathBuf> {
443 let config_dir = dirs::config_dir()
444 .or_else(|| dirs::home_dir().map(|h| h.join(".config")))
445 .unwrap_or_else(|| PathBuf::from("."));
446
447 Ok(config_dir.join("torsh").join("config.yaml"))
448}
449
450pub async fn init_config_dirs(config: &Config) -> Result<()> {
452 let dirs = [
453 &config.general.cache_dir,
454 &config.training.config_dir,
455 &config.training.checkpoint_dir,
456 &config.training.log_dir,
457 &config.hub.cache_dir,
458 &config.benchmark.output_dir,
459 ];
460
461 for dir in dirs {
462 if !dir.exists() {
463 tokio::fs::create_dir_all(dir)
464 .await
465 .with_context(|| format!("Failed to create directory: {}", dir.display()))?;
466 debug!("Created directory: {}", dir.display());
467 }
468 }
469
470 Ok(())
471}
472
473#[cfg(test)]
474mod tests {
475 use super::*;
476 use tempfile::tempdir;
477
478 #[tokio::test]
479 async fn test_default_config() {
480 let config = Config::default();
481 assert_eq!(config.general.default_device, "cpu");
482 assert_eq!(config.model.default_format, "torsh");
483 }
484
485 #[tokio::test]
486 async fn test_config_serialization() {
487 let config = Config::default();
488
489 let yaml = serde_yaml::to_string(&config).unwrap();
491 let parsed: Config = serde_yaml::from_str(&yaml).unwrap();
492 assert_eq!(config.general.default_device, parsed.general.default_device);
493
494 let json = serde_json::to_string_pretty(&config).unwrap();
496 let parsed: Config = serde_json::from_str(&json).unwrap();
497 assert_eq!(config.general.default_device, parsed.general.default_device);
498 }
499
500 #[tokio::test]
501 async fn test_config_file_operations() {
502 let temp_dir = tempdir().unwrap();
503 let config_path = temp_dir.path().join("test_config.yaml");
504
505 let config = Config::default();
506
507 save_config(&config, &config_path).await.unwrap();
509 assert!(config_path.exists());
510
511 let loaded_config = load_config_from_file(&config_path).await.unwrap();
513 assert_eq!(
514 config.general.default_device,
515 loaded_config.general.default_device
516 );
517 }
518}