1use torsh_core::dtype::TensorElement;
22use torsh_core::error::{Result, TorshError};
23use torsh_core::shape::Shape;
24
25#[derive(Debug)]
37pub struct TensorView<'a, T: TensorElement> {
38 data: &'a [T],
40
41 shape: Shape,
43
44 strides: Vec<usize>,
46
47 offset: usize,
49}
50
51#[derive(Debug)]
63pub struct TensorViewMut<'a, T: TensorElement> {
64 data: &'a mut [T],
66
67 shape: Shape,
69
70 strides: Vec<usize>,
72
73 offset: usize,
75}
76
77impl<'a, T: TensorElement> TensorView<'a, T> {
82 pub fn new(data: &'a [T], shape: Shape, strides: Vec<usize>, offset: usize) -> Self {
95 Self {
96 data,
97 shape,
98 strides,
99 offset,
100 }
101 }
102
103 #[inline]
105 pub fn shape(&self) -> &Shape {
106 &self.shape
107 }
108
109 #[inline]
111 pub fn strides(&self) -> &[usize] {
112 &self.strides
113 }
114
115 #[inline]
117 pub fn len(&self) -> usize {
118 self.shape.numel()
119 }
120
121 #[inline]
123 pub fn is_empty(&self) -> bool {
124 self.len() == 0
125 }
126
127 #[inline]
132 pub fn data(&self) -> &[T] {
133 &self.data[self.offset..]
134 }
135
136 pub fn is_contiguous(&self) -> bool {
140 if self.shape.dims().is_empty() {
141 return true;
142 }
143
144 let dims = self.shape.dims();
145 let mut expected_stride = 1;
146
147 for i in (0..dims.len()).rev() {
149 if self.strides[i] != expected_stride {
150 return false;
151 }
152 expected_stride *= dims[i];
153 }
154
155 true
156 }
157
158 pub fn get(&self, index: usize) -> Result<&T> {
169 if index >= self.len() {
170 return Err(TorshError::IndexError {
171 index,
172 size: self.len(),
173 });
174 }
175
176 Ok(&self.data[self.offset + index])
177 }
178
179 pub fn get_at(&self, indices: &[usize]) -> Result<&T> {
190 if indices.len() != self.shape.ndim() {
191 return Err(TorshError::InvalidArgument(format!(
192 "Expected {} indices, got {}",
193 self.shape.ndim(),
194 indices.len()
195 )));
196 }
197
198 let flat_index = self.compute_flat_index(indices)?;
199 Ok(&self.data[self.offset + flat_index])
200 }
201
202 fn compute_flat_index(&self, indices: &[usize]) -> Result<usize> {
204 let dims = self.shape.dims();
205 let mut flat_index = 0;
206
207 for (i, &idx) in indices.iter().enumerate() {
208 if idx >= dims[i] {
209 return Err(TorshError::IndexError {
210 index: idx,
211 size: dims[i],
212 });
213 }
214 flat_index += idx * self.strides[i];
215 }
216
217 Ok(flat_index)
218 }
219
220 pub fn iter(&self) -> TensorViewIter<'a, T> {
222 TensorViewIter {
223 data: self.data,
224 offset: self.offset,
225 len: self.len(),
226 current: 0,
227 }
228 }
229
230 pub fn to_vec(&self) -> Vec<T>
234 where
235 T: Copy,
236 {
237 self.data[self.offset..self.offset + self.len()].to_vec()
238 }
239}
240
241impl<'a, T: TensorElement> TensorViewMut<'a, T> {
246 pub fn new(data: &'a mut [T], shape: Shape, strides: Vec<usize>, offset: usize) -> Self {
254 Self {
255 data,
256 shape,
257 strides,
258 offset,
259 }
260 }
261
262 #[inline]
264 pub fn shape(&self) -> &Shape {
265 &self.shape
266 }
267
268 #[inline]
270 pub fn strides(&self) -> &[usize] {
271 &self.strides
272 }
273
274 #[inline]
276 pub fn len(&self) -> usize {
277 self.shape.numel()
278 }
279
280 #[inline]
282 pub fn is_empty(&self) -> bool {
283 self.len() == 0
284 }
285
286 #[inline]
288 pub fn data(&self) -> &[T] {
289 &self.data[self.offset..]
290 }
291
292 #[inline]
296 pub fn data_mut(&mut self) -> &mut [T] {
297 let len = self.len();
298 &mut self.data[self.offset..self.offset + len]
299 }
300
301 pub fn is_contiguous(&self) -> bool {
303 if self.shape.dims().is_empty() {
304 return true;
305 }
306
307 let dims = self.shape.dims();
308 let mut expected_stride = 1;
309
310 for i in (0..dims.len()).rev() {
311 if self.strides[i] != expected_stride {
312 return false;
313 }
314 expected_stride *= dims[i];
315 }
316
317 true
318 }
319
320 pub fn get(&self, index: usize) -> Result<&T> {
322 if index >= self.len() {
323 return Err(TorshError::IndexError {
324 index,
325 size: self.len(),
326 });
327 }
328
329 Ok(&self.data[self.offset + index])
330 }
331
332 pub fn get_mut(&mut self, index: usize) -> Result<&mut T> {
334 if index >= self.len() {
335 return Err(TorshError::IndexError {
336 index,
337 size: self.len(),
338 });
339 }
340
341 Ok(&mut self.data[self.offset + index])
342 }
343
344 pub fn fill(&mut self, value: T)
355 where
356 T: Copy,
357 {
358 let len = self.len();
359 self.data[self.offset..self.offset + len].fill(value);
360 }
361
362 pub fn iter(&self) -> TensorViewIter<'_, T> {
364 TensorViewIter {
365 data: self.data,
366 offset: self.offset,
367 len: self.len(),
368 current: 0,
369 }
370 }
371
372 pub fn iter_mut(&mut self) -> TensorViewIterMut<'_, T> {
374 let len = self.len();
375 TensorViewIterMut {
376 data: &mut self.data[self.offset..self.offset + len],
377 current: 0,
378 }
379 }
380}
381
382pub struct TensorViewIter<'a, T: TensorElement> {
388 data: &'a [T],
389 offset: usize,
390 len: usize,
391 current: usize,
392}
393
394impl<'a, T: TensorElement> Iterator for TensorViewIter<'a, T> {
395 type Item = &'a T;
396
397 fn next(&mut self) -> Option<Self::Item> {
398 if self.current >= self.len {
399 None
400 } else {
401 let item = &self.data[self.offset + self.current];
402 self.current += 1;
403 Some(item)
404 }
405 }
406
407 fn size_hint(&self) -> (usize, Option<usize>) {
408 let remaining = self.len - self.current;
409 (remaining, Some(remaining))
410 }
411}
412
413impl<'a, T: TensorElement> ExactSizeIterator for TensorViewIter<'a, T> {
414 fn len(&self) -> usize {
415 self.len - self.current
416 }
417}
418
419pub struct TensorViewIterMut<'a, T: TensorElement> {
421 data: &'a mut [T],
422 current: usize,
423}
424
425impl<'a, T: TensorElement> Iterator for TensorViewIterMut<'a, T> {
426 type Item = &'a mut T;
427
428 fn next(&mut self) -> Option<Self::Item> {
429 if self.current >= self.data.len() {
430 None
431 } else {
432 let item = unsafe {
433 let ptr = self.data.as_mut_ptr().add(self.current);
435 &mut *ptr
436 };
437 self.current += 1;
438 Some(item)
439 }
440 }
441
442 fn size_hint(&self) -> (usize, Option<usize>) {
443 let remaining = self.data.len() - self.current;
444 (remaining, Some(remaining))
445 }
446}
447
448impl<'a, T: TensorElement> ExactSizeIterator for TensorViewIterMut<'a, T> {
449 fn len(&self) -> usize {
450 self.data.len() - self.current
451 }
452}
453
454#[cfg(test)]
459mod tests {
460 use super::*;
461
462 #[test]
463 fn test_tensor_view_creation() {
464 let data = vec![1.0, 2.0, 3.0, 4.0];
465 let shape = Shape::new(vec![2, 2]);
466 let strides = vec![2, 1];
467
468 let view = TensorView::new(&data, shape, strides, 0);
469
470 assert_eq!(view.len(), 4);
471 assert!(!view.is_empty());
472 assert_eq!(view.shape().dims(), &[2, 2]);
473 }
474
475 #[test]
476 fn test_tensor_view_contiguous() {
477 let data = vec![1.0, 2.0, 3.0, 4.0];
478 let shape = Shape::new(vec![2, 2]);
479 let strides = vec![2, 1];
480
481 let view = TensorView::new(&data, shape, strides, 0);
482 assert!(view.is_contiguous());
483 }
484
485 #[test]
486 fn test_tensor_view_get() {
487 let data = vec![1.0, 2.0, 3.0, 4.0];
488 let shape = Shape::new(vec![4]);
489 let strides = vec![1];
490
491 let view = TensorView::new(&data, shape, strides, 0);
492
493 assert_eq!(*view.get(0).expect("get should succeed"), 1.0);
494 assert_eq!(*view.get(1).expect("get should succeed"), 2.0);
495 assert_eq!(*view.get(2).expect("get should succeed"), 3.0);
496 assert_eq!(*view.get(3).expect("get should succeed"), 4.0);
497
498 assert!(view.get(4).is_err());
499 }
500
501 #[test]
502 fn test_tensor_view_iter() {
503 let data = vec![1.0, 2.0, 3.0, 4.0];
504 let shape = Shape::new(vec![4]);
505 let strides = vec![1];
506
507 let view = TensorView::new(&data, shape, strides, 0);
508 let collected: Vec<_> = view.iter().copied().collect();
509
510 assert_eq!(collected, vec![1.0, 2.0, 3.0, 4.0]);
511 }
512
513 #[test]
514 fn test_tensor_view_mut_creation() {
515 let mut data = vec![1.0, 2.0, 3.0, 4.0];
516 let shape = Shape::new(vec![2, 2]);
517 let strides = vec![2, 1];
518
519 let view = TensorViewMut::new(&mut data, shape, strides, 0);
520
521 assert_eq!(view.len(), 4);
522 assert!(!view.is_empty());
523 }
524
525 #[test]
526 fn test_tensor_view_mut_fill() {
527 let mut data = vec![1.0, 2.0, 3.0, 4.0];
528 let shape = Shape::new(vec![4]);
529 let strides = vec![1];
530
531 let mut view = TensorViewMut::new(&mut data, shape, strides, 0);
532 view.fill(0.0);
533
534 assert_eq!(data, vec![0.0, 0.0, 0.0, 0.0]);
535 }
536
537 #[test]
538 fn test_tensor_view_mut_get_mut() {
539 let mut data = vec![1.0, 2.0, 3.0, 4.0];
540 let shape = Shape::new(vec![4]);
541 let strides = vec![1];
542
543 let mut view = TensorViewMut::new(&mut data, shape, strides, 0);
544
545 *view.get_mut(0).expect("get_mut should succeed") = 10.0;
546 *view.get_mut(1).expect("get_mut should succeed") = 20.0;
547
548 assert_eq!(data, vec![10.0, 20.0, 3.0, 4.0]);
549 }
550
551 #[test]
552 fn test_tensor_view_mut_iter_mut() {
553 let mut data = vec![1.0, 2.0, 3.0, 4.0];
554 let shape = Shape::new(vec![4]);
555 let strides = vec![1];
556
557 let mut view = TensorViewMut::new(&mut data, shape, strides, 0);
558
559 for elem in view.iter_mut() {
560 *elem *= 2.0;
561 }
562
563 assert_eq!(data, vec![2.0, 4.0, 6.0, 8.0]);
564 }
565
566 #[test]
567 fn test_tensor_view_to_vec() {
568 let data = vec![1.0, 2.0, 3.0, 4.0];
569 let shape = Shape::new(vec![4]);
570 let strides = vec![1];
571
572 let view = TensorView::new(&data, shape, strides, 0);
573 let copied = view.to_vec();
574
575 assert_eq!(copied, vec![1.0, 2.0, 3.0, 4.0]);
576 }
577
578 #[test]
579 fn test_tensor_view_with_offset() {
580 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
581 let shape = Shape::new(vec![2]);
582 let strides = vec![1];
583
584 let view = TensorView::new(&data, shape, strides, 2);
586
587 assert_eq!(view.len(), 2);
588 assert_eq!(*view.get(0).expect("get should succeed"), 3.0);
589 assert_eq!(*view.get(1).expect("get should succeed"), 4.0);
590 }
591}