Skip to main content

torsh_cli/
utils.rs

1//! Utility functions for ToRSh CLI
2
3// Framework infrastructure - components designed for future use
4#![allow(dead_code)]
5use anyhow::{Context, Result};
6use byte_unit::Byte;
7use chrono::Local;
8use colored::*;
9// Console utilities available when needed
10use indicatif::{ProgressBar, ProgressStyle};
11use serde_json::Value;
12use std::collections::HashMap;
13use std::fmt::Write as FmtWrite;
14use std::path::{Path, PathBuf};
15use std::time::{Duration, Instant};
16use sysinfo::System;
17use tracing::{debug, info};
18
19/// Display the ToRSh banner
20pub fn display_banner() {
21    let banner = r#"
22  ______         _____   _____ _
23 |__   _|       |  __ \ / ____| |
24    | | ___  _ _| |__) | (___ | |__
25    | |/ _ \| '__|  _  / \___ \| '_ \
26   _| | (_) | |  | | \ \ ____) | | | |
27  |_| \___/|_|  |_|  \_\_____/|_| |_|
28
29"#;
30
31    println!("{}", banner.bright_cyan().bold());
32    println!(
33        "{}",
34        "ToRSh CLI - Advanced Deep Learning Framework Tools"
35            .bright_white()
36            .bold()
37    );
38    println!(
39        "{}",
40        format!("Version: {} | Build: {}", env!("CARGO_PKG_VERSION"), "dev").bright_black()
41    );
42    println!();
43}
44
45/// Output formatting utilities
46pub mod output {
47    use super::*;
48    use serde::Serialize;
49
50    /// Format output based on the specified format
51    pub fn format_output<T: Serialize>(data: &T, format: &str) -> Result<String> {
52        match format {
53            "json" => {
54                serde_json::to_string_pretty(data).with_context(|| "Failed to serialize to JSON")
55            }
56            "yaml" => serde_yaml::to_string(data).with_context(|| "Failed to serialize to YAML"),
57            "table" => {
58                // For table format, we'll need to implement custom formatting
59                // This is a simplified version
60                format_as_table(data)
61            }
62            _ => {
63                anyhow::bail!("Unsupported output format: {}", format)
64            }
65        }
66    }
67
68    /// Format data as a table (simplified implementation)
69    fn format_as_table<T: Serialize>(data: &T) -> Result<String> {
70        let json_value = serde_json::to_value(data)?;
71        format_json_as_table(&json_value, 0)
72    }
73
74    fn format_json_as_table(value: &Value, indent: usize) -> Result<String> {
75        let mut output = String::new();
76        let indent_str = "  ".repeat(indent);
77
78        match value {
79            Value::Object(map) => {
80                for (key, val) in map {
81                    match val {
82                        Value::Object(_) | Value::Array(_) => {
83                            writeln!(output, "{}{}:", indent_str, key.bright_cyan())?;
84                            output.push_str(&format_json_as_table(val, indent + 1)?);
85                        }
86                        _ => {
87                            writeln!(
88                                output,
89                                "{}{}: {}",
90                                indent_str,
91                                key.bright_cyan(),
92                                format_json_value(val)
93                            )?;
94                        }
95                    }
96                }
97            }
98            Value::Array(arr) => {
99                for (i, val) in arr.iter().enumerate() {
100                    writeln!(output, "{}[{}]:", indent_str, i.to_string().bright_yellow())?;
101                    output.push_str(&format_json_as_table(val, indent + 1)?);
102                }
103            }
104            _ => {
105                writeln!(output, "{}{}", indent_str, format_json_value(value))?;
106            }
107        }
108
109        Ok(output)
110    }
111
112    fn format_json_value(value: &Value) -> String {
113        match value {
114            Value::String(s) => s.green().to_string(),
115            Value::Number(n) => n.to_string().yellow().to_string(),
116            Value::Bool(b) => {
117                if *b {
118                    "true".bright_green().to_string()
119                } else {
120                    "false".bright_red().to_string()
121                }
122            }
123            Value::Null => "null".bright_black().to_string(),
124            _ => value.to_string(),
125        }
126    }
127
128    /// Print a formatted table
129    pub fn print_table<T: Serialize>(title: &str, data: &T, format: &str) -> Result<()> {
130        println!("{}", title.bright_cyan().bold());
131        println!("{}", "=".repeat(title.len()).bright_cyan());
132        println!();
133
134        let formatted = format_output(data, format)?;
135        println!("{}", formatted);
136
137        Ok(())
138    }
139
140    /// Print a success message
141    pub fn print_success(message: &str) {
142        println!("{} {}", "✓".bright_green().bold(), message);
143    }
144
145    /// Print an error message
146    pub fn print_error(message: &str) {
147        eprintln!("{} {}", "✗".bright_red().bold(), message);
148    }
149
150    /// Print a warning message
151    pub fn print_warning(message: &str) {
152        println!("{} {}", "⚠".bright_yellow().bold(), message);
153    }
154
155    /// Print an info message
156    pub fn print_info(message: &str) {
157        println!("{} {}", "ℹ".bright_blue().bold(), message);
158    }
159}
160
161/// Progress bar utilities
162pub mod progress {
163    use super::*;
164
165    /// Create a progress bar with custom style
166    pub fn create_progress_bar(len: u64, message: &str) -> ProgressBar {
167        let pb = ProgressBar::new(len);
168        pb.set_style(
169            ProgressStyle::default_bar()
170                .template("{msg} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos:>7}/{len:7} {eta}")
171                .expect("Invalid progress bar template")
172                .progress_chars("█▉▊▋▌▍▎▏  "),
173        );
174        pb.set_message(message.to_string());
175        pb
176    }
177
178    /// Create a spinner for indeterminate progress
179    pub fn create_spinner(message: &str) -> ProgressBar {
180        let pb = ProgressBar::new_spinner();
181        pb.set_style(
182            ProgressStyle::default_spinner()
183                .template("{spinner:.cyan} {msg}")
184                .expect("Invalid spinner template")
185                .tick_chars("⠁⠂⠄⡀⢀⠠⠐⠈ "),
186        );
187        pb.set_message(message.to_string());
188        pb
189    }
190}
191
192/// File system utilities
193pub mod fs {
194    use super::*;
195
196    /// Get file size as human-readable string
197    pub fn format_file_size(size: u64) -> String {
198        Byte::from_u128(size as u128)
199            .unwrap_or_else(|| Byte::from_u128(0).expect("zero bytes should always be valid"))
200            .get_appropriate_unit(byte_unit::UnitType::Binary)
201            .to_string()
202    }
203
204    /// Get directory size recursively
205    pub fn get_directory_size(
206        path: &Path,
207    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<u64>> + Send + '_>> {
208        Box::pin(async move {
209            let mut total_size = 0u64;
210            let mut read_dir = tokio::fs::read_dir(path).await?;
211
212            while let Some(entry) = read_dir.next_entry().await? {
213                let metadata = entry.metadata().await?;
214                if metadata.is_file() {
215                    total_size += metadata.len();
216                } else if metadata.is_dir() {
217                    total_size += get_directory_size(&entry.path()).await?;
218                }
219            }
220
221            Ok(total_size)
222        })
223    }
224
225    /// Find files matching a pattern
226    pub fn find_files(directory: &Path, pattern: &str) -> Result<Vec<PathBuf>> {
227        let mut files = Vec::new();
228        let walker = walkdir::WalkDir::new(directory);
229
230        for entry in walker {
231            let entry = entry?;
232            if entry.file_type().is_file() {
233                let path = entry.path();
234                if glob::Pattern::new(pattern)?.matches_path(path) {
235                    files.push(path.to_path_buf());
236                }
237            }
238        }
239
240        Ok(files)
241    }
242
243    /// Create a backup of a file
244    pub async fn backup_file(file_path: &Path) -> Result<PathBuf> {
245        let timestamp = chrono::Local::now().format("%Y%m%d_%H%M%S");
246        let backup_path = file_path.with_extension(format!(
247            "{}.backup_{}",
248            file_path.extension().unwrap_or_default().to_string_lossy(),
249            timestamp
250        ));
251
252        tokio::fs::copy(file_path, &backup_path).await?;
253        info!("Created backup: {}", backup_path.display());
254
255        Ok(backup_path)
256    }
257
258    /// Clean up temporary files
259    pub async fn cleanup_temp_files(temp_dir: &Path) -> Result<()> {
260        if temp_dir.exists() {
261            tokio::fs::remove_dir_all(temp_dir).await?;
262            debug!("Cleaned up temporary directory: {}", temp_dir.display());
263        }
264        Ok(())
265    }
266}
267
268/// System information utilities
269pub mod system {
270    use super::*;
271
272    /// System information structure
273    #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
274    pub struct SystemInfo {
275        pub os: String,
276        pub kernel_version: String,
277        pub total_memory: String,
278        pub available_memory: String,
279        pub cpu_count: usize,
280        pub cpu_brand: String,
281        pub cpu_frequency: u64,
282        pub load_average: Vec<f64>,
283        pub uptime: String,
284    }
285
286    /// Get comprehensive system information
287    pub fn get_system_info() -> SystemInfo {
288        let mut sys = System::new_all();
289        sys.refresh_all();
290
291        SystemInfo {
292            os: format!(
293                "{} {}",
294                System::name().unwrap_or_default(),
295                System::os_version().unwrap_or_default()
296            ),
297            kernel_version: System::kernel_version().unwrap_or_default(),
298            total_memory: format_memory(sys.total_memory()),
299            available_memory: format_memory(sys.available_memory()),
300            cpu_count: sys.cpus().len(),
301            cpu_brand: sys
302                .cpus()
303                .first()
304                .map(|cpu| cpu.brand())
305                .unwrap_or("Unknown")
306                .to_string(),
307            cpu_frequency: sys.cpus().first().map(|cpu| cpu.frequency()).unwrap_or(0),
308            load_average: {
309                let load = System::load_average();
310                vec![load.one, load.five, load.fifteen]
311            },
312            uptime: format_duration(Duration::from_secs(System::uptime())),
313        }
314    }
315
316    /// Format memory size
317    fn format_memory(memory_kb: u64) -> String {
318        let memory_bytes = memory_kb * 1024;
319        Byte::from_u128(memory_bytes as u128)
320            .unwrap_or_else(|| Byte::from_u128(0).expect("zero bytes should always be valid"))
321            .get_appropriate_unit(byte_unit::UnitType::Binary)
322            .to_string()
323    }
324
325    /// Check if GPU is available with real hardware detection
326    pub fn check_gpu_availability() -> HashMap<String, bool> {
327        let mut gpu_info = HashMap::new();
328
329        // Check for CUDA with actual detection
330        // #[cfg(feature = "cuda")]
331        // {
332        //     gpu_info.insert("CUDA".to_string(), detect_cuda_availability());
333        // }
334        // #[cfg(not(feature = "cuda"))]
335        {
336            // Still check for CUDA runtime even if not compiled with CUDA support
337            gpu_info.insert("CUDA".to_string(), detect_cuda_runtime());
338        }
339
340        // Check for ROCm with actual detection
341        // #[cfg(feature = "rocm")]
342        // {
343        //     gpu_info.insert("ROCm".to_string(), detect_rocm_availability());
344        // }
345        // #[cfg(not(feature = "rocm"))]
346        {
347            gpu_info.insert("ROCm".to_string(), detect_rocm_runtime());
348        }
349
350        // Check for Metal (macOS) with actual detection
351        #[cfg(target_os = "macos")]
352        {
353            gpu_info.insert("Metal".to_string(), detect_metal_availability());
354        }
355
356        // Check for Vulkan support
357        gpu_info.insert("Vulkan".to_string(), detect_vulkan_availability());
358
359        // Check for OpenCL
360        gpu_info.insert("OpenCL".to_string(), detect_opencl_availability());
361
362        gpu_info
363    }
364
365    /// Detect CUDA availability at runtime
366    // #[cfg(feature = "cuda")]
367    #[allow(dead_code)]
368    fn detect_cuda_availability() -> bool {
369        // This would use CUDA runtime API calls
370        // For now, check if CUDA libraries are present
371        detect_cuda_runtime()
372    }
373
374    fn detect_cuda_runtime() -> bool {
375        // Check for CUDA runtime by looking for nvidia-smi command
376        std::process::Command::new("nvidia-smi")
377            .arg("--query-gpu=name")
378            .arg("--format=csv,noheader")
379            .output()
380            .map(|output| output.status.success())
381            .unwrap_or(false)
382    }
383
384    /// Detect ROCm availability
385    // #[cfg(feature = "rocm")]
386    #[allow(dead_code)]
387    fn detect_rocm_availability() -> bool {
388        detect_rocm_runtime()
389    }
390
391    fn detect_rocm_runtime() -> bool {
392        // Check for ROCm by looking for rocm-smi command
393        std::process::Command::new("rocm-smi")
394            .arg("--showproductname")
395            .output()
396            .map(|output| output.status.success())
397            .unwrap_or(false)
398    }
399
400    /// Detect Metal availability (macOS only)
401    #[cfg(target_os = "macos")]
402    fn detect_metal_availability() -> bool {
403        // Check if Metal is available by running system_profiler
404        std::process::Command::new("system_profiler")
405            .arg("SPDisplaysDataType")
406            .output()
407            .map(|output| {
408                output.status.success() && String::from_utf8_lossy(&output.stdout).contains("Metal")
409            })
410            .unwrap_or(true) // Assume available on macOS if detection fails
411    }
412
413    fn detect_vulkan_availability() -> bool {
414        // Check for Vulkan by looking for vulkaninfo command
415        std::process::Command::new("vulkaninfo")
416            .arg("--summary")
417            .output()
418            .map(|output| output.status.success())
419            .unwrap_or(false)
420    }
421
422    fn detect_opencl_availability() -> bool {
423        // Check for OpenCL by looking for clinfo command
424        std::process::Command::new("clinfo")
425            .output()
426            .map(|output| output.status.success())
427            .unwrap_or(false)
428    }
429
430    /// Get comprehensive device information
431    pub fn get_device_info() -> HashMap<String, serde_json::Value> {
432        let mut device_info = HashMap::new();
433
434        // Get system info for CPU details
435        let sys_info = get_system_info();
436
437        // CPU Information
438        device_info.insert(
439            "cpu".to_string(),
440            serde_json::json!({
441                "available": true,
442                "device_type": "cpu",
443                "description": "CPU device",
444                "brand": sys_info.cpu_brand,
445                "cores": sys_info.cpu_count,
446                "frequency_mhz": sys_info.cpu_frequency,
447                "capabilities": get_cpu_capabilities(),
448            }),
449        );
450
451        // GPU Information with detailed detection
452        let gpu_availability = check_gpu_availability();
453        for (gpu_type, available) in gpu_availability {
454            let detailed_info = if available {
455                match gpu_type.as_str() {
456                    "CUDA" => get_cuda_device_details(),
457                    "ROCm" => get_rocm_device_details(),
458                    "Metal" => get_metal_device_details(),
459                    "Vulkan" => get_vulkan_device_details(),
460                    "OpenCL" => get_opencl_device_details(),
461                    _ => serde_json::json!({}),
462                }
463            } else {
464                serde_json::json!({
465                    "reason": "Runtime or drivers not detected"
466                })
467            };
468
469            device_info.insert(
470                gpu_type.to_lowercase(),
471                serde_json::json!({
472                    "available": available,
473                    "device_type": "gpu",
474                    "description": format!("{} GPU device", gpu_type),
475                    "details": detailed_info
476                }),
477            );
478        }
479
480        device_info
481    }
482
483    /// Get CPU capabilities (SIMD instructions, etc.)
484    fn get_cpu_capabilities() -> Vec<String> {
485        let mut capabilities = Vec::new();
486
487        // Check for common SIMD instruction sets
488        #[cfg(target_arch = "x86_64")]
489        {
490            if is_x86_feature_detected!("sse") {
491                capabilities.push("SSE".to_string());
492            }
493            if is_x86_feature_detected!("sse2") {
494                capabilities.push("SSE2".to_string());
495            }
496            if is_x86_feature_detected!("sse3") {
497                capabilities.push("SSE3".to_string());
498            }
499            if is_x86_feature_detected!("sse4.1") {
500                capabilities.push("SSE4.1".to_string());
501            }
502            if is_x86_feature_detected!("sse4.2") {
503                capabilities.push("SSE4.2".to_string());
504            }
505            if is_x86_feature_detected!("avx") {
506                capabilities.push("AVX".to_string());
507            }
508            if is_x86_feature_detected!("avx2") {
509                capabilities.push("AVX2".to_string());
510            }
511            if is_x86_feature_detected!("fma") {
512                capabilities.push("FMA".to_string());
513            }
514        }
515
516        #[cfg(target_arch = "aarch64")]
517        {
518            if std::arch::is_aarch64_feature_detected!("neon") {
519                capabilities.push("NEON".to_string());
520            }
521        }
522
523        capabilities
524    }
525
526    /// Get detailed CUDA device information
527    fn get_cuda_device_details() -> serde_json::Value {
528        // Use nvidia-smi to get device details
529        if let Ok(output) = std::process::Command::new("nvidia-smi")
530            .arg("--query-gpu=name,memory.total,driver_version,cuda_version")
531            .arg("--format=csv,noheader,nounits")
532            .output()
533        {
534            if output.status.success() {
535                let info = String::from_utf8_lossy(&output.stdout);
536                let lines: Vec<&str> = info.trim().split('\n').collect();
537
538                return serde_json::json!({
539                    "devices": lines.iter().enumerate().map(|(i, line)| {
540                        let parts: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
541                        if parts.len() >= 4 {
542                            serde_json::json!({
543                                "id": i,
544                                "name": parts[0],
545                                "memory_mb": parts[1],
546                                "driver_version": parts[2],
547                                "cuda_version": parts[3]
548                            })
549                        } else {
550                            serde_json::json!({
551                                "id": i,
552                                "name": "Unknown GPU",
553                                "error": "Failed to parse GPU info"
554                            })
555                        }
556                    }).collect::<Vec<_>>()
557                });
558            }
559        }
560
561        serde_json::json!({ "error": "Failed to query CUDA devices" })
562    }
563
564    /// Get detailed ROCm device information
565    fn get_rocm_device_details() -> serde_json::Value {
566        if let Ok(output) = std::process::Command::new("rocm-smi")
567            .arg("--showproductname")
568            .arg("--showmeminfo=vram")
569            .output()
570        {
571            if output.status.success() {
572                return serde_json::json!({
573                    "detected": true,
574                    "raw_output": String::from_utf8_lossy(&output.stdout)
575                });
576            }
577        }
578
579        serde_json::json!({ "error": "Failed to query ROCm devices" })
580    }
581
582    /// Get detailed Metal device information (macOS only)
583    #[cfg(target_os = "macos")]
584    fn get_metal_device_details() -> serde_json::Value {
585        if let Ok(output) = std::process::Command::new("system_profiler")
586            .arg("SPDisplaysDataType")
587            .arg("-detailLevel")
588            .arg("full")
589            .output()
590        {
591            if output.status.success() {
592                let info = String::from_utf8_lossy(&output.stdout);
593                return serde_json::json!({
594                    "detected": true,
595                    "metal_support": info.contains("Metal"),
596                    "summary": "Metal GPU acceleration available"
597                });
598            }
599        }
600
601        serde_json::json!({ "error": "Failed to query Metal devices" })
602    }
603
604    #[cfg(not(target_os = "macos"))]
605    fn get_metal_device_details() -> serde_json::Value {
606        serde_json::json!({ "error": "Metal is only available on macOS" })
607    }
608
609    /// Get Vulkan device information
610    fn get_vulkan_device_details() -> serde_json::Value {
611        if let Ok(output) = std::process::Command::new("vulkaninfo")
612            .arg("--summary")
613            .output()
614        {
615            if output.status.success() {
616                return serde_json::json!({
617                    "detected": true,
618                    "summary": "Vulkan runtime available"
619                });
620            }
621        }
622
623        serde_json::json!({ "error": "Failed to query Vulkan devices" })
624    }
625
626    /// Get OpenCL device information
627    fn get_opencl_device_details() -> serde_json::Value {
628        if let Ok(output) = std::process::Command::new("clinfo").arg("--list").output() {
629            if output.status.success() {
630                let info = String::from_utf8_lossy(&output.stdout);
631                return serde_json::json!({
632                    "detected": true,
633                    "devices_summary": info.lines().take(10).collect::<Vec<_>>()
634                });
635            }
636        }
637
638        serde_json::json!({ "error": "Failed to query OpenCL devices" })
639    }
640}
641
642/// Time and duration utilities
643pub mod time {
644    use super::*;
645
646    /// Format duration as human-readable string
647    pub fn format_duration(duration: Duration) -> String {
648        let secs = duration.as_secs();
649        if secs < 60 {
650            format!("{}s", secs)
651        } else if secs < 3600 {
652            format!("{}m {}s", secs / 60, secs % 60)
653        } else if secs < 86400 {
654            format!("{}h {}m", secs / 3600, (secs % 3600) / 60)
655        } else {
656            format!("{}d {}h", secs / 86400, (secs % 86400) / 3600)
657        }
658    }
659
660    /// Get current timestamp as string
661    pub fn current_timestamp() -> String {
662        Local::now().format("%Y-%m-%d %H:%M:%S").to_string()
663    }
664
665    /// Parse human-readable duration
666    pub fn parse_duration(s: &str) -> Result<Duration> {
667        humantime::parse_duration(s).with_context(|| format!("Failed to parse duration: {}", s))
668    }
669
670    /// Measure execution time
671    pub async fn measure_time<F, T>(f: F) -> (T, Duration)
672    where
673        F: std::future::Future<Output = T>,
674    {
675        let start = Instant::now();
676        let result = f.await;
677        let duration = start.elapsed();
678        (result, duration)
679    }
680}
681
682/// Network utilities
683pub mod network {
684    use super::*;
685
686    /// Download a file with progress
687    pub async fn download_file_with_progress(
688        url: &str,
689        output_path: &Path,
690        show_progress: bool,
691    ) -> Result<()> {
692        let client = reqwest::Client::new();
693        let response = client.get(url).send().await?;
694
695        let total_size = response.content_length().unwrap_or(0);
696
697        let pb = if show_progress && total_size > 0 {
698            Some(progress::create_progress_bar(
699                total_size,
700                &format!(
701                    "Downloading {}",
702                    output_path
703                        .file_name()
704                        .unwrap_or_default()
705                        .to_string_lossy()
706                ),
707            ))
708        } else {
709            None
710        };
711
712        let mut file = tokio::fs::File::create(output_path).await?;
713        let mut downloaded = 0u64;
714        let mut stream = response.bytes_stream();
715
716        use futures_util::StreamExt;
717        use tokio::io::AsyncWriteExt;
718
719        while let Some(chunk) = stream.next().await {
720            let chunk = chunk?;
721            file.write_all(&chunk).await?;
722            downloaded += chunk.len() as u64;
723
724            if let Some(pb) = &pb {
725                pb.set_position(downloaded);
726            }
727        }
728
729        if let Some(pb) = pb {
730            pb.finish_with_message("Download completed");
731        }
732
733        Ok(())
734    }
735
736    /// Check if URL is accessible
737    pub async fn check_url_accessible(url: &str) -> bool {
738        let client = reqwest::Client::new();
739        client.head(url).send().await.is_ok()
740    }
741}
742
743/// Validation utilities
744pub mod validation {
745    use super::*;
746
747    /// Validate file exists and is readable
748    pub fn validate_file_exists(path: &Path) -> Result<()> {
749        if !path.exists() {
750            anyhow::bail!("File does not exist: {}", path.display());
751        }
752        if !path.is_file() {
753            anyhow::bail!("Path is not a file: {}", path.display());
754        }
755        Ok(())
756    }
757
758    /// Validate directory exists and is accessible
759    pub fn validate_directory_exists(path: &Path) -> Result<()> {
760        if !path.exists() {
761            anyhow::bail!("Directory does not exist: {}", path.display());
762        }
763        if !path.is_dir() {
764            anyhow::bail!("Path is not a directory: {}", path.display());
765        }
766        Ok(())
767    }
768
769    /// Validate model format
770    pub fn validate_model_format(format: &str) -> Result<()> {
771        let supported_formats = ["torsh", "pytorch", "onnx", "tensorflow", "tflite"];
772        if !supported_formats.contains(&format) {
773            anyhow::bail!(
774                "Unsupported model format: {}. Supported formats: {}",
775                format,
776                supported_formats.join(", ")
777            );
778        }
779        Ok(())
780    }
781
782    /// Validate device string
783    pub fn validate_device(device: &str) -> Result<()> {
784        if device == "cpu" {
785            return Ok(());
786        }
787
788        if device.starts_with("cuda") {
789            let parts: Vec<&str> = device.split(':').collect();
790            if parts.len() == 2 {
791                if parts[1].parse::<usize>().is_err() {
792                    anyhow::bail!("Invalid CUDA device ID: {}", parts[1]);
793                }
794                return Ok(());
795            } else if parts.len() == 1 && parts[0] == "cuda" {
796                return Ok(());
797            }
798        }
799
800        if device == "metal" {
801            return Ok(());
802        }
803
804        anyhow::bail!(
805            "Invalid device format: {}. Use 'cpu', 'cuda', 'cuda:N', or 'metal'",
806            device
807        );
808    }
809}
810
811/// Interactive utilities
812pub mod interactive {
813    use super::*;
814    use dialoguer::{Confirm, Input, Select};
815
816    /// Ask user for confirmation
817    pub fn confirm(message: &str, default: bool) -> Result<bool> {
818        Confirm::new()
819            .with_prompt(message)
820            .default(default)
821            .interact()
822            .with_context(|| "Failed to get user confirmation")
823    }
824
825    /// Get text input from user
826    pub fn input<T>(message: &str, default: Option<T>) -> Result<T>
827    where
828        T: Clone + std::fmt::Display + std::str::FromStr,
829        T::Err: std::fmt::Display + std::fmt::Debug + Send + Sync + 'static,
830    {
831        let mut input = Input::new().with_prompt(message);
832
833        if let Some(default_value) = default {
834            input = input.default(default_value);
835        }
836
837        input
838            .interact_text()
839            .with_context(|| "Failed to get user input")
840    }
841
842    /// Select from a list of options
843    pub fn select(message: &str, options: &[String]) -> Result<usize> {
844        Select::new()
845            .with_prompt(message)
846            .items(options)
847            .interact()
848            .with_context(|| "Failed to get user selection")
849    }
850}
851
852/// Export format_duration function at module level
853pub use time::format_duration;
854
855#[cfg(test)]
856mod tests {
857    use super::*;
858    use tempfile::tempdir;
859
860    #[test]
861    fn test_format_duration() {
862        assert_eq!(time::format_duration(Duration::from_secs(30)), "30s");
863        assert_eq!(time::format_duration(Duration::from_secs(90)), "1m 30s");
864        assert_eq!(time::format_duration(Duration::from_secs(3661)), "1h 1m");
865    }
866
867    #[test]
868    fn test_validation() {
869        assert!(validation::validate_model_format("torsh").is_ok());
870        assert!(validation::validate_model_format("invalid").is_err());
871
872        assert!(validation::validate_device("cpu").is_ok());
873        assert!(validation::validate_device("cuda:0").is_ok());
874        assert!(validation::validate_device("invalid").is_err());
875    }
876
877    #[tokio::test]
878    async fn test_file_operations() {
879        let temp_dir = tempdir().unwrap();
880        let test_file = temp_dir.path().join("test.txt");
881
882        tokio::fs::write(&test_file, "test content").await.unwrap();
883
884        let size = fs::get_directory_size(temp_dir.path()).await.unwrap();
885        assert!(size > 0);
886
887        let backup = fs::backup_file(&test_file).await.unwrap();
888        assert!(backup.exists());
889    }
890
891    #[test]
892    fn test_output_formatting() {
893        use serde_json::json;
894
895        let data = json!({
896            "name": "test",
897            "value": 42,
898            "active": true
899        });
900
901        let json_output = output::format_output(&data, "json").unwrap();
902        assert!(json_output.contains("test"));
903
904        let yaml_output = output::format_output(&data, "yaml").unwrap();
905        assert!(yaml_output.contains("name: test"));
906    }
907}