1use core::ops::{Index, IndexMut};
4
5use crate::array2::Array2;
6use crate::error::{Error, Result};
7use crate::numeric::Float;
8use crate::rand::SmallRng;
9use crate::view2::{ArrayView2, ArrayViewMut2};
10use crate::view3::{ArrayView3, ArrayViewMut3};
11
12#[derive(Clone, Copy, Debug, PartialEq, Eq)]
14pub enum Axis3 {
15 Axis0,
17 Axis1,
19 Axis2,
21}
22
23impl Axis3 {
24 pub(crate) fn index(self) -> usize {
25 match self {
26 Self::Axis0 => 0,
27 Self::Axis1 => 1,
28 Self::Axis2 => 2,
29 }
30 }
31}
32
33#[derive(Clone, Debug, PartialEq)]
35pub struct Array3<T> {
36 data: Vec<T>,
37 shape: [usize; 3],
38}
39
40impl<T> Array3<T> {
41 pub fn from_vec(shape: [usize; 3], data: Vec<T>) -> Result<Self> {
43 let expected = shape
44 .iter()
45 .try_fold(1usize, |acc, &dim| acc.checked_mul(dim))
46 .ok_or(Error::DimensionTooLarge)?;
47 if data.len() != expected {
48 return Err(Error::shape(vec![expected], vec![data.len()]));
49 }
50 Ok(Self { data, shape })
51 }
52
53 pub fn from_fn(shape: [usize; 3], mut f: impl FnMut(usize, usize, usize) -> T) -> Self {
55 let mut data = Vec::with_capacity(shape.iter().product());
56 for i in 0..shape[0] {
57 for j in 0..shape[1] {
58 for k in 0..shape[2] {
59 data.push(f(i, j, k));
60 }
61 }
62 }
63 Self { data, shape }
64 }
65
66 pub fn try_from_fn(
68 shape: [usize; 3],
69 mut f: impl FnMut(usize, usize, usize) -> T,
70 ) -> Result<Self> {
71 let len = checked_len(shape)?;
72 let mut data = Vec::new();
73 data.try_reserve_exact(len)
74 .map_err(|_| Error::AllocationFailed)?;
75 for i in 0..shape[0] {
76 for j in 0..shape[1] {
77 for k in 0..shape[2] {
78 data.push(f(i, j, k));
79 }
80 }
81 }
82 Ok(Self { data, shape })
83 }
84
85 pub fn shape(&self) -> [usize; 3] {
87 self.shape
88 }
89
90 pub fn strides(&self) -> [isize; 3] {
92 [
93 (self.shape[1] * self.shape[2]) as isize,
94 self.shape[2] as isize,
95 1,
96 ]
97 }
98
99 pub fn len(&self) -> usize {
101 self.data.len()
102 }
103
104 pub fn is_empty(&self) -> bool {
106 self.data.is_empty()
107 }
108
109 pub fn as_slice(&self) -> &[T] {
111 &self.data
112 }
113
114 pub fn as_mut_slice(&mut self) -> &mut [T] {
116 &mut self.data
117 }
118
119 pub fn view(&self) -> ArrayView3<'_, T> {
121 ArrayView3::from_raw_parts(&self.data, self.shape, self.strides(), 0)
122 }
123
124 pub fn view_mut(&mut self) -> ArrayViewMut3<'_, T> {
126 ArrayViewMut3::from_raw_parts(
127 &mut self.data,
128 self.shape,
129 [
130 (self.shape[1] * self.shape[2]) as isize,
131 self.shape[2] as isize,
132 1,
133 ],
134 0,
135 )
136 }
137
138 pub fn get(&self, i: usize, j: usize, k: usize) -> Option<&T> {
140 (i < self.shape[0] && j < self.shape[1] && k < self.shape[2])
141 .then(|| &self.data[self.linear_index(i, j, k)])
142 }
143
144 pub fn get_mut(&mut self, i: usize, j: usize, k: usize) -> Option<&mut T> {
146 if i >= self.shape[0] || j >= self.shape[1] || k >= self.shape[2] {
147 return None;
148 }
149 let idx = self.linear_index(i, j, k);
150 Some(&mut self.data[idx])
151 }
152
153 pub fn matrix_at(&self, axis: Axis3, index: usize) -> Result<ArrayView2<'_, T>> {
155 self.view().matrix_at(axis.index(), index)
156 }
157
158 pub fn for_each_matrix(
160 &self,
161 axis: Axis3,
162 f: impl FnMut(usize, ArrayView2<'_, T>) -> Result<()>,
163 ) -> Result<()> {
164 self.view().for_each_matrix(axis.index(), f)
165 }
166
167 pub fn matrix_at_mut(&mut self, axis: Axis3, index: usize) -> Result<ArrayViewMut2<'_, T>> {
169 let axis = axis.index();
170 if index >= self.shape[axis] {
171 return Err(Error::IndexOutOfBounds);
172 }
173 let strides = self.strides();
174 let axes: Vec<usize> = (0..3).filter(|&candidate| candidate != axis).collect();
175 ArrayViewMut2::new(
176 &mut self.data,
177 [self.shape[axes[0]], self.shape[axes[1]]],
178 [strides[axes[0]], strides[axes[1]]],
179 index as isize * strides[axis],
180 )
181 }
182
183 pub fn for_each_matrix_mut(
185 &mut self,
186 axis: Axis3,
187 mut f: impl FnMut(usize, ArrayViewMut2<'_, T>) -> Result<()>,
188 ) -> Result<()> {
189 let axis = axis.index();
190 if axis >= 3 {
191 return Err(Error::AxisOutOfBounds { axis, ndim: 3 });
192 }
193 for index in 0..self.shape[axis] {
194 f(
195 index,
196 self.matrix_at_mut(
197 match axis {
198 0 => Axis3::Axis0,
199 1 => Axis3::Axis1,
200 _ => Axis3::Axis2,
201 },
202 index,
203 )?,
204 )?;
205 }
206 Ok(())
207 }
208
209 fn linear_index(&self, i: usize, j: usize, k: usize) -> usize {
210 (i * self.shape[1] + j) * self.shape[2] + k
211 }
212}
213
214impl<T: Clone> Array3<T> {
215 pub fn filled(shape: [usize; 3], value: T) -> Self {
217 Self {
218 data: vec![value; shape.iter().product()],
219 shape,
220 }
221 }
222
223 pub fn try_filled(shape: [usize; 3], value: T) -> Result<Self> {
225 let len = checked_len(shape)?;
226 let mut data = Vec::new();
227 data.try_reserve_exact(len)
228 .map_err(|_| Error::AllocationFailed)?;
229 data.resize(len, value);
230 Ok(Self { data, shape })
231 }
232
233 pub fn clone_contiguous(view: ArrayView3<'_, T>) -> Self {
235 Self::from_fn(view.shape(), |i, j, k| view[(i, j, k)].clone())
236 }
237
238 pub fn unfold(&self, axis: Axis3) -> Array2<T> {
243 unfold_view(self.view(), axis)
244 }
245
246 pub fn fold(axis: Axis3, shape: [usize; 3], matrix: ArrayView2<'_, T>) -> Result<Self> {
248 fold_view(axis, shape, matrix)
249 }
250}
251
252pub fn unfold_view<T: Clone>(a: ArrayView3<'_, T>, axis: Axis3) -> Array2<T> {
254 let shape = a.shape();
255 match axis {
256 Axis3::Axis0 => Array2::from_fn([shape[0], shape[1] * shape[2]], |row, col| {
257 let j = col / shape[2];
258 let k = col % shape[2];
259 a[(row, j, k)].clone()
260 }),
261 Axis3::Axis1 => Array2::from_fn([shape[1], shape[0] * shape[2]], |row, col| {
262 let i = col / shape[2];
263 let k = col % shape[2];
264 a[(i, row, k)].clone()
265 }),
266 Axis3::Axis2 => Array2::from_fn([shape[2], shape[0] * shape[1]], |row, col| {
267 let i = col / shape[1];
268 let j = col % shape[1];
269 a[(i, j, row)].clone()
270 }),
271 }
272}
273
274pub fn fold_view<T: Clone>(
276 axis: Axis3,
277 shape: [usize; 3],
278 matrix: ArrayView2<'_, T>,
279) -> Result<Array3<T>> {
280 let expected = match axis {
281 Axis3::Axis0 => [shape[0], shape[1] * shape[2]],
282 Axis3::Axis1 => [shape[1], shape[0] * shape[2]],
283 Axis3::Axis2 => [shape[2], shape[0] * shape[1]],
284 };
285 if matrix.shape() != expected {
286 return Err(Error::shape(expected, matrix.shape()));
287 }
288 Ok(Array3::from_fn(shape, |i, j, k| match axis {
289 Axis3::Axis0 => matrix[(i, j * shape[2] + k)].clone(),
290 Axis3::Axis1 => matrix[(j, i * shape[2] + k)].clone(),
291 Axis3::Axis2 => matrix[(k, i * shape[1] + j)].clone(),
292 }))
293}
294
295impl<T: Float> Array3<T> {
296 pub fn zeros(shape: [usize; 3]) -> Self {
298 Self::filled(shape, T::zero())
299 }
300
301 pub fn try_zeros(shape: [usize; 3]) -> Result<Self> {
303 Self::try_filled(shape, T::zero())
304 }
305
306 pub fn ones(shape: [usize; 3]) -> Self {
308 Self::filled(shape, T::one())
309 }
310
311 pub fn try_ones(shape: [usize; 3]) -> Result<Self> {
313 Self::try_filled(shape, T::one())
314 }
315
316 pub fn zeros_like(&self) -> Self {
318 Self::zeros(self.shape)
319 }
320
321 pub fn scale(&mut self, alpha: T) {
323 for value in &mut self.data {
324 *value *= alpha;
325 }
326 }
327
328 pub fn scaled(&self, alpha: T) -> Self {
330 Self::from_fn(self.shape, |i, j, k| self[(i, j, k)] * alpha)
331 }
332
333 pub fn scaled_into(&self, alpha: T, mut out: ArrayViewMut3<'_, T>) -> Result<()> {
335 if self.shape != out.shape() {
336 return Err(Error::shape(self.shape, out.shape()));
337 }
338 for i in 0..self.shape[0] {
339 for j in 0..self.shape[1] {
340 for k in 0..self.shape[2] {
341 out[(i, j, k)] = self[(i, j, k)] * alpha;
342 }
343 }
344 }
345 Ok(())
346 }
347
348 pub fn add(&self, other: ArrayView3<'_, T>) -> Result<Self> {
350 self.zip_map(other, |left, right| left + right)
351 }
352
353 pub fn add_into(&self, other: ArrayView3<'_, T>, out: ArrayViewMut3<'_, T>) -> Result<()> {
355 self.zip_map_into(other, out, |left, right| left + right)
356 }
357
358 pub fn sub(&self, other: ArrayView3<'_, T>) -> Result<Self> {
360 self.zip_map(other, |left, right| left - right)
361 }
362
363 pub fn sub_into(&self, other: ArrayView3<'_, T>, out: ArrayViewMut3<'_, T>) -> Result<()> {
365 self.zip_map_into(other, out, |left, right| left - right)
366 }
367
368 pub fn mul(&self, other: ArrayView3<'_, T>) -> Result<Self> {
370 self.zip_map(other, |left, right| left * right)
371 }
372
373 pub fn mul_into(&self, other: ArrayView3<'_, T>, out: ArrayViewMut3<'_, T>) -> Result<()> {
375 self.zip_map_into(other, out, |left, right| left * right)
376 }
377
378 pub fn hadamard(&self, other: ArrayView3<'_, T>) -> Result<Self> {
380 self.mul(other)
381 }
382
383 pub fn hadamard_into(&self, other: ArrayView3<'_, T>, out: ArrayViewMut3<'_, T>) -> Result<()> {
385 self.mul_into(other, out)
386 }
387
388 pub fn div(&self, other: ArrayView3<'_, T>) -> Result<Self> {
390 self.zip_map(other, |left, right| left / right)
391 }
392
393 pub fn div_into(&self, other: ArrayView3<'_, T>, out: ArrayViewMut3<'_, T>) -> Result<()> {
395 self.zip_map_into(other, out, |left, right| left / right)
396 }
397
398 pub fn axpy_result(&self, alpha: T, x: ArrayView3<'_, T>) -> Result<Self> {
400 self.zip_map(x, |left, right| left + alpha * right)
401 }
402
403 pub fn axpy_into(
405 &self,
406 alpha: T,
407 x: ArrayView3<'_, T>,
408 out: ArrayViewMut3<'_, T>,
409 ) -> Result<()> {
410 self.zip_map_into(x, out, |left, right| left + alpha * right)
411 }
412
413 pub fn add_assign_view(&mut self, other: ArrayView3<'_, T>) -> Result<()> {
415 self.zip_map_inplace(other, |left, right| left + right)
416 }
417
418 pub fn sub_assign_view(&mut self, other: ArrayView3<'_, T>) -> Result<()> {
420 self.zip_map_inplace(other, |left, right| left - right)
421 }
422
423 pub fn mul_assign_view(&mut self, other: ArrayView3<'_, T>) -> Result<()> {
425 self.zip_map_inplace(other, |left, right| left * right)
426 }
427
428 pub fn div_assign_view(&mut self, other: ArrayView3<'_, T>) -> Result<()> {
430 self.zip_map_inplace(other, |left, right| left / right)
431 }
432
433 pub fn axpy(&mut self, alpha: T, x: ArrayView3<'_, T>) -> Result<()> {
435 if self.shape != x.shape() {
436 return Err(Error::shape(self.shape, x.shape()));
437 }
438 for i in 0..self.shape[0] {
439 for j in 0..self.shape[1] {
440 for k in 0..self.shape[2] {
441 self[(i, j, k)] += alpha * x[(i, j, k)];
442 }
443 }
444 }
445 Ok(())
446 }
447
448 pub fn zip_map(&self, other: ArrayView3<'_, T>, mut f: impl FnMut(T, T) -> T) -> Result<Self> {
450 if self.shape != other.shape() {
451 return Err(Error::shape(self.shape, other.shape()));
452 }
453 Ok(Self::from_fn(self.shape, |i, j, k| {
454 f(self[(i, j, k)], other[(i, j, k)])
455 }))
456 }
457
458 pub fn zip_map_into(
460 &self,
461 other: ArrayView3<'_, T>,
462 mut out: ArrayViewMut3<'_, T>,
463 mut f: impl FnMut(T, T) -> T,
464 ) -> Result<()> {
465 if self.shape != other.shape() {
466 return Err(Error::shape(self.shape, other.shape()));
467 }
468 if self.shape != out.shape() {
469 return Err(Error::shape(self.shape, out.shape()));
470 }
471 for i in 0..self.shape[0] {
472 for j in 0..self.shape[1] {
473 for k in 0..self.shape[2] {
474 out[(i, j, k)] = f(self[(i, j, k)], other[(i, j, k)]);
475 }
476 }
477 }
478 Ok(())
479 }
480
481 pub fn map_inplace(&mut self, mut f: impl FnMut(T) -> T) {
483 for value in &mut self.data {
484 *value = f(*value);
485 }
486 }
487
488 pub fn zip_map_inplace(
490 &mut self,
491 other: ArrayView3<'_, T>,
492 mut f: impl FnMut(T, T) -> T,
493 ) -> Result<()> {
494 if self.shape != other.shape() {
495 return Err(Error::shape(self.shape, other.shape()));
496 }
497 for i in 0..self.shape[0] {
498 for j in 0..self.shape[1] {
499 for k in 0..self.shape[2] {
500 self[(i, j, k)] = f(self[(i, j, k)], other[(i, j, k)]);
501 }
502 }
503 }
504 Ok(())
505 }
506
507 pub fn fill_uniform(&mut self, seed: u64) {
509 let mut rng = SmallRng::new(seed);
510 for value in &mut self.data {
511 *value = rng.uniform();
512 }
513 }
514
515 pub fn fill_randn(&mut self, seed: u64) {
517 let mut rng = SmallRng::new(seed);
518 for value in &mut self.data {
519 *value = rng.normal();
520 }
521 }
522
523 pub fn sum(&self) -> T {
525 self.data.iter().copied().sum()
526 }
527
528 pub fn mean(&self) -> T {
530 if self.is_empty() {
531 T::zero()
532 } else {
533 self.sum() / T::from_f64(self.len() as f64)
534 }
535 }
536
537 pub fn norm_frobenius(&self) -> T {
539 self.data
540 .iter()
541 .copied()
542 .map(|value| value * value)
543 .sum::<T>()
544 .sqrt()
545 }
546
547 pub fn max_abs(&self) -> T {
549 self.data
550 .iter()
551 .copied()
552 .map(T::abs)
553 .fold(
554 T::zero(),
555 |best, value| if value > best { value } else { best },
556 )
557 }
558
559 pub fn dot(&self, other: ArrayView3<'_, T>) -> Result<T> {
561 if self.shape != other.shape() {
562 return Err(Error::shape(self.shape, other.shape()));
563 }
564 let mut sum = T::zero();
565 for i in 0..self.shape[0] {
566 for j in 0..self.shape[1] {
567 for k in 0..self.shape[2] {
568 sum += self[(i, j, k)] * other[(i, j, k)];
569 }
570 }
571 }
572 Ok(sum)
573 }
574}
575
576impl<T> Index<(usize, usize, usize)> for Array3<T> {
577 type Output = T;
578
579 fn index(&self, index: (usize, usize, usize)) -> &Self::Output {
580 self.get(index.0, index.1, index.2)
581 .expect("array index out of bounds")
582 }
583}
584
585fn checked_len(shape: [usize; 3]) -> Result<usize> {
586 shape
587 .iter()
588 .try_fold(1usize, |acc, &dim| acc.checked_mul(dim))
589 .ok_or(Error::DimensionTooLarge)
590}
591
592impl<T> IndexMut<(usize, usize, usize)> for Array3<T> {
593 fn index_mut(&mut self, index: (usize, usize, usize)) -> &mut Self::Output {
594 self.get_mut(index.0, index.1, index.2)
595 .expect("array index out of bounds")
596 }
597}