Skip to main content

sci_form/ani/
neighbor.rs

1//! Cell-list spatial partitioning for efficient neighbor search.
2//!
3//! Bins atoms into a 3D grid of cells with side length equal to the cutoff
4//! radius. Only atoms in adjacent cells need distance evaluation, giving
5//! O(N) scaling for uniformly distributed systems.
6
7/// A neighbor pair: (atom_i, atom_j, squared_distance).
8#[derive(Debug, Clone, Copy)]
9pub struct NeighborPair {
10    pub i: usize,
11    pub j: usize,
12    pub dist_sq: f64,
13}
14
15/// Spatial cell-list for fast cutoff-radius neighbor search.
16pub struct CellList {
17    /// Cutoff radius.
18    pub cutoff: f64,
19    /// Cutoff radius squared (for comparison without sqrt).
20    cutoff_sq: f64,
21    /// Cell side length (= cutoff).
22    _cell_size: f64,
23    /// Number of cells along each axis.
24    n_cells: [usize; 3],
25    /// Origin (minimum corner) of the bounding box.
26    _origin: [f64; 3],
27    /// Atom indices in each cell, indexed by flat cell index.
28    cells: Vec<Vec<usize>>,
29}
30
31impl CellList {
32    /// Build a cell list from atom positions.
33    ///
34    /// `positions` is a slice of `[x, y, z]` arrays.
35    pub fn new(positions: &[[f64; 3]], cutoff: f64) -> Self {
36        let padding = 0.01;
37        let mut min = [f64::INFINITY; 3];
38        let mut max = [f64::NEG_INFINITY; 3];
39
40        for pos in positions {
41            for d in 0..3 {
42                if pos[d] < min[d] {
43                    min[d] = pos[d];
44                }
45                if pos[d] > max[d] {
46                    max[d] = pos[d];
47                }
48            }
49        }
50
51        let origin = [min[0] - padding, min[1] - padding, min[2] - padding];
52        let cell_size = cutoff;
53        let n_cells = [
54            ((max[0] - origin[0] + padding) / cell_size).ceil().max(1.0) as usize,
55            ((max[1] - origin[1] + padding) / cell_size).ceil().max(1.0) as usize,
56            ((max[2] - origin[2] + padding) / cell_size).ceil().max(1.0) as usize,
57        ];
58
59        let total = n_cells[0] * n_cells[1] * n_cells[2];
60        let mut cells = vec![Vec::new(); total];
61
62        for (idx, pos) in positions.iter().enumerate() {
63            let ci = Self::cell_index_static(pos, &origin, cell_size, &n_cells);
64            cells[ci].push(idx);
65        }
66
67        CellList {
68            cutoff,
69            cutoff_sq: cutoff * cutoff,
70            _cell_size: cell_size,
71            n_cells,
72            _origin: origin,
73            cells,
74        }
75    }
76
77    fn cell_index_static(
78        pos: &[f64; 3],
79        origin: &[f64; 3],
80        cell_size: f64,
81        n_cells: &[usize; 3],
82    ) -> usize {
83        let cx = ((pos[0] - origin[0]) / cell_size) as usize;
84        let cy = ((pos[1] - origin[1]) / cell_size) as usize;
85        let cz = ((pos[2] - origin[2]) / cell_size) as usize;
86        let cx = cx.min(n_cells[0] - 1);
87        let cy = cy.min(n_cells[1] - 1);
88        let cz = cz.min(n_cells[2] - 1);
89        cx * n_cells[1] * n_cells[2] + cy * n_cells[2] + cz
90    }
91
92    /// Find all neighbor pairs within the cutoff radius.
93    ///
94    /// Returns pairs with i < j to avoid duplicates.
95    pub fn find_neighbors(&self, positions: &[[f64; 3]]) -> Vec<NeighborPair> {
96        let mut pairs = Vec::new();
97        let nc = &self.n_cells;
98
99        for cx in 0..nc[0] {
100            for cy in 0..nc[1] {
101                for cz in 0..nc[2] {
102                    let ci = cx * nc[1] * nc[2] + cy * nc[2] + cz;
103                    self.pairs_within_cell(&self.cells[ci], positions, &mut pairs);
104
105                    // Check 13 forward-neighbors (half of 26 neighbors)
106                    for &(dx, dy, dz) in &HALF_NEIGHBOR_OFFSETS {
107                        let nx = cx as isize + dx;
108                        let ny = cy as isize + dy;
109                        let nz = cz as isize + dz;
110                        if nx < 0 || ny < 0 || nz < 0 {
111                            continue;
112                        }
113                        let (nx, ny, nz) = (nx as usize, ny as usize, nz as usize);
114                        if nx >= nc[0] || ny >= nc[1] || nz >= nc[2] {
115                            continue;
116                        }
117                        let ni = nx * nc[1] * nc[2] + ny * nc[2] + nz;
118                        self.pairs_between_cells(
119                            &self.cells[ci],
120                            &self.cells[ni],
121                            positions,
122                            &mut pairs,
123                        );
124                    }
125                }
126            }
127        }
128        pairs
129    }
130
131    fn pairs_within_cell(
132        &self,
133        cell: &[usize],
134        positions: &[[f64; 3]],
135        pairs: &mut Vec<NeighborPair>,
136    ) {
137        for a in 0..cell.len() {
138            for b in (a + 1)..cell.len() {
139                let i = cell[a];
140                let j = cell[b];
141                let dsq = dist_sq(&positions[i], &positions[j]);
142                if dsq < self.cutoff_sq {
143                    pairs.push(NeighborPair { i, j, dist_sq: dsq });
144                }
145            }
146        }
147    }
148
149    fn pairs_between_cells(
150        &self,
151        cell_a: &[usize],
152        cell_b: &[usize],
153        positions: &[[f64; 3]],
154        pairs: &mut Vec<NeighborPair>,
155    ) {
156        for &i in cell_a {
157            for &j in cell_b {
158                let dsq = dist_sq(&positions[i], &positions[j]);
159                if dsq < self.cutoff_sq {
160                    let (lo, hi) = if i < j { (i, j) } else { (j, i) };
161                    pairs.push(NeighborPair {
162                        i: lo,
163                        j: hi,
164                        dist_sq: dsq,
165                    });
166                }
167            }
168        }
169    }
170}
171
172#[inline]
173fn dist_sq(a: &[f64; 3], b: &[f64; 3]) -> f64 {
174    let dx = a[0] - b[0];
175    let dy = a[1] - b[1];
176    let dz = a[2] - b[2];
177    dx * dx + dy * dy + dz * dz
178}
179
180/// 13 forward-neighbor cell offsets (avoids double-counting).
181const HALF_NEIGHBOR_OFFSETS: [(isize, isize, isize); 13] = [
182    (1, 0, 0),
183    (0, 1, 0),
184    (0, 0, 1),
185    (1, 1, 0),
186    (1, -1, 0),
187    (1, 0, 1),
188    (1, 0, -1),
189    (0, 1, 1),
190    (0, 1, -1),
191    (1, 1, 1),
192    (1, 1, -1),
193    (1, -1, 1),
194    (1, -1, -1),
195];
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    #[test]
202    fn test_water_neighbors() {
203        // Water geometry: O at origin, two H at ~0.96 Å
204        let positions = [[0.0, 0.0, 0.0], [0.757, 0.586, 0.0], [-0.757, 0.586, 0.0]];
205        let cl = CellList::new(&positions, 5.2);
206        let pairs = cl.find_neighbors(&positions);
207        // All 3 pairs should be within 5.2 Å
208        assert_eq!(pairs.len(), 3, "Water should have 3 pairs");
209    }
210
211    #[test]
212    fn test_distant_atoms_excluded() {
213        let positions = [[0.0, 0.0, 0.0], [10.0, 0.0, 0.0]];
214        let cl = CellList::new(&positions, 5.2);
215        let pairs = cl.find_neighbors(&positions);
216        assert_eq!(pairs.len(), 0, "Atoms 10 Å apart should not be neighbors");
217    }
218}