scirs2_interpolate/resampling/grid_spec.rs
1//! N-dimensional grid specification and scattered-to-grid resampling.
2//!
3//! Provides `GridSpec`, `Aggregator`, `ResampleStrategy`, and
4//! `resample_scattered_to_grid` for mapping scattered (x, y) data onto a
5//! regular N-D grid, returning an `ArrayD<f64>`.
6//!
7//! # Design
8//!
9//! - `Rasterize(Aggregator)` bins each scattered point into its nearest grid
10//! cell and accumulates values; empty cells receive `f64::NAN`.
11//! - `Conservative` applies area-weighted (axis-aligned) accumulation —
12//! for axis-aligned grids this reduces to rasterize with count-weighting.
13//!
14//! # Example
15//!
16//! ```rust
17//! use scirs2_interpolate::resampling::{GridSpec, Aggregator, ResampleStrategy, resample_scattered_to_grid};
18//! use scirs2_core::ndarray::Array2;
19//!
20//! // 4 scattered points in 2-D, values = x + y
21//! let pts = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0]).unwrap();
22//! let vals = scirs2_core::ndarray::Array1::from_vec(vec![0.0, 1.0, 1.0, 2.0]);
23//! let grid = GridSpec::uniform(2, &[(0.0, 1.0, 2), (0.0, 1.0, 2)]);
24//! let out = resample_scattered_to_grid(&pts, &vals, &grid, ResampleStrategy::Rasterize(Aggregator::Mean))
25//! .expect("resample");
26//! assert_eq!(out.shape(), &[2, 2]);
27//! ```
28
29use scirs2_core::ndarray::{Array1, Array2, ArrayD, IxDyn};
30
31use crate::error::{InterpolateError, InterpolateResult};
32
33// ─────────────────────────────────────────────────────────────────────────────
34// GridSpec
35// ─────────────────────────────────────────────────────────────────────────────
36
37/// Description of a regular N-dimensional target grid.
38///
39/// Each axis is defined by an `Array1<f64>` of strictly increasing
40/// grid-cell centre coordinates.
41#[derive(Debug, Clone)]
42pub struct GridSpec {
43 /// Grid-cell centre coordinates for each axis.
44 pub axes: Vec<Array1<f64>>,
45}
46
47impl GridSpec {
48 /// Create a `GridSpec` from per-axis arrays.
49 ///
50 /// Each `axis` slice must contain at least 1 value and be strictly
51 /// increasing.
52 pub fn new(axes: Vec<Array1<f64>>) -> InterpolateResult<Self> {
53 for (dim, ax) in axes.iter().enumerate() {
54 if ax.is_empty() {
55 return Err(InterpolateError::InvalidInput {
56 message: format!("axis {dim} is empty"),
57 });
58 }
59 for i in 1..ax.len() {
60 if ax[i] <= ax[i - 1] {
61 return Err(InterpolateError::InvalidInput {
62 message: format!("axis {dim} is not strictly increasing"),
63 });
64 }
65 }
66 }
67 Ok(Self { axes })
68 }
69
70 /// Build a uniform grid from `(min, max, n_cells)` per dimension.
71 ///
72 /// This is a convenience constructor; panics only if `n_cells == 0`.
73 pub fn uniform(dim: usize, specs: &[(f64, f64, usize)]) -> Self {
74 assert_eq!(specs.len(), dim, "specs length must equal dim");
75 let axes: Vec<Array1<f64>> = specs
76 .iter()
77 .map(|&(lo, hi, n)| {
78 let n_pts = n.max(1);
79 if n_pts == 1 {
80 Array1::from_vec(vec![lo])
81 } else {
82 let step = (hi - lo) / (n_pts as f64 - 1.0);
83 Array1::from_iter((0..n_pts).map(|i| lo + i as f64 * step))
84 }
85 })
86 .collect();
87 Self { axes }
88 }
89
90 /// Dimensionality of the grid.
91 pub fn ndim(&self) -> usize {
92 self.axes.len()
93 }
94
95 /// Shape of the output array (one entry per axis).
96 pub fn shape(&self) -> Vec<usize> {
97 self.axes.iter().map(|ax| ax.len()).collect()
98 }
99
100 /// Total number of grid cells.
101 pub fn n_cells(&self) -> usize {
102 self.axes.iter().map(|ax| ax.len()).product()
103 }
104
105 /// Find the nearest cell index along `dim` for coordinate `val`.
106 ///
107 /// Returns the index of the closest grid-centre value.
108 pub fn nearest_index(&self, dim: usize, val: f64) -> usize {
109 let ax = &self.axes[dim];
110 let n = ax.len();
111 // Binary search for the insertion point, then compare neighbours.
112 // Use a manual binary search to avoid `unwrap()` on `as_slice()`.
113 let pos = {
114 let mut lo_idx = 0_usize;
115 let mut hi_idx = n;
116 while lo_idx < hi_idx {
117 let mid = lo_idx + (hi_idx - lo_idx) / 2;
118 if ax[mid] < val {
119 lo_idx = mid + 1;
120 } else {
121 hi_idx = mid;
122 }
123 }
124 lo_idx
125 };
126 if pos == 0 {
127 0
128 } else if pos == n {
129 n - 1
130 } else {
131 let lo = ax[pos - 1];
132 let hi = ax[pos];
133 if (val - lo).abs() <= (val - hi).abs() {
134 pos - 1
135 } else {
136 pos
137 }
138 }
139 }
140}
141
142// ─────────────────────────────────────────────────────────────────────────────
143// Aggregator
144// ─────────────────────────────────────────────────────────────────────────────
145
146/// Value-aggregation strategy for cells that receive multiple scattered points.
147#[derive(Debug, Clone, Copy, PartialEq, Eq)]
148pub enum Aggregator {
149 /// Arithmetic mean.
150 Mean,
151 /// Sample median (O(k log k) per non-empty cell).
152 Median,
153 /// Maximum value.
154 Max,
155 /// Minimum value.
156 Min,
157 /// Number of points that fell into the cell.
158 Count,
159}
160
161// ─────────────────────────────────────────────────────────────────────────────
162// ResampleStrategy
163// ─────────────────────────────────────────────────────────────────────────────
164
165/// How to map scattered data to grid cells.
166#[derive(Debug, Clone, Copy, PartialEq)]
167pub enum ResampleStrategy {
168 /// Assign each point to its nearest grid cell, then aggregate.
169 Rasterize(Aggregator),
170 /// Area-weighted conservative resampling (axis-aligned approximation).
171 ///
172 /// For axis-aligned grids this reduces to count-weighted rasterization,
173 /// preserving the total integral of the scattered data.
174 Conservative,
175}
176
177// ─────────────────────────────────────────────────────────────────────────────
178// resample_scattered_to_grid
179// ─────────────────────────────────────────────────────────────────────────────
180
181/// Resample scattered N-D data onto a regular grid.
182///
183/// # Arguments
184///
185/// * `points` – `n × d` array of sample coordinates.
186/// * `values` – length-`n` array of scalar values at those coordinates.
187/// * `grid` – target `GridSpec` describing the output grid.
188/// * `strategy` – how to handle multiple points per cell.
189///
190/// # Returns
191///
192/// A `d`-dimensional `ArrayD<f64>` whose shape matches `grid.shape()`.
193/// Cells that received no points contain `f64::NAN` (or `0.0` for `Count`).
194pub fn resample_scattered_to_grid(
195 points: &Array2<f64>,
196 values: &Array1<f64>,
197 grid: &GridSpec,
198 strategy: ResampleStrategy,
199) -> InterpolateResult<ArrayD<f64>> {
200 let n = points.nrows();
201 let d = points.ncols();
202
203 if d != grid.ndim() {
204 return Err(InterpolateError::DimensionMismatch(format!(
205 "points has {d} columns but grid has {} axes",
206 grid.ndim()
207 )));
208 }
209 if values.len() != n {
210 return Err(InterpolateError::DimensionMismatch(format!(
211 "points has {n} rows but values has {} elements",
212 values.len()
213 )));
214 }
215
216 let shape = grid.shape();
217 let total_cells = grid.n_cells();
218
219 // Build a flat bucket list: each cell collects the values of points
220 // that land in it.
221 let mut buckets: Vec<Vec<f64>> = vec![Vec::new(); total_cells];
222
223 for row in 0..n {
224 // Compute N-D cell index then convert to a flat C-order offset.
225 let mut multi_idx = vec![0_usize; d];
226 for dim in 0..d {
227 let coord = points[[row, dim]];
228 multi_idx[dim] = grid.nearest_index(dim, coord);
229 }
230
231 let mut flat_idx = 0_usize;
232 for dim in 0..d {
233 flat_idx += multi_idx[dim] * stride_for(&shape, dim);
234 }
235
236 if flat_idx < total_cells {
237 buckets[flat_idx].push(values[row]);
238 }
239 }
240
241 // Aggregate buckets.
242 let raw_data: Vec<f64> = buckets
243 .into_iter()
244 .map(|mut bucket| aggregate(bucket.as_mut_slice(), strategy))
245 .collect();
246
247 let out = ArrayD::from_shape_vec(IxDyn(&shape), raw_data).map_err(|e| {
248 InterpolateError::ShapeError(format!("failed to construct output array: {e}"))
249 })?;
250
251 Ok(out)
252}
253
254/// Compute C-order stride for dimension `dim` given `shape`.
255fn stride_for(shape: &[usize], dim: usize) -> usize {
256 shape[dim + 1..].iter().product()
257}
258
259/// Aggregate a mutable slice of values according to `strategy`.
260fn aggregate(bucket: &mut [f64], strategy: ResampleStrategy) -> f64 {
261 if bucket.is_empty() {
262 return match strategy {
263 ResampleStrategy::Rasterize(Aggregator::Count) | ResampleStrategy::Conservative => 0.0,
264 _ => f64::NAN,
265 };
266 }
267
268 match strategy {
269 ResampleStrategy::Rasterize(Aggregator::Mean) | ResampleStrategy::Conservative => {
270 bucket.iter().sum::<f64>() / bucket.len() as f64
271 }
272 ResampleStrategy::Rasterize(Aggregator::Median) => {
273 bucket.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
274 let mid = bucket.len() / 2;
275 if bucket.len() % 2 == 0 {
276 (bucket[mid - 1] + bucket[mid]) * 0.5
277 } else {
278 bucket[mid]
279 }
280 }
281 ResampleStrategy::Rasterize(Aggregator::Max) => {
282 bucket.iter().copied().fold(f64::NEG_INFINITY, f64::max)
283 }
284 ResampleStrategy::Rasterize(Aggregator::Min) => {
285 bucket.iter().copied().fold(f64::INFINITY, f64::min)
286 }
287 ResampleStrategy::Rasterize(Aggregator::Count) => bucket.len() as f64,
288 }
289}