vmd_rs/utils/
array.rs

1use ndarray::{s, ArrayBase, ArrayView1, Axis, Data, Ix1, OwnedRepr};
2
3pub trait Flip<A, S> {
4    fn flip(&self) -> ArrayView1<A>;
5}
6
7impl<A, S> Flip<A, S> for ArrayBase<S, Ix1>
8where
9    A: Clone,
10    S: Data<Elem = A>,
11{
12    fn flip(&self) -> ArrayView1<A>
13    where
14        A: Clone,
15        S: Data<Elem = A>,
16    {
17        self.slice(s![0..; -1])
18    }
19}
20
21pub fn fftshift1d<A, S>(arr: ArrayBase<S, Ix1>) -> ArrayBase<OwnedRepr<A>, Ix1>
22where
23    A: Clone,
24    S: Data<Elem = A>,
25{
26    let half = if arr.len() % 2 == 0 {
27        arr.len() / 2
28    } else {
29        (arr.len() / 2) + 1
30    };
31    let first_half = arr.slice(s![..half]);
32    let second_half = arr.slice(s![half..]);
33    ndarray::concatenate![Axis(0), second_half, first_half]
34}
35
36pub fn ifftshift1d<A, S>(arr: ArrayBase<S, Ix1>) -> ArrayBase<OwnedRepr<A>, Ix1>
37where
38    A: Clone,
39    S: Data<Elem = A>,
40{
41    let even = arr.len() % 2 == 0;
42    if even {
43        return fftshift1d(arr);
44    }
45    let half = arr.len() / 2;
46    let first_half = arr.slice(s![..half]);
47    let second_half = arr.slice(s![half..]);
48    ndarray::concatenate![Axis(0), second_half, first_half]
49}
50
51#[cfg(test)]
52mod tests {
53    use super::*;
54    use ndarray::{Array1, ArrayView1};
55
56    #[test]
57    fn test_flip_ok() {
58        let input = [1, 2, 3, 4, 5];
59        let expected = [5, 4, 3, 2, 1];
60        let input_view = ndarray::ArrayView1::from_shape(input.len(), &input).unwrap();
61        let output = input_view.flip();
62        let expected = ndarray::ArrayView1::from_shape(expected.len(), &expected).unwrap();
63
64        assert_eq!(output, expected);
65    }
66
67    #[test]
68    fn test_fftshift_even() {
69        let input = [1, 2, 3, 4, 5, 6];
70        let expected = [4, 5, 6, 1, 2, 3];
71        let expected = ArrayView1::from_shape(expected.len(), &expected).unwrap();
72        let input = Array1::from_iter(input);
73
74        let output = fftshift1d(input);
75        assert_eq!(&output, expected);
76    }
77
78    /// Match python's behaviour
79    #[test]
80    fn test_fftshift_odd() {
81        let input = [1, 2, 3, 4, 5, 6, 7];
82        let expected = [5, 6, 7, 1, 2, 3, 4];
83        let expected = ArrayView1::from_shape(expected.len(), &expected).unwrap();
84        let input = Array1::from_iter(input);
85
86        let output = fftshift1d(input);
87        assert_eq!(output, expected);
88    }
89
90    #[test]
91    fn test_ifftshift_even() {
92        let input = [4, 5, 6, 1, 2, 3];
93        let expected = [1, 2, 3, 4, 5, 6];
94        let expected = ArrayView1::from_shape(expected.len(), &expected).unwrap();
95        let input = Array1::from_iter(input);
96
97        let output = ifftshift1d(input);
98        assert_eq!(output, expected);
99    }
100
101    #[test]
102    fn test_ifftshift_odd() {
103        let input = [5, 6, 7, 1, 2, 3, 4];
104        let expected = [1, 2, 3, 4, 5, 6, 7];
105        let expected = ArrayView1::from_shape(expected.len(), &expected).unwrap();
106        let input = Array1::from_iter(input);
107
108        let output = ifftshift1d(input);
109        assert_eq!(output, expected);
110
111        let input = [6, 7, 8, 9, 1, 2, 3, 4, 5];
112        let expected = [1, 2, 3, 4, 5, 6, 7, 8, 9];
113        let expected = ArrayView1::from_shape(expected.len(), &expected).unwrap();
114        let input = Array1::from_iter(input);
115
116        let output = ifftshift1d(input);
117        assert_eq!(output, expected);
118    }
119}