scirs2_interpolate/voronoi/
gradient.rs

1//! Gradient estimation for Voronoi-based interpolation methods
2//!
3//! This module provides implementations for computing gradients of functions
4//! interpolated using natural neighbor interpolation.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use std::fmt::Debug;
9
10use super::natural::{InterpolationMethod, NaturalNeighborInterpolator};
11use crate::error::{InterpolateError, InterpolateResult};
12
13/// Trait for interpolators that can calculate values at query points
14pub trait Interpolator<F: Float + FromPrimitive + Debug + ordered_float::FloatCore> {
15    /// Interpolate at a single query point
16    fn interpolate(&self, query: &ArrayView1<F>) -> InterpolateResult<F>;
17}
18
19impl<
20        F: Float
21            + FromPrimitive
22            + Debug
23            + scirs2_core::ndarray::ScalarOperand
24            + 'static
25            + std::cmp::PartialOrd
26            + ordered_float::FloatCore,
27    > Interpolator<F> for NaturalNeighborInterpolator<F>
28{
29    fn interpolate(&self, query: &ArrayView1<F>) -> InterpolateResult<F> {
30        // Simply forward to the NaturalNeighborInterpolator's interpolate method
31        NaturalNeighborInterpolator::interpolate(self, query)
32    }
33}
34
35/// Trait for interpolators that can compute gradients
36pub trait GradientEstimation<F: Float + FromPrimitive + Debug + ordered_float::FloatCore> {
37    /// Computes the gradient of the interpolated function at a query point
38    ///
39    /// # Arguments
40    /// * `query` - The point at which to compute the gradient
41    ///
42    /// # Returns
43    /// A vector of partial derivatives with respect to each coordinate
44    fn gradient(&self, query: &ArrayView1<F>) -> InterpolateResult<Array1<F>>;
45
46    /// Computes the gradients of the interpolated function at multiple query points
47    ///
48    /// # Arguments
49    /// * `queries` - The points at which to compute gradients
50    ///
51    /// # Returns
52    /// A matrix where each row is the gradient at the corresponding query point
53    fn gradient_multi(&self, queries: &ArrayView2<F>) -> InterpolateResult<Array2<F>>;
54}
55
56/// Extends NaturalNeighborInterpolator with gradient estimation
57impl<
58        F: Float
59            + FromPrimitive
60            + Debug
61            + scirs2_core::ndarray::ScalarOperand
62            + 'static
63            + for<'a> std::iter::Sum<&'a F>
64            + std::cmp::PartialOrd
65            + ordered_float::FloatCore,
66    > GradientEstimation<F> for NaturalNeighborInterpolator<F>
67{
68    fn gradient(&self, query: &ArrayView1<F>) -> InterpolateResult<Array1<F>> {
69        let dim = query.len();
70
71        if dim != self.points.ncols() {
72            return Err(InterpolateError::DimensionMismatch(format!(
73                "Query point dimension ({}) does not match data dimension ({})",
74                dim,
75                self.points.ncols()
76            )));
77        }
78
79        // For natural neighbor interpolation, we can compute the gradient directly
80        // from the interpolation weights and the values at data points
81
82        // Get the natural neighbor weights
83        let neighbor_weights = self.voronoi_diagram().natural_neighbors(query)?;
84
85        if neighbor_weights.is_empty() {
86            // If no natural neighbors found, use finite difference approximation
87            return finite_difference_gradient(self, query);
88        }
89
90        // Compute the gradient based on the interpolation method
91        match self.method() {
92            InterpolationMethod::Sibson => {
93                // For Sibson's method, the gradient is computed as a weighted sum of
94                // value differences and point differences
95                let mut gradient = Array1::zeros(dim);
96
97                for (idx, weight) in neighbor_weights.iter() {
98                    let neighbor_point = self.points.row(*idx);
99                    let neighbor_value = self.values[*idx];
100
101                    // Compute contribution to gradient from this neighbor
102                    for d in 0..dim {
103                        let coordinate_diff = neighbor_point[d] - query[d];
104                        gradient[d] = gradient[d] + *weight * neighbor_value * coordinate_diff;
105                    }
106                }
107
108                // Normalize the gradient if necessary
109                let weight_sum: F = neighbor_weights.values().sum();
110                if weight_sum > F::zero() {
111                    gradient = gradient / weight_sum;
112                }
113
114                Ok(gradient)
115            }
116            InterpolationMethod::Laplace => {
117                // For Laplace's method, the gradient is approximated by a finite difference
118                // approach using the natural neighbors
119                let mut gradient = Array1::zeros(dim);
120
121                let center_value = self.interpolate(query)?;
122
123                // Compute a weighted average of finite differences
124                let mut total_weight = F::zero();
125
126                for (idx, weight) in neighbor_weights.iter() {
127                    let neighbor_point = self.points.row(*idx);
128                    let neighbor_value = self.values[*idx];
129
130                    // Compute distance from query to neighbor
131                    let mut distance = F::zero();
132                    for d in 0..dim {
133                        distance = distance
134                            + scirs2_core::numeric::Float::powi(neighbor_point[d] - query[d], 2);
135                    }
136                    distance = distance.sqrt();
137
138                    // Skip very close points to avoid numerical issues
139                    if distance < <F as scirs2_core::numeric::Float>::epsilon() {
140                        continue;
141                    }
142
143                    // Compute value difference
144                    let value_diff = neighbor_value - center_value;
145
146                    // Contribute to gradient
147                    for d in 0..dim {
148                        let coordinate_diff = neighbor_point[d] - query[d];
149                        // Directional derivative along the vector from query to neighbor
150                        let dir_deriv = value_diff / distance;
151                        // Project onto coordinate axis
152                        gradient[d] =
153                            gradient[d] + *weight * dir_deriv * coordinate_diff / distance;
154                    }
155
156                    total_weight = total_weight + *weight;
157                }
158
159                // Normalize the gradient
160                if total_weight > F::zero() {
161                    gradient = gradient / total_weight;
162                }
163
164                Ok(gradient)
165            }
166        }
167    }
168
169    fn gradient_multi(&self, queries: &ArrayView2<F>) -> InterpolateResult<Array2<F>> {
170        let n_queries = queries.nrows();
171        let dim = queries.ncols();
172
173        if dim != self.points.ncols() {
174            return Err(InterpolateError::DimensionMismatch(format!(
175                "Query points dimension ({}) does not match data dimension ({})",
176                dim,
177                self.points.ncols()
178            )));
179        }
180
181        let mut gradients = Array2::zeros((n_queries, dim));
182
183        for i in 0..n_queries {
184            let query = queries.row(i);
185            let gradient = self.gradient(&query)?;
186
187            gradients.row_mut(i).assign(&gradient);
188        }
189
190        Ok(gradients)
191    }
192}
193
194/// Computes a gradient using finite difference approximation
195///
196/// This is a fallback method when natural neighbor weights are not available.
197///
198/// # Arguments
199/// * `interpolator` - The interpolator to use for function evaluations
200/// * `query` - The point at which to compute the gradient
201///
202/// # Returns
203/// The estimated gradient vector
204#[allow(dead_code)]
205fn finite_difference_gradient<F, T>(
206    interpolator: &T,
207    query: &ArrayView1<F>,
208) -> InterpolateResult<Array1<F>>
209where
210    F: Float + FromPrimitive + Debug + ordered_float::FloatCore,
211    T: GradientEstimation<F> + Interpolator<F>,
212{
213    let dim = query.len();
214    let mut gradient = Array1::zeros(dim);
215
216    // Use central differences for better accuracy
217    let h = F::from(1e-6).unwrap(); // Step size
218
219    // Compute the center value
220    let center_value = match interpolator.interpolate(query) {
221        Ok(v) => v,
222        Err(_) => {
223            // If interpolation at the center fails, use a one-sided difference
224            // by evaluating at nearby points only
225            for d in 0..dim {
226                let mut forward_query = query.to_owned();
227                forward_query[d] = forward_query[d] + h;
228
229                if let Ok(forward_value) = interpolator.interpolate(&forward_query.view()) {
230                    gradient[d] = forward_value / h; // Approximate slope
231                }
232            }
233            return Ok(gradient);
234        }
235    };
236
237    // Use central differences for each dimension
238    for d in 0..dim {
239        let mut forward_query = query.to_owned();
240        forward_query[d] = forward_query[d] + h;
241
242        let mut backward_query = query.to_owned();
243        backward_query[d] = backward_query[d] - h;
244
245        // Try to compute the forward and backward values
246        let forward_result = interpolator.interpolate(&forward_query.view());
247        let backward_result = interpolator.interpolate(&backward_query.view());
248
249        match (forward_result, backward_result) {
250            (Ok(forward_value), Ok(backward_value)) => {
251                // Central difference
252                gradient[d] = (forward_value - backward_value) / (h + h);
253            }
254            (Ok(forward_value), Err(_)) => {
255                // Forward difference
256                gradient[d] = (forward_value - center_value) / h;
257            }
258            (Err(_), Ok(backward_value)) => {
259                // Backward difference
260                gradient[d] = (center_value - backward_value) / h;
261            }
262            (Err(_), Err(_)) => {
263                // Can't compute gradient in this direction
264                gradient[d] = F::zero();
265            }
266        }
267    }
268
269    Ok(gradient)
270}
271
272/// Information returned by interpolation with gradient
273pub struct InterpolateWithGradientResult<
274    F: Float + FromPrimitive + Debug + ordered_float::FloatCore,
275> {
276    /// The interpolated value
277    pub value: F,
278
279    /// The gradient vector
280    pub gradient: Array1<F>,
281}
282
283/// Extension trait for interpolators to compute interpolated values with gradients
284pub trait InterpolateWithGradient<F: Float + FromPrimitive + Debug + ordered_float::FloatCore> {
285    /// Interpolates a value and computes its gradient at a query point
286    ///
287    /// # Arguments
288    /// * `query` - The point at which to interpolate and compute the gradient
289    ///
290    /// # Returns
291    /// A struct containing the interpolated value and gradient
292    fn interpolate_with_gradient(
293        &self,
294        query: &ArrayView1<F>,
295    ) -> InterpolateResult<InterpolateWithGradientResult<F>>;
296
297    /// Interpolates values and computes gradients at multiple query points
298    ///
299    /// # Arguments
300    /// * `queries` - The points at which to interpolate and compute gradients
301    ///
302    /// # Returns
303    /// A vector of structs containing the interpolated values and gradients
304    fn interpolate_with_gradient_multi(
305        &self,
306        queries: &ArrayView2<F>,
307    ) -> InterpolateResult<Vec<InterpolateWithGradientResult<F>>>;
308}
309
310impl<
311        F: Float
312            + FromPrimitive
313            + Debug
314            + scirs2_core::ndarray::ScalarOperand
315            + 'static
316            + for<'a> std::iter::Sum<&'a F>
317            + std::cmp::PartialOrd
318            + ordered_float::FloatCore,
319    > InterpolateWithGradient<F> for NaturalNeighborInterpolator<F>
320{
321    fn interpolate_with_gradient(
322        &self,
323        query: &ArrayView1<F>,
324    ) -> InterpolateResult<InterpolateWithGradientResult<F>> {
325        let value = self.interpolate(query)?;
326        let gradient = self.gradient(query)?;
327
328        Ok(InterpolateWithGradientResult { value, gradient })
329    }
330
331    fn interpolate_with_gradient_multi(
332        &self,
333        queries: &ArrayView2<F>,
334    ) -> InterpolateResult<Vec<InterpolateWithGradientResult<F>>> {
335        let n_queries = queries.nrows();
336        let mut results = Vec::with_capacity(n_queries);
337
338        let values = self.interpolate_multi(queries)?;
339        let gradients = self.gradient_multi(queries)?;
340
341        for i in 0..n_queries {
342            results.push(InterpolateWithGradientResult {
343                value: values[i],
344                gradient: gradients.row(i).to_owned(),
345            });
346        }
347
348        Ok(results)
349    }
350}