scirs2_cluster/visualization/
animation.rs

1//! Advanced animation capabilities for clustering visualization
2//!
3//! This module provides sophisticated animation features for clustering algorithms,
4//! including 3D animations, convergence animations, real-time streaming visualizations,
5//! and export capabilities for creating videos and interactive presentations.
6
7use scirs2_core::ndarray::{s, Array1, Array2, Array3, ArrayView2};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use std::collections::{HashMap, VecDeque};
10use std::fmt::Debug;
11use std::time::{Duration, Instant};
12
13use serde::{Deserialize, Serialize};
14
15use super::{EasingFunction, ScatterPlot2D, ScatterPlot3D, VisualizationConfig};
16use crate::error::{ClusteringError, Result};
17
18/// Configuration for iterative algorithm animations (like K-means convergence)
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct IterativeAnimationConfig {
21    /// Capture frame every N iterations
22    pub capture_frequency: usize,
23    /// Interpolate between captured frames
24    pub interpolate_frames: bool,
25    /// Number of interpolation frames between captures
26    pub interpolation_frames: usize,
27    /// Animation speed (frames per second)
28    pub fps: f32,
29    /// Show convergence metrics overlay
30    pub show_convergence_overlay: bool,
31    /// Show iteration numbers
32    pub show_iteration_numbers: bool,
33    /// Highlight centroid movement
34    pub highlight_centroid_movement: bool,
35    /// Fade effect for old positions
36    pub fade_effect: bool,
37    /// Trail length for moving points
38    pub trail_length: usize,
39}
40
41impl Default for IterativeAnimationConfig {
42    fn default() -> Self {
43        Self {
44            capture_frequency: 1,
45            interpolate_frames: true,
46            interpolation_frames: 5,
47            fps: 10.0,
48            show_convergence_overlay: true,
49            show_iteration_numbers: true,
50            highlight_centroid_movement: true,
51            fade_effect: true,
52            trail_length: 3,
53        }
54    }
55}
56
57/// Configuration for streaming data visualizations
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct StreamingConfig {
60    /// Buffer size for streaming data
61    pub buffer_size: usize,
62    /// Update frequency for visualization
63    pub update_frequency_ms: u64,
64    /// Window size for rolling statistics
65    pub rolling_window_size: usize,
66    /// Show data arrival animation
67    pub animate_new_data: bool,
68    /// Animate cluster updates
69    pub animate_cluster_updates: bool,
70    /// Adaptive plot bounds
71    pub adaptive_bounds: bool,
72    /// Show streaming statistics
73    pub show_streaming_stats: bool,
74    /// Data point lifetime (for fading effect)
75    pub point_lifetime_ms: u64,
76}
77
78impl Default for StreamingConfig {
79    fn default() -> Self {
80        Self {
81            buffer_size: 1000,
82            update_frequency_ms: 100,
83            rolling_window_size: 50,
84            animate_new_data: true,
85            animate_cluster_updates: true,
86            adaptive_bounds: true,
87            show_streaming_stats: true,
88            point_lifetime_ms: 10000,
89        }
90    }
91}
92
93/// Animation frame for iterative algorithms
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct AnimationFrame {
96    /// Frame number
97    pub frame_number: usize,
98    /// Iteration number (for iterative algorithms)
99    pub iteration: usize,
100    /// Timestamp
101    pub timestamp: f64,
102    /// Data points for this frame
103    pub points: Array2<f64>,
104    /// Cluster labels
105    pub labels: Array1<i32>,
106    /// Centroids (if available)
107    pub centroids: Option<Array2<f64>>,
108    /// Previous centroids (for movement visualization)
109    pub previous_centroids: Option<Array2<f64>>,
110    /// Convergence metrics
111    pub convergence_info: Option<ConvergenceInfo>,
112    /// Custom annotations
113    pub annotations: Vec<AnimationAnnotation>,
114}
115
116/// Convergence information for animation overlays
117#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct ConvergenceInfo {
119    /// Current inertia/distortion
120    pub inertia: f64,
121    /// Change in inertia from previous iteration
122    pub inertia_change: f64,
123    /// Maximum centroid movement
124    pub max_centroid_movement: f64,
125    /// Number of points that changed clusters
126    pub label_changes: usize,
127    /// Whether algorithm has converged
128    pub converged: bool,
129}
130
131/// Animation annotation for custom overlays
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct AnimationAnnotation {
134    /// Annotation type
135    pub annotation_type: String,
136    /// Position (2D or 3D coordinates)
137    pub position: Vec<f64>,
138    /// Text content
139    pub text: String,
140    /// Color
141    pub color: String,
142    /// Font size
143    pub font_size: f32,
144}
145
146/// Recorder for iterative algorithm animations
147pub struct IterativeAnimationRecorder {
148    frames: Vec<AnimationFrame>,
149    config: IterativeAnimationConfig,
150    start_time: Instant,
151    current_iteration: usize,
152    previous_centroids: Option<Array2<f64>>,
153    previous_inertia: Option<f64>,
154}
155
156impl IterativeAnimationRecorder {
157    /// Create a new animation recorder
158    pub fn new(config: IterativeAnimationConfig) -> Self {
159        Self {
160            frames: Vec::new(),
161            config,
162            start_time: Instant::now(),
163            current_iteration: 0,
164            previous_centroids: None,
165            previous_inertia: None,
166        }
167    }
168
169    /// Record a frame during algorithm iteration
170    pub fn record_frame<F: Float + FromPrimitive + Debug>(
171        &mut self,
172        data: ArrayView2<F>,
173        labels: &Array1<i32>,
174        centroids: Option<&Array2<F>>,
175        inertia: Option<f64>,
176    ) -> Result<()> {
177        if !self
178            .current_iteration
179            .is_multiple_of(self.config.capture_frequency)
180        {
181            self.current_iteration += 1;
182            return Ok(());
183        }
184
185        let timestamp = self.start_time.elapsed().as_secs_f64();
186
187        // Convert data to f64
188        let points = data.mapv(|x| x.to_f64().unwrap_or(0.0));
189
190        // Convert centroids to f64
191        let centroids_f64 = centroids.map(|c| c.mapv(|x| x.to_f64().unwrap_or(0.0)));
192
193        // Calculate convergence info
194        let convergence_info =
195            if let (Some(current_centroids), Some(current_inertia)) = (&centroids_f64, inertia) {
196                let centroid_movement = if let Some(prev_centroids) = &self.previous_centroids {
197                    calculate_max_centroid_movement(prev_centroids, current_centroids)
198                } else {
199                    0.0
200                };
201
202                let inertia_change = if let Some(prev_inertia) = self.previous_inertia {
203                    prev_inertia - current_inertia
204                } else {
205                    0.0
206                };
207
208                Some(ConvergenceInfo {
209                    inertia: current_inertia,
210                    inertia_change,
211                    max_centroid_movement: centroid_movement,
212                    label_changes: 0, // Would need previous labels to calculate
213                    converged: centroid_movement < 1e-4, // Simple convergence check
214                })
215            } else {
216                None
217            };
218
219        let frame = AnimationFrame {
220            frame_number: self.frames.len(),
221            iteration: self.current_iteration,
222            timestamp,
223            points,
224            labels: labels.clone(),
225            centroids: centroids_f64.clone(),
226            previous_centroids: self.previous_centroids.clone(),
227            convergence_info,
228            annotations: Vec::new(),
229        };
230
231        self.frames.push(frame);
232
233        // Update state for next iteration
234        self.previous_centroids = centroids_f64;
235        self.previous_inertia = inertia;
236        self.current_iteration += 1;
237
238        Ok(())
239    }
240
241    /// Add custom annotation to the current frame
242    pub fn add_annotation(&mut self, annotation: AnimationAnnotation) {
243        if let Some(frame) = self.frames.last_mut() {
244            frame.annotations.push(annotation);
245        }
246    }
247
248    /// Generate interpolated frames between recorded frames
249    pub fn generate_interpolated_frames(&self) -> Vec<AnimationFrame> {
250        if !self.config.interpolate_frames || self.frames.len() < 2 {
251            return self.frames.clone();
252        }
253
254        let mut interpolated_frames = Vec::new();
255
256        for i in 0..self.frames.len() - 1 {
257            let current_frame = &self.frames[i];
258            let next_frame = &self.frames[i + 1];
259
260            // Add current frame
261            interpolated_frames.push(current_frame.clone());
262
263            // Add interpolated frames
264            for j in 1..=self.config.interpolation_frames {
265                let t = j as f64 / (self.config.interpolation_frames + 1) as f64;
266                let interpolated_frame =
267                    match interpolate_frames(current_frame, next_frame, t, &self.config) {
268                        Ok(frame) => frame,
269                        Err(_) => continue, // Skip interpolation on error
270                    };
271                interpolated_frames.push(interpolated_frame);
272            }
273        }
274
275        // Add last frame
276        if let Some(last_frame) = self.frames.last() {
277            interpolated_frames.push(last_frame.clone());
278        }
279
280        interpolated_frames
281    }
282
283    /// Get all recorded frames
284    pub fn get_frames(&self) -> &[AnimationFrame] {
285        &self.frames
286    }
287
288    /// Export animation to JSON format
289    pub fn export_to_json(&self) -> Result<String> {
290        #[cfg(feature = "serde")]
291        {
292            let frames = if self.config.interpolate_frames {
293                self.generate_interpolated_frames()
294            } else {
295                self.frames.clone()
296            };
297
298            return serde_json::to_string_pretty(&frames).map_err(|e| {
299                ClusteringError::ComputationError(format!("JSON export failed: {}", e))
300            });
301        }
302
303        #[cfg(not(feature = "serde"))]
304        {
305            Err(ClusteringError::ComputationError(
306                "JSON export requires 'serde' feature".to_string(),
307            ))
308        }
309    }
310}
311
312/// Streaming data visualizer for real-time clustering
313pub struct StreamingVisualizer {
314    data_buffer: VecDeque<(Array1<f64>, i32, Instant)>,
315    config: StreamingConfig,
316    last_update: Instant,
317    bounds: Option<(f64, f64, f64, f64, f64, f64)>, // min_x, max_x, min_y, max_y, min_z, max_z
318    streaming_stats: StreamingStats,
319}
320
321/// Statistics for streaming visualization
322#[derive(Debug, Clone, Serialize, Deserialize)]
323pub struct StreamingStats {
324    pub total_points_processed: usize,
325    pub points_per_second: f64,
326    pub cluster_counts: HashMap<i32, usize>,
327    pub recent_cluster_changes: usize,
328    pub data_arrival_rate: f64,
329}
330
331impl StreamingVisualizer {
332    /// Create a new streaming visualizer
333    pub fn new(config: StreamingConfig) -> Self {
334        Self {
335            data_buffer: VecDeque::new(),
336            config,
337            last_update: Instant::now(),
338            bounds: None,
339            streaming_stats: StreamingStats {
340                total_points_processed: 0,
341                points_per_second: 0.0,
342                cluster_counts: HashMap::new(),
343                recent_cluster_changes: 0,
344                data_arrival_rate: 0.0,
345            },
346        }
347    }
348
349    /// Add new data point to the stream
350    pub fn add_data_point(&mut self, point: Array1<f64>, label: i32) {
351        let now = Instant::now();
352
353        // Update bounds if adaptive (before moving point)
354        if self.config.adaptive_bounds {
355            self.update_bounds(&point);
356        }
357
358        // Add to buffer
359        self.data_buffer.push_back((point, label, now));
360
361        // Maintain buffer size
362        while self.data_buffer.len() > self.config.buffer_size {
363            self.data_buffer.pop_front();
364        }
365
366        // Update statistics
367        self.streaming_stats.total_points_processed += 1;
368        *self
369            .streaming_stats
370            .cluster_counts
371            .entry(label)
372            .or_insert(0) += 1;
373
374        // Clean up old points
375        self.cleanup_old_points(now);
376    }
377
378    /// Add batch of data points
379    pub fn add_data_batch(&mut self, points: &Array2<f64>, labels: &Array1<i32>) -> Result<()> {
380        if points.nrows() != labels.len() {
381            return Err(ClusteringError::InvalidInput(
382                "Number of points must match number of labels".to_string(),
383            ));
384        }
385
386        for i in 0..points.nrows() {
387            let point = points.row(i).to_owned();
388            self.add_data_point(point, labels[i]);
389        }
390
391        Ok(())
392    }
393
394    /// Check if visualization should be updated
395    pub fn should_update(&self) -> bool {
396        self.last_update.elapsed().as_millis() >= self.config.update_frequency_ms as u128
397    }
398
399    /// Generate current visualization frame
400    pub fn generate_frame(&mut self) -> Result<StreamingFrame> {
401        let now = Instant::now();
402
403        // Calculate statistics
404        let time_since_last_update = now.duration_since(self.last_update).as_secs_f64();
405        if time_since_last_update > 0.0 {
406            let recent_points = self
407                .data_buffer
408                .iter()
409                .filter(|(_, _, timestamp)| now.duration_since(*timestamp).as_secs_f64() < 1.0)
410                .count();
411            self.streaming_stats.points_per_second =
412                recent_points as f64 / time_since_last_update.min(1.0);
413        }
414
415        // Extract current data
416        let current_data: Vec<_> = self.data_buffer.iter().collect();
417
418        if current_data.is_empty() {
419            return Ok(StreamingFrame {
420                timestamp: std::time::SystemTime::now()
421                    .duration_since(std::time::UNIX_EPOCH)
422                    .unwrap()
423                    .as_secs_f64(),
424                points: Array2::zeros((0, 0)),
425                labels: Array1::zeros(0),
426                point_ages: Vec::new(),
427                bounds: self.bounds,
428                stats: self.streaming_stats.clone(),
429                new_points_mask: Vec::new(),
430            });
431        }
432
433        // Determine dimensionality
434        let n_dims = current_data[0].0.len();
435        let n_points = current_data.len();
436
437        // Convert to arrays
438        let mut points = Array2::zeros((n_points, n_dims));
439        let mut labels = Array1::zeros(n_points);
440        let mut point_ages = Vec::with_capacity(n_points);
441        let mut new_points_mask = Vec::with_capacity(n_points);
442
443        for (i, (point, label, timestamp)) in current_data.iter().enumerate() {
444            for j in 0..n_dims {
445                points[[i, j]] = point[j];
446            }
447            labels[i] = *label;
448
449            let age = now.duration_since(*timestamp).as_millis() as f64;
450            point_ages.push(age);
451
452            // Mark as new if arrived recently
453            new_points_mask.push(age < 500.0); // 500ms threshold for "new"
454        }
455
456        self.last_update = now;
457
458        Ok(StreamingFrame {
459            timestamp: std::time::SystemTime::now()
460                .duration_since(std::time::UNIX_EPOCH)
461                .unwrap()
462                .as_secs_f64(),
463            points,
464            labels,
465            point_ages,
466            bounds: self.bounds,
467            stats: self.streaming_stats.clone(),
468            new_points_mask,
469        })
470    }
471
472    /// Update adaptive bounds
473    fn update_bounds(&mut self, point: &Array1<f64>) {
474        let n_dims = point.len();
475
476        if let Some(bounds) = &mut self.bounds {
477            // Update existing bounds
478            if n_dims >= 1 {
479                bounds.0 = bounds.0.min(point[0]); // min_x
480                bounds.1 = bounds.1.max(point[0]); // max_x
481            }
482            if n_dims >= 2 {
483                bounds.2 = bounds.2.min(point[1]); // min_y
484                bounds.3 = bounds.3.max(point[1]); // max_y
485            }
486            if n_dims >= 3 {
487                bounds.4 = bounds.4.min(point[2]); // min_z
488                bounds.5 = bounds.5.max(point[2]); // max_z
489            }
490        } else {
491            // Initialize bounds
492            self.bounds = Some(if n_dims >= 3 {
493                (point[0], point[0], point[1], point[1], point[2], point[2])
494            } else if n_dims >= 2 {
495                (point[0], point[0], point[1], point[1], 0.0, 0.0)
496            } else {
497                (point[0], point[0], 0.0, 0.0, 0.0, 0.0)
498            });
499        }
500    }
501
502    /// Clean up old points based on lifetime
503    fn cleanup_old_points(&mut self, now: Instant) {
504        let lifetime = Duration::from_millis(self.config.point_lifetime_ms);
505
506        while let Some((_, _, timestamp)) = self.data_buffer.front() {
507            if now.duration_since(*timestamp) > lifetime {
508                self.data_buffer.pop_front();
509            } else {
510                break;
511            }
512        }
513    }
514
515    /// Get current streaming statistics
516    pub fn get_stats(&self) -> &StreamingStats {
517        &self.streaming_stats
518    }
519}
520
521/// Frame for streaming visualization
522#[derive(Debug, Clone, Serialize, Deserialize)]
523pub struct StreamingFrame {
524    pub timestamp: f64,
525    pub points: Array2<f64>,
526    pub labels: Array1<i32>,
527    pub point_ages: Vec<f64>,
528    pub bounds: Option<(f64, f64, f64, f64, f64, f64)>,
529    pub stats: StreamingStats,
530    pub new_points_mask: Vec<bool>,
531}
532
533/// Calculate maximum centroid movement between iterations
534#[allow(dead_code)]
535fn calculate_max_centroid_movement(
536    prev_centroids: &Array2<f64>,
537    current_centroids: &Array2<f64>,
538) -> f64 {
539    if prev_centroids.shape() != current_centroids.shape() {
540        return f64::INFINITY;
541    }
542
543    let mut max_movement = 0.0;
544
545    for i in 0..prev_centroids.nrows() {
546        let mut movement = 0.0;
547        for j in 0..prev_centroids.ncols() {
548            let diff = current_centroids[[i, j]] - prev_centroids[[i, j]];
549            movement += diff * diff;
550        }
551        movement = movement.sqrt();
552        max_movement = max_movement.max(movement);
553    }
554
555    max_movement
556}
557
558/// Interpolate between two animation frames
559#[allow(dead_code)]
560fn interpolate_frames(
561    frame1: &AnimationFrame,
562    frame2: &AnimationFrame,
563    t: f64,
564    config: &IterativeAnimationConfig,
565) -> Result<AnimationFrame> {
566    let t = apply_easing(t, EasingFunction::EaseInOut);
567
568    // Interpolate centroids if both frames have them
569    let centroids = if let (Some(c1), Some(c2)) = (&frame1.centroids, &frame2.centroids) {
570        if c1.shape() == c2.shape() {
571            Some(c1 * (1.0 - t) + c2 * t)
572        } else {
573            Some(c2.clone()) // Fall back to destination centroids
574        }
575    } else {
576        frame2.centroids.clone()
577    };
578
579    // Interpolate convergence info
580    let convergence_info =
581        if let (Some(conv1), Some(conv2)) = (&frame1.convergence_info, &frame2.convergence_info) {
582            Some(ConvergenceInfo {
583                inertia: conv1.inertia * (1.0 - t) + conv2.inertia * t,
584                inertia_change: conv1.inertia_change * (1.0 - t) + conv2.inertia_change * t,
585                max_centroid_movement: conv1.max_centroid_movement * (1.0 - t)
586                    + conv2.max_centroid_movement * t,
587                label_changes: if t < 0.5 {
588                    conv1.label_changes
589                } else {
590                    conv2.label_changes
591                },
592                converged: conv2.converged,
593            })
594        } else {
595            frame2.convergence_info.clone()
596        };
597
598    Ok(AnimationFrame {
599        frame_number: frame1.frame_number,
600        iteration: frame1.iteration,
601        timestamp: frame1.timestamp * (1.0 - t) + frame2.timestamp * t,
602        points: frame2.points.clone(), // Don't interpolate data points
603        labels: frame2.labels.clone(),
604        centroids,
605        previous_centroids: frame1.centroids.clone(),
606        convergence_info,
607        annotations: frame2.annotations.clone(),
608    })
609}
610
611/// Apply easing function to interpolation parameter
612#[allow(dead_code)]
613fn apply_easing(t: f64, easing: EasingFunction) -> f64 {
614    let t = t.clamp(0.0, 1.0);
615
616    match easing {
617        EasingFunction::Linear => t,
618        EasingFunction::EaseIn => t * t,
619        EasingFunction::EaseOut => 1.0 - (1.0 - t).powi(2),
620        EasingFunction::EaseInOut => {
621            if t < 0.5 {
622                2.0 * t * t
623            } else {
624                1.0 - 2.0 * (1.0 - t).powi(2)
625            }
626        }
627        EasingFunction::Bounce => {
628            if t < 1.0 / 2.75 {
629                7.5625 * t * t
630            } else if t < 2.0 / 2.75 {
631                let t = t - 1.5 / 2.75;
632                7.5625 * t * t + 0.75
633            } else if t < 2.5 / 2.75 {
634                let t = t - 2.25 / 2.75;
635                7.5625 * t * t + 0.9375
636            } else {
637                let t = t - 2.625 / 2.75;
638                7.5625 * t * t + 0.984375
639            }
640        }
641        EasingFunction::Elastic => {
642            if t == 0.0 || t == 1.0 {
643                t
644            } else {
645                let p = 0.3;
646                let s = p / 4.0;
647                -(2.0_f64.powf(10.0 * (t - 1.0))
648                    * ((t - 1.0 - s) * (2.0 * std::f64::consts::PI) / p).sin())
649            }
650        }
651    }
652}
653
654#[cfg(test)]
655mod tests {
656    use super::*;
657    use scirs2_core::ndarray::Array2;
658
659    #[test]
660    fn test_animation_recorder() {
661        let config = IterativeAnimationConfig::default();
662        let mut recorder = IterativeAnimationRecorder::new(config);
663
664        let data =
665            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
666        let labels = Array1::from_vec(vec![0, 0, 1, 1]);
667        let centroids = Array2::from_shape_vec((2, 2), vec![2.0, 3.0, 6.0, 7.0]).unwrap();
668
669        recorder
670            .record_frame(data.view(), &labels, Some(&centroids), Some(10.0))
671            .unwrap();
672
673        assert_eq!(recorder.get_frames().len(), 1);
674        assert_eq!(recorder.get_frames()[0].iteration, 0);
675    }
676
677    #[test]
678    fn test_streaming_visualizer() {
679        let config = StreamingConfig::default();
680        let mut visualizer = StreamingVisualizer::new(config);
681
682        let point = Array1::from_vec(vec![1.0, 2.0]);
683        visualizer.add_data_point(point, 0);
684
685        let frame = visualizer.generate_frame().unwrap();
686        assert_eq!(frame.points.nrows(), 1);
687        assert_eq!(frame.labels[0], 0);
688    }
689
690    #[test]
691    fn test_easing_functions() {
692        assert_eq!(apply_easing(0.0, EasingFunction::Linear), 0.0);
693        assert_eq!(apply_easing(1.0, EasingFunction::Linear), 1.0);
694        assert_eq!(apply_easing(0.5, EasingFunction::Linear), 0.5);
695
696        assert!(apply_easing(0.5, EasingFunction::EaseIn) < 0.5);
697        assert!(apply_easing(0.5, EasingFunction::EaseOut) > 0.5);
698    }
699}