1use crate::SpatialGrid;
6use rustpix_core::clustering::ClusteringError;
7use rustpix_core::soa::HitBatch;
8
9#[derive(Clone, Debug)]
11pub struct GridConfig {
12 pub radius: f64,
14 pub temporal_window_ns: f64,
16 pub min_cluster_size: u16,
18 pub max_cluster_size: Option<usize>,
20 pub cell_size: usize,
22}
23
24impl Default for GridConfig {
25 fn default() -> Self {
26 Self {
27 radius: 5.0,
28 temporal_window_ns: 75.0,
29 min_cluster_size: 1,
30 max_cluster_size: None,
31 cell_size: 32,
32 }
33 }
34}
35
36#[derive(Default)]
37pub struct GridState {
39 pub hits_processed: usize,
41 pub clusters_found: usize,
43 grid: Option<SpatialGrid<usize>>,
44 parent: Vec<usize>,
45 rank: Vec<usize>,
46 roots: Vec<usize>,
47 cluster_sizes: Vec<usize>,
48 root_to_label: Vec<i32>,
49}
50
51pub struct GridClustering {
53 config: GridConfig,
54}
55
56struct GridUnionContext {
57 radius_sq: f64,
58 window_tof: u32,
59 cell_size: i32,
60}
61
62impl GridClustering {
63 #[must_use]
65 pub fn new(config: GridConfig) -> Self {
66 Self { config }
67 }
68
69 pub fn cluster(
76 &self,
77 batch: &mut HitBatch,
78 state: &mut GridState,
79 ) -> Result<usize, ClusteringError> {
80 if batch.is_empty() {
81 return Ok(0);
82 }
83
84 let n = batch.len();
85 let GridState {
86 hits_processed,
87 clusters_found,
88 grid,
89 parent,
90 rank,
91 roots,
92 cluster_sizes,
93 root_to_label,
94 } = state;
95
96 *hits_processed = 0;
97 *clusters_found = 0;
98 batch.cluster_id.fill(-1);
99
100 let (width, height) = Self::batch_dimensions(batch);
101 Self::init_union_find(parent, rank, roots, cluster_sizes, root_to_label, n);
102
103 let grid = Self::prepare_grid(grid, self.config.cell_size, width, height);
104 Self::fill_grid(grid, batch);
105
106 let union_ctx = GridUnionContext {
107 radius_sq: self.config.radius * self.config.radius,
108 window_tof: float_to_u32((self.config.temporal_window_ns / 25.0).ceil()),
109 cell_size: i32::try_from(self.config.cell_size).unwrap_or(i32::MAX),
110 };
111
112 Self::union_hits(batch, grid, parent, rank, n, &union_ctx);
113
114 let clusters = Self::assign_labels(
115 batch,
116 parent,
117 roots,
118 cluster_sizes,
119 root_to_label,
120 n,
121 usize::from(self.config.min_cluster_size),
122 );
123
124 *hits_processed = n;
125 *clusters_found = clusters;
126 Ok(clusters)
127 }
128}
129
130fn find(parent: &mut [usize], i: usize) -> usize {
131 let mut root = i;
132 while root != parent[root] {
133 root = parent[root];
134 }
135 let mut curr = i;
136 while curr != root {
137 let next = parent[curr];
138 parent[curr] = root;
139 curr = next;
140 }
141 root
142}
143
144fn union_sets(parent: &mut [usize], rank: &mut [usize], i: usize, j: usize) {
145 let root_i = find(parent, i);
146 let root_j = find(parent, j);
147 if root_i != root_j {
148 if rank[root_i] < rank[root_j] {
149 parent[root_i] = root_j;
150 } else {
151 parent[root_j] = root_i;
152 if rank[root_i] == rank[root_j] {
153 rank[root_i] += 1;
154 }
155 }
156 }
157}
158
159impl GridClustering {
160 fn batch_dimensions(batch: &HitBatch) -> (usize, usize) {
161 let mut max_x = 0usize;
162 let mut max_y = 0usize;
163 for i in 0..batch.len() {
164 let x = usize::from(batch.x[i]);
165 let y = usize::from(batch.y[i]);
166 if x > max_x {
167 max_x = x;
168 }
169 if y > max_y {
170 max_y = y;
171 }
172 }
173 (max_x + 32, max_y + 32)
174 }
175
176 fn prepare_grid(
177 grid_slot: &mut Option<SpatialGrid<usize>>,
178 cell_size: usize,
179 width: usize,
180 height: usize,
181 ) -> &mut SpatialGrid<usize> {
182 let grid = grid_slot.get_or_insert_with(|| SpatialGrid::new(cell_size, width, height));
183 if grid.cell_size() == cell_size {
184 grid.ensure_dimensions(width, height);
185 grid.clear();
186 } else {
187 *grid = SpatialGrid::new(cell_size, width, height);
188 }
189 grid
190 }
191
192 fn fill_grid(grid: &mut SpatialGrid<usize>, batch: &HitBatch) {
193 for i in 0..batch.len() {
194 grid.insert(i32::from(batch.x[i]), i32::from(batch.y[i]), i);
195 }
196 }
197
198 fn init_union_find(
199 parent: &mut Vec<usize>,
200 rank: &mut Vec<usize>,
201 roots: &mut Vec<usize>,
202 cluster_sizes: &mut Vec<usize>,
203 root_to_label: &mut Vec<i32>,
204 n: usize,
205 ) {
206 if parent.len() < n {
207 parent.resize(n, 0);
208 }
209 if rank.len() < n {
210 rank.resize(n, 0);
211 }
212 if roots.len() < n {
213 roots.resize(n, 0);
214 }
215 if cluster_sizes.len() < n {
216 cluster_sizes.resize(n, 0);
217 }
218 if root_to_label.len() < n {
219 root_to_label.resize(n, -1);
220 }
221 for i in 0..n {
222 parent[i] = i;
223 rank[i] = 0;
224 }
225 }
226
227 fn union_hits(
228 batch: &HitBatch,
229 grid: &SpatialGrid<usize>,
230 parent: &mut [usize],
231 rank: &mut [usize],
232 n: usize,
233 ctx: &GridUnionContext,
234 ) {
235 for i in 0..n {
236 let x = i32::from(batch.x[i]);
237 let y = i32::from(batch.y[i]);
238
239 for dy in -1..=1 {
240 for dx in -1..=1 {
241 let px = x + dx * ctx.cell_size;
242 let py = y + dy * ctx.cell_size;
243
244 if let Some(cell) = grid.get_cell_slice(px, py) {
245 let start = cell.partition_point(|&idx| idx <= i);
246
247 for &j in &cell[start..] {
248 let dt = batch.tof[j].wrapping_sub(batch.tof[i]);
249 if dt > ctx.window_tof {
250 break;
251 }
252
253 let dx = f64::from(batch.x[i]) - f64::from(batch.x[j]);
254 let dy = f64::from(batch.y[i]) - f64::from(batch.y[j]);
255 let dist_sq = dx * dx + dy * dy;
256
257 if dist_sq <= ctx.radius_sq {
258 union_sets(parent, rank, i, j);
259 }
260 }
261 }
262 }
263 }
264 }
265 }
266
267 fn assign_labels(
268 batch: &mut HitBatch,
269 parent: &mut [usize],
270 roots: &mut [usize],
271 cluster_sizes: &mut [usize],
272 root_to_label: &mut [i32],
273 n: usize,
274 min_cluster_size: usize,
275 ) -> usize {
276 cluster_sizes[..n].fill(0);
277 for (i, root_slot) in roots.iter_mut().enumerate().take(n) {
278 let root = find(parent, i);
279 *root_slot = root;
280 cluster_sizes[root] += 1;
281 }
282
283 root_to_label[..n].fill(-1);
284 let mut next_label = 0;
285
286 for (i, &root) in roots.iter().enumerate().take(n) {
287 let size = cluster_sizes[root];
288
289 if size < min_cluster_size {
290 batch.cluster_id[i] = -1;
291 } else {
292 let label_slot = &mut root_to_label[root];
293 if *label_slot < 0 {
294 *label_slot = next_label;
295 next_label += 1;
296 }
297 batch.cluster_id[i] = *label_slot;
298 }
299 }
300
301 usize::try_from(next_label).unwrap_or(0)
302 }
303}
304
305fn float_to_u32(value: f64) -> u32 {
306 if value <= 0.0 {
307 return 0;
308 }
309 if value >= f64::from(u32::MAX) {
310 return u32::MAX;
311 }
312 format!("{value:.0}").parse::<u32>().unwrap_or(u32::MAX)
313}
314
315impl Default for GridClustering {
316 fn default() -> Self {
317 Self::new(GridConfig::default())
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324 use rustpix_core::soa::HitBatch;
325
326 #[test]
327 fn test_soa_clustering() {
328 let mut batch = HitBatch::default();
329 batch.push((10, 10, 100, 5, 0, 0));
331 batch.push((11, 11, 102, 5, 0, 0)); batch.push((50, 50, 100, 5, 0, 0)); batch.push((100, 100, 10000, 5, 0, 0)); let algo = GridClustering::default();
340 let mut state = GridState::default();
341
342 let count = algo.cluster(&mut batch, &mut state).unwrap();
343
344 assert_eq!(count, 3); assert_eq!(batch.cluster_id[0], batch.cluster_id[1]);
348 assert_ne!(batch.cluster_id[0], batch.cluster_id[2]);
349 }
350
351 #[test]
352 fn test_grid_requires_tof_sorted_input() {
353 let mut batch = HitBatch::default();
363 batch.push((10, 10, 100, 5, 0, 0)); batch.push((10, 10, 200, 5, 0, 0)); batch.push((10, 10, 102, 5, 0, 0)); let config = GridConfig {
369 temporal_window_ns: 50.0,
370 ..Default::default()
371 };
372 let algo = GridClustering::new(config);
373 let mut state = GridState::default();
374
375 algo.cluster(&mut batch, &mut state).unwrap();
378
379 assert_ne!(
382 batch.cluster_id[0], batch.cluster_id[2],
383 "Pruning should prevent linking unsorted hits separated by future hits"
384 );
385 }
386
387 #[test]
388 fn test_grid_temporal_pruning() {
389 let mut batch = HitBatch::default();
390
391 batch.push((10, 10, 100, 5, 0, 0));
399 batch.push((10, 10, 101, 5, 0, 0)); batch.push((10, 10, 200, 5, 0, 0)); let config = GridConfig {
403 temporal_window_ns: 50.0, ..Default::default()
405 };
406 let algo = GridClustering::new(config);
407 let mut state = GridState::default();
408
409 algo.cluster(&mut batch, &mut state).unwrap();
410
411 assert_eq!(batch.cluster_id[0], batch.cluster_id[1]);
412 assert_ne!(batch.cluster_id[0], batch.cluster_id[2]);
413 }
414}