1use pyo3::exceptions::{PyRuntimeError, PyValueError};
7use pyo3::prelude::*;
8use pyo3::types::PyDict;
9
10use scirs2_numpy::{IntoPyArray, PyArray1, PyArray2, PyArrayMethods};
12
13use scirs2_core::{Array1, Array2};
15
16use scirs2_cluster::kmeans;
18use scirs2_cluster::{calinski_harabasz_score, davies_bouldin_score, silhouette_score};
19use scirs2_cluster::{normalize, standardize, NormType};
20
21#[pyclass(name = "KMeans")]
23pub struct PyKMeans {
24 n_clusters: usize,
26 max_iter: usize,
28 tol: f64,
30 random_state: Option<u64>,
32 n_init: usize,
34 init: String,
36 cluster_centers_: Option<Vec<Vec<f64>>>,
38 labels_: Option<Vec<usize>>,
40 inertia_: Option<f64>,
42}
43
44#[pymethods]
45impl PyKMeans {
46 #[new]
48 #[pyo3(signature = (n_clusters=8, *, init="k-means++", n_init=10, max_iter=300, tol=1e-4, random_state=None))]
49 fn new(
50 n_clusters: usize,
51 init: &str,
52 n_init: usize,
53 max_iter: usize,
54 tol: f64,
55 random_state: Option<u64>,
56 ) -> Self {
57 Self {
58 n_clusters,
59 max_iter,
60 tol,
61 random_state,
62 n_init,
63 init: init.to_string(),
64 cluster_centers_: None,
65 labels_: None,
66 inertia_: None,
67 }
68 }
69
70 fn fit(&mut self, _py: Python, x: &Bound<'_, PyArray2<f64>>) -> PyResult<()> {
72 let binding = x.readonly();
73 let data = binding.as_array();
74
75 let (centroids, inertia) = kmeans(
77 data,
78 self.n_clusters,
79 Some(self.max_iter),
80 Some(self.tol),
81 Some(true), self.random_state,
83 )
84 .map_err(|e| PyRuntimeError::new_err(format!("K-means fitting failed: {}", e)))?;
85
86 let n_samples = data.nrows();
88 let mut labels = Vec::with_capacity(n_samples);
89
90 for sample in data.rows() {
91 let mut min_dist = f64::INFINITY;
92 let mut best_cluster = 0;
93
94 for (j, centroid) in centroids.rows().into_iter().enumerate() {
95 let dist: f64 = sample
96 .iter()
97 .zip(centroid.iter())
98 .map(|(a, b)| (a - b).powi(2))
99 .sum::<f64>()
100 .sqrt();
101
102 if dist < min_dist {
103 min_dist = dist;
104 best_cluster = j;
105 }
106 }
107 labels.push(best_cluster);
108 }
109
110 self.cluster_centers_ = Some(
112 centroids
113 .rows()
114 .into_iter()
115 .map(|row| row.to_vec())
116 .collect(),
117 );
118 self.labels_ = Some(labels);
119 self.inertia_ = Some(inertia);
120
121 Ok(())
122 }
123
124 fn fit_predict(
126 &mut self,
127 py: Python,
128 x: &Bound<'_, PyArray2<f64>>,
129 ) -> PyResult<Py<PyArray1<i32>>> {
130 self.fit(py, x)?;
131 self.labels(py)
132 }
133
134 fn predict(&self, py: Python, x: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyArray1<i32>>> {
136 if self.cluster_centers_.is_none() {
137 return Err(PyRuntimeError::new_err("Model not fitted yet"));
138 }
139
140 let binding = x.readonly();
141 let data = binding.as_array();
142 let centers = self.cluster_centers_.as_ref().ok_or_else(|| {
143 pyo3::exceptions::PyRuntimeError::new_err("Model not fitted: call fit() first")
144 })?;
145
146 let n_samples = data.nrows();
147 let mut labels = Vec::with_capacity(n_samples);
148
149 for sample in data.rows() {
150 let mut min_dist = f64::INFINITY;
151 let mut best_cluster = 0;
152
153 for (j, center) in centers.iter().enumerate() {
154 let dist: f64 = sample
155 .iter()
156 .zip(center.iter())
157 .map(|(a, b)| (a - b).powi(2))
158 .sum::<f64>()
159 .sqrt();
160
161 if dist < min_dist {
162 min_dist = dist;
163 best_cluster = j;
164 }
165 }
166 labels.push(best_cluster as i32);
167 }
168
169 let labels_array = Array1::from_vec(labels);
170 Ok(labels_array.into_pyarray(py).unbind())
171 }
172
173 #[getter]
175 fn cluster_centers_(&self, py: Python) -> PyResult<Option<Py<PyArray2<f64>>>> {
176 match &self.cluster_centers_ {
177 Some(centers) => {
178 let n_clusters = centers.len();
179 let n_features = centers.first().map(|c| c.len()).unwrap_or(0);
180 let flat: Vec<f64> = centers.iter().flatten().copied().collect();
181 let array = Array2::from_shape_vec((n_clusters, n_features), flat)
182 .map_err(|e| PyRuntimeError::new_err(format!("Array reshape error: {}", e)))?;
183 Ok(Some(array.into_pyarray(py).unbind()))
184 }
185 None => Ok(None),
186 }
187 }
188
189 #[getter]
191 fn labels(&self, py: Python) -> PyResult<Py<PyArray1<i32>>> {
192 match &self.labels_ {
193 Some(labels) => {
194 let labels_i32: Vec<i32> = labels.iter().map(|&x| x as i32).collect();
195 let array = Array1::from_vec(labels_i32);
196 Ok(array.into_pyarray(py).unbind())
197 }
198 None => Err(PyRuntimeError::new_err("Model not fitted yet")),
199 }
200 }
201
202 #[getter]
204 fn inertia_(&self) -> Option<f64> {
205 self.inertia_
206 }
207
208 fn set_params(&mut self, params: &Bound<'_, PyDict>) -> PyResult<()> {
210 for (key, value) in params.iter() {
211 let key_str: String = key.extract()?;
212 match key_str.as_str() {
213 "n_clusters" => self.n_clusters = value.extract()?,
214 "max_iter" => self.max_iter = value.extract()?,
215 "tol" => self.tol = value.extract()?,
216 "random_state" => self.random_state = value.extract()?,
217 "n_init" => self.n_init = value.extract()?,
218 "init" => self.init = value.extract()?,
219 _ => {
220 return Err(PyValueError::new_err(format!(
221 "Unknown parameter: {}",
222 key_str
223 )))
224 }
225 }
226 }
227 Ok(())
228 }
229
230 fn get_params(&self, py: Python, _deep: Option<bool>) -> PyResult<Py<PyAny>> {
232 let dict = PyDict::new(py);
233 dict.set_item("n_clusters", self.n_clusters)?;
234 dict.set_item("max_iter", self.max_iter)?;
235 dict.set_item("tol", self.tol)?;
236 dict.set_item("random_state", self.random_state)?;
237 dict.set_item("n_init", self.n_init)?;
238 dict.set_item("init", &self.init)?;
239 Ok(dict.into_any().unbind())
240 }
241}
242
243#[pyfunction]
245fn silhouette_score_py(
246 x: &Bound<'_, PyArray2<f64>>,
247 labels: &Bound<'_, PyArray1<i32>>,
248) -> PyResult<f64> {
249 let binding = x.readonly();
250 let data = binding.as_array();
251 let labels_binding = labels.readonly();
252 let labels_arr = labels_binding.as_array();
253
254 let score = silhouette_score(data, labels_arr)
255 .map_err(|e| PyRuntimeError::new_err(format!("Silhouette score failed: {}", e)))?;
256
257 Ok(score)
258}
259
260#[pyfunction]
262fn davies_bouldin_score_py(
263 x: &Bound<'_, PyArray2<f64>>,
264 labels: &Bound<'_, PyArray1<i32>>,
265) -> PyResult<f64> {
266 let binding = x.readonly();
267 let data = binding.as_array();
268 let labels_binding = labels.readonly();
269 let labels_arr = labels_binding.as_array();
270
271 let score = davies_bouldin_score(data, labels_arr)
272 .map_err(|e| PyRuntimeError::new_err(format!("Davies-Bouldin score failed: {}", e)))?;
273
274 Ok(score)
275}
276
277#[pyfunction]
279fn calinski_harabasz_score_py(
280 x: &Bound<'_, PyArray2<f64>>,
281 labels: &Bound<'_, PyArray1<i32>>,
282) -> PyResult<f64> {
283 let binding = x.readonly();
284 let data = binding.as_array();
285 let labels_binding = labels.readonly();
286 let labels_arr = labels_binding.as_array();
287
288 let score = calinski_harabasz_score(data, labels_arr)
289 .map_err(|e| PyRuntimeError::new_err(format!("Calinski-Harabasz score failed: {}", e)))?;
290
291 Ok(score)
292}
293
294#[pyfunction]
296fn standardize_py(py: Python, x: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyArray2<f64>>> {
297 let binding = x.readonly();
298 let data = binding.as_array();
299
300 let result = standardize(data, true) .map_err(|e| PyRuntimeError::new_err(format!("Standardization failed: {}", e)))?;
302
303 Ok(result.into_pyarray(py).unbind())
304}
305
306#[pyfunction]
308fn normalize_py(
309 py: Python,
310 x: &Bound<'_, PyArray2<f64>>,
311 norm: Option<&str>,
312) -> PyResult<Py<PyArray2<f64>>> {
313 let binding = x.readonly();
314 let data = binding.as_array();
315
316 let norm_type = match norm.unwrap_or("l2") {
317 "l1" => NormType::L1,
318 "l2" => NormType::L2,
319 "max" => NormType::Max,
320 other => {
321 return Err(PyValueError::new_err(format!(
322 "Unknown norm type: {}",
323 other
324 )))
325 }
326 };
327
328 let result =
329 normalize(data, norm_type, true) .map_err(|e| PyRuntimeError::new_err(format!("Normalization failed: {}", e)))?;
331
332 Ok(result.into_pyarray(py).unbind())
333}
334
335pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
337 m.add_class::<PyKMeans>()?;
339
340 m.add_function(wrap_pyfunction!(silhouette_score_py, m)?)?;
342 m.add_function(wrap_pyfunction!(davies_bouldin_score_py, m)?)?;
343 m.add_function(wrap_pyfunction!(calinski_harabasz_score_py, m)?)?;
344
345 m.add_function(wrap_pyfunction!(standardize_py, m)?)?;
347 m.add_function(wrap_pyfunction!(normalize_py, m)?)?;
348
349 Ok(())
350}