Skip to main content

torsh_tensor/
tensor_view.rs

1//! Zero-Copy Tensor Views for ToRSh
2//!
3//! This module provides immutable and mutable views into tensor data without copying.
4//! Views enable efficient SIMD operations by providing direct access to underlying buffers.
5//!
6//! # Design Goals
7//! - Zero-copy operations (no allocations)
8//! - Direct buffer access for SIMD
9//! - PyTorch-compatible API
10//! - Memory-safe through Rust's borrow checker
11//!
12//! # Architecture
13//! - `TensorView<'a, T>`: Immutable view (multiple readers allowed)
14//! - `TensorViewMut<'a, T>`: Mutable view (exclusive access)
15//!
16//! # Performance Impact
17//! - Eliminates 4 memory copies for SIMD operations
18//! - Enables 2-4x SIMD speedup (per SciRS2 docs)
19//! - Reduces memory allocations by 90%
20
21use torsh_core::dtype::TensorElement;
22use torsh_core::error::{Result, TorshError};
23use torsh_core::shape::Shape;
24
25/// Immutable view into tensor data (zero-copy)
26///
27/// Provides read-only access to tensor data without copying.
28/// Multiple immutable views can coexist (shared borrowing).
29///
30/// # Examples
31/// ```ignore
32/// let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])?;
33/// let view = tensor.view();
34/// assert_eq!(view.len(), 4);
35/// ```
36#[derive(Debug)]
37pub struct TensorView<'a, T: TensorElement> {
38    /// Direct reference to underlying buffer (zero-copy)
39    data: &'a [T],
40
41    /// Shape information
42    shape: Shape,
43
44    /// Strides for multi-dimensional indexing
45    strides: Vec<usize>,
46
47    /// Offset into the parent buffer
48    offset: usize,
49}
50
51/// Mutable view into tensor data (zero-copy)
52///
53/// Provides read-write access to tensor data without copying.
54/// Only one mutable view can exist at a time (exclusive borrowing).
55///
56/// # Examples
57/// ```ignore
58/// let mut tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])?;
59/// let mut view = tensor.view_mut();
60/// view.fill(0.0);
61/// ```
62#[derive(Debug)]
63pub struct TensorViewMut<'a, T: TensorElement> {
64    /// Direct mutable reference to underlying buffer (zero-copy)
65    data: &'a mut [T],
66
67    /// Shape information
68    shape: Shape,
69
70    /// Strides for multi-dimensional indexing
71    strides: Vec<usize>,
72
73    /// Offset into the parent buffer
74    offset: usize,
75}
76
77// ============================================================================
78// TensorView Implementation (Immutable)
79// ============================================================================
80
81impl<'a, T: TensorElement> TensorView<'a, T> {
82    /// Create a new tensor view from raw parts
83    ///
84    /// # Arguments
85    /// * `data` - Reference to underlying buffer
86    /// * `shape` - Shape of the view
87    /// * `strides` - Strides for indexing
88    /// * `offset` - Offset into parent buffer
89    ///
90    /// # Safety
91    /// Caller must ensure:
92    /// - `data` is valid for the lifetime 'a
93    /// - `shape`, `strides`, and `offset` define valid indexing
94    pub fn new(data: &'a [T], shape: Shape, strides: Vec<usize>, offset: usize) -> Self {
95        Self {
96            data,
97            shape,
98            strides,
99            offset,
100        }
101    }
102
103    /// Get the shape of the view
104    #[inline]
105    pub fn shape(&self) -> &Shape {
106        &self.shape
107    }
108
109    /// Get the strides of the view
110    #[inline]
111    pub fn strides(&self) -> &[usize] {
112        &self.strides
113    }
114
115    /// Get the number of elements in the view
116    #[inline]
117    pub fn len(&self) -> usize {
118        self.shape.numel()
119    }
120
121    /// Check if the view is empty
122    #[inline]
123    pub fn is_empty(&self) -> bool {
124        self.len() == 0
125    }
126
127    /// Get a reference to the underlying data slice
128    ///
129    /// Returns the raw data slice starting from the offset.
130    /// For SIMD operations, this provides direct buffer access.
131    #[inline]
132    pub fn data(&self) -> &[T] {
133        &self.data[self.offset..]
134    }
135
136    /// Check if the view is contiguous in memory
137    ///
138    /// Contiguous views enable fast SIMD operations without gather/scatter.
139    pub fn is_contiguous(&self) -> bool {
140        if self.shape.dims().is_empty() {
141            return true;
142        }
143
144        let dims = self.shape.dims();
145        let mut expected_stride = 1;
146
147        // Check from innermost to outermost dimension
148        for i in (0..dims.len()).rev() {
149            if self.strides[i] != expected_stride {
150                return false;
151            }
152            expected_stride *= dims[i];
153        }
154
155        true
156    }
157
158    /// Get element at flat index (zero-copy)
159    ///
160    /// # Arguments
161    /// * `index` - Flat index into the view
162    ///
163    /// # Returns
164    /// Reference to element at index
165    ///
166    /// # Errors
167    /// Returns error if index is out of bounds
168    pub fn get(&self, index: usize) -> Result<&T> {
169        if index >= self.len() {
170            return Err(TorshError::IndexError {
171                index,
172                size: self.len(),
173            });
174        }
175
176        Ok(&self.data[self.offset + index])
177    }
178
179    /// Get element at multi-dimensional index (zero-copy)
180    ///
181    /// # Arguments
182    /// * `indices` - Multi-dimensional indices
183    ///
184    /// # Returns
185    /// Reference to element at indices
186    ///
187    /// # Errors
188    /// Returns error if indices are out of bounds
189    pub fn get_at(&self, indices: &[usize]) -> Result<&T> {
190        if indices.len() != self.shape.ndim() {
191            return Err(TorshError::InvalidArgument(format!(
192                "Expected {} indices, got {}",
193                self.shape.ndim(),
194                indices.len()
195            )));
196        }
197
198        let flat_index = self.compute_flat_index(indices)?;
199        Ok(&self.data[self.offset + flat_index])
200    }
201
202    /// Compute flat index from multi-dimensional indices
203    fn compute_flat_index(&self, indices: &[usize]) -> Result<usize> {
204        let dims = self.shape.dims();
205        let mut flat_index = 0;
206
207        for (i, &idx) in indices.iter().enumerate() {
208            if idx >= dims[i] {
209                return Err(TorshError::IndexError {
210                    index: idx,
211                    size: dims[i],
212                });
213            }
214            flat_index += idx * self.strides[i];
215        }
216
217        Ok(flat_index)
218    }
219
220    /// Create an iterator over the view's elements
221    pub fn iter(&self) -> TensorViewIter<'a, T> {
222        TensorViewIter {
223            data: self.data,
224            offset: self.offset,
225            len: self.len(),
226            current: 0,
227        }
228    }
229
230    /// Convert view to a Vec (copies data)
231    ///
232    /// Note: This creates a copy. For zero-copy operations, use the view directly.
233    pub fn to_vec(&self) -> Vec<T>
234    where
235        T: Copy,
236    {
237        self.data[self.offset..self.offset + self.len()].to_vec()
238    }
239}
240
241// ============================================================================
242// TensorViewMut Implementation (Mutable)
243// ============================================================================
244
245impl<'a, T: TensorElement> TensorViewMut<'a, T> {
246    /// Create a new mutable tensor view from raw parts
247    ///
248    /// # Arguments
249    /// * `data` - Mutable reference to underlying buffer
250    /// * `shape` - Shape of the view
251    /// * `strides` - Strides for indexing
252    /// * `offset` - Offset into parent buffer
253    pub fn new(data: &'a mut [T], shape: Shape, strides: Vec<usize>, offset: usize) -> Self {
254        Self {
255            data,
256            shape,
257            strides,
258            offset,
259        }
260    }
261
262    /// Get the shape of the view
263    #[inline]
264    pub fn shape(&self) -> &Shape {
265        &self.shape
266    }
267
268    /// Get the strides of the view
269    #[inline]
270    pub fn strides(&self) -> &[usize] {
271        &self.strides
272    }
273
274    /// Get the number of elements in the view
275    #[inline]
276    pub fn len(&self) -> usize {
277        self.shape.numel()
278    }
279
280    /// Check if the view is empty
281    #[inline]
282    pub fn is_empty(&self) -> bool {
283        self.len() == 0
284    }
285
286    /// Get a reference to the underlying data slice
287    #[inline]
288    pub fn data(&self) -> &[T] {
289        &self.data[self.offset..]
290    }
291
292    /// Get a mutable reference to the underlying data slice
293    ///
294    /// For in-place SIMD operations, this provides direct buffer access.
295    #[inline]
296    pub fn data_mut(&mut self) -> &mut [T] {
297        let len = self.len();
298        &mut self.data[self.offset..self.offset + len]
299    }
300
301    /// Check if the view is contiguous in memory
302    pub fn is_contiguous(&self) -> bool {
303        if self.shape.dims().is_empty() {
304            return true;
305        }
306
307        let dims = self.shape.dims();
308        let mut expected_stride = 1;
309
310        for i in (0..dims.len()).rev() {
311            if self.strides[i] != expected_stride {
312                return false;
313            }
314            expected_stride *= dims[i];
315        }
316
317        true
318    }
319
320    /// Get element at flat index (zero-copy)
321    pub fn get(&self, index: usize) -> Result<&T> {
322        if index >= self.len() {
323            return Err(TorshError::IndexError {
324                index,
325                size: self.len(),
326            });
327        }
328
329        Ok(&self.data[self.offset + index])
330    }
331
332    /// Get mutable element at flat index (zero-copy)
333    pub fn get_mut(&mut self, index: usize) -> Result<&mut T> {
334        if index >= self.len() {
335            return Err(TorshError::IndexError {
336                index,
337                size: self.len(),
338            });
339        }
340
341        Ok(&mut self.data[self.offset + index])
342    }
343
344    /// Fill the view with a value (in-place)
345    ///
346    /// # Arguments
347    /// * `value` - Value to fill with
348    ///
349    /// # Examples
350    /// ```ignore
351    /// let mut view = tensor.view_mut();
352    /// view.fill(0.0);
353    /// ```
354    pub fn fill(&mut self, value: T)
355    where
356        T: Copy,
357    {
358        let len = self.len();
359        self.data[self.offset..self.offset + len].fill(value);
360    }
361
362    /// Create an iterator over the view's elements
363    pub fn iter(&self) -> TensorViewIter<'_, T> {
364        TensorViewIter {
365            data: self.data,
366            offset: self.offset,
367            len: self.len(),
368            current: 0,
369        }
370    }
371
372    /// Create a mutable iterator over the view's elements
373    pub fn iter_mut(&mut self) -> TensorViewIterMut<'_, T> {
374        let len = self.len();
375        TensorViewIterMut {
376            data: &mut self.data[self.offset..self.offset + len],
377            current: 0,
378        }
379    }
380}
381
382// ============================================================================
383// Iterator Implementations
384// ============================================================================
385
386/// Iterator over immutable tensor view
387pub struct TensorViewIter<'a, T: TensorElement> {
388    data: &'a [T],
389    offset: usize,
390    len: usize,
391    current: usize,
392}
393
394impl<'a, T: TensorElement> Iterator for TensorViewIter<'a, T> {
395    type Item = &'a T;
396
397    fn next(&mut self) -> Option<Self::Item> {
398        if self.current >= self.len {
399            None
400        } else {
401            let item = &self.data[self.offset + self.current];
402            self.current += 1;
403            Some(item)
404        }
405    }
406
407    fn size_hint(&self) -> (usize, Option<usize>) {
408        let remaining = self.len - self.current;
409        (remaining, Some(remaining))
410    }
411}
412
413impl<'a, T: TensorElement> ExactSizeIterator for TensorViewIter<'a, T> {
414    fn len(&self) -> usize {
415        self.len - self.current
416    }
417}
418
419/// Mutable iterator over tensor view
420pub struct TensorViewIterMut<'a, T: TensorElement> {
421    data: &'a mut [T],
422    current: usize,
423}
424
425impl<'a, T: TensorElement> Iterator for TensorViewIterMut<'a, T> {
426    type Item = &'a mut T;
427
428    fn next(&mut self) -> Option<Self::Item> {
429        if self.current >= self.data.len() {
430            None
431        } else {
432            let item = unsafe {
433                // SAFETY: We ensure current < len and never return the same reference twice
434                let ptr = self.data.as_mut_ptr().add(self.current);
435                &mut *ptr
436            };
437            self.current += 1;
438            Some(item)
439        }
440    }
441
442    fn size_hint(&self) -> (usize, Option<usize>) {
443        let remaining = self.data.len() - self.current;
444        (remaining, Some(remaining))
445    }
446}
447
448impl<'a, T: TensorElement> ExactSizeIterator for TensorViewIterMut<'a, T> {
449    fn len(&self) -> usize {
450        self.data.len() - self.current
451    }
452}
453
454// ============================================================================
455// Tests
456// ============================================================================
457
458#[cfg(test)]
459mod tests {
460    use super::*;
461
462    #[test]
463    fn test_tensor_view_creation() {
464        let data = vec![1.0, 2.0, 3.0, 4.0];
465        let shape = Shape::new(vec![2, 2]);
466        let strides = vec![2, 1];
467
468        let view = TensorView::new(&data, shape, strides, 0);
469
470        assert_eq!(view.len(), 4);
471        assert!(!view.is_empty());
472        assert_eq!(view.shape().dims(), &[2, 2]);
473    }
474
475    #[test]
476    fn test_tensor_view_contiguous() {
477        let data = vec![1.0, 2.0, 3.0, 4.0];
478        let shape = Shape::new(vec![2, 2]);
479        let strides = vec![2, 1];
480
481        let view = TensorView::new(&data, shape, strides, 0);
482        assert!(view.is_contiguous());
483    }
484
485    #[test]
486    fn test_tensor_view_get() {
487        let data = vec![1.0, 2.0, 3.0, 4.0];
488        let shape = Shape::new(vec![4]);
489        let strides = vec![1];
490
491        let view = TensorView::new(&data, shape, strides, 0);
492
493        assert_eq!(*view.get(0).expect("get should succeed"), 1.0);
494        assert_eq!(*view.get(1).expect("get should succeed"), 2.0);
495        assert_eq!(*view.get(2).expect("get should succeed"), 3.0);
496        assert_eq!(*view.get(3).expect("get should succeed"), 4.0);
497
498        assert!(view.get(4).is_err());
499    }
500
501    #[test]
502    fn test_tensor_view_iter() {
503        let data = vec![1.0, 2.0, 3.0, 4.0];
504        let shape = Shape::new(vec![4]);
505        let strides = vec![1];
506
507        let view = TensorView::new(&data, shape, strides, 0);
508        let collected: Vec<_> = view.iter().copied().collect();
509
510        assert_eq!(collected, vec![1.0, 2.0, 3.0, 4.0]);
511    }
512
513    #[test]
514    fn test_tensor_view_mut_creation() {
515        let mut data = vec![1.0, 2.0, 3.0, 4.0];
516        let shape = Shape::new(vec![2, 2]);
517        let strides = vec![2, 1];
518
519        let view = TensorViewMut::new(&mut data, shape, strides, 0);
520
521        assert_eq!(view.len(), 4);
522        assert!(!view.is_empty());
523    }
524
525    #[test]
526    fn test_tensor_view_mut_fill() {
527        let mut data = vec![1.0, 2.0, 3.0, 4.0];
528        let shape = Shape::new(vec![4]);
529        let strides = vec![1];
530
531        let mut view = TensorViewMut::new(&mut data, shape, strides, 0);
532        view.fill(0.0);
533
534        assert_eq!(data, vec![0.0, 0.0, 0.0, 0.0]);
535    }
536
537    #[test]
538    fn test_tensor_view_mut_get_mut() {
539        let mut data = vec![1.0, 2.0, 3.0, 4.0];
540        let shape = Shape::new(vec![4]);
541        let strides = vec![1];
542
543        let mut view = TensorViewMut::new(&mut data, shape, strides, 0);
544
545        *view.get_mut(0).expect("get_mut should succeed") = 10.0;
546        *view.get_mut(1).expect("get_mut should succeed") = 20.0;
547
548        assert_eq!(data, vec![10.0, 20.0, 3.0, 4.0]);
549    }
550
551    #[test]
552    fn test_tensor_view_mut_iter_mut() {
553        let mut data = vec![1.0, 2.0, 3.0, 4.0];
554        let shape = Shape::new(vec![4]);
555        let strides = vec![1];
556
557        let mut view = TensorViewMut::new(&mut data, shape, strides, 0);
558
559        for elem in view.iter_mut() {
560            *elem *= 2.0;
561        }
562
563        assert_eq!(data, vec![2.0, 4.0, 6.0, 8.0]);
564    }
565
566    #[test]
567    fn test_tensor_view_to_vec() {
568        let data = vec![1.0, 2.0, 3.0, 4.0];
569        let shape = Shape::new(vec![4]);
570        let strides = vec![1];
571
572        let view = TensorView::new(&data, shape, strides, 0);
573        let copied = view.to_vec();
574
575        assert_eq!(copied, vec![1.0, 2.0, 3.0, 4.0]);
576    }
577
578    #[test]
579    fn test_tensor_view_with_offset() {
580        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
581        let shape = Shape::new(vec![2]);
582        let strides = vec![1];
583
584        // View starting at offset 2
585        let view = TensorView::new(&data, shape, strides, 2);
586
587        assert_eq!(view.len(), 2);
588        assert_eq!(*view.get(0).expect("get should succeed"), 3.0);
589        assert_eq!(*view.get(1).expect("get should succeed"), 4.0);
590    }
591}