Skip to main content

scirs2/
cluster.rs

1//! Python bindings for scirs2-cluster
2//!
3//! This module provides Python bindings that make scirs2-cluster algorithms
4//! accessible from Python with scikit-learn compatible APIs.
5
6use pyo3::exceptions::{PyRuntimeError, PyValueError};
7use pyo3::prelude::*;
8use pyo3::types::PyDict;
9
10// NumPy types for Python array interface (scirs2-numpy with native ndarray 0.17)
11use scirs2_numpy::{IntoPyArray, PyArray1, PyArray2, PyArrayMethods};
12
13// ndarray types from scirs2-core
14use scirs2_core::{Array1, Array2};
15
16// Direct imports from scirs2-cluster (native ndarray 0.17 support)
17use scirs2_cluster::kmeans;
18use scirs2_cluster::{calinski_harabasz_score, davies_bouldin_score, silhouette_score};
19use scirs2_cluster::{normalize, standardize, NormType};
20
21/// Python-compatible K-means clustering implementation
22#[pyclass(name = "KMeans")]
23pub struct PyKMeans {
24    /// Number of clusters
25    n_clusters: usize,
26    /// Maximum iterations
27    max_iter: usize,
28    /// Convergence tolerance
29    tol: f64,
30    /// Random seed
31    random_state: Option<u64>,
32    /// Number of initializations
33    n_init: usize,
34    /// Initialization method
35    init: String,
36    /// Fitted cluster centers
37    cluster_centers_: Option<Vec<Vec<f64>>>,
38    /// Labels of each point
39    labels_: Option<Vec<usize>>,
40    /// Sum of squared distances to centroids
41    inertia_: Option<f64>,
42}
43
44#[pymethods]
45impl PyKMeans {
46    /// Create new K-means clustering instance
47    #[new]
48    #[pyo3(signature = (n_clusters=8, *, init="k-means++", n_init=10, max_iter=300, tol=1e-4, random_state=None))]
49    fn new(
50        n_clusters: usize,
51        init: &str,
52        n_init: usize,
53        max_iter: usize,
54        tol: f64,
55        random_state: Option<u64>,
56    ) -> Self {
57        Self {
58            n_clusters,
59            max_iter,
60            tol,
61            random_state,
62            n_init,
63            init: init.to_string(),
64            cluster_centers_: None,
65            labels_: None,
66            inertia_: None,
67        }
68    }
69
70    /// Fit K-means clustering to data
71    fn fit(&mut self, _py: Python, x: &Bound<'_, PyArray2<f64>>) -> PyResult<()> {
72        let binding = x.readonly();
73        let data = binding.as_array();
74
75        // Run K-means using scirs2_cluster directly
76        let (centroids, inertia) = kmeans(
77            data,
78            self.n_clusters,
79            Some(self.max_iter),
80            Some(self.tol),
81            Some(true), // check_finite
82            self.random_state,
83        )
84        .map_err(|e| PyRuntimeError::new_err(format!("K-means fitting failed: {}", e)))?;
85
86        // Assign labels by finding nearest centroid for each point
87        let n_samples = data.nrows();
88        let mut labels = Vec::with_capacity(n_samples);
89
90        for sample in data.rows() {
91            let mut min_dist = f64::INFINITY;
92            let mut best_cluster = 0;
93
94            for (j, centroid) in centroids.rows().into_iter().enumerate() {
95                let dist: f64 = sample
96                    .iter()
97                    .zip(centroid.iter())
98                    .map(|(a, b)| (a - b).powi(2))
99                    .sum::<f64>()
100                    .sqrt();
101
102                if dist < min_dist {
103                    min_dist = dist;
104                    best_cluster = j;
105                }
106            }
107            labels.push(best_cluster);
108        }
109
110        // Store results
111        self.cluster_centers_ = Some(
112            centroids
113                .rows()
114                .into_iter()
115                .map(|row| row.to_vec())
116                .collect(),
117        );
118        self.labels_ = Some(labels);
119        self.inertia_ = Some(inertia);
120
121        Ok(())
122    }
123
124    /// Fit and predict cluster labels
125    fn fit_predict(
126        &mut self,
127        py: Python,
128        x: &Bound<'_, PyArray2<f64>>,
129    ) -> PyResult<Py<PyArray1<i32>>> {
130        self.fit(py, x)?;
131        self.labels(py)
132    }
133
134    /// Predict cluster labels for new data
135    fn predict(&self, py: Python, x: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyArray1<i32>>> {
136        if self.cluster_centers_.is_none() {
137            return Err(PyRuntimeError::new_err("Model not fitted yet"));
138        }
139
140        let binding = x.readonly();
141        let data = binding.as_array();
142        let centers = self.cluster_centers_.as_ref().expect("Operation failed");
143
144        let n_samples = data.nrows();
145        let mut labels = Vec::with_capacity(n_samples);
146
147        for sample in data.rows() {
148            let mut min_dist = f64::INFINITY;
149            let mut best_cluster = 0;
150
151            for (j, center) in centers.iter().enumerate() {
152                let dist: f64 = sample
153                    .iter()
154                    .zip(center.iter())
155                    .map(|(a, b)| (a - b).powi(2))
156                    .sum::<f64>()
157                    .sqrt();
158
159                if dist < min_dist {
160                    min_dist = dist;
161                    best_cluster = j;
162                }
163            }
164            labels.push(best_cluster as i32);
165        }
166
167        let labels_array = Array1::from_vec(labels);
168        Ok(labels_array.into_pyarray(py).unbind())
169    }
170
171    /// Get cluster centers
172    #[getter]
173    fn cluster_centers_(&self, py: Python) -> PyResult<Option<Py<PyArray2<f64>>>> {
174        match &self.cluster_centers_ {
175            Some(centers) => {
176                let n_clusters = centers.len();
177                let n_features = centers.first().map(|c| c.len()).unwrap_or(0);
178                let flat: Vec<f64> = centers.iter().flatten().copied().collect();
179                let array = Array2::from_shape_vec((n_clusters, n_features), flat)
180                    .map_err(|e| PyRuntimeError::new_err(format!("Array reshape error: {}", e)))?;
181                Ok(Some(array.into_pyarray(py).unbind()))
182            }
183            None => Ok(None),
184        }
185    }
186
187    /// Get labels
188    #[getter]
189    fn labels(&self, py: Python) -> PyResult<Py<PyArray1<i32>>> {
190        match &self.labels_ {
191            Some(labels) => {
192                let labels_i32: Vec<i32> = labels.iter().map(|&x| x as i32).collect();
193                let array = Array1::from_vec(labels_i32);
194                Ok(array.into_pyarray(py).unbind())
195            }
196            None => Err(PyRuntimeError::new_err("Model not fitted yet")),
197        }
198    }
199
200    /// Get inertia (sum of squared distances to centroids)
201    #[getter]
202    fn inertia_(&self) -> Option<f64> {
203        self.inertia_
204    }
205
206    /// Set parameters
207    fn set_params(&mut self, params: &Bound<'_, PyDict>) -> PyResult<()> {
208        for (key, value) in params.iter() {
209            let key_str: String = key.extract()?;
210            match key_str.as_str() {
211                "n_clusters" => self.n_clusters = value.extract()?,
212                "max_iter" => self.max_iter = value.extract()?,
213                "tol" => self.tol = value.extract()?,
214                "random_state" => self.random_state = value.extract()?,
215                "n_init" => self.n_init = value.extract()?,
216                "init" => self.init = value.extract()?,
217                _ => {
218                    return Err(PyValueError::new_err(format!(
219                        "Unknown parameter: {}",
220                        key_str
221                    )))
222                }
223            }
224        }
225        Ok(())
226    }
227
228    /// Get parameters
229    fn get_params(&self, py: Python, _deep: Option<bool>) -> PyResult<Py<PyAny>> {
230        let dict = PyDict::new(py);
231        dict.set_item("n_clusters", self.n_clusters)?;
232        dict.set_item("max_iter", self.max_iter)?;
233        dict.set_item("tol", self.tol)?;
234        dict.set_item("random_state", self.random_state)?;
235        dict.set_item("n_init", self.n_init)?;
236        dict.set_item("init", &self.init)?;
237        Ok(dict.into_any().unbind())
238    }
239}
240
241/// Calculate silhouette score
242#[pyfunction]
243fn silhouette_score_py(
244    x: &Bound<'_, PyArray2<f64>>,
245    labels: &Bound<'_, PyArray1<i32>>,
246) -> PyResult<f64> {
247    let binding = x.readonly();
248    let data = binding.as_array();
249    let labels_binding = labels.readonly();
250    let labels_arr = labels_binding.as_array();
251
252    let score = silhouette_score(data, labels_arr)
253        .map_err(|e| PyRuntimeError::new_err(format!("Silhouette score failed: {}", e)))?;
254
255    Ok(score)
256}
257
258/// Calculate Davies-Bouldin score
259#[pyfunction]
260fn davies_bouldin_score_py(
261    x: &Bound<'_, PyArray2<f64>>,
262    labels: &Bound<'_, PyArray1<i32>>,
263) -> PyResult<f64> {
264    let binding = x.readonly();
265    let data = binding.as_array();
266    let labels_binding = labels.readonly();
267    let labels_arr = labels_binding.as_array();
268
269    let score = davies_bouldin_score(data, labels_arr)
270        .map_err(|e| PyRuntimeError::new_err(format!("Davies-Bouldin score failed: {}", e)))?;
271
272    Ok(score)
273}
274
275/// Calculate Calinski-Harabasz score
276#[pyfunction]
277fn calinski_harabasz_score_py(
278    x: &Bound<'_, PyArray2<f64>>,
279    labels: &Bound<'_, PyArray1<i32>>,
280) -> PyResult<f64> {
281    let binding = x.readonly();
282    let data = binding.as_array();
283    let labels_binding = labels.readonly();
284    let labels_arr = labels_binding.as_array();
285
286    let score = calinski_harabasz_score(data, labels_arr)
287        .map_err(|e| PyRuntimeError::new_err(format!("Calinski-Harabasz score failed: {}", e)))?;
288
289    Ok(score)
290}
291
292/// Standardize data to zero mean and unit variance
293#[pyfunction]
294fn standardize_py(py: Python, x: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyArray2<f64>>> {
295    let binding = x.readonly();
296    let data = binding.as_array();
297
298    let result = standardize(data, true)  // check_finite=true
299        .map_err(|e| PyRuntimeError::new_err(format!("Standardization failed: {}", e)))?;
300
301    Ok(result.into_pyarray(py).unbind())
302}
303
304/// Normalize data to unit norm
305#[pyfunction]
306fn normalize_py(
307    py: Python,
308    x: &Bound<'_, PyArray2<f64>>,
309    norm: Option<&str>,
310) -> PyResult<Py<PyArray2<f64>>> {
311    let binding = x.readonly();
312    let data = binding.as_array();
313
314    let norm_type = match norm.unwrap_or("l2") {
315        "l1" => NormType::L1,
316        "l2" => NormType::L2,
317        "max" => NormType::Max,
318        other => {
319            return Err(PyValueError::new_err(format!(
320                "Unknown norm type: {}",
321                other
322            )))
323        }
324    };
325
326    let result =
327        normalize(data, norm_type, true) // check_finite=true
328            .map_err(|e| PyRuntimeError::new_err(format!("Normalization failed: {}", e)))?;
329
330    Ok(result.into_pyarray(py).unbind())
331}
332
333/// Python module registration
334pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
335    // Classes
336    m.add_class::<PyKMeans>()?;
337
338    // Metrics
339    m.add_function(wrap_pyfunction!(silhouette_score_py, m)?)?;
340    m.add_function(wrap_pyfunction!(davies_bouldin_score_py, m)?)?;
341    m.add_function(wrap_pyfunction!(calinski_harabasz_score_py, m)?)?;
342
343    // Preprocessing
344    m.add_function(wrap_pyfunction!(standardize_py, m)?)?;
345    m.add_function(wrap_pyfunction!(normalize_py, m)?)?;
346
347    Ok(())
348}