Skip to main content

timsrust_utils/
ndarray.rs

1use std::ops::{AddAssign, Index, IndexMut};
2
3use crate::custom_error;
4
5custom_error!(pub NDArrayError);
6
7/// A simple N-dimensional array type with shape, strides, and contiguous data storage.
8///
9/// # Example
10/// ```
11/// use timsrust_utils::ndarray::NDArray;
12/// let arr = NDArray::new([2, 3], vec![1, 2, 3, 4, 5, 6]).unwrap();
13/// assert_eq!(arr[[1, 2]], 6);
14/// ```
15#[derive(Debug, Clone, PartialEq, Eq, Hash)]
16pub struct NDArray<T, const N: usize> {
17    shape: [usize; N],
18    strides: [usize; N],
19    data: Vec<T>,
20}
21
22impl<T, const N: usize> NDArray<T, N> {
23    /// Creates a new NDArray with the given shape and data.
24    ///
25    /// # Arguments
26    ///
27    /// * `shape` - The shape of the array as an array of dimension sizes.
28    /// * `data` - The data to fill the array, must match the product of shape dimensions.
29    ///
30    /// # Errors
31    ///
32    /// Returns `TimsUtilsError` if the shape and data length are incompatible.
33    ///
34    /// # Example
35    /// ```
36    /// use timsrust_utils::ndarray::NDArray;
37    /// let arr = NDArray::new([2, 2], vec![1, 2, 3, 4]).unwrap();
38    /// assert_eq!(arr.shape(), [2, 2]);
39    /// assert_eq!(arr[[1, 1]], 4);
40    /// ```
41    pub fn new(shape: [usize; N], data: Vec<T>) -> Result<Self, NDArrayError> {
42        if shape.iter().product::<usize>() != data.len() {
43            return Err(NDArrayError::new(format!(
44                "Incompatible shapes: {:?} and {:?}",
45                shape,
46                data.len()
47            )));
48        }
49        let mut strides = [0; N];
50        let mut stride = 1;
51        for (i, &dim) in shape.iter().rev().enumerate() {
52            strides[N - 1 - i] = stride;
53            stride *= dim;
54        }
55        let result = Self {
56            shape,
57            strides,
58            data,
59        };
60        Ok(result)
61    }
62
63    /// Returns the shape of the array.
64    ///
65    /// # Example
66    /// ```
67    /// use timsrust_utils::ndarray::NDArray;
68    /// let arr = NDArray::new([2, 3], vec![1, 2, 3, 4, 5, 6]).unwrap();
69    /// assert_eq!(arr.shape(), [2, 3]);
70    /// ```
71    pub fn shape(&self) -> [usize; N] {
72        self.shape
73    }
74
75    /// Computes the flat index in the data vector for the given N-dimensional indices.
76    ///
77    /// # Arguments
78    ///
79    /// * `indices` - The N-dimensional indices.
80    ///
81    /// # Example
82    /// ```
83    /// use timsrust_utils::ndarray::NDArray;
84    /// let arr = NDArray::new([2, 2], vec![1, 2, 3, 4]).unwrap();
85    /// let idx = arr.index([1, 1]);
86    /// assert_eq!(idx, 3);
87    /// ```
88    pub fn index(&self, indices: [usize; N]) -> usize {
89        indices
90            .iter()
91            .zip(self.strides.iter())
92            .map(|(&idx, &stride)| idx * stride)
93            .sum()
94    }
95
96    /// Converts a flat index into N-dimensional indices.
97    ///
98    /// # Arguments
99    ///
100    /// * `idx` - The flat index.
101    ///
102    /// # Example
103    /// ```
104    /// use timsrust_utils::ndarray::NDArray;
105    /// let arr = NDArray::new([2, 2], vec![1, 2, 3, 4]).unwrap();
106    /// let indices = arr.inverted_index(3);
107    /// assert_eq!(indices, [1, 1]);
108    /// ```
109    pub fn inverted_index(&self, mut idx: usize) -> [usize; N] {
110        let mut indices = [0; N];
111        // for d in 0..N {
112        for (d, index) in indices.iter_mut().enumerate().take(N) {
113            *index = idx / self.strides[d];
114            idx %= self.strides[d];
115        }
116        indices
117    }
118}
119
120impl<T: Default + Copy + AddAssign, const N: usize> NDArray<T, N> {
121    /// Projects the array along the specified axis, summing over all other axes.
122    ///
123    /// # Arguments
124    ///
125    /// * `axis` - The axis to project onto.
126    ///
127    /// # Returns
128    ///
129    /// A vector of values, one for each index along the specified axis.
130    ///
131    /// # Panics
132    ///
133    /// Panics if the axis is out of bounds.
134    ///
135    /// # Example
136    /// ```
137    /// use timsrust_utils::ndarray::NDArray;
138    /// let arr = NDArray::new([2, 2], vec![1, 2, 3, 4]).unwrap();
139    /// let proj = arr.project_axis(0);
140    /// assert_eq!(proj, vec![1+2, 3+4]);
141    /// ```
142    pub fn project_axis(&self, axis: usize) -> Vec<T> {
143        assert!(axis < N, "Axis out of bounds");
144        let mut result = vec![T::default(); self.shape[axis]];
145        for (i, value) in self.data.iter().enumerate() {
146            let indices = self.inverted_index(i);
147            result[indices[axis]] += *value;
148        }
149        result
150    }
151}
152
153impl<T: AddAssign + Copy, const N: usize> AddAssign for NDArray<T, N> {
154    /// Adds another NDArray to this one, elementwise.
155    ///
156    /// # Panics
157    ///
158    /// Panics if the shapes do not match.
159    ///
160    /// # Example
161    /// ```
162    /// use timsrust_utils::ndarray::NDArray;
163    /// let mut a = NDArray::new([2, 2], vec![1, 2, 3, 4]).unwrap();
164    /// let b = NDArray::new([2, 2], vec![10, 20, 30, 40]).unwrap();
165    /// a += b;
166    /// assert_eq!(a[[0, 0]], 11);
167    /// assert_eq!(a[[1, 1]], 44);
168    /// ```
169    fn add_assign(&mut self, other: Self) {
170        assert_eq!(self.shape(), other.shape());
171        other
172            .data
173            .into_iter()
174            .enumerate()
175            .for_each(|(i, value)| self.data[i] += value);
176    }
177}
178
179impl<T: Default, const N: usize> NDArray<T, N> {
180    /// Creates an empty NDArray with the given shape, filled with default values.
181    ///
182    /// # Arguments
183    ///
184    /// * `shape` - The shape of the array.
185    ///
186    /// # Example
187    /// ```
188    /// use timsrust_utils::ndarray::NDArray;
189    /// let arr: NDArray<i32, 2> = NDArray::empty([2, 2]);
190    /// assert_eq!(arr.shape(), [2, 2]);
191    /// assert_eq!(arr[[1,1]], 0);
192    /// ```
193    pub fn empty(shape: [usize; N]) -> Self {
194        let size = shape.iter().product();
195        let data = (0..size).map(|_| T::default()).collect();
196        Self::new(shape, data).expect("Failed to create empty NDArray")
197    }
198}
199
200impl<T, const N: usize> Index<[usize; N]> for NDArray<T, N> {
201    type Output = T;
202    /// Indexes the array using N-dimensional indices.
203    ///
204    /// # Example
205    /// ```
206    /// use timsrust_utils::ndarray::NDArray;
207    /// let arr = NDArray::new([2, 2], vec![1, 2, 3, 4]).unwrap();
208    /// assert_eq!(arr[[1, 1]], 4);
209    /// ```
210    fn index(&self, indices: [usize; N]) -> &Self::Output {
211        let idx = self.index(indices);
212        &self.data[idx]
213    }
214}
215
216impl<T, const N: usize> IndexMut<[usize; N]> for NDArray<T, N> {
217    /// Mutable indexing using N-dimensional indices.
218    ///
219    /// # Example
220    /// ```
221    /// use timsrust_utils::ndarray::NDArray;
222    /// let mut arr = NDArray::new([2, 2], vec![1, 2, 3, 4]).unwrap();
223    /// arr[[0, 0]] = 10;
224    /// assert_eq!(arr[[0, 0]], 10);
225    /// ```
226    fn index_mut(&mut self, indices: [usize; N]) -> &mut Self::Output {
227        let idx = self.index(indices);
228        &mut self.data[idx]
229    }
230}