1use scirs2_core::ndarray::ArrayView2;
7use scirs2_core::numeric::{Float, FromPrimitive};
8use std::collections::HashMap;
9use std::fmt::Debug;
10
11use super::types::*;
12use crate::error::{ClusteringError, Result};
13
14#[derive(Debug, Clone)]
16struct TreeNode<F: Float> {
17 id: usize,
19 height: F,
21 left: Option<Box<TreeNode<F>>>,
23 right: Option<Box<TreeNode<F>>>,
25 leaf_count: usize,
27}
28
29impl<F: Float + std::fmt::Display> TreeNode<F> {
30 fn new_leaf(id: usize) -> Self {
32 Self {
33 id,
34 height: F::zero(),
35 left: None,
36 right: None,
37 leaf_count: 1,
38 }
39 }
40
41 fn new_internal(id: usize, height: F, left: TreeNode<F>, right: TreeNode<F>) -> Self {
43 let leaf_count = left.leaf_count + right.leaf_count;
44 Self {
45 id,
46 height,
47 left: Some(Box::new(left)),
48 right: Some(Box::new(right)),
49 leaf_count,
50 }
51 }
52
53 fn is_leaf(&self) -> bool {
55 self.left.is_none() && self.right.is_none()
56 }
57}
58
59pub fn create_dendrogramplot<F: Float + FromPrimitive + PartialOrd + Debug + std::fmt::Display>(
87 linkage_matrix: ArrayView2<F>,
88 labels: Option<&[String]>,
89 config: DendrogramConfig<F>,
90) -> Result<DendrogramPlot<F>> {
91 let n_samples = linkage_matrix.shape()[0] + 1;
92 if n_samples < 2 {
93 return Err(ClusteringError::InvalidInput(
94 "Need at least 2 samples to create dendrogram".into(),
95 ));
96 }
97
98 let actual_threshold = if config.color_threshold.auto_threshold {
100 calculate_auto_threshold(linkage_matrix, config.color_threshold.target_clusters)?
101 } else {
102 config.color_threshold.threshold
103 };
104
105 let tree = build_dendrogram_tree(linkage_matrix)?;
107
108 let positions = calculate_node_positions(&tree, n_samples, config.orientation);
110
111 let branches = create_branches(&tree, &positions, actual_threshold, &config)?;
113
114 let leaves = create_leaves(&positions, labels, n_samples, config.orientation);
116
117 let colors = assign_branch_colors(&branches, &config);
119
120 let legend = create_legend(&config, actual_threshold);
122
123 let bounds = calculate_plot_bounds(&branches, &leaves);
125
126 Ok(DendrogramPlot {
127 branches,
128 leaves,
129 colors,
130 legend,
131 bounds,
132 config,
133 })
134}
135
136fn build_dendrogram_tree<F: Float + FromPrimitive + Debug + std::fmt::Display>(
138 linkage_matrix: ArrayView2<F>,
139) -> Result<TreeNode<F>> {
140 let n_samples = linkage_matrix.shape()[0] + 1;
141 let mut nodes: HashMap<usize, TreeNode<F>> = HashMap::new();
142
143 for i in 0..n_samples {
145 nodes.insert(i, TreeNode::new_leaf(i));
146 }
147
148 for (i, row) in linkage_matrix.outer_iter().enumerate() {
150 let left_id = row[0].to_usize().unwrap();
151 let right_id = row[1].to_usize().unwrap();
152 let distance = row[2];
153
154 let left_node = nodes.remove(&left_id).ok_or_else(|| {
155 ClusteringError::InvalidInput(format!("Invalid left cluster ID: {}", left_id))
156 })?;
157
158 let right_node = nodes.remove(&right_id).ok_or_else(|| {
159 ClusteringError::InvalidInput(format!("Invalid right cluster ID: {}", right_id))
160 })?;
161
162 let internal_id = n_samples + i;
163 let internal_node = TreeNode::new_internal(internal_id, distance, left_node, right_node);
164
165 nodes.insert(internal_id, internal_node);
166 }
167
168 let root_id = 2 * n_samples - 2;
170 nodes.remove(&root_id).ok_or_else(|| {
171 ClusteringError::ComputationError("Failed to construct dendrogram tree".to_string())
172 })
173}
174
175fn calculate_node_positions<F: Float + FromPrimitive + std::fmt::Display>(
177 root: &TreeNode<F>,
178 n_samples: usize,
179 orientation: DendrogramOrientation,
180) -> HashMap<usize, (F, F)> {
181 let mut positions = HashMap::new();
182 let mut leaf_counter = 0;
183
184 calculate_positions_recursive(root, &mut positions, &mut leaf_counter, orientation);
185 positions
186}
187
188fn calculate_positions_recursive<F: Float + FromPrimitive + std::fmt::Display>(
190 node: &TreeNode<F>,
191 positions: &mut HashMap<usize, (F, F)>,
192 leaf_counter: &mut usize,
193 orientation: DendrogramOrientation,
194) -> F {
195 if node.is_leaf() {
196 let x_pos = F::from(*leaf_counter).unwrap();
197 let y_pos = F::zero();
198
199 let pos = match orientation {
200 DendrogramOrientation::Top => (x_pos, y_pos),
201 DendrogramOrientation::Bottom => (x_pos, -y_pos),
202 DendrogramOrientation::Left => (y_pos, x_pos),
203 DendrogramOrientation::Right => (-y_pos, x_pos),
204 };
205
206 positions.insert(node.id, pos);
207 *leaf_counter += 1;
208 x_pos
209 } else {
210 let left = node.left.as_ref().unwrap();
211 let right = node.right.as_ref().unwrap();
212
213 let left_x = calculate_positions_recursive(left, positions, leaf_counter, orientation);
214 let right_x = calculate_positions_recursive(right, positions, leaf_counter, orientation);
215
216 let x_pos = (left_x + right_x) / F::from(2).unwrap();
217 let y_pos = node.height;
218
219 let pos = match orientation {
220 DendrogramOrientation::Top => (x_pos, y_pos),
221 DendrogramOrientation::Bottom => (x_pos, -y_pos),
222 DendrogramOrientation::Left => (y_pos, x_pos),
223 DendrogramOrientation::Right => (-y_pos, x_pos),
224 };
225
226 positions.insert(node.id, pos);
227 x_pos
228 }
229}
230
231fn create_branches<F: Float + FromPrimitive + PartialOrd + std::fmt::Display>(
233 root: &TreeNode<F>,
234 positions: &HashMap<usize, (F, F)>,
235 threshold: F,
236 config: &DendrogramConfig<F>,
237) -> Result<Vec<Branch<F>>> {
238 let mut branches = Vec::new();
239 create_branches_recursive(root, positions, threshold, config, &mut branches)?;
240 Ok(branches)
241}
242
243fn create_branches_recursive<F: Float + FromPrimitive + PartialOrd + std::fmt::Display>(
245 node: &TreeNode<F>,
246 positions: &HashMap<usize, (F, F)>,
247 threshold: F,
248 config: &DendrogramConfig<F>,
249 branches: &mut Vec<Branch<F>>,
250) -> Result<()> {
251 if !node.is_leaf() {
252 let left = node.left.as_ref().unwrap();
253 let right = node.right.as_ref().unwrap();
254
255 let node_pos = positions.get(&node.id).unwrap();
256 let left_pos = positions.get(&left.id).unwrap();
257 let right_pos = positions.get(&right.id).unwrap();
258
259 let color = if node.height > threshold {
261 config.color_threshold.above_color.clone()
262 } else {
263 config.color_threshold.below_color.clone()
264 };
265
266 let horizontal_branch = Branch {
268 start: *left_pos,
269 end: *right_pos,
270 distance: node.height,
271 cluster_id: Some(node.id),
272 color: color.clone(),
273 line_width: Some(config.line_width),
274 };
275 branches.push(horizontal_branch);
276
277 let mid_x = (left_pos.0 + right_pos.0) / F::from(2).unwrap();
279 let vertical_start = (mid_x, left_pos.1.max(right_pos.1));
280 let vertical_branch = Branch {
281 start: vertical_start,
282 end: *node_pos,
283 distance: node.height,
284 cluster_id: Some(node.id),
285 color,
286 line_width: Some(config.line_width),
287 };
288 branches.push(vertical_branch);
289
290 create_branches_recursive(left, positions, threshold, config, branches)?;
292 create_branches_recursive(right, positions, threshold, config, branches)?;
293 }
294
295 Ok(())
296}
297
298fn create_leaves<F: Float + FromPrimitive>(
300 positions: &HashMap<usize, (F, F)>,
301 labels: Option<&[String]>,
302 n_samples: usize,
303 orientation: DendrogramOrientation,
304) -> Vec<Leaf> {
305 let mut leaves = Vec::new();
306
307 for i in 0..n_samples {
308 if let Some(pos) = positions.get(&i) {
309 let label = if let Some(labels) = labels {
310 labels
311 .get(i)
312 .cloned()
313 .unwrap_or_else(|| format!("Sample {}", i))
314 } else {
315 format!("Sample {}", i)
316 };
317
318 let leaf = Leaf {
319 position: (pos.0.to_f64().unwrap(), pos.1.to_f64().unwrap()),
320 label,
321 color: "#333333".to_string(),
322 data_index: i,
323 };
324
325 leaves.push(leaf);
326 }
327 }
328
329 leaves
330}
331
332fn assign_branch_colors<F: Float>(
334 branches: &[Branch<F>],
335 config: &DendrogramConfig<F>,
336) -> Vec<String> {
337 branches.iter().map(|branch| branch.color.clone()).collect()
338}
339
340fn create_legend<F: Float + std::fmt::Display>(
342 config: &DendrogramConfig<F>,
343 threshold: F,
344) -> Vec<LegendEntry> {
345 if config.color_threshold.auto_threshold || config.color_threshold.threshold > F::zero() {
346 vec![
347 LegendEntry {
348 color: config.color_threshold.above_color.clone(),
349 label: format!("Distance > {}", threshold),
350 threshold: Some(threshold.to_f64().unwrap_or(0.0)),
351 },
352 LegendEntry {
353 color: config.color_threshold.below_color.clone(),
354 label: format!("Distance d {}", threshold),
355 threshold: Some(threshold.to_f64().unwrap_or(0.0)),
356 },
357 ]
358 } else {
359 Vec::new()
360 }
361}
362
363fn calculate_plot_bounds<F: Float>(branches: &[Branch<F>], leaves: &[Leaf]) -> (F, F, F, F) {
365 if branches.is_empty() && leaves.is_empty() {
366 return (F::zero(), F::zero(), F::zero(), F::zero());
367 }
368
369 let mut min_x = F::infinity();
370 let mut max_x = F::neg_infinity();
371 let mut min_y = F::infinity();
372 let mut max_y = F::neg_infinity();
373
374 for branch in branches {
376 min_x = min_x.min(branch.start.0).min(branch.end.0);
377 max_x = max_x.max(branch.start.0).max(branch.end.0);
378 min_y = min_y.min(branch.start.1).min(branch.end.1);
379 max_y = max_y.max(branch.start.1).max(branch.end.1);
380 }
381
382 for leaf in leaves {
384 let leaf_x = F::from(leaf.position.0).unwrap();
385 let leaf_y = F::from(leaf.position.1).unwrap();
386 min_x = min_x.min(leaf_x);
387 max_x = max_x.max(leaf_x);
388 min_y = min_y.min(leaf_y);
389 max_y = max_y.max(leaf_y);
390 }
391
392 (min_x, max_x, min_y, max_y)
393}
394
395fn calculate_auto_threshold<F: Float + FromPrimitive + PartialOrd>(
397 linkage_matrix: ArrayView2<F>,
398 target_clusters: Option<usize>,
399) -> Result<F> {
400 let target = target_clusters.unwrap_or(4);
401 let n_merges = linkage_matrix.shape()[0];
402
403 if target > n_merges {
404 return Ok(F::zero());
405 }
406
407 let merge_index = n_merges - target;
409 let threshold = linkage_matrix[[merge_index, 2]];
410
411 Ok(threshold)
412}
413
414#[cfg(test)]
415mod tests {
416 use super::*;
417 use scirs2_core::ndarray::Array2;
418
419 #[test]
420 fn test_tree_node_creation() {
421 let leaf = TreeNode::<f64>::new_leaf(0);
422 assert!(leaf.is_leaf());
423 assert_eq!(leaf.id, 0);
424 assert_eq!(leaf.leaf_count, 1);
425
426 let left = TreeNode::new_leaf(0);
427 let right = TreeNode::new_leaf(1);
428 let internal = TreeNode::new_internal(2, 0.5, left, right);
429 assert!(!internal.is_leaf());
430 assert_eq!(internal.leaf_count, 2);
431 }
432
433 #[test]
434 fn test_build_dendrogram_tree() {
435 let linkage = Array2::from_shape_vec(
436 (3, 4),
437 vec![0.0, 1.0, 0.1, 2.0, 2.0, 3.0, 0.2, 2.0, 4.0, 5.0, 0.3, 4.0],
438 )
439 .unwrap();
440
441 let tree = build_dendrogram_tree(linkage.view()).unwrap();
442 assert!(!tree.is_leaf());
443 assert_eq!(tree.leaf_count, 4);
444 }
445
446 #[test]
447 fn test_calculate_auto_threshold() {
448 let linkage = Array2::from_shape_vec(
449 (3, 4),
450 vec![0.0, 1.0, 0.1, 2.0, 2.0, 3.0, 0.2, 2.0, 4.0, 5.0, 0.3, 4.0],
451 )
452 .unwrap();
453
454 let threshold = calculate_auto_threshold(linkage.view(), Some(2)).unwrap();
455 assert!((threshold - 0.2).abs() < 1e-10);
456 }
457
458 #[test]
459 fn test_create_dendrogramplot() {
460 let linkage = Array2::from_shape_vec((1, 4), vec![0.0, 1.0, 0.1, 2.0]).unwrap();
461
462 let labels = Some(vec!["A".to_string(), "B".to_string()]);
463 let config = DendrogramConfig::default();
464
465 let plot = create_dendrogramplot(linkage.view(), labels.as_deref(), config).unwrap();
466 assert!(!plot.branches.is_empty());
467 assert_eq!(plot.leaves.len(), 2);
468 }
469}