1use crate::error::{ClusteringError, Result};
7use crate::leader::{LeaderNode, LeaderTree};
8use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
9use scirs2_core::numeric::Float;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13use super::core::SerializableModel;
14
15#[derive(Serialize, Deserialize, Debug, Clone)]
17pub struct KMeansModel {
18 pub centroids: Array2<f64>,
20 pub n_clusters: usize,
22 pub n_iter: usize,
24 pub inertia: f64,
26 pub labels: Option<Array1<usize>>,
28}
29
30impl SerializableModel for KMeansModel {}
31
32impl KMeansModel {
33 pub fn new(
35 centroids: Array2<f64>,
36 n_clusters: usize,
37 n_iter: usize,
38 inertia: f64,
39 labels: Option<Array1<usize>>,
40 ) -> Self {
41 Self {
42 centroids,
43 n_clusters,
44 n_iter,
45 inertia,
46 labels,
47 }
48 }
49
50 pub fn predict(&self, data: ArrayView2<f64>) -> Result<Array1<usize>> {
52 let n_samples = data.nrows();
53 let mut labels = Array1::zeros(n_samples);
54
55 for (i, sample) in data.rows().into_iter().enumerate() {
56 let mut min_distance = f64::INFINITY;
57 let mut closest_cluster = 0;
58
59 for (j, centroid) in self.centroids.rows().into_iter().enumerate() {
60 let distance = sample
61 .iter()
62 .zip(centroid.iter())
63 .map(|(a, b)| (a - b).powi(2))
64 .sum::<f64>()
65 .sqrt();
66
67 if distance < min_distance {
68 min_distance = distance;
69 closest_cluster = j;
70 }
71 }
72
73 labels[i] = closest_cluster;
74 }
75
76 Ok(labels)
77 }
78
79 pub fn predict_single(&self, point: &[f64]) -> Result<usize> {
81 if point.len() != self.centroids.ncols() {
82 return Err(ClusteringError::InvalidInput(
83 "Point dimensions must match centroid dimensions".to_string(),
84 ));
85 }
86
87 let mut min_distance = f64::INFINITY;
88 let mut closest_cluster = 0;
89
90 for (j, centroid) in self.centroids.rows().into_iter().enumerate() {
91 let distance = point
92 .iter()
93 .zip(centroid.iter())
94 .map(|(a, b)| (a - b).powi(2))
95 .sum::<f64>()
96 .sqrt();
97
98 if distance < min_distance {
99 min_distance = distance;
100 closest_cluster = j;
101 }
102 }
103
104 Ok(closest_cluster)
105 }
106}
107
108#[derive(Serialize, Deserialize, Debug, Clone)]
110pub struct HierarchicalModel {
111 pub linkage: Array2<f64>,
113 pub n_observations: usize,
115 pub method: String,
117 pub labels: Option<Vec<String>>,
119}
120
121impl SerializableModel for HierarchicalModel {}
122
123impl HierarchicalModel {
124 pub fn new(
126 linkage: Array2<f64>,
127 n_observations: usize,
128 method: String,
129 labels: Option<Vec<String>>,
130 ) -> Self {
131 Self {
132 linkage,
133 n_observations,
134 method,
135 labels,
136 }
137 }
138
139 pub fn to_newick(&self) -> Result<String> {
141 let mut newick = String::new();
142 let nnodes = self.linkage.nrows();
143
144 if nnodes == 0 {
145 return Ok("();".to_string());
146 }
147
148 self.validate_linkage_matrix()?;
149 self.build_newick_recursive(nnodes + self.n_observations - 1, &mut newick)?;
150
151 newick.push(';');
152 Ok(newick)
153 }
154
155 fn validate_linkage_matrix(&self) -> Result<()> {
157 let nnodes = self.linkage.nrows();
158
159 for i in 0..nnodes {
160 let left = self.linkage[[i, 0]] as usize;
161 let right = self.linkage[[i, 1]] as usize;
162 let distance = self.linkage[[i, 2]];
163
164 if left >= self.n_observations + i || right >= self.n_observations + i {
165 return Err(ClusteringError::InvalidInput(format!(
166 "Invalid node indices in linkage matrix at row {}: left={}, right={}",
167 i, left, right
168 )));
169 }
170
171 if distance < 0.0 {
172 return Err(ClusteringError::InvalidInput(format!(
173 "Negative distance in linkage matrix at row {}: {}",
174 i, distance
175 )));
176 }
177 }
178
179 Ok(())
180 }
181
182 fn build_newick_recursive(&self, nodeidx: usize, newick: &mut String) -> Result<()> {
184 if nodeidx < self.n_observations {
185 if let Some(ref labels) = self.labels {
186 newick.push_str(&labels[nodeidx]);
187 } else {
188 newick.push_str(&nodeidx.to_string());
189 }
190 } else {
191 let row_idx = nodeidx - self.n_observations;
192 if row_idx >= self.linkage.nrows() {
193 return Err(ClusteringError::InvalidInput(
194 "Invalid node index".to_string(),
195 ));
196 }
197
198 let left = self.linkage[[row_idx, 0]] as usize;
199 let right = self.linkage[[row_idx, 1]] as usize;
200 let distance = self.linkage[[row_idx, 2]];
201
202 newick.push('(');
203 self.build_newick_recursive(left, newick)?;
204 newick.push(':');
205 newick.push_str(&format!("{:.6}", distance / 2.0));
206 newick.push(',');
207 self.build_newick_recursive(right, newick)?;
208 newick.push(':');
209 newick.push_str(&format!("{:.6}", distance / 2.0));
210 newick.push(')');
211 }
212
213 Ok(())
214 }
215
216 pub fn to_json_tree(&self) -> Result<serde_json::Value> {
218 use serde_json::json;
219
220 let nnodes = self.linkage.nrows();
221 if nnodes == 0 {
222 return Ok(json!({}));
223 }
224
225 self.build_json_recursive(nnodes + self.n_observations - 1)
226 }
227
228 fn build_json_recursive(&self, nodeidx: usize) -> Result<serde_json::Value> {
229 use serde_json::json;
230
231 if nodeidx < self.n_observations {
232 let name = if let Some(ref labels) = self.labels {
233 labels[nodeidx].clone()
234 } else {
235 nodeidx.to_string()
236 };
237
238 Ok(json!({
239 "name": name,
240 "type": "leaf",
241 "index": nodeidx
242 }))
243 } else {
244 let row_idx = nodeidx - self.n_observations;
245 if row_idx >= self.linkage.nrows() {
246 return Err(ClusteringError::InvalidInput(
247 "Invalid node index".to_string(),
248 ));
249 }
250
251 let left = self.linkage[[row_idx, 0]] as usize;
252 let right = self.linkage[[row_idx, 1]] as usize;
253 let distance = self.linkage[[row_idx, 2]];
254
255 let left_child = self.build_json_recursive(left)?;
256 let right_child = self.build_json_recursive(right)?;
257
258 Ok(json!({
259 "type": "internal",
260 "distance": distance,
261 "children": [left_child, right_child]
262 }))
263 }
264 }
265}
266
267#[derive(Serialize, Deserialize, Debug, Clone)]
269pub struct DBSCANModel {
270 pub core_sample_indices: Array1<usize>,
272 pub labels: Array1<i32>,
274 pub eps: f64,
276 pub min_samples: usize,
278}
279
280impl SerializableModel for DBSCANModel {}
281
282impl DBSCANModel {
283 pub fn new(
285 core_sample_indices: Array1<usize>,
286 labels: Array1<i32>,
287 eps: f64,
288 min_samples: usize,
289 ) -> Self {
290 Self {
291 core_sample_indices,
292 labels,
293 eps,
294 min_samples,
295 }
296 }
297
298 pub fn n_clusters(&self) -> usize {
300 self.labels.iter().filter(|&&label| label >= 0).count()
301 }
302
303 pub fn noise_indices(&self) -> Vec<usize> {
305 self.labels
306 .iter()
307 .enumerate()
308 .filter_map(|(i, &label)| if label == -1 { Some(i) } else { None })
309 .collect()
310 }
311}
312
313#[derive(Serialize, Deserialize, Debug, Clone)]
315pub struct MeanShiftModel {
316 pub cluster_centers: Array2<f64>,
318 pub bandwidth: f64,
320 pub labels: Option<Array1<usize>>,
322}
323
324impl SerializableModel for MeanShiftModel {}
325
326#[derive(Serialize, Deserialize, Debug, Clone)]
328pub struct SpectralModel {
329 pub eigenvectors: Array2<f64>,
331 pub eigenvalues: Array1<f64>,
333 pub labels: Array1<usize>,
335 pub n_clusters: usize,
337 pub affinity: String,
339 pub gamma: Option<f64>,
341}
342
343impl SerializableModel for SpectralModel {}
344
345impl SpectralModel {
346 pub fn new(
348 eigenvectors: Array2<f64>,
349 eigenvalues: Array1<f64>,
350 labels: Array1<usize>,
351 n_clusters: usize,
352 affinity: String,
353 gamma: Option<f64>,
354 ) -> Self {
355 Self {
356 eigenvectors,
357 eigenvalues,
358 labels,
359 n_clusters,
360 affinity,
361 gamma,
362 }
363 }
364
365 pub fn predict(&self, data: ArrayView2<f64>) -> Result<Array1<usize>> {
367 let n_samples = data.nrows();
369 let mut labels = Array1::zeros(n_samples);
370
371 for (i, sample) in data.rows().into_iter().enumerate() {
372 let mut best_distance = f64::INFINITY;
373 let mut best_cluster = 0;
374
375 for cluster_id in 0..self.n_clusters {
376 let distance = sample
378 .iter()
379 .zip(
380 self.eigenvectors
381 .row(cluster_id % self.eigenvectors.nrows())
382 .iter(),
383 )
384 .map(|(a, b)| (a - b).powi(2))
385 .sum::<f64>()
386 .sqrt();
387
388 if distance < best_distance {
389 best_distance = distance;
390 best_cluster = cluster_id;
391 }
392 }
393
394 labels[i] = best_cluster;
395 }
396
397 Ok(labels)
398 }
399}
400
401pub trait ClusteringModel: SerializableModel {
403 fn n_clusters(&self) -> usize;
405
406 fn predict(&self, data: ArrayView2<f64>) -> Result<Array1<usize>>;
408
409 fn summary(&self) -> Result<serde_json::Value>;
411}
412
413impl ClusteringModel for KMeansModel {
414 fn n_clusters(&self) -> usize {
415 self.n_clusters
416 }
417
418 fn predict(&self, data: ArrayView2<f64>) -> Result<Array1<usize>> {
419 let n_samples = data.nrows();
421 let mut labels = Array1::zeros(n_samples);
422
423 for (i, sample) in data.axis_iter(scirs2_core::ndarray::Axis(0)).enumerate() {
424 let mut min_dist = f64::INFINITY;
425 let mut best_cluster = 0;
426
427 for (j, centroid) in self
428 .centroids
429 .axis_iter(scirs2_core::ndarray::Axis(0))
430 .enumerate()
431 {
432 let dist: f64 = sample
433 .iter()
434 .zip(centroid.iter())
435 .map(|(a, b)| (a - b).powi(2))
436 .sum::<f64>()
437 .sqrt();
438
439 if dist < min_dist {
440 min_dist = dist;
441 best_cluster = j;
442 }
443 }
444
445 labels[i] = best_cluster;
446 }
447
448 Ok(labels)
449 }
450
451 fn summary(&self) -> Result<serde_json::Value> {
452 Ok(serde_json::json!({
453 "algorithm": "K-Means",
454 "n_clusters": self.n_clusters,
455 "n_features": self.centroids.ncols(),
456 "n_iterations": self.n_iter,
457 "inertia": self.inertia,
458 "has_training_labels": self.labels.is_some()
459 }))
460 }
461}
462
463impl ClusteringModel for DBSCANModel {
464 fn n_clusters(&self) -> usize {
465 self.labels
466 .iter()
467 .filter(|&&x| x >= 0)
468 .map(|&x| x as usize)
469 .max()
470 .map(|x| x + 1)
471 .unwrap_or(0)
472 }
473
474 fn predict(&self, _data: ArrayView2<f64>) -> Result<Array1<usize>> {
475 Err(ClusteringError::InvalidInput(
477 "DBSCAN does not support prediction on new data. Use fit() instead.".to_string(),
478 ))
479 }
480
481 fn summary(&self) -> Result<serde_json::Value> {
482 let n_clusters = self.n_clusters();
483 let n_noise = self.labels.iter().filter(|&&x| x == -1).count();
484
485 Ok(serde_json::json!({
486 "algorithm": "DBSCAN",
487 "n_clusters": n_clusters,
488 "n_core_samples": self.core_sample_indices.len(),
489 "n_noise_points": n_noise,
490 "eps": self.eps,
491 "min_samples": self.min_samples
492 }))
493 }
494}
495
496impl ClusteringModel for SpectralModel {
497 fn n_clusters(&self) -> usize {
498 self.n_clusters
499 }
500
501 fn predict(&self, data: ArrayView2<f64>) -> Result<Array1<usize>> {
502 self.predict(data)
503 }
504
505 fn summary(&self) -> Result<serde_json::Value> {
506 Ok(serde_json::json!({
507 "algorithm": "Spectral Clustering",
508 "n_clusters": self.n_clusters,
509 "n_eigenvectors": self.eigenvectors.ncols(),
510 "affinity": self.affinity,
511 "gamma": self.gamma
512 }))
513 }
514}
515
516impl MeanShiftModel {
517 pub fn new(
519 cluster_centers: Array2<f64>,
520 bandwidth: f64,
521 labels: Option<Array1<usize>>,
522 ) -> Self {
523 Self {
524 cluster_centers,
525 bandwidth,
526 labels,
527 }
528 }
529
530 pub fn n_clusters(&self) -> usize {
532 self.cluster_centers.nrows()
533 }
534}
535
536#[derive(Serialize, Deserialize, Debug, Clone)]
538pub struct LeaderModel {
539 pub leaders: Vec<LeaderNode<f64>>,
541 pub threshold: f64,
543 pub metric: String,
545}
546
547impl SerializableModel for LeaderModel {}
548
549impl LeaderModel {
550 pub fn new(leaders: Vec<LeaderNode<f64>>, threshold: f64, metric: String) -> Self {
552 Self {
553 leaders,
554 threshold,
555 metric,
556 }
557 }
558
559 pub fn n_clusters(&self) -> usize {
561 self.leaders.len()
562 }
563
564 pub fn predict_single(&self, point: &[f64]) -> Result<Option<usize>> {
566 let mut best_leader = None;
567 let mut min_distance = self.threshold;
568
569 for (i, leader) in self.leaders.iter().enumerate() {
570 let distance = match self.metric.as_str() {
571 "euclidean" => point
572 .iter()
573 .zip(leader.leader.iter())
574 .map(|(a, b)| (a - b).powi(2))
575 .sum::<f64>()
576 .sqrt(),
577 "manhattan" => point
578 .iter()
579 .zip(leader.leader.iter())
580 .map(|(a, b)| (a - b).abs())
581 .sum::<f64>(),
582 _ => return Err(ClusteringError::InvalidInput("Unknown metric".to_string())),
583 };
584
585 if distance < min_distance {
586 min_distance = distance;
587 best_leader = Some(i);
588 }
589 }
590
591 Ok(best_leader)
592 }
593}
594
595#[derive(Serialize, Deserialize, Debug, Clone)]
597pub struct LeaderTreeModel<F: Float> {
598 pub tree: LeaderTree<F>,
600 pub threshold: F,
602 pub metric: String,
604}
605
606impl<F: Float + Serialize + for<'de> Deserialize<'de>> SerializableModel for LeaderTreeModel<F> {}
607
608#[derive(Serialize, Deserialize, Debug, Clone)]
610pub struct AffinityPropagationModel {
611 pub cluster_centers: Array2<f64>,
613 pub labels: Array1<i32>,
615 pub affinity_matrix: Array2<f64>,
617 pub converged: bool,
619 pub n_iter: usize,
621}
622
623impl SerializableModel for AffinityPropagationModel {}
624
625impl AffinityPropagationModel {
626 pub fn new(
628 cluster_centers: Array2<f64>,
629 labels: Array1<i32>,
630 affinity_matrix: Array2<f64>,
631 converged: bool,
632 n_iter: usize,
633 ) -> Self {
634 Self {
635 cluster_centers,
636 labels,
637 affinity_matrix,
638 converged,
639 n_iter,
640 }
641 }
642
643 pub fn n_clusters(&self) -> usize {
645 self.cluster_centers.nrows()
646 }
647}
648
649#[derive(Serialize, Deserialize, Debug, Clone)]
651pub struct BirchModel {
652 pub centroids: Array2<f64>,
654 pub threshold: f64,
656 pub branching_factor: usize,
658 pub n_subclusters: usize,
660}
661
662impl SerializableModel for BirchModel {}
663
664impl BirchModel {
665 pub fn new(
667 centroids: Array2<f64>,
668 threshold: f64,
669 branching_factor: usize,
670 n_subclusters: usize,
671 ) -> Self {
672 Self {
673 centroids,
674 threshold,
675 branching_factor,
676 n_subclusters,
677 }
678 }
679
680 pub fn n_clusters(&self) -> usize {
682 self.centroids.nrows()
683 }
684}
685
686#[derive(Serialize, Deserialize, Debug, Clone)]
688pub struct GMMModel {
689 pub weights: Array1<f64>,
691 pub means: Array2<f64>,
693 pub covariances: Vec<Array2<f64>>,
695 pub n_components: usize,
697 pub covariance_type: String,
699 pub log_likelihood: f64,
701 pub converged: bool,
703 pub n_iter: usize,
705}
706
707impl SerializableModel for GMMModel {}
708
709impl GMMModel {
710 pub fn new(
712 weights: Array1<f64>,
713 means: Array2<f64>,
714 covariances: Vec<Array2<f64>>,
715 n_components: usize,
716 covariance_type: String,
717 log_likelihood: f64,
718 converged: bool,
719 n_iter: usize,
720 ) -> Self {
721 Self {
722 weights,
723 means,
724 covariances,
725 n_components,
726 covariance_type,
727 log_likelihood,
728 converged,
729 n_iter,
730 }
731 }
732
733 pub fn predict_proba(&self, data: ArrayView2<f64>) -> Result<Array2<f64>> {
735 let n_samples = data.nrows();
736 let mut probabilities = Array2::zeros((n_samples, self.n_components));
737
738 for (i, sample) in data.rows().into_iter().enumerate() {
739 for j in 0..self.n_components {
740 let mean = self.means.row(j);
741 let diff: Vec<f64> = sample.iter().zip(mean.iter()).map(|(a, b)| a - b).collect();
742
743 let distance = diff.iter().map(|x| x * x).sum::<f64>().sqrt();
745 probabilities[[i, j]] = self.weights[j] * (-distance / 2.0).exp();
746 }
747 }
748
749 for i in 0..n_samples {
751 let sum: f64 = probabilities.row(i).sum();
752 if sum > 0.0 {
753 for j in 0..self.n_components {
754 probabilities[[i, j]] /= sum;
755 }
756 }
757 }
758
759 Ok(probabilities)
760 }
761}
762
763#[derive(Serialize, Deserialize, Debug, Clone)]
765pub struct SpectralClusteringModel {
766 pub labels: Array1<usize>,
768 pub affinity_matrix: Array2<f64>,
770 pub eigenvalues: Array1<f64>,
772 pub eigenvectors: Array2<f64>,
774 pub n_clusters: usize,
776}
777
778impl SerializableModel for SpectralClusteringModel {}
779
780impl SpectralClusteringModel {
781 pub fn new(
783 labels: Array1<usize>,
784 affinity_matrix: Array2<f64>,
785 eigenvalues: Array1<f64>,
786 eigenvectors: Array2<f64>,
787 n_clusters: usize,
788 ) -> Self {
789 Self {
790 labels,
791 affinity_matrix,
792 eigenvalues,
793 eigenvectors,
794 n_clusters,
795 }
796 }
797}
798
799pub fn kmeans_to_model(
803 centroids: Array2<f64>,
804 labels: Option<Array1<usize>>,
805 n_iter: usize,
806 inertia: f64,
807) -> KMeansModel {
808 let n_clusters = centroids.nrows();
809 KMeansModel::new(centroids, n_clusters, n_iter, inertia, labels)
810}
811
812pub fn dbscan_to_model(
814 core_sample_indices: Vec<usize>,
815 components: Array2<f64>,
816 labels: Array1<i32>,
817 eps: f64,
818 min_samples: usize,
819) -> DBSCANModel {
820 DBSCANModel::new(
821 Array1::from_vec(core_sample_indices),
822 labels,
823 eps,
824 min_samples,
825 )
826}
827
828pub fn hierarchy_to_model(
830 n_clusters: usize,
831 labels: Array1<usize>,
832 linkage_matrix: Array2<f64>,
833 distances: Vec<f64>,
834) -> HierarchicalModel {
835 HierarchicalModel::new(linkage_matrix, n_clusters, "ward".to_string(), None)
836}
837
838pub fn gmm_to_model(
840 weights: Array1<f64>,
841 means: Array2<f64>,
842 covariances: Vec<Array2<f64>>,
843 n_components: usize,
844 covariance_type: String,
845 log_likelihood: f64,
846 converged: bool,
847 n_iter: usize,
848) -> GMMModel {
849 GMMModel::new(
850 weights,
851 means,
852 covariances,
853 n_components,
854 covariance_type,
855 log_likelihood,
856 converged,
857 n_iter,
858 )
859}
860
861pub fn meanshift_to_model(
863 cluster_centers: Array2<f64>,
864 labels: Array1<usize>,
865 bandwidth: f64,
866 n_iter: usize,
867) -> MeanShiftModel {
868 MeanShiftModel::new(cluster_centers, bandwidth, Some(labels))
869}
870
871pub fn affinity_propagation_to_model(
873 exemplars: Vec<usize>,
874 labels: Array1<i32>,
875 damping: f64,
876 preference: f64,
877 n_iter: usize,
878) -> AffinityPropagationModel {
879 let n_clusters = exemplars.len();
881 let n_features = if n_clusters > 0 { 2 } else { 0 }; let cluster_centers = Array2::zeros((n_clusters, n_features));
883 let affinity_matrix = Array2::zeros((labels.len(), labels.len()));
884
885 AffinityPropagationModel::new(cluster_centers, labels, affinity_matrix, true, n_iter)
886}
887
888pub fn birch_to_model(
890 centroids: Array2<f64>,
891 threshold: f64,
892 branching_factor: usize,
893 n_subclusters: usize,
894) -> BirchModel {
895 BirchModel::new(centroids, threshold, branching_factor, n_subclusters)
896}
897
898pub fn leader_to_model(
900 leaders: Vec<LeaderNode<f64>>,
901 threshold: f64,
902 distance_metric: String,
903) -> LeaderModel {
904 LeaderModel {
907 leaders,
908 threshold,
909 metric: distance_metric,
910 }
911}
912
913pub fn leadertree_to_model(
915 tree: Option<LeaderTree<f64>>,
916 threshold: f64,
917 max_depth: usize,
918) -> LeaderTreeModel<f64> {
919 LeaderTreeModel {
920 tree: tree.unwrap_or_else(|| LeaderTree {
921 roots: Vec::new(),
922 threshold,
923 }),
924 threshold,
925 metric: "euclidean".to_string(),
926 }
927}
928
929pub fn spectral_clustering_to_model(
931 labels: Array1<usize>,
932 affinity_matrix: Array2<f64>,
933 eigenvalues: Array1<f64>,
934 eigenvectors: Array2<f64>,
935 n_clusters: usize,
936) -> SpectralClusteringModel {
937 SpectralClusteringModel::new(
938 labels,
939 affinity_matrix,
940 eigenvalues,
941 eigenvectors,
942 n_clusters,
943 )
944}
945
946pub fn save_kmeans<P: AsRef<std::path::Path>>(model: &KMeansModel, path: P) -> Result<()> {
950 model.save_to_file(path)
951}
952
953pub fn save_dbscan<P: AsRef<std::path::Path>>(model: &DBSCANModel, path: P) -> Result<()> {
955 model.save_to_file(path)
956}
957
958pub fn save_hierarchy<P: AsRef<std::path::Path>>(model: &HierarchicalModel, path: P) -> Result<()> {
960 model.save_to_file(path)
961}
962
963pub fn save_gmm<P: AsRef<std::path::Path>>(model: &GMMModel, path: P) -> Result<()> {
965 model.save_to_file(path)
966}
967
968pub fn save_meanshift<P: AsRef<std::path::Path>>(model: &MeanShiftModel, path: P) -> Result<()> {
970 model.save_to_file(path)
971}
972
973pub fn save_affinity_propagation<P: AsRef<std::path::Path>>(
975 exemplars: Vec<usize>,
976 labels: Array1<i32>,
977 damping: f64,
978 preference: f64,
979 n_iter: usize,
980 path: P,
981) -> Result<()> {
982 let model = affinity_propagation_to_model(exemplars, labels, damping, preference, n_iter);
983 model.save_to_file(path)
984}
985
986pub fn save_birch<P: AsRef<std::path::Path>>(model: &BirchModel, path: P) -> Result<()> {
988 model.save_to_file(path)
989}
990
991pub fn save_leader<P: AsRef<std::path::Path>>(model: &LeaderModel, path: P) -> Result<()> {
993 model.save_to_file(path)
994}
995
996pub fn save_leadertree<
998 F: Float + Serialize + for<'de> serde::Deserialize<'de>,
999 P: AsRef<std::path::Path>,
1000>(
1001 model: &LeaderTreeModel<F>,
1002 path: P,
1003) -> Result<()> {
1004 model.save_to_file(path)
1005}
1006
1007pub fn save_spectral_clustering<P: AsRef<std::path::Path>>(
1009 model: &SpectralClusteringModel,
1010 path: P,
1011) -> Result<()> {
1012 model.save_to_file(path)
1013}
1014
1015#[cfg(test)]
1016mod tests {
1017 use super::*;
1018 use scirs2_core::ndarray::Array2;
1019
1020 #[test]
1021 fn test_kmeans_model_predict() {
1022 let centroids = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 1.0, 1.0]).unwrap();
1023 let model = KMeansModel::new(centroids, 2, 10, 0.5, None);
1024
1025 let data = Array2::from_shape_vec((2, 2), vec![0.1, 0.1, 0.9, 0.9]).unwrap();
1026 let labels = model.predict(data.view()).unwrap();
1027
1028 assert_eq!(labels[0], 0); assert_eq!(labels[1], 1); }
1031
1032 #[test]
1033 fn test_dbscan_model_clusters() {
1034 let core_indices = Array1::from_vec(vec![0, 1, 2]);
1035 let labels = Array1::from_vec(vec![0, 0, 1, -1]);
1036 let model = DBSCANModel::new(core_indices, labels, 0.5, 2);
1037
1038 assert_eq!(model.n_clusters(), 3); assert_eq!(model.noise_indices(), vec![3]); }
1041
1042 #[test]
1043 fn test_hierarchical_model_newick() {
1044 let linkage = Array2::from_shape_vec((1, 3), vec![0.0, 1.0, 0.5]).unwrap();
1045 let model = HierarchicalModel::new(linkage, 2, "ward".to_string(), None);
1046
1047 let newick = model.to_newick().unwrap();
1048 assert!(newick.contains("("));
1049 assert!(newick.contains(")"));
1050 assert!(newick.ends_with(";"));
1051 }
1052}