1use crate::error::{ClusteringError, Result};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
8use scirs2_core::numeric::Float;
9use std::fmt::Debug;
10
11#[allow(dead_code)]
50pub fn leader_clustering<F, D>(
51 data: ArrayView2<F>,
52 threshold: F,
53 metric: D,
54) -> Result<(Array2<F>, Array1<usize>)>
55where
56 F: Float + Debug,
57 D: Fn(ArrayView1<F>, ArrayView1<F>) -> F,
58{
59 if data.is_empty() {
60 return Err(ClusteringError::InvalidInput(
61 "Input data is empty".to_string(),
62 ));
63 }
64
65 if threshold <= F::zero() {
66 return Err(ClusteringError::InvalidInput(
67 "Threshold must be positive".to_string(),
68 ));
69 }
70
71 let n_samples = data.nrows();
72 let n_features = data.ncols();
73
74 let mut leaders: Vec<Array1<F>> = Vec::new();
75 let mut labels = Array1::zeros(n_samples);
76
77 for (i, sample) in data.rows().into_iter().enumerate() {
79 let mut min_distance = F::infinity();
80 let mut closest_leader = 0;
81
82 for (j, leader) in leaders.iter().enumerate() {
84 let distance = metric(sample, leader.view());
85 if distance < min_distance {
86 min_distance = distance;
87 closest_leader = j;
88 }
89 }
90
91 if leaders.is_empty() || min_distance > threshold {
93 leaders.push(sample.to_owned());
95 let label_idx = leaders.len() - 1;
96 labels[i] = label_idx;
97 } else {
98 labels[i] = closest_leader;
100 }
101 }
102
103 let n_leaders = leaders.len();
105 let mut leaders_array = Array2::zeros((n_leaders, n_features));
106 for (i, leader) in leaders.iter().enumerate() {
107 leaders_array.row_mut(i).assign(leader);
108 }
109
110 Ok((leaders_array, labels))
111}
112
113#[allow(dead_code)]
115pub fn euclidean_distance<F: Float>(a: ArrayView1<F>, b: ArrayView1<F>) -> F {
116 a.iter()
117 .zip(b.iter())
118 .map(|(x, y)| (*x - *y) * (*x - *y))
119 .fold(F::zero(), |acc, x| acc + x)
120 .sqrt()
121}
122
123#[allow(dead_code)]
125pub fn manhattan_distance<F: Float>(a: ArrayView1<F>, b: ArrayView1<F>) -> F {
126 a.iter()
127 .zip(b.iter())
128 .map(|(x, y)| (*x - *y).abs())
129 .fold(F::zero(), |acc, x| acc + x)
130}
131
132pub struct LeaderClustering<F: Float> {
137 threshold: F,
138 leaders: Vec<Array1<F>>,
139}
140
141impl<F: Float + Debug> LeaderClustering<F> {
142 pub fn new(threshold: F) -> Result<Self> {
144 if threshold <= F::zero() {
145 return Err(ClusteringError::InvalidInput(
146 "Threshold must be positive".to_string(),
147 ));
148 }
149
150 Ok(Self {
151 threshold,
152 leaders: Vec::new(),
153 })
154 }
155
156 pub fn fit(&mut self, data: ArrayView2<F>) -> Result<()> {
158 self.leaders.clear();
159
160 for sample in data.rows() {
161 let mut min_distance = F::infinity();
162
163 for leader in &self.leaders {
165 let distance = euclidean_distance(sample, leader.view());
166 if distance < min_distance {
167 min_distance = distance;
168 }
169 }
170
171 if self.leaders.is_empty() || min_distance > self.threshold {
173 self.leaders.push(sample.to_owned());
174 }
175 }
176
177 Ok(())
178 }
179
180 pub fn predict(&self, data: ArrayView2<F>) -> Result<Array1<usize>> {
182 if self.leaders.is_empty() {
183 return Err(ClusteringError::InvalidState(
184 "Model has not been fitted yet".to_string(),
185 ));
186 }
187
188 let n_samples = data.nrows();
189 let mut labels = Array1::zeros(n_samples);
190
191 for (i, sample) in data.rows().into_iter().enumerate() {
192 let mut min_distance = F::infinity();
193 let mut closest_leader = 0;
194
195 for (j, leader) in self.leaders.iter().enumerate() {
196 let distance = euclidean_distance(sample, leader.view());
197 if distance < min_distance {
198 min_distance = distance;
199 closest_leader = j;
200 }
201 }
202
203 labels[i] = closest_leader;
204 }
205
206 Ok(labels)
207 }
208
209 pub fn fit_predict(&mut self, data: ArrayView2<F>) -> Result<Array1<usize>> {
211 self.fit(data)?;
212 self.predict(data)
213 }
214
215 pub fn get_leaders(&self) -> Array2<F> {
217 if self.leaders.is_empty() {
218 return Array2::zeros((0, 0));
219 }
220
221 let n_leaders = self.leaders.len();
222 let n_features = self.leaders[0].len();
223 let mut leaders_array = Array2::zeros((n_leaders, n_features));
224
225 for (i, leader) in self.leaders.iter().enumerate() {
226 leaders_array.row_mut(i).assign(leader);
227 }
228
229 leaders_array
230 }
231
232 pub fn n_clusters(&self) -> usize {
234 self.leaders.len()
235 }
236}
237
238#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
240pub struct LeaderTree<F: Float> {
241 pub roots: Vec<LeaderNode<F>>,
243 pub threshold: F,
245}
246
247#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
249pub struct LeaderNode<F: Float> {
250 pub leader: Array1<F>,
252 pub children: Vec<LeaderNode<F>>,
254 pub members: Vec<usize>,
256}
257
258impl<F: Float + Debug> LeaderTree<F> {
259 pub fn build_hierarchical(data: ArrayView2<F>, thresholds: &[F]) -> Result<Self> {
261 if thresholds.is_empty() {
262 return Err(ClusteringError::InvalidInput(
263 "At least one threshold is required".to_string(),
264 ));
265 }
266
267 let current_threshold = thresholds[0];
269 let (leaders, labels) = leader_clustering(data, current_threshold, euclidean_distance)?;
270
271 let mut roots = Vec::new();
273 for i in 0..leaders.nrows() {
274 let mut members = Vec::new();
275 for (j, &label) in labels.iter().enumerate() {
276 if label == i {
277 members.push(j);
278 }
279 }
280
281 roots.push(LeaderNode {
282 leader: leaders.row(i).to_owned(),
283 children: Vec::new(),
284 members,
285 });
286 }
287
288 if thresholds.len() > 1 {
290 for root in &mut roots {
291 Self::build_subtree(data, root, &thresholds[1..])?;
292 }
293 }
294
295 Ok(LeaderTree {
296 roots,
297 threshold: current_threshold,
298 })
299 }
300
301 fn build_subtree(
302 data: ArrayView2<F>,
303 parent: &mut LeaderNode<F>,
304 thresholds: &[F],
305 ) -> Result<()> {
306 if thresholds.is_empty() || parent.members.len() <= 1 {
307 return Ok(());
308 }
309
310 let n_features = data.ncols();
312 let mut cluster_data = Array2::zeros((parent.members.len(), n_features));
313 for (i, &idx) in parent.members.iter().enumerate() {
314 cluster_data.row_mut(i).assign(&data.row(idx));
315 }
316
317 let (sub_leaders, sub_labels) =
319 leader_clustering(cluster_data.view(), thresholds[0], euclidean_distance)?;
320
321 for i in 0..sub_leaders.nrows() {
323 let mut members = Vec::new();
324 for (j, &label) in sub_labels.iter().enumerate() {
325 if label == i {
326 members.push(parent.members[j]);
327 }
328 }
329
330 let mut child = LeaderNode {
331 leader: sub_leaders.row(i).to_owned(),
332 children: Vec::new(),
333 members,
334 };
335
336 if thresholds.len() > 1 {
338 Self::build_subtree(data, &mut child, &thresholds[1..])?;
339 }
340
341 parent.children.push(child);
342 }
343
344 Ok(())
345 }
346
347 pub fn node_count(&self) -> usize {
349 self.roots.iter().map(|root| Self::count_nodes(root)).sum()
350 }
351
352 fn count_nodes(node: &LeaderNode<F>) -> usize {
353 1 + node
354 .children
355 .iter()
356 .map(|child| Self::count_nodes(child))
357 .sum::<usize>()
358 }
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364 use scirs2_core::ndarray::array;
365
366 #[test]
367 fn test_leader_clustering_basic() {
368 let data = array![[1.0, 2.0], [1.2, 1.8], [5.0, 4.0], [5.2, 4.1],];
369
370 let (leaders, labels) = leader_clustering(data.view(), 1.0, euclidean_distance).unwrap();
371
372 assert_eq!(leaders.nrows(), 2);
374 assert_eq!(labels.len(), 4);
375
376 assert_eq!(labels[0], labels[1]);
378 assert_eq!(labels[2], labels[3]);
379 assert_ne!(labels[0], labels[2]);
380 }
381
382 #[test]
383 fn test_leader_clustering_single_cluster() {
384 let data = array![[1.0, 2.0], [1.2, 1.8], [1.1, 2.1], [0.9, 1.9],];
385
386 let (leaders, labels) = leader_clustering(data.view(), 2.0, euclidean_distance).unwrap();
387
388 assert_eq!(leaders.nrows(), 1);
390 assert!(labels.iter().all(|&l| l == 0));
391 }
392
393 #[test]
394 fn test_leader_class() {
395 let data = array![[1.0, 2.0], [1.2, 1.8], [5.0, 4.0], [5.2, 4.1],];
396
397 let mut leader = LeaderClustering::new(1.0).unwrap();
398 let labels = leader.fit_predict(data.view()).unwrap();
399
400 assert_eq!(leader.n_clusters(), 2);
401 assert_eq!(labels.len(), 4);
402
403 let new_data = array![[1.1, 1.9], [5.1, 4.05]];
405 let new_labels = leader.predict(new_data.view()).unwrap();
406 assert_eq!(new_labels[0], labels[0]); assert_eq!(new_labels[1], labels[2]); }
409
410 #[test]
411 fn test_hierarchical_leader_tree() {
412 let data = array![
413 [1.0, 2.0],
414 [1.2, 1.8],
415 [5.0, 4.0],
416 [5.2, 4.1],
417 [10.0, 10.0],
418 [10.2, 9.8],
419 ];
420
421 let thresholds = vec![6.0, 1.0];
422 let tree = LeaderTree::build_hierarchical(data.view(), &thresholds).unwrap();
423
424 assert!(tree.roots.len() <= 3);
426 assert!(tree.node_count() > tree.roots.len()); }
428
429 #[test]
430 fn test_invalid_threshold() {
431 let data = array![[1.0, 2.0]];
432
433 let result = leader_clustering(data.view(), -1.0, euclidean_distance);
434 assert!(result.is_err());
435
436 let result = LeaderClustering::<f64>::new(-1.0);
437 assert!(result.is_err());
438 }
439
440 #[test]
441 fn test_empty_data() {
442 let data: Array2<f64> = Array2::zeros((0, 2));
443
444 let result = leader_clustering(data.view(), 1.0, euclidean_distance);
445 assert!(result.is_err());
446 }
447}