scirs2_metrics/visualization/
helpers.rs

1//! Helper functions for visualization
2//!
3//! This module provides helper functions for creating visualizations for common
4//! metrics result types.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
7use std::error::Error;
8
9use crate::visualization::interactive::InteractiveOptions;
10use crate::visualization::{ColorMap, PlotType, VisualizationData, VisualizationMetadata};
11
12/// Create a confusion matrix visualization from a confusion matrix array
13///
14/// # Arguments
15///
16/// * `confusion_matrix` - The confusion matrix as a 2D array
17/// * `class_names` - Optional class names
18/// * `normalize` - Whether to normalize the confusion matrix
19///
20/// # Returns
21///
22/// * `Box<dyn crate::visualization::MetricVisualizer>` - A visualizer for the confusion matrix
23#[allow(dead_code)]
24pub fn visualize_confusion_matrix<A>(
25    confusion_matrix: ArrayView2<A>,
26    class_names: Option<Vec<String>>,
27    normalize: bool,
28) -> Box<dyn crate::visualization::MetricVisualizer>
29where
30    A: Clone + Into<f64>,
31{
32    // Convert the confusion _matrix to f64
33    let cm_f64 = Array2::from_shape_fn(confusion_matrix.dim(), |(i, j)| {
34        confusion_matrix[[i, j]].clone().into()
35    });
36
37    crate::visualization::confusion_matrix::confusion_matrix_visualization(
38        cm_f64,
39        class_names,
40        normalize,
41    )
42}
43
44/// Create a ROC curve visualization
45///
46/// # Arguments
47///
48/// * `fpr` - False positive rates
49/// * `tpr` - True positive rates
50/// * `thresholds` - Optional thresholds
51/// * `auc` - Optional area under the curve
52///
53/// # Returns
54///
55/// * `Box<dyn crate::visualization::MetricVisualizer>` - A visualizer for the ROC curve
56#[allow(dead_code)]
57pub fn visualize_roc_curve<A>(
58    fpr: ArrayView1<A>,
59    tpr: ArrayView1<A>,
60    thresholds: Option<ArrayView1<A>>,
61    auc: Option<f64>,
62) -> Box<dyn crate::visualization::MetricVisualizer>
63where
64    A: Clone + Into<f64>,
65{
66    // Convert the arrays to f64 vectors
67    let fpr_vec = fpr.iter().map(|x| x.clone().into()).collect::<Vec<f64>>();
68    let tpr_vec = tpr.iter().map(|x| x.clone().into()).collect::<Vec<f64>>();
69    let thresholds_vec =
70        thresholds.map(|t| t.iter().map(|x| x.clone().into()).collect::<Vec<f64>>());
71
72    Box::new(crate::visualization::roc_curve::roc_curve_visualization(
73        fpr_vec,
74        tpr_vec,
75        thresholds_vec,
76        auc,
77    ))
78}
79
80/// Create an interactive ROC curve visualization
81///
82/// # Arguments
83///
84/// * `fpr` - False positive rates
85/// * `tpr` - True positive rates
86/// * `thresholds` - Optional thresholds
87/// * `auc` - Optional area under the curve
88/// * `interactive_options` - Optional interactive options
89///
90/// # Returns
91///
92/// * `Box<dyn crate::visualization::MetricVisualizer>` - An interactive visualizer for the ROC curve
93#[allow(dead_code)]
94pub fn visualize_interactive_roc_curve<A>(
95    fpr: ArrayView1<A>,
96    tpr: ArrayView1<A>,
97    thresholds: Option<ArrayView1<A>>,
98    auc: Option<f64>,
99    interactive_options: Option<InteractiveOptions>,
100) -> Box<dyn crate::visualization::MetricVisualizer>
101where
102    A: Clone + Into<f64>,
103{
104    // Convert the arrays to f64 vectors
105    let fpr_vec = fpr.iter().map(|x| x.clone().into()).collect::<Vec<f64>>();
106    let tpr_vec = tpr.iter().map(|x| x.clone().into()).collect::<Vec<f64>>();
107    let thresholds_vec =
108        thresholds.map(|t| t.iter().map(|x| x.clone().into()).collect::<Vec<f64>>());
109
110    let mut visualizer = crate::visualization::interactive::interactive_roc_curve_visualization(
111        fpr_vec,
112        tpr_vec,
113        thresholds_vec,
114        auc,
115    );
116
117    if let Some(_options) = interactive_options {
118        visualizer = visualizer.with_interactive_options(_options);
119    }
120
121    Box::new(visualizer)
122}
123
124/// Create an interactive ROC curve visualization from labels and scores
125///
126/// # Arguments
127///
128/// * `y_true` - True binary labels
129/// * `y_score` - Target scores (probabilities or decision function output)
130/// * `pos_label` - Optional positive class label
131/// * `interactive_options` - Optional interactive options
132///
133/// # Returns
134///
135/// * `Result<Box<dyn crate::visualization::MetricVisualizer>, Box<dyn Error>>` - An interactive visualizer for the ROC curve
136#[allow(dead_code)]
137pub fn visualize_interactive_roc_from_labels<A, B>(
138    y_true: ArrayView1<A>,
139    y_score: ArrayView1<B>,
140    _pos_label: Option<A>,
141    interactive_options: Option<InteractiveOptions>,
142) -> Result<Box<dyn crate::visualization::MetricVisualizer>, Box<dyn Error>>
143where
144    A: Clone + PartialOrd + 'static,
145    B: Clone + PartialOrd + 'static,
146    f64: From<A> + From<B>,
147{
148    // Compute ROC curve
149    let (fpr, tpr, _thresholds) = crate::classification::curves::roc_curve(&y_true, &y_score)
150        .map_err(|e| Box::new(e) as Box<dyn Error>)?;
151
152    // Calculate AUC - simplified version
153    let auc = {
154        let n = fpr.len();
155        let mut area = 0.0;
156        for i in 1..n {
157            area += (fpr[i] - fpr[i - 1]) * (tpr[i] + tpr[i - 1]) / 2.0;
158        }
159        area
160    };
161
162    // Create an owned ROC curve visualizer using raw data
163    let mut visualizer = crate::visualization::interactive::roc_curve::InteractiveROCVisualizer::<
164        f64,
165        scirs2_core::ndarray::OwnedRepr<f64>,
166    >::new(fpr.to_vec(), tpr.to_vec(), None, Some(auc));
167
168    if let Some(_options) = interactive_options {
169        visualizer = visualizer.with_interactive_options(_options);
170    }
171
172    Ok(Box::new(visualizer))
173}
174
175/// Create a precision-recall curve visualization
176///
177/// # Arguments
178///
179/// * `precision` - Precision values
180/// * `recall` - Recall values
181/// * `thresholds` - Optional thresholds
182/// * `average_precision` - Optional average precision
183///
184/// # Returns
185///
186/// * `Box<dyn crate::visualization::MetricVisualizer>` - A visualizer for the precision-recall curve
187#[allow(dead_code)]
188pub fn visualize_precision_recall_curve<A>(
189    precision: ArrayView1<A>,
190    recall: ArrayView1<A>,
191    thresholds: Option<ArrayView1<A>>,
192    average_precision: Option<f64>,
193) -> Box<dyn crate::visualization::MetricVisualizer>
194where
195    A: Clone + Into<f64>,
196{
197    // Convert the arrays to f64 vectors
198    let precision_vec = precision
199        .iter()
200        .map(|x| x.clone().into())
201        .collect::<Vec<f64>>();
202    let recall_vec = recall
203        .iter()
204        .map(|x| x.clone().into())
205        .collect::<Vec<f64>>();
206    let thresholds_vec =
207        thresholds.map(|t| t.iter().map(|x| x.clone().into()).collect::<Vec<f64>>());
208
209    Box::new(
210        crate::visualization::precision_recall::precision_recall_visualization(
211            precision_vec,
212            recall_vec,
213            thresholds_vec,
214            average_precision,
215        ),
216    )
217}
218
219/// Create a calibration curve visualization
220///
221/// # Arguments
222///
223/// * `prob_true` - True probabilities
224/// * `prob_pred` - Predicted probabilities
225/// * `n_bins` - Number of bins
226/// * `strategy` - Binning strategy ("uniform" or "quantile")
227///
228/// # Returns
229///
230/// * `Box<dyn crate::visualization::MetricVisualizer>` - A visualizer for the calibration curve
231#[allow(dead_code)]
232pub fn visualize_calibration_curve<A>(
233    prob_true: ArrayView1<A>,
234    prob_pred: ArrayView1<A>,
235    n_bins: usize,
236    strategy: impl Into<String>,
237) -> Box<dyn crate::visualization::MetricVisualizer>
238where
239    A: Clone + Into<f64>,
240{
241    // Convert the arrays to f64 vectors
242    let prob_true_vec = prob_true
243        .iter()
244        .map(|x| x.clone().into())
245        .collect::<Vec<f64>>();
246    let prob_pred_vec = prob_pred
247        .iter()
248        .map(|x| x.clone().into())
249        .collect::<Vec<f64>>();
250
251    Box::new(
252        crate::visualization::calibration::calibration_visualization(
253            prob_true_vec,
254            prob_pred_vec,
255            n_bins,
256            strategy.into(),
257        ),
258    )
259}
260
261/// Create a learning curve visualization
262///
263/// # Arguments
264///
265/// * `train_sizes` - Training set sizes
266/// * `train_scores` - Training scores (multiple runs for each size)
267/// * `val_scores` - Validation scores (multiple runs for each size)
268/// * `score_name` - Name of the score (e.g., "Accuracy")
269///
270/// # Returns
271///
272/// * `Box<dyn crate::visualization::MetricVisualizer>` - A visualizer for the learning curve
273/// * `Result<Box<dyn crate::visualization::MetricVisualizer>, Box<dyn Error>>` - A visualizer for the learning curve, or an error
274#[allow(dead_code)]
275pub fn visualize_learning_curve(
276    train_sizes: Vec<usize>,
277    train_scores: Vec<Vec<f64>>,
278    val_scores: Vec<Vec<f64>>,
279    score_name: impl Into<String>,
280) -> Result<Box<dyn crate::visualization::MetricVisualizer>, Box<dyn Error>> {
281    let visualizer = crate::visualization::learning_curve::learning_curve_visualization(
282        train_sizes,
283        train_scores,
284        val_scores,
285        score_name,
286    )?;
287
288    Ok(Box::new(visualizer))
289}
290
291/// Create a generic metric visualization
292///
293/// This function creates a visualization for generic metric data,
294/// such as performance over time, hyperparameter tuning results, etc.
295///
296/// # Arguments
297///
298/// * `x_values` - X-axis values
299/// * `y_values` - Y-axis values
300/// * `title` - Plot title
301/// * `x_label` - X-axis label
302/// * `y_label` - Y-axis label
303/// * `plot_type` - Plot type
304///
305/// # Returns
306///
307/// * `Box<dyn crate::visualization::MetricVisualizer>` - A visualizer for the generic metric
308#[allow(dead_code)]
309pub fn visualize_metric<A, B>(
310    x_values: ArrayView1<A>,
311    y_values: ArrayView1<B>,
312    title: impl Into<String>,
313    x_label: impl Into<String>,
314    y_label: impl Into<String>,
315    plot_type: PlotType,
316) -> Box<dyn crate::visualization::MetricVisualizer>
317where
318    A: Clone + Into<f64>,
319    B: Clone + Into<f64>,
320{
321    let x_vec = x_values
322        .iter()
323        .map(|x| x.clone().into())
324        .collect::<Vec<f64>>();
325    let y_vec = y_values
326        .iter()
327        .map(|y| y.clone().into())
328        .collect::<Vec<f64>>();
329
330    Box::new(GenericMetricVisualizer::new(
331        x_vec,
332        y_vec,
333        title.into(),
334        x_label.into(),
335        y_label.into(),
336        plot_type,
337    ))
338}
339
340/// A generic visualizer for metric data
341pub struct GenericMetricVisualizer {
342    /// X-axis values
343    pub x: Vec<f64>,
344    /// Y-axis values
345    pub y: Vec<f64>,
346    /// Title
347    pub title: String,
348    /// X-axis label
349    pub x_label: String,
350    /// Y-axis label
351    pub y_label: String,
352    /// Plot type
353    pub plot_type: PlotType,
354    /// Optional series names
355    pub series_names: Option<Vec<String>>,
356}
357
358impl GenericMetricVisualizer {
359    /// Create a new generic metric visualizer
360    pub fn new(
361        x: Vec<f64>,
362        y: Vec<f64>,
363        title: impl Into<String>,
364        x_label: impl Into<String>,
365        y_label: impl Into<String>,
366        plot_type: PlotType,
367    ) -> Self {
368        Self {
369            x,
370            y,
371            title: title.into(),
372            x_label: x_label.into(),
373            y_label: y_label.into(),
374            plot_type,
375            series_names: None,
376        }
377    }
378
379    /// Add series names
380    pub fn with_series_names(mut self, seriesnames: Vec<String>) -> Self {
381        self.series_names = Some(seriesnames);
382        self
383    }
384}
385
386impl crate::visualization::MetricVisualizer for GenericMetricVisualizer {
387    fn prepare_data(&self) -> Result<VisualizationData, Box<dyn Error>> {
388        let mut data = VisualizationData::new();
389
390        // Set x and y data
391        data.x = self.x.clone();
392        data.y = self.y.clone();
393
394        // Add series names if available
395        if let Some(series_names) = &self.series_names {
396            data.series_names = Some(series_names.clone());
397        }
398
399        Ok(data)
400    }
401
402    fn get_metadata(&self) -> VisualizationMetadata {
403        let mut metadata = VisualizationMetadata::new(self.title.clone());
404        metadata.set_plot_type(self.plot_type.clone());
405        metadata.set_x_label(self.x_label.clone());
406        metadata.set_y_label(self.y_label.clone());
407        metadata
408    }
409}
410
411/// Create a multi-curve visualization
412///
413/// This function creates a visualization with multiple curves,
414/// such as performance comparisons between different models.
415///
416/// # Arguments
417///
418/// * `x_values` - X-axis values (common for all curves)
419/// * `y_values_list` - List of Y-axis values, one for each curve
420/// * `series_names` - Names for each curve
421/// * `title` - Plot title
422/// * `x_label` - X-axis label
423/// * `y_label` - Y-axis label
424///
425/// # Returns
426///
427/// * `Box<dyn crate::visualization::MetricVisualizer>` - A visualizer for the multi-curve plot
428#[allow(dead_code)]
429pub fn visualize_multi_curve<A, B>(
430    x_values: ArrayView1<A>,
431    y_values_list: Vec<ArrayView1<B>>,
432    series_names: Vec<String>,
433    title: impl Into<String>,
434    x_label: impl Into<String>,
435    y_label: impl Into<String>,
436) -> Box<dyn crate::visualization::MetricVisualizer>
437where
438    A: Clone + Into<f64>,
439    B: Clone + Into<f64>,
440{
441    let x_vec = x_values
442        .iter()
443        .map(|x| x.clone().into())
444        .collect::<Vec<f64>>();
445
446    // Set the first y-_values as the main y-axis data
447    let y_vec = if !y_values_list.is_empty() {
448        y_values_list[0]
449            .iter()
450            .map(|y| y.clone().into())
451            .collect::<Vec<f64>>()
452    } else {
453        Vec::new()
454    };
455
456    // Create a visualizer
457    let mut visualizer =
458        MultiCurveVisualizer::new(x_vec, y_vec, title.into(), x_label.into(), y_label.into());
459
460    // Add all series
461    for (i, y_values) in y_values_list.iter().enumerate() {
462        if i == 0 {
463            // Skip the first one, already added as main y-axis
464            continue;
465        }
466
467        let name = if i < series_names.len() {
468            series_names[i].clone()
469        } else {
470            format!("Series {}", i + 1)
471        };
472
473        let y_vec = y_values
474            .iter()
475            .map(|y| y.clone().into())
476            .collect::<Vec<f64>>();
477        visualizer.add_series(name, y_vec);
478    }
479
480    // Set all series _names
481    visualizer.set_series_names(series_names);
482
483    Box::new(visualizer)
484}
485
486/// A visualizer for multi-curve plots
487pub struct MultiCurveVisualizer {
488    /// X-axis values
489    pub x: Vec<f64>,
490    /// Y-axis values for the main curve
491    pub y: Vec<f64>,
492    /// Additional Y-axis values for secondary curves
493    pub secondary_y: Vec<(String, Vec<f64>)>,
494    /// Title
495    pub title: String,
496    /// X-axis label
497    pub x_label: String,
498    /// Y-axis label
499    pub y_label: String,
500    /// Series names
501    pub series_names: Vec<String>,
502}
503
504impl MultiCurveVisualizer {
505    /// Create a new multi-curve visualizer
506    pub fn new(
507        x: Vec<f64>,
508        y: Vec<f64>,
509        title: impl Into<String>,
510        x_label: impl Into<String>,
511        y_label: impl Into<String>,
512    ) -> Self {
513        Self {
514            x,
515            y,
516            secondary_y: Vec::new(),
517            title: title.into(),
518            x_label: x_label.into(),
519            y_label: y_label.into(),
520            series_names: Vec::new(),
521        }
522    }
523
524    /// Add a secondary curve
525    pub fn add_series(&mut self, name: impl Into<String>, y: Vec<f64>) {
526        self.secondary_y.push((name.into(), y));
527    }
528
529    /// Set series names
530    pub fn set_series_names(&mut self, names: Vec<String>) {
531        self.series_names = names;
532    }
533}
534
535impl crate::visualization::MetricVisualizer for MultiCurveVisualizer {
536    fn prepare_data(&self) -> Result<VisualizationData, Box<dyn Error>> {
537        let mut data = VisualizationData::new();
538
539        // Set main x and y data
540        data.x = self.x.clone();
541        data.y = self.y.clone();
542
543        // Add secondary curves
544        for (name, y) in &self.secondary_y {
545            data.series.insert(name.clone(), y.clone());
546        }
547
548        // Add series names
549        if !self.series_names.is_empty() {
550            data.series_names = Some(self.series_names.clone());
551        }
552
553        Ok(data)
554    }
555
556    fn get_metadata(&self) -> VisualizationMetadata {
557        let mut metadata = VisualizationMetadata::new(self.title.clone());
558        metadata.set_plot_type(PlotType::Line);
559        metadata.set_x_label(self.x_label.clone());
560        metadata.set_y_label(self.y_label.clone());
561        metadata
562    }
563}
564
565/// Create a heatmap visualization
566///
567/// This function creates a heatmap visualization for 2D data,
568/// such as correlation matrices, distance matrices, etc.
569///
570/// # Arguments
571///
572/// * `matrix` - 2D data matrix
573/// * `x_labels` - Optional labels for x-axis
574/// * `y_labels` - Optional labels for y-axis
575/// * `title` - Plot title
576/// * `color_map` - Optional color map
577///
578/// # Returns
579///
580/// * `Box<dyn crate::visualization::MetricVisualizer>` - A visualizer for the heatmap
581#[allow(dead_code)]
582pub fn visualize_heatmap<A>(
583    matrix: ArrayView2<A>,
584    x_labels: Option<Vec<String>>,
585    y_labels: Option<Vec<String>>,
586    title: impl Into<String>,
587    color_map: Option<ColorMap>,
588) -> Box<dyn crate::visualization::MetricVisualizer>
589where
590    A: Clone + Into<f64>,
591{
592    // Convert matrix to Vec<Vec<f64>>
593    let z = Array2::from_shape_fn(matrix.dim(), |(i, j)| matrix[[i, j]].clone().into());
594
595    let z_vec = (0..z.shape()[0])
596        .map(|i| (0..z.shape()[1]).map(|j| z[[i, j]]).collect::<Vec<f64>>())
597        .collect::<Vec<Vec<f64>>>();
598
599    // Create x and y coordinates for the heatmap
600    let x = (0..z.shape()[1]).map(|i| i as f64).collect::<Vec<f64>>();
601    let y = (0..z.shape()[0]).map(|i| i as f64).collect::<Vec<f64>>();
602
603    Box::new(HeatmapVisualizer::new(
604        x,
605        y,
606        z_vec,
607        title.into(),
608        x_labels,
609        y_labels,
610        color_map,
611    ))
612}
613
614/// A visualizer for heatmaps
615pub struct HeatmapVisualizer {
616    /// X-axis values
617    pub x: Vec<f64>,
618    /// Y-axis values
619    pub y: Vec<f64>,
620    /// Z-axis values (2D matrix)
621    pub z: Vec<Vec<f64>>,
622    /// Title
623    pub title: String,
624    /// X-axis labels
625    pub x_labels: Option<Vec<String>>,
626    /// Y-axis labels
627    pub y_labels: Option<Vec<String>>,
628    /// Color map
629    pub color_map: Option<ColorMap>,
630}
631
632impl HeatmapVisualizer {
633    /// Create a new heatmap visualizer
634    pub fn new(
635        x: Vec<f64>,
636        y: Vec<f64>,
637        z: Vec<Vec<f64>>,
638        title: impl Into<String>,
639        x_labels: Option<Vec<String>>,
640        y_labels: Option<Vec<String>>,
641        color_map: Option<ColorMap>,
642    ) -> Self {
643        Self {
644            x,
645            y,
646            z,
647            title: title.into(),
648            x_labels,
649            y_labels,
650            color_map,
651        }
652    }
653}
654
655impl crate::visualization::MetricVisualizer for HeatmapVisualizer {
656    fn prepare_data(&self) -> Result<VisualizationData, Box<dyn Error>> {
657        let mut data = VisualizationData::new();
658
659        // Set x, y, and z data
660        data.x = self.x.clone();
661        data.y = self.y.clone();
662        data.z = Some(self.z.clone());
663
664        // Add axis labels if available
665        if let Some(x_labels) = &self.x_labels {
666            data.x_labels = Some(x_labels.clone());
667        }
668
669        if let Some(y_labels) = &self.y_labels {
670            data.y_labels = Some(y_labels.clone());
671        }
672
673        Ok(data)
674    }
675
676    fn get_metadata(&self) -> VisualizationMetadata {
677        let mut metadata = VisualizationMetadata::new(self.title.clone());
678        metadata.set_plot_type(PlotType::Heatmap);
679
680        // Set default axis labels if none are provided
681        if self.x_labels.is_none() {
682            metadata.set_x_label("X");
683        } else {
684            metadata.set_x_label(""); // Labels are provided directly
685        }
686
687        if self.y_labels.is_none() {
688            metadata.set_y_label("Y");
689        } else {
690            metadata.set_y_label(""); // Labels are provided directly
691        }
692
693        metadata
694    }
695}
696
697/// Create a histogram visualization
698///
699/// This function creates a histogram visualization for 1D data.
700///
701/// # Arguments
702///
703/// * `values` - Data values
704/// * `bins` - Number of bins
705/// * `title` - Plot title
706/// * `x_label` - X-axis label
707/// * `y_label` - Y-axis label (defaults to "Frequency")
708///
709/// # Returns
710///
711/// * `Box<dyn crate::visualization::MetricVisualizer>` - A visualizer for the histogram
712#[allow(dead_code)]
713pub fn visualize_histogram<A>(
714    values: ArrayView1<A>,
715    bins: usize,
716    title: impl Into<String>,
717    x_label: impl Into<String>,
718    y_label: Option<String>,
719) -> Box<dyn crate::visualization::MetricVisualizer>
720where
721    A: Clone + Into<f64>,
722{
723    // Convert values to f64 vector
724    let values_vec = values
725        .iter()
726        .map(|x| x.clone().into())
727        .collect::<Vec<f64>>();
728
729    // Create histogram bins
730    let (bin_edges, bin_counts) = create_histogram_bins(&values_vec, bins);
731
732    Box::new(HistogramVisualizer::new(
733        bin_edges,
734        bin_counts,
735        title.into(),
736        x_label.into(),
737        y_label.unwrap_or_else(|| "Frequency".to_string()),
738    ))
739}
740
741/// Create histogram bins from data values
742///
743/// # Arguments
744///
745/// * `values` - Data values
746/// * `bins` - Number of bins
747///
748/// # Returns
749///
750/// * `(Vec<f64>, Vec<f64>)` - Bin edges and bin counts
751#[allow(dead_code)]
752fn create_histogram_bins(values: &[f64], bins: usize) -> (Vec<f64>, Vec<f64>) {
753    // Ensure we have at least one value and valid bins
754    if values.is_empty() || bins == 0 {
755        return (Vec::new(), Vec::new());
756    }
757
758    // Find min and max _values
759    let min_val = values.iter().fold(f64::INFINITY, |min, &val| min.min(val));
760    let max_val = values
761        .iter()
762        .fold(f64::NEG_INFINITY, |max, &val| max.max(val));
763
764    // Create bin edges
765    let bin_width = (max_val - min_val) / bins as f64;
766    let mut bin_edges = Vec::with_capacity(bins + 1);
767    for i in 0..=bins {
768        bin_edges.push(min_val + i as f64 * bin_width);
769    }
770
771    // Count _values in each bin
772    let mut bin_counts = vec![0.0; bins];
773    for &val in values {
774        if val >= min_val && val <= max_val {
775            let bin_idx = ((val - min_val) / bin_width).floor() as usize;
776            // Handle the edge case where val is exactly max_val
777            let bin_idx = bin_idx.min(bins - 1);
778            bin_counts[bin_idx] += 1.0;
779        }
780    }
781
782    (bin_edges, bin_counts)
783}
784
785/// A visualizer for histograms
786pub struct HistogramVisualizer {
787    /// Bin edges
788    pub bin_edges: Vec<f64>,
789    /// Bin counts
790    pub bin_counts: Vec<f64>,
791    /// Title
792    pub title: String,
793    /// X-axis label
794    pub x_label: String,
795    /// Y-axis label
796    pub y_label: String,
797}
798
799impl HistogramVisualizer {
800    /// Create a new histogram visualizer
801    pub fn new(
802        bin_edges: Vec<f64>,
803        bin_counts: Vec<f64>,
804        title: impl Into<String>,
805        x_label: impl Into<String>,
806        y_label: impl Into<String>,
807    ) -> Self {
808        Self {
809            bin_edges,
810            bin_counts,
811            title: title.into(),
812            x_label: x_label.into(),
813            y_label: y_label.into(),
814        }
815    }
816}
817
818impl crate::visualization::MetricVisualizer for HistogramVisualizer {
819    fn prepare_data(&self) -> Result<VisualizationData, Box<dyn Error>> {
820        let mut data = VisualizationData::new();
821
822        // Use bin centers as x values
823        if self.bin_edges.len() > 1 {
824            let bin_centers = self
825                .bin_edges
826                .windows(2)
827                .map(|w| (w[0] + w[1]) / 2.0)
828                .collect::<Vec<f64>>();
829
830            data.x = bin_centers;
831        } else {
832            data.x = Vec::new();
833        }
834
835        // Use bin counts as y values
836        data.y = self.bin_counts.clone();
837
838        // Store bin edges in auxiliary data
839        data.add_auxiliary_data("bin_edges", self.bin_edges.clone());
840
841        Ok(data)
842    }
843
844    fn get_metadata(&self) -> VisualizationMetadata {
845        let mut metadata = VisualizationMetadata::new(self.title.clone());
846        metadata.set_plot_type(PlotType::Histogram);
847        metadata.set_x_label(self.x_label.clone());
848        metadata.set_y_label(self.y_label.clone());
849        metadata
850    }
851}