trustformers_debug/model_diagnostics/
architecture.rs1use anyhow::Result;
8use std::collections::HashMap;
9
10use super::types::{ArchitecturalAnalysis, ModelArchitectureInfo};
11
12#[derive(Debug)]
14pub struct ArchitectureAnalyzer {
15 architecture_info: Option<ModelArchitectureInfo>,
17 config: ArchitectureAnalysisConfig,
19}
20
21#[derive(Debug, Clone)]
23pub struct ArchitectureAnalysisConfig {
24 pub target_parameter_efficiency: f64,
26 pub target_memory_efficiency: f64,
28 pub max_model_size_mb: f64,
30 pub preferred_layer_types: Vec<String>,
32}
33
34impl Default for ArchitectureAnalysisConfig {
35 fn default() -> Self {
36 Self {
37 target_parameter_efficiency: 0.7,
38 target_memory_efficiency: 0.8,
39 max_model_size_mb: 1024.0, preferred_layer_types: vec![
41 "Attention".to_string(),
42 "Linear".to_string(),
43 "Normalization".to_string(),
44 ],
45 }
46 }
47}
48
49impl ArchitectureAnalyzer {
50 pub fn new() -> Self {
52 Self {
53 architecture_info: None,
54 config: ArchitectureAnalysisConfig::default(),
55 }
56 }
57
58 pub fn with_config(config: ArchitectureAnalysisConfig) -> Self {
60 Self {
61 architecture_info: None,
62 config,
63 }
64 }
65
66 pub fn record_architecture(&mut self, arch_info: ModelArchitectureInfo) {
68 self.architecture_info = Some(arch_info);
69 }
70
71 pub fn get_architecture_info(&self) -> Option<&ModelArchitectureInfo> {
73 self.architecture_info.as_ref()
74 }
75
76 pub fn analyze_architecture(&self) -> Result<ArchitecturalAnalysis> {
78 let arch_info = self
79 .architecture_info
80 .as_ref()
81 .ok_or_else(|| anyhow::anyhow!("No architecture information available"))?;
82
83 let parameter_efficiency = self.calculate_parameter_efficiency(arch_info);
84 let computational_complexity = self.assess_computational_complexity(arch_info);
85 let memory_efficiency = self.calculate_memory_efficiency(arch_info);
86 let recommendations = self.generate_architecture_recommendations(arch_info);
87 let bottlenecks = self.identify_architectural_bottlenecks(arch_info);
88
89 Ok(ArchitecturalAnalysis {
90 parameter_efficiency,
91 computational_complexity,
92 memory_efficiency,
93 recommendations,
94 bottlenecks,
95 })
96 }
97
98 pub fn calculate_parameter_efficiency(&self, arch_info: &ModelArchitectureInfo) -> f64 {
100 if arch_info.total_parameters == 0 {
101 return 0.0;
102 }
103
104 let trainable_ratio =
105 arch_info.trainable_parameters as f64 / arch_info.total_parameters as f64;
106 let size_penalty = if arch_info.model_size_mb > self.config.max_model_size_mb {
107 0.8 } else {
109 1.0
110 };
111
112 let layer_efficiency = self.calculate_layer_type_efficiency(arch_info);
114
115 (trainable_ratio * size_penalty * layer_efficiency).min(1.0)
116 }
117
118 pub fn assess_computational_complexity(&self, arch_info: &ModelArchitectureInfo) -> String {
120 let param_count = arch_info.total_parameters;
121 let depth = arch_info.depth;
122 let width = arch_info.width;
123
124 let complexity_score = (param_count as f64).log10() + (depth as f64 * width as f64).log10();
126
127 match complexity_score {
128 x if x < 8.0 => "Low".to_string(),
129 x if x < 10.0 => "Medium".to_string(),
130 x if x < 12.0 => "High".to_string(),
131 _ => "Very High".to_string(),
132 }
133 }
134
135 pub fn calculate_memory_efficiency(&self, arch_info: &ModelArchitectureInfo) -> f64 {
137 if arch_info.model_size_mb == 0.0 {
138 return 0.0;
139 }
140
141 let theoretical_min_mb = (arch_info.total_parameters as f64 * 4.0) / (1024.0 * 1024.0); let efficiency = theoretical_min_mb / arch_info.model_size_mb;
144
145 let layer_organization_bonus = self.calculate_layer_organization_efficiency(arch_info);
147
148 (efficiency * layer_organization_bonus).min(1.0)
149 }
150
151 pub fn generate_architecture_recommendations(
153 &self,
154 arch_info: &ModelArchitectureInfo,
155 ) -> Vec<String> {
156 let mut recommendations = Vec::new();
157
158 let param_efficiency = self.calculate_parameter_efficiency(arch_info);
160 if param_efficiency < self.config.target_parameter_efficiency {
161 recommendations.push(
162 "Consider reducing model size or improving parameter utilization".to_string(),
163 );
164 recommendations.push("Evaluate layer pruning opportunities".to_string());
165 }
166
167 let memory_efficiency = self.calculate_memory_efficiency(arch_info);
169 if memory_efficiency < self.config.target_memory_efficiency {
170 recommendations.push("Consider weight quantization to reduce memory usage".to_string());
171 recommendations.push("Evaluate model compression techniques".to_string());
172 }
173
174 if arch_info.model_size_mb > self.config.max_model_size_mb {
176 recommendations.push("Model size exceeds recommended limits".to_string());
177 recommendations.push("Consider architectural changes to reduce model size".to_string());
178 }
179
180 let layer_recommendations = self.analyze_layer_type_distribution(arch_info);
182 recommendations.extend(layer_recommendations);
183
184 if arch_info.depth > 50 {
186 recommendations
187 .push("Very deep model detected - consider residual connections".to_string());
188 }
189
190 if arch_info.width > 4096 {
191 recommendations
192 .push("Very wide model detected - consider factorization techniques".to_string());
193 }
194
195 recommendations
196 }
197
198 pub fn identify_architectural_bottlenecks(
200 &self,
201 arch_info: &ModelArchitectureInfo,
202 ) -> Vec<String> {
203 let mut bottlenecks = Vec::new();
204
205 if let Some(dominant_layer) = self.find_dominant_layer_type(arch_info) {
207 if arch_info.layer_types.get(&dominant_layer).unwrap_or(&0)
208 > &(arch_info.layer_count / 2)
209 {
210 bottlenecks.push(format!("Over-reliance on {} layers", dominant_layer));
211 }
212 }
213
214 if let Some(dominant_activation) = self.find_dominant_activation(arch_info) {
216 if arch_info.activation_functions.get(&dominant_activation).unwrap_or(&0)
217 > &(arch_info.layer_count * 3 / 4)
218 {
219 bottlenecks.push(format!(
220 "Limited activation function diversity: {} dominates",
221 dominant_activation
222 ));
223 }
224 }
225
226 let aspect_ratio = arch_info.depth as f64 / arch_info.width as f64;
228 if aspect_ratio > 0.1 {
229 bottlenecks.push("Model may be too deep relative to width".to_string());
230 } else if aspect_ratio < 0.001 {
231 bottlenecks.push("Model may be too wide relative to depth".to_string());
232 }
233
234 let params_per_layer = arch_info.total_parameters as f64 / arch_info.layer_count as f64;
236 if params_per_layer > 1_000_000.0 {
237 bottlenecks.push("High parameter density per layer detected".to_string());
238 }
239
240 bottlenecks
241 }
242
243 fn calculate_layer_type_efficiency(&self, arch_info: &ModelArchitectureInfo) -> f64 {
245 let total_layers = arch_info.layer_count as f64;
246 if total_layers == 0.0 {
247 return 0.0;
248 }
249
250 let mut efficiency_score = 0.0;
251 for (layer_type, count) in &arch_info.layer_types {
252 let weight =
253 if self.config.preferred_layer_types.contains(layer_type) { 1.0 } else { 0.8 };
254 efficiency_score += (*count as f64 / total_layers) * weight;
255 }
256
257 efficiency_score.min(1.0)
258 }
259
260 fn calculate_layer_organization_efficiency(&self, arch_info: &ModelArchitectureInfo) -> f64 {
262 let diversity_bonus = (arch_info.layer_types.len() as f64 / 10.0).min(1.2);
264
265 let activation_bonus = (arch_info.activation_functions.len() as f64 / 5.0).min(1.1);
267
268 let aspect_ratio = arch_info.depth as f64 / arch_info.width as f64;
270 let aspect_penalty = if !(0.002..=0.05).contains(&aspect_ratio) { 0.9 } else { 1.0 };
271
272 diversity_bonus * activation_bonus * aspect_penalty
273 }
274
275 fn analyze_layer_type_distribution(&self, arch_info: &ModelArchitectureInfo) -> Vec<String> {
277 let mut recommendations = Vec::new();
278
279 if !arch_info.layer_types.contains_key("Normalization") {
281 recommendations
282 .push("Consider adding normalization layers for training stability".to_string());
283 }
284
285 if !arch_info.layer_types.contains_key("Dropout") {
286 recommendations.push("Consider adding dropout layers for regularization".to_string());
287 }
288
289 let total_layers = arch_info.layer_count;
291 for (layer_type, count) in &arch_info.layer_types {
292 let ratio = *count as f64 / total_layers as f64;
293 match layer_type.as_str() {
294 "Linear" if ratio > 0.8 => {
295 recommendations.push(
296 "High proportion of linear layers - consider adding non-linearity"
297 .to_string(),
298 );
299 },
300 "Convolutional" if ratio > 0.9 => {
301 recommendations.push(
302 "Very CNN-heavy architecture - consider hybrid approaches".to_string(),
303 );
304 },
305 "Attention" if ratio > 0.7 => {
306 recommendations.push(
307 "Attention-heavy architecture - consider computational efficiency"
308 .to_string(),
309 );
310 },
311 _ => {},
312 }
313 }
314
315 recommendations
316 }
317
318 fn find_dominant_layer_type(&self, arch_info: &ModelArchitectureInfo) -> Option<String> {
320 arch_info
321 .layer_types
322 .iter()
323 .max_by_key(|(_, count)| *count)
324 .map(|(layer_type, _)| layer_type.clone())
325 }
326
327 fn find_dominant_activation(&self, arch_info: &ModelArchitectureInfo) -> Option<String> {
329 arch_info
330 .activation_functions
331 .iter()
332 .max_by_key(|(_, count)| *count)
333 .map(|(activation, _)| activation.clone())
334 }
335
336 pub fn generate_architecture_report(&self) -> Result<ArchitectureReport> {
338 let arch_info = self
339 .architecture_info
340 .as_ref()
341 .ok_or_else(|| anyhow::anyhow!("No architecture information available"))?;
342
343 let analysis = self.analyze_architecture()?;
344
345 let overall_score = self.calculate_overall_architecture_score(&analysis);
346
347 Ok(ArchitectureReport {
348 model_summary: ModelSummary {
349 total_parameters: arch_info.total_parameters,
350 trainable_parameters: arch_info.trainable_parameters,
351 model_size_mb: arch_info.model_size_mb,
352 layer_count: arch_info.layer_count,
353 depth: arch_info.depth,
354 width: arch_info.width,
355 },
356 efficiency_metrics: EfficiencyMetrics {
357 parameter_efficiency: analysis.parameter_efficiency,
358 memory_efficiency: analysis.memory_efficiency,
359 computational_complexity: analysis.computational_complexity,
360 },
361 layer_distribution: arch_info.layer_types.clone(),
362 activation_distribution: arch_info.activation_functions.clone(),
363 recommendations: analysis.recommendations,
364 bottlenecks: analysis.bottlenecks,
365 overall_score,
366 })
367 }
368
369 fn calculate_overall_architecture_score(&self, analysis: &ArchitecturalAnalysis) -> f64 {
371 let complexity_penalty = match analysis.computational_complexity.as_str() {
372 "Low" => 1.0,
373 "Medium" => 0.9,
374 "High" => 0.8,
375 "Very High" => 0.7,
376 _ => 0.8,
377 };
378
379 let bottleneck_penalty = 1.0 - (analysis.bottlenecks.len() as f64 * 0.1).min(0.5);
380
381 (analysis.parameter_efficiency * 0.4
382 + analysis.memory_efficiency * 0.4
383 + complexity_penalty * 0.2)
384 * bottleneck_penalty
385 }
386
387 pub fn clear(&mut self) {
389 self.architecture_info = None;
390 }
391}
392
393impl Default for ArchitectureAnalyzer {
394 fn default() -> Self {
395 Self::new()
396 }
397}
398
399#[derive(Debug, Clone)]
401pub struct ArchitectureReport {
402 pub model_summary: ModelSummary,
404 pub efficiency_metrics: EfficiencyMetrics,
406 pub layer_distribution: HashMap<String, usize>,
408 pub activation_distribution: HashMap<String, usize>,
410 pub recommendations: Vec<String>,
412 pub bottlenecks: Vec<String>,
414 pub overall_score: f64,
416}
417
418#[derive(Debug, Clone)]
420pub struct ModelSummary {
421 pub total_parameters: usize,
423 pub trainable_parameters: usize,
425 pub model_size_mb: f64,
427 pub layer_count: usize,
429 pub depth: usize,
431 pub width: usize,
433}
434
435#[derive(Debug, Clone)]
437pub struct EfficiencyMetrics {
438 pub parameter_efficiency: f64,
440 pub memory_efficiency: f64,
442 pub computational_complexity: String,
444}
445
446#[cfg(test)]
447mod tests {
448 use super::*;
449
450 fn create_test_architecture() -> ModelArchitectureInfo {
451 let mut layer_types = HashMap::new();
452 layer_types.insert("Linear".to_string(), 10);
453 layer_types.insert("Attention".to_string(), 5);
454 layer_types.insert("Normalization".to_string(), 15);
455
456 let mut activation_functions = HashMap::new();
457 activation_functions.insert("ReLU".to_string(), 10);
458 activation_functions.insert("GELU".to_string(), 20);
459
460 ModelArchitectureInfo {
461 total_parameters: 1_000_000,
462 trainable_parameters: 950_000,
463 model_size_mb: 50.0,
464 layer_count: 30,
465 layer_types,
466 depth: 12,
467 width: 768,
468 activation_functions,
469 }
470 }
471
472 #[test]
473 fn test_architecture_analyzer_creation() {
474 let analyzer = ArchitectureAnalyzer::new();
475 assert!(analyzer.architecture_info.is_none());
476 }
477
478 #[test]
479 fn test_record_architecture() {
480 let mut analyzer = ArchitectureAnalyzer::new();
481 let arch_info = create_test_architecture();
482
483 analyzer.record_architecture(arch_info);
484 assert!(analyzer.architecture_info.is_some());
485 }
486
487 #[test]
488 fn test_parameter_efficiency_calculation() {
489 let analyzer = ArchitectureAnalyzer::new();
490 let arch_info = create_test_architecture();
491
492 let efficiency = analyzer.calculate_parameter_efficiency(&arch_info);
493 assert!(efficiency > 0.0 && efficiency <= 1.0);
494 }
495
496 #[test]
497 fn test_computational_complexity_assessment() {
498 let analyzer = ArchitectureAnalyzer::new();
499 let arch_info = create_test_architecture();
500
501 let complexity = analyzer.assess_computational_complexity(&arch_info);
502 assert!(["Low", "Medium", "High", "Very High"].contains(&complexity.as_str()));
503 }
504
505 #[test]
506 fn test_memory_efficiency_calculation() {
507 let analyzer = ArchitectureAnalyzer::new();
508 let arch_info = create_test_architecture();
509
510 let efficiency = analyzer.calculate_memory_efficiency(&arch_info);
511 assert!(efficiency > 0.0 && efficiency <= 1.0);
512 }
513
514 #[test]
515 fn test_architecture_analysis() {
516 let mut analyzer = ArchitectureAnalyzer::new();
517 let arch_info = create_test_architecture();
518
519 analyzer.record_architecture(arch_info);
520 let analysis = analyzer.analyze_architecture().expect("operation failed in test");
521
522 assert!(analysis.parameter_efficiency > 0.0);
523 assert!(analysis.memory_efficiency > 0.0);
524 assert!(!analysis.computational_complexity.is_empty());
525 }
526}