Skip to main content

trustformers_debug/model_diagnostics/
architecture.rs

1//! Model architecture analysis and optimization.
2//!
3//! This module provides comprehensive model architecture analysis including
4//! parameter efficiency assessment, computational complexity analysis,
5//! memory efficiency evaluation, and architecture optimization recommendations.
6
7use anyhow::Result;
8use std::collections::HashMap;
9
10use super::types::{ArchitecturalAnalysis, ModelArchitectureInfo};
11
12/// Architecture analyzer for evaluating model design and efficiency.
13#[derive(Debug)]
14pub struct ArchitectureAnalyzer {
15    /// Current architecture information
16    architecture_info: Option<ModelArchitectureInfo>,
17    /// Analysis configuration
18    config: ArchitectureAnalysisConfig,
19}
20
21/// Configuration for architecture analysis.
22#[derive(Debug, Clone)]
23pub struct ArchitectureAnalysisConfig {
24    /// Target parameter efficiency threshold
25    pub target_parameter_efficiency: f64,
26    /// Target memory efficiency threshold
27    pub target_memory_efficiency: f64,
28    /// Maximum acceptable model size in MB
29    pub max_model_size_mb: f64,
30    /// Preferred layer types for optimization recommendations
31    pub preferred_layer_types: Vec<String>,
32}
33
34impl Default for ArchitectureAnalysisConfig {
35    fn default() -> Self {
36        Self {
37            target_parameter_efficiency: 0.7,
38            target_memory_efficiency: 0.8,
39            max_model_size_mb: 1024.0, // 1GB
40            preferred_layer_types: vec![
41                "Attention".to_string(),
42                "Linear".to_string(),
43                "Normalization".to_string(),
44            ],
45        }
46    }
47}
48
49impl ArchitectureAnalyzer {
50    /// Create a new architecture analyzer.
51    pub fn new() -> Self {
52        Self {
53            architecture_info: None,
54            config: ArchitectureAnalysisConfig::default(),
55        }
56    }
57
58    /// Create a new architecture analyzer with custom configuration.
59    pub fn with_config(config: ArchitectureAnalysisConfig) -> Self {
60        Self {
61            architecture_info: None,
62            config,
63        }
64    }
65
66    /// Record architecture information.
67    pub fn record_architecture(&mut self, arch_info: ModelArchitectureInfo) {
68        self.architecture_info = Some(arch_info);
69    }
70
71    /// Get current architecture information.
72    pub fn get_architecture_info(&self) -> Option<&ModelArchitectureInfo> {
73        self.architecture_info.as_ref()
74    }
75
76    /// Perform comprehensive architecture analysis.
77    pub fn analyze_architecture(&self) -> Result<ArchitecturalAnalysis> {
78        let arch_info = self
79            .architecture_info
80            .as_ref()
81            .ok_or_else(|| anyhow::anyhow!("No architecture information available"))?;
82
83        let parameter_efficiency = self.calculate_parameter_efficiency(arch_info);
84        let computational_complexity = self.assess_computational_complexity(arch_info);
85        let memory_efficiency = self.calculate_memory_efficiency(arch_info);
86        let recommendations = self.generate_architecture_recommendations(arch_info);
87        let bottlenecks = self.identify_architectural_bottlenecks(arch_info);
88
89        Ok(ArchitecturalAnalysis {
90            parameter_efficiency,
91            computational_complexity,
92            memory_efficiency,
93            recommendations,
94            bottlenecks,
95        })
96    }
97
98    /// Calculate parameter efficiency score.
99    pub fn calculate_parameter_efficiency(&self, arch_info: &ModelArchitectureInfo) -> f64 {
100        if arch_info.total_parameters == 0 {
101            return 0.0;
102        }
103
104        let trainable_ratio =
105            arch_info.trainable_parameters as f64 / arch_info.total_parameters as f64;
106        let size_penalty = if arch_info.model_size_mb > self.config.max_model_size_mb {
107            0.8 // Penalize oversized models
108        } else {
109            1.0
110        };
111
112        // Consider layer type distribution
113        let layer_efficiency = self.calculate_layer_type_efficiency(arch_info);
114
115        (trainable_ratio * size_penalty * layer_efficiency).min(1.0)
116    }
117
118    /// Assess computational complexity of the architecture.
119    pub fn assess_computational_complexity(&self, arch_info: &ModelArchitectureInfo) -> String {
120        let param_count = arch_info.total_parameters;
121        let depth = arch_info.depth;
122        let width = arch_info.width;
123
124        // Estimate computational complexity based on parameters and architecture
125        let complexity_score = (param_count as f64).log10() + (depth as f64 * width as f64).log10();
126
127        match complexity_score {
128            x if x < 8.0 => "Low".to_string(),
129            x if x < 10.0 => "Medium".to_string(),
130            x if x < 12.0 => "High".to_string(),
131            _ => "Very High".to_string(),
132        }
133    }
134
135    /// Calculate memory efficiency score.
136    pub fn calculate_memory_efficiency(&self, arch_info: &ModelArchitectureInfo) -> f64 {
137        if arch_info.model_size_mb == 0.0 {
138            return 0.0;
139        }
140
141        // Theoretical minimum memory based on parameters
142        let theoretical_min_mb = (arch_info.total_parameters as f64 * 4.0) / (1024.0 * 1024.0); // 4 bytes per float32
143        let efficiency = theoretical_min_mb / arch_info.model_size_mb;
144
145        // Factor in layer organization efficiency
146        let layer_organization_bonus = self.calculate_layer_organization_efficiency(arch_info);
147
148        (efficiency * layer_organization_bonus).min(1.0)
149    }
150
151    /// Generate architecture optimization recommendations.
152    pub fn generate_architecture_recommendations(
153        &self,
154        arch_info: &ModelArchitectureInfo,
155    ) -> Vec<String> {
156        let mut recommendations = Vec::new();
157
158        // Parameter efficiency recommendations
159        let param_efficiency = self.calculate_parameter_efficiency(arch_info);
160        if param_efficiency < self.config.target_parameter_efficiency {
161            recommendations.push(
162                "Consider reducing model size or improving parameter utilization".to_string(),
163            );
164            recommendations.push("Evaluate layer pruning opportunities".to_string());
165        }
166
167        // Memory efficiency recommendations
168        let memory_efficiency = self.calculate_memory_efficiency(arch_info);
169        if memory_efficiency < self.config.target_memory_efficiency {
170            recommendations.push("Consider weight quantization to reduce memory usage".to_string());
171            recommendations.push("Evaluate model compression techniques".to_string());
172        }
173
174        // Model size recommendations
175        if arch_info.model_size_mb > self.config.max_model_size_mb {
176            recommendations.push("Model size exceeds recommended limits".to_string());
177            recommendations.push("Consider architectural changes to reduce model size".to_string());
178        }
179
180        // Layer type recommendations
181        let layer_recommendations = self.analyze_layer_type_distribution(arch_info);
182        recommendations.extend(layer_recommendations);
183
184        // Depth and width recommendations
185        if arch_info.depth > 50 {
186            recommendations
187                .push("Very deep model detected - consider residual connections".to_string());
188        }
189
190        if arch_info.width > 4096 {
191            recommendations
192                .push("Very wide model detected - consider factorization techniques".to_string());
193        }
194
195        recommendations
196    }
197
198    /// Identify architectural bottlenecks.
199    pub fn identify_architectural_bottlenecks(
200        &self,
201        arch_info: &ModelArchitectureInfo,
202    ) -> Vec<String> {
203        let mut bottlenecks = Vec::new();
204
205        // Check for imbalanced layer distribution
206        if let Some(dominant_layer) = self.find_dominant_layer_type(arch_info) {
207            if arch_info.layer_types.get(&dominant_layer).unwrap_or(&0)
208                > &(arch_info.layer_count / 2)
209            {
210                bottlenecks.push(format!("Over-reliance on {} layers", dominant_layer));
211            }
212        }
213
214        // Check for activation function bottlenecks
215        if let Some(dominant_activation) = self.find_dominant_activation(arch_info) {
216            if arch_info.activation_functions.get(&dominant_activation).unwrap_or(&0)
217                > &(arch_info.layer_count * 3 / 4)
218            {
219                bottlenecks.push(format!(
220                    "Limited activation function diversity: {} dominates",
221                    dominant_activation
222                ));
223            }
224        }
225
226        // Check for depth/width imbalance
227        let aspect_ratio = arch_info.depth as f64 / arch_info.width as f64;
228        if aspect_ratio > 0.1 {
229            bottlenecks.push("Model may be too deep relative to width".to_string());
230        } else if aspect_ratio < 0.001 {
231            bottlenecks.push("Model may be too wide relative to depth".to_string());
232        }
233
234        // Check for parameter distribution
235        let params_per_layer = arch_info.total_parameters as f64 / arch_info.layer_count as f64;
236        if params_per_layer > 1_000_000.0 {
237            bottlenecks.push("High parameter density per layer detected".to_string());
238        }
239
240        bottlenecks
241    }
242
243    /// Calculate efficiency based on layer types.
244    fn calculate_layer_type_efficiency(&self, arch_info: &ModelArchitectureInfo) -> f64 {
245        let total_layers = arch_info.layer_count as f64;
246        if total_layers == 0.0 {
247            return 0.0;
248        }
249
250        let mut efficiency_score = 0.0;
251        for (layer_type, count) in &arch_info.layer_types {
252            let weight =
253                if self.config.preferred_layer_types.contains(layer_type) { 1.0 } else { 0.8 };
254            efficiency_score += (*count as f64 / total_layers) * weight;
255        }
256
257        efficiency_score.min(1.0)
258    }
259
260    /// Calculate layer organization efficiency.
261    fn calculate_layer_organization_efficiency(&self, arch_info: &ModelArchitectureInfo) -> f64 {
262        // Bonus for good layer type diversity
263        let diversity_bonus = (arch_info.layer_types.len() as f64 / 10.0).min(1.2);
264
265        // Bonus for activation function diversity
266        let activation_bonus = (arch_info.activation_functions.len() as f64 / 5.0).min(1.1);
267
268        // Penalty for extreme aspect ratios
269        let aspect_ratio = arch_info.depth as f64 / arch_info.width as f64;
270        let aspect_penalty = if !(0.002..=0.05).contains(&aspect_ratio) { 0.9 } else { 1.0 };
271
272        diversity_bonus * activation_bonus * aspect_penalty
273    }
274
275    /// Analyze layer type distribution for recommendations.
276    fn analyze_layer_type_distribution(&self, arch_info: &ModelArchitectureInfo) -> Vec<String> {
277        let mut recommendations = Vec::new();
278
279        // Check for missing important layer types
280        if !arch_info.layer_types.contains_key("Normalization") {
281            recommendations
282                .push("Consider adding normalization layers for training stability".to_string());
283        }
284
285        if !arch_info.layer_types.contains_key("Dropout") {
286            recommendations.push("Consider adding dropout layers for regularization".to_string());
287        }
288
289        // Check for layer type imbalances
290        let total_layers = arch_info.layer_count;
291        for (layer_type, count) in &arch_info.layer_types {
292            let ratio = *count as f64 / total_layers as f64;
293            match layer_type.as_str() {
294                "Linear" if ratio > 0.8 => {
295                    recommendations.push(
296                        "High proportion of linear layers - consider adding non-linearity"
297                            .to_string(),
298                    );
299                },
300                "Convolutional" if ratio > 0.9 => {
301                    recommendations.push(
302                        "Very CNN-heavy architecture - consider hybrid approaches".to_string(),
303                    );
304                },
305                "Attention" if ratio > 0.7 => {
306                    recommendations.push(
307                        "Attention-heavy architecture - consider computational efficiency"
308                            .to_string(),
309                    );
310                },
311                _ => {},
312            }
313        }
314
315        recommendations
316    }
317
318    /// Find the dominant layer type.
319    fn find_dominant_layer_type(&self, arch_info: &ModelArchitectureInfo) -> Option<String> {
320        arch_info
321            .layer_types
322            .iter()
323            .max_by_key(|(_, count)| *count)
324            .map(|(layer_type, _)| layer_type.clone())
325    }
326
327    /// Find the dominant activation function.
328    fn find_dominant_activation(&self, arch_info: &ModelArchitectureInfo) -> Option<String> {
329        arch_info
330            .activation_functions
331            .iter()
332            .max_by_key(|(_, count)| *count)
333            .map(|(activation, _)| activation.clone())
334    }
335
336    /// Generate detailed architecture report.
337    pub fn generate_architecture_report(&self) -> Result<ArchitectureReport> {
338        let arch_info = self
339            .architecture_info
340            .as_ref()
341            .ok_or_else(|| anyhow::anyhow!("No architecture information available"))?;
342
343        let analysis = self.analyze_architecture()?;
344
345        let overall_score = self.calculate_overall_architecture_score(&analysis);
346
347        Ok(ArchitectureReport {
348            model_summary: ModelSummary {
349                total_parameters: arch_info.total_parameters,
350                trainable_parameters: arch_info.trainable_parameters,
351                model_size_mb: arch_info.model_size_mb,
352                layer_count: arch_info.layer_count,
353                depth: arch_info.depth,
354                width: arch_info.width,
355            },
356            efficiency_metrics: EfficiencyMetrics {
357                parameter_efficiency: analysis.parameter_efficiency,
358                memory_efficiency: analysis.memory_efficiency,
359                computational_complexity: analysis.computational_complexity,
360            },
361            layer_distribution: arch_info.layer_types.clone(),
362            activation_distribution: arch_info.activation_functions.clone(),
363            recommendations: analysis.recommendations,
364            bottlenecks: analysis.bottlenecks,
365            overall_score,
366        })
367    }
368
369    /// Calculate overall architecture score.
370    fn calculate_overall_architecture_score(&self, analysis: &ArchitecturalAnalysis) -> f64 {
371        let complexity_penalty = match analysis.computational_complexity.as_str() {
372            "Low" => 1.0,
373            "Medium" => 0.9,
374            "High" => 0.8,
375            "Very High" => 0.7,
376            _ => 0.8,
377        };
378
379        let bottleneck_penalty = 1.0 - (analysis.bottlenecks.len() as f64 * 0.1).min(0.5);
380
381        (analysis.parameter_efficiency * 0.4
382            + analysis.memory_efficiency * 0.4
383            + complexity_penalty * 0.2)
384            * bottleneck_penalty
385    }
386
387    /// Clear architecture information.
388    pub fn clear(&mut self) {
389        self.architecture_info = None;
390    }
391}
392
393impl Default for ArchitectureAnalyzer {
394    fn default() -> Self {
395        Self::new()
396    }
397}
398
399/// Comprehensive architecture report.
400#[derive(Debug, Clone)]
401pub struct ArchitectureReport {
402    /// Model summary statistics
403    pub model_summary: ModelSummary,
404    /// Efficiency metrics
405    pub efficiency_metrics: EfficiencyMetrics,
406    /// Layer type distribution
407    pub layer_distribution: HashMap<String, usize>,
408    /// Activation function distribution
409    pub activation_distribution: HashMap<String, usize>,
410    /// Optimization recommendations
411    pub recommendations: Vec<String>,
412    /// Identified bottlenecks
413    pub bottlenecks: Vec<String>,
414    /// Overall architecture score (0.0 to 1.0)
415    pub overall_score: f64,
416}
417
418/// Model summary information.
419#[derive(Debug, Clone)]
420pub struct ModelSummary {
421    /// Total number of parameters
422    pub total_parameters: usize,
423    /// Number of trainable parameters
424    pub trainable_parameters: usize,
425    /// Model size in megabytes
426    pub model_size_mb: f64,
427    /// Total number of layers
428    pub layer_count: usize,
429    /// Model depth
430    pub depth: usize,
431    /// Model width
432    pub width: usize,
433}
434
435/// Efficiency metrics for the architecture.
436#[derive(Debug, Clone)]
437pub struct EfficiencyMetrics {
438    /// Parameter efficiency score
439    pub parameter_efficiency: f64,
440    /// Memory efficiency score
441    pub memory_efficiency: f64,
442    /// Computational complexity assessment
443    pub computational_complexity: String,
444}
445
446#[cfg(test)]
447mod tests {
448    use super::*;
449
450    fn create_test_architecture() -> ModelArchitectureInfo {
451        let mut layer_types = HashMap::new();
452        layer_types.insert("Linear".to_string(), 10);
453        layer_types.insert("Attention".to_string(), 5);
454        layer_types.insert("Normalization".to_string(), 15);
455
456        let mut activation_functions = HashMap::new();
457        activation_functions.insert("ReLU".to_string(), 10);
458        activation_functions.insert("GELU".to_string(), 20);
459
460        ModelArchitectureInfo {
461            total_parameters: 1_000_000,
462            trainable_parameters: 950_000,
463            model_size_mb: 50.0,
464            layer_count: 30,
465            layer_types,
466            depth: 12,
467            width: 768,
468            activation_functions,
469        }
470    }
471
472    #[test]
473    fn test_architecture_analyzer_creation() {
474        let analyzer = ArchitectureAnalyzer::new();
475        assert!(analyzer.architecture_info.is_none());
476    }
477
478    #[test]
479    fn test_record_architecture() {
480        let mut analyzer = ArchitectureAnalyzer::new();
481        let arch_info = create_test_architecture();
482
483        analyzer.record_architecture(arch_info);
484        assert!(analyzer.architecture_info.is_some());
485    }
486
487    #[test]
488    fn test_parameter_efficiency_calculation() {
489        let analyzer = ArchitectureAnalyzer::new();
490        let arch_info = create_test_architecture();
491
492        let efficiency = analyzer.calculate_parameter_efficiency(&arch_info);
493        assert!(efficiency > 0.0 && efficiency <= 1.0);
494    }
495
496    #[test]
497    fn test_computational_complexity_assessment() {
498        let analyzer = ArchitectureAnalyzer::new();
499        let arch_info = create_test_architecture();
500
501        let complexity = analyzer.assess_computational_complexity(&arch_info);
502        assert!(["Low", "Medium", "High", "Very High"].contains(&complexity.as_str()));
503    }
504
505    #[test]
506    fn test_memory_efficiency_calculation() {
507        let analyzer = ArchitectureAnalyzer::new();
508        let arch_info = create_test_architecture();
509
510        let efficiency = analyzer.calculate_memory_efficiency(&arch_info);
511        assert!(efficiency > 0.0 && efficiency <= 1.0);
512    }
513
514    #[test]
515    fn test_architecture_analysis() {
516        let mut analyzer = ArchitectureAnalyzer::new();
517        let arch_info = create_test_architecture();
518
519        analyzer.record_architecture(arch_info);
520        let analysis = analyzer.analyze_architecture().expect("operation failed in test");
521
522        assert!(analysis.parameter_efficiency > 0.0);
523        assert!(analysis.memory_efficiency > 0.0);
524        assert!(!analysis.computational_complexity.is_empty());
525    }
526}