1use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ArchitectureAnalysisConfig {
13 pub enable_parameter_counting: bool,
15 pub enable_receptive_field_calculation: bool,
17 pub enable_depth_width_analysis: bool,
19 pub enable_connectivity_patterns: bool,
21 pub enable_symmetry_detection: bool,
23 pub max_receptive_field_depth: usize,
25 pub sampling_rate: f32,
27}
28
29impl Default for ArchitectureAnalysisConfig {
30 fn default() -> Self {
31 Self {
32 enable_parameter_counting: true,
33 enable_receptive_field_calculation: true,
34 enable_depth_width_analysis: true,
35 enable_connectivity_patterns: true,
36 enable_symmetry_detection: true,
37 max_receptive_field_depth: 50,
38 sampling_rate: 1.0,
39 }
40 }
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
45pub enum LayerType {
46 Linear,
47 Conv2D,
48 Conv3D,
49 BatchNorm,
50 LayerNorm,
51 Attention,
52 Embedding,
53 Dropout,
54 Activation,
55 Pooling,
56 Residual,
57 Unknown,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct LayerInfo {
63 pub id: String,
64 pub name: String,
65 pub layer_type: LayerType,
66 pub input_shape: Vec<usize>,
67 pub output_shape: Vec<usize>,
68 pub parameters: usize,
69 pub trainable_parameters: usize,
70 pub memory_usage: usize,
71 pub flops: u64,
72 pub receptive_field: Option<ReceptiveField>,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct ReceptiveField {
78 pub size: Vec<usize>,
79 pub stride: Vec<usize>,
80 pub padding: Vec<usize>,
81 pub effective_size: Vec<usize>,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct ConnectivityPattern {
87 pub from_layer: String,
88 pub to_layer: String,
89 pub connection_type: ConnectionType,
90 pub strength: f32,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize, Hash, Eq, PartialEq)]
94pub enum ConnectionType {
95 Sequential,
96 Residual,
97 Attention,
98 Skip,
99 Recurrent,
100 Branching,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct SymmetryInfo {
106 pub symmetry_type: SymmetryType,
107 pub symmetric_layers: Vec<String>,
108 pub confidence: f32,
109 pub description: String,
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub enum SymmetryType {
114 Translational,
115 Rotational,
116 Reflection,
117 Permutation,
118 Block,
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct ArchitectureAnalysisReport {
124 pub total_parameters: usize,
125 pub trainable_parameters: usize,
126 pub model_size_mb: f32,
127 pub total_flops: u64,
128 pub model_depth: usize,
129 pub model_width: usize,
130 pub layers: Vec<LayerInfo>,
131 pub connectivity_patterns: Vec<ConnectivityPattern>,
132 pub symmetries: Vec<SymmetryInfo>,
133 pub parameter_distribution: HashMap<LayerType, usize>,
134 pub bottlenecks: Vec<String>,
135 pub efficiency_metrics: EfficiencyMetrics,
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct EfficiencyMetrics {
141 pub parameter_efficiency: f32,
142 pub flops_efficiency: f32,
143 pub memory_efficiency: f32,
144 pub depth_efficiency: f32,
145 pub overall_score: f32,
146}
147
148#[derive(Debug)]
150pub struct ArchitectureAnalyzer {
151 config: ArchitectureAnalysisConfig,
152 layers: Vec<LayerInfo>,
153 connections: Vec<ConnectivityPattern>,
154 analysis_cache: HashMap<String, ArchitectureAnalysisReport>,
155}
156
157impl ArchitectureAnalyzer {
158 pub fn new(config: ArchitectureAnalysisConfig) -> Self {
160 Self {
161 config,
162 layers: Vec::new(),
163 connections: Vec::new(),
164 analysis_cache: HashMap::new(),
165 }
166 }
167
168 pub fn register_layer(&mut self, layer: LayerInfo) {
170 self.layers.push(layer);
171 }
172
173 pub fn add_connection(&mut self, pattern: ConnectivityPattern) {
175 self.connections.push(pattern);
176 }
177
178 pub async fn analyze(&mut self) -> Result<ArchitectureAnalysisReport> {
180 let mut report = ArchitectureAnalysisReport {
181 total_parameters: 0,
182 trainable_parameters: 0,
183 model_size_mb: 0.0,
184 total_flops: 0,
185 model_depth: 0,
186 model_width: 0,
187 layers: self.layers.clone(),
188 connectivity_patterns: self.connections.clone(),
189 symmetries: Vec::new(),
190 parameter_distribution: HashMap::new(),
191 bottlenecks: Vec::new(),
192 efficiency_metrics: EfficiencyMetrics {
193 parameter_efficiency: 0.0,
194 flops_efficiency: 0.0,
195 memory_efficiency: 0.0,
196 depth_efficiency: 0.0,
197 overall_score: 0.0,
198 },
199 };
200
201 if self.config.enable_parameter_counting {
202 self.count_parameters(&mut report);
203 }
204
205 if self.config.enable_receptive_field_calculation {
206 self.calculate_receptive_fields(&mut report).await?;
207 }
208
209 if self.config.enable_depth_width_analysis {
210 self.analyze_depth_width(&mut report);
211 }
212
213 if self.config.enable_connectivity_patterns {
214 self.analyze_connectivity_patterns(&mut report);
215 }
216
217 if self.config.enable_symmetry_detection {
218 self.detect_symmetries(&mut report);
219 }
220
221 self.calculate_efficiency_metrics(&mut report);
222 self.identify_bottlenecks(&mut report);
223
224 Ok(report)
225 }
226
227 fn count_parameters(&self, report: &mut ArchitectureAnalysisReport) {
229 let mut param_distribution: HashMap<LayerType, usize> = HashMap::new();
230
231 for layer in &self.layers {
232 report.total_parameters += layer.parameters;
233 report.trainable_parameters += layer.trainable_parameters;
234
235 *param_distribution.entry(layer.layer_type.clone()).or_insert(0) += layer.parameters;
236 }
237
238 report.parameter_distribution = param_distribution;
239
240 report.model_size_mb = (report.total_parameters * 4) as f32 / (1024.0 * 1024.0);
242
243 report.total_flops = self.layers.iter().map(|l| l.flops).sum();
245 }
246
247 async fn calculate_receptive_fields(
249 &mut self,
250 report: &mut ArchitectureAnalysisReport,
251 ) -> Result<()> {
252 for layer in &mut self.layers {
253 if matches!(layer.layer_type, LayerType::Conv2D | LayerType::Conv3D) {
254 layer.receptive_field =
255 Some(Self::compute_receptive_field_static(&layer.layer_type));
256 }
257 }
258
259 report.layers = self.layers.clone();
260 Ok(())
261 }
262
263 fn compute_receptive_field_static(layer_type: &LayerType) -> ReceptiveField {
265 match layer_type {
266 LayerType::Conv2D => {
267 let kernel_size = vec![3, 3]; let stride = vec![1, 1];
270 let padding = vec![1, 1];
271
272 ReceptiveField {
273 size: kernel_size.clone(),
274 stride,
275 padding,
276 effective_size: kernel_size,
277 }
278 },
279 LayerType::Conv3D => {
280 let kernel_size = vec![3, 3, 3]; let stride = vec![1, 1, 1];
283 let padding = vec![1, 1, 1];
284
285 ReceptiveField {
286 size: kernel_size.clone(),
287 stride,
288 padding,
289 effective_size: kernel_size,
290 }
291 },
292 _ => {
293 ReceptiveField {
295 size: vec![1],
296 stride: vec![1],
297 padding: vec![0],
298 effective_size: vec![1],
299 }
300 },
301 }
302 }
303
304 #[allow(dead_code)]
306 fn compute_receptive_field(&self, layer: &LayerInfo) -> ReceptiveField {
307 Self::compute_receptive_field_static(&layer.layer_type)
308 }
309
310 fn analyze_depth_width(&self, report: &mut ArchitectureAnalysisReport) {
312 report.model_depth = self.layers.len();
314
315 report.model_width = self.layers.iter().map(|l| l.parameters).max().unwrap_or(0);
317 }
318
319 fn analyze_connectivity_patterns(&self, report: &mut ArchitectureAnalysisReport) {
321 let mut pattern_types: HashMap<ConnectionType, usize> = HashMap::new();
322
323 for connection in &self.connections {
324 *pattern_types.entry(connection.connection_type.clone()).or_insert(0) += 1;
325 }
326
327 for (connection_type, count) in pattern_types {
329 if count > self.layers.len() / 2 {
330 report.bottlenecks.push(format!(
332 "High {:?} connectivity: {} connections",
333 connection_type, count
334 ));
335 }
336 }
337 }
338
339 fn detect_symmetries(&self, report: &mut ArchitectureAnalysisReport) {
341 let mut block_patterns: HashMap<Vec<LayerType>, Vec<usize>> = HashMap::new();
343
344 for window_size in 2..=5.min(self.layers.len()) {
346 for i in 0..=(self.layers.len() - window_size) {
347 let pattern: Vec<LayerType> =
348 self.layers[i..i + window_size].iter().map(|l| l.layer_type.clone()).collect();
349
350 block_patterns.entry(pattern).or_insert_with(Vec::new).push(i);
351 }
352 }
353
354 for (pattern, positions) in block_patterns {
356 if positions.len() > 1 {
357 let confidence = positions.len() as f32 / self.layers.len() as f32;
358
359 if confidence > 0.1 {
360 report.symmetries.push(SymmetryInfo {
362 symmetry_type: SymmetryType::Block,
363 symmetric_layers: positions
364 .iter()
365 .map(|&i| format!("block_{}", i))
366 .collect(),
367 confidence,
368 description: format!(
369 "Repeated block pattern: {:?} appears {} times",
370 pattern,
371 positions.len()
372 ),
373 });
374 }
375 }
376 }
377
378 let mut param_groups: HashMap<usize, Vec<String>> = HashMap::new();
380 for layer in &self.layers {
381 param_groups
382 .entry(layer.parameters)
383 .or_insert_with(Vec::new)
384 .push(layer.id.clone());
385 }
386
387 for (param_count, layer_ids) in param_groups {
388 if layer_ids.len() > 2 && param_count > 0 {
389 let confidence = layer_ids.len() as f32 / self.layers.len() as f32;
390
391 report.symmetries.push(SymmetryInfo {
392 symmetry_type: SymmetryType::Permutation,
393 symmetric_layers: layer_ids.clone(),
394 confidence,
395 description: format!(
396 "Parameter symmetry: {} layers with {} parameters each",
397 layer_ids.len(),
398 param_count
399 ),
400 });
401 }
402 }
403 }
404
405 fn calculate_efficiency_metrics(&self, report: &mut ArchitectureAnalysisReport) {
407 let total_params = report.total_parameters as f32;
408 let total_flops = report.total_flops as f32;
409 let depth = report.model_depth as f32;
410 let memory = report.model_size_mb;
411
412 report.efficiency_metrics.parameter_efficiency = if total_params > 0.0 {
414 1.0 / (total_params / 1_000_000.0).log10().max(1.0) } else {
416 1.0
417 };
418
419 report.efficiency_metrics.flops_efficiency = if total_flops > 0.0 {
421 1.0 / (total_flops / 1_000_000_000.0).log10().max(1.0) } else {
423 1.0
424 };
425
426 report.efficiency_metrics.memory_efficiency = if memory > 0.0 {
428 1.0 / (memory / 100.0).log10().max(1.0) } else {
430 1.0
431 };
432
433 report.efficiency_metrics.depth_efficiency = if depth > 0.0 {
435 let optimal_depth = 20.0; 1.0 - ((depth - optimal_depth).abs() / optimal_depth).min(1.0)
437 } else {
438 0.0
439 };
440
441 report.efficiency_metrics.overall_score = 0.3
443 * report.efficiency_metrics.parameter_efficiency
444 + 0.3 * report.efficiency_metrics.flops_efficiency
445 + 0.2 * report.efficiency_metrics.memory_efficiency
446 + 0.2 * report.efficiency_metrics.depth_efficiency;
447 }
448
449 fn identify_bottlenecks(&self, report: &mut ArchitectureAnalysisReport) {
451 if let Some(_max_params) = self.layers.iter().map(|l| l.parameters).max() {
453 let avg_params = report.total_parameters / self.layers.len().max(1);
454
455 for layer in &self.layers {
456 if layer.parameters > avg_params * 5 {
457 report.bottlenecks.push(format!(
458 "Parameter bottleneck: Layer '{}' has {} parameters ({}x average)",
459 layer.name,
460 layer.parameters,
461 layer.parameters / avg_params.max(1)
462 ));
463 }
464 }
465 }
466
467 for layer in &self.layers {
469 if layer.memory_usage > 100 * 1024 * 1024 {
470 report.bottlenecks.push(format!(
472 "Memory bottleneck: Layer '{}' uses {:.1}MB memory",
473 layer.name,
474 layer.memory_usage as f32 / (1024.0 * 1024.0)
475 ));
476 }
477 }
478
479 if let Some(_max_flops) = self.layers.iter().map(|l| l.flops).max() {
481 let avg_flops = report.total_flops / self.layers.len().max(1) as u64;
482
483 for layer in &self.layers {
484 if layer.flops > avg_flops * 10 {
485 report.bottlenecks.push(format!(
486 "Computation bottleneck: Layer '{}' requires {} FLOPS ({}x average)",
487 layer.name,
488 layer.flops,
489 layer.flops / avg_flops.max(1)
490 ));
491 }
492 }
493 }
494 }
495
496 pub async fn quick_analysis(&self) -> Result<crate::QuickArchitectureSummary> {
498 let total_parameters = self.layers.iter().map(|l| l.parameters as u64).sum::<u64>();
499 let total_flops = self.layers.iter().map(|l| l.flops).sum::<u64>();
500
501 let model_size_mb = (total_parameters as f64 * 4.0) / (1024.0 * 1024.0);
503
504 let efficiency_score = if total_flops > 0 {
506 (total_parameters as f64 / total_flops as f64 * 1000.0).min(100.0)
507 } else {
508 50.0
509 };
510
511 let mut recommendations = Vec::new();
512 if total_parameters > 1_000_000_000 {
513 recommendations
514 .push("Consider model compression techniques for large model".to_string());
515 }
516 if efficiency_score < 30.0 {
517 recommendations.push("Model architecture could be more efficient".to_string());
518 }
519 if model_size_mb > 1000.0 {
520 recommendations.push("Large model size may impact deployment".to_string());
521 }
522 if recommendations.is_empty() {
523 recommendations.push("Architecture appears well-balanced".to_string());
524 }
525
526 Ok(crate::QuickArchitectureSummary {
527 total_parameters,
528 model_size_mb,
529 efficiency_score,
530 recommendations,
531 })
532 }
533
534 pub async fn generate_report(&self) -> Result<ArchitectureAnalysisReport> {
536 let mut temp_analyzer = ArchitectureAnalyzer {
538 config: self.config.clone(),
539 layers: self.layers.clone(),
540 connections: self.connections.clone(),
541 analysis_cache: HashMap::new(),
542 };
543
544 temp_analyzer.analyze().await
545 }
546
547 pub fn clear(&mut self) {
549 self.layers.clear();
550 self.connections.clear();
551 self.analysis_cache.clear();
552 }
553
554 pub fn get_summary(&self) -> ArchitectureSummary {
556 let total_params: usize = self.layers.iter().map(|l| l.parameters).sum();
557 let total_flops: u64 = self.layers.iter().map(|l| l.flops).sum();
558
559 ArchitectureSummary {
560 total_layers: self.layers.len(),
561 total_parameters: total_params,
562 total_flops,
563 average_layer_size: if !self.layers.is_empty() {
564 total_params / self.layers.len()
565 } else {
566 0
567 },
568 layer_type_distribution: {
569 let mut dist = HashMap::new();
570 for layer in &self.layers {
571 *dist.entry(layer.layer_type.clone()).or_insert(0) += 1;
572 }
573 dist
574 },
575 }
576 }
577}
578
579#[derive(Debug, Clone, Serialize, Deserialize)]
581pub struct ArchitectureSummary {
582 pub total_layers: usize,
583 pub total_parameters: usize,
584 pub total_flops: u64,
585 pub average_layer_size: usize,
586 pub layer_type_distribution: HashMap<LayerType, usize>,
587}
588
589pub fn create_layer_info(
591 id: String,
592 name: String,
593 layer_type: LayerType,
594 input_shape: Vec<usize>,
595 output_shape: Vec<usize>,
596 parameters: usize,
597) -> LayerInfo {
598 let memory_usage = parameters * 4; let flops = estimate_flops(&layer_type, &input_shape, &output_shape, parameters);
600
601 LayerInfo {
602 id,
603 name,
604 layer_type,
605 input_shape,
606 output_shape,
607 parameters,
608 trainable_parameters: parameters, memory_usage,
610 flops,
611 receptive_field: None,
612 }
613}
614
615fn estimate_flops(
617 layer_type: &LayerType,
618 input_shape: &[usize],
619 output_shape: &[usize],
620 parameters: usize,
621) -> u64 {
622 match layer_type {
623 LayerType::Linear => {
624 if input_shape.len() >= 2 && output_shape.len() >= 2 {
626 let batch_size = input_shape[0] as u64;
627 let input_features = input_shape[1] as u64;
628 let output_features = output_shape[1] as u64;
629 batch_size * input_features * output_features * 2 } else {
631 parameters as u64 * 2
632 }
633 },
634 LayerType::Conv2D => {
635 if output_shape.len() >= 4 {
637 let batch_size = output_shape[0] as u64;
638 let output_channels = output_shape[1] as u64;
639 let output_h = output_shape[2] as u64;
640 let output_w = output_shape[3] as u64;
641 batch_size
642 * output_channels
643 * output_h
644 * output_w
645 * (parameters as u64 / output_channels).max(1)
646 * 2
647 } else {
648 parameters as u64 * 2
649 }
650 },
651 LayerType::Attention => {
652 if input_shape.len() >= 3 {
654 let batch_size = input_shape[0] as u64;
655 let seq_len = input_shape[1] as u64;
656 let hidden_size = input_shape[2] as u64;
657 batch_size * seq_len * seq_len * hidden_size * 4 } else {
659 parameters as u64 * 4
660 }
661 },
662 _ => {
663 parameters as u64
665 },
666 }
667}