Skip to main content

tensors/
array2.rs

1//! Owned two-dimensional row-major arrays.
2
3use core::ops::{Index, IndexMut};
4
5use crate::error::{Error, Result};
6use crate::numeric::Float;
7use crate::rand::SmallRng;
8use crate::view2::{ArrayView2, ArrayViewMut2};
9
10/// Owned 2D row-major array.
11#[derive(Clone, Debug, PartialEq)]
12pub struct Array2<T> {
13    data: Vec<T>,
14    rows: usize,
15    cols: usize,
16}
17
18impl<T> Array2<T> {
19    /// Build an array from row-major data.
20    pub fn from_vec(shape: [usize; 2], data: Vec<T>) -> Result<Self> {
21        let expected = shape[0]
22            .checked_mul(shape[1])
23            .ok_or(Error::DimensionTooLarge)?;
24        if data.len() != expected {
25            return Err(Error::shape(vec![expected], vec![data.len()]));
26        }
27        Ok(Self {
28            data,
29            rows: shape[0],
30            cols: shape[1],
31        })
32    }
33
34    /// Build an array from a function over `(row, col)`.
35    pub fn from_fn(shape: [usize; 2], mut f: impl FnMut(usize, usize) -> T) -> Self {
36        let len = shape[0] * shape[1];
37        let mut data = Vec::with_capacity(len);
38        for i in 0..shape[0] {
39            for j in 0..shape[1] {
40                data.push(f(i, j));
41            }
42        }
43        Self {
44            data,
45            rows: shape[0],
46            cols: shape[1],
47        }
48    }
49
50    /// Fallibly build an array from a function over `(row, col)`.
51    pub fn try_from_fn(shape: [usize; 2], mut f: impl FnMut(usize, usize) -> T) -> Result<Self> {
52        let len = checked_len(shape)?;
53        let mut data = Vec::new();
54        data.try_reserve_exact(len)
55            .map_err(|_| Error::AllocationFailed)?;
56        for i in 0..shape[0] {
57            for j in 0..shape[1] {
58                data.push(f(i, j));
59            }
60        }
61        Ok(Self {
62            data,
63            rows: shape[0],
64            cols: shape[1],
65        })
66    }
67
68    /// Shape as `[rows, cols]`.
69    #[inline]
70    pub fn shape(&self) -> [usize; 2] {
71        [self.rows, self.cols]
72    }
73
74    /// Number of rows.
75    #[inline]
76    pub fn rows(&self) -> usize {
77        self.rows
78    }
79
80    /// Number of columns.
81    #[inline]
82    pub fn cols(&self) -> usize {
83        self.cols
84    }
85
86    /// Row-major strides in elements.
87    #[inline]
88    pub fn strides(&self) -> [isize; 2] {
89        [self.cols as isize, 1]
90    }
91
92    /// Distance in elements between consecutive rows.
93    #[inline]
94    pub fn row_stride(&self) -> isize {
95        self.cols as isize
96    }
97
98    /// Distance in elements between consecutive columns.
99    #[inline]
100    pub fn col_stride(&self) -> isize {
101        1
102    }
103
104    /// Leading dimension for row-major storage.
105    #[inline]
106    pub fn leading_dimension(&self) -> isize {
107        self.cols as isize
108    }
109
110    /// Number of elements.
111    #[inline]
112    pub fn len(&self) -> usize {
113        self.data.len()
114    }
115
116    /// Whether the array has zero elements.
117    #[inline]
118    pub fn is_empty(&self) -> bool {
119        self.data.is_empty()
120    }
121
122    /// Owned arrays are compact row-major contiguous.
123    #[inline]
124    pub fn is_contiguous(&self) -> bool {
125        true
126    }
127
128    /// Borrow the row-major backing slice.
129    #[inline]
130    pub fn as_slice(&self) -> &[T] {
131        &self.data
132    }
133
134    /// Borrow the row-major backing slice mutably.
135    #[inline]
136    pub fn as_mut_slice(&mut self) -> &mut [T] {
137        &mut self.data
138    }
139
140    /// Consume into the backing vector.
141    pub fn into_vec(self) -> Vec<T> {
142        self.data
143    }
144
145    /// Immutable strided view.
146    pub fn view(&self) -> ArrayView2<'_, T> {
147        ArrayView2::from_raw_parts(&self.data, self.shape(), self.strides(), 0)
148    }
149
150    /// Mutable strided view.
151    pub fn view_mut(&mut self) -> ArrayViewMut2<'_, T> {
152        ArrayViewMut2::from_raw_parts(
153            &mut self.data,
154            [self.rows, self.cols],
155            [self.cols as isize, 1],
156            0,
157        )
158    }
159
160    /// Transpose view without copying.
161    pub fn transpose_view(&self) -> ArrayView2<'_, T> {
162        self.view().transpose()
163    }
164
165    /// Get an element reference.
166    #[inline]
167    pub fn get(&self, row: usize, col: usize) -> Option<&T> {
168        (row < self.rows && col < self.cols).then(|| &self.data[row * self.cols + col])
169    }
170
171    /// Get a mutable element reference.
172    #[inline]
173    pub fn get_mut(&mut self, row: usize, col: usize) -> Option<&mut T> {
174        (row < self.rows && col < self.cols).then(|| &mut self.data[row * self.cols + col])
175    }
176
177    /// Row as a one-row matrix view.
178    pub fn row(&self, row: usize) -> Result<ArrayView2<'_, T>> {
179        self.view().row(row)
180    }
181
182    /// Borrow a row as a contiguous slice.
183    pub fn row_slice(&self, row: usize) -> Result<&[T]> {
184        if row >= self.rows {
185            return Err(Error::IndexOutOfBounds);
186        }
187        let start = row * self.cols;
188        Ok(&self.data[start..start + self.cols])
189    }
190
191    /// Mutable row as a one-row matrix view.
192    pub fn row_mut(&mut self, row: usize) -> Result<ArrayViewMut2<'_, T>> {
193        if row >= self.rows {
194            return Err(Error::IndexOutOfBounds);
195        }
196        Ok(ArrayViewMut2::from_raw_parts(
197            &mut self.data,
198            [1, self.cols],
199            [self.cols as isize, 1],
200            (row * self.cols) as isize,
201        ))
202    }
203
204    /// Borrow a row as a mutable contiguous slice.
205    pub fn row_slice_mut(&mut self, row: usize) -> Result<&mut [T]> {
206        if row >= self.rows {
207            return Err(Error::IndexOutOfBounds);
208        }
209        let start = row * self.cols;
210        Ok(&mut self.data[start..start + self.cols])
211    }
212
213    /// Column as an `rows x 1` matrix view.
214    pub fn col(&self, col: usize) -> Result<ArrayView2<'_, T>> {
215        self.view().col(col)
216    }
217
218    /// Mutable column as an `rows x 1` matrix view.
219    pub fn col_mut(&mut self, col: usize) -> Result<ArrayViewMut2<'_, T>> {
220        if col >= self.cols {
221            return Err(Error::IndexOutOfBounds);
222        }
223        Ok(ArrayViewMut2::from_raw_parts(
224            &mut self.data,
225            [self.rows, 1],
226            [self.cols as isize, 1],
227            col as isize,
228        ))
229    }
230
231    /// Half-open row slice.
232    pub fn rows_range(&self, start: usize, end: usize) -> Result<ArrayView2<'_, T>> {
233        self.view().rows_range(start, end)
234    }
235
236    /// Mutable half-open row slice.
237    pub fn rows_range_mut(&mut self, start: usize, end: usize) -> Result<ArrayViewMut2<'_, T>> {
238        if start > end || end > self.rows {
239            return Err(Error::IndexOutOfBounds);
240        }
241        Ok(ArrayViewMut2::from_raw_parts(
242            &mut self.data,
243            [end - start, self.cols],
244            [self.cols as isize, 1],
245            (start * self.cols) as isize,
246        ))
247    }
248
249    /// Half-open column slice.
250    pub fn cols_range(&self, start: usize, end: usize) -> Result<ArrayView2<'_, T>> {
251        self.view().cols_range(start, end)
252    }
253
254    /// Mutable half-open column slice.
255    pub fn cols_range_mut(&mut self, start: usize, end: usize) -> Result<ArrayViewMut2<'_, T>> {
256        if start > end || end > self.cols {
257            return Err(Error::IndexOutOfBounds);
258        }
259        Ok(ArrayViewMut2::from_raw_parts(
260            &mut self.data,
261            [self.rows, end - start],
262            [self.cols as isize, 1],
263            start as isize,
264        ))
265    }
266
267    /// Reshape without changing storage order.
268    pub fn reshape(mut self, shape: [usize; 2]) -> Result<Self> {
269        let expected = shape[0]
270            .checked_mul(shape[1])
271            .ok_or(Error::DimensionTooLarge)?;
272        if expected != self.data.len() {
273            return Err(Error::shape(vec![self.data.len()], vec![expected]));
274        }
275        self.rows = shape[0];
276        self.cols = shape[1];
277        Ok(self)
278    }
279}
280
281impl<T: Clone> Array2<T> {
282    /// Fill a new array with `value`.
283    pub fn filled(shape: [usize; 2], value: T) -> Self {
284        Self {
285            data: vec![value; shape[0] * shape[1]],
286            rows: shape[0],
287            cols: shape[1],
288        }
289    }
290
291    /// Fallibly fill a new array with `value`.
292    pub fn try_filled(shape: [usize; 2], value: T) -> Result<Self> {
293        let len = checked_len(shape)?;
294        let mut data = Vec::new();
295        data.try_reserve_exact(len)
296            .map_err(|_| Error::AllocationFailed)?;
297        data.resize(len, value);
298        Ok(Self {
299            data,
300            rows: shape[0],
301            cols: shape[1],
302        })
303    }
304
305    /// Clone into compact row-major storage.
306    pub fn clone_contiguous(view: ArrayView2<'_, T>) -> Self {
307        Self::from_fn(view.shape(), |i, j| view[(i, j)].clone())
308    }
309
310    /// Copy this array into compact row-major storage.
311    pub fn to_row_major(&self) -> Self {
312        self.clone()
313    }
314
315    /// Copy this array into a column-major vector.
316    pub fn to_col_major_vec(&self) -> Vec<T> {
317        self.view().to_col_major_vec()
318    }
319
320    /// Copy values from another view with the same shape.
321    pub fn copy_from_view(&mut self, other: ArrayView2<'_, T>) -> Result<()> {
322        if self.shape() != other.shape() {
323            return Err(Error::shape(self.shape(), other.shape()));
324        }
325        for i in 0..self.rows {
326            for j in 0..self.cols {
327                self[(i, j)] = other[(i, j)].clone();
328            }
329        }
330        Ok(())
331    }
332}
333
334impl<T: Float> Array2<T> {
335    /// Array filled with zeros.
336    pub fn zeros(shape: [usize; 2]) -> Self {
337        Self::filled(shape, T::zero())
338    }
339
340    /// Fallibly allocate an array filled with zeros.
341    pub fn try_zeros(shape: [usize; 2]) -> Result<Self> {
342        Self::try_filled(shape, T::zero())
343    }
344
345    /// Array filled with ones.
346    pub fn ones(shape: [usize; 2]) -> Self {
347        Self::filled(shape, T::one())
348    }
349
350    /// Fallibly allocate an array filled with ones.
351    pub fn try_ones(shape: [usize; 2]) -> Result<Self> {
352        Self::try_filled(shape, T::one())
353    }
354
355    /// Allocate another array with the same shape, filled with zeros.
356    pub fn zeros_like(&self) -> Self {
357        Self::zeros(self.shape())
358    }
359
360    /// Scale in place.
361    pub fn scale(&mut self, alpha: T) {
362        for value in &mut self.data {
363            *value *= alpha;
364        }
365    }
366
367    /// Return `self * alpha` without modifying `self`.
368    pub fn scaled(&self, alpha: T) -> Self {
369        Self::from_fn(self.shape(), |i, j| self[(i, j)] * alpha)
370    }
371
372    /// Write `self * alpha` into `out` without modifying `self`.
373    pub fn scaled_into(&self, alpha: T, mut out: ArrayViewMut2<'_, T>) -> Result<()> {
374        if self.shape() != out.shape() {
375            return Err(Error::shape(self.shape(), out.shape()));
376        }
377        for i in 0..self.rows {
378            for j in 0..self.cols {
379                out[(i, j)] = self[(i, j)] * alpha;
380            }
381        }
382        Ok(())
383    }
384
385    /// Return `self + other` without modifying either input.
386    pub fn add(&self, other: ArrayView2<'_, T>) -> Result<Self> {
387        self.zip_map(other, |left, right| left + right)
388    }
389
390    /// Write `self + other` into `out` without modifying either input.
391    pub fn add_into(&self, other: ArrayView2<'_, T>, out: ArrayViewMut2<'_, T>) -> Result<()> {
392        self.zip_map_into(other, out, |left, right| left + right)
393    }
394
395    /// Return `self - other` without modifying either input.
396    pub fn sub(&self, other: ArrayView2<'_, T>) -> Result<Self> {
397        self.zip_map(other, |left, right| left - right)
398    }
399
400    /// Write `self - other` into `out` without modifying either input.
401    pub fn sub_into(&self, other: ArrayView2<'_, T>, out: ArrayViewMut2<'_, T>) -> Result<()> {
402        self.zip_map_into(other, out, |left, right| left - right)
403    }
404
405    /// Return the elementwise product `self * other` without modifying either input.
406    pub fn mul(&self, other: ArrayView2<'_, T>) -> Result<Self> {
407        self.zip_map(other, |left, right| left * right)
408    }
409
410    /// Write the elementwise product into `out` without modifying either input.
411    pub fn mul_into(&self, other: ArrayView2<'_, T>, out: ArrayViewMut2<'_, T>) -> Result<()> {
412        self.zip_map_into(other, out, |left, right| left * right)
413    }
414
415    /// Return the Hadamard product without modifying either input.
416    pub fn hadamard(&self, other: ArrayView2<'_, T>) -> Result<Self> {
417        self.mul(other)
418    }
419
420    /// Write the Hadamard product into `out` without modifying either input.
421    pub fn hadamard_into(&self, other: ArrayView2<'_, T>, out: ArrayViewMut2<'_, T>) -> Result<()> {
422        self.mul_into(other, out)
423    }
424
425    /// Return the elementwise quotient `self / other` without modifying either input.
426    pub fn div(&self, other: ArrayView2<'_, T>) -> Result<Self> {
427        self.zip_map(other, |left, right| left / right)
428    }
429
430    /// Write the elementwise quotient into `out` without modifying either input.
431    pub fn div_into(&self, other: ArrayView2<'_, T>, out: ArrayViewMut2<'_, T>) -> Result<()> {
432        self.zip_map_into(other, out, |left, right| left / right)
433    }
434
435    /// Return `self + alpha * x` without modifying either input.
436    pub fn axpy_result(&self, alpha: T, x: ArrayView2<'_, T>) -> Result<Self> {
437        self.zip_map(x, |left, right| left + alpha * right)
438    }
439
440    /// Write `self + alpha * x` into `out` without modifying either input.
441    pub fn axpy_into(
442        &self,
443        alpha: T,
444        x: ArrayView2<'_, T>,
445        out: ArrayViewMut2<'_, T>,
446    ) -> Result<()> {
447        self.zip_map_into(x, out, |left, right| left + alpha * right)
448    }
449
450    /// Return a matrix product `self @ other` without modifying either input.
451    pub fn matmul(&self, other: ArrayView2<'_, T>) -> Result<Self> {
452        crate::linalg::matmul(self.view(), other)
453    }
454
455    /// Write matrix product `self @ other` into `out` without modifying either input.
456    pub fn matmul_into(&self, other: ArrayView2<'_, T>, out: ArrayViewMut2<'_, T>) -> Result<()> {
457        crate::linalg::gemm(T::one(), self.view(), false, other, false, T::zero(), out)
458    }
459
460    /// In-place addition.
461    pub fn add_assign_view(&mut self, other: ArrayView2<'_, T>) -> Result<()> {
462        if self.shape() != other.shape() {
463            return Err(Error::shape(self.shape(), other.shape()));
464        }
465        for i in 0..self.rows {
466            for j in 0..self.cols {
467                self[(i, j)] += other[(i, j)];
468            }
469        }
470        Ok(())
471    }
472
473    /// In-place subtraction.
474    pub fn sub_assign_view(&mut self, other: ArrayView2<'_, T>) -> Result<()> {
475        if self.shape() != other.shape() {
476            return Err(Error::shape(self.shape(), other.shape()));
477        }
478        for i in 0..self.rows {
479            for j in 0..self.cols {
480                self[(i, j)] -= other[(i, j)];
481            }
482        }
483        Ok(())
484    }
485
486    /// Hadamard product in place.
487    pub fn mul_assign_view(&mut self, other: ArrayView2<'_, T>) -> Result<()> {
488        if self.shape() != other.shape() {
489            return Err(Error::shape(self.shape(), other.shape()));
490        }
491        for i in 0..self.rows {
492            for j in 0..self.cols {
493                self[(i, j)] *= other[(i, j)];
494            }
495        }
496        Ok(())
497    }
498
499    /// In-place elementwise division.
500    pub fn div_assign_view(&mut self, other: ArrayView2<'_, T>) -> Result<()> {
501        if self.shape() != other.shape() {
502            return Err(Error::shape(self.shape(), other.shape()));
503        }
504        for i in 0..self.rows {
505            for j in 0..self.cols {
506                self[(i, j)] /= other[(i, j)];
507            }
508        }
509        Ok(())
510    }
511
512    /// Compute `self += alpha * x`.
513    pub fn axpy(&mut self, alpha: T, x: ArrayView2<'_, T>) -> Result<()> {
514        if self.shape() != x.shape() {
515            return Err(Error::shape(self.shape(), x.shape()));
516        }
517        for i in 0..self.rows {
518            for j in 0..self.cols {
519                self[(i, j)] += alpha * x[(i, j)];
520            }
521        }
522        Ok(())
523    }
524
525    /// Map values in place.
526    pub fn map_inplace(&mut self, mut f: impl FnMut(T) -> T) {
527        for value in &mut self.data {
528            *value = f(*value);
529        }
530    }
531
532    /// Zip-map another view into this array in place.
533    pub fn zip_map_inplace(
534        &mut self,
535        other: ArrayView2<'_, T>,
536        mut f: impl FnMut(T, T) -> T,
537    ) -> Result<()> {
538        if self.shape() != other.shape() {
539            return Err(Error::shape(self.shape(), other.shape()));
540        }
541        for i in 0..self.rows {
542            for j in 0..self.cols {
543                self[(i, j)] = f(self[(i, j)], other[(i, j)]);
544            }
545        }
546        Ok(())
547    }
548
549    /// Return an elementwise zip-map without modifying either input.
550    pub fn zip_map(&self, other: ArrayView2<'_, T>, mut f: impl FnMut(T, T) -> T) -> Result<Self> {
551        if self.shape() != other.shape() {
552            return Err(Error::shape(self.shape(), other.shape()));
553        }
554        Ok(Self::from_fn(self.shape(), |i, j| {
555            f(self[(i, j)], other[(i, j)])
556        }))
557    }
558
559    /// Write an elementwise zip-map into `out` without modifying either input.
560    pub fn zip_map_into(
561        &self,
562        other: ArrayView2<'_, T>,
563        mut out: ArrayViewMut2<'_, T>,
564        mut f: impl FnMut(T, T) -> T,
565    ) -> Result<()> {
566        if self.shape() != other.shape() {
567            return Err(Error::shape(self.shape(), other.shape()));
568        }
569        if self.shape() != out.shape() {
570            return Err(Error::shape(self.shape(), out.shape()));
571        }
572        for i in 0..self.rows {
573            for j in 0..self.cols {
574                out[(i, j)] = f(self[(i, j)], other[(i, j)]);
575            }
576        }
577        Ok(())
578    }
579
580    /// Fill with deterministic uniform random values in `[0, 1)`.
581    pub fn fill_uniform(&mut self, seed: u64) {
582        let mut rng = SmallRng::new(seed);
583        for value in &mut self.data {
584            *value = rng.uniform();
585        }
586    }
587
588    /// Fill with deterministic standard-normal random values.
589    pub fn fill_randn(&mut self, seed: u64) {
590        let mut rng = SmallRng::new(seed);
591        for value in &mut self.data {
592            *value = rng.normal();
593        }
594    }
595
596    /// Sum all elements.
597    pub fn sum(&self) -> T {
598        self.data.iter().copied().sum()
599    }
600
601    /// Mean of all elements, or zero for an empty array.
602    pub fn mean(&self) -> T {
603        if self.is_empty() {
604            T::zero()
605        } else {
606            self.sum() / T::from_f64(self.len() as f64)
607        }
608    }
609
610    /// Frobenius norm.
611    pub fn norm_frobenius(&self) -> T {
612        self.data
613            .iter()
614            .copied()
615            .map(|value| value * value)
616            .sum::<T>()
617            .sqrt()
618    }
619
620    /// Maximum absolute element, or zero for an empty array.
621    pub fn max_abs(&self) -> T {
622        self.data
623            .iter()
624            .copied()
625            .map(T::abs)
626            .fold(
627                T::zero(),
628                |best, value| if value > best { value } else { best },
629            )
630    }
631
632    /// Dot product of two arrays treated as flattened row-major buffers.
633    pub fn dot(&self, other: ArrayView2<'_, T>) -> Result<T> {
634        if self.shape() != other.shape() {
635            return Err(Error::shape(self.shape(), other.shape()));
636        }
637        let mut sum = T::zero();
638        for i in 0..self.rows {
639            for j in 0..self.cols {
640                sum += self[(i, j)] * other[(i, j)];
641            }
642        }
643        Ok(sum)
644    }
645}
646
647fn checked_len(shape: [usize; 2]) -> Result<usize> {
648    shape[0]
649        .checked_mul(shape[1])
650        .ok_or(Error::DimensionTooLarge)
651}
652
653impl<T> Index<(usize, usize)> for Array2<T> {
654    type Output = T;
655
656    fn index(&self, index: (usize, usize)) -> &Self::Output {
657        self.get(index.0, index.1)
658            .expect("array index out of bounds")
659    }
660}
661
662impl<T> IndexMut<(usize, usize)> for Array2<T> {
663    fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
664        self.get_mut(index.0, index.1)
665            .expect("array index out of bounds")
666    }
667}