timsrust_utils/ndarray.rs
1use std::ops::{AddAssign, Index, IndexMut};
2
3use crate::custom_error;
4
5custom_error!(pub NDArrayError);
6
7/// A simple N-dimensional array type with shape, strides, and contiguous data storage.
8///
9/// # Example
10/// ```
11/// use timsrust_utils::ndarray::NDArray;
12/// let arr = NDArray::new([2, 3], vec![1, 2, 3, 4, 5, 6]).unwrap();
13/// assert_eq!(arr[[1, 2]], 6);
14/// ```
15#[derive(Debug, Clone, PartialEq, Eq, Hash)]
16pub struct NDArray<T, const N: usize> {
17 shape: [usize; N],
18 strides: [usize; N],
19 data: Vec<T>,
20}
21
22impl<T, const N: usize> NDArray<T, N> {
23 /// Creates a new NDArray with the given shape and data.
24 ///
25 /// # Arguments
26 ///
27 /// * `shape` - The shape of the array as an array of dimension sizes.
28 /// * `data` - The data to fill the array, must match the product of shape dimensions.
29 ///
30 /// # Errors
31 ///
32 /// Returns `TimsUtilsError` if the shape and data length are incompatible.
33 ///
34 /// # Example
35 /// ```
36 /// use timsrust_utils::ndarray::NDArray;
37 /// let arr = NDArray::new([2, 2], vec![1, 2, 3, 4]).unwrap();
38 /// assert_eq!(arr.shape(), [2, 2]);
39 /// assert_eq!(arr[[1, 1]], 4);
40 /// ```
41 pub fn new(shape: [usize; N], data: Vec<T>) -> Result<Self, NDArrayError> {
42 if shape.iter().product::<usize>() != data.len() {
43 return Err(NDArrayError::new(format!(
44 "Incompatible shapes: {:?} and {:?}",
45 shape,
46 data.len()
47 )));
48 }
49 let mut strides = [0; N];
50 let mut stride = 1;
51 for (i, &dim) in shape.iter().rev().enumerate() {
52 strides[N - 1 - i] = stride;
53 stride *= dim;
54 }
55 let result = Self {
56 shape,
57 strides,
58 data,
59 };
60 Ok(result)
61 }
62
63 /// Returns the shape of the array.
64 ///
65 /// # Example
66 /// ```
67 /// use timsrust_utils::ndarray::NDArray;
68 /// let arr = NDArray::new([2, 3], vec![1, 2, 3, 4, 5, 6]).unwrap();
69 /// assert_eq!(arr.shape(), [2, 3]);
70 /// ```
71 pub fn shape(&self) -> [usize; N] {
72 self.shape
73 }
74
75 /// Computes the flat index in the data vector for the given N-dimensional indices.
76 ///
77 /// # Arguments
78 ///
79 /// * `indices` - The N-dimensional indices.
80 ///
81 /// # Example
82 /// ```
83 /// use timsrust_utils::ndarray::NDArray;
84 /// let arr = NDArray::new([2, 2], vec![1, 2, 3, 4]).unwrap();
85 /// let idx = arr.index([1, 1]);
86 /// assert_eq!(idx, 3);
87 /// ```
88 pub fn index(&self, indices: [usize; N]) -> usize {
89 indices
90 .iter()
91 .zip(self.strides.iter())
92 .map(|(&idx, &stride)| idx * stride)
93 .sum()
94 }
95
96 /// Converts a flat index into N-dimensional indices.
97 ///
98 /// # Arguments
99 ///
100 /// * `idx` - The flat index.
101 ///
102 /// # Example
103 /// ```
104 /// use timsrust_utils::ndarray::NDArray;
105 /// let arr = NDArray::new([2, 2], vec![1, 2, 3, 4]).unwrap();
106 /// let indices = arr.inverted_index(3);
107 /// assert_eq!(indices, [1, 1]);
108 /// ```
109 pub fn inverted_index(&self, mut idx: usize) -> [usize; N] {
110 let mut indices = [0; N];
111 // for d in 0..N {
112 for (d, index) in indices.iter_mut().enumerate().take(N) {
113 *index = idx / self.strides[d];
114 idx %= self.strides[d];
115 }
116 indices
117 }
118}
119
120impl<T: Default + Copy + AddAssign, const N: usize> NDArray<T, N> {
121 /// Projects the array along the specified axis, summing over all other axes.
122 ///
123 /// # Arguments
124 ///
125 /// * `axis` - The axis to project onto.
126 ///
127 /// # Returns
128 ///
129 /// A vector of values, one for each index along the specified axis.
130 ///
131 /// # Panics
132 ///
133 /// Panics if the axis is out of bounds.
134 ///
135 /// # Example
136 /// ```
137 /// use timsrust_utils::ndarray::NDArray;
138 /// let arr = NDArray::new([2, 2], vec![1, 2, 3, 4]).unwrap();
139 /// let proj = arr.project_axis(0);
140 /// assert_eq!(proj, vec![1+2, 3+4]);
141 /// ```
142 pub fn project_axis(&self, axis: usize) -> Vec<T> {
143 assert!(axis < N, "Axis out of bounds");
144 let mut result = vec![T::default(); self.shape[axis]];
145 for (i, value) in self.data.iter().enumerate() {
146 let indices = self.inverted_index(i);
147 result[indices[axis]] += *value;
148 }
149 result
150 }
151}
152
153impl<T: AddAssign + Copy, const N: usize> AddAssign for NDArray<T, N> {
154 /// Adds another NDArray to this one, elementwise.
155 ///
156 /// # Panics
157 ///
158 /// Panics if the shapes do not match.
159 ///
160 /// # Example
161 /// ```
162 /// use timsrust_utils::ndarray::NDArray;
163 /// let mut a = NDArray::new([2, 2], vec![1, 2, 3, 4]).unwrap();
164 /// let b = NDArray::new([2, 2], vec![10, 20, 30, 40]).unwrap();
165 /// a += b;
166 /// assert_eq!(a[[0, 0]], 11);
167 /// assert_eq!(a[[1, 1]], 44);
168 /// ```
169 fn add_assign(&mut self, other: Self) {
170 assert_eq!(self.shape(), other.shape());
171 other
172 .data
173 .into_iter()
174 .enumerate()
175 .for_each(|(i, value)| self.data[i] += value);
176 }
177}
178
179impl<T: Default, const N: usize> NDArray<T, N> {
180 /// Creates an empty NDArray with the given shape, filled with default values.
181 ///
182 /// # Arguments
183 ///
184 /// * `shape` - The shape of the array.
185 ///
186 /// # Example
187 /// ```
188 /// use timsrust_utils::ndarray::NDArray;
189 /// let arr: NDArray<i32, 2> = NDArray::empty([2, 2]);
190 /// assert_eq!(arr.shape(), [2, 2]);
191 /// assert_eq!(arr[[1,1]], 0);
192 /// ```
193 pub fn empty(shape: [usize; N]) -> Self {
194 let size = shape.iter().product();
195 let data = (0..size).map(|_| T::default()).collect();
196 Self::new(shape, data).expect("Failed to create empty NDArray")
197 }
198}
199
200impl<T, const N: usize> Index<[usize; N]> for NDArray<T, N> {
201 type Output = T;
202 /// Indexes the array using N-dimensional indices.
203 ///
204 /// # Example
205 /// ```
206 /// use timsrust_utils::ndarray::NDArray;
207 /// let arr = NDArray::new([2, 2], vec![1, 2, 3, 4]).unwrap();
208 /// assert_eq!(arr[[1, 1]], 4);
209 /// ```
210 fn index(&self, indices: [usize; N]) -> &Self::Output {
211 let idx = self.index(indices);
212 &self.data[idx]
213 }
214}
215
216impl<T, const N: usize> IndexMut<[usize; N]> for NDArray<T, N> {
217 /// Mutable indexing using N-dimensional indices.
218 ///
219 /// # Example
220 /// ```
221 /// use timsrust_utils::ndarray::NDArray;
222 /// let mut arr = NDArray::new([2, 2], vec![1, 2, 3, 4]).unwrap();
223 /// arr[[0, 0]] = 10;
224 /// assert_eq!(arr[[0, 0]], 10);
225 /// ```
226 fn index_mut(&mut self, indices: [usize; N]) -> &mut Self::Output {
227 let idx = self.index(indices);
228 &mut self.data[idx]
229 }
230}