1use crate::{Tensor, TensorStorage};
4use std::collections::HashMap;
5use std::sync::{Arc, RwLock, Weak};
6use torsh_core::{
7 device::DeviceType,
8 dtype::TensorElement,
9 error::{Result, TorshError},
10 shape::Shape,
11};
12
13#[derive(Debug, Clone)]
15pub struct TensorView<T: TensorElement> {
16 storage: Arc<RwLock<ViewStorage<T>>>,
18 shape: Shape,
20 strides: Vec<usize>,
22 offset: usize,
24 device: DeviceType,
26}
27
28#[derive(Debug)]
30struct ViewStorage<T: TensorElement> {
31 #[allow(dead_code)]
33 parent: Weak<RwLock<Vec<T>>>,
34 data_ref: Option<Arc<RwLock<Vec<T>>>>,
36 view_cache: HashMap<ViewKey, Arc<TensorView<T>>>,
38 view_count: usize,
40}
41
42#[derive(Debug, Hash, PartialEq, Eq, Clone)]
44struct ViewKey {
45 shape: Vec<usize>,
46 strides: Vec<usize>,
47 offset: usize,
48}
49
50impl<T: TensorElement + Copy> Tensor<T> {
51 pub fn calculate_strides(&self) -> Vec<usize> {
53 let shape_binding = self.shape();
54 let dims = shape_binding.dims();
55 let mut strides = vec![1; dims.len()];
56 for i in (0..dims.len().saturating_sub(1)).rev() {
57 strides[i] = strides[i + 1] * dims[i + 1];
58 }
59 strides
60 }
61 pub fn create_view(&self, new_shape: &[usize]) -> Result<TensorView<T>> {
63 let new_numel = new_shape.iter().product::<usize>();
64 if new_numel != self.numel() {
65 return Err(TorshError::InvalidOperation(format!(
66 "View shape {:?} has {} elements, but tensor has {} elements",
67 new_shape,
68 new_numel,
69 self.numel()
70 )));
71 }
72
73 let mut strides = vec![1; new_shape.len()];
75 for i in (0..new_shape.len().saturating_sub(1)).rev() {
76 strides[i] = strides[i + 1] * new_shape[i + 1];
77 }
78
79 self.create_view_with_strides(new_shape, &strides, 0)
80 }
81
82 pub fn view_with_strides(
84 &self,
85 new_shape: &[usize],
86 strides: &[usize],
87 ) -> Result<TensorView<T>> {
88 if new_shape.len() != strides.len() {
89 return Err(TorshError::InvalidOperation(
90 "Shape and strides must have same length".to_string(),
91 ));
92 }
93
94 self.create_view_with_strides(new_shape, strides, 0)
95 }
96
97 pub fn slice(&self, dim: usize, start: usize, end: usize) -> Result<TensorView<T>> {
99 let shape_binding = self.shape();
100 let dims = shape_binding.dims();
101 if dim >= dims.len() {
102 return Err(TorshError::InvalidOperation(format!(
103 "Dimension {} out of bounds for tensor with {} dimensions",
104 dim,
105 dims.len()
106 )));
107 }
108
109 if start >= end || end > dims[dim] {
110 return Err(TorshError::InvalidOperation(format!(
111 "Invalid slice range [{}:{}] for dimension of size {}",
112 start, end, dims[dim]
113 )));
114 }
115
116 let mut new_shape = dims.to_vec();
118 new_shape[dim] = end - start;
119
120 let strides = self.calculate_strides();
122 let offset = start * strides[dim];
123
124 self.create_view_with_strides(&new_shape, &strides, offset)
125 }
126
127 fn create_view_with_strides(
129 &self,
130 shape: &[usize],
131 strides: &[usize],
132 offset: usize,
133 ) -> Result<TensorView<T>> {
134 let data_ref = match &self.storage {
136 TensorStorage::InMemory(data) => data.clone(),
137 TensorStorage::MemoryMapped(_) => {
138 let data = self.to_vec()?;
140 Arc::new(RwLock::new(data))
141 }
142 #[cfg(feature = "simd")]
143 TensorStorage::Aligned(data) => {
144 let aligned_data = data.read().expect("lock should not be poisoned");
146 let vec_data = aligned_data.as_slice().to_vec();
147 Arc::new(RwLock::new(vec_data))
148 }
149 #[cfg(feature = "simd")]
150 TensorStorage::SimdOptimized(storage) => {
151 let vec_data = storage.as_slice().to_vec();
153 Arc::new(RwLock::new(vec_data))
154 }
155 };
156
157 let view_storage = ViewStorage {
159 parent: Arc::downgrade(&data_ref),
160 data_ref: Some(data_ref),
161 view_cache: HashMap::new(),
162 view_count: 1,
163 };
164
165 Ok(TensorView {
166 storage: Arc::new(RwLock::new(view_storage)),
167 shape: Shape::new(shape.to_vec()),
168 strides: strides.to_vec(),
169 offset,
170 device: self.device,
171 })
172 }
173
174 pub fn alias(&self) -> TensorAlias<T> {
176 TensorAlias {
177 tensor: self.clone(),
178 is_mutable: false,
179 }
180 }
181
182 pub fn alias_mut(&mut self) -> TensorAlias<T> {
184 TensorAlias {
185 tensor: self.clone(),
186 is_mutable: true,
187 }
188 }
189}
190
191impl<T: TensorElement + Copy> TensorView<T> {
192 pub fn shape(&self) -> &Shape {
194 &self.shape
195 }
196
197 pub fn strides(&self) -> &[usize] {
199 &self.strides
200 }
201
202 pub fn offset(&self) -> usize {
204 self.offset
205 }
206
207 pub fn to_tensor(&self) -> Result<Tensor<T>> {
209 let data = self.to_vec()?;
210 Tensor::from_data(data, self.shape.dims().to_vec(), self.device)
211 }
212
213 pub fn to_vec(&self) -> Result<Vec<T>> {
215 let storage = self.storage.read().expect("lock should not be poisoned");
216 if let Some(data_ref) = &storage.data_ref {
217 let data = data_ref.read().expect("lock should not be poisoned");
218 let mut result = Vec::with_capacity(self.shape.numel());
219
220 self.extract_view_data(&data, &mut result, &mut vec![0; self.shape.ndim()], 0)?;
222
223 Ok(result)
224 } else {
225 Err(TorshError::InvalidOperation(
226 "View data no longer available".to_string(),
227 ))
228 }
229 }
230
231 fn extract_view_data(
233 &self,
234 data: &[T],
235 result: &mut Vec<T>,
236 indices: &mut [usize],
237 dim: usize,
238 ) -> Result<()> {
239 if dim == self.shape.ndim() {
240 let flat_index = self.offset
242 + indices
243 .iter()
244 .zip(self.strides.iter())
245 .map(|(&idx, &stride)| idx * stride)
246 .sum::<usize>();
247
248 if flat_index < data.len() {
249 result.push(data[flat_index]);
250 } else {
251 return Err(TorshError::InvalidOperation(
252 "View index out of bounds".to_string(),
253 ));
254 }
255 } else {
256 for i in 0..self.shape.dims()[dim] {
257 indices[dim] = i;
258 self.extract_view_data(data, result, indices, dim + 1)?;
259 }
260 }
261 Ok(())
262 }
263
264 pub fn is_contiguous(&self) -> bool {
266 let dims = self.shape.dims();
268 let mut expected_strides = vec![1; dims.len()];
269 for i in (0..dims.len().saturating_sub(1)).rev() {
270 expected_strides[i] = expected_strides[i + 1] * dims[i + 1];
271 }
272 self.strides == expected_strides
273 }
274
275 pub fn is_view(&self) -> bool {
277 true
278 }
279
280 pub fn get(&self, indices: &[usize]) -> Result<T> {
282 if indices.len() != self.shape.ndim() {
283 return Err(TorshError::InvalidOperation(format!(
284 "Expected {} indices, got {}",
285 self.shape.ndim(),
286 indices.len()
287 )));
288 }
289
290 for (i, &idx) in indices.iter().enumerate() {
291 if idx >= self.shape.dims()[i] {
292 return Err(TorshError::InvalidOperation(format!(
293 "Index {} out of bounds for dimension {} (size {})",
294 idx,
295 i,
296 self.shape.dims()[i]
297 )));
298 }
299 }
300
301 let storage = self.storage.read().expect("lock should not be poisoned");
302 if let Some(data_ref) = &storage.data_ref {
303 let data = data_ref.read().expect("lock should not be poisoned");
304
305 let flat_index = self.offset
307 + indices
308 .iter()
309 .zip(self.strides.iter())
310 .map(|(&idx, &stride)| idx * stride)
311 .sum::<usize>();
312
313 if flat_index < data.len() {
314 Ok(data[flat_index])
315 } else {
316 Err(TorshError::InvalidOperation(
317 "View index out of bounds".to_string(),
318 ))
319 }
320 } else {
321 Err(TorshError::InvalidOperation(
322 "View data no longer available".to_string(),
323 ))
324 }
325 }
326
327 pub fn view_memory_usage(&self) -> ViewMemoryUsage {
329 let storage = self.storage.read().expect("lock should not be poisoned");
330 ViewMemoryUsage {
331 view_elements: self.shape.numel(),
332 total_elements: storage
333 .data_ref
334 .as_ref()
335 .map(|data| data.read().expect("lock should not be poisoned").len())
336 .unwrap_or(0),
337 active_views: storage.view_count,
338 is_contiguous: self.is_contiguous(),
339 memory_efficiency: self.calculate_memory_efficiency(),
340 }
341 }
342
343 fn calculate_memory_efficiency(&self) -> f64 {
345 let view_size = self.shape.numel();
346 let storage = self.storage.read().expect("lock should not be poisoned");
347 let total_size = storage
348 .data_ref
349 .as_ref()
350 .map(|data| data.read().expect("lock should not be poisoned").len())
351 .unwrap_or(1);
352
353 view_size as f64 / total_size as f64
354 }
355}
356
357#[derive(Debug, Clone)]
359pub struct TensorAlias<T: TensorElement> {
360 tensor: Tensor<T>,
361 is_mutable: bool,
362}
363
364impl<T: TensorElement + Copy> TensorAlias<T> {
365 pub fn tensor(&self) -> &Tensor<T> {
367 &self.tensor
368 }
369
370 pub fn is_mutable(&self) -> bool {
372 self.is_mutable
373 }
374
375 pub fn to_owned(&self) -> Result<Tensor<T>> {
377 Ok(self.tensor.clone())
378 }
379
380 pub fn ref_count(&self) -> usize {
382 match &self.tensor.storage {
383 TensorStorage::InMemory(data) => Arc::strong_count(data),
384 TensorStorage::MemoryMapped(storage) => Arc::strong_count(storage),
385 #[cfg(feature = "simd")]
386 TensorStorage::Aligned(data) => Arc::strong_count(data),
387 #[cfg(feature = "simd")]
388 TensorStorage::SimdOptimized(storage) => Arc::strong_count(storage),
389 }
390 }
391
392 pub fn is_unique(&self) -> bool {
394 self.ref_count() == 1
395 }
396}
397
398#[derive(Debug, Clone)]
400pub struct ViewMemoryUsage {
401 pub view_elements: usize,
403 pub total_elements: usize,
405 pub active_views: usize,
407 pub is_contiguous: bool,
409 pub memory_efficiency: f64,
411}
412
413impl<T: TensorElement + Copy> Drop for ViewStorage<T> {
414 fn drop(&mut self) {
415 self.view_cache.clear();
417 self.view_count = 0;
418 }
419}
420
421#[cfg(test)]
422mod tests {
423 use crate::creation::*;
424
425 #[test]
426 fn test_tensor_view() {
427 let tensor = ones::<f32>(&[2, 3, 4]).expect("ones creation should succeed");
428 let view = tensor
429 .create_view(&[6, 4])
430 .expect("create_view should succeed");
431 assert_eq!(view.shape().dims(), &[6, 4]);
432 assert_eq!(view.shape().numel(), 24);
433 }
434
435 #[test]
436 fn test_tensor_slice() {
437 let tensor = arange(0.0f32, 12.0, 1.0).expect("arange should succeed");
438 let _reshaped = tensor
439 .create_view(&[3, 4])
440 .expect("create_view should succeed");
441 }
445
446 #[test]
447 fn test_tensor_squeeze_unsqueeze() {
448 let tensor = ones::<f32>(&[1, 3, 1, 4]).expect("ones creation should succeed");
449 let squeezed = tensor.squeeze(0).expect("squeeze should succeed");
450 assert_eq!(squeezed.shape().dims(), &[3, 1, 4]);
451
452 let squeezed_all = tensor.squeeze_all().expect("squeeze_all should succeed");
453 assert_eq!(squeezed_all.shape().dims(), &[3, 4]);
454
455 let unsqueezed = tensor.unsqueeze(2).expect("unsqueeze should succeed");
456 assert_eq!(unsqueezed.shape().dims(), &[1, 3, 1, 1, 4]);
457 }
458
459 #[test]
460 fn test_tensor_permute() {
461 let tensor = ones::<f32>(&[2, 3, 4]).expect("ones creation should succeed");
462 let permuted = tensor.permute(&[2, 0, 1]).expect("permute should succeed");
463 assert_eq!(permuted.shape().dims(), &[4, 2, 3]);
464 }
465
466 #[test]
467 fn test_tensor_alias() {
468 let tensor = ones::<f32>(&[10, 10]).expect("ones creation should succeed");
469 let alias = tensor.alias();
470 assert!(!alias.is_mutable());
471 assert!(alias.ref_count() >= 2); }
473
474 #[test]
475 fn test_view_memory_usage() {
476 let tensor = ones::<f32>(&[100, 100]).expect("ones creation should succeed");
477 let view = tensor
478 .create_view(&[1000, 10])
479 .expect("create_view should succeed");
480 let usage = view.view_memory_usage();
481 assert_eq!(usage.view_elements, 10000);
482 assert_eq!(usage.memory_efficiency, 1.0); }
484}