Skip to main content

scirs2/
cluster.rs

1//! Python bindings for scirs2-cluster
2//!
3//! This module provides Python bindings that make scirs2-cluster algorithms
4//! accessible from Python with scikit-learn compatible APIs.
5
6use pyo3::exceptions::{PyRuntimeError, PyValueError};
7use pyo3::prelude::*;
8use pyo3::types::PyDict;
9
10// NumPy types for Python array interface (scirs2-numpy with native ndarray 0.17)
11use scirs2_numpy::{IntoPyArray, PyArray1, PyArray2, PyArrayMethods};
12
13// ndarray types from scirs2-core
14use scirs2_core::{Array1, Array2};
15
16// Direct imports from scirs2-cluster (native ndarray 0.17 support)
17use scirs2_cluster::kmeans;
18use scirs2_cluster::{calinski_harabasz_score, davies_bouldin_score, silhouette_score};
19use scirs2_cluster::{normalize, standardize, NormType};
20
21/// Python-compatible K-means clustering implementation
22#[pyclass(name = "KMeans")]
23pub struct PyKMeans {
24    /// Number of clusters
25    n_clusters: usize,
26    /// Maximum iterations
27    max_iter: usize,
28    /// Convergence tolerance
29    tol: f64,
30    /// Random seed
31    random_state: Option<u64>,
32    /// Number of initializations
33    n_init: usize,
34    /// Initialization method
35    init: String,
36    /// Fitted cluster centers
37    cluster_centers_: Option<Vec<Vec<f64>>>,
38    /// Labels of each point
39    labels_: Option<Vec<usize>>,
40    /// Sum of squared distances to centroids
41    inertia_: Option<f64>,
42}
43
44#[pymethods]
45impl PyKMeans {
46    /// Create new K-means clustering instance
47    #[new]
48    #[pyo3(signature = (n_clusters=8, *, init="k-means++", n_init=10, max_iter=300, tol=1e-4, random_state=None))]
49    fn new(
50        n_clusters: usize,
51        init: &str,
52        n_init: usize,
53        max_iter: usize,
54        tol: f64,
55        random_state: Option<u64>,
56    ) -> Self {
57        Self {
58            n_clusters,
59            max_iter,
60            tol,
61            random_state,
62            n_init,
63            init: init.to_string(),
64            cluster_centers_: None,
65            labels_: None,
66            inertia_: None,
67        }
68    }
69
70    /// Fit K-means clustering to data
71    fn fit(&mut self, _py: Python, x: &Bound<'_, PyArray2<f64>>) -> PyResult<()> {
72        let binding = x.readonly();
73        let data = binding.as_array();
74
75        // Run K-means using scirs2_cluster directly
76        let (centroids, inertia) = kmeans(
77            data,
78            self.n_clusters,
79            Some(self.max_iter),
80            Some(self.tol),
81            Some(true), // check_finite
82            self.random_state,
83        )
84        .map_err(|e| PyRuntimeError::new_err(format!("K-means fitting failed: {}", e)))?;
85
86        // Assign labels by finding nearest centroid for each point
87        let n_samples = data.nrows();
88        let mut labels = Vec::with_capacity(n_samples);
89
90        for sample in data.rows() {
91            let mut min_dist = f64::INFINITY;
92            let mut best_cluster = 0;
93
94            for (j, centroid) in centroids.rows().into_iter().enumerate() {
95                let dist: f64 = sample
96                    .iter()
97                    .zip(centroid.iter())
98                    .map(|(a, b)| (a - b).powi(2))
99                    .sum::<f64>()
100                    .sqrt();
101
102                if dist < min_dist {
103                    min_dist = dist;
104                    best_cluster = j;
105                }
106            }
107            labels.push(best_cluster);
108        }
109
110        // Store results
111        self.cluster_centers_ = Some(
112            centroids
113                .rows()
114                .into_iter()
115                .map(|row| row.to_vec())
116                .collect(),
117        );
118        self.labels_ = Some(labels);
119        self.inertia_ = Some(inertia);
120
121        Ok(())
122    }
123
124    /// Fit and predict cluster labels
125    fn fit_predict(
126        &mut self,
127        py: Python,
128        x: &Bound<'_, PyArray2<f64>>,
129    ) -> PyResult<Py<PyArray1<i32>>> {
130        self.fit(py, x)?;
131        self.labels(py)
132    }
133
134    /// Predict cluster labels for new data
135    fn predict(&self, py: Python, x: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyArray1<i32>>> {
136        if self.cluster_centers_.is_none() {
137            return Err(PyRuntimeError::new_err("Model not fitted yet"));
138        }
139
140        let binding = x.readonly();
141        let data = binding.as_array();
142        let centers = self.cluster_centers_.as_ref().ok_or_else(|| {
143            pyo3::exceptions::PyRuntimeError::new_err("Model not fitted: call fit() first")
144        })?;
145
146        let n_samples = data.nrows();
147        let mut labels = Vec::with_capacity(n_samples);
148
149        for sample in data.rows() {
150            let mut min_dist = f64::INFINITY;
151            let mut best_cluster = 0;
152
153            for (j, center) in centers.iter().enumerate() {
154                let dist: f64 = sample
155                    .iter()
156                    .zip(center.iter())
157                    .map(|(a, b)| (a - b).powi(2))
158                    .sum::<f64>()
159                    .sqrt();
160
161                if dist < min_dist {
162                    min_dist = dist;
163                    best_cluster = j;
164                }
165            }
166            labels.push(best_cluster as i32);
167        }
168
169        let labels_array = Array1::from_vec(labels);
170        Ok(labels_array.into_pyarray(py).unbind())
171    }
172
173    /// Get cluster centers
174    #[getter]
175    fn cluster_centers_(&self, py: Python) -> PyResult<Option<Py<PyArray2<f64>>>> {
176        match &self.cluster_centers_ {
177            Some(centers) => {
178                let n_clusters = centers.len();
179                let n_features = centers.first().map(|c| c.len()).unwrap_or(0);
180                let flat: Vec<f64> = centers.iter().flatten().copied().collect();
181                let array = Array2::from_shape_vec((n_clusters, n_features), flat)
182                    .map_err(|e| PyRuntimeError::new_err(format!("Array reshape error: {}", e)))?;
183                Ok(Some(array.into_pyarray(py).unbind()))
184            }
185            None => Ok(None),
186        }
187    }
188
189    /// Get labels
190    #[getter]
191    fn labels(&self, py: Python) -> PyResult<Py<PyArray1<i32>>> {
192        match &self.labels_ {
193            Some(labels) => {
194                let labels_i32: Vec<i32> = labels.iter().map(|&x| x as i32).collect();
195                let array = Array1::from_vec(labels_i32);
196                Ok(array.into_pyarray(py).unbind())
197            }
198            None => Err(PyRuntimeError::new_err("Model not fitted yet")),
199        }
200    }
201
202    /// Get inertia (sum of squared distances to centroids)
203    #[getter]
204    fn inertia_(&self) -> Option<f64> {
205        self.inertia_
206    }
207
208    /// Set parameters
209    fn set_params(&mut self, params: &Bound<'_, PyDict>) -> PyResult<()> {
210        for (key, value) in params.iter() {
211            let key_str: String = key.extract()?;
212            match key_str.as_str() {
213                "n_clusters" => self.n_clusters = value.extract()?,
214                "max_iter" => self.max_iter = value.extract()?,
215                "tol" => self.tol = value.extract()?,
216                "random_state" => self.random_state = value.extract()?,
217                "n_init" => self.n_init = value.extract()?,
218                "init" => self.init = value.extract()?,
219                _ => {
220                    return Err(PyValueError::new_err(format!(
221                        "Unknown parameter: {}",
222                        key_str
223                    )))
224                }
225            }
226        }
227        Ok(())
228    }
229
230    /// Get parameters
231    fn get_params(&self, py: Python, _deep: Option<bool>) -> PyResult<Py<PyAny>> {
232        let dict = PyDict::new(py);
233        dict.set_item("n_clusters", self.n_clusters)?;
234        dict.set_item("max_iter", self.max_iter)?;
235        dict.set_item("tol", self.tol)?;
236        dict.set_item("random_state", self.random_state)?;
237        dict.set_item("n_init", self.n_init)?;
238        dict.set_item("init", &self.init)?;
239        Ok(dict.into_any().unbind())
240    }
241}
242
243/// Calculate silhouette score
244#[pyfunction]
245fn silhouette_score_py(
246    x: &Bound<'_, PyArray2<f64>>,
247    labels: &Bound<'_, PyArray1<i32>>,
248) -> PyResult<f64> {
249    let binding = x.readonly();
250    let data = binding.as_array();
251    let labels_binding = labels.readonly();
252    let labels_arr = labels_binding.as_array();
253
254    let score = silhouette_score(data, labels_arr)
255        .map_err(|e| PyRuntimeError::new_err(format!("Silhouette score failed: {}", e)))?;
256
257    Ok(score)
258}
259
260/// Calculate Davies-Bouldin score
261#[pyfunction]
262fn davies_bouldin_score_py(
263    x: &Bound<'_, PyArray2<f64>>,
264    labels: &Bound<'_, PyArray1<i32>>,
265) -> PyResult<f64> {
266    let binding = x.readonly();
267    let data = binding.as_array();
268    let labels_binding = labels.readonly();
269    let labels_arr = labels_binding.as_array();
270
271    let score = davies_bouldin_score(data, labels_arr)
272        .map_err(|e| PyRuntimeError::new_err(format!("Davies-Bouldin score failed: {}", e)))?;
273
274    Ok(score)
275}
276
277/// Calculate Calinski-Harabasz score
278#[pyfunction]
279fn calinski_harabasz_score_py(
280    x: &Bound<'_, PyArray2<f64>>,
281    labels: &Bound<'_, PyArray1<i32>>,
282) -> PyResult<f64> {
283    let binding = x.readonly();
284    let data = binding.as_array();
285    let labels_binding = labels.readonly();
286    let labels_arr = labels_binding.as_array();
287
288    let score = calinski_harabasz_score(data, labels_arr)
289        .map_err(|e| PyRuntimeError::new_err(format!("Calinski-Harabasz score failed: {}", e)))?;
290
291    Ok(score)
292}
293
294/// Standardize data to zero mean and unit variance
295#[pyfunction]
296fn standardize_py(py: Python, x: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyArray2<f64>>> {
297    let binding = x.readonly();
298    let data = binding.as_array();
299
300    let result = standardize(data, true)  // check_finite=true
301        .map_err(|e| PyRuntimeError::new_err(format!("Standardization failed: {}", e)))?;
302
303    Ok(result.into_pyarray(py).unbind())
304}
305
306/// Normalize data to unit norm
307#[pyfunction]
308fn normalize_py(
309    py: Python,
310    x: &Bound<'_, PyArray2<f64>>,
311    norm: Option<&str>,
312) -> PyResult<Py<PyArray2<f64>>> {
313    let binding = x.readonly();
314    let data = binding.as_array();
315
316    let norm_type = match norm.unwrap_or("l2") {
317        "l1" => NormType::L1,
318        "l2" => NormType::L2,
319        "max" => NormType::Max,
320        other => {
321            return Err(PyValueError::new_err(format!(
322                "Unknown norm type: {}",
323                other
324            )))
325        }
326    };
327
328    let result =
329        normalize(data, norm_type, true) // check_finite=true
330            .map_err(|e| PyRuntimeError::new_err(format!("Normalization failed: {}", e)))?;
331
332    Ok(result.into_pyarray(py).unbind())
333}
334
335/// Python module registration
336pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
337    // Classes
338    m.add_class::<PyKMeans>()?;
339
340    // Metrics
341    m.add_function(wrap_pyfunction!(silhouette_score_py, m)?)?;
342    m.add_function(wrap_pyfunction!(davies_bouldin_score_py, m)?)?;
343    m.add_function(wrap_pyfunction!(calinski_harabasz_score_py, m)?)?;
344
345    // Preprocessing
346    m.add_function(wrap_pyfunction!(standardize_py, m)?)?;
347    m.add_function(wrap_pyfunction!(normalize_py, m)?)?;
348
349    Ok(())
350}