1use anyhow::{Context, Result};
8use parking_lot::RwLock;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use tracing::{debug, info, warn};
13
14#[derive(Debug)]
23pub struct LargeModelVisualizer {
24 config: LargeModelVisualizerConfig,
26 layer_cache: Arc<RwLock<HashMap<String, LayerMetadata>>>,
28 state: Arc<RwLock<VisualizationState>>,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct LargeModelVisualizerConfig {
35 pub enable_smart_sampling: bool,
37 pub max_full_layers: usize,
39 pub sampling_strategy: SamplingStrategy,
41 pub enable_hierarchical: bool,
43 pub enable_streaming: bool,
45 pub max_memory_mb: usize,
47 pub stream_chunk_size: usize,
49 pub enable_progressive_loading: bool,
51 pub output_format: VisualizationFormat,
53}
54
55impl Default for LargeModelVisualizerConfig {
56 fn default() -> Self {
57 Self {
58 enable_smart_sampling: true,
59 max_full_layers: 50,
60 sampling_strategy: SamplingStrategy::Adaptive,
61 enable_hierarchical: true,
62 enable_streaming: true,
63 max_memory_mb: 1024, stream_chunk_size: 10,
65 enable_progressive_loading: true,
66 output_format: VisualizationFormat::InteractiveSvg,
67 }
68 }
69}
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
73pub enum SamplingStrategy {
74 Uniform,
76 Adaptive,
78 Representative,
80 ImportanceBased,
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
86pub enum VisualizationFormat {
87 StaticPng,
89 StaticSvg,
91 InteractiveSvg,
93 InteractiveHtml,
95 TextSummary,
97 JsonMetadata,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct LayerMetadata {
104 pub name: String,
106 pub index: usize,
108 pub layer_type: String,
110 pub param_count: usize,
112 pub memory_mb: f64,
114 pub compute_flops: u64,
116 pub input_shape: Vec<usize>,
118 pub output_shape: Vec<usize>,
120 pub is_sampled: bool,
122}
123
124#[derive(Debug, Clone, Default)]
126struct VisualizationState {
127 total_layers: usize,
129 #[allow(dead_code)]
131 loaded_layers: Vec<String>,
132 current_memory_mb: f64,
134 progress: f64,
136 #[allow(dead_code)]
138 is_complete: bool,
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct VisualizationResult {
144 pub output_path: Option<String>,
146 pub inline_data: Option<Vec<u8>>,
148 pub stats: VisualizationStats,
150 pub sampled_layers: Vec<usize>,
152 pub model_stats: ModelStatistics,
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct VisualizationStats {
159 pub layers_visualized: usize,
161 pub total_layers: usize,
163 pub sampling_ratio: f64,
165 pub memory_used_mb: f64,
167 pub time_taken_secs: f64,
169 pub output_size_bytes: usize,
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct ModelStatistics {
176 pub total_params: usize,
178 pub total_memory_mb: f64,
180 pub total_gflops: f64,
182 pub max_depth: usize,
184 pub layer_types: HashMap<String, usize>,
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
190pub struct LayerGroup {
191 pub name: String,
193 pub layers: Vec<usize>,
195 pub collapsed: bool,
197 pub summary: GroupSummary,
199}
200
201#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct GroupSummary {
204 pub param_count: usize,
206 pub memory_mb: f64,
208 pub avg_compute_flops: u64,
210}
211
212impl LargeModelVisualizer {
213 pub fn new(config: LargeModelVisualizerConfig) -> Self {
226 info!("Initializing large model visualizer");
227 Self {
228 config,
229 layer_cache: Arc::new(RwLock::new(HashMap::new())),
230 state: Arc::new(RwLock::new(VisualizationState::default())),
231 }
232 }
233
234 pub fn add_layer(&self, metadata: LayerMetadata) -> Result<()> {
239 let mut cache = self.layer_cache.write();
240 let mut state = self.state.write();
241
242 cache.insert(metadata.name.clone(), metadata.clone());
243 state.total_layers = cache.len();
244 state.current_memory_mb += metadata.memory_mb;
245
246 if state.current_memory_mb > self.config.max_memory_mb as f64 {
248 warn!(
249 "Memory limit exceeded: {:.1} MB > {} MB. Consider increasing max_memory_mb or enabling sampling",
250 state.current_memory_mb,
251 self.config.max_memory_mb
252 );
253 }
254
255 Ok(())
256 }
257
258 pub fn determine_sampling(&self) -> Result<Vec<usize>> {
263 let cache = self.layer_cache.read();
264 let state = self.state.read();
265
266 if !self.config.enable_smart_sampling || state.total_layers <= self.config.max_full_layers {
267 return Ok((0..state.total_layers).collect());
269 }
270
271 debug!(
272 "Applying {:?} sampling strategy for {} layers",
273 self.config.sampling_strategy, state.total_layers
274 );
275
276 let sampled_indices = match self.config.sampling_strategy {
277 SamplingStrategy::Uniform => self.uniform_sampling(state.total_layers),
278 SamplingStrategy::Adaptive => self.adaptive_sampling(&cache),
279 SamplingStrategy::Representative => self.representative_sampling(state.total_layers),
280 SamplingStrategy::ImportanceBased => self.importance_sampling(&cache),
281 };
282
283 Ok(sampled_indices)
284 }
285
286 fn uniform_sampling(&self, total_layers: usize) -> Vec<usize> {
288 let max_layers = self.config.max_full_layers;
289 let step = (total_layers as f64 / max_layers as f64).ceil() as usize;
290
291 (0..total_layers).step_by(step).collect()
292 }
293
294 fn adaptive_sampling(&self, cache: &HashMap<String, LayerMetadata>) -> Vec<usize> {
296 let mut layers: Vec<_> = cache.values().collect();
297 layers.sort_by_key(|l| l.index);
298
299 let mut sampled = Vec::new();
300 let max_layers = self.config.max_full_layers;
301
302 if !layers.is_empty() {
304 sampled.push(0);
305 sampled.push(layers.len() - 1);
306 }
307
308 let mut variances = Vec::new();
310 for i in 0..layers.len().saturating_sub(1) {
311 let complexity_diff =
312 (layers[i + 1].param_count as i64 - layers[i].param_count as i64).abs();
313 variances.push((i, complexity_diff));
314 }
315
316 variances.sort_by_key(|item| std::cmp::Reverse(item.1));
318
319 for (idx, _) in variances.iter().take(max_layers.saturating_sub(2)) {
321 sampled.push(*idx);
322 }
323
324 sampled.sort_unstable();
325 sampled.dedup();
326 sampled
327 }
328
329 fn representative_sampling(&self, total_layers: usize) -> Vec<usize> {
331 let mut sampled = Vec::new();
332
333 if total_layers == 0 {
334 return sampled;
335 }
336
337 sampled.extend(0..3.min(total_layers));
339
340 let mid = total_layers / 2;
342 sampled.extend((mid.saturating_sub(1))..=(mid + 1).min(total_layers - 1));
343
344 sampled.extend((total_layers.saturating_sub(3))..total_layers);
346
347 let remaining_budget = self.config.max_full_layers.saturating_sub(sampled.len());
349 let step = (total_layers as f64 / remaining_budget as f64).ceil() as usize;
350
351 for i in (0..total_layers).step_by(step) {
352 sampled.push(i);
353 }
354
355 sampled.sort_unstable();
356 sampled.dedup();
357 sampled
358 }
359
360 fn importance_sampling(&self, cache: &HashMap<String, LayerMetadata>) -> Vec<usize> {
362 let mut layers: Vec<_> = cache.values().collect();
363
364 layers.sort_by(|a, b| {
366 let score_a = (a.param_count as f64) + (a.compute_flops as f64 / 1e9);
367 let score_b = (b.param_count as f64) + (b.compute_flops as f64 / 1e9);
368 score_b.partial_cmp(&score_a).unwrap_or(std::cmp::Ordering::Equal)
369 });
370
371 layers.iter().take(self.config.max_full_layers).map(|l| l.index).collect()
372 }
373
374 pub fn create_layer_groups(&self) -> Result<Vec<LayerGroup>> {
378 let cache = self.layer_cache.read();
379
380 if !self.config.enable_hierarchical || cache.len() < 20 {
381 return Ok(Vec::new());
383 }
384
385 let mut groups: HashMap<String, Vec<usize>> = HashMap::new();
386
387 for metadata in cache.values() {
389 groups.entry(metadata.layer_type.clone()).or_default().push(metadata.index);
390 }
391
392 let mut layer_groups = Vec::new();
394
395 for (layer_type, indices) in groups {
396 let group_layers: Vec<_> = indices
398 .iter()
399 .filter_map(|&idx| cache.values().find(|l| l.index == idx))
400 .collect();
401
402 let param_count: usize = group_layers.iter().map(|l| l.param_count).sum();
403 let memory_mb: f64 = group_layers.iter().map(|l| l.memory_mb).sum();
404 let avg_compute_flops = if !group_layers.is_empty() {
405 group_layers.iter().map(|l| l.compute_flops).sum::<u64>()
406 / group_layers.len() as u64
407 } else {
408 0
409 };
410
411 let indices_len = indices.len();
412 layer_groups.push(LayerGroup {
413 name: format!("{} ({} layers)", layer_type, indices_len),
414 layers: indices,
415 collapsed: indices_len > 10, summary: GroupSummary {
417 param_count,
418 memory_mb,
419 avg_compute_flops,
420 },
421 });
422 }
423
424 layer_groups.sort_by_key(|g| g.layers.first().copied().unwrap_or(0));
426
427 Ok(layer_groups)
428 }
429
430 pub fn visualize(&self, output_path: Option<String>) -> Result<VisualizationResult> {
438 info!("Starting large model visualization");
439
440 let start_time = std::time::Instant::now();
441
442 let sampled_layers = self.determine_sampling()?;
444
445 info!(
446 "Visualizing {} out of {} layers",
447 sampled_layers.len(),
448 self.state.read().total_layers
449 );
450
451 let model_stats = self.calculate_model_stats()?;
453
454 let (output_data, output_size) = match self.config.output_format {
456 VisualizationFormat::TextSummary => self.generate_text_summary(&sampled_layers)?,
457 VisualizationFormat::JsonMetadata => self.generate_json_metadata(&sampled_layers)?,
458 VisualizationFormat::StaticSvg => self.generate_static_svg(&sampled_layers)?,
459 VisualizationFormat::InteractiveSvg => {
460 self.generate_interactive_svg(&sampled_layers)?
461 },
462 VisualizationFormat::InteractiveHtml => {
463 self.generate_interactive_html(&sampled_layers)?
464 },
465 VisualizationFormat::StaticPng => {
466 #[cfg(feature = "video")]
471 {
472 self.generate_png(&sampled_layers)?
473 }
474 #[cfg(not(feature = "video"))]
475 {
476 return Err(anyhow::anyhow!(
477 "PNG generation requires the `video` feature. \
478 Rebuild with `--features video`, or use \
479 VisualizationFormat::StaticSvg / InteractiveHtml instead."
480 ));
481 }
482 },
483 };
484
485 if let Some(ref path) = output_path {
487 std::fs::write(path, &output_data)
488 .with_context(|| format!("Failed to write visualization to {}", path))?;
489 info!("Saved visualization to {}", path);
490 }
491
492 let time_taken = start_time.elapsed().as_secs_f64();
493 let state = self.state.read();
494
495 Ok(VisualizationResult {
496 output_path,
497 inline_data: if output_size < 1024 * 1024 { Some(output_data) } else { None }, stats: VisualizationStats {
499 layers_visualized: sampled_layers.len(),
500 total_layers: state.total_layers,
501 sampling_ratio: sampled_layers.len() as f64 / state.total_layers as f64,
502 memory_used_mb: state.current_memory_mb,
503 time_taken_secs: time_taken,
504 output_size_bytes: output_size,
505 },
506 sampled_layers,
507 model_stats,
508 })
509 }
510
511 fn calculate_model_stats(&self) -> Result<ModelStatistics> {
513 let cache = self.layer_cache.read();
514
515 let total_params: usize = cache.values().map(|l| l.param_count).sum();
516 let total_memory_mb: f64 = cache.values().map(|l| l.memory_mb).sum();
517 let total_gflops: f64 = cache.values().map(|l| l.compute_flops).sum::<u64>() as f64 / 1e9;
518 let max_depth = cache.values().map(|l| l.index).max().unwrap_or(0);
519
520 let mut layer_types: HashMap<String, usize> = HashMap::new();
521 for metadata in cache.values() {
522 *layer_types.entry(metadata.layer_type.clone()).or_insert(0) += 1;
523 }
524
525 Ok(ModelStatistics {
526 total_params,
527 total_memory_mb,
528 total_gflops,
529 max_depth,
530 layer_types,
531 })
532 }
533
534 fn generate_text_summary(&self, sampled_layers: &[usize]) -> Result<(Vec<u8>, usize)> {
536 let cache = self.layer_cache.read();
537
538 let mut summary = String::from("=== Large Model Visualization Summary ===\n\n");
539
540 summary.push_str(&format!(
541 "Total Layers: {}\n",
542 self.state.read().total_layers
543 ));
544 summary.push_str(&format!("Visualized Layers: {}\n\n", sampled_layers.len()));
545
546 summary.push_str("Layer Details:\n");
547 for &idx in sampled_layers {
548 if let Some(layer) = cache.values().find(|l| l.index == idx) {
549 summary.push_str(&format!(
550 " [{}] {} - {} params, {:.2} MB, {:.1} GFLOPS\n",
551 layer.index,
552 layer.name,
553 layer.param_count,
554 layer.memory_mb,
555 layer.compute_flops as f64 / 1e9
556 ));
557 }
558 }
559
560 let bytes = summary.into_bytes();
561 let size = bytes.len();
562 Ok((bytes, size))
563 }
564
565 fn generate_json_metadata(&self, sampled_layers: &[usize]) -> Result<(Vec<u8>, usize)> {
567 let cache = self.layer_cache.read();
568
569 let layers: Vec<_> = sampled_layers
570 .iter()
571 .filter_map(|&idx| cache.values().find(|l| l.index == idx).cloned())
572 .collect();
573
574 let json = serde_json::to_string_pretty(&layers)?;
575 let bytes = json.into_bytes();
576 let size = bytes.len();
577 Ok((bytes, size))
578 }
579
580 fn generate_static_svg(&self, sampled_layers: &[usize]) -> Result<(Vec<u8>, usize)> {
582 let cache = self.layer_cache.read();
583
584 let mut svg = String::from(
585 r#"<?xml version="1.0" encoding="UTF-8"?>
586<svg xmlns="http://www.w3.org/2000/svg" width="1200" height="800" viewBox="0 0 1200 800">
587<style>
588.layer { fill: #4a90e2; stroke: #2c5aa0; stroke-width: 2; }
589.layer-text { fill: white; font-family: Arial, sans-serif; font-size: 12px; }
590.title { font-family: Arial, sans-serif; font-size: 20px; font-weight: bold; }
591</style>
592<text x="600" y="30" class="title" text-anchor="middle">Model Architecture</text>
593"#,
594 );
595
596 let layer_height = 60;
597 let layer_width = 200;
598 let x_offset = 500;
599 let y_start = 60;
600
601 for (i, &idx) in sampled_layers.iter().enumerate() {
602 if let Some(layer) = cache.values().find(|l| l.index == idx) {
603 let y = y_start + i * (layer_height + 20);
604
605 svg.push_str(&format!(
606 r#"<rect x="{}" y="{}" width="{}" height="{}" class="layer" />
607<text x="{}" y="{}" class="layer-text" text-anchor="middle">{}</text>
608<text x="{}" y="{}" class="layer-text" text-anchor="middle">{:.1}M params</text>
609"#,
610 x_offset,
611 y,
612 layer_width,
613 layer_height,
614 x_offset + layer_width / 2,
615 y + 25,
616 layer.name,
617 x_offset + layer_width / 2,
618 y + 45,
619 layer.param_count as f64 / 1e6
620 ));
621 }
622 }
623
624 svg.push_str("</svg>");
625
626 let bytes = svg.into_bytes();
627 let size = bytes.len();
628 Ok((bytes, size))
629 }
630
631 fn generate_interactive_svg(&self, sampled_layers: &[usize]) -> Result<(Vec<u8>, usize)> {
633 let cache = self.layer_cache.read();
634
635 let layer_height = 60usize;
636 let layer_width = 200usize;
637 let x_offset = 500usize;
638 let y_start = 60usize;
639 let svg_height = y_start + sampled_layers.len() * (layer_height + 20) + 40;
640 let svg_width = 1200usize;
641
642 let mut layer_elems = String::new();
644 for (i, &idx) in sampled_layers.iter().enumerate() {
645 if let Some(layer) = cache.values().find(|l| l.index == idx) {
646 let y = y_start + i * (layer_height + 20);
647 layer_elems.push_str(&format!(
648 r#"<rect x="{x}" y="{y}" width="{w}" height="{h}" class="layer" />
649<text x="{cx}" y="{ty}" class="layer-text" text-anchor="middle">{name}</text>
650<text x="{cx}" y="{py}" class="layer-text" text-anchor="middle">{params:.1}M params</text>
651"#,
652 x = x_offset,
653 y = y,
654 w = layer_width,
655 h = layer_height,
656 cx = x_offset + layer_width / 2,
657 ty = y + 25,
658 py = y + 45,
659 name = layer.name,
660 params = layer.param_count as f64 / 1e6
661 ));
662 }
663 }
664
665 let svg = format!(
670 r#"<?xml version="1.0" encoding="UTF-8"?>
671<svg xmlns="http://www.w3.org/2000/svg"
672 xmlns:xlink="http://www.w3.org/1999/xlink"
673 id="svg-root"
674 width="{width}" height="{height}"
675 viewBox="0 0 {width} {height}"
676 style="cursor:grab;user-select:none;">
677<style>
678.layer {{ fill: #4a90e2; stroke: #2c5aa0; stroke-width: 2; }}
679.layer-text {{ fill: white; font-family: Arial, sans-serif; font-size: 12px; }}
680.title {{ font-family: Arial, sans-serif; font-size: 20px; font-weight: bold; }}
681</style>
682<text x="{title_x}" y="30" class="title" text-anchor="middle">Model Architecture (interactive)</text>
683<g id="viewport">
684{layers}
685</g>
686<script type="text/javascript"><![CDATA[
687(function() {{
688 var svg = document.getElementById('svg-root');
689 var vp = document.getElementById('viewport');
690 var tx = 0, ty = 0, scale = 1.0;
691 var dragging = false;
692 var startX = 0, startY = 0;
693
694 function applyTransform() {{
695 vp.setAttribute('transform',
696 'translate(' + tx + ',' + ty + ') scale(' + scale + ')');
697 }}
698
699 // Pan: mousedown / mousemove / mouseup
700 svg.addEventListener('mousedown', function(e) {{
701 dragging = true;
702 startX = e.clientX - tx;
703 startY = e.clientY - ty;
704 svg.style.cursor = 'grabbing';
705 e.preventDefault();
706 }});
707 window.addEventListener('mousemove', function(e) {{
708 if (!dragging) return;
709 tx = e.clientX - startX;
710 ty = e.clientY - startY;
711 applyTransform();
712 }});
713 window.addEventListener('mouseup', function() {{
714 dragging = false;
715 svg.style.cursor = 'grab';
716 }});
717
718 // Touch pan
719 var lastTouch = null;
720 svg.addEventListener('touchstart', function(e) {{
721 if (e.touches.length === 1) {{
722 lastTouch = e.touches[0];
723 }}
724 e.preventDefault();
725 }}, {{ passive: false }});
726 svg.addEventListener('touchmove', function(e) {{
727 if (e.touches.length === 1 && lastTouch) {{
728 var t = e.touches[0];
729 tx += t.clientX - lastTouch.clientX;
730 ty += t.clientY - lastTouch.clientY;
731 lastTouch = t;
732 applyTransform();
733 }}
734 e.preventDefault();
735 }}, {{ passive: false }});
736 svg.addEventListener('touchend', function() {{ lastTouch = null; }});
737
738 // Zoom: mousewheel
739 svg.addEventListener('wheel', function(e) {{
740 e.preventDefault();
741 var delta = e.deltaY > 0 ? 0.9 : 1.1;
742 // Zoom towards cursor position
743 var rect = svg.getBoundingClientRect();
744 var mx = e.clientX - rect.left;
745 var my = e.clientY - rect.top;
746 tx = mx - (mx - tx) * delta;
747 ty = my - (my - ty) * delta;
748 scale = Math.max(0.1, Math.min(10.0, scale * delta));
749 applyTransform();
750 }}, {{ passive: false }});
751
752 // Double-click to reset
753 svg.addEventListener('dblclick', function() {{
754 tx = 0; ty = 0; scale = 1.0;
755 applyTransform();
756 }});
757}})();
758]]></script>
759</svg>"#,
760 width = svg_width,
761 height = svg_height,
762 title_x = svg_width / 2,
763 layers = layer_elems,
764 );
765
766 let bytes = svg.into_bytes();
767 let size = bytes.len();
768 Ok((bytes, size))
769 }
770
771 fn generate_interactive_html(&self, sampled_layers: &[usize]) -> Result<(Vec<u8>, usize)> {
773 let cache = self.layer_cache.read();
774 let model_stats = self.calculate_model_stats()?;
775
776 let mut html = String::from(
777 r#"<!DOCTYPE html>
778<html>
779<head>
780<meta charset="UTF-8">
781<title>Large Model Visualization</title>
782<style>
783body { font-family: Arial, sans-serif; margin: 20px; background: #f5f5f5; }
784.container { max-width: 1200px; margin: 0 auto; }
785.header { background: #4a90e2; color: white; padding: 20px; border-radius: 8px; }
786.stats { display: grid; grid-template-columns: repeat(4, 1fr); gap: 15px; margin: 20px 0; }
787.stat-card { background: white; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); }
788.layer-list { background: white; padding: 20px; border-radius: 8px; }
789.layer { padding: 10px; margin: 5px 0; background: #f9f9f9; border-left: 4px solid #4a90e2; }
790</style>
791</head>
792<body>
793<div class="container">
794<div class="header">
795<h1>Large Model Visualization</h1>
796<p>Interactive view of model architecture</p>
797</div>
798<div class="stats">
799"#,
800 );
801
802 html.push_str(&format!(
804 r#"<div class="stat-card">
805<h3>{:.1}M</h3>
806<p>Total Parameters</p>
807</div>
808<div class="stat-card">
809<h3>{:.1} GB</h3>
810<p>Total Memory</p>
811</div>
812<div class="stat-card">
813<h3>{}</h3>
814<p>Total Layers</p>
815</div>
816<div class="stat-card">
817<h3>{}/{}</h3>
818<p>Visualized/Total</p>
819</div>
820"#,
821 model_stats.total_params as f64 / 1e6,
822 model_stats.total_memory_mb / 1024.0,
823 model_stats.max_depth + 1,
824 sampled_layers.len(),
825 self.state.read().total_layers
826 ));
827
828 html.push_str("</div><div class=\"layer-list\"><h2>Layer Details</h2>");
829
830 for &idx in sampled_layers {
832 if let Some(layer) = cache.values().find(|l| l.index == idx) {
833 html.push_str(&format!(
834 r#"<div class="layer">
835<strong>[{}] {}</strong><br>
836Type: {} | Parameters: {:.1}M | Memory: {:.2} MB | Compute: {:.1} GFLOPS
837</div>
838"#,
839 layer.index,
840 layer.name,
841 layer.layer_type,
842 layer.param_count as f64 / 1e6,
843 layer.memory_mb,
844 layer.compute_flops as f64 / 1e9
845 ));
846 }
847 }
848
849 html.push_str("</div></div></body></html>");
850
851 let bytes = html.into_bytes();
852 let size = bytes.len();
853 Ok((bytes, size))
854 }
855
856 #[cfg(feature = "video")]
864 fn generate_png(&self, sampled_layers: &[usize]) -> Result<(Vec<u8>, usize)> {
865 use image::{ImageBuffer, Rgb};
866 use std::io::Cursor;
867
868 let cache = self.layer_cache.read();
869
870 let mut layers: Vec<&LayerMetadata> = sampled_layers
872 .iter()
873 .filter_map(|&idx| cache.values().find(|l| l.index == idx))
874 .collect();
875 layers.sort_by_key(|l| l.index);
876
877 const IMG_WIDTH: u32 = 1200;
879 const BAR_HEIGHT: u32 = 30;
880 const BAR_PADDING: u32 = 6;
881 const LEFT_MARGIN: u32 = 20;
882 const RIGHT_MARGIN: u32 = 20;
883
884 let row_height = BAR_HEIGHT + BAR_PADDING;
885 let img_height = if layers.is_empty() {
886 100
887 } else {
888 layers.len() as u32 * row_height + 2 * BAR_PADDING + 40 };
890
891 let max_params = layers.iter().map(|l| l.param_count).max().unwrap_or(1).max(1);
892
893 let max_memory = layers.iter().map(|l| l.memory_mb).fold(0.0_f64, f64::max).max(1.0);
894
895 let available_width = IMG_WIDTH - LEFT_MARGIN - RIGHT_MARGIN;
896
897 let mut img = ImageBuffer::<Rgb<u8>, Vec<u8>>::new(IMG_WIDTH, img_height);
898
899 for pixel in img.pixels_mut() {
901 *pixel = Rgb([245u8, 245u8, 250u8]);
902 }
903
904 for x in 0..IMG_WIDTH {
906 for y in 0..36 {
907 img.put_pixel(x, y, Rgb([74u8, 144u8, 226u8]));
908 }
909 }
910
911 for (i, layer) in layers.iter().enumerate() {
913 let bar_top = 40 + i as u32 * row_height;
914
915 let bar_w = ((layer.param_count as f64 / max_params as f64) * available_width as f64)
917 .round() as u32;
918 let bar_w = bar_w.max(4); let t = (layer.memory_mb / max_memory).clamp(0.0, 1.0) as f32;
922 let r = (t * 220.0) as u8;
923 let g = ((1.0 - t) * 100.0 + 40.0) as u8;
924 let b = ((1.0 - t) * 220.0) as u8;
925 let bar_colour = Rgb([r, g, b]);
926
927 for x in LEFT_MARGIN..(LEFT_MARGIN + bar_w).min(IMG_WIDTH - RIGHT_MARGIN) {
928 for y in bar_top..(bar_top + BAR_HEIGHT).min(img_height) {
929 img.put_pixel(x, y, bar_colour);
930 }
931 }
932 }
933
934 let mut png_bytes: Vec<u8> = Vec::new();
936 img.write_to(&mut Cursor::new(&mut png_bytes), image::ImageFormat::Png)
937 .with_context(|| "Failed to PNG-encode large model visualization")?;
938
939 let size = png_bytes.len();
940 Ok((png_bytes, size))
941 }
942
943 pub fn get_progress(&self) -> f64 {
945 self.state.read().progress
946 }
947
948 pub fn get_memory_stats(&self) -> MemoryStats {
950 let state = self.state.read();
951 MemoryStats {
952 current_mb: state.current_memory_mb,
953 max_mb: self.config.max_memory_mb as f64,
954 utilization_pct: (state.current_memory_mb / self.config.max_memory_mb as f64 * 100.0)
955 .min(100.0),
956 }
957 }
958}
959
960#[derive(Debug, Clone, Serialize, Deserialize)]
962pub struct MemoryStats {
963 pub current_mb: f64,
965 pub max_mb: f64,
967 pub utilization_pct: f64,
969}
970
971#[cfg(test)]
972mod tests {
973 use super::*;
974
975 #[test]
976 fn test_visualizer_creation() {
977 let config = LargeModelVisualizerConfig::default();
978 let _visualizer = LargeModelVisualizer::new(config);
979 }
980
981 #[test]
982 fn test_add_layers() -> Result<()> {
983 let config = LargeModelVisualizerConfig::default();
984 let visualizer = LargeModelVisualizer::new(config);
985
986 for i in 0..10 {
987 let metadata = LayerMetadata {
988 name: format!("layer_{}", i),
989 index: i,
990 layer_type: "Linear".to_string(),
991 param_count: 1024 * 1024,
992 memory_mb: 4.0,
993 compute_flops: 1_000_000_000,
994 input_shape: vec![512],
995 output_shape: vec![512],
996 is_sampled: false,
997 };
998 visualizer.add_layer(metadata)?;
999 }
1000
1001 let stats = visualizer.get_memory_stats();
1002 assert_eq!(stats.current_mb, 40.0);
1003
1004 Ok(())
1005 }
1006
1007 #[test]
1008 fn test_uniform_sampling() -> Result<()> {
1009 let config = LargeModelVisualizerConfig {
1010 max_full_layers: 5,
1011 sampling_strategy: SamplingStrategy::Uniform,
1012 ..Default::default()
1013 };
1014
1015 let visualizer = LargeModelVisualizer::new(config);
1016
1017 for i in 0..20 {
1019 let metadata = LayerMetadata {
1020 name: format!("layer_{}", i),
1021 index: i,
1022 layer_type: "Linear".to_string(),
1023 param_count: 1024 * 1024,
1024 memory_mb: 4.0,
1025 compute_flops: 1_000_000_000,
1026 input_shape: vec![512],
1027 output_shape: vec![512],
1028 is_sampled: false,
1029 };
1030 visualizer.add_layer(metadata)?;
1031 }
1032
1033 let sampled = visualizer.determine_sampling()?;
1034 assert_eq!(sampled.len(), 5);
1035
1036 Ok(())
1037 }
1038
1039 #[cfg(feature = "video")]
1040 #[test]
1041 fn test_png_visualization() -> Result<()> {
1042 let config = LargeModelVisualizerConfig {
1043 output_format: VisualizationFormat::StaticPng,
1044 ..Default::default()
1045 };
1046
1047 let visualizer = LargeModelVisualizer::new(config);
1048
1049 for i in 0..5_usize {
1050 let metadata = LayerMetadata {
1051 name: format!("layer_{}", i),
1052 index: i,
1053 layer_type: "Linear".to_string(),
1054 param_count: 1024 * (i + 1),
1055 memory_mb: 2.0 * (i + 1) as f64,
1056 compute_flops: 500_000_000 * (i + 1) as u64,
1057 input_shape: vec![512],
1058 output_shape: vec![512],
1059 is_sampled: false,
1060 };
1061 visualizer.add_layer(metadata)?;
1062 }
1063
1064 let result = visualizer.visualize(None)?;
1065
1066 assert_eq!(result.stats.layers_visualized, 5);
1068 assert!(
1069 result.stats.output_size_bytes > 0,
1070 "PNG output must be non-empty"
1071 );
1072
1073 let data = result.inline_data.expect("inline data should be present for small PNG");
1075 assert!(
1076 data.starts_with(&[0x89, 0x50, 0x4E, 0x47]),
1077 "Output must start with PNG magic bytes"
1078 );
1079
1080 Ok(())
1081 }
1082
1083 #[test]
1084 fn test_text_visualization() -> Result<()> {
1085 let config = LargeModelVisualizerConfig {
1086 output_format: VisualizationFormat::TextSummary,
1087 ..Default::default()
1088 };
1089
1090 let visualizer = LargeModelVisualizer::new(config);
1091
1092 for i in 0..5 {
1094 let metadata = LayerMetadata {
1095 name: format!("layer_{}", i),
1096 index: i,
1097 layer_type: "Linear".to_string(),
1098 param_count: 1024 * 1024 * (i + 1),
1099 memory_mb: 4.0 * (i + 1) as f64,
1100 compute_flops: 1_000_000_000 * (i + 1) as u64,
1101 input_shape: vec![512],
1102 output_shape: vec![512],
1103 is_sampled: false,
1104 };
1105 visualizer.add_layer(metadata)?;
1106 }
1107
1108 let result = visualizer.visualize(None)?;
1109
1110 assert_eq!(result.stats.layers_visualized, 5);
1111 assert!(result.stats.output_size_bytes > 0);
1112
1113 Ok(())
1114 }
1115}