1use core::ops::{Index, IndexMut};
4
5use crate::error::{Error, Result};
6use crate::view2::{ArrayView2, ArrayViewMut2, validate_view};
7
8#[derive(Clone, Copy, Debug)]
10pub struct ArrayView3<'a, T> {
11 pub(crate) data: &'a [T],
12 pub(crate) shape: [usize; 3],
13 pub(crate) strides: [isize; 3],
14 pub(crate) offset: isize,
15}
16
17#[derive(Debug)]
19pub struct ArrayViewMut3<'a, T> {
20 pub(crate) data: &'a mut [T],
21 pub(crate) shape: [usize; 3],
22 pub(crate) strides: [isize; 3],
23 pub(crate) offset: isize,
24}
25
26impl<'a, T> ArrayView3<'a, T> {
27 pub fn new(
29 data: &'a [T],
30 shape: [usize; 3],
31 strides: [isize; 3],
32 offset: isize,
33 ) -> Result<Self> {
34 validate_view(data.len(), &shape, &strides, offset)?;
35 Ok(Self {
36 data,
37 shape,
38 strides,
39 offset,
40 })
41 }
42
43 pub(crate) fn from_raw_parts(
44 data: &'a [T],
45 shape: [usize; 3],
46 strides: [isize; 3],
47 offset: isize,
48 ) -> Self {
49 Self {
50 data,
51 shape,
52 strides,
53 offset,
54 }
55 }
56
57 pub fn shape(&self) -> [usize; 3] {
59 self.shape
60 }
61
62 pub fn strides(&self) -> [isize; 3] {
64 self.strides
65 }
66
67 pub fn len(&self) -> usize {
69 self.shape.iter().product()
70 }
71
72 pub fn is_empty(&self) -> bool {
74 self.len() == 0
75 }
76
77 pub fn is_contiguous(&self) -> bool {
79 self.shape.contains(&0)
80 || (self.offset == 0
81 && self.strides
82 == [
83 (self.shape[1] * self.shape[2]) as isize,
84 self.shape[2] as isize,
85 1,
86 ]
87 && self.len() == self.data.len())
88 }
89
90 pub fn as_slice(&self) -> Option<&'a [T]> {
92 self.is_contiguous().then_some(self.data)
93 }
94
95 pub fn get(&self, i: usize, j: usize, k: usize) -> Option<&'a T> {
97 (i < self.shape[0] && j < self.shape[1] && k < self.shape[2])
98 .then(|| &self.data[self.linear_index(i, j, k)])
99 }
100
101 pub fn matrix_at(&self, axis: usize, index: usize) -> Result<ArrayView2<'a, T>> {
103 if axis >= 3 {
104 return Err(Error::AxisOutOfBounds { axis, ndim: 3 });
105 }
106 if index >= self.shape[axis] {
107 return Err(Error::IndexOutOfBounds);
108 }
109 let axes: Vec<usize> = (0..3).filter(|&candidate| candidate != axis).collect();
110 Ok(ArrayView2::from_raw_parts(
111 self.data,
112 [self.shape[axes[0]], self.shape[axes[1]]],
113 [self.strides[axes[0]], self.strides[axes[1]]],
114 self.offset + index as isize * self.strides[axis],
115 ))
116 }
117
118 pub fn for_each_matrix(
120 &self,
121 axis: usize,
122 mut f: impl FnMut(usize, ArrayView2<'a, T>) -> Result<()>,
123 ) -> Result<()> {
124 if axis >= 3 {
125 return Err(Error::AxisOutOfBounds { axis, ndim: 3 });
126 }
127 for index in 0..self.shape[axis] {
128 f(index, self.matrix_at(axis, index)?)?;
129 }
130 Ok(())
131 }
132
133 #[inline]
134 pub(crate) fn linear_index(&self, i: usize, j: usize, k: usize) -> usize {
135 (self.offset
136 + i as isize * self.strides[0]
137 + j as isize * self.strides[1]
138 + k as isize * self.strides[2]) as usize
139 }
140}
141
142impl<'a, T> ArrayViewMut3<'a, T> {
143 pub fn new(
145 data: &'a mut [T],
146 shape: [usize; 3],
147 strides: [isize; 3],
148 offset: isize,
149 ) -> Result<Self> {
150 validate_view(data.len(), &shape, &strides, offset)?;
151 Ok(Self {
152 data,
153 shape,
154 strides,
155 offset,
156 })
157 }
158
159 pub(crate) fn from_raw_parts(
160 data: &'a mut [T],
161 shape: [usize; 3],
162 strides: [isize; 3],
163 offset: isize,
164 ) -> Self {
165 Self {
166 data,
167 shape,
168 strides,
169 offset,
170 }
171 }
172
173 pub fn shape(&self) -> [usize; 3] {
175 self.shape
176 }
177
178 pub fn as_view(&self) -> ArrayView3<'_, T> {
180 ArrayView3 {
181 data: self.data,
182 shape: self.shape,
183 strides: self.strides,
184 offset: self.offset,
185 }
186 }
187
188 pub fn get(&self, i: usize, j: usize, k: usize) -> Option<&T> {
190 (i < self.shape[0] && j < self.shape[1] && k < self.shape[2])
191 .then(|| &self.data[self.linear_index(i, j, k)])
192 }
193
194 pub fn get_mut(&mut self, i: usize, j: usize, k: usize) -> Option<&mut T> {
196 if i >= self.shape[0] || j >= self.shape[1] || k >= self.shape[2] {
197 return None;
198 }
199 let index = self.linear_index(i, j, k);
200 Some(&mut self.data[index])
201 }
202
203 pub fn matrix_at_mut(&mut self, axis: usize, index: usize) -> Result<ArrayViewMut2<'_, T>> {
205 if axis >= 3 {
206 return Err(Error::AxisOutOfBounds { axis, ndim: 3 });
207 }
208 if index >= self.shape[axis] {
209 return Err(Error::IndexOutOfBounds);
210 }
211 let axes: Vec<usize> = (0..3).filter(|&candidate| candidate != axis).collect();
212 let offset = self.offset + index as isize * self.strides[axis];
213 ArrayViewMut2::new(
214 &mut *self.data,
215 [self.shape[axes[0]], self.shape[axes[1]]],
216 [self.strides[axes[0]], self.strides[axes[1]]],
217 offset,
218 )
219 }
220
221 pub fn for_each_matrix_mut(
223 &mut self,
224 axis: usize,
225 mut f: impl FnMut(usize, ArrayViewMut2<'_, T>) -> Result<()>,
226 ) -> Result<()> {
227 if axis >= 3 {
228 return Err(Error::AxisOutOfBounds { axis, ndim: 3 });
229 }
230 for index in 0..self.shape[axis] {
231 f(index, self.matrix_at_mut(axis, index)?)?;
232 }
233 Ok(())
234 }
235
236 #[inline]
237 pub(crate) fn linear_index(&self, i: usize, j: usize, k: usize) -> usize {
238 (self.offset
239 + i as isize * self.strides[0]
240 + j as isize * self.strides[1]
241 + k as isize * self.strides[2]) as usize
242 }
243}
244
245impl<T> Index<(usize, usize, usize)> for ArrayView3<'_, T> {
246 type Output = T;
247
248 fn index(&self, index: (usize, usize, usize)) -> &Self::Output {
249 self.get(index.0, index.1, index.2)
250 .expect("view index out of bounds")
251 }
252}
253
254impl<T> Index<(usize, usize, usize)> for ArrayViewMut3<'_, T> {
255 type Output = T;
256
257 fn index(&self, index: (usize, usize, usize)) -> &Self::Output {
258 self.get(index.0, index.1, index.2)
259 .expect("view index out of bounds")
260 }
261}
262
263impl<T> IndexMut<(usize, usize, usize)> for ArrayViewMut3<'_, T> {
264 fn index_mut(&mut self, index: (usize, usize, usize)) -> &mut Self::Output {
265 self.get_mut(index.0, index.1, index.2)
266 .expect("view index out of bounds")
267 }
268}