Skip to main content

tensors/
arrayn.rs

1//! Dynamic-rank row-major arrays and views.
2
3use crate::error::{Error, Result};
4use crate::numeric::Float;
5use crate::view2::validate_view;
6
7/// Owned dynamic-rank row-major array.
8#[derive(Clone, Debug, PartialEq)]
9pub struct ArrayN<T> {
10    data: Vec<T>,
11    shape: Vec<usize>,
12    strides: Vec<isize>,
13}
14
15/// Immutable dynamic-rank strided view.
16#[derive(Clone, Debug)]
17pub struct ArrayViewN<'a, T> {
18    data: &'a [T],
19    shape: Vec<usize>,
20    strides: Vec<isize>,
21    offset: isize,
22}
23
24/// Mutable dynamic-rank strided view.
25#[derive(Debug)]
26pub struct ArrayViewMutN<'a, T> {
27    data: &'a mut [T],
28    shape: Vec<usize>,
29    strides: Vec<isize>,
30    offset: isize,
31}
32
33impl<T> ArrayN<T> {
34    /// Build an array from row-major data.
35    pub fn from_vec(shape: Vec<usize>, data: Vec<T>) -> Result<Self> {
36        let expected = checked_len(&shape)?;
37        if data.len() != expected {
38            return Err(Error::shape(vec![expected], vec![data.len()]));
39        }
40        let strides = row_major_strides(&shape);
41        Ok(Self {
42            data,
43            shape,
44            strides,
45        })
46    }
47
48    /// Shape.
49    pub fn shape(&self) -> &[usize] {
50        &self.shape
51    }
52
53    /// Strides in elements.
54    pub fn strides(&self) -> &[isize] {
55        &self.strides
56    }
57
58    /// Number of dimensions.
59    pub fn ndim(&self) -> usize {
60        self.shape.len()
61    }
62
63    /// Number of elements.
64    pub fn len(&self) -> usize {
65        self.data.len()
66    }
67
68    /// Whether the array is empty.
69    pub fn is_empty(&self) -> bool {
70        self.data.is_empty()
71    }
72
73    /// Borrow the backing slice.
74    pub fn as_slice(&self) -> &[T] {
75        &self.data
76    }
77
78    /// Borrow the backing slice mutably.
79    pub fn as_mut_slice(&mut self) -> &mut [T] {
80        &mut self.data
81    }
82
83    /// Immutable dynamic-rank view.
84    pub fn view(&self) -> ArrayViewN<'_, T> {
85        ArrayViewN {
86            data: &self.data,
87            shape: self.shape.clone(),
88            strides: self.strides.clone(),
89            offset: 0,
90        }
91    }
92
93    /// Mutable dynamic-rank view.
94    pub fn view_mut(&mut self) -> ArrayViewMutN<'_, T> {
95        ArrayViewMutN {
96            data: &mut self.data,
97            shape: self.shape.clone(),
98            strides: self.strides.clone(),
99            offset: 0,
100        }
101    }
102
103    /// Get an element reference.
104    pub fn get(&self, index: &[usize]) -> Option<&T> {
105        self.linear_index(index).map(|idx| &self.data[idx])
106    }
107
108    /// Fix one axis at `index`, returning a lower-rank immutable view.
109    pub fn slice_axis(&self, axis: usize, index: usize) -> Result<ArrayViewN<'_, T>> {
110        self.view().slice_axis(axis, index)
111    }
112
113    /// Fix one axis at `index`, returning a lower-rank mutable view.
114    pub fn slice_axis_mut(&mut self, axis: usize, index: usize) -> Result<ArrayViewMutN<'_, T>> {
115        if axis >= self.ndim() {
116            return Err(Error::AxisOutOfBounds {
117                axis,
118                ndim: self.ndim(),
119            });
120        }
121        if index >= self.shape[axis] {
122            return Err(Error::IndexOutOfBounds);
123        }
124        let mut shape = self.shape.clone();
125        let mut strides = self.strides.clone();
126        let offset = index as isize * strides[axis];
127        shape.remove(axis);
128        strides.remove(axis);
129        Ok(ArrayViewMutN {
130            data: &mut self.data,
131            shape,
132            strides,
133            offset,
134        })
135    }
136
137    /// Get a mutable element reference.
138    pub fn get_mut(&mut self, index: &[usize]) -> Option<&mut T> {
139        self.linear_index(index).map(|idx| &mut self.data[idx])
140    }
141
142    fn linear_index(&self, index: &[usize]) -> Option<usize> {
143        if index.len() != self.ndim() {
144            return None;
145        }
146        let mut linear = 0usize;
147        for ((&idx, &dim), &stride) in index.iter().zip(&self.shape).zip(&self.strides) {
148            if idx >= dim {
149                return None;
150            }
151            linear += idx * stride as usize;
152        }
153        Some(linear)
154    }
155}
156
157impl<T: Clone> ArrayN<T> {
158    /// Fill a new array with `value`.
159    pub fn filled(shape: Vec<usize>, value: T) -> Self {
160        let len = shape.iter().product();
161        let strides = row_major_strides(&shape);
162        Self {
163            data: vec![value; len],
164            shape,
165            strides,
166        }
167    }
168
169    /// Fallibly fill a new array with `value`.
170    pub fn try_filled(shape: Vec<usize>, value: T) -> Result<Self> {
171        let len = checked_len(&shape)?;
172        let strides = row_major_strides(&shape);
173        let mut data = Vec::new();
174        data.try_reserve_exact(len)
175            .map_err(|_| Error::AllocationFailed)?;
176        data.resize(len, value);
177        Ok(Self {
178            data,
179            shape,
180            strides,
181        })
182    }
183}
184
185impl<T: Float> ArrayN<T> {
186    /// Array filled with zeros.
187    pub fn zeros(shape: Vec<usize>) -> Self {
188        Self::filled(shape, T::zero())
189    }
190
191    /// Fallibly allocate an array filled with zeros.
192    pub fn try_zeros(shape: Vec<usize>) -> Result<Self> {
193        Self::try_filled(shape, T::zero())
194    }
195
196    /// Array filled with ones.
197    pub fn ones(shape: Vec<usize>) -> Self {
198        Self::filled(shape, T::one())
199    }
200
201    /// Fallibly allocate an array filled with ones.
202    pub fn try_ones(shape: Vec<usize>) -> Result<Self> {
203        Self::try_filled(shape, T::one())
204    }
205}
206
207impl<'a, T> ArrayViewN<'a, T> {
208    /// Create a checked dynamic-rank view.
209    pub fn new(
210        data: &'a [T],
211        shape: &'a [usize],
212        strides: &'a [isize],
213        offset: isize,
214    ) -> Result<Self> {
215        validate_view(data.len(), shape, strides, offset)?;
216        Ok(Self {
217            data,
218            shape: shape.to_vec(),
219            strides: strides.to_vec(),
220            offset,
221        })
222    }
223
224    /// Shape.
225    pub fn shape(&self) -> &[usize] {
226        &self.shape
227    }
228
229    /// Strides in elements.
230    pub fn strides(&self) -> &[isize] {
231        &self.strides
232    }
233
234    /// Number of dimensions.
235    pub fn ndim(&self) -> usize {
236        self.shape.len()
237    }
238
239    /// Number of elements.
240    pub fn len(&self) -> usize {
241        self.shape.iter().product()
242    }
243
244    /// Whether the view is empty.
245    pub fn is_empty(&self) -> bool {
246        self.len() == 0
247    }
248
249    /// Whether the view is compact row-major contiguous.
250    pub fn is_contiguous(&self) -> bool {
251        is_compact_row_major(&self.shape, &self.strides)
252    }
253
254    /// Borrow the backing slice if this view covers it contiguously.
255    pub fn as_slice(&self) -> Option<&'a [T]> {
256        if !self.is_contiguous() {
257            return None;
258        }
259        let start = self.offset as usize;
260        let end = start + self.len();
261        Some(&self.data[start..end])
262    }
263
264    /// Get an element reference.
265    pub fn get(&self, index: &[usize]) -> Option<&'a T> {
266        self.linear_index(index).map(|idx| &self.data[idx])
267    }
268
269    /// Fix one axis at `index`, returning a lower-rank view.
270    pub fn slice_axis(&self, axis: usize, index: usize) -> Result<Self> {
271        if axis >= self.ndim() {
272            return Err(Error::AxisOutOfBounds {
273                axis,
274                ndim: self.ndim(),
275            });
276        }
277        if index >= self.shape[axis] {
278            return Err(Error::IndexOutOfBounds);
279        }
280        let mut shape = self.shape.clone();
281        let mut strides = self.strides.clone();
282        let offset = self.offset + index as isize * strides[axis];
283        shape.remove(axis);
284        strides.remove(axis);
285        Ok(Self {
286            data: self.data,
287            shape,
288            strides,
289            offset,
290        })
291    }
292
293    fn linear_index(&self, index: &[usize]) -> Option<usize> {
294        if index.len() != self.ndim() {
295            return None;
296        }
297        let mut linear = self.offset;
298        for ((&idx, &dim), &stride) in index.iter().zip(&self.shape).zip(&self.strides) {
299            if idx >= dim {
300                return None;
301            }
302            linear += idx as isize * stride;
303        }
304        (linear >= 0).then_some(linear as usize)
305    }
306}
307
308impl<'a, T> ArrayViewMutN<'a, T> {
309    /// Create a checked mutable dynamic-rank view.
310    pub fn new(
311        data: &'a mut [T],
312        shape: Vec<usize>,
313        strides: Vec<isize>,
314        offset: isize,
315    ) -> Result<Self> {
316        validate_view(data.len(), &shape, &strides, offset)?;
317        Ok(Self {
318            data,
319            shape,
320            strides,
321            offset,
322        })
323    }
324
325    /// Shape.
326    pub fn shape(&self) -> &[usize] {
327        &self.shape
328    }
329
330    /// Strides in elements.
331    pub fn strides(&self) -> &[isize] {
332        &self.strides
333    }
334
335    /// Number of dimensions.
336    pub fn ndim(&self) -> usize {
337        self.shape.len()
338    }
339
340    /// Number of elements.
341    pub fn len(&self) -> usize {
342        self.shape.iter().product()
343    }
344
345    /// Whether the view is empty.
346    pub fn is_empty(&self) -> bool {
347        self.len() == 0
348    }
349
350    /// Whether the view is compact row-major contiguous.
351    pub fn is_contiguous(&self) -> bool {
352        is_compact_row_major(&self.shape, &self.strides)
353    }
354
355    /// Borrow the backing slice if this view covers it contiguously.
356    pub fn as_mut_slice(&mut self) -> Option<&mut [T]> {
357        if !self.is_contiguous() {
358            return None;
359        }
360        let start = self.offset as usize;
361        let end = start + self.len();
362        Some(&mut self.data[start..end])
363    }
364
365    /// Immutable view over the same region.
366    pub fn as_view(&self) -> ArrayViewN<'_, T> {
367        ArrayViewN {
368            data: self.data,
369            shape: self.shape.clone(),
370            strides: self.strides.clone(),
371            offset: self.offset,
372        }
373    }
374
375    /// Get an element reference.
376    pub fn get(&self, index: &[usize]) -> Option<&T> {
377        self.linear_index(index).map(|idx| &self.data[idx])
378    }
379
380    /// Get a mutable element reference.
381    pub fn get_mut(&mut self, index: &[usize]) -> Option<&mut T> {
382        let linear = self.linear_index(index)?;
383        Some(&mut self.data[linear])
384    }
385
386    /// Fix one axis at `index`, returning a lower-rank mutable view.
387    pub fn slice_axis_mut(&mut self, axis: usize, index: usize) -> Result<ArrayViewMutN<'_, T>> {
388        if axis >= self.ndim() {
389            return Err(Error::AxisOutOfBounds {
390                axis,
391                ndim: self.ndim(),
392            });
393        }
394        if index >= self.shape[axis] {
395            return Err(Error::IndexOutOfBounds);
396        }
397        let mut shape = self.shape.clone();
398        let mut strides = self.strides.clone();
399        let offset = self.offset + index as isize * strides[axis];
400        shape.remove(axis);
401        strides.remove(axis);
402        Ok(ArrayViewMutN {
403            data: &mut *self.data,
404            shape,
405            strides,
406            offset,
407        })
408    }
409
410    fn linear_index(&self, index: &[usize]) -> Option<usize> {
411        if index.len() != self.ndim() {
412            return None;
413        }
414        let mut linear = self.offset;
415        for ((&idx, &dim), &stride) in index.iter().zip(&self.shape).zip(&self.strides) {
416            if idx >= dim {
417                return None;
418            }
419            linear += idx as isize * stride;
420        }
421        (linear >= 0).then_some(linear as usize)
422    }
423}
424
425fn checked_len(shape: &[usize]) -> Result<usize> {
426    shape
427        .iter()
428        .try_fold(1usize, |acc, &dim| acc.checked_mul(dim))
429        .ok_or(Error::DimensionTooLarge)
430}
431
432fn row_major_strides(shape: &[usize]) -> Vec<isize> {
433    let mut strides = vec![1isize; shape.len()];
434    let mut acc = 1isize;
435    for axis in (0..shape.len()).rev() {
436        strides[axis] = acc;
437        acc *= shape[axis] as isize;
438    }
439    strides
440}
441
442fn is_compact_row_major(shape: &[usize], strides: &[isize]) -> bool {
443    if shape.contains(&0) {
444        return true;
445    }
446    let mut expected = 1isize;
447    for (&dim, &stride) in shape.iter().zip(strides).rev() {
448        if dim > 1 && stride != expected {
449            return false;
450        }
451        expected *= dim as isize;
452    }
453    true
454}