1use pyo3::exceptions::{PyRuntimeError, PyValueError};
7use pyo3::prelude::*;
8use pyo3::types::PyDict;
9
10use scirs2_numpy::{IntoPyArray, PyArray1, PyArray2, PyArrayMethods};
12
13use scirs2_core::{Array1, Array2};
15
16use scirs2_cluster::kmeans;
18use scirs2_cluster::{calinski_harabasz_score, davies_bouldin_score, silhouette_score};
19use scirs2_cluster::{normalize, standardize, NormType};
20
21#[pyclass(name = "KMeans")]
23pub struct PyKMeans {
24 n_clusters: usize,
26 max_iter: usize,
28 tol: f64,
30 random_state: Option<u64>,
32 n_init: usize,
34 init: String,
36 cluster_centers_: Option<Vec<Vec<f64>>>,
38 labels_: Option<Vec<usize>>,
40 inertia_: Option<f64>,
42}
43
44#[pymethods]
45impl PyKMeans {
46 #[new]
48 #[pyo3(signature = (n_clusters=8, *, init="k-means++", n_init=10, max_iter=300, tol=1e-4, random_state=None))]
49 fn new(
50 n_clusters: usize,
51 init: &str,
52 n_init: usize,
53 max_iter: usize,
54 tol: f64,
55 random_state: Option<u64>,
56 ) -> Self {
57 Self {
58 n_clusters,
59 max_iter,
60 tol,
61 random_state,
62 n_init,
63 init: init.to_string(),
64 cluster_centers_: None,
65 labels_: None,
66 inertia_: None,
67 }
68 }
69
70 fn fit(&mut self, _py: Python, x: &Bound<'_, PyArray2<f64>>) -> PyResult<()> {
72 let binding = x.readonly();
73 let data = binding.as_array();
74
75 let (centroids, inertia) = kmeans(
77 data,
78 self.n_clusters,
79 Some(self.max_iter),
80 Some(self.tol),
81 Some(true), self.random_state,
83 )
84 .map_err(|e| PyRuntimeError::new_err(format!("K-means fitting failed: {}", e)))?;
85
86 let n_samples = data.nrows();
88 let mut labels = Vec::with_capacity(n_samples);
89
90 for sample in data.rows() {
91 let mut min_dist = f64::INFINITY;
92 let mut best_cluster = 0;
93
94 for (j, centroid) in centroids.rows().into_iter().enumerate() {
95 let dist: f64 = sample
96 .iter()
97 .zip(centroid.iter())
98 .map(|(a, b)| (a - b).powi(2))
99 .sum::<f64>()
100 .sqrt();
101
102 if dist < min_dist {
103 min_dist = dist;
104 best_cluster = j;
105 }
106 }
107 labels.push(best_cluster);
108 }
109
110 self.cluster_centers_ = Some(
112 centroids
113 .rows()
114 .into_iter()
115 .map(|row| row.to_vec())
116 .collect(),
117 );
118 self.labels_ = Some(labels);
119 self.inertia_ = Some(inertia);
120
121 Ok(())
122 }
123
124 fn fit_predict(
126 &mut self,
127 py: Python,
128 x: &Bound<'_, PyArray2<f64>>,
129 ) -> PyResult<Py<PyArray1<i32>>> {
130 self.fit(py, x)?;
131 self.labels(py)
132 }
133
134 fn predict(&self, py: Python, x: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyArray1<i32>>> {
136 if self.cluster_centers_.is_none() {
137 return Err(PyRuntimeError::new_err("Model not fitted yet"));
138 }
139
140 let binding = x.readonly();
141 let data = binding.as_array();
142 let centers = self.cluster_centers_.as_ref().expect("Operation failed");
143
144 let n_samples = data.nrows();
145 let mut labels = Vec::with_capacity(n_samples);
146
147 for sample in data.rows() {
148 let mut min_dist = f64::INFINITY;
149 let mut best_cluster = 0;
150
151 for (j, center) in centers.iter().enumerate() {
152 let dist: f64 = sample
153 .iter()
154 .zip(center.iter())
155 .map(|(a, b)| (a - b).powi(2))
156 .sum::<f64>()
157 .sqrt();
158
159 if dist < min_dist {
160 min_dist = dist;
161 best_cluster = j;
162 }
163 }
164 labels.push(best_cluster as i32);
165 }
166
167 let labels_array = Array1::from_vec(labels);
168 Ok(labels_array.into_pyarray(py).unbind())
169 }
170
171 #[getter]
173 fn cluster_centers_(&self, py: Python) -> PyResult<Option<Py<PyArray2<f64>>>> {
174 match &self.cluster_centers_ {
175 Some(centers) => {
176 let n_clusters = centers.len();
177 let n_features = centers.first().map(|c| c.len()).unwrap_or(0);
178 let flat: Vec<f64> = centers.iter().flatten().copied().collect();
179 let array = Array2::from_shape_vec((n_clusters, n_features), flat)
180 .map_err(|e| PyRuntimeError::new_err(format!("Array reshape error: {}", e)))?;
181 Ok(Some(array.into_pyarray(py).unbind()))
182 }
183 None => Ok(None),
184 }
185 }
186
187 #[getter]
189 fn labels(&self, py: Python) -> PyResult<Py<PyArray1<i32>>> {
190 match &self.labels_ {
191 Some(labels) => {
192 let labels_i32: Vec<i32> = labels.iter().map(|&x| x as i32).collect();
193 let array = Array1::from_vec(labels_i32);
194 Ok(array.into_pyarray(py).unbind())
195 }
196 None => Err(PyRuntimeError::new_err("Model not fitted yet")),
197 }
198 }
199
200 #[getter]
202 fn inertia_(&self) -> Option<f64> {
203 self.inertia_
204 }
205
206 fn set_params(&mut self, params: &Bound<'_, PyDict>) -> PyResult<()> {
208 for (key, value) in params.iter() {
209 let key_str: String = key.extract()?;
210 match key_str.as_str() {
211 "n_clusters" => self.n_clusters = value.extract()?,
212 "max_iter" => self.max_iter = value.extract()?,
213 "tol" => self.tol = value.extract()?,
214 "random_state" => self.random_state = value.extract()?,
215 "n_init" => self.n_init = value.extract()?,
216 "init" => self.init = value.extract()?,
217 _ => {
218 return Err(PyValueError::new_err(format!(
219 "Unknown parameter: {}",
220 key_str
221 )))
222 }
223 }
224 }
225 Ok(())
226 }
227
228 fn get_params(&self, py: Python, _deep: Option<bool>) -> PyResult<Py<PyAny>> {
230 let dict = PyDict::new(py);
231 dict.set_item("n_clusters", self.n_clusters)?;
232 dict.set_item("max_iter", self.max_iter)?;
233 dict.set_item("tol", self.tol)?;
234 dict.set_item("random_state", self.random_state)?;
235 dict.set_item("n_init", self.n_init)?;
236 dict.set_item("init", &self.init)?;
237 Ok(dict.into_any().unbind())
238 }
239}
240
241#[pyfunction]
243fn silhouette_score_py(
244 x: &Bound<'_, PyArray2<f64>>,
245 labels: &Bound<'_, PyArray1<i32>>,
246) -> PyResult<f64> {
247 let binding = x.readonly();
248 let data = binding.as_array();
249 let labels_binding = labels.readonly();
250 let labels_arr = labels_binding.as_array();
251
252 let score = silhouette_score(data, labels_arr)
253 .map_err(|e| PyRuntimeError::new_err(format!("Silhouette score failed: {}", e)))?;
254
255 Ok(score)
256}
257
258#[pyfunction]
260fn davies_bouldin_score_py(
261 x: &Bound<'_, PyArray2<f64>>,
262 labels: &Bound<'_, PyArray1<i32>>,
263) -> PyResult<f64> {
264 let binding = x.readonly();
265 let data = binding.as_array();
266 let labels_binding = labels.readonly();
267 let labels_arr = labels_binding.as_array();
268
269 let score = davies_bouldin_score(data, labels_arr)
270 .map_err(|e| PyRuntimeError::new_err(format!("Davies-Bouldin score failed: {}", e)))?;
271
272 Ok(score)
273}
274
275#[pyfunction]
277fn calinski_harabasz_score_py(
278 x: &Bound<'_, PyArray2<f64>>,
279 labels: &Bound<'_, PyArray1<i32>>,
280) -> PyResult<f64> {
281 let binding = x.readonly();
282 let data = binding.as_array();
283 let labels_binding = labels.readonly();
284 let labels_arr = labels_binding.as_array();
285
286 let score = calinski_harabasz_score(data, labels_arr)
287 .map_err(|e| PyRuntimeError::new_err(format!("Calinski-Harabasz score failed: {}", e)))?;
288
289 Ok(score)
290}
291
292#[pyfunction]
294fn standardize_py(py: Python, x: &Bound<'_, PyArray2<f64>>) -> PyResult<Py<PyArray2<f64>>> {
295 let binding = x.readonly();
296 let data = binding.as_array();
297
298 let result = standardize(data, true) .map_err(|e| PyRuntimeError::new_err(format!("Standardization failed: {}", e)))?;
300
301 Ok(result.into_pyarray(py).unbind())
302}
303
304#[pyfunction]
306fn normalize_py(
307 py: Python,
308 x: &Bound<'_, PyArray2<f64>>,
309 norm: Option<&str>,
310) -> PyResult<Py<PyArray2<f64>>> {
311 let binding = x.readonly();
312 let data = binding.as_array();
313
314 let norm_type = match norm.unwrap_or("l2") {
315 "l1" => NormType::L1,
316 "l2" => NormType::L2,
317 "max" => NormType::Max,
318 other => {
319 return Err(PyValueError::new_err(format!(
320 "Unknown norm type: {}",
321 other
322 )))
323 }
324 };
325
326 let result =
327 normalize(data, norm_type, true) .map_err(|e| PyRuntimeError::new_err(format!("Normalization failed: {}", e)))?;
329
330 Ok(result.into_pyarray(py).unbind())
331}
332
333pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
335 m.add_class::<PyKMeans>()?;
337
338 m.add_function(wrap_pyfunction!(silhouette_score_py, m)?)?;
340 m.add_function(wrap_pyfunction!(davies_bouldin_score_py, m)?)?;
341 m.add_function(wrap_pyfunction!(calinski_harabasz_score_py, m)?)?;
342
343 m.add_function(wrap_pyfunction!(standardize_py, m)?)?;
345 m.add_function(wrap_pyfunction!(normalize_py, m)?)?;
346
347 Ok(())
348}