Skip to main content

trustformers_debug/
architecture_analysis.rs

1//! Architecture Analysis
2//!
3//! Comprehensive analysis tools for neural network architectures including
4//! parameter counting, receptive field calculation, and connectivity analysis.
5
6use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Configuration for architecture analysis
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ArchitectureAnalysisConfig {
13    /// Enable parameter counting
14    pub enable_parameter_counting: bool,
15    /// Enable receptive field calculation
16    pub enable_receptive_field_calculation: bool,
17    /// Enable depth/width analysis
18    pub enable_depth_width_analysis: bool,
19    /// Enable connectivity pattern detection
20    pub enable_connectivity_patterns: bool,
21    /// Enable symmetry detection
22    pub enable_symmetry_detection: bool,
23    /// Maximum depth to analyze for receptive fields
24    pub max_receptive_field_depth: usize,
25    /// Sampling rate for large models (0.0 to 1.0)
26    pub sampling_rate: f32,
27}
28
29impl Default for ArchitectureAnalysisConfig {
30    fn default() -> Self {
31        Self {
32            enable_parameter_counting: true,
33            enable_receptive_field_calculation: true,
34            enable_depth_width_analysis: true,
35            enable_connectivity_patterns: true,
36            enable_symmetry_detection: true,
37            max_receptive_field_depth: 50,
38            sampling_rate: 1.0,
39        }
40    }
41}
42
43/// Layer type for architecture analysis
44#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
45pub enum LayerType {
46    Linear,
47    Conv2D,
48    Conv3D,
49    BatchNorm,
50    LayerNorm,
51    Attention,
52    Embedding,
53    Dropout,
54    Activation,
55    Pooling,
56    Residual,
57    Unknown,
58}
59
60/// Information about a single layer
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct LayerInfo {
63    pub id: String,
64    pub name: String,
65    pub layer_type: LayerType,
66    pub input_shape: Vec<usize>,
67    pub output_shape: Vec<usize>,
68    pub parameters: usize,
69    pub trainable_parameters: usize,
70    pub memory_usage: usize,
71    pub flops: u64,
72    pub receptive_field: Option<ReceptiveField>,
73}
74
75/// Receptive field information
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct ReceptiveField {
78    pub size: Vec<usize>,
79    pub stride: Vec<usize>,
80    pub padding: Vec<usize>,
81    pub effective_size: Vec<usize>,
82}
83
84/// Connectivity pattern between layers
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct ConnectivityPattern {
87    pub from_layer: String,
88    pub to_layer: String,
89    pub connection_type: ConnectionType,
90    pub strength: f32,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize, Hash, Eq, PartialEq)]
94pub enum ConnectionType {
95    Sequential,
96    Residual,
97    Attention,
98    Skip,
99    Recurrent,
100    Branching,
101}
102
103/// Symmetry information in the architecture
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct SymmetryInfo {
106    pub symmetry_type: SymmetryType,
107    pub symmetric_layers: Vec<String>,
108    pub confidence: f32,
109    pub description: String,
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub enum SymmetryType {
114    Translational,
115    Rotational,
116    Reflection,
117    Permutation,
118    Block,
119}
120
121/// Architecture analysis results
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct ArchitectureAnalysisReport {
124    pub total_parameters: usize,
125    pub trainable_parameters: usize,
126    pub model_size_mb: f32,
127    pub total_flops: u64,
128    pub model_depth: usize,
129    pub model_width: usize,
130    pub layers: Vec<LayerInfo>,
131    pub connectivity_patterns: Vec<ConnectivityPattern>,
132    pub symmetries: Vec<SymmetryInfo>,
133    pub parameter_distribution: HashMap<LayerType, usize>,
134    pub bottlenecks: Vec<String>,
135    pub efficiency_metrics: EfficiencyMetrics,
136}
137
138/// Model efficiency metrics
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct EfficiencyMetrics {
141    pub parameter_efficiency: f32,
142    pub flops_efficiency: f32,
143    pub memory_efficiency: f32,
144    pub depth_efficiency: f32,
145    pub overall_score: f32,
146}
147
148/// Architecture analyzer
149#[derive(Debug)]
150pub struct ArchitectureAnalyzer {
151    config: ArchitectureAnalysisConfig,
152    layers: Vec<LayerInfo>,
153    connections: Vec<ConnectivityPattern>,
154    analysis_cache: HashMap<String, ArchitectureAnalysisReport>,
155}
156
157impl ArchitectureAnalyzer {
158    /// Create a new architecture analyzer
159    pub fn new(config: ArchitectureAnalysisConfig) -> Self {
160        Self {
161            config,
162            layers: Vec::new(),
163            connections: Vec::new(),
164            analysis_cache: HashMap::new(),
165        }
166    }
167
168    /// Register a layer for analysis
169    pub fn register_layer(&mut self, layer: LayerInfo) {
170        self.layers.push(layer);
171    }
172
173    /// Add a connection between layers
174    pub fn add_connection(&mut self, pattern: ConnectivityPattern) {
175        self.connections.push(pattern);
176    }
177
178    /// Analyze the registered architecture
179    pub async fn analyze(&mut self) -> Result<ArchitectureAnalysisReport> {
180        let mut report = ArchitectureAnalysisReport {
181            total_parameters: 0,
182            trainable_parameters: 0,
183            model_size_mb: 0.0,
184            total_flops: 0,
185            model_depth: 0,
186            model_width: 0,
187            layers: self.layers.clone(),
188            connectivity_patterns: self.connections.clone(),
189            symmetries: Vec::new(),
190            parameter_distribution: HashMap::new(),
191            bottlenecks: Vec::new(),
192            efficiency_metrics: EfficiencyMetrics {
193                parameter_efficiency: 0.0,
194                flops_efficiency: 0.0,
195                memory_efficiency: 0.0,
196                depth_efficiency: 0.0,
197                overall_score: 0.0,
198            },
199        };
200
201        if self.config.enable_parameter_counting {
202            self.count_parameters(&mut report);
203        }
204
205        if self.config.enable_receptive_field_calculation {
206            self.calculate_receptive_fields(&mut report).await?;
207        }
208
209        if self.config.enable_depth_width_analysis {
210            self.analyze_depth_width(&mut report);
211        }
212
213        if self.config.enable_connectivity_patterns {
214            self.analyze_connectivity_patterns(&mut report);
215        }
216
217        if self.config.enable_symmetry_detection {
218            self.detect_symmetries(&mut report);
219        }
220
221        self.calculate_efficiency_metrics(&mut report);
222        self.identify_bottlenecks(&mut report);
223
224        Ok(report)
225    }
226
227    /// Count parameters in all layers
228    fn count_parameters(&self, report: &mut ArchitectureAnalysisReport) {
229        let mut param_distribution: HashMap<LayerType, usize> = HashMap::new();
230
231        for layer in &self.layers {
232            report.total_parameters += layer.parameters;
233            report.trainable_parameters += layer.trainable_parameters;
234
235            *param_distribution.entry(layer.layer_type.clone()).or_insert(0) += layer.parameters;
236        }
237
238        report.parameter_distribution = param_distribution;
239
240        // Estimate model size (4 bytes per float32 parameter)
241        report.model_size_mb = (report.total_parameters * 4) as f32 / (1024.0 * 1024.0);
242
243        // Calculate total FLOPS
244        report.total_flops = self.layers.iter().map(|l| l.flops).sum();
245    }
246
247    /// Calculate receptive fields for convolutional layers
248    async fn calculate_receptive_fields(
249        &mut self,
250        report: &mut ArchitectureAnalysisReport,
251    ) -> Result<()> {
252        for layer in &mut self.layers {
253            if matches!(layer.layer_type, LayerType::Conv2D | LayerType::Conv3D) {
254                layer.receptive_field =
255                    Some(Self::compute_receptive_field_static(&layer.layer_type));
256            }
257        }
258
259        report.layers = self.layers.clone();
260        Ok(())
261    }
262
263    /// Compute receptive field for a convolutional layer (static version)
264    fn compute_receptive_field_static(layer_type: &LayerType) -> ReceptiveField {
265        match layer_type {
266            LayerType::Conv2D => {
267                // Simple 2D convolution receptive field calculation
268                let kernel_size = vec![3, 3]; // Default 3x3 kernel
269                let stride = vec![1, 1];
270                let padding = vec![1, 1];
271
272                ReceptiveField {
273                    size: kernel_size.clone(),
274                    stride,
275                    padding,
276                    effective_size: kernel_size,
277                }
278            },
279            LayerType::Conv3D => {
280                // Simple 3D convolution receptive field calculation
281                let kernel_size = vec![3, 3, 3]; // Default 3x3x3 kernel
282                let stride = vec![1, 1, 1];
283                let padding = vec![1, 1, 1];
284
285                ReceptiveField {
286                    size: kernel_size.clone(),
287                    stride,
288                    padding,
289                    effective_size: kernel_size,
290                }
291            },
292            _ => {
293                // For non-conv layers, receptive field is 1
294                ReceptiveField {
295                    size: vec![1],
296                    stride: vec![1],
297                    padding: vec![0],
298                    effective_size: vec![1],
299                }
300            },
301        }
302    }
303
304    /// Compute receptive field for a convolutional layer
305    #[allow(dead_code)]
306    fn compute_receptive_field(&self, layer: &LayerInfo) -> ReceptiveField {
307        Self::compute_receptive_field_static(&layer.layer_type)
308    }
309
310    /// Analyze model depth and width
311    fn analyze_depth_width(&self, report: &mut ArchitectureAnalysisReport) {
312        // Calculate depth (number of sequential layers)
313        report.model_depth = self.layers.len();
314
315        // Calculate width (maximum number of parameters in a single layer)
316        report.model_width = self.layers.iter().map(|l| l.parameters).max().unwrap_or(0);
317    }
318
319    /// Analyze connectivity patterns
320    fn analyze_connectivity_patterns(&self, report: &mut ArchitectureAnalysisReport) {
321        let mut pattern_types: HashMap<ConnectionType, usize> = HashMap::new();
322
323        for connection in &self.connections {
324            *pattern_types.entry(connection.connection_type.clone()).or_insert(0) += 1;
325        }
326
327        // Find unusual connectivity patterns
328        for (connection_type, count) in pattern_types {
329            if count > self.layers.len() / 2 {
330                // High connectivity, might indicate bottlenecks
331                report.bottlenecks.push(format!(
332                    "High {:?} connectivity: {} connections",
333                    connection_type, count
334                ));
335            }
336        }
337    }
338
339    /// Detect architectural symmetries
340    fn detect_symmetries(&self, report: &mut ArchitectureAnalysisReport) {
341        // Detect block symmetries (repeated layer patterns)
342        let mut block_patterns: HashMap<Vec<LayerType>, Vec<usize>> = HashMap::new();
343
344        // Look for patterns of 2-5 consecutive layers
345        for window_size in 2..=5.min(self.layers.len()) {
346            for i in 0..=(self.layers.len() - window_size) {
347                let pattern: Vec<LayerType> =
348                    self.layers[i..i + window_size].iter().map(|l| l.layer_type.clone()).collect();
349
350                block_patterns.entry(pattern).or_default().push(i);
351            }
352        }
353
354        // Find repeated patterns
355        for (pattern, positions) in block_patterns {
356            if positions.len() > 1 {
357                let confidence = positions.len() as f32 / self.layers.len() as f32;
358
359                if confidence > 0.1 {
360                    // At least 10% of the model
361                    report.symmetries.push(SymmetryInfo {
362                        symmetry_type: SymmetryType::Block,
363                        symmetric_layers: positions
364                            .iter()
365                            .map(|&i| format!("block_{}", i))
366                            .collect(),
367                        confidence,
368                        description: format!(
369                            "Repeated block pattern: {:?} appears {} times",
370                            pattern,
371                            positions.len()
372                        ),
373                    });
374                }
375            }
376        }
377
378        // Detect parameter symmetries
379        let mut param_groups: HashMap<usize, Vec<String>> = HashMap::new();
380        for layer in &self.layers {
381            param_groups.entry(layer.parameters).or_default().push(layer.id.clone());
382        }
383
384        for (param_count, layer_ids) in param_groups {
385            if layer_ids.len() > 2 && param_count > 0 {
386                let confidence = layer_ids.len() as f32 / self.layers.len() as f32;
387
388                report.symmetries.push(SymmetryInfo {
389                    symmetry_type: SymmetryType::Permutation,
390                    symmetric_layers: layer_ids.clone(),
391                    confidence,
392                    description: format!(
393                        "Parameter symmetry: {} layers with {} parameters each",
394                        layer_ids.len(),
395                        param_count
396                    ),
397                });
398            }
399        }
400    }
401
402    /// Calculate efficiency metrics
403    fn calculate_efficiency_metrics(&self, report: &mut ArchitectureAnalysisReport) {
404        let total_params = report.total_parameters as f32;
405        let total_flops = report.total_flops as f32;
406        let depth = report.model_depth as f32;
407        let memory = report.model_size_mb;
408
409        // Parameter efficiency: fewer parameters for same capability is better
410        report.efficiency_metrics.parameter_efficiency = if total_params > 0.0 {
411            1.0 / (total_params / 1_000_000.0).log10().max(1.0) // Inverse log scale
412        } else {
413            1.0
414        };
415
416        // FLOPS efficiency: fewer FLOPS for same capability is better
417        report.efficiency_metrics.flops_efficiency = if total_flops > 0.0 {
418            1.0 / (total_flops / 1_000_000_000.0).log10().max(1.0) // Inverse log scale
419        } else {
420            1.0
421        };
422
423        // Memory efficiency: less memory usage is better
424        report.efficiency_metrics.memory_efficiency = if memory > 0.0 {
425            1.0 / (memory / 100.0).log10().max(1.0) // Inverse log scale
426        } else {
427            1.0
428        };
429
430        // Depth efficiency: moderate depth is best (not too shallow, not too deep)
431        report.efficiency_metrics.depth_efficiency = if depth > 0.0 {
432            let optimal_depth = 20.0; // Assumed optimal depth
433            1.0 - ((depth - optimal_depth).abs() / optimal_depth).min(1.0)
434        } else {
435            0.0
436        };
437
438        // Overall efficiency score (weighted average)
439        report.efficiency_metrics.overall_score = 0.3
440            * report.efficiency_metrics.parameter_efficiency
441            + 0.3 * report.efficiency_metrics.flops_efficiency
442            + 0.2 * report.efficiency_metrics.memory_efficiency
443            + 0.2 * report.efficiency_metrics.depth_efficiency;
444    }
445
446    /// Identify potential bottlenecks
447    fn identify_bottlenecks(&self, report: &mut ArchitectureAnalysisReport) {
448        // Find layers with disproportionately high parameter counts
449        if let Some(_max_params) = self.layers.iter().map(|l| l.parameters).max() {
450            let avg_params = report.total_parameters / self.layers.len().max(1);
451
452            for layer in &self.layers {
453                if layer.parameters > avg_params * 5 {
454                    report.bottlenecks.push(format!(
455                        "Parameter bottleneck: Layer '{}' has {} parameters ({}x average)",
456                        layer.name,
457                        layer.parameters,
458                        layer.parameters / avg_params.max(1)
459                    ));
460                }
461            }
462        }
463
464        // Find layers with very large memory usage
465        for layer in &self.layers {
466            if layer.memory_usage > 100 * 1024 * 1024 {
467                // > 100MB
468                report.bottlenecks.push(format!(
469                    "Memory bottleneck: Layer '{}' uses {:.1}MB memory",
470                    layer.name,
471                    layer.memory_usage as f32 / (1024.0 * 1024.0)
472                ));
473            }
474        }
475
476        // Find layers with very high FLOPS
477        if let Some(_max_flops) = self.layers.iter().map(|l| l.flops).max() {
478            let avg_flops = report.total_flops / self.layers.len().max(1) as u64;
479
480            for layer in &self.layers {
481                if layer.flops > avg_flops * 10 {
482                    report.bottlenecks.push(format!(
483                        "Computation bottleneck: Layer '{}' requires {} FLOPS ({}x average)",
484                        layer.name,
485                        layer.flops,
486                        layer.flops / avg_flops.max(1)
487                    ));
488                }
489            }
490        }
491    }
492
493    /// Quick architecture analysis for simplified interface
494    pub async fn quick_analysis(&self) -> Result<crate::QuickArchitectureSummary> {
495        let total_parameters = self.layers.iter().map(|l| l.parameters as u64).sum::<u64>();
496        let total_flops = self.layers.iter().map(|l| l.flops).sum::<u64>();
497
498        // Estimate model size in MB (4 bytes per float32 parameter)
499        let model_size_mb = (total_parameters as f64 * 4.0) / (1024.0 * 1024.0);
500
501        // Calculate efficiency score based on parameters to FLOPS ratio
502        let efficiency_score = if total_flops > 0 {
503            (total_parameters as f64 / total_flops as f64 * 1000.0).min(100.0)
504        } else {
505            50.0
506        };
507
508        let mut recommendations = Vec::new();
509        if total_parameters > 1_000_000_000 {
510            recommendations
511                .push("Consider model compression techniques for large model".to_string());
512        }
513        if efficiency_score < 30.0 {
514            recommendations.push("Model architecture could be more efficient".to_string());
515        }
516        if model_size_mb > 1000.0 {
517            recommendations.push("Large model size may impact deployment".to_string());
518        }
519        if recommendations.is_empty() {
520            recommendations.push("Architecture appears well-balanced".to_string());
521        }
522
523        Ok(crate::QuickArchitectureSummary {
524            total_parameters,
525            model_size_mb,
526            efficiency_score,
527            recommendations,
528        })
529    }
530
531    /// Generate a comprehensive report
532    pub async fn generate_report(&self) -> Result<ArchitectureAnalysisReport> {
533        // Create a temporary clone to avoid mutable borrow issues
534        let mut temp_analyzer = ArchitectureAnalyzer {
535            config: self.config.clone(),
536            layers: self.layers.clone(),
537            connections: self.connections.clone(),
538            analysis_cache: HashMap::new(),
539        };
540
541        temp_analyzer.analyze().await
542    }
543
544    /// Clear all registered layers and connections
545    pub fn clear(&mut self) {
546        self.layers.clear();
547        self.connections.clear();
548        self.analysis_cache.clear();
549    }
550
551    /// Get summary statistics
552    pub fn get_summary(&self) -> ArchitectureSummary {
553        let total_params: usize = self.layers.iter().map(|l| l.parameters).sum();
554        let total_flops: u64 = self.layers.iter().map(|l| l.flops).sum();
555
556        ArchitectureSummary {
557            total_layers: self.layers.len(),
558            total_parameters: total_params,
559            total_flops,
560            average_layer_size: if !self.layers.is_empty() {
561                total_params / self.layers.len()
562            } else {
563                0
564            },
565            layer_type_distribution: {
566                let mut dist = HashMap::new();
567                for layer in &self.layers {
568                    *dist.entry(layer.layer_type.clone()).or_insert(0) += 1;
569                }
570                dist
571            },
572        }
573    }
574}
575
576/// Summary statistics for architecture
577#[derive(Debug, Clone, Serialize, Deserialize)]
578pub struct ArchitectureSummary {
579    pub total_layers: usize,
580    pub total_parameters: usize,
581    pub total_flops: u64,
582    pub average_layer_size: usize,
583    pub layer_type_distribution: HashMap<LayerType, usize>,
584}
585
586/// Convenience function to create a layer info
587pub fn create_layer_info(
588    id: String,
589    name: String,
590    layer_type: LayerType,
591    input_shape: Vec<usize>,
592    output_shape: Vec<usize>,
593    parameters: usize,
594) -> LayerInfo {
595    let memory_usage = parameters * 4; // 4 bytes per float32
596    let flops = estimate_flops(&layer_type, &input_shape, &output_shape, parameters);
597
598    LayerInfo {
599        id,
600        name,
601        layer_type,
602        input_shape,
603        output_shape,
604        parameters,
605        trainable_parameters: parameters, // Assume all parameters are trainable by default
606        memory_usage,
607        flops,
608        receptive_field: None,
609    }
610}
611
612/// Estimate FLOPS for a layer
613fn estimate_flops(
614    layer_type: &LayerType,
615    input_shape: &[usize],
616    output_shape: &[usize],
617    parameters: usize,
618) -> u64 {
619    match layer_type {
620        LayerType::Linear => {
621            // Matrix multiplication: input_features * output_features * batch_size
622            if input_shape.len() >= 2 && output_shape.len() >= 2 {
623                let batch_size = input_shape[0] as u64;
624                let input_features = input_shape[1] as u64;
625                let output_features = output_shape[1] as u64;
626                batch_size * input_features * output_features * 2 // Multiply-add
627            } else {
628                parameters as u64 * 2
629            }
630        },
631        LayerType::Conv2D => {
632            // Convolution: output_h * output_w * kernel_h * kernel_w * input_channels * output_channels
633            if output_shape.len() >= 4 {
634                let batch_size = output_shape[0] as u64;
635                let output_channels = output_shape[1] as u64;
636                let output_h = output_shape[2] as u64;
637                let output_w = output_shape[3] as u64;
638                batch_size
639                    * output_channels
640                    * output_h
641                    * output_w
642                    * (parameters as u64 / output_channels).max(1)
643                    * 2
644            } else {
645                parameters as u64 * 2
646            }
647        },
648        LayerType::Attention => {
649            // Attention: roughly O(sequence_length^2 * hidden_size)
650            if input_shape.len() >= 3 {
651                let batch_size = input_shape[0] as u64;
652                let seq_len = input_shape[1] as u64;
653                let hidden_size = input_shape[2] as u64;
654                batch_size * seq_len * seq_len * hidden_size * 4 // Q, K, V, output projections
655            } else {
656                parameters as u64 * 4
657            }
658        },
659        _ => {
660            // For other layers, estimate based on parameters
661            parameters as u64
662        },
663    }
664}