Skip to main content

scirs2_cluster/
plotting.rs

1//! Native plotting capabilities for clustering results
2//!
3//! This module provides native plotting implementations using popular Rust visualization
4//! libraries like plotters and egui. It bridges the visualization data structures with
5//! actual plotting backends to create publication-ready plots.
6
7use crate::error::{ClusteringError, Result};
8use crate::hierarchy::visualization::{create_dendrogramplot, DendrogramConfig, DendrogramPlot};
9use crate::visualization::{ScatterPlot2D, ScatterPlot3D, VisualizationConfig};
10use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
11use std::path::Path;
12
13#[cfg(feature = "egui")]
14use egui::*;
15#[cfg(feature = "plotters")]
16use plotters::prelude::*;
17
18/// Plot output format
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum PlotFormat {
21    /// PNG image format
22    PNG,
23    /// SVG vector format  
24    SVG,
25    /// PDF format (if supported)
26    PDF,
27    /// Interactive HTML
28    HTML,
29}
30
31/// Plot output configuration
32#[derive(Debug, Clone)]
33pub struct PlotOutput {
34    /// Output format
35    pub format: PlotFormat,
36    /// Output dimensions (width, height) in pixels
37    pub dimensions: (u32, u32),
38    /// DPI for raster formats
39    pub dpi: u32,
40    /// Background color (hex format)
41    pub background_color: String,
42    /// Whether to show grid
43    pub show_grid: bool,
44    /// Whether to show axes
45    pub show_axes: bool,
46    /// Plot title
47    pub title: Option<String>,
48    /// Axis labels (x, y, z)
49    pub axis_labels: (Option<String>, Option<String>, Option<String>),
50}
51
52impl Default for PlotOutput {
53    fn default() -> Self {
54        Self {
55            format: PlotFormat::PNG,
56            dimensions: (800, 600),
57            dpi: 300,
58            background_color: "#FFFFFF".to_string(),
59            show_grid: true,
60            show_axes: true,
61            title: None,
62            axis_labels: (None, None, None),
63        }
64    }
65}
66
67/// Native dendrogram plot using plotters
68#[cfg(feature = "plotters")]
69#[allow(dead_code)]
70pub fn plot_dendrogram<P: AsRef<Path>>(
71    dendrogram_plot: &DendrogramPlot<f64>,
72    output_path: P,
73    output_config: &PlotOutput,
74) -> Result<()> {
75    let path = output_path.as_ref();
76
77    match output_config.format {
78        PlotFormat::PNG => plot_dendrogram_png(dendrogram_plot, path, output_config),
79        PlotFormat::SVG => plot_dendrogram_svg(dendrogram_plot, path, output_config),
80        _ => Err(ClusteringError::ComputationError(
81            "Unsupported output format for plotters dendrogram".to_string(),
82        )),
83    }
84}
85
86#[cfg(feature = "plotters")]
87#[allow(dead_code)]
88fn plot_dendrogram_png<P: AsRef<Path>>(
89    dendrogram_plot: &DendrogramPlot<f64>,
90    output_path: P,
91    output_config: &PlotOutput,
92) -> Result<()> {
93    let root = BitMapBackend::new(&output_path, output_config.dimensions).into_drawing_area();
94    root.fill(&WHITE).map_err(|e| {
95        ClusteringError::ComputationError(format!("Failed to initialize plot: {}", e))
96    })?;
97
98    let bounds = dendrogram_plot.bounds;
99    let margin = 0.1;
100    let x_range = bounds.1 - bounds.0;
101    let y_range = bounds.3 - bounds.2;
102
103    let mut chart = ChartBuilder::on(&root)
104        .caption(
105            output_config.title.as_deref().unwrap_or("Dendrogram"),
106            ("sans-serif", 30),
107        )
108        .margin(20)
109        .x_label_area_size(40)
110        .y_label_area_size(50)
111        .build_cartesian_2d(
112            (bounds.0 - margin * x_range)..(bounds.1 + margin * x_range),
113            (bounds.2 - margin * y_range)..(bounds.3 + margin * y_range),
114        )
115        .map_err(|e| ClusteringError::ComputationError(format!("Failed to build chart: {}", e)))?;
116
117    // Configure chart
118    chart
119        .configure_mesh()
120        .x_desc(
121            output_config
122                .axis_labels
123                .0
124                .as_deref()
125                .unwrap_or("Sample Index"),
126        )
127        .y_desc(output_config.axis_labels.1.as_deref().unwrap_or("Distance"))
128        .draw()
129        .map_err(|e| ClusteringError::ComputationError(format!("Failed to draw mesh: {}", e)))?;
130
131    // Draw branches
132    for (i, branch) in dendrogram_plot.branches.iter().enumerate() {
133        let colorhex = &dendrogram_plot.colors[i];
134        let color = parsehex_color_plotters(colorhex).unwrap_or(BLACK);
135
136        chart
137            .draw_series(std::iter::once(PathElement::new(
138                vec![
139                    (branch.start.0, branch.start.1),
140                    (branch.end.0, branch.end.1),
141                ],
142                color.stroke_width(2),
143            )))
144            .map_err(|e| {
145                ClusteringError::ComputationError(format!("Failed to draw branches: {}", e))
146            })?;
147    }
148
149    // Draw leaf labels
150    for leaf in &dendrogram_plot.leaves {
151        let text_style = ("sans-serif", 12).into_font().color(&BLACK);
152
153        chart
154            .draw_series(std::iter::once(Text::new(
155                leaf.label.clone(),
156                (leaf.position.0, leaf.position.1),
157                text_style,
158            )))
159            .map_err(|e| {
160                ClusteringError::ComputationError(format!("Failed to draw labels: {}", e))
161            })?;
162    }
163
164    root.present()
165        .map_err(|e| ClusteringError::ComputationError(format!("Failed to save plot: {}", e)))?;
166
167    Ok(())
168}
169
170#[cfg(feature = "plotters")]
171#[allow(dead_code)]
172fn plot_dendrogram_svg<P: AsRef<Path>>(
173    dendrogram_plot: &DendrogramPlot<f64>,
174    output_path: P,
175    output_config: &PlotOutput,
176) -> Result<()> {
177    let root = SVGBackend::new(&output_path, output_config.dimensions).into_drawing_area();
178    root.fill(&WHITE).map_err(|e| {
179        ClusteringError::ComputationError(format!("Failed to initialize plot: {}", e))
180    })?;
181
182    let bounds = dendrogram_plot.bounds;
183    let margin = 0.1;
184    let x_range = bounds.1 - bounds.0;
185    let y_range = bounds.3 - bounds.2;
186
187    let mut chart = ChartBuilder::on(&root)
188        .caption(
189            output_config.title.as_deref().unwrap_or("Dendrogram"),
190            ("sans-serif", 30),
191        )
192        .margin(20)
193        .x_label_area_size(40)
194        .y_label_area_size(50)
195        .build_cartesian_2d(
196            (bounds.0 - margin * x_range)..(bounds.1 + margin * x_range),
197            (bounds.2 - margin * y_range)..(bounds.3 + margin * y_range),
198        )
199        .map_err(|e| ClusteringError::ComputationError(format!("Failed to build chart: {}", e)))?;
200
201    // Configure chart
202    chart
203        .configure_mesh()
204        .x_desc(
205            output_config
206                .axis_labels
207                .0
208                .as_deref()
209                .unwrap_or("Sample Index"),
210        )
211        .y_desc(output_config.axis_labels.1.as_deref().unwrap_or("Distance"))
212        .draw()
213        .map_err(|e| ClusteringError::ComputationError(format!("Failed to draw mesh: {}", e)))?;
214
215    // Draw branches
216    for (i, branch) in dendrogram_plot.branches.iter().enumerate() {
217        let colorhex = &dendrogram_plot.colors[i];
218        let color = parsehex_color_plotters(colorhex).unwrap_or(BLACK);
219
220        chart
221            .draw_series(std::iter::once(PathElement::new(
222                vec![
223                    (branch.start.0, branch.start.1),
224                    (branch.end.0, branch.end.1),
225                ],
226                color.stroke_width(2),
227            )))
228            .map_err(|e| {
229                ClusteringError::ComputationError(format!("Failed to draw branches: {}", e))
230            })?;
231    }
232
233    // Draw leaf labels
234    for leaf in &dendrogram_plot.leaves {
235        let text_style = ("sans-serif", 12).into_font().color(&BLACK);
236
237        chart
238            .draw_series(std::iter::once(Text::new(
239                leaf.label.clone(),
240                (leaf.position.0, leaf.position.1),
241                text_style,
242            )))
243            .map_err(|e| {
244                ClusteringError::ComputationError(format!("Failed to draw labels: {}", e))
245            })?;
246    }
247
248    root.present()
249        .map_err(|e| ClusteringError::ComputationError(format!("Failed to save plot: {}", e)))?;
250
251    Ok(())
252}
253
254/// Native 2D scatter plot using plotters
255#[cfg(feature = "plotters")]
256#[allow(dead_code)]
257pub fn plot_scatter_2d<P: AsRef<Path>>(
258    scatter_plot: &ScatterPlot2D,
259    output_path: P,
260    output_config: &PlotOutput,
261) -> Result<()> {
262    let path = output_path.as_ref();
263
264    match output_config.format {
265        PlotFormat::PNG => plot_scatter_2d_png(scatter_plot, path, output_config),
266        PlotFormat::SVG => plot_scatter_2d_svg(scatter_plot, path, output_config),
267        _ => Err(ClusteringError::ComputationError(
268            "Unsupported output format for plotters backend".to_string(),
269        )),
270    }
271}
272
273#[cfg(feature = "plotters")]
274#[allow(dead_code)]
275fn plot_scatter_2d_png<P: AsRef<Path>>(
276    scatter_plot: &ScatterPlot2D,
277    output_path: P,
278    output_config: &PlotOutput,
279) -> Result<()> {
280    let root = BitMapBackend::new(&output_path, output_config.dimensions).into_drawing_area();
281    root.fill(&WHITE).map_err(|e| {
282        ClusteringError::ComputationError(format!("Failed to initialize plot: {}", e))
283    })?;
284
285    let (min_x, max_x, min_y, max_y) = scatter_plot.bounds;
286    let margin = 0.1;
287    let x_range = max_x - min_x;
288    let y_range = max_y - min_y;
289
290    let mut chart = ChartBuilder::on(&root)
291        .caption(
292            output_config
293                .title
294                .as_deref()
295                .unwrap_or("Cluster Visualization"),
296            ("sans-serif", 30),
297        )
298        .margin(20)
299        .x_label_area_size(40)
300        .y_label_area_size(50)
301        .build_cartesian_2d(
302            (min_x - margin * x_range)..(max_x + margin * x_range),
303            (min_y - margin * y_range)..(max_y + margin * y_range),
304        )
305        .map_err(|e| ClusteringError::ComputationError(format!("Failed to build chart: {}", e)))?;
306
307    // Configure chart
308    chart
309        .configure_mesh()
310        .x_desc(output_config.axis_labels.0.as_deref().unwrap_or("X"))
311        .y_desc(output_config.axis_labels.1.as_deref().unwrap_or("Y"))
312        .draw()
313        .map_err(|e| ClusteringError::ComputationError(format!("Failed to draw mesh: {}", e)))?;
314
315    // Plot points
316    for (i, point) in scatter_plot.points.rows().into_iter().enumerate() {
317        let x = point[0];
318        let y = point[1];
319        let colorhex = &scatter_plot.colors[i];
320        let size = scatter_plot.sizes[i] as i32;
321
322        // Parse hex color
323        let color = parsehex_color_plotters(colorhex).unwrap_or(RED);
324
325        chart
326            .draw_series(std::iter::once(Circle::new((x, y), size, color.filled())))
327            .map_err(|e| {
328                ClusteringError::ComputationError(format!("Failed to draw points: {}", e))
329            })?;
330    }
331
332    // Plot centroids if available
333    if let Some(centroids) = &scatter_plot.centroids {
334        for centroid in centroids.rows() {
335            let x = centroid[0];
336            let y = centroid[1];
337
338            chart
339                .draw_series(std::iter::once(Cross::new(
340                    (x, y),
341                    8,
342                    BLACK.stroke_width(3),
343                )))
344                .map_err(|e| {
345                    ClusteringError::ComputationError(format!("Failed to draw centroids: {}", e))
346                })?;
347        }
348    }
349
350    root.present()
351        .map_err(|e| ClusteringError::ComputationError(format!("Failed to save plot: {}", e)))?;
352
353    Ok(())
354}
355
356#[cfg(feature = "plotters")]
357#[allow(dead_code)]
358fn plot_scatter_2d_svg<P: AsRef<Path>>(
359    scatter_plot: &ScatterPlot2D,
360    output_path: P,
361    output_config: &PlotOutput,
362) -> Result<()> {
363    let root = SVGBackend::new(&output_path, output_config.dimensions).into_drawing_area();
364    root.fill(&WHITE).map_err(|e| {
365        ClusteringError::ComputationError(format!("Failed to initialize plot: {}", e))
366    })?;
367
368    let (min_x, max_x, min_y, max_y) = scatter_plot.bounds;
369    let margin = 0.1;
370    let x_range = max_x - min_x;
371    let y_range = max_y - min_y;
372
373    let mut chart = ChartBuilder::on(&root)
374        .caption(
375            output_config
376                .title
377                .as_deref()
378                .unwrap_or("Cluster Visualization"),
379            ("sans-serif", 30),
380        )
381        .margin(20)
382        .x_label_area_size(40)
383        .y_label_area_size(50)
384        .build_cartesian_2d(
385            (min_x - margin * x_range)..(max_x + margin * x_range),
386            (min_y - margin * y_range)..(max_y + margin * y_range),
387        )
388        .map_err(|e| ClusteringError::ComputationError(format!("Failed to build chart: {}", e)))?;
389
390    // Configure chart
391    chart
392        .configure_mesh()
393        .x_desc(output_config.axis_labels.0.as_deref().unwrap_or("X"))
394        .y_desc(output_config.axis_labels.1.as_deref().unwrap_or("Y"))
395        .draw()
396        .map_err(|e| ClusteringError::ComputationError(format!("Failed to draw mesh: {}", e)))?;
397
398    // Plot points
399    for (i, point) in scatter_plot.points.rows().into_iter().enumerate() {
400        let x = point[0];
401        let y = point[1];
402        let colorhex = &scatter_plot.colors[i];
403        let size = scatter_plot.sizes[i] as i32;
404
405        // Parse hex color
406        let color = parsehex_color_plotters(colorhex).unwrap_or(RED);
407
408        chart
409            .draw_series(std::iter::once(Circle::new((x, y), size, color.filled())))
410            .map_err(|e| {
411                ClusteringError::ComputationError(format!("Failed to draw points: {}", e))
412            })?;
413    }
414
415    // Plot centroids if available
416    if let Some(centroids) = &scatter_plot.centroids {
417        for centroid in centroids.rows() {
418            let x = centroid[0];
419            let y = centroid[1];
420
421            chart
422                .draw_series(std::iter::once(Cross::new(
423                    (x, y),
424                    8,
425                    BLACK.stroke_width(3),
426                )))
427                .map_err(|e| {
428                    ClusteringError::ComputationError(format!("Failed to draw centroids: {}", e))
429                })?;
430        }
431    }
432
433    root.present()
434        .map_err(|e| ClusteringError::ComputationError(format!("Failed to save plot: {}", e)))?;
435
436    Ok(())
437}
438
439/// Interactive clustering visualization using egui
440#[cfg(feature = "egui")]
441pub struct InteractiveClusteringApp {
442    /// Current scatter plot data
443    pub scatter_plot_2d: Option<ScatterPlot2D>,
444    /// Visualization configuration
445    pub config: VisualizationConfig,
446    /// Show/hide elements
447    pub show_centroids: bool,
448    pub show_boundaries: bool,
449    pub show_legend: bool,
450    /// Zoom and pan state
451    pub zoom: f32,
452    pub pan_offset: (f32, f32),
453    /// Selected cluster (for highlighting)
454    pub selected_cluster: Option<i32>,
455}
456
457#[cfg(feature = "egui")]
458impl Default for InteractiveClusteringApp {
459    fn default() -> Self {
460        Self {
461            scatter_plot_2d: None,
462            config: VisualizationConfig::default(),
463            show_centroids: true,
464            show_boundaries: false,
465            show_legend: true,
466            zoom: 1.0,
467            pan_offset: (0.0, 0.0),
468            selected_cluster: None,
469        }
470    }
471}
472
473#[cfg(feature = "egui")]
474impl InteractiveClusteringApp {
475    /// Create new interactive app with data
476    pub fn new(_scatterplot: ScatterPlot2D) -> Self {
477        Self {
478            scatter_plot_2d: Some(_scatterplot),
479            ..Default::default()
480        }
481    }
482
483    /// Update the scatter plot data
484    pub fn set_data(&mut self, scatterplot: ScatterPlot2D) {
485        self.scatter_plot_2d = Some(scatterplot);
486    }
487}
488
489#[cfg(feature = "egui")]
490impl eframe::App for InteractiveClusteringApp {
491    fn update(&mut self, ctx: &egui::Context, frame: &mut eframe::Frame) {
492        egui::SidePanel::left("controls").show(ctx, |ui| {
493            ui.heading("Clustering Visualization");
494            ui.separator();
495
496            ui.checkbox(&mut self.show_centroids, "Show Centroids");
497            ui.checkbox(&mut self.show_boundaries, "Show Boundaries");
498            ui.checkbox(&mut self.show_legend, "Show Legend");
499
500            ui.separator();
501            ui.label("Zoom:");
502            ui.add(egui::Slider::new(&mut self.zoom, 0.1..=5.0));
503
504            if ui.button("Reset View").clicked() {
505                self.zoom = 1.0;
506                self.pan_offset = (0.0, 0.0);
507            }
508
509            ui.separator();
510            if let Some(ref plot) = self.scatter_plot_2d {
511                ui.label("Cluster Information:");
512                for legend_entry in &plot.legend {
513                    let color = parsehex_color(&legend_entry.color).unwrap_or([255, 0, 0]);
514                    let color32 = Color32::from_rgb(color[0], color[1], color[2]);
515
516                    ui.horizontal(|ui| {
517                        ui.colored_label(color32, "●");
518                        if ui
519                            .selectable_label(
520                                self.selected_cluster == Some(legend_entry.cluster_id),
521                                format!(
522                                    "Cluster {} ({} points)",
523                                    legend_entry.cluster_id, legend_entry.count
524                                ),
525                            )
526                            .clicked()
527                        {
528                            self.selected_cluster =
529                                if self.selected_cluster == Some(legend_entry.cluster_id) {
530                                    None
531                                } else {
532                                    Some(legend_entry.cluster_id)
533                                };
534                        }
535                    });
536                }
537            }
538        });
539
540        egui::CentralPanel::default().show(ctx, |ui| {
541            if let Some(plot) = self.scatter_plot_2d.clone() {
542                self.draw_scatterplot(ui, &plot);
543            } else {
544                ui.centered_and_justified(|ui| {
545                    ui.label("No clustering data available");
546                });
547            }
548        });
549    }
550}
551
552#[cfg(feature = "egui")]
553impl InteractiveClusteringApp {
554    fn draw_scatterplot(&mut self, ui: &mut Ui, plot: &ScatterPlot2D) {
555        let (response, painter) = ui.allocate_painter(ui.available_size(), Sense::drag());
556
557        let rect = response.rect;
558        let (min_x, max_x, min_y, max_y) = plot.bounds;
559
560        // Handle pan and zoom
561        if response.dragged() {
562            self.pan_offset.0 += response.drag_delta().x;
563            self.pan_offset.1 += response.drag_delta().y;
564        }
565
566        // Convert data coordinates to screen coordinates
567        let to_screen = |x: f64, y: f64| -> Pos2 {
568            let norm_x = (x - min_x) / (max_x - min_x);
569            let norm_y = (y - min_y) / (max_y - min_y);
570
571            let screen_x =
572                rect.left() + norm_x as f32 * rect.width() * self.zoom + self.pan_offset.0;
573            let screen_y =
574                rect.bottom() - norm_y as f32 * rect.height() * self.zoom + self.pan_offset.1;
575
576            Pos2::new(screen_x, screen_y)
577        };
578
579        // Draw points
580        for (i, point) in plot.points.rows().into_iter().enumerate() {
581            let x = point[0];
582            let y = point[1];
583            let screen_pos = to_screen(x, y);
584
585            if !rect.contains(screen_pos) {
586                continue; // Skip points outside visible area
587            }
588
589            let colorhex = &plot.colors[i];
590            let color = parsehex_color(colorhex).unwrap_or([255, 0, 0]);
591            let color32 = Color32::from_rgb(color[0], color[1], color[2]);
592
593            let radius = plot.sizes[i] * self.zoom;
594            let cluster_id = plot.labels[i];
595
596            // Highlight selected cluster
597            let point_color = if let Some(selected) = self.selected_cluster {
598                if cluster_id == selected {
599                    color32
600                } else {
601                    Color32::from_rgba_premultiplied(color32.r(), color32.g(), color32.b(), 100)
602                }
603            } else {
604                color32
605            };
606
607            painter.circle_filled(screen_pos, radius, point_color);
608        }
609
610        // Draw centroids
611        if self.show_centroids {
612            if let Some(ref centroids) = plot.centroids {
613                for centroid in centroids.rows() {
614                    let x = centroid[0];
615                    let y = centroid[1];
616                    let screen_pos = to_screen(x, y);
617
618                    if rect.contains(screen_pos) {
619                        painter.circle_stroke(
620                            screen_pos,
621                            8.0 * self.zoom,
622                            Stroke::new(3.0, Color32::BLACK),
623                        );
624                        painter.line_segment(
625                            [
626                                Pos2::new(screen_pos.x - 6.0 * self.zoom, screen_pos.y),
627                                Pos2::new(screen_pos.x + 6.0 * self.zoom, screen_pos.y),
628                            ],
629                            Stroke::new(3.0, Color32::BLACK),
630                        );
631                        painter.line_segment(
632                            [
633                                Pos2::new(screen_pos.x, screen_pos.y - 6.0 * self.zoom),
634                                Pos2::new(screen_pos.x, screen_pos.y + 6.0 * self.zoom),
635                            ],
636                            Stroke::new(3.0, Color32::BLACK),
637                        );
638                    }
639                }
640            }
641        }
642    }
643}
644
645/// Utility function to parse hex color to RGB
646#[allow(dead_code)]
647fn parsehex_color(hex: &str) -> Option<[u8; 3]> {
648    if hex.len() != 7 || !hex.starts_with('#') {
649        return None;
650    }
651
652    let r = u8::from_str_radix(&hex[1..3], 16).ok()?;
653    let g = u8::from_str_radix(&hex[3..5], 16).ok()?;
654    let b = u8::from_str_radix(&hex[5..7], 16).ok()?;
655
656    Some([r, g, b])
657}
658
659#[cfg(feature = "plotters")]
660#[allow(dead_code)]
661fn parsehex_color_plotters(hex: &str) -> Option<RGBColor> {
662    let rgb = parsehex_color(hex)?;
663    Some(RGBColor(rgb[0], rgb[1], rgb[2]))
664}
665
666/// High-level function to create and save a dendrogram plot
667#[allow(dead_code)]
668pub fn save_dendrogram_plot<P: AsRef<Path>>(
669    linkage_matrix: ArrayView2<f64>,
670    labels: Option<&[String]>,
671    output_path: P,
672    dendrogram_config: Option<&DendrogramConfig<f64>>,
673    output_config: Option<&PlotOutput>,
674) -> Result<()> {
675    let default_dend_config = DendrogramConfig::default();
676    let default_out_config = PlotOutput::default();
677    let dend_config = dendrogram_config.unwrap_or(&default_dend_config);
678    let out_config = output_config.unwrap_or(&default_out_config);
679
680    // Create dendrogram plot data
681    let dendrogram_plot = create_dendrogramplot(linkage_matrix, labels, dend_config.clone())?;
682
683    #[cfg(feature = "plotters")]
684    {
685        plot_dendrogram(&dendrogram_plot, output_path, out_config)?;
686    }
687
688    #[cfg(not(feature = "plotters"))]
689    {
690        return Err(ClusteringError::ComputationError(
691            "Plotters feature not enabled. Enable with --features plotters".to_string(),
692        ));
693    }
694
695    Ok(())
696}
697
698/// Native 3D scatter plot using plotters
699#[cfg(feature = "plotters")]
700#[allow(dead_code)]
701pub fn plot_scatter_3d<P: AsRef<Path>>(
702    scatter_plot: &ScatterPlot3D,
703    output_path: P,
704    output_config: &PlotOutput,
705) -> Result<()> {
706    let path = output_path.as_ref();
707
708    match output_config.format {
709        PlotFormat::PNG => plot_scatter_3d_png(scatter_plot, path, output_config),
710        PlotFormat::SVG => plot_scatter_3d_svg(scatter_plot, path, output_config),
711        _ => Err(ClusteringError::ComputationError(
712            "Unsupported output format for 3D plotters backend".to_string(),
713        )),
714    }
715}
716
717#[cfg(feature = "plotters")]
718#[allow(dead_code)]
719fn plot_scatter_3d_png<P: AsRef<Path>>(
720    scatter_plot: &ScatterPlot3D,
721    output_path: P,
722    output_config: &PlotOutput,
723) -> Result<()> {
724    use plotters::coord::ranged3d::Cartesian3d;
725    use plotters::coord::types::RangedCoordf64;
726
727    let root = BitMapBackend::new(&output_path, output_config.dimensions).into_drawing_area();
728    root.fill(&WHITE).map_err(|e| {
729        ClusteringError::ComputationError(format!("Failed to initialize plot: {}", e))
730    })?;
731
732    let (min_x, max_x, min_y, max_y, min_z, max_z) = scatter_plot.bounds;
733    let margin = 0.1;
734    let x_range = max_x - min_x;
735    let y_range = max_y - min_y;
736    let z_range = max_z - min_z;
737
738    let chart_builder = ChartBuilder::on(&root)
739        .caption(
740            output_config
741                .title
742                .as_deref()
743                .unwrap_or("3D Cluster Visualization"),
744            ("sans-serif", 30),
745        )
746        .margin(20)
747        .build_cartesian_3d(
748            (min_x - margin * x_range)..(max_x + margin * x_range),
749            (min_y - margin * y_range)..(max_y + margin * y_range),
750            (min_z - margin * z_range)..(max_z + margin * z_range),
751        )
752        .map_err(|e| {
753            ClusteringError::ComputationError(format!("Failed to build 3D chart: {}", e))
754        })?;
755
756    let mut chart = chart_builder;
757
758    // Configure chart
759    chart
760        .configure_axes()
761        .light_grid_style(BLUE.mix(0.15))
762        .max_light_lines(4)
763        .draw()
764        .map_err(|e| ClusteringError::ComputationError(format!("Failed to draw axes: {}", e)))?;
765
766    // Plot points
767    for (i, point) in scatter_plot.points.rows().into_iter().enumerate() {
768        let x = point[0];
769        let y = point[1];
770        let z = point[2];
771        let colorhex = &scatter_plot.colors[i];
772        let size = scatter_plot.sizes[i] as i32;
773
774        // Parse hex color
775        let color = parsehex_color_plotters(colorhex).unwrap_or(RED);
776
777        chart
778            .draw_series(std::iter::once(Circle::new(
779                (x, y, z),
780                size,
781                color.filled(),
782            )))
783            .map_err(|e| {
784                ClusteringError::ComputationError(format!("Failed to draw 3D points: {}", e))
785            })?;
786    }
787
788    // Plot centroids if available
789    if let Some(centroids) = &scatter_plot.centroids {
790        for centroid in centroids.rows() {
791            let x = centroid[0];
792            let y = centroid[1];
793            let z = centroid[2];
794
795            chart
796                .draw_series(std::iter::once(Circle::new(
797                    (x, y, z),
798                    8,
799                    BLACK.stroke_width(3),
800                )))
801                .map_err(|e| {
802                    ClusteringError::ComputationError(format!("Failed to draw 3D centroids: {}", e))
803                })?;
804        }
805    }
806
807    root.present()
808        .map_err(|e| ClusteringError::ComputationError(format!("Failed to save 3D plot: {}", e)))?;
809
810    Ok(())
811}
812
813#[cfg(feature = "plotters")]
814#[allow(dead_code)]
815fn plot_scatter_3d_svg<P: AsRef<Path>>(
816    scatter_plot: &ScatterPlot3D,
817    output_path: P,
818    output_config: &PlotOutput,
819) -> Result<()> {
820    use plotters::coord::ranged3d::Cartesian3d;
821    use plotters::coord::types::RangedCoordf64;
822
823    let root = SVGBackend::new(&output_path, output_config.dimensions).into_drawing_area();
824    root.fill(&WHITE).map_err(|e| {
825        ClusteringError::ComputationError(format!("Failed to initialize plot: {}", e))
826    })?;
827
828    let (min_x, max_x, min_y, max_y, min_z, max_z) = scatter_plot.bounds;
829    let margin = 0.1;
830    let x_range = max_x - min_x;
831    let y_range = max_y - min_y;
832    let z_range = max_z - min_z;
833
834    let chart_builder = ChartBuilder::on(&root)
835        .caption(
836            output_config
837                .title
838                .as_deref()
839                .unwrap_or("3D Cluster Visualization"),
840            ("sans-serif", 30),
841        )
842        .margin(20)
843        .build_cartesian_3d(
844            (min_x - margin * x_range)..(max_x + margin * x_range),
845            (min_y - margin * y_range)..(max_y + margin * y_range),
846            (min_z - margin * z_range)..(max_z + margin * z_range),
847        )
848        .map_err(|e| {
849            ClusteringError::ComputationError(format!("Failed to build 3D chart: {}", e))
850        })?;
851
852    let mut chart = chart_builder;
853
854    // Configure chart
855    chart
856        .configure_axes()
857        .light_grid_style(BLUE.mix(0.15))
858        .max_light_lines(4)
859        .draw()
860        .map_err(|e| ClusteringError::ComputationError(format!("Failed to draw axes: {}", e)))?;
861
862    // Plot points
863    for (i, point) in scatter_plot.points.rows().into_iter().enumerate() {
864        let x = point[0];
865        let y = point[1];
866        let z = point[2];
867        let colorhex = &scatter_plot.colors[i];
868        let size = scatter_plot.sizes[i] as i32;
869
870        // Parse hex color
871        let color = parsehex_color_plotters(colorhex).unwrap_or(RED);
872
873        chart
874            .draw_series(std::iter::once(Circle::new(
875                (x, y, z),
876                size,
877                color.filled(),
878            )))
879            .map_err(|e| {
880                ClusteringError::ComputationError(format!("Failed to draw 3D points: {}", e))
881            })?;
882    }
883
884    // Plot centroids if available
885    if let Some(centroids) = &scatter_plot.centroids {
886        for centroid in centroids.rows() {
887            let x = centroid[0];
888            let y = centroid[1];
889            let z = centroid[2];
890
891            chart
892                .draw_series(std::iter::once(Circle::new(
893                    (x, y, z),
894                    8,
895                    BLACK.stroke_width(3),
896                )))
897                .map_err(|e| {
898                    ClusteringError::ComputationError(format!("Failed to draw 3D centroids: {}", e))
899                })?;
900        }
901    }
902
903    root.present()
904        .map_err(|e| ClusteringError::ComputationError(format!("Failed to save 3D plot: {}", e)))?;
905
906    Ok(())
907}
908
909/// High-level function to create and save a clustering plot
910#[allow(dead_code)]
911pub fn save_clustering_plot<P: AsRef<Path>>(
912    data: ArrayView2<f64>,
913    labels: &Array1<i32>,
914    centroids: Option<&Array2<f64>>,
915    output_path: P,
916    config: Option<&VisualizationConfig>,
917    output_config: Option<&PlotOutput>,
918) -> Result<()> {
919    let default_vis_config = VisualizationConfig::default();
920    let default_out_config = PlotOutput::default();
921    let vis_config = config.unwrap_or(&default_vis_config);
922    let out_config = output_config.unwrap_or(&default_out_config);
923
924    // Create scatter plot data
925    let scatter_plot =
926        crate::visualization::create_scatter_plot_2d(data, labels, centroids, vis_config)?;
927
928    #[cfg(feature = "plotters")]
929    {
930        plot_scatter_2d(&scatter_plot, output_path, out_config)?;
931    }
932
933    #[cfg(not(feature = "plotters"))]
934    {
935        return Err(ClusteringError::ComputationError(
936            "Plotters feature not enabled. Enable with --features plotters".to_string(),
937        ));
938    }
939
940    Ok(())
941}
942
943/// High-level function to create and save a 3D clustering plot
944#[allow(dead_code)]
945pub fn save_clustering_plot_3d<P: AsRef<Path>>(
946    data: ArrayView2<f64>,
947    labels: &Array1<i32>,
948    centroids: Option<&Array2<f64>>,
949    output_path: P,
950    config: Option<&VisualizationConfig>,
951    output_config: Option<&PlotOutput>,
952) -> Result<()> {
953    let default_vis_config = VisualizationConfig::default();
954    let default_out_config = PlotOutput::default();
955    let vis_config = config.unwrap_or(&default_vis_config);
956    let out_config = output_config.unwrap_or(&default_out_config);
957
958    // Create 3D scatter plot data
959    let scatter_plot =
960        crate::visualization::create_scatter_plot_3d(data, labels, centroids, vis_config)?;
961
962    #[cfg(feature = "plotters")]
963    {
964        plot_scatter_3d(&scatter_plot, output_path, out_config)?;
965    }
966
967    #[cfg(not(feature = "plotters"))]
968    {
969        return Err(ClusteringError::ComputationError(
970            "Plotters feature not enabled. Enable with --features plotters".to_string(),
971        ));
972    }
973
974    Ok(())
975}
976
977/// Launch interactive clustering visualization
978#[cfg(feature = "egui")]
979#[allow(dead_code)]
980pub fn launch_interactive_visualization(
981    data: ArrayView2<f64>,
982    labels: &Array1<i32>,
983    centroids: Option<&Array2<f64>>,
984    config: Option<&VisualizationConfig>,
985) -> Result<()> {
986    let default_vis_config = VisualizationConfig::default();
987    let vis_config = config.unwrap_or(&default_vis_config);
988
989    // Create scatter plot data
990    let scatter_plot =
991        crate::visualization::create_scatter_plot_2d(data, labels, centroids, vis_config)?;
992
993    let options = eframe::NativeOptions {
994        viewport: egui::ViewportBuilder::default()
995            .with_inner_size([1200.0, 800.0])
996            .with_title("Clustering Visualization"),
997        ..Default::default()
998    };
999
1000    let app = InteractiveClusteringApp::new(scatter_plot);
1001
1002    eframe::run_native(
1003        "Clustering Visualization",
1004        options,
1005        Box::new(|_| Ok::<_, Box<dyn std::error::Error + Send + Sync>>(Box::new(app))),
1006    )
1007    .map_err(|e| {
1008        ClusteringError::ComputationError(format!("Failed to launch visualization: {}", e))
1009    })?;
1010
1011    Ok(())
1012}
1013
1014#[cfg(not(feature = "egui"))]
1015#[allow(dead_code)]
1016pub fn launch_interactive_visualization(
1017    _data: ArrayView2<f64>,
1018    _labels: &Array1<i32>,
1019    _centroids: Option<&Array2<f64>>,
1020    _config: Option<&VisualizationConfig>,
1021) -> Result<()> {
1022    Err(ClusteringError::ComputationError(
1023        "Interactive visualization requires egui feature. Enable with --features egui".to_string(),
1024    ))
1025}
1026
1027#[cfg(test)]
1028mod tests {
1029    use super::*;
1030    use scirs2_core::ndarray::arr2;
1031
1032    #[test]
1033    fn testhex_color_parsing() {
1034        assert_eq!(parsehex_color("#FF0000"), Some([255, 0, 0]));
1035        assert_eq!(parsehex_color("#00FF00"), Some([0, 255, 0]));
1036        assert_eq!(parsehex_color("#0000FF"), Some([0, 0, 255]));
1037        assert_eq!(parsehex_color("FF0000"), None); // Missing #
1038        assert_eq!(parsehex_color("#FG0000"), None); // Invalid hex
1039    }
1040
1041    #[test]
1042    fn test_plot_output_default() {
1043        let output = PlotOutput::default();
1044        assert_eq!(output.format, PlotFormat::PNG);
1045        assert_eq!(output.dimensions, (800, 600));
1046        assert_eq!(output.dpi, 300);
1047    }
1048}