Skip to main content

scirs2_interpolate/gpu_kdtree/
dispatch.rs

1//! Auto-dispatch layer for GPU vs CPU k-d tree queries.
2//!
3//! [`knn_auto_dispatch`] inspects the size of the tree and the query batch
4//! and, when both exceed the configured thresholds **and** the `gpu_kdtree`
5//! cargo feature is active, routes the request to the wgpu linear-scan
6//! backend.  On failure, or when thresholds are not met, it falls through to
7//! the parallel CPU path.
8
9use super::tree::{GpuKdTree, KdQueryResult};
10use crate::error::InterpolateResult;
11
12// ---------------------------------------------------------------------------
13// Public types
14// ---------------------------------------------------------------------------
15
16/// Configuration for [`knn_auto_dispatch`].
17///
18/// Both thresholds must be exceeded for the GPU path to be attempted.
19#[derive(Debug, Clone)]
20pub struct KdTreeConfig {
21    /// Minimum number of points in the tree before GPU dispatch is tried.
22    /// Default: 100 000.
23    pub gpu_threshold_points: usize,
24    /// Minimum number of query points before GPU dispatch is tried.
25    /// Default: 1 000.
26    pub gpu_threshold_queries: usize,
27}
28
29impl Default for KdTreeConfig {
30    fn default() -> Self {
31        Self {
32            gpu_threshold_points: 100_000,
33            gpu_threshold_queries: 1_000,
34        }
35    }
36}
37
38// ---------------------------------------------------------------------------
39// Auto-dispatch
40// ---------------------------------------------------------------------------
41
42/// Compute batch k-NN, automatically choosing GPU or CPU path.
43///
44/// GPU dispatch requires:
45/// 1. The `gpu_kdtree` cargo feature to be active.
46/// 2. `tree.n_points() >= config.gpu_threshold_points`.
47/// 3. `queries.len() >= config.gpu_threshold_queries`.
48///
49/// Any failure in the GPU path causes a silent fallback to the CPU path.
50///
51/// # Arguments
52///
53/// * `tree`    – A pre-built [`GpuKdTree`].
54/// * `queries` – Batch of query points; each must match `tree.dim()`.
55/// * `k`       – Number of nearest neighbors to find per query.
56/// * `config`  – Dispatch thresholds.
57///
58/// # Errors
59///
60/// Returns an error only when the CPU path fails (e.g. dimension mismatch).
61pub fn knn_auto_dispatch(
62    tree: &GpuKdTree,
63    queries: &[Vec<f64>],
64    k: usize,
65    config: &KdTreeConfig,
66) -> InterpolateResult<Vec<KdQueryResult>> {
67    // Attempt GPU path when conditions are met.
68    #[cfg(feature = "gpu_kdtree")]
69    if tree.n_points() >= config.gpu_threshold_points
70        && queries.len() >= config.gpu_threshold_queries
71    {
72        match super::wgpu_linear_scan::knn_wgpu(tree, queries, k) {
73            Ok(results) => return Ok(results),
74            // Fall through to CPU on any GPU error.
75            Err(_) => {}
76        }
77    }
78
79    // Suppress "unused variable" warning for `config` in non-gpu builds.
80    let _ = config;
81
82    // CPU path: always available.
83    tree.knn_batch(queries, k)
84}
85
86// ---------------------------------------------------------------------------
87// Tests
88// ---------------------------------------------------------------------------
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93    use crate::gpu_kdtree::GpuKdTree;
94
95    #[test]
96    fn test_knn_auto_dispatch_cpu_below_threshold() {
97        // Small dataset — should always use the CPU path.
98        let pts: Vec<Vec<f64>> = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.5, 0.5]];
99        let tree = GpuKdTree::new(pts).expect("build");
100        let queries = vec![vec![0.4_f64, 0.4]];
101        let cfg = KdTreeConfig::default(); // thresholds far above 3 points
102
103        let results = knn_auto_dispatch(&tree, &queries, 1, &cfg).expect("dispatch");
104        assert_eq!(results.len(), 1);
105        // Closest to (0.4, 0.4) is (0.5, 0.5) at dist² = 0.02
106        assert_eq!(results[0].indices[0], 2);
107    }
108
109    #[test]
110    fn test_knn_auto_dispatch_returns_correct_count() {
111        let pts: Vec<Vec<f64>> = (0..10).map(|i| vec![i as f64, 0.0]).collect();
112        let tree = GpuKdTree::new(pts).expect("build");
113        let queries: Vec<Vec<f64>> = (0..5).map(|i| vec![i as f64 + 0.1, 0.0]).collect();
114        let cfg = KdTreeConfig {
115            gpu_threshold_points: usize::MAX,
116            gpu_threshold_queries: usize::MAX,
117        };
118
119        let results = knn_auto_dispatch(&tree, &queries, 3, &cfg).expect("dispatch");
120        assert_eq!(results.len(), 5);
121        for r in &results {
122            assert_eq!(r.indices.len(), 3);
123            assert_eq!(r.distances_sq.len(), 3);
124        }
125    }
126
127    #[test]
128    fn test_knn_auto_dispatch_empty_queries() {
129        let pts = vec![vec![1.0_f64, 2.0]];
130        let tree = GpuKdTree::new(pts).expect("build");
131        let results = knn_auto_dispatch(&tree, &[], 1, &KdTreeConfig::default())
132            .expect("dispatch empty queries");
133        assert!(results.is_empty());
134    }
135
136    #[test]
137    fn test_kdtree_config_default_thresholds() {
138        let cfg = KdTreeConfig::default();
139        assert_eq!(cfg.gpu_threshold_points, 100_000);
140        assert_eq!(cfg.gpu_threshold_queries, 1_000);
141    }
142}