1use scirs2_core::ndarray::{Array1, Array2};
7use std::path::Path;
8use thiserror::Error;
9
10#[derive(Debug, Error)]
12pub enum VisualizationError {
13 #[error("IO error: {0}")]
14 Io(#[from] std::io::Error),
15 #[error("Plotting error: {0}")]
16 Plotting(String),
17 #[error("Invalid dimensions: {0}")]
18 InvalidDimensions(String),
19 #[error("Feature not enabled: {0}")]
20 FeatureNotEnabled(String),
21}
22
23pub type VisualizationResult<T> = Result<T, VisualizationError>;
24
25#[derive(Debug, Clone)]
27pub struct PlotConfig {
28 pub width: u32,
29 pub height: u32,
30 pub title: String,
31 pub xlabel: String,
32 pub ylabel: String,
33 pub show_legend: bool,
34 pub marker_size: u32,
35}
36
37impl Default for PlotConfig {
38 fn default() -> Self {
39 Self {
40 width: 800,
41 height: 600,
42 title: "Dataset Visualization".to_string(),
43 xlabel: "Feature 1".to_string(),
44 ylabel: "Feature 2".to_string(),
45 show_legend: true,
46 marker_size: 3,
47 }
48 }
49}
50
51#[cfg(feature = "visualization")]
52use plotters::prelude::*;
53
54#[cfg(feature = "visualization")]
55pub fn plot_2d_classification<P: AsRef<Path>>(
57 path: P,
58 features: &Array2<f64>,
59 targets: &Array1<i32>,
60 config: Option<PlotConfig>,
61) -> VisualizationResult<()> {
62 let config = config.unwrap_or_default();
63
64 if features.ncols() < 2 {
65 return Err(VisualizationError::InvalidDimensions(
66 "Need at least 2 features for 2D plot".to_string(),
67 ));
68 }
69
70 let root = BitMapBackend::new(path.as_ref(), (config.width, config.height)).into_drawing_area();
71 root.fill(&WHITE)
72 .map_err(|e| VisualizationError::Plotting(format!("{}", e)))?;
73
74 let x_min = features
76 .column(0)
77 .iter()
78 .cloned()
79 .fold(f64::INFINITY, f64::min);
80 let x_max = features
81 .column(0)
82 .iter()
83 .cloned()
84 .fold(f64::NEG_INFINITY, f64::max);
85 let y_min = features
86 .column(1)
87 .iter()
88 .cloned()
89 .fold(f64::INFINITY, f64::min);
90 let y_max = features
91 .column(1)
92 .iter()
93 .cloned()
94 .fold(f64::NEG_INFINITY, f64::max);
95
96 let mut chart = ChartBuilder::on(&root)
97 .caption(&config.title, ("sans-serif", 30).into_font())
98 .margin(10)
99 .x_label_area_size(30)
100 .y_label_area_size(30)
101 .build_cartesian_2d(x_min..x_max, y_min..y_max)
102 .map_err(|e| VisualizationError::Plotting(format!("{}", e)))?;
103
104 chart
105 .configure_mesh()
106 .x_desc(&config.xlabel)
107 .y_desc(&config.ylabel)
108 .draw()
109 .map_err(|e| VisualizationError::Plotting(format!("{}", e)))?;
110
111 let mut class_points: std::collections::HashMap<i32, Vec<(f64, f64)>> =
113 std::collections::HashMap::new();
114
115 for i in 0..features.nrows() {
116 let x = features[[i, 0]];
117 let y = features[[i, 1]];
118 let class = targets[i];
119 class_points.entry(class).or_default().push((x, y));
120 }
121
122 let colors = [&RED, &BLUE, &GREEN, &YELLOW, &MAGENTA, &CYAN];
124
125 for (idx, (class, points)) in class_points.iter().enumerate() {
126 let color = colors[idx % colors.len()];
127 chart
128 .draw_series(
129 points
130 .iter()
131 .map(|&(x, y)| Circle::new((x, y), config.marker_size, color.filled())),
132 )
133 .map_err(|e| VisualizationError::Plotting(format!("{}", e)))?
134 .label(format!("Class {}", class))
135 .legend(move |(x, y)| Circle::new((x, y), config.marker_size, color.filled()));
136 }
137
138 if config.show_legend {
139 chart
140 .configure_series_labels()
141 .background_style(WHITE.mix(0.8))
142 .border_style(BLACK)
143 .draw()
144 .map_err(|e| VisualizationError::Plotting(format!("{}", e)))?;
145 }
146
147 root.present()
148 .map_err(|e| VisualizationError::Plotting(format!("{}", e)))?;
149
150 Ok(())
151}
152
153#[cfg(feature = "visualization")]
154pub fn plot_2d_regression<P: AsRef<Path>>(
156 path: P,
157 features: &Array2<f64>,
158 targets: &Array1<f64>,
159 config: Option<PlotConfig>,
160) -> VisualizationResult<()> {
161 let config = config.unwrap_or_default();
162
163 if features.ncols() < 2 {
164 return Err(VisualizationError::InvalidDimensions(
165 "Need at least 2 features for 2D plot".to_string(),
166 ));
167 }
168
169 let root = BitMapBackend::new(path.as_ref(), (config.width, config.height)).into_drawing_area();
170 root.fill(&WHITE)
171 .map_err(|e| VisualizationError::Plotting(format!("{}", e)))?;
172
173 let x_min = features
175 .column(0)
176 .iter()
177 .cloned()
178 .fold(f64::INFINITY, f64::min);
179 let x_max = features
180 .column(0)
181 .iter()
182 .cloned()
183 .fold(f64::NEG_INFINITY, f64::max);
184 let y_min = features
185 .column(1)
186 .iter()
187 .cloned()
188 .fold(f64::INFINITY, f64::min);
189 let y_max = features
190 .column(1)
191 .iter()
192 .cloned()
193 .fold(f64::NEG_INFINITY, f64::max);
194 let t_min = targets.iter().cloned().fold(f64::INFINITY, f64::min);
195 let t_max = targets.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
196
197 let mut chart = ChartBuilder::on(&root)
198 .caption(&config.title, ("sans-serif", 30).into_font())
199 .margin(10)
200 .x_label_area_size(30)
201 .y_label_area_size(30)
202 .build_cartesian_2d(x_min..x_max, y_min..y_max)
203 .map_err(|e| VisualizationError::Plotting(format!("{}", e)))?;
204
205 chart
206 .configure_mesh()
207 .x_desc(&config.xlabel)
208 .y_desc(&config.ylabel)
209 .draw()
210 .map_err(|e| VisualizationError::Plotting(format!("{}", e)))?;
211
212 chart
214 .draw_series((0..features.nrows()).map(|i| {
215 let x = features[[i, 0]];
216 let y = features[[i, 1]];
217 let t = targets[i];
218
219 let normalized = if (t_max - t_min).abs() > 1e-10 {
221 (t - t_min) / (t_max - t_min)
222 } else {
223 0.5
224 };
225
226 let color = RGBColor(
228 (normalized * 255.0) as u8,
229 0,
230 ((1.0 - normalized) * 255.0) as u8,
231 );
232
233 Circle::new((x, y), config.marker_size, color.filled())
234 }))
235 .map_err(|e| VisualizationError::Plotting(format!("{}", e)))?;
236
237 root.present()
238 .map_err(|e| VisualizationError::Plotting(format!("{}", e)))?;
239
240 Ok(())
241}
242
243#[cfg(feature = "visualization")]
244pub fn plot_feature_distributions<P: AsRef<Path>>(
246 path: P,
247 features: &Array2<f64>,
248 feature_names: Option<&[String]>,
249 config: Option<PlotConfig>,
250) -> VisualizationResult<()> {
251 let config = config.unwrap_or_default();
252 let n_features = features.ncols().min(4); let root = BitMapBackend::new(path.as_ref(), (config.width, config.height)).into_drawing_area();
255 root.fill(&WHITE)
256 .map_err(|e| VisualizationError::Plotting(format!("{}", e)))?;
257
258 let grid_rows = ((n_features as f64).sqrt().ceil()) as usize;
259 let grid_cols = (n_features + grid_rows - 1) / grid_rows;
260
261 let areas = root.split_evenly((grid_rows, grid_cols));
262
263 for (idx, area) in areas.iter().enumerate().take(n_features) {
264 let feature_data = features.column(idx);
265 let default_name = format!("Feature {}", idx);
266 let feature_name = feature_names
267 .and_then(|names| names.get(idx))
268 .map(|s| s.as_str())
269 .unwrap_or(&default_name);
270
271 let min_val = feature_data.iter().cloned().fold(f64::INFINITY, f64::min);
273 let max_val = feature_data
274 .iter()
275 .cloned()
276 .fold(f64::NEG_INFINITY, f64::max);
277 let n_bins = 20;
278 let bin_width = (max_val - min_val) / n_bins as f64;
279
280 let mut bins = vec![0usize; n_bins];
281 for &val in feature_data.iter() {
282 let bin_idx = ((val - min_val) / bin_width).floor() as usize;
283 let bin_idx = bin_idx.min(n_bins - 1);
284 bins[bin_idx] += 1;
285 }
286
287 let max_count = *bins.iter().max().unwrap_or(&1);
288
289 let mut chart = ChartBuilder::on(area)
290 .caption(feature_name, ("sans-serif", 20).into_font())
291 .margin(5)
292 .x_label_area_size(20)
293 .y_label_area_size(30)
294 .build_cartesian_2d(min_val..max_val, 0usize..(max_count + 1))
295 .map_err(|e| VisualizationError::Plotting(format!("{}", e)))?;
296
297 chart
298 .configure_mesh()
299 .draw()
300 .map_err(|e| VisualizationError::Plotting(format!("{}", e)))?;
301
302 chart
303 .draw_series(bins.iter().enumerate().map(|(i, &count)| {
304 let x0 = min_val + i as f64 * bin_width;
305 let x1 = x0 + bin_width;
306 Rectangle::new([(x0, 0), (x1, count)], BLUE.mix(0.5).filled())
307 }))
308 .map_err(|e| VisualizationError::Plotting(format!("{}", e)))?;
309 }
310
311 root.present()
312 .map_err(|e| VisualizationError::Plotting(format!("{}", e)))?;
313
314 Ok(())
315}
316
317#[cfg(not(feature = "visualization"))]
319pub fn plot_2d_classification<P: AsRef<Path>>(
320 _path: P,
321 _features: &Array2<f64>,
322 _targets: &Array1<i32>,
323 _config: Option<PlotConfig>,
324) -> VisualizationResult<()> {
325 Err(VisualizationError::FeatureNotEnabled(
326 "visualization feature is not enabled. Enable with --features visualization".to_string(),
327 ))
328}
329
330#[cfg(not(feature = "visualization"))]
331pub fn plot_2d_regression<P: AsRef<Path>>(
332 _path: P,
333 _features: &Array2<f64>,
334 _targets: &Array1<f64>,
335 _config: Option<PlotConfig>,
336) -> VisualizationResult<()> {
337 Err(VisualizationError::FeatureNotEnabled(
338 "visualization feature is not enabled. Enable with --features visualization".to_string(),
339 ))
340}
341
342#[cfg(not(feature = "visualization"))]
343pub fn plot_feature_distributions<P: AsRef<Path>>(
344 _path: P,
345 _features: &Array2<f64>,
346 _feature_names: Option<&[String]>,
347 _config: Option<PlotConfig>,
348) -> VisualizationResult<()> {
349 Err(VisualizationError::FeatureNotEnabled(
350 "visualization feature is not enabled. Enable with --features visualization".to_string(),
351 ))
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357
358 #[test]
359 fn test_plot_config_default() {
360 let config = PlotConfig::default();
361 assert_eq!(config.width, 800);
362 assert_eq!(config.height, 600);
363 }
364
365 #[test]
366 #[cfg(not(feature = "visualization"))]
367 fn test_visualization_disabled() {
368 use scirs2_core::ndarray::Array2;
369 let features = Array2::zeros((10, 2));
370 let targets = Array1::zeros(10);
371 let int_targets: Array1<i32> = targets.mapv(|x: f64| x as i32);
372
373 let result = plot_2d_classification("/tmp/test.png", &features, &int_targets, None);
374 assert!(result.is_err());
375 }
376}