1use crate::GlobalOptions;
6use std::fs::File;
7use std::io::Read;
8use std::path::{Path, PathBuf};
9use voirs_sdk::Result;
10
11#[derive(Debug)]
13pub struct ModelInspection {
14 pub model_type: String,
15 pub format: String,
16 pub file_size: u64,
17 pub parameter_count: Option<usize>,
18 pub layers: Vec<LayerInfo>,
19 pub metadata: Vec<(String, String)>,
20}
21
22#[derive(Debug)]
24pub struct LayerInfo {
25 pub name: String,
26 pub layer_type: String,
27 pub shape: Vec<usize>,
28 pub param_count: usize,
29}
30
31pub async fn run_model_inspect(
33 model_path: &Path,
34 detailed: bool,
35 export_path: Option<&PathBuf>,
36 verify: bool,
37 global: &GlobalOptions,
38) -> Result<()> {
39 if !global.quiet {
40 println!("š Inspecting model: {}", model_path.display());
41 println!();
42 }
43
44 if !model_path.exists() {
46 return Err(voirs_sdk::VoirsError::config_error(format!(
47 "Model file not found: {}",
48 model_path.display()
49 )));
50 }
51
52 let metadata = std::fs::metadata(model_path).map_err(|e| {
54 voirs_sdk::VoirsError::config_error(format!("Failed to read file metadata: {}", e))
55 })?;
56
57 let file_size = metadata.len();
58
59 let format = detect_model_format(model_path)?;
61
62 if !global.quiet {
63 println!("š File Information:");
64 println!(" Format: {}", format);
65 println!(
66 " Size: {} bytes ({:.2} MB)",
67 file_size,
68 file_size as f64 / 1_048_576.0
69 );
70 println!();
71 }
72
73 let inspection = match format.as_str() {
75 "SafeTensors" => inspect_safetensors(model_path, detailed)?,
76 "PyTorch" => inspect_pytorch(model_path, detailed)?,
77 "ONNX" => inspect_onnx(model_path, detailed)?,
78 _ => {
79 return Err(voirs_sdk::VoirsError::config_error(format!(
80 "Unsupported model format: {}",
81 format
82 )));
83 }
84 };
85
86 display_inspection(&inspection, detailed, global.quiet);
88
89 if verify {
91 verify_model_integrity(model_path, &format, global.quiet)?;
92 }
93
94 if let Some(export_path) = export_path {
96 export_architecture(&inspection, export_path)?;
97 if !global.quiet {
98 println!("\nā
Architecture exported to: {}", export_path.display());
99 }
100 }
101
102 Ok(())
103}
104
105fn detect_model_format(path: &Path) -> Result<String> {
107 let ext = path
108 .extension()
109 .and_then(|e| e.to_str())
110 .ok_or_else(|| voirs_sdk::VoirsError::config_error("No file extension found"))?;
111
112 match ext.to_lowercase().as_str() {
113 "safetensors" | "st" => Ok("SafeTensors".to_string()),
114 "pt" | "pth" | "bin" => Ok("PyTorch".to_string()),
115 "onnx" => Ok("ONNX".to_string()),
116 _ => Err(voirs_sdk::VoirsError::config_error(format!(
117 "Unknown model format: {}",
118 ext
119 ))),
120 }
121}
122
123fn inspect_safetensors(path: &Path, detailed: bool) -> Result<ModelInspection> {
125 use safetensors::SafeTensors;
126
127 let mut file = File::open(path)
128 .map_err(|e| voirs_sdk::VoirsError::config_error(format!("Failed to open file: {}", e)))?;
129
130 let mut buffer = Vec::new();
131 file.read_to_end(&mut buffer)
132 .map_err(|e| voirs_sdk::VoirsError::config_error(format!("Failed to read file: {}", e)))?;
133
134 let tensors = SafeTensors::deserialize(&buffer).map_err(|e| {
135 voirs_sdk::VoirsError::config_error(format!("Failed to deserialize SafeTensors: {}", e))
136 })?;
137
138 let mut layers = Vec::new();
139 let mut total_params = 0;
140
141 for name in tensors.names() {
142 let tensor = tensors.tensor(name).map_err(|e| {
143 voirs_sdk::VoirsError::config_error(format!("Failed to get tensor: {}", e))
144 })?;
145
146 let shape: Vec<usize> = tensor.shape().to_vec();
147 let param_count: usize = shape.iter().product();
148 total_params += param_count;
149
150 if detailed {
151 layers.push(LayerInfo {
152 name: name.to_string(),
153 layer_type: infer_layer_type(name),
154 shape,
155 param_count,
156 });
157 }
158 }
159
160 let mut metadata = Vec::new();
161 Ok(ModelInspection {
165 model_type: infer_model_type(tensors.names()),
166 format: "SafeTensors".to_string(),
167 file_size: buffer.len() as u64,
168 parameter_count: Some(total_params),
169 layers,
170 metadata,
171 })
172}
173
174fn inspect_pytorch(path: &Path, detailed: bool) -> Result<ModelInspection> {
176 let metadata_result = std::fs::metadata(path);
177 let file_size = metadata_result.map(|m| m.len()).unwrap_or(0);
178
179 let mut file = File::open(path)
181 .map_err(|e| voirs_sdk::VoirsError::config_error(format!("Failed to open file: {}", e)))?;
182
183 let mut magic = [0u8; 8];
184 let _ = file.read(&mut magic);
185
186 let mut metadata = vec![];
187 let is_valid_pickle = magic.starts_with(b"\x80") || magic.starts_with(b"PK");
188
189 if is_valid_pickle {
190 metadata.push(("format_valid".to_string(), "true".to_string()));
191 metadata.push((
192 "pickle_protocol".to_string(),
193 format!("{}", magic[1] as char),
194 ));
195 } else {
196 metadata.push(("format_valid".to_string(), "false".to_string()));
197 metadata.push((
198 "warning".to_string(),
199 "File may not be a valid PyTorch checkpoint".to_string(),
200 ));
201 }
202
203 let estimated_params = if file_size > 1024 {
205 Some(((file_size as f64 / 4.0) * 0.9) as usize) } else {
207 None
208 };
209
210 metadata.push((
211 "note".to_string(),
212 "Full inspection requires PyTorch/tch-rs bindings".to_string(),
213 ));
214 metadata.push((
215 "recommendation".to_string(),
216 "Convert to SafeTensors format for detailed inspection".to_string(),
217 ));
218
219 Ok(ModelInspection {
220 model_type: infer_pytorch_model_type(path),
221 format: "PyTorch".to_string(),
222 file_size,
223 parameter_count: estimated_params,
224 layers: vec![],
225 metadata,
226 })
227}
228
229fn inspect_onnx(path: &Path, detailed: bool) -> Result<ModelInspection> {
231 let metadata_result = std::fs::metadata(path);
232 let file_size = metadata_result.map(|m| m.len()).unwrap_or(0);
233
234 let mut file = File::open(path)
236 .map_err(|e| voirs_sdk::VoirsError::config_error(format!("Failed to open file: {}", e)))?;
237
238 let mut buffer = vec![0u8; 256]; let bytes_read = file.read(&mut buffer).unwrap_or(0);
240
241 let mut metadata = vec![];
242
243 let has_onnx_marker = buffer.windows(4).any(|w| w == b"ONNX" || w == b"onnx");
245 let has_protobuf = bytes_read > 0 && (buffer[0] == 0x08 || buffer[0] == 0x0a);
246
247 if has_onnx_marker && has_protobuf {
248 metadata.push(("format_valid".to_string(), "true".to_string()));
249 metadata.push(("protobuf_format".to_string(), "detected".to_string()));
250 } else if has_protobuf {
251 metadata.push(("format_valid".to_string(), "likely".to_string()));
252 metadata.push((
253 "warning".to_string(),
254 "Protobuf detected but no ONNX marker found".to_string(),
255 ));
256 } else {
257 metadata.push(("format_valid".to_string(), "false".to_string()));
258 metadata.push((
259 "warning".to_string(),
260 "File may not be a valid ONNX model".to_string(),
261 ));
262 }
263
264 if let Some(ir_version) = extract_onnx_ir_version(&buffer[..bytes_read]) {
266 metadata.push(("ir_version".to_string(), ir_version.to_string()));
267 }
268
269 let estimated_params = if file_size > 1024 {
271 Some(((file_size as f64 / 4.5) * 0.85) as usize) } else {
273 None
274 };
275
276 metadata.push((
277 "note".to_string(),
278 "Full inspection requires tract-onnx or onnxruntime bindings".to_string(),
279 ));
280 metadata.push((
281 "recommendation".to_string(),
282 "Use 'onnx' Python tools for detailed inspection, or convert to SafeTensors".to_string(),
283 ));
284
285 Ok(ModelInspection {
286 model_type: infer_onnx_model_type(path),
287 format: "ONNX".to_string(),
288 file_size,
289 parameter_count: estimated_params,
290 layers: vec![],
291 metadata,
292 })
293}
294
295fn infer_layer_type(name: &str) -> String {
297 if name.contains("weight") && name.contains("conv") {
298 "Convolution".to_string()
299 } else if name.contains("weight") && name.contains("linear") {
300 "Linear".to_string()
301 } else if name.contains("weight") && name.contains("attention") {
302 "Attention".to_string()
303 } else if name.contains("norm") || name.contains("bn") {
304 "Normalization".to_string()
305 } else if name.contains("embedding") {
306 "Embedding".to_string()
307 } else if name.contains("bias") {
308 "Bias".to_string()
309 } else {
310 "Other".to_string()
311 }
312}
313
314fn infer_model_type(names: Vec<&str>) -> String {
316 let names_str = names.join(" ").to_lowercase();
317
318 if names_str.contains("diffwave") || names_str.contains("residual_blocks") {
319 "DiffWave Vocoder".to_string()
320 } else if names_str.contains("hifigan") || names_str.contains("generator") {
321 "HiFi-GAN Vocoder".to_string()
322 } else if names_str.contains("vits") || names_str.contains("posterior_encoder") {
323 "VITS Acoustic Model".to_string()
324 } else if names_str.contains("fastspeech") {
325 "FastSpeech2 Acoustic Model".to_string()
326 } else if names_str.contains("g2p") || names_str.contains("phoneme") {
327 "G2P Model".to_string()
328 } else {
329 "Unknown Model Type".to_string()
330 }
331}
332
333fn infer_pytorch_model_type(path: &Path) -> String {
335 let filename = path
336 .file_name()
337 .and_then(|n| n.to_str())
338 .unwrap_or("")
339 .to_lowercase();
340
341 if filename.contains("vocoder") || filename.contains("hifigan") || filename.contains("diffwave")
342 {
343 "Vocoder Model (PyTorch)".to_string()
344 } else if filename.contains("acoustic")
345 || filename.contains("vits")
346 || filename.contains("fastspeech")
347 {
348 "Acoustic Model (PyTorch)".to_string()
349 } else if filename.contains("g2p") || filename.contains("phoneme") {
350 "G2P Model (PyTorch)".to_string()
351 } else if filename.contains("encoder") {
352 "Encoder Model (PyTorch)".to_string()
353 } else if filename.contains("decoder") {
354 "Decoder Model (PyTorch)".to_string()
355 } else {
356 "Unknown Model Type (PyTorch)".to_string()
357 }
358}
359
360fn infer_onnx_model_type(path: &Path) -> String {
362 let filename = path
363 .file_name()
364 .and_then(|n| n.to_str())
365 .unwrap_or("")
366 .to_lowercase();
367
368 if filename.contains("vocoder") || filename.contains("hifigan") || filename.contains("diffwave")
369 {
370 "Vocoder Model (ONNX)".to_string()
371 } else if filename.contains("acoustic")
372 || filename.contains("vits")
373 || filename.contains("fastspeech")
374 {
375 "Acoustic Model (ONNX)".to_string()
376 } else if filename.contains("g2p") || filename.contains("phoneme") {
377 "G2P Model (ONNX)".to_string()
378 } else if filename.contains("encoder") {
379 "Encoder Model (ONNX)".to_string()
380 } else if filename.contains("decoder") {
381 "Decoder Model (ONNX)".to_string()
382 } else {
383 "Unknown Model Type (ONNX)".to_string()
384 }
385}
386
387fn extract_onnx_ir_version(buffer: &[u8]) -> Option<u8> {
389 for i in 0..buffer.len().saturating_sub(2) {
393 if buffer[i] == 0x08 && buffer[i + 1] > 0 && buffer[i + 1] < 20 {
395 return Some(buffer[i + 1]);
396 }
397 }
398 None
399}
400
401fn display_inspection(inspection: &ModelInspection, detailed: bool, quiet: bool) {
403 if quiet {
404 return;
405 }
406
407 println!("š¬ Model Analysis:");
408 println!(" Type: {}", inspection.model_type);
409
410 if let Some(count) = inspection.parameter_count {
411 println!(
412 " Parameters: {:?} ({:.2}M)",
413 count,
414 count as f64 / 1_000_000.0
415 );
416 }
417
418 if !inspection.metadata.is_empty() {
419 println!("\nš Metadata:");
420 for (key, value) in &inspection.metadata {
421 println!(" {}: {}", key, value);
422 }
423 }
424
425 if detailed && !inspection.layers.is_empty() {
426 println!("\nš§© Layers ({} total):", inspection.layers.len());
427 for layer in &inspection.layers {
428 println!(" {} [{}]", layer.name, layer.layer_type);
429 println!(" Shape: {:?}", layer.shape);
430 println!(" Parameters: {}", layer.param_count);
431 }
432 } else if !inspection.layers.is_empty() {
433 println!(
434 " Layers: {} (use --detailed for full list)",
435 inspection.layers.len()
436 );
437 }
438}
439
440fn verify_model_integrity(path: &Path, format: &str, quiet: bool) -> Result<()> {
442 use safetensors::SafeTensors;
443
444 if !quiet {
445 println!("\nš Verifying model integrity...");
446 }
447
448 let _file = File::open(path)
450 .map_err(|e| voirs_sdk::VoirsError::config_error(format!("Failed to open file: {}", e)))?;
451
452 let checksum = calculate_file_checksum(path)?;
454 if !quiet {
455 println!(" SHA-256: {}", checksum);
456 }
457
458 match format {
460 "SafeTensors" => {
461 let mut file = File::open(path)?;
463 let mut buffer = Vec::new();
464 file.read_to_end(&mut buffer)?;
465 SafeTensors::deserialize(&buffer).map_err(|e| {
466 voirs_sdk::VoirsError::config_error(format!("SafeTensors validation failed: {}", e))
467 })?;
468
469 if !quiet {
470 println!(" Format: Valid SafeTensors");
471 }
472 }
473 "PyTorch" => {
474 let mut file = File::open(path)?;
476 let mut magic = [0u8; 2];
477 file.read_exact(&mut magic).ok();
478
479 if magic[0] == 0x80 || magic.starts_with(b"PK") {
480 if !quiet {
481 println!(" Format: Valid PyTorch/Pickle");
482 }
483 } else if !quiet {
484 println!(" Format: Warning - may not be valid PyTorch");
485 }
486 }
487 "ONNX" => {
488 let mut file = File::open(path)?;
490 let mut buffer = vec![0u8; 64];
491 let _ = file.read(&mut buffer);
492
493 let has_onnx = buffer.windows(4).any(|w| w == b"ONNX");
494 if has_onnx && !quiet {
495 println!(" Format: Valid ONNX");
496 } else if !quiet {
497 println!(" Format: Warning - may not be valid ONNX");
498 }
499 }
500 _ => {
501 }
503 }
504
505 if !quiet {
506 println!("ā
Model integrity verified");
507 }
508
509 Ok(())
510}
511
512fn calculate_file_checksum(path: &Path) -> Result<String> {
514 use sha2::{Digest, Sha256};
515
516 let mut file = File::open(path)
517 .map_err(|e| voirs_sdk::VoirsError::config_error(format!("Failed to open file: {}", e)))?;
518
519 let mut hasher = Sha256::new();
520 let mut buffer = vec![0u8; 8192];
521
522 loop {
523 let bytes_read = file.read(&mut buffer).map_err(|e| {
524 voirs_sdk::VoirsError::config_error(format!("Failed to read file: {}", e))
525 })?;
526
527 if bytes_read == 0 {
528 break;
529 }
530
531 hasher.update(&buffer[..bytes_read]);
532 }
533
534 let result = hasher.finalize();
535 Ok(format!("{:x}", result))
536}
537
538fn export_architecture(inspection: &ModelInspection, path: &PathBuf) -> Result<()> {
540 use serde_json;
541
542 let json = serde_json::to_string_pretty(&serde_json::json!({
543 "model_type": inspection.model_type,
544 "format": inspection.format,
545 "file_size": inspection.file_size,
546 "parameter_count": inspection.parameter_count,
547 "layer_count": inspection.layers.len(),
548 "layers": inspection.layers.iter().map(|l| serde_json::json!({
549 "name": l.name,
550 "type": l.layer_type,
551 "shape": l.shape,
552 "parameters": l.param_count,
553 })).collect::<Vec<_>>(),
554 "metadata": inspection.metadata.iter().map(|(k, v)| serde_json::json!({
555 "key": k,
556 "value": v,
557 })).collect::<Vec<_>>(),
558 }))
559 .map_err(|e| voirs_sdk::VoirsError::config_error(format!("Failed to serialize: {}", e)))?;
560
561 std::fs::write(path, json)
562 .map_err(|e| voirs_sdk::VoirsError::config_error(format!("Failed to write file: {}", e)))?;
563
564 Ok(())
565}