Skip to main content

rustpix_algorithms/
grid.rs

1//! SoA-based Grid Clustering.
2//!
3//! Adapted from generic `GridClustering` to work directly on `HitBatch` (`SoA`).
4
5use crate::SpatialGrid;
6use rustpix_core::clustering::ClusteringError;
7use rustpix_core::soa::HitBatch;
8
9/// Configuration for grid-based clustering.
10#[derive(Clone, Debug)]
11pub struct GridConfig {
12    /// Spatial radius for neighbor detection (pixels).
13    pub radius: f64,
14    /// Temporal correlation window (nanoseconds).
15    pub temporal_window_ns: f64,
16    /// Minimum cluster size to keep.
17    pub min_cluster_size: u16,
18    /// Maximum cluster size (None = unlimited).
19    pub max_cluster_size: Option<usize>,
20    /// Grid cell size (pixels).
21    pub cell_size: usize,
22}
23
24impl Default for GridConfig {
25    fn default() -> Self {
26        Self {
27            radius: 5.0,
28            temporal_window_ns: 75.0,
29            min_cluster_size: 1,
30            max_cluster_size: None,
31            cell_size: 32,
32        }
33    }
34}
35
36#[derive(Default)]
37/// Reusable grid clustering state.
38pub struct GridState {
39    /// Number of hits processed.
40    pub hits_processed: usize,
41    /// Number of clusters found.
42    pub clusters_found: usize,
43    grid: Option<SpatialGrid<usize>>,
44    parent: Vec<usize>,
45    rank: Vec<usize>,
46    roots: Vec<usize>,
47    cluster_sizes: Vec<usize>,
48    root_to_label: Vec<i32>,
49}
50
51/// SoA-optimized grid clustering implementation.
52pub struct GridClustering {
53    config: GridConfig,
54}
55
56struct GridUnionContext {
57    radius_sq: f64,
58    window_tof: u32,
59    cell_size: i32,
60}
61
62impl GridClustering {
63    /// Create with custom configuration.
64    #[must_use]
65    pub fn new(config: GridConfig) -> Self {
66        Self { config }
67    }
68
69    /// Cluster a batch of hits in-place.
70    ///
71    /// Updates `cluster_id` field in `batch`.
72    ///
73    /// # Errors
74    /// Returns an error if clustering fails.
75    pub fn cluster(
76        &self,
77        batch: &mut HitBatch,
78        state: &mut GridState,
79    ) -> Result<usize, ClusteringError> {
80        if batch.is_empty() {
81            return Ok(0);
82        }
83
84        let n = batch.len();
85        let GridState {
86            hits_processed,
87            clusters_found,
88            grid,
89            parent,
90            rank,
91            roots,
92            cluster_sizes,
93            root_to_label,
94        } = state;
95
96        *hits_processed = 0;
97        *clusters_found = 0;
98        batch.cluster_id.fill(-1);
99
100        let (width, height) = Self::batch_dimensions(batch);
101        Self::init_union_find(parent, rank, roots, cluster_sizes, root_to_label, n);
102
103        let grid = Self::prepare_grid(grid, self.config.cell_size, width, height);
104        Self::fill_grid(grid, batch);
105
106        let union_ctx = GridUnionContext {
107            radius_sq: self.config.radius * self.config.radius,
108            window_tof: float_to_u32((self.config.temporal_window_ns / 25.0).ceil()),
109            cell_size: i32::try_from(self.config.cell_size).unwrap_or(i32::MAX),
110        };
111
112        Self::union_hits(batch, grid, parent, rank, n, &union_ctx);
113
114        let clusters = Self::assign_labels(
115            batch,
116            parent,
117            roots,
118            cluster_sizes,
119            root_to_label,
120            n,
121            usize::from(self.config.min_cluster_size),
122        );
123
124        *hits_processed = n;
125        *clusters_found = clusters;
126        Ok(clusters)
127    }
128}
129
130fn find(parent: &mut [usize], i: usize) -> usize {
131    let mut root = i;
132    while root != parent[root] {
133        root = parent[root];
134    }
135    let mut curr = i;
136    while curr != root {
137        let next = parent[curr];
138        parent[curr] = root;
139        curr = next;
140    }
141    root
142}
143
144fn union_sets(parent: &mut [usize], rank: &mut [usize], i: usize, j: usize) {
145    let root_i = find(parent, i);
146    let root_j = find(parent, j);
147    if root_i != root_j {
148        if rank[root_i] < rank[root_j] {
149            parent[root_i] = root_j;
150        } else {
151            parent[root_j] = root_i;
152            if rank[root_i] == rank[root_j] {
153                rank[root_i] += 1;
154            }
155        }
156    }
157}
158
159impl GridClustering {
160    fn batch_dimensions(batch: &HitBatch) -> (usize, usize) {
161        let mut max_x = 0usize;
162        let mut max_y = 0usize;
163        for i in 0..batch.len() {
164            let x = usize::from(batch.x[i]);
165            let y = usize::from(batch.y[i]);
166            if x > max_x {
167                max_x = x;
168            }
169            if y > max_y {
170                max_y = y;
171            }
172        }
173        (max_x + 32, max_y + 32)
174    }
175
176    fn prepare_grid(
177        grid_slot: &mut Option<SpatialGrid<usize>>,
178        cell_size: usize,
179        width: usize,
180        height: usize,
181    ) -> &mut SpatialGrid<usize> {
182        let grid = grid_slot.get_or_insert_with(|| SpatialGrid::new(cell_size, width, height));
183        if grid.cell_size() == cell_size {
184            grid.ensure_dimensions(width, height);
185            grid.clear();
186        } else {
187            *grid = SpatialGrid::new(cell_size, width, height);
188        }
189        grid
190    }
191
192    fn fill_grid(grid: &mut SpatialGrid<usize>, batch: &HitBatch) {
193        for i in 0..batch.len() {
194            grid.insert(i32::from(batch.x[i]), i32::from(batch.y[i]), i);
195        }
196    }
197
198    fn init_union_find(
199        parent: &mut Vec<usize>,
200        rank: &mut Vec<usize>,
201        roots: &mut Vec<usize>,
202        cluster_sizes: &mut Vec<usize>,
203        root_to_label: &mut Vec<i32>,
204        n: usize,
205    ) {
206        if parent.len() < n {
207            parent.resize(n, 0);
208        }
209        if rank.len() < n {
210            rank.resize(n, 0);
211        }
212        if roots.len() < n {
213            roots.resize(n, 0);
214        }
215        if cluster_sizes.len() < n {
216            cluster_sizes.resize(n, 0);
217        }
218        if root_to_label.len() < n {
219            root_to_label.resize(n, -1);
220        }
221        for i in 0..n {
222            parent[i] = i;
223            rank[i] = 0;
224        }
225    }
226
227    fn union_hits(
228        batch: &HitBatch,
229        grid: &SpatialGrid<usize>,
230        parent: &mut [usize],
231        rank: &mut [usize],
232        n: usize,
233        ctx: &GridUnionContext,
234    ) {
235        for i in 0..n {
236            let x = i32::from(batch.x[i]);
237            let y = i32::from(batch.y[i]);
238
239            for dy in -1..=1 {
240                for dx in -1..=1 {
241                    let px = x + dx * ctx.cell_size;
242                    let py = y + dy * ctx.cell_size;
243
244                    if let Some(cell) = grid.get_cell_slice(px, py) {
245                        let start = cell.partition_point(|&idx| idx <= i);
246
247                        for &j in &cell[start..] {
248                            let dt = batch.tof[j].wrapping_sub(batch.tof[i]);
249                            if dt > ctx.window_tof {
250                                break;
251                            }
252
253                            let dx = f64::from(batch.x[i]) - f64::from(batch.x[j]);
254                            let dy = f64::from(batch.y[i]) - f64::from(batch.y[j]);
255                            let dist_sq = dx * dx + dy * dy;
256
257                            if dist_sq <= ctx.radius_sq {
258                                union_sets(parent, rank, i, j);
259                            }
260                        }
261                    }
262                }
263            }
264        }
265    }
266
267    fn assign_labels(
268        batch: &mut HitBatch,
269        parent: &mut [usize],
270        roots: &mut [usize],
271        cluster_sizes: &mut [usize],
272        root_to_label: &mut [i32],
273        n: usize,
274        min_cluster_size: usize,
275    ) -> usize {
276        cluster_sizes[..n].fill(0);
277        for (i, root_slot) in roots.iter_mut().enumerate().take(n) {
278            let root = find(parent, i);
279            *root_slot = root;
280            cluster_sizes[root] += 1;
281        }
282
283        root_to_label[..n].fill(-1);
284        let mut next_label = 0;
285
286        for (i, &root) in roots.iter().enumerate().take(n) {
287            let size = cluster_sizes[root];
288
289            if size < min_cluster_size {
290                batch.cluster_id[i] = -1;
291            } else {
292                let label_slot = &mut root_to_label[root];
293                if *label_slot < 0 {
294                    *label_slot = next_label;
295                    next_label += 1;
296                }
297                batch.cluster_id[i] = *label_slot;
298            }
299        }
300
301        usize::try_from(next_label).unwrap_or(0)
302    }
303}
304
305fn float_to_u32(value: f64) -> u32 {
306    if value <= 0.0 {
307        return 0;
308    }
309    if value >= f64::from(u32::MAX) {
310        return u32::MAX;
311    }
312    format!("{value:.0}").parse::<u32>().unwrap_or(u32::MAX)
313}
314
315impl Default for GridClustering {
316    fn default() -> Self {
317        Self::new(GridConfig::default())
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324    use rustpix_core::soa::HitBatch;
325
326    #[test]
327    fn test_soa_clustering() {
328        let mut batch = HitBatch::default();
329        // Cluster 1
330        batch.push((10, 10, 100, 5, 0, 0));
331        batch.push((11, 11, 102, 5, 0, 0)); // Close in space and time
332
333        // Cluster 2
334        batch.push((50, 50, 100, 5, 0, 0)); // Far in space
335
336        // Noise
337        batch.push((100, 100, 10000, 5, 0, 0)); // Far in time
338
339        let algo = GridClustering::default();
340        let mut state = GridState::default();
341
342        let count = algo.cluster(&mut batch, &mut state).unwrap();
343
344        assert_eq!(count, 3); // 1, 2, and 3 (noise is usually single-hit cluster if min_size=1)
345                              // With default min_cluster_size=1, noise is a cluster.
346
347        assert_eq!(batch.cluster_id[0], batch.cluster_id[1]);
348        assert_ne!(batch.cluster_id[0], batch.cluster_id[2]);
349    }
350
351    #[test]
352    fn test_grid_requires_tof_sorted_input() {
353        // This test documents that if hits are not sorted by TOF, clustering might fail to link them
354        // if we rely on temporal pruning (break loop early).
355        //
356        // Example: Hit A (TOF 100), Hit B (TOF 200), Hit C (TOF 102)
357        // If stored as [A, B, C], when processing A:
358        //   - Check A vs B (diff 100). If window=10, loop breaks.
359        //   - A vs C never checked.
360        // Result: A not linked to C, even though diff is 2.
361
362        let mut batch = HitBatch::default();
363        batch.push((10, 10, 100, 5, 0, 0)); // Hit A
364        batch.push((10, 10, 200, 5, 0, 0)); // Hit B (far future)
365        batch.push((10, 10, 102, 5, 0, 0)); // Hit C (should be with A)
366
367        // Window of 2 ticks (25ns each): 50ns / 25.0 = 2.0, ceil = 2
368        let config = GridConfig {
369            temporal_window_ns: 50.0,
370            ..Default::default()
371        };
372        let algo = GridClustering::new(config);
373        let mut state = GridState::default();
374
375        // This relies on implementation detail. If we implement pruning, we expect A and C NOT to cluster
376        // because B stops the search from A.
377        algo.cluster(&mut batch, &mut state).unwrap();
378
379        // If full N^2 check, A and C would cluster.
380        // With pruning, they won't.
381        assert_ne!(
382            batch.cluster_id[0], batch.cluster_id[2],
383            "Pruning should prevent linking unsorted hits separated by future hits"
384        );
385    }
386
387    #[test]
388    fn test_grid_temporal_pruning() {
389        let mut batch = HitBatch::default();
390
391        // Ensure that we don't scan infinity.
392        // A, B, C, D... sorted.
393        // A (0), B (100), C (200), D (300). Window = 50.
394        // A checks B -> fail, break. A checks C? No.
395        // If logic is correct, performance is O(N * window_density) not O(N^2).
396
397        // Correctness check:
398        batch.push((10, 10, 100, 5, 0, 0));
399        batch.push((10, 10, 101, 5, 0, 0)); // Linked to 0 (delta 1 tick = 25ns < 50ns)
400        batch.push((10, 10, 200, 5, 0, 0)); // Not linked
401
402        let config = GridConfig {
403            temporal_window_ns: 50.0, // ~2 ticks
404            ..Default::default()
405        };
406        let algo = GridClustering::new(config);
407        let mut state = GridState::default();
408
409        algo.cluster(&mut batch, &mut state).unwrap();
410
411        assert_eq!(batch.cluster_id[0], batch.cluster_id[1]);
412        assert_ne!(batch.cluster_id[0], batch.cluster_id[2]);
413    }
414}