Skip to main content

tensors/
array3.rs

1//! Owned three-dimensional row-major arrays.
2
3use core::ops::{Index, IndexMut};
4
5use crate::array2::Array2;
6use crate::error::{Error, Result};
7use crate::numeric::Float;
8use crate::rand::SmallRng;
9use crate::view2::{ArrayView2, ArrayViewMut2};
10use crate::view3::{ArrayView3, ArrayViewMut3};
11
12/// Axis selector for 3D arrays.
13#[derive(Clone, Copy, Debug, PartialEq, Eq)]
14pub enum Axis3 {
15    /// First axis.
16    Axis0,
17    /// Second axis.
18    Axis1,
19    /// Third axis.
20    Axis2,
21}
22
23impl Axis3 {
24    pub(crate) fn index(self) -> usize {
25        match self {
26            Self::Axis0 => 0,
27            Self::Axis1 => 1,
28            Self::Axis2 => 2,
29        }
30    }
31}
32
33/// Owned 3D row-major array.
34#[derive(Clone, Debug, PartialEq)]
35pub struct Array3<T> {
36    data: Vec<T>,
37    shape: [usize; 3],
38}
39
40impl<T> Array3<T> {
41    /// Build an array from row-major data.
42    pub fn from_vec(shape: [usize; 3], data: Vec<T>) -> Result<Self> {
43        let expected = shape
44            .iter()
45            .try_fold(1usize, |acc, &dim| acc.checked_mul(dim))
46            .ok_or(Error::DimensionTooLarge)?;
47        if data.len() != expected {
48            return Err(Error::shape(vec![expected], vec![data.len()]));
49        }
50        Ok(Self { data, shape })
51    }
52
53    /// Build an array from a function over `(i, j, k)`.
54    pub fn from_fn(shape: [usize; 3], mut f: impl FnMut(usize, usize, usize) -> T) -> Self {
55        let mut data = Vec::with_capacity(shape.iter().product());
56        for i in 0..shape[0] {
57            for j in 0..shape[1] {
58                for k in 0..shape[2] {
59                    data.push(f(i, j, k));
60                }
61            }
62        }
63        Self { data, shape }
64    }
65
66    /// Fallibly build an array from a function over `(i, j, k)`.
67    pub fn try_from_fn(
68        shape: [usize; 3],
69        mut f: impl FnMut(usize, usize, usize) -> T,
70    ) -> Result<Self> {
71        let len = checked_len(shape)?;
72        let mut data = Vec::new();
73        data.try_reserve_exact(len)
74            .map_err(|_| Error::AllocationFailed)?;
75        for i in 0..shape[0] {
76            for j in 0..shape[1] {
77                for k in 0..shape[2] {
78                    data.push(f(i, j, k));
79                }
80            }
81        }
82        Ok(Self { data, shape })
83    }
84
85    /// Shape as `[dim0, dim1, dim2]`.
86    pub fn shape(&self) -> [usize; 3] {
87        self.shape
88    }
89
90    /// Row-major strides in elements.
91    pub fn strides(&self) -> [isize; 3] {
92        [
93            (self.shape[1] * self.shape[2]) as isize,
94            self.shape[2] as isize,
95            1,
96        ]
97    }
98
99    /// Number of elements.
100    pub fn len(&self) -> usize {
101        self.data.len()
102    }
103
104    /// Whether the array has zero elements.
105    pub fn is_empty(&self) -> bool {
106        self.data.is_empty()
107    }
108
109    /// Borrow the backing slice.
110    pub fn as_slice(&self) -> &[T] {
111        &self.data
112    }
113
114    /// Borrow the backing slice mutably.
115    pub fn as_mut_slice(&mut self) -> &mut [T] {
116        &mut self.data
117    }
118
119    /// Immutable strided view.
120    pub fn view(&self) -> ArrayView3<'_, T> {
121        ArrayView3::from_raw_parts(&self.data, self.shape, self.strides(), 0)
122    }
123
124    /// Mutable strided view.
125    pub fn view_mut(&mut self) -> ArrayViewMut3<'_, T> {
126        ArrayViewMut3::from_raw_parts(
127            &mut self.data,
128            self.shape,
129            [
130                (self.shape[1] * self.shape[2]) as isize,
131                self.shape[2] as isize,
132                1,
133            ],
134            0,
135        )
136    }
137
138    /// Get an element reference.
139    pub fn get(&self, i: usize, j: usize, k: usize) -> Option<&T> {
140        (i < self.shape[0] && j < self.shape[1] && k < self.shape[2])
141            .then(|| &self.data[self.linear_index(i, j, k)])
142    }
143
144    /// Get a mutable element reference.
145    pub fn get_mut(&mut self, i: usize, j: usize, k: usize) -> Option<&mut T> {
146        if i >= self.shape[0] || j >= self.shape[1] || k >= self.shape[2] {
147            return None;
148        }
149        let idx = self.linear_index(i, j, k);
150        Some(&mut self.data[idx])
151    }
152
153    /// Extract a 2D matrix view by fixing one axis.
154    pub fn matrix_at(&self, axis: Axis3, index: usize) -> Result<ArrayView2<'_, T>> {
155        self.view().matrix_at(axis.index(), index)
156    }
157
158    /// Visit each 2D matrix slice along `axis` in order.
159    pub fn for_each_matrix(
160        &self,
161        axis: Axis3,
162        f: impl FnMut(usize, ArrayView2<'_, T>) -> Result<()>,
163    ) -> Result<()> {
164        self.view().for_each_matrix(axis.index(), f)
165    }
166
167    /// Extract a mutable 2D matrix view by fixing one axis.
168    pub fn matrix_at_mut(&mut self, axis: Axis3, index: usize) -> Result<ArrayViewMut2<'_, T>> {
169        let axis = axis.index();
170        if index >= self.shape[axis] {
171            return Err(Error::IndexOutOfBounds);
172        }
173        let strides = self.strides();
174        let axes: Vec<usize> = (0..3).filter(|&candidate| candidate != axis).collect();
175        ArrayViewMut2::new(
176            &mut self.data,
177            [self.shape[axes[0]], self.shape[axes[1]]],
178            [strides[axes[0]], strides[axes[1]]],
179            index as isize * strides[axis],
180        )
181    }
182
183    /// Visit each mutable 2D matrix slice along `axis` in order.
184    pub fn for_each_matrix_mut(
185        &mut self,
186        axis: Axis3,
187        mut f: impl FnMut(usize, ArrayViewMut2<'_, T>) -> Result<()>,
188    ) -> Result<()> {
189        let axis = axis.index();
190        if axis >= 3 {
191            return Err(Error::AxisOutOfBounds { axis, ndim: 3 });
192        }
193        for index in 0..self.shape[axis] {
194            f(
195                index,
196                self.matrix_at_mut(
197                    match axis {
198                        0 => Axis3::Axis0,
199                        1 => Axis3::Axis1,
200                        _ => Axis3::Axis2,
201                    },
202                    index,
203                )?,
204            )?;
205        }
206        Ok(())
207    }
208
209    fn linear_index(&self, i: usize, j: usize, k: usize) -> usize {
210        (i * self.shape[1] + j) * self.shape[2] + k
211    }
212}
213
214impl<T: Clone> Array3<T> {
215    /// Fill a new array with `value`.
216    pub fn filled(shape: [usize; 3], value: T) -> Self {
217        Self {
218            data: vec![value; shape.iter().product()],
219            shape,
220        }
221    }
222
223    /// Fallibly fill a new array with `value`.
224    pub fn try_filled(shape: [usize; 3], value: T) -> Result<Self> {
225        let len = checked_len(shape)?;
226        let mut data = Vec::new();
227        data.try_reserve_exact(len)
228            .map_err(|_| Error::AllocationFailed)?;
229        data.resize(len, value);
230        Ok(Self { data, shape })
231    }
232
233    /// Clone into compact row-major storage.
234    pub fn clone_contiguous(view: ArrayView3<'_, T>) -> Self {
235        Self::from_fn(view.shape(), |i, j, k| view[(i, j, k)].clone())
236    }
237
238    /// Unfold the tensor along `axis` into a row-major matrix.
239    ///
240    /// For shape `[i, j, k]`, mode-0 unfolding has shape `[i, j * k]`,
241    /// mode-1 has shape `[j, i * k]`, and mode-2 has shape `[k, i * j]`.
242    pub fn unfold(&self, axis: Axis3) -> Array2<T> {
243        unfold_view(self.view(), axis)
244    }
245
246    /// Fold a mode-unfolded matrix back into a tensor with `shape`.
247    pub fn fold(axis: Axis3, shape: [usize; 3], matrix: ArrayView2<'_, T>) -> Result<Self> {
248        fold_view(axis, shape, matrix)
249    }
250}
251
252/// Unfold a 3D tensor view along `axis` into a row-major matrix.
253pub fn unfold_view<T: Clone>(a: ArrayView3<'_, T>, axis: Axis3) -> Array2<T> {
254    let shape = a.shape();
255    match axis {
256        Axis3::Axis0 => Array2::from_fn([shape[0], shape[1] * shape[2]], |row, col| {
257            let j = col / shape[2];
258            let k = col % shape[2];
259            a[(row, j, k)].clone()
260        }),
261        Axis3::Axis1 => Array2::from_fn([shape[1], shape[0] * shape[2]], |row, col| {
262            let i = col / shape[2];
263            let k = col % shape[2];
264            a[(i, row, k)].clone()
265        }),
266        Axis3::Axis2 => Array2::from_fn([shape[2], shape[0] * shape[1]], |row, col| {
267            let i = col / shape[1];
268            let j = col % shape[1];
269            a[(i, j, row)].clone()
270        }),
271    }
272}
273
274/// Fold a mode-unfolded matrix back into a 3D tensor with `shape`.
275pub fn fold_view<T: Clone>(
276    axis: Axis3,
277    shape: [usize; 3],
278    matrix: ArrayView2<'_, T>,
279) -> Result<Array3<T>> {
280    let expected = match axis {
281        Axis3::Axis0 => [shape[0], shape[1] * shape[2]],
282        Axis3::Axis1 => [shape[1], shape[0] * shape[2]],
283        Axis3::Axis2 => [shape[2], shape[0] * shape[1]],
284    };
285    if matrix.shape() != expected {
286        return Err(Error::shape(expected, matrix.shape()));
287    }
288    Ok(Array3::from_fn(shape, |i, j, k| match axis {
289        Axis3::Axis0 => matrix[(i, j * shape[2] + k)].clone(),
290        Axis3::Axis1 => matrix[(j, i * shape[2] + k)].clone(),
291        Axis3::Axis2 => matrix[(k, i * shape[1] + j)].clone(),
292    }))
293}
294
295impl<T: Float> Array3<T> {
296    /// Array filled with zeros.
297    pub fn zeros(shape: [usize; 3]) -> Self {
298        Self::filled(shape, T::zero())
299    }
300
301    /// Fallibly allocate an array filled with zeros.
302    pub fn try_zeros(shape: [usize; 3]) -> Result<Self> {
303        Self::try_filled(shape, T::zero())
304    }
305
306    /// Array filled with ones.
307    pub fn ones(shape: [usize; 3]) -> Self {
308        Self::filled(shape, T::one())
309    }
310
311    /// Fallibly allocate an array filled with ones.
312    pub fn try_ones(shape: [usize; 3]) -> Result<Self> {
313        Self::try_filled(shape, T::one())
314    }
315
316    /// Allocate another array with the same shape, filled with zeros.
317    pub fn zeros_like(&self) -> Self {
318        Self::zeros(self.shape)
319    }
320
321    /// Scale in place.
322    pub fn scale(&mut self, alpha: T) {
323        for value in &mut self.data {
324            *value *= alpha;
325        }
326    }
327
328    /// Return `self * alpha` without modifying `self`.
329    pub fn scaled(&self, alpha: T) -> Self {
330        Self::from_fn(self.shape, |i, j, k| self[(i, j, k)] * alpha)
331    }
332
333    /// Write `self * alpha` into `out` without modifying `self`.
334    pub fn scaled_into(&self, alpha: T, mut out: ArrayViewMut3<'_, T>) -> Result<()> {
335        if self.shape != out.shape() {
336            return Err(Error::shape(self.shape, out.shape()));
337        }
338        for i in 0..self.shape[0] {
339            for j in 0..self.shape[1] {
340                for k in 0..self.shape[2] {
341                    out[(i, j, k)] = self[(i, j, k)] * alpha;
342                }
343            }
344        }
345        Ok(())
346    }
347
348    /// Return `self + other` without modifying either input.
349    pub fn add(&self, other: ArrayView3<'_, T>) -> Result<Self> {
350        self.zip_map(other, |left, right| left + right)
351    }
352
353    /// Write `self + other` into `out` without modifying either input.
354    pub fn add_into(&self, other: ArrayView3<'_, T>, out: ArrayViewMut3<'_, T>) -> Result<()> {
355        self.zip_map_into(other, out, |left, right| left + right)
356    }
357
358    /// Return `self - other` without modifying either input.
359    pub fn sub(&self, other: ArrayView3<'_, T>) -> Result<Self> {
360        self.zip_map(other, |left, right| left - right)
361    }
362
363    /// Write `self - other` into `out` without modifying either input.
364    pub fn sub_into(&self, other: ArrayView3<'_, T>, out: ArrayViewMut3<'_, T>) -> Result<()> {
365        self.zip_map_into(other, out, |left, right| left - right)
366    }
367
368    /// Return the elementwise product `self * other` without modifying either input.
369    pub fn mul(&self, other: ArrayView3<'_, T>) -> Result<Self> {
370        self.zip_map(other, |left, right| left * right)
371    }
372
373    /// Write the elementwise product into `out` without modifying either input.
374    pub fn mul_into(&self, other: ArrayView3<'_, T>, out: ArrayViewMut3<'_, T>) -> Result<()> {
375        self.zip_map_into(other, out, |left, right| left * right)
376    }
377
378    /// Return the Hadamard product without modifying either input.
379    pub fn hadamard(&self, other: ArrayView3<'_, T>) -> Result<Self> {
380        self.mul(other)
381    }
382
383    /// Write the Hadamard product into `out` without modifying either input.
384    pub fn hadamard_into(&self, other: ArrayView3<'_, T>, out: ArrayViewMut3<'_, T>) -> Result<()> {
385        self.mul_into(other, out)
386    }
387
388    /// Return the elementwise quotient `self / other` without modifying either input.
389    pub fn div(&self, other: ArrayView3<'_, T>) -> Result<Self> {
390        self.zip_map(other, |left, right| left / right)
391    }
392
393    /// Write the elementwise quotient into `out` without modifying either input.
394    pub fn div_into(&self, other: ArrayView3<'_, T>, out: ArrayViewMut3<'_, T>) -> Result<()> {
395        self.zip_map_into(other, out, |left, right| left / right)
396    }
397
398    /// Return `self + alpha * x` without modifying either input.
399    pub fn axpy_result(&self, alpha: T, x: ArrayView3<'_, T>) -> Result<Self> {
400        self.zip_map(x, |left, right| left + alpha * right)
401    }
402
403    /// Write `self + alpha * x` into `out` without modifying either input.
404    pub fn axpy_into(
405        &self,
406        alpha: T,
407        x: ArrayView3<'_, T>,
408        out: ArrayViewMut3<'_, T>,
409    ) -> Result<()> {
410        self.zip_map_into(x, out, |left, right| left + alpha * right)
411    }
412
413    /// In-place addition.
414    pub fn add_assign_view(&mut self, other: ArrayView3<'_, T>) -> Result<()> {
415        self.zip_map_inplace(other, |left, right| left + right)
416    }
417
418    /// In-place subtraction.
419    pub fn sub_assign_view(&mut self, other: ArrayView3<'_, T>) -> Result<()> {
420        self.zip_map_inplace(other, |left, right| left - right)
421    }
422
423    /// Hadamard product in place.
424    pub fn mul_assign_view(&mut self, other: ArrayView3<'_, T>) -> Result<()> {
425        self.zip_map_inplace(other, |left, right| left * right)
426    }
427
428    /// In-place elementwise division.
429    pub fn div_assign_view(&mut self, other: ArrayView3<'_, T>) -> Result<()> {
430        self.zip_map_inplace(other, |left, right| left / right)
431    }
432
433    /// Compute `self += alpha * x`.
434    pub fn axpy(&mut self, alpha: T, x: ArrayView3<'_, T>) -> Result<()> {
435        if self.shape != x.shape() {
436            return Err(Error::shape(self.shape, x.shape()));
437        }
438        for i in 0..self.shape[0] {
439            for j in 0..self.shape[1] {
440                for k in 0..self.shape[2] {
441                    self[(i, j, k)] += alpha * x[(i, j, k)];
442                }
443            }
444        }
445        Ok(())
446    }
447
448    /// Return an elementwise zip-map without modifying either input.
449    pub fn zip_map(&self, other: ArrayView3<'_, T>, mut f: impl FnMut(T, T) -> T) -> Result<Self> {
450        if self.shape != other.shape() {
451            return Err(Error::shape(self.shape, other.shape()));
452        }
453        Ok(Self::from_fn(self.shape, |i, j, k| {
454            f(self[(i, j, k)], other[(i, j, k)])
455        }))
456    }
457
458    /// Write an elementwise zip-map into `out` without modifying either input.
459    pub fn zip_map_into(
460        &self,
461        other: ArrayView3<'_, T>,
462        mut out: ArrayViewMut3<'_, T>,
463        mut f: impl FnMut(T, T) -> T,
464    ) -> Result<()> {
465        if self.shape != other.shape() {
466            return Err(Error::shape(self.shape, other.shape()));
467        }
468        if self.shape != out.shape() {
469            return Err(Error::shape(self.shape, out.shape()));
470        }
471        for i in 0..self.shape[0] {
472            for j in 0..self.shape[1] {
473                for k in 0..self.shape[2] {
474                    out[(i, j, k)] = f(self[(i, j, k)], other[(i, j, k)]);
475                }
476            }
477        }
478        Ok(())
479    }
480
481    /// Map values in place.
482    pub fn map_inplace(&mut self, mut f: impl FnMut(T) -> T) {
483        for value in &mut self.data {
484            *value = f(*value);
485        }
486    }
487
488    /// Zip-map another view into this array in place.
489    pub fn zip_map_inplace(
490        &mut self,
491        other: ArrayView3<'_, T>,
492        mut f: impl FnMut(T, T) -> T,
493    ) -> Result<()> {
494        if self.shape != other.shape() {
495            return Err(Error::shape(self.shape, other.shape()));
496        }
497        for i in 0..self.shape[0] {
498            for j in 0..self.shape[1] {
499                for k in 0..self.shape[2] {
500                    self[(i, j, k)] = f(self[(i, j, k)], other[(i, j, k)]);
501                }
502            }
503        }
504        Ok(())
505    }
506
507    /// Fill with deterministic uniform random values in `[0, 1)`.
508    pub fn fill_uniform(&mut self, seed: u64) {
509        let mut rng = SmallRng::new(seed);
510        for value in &mut self.data {
511            *value = rng.uniform();
512        }
513    }
514
515    /// Fill with deterministic standard-normal random values.
516    pub fn fill_randn(&mut self, seed: u64) {
517        let mut rng = SmallRng::new(seed);
518        for value in &mut self.data {
519            *value = rng.normal();
520        }
521    }
522
523    /// Sum all elements.
524    pub fn sum(&self) -> T {
525        self.data.iter().copied().sum()
526    }
527
528    /// Mean of all elements, or zero for an empty array.
529    pub fn mean(&self) -> T {
530        if self.is_empty() {
531            T::zero()
532        } else {
533            self.sum() / T::from_f64(self.len() as f64)
534        }
535    }
536
537    /// Frobenius norm.
538    pub fn norm_frobenius(&self) -> T {
539        self.data
540            .iter()
541            .copied()
542            .map(|value| value * value)
543            .sum::<T>()
544            .sqrt()
545    }
546
547    /// Maximum absolute element, or zero for an empty array.
548    pub fn max_abs(&self) -> T {
549        self.data
550            .iter()
551            .copied()
552            .map(T::abs)
553            .fold(
554                T::zero(),
555                |best, value| if value > best { value } else { best },
556            )
557    }
558
559    /// Dot product of two tensors treated as flattened row-major buffers.
560    pub fn dot(&self, other: ArrayView3<'_, T>) -> Result<T> {
561        if self.shape != other.shape() {
562            return Err(Error::shape(self.shape, other.shape()));
563        }
564        let mut sum = T::zero();
565        for i in 0..self.shape[0] {
566            for j in 0..self.shape[1] {
567                for k in 0..self.shape[2] {
568                    sum += self[(i, j, k)] * other[(i, j, k)];
569                }
570            }
571        }
572        Ok(sum)
573    }
574}
575
576impl<T> Index<(usize, usize, usize)> for Array3<T> {
577    type Output = T;
578
579    fn index(&self, index: (usize, usize, usize)) -> &Self::Output {
580        self.get(index.0, index.1, index.2)
581            .expect("array index out of bounds")
582    }
583}
584
585fn checked_len(shape: [usize; 3]) -> Result<usize> {
586    shape
587        .iter()
588        .try_fold(1usize, |acc, &dim| acc.checked_mul(dim))
589        .ok_or(Error::DimensionTooLarge)
590}
591
592impl<T> IndexMut<(usize, usize, usize)> for Array3<T> {
593    fn index_mut(&mut self, index: (usize, usize, usize)) -> &mut Self::Output {
594        self.get_mut(index.0, index.1, index.2)
595            .expect("array index out of bounds")
596    }
597}