1use crate::{AbsClustering, AbsConfig, AbsState, DbscanClustering, DbscanConfig, DbscanState};
4use crate::{GridClustering, GridConfig, GridState};
5use rustpix_core::clustering::ClusteringConfig;
6use rustpix_core::error::Result;
7use rustpix_core::extraction::{ExtractionConfig, NeutronExtraction, SimpleCentroidExtraction};
8use rustpix_core::neutron::{Neutron, NeutronBatch};
9use rustpix_core::soa::HitBatch;
10
11#[derive(Clone, Copy, Debug)]
13pub enum ClusteringAlgorithm {
14 Abs,
16 Dbscan,
18 Grid,
20}
21
22#[derive(Clone, Debug)]
24pub struct AlgorithmParams {
25 pub abs_scan_interval: usize,
27 pub dbscan_min_points: usize,
29 pub grid_cell_size: usize,
31}
32
33impl Default for AlgorithmParams {
34 fn default() -> Self {
35 Self {
36 abs_scan_interval: 100,
37 dbscan_min_points: 2,
38 grid_cell_size: 32,
39 }
40 }
41}
42
43pub struct ClusterAndExtractStream<I>
45where
46 I: Iterator<Item = HitBatch>,
47{
48 batches: I,
49 algorithm: ClusteringAlgorithm,
50 clustering: ClusteringConfig,
51 extraction: ExtractionConfig,
52 params: AlgorithmParams,
53}
54
55impl<I> Iterator for ClusterAndExtractStream<I>
56where
57 I: Iterator<Item = HitBatch>,
58{
59 type Item = Result<NeutronBatch>;
60
61 fn next(&mut self) -> Option<Self::Item> {
62 self.batches.next().map(|mut batch| {
63 cluster_and_extract_batch(
64 &mut batch,
65 self.algorithm,
66 &self.clustering,
67 &self.extraction,
68 &self.params,
69 )
70 })
71 }
72}
73
74pub fn cluster_and_extract_stream_iter<I>(
76 batches: I,
77 algorithm: ClusteringAlgorithm,
78 clustering: ClusteringConfig,
79 extraction: ExtractionConfig,
80 params: AlgorithmParams,
81) -> ClusterAndExtractStream<I::IntoIter>
82where
83 I: IntoIterator<Item = HitBatch>,
84{
85 ClusterAndExtractStream {
86 batches: batches.into_iter(),
87 algorithm,
88 clustering,
89 extraction,
90 params,
91 }
92}
93
94pub fn cluster_and_extract(
99 batch: &mut HitBatch,
100 algorithm: ClusteringAlgorithm,
101 clustering: &ClusteringConfig,
102 extraction: &ExtractionConfig,
103 params: &AlgorithmParams,
104) -> Result<Vec<Neutron>> {
105 let num_clusters = match algorithm {
106 ClusteringAlgorithm::Abs => {
107 let algo = AbsClustering::new(AbsConfig {
108 radius: clustering.radius,
109 neutron_correlation_window_ns: clustering.temporal_window_ns,
110 min_cluster_size: clustering.min_cluster_size,
111 scan_interval: params.abs_scan_interval,
112 });
113 let mut state = AbsState::default();
114 algo.cluster(batch, &mut state)?
115 }
116 ClusteringAlgorithm::Dbscan => {
117 let algo = DbscanClustering::new(DbscanConfig {
118 epsilon: clustering.radius,
119 temporal_window_ns: clustering.temporal_window_ns,
120 min_points: params.dbscan_min_points,
121 min_cluster_size: clustering.min_cluster_size,
122 });
123 let mut state = DbscanState::default();
124 algo.cluster(batch, &mut state)?
125 }
126 ClusteringAlgorithm::Grid => {
127 let algo = GridClustering::new(GridConfig {
128 radius: clustering.radius,
129 temporal_window_ns: clustering.temporal_window_ns,
130 min_cluster_size: clustering.min_cluster_size,
131 cell_size: params.grid_cell_size,
132 max_cluster_size: clustering.max_cluster_size.map(|value| value as usize),
133 });
134 let mut state = GridState::default();
135 algo.cluster(batch, &mut state)?
136 }
137 };
138
139 let mut extractor = SimpleCentroidExtraction::new();
140 extractor.configure(extraction.clone());
141 extractor
142 .extract_soa(batch, num_clusters)
143 .map_err(Into::into)
144}
145
146pub fn cluster_and_extract_batch(
151 batch: &mut HitBatch,
152 algorithm: ClusteringAlgorithm,
153 clustering: &ClusteringConfig,
154 extraction: &ExtractionConfig,
155 params: &AlgorithmParams,
156) -> Result<NeutronBatch> {
157 let num_clusters = match algorithm {
158 ClusteringAlgorithm::Abs => {
159 let algo = AbsClustering::new(AbsConfig {
160 radius: clustering.radius,
161 neutron_correlation_window_ns: clustering.temporal_window_ns,
162 min_cluster_size: clustering.min_cluster_size,
163 scan_interval: params.abs_scan_interval,
164 });
165 let mut state = AbsState::default();
166 algo.cluster(batch, &mut state)?
167 }
168 ClusteringAlgorithm::Dbscan => {
169 let algo = DbscanClustering::new(DbscanConfig {
170 epsilon: clustering.radius,
171 temporal_window_ns: clustering.temporal_window_ns,
172 min_points: params.dbscan_min_points,
173 min_cluster_size: clustering.min_cluster_size,
174 });
175 let mut state = DbscanState::default();
176 algo.cluster(batch, &mut state)?
177 }
178 ClusteringAlgorithm::Grid => {
179 let algo = GridClustering::new(GridConfig {
180 radius: clustering.radius,
181 temporal_window_ns: clustering.temporal_window_ns,
182 min_cluster_size: clustering.min_cluster_size,
183 cell_size: params.grid_cell_size,
184 max_cluster_size: clustering.max_cluster_size.map(|value| value as usize),
185 });
186 let mut state = GridState::default();
187 algo.cluster(batch, &mut state)?
188 }
189 };
190
191 let mut extractor = SimpleCentroidExtraction::new();
192 extractor.configure(extraction.clone());
193 extractor
194 .extract_soa_batch(batch, num_clusters)
195 .map_err(Into::into)
196}
197
198pub fn cluster_and_extract_stream<I>(
203 batches: I,
204 algorithm: ClusteringAlgorithm,
205 clustering: &ClusteringConfig,
206 extraction: &ExtractionConfig,
207 params: &AlgorithmParams,
208) -> Result<NeutronBatch>
209where
210 I: IntoIterator<Item = HitBatch>,
211{
212 let mut all_neutrons = NeutronBatch::default();
213 let iter = cluster_and_extract_stream_iter(
214 batches,
215 algorithm,
216 clustering.clone(),
217 extraction.clone(),
218 params.clone(),
219 );
220 for neutrons in iter {
221 let neutrons = neutrons?;
222 all_neutrons.append(&neutrons);
223 }
224 Ok(all_neutrons)
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 #[test]
232 fn test_stream_iter_matches_batch_results() {
233 let mut batch1 = HitBatch::with_capacity(2);
234 batch1.push((10, 10, 100, 5, 1_000, 0));
235 batch1.push((11, 10, 102, 6, 1_002, 0));
236
237 let mut batch2 = HitBatch::with_capacity(2);
238 batch2.push((20, 20, 200, 7, 2_000, 1));
239 batch2.push((21, 20, 202, 8, 2_002, 1));
240
241 let algorithm = ClusteringAlgorithm::Abs;
242 let clustering = ClusteringConfig::default();
243 let extraction = ExtractionConfig::default();
244 let params = AlgorithmParams::default();
245
246 let mut expected1 = batch1.clone();
247 let expected1 =
248 cluster_and_extract_batch(&mut expected1, algorithm, &clustering, &extraction, ¶ms)
249 .unwrap();
250
251 let mut expected2 = batch2.clone();
252 let expected2 =
253 cluster_and_extract_batch(&mut expected2, algorithm, &clustering, &extraction, ¶ms)
254 .unwrap();
255
256 let mut iter = cluster_and_extract_stream_iter(
257 vec![batch1, batch2],
258 algorithm,
259 clustering,
260 extraction,
261 params,
262 );
263
264 let batch_out1 = iter.next().unwrap().unwrap();
265 assert_eq!(batch_out1.x, expected1.x);
266 assert_eq!(batch_out1.y, expected1.y);
267 assert_eq!(batch_out1.tof, expected1.tof);
268 assert_eq!(batch_out1.tot, expected1.tot);
269 assert_eq!(batch_out1.n_hits, expected1.n_hits);
270 assert_eq!(batch_out1.chip_id, expected1.chip_id);
271
272 let batch_out2 = iter.next().unwrap().unwrap();
273 assert_eq!(batch_out2.x, expected2.x);
274 assert_eq!(batch_out2.y, expected2.y);
275 assert_eq!(batch_out2.tof, expected2.tof);
276 assert_eq!(batch_out2.tot, expected2.tot);
277 assert_eq!(batch_out2.n_hits, expected2.n_hits);
278 assert_eq!(batch_out2.chip_id, expected2.chip_id);
279
280 assert!(iter.next().is_none());
281 }
282}