scirs2_cluster/hierarchy/visualization/
plotting.rs

1//! Dendrogram plotting functionality
2//!
3//! This module contains the core functionality for creating and positioning
4//! dendrogram plots from linkage matrices.
5
6use scirs2_core::ndarray::ArrayView2;
7use scirs2_core::numeric::{Float, FromPrimitive};
8use std::collections::HashMap;
9use std::fmt::Debug;
10
11use super::types::*;
12use crate::error::{ClusteringError, Result};
13
14/// Internal tree node structure for dendrogram construction
15#[derive(Debug, Clone)]
16struct TreeNode<F: Float> {
17    /// Node ID (sample index for leaves, or cluster index for internal nodes)
18    id: usize,
19    /// Height of this node
20    height: F,
21    /// Left child (None for leaves)
22    left: Option<Box<TreeNode<F>>>,
23    /// Right child (None for leaves)
24    right: Option<Box<TreeNode<F>>>,
25    /// Number of leaves under this node
26    leaf_count: usize,
27}
28
29impl<F: Float + std::fmt::Display> TreeNode<F> {
30    /// Create a new leaf node
31    fn new_leaf(id: usize) -> Self {
32        Self {
33            id,
34            height: F::zero(),
35            left: None,
36            right: None,
37            leaf_count: 1,
38        }
39    }
40
41    /// Create a new internal node
42    fn new_internal(id: usize, height: F, left: TreeNode<F>, right: TreeNode<F>) -> Self {
43        let leaf_count = left.leaf_count + right.leaf_count;
44        Self {
45            id,
46            height,
47            left: Some(Box::new(left)),
48            right: Some(Box::new(right)),
49            leaf_count,
50        }
51    }
52
53    /// Check if this node is a leaf
54    fn is_leaf(&self) -> bool {
55        self.left.is_none() && self.right.is_none()
56    }
57}
58
59/// Create an enhanced dendrogram plot from a linkage matrix
60///
61/// This function takes a linkage matrix (as produced by hierarchical clustering)
62/// and creates a comprehensive dendrogram visualization with advanced styling options.
63///
64/// # Arguments
65/// * `linkage_matrix` - The linkage matrix from hierarchical clustering
66/// * `labels` - Optional labels for the leaf nodes
67/// * `config` - Configuration options for the dendrogram
68///
69/// # Returns
70/// * `Result<DendrogramPlot<F>>` - The complete dendrogram plot structure
71///
72/// # Example
73/// ```rust
74/// use scirs2_core::ndarray::Array2;
75/// use scirs2_cluster::hierarchy::visualization::{create_dendrogramplot, DendrogramConfig};
76///
77/// let linkage = Array2::from_shape_vec((3, 4), vec![
78///     0.0, 1.0, 0.1, 2.0,
79///     2.0, 3.0, 0.2, 2.0,
80///     4.0, 5.0, 0.3, 4.0,
81/// ]).unwrap();
82/// let labels = Some(vec!["A".to_string(), "B".to_string(), "C".to_string(), "D".to_string()]);
83/// let config = DendrogramConfig::default();
84/// let plot = create_dendrogramplot(linkage.view(), labels.as_deref(), config).unwrap();
85/// ```
86pub fn create_dendrogramplot<F: Float + FromPrimitive + PartialOrd + Debug + std::fmt::Display>(
87    linkage_matrix: ArrayView2<F>,
88    labels: Option<&[String]>,
89    config: DendrogramConfig<F>,
90) -> Result<DendrogramPlot<F>> {
91    let n_samples = linkage_matrix.shape()[0] + 1;
92    if n_samples < 2 {
93        return Err(ClusteringError::InvalidInput(
94            "Need at least 2 samples to create dendrogram".into(),
95        ));
96    }
97
98    // Calculate color threshold if using automatic mode
99    let actual_threshold = if config.color_threshold.auto_threshold {
100        calculate_auto_threshold(linkage_matrix, config.color_threshold.target_clusters)?
101    } else {
102        config.color_threshold.threshold
103    };
104
105    // Build the dendrogram tree structure
106    let tree = build_dendrogram_tree(linkage_matrix)?;
107
108    // Calculate positions for nodes
109    let positions = calculate_node_positions(&tree, n_samples, config.orientation);
110
111    // Create branches
112    let branches = create_branches(&tree, &positions, actual_threshold, &config)?;
113
114    // Create leaves
115    let leaves = create_leaves(&positions, labels, n_samples, config.orientation);
116
117    // Assign colors to branches
118    let colors = assign_branch_colors(&branches, &config);
119
120    // Create legend
121    let legend = create_legend(&config, actual_threshold);
122
123    // Calculate plot bounds
124    let bounds = calculate_plot_bounds(&branches, &leaves);
125
126    Ok(DendrogramPlot {
127        branches,
128        leaves,
129        colors,
130        legend,
131        bounds,
132        config,
133    })
134}
135
136/// Build the tree structure from a linkage matrix
137fn build_dendrogram_tree<F: Float + FromPrimitive + Debug + std::fmt::Display>(
138    linkage_matrix: ArrayView2<F>,
139) -> Result<TreeNode<F>> {
140    let n_samples = linkage_matrix.shape()[0] + 1;
141    let mut nodes: HashMap<usize, TreeNode<F>> = HashMap::new();
142
143    // Create leaf nodes
144    for i in 0..n_samples {
145        nodes.insert(i, TreeNode::new_leaf(i));
146    }
147
148    // Create internal nodes from linkage matrix
149    for (i, row) in linkage_matrix.outer_iter().enumerate() {
150        let left_id = row[0].to_usize().unwrap();
151        let right_id = row[1].to_usize().unwrap();
152        let distance = row[2];
153
154        let left_node = nodes.remove(&left_id).ok_or_else(|| {
155            ClusteringError::InvalidInput(format!("Invalid left cluster ID: {}", left_id))
156        })?;
157
158        let right_node = nodes.remove(&right_id).ok_or_else(|| {
159            ClusteringError::InvalidInput(format!("Invalid right cluster ID: {}", right_id))
160        })?;
161
162        let internal_id = n_samples + i;
163        let internal_node = TreeNode::new_internal(internal_id, distance, left_node, right_node);
164
165        nodes.insert(internal_id, internal_node);
166    }
167
168    // Return the root node (should be the only remaining node)
169    let root_id = 2 * n_samples - 2;
170    nodes.remove(&root_id).ok_or_else(|| {
171        ClusteringError::ComputationError("Failed to construct dendrogram tree".to_string())
172    })
173}
174
175/// Calculate positions for all nodes in the dendrogram
176fn calculate_node_positions<F: Float + FromPrimitive + std::fmt::Display>(
177    root: &TreeNode<F>,
178    n_samples: usize,
179    orientation: DendrogramOrientation,
180) -> HashMap<usize, (F, F)> {
181    let mut positions = HashMap::new();
182    let mut leaf_counter = 0;
183
184    calculate_positions_recursive(root, &mut positions, &mut leaf_counter, orientation);
185    positions
186}
187
188/// Recursively calculate positions for nodes
189fn calculate_positions_recursive<F: Float + FromPrimitive + std::fmt::Display>(
190    node: &TreeNode<F>,
191    positions: &mut HashMap<usize, (F, F)>,
192    leaf_counter: &mut usize,
193    orientation: DendrogramOrientation,
194) -> F {
195    if node.is_leaf() {
196        let x_pos = F::from(*leaf_counter).unwrap();
197        let y_pos = F::zero();
198
199        let pos = match orientation {
200            DendrogramOrientation::Top => (x_pos, y_pos),
201            DendrogramOrientation::Bottom => (x_pos, -y_pos),
202            DendrogramOrientation::Left => (y_pos, x_pos),
203            DendrogramOrientation::Right => (-y_pos, x_pos),
204        };
205
206        positions.insert(node.id, pos);
207        *leaf_counter += 1;
208        x_pos
209    } else {
210        let left = node.left.as_ref().unwrap();
211        let right = node.right.as_ref().unwrap();
212
213        let left_x = calculate_positions_recursive(left, positions, leaf_counter, orientation);
214        let right_x = calculate_positions_recursive(right, positions, leaf_counter, orientation);
215
216        let x_pos = (left_x + right_x) / F::from(2).unwrap();
217        let y_pos = node.height;
218
219        let pos = match orientation {
220            DendrogramOrientation::Top => (x_pos, y_pos),
221            DendrogramOrientation::Bottom => (x_pos, -y_pos),
222            DendrogramOrientation::Left => (y_pos, x_pos),
223            DendrogramOrientation::Right => (-y_pos, x_pos),
224        };
225
226        positions.insert(node.id, pos);
227        x_pos
228    }
229}
230
231/// Create branch structures for visualization
232fn create_branches<F: Float + FromPrimitive + PartialOrd + std::fmt::Display>(
233    root: &TreeNode<F>,
234    positions: &HashMap<usize, (F, F)>,
235    threshold: F,
236    config: &DendrogramConfig<F>,
237) -> Result<Vec<Branch<F>>> {
238    let mut branches = Vec::new();
239    create_branches_recursive(root, positions, threshold, config, &mut branches)?;
240    Ok(branches)
241}
242
243/// Recursively create branches
244fn create_branches_recursive<F: Float + FromPrimitive + PartialOrd + std::fmt::Display>(
245    node: &TreeNode<F>,
246    positions: &HashMap<usize, (F, F)>,
247    threshold: F,
248    config: &DendrogramConfig<F>,
249    branches: &mut Vec<Branch<F>>,
250) -> Result<()> {
251    if !node.is_leaf() {
252        let left = node.left.as_ref().unwrap();
253        let right = node.right.as_ref().unwrap();
254
255        let node_pos = positions.get(&node.id).unwrap();
256        let left_pos = positions.get(&left.id).unwrap();
257        let right_pos = positions.get(&right.id).unwrap();
258
259        // Determine color based on threshold
260        let color = if node.height > threshold {
261            config.color_threshold.above_color.clone()
262        } else {
263            config.color_threshold.below_color.clone()
264        };
265
266        // Create horizontal line from left child to right child
267        let horizontal_branch = Branch {
268            start: *left_pos,
269            end: *right_pos,
270            distance: node.height,
271            cluster_id: Some(node.id),
272            color: color.clone(),
273            line_width: Some(config.line_width),
274        };
275        branches.push(horizontal_branch);
276
277        // Create vertical line from horizontal line to node
278        let mid_x = (left_pos.0 + right_pos.0) / F::from(2).unwrap();
279        let vertical_start = (mid_x, left_pos.1.max(right_pos.1));
280        let vertical_branch = Branch {
281            start: vertical_start,
282            end: *node_pos,
283            distance: node.height,
284            cluster_id: Some(node.id),
285            color,
286            line_width: Some(config.line_width),
287        };
288        branches.push(vertical_branch);
289
290        // Recursively process children
291        create_branches_recursive(left, positions, threshold, config, branches)?;
292        create_branches_recursive(right, positions, threshold, config, branches)?;
293    }
294
295    Ok(())
296}
297
298/// Create leaf representations
299fn create_leaves<F: Float + FromPrimitive>(
300    positions: &HashMap<usize, (F, F)>,
301    labels: Option<&[String]>,
302    n_samples: usize,
303    orientation: DendrogramOrientation,
304) -> Vec<Leaf> {
305    let mut leaves = Vec::new();
306
307    for i in 0..n_samples {
308        if let Some(pos) = positions.get(&i) {
309            let label = if let Some(labels) = labels {
310                labels
311                    .get(i)
312                    .cloned()
313                    .unwrap_or_else(|| format!("Sample {}", i))
314            } else {
315                format!("Sample {}", i)
316            };
317
318            let leaf = Leaf {
319                position: (pos.0.to_f64().unwrap(), pos.1.to_f64().unwrap()),
320                label,
321                color: "#333333".to_string(),
322                data_index: i,
323            };
324
325            leaves.push(leaf);
326        }
327    }
328
329    leaves
330}
331
332/// Assign colors to branches
333fn assign_branch_colors<F: Float>(
334    branches: &[Branch<F>],
335    config: &DendrogramConfig<F>,
336) -> Vec<String> {
337    branches.iter().map(|branch| branch.color.clone()).collect()
338}
339
340/// Create legend for the plot
341fn create_legend<F: Float + std::fmt::Display>(
342    config: &DendrogramConfig<F>,
343    threshold: F,
344) -> Vec<LegendEntry> {
345    if config.color_threshold.auto_threshold || config.color_threshold.threshold > F::zero() {
346        vec![
347            LegendEntry {
348                color: config.color_threshold.above_color.clone(),
349                label: format!("Distance > {}", threshold),
350                threshold: Some(threshold.to_f64().unwrap_or(0.0)),
351            },
352            LegendEntry {
353                color: config.color_threshold.below_color.clone(),
354                label: format!("Distance d {}", threshold),
355                threshold: Some(threshold.to_f64().unwrap_or(0.0)),
356            },
357        ]
358    } else {
359        Vec::new()
360    }
361}
362
363/// Calculate plot bounds
364fn calculate_plot_bounds<F: Float>(branches: &[Branch<F>], leaves: &[Leaf]) -> (F, F, F, F) {
365    if branches.is_empty() && leaves.is_empty() {
366        return (F::zero(), F::zero(), F::zero(), F::zero());
367    }
368
369    let mut min_x = F::infinity();
370    let mut max_x = F::neg_infinity();
371    let mut min_y = F::infinity();
372    let mut max_y = F::neg_infinity();
373
374    // Consider branch bounds
375    for branch in branches {
376        min_x = min_x.min(branch.start.0).min(branch.end.0);
377        max_x = max_x.max(branch.start.0).max(branch.end.0);
378        min_y = min_y.min(branch.start.1).min(branch.end.1);
379        max_y = max_y.max(branch.start.1).max(branch.end.1);
380    }
381
382    // Consider leaf bounds
383    for leaf in leaves {
384        let leaf_x = F::from(leaf.position.0).unwrap();
385        let leaf_y = F::from(leaf.position.1).unwrap();
386        min_x = min_x.min(leaf_x);
387        max_x = max_x.max(leaf_x);
388        min_y = min_y.min(leaf_y);
389        max_y = max_y.max(leaf_y);
390    }
391
392    (min_x, max_x, min_y, max_y)
393}
394
395/// Calculate automatic threshold based on desired number of clusters
396fn calculate_auto_threshold<F: Float + FromPrimitive + PartialOrd>(
397    linkage_matrix: ArrayView2<F>,
398    target_clusters: Option<usize>,
399) -> Result<F> {
400    let target = target_clusters.unwrap_or(4);
401    let n_merges = linkage_matrix.shape()[0];
402
403    if target > n_merges {
404        return Ok(F::zero());
405    }
406
407    // Get the distance at which we have the target number of clusters
408    let merge_index = n_merges - target;
409    let threshold = linkage_matrix[[merge_index, 2]];
410
411    Ok(threshold)
412}
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417    use scirs2_core::ndarray::Array2;
418
419    #[test]
420    fn test_tree_node_creation() {
421        let leaf = TreeNode::<f64>::new_leaf(0);
422        assert!(leaf.is_leaf());
423        assert_eq!(leaf.id, 0);
424        assert_eq!(leaf.leaf_count, 1);
425
426        let left = TreeNode::new_leaf(0);
427        let right = TreeNode::new_leaf(1);
428        let internal = TreeNode::new_internal(2, 0.5, left, right);
429        assert!(!internal.is_leaf());
430        assert_eq!(internal.leaf_count, 2);
431    }
432
433    #[test]
434    fn test_build_dendrogram_tree() {
435        let linkage = Array2::from_shape_vec(
436            (3, 4),
437            vec![0.0, 1.0, 0.1, 2.0, 2.0, 3.0, 0.2, 2.0, 4.0, 5.0, 0.3, 4.0],
438        )
439        .unwrap();
440
441        let tree = build_dendrogram_tree(linkage.view()).unwrap();
442        assert!(!tree.is_leaf());
443        assert_eq!(tree.leaf_count, 4);
444    }
445
446    #[test]
447    fn test_calculate_auto_threshold() {
448        let linkage = Array2::from_shape_vec(
449            (3, 4),
450            vec![0.0, 1.0, 0.1, 2.0, 2.0, 3.0, 0.2, 2.0, 4.0, 5.0, 0.3, 4.0],
451        )
452        .unwrap();
453
454        let threshold = calculate_auto_threshold(linkage.view(), Some(2)).unwrap();
455        assert!((threshold - 0.2).abs() < 1e-10);
456    }
457
458    #[test]
459    fn test_create_dendrogramplot() {
460        let linkage = Array2::from_shape_vec((1, 4), vec![0.0, 1.0, 0.1, 2.0]).unwrap();
461
462        let labels = Some(vec!["A".to_string(), "B".to_string()]);
463        let config = DendrogramConfig::default();
464
465        let plot = create_dendrogramplot(linkage.view(), labels.as_deref(), config).unwrap();
466        assert!(!plot.branches.is_empty());
467        assert_eq!(plot.leaves.len(), 2);
468    }
469}