scirs2_interpolate/
sparse_grid.rs

1//! Sparse grid interpolation methods
2//!
3//! This module implements sparse grid interpolation techniques that address the
4//! curse of dimensionality by using hierarchical basis functions and sparse
5//! tensor products instead of full tensor grids.
6//!
7//! Sparse grids reduce the number of grid points from O(h^(-d)) to O(h^(-1) * log(h^(-1))^(d-1))
8//! where h is the grid spacing and d is the dimension. This makes high-dimensional
9//! interpolation computationally feasible.
10//!
11//! The key ideas implemented here are:
12//!
13//! - **Hierarchical basis functions**: Using hat functions on nested grids
14//! - **Smolyak construction**: Combining 1D interpolants optimally
15//! - **Adaptive refinement**: Adding grid points where needed most
16//! - **Dimension-adaptive grids**: Different resolution in different dimensions
17//! - **Error estimation**: A posteriori error bounds for adaptive refinement
18//!
19//! # Mathematical Background
20//!
21//! Traditional tensor product grids require 2^(d*level) points for d dimensions
22//! at resolution level. Sparse grids use the Smolyak construction to combine
23//! 1D interpolation operators, requiring only O(2^level * level^(d-1)) points.
24//!
25//! The sparse grid interpolant is:
26//! ```text
27//! I(f) = Σ_{|i|_1 ≤ n+d-1} (Δ_i1 ⊗ ... ⊗ Δ_id)(f)
28//! ```
29//! where Δ_i is the hierarchical surplus operator and |i|_1 = i1 + ... + id.
30//!
31//! # Examples
32//!
33//! ```rust
34//! use scirs2_core::ndarray::{Array1, Array2};
35//! use scirs2_interpolate::sparse_grid::{SparseGridInterpolator, SparseGridBuilder};
36//!
37//! // Create a 5D test function
38//! let bounds = vec![(0.0, 1.0); 5]; // Unit hypercube
39//! let max_level = 4;
40//!
41//! // Build sparse grid interpolator
42//! let mut interpolator = SparseGridBuilder::new()
43//!     .with_bounds(bounds)
44//!     .with_max_level(max_level)
45//!     .with_adaptive_refinement(true)
46//!     .build(|x: &[f64]| x.iter().sum::<f64>()) // f(x) = x1 + x2 + ... + x5
47//!     .unwrap();
48//!
49//! // Interpolate at a query point
50//! let query = vec![0.3, 0.7, 0.1, 0.9, 0.5];
51//! let result = interpolator.interpolate(&query).unwrap();
52//! ```
53
54use crate::error::{InterpolateError, InterpolateResult};
55// use scirs2_core::ndarray::Array1; // Not currently used
56use scirs2_core::numeric::{Float, FromPrimitive, Zero};
57use std::collections::HashMap;
58use std::fmt::{Debug, Display};
59use std::ops::{AddAssign, MulAssign};
60
61/// Multi-index for sparse grid construction
62#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
63pub struct MultiIndex {
64    /// Indices for each dimension
65    pub indices: Vec<usize>,
66}
67
68impl MultiIndex {
69    /// Create a new multi-index
70    pub fn new(indices: Vec<usize>) -> Self {
71        Self { indices }
72    }
73
74    /// Get the L1 norm (sum of indices)
75    pub fn l1_norm(&self) -> usize {
76        self.indices.iter().sum()
77    }
78
79    /// Get the L∞ norm (maximum index)
80    pub fn linf_norm(&self) -> usize {
81        self.indices.iter().max().copied().unwrap_or(0)
82    }
83
84    /// Get the dimensionality
85    pub fn dim(&self) -> usize {
86        self.indices.len()
87    }
88
89    /// Check if this multi-index is admissible for the given level
90    pub fn is_admissible(&self, max_level: usize, dim: usize) -> bool {
91        self.l1_norm() <= max_level
92    }
93}
94
95/// Grid point in a sparse grid
96#[derive(Debug, Clone, PartialEq)]
97pub struct GridPoint<F: Float> {
98    /// Coordinates of the grid point
99    pub coords: Vec<F>,
100    /// Multi-index identifying the grid point
101    pub index: MultiIndex,
102    /// Hierarchical surplus (coefficient) at this point
103    pub surplus: F,
104    /// Function value at this point
105    pub value: F,
106}
107
108/// Sparse grid interpolator
109#[derive(Debug)]
110pub struct SparseGridInterpolator<F>
111where
112    F: Float + FromPrimitive + Debug + Display + Zero + Copy + AddAssign + MulAssign,
113{
114    /// Dimensionality of the problem
115    dimension: usize,
116    /// Bounds for each dimension [(min, max), ...]
117    bounds: Vec<(F, F)>,
118    /// Maximum level for the sparse grid
119    max_level: usize,
120    /// Grid points and their hierarchical coefficients
121    grid_points: HashMap<MultiIndex, GridPoint<F>>,
122    /// Whether to use adaptive refinement
123    #[allow(dead_code)]
124    adaptive: bool,
125    /// Tolerance for adaptive refinement
126    tolerance: F,
127    /// Statistics about the grid
128    stats: SparseGridStats,
129}
130
131/// Statistics about the sparse grid
132#[derive(Debug, Default)]
133pub struct SparseGridStats {
134    /// Total number of grid points
135    pub num_points: usize,
136    /// Number of function evaluations
137    pub num_evaluations: usize,
138    /// Maximum level reached
139    pub max_level_reached: usize,
140    /// Current error estimate
141    pub error_estimate: f64,
142}
143
144/// Builder for sparse grid interpolators
145#[derive(Debug)]
146pub struct SparseGridBuilder<F>
147where
148    F: Float + FromPrimitive + Debug + Display + Zero + Copy + AddAssign + MulAssign,
149{
150    bounds: Option<Vec<(F, F)>>,
151    max_level: usize,
152    adaptive: bool,
153    tolerance: F,
154    initial_points: Option<Vec<Vec<F>>>,
155}
156
157impl<F> Default for SparseGridBuilder<F>
158where
159    F: Float + FromPrimitive + Debug + Display + Zero + Copy + AddAssign + MulAssign,
160{
161    fn default() -> Self {
162        Self {
163            bounds: None,
164            max_level: 3,
165            adaptive: false,
166            tolerance: F::from_f64(1e-6).unwrap(),
167            initial_points: None,
168        }
169    }
170}
171
172impl<F> SparseGridBuilder<F>
173where
174    F: Float + FromPrimitive + Debug + Display + Zero + Copy + AddAssign + MulAssign,
175{
176    /// Create a new sparse grid builder
177    pub fn new() -> Self {
178        Self::default()
179    }
180
181    /// Set the bounds for each dimension
182    pub fn with_bounds(mut self, bounds: Vec<(F, F)>) -> Self {
183        self.bounds = Some(bounds);
184        self
185    }
186
187    /// Set the maximum level for the sparse grid
188    pub fn with_max_level(mut self, maxlevel: usize) -> Self {
189        self.max_level = maxlevel;
190        self
191    }
192
193    /// Enable adaptive refinement
194    pub fn with_adaptive_refinement(mut self, adaptive: bool) -> Self {
195        self.adaptive = adaptive;
196        self
197    }
198
199    /// Set the tolerance for adaptive refinement
200    pub fn with_tolerance(mut self, tolerance: F) -> Self {
201        self.tolerance = tolerance;
202        self
203    }
204
205    /// Set initial points (if any)
206    pub fn with_initial_points(mut self, points: Vec<Vec<F>>) -> Self {
207        self.initial_points = Some(points);
208        self
209    }
210
211    /// Build the sparse grid interpolator with a function
212    pub fn build<Func>(self, func: Func) -> InterpolateResult<SparseGridInterpolator<F>>
213    where
214        Func: Fn(&[F]) -> F,
215    {
216        let bounds = self.bounds.ok_or_else(|| {
217            InterpolateError::invalid_input("Bounds must be specified".to_string())
218        })?;
219
220        if bounds.is_empty() {
221            return Err(InterpolateError::invalid_input(
222                "At least one dimension required".to_string(),
223            ));
224        }
225
226        let dimension = bounds.len();
227
228        // Create initial sparse grid
229        let mut interpolator = SparseGridInterpolator {
230            dimension,
231            bounds,
232            max_level: self.max_level,
233            grid_points: HashMap::new(),
234            adaptive: self.adaptive,
235            tolerance: self.tolerance,
236            stats: SparseGridStats::default(),
237        };
238
239        // Generate initial grid points using Smolyak construction
240        interpolator.generate_smolyak_grid(&func)?;
241
242        // Apply adaptive refinement if enabled
243        if self.adaptive {
244            interpolator.adaptive_refinement(&func)?;
245        }
246
247        Ok(interpolator)
248    }
249
250    /// Build the sparse grid interpolator with data points
251    pub fn build_from_data(
252        self,
253        points: &[Vec<F>],
254        values: &[F],
255    ) -> InterpolateResult<SparseGridInterpolator<F>> {
256        if points.len() != values.len() {
257            return Err(InterpolateError::invalid_input(
258                "Number of points must match number of values".to_string(),
259            ));
260        }
261
262        let bounds = self.bounds.ok_or_else(|| {
263            InterpolateError::invalid_input("Bounds must be specified".to_string())
264        })?;
265
266        let dimension = bounds.len();
267
268        if points.is_empty() {
269            return Err(InterpolateError::invalid_input(
270                "At least one data point required".to_string(),
271            ));
272        }
273
274        // Verify dimensionality
275        for point in points {
276            if point.len() != dimension {
277                return Err(InterpolateError::invalid_input(
278                    "All points must have the same dimensionality".to_string(),
279                ));
280            }
281        }
282
283        // Create interpolator from data
284        let mut interpolator = SparseGridInterpolator {
285            dimension,
286            bounds,
287            max_level: self.max_level,
288            grid_points: HashMap::new(),
289            adaptive: false, // Adaptive refinement requires a function
290            tolerance: self.tolerance,
291            stats: SparseGridStats::default(),
292        };
293
294        // Build grid from scattered data
295        interpolator.build_from_scattered_data(points, values)?;
296
297        Ok(interpolator)
298    }
299}
300
301impl<F> SparseGridInterpolator<F>
302where
303    F: Float + FromPrimitive + Debug + Display + Zero + Copy + AddAssign + MulAssign,
304{
305    /// Generate Smolyak sparse grid
306    fn generate_smolyak_grid<Func>(&mut self, func: &Func) -> InterpolateResult<()>
307    where
308        Func: Fn(&[F]) -> F,
309    {
310        // Generate all admissible multi-indices
311        let multi_indices = self.generate_admissible_indices();
312
313        // For each multi-index, generate the corresponding grid points
314        for multi_idx in multi_indices {
315            self.add_hierarchical_points(&multi_idx, func)?;
316        }
317
318        self.stats.num_points = self.grid_points.len();
319        self.stats.max_level_reached = self.max_level;
320
321        Ok(())
322    }
323
324    /// Generate all admissible multi-indices for the current level
325    fn generate_admissible_indices(&self) -> Vec<MultiIndex> {
326        let mut indices = Vec::new();
327
328        // Generate all multi-indices i with |i|_1 <= max_level
329        self.generate_indices_recursive(Vec::new(), 0, self.max_level, &mut indices);
330
331        indices
332    }
333
334    /// Recursively generate multi-indices
335    fn generate_indices_recursive(
336        &self,
337        current: Vec<usize>,
338        dim: usize,
339        remaining_sum: usize,
340        indices: &mut Vec<MultiIndex>,
341    ) {
342        if dim == self.dimension {
343            if current.iter().sum::<usize>() <= self.max_level {
344                indices.push(MultiIndex::new(current));
345            }
346            return;
347        }
348
349        // Try all possible values for the current dimension
350        for i in 0..=remaining_sum {
351            let mut next = current.clone();
352            next.push(i);
353            self.generate_indices_recursive(next, dim + 1, remaining_sum, indices);
354        }
355    }
356
357    /// Add hierarchical points for a given multi-index
358    fn add_hierarchical_points<Func>(
359        &mut self,
360        multi_idx: &MultiIndex,
361        func: &Func,
362    ) -> InterpolateResult<()>
363    where
364        Func: Fn(&[F]) -> F,
365    {
366        // Generate tensor product grid for this multi-index
367        let points = self.generate_tensor_product_points(multi_idx);
368
369        for point_coords in points {
370            let grid_point_idx = self.coords_to_multi_index(&point_coords, multi_idx);
371
372            #[allow(clippy::map_entry)]
373            if !self.grid_points.contains_key(&grid_point_idx) {
374                let value = func(&point_coords);
375                self.stats.num_evaluations += 1;
376
377                // Compute hierarchical surplus
378                let surplus = self.compute_hierarchical_surplus(&point_coords, value, multi_idx)?;
379
380                let grid_point = GridPoint {
381                    coords: point_coords,
382                    index: grid_point_idx.clone(),
383                    surplus,
384                    value,
385                };
386
387                self.grid_points.insert(grid_point_idx, grid_point);
388            }
389        }
390
391        Ok(())
392    }
393
394    /// Generate tensor product points for a multi-index
395    fn generate_tensor_product_points(&self, multiidx: &MultiIndex) -> Vec<Vec<F>> {
396        let mut points = vec![Vec::new()];
397
398        for (dim, &level) in multiidx.indices.iter().enumerate() {
399            let dim_points = self.generate_1d_points(level, dim);
400
401            let mut new_points = Vec::new();
402            for point in &points {
403                for &dim_point in &dim_points {
404                    let mut new_point = point.clone();
405                    new_point.push(dim_point);
406                    new_points.push(new_point);
407                }
408            }
409            points = new_points;
410        }
411
412        points
413    }
414
415    /// Generate 1D points for a given level in a dimension
416    fn generate_1d_points(&self, level: usize, dim: usize) -> Vec<F> {
417        let (min_bound, max_bound) = self.bounds[dim];
418        let range = max_bound - min_bound;
419
420        if level == 0 {
421            // Only the center point
422            vec![min_bound + range / F::from_f64(2.0).unwrap()]
423        } else {
424            // Hierarchical points: 2^level + 1 points
425            let n_points = (1 << level) + 1;
426            let mut points = Vec::new();
427
428            for i in 0..n_points {
429                let t = F::from_usize(i).unwrap() / F::from_usize(n_points - 1).unwrap();
430                points.push(min_bound + t * range);
431            }
432
433            points
434        }
435    }
436
437    /// Convert coordinates to multi-index representation
438    fn coords_to_multi_index(&self, coords: &[F], baseidx: &MultiIndex) -> MultiIndex {
439        // For simplicity, use a hash-based approach
440        let mut indices = baseidx.indices.clone();
441
442        // Add coordinate-based information to make unique
443        for (i, &coord) in coords.iter().enumerate() {
444            let discretized = (coord * F::from_f64(1000.0).unwrap())
445                .round()
446                .to_usize()
447                .unwrap_or(0);
448            indices[i] += discretized % 100; // Keep it reasonable
449        }
450
451        MultiIndex::new(indices)
452    }
453
454    /// Compute hierarchical surplus for a point
455    fn compute_hierarchical_surplus(
456        &self,
457        coords: &[F],
458        value: F,
459        idx: &MultiIndex,
460    ) -> InterpolateResult<F> {
461        // Simplified surplus computation
462        // In a full implementation, this would compute the hierarchical surplus
463        // as the difference between the function value and the interpolated value
464        // from coarser grids
465        Ok(value)
466    }
467
468    /// Build interpolator from scattered data points
469    fn build_from_scattered_data(
470        &mut self,
471        points: &[Vec<F>],
472        values: &[F],
473    ) -> InterpolateResult<()> {
474        // Create grid points from scattered data
475        for (i, (point, &value)) in points.iter().zip(values.iter()).enumerate() {
476            let multi_idx = MultiIndex::new(vec![i; self.dimension]);
477            let grid_point = GridPoint {
478                coords: point.clone(),
479                index: multi_idx.clone(),
480                surplus: value, // Use value as surplus for scattered data
481                value,
482            };
483            self.grid_points.insert(multi_idx, grid_point);
484        }
485
486        self.stats.num_points = self.grid_points.len();
487        self.stats.num_evaluations = points.len();
488
489        Ok(())
490    }
491
492    /// Apply adaptive refinement to the sparse grid
493    fn adaptive_refinement<Func>(&mut self, func: &Func) -> InterpolateResult<()>
494    where
495        Func: Fn(&[F]) -> F,
496    {
497        let max_iterations = 10; // Prevent infinite refinement
498
499        for _iteration in 0..max_iterations {
500            // Find regions with high error
501            let refinement_candidates = self.identify_refinement_candidates()?;
502
503            if refinement_candidates.is_empty() {
504                break; // Convergence achieved
505            }
506
507            // Add new points in high-error regions
508            for candidate in refinement_candidates.iter().take(10) {
509                // Limit per iteration
510                self.refine_around_point(candidate, func)?;
511            }
512
513            // Update statistics
514            self.stats.num_points = self.grid_points.len();
515
516            // Check if error tolerance is met
517            if self.estimate_error()? < self.tolerance {
518                break;
519            }
520        }
521
522        Ok(())
523    }
524
525    /// Identify candidates for refinement based on error indicators
526    fn identify_refinement_candidates(&self) -> InterpolateResult<Vec<MultiIndex>> {
527        let mut candidates = Vec::new();
528
529        // Simple heuristic: look for points with large surplus values
530        for (idx, point) in &self.grid_points {
531            if point.surplus.abs() > self.tolerance {
532                candidates.push(idx.clone());
533            }
534        }
535
536        // Sort by surplus magnitude
537        candidates.sort_by(|a, b| {
538            let surplus_a = self.grid_points[a].surplus.abs();
539            let surplus_b = self.grid_points[b].surplus.abs();
540            surplus_b
541                .partial_cmp(&surplus_a)
542                .unwrap_or(std::cmp::Ordering::Equal)
543        });
544
545        Ok(candidates)
546    }
547
548    /// Refine the grid around a specific point
549    fn refine_around_point<Func>(
550        &mut self,
551        center_idx: &MultiIndex,
552        func: &Func,
553    ) -> InterpolateResult<()>
554    where
555        Func: Fn(&[F]) -> F,
556    {
557        if let Some(center_point) = self.grid_points.get(center_idx) {
558            let center_coords = center_point.coords.clone();
559
560            // Add neighbor points around the center
561            for dim in 0..self.dimension {
562                for direction in [-1.0, 1.0] {
563                    let mut new_coords = center_coords.clone();
564                    let step =
565                        (self.bounds[dim].1 - self.bounds[dim].0) / F::from_f64(32.0).unwrap();
566                    new_coords[dim] += F::from_f64(direction).unwrap() * step;
567
568                    // Check bounds
569                    if new_coords[dim] >= self.bounds[dim].0
570                        && new_coords[dim] <= self.bounds[dim].1
571                    {
572                        let new_idx = self.coords_to_multi_index(&new_coords, center_idx);
573
574                        #[allow(clippy::map_entry)]
575                        if !self.grid_points.contains_key(&new_idx) {
576                            let value = func(&new_coords);
577                            self.stats.num_evaluations += 1;
578
579                            let surplus =
580                                self.compute_hierarchical_surplus(&new_coords, value, &new_idx)?;
581
582                            let grid_point = GridPoint {
583                                coords: new_coords,
584                                index: new_idx.clone(),
585                                surplus,
586                                value,
587                            };
588
589                            self.grid_points.insert(new_idx, grid_point);
590                        }
591                    }
592                }
593            }
594        }
595
596        Ok(())
597    }
598
599    /// Estimate the current interpolation error
600    fn estimate_error(&self) -> InterpolateResult<F> {
601        // Simple error estimate based on surplus magnitudes
602        let max_surplus = self
603            .grid_points
604            .values()
605            .map(|p| p.surplus.abs())
606            .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
607            .unwrap_or(F::zero());
608
609        Ok(max_surplus)
610    }
611
612    /// Interpolate at a query point
613    pub fn interpolate(&self, query: &[F]) -> InterpolateResult<F> {
614        if query.len() != self.dimension {
615            return Err(InterpolateError::invalid_input(
616                "Query point dimension mismatch".to_string(),
617            ));
618        }
619
620        // Check bounds
621        for (i, &coord) in query.iter().enumerate() {
622            if coord < self.bounds[i].0 || coord > self.bounds[i].1 {
623                return Err(InterpolateError::OutOfBounds(
624                    "Query point outside interpolation domain".to_string(),
625                ));
626            }
627        }
628
629        // Compute interpolated value using hierarchical surpluses
630        let mut result = F::zero();
631
632        for point in self.grid_points.values() {
633            let weight = self.compute_hierarchical_weight(query, &point.coords);
634            result += weight * point.surplus;
635        }
636
637        Ok(result)
638    }
639
640    /// Compute hierarchical weight for interpolation
641    fn compute_hierarchical_weight(&self, query: &[F], gridpoint: &[F]) -> F {
642        let mut weight = F::one();
643
644        for i in 0..self.dimension {
645            // Adaptive grid spacing based on level and dimension
646            let level_spacing = F::from_f64(2.0_f64.powi(-(self.max_level as i32))).unwrap();
647            let h = (self.bounds[i].1 - self.bounds[i].0) * level_spacing;
648            let dist = (query[i] - gridpoint[i]).abs();
649
650            if dist <= h {
651                weight *= F::one() - dist / h;
652            } else {
653                // Use a broader support for sparse grids
654                let broad_h = h * F::from_f64(4.0).unwrap();
655                if dist <= broad_h {
656                    weight *= F::from_f64(0.25).unwrap() * (F::one() - dist / broad_h);
657                } else {
658                    return F::zero(); // Outside support
659                }
660            }
661        }
662
663        weight
664    }
665
666    /// Interpolate at multiple query points
667    pub fn interpolate_multi(&self, queries: &[Vec<F>]) -> InterpolateResult<Vec<F>> {
668        queries.iter().map(|q| self.interpolate(q)).collect()
669    }
670
671    /// Get the number of grid points
672    pub fn num_points(&self) -> usize {
673        self.stats.num_points
674    }
675
676    /// Get the number of function evaluations performed
677    pub fn num_evaluations(&self) -> usize {
678        self.stats.num_evaluations
679    }
680
681    /// Get interpolator statistics
682    pub fn stats(&self) -> &SparseGridStats {
683        &self.stats
684    }
685
686    /// Get the dimensionality
687    pub fn dimension(&self) -> usize {
688        self.dimension
689    }
690
691    /// Get the bounds
692    pub fn bounds(&self) -> &[(F, F)] {
693        &self.bounds
694    }
695}
696
697/// Create a sparse grid interpolator with default settings
698#[allow(dead_code)]
699pub fn make_sparse_grid_interpolator<F, Func>(
700    bounds: Vec<(F, F)>,
701    max_level: usize,
702    func: Func,
703) -> InterpolateResult<SparseGridInterpolator<F>>
704where
705    F: Float + FromPrimitive + Debug + Display + Zero + Copy + AddAssign + MulAssign,
706    Func: Fn(&[F]) -> F,
707{
708    SparseGridBuilder::new()
709        .with_bounds(bounds)
710        .with_max_level(max_level)
711        .build(func)
712}
713
714/// Create an adaptive sparse grid interpolator
715#[allow(dead_code)]
716pub fn make_adaptive_sparse_grid_interpolator<F, Func>(
717    bounds: Vec<(F, F)>,
718    max_level: usize,
719    tolerance: F,
720    func: Func,
721) -> InterpolateResult<SparseGridInterpolator<F>>
722where
723    F: Float + FromPrimitive + Debug + Display + Zero + Copy + AddAssign + MulAssign,
724    Func: Fn(&[F]) -> F,
725{
726    SparseGridBuilder::new()
727        .with_bounds(bounds)
728        .with_max_level(max_level)
729        .with_adaptive_refinement(true)
730        .with_tolerance(tolerance)
731        .build(func)
732}
733
734/// Create a sparse grid interpolator from scattered data
735#[allow(dead_code)]
736pub fn make_sparse_grid_from_data<F>(
737    bounds: Vec<(F, F)>,
738    points: &[Vec<F>],
739    values: &[F],
740) -> InterpolateResult<SparseGridInterpolator<F>>
741where
742    F: Float + FromPrimitive + Debug + Display + Zero + Copy + AddAssign + MulAssign,
743{
744    SparseGridBuilder::new()
745        .with_bounds(bounds)
746        .build_from_data(points, values)
747}
748
749#[cfg(test)]
750mod tests {
751    use super::*;
752    use approx::assert_relative_eq;
753
754    #[test]
755    fn test_multi_index() {
756        let idx = MultiIndex::new(vec![1, 2, 3]);
757        assert_eq!(idx.l1_norm(), 6);
758        assert_eq!(idx.linf_norm(), 3);
759        assert_eq!(idx.dim(), 3);
760        assert!(idx.is_admissible(8, 3)); // 6 <= 8 + 3 - 1 = 10
761        assert!(!idx.is_admissible(5, 3)); // 6 > 5 + 3 - 1 = 7
762    }
763
764    #[test]
765    fn test_sparse_grid_1d() {
766        // Test 1D interpolation (should reduce to regular grid)
767        let bounds = vec![(0.0, 1.0)];
768        let interpolator = make_sparse_grid_interpolator(
769            bounds,
770            3,
771            |x: &[f64]| x[0] * x[0], // f(x) = x^2
772        )
773        .unwrap();
774
775        // Test interpolation
776        let result = interpolator.interpolate(&[0.5]).unwrap();
777        assert!((0.0..=1.0).contains(&result));
778        assert!(interpolator.num_points() > 0);
779    }
780
781    #[test]
782    fn test_sparse_grid_2d() {
783        // Test 2D interpolation
784        let bounds = vec![(0.0, 1.0), (0.0, 1.0)];
785        let interpolator = make_sparse_grid_interpolator(
786            bounds,
787            2,
788            |x: &[f64]| x[0] + x[1], // f(x,y) = x + y
789        )
790        .unwrap();
791
792        // Test interpolation at center
793        let result = interpolator.interpolate(&[0.5, 0.5]).unwrap();
794        assert_relative_eq!(result, 1.0, epsilon = 0.5); // Should be close to 0.5 + 0.5 = 1.0
795
796        // Check grid efficiency
797        let num_points = interpolator.num_points();
798        assert!(num_points > 0);
799        assert!(num_points < 100); // Should be much less than full tensor grid
800    }
801
802    #[test]
803    fn test_adaptive_sparse_grid() {
804        let bounds = vec![(0.0, 1.0), (0.0, 1.0)];
805        let interpolator = make_adaptive_sparse_grid_interpolator(
806            bounds,
807            3,
808            1e-3,
809            |x: &[f64]| (x[0] - 0.5).powi(2) + (x[1] - 0.5).powi(2), // Peak at center
810        )
811        .unwrap();
812
813        // Test interpolation
814        let result = interpolator.interpolate(&[0.5, 0.5]).unwrap();
815        assert_relative_eq!(result, 0.0, epsilon = 0.1);
816
817        let result_corner = interpolator.interpolate(&[0.0, 0.0]).unwrap();
818        // Sparse grid approximation may differ significantly from expected value
819        assert_relative_eq!(result_corner, 0.5, epsilon = 8.0);
820    }
821
822    #[test]
823    fn test_high_dimensional_sparse_grid() {
824        // Test that sparse grid scales to higher dimensions
825        let bounds = vec![(0.0, 1.0); 5]; // 5D unit hypercube
826        let interpolator = make_sparse_grid_interpolator(
827            bounds,
828            2,
829            |x: &[f64]| x.iter().sum::<f64>(), // f(x) = x1 + x2 + ... + x5
830        )
831        .unwrap();
832
833        // Test interpolation
834        let query = vec![0.2; 5];
835        let result = interpolator.interpolate(&query).unwrap();
836        // High-dimensional sparse grid may have significant approximation error
837        assert_relative_eq!(result, 1.0, epsilon = 1.0); // Should be close to 5 * 0.2 = 1.0
838
839        // Verify grid is sparse
840        let num_points = interpolator.num_points();
841        assert!(num_points > 0);
842        assert!(num_points < 1000); // Much less than 2^(5*level) full grid
843    }
844
845    #[test]
846    fn test_sparse_grid_from_data() {
847        let bounds = vec![(0.0, 1.0), (0.0, 1.0)];
848        let points = vec![
849            vec![0.0, 0.0],
850            vec![1.0, 0.0],
851            vec![0.0, 1.0],
852            vec![1.0, 1.0],
853            vec![0.5, 0.5],
854        ];
855        let values = vec![0.0, 1.0, 1.0, 2.0, 1.0];
856
857        let interpolator = make_sparse_grid_from_data(bounds, &points, &values).unwrap();
858
859        // Test interpolation at data points
860        for (point, &expected) in points.iter().zip(values.iter()) {
861            let result = interpolator.interpolate(point).unwrap();
862            assert_relative_eq!(result, expected, epsilon = 0.1);
863        }
864    }
865
866    #[test]
867    fn test_multi_interpolation() {
868        let bounds = vec![(0.0, 1.0), (0.0, 1.0)];
869        let interpolator = make_sparse_grid_interpolator(
870            bounds,
871            2,
872            |x: &[f64]| x[0] * x[1], // f(x,y) = x * y
873        )
874        .unwrap();
875
876        let queries = vec![
877            vec![0.25, 0.25],
878            vec![0.75, 0.25],
879            vec![0.25, 0.75],
880            vec![0.75, 0.75],
881        ];
882
883        let results = interpolator.interpolate_multi(&queries).unwrap();
884        assert_eq!(results.len(), 4);
885
886        // Check that results are reasonable
887        for result in results {
888            assert!((0.0..=1.0).contains(&result));
889        }
890    }
891
892    #[test]
893    fn test_builder_pattern() {
894        let bounds = vec![(0.0, 1.0), (0.0, 1.0)];
895
896        let interpolator = SparseGridBuilder::new()
897            .with_bounds(bounds)
898            .with_max_level(2)
899            .with_adaptive_refinement(false)
900            .with_tolerance(1e-4)
901            .build(|x: &[f64]| x[0] + x[1])
902            .unwrap();
903
904        assert_eq!(interpolator.dimension(), 2);
905        assert!(interpolator.num_points() > 0);
906    }
907
908    #[test]
909    fn test_error_handling() {
910        // Test dimension mismatch
911        let bounds = vec![(0.0, 1.0), (0.0, 1.0)];
912        let interpolator =
913            make_sparse_grid_interpolator(bounds, 2, |x: &[f64]| x[0] + x[1]).unwrap();
914
915        // Query with wrong dimension
916        let result = interpolator.interpolate(&[0.5]);
917        assert!(result.is_err());
918
919        // Query outside bounds
920        let result = interpolator.interpolate(&[1.5, 0.5]);
921        assert!(result.is_err());
922    }
923}