Skip to main content

scirs2_interpolate/sparse_grid/
core.rs

1//! Sparse grid interpolation methods
2//!
3//! This module implements sparse grid interpolation techniques that address the
4//! curse of dimensionality by using hierarchical basis functions and sparse
5//! tensor products instead of full tensor grids.
6//!
7//! Sparse grids reduce the number of grid points from O(h^(-d)) to O(h^(-1) * log(h^(-1))^(d-1))
8//! where h is the grid spacing and d is the dimension. This makes high-dimensional
9//! interpolation computationally feasible.
10//!
11//! The key ideas implemented here are:
12//!
13//! - **Hierarchical basis functions**: Using hat functions on nested grids
14//! - **Smolyak construction**: Combining 1D interpolants optimally
15//! - **Adaptive refinement**: Adding grid points where needed most
16//! - **Dimension-adaptive grids**: Different resolution in different dimensions
17//! - **Error estimation**: A posteriori error bounds for adaptive refinement
18//!
19//! # Mathematical Background
20//!
21//! Traditional tensor product grids require 2^(d*level) points for d dimensions
22//! at resolution level. Sparse grids use the Smolyak construction to combine
23//! 1D interpolation operators, requiring only O(2^level * level^(d-1)) points.
24//!
25//! The sparse grid interpolant is:
26//! ```text
27//! I(f) = Σ_{|i|_1 ≤ n+d-1} (Δ_i1 ⊗ ... ⊗ Δ_id)(f)
28//! ```
29//! where Δ_i is the hierarchical surplus operator and |i|_1 = i1 + ... + id.
30//!
31//! # Examples
32//!
33//! ```rust
34//! use scirs2_core::ndarray::{Array1, Array2};
35//! use scirs2_interpolate::sparse_grid::{SparseGridInterpolator, SparseGridBuilder};
36//!
37//! // Create a 5D test function
38//! let bounds = vec![(0.0, 1.0); 5]; // Unit hypercube
39//! let max_level = 4;
40//!
41//! // Build sparse grid interpolator
42//! let mut interpolator = SparseGridBuilder::new()
43//!     .with_bounds(bounds)
44//!     .with_max_level(max_level)
45//!     .with_adaptive_refinement(true)
46//!     .build(|x: &[f64]| x.iter().sum::<f64>()) // f(x) = x1 + x2 + ... + x5
47//!     .expect("Operation failed");
48//!
49//! // Interpolate at a query point
50//! let query = vec![0.3, 0.7, 0.1, 0.9, 0.5];
51//! let result = interpolator.interpolate(&query).expect("Operation failed");
52//! ```
53
54use crate::error::{InterpolateError, InterpolateResult};
55// use scirs2_core::ndarray::Array1; // Not currently used
56use scirs2_core::numeric::{Float, FromPrimitive, Zero};
57use std::collections::HashMap;
58use std::fmt::{Debug, Display};
59use std::ops::{AddAssign, MulAssign};
60
61/// Multi-index for sparse grid construction
62#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
63pub struct MultiIndex {
64    /// Indices for each dimension
65    pub indices: Vec<usize>,
66}
67
68impl MultiIndex {
69    /// Create a new multi-index
70    pub fn new(indices: Vec<usize>) -> Self {
71        Self { indices }
72    }
73
74    /// Get the L1 norm (sum of indices)
75    pub fn l1_norm(&self) -> usize {
76        self.indices.iter().sum()
77    }
78
79    /// Get the L∞ norm (maximum index)
80    pub fn linf_norm(&self) -> usize {
81        self.indices.iter().max().copied().unwrap_or(0)
82    }
83
84    /// Get the dimensionality
85    pub fn dim(&self) -> usize {
86        self.indices.len()
87    }
88
89    /// Check if this multi-index is admissible for the given level
90    pub fn is_admissible(&self, max_level: usize, dim: usize) -> bool {
91        self.l1_norm() <= max_level
92    }
93}
94
95/// Grid point in a sparse grid
96#[derive(Debug, Clone, PartialEq)]
97pub struct GridPoint<F: Float> {
98    /// Coordinates of the grid point
99    pub coords: Vec<F>,
100    /// Multi-index identifying the grid point
101    pub index: MultiIndex,
102    /// Hierarchical surplus (coefficient) at this point
103    pub surplus: F,
104    /// Function value at this point
105    pub value: F,
106}
107
108/// Sparse grid interpolator
109#[derive(Debug)]
110pub struct SparseGridInterpolator<F>
111where
112    F: Float + FromPrimitive + Debug + Display + Zero + Copy + AddAssign + MulAssign,
113{
114    /// Dimensionality of the problem
115    dimension: usize,
116    /// Bounds for each dimension [(min, max), ...]
117    bounds: Vec<(F, F)>,
118    /// Maximum level for the sparse grid
119    max_level: usize,
120    /// Grid points and their hierarchical coefficients
121    grid_points: HashMap<MultiIndex, GridPoint<F>>,
122    /// Whether to use adaptive refinement
123    #[allow(dead_code)]
124    adaptive: bool,
125    /// Tolerance for adaptive refinement
126    tolerance: F,
127    /// Statistics about the grid
128    stats: SparseGridStats,
129}
130
131/// Statistics about the sparse grid
132#[derive(Debug, Default)]
133pub struct SparseGridStats {
134    /// Total number of grid points
135    pub num_points: usize,
136    /// Number of function evaluations
137    pub num_evaluations: usize,
138    /// Maximum level reached
139    pub max_level_reached: usize,
140    /// Current error estimate
141    pub error_estimate: f64,
142}
143
144/// Builder for sparse grid interpolators
145#[derive(Debug)]
146pub struct SparseGridBuilder<F>
147where
148    F: Float + FromPrimitive + Debug + Display + Zero + Copy + AddAssign + MulAssign,
149{
150    bounds: Option<Vec<(F, F)>>,
151    max_level: usize,
152    adaptive: bool,
153    tolerance: F,
154    initial_points: Option<Vec<Vec<F>>>,
155}
156
157impl<F> Default for SparseGridBuilder<F>
158where
159    F: Float + FromPrimitive + Debug + Display + Zero + Copy + AddAssign + MulAssign,
160{
161    fn default() -> Self {
162        Self {
163            bounds: None,
164            max_level: 3,
165            adaptive: false,
166            tolerance: F::from_f64(1e-6).expect("Operation failed"),
167            initial_points: None,
168        }
169    }
170}
171
172impl<F> SparseGridBuilder<F>
173where
174    F: Float + FromPrimitive + Debug + Display + Zero + Copy + AddAssign + MulAssign,
175{
176    /// Create a new sparse grid builder
177    pub fn new() -> Self {
178        Self::default()
179    }
180
181    /// Set the bounds for each dimension
182    pub fn with_bounds(mut self, bounds: Vec<(F, F)>) -> Self {
183        self.bounds = Some(bounds);
184        self
185    }
186
187    /// Set the maximum level for the sparse grid
188    pub fn with_max_level(mut self, maxlevel: usize) -> Self {
189        self.max_level = maxlevel;
190        self
191    }
192
193    /// Enable adaptive refinement
194    pub fn with_adaptive_refinement(mut self, adaptive: bool) -> Self {
195        self.adaptive = adaptive;
196        self
197    }
198
199    /// Set the tolerance for adaptive refinement
200    pub fn with_tolerance(mut self, tolerance: F) -> Self {
201        self.tolerance = tolerance;
202        self
203    }
204
205    /// Set initial points (if any)
206    pub fn with_initial_points(mut self, points: Vec<Vec<F>>) -> Self {
207        self.initial_points = Some(points);
208        self
209    }
210
211    /// Build the sparse grid interpolator with a function
212    pub fn build<Func>(self, func: Func) -> InterpolateResult<SparseGridInterpolator<F>>
213    where
214        Func: Fn(&[F]) -> F,
215    {
216        let bounds = self.bounds.ok_or_else(|| {
217            InterpolateError::invalid_input("Bounds must be specified".to_string())
218        })?;
219
220        if bounds.is_empty() {
221            return Err(InterpolateError::invalid_input(
222                "At least one dimension required".to_string(),
223            ));
224        }
225
226        let dimension = bounds.len();
227
228        // Create initial sparse grid
229        let mut interpolator = SparseGridInterpolator {
230            dimension,
231            bounds,
232            max_level: self.max_level,
233            grid_points: HashMap::new(),
234            adaptive: self.adaptive,
235            tolerance: self.tolerance,
236            stats: SparseGridStats::default(),
237        };
238
239        // Generate initial grid points using Smolyak construction
240        interpolator.generate_smolyak_grid(&func)?;
241
242        // Apply adaptive refinement if enabled
243        if self.adaptive {
244            interpolator.adaptive_refinement(&func)?;
245        }
246
247        Ok(interpolator)
248    }
249
250    /// Build the sparse grid interpolator with data points
251    pub fn build_from_data(
252        self,
253        points: &[Vec<F>],
254        values: &[F],
255    ) -> InterpolateResult<SparseGridInterpolator<F>> {
256        if points.len() != values.len() {
257            return Err(InterpolateError::invalid_input(
258                "Number of points must match number of values".to_string(),
259            ));
260        }
261
262        let bounds = self.bounds.ok_or_else(|| {
263            InterpolateError::invalid_input("Bounds must be specified".to_string())
264        })?;
265
266        let dimension = bounds.len();
267
268        if points.is_empty() {
269            return Err(InterpolateError::invalid_input(
270                "At least one data point required".to_string(),
271            ));
272        }
273
274        // Verify dimensionality
275        for point in points {
276            if point.len() != dimension {
277                return Err(InterpolateError::invalid_input(
278                    "All points must have the same dimensionality".to_string(),
279                ));
280            }
281        }
282
283        // Create interpolator from data
284        let mut interpolator = SparseGridInterpolator {
285            dimension,
286            bounds,
287            max_level: self.max_level,
288            grid_points: HashMap::new(),
289            adaptive: false, // Adaptive refinement requires a function
290            tolerance: self.tolerance,
291            stats: SparseGridStats::default(),
292        };
293
294        // Build grid from scattered data
295        interpolator.build_from_scattered_data(points, values)?;
296
297        Ok(interpolator)
298    }
299}
300
301impl<F> SparseGridInterpolator<F>
302where
303    F: Float + FromPrimitive + Debug + Display + Zero + Copy + AddAssign + MulAssign,
304{
305    /// Generate Smolyak sparse grid
306    fn generate_smolyak_grid<Func>(&mut self, func: &Func) -> InterpolateResult<()>
307    where
308        Func: Fn(&[F]) -> F,
309    {
310        // Generate all admissible multi-indices
311        let multi_indices = self.generate_admissible_indices();
312
313        // For each multi-index, generate the corresponding grid points
314        for multi_idx in multi_indices {
315            self.add_hierarchical_points(&multi_idx, func)?;
316        }
317
318        self.stats.num_points = self.grid_points.len();
319        self.stats.max_level_reached = self.max_level;
320
321        Ok(())
322    }
323
324    /// Generate all admissible multi-indices for the current level
325    fn generate_admissible_indices(&self) -> Vec<MultiIndex> {
326        let mut indices = Vec::new();
327
328        // Generate all multi-indices i with |i|_1 <= max_level
329        self.generate_indices_recursive(Vec::new(), 0, self.max_level, &mut indices);
330
331        indices
332    }
333
334    /// Recursively generate multi-indices
335    fn generate_indices_recursive(
336        &self,
337        current: Vec<usize>,
338        dim: usize,
339        remaining_sum: usize,
340        indices: &mut Vec<MultiIndex>,
341    ) {
342        if dim == self.dimension {
343            if current.iter().sum::<usize>() <= self.max_level {
344                indices.push(MultiIndex::new(current));
345            }
346            return;
347        }
348
349        // Try all possible values for the current dimension
350        for i in 0..=remaining_sum {
351            let mut next = current.clone();
352            next.push(i);
353            self.generate_indices_recursive(next, dim + 1, remaining_sum, indices);
354        }
355    }
356
357    /// Add hierarchical points for a given multi-index
358    fn add_hierarchical_points<Func>(
359        &mut self,
360        multi_idx: &MultiIndex,
361        func: &Func,
362    ) -> InterpolateResult<()>
363    where
364        Func: Fn(&[F]) -> F,
365    {
366        // Generate tensor product grid for this multi-index
367        let points = self.generate_tensor_product_points(multi_idx);
368
369        for point_coords in points {
370            let grid_point_idx = self.coords_to_multi_index(&point_coords, multi_idx);
371
372            #[allow(clippy::map_entry)]
373            if !self.grid_points.contains_key(&grid_point_idx) {
374                let value = func(&point_coords);
375                self.stats.num_evaluations += 1;
376
377                // Compute hierarchical surplus
378                let surplus = self.compute_hierarchical_surplus(&point_coords, value, multi_idx)?;
379
380                let grid_point = GridPoint {
381                    coords: point_coords,
382                    index: grid_point_idx.clone(),
383                    surplus,
384                    value,
385                };
386
387                self.grid_points.insert(grid_point_idx, grid_point);
388            }
389        }
390
391        Ok(())
392    }
393
394    /// Generate tensor product points for a multi-index
395    fn generate_tensor_product_points(&self, multiidx: &MultiIndex) -> Vec<Vec<F>> {
396        let mut points = vec![Vec::new()];
397
398        for (dim, &level) in multiidx.indices.iter().enumerate() {
399            let dim_points = self.generate_1d_points(level, dim);
400
401            let mut new_points = Vec::new();
402            for point in &points {
403                for &dim_point in &dim_points {
404                    let mut new_point = point.clone();
405                    new_point.push(dim_point);
406                    new_points.push(new_point);
407                }
408            }
409            points = new_points;
410        }
411
412        points
413    }
414
415    /// Generate 1D points for a given level in a dimension
416    fn generate_1d_points(&self, level: usize, dim: usize) -> Vec<F> {
417        let (min_bound, max_bound) = self.bounds[dim];
418        let range = max_bound - min_bound;
419
420        if level == 0 {
421            // Only the center point
422            vec![min_bound + range / F::from_f64(2.0).expect("Operation failed")]
423        } else {
424            // Hierarchical points: 2^level + 1 points
425            let n_points = (1 << level) + 1;
426            let mut points = Vec::new();
427
428            for i in 0..n_points {
429                let t = F::from_usize(i).expect("Operation failed")
430                    / F::from_usize(n_points - 1).expect("Operation failed");
431                points.push(min_bound + t * range);
432            }
433
434            points
435        }
436    }
437
438    /// Convert coordinates to multi-index representation
439    fn coords_to_multi_index(&self, coords: &[F], baseidx: &MultiIndex) -> MultiIndex {
440        // For simplicity, use a hash-based approach
441        let mut indices = baseidx.indices.clone();
442
443        // Add coordinate-based information to make unique
444        for (i, &coord) in coords.iter().enumerate() {
445            let discretized = (coord * F::from_f64(1000.0).expect("Operation failed"))
446                .round()
447                .to_usize()
448                .unwrap_or(0);
449            indices[i] += discretized % 100; // Keep it reasonable
450        }
451
452        MultiIndex::new(indices)
453    }
454
455    /// Compute hierarchical surplus for a point
456    fn compute_hierarchical_surplus(
457        &self,
458        coords: &[F],
459        value: F,
460        idx: &MultiIndex,
461    ) -> InterpolateResult<F> {
462        // Simplified surplus computation
463        // In a full implementation, this would compute the hierarchical surplus
464        // as the difference between the function value and the interpolated value
465        // from coarser grids
466        Ok(value)
467    }
468
469    /// Build interpolator from scattered data points
470    fn build_from_scattered_data(
471        &mut self,
472        points: &[Vec<F>],
473        values: &[F],
474    ) -> InterpolateResult<()> {
475        // Create grid points from scattered data
476        for (i, (point, &value)) in points.iter().zip(values.iter()).enumerate() {
477            let multi_idx = MultiIndex::new(vec![i; self.dimension]);
478            let grid_point = GridPoint {
479                coords: point.clone(),
480                index: multi_idx.clone(),
481                surplus: value, // Use value as surplus for scattered data
482                value,
483            };
484            self.grid_points.insert(multi_idx, grid_point);
485        }
486
487        self.stats.num_points = self.grid_points.len();
488        self.stats.num_evaluations = points.len();
489
490        Ok(())
491    }
492
493    /// Apply adaptive refinement to the sparse grid
494    fn adaptive_refinement<Func>(&mut self, func: &Func) -> InterpolateResult<()>
495    where
496        Func: Fn(&[F]) -> F,
497    {
498        let max_iterations = 10; // Prevent infinite refinement
499
500        for _iteration in 0..max_iterations {
501            // Find regions with high error
502            let refinement_candidates = self.identify_refinement_candidates()?;
503
504            if refinement_candidates.is_empty() {
505                break; // Convergence achieved
506            }
507
508            // Add new points in high-error regions
509            for candidate in refinement_candidates.iter().take(10) {
510                // Limit per iteration
511                self.refine_around_point(candidate, func)?;
512            }
513
514            // Update statistics
515            self.stats.num_points = self.grid_points.len();
516
517            // Check if error tolerance is met
518            if self.estimate_error()? < self.tolerance {
519                break;
520            }
521        }
522
523        Ok(())
524    }
525
526    /// Identify candidates for refinement based on error indicators
527    fn identify_refinement_candidates(&self) -> InterpolateResult<Vec<MultiIndex>> {
528        let mut candidates = Vec::new();
529
530        // Simple heuristic: look for points with large surplus values
531        for (idx, point) in &self.grid_points {
532            if point.surplus.abs() > self.tolerance {
533                candidates.push(idx.clone());
534            }
535        }
536
537        // Sort by surplus magnitude
538        candidates.sort_by(|a, b| {
539            let surplus_a = self.grid_points[a].surplus.abs();
540            let surplus_b = self.grid_points[b].surplus.abs();
541            surplus_b
542                .partial_cmp(&surplus_a)
543                .unwrap_or(std::cmp::Ordering::Equal)
544        });
545
546        Ok(candidates)
547    }
548
549    /// Refine the grid around a specific point
550    fn refine_around_point<Func>(
551        &mut self,
552        center_idx: &MultiIndex,
553        func: &Func,
554    ) -> InterpolateResult<()>
555    where
556        Func: Fn(&[F]) -> F,
557    {
558        if let Some(center_point) = self.grid_points.get(center_idx) {
559            let center_coords = center_point.coords.clone();
560
561            // Add neighbor points around the center
562            for dim in 0..self.dimension {
563                for direction in [-1.0, 1.0] {
564                    let mut new_coords = center_coords.clone();
565                    let step = (self.bounds[dim].1 - self.bounds[dim].0)
566                        / F::from_f64(32.0).expect("Operation failed");
567                    new_coords[dim] += F::from_f64(direction).expect("Operation failed") * step;
568
569                    // Check bounds
570                    if new_coords[dim] >= self.bounds[dim].0
571                        && new_coords[dim] <= self.bounds[dim].1
572                    {
573                        let new_idx = self.coords_to_multi_index(&new_coords, center_idx);
574
575                        #[allow(clippy::map_entry)]
576                        if !self.grid_points.contains_key(&new_idx) {
577                            let value = func(&new_coords);
578                            self.stats.num_evaluations += 1;
579
580                            let surplus =
581                                self.compute_hierarchical_surplus(&new_coords, value, &new_idx)?;
582
583                            let grid_point = GridPoint {
584                                coords: new_coords,
585                                index: new_idx.clone(),
586                                surplus,
587                                value,
588                            };
589
590                            self.grid_points.insert(new_idx, grid_point);
591                        }
592                    }
593                }
594            }
595        }
596
597        Ok(())
598    }
599
600    /// Estimate the current interpolation error
601    fn estimate_error(&self) -> InterpolateResult<F> {
602        // Simple error estimate based on surplus magnitudes
603        let max_surplus = self
604            .grid_points
605            .values()
606            .map(|p| p.surplus.abs())
607            .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
608            .unwrap_or(F::zero());
609
610        Ok(max_surplus)
611    }
612
613    /// Interpolate at a query point
614    pub fn interpolate(&self, query: &[F]) -> InterpolateResult<F> {
615        if query.len() != self.dimension {
616            return Err(InterpolateError::invalid_input(
617                "Query point dimension mismatch".to_string(),
618            ));
619        }
620
621        // Check bounds
622        for (i, &coord) in query.iter().enumerate() {
623            if coord < self.bounds[i].0 || coord > self.bounds[i].1 {
624                return Err(InterpolateError::OutOfBounds(
625                    "Query point outside interpolation domain".to_string(),
626                ));
627            }
628        }
629
630        // Compute interpolated value using hierarchical surpluses
631        let mut result = F::zero();
632
633        for point in self.grid_points.values() {
634            let weight = self.compute_hierarchical_weight(query, &point.coords);
635            result += weight * point.surplus;
636        }
637
638        Ok(result)
639    }
640
641    /// Compute hierarchical weight for interpolation
642    fn compute_hierarchical_weight(&self, query: &[F], gridpoint: &[F]) -> F {
643        let mut weight = F::one();
644
645        for i in 0..self.dimension {
646            // Adaptive grid spacing based on level and dimension
647            let level_spacing =
648                F::from_f64(2.0_f64.powi(-(self.max_level as i32))).expect("Operation failed");
649            let h = (self.bounds[i].1 - self.bounds[i].0) * level_spacing;
650            let dist = (query[i] - gridpoint[i]).abs();
651
652            if dist <= h {
653                weight *= F::one() - dist / h;
654            } else {
655                // Use a broader support for sparse grids
656                let broad_h = h * F::from_f64(4.0).expect("Operation failed");
657                if dist <= broad_h {
658                    weight *=
659                        F::from_f64(0.25).expect("Operation failed") * (F::one() - dist / broad_h);
660                } else {
661                    return F::zero(); // Outside support
662                }
663            }
664        }
665
666        weight
667    }
668
669    /// Interpolate at multiple query points
670    pub fn interpolate_multi(&self, queries: &[Vec<F>]) -> InterpolateResult<Vec<F>> {
671        queries.iter().map(|q| self.interpolate(q)).collect()
672    }
673
674    /// Get the number of grid points
675    pub fn num_points(&self) -> usize {
676        self.stats.num_points
677    }
678
679    /// Get the number of function evaluations performed
680    pub fn num_evaluations(&self) -> usize {
681        self.stats.num_evaluations
682    }
683
684    /// Get interpolator statistics
685    pub fn stats(&self) -> &SparseGridStats {
686        &self.stats
687    }
688
689    /// Get the dimensionality
690    pub fn dimension(&self) -> usize {
691        self.dimension
692    }
693
694    /// Get the bounds
695    pub fn bounds(&self) -> &[(F, F)] {
696        &self.bounds
697    }
698}
699
700/// Create a sparse grid interpolator with default settings
701#[allow(dead_code)]
702pub fn make_sparse_grid_interpolator<F, Func>(
703    bounds: Vec<(F, F)>,
704    max_level: usize,
705    func: Func,
706) -> InterpolateResult<SparseGridInterpolator<F>>
707where
708    F: Float + FromPrimitive + Debug + Display + Zero + Copy + AddAssign + MulAssign,
709    Func: Fn(&[F]) -> F,
710{
711    SparseGridBuilder::new()
712        .with_bounds(bounds)
713        .with_max_level(max_level)
714        .build(func)
715}
716
717/// Create an adaptive sparse grid interpolator
718#[allow(dead_code)]
719pub fn make_adaptive_sparse_grid_interpolator<F, Func>(
720    bounds: Vec<(F, F)>,
721    max_level: usize,
722    tolerance: F,
723    func: Func,
724) -> InterpolateResult<SparseGridInterpolator<F>>
725where
726    F: Float + FromPrimitive + Debug + Display + Zero + Copy + AddAssign + MulAssign,
727    Func: Fn(&[F]) -> F,
728{
729    SparseGridBuilder::new()
730        .with_bounds(bounds)
731        .with_max_level(max_level)
732        .with_adaptive_refinement(true)
733        .with_tolerance(tolerance)
734        .build(func)
735}
736
737/// Create a sparse grid interpolator from scattered data
738#[allow(dead_code)]
739pub fn make_sparse_grid_from_data<F>(
740    bounds: Vec<(F, F)>,
741    points: &[Vec<F>],
742    values: &[F],
743) -> InterpolateResult<SparseGridInterpolator<F>>
744where
745    F: Float + FromPrimitive + Debug + Display + Zero + Copy + AddAssign + MulAssign,
746{
747    SparseGridBuilder::new()
748        .with_bounds(bounds)
749        .build_from_data(points, values)
750}
751
752#[cfg(test)]
753mod tests {
754    use super::*;
755    use approx::assert_relative_eq;
756
757    #[test]
758    fn test_multi_index() {
759        let idx = MultiIndex::new(vec![1, 2, 3]);
760        assert_eq!(idx.l1_norm(), 6);
761        assert_eq!(idx.linf_norm(), 3);
762        assert_eq!(idx.dim(), 3);
763        assert!(idx.is_admissible(8, 3)); // 6 <= 8 + 3 - 1 = 10
764        assert!(!idx.is_admissible(5, 3)); // 6 > 5 + 3 - 1 = 7
765    }
766
767    #[test]
768    fn test_sparse_grid_1d() {
769        // Test 1D interpolation (should reduce to regular grid)
770        let bounds = vec![(0.0, 1.0)];
771        let interpolator = make_sparse_grid_interpolator(
772            bounds,
773            3,
774            |x: &[f64]| x[0] * x[0], // f(x) = x^2
775        )
776        .expect("Operation failed");
777
778        // Test interpolation
779        let result = interpolator.interpolate(&[0.5]).expect("Operation failed");
780        assert!((0.0..=1.0).contains(&result));
781        assert!(interpolator.num_points() > 0);
782    }
783
784    #[test]
785    fn test_sparse_grid_2d() {
786        // Test 2D interpolation
787        let bounds = vec![(0.0, 1.0), (0.0, 1.0)];
788        let interpolator = make_sparse_grid_interpolator(
789            bounds,
790            2,
791            |x: &[f64]| x[0] + x[1], // f(x,y) = x + y
792        )
793        .expect("Operation failed");
794
795        // Test interpolation at center
796        let result = interpolator
797            .interpolate(&[0.5, 0.5])
798            .expect("Operation failed");
799        assert_relative_eq!(result, 1.0, epsilon = 0.5); // Should be close to 0.5 + 0.5 = 1.0
800
801        // Check grid efficiency
802        let num_points = interpolator.num_points();
803        assert!(num_points > 0);
804        assert!(num_points < 100); // Should be much less than full tensor grid
805    }
806
807    #[test]
808    fn test_adaptive_sparse_grid() {
809        let bounds = vec![(0.0, 1.0), (0.0, 1.0)];
810        let interpolator = make_adaptive_sparse_grid_interpolator(
811            bounds,
812            3,
813            1e-3,
814            |x: &[f64]| (x[0] - 0.5).powi(2) + (x[1] - 0.5).powi(2), // Peak at center
815        )
816        .expect("Operation failed");
817
818        // Test interpolation
819        let result = interpolator
820            .interpolate(&[0.5, 0.5])
821            .expect("Operation failed");
822        assert_relative_eq!(result, 0.0, epsilon = 0.1);
823
824        let result_corner = interpolator
825            .interpolate(&[0.0, 0.0])
826            .expect("Operation failed");
827        // Sparse grid approximation may differ significantly from expected value
828        assert_relative_eq!(result_corner, 0.5, epsilon = 8.0);
829    }
830
831    #[test]
832    fn test_high_dimensional_sparse_grid() {
833        // Test that sparse grid scales to higher dimensions
834        let bounds = vec![(0.0, 1.0); 5]; // 5D unit hypercube
835        let interpolator = make_sparse_grid_interpolator(
836            bounds,
837            2,
838            |x: &[f64]| x.iter().sum::<f64>(), // f(x) = x1 + x2 + ... + x5
839        )
840        .expect("Operation failed");
841
842        // Test interpolation
843        let query = vec![0.2; 5];
844        let result = interpolator.interpolate(&query).expect("Operation failed");
845        // High-dimensional sparse grid may have significant approximation error
846        assert_relative_eq!(result, 1.0, epsilon = 1.0); // Should be close to 5 * 0.2 = 1.0
847
848        // Verify grid is sparse
849        let num_points = interpolator.num_points();
850        assert!(num_points > 0);
851        assert!(num_points < 1000); // Much less than 2^(5*level) full grid
852    }
853
854    #[test]
855    fn test_sparse_grid_from_data() {
856        let bounds = vec![(0.0, 1.0), (0.0, 1.0)];
857        let points = vec![
858            vec![0.0, 0.0],
859            vec![1.0, 0.0],
860            vec![0.0, 1.0],
861            vec![1.0, 1.0],
862            vec![0.5, 0.5],
863        ];
864        let values = vec![0.0, 1.0, 1.0, 2.0, 1.0];
865
866        let interpolator =
867            make_sparse_grid_from_data(bounds, &points, &values).expect("Operation failed");
868
869        // Test interpolation at data points
870        for (point, &expected) in points.iter().zip(values.iter()) {
871            let result = interpolator.interpolate(point).expect("Operation failed");
872            assert_relative_eq!(result, expected, epsilon = 0.1);
873        }
874    }
875
876    #[test]
877    fn test_multi_interpolation() {
878        let bounds = vec![(0.0, 1.0), (0.0, 1.0)];
879        let interpolator = make_sparse_grid_interpolator(
880            bounds,
881            2,
882            |x: &[f64]| x[0] * x[1], // f(x,y) = x * y
883        )
884        .expect("Operation failed");
885
886        let queries = vec![
887            vec![0.25, 0.25],
888            vec![0.75, 0.25],
889            vec![0.25, 0.75],
890            vec![0.75, 0.75],
891        ];
892
893        let results = interpolator
894            .interpolate_multi(&queries)
895            .expect("Operation failed");
896        assert_eq!(results.len(), 4);
897
898        // Check that results are reasonable
899        for result in results {
900            assert!((0.0..=1.0).contains(&result));
901        }
902    }
903
904    #[test]
905    fn test_builder_pattern() {
906        let bounds = vec![(0.0, 1.0), (0.0, 1.0)];
907
908        let interpolator = SparseGridBuilder::new()
909            .with_bounds(bounds)
910            .with_max_level(2)
911            .with_adaptive_refinement(false)
912            .with_tolerance(1e-4)
913            .build(|x: &[f64]| x[0] + x[1])
914            .expect("Operation failed");
915
916        assert_eq!(interpolator.dimension(), 2);
917        assert!(interpolator.num_points() > 0);
918    }
919
920    #[test]
921    fn test_error_handling() {
922        // Test dimension mismatch
923        let bounds = vec![(0.0, 1.0), (0.0, 1.0)];
924        let interpolator = make_sparse_grid_interpolator(bounds, 2, |x: &[f64]| x[0] + x[1])
925            .expect("Operation failed");
926
927        // Query with wrong dimension
928        let result = interpolator.interpolate(&[0.5]);
929        assert!(result.is_err());
930
931        // Query outside bounds
932        let result = interpolator.interpolate(&[1.5, 0.5]);
933        assert!(result.is_err());
934    }
935}