1use anyhow::{Context, Result};
8use parking_lot::RwLock;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use tracing::{debug, info, warn};
13
14#[derive(Debug)]
23pub struct LargeModelVisualizer {
24 config: LargeModelVisualizerConfig,
26 layer_cache: Arc<RwLock<HashMap<String, LayerMetadata>>>,
28 state: Arc<RwLock<VisualizationState>>,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct LargeModelVisualizerConfig {
35 pub enable_smart_sampling: bool,
37 pub max_full_layers: usize,
39 pub sampling_strategy: SamplingStrategy,
41 pub enable_hierarchical: bool,
43 pub enable_streaming: bool,
45 pub max_memory_mb: usize,
47 pub stream_chunk_size: usize,
49 pub enable_progressive_loading: bool,
51 pub output_format: VisualizationFormat,
53}
54
55impl Default for LargeModelVisualizerConfig {
56 fn default() -> Self {
57 Self {
58 enable_smart_sampling: true,
59 max_full_layers: 50,
60 sampling_strategy: SamplingStrategy::Adaptive,
61 enable_hierarchical: true,
62 enable_streaming: true,
63 max_memory_mb: 1024, stream_chunk_size: 10,
65 enable_progressive_loading: true,
66 output_format: VisualizationFormat::InteractiveSvg,
67 }
68 }
69}
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
73pub enum SamplingStrategy {
74 Uniform,
76 Adaptive,
78 Representative,
80 ImportanceBased,
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
86pub enum VisualizationFormat {
87 StaticPng,
89 StaticSvg,
91 InteractiveSvg,
93 InteractiveHtml,
95 TextSummary,
97 JsonMetadata,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct LayerMetadata {
104 pub name: String,
106 pub index: usize,
108 pub layer_type: String,
110 pub param_count: usize,
112 pub memory_mb: f64,
114 pub compute_flops: u64,
116 pub input_shape: Vec<usize>,
118 pub output_shape: Vec<usize>,
120 pub is_sampled: bool,
122}
123
124#[derive(Debug, Clone, Default)]
126struct VisualizationState {
127 total_layers: usize,
129 #[allow(dead_code)]
131 loaded_layers: Vec<String>,
132 current_memory_mb: f64,
134 progress: f64,
136 #[allow(dead_code)]
138 is_complete: bool,
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct VisualizationResult {
144 pub output_path: Option<String>,
146 pub inline_data: Option<Vec<u8>>,
148 pub stats: VisualizationStats,
150 pub sampled_layers: Vec<usize>,
152 pub model_stats: ModelStatistics,
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct VisualizationStats {
159 pub layers_visualized: usize,
161 pub total_layers: usize,
163 pub sampling_ratio: f64,
165 pub memory_used_mb: f64,
167 pub time_taken_secs: f64,
169 pub output_size_bytes: usize,
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct ModelStatistics {
176 pub total_params: usize,
178 pub total_memory_mb: f64,
180 pub total_gflops: f64,
182 pub max_depth: usize,
184 pub layer_types: HashMap<String, usize>,
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
190pub struct LayerGroup {
191 pub name: String,
193 pub layers: Vec<usize>,
195 pub collapsed: bool,
197 pub summary: GroupSummary,
199}
200
201#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct GroupSummary {
204 pub param_count: usize,
206 pub memory_mb: f64,
208 pub avg_compute_flops: u64,
210}
211
212impl LargeModelVisualizer {
213 pub fn new(config: LargeModelVisualizerConfig) -> Self {
226 info!("Initializing large model visualizer");
227 Self {
228 config,
229 layer_cache: Arc::new(RwLock::new(HashMap::new())),
230 state: Arc::new(RwLock::new(VisualizationState::default())),
231 }
232 }
233
234 pub fn add_layer(&self, metadata: LayerMetadata) -> Result<()> {
239 let mut cache = self.layer_cache.write();
240 let mut state = self.state.write();
241
242 cache.insert(metadata.name.clone(), metadata.clone());
243 state.total_layers = cache.len();
244 state.current_memory_mb += metadata.memory_mb;
245
246 if state.current_memory_mb > self.config.max_memory_mb as f64 {
248 warn!(
249 "Memory limit exceeded: {:.1} MB > {} MB. Consider increasing max_memory_mb or enabling sampling",
250 state.current_memory_mb,
251 self.config.max_memory_mb
252 );
253 }
254
255 Ok(())
256 }
257
258 pub fn determine_sampling(&self) -> Result<Vec<usize>> {
263 let cache = self.layer_cache.read();
264 let state = self.state.read();
265
266 if !self.config.enable_smart_sampling || state.total_layers <= self.config.max_full_layers {
267 return Ok((0..state.total_layers).collect());
269 }
270
271 debug!(
272 "Applying {:?} sampling strategy for {} layers",
273 self.config.sampling_strategy, state.total_layers
274 );
275
276 let sampled_indices = match self.config.sampling_strategy {
277 SamplingStrategy::Uniform => self.uniform_sampling(state.total_layers),
278 SamplingStrategy::Adaptive => self.adaptive_sampling(&cache),
279 SamplingStrategy::Representative => self.representative_sampling(state.total_layers),
280 SamplingStrategy::ImportanceBased => self.importance_sampling(&cache),
281 };
282
283 Ok(sampled_indices)
284 }
285
286 fn uniform_sampling(&self, total_layers: usize) -> Vec<usize> {
288 let max_layers = self.config.max_full_layers;
289 let step = (total_layers as f64 / max_layers as f64).ceil() as usize;
290
291 (0..total_layers).step_by(step).collect()
292 }
293
294 fn adaptive_sampling(&self, cache: &HashMap<String, LayerMetadata>) -> Vec<usize> {
296 let mut layers: Vec<_> = cache.values().collect();
297 layers.sort_by_key(|l| l.index);
298
299 let mut sampled = Vec::new();
300 let max_layers = self.config.max_full_layers;
301
302 if !layers.is_empty() {
304 sampled.push(0);
305 sampled.push(layers.len() - 1);
306 }
307
308 let mut variances = Vec::new();
310 for i in 0..layers.len().saturating_sub(1) {
311 let complexity_diff =
312 (layers[i + 1].param_count as i64 - layers[i].param_count as i64).abs();
313 variances.push((i, complexity_diff));
314 }
315
316 variances.sort_by_key(|item| std::cmp::Reverse(item.1));
318
319 for (idx, _) in variances.iter().take(max_layers.saturating_sub(2)) {
321 sampled.push(*idx);
322 }
323
324 sampled.sort_unstable();
325 sampled.dedup();
326 sampled
327 }
328
329 fn representative_sampling(&self, total_layers: usize) -> Vec<usize> {
331 let mut sampled = Vec::new();
332
333 if total_layers == 0 {
334 return sampled;
335 }
336
337 sampled.extend(0..3.min(total_layers));
339
340 let mid = total_layers / 2;
342 sampled.extend((mid.saturating_sub(1))..=(mid + 1).min(total_layers - 1));
343
344 sampled.extend((total_layers.saturating_sub(3))..total_layers);
346
347 let remaining_budget = self.config.max_full_layers.saturating_sub(sampled.len());
349 let step = (total_layers as f64 / remaining_budget as f64).ceil() as usize;
350
351 for i in (0..total_layers).step_by(step) {
352 sampled.push(i);
353 }
354
355 sampled.sort_unstable();
356 sampled.dedup();
357 sampled
358 }
359
360 fn importance_sampling(&self, cache: &HashMap<String, LayerMetadata>) -> Vec<usize> {
362 let mut layers: Vec<_> = cache.values().collect();
363
364 layers.sort_by(|a, b| {
366 let score_a = (a.param_count as f64) + (a.compute_flops as f64 / 1e9);
367 let score_b = (b.param_count as f64) + (b.compute_flops as f64 / 1e9);
368 score_b.partial_cmp(&score_a).unwrap_or(std::cmp::Ordering::Equal)
369 });
370
371 layers.iter().take(self.config.max_full_layers).map(|l| l.index).collect()
372 }
373
374 pub fn create_layer_groups(&self) -> Result<Vec<LayerGroup>> {
378 let cache = self.layer_cache.read();
379
380 if !self.config.enable_hierarchical || cache.len() < 20 {
381 return Ok(Vec::new());
383 }
384
385 let mut groups: HashMap<String, Vec<usize>> = HashMap::new();
386
387 for metadata in cache.values() {
389 groups.entry(metadata.layer_type.clone()).or_default().push(metadata.index);
390 }
391
392 let mut layer_groups = Vec::new();
394
395 for (layer_type, indices) in groups {
396 let group_layers: Vec<_> = indices
398 .iter()
399 .filter_map(|&idx| cache.values().find(|l| l.index == idx))
400 .collect();
401
402 let param_count: usize = group_layers.iter().map(|l| l.param_count).sum();
403 let memory_mb: f64 = group_layers.iter().map(|l| l.memory_mb).sum();
404 let avg_compute_flops = if !group_layers.is_empty() {
405 group_layers.iter().map(|l| l.compute_flops).sum::<u64>()
406 / group_layers.len() as u64
407 } else {
408 0
409 };
410
411 let indices_len = indices.len();
412 layer_groups.push(LayerGroup {
413 name: format!("{} ({} layers)", layer_type, indices_len),
414 layers: indices,
415 collapsed: indices_len > 10, summary: GroupSummary {
417 param_count,
418 memory_mb,
419 avg_compute_flops,
420 },
421 });
422 }
423
424 layer_groups.sort_by_key(|g| g.layers.first().copied().unwrap_or(0));
426
427 Ok(layer_groups)
428 }
429
430 pub fn visualize(&self, output_path: Option<String>) -> Result<VisualizationResult> {
438 info!("Starting large model visualization");
439
440 let start_time = std::time::Instant::now();
441
442 let sampled_layers = self.determine_sampling()?;
444
445 info!(
446 "Visualizing {} out of {} layers",
447 sampled_layers.len(),
448 self.state.read().total_layers
449 );
450
451 let model_stats = self.calculate_model_stats()?;
453
454 let (output_data, output_size) = match self.config.output_format {
456 VisualizationFormat::TextSummary => self.generate_text_summary(&sampled_layers)?,
457 VisualizationFormat::JsonMetadata => self.generate_json_metadata(&sampled_layers)?,
458 VisualizationFormat::StaticSvg => self.generate_static_svg(&sampled_layers)?,
459 VisualizationFormat::InteractiveSvg => {
460 self.generate_interactive_svg(&sampled_layers)?
461 },
462 VisualizationFormat::InteractiveHtml => {
463 self.generate_interactive_html(&sampled_layers)?
464 },
465 VisualizationFormat::StaticPng => {
466 anyhow::bail!("PNG generation not yet implemented - use SVG or HTML instead")
467 },
468 };
469
470 if let Some(ref path) = output_path {
472 std::fs::write(path, &output_data)
473 .with_context(|| format!("Failed to write visualization to {}", path))?;
474 info!("Saved visualization to {}", path);
475 }
476
477 let time_taken = start_time.elapsed().as_secs_f64();
478 let state = self.state.read();
479
480 Ok(VisualizationResult {
481 output_path,
482 inline_data: if output_size < 1024 * 1024 { Some(output_data) } else { None }, stats: VisualizationStats {
484 layers_visualized: sampled_layers.len(),
485 total_layers: state.total_layers,
486 sampling_ratio: sampled_layers.len() as f64 / state.total_layers as f64,
487 memory_used_mb: state.current_memory_mb,
488 time_taken_secs: time_taken,
489 output_size_bytes: output_size,
490 },
491 sampled_layers,
492 model_stats,
493 })
494 }
495
496 fn calculate_model_stats(&self) -> Result<ModelStatistics> {
498 let cache = self.layer_cache.read();
499
500 let total_params: usize = cache.values().map(|l| l.param_count).sum();
501 let total_memory_mb: f64 = cache.values().map(|l| l.memory_mb).sum();
502 let total_gflops: f64 = cache.values().map(|l| l.compute_flops).sum::<u64>() as f64 / 1e9;
503 let max_depth = cache.values().map(|l| l.index).max().unwrap_or(0);
504
505 let mut layer_types: HashMap<String, usize> = HashMap::new();
506 for metadata in cache.values() {
507 *layer_types.entry(metadata.layer_type.clone()).or_insert(0) += 1;
508 }
509
510 Ok(ModelStatistics {
511 total_params,
512 total_memory_mb,
513 total_gflops,
514 max_depth,
515 layer_types,
516 })
517 }
518
519 fn generate_text_summary(&self, sampled_layers: &[usize]) -> Result<(Vec<u8>, usize)> {
521 let cache = self.layer_cache.read();
522
523 let mut summary = String::from("=== Large Model Visualization Summary ===\n\n");
524
525 summary.push_str(&format!(
526 "Total Layers: {}\n",
527 self.state.read().total_layers
528 ));
529 summary.push_str(&format!("Visualized Layers: {}\n\n", sampled_layers.len()));
530
531 summary.push_str("Layer Details:\n");
532 for &idx in sampled_layers {
533 if let Some(layer) = cache.values().find(|l| l.index == idx) {
534 summary.push_str(&format!(
535 " [{}] {} - {} params, {:.2} MB, {:.1} GFLOPS\n",
536 layer.index,
537 layer.name,
538 layer.param_count,
539 layer.memory_mb,
540 layer.compute_flops as f64 / 1e9
541 ));
542 }
543 }
544
545 let bytes = summary.into_bytes();
546 let size = bytes.len();
547 Ok((bytes, size))
548 }
549
550 fn generate_json_metadata(&self, sampled_layers: &[usize]) -> Result<(Vec<u8>, usize)> {
552 let cache = self.layer_cache.read();
553
554 let layers: Vec<_> = sampled_layers
555 .iter()
556 .filter_map(|&idx| cache.values().find(|l| l.index == idx).cloned())
557 .collect();
558
559 let json = serde_json::to_string_pretty(&layers)?;
560 let bytes = json.into_bytes();
561 let size = bytes.len();
562 Ok((bytes, size))
563 }
564
565 fn generate_static_svg(&self, sampled_layers: &[usize]) -> Result<(Vec<u8>, usize)> {
567 let cache = self.layer_cache.read();
568
569 let mut svg = String::from(
570 r#"<?xml version="1.0" encoding="UTF-8"?>
571<svg xmlns="http://www.w3.org/2000/svg" width="1200" height="800" viewBox="0 0 1200 800">
572<style>
573.layer { fill: #4a90e2; stroke: #2c5aa0; stroke-width: 2; }
574.layer-text { fill: white; font-family: Arial, sans-serif; font-size: 12px; }
575.title { font-family: Arial, sans-serif; font-size: 20px; font-weight: bold; }
576</style>
577<text x="600" y="30" class="title" text-anchor="middle">Model Architecture</text>
578"#,
579 );
580
581 let layer_height = 60;
582 let layer_width = 200;
583 let x_offset = 500;
584 let y_start = 60;
585
586 for (i, &idx) in sampled_layers.iter().enumerate() {
587 if let Some(layer) = cache.values().find(|l| l.index == idx) {
588 let y = y_start + i * (layer_height + 20);
589
590 svg.push_str(&format!(
591 r#"<rect x="{}" y="{}" width="{}" height="{}" class="layer" />
592<text x="{}" y="{}" class="layer-text" text-anchor="middle">{}</text>
593<text x="{}" y="{}" class="layer-text" text-anchor="middle">{:.1}M params</text>
594"#,
595 x_offset,
596 y,
597 layer_width,
598 layer_height,
599 x_offset + layer_width / 2,
600 y + 25,
601 layer.name,
602 x_offset + layer_width / 2,
603 y + 45,
604 layer.param_count as f64 / 1e6
605 ));
606 }
607 }
608
609 svg.push_str("</svg>");
610
611 let bytes = svg.into_bytes();
612 let size = bytes.len();
613 Ok((bytes, size))
614 }
615
616 fn generate_interactive_svg(&self, sampled_layers: &[usize]) -> Result<(Vec<u8>, usize)> {
618 self.generate_static_svg(sampled_layers)
621 }
622
623 fn generate_interactive_html(&self, sampled_layers: &[usize]) -> Result<(Vec<u8>, usize)> {
625 let cache = self.layer_cache.read();
626 let model_stats = self.calculate_model_stats()?;
627
628 let mut html = String::from(
629 r#"<!DOCTYPE html>
630<html>
631<head>
632<meta charset="UTF-8">
633<title>Large Model Visualization</title>
634<style>
635body { font-family: Arial, sans-serif; margin: 20px; background: #f5f5f5; }
636.container { max-width: 1200px; margin: 0 auto; }
637.header { background: #4a90e2; color: white; padding: 20px; border-radius: 8px; }
638.stats { display: grid; grid-template-columns: repeat(4, 1fr); gap: 15px; margin: 20px 0; }
639.stat-card { background: white; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); }
640.layer-list { background: white; padding: 20px; border-radius: 8px; }
641.layer { padding: 10px; margin: 5px 0; background: #f9f9f9; border-left: 4px solid #4a90e2; }
642</style>
643</head>
644<body>
645<div class="container">
646<div class="header">
647<h1>Large Model Visualization</h1>
648<p>Interactive view of model architecture</p>
649</div>
650<div class="stats">
651"#,
652 );
653
654 html.push_str(&format!(
656 r#"<div class="stat-card">
657<h3>{:.1}M</h3>
658<p>Total Parameters</p>
659</div>
660<div class="stat-card">
661<h3>{:.1} GB</h3>
662<p>Total Memory</p>
663</div>
664<div class="stat-card">
665<h3>{}</h3>
666<p>Total Layers</p>
667</div>
668<div class="stat-card">
669<h3>{}/{}</h3>
670<p>Visualized/Total</p>
671</div>
672"#,
673 model_stats.total_params as f64 / 1e6,
674 model_stats.total_memory_mb / 1024.0,
675 model_stats.max_depth + 1,
676 sampled_layers.len(),
677 self.state.read().total_layers
678 ));
679
680 html.push_str("</div><div class=\"layer-list\"><h2>Layer Details</h2>");
681
682 for &idx in sampled_layers {
684 if let Some(layer) = cache.values().find(|l| l.index == idx) {
685 html.push_str(&format!(
686 r#"<div class="layer">
687<strong>[{}] {}</strong><br>
688Type: {} | Parameters: {:.1}M | Memory: {:.2} MB | Compute: {:.1} GFLOPS
689</div>
690"#,
691 layer.index,
692 layer.name,
693 layer.layer_type,
694 layer.param_count as f64 / 1e6,
695 layer.memory_mb,
696 layer.compute_flops as f64 / 1e9
697 ));
698 }
699 }
700
701 html.push_str("</div></div></body></html>");
702
703 let bytes = html.into_bytes();
704 let size = bytes.len();
705 Ok((bytes, size))
706 }
707
708 pub fn get_progress(&self) -> f64 {
710 self.state.read().progress
711 }
712
713 pub fn get_memory_stats(&self) -> MemoryStats {
715 let state = self.state.read();
716 MemoryStats {
717 current_mb: state.current_memory_mb,
718 max_mb: self.config.max_memory_mb as f64,
719 utilization_pct: (state.current_memory_mb / self.config.max_memory_mb as f64 * 100.0)
720 .min(100.0),
721 }
722 }
723}
724
725#[derive(Debug, Clone, Serialize, Deserialize)]
727pub struct MemoryStats {
728 pub current_mb: f64,
730 pub max_mb: f64,
732 pub utilization_pct: f64,
734}
735
736#[cfg(test)]
737mod tests {
738 use super::*;
739
740 #[test]
741 fn test_visualizer_creation() {
742 let config = LargeModelVisualizerConfig::default();
743 let _visualizer = LargeModelVisualizer::new(config);
744 }
745
746 #[test]
747 fn test_add_layers() -> Result<()> {
748 let config = LargeModelVisualizerConfig::default();
749 let visualizer = LargeModelVisualizer::new(config);
750
751 for i in 0..10 {
752 let metadata = LayerMetadata {
753 name: format!("layer_{}", i),
754 index: i,
755 layer_type: "Linear".to_string(),
756 param_count: 1024 * 1024,
757 memory_mb: 4.0,
758 compute_flops: 1_000_000_000,
759 input_shape: vec![512],
760 output_shape: vec![512],
761 is_sampled: false,
762 };
763 visualizer.add_layer(metadata)?;
764 }
765
766 let stats = visualizer.get_memory_stats();
767 assert_eq!(stats.current_mb, 40.0);
768
769 Ok(())
770 }
771
772 #[test]
773 fn test_uniform_sampling() -> Result<()> {
774 let config = LargeModelVisualizerConfig {
775 max_full_layers: 5,
776 sampling_strategy: SamplingStrategy::Uniform,
777 ..Default::default()
778 };
779
780 let visualizer = LargeModelVisualizer::new(config);
781
782 for i in 0..20 {
784 let metadata = LayerMetadata {
785 name: format!("layer_{}", i),
786 index: i,
787 layer_type: "Linear".to_string(),
788 param_count: 1024 * 1024,
789 memory_mb: 4.0,
790 compute_flops: 1_000_000_000,
791 input_shape: vec![512],
792 output_shape: vec![512],
793 is_sampled: false,
794 };
795 visualizer.add_layer(metadata)?;
796 }
797
798 let sampled = visualizer.determine_sampling()?;
799 assert_eq!(sampled.len(), 5);
800
801 Ok(())
802 }
803
804 #[test]
805 fn test_text_visualization() -> Result<()> {
806 let config = LargeModelVisualizerConfig {
807 output_format: VisualizationFormat::TextSummary,
808 ..Default::default()
809 };
810
811 let visualizer = LargeModelVisualizer::new(config);
812
813 for i in 0..5 {
815 let metadata = LayerMetadata {
816 name: format!("layer_{}", i),
817 index: i,
818 layer_type: "Linear".to_string(),
819 param_count: 1024 * 1024 * (i + 1),
820 memory_mb: 4.0 * (i + 1) as f64,
821 compute_flops: 1_000_000_000 * (i + 1) as u64,
822 input_shape: vec![512],
823 output_shape: vec![512],
824 is_sampled: false,
825 };
826 visualizer.add_layer(metadata)?;
827 }
828
829 let result = visualizer.visualize(None)?;
830
831 assert_eq!(result.stats.layers_visualized, 5);
832 assert!(result.stats.output_size_bytes > 0);
833
834 Ok(())
835 }
836}