scirs2_cluster/
leader.rs

1//! Leader algorithm implementation for clustering
2//!
3//! The Leader algorithm is a simple, single-pass clustering algorithm that
4//! processes data points sequentially, creating clusters on-the-fly.
5
6use crate::error::{ClusteringError, Result};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
8use scirs2_core::numeric::Float;
9use std::fmt::Debug;
10
11/// Leader algorithm clustering
12///
13/// The Leader algorithm is a simple single-pass clustering method that:
14/// 1. Takes the first data point as the first cluster leader
15/// 2. For each subsequent point:
16///    - If it's within threshold distance of an existing leader, assign it to that cluster
17///    - Otherwise, make it a new leader
18///
19/// # Type Parameters
20///
21/// * `F` - Floating point type (f32 or f64)
22///
23/// # Arguments
24///
25/// * `data` - Input data matrix of shape (n_samples, n_features)
26/// * `threshold` - Distance threshold for creating new clusters
27/// * `metric` - Distance metric function
28///
29/// # Returns
30///
31/// * `leaders` - Array of cluster leaders
32/// * `labels` - Cluster assignments for each data point
33///
34/// # Example
35///
36/// ```
37/// use scirs2_core::ndarray::array;
38/// use scirs2_cluster::leader::{leader_clustering, euclidean_distance};
39///
40/// let data = array![
41///     [1.0, 2.0],
42///     [1.2, 1.8],
43///     [5.0, 4.0],
44///     [5.2, 4.1],
45/// ];
46///
47/// let (leaders, labels) = leader_clustering(data.view(), 1.0, euclidean_distance).unwrap();
48/// ```
49#[allow(dead_code)]
50pub fn leader_clustering<F, D>(
51    data: ArrayView2<F>,
52    threshold: F,
53    metric: D,
54) -> Result<(Array2<F>, Array1<usize>)>
55where
56    F: Float + Debug,
57    D: Fn(ArrayView1<F>, ArrayView1<F>) -> F,
58{
59    if data.is_empty() {
60        return Err(ClusteringError::InvalidInput(
61            "Input data is empty".to_string(),
62        ));
63    }
64
65    if threshold <= F::zero() {
66        return Err(ClusteringError::InvalidInput(
67            "Threshold must be positive".to_string(),
68        ));
69    }
70
71    let n_samples = data.nrows();
72    let n_features = data.ncols();
73
74    let mut leaders: Vec<Array1<F>> = Vec::new();
75    let mut labels = Array1::zeros(n_samples);
76
77    // Process each data point
78    for (i, sample) in data.rows().into_iter().enumerate() {
79        let mut min_distance = F::infinity();
80        let mut closest_leader = 0;
81
82        // Find the closest leader
83        for (j, leader) in leaders.iter().enumerate() {
84            let distance = metric(sample, leader.view());
85            if distance < min_distance {
86                min_distance = distance;
87                closest_leader = j;
88            }
89        }
90
91        // Assign to existing cluster or create new one
92        if leaders.is_empty() || min_distance > threshold {
93            // Create new cluster
94            leaders.push(sample.to_owned());
95            let label_idx = leaders.len() - 1;
96            labels[i] = label_idx;
97        } else {
98            // Assign to existing cluster
99            labels[i] = closest_leader;
100        }
101    }
102
103    // Convert leaders to Array2
104    let n_leaders = leaders.len();
105    let mut leaders_array = Array2::zeros((n_leaders, n_features));
106    for (i, leader) in leaders.iter().enumerate() {
107        leaders_array.row_mut(i).assign(leader);
108    }
109
110    Ok((leaders_array, labels))
111}
112
113/// Euclidean distance function
114#[allow(dead_code)]
115pub fn euclidean_distance<F: Float>(a: ArrayView1<F>, b: ArrayView1<F>) -> F {
116    a.iter()
117        .zip(b.iter())
118        .map(|(x, y)| (*x - *y) * (*x - *y))
119        .fold(F::zero(), |acc, x| acc + x)
120        .sqrt()
121}
122
123/// Manhattan distance function
124#[allow(dead_code)]
125pub fn manhattan_distance<F: Float>(a: ArrayView1<F>, b: ArrayView1<F>) -> F {
126    a.iter()
127        .zip(b.iter())
128        .map(|(x, y)| (*x - *y).abs())
129        .fold(F::zero(), |acc, x| acc + x)
130}
131
132/// Leader algorithm with order-dependent results
133///
134/// This variant processes points in the order they appear, which can lead
135/// to different results based on data ordering.
136pub struct LeaderClustering<F: Float> {
137    threshold: F,
138    leaders: Vec<Array1<F>>,
139}
140
141impl<F: Float + Debug> LeaderClustering<F> {
142    /// Create a new Leader clustering instance
143    pub fn new(threshold: F) -> Result<Self> {
144        if threshold <= F::zero() {
145            return Err(ClusteringError::InvalidInput(
146                "Threshold must be positive".to_string(),
147            ));
148        }
149
150        Ok(Self {
151            threshold,
152            leaders: Vec::new(),
153        })
154    }
155
156    /// Fit the model to data
157    pub fn fit(&mut self, data: ArrayView2<F>) -> Result<()> {
158        self.leaders.clear();
159
160        for sample in data.rows() {
161            let mut min_distance = F::infinity();
162
163            // Find the closest leader
164            for leader in &self.leaders {
165                let distance = euclidean_distance(sample, leader.view());
166                if distance < min_distance {
167                    min_distance = distance;
168                }
169            }
170
171            // Create new cluster if needed
172            if self.leaders.is_empty() || min_distance > self.threshold {
173                self.leaders.push(sample.to_owned());
174            }
175        }
176
177        Ok(())
178    }
179
180    /// Predict cluster labels for data
181    pub fn predict(&self, data: ArrayView2<F>) -> Result<Array1<usize>> {
182        if self.leaders.is_empty() {
183            return Err(ClusteringError::InvalidState(
184                "Model has not been fitted yet".to_string(),
185            ));
186        }
187
188        let n_samples = data.nrows();
189        let mut labels = Array1::zeros(n_samples);
190
191        for (i, sample) in data.rows().into_iter().enumerate() {
192            let mut min_distance = F::infinity();
193            let mut closest_leader = 0;
194
195            for (j, leader) in self.leaders.iter().enumerate() {
196                let distance = euclidean_distance(sample, leader.view());
197                if distance < min_distance {
198                    min_distance = distance;
199                    closest_leader = j;
200                }
201            }
202
203            labels[i] = closest_leader;
204        }
205
206        Ok(labels)
207    }
208
209    /// Fit the model and return predictions
210    pub fn fit_predict(&mut self, data: ArrayView2<F>) -> Result<Array1<usize>> {
211        self.fit(data)?;
212        self.predict(data)
213    }
214
215    /// Get the cluster leaders
216    pub fn get_leaders(&self) -> Array2<F> {
217        if self.leaders.is_empty() {
218            return Array2::zeros((0, 0));
219        }
220
221        let n_leaders = self.leaders.len();
222        let n_features = self.leaders[0].len();
223        let mut leaders_array = Array2::zeros((n_leaders, n_features));
224
225        for (i, leader) in self.leaders.iter().enumerate() {
226            leaders_array.row_mut(i).assign(leader);
227        }
228
229        leaders_array
230    }
231
232    /// Get the number of clusters
233    pub fn n_clusters(&self) -> usize {
234        self.leaders.len()
235    }
236}
237
238/// Tree representation for hierarchical organization of leaders
239#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
240pub struct LeaderTree<F: Float> {
241    /// Root nodes of the tree
242    pub roots: Vec<LeaderNode<F>>,
243    /// Distance threshold for this level
244    pub threshold: F,
245}
246
247/// Node in the leader tree structure
248#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
249pub struct LeaderNode<F: Float> {
250    /// The leader vector
251    pub leader: Array1<F>,
252    /// Child nodes
253    pub children: Vec<LeaderNode<F>>,
254    /// Indices of data points in this cluster
255    pub members: Vec<usize>,
256}
257
258impl<F: Float + Debug> LeaderTree<F> {
259    /// Build a hierarchical leader tree with multiple threshold levels
260    pub fn build_hierarchical(data: ArrayView2<F>, thresholds: &[F]) -> Result<Self> {
261        if thresholds.is_empty() {
262            return Err(ClusteringError::InvalidInput(
263                "At least one threshold is required".to_string(),
264            ));
265        }
266
267        // Start with the largest threshold
268        let current_threshold = thresholds[0];
269        let (leaders, labels) = leader_clustering(data, current_threshold, euclidean_distance)?;
270
271        // Build root nodes
272        let mut roots = Vec::new();
273        for i in 0..leaders.nrows() {
274            let mut members = Vec::new();
275            for (j, &label) in labels.iter().enumerate() {
276                if label == i {
277                    members.push(j);
278                }
279            }
280
281            roots.push(LeaderNode {
282                leader: leaders.row(i).to_owned(),
283                children: Vec::new(),
284                members,
285            });
286        }
287
288        // Build lower levels if more thresholds provided
289        if thresholds.len() > 1 {
290            for root in &mut roots {
291                Self::build_subtree(data, root, &thresholds[1..])?;
292            }
293        }
294
295        Ok(LeaderTree {
296            roots,
297            threshold: current_threshold,
298        })
299    }
300
301    fn build_subtree(
302        data: ArrayView2<F>,
303        parent: &mut LeaderNode<F>,
304        thresholds: &[F],
305    ) -> Result<()> {
306        if thresholds.is_empty() || parent.members.len() <= 1 {
307            return Ok(());
308        }
309
310        // Extract data for this cluster
311        let n_features = data.ncols();
312        let mut cluster_data = Array2::zeros((parent.members.len(), n_features));
313        for (i, &idx) in parent.members.iter().enumerate() {
314            cluster_data.row_mut(i).assign(&data.row(idx));
315        }
316
317        // Cluster with smaller threshold
318        let (sub_leaders, sub_labels) =
319            leader_clustering(cluster_data.view(), thresholds[0], euclidean_distance)?;
320
321        // Build child nodes
322        for i in 0..sub_leaders.nrows() {
323            let mut members = Vec::new();
324            for (j, &label) in sub_labels.iter().enumerate() {
325                if label == i {
326                    members.push(parent.members[j]);
327                }
328            }
329
330            let mut child = LeaderNode {
331                leader: sub_leaders.row(i).to_owned(),
332                children: Vec::new(),
333                members,
334            };
335
336            // Recursively build subtree
337            if thresholds.len() > 1 {
338                Self::build_subtree(data, &mut child, &thresholds[1..])?;
339            }
340
341            parent.children.push(child);
342        }
343
344        Ok(())
345    }
346
347    /// Get the total number of nodes in the tree
348    pub fn node_count(&self) -> usize {
349        self.roots.iter().map(|root| Self::count_nodes(root)).sum()
350    }
351
352    fn count_nodes(node: &LeaderNode<F>) -> usize {
353        1 + node
354            .children
355            .iter()
356            .map(|child| Self::count_nodes(child))
357            .sum::<usize>()
358    }
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364    use scirs2_core::ndarray::array;
365
366    #[test]
367    fn test_leader_clustering_basic() {
368        let data = array![[1.0, 2.0], [1.2, 1.8], [5.0, 4.0], [5.2, 4.1],];
369
370        let (leaders, labels) = leader_clustering(data.view(), 1.0, euclidean_distance).unwrap();
371
372        // Should create 2 clusters
373        assert_eq!(leaders.nrows(), 2);
374        assert_eq!(labels.len(), 4);
375
376        // Points 0,1 should be in one cluster, points 2,3 in another
377        assert_eq!(labels[0], labels[1]);
378        assert_eq!(labels[2], labels[3]);
379        assert_ne!(labels[0], labels[2]);
380    }
381
382    #[test]
383    fn test_leader_clustering_single_cluster() {
384        let data = array![[1.0, 2.0], [1.2, 1.8], [1.1, 2.1], [0.9, 1.9],];
385
386        let (leaders, labels) = leader_clustering(data.view(), 2.0, euclidean_distance).unwrap();
387
388        // Should create 1 cluster with large threshold
389        assert_eq!(leaders.nrows(), 1);
390        assert!(labels.iter().all(|&l| l == 0));
391    }
392
393    #[test]
394    fn test_leader_class() {
395        let data = array![[1.0, 2.0], [1.2, 1.8], [5.0, 4.0], [5.2, 4.1],];
396
397        let mut leader = LeaderClustering::new(1.0).unwrap();
398        let labels = leader.fit_predict(data.view()).unwrap();
399
400        assert_eq!(leader.n_clusters(), 2);
401        assert_eq!(labels.len(), 4);
402
403        // Test prediction on new data
404        let new_data = array![[1.1, 1.9], [5.1, 4.05]];
405        let new_labels = leader.predict(new_data.view()).unwrap();
406        assert_eq!(new_labels[0], labels[0]); // Close to first cluster
407        assert_eq!(new_labels[1], labels[2]); // Close to second cluster
408    }
409
410    #[test]
411    fn test_hierarchical_leader_tree() {
412        let data = array![
413            [1.0, 2.0],
414            [1.2, 1.8],
415            [5.0, 4.0],
416            [5.2, 4.1],
417            [10.0, 10.0],
418            [10.2, 9.8],
419        ];
420
421        let thresholds = vec![6.0, 1.0];
422        let tree = LeaderTree::build_hierarchical(data.view(), &thresholds).unwrap();
423
424        // At threshold 6.0, should have 2 clusters (1,2 and 3,4,5,6)
425        assert!(tree.roots.len() <= 3);
426        assert!(tree.node_count() > tree.roots.len()); // Should have child nodes
427    }
428
429    #[test]
430    fn test_invalid_threshold() {
431        let data = array![[1.0, 2.0]];
432
433        let result = leader_clustering(data.view(), -1.0, euclidean_distance);
434        assert!(result.is_err());
435
436        let result = LeaderClustering::<f64>::new(-1.0);
437        assert!(result.is_err());
438    }
439
440    #[test]
441    fn test_empty_data() {
442        let data: Array2<f64> = Array2::zeros((0, 2));
443
444        let result = leader_clustering(data.view(), 1.0, euclidean_distance);
445        assert!(result.is_err());
446    }
447}