tenflowers_dataset/visualization/
visualizer.rs1use crate::{transforms::Transform, Dataset};
7use tenflowers_core::{Result, TensorError};
8
9use super::types::*;
10
11pub struct DatasetVisualizer;
13
14impl DatasetVisualizer {
15 pub fn sample_preview<T, D>(dataset: &D, num_samples: usize) -> Result<SamplePreview>
17 where
18 T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
19 D: Dataset<T>,
20 {
21 if dataset.is_empty() {
22 return Err(TensorError::invalid_argument(
23 "Dataset is empty".to_string(),
24 ));
25 }
26
27 let total_samples = dataset.len();
28 let samples_to_show = num_samples.min(total_samples);
29
30 let mut samples = Vec::new();
32 let step = if samples_to_show == 1 {
33 0
34 } else {
35 total_samples / samples_to_show
36 };
37
38 for i in 0..samples_to_show {
39 let index = if step == 0 { 0 } else { i * step };
40 let index = index.min(total_samples - 1);
41
42 if let Ok((features, labels)) = dataset.get(index) {
43 samples.push(SampleInfo {
44 index,
45 feature_shape: features.shape().dims().to_vec(),
46 label_shape: labels.shape().dims().to_vec(),
47 });
48 }
49 }
50
51 Ok(SamplePreview {
52 total_samples,
53 samples_shown: samples.len(),
54 samples,
55 })
56 }
57
58 pub fn feature_distribution<T, D>(
60 dataset: &D,
61 max_samples: Option<usize>,
62 ) -> Result<DistributionInfo<T>>
63 where
64 T: Clone
65 + Default
66 + scirs2_core::numeric::Zero
67 + Send
68 + Sync
69 + 'static
70 + scirs2_core::numeric::Float,
71 D: Dataset<T>,
72 {
73 if dataset.is_empty() {
74 return Err(TensorError::invalid_argument(
75 "Dataset is empty".to_string(),
76 ));
77 }
78
79 let samples_to_analyze = max_samples.unwrap_or(dataset.len()).min(dataset.len());
80 let mut feature_stats = Vec::new();
81 let mut label_stats = Vec::new();
82
83 let (first_features, first_labels) = dataset.get(0)?;
85 let feature_dims = first_features.numel();
86 let label_dims = first_labels.numel();
87
88 let mut feature_sums = vec![T::zero(); feature_dims];
90 let mut feature_squared_sums = vec![T::zero(); feature_dims];
91 let mut label_sums = vec![T::zero(); label_dims];
92 let mut label_squared_sums = vec![T::zero(); label_dims];
93
94 let mut valid_samples = 0;
95
96 for i in 0..samples_to_analyze {
98 if let Ok((features, labels)) = dataset.get(i) {
99 if let Some(feature_data) = features.as_slice() {
101 for (j, &value) in feature_data.iter().enumerate() {
102 if j < feature_dims {
103 feature_sums[j] = feature_sums[j] + value;
104 feature_squared_sums[j] = feature_squared_sums[j] + value * value;
105 }
106 }
107 }
108
109 if let Some(label_data) = labels.as_slice() {
111 for (j, &value) in label_data.iter().enumerate() {
112 if j < label_dims {
113 label_sums[j] = label_sums[j] + value;
114 label_squared_sums[j] = label_squared_sums[j] + value * value;
115 }
116 }
117 }
118
119 valid_samples += 1;
120 }
121 }
122
123 if valid_samples == 0 {
124 return Err(TensorError::invalid_argument(
125 "No valid samples found".to_string(),
126 ));
127 }
128
129 let n = T::from(valid_samples).expect("sample count should convert to float");
130
131 for i in 0..feature_dims {
133 let mean = feature_sums[i] / n;
134 let variance = (feature_squared_sums[i] / n) - (mean * mean);
135 let std_dev = variance.sqrt();
136
137 feature_stats.push(FeatureStats {
138 dimension: i,
139 mean,
140 std_dev,
141 min: T::zero(), max: T::zero(),
143 });
144 }
145
146 for i in 0..label_dims {
148 let mean = label_sums[i] / n;
149 let variance = (label_squared_sums[i] / n) - (mean * mean);
150 let std_dev = variance.sqrt();
151
152 label_stats.push(FeatureStats {
153 dimension: i,
154 mean,
155 std_dev,
156 min: T::zero(),
157 max: T::zero(),
158 });
159 }
160
161 Ok(DistributionInfo {
162 samples_analyzed: valid_samples,
163 feature_stats,
164 label_stats,
165 })
166 }
167
168 pub fn class_distribution<T, D>(dataset: &D) -> Result<ClassDistribution>
170 where
171 T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
172 D: Dataset<T>,
173 {
174 let mut class_counts = std::collections::HashMap::new();
175 let mut total_samples = 0;
176
177 for i in 0..dataset.len() {
178 if let Ok((_, labels)) = dataset.get(i) {
179 let class_key = format!("{:?}", labels.shape());
181 *class_counts.entry(class_key).or_insert(0) += 1;
182 total_samples += 1;
183 }
184 }
185
186 Ok(ClassDistribution {
187 total_samples,
188 class_counts,
189 })
190 }
191
192 pub fn feature_histogram<T, D>(
194 dataset: &D,
195 feature_index: usize,
196 bins: usize,
197 ) -> Result<FeatureHistogram<T>>
198 where
199 T: Clone
200 + Default
201 + scirs2_core::numeric::Zero
202 + Send
203 + Sync
204 + 'static
205 + scirs2_core::numeric::Float
206 + PartialOrd,
207 D: Dataset<T>,
208 {
209 let mut values = Vec::new();
210
211 for i in 0..dataset.len() {
213 if let Ok((features, _)) = dataset.get(i) {
214 if let Some(feature_data) = features.as_slice() {
215 if feature_index < feature_data.len() {
216 values.push(feature_data[feature_index]);
217 }
218 }
219 }
220 }
221
222 if values.is_empty() {
223 return Err(TensorError::invalid_argument(
224 "No valid feature values found".to_string(),
225 ));
226 }
227
228 let mut min_val = values[0];
230 let mut max_val = values[0];
231
232 for &val in &values {
233 if val < min_val {
234 min_val = val;
235 }
236 if val > max_val {
237 max_val = val;
238 }
239 }
240
241 let range = max_val - min_val;
243 let bin_width = if range > T::zero() {
244 range / T::from(bins).expect("bin count should convert to float")
245 } else {
246 T::from(1.0).expect("constant 1.0 should convert to float")
247 };
248
249 let mut bin_counts = vec![0; bins];
250
251 for val in values {
253 if range > T::zero() {
254 let bin_index = ((val - min_val) / bin_width).to_usize().unwrap_or(0);
255 let bin_index = bin_index.min(bins - 1);
256 bin_counts[bin_index] += 1;
257 } else {
258 bin_counts[0] += 1;
259 }
260 }
261
262 Ok(FeatureHistogram {
263 feature_index,
264 min_value: min_val,
265 max_value: max_val,
266 bin_width,
267 bin_counts,
268 })
269 }
270
271 pub fn analyze_augmentation_effects<T, D, Tr>(
273 dataset: &D,
274 transform: &Tr,
275 num_samples: usize,
276 ) -> Result<AugmentationEffects<T>>
277 where
278 T: Clone
279 + Default
280 + scirs2_core::numeric::Zero
281 + Send
282 + Sync
283 + 'static
284 + scirs2_core::numeric::Float
285 + PartialOrd,
286 D: Dataset<T>,
287 Tr: Transform<T>,
288 {
289 if dataset.is_empty() {
290 return Err(TensorError::invalid_argument(
291 "Dataset is empty".to_string(),
292 ));
293 }
294
295 let samples_to_analyze = num_samples.min(dataset.len());
296 let mut before_after_pairs = Vec::new();
297 let mut transform_success_count = 0;
298
299 for i in 0..samples_to_analyze {
301 if let Ok(original_sample) = dataset.get(i) {
302 match transform.apply(original_sample.clone()) {
303 Ok(transformed_sample) => {
304 before_after_pairs.push(BeforeAfterPair {
305 index: i,
306 original: original_sample,
307 transformed: transformed_sample,
308 });
309 transform_success_count += 1;
310 }
311 Err(_) => {
312 continue;
314 }
315 }
316 }
317 }
318
319 if before_after_pairs.is_empty() {
320 return Err(TensorError::invalid_argument(
321 "No successful transforms".to_string(),
322 ));
323 }
324
325 let feature_changes = Self::analyze_feature_changes(&before_after_pairs)?;
327
328 let distribution_changes = Self::analyze_distribution_changes(&before_after_pairs)?;
330
331 Ok(AugmentationEffects {
332 samples_analyzed: before_after_pairs.len(),
333 transform_success_rate: transform_success_count as f64 / samples_to_analyze as f64,
334 feature_changes,
335 distribution_changes,
336 sample_pairs: before_after_pairs,
337 })
338 }
339
340 pub fn compare_samples<T, Tr>(
342 samples: &[(tenflowers_core::Tensor<T>, tenflowers_core::Tensor<T>)],
343 transform: &Tr,
344 comparison_count: usize,
345 ) -> Result<Vec<SampleComparison<T>>>
346 where
347 T: Clone
348 + Default
349 + scirs2_core::numeric::Zero
350 + Send
351 + Sync
352 + 'static
353 + scirs2_core::numeric::Float,
354 Tr: Transform<T>,
355 {
356 let mut comparisons = Vec::new();
357 let samples_to_compare = comparison_count.min(samples.len());
358
359 for (i, original) in samples.iter().enumerate().take(samples_to_compare) {
360 let original = original.clone();
361
362 match transform.apply(original.clone()) {
363 Ok(transformed) => {
364 let original_stats = Self::calculate_tensor_stats(&original.0)?;
366 let transformed_stats = Self::calculate_tensor_stats(&transformed.0)?;
367
368 comparisons.push(SampleComparison {
369 sample_index: i,
370 original_stats,
371 transformed_stats,
372 change_magnitude: Self::calculate_change_magnitude(
373 &original.0,
374 &transformed.0,
375 )?,
376 });
377 }
378 Err(_) => {
379 continue;
381 }
382 }
383 }
384
385 Ok(comparisons)
386 }
387
388 pub fn analyze_feature_changes<T>(
390 pairs: &[BeforeAfterPair<T>],
391 ) -> Result<FeatureChangeAnalysis<T>>
392 where
393 T: Clone
394 + Default
395 + scirs2_core::numeric::Zero
396 + Send
397 + Sync
398 + 'static
399 + scirs2_core::numeric::Float,
400 {
401 if pairs.is_empty() {
402 return Err(TensorError::invalid_argument(
403 "No sample pairs provided".to_string(),
404 ));
405 }
406
407 let first_features = &pairs[0].original.0;
409 let feature_count = first_features.numel();
410
411 let mut total_change = T::zero();
412 let mut max_change = T::zero();
413 let mut min_change = T::from(f64::INFINITY).unwrap_or(T::zero());
414 let mut change_count = 0;
415
416 for pair in pairs {
418 if let (Some(orig_data), Some(trans_data)) =
419 (pair.original.0.as_slice(), pair.transformed.0.as_slice())
420 {
421 for (orig, trans) in orig_data.iter().zip(trans_data.iter()) {
422 let change = (*trans - *orig).abs();
423 total_change = total_change + change;
424
425 if change > max_change {
426 max_change = change;
427 }
428 if change < min_change {
429 min_change = change;
430 }
431 change_count += 1;
432 }
433 }
434 }
435
436 let avg_change = if change_count > 0 {
437 total_change
438 / T::from(change_count)
439 .unwrap_or(T::from(1.0).expect("constant 1.0 should convert to float"))
440 } else {
441 T::zero()
442 };
443
444 Ok(FeatureChangeAnalysis {
445 feature_count,
446 average_change: avg_change,
447 max_change,
448 min_change,
449 samples_with_changes: pairs.len(),
450 })
451 }
452
453 pub fn analyze_distribution_changes<T>(
455 pairs: &[BeforeAfterPair<T>],
456 ) -> Result<DistributionChangeAnalysis<T>>
457 where
458 T: Clone
459 + Default
460 + scirs2_core::numeric::Zero
461 + Send
462 + Sync
463 + 'static
464 + scirs2_core::numeric::Float,
465 {
466 let mut original_sum = T::zero();
468 let mut transformed_sum = T::zero();
469 let mut original_squared_sum = T::zero();
470 let mut transformed_squared_sum = T::zero();
471 let mut total_elements = 0;
472
473 for pair in pairs {
474 if let (Some(orig_data), Some(trans_data)) =
475 (pair.original.0.as_slice(), pair.transformed.0.as_slice())
476 {
477 for (&orig, &trans) in orig_data.iter().zip(trans_data.iter()) {
478 original_sum = original_sum + orig;
479 transformed_sum = transformed_sum + trans;
480 original_squared_sum = original_squared_sum + orig * orig;
481 transformed_squared_sum = transformed_squared_sum + trans * trans;
482 total_elements += 1;
483 }
484 }
485 }
486
487 if total_elements == 0 {
488 return Err(TensorError::invalid_argument(
489 "No valid data found".to_string(),
490 ));
491 }
492
493 let n = T::from(total_elements)
494 .unwrap_or(T::from(1.0).expect("constant 1.0 should convert to float"));
495
496 let original_mean = original_sum / n;
497 let transformed_mean = transformed_sum / n;
498
499 let original_variance = (original_squared_sum / n) - (original_mean * original_mean);
500 let transformed_variance =
501 (transformed_squared_sum / n) - (transformed_mean * transformed_mean);
502
503 let original_std = original_variance.sqrt();
504 let transformed_std = transformed_variance.sqrt();
505
506 Ok(DistributionChangeAnalysis {
507 original_mean,
508 transformed_mean,
509 original_std,
510 transformed_std,
511 mean_change: (transformed_mean - original_mean).abs(),
512 std_change: (transformed_std - original_std).abs(),
513 })
514 }
515
516 pub fn calculate_tensor_stats<T>(tensor: &tenflowers_core::Tensor<T>) -> Result<TensorStats<T>>
518 where
519 T: Clone
520 + Default
521 + scirs2_core::numeric::Zero
522 + Send
523 + Sync
524 + 'static
525 + scirs2_core::numeric::Float,
526 {
527 if let Some(data) = tensor.as_slice() {
528 if data.is_empty() {
529 return Ok(TensorStats {
530 mean: T::zero(),
531 std: T::zero(),
532 min: T::zero(),
533 max: T::zero(),
534 element_count: 0,
535 });
536 }
537
538 let mut sum = T::zero();
539 let mut squared_sum = T::zero();
540 let mut min_val = data[0];
541 let mut max_val = data[0];
542
543 for &value in data {
544 sum = sum + value;
545 squared_sum = squared_sum + value * value;
546 if value < min_val {
547 min_val = value;
548 }
549 if value > max_val {
550 max_val = value;
551 }
552 }
553
554 let n = T::from(data.len())
555 .unwrap_or(T::from(1.0).expect("constant 1.0 should convert to float"));
556 let mean = sum / n;
557 let variance = (squared_sum / n) - (mean * mean);
558 let std = variance.sqrt();
559
560 Ok(TensorStats {
561 mean,
562 std,
563 min: min_val,
564 max: max_val,
565 element_count: data.len(),
566 })
567 } else {
568 Err(TensorError::device_error_simple(
569 "Cannot access tensor data".to_string(),
570 ))
571 }
572 }
573
574 pub fn calculate_change_magnitude<T>(
576 original: &tenflowers_core::Tensor<T>,
577 transformed: &tenflowers_core::Tensor<T>,
578 ) -> Result<T>
579 where
580 T: Clone
581 + Default
582 + scirs2_core::numeric::Zero
583 + Send
584 + Sync
585 + 'static
586 + scirs2_core::numeric::Float,
587 {
588 if let (Some(orig_data), Some(trans_data)) = (original.as_slice(), transformed.as_slice()) {
589 if orig_data.len() != trans_data.len() {
590 return Err(TensorError::invalid_argument(
591 "Tensor size mismatch".to_string(),
592 ));
593 }
594
595 let mut total_change = T::zero();
596 for (orig, trans) in orig_data.iter().zip(trans_data.iter()) {
597 let diff = *trans - *orig;
598 total_change = total_change + diff * diff;
599 }
600
601 let n = T::from(orig_data.len())
602 .unwrap_or(T::from(1.0).expect("constant 1.0 should convert to float"));
603 Ok((total_change / n).sqrt()) } else {
605 Err(TensorError::device_error_simple(
606 "Cannot access tensor data".to_string(),
607 ))
608 }
609 }
610}