Skip to main content

tensors/
view3.rs

1//! Three-dimensional strided array views.
2
3use core::ops::{Index, IndexMut};
4
5use crate::error::{Error, Result};
6use crate::view2::{ArrayView2, ArrayViewMut2, validate_view};
7
8/// Immutable 3D array view.
9#[derive(Clone, Copy, Debug)]
10pub struct ArrayView3<'a, T> {
11    pub(crate) data: &'a [T],
12    pub(crate) shape: [usize; 3],
13    pub(crate) strides: [isize; 3],
14    pub(crate) offset: isize,
15}
16
17/// Mutable 3D array view.
18#[derive(Debug)]
19pub struct ArrayViewMut3<'a, T> {
20    pub(crate) data: &'a mut [T],
21    pub(crate) shape: [usize; 3],
22    pub(crate) strides: [isize; 3],
23    pub(crate) offset: isize,
24}
25
26impl<'a, T> ArrayView3<'a, T> {
27    /// Create a checked immutable view.
28    pub fn new(
29        data: &'a [T],
30        shape: [usize; 3],
31        strides: [isize; 3],
32        offset: isize,
33    ) -> Result<Self> {
34        validate_view(data.len(), &shape, &strides, offset)?;
35        Ok(Self {
36            data,
37            shape,
38            strides,
39            offset,
40        })
41    }
42
43    pub(crate) fn from_raw_parts(
44        data: &'a [T],
45        shape: [usize; 3],
46        strides: [isize; 3],
47        offset: isize,
48    ) -> Self {
49        Self {
50            data,
51            shape,
52            strides,
53            offset,
54        }
55    }
56
57    /// Shape as `[dim0, dim1, dim2]`.
58    pub fn shape(&self) -> [usize; 3] {
59        self.shape
60    }
61
62    /// Strides in elements.
63    pub fn strides(&self) -> [isize; 3] {
64        self.strides
65    }
66
67    /// Number of logical elements.
68    pub fn len(&self) -> usize {
69        self.shape.iter().product()
70    }
71
72    /// Whether the view is empty.
73    pub fn is_empty(&self) -> bool {
74        self.len() == 0
75    }
76
77    /// Whether the view is compact row-major contiguous.
78    pub fn is_contiguous(&self) -> bool {
79        self.shape.contains(&0)
80            || (self.offset == 0
81                && self.strides
82                    == [
83                        (self.shape[1] * self.shape[2]) as isize,
84                        self.shape[2] as isize,
85                        1,
86                    ]
87                && self.len() == self.data.len())
88    }
89
90    /// Borrow the backing slice if this view covers it contiguously.
91    pub fn as_slice(&self) -> Option<&'a [T]> {
92        self.is_contiguous().then_some(self.data)
93    }
94
95    /// Get an element reference.
96    pub fn get(&self, i: usize, j: usize, k: usize) -> Option<&'a T> {
97        (i < self.shape[0] && j < self.shape[1] && k < self.shape[2])
98            .then(|| &self.data[self.linear_index(i, j, k)])
99    }
100
101    /// Extract a 2D matrix view by fixing one axis.
102    pub fn matrix_at(&self, axis: usize, index: usize) -> Result<ArrayView2<'a, T>> {
103        if axis >= 3 {
104            return Err(Error::AxisOutOfBounds { axis, ndim: 3 });
105        }
106        if index >= self.shape[axis] {
107            return Err(Error::IndexOutOfBounds);
108        }
109        let axes: Vec<usize> = (0..3).filter(|&candidate| candidate != axis).collect();
110        Ok(ArrayView2::from_raw_parts(
111            self.data,
112            [self.shape[axes[0]], self.shape[axes[1]]],
113            [self.strides[axes[0]], self.strides[axes[1]]],
114            self.offset + index as isize * self.strides[axis],
115        ))
116    }
117
118    /// Visit each 2D matrix slice along `axis` in order.
119    pub fn for_each_matrix(
120        &self,
121        axis: usize,
122        mut f: impl FnMut(usize, ArrayView2<'a, T>) -> Result<()>,
123    ) -> Result<()> {
124        if axis >= 3 {
125            return Err(Error::AxisOutOfBounds { axis, ndim: 3 });
126        }
127        for index in 0..self.shape[axis] {
128            f(index, self.matrix_at(axis, index)?)?;
129        }
130        Ok(())
131    }
132
133    #[inline]
134    pub(crate) fn linear_index(&self, i: usize, j: usize, k: usize) -> usize {
135        (self.offset
136            + i as isize * self.strides[0]
137            + j as isize * self.strides[1]
138            + k as isize * self.strides[2]) as usize
139    }
140}
141
142impl<'a, T> ArrayViewMut3<'a, T> {
143    /// Create a checked mutable view.
144    pub fn new(
145        data: &'a mut [T],
146        shape: [usize; 3],
147        strides: [isize; 3],
148        offset: isize,
149    ) -> Result<Self> {
150        validate_view(data.len(), &shape, &strides, offset)?;
151        Ok(Self {
152            data,
153            shape,
154            strides,
155            offset,
156        })
157    }
158
159    pub(crate) fn from_raw_parts(
160        data: &'a mut [T],
161        shape: [usize; 3],
162        strides: [isize; 3],
163        offset: isize,
164    ) -> Self {
165        Self {
166            data,
167            shape,
168            strides,
169            offset,
170        }
171    }
172
173    /// Shape as `[dim0, dim1, dim2]`.
174    pub fn shape(&self) -> [usize; 3] {
175        self.shape
176    }
177
178    /// Immutable view over the same region.
179    pub fn as_view(&self) -> ArrayView3<'_, T> {
180        ArrayView3 {
181            data: self.data,
182            shape: self.shape,
183            strides: self.strides,
184            offset: self.offset,
185        }
186    }
187
188    /// Get an element reference.
189    pub fn get(&self, i: usize, j: usize, k: usize) -> Option<&T> {
190        (i < self.shape[0] && j < self.shape[1] && k < self.shape[2])
191            .then(|| &self.data[self.linear_index(i, j, k)])
192    }
193
194    /// Get a mutable element reference.
195    pub fn get_mut(&mut self, i: usize, j: usize, k: usize) -> Option<&mut T> {
196        if i >= self.shape[0] || j >= self.shape[1] || k >= self.shape[2] {
197            return None;
198        }
199        let index = self.linear_index(i, j, k);
200        Some(&mut self.data[index])
201    }
202
203    /// Extract a mutable 2D matrix view by fixing one axis.
204    pub fn matrix_at_mut(&mut self, axis: usize, index: usize) -> Result<ArrayViewMut2<'_, T>> {
205        if axis >= 3 {
206            return Err(Error::AxisOutOfBounds { axis, ndim: 3 });
207        }
208        if index >= self.shape[axis] {
209            return Err(Error::IndexOutOfBounds);
210        }
211        let axes: Vec<usize> = (0..3).filter(|&candidate| candidate != axis).collect();
212        let offset = self.offset + index as isize * self.strides[axis];
213        ArrayViewMut2::new(
214            &mut *self.data,
215            [self.shape[axes[0]], self.shape[axes[1]]],
216            [self.strides[axes[0]], self.strides[axes[1]]],
217            offset,
218        )
219    }
220
221    /// Visit each mutable 2D matrix slice along `axis` in order.
222    pub fn for_each_matrix_mut(
223        &mut self,
224        axis: usize,
225        mut f: impl FnMut(usize, ArrayViewMut2<'_, T>) -> Result<()>,
226    ) -> Result<()> {
227        if axis >= 3 {
228            return Err(Error::AxisOutOfBounds { axis, ndim: 3 });
229        }
230        for index in 0..self.shape[axis] {
231            f(index, self.matrix_at_mut(axis, index)?)?;
232        }
233        Ok(())
234    }
235
236    #[inline]
237    pub(crate) fn linear_index(&self, i: usize, j: usize, k: usize) -> usize {
238        (self.offset
239            + i as isize * self.strides[0]
240            + j as isize * self.strides[1]
241            + k as isize * self.strides[2]) as usize
242    }
243}
244
245impl<T> Index<(usize, usize, usize)> for ArrayView3<'_, T> {
246    type Output = T;
247
248    fn index(&self, index: (usize, usize, usize)) -> &Self::Output {
249        self.get(index.0, index.1, index.2)
250            .expect("view index out of bounds")
251    }
252}
253
254impl<T> Index<(usize, usize, usize)> for ArrayViewMut3<'_, T> {
255    type Output = T;
256
257    fn index(&self, index: (usize, usize, usize)) -> &Self::Output {
258        self.get(index.0, index.1, index.2)
259            .expect("view index out of bounds")
260    }
261}
262
263impl<T> IndexMut<(usize, usize, usize)> for ArrayViewMut3<'_, T> {
264    fn index_mut(&mut self, index: (usize, usize, usize)) -> &mut Self::Output {
265        self.get_mut(index.0, index.1, index.2)
266            .expect("view index out of bounds")
267    }
268}