Skip to main content

tensors/
view2.rs

1//! Two-dimensional strided array views.
2
3use core::ops::{Index, IndexMut};
4
5use crate::array2::Array2;
6use crate::error::{Error, Result};
7
8/// Immutable 2D array view.
9#[derive(Clone, Copy, Debug)]
10pub struct ArrayView2<'a, T> {
11    pub(crate) data: &'a [T],
12    pub(crate) shape: [usize; 2],
13    pub(crate) strides: [isize; 2],
14    pub(crate) offset: isize,
15}
16
17/// Mutable 2D array view.
18#[derive(Debug)]
19pub struct ArrayViewMut2<'a, T> {
20    pub(crate) data: &'a mut [T],
21    pub(crate) shape: [usize; 2],
22    pub(crate) strides: [isize; 2],
23    pub(crate) offset: isize,
24}
25
26impl<'a, T> ArrayView2<'a, T> {
27    /// Create a checked immutable view.
28    pub fn new(
29        data: &'a [T],
30        shape: [usize; 2],
31        strides: [isize; 2],
32        offset: isize,
33    ) -> Result<Self> {
34        validate_view(data.len(), &shape, &strides, offset)?;
35        Ok(Self {
36            data,
37            shape,
38            strides,
39            offset,
40        })
41    }
42
43    pub(crate) fn from_raw_parts(
44        data: &'a [T],
45        shape: [usize; 2],
46        strides: [isize; 2],
47        offset: isize,
48    ) -> Self {
49        Self {
50            data,
51            shape,
52            strides,
53            offset,
54        }
55    }
56
57    /// Shape as `[rows, cols]`.
58    #[inline]
59    pub fn shape(&self) -> [usize; 2] {
60        self.shape
61    }
62
63    /// Number of rows.
64    #[inline]
65    pub fn rows(&self) -> usize {
66        self.shape[0]
67    }
68
69    /// Number of columns.
70    #[inline]
71    pub fn cols(&self) -> usize {
72        self.shape[1]
73    }
74
75    /// Strides in elements.
76    #[inline]
77    pub fn strides(&self) -> [isize; 2] {
78        self.strides
79    }
80
81    /// Distance in elements between consecutive rows.
82    #[inline]
83    pub fn row_stride(&self) -> isize {
84        self.strides[0]
85    }
86
87    /// Distance in elements between consecutive columns.
88    #[inline]
89    pub fn col_stride(&self) -> isize {
90        self.strides[1]
91    }
92
93    /// Leading dimension for the current row-major-style view.
94    #[inline]
95    pub fn leading_dimension(&self) -> isize {
96        self.strides[0]
97    }
98
99    /// Number of elements in the logical view.
100    #[inline]
101    pub fn len(&self) -> usize {
102        self.shape[0] * self.shape[1]
103    }
104
105    /// Whether the logical view has no elements.
106    #[inline]
107    pub fn is_empty(&self) -> bool {
108        self.len() == 0
109    }
110
111    /// Whether the view is compact row-major contiguous.
112    #[inline]
113    pub fn is_contiguous(&self) -> bool {
114        is_compact_row_major(self.shape, self.strides)
115    }
116
117    /// Borrow the backing slice if this view covers it contiguously.
118    pub fn as_slice(&self) -> Option<&'a [T]> {
119        if !self.is_contiguous() {
120            return None;
121        }
122        let start = self.offset as usize;
123        let end = start + self.len();
124        Some(&self.data[start..end])
125    }
126
127    /// Get a copied reference at `(row, col)`.
128    #[inline]
129    pub fn get(&self, row: usize, col: usize) -> Option<&'a T> {
130        if row >= self.rows() || col >= self.cols() {
131            return None;
132        }
133        Some(&self.data[self.linear_index(row, col)])
134    }
135
136    /// Return a transpose view without copying.
137    #[inline]
138    pub fn transpose(self) -> Self {
139        Self {
140            data: self.data,
141            shape: [self.shape[1], self.shape[0]],
142            strides: [self.strides[1], self.strides[0]],
143            offset: self.offset,
144        }
145    }
146
147    /// Return a row view as a one-row matrix.
148    pub fn row(&self, row: usize) -> Result<Self> {
149        if row >= self.rows() {
150            return Err(Error::IndexOutOfBounds);
151        }
152        Ok(Self {
153            data: self.data,
154            shape: [1, self.cols()],
155            strides: self.strides,
156            offset: self.offset + row as isize * self.strides[0],
157        })
158    }
159
160    /// Borrow a row as a contiguous slice when the row layout permits it.
161    pub fn row_slice(&self, row: usize) -> Result<Option<&'a [T]>> {
162        if row >= self.rows() {
163            return Err(Error::IndexOutOfBounds);
164        }
165        if self.cols() == 0 {
166            return Ok(Some(&self.data[0..0]));
167        }
168        if self.strides[1] != 1 {
169            return Ok(None);
170        }
171        let start = self.linear_index(row, 0);
172        let end = start + self.cols();
173        Ok(Some(&self.data[start..end]))
174    }
175
176    /// Return a column view as an `rows x 1` matrix.
177    pub fn col(&self, col: usize) -> Result<Self> {
178        if col >= self.cols() {
179            return Err(Error::IndexOutOfBounds);
180        }
181        Ok(Self {
182            data: self.data,
183            shape: [self.rows(), 1],
184            strides: self.strides,
185            offset: self.offset + col as isize * self.strides[1],
186        })
187    }
188
189    /// Slice a half-open row range.
190    pub fn rows_range(&self, start: usize, end: usize) -> Result<Self> {
191        if start > end || end > self.rows() {
192            return Err(Error::IndexOutOfBounds);
193        }
194        Ok(Self {
195            data: self.data,
196            shape: [end - start, self.cols()],
197            strides: self.strides,
198            offset: self.offset + start as isize * self.strides[0],
199        })
200    }
201
202    /// Slice a half-open column range.
203    pub fn cols_range(&self, start: usize, end: usize) -> Result<Self> {
204        if start > end || end > self.cols() {
205            return Err(Error::IndexOutOfBounds);
206        }
207        Ok(Self {
208            data: self.data,
209            shape: [self.rows(), end - start],
210            strides: self.strides,
211            offset: self.offset + start as isize * self.strides[1],
212        })
213    }
214
215    #[inline]
216    pub(crate) fn linear_index(&self, row: usize, col: usize) -> usize {
217        (self.offset + row as isize * self.strides[0] + col as isize * self.strides[1]) as usize
218    }
219}
220
221impl<T: Clone> ArrayView2<'_, T> {
222    /// Copy this view into compact row-major storage.
223    pub fn to_row_major(&self) -> Array2<T> {
224        Array2::from_fn(self.shape, |i, j| self[(i, j)].clone())
225    }
226
227    /// Copy this view into a column-major vector.
228    pub fn to_col_major_vec(&self) -> Vec<T> {
229        let mut data = Vec::with_capacity(self.len());
230        for j in 0..self.cols() {
231            for i in 0..self.rows() {
232                data.push(self[(i, j)].clone());
233            }
234        }
235        data
236    }
237}
238
239impl<'a, T> ArrayViewMut2<'a, T> {
240    /// Create a checked mutable view.
241    pub fn new(
242        data: &'a mut [T],
243        shape: [usize; 2],
244        strides: [isize; 2],
245        offset: isize,
246    ) -> Result<Self> {
247        validate_view(data.len(), &shape, &strides, offset)?;
248        Ok(Self {
249            data,
250            shape,
251            strides,
252            offset,
253        })
254    }
255
256    pub(crate) fn from_raw_parts(
257        data: &'a mut [T],
258        shape: [usize; 2],
259        strides: [isize; 2],
260        offset: isize,
261    ) -> Self {
262        Self {
263            data,
264            shape,
265            strides,
266            offset,
267        }
268    }
269
270    /// Shape as `[rows, cols]`.
271    #[inline]
272    pub fn shape(&self) -> [usize; 2] {
273        self.shape
274    }
275
276    /// Number of rows.
277    #[inline]
278    pub fn rows(&self) -> usize {
279        self.shape[0]
280    }
281
282    /// Number of columns.
283    #[inline]
284    pub fn cols(&self) -> usize {
285        self.shape[1]
286    }
287
288    /// Strides in elements.
289    #[inline]
290    pub fn strides(&self) -> [isize; 2] {
291        self.strides
292    }
293
294    /// Distance in elements between consecutive rows.
295    #[inline]
296    pub fn row_stride(&self) -> isize {
297        self.strides[0]
298    }
299
300    /// Distance in elements between consecutive columns.
301    #[inline]
302    pub fn col_stride(&self) -> isize {
303        self.strides[1]
304    }
305
306    /// Leading dimension for the current row-major-style view.
307    #[inline]
308    pub fn leading_dimension(&self) -> isize {
309        self.strides[0]
310    }
311
312    /// Number of elements in the logical view.
313    #[inline]
314    pub fn len(&self) -> usize {
315        self.shape[0] * self.shape[1]
316    }
317
318    /// Whether the logical view has no elements.
319    #[inline]
320    pub fn is_empty(&self) -> bool {
321        self.len() == 0
322    }
323
324    /// Whether the view is compact row-major contiguous.
325    #[inline]
326    pub fn is_contiguous(&self) -> bool {
327        is_compact_row_major(self.shape, self.strides)
328    }
329
330    /// Immutable view over the same region.
331    pub fn as_view(&self) -> ArrayView2<'_, T> {
332        ArrayView2 {
333            data: self.data,
334            shape: self.shape,
335            strides: self.strides,
336            offset: self.offset,
337        }
338    }
339
340    /// Borrow the backing slice if this view covers it contiguously.
341    pub fn as_mut_slice(&mut self) -> Option<&mut [T]> {
342        if !self.is_contiguous() {
343            return None;
344        }
345        let start = self.offset as usize;
346        let end = start + self.len();
347        Some(&mut self.data[start..end])
348    }
349
350    /// Get an immutable element reference.
351    #[inline]
352    pub fn get(&self, row: usize, col: usize) -> Option<&T> {
353        if row >= self.rows() || col >= self.cols() {
354            return None;
355        }
356        Some(&self.data[self.linear_index(row, col)])
357    }
358
359    /// Get a mutable element reference.
360    #[inline]
361    pub fn get_mut(&mut self, row: usize, col: usize) -> Option<&mut T> {
362        if row >= self.rows() || col >= self.cols() {
363            return None;
364        }
365        let index = self.linear_index(row, col);
366        Some(&mut self.data[index])
367    }
368
369    /// Return a transpose view without copying.
370    pub fn transpose(self) -> Self {
371        Self {
372            data: self.data,
373            shape: [self.shape[1], self.shape[0]],
374            strides: [self.strides[1], self.strides[0]],
375            offset: self.offset,
376        }
377    }
378
379    /// Return a mutable row view as a one-row matrix.
380    pub fn row_mut(&mut self, row: usize) -> Result<ArrayViewMut2<'_, T>> {
381        if row >= self.rows() {
382            return Err(Error::IndexOutOfBounds);
383        }
384        let cols = self.cols();
385        let strides = self.strides;
386        let offset = self.offset + row as isize * strides[0];
387        Ok(ArrayViewMut2 {
388            data: &mut *self.data,
389            shape: [1, cols],
390            strides,
391            offset,
392        })
393    }
394
395    /// Borrow a mutable row as a contiguous slice when the row layout permits it.
396    pub fn row_slice_mut(&mut self, row: usize) -> Result<Option<&mut [T]>> {
397        if row >= self.rows() {
398            return Err(Error::IndexOutOfBounds);
399        }
400        if self.cols() == 0 {
401            return Ok(Some(&mut self.data[0..0]));
402        }
403        if self.strides[1] != 1 {
404            return Ok(None);
405        }
406        let start = self.linear_index(row, 0);
407        let end = start + self.cols();
408        Ok(Some(&mut self.data[start..end]))
409    }
410
411    /// Return a mutable column view as an `rows x 1` matrix.
412    pub fn col_mut(&mut self, col: usize) -> Result<ArrayViewMut2<'_, T>> {
413        if col >= self.cols() {
414            return Err(Error::IndexOutOfBounds);
415        }
416        let rows = self.rows();
417        let strides = self.strides;
418        let offset = self.offset + col as isize * strides[1];
419        Ok(ArrayViewMut2 {
420            data: &mut *self.data,
421            shape: [rows, 1],
422            strides,
423            offset,
424        })
425    }
426
427    /// Slice a mutable half-open row range.
428    pub fn rows_range_mut(&mut self, start: usize, end: usize) -> Result<ArrayViewMut2<'_, T>> {
429        if start > end || end > self.rows() {
430            return Err(Error::IndexOutOfBounds);
431        }
432        let cols = self.cols();
433        let strides = self.strides;
434        let offset = self.offset + start as isize * strides[0];
435        Ok(ArrayViewMut2 {
436            data: &mut *self.data,
437            shape: [end - start, cols],
438            strides,
439            offset,
440        })
441    }
442
443    /// Slice a mutable half-open column range.
444    pub fn cols_range_mut(&mut self, start: usize, end: usize) -> Result<ArrayViewMut2<'_, T>> {
445        if start > end || end > self.cols() {
446            return Err(Error::IndexOutOfBounds);
447        }
448        let rows = self.rows();
449        let strides = self.strides;
450        let offset = self.offset + start as isize * strides[1];
451        Ok(ArrayViewMut2 {
452            data: &mut *self.data,
453            shape: [rows, end - start],
454            strides,
455            offset,
456        })
457    }
458
459    #[inline]
460    pub(crate) fn linear_index(&self, row: usize, col: usize) -> usize {
461        (self.offset + row as isize * self.strides[0] + col as isize * self.strides[1]) as usize
462    }
463}
464
465impl<T: Clone> ArrayViewMut2<'_, T> {
466    /// Copy this view into compact row-major storage.
467    pub fn to_row_major(&self) -> Array2<T> {
468        self.as_view().to_row_major()
469    }
470
471    /// Copy this view into a column-major vector.
472    pub fn to_col_major_vec(&self) -> Vec<T> {
473        self.as_view().to_col_major_vec()
474    }
475
476    /// Copy values from another view with the same shape.
477    pub fn copy_from_view(&mut self, other: ArrayView2<'_, T>) -> Result<()> {
478        if self.shape() != other.shape() {
479            return Err(Error::shape(self.shape(), other.shape()));
480        }
481        for i in 0..self.rows() {
482            for j in 0..self.cols() {
483                self[(i, j)] = other[(i, j)].clone();
484            }
485        }
486        Ok(())
487    }
488}
489
490#[inline]
491pub(crate) fn is_compact_row_major(shape: [usize; 2], strides: [isize; 2]) -> bool {
492    shape[0] == 0
493        || shape[1] == 0
494        || (strides[1] == 1 && (shape[0] <= 1 || strides[0] == shape[1] as isize))
495}
496
497impl<T> Index<(usize, usize)> for ArrayView2<'_, T> {
498    type Output = T;
499
500    fn index(&self, index: (usize, usize)) -> &Self::Output {
501        self.get(index.0, index.1)
502            .expect("view index out of bounds")
503    }
504}
505
506impl<T> Index<(usize, usize)> for ArrayViewMut2<'_, T> {
507    type Output = T;
508
509    fn index(&self, index: (usize, usize)) -> &Self::Output {
510        self.get(index.0, index.1)
511            .expect("view index out of bounds")
512    }
513}
514
515impl<T> IndexMut<(usize, usize)> for ArrayViewMut2<'_, T> {
516    fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
517        self.get_mut(index.0, index.1)
518            .expect("view index out of bounds")
519    }
520}
521
522pub(crate) fn validate_view(
523    len: usize,
524    shape: &[usize],
525    strides: &[isize],
526    offset: isize,
527) -> Result<()> {
528    if shape.len() != strides.len() || offset < 0 {
529        return Err(Error::InvalidStride);
530    }
531    if shape.contains(&0) {
532        return Ok(());
533    }
534    let mut min = offset;
535    let mut max = offset;
536    for (&dim, &stride) in shape.iter().zip(strides) {
537        let span = (dim - 1) as isize * stride;
538        if span >= 0 {
539            max += span;
540        } else {
541            min += span;
542        }
543    }
544    if min < 0 || max < 0 || max as usize >= len {
545        return Err(Error::InvalidStride);
546    }
547    Ok(())
548}