Skip to main content

rustpix_algorithms/
processing.rs

1//! High-level processing helpers that combine clustering and extraction.
2
3use crate::{AbsClustering, AbsConfig, AbsState, DbscanClustering, DbscanConfig, DbscanState};
4use crate::{GridClustering, GridConfig, GridState};
5use rustpix_core::clustering::ClusteringConfig;
6use rustpix_core::error::Result;
7use rustpix_core::extraction::{ExtractionConfig, NeutronExtraction, SimpleCentroidExtraction};
8use rustpix_core::neutron::{Neutron, NeutronBatch};
9use rustpix_core::soa::HitBatch;
10
11/// Supported clustering algorithms.
12#[derive(Clone, Copy, Debug)]
13pub enum ClusteringAlgorithm {
14    /// Age-Based Spatial clustering.
15    Abs,
16    /// DBSCAN clustering.
17    Dbscan,
18    /// Grid-based clustering.
19    Grid,
20}
21
22/// Algorithm-specific tuning parameters.
23#[derive(Clone, Debug)]
24pub struct AlgorithmParams {
25    /// ABS scan interval (hits between aging scans).
26    pub abs_scan_interval: usize,
27    /// DBSCAN minimum points for a seed cluster.
28    pub dbscan_min_points: usize,
29    /// Grid cell size (pixels).
30    pub grid_cell_size: usize,
31}
32
33impl Default for AlgorithmParams {
34    fn default() -> Self {
35        Self {
36            abs_scan_interval: 100,
37            dbscan_min_points: 2,
38            grid_cell_size: 32,
39        }
40    }
41}
42
43/// Iterator that clusters and extracts each incoming batch.
44pub struct ClusterAndExtractStream<I>
45where
46    I: Iterator<Item = HitBatch>,
47{
48    batches: I,
49    algorithm: ClusteringAlgorithm,
50    clustering: ClusteringConfig,
51    extraction: ExtractionConfig,
52    params: AlgorithmParams,
53}
54
55impl<I> Iterator for ClusterAndExtractStream<I>
56where
57    I: Iterator<Item = HitBatch>,
58{
59    type Item = Result<NeutronBatch>;
60
61    fn next(&mut self) -> Option<Self::Item> {
62        self.batches.next().map(|mut batch| {
63            cluster_and_extract_batch(
64                &mut batch,
65                self.algorithm,
66                &self.clustering,
67                &self.extraction,
68                &self.params,
69            )
70        })
71    }
72}
73
74/// Create a streaming cluster-and-extract iterator.
75pub fn cluster_and_extract_stream_iter<I>(
76    batches: I,
77    algorithm: ClusteringAlgorithm,
78    clustering: ClusteringConfig,
79    extraction: ExtractionConfig,
80    params: AlgorithmParams,
81) -> ClusterAndExtractStream<I::IntoIter>
82where
83    I: IntoIterator<Item = HitBatch>,
84{
85    ClusterAndExtractStream {
86        batches: batches.into_iter(),
87        algorithm,
88        clustering,
89        extraction,
90        params,
91    }
92}
93
94/// Cluster hits in-place, then extract neutrons using the configured algorithm.
95///
96/// # Errors
97/// Returns an error if clustering or extraction fails.
98pub fn cluster_and_extract(
99    batch: &mut HitBatch,
100    algorithm: ClusteringAlgorithm,
101    clustering: &ClusteringConfig,
102    extraction: &ExtractionConfig,
103    params: &AlgorithmParams,
104) -> Result<Vec<Neutron>> {
105    let num_clusters = match algorithm {
106        ClusteringAlgorithm::Abs => {
107            let algo = AbsClustering::new(AbsConfig {
108                radius: clustering.radius,
109                neutron_correlation_window_ns: clustering.temporal_window_ns,
110                min_cluster_size: clustering.min_cluster_size,
111                scan_interval: params.abs_scan_interval,
112            });
113            let mut state = AbsState::default();
114            algo.cluster(batch, &mut state)?
115        }
116        ClusteringAlgorithm::Dbscan => {
117            let algo = DbscanClustering::new(DbscanConfig {
118                epsilon: clustering.radius,
119                temporal_window_ns: clustering.temporal_window_ns,
120                min_points: params.dbscan_min_points,
121                min_cluster_size: clustering.min_cluster_size,
122            });
123            let mut state = DbscanState::default();
124            algo.cluster(batch, &mut state)?
125        }
126        ClusteringAlgorithm::Grid => {
127            let algo = GridClustering::new(GridConfig {
128                radius: clustering.radius,
129                temporal_window_ns: clustering.temporal_window_ns,
130                min_cluster_size: clustering.min_cluster_size,
131                cell_size: params.grid_cell_size,
132                max_cluster_size: clustering.max_cluster_size.map(|value| value as usize),
133            });
134            let mut state = GridState::default();
135            algo.cluster(batch, &mut state)?
136        }
137    };
138
139    let mut extractor = SimpleCentroidExtraction::new();
140    extractor.configure(extraction.clone());
141    extractor
142        .extract_soa(batch, num_clusters)
143        .map_err(Into::into)
144}
145
146/// Cluster hits in-place, then extract neutrons into a `NeutronBatch`.
147///
148/// # Errors
149/// Returns an error if clustering or extraction fails.
150pub fn cluster_and_extract_batch(
151    batch: &mut HitBatch,
152    algorithm: ClusteringAlgorithm,
153    clustering: &ClusteringConfig,
154    extraction: &ExtractionConfig,
155    params: &AlgorithmParams,
156) -> Result<NeutronBatch> {
157    let num_clusters = match algorithm {
158        ClusteringAlgorithm::Abs => {
159            let algo = AbsClustering::new(AbsConfig {
160                radius: clustering.radius,
161                neutron_correlation_window_ns: clustering.temporal_window_ns,
162                min_cluster_size: clustering.min_cluster_size,
163                scan_interval: params.abs_scan_interval,
164            });
165            let mut state = AbsState::default();
166            algo.cluster(batch, &mut state)?
167        }
168        ClusteringAlgorithm::Dbscan => {
169            let algo = DbscanClustering::new(DbscanConfig {
170                epsilon: clustering.radius,
171                temporal_window_ns: clustering.temporal_window_ns,
172                min_points: params.dbscan_min_points,
173                min_cluster_size: clustering.min_cluster_size,
174            });
175            let mut state = DbscanState::default();
176            algo.cluster(batch, &mut state)?
177        }
178        ClusteringAlgorithm::Grid => {
179            let algo = GridClustering::new(GridConfig {
180                radius: clustering.radius,
181                temporal_window_ns: clustering.temporal_window_ns,
182                min_cluster_size: clustering.min_cluster_size,
183                cell_size: params.grid_cell_size,
184                max_cluster_size: clustering.max_cluster_size.map(|value| value as usize),
185            });
186            let mut state = GridState::default();
187            algo.cluster(batch, &mut state)?
188        }
189    };
190
191    let mut extractor = SimpleCentroidExtraction::new();
192    extractor.configure(extraction.clone());
193    extractor
194        .extract_soa_batch(batch, num_clusters)
195        .map_err(Into::into)
196}
197
198/// Cluster hits in batches, then extract and append neutrons into a single batch.
199///
200/// # Errors
201/// Returns an error if clustering or extraction fails for any batch.
202pub fn cluster_and_extract_stream<I>(
203    batches: I,
204    algorithm: ClusteringAlgorithm,
205    clustering: &ClusteringConfig,
206    extraction: &ExtractionConfig,
207    params: &AlgorithmParams,
208) -> Result<NeutronBatch>
209where
210    I: IntoIterator<Item = HitBatch>,
211{
212    let mut all_neutrons = NeutronBatch::default();
213    let iter = cluster_and_extract_stream_iter(
214        batches,
215        algorithm,
216        clustering.clone(),
217        extraction.clone(),
218        params.clone(),
219    );
220    for neutrons in iter {
221        let neutrons = neutrons?;
222        all_neutrons.append(&neutrons);
223    }
224    Ok(all_neutrons)
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    #[test]
232    fn test_stream_iter_matches_batch_results() {
233        let mut batch1 = HitBatch::with_capacity(2);
234        batch1.push((10, 10, 100, 5, 1_000, 0));
235        batch1.push((11, 10, 102, 6, 1_002, 0));
236
237        let mut batch2 = HitBatch::with_capacity(2);
238        batch2.push((20, 20, 200, 7, 2_000, 1));
239        batch2.push((21, 20, 202, 8, 2_002, 1));
240
241        let algorithm = ClusteringAlgorithm::Abs;
242        let clustering = ClusteringConfig::default();
243        let extraction = ExtractionConfig::default();
244        let params = AlgorithmParams::default();
245
246        let mut expected1 = batch1.clone();
247        let expected1 =
248            cluster_and_extract_batch(&mut expected1, algorithm, &clustering, &extraction, &params)
249                .unwrap();
250
251        let mut expected2 = batch2.clone();
252        let expected2 =
253            cluster_and_extract_batch(&mut expected2, algorithm, &clustering, &extraction, &params)
254                .unwrap();
255
256        let mut iter = cluster_and_extract_stream_iter(
257            vec![batch1, batch2],
258            algorithm,
259            clustering,
260            extraction,
261            params,
262        );
263
264        let batch_out1 = iter.next().unwrap().unwrap();
265        assert_eq!(batch_out1.x, expected1.x);
266        assert_eq!(batch_out1.y, expected1.y);
267        assert_eq!(batch_out1.tof, expected1.tof);
268        assert_eq!(batch_out1.tot, expected1.tot);
269        assert_eq!(batch_out1.n_hits, expected1.n_hits);
270        assert_eq!(batch_out1.chip_id, expected1.chip_id);
271
272        let batch_out2 = iter.next().unwrap().unwrap();
273        assert_eq!(batch_out2.x, expected2.x);
274        assert_eq!(batch_out2.y, expected2.y);
275        assert_eq!(batch_out2.tof, expected2.tof);
276        assert_eq!(batch_out2.tot, expected2.tot);
277        assert_eq!(batch_out2.n_hits, expected2.n_hits);
278        assert_eq!(batch_out2.chip_id, expected2.chip_id);
279
280        assert!(iter.next().is_none());
281    }
282}