Skip to main content

rustpix_algorithms/
dbscan.rs

1//! SoA-optimized DBSCAN clustering.
2
3use rayon::prelude::*;
4use rustpix_core::clustering::ClusteringError;
5use rustpix_core::soa::HitBatch;
6
7/// Configuration for DBSCAN clustering.
8#[derive(Clone, Debug)]
9pub struct DbscanConfig {
10    /// Spatial neighborhood radius (pixels).
11    pub epsilon: f64,
12    /// Temporal correlation window (nanoseconds).
13    pub temporal_window_ns: f64,
14    /// Minimum number of points to seed a cluster.
15    pub min_points: usize,
16    /// Minimum cluster size to keep after pruning.
17    pub min_cluster_size: u16,
18}
19
20impl Default for DbscanConfig {
21    fn default() -> Self {
22        Self {
23            epsilon: 5.0,
24            temporal_window_ns: 75.0,
25            min_points: 2,
26            min_cluster_size: 1,
27        }
28    }
29}
30
31/// DBSCAN clustering implementation.
32pub struct DbscanClustering {
33    config: DbscanConfig,
34}
35
36#[derive(Default)]
37/// Reusable DBSCAN clustering state buffers.
38pub struct DbscanState {
39    grid: Vec<Vec<usize>>,
40    visited: Vec<bool>,
41    noise: Vec<bool>,
42    neighbors: Vec<usize>,
43    seeds: Vec<usize>,
44    cluster_sizes: Vec<usize>,
45    id_map: Vec<i32>,
46}
47
48struct DbscanContext<'a> {
49    grid: &'a [Vec<usize>],
50    cell_size: usize,
51    grid_w: usize,
52    eps_sq: f64,
53    window_tof: u32,
54}
55
56/// Mutable tracking state used during DBSCAN clustering.
57struct TrackingState<'a> {
58    visited: &'a mut [bool],
59    noise: &'a mut [bool],
60}
61
62impl DbscanClustering {
63    /// Create a DBSCAN clustering instance with the provided configuration.
64    #[must_use]
65    pub fn new(config: DbscanConfig) -> Self {
66        Self { config }
67    }
68
69    /// Create a fresh DBSCAN state container.
70    #[must_use]
71    pub fn create_state(&self) -> DbscanState {
72        DbscanState::default()
73    }
74
75    /// Cluster hits using DBSCAN.
76    ///
77    /// # Errors
78    /// Returns an error if clustering fails.
79    pub fn cluster(
80        &self,
81        batch: &mut HitBatch,
82        state: &mut DbscanState,
83    ) -> Result<usize, ClusteringError> {
84        let n = batch.len();
85        if batch.is_empty() {
86            return Ok(0);
87        }
88
89        // Reset cluster IDs
90        // SoA cluster_id is i32
91        batch.cluster_id.par_iter_mut().for_each(|id| *id = -1);
92
93        // We need spatial indexing.
94        // Reuse logic from SoAGridClustering or implement a simple grid?
95        // DBSCAN needs precise distance check, so Grid is just a broad phase.
96
97        let ctx = self.build_context(batch, &mut state.grid);
98
99        if state.visited.len() < n {
100            state.visited.resize(n, false);
101            state.noise.resize(n, false);
102        }
103        // Reset flags
104        state.visited[..n].fill(false);
105        state.noise[..n].fill(false);
106
107        let mut current_cluster_id = 0;
108
109        // Use slices for tracking to avoid split borrowing issues with state
110        // We'll pass slices to helper functions
111        // But we need to use the `visited` and `noise` from `state`.
112        // To avoid borrowing `state` while reading `ctx` (which borrows `state.grid`),
113        // we can split `state` or pass things differently.
114        // `ctx` borrows `state.grid`.
115        // `visited` and `noise` are separate fields.
116        // Rust might figure it out if we borrow fields separately.
117
118        // To make it safe and easier, let's extract the slices from state:
119        let visited_slice = &mut state.visited[..n];
120        let noise_slice = &mut state.noise[..n];
121        let neighbors_buffer = &mut state.neighbors;
122        let seeds_buffer = &mut state.seeds;
123
124        for i in 0..n {
125            if visited_slice[i] {
126                continue;
127            }
128            visited_slice[i] = true;
129
130            Self::region_query_into(&ctx, i, batch, neighbors_buffer);
131
132            if neighbors_buffer.len() < self.config.min_points {
133                noise_slice[i] = true;
134            } else {
135                batch.cluster_id[i] = current_cluster_id;
136                seeds_buffer.clear();
137                seeds_buffer.extend_from_slice(neighbors_buffer);
138                let mut tracking = TrackingState {
139                    visited: visited_slice,
140                    noise: noise_slice,
141                };
142                self.expand_cluster(
143                    &ctx,
144                    seeds_buffer,
145                    current_cluster_id,
146                    batch,
147                    &mut tracking,
148                    neighbors_buffer,
149                );
150                current_cluster_id += 1;
151            }
152        }
153
154        Ok(self.prune_small_clusters(batch, state, current_cluster_id))
155    }
156
157    fn build_context<'a>(
158        &self,
159        batch: &HitBatch,
160        grid: &'a mut Vec<Vec<usize>>,
161    ) -> DbscanContext<'a> {
162        let n = batch.len();
163        let cell_size = float_to_usize(self.config.epsilon.ceil()).max(32);
164
165        let mut max_x = 0usize;
166        let mut max_y = 0usize;
167        for i in 0..n {
168            let x = usize::from(batch.x[i]);
169            let y = usize::from(batch.y[i]);
170            if x > max_x {
171                max_x = x;
172            }
173            if y > max_y {
174                max_y = y;
175            }
176        }
177
178        let width = max_x + 32;
179        let height = max_y + 32;
180        let grid_w = width / cell_size + 1;
181        let grid_h = height / cell_size + 1;
182        let total_cells = grid_w * grid_h;
183
184        if grid.len() < total_cells {
185            grid.resize(total_cells, Vec::new());
186        } else {
187            for cell in grid.iter_mut() {
188                cell.clear();
189            }
190        }
191
192        for i in 0..n {
193            let cx = usize::from(batch.x[i]) / cell_size;
194            let cy = usize::from(batch.y[i]) / cell_size;
195            let idx = cy * grid_w + cx;
196            if idx < grid.len() {
197                grid[idx].push(i);
198            }
199        }
200
201        let epsilon_sq = self.config.epsilon * self.config.epsilon;
202        let window_tof = float_to_u32((self.config.temporal_window_ns / 25.0).ceil());
203
204        DbscanContext {
205            grid,
206            cell_size,
207            grid_w,
208            eps_sq: epsilon_sq,
209            window_tof,
210        }
211    }
212
213    fn prune_small_clusters(
214        &self,
215        batch: &mut HitBatch,
216        state: &mut DbscanState,
217        cluster_count: i32,
218    ) -> usize {
219        if self.config.min_cluster_size <= 1 || cluster_count <= 0 {
220            return usize::try_from(cluster_count).unwrap_or(0);
221        }
222
223        let current_cluster_len = usize::try_from(cluster_count).unwrap_or(0);
224        if state.cluster_sizes.len() < current_cluster_len {
225            state.cluster_sizes.resize(current_cluster_len, 0);
226        }
227        let sizes = &mut state.cluster_sizes[..current_cluster_len];
228        sizes.fill(0);
229        for &id in &batch.cluster_id {
230            if let Ok(idx) = usize::try_from(id) {
231                if let Some(size) = sizes.get_mut(idx) {
232                    *size += 1;
233                }
234            }
235        }
236
237        if state.id_map.len() < current_cluster_len {
238            state.id_map.resize(current_cluster_len, -1);
239        }
240        let id_map = &mut state.id_map[..current_cluster_len];
241        id_map.fill(-1);
242        let mut new_cluster_count = 0;
243        let min_size = usize::from(self.config.min_cluster_size);
244
245        for (old_id, &size) in sizes.iter().enumerate() {
246            if size >= min_size {
247                id_map[old_id] = new_cluster_count;
248                new_cluster_count += 1;
249            }
250        }
251
252        batch.cluster_id.par_iter_mut().for_each(|id| {
253            if let Ok(idx) = usize::try_from(*id) {
254                if let Some(&new_id) = id_map.get(idx) {
255                    *id = new_id;
256                }
257            }
258        });
259
260        usize::try_from(new_cluster_count).unwrap_or(0)
261    }
262
263    fn region_query_into(
264        ctx: &DbscanContext,
265        idx: usize,
266        batch: &HitBatch,
267        neighbors: &mut Vec<usize>,
268    ) {
269        let x = f64::from(batch.x[idx]);
270        let y = f64::from(batch.y[idx]);
271        let tof = batch.tof[idx];
272        let cx = usize::from(batch.x[idx]) / ctx.cell_size;
273        let cy = usize::from(batch.y[idx]) / ctx.cell_size;
274        let cell_col = i32::try_from(cx).unwrap_or(i32::MAX);
275        let cell_row = i32::try_from(cy).unwrap_or(i32::MAX);
276
277        neighbors.clear();
278
279        // Check neighboring cells
280        for dy in -1..=1 {
281            for dx in -1..=1 {
282                let ncx = cell_col + dx;
283                let ncy = cell_row + dy;
284                if ncx < 0 || ncy < 0 {
285                    continue;
286                }
287                let (Ok(neighbor_x), Ok(neighbor_y)) = (usize::try_from(ncx), usize::try_from(ncy))
288                else {
289                    continue;
290                };
291                let gidx = neighbor_y * ctx.grid_w + neighbor_x;
292                if let Some(cell) = ctx.grid.get(gidx) {
293                    for &j in cell {
294                        if j == idx {
295                            continue;
296                        }
297                        let val_x = f64::from(batch.x[j]);
298                        let val_y = f64::from(batch.y[j]);
299                        let val_tof = batch.tof[j];
300
301                        let dt = tof.abs_diff(val_tof);
302                        if dt <= ctx.window_tof {
303                            let dist_sq = (x - val_x).powi(2) + (y - val_y).powi(2);
304                            if dist_sq <= ctx.eps_sq {
305                                neighbors.push(j);
306                            }
307                        }
308                    }
309                }
310            }
311        }
312    }
313
314    fn expand_cluster(
315        &self,
316        ctx: &DbscanContext,
317        seeds: &mut Vec<usize>,
318        cluster_id: i32,
319        batch: &mut HitBatch,
320        tracking: &mut TrackingState,
321        neighbors: &mut Vec<usize>,
322    ) {
323        let mut i = 0;
324        while i < seeds.len() {
325            let current_p = seeds[i];
326            i += 1;
327
328            if tracking.noise[current_p] {
329                tracking.noise[current_p] = false;
330                batch.cluster_id[current_p] = cluster_id;
331            }
332
333            if !tracking.visited[current_p] {
334                tracking.visited[current_p] = true;
335                batch.cluster_id[current_p] = cluster_id;
336
337                Self::region_query_into(ctx, current_p, batch, neighbors);
338                if neighbors.len() >= self.config.min_points {
339                    seeds.extend_from_slice(neighbors);
340                }
341            } else if batch.cluster_id[current_p] == -1 {
342                batch.cluster_id[current_p] = cluster_id;
343            }
344        }
345    }
346}
347
348fn float_to_u32(value: f64) -> u32 {
349    if value <= 0.0 {
350        return 0;
351    }
352    if value >= f64::from(u32::MAX) {
353        return u32::MAX;
354    }
355    format!("{value:.0}").parse::<u32>().unwrap_or(u32::MAX)
356}
357
358fn float_to_usize(value: f64) -> usize {
359    if value <= 0.0 {
360        return 0;
361    }
362    format!("{value:.0}").parse::<usize>().unwrap_or(usize::MAX)
363}