Skip to main content

rustpix_algorithms/
abs.rs

1//! SoA-optimized ABS (Age-Based Spatial) clustering.
2
3use rustpix_core::clustering::ClusteringError;
4use rustpix_core::soa::HitBatch;
5
6/// Configuration for ABS (Age-Based Spatial) clustering.
7#[derive(Clone, Debug)]
8pub struct AbsConfig {
9    /// Spatial radius for neighbor detection (pixels).
10    pub radius: f64,
11    /// Temporal correlation window (nanoseconds).
12    pub neutron_correlation_window_ns: f64,
13    /// Minimum cluster size to keep.
14    pub min_cluster_size: u16,
15    /// Number of hits between aging scans.
16    pub scan_interval: usize,
17}
18
19impl Default for AbsConfig {
20    fn default() -> Self {
21        Self {
22            radius: 5.0,
23            neutron_correlation_window_ns: 75.0,
24            min_cluster_size: 1,
25            scan_interval: 100,
26        }
27    }
28}
29
30struct Bucket {
31    x_min: u16,
32    x_max: u16,
33    y_min: u16,
34    y_max: u16,
35    start_tof: u32,
36    cluster_id: i32,
37    is_active: bool,
38    insertion_x: u16,
39    insertion_y: u16,
40}
41
42impl Bucket {
43    fn new() -> Self {
44        Self {
45            x_min: u16::MAX,
46            x_max: 0,
47            y_min: u16::MAX,
48            y_max: 0,
49            start_tof: 0,
50            cluster_id: -1,
51            is_active: false,
52            insertion_x: 0,
53            insertion_y: 0,
54        }
55    }
56
57    fn initialize(&mut self, x: u16, y: u16, tof: u32, cluster_id: i32) {
58        self.x_min = x;
59        self.x_max = x;
60        self.y_min = y;
61        self.y_max = y;
62        self.start_tof = tof;
63        self.cluster_id = cluster_id;
64        self.is_active = true;
65        self.insertion_x = x;
66        self.insertion_y = y;
67    }
68
69    fn add_hit(&mut self, x: u16, y: u16) {
70        self.x_min = self.x_min.min(x);
71        self.x_max = self.x_max.max(x);
72        self.y_min = self.y_min.min(y);
73        self.y_max = self.y_max.max(y);
74    }
75}
76
77/// ABS clustering implementation.
78pub struct AbsClustering {
79    config: AbsConfig,
80}
81
82struct AbsSearchContext {
83    window_tof: u32,
84    cell_size: usize,
85    grid_w: usize,
86    radius_i32: i32,
87}
88
89/// Reusable ABS clustering state for streaming or repeated runs.
90pub struct AbsState {
91    buckets: Vec<Bucket>,
92    active_indices: Vec<usize>,
93    free_indices: Vec<usize>,
94    grid: Vec<Vec<usize>>, // Spatial index
95    grid_w: usize,
96    next_cluster_id: i32,
97    cluster_sizes: Vec<u32>,
98}
99
100impl Default for AbsState {
101    fn default() -> Self {
102        Self {
103            buckets: Vec::new(),
104            active_indices: Vec::new(),
105            free_indices: Vec::new(),
106            grid: vec![Vec::new(); (256 / 32 + 1) * (256 / 32 + 1)], // 32 is cell size
107            grid_w: 256 / 32 + 1,
108            next_cluster_id: 0,
109            cluster_sizes: Vec::new(),
110        }
111    }
112}
113
114impl AbsClustering {
115    /// Create an ABS clustering instance with the provided configuration.
116    #[must_use]
117    pub fn new(config: AbsConfig) -> Self {
118        Self { config }
119    }
120
121    /// Cluster hits using the ABS algorithm.
122    ///
123    /// # Errors
124    /// Returns an error if internal state limits are exceeded.
125    pub fn cluster(
126        &self,
127        batch: &mut HitBatch,
128        state: &mut AbsState,
129    ) -> Result<usize, ClusteringError> {
130        if batch.is_empty() {
131            return Ok(0);
132        }
133
134        // Initialize state if needed (or assume persistent state for streaming?)
135        // If streaming, we keep state.
136        // But users might want to cluster a single batch.
137        // Let's assume persistent state passed in `state`.
138        // We only reset `cluster_id` in batch.
139
140        let n = batch.len();
141        // Since batch.cluster_id stores per-hit result, we write to it eventually.
142        // ABS writes cluster ID when assigning hits to buckets.
143        batch.cluster_id.fill(-1);
144        state.cluster_sizes.clear();
145        state.next_cluster_id = 0;
146
147        let window_tof = self.window_tof();
148        let cell_size = 32;
149
150        let grid_w = Self::resize_grid(batch, state, cell_size);
151        let radius_i32 = self.radius_as_i32();
152        let search_ctx = AbsSearchContext {
153            window_tof,
154            cell_size,
155            grid_w,
156            radius_i32,
157        };
158
159        for i in 0..n {
160            let x = batch.x[i];
161            let y = batch.y[i];
162            let tof = batch.tof[i];
163
164            // Aging
165            if i % self.config.scan_interval == 0 && i > 0 {
166                Self::scan_and_close(tof, state, window_tof, cell_size, grid_w);
167            }
168
169            let found = Self::find_bucket_for_hit(x, y, tof, state, &search_ctx);
170
171            if let Some(bidx) = found {
172                let cid = state.buckets[bidx].cluster_id;
173                if let Ok(idx) = usize::try_from(cid) {
174                    if let Some(size) = state.cluster_sizes.get_mut(idx) {
175                        *size += 1;
176                    }
177                }
178                batch.cluster_id[i] = cid;
179                state.buckets[bidx].add_hit(x, y);
180            } else {
181                let bidx = Self::get_bucket(state)?;
182                let cid = Self::new_cluster_id(state)?;
183                state.buckets[bidx].initialize(x, y, tof, cid);
184                if let Ok(idx) = usize::try_from(cid) {
185                    if let Some(size) = state.cluster_sizes.get_mut(idx) {
186                        *size += 1;
187                    }
188                }
189                batch.cluster_id[i] = cid;
190                state.active_indices.push(bidx);
191
192                // Insert into grid
193                let cell_col = usize::from(x) / cell_size;
194                let cell_row = usize::from(y) / cell_size;
195                let gidx = cell_row * grid_w + cell_col;
196                if gidx < state.grid.len() {
197                    state.grid[gidx].push(bidx);
198                }
199            }
200        }
201
202        // Final cleanup?
203        // If this is streaming, we DON'T close active buckets at end of batch unless strictly required.
204        // But user expects clustering to finish for batch?
205        // If we keep state, we might return partial hits?
206        // The `cluster` function usually assumes a closed batch.
207        // If streaming, we should probably close everything to yield results,
208        // OR we yield only closed clusters?
209        // The `HitBatch` needs to be fully labeled if we want to extract from it.
210        // If we leave buckets open, those hits have cluster_id = -1.
211
212        // For now, force close everything at end of batch to match existing behavior on distinct files.
213        // Stream users might want persistence.
214        // But `process_section_into_batch` creates isolated batches per chunk.
215        // If we want cross-chunk clustering, we need persistent state.
216        // I'll close all for now.
217
218        let last_tof = batch.tof.last().copied().unwrap_or(0);
219        let min_cluster_size = u32::from(self.config.min_cluster_size);
220        Ok(Self::finish_batch(
221            batch,
222            state,
223            window_tof,
224            cell_size,
225            grid_w,
226            last_tof,
227            min_cluster_size,
228        ))
229    }
230
231    fn window_tof(&self) -> u32 {
232        let window = (self.config.neutron_correlation_window_ns / 25.0).ceil();
233        if window <= 0.0 {
234            return 0;
235        }
236        if window >= f64::from(u32::MAX) {
237            return u32::MAX;
238        }
239        format!("{window:.0}").parse::<u32>().unwrap_or(u32::MAX)
240    }
241
242    fn radius_as_i32(&self) -> i32 {
243        let radius = self.config.radius.ceil();
244        if radius <= 0.0 {
245            return 0;
246        }
247        if radius >= f64::from(i32::MAX) {
248            return i32::MAX;
249        }
250        format!("{radius:.0}").parse::<i32>().unwrap_or(i32::MAX)
251    }
252
253    fn resize_grid(batch: &HitBatch, state: &mut AbsState, cell_size: usize) -> usize {
254        let mut max_x = 0usize;
255        let mut max_y = 0usize;
256        for i in 0..batch.len() {
257            let x = usize::from(batch.x[i]);
258            let y = usize::from(batch.y[i]);
259            if x > max_x {
260                max_x = x;
261            }
262            if y > max_y {
263                max_y = y;
264            }
265        }
266
267        let req_w = max_x + 32;
268        let req_h = max_y + 32;
269        let req_grid_w = req_w / cell_size + 1;
270        let req_grid_h = req_h / cell_size + 1;
271        let req_total = req_grid_w * req_grid_h;
272
273        if req_total > state.grid.len() || req_grid_w > state.grid_w {
274            state.grid = vec![Vec::new(); req_total];
275            state.grid_w = req_grid_w;
276        }
277
278        state.grid_w
279    }
280
281    fn finalize_clusters(
282        batch: &mut HitBatch,
283        state: &mut AbsState,
284        min_cluster_size: u32,
285    ) -> usize {
286        let mut remap = vec![-1i32; state.cluster_sizes.len()];
287        let mut next = 0i32;
288        for (cid, &count) in state.cluster_sizes.iter().enumerate() {
289            if count >= min_cluster_size {
290                remap[cid] = next;
291                next += 1;
292            }
293        }
294
295        for cid in &mut batch.cluster_id {
296            if let Ok(idx) = usize::try_from(*cid) {
297                if let Some(&new_id) = remap.get(idx) {
298                    *cid = new_id;
299                }
300            }
301        }
302
303        usize::try_from(next).unwrap_or(0)
304    }
305
306    fn finish_batch(
307        batch: &mut HitBatch,
308        state: &mut AbsState,
309        window_tof: u32,
310        cell_size: usize,
311        grid_w: usize,
312        last_tof: u32,
313        min_cluster_size: u32,
314    ) -> usize {
315        Self::scan_and_close(
316            last_tof.wrapping_add(window_tof + 1),
317            state,
318            window_tof,
319            cell_size,
320            grid_w,
321        );
322
323        // Force close remaining active
324        Self::close_active_buckets(state, cell_size, grid_w);
325        Self::finalize_clusters(batch, state, min_cluster_size)
326    }
327
328    fn find_bucket_for_hit(
329        x: u16,
330        y: u16,
331        tof: u32,
332        state: &AbsState,
333        ctx: &AbsSearchContext,
334    ) -> Option<usize> {
335        let cell_col = usize::from(x) / ctx.cell_size;
336        let cell_row = usize::from(y) / ctx.cell_size;
337        let cell_col_i32 = i32::try_from(cell_col).unwrap_or(i32::MAX);
338        let cell_row_i32 = i32::try_from(cell_row).unwrap_or(i32::MAX);
339        let ix = i32::from(x);
340        let iy = i32::from(y);
341
342        for dy in -1..=1 {
343            for dx in -1..=1 {
344                let ncx = cell_col_i32 + dx;
345                let ncy = cell_row_i32 + dy;
346                if ncx < 0 || ncy < 0 {
347                    continue;
348                }
349                let (Ok(neighbor_x), Ok(neighbor_y)) = (usize::try_from(ncx), usize::try_from(ncy))
350                else {
351                    continue;
352                };
353                let gidx = neighbor_y * ctx.grid_w + neighbor_x;
354                if let Some(cell) = state.grid.get(gidx) {
355                    for &bidx in cell {
356                        let bucket = &state.buckets[bidx];
357                        if bucket.is_active {
358                            let x_min_bound = i32::from(bucket.x_min) - ctx.radius_i32;
359                            let x_max_bound = i32::from(bucket.x_max) + ctx.radius_i32;
360                            let y_min_bound = i32::from(bucket.y_min) - ctx.radius_i32;
361                            let y_max_bound = i32::from(bucket.y_max) + ctx.radius_i32;
362
363                            if ix >= x_min_bound
364                                && ix <= x_max_bound
365                                && iy >= y_min_bound
366                                && iy <= y_max_bound
367                            {
368                                let dt = tof.wrapping_sub(bucket.start_tof);
369                                if dt <= ctx.window_tof {
370                                    return Some(bidx);
371                                }
372                            }
373                        }
374                    }
375                }
376            }
377        }
378        None
379    }
380
381    fn close_active_buckets(state: &mut AbsState, cell_size: usize, grid_w: usize) {
382        let active = std::mem::take(&mut state.active_indices);
383        for bidx in active {
384            state.buckets[bidx].is_active = false;
385            state.free_indices.push(bidx);
386            let b = &state.buckets[bidx];
387            let gx = usize::from(b.insertion_x) / cell_size;
388            let gy = usize::from(b.insertion_y) / cell_size;
389            let gidx = gy * grid_w + gx;
390            if let Some(cell) = state.grid.get_mut(gidx) {
391                if let Some(pos) = cell.iter().position(|&x| x == bidx) {
392                    cell.swap_remove(pos);
393                }
394            }
395        }
396    }
397
398    fn get_bucket(state: &mut AbsState) -> Result<usize, ClusteringError> {
399        if let Some(idx) = state.free_indices.pop() {
400            Ok(idx)
401        } else {
402            if state.buckets.len() >= 1_000_000 {
403                return Err(ClusteringError::StateError(
404                    "bucket pool size exceeds limit (1,000,000)".to_string(),
405                ));
406            }
407            let idx = state.buckets.len();
408            state.buckets.push(Bucket::new());
409            Ok(idx)
410        }
411    }
412
413    fn new_cluster_id(state: &mut AbsState) -> Result<i32, ClusteringError> {
414        if state.next_cluster_id == i32::MAX {
415            return Err(ClusteringError::StateError(
416                "cluster id overflow".to_string(),
417            ));
418        }
419        let cid = state.next_cluster_id;
420        state.next_cluster_id += 1;
421        state.cluster_sizes.push(0);
422        Ok(cid)
423    }
424
425    fn scan_and_close(
426        ref_tof: u32,
427        state: &mut AbsState,
428        window_tof: u32,
429        cell_size: usize,
430        grid_w: usize,
431    ) {
432        let mut keep = Vec::new();
433        let mut remove = Vec::new();
434
435        for &bidx in &state.active_indices {
436            let bucket = &state.buckets[bidx];
437            let dt = ref_tof.wrapping_sub(bucket.start_tof);
438            if dt > window_tof {
439                remove.push(bidx);
440            } else {
441                keep.push(bidx);
442            }
443        }
444        state.active_indices = keep;
445
446        for bidx in remove {
447            // Remove from grid
448            let b = &state.buckets[bidx];
449            let gx = usize::from(b.insertion_x) / cell_size;
450            let gy = usize::from(b.insertion_y) / cell_size;
451            let gidx = gy * grid_w + gx;
452            if let Some(cell) = state.grid.get_mut(gidx) {
453                if let Some(pos) = cell.iter().position(|&x| x == bidx) {
454                    cell.swap_remove(pos);
455                }
456            }
457
458            state.buckets[bidx].is_active = false;
459            state.free_indices.push(bidx);
460        }
461    }
462}