1use core::ops::{Index, IndexMut};
4
5use crate::error::{Error, Result};
6use crate::numeric::Float;
7use crate::rand::SmallRng;
8use crate::view2::{ArrayView2, ArrayViewMut2};
9
10#[derive(Clone, Debug, PartialEq)]
12pub struct Array2<T> {
13 data: Vec<T>,
14 rows: usize,
15 cols: usize,
16}
17
18impl<T> Array2<T> {
19 pub fn from_vec(shape: [usize; 2], data: Vec<T>) -> Result<Self> {
21 let expected = shape[0]
22 .checked_mul(shape[1])
23 .ok_or(Error::DimensionTooLarge)?;
24 if data.len() != expected {
25 return Err(Error::shape(vec![expected], vec![data.len()]));
26 }
27 Ok(Self {
28 data,
29 rows: shape[0],
30 cols: shape[1],
31 })
32 }
33
34 pub fn from_fn(shape: [usize; 2], mut f: impl FnMut(usize, usize) -> T) -> Self {
36 let len = shape[0] * shape[1];
37 let mut data = Vec::with_capacity(len);
38 for i in 0..shape[0] {
39 for j in 0..shape[1] {
40 data.push(f(i, j));
41 }
42 }
43 Self {
44 data,
45 rows: shape[0],
46 cols: shape[1],
47 }
48 }
49
50 pub fn try_from_fn(shape: [usize; 2], mut f: impl FnMut(usize, usize) -> T) -> Result<Self> {
52 let len = checked_len(shape)?;
53 let mut data = Vec::new();
54 data.try_reserve_exact(len)
55 .map_err(|_| Error::AllocationFailed)?;
56 for i in 0..shape[0] {
57 for j in 0..shape[1] {
58 data.push(f(i, j));
59 }
60 }
61 Ok(Self {
62 data,
63 rows: shape[0],
64 cols: shape[1],
65 })
66 }
67
68 #[inline]
70 pub fn shape(&self) -> [usize; 2] {
71 [self.rows, self.cols]
72 }
73
74 #[inline]
76 pub fn rows(&self) -> usize {
77 self.rows
78 }
79
80 #[inline]
82 pub fn cols(&self) -> usize {
83 self.cols
84 }
85
86 #[inline]
88 pub fn strides(&self) -> [isize; 2] {
89 [self.cols as isize, 1]
90 }
91
92 #[inline]
94 pub fn row_stride(&self) -> isize {
95 self.cols as isize
96 }
97
98 #[inline]
100 pub fn col_stride(&self) -> isize {
101 1
102 }
103
104 #[inline]
106 pub fn leading_dimension(&self) -> isize {
107 self.cols as isize
108 }
109
110 #[inline]
112 pub fn len(&self) -> usize {
113 self.data.len()
114 }
115
116 #[inline]
118 pub fn is_empty(&self) -> bool {
119 self.data.is_empty()
120 }
121
122 #[inline]
124 pub fn is_contiguous(&self) -> bool {
125 true
126 }
127
128 #[inline]
130 pub fn as_slice(&self) -> &[T] {
131 &self.data
132 }
133
134 #[inline]
136 pub fn as_mut_slice(&mut self) -> &mut [T] {
137 &mut self.data
138 }
139
140 pub fn into_vec(self) -> Vec<T> {
142 self.data
143 }
144
145 pub fn view(&self) -> ArrayView2<'_, T> {
147 ArrayView2::from_raw_parts(&self.data, self.shape(), self.strides(), 0)
148 }
149
150 pub fn view_mut(&mut self) -> ArrayViewMut2<'_, T> {
152 ArrayViewMut2::from_raw_parts(
153 &mut self.data,
154 [self.rows, self.cols],
155 [self.cols as isize, 1],
156 0,
157 )
158 }
159
160 pub fn transpose_view(&self) -> ArrayView2<'_, T> {
162 self.view().transpose()
163 }
164
165 #[inline]
167 pub fn get(&self, row: usize, col: usize) -> Option<&T> {
168 (row < self.rows && col < self.cols).then(|| &self.data[row * self.cols + col])
169 }
170
171 #[inline]
173 pub fn get_mut(&mut self, row: usize, col: usize) -> Option<&mut T> {
174 (row < self.rows && col < self.cols).then(|| &mut self.data[row * self.cols + col])
175 }
176
177 pub fn row(&self, row: usize) -> Result<ArrayView2<'_, T>> {
179 self.view().row(row)
180 }
181
182 pub fn row_slice(&self, row: usize) -> Result<&[T]> {
184 if row >= self.rows {
185 return Err(Error::IndexOutOfBounds);
186 }
187 let start = row * self.cols;
188 Ok(&self.data[start..start + self.cols])
189 }
190
191 pub fn row_mut(&mut self, row: usize) -> Result<ArrayViewMut2<'_, T>> {
193 if row >= self.rows {
194 return Err(Error::IndexOutOfBounds);
195 }
196 Ok(ArrayViewMut2::from_raw_parts(
197 &mut self.data,
198 [1, self.cols],
199 [self.cols as isize, 1],
200 (row * self.cols) as isize,
201 ))
202 }
203
204 pub fn row_slice_mut(&mut self, row: usize) -> Result<&mut [T]> {
206 if row >= self.rows {
207 return Err(Error::IndexOutOfBounds);
208 }
209 let start = row * self.cols;
210 Ok(&mut self.data[start..start + self.cols])
211 }
212
213 pub fn col(&self, col: usize) -> Result<ArrayView2<'_, T>> {
215 self.view().col(col)
216 }
217
218 pub fn col_mut(&mut self, col: usize) -> Result<ArrayViewMut2<'_, T>> {
220 if col >= self.cols {
221 return Err(Error::IndexOutOfBounds);
222 }
223 Ok(ArrayViewMut2::from_raw_parts(
224 &mut self.data,
225 [self.rows, 1],
226 [self.cols as isize, 1],
227 col as isize,
228 ))
229 }
230
231 pub fn rows_range(&self, start: usize, end: usize) -> Result<ArrayView2<'_, T>> {
233 self.view().rows_range(start, end)
234 }
235
236 pub fn rows_range_mut(&mut self, start: usize, end: usize) -> Result<ArrayViewMut2<'_, T>> {
238 if start > end || end > self.rows {
239 return Err(Error::IndexOutOfBounds);
240 }
241 Ok(ArrayViewMut2::from_raw_parts(
242 &mut self.data,
243 [end - start, self.cols],
244 [self.cols as isize, 1],
245 (start * self.cols) as isize,
246 ))
247 }
248
249 pub fn cols_range(&self, start: usize, end: usize) -> Result<ArrayView2<'_, T>> {
251 self.view().cols_range(start, end)
252 }
253
254 pub fn cols_range_mut(&mut self, start: usize, end: usize) -> Result<ArrayViewMut2<'_, T>> {
256 if start > end || end > self.cols {
257 return Err(Error::IndexOutOfBounds);
258 }
259 Ok(ArrayViewMut2::from_raw_parts(
260 &mut self.data,
261 [self.rows, end - start],
262 [self.cols as isize, 1],
263 start as isize,
264 ))
265 }
266
267 pub fn reshape(mut self, shape: [usize; 2]) -> Result<Self> {
269 let expected = shape[0]
270 .checked_mul(shape[1])
271 .ok_or(Error::DimensionTooLarge)?;
272 if expected != self.data.len() {
273 return Err(Error::shape(vec![self.data.len()], vec![expected]));
274 }
275 self.rows = shape[0];
276 self.cols = shape[1];
277 Ok(self)
278 }
279}
280
281impl<T: Clone> Array2<T> {
282 pub fn filled(shape: [usize; 2], value: T) -> Self {
284 Self {
285 data: vec![value; shape[0] * shape[1]],
286 rows: shape[0],
287 cols: shape[1],
288 }
289 }
290
291 pub fn try_filled(shape: [usize; 2], value: T) -> Result<Self> {
293 let len = checked_len(shape)?;
294 let mut data = Vec::new();
295 data.try_reserve_exact(len)
296 .map_err(|_| Error::AllocationFailed)?;
297 data.resize(len, value);
298 Ok(Self {
299 data,
300 rows: shape[0],
301 cols: shape[1],
302 })
303 }
304
305 pub fn clone_contiguous(view: ArrayView2<'_, T>) -> Self {
307 Self::from_fn(view.shape(), |i, j| view[(i, j)].clone())
308 }
309
310 pub fn to_row_major(&self) -> Self {
312 self.clone()
313 }
314
315 pub fn to_col_major_vec(&self) -> Vec<T> {
317 self.view().to_col_major_vec()
318 }
319
320 pub fn copy_from_view(&mut self, other: ArrayView2<'_, T>) -> Result<()> {
322 if self.shape() != other.shape() {
323 return Err(Error::shape(self.shape(), other.shape()));
324 }
325 for i in 0..self.rows {
326 for j in 0..self.cols {
327 self[(i, j)] = other[(i, j)].clone();
328 }
329 }
330 Ok(())
331 }
332}
333
334impl<T: Float> Array2<T> {
335 pub fn zeros(shape: [usize; 2]) -> Self {
337 Self::filled(shape, T::zero())
338 }
339
340 pub fn try_zeros(shape: [usize; 2]) -> Result<Self> {
342 Self::try_filled(shape, T::zero())
343 }
344
345 pub fn ones(shape: [usize; 2]) -> Self {
347 Self::filled(shape, T::one())
348 }
349
350 pub fn try_ones(shape: [usize; 2]) -> Result<Self> {
352 Self::try_filled(shape, T::one())
353 }
354
355 pub fn zeros_like(&self) -> Self {
357 Self::zeros(self.shape())
358 }
359
360 pub fn scale(&mut self, alpha: T) {
362 for value in &mut self.data {
363 *value *= alpha;
364 }
365 }
366
367 pub fn scaled(&self, alpha: T) -> Self {
369 Self::from_fn(self.shape(), |i, j| self[(i, j)] * alpha)
370 }
371
372 pub fn scaled_into(&self, alpha: T, mut out: ArrayViewMut2<'_, T>) -> Result<()> {
374 if self.shape() != out.shape() {
375 return Err(Error::shape(self.shape(), out.shape()));
376 }
377 for i in 0..self.rows {
378 for j in 0..self.cols {
379 out[(i, j)] = self[(i, j)] * alpha;
380 }
381 }
382 Ok(())
383 }
384
385 pub fn add(&self, other: ArrayView2<'_, T>) -> Result<Self> {
387 self.zip_map(other, |left, right| left + right)
388 }
389
390 pub fn add_into(&self, other: ArrayView2<'_, T>, out: ArrayViewMut2<'_, T>) -> Result<()> {
392 self.zip_map_into(other, out, |left, right| left + right)
393 }
394
395 pub fn sub(&self, other: ArrayView2<'_, T>) -> Result<Self> {
397 self.zip_map(other, |left, right| left - right)
398 }
399
400 pub fn sub_into(&self, other: ArrayView2<'_, T>, out: ArrayViewMut2<'_, T>) -> Result<()> {
402 self.zip_map_into(other, out, |left, right| left - right)
403 }
404
405 pub fn mul(&self, other: ArrayView2<'_, T>) -> Result<Self> {
407 self.zip_map(other, |left, right| left * right)
408 }
409
410 pub fn mul_into(&self, other: ArrayView2<'_, T>, out: ArrayViewMut2<'_, T>) -> Result<()> {
412 self.zip_map_into(other, out, |left, right| left * right)
413 }
414
415 pub fn hadamard(&self, other: ArrayView2<'_, T>) -> Result<Self> {
417 self.mul(other)
418 }
419
420 pub fn hadamard_into(&self, other: ArrayView2<'_, T>, out: ArrayViewMut2<'_, T>) -> Result<()> {
422 self.mul_into(other, out)
423 }
424
425 pub fn div(&self, other: ArrayView2<'_, T>) -> Result<Self> {
427 self.zip_map(other, |left, right| left / right)
428 }
429
430 pub fn div_into(&self, other: ArrayView2<'_, T>, out: ArrayViewMut2<'_, T>) -> Result<()> {
432 self.zip_map_into(other, out, |left, right| left / right)
433 }
434
435 pub fn axpy_result(&self, alpha: T, x: ArrayView2<'_, T>) -> Result<Self> {
437 self.zip_map(x, |left, right| left + alpha * right)
438 }
439
440 pub fn axpy_into(
442 &self,
443 alpha: T,
444 x: ArrayView2<'_, T>,
445 out: ArrayViewMut2<'_, T>,
446 ) -> Result<()> {
447 self.zip_map_into(x, out, |left, right| left + alpha * right)
448 }
449
450 pub fn matmul(&self, other: ArrayView2<'_, T>) -> Result<Self> {
452 crate::linalg::matmul(self.view(), other)
453 }
454
455 pub fn matmul_into(&self, other: ArrayView2<'_, T>, out: ArrayViewMut2<'_, T>) -> Result<()> {
457 crate::linalg::gemm(T::one(), self.view(), false, other, false, T::zero(), out)
458 }
459
460 pub fn add_assign_view(&mut self, other: ArrayView2<'_, T>) -> Result<()> {
462 if self.shape() != other.shape() {
463 return Err(Error::shape(self.shape(), other.shape()));
464 }
465 for i in 0..self.rows {
466 for j in 0..self.cols {
467 self[(i, j)] += other[(i, j)];
468 }
469 }
470 Ok(())
471 }
472
473 pub fn sub_assign_view(&mut self, other: ArrayView2<'_, T>) -> Result<()> {
475 if self.shape() != other.shape() {
476 return Err(Error::shape(self.shape(), other.shape()));
477 }
478 for i in 0..self.rows {
479 for j in 0..self.cols {
480 self[(i, j)] -= other[(i, j)];
481 }
482 }
483 Ok(())
484 }
485
486 pub fn mul_assign_view(&mut self, other: ArrayView2<'_, T>) -> Result<()> {
488 if self.shape() != other.shape() {
489 return Err(Error::shape(self.shape(), other.shape()));
490 }
491 for i in 0..self.rows {
492 for j in 0..self.cols {
493 self[(i, j)] *= other[(i, j)];
494 }
495 }
496 Ok(())
497 }
498
499 pub fn div_assign_view(&mut self, other: ArrayView2<'_, T>) -> Result<()> {
501 if self.shape() != other.shape() {
502 return Err(Error::shape(self.shape(), other.shape()));
503 }
504 for i in 0..self.rows {
505 for j in 0..self.cols {
506 self[(i, j)] /= other[(i, j)];
507 }
508 }
509 Ok(())
510 }
511
512 pub fn axpy(&mut self, alpha: T, x: ArrayView2<'_, T>) -> Result<()> {
514 if self.shape() != x.shape() {
515 return Err(Error::shape(self.shape(), x.shape()));
516 }
517 for i in 0..self.rows {
518 for j in 0..self.cols {
519 self[(i, j)] += alpha * x[(i, j)];
520 }
521 }
522 Ok(())
523 }
524
525 pub fn map_inplace(&mut self, mut f: impl FnMut(T) -> T) {
527 for value in &mut self.data {
528 *value = f(*value);
529 }
530 }
531
532 pub fn zip_map_inplace(
534 &mut self,
535 other: ArrayView2<'_, T>,
536 mut f: impl FnMut(T, T) -> T,
537 ) -> Result<()> {
538 if self.shape() != other.shape() {
539 return Err(Error::shape(self.shape(), other.shape()));
540 }
541 for i in 0..self.rows {
542 for j in 0..self.cols {
543 self[(i, j)] = f(self[(i, j)], other[(i, j)]);
544 }
545 }
546 Ok(())
547 }
548
549 pub fn zip_map(&self, other: ArrayView2<'_, T>, mut f: impl FnMut(T, T) -> T) -> Result<Self> {
551 if self.shape() != other.shape() {
552 return Err(Error::shape(self.shape(), other.shape()));
553 }
554 Ok(Self::from_fn(self.shape(), |i, j| {
555 f(self[(i, j)], other[(i, j)])
556 }))
557 }
558
559 pub fn zip_map_into(
561 &self,
562 other: ArrayView2<'_, T>,
563 mut out: ArrayViewMut2<'_, T>,
564 mut f: impl FnMut(T, T) -> T,
565 ) -> Result<()> {
566 if self.shape() != other.shape() {
567 return Err(Error::shape(self.shape(), other.shape()));
568 }
569 if self.shape() != out.shape() {
570 return Err(Error::shape(self.shape(), out.shape()));
571 }
572 for i in 0..self.rows {
573 for j in 0..self.cols {
574 out[(i, j)] = f(self[(i, j)], other[(i, j)]);
575 }
576 }
577 Ok(())
578 }
579
580 pub fn fill_uniform(&mut self, seed: u64) {
582 let mut rng = SmallRng::new(seed);
583 for value in &mut self.data {
584 *value = rng.uniform();
585 }
586 }
587
588 pub fn fill_randn(&mut self, seed: u64) {
590 let mut rng = SmallRng::new(seed);
591 for value in &mut self.data {
592 *value = rng.normal();
593 }
594 }
595
596 pub fn sum(&self) -> T {
598 self.data.iter().copied().sum()
599 }
600
601 pub fn mean(&self) -> T {
603 if self.is_empty() {
604 T::zero()
605 } else {
606 self.sum() / T::from_f64(self.len() as f64)
607 }
608 }
609
610 pub fn norm_frobenius(&self) -> T {
612 self.data
613 .iter()
614 .copied()
615 .map(|value| value * value)
616 .sum::<T>()
617 .sqrt()
618 }
619
620 pub fn max_abs(&self) -> T {
622 self.data
623 .iter()
624 .copied()
625 .map(T::abs)
626 .fold(
627 T::zero(),
628 |best, value| if value > best { value } else { best },
629 )
630 }
631
632 pub fn dot(&self, other: ArrayView2<'_, T>) -> Result<T> {
634 if self.shape() != other.shape() {
635 return Err(Error::shape(self.shape(), other.shape()));
636 }
637 let mut sum = T::zero();
638 for i in 0..self.rows {
639 for j in 0..self.cols {
640 sum += self[(i, j)] * other[(i, j)];
641 }
642 }
643 Ok(sum)
644 }
645}
646
647fn checked_len(shape: [usize; 2]) -> Result<usize> {
648 shape[0]
649 .checked_mul(shape[1])
650 .ok_or(Error::DimensionTooLarge)
651}
652
653impl<T> Index<(usize, usize)> for Array2<T> {
654 type Output = T;
655
656 fn index(&self, index: (usize, usize)) -> &Self::Output {
657 self.get(index.0, index.1)
658 .expect("array index out of bounds")
659 }
660}
661
662impl<T> IndexMut<(usize, usize)> for Array2<T> {
663 fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
664 self.get_mut(index.0, index.1)
665 .expect("array index out of bounds")
666 }
667}