Skip to main content

trustformers_debug/
large_model_viz.rs

1//! Large Model Visualization with Memory Efficiency
2//!
3//! This module provides optimized visualization for large transformer models,
4//! using smart sampling, hierarchical rendering, and memory-efficient techniques
5//! to handle models with billions of parameters.
6
7use anyhow::{Context, Result};
8use parking_lot::RwLock;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use tracing::{debug, info, warn};
13
14/// Large model visualizer with memory-efficient rendering
15///
16/// Features:
17/// - Smart layer sampling (visualize representative layers)
18/// - Hierarchical graph rendering (collapse/expand sections)
19/// - Streaming visualization (process in chunks)
20/// - Memory-bounded caching
21/// - Progressive loading
22#[derive(Debug)]
23pub struct LargeModelVisualizer {
24    /// Configuration
25    config: LargeModelVisualizerConfig,
26    /// Cached layer metadata
27    layer_cache: Arc<RwLock<HashMap<String, LayerMetadata>>>,
28    /// Visualization state
29    state: Arc<RwLock<VisualizationState>>,
30}
31
32/// Configuration for large model visualization
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct LargeModelVisualizerConfig {
35    /// Enable smart layer sampling
36    pub enable_smart_sampling: bool,
37    /// Maximum layers to visualize fully (rest are sampled)
38    pub max_full_layers: usize,
39    /// Sampling strategy
40    pub sampling_strategy: SamplingStrategy,
41    /// Enable hierarchical rendering
42    pub enable_hierarchical: bool,
43    /// Enable streaming mode for very large models
44    pub enable_streaming: bool,
45    /// Maximum memory for visualization (MB)
46    pub max_memory_mb: usize,
47    /// Chunk size for streaming (number of layers)
48    pub stream_chunk_size: usize,
49    /// Enable progressive detail loading
50    pub enable_progressive_loading: bool,
51    /// Visualization format
52    pub output_format: VisualizationFormat,
53}
54
55impl Default for LargeModelVisualizerConfig {
56    fn default() -> Self {
57        Self {
58            enable_smart_sampling: true,
59            max_full_layers: 50,
60            sampling_strategy: SamplingStrategy::Adaptive,
61            enable_hierarchical: true,
62            enable_streaming: true,
63            max_memory_mb: 1024, // 1 GB
64            stream_chunk_size: 10,
65            enable_progressive_loading: true,
66            output_format: VisualizationFormat::InteractiveSvg,
67        }
68    }
69}
70
71/// Layer sampling strategy for large models
72#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
73pub enum SamplingStrategy {
74    /// Uniform sampling (evenly spaced layers)
75    Uniform,
76    /// Adaptive sampling (more samples where complexity varies)
77    Adaptive,
78    /// Representative sampling (first, middle, last + interesting layers)
79    Representative,
80    /// Importance-based (based on parameter count, compute cost)
81    ImportanceBased,
82}
83
84/// Visualization output format
85#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
86pub enum VisualizationFormat {
87    /// Static PNG image (memory efficient)
88    StaticPng,
89    /// Static SVG (scalable but larger)
90    StaticSvg,
91    /// Interactive SVG with zoom/pan
92    InteractiveSvg,
93    /// Interactive HTML with JavaScript
94    InteractiveHtml,
95    /// Text-based summary (minimal memory)
96    TextSummary,
97    /// JSON metadata only
98    JsonMetadata,
99}
100
101/// Metadata about a model layer
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct LayerMetadata {
104    /// Layer name
105    pub name: String,
106    /// Layer index
107    pub index: usize,
108    /// Layer type
109    pub layer_type: String,
110    /// Number of parameters
111    pub param_count: usize,
112    /// Estimated memory (MB)
113    pub memory_mb: f64,
114    /// Estimated compute cost (FLOPS)
115    pub compute_flops: u64,
116    /// Input shape
117    pub input_shape: Vec<usize>,
118    /// Output shape
119    pub output_shape: Vec<usize>,
120    /// Is this layer sampled for visualization?
121    pub is_sampled: bool,
122}
123
124/// Current visualization state
125#[derive(Debug, Clone, Default)]
126struct VisualizationState {
127    /// Total layers in model
128    total_layers: usize,
129    /// Layers currently loaded
130    #[allow(dead_code)]
131    loaded_layers: Vec<String>,
132    /// Current memory usage (MB)
133    current_memory_mb: f64,
134    /// Visualization progress (0.0-1.0)
135    progress: f64,
136    /// Is visualization complete?
137    #[allow(dead_code)]
138    is_complete: bool,
139}
140
141/// Visualization result
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct VisualizationResult {
144    /// Output file path (if saved to file)
145    pub output_path: Option<String>,
146    /// Inline data (if small enough)
147    pub inline_data: Option<Vec<u8>>,
148    /// Visualization statistics
149    pub stats: VisualizationStats,
150    /// Sampled layer indices
151    pub sampled_layers: Vec<usize>,
152    /// Total model statistics
153    pub model_stats: ModelStatistics,
154}
155
156/// Statistics about the visualization
157#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct VisualizationStats {
159    /// Number of layers visualized
160    pub layers_visualized: usize,
161    /// Number of layers in model
162    pub total_layers: usize,
163    /// Sampling ratio
164    pub sampling_ratio: f64,
165    /// Memory used for visualization (MB)
166    pub memory_used_mb: f64,
167    /// Time taken (seconds)
168    pub time_taken_secs: f64,
169    /// Output size (bytes)
170    pub output_size_bytes: usize,
171}
172
173/// Overall model statistics
174#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct ModelStatistics {
176    /// Total parameters
177    pub total_params: usize,
178    /// Total memory footprint (MB)
179    pub total_memory_mb: f64,
180    /// Total compute cost (GFLOPS)
181    pub total_gflops: f64,
182    /// Deepest layer index
183    pub max_depth: usize,
184    /// Layer type distribution
185    pub layer_types: HashMap<String, usize>,
186}
187
188/// Layer group for hierarchical visualization
189#[derive(Debug, Clone, Serialize, Deserialize)]
190pub struct LayerGroup {
191    /// Group name
192    pub name: String,
193    /// Layer indices in this group
194    pub layers: Vec<usize>,
195    /// Is this group collapsed?
196    pub collapsed: bool,
197    /// Summary statistics for group
198    pub summary: GroupSummary,
199}
200
201/// Summary for a layer group
202#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct GroupSummary {
204    /// Total parameters in group
205    pub param_count: usize,
206    /// Total memory (MB)
207    pub memory_mb: f64,
208    /// Average compute cost per layer
209    pub avg_compute_flops: u64,
210}
211
212impl LargeModelVisualizer {
213    /// Create a new large model visualizer
214    ///
215    /// # Arguments
216    /// * `config` - Visualizer configuration
217    ///
218    /// # Example
219    /// ```rust
220    /// use trustformers_debug::{LargeModelVisualizer, LargeModelVisualizerConfig};
221    ///
222    /// let config = LargeModelVisualizerConfig::default();
223    /// let visualizer = LargeModelVisualizer::new(config);
224    /// ```
225    pub fn new(config: LargeModelVisualizerConfig) -> Self {
226        info!("Initializing large model visualizer");
227        Self {
228            config,
229            layer_cache: Arc::new(RwLock::new(HashMap::new())),
230            state: Arc::new(RwLock::new(VisualizationState::default())),
231        }
232    }
233
234    /// Add layer metadata to the visualizer
235    ///
236    /// # Arguments
237    /// * `metadata` - Layer metadata
238    pub fn add_layer(&self, metadata: LayerMetadata) -> Result<()> {
239        let mut cache = self.layer_cache.write();
240        let mut state = self.state.write();
241
242        cache.insert(metadata.name.clone(), metadata.clone());
243        state.total_layers = cache.len();
244        state.current_memory_mb += metadata.memory_mb;
245
246        // Check memory limit
247        if state.current_memory_mb > self.config.max_memory_mb as f64 {
248            warn!(
249                "Memory limit exceeded: {:.1} MB > {} MB. Consider increasing max_memory_mb or enabling sampling",
250                state.current_memory_mb,
251                self.config.max_memory_mb
252            );
253        }
254
255        Ok(())
256    }
257
258    /// Analyze model and determine sampling strategy
259    ///
260    /// # Returns
261    /// Indices of layers to visualize in detail
262    pub fn determine_sampling(&self) -> Result<Vec<usize>> {
263        let cache = self.layer_cache.read();
264        let state = self.state.read();
265
266        if !self.config.enable_smart_sampling || state.total_layers <= self.config.max_full_layers {
267            // Visualize all layers
268            return Ok((0..state.total_layers).collect());
269        }
270
271        debug!(
272            "Applying {:?} sampling strategy for {} layers",
273            self.config.sampling_strategy, state.total_layers
274        );
275
276        let sampled_indices = match self.config.sampling_strategy {
277            SamplingStrategy::Uniform => self.uniform_sampling(state.total_layers),
278            SamplingStrategy::Adaptive => self.adaptive_sampling(&cache),
279            SamplingStrategy::Representative => self.representative_sampling(state.total_layers),
280            SamplingStrategy::ImportanceBased => self.importance_sampling(&cache),
281        };
282
283        Ok(sampled_indices)
284    }
285
286    /// Uniform sampling: evenly spaced layers
287    fn uniform_sampling(&self, total_layers: usize) -> Vec<usize> {
288        let max_layers = self.config.max_full_layers;
289        let step = (total_layers as f64 / max_layers as f64).ceil() as usize;
290
291        (0..total_layers).step_by(step).collect()
292    }
293
294    /// Adaptive sampling: more samples where complexity varies
295    fn adaptive_sampling(&self, cache: &HashMap<String, LayerMetadata>) -> Vec<usize> {
296        let mut layers: Vec<_> = cache.values().collect();
297        layers.sort_by_key(|l| l.index);
298
299        let mut sampled = Vec::new();
300        let max_layers = self.config.max_full_layers;
301
302        // Always include first and last layers
303        if !layers.is_empty() {
304            sampled.push(0);
305            sampled.push(layers.len() - 1);
306        }
307
308        // Calculate complexity variance between consecutive layers
309        let mut variances = Vec::new();
310        for i in 0..layers.len().saturating_sub(1) {
311            let complexity_diff =
312                (layers[i + 1].param_count as i64 - layers[i].param_count as i64).abs();
313            variances.push((i, complexity_diff));
314        }
315
316        // Sort by variance (descending)
317        variances.sort_by_key(|item| std::cmp::Reverse(item.1));
318
319        // Sample layers with highest variance
320        for (idx, _) in variances.iter().take(max_layers.saturating_sub(2)) {
321            sampled.push(*idx);
322        }
323
324        sampled.sort_unstable();
325        sampled.dedup();
326        sampled
327    }
328
329    /// Representative sampling: first, middle, last + interesting layers
330    fn representative_sampling(&self, total_layers: usize) -> Vec<usize> {
331        let mut sampled = Vec::new();
332
333        if total_layers == 0 {
334            return sampled;
335        }
336
337        // First layers
338        sampled.extend(0..3.min(total_layers));
339
340        // Middle layers
341        let mid = total_layers / 2;
342        sampled.extend((mid.saturating_sub(1))..=(mid + 1).min(total_layers - 1));
343
344        // Last layers
345        sampled.extend((total_layers.saturating_sub(3))..total_layers);
346
347        // Add evenly spaced samples in between
348        let remaining_budget = self.config.max_full_layers.saturating_sub(sampled.len());
349        let step = (total_layers as f64 / remaining_budget as f64).ceil() as usize;
350
351        for i in (0..total_layers).step_by(step) {
352            sampled.push(i);
353        }
354
355        sampled.sort_unstable();
356        sampled.dedup();
357        sampled
358    }
359
360    /// Importance-based sampling: prioritize large/complex layers
361    fn importance_sampling(&self, cache: &HashMap<String, LayerMetadata>) -> Vec<usize> {
362        let mut layers: Vec<_> = cache.values().collect();
363
364        // Calculate importance score (weighted sum of params and compute)
365        layers.sort_by(|a, b| {
366            let score_a = (a.param_count as f64) + (a.compute_flops as f64 / 1e9);
367            let score_b = (b.param_count as f64) + (b.compute_flops as f64 / 1e9);
368            score_b.partial_cmp(&score_a).unwrap_or(std::cmp::Ordering::Equal)
369        });
370
371        layers.iter().take(self.config.max_full_layers).map(|l| l.index).collect()
372    }
373
374    /// Create hierarchical layer groups
375    ///
376    /// Groups layers by type or sequential blocks for collapsible visualization
377    pub fn create_layer_groups(&self) -> Result<Vec<LayerGroup>> {
378        let cache = self.layer_cache.read();
379
380        if !self.config.enable_hierarchical || cache.len() < 20 {
381            // Not worth grouping small models
382            return Ok(Vec::new());
383        }
384
385        let mut groups: HashMap<String, Vec<usize>> = HashMap::new();
386
387        // Group by layer type
388        for metadata in cache.values() {
389            groups.entry(metadata.layer_type.clone()).or_default().push(metadata.index);
390        }
391
392        // Create LayerGroup objects
393        let mut layer_groups = Vec::new();
394
395        for (layer_type, indices) in groups {
396            // Calculate summary
397            let group_layers: Vec<_> = indices
398                .iter()
399                .filter_map(|&idx| cache.values().find(|l| l.index == idx))
400                .collect();
401
402            let param_count: usize = group_layers.iter().map(|l| l.param_count).sum();
403            let memory_mb: f64 = group_layers.iter().map(|l| l.memory_mb).sum();
404            let avg_compute_flops = if !group_layers.is_empty() {
405                group_layers.iter().map(|l| l.compute_flops).sum::<u64>()
406                    / group_layers.len() as u64
407            } else {
408                0
409            };
410
411            let indices_len = indices.len();
412            layer_groups.push(LayerGroup {
413                name: format!("{} ({} layers)", layer_type, indices_len),
414                layers: indices,
415                collapsed: indices_len > 10, // Auto-collapse large groups
416                summary: GroupSummary {
417                    param_count,
418                    memory_mb,
419                    avg_compute_flops,
420                },
421            });
422        }
423
424        // Sort by first layer index
425        layer_groups.sort_by_key(|g| g.layers.first().copied().unwrap_or(0));
426
427        Ok(layer_groups)
428    }
429
430    /// Generate visualization with memory-efficient rendering
431    ///
432    /// # Arguments
433    /// * `output_path` - Optional output file path
434    ///
435    /// # Returns
436    /// Visualization result with statistics
437    pub fn visualize(&self, output_path: Option<String>) -> Result<VisualizationResult> {
438        info!("Starting large model visualization");
439
440        let start_time = std::time::Instant::now();
441
442        // Determine which layers to visualize
443        let sampled_layers = self.determine_sampling()?;
444
445        info!(
446            "Visualizing {} out of {} layers",
447            sampled_layers.len(),
448            self.state.read().total_layers
449        );
450
451        // Calculate model statistics
452        let model_stats = self.calculate_model_stats()?;
453
454        // Generate visualization based on format
455        let (output_data, output_size) = match self.config.output_format {
456            VisualizationFormat::TextSummary => self.generate_text_summary(&sampled_layers)?,
457            VisualizationFormat::JsonMetadata => self.generate_json_metadata(&sampled_layers)?,
458            VisualizationFormat::StaticSvg => self.generate_static_svg(&sampled_layers)?,
459            VisualizationFormat::InteractiveSvg => {
460                self.generate_interactive_svg(&sampled_layers)?
461            },
462            VisualizationFormat::InteractiveHtml => {
463                self.generate_interactive_html(&sampled_layers)?
464            },
465            VisualizationFormat::StaticPng => {
466                anyhow::bail!("PNG generation not yet implemented - use SVG or HTML instead")
467            },
468        };
469
470        // Save to file if path provided
471        if let Some(ref path) = output_path {
472            std::fs::write(path, &output_data)
473                .with_context(|| format!("Failed to write visualization to {}", path))?;
474            info!("Saved visualization to {}", path);
475        }
476
477        let time_taken = start_time.elapsed().as_secs_f64();
478        let state = self.state.read();
479
480        Ok(VisualizationResult {
481            output_path,
482            inline_data: if output_size < 1024 * 1024 { Some(output_data) } else { None }, // Include inline if < 1MB
483            stats: VisualizationStats {
484                layers_visualized: sampled_layers.len(),
485                total_layers: state.total_layers,
486                sampling_ratio: sampled_layers.len() as f64 / state.total_layers as f64,
487                memory_used_mb: state.current_memory_mb,
488                time_taken_secs: time_taken,
489                output_size_bytes: output_size,
490            },
491            sampled_layers,
492            model_stats,
493        })
494    }
495
496    /// Calculate overall model statistics
497    fn calculate_model_stats(&self) -> Result<ModelStatistics> {
498        let cache = self.layer_cache.read();
499
500        let total_params: usize = cache.values().map(|l| l.param_count).sum();
501        let total_memory_mb: f64 = cache.values().map(|l| l.memory_mb).sum();
502        let total_gflops: f64 = cache.values().map(|l| l.compute_flops).sum::<u64>() as f64 / 1e9;
503        let max_depth = cache.values().map(|l| l.index).max().unwrap_or(0);
504
505        let mut layer_types: HashMap<String, usize> = HashMap::new();
506        for metadata in cache.values() {
507            *layer_types.entry(metadata.layer_type.clone()).or_insert(0) += 1;
508        }
509
510        Ok(ModelStatistics {
511            total_params,
512            total_memory_mb,
513            total_gflops,
514            max_depth,
515            layer_types,
516        })
517    }
518
519    /// Generate text summary (minimal memory)
520    fn generate_text_summary(&self, sampled_layers: &[usize]) -> Result<(Vec<u8>, usize)> {
521        let cache = self.layer_cache.read();
522
523        let mut summary = String::from("=== Large Model Visualization Summary ===\n\n");
524
525        summary.push_str(&format!(
526            "Total Layers: {}\n",
527            self.state.read().total_layers
528        ));
529        summary.push_str(&format!("Visualized Layers: {}\n\n", sampled_layers.len()));
530
531        summary.push_str("Layer Details:\n");
532        for &idx in sampled_layers {
533            if let Some(layer) = cache.values().find(|l| l.index == idx) {
534                summary.push_str(&format!(
535                    "  [{}] {} - {} params, {:.2} MB, {:.1} GFLOPS\n",
536                    layer.index,
537                    layer.name,
538                    layer.param_count,
539                    layer.memory_mb,
540                    layer.compute_flops as f64 / 1e9
541                ));
542            }
543        }
544
545        let bytes = summary.into_bytes();
546        let size = bytes.len();
547        Ok((bytes, size))
548    }
549
550    /// Generate JSON metadata
551    fn generate_json_metadata(&self, sampled_layers: &[usize]) -> Result<(Vec<u8>, usize)> {
552        let cache = self.layer_cache.read();
553
554        let layers: Vec<_> = sampled_layers
555            .iter()
556            .filter_map(|&idx| cache.values().find(|l| l.index == idx).cloned())
557            .collect();
558
559        let json = serde_json::to_string_pretty(&layers)?;
560        let bytes = json.into_bytes();
561        let size = bytes.len();
562        Ok((bytes, size))
563    }
564
565    /// Generate static SVG
566    fn generate_static_svg(&self, sampled_layers: &[usize]) -> Result<(Vec<u8>, usize)> {
567        let cache = self.layer_cache.read();
568
569        let mut svg = String::from(
570            r#"<?xml version="1.0" encoding="UTF-8"?>
571<svg xmlns="http://www.w3.org/2000/svg" width="1200" height="800" viewBox="0 0 1200 800">
572<style>
573.layer { fill: #4a90e2; stroke: #2c5aa0; stroke-width: 2; }
574.layer-text { fill: white; font-family: Arial, sans-serif; font-size: 12px; }
575.title { font-family: Arial, sans-serif; font-size: 20px; font-weight: bold; }
576</style>
577<text x="600" y="30" class="title" text-anchor="middle">Model Architecture</text>
578"#,
579        );
580
581        let layer_height = 60;
582        let layer_width = 200;
583        let x_offset = 500;
584        let y_start = 60;
585
586        for (i, &idx) in sampled_layers.iter().enumerate() {
587            if let Some(layer) = cache.values().find(|l| l.index == idx) {
588                let y = y_start + i * (layer_height + 20);
589
590                svg.push_str(&format!(
591                    r#"<rect x="{}" y="{}" width="{}" height="{}" class="layer" />
592<text x="{}" y="{}" class="layer-text" text-anchor="middle">{}</text>
593<text x="{}" y="{}" class="layer-text" text-anchor="middle">{:.1}M params</text>
594"#,
595                    x_offset,
596                    y,
597                    layer_width,
598                    layer_height,
599                    x_offset + layer_width / 2,
600                    y + 25,
601                    layer.name,
602                    x_offset + layer_width / 2,
603                    y + 45,
604                    layer.param_count as f64 / 1e6
605                ));
606            }
607        }
608
609        svg.push_str("</svg>");
610
611        let bytes = svg.into_bytes();
612        let size = bytes.len();
613        Ok((bytes, size))
614    }
615
616    /// Generate interactive SVG with zoom/pan
617    fn generate_interactive_svg(&self, sampled_layers: &[usize]) -> Result<(Vec<u8>, usize)> {
618        // For now, delegate to static SVG
619        // TODO: Add pan/zoom JavaScript
620        self.generate_static_svg(sampled_layers)
621    }
622
623    /// Generate interactive HTML with JavaScript
624    fn generate_interactive_html(&self, sampled_layers: &[usize]) -> Result<(Vec<u8>, usize)> {
625        let cache = self.layer_cache.read();
626        let model_stats = self.calculate_model_stats()?;
627
628        let mut html = String::from(
629            r#"<!DOCTYPE html>
630<html>
631<head>
632<meta charset="UTF-8">
633<title>Large Model Visualization</title>
634<style>
635body { font-family: Arial, sans-serif; margin: 20px; background: #f5f5f5; }
636.container { max-width: 1200px; margin: 0 auto; }
637.header { background: #4a90e2; color: white; padding: 20px; border-radius: 8px; }
638.stats { display: grid; grid-template-columns: repeat(4, 1fr); gap: 15px; margin: 20px 0; }
639.stat-card { background: white; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); }
640.layer-list { background: white; padding: 20px; border-radius: 8px; }
641.layer { padding: 10px; margin: 5px 0; background: #f9f9f9; border-left: 4px solid #4a90e2; }
642</style>
643</head>
644<body>
645<div class="container">
646<div class="header">
647<h1>Large Model Visualization</h1>
648<p>Interactive view of model architecture</p>
649</div>
650<div class="stats">
651"#,
652        );
653
654        // Add stats cards
655        html.push_str(&format!(
656            r#"<div class="stat-card">
657<h3>{:.1}M</h3>
658<p>Total Parameters</p>
659</div>
660<div class="stat-card">
661<h3>{:.1} GB</h3>
662<p>Total Memory</p>
663</div>
664<div class="stat-card">
665<h3>{}</h3>
666<p>Total Layers</p>
667</div>
668<div class="stat-card">
669<h3>{}/{}</h3>
670<p>Visualized/Total</p>
671</div>
672"#,
673            model_stats.total_params as f64 / 1e6,
674            model_stats.total_memory_mb / 1024.0,
675            model_stats.max_depth + 1,
676            sampled_layers.len(),
677            self.state.read().total_layers
678        ));
679
680        html.push_str("</div><div class=\"layer-list\"><h2>Layer Details</h2>");
681
682        // Add layer details
683        for &idx in sampled_layers {
684            if let Some(layer) = cache.values().find(|l| l.index == idx) {
685                html.push_str(&format!(
686                    r#"<div class="layer">
687<strong>[{}] {}</strong><br>
688Type: {} | Parameters: {:.1}M | Memory: {:.2} MB | Compute: {:.1} GFLOPS
689</div>
690"#,
691                    layer.index,
692                    layer.name,
693                    layer.layer_type,
694                    layer.param_count as f64 / 1e6,
695                    layer.memory_mb,
696                    layer.compute_flops as f64 / 1e9
697                ));
698            }
699        }
700
701        html.push_str("</div></div></body></html>");
702
703        let bytes = html.into_bytes();
704        let size = bytes.len();
705        Ok((bytes, size))
706    }
707
708    /// Get current visualization progress (0.0-1.0)
709    pub fn get_progress(&self) -> f64 {
710        self.state.read().progress
711    }
712
713    /// Get memory usage statistics
714    pub fn get_memory_stats(&self) -> MemoryStats {
715        let state = self.state.read();
716        MemoryStats {
717            current_mb: state.current_memory_mb,
718            max_mb: self.config.max_memory_mb as f64,
719            utilization_pct: (state.current_memory_mb / self.config.max_memory_mb as f64 * 100.0)
720                .min(100.0),
721        }
722    }
723}
724
725/// Memory usage statistics
726#[derive(Debug, Clone, Serialize, Deserialize)]
727pub struct MemoryStats {
728    /// Current memory usage (MB)
729    pub current_mb: f64,
730    /// Maximum allowed memory (MB)
731    pub max_mb: f64,
732    /// Utilization percentage
733    pub utilization_pct: f64,
734}
735
736#[cfg(test)]
737mod tests {
738    use super::*;
739
740    #[test]
741    fn test_visualizer_creation() {
742        let config = LargeModelVisualizerConfig::default();
743        let _visualizer = LargeModelVisualizer::new(config);
744    }
745
746    #[test]
747    fn test_add_layers() -> Result<()> {
748        let config = LargeModelVisualizerConfig::default();
749        let visualizer = LargeModelVisualizer::new(config);
750
751        for i in 0..10 {
752            let metadata = LayerMetadata {
753                name: format!("layer_{}", i),
754                index: i,
755                layer_type: "Linear".to_string(),
756                param_count: 1024 * 1024,
757                memory_mb: 4.0,
758                compute_flops: 1_000_000_000,
759                input_shape: vec![512],
760                output_shape: vec![512],
761                is_sampled: false,
762            };
763            visualizer.add_layer(metadata)?;
764        }
765
766        let stats = visualizer.get_memory_stats();
767        assert_eq!(stats.current_mb, 40.0);
768
769        Ok(())
770    }
771
772    #[test]
773    fn test_uniform_sampling() -> Result<()> {
774        let config = LargeModelVisualizerConfig {
775            max_full_layers: 5,
776            sampling_strategy: SamplingStrategy::Uniform,
777            ..Default::default()
778        };
779
780        let visualizer = LargeModelVisualizer::new(config);
781
782        // Add 20 layers
783        for i in 0..20 {
784            let metadata = LayerMetadata {
785                name: format!("layer_{}", i),
786                index: i,
787                layer_type: "Linear".to_string(),
788                param_count: 1024 * 1024,
789                memory_mb: 4.0,
790                compute_flops: 1_000_000_000,
791                input_shape: vec![512],
792                output_shape: vec![512],
793                is_sampled: false,
794            };
795            visualizer.add_layer(metadata)?;
796        }
797
798        let sampled = visualizer.determine_sampling()?;
799        assert_eq!(sampled.len(), 5);
800
801        Ok(())
802    }
803
804    #[test]
805    fn test_text_visualization() -> Result<()> {
806        let config = LargeModelVisualizerConfig {
807            output_format: VisualizationFormat::TextSummary,
808            ..Default::default()
809        };
810
811        let visualizer = LargeModelVisualizer::new(config);
812
813        // Add a few layers
814        for i in 0..5 {
815            let metadata = LayerMetadata {
816                name: format!("layer_{}", i),
817                index: i,
818                layer_type: "Linear".to_string(),
819                param_count: 1024 * 1024 * (i + 1),
820                memory_mb: 4.0 * (i + 1) as f64,
821                compute_flops: 1_000_000_000 * (i + 1) as u64,
822                input_shape: vec![512],
823                output_shape: vec![512],
824                is_sampled: false,
825            };
826            visualizer.add_layer(metadata)?;
827        }
828
829        let result = visualizer.visualize(None)?;
830
831        assert_eq!(result.stats.layers_visualized, 5);
832        assert!(result.stats.output_size_bytes > 0);
833
834        Ok(())
835    }
836}