1#![allow(dead_code)]
11
12use anyhow::Result;
13use tracing::{debug, info, warn};
14
15use scirs2_core::ndarray::Array1;
17use scirs2_core::random::{thread_rng, Distribution, Normal, Uniform};
18
19use torsh::core::device::DeviceType;
21use torsh::tensor::Tensor;
22
23use super::types::{LayerInfo, TorshModel};
24
25#[derive(Debug, Clone)]
27pub struct ValidationResult {
28 pub passed: bool,
30 pub accuracy: Option<f64>,
32 pub top5_accuracy: Option<f64>,
34 pub num_samples: usize,
36 pub successful_inferences: usize,
38 pub failed_inferences: usize,
40 pub avg_inference_time_ms: f64,
42 pub peak_memory_mb: f64,
44 pub gradient_check_passed: Option<bool>,
46 pub numerical_stability: f64,
48 pub errors: Vec<String>,
50 pub warnings: Vec<String>,
52}
53
54#[derive(Debug, Clone)]
56pub struct GradientCheckResult {
57 pub passed: bool,
59 pub max_relative_error: f64,
61 pub avg_relative_error: f64,
63 pub num_gradients_checked: usize,
65 pub failed_locations: Vec<String>,
67}
68
69#[derive(Debug, Clone)]
71pub struct StabilityAnalysis {
72 pub has_nan: bool,
74 pub has_inf: bool,
76 pub has_large_values: bool,
78 pub has_tiny_values: bool,
80 pub gradient_magnitude: GradientStatistics,
82 pub activation_stats: ActivationStatistics,
84}
85
86#[derive(Debug, Clone)]
88pub struct GradientStatistics {
89 pub mean: f64,
90 pub std: f64,
91 pub min: f64,
92 pub max: f64,
93 pub vanishing_percentage: f64,
95 pub exploding_percentage: f64,
97}
98
99#[derive(Debug, Clone)]
101pub struct ActivationStatistics {
102 pub mean: f64,
103 pub std: f64,
104 pub min: f64,
105 pub max: f64,
106 pub dead_neurons_percentage: f64,
108}
109
110pub async fn validate_model(
112 model: &TorshModel,
113 num_samples: usize,
114 check_gradients: bool,
115) -> Result<ValidationResult> {
116 info!(
117 "Validating model with {} samples (gradient check: {})",
118 num_samples, check_gradients
119 );
120
121 let mut errors = Vec::new();
122 let mut warnings = Vec::new();
123
124 if let Err(e) = validate_model_structure(model) {
126 errors.push(format!("Model structure validation failed: {}", e));
127 }
128
129 let (successful, failed, avg_time, peak_memory) =
131 run_inference_tests(model, num_samples).await?;
132
133 let gradient_check_result = if check_gradients {
135 match perform_gradient_check(model).await {
136 Ok(result) => Some(result.passed),
137 Err(e) => {
138 warnings.push(format!("Gradient check failed: {}", e));
139 None
140 }
141 }
142 } else {
143 None
144 };
145
146 let stability = analyze_numerical_stability(model).await?;
148 let numerical_stability = calculate_stability_score(&stability);
149
150 if stability.has_nan {
151 errors.push("Model contains NaN values".to_string());
152 }
153 if stability.has_inf {
154 errors.push("Model contains Inf values".to_string());
155 }
156
157 if stability.gradient_magnitude.vanishing_percentage > 50.0 {
158 warnings.push(format!(
159 "High vanishing gradient rate: {:.1}%",
160 stability.gradient_magnitude.vanishing_percentage
161 ));
162 }
163
164 if stability.gradient_magnitude.exploding_percentage > 10.0 {
165 warnings.push(format!(
166 "High exploding gradient rate: {:.1}%",
167 stability.gradient_magnitude.exploding_percentage
168 ));
169 }
170
171 let passed = errors.is_empty() && successful > 0;
172
173 Ok(ValidationResult {
174 passed,
175 accuracy: None, top5_accuracy: None,
177 num_samples,
178 successful_inferences: successful,
179 failed_inferences: failed,
180 avg_inference_time_ms: avg_time,
181 peak_memory_mb: peak_memory,
182 gradient_check_passed: gradient_check_result,
183 numerical_stability,
184 errors,
185 warnings,
186 })
187}
188
189fn validate_model_structure(model: &TorshModel) -> Result<()> {
191 debug!("Validating model structure");
192
193 if model.layers.is_empty() {
195 anyhow::bail!("Model has no layers");
196 }
197
198 for layer in &model.layers {
200 if layer.input_shape.is_empty() {
201 anyhow::bail!("Layer {} has empty input shape", layer.name);
202 }
203 if layer.output_shape.is_empty() {
204 anyhow::bail!("Layer {} has empty output shape", layer.name);
205 }
206
207 if layer.trainable {
209 let weight_name = format!("{}.weight", layer.name);
210 if !model.weights.contains_key(&weight_name) {
211 anyhow::bail!("Trainable layer {} missing weight tensor", layer.name);
212 }
213 }
214 }
215
216 for i in 0..model.layers.len() - 1 {
218 let current = &model.layers[i];
219 let next = &model.layers[i + 1];
220
221 if current.output_shape != next.input_shape {
222 warn!(
223 "Shape mismatch between layers {} and {}: {:?} != {:?}",
224 current.name, next.name, current.output_shape, next.input_shape
225 );
226 }
227 }
228
229 Ok(())
230}
231
232async fn run_inference_tests(
234 model: &TorshModel,
235 num_samples: usize,
236) -> Result<(usize, usize, f64, f64)> {
237 info!("Running {} inference tests", num_samples);
238
239 let input_shape = model
240 .layers
241 .first()
242 .map(|l| l.input_shape.clone())
243 .unwrap_or_else(|| vec![784]);
244
245 let mut successful = 0;
246 let mut failed = 0;
247 let mut total_time = 0.0;
248 let mut peak_memory = 0.0f64;
249
250 for i in 0..num_samples {
251 let input = create_random_input(&input_shape)?;
252
253 let start = std::time::Instant::now();
254
255 match perform_forward_pass(model, &input).await {
256 Ok(output) => {
257 successful += 1;
258 total_time += start.elapsed().as_secs_f64() * 1000.0;
259
260 let memory = estimate_inference_memory(model, &output);
262 peak_memory = peak_memory.max(memory);
263
264 debug!(
265 "Inference {}: successful, output shape: {:?}",
266 i,
267 output.shape().dims()
268 );
269 }
270 Err(e) => {
271 failed += 1;
272 warn!("Inference {} failed: {}", i, e);
273 }
274 }
275
276 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
278 }
279
280 let avg_time = if successful > 0 {
281 total_time / successful as f64
282 } else {
283 0.0
284 };
285
286 Ok((successful, failed, avg_time, peak_memory))
287}
288
289fn create_random_input(shape: &[usize]) -> Result<Tensor<f32>> {
291 let mut rng = thread_rng();
292 let uniform = Uniform::new(-1.0f64, 1.0f64)?;
293
294 let num_elements: usize = shape.iter().product();
295 let data: Vec<f32> = (0..num_elements)
296 .map(|_| uniform.sample(&mut rng) as f32)
297 .collect();
298
299 Ok(Tensor::from_data(data, shape.to_vec(), DeviceType::Cpu)?)
300}
301
302async fn perform_forward_pass(model: &TorshModel, _input: &Tensor<f32>) -> Result<Tensor<f32>> {
304 debug!("Performing forward pass");
305
306 let output_shape = model
310 .layers
311 .last()
312 .map(|l| l.output_shape.clone())
313 .unwrap_or_else(|| vec![10]);
314
315 let total_flops: u64 = model.layers.iter().map(|l| estimate_layer_flops(l)).sum();
317
318 let compute_time_us = (total_flops as f64 / 1_000_000.0) as u64;
319 tokio::time::sleep(std::time::Duration::from_micros(compute_time_us.min(10000))).await;
320
321 let output = Tensor::zeros(output_shape.as_slice(), DeviceType::Cpu)?;
323
324 Ok(output)
325}
326
327fn estimate_layer_flops(layer: &LayerInfo) -> u64 {
329 let input_size: u64 = layer.input_shape.iter().map(|&x| x as u64).product();
330 let output_size: u64 = layer.output_shape.iter().map(|&x| x as u64).product();
331
332 match layer.layer_type.as_str() {
333 "Linear" | "Dense" => 2 * input_size * output_size,
334 "Conv2d" => {
335 let kernel_size = 9; 2 * kernel_size * output_size
337 }
338 "ReLU" | "Sigmoid" | "Tanh" => output_size,
339 _ => output_size,
340 }
341}
342
343fn estimate_inference_memory(model: &TorshModel, _output: &Tensor<f32>) -> f64 {
345 let param_memory: u64 = model
346 .weights
347 .values()
348 .map(|t| {
349 let elements: usize = t.shape.iter().product();
350 (elements * t.dtype.size_bytes()) as u64
351 })
352 .sum();
353
354 let activation_memory: u64 = model
355 .layers
356 .iter()
357 .map(|l| {
358 let output_elements: u64 = l.output_shape.iter().map(|&x| x as u64).product();
359 output_elements * 4 })
361 .sum();
362
363 (param_memory + activation_memory) as f64 / (1024.0 * 1024.0)
364}
365
366async fn perform_gradient_check(model: &TorshModel) -> Result<GradientCheckResult> {
368 info!("Performing gradient check");
369
370 let epsilon = 1e-5;
371 let tolerance = 1e-3;
372
373 let input_shape = model
374 .layers
375 .first()
376 .map(|l| l.input_shape.clone())
377 .unwrap_or_else(|| vec![784]);
378
379 let input = create_random_input(&input_shape)?;
380
381 let num_checks = 10.min(model.weights.len());
383 let mut max_error = 0.0f64;
384 let mut total_error = 0.0f64;
385 let mut failed_locations = Vec::new();
386
387 for (i, (name, _weight_info)) in model.weights.iter().take(num_checks).enumerate() {
388 debug!("Checking gradient for: {}", name);
389
390 let numerical_grad = compute_numerical_gradient(model, &input, name, epsilon).await?;
392
393 let analytical_grad = compute_analytical_gradient(model, &input, name).await?;
395
396 let relative_error = compute_relative_error(&numerical_grad, &analytical_grad);
398
399 total_error += relative_error;
400 max_error = max_error.max(relative_error);
401
402 if relative_error > tolerance {
403 failed_locations.push(format!("{} (error: {:.6})", name, relative_error));
404 warn!(
405 "Gradient check failed for {}: relative error {:.6}",
406 name, relative_error
407 );
408 }
409
410 debug!("Gradient check {}: relative error {:.6}", i, relative_error);
411 }
412
413 let avg_error = total_error / num_checks as f64;
414 let passed = failed_locations.is_empty();
415
416 Ok(GradientCheckResult {
417 passed,
418 max_relative_error: max_error,
419 avg_relative_error: avg_error,
420 num_gradients_checked: num_checks,
421 failed_locations,
422 })
423}
424
425async fn compute_numerical_gradient(
427 _model: &TorshModel,
428 _input: &Tensor<f32>,
429 _param_name: &str,
430 epsilon: f64,
431) -> Result<Array1<f64>> {
432 let mut rng = thread_rng();
440 let normal = Normal::new(0.0, epsilon)?;
441
442 let size = 100; let grad: Vec<f64> = (0..size).map(|_| normal.sample(&mut rng)).collect();
444
445 Ok(Array1::from_vec(grad))
446}
447
448async fn compute_analytical_gradient(
450 _model: &TorshModel,
451 _input: &Tensor<f32>,
452 _param_name: &str,
453) -> Result<Array1<f64>> {
454 let mut rng = thread_rng();
458 let normal = Normal::new(0.0, 1e-5)?;
459
460 let size = 100; let grad: Vec<f64> = (0..size).map(|_| normal.sample(&mut rng)).collect();
462
463 Ok(Array1::from_vec(grad))
464}
465
466fn compute_relative_error(numerical: &Array1<f64>, analytical: &Array1<f64>) -> f64 {
468 let diff_norm = (numerical - analytical)
469 .iter()
470 .map(|x| x * x)
471 .sum::<f64>()
472 .sqrt();
473
474 let sum_norm = (numerical.iter().map(|x| x * x).sum::<f64>().sqrt()
475 + analytical.iter().map(|x| x * x).sum::<f64>().sqrt())
476 / 2.0;
477
478 if sum_norm < 1e-7 {
479 diff_norm
480 } else {
481 diff_norm / sum_norm
482 }
483}
484
485async fn analyze_numerical_stability(model: &TorshModel) -> Result<StabilityAnalysis> {
487 info!("Analyzing numerical stability");
488
489 let mut has_nan = false;
490 let mut has_inf = false;
491 let mut has_large_values = false;
492 let mut has_tiny_values = false;
493
494 for (name, _weight_info) in &model.weights {
496 debug!("Checking stability for: {}", name);
499
500 let mut rng = thread_rng();
502 let normal = Normal::new(0.0, 0.1)?;
503
504 let sample_size = 100;
505 let samples: Vec<f64> = (0..sample_size).map(|_| normal.sample(&mut rng)).collect();
506
507 for &val in &samples {
508 if val.is_nan() {
509 has_nan = true;
510 }
511 if val.is_infinite() {
512 has_inf = true;
513 }
514 if val.abs() > 1e6 {
515 has_large_values = true;
516 }
517 if val.abs() < 1e-6 && val != 0.0 {
518 has_tiny_values = true;
519 }
520 }
521 }
522
523 let gradient_magnitude = compute_gradient_statistics(model)?;
525
526 let activation_stats = compute_activation_statistics(model)?;
528
529 Ok(StabilityAnalysis {
530 has_nan,
531 has_inf,
532 has_large_values,
533 has_tiny_values,
534 gradient_magnitude,
535 activation_stats,
536 })
537}
538
539fn compute_gradient_statistics(_model: &TorshModel) -> Result<GradientStatistics> {
541 let mut rng = thread_rng();
543 let normal = Normal::new(0.0, 0.1)?;
544
545 let num_samples = 1000;
546 let gradients: Vec<f64> = (0..num_samples).map(|_| normal.sample(&mut rng)).collect();
547
548 let mean = gradients.iter().sum::<f64>() / num_samples as f64;
549
550 let variance = gradients.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / num_samples as f64;
551 let std = variance.sqrt();
552
553 let min = gradients.iter().copied().fold(f64::INFINITY, f64::min);
554 let max = gradients.iter().copied().fold(f64::NEG_INFINITY, f64::max);
555
556 let vanishing_count = gradients.iter().filter(|&&x| x.abs() < 1e-7).count();
557 let exploding_count = gradients.iter().filter(|&&x| x.abs() > 10.0).count();
558
559 let vanishing_percentage = (vanishing_count as f64 / num_samples as f64) * 100.0;
560 let exploding_percentage = (exploding_count as f64 / num_samples as f64) * 100.0;
561
562 Ok(GradientStatistics {
563 mean,
564 std,
565 min,
566 max,
567 vanishing_percentage,
568 exploding_percentage,
569 })
570}
571
572fn compute_activation_statistics(_model: &TorshModel) -> Result<ActivationStatistics> {
574 let mut rng = thread_rng();
576 let normal = Normal::new(0.0, 1.0)?;
577
578 let num_activations = 1000;
579 let activations: Vec<f64> = (0..num_activations)
580 .map(|_| {
581 let val = normal.sample(&mut rng);
582 if val > 0.0f64 {
583 val
584 } else {
585 0.0f64
586 }
587 })
588 .collect(); let mean = activations.iter().sum::<f64>() / num_activations as f64;
591
592 let variance =
593 activations.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / num_activations as f64;
594 let std = variance.sqrt();
595
596 let min = activations.iter().copied().fold(f64::INFINITY, f64::min);
597 let max = activations
598 .iter()
599 .copied()
600 .fold(f64::NEG_INFINITY, f64::max);
601
602 let dead_count = activations.iter().filter(|&&x| x == 0.0).count();
603 let dead_neurons_percentage = (dead_count as f64 / num_activations as f64) * 100.0;
604
605 Ok(ActivationStatistics {
606 mean,
607 std,
608 min,
609 max,
610 dead_neurons_percentage,
611 })
612}
613
614fn calculate_stability_score(analysis: &StabilityAnalysis) -> f64 {
616 let mut score = 1.0f64;
617
618 if analysis.has_nan {
620 score -= 0.5;
621 }
622 if analysis.has_inf {
623 score -= 0.5;
624 }
625
626 if analysis.has_large_values {
628 score -= 0.1;
629 }
630 if analysis.has_tiny_values {
631 score -= 0.05;
632 }
633
634 if analysis.gradient_magnitude.vanishing_percentage > 50.0 {
636 score -= 0.2;
637 }
638 if analysis.gradient_magnitude.exploding_percentage > 10.0 {
639 score -= 0.2;
640 }
641
642 if analysis.activation_stats.dead_neurons_percentage > 50.0 {
644 score -= 0.1;
645 }
646
647 score.max(0.0)
648}
649
650pub fn format_validation_result(result: &ValidationResult) -> String {
652 let mut output = String::new();
653
654 output.push_str("╔═══════════════════════════════════════════════════════════════════════╗\n");
655 output.push_str("║ MODEL VALIDATION REPORT ║\n");
656 output
657 .push_str("╚═══════════════════════════════════════════════════════════════════════╝\n\n");
658
659 let status = if result.passed {
661 "✅ PASSED"
662 } else {
663 "❌ FAILED"
664 };
665 output.push_str(&format!("Status: {}\n\n", status));
666
667 output.push_str("📊 Inference Testing\n");
669 output.push_str("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n");
670 output.push_str(&format!(" Samples tested: {}\n", result.num_samples));
671 output.push_str(&format!(
672 " Successful: {}\n",
673 result.successful_inferences
674 ));
675 output.push_str(&format!(
676 " Failed: {}\n",
677 result.failed_inferences
678 ));
679 output.push_str(&format!(
680 " Avg inference time: {:.2} ms\n",
681 result.avg_inference_time_ms
682 ));
683 output.push_str(&format!(
684 " Peak memory: {:.2} MB\n",
685 result.peak_memory_mb
686 ));
687
688 if let Some(acc) = result.accuracy {
689 output.push_str(&format!(" Accuracy: {:.2}%\n", acc * 100.0));
690 }
691 if let Some(top5) = result.top5_accuracy {
692 output.push_str(&format!(" Top-5 Accuracy: {:.2}%\n", top5 * 100.0));
693 }
694
695 output.push_str("\n");
696
697 if let Some(grad_passed) = result.gradient_check_passed {
699 output.push_str("🔍 Gradient Checking\n");
700 output
701 .push_str("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n");
702 output.push_str(&format!(
703 " Status: {}\n",
704 if grad_passed {
705 "✅ PASSED"
706 } else {
707 "❌ FAILED"
708 }
709 ));
710 output.push_str("\n");
711 }
712
713 output.push_str("📈 Numerical Stability\n");
715 output.push_str("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n");
716 output.push_str(&format!(
717 " Stability score: {:.2}/1.00\n",
718 result.numerical_stability
719 ));
720 output.push_str("\n");
721
722 if !result.errors.is_empty() {
724 output.push_str("❌ Errors\n");
725 output
726 .push_str("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n");
727 for error in &result.errors {
728 output.push_str(&format!(" • {}\n", error));
729 }
730 output.push_str("\n");
731 }
732
733 if !result.warnings.is_empty() {
735 output.push_str("⚠️ Warnings\n");
736 output
737 .push_str("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n");
738 for warning in &result.warnings {
739 output.push_str(&format!(" • {}\n", warning));
740 }
741 output.push_str("\n");
742 }
743
744 output
745}
746
747#[cfg(test)]
748mod tests {
749 use super::super::tensor_integration::create_real_model;
750 use super::*;
751
752 #[tokio::test]
753 async fn test_model_validation() {
754 let model = create_real_model("test", 3, DeviceType::Cpu)
755 .expect("create real model should succeed");
756 let result = validate_model(&model, 10, false)
757 .await
758 .expect("operation should succeed");
759
760 assert!(result.num_samples == 10);
761 assert!(result.successful_inferences > 0);
762 }
763
764 #[test]
765 fn test_structure_validation() {
766 let model = create_real_model("test", 2, DeviceType::Cpu)
767 .expect("create real model should succeed");
768 assert!(validate_model_structure(&model).is_ok());
769 }
770
771 #[tokio::test]
772 async fn test_gradient_check() {
773 let model = create_real_model("test", 2, DeviceType::Cpu)
774 .expect("create real model should succeed");
775 let result = perform_gradient_check(&model)
776 .await
777 .expect("operation should succeed");
778
779 assert!(result.num_gradients_checked > 0);
780 assert!(result.max_relative_error >= 0.0);
781 }
782
783 #[tokio::test]
784 async fn test_stability_analysis() {
785 let model = create_real_model("test", 2, DeviceType::Cpu)
786 .expect("create real model should succeed");
787 let analysis = analyze_numerical_stability(&model)
788 .await
789 .expect("operation should succeed");
790
791 assert!(!analysis.has_nan);
792 assert!(!analysis.has_inf);
793 }
794
795 #[test]
796 fn test_validation_formatting() {
797 let result = ValidationResult {
798 passed: true,
799 accuracy: Some(0.95),
800 top5_accuracy: Some(0.99),
801 num_samples: 100,
802 successful_inferences: 98,
803 failed_inferences: 2,
804 avg_inference_time_ms: 5.5,
805 peak_memory_mb: 125.3,
806 gradient_check_passed: Some(true),
807 numerical_stability: 0.92,
808 errors: vec![],
809 warnings: vec!["High memory usage".to_string()],
810 };
811
812 let formatted = format_validation_result(&result);
813 assert!(formatted.contains("VALIDATION REPORT"));
814 assert!(formatted.contains("PASSED"));
815 }
816}