Skip to main content

voirs_cli/commands/
model_inspect.rs

1//! Model inspection and analysis command implementation.
2//!
3//! Provides tools to inspect model architecture, verify integrity, and analyze checkpoints.
4
5use crate::GlobalOptions;
6use std::fs::File;
7use std::io::Read;
8use std::path::{Path, PathBuf};
9use voirs_sdk::Result;
10
11/// Model inspection result
12#[derive(Debug)]
13pub struct ModelInspection {
14    pub model_type: String,
15    pub format: String,
16    pub file_size: u64,
17    pub parameter_count: Option<usize>,
18    pub layers: Vec<LayerInfo>,
19    pub metadata: Vec<(String, String)>,
20}
21
22/// Layer information
23#[derive(Debug)]
24pub struct LayerInfo {
25    pub name: String,
26    pub layer_type: String,
27    pub shape: Vec<usize>,
28    pub param_count: usize,
29}
30
31/// Run model inspection command.
32pub async fn run_model_inspect(
33    model_path: &Path,
34    detailed: bool,
35    export_path: Option<&PathBuf>,
36    verify: bool,
37    global: &GlobalOptions,
38) -> Result<()> {
39    if !global.quiet {
40        println!("šŸ” Inspecting model: {}", model_path.display());
41        println!();
42    }
43
44    // Check if file exists
45    if !model_path.exists() {
46        return Err(voirs_sdk::VoirsError::config_error(format!(
47            "Model file not found: {}",
48            model_path.display()
49        )));
50    }
51
52    // Read file metadata
53    let metadata = std::fs::metadata(model_path).map_err(|e| {
54        voirs_sdk::VoirsError::config_error(format!("Failed to read file metadata: {}", e))
55    })?;
56
57    let file_size = metadata.len();
58
59    // Determine model format
60    let format = detect_model_format(model_path)?;
61
62    if !global.quiet {
63        println!("šŸ“„ File Information:");
64        println!("  Format: {}", format);
65        println!(
66            "  Size: {} bytes ({:.2} MB)",
67            file_size,
68            file_size as f64 / 1_048_576.0
69        );
70        println!();
71    }
72
73    // Perform format-specific inspection
74    let inspection = match format.as_str() {
75        "SafeTensors" => inspect_safetensors(model_path, detailed)?,
76        "PyTorch" => inspect_pytorch(model_path, detailed)?,
77        "ONNX" => inspect_onnx(model_path, detailed)?,
78        _ => {
79            return Err(voirs_sdk::VoirsError::config_error(format!(
80                "Unsupported model format: {}",
81                format
82            )));
83        }
84    };
85
86    // Display inspection results
87    display_inspection(&inspection, detailed, global.quiet);
88
89    // Verify integrity if requested
90    if verify {
91        verify_model_integrity(model_path, &format, global.quiet)?;
92    }
93
94    // Export architecture if requested
95    if let Some(export_path) = export_path {
96        export_architecture(&inspection, export_path)?;
97        if !global.quiet {
98            println!("\nāœ… Architecture exported to: {}", export_path.display());
99        }
100    }
101
102    Ok(())
103}
104
105/// Detect model file format
106fn detect_model_format(path: &Path) -> Result<String> {
107    let ext = path
108        .extension()
109        .and_then(|e| e.to_str())
110        .ok_or_else(|| voirs_sdk::VoirsError::config_error("No file extension found"))?;
111
112    match ext.to_lowercase().as_str() {
113        "safetensors" | "st" => Ok("SafeTensors".to_string()),
114        "pt" | "pth" | "bin" => Ok("PyTorch".to_string()),
115        "onnx" => Ok("ONNX".to_string()),
116        _ => Err(voirs_sdk::VoirsError::config_error(format!(
117            "Unknown model format: {}",
118            ext
119        ))),
120    }
121}
122
123/// Inspect SafeTensors model
124fn inspect_safetensors(path: &Path, detailed: bool) -> Result<ModelInspection> {
125    use safetensors::SafeTensors;
126
127    let mut file = File::open(path)
128        .map_err(|e| voirs_sdk::VoirsError::config_error(format!("Failed to open file: {}", e)))?;
129
130    let mut buffer = Vec::new();
131    file.read_to_end(&mut buffer)
132        .map_err(|e| voirs_sdk::VoirsError::config_error(format!("Failed to read file: {}", e)))?;
133
134    let tensors = SafeTensors::deserialize(&buffer).map_err(|e| {
135        voirs_sdk::VoirsError::config_error(format!("Failed to deserialize SafeTensors: {}", e))
136    })?;
137
138    let mut layers = Vec::new();
139    let mut total_params = 0;
140
141    for name in tensors.names() {
142        let tensor = tensors.tensor(name).map_err(|e| {
143            voirs_sdk::VoirsError::config_error(format!("Failed to get tensor: {}", e))
144        })?;
145
146        let shape: Vec<usize> = tensor.shape().to_vec();
147        let param_count: usize = shape.iter().product();
148        total_params += param_count;
149
150        if detailed {
151            layers.push(LayerInfo {
152                name: name.to_string(),
153                layer_type: infer_layer_type(name),
154                shape,
155                param_count,
156            });
157        }
158    }
159
160    let mut metadata = Vec::new();
161    // SafeTensors metadata is not publicly accessible in this version
162    // metadata.push(("metadata".to_string(), "See SafeTensors spec".to_string()));
163
164    Ok(ModelInspection {
165        model_type: infer_model_type(tensors.names()),
166        format: "SafeTensors".to_string(),
167        file_size: buffer.len() as u64,
168        parameter_count: Some(total_params),
169        layers,
170        metadata,
171    })
172}
173
174/// Inspect PyTorch model
175fn inspect_pytorch(path: &Path, detailed: bool) -> Result<ModelInspection> {
176    let metadata_result = std::fs::metadata(path);
177    let file_size = metadata_result.map(|m| m.len()).unwrap_or(0);
178
179    // Try to validate PyTorch format by checking pickle magic bytes
180    let mut file = File::open(path)
181        .map_err(|e| voirs_sdk::VoirsError::config_error(format!("Failed to open file: {}", e)))?;
182
183    let mut magic = [0u8; 8];
184    let _ = file.read(&mut magic);
185
186    let mut metadata = vec![];
187    let is_valid_pickle = magic.starts_with(b"\x80") || magic.starts_with(b"PK");
188
189    if is_valid_pickle {
190        metadata.push(("format_valid".to_string(), "true".to_string()));
191        metadata.push((
192            "pickle_protocol".to_string(),
193            format!("{}", magic[1] as char),
194        ));
195    } else {
196        metadata.push(("format_valid".to_string(), "false".to_string()));
197        metadata.push((
198            "warning".to_string(),
199            "File may not be a valid PyTorch checkpoint".to_string(),
200        ));
201    }
202
203    // Try to estimate parameter count from file size (rough heuristic)
204    let estimated_params = if file_size > 1024 {
205        Some(((file_size as f64 / 4.0) * 0.9) as usize) // Rough estimate: ~4 bytes per float32 param
206    } else {
207        None
208    };
209
210    metadata.push((
211        "note".to_string(),
212        "Full inspection requires PyTorch/tch-rs bindings".to_string(),
213    ));
214    metadata.push((
215        "recommendation".to_string(),
216        "Convert to SafeTensors format for detailed inspection".to_string(),
217    ));
218
219    Ok(ModelInspection {
220        model_type: infer_pytorch_model_type(path),
221        format: "PyTorch".to_string(),
222        file_size,
223        parameter_count: estimated_params,
224        layers: vec![],
225        metadata,
226    })
227}
228
229/// Inspect ONNX model
230fn inspect_onnx(path: &Path, detailed: bool) -> Result<ModelInspection> {
231    let metadata_result = std::fs::metadata(path);
232    let file_size = metadata_result.map(|m| m.len()).unwrap_or(0);
233
234    // Try to validate ONNX format by checking protobuf header
235    let mut file = File::open(path)
236        .map_err(|e| voirs_sdk::VoirsError::config_error(format!("Failed to open file: {}", e)))?;
237
238    let mut buffer = vec![0u8; 256]; // Read first 256 bytes for analysis
239    let bytes_read = file.read(&mut buffer).unwrap_or(0);
240
241    let mut metadata = vec![];
242
243    // Check for protobuf magic bytes and common ONNX markers
244    let has_onnx_marker = buffer.windows(4).any(|w| w == b"ONNX" || w == b"onnx");
245    let has_protobuf = bytes_read > 0 && (buffer[0] == 0x08 || buffer[0] == 0x0a);
246
247    if has_onnx_marker && has_protobuf {
248        metadata.push(("format_valid".to_string(), "true".to_string()));
249        metadata.push(("protobuf_format".to_string(), "detected".to_string()));
250    } else if has_protobuf {
251        metadata.push(("format_valid".to_string(), "likely".to_string()));
252        metadata.push((
253            "warning".to_string(),
254            "Protobuf detected but no ONNX marker found".to_string(),
255        ));
256    } else {
257        metadata.push(("format_valid".to_string(), "false".to_string()));
258        metadata.push((
259            "warning".to_string(),
260            "File may not be a valid ONNX model".to_string(),
261        ));
262    }
263
264    // Try to find IR version in the protobuf data
265    if let Some(ir_version) = extract_onnx_ir_version(&buffer[..bytes_read]) {
266        metadata.push(("ir_version".to_string(), ir_version.to_string()));
267    }
268
269    // Estimate parameters from file size
270    let estimated_params = if file_size > 1024 {
271        Some(((file_size as f64 / 4.5) * 0.85) as usize) // Rough estimate accounting for protobuf overhead
272    } else {
273        None
274    };
275
276    metadata.push((
277        "note".to_string(),
278        "Full inspection requires tract-onnx or onnxruntime bindings".to_string(),
279    ));
280    metadata.push((
281        "recommendation".to_string(),
282        "Use 'onnx' Python tools for detailed inspection, or convert to SafeTensors".to_string(),
283    ));
284
285    Ok(ModelInspection {
286        model_type: infer_onnx_model_type(path),
287        format: "ONNX".to_string(),
288        file_size,
289        parameter_count: estimated_params,
290        layers: vec![],
291        metadata,
292    })
293}
294
295/// Infer layer type from tensor name
296fn infer_layer_type(name: &str) -> String {
297    if name.contains("weight") && name.contains("conv") {
298        "Convolution".to_string()
299    } else if name.contains("weight") && name.contains("linear") {
300        "Linear".to_string()
301    } else if name.contains("weight") && name.contains("attention") {
302        "Attention".to_string()
303    } else if name.contains("norm") || name.contains("bn") {
304        "Normalization".to_string()
305    } else if name.contains("embedding") {
306        "Embedding".to_string()
307    } else if name.contains("bias") {
308        "Bias".to_string()
309    } else {
310        "Other".to_string()
311    }
312}
313
314/// Infer model type from tensor names
315fn infer_model_type(names: Vec<&str>) -> String {
316    let names_str = names.join(" ").to_lowercase();
317
318    if names_str.contains("diffwave") || names_str.contains("residual_blocks") {
319        "DiffWave Vocoder".to_string()
320    } else if names_str.contains("hifigan") || names_str.contains("generator") {
321        "HiFi-GAN Vocoder".to_string()
322    } else if names_str.contains("vits") || names_str.contains("posterior_encoder") {
323        "VITS Acoustic Model".to_string()
324    } else if names_str.contains("fastspeech") {
325        "FastSpeech2 Acoustic Model".to_string()
326    } else if names_str.contains("g2p") || names_str.contains("phoneme") {
327        "G2P Model".to_string()
328    } else {
329        "Unknown Model Type".to_string()
330    }
331}
332
333/// Infer PyTorch model type from filename
334fn infer_pytorch_model_type(path: &Path) -> String {
335    let filename = path
336        .file_name()
337        .and_then(|n| n.to_str())
338        .unwrap_or("")
339        .to_lowercase();
340
341    if filename.contains("vocoder") || filename.contains("hifigan") || filename.contains("diffwave")
342    {
343        "Vocoder Model (PyTorch)".to_string()
344    } else if filename.contains("acoustic")
345        || filename.contains("vits")
346        || filename.contains("fastspeech")
347    {
348        "Acoustic Model (PyTorch)".to_string()
349    } else if filename.contains("g2p") || filename.contains("phoneme") {
350        "G2P Model (PyTorch)".to_string()
351    } else if filename.contains("encoder") {
352        "Encoder Model (PyTorch)".to_string()
353    } else if filename.contains("decoder") {
354        "Decoder Model (PyTorch)".to_string()
355    } else {
356        "Unknown Model Type (PyTorch)".to_string()
357    }
358}
359
360/// Infer ONNX model type from filename
361fn infer_onnx_model_type(path: &Path) -> String {
362    let filename = path
363        .file_name()
364        .and_then(|n| n.to_str())
365        .unwrap_or("")
366        .to_lowercase();
367
368    if filename.contains("vocoder") || filename.contains("hifigan") || filename.contains("diffwave")
369    {
370        "Vocoder Model (ONNX)".to_string()
371    } else if filename.contains("acoustic")
372        || filename.contains("vits")
373        || filename.contains("fastspeech")
374    {
375        "Acoustic Model (ONNX)".to_string()
376    } else if filename.contains("g2p") || filename.contains("phoneme") {
377        "G2P Model (ONNX)".to_string()
378    } else if filename.contains("encoder") {
379        "Encoder Model (ONNX)".to_string()
380    } else if filename.contains("decoder") {
381        "Decoder Model (ONNX)".to_string()
382    } else {
383        "Unknown Model Type (ONNX)".to_string()
384    }
385}
386
387/// Extract ONNX IR version from protobuf data
388fn extract_onnx_ir_version(buffer: &[u8]) -> Option<u8> {
389    // IR version is typically one of the first fields in ONNX protobuf
390    // Field tag 1 (ir_version) has wire type 0 (varint)
391    // This is a simple heuristic search
392    for i in 0..buffer.len().saturating_sub(2) {
393        // Look for field tag 0x08 (field 1, wire type 0)
394        if buffer[i] == 0x08 && buffer[i + 1] > 0 && buffer[i + 1] < 20 {
395            return Some(buffer[i + 1]);
396        }
397    }
398    None
399}
400
401/// Display inspection results
402fn display_inspection(inspection: &ModelInspection, detailed: bool, quiet: bool) {
403    if quiet {
404        return;
405    }
406
407    println!("šŸ”¬ Model Analysis:");
408    println!("  Type: {}", inspection.model_type);
409
410    if let Some(count) = inspection.parameter_count {
411        println!(
412            "  Parameters: {:?} ({:.2}M)",
413            count,
414            count as f64 / 1_000_000.0
415        );
416    }
417
418    if !inspection.metadata.is_empty() {
419        println!("\nšŸ“‹ Metadata:");
420        for (key, value) in &inspection.metadata {
421            println!("  {}: {}", key, value);
422        }
423    }
424
425    if detailed && !inspection.layers.is_empty() {
426        println!("\n🧩 Layers ({} total):", inspection.layers.len());
427        for layer in &inspection.layers {
428            println!("  {} [{}]", layer.name, layer.layer_type);
429            println!("    Shape: {:?}", layer.shape);
430            println!("    Parameters: {}", layer.param_count);
431        }
432    } else if !inspection.layers.is_empty() {
433        println!(
434            "  Layers: {} (use --detailed for full list)",
435            inspection.layers.len()
436        );
437    }
438}
439
440/// Verify model integrity
441fn verify_model_integrity(path: &Path, format: &str, quiet: bool) -> Result<()> {
442    use safetensors::SafeTensors;
443
444    if !quiet {
445        println!("\nšŸ” Verifying model integrity...");
446    }
447
448    // Basic file existence and readability check
449    let _file = File::open(path)
450        .map_err(|e| voirs_sdk::VoirsError::config_error(format!("Failed to open file: {}", e)))?;
451
452    // Calculate checksum for integrity tracking
453    let checksum = calculate_file_checksum(path)?;
454    if !quiet {
455        println!("  SHA-256: {}", checksum);
456    }
457
458    // Format-specific validation
459    match format {
460        "SafeTensors" => {
461            // Try to deserialize
462            let mut file = File::open(path)?;
463            let mut buffer = Vec::new();
464            file.read_to_end(&mut buffer)?;
465            SafeTensors::deserialize(&buffer).map_err(|e| {
466                voirs_sdk::VoirsError::config_error(format!("SafeTensors validation failed: {}", e))
467            })?;
468
469            if !quiet {
470                println!("  Format: Valid SafeTensors");
471            }
472        }
473        "PyTorch" => {
474            // Validate pickle format
475            let mut file = File::open(path)?;
476            let mut magic = [0u8; 2];
477            file.read_exact(&mut magic).ok();
478
479            if magic[0] == 0x80 || magic.starts_with(b"PK") {
480                if !quiet {
481                    println!("  Format: Valid PyTorch/Pickle");
482                }
483            } else if !quiet {
484                println!("  Format: Warning - may not be valid PyTorch");
485            }
486        }
487        "ONNX" => {
488            // Validate ONNX protobuf format
489            let mut file = File::open(path)?;
490            let mut buffer = vec![0u8; 64];
491            let _ = file.read(&mut buffer);
492
493            let has_onnx = buffer.windows(4).any(|w| w == b"ONNX");
494            if has_onnx && !quiet {
495                println!("  Format: Valid ONNX");
496            } else if !quiet {
497                println!("  Format: Warning - may not be valid ONNX");
498            }
499        }
500        _ => {
501            // Basic checks only for other formats
502        }
503    }
504
505    if !quiet {
506        println!("āœ… Model integrity verified");
507    }
508
509    Ok(())
510}
511
512/// Calculate SHA-256 checksum of a file
513fn calculate_file_checksum(path: &Path) -> Result<String> {
514    use sha2::{Digest, Sha256};
515
516    let mut file = File::open(path)
517        .map_err(|e| voirs_sdk::VoirsError::config_error(format!("Failed to open file: {}", e)))?;
518
519    let mut hasher = Sha256::new();
520    let mut buffer = vec![0u8; 8192];
521
522    loop {
523        let bytes_read = file.read(&mut buffer).map_err(|e| {
524            voirs_sdk::VoirsError::config_error(format!("Failed to read file: {}", e))
525        })?;
526
527        if bytes_read == 0 {
528            break;
529        }
530
531        hasher.update(&buffer[..bytes_read]);
532    }
533
534    let result = hasher.finalize();
535    Ok(format!("{:x}", result))
536}
537
538/// Export architecture to file
539fn export_architecture(inspection: &ModelInspection, path: &PathBuf) -> Result<()> {
540    use serde_json;
541
542    let json = serde_json::to_string_pretty(&serde_json::json!({
543        "model_type": inspection.model_type,
544        "format": inspection.format,
545        "file_size": inspection.file_size,
546        "parameter_count": inspection.parameter_count,
547        "layer_count": inspection.layers.len(),
548        "layers": inspection.layers.iter().map(|l| serde_json::json!({
549            "name": l.name,
550            "type": l.layer_type,
551            "shape": l.shape,
552            "parameters": l.param_count,
553        })).collect::<Vec<_>>(),
554        "metadata": inspection.metadata.iter().map(|(k, v)| serde_json::json!({
555            "key": k,
556            "value": v,
557        })).collect::<Vec<_>>(),
558    }))
559    .map_err(|e| voirs_sdk::VoirsError::config_error(format!("Failed to serialize: {}", e)))?;
560
561    std::fs::write(path, json)
562        .map_err(|e| voirs_sdk::VoirsError::config_error(format!("Failed to write file: {}", e)))?;
563
564    Ok(())
565}