1use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ArchitectureAnalysisConfig {
13 pub enable_parameter_counting: bool,
15 pub enable_receptive_field_calculation: bool,
17 pub enable_depth_width_analysis: bool,
19 pub enable_connectivity_patterns: bool,
21 pub enable_symmetry_detection: bool,
23 pub max_receptive_field_depth: usize,
25 pub sampling_rate: f32,
27}
28
29impl Default for ArchitectureAnalysisConfig {
30 fn default() -> Self {
31 Self {
32 enable_parameter_counting: true,
33 enable_receptive_field_calculation: true,
34 enable_depth_width_analysis: true,
35 enable_connectivity_patterns: true,
36 enable_symmetry_detection: true,
37 max_receptive_field_depth: 50,
38 sampling_rate: 1.0,
39 }
40 }
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
45pub enum LayerType {
46 Linear,
47 Conv2D,
48 Conv3D,
49 BatchNorm,
50 LayerNorm,
51 Attention,
52 Embedding,
53 Dropout,
54 Activation,
55 Pooling,
56 Residual,
57 Unknown,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct LayerInfo {
63 pub id: String,
64 pub name: String,
65 pub layer_type: LayerType,
66 pub input_shape: Vec<usize>,
67 pub output_shape: Vec<usize>,
68 pub parameters: usize,
69 pub trainable_parameters: usize,
70 pub memory_usage: usize,
71 pub flops: u64,
72 pub receptive_field: Option<ReceptiveField>,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct ReceptiveField {
78 pub size: Vec<usize>,
79 pub stride: Vec<usize>,
80 pub padding: Vec<usize>,
81 pub effective_size: Vec<usize>,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct ConnectivityPattern {
87 pub from_layer: String,
88 pub to_layer: String,
89 pub connection_type: ConnectionType,
90 pub strength: f32,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize, Hash, Eq, PartialEq)]
94pub enum ConnectionType {
95 Sequential,
96 Residual,
97 Attention,
98 Skip,
99 Recurrent,
100 Branching,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct SymmetryInfo {
106 pub symmetry_type: SymmetryType,
107 pub symmetric_layers: Vec<String>,
108 pub confidence: f32,
109 pub description: String,
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub enum SymmetryType {
114 Translational,
115 Rotational,
116 Reflection,
117 Permutation,
118 Block,
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct ArchitectureAnalysisReport {
124 pub total_parameters: usize,
125 pub trainable_parameters: usize,
126 pub model_size_mb: f32,
127 pub total_flops: u64,
128 pub model_depth: usize,
129 pub model_width: usize,
130 pub layers: Vec<LayerInfo>,
131 pub connectivity_patterns: Vec<ConnectivityPattern>,
132 pub symmetries: Vec<SymmetryInfo>,
133 pub parameter_distribution: HashMap<LayerType, usize>,
134 pub bottlenecks: Vec<String>,
135 pub efficiency_metrics: EfficiencyMetrics,
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct EfficiencyMetrics {
141 pub parameter_efficiency: f32,
142 pub flops_efficiency: f32,
143 pub memory_efficiency: f32,
144 pub depth_efficiency: f32,
145 pub overall_score: f32,
146}
147
148#[derive(Debug)]
150pub struct ArchitectureAnalyzer {
151 config: ArchitectureAnalysisConfig,
152 layers: Vec<LayerInfo>,
153 connections: Vec<ConnectivityPattern>,
154 analysis_cache: HashMap<String, ArchitectureAnalysisReport>,
155}
156
157impl ArchitectureAnalyzer {
158 pub fn new(config: ArchitectureAnalysisConfig) -> Self {
160 Self {
161 config,
162 layers: Vec::new(),
163 connections: Vec::new(),
164 analysis_cache: HashMap::new(),
165 }
166 }
167
168 pub fn register_layer(&mut self, layer: LayerInfo) {
170 self.layers.push(layer);
171 }
172
173 pub fn add_connection(&mut self, pattern: ConnectivityPattern) {
175 self.connections.push(pattern);
176 }
177
178 pub async fn analyze(&mut self) -> Result<ArchitectureAnalysisReport> {
180 let mut report = ArchitectureAnalysisReport {
181 total_parameters: 0,
182 trainable_parameters: 0,
183 model_size_mb: 0.0,
184 total_flops: 0,
185 model_depth: 0,
186 model_width: 0,
187 layers: self.layers.clone(),
188 connectivity_patterns: self.connections.clone(),
189 symmetries: Vec::new(),
190 parameter_distribution: HashMap::new(),
191 bottlenecks: Vec::new(),
192 efficiency_metrics: EfficiencyMetrics {
193 parameter_efficiency: 0.0,
194 flops_efficiency: 0.0,
195 memory_efficiency: 0.0,
196 depth_efficiency: 0.0,
197 overall_score: 0.0,
198 },
199 };
200
201 if self.config.enable_parameter_counting {
202 self.count_parameters(&mut report);
203 }
204
205 if self.config.enable_receptive_field_calculation {
206 self.calculate_receptive_fields(&mut report).await?;
207 }
208
209 if self.config.enable_depth_width_analysis {
210 self.analyze_depth_width(&mut report);
211 }
212
213 if self.config.enable_connectivity_patterns {
214 self.analyze_connectivity_patterns(&mut report);
215 }
216
217 if self.config.enable_symmetry_detection {
218 self.detect_symmetries(&mut report);
219 }
220
221 self.calculate_efficiency_metrics(&mut report);
222 self.identify_bottlenecks(&mut report);
223
224 Ok(report)
225 }
226
227 fn count_parameters(&self, report: &mut ArchitectureAnalysisReport) {
229 let mut param_distribution: HashMap<LayerType, usize> = HashMap::new();
230
231 for layer in &self.layers {
232 report.total_parameters += layer.parameters;
233 report.trainable_parameters += layer.trainable_parameters;
234
235 *param_distribution.entry(layer.layer_type.clone()).or_insert(0) += layer.parameters;
236 }
237
238 report.parameter_distribution = param_distribution;
239
240 report.model_size_mb = (report.total_parameters * 4) as f32 / (1024.0 * 1024.0);
242
243 report.total_flops = self.layers.iter().map(|l| l.flops).sum();
245 }
246
247 async fn calculate_receptive_fields(
249 &mut self,
250 report: &mut ArchitectureAnalysisReport,
251 ) -> Result<()> {
252 for layer in &mut self.layers {
253 if matches!(layer.layer_type, LayerType::Conv2D | LayerType::Conv3D) {
254 layer.receptive_field =
255 Some(Self::compute_receptive_field_static(&layer.layer_type));
256 }
257 }
258
259 report.layers = self.layers.clone();
260 Ok(())
261 }
262
263 fn compute_receptive_field_static(layer_type: &LayerType) -> ReceptiveField {
265 match layer_type {
266 LayerType::Conv2D => {
267 let kernel_size = vec![3, 3]; let stride = vec![1, 1];
270 let padding = vec![1, 1];
271
272 ReceptiveField {
273 size: kernel_size.clone(),
274 stride,
275 padding,
276 effective_size: kernel_size,
277 }
278 },
279 LayerType::Conv3D => {
280 let kernel_size = vec![3, 3, 3]; let stride = vec![1, 1, 1];
283 let padding = vec![1, 1, 1];
284
285 ReceptiveField {
286 size: kernel_size.clone(),
287 stride,
288 padding,
289 effective_size: kernel_size,
290 }
291 },
292 _ => {
293 ReceptiveField {
295 size: vec![1],
296 stride: vec![1],
297 padding: vec![0],
298 effective_size: vec![1],
299 }
300 },
301 }
302 }
303
304 #[allow(dead_code)]
306 fn compute_receptive_field(&self, layer: &LayerInfo) -> ReceptiveField {
307 Self::compute_receptive_field_static(&layer.layer_type)
308 }
309
310 fn analyze_depth_width(&self, report: &mut ArchitectureAnalysisReport) {
312 report.model_depth = self.layers.len();
314
315 report.model_width = self.layers.iter().map(|l| l.parameters).max().unwrap_or(0);
317 }
318
319 fn analyze_connectivity_patterns(&self, report: &mut ArchitectureAnalysisReport) {
321 let mut pattern_types: HashMap<ConnectionType, usize> = HashMap::new();
322
323 for connection in &self.connections {
324 *pattern_types.entry(connection.connection_type.clone()).or_insert(0) += 1;
325 }
326
327 for (connection_type, count) in pattern_types {
329 if count > self.layers.len() / 2 {
330 report.bottlenecks.push(format!(
332 "High {:?} connectivity: {} connections",
333 connection_type, count
334 ));
335 }
336 }
337 }
338
339 fn detect_symmetries(&self, report: &mut ArchitectureAnalysisReport) {
341 let mut block_patterns: HashMap<Vec<LayerType>, Vec<usize>> = HashMap::new();
343
344 for window_size in 2..=5.min(self.layers.len()) {
346 for i in 0..=(self.layers.len() - window_size) {
347 let pattern: Vec<LayerType> =
348 self.layers[i..i + window_size].iter().map(|l| l.layer_type.clone()).collect();
349
350 block_patterns.entry(pattern).or_default().push(i);
351 }
352 }
353
354 for (pattern, positions) in block_patterns {
356 if positions.len() > 1 {
357 let confidence = positions.len() as f32 / self.layers.len() as f32;
358
359 if confidence > 0.1 {
360 report.symmetries.push(SymmetryInfo {
362 symmetry_type: SymmetryType::Block,
363 symmetric_layers: positions
364 .iter()
365 .map(|&i| format!("block_{}", i))
366 .collect(),
367 confidence,
368 description: format!(
369 "Repeated block pattern: {:?} appears {} times",
370 pattern,
371 positions.len()
372 ),
373 });
374 }
375 }
376 }
377
378 let mut param_groups: HashMap<usize, Vec<String>> = HashMap::new();
380 for layer in &self.layers {
381 param_groups.entry(layer.parameters).or_default().push(layer.id.clone());
382 }
383
384 for (param_count, layer_ids) in param_groups {
385 if layer_ids.len() > 2 && param_count > 0 {
386 let confidence = layer_ids.len() as f32 / self.layers.len() as f32;
387
388 report.symmetries.push(SymmetryInfo {
389 symmetry_type: SymmetryType::Permutation,
390 symmetric_layers: layer_ids.clone(),
391 confidence,
392 description: format!(
393 "Parameter symmetry: {} layers with {} parameters each",
394 layer_ids.len(),
395 param_count
396 ),
397 });
398 }
399 }
400 }
401
402 fn calculate_efficiency_metrics(&self, report: &mut ArchitectureAnalysisReport) {
404 let total_params = report.total_parameters as f32;
405 let total_flops = report.total_flops as f32;
406 let depth = report.model_depth as f32;
407 let memory = report.model_size_mb;
408
409 report.efficiency_metrics.parameter_efficiency = if total_params > 0.0 {
411 1.0 / (total_params / 1_000_000.0).log10().max(1.0) } else {
413 1.0
414 };
415
416 report.efficiency_metrics.flops_efficiency = if total_flops > 0.0 {
418 1.0 / (total_flops / 1_000_000_000.0).log10().max(1.0) } else {
420 1.0
421 };
422
423 report.efficiency_metrics.memory_efficiency = if memory > 0.0 {
425 1.0 / (memory / 100.0).log10().max(1.0) } else {
427 1.0
428 };
429
430 report.efficiency_metrics.depth_efficiency = if depth > 0.0 {
432 let optimal_depth = 20.0; 1.0 - ((depth - optimal_depth).abs() / optimal_depth).min(1.0)
434 } else {
435 0.0
436 };
437
438 report.efficiency_metrics.overall_score = 0.3
440 * report.efficiency_metrics.parameter_efficiency
441 + 0.3 * report.efficiency_metrics.flops_efficiency
442 + 0.2 * report.efficiency_metrics.memory_efficiency
443 + 0.2 * report.efficiency_metrics.depth_efficiency;
444 }
445
446 fn identify_bottlenecks(&self, report: &mut ArchitectureAnalysisReport) {
448 if let Some(_max_params) = self.layers.iter().map(|l| l.parameters).max() {
450 let avg_params = report.total_parameters / self.layers.len().max(1);
451
452 for layer in &self.layers {
453 if layer.parameters > avg_params * 5 {
454 report.bottlenecks.push(format!(
455 "Parameter bottleneck: Layer '{}' has {} parameters ({}x average)",
456 layer.name,
457 layer.parameters,
458 layer.parameters / avg_params.max(1)
459 ));
460 }
461 }
462 }
463
464 for layer in &self.layers {
466 if layer.memory_usage > 100 * 1024 * 1024 {
467 report.bottlenecks.push(format!(
469 "Memory bottleneck: Layer '{}' uses {:.1}MB memory",
470 layer.name,
471 layer.memory_usage as f32 / (1024.0 * 1024.0)
472 ));
473 }
474 }
475
476 if let Some(_max_flops) = self.layers.iter().map(|l| l.flops).max() {
478 let avg_flops = report.total_flops / self.layers.len().max(1) as u64;
479
480 for layer in &self.layers {
481 if layer.flops > avg_flops * 10 {
482 report.bottlenecks.push(format!(
483 "Computation bottleneck: Layer '{}' requires {} FLOPS ({}x average)",
484 layer.name,
485 layer.flops,
486 layer.flops / avg_flops.max(1)
487 ));
488 }
489 }
490 }
491 }
492
493 pub async fn quick_analysis(&self) -> Result<crate::QuickArchitectureSummary> {
495 let total_parameters = self.layers.iter().map(|l| l.parameters as u64).sum::<u64>();
496 let total_flops = self.layers.iter().map(|l| l.flops).sum::<u64>();
497
498 let model_size_mb = (total_parameters as f64 * 4.0) / (1024.0 * 1024.0);
500
501 let efficiency_score = if total_flops > 0 {
503 (total_parameters as f64 / total_flops as f64 * 1000.0).min(100.0)
504 } else {
505 50.0
506 };
507
508 let mut recommendations = Vec::new();
509 if total_parameters > 1_000_000_000 {
510 recommendations
511 .push("Consider model compression techniques for large model".to_string());
512 }
513 if efficiency_score < 30.0 {
514 recommendations.push("Model architecture could be more efficient".to_string());
515 }
516 if model_size_mb > 1000.0 {
517 recommendations.push("Large model size may impact deployment".to_string());
518 }
519 if recommendations.is_empty() {
520 recommendations.push("Architecture appears well-balanced".to_string());
521 }
522
523 Ok(crate::QuickArchitectureSummary {
524 total_parameters,
525 model_size_mb,
526 efficiency_score,
527 recommendations,
528 })
529 }
530
531 pub async fn generate_report(&self) -> Result<ArchitectureAnalysisReport> {
533 let mut temp_analyzer = ArchitectureAnalyzer {
535 config: self.config.clone(),
536 layers: self.layers.clone(),
537 connections: self.connections.clone(),
538 analysis_cache: HashMap::new(),
539 };
540
541 temp_analyzer.analyze().await
542 }
543
544 pub fn clear(&mut self) {
546 self.layers.clear();
547 self.connections.clear();
548 self.analysis_cache.clear();
549 }
550
551 pub fn get_summary(&self) -> ArchitectureSummary {
553 let total_params: usize = self.layers.iter().map(|l| l.parameters).sum();
554 let total_flops: u64 = self.layers.iter().map(|l| l.flops).sum();
555
556 ArchitectureSummary {
557 total_layers: self.layers.len(),
558 total_parameters: total_params,
559 total_flops,
560 average_layer_size: if !self.layers.is_empty() {
561 total_params / self.layers.len()
562 } else {
563 0
564 },
565 layer_type_distribution: {
566 let mut dist = HashMap::new();
567 for layer in &self.layers {
568 *dist.entry(layer.layer_type.clone()).or_insert(0) += 1;
569 }
570 dist
571 },
572 }
573 }
574}
575
576#[derive(Debug, Clone, Serialize, Deserialize)]
578pub struct ArchitectureSummary {
579 pub total_layers: usize,
580 pub total_parameters: usize,
581 pub total_flops: u64,
582 pub average_layer_size: usize,
583 pub layer_type_distribution: HashMap<LayerType, usize>,
584}
585
586pub fn create_layer_info(
588 id: String,
589 name: String,
590 layer_type: LayerType,
591 input_shape: Vec<usize>,
592 output_shape: Vec<usize>,
593 parameters: usize,
594) -> LayerInfo {
595 let memory_usage = parameters * 4; let flops = estimate_flops(&layer_type, &input_shape, &output_shape, parameters);
597
598 LayerInfo {
599 id,
600 name,
601 layer_type,
602 input_shape,
603 output_shape,
604 parameters,
605 trainable_parameters: parameters, memory_usage,
607 flops,
608 receptive_field: None,
609 }
610}
611
612fn estimate_flops(
614 layer_type: &LayerType,
615 input_shape: &[usize],
616 output_shape: &[usize],
617 parameters: usize,
618) -> u64 {
619 match layer_type {
620 LayerType::Linear => {
621 if input_shape.len() >= 2 && output_shape.len() >= 2 {
622 let batch_size = input_shape[0] as u64;
623 let input_features = input_shape[1] as u64;
624 let output_features = output_shape[1] as u64;
625 batch_size * input_features * output_features * 2
626 } else {
627 parameters as u64 * 2
628 }
629 },
630 LayerType::Conv2D => {
631 if output_shape.len() >= 4 {
632 let batch_size = output_shape[0] as u64;
633 let output_channels = output_shape[1] as u64;
634 let output_h = output_shape[2] as u64;
635 let output_w = output_shape[3] as u64;
636 batch_size
637 * output_channels
638 * output_h
639 * output_w
640 * (parameters as u64 / output_channels).max(1)
641 * 2
642 } else {
643 parameters as u64 * 2
644 }
645 },
646 LayerType::Attention => {
647 if input_shape.len() >= 3 {
648 let batch_size = input_shape[0] as u64;
649 let seq_len = input_shape[1] as u64;
650 let hidden_size = input_shape[2] as u64;
651 batch_size * seq_len * seq_len * hidden_size * 4
652 } else {
653 parameters as u64 * 4
654 }
655 },
656 _ => parameters as u64,
657 }
658}
659
660#[cfg(test)]
665mod tests {
666 use super::*;
667
668 fn make_linear_layer(id: &str, params: usize) -> LayerInfo {
669 create_layer_info(
670 id.to_string(),
671 format!("{}_layer", id),
672 LayerType::Linear,
673 vec![1, params],
674 vec![1, params],
675 params,
676 )
677 }
678
679 #[test]
682 fn test_config_default() {
683 let cfg = ArchitectureAnalysisConfig::default();
684 assert!(cfg.enable_parameter_counting);
685 assert!(cfg.enable_receptive_field_calculation);
686 assert!(cfg.enable_depth_width_analysis);
687 assert!(cfg.enable_connectivity_patterns);
688 assert!(cfg.enable_symmetry_detection);
689 assert!(cfg.max_receptive_field_depth > 0);
690 assert!((cfg.sampling_rate - 1.0).abs() < 1e-6);
691 }
692
693 #[test]
696 fn test_analyzer_new_empty() {
697 let analyzer = ArchitectureAnalyzer::new(ArchitectureAnalysisConfig::default());
698 let summary = analyzer.get_summary();
699 assert_eq!(summary.total_layers, 0);
700 assert_eq!(summary.total_parameters, 0);
701 }
702
703 #[test]
704 fn test_register_layer_accumulates() {
705 let mut analyzer = ArchitectureAnalyzer::new(ArchitectureAnalysisConfig::default());
706 analyzer.register_layer(make_linear_layer("l0", 512));
707 analyzer.register_layer(make_linear_layer("l1", 256));
708 assert_eq!(analyzer.get_summary().total_layers, 2);
709 }
710
711 #[test]
712 fn test_add_connection() {
713 let mut analyzer = ArchitectureAnalyzer::new(ArchitectureAnalysisConfig::default());
714 analyzer.register_layer(make_linear_layer("a", 128));
715 analyzer.register_layer(make_linear_layer("b", 128));
716 analyzer.add_connection(ConnectivityPattern {
717 from_layer: "a".to_string(),
718 to_layer: "b".to_string(),
719 connection_type: ConnectionType::Sequential,
720 strength: 1.0,
721 });
722 let summary = analyzer.get_summary();
723 assert_eq!(summary.total_layers, 2);
724 }
725
726 #[test]
727 fn test_clear_resets_state() {
728 let mut analyzer = ArchitectureAnalyzer::new(ArchitectureAnalysisConfig::default());
729 analyzer.register_layer(make_linear_layer("l0", 64));
730 analyzer.clear();
731 assert_eq!(analyzer.get_summary().total_layers, 0);
732 }
733
734 #[test]
737 fn test_create_layer_info_parameters() {
738 let layer = make_linear_layer("dense", 1024);
739 assert_eq!(layer.parameters, 1024);
740 assert_eq!(layer.trainable_parameters, 1024);
741 assert_eq!(layer.memory_usage, 1024 * 4);
742 assert!(layer.receptive_field.is_none());
743 }
744
745 #[test]
746 fn test_create_layer_info_conv2d_flops() {
747 let layer = create_layer_info(
748 "conv".to_string(),
749 "conv_layer".to_string(),
750 LayerType::Conv2D,
751 vec![1, 3, 224, 224],
752 vec![1, 64, 112, 112],
753 64 * 3 * 3 * 3,
754 );
755 assert!(layer.flops > 0);
756 }
757
758 #[test]
759 fn test_create_layer_info_attention_flops() {
760 let layer = create_layer_info(
761 "attn".to_string(),
762 "attention".to_string(),
763 LayerType::Attention,
764 vec![1, 128, 768],
765 vec![1, 128, 768],
766 768 * 768 * 4,
767 );
768 assert!(layer.flops > 0);
769 }
770
771 #[test]
774 fn test_summary_totals() {
775 let mut analyzer = ArchitectureAnalyzer::new(ArchitectureAnalysisConfig::default());
776 analyzer.register_layer(make_linear_layer("l0", 100));
777 analyzer.register_layer(make_linear_layer("l1", 200));
778 let s = analyzer.get_summary();
779 assert_eq!(s.total_parameters, 300);
780 assert_eq!(s.average_layer_size, 150);
781 }
782
783 #[test]
784 fn test_summary_layer_type_distribution() {
785 let mut analyzer = ArchitectureAnalyzer::new(ArchitectureAnalysisConfig::default());
786 analyzer.register_layer(make_linear_layer("l0", 64));
787 analyzer.register_layer(make_linear_layer("l1", 64));
788 let s = analyzer.get_summary();
789 assert_eq!(
790 s.layer_type_distribution.get(&LayerType::Linear).copied().unwrap_or(0),
791 2
792 );
793 }
794
795 #[test]
798 fn test_layer_type_all_variants() {
799 let types = [
800 LayerType::Linear,
801 LayerType::Conv2D,
802 LayerType::Conv3D,
803 LayerType::BatchNorm,
804 LayerType::LayerNorm,
805 LayerType::Attention,
806 LayerType::Embedding,
807 LayerType::Dropout,
808 LayerType::Activation,
809 LayerType::Pooling,
810 LayerType::Residual,
811 LayerType::Unknown,
812 ];
813 for t in &types {
814 assert!(!format!("{:?}", t).is_empty());
815 }
816 }
817
818 #[test]
821 fn test_connection_type_variants() {
822 let types = [
823 ConnectionType::Sequential,
824 ConnectionType::Residual,
825 ConnectionType::Attention,
826 ConnectionType::Skip,
827 ConnectionType::Recurrent,
828 ConnectionType::Branching,
829 ];
830 for t in &types {
831 assert!(!format!("{:?}", t).is_empty());
832 }
833 }
834
835 #[test]
838 fn test_symmetry_type_variants() {
839 let types = [
840 SymmetryType::Translational,
841 SymmetryType::Rotational,
842 SymmetryType::Reflection,
843 SymmetryType::Permutation,
844 SymmetryType::Block,
845 ];
846 for t in &types {
847 assert!(!format!("{:?}", t).is_empty());
848 }
849 }
850
851 #[test]
854 fn test_efficiency_metrics_construction() {
855 let metrics = EfficiencyMetrics {
856 parameter_efficiency: 0.8,
857 flops_efficiency: 0.7,
858 memory_efficiency: 0.9,
859 depth_efficiency: 0.6,
860 overall_score: 0.75,
861 };
862 assert!((metrics.overall_score - 0.75).abs() < 1e-6);
863 }
864
865 #[tokio::test]
868 async fn test_analyze_empty() {
869 let mut analyzer = ArchitectureAnalyzer::new(ArchitectureAnalysisConfig::default());
870 let report = analyzer.analyze().await.expect("analyze should succeed");
871 assert_eq!(report.total_parameters, 0);
872 assert_eq!(report.model_depth, 0);
873 }
874
875 #[tokio::test]
876 async fn test_analyze_parameter_counting() {
877 let mut analyzer = ArchitectureAnalyzer::new(ArchitectureAnalysisConfig::default());
878 analyzer.register_layer(make_linear_layer("l0", 512));
879 analyzer.register_layer(make_linear_layer("l1", 512));
880 let report = analyzer.analyze().await.expect("analyze should succeed");
881 assert_eq!(report.total_parameters, 1024);
882 assert_eq!(report.trainable_parameters, 1024);
883 }
884
885 #[tokio::test]
886 async fn test_analyze_depth_width() {
887 let mut analyzer = ArchitectureAnalyzer::new(ArchitectureAnalysisConfig::default());
888 analyzer.register_layer(make_linear_layer("l0", 128));
889 analyzer.register_layer(make_linear_layer("l1", 256));
890 analyzer.register_layer(make_linear_layer("l2", 64));
891 let report = analyzer.analyze().await.expect("analyze should succeed");
892 assert_eq!(report.model_depth, 3);
893 assert_eq!(report.model_width, 256);
894 }
895
896 #[tokio::test]
897 async fn test_quick_analysis_returns_summary() {
898 let mut analyzer = ArchitectureAnalyzer::new(ArchitectureAnalysisConfig::default());
899 analyzer.register_layer(make_linear_layer("l0", 1000));
900 let qs = analyzer.quick_analysis().await.expect("quick_analysis should succeed");
901 assert_eq!(qs.total_parameters, 1000);
902 assert!(!qs.recommendations.is_empty());
903 }
904
905 #[tokio::test]
906 async fn test_generate_report_symmetry_detection() {
907 let mut analyzer = ArchitectureAnalyzer::new(ArchitectureAnalysisConfig::default());
908 for i in 0..6 {
910 analyzer.register_layer(make_linear_layer(&format!("l{}", i), 256));
911 }
912 let report = analyzer.generate_report().await.expect("report should succeed");
913 assert_eq!(report.total_parameters, 6 * 256);
915 }
916}