1use crate::dataset::Dataset;
14use crate::distance::{cosine_distance, euclidean_sq, manhattan};
15use crate::error::{Result, ScryLearnError};
16use crate::neighbors::kdtree::KdTree;
17use crate::neighbors::DistanceMetric;
18
19const KDTREE_MAX_DIM: usize = 20;
21
22#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
45#[non_exhaustive]
46pub struct Dbscan {
47 eps: f64,
48 min_samples: usize,
49 metric: DistanceMetric,
50 labels: Vec<i32>, n_clusters: usize,
52 core_features: Vec<Vec<f64>>,
54 core_labels: Vec<i32>,
56 fitted: bool,
57 #[cfg_attr(feature = "serde", serde(default))]
58 _schema_version: u32,
59}
60
61impl Dbscan {
62 pub fn new(eps: f64, min_samples: usize) -> Self {
69 Self {
70 eps,
71 min_samples,
72 metric: DistanceMetric::Euclidean,
73 labels: Vec::new(),
74 n_clusters: 0,
75 core_features: Vec::new(),
76 core_labels: Vec::new(),
77 fitted: false,
78 _schema_version: crate::version::SCHEMA_VERSION,
79 }
80 }
81
82 pub fn metric(mut self, m: DistanceMetric) -> Self {
88 self.metric = m;
89 self
90 }
91
92 pub fn fit(&mut self, data: &Dataset) -> Result<()> {
97 data.validate_finite()?;
98 let n = data.n_samples();
99 if n == 0 {
100 return Err(ScryLearnError::EmptyDataset);
101 }
102
103 let rows = data.feature_matrix();
104 let n_features = data.n_features();
105 let threshold = self.eps_threshold();
106
107 let use_kdtree =
108 matches!(self.metric, DistanceMetric::Euclidean) && n_features <= KDTREE_MAX_DIM;
109
110 let kdtree = if use_kdtree {
111 Some(KdTree::build(&rows))
112 } else {
113 None
114 };
115
116 let mut labels = vec![-1i32; n];
117 let mut cluster_id = 0i32;
118
119 for i in 0..n {
120 if labels[i] != -1 {
121 continue;
122 }
123
124 let neighbors = self.find_neighbors(i, &rows, threshold, kdtree.as_ref());
126
127 if neighbors.len() < self.min_samples {
128 continue; }
130
131 labels[i] = cluster_id;
133 let mut queue: Vec<usize> = neighbors.into_iter().filter(|&j| j != i).collect();
134 let mut qi = 0;
135
136 while qi < queue.len() {
137 let j = queue[qi];
138 qi += 1;
139
140 if labels[j] == -1 {
141 labels[j] = cluster_id;
142 }
143 if labels[j] != cluster_id {
144 continue;
145 }
146
147 let j_neighbors = self.find_neighbors(j, &rows, threshold, kdtree.as_ref());
149
150 if j_neighbors.len() >= self.min_samples {
151 for k in j_neighbors {
152 if labels[k] == -1 {
153 labels[k] = cluster_id;
154 queue.push(k);
155 }
156 }
157 }
158 }
159
160 cluster_id += 1;
161 }
162
163 let mut core_features = Vec::new();
165 let mut core_labels = Vec::new();
166 for i in 0..n {
167 if labels[i] >= 0 {
168 let neighbors = self.find_neighbors(i, &rows, threshold, kdtree.as_ref());
169 if neighbors.len() >= self.min_samples {
170 core_features.push(rows[i].clone());
171 core_labels.push(labels[i]);
172 }
173 }
174 }
175
176 self.labels = labels;
177 self.n_clusters = cluster_id as usize;
178 self.core_features = core_features;
179 self.core_labels = core_labels;
180 self.fitted = true;
181 Ok(())
182 }
183
184 fn find_neighbors(
186 &self,
187 idx: usize,
188 rows: &[Vec<f64>],
189 threshold: f64,
190 kdtree: Option<&KdTree>,
191 ) -> Vec<usize> {
192 kdtree.map_or_else(
193 || {
194 let n = rows.len();
196 (0..n)
197 .filter(|&j| self.distance(&rows[idx], &rows[j]) <= threshold)
198 .collect()
199 },
200 |tree| {
201 tree.query_radius(&rows[idx], threshold, rows)
203 },
204 )
205 }
206
207 #[inline]
212 fn distance(&self, a: &[f64], b: &[f64]) -> f64 {
213 match self.metric {
214 DistanceMetric::Euclidean => euclidean_sq(a, b),
215 DistanceMetric::Manhattan => manhattan(a, b),
216 DistanceMetric::Cosine => cosine_distance(a, b),
217 }
218 }
219
220 #[inline]
225 fn eps_threshold(&self) -> f64 {
226 match self.metric {
227 DistanceMetric::Euclidean => self.eps * self.eps,
228 DistanceMetric::Manhattan | DistanceMetric::Cosine => self.eps,
229 }
230 }
231
232 pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<i32>> {
258 crate::version::check_schema_version(self._schema_version)?;
259 if !self.fitted {
260 return Err(ScryLearnError::NotFitted);
261 }
262
263 let threshold = self.eps_threshold();
264
265 Ok(features
266 .iter()
267 .map(|query| {
268 let mut best_dist = f64::INFINITY;
269 let mut best_label = -1i32;
270
271 for (i, core_pt) in self.core_features.iter().enumerate() {
272 let d = self.distance(query, core_pt);
273 if d <= threshold && d < best_dist {
274 best_dist = d;
275 best_label = self.core_labels[i];
276 }
277 }
278
279 best_label
280 })
281 .collect())
282 }
283
284 pub fn labels(&self) -> &[i32] {
286 &self.labels
287 }
288
289 pub fn n_clusters(&self) -> usize {
291 self.n_clusters
292 }
293
294 pub fn n_noise(&self) -> usize {
296 self.labels.iter().filter(|&&l| l == -1).count()
297 }
298
299 pub fn n_core_points(&self) -> usize {
301 self.core_features.len()
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308
309 #[test]
310 fn test_dbscan_two_clusters() {
311 let mut rng = crate::rng::FastRng::new(0);
312 let mut f1 = Vec::new();
313 let mut f2 = Vec::new();
314 for _ in 0..10 {
316 f1.push(rng.f64() * 2.0);
317 f2.push(rng.f64() * 2.0);
318 }
319 for _ in 0..10 {
321 f1.push(50.0 + rng.f64() * 2.0);
322 f2.push(50.0 + rng.f64() * 2.0);
323 }
324
325 let data = Dataset::new(
326 vec![f1, f2],
327 vec![0.0; 20],
328 vec!["x".into(), "y".into()],
329 "label",
330 );
331
332 let mut db = Dbscan::new(5.0, 3);
333 db.fit(&data).unwrap();
334
335 assert_eq!(db.n_clusters(), 2, "should find 2 clusters");
336 }
337
338 #[test]
339 fn test_dbscan_noise() {
340 let data = Dataset::new(
342 vec![vec![0.0, 100.0, 200.0], vec![0.0, 100.0, 200.0]],
343 vec![0.0; 3],
344 vec!["x".into(), "y".into()],
345 "label",
346 );
347
348 let mut db = Dbscan::new(1.0, 2);
349 db.fit(&data).unwrap();
350
351 assert_eq!(db.n_noise(), 3, "all points should be noise");
352 }
353
354 #[test]
355 fn test_dbscan_kdtree_parity() {
356 let mut rng = crate::rng::FastRng::new(42);
358 let n = 100;
359 let mut f1 = Vec::with_capacity(n);
360 let mut f2 = Vec::with_capacity(n);
361 for _ in 0..40 {
363 f1.push(rng.f64() * 3.0);
364 f2.push(rng.f64() * 3.0);
365 }
366 for _ in 0..40 {
367 f1.push(20.0 + rng.f64() * 3.0);
368 f2.push(20.0 + rng.f64() * 3.0);
369 }
370 for _ in 0..20 {
371 f1.push(rng.f64() * 100.0);
372 f2.push(rng.f64() * 100.0);
373 }
374
375 let data = Dataset::new(
376 vec![f1, f2],
377 vec![0.0; n],
378 vec!["x".into(), "y".into()],
379 "label",
380 );
381
382 let mut db_kd = Dbscan::new(4.0, 3);
384 db_kd.fit(&data).unwrap();
385
386 let labels_kd = db_kd.labels().to_vec();
390
391 let mut db_kd2 = Dbscan::new(4.0, 3);
396 db_kd2.fit(&data).unwrap();
397 let labels_kd2 = db_kd2.labels().to_vec();
398
399 assert_eq!(labels_kd, labels_kd2, "DBSCAN should be deterministic");
400 assert!(db_kd.n_clusters() >= 2, "should find at least 2 clusters");
401 }
402
403 #[test]
404 fn test_dbscan_predict() {
405 let data = Dataset::new(
406 vec![
407 vec![0.0, 0.0, 0.0, 10.0, 10.0, 10.0],
408 vec![0.0, 0.0, 0.0, 10.0, 10.0, 10.0],
409 ],
410 vec![0.0; 6],
411 vec!["x".into(), "y".into()],
412 "label",
413 );
414
415 let mut db = Dbscan::new(5.0, 2);
416 db.fit(&data).unwrap();
417
418 assert_eq!(db.n_clusters(), 2);
419
420 let near_a = db.predict(&[vec![0.5, 0.5]]).unwrap();
422 assert!(near_a[0] >= 0, "Should be assigned to cluster A");
423
424 let near_b = db.predict(&[vec![10.5, 10.5]]).unwrap();
426 assert!(near_b[0] >= 0, "Should be assigned to cluster B");
427
428 assert_ne!(near_a[0], near_b[0], "Different clusters");
429
430 let far = db.predict(&[vec![500.0, 500.0]]).unwrap();
432 assert_eq!(far[0], -1, "Far point should be noise");
433 }
434
435 #[test]
436 fn test_dbscan_manhattan() {
437 let mut rng = crate::rng::FastRng::new(0);
439 let mut f1 = Vec::new();
440 let mut f2 = Vec::new();
441 for _ in 0..10 {
442 f1.push(rng.f64() * 2.0);
443 f2.push(rng.f64() * 2.0);
444 }
445 for _ in 0..10 {
446 f1.push(50.0 + rng.f64() * 2.0);
447 f2.push(50.0 + rng.f64() * 2.0);
448 }
449
450 let data = Dataset::new(
451 vec![f1, f2],
452 vec![0.0; 20],
453 vec!["x".into(), "y".into()],
454 "label",
455 );
456
457 let mut db = Dbscan::new(5.0, 3).metric(DistanceMetric::Manhattan);
458 db.fit(&data).unwrap();
459
460 assert_eq!(
461 db.n_clusters(),
462 2,
463 "Manhattan DBSCAN should find 2 clusters"
464 );
465 }
466}