Skip to main content

trustformers_debug/
architecture_analysis.rs

1//! Architecture Analysis
2//!
3//! Comprehensive analysis tools for neural network architectures including
4//! parameter counting, receptive field calculation, and connectivity analysis.
5
6use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Configuration for architecture analysis
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ArchitectureAnalysisConfig {
13    /// Enable parameter counting
14    pub enable_parameter_counting: bool,
15    /// Enable receptive field calculation
16    pub enable_receptive_field_calculation: bool,
17    /// Enable depth/width analysis
18    pub enable_depth_width_analysis: bool,
19    /// Enable connectivity pattern detection
20    pub enable_connectivity_patterns: bool,
21    /// Enable symmetry detection
22    pub enable_symmetry_detection: bool,
23    /// Maximum depth to analyze for receptive fields
24    pub max_receptive_field_depth: usize,
25    /// Sampling rate for large models (0.0 to 1.0)
26    pub sampling_rate: f32,
27}
28
29impl Default for ArchitectureAnalysisConfig {
30    fn default() -> Self {
31        Self {
32            enable_parameter_counting: true,
33            enable_receptive_field_calculation: true,
34            enable_depth_width_analysis: true,
35            enable_connectivity_patterns: true,
36            enable_symmetry_detection: true,
37            max_receptive_field_depth: 50,
38            sampling_rate: 1.0,
39        }
40    }
41}
42
43/// Layer type for architecture analysis
44#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
45pub enum LayerType {
46    Linear,
47    Conv2D,
48    Conv3D,
49    BatchNorm,
50    LayerNorm,
51    Attention,
52    Embedding,
53    Dropout,
54    Activation,
55    Pooling,
56    Residual,
57    Unknown,
58}
59
60/// Information about a single layer
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct LayerInfo {
63    pub id: String,
64    pub name: String,
65    pub layer_type: LayerType,
66    pub input_shape: Vec<usize>,
67    pub output_shape: Vec<usize>,
68    pub parameters: usize,
69    pub trainable_parameters: usize,
70    pub memory_usage: usize,
71    pub flops: u64,
72    pub receptive_field: Option<ReceptiveField>,
73}
74
75/// Receptive field information
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct ReceptiveField {
78    pub size: Vec<usize>,
79    pub stride: Vec<usize>,
80    pub padding: Vec<usize>,
81    pub effective_size: Vec<usize>,
82}
83
84/// Connectivity pattern between layers
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct ConnectivityPattern {
87    pub from_layer: String,
88    pub to_layer: String,
89    pub connection_type: ConnectionType,
90    pub strength: f32,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize, Hash, Eq, PartialEq)]
94pub enum ConnectionType {
95    Sequential,
96    Residual,
97    Attention,
98    Skip,
99    Recurrent,
100    Branching,
101}
102
103/// Symmetry information in the architecture
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct SymmetryInfo {
106    pub symmetry_type: SymmetryType,
107    pub symmetric_layers: Vec<String>,
108    pub confidence: f32,
109    pub description: String,
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub enum SymmetryType {
114    Translational,
115    Rotational,
116    Reflection,
117    Permutation,
118    Block,
119}
120
121/// Architecture analysis results
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct ArchitectureAnalysisReport {
124    pub total_parameters: usize,
125    pub trainable_parameters: usize,
126    pub model_size_mb: f32,
127    pub total_flops: u64,
128    pub model_depth: usize,
129    pub model_width: usize,
130    pub layers: Vec<LayerInfo>,
131    pub connectivity_patterns: Vec<ConnectivityPattern>,
132    pub symmetries: Vec<SymmetryInfo>,
133    pub parameter_distribution: HashMap<LayerType, usize>,
134    pub bottlenecks: Vec<String>,
135    pub efficiency_metrics: EfficiencyMetrics,
136}
137
138/// Model efficiency metrics
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct EfficiencyMetrics {
141    pub parameter_efficiency: f32,
142    pub flops_efficiency: f32,
143    pub memory_efficiency: f32,
144    pub depth_efficiency: f32,
145    pub overall_score: f32,
146}
147
148/// Architecture analyzer
149#[derive(Debug)]
150pub struct ArchitectureAnalyzer {
151    config: ArchitectureAnalysisConfig,
152    layers: Vec<LayerInfo>,
153    connections: Vec<ConnectivityPattern>,
154    analysis_cache: HashMap<String, ArchitectureAnalysisReport>,
155}
156
157impl ArchitectureAnalyzer {
158    /// Create a new architecture analyzer
159    pub fn new(config: ArchitectureAnalysisConfig) -> Self {
160        Self {
161            config,
162            layers: Vec::new(),
163            connections: Vec::new(),
164            analysis_cache: HashMap::new(),
165        }
166    }
167
168    /// Register a layer for analysis
169    pub fn register_layer(&mut self, layer: LayerInfo) {
170        self.layers.push(layer);
171    }
172
173    /// Add a connection between layers
174    pub fn add_connection(&mut self, pattern: ConnectivityPattern) {
175        self.connections.push(pattern);
176    }
177
178    /// Analyze the registered architecture
179    pub async fn analyze(&mut self) -> Result<ArchitectureAnalysisReport> {
180        let mut report = ArchitectureAnalysisReport {
181            total_parameters: 0,
182            trainable_parameters: 0,
183            model_size_mb: 0.0,
184            total_flops: 0,
185            model_depth: 0,
186            model_width: 0,
187            layers: self.layers.clone(),
188            connectivity_patterns: self.connections.clone(),
189            symmetries: Vec::new(),
190            parameter_distribution: HashMap::new(),
191            bottlenecks: Vec::new(),
192            efficiency_metrics: EfficiencyMetrics {
193                parameter_efficiency: 0.0,
194                flops_efficiency: 0.0,
195                memory_efficiency: 0.0,
196                depth_efficiency: 0.0,
197                overall_score: 0.0,
198            },
199        };
200
201        if self.config.enable_parameter_counting {
202            self.count_parameters(&mut report);
203        }
204
205        if self.config.enable_receptive_field_calculation {
206            self.calculate_receptive_fields(&mut report).await?;
207        }
208
209        if self.config.enable_depth_width_analysis {
210            self.analyze_depth_width(&mut report);
211        }
212
213        if self.config.enable_connectivity_patterns {
214            self.analyze_connectivity_patterns(&mut report);
215        }
216
217        if self.config.enable_symmetry_detection {
218            self.detect_symmetries(&mut report);
219        }
220
221        self.calculate_efficiency_metrics(&mut report);
222        self.identify_bottlenecks(&mut report);
223
224        Ok(report)
225    }
226
227    /// Count parameters in all layers
228    fn count_parameters(&self, report: &mut ArchitectureAnalysisReport) {
229        let mut param_distribution: HashMap<LayerType, usize> = HashMap::new();
230
231        for layer in &self.layers {
232            report.total_parameters += layer.parameters;
233            report.trainable_parameters += layer.trainable_parameters;
234
235            *param_distribution.entry(layer.layer_type.clone()).or_insert(0) += layer.parameters;
236        }
237
238        report.parameter_distribution = param_distribution;
239
240        // Estimate model size (4 bytes per float32 parameter)
241        report.model_size_mb = (report.total_parameters * 4) as f32 / (1024.0 * 1024.0);
242
243        // Calculate total FLOPS
244        report.total_flops = self.layers.iter().map(|l| l.flops).sum();
245    }
246
247    /// Calculate receptive fields for convolutional layers
248    async fn calculate_receptive_fields(
249        &mut self,
250        report: &mut ArchitectureAnalysisReport,
251    ) -> Result<()> {
252        for layer in &mut self.layers {
253            if matches!(layer.layer_type, LayerType::Conv2D | LayerType::Conv3D) {
254                layer.receptive_field =
255                    Some(Self::compute_receptive_field_static(&layer.layer_type));
256            }
257        }
258
259        report.layers = self.layers.clone();
260        Ok(())
261    }
262
263    /// Compute receptive field for a convolutional layer (static version)
264    fn compute_receptive_field_static(layer_type: &LayerType) -> ReceptiveField {
265        match layer_type {
266            LayerType::Conv2D => {
267                // Simple 2D convolution receptive field calculation
268                let kernel_size = vec![3, 3]; // Default 3x3 kernel
269                let stride = vec![1, 1];
270                let padding = vec![1, 1];
271
272                ReceptiveField {
273                    size: kernel_size.clone(),
274                    stride,
275                    padding,
276                    effective_size: kernel_size,
277                }
278            },
279            LayerType::Conv3D => {
280                // Simple 3D convolution receptive field calculation
281                let kernel_size = vec![3, 3, 3]; // Default 3x3x3 kernel
282                let stride = vec![1, 1, 1];
283                let padding = vec![1, 1, 1];
284
285                ReceptiveField {
286                    size: kernel_size.clone(),
287                    stride,
288                    padding,
289                    effective_size: kernel_size,
290                }
291            },
292            _ => {
293                // For non-conv layers, receptive field is 1
294                ReceptiveField {
295                    size: vec![1],
296                    stride: vec![1],
297                    padding: vec![0],
298                    effective_size: vec![1],
299                }
300            },
301        }
302    }
303
304    /// Compute receptive field for a convolutional layer
305    #[allow(dead_code)]
306    fn compute_receptive_field(&self, layer: &LayerInfo) -> ReceptiveField {
307        Self::compute_receptive_field_static(&layer.layer_type)
308    }
309
310    /// Analyze model depth and width
311    fn analyze_depth_width(&self, report: &mut ArchitectureAnalysisReport) {
312        // Calculate depth (number of sequential layers)
313        report.model_depth = self.layers.len();
314
315        // Calculate width (maximum number of parameters in a single layer)
316        report.model_width = self.layers.iter().map(|l| l.parameters).max().unwrap_or(0);
317    }
318
319    /// Analyze connectivity patterns
320    fn analyze_connectivity_patterns(&self, report: &mut ArchitectureAnalysisReport) {
321        let mut pattern_types: HashMap<ConnectionType, usize> = HashMap::new();
322
323        for connection in &self.connections {
324            *pattern_types.entry(connection.connection_type.clone()).or_insert(0) += 1;
325        }
326
327        // Find unusual connectivity patterns
328        for (connection_type, count) in pattern_types {
329            if count > self.layers.len() / 2 {
330                // High connectivity, might indicate bottlenecks
331                report.bottlenecks.push(format!(
332                    "High {:?} connectivity: {} connections",
333                    connection_type, count
334                ));
335            }
336        }
337    }
338
339    /// Detect architectural symmetries
340    fn detect_symmetries(&self, report: &mut ArchitectureAnalysisReport) {
341        // Detect block symmetries (repeated layer patterns)
342        let mut block_patterns: HashMap<Vec<LayerType>, Vec<usize>> = HashMap::new();
343
344        // Look for patterns of 2-5 consecutive layers
345        for window_size in 2..=5.min(self.layers.len()) {
346            for i in 0..=(self.layers.len() - window_size) {
347                let pattern: Vec<LayerType> =
348                    self.layers[i..i + window_size].iter().map(|l| l.layer_type.clone()).collect();
349
350                block_patterns.entry(pattern).or_default().push(i);
351            }
352        }
353
354        // Find repeated patterns
355        for (pattern, positions) in block_patterns {
356            if positions.len() > 1 {
357                let confidence = positions.len() as f32 / self.layers.len() as f32;
358
359                if confidence > 0.1 {
360                    // At least 10% of the model
361                    report.symmetries.push(SymmetryInfo {
362                        symmetry_type: SymmetryType::Block,
363                        symmetric_layers: positions
364                            .iter()
365                            .map(|&i| format!("block_{}", i))
366                            .collect(),
367                        confidence,
368                        description: format!(
369                            "Repeated block pattern: {:?} appears {} times",
370                            pattern,
371                            positions.len()
372                        ),
373                    });
374                }
375            }
376        }
377
378        // Detect parameter symmetries
379        let mut param_groups: HashMap<usize, Vec<String>> = HashMap::new();
380        for layer in &self.layers {
381            param_groups.entry(layer.parameters).or_default().push(layer.id.clone());
382        }
383
384        for (param_count, layer_ids) in param_groups {
385            if layer_ids.len() > 2 && param_count > 0 {
386                let confidence = layer_ids.len() as f32 / self.layers.len() as f32;
387
388                report.symmetries.push(SymmetryInfo {
389                    symmetry_type: SymmetryType::Permutation,
390                    symmetric_layers: layer_ids.clone(),
391                    confidence,
392                    description: format!(
393                        "Parameter symmetry: {} layers with {} parameters each",
394                        layer_ids.len(),
395                        param_count
396                    ),
397                });
398            }
399        }
400    }
401
402    /// Calculate efficiency metrics
403    fn calculate_efficiency_metrics(&self, report: &mut ArchitectureAnalysisReport) {
404        let total_params = report.total_parameters as f32;
405        let total_flops = report.total_flops as f32;
406        let depth = report.model_depth as f32;
407        let memory = report.model_size_mb;
408
409        // Parameter efficiency: fewer parameters for same capability is better
410        report.efficiency_metrics.parameter_efficiency = if total_params > 0.0 {
411            1.0 / (total_params / 1_000_000.0).log10().max(1.0) // Inverse log scale
412        } else {
413            1.0
414        };
415
416        // FLOPS efficiency: fewer FLOPS for same capability is better
417        report.efficiency_metrics.flops_efficiency = if total_flops > 0.0 {
418            1.0 / (total_flops / 1_000_000_000.0).log10().max(1.0) // Inverse log scale
419        } else {
420            1.0
421        };
422
423        // Memory efficiency: less memory usage is better
424        report.efficiency_metrics.memory_efficiency = if memory > 0.0 {
425            1.0 / (memory / 100.0).log10().max(1.0) // Inverse log scale
426        } else {
427            1.0
428        };
429
430        // Depth efficiency: moderate depth is best (not too shallow, not too deep)
431        report.efficiency_metrics.depth_efficiency = if depth > 0.0 {
432            let optimal_depth = 20.0; // Assumed optimal depth
433            1.0 - ((depth - optimal_depth).abs() / optimal_depth).min(1.0)
434        } else {
435            0.0
436        };
437
438        // Overall efficiency score (weighted average)
439        report.efficiency_metrics.overall_score = 0.3
440            * report.efficiency_metrics.parameter_efficiency
441            + 0.3 * report.efficiency_metrics.flops_efficiency
442            + 0.2 * report.efficiency_metrics.memory_efficiency
443            + 0.2 * report.efficiency_metrics.depth_efficiency;
444    }
445
446    /// Identify potential bottlenecks
447    fn identify_bottlenecks(&self, report: &mut ArchitectureAnalysisReport) {
448        // Find layers with disproportionately high parameter counts
449        if let Some(_max_params) = self.layers.iter().map(|l| l.parameters).max() {
450            let avg_params = report.total_parameters / self.layers.len().max(1);
451
452            for layer in &self.layers {
453                if layer.parameters > avg_params * 5 {
454                    report.bottlenecks.push(format!(
455                        "Parameter bottleneck: Layer '{}' has {} parameters ({}x average)",
456                        layer.name,
457                        layer.parameters,
458                        layer.parameters / avg_params.max(1)
459                    ));
460                }
461            }
462        }
463
464        // Find layers with very large memory usage
465        for layer in &self.layers {
466            if layer.memory_usage > 100 * 1024 * 1024 {
467                // > 100MB
468                report.bottlenecks.push(format!(
469                    "Memory bottleneck: Layer '{}' uses {:.1}MB memory",
470                    layer.name,
471                    layer.memory_usage as f32 / (1024.0 * 1024.0)
472                ));
473            }
474        }
475
476        // Find layers with very high FLOPS
477        if let Some(_max_flops) = self.layers.iter().map(|l| l.flops).max() {
478            let avg_flops = report.total_flops / self.layers.len().max(1) as u64;
479
480            for layer in &self.layers {
481                if layer.flops > avg_flops * 10 {
482                    report.bottlenecks.push(format!(
483                        "Computation bottleneck: Layer '{}' requires {} FLOPS ({}x average)",
484                        layer.name,
485                        layer.flops,
486                        layer.flops / avg_flops.max(1)
487                    ));
488                }
489            }
490        }
491    }
492
493    /// Quick architecture analysis for simplified interface
494    pub async fn quick_analysis(&self) -> Result<crate::QuickArchitectureSummary> {
495        let total_parameters = self.layers.iter().map(|l| l.parameters as u64).sum::<u64>();
496        let total_flops = self.layers.iter().map(|l| l.flops).sum::<u64>();
497
498        // Estimate model size in MB (4 bytes per float32 parameter)
499        let model_size_mb = (total_parameters as f64 * 4.0) / (1024.0 * 1024.0);
500
501        // Calculate efficiency score based on parameters to FLOPS ratio
502        let efficiency_score = if total_flops > 0 {
503            (total_parameters as f64 / total_flops as f64 * 1000.0).min(100.0)
504        } else {
505            50.0
506        };
507
508        let mut recommendations = Vec::new();
509        if total_parameters > 1_000_000_000 {
510            recommendations
511                .push("Consider model compression techniques for large model".to_string());
512        }
513        if efficiency_score < 30.0 {
514            recommendations.push("Model architecture could be more efficient".to_string());
515        }
516        if model_size_mb > 1000.0 {
517            recommendations.push("Large model size may impact deployment".to_string());
518        }
519        if recommendations.is_empty() {
520            recommendations.push("Architecture appears well-balanced".to_string());
521        }
522
523        Ok(crate::QuickArchitectureSummary {
524            total_parameters,
525            model_size_mb,
526            efficiency_score,
527            recommendations,
528        })
529    }
530
531    /// Generate a comprehensive report
532    pub async fn generate_report(&self) -> Result<ArchitectureAnalysisReport> {
533        // Create a temporary clone to avoid mutable borrow issues
534        let mut temp_analyzer = ArchitectureAnalyzer {
535            config: self.config.clone(),
536            layers: self.layers.clone(),
537            connections: self.connections.clone(),
538            analysis_cache: HashMap::new(),
539        };
540
541        temp_analyzer.analyze().await
542    }
543
544    /// Clear all registered layers and connections
545    pub fn clear(&mut self) {
546        self.layers.clear();
547        self.connections.clear();
548        self.analysis_cache.clear();
549    }
550
551    /// Get summary statistics
552    pub fn get_summary(&self) -> ArchitectureSummary {
553        let total_params: usize = self.layers.iter().map(|l| l.parameters).sum();
554        let total_flops: u64 = self.layers.iter().map(|l| l.flops).sum();
555
556        ArchitectureSummary {
557            total_layers: self.layers.len(),
558            total_parameters: total_params,
559            total_flops,
560            average_layer_size: if !self.layers.is_empty() {
561                total_params / self.layers.len()
562            } else {
563                0
564            },
565            layer_type_distribution: {
566                let mut dist = HashMap::new();
567                for layer in &self.layers {
568                    *dist.entry(layer.layer_type.clone()).or_insert(0) += 1;
569                }
570                dist
571            },
572        }
573    }
574}
575
576/// Summary statistics for architecture
577#[derive(Debug, Clone, Serialize, Deserialize)]
578pub struct ArchitectureSummary {
579    pub total_layers: usize,
580    pub total_parameters: usize,
581    pub total_flops: u64,
582    pub average_layer_size: usize,
583    pub layer_type_distribution: HashMap<LayerType, usize>,
584}
585
586/// Convenience function to create a layer info
587pub fn create_layer_info(
588    id: String,
589    name: String,
590    layer_type: LayerType,
591    input_shape: Vec<usize>,
592    output_shape: Vec<usize>,
593    parameters: usize,
594) -> LayerInfo {
595    let memory_usage = parameters * 4; // 4 bytes per float32
596    let flops = estimate_flops(&layer_type, &input_shape, &output_shape, parameters);
597
598    LayerInfo {
599        id,
600        name,
601        layer_type,
602        input_shape,
603        output_shape,
604        parameters,
605        trainable_parameters: parameters, // Assume all parameters are trainable by default
606        memory_usage,
607        flops,
608        receptive_field: None,
609    }
610}
611
612/// Estimate FLOPS for a layer
613fn estimate_flops(
614    layer_type: &LayerType,
615    input_shape: &[usize],
616    output_shape: &[usize],
617    parameters: usize,
618) -> u64 {
619    match layer_type {
620        LayerType::Linear => {
621            if input_shape.len() >= 2 && output_shape.len() >= 2 {
622                let batch_size = input_shape[0] as u64;
623                let input_features = input_shape[1] as u64;
624                let output_features = output_shape[1] as u64;
625                batch_size * input_features * output_features * 2
626            } else {
627                parameters as u64 * 2
628            }
629        },
630        LayerType::Conv2D => {
631            if output_shape.len() >= 4 {
632                let batch_size = output_shape[0] as u64;
633                let output_channels = output_shape[1] as u64;
634                let output_h = output_shape[2] as u64;
635                let output_w = output_shape[3] as u64;
636                batch_size
637                    * output_channels
638                    * output_h
639                    * output_w
640                    * (parameters as u64 / output_channels).max(1)
641                    * 2
642            } else {
643                parameters as u64 * 2
644            }
645        },
646        LayerType::Attention => {
647            if input_shape.len() >= 3 {
648                let batch_size = input_shape[0] as u64;
649                let seq_len = input_shape[1] as u64;
650                let hidden_size = input_shape[2] as u64;
651                batch_size * seq_len * seq_len * hidden_size * 4
652            } else {
653                parameters as u64 * 4
654            }
655        },
656        _ => parameters as u64,
657    }
658}
659
660// ─────────────────────────────────────────────────────────────────────────────
661// Tests
662// ─────────────────────────────────────────────────────────────────────────────
663
664#[cfg(test)]
665mod tests {
666    use super::*;
667
668    fn make_linear_layer(id: &str, params: usize) -> LayerInfo {
669        create_layer_info(
670            id.to_string(),
671            format!("{}_layer", id),
672            LayerType::Linear,
673            vec![1, params],
674            vec![1, params],
675            params,
676        )
677    }
678
679    // ── Config ────────────────────────────────────────────────────────────
680
681    #[test]
682    fn test_config_default() {
683        let cfg = ArchitectureAnalysisConfig::default();
684        assert!(cfg.enable_parameter_counting);
685        assert!(cfg.enable_receptive_field_calculation);
686        assert!(cfg.enable_depth_width_analysis);
687        assert!(cfg.enable_connectivity_patterns);
688        assert!(cfg.enable_symmetry_detection);
689        assert!(cfg.max_receptive_field_depth > 0);
690        assert!((cfg.sampling_rate - 1.0).abs() < 1e-6);
691    }
692
693    // ── ArchitectureAnalyzer ───────────────────────────────────────────────
694
695    #[test]
696    fn test_analyzer_new_empty() {
697        let analyzer = ArchitectureAnalyzer::new(ArchitectureAnalysisConfig::default());
698        let summary = analyzer.get_summary();
699        assert_eq!(summary.total_layers, 0);
700        assert_eq!(summary.total_parameters, 0);
701    }
702
703    #[test]
704    fn test_register_layer_accumulates() {
705        let mut analyzer = ArchitectureAnalyzer::new(ArchitectureAnalysisConfig::default());
706        analyzer.register_layer(make_linear_layer("l0", 512));
707        analyzer.register_layer(make_linear_layer("l1", 256));
708        assert_eq!(analyzer.get_summary().total_layers, 2);
709    }
710
711    #[test]
712    fn test_add_connection() {
713        let mut analyzer = ArchitectureAnalyzer::new(ArchitectureAnalysisConfig::default());
714        analyzer.register_layer(make_linear_layer("a", 128));
715        analyzer.register_layer(make_linear_layer("b", 128));
716        analyzer.add_connection(ConnectivityPattern {
717            from_layer: "a".to_string(),
718            to_layer: "b".to_string(),
719            connection_type: ConnectionType::Sequential,
720            strength: 1.0,
721        });
722        let summary = analyzer.get_summary();
723        assert_eq!(summary.total_layers, 2);
724    }
725
726    #[test]
727    fn test_clear_resets_state() {
728        let mut analyzer = ArchitectureAnalyzer::new(ArchitectureAnalysisConfig::default());
729        analyzer.register_layer(make_linear_layer("l0", 64));
730        analyzer.clear();
731        assert_eq!(analyzer.get_summary().total_layers, 0);
732    }
733
734    // ── LayerInfo via create_layer_info ────────────────────────────────────
735
736    #[test]
737    fn test_create_layer_info_parameters() {
738        let layer = make_linear_layer("dense", 1024);
739        assert_eq!(layer.parameters, 1024);
740        assert_eq!(layer.trainable_parameters, 1024);
741        assert_eq!(layer.memory_usage, 1024 * 4);
742        assert!(layer.receptive_field.is_none());
743    }
744
745    #[test]
746    fn test_create_layer_info_conv2d_flops() {
747        let layer = create_layer_info(
748            "conv".to_string(),
749            "conv_layer".to_string(),
750            LayerType::Conv2D,
751            vec![1, 3, 224, 224],
752            vec![1, 64, 112, 112],
753            64 * 3 * 3 * 3,
754        );
755        assert!(layer.flops > 0);
756    }
757
758    #[test]
759    fn test_create_layer_info_attention_flops() {
760        let layer = create_layer_info(
761            "attn".to_string(),
762            "attention".to_string(),
763            LayerType::Attention,
764            vec![1, 128, 768],
765            vec![1, 128, 768],
766            768 * 768 * 4,
767        );
768        assert!(layer.flops > 0);
769    }
770
771    // ── ArchitectureSummary ────────────────────────────────────────────────
772
773    #[test]
774    fn test_summary_totals() {
775        let mut analyzer = ArchitectureAnalyzer::new(ArchitectureAnalysisConfig::default());
776        analyzer.register_layer(make_linear_layer("l0", 100));
777        analyzer.register_layer(make_linear_layer("l1", 200));
778        let s = analyzer.get_summary();
779        assert_eq!(s.total_parameters, 300);
780        assert_eq!(s.average_layer_size, 150);
781    }
782
783    #[test]
784    fn test_summary_layer_type_distribution() {
785        let mut analyzer = ArchitectureAnalyzer::new(ArchitectureAnalysisConfig::default());
786        analyzer.register_layer(make_linear_layer("l0", 64));
787        analyzer.register_layer(make_linear_layer("l1", 64));
788        let s = analyzer.get_summary();
789        assert_eq!(
790            s.layer_type_distribution.get(&LayerType::Linear).copied().unwrap_or(0),
791            2
792        );
793    }
794
795    // ── LayerType variants ─────────────────────────────────────────────────
796
797    #[test]
798    fn test_layer_type_all_variants() {
799        let types = [
800            LayerType::Linear,
801            LayerType::Conv2D,
802            LayerType::Conv3D,
803            LayerType::BatchNorm,
804            LayerType::LayerNorm,
805            LayerType::Attention,
806            LayerType::Embedding,
807            LayerType::Dropout,
808            LayerType::Activation,
809            LayerType::Pooling,
810            LayerType::Residual,
811            LayerType::Unknown,
812        ];
813        for t in &types {
814            assert!(!format!("{:?}", t).is_empty());
815        }
816    }
817
818    // ── ConnectionType variants ────────────────────────────────────────────
819
820    #[test]
821    fn test_connection_type_variants() {
822        let types = [
823            ConnectionType::Sequential,
824            ConnectionType::Residual,
825            ConnectionType::Attention,
826            ConnectionType::Skip,
827            ConnectionType::Recurrent,
828            ConnectionType::Branching,
829        ];
830        for t in &types {
831            assert!(!format!("{:?}", t).is_empty());
832        }
833    }
834
835    // ── SymmetryType variants ──────────────────────────────────────────────
836
837    #[test]
838    fn test_symmetry_type_variants() {
839        let types = [
840            SymmetryType::Translational,
841            SymmetryType::Rotational,
842            SymmetryType::Reflection,
843            SymmetryType::Permutation,
844            SymmetryType::Block,
845        ];
846        for t in &types {
847            assert!(!format!("{:?}", t).is_empty());
848        }
849    }
850
851    // ── EfficiencyMetrics ──────────────────────────────────────────────────
852
853    #[test]
854    fn test_efficiency_metrics_construction() {
855        let metrics = EfficiencyMetrics {
856            parameter_efficiency: 0.8,
857            flops_efficiency: 0.7,
858            memory_efficiency: 0.9,
859            depth_efficiency: 0.6,
860            overall_score: 0.75,
861        };
862        assert!((metrics.overall_score - 0.75).abs() < 1e-6);
863    }
864
865    // ── async analyze ──────────────────────────────────────────────────────
866
867    #[tokio::test]
868    async fn test_analyze_empty() {
869        let mut analyzer = ArchitectureAnalyzer::new(ArchitectureAnalysisConfig::default());
870        let report = analyzer.analyze().await.expect("analyze should succeed");
871        assert_eq!(report.total_parameters, 0);
872        assert_eq!(report.model_depth, 0);
873    }
874
875    #[tokio::test]
876    async fn test_analyze_parameter_counting() {
877        let mut analyzer = ArchitectureAnalyzer::new(ArchitectureAnalysisConfig::default());
878        analyzer.register_layer(make_linear_layer("l0", 512));
879        analyzer.register_layer(make_linear_layer("l1", 512));
880        let report = analyzer.analyze().await.expect("analyze should succeed");
881        assert_eq!(report.total_parameters, 1024);
882        assert_eq!(report.trainable_parameters, 1024);
883    }
884
885    #[tokio::test]
886    async fn test_analyze_depth_width() {
887        let mut analyzer = ArchitectureAnalyzer::new(ArchitectureAnalysisConfig::default());
888        analyzer.register_layer(make_linear_layer("l0", 128));
889        analyzer.register_layer(make_linear_layer("l1", 256));
890        analyzer.register_layer(make_linear_layer("l2", 64));
891        let report = analyzer.analyze().await.expect("analyze should succeed");
892        assert_eq!(report.model_depth, 3);
893        assert_eq!(report.model_width, 256);
894    }
895
896    #[tokio::test]
897    async fn test_quick_analysis_returns_summary() {
898        let mut analyzer = ArchitectureAnalyzer::new(ArchitectureAnalysisConfig::default());
899        analyzer.register_layer(make_linear_layer("l0", 1000));
900        let qs = analyzer.quick_analysis().await.expect("quick_analysis should succeed");
901        assert_eq!(qs.total_parameters, 1000);
902        assert!(!qs.recommendations.is_empty());
903    }
904
905    #[tokio::test]
906    async fn test_generate_report_symmetry_detection() {
907        let mut analyzer = ArchitectureAnalyzer::new(ArchitectureAnalysisConfig::default());
908        // Register 6 identical layers — should trigger permutation symmetry detection.
909        for i in 0..6 {
910            analyzer.register_layer(make_linear_layer(&format!("l{}", i), 256));
911        }
912        let report = analyzer.generate_report().await.expect("report should succeed");
913        // Symmetry detection is heuristic; just verify the report is populated
914        assert_eq!(report.total_parameters, 6 * 256);
915    }
916}