Skip to main content

scirs2_ndimage/interpolation/
utils.rs

1//! Utility functions for interpolation
2
3use scirs2_core::ndarray::Array;
4use scirs2_core::numeric::{Float, FromPrimitive};
5use std::fmt::Debug;
6
7use super::BoundaryMode;
8use crate::error::{NdimageError, NdimageResult};
9
10/// Helper function for safe conversion from usize to float
11#[allow(dead_code)]
12fn safe_usize_to_float<T: Float + FromPrimitive>(value: usize) -> NdimageResult<T> {
13    T::from_usize(value).ok_or_else(|| {
14        NdimageError::ComputationError(format!("Failed to convert usize {} to float type", value))
15    })
16}
17
18/// Helper function for safe conversion from float to usize
19#[allow(dead_code)]
20fn safe_float_to_usize<T: Float>(value: T) -> NdimageResult<usize> {
21    value.to_usize().ok_or_else(|| {
22        NdimageError::ComputationError("Failed to convert float to usize".to_string())
23    })
24}
25
26/// Handle out-of-bounds coordinates according to the boundary mode
27///
28/// # Arguments
29///
30/// * `coord` - Coordinate to process
31/// * `size` - Size of the array dimension
32/// * `mode` - Boundary handling mode
33///
34/// # Returns
35///
36/// * `Result<T>` - Processed coordinate
37#[allow(dead_code)]
38pub fn handle_boundary<T>(coord: T, size: usize, mode: BoundaryMode) -> NdimageResult<T>
39where
40    T: Float + FromPrimitive + Debug + std::ops::AddAssign + std::ops::DivAssign + 'static,
41{
42    // Convert size to T for calculations
43    let size_t = safe_usize_to_float(size)?;
44
45    // Handle within-bounds case
46    if coord >= T::zero() && coord < size_t {
47        return Ok(coord);
48    }
49
50    // Handle out-of-bounds according to mode
51    match mode {
52        BoundaryMode::Constant => {
53            // For constant mode, return an out-of-bounds indicator
54            // The actual handling would be done by the caller
55            Err(NdimageError::InterpolationError(format!(
56                "Coordinate {:?} out of bounds for size {} with constant mode",
57                coord, size
58            )))
59        }
60        BoundaryMode::Nearest => {
61            if coord < T::zero() {
62                Ok(T::zero())
63            } else {
64                Ok(size_t - T::one())
65            }
66        }
67        BoundaryMode::Reflect => {
68            // Placeholder for reflect mode
69            // Would implement proper reflection calculation
70            Ok(T::zero())
71        }
72        BoundaryMode::Mirror => {
73            // Placeholder for mirror mode
74            // Would implement proper mirroring calculation
75            Ok(T::zero())
76        }
77        BoundaryMode::Wrap => {
78            // Placeholder for wrap mode
79            // Would implement proper wrapping calculation
80            Ok(T::zero())
81        }
82    }
83}
84
85/// Get the weights for linear interpolation
86///
87/// # Arguments
88///
89/// * `x` - Position for interpolation
90///
91/// # Returns
92///
93/// * `(usize, usize, T)` - (left index, right index, right weight)
94#[allow(dead_code)]
95pub fn linear_weights<T>(x: T) -> (usize, usize, T)
96where
97    T: Float + FromPrimitive + Debug + std::ops::AddAssign + std::ops::DivAssign + 'static,
98{
99    let x_floor = x.floor();
100    let x_int = safe_float_to_usize(x_floor).unwrap_or(0); // Use 0 as fallback for interpolation
101    let t = x - x_floor;
102
103    (x_int, x_int + 1, t)
104}
105
106/// Get the weights for cubic interpolation
107///
108/// # Arguments
109///
110/// * `x` - Position for interpolation
111///
112/// # Returns
113///
114/// * `(usize, [T; 4])` - (starting index, weights for 4 points)
115#[allow(dead_code)]
116pub fn cubic_weights<T>(x: T) -> (usize, [T; 4])
117where
118    T: Float + FromPrimitive + Debug + std::ops::AddAssign + std::ops::DivAssign + 'static,
119{
120    let x_floor = x.floor();
121    let x_int = safe_float_to_usize(x_floor).unwrap_or(0); // Use 0 as fallback for interpolation
122    let t = x - x_floor;
123
124    // Catmull-Rom cubic interpolation weights
125    let t2 = t * t;
126    let t3 = t2 * t;
127
128    // Pre-calculate constants with safe conversions
129    let half = T::from_f64(0.5).unwrap_or_else(|| T::one() / (T::one() + T::one()));
130    let two = T::from_f64(2.0).unwrap_or_else(|| T::one() + T::one());
131    let three = T::from_f64(3.0).unwrap_or_else(|| two + T::one());
132    let four = T::from_f64(4.0).unwrap_or_else(|| two + two);
133    let five = T::from_f64(5.0).unwrap_or_else(|| four + T::one());
134
135    let w0 = half * (-t3 + two * t2 - t);
136    let w1 = half * (three * t3 - five * t2 + two);
137    let w2 = half * (-three * t3 + four * t2 + t);
138    let w3 = half * (t3 - t2);
139
140    let weights = [w0, w1, w2, w3];
141
142    // Starting index is one less than floor because cubic uses 4 points
143    let start_idx = if x_int > 0 { x_int - 1 } else { 0 };
144
145    (start_idx, weights)
146}
147
148/// Helper function for nearest neighbor interpolation
149#[allow(dead_code)]
150pub fn interpolate_nearest<T>(
151    input: &Array<T, scirs2_core::ndarray::IxDyn>,
152    coords: &[T],
153    boundary: &BoundaryMode,
154    const_val: T,
155) -> T
156where
157    T: Float + FromPrimitive + Debug + std::ops::AddAssign + std::ops::DivAssign + 'static,
158{
159    // Round coordinates to nearest integers
160    let int_coords: Vec<isize> = coords
161        .iter()
162        .map(|&coord| coord.round().to_isize().unwrap_or(0))
163        .collect();
164
165    // Apply boundary conditions and check bounds
166    let inputshape = input.shape();
167    let bounded_coords: Vec<usize> = int_coords
168        .iter()
169        .enumerate()
170        .map(|(i, &coord)| {
171            let dim_size = inputshape[i] as isize;
172            apply_boundary_condition(coord, dim_size, boundary)
173        })
174        .collect();
175
176    // Check if coordinates are valid (within bounds after boundary handling)
177    for (i, &coord) in bounded_coords.iter().enumerate() {
178        if coord >= inputshape[i] {
179            return const_val; // Out of bounds, return constant value
180        }
181    }
182
183    // Get value at the bounded coordinates
184    input
185        .get(bounded_coords.as_slice())
186        .copied()
187        .unwrap_or(const_val)
188}
189
190/// Helper function for linear interpolation  
191#[allow(dead_code)]
192pub fn interpolate_linear<T>(
193    input: &Array<T, scirs2_core::ndarray::IxDyn>,
194    coords: &[T],
195    boundary: &BoundaryMode,
196    const_val: T,
197) -> T
198where
199    T: Float + FromPrimitive + Debug + std::ops::AddAssign + std::ops::DivAssign + 'static,
200{
201    let ndim = coords.len();
202    if ndim == 0 {
203        return const_val;
204    }
205
206    // Handle 1D linear interpolation
207    if ndim == 1 {
208        let x = coords[0];
209        let x0 = x.floor();
210        let x1 = x0 + T::one();
211        let dx = x - x0;
212
213        let i0 = x0.to_isize().unwrap_or(0);
214        let i1 = x1.to_isize().unwrap_or(0);
215
216        let dim_size = input.shape()[0] as isize;
217        let idx0 = apply_boundary_condition(i0, dim_size, boundary);
218        let idx1 = apply_boundary_condition(i1, dim_size, boundary);
219
220        // Check bounds for constant mode
221        if matches!(boundary, BoundaryMode::Constant)
222            && (i0 < 0 || i0 >= dim_size || i1 < 0 || i1 >= dim_size)
223        {
224            return const_val;
225        }
226
227        let v0 = input.get([idx0]).copied().unwrap_or(const_val);
228        let v1 = input.get([idx1]).copied().unwrap_or(const_val);
229
230        return v0 * (T::one() - dx) + v1 * dx;
231    }
232
233    // For 2D and higher, use separable linear interpolation
234    if ndim == 2 {
235        let x = coords[0];
236        let y = coords[1];
237
238        let x0 = x.floor();
239        let x1 = x0 + T::one();
240        let y0 = y.floor();
241        let y1 = y0 + T::one();
242
243        let dx = x - x0;
244        let dy = y - y0;
245
246        let i0 = x0.to_isize().unwrap_or(0);
247        let i1 = x1.to_isize().unwrap_or(0);
248        let j0 = y0.to_isize().unwrap_or(0);
249        let j1 = y1.to_isize().unwrap_or(0);
250
251        let dim_size_x = input.shape()[0] as isize;
252        let dim_size_y = input.shape()[1] as isize;
253
254        // For Constant mode: check if the primary coordinate (floor) is out of bounds.
255        // Out-of-bounds corners are handled per-lookup via apply_boundary_condition +
256        // unwrap_or(const_val), so we only need to short-circuit when the primary
257        // (floor) coordinate is entirely outside the array.
258        if matches!(boundary, BoundaryMode::Constant)
259            && (i0 < 0 || i0 >= dim_size_x || j0 < 0 || j0 >= dim_size_y)
260        {
261            return const_val;
262        }
263
264        let idx0 = apply_boundary_condition(i0, dim_size_x, boundary);
265        let idx1 = apply_boundary_condition(i1, dim_size_x, boundary);
266        let jdx0 = apply_boundary_condition(j0, dim_size_y, boundary);
267        let jdx1 = apply_boundary_condition(j1, dim_size_y, boundary);
268
269        let v00 = input.get([idx0, jdx0]).copied().unwrap_or(const_val);
270        let v01 = input.get([idx0, jdx1]).copied().unwrap_or(const_val);
271        let v10 = input.get([idx1, jdx0]).copied().unwrap_or(const_val);
272        let v11 = input.get([idx1, jdx1]).copied().unwrap_or(const_val);
273
274        // Bilinear interpolation
275        let v0 = v00 * (T::one() - dy) + v01 * dy;
276        let v1 = v10 * (T::one() - dy) + v11 * dy;
277
278        return v0 * (T::one() - dx) + v1 * dx;
279    }
280
281    // For higher dimensions, fall back to nearest neighbor
282    interpolate_nearest(input, coords, boundary, const_val)
283}
284
285/// Apply boundary condition to a coordinate
286#[allow(dead_code)]
287pub fn apply_boundary_condition(_coord: isize, dimsize: isize, mode: &BoundaryMode) -> usize {
288    match mode {
289        BoundaryMode::Constant => {
290            if _coord < 0 || _coord >= dimsize {
291                // Return a value that will be caught as out of bounds
292                dimsize as usize
293            } else {
294                _coord as usize
295            }
296        }
297        BoundaryMode::Nearest => {
298            if _coord < 0 {
299                0
300            } else if _coord >= dimsize {
301                (dimsize - 1) as usize
302            } else {
303                _coord as usize
304            }
305        }
306        BoundaryMode::Wrap => {
307            if dimsize == 0 {
308                0
309            } else {
310                let wrapped = ((_coord % dimsize) + dimsize) % dimsize;
311                wrapped as usize
312            }
313        }
314        BoundaryMode::Reflect => {
315            if dimsize <= 1 {
316                0
317            } else {
318                let reflected = if _coord < 0 {
319                    (-_coord - 1) % dimsize
320                } else if _coord >= dimsize {
321                    (2 * dimsize - _coord - 1) % dimsize
322                } else {
323                    _coord
324                };
325                reflected as usize
326            }
327        }
328        BoundaryMode::Mirror => {
329            if dimsize <= 1 {
330                0
331            } else {
332                let period = 2 * (dimsize - 1);
333                let mirrored = if _coord < 0 {
334                    (-_coord) % period
335                } else if _coord >= dimsize {
336                    period - ((_coord - dimsize + 1) % period) - 1
337                } else {
338                    _coord
339                };
340                (mirrored.min(dimsize - 1)) as usize
341            }
342        }
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349
350    #[test]
351    fn test_handle_boundary_within_bounds() {
352        let result = handle_boundary(1.5, 10, BoundaryMode::Nearest)
353            .expect("handle_boundary should succeed for test");
354        assert_eq!(result, 1.5);
355    }
356
357    #[test]
358    fn test_handle_boundary_nearest() {
359        let result = handle_boundary(-2.0, 10, BoundaryMode::Nearest)
360            .expect("handle_boundary should succeed for test");
361        assert_eq!(result, 0.0);
362
363        let result = handle_boundary(15.0, 10, BoundaryMode::Nearest)
364            .expect("handle_boundary should succeed for test");
365        assert_eq!(result, 9.0);
366    }
367
368    #[test]
369    fn test_linear_weights() {
370        let (i0, i1, t) = linear_weights(1.3);
371        assert_eq!(i0, 1);
372        assert_eq!(i1, 2);
373        assert!((t - 0.3).abs() < 1e-10);
374    }
375
376    #[test]
377    fn test_cubic_weights() {
378        let (start_idx, weights) = cubic_weights(1.3);
379        assert!(start_idx <= 1);
380        assert_eq!(weights.len(), 4);
381
382        // Weights should sum to 1
383        let sum: f64 = weights.iter().sum();
384        assert!((sum - 1.0).abs() < 1e-10);
385    }
386}