1use rustpix_core::clustering::ClusteringError;
4use rustpix_core::soa::HitBatch;
5
6#[derive(Clone, Debug)]
8pub struct AbsConfig {
9 pub radius: f64,
11 pub neutron_correlation_window_ns: f64,
13 pub min_cluster_size: u16,
15 pub scan_interval: usize,
17}
18
19impl Default for AbsConfig {
20 fn default() -> Self {
21 Self {
22 radius: 5.0,
23 neutron_correlation_window_ns: 75.0,
24 min_cluster_size: 1,
25 scan_interval: 100,
26 }
27 }
28}
29
30struct Bucket {
31 x_min: u16,
32 x_max: u16,
33 y_min: u16,
34 y_max: u16,
35 start_tof: u32,
36 cluster_id: i32,
37 is_active: bool,
38 insertion_x: u16,
39 insertion_y: u16,
40}
41
42impl Bucket {
43 fn new() -> Self {
44 Self {
45 x_min: u16::MAX,
46 x_max: 0,
47 y_min: u16::MAX,
48 y_max: 0,
49 start_tof: 0,
50 cluster_id: -1,
51 is_active: false,
52 insertion_x: 0,
53 insertion_y: 0,
54 }
55 }
56
57 fn initialize(&mut self, x: u16, y: u16, tof: u32, cluster_id: i32) {
58 self.x_min = x;
59 self.x_max = x;
60 self.y_min = y;
61 self.y_max = y;
62 self.start_tof = tof;
63 self.cluster_id = cluster_id;
64 self.is_active = true;
65 self.insertion_x = x;
66 self.insertion_y = y;
67 }
68
69 fn add_hit(&mut self, x: u16, y: u16) {
70 self.x_min = self.x_min.min(x);
71 self.x_max = self.x_max.max(x);
72 self.y_min = self.y_min.min(y);
73 self.y_max = self.y_max.max(y);
74 }
75}
76
77pub struct AbsClustering {
79 config: AbsConfig,
80}
81
82struct AbsSearchContext {
83 window_tof: u32,
84 cell_size: usize,
85 grid_w: usize,
86 radius_i32: i32,
87}
88
89pub struct AbsState {
91 buckets: Vec<Bucket>,
92 active_indices: Vec<usize>,
93 free_indices: Vec<usize>,
94 grid: Vec<Vec<usize>>, grid_w: usize,
96 next_cluster_id: i32,
97 cluster_sizes: Vec<u32>,
98}
99
100impl Default for AbsState {
101 fn default() -> Self {
102 Self {
103 buckets: Vec::new(),
104 active_indices: Vec::new(),
105 free_indices: Vec::new(),
106 grid: vec![Vec::new(); (256 / 32 + 1) * (256 / 32 + 1)], grid_w: 256 / 32 + 1,
108 next_cluster_id: 0,
109 cluster_sizes: Vec::new(),
110 }
111 }
112}
113
114impl AbsClustering {
115 #[must_use]
117 pub fn new(config: AbsConfig) -> Self {
118 Self { config }
119 }
120
121 pub fn cluster(
126 &self,
127 batch: &mut HitBatch,
128 state: &mut AbsState,
129 ) -> Result<usize, ClusteringError> {
130 if batch.is_empty() {
131 return Ok(0);
132 }
133
134 let n = batch.len();
141 batch.cluster_id.fill(-1);
144 state.cluster_sizes.clear();
145 state.next_cluster_id = 0;
146
147 let window_tof = self.window_tof();
148 let cell_size = 32;
149
150 let grid_w = Self::resize_grid(batch, state, cell_size);
151 let radius_i32 = self.radius_as_i32();
152 let search_ctx = AbsSearchContext {
153 window_tof,
154 cell_size,
155 grid_w,
156 radius_i32,
157 };
158
159 for i in 0..n {
160 let x = batch.x[i];
161 let y = batch.y[i];
162 let tof = batch.tof[i];
163
164 if i % self.config.scan_interval == 0 && i > 0 {
166 Self::scan_and_close(tof, state, window_tof, cell_size, grid_w);
167 }
168
169 let found = Self::find_bucket_for_hit(x, y, tof, state, &search_ctx);
170
171 if let Some(bidx) = found {
172 let cid = state.buckets[bidx].cluster_id;
173 if let Ok(idx) = usize::try_from(cid) {
174 if let Some(size) = state.cluster_sizes.get_mut(idx) {
175 *size += 1;
176 }
177 }
178 batch.cluster_id[i] = cid;
179 state.buckets[bidx].add_hit(x, y);
180 } else {
181 let bidx = Self::get_bucket(state)?;
182 let cid = Self::new_cluster_id(state)?;
183 state.buckets[bidx].initialize(x, y, tof, cid);
184 if let Ok(idx) = usize::try_from(cid) {
185 if let Some(size) = state.cluster_sizes.get_mut(idx) {
186 *size += 1;
187 }
188 }
189 batch.cluster_id[i] = cid;
190 state.active_indices.push(bidx);
191
192 let cell_col = usize::from(x) / cell_size;
194 let cell_row = usize::from(y) / cell_size;
195 let gidx = cell_row * grid_w + cell_col;
196 if gidx < state.grid.len() {
197 state.grid[gidx].push(bidx);
198 }
199 }
200 }
201
202 let last_tof = batch.tof.last().copied().unwrap_or(0);
219 let min_cluster_size = u32::from(self.config.min_cluster_size);
220 Ok(Self::finish_batch(
221 batch,
222 state,
223 window_tof,
224 cell_size,
225 grid_w,
226 last_tof,
227 min_cluster_size,
228 ))
229 }
230
231 fn window_tof(&self) -> u32 {
232 let window = (self.config.neutron_correlation_window_ns / 25.0).ceil();
233 if window <= 0.0 {
234 return 0;
235 }
236 if window >= f64::from(u32::MAX) {
237 return u32::MAX;
238 }
239 format!("{window:.0}").parse::<u32>().unwrap_or(u32::MAX)
240 }
241
242 fn radius_as_i32(&self) -> i32 {
243 let radius = self.config.radius.ceil();
244 if radius <= 0.0 {
245 return 0;
246 }
247 if radius >= f64::from(i32::MAX) {
248 return i32::MAX;
249 }
250 format!("{radius:.0}").parse::<i32>().unwrap_or(i32::MAX)
251 }
252
253 fn resize_grid(batch: &HitBatch, state: &mut AbsState, cell_size: usize) -> usize {
254 let mut max_x = 0usize;
255 let mut max_y = 0usize;
256 for i in 0..batch.len() {
257 let x = usize::from(batch.x[i]);
258 let y = usize::from(batch.y[i]);
259 if x > max_x {
260 max_x = x;
261 }
262 if y > max_y {
263 max_y = y;
264 }
265 }
266
267 let req_w = max_x + 32;
268 let req_h = max_y + 32;
269 let req_grid_w = req_w / cell_size + 1;
270 let req_grid_h = req_h / cell_size + 1;
271 let req_total = req_grid_w * req_grid_h;
272
273 if req_total > state.grid.len() || req_grid_w > state.grid_w {
274 state.grid = vec![Vec::new(); req_total];
275 state.grid_w = req_grid_w;
276 }
277
278 state.grid_w
279 }
280
281 fn finalize_clusters(
282 batch: &mut HitBatch,
283 state: &mut AbsState,
284 min_cluster_size: u32,
285 ) -> usize {
286 let mut remap = vec![-1i32; state.cluster_sizes.len()];
287 let mut next = 0i32;
288 for (cid, &count) in state.cluster_sizes.iter().enumerate() {
289 if count >= min_cluster_size {
290 remap[cid] = next;
291 next += 1;
292 }
293 }
294
295 for cid in &mut batch.cluster_id {
296 if let Ok(idx) = usize::try_from(*cid) {
297 if let Some(&new_id) = remap.get(idx) {
298 *cid = new_id;
299 }
300 }
301 }
302
303 usize::try_from(next).unwrap_or(0)
304 }
305
306 fn finish_batch(
307 batch: &mut HitBatch,
308 state: &mut AbsState,
309 window_tof: u32,
310 cell_size: usize,
311 grid_w: usize,
312 last_tof: u32,
313 min_cluster_size: u32,
314 ) -> usize {
315 Self::scan_and_close(
316 last_tof.wrapping_add(window_tof + 1),
317 state,
318 window_tof,
319 cell_size,
320 grid_w,
321 );
322
323 Self::close_active_buckets(state, cell_size, grid_w);
325 Self::finalize_clusters(batch, state, min_cluster_size)
326 }
327
328 fn find_bucket_for_hit(
329 x: u16,
330 y: u16,
331 tof: u32,
332 state: &AbsState,
333 ctx: &AbsSearchContext,
334 ) -> Option<usize> {
335 let cell_col = usize::from(x) / ctx.cell_size;
336 let cell_row = usize::from(y) / ctx.cell_size;
337 let cell_col_i32 = i32::try_from(cell_col).unwrap_or(i32::MAX);
338 let cell_row_i32 = i32::try_from(cell_row).unwrap_or(i32::MAX);
339 let ix = i32::from(x);
340 let iy = i32::from(y);
341
342 for dy in -1..=1 {
343 for dx in -1..=1 {
344 let ncx = cell_col_i32 + dx;
345 let ncy = cell_row_i32 + dy;
346 if ncx < 0 || ncy < 0 {
347 continue;
348 }
349 let (Ok(neighbor_x), Ok(neighbor_y)) = (usize::try_from(ncx), usize::try_from(ncy))
350 else {
351 continue;
352 };
353 let gidx = neighbor_y * ctx.grid_w + neighbor_x;
354 if let Some(cell) = state.grid.get(gidx) {
355 for &bidx in cell {
356 let bucket = &state.buckets[bidx];
357 if bucket.is_active {
358 let x_min_bound = i32::from(bucket.x_min) - ctx.radius_i32;
359 let x_max_bound = i32::from(bucket.x_max) + ctx.radius_i32;
360 let y_min_bound = i32::from(bucket.y_min) - ctx.radius_i32;
361 let y_max_bound = i32::from(bucket.y_max) + ctx.radius_i32;
362
363 if ix >= x_min_bound
364 && ix <= x_max_bound
365 && iy >= y_min_bound
366 && iy <= y_max_bound
367 {
368 let dt = tof.wrapping_sub(bucket.start_tof);
369 if dt <= ctx.window_tof {
370 return Some(bidx);
371 }
372 }
373 }
374 }
375 }
376 }
377 }
378 None
379 }
380
381 fn close_active_buckets(state: &mut AbsState, cell_size: usize, grid_w: usize) {
382 let active = std::mem::take(&mut state.active_indices);
383 for bidx in active {
384 state.buckets[bidx].is_active = false;
385 state.free_indices.push(bidx);
386 let b = &state.buckets[bidx];
387 let gx = usize::from(b.insertion_x) / cell_size;
388 let gy = usize::from(b.insertion_y) / cell_size;
389 let gidx = gy * grid_w + gx;
390 if let Some(cell) = state.grid.get_mut(gidx) {
391 if let Some(pos) = cell.iter().position(|&x| x == bidx) {
392 cell.swap_remove(pos);
393 }
394 }
395 }
396 }
397
398 fn get_bucket(state: &mut AbsState) -> Result<usize, ClusteringError> {
399 if let Some(idx) = state.free_indices.pop() {
400 Ok(idx)
401 } else {
402 if state.buckets.len() >= 1_000_000 {
403 return Err(ClusteringError::StateError(
404 "bucket pool size exceeds limit (1,000,000)".to_string(),
405 ));
406 }
407 let idx = state.buckets.len();
408 state.buckets.push(Bucket::new());
409 Ok(idx)
410 }
411 }
412
413 fn new_cluster_id(state: &mut AbsState) -> Result<i32, ClusteringError> {
414 if state.next_cluster_id == i32::MAX {
415 return Err(ClusteringError::StateError(
416 "cluster id overflow".to_string(),
417 ));
418 }
419 let cid = state.next_cluster_id;
420 state.next_cluster_id += 1;
421 state.cluster_sizes.push(0);
422 Ok(cid)
423 }
424
425 fn scan_and_close(
426 ref_tof: u32,
427 state: &mut AbsState,
428 window_tof: u32,
429 cell_size: usize,
430 grid_w: usize,
431 ) {
432 let mut keep = Vec::new();
433 let mut remove = Vec::new();
434
435 for &bidx in &state.active_indices {
436 let bucket = &state.buckets[bidx];
437 let dt = ref_tof.wrapping_sub(bucket.start_tof);
438 if dt > window_tof {
439 remove.push(bidx);
440 } else {
441 keep.push(bidx);
442 }
443 }
444 state.active_indices = keep;
445
446 for bidx in remove {
447 let b = &state.buckets[bidx];
449 let gx = usize::from(b.insertion_x) / cell_size;
450 let gy = usize::from(b.insertion_y) / cell_size;
451 let gidx = gy * grid_w + gx;
452 if let Some(cell) = state.grid.get_mut(gidx) {
453 if let Some(pos) = cell.iter().position(|&x| x == bidx) {
454 cell.swap_remove(pos);
455 }
456 }
457
458 state.buckets[bidx].is_active = false;
459 state.free_indices.push(bidx);
460 }
461 }
462}