Skip to main content

torsh_tensor/
tensor_views.rs

1// Tensor views and aliasing module for efficient memory management
2
3use crate::{Tensor, TensorStorage};
4use std::collections::HashMap;
5use std::sync::{Arc, RwLock, Weak};
6use torsh_core::{
7    device::DeviceType,
8    dtype::TensorElement,
9    error::{Result, TorshError},
10    shape::Shape,
11};
12
13/// A view into a tensor that shares memory but may have different shape/strides
14#[derive(Debug, Clone)]
15pub struct TensorView<T: TensorElement> {
16    /// Reference to the underlying tensor storage
17    storage: Arc<RwLock<ViewStorage<T>>>,
18    /// Shape of this view
19    shape: Shape,
20    /// Strides for this view
21    strides: Vec<usize>,
22    /// Offset into the underlying data
23    offset: usize,
24    /// Device type
25    device: DeviceType,
26}
27
28/// Storage for tensor views with reference counting
29#[derive(Debug)]
30struct ViewStorage<T: TensorElement> {
31    /// Weak reference to parent tensor to avoid cycles
32    #[allow(dead_code)]
33    parent: Weak<RwLock<Vec<T>>>,
34    /// Strong reference to keep data alive if needed
35    data_ref: Option<Arc<RwLock<Vec<T>>>>,
36    /// Cache for computed views
37    view_cache: HashMap<ViewKey, Arc<TensorView<T>>>,
38    /// Reference count for active views
39    view_count: usize,
40}
41
42/// Key for view caching
43#[derive(Debug, Hash, PartialEq, Eq, Clone)]
44struct ViewKey {
45    shape: Vec<usize>,
46    strides: Vec<usize>,
47    offset: usize,
48}
49
50impl<T: TensorElement + Copy> Tensor<T> {
51    /// Calculate strides for current tensor shape
52    pub fn calculate_strides(&self) -> Vec<usize> {
53        let shape_binding = self.shape();
54        let dims = shape_binding.dims();
55        let mut strides = vec![1; dims.len()];
56        for i in (0..dims.len().saturating_sub(1)).rev() {
57            strides[i] = strides[i + 1] * dims[i + 1];
58        }
59        strides
60    }
61    /// Create a view of this tensor with a new shape (must have same number of elements)
62    pub fn create_view(&self, new_shape: &[usize]) -> Result<TensorView<T>> {
63        let new_numel = new_shape.iter().product::<usize>();
64        if new_numel != self.numel() {
65            return Err(TorshError::InvalidOperation(format!(
66                "View shape {:?} has {} elements, but tensor has {} elements",
67                new_shape,
68                new_numel,
69                self.numel()
70            )));
71        }
72
73        // Calculate strides for the new shape (row-major)
74        let mut strides = vec![1; new_shape.len()];
75        for i in (0..new_shape.len().saturating_sub(1)).rev() {
76            strides[i] = strides[i + 1] * new_shape[i + 1];
77        }
78
79        self.create_view_with_strides(new_shape, &strides, 0)
80    }
81
82    /// Create a view with custom strides (advanced usage)
83    pub fn view_with_strides(
84        &self,
85        new_shape: &[usize],
86        strides: &[usize],
87    ) -> Result<TensorView<T>> {
88        if new_shape.len() != strides.len() {
89            return Err(TorshError::InvalidOperation(
90                "Shape and strides must have same length".to_string(),
91            ));
92        }
93
94        self.create_view_with_strides(new_shape, strides, 0)
95    }
96
97    /// Create a slice view of the tensor along a specific dimension
98    pub fn slice(&self, dim: usize, start: usize, end: usize) -> Result<TensorView<T>> {
99        let shape_binding = self.shape();
100        let dims = shape_binding.dims();
101        if dim >= dims.len() {
102            return Err(TorshError::InvalidOperation(format!(
103                "Dimension {} out of bounds for tensor with {} dimensions",
104                dim,
105                dims.len()
106            )));
107        }
108
109        if start >= end || end > dims[dim] {
110            return Err(TorshError::InvalidOperation(format!(
111                "Invalid slice range [{}:{}] for dimension of size {}",
112                start, end, dims[dim]
113            )));
114        }
115
116        // Calculate new shape and offset
117        let mut new_shape = dims.to_vec();
118        new_shape[dim] = end - start;
119
120        // Calculate offset for the slice
121        let strides = self.calculate_strides();
122        let offset = start * strides[dim];
123
124        self.create_view_with_strides(&new_shape, &strides, offset)
125    }
126
127    /// Internal method to create views with custom strides and offset
128    fn create_view_with_strides(
129        &self,
130        shape: &[usize],
131        strides: &[usize],
132        offset: usize,
133    ) -> Result<TensorView<T>> {
134        // Get reference to underlying data
135        let data_ref = match &self.storage {
136            TensorStorage::InMemory(data) => data.clone(),
137            TensorStorage::MemoryMapped(_) => {
138                // For memory-mapped storage, convert to in-memory for views
139                let data = self.to_vec()?;
140                Arc::new(RwLock::new(data))
141            }
142            #[cfg(feature = "simd")]
143            TensorStorage::Aligned(data) => {
144                // Convert AlignedVec to Vec for standard view handling
145                let aligned_data = data.read().expect("lock should not be poisoned");
146                let vec_data = aligned_data.as_slice().to_vec();
147                Arc::new(RwLock::new(vec_data))
148            }
149            #[cfg(feature = "simd")]
150            TensorStorage::SimdOptimized(storage) => {
151                // Lock-free access - convert to Vec for view handling
152                let vec_data = storage.as_slice().to_vec();
153                Arc::new(RwLock::new(vec_data))
154            }
155        };
156
157        // Create view storage
158        let view_storage = ViewStorage {
159            parent: Arc::downgrade(&data_ref),
160            data_ref: Some(data_ref),
161            view_cache: HashMap::new(),
162            view_count: 1,
163        };
164
165        Ok(TensorView {
166            storage: Arc::new(RwLock::new(view_storage)),
167            shape: Shape::new(shape.to_vec()),
168            strides: strides.to_vec(),
169            offset,
170            device: self.device,
171        })
172    }
173
174    /// Create an alias (shared reference) to this tensor
175    pub fn alias(&self) -> TensorAlias<T> {
176        TensorAlias {
177            tensor: self.clone(),
178            is_mutable: false,
179        }
180    }
181
182    /// Create a mutable alias to this tensor
183    pub fn alias_mut(&mut self) -> TensorAlias<T> {
184        TensorAlias {
185            tensor: self.clone(),
186            is_mutable: true,
187        }
188    }
189}
190
191impl<T: TensorElement + Copy> TensorView<T> {
192    /// Get the shape of this view
193    pub fn shape(&self) -> &Shape {
194        &self.shape
195    }
196
197    /// Get the strides of this view
198    pub fn strides(&self) -> &[usize] {
199        &self.strides
200    }
201
202    /// Get the offset of this view
203    pub fn offset(&self) -> usize {
204        self.offset
205    }
206
207    /// Convert view to a contiguous tensor
208    pub fn to_tensor(&self) -> Result<Tensor<T>> {
209        let data = self.to_vec()?;
210        Tensor::from_data(data, self.shape.dims().to_vec(), self.device)
211    }
212
213    /// Get data as vector (materializes the view)
214    pub fn to_vec(&self) -> Result<Vec<T>> {
215        let storage = self.storage.read().expect("lock should not be poisoned");
216        if let Some(data_ref) = &storage.data_ref {
217            let data = data_ref.read().expect("lock should not be poisoned");
218            let mut result = Vec::with_capacity(self.shape.numel());
219
220            // Extract data according to view's shape, strides, and offset
221            self.extract_view_data(&data, &mut result, &mut vec![0; self.shape.ndim()], 0)?;
222
223            Ok(result)
224        } else {
225            Err(TorshError::InvalidOperation(
226                "View data no longer available".to_string(),
227            ))
228        }
229    }
230
231    /// Recursively extract data for the view
232    fn extract_view_data(
233        &self,
234        data: &[T],
235        result: &mut Vec<T>,
236        indices: &mut [usize],
237        dim: usize,
238    ) -> Result<()> {
239        if dim == self.shape.ndim() {
240            // Calculate flat index from view indices
241            let flat_index = self.offset
242                + indices
243                    .iter()
244                    .zip(self.strides.iter())
245                    .map(|(&idx, &stride)| idx * stride)
246                    .sum::<usize>();
247
248            if flat_index < data.len() {
249                result.push(data[flat_index]);
250            } else {
251                return Err(TorshError::InvalidOperation(
252                    "View index out of bounds".to_string(),
253                ));
254            }
255        } else {
256            for i in 0..self.shape.dims()[dim] {
257                indices[dim] = i;
258                self.extract_view_data(data, result, indices, dim + 1)?;
259            }
260        }
261        Ok(())
262    }
263
264    /// Check if this view is contiguous in memory
265    pub fn is_contiguous(&self) -> bool {
266        // A view is contiguous if its strides match row-major layout
267        let dims = self.shape.dims();
268        let mut expected_strides = vec![1; dims.len()];
269        for i in (0..dims.len().saturating_sub(1)).rev() {
270            expected_strides[i] = expected_strides[i + 1] * dims[i + 1];
271        }
272        self.strides == expected_strides
273    }
274
275    /// Check if this is a view (always true for TensorView)
276    pub fn is_view(&self) -> bool {
277        true
278    }
279
280    /// Get element at specific indices
281    pub fn get(&self, indices: &[usize]) -> Result<T> {
282        if indices.len() != self.shape.ndim() {
283            return Err(TorshError::InvalidOperation(format!(
284                "Expected {} indices, got {}",
285                self.shape.ndim(),
286                indices.len()
287            )));
288        }
289
290        for (i, &idx) in indices.iter().enumerate() {
291            if idx >= self.shape.dims()[i] {
292                return Err(TorshError::InvalidOperation(format!(
293                    "Index {} out of bounds for dimension {} (size {})",
294                    idx,
295                    i,
296                    self.shape.dims()[i]
297                )));
298            }
299        }
300
301        let storage = self.storage.read().expect("lock should not be poisoned");
302        if let Some(data_ref) = &storage.data_ref {
303            let data = data_ref.read().expect("lock should not be poisoned");
304
305            // Calculate flat index from view indices
306            let flat_index = self.offset
307                + indices
308                    .iter()
309                    .zip(self.strides.iter())
310                    .map(|(&idx, &stride)| idx * stride)
311                    .sum::<usize>();
312
313            if flat_index < data.len() {
314                Ok(data[flat_index])
315            } else {
316                Err(TorshError::InvalidOperation(
317                    "View index out of bounds".to_string(),
318                ))
319            }
320        } else {
321            Err(TorshError::InvalidOperation(
322                "View data no longer available".to_string(),
323            ))
324        }
325    }
326
327    /// Get memory usage of this view
328    pub fn view_memory_usage(&self) -> ViewMemoryUsage {
329        let storage = self.storage.read().expect("lock should not be poisoned");
330        ViewMemoryUsage {
331            view_elements: self.shape.numel(),
332            total_elements: storage
333                .data_ref
334                .as_ref()
335                .map(|data| data.read().expect("lock should not be poisoned").len())
336                .unwrap_or(0),
337            active_views: storage.view_count,
338            is_contiguous: self.is_contiguous(),
339            memory_efficiency: self.calculate_memory_efficiency(),
340        }
341    }
342
343    /// Calculate memory efficiency of this view
344    fn calculate_memory_efficiency(&self) -> f64 {
345        let view_size = self.shape.numel();
346        let storage = self.storage.read().expect("lock should not be poisoned");
347        let total_size = storage
348            .data_ref
349            .as_ref()
350            .map(|data| data.read().expect("lock should not be poisoned").len())
351            .unwrap_or(1);
352
353        view_size as f64 / total_size as f64
354    }
355}
356
357/// An alias to a tensor that shares memory
358#[derive(Debug, Clone)]
359pub struct TensorAlias<T: TensorElement> {
360    tensor: Tensor<T>,
361    is_mutable: bool,
362}
363
364impl<T: TensorElement + Copy> TensorAlias<T> {
365    /// Get reference to the underlying tensor
366    pub fn tensor(&self) -> &Tensor<T> {
367        &self.tensor
368    }
369
370    /// Check if this alias allows mutation
371    pub fn is_mutable(&self) -> bool {
372        self.is_mutable
373    }
374
375    /// Convert to owned tensor (creates copy if shared)
376    pub fn to_owned(&self) -> Result<Tensor<T>> {
377        Ok(self.tensor.clone())
378    }
379
380    /// Get the reference count of the underlying data
381    pub fn ref_count(&self) -> usize {
382        match &self.tensor.storage {
383            TensorStorage::InMemory(data) => Arc::strong_count(data),
384            TensorStorage::MemoryMapped(storage) => Arc::strong_count(storage),
385            #[cfg(feature = "simd")]
386            TensorStorage::Aligned(data) => Arc::strong_count(data),
387            #[cfg(feature = "simd")]
388            TensorStorage::SimdOptimized(storage) => Arc::strong_count(storage),
389        }
390    }
391
392    /// Check if this alias has exclusive access to the data
393    pub fn is_unique(&self) -> bool {
394        self.ref_count() == 1
395    }
396}
397
398/// Memory usage information for tensor views
399#[derive(Debug, Clone)]
400pub struct ViewMemoryUsage {
401    /// Number of elements in this view
402    pub view_elements: usize,
403    /// Total elements in underlying storage
404    pub total_elements: usize,
405    /// Number of active views on this storage
406    pub active_views: usize,
407    /// Whether the view is contiguous in memory
408    pub is_contiguous: bool,
409    /// Memory efficiency (view_size / total_size)
410    pub memory_efficiency: f64,
411}
412
413impl<T: TensorElement + Copy> Drop for ViewStorage<T> {
414    fn drop(&mut self) {
415        // Clean up view cache and decrement reference counts
416        self.view_cache.clear();
417        self.view_count = 0;
418    }
419}
420
421#[cfg(test)]
422mod tests {
423    use crate::creation::*;
424
425    #[test]
426    fn test_tensor_view() {
427        let tensor = ones::<f32>(&[2, 3, 4]).expect("ones creation should succeed");
428        let view = tensor
429            .create_view(&[6, 4])
430            .expect("create_view should succeed");
431        assert_eq!(view.shape().dims(), &[6, 4]);
432        assert_eq!(view.shape().numel(), 24);
433    }
434
435    #[test]
436    fn test_tensor_slice() {
437        let tensor = arange(0.0f32, 12.0, 1.0).expect("arange should succeed");
438        let _reshaped = tensor
439            .create_view(&[3, 4])
440            .expect("create_view should succeed");
441        // This would work in a full implementation
442        // let slice = reshaped.slice(0, 1, 3).unwrap();
443        // assert_eq!(slice.shape().dims(), &[2, 4]);
444    }
445
446    #[test]
447    fn test_tensor_squeeze_unsqueeze() {
448        let tensor = ones::<f32>(&[1, 3, 1, 4]).expect("ones creation should succeed");
449        let squeezed = tensor.squeeze(0).expect("squeeze should succeed");
450        assert_eq!(squeezed.shape().dims(), &[3, 1, 4]);
451
452        let squeezed_all = tensor.squeeze_all().expect("squeeze_all should succeed");
453        assert_eq!(squeezed_all.shape().dims(), &[3, 4]);
454
455        let unsqueezed = tensor.unsqueeze(2).expect("unsqueeze should succeed");
456        assert_eq!(unsqueezed.shape().dims(), &[1, 3, 1, 1, 4]);
457    }
458
459    #[test]
460    fn test_tensor_permute() {
461        let tensor = ones::<f32>(&[2, 3, 4]).expect("ones creation should succeed");
462        let permuted = tensor.permute(&[2, 0, 1]).expect("permute should succeed");
463        assert_eq!(permuted.shape().dims(), &[4, 2, 3]);
464    }
465
466    #[test]
467    fn test_tensor_alias() {
468        let tensor = ones::<f32>(&[10, 10]).expect("ones creation should succeed");
469        let alias = tensor.alias();
470        assert!(!alias.is_mutable());
471        assert!(alias.ref_count() >= 2); // Original + alias
472    }
473
474    #[test]
475    fn test_view_memory_usage() {
476        let tensor = ones::<f32>(&[100, 100]).expect("ones creation should succeed");
477        let view = tensor
478            .create_view(&[1000, 10])
479            .expect("create_view should succeed");
480        let usage = view.view_memory_usage();
481        assert_eq!(usage.view_elements, 10000);
482        assert_eq!(usage.memory_efficiency, 1.0); // Full tensor view
483    }
484}