Skip to main content

scirs2/
spatial.rs

1//! Python bindings for scirs2-spatial
2//!
3//! Provides spatial algorithms similar to scipy.spatial
4
5use pyo3::prelude::*;
6use pyo3::types::PyDict;
7use scirs2_core::ndarray::Array2 as Array2_17;
8use scirs2_core::python::numpy_compat::{
9    scirs_to_numpy_array1, scirs_to_numpy_array2, Array1, Array2,
10};
11use scirs2_numpy::{PyArray1, PyArray2, PyReadonlyArray1, PyReadonlyArray2};
12
13// Import KDTree
14use scirs2_spatial::distance::EuclideanDistance;
15use scirs2_spatial::KDTree;
16
17// Import ConvexHull
18use scirs2_spatial::convex_hull::ConvexHull;
19
20// =============================================================================
21// Distance Functions
22// =============================================================================
23
24/// Euclidean distance between two points
25#[pyfunction]
26fn euclidean_py(u: PyReadonlyArray1<f64>, v: PyReadonlyArray1<f64>) -> PyResult<f64> {
27    let u_arr = u.as_array();
28    let v_arr = v.as_array();
29
30    if u_arr.len() != v_arr.len() {
31        return Err(pyo3::exceptions::PyValueError::new_err(
32            "Arrays must have same length",
33        ));
34    }
35
36    let dist: f64 = u_arr
37        .iter()
38        .zip(v_arr.iter())
39        .map(|(a, b)| (a - b).powi(2))
40        .sum::<f64>()
41        .sqrt();
42
43    Ok(dist)
44}
45
46/// Manhattan (city block) distance between two points
47#[pyfunction]
48fn cityblock_py(u: PyReadonlyArray1<f64>, v: PyReadonlyArray1<f64>) -> PyResult<f64> {
49    let u_arr = u.as_array();
50    let v_arr = v.as_array();
51
52    if u_arr.len() != v_arr.len() {
53        return Err(pyo3::exceptions::PyValueError::new_err(
54            "Arrays must have same length",
55        ));
56    }
57
58    let dist: f64 = u_arr
59        .iter()
60        .zip(v_arr.iter())
61        .map(|(a, b)| (a - b).abs())
62        .sum();
63
64    Ok(dist)
65}
66
67/// Chebyshev distance between two points
68#[pyfunction]
69fn chebyshev_py(u: PyReadonlyArray1<f64>, v: PyReadonlyArray1<f64>) -> PyResult<f64> {
70    let u_arr = u.as_array();
71    let v_arr = v.as_array();
72
73    if u_arr.len() != v_arr.len() {
74        return Err(pyo3::exceptions::PyValueError::new_err(
75            "Arrays must have same length",
76        ));
77    }
78
79    let dist: f64 = u_arr
80        .iter()
81        .zip(v_arr.iter())
82        .map(|(a, b)| (a - b).abs())
83        .fold(0.0, f64::max);
84
85    Ok(dist)
86}
87
88/// Minkowski distance between two points
89#[pyfunction]
90fn minkowski_py(u: PyReadonlyArray1<f64>, v: PyReadonlyArray1<f64>, p: f64) -> PyResult<f64> {
91    let u_arr = u.as_array();
92    let v_arr = v.as_array();
93
94    if u_arr.len() != v_arr.len() {
95        return Err(pyo3::exceptions::PyValueError::new_err(
96            "Arrays must have same length",
97        ));
98    }
99
100    let dist: f64 = u_arr
101        .iter()
102        .zip(v_arr.iter())
103        .map(|(a, b)| (a - b).abs().powf(p))
104        .sum::<f64>()
105        .powf(1.0 / p);
106
107    Ok(dist)
108}
109
110/// Cosine distance between two points
111#[pyfunction]
112fn cosine_py(u: PyReadonlyArray1<f64>, v: PyReadonlyArray1<f64>) -> PyResult<f64> {
113    let u_arr = u.as_array();
114    let v_arr = v.as_array();
115
116    if u_arr.len() != v_arr.len() {
117        return Err(pyo3::exceptions::PyValueError::new_err(
118            "Arrays must have same length",
119        ));
120    }
121
122    let dot: f64 = u_arr.iter().zip(v_arr.iter()).map(|(a, b)| a * b).sum();
123    let norm_u: f64 = u_arr.iter().map(|a| a.powi(2)).sum::<f64>().sqrt();
124    let norm_v: f64 = v_arr.iter().map(|a| a.powi(2)).sum::<f64>().sqrt();
125
126    if norm_u == 0.0 || norm_v == 0.0 {
127        return Err(pyo3::exceptions::PyValueError::new_err("Zero vector"));
128    }
129
130    Ok(1.0 - dot / (norm_u * norm_v))
131}
132
133// =============================================================================
134// Pairwise Distance Matrix
135// =============================================================================
136
137/// Compute pairwise distances between observations
138#[pyfunction]
139#[pyo3(signature = (x, metric="euclidean"))]
140fn pdist_py(py: Python, x: PyReadonlyArray2<f64>, metric: &str) -> PyResult<Py<PyArray1<f64>>> {
141    let x_arr = x.as_array();
142    let n = x_arr.nrows();
143
144    // Number of pairwise distances
145    let n_dist = n * (n - 1) / 2;
146    let mut result = Vec::with_capacity(n_dist);
147
148    for i in 0..n {
149        for j in (i + 1)..n {
150            let dist = match metric {
151                "euclidean" => x_arr
152                    .row(i)
153                    .iter()
154                    .zip(x_arr.row(j).iter())
155                    .map(|(a, b)| (a - b).powi(2))
156                    .sum::<f64>()
157                    .sqrt(),
158                "cityblock" | "manhattan" => x_arr
159                    .row(i)
160                    .iter()
161                    .zip(x_arr.row(j).iter())
162                    .map(|(a, b)| (a - b).abs())
163                    .sum(),
164                "chebyshev" => x_arr
165                    .row(i)
166                    .iter()
167                    .zip(x_arr.row(j).iter())
168                    .map(|(a, b)| (a - b).abs())
169                    .fold(0.0, f64::max),
170                _ => x_arr
171                    .row(i)
172                    .iter()
173                    .zip(x_arr.row(j).iter())
174                    .map(|(a, b)| (a - b).powi(2))
175                    .sum::<f64>()
176                    .sqrt(),
177            };
178            result.push(dist);
179        }
180    }
181
182    scirs_to_numpy_array1(Array1::from_vec(result), py)
183}
184
185/// Compute pairwise distances between two sets of observations
186#[pyfunction]
187#[pyo3(signature = (xa, xb, metric="euclidean"))]
188fn cdist_py(
189    py: Python,
190    xa: PyReadonlyArray2<f64>,
191    xb: PyReadonlyArray2<f64>,
192    metric: &str,
193) -> PyResult<Py<PyArray2<f64>>> {
194    let xa_arr = xa.as_array();
195    let xb_arr = xb.as_array();
196    let na = xa_arr.nrows();
197    let nb = xb_arr.nrows();
198
199    if xa_arr.ncols() != xb_arr.ncols() {
200        return Err(pyo3::exceptions::PyValueError::new_err(
201            "Arrays must have same number of columns",
202        ));
203    }
204
205    let mut result = Vec::with_capacity(na * nb);
206
207    for i in 0..na {
208        for j in 0..nb {
209            let dist = match metric {
210                "euclidean" => xa_arr
211                    .row(i)
212                    .iter()
213                    .zip(xb_arr.row(j).iter())
214                    .map(|(a, b)| (a - b).powi(2))
215                    .sum::<f64>()
216                    .sqrt(),
217                "cityblock" | "manhattan" => xa_arr
218                    .row(i)
219                    .iter()
220                    .zip(xb_arr.row(j).iter())
221                    .map(|(a, b)| (a - b).abs())
222                    .sum(),
223                "chebyshev" => xa_arr
224                    .row(i)
225                    .iter()
226                    .zip(xb_arr.row(j).iter())
227                    .map(|(a, b)| (a - b).abs())
228                    .fold(0.0, f64::max),
229                _ => xa_arr
230                    .row(i)
231                    .iter()
232                    .zip(xb_arr.row(j).iter())
233                    .map(|(a, b)| (a - b).powi(2))
234                    .sum::<f64>()
235                    .sqrt(),
236            };
237            result.push(dist);
238        }
239    }
240
241    // Reshape to 2D array
242    let arr = Array2::from_shape_vec((na, nb), result)
243        .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
244
245    scirs_to_numpy_array2(arr, py)
246}
247
248/// Convert condensed distance matrix to square form
249#[pyfunction]
250fn squareform_py(py: Python, x: PyReadonlyArray1<f64>) -> PyResult<Py<PyArray2<f64>>> {
251    let x_arr = x.as_array();
252    let n_dist = x_arr.len();
253
254    // Solve n*(n-1)/2 = n_dist for n
255    let n = ((1.0 + (1.0 + 8.0 * n_dist as f64).sqrt()) / 2.0) as usize;
256
257    let mut result = Array2::zeros((n, n));
258
259    let mut idx = 0;
260    for i in 0..n {
261        for j in (i + 1)..n {
262            result[[i, j]] = x_arr[idx];
263            result[[j, i]] = x_arr[idx];
264            idx += 1;
265        }
266    }
267
268    scirs_to_numpy_array2(result, py)
269}
270
271// =============================================================================
272// Convex Hull
273// =============================================================================
274
275/// Compute the convex hull of a set of points
276///
277/// Returns indices of points that form the convex hull vertices
278#[pyfunction]
279fn convex_hull_py(py: Python, points: PyReadonlyArray2<f64>) -> PyResult<Py<PyAny>> {
280    let points_arr = points.as_array();
281    let n = points_arr.nrows();
282    let k = points_arr.ncols();
283
284    // Convert to Array2_17
285    let mut pts = Vec::with_capacity(n * k);
286    for row in points_arr.rows() {
287        for &val in row.iter() {
288            pts.push(val);
289        }
290    }
291    let arr = Array2_17::from_shape_vec((n, k), pts)
292        .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
293
294    let hull = ConvexHull::new(&arr.view())
295        .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
296
297    // Get vertices and simplices
298    let vertices: Vec<i64> = hull.vertex_indices().iter().map(|&i| i as i64).collect();
299    let simplices: Vec<Vec<i64>> = hull
300        .simplices()
301        .iter()
302        .map(|s| s.iter().map(|&i| i as i64).collect())
303        .collect();
304
305    // Calculate volume and area
306    let volume = hull.volume().unwrap_or(0.0);
307    let area = hull.area().unwrap_or(0.0);
308
309    let dict = PyDict::new(py);
310    dict.set_item(
311        "vertices",
312        scirs_to_numpy_array1(Array1::from_vec(vertices), py)?,
313    )?;
314
315    // Convert simplices to a flat representation for Python
316    let simplices_py: Vec<Vec<i64>> = simplices;
317    dict.set_item("simplices", simplices_py)?;
318    dict.set_item("volume", volume)?;
319    dict.set_item("area", area)?;
320
321    Ok(dict.into())
322}
323
324/// ConvexHull class for working with convex hulls
325#[pyclass(name = "ConvexHullPy", unsendable)]
326pub struct PyConvexHull {
327    hull: ConvexHull,
328}
329
330#[pymethods]
331impl PyConvexHull {
332    /// Create a new ConvexHull from a 2D array of points
333    ///
334    /// Parameters:
335    /// - points: Array of shape (n, k) containing n points in k dimensions
336    #[new]
337    fn new(points: PyReadonlyArray2<f64>) -> PyResult<Self> {
338        let points_arr = points.as_array();
339        let n = points_arr.nrows();
340        let k = points_arr.ncols();
341
342        // Convert to Array2_17
343        let mut pts = Vec::with_capacity(n * k);
344        for row in points_arr.rows() {
345            for &val in row.iter() {
346                pts.push(val);
347            }
348        }
349        let arr = Array2_17::from_shape_vec((n, k), pts)
350            .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
351
352        let hull = ConvexHull::new(&arr.view())
353            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
354
355        Ok(PyConvexHull { hull })
356    }
357
358    /// Get the indices of vertices that form the convex hull
359    fn vertices(&self, py: Python) -> PyResult<Py<PyArray1<i64>>> {
360        let vertices: Vec<i64> = self
361            .hull
362            .vertex_indices()
363            .iter()
364            .map(|&i| i as i64)
365            .collect();
366        scirs_to_numpy_array1(Array1::from_vec(vertices), py)
367    }
368
369    /// Get the simplices (facets) of the convex hull
370    fn simplices(&self) -> Vec<Vec<i64>> {
371        self.hull
372            .simplices()
373            .iter()
374            .map(|s| s.iter().map(|&i| i as i64).collect())
375            .collect()
376    }
377
378    /// Calculate the volume of the convex hull
379    fn volume(&self) -> PyResult<f64> {
380        self.hull
381            .volume()
382            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))
383    }
384
385    /// Calculate the surface area of the convex hull
386    fn area(&self) -> PyResult<f64> {
387        self.hull
388            .area()
389            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))
390    }
391
392    /// Check if a point is inside the convex hull
393    fn contains(&self, point: PyReadonlyArray1<f64>) -> PyResult<bool> {
394        let point_vec: Vec<f64> = point.as_array().to_vec();
395        self.hull
396            .contains(&point_vec)
397            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))
398    }
399}
400
401// =============================================================================
402// KD-Tree
403// =============================================================================
404
405/// KD-Tree for efficient nearest neighbor searches
406#[pyclass(name = "KDTree")]
407pub struct PyKDTree {
408    tree: KDTree<f64, EuclideanDistance<f64>>,
409}
410
411#[pymethods]
412impl PyKDTree {
413    /// Create a new KD-Tree from a 2D array of points
414    ///
415    /// Parameters:
416    /// - data: Array of shape (n, k) containing n points in k dimensions
417    #[new]
418    fn new(data: PyReadonlyArray2<f64>) -> PyResult<Self> {
419        let data_arr = data.as_array();
420        let n = data_arr.nrows();
421        let k = data_arr.ncols();
422
423        // Convert to Array2_17
424        let mut points = Vec::with_capacity(n * k);
425        for row in data_arr.rows() {
426            for &val in row.iter() {
427                points.push(val);
428            }
429        }
430        let arr = Array2_17::from_shape_vec((n, k), points)
431            .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
432
433        let tree = KDTree::new(&arr)
434            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
435
436        Ok(PyKDTree { tree })
437    }
438
439    /// Query the tree for the k nearest neighbors to a point
440    ///
441    /// Parameters:
442    /// - point: Query point
443    /// - k: Number of nearest neighbors to find
444    ///
445    /// Returns:
446    /// - Tuple of (indices, distances) arrays
447    fn query(&self, py: Python, point: PyReadonlyArray1<f64>, k: usize) -> PyResult<Py<PyAny>> {
448        let point_vec: Vec<f64> = point.as_array().to_vec();
449
450        let (indices, distances) = self
451            .tree
452            .query(&point_vec, k)
453            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
454
455        let dict = PyDict::new(py);
456        dict.set_item(
457            "indices",
458            scirs_to_numpy_array1(
459                Array1::from_vec(indices.iter().map(|&i| i as i64).collect()),
460                py,
461            )?,
462        )?;
463        dict.set_item(
464            "distances",
465            scirs_to_numpy_array1(Array1::from_vec(distances), py)?,
466        )?;
467
468        Ok(dict.into())
469    }
470
471    /// Query the tree for all points within a given radius
472    ///
473    /// Parameters:
474    /// - point: Query point
475    /// - r: Radius
476    ///
477    /// Returns:
478    /// - Tuple of (indices, distances) arrays
479    fn query_radius(
480        &self,
481        py: Python,
482        point: PyReadonlyArray1<f64>,
483        r: f64,
484    ) -> PyResult<Py<PyAny>> {
485        let point_vec: Vec<f64> = point.as_array().to_vec();
486
487        let (indices, distances) = self
488            .tree
489            .query_radius(&point_vec, r)
490            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
491
492        let dict = PyDict::new(py);
493        dict.set_item(
494            "indices",
495            scirs_to_numpy_array1(
496                Array1::from_vec(indices.iter().map(|&i| i as i64).collect()),
497                py,
498            )?,
499        )?;
500        dict.set_item(
501            "distances",
502            scirs_to_numpy_array1(Array1::from_vec(distances), py)?,
503        )?;
504
505        Ok(dict.into())
506    }
507}
508
509/// Python module registration
510pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
511    // Distance functions
512    m.add_function(wrap_pyfunction!(euclidean_py, m)?)?;
513    m.add_function(wrap_pyfunction!(cityblock_py, m)?)?;
514    m.add_function(wrap_pyfunction!(chebyshev_py, m)?)?;
515    m.add_function(wrap_pyfunction!(minkowski_py, m)?)?;
516    m.add_function(wrap_pyfunction!(cosine_py, m)?)?;
517
518    // Pairwise distances
519    m.add_function(wrap_pyfunction!(pdist_py, m)?)?;
520    m.add_function(wrap_pyfunction!(cdist_py, m)?)?;
521    m.add_function(wrap_pyfunction!(squareform_py, m)?)?;
522
523    // Convex hull
524    m.add_function(wrap_pyfunction!(convex_hull_py, m)?)?;
525    m.add_class::<PyConvexHull>()?;
526
527    // Spatial data structures
528    m.add_class::<PyKDTree>()?;
529
530    Ok(())
531}