1use core::ops::{Index, IndexMut};
4
5use crate::array2::Array2;
6use crate::error::{Error, Result};
7
8#[derive(Clone, Copy, Debug)]
10pub struct ArrayView2<'a, T> {
11 pub(crate) data: &'a [T],
12 pub(crate) shape: [usize; 2],
13 pub(crate) strides: [isize; 2],
14 pub(crate) offset: isize,
15}
16
17#[derive(Debug)]
19pub struct ArrayViewMut2<'a, T> {
20 pub(crate) data: &'a mut [T],
21 pub(crate) shape: [usize; 2],
22 pub(crate) strides: [isize; 2],
23 pub(crate) offset: isize,
24}
25
26impl<'a, T> ArrayView2<'a, T> {
27 pub fn new(
29 data: &'a [T],
30 shape: [usize; 2],
31 strides: [isize; 2],
32 offset: isize,
33 ) -> Result<Self> {
34 validate_view(data.len(), &shape, &strides, offset)?;
35 Ok(Self {
36 data,
37 shape,
38 strides,
39 offset,
40 })
41 }
42
43 pub(crate) fn from_raw_parts(
44 data: &'a [T],
45 shape: [usize; 2],
46 strides: [isize; 2],
47 offset: isize,
48 ) -> Self {
49 Self {
50 data,
51 shape,
52 strides,
53 offset,
54 }
55 }
56
57 #[inline]
59 pub fn shape(&self) -> [usize; 2] {
60 self.shape
61 }
62
63 #[inline]
65 pub fn rows(&self) -> usize {
66 self.shape[0]
67 }
68
69 #[inline]
71 pub fn cols(&self) -> usize {
72 self.shape[1]
73 }
74
75 #[inline]
77 pub fn strides(&self) -> [isize; 2] {
78 self.strides
79 }
80
81 #[inline]
83 pub fn row_stride(&self) -> isize {
84 self.strides[0]
85 }
86
87 #[inline]
89 pub fn col_stride(&self) -> isize {
90 self.strides[1]
91 }
92
93 #[inline]
95 pub fn leading_dimension(&self) -> isize {
96 self.strides[0]
97 }
98
99 #[inline]
101 pub fn len(&self) -> usize {
102 self.shape[0] * self.shape[1]
103 }
104
105 #[inline]
107 pub fn is_empty(&self) -> bool {
108 self.len() == 0
109 }
110
111 #[inline]
113 pub fn is_contiguous(&self) -> bool {
114 is_compact_row_major(self.shape, self.strides)
115 }
116
117 pub fn as_slice(&self) -> Option<&'a [T]> {
119 if !self.is_contiguous() {
120 return None;
121 }
122 let start = self.offset as usize;
123 let end = start + self.len();
124 Some(&self.data[start..end])
125 }
126
127 #[inline]
129 pub fn get(&self, row: usize, col: usize) -> Option<&'a T> {
130 if row >= self.rows() || col >= self.cols() {
131 return None;
132 }
133 Some(&self.data[self.linear_index(row, col)])
134 }
135
136 #[inline]
138 pub fn transpose(self) -> Self {
139 Self {
140 data: self.data,
141 shape: [self.shape[1], self.shape[0]],
142 strides: [self.strides[1], self.strides[0]],
143 offset: self.offset,
144 }
145 }
146
147 pub fn row(&self, row: usize) -> Result<Self> {
149 if row >= self.rows() {
150 return Err(Error::IndexOutOfBounds);
151 }
152 Ok(Self {
153 data: self.data,
154 shape: [1, self.cols()],
155 strides: self.strides,
156 offset: self.offset + row as isize * self.strides[0],
157 })
158 }
159
160 pub fn row_slice(&self, row: usize) -> Result<Option<&'a [T]>> {
162 if row >= self.rows() {
163 return Err(Error::IndexOutOfBounds);
164 }
165 if self.cols() == 0 {
166 return Ok(Some(&self.data[0..0]));
167 }
168 if self.strides[1] != 1 {
169 return Ok(None);
170 }
171 let start = self.linear_index(row, 0);
172 let end = start + self.cols();
173 Ok(Some(&self.data[start..end]))
174 }
175
176 pub fn col(&self, col: usize) -> Result<Self> {
178 if col >= self.cols() {
179 return Err(Error::IndexOutOfBounds);
180 }
181 Ok(Self {
182 data: self.data,
183 shape: [self.rows(), 1],
184 strides: self.strides,
185 offset: self.offset + col as isize * self.strides[1],
186 })
187 }
188
189 pub fn rows_range(&self, start: usize, end: usize) -> Result<Self> {
191 if start > end || end > self.rows() {
192 return Err(Error::IndexOutOfBounds);
193 }
194 Ok(Self {
195 data: self.data,
196 shape: [end - start, self.cols()],
197 strides: self.strides,
198 offset: self.offset + start as isize * self.strides[0],
199 })
200 }
201
202 pub fn cols_range(&self, start: usize, end: usize) -> Result<Self> {
204 if start > end || end > self.cols() {
205 return Err(Error::IndexOutOfBounds);
206 }
207 Ok(Self {
208 data: self.data,
209 shape: [self.rows(), end - start],
210 strides: self.strides,
211 offset: self.offset + start as isize * self.strides[1],
212 })
213 }
214
215 #[inline]
216 pub(crate) fn linear_index(&self, row: usize, col: usize) -> usize {
217 (self.offset + row as isize * self.strides[0] + col as isize * self.strides[1]) as usize
218 }
219}
220
221impl<T: Clone> ArrayView2<'_, T> {
222 pub fn to_row_major(&self) -> Array2<T> {
224 Array2::from_fn(self.shape, |i, j| self[(i, j)].clone())
225 }
226
227 pub fn to_col_major_vec(&self) -> Vec<T> {
229 let mut data = Vec::with_capacity(self.len());
230 for j in 0..self.cols() {
231 for i in 0..self.rows() {
232 data.push(self[(i, j)].clone());
233 }
234 }
235 data
236 }
237}
238
239impl<'a, T> ArrayViewMut2<'a, T> {
240 pub fn new(
242 data: &'a mut [T],
243 shape: [usize; 2],
244 strides: [isize; 2],
245 offset: isize,
246 ) -> Result<Self> {
247 validate_view(data.len(), &shape, &strides, offset)?;
248 Ok(Self {
249 data,
250 shape,
251 strides,
252 offset,
253 })
254 }
255
256 pub(crate) fn from_raw_parts(
257 data: &'a mut [T],
258 shape: [usize; 2],
259 strides: [isize; 2],
260 offset: isize,
261 ) -> Self {
262 Self {
263 data,
264 shape,
265 strides,
266 offset,
267 }
268 }
269
270 #[inline]
272 pub fn shape(&self) -> [usize; 2] {
273 self.shape
274 }
275
276 #[inline]
278 pub fn rows(&self) -> usize {
279 self.shape[0]
280 }
281
282 #[inline]
284 pub fn cols(&self) -> usize {
285 self.shape[1]
286 }
287
288 #[inline]
290 pub fn strides(&self) -> [isize; 2] {
291 self.strides
292 }
293
294 #[inline]
296 pub fn row_stride(&self) -> isize {
297 self.strides[0]
298 }
299
300 #[inline]
302 pub fn col_stride(&self) -> isize {
303 self.strides[1]
304 }
305
306 #[inline]
308 pub fn leading_dimension(&self) -> isize {
309 self.strides[0]
310 }
311
312 #[inline]
314 pub fn len(&self) -> usize {
315 self.shape[0] * self.shape[1]
316 }
317
318 #[inline]
320 pub fn is_empty(&self) -> bool {
321 self.len() == 0
322 }
323
324 #[inline]
326 pub fn is_contiguous(&self) -> bool {
327 is_compact_row_major(self.shape, self.strides)
328 }
329
330 pub fn as_view(&self) -> ArrayView2<'_, T> {
332 ArrayView2 {
333 data: self.data,
334 shape: self.shape,
335 strides: self.strides,
336 offset: self.offset,
337 }
338 }
339
340 pub fn as_mut_slice(&mut self) -> Option<&mut [T]> {
342 if !self.is_contiguous() {
343 return None;
344 }
345 let start = self.offset as usize;
346 let end = start + self.len();
347 Some(&mut self.data[start..end])
348 }
349
350 #[inline]
352 pub fn get(&self, row: usize, col: usize) -> Option<&T> {
353 if row >= self.rows() || col >= self.cols() {
354 return None;
355 }
356 Some(&self.data[self.linear_index(row, col)])
357 }
358
359 #[inline]
361 pub fn get_mut(&mut self, row: usize, col: usize) -> Option<&mut T> {
362 if row >= self.rows() || col >= self.cols() {
363 return None;
364 }
365 let index = self.linear_index(row, col);
366 Some(&mut self.data[index])
367 }
368
369 pub fn transpose(self) -> Self {
371 Self {
372 data: self.data,
373 shape: [self.shape[1], self.shape[0]],
374 strides: [self.strides[1], self.strides[0]],
375 offset: self.offset,
376 }
377 }
378
379 pub fn row_mut(&mut self, row: usize) -> Result<ArrayViewMut2<'_, T>> {
381 if row >= self.rows() {
382 return Err(Error::IndexOutOfBounds);
383 }
384 let cols = self.cols();
385 let strides = self.strides;
386 let offset = self.offset + row as isize * strides[0];
387 Ok(ArrayViewMut2 {
388 data: &mut *self.data,
389 shape: [1, cols],
390 strides,
391 offset,
392 })
393 }
394
395 pub fn row_slice_mut(&mut self, row: usize) -> Result<Option<&mut [T]>> {
397 if row >= self.rows() {
398 return Err(Error::IndexOutOfBounds);
399 }
400 if self.cols() == 0 {
401 return Ok(Some(&mut self.data[0..0]));
402 }
403 if self.strides[1] != 1 {
404 return Ok(None);
405 }
406 let start = self.linear_index(row, 0);
407 let end = start + self.cols();
408 Ok(Some(&mut self.data[start..end]))
409 }
410
411 pub fn col_mut(&mut self, col: usize) -> Result<ArrayViewMut2<'_, T>> {
413 if col >= self.cols() {
414 return Err(Error::IndexOutOfBounds);
415 }
416 let rows = self.rows();
417 let strides = self.strides;
418 let offset = self.offset + col as isize * strides[1];
419 Ok(ArrayViewMut2 {
420 data: &mut *self.data,
421 shape: [rows, 1],
422 strides,
423 offset,
424 })
425 }
426
427 pub fn rows_range_mut(&mut self, start: usize, end: usize) -> Result<ArrayViewMut2<'_, T>> {
429 if start > end || end > self.rows() {
430 return Err(Error::IndexOutOfBounds);
431 }
432 let cols = self.cols();
433 let strides = self.strides;
434 let offset = self.offset + start as isize * strides[0];
435 Ok(ArrayViewMut2 {
436 data: &mut *self.data,
437 shape: [end - start, cols],
438 strides,
439 offset,
440 })
441 }
442
443 pub fn cols_range_mut(&mut self, start: usize, end: usize) -> Result<ArrayViewMut2<'_, T>> {
445 if start > end || end > self.cols() {
446 return Err(Error::IndexOutOfBounds);
447 }
448 let rows = self.rows();
449 let strides = self.strides;
450 let offset = self.offset + start as isize * strides[1];
451 Ok(ArrayViewMut2 {
452 data: &mut *self.data,
453 shape: [rows, end - start],
454 strides,
455 offset,
456 })
457 }
458
459 #[inline]
460 pub(crate) fn linear_index(&self, row: usize, col: usize) -> usize {
461 (self.offset + row as isize * self.strides[0] + col as isize * self.strides[1]) as usize
462 }
463}
464
465impl<T: Clone> ArrayViewMut2<'_, T> {
466 pub fn to_row_major(&self) -> Array2<T> {
468 self.as_view().to_row_major()
469 }
470
471 pub fn to_col_major_vec(&self) -> Vec<T> {
473 self.as_view().to_col_major_vec()
474 }
475
476 pub fn copy_from_view(&mut self, other: ArrayView2<'_, T>) -> Result<()> {
478 if self.shape() != other.shape() {
479 return Err(Error::shape(self.shape(), other.shape()));
480 }
481 for i in 0..self.rows() {
482 for j in 0..self.cols() {
483 self[(i, j)] = other[(i, j)].clone();
484 }
485 }
486 Ok(())
487 }
488}
489
490#[inline]
491pub(crate) fn is_compact_row_major(shape: [usize; 2], strides: [isize; 2]) -> bool {
492 shape[0] == 0
493 || shape[1] == 0
494 || (strides[1] == 1 && (shape[0] <= 1 || strides[0] == shape[1] as isize))
495}
496
497impl<T> Index<(usize, usize)> for ArrayView2<'_, T> {
498 type Output = T;
499
500 fn index(&self, index: (usize, usize)) -> &Self::Output {
501 self.get(index.0, index.1)
502 .expect("view index out of bounds")
503 }
504}
505
506impl<T> Index<(usize, usize)> for ArrayViewMut2<'_, T> {
507 type Output = T;
508
509 fn index(&self, index: (usize, usize)) -> &Self::Output {
510 self.get(index.0, index.1)
511 .expect("view index out of bounds")
512 }
513}
514
515impl<T> IndexMut<(usize, usize)> for ArrayViewMut2<'_, T> {
516 fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
517 self.get_mut(index.0, index.1)
518 .expect("view index out of bounds")
519 }
520}
521
522pub(crate) fn validate_view(
523 len: usize,
524 shape: &[usize],
525 strides: &[isize],
526 offset: isize,
527) -> Result<()> {
528 if shape.len() != strides.len() || offset < 0 {
529 return Err(Error::InvalidStride);
530 }
531 if shape.contains(&0) {
532 return Ok(());
533 }
534 let mut min = offset;
535 let mut max = offset;
536 for (&dim, &stride) in shape.iter().zip(strides) {
537 let span = (dim - 1) as isize * stride;
538 if span >= 0 {
539 max += span;
540 } else {
541 min += span;
542 }
543 }
544 if min < 0 || max < 0 || max as usize >= len {
545 return Err(Error::InvalidStride);
546 }
547 Ok(())
548}