1use std::{
4 fmt, io,
5 ops::{Index, IndexMut},
6};
7
8pub mod iter;
9use iter::{AxisIter, IndicesIter};
10
11pub mod npy;
12
13pub(crate) mod shape;
14use shape::Strides;
15pub use shape::{Axis, Shape};
16
17pub mod view;
18use view::View;
19
20#[derive(Clone, Debug, PartialEq)]
22pub struct Array<T> {
23 data: Vec<T>,
24 shape: Shape,
25 strides: Strides,
26}
27
28impl<T> Array<T> {
29 pub fn as_mut_slice(&mut self) -> &mut [T] {
31 self.data.as_mut_slice()
32 }
33
34 pub fn as_slice(&self) -> &[T] {
36 self.data.as_slice()
37 }
38
39 pub fn dimensions(&self) -> usize {
41 self.shape.len()
42 }
43
44 pub fn elements(&self) -> usize {
46 self.data.len()
47 }
48
49 pub fn from_element<S>(element: T, shape: S) -> Self
51 where
52 T: Clone,
53 S: Into<Shape>,
54 {
55 let shape = shape.into();
56 let elements = shape.elements();
57
58 Self::new_unchecked(vec![element; elements], shape)
59 }
60
61 pub fn from_iter<I, S>(iter: I, shape: S) -> Result<Self, ShapeError>
67 where
68 I: IntoIterator<Item = T>,
69 S: Into<Shape>,
70 {
71 Self::new(Vec::from_iter(iter), shape)
72 }
73
74 pub fn get<I>(&self, index: I) -> Option<&T>
76 where
77 I: AsRef<[usize]>,
78 {
79 let index = index.as_ref();
80
81 if index.len() == self.dimensions() {
82 self.strides
83 .flat_index(&self.shape, index)
84 .and_then(|flat| self.data.get(flat))
85 } else {
86 None
87 }
88 }
89
90 pub fn get_axis(&self, axis: Axis, index: usize) -> Option<View<'_, T>> {
95 if axis.0 > self.dimensions() || index >= self.shape[axis.0] {
96 None
97 } else {
98 let offset = index * self.strides[axis.0];
99 let data = &self.data[offset..];
100 let shape = self.shape.remove_axis(axis);
101 let strides = self.strides.remove_axis(axis);
102
103 Some(View::new_unchecked(data, shape, strides))
104 }
105 }
106
107 pub fn get_mut<I>(&mut self, index: I) -> Option<&mut T>
110 where
111 I: AsRef<[usize]>,
112 {
113 let index = index.as_ref();
114
115 if index.len() == self.dimensions() {
116 self.strides
117 .flat_index(&self.shape, index)
118 .and_then(|flat| self.data.get_mut(flat))
119 } else {
120 None
121 }
122 }
123
124 pub fn index_axis(&self, axis: Axis, index: usize) -> View<'_, T> {
130 self.get_axis(axis, index)
131 .expect("axis or index out of bounds")
132 }
133
134 pub fn iter(&self) -> std::slice::Iter<'_, T> {
136 self.data.iter()
137 }
138
139 pub fn iter_axis(&self, axis: Axis) -> AxisIter<'_, T> {
141 AxisIter::new(self, axis)
142 }
143
144 pub fn iter_indices(&self) -> IndicesIter<'_> {
146 IndicesIter::new(self)
147 }
148
149 pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, T> {
151 self.data.iter_mut()
152 }
153
154 pub fn new<D, S>(data: D, shape: S) -> Result<Self, ShapeError>
160 where
161 D: Into<Vec<T>>,
162 S: Into<Shape>,
163 {
164 let data = data.into();
165 let shape = shape.into();
166
167 if data.len() == shape.elements() {
168 Ok(Array::new_unchecked(data, shape))
169 } else {
170 Err(ShapeError {
171 shape,
172 n: data.len(),
173 })
174 }
175 }
176
177 pub fn new_unchecked<D, S>(data: D, shape: S) -> Self
182 where
183 D: Into<Vec<T>>,
184 S: Into<Shape>,
185 {
186 let data = data.into();
187 let shape = shape.into();
188
189 Self {
190 data,
191 strides: shape.strides(),
192 shape,
193 }
194 }
195
196 pub fn shape(&self) -> &Shape {
198 &self.shape
199 }
200}
201
202impl Array<f64> {
203 pub fn from_zeros<S>(shape: S) -> Self
205 where
206 S: Into<Shape>,
207 {
208 Self::from_element(0.0, shape)
209 }
210
211 pub fn read_npy<R>(mut reader: R) -> io::Result<Self>
216 where
217 R: io::BufRead,
218 {
219 npy::read_array(&mut reader)
220 }
221
222 pub fn sum(&self, axis: Axis) -> Self {
224 let smaller_shape = self.shape.remove_axis(axis).into_shape();
225
226 self.iter_axis(axis)
227 .fold(Array::from_zeros(smaller_shape), |mut array, view| {
228 array.iter_mut().zip(view.iter()).for_each(|(x, y)| *x += y);
229 array
230 })
231 }
232
233 pub fn write_npy<W>(&self, mut writer: W) -> io::Result<()>
238 where
239 W: io::Write,
240 {
241 npy::write_array(&mut writer, self)
242 }
243}
244
245impl<T, I> Index<I> for Array<T>
246where
247 I: AsRef<[usize]>,
248{
249 type Output = T;
250
251 fn index(&self, index: I) -> &Self::Output {
252 self.get(index)
253 .expect("index invalid dimension or out of bounds")
254 }
255}
256
257impl<T, I> IndexMut<I> for Array<T>
258where
259 I: AsRef<[usize]>,
260{
261 fn index_mut(&mut self, index: I) -> &mut Self::Output {
262 self.get_mut(index)
263 .expect("index invalid dimension or out of bounds")
264 }
265}
266
267#[derive(Debug)]
269pub struct ShapeError {
270 shape: Shape,
271 n: usize,
272}
273
274impl fmt::Display for ShapeError {
275 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
276 let ShapeError { shape, n } = self;
277 write!(
278 f,
279 "cannot construct array with shape {shape} from {n} elements"
280 )
281 }
282}
283
284impl std::error::Error for ShapeError {}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289
290 use crate::approx::ApproxEq;
291
292 impl<T> ApproxEq for Array<T>
293 where
294 T: ApproxEq,
295 {
296 const DEFAULT_EPSILON: Self::Epsilon = T::DEFAULT_EPSILON;
297
298 type Epsilon = T::Epsilon;
299
300 fn approx_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
301 self.data.approx_eq(&other.data, epsilon)
302 && self.shape == other.shape
303 && self.strides == other.strides
304 }
305 }
306}