Skip to main content

trustformers_debug/
large_model_viz.rs

1//! Large Model Visualization with Memory Efficiency
2//!
3//! This module provides optimized visualization for large transformer models,
4//! using smart sampling, hierarchical rendering, and memory-efficient techniques
5//! to handle models with billions of parameters.
6
7use anyhow::{Context, Result};
8use parking_lot::RwLock;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use tracing::{debug, info, warn};
13
14/// Large model visualizer with memory-efficient rendering
15///
16/// Features:
17/// - Smart layer sampling (visualize representative layers)
18/// - Hierarchical graph rendering (collapse/expand sections)
19/// - Streaming visualization (process in chunks)
20/// - Memory-bounded caching
21/// - Progressive loading
22#[derive(Debug)]
23pub struct LargeModelVisualizer {
24    /// Configuration
25    config: LargeModelVisualizerConfig,
26    /// Cached layer metadata
27    layer_cache: Arc<RwLock<HashMap<String, LayerMetadata>>>,
28    /// Visualization state
29    state: Arc<RwLock<VisualizationState>>,
30}
31
32/// Configuration for large model visualization
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct LargeModelVisualizerConfig {
35    /// Enable smart layer sampling
36    pub enable_smart_sampling: bool,
37    /// Maximum layers to visualize fully (rest are sampled)
38    pub max_full_layers: usize,
39    /// Sampling strategy
40    pub sampling_strategy: SamplingStrategy,
41    /// Enable hierarchical rendering
42    pub enable_hierarchical: bool,
43    /// Enable streaming mode for very large models
44    pub enable_streaming: bool,
45    /// Maximum memory for visualization (MB)
46    pub max_memory_mb: usize,
47    /// Chunk size for streaming (number of layers)
48    pub stream_chunk_size: usize,
49    /// Enable progressive detail loading
50    pub enable_progressive_loading: bool,
51    /// Visualization format
52    pub output_format: VisualizationFormat,
53}
54
55impl Default for LargeModelVisualizerConfig {
56    fn default() -> Self {
57        Self {
58            enable_smart_sampling: true,
59            max_full_layers: 50,
60            sampling_strategy: SamplingStrategy::Adaptive,
61            enable_hierarchical: true,
62            enable_streaming: true,
63            max_memory_mb: 1024, // 1 GB
64            stream_chunk_size: 10,
65            enable_progressive_loading: true,
66            output_format: VisualizationFormat::InteractiveSvg,
67        }
68    }
69}
70
71/// Layer sampling strategy for large models
72#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
73pub enum SamplingStrategy {
74    /// Uniform sampling (evenly spaced layers)
75    Uniform,
76    /// Adaptive sampling (more samples where complexity varies)
77    Adaptive,
78    /// Representative sampling (first, middle, last + interesting layers)
79    Representative,
80    /// Importance-based (based on parameter count, compute cost)
81    ImportanceBased,
82}
83
84/// Visualization output format
85#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
86pub enum VisualizationFormat {
87    /// Static PNG image (memory efficient)
88    StaticPng,
89    /// Static SVG (scalable but larger)
90    StaticSvg,
91    /// Interactive SVG with zoom/pan
92    InteractiveSvg,
93    /// Interactive HTML with JavaScript
94    InteractiveHtml,
95    /// Text-based summary (minimal memory)
96    TextSummary,
97    /// JSON metadata only
98    JsonMetadata,
99}
100
101/// Metadata about a model layer
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct LayerMetadata {
104    /// Layer name
105    pub name: String,
106    /// Layer index
107    pub index: usize,
108    /// Layer type
109    pub layer_type: String,
110    /// Number of parameters
111    pub param_count: usize,
112    /// Estimated memory (MB)
113    pub memory_mb: f64,
114    /// Estimated compute cost (FLOPS)
115    pub compute_flops: u64,
116    /// Input shape
117    pub input_shape: Vec<usize>,
118    /// Output shape
119    pub output_shape: Vec<usize>,
120    /// Is this layer sampled for visualization?
121    pub is_sampled: bool,
122}
123
124/// Current visualization state
125#[derive(Debug, Clone, Default)]
126struct VisualizationState {
127    /// Total layers in model
128    total_layers: usize,
129    /// Layers currently loaded
130    #[allow(dead_code)]
131    loaded_layers: Vec<String>,
132    /// Current memory usage (MB)
133    current_memory_mb: f64,
134    /// Visualization progress (0.0-1.0)
135    progress: f64,
136    /// Is visualization complete?
137    #[allow(dead_code)]
138    is_complete: bool,
139}
140
141/// Visualization result
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct VisualizationResult {
144    /// Output file path (if saved to file)
145    pub output_path: Option<String>,
146    /// Inline data (if small enough)
147    pub inline_data: Option<Vec<u8>>,
148    /// Visualization statistics
149    pub stats: VisualizationStats,
150    /// Sampled layer indices
151    pub sampled_layers: Vec<usize>,
152    /// Total model statistics
153    pub model_stats: ModelStatistics,
154}
155
156/// Statistics about the visualization
157#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct VisualizationStats {
159    /// Number of layers visualized
160    pub layers_visualized: usize,
161    /// Number of layers in model
162    pub total_layers: usize,
163    /// Sampling ratio
164    pub sampling_ratio: f64,
165    /// Memory used for visualization (MB)
166    pub memory_used_mb: f64,
167    /// Time taken (seconds)
168    pub time_taken_secs: f64,
169    /// Output size (bytes)
170    pub output_size_bytes: usize,
171}
172
173/// Overall model statistics
174#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct ModelStatistics {
176    /// Total parameters
177    pub total_params: usize,
178    /// Total memory footprint (MB)
179    pub total_memory_mb: f64,
180    /// Total compute cost (GFLOPS)
181    pub total_gflops: f64,
182    /// Deepest layer index
183    pub max_depth: usize,
184    /// Layer type distribution
185    pub layer_types: HashMap<String, usize>,
186}
187
188/// Layer group for hierarchical visualization
189#[derive(Debug, Clone, Serialize, Deserialize)]
190pub struct LayerGroup {
191    /// Group name
192    pub name: String,
193    /// Layer indices in this group
194    pub layers: Vec<usize>,
195    /// Is this group collapsed?
196    pub collapsed: bool,
197    /// Summary statistics for group
198    pub summary: GroupSummary,
199}
200
201/// Summary for a layer group
202#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct GroupSummary {
204    /// Total parameters in group
205    pub param_count: usize,
206    /// Total memory (MB)
207    pub memory_mb: f64,
208    /// Average compute cost per layer
209    pub avg_compute_flops: u64,
210}
211
212impl LargeModelVisualizer {
213    /// Create a new large model visualizer
214    ///
215    /// # Arguments
216    /// * `config` - Visualizer configuration
217    ///
218    /// # Example
219    /// ```rust
220    /// use trustformers_debug::{LargeModelVisualizer, LargeModelVisualizerConfig};
221    ///
222    /// let config = LargeModelVisualizerConfig::default();
223    /// let visualizer = LargeModelVisualizer::new(config);
224    /// ```
225    pub fn new(config: LargeModelVisualizerConfig) -> Self {
226        info!("Initializing large model visualizer");
227        Self {
228            config,
229            layer_cache: Arc::new(RwLock::new(HashMap::new())),
230            state: Arc::new(RwLock::new(VisualizationState::default())),
231        }
232    }
233
234    /// Add layer metadata to the visualizer
235    ///
236    /// # Arguments
237    /// * `metadata` - Layer metadata
238    pub fn add_layer(&self, metadata: LayerMetadata) -> Result<()> {
239        let mut cache = self.layer_cache.write();
240        let mut state = self.state.write();
241
242        cache.insert(metadata.name.clone(), metadata.clone());
243        state.total_layers = cache.len();
244        state.current_memory_mb += metadata.memory_mb;
245
246        // Check memory limit
247        if state.current_memory_mb > self.config.max_memory_mb as f64 {
248            warn!(
249                "Memory limit exceeded: {:.1} MB > {} MB. Consider increasing max_memory_mb or enabling sampling",
250                state.current_memory_mb,
251                self.config.max_memory_mb
252            );
253        }
254
255        Ok(())
256    }
257
258    /// Analyze model and determine sampling strategy
259    ///
260    /// # Returns
261    /// Indices of layers to visualize in detail
262    pub fn determine_sampling(&self) -> Result<Vec<usize>> {
263        let cache = self.layer_cache.read();
264        let state = self.state.read();
265
266        if !self.config.enable_smart_sampling || state.total_layers <= self.config.max_full_layers {
267            // Visualize all layers
268            return Ok((0..state.total_layers).collect());
269        }
270
271        debug!(
272            "Applying {:?} sampling strategy for {} layers",
273            self.config.sampling_strategy, state.total_layers
274        );
275
276        let sampled_indices = match self.config.sampling_strategy {
277            SamplingStrategy::Uniform => self.uniform_sampling(state.total_layers),
278            SamplingStrategy::Adaptive => self.adaptive_sampling(&cache),
279            SamplingStrategy::Representative => self.representative_sampling(state.total_layers),
280            SamplingStrategy::ImportanceBased => self.importance_sampling(&cache),
281        };
282
283        Ok(sampled_indices)
284    }
285
286    /// Uniform sampling: evenly spaced layers
287    fn uniform_sampling(&self, total_layers: usize) -> Vec<usize> {
288        let max_layers = self.config.max_full_layers;
289        let step = (total_layers as f64 / max_layers as f64).ceil() as usize;
290
291        (0..total_layers).step_by(step).collect()
292    }
293
294    /// Adaptive sampling: more samples where complexity varies
295    fn adaptive_sampling(&self, cache: &HashMap<String, LayerMetadata>) -> Vec<usize> {
296        let mut layers: Vec<_> = cache.values().collect();
297        layers.sort_by_key(|l| l.index);
298
299        let mut sampled = Vec::new();
300        let max_layers = self.config.max_full_layers;
301
302        // Always include first and last layers
303        if !layers.is_empty() {
304            sampled.push(0);
305            sampled.push(layers.len() - 1);
306        }
307
308        // Calculate complexity variance between consecutive layers
309        let mut variances = Vec::new();
310        for i in 0..layers.len().saturating_sub(1) {
311            let complexity_diff =
312                (layers[i + 1].param_count as i64 - layers[i].param_count as i64).abs();
313            variances.push((i, complexity_diff));
314        }
315
316        // Sort by variance (descending)
317        variances.sort_by_key(|item| std::cmp::Reverse(item.1));
318
319        // Sample layers with highest variance
320        for (idx, _) in variances.iter().take(max_layers.saturating_sub(2)) {
321            sampled.push(*idx);
322        }
323
324        sampled.sort_unstable();
325        sampled.dedup();
326        sampled
327    }
328
329    /// Representative sampling: first, middle, last + interesting layers
330    fn representative_sampling(&self, total_layers: usize) -> Vec<usize> {
331        let mut sampled = Vec::new();
332
333        if total_layers == 0 {
334            return sampled;
335        }
336
337        // First layers
338        sampled.extend(0..3.min(total_layers));
339
340        // Middle layers
341        let mid = total_layers / 2;
342        sampled.extend((mid.saturating_sub(1))..=(mid + 1).min(total_layers - 1));
343
344        // Last layers
345        sampled.extend((total_layers.saturating_sub(3))..total_layers);
346
347        // Add evenly spaced samples in between
348        let remaining_budget = self.config.max_full_layers.saturating_sub(sampled.len());
349        let step = (total_layers as f64 / remaining_budget as f64).ceil() as usize;
350
351        for i in (0..total_layers).step_by(step) {
352            sampled.push(i);
353        }
354
355        sampled.sort_unstable();
356        sampled.dedup();
357        sampled
358    }
359
360    /// Importance-based sampling: prioritize large/complex layers
361    fn importance_sampling(&self, cache: &HashMap<String, LayerMetadata>) -> Vec<usize> {
362        let mut layers: Vec<_> = cache.values().collect();
363
364        // Calculate importance score (weighted sum of params and compute)
365        layers.sort_by(|a, b| {
366            let score_a = (a.param_count as f64) + (a.compute_flops as f64 / 1e9);
367            let score_b = (b.param_count as f64) + (b.compute_flops as f64 / 1e9);
368            score_b.partial_cmp(&score_a).unwrap_or(std::cmp::Ordering::Equal)
369        });
370
371        layers.iter().take(self.config.max_full_layers).map(|l| l.index).collect()
372    }
373
374    /// Create hierarchical layer groups
375    ///
376    /// Groups layers by type or sequential blocks for collapsible visualization
377    pub fn create_layer_groups(&self) -> Result<Vec<LayerGroup>> {
378        let cache = self.layer_cache.read();
379
380        if !self.config.enable_hierarchical || cache.len() < 20 {
381            // Not worth grouping small models
382            return Ok(Vec::new());
383        }
384
385        let mut groups: HashMap<String, Vec<usize>> = HashMap::new();
386
387        // Group by layer type
388        for metadata in cache.values() {
389            groups.entry(metadata.layer_type.clone()).or_default().push(metadata.index);
390        }
391
392        // Create LayerGroup objects
393        let mut layer_groups = Vec::new();
394
395        for (layer_type, indices) in groups {
396            // Calculate summary
397            let group_layers: Vec<_> = indices
398                .iter()
399                .filter_map(|&idx| cache.values().find(|l| l.index == idx))
400                .collect();
401
402            let param_count: usize = group_layers.iter().map(|l| l.param_count).sum();
403            let memory_mb: f64 = group_layers.iter().map(|l| l.memory_mb).sum();
404            let avg_compute_flops = if !group_layers.is_empty() {
405                group_layers.iter().map(|l| l.compute_flops).sum::<u64>()
406                    / group_layers.len() as u64
407            } else {
408                0
409            };
410
411            let indices_len = indices.len();
412            layer_groups.push(LayerGroup {
413                name: format!("{} ({} layers)", layer_type, indices_len),
414                layers: indices,
415                collapsed: indices_len > 10, // Auto-collapse large groups
416                summary: GroupSummary {
417                    param_count,
418                    memory_mb,
419                    avg_compute_flops,
420                },
421            });
422        }
423
424        // Sort by first layer index
425        layer_groups.sort_by_key(|g| g.layers.first().copied().unwrap_or(0));
426
427        Ok(layer_groups)
428    }
429
430    /// Generate visualization with memory-efficient rendering
431    ///
432    /// # Arguments
433    /// * `output_path` - Optional output file path
434    ///
435    /// # Returns
436    /// Visualization result with statistics
437    pub fn visualize(&self, output_path: Option<String>) -> Result<VisualizationResult> {
438        info!("Starting large model visualization");
439
440        let start_time = std::time::Instant::now();
441
442        // Determine which layers to visualize
443        let sampled_layers = self.determine_sampling()?;
444
445        info!(
446            "Visualizing {} out of {} layers",
447            sampled_layers.len(),
448            self.state.read().total_layers
449        );
450
451        // Calculate model statistics
452        let model_stats = self.calculate_model_stats()?;
453
454        // Generate visualization based on format
455        let (output_data, output_size) = match self.config.output_format {
456            VisualizationFormat::TextSummary => self.generate_text_summary(&sampled_layers)?,
457            VisualizationFormat::JsonMetadata => self.generate_json_metadata(&sampled_layers)?,
458            VisualizationFormat::StaticSvg => self.generate_static_svg(&sampled_layers)?,
459            VisualizationFormat::InteractiveSvg => {
460                self.generate_interactive_svg(&sampled_layers)?
461            },
462            VisualizationFormat::InteractiveHtml => {
463                self.generate_interactive_html(&sampled_layers)?
464            },
465            VisualizationFormat::StaticPng => {
466                // PNG output requires the `video` or `gif` feature which gates the `image` crate.
467                // To enable: rebuild with `--features video` (or `--features gif`).
468                // Without that feature, fall back to a descriptive error so callers can
469                // switch to SVG/HTML output which works without any extra features.
470                #[cfg(feature = "video")]
471                {
472                    self.generate_png(&sampled_layers)?
473                }
474                #[cfg(not(feature = "video"))]
475                {
476                    return Err(anyhow::anyhow!(
477                        "PNG generation requires the `video` feature. \
478                         Rebuild with `--features video`, or use \
479                         VisualizationFormat::StaticSvg / InteractiveHtml instead."
480                    ));
481                }
482            },
483        };
484
485        // Save to file if path provided
486        if let Some(ref path) = output_path {
487            std::fs::write(path, &output_data)
488                .with_context(|| format!("Failed to write visualization to {}", path))?;
489            info!("Saved visualization to {}", path);
490        }
491
492        let time_taken = start_time.elapsed().as_secs_f64();
493        let state = self.state.read();
494
495        Ok(VisualizationResult {
496            output_path,
497            inline_data: if output_size < 1024 * 1024 { Some(output_data) } else { None }, // Include inline if < 1MB
498            stats: VisualizationStats {
499                layers_visualized: sampled_layers.len(),
500                total_layers: state.total_layers,
501                sampling_ratio: sampled_layers.len() as f64 / state.total_layers as f64,
502                memory_used_mb: state.current_memory_mb,
503                time_taken_secs: time_taken,
504                output_size_bytes: output_size,
505            },
506            sampled_layers,
507            model_stats,
508        })
509    }
510
511    /// Calculate overall model statistics
512    fn calculate_model_stats(&self) -> Result<ModelStatistics> {
513        let cache = self.layer_cache.read();
514
515        let total_params: usize = cache.values().map(|l| l.param_count).sum();
516        let total_memory_mb: f64 = cache.values().map(|l| l.memory_mb).sum();
517        let total_gflops: f64 = cache.values().map(|l| l.compute_flops).sum::<u64>() as f64 / 1e9;
518        let max_depth = cache.values().map(|l| l.index).max().unwrap_or(0);
519
520        let mut layer_types: HashMap<String, usize> = HashMap::new();
521        for metadata in cache.values() {
522            *layer_types.entry(metadata.layer_type.clone()).or_insert(0) += 1;
523        }
524
525        Ok(ModelStatistics {
526            total_params,
527            total_memory_mb,
528            total_gflops,
529            max_depth,
530            layer_types,
531        })
532    }
533
534    /// Generate text summary (minimal memory)
535    fn generate_text_summary(&self, sampled_layers: &[usize]) -> Result<(Vec<u8>, usize)> {
536        let cache = self.layer_cache.read();
537
538        let mut summary = String::from("=== Large Model Visualization Summary ===\n\n");
539
540        summary.push_str(&format!(
541            "Total Layers: {}\n",
542            self.state.read().total_layers
543        ));
544        summary.push_str(&format!("Visualized Layers: {}\n\n", sampled_layers.len()));
545
546        summary.push_str("Layer Details:\n");
547        for &idx in sampled_layers {
548            if let Some(layer) = cache.values().find(|l| l.index == idx) {
549                summary.push_str(&format!(
550                    "  [{}] {} - {} params, {:.2} MB, {:.1} GFLOPS\n",
551                    layer.index,
552                    layer.name,
553                    layer.param_count,
554                    layer.memory_mb,
555                    layer.compute_flops as f64 / 1e9
556                ));
557            }
558        }
559
560        let bytes = summary.into_bytes();
561        let size = bytes.len();
562        Ok((bytes, size))
563    }
564
565    /// Generate JSON metadata
566    fn generate_json_metadata(&self, sampled_layers: &[usize]) -> Result<(Vec<u8>, usize)> {
567        let cache = self.layer_cache.read();
568
569        let layers: Vec<_> = sampled_layers
570            .iter()
571            .filter_map(|&idx| cache.values().find(|l| l.index == idx).cloned())
572            .collect();
573
574        let json = serde_json::to_string_pretty(&layers)?;
575        let bytes = json.into_bytes();
576        let size = bytes.len();
577        Ok((bytes, size))
578    }
579
580    /// Generate static SVG
581    fn generate_static_svg(&self, sampled_layers: &[usize]) -> Result<(Vec<u8>, usize)> {
582        let cache = self.layer_cache.read();
583
584        let mut svg = String::from(
585            r#"<?xml version="1.0" encoding="UTF-8"?>
586<svg xmlns="http://www.w3.org/2000/svg" width="1200" height="800" viewBox="0 0 1200 800">
587<style>
588.layer { fill: #4a90e2; stroke: #2c5aa0; stroke-width: 2; }
589.layer-text { fill: white; font-family: Arial, sans-serif; font-size: 12px; }
590.title { font-family: Arial, sans-serif; font-size: 20px; font-weight: bold; }
591</style>
592<text x="600" y="30" class="title" text-anchor="middle">Model Architecture</text>
593"#,
594        );
595
596        let layer_height = 60;
597        let layer_width = 200;
598        let x_offset = 500;
599        let y_start = 60;
600
601        for (i, &idx) in sampled_layers.iter().enumerate() {
602            if let Some(layer) = cache.values().find(|l| l.index == idx) {
603                let y = y_start + i * (layer_height + 20);
604
605                svg.push_str(&format!(
606                    r#"<rect x="{}" y="{}" width="{}" height="{}" class="layer" />
607<text x="{}" y="{}" class="layer-text" text-anchor="middle">{}</text>
608<text x="{}" y="{}" class="layer-text" text-anchor="middle">{:.1}M params</text>
609"#,
610                    x_offset,
611                    y,
612                    layer_width,
613                    layer_height,
614                    x_offset + layer_width / 2,
615                    y + 25,
616                    layer.name,
617                    x_offset + layer_width / 2,
618                    y + 45,
619                    layer.param_count as f64 / 1e6
620                ));
621            }
622        }
623
624        svg.push_str("</svg>");
625
626        let bytes = svg.into_bytes();
627        let size = bytes.len();
628        Ok((bytes, size))
629    }
630
631    /// Generate interactive SVG with zoom/pan via embedded ECMAScript
632    fn generate_interactive_svg(&self, sampled_layers: &[usize]) -> Result<(Vec<u8>, usize)> {
633        let cache = self.layer_cache.read();
634
635        let layer_height = 60usize;
636        let layer_width = 200usize;
637        let x_offset = 500usize;
638        let y_start = 60usize;
639        let svg_height = y_start + sampled_layers.len() * (layer_height + 20) + 40;
640        let svg_width = 1200usize;
641
642        // Build the inner layer elements first
643        let mut layer_elems = String::new();
644        for (i, &idx) in sampled_layers.iter().enumerate() {
645            if let Some(layer) = cache.values().find(|l| l.index == idx) {
646                let y = y_start + i * (layer_height + 20);
647                layer_elems.push_str(&format!(
648                    r#"<rect x="{x}" y="{y}" width="{w}" height="{h}" class="layer" />
649<text x="{cx}" y="{ty}" class="layer-text" text-anchor="middle">{name}</text>
650<text x="{cx}" y="{py}" class="layer-text" text-anchor="middle">{params:.1}M params</text>
651"#,
652                    x = x_offset,
653                    y = y,
654                    w = layer_width,
655                    h = layer_height,
656                    cx = x_offset + layer_width / 2,
657                    ty = y + 25,
658                    py = y + 45,
659                    name = layer.name,
660                    params = layer.param_count as f64 / 1e6
661                ));
662            }
663        }
664
665        // Compose full SVG with embedded pan/zoom JavaScript.
666        // The <script> block uses an SVG foreignObject-free approach: it attaches
667        // pointer-event listeners directly to the root <svg> element and manipulates
668        // a <g id="viewport"> transform, which is valid SVG+JS in any modern browser.
669        let svg = format!(
670            r#"<?xml version="1.0" encoding="UTF-8"?>
671<svg xmlns="http://www.w3.org/2000/svg"
672     xmlns:xlink="http://www.w3.org/1999/xlink"
673     id="svg-root"
674     width="{width}" height="{height}"
675     viewBox="0 0 {width} {height}"
676     style="cursor:grab;user-select:none;">
677<style>
678.layer {{ fill: #4a90e2; stroke: #2c5aa0; stroke-width: 2; }}
679.layer-text {{ fill: white; font-family: Arial, sans-serif; font-size: 12px; }}
680.title {{ font-family: Arial, sans-serif; font-size: 20px; font-weight: bold; }}
681</style>
682<text x="{title_x}" y="30" class="title" text-anchor="middle">Model Architecture (interactive)</text>
683<g id="viewport">
684{layers}
685</g>
686<script type="text/javascript"><![CDATA[
687(function() {{
688  var svg   = document.getElementById('svg-root');
689  var vp    = document.getElementById('viewport');
690  var tx = 0, ty = 0, scale = 1.0;
691  var dragging = false;
692  var startX = 0, startY = 0;
693
694  function applyTransform() {{
695    vp.setAttribute('transform',
696      'translate(' + tx + ',' + ty + ') scale(' + scale + ')');
697  }}
698
699  // Pan: mousedown / mousemove / mouseup
700  svg.addEventListener('mousedown', function(e) {{
701    dragging = true;
702    startX = e.clientX - tx;
703    startY = e.clientY - ty;
704    svg.style.cursor = 'grabbing';
705    e.preventDefault();
706  }});
707  window.addEventListener('mousemove', function(e) {{
708    if (!dragging) return;
709    tx = e.clientX - startX;
710    ty = e.clientY - startY;
711    applyTransform();
712  }});
713  window.addEventListener('mouseup', function() {{
714    dragging = false;
715    svg.style.cursor = 'grab';
716  }});
717
718  // Touch pan
719  var lastTouch = null;
720  svg.addEventListener('touchstart', function(e) {{
721    if (e.touches.length === 1) {{
722      lastTouch = e.touches[0];
723    }}
724    e.preventDefault();
725  }}, {{ passive: false }});
726  svg.addEventListener('touchmove', function(e) {{
727    if (e.touches.length === 1 && lastTouch) {{
728      var t = e.touches[0];
729      tx += t.clientX - lastTouch.clientX;
730      ty += t.clientY - lastTouch.clientY;
731      lastTouch = t;
732      applyTransform();
733    }}
734    e.preventDefault();
735  }}, {{ passive: false }});
736  svg.addEventListener('touchend', function() {{ lastTouch = null; }});
737
738  // Zoom: mousewheel
739  svg.addEventListener('wheel', function(e) {{
740    e.preventDefault();
741    var delta = e.deltaY > 0 ? 0.9 : 1.1;
742    // Zoom towards cursor position
743    var rect  = svg.getBoundingClientRect();
744    var mx = e.clientX - rect.left;
745    var my = e.clientY - rect.top;
746    tx = mx - (mx - tx) * delta;
747    ty = my - (my - ty) * delta;
748    scale = Math.max(0.1, Math.min(10.0, scale * delta));
749    applyTransform();
750  }}, {{ passive: false }});
751
752  // Double-click to reset
753  svg.addEventListener('dblclick', function() {{
754    tx = 0; ty = 0; scale = 1.0;
755    applyTransform();
756  }});
757}})();
758]]></script>
759</svg>"#,
760            width = svg_width,
761            height = svg_height,
762            title_x = svg_width / 2,
763            layers = layer_elems,
764        );
765
766        let bytes = svg.into_bytes();
767        let size = bytes.len();
768        Ok((bytes, size))
769    }
770
771    /// Generate interactive HTML with JavaScript
772    fn generate_interactive_html(&self, sampled_layers: &[usize]) -> Result<(Vec<u8>, usize)> {
773        let cache = self.layer_cache.read();
774        let model_stats = self.calculate_model_stats()?;
775
776        let mut html = String::from(
777            r#"<!DOCTYPE html>
778<html>
779<head>
780<meta charset="UTF-8">
781<title>Large Model Visualization</title>
782<style>
783body { font-family: Arial, sans-serif; margin: 20px; background: #f5f5f5; }
784.container { max-width: 1200px; margin: 0 auto; }
785.header { background: #4a90e2; color: white; padding: 20px; border-radius: 8px; }
786.stats { display: grid; grid-template-columns: repeat(4, 1fr); gap: 15px; margin: 20px 0; }
787.stat-card { background: white; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); }
788.layer-list { background: white; padding: 20px; border-radius: 8px; }
789.layer { padding: 10px; margin: 5px 0; background: #f9f9f9; border-left: 4px solid #4a90e2; }
790</style>
791</head>
792<body>
793<div class="container">
794<div class="header">
795<h1>Large Model Visualization</h1>
796<p>Interactive view of model architecture</p>
797</div>
798<div class="stats">
799"#,
800        );
801
802        // Add stats cards
803        html.push_str(&format!(
804            r#"<div class="stat-card">
805<h3>{:.1}M</h3>
806<p>Total Parameters</p>
807</div>
808<div class="stat-card">
809<h3>{:.1} GB</h3>
810<p>Total Memory</p>
811</div>
812<div class="stat-card">
813<h3>{}</h3>
814<p>Total Layers</p>
815</div>
816<div class="stat-card">
817<h3>{}/{}</h3>
818<p>Visualized/Total</p>
819</div>
820"#,
821            model_stats.total_params as f64 / 1e6,
822            model_stats.total_memory_mb / 1024.0,
823            model_stats.max_depth + 1,
824            sampled_layers.len(),
825            self.state.read().total_layers
826        ));
827
828        html.push_str("</div><div class=\"layer-list\"><h2>Layer Details</h2>");
829
830        // Add layer details
831        for &idx in sampled_layers {
832            if let Some(layer) = cache.values().find(|l| l.index == idx) {
833                html.push_str(&format!(
834                    r#"<div class="layer">
835<strong>[{}] {}</strong><br>
836Type: {} | Parameters: {:.1}M | Memory: {:.2} MB | Compute: {:.1} GFLOPS
837</div>
838"#,
839                    layer.index,
840                    layer.name,
841                    layer.layer_type,
842                    layer.param_count as f64 / 1e6,
843                    layer.memory_mb,
844                    layer.compute_flops as f64 / 1e9
845                ));
846            }
847        }
848
849        html.push_str("</div></div></body></html>");
850
851        let bytes = html.into_bytes();
852        let size = bytes.len();
853        Ok((bytes, size))
854    }
855
856    /// Generate a static PNG heatmap visualization of the sampled layers.
857    ///
858    /// Each layer is rendered as a horizontal bar whose width is proportional to
859    /// `param_count` and whose colour encodes `memory_mb` (blue → red gradient).
860    /// The resulting image is PNG-encoded and returned as a raw byte vector.
861    ///
862    /// Requires the `video` feature (which enables the `image` crate).
863    #[cfg(feature = "video")]
864    fn generate_png(&self, sampled_layers: &[usize]) -> Result<(Vec<u8>, usize)> {
865        use image::{ImageBuffer, Rgb};
866        use std::io::Cursor;
867
868        let cache = self.layer_cache.read();
869
870        // Gather layers in index order.
871        let mut layers: Vec<&LayerMetadata> = sampled_layers
872            .iter()
873            .filter_map(|&idx| cache.values().find(|l| l.index == idx))
874            .collect();
875        layers.sort_by_key(|l| l.index);
876
877        // Image layout constants.
878        const IMG_WIDTH: u32 = 1200;
879        const BAR_HEIGHT: u32 = 30;
880        const BAR_PADDING: u32 = 6;
881        const LEFT_MARGIN: u32 = 20;
882        const RIGHT_MARGIN: u32 = 20;
883
884        let row_height = BAR_HEIGHT + BAR_PADDING;
885        let img_height = if layers.is_empty() {
886            100
887        } else {
888            layers.len() as u32 * row_height + 2 * BAR_PADDING + 40 // +40 for title row
889        };
890
891        let max_params = layers.iter().map(|l| l.param_count).max().unwrap_or(1).max(1);
892
893        let max_memory = layers.iter().map(|l| l.memory_mb).fold(0.0_f64, f64::max).max(1.0);
894
895        let available_width = IMG_WIDTH - LEFT_MARGIN - RIGHT_MARGIN;
896
897        let mut img = ImageBuffer::<Rgb<u8>, Vec<u8>>::new(IMG_WIDTH, img_height);
898
899        // Background: near-white.
900        for pixel in img.pixels_mut() {
901            *pixel = Rgb([245u8, 245u8, 250u8]);
902        }
903
904        // Title bar.
905        for x in 0..IMG_WIDTH {
906            for y in 0..36 {
907                img.put_pixel(x, y, Rgb([74u8, 144u8, 226u8]));
908            }
909        }
910
911        // Draw each layer as a horizontal heatmap bar.
912        for (i, layer) in layers.iter().enumerate() {
913            let bar_top = 40 + i as u32 * row_height;
914
915            // Bar width proportional to param_count.
916            let bar_w = ((layer.param_count as f64 / max_params as f64) * available_width as f64)
917                .round() as u32;
918            let bar_w = bar_w.max(4); // always visible
919
920            // Colour: blue (low memory) → red (high memory) gradient.
921            let t = (layer.memory_mb / max_memory).clamp(0.0, 1.0) as f32;
922            let r = (t * 220.0) as u8;
923            let g = ((1.0 - t) * 100.0 + 40.0) as u8;
924            let b = ((1.0 - t) * 220.0) as u8;
925            let bar_colour = Rgb([r, g, b]);
926
927            for x in LEFT_MARGIN..(LEFT_MARGIN + bar_w).min(IMG_WIDTH - RIGHT_MARGIN) {
928                for y in bar_top..(bar_top + BAR_HEIGHT).min(img_height) {
929                    img.put_pixel(x, y, bar_colour);
930                }
931            }
932        }
933
934        // Encode as PNG into an in-memory buffer.
935        let mut png_bytes: Vec<u8> = Vec::new();
936        img.write_to(&mut Cursor::new(&mut png_bytes), image::ImageFormat::Png)
937            .with_context(|| "Failed to PNG-encode large model visualization")?;
938
939        let size = png_bytes.len();
940        Ok((png_bytes, size))
941    }
942
943    /// Get current visualization progress (0.0-1.0)
944    pub fn get_progress(&self) -> f64 {
945        self.state.read().progress
946    }
947
948    /// Get memory usage statistics
949    pub fn get_memory_stats(&self) -> MemoryStats {
950        let state = self.state.read();
951        MemoryStats {
952            current_mb: state.current_memory_mb,
953            max_mb: self.config.max_memory_mb as f64,
954            utilization_pct: (state.current_memory_mb / self.config.max_memory_mb as f64 * 100.0)
955                .min(100.0),
956        }
957    }
958}
959
960/// Memory usage statistics
961#[derive(Debug, Clone, Serialize, Deserialize)]
962pub struct MemoryStats {
963    /// Current memory usage (MB)
964    pub current_mb: f64,
965    /// Maximum allowed memory (MB)
966    pub max_mb: f64,
967    /// Utilization percentage
968    pub utilization_pct: f64,
969}
970
971#[cfg(test)]
972mod tests {
973    use super::*;
974
975    #[test]
976    fn test_visualizer_creation() {
977        let config = LargeModelVisualizerConfig::default();
978        let _visualizer = LargeModelVisualizer::new(config);
979    }
980
981    #[test]
982    fn test_add_layers() -> Result<()> {
983        let config = LargeModelVisualizerConfig::default();
984        let visualizer = LargeModelVisualizer::new(config);
985
986        for i in 0..10 {
987            let metadata = LayerMetadata {
988                name: format!("layer_{}", i),
989                index: i,
990                layer_type: "Linear".to_string(),
991                param_count: 1024 * 1024,
992                memory_mb: 4.0,
993                compute_flops: 1_000_000_000,
994                input_shape: vec![512],
995                output_shape: vec![512],
996                is_sampled: false,
997            };
998            visualizer.add_layer(metadata)?;
999        }
1000
1001        let stats = visualizer.get_memory_stats();
1002        assert_eq!(stats.current_mb, 40.0);
1003
1004        Ok(())
1005    }
1006
1007    #[test]
1008    fn test_uniform_sampling() -> Result<()> {
1009        let config = LargeModelVisualizerConfig {
1010            max_full_layers: 5,
1011            sampling_strategy: SamplingStrategy::Uniform,
1012            ..Default::default()
1013        };
1014
1015        let visualizer = LargeModelVisualizer::new(config);
1016
1017        // Add 20 layers
1018        for i in 0..20 {
1019            let metadata = LayerMetadata {
1020                name: format!("layer_{}", i),
1021                index: i,
1022                layer_type: "Linear".to_string(),
1023                param_count: 1024 * 1024,
1024                memory_mb: 4.0,
1025                compute_flops: 1_000_000_000,
1026                input_shape: vec![512],
1027                output_shape: vec![512],
1028                is_sampled: false,
1029            };
1030            visualizer.add_layer(metadata)?;
1031        }
1032
1033        let sampled = visualizer.determine_sampling()?;
1034        assert_eq!(sampled.len(), 5);
1035
1036        Ok(())
1037    }
1038
1039    #[cfg(feature = "video")]
1040    #[test]
1041    fn test_png_visualization() -> Result<()> {
1042        let config = LargeModelVisualizerConfig {
1043            output_format: VisualizationFormat::StaticPng,
1044            ..Default::default()
1045        };
1046
1047        let visualizer = LargeModelVisualizer::new(config);
1048
1049        for i in 0..5_usize {
1050            let metadata = LayerMetadata {
1051                name: format!("layer_{}", i),
1052                index: i,
1053                layer_type: "Linear".to_string(),
1054                param_count: 1024 * (i + 1),
1055                memory_mb: 2.0 * (i + 1) as f64,
1056                compute_flops: 500_000_000 * (i + 1) as u64,
1057                input_shape: vec![512],
1058                output_shape: vec![512],
1059                is_sampled: false,
1060            };
1061            visualizer.add_layer(metadata)?;
1062        }
1063
1064        let result = visualizer.visualize(None)?;
1065
1066        // Basic sanity checks
1067        assert_eq!(result.stats.layers_visualized, 5);
1068        assert!(
1069            result.stats.output_size_bytes > 0,
1070            "PNG output must be non-empty"
1071        );
1072
1073        // Verify PNG magic bytes: 0x89 P N G
1074        let data = result.inline_data.expect("inline data should be present for small PNG");
1075        assert!(
1076            data.starts_with(&[0x89, 0x50, 0x4E, 0x47]),
1077            "Output must start with PNG magic bytes"
1078        );
1079
1080        Ok(())
1081    }
1082
1083    #[test]
1084    fn test_text_visualization() -> Result<()> {
1085        let config = LargeModelVisualizerConfig {
1086            output_format: VisualizationFormat::TextSummary,
1087            ..Default::default()
1088        };
1089
1090        let visualizer = LargeModelVisualizer::new(config);
1091
1092        // Add a few layers
1093        for i in 0..5 {
1094            let metadata = LayerMetadata {
1095                name: format!("layer_{}", i),
1096                index: i,
1097                layer_type: "Linear".to_string(),
1098                param_count: 1024 * 1024 * (i + 1),
1099                memory_mb: 4.0 * (i + 1) as f64,
1100                compute_flops: 1_000_000_000 * (i + 1) as u64,
1101                input_shape: vec![512],
1102                output_shape: vec![512],
1103                is_sampled: false,
1104            };
1105            visualizer.add_layer(metadata)?;
1106        }
1107
1108        let result = visualizer.visualize(None)?;
1109
1110        assert_eq!(result.stats.layers_visualized, 5);
1111        assert!(result.stats.output_size_bytes > 0);
1112
1113        Ok(())
1114    }
1115}