scirs2_cluster/serialization/
models.rs

1//! Clustering model implementations for serialization
2//!
3//! This module contains serializable model structures for different
4//! clustering algorithms including K-means, DBSCAN, hierarchical clustering, etc.
5
6use crate::error::{ClusteringError, Result};
7use crate::leader::{LeaderNode, LeaderTree};
8use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
9use scirs2_core::numeric::Float;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13use super::core::SerializableModel;
14
15/// K-means clustering model that can be serialized
16#[derive(Serialize, Deserialize, Debug, Clone)]
17pub struct KMeansModel {
18    /// Cluster centroids
19    pub centroids: Array2<f64>,
20    /// Number of clusters
21    pub n_clusters: usize,
22    /// Number of iterations performed
23    pub n_iter: usize,
24    /// Sum of squared distances
25    pub inertia: f64,
26    /// Cluster labels for training data (optional)
27    pub labels: Option<Array1<usize>>,
28}
29
30impl SerializableModel for KMeansModel {}
31
32impl KMeansModel {
33    /// Create a new K-means model
34    pub fn new(
35        centroids: Array2<f64>,
36        n_clusters: usize,
37        n_iter: usize,
38        inertia: f64,
39        labels: Option<Array1<usize>>,
40    ) -> Self {
41        Self {
42            centroids,
43            n_clusters,
44            n_iter,
45            inertia,
46            labels,
47        }
48    }
49
50    /// Predict cluster labels for new data
51    pub fn predict(&self, data: ArrayView2<f64>) -> Result<Array1<usize>> {
52        let n_samples = data.nrows();
53        let mut labels = Array1::zeros(n_samples);
54
55        for (i, sample) in data.rows().into_iter().enumerate() {
56            let mut min_distance = f64::INFINITY;
57            let mut closest_cluster = 0;
58
59            for (j, centroid) in self.centroids.rows().into_iter().enumerate() {
60                let distance = sample
61                    .iter()
62                    .zip(centroid.iter())
63                    .map(|(a, b)| (a - b).powi(2))
64                    .sum::<f64>()
65                    .sqrt();
66
67                if distance < min_distance {
68                    min_distance = distance;
69                    closest_cluster = j;
70                }
71            }
72
73            labels[i] = closest_cluster;
74        }
75
76        Ok(labels)
77    }
78
79    /// Get the closest cluster center for a single point
80    pub fn predict_single(&self, point: &[f64]) -> Result<usize> {
81        if point.len() != self.centroids.ncols() {
82            return Err(ClusteringError::InvalidInput(
83                "Point dimensions must match centroid dimensions".to_string(),
84            ));
85        }
86
87        let mut min_distance = f64::INFINITY;
88        let mut closest_cluster = 0;
89
90        for (j, centroid) in self.centroids.rows().into_iter().enumerate() {
91            let distance = point
92                .iter()
93                .zip(centroid.iter())
94                .map(|(a, b)| (a - b).powi(2))
95                .sum::<f64>()
96                .sqrt();
97
98            if distance < min_distance {
99                min_distance = distance;
100                closest_cluster = j;
101            }
102        }
103
104        Ok(closest_cluster)
105    }
106}
107
108/// Hierarchical clustering result that can be serialized
109#[derive(Serialize, Deserialize, Debug, Clone)]
110pub struct HierarchicalModel {
111    /// Linkage matrix
112    pub linkage: Array2<f64>,
113    /// Number of original observations
114    pub n_observations: usize,
115    /// Method used for linkage
116    pub method: String,
117    /// Dendrogram labels (optional)
118    pub labels: Option<Vec<String>>,
119}
120
121impl SerializableModel for HierarchicalModel {}
122
123impl HierarchicalModel {
124    /// Create a new hierarchical clustering model
125    pub fn new(
126        linkage: Array2<f64>,
127        n_observations: usize,
128        method: String,
129        labels: Option<Vec<String>>,
130    ) -> Self {
131        Self {
132            linkage,
133            n_observations,
134            method,
135            labels,
136        }
137    }
138
139    /// Export dendrogram to Newick format
140    pub fn to_newick(&self) -> Result<String> {
141        let mut newick = String::new();
142        let nnodes = self.linkage.nrows();
143
144        if nnodes == 0 {
145            return Ok("();".to_string());
146        }
147
148        self.validate_linkage_matrix()?;
149        self.build_newick_recursive(nnodes + self.n_observations - 1, &mut newick)?;
150
151        newick.push(';');
152        Ok(newick)
153    }
154
155    /// Validate linkage matrix for consistency
156    fn validate_linkage_matrix(&self) -> Result<()> {
157        let nnodes = self.linkage.nrows();
158
159        for i in 0..nnodes {
160            let left = self.linkage[[i, 0]] as usize;
161            let right = self.linkage[[i, 1]] as usize;
162            let distance = self.linkage[[i, 2]];
163
164            if left >= self.n_observations + i || right >= self.n_observations + i {
165                return Err(ClusteringError::InvalidInput(format!(
166                    "Invalid node indices in linkage matrix at row {}: left={}, right={}",
167                    i, left, right
168                )));
169            }
170
171            if distance < 0.0 {
172                return Err(ClusteringError::InvalidInput(format!(
173                    "Negative distance in linkage matrix at row {}: {}",
174                    i, distance
175                )));
176            }
177        }
178
179        Ok(())
180    }
181
182    /// Build Newick string recursively
183    fn build_newick_recursive(&self, nodeidx: usize, newick: &mut String) -> Result<()> {
184        if nodeidx < self.n_observations {
185            if let Some(ref labels) = self.labels {
186                newick.push_str(&labels[nodeidx]);
187            } else {
188                newick.push_str(&nodeidx.to_string());
189            }
190        } else {
191            let row_idx = nodeidx - self.n_observations;
192            if row_idx >= self.linkage.nrows() {
193                return Err(ClusteringError::InvalidInput(
194                    "Invalid node index".to_string(),
195                ));
196            }
197
198            let left = self.linkage[[row_idx, 0]] as usize;
199            let right = self.linkage[[row_idx, 1]] as usize;
200            let distance = self.linkage[[row_idx, 2]];
201
202            newick.push('(');
203            self.build_newick_recursive(left, newick)?;
204            newick.push(':');
205            newick.push_str(&format!("{:.6}", distance / 2.0));
206            newick.push(',');
207            self.build_newick_recursive(right, newick)?;
208            newick.push(':');
209            newick.push_str(&format!("{:.6}", distance / 2.0));
210            newick.push(')');
211        }
212
213        Ok(())
214    }
215
216    /// Export dendrogram to JSON format
217    pub fn to_json_tree(&self) -> Result<serde_json::Value> {
218        use serde_json::json;
219
220        let nnodes = self.linkage.nrows();
221        if nnodes == 0 {
222            return Ok(json!({}));
223        }
224
225        self.build_json_recursive(nnodes + self.n_observations - 1)
226    }
227
228    fn build_json_recursive(&self, nodeidx: usize) -> Result<serde_json::Value> {
229        use serde_json::json;
230
231        if nodeidx < self.n_observations {
232            let name = if let Some(ref labels) = self.labels {
233                labels[nodeidx].clone()
234            } else {
235                nodeidx.to_string()
236            };
237
238            Ok(json!({
239                "name": name,
240                "type": "leaf",
241                "index": nodeidx
242            }))
243        } else {
244            let row_idx = nodeidx - self.n_observations;
245            if row_idx >= self.linkage.nrows() {
246                return Err(ClusteringError::InvalidInput(
247                    "Invalid node index".to_string(),
248                ));
249            }
250
251            let left = self.linkage[[row_idx, 0]] as usize;
252            let right = self.linkage[[row_idx, 1]] as usize;
253            let distance = self.linkage[[row_idx, 2]];
254
255            let left_child = self.build_json_recursive(left)?;
256            let right_child = self.build_json_recursive(right)?;
257
258            Ok(json!({
259                "type": "internal",
260                "distance": distance,
261                "children": [left_child, right_child]
262            }))
263        }
264    }
265}
266
267/// DBSCAN model that can be serialized
268#[derive(Serialize, Deserialize, Debug, Clone)]
269pub struct DBSCANModel {
270    /// Core sample indices
271    pub core_sample_indices: Array1<usize>,
272    /// Cluster labels
273    pub labels: Array1<i32>,
274    /// Epsilon parameter
275    pub eps: f64,
276    /// Min samples parameter
277    pub min_samples: usize,
278}
279
280impl SerializableModel for DBSCANModel {}
281
282impl DBSCANModel {
283    /// Create a new DBSCAN model
284    pub fn new(
285        core_sample_indices: Array1<usize>,
286        labels: Array1<i32>,
287        eps: f64,
288        min_samples: usize,
289    ) -> Self {
290        Self {
291            core_sample_indices,
292            labels,
293            eps,
294            min_samples,
295        }
296    }
297
298    /// Get number of clusters (excluding noise)
299    pub fn n_clusters(&self) -> usize {
300        self.labels.iter().filter(|&&label| label >= 0).count()
301    }
302
303    /// Get noise point indices
304    pub fn noise_indices(&self) -> Vec<usize> {
305        self.labels
306            .iter()
307            .enumerate()
308            .filter_map(|(i, &label)| if label == -1 { Some(i) } else { None })
309            .collect()
310    }
311}
312
313/// Mean Shift model that can be serialized
314#[derive(Serialize, Deserialize, Debug, Clone)]
315pub struct MeanShiftModel {
316    /// Cluster centers
317    pub cluster_centers: Array2<f64>,
318    /// Bandwidth parameter
319    pub bandwidth: f64,
320    /// Cluster labels (optional)
321    pub labels: Option<Array1<usize>>,
322}
323
324impl SerializableModel for MeanShiftModel {}
325
326/// Spectral clustering model that can be serialized
327#[derive(Serialize, Deserialize, Debug, Clone)]
328pub struct SpectralModel {
329    /// Eigenvectors used for clustering
330    pub eigenvectors: Array2<f64>,
331    /// Eigenvalues
332    pub eigenvalues: Array1<f64>,
333    /// Cluster labels
334    pub labels: Array1<usize>,
335    /// Number of clusters
336    pub n_clusters: usize,
337    /// Affinity matrix parameters
338    pub affinity: String,
339    /// Gamma parameter for RBF kernel
340    pub gamma: Option<f64>,
341}
342
343impl SerializableModel for SpectralModel {}
344
345impl SpectralModel {
346    /// Create a new spectral clustering model
347    pub fn new(
348        eigenvectors: Array2<f64>,
349        eigenvalues: Array1<f64>,
350        labels: Array1<usize>,
351        n_clusters: usize,
352        affinity: String,
353        gamma: Option<f64>,
354    ) -> Self {
355        Self {
356            eigenvectors,
357            eigenvalues,
358            labels,
359            n_clusters,
360            affinity,
361            gamma,
362        }
363    }
364
365    /// Predict cluster labels for new data
366    pub fn predict(&self, data: ArrayView2<f64>) -> Result<Array1<usize>> {
367        // Simple prediction based on closest eigenvector projection
368        let n_samples = data.nrows();
369        let mut labels = Array1::zeros(n_samples);
370
371        for (i, sample) in data.rows().into_iter().enumerate() {
372            let mut best_distance = f64::INFINITY;
373            let mut best_cluster = 0;
374
375            for cluster_id in 0..self.n_clusters {
376                // Simple distance to cluster center in eigenspace
377                let distance = sample
378                    .iter()
379                    .zip(
380                        self.eigenvectors
381                            .row(cluster_id % self.eigenvectors.nrows())
382                            .iter(),
383                    )
384                    .map(|(a, b)| (a - b).powi(2))
385                    .sum::<f64>()
386                    .sqrt();
387
388                if distance < best_distance {
389                    best_distance = distance;
390                    best_cluster = cluster_id;
391                }
392            }
393
394            labels[i] = best_cluster;
395        }
396
397        Ok(labels)
398    }
399}
400
401/// Generic clustering model trait
402pub trait ClusteringModel: SerializableModel {
403    /// Get the number of clusters
404    fn n_clusters(&self) -> usize;
405
406    /// Predict cluster labels for new data
407    fn predict(&self, data: ArrayView2<f64>) -> Result<Array1<usize>>;
408
409    /// Get model summary as JSON
410    fn summary(&self) -> Result<serde_json::Value>;
411}
412
413impl ClusteringModel for KMeansModel {
414    fn n_clusters(&self) -> usize {
415        self.n_clusters
416    }
417
418    fn predict(&self, data: ArrayView2<f64>) -> Result<Array1<usize>> {
419        // Find nearest centroid for each data point
420        let n_samples = data.nrows();
421        let mut labels = Array1::zeros(n_samples);
422
423        for (i, sample) in data.axis_iter(scirs2_core::ndarray::Axis(0)).enumerate() {
424            let mut min_dist = f64::INFINITY;
425            let mut best_cluster = 0;
426
427            for (j, centroid) in self
428                .centroids
429                .axis_iter(scirs2_core::ndarray::Axis(0))
430                .enumerate()
431            {
432                let dist: f64 = sample
433                    .iter()
434                    .zip(centroid.iter())
435                    .map(|(a, b)| (a - b).powi(2))
436                    .sum::<f64>()
437                    .sqrt();
438
439                if dist < min_dist {
440                    min_dist = dist;
441                    best_cluster = j;
442                }
443            }
444
445            labels[i] = best_cluster;
446        }
447
448        Ok(labels)
449    }
450
451    fn summary(&self) -> Result<serde_json::Value> {
452        Ok(serde_json::json!({
453            "algorithm": "K-Means",
454            "n_clusters": self.n_clusters,
455            "n_features": self.centroids.ncols(),
456            "n_iterations": self.n_iter,
457            "inertia": self.inertia,
458            "has_training_labels": self.labels.is_some()
459        }))
460    }
461}
462
463impl ClusteringModel for DBSCANModel {
464    fn n_clusters(&self) -> usize {
465        self.labels
466            .iter()
467            .filter(|&&x| x >= 0)
468            .map(|&x| x as usize)
469            .max()
470            .map(|x| x + 1)
471            .unwrap_or(0)
472    }
473
474    fn predict(&self, _data: ArrayView2<f64>) -> Result<Array1<usize>> {
475        // DBSCAN doesn't support prediction on new data without re-running the algorithm
476        Err(ClusteringError::InvalidInput(
477            "DBSCAN does not support prediction on new data. Use fit() instead.".to_string(),
478        ))
479    }
480
481    fn summary(&self) -> Result<serde_json::Value> {
482        let n_clusters = self.n_clusters();
483        let n_noise = self.labels.iter().filter(|&&x| x == -1).count();
484
485        Ok(serde_json::json!({
486            "algorithm": "DBSCAN",
487            "n_clusters": n_clusters,
488            "n_core_samples": self.core_sample_indices.len(),
489            "n_noise_points": n_noise,
490            "eps": self.eps,
491            "min_samples": self.min_samples
492        }))
493    }
494}
495
496impl ClusteringModel for SpectralModel {
497    fn n_clusters(&self) -> usize {
498        self.n_clusters
499    }
500
501    fn predict(&self, data: ArrayView2<f64>) -> Result<Array1<usize>> {
502        self.predict(data)
503    }
504
505    fn summary(&self) -> Result<serde_json::Value> {
506        Ok(serde_json::json!({
507            "algorithm": "Spectral Clustering",
508            "n_clusters": self.n_clusters,
509            "n_eigenvectors": self.eigenvectors.ncols(),
510            "affinity": self.affinity,
511            "gamma": self.gamma
512        }))
513    }
514}
515
516impl MeanShiftModel {
517    /// Create a new Mean Shift model
518    pub fn new(
519        cluster_centers: Array2<f64>,
520        bandwidth: f64,
521        labels: Option<Array1<usize>>,
522    ) -> Self {
523        Self {
524            cluster_centers,
525            bandwidth,
526            labels,
527        }
528    }
529
530    /// Get number of clusters
531    pub fn n_clusters(&self) -> usize {
532        self.cluster_centers.nrows()
533    }
534}
535
536/// Leader clustering model that can be serialized
537#[derive(Serialize, Deserialize, Debug, Clone)]
538pub struct LeaderModel {
539    /// Leader nodes
540    pub leaders: Vec<LeaderNode<f64>>,
541    /// Threshold parameter
542    pub threshold: f64,
543    /// Distance metric used
544    pub metric: String,
545}
546
547impl SerializableModel for LeaderModel {}
548
549impl LeaderModel {
550    /// Create a new Leader model
551    pub fn new(leaders: Vec<LeaderNode<f64>>, threshold: f64, metric: String) -> Self {
552        Self {
553            leaders,
554            threshold,
555            metric,
556        }
557    }
558
559    /// Get number of clusters
560    pub fn n_clusters(&self) -> usize {
561        self.leaders.len()
562    }
563
564    /// Predict cluster for a new point
565    pub fn predict_single(&self, point: &[f64]) -> Result<Option<usize>> {
566        let mut best_leader = None;
567        let mut min_distance = self.threshold;
568
569        for (i, leader) in self.leaders.iter().enumerate() {
570            let distance = match self.metric.as_str() {
571                "euclidean" => point
572                    .iter()
573                    .zip(leader.leader.iter())
574                    .map(|(a, b)| (a - b).powi(2))
575                    .sum::<f64>()
576                    .sqrt(),
577                "manhattan" => point
578                    .iter()
579                    .zip(leader.leader.iter())
580                    .map(|(a, b)| (a - b).abs())
581                    .sum::<f64>(),
582                _ => return Err(ClusteringError::InvalidInput("Unknown metric".to_string())),
583            };
584
585            if distance < min_distance {
586                min_distance = distance;
587                best_leader = Some(i);
588            }
589        }
590
591        Ok(best_leader)
592    }
593}
594
595/// Leader Tree clustering model that can be serialized
596#[derive(Serialize, Deserialize, Debug, Clone)]
597pub struct LeaderTreeModel<F: Float> {
598    /// Root of the leader tree
599    pub tree: LeaderTree<F>,
600    /// Threshold parameter
601    pub threshold: F,
602    /// Distance metric used
603    pub metric: String,
604}
605
606impl<F: Float + Serialize + for<'de> Deserialize<'de>> SerializableModel for LeaderTreeModel<F> {}
607
608/// Affinity Propagation model that can be serialized
609#[derive(Serialize, Deserialize, Debug, Clone)]
610pub struct AffinityPropagationModel {
611    /// Cluster centers (exemplars)
612    pub cluster_centers: Array2<f64>,
613    /// Cluster labels
614    pub labels: Array1<i32>,
615    /// Affinity matrix
616    pub affinity_matrix: Array2<f64>,
617    /// Converged flag
618    pub converged: bool,
619    /// Number of iterations
620    pub n_iter: usize,
621}
622
623impl SerializableModel for AffinityPropagationModel {}
624
625impl AffinityPropagationModel {
626    /// Create a new Affinity Propagation model
627    pub fn new(
628        cluster_centers: Array2<f64>,
629        labels: Array1<i32>,
630        affinity_matrix: Array2<f64>,
631        converged: bool,
632        n_iter: usize,
633    ) -> Self {
634        Self {
635            cluster_centers,
636            labels,
637            affinity_matrix,
638            converged,
639            n_iter,
640        }
641    }
642
643    /// Get number of clusters
644    pub fn n_clusters(&self) -> usize {
645        self.cluster_centers.nrows()
646    }
647}
648
649/// BIRCH clustering model that can be serialized
650#[derive(Serialize, Deserialize, Debug, Clone)]
651pub struct BirchModel {
652    /// Cluster centroids
653    pub centroids: Array2<f64>,
654    /// Threshold parameter
655    pub threshold: f64,
656    /// Branching factor
657    pub branching_factor: usize,
658    /// Number of subclusters
659    pub n_subclusters: usize,
660}
661
662impl SerializableModel for BirchModel {}
663
664impl BirchModel {
665    /// Create a new BIRCH model
666    pub fn new(
667        centroids: Array2<f64>,
668        threshold: f64,
669        branching_factor: usize,
670        n_subclusters: usize,
671    ) -> Self {
672        Self {
673            centroids,
674            threshold,
675            branching_factor,
676            n_subclusters,
677        }
678    }
679
680    /// Get number of clusters
681    pub fn n_clusters(&self) -> usize {
682        self.centroids.nrows()
683    }
684}
685
686/// Gaussian Mixture Model that can be serialized
687#[derive(Serialize, Deserialize, Debug, Clone)]
688pub struct GMMModel {
689    /// Mixture weights
690    pub weights: Array1<f64>,
691    /// Component means
692    pub means: Array2<f64>,
693    /// Component covariances
694    pub covariances: Vec<Array2<f64>>,
695    /// Number of components
696    pub n_components: usize,
697    /// Covariance type
698    pub covariance_type: String,
699    /// Log-likelihood
700    pub log_likelihood: f64,
701    /// Converged flag
702    pub converged: bool,
703    /// Number of iterations
704    pub n_iter: usize,
705}
706
707impl SerializableModel for GMMModel {}
708
709impl GMMModel {
710    /// Create a new GMM model
711    pub fn new(
712        weights: Array1<f64>,
713        means: Array2<f64>,
714        covariances: Vec<Array2<f64>>,
715        n_components: usize,
716        covariance_type: String,
717        log_likelihood: f64,
718        converged: bool,
719        n_iter: usize,
720    ) -> Self {
721        Self {
722            weights,
723            means,
724            covariances,
725            n_components,
726            covariance_type,
727            log_likelihood,
728            converged,
729            n_iter,
730        }
731    }
732
733    /// Predict cluster probabilities for new data
734    pub fn predict_proba(&self, data: ArrayView2<f64>) -> Result<Array2<f64>> {
735        let n_samples = data.nrows();
736        let mut probabilities = Array2::zeros((n_samples, self.n_components));
737
738        for (i, sample) in data.rows().into_iter().enumerate() {
739            for j in 0..self.n_components {
740                let mean = self.means.row(j);
741                let diff: Vec<f64> = sample.iter().zip(mean.iter()).map(|(a, b)| a - b).collect();
742
743                // Simplified probability calculation (would need proper multivariate normal)
744                let distance = diff.iter().map(|x| x * x).sum::<f64>().sqrt();
745                probabilities[[i, j]] = self.weights[j] * (-distance / 2.0).exp();
746            }
747        }
748
749        // Normalize probabilities
750        for i in 0..n_samples {
751            let sum: f64 = probabilities.row(i).sum();
752            if sum > 0.0 {
753                for j in 0..self.n_components {
754                    probabilities[[i, j]] /= sum;
755                }
756            }
757        }
758
759        Ok(probabilities)
760    }
761}
762
763/// Spectral clustering model that can be serialized
764#[derive(Serialize, Deserialize, Debug, Clone)]
765pub struct SpectralClusteringModel {
766    /// Cluster labels
767    pub labels: Array1<usize>,
768    /// Affinity matrix
769    pub affinity_matrix: Array2<f64>,
770    /// Eigenvalues
771    pub eigenvalues: Array1<f64>,
772    /// Eigenvectors
773    pub eigenvectors: Array2<f64>,
774    /// Number of clusters
775    pub n_clusters: usize,
776}
777
778impl SerializableModel for SpectralClusteringModel {}
779
780impl SpectralClusteringModel {
781    /// Create a new Spectral clustering model
782    pub fn new(
783        labels: Array1<usize>,
784        affinity_matrix: Array2<f64>,
785        eigenvalues: Array1<f64>,
786        eigenvectors: Array2<f64>,
787        n_clusters: usize,
788    ) -> Self {
789        Self {
790            labels,
791            affinity_matrix,
792            eigenvalues,
793            eigenvectors,
794            n_clusters,
795        }
796    }
797}
798
799// Conversion functions for creating models from algorithm results
800
801/// Convert K-means results to serializable model
802pub fn kmeans_to_model(
803    centroids: Array2<f64>,
804    labels: Option<Array1<usize>>,
805    n_iter: usize,
806    inertia: f64,
807) -> KMeansModel {
808    let n_clusters = centroids.nrows();
809    KMeansModel::new(centroids, n_clusters, n_iter, inertia, labels)
810}
811
812/// Convert DBSCAN results to serializable model
813pub fn dbscan_to_model(
814    core_sample_indices: Vec<usize>,
815    components: Array2<f64>,
816    labels: Array1<i32>,
817    eps: f64,
818    min_samples: usize,
819) -> DBSCANModel {
820    DBSCANModel::new(
821        Array1::from_vec(core_sample_indices),
822        labels,
823        eps,
824        min_samples,
825    )
826}
827
828/// Convert hierarchical clustering results to serializable model
829pub fn hierarchy_to_model(
830    n_clusters: usize,
831    labels: Array1<usize>,
832    linkage_matrix: Array2<f64>,
833    distances: Vec<f64>,
834) -> HierarchicalModel {
835    HierarchicalModel::new(linkage_matrix, n_clusters, "ward".to_string(), None)
836}
837
838/// Convert GMM results to serializable model
839pub fn gmm_to_model(
840    weights: Array1<f64>,
841    means: Array2<f64>,
842    covariances: Vec<Array2<f64>>,
843    n_components: usize,
844    covariance_type: String,
845    log_likelihood: f64,
846    converged: bool,
847    n_iter: usize,
848) -> GMMModel {
849    GMMModel::new(
850        weights,
851        means,
852        covariances,
853        n_components,
854        covariance_type,
855        log_likelihood,
856        converged,
857        n_iter,
858    )
859}
860
861/// Convert Mean Shift results to serializable model
862pub fn meanshift_to_model(
863    cluster_centers: Array2<f64>,
864    labels: Array1<usize>,
865    bandwidth: f64,
866    n_iter: usize,
867) -> MeanShiftModel {
868    MeanShiftModel::new(cluster_centers, bandwidth, Some(labels))
869}
870
871/// Convert Affinity Propagation results to serializable model
872pub fn affinity_propagation_to_model(
873    exemplars: Vec<usize>,
874    labels: Array1<i32>,
875    damping: f64,
876    preference: f64,
877    n_iter: usize,
878) -> AffinityPropagationModel {
879    // Extract cluster centers from exemplars
880    let n_clusters = exemplars.len();
881    let n_features = if n_clusters > 0 { 2 } else { 0 }; // Default assumption
882    let cluster_centers = Array2::zeros((n_clusters, n_features));
883    let affinity_matrix = Array2::zeros((labels.len(), labels.len()));
884
885    AffinityPropagationModel::new(cluster_centers, labels, affinity_matrix, true, n_iter)
886}
887
888/// Convert BIRCH results to serializable model
889pub fn birch_to_model(
890    centroids: Array2<f64>,
891    threshold: f64,
892    branching_factor: usize,
893    n_subclusters: usize,
894) -> BirchModel {
895    BirchModel::new(centroids, threshold, branching_factor, n_subclusters)
896}
897
898/// Convert Leader clustering results to serializable model
899pub fn leader_to_model(
900    leaders: Vec<LeaderNode<f64>>,
901    threshold: f64,
902    distance_metric: String,
903) -> LeaderModel {
904    // Convert LeaderNode to LeaderNode<f64> if needed, or use directly
905
906    LeaderModel {
907        leaders,
908        threshold,
909        metric: distance_metric,
910    }
911}
912
913/// Convert Leader Tree results to serializable model
914pub fn leadertree_to_model(
915    tree: Option<LeaderTree<f64>>,
916    threshold: f64,
917    max_depth: usize,
918) -> LeaderTreeModel<f64> {
919    LeaderTreeModel {
920        tree: tree.unwrap_or_else(|| LeaderTree {
921            roots: Vec::new(),
922            threshold,
923        }),
924        threshold,
925        metric: "euclidean".to_string(),
926    }
927}
928
929/// Convert Spectral clustering results to serializable model
930pub fn spectral_clustering_to_model(
931    labels: Array1<usize>,
932    affinity_matrix: Array2<f64>,
933    eigenvalues: Array1<f64>,
934    eigenvectors: Array2<f64>,
935    n_clusters: usize,
936) -> SpectralClusteringModel {
937    SpectralClusteringModel::new(
938        labels,
939        affinity_matrix,
940        eigenvalues,
941        eigenvectors,
942        n_clusters,
943    )
944}
945
946// Save functions for convenience
947
948/// Save K-means model to file
949pub fn save_kmeans<P: AsRef<std::path::Path>>(model: &KMeansModel, path: P) -> Result<()> {
950    model.save_to_file(path)
951}
952
953/// Save DBSCAN model to file
954pub fn save_dbscan<P: AsRef<std::path::Path>>(model: &DBSCANModel, path: P) -> Result<()> {
955    model.save_to_file(path)
956}
957
958/// Save hierarchical clustering model to file
959pub fn save_hierarchy<P: AsRef<std::path::Path>>(model: &HierarchicalModel, path: P) -> Result<()> {
960    model.save_to_file(path)
961}
962
963/// Save GMM model to file
964pub fn save_gmm<P: AsRef<std::path::Path>>(model: &GMMModel, path: P) -> Result<()> {
965    model.save_to_file(path)
966}
967
968/// Save Mean Shift model to file
969pub fn save_meanshift<P: AsRef<std::path::Path>>(model: &MeanShiftModel, path: P) -> Result<()> {
970    model.save_to_file(path)
971}
972
973/// Save Affinity Propagation model to file
974pub fn save_affinity_propagation<P: AsRef<std::path::Path>>(
975    exemplars: Vec<usize>,
976    labels: Array1<i32>,
977    damping: f64,
978    preference: f64,
979    n_iter: usize,
980    path: P,
981) -> Result<()> {
982    let model = affinity_propagation_to_model(exemplars, labels, damping, preference, n_iter);
983    model.save_to_file(path)
984}
985
986/// Save BIRCH model to file
987pub fn save_birch<P: AsRef<std::path::Path>>(model: &BirchModel, path: P) -> Result<()> {
988    model.save_to_file(path)
989}
990
991/// Save Leader clustering model to file
992pub fn save_leader<P: AsRef<std::path::Path>>(model: &LeaderModel, path: P) -> Result<()> {
993    model.save_to_file(path)
994}
995
996/// Save Leader Tree model to file
997pub fn save_leadertree<
998    F: Float + Serialize + for<'de> serde::Deserialize<'de>,
999    P: AsRef<std::path::Path>,
1000>(
1001    model: &LeaderTreeModel<F>,
1002    path: P,
1003) -> Result<()> {
1004    model.save_to_file(path)
1005}
1006
1007/// Save Spectral clustering model to file
1008pub fn save_spectral_clustering<P: AsRef<std::path::Path>>(
1009    model: &SpectralClusteringModel,
1010    path: P,
1011) -> Result<()> {
1012    model.save_to_file(path)
1013}
1014
1015#[cfg(test)]
1016mod tests {
1017    use super::*;
1018    use scirs2_core::ndarray::Array2;
1019
1020    #[test]
1021    fn test_kmeans_model_predict() {
1022        let centroids = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 1.0, 1.0]).unwrap();
1023        let model = KMeansModel::new(centroids, 2, 10, 0.5, None);
1024
1025        let data = Array2::from_shape_vec((2, 2), vec![0.1, 0.1, 0.9, 0.9]).unwrap();
1026        let labels = model.predict(data.view()).unwrap();
1027
1028        assert_eq!(labels[0], 0); // Closer to first centroid
1029        assert_eq!(labels[1], 1); // Closer to second centroid
1030    }
1031
1032    #[test]
1033    fn test_dbscan_model_clusters() {
1034        let core_indices = Array1::from_vec(vec![0, 1, 2]);
1035        let labels = Array1::from_vec(vec![0, 0, 1, -1]);
1036        let model = DBSCANModel::new(core_indices, labels, 0.5, 2);
1037
1038        assert_eq!(model.n_clusters(), 3); // Points with labels 0, 0, 1 (excluding -1)
1039        assert_eq!(model.noise_indices(), vec![3]); // Point with label -1
1040    }
1041
1042    #[test]
1043    fn test_hierarchical_model_newick() {
1044        let linkage = Array2::from_shape_vec((1, 3), vec![0.0, 1.0, 0.5]).unwrap();
1045        let model = HierarchicalModel::new(linkage, 2, "ward".to_string(), None);
1046
1047        let newick = model.to_newick().unwrap();
1048        assert!(newick.contains("("));
1049        assert!(newick.contains(")"));
1050        assert!(newick.ends_with(";"));
1051    }
1052}