Skip to main content

tenflowers_core/tensor/
core.rs

1//! Core Tensor Structure and Properties
2//!
3//! This module contains the fundamental tensor structure, storage definition,
4//! and basic property access methods. It provides the foundation for all
5//! tensor operations while maintaining clean separation of concerns.
6
7use crate::{Device, Result, Shape};
8use scirs2_core::ndarray::ArrayD;
9use std::sync::Arc;
10
11/// Core tensor structure that holds data and metadata
12#[derive(Debug, Clone)]
13pub struct Tensor<T> {
14    pub storage: TensorStorage<T>,
15    pub(in crate::tensor) shape: Shape,
16    pub(in crate::tensor) device: Device,
17    pub(in crate::tensor) requires_grad: bool,
18    pub(in crate::tensor) grad: Option<Arc<Tensor<T>>>,
19}
20
21/// Storage abstraction for different device types
22#[derive(Debug, Clone)]
23pub enum TensorStorage<T> {
24    Cpu(ArrayD<T>),
25    #[cfg(feature = "gpu")]
26    Gpu(crate::gpu::buffer::GpuBuffer<T>),
27}
28
29// Core implementation block for all tensor types
30impl<T> Tensor<T> {
31    /// Get the shape of the tensor
32    pub fn shape(&self) -> &Shape {
33        &self.shape
34    }
35
36    /// Get the device where the tensor is located
37    pub fn device(&self) -> &Device {
38        &self.device
39    }
40
41    /// Get the data type of the tensor
42    pub fn dtype(&self) -> crate::DType
43    where
44        T: 'static,
45    {
46        crate::dtype_from_type::<T>()
47    }
48
49    /// Check if tensor requires gradient computation
50    pub fn requires_grad(&self) -> bool {
51        self.requires_grad
52    }
53
54    /// Set whether tensor requires gradient computation
55    pub fn set_requires_grad(&mut self, requires_grad: bool) {
56        self.requires_grad = requires_grad;
57    }
58
59    /// Get the gradient tensor if it exists
60    pub fn grad(&self) -> Option<&Tensor<T>> {
61        self.grad.as_ref().map(|g| g.as_ref())
62    }
63
64    /// Set the gradient tensor
65    pub fn set_grad(&mut self, grad: Option<Tensor<T>>) {
66        self.grad = grad.map(Arc::new);
67    }
68
69    /// Get a reference to the underlying data (CPU only)
70    pub fn data(&self) -> &[T] {
71        match &self.storage {
72            TensorStorage::Cpu(arr) => {
73                arr.as_slice().unwrap_or_else(|| {
74                    panic!("Tensor data is not contiguous. Use to_owned() or iter() for non-contiguous access.")
75                })
76            }
77            #[cfg(feature = "gpu")]
78            TensorStorage::Gpu(_) => {
79                panic!("Cannot access GPU tensor data directly. Use to_cpu() first.")
80            }
81        }
82    }
83
84    /// Get the value at a specific index (for CPU tensors only)
85    pub fn get(&self, index: &[usize]) -> Option<T>
86    where
87        T: Clone,
88    {
89        match &self.storage {
90            TensorStorage::Cpu(arr) => {
91                if index.len() != arr.ndim() {
92                    return None;
93                }
94                arr.get(index).cloned()
95            }
96            #[cfg(feature = "gpu")]
97            _ => None,
98        }
99    }
100
101    /// Get the underlying data as a slice (CPU tensors only)
102    pub fn as_slice(&self) -> Option<&[T]> {
103        match &self.storage {
104            TensorStorage::Cpu(array) => array.as_slice(),
105            #[cfg(feature = "gpu")]
106            TensorStorage::Gpu(_) => None,
107        }
108    }
109
110    /// Check if tensor is empty (has no elements)
111    pub fn is_empty(&self) -> bool {
112        self.shape.elements() == 0
113    }
114
115    /// Get memory usage in bytes
116    pub fn memory_usage(&self) -> usize {
117        let element_size = std::mem::size_of::<T>();
118        self.shape.elements() * element_size
119    }
120
121    /// Check if two tensors have the same shape
122    pub fn same_shape(&self, other: &Self) -> bool {
123        self.shape == other.shape
124    }
125
126    /// Check if tensors are broadcastable
127    pub fn is_broadcastable_with(&self, other: &Self) -> bool {
128        let dims1 = self.shape.dims();
129        let dims2 = other.shape.dims();
130
131        let max_dims = dims1.len().max(dims2.len());
132
133        for i in 0..max_dims {
134            let dim1 = dims1
135                .get(dims1.len().saturating_sub(i + 1))
136                .copied()
137                .unwrap_or(1);
138            let dim2 = dims2
139                .get(dims2.len().saturating_sub(i + 1))
140                .copied()
141                .unwrap_or(1);
142
143            if dim1 != dim2 && dim1 != 1 && dim2 != 1 {
144                return false;
145            }
146        }
147
148        true
149    }
150
151    /// Get tensor summary statistics as a formatted string
152    pub fn summary(&self) -> String
153    where
154        T: std::fmt::Display + Clone,
155    {
156        format!(
157            "Tensor<{}>: shape={:?}, device={:?}, numel={}, memory={}B, requires_grad={}",
158            std::any::type_name::<T>(),
159            self.shape.dims(),
160            self.device,
161            self.shape.elements(),
162            self.memory_usage(),
163            self.requires_grad
164        )
165    }
166
167    /// Get the total number of elements (alias for size)
168    pub fn size(&self) -> usize {
169        self.shape.size()
170    }
171
172    /// Get the total number of elements
173    pub fn numel(&self) -> usize {
174        self.shape.size()
175    }
176
177    /// Get the number of dimensions (rank)
178    pub fn rank(&self) -> usize {
179        self.shape.rank()
180    }
181
182    /// Get the number of dimensions (alias for rank)
183    pub fn ndim(&self) -> usize {
184        self.shape.rank()
185    }
186
187    /// Check if tensor is a scalar (0-dimensional)
188    pub fn is_scalar(&self) -> bool {
189        self.shape.rank() == 0
190    }
191
192    /// Check if tensor is a vector (1-dimensional)
193    pub fn is_vector(&self) -> bool {
194        self.shape.rank() == 1
195    }
196
197    /// Check if tensor is a matrix (2-dimensional)
198    pub fn is_matrix(&self) -> bool {
199        self.shape.rank() == 2
200    }
201
202    /// Check if tensor data is contiguous in memory
203    pub fn is_contiguous(&self) -> bool {
204        match &self.storage {
205            TensorStorage::Cpu(arr) => arr.is_standard_layout(),
206            #[cfg(feature = "gpu")]
207            TensorStorage::Gpu(_) => true, // GPU buffers are always contiguous
208        }
209    }
210}
211
212// Separate impl block for methods requiring Pod bounds
213impl<T> Tensor<T>
214where
215    T: Clone + bytemuck::Pod + bytemuck::Zeroable + Send + Sync + 'static,
216{
217    /// Apply a function to each element of the tensor
218    pub fn map_inplace<F>(&mut self, f: F) -> Result<()>
219    where
220        F: Fn(&T) -> T,
221    {
222        match &mut self.storage {
223            TensorStorage::Cpu(arr) => {
224                arr.mapv_inplace(|x| f(&x));
225                Ok(())
226            }
227            #[cfg(feature = "gpu")]
228            TensorStorage::Gpu(buffer) => {
229                // Handle GPU case manually to avoid double borrow
230                if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
231                    || std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>()
232                {
233                    // For f32/f64, convert to CPU, apply operation, and convert back
234                    let mut cpu_array = buffer.to_cpu_array()?;
235                    cpu_array.mapv_inplace(|x| f(&x));
236                    let device_id = match self.device {
237                        crate::Device::Gpu(id) => id,
238                        _ => {
239                            return Err(crate::TensorError::device_error_simple(
240                                "Expected GPU device".to_string(),
241                            ))
242                        }
243                    };
244                    let new_gpu_buffer =
245                        crate::gpu::buffer::GpuBuffer::from_cpu_array(&cpu_array, device_id)?;
246                    *buffer = new_gpu_buffer;
247                    Ok(())
248                } else {
249                    // Fallback: not supported for this type
250                    Err(crate::TensorError::unsupported_operation_simple(format!(
251                        "GPU map_inplace not supported for type {}",
252                        std::any::type_name::<T>()
253                    )))
254                }
255            }
256        }
257    }
258}