sklears_datasets/
viz.rs

1//! Dataset visualization utilities
2//!
3//! This module provides simple visualization functions for generated datasets
4//! when the `visualization` feature is enabled.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use std::path::Path;
8use thiserror::Error;
9
10/// Error types for visualization operations
11#[derive(Debug, Error)]
12pub enum VisualizationError {
13    #[error("IO error: {0}")]
14    Io(#[from] std::io::Error),
15    #[error("Plotting error: {0}")]
16    Plotting(String),
17    #[error("Invalid dimensions: {0}")]
18    InvalidDimensions(String),
19    #[error("Feature not enabled: {0}")]
20    FeatureNotEnabled(String),
21}
22
23pub type VisualizationResult<T> = Result<T, VisualizationError>;
24
25/// Configuration for plot appearance
26#[derive(Debug, Clone)]
27pub struct PlotConfig {
28    pub width: u32,
29    pub height: u32,
30    pub title: String,
31    pub xlabel: String,
32    pub ylabel: String,
33    pub show_legend: bool,
34    pub marker_size: u32,
35}
36
37impl Default for PlotConfig {
38    fn default() -> Self {
39        Self {
40            width: 800,
41            height: 600,
42            title: "Dataset Visualization".to_string(),
43            xlabel: "Feature 1".to_string(),
44            ylabel: "Feature 2".to_string(),
45            show_legend: true,
46            marker_size: 3,
47        }
48    }
49}
50
51#[cfg(feature = "visualization")]
52use plotters::prelude::*;
53
54#[cfg(feature = "visualization")]
55/// Plot 2D classification dataset with class labels
56pub fn plot_2d_classification<P: AsRef<Path>>(
57    path: P,
58    features: &Array2<f64>,
59    targets: &Array1<i32>,
60    config: Option<PlotConfig>,
61) -> VisualizationResult<()> {
62    let config = config.unwrap_or_default();
63
64    if features.ncols() < 2 {
65        return Err(VisualizationError::InvalidDimensions(
66            "Need at least 2 features for 2D plot".to_string(),
67        ));
68    }
69
70    let root = BitMapBackend::new(path.as_ref(), (config.width, config.height)).into_drawing_area();
71    root.fill(&WHITE)
72        .map_err(|e| VisualizationError::Plotting(format!("{}", e)))?;
73
74    // Find data ranges
75    let x_min = features
76        .column(0)
77        .iter()
78        .cloned()
79        .fold(f64::INFINITY, f64::min);
80    let x_max = features
81        .column(0)
82        .iter()
83        .cloned()
84        .fold(f64::NEG_INFINITY, f64::max);
85    let y_min = features
86        .column(1)
87        .iter()
88        .cloned()
89        .fold(f64::INFINITY, f64::min);
90    let y_max = features
91        .column(1)
92        .iter()
93        .cloned()
94        .fold(f64::NEG_INFINITY, f64::max);
95
96    let mut chart = ChartBuilder::on(&root)
97        .caption(&config.title, ("sans-serif", 30).into_font())
98        .margin(10)
99        .x_label_area_size(30)
100        .y_label_area_size(30)
101        .build_cartesian_2d(x_min..x_max, y_min..y_max)
102        .map_err(|e| VisualizationError::Plotting(format!("{}", e)))?;
103
104    chart
105        .configure_mesh()
106        .x_desc(&config.xlabel)
107        .y_desc(&config.ylabel)
108        .draw()
109        .map_err(|e| VisualizationError::Plotting(format!("{}", e)))?;
110
111    // Group points by class
112    let mut class_points: std::collections::HashMap<i32, Vec<(f64, f64)>> =
113        std::collections::HashMap::new();
114
115    for i in 0..features.nrows() {
116        let x = features[[i, 0]];
117        let y = features[[i, 1]];
118        let class = targets[i];
119        class_points.entry(class).or_default().push((x, y));
120    }
121
122    // Plot each class with different color
123    let colors = [&RED, &BLUE, &GREEN, &YELLOW, &MAGENTA, &CYAN];
124
125    for (idx, (class, points)) in class_points.iter().enumerate() {
126        let color = colors[idx % colors.len()];
127        chart
128            .draw_series(
129                points
130                    .iter()
131                    .map(|&(x, y)| Circle::new((x, y), config.marker_size, color.filled())),
132            )
133            .map_err(|e| VisualizationError::Plotting(format!("{}", e)))?
134            .label(format!("Class {}", class))
135            .legend(move |(x, y)| Circle::new((x, y), config.marker_size, color.filled()));
136    }
137
138    if config.show_legend {
139        chart
140            .configure_series_labels()
141            .background_style(WHITE.mix(0.8))
142            .border_style(BLACK)
143            .draw()
144            .map_err(|e| VisualizationError::Plotting(format!("{}", e)))?;
145    }
146
147    root.present()
148        .map_err(|e| VisualizationError::Plotting(format!("{}", e)))?;
149
150    Ok(())
151}
152
153#[cfg(feature = "visualization")]
154/// Plot 2D regression dataset with target values as colors
155pub fn plot_2d_regression<P: AsRef<Path>>(
156    path: P,
157    features: &Array2<f64>,
158    targets: &Array1<f64>,
159    config: Option<PlotConfig>,
160) -> VisualizationResult<()> {
161    let config = config.unwrap_or_default();
162
163    if features.ncols() < 2 {
164        return Err(VisualizationError::InvalidDimensions(
165            "Need at least 2 features for 2D plot".to_string(),
166        ));
167    }
168
169    let root = BitMapBackend::new(path.as_ref(), (config.width, config.height)).into_drawing_area();
170    root.fill(&WHITE)
171        .map_err(|e| VisualizationError::Plotting(format!("{}", e)))?;
172
173    // Find data ranges
174    let x_min = features
175        .column(0)
176        .iter()
177        .cloned()
178        .fold(f64::INFINITY, f64::min);
179    let x_max = features
180        .column(0)
181        .iter()
182        .cloned()
183        .fold(f64::NEG_INFINITY, f64::max);
184    let y_min = features
185        .column(1)
186        .iter()
187        .cloned()
188        .fold(f64::INFINITY, f64::min);
189    let y_max = features
190        .column(1)
191        .iter()
192        .cloned()
193        .fold(f64::NEG_INFINITY, f64::max);
194    let t_min = targets.iter().cloned().fold(f64::INFINITY, f64::min);
195    let t_max = targets.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
196
197    let mut chart = ChartBuilder::on(&root)
198        .caption(&config.title, ("sans-serif", 30).into_font())
199        .margin(10)
200        .x_label_area_size(30)
201        .y_label_area_size(30)
202        .build_cartesian_2d(x_min..x_max, y_min..y_max)
203        .map_err(|e| VisualizationError::Plotting(format!("{}", e)))?;
204
205    chart
206        .configure_mesh()
207        .x_desc(&config.xlabel)
208        .y_desc(&config.ylabel)
209        .draw()
210        .map_err(|e| VisualizationError::Plotting(format!("{}", e)))?;
211
212    // Plot points with color based on target value
213    chart
214        .draw_series((0..features.nrows()).map(|i| {
215            let x = features[[i, 0]];
216            let y = features[[i, 1]];
217            let t = targets[i];
218
219            // Normalize target value to 0-1 range for color mapping
220            let normalized = if (t_max - t_min).abs() > 1e-10 {
221                (t - t_min) / (t_max - t_min)
222            } else {
223                0.5
224            };
225
226            // Map to color (blue for low values, red for high values)
227            let color = RGBColor(
228                (normalized * 255.0) as u8,
229                0,
230                ((1.0 - normalized) * 255.0) as u8,
231            );
232
233            Circle::new((x, y), config.marker_size, color.filled())
234        }))
235        .map_err(|e| VisualizationError::Plotting(format!("{}", e)))?;
236
237    root.present()
238        .map_err(|e| VisualizationError::Plotting(format!("{}", e)))?;
239
240    Ok(())
241}
242
243#[cfg(feature = "visualization")]
244/// Plot feature distributions as histograms
245pub fn plot_feature_distributions<P: AsRef<Path>>(
246    path: P,
247    features: &Array2<f64>,
248    feature_names: Option<&[String]>,
249    config: Option<PlotConfig>,
250) -> VisualizationResult<()> {
251    let config = config.unwrap_or_default();
252    let n_features = features.ncols().min(4); // Plot up to 4 features
253
254    let root = BitMapBackend::new(path.as_ref(), (config.width, config.height)).into_drawing_area();
255    root.fill(&WHITE)
256        .map_err(|e| VisualizationError::Plotting(format!("{}", e)))?;
257
258    let grid_rows = ((n_features as f64).sqrt().ceil()) as usize;
259    let grid_cols = (n_features + grid_rows - 1) / grid_rows;
260
261    let areas = root.split_evenly((grid_rows, grid_cols));
262
263    for (idx, area) in areas.iter().enumerate().take(n_features) {
264        let feature_data = features.column(idx);
265        let default_name = format!("Feature {}", idx);
266        let feature_name = feature_names
267            .and_then(|names| names.get(idx))
268            .map(|s| s.as_str())
269            .unwrap_or(&default_name);
270
271        // Calculate histogram
272        let min_val = feature_data.iter().cloned().fold(f64::INFINITY, f64::min);
273        let max_val = feature_data
274            .iter()
275            .cloned()
276            .fold(f64::NEG_INFINITY, f64::max);
277        let n_bins = 20;
278        let bin_width = (max_val - min_val) / n_bins as f64;
279
280        let mut bins = vec![0usize; n_bins];
281        for &val in feature_data.iter() {
282            let bin_idx = ((val - min_val) / bin_width).floor() as usize;
283            let bin_idx = bin_idx.min(n_bins - 1);
284            bins[bin_idx] += 1;
285        }
286
287        let max_count = *bins.iter().max().unwrap_or(&1);
288
289        let mut chart = ChartBuilder::on(area)
290            .caption(feature_name, ("sans-serif", 20).into_font())
291            .margin(5)
292            .x_label_area_size(20)
293            .y_label_area_size(30)
294            .build_cartesian_2d(min_val..max_val, 0usize..(max_count + 1))
295            .map_err(|e| VisualizationError::Plotting(format!("{}", e)))?;
296
297        chart
298            .configure_mesh()
299            .draw()
300            .map_err(|e| VisualizationError::Plotting(format!("{}", e)))?;
301
302        chart
303            .draw_series(bins.iter().enumerate().map(|(i, &count)| {
304                let x0 = min_val + i as f64 * bin_width;
305                let x1 = x0 + bin_width;
306                Rectangle::new([(x0, 0), (x1, count)], BLUE.mix(0.5).filled())
307            }))
308            .map_err(|e| VisualizationError::Plotting(format!("{}", e)))?;
309    }
310
311    root.present()
312        .map_err(|e| VisualizationError::Plotting(format!("{}", e)))?;
313
314    Ok(())
315}
316
317// Placeholder functions when visualization feature is not enabled
318#[cfg(not(feature = "visualization"))]
319pub fn plot_2d_classification<P: AsRef<Path>>(
320    _path: P,
321    _features: &Array2<f64>,
322    _targets: &Array1<i32>,
323    _config: Option<PlotConfig>,
324) -> VisualizationResult<()> {
325    Err(VisualizationError::FeatureNotEnabled(
326        "visualization feature is not enabled. Enable with --features visualization".to_string(),
327    ))
328}
329
330#[cfg(not(feature = "visualization"))]
331pub fn plot_2d_regression<P: AsRef<Path>>(
332    _path: P,
333    _features: &Array2<f64>,
334    _targets: &Array1<f64>,
335    _config: Option<PlotConfig>,
336) -> VisualizationResult<()> {
337    Err(VisualizationError::FeatureNotEnabled(
338        "visualization feature is not enabled. Enable with --features visualization".to_string(),
339    ))
340}
341
342#[cfg(not(feature = "visualization"))]
343pub fn plot_feature_distributions<P: AsRef<Path>>(
344    _path: P,
345    _features: &Array2<f64>,
346    _feature_names: Option<&[String]>,
347    _config: Option<PlotConfig>,
348) -> VisualizationResult<()> {
349    Err(VisualizationError::FeatureNotEnabled(
350        "visualization feature is not enabled. Enable with --features visualization".to_string(),
351    ))
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357
358    #[test]
359    fn test_plot_config_default() {
360        let config = PlotConfig::default();
361        assert_eq!(config.width, 800);
362        assert_eq!(config.height, 600);
363    }
364
365    #[test]
366    #[cfg(not(feature = "visualization"))]
367    fn test_visualization_disabled() {
368        use scirs2_core::ndarray::Array2;
369        let features = Array2::zeros((10, 2));
370        let targets = Array1::zeros(10);
371        let int_targets: Array1<i32> = targets.mapv(|x: f64| x as i32);
372
373        let result = plot_2d_classification("/tmp/test.png", &features, &int_targets, None);
374        assert!(result.is_err());
375    }
376}