tenflowers_core/tensor/
core.rs1use crate::{Device, Result, Shape};
8use scirs2_core::ndarray::ArrayD;
9use std::sync::Arc;
10
11#[derive(Debug, Clone)]
13pub struct Tensor<T> {
14 pub storage: TensorStorage<T>,
15 pub(in crate::tensor) shape: Shape,
16 pub(in crate::tensor) device: Device,
17 pub(in crate::tensor) requires_grad: bool,
18 pub(in crate::tensor) grad: Option<Arc<Tensor<T>>>,
19}
20
21#[derive(Debug, Clone)]
23pub enum TensorStorage<T> {
24 Cpu(ArrayD<T>),
25 #[cfg(feature = "gpu")]
26 Gpu(crate::gpu::buffer::GpuBuffer<T>),
27}
28
29impl<T> Tensor<T> {
31 pub fn shape(&self) -> &Shape {
33 &self.shape
34 }
35
36 pub fn device(&self) -> &Device {
38 &self.device
39 }
40
41 pub fn dtype(&self) -> crate::DType
43 where
44 T: 'static,
45 {
46 crate::dtype_from_type::<T>()
47 }
48
49 pub fn requires_grad(&self) -> bool {
51 self.requires_grad
52 }
53
54 pub fn set_requires_grad(&mut self, requires_grad: bool) {
56 self.requires_grad = requires_grad;
57 }
58
59 pub fn grad(&self) -> Option<&Tensor<T>> {
61 self.grad.as_ref().map(|g| g.as_ref())
62 }
63
64 pub fn set_grad(&mut self, grad: Option<Tensor<T>>) {
66 self.grad = grad.map(Arc::new);
67 }
68
69 pub fn data(&self) -> &[T] {
71 match &self.storage {
72 TensorStorage::Cpu(arr) => {
73 arr.as_slice().unwrap_or_else(|| {
74 panic!("Tensor data is not contiguous. Use to_owned() or iter() for non-contiguous access.")
75 })
76 }
77 #[cfg(feature = "gpu")]
78 TensorStorage::Gpu(_) => {
79 panic!("Cannot access GPU tensor data directly. Use to_cpu() first.")
80 }
81 }
82 }
83
84 pub fn get(&self, index: &[usize]) -> Option<T>
86 where
87 T: Clone,
88 {
89 match &self.storage {
90 TensorStorage::Cpu(arr) => {
91 if index.len() != arr.ndim() {
92 return None;
93 }
94 arr.get(index).cloned()
95 }
96 #[cfg(feature = "gpu")]
97 _ => None,
98 }
99 }
100
101 pub fn as_slice(&self) -> Option<&[T]> {
103 match &self.storage {
104 TensorStorage::Cpu(array) => array.as_slice(),
105 #[cfg(feature = "gpu")]
106 TensorStorage::Gpu(_) => None,
107 }
108 }
109
110 pub fn is_empty(&self) -> bool {
112 self.shape.elements() == 0
113 }
114
115 pub fn memory_usage(&self) -> usize {
117 let element_size = std::mem::size_of::<T>();
118 self.shape.elements() * element_size
119 }
120
121 pub fn same_shape(&self, other: &Self) -> bool {
123 self.shape == other.shape
124 }
125
126 pub fn is_broadcastable_with(&self, other: &Self) -> bool {
128 let dims1 = self.shape.dims();
129 let dims2 = other.shape.dims();
130
131 let max_dims = dims1.len().max(dims2.len());
132
133 for i in 0..max_dims {
134 let dim1 = dims1
135 .get(dims1.len().saturating_sub(i + 1))
136 .copied()
137 .unwrap_or(1);
138 let dim2 = dims2
139 .get(dims2.len().saturating_sub(i + 1))
140 .copied()
141 .unwrap_or(1);
142
143 if dim1 != dim2 && dim1 != 1 && dim2 != 1 {
144 return false;
145 }
146 }
147
148 true
149 }
150
151 pub fn summary(&self) -> String
153 where
154 T: std::fmt::Display + Clone,
155 {
156 format!(
157 "Tensor<{}>: shape={:?}, device={:?}, numel={}, memory={}B, requires_grad={}",
158 std::any::type_name::<T>(),
159 self.shape.dims(),
160 self.device,
161 self.shape.elements(),
162 self.memory_usage(),
163 self.requires_grad
164 )
165 }
166
167 pub fn size(&self) -> usize {
169 self.shape.size()
170 }
171
172 pub fn numel(&self) -> usize {
174 self.shape.size()
175 }
176
177 pub fn rank(&self) -> usize {
179 self.shape.rank()
180 }
181
182 pub fn ndim(&self) -> usize {
184 self.shape.rank()
185 }
186
187 pub fn is_scalar(&self) -> bool {
189 self.shape.rank() == 0
190 }
191
192 pub fn is_vector(&self) -> bool {
194 self.shape.rank() == 1
195 }
196
197 pub fn is_matrix(&self) -> bool {
199 self.shape.rank() == 2
200 }
201
202 pub fn is_contiguous(&self) -> bool {
204 match &self.storage {
205 TensorStorage::Cpu(arr) => arr.is_standard_layout(),
206 #[cfg(feature = "gpu")]
207 TensorStorage::Gpu(_) => true, }
209 }
210}
211
212impl<T> Tensor<T>
214where
215 T: Clone + bytemuck::Pod + bytemuck::Zeroable + Send + Sync + 'static,
216{
217 pub fn map_inplace<F>(&mut self, f: F) -> Result<()>
219 where
220 F: Fn(&T) -> T,
221 {
222 match &mut self.storage {
223 TensorStorage::Cpu(arr) => {
224 arr.mapv_inplace(|x| f(&x));
225 Ok(())
226 }
227 #[cfg(feature = "gpu")]
228 TensorStorage::Gpu(buffer) => {
229 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
231 || std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>()
232 {
233 let mut cpu_array = buffer.to_cpu_array()?;
235 cpu_array.mapv_inplace(|x| f(&x));
236 let device_id = match self.device {
237 crate::Device::Gpu(id) => id,
238 _ => {
239 return Err(crate::TensorError::device_error_simple(
240 "Expected GPU device".to_string(),
241 ))
242 }
243 };
244 let new_gpu_buffer =
245 crate::gpu::buffer::GpuBuffer::from_cpu_array(&cpu_array, device_id)?;
246 *buffer = new_gpu_buffer;
247 Ok(())
248 } else {
249 Err(crate::TensorError::unsupported_operation_simple(format!(
251 "GPU map_inplace not supported for type {}",
252 std::any::type_name::<T>()
253 )))
254 }
255 }
256 }
257 }
258}