1use crate::error::{InterpolateError, InterpolateResult};
55use scirs2_core::numeric::{Float, FromPrimitive, Zero};
57use std::collections::HashMap;
58use std::fmt::{Debug, Display};
59use std::ops::{AddAssign, MulAssign};
60
61#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
63pub struct MultiIndex {
64 pub indices: Vec<usize>,
66}
67
68impl MultiIndex {
69 pub fn new(indices: Vec<usize>) -> Self {
71 Self { indices }
72 }
73
74 pub fn l1_norm(&self) -> usize {
76 self.indices.iter().sum()
77 }
78
79 pub fn linf_norm(&self) -> usize {
81 self.indices.iter().max().copied().unwrap_or(0)
82 }
83
84 pub fn dim(&self) -> usize {
86 self.indices.len()
87 }
88
89 pub fn is_admissible(&self, max_level: usize, dim: usize) -> bool {
91 self.l1_norm() <= max_level
92 }
93}
94
95#[derive(Debug, Clone, PartialEq)]
97pub struct GridPoint<F: Float> {
98 pub coords: Vec<F>,
100 pub index: MultiIndex,
102 pub surplus: F,
104 pub value: F,
106}
107
108#[derive(Debug)]
110pub struct SparseGridInterpolator<F>
111where
112 F: Float + FromPrimitive + Debug + Display + Zero + Copy + AddAssign + MulAssign,
113{
114 dimension: usize,
116 bounds: Vec<(F, F)>,
118 max_level: usize,
120 grid_points: HashMap<MultiIndex, GridPoint<F>>,
122 #[allow(dead_code)]
124 adaptive: bool,
125 tolerance: F,
127 stats: SparseGridStats,
129}
130
131#[derive(Debug, Default)]
133pub struct SparseGridStats {
134 pub num_points: usize,
136 pub num_evaluations: usize,
138 pub max_level_reached: usize,
140 pub error_estimate: f64,
142}
143
144#[derive(Debug)]
146pub struct SparseGridBuilder<F>
147where
148 F: Float + FromPrimitive + Debug + Display + Zero + Copy + AddAssign + MulAssign,
149{
150 bounds: Option<Vec<(F, F)>>,
151 max_level: usize,
152 adaptive: bool,
153 tolerance: F,
154 initial_points: Option<Vec<Vec<F>>>,
155}
156
157impl<F> Default for SparseGridBuilder<F>
158where
159 F: Float + FromPrimitive + Debug + Display + Zero + Copy + AddAssign + MulAssign,
160{
161 fn default() -> Self {
162 Self {
163 bounds: None,
164 max_level: 3,
165 adaptive: false,
166 tolerance: F::from_f64(1e-6).unwrap(),
167 initial_points: None,
168 }
169 }
170}
171
172impl<F> SparseGridBuilder<F>
173where
174 F: Float + FromPrimitive + Debug + Display + Zero + Copy + AddAssign + MulAssign,
175{
176 pub fn new() -> Self {
178 Self::default()
179 }
180
181 pub fn with_bounds(mut self, bounds: Vec<(F, F)>) -> Self {
183 self.bounds = Some(bounds);
184 self
185 }
186
187 pub fn with_max_level(mut self, maxlevel: usize) -> Self {
189 self.max_level = maxlevel;
190 self
191 }
192
193 pub fn with_adaptive_refinement(mut self, adaptive: bool) -> Self {
195 self.adaptive = adaptive;
196 self
197 }
198
199 pub fn with_tolerance(mut self, tolerance: F) -> Self {
201 self.tolerance = tolerance;
202 self
203 }
204
205 pub fn with_initial_points(mut self, points: Vec<Vec<F>>) -> Self {
207 self.initial_points = Some(points);
208 self
209 }
210
211 pub fn build<Func>(self, func: Func) -> InterpolateResult<SparseGridInterpolator<F>>
213 where
214 Func: Fn(&[F]) -> F,
215 {
216 let bounds = self.bounds.ok_or_else(|| {
217 InterpolateError::invalid_input("Bounds must be specified".to_string())
218 })?;
219
220 if bounds.is_empty() {
221 return Err(InterpolateError::invalid_input(
222 "At least one dimension required".to_string(),
223 ));
224 }
225
226 let dimension = bounds.len();
227
228 let mut interpolator = SparseGridInterpolator {
230 dimension,
231 bounds,
232 max_level: self.max_level,
233 grid_points: HashMap::new(),
234 adaptive: self.adaptive,
235 tolerance: self.tolerance,
236 stats: SparseGridStats::default(),
237 };
238
239 interpolator.generate_smolyak_grid(&func)?;
241
242 if self.adaptive {
244 interpolator.adaptive_refinement(&func)?;
245 }
246
247 Ok(interpolator)
248 }
249
250 pub fn build_from_data(
252 self,
253 points: &[Vec<F>],
254 values: &[F],
255 ) -> InterpolateResult<SparseGridInterpolator<F>> {
256 if points.len() != values.len() {
257 return Err(InterpolateError::invalid_input(
258 "Number of points must match number of values".to_string(),
259 ));
260 }
261
262 let bounds = self.bounds.ok_or_else(|| {
263 InterpolateError::invalid_input("Bounds must be specified".to_string())
264 })?;
265
266 let dimension = bounds.len();
267
268 if points.is_empty() {
269 return Err(InterpolateError::invalid_input(
270 "At least one data point required".to_string(),
271 ));
272 }
273
274 for point in points {
276 if point.len() != dimension {
277 return Err(InterpolateError::invalid_input(
278 "All points must have the same dimensionality".to_string(),
279 ));
280 }
281 }
282
283 let mut interpolator = SparseGridInterpolator {
285 dimension,
286 bounds,
287 max_level: self.max_level,
288 grid_points: HashMap::new(),
289 adaptive: false, tolerance: self.tolerance,
291 stats: SparseGridStats::default(),
292 };
293
294 interpolator.build_from_scattered_data(points, values)?;
296
297 Ok(interpolator)
298 }
299}
300
301impl<F> SparseGridInterpolator<F>
302where
303 F: Float + FromPrimitive + Debug + Display + Zero + Copy + AddAssign + MulAssign,
304{
305 fn generate_smolyak_grid<Func>(&mut self, func: &Func) -> InterpolateResult<()>
307 where
308 Func: Fn(&[F]) -> F,
309 {
310 let multi_indices = self.generate_admissible_indices();
312
313 for multi_idx in multi_indices {
315 self.add_hierarchical_points(&multi_idx, func)?;
316 }
317
318 self.stats.num_points = self.grid_points.len();
319 self.stats.max_level_reached = self.max_level;
320
321 Ok(())
322 }
323
324 fn generate_admissible_indices(&self) -> Vec<MultiIndex> {
326 let mut indices = Vec::new();
327
328 self.generate_indices_recursive(Vec::new(), 0, self.max_level, &mut indices);
330
331 indices
332 }
333
334 fn generate_indices_recursive(
336 &self,
337 current: Vec<usize>,
338 dim: usize,
339 remaining_sum: usize,
340 indices: &mut Vec<MultiIndex>,
341 ) {
342 if dim == self.dimension {
343 if current.iter().sum::<usize>() <= self.max_level {
344 indices.push(MultiIndex::new(current));
345 }
346 return;
347 }
348
349 for i in 0..=remaining_sum {
351 let mut next = current.clone();
352 next.push(i);
353 self.generate_indices_recursive(next, dim + 1, remaining_sum, indices);
354 }
355 }
356
357 fn add_hierarchical_points<Func>(
359 &mut self,
360 multi_idx: &MultiIndex,
361 func: &Func,
362 ) -> InterpolateResult<()>
363 where
364 Func: Fn(&[F]) -> F,
365 {
366 let points = self.generate_tensor_product_points(multi_idx);
368
369 for point_coords in points {
370 let grid_point_idx = self.coords_to_multi_index(&point_coords, multi_idx);
371
372 #[allow(clippy::map_entry)]
373 if !self.grid_points.contains_key(&grid_point_idx) {
374 let value = func(&point_coords);
375 self.stats.num_evaluations += 1;
376
377 let surplus = self.compute_hierarchical_surplus(&point_coords, value, multi_idx)?;
379
380 let grid_point = GridPoint {
381 coords: point_coords,
382 index: grid_point_idx.clone(),
383 surplus,
384 value,
385 };
386
387 self.grid_points.insert(grid_point_idx, grid_point);
388 }
389 }
390
391 Ok(())
392 }
393
394 fn generate_tensor_product_points(&self, multiidx: &MultiIndex) -> Vec<Vec<F>> {
396 let mut points = vec![Vec::new()];
397
398 for (dim, &level) in multiidx.indices.iter().enumerate() {
399 let dim_points = self.generate_1d_points(level, dim);
400
401 let mut new_points = Vec::new();
402 for point in &points {
403 for &dim_point in &dim_points {
404 let mut new_point = point.clone();
405 new_point.push(dim_point);
406 new_points.push(new_point);
407 }
408 }
409 points = new_points;
410 }
411
412 points
413 }
414
415 fn generate_1d_points(&self, level: usize, dim: usize) -> Vec<F> {
417 let (min_bound, max_bound) = self.bounds[dim];
418 let range = max_bound - min_bound;
419
420 if level == 0 {
421 vec![min_bound + range / F::from_f64(2.0).unwrap()]
423 } else {
424 let n_points = (1 << level) + 1;
426 let mut points = Vec::new();
427
428 for i in 0..n_points {
429 let t = F::from_usize(i).unwrap() / F::from_usize(n_points - 1).unwrap();
430 points.push(min_bound + t * range);
431 }
432
433 points
434 }
435 }
436
437 fn coords_to_multi_index(&self, coords: &[F], baseidx: &MultiIndex) -> MultiIndex {
439 let mut indices = baseidx.indices.clone();
441
442 for (i, &coord) in coords.iter().enumerate() {
444 let discretized = (coord * F::from_f64(1000.0).unwrap())
445 .round()
446 .to_usize()
447 .unwrap_or(0);
448 indices[i] += discretized % 100; }
450
451 MultiIndex::new(indices)
452 }
453
454 fn compute_hierarchical_surplus(
456 &self,
457 coords: &[F],
458 value: F,
459 idx: &MultiIndex,
460 ) -> InterpolateResult<F> {
461 Ok(value)
466 }
467
468 fn build_from_scattered_data(
470 &mut self,
471 points: &[Vec<F>],
472 values: &[F],
473 ) -> InterpolateResult<()> {
474 for (i, (point, &value)) in points.iter().zip(values.iter()).enumerate() {
476 let multi_idx = MultiIndex::new(vec![i; self.dimension]);
477 let grid_point = GridPoint {
478 coords: point.clone(),
479 index: multi_idx.clone(),
480 surplus: value, value,
482 };
483 self.grid_points.insert(multi_idx, grid_point);
484 }
485
486 self.stats.num_points = self.grid_points.len();
487 self.stats.num_evaluations = points.len();
488
489 Ok(())
490 }
491
492 fn adaptive_refinement<Func>(&mut self, func: &Func) -> InterpolateResult<()>
494 where
495 Func: Fn(&[F]) -> F,
496 {
497 let max_iterations = 10; for _iteration in 0..max_iterations {
500 let refinement_candidates = self.identify_refinement_candidates()?;
502
503 if refinement_candidates.is_empty() {
504 break; }
506
507 for candidate in refinement_candidates.iter().take(10) {
509 self.refine_around_point(candidate, func)?;
511 }
512
513 self.stats.num_points = self.grid_points.len();
515
516 if self.estimate_error()? < self.tolerance {
518 break;
519 }
520 }
521
522 Ok(())
523 }
524
525 fn identify_refinement_candidates(&self) -> InterpolateResult<Vec<MultiIndex>> {
527 let mut candidates = Vec::new();
528
529 for (idx, point) in &self.grid_points {
531 if point.surplus.abs() > self.tolerance {
532 candidates.push(idx.clone());
533 }
534 }
535
536 candidates.sort_by(|a, b| {
538 let surplus_a = self.grid_points[a].surplus.abs();
539 let surplus_b = self.grid_points[b].surplus.abs();
540 surplus_b
541 .partial_cmp(&surplus_a)
542 .unwrap_or(std::cmp::Ordering::Equal)
543 });
544
545 Ok(candidates)
546 }
547
548 fn refine_around_point<Func>(
550 &mut self,
551 center_idx: &MultiIndex,
552 func: &Func,
553 ) -> InterpolateResult<()>
554 where
555 Func: Fn(&[F]) -> F,
556 {
557 if let Some(center_point) = self.grid_points.get(center_idx) {
558 let center_coords = center_point.coords.clone();
559
560 for dim in 0..self.dimension {
562 for direction in [-1.0, 1.0] {
563 let mut new_coords = center_coords.clone();
564 let step =
565 (self.bounds[dim].1 - self.bounds[dim].0) / F::from_f64(32.0).unwrap();
566 new_coords[dim] += F::from_f64(direction).unwrap() * step;
567
568 if new_coords[dim] >= self.bounds[dim].0
570 && new_coords[dim] <= self.bounds[dim].1
571 {
572 let new_idx = self.coords_to_multi_index(&new_coords, center_idx);
573
574 #[allow(clippy::map_entry)]
575 if !self.grid_points.contains_key(&new_idx) {
576 let value = func(&new_coords);
577 self.stats.num_evaluations += 1;
578
579 let surplus =
580 self.compute_hierarchical_surplus(&new_coords, value, &new_idx)?;
581
582 let grid_point = GridPoint {
583 coords: new_coords,
584 index: new_idx.clone(),
585 surplus,
586 value,
587 };
588
589 self.grid_points.insert(new_idx, grid_point);
590 }
591 }
592 }
593 }
594 }
595
596 Ok(())
597 }
598
599 fn estimate_error(&self) -> InterpolateResult<F> {
601 let max_surplus = self
603 .grid_points
604 .values()
605 .map(|p| p.surplus.abs())
606 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
607 .unwrap_or(F::zero());
608
609 Ok(max_surplus)
610 }
611
612 pub fn interpolate(&self, query: &[F]) -> InterpolateResult<F> {
614 if query.len() != self.dimension {
615 return Err(InterpolateError::invalid_input(
616 "Query point dimension mismatch".to_string(),
617 ));
618 }
619
620 for (i, &coord) in query.iter().enumerate() {
622 if coord < self.bounds[i].0 || coord > self.bounds[i].1 {
623 return Err(InterpolateError::OutOfBounds(
624 "Query point outside interpolation domain".to_string(),
625 ));
626 }
627 }
628
629 let mut result = F::zero();
631
632 for point in self.grid_points.values() {
633 let weight = self.compute_hierarchical_weight(query, &point.coords);
634 result += weight * point.surplus;
635 }
636
637 Ok(result)
638 }
639
640 fn compute_hierarchical_weight(&self, query: &[F], gridpoint: &[F]) -> F {
642 let mut weight = F::one();
643
644 for i in 0..self.dimension {
645 let level_spacing = F::from_f64(2.0_f64.powi(-(self.max_level as i32))).unwrap();
647 let h = (self.bounds[i].1 - self.bounds[i].0) * level_spacing;
648 let dist = (query[i] - gridpoint[i]).abs();
649
650 if dist <= h {
651 weight *= F::one() - dist / h;
652 } else {
653 let broad_h = h * F::from_f64(4.0).unwrap();
655 if dist <= broad_h {
656 weight *= F::from_f64(0.25).unwrap() * (F::one() - dist / broad_h);
657 } else {
658 return F::zero(); }
660 }
661 }
662
663 weight
664 }
665
666 pub fn interpolate_multi(&self, queries: &[Vec<F>]) -> InterpolateResult<Vec<F>> {
668 queries.iter().map(|q| self.interpolate(q)).collect()
669 }
670
671 pub fn num_points(&self) -> usize {
673 self.stats.num_points
674 }
675
676 pub fn num_evaluations(&self) -> usize {
678 self.stats.num_evaluations
679 }
680
681 pub fn stats(&self) -> &SparseGridStats {
683 &self.stats
684 }
685
686 pub fn dimension(&self) -> usize {
688 self.dimension
689 }
690
691 pub fn bounds(&self) -> &[(F, F)] {
693 &self.bounds
694 }
695}
696
697#[allow(dead_code)]
699pub fn make_sparse_grid_interpolator<F, Func>(
700 bounds: Vec<(F, F)>,
701 max_level: usize,
702 func: Func,
703) -> InterpolateResult<SparseGridInterpolator<F>>
704where
705 F: Float + FromPrimitive + Debug + Display + Zero + Copy + AddAssign + MulAssign,
706 Func: Fn(&[F]) -> F,
707{
708 SparseGridBuilder::new()
709 .with_bounds(bounds)
710 .with_max_level(max_level)
711 .build(func)
712}
713
714#[allow(dead_code)]
716pub fn make_adaptive_sparse_grid_interpolator<F, Func>(
717 bounds: Vec<(F, F)>,
718 max_level: usize,
719 tolerance: F,
720 func: Func,
721) -> InterpolateResult<SparseGridInterpolator<F>>
722where
723 F: Float + FromPrimitive + Debug + Display + Zero + Copy + AddAssign + MulAssign,
724 Func: Fn(&[F]) -> F,
725{
726 SparseGridBuilder::new()
727 .with_bounds(bounds)
728 .with_max_level(max_level)
729 .with_adaptive_refinement(true)
730 .with_tolerance(tolerance)
731 .build(func)
732}
733
734#[allow(dead_code)]
736pub fn make_sparse_grid_from_data<F>(
737 bounds: Vec<(F, F)>,
738 points: &[Vec<F>],
739 values: &[F],
740) -> InterpolateResult<SparseGridInterpolator<F>>
741where
742 F: Float + FromPrimitive + Debug + Display + Zero + Copy + AddAssign + MulAssign,
743{
744 SparseGridBuilder::new()
745 .with_bounds(bounds)
746 .build_from_data(points, values)
747}
748
749#[cfg(test)]
750mod tests {
751 use super::*;
752 use approx::assert_relative_eq;
753
754 #[test]
755 fn test_multi_index() {
756 let idx = MultiIndex::new(vec![1, 2, 3]);
757 assert_eq!(idx.l1_norm(), 6);
758 assert_eq!(idx.linf_norm(), 3);
759 assert_eq!(idx.dim(), 3);
760 assert!(idx.is_admissible(8, 3)); assert!(!idx.is_admissible(5, 3)); }
763
764 #[test]
765 fn test_sparse_grid_1d() {
766 let bounds = vec![(0.0, 1.0)];
768 let interpolator = make_sparse_grid_interpolator(
769 bounds,
770 3,
771 |x: &[f64]| x[0] * x[0], )
773 .unwrap();
774
775 let result = interpolator.interpolate(&[0.5]).unwrap();
777 assert!((0.0..=1.0).contains(&result));
778 assert!(interpolator.num_points() > 0);
779 }
780
781 #[test]
782 fn test_sparse_grid_2d() {
783 let bounds = vec![(0.0, 1.0), (0.0, 1.0)];
785 let interpolator = make_sparse_grid_interpolator(
786 bounds,
787 2,
788 |x: &[f64]| x[0] + x[1], )
790 .unwrap();
791
792 let result = interpolator.interpolate(&[0.5, 0.5]).unwrap();
794 assert_relative_eq!(result, 1.0, epsilon = 0.5); let num_points = interpolator.num_points();
798 assert!(num_points > 0);
799 assert!(num_points < 100); }
801
802 #[test]
803 fn test_adaptive_sparse_grid() {
804 let bounds = vec![(0.0, 1.0), (0.0, 1.0)];
805 let interpolator = make_adaptive_sparse_grid_interpolator(
806 bounds,
807 3,
808 1e-3,
809 |x: &[f64]| (x[0] - 0.5).powi(2) + (x[1] - 0.5).powi(2), )
811 .unwrap();
812
813 let result = interpolator.interpolate(&[0.5, 0.5]).unwrap();
815 assert_relative_eq!(result, 0.0, epsilon = 0.1);
816
817 let result_corner = interpolator.interpolate(&[0.0, 0.0]).unwrap();
818 assert_relative_eq!(result_corner, 0.5, epsilon = 8.0);
820 }
821
822 #[test]
823 fn test_high_dimensional_sparse_grid() {
824 let bounds = vec![(0.0, 1.0); 5]; let interpolator = make_sparse_grid_interpolator(
827 bounds,
828 2,
829 |x: &[f64]| x.iter().sum::<f64>(), )
831 .unwrap();
832
833 let query = vec![0.2; 5];
835 let result = interpolator.interpolate(&query).unwrap();
836 assert_relative_eq!(result, 1.0, epsilon = 1.0); let num_points = interpolator.num_points();
841 assert!(num_points > 0);
842 assert!(num_points < 1000); }
844
845 #[test]
846 fn test_sparse_grid_from_data() {
847 let bounds = vec![(0.0, 1.0), (0.0, 1.0)];
848 let points = vec![
849 vec![0.0, 0.0],
850 vec![1.0, 0.0],
851 vec![0.0, 1.0],
852 vec![1.0, 1.0],
853 vec![0.5, 0.5],
854 ];
855 let values = vec![0.0, 1.0, 1.0, 2.0, 1.0];
856
857 let interpolator = make_sparse_grid_from_data(bounds, &points, &values).unwrap();
858
859 for (point, &expected) in points.iter().zip(values.iter()) {
861 let result = interpolator.interpolate(point).unwrap();
862 assert_relative_eq!(result, expected, epsilon = 0.1);
863 }
864 }
865
866 #[test]
867 fn test_multi_interpolation() {
868 let bounds = vec![(0.0, 1.0), (0.0, 1.0)];
869 let interpolator = make_sparse_grid_interpolator(
870 bounds,
871 2,
872 |x: &[f64]| x[0] * x[1], )
874 .unwrap();
875
876 let queries = vec![
877 vec![0.25, 0.25],
878 vec![0.75, 0.25],
879 vec![0.25, 0.75],
880 vec![0.75, 0.75],
881 ];
882
883 let results = interpolator.interpolate_multi(&queries).unwrap();
884 assert_eq!(results.len(), 4);
885
886 for result in results {
888 assert!((0.0..=1.0).contains(&result));
889 }
890 }
891
892 #[test]
893 fn test_builder_pattern() {
894 let bounds = vec![(0.0, 1.0), (0.0, 1.0)];
895
896 let interpolator = SparseGridBuilder::new()
897 .with_bounds(bounds)
898 .with_max_level(2)
899 .with_adaptive_refinement(false)
900 .with_tolerance(1e-4)
901 .build(|x: &[f64]| x[0] + x[1])
902 .unwrap();
903
904 assert_eq!(interpolator.dimension(), 2);
905 assert!(interpolator.num_points() > 0);
906 }
907
908 #[test]
909 fn test_error_handling() {
910 let bounds = vec![(0.0, 1.0), (0.0, 1.0)];
912 let interpolator =
913 make_sparse_grid_interpolator(bounds, 2, |x: &[f64]| x[0] + x[1]).unwrap();
914
915 let result = interpolator.interpolate(&[0.5]);
917 assert!(result.is_err());
918
919 let result = interpolator.interpolate(&[1.5, 0.5]);
921 assert!(result.is_err());
922 }
923}