1use crate::error::{ClusteringError, Result};
8use crate::hierarchy::visualization::{create_dendrogramplot, DendrogramConfig, DendrogramPlot};
9use crate::visualization::{ScatterPlot2D, ScatterPlot3D, VisualizationConfig};
10use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
11use std::path::Path;
12
13#[cfg(feature = "egui")]
14use egui::*;
15#[cfg(feature = "plotters")]
16use plotters::prelude::*;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum PlotFormat {
21 PNG,
23 SVG,
25 PDF,
27 HTML,
29}
30
31#[derive(Debug, Clone)]
33pub struct PlotOutput {
34 pub format: PlotFormat,
36 pub dimensions: (u32, u32),
38 pub dpi: u32,
40 pub background_color: String,
42 pub show_grid: bool,
44 pub show_axes: bool,
46 pub title: Option<String>,
48 pub axis_labels: (Option<String>, Option<String>, Option<String>),
50}
51
52impl Default for PlotOutput {
53 fn default() -> Self {
54 Self {
55 format: PlotFormat::PNG,
56 dimensions: (800, 600),
57 dpi: 300,
58 background_color: "#FFFFFF".to_string(),
59 show_grid: true,
60 show_axes: true,
61 title: None,
62 axis_labels: (None, None, None),
63 }
64 }
65}
66
67#[cfg(feature = "plotters")]
69#[allow(dead_code)]
70pub fn plot_dendrogram<P: AsRef<Path>>(
71 dendrogram_plot: &DendrogramPlot<f64>,
72 output_path: P,
73 output_config: &PlotOutput,
74) -> Result<()> {
75 let path = output_path.as_ref();
76
77 match output_config.format {
78 PlotFormat::PNG => plot_dendrogram_png(dendrogram_plot, path, output_config),
79 PlotFormat::SVG => plot_dendrogram_svg(dendrogram_plot, path, output_config),
80 _ => Err(ClusteringError::ComputationError(
81 "Unsupported output format for plotters dendrogram".to_string(),
82 )),
83 }
84}
85
86#[cfg(feature = "plotters")]
87#[allow(dead_code)]
88fn plot_dendrogram_png<P: AsRef<Path>>(
89 dendrogram_plot: &DendrogramPlot<f64>,
90 output_path: P,
91 output_config: &PlotOutput,
92) -> Result<()> {
93 let root = BitMapBackend::new(&output_path, output_config.dimensions).into_drawing_area();
94 root.fill(&WHITE).map_err(|e| {
95 ClusteringError::ComputationError(format!("Failed to initialize plot: {}", e))
96 })?;
97
98 let bounds = dendrogram_plot.bounds;
99 let margin = 0.1;
100 let x_range = bounds.1 - bounds.0;
101 let y_range = bounds.3 - bounds.2;
102
103 let mut chart = ChartBuilder::on(&root)
104 .caption(
105 output_config.title.as_deref().unwrap_or("Dendrogram"),
106 ("sans-serif", 30),
107 )
108 .margin(20)
109 .x_label_area_size(40)
110 .y_label_area_size(50)
111 .build_cartesian_2d(
112 (bounds.0 - margin * x_range)..(bounds.1 + margin * x_range),
113 (bounds.2 - margin * y_range)..(bounds.3 + margin * y_range),
114 )
115 .map_err(|e| ClusteringError::ComputationError(format!("Failed to build chart: {}", e)))?;
116
117 chart
119 .configure_mesh()
120 .x_desc(
121 output_config
122 .axis_labels
123 .0
124 .as_deref()
125 .unwrap_or("Sample Index"),
126 )
127 .y_desc(output_config.axis_labels.1.as_deref().unwrap_or("Distance"))
128 .draw()
129 .map_err(|e| ClusteringError::ComputationError(format!("Failed to draw mesh: {}", e)))?;
130
131 for (i, branch) in dendrogram_plot.branches.iter().enumerate() {
133 let colorhex = &dendrogram_plot.colors[i];
134 let color = parsehex_color_plotters(colorhex).unwrap_or(BLACK);
135
136 chart
137 .draw_series(std::iter::once(PathElement::new(
138 vec![
139 (branch.start.0, branch.start.1),
140 (branch.end.0, branch.end.1),
141 ],
142 color.stroke_width(2),
143 )))
144 .map_err(|e| {
145 ClusteringError::ComputationError(format!("Failed to draw branches: {}", e))
146 })?;
147 }
148
149 for leaf in &dendrogram_plot.leaves {
151 let text_style = ("sans-serif", 12).into_font().color(&BLACK);
152
153 chart
154 .draw_series(std::iter::once(Text::new(
155 leaf.label.clone(),
156 (leaf.position.0, leaf.position.1),
157 text_style,
158 )))
159 .map_err(|e| {
160 ClusteringError::ComputationError(format!("Failed to draw labels: {}", e))
161 })?;
162 }
163
164 root.present()
165 .map_err(|e| ClusteringError::ComputationError(format!("Failed to save plot: {}", e)))?;
166
167 Ok(())
168}
169
170#[cfg(feature = "plotters")]
171#[allow(dead_code)]
172fn plot_dendrogram_svg<P: AsRef<Path>>(
173 dendrogram_plot: &DendrogramPlot<f64>,
174 output_path: P,
175 output_config: &PlotOutput,
176) -> Result<()> {
177 let root = SVGBackend::new(&output_path, output_config.dimensions).into_drawing_area();
178 root.fill(&WHITE).map_err(|e| {
179 ClusteringError::ComputationError(format!("Failed to initialize plot: {}", e))
180 })?;
181
182 let bounds = dendrogram_plot.bounds;
183 let margin = 0.1;
184 let x_range = bounds.1 - bounds.0;
185 let y_range = bounds.3 - bounds.2;
186
187 let mut chart = ChartBuilder::on(&root)
188 .caption(
189 output_config.title.as_deref().unwrap_or("Dendrogram"),
190 ("sans-serif", 30),
191 )
192 .margin(20)
193 .x_label_area_size(40)
194 .y_label_area_size(50)
195 .build_cartesian_2d(
196 (bounds.0 - margin * x_range)..(bounds.1 + margin * x_range),
197 (bounds.2 - margin * y_range)..(bounds.3 + margin * y_range),
198 )
199 .map_err(|e| ClusteringError::ComputationError(format!("Failed to build chart: {}", e)))?;
200
201 chart
203 .configure_mesh()
204 .x_desc(
205 output_config
206 .axis_labels
207 .0
208 .as_deref()
209 .unwrap_or("Sample Index"),
210 )
211 .y_desc(output_config.axis_labels.1.as_deref().unwrap_or("Distance"))
212 .draw()
213 .map_err(|e| ClusteringError::ComputationError(format!("Failed to draw mesh: {}", e)))?;
214
215 for (i, branch) in dendrogram_plot.branches.iter().enumerate() {
217 let colorhex = &dendrogram_plot.colors[i];
218 let color = parsehex_color_plotters(colorhex).unwrap_or(BLACK);
219
220 chart
221 .draw_series(std::iter::once(PathElement::new(
222 vec![
223 (branch.start.0, branch.start.1),
224 (branch.end.0, branch.end.1),
225 ],
226 color.stroke_width(2),
227 )))
228 .map_err(|e| {
229 ClusteringError::ComputationError(format!("Failed to draw branches: {}", e))
230 })?;
231 }
232
233 for leaf in &dendrogram_plot.leaves {
235 let text_style = ("sans-serif", 12).into_font().color(&BLACK);
236
237 chart
238 .draw_series(std::iter::once(Text::new(
239 leaf.label.clone(),
240 (leaf.position.0, leaf.position.1),
241 text_style,
242 )))
243 .map_err(|e| {
244 ClusteringError::ComputationError(format!("Failed to draw labels: {}", e))
245 })?;
246 }
247
248 root.present()
249 .map_err(|e| ClusteringError::ComputationError(format!("Failed to save plot: {}", e)))?;
250
251 Ok(())
252}
253
254#[cfg(feature = "plotters")]
256#[allow(dead_code)]
257pub fn plot_scatter_2d<P: AsRef<Path>>(
258 scatter_plot: &ScatterPlot2D,
259 output_path: P,
260 output_config: &PlotOutput,
261) -> Result<()> {
262 let path = output_path.as_ref();
263
264 match output_config.format {
265 PlotFormat::PNG => plot_scatter_2d_png(scatter_plot, path, output_config),
266 PlotFormat::SVG => plot_scatter_2d_svg(scatter_plot, path, output_config),
267 _ => Err(ClusteringError::ComputationError(
268 "Unsupported output format for plotters backend".to_string(),
269 )),
270 }
271}
272
273#[cfg(feature = "plotters")]
274#[allow(dead_code)]
275fn plot_scatter_2d_png<P: AsRef<Path>>(
276 scatter_plot: &ScatterPlot2D,
277 output_path: P,
278 output_config: &PlotOutput,
279) -> Result<()> {
280 let root = BitMapBackend::new(&output_path, output_config.dimensions).into_drawing_area();
281 root.fill(&WHITE).map_err(|e| {
282 ClusteringError::ComputationError(format!("Failed to initialize plot: {}", e))
283 })?;
284
285 let (min_x, max_x, min_y, max_y) = scatter_plot.bounds;
286 let margin = 0.1;
287 let x_range = max_x - min_x;
288 let y_range = max_y - min_y;
289
290 let mut chart = ChartBuilder::on(&root)
291 .caption(
292 output_config
293 .title
294 .as_deref()
295 .unwrap_or("Cluster Visualization"),
296 ("sans-serif", 30),
297 )
298 .margin(20)
299 .x_label_area_size(40)
300 .y_label_area_size(50)
301 .build_cartesian_2d(
302 (min_x - margin * x_range)..(max_x + margin * x_range),
303 (min_y - margin * y_range)..(max_y + margin * y_range),
304 )
305 .map_err(|e| ClusteringError::ComputationError(format!("Failed to build chart: {}", e)))?;
306
307 chart
309 .configure_mesh()
310 .x_desc(output_config.axis_labels.0.as_deref().unwrap_or("X"))
311 .y_desc(output_config.axis_labels.1.as_deref().unwrap_or("Y"))
312 .draw()
313 .map_err(|e| ClusteringError::ComputationError(format!("Failed to draw mesh: {}", e)))?;
314
315 for (i, point) in scatter_plot.points.rows().into_iter().enumerate() {
317 let x = point[0];
318 let y = point[1];
319 let colorhex = &scatter_plot.colors[i];
320 let size = scatter_plot.sizes[i] as i32;
321
322 let color = parsehex_color_plotters(colorhex).unwrap_or(RED);
324
325 chart
326 .draw_series(std::iter::once(Circle::new((x, y), size, color.filled())))
327 .map_err(|e| {
328 ClusteringError::ComputationError(format!("Failed to draw points: {}", e))
329 })?;
330 }
331
332 if let Some(centroids) = &scatter_plot.centroids {
334 for centroid in centroids.rows() {
335 let x = centroid[0];
336 let y = centroid[1];
337
338 chart
339 .draw_series(std::iter::once(Cross::new(
340 (x, y),
341 8,
342 BLACK.stroke_width(3),
343 )))
344 .map_err(|e| {
345 ClusteringError::ComputationError(format!("Failed to draw centroids: {}", e))
346 })?;
347 }
348 }
349
350 root.present()
351 .map_err(|e| ClusteringError::ComputationError(format!("Failed to save plot: {}", e)))?;
352
353 Ok(())
354}
355
356#[cfg(feature = "plotters")]
357#[allow(dead_code)]
358fn plot_scatter_2d_svg<P: AsRef<Path>>(
359 scatter_plot: &ScatterPlot2D,
360 output_path: P,
361 output_config: &PlotOutput,
362) -> Result<()> {
363 let root = SVGBackend::new(&output_path, output_config.dimensions).into_drawing_area();
364 root.fill(&WHITE).map_err(|e| {
365 ClusteringError::ComputationError(format!("Failed to initialize plot: {}", e))
366 })?;
367
368 let (min_x, max_x, min_y, max_y) = scatter_plot.bounds;
369 let margin = 0.1;
370 let x_range = max_x - min_x;
371 let y_range = max_y - min_y;
372
373 let mut chart = ChartBuilder::on(&root)
374 .caption(
375 output_config
376 .title
377 .as_deref()
378 .unwrap_or("Cluster Visualization"),
379 ("sans-serif", 30),
380 )
381 .margin(20)
382 .x_label_area_size(40)
383 .y_label_area_size(50)
384 .build_cartesian_2d(
385 (min_x - margin * x_range)..(max_x + margin * x_range),
386 (min_y - margin * y_range)..(max_y + margin * y_range),
387 )
388 .map_err(|e| ClusteringError::ComputationError(format!("Failed to build chart: {}", e)))?;
389
390 chart
392 .configure_mesh()
393 .x_desc(output_config.axis_labels.0.as_deref().unwrap_or("X"))
394 .y_desc(output_config.axis_labels.1.as_deref().unwrap_or("Y"))
395 .draw()
396 .map_err(|e| ClusteringError::ComputationError(format!("Failed to draw mesh: {}", e)))?;
397
398 for (i, point) in scatter_plot.points.rows().into_iter().enumerate() {
400 let x = point[0];
401 let y = point[1];
402 let colorhex = &scatter_plot.colors[i];
403 let size = scatter_plot.sizes[i] as i32;
404
405 let color = parsehex_color_plotters(colorhex).unwrap_or(RED);
407
408 chart
409 .draw_series(std::iter::once(Circle::new((x, y), size, color.filled())))
410 .map_err(|e| {
411 ClusteringError::ComputationError(format!("Failed to draw points: {}", e))
412 })?;
413 }
414
415 if let Some(centroids) = &scatter_plot.centroids {
417 for centroid in centroids.rows() {
418 let x = centroid[0];
419 let y = centroid[1];
420
421 chart
422 .draw_series(std::iter::once(Cross::new(
423 (x, y),
424 8,
425 BLACK.stroke_width(3),
426 )))
427 .map_err(|e| {
428 ClusteringError::ComputationError(format!("Failed to draw centroids: {}", e))
429 })?;
430 }
431 }
432
433 root.present()
434 .map_err(|e| ClusteringError::ComputationError(format!("Failed to save plot: {}", e)))?;
435
436 Ok(())
437}
438
439#[cfg(feature = "egui")]
441pub struct InteractiveClusteringApp {
442 pub scatter_plot_2d: Option<ScatterPlot2D>,
444 pub config: VisualizationConfig,
446 pub show_centroids: bool,
448 pub show_boundaries: bool,
449 pub show_legend: bool,
450 pub zoom: f32,
452 pub pan_offset: (f32, f32),
453 pub selected_cluster: Option<i32>,
455}
456
457#[cfg(feature = "egui")]
458impl Default for InteractiveClusteringApp {
459 fn default() -> Self {
460 Self {
461 scatter_plot_2d: None,
462 config: VisualizationConfig::default(),
463 show_centroids: true,
464 show_boundaries: false,
465 show_legend: true,
466 zoom: 1.0,
467 pan_offset: (0.0, 0.0),
468 selected_cluster: None,
469 }
470 }
471}
472
473#[cfg(feature = "egui")]
474impl InteractiveClusteringApp {
475 pub fn new(_scatterplot: ScatterPlot2D) -> Self {
477 Self {
478 scatter_plot_2d: Some(_scatterplot),
479 ..Default::default()
480 }
481 }
482
483 pub fn set_data(&mut self, scatterplot: ScatterPlot2D) {
485 self.scatter_plot_2d = Some(scatterplot);
486 }
487}
488
489#[cfg(feature = "egui")]
490impl eframe::App for InteractiveClusteringApp {
491 fn update(&mut self, ctx: &egui::Context, frame: &mut eframe::Frame) {
492 egui::SidePanel::left("controls").show(ctx, |ui| {
493 ui.heading("Clustering Visualization");
494 ui.separator();
495
496 ui.checkbox(&mut self.show_centroids, "Show Centroids");
497 ui.checkbox(&mut self.show_boundaries, "Show Boundaries");
498 ui.checkbox(&mut self.show_legend, "Show Legend");
499
500 ui.separator();
501 ui.label("Zoom:");
502 ui.add(egui::Slider::new(&mut self.zoom, 0.1..=5.0));
503
504 if ui.button("Reset View").clicked() {
505 self.zoom = 1.0;
506 self.pan_offset = (0.0, 0.0);
507 }
508
509 ui.separator();
510 if let Some(ref plot) = self.scatter_plot_2d {
511 ui.label("Cluster Information:");
512 for legend_entry in &plot.legend {
513 let color = parsehex_color(&legend_entry.color).unwrap_or([255, 0, 0]);
514 let color32 = Color32::from_rgb(color[0], color[1], color[2]);
515
516 ui.horizontal(|ui| {
517 ui.colored_label(color32, "●");
518 if ui
519 .selectable_label(
520 self.selected_cluster == Some(legend_entry.cluster_id),
521 format!(
522 "Cluster {} ({} points)",
523 legend_entry.cluster_id, legend_entry.count
524 ),
525 )
526 .clicked()
527 {
528 self.selected_cluster =
529 if self.selected_cluster == Some(legend_entry.cluster_id) {
530 None
531 } else {
532 Some(legend_entry.cluster_id)
533 };
534 }
535 });
536 }
537 }
538 });
539
540 egui::CentralPanel::default().show(ctx, |ui| {
541 if let Some(plot) = self.scatter_plot_2d.clone() {
542 self.draw_scatterplot(ui, &plot);
543 } else {
544 ui.centered_and_justified(|ui| {
545 ui.label("No clustering data available");
546 });
547 }
548 });
549 }
550}
551
552#[cfg(feature = "egui")]
553impl InteractiveClusteringApp {
554 fn draw_scatterplot(&mut self, ui: &mut Ui, plot: &ScatterPlot2D) {
555 let (response, painter) = ui.allocate_painter(ui.available_size(), Sense::drag());
556
557 let rect = response.rect;
558 let (min_x, max_x, min_y, max_y) = plot.bounds;
559
560 if response.dragged() {
562 self.pan_offset.0 += response.drag_delta().x;
563 self.pan_offset.1 += response.drag_delta().y;
564 }
565
566 let to_screen = |x: f64, y: f64| -> Pos2 {
568 let norm_x = (x - min_x) / (max_x - min_x);
569 let norm_y = (y - min_y) / (max_y - min_y);
570
571 let screen_x =
572 rect.left() + norm_x as f32 * rect.width() * self.zoom + self.pan_offset.0;
573 let screen_y =
574 rect.bottom() - norm_y as f32 * rect.height() * self.zoom + self.pan_offset.1;
575
576 Pos2::new(screen_x, screen_y)
577 };
578
579 for (i, point) in plot.points.rows().into_iter().enumerate() {
581 let x = point[0];
582 let y = point[1];
583 let screen_pos = to_screen(x, y);
584
585 if !rect.contains(screen_pos) {
586 continue; }
588
589 let colorhex = &plot.colors[i];
590 let color = parsehex_color(colorhex).unwrap_or([255, 0, 0]);
591 let color32 = Color32::from_rgb(color[0], color[1], color[2]);
592
593 let radius = plot.sizes[i] * self.zoom;
594 let cluster_id = plot.labels[i];
595
596 let point_color = if let Some(selected) = self.selected_cluster {
598 if cluster_id == selected {
599 color32
600 } else {
601 Color32::from_rgba_premultiplied(color32.r(), color32.g(), color32.b(), 100)
602 }
603 } else {
604 color32
605 };
606
607 painter.circle_filled(screen_pos, radius, point_color);
608 }
609
610 if self.show_centroids {
612 if let Some(ref centroids) = plot.centroids {
613 for centroid in centroids.rows() {
614 let x = centroid[0];
615 let y = centroid[1];
616 let screen_pos = to_screen(x, y);
617
618 if rect.contains(screen_pos) {
619 painter.circle_stroke(
620 screen_pos,
621 8.0 * self.zoom,
622 Stroke::new(3.0, Color32::BLACK),
623 );
624 painter.line_segment(
625 [
626 Pos2::new(screen_pos.x - 6.0 * self.zoom, screen_pos.y),
627 Pos2::new(screen_pos.x + 6.0 * self.zoom, screen_pos.y),
628 ],
629 Stroke::new(3.0, Color32::BLACK),
630 );
631 painter.line_segment(
632 [
633 Pos2::new(screen_pos.x, screen_pos.y - 6.0 * self.zoom),
634 Pos2::new(screen_pos.x, screen_pos.y + 6.0 * self.zoom),
635 ],
636 Stroke::new(3.0, Color32::BLACK),
637 );
638 }
639 }
640 }
641 }
642 }
643}
644
645#[allow(dead_code)]
647fn parsehex_color(hex: &str) -> Option<[u8; 3]> {
648 if hex.len() != 7 || !hex.starts_with('#') {
649 return None;
650 }
651
652 let r = u8::from_str_radix(&hex[1..3], 16).ok()?;
653 let g = u8::from_str_radix(&hex[3..5], 16).ok()?;
654 let b = u8::from_str_radix(&hex[5..7], 16).ok()?;
655
656 Some([r, g, b])
657}
658
659#[cfg(feature = "plotters")]
660#[allow(dead_code)]
661fn parsehex_color_plotters(hex: &str) -> Option<RGBColor> {
662 let rgb = parsehex_color(hex)?;
663 Some(RGBColor(rgb[0], rgb[1], rgb[2]))
664}
665
666#[allow(dead_code)]
668pub fn save_dendrogram_plot<P: AsRef<Path>>(
669 linkage_matrix: ArrayView2<f64>,
670 labels: Option<&[String]>,
671 output_path: P,
672 dendrogram_config: Option<&DendrogramConfig<f64>>,
673 output_config: Option<&PlotOutput>,
674) -> Result<()> {
675 let default_dend_config = DendrogramConfig::default();
676 let default_out_config = PlotOutput::default();
677 let dend_config = dendrogram_config.unwrap_or(&default_dend_config);
678 let out_config = output_config.unwrap_or(&default_out_config);
679
680 let dendrogram_plot = create_dendrogramplot(linkage_matrix, labels, dend_config.clone())?;
682
683 #[cfg(feature = "plotters")]
684 {
685 plot_dendrogram(&dendrogram_plot, output_path, out_config)?;
686 }
687
688 #[cfg(not(feature = "plotters"))]
689 {
690 return Err(ClusteringError::ComputationError(
691 "Plotters feature not enabled. Enable with --features plotters".to_string(),
692 ));
693 }
694
695 Ok(())
696}
697
698#[cfg(feature = "plotters")]
700#[allow(dead_code)]
701pub fn plot_scatter_3d<P: AsRef<Path>>(
702 scatter_plot: &ScatterPlot3D,
703 output_path: P,
704 output_config: &PlotOutput,
705) -> Result<()> {
706 let path = output_path.as_ref();
707
708 match output_config.format {
709 PlotFormat::PNG => plot_scatter_3d_png(scatter_plot, path, output_config),
710 PlotFormat::SVG => plot_scatter_3d_svg(scatter_plot, path, output_config),
711 _ => Err(ClusteringError::ComputationError(
712 "Unsupported output format for 3D plotters backend".to_string(),
713 )),
714 }
715}
716
717#[cfg(feature = "plotters")]
718#[allow(dead_code)]
719fn plot_scatter_3d_png<P: AsRef<Path>>(
720 scatter_plot: &ScatterPlot3D,
721 output_path: P,
722 output_config: &PlotOutput,
723) -> Result<()> {
724 use plotters::coord::ranged3d::Cartesian3d;
725 use plotters::coord::types::RangedCoordf64;
726
727 let root = BitMapBackend::new(&output_path, output_config.dimensions).into_drawing_area();
728 root.fill(&WHITE).map_err(|e| {
729 ClusteringError::ComputationError(format!("Failed to initialize plot: {}", e))
730 })?;
731
732 let (min_x, max_x, min_y, max_y, min_z, max_z) = scatter_plot.bounds;
733 let margin = 0.1;
734 let x_range = max_x - min_x;
735 let y_range = max_y - min_y;
736 let z_range = max_z - min_z;
737
738 let chart_builder = ChartBuilder::on(&root)
739 .caption(
740 output_config
741 .title
742 .as_deref()
743 .unwrap_or("3D Cluster Visualization"),
744 ("sans-serif", 30),
745 )
746 .margin(20)
747 .build_cartesian_3d(
748 (min_x - margin * x_range)..(max_x + margin * x_range),
749 (min_y - margin * y_range)..(max_y + margin * y_range),
750 (min_z - margin * z_range)..(max_z + margin * z_range),
751 )
752 .map_err(|e| {
753 ClusteringError::ComputationError(format!("Failed to build 3D chart: {}", e))
754 })?;
755
756 let mut chart = chart_builder;
757
758 chart
760 .configure_axes()
761 .light_grid_style(BLUE.mix(0.15))
762 .max_light_lines(4)
763 .draw()
764 .map_err(|e| ClusteringError::ComputationError(format!("Failed to draw axes: {}", e)))?;
765
766 for (i, point) in scatter_plot.points.rows().into_iter().enumerate() {
768 let x = point[0];
769 let y = point[1];
770 let z = point[2];
771 let colorhex = &scatter_plot.colors[i];
772 let size = scatter_plot.sizes[i] as i32;
773
774 let color = parsehex_color_plotters(colorhex).unwrap_or(RED);
776
777 chart
778 .draw_series(std::iter::once(Circle::new(
779 (x, y, z),
780 size,
781 color.filled(),
782 )))
783 .map_err(|e| {
784 ClusteringError::ComputationError(format!("Failed to draw 3D points: {}", e))
785 })?;
786 }
787
788 if let Some(centroids) = &scatter_plot.centroids {
790 for centroid in centroids.rows() {
791 let x = centroid[0];
792 let y = centroid[1];
793 let z = centroid[2];
794
795 chart
796 .draw_series(std::iter::once(Circle::new(
797 (x, y, z),
798 8,
799 BLACK.stroke_width(3),
800 )))
801 .map_err(|e| {
802 ClusteringError::ComputationError(format!("Failed to draw 3D centroids: {}", e))
803 })?;
804 }
805 }
806
807 root.present()
808 .map_err(|e| ClusteringError::ComputationError(format!("Failed to save 3D plot: {}", e)))?;
809
810 Ok(())
811}
812
813#[cfg(feature = "plotters")]
814#[allow(dead_code)]
815fn plot_scatter_3d_svg<P: AsRef<Path>>(
816 scatter_plot: &ScatterPlot3D,
817 output_path: P,
818 output_config: &PlotOutput,
819) -> Result<()> {
820 use plotters::coord::ranged3d::Cartesian3d;
821 use plotters::coord::types::RangedCoordf64;
822
823 let root = SVGBackend::new(&output_path, output_config.dimensions).into_drawing_area();
824 root.fill(&WHITE).map_err(|e| {
825 ClusteringError::ComputationError(format!("Failed to initialize plot: {}", e))
826 })?;
827
828 let (min_x, max_x, min_y, max_y, min_z, max_z) = scatter_plot.bounds;
829 let margin = 0.1;
830 let x_range = max_x - min_x;
831 let y_range = max_y - min_y;
832 let z_range = max_z - min_z;
833
834 let chart_builder = ChartBuilder::on(&root)
835 .caption(
836 output_config
837 .title
838 .as_deref()
839 .unwrap_or("3D Cluster Visualization"),
840 ("sans-serif", 30),
841 )
842 .margin(20)
843 .build_cartesian_3d(
844 (min_x - margin * x_range)..(max_x + margin * x_range),
845 (min_y - margin * y_range)..(max_y + margin * y_range),
846 (min_z - margin * z_range)..(max_z + margin * z_range),
847 )
848 .map_err(|e| {
849 ClusteringError::ComputationError(format!("Failed to build 3D chart: {}", e))
850 })?;
851
852 let mut chart = chart_builder;
853
854 chart
856 .configure_axes()
857 .light_grid_style(BLUE.mix(0.15))
858 .max_light_lines(4)
859 .draw()
860 .map_err(|e| ClusteringError::ComputationError(format!("Failed to draw axes: {}", e)))?;
861
862 for (i, point) in scatter_plot.points.rows().into_iter().enumerate() {
864 let x = point[0];
865 let y = point[1];
866 let z = point[2];
867 let colorhex = &scatter_plot.colors[i];
868 let size = scatter_plot.sizes[i] as i32;
869
870 let color = parsehex_color_plotters(colorhex).unwrap_or(RED);
872
873 chart
874 .draw_series(std::iter::once(Circle::new(
875 (x, y, z),
876 size,
877 color.filled(),
878 )))
879 .map_err(|e| {
880 ClusteringError::ComputationError(format!("Failed to draw 3D points: {}", e))
881 })?;
882 }
883
884 if let Some(centroids) = &scatter_plot.centroids {
886 for centroid in centroids.rows() {
887 let x = centroid[0];
888 let y = centroid[1];
889 let z = centroid[2];
890
891 chart
892 .draw_series(std::iter::once(Circle::new(
893 (x, y, z),
894 8,
895 BLACK.stroke_width(3),
896 )))
897 .map_err(|e| {
898 ClusteringError::ComputationError(format!("Failed to draw 3D centroids: {}", e))
899 })?;
900 }
901 }
902
903 root.present()
904 .map_err(|e| ClusteringError::ComputationError(format!("Failed to save 3D plot: {}", e)))?;
905
906 Ok(())
907}
908
909#[allow(dead_code)]
911pub fn save_clustering_plot<P: AsRef<Path>>(
912 data: ArrayView2<f64>,
913 labels: &Array1<i32>,
914 centroids: Option<&Array2<f64>>,
915 output_path: P,
916 config: Option<&VisualizationConfig>,
917 output_config: Option<&PlotOutput>,
918) -> Result<()> {
919 let default_vis_config = VisualizationConfig::default();
920 let default_out_config = PlotOutput::default();
921 let vis_config = config.unwrap_or(&default_vis_config);
922 let out_config = output_config.unwrap_or(&default_out_config);
923
924 let scatter_plot =
926 crate::visualization::create_scatter_plot_2d(data, labels, centroids, vis_config)?;
927
928 #[cfg(feature = "plotters")]
929 {
930 plot_scatter_2d(&scatter_plot, output_path, out_config)?;
931 }
932
933 #[cfg(not(feature = "plotters"))]
934 {
935 return Err(ClusteringError::ComputationError(
936 "Plotters feature not enabled. Enable with --features plotters".to_string(),
937 ));
938 }
939
940 Ok(())
941}
942
943#[allow(dead_code)]
945pub fn save_clustering_plot_3d<P: AsRef<Path>>(
946 data: ArrayView2<f64>,
947 labels: &Array1<i32>,
948 centroids: Option<&Array2<f64>>,
949 output_path: P,
950 config: Option<&VisualizationConfig>,
951 output_config: Option<&PlotOutput>,
952) -> Result<()> {
953 let default_vis_config = VisualizationConfig::default();
954 let default_out_config = PlotOutput::default();
955 let vis_config = config.unwrap_or(&default_vis_config);
956 let out_config = output_config.unwrap_or(&default_out_config);
957
958 let scatter_plot =
960 crate::visualization::create_scatter_plot_3d(data, labels, centroids, vis_config)?;
961
962 #[cfg(feature = "plotters")]
963 {
964 plot_scatter_3d(&scatter_plot, output_path, out_config)?;
965 }
966
967 #[cfg(not(feature = "plotters"))]
968 {
969 return Err(ClusteringError::ComputationError(
970 "Plotters feature not enabled. Enable with --features plotters".to_string(),
971 ));
972 }
973
974 Ok(())
975}
976
977#[cfg(feature = "egui")]
979#[allow(dead_code)]
980pub fn launch_interactive_visualization(
981 data: ArrayView2<f64>,
982 labels: &Array1<i32>,
983 centroids: Option<&Array2<f64>>,
984 config: Option<&VisualizationConfig>,
985) -> Result<()> {
986 let default_vis_config = VisualizationConfig::default();
987 let vis_config = config.unwrap_or(&default_vis_config);
988
989 let scatter_plot =
991 crate::visualization::create_scatter_plot_2d(data, labels, centroids, vis_config)?;
992
993 let options = eframe::NativeOptions {
994 viewport: egui::ViewportBuilder::default()
995 .with_inner_size([1200.0, 800.0])
996 .with_title("Clustering Visualization"),
997 ..Default::default()
998 };
999
1000 let app = InteractiveClusteringApp::new(scatter_plot);
1001
1002 eframe::run_native(
1003 "Clustering Visualization",
1004 options,
1005 Box::new(|_| Ok::<_, Box<dyn std::error::Error + Send + Sync>>(Box::new(app))),
1006 )
1007 .map_err(|e| {
1008 ClusteringError::ComputationError(format!("Failed to launch visualization: {}", e))
1009 })?;
1010
1011 Ok(())
1012}
1013
1014#[cfg(not(feature = "egui"))]
1015#[allow(dead_code)]
1016pub fn launch_interactive_visualization(
1017 _data: ArrayView2<f64>,
1018 _labels: &Array1<i32>,
1019 _centroids: Option<&Array2<f64>>,
1020 _config: Option<&VisualizationConfig>,
1021) -> Result<()> {
1022 Err(ClusteringError::ComputationError(
1023 "Interactive visualization requires egui feature. Enable with --features egui".to_string(),
1024 ))
1025}
1026
1027#[cfg(test)]
1028mod tests {
1029 use super::*;
1030 use scirs2_core::ndarray::arr2;
1031
1032 #[test]
1033 fn testhex_color_parsing() {
1034 assert_eq!(parsehex_color("#FF0000"), Some([255, 0, 0]));
1035 assert_eq!(parsehex_color("#00FF00"), Some([0, 255, 0]));
1036 assert_eq!(parsehex_color("#0000FF"), Some([0, 0, 255]));
1037 assert_eq!(parsehex_color("FF0000"), None); assert_eq!(parsehex_color("#FG0000"), None); }
1040
1041 #[test]
1042 fn test_plot_output_default() {
1043 let output = PlotOutput::default();
1044 assert_eq!(output.format, PlotFormat::PNG);
1045 assert_eq!(output.dimensions, (800, 600));
1046 assert_eq!(output.dpi, 300);
1047 }
1048}