scirs2_interpolate/gpu_kdtree/dispatch.rs
1//! Auto-dispatch layer for GPU vs CPU k-d tree queries.
2//!
3//! [`knn_auto_dispatch`] inspects the size of the tree and the query batch
4//! and, when both exceed the configured thresholds **and** the `gpu_kdtree`
5//! cargo feature is active, routes the request to the wgpu linear-scan
6//! backend. On failure, or when thresholds are not met, it falls through to
7//! the parallel CPU path.
8
9use super::tree::{GpuKdTree, KdQueryResult};
10use crate::error::InterpolateResult;
11
12// ---------------------------------------------------------------------------
13// Public types
14// ---------------------------------------------------------------------------
15
16/// Configuration for [`knn_auto_dispatch`].
17///
18/// Both thresholds must be exceeded for the GPU path to be attempted.
19#[derive(Debug, Clone)]
20pub struct KdTreeConfig {
21 /// Minimum number of points in the tree before GPU dispatch is tried.
22 /// Default: 100 000.
23 pub gpu_threshold_points: usize,
24 /// Minimum number of query points before GPU dispatch is tried.
25 /// Default: 1 000.
26 pub gpu_threshold_queries: usize,
27}
28
29impl Default for KdTreeConfig {
30 fn default() -> Self {
31 Self {
32 gpu_threshold_points: 100_000,
33 gpu_threshold_queries: 1_000,
34 }
35 }
36}
37
38// ---------------------------------------------------------------------------
39// Auto-dispatch
40// ---------------------------------------------------------------------------
41
42/// Compute batch k-NN, automatically choosing GPU or CPU path.
43///
44/// GPU dispatch requires:
45/// 1. The `gpu_kdtree` cargo feature to be active.
46/// 2. `tree.n_points() >= config.gpu_threshold_points`.
47/// 3. `queries.len() >= config.gpu_threshold_queries`.
48///
49/// Any failure in the GPU path causes a silent fallback to the CPU path.
50///
51/// # Arguments
52///
53/// * `tree` – A pre-built [`GpuKdTree`].
54/// * `queries` – Batch of query points; each must match `tree.dim()`.
55/// * `k` – Number of nearest neighbors to find per query.
56/// * `config` – Dispatch thresholds.
57///
58/// # Errors
59///
60/// Returns an error only when the CPU path fails (e.g. dimension mismatch).
61pub fn knn_auto_dispatch(
62 tree: &GpuKdTree,
63 queries: &[Vec<f64>],
64 k: usize,
65 config: &KdTreeConfig,
66) -> InterpolateResult<Vec<KdQueryResult>> {
67 // Attempt GPU path when conditions are met.
68 #[cfg(feature = "gpu_kdtree")]
69 if tree.n_points() >= config.gpu_threshold_points
70 && queries.len() >= config.gpu_threshold_queries
71 {
72 match super::wgpu_linear_scan::knn_wgpu(tree, queries, k) {
73 Ok(results) => return Ok(results),
74 // Fall through to CPU on any GPU error.
75 Err(_) => {}
76 }
77 }
78
79 // Suppress "unused variable" warning for `config` in non-gpu builds.
80 let _ = config;
81
82 // CPU path: always available.
83 tree.knn_batch(queries, k)
84}
85
86// ---------------------------------------------------------------------------
87// Tests
88// ---------------------------------------------------------------------------
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93 use crate::gpu_kdtree::GpuKdTree;
94
95 #[test]
96 fn test_knn_auto_dispatch_cpu_below_threshold() {
97 // Small dataset — should always use the CPU path.
98 let pts: Vec<Vec<f64>> = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.5, 0.5]];
99 let tree = GpuKdTree::new(pts).expect("build");
100 let queries = vec![vec![0.4_f64, 0.4]];
101 let cfg = KdTreeConfig::default(); // thresholds far above 3 points
102
103 let results = knn_auto_dispatch(&tree, &queries, 1, &cfg).expect("dispatch");
104 assert_eq!(results.len(), 1);
105 // Closest to (0.4, 0.4) is (0.5, 0.5) at dist² = 0.02
106 assert_eq!(results[0].indices[0], 2);
107 }
108
109 #[test]
110 fn test_knn_auto_dispatch_returns_correct_count() {
111 let pts: Vec<Vec<f64>> = (0..10).map(|i| vec![i as f64, 0.0]).collect();
112 let tree = GpuKdTree::new(pts).expect("build");
113 let queries: Vec<Vec<f64>> = (0..5).map(|i| vec![i as f64 + 0.1, 0.0]).collect();
114 let cfg = KdTreeConfig {
115 gpu_threshold_points: usize::MAX,
116 gpu_threshold_queries: usize::MAX,
117 };
118
119 let results = knn_auto_dispatch(&tree, &queries, 3, &cfg).expect("dispatch");
120 assert_eq!(results.len(), 5);
121 for r in &results {
122 assert_eq!(r.indices.len(), 3);
123 assert_eq!(r.distances_sq.len(), 3);
124 }
125 }
126
127 #[test]
128 fn test_knn_auto_dispatch_empty_queries() {
129 let pts = vec![vec![1.0_f64, 2.0]];
130 let tree = GpuKdTree::new(pts).expect("build");
131 let results = knn_auto_dispatch(&tree, &[], 1, &KdTreeConfig::default())
132 .expect("dispatch empty queries");
133 assert!(results.is_empty());
134 }
135
136 #[test]
137 fn test_kdtree_config_default_thresholds() {
138 let cfg = KdTreeConfig::default();
139 assert_eq!(cfg.gpu_threshold_points, 100_000);
140 assert_eq!(cfg.gpu_threshold_queries, 1_000);
141 }
142}