trustformers_debug/
architecture_analysis.rs

1//! Architecture Analysis
2//!
3//! Comprehensive analysis tools for neural network architectures including
4//! parameter counting, receptive field calculation, and connectivity analysis.
5
6use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Configuration for architecture analysis
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ArchitectureAnalysisConfig {
13    /// Enable parameter counting
14    pub enable_parameter_counting: bool,
15    /// Enable receptive field calculation
16    pub enable_receptive_field_calculation: bool,
17    /// Enable depth/width analysis
18    pub enable_depth_width_analysis: bool,
19    /// Enable connectivity pattern detection
20    pub enable_connectivity_patterns: bool,
21    /// Enable symmetry detection
22    pub enable_symmetry_detection: bool,
23    /// Maximum depth to analyze for receptive fields
24    pub max_receptive_field_depth: usize,
25    /// Sampling rate for large models (0.0 to 1.0)
26    pub sampling_rate: f32,
27}
28
29impl Default for ArchitectureAnalysisConfig {
30    fn default() -> Self {
31        Self {
32            enable_parameter_counting: true,
33            enable_receptive_field_calculation: true,
34            enable_depth_width_analysis: true,
35            enable_connectivity_patterns: true,
36            enable_symmetry_detection: true,
37            max_receptive_field_depth: 50,
38            sampling_rate: 1.0,
39        }
40    }
41}
42
43/// Layer type for architecture analysis
44#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
45pub enum LayerType {
46    Linear,
47    Conv2D,
48    Conv3D,
49    BatchNorm,
50    LayerNorm,
51    Attention,
52    Embedding,
53    Dropout,
54    Activation,
55    Pooling,
56    Residual,
57    Unknown,
58}
59
60/// Information about a single layer
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct LayerInfo {
63    pub id: String,
64    pub name: String,
65    pub layer_type: LayerType,
66    pub input_shape: Vec<usize>,
67    pub output_shape: Vec<usize>,
68    pub parameters: usize,
69    pub trainable_parameters: usize,
70    pub memory_usage: usize,
71    pub flops: u64,
72    pub receptive_field: Option<ReceptiveField>,
73}
74
75/// Receptive field information
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct ReceptiveField {
78    pub size: Vec<usize>,
79    pub stride: Vec<usize>,
80    pub padding: Vec<usize>,
81    pub effective_size: Vec<usize>,
82}
83
84/// Connectivity pattern between layers
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct ConnectivityPattern {
87    pub from_layer: String,
88    pub to_layer: String,
89    pub connection_type: ConnectionType,
90    pub strength: f32,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize, Hash, Eq, PartialEq)]
94pub enum ConnectionType {
95    Sequential,
96    Residual,
97    Attention,
98    Skip,
99    Recurrent,
100    Branching,
101}
102
103/// Symmetry information in the architecture
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct SymmetryInfo {
106    pub symmetry_type: SymmetryType,
107    pub symmetric_layers: Vec<String>,
108    pub confidence: f32,
109    pub description: String,
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub enum SymmetryType {
114    Translational,
115    Rotational,
116    Reflection,
117    Permutation,
118    Block,
119}
120
121/// Architecture analysis results
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct ArchitectureAnalysisReport {
124    pub total_parameters: usize,
125    pub trainable_parameters: usize,
126    pub model_size_mb: f32,
127    pub total_flops: u64,
128    pub model_depth: usize,
129    pub model_width: usize,
130    pub layers: Vec<LayerInfo>,
131    pub connectivity_patterns: Vec<ConnectivityPattern>,
132    pub symmetries: Vec<SymmetryInfo>,
133    pub parameter_distribution: HashMap<LayerType, usize>,
134    pub bottlenecks: Vec<String>,
135    pub efficiency_metrics: EfficiencyMetrics,
136}
137
138/// Model efficiency metrics
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct EfficiencyMetrics {
141    pub parameter_efficiency: f32,
142    pub flops_efficiency: f32,
143    pub memory_efficiency: f32,
144    pub depth_efficiency: f32,
145    pub overall_score: f32,
146}
147
148/// Architecture analyzer
149#[derive(Debug)]
150pub struct ArchitectureAnalyzer {
151    config: ArchitectureAnalysisConfig,
152    layers: Vec<LayerInfo>,
153    connections: Vec<ConnectivityPattern>,
154    analysis_cache: HashMap<String, ArchitectureAnalysisReport>,
155}
156
157impl ArchitectureAnalyzer {
158    /// Create a new architecture analyzer
159    pub fn new(config: ArchitectureAnalysisConfig) -> Self {
160        Self {
161            config,
162            layers: Vec::new(),
163            connections: Vec::new(),
164            analysis_cache: HashMap::new(),
165        }
166    }
167
168    /// Register a layer for analysis
169    pub fn register_layer(&mut self, layer: LayerInfo) {
170        self.layers.push(layer);
171    }
172
173    /// Add a connection between layers
174    pub fn add_connection(&mut self, pattern: ConnectivityPattern) {
175        self.connections.push(pattern);
176    }
177
178    /// Analyze the registered architecture
179    pub async fn analyze(&mut self) -> Result<ArchitectureAnalysisReport> {
180        let mut report = ArchitectureAnalysisReport {
181            total_parameters: 0,
182            trainable_parameters: 0,
183            model_size_mb: 0.0,
184            total_flops: 0,
185            model_depth: 0,
186            model_width: 0,
187            layers: self.layers.clone(),
188            connectivity_patterns: self.connections.clone(),
189            symmetries: Vec::new(),
190            parameter_distribution: HashMap::new(),
191            bottlenecks: Vec::new(),
192            efficiency_metrics: EfficiencyMetrics {
193                parameter_efficiency: 0.0,
194                flops_efficiency: 0.0,
195                memory_efficiency: 0.0,
196                depth_efficiency: 0.0,
197                overall_score: 0.0,
198            },
199        };
200
201        if self.config.enable_parameter_counting {
202            self.count_parameters(&mut report);
203        }
204
205        if self.config.enable_receptive_field_calculation {
206            self.calculate_receptive_fields(&mut report).await?;
207        }
208
209        if self.config.enable_depth_width_analysis {
210            self.analyze_depth_width(&mut report);
211        }
212
213        if self.config.enable_connectivity_patterns {
214            self.analyze_connectivity_patterns(&mut report);
215        }
216
217        if self.config.enable_symmetry_detection {
218            self.detect_symmetries(&mut report);
219        }
220
221        self.calculate_efficiency_metrics(&mut report);
222        self.identify_bottlenecks(&mut report);
223
224        Ok(report)
225    }
226
227    /// Count parameters in all layers
228    fn count_parameters(&self, report: &mut ArchitectureAnalysisReport) {
229        let mut param_distribution: HashMap<LayerType, usize> = HashMap::new();
230
231        for layer in &self.layers {
232            report.total_parameters += layer.parameters;
233            report.trainable_parameters += layer.trainable_parameters;
234
235            *param_distribution.entry(layer.layer_type.clone()).or_insert(0) += layer.parameters;
236        }
237
238        report.parameter_distribution = param_distribution;
239
240        // Estimate model size (4 bytes per float32 parameter)
241        report.model_size_mb = (report.total_parameters * 4) as f32 / (1024.0 * 1024.0);
242
243        // Calculate total FLOPS
244        report.total_flops = self.layers.iter().map(|l| l.flops).sum();
245    }
246
247    /// Calculate receptive fields for convolutional layers
248    async fn calculate_receptive_fields(
249        &mut self,
250        report: &mut ArchitectureAnalysisReport,
251    ) -> Result<()> {
252        for layer in &mut self.layers {
253            if matches!(layer.layer_type, LayerType::Conv2D | LayerType::Conv3D) {
254                layer.receptive_field =
255                    Some(Self::compute_receptive_field_static(&layer.layer_type));
256            }
257        }
258
259        report.layers = self.layers.clone();
260        Ok(())
261    }
262
263    /// Compute receptive field for a convolutional layer (static version)
264    fn compute_receptive_field_static(layer_type: &LayerType) -> ReceptiveField {
265        match layer_type {
266            LayerType::Conv2D => {
267                // Simple 2D convolution receptive field calculation
268                let kernel_size = vec![3, 3]; // Default 3x3 kernel
269                let stride = vec![1, 1];
270                let padding = vec![1, 1];
271
272                ReceptiveField {
273                    size: kernel_size.clone(),
274                    stride,
275                    padding,
276                    effective_size: kernel_size,
277                }
278            },
279            LayerType::Conv3D => {
280                // Simple 3D convolution receptive field calculation
281                let kernel_size = vec![3, 3, 3]; // Default 3x3x3 kernel
282                let stride = vec![1, 1, 1];
283                let padding = vec![1, 1, 1];
284
285                ReceptiveField {
286                    size: kernel_size.clone(),
287                    stride,
288                    padding,
289                    effective_size: kernel_size,
290                }
291            },
292            _ => {
293                // For non-conv layers, receptive field is 1
294                ReceptiveField {
295                    size: vec![1],
296                    stride: vec![1],
297                    padding: vec![0],
298                    effective_size: vec![1],
299                }
300            },
301        }
302    }
303
304    /// Compute receptive field for a convolutional layer
305    #[allow(dead_code)]
306    fn compute_receptive_field(&self, layer: &LayerInfo) -> ReceptiveField {
307        Self::compute_receptive_field_static(&layer.layer_type)
308    }
309
310    /// Analyze model depth and width
311    fn analyze_depth_width(&self, report: &mut ArchitectureAnalysisReport) {
312        // Calculate depth (number of sequential layers)
313        report.model_depth = self.layers.len();
314
315        // Calculate width (maximum number of parameters in a single layer)
316        report.model_width = self.layers.iter().map(|l| l.parameters).max().unwrap_or(0);
317    }
318
319    /// Analyze connectivity patterns
320    fn analyze_connectivity_patterns(&self, report: &mut ArchitectureAnalysisReport) {
321        let mut pattern_types: HashMap<ConnectionType, usize> = HashMap::new();
322
323        for connection in &self.connections {
324            *pattern_types.entry(connection.connection_type.clone()).or_insert(0) += 1;
325        }
326
327        // Find unusual connectivity patterns
328        for (connection_type, count) in pattern_types {
329            if count > self.layers.len() / 2 {
330                // High connectivity, might indicate bottlenecks
331                report.bottlenecks.push(format!(
332                    "High {:?} connectivity: {} connections",
333                    connection_type, count
334                ));
335            }
336        }
337    }
338
339    /// Detect architectural symmetries
340    fn detect_symmetries(&self, report: &mut ArchitectureAnalysisReport) {
341        // Detect block symmetries (repeated layer patterns)
342        let mut block_patterns: HashMap<Vec<LayerType>, Vec<usize>> = HashMap::new();
343
344        // Look for patterns of 2-5 consecutive layers
345        for window_size in 2..=5.min(self.layers.len()) {
346            for i in 0..=(self.layers.len() - window_size) {
347                let pattern: Vec<LayerType> =
348                    self.layers[i..i + window_size].iter().map(|l| l.layer_type.clone()).collect();
349
350                block_patterns.entry(pattern).or_insert_with(Vec::new).push(i);
351            }
352        }
353
354        // Find repeated patterns
355        for (pattern, positions) in block_patterns {
356            if positions.len() > 1 {
357                let confidence = positions.len() as f32 / self.layers.len() as f32;
358
359                if confidence > 0.1 {
360                    // At least 10% of the model
361                    report.symmetries.push(SymmetryInfo {
362                        symmetry_type: SymmetryType::Block,
363                        symmetric_layers: positions
364                            .iter()
365                            .map(|&i| format!("block_{}", i))
366                            .collect(),
367                        confidence,
368                        description: format!(
369                            "Repeated block pattern: {:?} appears {} times",
370                            pattern,
371                            positions.len()
372                        ),
373                    });
374                }
375            }
376        }
377
378        // Detect parameter symmetries
379        let mut param_groups: HashMap<usize, Vec<String>> = HashMap::new();
380        for layer in &self.layers {
381            param_groups
382                .entry(layer.parameters)
383                .or_insert_with(Vec::new)
384                .push(layer.id.clone());
385        }
386
387        for (param_count, layer_ids) in param_groups {
388            if layer_ids.len() > 2 && param_count > 0 {
389                let confidence = layer_ids.len() as f32 / self.layers.len() as f32;
390
391                report.symmetries.push(SymmetryInfo {
392                    symmetry_type: SymmetryType::Permutation,
393                    symmetric_layers: layer_ids.clone(),
394                    confidence,
395                    description: format!(
396                        "Parameter symmetry: {} layers with {} parameters each",
397                        layer_ids.len(),
398                        param_count
399                    ),
400                });
401            }
402        }
403    }
404
405    /// Calculate efficiency metrics
406    fn calculate_efficiency_metrics(&self, report: &mut ArchitectureAnalysisReport) {
407        let total_params = report.total_parameters as f32;
408        let total_flops = report.total_flops as f32;
409        let depth = report.model_depth as f32;
410        let memory = report.model_size_mb;
411
412        // Parameter efficiency: fewer parameters for same capability is better
413        report.efficiency_metrics.parameter_efficiency = if total_params > 0.0 {
414            1.0 / (total_params / 1_000_000.0).log10().max(1.0) // Inverse log scale
415        } else {
416            1.0
417        };
418
419        // FLOPS efficiency: fewer FLOPS for same capability is better
420        report.efficiency_metrics.flops_efficiency = if total_flops > 0.0 {
421            1.0 / (total_flops / 1_000_000_000.0).log10().max(1.0) // Inverse log scale
422        } else {
423            1.0
424        };
425
426        // Memory efficiency: less memory usage is better
427        report.efficiency_metrics.memory_efficiency = if memory > 0.0 {
428            1.0 / (memory / 100.0).log10().max(1.0) // Inverse log scale
429        } else {
430            1.0
431        };
432
433        // Depth efficiency: moderate depth is best (not too shallow, not too deep)
434        report.efficiency_metrics.depth_efficiency = if depth > 0.0 {
435            let optimal_depth = 20.0; // Assumed optimal depth
436            1.0 - ((depth - optimal_depth).abs() / optimal_depth).min(1.0)
437        } else {
438            0.0
439        };
440
441        // Overall efficiency score (weighted average)
442        report.efficiency_metrics.overall_score = 0.3
443            * report.efficiency_metrics.parameter_efficiency
444            + 0.3 * report.efficiency_metrics.flops_efficiency
445            + 0.2 * report.efficiency_metrics.memory_efficiency
446            + 0.2 * report.efficiency_metrics.depth_efficiency;
447    }
448
449    /// Identify potential bottlenecks
450    fn identify_bottlenecks(&self, report: &mut ArchitectureAnalysisReport) {
451        // Find layers with disproportionately high parameter counts
452        if let Some(_max_params) = self.layers.iter().map(|l| l.parameters).max() {
453            let avg_params = report.total_parameters / self.layers.len().max(1);
454
455            for layer in &self.layers {
456                if layer.parameters > avg_params * 5 {
457                    report.bottlenecks.push(format!(
458                        "Parameter bottleneck: Layer '{}' has {} parameters ({}x average)",
459                        layer.name,
460                        layer.parameters,
461                        layer.parameters / avg_params.max(1)
462                    ));
463                }
464            }
465        }
466
467        // Find layers with very large memory usage
468        for layer in &self.layers {
469            if layer.memory_usage > 100 * 1024 * 1024 {
470                // > 100MB
471                report.bottlenecks.push(format!(
472                    "Memory bottleneck: Layer '{}' uses {:.1}MB memory",
473                    layer.name,
474                    layer.memory_usage as f32 / (1024.0 * 1024.0)
475                ));
476            }
477        }
478
479        // Find layers with very high FLOPS
480        if let Some(_max_flops) = self.layers.iter().map(|l| l.flops).max() {
481            let avg_flops = report.total_flops / self.layers.len().max(1) as u64;
482
483            for layer in &self.layers {
484                if layer.flops > avg_flops * 10 {
485                    report.bottlenecks.push(format!(
486                        "Computation bottleneck: Layer '{}' requires {} FLOPS ({}x average)",
487                        layer.name,
488                        layer.flops,
489                        layer.flops / avg_flops.max(1)
490                    ));
491                }
492            }
493        }
494    }
495
496    /// Quick architecture analysis for simplified interface
497    pub async fn quick_analysis(&self) -> Result<crate::QuickArchitectureSummary> {
498        let total_parameters = self.layers.iter().map(|l| l.parameters as u64).sum::<u64>();
499        let total_flops = self.layers.iter().map(|l| l.flops).sum::<u64>();
500
501        // Estimate model size in MB (4 bytes per float32 parameter)
502        let model_size_mb = (total_parameters as f64 * 4.0) / (1024.0 * 1024.0);
503
504        // Calculate efficiency score based on parameters to FLOPS ratio
505        let efficiency_score = if total_flops > 0 {
506            (total_parameters as f64 / total_flops as f64 * 1000.0).min(100.0)
507        } else {
508            50.0
509        };
510
511        let mut recommendations = Vec::new();
512        if total_parameters > 1_000_000_000 {
513            recommendations
514                .push("Consider model compression techniques for large model".to_string());
515        }
516        if efficiency_score < 30.0 {
517            recommendations.push("Model architecture could be more efficient".to_string());
518        }
519        if model_size_mb > 1000.0 {
520            recommendations.push("Large model size may impact deployment".to_string());
521        }
522        if recommendations.is_empty() {
523            recommendations.push("Architecture appears well-balanced".to_string());
524        }
525
526        Ok(crate::QuickArchitectureSummary {
527            total_parameters,
528            model_size_mb,
529            efficiency_score,
530            recommendations,
531        })
532    }
533
534    /// Generate a comprehensive report
535    pub async fn generate_report(&self) -> Result<ArchitectureAnalysisReport> {
536        // Create a temporary clone to avoid mutable borrow issues
537        let mut temp_analyzer = ArchitectureAnalyzer {
538            config: self.config.clone(),
539            layers: self.layers.clone(),
540            connections: self.connections.clone(),
541            analysis_cache: HashMap::new(),
542        };
543
544        temp_analyzer.analyze().await
545    }
546
547    /// Clear all registered layers and connections
548    pub fn clear(&mut self) {
549        self.layers.clear();
550        self.connections.clear();
551        self.analysis_cache.clear();
552    }
553
554    /// Get summary statistics
555    pub fn get_summary(&self) -> ArchitectureSummary {
556        let total_params: usize = self.layers.iter().map(|l| l.parameters).sum();
557        let total_flops: u64 = self.layers.iter().map(|l| l.flops).sum();
558
559        ArchitectureSummary {
560            total_layers: self.layers.len(),
561            total_parameters: total_params,
562            total_flops,
563            average_layer_size: if !self.layers.is_empty() {
564                total_params / self.layers.len()
565            } else {
566                0
567            },
568            layer_type_distribution: {
569                let mut dist = HashMap::new();
570                for layer in &self.layers {
571                    *dist.entry(layer.layer_type.clone()).or_insert(0) += 1;
572                }
573                dist
574            },
575        }
576    }
577}
578
579/// Summary statistics for architecture
580#[derive(Debug, Clone, Serialize, Deserialize)]
581pub struct ArchitectureSummary {
582    pub total_layers: usize,
583    pub total_parameters: usize,
584    pub total_flops: u64,
585    pub average_layer_size: usize,
586    pub layer_type_distribution: HashMap<LayerType, usize>,
587}
588
589/// Convenience function to create a layer info
590pub fn create_layer_info(
591    id: String,
592    name: String,
593    layer_type: LayerType,
594    input_shape: Vec<usize>,
595    output_shape: Vec<usize>,
596    parameters: usize,
597) -> LayerInfo {
598    let memory_usage = parameters * 4; // 4 bytes per float32
599    let flops = estimate_flops(&layer_type, &input_shape, &output_shape, parameters);
600
601    LayerInfo {
602        id,
603        name,
604        layer_type,
605        input_shape,
606        output_shape,
607        parameters,
608        trainable_parameters: parameters, // Assume all parameters are trainable by default
609        memory_usage,
610        flops,
611        receptive_field: None,
612    }
613}
614
615/// Estimate FLOPS for a layer
616fn estimate_flops(
617    layer_type: &LayerType,
618    input_shape: &[usize],
619    output_shape: &[usize],
620    parameters: usize,
621) -> u64 {
622    match layer_type {
623        LayerType::Linear => {
624            // Matrix multiplication: input_features * output_features * batch_size
625            if input_shape.len() >= 2 && output_shape.len() >= 2 {
626                let batch_size = input_shape[0] as u64;
627                let input_features = input_shape[1] as u64;
628                let output_features = output_shape[1] as u64;
629                batch_size * input_features * output_features * 2 // Multiply-add
630            } else {
631                parameters as u64 * 2
632            }
633        },
634        LayerType::Conv2D => {
635            // Convolution: output_h * output_w * kernel_h * kernel_w * input_channels * output_channels
636            if output_shape.len() >= 4 {
637                let batch_size = output_shape[0] as u64;
638                let output_channels = output_shape[1] as u64;
639                let output_h = output_shape[2] as u64;
640                let output_w = output_shape[3] as u64;
641                batch_size
642                    * output_channels
643                    * output_h
644                    * output_w
645                    * (parameters as u64 / output_channels).max(1)
646                    * 2
647            } else {
648                parameters as u64 * 2
649            }
650        },
651        LayerType::Attention => {
652            // Attention: roughly O(sequence_length^2 * hidden_size)
653            if input_shape.len() >= 3 {
654                let batch_size = input_shape[0] as u64;
655                let seq_len = input_shape[1] as u64;
656                let hidden_size = input_shape[2] as u64;
657                batch_size * seq_len * seq_len * hidden_size * 4 // Q, K, V, output projections
658            } else {
659                parameters as u64 * 4
660            }
661        },
662        _ => {
663            // For other layers, estimate based on parameters
664            parameters as u64
665        },
666    }
667}