1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
7use std::error::Error;
8
9use crate::visualization::interactive::InteractiveOptions;
10use crate::visualization::{ColorMap, PlotType, VisualizationData, VisualizationMetadata};
11
12#[allow(dead_code)]
24pub fn visualize_confusion_matrix<A>(
25 confusion_matrix: ArrayView2<A>,
26 class_names: Option<Vec<String>>,
27 normalize: bool,
28) -> Box<dyn crate::visualization::MetricVisualizer>
29where
30 A: Clone + Into<f64>,
31{
32 let cm_f64 = Array2::from_shape_fn(confusion_matrix.dim(), |(i, j)| {
34 confusion_matrix[[i, j]].clone().into()
35 });
36
37 crate::visualization::confusion_matrix::confusion_matrix_visualization(
38 cm_f64,
39 class_names,
40 normalize,
41 )
42}
43
44#[allow(dead_code)]
57pub fn visualize_roc_curve<A>(
58 fpr: ArrayView1<A>,
59 tpr: ArrayView1<A>,
60 thresholds: Option<ArrayView1<A>>,
61 auc: Option<f64>,
62) -> Box<dyn crate::visualization::MetricVisualizer>
63where
64 A: Clone + Into<f64>,
65{
66 let fpr_vec = fpr.iter().map(|x| x.clone().into()).collect::<Vec<f64>>();
68 let tpr_vec = tpr.iter().map(|x| x.clone().into()).collect::<Vec<f64>>();
69 let thresholds_vec =
70 thresholds.map(|t| t.iter().map(|x| x.clone().into()).collect::<Vec<f64>>());
71
72 Box::new(crate::visualization::roc_curve::roc_curve_visualization(
73 fpr_vec,
74 tpr_vec,
75 thresholds_vec,
76 auc,
77 ))
78}
79
80#[allow(dead_code)]
94pub fn visualize_interactive_roc_curve<A>(
95 fpr: ArrayView1<A>,
96 tpr: ArrayView1<A>,
97 thresholds: Option<ArrayView1<A>>,
98 auc: Option<f64>,
99 interactive_options: Option<InteractiveOptions>,
100) -> Box<dyn crate::visualization::MetricVisualizer>
101where
102 A: Clone + Into<f64>,
103{
104 let fpr_vec = fpr.iter().map(|x| x.clone().into()).collect::<Vec<f64>>();
106 let tpr_vec = tpr.iter().map(|x| x.clone().into()).collect::<Vec<f64>>();
107 let thresholds_vec =
108 thresholds.map(|t| t.iter().map(|x| x.clone().into()).collect::<Vec<f64>>());
109
110 let mut visualizer = crate::visualization::interactive::interactive_roc_curve_visualization(
111 fpr_vec,
112 tpr_vec,
113 thresholds_vec,
114 auc,
115 );
116
117 if let Some(_options) = interactive_options {
118 visualizer = visualizer.with_interactive_options(_options);
119 }
120
121 Box::new(visualizer)
122}
123
124#[allow(dead_code)]
137pub fn visualize_interactive_roc_from_labels<A, B>(
138 y_true: ArrayView1<A>,
139 y_score: ArrayView1<B>,
140 _pos_label: Option<A>,
141 interactive_options: Option<InteractiveOptions>,
142) -> Result<Box<dyn crate::visualization::MetricVisualizer>, Box<dyn Error>>
143where
144 A: Clone + PartialOrd + 'static,
145 B: Clone + PartialOrd + 'static,
146 f64: From<A> + From<B>,
147{
148 let (fpr, tpr, _thresholds) = crate::classification::curves::roc_curve(&y_true, &y_score)
150 .map_err(|e| Box::new(e) as Box<dyn Error>)?;
151
152 let auc = {
154 let n = fpr.len();
155 let mut area = 0.0;
156 for i in 1..n {
157 area += (fpr[i] - fpr[i - 1]) * (tpr[i] + tpr[i - 1]) / 2.0;
158 }
159 area
160 };
161
162 let mut visualizer = crate::visualization::interactive::roc_curve::InteractiveROCVisualizer::<
164 f64,
165 scirs2_core::ndarray::OwnedRepr<f64>,
166 >::new(fpr.to_vec(), tpr.to_vec(), None, Some(auc));
167
168 if let Some(_options) = interactive_options {
169 visualizer = visualizer.with_interactive_options(_options);
170 }
171
172 Ok(Box::new(visualizer))
173}
174
175#[allow(dead_code)]
188pub fn visualize_precision_recall_curve<A>(
189 precision: ArrayView1<A>,
190 recall: ArrayView1<A>,
191 thresholds: Option<ArrayView1<A>>,
192 average_precision: Option<f64>,
193) -> Box<dyn crate::visualization::MetricVisualizer>
194where
195 A: Clone + Into<f64>,
196{
197 let precision_vec = precision
199 .iter()
200 .map(|x| x.clone().into())
201 .collect::<Vec<f64>>();
202 let recall_vec = recall
203 .iter()
204 .map(|x| x.clone().into())
205 .collect::<Vec<f64>>();
206 let thresholds_vec =
207 thresholds.map(|t| t.iter().map(|x| x.clone().into()).collect::<Vec<f64>>());
208
209 Box::new(
210 crate::visualization::precision_recall::precision_recall_visualization(
211 precision_vec,
212 recall_vec,
213 thresholds_vec,
214 average_precision,
215 ),
216 )
217}
218
219#[allow(dead_code)]
232pub fn visualize_calibration_curve<A>(
233 prob_true: ArrayView1<A>,
234 prob_pred: ArrayView1<A>,
235 n_bins: usize,
236 strategy: impl Into<String>,
237) -> Box<dyn crate::visualization::MetricVisualizer>
238where
239 A: Clone + Into<f64>,
240{
241 let prob_true_vec = prob_true
243 .iter()
244 .map(|x| x.clone().into())
245 .collect::<Vec<f64>>();
246 let prob_pred_vec = prob_pred
247 .iter()
248 .map(|x| x.clone().into())
249 .collect::<Vec<f64>>();
250
251 Box::new(
252 crate::visualization::calibration::calibration_visualization(
253 prob_true_vec,
254 prob_pred_vec,
255 n_bins,
256 strategy.into(),
257 ),
258 )
259}
260
261#[allow(dead_code)]
275pub fn visualize_learning_curve(
276 train_sizes: Vec<usize>,
277 train_scores: Vec<Vec<f64>>,
278 val_scores: Vec<Vec<f64>>,
279 score_name: impl Into<String>,
280) -> Result<Box<dyn crate::visualization::MetricVisualizer>, Box<dyn Error>> {
281 let visualizer = crate::visualization::learning_curve::learning_curve_visualization(
282 train_sizes,
283 train_scores,
284 val_scores,
285 score_name,
286 )?;
287
288 Ok(Box::new(visualizer))
289}
290
291#[allow(dead_code)]
309pub fn visualize_metric<A, B>(
310 x_values: ArrayView1<A>,
311 y_values: ArrayView1<B>,
312 title: impl Into<String>,
313 x_label: impl Into<String>,
314 y_label: impl Into<String>,
315 plot_type: PlotType,
316) -> Box<dyn crate::visualization::MetricVisualizer>
317where
318 A: Clone + Into<f64>,
319 B: Clone + Into<f64>,
320{
321 let x_vec = x_values
322 .iter()
323 .map(|x| x.clone().into())
324 .collect::<Vec<f64>>();
325 let y_vec = y_values
326 .iter()
327 .map(|y| y.clone().into())
328 .collect::<Vec<f64>>();
329
330 Box::new(GenericMetricVisualizer::new(
331 x_vec,
332 y_vec,
333 title.into(),
334 x_label.into(),
335 y_label.into(),
336 plot_type,
337 ))
338}
339
340pub struct GenericMetricVisualizer {
342 pub x: Vec<f64>,
344 pub y: Vec<f64>,
346 pub title: String,
348 pub x_label: String,
350 pub y_label: String,
352 pub plot_type: PlotType,
354 pub series_names: Option<Vec<String>>,
356}
357
358impl GenericMetricVisualizer {
359 pub fn new(
361 x: Vec<f64>,
362 y: Vec<f64>,
363 title: impl Into<String>,
364 x_label: impl Into<String>,
365 y_label: impl Into<String>,
366 plot_type: PlotType,
367 ) -> Self {
368 Self {
369 x,
370 y,
371 title: title.into(),
372 x_label: x_label.into(),
373 y_label: y_label.into(),
374 plot_type,
375 series_names: None,
376 }
377 }
378
379 pub fn with_series_names(mut self, seriesnames: Vec<String>) -> Self {
381 self.series_names = Some(seriesnames);
382 self
383 }
384}
385
386impl crate::visualization::MetricVisualizer for GenericMetricVisualizer {
387 fn prepare_data(&self) -> Result<VisualizationData, Box<dyn Error>> {
388 let mut data = VisualizationData::new();
389
390 data.x = self.x.clone();
392 data.y = self.y.clone();
393
394 if let Some(series_names) = &self.series_names {
396 data.series_names = Some(series_names.clone());
397 }
398
399 Ok(data)
400 }
401
402 fn get_metadata(&self) -> VisualizationMetadata {
403 let mut metadata = VisualizationMetadata::new(self.title.clone());
404 metadata.set_plot_type(self.plot_type.clone());
405 metadata.set_x_label(self.x_label.clone());
406 metadata.set_y_label(self.y_label.clone());
407 metadata
408 }
409}
410
411#[allow(dead_code)]
429pub fn visualize_multi_curve<A, B>(
430 x_values: ArrayView1<A>,
431 y_values_list: Vec<ArrayView1<B>>,
432 series_names: Vec<String>,
433 title: impl Into<String>,
434 x_label: impl Into<String>,
435 y_label: impl Into<String>,
436) -> Box<dyn crate::visualization::MetricVisualizer>
437where
438 A: Clone + Into<f64>,
439 B: Clone + Into<f64>,
440{
441 let x_vec = x_values
442 .iter()
443 .map(|x| x.clone().into())
444 .collect::<Vec<f64>>();
445
446 let y_vec = if !y_values_list.is_empty() {
448 y_values_list[0]
449 .iter()
450 .map(|y| y.clone().into())
451 .collect::<Vec<f64>>()
452 } else {
453 Vec::new()
454 };
455
456 let mut visualizer =
458 MultiCurveVisualizer::new(x_vec, y_vec, title.into(), x_label.into(), y_label.into());
459
460 for (i, y_values) in y_values_list.iter().enumerate() {
462 if i == 0 {
463 continue;
465 }
466
467 let name = if i < series_names.len() {
468 series_names[i].clone()
469 } else {
470 format!("Series {}", i + 1)
471 };
472
473 let y_vec = y_values
474 .iter()
475 .map(|y| y.clone().into())
476 .collect::<Vec<f64>>();
477 visualizer.add_series(name, y_vec);
478 }
479
480 visualizer.set_series_names(series_names);
482
483 Box::new(visualizer)
484}
485
486pub struct MultiCurveVisualizer {
488 pub x: Vec<f64>,
490 pub y: Vec<f64>,
492 pub secondary_y: Vec<(String, Vec<f64>)>,
494 pub title: String,
496 pub x_label: String,
498 pub y_label: String,
500 pub series_names: Vec<String>,
502}
503
504impl MultiCurveVisualizer {
505 pub fn new(
507 x: Vec<f64>,
508 y: Vec<f64>,
509 title: impl Into<String>,
510 x_label: impl Into<String>,
511 y_label: impl Into<String>,
512 ) -> Self {
513 Self {
514 x,
515 y,
516 secondary_y: Vec::new(),
517 title: title.into(),
518 x_label: x_label.into(),
519 y_label: y_label.into(),
520 series_names: Vec::new(),
521 }
522 }
523
524 pub fn add_series(&mut self, name: impl Into<String>, y: Vec<f64>) {
526 self.secondary_y.push((name.into(), y));
527 }
528
529 pub fn set_series_names(&mut self, names: Vec<String>) {
531 self.series_names = names;
532 }
533}
534
535impl crate::visualization::MetricVisualizer for MultiCurveVisualizer {
536 fn prepare_data(&self) -> Result<VisualizationData, Box<dyn Error>> {
537 let mut data = VisualizationData::new();
538
539 data.x = self.x.clone();
541 data.y = self.y.clone();
542
543 for (name, y) in &self.secondary_y {
545 data.series.insert(name.clone(), y.clone());
546 }
547
548 if !self.series_names.is_empty() {
550 data.series_names = Some(self.series_names.clone());
551 }
552
553 Ok(data)
554 }
555
556 fn get_metadata(&self) -> VisualizationMetadata {
557 let mut metadata = VisualizationMetadata::new(self.title.clone());
558 metadata.set_plot_type(PlotType::Line);
559 metadata.set_x_label(self.x_label.clone());
560 metadata.set_y_label(self.y_label.clone());
561 metadata
562 }
563}
564
565#[allow(dead_code)]
582pub fn visualize_heatmap<A>(
583 matrix: ArrayView2<A>,
584 x_labels: Option<Vec<String>>,
585 y_labels: Option<Vec<String>>,
586 title: impl Into<String>,
587 color_map: Option<ColorMap>,
588) -> Box<dyn crate::visualization::MetricVisualizer>
589where
590 A: Clone + Into<f64>,
591{
592 let z = Array2::from_shape_fn(matrix.dim(), |(i, j)| matrix[[i, j]].clone().into());
594
595 let z_vec = (0..z.shape()[0])
596 .map(|i| (0..z.shape()[1]).map(|j| z[[i, j]]).collect::<Vec<f64>>())
597 .collect::<Vec<Vec<f64>>>();
598
599 let x = (0..z.shape()[1]).map(|i| i as f64).collect::<Vec<f64>>();
601 let y = (0..z.shape()[0]).map(|i| i as f64).collect::<Vec<f64>>();
602
603 Box::new(HeatmapVisualizer::new(
604 x,
605 y,
606 z_vec,
607 title.into(),
608 x_labels,
609 y_labels,
610 color_map,
611 ))
612}
613
614pub struct HeatmapVisualizer {
616 pub x: Vec<f64>,
618 pub y: Vec<f64>,
620 pub z: Vec<Vec<f64>>,
622 pub title: String,
624 pub x_labels: Option<Vec<String>>,
626 pub y_labels: Option<Vec<String>>,
628 pub color_map: Option<ColorMap>,
630}
631
632impl HeatmapVisualizer {
633 pub fn new(
635 x: Vec<f64>,
636 y: Vec<f64>,
637 z: Vec<Vec<f64>>,
638 title: impl Into<String>,
639 x_labels: Option<Vec<String>>,
640 y_labels: Option<Vec<String>>,
641 color_map: Option<ColorMap>,
642 ) -> Self {
643 Self {
644 x,
645 y,
646 z,
647 title: title.into(),
648 x_labels,
649 y_labels,
650 color_map,
651 }
652 }
653}
654
655impl crate::visualization::MetricVisualizer for HeatmapVisualizer {
656 fn prepare_data(&self) -> Result<VisualizationData, Box<dyn Error>> {
657 let mut data = VisualizationData::new();
658
659 data.x = self.x.clone();
661 data.y = self.y.clone();
662 data.z = Some(self.z.clone());
663
664 if let Some(x_labels) = &self.x_labels {
666 data.x_labels = Some(x_labels.clone());
667 }
668
669 if let Some(y_labels) = &self.y_labels {
670 data.y_labels = Some(y_labels.clone());
671 }
672
673 Ok(data)
674 }
675
676 fn get_metadata(&self) -> VisualizationMetadata {
677 let mut metadata = VisualizationMetadata::new(self.title.clone());
678 metadata.set_plot_type(PlotType::Heatmap);
679
680 if self.x_labels.is_none() {
682 metadata.set_x_label("X");
683 } else {
684 metadata.set_x_label(""); }
686
687 if self.y_labels.is_none() {
688 metadata.set_y_label("Y");
689 } else {
690 metadata.set_y_label(""); }
692
693 metadata
694 }
695}
696
697#[allow(dead_code)]
713pub fn visualize_histogram<A>(
714 values: ArrayView1<A>,
715 bins: usize,
716 title: impl Into<String>,
717 x_label: impl Into<String>,
718 y_label: Option<String>,
719) -> Box<dyn crate::visualization::MetricVisualizer>
720where
721 A: Clone + Into<f64>,
722{
723 let values_vec = values
725 .iter()
726 .map(|x| x.clone().into())
727 .collect::<Vec<f64>>();
728
729 let (bin_edges, bin_counts) = create_histogram_bins(&values_vec, bins);
731
732 Box::new(HistogramVisualizer::new(
733 bin_edges,
734 bin_counts,
735 title.into(),
736 x_label.into(),
737 y_label.unwrap_or_else(|| "Frequency".to_string()),
738 ))
739}
740
741#[allow(dead_code)]
752fn create_histogram_bins(values: &[f64], bins: usize) -> (Vec<f64>, Vec<f64>) {
753 if values.is_empty() || bins == 0 {
755 return (Vec::new(), Vec::new());
756 }
757
758 let min_val = values.iter().fold(f64::INFINITY, |min, &val| min.min(val));
760 let max_val = values
761 .iter()
762 .fold(f64::NEG_INFINITY, |max, &val| max.max(val));
763
764 let bin_width = (max_val - min_val) / bins as f64;
766 let mut bin_edges = Vec::with_capacity(bins + 1);
767 for i in 0..=bins {
768 bin_edges.push(min_val + i as f64 * bin_width);
769 }
770
771 let mut bin_counts = vec![0.0; bins];
773 for &val in values {
774 if val >= min_val && val <= max_val {
775 let bin_idx = ((val - min_val) / bin_width).floor() as usize;
776 let bin_idx = bin_idx.min(bins - 1);
778 bin_counts[bin_idx] += 1.0;
779 }
780 }
781
782 (bin_edges, bin_counts)
783}
784
785pub struct HistogramVisualizer {
787 pub bin_edges: Vec<f64>,
789 pub bin_counts: Vec<f64>,
791 pub title: String,
793 pub x_label: String,
795 pub y_label: String,
797}
798
799impl HistogramVisualizer {
800 pub fn new(
802 bin_edges: Vec<f64>,
803 bin_counts: Vec<f64>,
804 title: impl Into<String>,
805 x_label: impl Into<String>,
806 y_label: impl Into<String>,
807 ) -> Self {
808 Self {
809 bin_edges,
810 bin_counts,
811 title: title.into(),
812 x_label: x_label.into(),
813 y_label: y_label.into(),
814 }
815 }
816}
817
818impl crate::visualization::MetricVisualizer for HistogramVisualizer {
819 fn prepare_data(&self) -> Result<VisualizationData, Box<dyn Error>> {
820 let mut data = VisualizationData::new();
821
822 if self.bin_edges.len() > 1 {
824 let bin_centers = self
825 .bin_edges
826 .windows(2)
827 .map(|w| (w[0] + w[1]) / 2.0)
828 .collect::<Vec<f64>>();
829
830 data.x = bin_centers;
831 } else {
832 data.x = Vec::new();
833 }
834
835 data.y = self.bin_counts.clone();
837
838 data.add_auxiliary_data("bin_edges", self.bin_edges.clone());
840
841 Ok(data)
842 }
843
844 fn get_metadata(&self) -> VisualizationMetadata {
845 let mut metadata = VisualizationMetadata::new(self.title.clone());
846 metadata.set_plot_type(PlotType::Histogram);
847 metadata.set_x_label(self.x_label.clone());
848 metadata.set_y_label(self.y_label.clone());
849 metadata
850 }
851}