1#[derive(Debug, Clone, Copy)]
9pub struct NeighborPair {
10 pub i: usize,
11 pub j: usize,
12 pub dist_sq: f64,
13}
14
15pub struct CellList {
17 pub cutoff: f64,
19 cutoff_sq: f64,
21 _cell_size: f64,
23 n_cells: [usize; 3],
25 _origin: [f64; 3],
27 cells: Vec<Vec<usize>>,
29}
30
31impl CellList {
32 pub fn new(positions: &[[f64; 3]], cutoff: f64) -> Self {
36 let padding = 0.01;
37 let mut min = [f64::INFINITY; 3];
38 let mut max = [f64::NEG_INFINITY; 3];
39
40 for pos in positions {
41 for d in 0..3 {
42 if pos[d] < min[d] {
43 min[d] = pos[d];
44 }
45 if pos[d] > max[d] {
46 max[d] = pos[d];
47 }
48 }
49 }
50
51 let origin = [min[0] - padding, min[1] - padding, min[2] - padding];
52 let cell_size = cutoff;
53 let n_cells = [
54 ((max[0] - origin[0] + padding) / cell_size).ceil().max(1.0) as usize,
55 ((max[1] - origin[1] + padding) / cell_size).ceil().max(1.0) as usize,
56 ((max[2] - origin[2] + padding) / cell_size).ceil().max(1.0) as usize,
57 ];
58
59 let total = n_cells[0] * n_cells[1] * n_cells[2];
60 let mut cells = vec![Vec::new(); total];
61
62 for (idx, pos) in positions.iter().enumerate() {
63 let ci = Self::cell_index_static(pos, &origin, cell_size, &n_cells);
64 cells[ci].push(idx);
65 }
66
67 CellList {
68 cutoff,
69 cutoff_sq: cutoff * cutoff,
70 _cell_size: cell_size,
71 n_cells,
72 _origin: origin,
73 cells,
74 }
75 }
76
77 fn cell_index_static(
78 pos: &[f64; 3],
79 origin: &[f64; 3],
80 cell_size: f64,
81 n_cells: &[usize; 3],
82 ) -> usize {
83 let cx = ((pos[0] - origin[0]) / cell_size) as usize;
84 let cy = ((pos[1] - origin[1]) / cell_size) as usize;
85 let cz = ((pos[2] - origin[2]) / cell_size) as usize;
86 let cx = cx.min(n_cells[0] - 1);
87 let cy = cy.min(n_cells[1] - 1);
88 let cz = cz.min(n_cells[2] - 1);
89 cx * n_cells[1] * n_cells[2] + cy * n_cells[2] + cz
90 }
91
92 pub fn find_neighbors(&self, positions: &[[f64; 3]]) -> Vec<NeighborPair> {
96 let mut pairs = Vec::new();
97 let nc = &self.n_cells;
98
99 for cx in 0..nc[0] {
100 for cy in 0..nc[1] {
101 for cz in 0..nc[2] {
102 let ci = cx * nc[1] * nc[2] + cy * nc[2] + cz;
103 self.pairs_within_cell(&self.cells[ci], positions, &mut pairs);
104
105 for &(dx, dy, dz) in &HALF_NEIGHBOR_OFFSETS {
107 let nx = cx as isize + dx;
108 let ny = cy as isize + dy;
109 let nz = cz as isize + dz;
110 if nx < 0 || ny < 0 || nz < 0 {
111 continue;
112 }
113 let (nx, ny, nz) = (nx as usize, ny as usize, nz as usize);
114 if nx >= nc[0] || ny >= nc[1] || nz >= nc[2] {
115 continue;
116 }
117 let ni = nx * nc[1] * nc[2] + ny * nc[2] + nz;
118 self.pairs_between_cells(
119 &self.cells[ci],
120 &self.cells[ni],
121 positions,
122 &mut pairs,
123 );
124 }
125 }
126 }
127 }
128 pairs
129 }
130
131 fn pairs_within_cell(
132 &self,
133 cell: &[usize],
134 positions: &[[f64; 3]],
135 pairs: &mut Vec<NeighborPair>,
136 ) {
137 for a in 0..cell.len() {
138 for b in (a + 1)..cell.len() {
139 let i = cell[a];
140 let j = cell[b];
141 let dsq = dist_sq(&positions[i], &positions[j]);
142 if dsq < self.cutoff_sq {
143 pairs.push(NeighborPair { i, j, dist_sq: dsq });
144 }
145 }
146 }
147 }
148
149 fn pairs_between_cells(
150 &self,
151 cell_a: &[usize],
152 cell_b: &[usize],
153 positions: &[[f64; 3]],
154 pairs: &mut Vec<NeighborPair>,
155 ) {
156 for &i in cell_a {
157 for &j in cell_b {
158 let dsq = dist_sq(&positions[i], &positions[j]);
159 if dsq < self.cutoff_sq {
160 let (lo, hi) = if i < j { (i, j) } else { (j, i) };
161 pairs.push(NeighborPair {
162 i: lo,
163 j: hi,
164 dist_sq: dsq,
165 });
166 }
167 }
168 }
169 }
170}
171
172#[inline]
173fn dist_sq(a: &[f64; 3], b: &[f64; 3]) -> f64 {
174 let dx = a[0] - b[0];
175 let dy = a[1] - b[1];
176 let dz = a[2] - b[2];
177 dx * dx + dy * dy + dz * dz
178}
179
180const HALF_NEIGHBOR_OFFSETS: [(isize, isize, isize); 13] = [
182 (1, 0, 0),
183 (0, 1, 0),
184 (0, 0, 1),
185 (1, 1, 0),
186 (1, -1, 0),
187 (1, 0, 1),
188 (1, 0, -1),
189 (0, 1, 1),
190 (0, 1, -1),
191 (1, 1, 1),
192 (1, 1, -1),
193 (1, -1, 1),
194 (1, -1, -1),
195];
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 #[test]
202 fn test_water_neighbors() {
203 let positions = [[0.0, 0.0, 0.0], [0.757, 0.586, 0.0], [-0.757, 0.586, 0.0]];
205 let cl = CellList::new(&positions, 5.2);
206 let pairs = cl.find_neighbors(&positions);
207 assert_eq!(pairs.len(), 3, "Water should have 3 pairs");
209 }
210
211 #[test]
212 fn test_distant_atoms_excluded() {
213 let positions = [[0.0, 0.0, 0.0], [10.0, 0.0, 0.0]];
214 let cl = CellList::new(&positions, 5.2);
215 let pairs = cl.find_neighbors(&positions);
216 assert_eq!(pairs.len(), 0, "Atoms 10 Å apart should not be neighbors");
217 }
218}