Skip to main content

tenflowers_dataset/visualization/
visualizer.rs

1//! Core dataset visualization implementation
2//!
3//! This module contains the main DatasetVisualizer struct and its methods
4//! for analyzing and visualizing dataset properties.
5
6use crate::{transforms::Transform, Dataset};
7use tenflowers_core::{Result, TensorError};
8
9use super::types::*;
10
11/// Visualization utilities for datasets
12pub struct DatasetVisualizer;
13
14impl DatasetVisualizer {
15    /// Create a sample preview showing basic statistics and examples
16    pub fn sample_preview<T, D>(dataset: &D, num_samples: usize) -> Result<SamplePreview>
17    where
18        T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
19        D: Dataset<T>,
20    {
21        if dataset.is_empty() {
22            return Err(TensorError::invalid_argument(
23                "Dataset is empty".to_string(),
24            ));
25        }
26
27        let total_samples = dataset.len();
28        let samples_to_show = num_samples.min(total_samples);
29
30        // Get sample data
31        let mut samples = Vec::new();
32        let step = if samples_to_show == 1 {
33            0
34        } else {
35            total_samples / samples_to_show
36        };
37
38        for i in 0..samples_to_show {
39            let index = if step == 0 { 0 } else { i * step };
40            let index = index.min(total_samples - 1);
41
42            if let Ok((features, labels)) = dataset.get(index) {
43                samples.push(SampleInfo {
44                    index,
45                    feature_shape: features.shape().dims().to_vec(),
46                    label_shape: labels.shape().dims().to_vec(),
47                });
48            }
49        }
50
51        Ok(SamplePreview {
52            total_samples,
53            samples_shown: samples.len(),
54            samples,
55        })
56    }
57
58    /// Generate distribution information for dataset features and labels
59    pub fn feature_distribution<T, D>(
60        dataset: &D,
61        max_samples: Option<usize>,
62    ) -> Result<DistributionInfo<T>>
63    where
64        T: Clone
65            + Default
66            + scirs2_core::numeric::Zero
67            + Send
68            + Sync
69            + 'static
70            + scirs2_core::numeric::Float,
71        D: Dataset<T>,
72    {
73        if dataset.is_empty() {
74            return Err(TensorError::invalid_argument(
75                "Dataset is empty".to_string(),
76            ));
77        }
78
79        let samples_to_analyze = max_samples.unwrap_or(dataset.len()).min(dataset.len());
80        let mut feature_stats = Vec::new();
81        let mut label_stats = Vec::new();
82
83        // Get first sample to determine shapes
84        let (first_features, first_labels) = dataset.get(0)?;
85        let feature_dims = first_features.numel();
86        let label_dims = first_labels.numel();
87
88        // Initialize accumulators
89        let mut feature_sums = vec![T::zero(); feature_dims];
90        let mut feature_squared_sums = vec![T::zero(); feature_dims];
91        let mut label_sums = vec![T::zero(); label_dims];
92        let mut label_squared_sums = vec![T::zero(); label_dims];
93
94        let mut valid_samples = 0;
95
96        // Accumulate statistics
97        for i in 0..samples_to_analyze {
98            if let Ok((features, labels)) = dataset.get(i) {
99                // Process features
100                if let Some(feature_data) = features.as_slice() {
101                    for (j, &value) in feature_data.iter().enumerate() {
102                        if j < feature_dims {
103                            feature_sums[j] = feature_sums[j] + value;
104                            feature_squared_sums[j] = feature_squared_sums[j] + value * value;
105                        }
106                    }
107                }
108
109                // Process labels
110                if let Some(label_data) = labels.as_slice() {
111                    for (j, &value) in label_data.iter().enumerate() {
112                        if j < label_dims {
113                            label_sums[j] = label_sums[j] + value;
114                            label_squared_sums[j] = label_squared_sums[j] + value * value;
115                        }
116                    }
117                }
118
119                valid_samples += 1;
120            }
121        }
122
123        if valid_samples == 0 {
124            return Err(TensorError::invalid_argument(
125                "No valid samples found".to_string(),
126            ));
127        }
128
129        let n = T::from(valid_samples).expect("sample count should convert to float");
130
131        // Calculate feature statistics
132        for i in 0..feature_dims {
133            let mean = feature_sums[i] / n;
134            let variance = (feature_squared_sums[i] / n) - (mean * mean);
135            let std_dev = variance.sqrt();
136
137            feature_stats.push(FeatureStats {
138                dimension: i,
139                mean,
140                std_dev,
141                min: T::zero(), // Would need to track min/max separately
142                max: T::zero(),
143            });
144        }
145
146        // Calculate label statistics
147        for i in 0..label_dims {
148            let mean = label_sums[i] / n;
149            let variance = (label_squared_sums[i] / n) - (mean * mean);
150            let std_dev = variance.sqrt();
151
152            label_stats.push(FeatureStats {
153                dimension: i,
154                mean,
155                std_dev,
156                min: T::zero(),
157                max: T::zero(),
158            });
159        }
160
161        Ok(DistributionInfo {
162            samples_analyzed: valid_samples,
163            feature_stats,
164            label_stats,
165        })
166    }
167
168    /// Generate class distribution for classification datasets
169    pub fn class_distribution<T, D>(dataset: &D) -> Result<ClassDistribution>
170    where
171        T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
172        D: Dataset<T>,
173    {
174        let mut class_counts = std::collections::HashMap::new();
175        let mut total_samples = 0;
176
177        for i in 0..dataset.len() {
178            if let Ok((_, labels)) = dataset.get(i) {
179                // For simplicity, convert labels to string representation
180                let class_key = format!("{:?}", labels.shape());
181                *class_counts.entry(class_key).or_insert(0) += 1;
182                total_samples += 1;
183            }
184        }
185
186        Ok(ClassDistribution {
187            total_samples,
188            class_counts,
189        })
190    }
191
192    /// Generate a simple text-based histogram for a single feature dimension
193    pub fn feature_histogram<T, D>(
194        dataset: &D,
195        feature_index: usize,
196        bins: usize,
197    ) -> Result<FeatureHistogram<T>>
198    where
199        T: Clone
200            + Default
201            + scirs2_core::numeric::Zero
202            + Send
203            + Sync
204            + 'static
205            + scirs2_core::numeric::Float
206            + PartialOrd,
207        D: Dataset<T>,
208    {
209        let mut values = Vec::new();
210
211        // Collect all values for the specified feature
212        for i in 0..dataset.len() {
213            if let Ok((features, _)) = dataset.get(i) {
214                if let Some(feature_data) = features.as_slice() {
215                    if feature_index < feature_data.len() {
216                        values.push(feature_data[feature_index]);
217                    }
218                }
219            }
220        }
221
222        if values.is_empty() {
223            return Err(TensorError::invalid_argument(
224                "No valid feature values found".to_string(),
225            ));
226        }
227
228        // Find min and max
229        let mut min_val = values[0];
230        let mut max_val = values[0];
231
232        for &val in &values {
233            if val < min_val {
234                min_val = val;
235            }
236            if val > max_val {
237                max_val = val;
238            }
239        }
240
241        // Create bins
242        let range = max_val - min_val;
243        let bin_width = if range > T::zero() {
244            range / T::from(bins).expect("bin count should convert to float")
245        } else {
246            T::from(1.0).expect("constant 1.0 should convert to float")
247        };
248
249        let mut bin_counts = vec![0; bins];
250
251        // Assign values to bins
252        for val in values {
253            if range > T::zero() {
254                let bin_index = ((val - min_val) / bin_width).to_usize().unwrap_or(0);
255                let bin_index = bin_index.min(bins - 1);
256                bin_counts[bin_index] += 1;
257            } else {
258                bin_counts[0] += 1;
259            }
260        }
261
262        Ok(FeatureHistogram {
263            feature_index,
264            min_value: min_val,
265            max_value: max_val,
266            bin_width,
267            bin_counts,
268        })
269    }
270
271    /// Analyze the effects of a transform on dataset samples
272    pub fn analyze_augmentation_effects<T, D, Tr>(
273        dataset: &D,
274        transform: &Tr,
275        num_samples: usize,
276    ) -> Result<AugmentationEffects<T>>
277    where
278        T: Clone
279            + Default
280            + scirs2_core::numeric::Zero
281            + Send
282            + Sync
283            + 'static
284            + scirs2_core::numeric::Float
285            + PartialOrd,
286        D: Dataset<T>,
287        Tr: Transform<T>,
288    {
289        if dataset.is_empty() {
290            return Err(TensorError::invalid_argument(
291                "Dataset is empty".to_string(),
292            ));
293        }
294
295        let samples_to_analyze = num_samples.min(dataset.len());
296        let mut before_after_pairs = Vec::new();
297        let mut transform_success_count = 0;
298
299        // Collect before/after pairs
300        for i in 0..samples_to_analyze {
301            if let Ok(original_sample) = dataset.get(i) {
302                match transform.apply(original_sample.clone()) {
303                    Ok(transformed_sample) => {
304                        before_after_pairs.push(BeforeAfterPair {
305                            index: i,
306                            original: original_sample,
307                            transformed: transformed_sample,
308                        });
309                        transform_success_count += 1;
310                    }
311                    Err(_) => {
312                        // Transform failed, skip this sample
313                        continue;
314                    }
315                }
316            }
317        }
318
319        if before_after_pairs.is_empty() {
320            return Err(TensorError::invalid_argument(
321                "No successful transforms".to_string(),
322            ));
323        }
324
325        // Analyze feature changes
326        let feature_changes = Self::analyze_feature_changes(&before_after_pairs)?;
327
328        // Analyze distribution changes
329        let distribution_changes = Self::analyze_distribution_changes(&before_after_pairs)?;
330
331        Ok(AugmentationEffects {
332            samples_analyzed: before_after_pairs.len(),
333            transform_success_rate: transform_success_count as f64 / samples_to_analyze as f64,
334            feature_changes,
335            distribution_changes,
336            sample_pairs: before_after_pairs,
337        })
338    }
339
340    /// Compare before/after samples for a specific transform
341    pub fn compare_samples<T, Tr>(
342        samples: &[(tenflowers_core::Tensor<T>, tenflowers_core::Tensor<T>)],
343        transform: &Tr,
344        comparison_count: usize,
345    ) -> Result<Vec<SampleComparison<T>>>
346    where
347        T: Clone
348            + Default
349            + scirs2_core::numeric::Zero
350            + Send
351            + Sync
352            + 'static
353            + scirs2_core::numeric::Float,
354        Tr: Transform<T>,
355    {
356        let mut comparisons = Vec::new();
357        let samples_to_compare = comparison_count.min(samples.len());
358
359        for (i, original) in samples.iter().enumerate().take(samples_to_compare) {
360            let original = original.clone();
361
362            match transform.apply(original.clone()) {
363                Ok(transformed) => {
364                    // Calculate basic statistics
365                    let original_stats = Self::calculate_tensor_stats(&original.0)?;
366                    let transformed_stats = Self::calculate_tensor_stats(&transformed.0)?;
367
368                    comparisons.push(SampleComparison {
369                        sample_index: i,
370                        original_stats,
371                        transformed_stats,
372                        change_magnitude: Self::calculate_change_magnitude(
373                            &original.0,
374                            &transformed.0,
375                        )?,
376                    });
377                }
378                Err(_) => {
379                    // Skip failed transforms
380                    continue;
381                }
382            }
383        }
384
385        Ok(comparisons)
386    }
387
388    // Helper method to analyze feature changes across all samples
389    pub fn analyze_feature_changes<T>(
390        pairs: &[BeforeAfterPair<T>],
391    ) -> Result<FeatureChangeAnalysis<T>>
392    where
393        T: Clone
394            + Default
395            + scirs2_core::numeric::Zero
396            + Send
397            + Sync
398            + 'static
399            + scirs2_core::numeric::Float,
400    {
401        if pairs.is_empty() {
402            return Err(TensorError::invalid_argument(
403                "No sample pairs provided".to_string(),
404            ));
405        }
406
407        // Get feature dimensions from first sample
408        let first_features = &pairs[0].original.0;
409        let feature_count = first_features.numel();
410
411        let mut total_change = T::zero();
412        let mut max_change = T::zero();
413        let mut min_change = T::from(f64::INFINITY).unwrap_or(T::zero());
414        let mut change_count = 0;
415
416        // Calculate changes across all samples
417        for pair in pairs {
418            if let (Some(orig_data), Some(trans_data)) =
419                (pair.original.0.as_slice(), pair.transformed.0.as_slice())
420            {
421                for (orig, trans) in orig_data.iter().zip(trans_data.iter()) {
422                    let change = (*trans - *orig).abs();
423                    total_change = total_change + change;
424
425                    if change > max_change {
426                        max_change = change;
427                    }
428                    if change < min_change {
429                        min_change = change;
430                    }
431                    change_count += 1;
432                }
433            }
434        }
435
436        let avg_change = if change_count > 0 {
437            total_change
438                / T::from(change_count)
439                    .unwrap_or(T::from(1.0).expect("constant 1.0 should convert to float"))
440        } else {
441            T::zero()
442        };
443
444        Ok(FeatureChangeAnalysis {
445            feature_count,
446            average_change: avg_change,
447            max_change,
448            min_change,
449            samples_with_changes: pairs.len(),
450        })
451    }
452
453    // Helper method to analyze distribution changes
454    pub fn analyze_distribution_changes<T>(
455        pairs: &[BeforeAfterPair<T>],
456    ) -> Result<DistributionChangeAnalysis<T>>
457    where
458        T: Clone
459            + Default
460            + scirs2_core::numeric::Zero
461            + Send
462            + Sync
463            + 'static
464            + scirs2_core::numeric::Float,
465    {
466        // Calculate mean and std before and after transformation
467        let mut original_sum = T::zero();
468        let mut transformed_sum = T::zero();
469        let mut original_squared_sum = T::zero();
470        let mut transformed_squared_sum = T::zero();
471        let mut total_elements = 0;
472
473        for pair in pairs {
474            if let (Some(orig_data), Some(trans_data)) =
475                (pair.original.0.as_slice(), pair.transformed.0.as_slice())
476            {
477                for (&orig, &trans) in orig_data.iter().zip(trans_data.iter()) {
478                    original_sum = original_sum + orig;
479                    transformed_sum = transformed_sum + trans;
480                    original_squared_sum = original_squared_sum + orig * orig;
481                    transformed_squared_sum = transformed_squared_sum + trans * trans;
482                    total_elements += 1;
483                }
484            }
485        }
486
487        if total_elements == 0 {
488            return Err(TensorError::invalid_argument(
489                "No valid data found".to_string(),
490            ));
491        }
492
493        let n = T::from(total_elements)
494            .unwrap_or(T::from(1.0).expect("constant 1.0 should convert to float"));
495
496        let original_mean = original_sum / n;
497        let transformed_mean = transformed_sum / n;
498
499        let original_variance = (original_squared_sum / n) - (original_mean * original_mean);
500        let transformed_variance =
501            (transformed_squared_sum / n) - (transformed_mean * transformed_mean);
502
503        let original_std = original_variance.sqrt();
504        let transformed_std = transformed_variance.sqrt();
505
506        Ok(DistributionChangeAnalysis {
507            original_mean,
508            transformed_mean,
509            original_std,
510            transformed_std,
511            mean_change: (transformed_mean - original_mean).abs(),
512            std_change: (transformed_std - original_std).abs(),
513        })
514    }
515
516    // Helper method to calculate basic tensor statistics
517    pub fn calculate_tensor_stats<T>(tensor: &tenflowers_core::Tensor<T>) -> Result<TensorStats<T>>
518    where
519        T: Clone
520            + Default
521            + scirs2_core::numeric::Zero
522            + Send
523            + Sync
524            + 'static
525            + scirs2_core::numeric::Float,
526    {
527        if let Some(data) = tensor.as_slice() {
528            if data.is_empty() {
529                return Ok(TensorStats {
530                    mean: T::zero(),
531                    std: T::zero(),
532                    min: T::zero(),
533                    max: T::zero(),
534                    element_count: 0,
535                });
536            }
537
538            let mut sum = T::zero();
539            let mut squared_sum = T::zero();
540            let mut min_val = data[0];
541            let mut max_val = data[0];
542
543            for &value in data {
544                sum = sum + value;
545                squared_sum = squared_sum + value * value;
546                if value < min_val {
547                    min_val = value;
548                }
549                if value > max_val {
550                    max_val = value;
551                }
552            }
553
554            let n = T::from(data.len())
555                .unwrap_or(T::from(1.0).expect("constant 1.0 should convert to float"));
556            let mean = sum / n;
557            let variance = (squared_sum / n) - (mean * mean);
558            let std = variance.sqrt();
559
560            Ok(TensorStats {
561                mean,
562                std,
563                min: min_val,
564                max: max_val,
565                element_count: data.len(),
566            })
567        } else {
568            Err(TensorError::device_error_simple(
569                "Cannot access tensor data".to_string(),
570            ))
571        }
572    }
573
574    // Helper method to calculate change magnitude between tensors
575    pub fn calculate_change_magnitude<T>(
576        original: &tenflowers_core::Tensor<T>,
577        transformed: &tenflowers_core::Tensor<T>,
578    ) -> Result<T>
579    where
580        T: Clone
581            + Default
582            + scirs2_core::numeric::Zero
583            + Send
584            + Sync
585            + 'static
586            + scirs2_core::numeric::Float,
587    {
588        if let (Some(orig_data), Some(trans_data)) = (original.as_slice(), transformed.as_slice()) {
589            if orig_data.len() != trans_data.len() {
590                return Err(TensorError::invalid_argument(
591                    "Tensor size mismatch".to_string(),
592                ));
593            }
594
595            let mut total_change = T::zero();
596            for (orig, trans) in orig_data.iter().zip(trans_data.iter()) {
597                let diff = *trans - *orig;
598                total_change = total_change + diff * diff;
599            }
600
601            let n = T::from(orig_data.len())
602                .unwrap_or(T::from(1.0).expect("constant 1.0 should convert to float"));
603            Ok((total_change / n).sqrt()) // RMS change
604        } else {
605            Err(TensorError::device_error_simple(
606                "Cannot access tensor data".to_string(),
607            ))
608        }
609    }
610}