1use crate::error::ExtractionError;
5use crate::neutron::{Neutron, NeutronBatch};
6
7#[derive(Clone, Debug)]
9pub struct ExtractionConfig {
10 pub super_resolution_factor: f64,
12 pub weighted_by_tot: bool,
14 pub min_tot_threshold: u16,
16}
17
18impl Default for ExtractionConfig {
19 fn default() -> Self {
20 Self {
21 super_resolution_factor: 8.0,
22 weighted_by_tot: true,
23 min_tot_threshold: 10,
24 }
25 }
26}
27
28impl ExtractionConfig {
29 #[must_use]
31 pub fn venus_defaults() -> Self {
32 Self::default()
33 }
34
35 #[must_use]
37 pub fn with_super_resolution(mut self, factor: f64) -> Self {
38 self.super_resolution_factor = factor;
39 self
40 }
41
42 #[must_use]
44 pub fn with_weighted_by_tot(mut self, weighted: bool) -> Self {
45 self.weighted_by_tot = weighted;
46 self
47 }
48
49 #[must_use]
51 pub fn with_min_tot_threshold(mut self, threshold: u16) -> Self {
52 self.min_tot_threshold = threshold;
53 self
54 }
55}
56
57pub trait NeutronExtraction: Send + Sync {
61 fn name(&self) -> &'static str;
63
64 fn configure(&mut self, config: ExtractionConfig);
66
67 fn config(&self) -> &ExtractionConfig;
69
70 fn extract_soa(
78 &self,
79 batch: &crate::soa::HitBatch,
80 num_clusters: usize,
81 ) -> Result<Vec<Neutron>, ExtractionError>;
82}
83
84#[derive(Clone, Debug, Default)]
91struct ClusterAccumulator {
92 sum_x: f64,
93 sum_y: f64,
94 raw_sum_x: f64,
95 raw_sum_y: f64,
96 sum_tot: u64,
97 count: u32,
98 max_tot: u16,
99 rep_tof: u32,
100 rep_chip: u8,
101}
102
103#[derive(Clone, Debug, Default)]
110pub struct SimpleCentroidExtraction {
111 config: ExtractionConfig,
112}
113
114impl SimpleCentroidExtraction {
115 #[must_use]
117 pub fn new() -> Self {
118 Self {
119 config: ExtractionConfig::default(),
120 }
121 }
122
123 #[must_use]
125 pub fn with_config(config: ExtractionConfig) -> Self {
126 Self { config }
127 }
128}
129
130impl NeutronExtraction for SimpleCentroidExtraction {
131 fn name(&self) -> &'static str {
132 "SimpleCentroid"
133 }
134
135 fn configure(&mut self, config: ExtractionConfig) {
136 self.config = config;
137 }
138
139 fn config(&self) -> &ExtractionConfig {
140 &self.config
141 }
142
143 fn extract_soa(
144 &self,
145 batch: &crate::soa::HitBatch,
146 num_clusters: usize,
147 ) -> Result<Vec<Neutron>, ExtractionError> {
148 let mut accumulators = vec![ClusterAccumulator::default(); num_clusters];
149 if self.config.weighted_by_tot {
150 accumulate_weighted(
151 &mut accumulators,
152 batch,
153 num_clusters,
154 self.config.min_tot_threshold,
155 );
156 Ok(build_neutrons_weighted(
157 accumulators,
158 self.config.super_resolution_factor,
159 ))
160 } else {
161 accumulate_unweighted(
162 &mut accumulators,
163 batch,
164 num_clusters,
165 self.config.min_tot_threshold,
166 );
167 Ok(build_neutrons_unweighted(
168 accumulators,
169 self.config.super_resolution_factor,
170 ))
171 }
172 }
173}
174
175impl SimpleCentroidExtraction {
176 pub fn extract_soa_batch(
181 &self,
182 batch: &crate::soa::HitBatch,
183 num_clusters: usize,
184 ) -> Result<NeutronBatch, ExtractionError> {
185 let mut accumulators = vec![ClusterAccumulator::default(); num_clusters];
186 if self.config.weighted_by_tot {
187 accumulate_weighted(
188 &mut accumulators,
189 batch,
190 num_clusters,
191 self.config.min_tot_threshold,
192 );
193 Ok(build_neutron_batch_weighted(
194 accumulators,
195 self.config.super_resolution_factor,
196 ))
197 } else {
198 accumulate_unweighted(
199 &mut accumulators,
200 batch,
201 num_clusters,
202 self.config.min_tot_threshold,
203 );
204 Ok(build_neutron_batch_unweighted(
205 accumulators,
206 self.config.super_resolution_factor,
207 ))
208 }
209 }
210}
211
212#[inline]
213fn cluster_index(label: i32, num_clusters: usize) -> Option<usize> {
214 if label < 0 {
215 return None;
216 }
217 let idx = usize::try_from(label).ok()?;
218 if idx >= num_clusters {
219 None
220 } else {
221 Some(idx)
222 }
223}
224
225fn accumulate_weighted(
226 accumulators: &mut [ClusterAccumulator],
227 batch: &crate::soa::HitBatch,
228 num_clusters: usize,
229 min_tot: u16,
230) {
231 let labels = &batch.cluster_id;
232 let x_values = &batch.x;
233 let y_values = &batch.y;
234 let time_over_threshold = &batch.tot;
235 let time_of_flight = &batch.tof;
236 let chip_ids = &batch.chip_id;
237
238 if min_tot > 0 {
239 for i in 0..labels.len() {
240 let Some(cluster_idx) = cluster_index(labels[i], num_clusters) else {
241 continue;
242 };
243 let tot = time_over_threshold[i];
244 if tot < min_tot {
245 continue;
246 }
247
248 let acc = &mut accumulators[cluster_idx];
249 let x = f64::from(x_values[i]);
250 let y = f64::from(y_values[i]);
251 let weight = f64::from(tot);
252
253 acc.count += 1;
254 acc.sum_tot += u64::from(tot);
255 acc.raw_sum_x += x;
256 acc.raw_sum_y += y;
257 acc.sum_x += x * weight;
258 acc.sum_y += y * weight;
259
260 if tot >= acc.max_tot {
261 acc.max_tot = tot;
262 acc.rep_tof = time_of_flight[i];
263 acc.rep_chip = chip_ids[i];
264 }
265 }
266 } else {
267 for i in 0..labels.len() {
268 let Some(cluster_idx) = cluster_index(labels[i], num_clusters) else {
269 continue;
270 };
271 let tot = time_over_threshold[i];
272 let acc = &mut accumulators[cluster_idx];
273 let x = f64::from(x_values[i]);
274 let y = f64::from(y_values[i]);
275 let weight = f64::from(tot);
276
277 acc.count += 1;
278 acc.sum_tot += u64::from(tot);
279 acc.raw_sum_x += x;
280 acc.raw_sum_y += y;
281 acc.sum_x += x * weight;
282 acc.sum_y += y * weight;
283
284 if tot >= acc.max_tot {
285 acc.max_tot = tot;
286 acc.rep_tof = time_of_flight[i];
287 acc.rep_chip = chip_ids[i];
288 }
289 }
290 }
291}
292
293fn accumulate_unweighted(
294 accumulators: &mut [ClusterAccumulator],
295 batch: &crate::soa::HitBatch,
296 num_clusters: usize,
297 min_tot: u16,
298) {
299 let labels = &batch.cluster_id;
300 let x_values = &batch.x;
301 let y_values = &batch.y;
302 let time_over_threshold = &batch.tot;
303 let time_of_flight = &batch.tof;
304 let chip_ids = &batch.chip_id;
305
306 if min_tot > 0 {
307 for i in 0..labels.len() {
308 let Some(cluster_idx) = cluster_index(labels[i], num_clusters) else {
309 continue;
310 };
311 let tot = time_over_threshold[i];
312 if tot < min_tot {
313 continue;
314 }
315
316 let acc = &mut accumulators[cluster_idx];
317 let x = f64::from(x_values[i]);
318 let y = f64::from(y_values[i]);
319
320 acc.count += 1;
321 acc.sum_tot += u64::from(tot);
322 acc.raw_sum_x += x;
323 acc.raw_sum_y += y;
324
325 if tot >= acc.max_tot {
326 acc.max_tot = tot;
327 acc.rep_tof = time_of_flight[i];
328 acc.rep_chip = chip_ids[i];
329 }
330 }
331 } else {
332 for i in 0..labels.len() {
333 let Some(cluster_idx) = cluster_index(labels[i], num_clusters) else {
334 continue;
335 };
336 let tot = time_over_threshold[i];
337
338 let acc = &mut accumulators[cluster_idx];
339 let x = f64::from(x_values[i]);
340 let y = f64::from(y_values[i]);
341
342 acc.count += 1;
343 acc.sum_tot += u64::from(tot);
344 acc.raw_sum_x += x;
345 acc.raw_sum_y += y;
346
347 if tot >= acc.max_tot {
348 acc.max_tot = tot;
349 acc.rep_tof = time_of_flight[i];
350 acc.rep_chip = chip_ids[i];
351 }
352 }
353 }
354}
355
356fn sum_tot_as_f64(sum_tot: u64) -> f64 {
357 let clamped = sum_tot.min(u64::from(u32::MAX));
358 f64::from(u32::try_from(clamped).unwrap_or(u32::MAX))
359}
360
361fn build_neutrons_weighted(accumulators: Vec<ClusterAccumulator>, scale: f64) -> Vec<Neutron> {
362 let mut neutrons = Vec::with_capacity(accumulators.len());
363 for acc in accumulators {
364 if acc.count == 0 {
365 continue;
366 }
367
368 let (centroid_x, centroid_y) = if acc.sum_tot > 0 {
369 let sum_weight = sum_tot_as_f64(acc.sum_tot);
370 (acc.sum_x / sum_weight, acc.sum_y / sum_weight)
371 } else {
372 (
373 acc.raw_sum_x / f64::from(acc.count),
374 acc.raw_sum_y / f64::from(acc.count),
375 )
376 };
377
378 let scaled_x = centroid_x * scale;
379 let scaled_y = centroid_y * scale;
380
381 neutrons.push(Neutron::new(
382 scaled_x,
383 scaled_y,
384 acc.rep_tof,
385 u16::try_from(acc.sum_tot.min(u64::from(u16::MAX))).unwrap_or(u16::MAX),
386 u16::try_from(acc.count).unwrap_or(u16::MAX),
387 acc.rep_chip,
388 ));
389 }
390 neutrons
391}
392
393fn build_neutrons_unweighted(accumulators: Vec<ClusterAccumulator>, scale: f64) -> Vec<Neutron> {
394 let mut neutrons = Vec::with_capacity(accumulators.len());
395 for acc in accumulators {
396 if acc.count == 0 {
397 continue;
398 }
399
400 let centroid_x = acc.raw_sum_x / f64::from(acc.count);
401 let centroid_y = acc.raw_sum_y / f64::from(acc.count);
402
403 let scaled_x = centroid_x * scale;
404 let scaled_y = centroid_y * scale;
405
406 neutrons.push(Neutron::new(
407 scaled_x,
408 scaled_y,
409 acc.rep_tof,
410 u16::try_from(acc.sum_tot.min(u64::from(u16::MAX))).unwrap_or(u16::MAX),
411 u16::try_from(acc.count).unwrap_or(u16::MAX),
412 acc.rep_chip,
413 ));
414 }
415 neutrons
416}
417
418fn build_neutron_batch_weighted(accumulators: Vec<ClusterAccumulator>, scale: f64) -> NeutronBatch {
419 let mut batch = NeutronBatch::with_capacity(accumulators.len());
420 for acc in accumulators {
421 if acc.count == 0 {
422 continue;
423 }
424
425 let (centroid_x, centroid_y) = if acc.sum_tot > 0 {
426 let sum_weight = sum_tot_as_f64(acc.sum_tot);
427 (acc.sum_x / sum_weight, acc.sum_y / sum_weight)
428 } else {
429 (
430 acc.raw_sum_x / f64::from(acc.count),
431 acc.raw_sum_y / f64::from(acc.count),
432 )
433 };
434
435 batch.push(Neutron::new(
436 centroid_x * scale,
437 centroid_y * scale,
438 acc.rep_tof,
439 u16::try_from(acc.sum_tot.min(u64::from(u16::MAX))).unwrap_or(u16::MAX),
440 u16::try_from(acc.count).unwrap_or(u16::MAX),
441 acc.rep_chip,
442 ));
443 }
444 batch
445}
446
447fn build_neutron_batch_unweighted(
448 accumulators: Vec<ClusterAccumulator>,
449 scale: f64,
450) -> NeutronBatch {
451 let mut batch = NeutronBatch::with_capacity(accumulators.len());
452 for acc in accumulators {
453 if acc.count == 0 {
454 continue;
455 }
456
457 let centroid_x = acc.raw_sum_x / f64::from(acc.count);
458 let centroid_y = acc.raw_sum_y / f64::from(acc.count);
459
460 batch.push(Neutron::new(
461 centroid_x * scale,
462 centroid_y * scale,
463 acc.rep_tof,
464 u16::try_from(acc.sum_tot.min(u64::from(u16::MAX))).unwrap_or(u16::MAX),
465 u16::try_from(acc.count).unwrap_or(u16::MAX),
466 acc.rep_chip,
467 ));
468 }
469 batch
470}
471
472#[cfg(test)]
473mod tests {
474 use super::*;
475 use crate::soa::HitBatch;
476
477 fn make_batch(hits: &[(u32, u16, u16, u32, u16, u8, i32)]) -> HitBatch {
478 let mut batch = HitBatch::with_capacity(hits.len());
479 for (i, (tof, x, y, timestamp, tot, chip_id, cluster_id)) in hits.iter().enumerate() {
480 batch.push((*x, *y, *tof, *tot, *timestamp, *chip_id));
481 batch.cluster_id[i] = *cluster_id;
482 }
483 batch
484 }
485
486 #[test]
487 fn test_single_hit_extraction() {
488 let batch = make_batch(&[(1000, 100, 200, 500, 50, 0, 0)]);
489
490 let extractor = SimpleCentroidExtraction::new();
491 let neutrons = extractor.extract_soa(&batch, 1).unwrap();
492
493 assert_eq!(neutrons.len(), 1);
494 assert!((neutrons[0].x - 800.0).abs() < f64::EPSILON); assert!((neutrons[0].y - 1600.0).abs() < f64::EPSILON); assert_eq!(neutrons[0].tof, 1000);
497 assert_eq!(neutrons[0].n_hits, 1);
498 }
499
500 #[test]
501 fn test_weighted_centroid() {
502 let batch = make_batch(&[
503 (1000, 0, 0, 500, 30, 0, 0), (1000, 2, 0, 500, 10, 0, 0), ]);
506
507 let extractor = SimpleCentroidExtraction::new();
508 let neutrons = extractor.extract_soa(&batch, 1).unwrap();
509
510 assert_eq!(neutrons.len(), 1);
511 assert!((neutrons[0].x - 4.0).abs() < 0.01);
513 assert_eq!(neutrons[0].n_hits, 2);
514 assert_eq!(neutrons[0].tot, 40);
515 }
516
517 #[test]
518 fn test_multiple_clusters() {
519 let batch = make_batch(&[
520 (1000, 10, 10, 500, 50, 0, 0),
521 (1000, 11, 10, 500, 50, 0, 0),
522 (2000, 100, 100, 500, 50, 1, 1),
523 ]);
524
525 let extractor = SimpleCentroidExtraction::new();
526 let neutrons = extractor.extract_soa(&batch, 2).unwrap();
527
528 assert_eq!(neutrons.len(), 2);
529 assert_eq!(neutrons[0].n_hits, 2);
530 assert_eq!(neutrons[1].n_hits, 1);
531 }
532
533 #[test]
534 fn test_tot_threshold_filters_low_tot_hits() {
535 let batch = make_batch(&[
537 (1000, 0, 0, 500, 5, 0, 0), (1000, 10, 0, 500, 15, 0, 0), (1000, 20, 0, 500, 20, 0, 0), ]);
541
542 let extractor = SimpleCentroidExtraction::new();
544 let neutrons = extractor.extract_soa(&batch, 1).unwrap();
545
546 assert_eq!(neutrons.len(), 1);
547 assert_eq!(neutrons[0].n_hits, 2);
549 assert_eq!(neutrons[0].tot, 35);
551 assert!((neutrons[0].x - 125.71).abs() < 0.1);
555 }
556
557 #[test]
558 fn test_tot_threshold_skips_empty_clusters_after_filtering() {
559 let batch = make_batch(&[
561 (1000, 0, 0, 500, 5, 0, 0), (1000, 1, 0, 500, 8, 0, 0), ]);
564
565 let extractor = SimpleCentroidExtraction::new();
566 let neutrons = extractor.extract_soa(&batch, 1).unwrap();
567
568 assert_eq!(neutrons.len(), 0);
570 }
571
572 #[test]
573 fn test_tot_threshold_disabled_when_zero() {
574 let batch = make_batch(&[
575 (1000, 0, 0, 500, 5, 0, 0), (1000, 10, 0, 500, 3, 0, 0), ]);
578
579 let mut extractor = SimpleCentroidExtraction::new();
581 extractor.configure(ExtractionConfig::default().with_min_tot_threshold(0));
582
583 let neutrons = extractor.extract_soa(&batch, 1).unwrap();
584
585 assert_eq!(neutrons.len(), 1);
587 assert_eq!(neutrons[0].n_hits, 2);
588 assert_eq!(neutrons[0].tot, 8); }
590
591 #[test]
592 fn test_representative_tof_from_max_tot_after_filtering() {
593 let batch = make_batch(&[
595 (1000, 0, 0, 500, 5, 0, 0), (2000, 10, 0, 500, 15, 0, 0), (3000, 20, 0, 500, 25, 0, 0), ]);
599
600 let extractor = SimpleCentroidExtraction::new();
601 let neutrons = extractor.extract_soa(&batch, 1).unwrap();
602
603 assert_eq!(neutrons.len(), 1);
604 assert_eq!(neutrons[0].tof, 3000);
606 assert_ne!(neutrons[0].tof, 1000);
608 }
609
610 #[test]
611 fn test_zero_tot_weighted_centroid() {
612 let batch = make_batch(&[
614 (1000, 10, 20, 500, 0, 0, 0), (1000, 30, 40, 500, 0, 0, 0), ]);
617
618 let mut extractor = SimpleCentroidExtraction::new();
620 extractor.configure(ExtractionConfig::default().with_min_tot_threshold(0));
621
622 let neutrons = extractor.extract_soa(&batch, 1).unwrap();
623
624 assert_eq!(neutrons.len(), 1);
625 assert!((neutrons[0].x - 160.0).abs() < 0.01);
628 assert!((neutrons[0].y - 240.0).abs() < 0.01);
629 assert_eq!(neutrons[0].tot, 0);
630 assert_eq!(neutrons[0].n_hits, 2);
631 assert!(!neutrons[0].x.is_nan());
633 assert!(!neutrons[0].y.is_nan());
634 }
635
636 #[test]
637 fn test_extract_soa_expected_values() {
638 let mut batch = HitBatch::with_capacity(3);
639 batch.push((10, 10, 1000, 20, 500, 0));
641 batch.push((20, 10, 1500, 10, 500, 0));
642 batch.push((5, 7, 2000, 15, 500, 1));
644
645 batch.cluster_id[0] = 0;
646 batch.cluster_id[1] = 0;
647 batch.cluster_id[2] = 1;
648
649 let extractor = SimpleCentroidExtraction::new();
650 let neutrons = extractor.extract_soa(&batch, 2).unwrap();
651
652 assert_eq!(neutrons.len(), 2);
653
654 let n0 = &neutrons[0];
655 let expected_x = (10.0 * 20.0 + 20.0 * 10.0) / 30.0 * 8.0;
656 let expected_y = 10.0 * 8.0;
657 assert!((n0.x - expected_x).abs() < 1e-6);
658 assert!((n0.y - expected_y).abs() < 1e-6);
659 assert_eq!(n0.tof, 1000);
660 assert_eq!(n0.tot, 30);
661 assert_eq!(n0.n_hits, 2);
662 assert_eq!(n0.chip_id, 0);
663
664 let n1 = &neutrons[1];
665 assert!((n1.x - 40.0).abs() < 1e-6);
666 assert!((n1.y - 56.0).abs() < 1e-6);
667 assert_eq!(n1.tof, 2000);
668 assert_eq!(n1.tot, 15);
669 assert_eq!(n1.n_hits, 1);
670 assert_eq!(n1.chip_id, 1);
671 }
672
673 #[test]
674 fn test_super_resolution_factor_affects_output() {
675 let batch = make_batch(&[(1000, 2, 3, 500, 20, 0, 0)]);
676
677 let mut extractor = SimpleCentroidExtraction::new();
678 extractor.configure(
679 ExtractionConfig::default()
680 .with_super_resolution(1.0)
681 .with_min_tot_threshold(0),
682 );
683 let neutrons = extractor.extract_soa(&batch, 1).unwrap();
684 assert_eq!(neutrons.len(), 1);
685 assert!((neutrons[0].x - 2.0).abs() < f64::EPSILON);
686 assert!((neutrons[0].y - 3.0).abs() < f64::EPSILON);
687
688 extractor.configure(
689 ExtractionConfig::default()
690 .with_super_resolution(4.0)
691 .with_min_tot_threshold(0),
692 );
693 let neutrons = extractor.extract_soa(&batch, 1).unwrap();
694 assert_eq!(neutrons.len(), 1);
695 assert!((neutrons[0].x - 8.0).abs() < f64::EPSILON);
696 assert!((neutrons[0].y - 12.0).abs() < f64::EPSILON);
697 }
698}