1use crate::{Atom, NeighborData};
3
4pub struct SpatialGrid {
5 atom_indices: Vec<u32>,
7
8 positions_x: Vec<f32>,
10 positions_y: Vec<f32>,
11 positions_z: Vec<f32>,
12
13 radii: Vec<f32>,
15
16 cell_starts: Vec<u32>,
18
19 grid_dims: [u32; 3],
21 num_cells: usize,
22
23 half_shell_offsets: Vec<(i32, i32, i32)>,
25}
26
27impl SpatialGrid {
28 pub fn new(
29 atoms: &[Atom],
30 active_indices: &[usize],
31 cell_size: f32,
32 max_search_radius: f32,
33 ) -> Self {
34 let (min_bounds, max_bounds) = Self::calculate_bounds(atoms, active_indices, cell_size);
36 let inv_cell_size = 1.0 / cell_size;
37
38 let grid_dims = [
40 ((max_bounds[0] - min_bounds[0]) * inv_cell_size).ceil() as u32 + 1,
41 ((max_bounds[1] - min_bounds[1]) * inv_cell_size).ceil() as u32 + 1,
42 ((max_bounds[2] - min_bounds[2]) * inv_cell_size).ceil() as u32 + 1,
43 ];
44 let num_cells = (grid_dims[0] * grid_dims[1] * grid_dims[2]) as usize;
45
46 let search_extent = (max_search_radius / cell_size).ceil() as i32;
48
49 let half_shell_offsets = Self::compute_half_shell_offsets(search_extent);
51
52 let mut cell_counts = vec![0u32; num_cells];
54 for &idx in active_indices {
55 let cell = Self::get_cell_index_static(
56 &atoms[idx].position,
57 &min_bounds,
58 inv_cell_size,
59 &grid_dims,
60 );
61 cell_counts[cell] += 1;
62 }
63
64 let mut cell_starts = vec![0u32; num_cells + 1];
66 for i in 0..num_cells {
67 cell_starts[i + 1] = cell_starts[i] + cell_counts[i];
68 }
69
70 let n_active = active_indices.len();
72 let mut atom_indices = vec![0u32; n_active];
73 let mut positions_x = vec![0.0f32; n_active];
74 let mut positions_y = vec![0.0f32; n_active];
75 let mut positions_z = vec![0.0f32; n_active];
76 let mut radii = vec![0.0f32; n_active];
77
78 let mut write_pos = cell_starts[..num_cells].to_vec();
79
80 for &orig_idx in active_indices {
81 let atom = &atoms[orig_idx];
82 let pos = &atom.position;
83 let cell = Self::get_cell_index_static(pos, &min_bounds, inv_cell_size, &grid_dims);
84
85 let wp = write_pos[cell] as usize;
86 atom_indices[wp] = orig_idx as u32;
87 positions_x[wp] = pos[0];
88 positions_y[wp] = pos[1];
89 positions_z[wp] = pos[2];
90 radii[wp] = atom.radius;
91
92 write_pos[cell] += 1;
93 }
94
95 SpatialGrid {
96 atom_indices,
97 positions_x,
98 positions_y,
99 positions_z,
100 radii,
101 cell_starts,
102 grid_dims,
103 num_cells,
104 half_shell_offsets,
105 }
106 }
107
108 fn calculate_bounds(
109 atoms: &[Atom],
110 active_indices: &[usize],
111 padding: f32,
112 ) -> ([f32; 3], [f32; 3]) {
113 let mut min_b = [f32::INFINITY; 3];
114 let mut max_b = [f32::NEG_INFINITY; 3];
115
116 for &idx in active_indices {
117 let pos = &atoms[idx].position;
118 for i in 0..3 {
119 min_b[i] = min_b[i].min(pos[i]);
120 max_b[i] = max_b[i].max(pos[i]);
121 }
122 }
123
124 for i in 0..3 {
125 min_b[i] -= padding;
126 max_b[i] += padding;
127 }
128
129 (min_b, max_b)
130 }
131
132 #[inline(always)]
133 fn get_cell_index_static(
134 pos: &[f32; 3],
135 min_bounds: &[f32; 3],
136 inv_cell_size: f32,
137 grid_dims: &[u32; 3],
138 ) -> usize {
139 let x = ((pos[0] - min_bounds[0]) * inv_cell_size) as u32;
140 let y = ((pos[1] - min_bounds[1]) * inv_cell_size) as u32;
141 let z = ((pos[2] - min_bounds[2]) * inv_cell_size) as u32;
142 (x + y * grid_dims[0] + z * grid_dims[0] * grid_dims[1]) as usize
143 }
144
145 #[inline(always)]
146 fn cell_coords_to_index(&self, cx: i32, cy: i32, cz: i32) -> Option<usize> {
147 if cx < 0 || cy < 0 || cz < 0 {
148 return None;
149 }
150 let cx = cx as u32;
151 let cy = cy as u32;
152 let cz = cz as u32;
153 if cx >= self.grid_dims[0] || cy >= self.grid_dims[1] || cz >= self.grid_dims[2] {
154 return None;
155 }
156 Some((cx + cy * self.grid_dims[0] + cz * self.grid_dims[0] * self.grid_dims[1]) as usize)
157 }
158
159 #[inline(always)]
160 fn index_to_cell_coords(&self, idx: usize) -> (i32, i32, i32) {
161 let idx = idx as u32;
162 let cz = idx / (self.grid_dims[0] * self.grid_dims[1]);
163 let remainder = idx % (self.grid_dims[0] * self.grid_dims[1]);
164 let cy = remainder / self.grid_dims[0];
165 let cx = remainder % self.grid_dims[0];
166 (cx as i32, cy as i32, cz as i32)
167 }
168
169 fn compute_half_shell_offsets(extent: i32) -> Vec<(i32, i32, i32)> {
175 let mut offsets = Vec::new();
176
177 for dz in -extent..=extent {
178 for dy in -extent..=extent {
179 for dx in -extent..=extent {
180 let include =
182 (dz > 0) || (dz == 0 && dy > 0) || (dz == 0 && dy == 0 && dx >= 0);
183
184 if include {
185 offsets.push((dx, dy, dz));
186 }
187 }
188 }
189 }
190
191 offsets
192 }
193
194 pub fn build_all_neighbor_lists(
196 &self,
197 atoms: &[Atom],
198 active_indices: &[usize],
199 probe_radius: f32,
200 max_radius: f32,
201 ) -> Vec<Vec<NeighborData>> {
202 let n_atoms = atoms.len();
203 let n_active = active_indices.len();
204
205 let mut orig_to_active = vec![u32::MAX; n_atoms];
207 for (active_idx, &orig_idx) in active_indices.iter().enumerate() {
208 orig_to_active[orig_idx] = active_idx as u32;
209 }
210
211 let mut neighbors: Vec<Vec<NeighborData>> =
213 (0..n_active).map(|_| Vec::with_capacity(80)).collect();
214
215 let max_search_radius = max_radius + max_radius + 2.0 * probe_radius;
217 let max_search_radius_sq = max_search_radius * max_search_radius;
218
219 for cell_a in 0..self.num_cells {
221 let start_a = self.cell_starts[cell_a] as usize;
222 let end_a = self.cell_starts[cell_a + 1] as usize;
223
224 if start_a == end_a {
225 continue;
226 }
227
228 let (cx, cy, cz) = self.index_to_cell_coords(cell_a);
229
230 for &(dx, dy, dz) in &self.half_shell_offsets {
232 let cell_b = match self.cell_coords_to_index(cx + dx, cy + dy, cz + dz) {
233 Some(c) => c,
234 None => continue,
235 };
236
237 let start_b = self.cell_starts[cell_b] as usize;
238 let end_b = self.cell_starts[cell_b + 1] as usize;
239
240 if start_b == end_b {
241 continue;
242 }
243
244 let is_self = dx == 0 && dy == 0 && dz == 0;
245
246 if is_self {
247 self.process_self_cell(
248 atoms,
249 &orig_to_active,
250 start_a,
251 end_a,
252 probe_radius,
253 max_radius,
254 max_search_radius_sq,
255 &mut neighbors,
256 );
257 } else {
258 self.process_neighbor_cells(
259 atoms,
260 &orig_to_active,
261 start_a,
262 end_a,
263 start_b,
264 end_b,
265 probe_radius,
266 max_radius,
267 max_search_radius_sq,
268 &mut neighbors,
269 );
270 }
271 }
272 }
273
274 self.sort_neighbors_by_distance(atoms, active_indices, &mut neighbors);
276
277 neighbors
278 }
279
280 #[inline(always)]
281 #[allow(clippy::too_many_arguments)]
282 fn process_self_cell(
283 &self,
284 atoms: &[Atom],
285 orig_to_active: &[u32],
286 start: usize,
287 end: usize,
288 probe_radius: f32,
289 max_radius: f32,
290 max_search_radius_sq: f32,
291 neighbors: &mut [Vec<NeighborData>],
292 ) {
293 for i in start..end {
294 let orig_i = self.atom_indices[i] as usize;
295 let active_i = orig_to_active[orig_i];
296 if active_i == u32::MAX {
297 continue;
298 }
299
300 let xi = self.positions_x[i];
301 let yi = self.positions_y[i];
302 let zi = self.positions_z[i];
303 let ri = self.radii[i];
304 let id_i = atoms[orig_i].id;
305
306 let sr_i = ri + max_radius + 2.0 * probe_radius;
308 let sr_i_sq = sr_i * sr_i;
309
310 for j in (i + 1)..end {
311 let orig_j = self.atom_indices[j] as usize;
312
313 if atoms[orig_j].id == id_i {
315 continue;
316 }
317
318 let dx = xi - self.positions_x[j];
319 let dy = yi - self.positions_y[j];
320 let dz = zi - self.positions_z[j];
321 let dist_sq = dx * dx + dy * dy + dz * dz;
322
323 if dist_sq > max_search_radius_sq {
325 continue;
326 }
327
328 let rj = self.radii[j];
329
330 let sr_j = rj + max_radius + 2.0 * probe_radius;
332 let sr_j_sq = sr_j * sr_j;
333
334 if dist_sq <= sr_i_sq {
336 let thresh_j = rj + probe_radius;
337 neighbors[active_i as usize].push(NeighborData {
338 idx: orig_j as u32,
339 threshold_squared: thresh_j * thresh_j,
340 });
341 }
342
343 if dist_sq <= sr_j_sq {
345 let active_j = orig_to_active[orig_j];
346 if active_j != u32::MAX {
347 let thresh_i = ri + probe_radius;
348 neighbors[active_j as usize].push(NeighborData {
349 idx: orig_i as u32,
350 threshold_squared: thresh_i * thresh_i,
351 });
352 }
353 }
354 }
355 }
356 }
357
358 #[inline(always)]
359 #[allow(clippy::too_many_arguments)]
360 fn process_neighbor_cells(
361 &self,
362 atoms: &[Atom],
363 orig_to_active: &[u32],
364 start_a: usize,
365 end_a: usize,
366 start_b: usize,
367 end_b: usize,
368 probe_radius: f32,
369 max_radius: f32,
370 max_search_radius_sq: f32,
371 neighbors: &mut [Vec<NeighborData>],
372 ) {
373 for i in start_a..end_a {
374 let orig_i = self.atom_indices[i] as usize;
375 let active_i = orig_to_active[orig_i];
376 if active_i == u32::MAX {
377 continue;
378 }
379
380 let xi = self.positions_x[i];
381 let yi = self.positions_y[i];
382 let zi = self.positions_z[i];
383 let ri = self.radii[i];
384 let id_i = atoms[orig_i].id;
385
386 let sr_i = ri + max_radius + 2.0 * probe_radius;
388 let sr_i_sq = sr_i * sr_i;
389
390 for j in start_b..end_b {
391 let orig_j = self.atom_indices[j] as usize;
392
393 if atoms[orig_j].id == id_i {
395 continue;
396 }
397
398 let dx = xi - self.positions_x[j];
399 let dy = yi - self.positions_y[j];
400 let dz = zi - self.positions_z[j];
401 let dist_sq = dx * dx + dy * dy + dz * dz;
402
403 if dist_sq > max_search_radius_sq {
405 continue;
406 }
407
408 let rj = self.radii[j];
409
410 let sr_j = rj + max_radius + 2.0 * probe_radius;
412 let sr_j_sq = sr_j * sr_j;
413
414 if dist_sq <= sr_i_sq {
416 let thresh_j = rj + probe_radius;
417 neighbors[active_i as usize].push(NeighborData {
418 idx: orig_j as u32,
419 threshold_squared: thresh_j * thresh_j,
420 });
421 }
422
423 if dist_sq <= sr_j_sq {
425 let active_j = orig_to_active[orig_j];
426 if active_j != u32::MAX {
427 let thresh_i = ri + probe_radius;
428 neighbors[active_j as usize].push(NeighborData {
429 idx: orig_i as u32,
430 threshold_squared: thresh_i * thresh_i,
431 });
432 }
433 }
434 }
435 }
436 }
437
438 fn sort_neighbors_by_distance(
439 &self,
440 atoms: &[Atom],
441 active_indices: &[usize],
442 neighbors: &mut [Vec<NeighborData>],
443 ) {
444 for (active_idx, neighbor_list) in neighbors.iter_mut().enumerate() {
445 if neighbor_list.len() <= 1 {
446 continue;
447 }
448
449 let center = atoms[active_indices[active_idx]].position;
450
451 neighbor_list.sort_unstable_by(|a, b| {
452 let pa = atoms[a.idx as usize].position;
453 let pb = atoms[b.idx as usize].position;
454
455 let da = (center[0] - pa[0]).powi(2)
456 + (center[1] - pa[1]).powi(2)
457 + (center[2] - pa[2]).powi(2);
458 let db = (center[0] - pb[0]).powi(2)
459 + (center[1] - pb[1]).powi(2)
460 + (center[2] - pb[2]).powi(2);
461
462 da.partial_cmp(&db).unwrap_or(std::cmp::Ordering::Equal)
463 });
464 }
465 }
466}