Skip to main content

scry_learn/neighbors/
knn.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! K-Nearest Neighbors classifier and regressor.
3//!
4//! Supports two search algorithms:
5//! - **Brute force**: O(n) per query. Always correct, any metric.
6//! - **KD-tree**: O(log n) average per query. Euclidean only, best for < 20 features.
7//!
8//! Optimizations:
9//! - Uses squared Euclidean distance (avoids sqrt — monotonic, same ordering).
10//! - Uses `select_nth_unstable` for partial sort (O(n) vs O(n·log n)).
11//! - Fixed-size vote array avoids HashMap overhead.
12
13use rayon::prelude::*;
14
15use crate::accel;
16use crate::constants::KNN_PAR_THRESHOLD;
17use crate::dataset::Dataset;
18use crate::distance::{
19    cosine_distance, euclidean_sq, manhattan, sparse_cosine, sparse_euclidean_sq, sparse_manhattan,
20};
21use crate::error::{Result, ScryLearnError};
22use crate::neighbors::kdtree::KdTree;
23use crate::sparse::{CsrMatrix, SparseRow};
24use crate::weights::{compute_sample_weights, ClassWeight};
25
26/// Distance metric for KNN.
27///
28/// # Example
29///
30/// ```
31/// use scry_learn::neighbors::DistanceMetric;
32///
33/// let metric = DistanceMetric::Cosine;
34/// ```
35#[derive(Clone, Copy, Debug)]
36#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
37#[non_exhaustive]
38pub enum DistanceMetric {
39    /// Euclidean distance (L2).
40    Euclidean,
41    /// Manhattan distance (L1).
42    Manhattan,
43    /// Cosine distance: `1 − cos(θ)`, range `[0, 2]`.
44    Cosine,
45}
46
47/// Weighting function for neighbor votes.
48///
49/// Controls how the k-nearest neighbors contribute to predictions.
50///
51/// # Example
52///
53/// ```
54/// use scry_learn::neighbors::{KnnClassifier, WeightFunction};
55///
56/// let knn = KnnClassifier::new()
57///     .k(5)
58///     .weights(WeightFunction::Distance);
59/// ```
60#[derive(Clone, Copy, Debug, Default)]
61#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
62#[non_exhaustive]
63pub enum WeightFunction {
64    /// All neighbors have equal vote weight.
65    #[default]
66    Uniform,
67    /// Closer neighbors contribute more: weight = `1 / distance`.
68    ///
69    /// When distance is zero (exact match), that neighbor gets all the weight.
70    Distance,
71}
72
73/// Algorithm used for nearest-neighbor search.
74///
75/// # Example
76///
77/// ```
78/// use scry_learn::neighbors::{KnnClassifier, Algorithm};
79///
80/// let knn = KnnClassifier::new()
81///     .k(5)
82///     .algorithm(Algorithm::KDTree);
83/// ```
84#[derive(Clone, Copy, Debug, Default)]
85#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
86#[non_exhaustive]
87pub enum Algorithm {
88    /// Automatically choose the best algorithm based on data and metric.
89    ///
90    /// Uses KD-tree for Euclidean distance with < 20 features,
91    /// brute-force otherwise.
92    #[default]
93    Auto,
94    /// Brute-force O(n) search. Works with all distance metrics.
95    BruteForce,
96    /// KD-tree O(log n) average-case search. Euclidean distance only.
97    ///
98    /// Falls back to brute-force if a non-Euclidean metric is selected.
99    KDTree,
100}
101
102// ─────────────────────────────────────────────────────────────────
103// KNN Classifier
104// ─────────────────────────────────────────────────────────────────
105
106/// K-Nearest Neighbors classifier.
107///
108/// Uses brute-force distance computation — fast enough for datasets up to ~100k samples.
109///
110/// # Example
111///
112/// ```
113/// use scry_learn::dataset::Dataset;
114/// use scry_learn::neighbors::{KnnClassifier, WeightFunction};
115///
116/// let features = vec![
117///     vec![0.0, 0.0, 10.0, 10.0],
118///     vec![0.0, 0.0, 10.0, 10.0],
119/// ];
120/// let target = vec![0.0, 0.0, 1.0, 1.0];
121/// let data = Dataset::new(features, target, vec!["x".into(), "y".into()], "class");
122///
123/// let mut knn = KnnClassifier::new()
124///     .k(3)
125///     .weights(WeightFunction::Distance);
126/// knn.fit(&data).unwrap();
127///
128/// let preds = knn.predict(&[vec![1.0, 1.0]]).unwrap();
129/// assert_eq!(preds[0] as usize, 0);
130/// ```
131#[derive(Clone)]
132#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
133#[non_exhaustive]
134pub struct KnnClassifier {
135    k: usize,
136    metric: DistanceMetric,
137    weight_fn: WeightFunction,
138    class_weight: ClassWeight,
139    algorithm: Algorithm,
140    train_features: Vec<Vec<f64>>, // [n_samples][n_features]
141    train_target: Vec<f64>,
142    train_weights: Vec<f64>,
143    n_classes: usize,
144    kdtree: Option<KdTree>,
145    /// Sparse training data (CSR) for sparse-native distance computation.
146    train_sparse: Option<CsrMatrix>,
147    fitted: bool,
148    #[cfg_attr(feature = "serde", serde(default))]
149    _schema_version: u32,
150}
151
152impl KnnClassifier {
153    /// Create a new KNN classifier with k=5.
154    pub fn new() -> Self {
155        Self {
156            k: 5,
157            metric: DistanceMetric::Euclidean,
158            weight_fn: WeightFunction::Uniform,
159            class_weight: ClassWeight::Uniform,
160            algorithm: Algorithm::Auto,
161            train_features: Vec::new(),
162            train_target: Vec::new(),
163            train_weights: Vec::new(),
164            n_classes: 0,
165            kdtree: None,
166            train_sparse: None,
167            fitted: false,
168            _schema_version: crate::version::SCHEMA_VERSION,
169        }
170    }
171
172    /// Set the number of neighbors.
173    pub fn k(mut self, k: usize) -> Self {
174        self.k = k;
175        self
176    }
177
178    /// Set the distance metric.
179    pub fn metric(mut self, m: DistanceMetric) -> Self {
180        self.metric = m;
181        self
182    }
183
184    /// Set the neighbor weighting function.
185    ///
186    /// - [`WeightFunction::Uniform`]: every neighbor's vote counts equally.
187    /// - [`WeightFunction::Distance`]: closer neighbors contribute more (weight = `1/d`).
188    pub fn weights(mut self, w: WeightFunction) -> Self {
189        self.weight_fn = w;
190        self
191    }
192
193    /// Set class weighting strategy for imbalanced datasets.
194    pub fn class_weight(mut self, cw: ClassWeight) -> Self {
195        self.class_weight = cw;
196        self
197    }
198
199    /// Set the nearest-neighbor search algorithm.
200    ///
201    /// - [`Algorithm::Auto`] (default): uses KD-tree for Euclidean distance
202    ///   with fewer than 20 features, brute-force otherwise.
203    /// - [`Algorithm::BruteForce`]: always O(n) brute-force scan.
204    /// - [`Algorithm::KDTree`]: builds a KD-tree for O(log n) queries;
205    ///   falls back to brute-force if a non-Euclidean metric is set.
206    pub fn algorithm(mut self, algo: Algorithm) -> Self {
207        self.algorithm = algo;
208        self
209    }
210
211    /// Store training data. Builds a KD-tree if the selected algorithm
212    /// (or `Auto` heuristic) calls for it.
213    ///
214    /// If the dataset uses sparse storage, the CSR representation is stored
215    /// for efficient sparse distance computation in [`KnnClassifier::predict_sparse`].
216    pub fn fit(&mut self, data: &Dataset) -> Result<()> {
217        data.validate_finite()?;
218        if data.n_samples() == 0 {
219            return Err(ScryLearnError::EmptyDataset);
220        }
221
222        // Store sparse training data if available.
223        if let Some(csr) = data.sparse_csr() {
224            self.train_sparse = Some(csr);
225            self.train_features = Vec::new(); // no dense copy needed
226        } else {
227            self.train_sparse = None;
228            self.train_features = data.feature_matrix();
229        }
230
231        self.train_target.clone_from(&data.target);
232        self.train_weights = compute_sample_weights(&data.target, &self.class_weight);
233        self.n_classes = data.n_classes();
234
235        // Build KD-tree if appropriate (only for dense data).
236        self.kdtree = if self.train_sparse.is_none()
237            && should_use_kdtree(self.algorithm, self.metric, data.n_features())
238        {
239            Some(KdTree::build(&self.train_features))
240        } else {
241            None
242        };
243
244        self.fitted = true;
245        Ok(())
246    }
247
248    /// Predict class labels.
249    ///
250    /// Uses partial sort (`select_nth_unstable`) to find k nearest neighbors
251    /// in O(n) instead of full O(n·log n) sort. Euclidean distances skip sqrt
252    /// since we only need relative ordering (unless distance weighting is on).
253    pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
254        crate::version::check_schema_version(self._schema_version)?;
255        if !self.fitted {
256            return Err(ScryLearnError::NotFitted);
257        }
258        if self.train_features.is_empty() && self.train_sparse.is_some() {
259            return Err(ScryLearnError::InvalidParameter(
260                "model was trained on sparse data; use predict_sparse() instead".into(),
261            ));
262        }
263        let probas = self.compute_votes(features);
264        Ok(probas
265            .into_iter()
266            .map(|votes| {
267                // Fold to keep the *first* class with max votes on ties
268                // (sklearn picks lowest class index).
269                votes
270                    .iter()
271                    .enumerate()
272                    .fold((0usize, f64::NEG_INFINITY), |(best_i, best_v), (i, &v)| {
273                        if v > best_v {
274                            (i, v)
275                        } else {
276                            (best_i, best_v)
277                        }
278                    })
279                    .0 as f64
280            })
281            .collect())
282    }
283
284    /// Predict class probability distribution for each sample.
285    ///
286    /// Returns a `Vec<Vec<f64>>` where `result[i][c]` is the estimated
287    /// probability that sample `i` belongs to class `c`. Probabilities
288    /// sum to 1.0 for each sample.
289    ///
290    /// # Example
291    ///
292    /// ```
293    /// use scry_learn::dataset::Dataset;
294    /// use scry_learn::neighbors::KnnClassifier;
295    ///
296    /// let features = vec![
297    ///     vec![0.0, 0.0, 10.0, 10.0],
298    ///     vec![0.0, 0.0, 10.0, 10.0],
299    /// ];
300    /// let target = vec![0.0, 0.0, 1.0, 1.0];
301    /// let data = Dataset::new(features, target, vec!["x".into(), "y".into()], "class");
302    ///
303    /// let mut knn = KnnClassifier::new().k(3);
304    /// knn.fit(&data).unwrap();
305    ///
306    /// let probas = knn.predict_proba(&[vec![1.0, 1.0]]).unwrap();
307    /// let sum: f64 = probas[0].iter().sum();
308    /// assert!((sum - 1.0).abs() < 1e-9);
309    /// ```
310    pub fn predict_proba(&self, features: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
311        if !self.fitted {
312            return Err(ScryLearnError::NotFitted);
313        }
314        if self.train_features.is_empty() && self.train_sparse.is_some() {
315            return Err(ScryLearnError::InvalidParameter(
316                "model was trained on sparse data; use predict_sparse() instead".into(),
317            ));
318        }
319        let votes = self.compute_votes(features);
320        Ok(votes
321            .into_iter()
322            .map(|v| {
323                let total: f64 = v.iter().sum();
324                if total > 0.0 {
325                    v.iter().map(|&x| x / total).collect()
326                } else {
327                    // Fallback: uniform distribution.
328                    let n = v.len() as f64;
329                    vec![1.0 / n; v.len()]
330                }
331            })
332            .collect())
333    }
334
335    /// Core voting logic shared by `predict` and `predict_proba`.
336    ///
337    /// Returns raw weighted vote counts per class for each query sample.
338    ///
339    /// When the metric is Euclidean and no KD-tree is in use, distances
340    /// are computed in a single batch via [`ComputeBackend`], which
341    /// uses GPU compute shaders when the `gpu` feature is enabled and
342    /// the dataset is large enough.
343    #[allow(clippy::option_if_let_else)]
344    fn compute_votes(&self, features: &[Vec<f64>]) -> Vec<Vec<f64>> {
345        let k = self.k.min(self.train_features.len());
346        let use_actual_dist = matches!(self.weight_fn, WeightFunction::Distance);
347        let metric = self.metric;
348
349        // Try batched backend path for Euclidean brute-force.
350        let batched = if self.kdtree.is_none() && matches!(metric, DistanceMetric::Euclidean) {
351            batched_brute_force_neighbors(features, &self.train_features, k, use_actual_dist)
352        } else {
353            None
354        };
355
356        if let Some(all_neighbors) = batched {
357            // Batched path — distances already computed.
358            all_neighbors
359                .into_iter()
360                .map(|neighbors| {
361                    aggregate_votes(
362                        &neighbors,
363                        &self.train_target,
364                        &self.train_weights,
365                        self.n_classes,
366                        use_actual_dist,
367                    )
368                })
369                .collect()
370        } else {
371            // Per-sample path (KD-tree or non-Euclidean metric).
372            let n_train = self.train_features.len();
373            let n_features = if n_train > 0 {
374                self.train_features[0].len()
375            } else {
376                0
377            };
378            let use_par =
379                self.kdtree.is_none() && features.len() * n_train * n_features >= KNN_PAR_THRESHOLD;
380
381            let vote_fn = |query: &Vec<f64>| {
382                let neighbors: Vec<(f64, usize)> = if let Some(ref tree) = self.kdtree {
383                    let raw = tree.query_k_nearest(query, k, &self.train_features);
384                    if use_actual_dist {
385                        raw.into_iter().map(|(d2, i)| (d2.sqrt(), i)).collect()
386                    } else {
387                        raw
388                    }
389                } else {
390                    scalar_brute_force(query, &self.train_features, k, metric, use_actual_dist)
391                };
392
393                aggregate_votes(
394                    &neighbors,
395                    &self.train_target,
396                    &self.train_weights,
397                    self.n_classes,
398                    use_actual_dist,
399                )
400            };
401
402            if use_par {
403                features.par_iter().map(vote_fn).collect()
404            } else {
405                features.iter().map(vote_fn).collect()
406            }
407        }
408    }
409}
410
411impl KnnClassifier {
412    /// Predict class labels from sparse features (CSR format).
413    ///
414    /// Uses true sparse distance computation via merge-join on sorted indices,
415    /// avoiding densification. Supports Euclidean, Manhattan, and Cosine metrics.
416    pub fn predict_sparse(&self, features: &CsrMatrix) -> Result<Vec<f64>> {
417        if !self.fitted {
418            return Err(ScryLearnError::NotFitted);
419        }
420        let n_train = self.train_target.len();
421        let k = self.k.min(n_train);
422        let use_actual_dist = matches!(self.weight_fn, WeightFunction::Distance);
423
424        Ok((0..features.n_rows())
425            .map(|i| {
426                let query = features.row(i);
427                let neighbors = if let Some(ref train_csr) = self.train_sparse {
428                    sparse_brute_force(&query, train_csr, k, self.metric, use_actual_dist)
429                } else {
430                    let dense = sparse_row_to_dense(&query, features.n_cols());
431                    scalar_brute_force(
432                        &dense,
433                        &self.train_features,
434                        k,
435                        self.metric,
436                        use_actual_dist,
437                    )
438                };
439                let votes = aggregate_votes(
440                    &neighbors,
441                    &self.train_target,
442                    &self.train_weights,
443                    self.n_classes,
444                    use_actual_dist,
445                );
446                votes
447                    .iter()
448                    .enumerate()
449                    .fold((0usize, f64::NEG_INFINITY), |(best_i, best_v), (i, &v)| {
450                        if v > best_v {
451                            (i, v)
452                        } else {
453                            (best_i, best_v)
454                        }
455                    })
456                    .0 as f64
457            })
458            .collect())
459    }
460}
461
462impl Default for KnnClassifier {
463    fn default() -> Self {
464        Self::new()
465    }
466}
467
468// ─────────────────────────────────────────────────────────────────
469// KNN Regressor
470// ─────────────────────────────────────────────────────────────────
471
472/// K-Nearest Neighbors regressor.
473///
474/// Predicts the (optionally distance-weighted) mean of the k-nearest
475/// training targets for each query point.
476///
477/// # Example
478///
479/// ```
480/// use scry_learn::dataset::Dataset;
481/// use scry_learn::neighbors::{KnnRegressor, WeightFunction};
482///
483/// let features = vec![vec![1.0, 2.0, 3.0]];
484/// let target = vec![10.0, 20.0, 30.0];
485/// let data = Dataset::new(features, target, vec!["x".into()], "y");
486///
487/// let mut knn = KnnRegressor::new()
488///     .k(2)
489///     .weights(WeightFunction::Uniform);
490/// knn.fit(&data).unwrap();
491///
492/// let preds = knn.predict(&[vec![2.5]]).unwrap();
493/// // Nearest neighbors are x=2 (y=20) and x=3 (y=30) → mean = 25
494/// assert!((preds[0] - 25.0).abs() < 1e-9);
495/// ```
496#[derive(Clone)]
497#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
498#[non_exhaustive]
499pub struct KnnRegressor {
500    k: usize,
501    metric: DistanceMetric,
502    weight_fn: WeightFunction,
503    algorithm: Algorithm,
504    train_features: Vec<Vec<f64>>, // [n_samples][n_features]
505    train_target: Vec<f64>,
506    kdtree: Option<KdTree>,
507    /// Sparse training data (CSR) for sparse-native distance computation.
508    train_sparse: Option<CsrMatrix>,
509    fitted: bool,
510    #[cfg_attr(feature = "serde", serde(default))]
511    _schema_version: u32,
512}
513
514impl KnnRegressor {
515    /// Create a new KNN regressor with k=5.
516    pub fn new() -> Self {
517        Self {
518            k: 5,
519            metric: DistanceMetric::Euclidean,
520            weight_fn: WeightFunction::Uniform,
521            algorithm: Algorithm::Auto,
522            train_features: Vec::new(),
523            train_target: Vec::new(),
524            kdtree: None,
525            train_sparse: None,
526            fitted: false,
527            _schema_version: crate::version::SCHEMA_VERSION,
528        }
529    }
530
531    /// Set the number of neighbors.
532    pub fn k(mut self, k: usize) -> Self {
533        self.k = k;
534        self
535    }
536
537    /// Set the distance metric.
538    pub fn metric(mut self, m: DistanceMetric) -> Self {
539        self.metric = m;
540        self
541    }
542
543    /// Set the neighbor weighting function.
544    ///
545    /// - [`WeightFunction::Uniform`]: all k neighbors contribute equally to the mean.
546    /// - [`WeightFunction::Distance`]: closer neighbors are weighted by `1/distance`.
547    pub fn weights(mut self, w: WeightFunction) -> Self {
548        self.weight_fn = w;
549        self
550    }
551
552    /// Set the nearest-neighbor search algorithm.
553    ///
554    /// See [`Algorithm`] for details.
555    pub fn algorithm(mut self, algo: Algorithm) -> Self {
556        self.algorithm = algo;
557        self
558    }
559
560    /// Store training data. Builds KD-tree if appropriate.
561    ///
562    /// If the dataset uses sparse storage, the CSR representation is stored
563    /// for efficient sparse distance computation in [`KnnRegressor::predict_sparse`].
564    pub fn fit(&mut self, data: &Dataset) -> Result<()> {
565        data.validate_finite()?;
566        if data.n_samples() == 0 {
567            return Err(ScryLearnError::EmptyDataset);
568        }
569
570        if let Some(csr) = data.sparse_csr() {
571            self.train_sparse = Some(csr);
572            self.train_features = Vec::new();
573        } else {
574            self.train_sparse = None;
575            self.train_features = data.feature_matrix();
576        }
577
578        self.train_target.clone_from(&data.target);
579
580        self.kdtree = if self.train_sparse.is_none()
581            && should_use_kdtree(self.algorithm, self.metric, data.n_features())
582        {
583            Some(KdTree::build(&self.train_features))
584        } else {
585            None
586        };
587
588        self.fitted = true;
589        Ok(())
590    }
591
592    /// Predict continuous target values.
593    ///
594    /// For each query point, finds the k nearest training samples and returns
595    /// their mean (or distance-weighted mean) target value.
596    ///
597    /// When the metric is Euclidean and no KD-tree is in use, distances
598    /// are computed in a single batch via `ComputeBackend`.
599    #[allow(clippy::option_if_let_else)]
600    pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
601        crate::version::check_schema_version(self._schema_version)?;
602        if !self.fitted {
603            return Err(ScryLearnError::NotFitted);
604        }
605        if self.train_features.is_empty() && self.train_sparse.is_some() {
606            return Err(ScryLearnError::InvalidParameter(
607                "model was trained on sparse data; use predict_sparse() instead".into(),
608            ));
609        }
610
611        let k = self.k.min(self.train_features.len());
612        let use_actual_dist = matches!(self.weight_fn, WeightFunction::Distance);
613        let metric = self.metric;
614
615        // Try batched backend path for Euclidean brute-force.
616        let batched = if self.kdtree.is_none() && matches!(metric, DistanceMetric::Euclidean) {
617            batched_brute_force_neighbors(features, &self.train_features, k, use_actual_dist)
618        } else {
619            None
620        };
621
622        let get_neighbors = |query: &Vec<f64>| -> Vec<(f64, usize)> {
623            if let Some(ref tree) = self.kdtree {
624                let raw = tree.query_k_nearest(query, k, &self.train_features);
625                if use_actual_dist {
626                    raw.into_iter().map(|(d2, i)| (d2.sqrt(), i)).collect()
627                } else {
628                    raw
629                }
630            } else {
631                scalar_brute_force(query, &self.train_features, k, metric, use_actual_dist)
632            }
633        };
634
635        if let Some(ref all) = batched {
636            // Batched path — already computed.
637            Ok(features
638                .iter()
639                .enumerate()
640                .map(|(qi, _query)| {
641                    aggregate_regression(&all[qi], &self.train_target, use_actual_dist, k)
642                })
643                .collect())
644        } else {
645            let n_train = self.train_features.len();
646            let n_features = if n_train > 0 {
647                self.train_features[0].len()
648            } else {
649                0
650            };
651            let use_par =
652                self.kdtree.is_none() && features.len() * n_train * n_features >= KNN_PAR_THRESHOLD;
653
654            let predict_fn = |query: &Vec<f64>| {
655                let neighbors = get_neighbors(query);
656                aggregate_regression(&neighbors, &self.train_target, use_actual_dist, k)
657            };
658
659            if use_par {
660                Ok(features.par_iter().map(predict_fn).collect())
661            } else {
662                Ok(features.iter().map(predict_fn).collect())
663            }
664        }
665    }
666}
667
668impl KnnRegressor {
669    /// Predict from sparse features (CSR format).
670    ///
671    /// Uses true sparse distance computation via merge-join on sorted indices.
672    pub fn predict_sparse(&self, features: &CsrMatrix) -> Result<Vec<f64>> {
673        if !self.fitted {
674            return Err(ScryLearnError::NotFitted);
675        }
676        let n_train = self.train_target.len();
677        let k = self.k.min(n_train);
678        let use_actual_dist = matches!(self.weight_fn, WeightFunction::Distance);
679
680        Ok((0..features.n_rows())
681            .map(|i| {
682                let query = features.row(i);
683                let neighbors = if let Some(ref train_csr) = self.train_sparse {
684                    sparse_brute_force(&query, train_csr, k, self.metric, use_actual_dist)
685                } else {
686                    let dense = sparse_row_to_dense(&query, features.n_cols());
687                    scalar_brute_force(
688                        &dense,
689                        &self.train_features,
690                        k,
691                        self.metric,
692                        use_actual_dist,
693                    )
694                };
695                aggregate_regression(&neighbors, &self.train_target, use_actual_dist, k)
696            })
697            .collect())
698    }
699}
700
701impl Default for KnnRegressor {
702    fn default() -> Self {
703        Self::new()
704    }
705}
706
707// ─────────────────────────────────────────────────────────────────
708// Shared helpers
709// ─────────────────────────────────────────────────────────────────
710
711/// Per-sample brute-force distance computation.
712///
713/// Used when the batched backend path is not applicable (non-Euclidean metric).
714fn scalar_brute_force(
715    query: &[f64],
716    train: &[Vec<f64>],
717    k: usize,
718    metric: DistanceMetric,
719    use_actual_dist: bool,
720) -> Vec<(f64, usize)> {
721    let mut dists: Vec<(f64, usize)> = train
722        .iter()
723        .enumerate()
724        .map(|(i, train_row)| {
725            let d = if use_actual_dist {
726                actual_distance(query, train_row, metric)
727            } else {
728                distance_for_compare(query, train_row, metric)
729            };
730            (d, i)
731        })
732        .collect();
733
734    if k < dists.len() {
735        dists.select_nth_unstable_by(k - 1, |a, b| {
736            a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
737        });
738    }
739    dists.truncate(k);
740    // Stable sort by (distance, index) to deterministically prefer
741    // lower-index training samples on ties (matches sklearn behavior).
742    dists.sort_by(|a, b| {
743        a.0.partial_cmp(&b.0)
744            .unwrap_or(std::cmp::Ordering::Equal)
745            .then(a.1.cmp(&b.1))
746    });
747    dists
748}
749
750/// Batched brute-force using `ComputeBackend::pairwise_distances_squared()`.
751///
752/// Returns `Some(neighbors)` where `neighbors[i]` is a `Vec<(dist, idx)>` of
753/// k-nearest for query `i`, or `None` if batch threshold isn't met.
754///
755/// Only valid for Euclidean distance (squared distances preserve ordering).
756fn batched_brute_force_neighbors(
757    queries: &[Vec<f64>],
758    train: &[Vec<f64>],
759    k: usize,
760    use_actual_dist: bool,
761) -> Option<Vec<Vec<(f64, usize)>>> {
762    let n_q = queries.len();
763    let n_t = train.len();
764    if n_q == 0 || n_t == 0 {
765        return None;
766    }
767    let dim = queries[0].len();
768
769    // Only worth batching for reasonably sized problems.
770    // The backend has its own internal thresholds too.
771    if n_q * n_t < 256 {
772        return None;
773    }
774
775    // Flatten row-major: queries[n_q][dim] → flat[n_q * dim]
776    let q_flat: Vec<f64> = queries.iter().flat_map(|r| r.iter().copied()).collect();
777    let t_flat: Vec<f64> = train.iter().flat_map(|r| r.iter().copied()).collect();
778
779    let backend = accel::auto();
780    let dist_matrix = backend.pairwise_distances_squared(&q_flat, &t_flat, n_q, n_t, dim);
781
782    let result: Vec<Vec<(f64, usize)>> = (0..n_q)
783        .map(|qi| {
784            let row = &dist_matrix[qi * n_t..(qi + 1) * n_t];
785            let mut indexed: Vec<(f64, usize)> = row
786                .iter()
787                .enumerate()
788                .map(|(j, &d2)| {
789                    let d = if use_actual_dist { d2.sqrt() } else { d2 };
790                    (d, j)
791                })
792                .collect();
793
794            let k_eff = k.min(indexed.len());
795            if k_eff < indexed.len() {
796                indexed.select_nth_unstable_by(k_eff - 1, |a, b| {
797                    a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
798                });
799            }
800            indexed.truncate(k_eff);
801            // Stable sort by (distance, index) — matches sklearn tie-breaking.
802            indexed.sort_by(|a, b| {
803                a.0.partial_cmp(&b.0)
804                    .unwrap_or(std::cmp::Ordering::Equal)
805                    .then(a.1.cmp(&b.1))
806            });
807            indexed
808        })
809        .collect();
810
811    Some(result)
812}
813
814/// Aggregate weighted votes for classification.
815fn aggregate_votes(
816    neighbors: &[(f64, usize)],
817    target: &[f64],
818    weights: &[f64],
819    n_classes: usize,
820    use_actual_dist: bool,
821) -> Vec<f64> {
822    let mut votes = vec![0.0_f64; n_classes.max(1)];
823
824    if use_actual_dist {
825        let has_exact = neighbors.iter().any(|&(d, _)| d < f64::EPSILON);
826        if has_exact {
827            for &(d, idx) in neighbors {
828                if d < f64::EPSILON {
829                    let class = target[idx] as usize;
830                    let w = weights[idx];
831                    if class < votes.len() {
832                        votes[class] += w;
833                    }
834                }
835            }
836        } else {
837            for &(d, idx) in neighbors {
838                let class = target[idx] as usize;
839                let w = weights[idx];
840                if class < votes.len() {
841                    votes[class] += w / d;
842                }
843            }
844        }
845    } else {
846        for &(_, idx) in neighbors {
847            let class = target[idx] as usize;
848            let w = weights[idx];
849            if class < votes.len() {
850                votes[class] += w;
851            }
852        }
853    }
854
855    votes
856}
857
858/// Aggregate predictions for regression.
859fn aggregate_regression(
860    neighbors: &[(f64, usize)],
861    target: &[f64],
862    use_actual_dist: bool,
863    k: usize,
864) -> f64 {
865    if use_actual_dist {
866        let has_exact = neighbors.iter().any(|&(d, _)| d < f64::EPSILON);
867        if has_exact {
868            let (sum, count) = neighbors.iter().fold((0.0, 0usize), |(s, c), &(d, idx)| {
869                if d < f64::EPSILON {
870                    (s + target[idx], c + 1)
871                } else {
872                    (s, c)
873                }
874            });
875            sum / count as f64
876        } else {
877            let (weighted_sum, total_w) =
878                neighbors.iter().fold((0.0, 0.0), |(ws, tw), &(d, idx)| {
879                    let w = 1.0 / d;
880                    (ws + w * target[idx], tw + w)
881                });
882            weighted_sum / total_w
883        }
884    } else {
885        let sum: f64 = neighbors.iter().map(|&(_, idx)| target[idx]).sum();
886        sum / k as f64
887    }
888}
889
890// ─────────────────────────────────────────────────────────────────
891// Distance functions
892// ─────────────────────────────────────────────────────────────────
893
894/// Compute distance for comparison purposes (skips sqrt for Euclidean).
895///
896/// For Euclidean, returns squared distance (monotonic — preserves ordering).
897/// For Manhattan and Cosine, returns the actual distance.
898#[inline]
899fn distance_for_compare(a: &[f64], b: &[f64], metric: DistanceMetric) -> f64 {
900    match metric {
901        DistanceMetric::Euclidean => euclidean_sq(a, b),
902        DistanceMetric::Manhattan => manhattan(a, b),
903        DistanceMetric::Cosine => cosine_distance(a, b),
904    }
905}
906
907/// Compute the actual distance (with sqrt for Euclidean).
908///
909/// Used when `WeightFunction::Distance` is active, since we need true
910/// distances for the `1/d` weighting.
911#[inline]
912fn actual_distance(a: &[f64], b: &[f64], metric: DistanceMetric) -> f64 {
913    match metric {
914        DistanceMetric::Euclidean => euclidean_sq(a, b).sqrt(),
915        DistanceMetric::Manhattan => manhattan(a, b),
916        DistanceMetric::Cosine => cosine_distance(a, b),
917    }
918}
919
920/// Convert a sparse row view to a dense vector.
921fn sparse_row_to_dense(row: &SparseRow<'_>, n_cols: usize) -> Vec<f64> {
922    let mut dense = vec![0.0; n_cols];
923    for (col, val) in row.iter() {
924        dense[col] = val;
925    }
926    dense
927}
928
929/// Compute sparse distance for comparison (skips sqrt for Euclidean).
930#[inline]
931fn sparse_distance_for_compare(
932    a: &SparseRow<'_>,
933    b: &SparseRow<'_>,
934    metric: DistanceMetric,
935) -> f64 {
936    match metric {
937        DistanceMetric::Euclidean => sparse_euclidean_sq(a, b),
938        DistanceMetric::Manhattan => sparse_manhattan(a, b),
939        DistanceMetric::Cosine => sparse_cosine(a, b),
940    }
941}
942
943/// Compute actual sparse distance (with sqrt for Euclidean).
944#[inline]
945fn sparse_actual_distance(a: &SparseRow<'_>, b: &SparseRow<'_>, metric: DistanceMetric) -> f64 {
946    match metric {
947        DistanceMetric::Euclidean => sparse_euclidean_sq(a, b).sqrt(),
948        DistanceMetric::Manhattan => sparse_manhattan(a, b),
949        DistanceMetric::Cosine => sparse_cosine(a, b),
950    }
951}
952
953/// Brute-force k-nearest on sparse training data.
954fn sparse_brute_force(
955    query: &SparseRow<'_>,
956    train: &CsrMatrix,
957    k: usize,
958    metric: DistanceMetric,
959    use_actual_dist: bool,
960) -> Vec<(f64, usize)> {
961    let n = train.n_rows();
962    let mut dists: Vec<(f64, usize)> = (0..n)
963        .map(|i| {
964            let train_row = train.row(i);
965            let d = if use_actual_dist {
966                sparse_actual_distance(query, &train_row, metric)
967            } else {
968                sparse_distance_for_compare(query, &train_row, metric)
969            };
970            (d, i)
971        })
972        .collect();
973
974    if k < dists.len() {
975        dists.select_nth_unstable_by(k - 1, |a, b| {
976            a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
977        });
978    }
979    dists.truncate(k);
980    dists.sort_by(|a, b| {
981        a.0.partial_cmp(&b.0)
982            .unwrap_or(std::cmp::Ordering::Equal)
983            .then(a.1.cmp(&b.1))
984    });
985    dists
986}
987
988/// Decide whether to use the KD-tree based on algorithm selection, metric, and dimensionality.
989fn should_use_kdtree(algo: Algorithm, metric: DistanceMetric, n_features: usize) -> bool {
990    match algo {
991        Algorithm::BruteForce => false,
992        Algorithm::KDTree => matches!(metric, DistanceMetric::Euclidean),
993        Algorithm::Auto => matches!(metric, DistanceMetric::Euclidean) && n_features < 20,
994    }
995}
996
997#[cfg(test)]
998mod tests {
999    use super::*;
1000
1001    #[test]
1002    fn test_knn_simple() {
1003        // Two clusters: class 0 near origin, class 1 near (10, 10).
1004        let features = vec![
1005            vec![0.0, 0.0, 0.0, 10.0, 10.0, 10.0],
1006            vec![0.0, 0.0, 0.0, 10.0, 10.0, 10.0],
1007        ];
1008        let target = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
1009        let data = Dataset::new(features, target, vec!["x".into(), "y".into()], "class");
1010
1011        let mut knn = KnnClassifier::new().k(3);
1012        knn.fit(&data).unwrap();
1013
1014        let preds = knn.predict(&[vec![1.0, 1.0], vec![9.0, 9.0]]).unwrap();
1015        assert!((preds[0] - 0.0).abs() < 1e-6);
1016        assert!((preds[1] - 1.0).abs() < 1e-6);
1017    }
1018
1019    #[test]
1020    fn test_knn_distance_weights() {
1021        // 3 class-0 samples far away, 2 class-1 samples very close to query.
1022        // With distance weights, class 1 should win (closer neighbors dominate).
1023        // Query at x=0.15: class-1 at 0.1 (d=0.05), 0.2 (d=0.05); class-0 at 5, 10, 10 (far).
1024        let features = vec![vec![5.0, 10.0, 10.0, 0.1, 0.2]];
1025        let target = vec![0.0, 0.0, 0.0, 1.0, 1.0];
1026        let data = Dataset::new(features, target, vec!["x".into()], "class");
1027
1028        let mut knn_dist = KnnClassifier::new().k(5).weights(WeightFunction::Distance);
1029        knn_dist.fit(&data).unwrap();
1030        let preds_d = knn_dist.predict(&[vec![0.15]]).unwrap();
1031        assert_eq!(
1032            preds_d[0] as usize, 1,
1033            "Distance-weighted should pick closer class 1"
1034        );
1035    }
1036
1037    #[test]
1038    fn test_knn_predict_proba() {
1039        let features = vec![vec![0.0, 0.0, 10.0, 10.0], vec![0.0, 0.0, 10.0, 10.0]];
1040        let target = vec![0.0, 0.0, 1.0, 1.0];
1041        let data = Dataset::new(features, target, vec!["x".into(), "y".into()], "class");
1042
1043        let mut knn = KnnClassifier::new().k(4);
1044        knn.fit(&data).unwrap();
1045
1046        let probas = knn
1047            .predict_proba(&[vec![1.0, 1.0], vec![5.0, 5.0]])
1048            .unwrap();
1049        for p in &probas {
1050            let sum: f64 = p.iter().sum();
1051            assert!(
1052                (sum - 1.0).abs() < 1e-9,
1053                "Probabilities must sum to 1.0, got {sum}"
1054            );
1055        }
1056
1057        // Point near class 0 should have higher probability for class 0.
1058        assert!(
1059            probas[0][0] > 0.4,
1060            "Expected high prob for class 0 at (1,1)"
1061        );
1062    }
1063
1064    #[test]
1065    fn test_knn_cosine() {
1066        // Cosine distance ignores magnitude — direction matters.
1067        // [1, 0] and [100, 0] have same direction → distance ≈ 0.
1068        // [1, 0] and [0, 1] are orthogonal → distance ≈ 1.
1069        let d_same = cosine_distance(&[1.0, 0.0], &[100.0, 0.0]);
1070        let d_orth = cosine_distance(&[1.0, 0.0], &[0.0, 1.0]);
1071        assert!(
1072            d_same < 1e-9,
1073            "Same direction should have ~0 distance, got {d_same}"
1074        );
1075        assert!(
1076            (d_orth - 1.0).abs() < 1e-9,
1077            "Orthogonal should have distance ~1, got {d_orth}"
1078        );
1079
1080        // Use cosine metric in classifier.
1081        let features = vec![vec![1.0, 100.0, 0.0, 0.0], vec![0.0, 0.0, 1.0, 100.0]];
1082        let target = vec![0.0, 0.0, 1.0, 1.0];
1083        let data = Dataset::new(features, target, vec!["x".into(), "y".into()], "class");
1084
1085        let mut knn = KnnClassifier::new().k(2).metric(DistanceMetric::Cosine);
1086        knn.fit(&data).unwrap();
1087
1088        // Query [50, 0] has same direction as class 0.
1089        let preds = knn.predict(&[vec![50.0, 0.0]]).unwrap();
1090        assert_eq!(
1091            preds[0] as usize, 0,
1092            "Cosine metric should match class 0 by direction"
1093        );
1094    }
1095
1096    #[test]
1097    fn test_knn_regressor_simple() {
1098        // 3 points: x=1→y=10, x=5→y=50, x=9→y=90
1099        let features = vec![vec![1.0, 5.0, 9.0]];
1100        let target = vec![10.0, 50.0, 90.0];
1101        let data = Dataset::new(features, target, vec!["x".into()], "y");
1102
1103        let mut knn = KnnRegressor::new().k(2);
1104        knn.fit(&data).unwrap();
1105
1106        // Query x=3: nearest are x=1(y=10) and x=5(y=50) → mean=30
1107        let preds = knn.predict(&[vec![3.0]]).unwrap();
1108        assert!(
1109            (preds[0] - 30.0).abs() < 1e-9,
1110            "Expected 30.0, got {}",
1111            preds[0]
1112        );
1113
1114        // Query x=7: nearest are x=5(y=50) and x=9(y=90) → mean=70
1115        let preds2 = knn.predict(&[vec![7.0]]).unwrap();
1116        assert!(
1117            (preds2[0] - 70.0).abs() < 1e-9,
1118            "Expected 70.0, got {}",
1119            preds2[0]
1120        );
1121    }
1122
1123    #[test]
1124    fn test_knn_regressor_distance_weights() {
1125        // x=0→y=0, x=10→y=100. Query at x=1 (much closer to x=0).
1126        // Uniform: mean(0, 100) = 50.
1127        // Distance: weighted toward x=0 → should be << 50.
1128        let features = vec![vec![0.0, 10.0]];
1129        let target = vec![0.0, 100.0];
1130        let data = Dataset::new(features, target, vec!["x".into()], "y");
1131
1132        let mut knn_u = KnnRegressor::new().k(2);
1133        knn_u.fit(&data).unwrap();
1134        let pred_u = knn_u.predict(&[vec![1.0]]).unwrap()[0];
1135        assert!((pred_u - 50.0).abs() < 1e-9, "Uniform should give 50.0");
1136
1137        let mut knn_d = KnnRegressor::new().k(2).weights(WeightFunction::Distance);
1138        knn_d.fit(&data).unwrap();
1139        let pred_d = knn_d.predict(&[vec![1.0]]).unwrap()[0];
1140        // 1/1 * 0 + 1/9 * 100 = 11.11... / (1 + 0.111...) = ~10
1141        assert!(
1142            pred_d < 20.0,
1143            "Distance-weighted should favor x=0, got {pred_d}"
1144        );
1145    }
1146
1147    #[test]
1148    fn test_knn_not_fitted() {
1149        let knn = KnnClassifier::new();
1150        assert!(knn.predict(&[vec![1.0]]).is_err());
1151        assert!(knn.predict_proba(&[vec![1.0]]).is_err());
1152
1153        let knn_r = KnnRegressor::new();
1154        assert!(knn_r.predict(&[vec![1.0]]).is_err());
1155    }
1156
1157    #[test]
1158    fn test_knn_predict_sparse_matches_dense() {
1159        let features = vec![
1160            vec![0.0, 0.0, 0.0, 10.0, 10.0, 10.0],
1161            vec![0.0, 0.0, 0.0, 10.0, 10.0, 10.0],
1162        ];
1163        let target = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
1164        let data = Dataset::new(features, target, vec!["x".into(), "y".into()], "class");
1165
1166        let mut knn = KnnClassifier::new().k(3);
1167        knn.fit(&data).unwrap();
1168
1169        let test = vec![vec![1.0, 1.0], vec![9.0, 9.0]];
1170        let preds_dense = knn.predict(&test).unwrap();
1171        let csr = CsrMatrix::from_dense(&test);
1172        let preds_sparse = knn.predict_sparse(&csr).unwrap();
1173
1174        for (d, s) in preds_dense.iter().zip(preds_sparse.iter()) {
1175            assert!((d - s).abs() < 1e-6, "Dense={d} vs Sparse={s}");
1176        }
1177    }
1178
1179    #[test]
1180    fn test_knn_regressor_predict_sparse() {
1181        let features = vec![vec![1.0, 5.0, 9.0]];
1182        let target = vec![10.0, 50.0, 90.0];
1183        let data = Dataset::new(features, target, vec!["x".into()], "y");
1184
1185        let mut knn = KnnRegressor::new().k(2);
1186        knn.fit(&data).unwrap();
1187
1188        let test = vec![vec![3.0], vec![7.0]];
1189        let preds_dense = knn.predict(&test).unwrap();
1190        let csr = CsrMatrix::from_dense(&test);
1191        let preds_sparse = knn.predict_sparse(&csr).unwrap();
1192
1193        for (d, s) in preds_dense.iter().zip(preds_sparse.iter()) {
1194            assert!((d - s).abs() < 1e-6, "Dense={d} vs Sparse={s}");
1195        }
1196    }
1197
1198    #[test]
1199    fn test_sparse_euclidean_matches_dense() {
1200        // Dense: d²([1,0,3], [0,2,3]) = 1 + 4 + 0 = 5
1201        let a = CsrMatrix::from_dense(&[vec![1.0, 0.0, 3.0]]);
1202        let b = CsrMatrix::from_dense(&[vec![0.0, 2.0, 3.0]]);
1203        let d2 = sparse_euclidean_sq(&a.row(0), &b.row(0));
1204        assert!((d2 - 5.0).abs() < 1e-10, "Expected 5.0, got {d2}");
1205    }
1206
1207    #[test]
1208    fn test_sparse_manhattan_matches_dense() {
1209        // Dense: d([1,0,3], [0,2,3]) = 1 + 2 + 0 = 3
1210        let a = CsrMatrix::from_dense(&[vec![1.0, 0.0, 3.0]]);
1211        let b = CsrMatrix::from_dense(&[vec![0.0, 2.0, 3.0]]);
1212        let d = sparse_manhattan(&a.row(0), &b.row(0));
1213        assert!((d - 3.0).abs() < 1e-10, "Expected 3.0, got {d}");
1214    }
1215
1216    #[test]
1217    fn test_sparse_cosine_matches_dense() {
1218        // Same direction → distance ≈ 0
1219        let a = CsrMatrix::from_dense(&[vec![1.0, 0.0]]);
1220        let b = CsrMatrix::from_dense(&[vec![100.0, 0.0]]);
1221        let d = sparse_cosine(&a.row(0), &b.row(0));
1222        assert!(d < 1e-9, "Same direction should be ~0, got {d}");
1223
1224        // Orthogonal → distance ≈ 1
1225        let c = CsrMatrix::from_dense(&[vec![0.0, 1.0]]);
1226        let d_orth = sparse_cosine(&a.row(0), &c.row(0));
1227        assert!(
1228            (d_orth - 1.0).abs() < 1e-9,
1229            "Orthogonal should be ~1, got {d_orth}"
1230        );
1231    }
1232
1233    #[test]
1234    fn test_sparse_knn_end_to_end() {
1235        // Train on dense, predict_sparse with CSR — results should match.
1236        use crate::sparse::CscMatrix;
1237        let features = vec![
1238            vec![0.0, 0.0, 0.0, 10.0, 10.0, 10.0],
1239            vec![0.0, 0.0, 0.0, 10.0, 10.0, 10.0],
1240        ];
1241        let target = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
1242        let data = Dataset::new(
1243            features.clone(),
1244            target.clone(),
1245            vec!["x".into(), "y".into()],
1246            "class",
1247        );
1248
1249        // Fit on dense.
1250        let mut knn_dense = KnnClassifier::new().k(3);
1251        knn_dense.fit(&data).unwrap();
1252
1253        // Fit on sparse.
1254        let csc = CscMatrix::from_dense(&features);
1255        let data_sparse = Dataset::from_sparse(csc, target, vec!["x".into(), "y".into()], "class");
1256        let mut knn_sparse = KnnClassifier::new().k(3);
1257        knn_sparse.fit(&data_sparse).unwrap();
1258        assert!(knn_sparse.train_sparse.is_some());
1259
1260        // Query with sparse input.
1261        let test = vec![vec![1.0, 1.0], vec![9.0, 9.0]];
1262        let preds_dense = knn_dense.predict(&test).unwrap();
1263        let csr = CsrMatrix::from_dense(&test);
1264        let preds_sparse = knn_sparse.predict_sparse(&csr).unwrap();
1265
1266        for (d, s) in preds_dense.iter().zip(preds_sparse.iter()) {
1267            assert!((d - s).abs() < 1e-6, "Dense={d} vs Sparse={s}");
1268        }
1269    }
1270
1271    #[test]
1272    fn test_high_dimensional_sparse_knn() {
1273        // 100×5000 matrix with ~2% density — should complete without OOM.
1274        // (Would require 100 × 400KB per query if densifying.)
1275        use crate::sparse::CscMatrix;
1276        let n_train = 100;
1277        let n_feat = 5000;
1278        let mut rng = crate::rng::FastRng::new(42);
1279
1280        // Build sparse training data as column-major.
1281        let mut cols: Vec<Vec<f64>> = vec![vec![0.0; n_train]; n_feat];
1282        for col in &mut cols {
1283            for x in col.iter_mut() {
1284                if rng.f64() < 0.02 {
1285                    *x = rng.f64() * 10.0;
1286                }
1287            }
1288        }
1289        let target: Vec<f64> = (0..n_train).map(|i| (i % 3) as f64).collect();
1290        let csc = CscMatrix::from_dense(&cols);
1291        let names: Vec<String> = (0..n_feat).map(|j| format!("f{j}")).collect();
1292        let data = Dataset::from_sparse(csc, target, names, "class");
1293
1294        let mut knn = KnnClassifier::new().k(5);
1295        knn.fit(&data).unwrap();
1296        assert!(knn.train_sparse.is_some());
1297
1298        // Build sparse query.
1299        let mut query_row = vec![0.0; n_feat];
1300        for x in &mut query_row {
1301            if rng.f64() < 0.02 {
1302                *x = rng.f64() * 10.0;
1303            }
1304        }
1305        let query_csr = CsrMatrix::from_dense(&[query_row]);
1306        let preds = knn.predict_sparse(&query_csr).unwrap();
1307        assert_eq!(preds.len(), 1);
1308        assert!(preds[0] >= 0.0 && preds[0] < 3.0);
1309    }
1310}
1311
1312#[cfg(all(test, feature = "scry-gpu"))]
1313mod gpu_tests {
1314    use super::*;
1315
1316    #[test]
1317    fn gpu_knn_classifier_batched_matches_scalar() {
1318        // 100 training samples × 5 features, 10 queries → 1000 pairs (above 256 threshold)
1319        let n_train = 100;
1320        let n_feat = 5;
1321        let mut features_col: Vec<Vec<f64>> = Vec::with_capacity(n_feat);
1322        for j in 0..n_feat {
1323            let col: Vec<f64> = (0..n_train)
1324                .map(|i| ((i * (j + 3)) % 37) as f64 * 0.5)
1325                .collect();
1326            features_col.push(col);
1327        }
1328        let target: Vec<f64> = (0..n_train).map(|i| (i % 3) as f64).collect();
1329        let names: Vec<String> = (0..n_feat).map(|j| format!("f{j}")).collect();
1330        let data = Dataset::new(features_col, target, names, "class");
1331
1332        let mut knn = KnnClassifier::new().k(5).algorithm(Algorithm::BruteForce);
1333        knn.fit(&data).unwrap();
1334
1335        // 10 queries — enough to trigger batched path
1336        let queries: Vec<Vec<f64>> = (0..10)
1337            .map(|i| (0..n_feat).map(|j| ((i + j) % 17) as f64 * 0.3).collect())
1338            .collect();
1339
1340        let preds = knn.predict(&queries).unwrap();
1341        assert_eq!(preds.len(), 10);
1342        for p in &preds {
1343            assert!(
1344                *p >= 0.0 && *p < 3.0,
1345                "prediction must be a valid class: {p}"
1346            );
1347        }
1348    }
1349
1350    #[test]
1351    fn gpu_knn_regressor_batched_matches_scalar() {
1352        let n_train = 100;
1353        let n_feat = 5;
1354        let mut features_col: Vec<Vec<f64>> = Vec::with_capacity(n_feat);
1355        for j in 0..n_feat {
1356            let col: Vec<f64> = (0..n_train)
1357                .map(|i| ((i * (j + 2)) % 41) as f64 * 0.2)
1358                .collect();
1359            features_col.push(col);
1360        }
1361        let target: Vec<f64> = (0..n_train).map(|i| (i % 50) as f64).collect();
1362        let names: Vec<String> = (0..n_feat).map(|j| format!("f{j}")).collect();
1363        let data = Dataset::new(features_col, target, names, "y");
1364
1365        let mut knn = KnnRegressor::new().k(5).algorithm(Algorithm::BruteForce);
1366        knn.fit(&data).unwrap();
1367
1368        let queries: Vec<Vec<f64>> = (0..10)
1369            .map(|i| (0..n_feat).map(|j| ((i + j) % 19) as f64 * 0.4).collect())
1370            .collect();
1371
1372        let preds = knn.predict(&queries).unwrap();
1373        assert_eq!(preds.len(), 10);
1374        for p in &preds {
1375            assert!(p.is_finite(), "prediction must be finite: {p}");
1376        }
1377    }
1378}