1use rayon::prelude::*;
14
15use crate::accel;
16use crate::constants::KNN_PAR_THRESHOLD;
17use crate::dataset::Dataset;
18use crate::distance::{
19 cosine_distance, euclidean_sq, manhattan, sparse_cosine, sparse_euclidean_sq, sparse_manhattan,
20};
21use crate::error::{Result, ScryLearnError};
22use crate::neighbors::kdtree::KdTree;
23use crate::sparse::{CsrMatrix, SparseRow};
24use crate::weights::{compute_sample_weights, ClassWeight};
25
26#[derive(Clone, Copy, Debug)]
36#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
37#[non_exhaustive]
38pub enum DistanceMetric {
39 Euclidean,
41 Manhattan,
43 Cosine,
45}
46
47#[derive(Clone, Copy, Debug, Default)]
61#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
62#[non_exhaustive]
63pub enum WeightFunction {
64 #[default]
66 Uniform,
67 Distance,
71}
72
73#[derive(Clone, Copy, Debug, Default)]
85#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
86#[non_exhaustive]
87pub enum Algorithm {
88 #[default]
93 Auto,
94 BruteForce,
96 KDTree,
100}
101
102#[derive(Clone)]
132#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
133#[non_exhaustive]
134pub struct KnnClassifier {
135 k: usize,
136 metric: DistanceMetric,
137 weight_fn: WeightFunction,
138 class_weight: ClassWeight,
139 algorithm: Algorithm,
140 train_features: Vec<Vec<f64>>, train_target: Vec<f64>,
142 train_weights: Vec<f64>,
143 n_classes: usize,
144 kdtree: Option<KdTree>,
145 train_sparse: Option<CsrMatrix>,
147 fitted: bool,
148 #[cfg_attr(feature = "serde", serde(default))]
149 _schema_version: u32,
150}
151
152impl KnnClassifier {
153 pub fn new() -> Self {
155 Self {
156 k: 5,
157 metric: DistanceMetric::Euclidean,
158 weight_fn: WeightFunction::Uniform,
159 class_weight: ClassWeight::Uniform,
160 algorithm: Algorithm::Auto,
161 train_features: Vec::new(),
162 train_target: Vec::new(),
163 train_weights: Vec::new(),
164 n_classes: 0,
165 kdtree: None,
166 train_sparse: None,
167 fitted: false,
168 _schema_version: crate::version::SCHEMA_VERSION,
169 }
170 }
171
172 pub fn k(mut self, k: usize) -> Self {
174 self.k = k;
175 self
176 }
177
178 pub fn metric(mut self, m: DistanceMetric) -> Self {
180 self.metric = m;
181 self
182 }
183
184 pub fn weights(mut self, w: WeightFunction) -> Self {
189 self.weight_fn = w;
190 self
191 }
192
193 pub fn class_weight(mut self, cw: ClassWeight) -> Self {
195 self.class_weight = cw;
196 self
197 }
198
199 pub fn algorithm(mut self, algo: Algorithm) -> Self {
207 self.algorithm = algo;
208 self
209 }
210
211 pub fn fit(&mut self, data: &Dataset) -> Result<()> {
217 data.validate_finite()?;
218 if data.n_samples() == 0 {
219 return Err(ScryLearnError::EmptyDataset);
220 }
221
222 if let Some(csr) = data.sparse_csr() {
224 self.train_sparse = Some(csr);
225 self.train_features = Vec::new(); } else {
227 self.train_sparse = None;
228 self.train_features = data.feature_matrix();
229 }
230
231 self.train_target.clone_from(&data.target);
232 self.train_weights = compute_sample_weights(&data.target, &self.class_weight);
233 self.n_classes = data.n_classes();
234
235 self.kdtree = if self.train_sparse.is_none()
237 && should_use_kdtree(self.algorithm, self.metric, data.n_features())
238 {
239 Some(KdTree::build(&self.train_features))
240 } else {
241 None
242 };
243
244 self.fitted = true;
245 Ok(())
246 }
247
248 pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
254 crate::version::check_schema_version(self._schema_version)?;
255 if !self.fitted {
256 return Err(ScryLearnError::NotFitted);
257 }
258 if self.train_features.is_empty() && self.train_sparse.is_some() {
259 return Err(ScryLearnError::InvalidParameter(
260 "model was trained on sparse data; use predict_sparse() instead".into(),
261 ));
262 }
263 let probas = self.compute_votes(features);
264 Ok(probas
265 .into_iter()
266 .map(|votes| {
267 votes
270 .iter()
271 .enumerate()
272 .fold((0usize, f64::NEG_INFINITY), |(best_i, best_v), (i, &v)| {
273 if v > best_v {
274 (i, v)
275 } else {
276 (best_i, best_v)
277 }
278 })
279 .0 as f64
280 })
281 .collect())
282 }
283
284 pub fn predict_proba(&self, features: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
311 if !self.fitted {
312 return Err(ScryLearnError::NotFitted);
313 }
314 if self.train_features.is_empty() && self.train_sparse.is_some() {
315 return Err(ScryLearnError::InvalidParameter(
316 "model was trained on sparse data; use predict_sparse() instead".into(),
317 ));
318 }
319 let votes = self.compute_votes(features);
320 Ok(votes
321 .into_iter()
322 .map(|v| {
323 let total: f64 = v.iter().sum();
324 if total > 0.0 {
325 v.iter().map(|&x| x / total).collect()
326 } else {
327 let n = v.len() as f64;
329 vec![1.0 / n; v.len()]
330 }
331 })
332 .collect())
333 }
334
335 #[allow(clippy::option_if_let_else)]
344 fn compute_votes(&self, features: &[Vec<f64>]) -> Vec<Vec<f64>> {
345 let k = self.k.min(self.train_features.len());
346 let use_actual_dist = matches!(self.weight_fn, WeightFunction::Distance);
347 let metric = self.metric;
348
349 let batched = if self.kdtree.is_none() && matches!(metric, DistanceMetric::Euclidean) {
351 batched_brute_force_neighbors(features, &self.train_features, k, use_actual_dist)
352 } else {
353 None
354 };
355
356 if let Some(all_neighbors) = batched {
357 all_neighbors
359 .into_iter()
360 .map(|neighbors| {
361 aggregate_votes(
362 &neighbors,
363 &self.train_target,
364 &self.train_weights,
365 self.n_classes,
366 use_actual_dist,
367 )
368 })
369 .collect()
370 } else {
371 let n_train = self.train_features.len();
373 let n_features = if n_train > 0 {
374 self.train_features[0].len()
375 } else {
376 0
377 };
378 let use_par =
379 self.kdtree.is_none() && features.len() * n_train * n_features >= KNN_PAR_THRESHOLD;
380
381 let vote_fn = |query: &Vec<f64>| {
382 let neighbors: Vec<(f64, usize)> = if let Some(ref tree) = self.kdtree {
383 let raw = tree.query_k_nearest(query, k, &self.train_features);
384 if use_actual_dist {
385 raw.into_iter().map(|(d2, i)| (d2.sqrt(), i)).collect()
386 } else {
387 raw
388 }
389 } else {
390 scalar_brute_force(query, &self.train_features, k, metric, use_actual_dist)
391 };
392
393 aggregate_votes(
394 &neighbors,
395 &self.train_target,
396 &self.train_weights,
397 self.n_classes,
398 use_actual_dist,
399 )
400 };
401
402 if use_par {
403 features.par_iter().map(vote_fn).collect()
404 } else {
405 features.iter().map(vote_fn).collect()
406 }
407 }
408 }
409}
410
411impl KnnClassifier {
412 pub fn predict_sparse(&self, features: &CsrMatrix) -> Result<Vec<f64>> {
417 if !self.fitted {
418 return Err(ScryLearnError::NotFitted);
419 }
420 let n_train = self.train_target.len();
421 let k = self.k.min(n_train);
422 let use_actual_dist = matches!(self.weight_fn, WeightFunction::Distance);
423
424 Ok((0..features.n_rows())
425 .map(|i| {
426 let query = features.row(i);
427 let neighbors = if let Some(ref train_csr) = self.train_sparse {
428 sparse_brute_force(&query, train_csr, k, self.metric, use_actual_dist)
429 } else {
430 let dense = sparse_row_to_dense(&query, features.n_cols());
431 scalar_brute_force(
432 &dense,
433 &self.train_features,
434 k,
435 self.metric,
436 use_actual_dist,
437 )
438 };
439 let votes = aggregate_votes(
440 &neighbors,
441 &self.train_target,
442 &self.train_weights,
443 self.n_classes,
444 use_actual_dist,
445 );
446 votes
447 .iter()
448 .enumerate()
449 .fold((0usize, f64::NEG_INFINITY), |(best_i, best_v), (i, &v)| {
450 if v > best_v {
451 (i, v)
452 } else {
453 (best_i, best_v)
454 }
455 })
456 .0 as f64
457 })
458 .collect())
459 }
460}
461
462impl Default for KnnClassifier {
463 fn default() -> Self {
464 Self::new()
465 }
466}
467
468#[derive(Clone)]
497#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
498#[non_exhaustive]
499pub struct KnnRegressor {
500 k: usize,
501 metric: DistanceMetric,
502 weight_fn: WeightFunction,
503 algorithm: Algorithm,
504 train_features: Vec<Vec<f64>>, train_target: Vec<f64>,
506 kdtree: Option<KdTree>,
507 train_sparse: Option<CsrMatrix>,
509 fitted: bool,
510 #[cfg_attr(feature = "serde", serde(default))]
511 _schema_version: u32,
512}
513
514impl KnnRegressor {
515 pub fn new() -> Self {
517 Self {
518 k: 5,
519 metric: DistanceMetric::Euclidean,
520 weight_fn: WeightFunction::Uniform,
521 algorithm: Algorithm::Auto,
522 train_features: Vec::new(),
523 train_target: Vec::new(),
524 kdtree: None,
525 train_sparse: None,
526 fitted: false,
527 _schema_version: crate::version::SCHEMA_VERSION,
528 }
529 }
530
531 pub fn k(mut self, k: usize) -> Self {
533 self.k = k;
534 self
535 }
536
537 pub fn metric(mut self, m: DistanceMetric) -> Self {
539 self.metric = m;
540 self
541 }
542
543 pub fn weights(mut self, w: WeightFunction) -> Self {
548 self.weight_fn = w;
549 self
550 }
551
552 pub fn algorithm(mut self, algo: Algorithm) -> Self {
556 self.algorithm = algo;
557 self
558 }
559
560 pub fn fit(&mut self, data: &Dataset) -> Result<()> {
565 data.validate_finite()?;
566 if data.n_samples() == 0 {
567 return Err(ScryLearnError::EmptyDataset);
568 }
569
570 if let Some(csr) = data.sparse_csr() {
571 self.train_sparse = Some(csr);
572 self.train_features = Vec::new();
573 } else {
574 self.train_sparse = None;
575 self.train_features = data.feature_matrix();
576 }
577
578 self.train_target.clone_from(&data.target);
579
580 self.kdtree = if self.train_sparse.is_none()
581 && should_use_kdtree(self.algorithm, self.metric, data.n_features())
582 {
583 Some(KdTree::build(&self.train_features))
584 } else {
585 None
586 };
587
588 self.fitted = true;
589 Ok(())
590 }
591
592 #[allow(clippy::option_if_let_else)]
600 pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
601 crate::version::check_schema_version(self._schema_version)?;
602 if !self.fitted {
603 return Err(ScryLearnError::NotFitted);
604 }
605 if self.train_features.is_empty() && self.train_sparse.is_some() {
606 return Err(ScryLearnError::InvalidParameter(
607 "model was trained on sparse data; use predict_sparse() instead".into(),
608 ));
609 }
610
611 let k = self.k.min(self.train_features.len());
612 let use_actual_dist = matches!(self.weight_fn, WeightFunction::Distance);
613 let metric = self.metric;
614
615 let batched = if self.kdtree.is_none() && matches!(metric, DistanceMetric::Euclidean) {
617 batched_brute_force_neighbors(features, &self.train_features, k, use_actual_dist)
618 } else {
619 None
620 };
621
622 let get_neighbors = |query: &Vec<f64>| -> Vec<(f64, usize)> {
623 if let Some(ref tree) = self.kdtree {
624 let raw = tree.query_k_nearest(query, k, &self.train_features);
625 if use_actual_dist {
626 raw.into_iter().map(|(d2, i)| (d2.sqrt(), i)).collect()
627 } else {
628 raw
629 }
630 } else {
631 scalar_brute_force(query, &self.train_features, k, metric, use_actual_dist)
632 }
633 };
634
635 if let Some(ref all) = batched {
636 Ok(features
638 .iter()
639 .enumerate()
640 .map(|(qi, _query)| {
641 aggregate_regression(&all[qi], &self.train_target, use_actual_dist, k)
642 })
643 .collect())
644 } else {
645 let n_train = self.train_features.len();
646 let n_features = if n_train > 0 {
647 self.train_features[0].len()
648 } else {
649 0
650 };
651 let use_par =
652 self.kdtree.is_none() && features.len() * n_train * n_features >= KNN_PAR_THRESHOLD;
653
654 let predict_fn = |query: &Vec<f64>| {
655 let neighbors = get_neighbors(query);
656 aggregate_regression(&neighbors, &self.train_target, use_actual_dist, k)
657 };
658
659 if use_par {
660 Ok(features.par_iter().map(predict_fn).collect())
661 } else {
662 Ok(features.iter().map(predict_fn).collect())
663 }
664 }
665 }
666}
667
668impl KnnRegressor {
669 pub fn predict_sparse(&self, features: &CsrMatrix) -> Result<Vec<f64>> {
673 if !self.fitted {
674 return Err(ScryLearnError::NotFitted);
675 }
676 let n_train = self.train_target.len();
677 let k = self.k.min(n_train);
678 let use_actual_dist = matches!(self.weight_fn, WeightFunction::Distance);
679
680 Ok((0..features.n_rows())
681 .map(|i| {
682 let query = features.row(i);
683 let neighbors = if let Some(ref train_csr) = self.train_sparse {
684 sparse_brute_force(&query, train_csr, k, self.metric, use_actual_dist)
685 } else {
686 let dense = sparse_row_to_dense(&query, features.n_cols());
687 scalar_brute_force(
688 &dense,
689 &self.train_features,
690 k,
691 self.metric,
692 use_actual_dist,
693 )
694 };
695 aggregate_regression(&neighbors, &self.train_target, use_actual_dist, k)
696 })
697 .collect())
698 }
699}
700
701impl Default for KnnRegressor {
702 fn default() -> Self {
703 Self::new()
704 }
705}
706
707fn scalar_brute_force(
715 query: &[f64],
716 train: &[Vec<f64>],
717 k: usize,
718 metric: DistanceMetric,
719 use_actual_dist: bool,
720) -> Vec<(f64, usize)> {
721 let mut dists: Vec<(f64, usize)> = train
722 .iter()
723 .enumerate()
724 .map(|(i, train_row)| {
725 let d = if use_actual_dist {
726 actual_distance(query, train_row, metric)
727 } else {
728 distance_for_compare(query, train_row, metric)
729 };
730 (d, i)
731 })
732 .collect();
733
734 if k < dists.len() {
735 dists.select_nth_unstable_by(k - 1, |a, b| {
736 a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
737 });
738 }
739 dists.truncate(k);
740 dists.sort_by(|a, b| {
743 a.0.partial_cmp(&b.0)
744 .unwrap_or(std::cmp::Ordering::Equal)
745 .then(a.1.cmp(&b.1))
746 });
747 dists
748}
749
750fn batched_brute_force_neighbors(
757 queries: &[Vec<f64>],
758 train: &[Vec<f64>],
759 k: usize,
760 use_actual_dist: bool,
761) -> Option<Vec<Vec<(f64, usize)>>> {
762 let n_q = queries.len();
763 let n_t = train.len();
764 if n_q == 0 || n_t == 0 {
765 return None;
766 }
767 let dim = queries[0].len();
768
769 if n_q * n_t < 256 {
772 return None;
773 }
774
775 let q_flat: Vec<f64> = queries.iter().flat_map(|r| r.iter().copied()).collect();
777 let t_flat: Vec<f64> = train.iter().flat_map(|r| r.iter().copied()).collect();
778
779 let backend = accel::auto();
780 let dist_matrix = backend.pairwise_distances_squared(&q_flat, &t_flat, n_q, n_t, dim);
781
782 let result: Vec<Vec<(f64, usize)>> = (0..n_q)
783 .map(|qi| {
784 let row = &dist_matrix[qi * n_t..(qi + 1) * n_t];
785 let mut indexed: Vec<(f64, usize)> = row
786 .iter()
787 .enumerate()
788 .map(|(j, &d2)| {
789 let d = if use_actual_dist { d2.sqrt() } else { d2 };
790 (d, j)
791 })
792 .collect();
793
794 let k_eff = k.min(indexed.len());
795 if k_eff < indexed.len() {
796 indexed.select_nth_unstable_by(k_eff - 1, |a, b| {
797 a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
798 });
799 }
800 indexed.truncate(k_eff);
801 indexed.sort_by(|a, b| {
803 a.0.partial_cmp(&b.0)
804 .unwrap_or(std::cmp::Ordering::Equal)
805 .then(a.1.cmp(&b.1))
806 });
807 indexed
808 })
809 .collect();
810
811 Some(result)
812}
813
814fn aggregate_votes(
816 neighbors: &[(f64, usize)],
817 target: &[f64],
818 weights: &[f64],
819 n_classes: usize,
820 use_actual_dist: bool,
821) -> Vec<f64> {
822 let mut votes = vec![0.0_f64; n_classes.max(1)];
823
824 if use_actual_dist {
825 let has_exact = neighbors.iter().any(|&(d, _)| d < f64::EPSILON);
826 if has_exact {
827 for &(d, idx) in neighbors {
828 if d < f64::EPSILON {
829 let class = target[idx] as usize;
830 let w = weights[idx];
831 if class < votes.len() {
832 votes[class] += w;
833 }
834 }
835 }
836 } else {
837 for &(d, idx) in neighbors {
838 let class = target[idx] as usize;
839 let w = weights[idx];
840 if class < votes.len() {
841 votes[class] += w / d;
842 }
843 }
844 }
845 } else {
846 for &(_, idx) in neighbors {
847 let class = target[idx] as usize;
848 let w = weights[idx];
849 if class < votes.len() {
850 votes[class] += w;
851 }
852 }
853 }
854
855 votes
856}
857
858fn aggregate_regression(
860 neighbors: &[(f64, usize)],
861 target: &[f64],
862 use_actual_dist: bool,
863 k: usize,
864) -> f64 {
865 if use_actual_dist {
866 let has_exact = neighbors.iter().any(|&(d, _)| d < f64::EPSILON);
867 if has_exact {
868 let (sum, count) = neighbors.iter().fold((0.0, 0usize), |(s, c), &(d, idx)| {
869 if d < f64::EPSILON {
870 (s + target[idx], c + 1)
871 } else {
872 (s, c)
873 }
874 });
875 sum / count as f64
876 } else {
877 let (weighted_sum, total_w) =
878 neighbors.iter().fold((0.0, 0.0), |(ws, tw), &(d, idx)| {
879 let w = 1.0 / d;
880 (ws + w * target[idx], tw + w)
881 });
882 weighted_sum / total_w
883 }
884 } else {
885 let sum: f64 = neighbors.iter().map(|&(_, idx)| target[idx]).sum();
886 sum / k as f64
887 }
888}
889
890#[inline]
899fn distance_for_compare(a: &[f64], b: &[f64], metric: DistanceMetric) -> f64 {
900 match metric {
901 DistanceMetric::Euclidean => euclidean_sq(a, b),
902 DistanceMetric::Manhattan => manhattan(a, b),
903 DistanceMetric::Cosine => cosine_distance(a, b),
904 }
905}
906
907#[inline]
912fn actual_distance(a: &[f64], b: &[f64], metric: DistanceMetric) -> f64 {
913 match metric {
914 DistanceMetric::Euclidean => euclidean_sq(a, b).sqrt(),
915 DistanceMetric::Manhattan => manhattan(a, b),
916 DistanceMetric::Cosine => cosine_distance(a, b),
917 }
918}
919
920fn sparse_row_to_dense(row: &SparseRow<'_>, n_cols: usize) -> Vec<f64> {
922 let mut dense = vec![0.0; n_cols];
923 for (col, val) in row.iter() {
924 dense[col] = val;
925 }
926 dense
927}
928
929#[inline]
931fn sparse_distance_for_compare(
932 a: &SparseRow<'_>,
933 b: &SparseRow<'_>,
934 metric: DistanceMetric,
935) -> f64 {
936 match metric {
937 DistanceMetric::Euclidean => sparse_euclidean_sq(a, b),
938 DistanceMetric::Manhattan => sparse_manhattan(a, b),
939 DistanceMetric::Cosine => sparse_cosine(a, b),
940 }
941}
942
943#[inline]
945fn sparse_actual_distance(a: &SparseRow<'_>, b: &SparseRow<'_>, metric: DistanceMetric) -> f64 {
946 match metric {
947 DistanceMetric::Euclidean => sparse_euclidean_sq(a, b).sqrt(),
948 DistanceMetric::Manhattan => sparse_manhattan(a, b),
949 DistanceMetric::Cosine => sparse_cosine(a, b),
950 }
951}
952
953fn sparse_brute_force(
955 query: &SparseRow<'_>,
956 train: &CsrMatrix,
957 k: usize,
958 metric: DistanceMetric,
959 use_actual_dist: bool,
960) -> Vec<(f64, usize)> {
961 let n = train.n_rows();
962 let mut dists: Vec<(f64, usize)> = (0..n)
963 .map(|i| {
964 let train_row = train.row(i);
965 let d = if use_actual_dist {
966 sparse_actual_distance(query, &train_row, metric)
967 } else {
968 sparse_distance_for_compare(query, &train_row, metric)
969 };
970 (d, i)
971 })
972 .collect();
973
974 if k < dists.len() {
975 dists.select_nth_unstable_by(k - 1, |a, b| {
976 a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
977 });
978 }
979 dists.truncate(k);
980 dists.sort_by(|a, b| {
981 a.0.partial_cmp(&b.0)
982 .unwrap_or(std::cmp::Ordering::Equal)
983 .then(a.1.cmp(&b.1))
984 });
985 dists
986}
987
988fn should_use_kdtree(algo: Algorithm, metric: DistanceMetric, n_features: usize) -> bool {
990 match algo {
991 Algorithm::BruteForce => false,
992 Algorithm::KDTree => matches!(metric, DistanceMetric::Euclidean),
993 Algorithm::Auto => matches!(metric, DistanceMetric::Euclidean) && n_features < 20,
994 }
995}
996
997#[cfg(test)]
998mod tests {
999 use super::*;
1000
1001 #[test]
1002 fn test_knn_simple() {
1003 let features = vec![
1005 vec![0.0, 0.0, 0.0, 10.0, 10.0, 10.0],
1006 vec![0.0, 0.0, 0.0, 10.0, 10.0, 10.0],
1007 ];
1008 let target = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
1009 let data = Dataset::new(features, target, vec!["x".into(), "y".into()], "class");
1010
1011 let mut knn = KnnClassifier::new().k(3);
1012 knn.fit(&data).unwrap();
1013
1014 let preds = knn.predict(&[vec![1.0, 1.0], vec![9.0, 9.0]]).unwrap();
1015 assert!((preds[0] - 0.0).abs() < 1e-6);
1016 assert!((preds[1] - 1.0).abs() < 1e-6);
1017 }
1018
1019 #[test]
1020 fn test_knn_distance_weights() {
1021 let features = vec![vec![5.0, 10.0, 10.0, 0.1, 0.2]];
1025 let target = vec![0.0, 0.0, 0.0, 1.0, 1.0];
1026 let data = Dataset::new(features, target, vec!["x".into()], "class");
1027
1028 let mut knn_dist = KnnClassifier::new().k(5).weights(WeightFunction::Distance);
1029 knn_dist.fit(&data).unwrap();
1030 let preds_d = knn_dist.predict(&[vec![0.15]]).unwrap();
1031 assert_eq!(
1032 preds_d[0] as usize, 1,
1033 "Distance-weighted should pick closer class 1"
1034 );
1035 }
1036
1037 #[test]
1038 fn test_knn_predict_proba() {
1039 let features = vec![vec![0.0, 0.0, 10.0, 10.0], vec![0.0, 0.0, 10.0, 10.0]];
1040 let target = vec![0.0, 0.0, 1.0, 1.0];
1041 let data = Dataset::new(features, target, vec!["x".into(), "y".into()], "class");
1042
1043 let mut knn = KnnClassifier::new().k(4);
1044 knn.fit(&data).unwrap();
1045
1046 let probas = knn
1047 .predict_proba(&[vec![1.0, 1.0], vec![5.0, 5.0]])
1048 .unwrap();
1049 for p in &probas {
1050 let sum: f64 = p.iter().sum();
1051 assert!(
1052 (sum - 1.0).abs() < 1e-9,
1053 "Probabilities must sum to 1.0, got {sum}"
1054 );
1055 }
1056
1057 assert!(
1059 probas[0][0] > 0.4,
1060 "Expected high prob for class 0 at (1,1)"
1061 );
1062 }
1063
1064 #[test]
1065 fn test_knn_cosine() {
1066 let d_same = cosine_distance(&[1.0, 0.0], &[100.0, 0.0]);
1070 let d_orth = cosine_distance(&[1.0, 0.0], &[0.0, 1.0]);
1071 assert!(
1072 d_same < 1e-9,
1073 "Same direction should have ~0 distance, got {d_same}"
1074 );
1075 assert!(
1076 (d_orth - 1.0).abs() < 1e-9,
1077 "Orthogonal should have distance ~1, got {d_orth}"
1078 );
1079
1080 let features = vec![vec![1.0, 100.0, 0.0, 0.0], vec![0.0, 0.0, 1.0, 100.0]];
1082 let target = vec![0.0, 0.0, 1.0, 1.0];
1083 let data = Dataset::new(features, target, vec!["x".into(), "y".into()], "class");
1084
1085 let mut knn = KnnClassifier::new().k(2).metric(DistanceMetric::Cosine);
1086 knn.fit(&data).unwrap();
1087
1088 let preds = knn.predict(&[vec![50.0, 0.0]]).unwrap();
1090 assert_eq!(
1091 preds[0] as usize, 0,
1092 "Cosine metric should match class 0 by direction"
1093 );
1094 }
1095
1096 #[test]
1097 fn test_knn_regressor_simple() {
1098 let features = vec![vec![1.0, 5.0, 9.0]];
1100 let target = vec![10.0, 50.0, 90.0];
1101 let data = Dataset::new(features, target, vec!["x".into()], "y");
1102
1103 let mut knn = KnnRegressor::new().k(2);
1104 knn.fit(&data).unwrap();
1105
1106 let preds = knn.predict(&[vec![3.0]]).unwrap();
1108 assert!(
1109 (preds[0] - 30.0).abs() < 1e-9,
1110 "Expected 30.0, got {}",
1111 preds[0]
1112 );
1113
1114 let preds2 = knn.predict(&[vec![7.0]]).unwrap();
1116 assert!(
1117 (preds2[0] - 70.0).abs() < 1e-9,
1118 "Expected 70.0, got {}",
1119 preds2[0]
1120 );
1121 }
1122
1123 #[test]
1124 fn test_knn_regressor_distance_weights() {
1125 let features = vec![vec![0.0, 10.0]];
1129 let target = vec![0.0, 100.0];
1130 let data = Dataset::new(features, target, vec!["x".into()], "y");
1131
1132 let mut knn_u = KnnRegressor::new().k(2);
1133 knn_u.fit(&data).unwrap();
1134 let pred_u = knn_u.predict(&[vec![1.0]]).unwrap()[0];
1135 assert!((pred_u - 50.0).abs() < 1e-9, "Uniform should give 50.0");
1136
1137 let mut knn_d = KnnRegressor::new().k(2).weights(WeightFunction::Distance);
1138 knn_d.fit(&data).unwrap();
1139 let pred_d = knn_d.predict(&[vec![1.0]]).unwrap()[0];
1140 assert!(
1142 pred_d < 20.0,
1143 "Distance-weighted should favor x=0, got {pred_d}"
1144 );
1145 }
1146
1147 #[test]
1148 fn test_knn_not_fitted() {
1149 let knn = KnnClassifier::new();
1150 assert!(knn.predict(&[vec![1.0]]).is_err());
1151 assert!(knn.predict_proba(&[vec![1.0]]).is_err());
1152
1153 let knn_r = KnnRegressor::new();
1154 assert!(knn_r.predict(&[vec![1.0]]).is_err());
1155 }
1156
1157 #[test]
1158 fn test_knn_predict_sparse_matches_dense() {
1159 let features = vec![
1160 vec![0.0, 0.0, 0.0, 10.0, 10.0, 10.0],
1161 vec![0.0, 0.0, 0.0, 10.0, 10.0, 10.0],
1162 ];
1163 let target = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
1164 let data = Dataset::new(features, target, vec!["x".into(), "y".into()], "class");
1165
1166 let mut knn = KnnClassifier::new().k(3);
1167 knn.fit(&data).unwrap();
1168
1169 let test = vec![vec![1.0, 1.0], vec![9.0, 9.0]];
1170 let preds_dense = knn.predict(&test).unwrap();
1171 let csr = CsrMatrix::from_dense(&test);
1172 let preds_sparse = knn.predict_sparse(&csr).unwrap();
1173
1174 for (d, s) in preds_dense.iter().zip(preds_sparse.iter()) {
1175 assert!((d - s).abs() < 1e-6, "Dense={d} vs Sparse={s}");
1176 }
1177 }
1178
1179 #[test]
1180 fn test_knn_regressor_predict_sparse() {
1181 let features = vec![vec![1.0, 5.0, 9.0]];
1182 let target = vec![10.0, 50.0, 90.0];
1183 let data = Dataset::new(features, target, vec!["x".into()], "y");
1184
1185 let mut knn = KnnRegressor::new().k(2);
1186 knn.fit(&data).unwrap();
1187
1188 let test = vec![vec![3.0], vec![7.0]];
1189 let preds_dense = knn.predict(&test).unwrap();
1190 let csr = CsrMatrix::from_dense(&test);
1191 let preds_sparse = knn.predict_sparse(&csr).unwrap();
1192
1193 for (d, s) in preds_dense.iter().zip(preds_sparse.iter()) {
1194 assert!((d - s).abs() < 1e-6, "Dense={d} vs Sparse={s}");
1195 }
1196 }
1197
1198 #[test]
1199 fn test_sparse_euclidean_matches_dense() {
1200 let a = CsrMatrix::from_dense(&[vec![1.0, 0.0, 3.0]]);
1202 let b = CsrMatrix::from_dense(&[vec![0.0, 2.0, 3.0]]);
1203 let d2 = sparse_euclidean_sq(&a.row(0), &b.row(0));
1204 assert!((d2 - 5.0).abs() < 1e-10, "Expected 5.0, got {d2}");
1205 }
1206
1207 #[test]
1208 fn test_sparse_manhattan_matches_dense() {
1209 let a = CsrMatrix::from_dense(&[vec![1.0, 0.0, 3.0]]);
1211 let b = CsrMatrix::from_dense(&[vec![0.0, 2.0, 3.0]]);
1212 let d = sparse_manhattan(&a.row(0), &b.row(0));
1213 assert!((d - 3.0).abs() < 1e-10, "Expected 3.0, got {d}");
1214 }
1215
1216 #[test]
1217 fn test_sparse_cosine_matches_dense() {
1218 let a = CsrMatrix::from_dense(&[vec![1.0, 0.0]]);
1220 let b = CsrMatrix::from_dense(&[vec![100.0, 0.0]]);
1221 let d = sparse_cosine(&a.row(0), &b.row(0));
1222 assert!(d < 1e-9, "Same direction should be ~0, got {d}");
1223
1224 let c = CsrMatrix::from_dense(&[vec![0.0, 1.0]]);
1226 let d_orth = sparse_cosine(&a.row(0), &c.row(0));
1227 assert!(
1228 (d_orth - 1.0).abs() < 1e-9,
1229 "Orthogonal should be ~1, got {d_orth}"
1230 );
1231 }
1232
1233 #[test]
1234 fn test_sparse_knn_end_to_end() {
1235 use crate::sparse::CscMatrix;
1237 let features = vec![
1238 vec![0.0, 0.0, 0.0, 10.0, 10.0, 10.0],
1239 vec![0.0, 0.0, 0.0, 10.0, 10.0, 10.0],
1240 ];
1241 let target = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
1242 let data = Dataset::new(
1243 features.clone(),
1244 target.clone(),
1245 vec!["x".into(), "y".into()],
1246 "class",
1247 );
1248
1249 let mut knn_dense = KnnClassifier::new().k(3);
1251 knn_dense.fit(&data).unwrap();
1252
1253 let csc = CscMatrix::from_dense(&features);
1255 let data_sparse = Dataset::from_sparse(csc, target, vec!["x".into(), "y".into()], "class");
1256 let mut knn_sparse = KnnClassifier::new().k(3);
1257 knn_sparse.fit(&data_sparse).unwrap();
1258 assert!(knn_sparse.train_sparse.is_some());
1259
1260 let test = vec![vec![1.0, 1.0], vec![9.0, 9.0]];
1262 let preds_dense = knn_dense.predict(&test).unwrap();
1263 let csr = CsrMatrix::from_dense(&test);
1264 let preds_sparse = knn_sparse.predict_sparse(&csr).unwrap();
1265
1266 for (d, s) in preds_dense.iter().zip(preds_sparse.iter()) {
1267 assert!((d - s).abs() < 1e-6, "Dense={d} vs Sparse={s}");
1268 }
1269 }
1270
1271 #[test]
1272 fn test_high_dimensional_sparse_knn() {
1273 use crate::sparse::CscMatrix;
1276 let n_train = 100;
1277 let n_feat = 5000;
1278 let mut rng = crate::rng::FastRng::new(42);
1279
1280 let mut cols: Vec<Vec<f64>> = vec![vec![0.0; n_train]; n_feat];
1282 for col in &mut cols {
1283 for x in col.iter_mut() {
1284 if rng.f64() < 0.02 {
1285 *x = rng.f64() * 10.0;
1286 }
1287 }
1288 }
1289 let target: Vec<f64> = (0..n_train).map(|i| (i % 3) as f64).collect();
1290 let csc = CscMatrix::from_dense(&cols);
1291 let names: Vec<String> = (0..n_feat).map(|j| format!("f{j}")).collect();
1292 let data = Dataset::from_sparse(csc, target, names, "class");
1293
1294 let mut knn = KnnClassifier::new().k(5);
1295 knn.fit(&data).unwrap();
1296 assert!(knn.train_sparse.is_some());
1297
1298 let mut query_row = vec![0.0; n_feat];
1300 for x in &mut query_row {
1301 if rng.f64() < 0.02 {
1302 *x = rng.f64() * 10.0;
1303 }
1304 }
1305 let query_csr = CsrMatrix::from_dense(&[query_row]);
1306 let preds = knn.predict_sparse(&query_csr).unwrap();
1307 assert_eq!(preds.len(), 1);
1308 assert!(preds[0] >= 0.0 && preds[0] < 3.0);
1309 }
1310}
1311
1312#[cfg(all(test, feature = "scry-gpu"))]
1313mod gpu_tests {
1314 use super::*;
1315
1316 #[test]
1317 fn gpu_knn_classifier_batched_matches_scalar() {
1318 let n_train = 100;
1320 let n_feat = 5;
1321 let mut features_col: Vec<Vec<f64>> = Vec::with_capacity(n_feat);
1322 for j in 0..n_feat {
1323 let col: Vec<f64> = (0..n_train)
1324 .map(|i| ((i * (j + 3)) % 37) as f64 * 0.5)
1325 .collect();
1326 features_col.push(col);
1327 }
1328 let target: Vec<f64> = (0..n_train).map(|i| (i % 3) as f64).collect();
1329 let names: Vec<String> = (0..n_feat).map(|j| format!("f{j}")).collect();
1330 let data = Dataset::new(features_col, target, names, "class");
1331
1332 let mut knn = KnnClassifier::new().k(5).algorithm(Algorithm::BruteForce);
1333 knn.fit(&data).unwrap();
1334
1335 let queries: Vec<Vec<f64>> = (0..10)
1337 .map(|i| (0..n_feat).map(|j| ((i + j) % 17) as f64 * 0.3).collect())
1338 .collect();
1339
1340 let preds = knn.predict(&queries).unwrap();
1341 assert_eq!(preds.len(), 10);
1342 for p in &preds {
1343 assert!(
1344 *p >= 0.0 && *p < 3.0,
1345 "prediction must be a valid class: {p}"
1346 );
1347 }
1348 }
1349
1350 #[test]
1351 fn gpu_knn_regressor_batched_matches_scalar() {
1352 let n_train = 100;
1353 let n_feat = 5;
1354 let mut features_col: Vec<Vec<f64>> = Vec::with_capacity(n_feat);
1355 for j in 0..n_feat {
1356 let col: Vec<f64> = (0..n_train)
1357 .map(|i| ((i * (j + 2)) % 41) as f64 * 0.2)
1358 .collect();
1359 features_col.push(col);
1360 }
1361 let target: Vec<f64> = (0..n_train).map(|i| (i % 50) as f64).collect();
1362 let names: Vec<String> = (0..n_feat).map(|j| format!("f{j}")).collect();
1363 let data = Dataset::new(features_col, target, names, "y");
1364
1365 let mut knn = KnnRegressor::new().k(5).algorithm(Algorithm::BruteForce);
1366 knn.fit(&data).unwrap();
1367
1368 let queries: Vec<Vec<f64>> = (0..10)
1369 .map(|i| (0..n_feat).map(|j| ((i + j) % 19) as f64 * 0.4).collect())
1370 .collect();
1371
1372 let preds = knn.predict(&queries).unwrap();
1373 assert_eq!(preds.len(), 10);
1374 for p in &preds {
1375 assert!(p.is_finite(), "prediction must be finite: {p}");
1376 }
1377 }
1378}