rustpix_algorithms/
dbscan.rs1use rayon::prelude::*;
4use rustpix_core::clustering::ClusteringError;
5use rustpix_core::soa::HitBatch;
6
7#[derive(Clone, Debug)]
9pub struct DbscanConfig {
10 pub epsilon: f64,
12 pub temporal_window_ns: f64,
14 pub min_points: usize,
16 pub min_cluster_size: u16,
18}
19
20impl Default for DbscanConfig {
21 fn default() -> Self {
22 Self {
23 epsilon: 5.0,
24 temporal_window_ns: 75.0,
25 min_points: 2,
26 min_cluster_size: 1,
27 }
28 }
29}
30
31pub struct DbscanClustering {
33 config: DbscanConfig,
34}
35
36#[derive(Default)]
37pub struct DbscanState {
39 grid: Vec<Vec<usize>>,
40 visited: Vec<bool>,
41 noise: Vec<bool>,
42 neighbors: Vec<usize>,
43 seeds: Vec<usize>,
44 cluster_sizes: Vec<usize>,
45 id_map: Vec<i32>,
46}
47
48struct DbscanContext<'a> {
49 grid: &'a [Vec<usize>],
50 cell_size: usize,
51 grid_w: usize,
52 eps_sq: f64,
53 window_tof: u32,
54}
55
56struct TrackingState<'a> {
58 visited: &'a mut [bool],
59 noise: &'a mut [bool],
60}
61
62impl DbscanClustering {
63 #[must_use]
65 pub fn new(config: DbscanConfig) -> Self {
66 Self { config }
67 }
68
69 #[must_use]
71 pub fn create_state(&self) -> DbscanState {
72 DbscanState::default()
73 }
74
75 pub fn cluster(
80 &self,
81 batch: &mut HitBatch,
82 state: &mut DbscanState,
83 ) -> Result<usize, ClusteringError> {
84 let n = batch.len();
85 if batch.is_empty() {
86 return Ok(0);
87 }
88
89 batch.cluster_id.par_iter_mut().for_each(|id| *id = -1);
92
93 let ctx = self.build_context(batch, &mut state.grid);
98
99 if state.visited.len() < n {
100 state.visited.resize(n, false);
101 state.noise.resize(n, false);
102 }
103 state.visited[..n].fill(false);
105 state.noise[..n].fill(false);
106
107 let mut current_cluster_id = 0;
108
109 let visited_slice = &mut state.visited[..n];
120 let noise_slice = &mut state.noise[..n];
121 let neighbors_buffer = &mut state.neighbors;
122 let seeds_buffer = &mut state.seeds;
123
124 for i in 0..n {
125 if visited_slice[i] {
126 continue;
127 }
128 visited_slice[i] = true;
129
130 Self::region_query_into(&ctx, i, batch, neighbors_buffer);
131
132 if neighbors_buffer.len() < self.config.min_points {
133 noise_slice[i] = true;
134 } else {
135 batch.cluster_id[i] = current_cluster_id;
136 seeds_buffer.clear();
137 seeds_buffer.extend_from_slice(neighbors_buffer);
138 let mut tracking = TrackingState {
139 visited: visited_slice,
140 noise: noise_slice,
141 };
142 self.expand_cluster(
143 &ctx,
144 seeds_buffer,
145 current_cluster_id,
146 batch,
147 &mut tracking,
148 neighbors_buffer,
149 );
150 current_cluster_id += 1;
151 }
152 }
153
154 Ok(self.prune_small_clusters(batch, state, current_cluster_id))
155 }
156
157 fn build_context<'a>(
158 &self,
159 batch: &HitBatch,
160 grid: &'a mut Vec<Vec<usize>>,
161 ) -> DbscanContext<'a> {
162 let n = batch.len();
163 let cell_size = float_to_usize(self.config.epsilon.ceil()).max(32);
164
165 let mut max_x = 0usize;
166 let mut max_y = 0usize;
167 for i in 0..n {
168 let x = usize::from(batch.x[i]);
169 let y = usize::from(batch.y[i]);
170 if x > max_x {
171 max_x = x;
172 }
173 if y > max_y {
174 max_y = y;
175 }
176 }
177
178 let width = max_x + 32;
179 let height = max_y + 32;
180 let grid_w = width / cell_size + 1;
181 let grid_h = height / cell_size + 1;
182 let total_cells = grid_w * grid_h;
183
184 if grid.len() < total_cells {
185 grid.resize(total_cells, Vec::new());
186 } else {
187 for cell in grid.iter_mut() {
188 cell.clear();
189 }
190 }
191
192 for i in 0..n {
193 let cx = usize::from(batch.x[i]) / cell_size;
194 let cy = usize::from(batch.y[i]) / cell_size;
195 let idx = cy * grid_w + cx;
196 if idx < grid.len() {
197 grid[idx].push(i);
198 }
199 }
200
201 let epsilon_sq = self.config.epsilon * self.config.epsilon;
202 let window_tof = float_to_u32((self.config.temporal_window_ns / 25.0).ceil());
203
204 DbscanContext {
205 grid,
206 cell_size,
207 grid_w,
208 eps_sq: epsilon_sq,
209 window_tof,
210 }
211 }
212
213 fn prune_small_clusters(
214 &self,
215 batch: &mut HitBatch,
216 state: &mut DbscanState,
217 cluster_count: i32,
218 ) -> usize {
219 if self.config.min_cluster_size <= 1 || cluster_count <= 0 {
220 return usize::try_from(cluster_count).unwrap_or(0);
221 }
222
223 let current_cluster_len = usize::try_from(cluster_count).unwrap_or(0);
224 if state.cluster_sizes.len() < current_cluster_len {
225 state.cluster_sizes.resize(current_cluster_len, 0);
226 }
227 let sizes = &mut state.cluster_sizes[..current_cluster_len];
228 sizes.fill(0);
229 for &id in &batch.cluster_id {
230 if let Ok(idx) = usize::try_from(id) {
231 if let Some(size) = sizes.get_mut(idx) {
232 *size += 1;
233 }
234 }
235 }
236
237 if state.id_map.len() < current_cluster_len {
238 state.id_map.resize(current_cluster_len, -1);
239 }
240 let id_map = &mut state.id_map[..current_cluster_len];
241 id_map.fill(-1);
242 let mut new_cluster_count = 0;
243 let min_size = usize::from(self.config.min_cluster_size);
244
245 for (old_id, &size) in sizes.iter().enumerate() {
246 if size >= min_size {
247 id_map[old_id] = new_cluster_count;
248 new_cluster_count += 1;
249 }
250 }
251
252 batch.cluster_id.par_iter_mut().for_each(|id| {
253 if let Ok(idx) = usize::try_from(*id) {
254 if let Some(&new_id) = id_map.get(idx) {
255 *id = new_id;
256 }
257 }
258 });
259
260 usize::try_from(new_cluster_count).unwrap_or(0)
261 }
262
263 fn region_query_into(
264 ctx: &DbscanContext,
265 idx: usize,
266 batch: &HitBatch,
267 neighbors: &mut Vec<usize>,
268 ) {
269 let x = f64::from(batch.x[idx]);
270 let y = f64::from(batch.y[idx]);
271 let tof = batch.tof[idx];
272 let cx = usize::from(batch.x[idx]) / ctx.cell_size;
273 let cy = usize::from(batch.y[idx]) / ctx.cell_size;
274 let cell_col = i32::try_from(cx).unwrap_or(i32::MAX);
275 let cell_row = i32::try_from(cy).unwrap_or(i32::MAX);
276
277 neighbors.clear();
278
279 for dy in -1..=1 {
281 for dx in -1..=1 {
282 let ncx = cell_col + dx;
283 let ncy = cell_row + dy;
284 if ncx < 0 || ncy < 0 {
285 continue;
286 }
287 let (Ok(neighbor_x), Ok(neighbor_y)) = (usize::try_from(ncx), usize::try_from(ncy))
288 else {
289 continue;
290 };
291 let gidx = neighbor_y * ctx.grid_w + neighbor_x;
292 if let Some(cell) = ctx.grid.get(gidx) {
293 for &j in cell {
294 if j == idx {
295 continue;
296 }
297 let val_x = f64::from(batch.x[j]);
298 let val_y = f64::from(batch.y[j]);
299 let val_tof = batch.tof[j];
300
301 let dt = tof.abs_diff(val_tof);
302 if dt <= ctx.window_tof {
303 let dist_sq = (x - val_x).powi(2) + (y - val_y).powi(2);
304 if dist_sq <= ctx.eps_sq {
305 neighbors.push(j);
306 }
307 }
308 }
309 }
310 }
311 }
312 }
313
314 fn expand_cluster(
315 &self,
316 ctx: &DbscanContext,
317 seeds: &mut Vec<usize>,
318 cluster_id: i32,
319 batch: &mut HitBatch,
320 tracking: &mut TrackingState,
321 neighbors: &mut Vec<usize>,
322 ) {
323 let mut i = 0;
324 while i < seeds.len() {
325 let current_p = seeds[i];
326 i += 1;
327
328 if tracking.noise[current_p] {
329 tracking.noise[current_p] = false;
330 batch.cluster_id[current_p] = cluster_id;
331 }
332
333 if !tracking.visited[current_p] {
334 tracking.visited[current_p] = true;
335 batch.cluster_id[current_p] = cluster_id;
336
337 Self::region_query_into(ctx, current_p, batch, neighbors);
338 if neighbors.len() >= self.config.min_points {
339 seeds.extend_from_slice(neighbors);
340 }
341 } else if batch.cluster_id[current_p] == -1 {
342 batch.cluster_id[current_p] = cluster_id;
343 }
344 }
345 }
346}
347
348fn float_to_u32(value: f64) -> u32 {
349 if value <= 0.0 {
350 return 0;
351 }
352 if value >= f64::from(u32::MAX) {
353 return u32::MAX;
354 }
355 format!("{value:.0}").parse::<u32>().unwrap_or(u32::MAX)
356}
357
358fn float_to_usize(value: f64) -> usize {
359 if value <= 0.0 {
360 return 0;
361 }
362 format!("{value:.0}").parse::<usize>().unwrap_or(usize::MAX)
363}