Skip to main content

scirs2_neural/visualization/
activations.rs

1//! Layer activation and feature visualization for neural networks
2//!
3//! This module provides comprehensive tools for visualizing layer activations,
4//! feature maps, activation histograms, and attention patterns.
5
6use super::config::VisualizationConfig;
7use crate::error::{NeuralError, Result};
8use crate::models::sequential::Sequential;
9use scirs2_core::ndarray::ArrayStatCompat;
10use scirs2_core::ndarray::{ArrayD, ScalarOperand};
11use scirs2_core::numeric::Float;
12use scirs2_core::NumAssign;
13use serde::Serialize;
14use statrs::statistics::Statistics;
15use std::collections::HashMap;
16use std::fmt::Debug;
17use std::path::PathBuf;
18/// Layer activation visualizer
19#[allow(dead_code)]
20pub struct ActivationVisualizer<F: Float + Debug + ScalarOperand + NumAssign> {
21    /// Model reference
22    model: Sequential<F>,
23    /// Visualization configuration
24    config: VisualizationConfig,
25    /// Cached activations
26    activation_cache: HashMap<String, ArrayD<F>>,
27}
28/// Activation visualization options
29#[derive(Debug, Clone, Serialize)]
30pub struct ActivationVisualizationOptions {
31    /// Layers to visualize
32    pub target_layers: Vec<String>,
33    /// Visualization type
34    pub visualization_type: ActivationVisualizationType,
35    /// Normalization method
36    pub normalization: ActivationNormalization,
37    /// Color mapping
38    pub colormap: Colormap,
39    /// Aggregation method for multi-channel data
40    pub aggregation: ChannelAggregation,
41}
42
43/// Types of activation visualizations
44#[derive(Debug, Clone, PartialEq, Serialize)]
45pub enum ActivationVisualizationType {
46    /// Feature maps as heatmaps
47    FeatureMaps,
48    /// Activation histograms
49    Histograms,
50    /// Statistics summary
51    Statistics,
52    /// Spatial attention maps
53    AttentionMaps,
54    /// Activation flow
55    ActivationFlow,
56}
57
58/// Activation normalization methods
59#[derive(Debug, Clone, PartialEq, Serialize)]
60pub enum ActivationNormalization {
61    /// No normalization
62    None,
63    /// Min-max normalization to [0, 1]
64    MinMax,
65    /// Z-score normalization
66    ZScore,
67    /// Percentile-based normalization
68    Percentile(f64, f64),
69    /// Custom normalization function
70    Custom(String),
71}
72
73/// Color mapping for visualizations
74#[derive(Debug, Clone, PartialEq, Serialize)]
75pub enum Colormap {
76    /// Viridis colormap
77    Viridis,
78    /// Plasma colormap
79    Plasma,
80    /// Inferno colormap
81    Inferno,
82    /// Jet colormap
83    Jet,
84    /// Grayscale
85    Gray,
86    /// Red-blue diverging
87    RdBu,
88    /// Custom colormap
89    Custom(Vec<String>),
90}
91
92/// Channel aggregation methods
93#[derive(Debug, Clone, PartialEq, Serialize)]
94pub enum ChannelAggregation {
95    /// No aggregation (show all channels)
96    None,
97    /// Average across channels
98    Mean,
99    /// Maximum across channels
100    Max,
101    /// Minimum across channels
102    Min,
103    /// Standard deviation across channels
104    Std,
105    /// Select specific channels
106    Select(Vec<usize>),
107}
108
109/// Activation statistics for a layer
110#[derive(Debug, Clone, Serialize)]
111pub struct ActivationStatistics<F: Float + Debug + serde::Serialize + NumAssign> {
112    /// Layer name
113    pub layer_name: String,
114    /// Mean activation value
115    pub mean: F,
116    /// Standard deviation
117    pub std: F,
118    /// Minimum value
119    pub min: F,
120    /// Maximum value
121    pub max: F,
122    /// Percentiles (5%, 25%, 50%, 75%, 95%)
123    pub percentiles: [F; 5],
124    /// Sparsity (fraction of zero or near-zero activations)
125    pub sparsity: f64,
126    /// Dead neurons (always zero)
127    pub dead_neurons: usize,
128    /// Total neurons
129    pub total_neurons: usize,
130}
131
132/// Feature map information
133#[derive(Debug, Clone, Serialize)]
134pub struct FeatureMapInfo {
135    /// Feature map index
136    pub feature_index: usize,
137    /// Spatial dimensions (height, width)
138    pub spatial_dims: (usize, usize),
139    /// Channel dimension
140    pub channels: usize,
141    /// Activation range (min, max)
142    pub activation_range: (f64, f64),
143}
144
145/// Activation histogram data
146#[derive(Debug, Clone, Serialize)]
147pub struct ActivationHistogram<F: Float + Debug + NumAssign> {
148    /// Layer name
149    pub layer_name: String,
150    /// Histogram bins
151    pub bins: Vec<F>,
152    /// Bin counts
153    pub counts: Vec<usize>,
154    /// Bin edges
155    pub edges: Vec<F>,
156    /// Total sample count
157    pub total_samples: usize,
158}
159
160// Implementation for ActivationVisualizer
161impl<
162        F: Float
163            + Debug
164            + std::fmt::Display
165            + 'static
166            + scirs2_core::numeric::FromPrimitive
167            + ScalarOperand
168            + Send
169            + Sync
170            + serde::Serialize
171            + NumAssign,
172    > ActivationVisualizer<F>
173{
174    /// Create a new activation visualizer
175    pub fn new(model: Sequential<F>, config: VisualizationConfig) -> Self {
176        Self {
177            model,
178            config,
179            activation_cache: HashMap::new(),
180        }
181    }
182    /// Visualize layer activations for given input
183    pub fn visualize_activations(
184        &mut self,
185        input: &ArrayD<F>,
186        options: &ActivationVisualizationOptions,
187    ) -> Result<Vec<PathBuf>> {
188        // Compute activations
189        self.compute_activations(input, &options.target_layers)?;
190        // Generate visualizations based on type
191        match options.visualization_type {
192            ActivationVisualizationType::FeatureMaps => self.generate_feature_maps(options),
193            ActivationVisualizationType::Histograms => self.generate_histograms(options),
194            ActivationVisualizationType::Statistics => self.generate_statistics(options),
195            ActivationVisualizationType::AttentionMaps => self.generate_attention_maps(options),
196            ActivationVisualizationType::ActivationFlow => self.generate_activation_flow(options),
197        }
198    }
199
200    /// Get cached activations for a layer
201    pub fn get_cached_activations(&self, layer_name: &str) -> Option<&ArrayD<F>> {
202        self.activation_cache.get(layer_name)
203    }
204
205    /// Clear the activation cache
206    pub fn clear_cache(&mut self) {
207        self.activation_cache.clear();
208    }
209
210    /// Get activation statistics for all cached layers
211    pub fn get_activation_statistics(&self) -> Result<Vec<ActivationStatistics<F>>> {
212        let mut stats = Vec::new();
213        for (layer_name, activations) in &self.activation_cache {
214            let layer_stats = self.compute_layer_statistics(layer_name, activations)?;
215            stats.push(layer_stats);
216        }
217        Ok(stats)
218    }
219
220    /// Update the visualization configuration
221    pub fn update_config(&mut self, config: VisualizationConfig) {
222        self.config = config;
223    }
224    fn compute_activations(&mut self, input: &ArrayD<F>, target_layers: &[String]) -> Result<()> {
225        let mut current_output = input.clone();
226        // Store input if requested
227        if target_layers.is_empty() || target_layers.contains(&"input".to_string()) {
228            self.activation_cache
229                .insert("input".to_string(), input.clone());
230        }
231        // Run forward pass through each layer and capture activations
232        for (layer_idx, layer) in self.model.layers().iter().enumerate() {
233            current_output = layer.forward(&current_output)?;
234            let layer_name = format!("layer_{}", layer_idx);
235            // Store activation if this layer is requested or if no specific layers requested
236            if target_layers.is_empty() || target_layers.contains(&layer_name) {
237                self.activation_cache
238                    .insert(layer_name, current_output.clone());
239            }
240        }
241        Ok(())
242    }
243    fn generate_feature_maps(
244        &self,
245        options: &ActivationVisualizationOptions,
246    ) -> Result<Vec<PathBuf>> {
247        let mut output_paths = Vec::new();
248        for layer_name in &options.target_layers {
249            if let Some(activations) = self.activation_cache.get(layer_name) {
250                let feature_maps = self.process_activations_for_visualization(
251                    activations,
252                    &options.normalization,
253                    &options.aggregation,
254                )?;
255                // Generate SVG visualization
256                let svg_content =
257                    self.create_feature_map_svg(&feature_maps, layer_name, &options.colormap)?;
258                let output_path = self
259                    .config
260                    .output_dir
261                    .join(format!("{}_feature_maps.svg", layer_name));
262                std::fs::write(&output_path, svg_content).map_err(|e| {
263                    NeuralError::IOError(format!("Failed to write feature map: {}", e))
264                })?;
265                output_paths.push(output_path);
266            }
267        }
268        Ok(output_paths)
269    }
270    fn generate_histograms(
271        &self,
272        options: &ActivationVisualizationOptions,
273    ) -> Result<Vec<PathBuf>> {
274        let mut output_paths = Vec::new();
275        for layer_name in &options.target_layers {
276            if let Some(activations) = self.activation_cache.get(layer_name) {
277                let histogram = self.compute_activation_histogram(layer_name, activations, 50)?;
278                // Generate SVG histogram
279                let svg_content = self.create_histogram_svg(&histogram)?;
280                let output_path = self
281                    .config
282                    .output_dir
283                    .join(format!("{}_histogram.svg", layer_name));
284                std::fs::write(&output_path, svg_content).map_err(|e| {
285                    NeuralError::IOError(format!("Failed to write histogram: {}", e))
286                })?;
287                output_paths.push(output_path);
288            }
289        }
290        Ok(output_paths)
291    }
292    fn generate_statistics(
293        &self,
294        options: &ActivationVisualizationOptions,
295    ) -> Result<Vec<PathBuf>> {
296        let mut all_stats = Vec::new();
297        for layer_name in &options.target_layers {
298            if let Some(activations) = self.activation_cache.get(layer_name) {
299                let stats = self.compute_layer_statistics(layer_name, activations)?;
300                all_stats.push(stats);
301            }
302        }
303        // Generate JSON statistics report
304        let json_content = serde_json::to_string_pretty(&all_stats).map_err(|e| {
305            NeuralError::SerializationError(format!("Failed to serialize statistics: {}", e))
306        })?;
307        let json_path = self.config.output_dir.join("activation_statistics.json");
308        std::fs::write(&json_path, json_content)
309            .map_err(|e| NeuralError::IOError(format!("Failed to write statistics: {}", e)))?;
310        // Generate SVG statistics visualization
311        let svg_content = self.create_statistics_svg(&all_stats)?;
312        let svg_path = self.config.output_dir.join("activation_statistics.svg");
313        std::fs::write(&svg_path, svg_content).map_err(|e| {
314            NeuralError::IOError(format!("Failed to write statistics visualization: {}", e))
315        })?;
316        Ok(vec![json_path, svg_path])
317    }
318    fn generate_attention_maps(
319        &self,
320        options: &ActivationVisualizationOptions,
321    ) -> Result<Vec<PathBuf>> {
322        let mut output_paths = Vec::new();
323        for layer_name in &options.target_layers {
324            if let Some(activations) = self.activation_cache.get(layer_name) {
325                // Check if activations have spatial dimensions suitable for attention maps
326                if activations.ndim() >= 3 {
327                    let attention_map = self.compute_spatial_attention(activations)?;
328                    let svg_content = self.create_attention_map_svg(&attention_map, layer_name)?;
329                    let output_path = self
330                        .config
331                        .output_dir
332                        .join(format!("{}_attention.svg", layer_name));
333                    std::fs::write(&output_path, svg_content).map_err(|e| {
334                        NeuralError::IOError(format!("Failed to write attention map: {}", e))
335                    })?;
336                    output_paths.push(output_path);
337                }
338            }
339        }
340        Ok(output_paths)
341    }
342    fn generate_activation_flow(
343        &self,
344        options: &ActivationVisualizationOptions,
345    ) -> Result<Vec<PathBuf>> {
346        // Compute activation flow between consecutive layers
347        let mut flow_data = Vec::new();
348        let sorted_layers: Vec<_> = options.target_layers.iter().collect();
349        for i in 0..sorted_layers.len().saturating_sub(1) {
350            let from_layer = sorted_layers[i];
351            let to_layer = sorted_layers[i + 1];
352            if let (Some(from_activations), Some(to_activations)) = (
353                self.activation_cache.get(from_layer),
354                self.activation_cache.get(to_layer),
355            ) {
356                let flow_intensity =
357                    self.compute_activation_flow(from_activations, to_activations)?;
358                flow_data.push((from_layer.clone(), to_layer.clone(), flow_intensity));
359            }
360        }
361        if !flow_data.is_empty() {
362            let svg_content = self.create_flow_diagram_svg(&flow_data)?;
363            let output_path = self.config.output_dir.join("activation_flow.svg");
364            std::fs::write(&output_path, svg_content).map_err(|e| {
365                NeuralError::IOError(format!("Failed to write activation flow: {}", e))
366            })?;
367            Ok(vec![output_path])
368        } else {
369            Ok(Vec::new())
370        }
371    }
372    fn compute_layer_statistics(
373        &self,
374        layer_name: &str,
375        activations: &ArrayD<F>,
376    ) -> Result<ActivationStatistics<F>> {
377        let total_elements = activations.len();
378        if total_elements == 0 {
379            return Err(NeuralError::InvalidArgument(
380                "Empty activation tensor".to_string(),
381            ));
382        }
383        // Compute basic statistics
384        let mut sum = F::zero();
385        let mut min_val = F::infinity();
386        let mut max_val = F::neg_infinity();
387        let mut zero_count = 0;
388        for &val in activations.iter() {
389            sum += val;
390            if val < min_val {
391                min_val = val;
392            }
393            if val > max_val {
394                max_val = val;
395            }
396            if val.abs() < F::from(1e-6).unwrap_or(F::zero()) {
397                zero_count += 1;
398            }
399        }
400        let mean = sum / F::from(total_elements).unwrap_or(F::one());
401        // Compute standard deviation
402        let mut variance_sum = F::zero();
403        for &val in activations.iter() {
404            let diff = val - mean;
405            variance_sum += diff * diff;
406        }
407        let variance = variance_sum / F::from(total_elements - 1).unwrap_or(F::one());
408        let std = variance.sqrt();
409        // Compute percentiles (simplified implementation)
410        let mut sorted_values: Vec<F> = activations.iter().copied().collect();
411        sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
412        let percentiles = [
413            sorted_values[total_elements * 5 / 100],  // 5%
414            sorted_values[total_elements * 25 / 100], // 25%
415            sorted_values[total_elements * 50 / 100], // 50%
416            sorted_values[total_elements * 75 / 100], // 75%
417            sorted_values[total_elements * 95 / 100], // 95%
418        ];
419        let sparsity = zero_count as f64 / total_elements as f64;
420        Ok(ActivationStatistics {
421            layer_name: layer_name.to_string(),
422            mean,
423            std,
424            min: min_val,
425            max: max_val,
426            percentiles,
427            sparsity,
428            dead_neurons: zero_count,
429            total_neurons: total_elements,
430        })
431    }
432    fn process_activations_for_visualization(
433        &self,
434        activations: &ArrayD<F>,
435        normalization: &ActivationNormalization,
436        aggregation: &ChannelAggregation,
437    ) -> Result<ArrayD<F>> {
438        let mut processed = activations.clone();
439        // Apply channel aggregation first
440        processed = match aggregation {
441            ChannelAggregation::None => processed,
442            ChannelAggregation::Mean => {
443                if processed.ndim() > 2 {
444                    let mean_axis = processed.ndim() - 1; // Usually channel is last dimension
445                    processed
446                        .mean_axis(scirs2_core::ndarray::Axis(mean_axis))
447                        .expect("Operation failed")
448                        .insert_axis(scirs2_core::ndarray::Axis(mean_axis))
449                } else {
450                    processed
451                }
452            }
453            ChannelAggregation::Max => {
454                if processed.ndim() > 2 {
455                    let max_axis = processed.ndim() - 1;
456                    let max_values = processed.fold_axis(
457                        scirs2_core::ndarray::Axis(max_axis),
458                        F::neg_infinity(),
459                        |&acc, &x| acc.max(x),
460                    );
461                    max_values.insert_axis(scirs2_core::ndarray::Axis(max_axis))
462                } else {
463                    processed
464                }
465            }
466            ChannelAggregation::Min => {
467                if processed.ndim() > 2 {
468                    let min_axis = processed.ndim() - 1;
469                    let min_values = processed.fold_axis(
470                        scirs2_core::ndarray::Axis(min_axis),
471                        F::infinity(),
472                        |&acc, &x| acc.min(x),
473                    );
474                    min_values.insert_axis(scirs2_core::ndarray::Axis(min_axis))
475                } else {
476                    processed
477                }
478            }
479            ChannelAggregation::Std => {
480                if processed.ndim() > 2 {
481                    let std_axis = processed.ndim() - 1;
482                    let mean = processed
483                        .mean_axis(scirs2_core::ndarray::Axis(std_axis))
484                        .expect("Operation failed");
485                    let variance =
486                        processed.map_axis(scirs2_core::ndarray::Axis(std_axis), |channel| {
487                            let mean_val = mean.iter().next().copied().unwrap_or(F::zero());
488                            let variance_sum = channel
489                                .iter()
490                                .map(|&x| (x - mean_val) * (x - mean_val))
491                                .fold(F::zero(), |acc, x| acc + x);
492                            (variance_sum / F::from(channel.len()).unwrap_or(F::one())).sqrt()
493                        });
494                    variance.insert_axis(scirs2_core::ndarray::Axis(std_axis))
495                } else {
496                    processed
497                }
498            }
499            ChannelAggregation::Select(channels) => {
500                if processed.ndim() > 2 && !channels.is_empty() {
501                    let channel_axis = processed.ndim() - 1;
502                    let mut selected_slices = Vec::new();
503                    for &channel_idx in channels {
504                        if channel_idx < processed.shape()[channel_axis] {
505                            let slice = processed
506                                .index_axis(scirs2_core::ndarray::Axis(channel_axis), channel_idx);
507                            selected_slices
508                                .push(slice.insert_axis(scirs2_core::ndarray::Axis(channel_axis)));
509                        }
510                    }
511                    if !selected_slices.is_empty() {
512                        scirs2_core::ndarray::concatenate(
513                            scirs2_core::ndarray::Axis(channel_axis),
514                            &selected_slices.iter().map(|x| x.view()).collect::<Vec<_>>(),
515                        )
516                        .map_err(|_| {
517                            NeuralError::DimensionMismatch(
518                                "Failed to concatenate selected channels".to_string(),
519                            )
520                        })?
521                    } else {
522                        processed
523                    }
524                } else {
525                    processed
526                }
527            }
528        };
529        // Apply normalization
530        processed = match normalization {
531            ActivationNormalization::None => processed,
532            ActivationNormalization::MinMax => {
533                let min_val = processed.iter().copied().fold(F::infinity(), F::min);
534                let max_val = processed.iter().copied().fold(F::neg_infinity(), F::max);
535                let range = max_val - min_val;
536                if range > F::zero() {
537                    processed.mapv(|x| (x - min_val) / range)
538                } else {
539                    processed.mapv(|_| F::zero())
540                }
541            }
542            ActivationNormalization::ZScore => {
543                let mean = processed.mean_or(F::zero());
544                let variance = processed
545                    .iter()
546                    .map(|&x| (x - mean) * (x - mean))
547                    .fold(F::zero(), |acc, x| acc + x)
548                    / F::from(processed.len()).unwrap_or(F::one());
549                let std = variance.sqrt();
550                if std > F::zero() {
551                    processed.mapv(|x| (x - mean) / std)
552                } else {
553                    processed.mapv(|_| F::zero())
554                }
555            }
556            ActivationNormalization::Percentile(low, high) => {
557                let mut values: Vec<F> = processed.iter().copied().collect();
558                values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
559                let n = values.len();
560                let low_idx = ((low / 100.0) * n as f64) as usize;
561                let high_idx = ((high / 100.0) * n as f64) as usize;
562                if low_idx < n && high_idx < n && low_idx < high_idx {
563                    let low_val = values[low_idx];
564                    let high_val = values[high_idx];
565                    let range = high_val - low_val;
566                    if range > F::zero() {
567                        processed.mapv(|x| ((x - low_val) / range).max(F::zero()).min(F::one()))
568                    } else {
569                        processed.mapv(|_| F::zero())
570                    }
571                } else {
572                    processed
573                }
574            }
575            ActivationNormalization::Custom(_) => {
576                // Custom normalization would require function pointer or callback
577                // For now, fall back to no normalization
578                processed
579            }
580        };
581        Ok(processed)
582    }
583    fn create_feature_map_svg(
584        &self,
585        feature_maps: &ArrayD<F>,
586        layer_name: &str,
587        colormap: &Colormap,
588    ) -> Result<String> {
589        let width = self.config.style.layout.width;
590        let height = self.config.style.layout.height;
591        // Get color scheme
592        let colors = self.get_colormap_colors(colormap);
593        let mut svg = format!(
594            r#"<svg xmlns="http://www.w3.org/2000/svg" width="{}" height="{}" viewBox="0 0 {} {}">
595"#,
596            width, height, width, height
597        );
598        // Title
599        svg.push_str(&format!(
600            "<text x=\"{}\" y=\"30\" text-anchor=\"middle\" font-family=\"{}\" font-size=\"{}\" fill=\"#333\">{} Feature Maps</text>\n",
601            width / 2, self.config.style.font.family,
602            (self.config.style.font.size as f32 * self.config.style.font.title_scale) as u32,
603            layer_name
604        ));
605        // Simple grid visualization of feature maps
606        if feature_maps.ndim() >= 2 {
607            let shape = feature_maps.shape();
608            let map_height = shape[0].min(32); // Limit visualization size
609            let map_width = shape[1].min(32);
610            let cell_width = (width - 100) / map_width as u32;
611            let cell_height = (height - 100) / map_height as u32;
612            for i in 0..map_height {
613                for j in 0..map_width {
614                    let value = if let Some(&val) = feature_maps.get([i, j].as_slice()) {
615                        val
616                    } else {
617                        F::zero()
618                    };
619                    let intensity = (value.to_f64().unwrap_or(0.0) * 255.0).clamp(0.0, 255.0) as u8;
620                    let color = if colors.len() > 1 {
621                        // Interpolate between colors
622                        let color_idx =
623                            (intensity as f64 / 255.0 * (colors.len() - 1) as f64) as usize;
624                        colors[color_idx.min(colors.len() - 1)].clone()
625                    } else {
626                        format!("rgb({},{},{})", intensity, intensity, intensity)
627                    };
628                    svg.push_str(&format!(
629                        "<rect x=\"{}\" y=\"{}\" width=\"{}\" height=\"{}\" fill=\"{}\" stroke=\"#ccc\" stroke-width=\"0.5\"/>\n",
630                        50 + j * cell_width as usize,
631                        50 + i * cell_height as usize,
632                        cell_width,
633                        cell_height,
634                        color
635                    ));
636                }
637            }
638        }
639        svg.push_str("</svg>");
640        Ok(svg)
641    }
642    fn compute_activation_histogram(
643        &self,
644        layer_name: &str,
645        activations: &ArrayD<F>,
646        num_bins: usize,
647    ) -> Result<ActivationHistogram<F>> {
648        let values: Vec<F> = activations.iter().copied().collect();
649        if values.is_empty() {
650            return Err(NeuralError::ValidationError(
651                "Empty activations for histogram".to_string(),
652            ));
653        }
654        let min_val = values.iter().copied().fold(F::infinity(), F::min);
655        let max_val = values.iter().copied().fold(F::neg_infinity(), F::max);
656        let range = max_val - min_val;
657        if range <= F::zero() {
658            // All values are the same
659            return Ok(ActivationHistogram {
660                layer_name: layer_name.to_string(),
661                bins: vec![min_val],
662                counts: vec![values.len()],
663                edges: vec![min_val, max_val],
664                total_samples: values.len(),
665            });
666        }
667        let bin_width = range / F::from(num_bins).unwrap_or(F::one());
668        let mut bins = Vec::with_capacity(num_bins);
669        let mut counts = vec![0; num_bins];
670        let mut edges = Vec::with_capacity(num_bins + 1);
671        // Create bin edges and centers
672        for i in 0..=num_bins {
673            edges.push(min_val + F::from(i).unwrap_or(F::zero()) * bin_width);
674        }
675        for i in 0..num_bins {
676            bins.push(
677                min_val
678                    + (F::from(i).unwrap_or(F::zero()) + F::from(0.5).unwrap_or(F::zero()))
679                        * bin_width,
680            );
681        }
682        // Count values in each bin
683        for &value in &values {
684            let bin_idx = ((value - min_val) / bin_width)
685                .to_usize()
686                .unwrap_or(0)
687                .min(num_bins - 1);
688            counts[bin_idx] += 1;
689        }
690        Ok(ActivationHistogram {
691            layer_name: layer_name.to_string(),
692            bins,
693            counts,
694            edges,
695            total_samples: values.len(),
696        })
697    }
698
699    fn create_histogram_svg(&self, histogram: &ActivationHistogram<F>) -> Result<String> {
700        let width = self.config.style.layout.width;
701        let height = self.config.style.layout.height;
702        let mut svg = format!(
703            "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n\
704            <svg width=\"{}\" height=\"{}\" xmlns=\"http://www.w3.org/2000/svg\">\n",
705            width, height
706        );
707        svg.push_str("</svg>");
708        Ok(svg)
709    }
710
711    fn create_statistics_svg(&self, stats: &[ActivationStatistics<F>]) -> Result<String> {
712        let width = self.config.style.layout.width;
713        let height = self.config.style.layout.height;
714        let mut svg = format!(
715            "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n\
716            <svg width=\"{}\" height=\"{}\" xmlns=\"http://www.w3.org/2000/svg\">\n",
717            width, height
718        );
719        svg.push_str("</svg>");
720        Ok(svg)
721    }
722
723    fn compute_spatial_attention(&self, activations: &ArrayD<F>) -> Result<ArrayD<F>> {
724        // Stub implementation
725        Ok(activations.clone())
726    }
727
728    fn create_attention_map_svg(
729        &self,
730        attention_map: &ArrayD<F>,
731        layer_name: &str,
732    ) -> Result<String> {
733        let width = self.config.style.layout.width;
734        let height = self.config.style.layout.height;
735        let mut svg = format!(
736            "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n\
737            <svg width=\"{}\" height=\"{}\" xmlns=\"http://www.w3.org/2000/svg\">\n",
738            width, height
739        );
740        svg.push_str("</svg>");
741        Ok(svg)
742    }
743
744    fn compute_activation_flow(
745        &self,
746        from_activations: &ArrayD<F>,
747        to_activations: &ArrayD<F>,
748    ) -> Result<f64> {
749        Ok(0.0)
750    }
751
752    fn create_flow_diagram_svg(&self, flow_data: &[(String, String, f64)]) -> Result<String> {
753        let width = self.config.style.layout.width;
754        let height = self.config.style.layout.height;
755        let mut svg = format!(
756            "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n\
757            <svg width=\"{}\" height=\"{}\" xmlns=\"http://www.w3.org/2000/svg\">\n",
758            width, height
759        );
760        svg.push_str(&format!(
761            "<text x=\"{}\" y=\"30\" text-anchor=\"middle\" font-family=\"{}\" font-size=\"{}\" fill=\"#333\">Activation Flow Diagram</text>\n",
762            width / 2, self.config.style.font.family, self.config.style.font.size
763        ));
764        svg.push_str("</svg>");
765        Ok(svg)
766    }
767
768    fn get_colormap_colors(&self, colormap: &Colormap) -> Vec<String> {
769        match colormap {
770            Colormap::Viridis => vec![
771                "#440154".to_string(),
772                "#482677".to_string(),
773                "#3f4a8a".to_string(),
774                "#31678e".to_string(),
775                "#26838f".to_string(),
776                "#1f9d8a".to_string(),
777                "#6cce5a".to_string(),
778                "#b6de2b".to_string(),
779                "#fee825".to_string(),
780            ],
781            Colormap::Plasma => vec![
782                "#0c0786".to_string(),
783                "#40039a".to_string(),
784                "#6a0a83".to_string(),
785                "#8b0aa5".to_string(),
786                "#a83eaf".to_string(),
787                "#c06fad".to_string(),
788                "#d8a1a3".to_string(),
789                "#f0d3a3".to_string(),
790                "#fcffa4".to_string(),
791            ],
792            Colormap::Inferno => vec![
793                "#000003".to_string(),
794                "#1f0c48".to_string(),
795                "#581845".to_string(),
796                "#8b1538".to_string(),
797                "#b71f2b".to_string(),
798                "#db4c26".to_string(),
799                "#ed7953".to_string(),
800                "#fbad76".to_string(),
801            ],
802            Colormap::Jet => vec![
803                "#00007f".to_string(),
804                "#0000ff".to_string(),
805                "#007fff".to_string(),
806                "#00ffff".to_string(),
807                "#7fff00".to_string(),
808                "#ffff00".to_string(),
809                "#ff7f00".to_string(),
810                "#ff0000".to_string(),
811                "#7f0000".to_string(),
812            ],
813            Colormap::Gray => vec![
814                "#000000".to_string(),
815                "#404040".to_string(),
816                "#808080".to_string(),
817                "#c0c0c0".to_string(),
818                "#ffffff".to_string(),
819            ],
820            Colormap::RdBu => vec![
821                "#053061".to_string(),
822                "#2166ac".to_string(),
823                "#4393c3".to_string(),
824                "#92c5de".to_string(),
825                "#d1e5f0".to_string(),
826                "#f7f7f7".to_string(),
827                "#fddbc7".to_string(),
828                "#f4a582".to_string(),
829                "#d6604d".to_string(),
830                "#b2182b".to_string(),
831                "#67001f".to_string(),
832            ],
833            Colormap::Custom(colors) => colors.clone(),
834        }
835    }
836}
837
838// Default implementations for configuration types
839impl Default for ActivationVisualizationOptions {
840    fn default() -> Self {
841        Self {
842            target_layers: Vec::new(),
843            visualization_type: ActivationVisualizationType::FeatureMaps,
844            normalization: ActivationNormalization::MinMax,
845            colormap: Colormap::Viridis,
846            aggregation: ChannelAggregation::Mean,
847        }
848    }
849}
850
851impl Default for FeatureMapInfo {
852    fn default() -> Self {
853        Self {
854            feature_index: 0,
855            spatial_dims: (1, 1),
856            channels: 1,
857            activation_range: (0.0, 1.0),
858        }
859    }
860}
861#[cfg(test)]
862mod tests {
863    use super::*;
864    use crate::layers::Dense;
865    use scirs2_core::random::SeedableRng;
866    #[test]
867    fn test_activation_visualizer_creation() {
868        let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(42);
869        let mut model = Sequential::<f32>::new();
870        model.add_layer(Dense::new(10, 5, Some("relu"), &mut rng).expect("Operation failed"));
871        let config = VisualizationConfig::default();
872        let visualizer = ActivationVisualizer::new(model, config);
873        assert!(visualizer.activation_cache.is_empty());
874    }
875
876    #[test]
877    fn test_activation_visualization_options_default() {
878        let options = ActivationVisualizationOptions::default();
879        assert_eq!(
880            options.visualization_type,
881            ActivationVisualizationType::FeatureMaps
882        );
883        assert_eq!(options.normalization, ActivationNormalization::MinMax);
884        assert_eq!(options.colormap, Colormap::Viridis);
885        assert_eq!(options.aggregation, ChannelAggregation::Mean);
886    }
887
888    #[test]
889    fn test_activation_visualization_types() {
890        let types = [
891            ActivationVisualizationType::FeatureMaps,
892            ActivationVisualizationType::Histograms,
893            ActivationVisualizationType::Statistics,
894            ActivationVisualizationType::AttentionMaps,
895            ActivationVisualizationType::ActivationFlow,
896        ];
897        assert_eq!(types.len(), 5);
898        assert_eq!(types[0], ActivationVisualizationType::FeatureMaps);
899    }
900
901    #[test]
902    fn test_normalization_methods() {
903        let none = ActivationNormalization::None;
904        let minmax = ActivationNormalization::MinMax;
905        let zscore = ActivationNormalization::ZScore;
906        let percentile = ActivationNormalization::Percentile(5.0, 95.0);
907        assert_eq!(none, ActivationNormalization::None);
908        assert_eq!(minmax, ActivationNormalization::MinMax);
909        assert_eq!(zscore, ActivationNormalization::ZScore);
910        match percentile {
911            ActivationNormalization::Percentile(low, high) => {
912                assert_eq!(low, 5.0);
913                assert_eq!(high, 95.0);
914            }
915            _ => unreachable!("Expected percentile normalization"),
916        }
917    }
918
919    #[test]
920    fn test_colormaps() {
921        let colormaps = [
922            Colormap::Viridis,
923            Colormap::Plasma,
924            Colormap::Inferno,
925            Colormap::Jet,
926            Colormap::Gray,
927            Colormap::RdBu,
928        ];
929        assert_eq!(colormaps.len(), 6);
930        assert_eq!(colormaps[0], Colormap::Viridis);
931        let custom = Colormap::Custom(vec!["#ff0000".to_string(), "#00ff00".to_string()]);
932        match custom {
933            Colormap::Custom(colors) => assert_eq!(colors.len(), 2),
934            _ => unreachable!("Expected custom colormap"),
935        }
936    }
937
938    #[test]
939    fn test_channel_aggregation() {
940        let aggregations = [
941            ChannelAggregation::None,
942            ChannelAggregation::Mean,
943            ChannelAggregation::Max,
944            ChannelAggregation::Min,
945            ChannelAggregation::Std,
946            ChannelAggregation::Select(vec![0, 1, 2]),
947        ];
948        assert_eq!(aggregations.len(), 6);
949        assert_eq!(aggregations[1], ChannelAggregation::Mean);
950        match &aggregations[5] {
951            ChannelAggregation::Select(channels) => assert_eq!(channels.len(), 3),
952            _ => unreachable!("Expected select aggregation"),
953        }
954    }
955
956    #[test]
957    fn test_feature_map_info_default() {
958        let info = FeatureMapInfo::default();
959        assert_eq!(info.feature_index, 0);
960        assert_eq!(info.spatial_dims, (1, 1));
961        assert_eq!(info.channels, 1);
962        assert_eq!(info.activation_range, (0.0, 1.0));
963    }
964
965    #[test]
966    fn test_cache_operations() {
967        let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(42);
968        let mut model = Sequential::<f32>::new();
969        model.add_layer(Dense::new(10, 5, Some("relu"), &mut rng).expect("Operation failed"));
970        let config = VisualizationConfig::default();
971        let mut visualizer = ActivationVisualizer::new(model, config);
972        assert!(visualizer.get_cached_activations("test_layer").is_none());
973        visualizer.clear_cache();
974    }
975}