scirs2_ndimage/interpolation/
utils.rs1use scirs2_core::ndarray::Array;
4use scirs2_core::numeric::{Float, FromPrimitive};
5use std::fmt::Debug;
6
7use super::BoundaryMode;
8use crate::error::{NdimageError, NdimageResult};
9
10#[allow(dead_code)]
12fn safe_usize_to_float<T: Float + FromPrimitive>(value: usize) -> NdimageResult<T> {
13 T::from_usize(value).ok_or_else(|| {
14 NdimageError::ComputationError(format!("Failed to convert usize {} to float type", value))
15 })
16}
17
18#[allow(dead_code)]
20fn safe_float_to_usize<T: Float>(value: T) -> NdimageResult<usize> {
21 value.to_usize().ok_or_else(|| {
22 NdimageError::ComputationError("Failed to convert float to usize".to_string())
23 })
24}
25
26#[allow(dead_code)]
38pub fn handle_boundary<T>(coord: T, size: usize, mode: BoundaryMode) -> NdimageResult<T>
39where
40 T: Float + FromPrimitive + Debug + std::ops::AddAssign + std::ops::DivAssign + 'static,
41{
42 let size_t = safe_usize_to_float(size)?;
44
45 if coord >= T::zero() && coord < size_t {
47 return Ok(coord);
48 }
49
50 match mode {
52 BoundaryMode::Constant => {
53 Err(NdimageError::InterpolationError(format!(
56 "Coordinate {:?} out of bounds for size {} with constant mode",
57 coord, size
58 )))
59 }
60 BoundaryMode::Nearest => {
61 if coord < T::zero() {
62 Ok(T::zero())
63 } else {
64 Ok(size_t - T::one())
65 }
66 }
67 BoundaryMode::Reflect => {
68 Ok(T::zero())
71 }
72 BoundaryMode::Mirror => {
73 Ok(T::zero())
76 }
77 BoundaryMode::Wrap => {
78 Ok(T::zero())
81 }
82 }
83}
84
85#[allow(dead_code)]
95pub fn linear_weights<T>(x: T) -> (usize, usize, T)
96where
97 T: Float + FromPrimitive + Debug + std::ops::AddAssign + std::ops::DivAssign + 'static,
98{
99 let x_floor = x.floor();
100 let x_int = safe_float_to_usize(x_floor).unwrap_or(0); let t = x - x_floor;
102
103 (x_int, x_int + 1, t)
104}
105
106#[allow(dead_code)]
116pub fn cubic_weights<T>(x: T) -> (usize, [T; 4])
117where
118 T: Float + FromPrimitive + Debug + std::ops::AddAssign + std::ops::DivAssign + 'static,
119{
120 let x_floor = x.floor();
121 let x_int = safe_float_to_usize(x_floor).unwrap_or(0); let t = x - x_floor;
123
124 let t2 = t * t;
126 let t3 = t2 * t;
127
128 let half = T::from_f64(0.5).unwrap_or_else(|| T::one() / (T::one() + T::one()));
130 let two = T::from_f64(2.0).unwrap_or_else(|| T::one() + T::one());
131 let three = T::from_f64(3.0).unwrap_or_else(|| two + T::one());
132 let four = T::from_f64(4.0).unwrap_or_else(|| two + two);
133 let five = T::from_f64(5.0).unwrap_or_else(|| four + T::one());
134
135 let w0 = half * (-t3 + two * t2 - t);
136 let w1 = half * (three * t3 - five * t2 + two);
137 let w2 = half * (-three * t3 + four * t2 + t);
138 let w3 = half * (t3 - t2);
139
140 let weights = [w0, w1, w2, w3];
141
142 let start_idx = if x_int > 0 { x_int - 1 } else { 0 };
144
145 (start_idx, weights)
146}
147
148#[allow(dead_code)]
150pub fn interpolate_nearest<T>(
151 input: &Array<T, scirs2_core::ndarray::IxDyn>,
152 coords: &[T],
153 boundary: &BoundaryMode,
154 const_val: T,
155) -> T
156where
157 T: Float + FromPrimitive + Debug + std::ops::AddAssign + std::ops::DivAssign + 'static,
158{
159 let int_coords: Vec<isize> = coords
161 .iter()
162 .map(|&coord| coord.round().to_isize().unwrap_or(0))
163 .collect();
164
165 let inputshape = input.shape();
167 let bounded_coords: Vec<usize> = int_coords
168 .iter()
169 .enumerate()
170 .map(|(i, &coord)| {
171 let dim_size = inputshape[i] as isize;
172 apply_boundary_condition(coord, dim_size, boundary)
173 })
174 .collect();
175
176 for (i, &coord) in bounded_coords.iter().enumerate() {
178 if coord >= inputshape[i] {
179 return const_val; }
181 }
182
183 input
185 .get(bounded_coords.as_slice())
186 .copied()
187 .unwrap_or(const_val)
188}
189
190#[allow(dead_code)]
192pub fn interpolate_linear<T>(
193 input: &Array<T, scirs2_core::ndarray::IxDyn>,
194 coords: &[T],
195 boundary: &BoundaryMode,
196 const_val: T,
197) -> T
198where
199 T: Float + FromPrimitive + Debug + std::ops::AddAssign + std::ops::DivAssign + 'static,
200{
201 let ndim = coords.len();
202 if ndim == 0 {
203 return const_val;
204 }
205
206 if ndim == 1 {
208 let x = coords[0];
209 let x0 = x.floor();
210 let x1 = x0 + T::one();
211 let dx = x - x0;
212
213 let i0 = x0.to_isize().unwrap_or(0);
214 let i1 = x1.to_isize().unwrap_or(0);
215
216 let dim_size = input.shape()[0] as isize;
217 let idx0 = apply_boundary_condition(i0, dim_size, boundary);
218 let idx1 = apply_boundary_condition(i1, dim_size, boundary);
219
220 if matches!(boundary, BoundaryMode::Constant)
222 && (i0 < 0 || i0 >= dim_size || i1 < 0 || i1 >= dim_size)
223 {
224 return const_val;
225 }
226
227 let v0 = input.get([idx0]).copied().unwrap_or(const_val);
228 let v1 = input.get([idx1]).copied().unwrap_or(const_val);
229
230 return v0 * (T::one() - dx) + v1 * dx;
231 }
232
233 if ndim == 2 {
235 let x = coords[0];
236 let y = coords[1];
237
238 let x0 = x.floor();
239 let x1 = x0 + T::one();
240 let y0 = y.floor();
241 let y1 = y0 + T::one();
242
243 let dx = x - x0;
244 let dy = y - y0;
245
246 let i0 = x0.to_isize().unwrap_or(0);
247 let i1 = x1.to_isize().unwrap_or(0);
248 let j0 = y0.to_isize().unwrap_or(0);
249 let j1 = y1.to_isize().unwrap_or(0);
250
251 let dim_size_x = input.shape()[0] as isize;
252 let dim_size_y = input.shape()[1] as isize;
253
254 if matches!(boundary, BoundaryMode::Constant)
259 && (i0 < 0 || i0 >= dim_size_x || j0 < 0 || j0 >= dim_size_y)
260 {
261 return const_val;
262 }
263
264 let idx0 = apply_boundary_condition(i0, dim_size_x, boundary);
265 let idx1 = apply_boundary_condition(i1, dim_size_x, boundary);
266 let jdx0 = apply_boundary_condition(j0, dim_size_y, boundary);
267 let jdx1 = apply_boundary_condition(j1, dim_size_y, boundary);
268
269 let v00 = input.get([idx0, jdx0]).copied().unwrap_or(const_val);
270 let v01 = input.get([idx0, jdx1]).copied().unwrap_or(const_val);
271 let v10 = input.get([idx1, jdx0]).copied().unwrap_or(const_val);
272 let v11 = input.get([idx1, jdx1]).copied().unwrap_or(const_val);
273
274 let v0 = v00 * (T::one() - dy) + v01 * dy;
276 let v1 = v10 * (T::one() - dy) + v11 * dy;
277
278 return v0 * (T::one() - dx) + v1 * dx;
279 }
280
281 interpolate_nearest(input, coords, boundary, const_val)
283}
284
285#[allow(dead_code)]
287pub fn apply_boundary_condition(_coord: isize, dimsize: isize, mode: &BoundaryMode) -> usize {
288 match mode {
289 BoundaryMode::Constant => {
290 if _coord < 0 || _coord >= dimsize {
291 dimsize as usize
293 } else {
294 _coord as usize
295 }
296 }
297 BoundaryMode::Nearest => {
298 if _coord < 0 {
299 0
300 } else if _coord >= dimsize {
301 (dimsize - 1) as usize
302 } else {
303 _coord as usize
304 }
305 }
306 BoundaryMode::Wrap => {
307 if dimsize == 0 {
308 0
309 } else {
310 let wrapped = ((_coord % dimsize) + dimsize) % dimsize;
311 wrapped as usize
312 }
313 }
314 BoundaryMode::Reflect => {
315 if dimsize <= 1 {
316 0
317 } else {
318 let reflected = if _coord < 0 {
319 (-_coord - 1) % dimsize
320 } else if _coord >= dimsize {
321 (2 * dimsize - _coord - 1) % dimsize
322 } else {
323 _coord
324 };
325 reflected as usize
326 }
327 }
328 BoundaryMode::Mirror => {
329 if dimsize <= 1 {
330 0
331 } else {
332 let period = 2 * (dimsize - 1);
333 let mirrored = if _coord < 0 {
334 (-_coord) % period
335 } else if _coord >= dimsize {
336 period - ((_coord - dimsize + 1) % period) - 1
337 } else {
338 _coord
339 };
340 (mirrored.min(dimsize - 1)) as usize
341 }
342 }
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349
350 #[test]
351 fn test_handle_boundary_within_bounds() {
352 let result = handle_boundary(1.5, 10, BoundaryMode::Nearest)
353 .expect("handle_boundary should succeed for test");
354 assert_eq!(result, 1.5);
355 }
356
357 #[test]
358 fn test_handle_boundary_nearest() {
359 let result = handle_boundary(-2.0, 10, BoundaryMode::Nearest)
360 .expect("handle_boundary should succeed for test");
361 assert_eq!(result, 0.0);
362
363 let result = handle_boundary(15.0, 10, BoundaryMode::Nearest)
364 .expect("handle_boundary should succeed for test");
365 assert_eq!(result, 9.0);
366 }
367
368 #[test]
369 fn test_linear_weights() {
370 let (i0, i1, t) = linear_weights(1.3);
371 assert_eq!(i0, 1);
372 assert_eq!(i1, 2);
373 assert!((t - 0.3).abs() < 1e-10);
374 }
375
376 #[test]
377 fn test_cubic_weights() {
378 let (start_idx, weights) = cubic_weights(1.3);
379 assert!(start_idx <= 1);
380 assert_eq!(weights.len(), 4);
381
382 let sum: f64 = weights.iter().sum();
384 assert!((sum - 1.0).abs() < 1e-10);
385 }
386}