sfs_core/
array.rs

1//! N-dimensional array.
2
3use std::{
4    fmt, io,
5    ops::{Index, IndexMut},
6};
7
8pub mod iter;
9use iter::{AxisIter, IndicesIter};
10
11pub mod npy;
12
13pub(crate) mod shape;
14use shape::Strides;
15pub use shape::{Axis, Shape};
16
17pub mod view;
18use view::View;
19
20/// An N-dimensional strided array.
21#[derive(Clone, Debug, PartialEq)]
22pub struct Array<T> {
23    data: Vec<T>,
24    shape: Shape,
25    strides: Strides,
26}
27
28impl<T> Array<T> {
29    /// Returns a mutable reference to the underlying data as a flat slice in row-major order.
30    pub fn as_mut_slice(&mut self) -> &mut [T] {
31        self.data.as_mut_slice()
32    }
33
34    /// Returns the underlying data as a flat slice in row-major order.
35    pub fn as_slice(&self) -> &[T] {
36        self.data.as_slice()
37    }
38
39    /// Returns the number of dimensions of the array.
40    pub fn dimensions(&self) -> usize {
41        self.shape.len()
42    }
43
44    /// Returns the number of elements in the array.
45    pub fn elements(&self) -> usize {
46        self.data.len()
47    }
48
49    /// Creates a new array by repeating a single element to a shape.
50    pub fn from_element<S>(element: T, shape: S) -> Self
51    where
52        T: Clone,
53        S: Into<Shape>,
54    {
55        let shape = shape.into();
56        let elements = shape.elements();
57
58        Self::new_unchecked(vec![element; elements], shape)
59    }
60
61    /// Creates a new array from an iterator an its shape.
62    ///
63    /// # Errors
64    ///
65    /// If the number of items in the iterator does not match the provided shape.
66    pub fn from_iter<I, S>(iter: I, shape: S) -> Result<Self, ShapeError>
67    where
68        I: IntoIterator<Item = T>,
69        S: Into<Shape>,
70    {
71        Self::new(Vec::from_iter(iter), shape)
72    }
73
74    /// Returns the element at the provided index if in bounds, and `None` otherwise,
75    pub fn get<I>(&self, index: I) -> Option<&T>
76    where
77        I: AsRef<[usize]>,
78    {
79        let index = index.as_ref();
80
81        if index.len() == self.dimensions() {
82            self.strides
83                .flat_index(&self.shape, index)
84                .and_then(|flat| self.data.get(flat))
85        } else {
86            None
87        }
88    }
89
90    /// Returns a view of the array along the provided axis at the provided index if in bounds, and
91    /// `None` otherwise.
92    ///
93    /// See [`Array::index_axis`] for a panicking version.
94    pub fn get_axis(&self, axis: Axis, index: usize) -> Option<View<'_, T>> {
95        if axis.0 > self.dimensions() || index >= self.shape[axis.0] {
96            None
97        } else {
98            let offset = index * self.strides[axis.0];
99            let data = &self.data[offset..];
100            let shape = self.shape.remove_axis(axis);
101            let strides = self.strides.remove_axis(axis);
102
103            Some(View::new_unchecked(data, shape, strides))
104        }
105    }
106
107    /// Returns a mutable reference to the element at the provided index if in bounds, and `None`
108    /// otherwise,
109    pub fn get_mut<I>(&mut self, index: I) -> Option<&mut T>
110    where
111        I: AsRef<[usize]>,
112    {
113        let index = index.as_ref();
114
115        if index.len() == self.dimensions() {
116            self.strides
117                .flat_index(&self.shape, index)
118                .and_then(|flat| self.data.get_mut(flat))
119        } else {
120            None
121        }
122    }
123
124    /// Returns a view of the array along the provided axis at the provided index if in bounds.
125    ///
126    /// # Panics
127    ///
128    /// If the axis or the index is not in bounds, see [`Array::get_axis`] for a fallible version.
129    pub fn index_axis(&self, axis: Axis, index: usize) -> View<'_, T> {
130        self.get_axis(axis, index)
131            .expect("axis or index out of bounds")
132    }
133
134    /// Returns an iterator over the underlying data in row-major order.
135    pub fn iter(&self) -> std::slice::Iter<'_, T> {
136        self.data.iter()
137    }
138
139    /// Returns an iterator over views of the array along the provided axis.
140    pub fn iter_axis(&self, axis: Axis) -> AxisIter<'_, T> {
141        AxisIter::new(self, axis)
142    }
143
144    /// Returns an iterator over indices of the array in row-major order.
145    pub fn iter_indices(&self) -> IndicesIter<'_> {
146        IndicesIter::new(self)
147    }
148
149    /// Returns an iterator over mutable references to the underlying data in row-major order.
150    pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, T> {
151        self.data.iter_mut()
152    }
153
154    /// Creates a new array from data in row-major order and a shape.
155    ///
156    /// # Errors
157    ///
158    /// If the number of items in the data does not match the provided shape.
159    pub fn new<D, S>(data: D, shape: S) -> Result<Self, ShapeError>
160    where
161        D: Into<Vec<T>>,
162        S: Into<Shape>,
163    {
164        let data = data.into();
165        let shape = shape.into();
166
167        if data.len() == shape.elements() {
168            Ok(Array::new_unchecked(data, shape))
169        } else {
170            Err(ShapeError {
171                shape,
172                n: data.len(),
173            })
174        }
175    }
176
177    /// Creates a new array from data in row-major order and a shape.
178    ///
179    /// Prefer using [`Array::new`] to ensure the data fits the provided shape.
180    /// It is a logic error where this is not true, though it can not trigger unsafe behaviour.
181    pub fn new_unchecked<D, S>(data: D, shape: S) -> Self
182    where
183        D: Into<Vec<T>>,
184        S: Into<Shape>,
185    {
186        let data = data.into();
187        let shape = shape.into();
188
189        Self {
190            data,
191            strides: shape.strides(),
192            shape,
193        }
194    }
195
196    /// Returns the shape of the array.
197    pub fn shape(&self) -> &Shape {
198        &self.shape
199    }
200}
201
202impl Array<f64> {
203    /// Creates a new array filled with zeros to a shape.
204    pub fn from_zeros<S>(shape: S) -> Self
205    where
206        S: Into<Shape>,
207    {
208        Self::from_element(0.0, shape)
209    }
210
211    /// Reads an array from the [`npy`] format.
212    ///
213    /// See the [format docs](https://numpy.org/devdocs/reference/generated/numpy.lib.format.html)
214    /// for details.
215    pub fn read_npy<R>(mut reader: R) -> io::Result<Self>
216    where
217        R: io::BufRead,
218    {
219        npy::read_array(&mut reader)
220    }
221
222    /// Returns the sum of the elements in the array.
223    pub fn sum(&self, axis: Axis) -> Self {
224        let smaller_shape = self.shape.remove_axis(axis).into_shape();
225
226        self.iter_axis(axis)
227            .fold(Array::from_zeros(smaller_shape), |mut array, view| {
228                array.iter_mut().zip(view.iter()).for_each(|(x, y)| *x += y);
229                array
230            })
231    }
232
233    /// Writes the in the [`npy`] format.
234    ///
235    /// See the [format docs](https://numpy.org/devdocs/reference/generated/numpy.lib.format.html)
236    /// for details.
237    pub fn write_npy<W>(&self, mut writer: W) -> io::Result<()>
238    where
239        W: io::Write,
240    {
241        npy::write_array(&mut writer, self)
242    }
243}
244
245impl<T, I> Index<I> for Array<T>
246where
247    I: AsRef<[usize]>,
248{
249    type Output = T;
250
251    fn index(&self, index: I) -> &Self::Output {
252        self.get(index)
253            .expect("index invalid dimension or out of bounds")
254    }
255}
256
257impl<T, I> IndexMut<I> for Array<T>
258where
259    I: AsRef<[usize]>,
260{
261    fn index_mut(&mut self, index: I) -> &mut Self::Output {
262        self.get_mut(index)
263            .expect("index invalid dimension or out of bounds")
264    }
265}
266
267/// An error associated with a shape mismatch on construction of an [`Array`].
268#[derive(Debug)]
269pub struct ShapeError {
270    shape: Shape,
271    n: usize,
272}
273
274impl fmt::Display for ShapeError {
275    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
276        let ShapeError { shape, n } = self;
277        write!(
278            f,
279            "cannot construct array with shape {shape} from {n} elements"
280        )
281    }
282}
283
284impl std::error::Error for ShapeError {}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    use crate::approx::ApproxEq;
291
292    impl<T> ApproxEq for Array<T>
293    where
294        T: ApproxEq,
295    {
296        const DEFAULT_EPSILON: Self::Epsilon = T::DEFAULT_EPSILON;
297
298        type Epsilon = T::Epsilon;
299
300        fn approx_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
301            self.data.approx_eq(&other.data, epsilon)
302                && self.shape == other.shape
303                && self.strides == other.strides
304        }
305    }
306}