Skip to main content

scirs2_interpolate/resampling/
grid_spec.rs

1//! N-dimensional grid specification and scattered-to-grid resampling.
2//!
3//! Provides `GridSpec`, `Aggregator`, `ResampleStrategy`, and
4//! `resample_scattered_to_grid` for mapping scattered (x, y) data onto a
5//! regular N-D grid, returning an `ArrayD<f64>`.
6//!
7//! # Design
8//!
9//! - `Rasterize(Aggregator)` bins each scattered point into its nearest grid
10//!   cell and accumulates values; empty cells receive `f64::NAN`.
11//! - `Conservative` applies area-weighted (axis-aligned) accumulation —
12//!   for axis-aligned grids this reduces to rasterize with count-weighting.
13//!
14//! # Example
15//!
16//! ```rust
17//! use scirs2_interpolate::resampling::{GridSpec, Aggregator, ResampleStrategy, resample_scattered_to_grid};
18//! use scirs2_core::ndarray::Array2;
19//!
20//! // 4 scattered points in 2-D, values = x + y
21//! let pts = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0]).unwrap();
22//! let vals = scirs2_core::ndarray::Array1::from_vec(vec![0.0, 1.0, 1.0, 2.0]);
23//! let grid = GridSpec::uniform(2, &[(0.0, 1.0, 2), (0.0, 1.0, 2)]);
24//! let out = resample_scattered_to_grid(&pts, &vals, &grid, ResampleStrategy::Rasterize(Aggregator::Mean))
25//!     .expect("resample");
26//! assert_eq!(out.shape(), &[2, 2]);
27//! ```
28
29use scirs2_core::ndarray::{Array1, Array2, ArrayD, IxDyn};
30
31use crate::error::{InterpolateError, InterpolateResult};
32
33// ─────────────────────────────────────────────────────────────────────────────
34// GridSpec
35// ─────────────────────────────────────────────────────────────────────────────
36
37/// Description of a regular N-dimensional target grid.
38///
39/// Each axis is defined by an `Array1<f64>` of strictly increasing
40/// grid-cell centre coordinates.
41#[derive(Debug, Clone)]
42pub struct GridSpec {
43    /// Grid-cell centre coordinates for each axis.
44    pub axes: Vec<Array1<f64>>,
45}
46
47impl GridSpec {
48    /// Create a `GridSpec` from per-axis arrays.
49    ///
50    /// Each `axis` slice must contain at least 1 value and be strictly
51    /// increasing.
52    pub fn new(axes: Vec<Array1<f64>>) -> InterpolateResult<Self> {
53        for (dim, ax) in axes.iter().enumerate() {
54            if ax.is_empty() {
55                return Err(InterpolateError::InvalidInput {
56                    message: format!("axis {dim} is empty"),
57                });
58            }
59            for i in 1..ax.len() {
60                if ax[i] <= ax[i - 1] {
61                    return Err(InterpolateError::InvalidInput {
62                        message: format!("axis {dim} is not strictly increasing"),
63                    });
64                }
65            }
66        }
67        Ok(Self { axes })
68    }
69
70    /// Build a uniform grid from `(min, max, n_cells)` per dimension.
71    ///
72    /// This is a convenience constructor; panics only if `n_cells == 0`.
73    pub fn uniform(dim: usize, specs: &[(f64, f64, usize)]) -> Self {
74        assert_eq!(specs.len(), dim, "specs length must equal dim");
75        let axes: Vec<Array1<f64>> = specs
76            .iter()
77            .map(|&(lo, hi, n)| {
78                let n_pts = n.max(1);
79                if n_pts == 1 {
80                    Array1::from_vec(vec![lo])
81                } else {
82                    let step = (hi - lo) / (n_pts as f64 - 1.0);
83                    Array1::from_iter((0..n_pts).map(|i| lo + i as f64 * step))
84                }
85            })
86            .collect();
87        Self { axes }
88    }
89
90    /// Dimensionality of the grid.
91    pub fn ndim(&self) -> usize {
92        self.axes.len()
93    }
94
95    /// Shape of the output array (one entry per axis).
96    pub fn shape(&self) -> Vec<usize> {
97        self.axes.iter().map(|ax| ax.len()).collect()
98    }
99
100    /// Total number of grid cells.
101    pub fn n_cells(&self) -> usize {
102        self.axes.iter().map(|ax| ax.len()).product()
103    }
104
105    /// Find the nearest cell index along `dim` for coordinate `val`.
106    ///
107    /// Returns the index of the closest grid-centre value.
108    pub fn nearest_index(&self, dim: usize, val: f64) -> usize {
109        let ax = &self.axes[dim];
110        let n = ax.len();
111        // Binary search for the insertion point, then compare neighbours.
112        // Use a manual binary search to avoid `unwrap()` on `as_slice()`.
113        let pos = {
114            let mut lo_idx = 0_usize;
115            let mut hi_idx = n;
116            while lo_idx < hi_idx {
117                let mid = lo_idx + (hi_idx - lo_idx) / 2;
118                if ax[mid] < val {
119                    lo_idx = mid + 1;
120                } else {
121                    hi_idx = mid;
122                }
123            }
124            lo_idx
125        };
126        if pos == 0 {
127            0
128        } else if pos == n {
129            n - 1
130        } else {
131            let lo = ax[pos - 1];
132            let hi = ax[pos];
133            if (val - lo).abs() <= (val - hi).abs() {
134                pos - 1
135            } else {
136                pos
137            }
138        }
139    }
140}
141
142// ─────────────────────────────────────────────────────────────────────────────
143// Aggregator
144// ─────────────────────────────────────────────────────────────────────────────
145
146/// Value-aggregation strategy for cells that receive multiple scattered points.
147#[derive(Debug, Clone, Copy, PartialEq, Eq)]
148pub enum Aggregator {
149    /// Arithmetic mean.
150    Mean,
151    /// Sample median (O(k log k) per non-empty cell).
152    Median,
153    /// Maximum value.
154    Max,
155    /// Minimum value.
156    Min,
157    /// Number of points that fell into the cell.
158    Count,
159}
160
161// ─────────────────────────────────────────────────────────────────────────────
162// ResampleStrategy
163// ─────────────────────────────────────────────────────────────────────────────
164
165/// How to map scattered data to grid cells.
166#[derive(Debug, Clone, Copy, PartialEq)]
167pub enum ResampleStrategy {
168    /// Assign each point to its nearest grid cell, then aggregate.
169    Rasterize(Aggregator),
170    /// Area-weighted conservative resampling (axis-aligned approximation).
171    ///
172    /// For axis-aligned grids this reduces to count-weighted rasterization,
173    /// preserving the total integral of the scattered data.
174    Conservative,
175}
176
177// ─────────────────────────────────────────────────────────────────────────────
178// resample_scattered_to_grid
179// ─────────────────────────────────────────────────────────────────────────────
180
181/// Resample scattered N-D data onto a regular grid.
182///
183/// # Arguments
184///
185/// * `points` – `n × d` array of sample coordinates.
186/// * `values` – length-`n` array of scalar values at those coordinates.
187/// * `grid`   – target `GridSpec` describing the output grid.
188/// * `strategy` – how to handle multiple points per cell.
189///
190/// # Returns
191///
192/// A `d`-dimensional `ArrayD<f64>` whose shape matches `grid.shape()`.
193/// Cells that received no points contain `f64::NAN` (or `0.0` for `Count`).
194pub fn resample_scattered_to_grid(
195    points: &Array2<f64>,
196    values: &Array1<f64>,
197    grid: &GridSpec,
198    strategy: ResampleStrategy,
199) -> InterpolateResult<ArrayD<f64>> {
200    let n = points.nrows();
201    let d = points.ncols();
202
203    if d != grid.ndim() {
204        return Err(InterpolateError::DimensionMismatch(format!(
205            "points has {d} columns but grid has {} axes",
206            grid.ndim()
207        )));
208    }
209    if values.len() != n {
210        return Err(InterpolateError::DimensionMismatch(format!(
211            "points has {n} rows but values has {} elements",
212            values.len()
213        )));
214    }
215
216    let shape = grid.shape();
217    let total_cells = grid.n_cells();
218
219    // Build a flat bucket list: each cell collects the values of points
220    // that land in it.
221    let mut buckets: Vec<Vec<f64>> = vec![Vec::new(); total_cells];
222
223    for row in 0..n {
224        // Compute N-D cell index then convert to a flat C-order offset.
225        let mut multi_idx = vec![0_usize; d];
226        for dim in 0..d {
227            let coord = points[[row, dim]];
228            multi_idx[dim] = grid.nearest_index(dim, coord);
229        }
230
231        let mut flat_idx = 0_usize;
232        for dim in 0..d {
233            flat_idx += multi_idx[dim] * stride_for(&shape, dim);
234        }
235
236        if flat_idx < total_cells {
237            buckets[flat_idx].push(values[row]);
238        }
239    }
240
241    // Aggregate buckets.
242    let raw_data: Vec<f64> = buckets
243        .into_iter()
244        .map(|mut bucket| aggregate(bucket.as_mut_slice(), strategy))
245        .collect();
246
247    let out = ArrayD::from_shape_vec(IxDyn(&shape), raw_data).map_err(|e| {
248        InterpolateError::ShapeError(format!("failed to construct output array: {e}"))
249    })?;
250
251    Ok(out)
252}
253
254/// Compute C-order stride for dimension `dim` given `shape`.
255fn stride_for(shape: &[usize], dim: usize) -> usize {
256    shape[dim + 1..].iter().product()
257}
258
259/// Aggregate a mutable slice of values according to `strategy`.
260fn aggregate(bucket: &mut [f64], strategy: ResampleStrategy) -> f64 {
261    if bucket.is_empty() {
262        return match strategy {
263            ResampleStrategy::Rasterize(Aggregator::Count) | ResampleStrategy::Conservative => 0.0,
264            _ => f64::NAN,
265        };
266    }
267
268    match strategy {
269        ResampleStrategy::Rasterize(Aggregator::Mean) | ResampleStrategy::Conservative => {
270            bucket.iter().sum::<f64>() / bucket.len() as f64
271        }
272        ResampleStrategy::Rasterize(Aggregator::Median) => {
273            bucket.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
274            let mid = bucket.len() / 2;
275            if bucket.len() % 2 == 0 {
276                (bucket[mid - 1] + bucket[mid]) * 0.5
277            } else {
278                bucket[mid]
279            }
280        }
281        ResampleStrategy::Rasterize(Aggregator::Max) => {
282            bucket.iter().copied().fold(f64::NEG_INFINITY, f64::max)
283        }
284        ResampleStrategy::Rasterize(Aggregator::Min) => {
285            bucket.iter().copied().fold(f64::INFINITY, f64::min)
286        }
287        ResampleStrategy::Rasterize(Aggregator::Count) => bucket.len() as f64,
288    }
289}