scirs2_core/ndarray_ext/manipulation.rs
1//! Array manipulation operations similar to `NumPy`'s array manipulation routines
2//!
3//! This module provides functions for manipulating arrays, including flip, roll,
4//! tile, repeat, and other operations, designed to mirror `NumPy`'s functionality.
5
6use ndarray::{Array, ArrayView, Ix1, Ix2};
7use num_traits::Zero;
8
9/// Result type for gradient function
10pub type GradientResult<T> = Result<(Array<T, Ix2>, Array<T, Ix2>), &'static str>;
11
12/// Flip a 2D array along one or more axes
13///
14/// # Arguments
15///
16/// * `array` - The input 2D array
17/// * `flip_axis_0` - Whether to flip along axis 0 (rows)
18/// * `flip_axis_0` - Whether to flip along axis 1 (columns)
19///
20/// # Returns
21///
22/// A new array with axes flipped as specified
23///
24/// # Examples
25///
26/// ```
27/// use ndarray::array;
28/// use scirs2_core::ndarray_ext::manipulation::flip_2d;
29///
30/// let a = array![[1, 2, 3], [4, 5, 6]];
31///
32/// // Flip along rows
33/// let flipped_rows = flip_2d(a.view(), true, false);
34/// assert_eq!(flipped_rows, array![[4, 5, 6], [1, 2, 3]]);
35///
36/// // Flip along columns
37/// let flipped_cols = flip_2d(a.view(), false, true);
38/// assert_eq!(flipped_cols, array![[3, 2, 1], [6, 5, 4]]);
39///
40/// // Flip along both axes
41/// let flipped_both = flip_2d(a.view(), true, true);
42/// assert_eq!(flipped_both, array![[6, 5, 4], [3, 2, 1]]);
43/// ```
44#[allow(dead_code)]
45pub fn flip_2d<T>(array: ArrayView<T, Ix2>, flip_axis_0: bool, flipaxis_1: bool) -> Array<T, Ix2>
46where
47 T: Clone + Zero,
48{
49 let (rows, cols) = (array.shape()[0], array.shape()[1]);
50 let mut result = Array::<T, Ix2>::zeros((rows, cols));
51
52 for i in 0..rows {
53 for j in 0..cols {
54 let src_i = if flip_axis_0 { rows - 1 - i } else { i };
55 let src_j = if flipaxis_1 { cols - 1 - j } else { j };
56
57 result[[i, j]] = array[[src_i, src_j]].clone();
58 }
59 }
60
61 result
62}
63
64/// Roll array elements along one or both axes
65///
66/// # Arguments
67///
68/// * `array` - The input 2D array
69/// * `shift_axis_0` - Number of places to shift along axis 0 (can be negative)
70/// * `shift_axis_1` - Number of places to shift along axis 1 (can be negative)
71///
72/// # Returns
73///
74/// A new array with elements rolled as specified
75///
76/// # Examples
77///
78/// ```
79/// use ndarray::array;
80/// use scirs2_core::ndarray_ext::manipulation::roll_2d;
81///
82/// let a = array![[1, 2, 3], [4, 5, 6]];
83///
84/// // Roll along rows by 1
85/// let rolled_rows = roll_2d(a.view(), 1, 0);
86/// assert_eq!(rolled_rows, array![[4, 5, 6], [1, 2, 3]]);
87///
88/// // Roll along columns by -1
89/// let rolled_cols = roll_2d(a.view(), 0, -1);
90/// assert_eq!(rolled_cols, array![[2, 3, 1], [5, 6, 4]]);
91/// ```
92#[allow(dead_code)]
93pub fn roll_2d<T>(
94 array: ArrayView<T, Ix2>,
95 shift_axis_0: isize,
96 shift_axis_1: isize,
97) -> Array<T, Ix2>
98where
99 T: Clone + Zero,
100{
101 let (rows, cols) = (array.shape()[0], array.shape()[1]);
102
103 // Handle case where no shifting is needed
104 if shift_axis_0 == 0 && shift_axis_1 == 0 {
105 return array.to_owned();
106 }
107
108 // Calculate effective shifts (handle negative shifts and wrap around)
109 let effective_shift_0 = if rows == 0 {
110 0
111 } else {
112 ((shift_axis_0 % rows as isize) + rows as isize) % rows as isize
113 };
114 let effective_shift_1 = if cols == 0 {
115 0
116 } else {
117 ((shift_axis_1 % cols as isize) + cols as isize) % cols as isize
118 };
119
120 let mut result = Array::<T, Ix2>::zeros((rows, cols));
121
122 for i in 0..rows {
123 for j in 0..cols {
124 // Calculate source indices with wrapping
125 let src_i = (i as isize + rows as isize - effective_shift_0) % rows as isize;
126 let src_j = (j as isize + cols as isize - effective_shift_1) % cols as isize;
127
128 result[[i, j]] = array[[src_i as usize, src_j as usize]].clone();
129 }
130 }
131
132 result
133}
134
135/// Repeat an array by tiling it in multiple dimensions
136///
137/// # Arguments
138///
139/// * `array` - The input 2D array
140/// * `reps_axis_0` - Number of times to repeat the array along axis 0
141/// * `reps_axis_1` - Number of times to repeat the array along axis 1
142///
143/// # Returns
144///
145/// A new array formed by repeating the input array
146///
147/// # Examples
148///
149/// ```
150/// use ndarray::array;
151/// use scirs2_core::ndarray_ext::manipulation::tile_2d;
152///
153/// let a = array![[1, 2], [3, 4]];
154///
155/// // Tile array to repeat it 2 times along axis 0 and 3 times along axis 1
156/// let tiled = tile_2d(a.view(), 2, 3);
157/// assert_eq!(tiled.shape(), &[4, 6]);
158/// assert_eq!(tiled,
159/// array![
160/// [1, 2, 1, 2, 1, 2],
161/// [3, 4, 3, 4, 3, 4],
162/// [1, 2, 1, 2, 1, 2],
163/// [3, 4, 3, 4, 3, 4]
164/// ]
165/// );
166/// ```
167#[allow(dead_code)]
168pub fn tile_2d<T>(array: ArrayView<T, Ix2>, reps_axis_0: usize, repsaxis_1: usize) -> Array<T, Ix2>
169where
170 T: Clone + Default + Zero,
171{
172 let (rows, cols) = (array.shape()[0], array.shape()[1]);
173
174 // New dimensions after tiling
175 let new_rows = rows * reps_axis_0;
176 let new_cols = cols * repsaxis_1;
177
178 // Edge case - zero repetitions
179 if reps_axis_0 == 0 || repsaxis_1 == 0 {
180 return Array::<T, Ix2>::default((0, 0));
181 }
182
183 // Edge case - one repetition
184 if reps_axis_0 == 1 && repsaxis_1 == 1 {
185 return array.to_owned();
186 }
187
188 let mut result = Array::<T, Ix2>::zeros((new_rows, new_cols));
189
190 // Fill the result with repeated copies of the array
191 for i in 0..new_rows {
192 for j in 0..new_cols {
193 let src_i = i % rows;
194 let src_j = j % cols;
195
196 result[[i, j]] = array[[src_i, src_j]].clone();
197 }
198 }
199
200 result
201}
202
203/// Repeat array elements by duplicating values
204///
205/// # Arguments
206///
207/// * `array` - The input 2D array
208/// * `repeats_axis_0` - Number of times to repeat each element along axis 0
209/// * `repeats_axis_1` - Number of times to repeat each element along axis 1
210///
211/// # Returns
212///
213/// A new array with elements repeated as specified
214///
215/// # Examples
216///
217/// ```
218/// use ndarray::array;
219/// use scirs2_core::ndarray_ext::manipulation::repeat_2d;
220///
221/// let a = array![[1, 2], [3, 4]];
222///
223/// // Repeat array elements 2 times along axis 0 and 3 times along axis 1
224/// let repeated = repeat_2d(a.view(), 2, 3);
225/// assert_eq!(repeated.shape(), &[4, 6]);
226/// assert_eq!(repeated,
227/// array![
228/// [1, 1, 1, 2, 2, 2],
229/// [1, 1, 1, 2, 2, 2],
230/// [3, 3, 3, 4, 4, 4],
231/// [3, 3, 3, 4, 4, 4]
232/// ]
233/// );
234/// ```
235#[allow(dead_code)]
236pub fn repeat_2d<T>(
237 array: ArrayView<T, Ix2>,
238 repeats_axis_0: usize,
239 repeats_axis_1: usize,
240) -> Array<T, Ix2>
241where
242 T: Clone + Default + Zero,
243{
244 let (rows, cols) = (array.shape()[0], array.shape()[1]);
245
246 // New dimensions after repeating
247 let new_rows = rows * repeats_axis_0;
248 let new_cols = cols * repeats_axis_1;
249
250 // Edge case - zero repetitions
251 if repeats_axis_0 == 0 || repeats_axis_1 == 0 {
252 return Array::<T, Ix2>::default((0, 0));
253 }
254
255 // Edge case - one repetition
256 if repeats_axis_0 == 1 && repeats_axis_1 == 1 {
257 return array.to_owned();
258 }
259
260 let mut result = Array::<T, Ix2>::zeros((new_rows, new_cols));
261
262 // Fill the result with repeated elements
263 for i in 0..rows {
264 for j in 0..cols {
265 for i_rep in 0..repeats_axis_0 {
266 for j_rep in 0..repeats_axis_1 {
267 let dest_i = i * repeats_axis_0 + i_rep;
268 let dest_j = j * repeats_axis_1 + j_rep;
269
270 result[[dest_i, dest_j]] = array[[i, j]].clone();
271 }
272 }
273 }
274 }
275
276 result
277}
278
279/// Swap rows or columns in a 2D array
280///
281/// # Arguments
282///
283/// * `array` - The input 2D array
284/// * `index1` - First index to swap
285/// * `index2` - Second index to swap
286/// * `axis` - Axis along which to swap (0 for rows, 1 for columns)
287///
288/// # Returns
289///
290/// A new array with specified rows or columns swapped
291///
292/// # Examples
293///
294/// ```
295/// use ndarray::array;
296/// use scirs2_core::ndarray_ext::manipulation::swap_axes_2d;
297///
298/// let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
299///
300/// // Swap rows 0 and 2
301/// let swapped_rows = swap_axes_2d(a.view(), 0, 2, 0).unwrap();
302/// assert_eq!(swapped_rows, array![[7, 8, 9], [4, 5, 6], [1, 2, 3]]);
303///
304/// // Swap columns 0 and 1
305/// let swapped_cols = swap_axes_2d(a.view(), 0, 1, 1).unwrap();
306/// assert_eq!(swapped_cols, array![[2, 1, 3], [5, 4, 6], [8, 7, 9]]);
307/// ```
308#[allow(dead_code)]
309pub fn swap_axes_2d<T>(
310 array: ArrayView<T, Ix2>,
311 index1: usize,
312 index2: usize,
313 axis: usize,
314) -> Result<Array<T, Ix2>, &'static str>
315where
316 T: Clone,
317{
318 let (rows, cols) = (array.shape()[0], array.shape()[1]);
319
320 if axis > 1 {
321 return Err("Axis must be 0 or 1 for 2D arrays");
322 }
323
324 // Check indices are in bounds
325 let axis_len = if axis == 0 { rows } else { cols };
326 if index1 >= axis_len || index2 >= axis_len {
327 return Err("Indices out of bounds");
328 }
329
330 // If indices are the same, just clone the array
331 if index1 == index2 {
332 return Ok(array.to_owned());
333 }
334
335 let mut result = array.to_owned();
336
337 match axis {
338 0 => {
339 // Swap rows
340 for j in 0..cols {
341 let temp = result[[index1, j]].clone();
342 result[[index1, j]] = result[[index2, j]].clone();
343 result[[index2, j]] = temp;
344 }
345 }
346 1 => {
347 // Swap columns
348 for i in 0..rows {
349 let temp = result[[i, index1]].clone();
350 result[[i, index1]] = result[[i, index2]].clone();
351 result[[i, index2]] = temp;
352 }
353 }
354 _ => unreachable!(),
355 }
356
357 Ok(result)
358}
359
360/// Pad a 2D array with a constant value
361///
362/// # Arguments
363///
364/// * `array` - The input 2D array
365/// * `pad_width` - A tuple of tuples specifying the number of values padded
366/// to the edges of each axis: ((before_axis_0, after_axis_0), (before_axis_1, after_axis_1))
367/// * `pad_value` - The value to set the padded elements
368///
369/// # Returns
370///
371/// A new array with padded borders
372///
373/// # Examples
374///
375/// ```
376/// use ndarray::array;
377/// use scirs2_core::ndarray_ext::manipulation::pad_2d;
378///
379/// let a = array![[1, 2], [3, 4]];
380///
381/// // Pad with 1 row before, 2 rows after, 1 column before, and 0 columns after
382/// let padded = pad_2d(a.view(), ((1, 2), (1, 0)), 0);
383/// assert_eq!(padded.shape(), &[5, 3]);
384/// assert_eq!(padded,
385/// array![
386/// [0, 0, 0],
387/// [0, 1, 2],
388/// [0, 3, 4],
389/// [0, 0, 0],
390/// [0, 0, 0]
391/// ]
392/// );
393/// ```
394#[allow(dead_code)]
395pub fn pad_2d<T>(
396 array: ArrayView<T, Ix2>,
397 pad_width: ((usize, usize), (usize, usize)),
398 pad_value: T,
399) -> Array<T, Ix2>
400where
401 T: Clone,
402{
403 let (rows, cols) = (array.shape()[0], array.shape()[1]);
404 let ((before_0, after_0), (before_1, after_1)) = pad_width;
405
406 // Calculate new dimensions
407 let new_rows = rows + before_0 + after_0;
408 let new_cols = cols + before_1 + after_1;
409
410 // Create the result array filled with the padding value
411 let mut result = Array::<T, Ix2>::from_elem((new_rows, new_cols), pad_value);
412
413 // Copy the original array into the padded array
414 for i in 0..rows {
415 for j in 0..cols {
416 result[[i + before_0, j + before_1]] = array[[i, j]].clone();
417 }
418 }
419
420 result
421}
422
423/// Concatenate 2D arrays along a specified axis
424///
425/// # Arguments
426///
427/// * `arrays` - A slice of 2D arrays to concatenate
428/// * `axis` - The axis along which to concatenate (0 for rows, 1 for columns)
429///
430/// # Returns
431///
432/// A new array containing the concatenated arrays
433///
434/// # Examples
435///
436/// ```
437/// use ndarray::array;
438/// use scirs2_core::ndarray_ext::manipulation::concatenate_2d;
439///
440/// let a = array![[1, 2], [3, 4]];
441/// let b = array![[5, 6], [7, 8]];
442///
443/// // Concatenate along rows (vertically)
444/// let vertical = concatenate_2d(&[a.view(), b.view()], 0).unwrap();
445/// assert_eq!(vertical.shape(), &[4, 2]);
446/// assert_eq!(vertical, array![[1, 2], [3, 4], [5, 6], [7, 8]]);
447///
448/// // Concatenate along columns (horizontally)
449/// let horizontal = concatenate_2d(&[a.view(), b.view()], 1).unwrap();
450/// assert_eq!(horizontal.shape(), &[2, 4]);
451/// assert_eq!(horizontal, array![[1, 2, 5, 6], [3, 4, 7, 8]]);
452/// ```
453#[allow(dead_code)]
454pub fn concatenate_2d<T>(
455 arrays: &[ArrayView<T, Ix2>],
456 axis: usize,
457) -> Result<Array<T, Ix2>, &'static str>
458where
459 T: Clone + Zero,
460{
461 if arrays.is_empty() {
462 return Err("No arrays provided for concatenation");
463 }
464
465 if axis > 1 {
466 return Err("Axis must be 0 or 1 for 2D arrays");
467 }
468
469 // Get the shape of the first array as a reference
470 let firstshape = arrays[0].shape();
471
472 // Calculate the total shape after concatenation
473 let mut totalshape = [firstshape[0], firstshape[1]];
474 for array in arrays.iter().skip(1) {
475 let currentshape = array.shape();
476
477 // Ensure all arrays have compatible shapes
478 if axis == 0 && currentshape[1] != firstshape[1] {
479 return Err("All arrays must have the same number of columns for axis=0 concatenation");
480 } else if axis == 1 && currentshape[0] != firstshape[0] {
481 return Err("All arrays must have the same number of rows for axis=1 concatenation");
482 }
483
484 totalshape[axis] += currentshape[axis];
485 }
486
487 // Create the result array
488 let mut result = Array::<T, Ix2>::zeros((totalshape[0], totalshape[1]));
489
490 // Fill the result array with data from the input arrays
491 match axis {
492 0 => {
493 // Concatenate along axis 0 (vertically)
494 let mut row_offset = 0;
495 for array in arrays {
496 let rows = array.shape()[0];
497 let cols = array.shape()[1];
498
499 for i in 0..rows {
500 for j in 0..cols {
501 result[[row_offset + i, j]] = array[[i, j]].clone();
502 }
503 }
504
505 row_offset += rows;
506 }
507 }
508 1 => {
509 // Concatenate along axis 1 (horizontally)
510 let mut col_offset = 0;
511 for array in arrays {
512 let rows = array.shape()[0];
513 let cols = array.shape()[1];
514
515 for i in 0..rows {
516 for j in 0..cols {
517 result[[i, col_offset + j]] = array[[i, j]].clone();
518 }
519 }
520
521 col_offset += cols;
522 }
523 }
524 _ => unreachable!(),
525 }
526
527 Ok(result)
528}
529
530/// Stack a sequence of 1D arrays into a 2D array
531///
532/// # Arguments
533///
534/// * `arrays` - A slice of 1D arrays to stack
535///
536/// # Returns
537///
538/// A 2D array where each row contains an input array
539///
540/// # Examples
541///
542/// ```
543/// use ndarray::array;
544/// use scirs2_core::ndarray_ext::manipulation::vstack_1d;
545///
546/// let a = array![1, 2, 3];
547/// let b = array![4, 5, 6];
548/// let c = array![7, 8, 9];
549///
550/// let stacked = vstack_1d(&[a.view(), b.view(), c.view()]).unwrap();
551/// assert_eq!(stacked.shape(), &[3, 3]);
552/// assert_eq!(stacked, array![[1, 2, 3], [4, 5, 6], [7, 8, 9]]);
553/// ```
554#[allow(dead_code)]
555pub fn vstack_1d<T>(arrays: &[ArrayView<T, Ix1>]) -> Result<Array<T, Ix2>, &'static str>
556where
557 T: Clone + Zero,
558{
559 if arrays.is_empty() {
560 return Err("No arrays provided for stacking");
561 }
562
563 // All arrays must have the same length
564 let expected_len = arrays[0].len();
565 for (_i, array) in arrays.iter().enumerate().skip(1) {
566 if array.len() != expected_len {
567 return Err("Arrays must have consistent lengths for stacking");
568 }
569 }
570
571 // Create the result array
572 let rows = arrays.len();
573 let cols = expected_len;
574 let mut result = Array::<T, Ix2>::zeros((rows, cols));
575
576 // Fill the result array
577 for (i, array) in arrays.iter().enumerate() {
578 for (j, val) in array.iter().enumerate() {
579 result[[i, j]] = val.clone();
580 }
581 }
582
583 Ok(result)
584}
585
586/// Stack a sequence of 1D arrays horizontally (as columns) into a 2D array
587///
588/// # Arguments
589///
590/// * `arrays` - A slice of 1D arrays to stack
591///
592/// # Returns
593///
594/// A 2D array where each column contains an input array
595///
596/// # Examples
597///
598/// ```
599/// use ndarray::array;
600/// use scirs2_core::ndarray_ext::manipulation::hstack_1d;
601///
602/// let a = array![1, 2, 3];
603/// let b = array![4, 5, 6];
604///
605/// let stacked = hstack_1d(&[a.view(), b.view()]).unwrap();
606/// assert_eq!(stacked.shape(), &[3, 2]);
607/// assert_eq!(stacked, array![[1, 4], [2, 5], [3, 6]]);
608/// ```
609#[allow(dead_code)]
610pub fn hstack_1d<T>(arrays: &[ArrayView<T, Ix1>]) -> Result<Array<T, Ix2>, &'static str>
611where
612 T: Clone + Zero,
613{
614 if arrays.is_empty() {
615 return Err("No arrays provided for stacking");
616 }
617
618 // All arrays must have the same length
619 let expected_len = arrays[0].len();
620 for (_i, array) in arrays.iter().enumerate().skip(1) {
621 if array.len() != expected_len {
622 return Err("Arrays must have consistent lengths for stacking");
623 }
624 }
625
626 // Create the result array
627 let rows = expected_len;
628 let cols = arrays.len();
629 let mut result = Array::<T, Ix2>::zeros((rows, cols));
630
631 // Fill the result array
632 for (j, array) in arrays.iter().enumerate() {
633 for (i, val) in array.iter().enumerate() {
634 result[[i, j]] = val.clone();
635 }
636 }
637
638 Ok(result)
639}
640
641/// Remove a dimension of size 1 from a 2D array, resulting in a 1D array
642///
643/// # Arguments
644///
645/// * `array` - The input 2D array
646/// * `axis` - The axis to squeeze (0 for rows, 1 for columns)
647///
648/// # Returns
649///
650/// A 1D array with the specified dimension removed
651///
652/// # Examples
653///
654/// ```
655/// use ndarray::array;
656/// use scirs2_core::ndarray_ext::manipulation::squeeze_2d;
657///
658/// let a = array![[1, 2, 3]]; // 1x3 array (1 row, 3 columns)
659/// let b = array![[1], [2], [3]]; // 3x1 array (3 rows, 1 column)
660///
661/// // Squeeze out the row dimension (axis 0) from a
662/// let squeezed_a = squeeze_2d(a.view(), 0).unwrap();
663/// assert_eq!(squeezed_a.shape(), &[3]);
664/// assert_eq!(squeezed_a, array![1, 2, 3]);
665///
666/// // Squeeze out the column dimension (axis 1) from b
667/// let squeezed_b = squeeze_2d(b.view(), 1).unwrap();
668/// assert_eq!(squeezed_b.shape(), &[3]);
669/// assert_eq!(squeezed_b, array![1, 2, 3]);
670/// ```
671#[allow(dead_code)]
672pub fn squeeze_2d<T>(array: ArrayView<T, Ix2>, axis: usize) -> Result<Array<T, Ix1>, &'static str>
673where
674 T: Clone + Zero,
675{
676 let (rows, cols) = (array.shape()[0], array.shape()[1]);
677
678 match axis {
679 0 => {
680 // Squeeze out row dimension
681 if rows != 1 {
682 return Err("Cannot squeeze array with more than 1 row along axis 0");
683 }
684
685 let mut result = Array::<T, Ix1>::zeros(cols);
686 for j in 0..cols {
687 result[j] = array[[0, j]].clone();
688 }
689
690 Ok(result)
691 }
692 1 => {
693 // Squeeze out column dimension
694 if cols != 1 {
695 return Err("Cannot squeeze array with more than 1 column along axis 1");
696 }
697
698 let mut result = Array::<T, Ix1>::zeros(rows);
699 for i in 0..rows {
700 result[i] = array[[i, 0]].clone();
701 }
702
703 Ok(result)
704 }
705 _ => Err("Axis must be 0 or 1 for 2D arrays"),
706 }
707}
708
709/// Create a meshgrid from 1D coordinate arrays
710///
711/// # Arguments
712///
713/// * `x` - 1D array of x coordinates
714/// * `y` - 1D array of y coordinates
715///
716/// # Returns
717///
718/// A tuple of two 2D arrays (X, Y) where X and Y are copies of the input arrays
719/// repeated to form a meshgrid
720///
721/// # Examples
722///
723/// ```
724/// use ndarray::array;
725/// use scirs2_core::ndarray_ext::manipulation::meshgrid;
726///
727/// let x = array![1, 2, 3];
728/// let y = array![4, 5];
729/// let (x_grid, y_grid) = meshgrid(x.view(), y.view()).unwrap();
730/// assert_eq!(x_grid.shape(), &[2, 3]);
731/// assert_eq!(y_grid.shape(), &[2, 3]);
732/// assert_eq!(x_grid, array![[1, 2, 3], [1, 2, 3]]);
733/// assert_eq!(y_grid, array![[4, 4, 4], [5, 5, 5]]);
734/// ```
735#[allow(dead_code)]
736pub fn meshgrid<T>(x: ArrayView<T, Ix1>, y: ArrayView<T, Ix1>) -> GradientResult<T>
737where
738 T: Clone + Zero,
739{
740 let nx = x.len();
741 let ny = y.len();
742
743 if nx == 0 || ny == 0 {
744 return Err("Input arrays must not be empty");
745 }
746
747 // Create output arrays
748 let mut x_grid = Array::<T, Ix2>::zeros((ny, nx));
749 let mut y_grid = Array::<T, Ix2>::zeros((ny, nx));
750
751 // Fill the meshgrid
752 for i in 0..ny {
753 for j in 0..nx {
754 x_grid[[i, j]] = x[j].clone();
755 y_grid[[i, j]] = y[i].clone();
756 }
757 }
758
759 Ok((x_grid, y_grid))
760}
761
762/// Find unique elements in an array
763///
764/// # Arguments
765///
766/// * `array` - The input 1D array
767///
768/// # Returns
769///
770/// A 1D array containing the unique elements of the input array, sorted
771///
772/// # Examples
773///
774/// ```
775/// use ndarray::array;
776/// use scirs2_core::ndarray_ext::manipulation::unique;
777///
778/// let a = array![3, 1, 2, 2, 3, 4, 1];
779/// let result = unique(a.view()).unwrap();
780/// assert_eq!(result, array![1, 2, 3, 4]);
781/// ```
782#[allow(dead_code)]
783pub fn unique<T>(array: ArrayView<T, Ix1>) -> Result<Array<T, Ix1>, &'static str>
784where
785 T: Clone + Ord,
786{
787 if array.is_empty() {
788 return Err("Input array must not be empty");
789 }
790
791 // Clone elements to a Vec and sort
792 let mut values: Vec<T> = array.iter().cloned().collect();
793 values.sort();
794
795 // Remove duplicates
796 values.dedup();
797
798 // Convert to ndarray
799 Ok(Array::from_vec(values))
800}
801
802/// Return the indices of the minimum values along the specified axis
803///
804/// # Arguments
805///
806/// * `array` - The input 2D array
807/// * `axis` - The axis along which to find the minimum values (0 for rows, 1 for columns, None for flattened array)
808///
809/// # Returns
810///
811/// A 1D array containing the indices of the minimum values
812///
813/// # Examples
814///
815/// ```
816/// use ndarray::array;
817/// use scirs2_core::ndarray_ext::manipulation::argmin;
818///
819/// let a = array![[5, 2, 3], [4, 1, 6]];
820///
821/// // Find indices of minimum values along axis 0 (columns)
822/// let result = argmin(a.view(), Some(0)).unwrap();
823/// assert_eq!(result, array![1, 1, 0]); // The indices of min values in each column
824///
825/// // Find indices of minimum values along axis 1 (rows)
826/// let result = argmin(a.view(), Some(1)).unwrap();
827/// assert_eq!(result, array![1, 1]); // The indices of min values in each row
828///
829/// // Find index of minimum value in flattened array
830/// let result = argmin(a.view(), None).unwrap();
831/// assert_eq!(result[0], 4); // The index of the minimum value in the flattened array (row 1, col 1)
832/// ```
833#[allow(dead_code)]
834pub fn argmin<T>(
835 array: ArrayView<T, Ix2>,
836 axis: Option<usize>,
837) -> Result<Array<usize, Ix1>, &'static str>
838where
839 T: Clone + PartialOrd,
840{
841 let (rows, cols) = (array.shape()[0], array.shape()[1]);
842
843 if rows == 0 || cols == 0 {
844 return Err("Input array must not be empty");
845 }
846
847 match axis {
848 Some(0) => {
849 // Find min indices along axis 0 (for each column)
850 let mut indices = Array::<usize, Ix1>::zeros(cols);
851
852 for j in 0..cols {
853 let mut min_idx = 0;
854 let mut min_val = &array[[0, j]];
855
856 for i in 1..rows {
857 if &array[[i, j]] < min_val {
858 min_idx = i;
859 min_val = &array[[i, j]];
860 }
861 }
862
863 indices[j] = min_idx;
864 }
865
866 Ok(indices)
867 }
868 Some(1) => {
869 // Find min indices along axis 1 (for each row)
870 let mut indices = Array::<usize, Ix1>::zeros(rows);
871
872 for i in 0..rows {
873 let mut min_idx = 0;
874 let mut min_val = &array[[i, 0]];
875
876 for j in 1..cols {
877 if &array[[i, j]] < min_val {
878 min_idx = j;
879 min_val = &array[[i, j]];
880 }
881 }
882
883 indices[i] = min_idx;
884 }
885
886 Ok(indices)
887 }
888 Some(_) => Err("Axis must be 0 or 1 for 2D arrays"),
889 None => {
890 // Find min index in flattened array
891 let mut min_idx = 0;
892 let mut min_val = &array[[0, 0]];
893
894 for i in 0..rows {
895 for j in 0..cols {
896 if &array[[i, j]] < min_val {
897 min_idx = i * cols + j;
898 min_val = &array[[i, j]];
899 }
900 }
901 }
902
903 Ok(Array::from_vec(vec![min_idx]))
904 }
905 }
906}
907
908/// Return the indices of the maximum values along the specified axis
909///
910/// # Arguments
911///
912/// * `array` - The input 2D array
913/// * `axis` - The axis along which to find the maximum values (0 for rows, 1 for columns, None for flattened array)
914///
915/// # Returns
916///
917/// A 1D array containing the indices of the maximum values
918///
919/// # Examples
920///
921/// ```
922/// use ndarray::array;
923/// use scirs2_core::ndarray_ext::manipulation::argmax;
924///
925/// let a = array![[5, 2, 3], [4, 1, 6]];
926///
927/// // Find indices of maximum values along axis 0 (columns)
928/// let result = argmax(a.view(), Some(0)).unwrap();
929/// assert_eq!(result, array![0, 0, 1]); // The indices of max values in each column
930///
931/// // Find indices of maximum values along axis 1 (rows)
932/// let result = argmax(a.view(), Some(1)).unwrap();
933/// assert_eq!(result, array![0, 2]); // The indices of max values in each row
934///
935/// // Find index of maximum value in flattened array
936/// let result = argmax(a.view(), None).unwrap();
937/// assert_eq!(result[0], 5); // The index of the maximum value in the flattened array (row 1, col 2)
938/// ```
939#[allow(dead_code)]
940pub fn argmax<T>(
941 array: ArrayView<T, Ix2>,
942 axis: Option<usize>,
943) -> Result<Array<usize, Ix1>, &'static str>
944where
945 T: Clone + PartialOrd,
946{
947 let (rows, cols) = (array.shape()[0], array.shape()[1]);
948
949 if rows == 0 || cols == 0 {
950 return Err("Input array must not be empty");
951 }
952
953 match axis {
954 Some(0) => {
955 // Find max indices along axis 0 (for each column)
956 let mut indices = Array::<usize, Ix1>::zeros(cols);
957
958 for j in 0..cols {
959 let mut max_idx = 0;
960 let mut max_val = &array[[0, j]];
961
962 for i in 1..rows {
963 if &array[[i, j]] > max_val {
964 max_idx = i;
965 max_val = &array[[i, j]];
966 }
967 }
968
969 indices[j] = max_idx;
970 }
971
972 Ok(indices)
973 }
974 Some(1) => {
975 // Find max indices along axis 1 (for each row)
976 let mut indices = Array::<usize, Ix1>::zeros(rows);
977
978 for i in 0..rows {
979 let mut max_idx = 0;
980 let mut max_val = &array[[i, 0]];
981
982 for j in 1..cols {
983 if &array[[i, j]] > max_val {
984 max_idx = j;
985 max_val = &array[[i, j]];
986 }
987 }
988
989 indices[i] = max_idx;
990 }
991
992 Ok(indices)
993 }
994 Some(_) => Err("Axis must be 0 or 1 for 2D arrays"),
995 None => {
996 // Find max index in flattened array
997 let mut max_idx = 0;
998 let mut max_val = &array[[0, 0]];
999
1000 for i in 0..rows {
1001 for j in 0..cols {
1002 if &array[[i, j]] > max_val {
1003 max_idx = i * cols + j;
1004 max_val = &array[[i, j]];
1005 }
1006 }
1007 }
1008
1009 Ok(Array::from_vec(vec![max_idx]))
1010 }
1011 }
1012}
1013
1014/// Calculate the gradient of an array
1015///
1016/// # Arguments
1017///
1018/// * `array` - The input 2D array
1019/// * `spacing` - Optional tuple of spacings for each axis
1020///
1021/// # Returns
1022///
1023/// A tuple of arrays (grad_y, grad_x) containing the gradient along each axis
1024///
1025/// # Examples
1026///
1027/// ```
1028/// use ndarray::array;
1029/// use scirs2_core::ndarray_ext::manipulation::gradient;
1030///
1031/// let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
1032///
1033/// // Calculate gradient with default spacing
1034/// let (grad_y, grad_x) = gradient(a.view(), None).unwrap();
1035/// // Vertical gradient (y-direction)
1036/// assert_eq!(grad_y.shape(), &[2, 3]);
1037/// // Horizontal gradient (x-direction)
1038/// assert_eq!(grad_x.shape(), &[2, 3]);
1039/// ```
1040#[allow(dead_code)]
1041pub fn gradient<T>(array: ArrayView<T, Ix2>, spacing: Option<(T, T)>) -> GradientResult<T>
1042where
1043 T: Clone + num_traits::Float,
1044{
1045 let (rows, cols) = (array.shape()[0], array.shape()[1]);
1046
1047 if rows == 0 || cols == 0 {
1048 return Err("Input array must not be empty");
1049 }
1050
1051 // Get spacing values (default to 1.0)
1052 let (dy, dx) = spacing.unwrap_or((T::one(), T::one()));
1053
1054 // Create output arrays for gradients
1055 let mut grad_y = Array::<T, Ix2>::zeros((rows, cols));
1056 let mut grad_x = Array::<T, Ix2>::zeros((rows, cols));
1057
1058 // Calculate gradient along y axis (rows)
1059 if rows == 1 {
1060 // Single row, gradient is zero
1061 // (already initialized with zeros)
1062 } else {
1063 // First row: forward difference
1064 for j in 0..cols {
1065 grad_y[[0, j]] = (array[[1, j]] - array[[0, j]]) / dy;
1066 }
1067
1068 // Middle rows: central difference
1069 for i in 1..rows - 1 {
1070 for j in 0..cols {
1071 grad_y[[i, j]] = (array[[i + 1, j]] - array[[i.saturating_sub(1), j]]) / (dy + dy);
1072 }
1073 }
1074
1075 // Last row: backward difference
1076 for j in 0..cols {
1077 grad_y[[rows - 1, j]] = (array[[rows - 1, j]] - array[[rows - 2, j]]) / dy;
1078 }
1079 }
1080
1081 // Calculate gradient along x axis (columns)
1082 if cols == 1 {
1083 // Single column, gradient is zero
1084 // (already initialized with zeros)
1085 } else {
1086 for i in 0..rows {
1087 // First column: forward difference
1088 grad_x[[i, 0]] = (array[[i, 1]] - array[[i, 0]]) / dx;
1089
1090 // Middle columns: central difference
1091 for j in 1..cols - 1 {
1092 grad_x[[i, j]] = (array[[i, j + 1]] - array[[i, j.saturating_sub(1)]]) / (dx + dx);
1093 }
1094
1095 // Last column: backward difference
1096 grad_x[[i, cols - 1]] = (array[[i, cols - 1]] - array[[i, cols - 2]]) / dx;
1097 }
1098 }
1099
1100 Ok((grad_y, grad_x))
1101}
1102
1103#[cfg(test)]
1104mod tests {
1105 use super::*;
1106 use approx::assert_abs_diff_eq;
1107 use ndarray::array;
1108
1109 #[test]
1110 fn test_flip_2d() {
1111 let a = array![[1, 2, 3], [4, 5, 6]];
1112
1113 // Test flipping along axis 0 (rows)
1114 let flipped_rows = flip_2d(a.view(), true, false);
1115 assert_eq!(flipped_rows, array![[4, 5, 6], [1, 2, 3]]);
1116
1117 // Test flipping along axis 1 (columns)
1118 let flipped_cols = flip_2d(a.view(), false, true);
1119 assert_eq!(flipped_cols, array![[3, 2, 1], [6, 5, 4]]);
1120
1121 // Test flipping along both axes
1122 let flipped_both = flip_2d(a.view(), true, true);
1123 assert_eq!(flipped_both, array![[6, 5, 4], [3, 2, 1]]);
1124 }
1125
1126 #[test]
1127 fn test_roll_2d() {
1128 let a = array![[1, 2, 3], [4, 5, 6]];
1129
1130 // Test rolling along axis 0 (rows)
1131 let rolled_rows = roll_2d(a.view(), 1, 0);
1132 assert_eq!(rolled_rows, array![[4, 5, 6], [1, 2, 3]]);
1133
1134 // Test rolling along axis 1 (columns)
1135 let rolled_cols = roll_2d(a.view(), 0, 1);
1136 assert_eq!(rolled_cols, array![[3, 1, 2], [6, 4, 5]]);
1137
1138 // Test negative rolling
1139 let rolled_neg = roll_2d(a.view(), 0, -1);
1140 assert_eq!(rolled_neg, array![[2, 3, 1], [5, 6, 4]]);
1141
1142 // Test rolling by zero (should return the original array)
1143 let rolled_zero = roll_2d(a.view(), 0, 0);
1144 assert_eq!(rolled_zero, a);
1145 }
1146
1147 #[test]
1148 fn test_tile_2d() {
1149 let a = array![[1, 2], [3, 4]];
1150
1151 // Test tiling along both axes
1152 let tiled = tile_2d(a.view(), 2, 3);
1153 assert_eq!(tiled.shape(), &[4, 6]);
1154 assert_eq!(
1155 tiled,
1156 array![
1157 [1, 2, 1, 2, 1, 2],
1158 [3, 4, 3, 4, 3, 4],
1159 [1, 2, 1, 2, 1, 2],
1160 [3, 4, 3, 4, 3, 4]
1161 ]
1162 );
1163
1164 // Test tiling along axis 0 only
1165 let tiled_axis_0 = tile_2d(a.view(), 2, 1);
1166 assert_eq!(tiled_axis_0.shape(), &[4, 2]);
1167 assert_eq!(tiled_axis_0, array![[1, 2], [3, 4], [1, 2], [3, 4]]);
1168
1169 // Test tiling a single element
1170 let single = array![[5]];
1171 let tiled_single = tile_2d(single.view(), 2, 2);
1172 assert_eq!(tiled_single.shape(), &[2, 2]);
1173 assert_eq!(tiled_single, array![[5, 5], [5, 5]]);
1174 }
1175
1176 #[test]
1177 fn test_repeat_2d() {
1178 let a = array![[1, 2], [3, 4]];
1179
1180 // Test repeating along both axes
1181 let repeated = repeat_2d(a.view(), 2, 3);
1182 assert_eq!(repeated.shape(), &[4, 6]);
1183 assert_eq!(
1184 repeated,
1185 array![
1186 [1, 1, 1, 2, 2, 2],
1187 [1, 1, 1, 2, 2, 2],
1188 [3, 3, 3, 4, 4, 4],
1189 [3, 3, 3, 4, 4, 4]
1190 ]
1191 );
1192
1193 // Test repeating along axis 1 only
1194 let repeated_axis_1 = repeat_2d(a.view(), 1, 2);
1195 assert_eq!(repeated_axis_1.shape(), &[2, 4]);
1196 assert_eq!(repeated_axis_1, array![[1, 1, 2, 2], [3, 3, 4, 4]]);
1197 }
1198
1199 #[test]
1200 fn test_swap_axes_2d() {
1201 let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
1202
1203 // Test swapping rows
1204 let swapped_rows = swap_axes_2d(a.view(), 0, 2, 0).unwrap();
1205 assert_eq!(swapped_rows, array![[7, 8, 9], [4, 5, 6], [1, 2, 3]]);
1206
1207 // Test swapping columns
1208 let swapped_cols = swap_axes_2d(a.view(), 0, 2, 1).unwrap();
1209 assert_eq!(swapped_cols, array![[3, 2, 1], [6, 5, 4], [9, 8, 7]]);
1210
1211 // Test swapping same indices (should return a clone of the original)
1212 let swapped_same = swap_axes_2d(a.view(), 1, 1, 0).unwrap();
1213 assert_eq!(swapped_same, a);
1214
1215 // Test invalid axis
1216 assert!(swap_axes_2d(a.view(), 0, 1, 2).is_err());
1217
1218 // Test out of bounds indices
1219 assert!(swap_axes_2d(a.view(), 0, 3, 0).is_err());
1220 }
1221
1222 #[test]
1223 fn test_pad_2d() {
1224 let a = array![[1, 2], [3, 4]];
1225
1226 // Test padding on all sides
1227 let padded_all = pad_2d(a.view(), ((1, 1), (1, 1)), 0);
1228 assert_eq!(padded_all.shape(), &[4, 4]);
1229 assert_eq!(
1230 padded_all,
1231 array![[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]]
1232 );
1233
1234 // Test uneven padding
1235 let padded_uneven = pad_2d(a.view(), ((2, 0), (0, 1)), 9);
1236 assert_eq!(padded_uneven.shape(), &[4, 3]);
1237 assert_eq!(
1238 padded_uneven,
1239 array![[9, 9, 9], [9, 9, 9], [1, 2, 9], [3, 4, 9]]
1240 );
1241 }
1242
1243 #[test]
1244 fn test_concatenate_2d() {
1245 let a = array![[1, 2], [3, 4]];
1246 let b = array![[5, 6], [7, 8]];
1247
1248 // Test concatenating along axis 0 (vertically)
1249 let vertical = concatenate_2d(&[a.view(), b.view()], 0).unwrap();
1250 assert_eq!(vertical.shape(), &[4, 2]);
1251 assert_eq!(vertical, array![[1, 2], [3, 4], [5, 6], [7, 8]]);
1252
1253 // Test concatenating along axis 1 (horizontally)
1254 let horizontal = concatenate_2d(&[a.view(), b.view()], 1).unwrap();
1255 assert_eq!(horizontal.shape(), &[2, 4]);
1256 assert_eq!(horizontal, array![[1, 2, 5, 6], [3, 4, 7, 8]]);
1257
1258 // Test concatenating with incompatible shapes
1259 let c = array![[9, 10, 11]];
1260 assert!(concatenate_2d(&[a.view(), c.view()], 0).is_err());
1261
1262 // Test concatenating empty array list
1263 let empty: [ArrayView<i32, Ix2>; 0] = [];
1264 assert!(concatenate_2d(&empty, 0).is_err());
1265
1266 // Test invalid axis
1267 assert!(concatenate_2d(&[a.view(), b.view()], 2).is_err());
1268 }
1269
1270 #[test]
1271 fn test_vstack_1d() {
1272 let a = array![1, 2, 3];
1273 let b = array![4, 5, 6];
1274 let c = array![7, 8, 9];
1275
1276 // Test stacking multiple arrays
1277 let stacked = vstack_1d(&[a.view(), b.view(), c.view()]).unwrap();
1278 assert_eq!(stacked.shape(), &[3, 3]);
1279 assert_eq!(stacked, array![[1, 2, 3], [4, 5, 6], [7, 8, 9]]);
1280
1281 // Test stacking empty list
1282 let empty: [ArrayView<i32, Ix1>; 0] = [];
1283 assert!(vstack_1d(&empty).is_err());
1284
1285 // Test inconsistent lengths
1286 let d = array![10, 11];
1287 assert!(vstack_1d(&[a.view(), d.view()]).is_err());
1288 }
1289
1290 #[test]
1291 fn test_hstack_1d() {
1292 let a = array![1, 2, 3];
1293 let b = array![4, 5, 6];
1294
1295 // Test stacking multiple arrays
1296 let stacked = hstack_1d(&[a.view(), b.view()]).unwrap();
1297 assert_eq!(stacked.shape(), &[3, 2]);
1298 assert_eq!(stacked, array![[1, 4], [2, 5], [3, 6]]);
1299
1300 // Test stacking empty list
1301 let empty: [ArrayView<i32, Ix1>; 0] = [];
1302 assert!(hstack_1d(&empty).is_err());
1303
1304 // Test inconsistent lengths
1305 let c = array![7, 8];
1306 assert!(hstack_1d(&[a.view(), c.view()]).is_err());
1307 }
1308
1309 #[test]
1310 fn test_squeeze_2d() {
1311 let a = array![[1, 2, 3]]; // 1x3 array
1312 let b = array![[1], [2], [3]]; // 3x1 array
1313
1314 // Test squeezing axis 0
1315 let squeezed_a = squeeze_2d(a.view(), 0).unwrap();
1316 assert_eq!(squeezed_a.shape(), &[3]);
1317 assert_eq!(squeezed_a, array![1, 2, 3]);
1318
1319 // Test squeezing axis 1
1320 let squeezed_b = squeeze_2d(b.view(), 1).unwrap();
1321 assert_eq!(squeezed_b.shape(), &[3]);
1322 assert_eq!(squeezed_b, array![1, 2, 3]);
1323
1324 // Test squeezing on an axis with size > 1 (should fail)
1325 let c = array![[1, 2], [3, 4]]; // 2x2 array
1326 assert!(squeeze_2d(c.view(), 0).is_err());
1327 assert!(squeeze_2d(c.view(), 1).is_err());
1328
1329 // Test invalid axis
1330 assert!(squeeze_2d(a.view(), 2).is_err());
1331 }
1332
1333 #[test]
1334 fn test_meshgrid() {
1335 let x = array![1, 2, 3];
1336 let y = array![4, 5];
1337
1338 let (x_grid, y_grid) = meshgrid(x.view(), y.view()).unwrap();
1339 assert_eq!(x_grid.shape(), &[2, 3]);
1340 assert_eq!(y_grid.shape(), &[2, 3]);
1341 assert_eq!(x_grid, array![[1, 2, 3], [1, 2, 3]]);
1342 assert_eq!(y_grid, array![[4, 4, 4], [5, 5, 5]]);
1343
1344 // Test empty arrays
1345 let empty = array![];
1346 assert!(meshgrid(x.view(), empty.view()).is_err());
1347 assert!(meshgrid(empty.view(), y.view()).is_err());
1348 }
1349
1350 #[test]
1351 fn test_unique() {
1352 let a = array![3, 1, 2, 2, 3, 4, 1];
1353 let result = unique(a.view()).unwrap();
1354 assert_eq!(result, array![1, 2, 3, 4]);
1355
1356 // Test empty array
1357 let empty: Array<i32, Ix1> = array![];
1358 assert!(unique(empty.view()).is_err());
1359 }
1360
1361 #[test]
1362 fn test_argmin() {
1363 let a = array![[5, 2, 3], [4, 1, 6]];
1364
1365 // Test along axis 0
1366 let result = argmin(a.view(), Some(0)).unwrap();
1367 assert_eq!(result, array![1, 1, 0]);
1368
1369 // Test along axis 1
1370 let result = argmin(a.view(), Some(1)).unwrap();
1371 assert_eq!(result, array![1, 1]);
1372
1373 // Test flattened array
1374 let result = argmin(a.view(), None).unwrap();
1375 assert_eq!(result[0], 4); // Index of 1 in the flattened array (row 1, col 1)
1376
1377 // Test invalid axis
1378 assert!(argmin(a.view(), Some(2)).is_err());
1379
1380 // Test empty array
1381 let empty: Array<i32, Ix2> = Array::zeros((0, 0));
1382 assert!(argmin(empty.view(), Some(0)).is_err());
1383 }
1384
1385 #[test]
1386 fn test_argmax() {
1387 let a = array![[5, 2, 3], [4, 1, 6]];
1388
1389 // Test along axis 0
1390 let result = argmax(a.view(), Some(0)).unwrap();
1391 assert_eq!(result, array![0, 0, 1]);
1392
1393 // Test along axis 1
1394 let result = argmax(a.view(), Some(1)).unwrap();
1395 assert_eq!(result, array![0, 2]);
1396
1397 // Test flattened array
1398 let result = argmax(a.view(), None).unwrap();
1399 assert_eq!(result[0], 5); // Index of 6 in the flattened array (row 1, col 2)
1400
1401 // Test invalid axis
1402 assert!(argmax(a.view(), Some(2)).is_err());
1403
1404 // Test empty array
1405 let empty: Array<i32, Ix2> = Array::zeros((0, 0));
1406 assert!(argmax(empty.view(), Some(0)).is_err());
1407 }
1408
1409 #[test]
1410 fn test_gradient() {
1411 let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
1412
1413 // Calculate gradient with default spacing
1414 let (grad_y, grad_x) = gradient(a.view(), None).unwrap();
1415
1416 // Verify shapes
1417 assert_eq!(grad_y.shape(), &[2, 3]);
1418 assert_eq!(grad_x.shape(), &[2, 3]);
1419
1420 // Check gradient values
1421 // Vertical gradient (y-direction)
1422 assert_abs_diff_eq!(grad_y[[0, 0]], 3.0, epsilon = 1e-10);
1423 assert_abs_diff_eq!(grad_y[[0, 1]], 3.0, epsilon = 1e-10);
1424 assert_abs_diff_eq!(grad_y[[0, 2]], 3.0, epsilon = 1e-10);
1425 assert_abs_diff_eq!(grad_y[[1, 0]], 3.0, epsilon = 1e-10);
1426 assert_abs_diff_eq!(grad_y[[1, 1]], 3.0, epsilon = 1e-10);
1427 assert_abs_diff_eq!(grad_y[[1, 2]], 3.0, epsilon = 1e-10);
1428
1429 // Horizontal gradient (x-direction)
1430 assert_abs_diff_eq!(grad_x[[0, 0]], 1.0, epsilon = 1e-10);
1431 assert_abs_diff_eq!(grad_x[[0, 1]], 1.0, epsilon = 1e-10);
1432 assert_abs_diff_eq!(grad_x[[0, 2]], 1.0, epsilon = 1e-10);
1433 assert_abs_diff_eq!(grad_x[[1, 0]], 1.0, epsilon = 1e-10);
1434 assert_abs_diff_eq!(grad_x[[1, 1]], 1.0, epsilon = 1e-10);
1435 assert_abs_diff_eq!(grad_x[[1, 2]], 1.0, epsilon = 1e-10);
1436
1437 // Test with custom spacing
1438 let (grad_y, grad_x) = gradient(a.view(), Some((2.0, 0.5))).unwrap();
1439
1440 // Vertical gradient (y-direction) with spacing = 2.0
1441 assert_abs_diff_eq!(grad_y[[0, 0]], 1.5, epsilon = 1e-10); // 3.0 / 2.0
1442 assert_abs_diff_eq!(grad_y[[0, 1]], 1.5, epsilon = 1e-10);
1443 assert_abs_diff_eq!(grad_y[[0, 2]], 1.5, epsilon = 1e-10);
1444
1445 // Horizontal gradient (x-direction) with spacing = 0.5
1446 assert_abs_diff_eq!(grad_x[[0, 0]], 2.0, epsilon = 1e-10); // 1.0 / 0.5
1447 assert_abs_diff_eq!(grad_x[[0, 1]], 2.0, epsilon = 1e-10);
1448 assert_abs_diff_eq!(grad_x[[0, 2]], 2.0, epsilon = 1e-10);
1449
1450 // Test empty array
1451 let empty: Array<f32, Ix2> = Array::zeros((0, 0));
1452 assert!(gradient(empty.view(), None).is_err());
1453 }
1454}