rust_sasa/structures/
spatial_grid.rs

1// Copyright (c) 2024 Maxwell Campbell. Licensed under the MIT License.
2use crate::{Atom, NeighborData};
3
4pub struct SpatialGrid {
5    /// Atom indices sorted by cell (contiguous per cell)
6    atom_indices: Vec<u32>,
7
8    /// Positions in SoA layout
9    positions_x: Vec<f32>,
10    positions_y: Vec<f32>,
11    positions_z: Vec<f32>,
12
13    /// Radii for each sorted atom (parallel to positions)
14    radii: Vec<f32>,
15
16    /// Start index in atom_indices for each cell
17    cell_starts: Vec<u32>,
18
19    /// Grid parameters
20    grid_dims: [u32; 3],
21    num_cells: usize,
22
23    /// Precomputed half-shell offsets for the search extent
24    half_shell_offsets: Vec<(i32, i32, i32)>,
25}
26
27impl SpatialGrid {
28    pub fn new(
29        atoms: &[Atom],
30        active_indices: &[usize],
31        cell_size: f32,
32        max_search_radius: f32,
33    ) -> Self {
34        // Calculate bounds
35        let (min_bounds, max_bounds) = Self::calculate_bounds(atoms, active_indices, cell_size);
36        let inv_cell_size = 1.0 / cell_size;
37
38        // Grid dimensions
39        let grid_dims = [
40            ((max_bounds[0] - min_bounds[0]) * inv_cell_size).ceil() as u32 + 1,
41            ((max_bounds[1] - min_bounds[1]) * inv_cell_size).ceil() as u32 + 1,
42            ((max_bounds[2] - min_bounds[2]) * inv_cell_size).ceil() as u32 + 1,
43        ];
44        let num_cells = (grid_dims[0] * grid_dims[1] * grid_dims[2]) as usize;
45
46        // Calculate search extent
47        let search_extent = (max_search_radius / cell_size).ceil() as i32;
48
49        // Precompute half-shell offsets for this search extent
50        let half_shell_offsets = Self::compute_half_shell_offsets(search_extent);
51
52        // Count atoms per cell
53        let mut cell_counts = vec![0u32; num_cells];
54        for &idx in active_indices {
55            let cell = Self::get_cell_index_static(
56                &atoms[idx].position,
57                &min_bounds,
58                inv_cell_size,
59                &grid_dims,
60            );
61            cell_counts[cell] += 1;
62        }
63
64        // Build cell_starts (exclusive prefix sum)
65        let mut cell_starts = vec![0u32; num_cells + 1];
66        for i in 0..num_cells {
67            cell_starts[i + 1] = cell_starts[i] + cell_counts[i];
68        }
69
70        // Allocate and fill sorted arrays
71        let n_active = active_indices.len();
72        let mut atom_indices = vec![0u32; n_active];
73        let mut positions_x = vec![0.0f32; n_active];
74        let mut positions_y = vec![0.0f32; n_active];
75        let mut positions_z = vec![0.0f32; n_active];
76        let mut radii = vec![0.0f32; n_active];
77
78        let mut write_pos = cell_starts[..num_cells].to_vec();
79
80        for &orig_idx in active_indices {
81            let atom = &atoms[orig_idx];
82            let pos = &atom.position;
83            let cell = Self::get_cell_index_static(pos, &min_bounds, inv_cell_size, &grid_dims);
84
85            let wp = write_pos[cell] as usize;
86            atom_indices[wp] = orig_idx as u32;
87            positions_x[wp] = pos[0];
88            positions_y[wp] = pos[1];
89            positions_z[wp] = pos[2];
90            radii[wp] = atom.radius;
91
92            write_pos[cell] += 1;
93        }
94
95        SpatialGrid {
96            atom_indices,
97            positions_x,
98            positions_y,
99            positions_z,
100            radii,
101            cell_starts,
102            grid_dims,
103            num_cells,
104            half_shell_offsets,
105        }
106    }
107
108    fn calculate_bounds(
109        atoms: &[Atom],
110        active_indices: &[usize],
111        padding: f32,
112    ) -> ([f32; 3], [f32; 3]) {
113        let mut min_b = [f32::INFINITY; 3];
114        let mut max_b = [f32::NEG_INFINITY; 3];
115
116        for &idx in active_indices {
117            let pos = &atoms[idx].position;
118            for i in 0..3 {
119                min_b[i] = min_b[i].min(pos[i]);
120                max_b[i] = max_b[i].max(pos[i]);
121            }
122        }
123
124        for i in 0..3 {
125            min_b[i] -= padding;
126            max_b[i] += padding;
127        }
128
129        (min_b, max_b)
130    }
131
132    #[inline(always)]
133    fn get_cell_index_static(
134        pos: &[f32; 3],
135        min_bounds: &[f32; 3],
136        inv_cell_size: f32,
137        grid_dims: &[u32; 3],
138    ) -> usize {
139        let x = ((pos[0] - min_bounds[0]) * inv_cell_size) as u32;
140        let y = ((pos[1] - min_bounds[1]) * inv_cell_size) as u32;
141        let z = ((pos[2] - min_bounds[2]) * inv_cell_size) as u32;
142        (x + y * grid_dims[0] + z * grid_dims[0] * grid_dims[1]) as usize
143    }
144
145    #[inline(always)]
146    fn cell_coords_to_index(&self, cx: i32, cy: i32, cz: i32) -> Option<usize> {
147        if cx < 0 || cy < 0 || cz < 0 {
148            return None;
149        }
150        let cx = cx as u32;
151        let cy = cy as u32;
152        let cz = cz as u32;
153        if cx >= self.grid_dims[0] || cy >= self.grid_dims[1] || cz >= self.grid_dims[2] {
154            return None;
155        }
156        Some((cx + cy * self.grid_dims[0] + cz * self.grid_dims[0] * self.grid_dims[1]) as usize)
157    }
158
159    #[inline(always)]
160    fn index_to_cell_coords(&self, idx: usize) -> (i32, i32, i32) {
161        let idx = idx as u32;
162        let cz = idx / (self.grid_dims[0] * self.grid_dims[1]);
163        let remainder = idx % (self.grid_dims[0] * self.grid_dims[1]);
164        let cy = remainder / self.grid_dims[0];
165        let cx = remainder % self.grid_dims[0];
166        (cx as i32, cy as i32, cz as i32)
167    }
168
169    /// Compute half-shell offsets for a given search extent [See http://doi.acm.org/10.1145/1862648.1862653]
170    ///
171    /// Half-shell means: for each pair of cells, only one cell "owns" the check.
172    /// We select cells where (dz > 0) OR (dz == 0 && dy > 0) OR (dz == 0 && dy == 0 && dx >= 0)
173    /// Note: dx >= 0 includes self (0,0,0) which we handle specially
174    fn compute_half_shell_offsets(extent: i32) -> Vec<(i32, i32, i32)> {
175        let mut offsets = Vec::new();
176
177        for dz in -extent..=extent {
178            for dy in -extent..=extent {
179                for dx in -extent..=extent {
180                    // Half-shell condition
181                    let include =
182                        (dz > 0) || (dz == 0 && dy > 0) || (dz == 0 && dy == 0 && dx >= 0);
183
184                    if include {
185                        offsets.push((dx, dy, dz));
186                    }
187                }
188            }
189        }
190
191        offsets
192    }
193
194    /// Build neighbor lists for all active atoms
195    pub fn build_all_neighbor_lists(
196        &self,
197        atoms: &[Atom],
198        active_indices: &[usize],
199        probe_radius: f32,
200        max_radius: f32,
201    ) -> Vec<Vec<NeighborData>> {
202        let n_atoms = atoms.len();
203        let n_active = active_indices.len();
204
205        // Map original index -> active index
206        let mut orig_to_active = vec![u32::MAX; n_atoms];
207        for (active_idx, &orig_idx) in active_indices.iter().enumerate() {
208            orig_to_active[orig_idx] = active_idx as u32;
209        }
210
211        // Preallocate neighbor lists
212        let mut neighbors: Vec<Vec<NeighborData>> =
213            (0..n_active).map(|_| Vec::with_capacity(80)).collect();
214
215        // Maximum possible search radius (for quick rejection)
216        let max_search_radius = max_radius + max_radius + 2.0 * probe_radius;
217        let max_search_radius_sq = max_search_radius * max_search_radius;
218
219        // Iterate through all cells using half-shell pattern
220        for cell_a in 0..self.num_cells {
221            let start_a = self.cell_starts[cell_a] as usize;
222            let end_a = self.cell_starts[cell_a + 1] as usize;
223
224            if start_a == end_a {
225                continue;
226            }
227
228            let (cx, cy, cz) = self.index_to_cell_coords(cell_a);
229
230            // Process this cell against all cells in half-shell
231            for &(dx, dy, dz) in &self.half_shell_offsets {
232                let cell_b = match self.cell_coords_to_index(cx + dx, cy + dy, cz + dz) {
233                    Some(c) => c,
234                    None => continue,
235                };
236
237                let start_b = self.cell_starts[cell_b] as usize;
238                let end_b = self.cell_starts[cell_b + 1] as usize;
239
240                if start_b == end_b {
241                    continue;
242                }
243
244                let is_self = dx == 0 && dy == 0 && dz == 0;
245
246                if is_self {
247                    self.process_self_cell(
248                        atoms,
249                        &orig_to_active,
250                        start_a,
251                        end_a,
252                        probe_radius,
253                        max_radius,
254                        max_search_radius_sq,
255                        &mut neighbors,
256                    );
257                } else {
258                    self.process_neighbor_cells(
259                        atoms,
260                        &orig_to_active,
261                        start_a,
262                        end_a,
263                        start_b,
264                        end_b,
265                        probe_radius,
266                        max_radius,
267                        max_search_radius_sq,
268                        &mut neighbors,
269                    );
270                }
271            }
272        }
273
274        // Sort neighbors by distance for early-exit optimization
275        self.sort_neighbors_by_distance(atoms, active_indices, &mut neighbors);
276
277        neighbors
278    }
279
280    #[inline(always)]
281    #[allow(clippy::too_many_arguments)]
282    fn process_self_cell(
283        &self,
284        atoms: &[Atom],
285        orig_to_active: &[u32],
286        start: usize,
287        end: usize,
288        probe_radius: f32,
289        max_radius: f32,
290        max_search_radius_sq: f32,
291        neighbors: &mut [Vec<NeighborData>],
292    ) {
293        for i in start..end {
294            let orig_i = self.atom_indices[i] as usize;
295            let active_i = orig_to_active[orig_i];
296            if active_i == u32::MAX {
297                continue;
298            }
299
300            let xi = self.positions_x[i];
301            let yi = self.positions_y[i];
302            let zi = self.positions_z[i];
303            let ri = self.radii[i];
304            let id_i = atoms[orig_i].id;
305
306            // Search radius for atom i
307            let sr_i = ri + max_radius + 2.0 * probe_radius;
308            let sr_i_sq = sr_i * sr_i;
309
310            for j in (i + 1)..end {
311                let orig_j = self.atom_indices[j] as usize;
312
313                // Skip if same atom ID
314                if atoms[orig_j].id == id_i {
315                    continue;
316                }
317
318                let dx = xi - self.positions_x[j];
319                let dy = yi - self.positions_y[j];
320                let dz = zi - self.positions_z[j];
321                let dist_sq = dx * dx + dy * dy + dz * dz;
322
323                // Quick rejection
324                if dist_sq > max_search_radius_sq {
325                    continue;
326                }
327
328                let rj = self.radii[j];
329
330                // Search radius for atom j
331                let sr_j = rj + max_radius + 2.0 * probe_radius;
332                let sr_j_sq = sr_j * sr_j;
333
334                // Check if i finds j
335                if dist_sq <= sr_i_sq {
336                    let thresh_j = rj + probe_radius;
337                    neighbors[active_i as usize].push(NeighborData {
338                        idx: orig_j as u32,
339                        threshold_squared: thresh_j * thresh_j,
340                    });
341                }
342
343                // Check if j finds i
344                if dist_sq <= sr_j_sq {
345                    let active_j = orig_to_active[orig_j];
346                    if active_j != u32::MAX {
347                        let thresh_i = ri + probe_radius;
348                        neighbors[active_j as usize].push(NeighborData {
349                            idx: orig_i as u32,
350                            threshold_squared: thresh_i * thresh_i,
351                        });
352                    }
353                }
354            }
355        }
356    }
357
358    #[inline(always)]
359    #[allow(clippy::too_many_arguments)]
360    fn process_neighbor_cells(
361        &self,
362        atoms: &[Atom],
363        orig_to_active: &[u32],
364        start_a: usize,
365        end_a: usize,
366        start_b: usize,
367        end_b: usize,
368        probe_radius: f32,
369        max_radius: f32,
370        max_search_radius_sq: f32,
371        neighbors: &mut [Vec<NeighborData>],
372    ) {
373        for i in start_a..end_a {
374            let orig_i = self.atom_indices[i] as usize;
375            let active_i = orig_to_active[orig_i];
376            if active_i == u32::MAX {
377                continue;
378            }
379
380            let xi = self.positions_x[i];
381            let yi = self.positions_y[i];
382            let zi = self.positions_z[i];
383            let ri = self.radii[i];
384            let id_i = atoms[orig_i].id;
385
386            // Search radius for atom i
387            let sr_i = ri + max_radius + 2.0 * probe_radius;
388            let sr_i_sq = sr_i * sr_i;
389
390            for j in start_b..end_b {
391                let orig_j = self.atom_indices[j] as usize;
392
393                // Skip if same atom ID
394                if atoms[orig_j].id == id_i {
395                    continue;
396                }
397
398                let dx = xi - self.positions_x[j];
399                let dy = yi - self.positions_y[j];
400                let dz = zi - self.positions_z[j];
401                let dist_sq = dx * dx + dy * dy + dz * dz;
402
403                // Quick rejection
404                if dist_sq > max_search_radius_sq {
405                    continue;
406                }
407
408                let rj = self.radii[j];
409
410                // Search radius for atom j
411                let sr_j = rj + max_radius + 2.0 * probe_radius;
412                let sr_j_sq = sr_j * sr_j;
413
414                // Check if i finds j
415                if dist_sq <= sr_i_sq {
416                    let thresh_j = rj + probe_radius;
417                    neighbors[active_i as usize].push(NeighborData {
418                        idx: orig_j as u32,
419                        threshold_squared: thresh_j * thresh_j,
420                    });
421                }
422
423                // Check if j finds i
424                if dist_sq <= sr_j_sq {
425                    let active_j = orig_to_active[orig_j];
426                    if active_j != u32::MAX {
427                        let thresh_i = ri + probe_radius;
428                        neighbors[active_j as usize].push(NeighborData {
429                            idx: orig_i as u32,
430                            threshold_squared: thresh_i * thresh_i,
431                        });
432                    }
433                }
434            }
435        }
436    }
437
438    fn sort_neighbors_by_distance(
439        &self,
440        atoms: &[Atom],
441        active_indices: &[usize],
442        neighbors: &mut [Vec<NeighborData>],
443    ) {
444        for (active_idx, neighbor_list) in neighbors.iter_mut().enumerate() {
445            if neighbor_list.len() <= 1 {
446                continue;
447            }
448
449            let center = atoms[active_indices[active_idx]].position;
450
451            neighbor_list.sort_unstable_by(|a, b| {
452                let pa = atoms[a.idx as usize].position;
453                let pb = atoms[b.idx as usize].position;
454
455                let da = (center[0] - pa[0]).powi(2)
456                    + (center[1] - pa[1]).powi(2)
457                    + (center[2] - pa[2]).powi(2);
458                let db = (center[0] - pb[0]).powi(2)
459                    + (center[1] - pb[1]).powi(2)
460                    + (center[2] - pb[2]).powi(2);
461
462                da.partial_cmp(&db).unwrap_or(std::cmp::Ordering::Equal)
463            });
464        }
465    }
466}