Skip to main content

svod_tensor/
data.rs

1use bon::bon;
2use snafu::ResultExt;
3use std::sync::Arc;
4
5use svod_device::{Buffer, registry};
6use svod_dtype::DType;
7use svod_dtype::ext::HasDType;
8use svod_ir::{DeviceSpec, SInt, UOp, shape::Shape};
9
10use crate::Tensor;
11use crate::error::*;
12use crate::tensor_registry;
13
14#[bon]
15impl Tensor {
16    /// Create tensor from slice on CPU (default device).
17    ///
18    /// # Examples
19    /// ```
20    /// # use svod_tensor::Tensor;
21    /// let a = Tensor::from_slice(&[1.0f32, 2.0, 3.0]);
22    /// ```
23    pub fn from_slice<T: HasDType, C: AsRef<[T]>>(source: C) -> Self {
24        let source = source.as_ref();
25        Self::from_bytes_shaped(
26            unsafe { std::slice::from_raw_parts(source.as_ptr() as *const u8, source.len() * T::DTYPE.bytes()) },
27            &[source.len()],
28            T::DTYPE,
29            DeviceSpec::Cpu,
30        )
31    }
32
33    /// Create tensor from slice with explicit device specification using builder pattern.
34    #[builder]
35    pub fn from_slice_with<T: HasDType, C: AsRef<[T]>>(
36        source: C,
37        #[builder(default = DeviceSpec::Cpu)] device: DeviceSpec,
38    ) -> Self {
39        let source = source.as_ref();
40        Self::from_bytes_shaped(
41            unsafe { std::slice::from_raw_parts(source.as_ptr() as *const u8, source.len() * T::DTYPE.bytes()) },
42            &[source.len()],
43            T::DTYPE,
44            device,
45        )
46    }
47}
48
49impl Tensor {
50    /// Core: create a tensor from raw bytes with a known shape.
51    ///
52    /// Builds the buffer UOp with the target shape directly — no reshape,
53    /// so the returned tensor retains its buffer for zero-copy `array_view`.
54    fn from_bytes_shaped(bytes: &[u8], shape: &[usize], dtype: DType, device: DeviceSpec) -> Self {
55        let numel: usize = shape.iter().product();
56        let ir_shape = Shape::from_iter(shape.iter().map(|&d| SInt::Const(d)));
57
58        let buffer_uop = UOp::new_buffer(device.clone(), numel, dtype.clone());
59        let buffer_uop_id = buffer_uop.id;
60
61        let allocator = match &device {
62            DeviceSpec::Cpu => registry::cpu().expect("CPU always should be accessible"),
63            _ => registry::cpu().expect("CPU fallback for unsupported device"),
64        };
65
66        let mut buffer = Buffer::new(allocator, dtype.clone(), shape.to_vec(), Default::default());
67        buffer.copyin(bytes).expect("Buffer write always successful");
68
69        let buffer_arc = Arc::new(buffer);
70        let uop = buffer_uop.try_reshape(&ir_shape).expect("shape matches element count");
71
72        let entry = tensor_registry::register_tensor_with_buffer(uop, buffer_arc.clone(), buffer_uop_id);
73        Self::with_buffer(entry, buffer_arc)
74    }
75
76    /// Create tensor from raw bytes with explicit dtype and shape.
77    ///
78    /// The bytes are interpreted as little-endian values of the given dtype.
79    /// Length must equal `product(shape) * dtype.bytes()`.
80    /// Used for types without a native Rust representation (Float16, BFloat16, FP8).
81    pub fn from_raw_bytes(data: &[u8], shape: &[usize], dtype: DType) -> Result<Self> {
82        let numel: usize = shape.iter().product();
83        let expected_bytes = numel * dtype.bytes();
84        if data.len() != expected_bytes {
85            return Err(Error::IrConstruction {
86                details: format!(
87                    "from_raw_bytes: data length {} != expected {} ({} elements * {} bytes)",
88                    data.len(),
89                    expected_bytes,
90                    numel,
91                    dtype.bytes()
92                ),
93            });
94        }
95        Ok(Self::from_bytes_shaped(data, shape, dtype, DeviceSpec::Cpu))
96    }
97
98    /// Create tensor from an ndarray (owned `Array` or `ArrayView`).
99    ///
100    /// When the array is already C-contiguous, uses the backing slice directly
101    /// (no intermediate allocation). Falls back to `.iter().cloned().collect()`
102    /// for Fortran-order or non-contiguous layouts.
103    ///
104    /// # Examples
105    /// ```
106    /// # use svod_tensor::Tensor;
107    /// # use ndarray::array;
108    /// let t = Tensor::from_ndarray(&array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]]);
109    /// let view = t.array_view::<f32>().unwrap();
110    /// assert_eq!(view[[1, 2]], 6.0);
111    /// ```
112    pub fn from_ndarray<T, S, D>(array: &ndarray::ArrayBase<S, D>) -> Self
113    where
114        T: HasDType + Clone,
115        S: ndarray::Data<Elem = T>,
116        D: ndarray::Dimension,
117    {
118        let shape: Vec<usize> = array.shape().to_vec();
119        if array.is_empty() {
120            let t = Self::empty_zero(T::DTYPE);
121            if shape.len() <= 1 {
122                return t;
123            }
124            let isize_shape: Vec<isize> = shape.iter().map(|&d| d as isize).collect();
125            return t.try_reshape(&isize_shape).expect("empty reshape matches");
126        }
127        // Fast path: C-contiguous — use backing slice directly, no intermediate Vec
128        if let Some(slice) = array.as_slice() {
129            let bytes =
130                unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const u8, slice.len() * T::DTYPE.bytes()) };
131            Self::from_bytes_shaped(bytes, &shape, T::DTYPE, DeviceSpec::Cpu)
132        } else {
133            // Slow path: Fortran-order or non-contiguous — collect in logical order
134            let data: Vec<T> = array.iter().cloned().collect();
135            let bytes =
136                unsafe { std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * T::DTYPE.bytes()) };
137            Self::from_bytes_shaped(bytes, &shape, T::DTYPE, DeviceSpec::Cpu)
138        }
139    }
140
141    /// Get a reference to the underlying buffer.
142    ///
143    /// Returns `None` for lazy tensors that haven't been realized yet.
144    /// Returns `Some(buffer)` for input tensors and realized tensors.
145    pub fn buffer(&self) -> Option<Buffer> {
146        // Check local field first, then entry, then global registry by base UOp ID.
147        if let Some(buf) = self.buffer.as_ref().or_else(|| self.entry.buffer()) {
148            return Some((**buf).clone());
149        }
150        crate::tensor_registry::get_buffer_arc(self.uop().base().id).map(|arc| (*arc).clone())
151    }
152
153    /// Read realized tensor data as an ndarray.
154    ///
155    /// The tensor must have a buffer (from `from_slice`, `realize()`, etc.).
156    /// Returns error if the tensor has not been realized.
157    ///
158    /// # Examples
159    /// ```
160    /// # use svod_tensor::Tensor;
161    /// let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0]);
162    /// let result = t.as_ndarray::<f32>().unwrap();
163    /// assert_eq!(result.shape(), &[3]);
164    /// ```
165    pub fn as_ndarray<T: HasDType + Default + Clone>(&self) -> Result<ndarray::ArrayD<T>> {
166        use ndarray::{ArrayD, IxDyn};
167
168        let uop = self.uop();
169        let shape = uop.shape().context(UOpSnafu)?.ok_or(Error::NoShape)?;
170
171        // Refuse symbolic shapes — matches Tinygrad: assert all_int(self.shape)
172        if shape.iter().any(|dim| dim.as_const().is_none()) {
173            return SymbolicShapeSnafu.fail();
174        }
175
176        let dims: Vec<usize> = shape.iter().map(|dim| dim.as_const().unwrap()).collect();
177
178        if dims.contains(&0) {
179            let arr = ArrayD::from_shape_vec(IxDyn(&dims), vec![]).context(NdarrayShapeSnafu)?;
180            return Ok(arr);
181        }
182
183        let buffer = self.buffer().ok_or(Error::NoBuffer)?;
184
185        if buffer.dtype() != T::DTYPE {
186            return TypeMismatchSnafu { expected: T::DTYPE, actual: buffer.dtype() }.fail();
187        }
188
189        let count = buffer.size() / T::DTYPE.bytes();
190        let mut data = vec![T::default(); count];
191        buffer
192            .copyout(unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut u8, count * T::DTYPE.bytes()) })
193            .context(DeviceSnafu)?;
194
195        let arr = ArrayD::from_shape_vec(IxDyn(&dims), data).context(NdarrayShapeSnafu)?;
196        Ok(arr)
197    }
198
199    /// Read realized tensor data as a flat `Vec<T>`.
200    ///
201    /// The tensor must have a buffer (from `from_slice`, `realize()`, etc.).
202    /// Returns error if the tensor has not been realized.
203    ///
204    /// # Examples
205    /// ```
206    /// # use svod_tensor::Tensor;
207    /// let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0]);
208    /// let v = t.as_vec::<f32>().unwrap();
209    /// assert_eq!(v, vec![1.0, 2.0, 3.0]);
210    /// ```
211    pub fn as_vec<T: HasDType + Default + Clone>(&self) -> Result<Vec<T>> {
212        let uop = self.uop();
213        if let Ok(Some(shape)) = uop.shape() {
214            // Refuse symbolic shapes — matches Tinygrad: assert all_int(self.shape)
215            if shape.iter().any(|dim| dim.as_const().is_none()) {
216                return SymbolicShapeSnafu.fail();
217            }
218            if shape.iter().any(|dim| dim.as_const() == Some(0)) {
219                return Ok(vec![]);
220            }
221        }
222
223        let buffer = self.buffer().ok_or(Error::NoBuffer)?;
224
225        if buffer.dtype() != T::DTYPE {
226            return TypeMismatchSnafu { expected: T::DTYPE, actual: buffer.dtype() }.fail();
227        }
228
229        let count = buffer.size() / T::DTYPE.bytes();
230        let mut data = vec![T::default(); count];
231        buffer
232            .copyout(unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut u8, count * T::DTYPE.bytes()) })
233            .context(DeviceSnafu)?;
234
235        Ok(data)
236    }
237
238    /// Typed immutable view into the buffer, shaped by the tensor's logical shape.
239    ///
240    /// Uses the tensor's concrete shape for multidimensional indexing.
241    /// Falls back to the buffer's flat shape for symbolic tensors.
242    ///
243    /// # Examples
244    /// ```
245    /// # use svod_tensor::Tensor;
246    /// # use ndarray::array;
247    /// let t = Tensor::from_ndarray(&array![[1.0f32, 2.0], [3.0, 4.0]]);
248    /// let view = t.array_view::<f32>().unwrap();
249    /// assert_eq!(view[[0, 1]], 2.0);
250    /// ```
251    pub fn array_view<T: HasDType>(&self) -> Result<ndarray::ArrayViewD<'_, T>> {
252        let buffer_arc = self.buffer.as_ref().or_else(|| self.entry.buffer()).ok_or(Error::NoBuffer)?;
253        let flat = buffer_arc.as_array::<T>().context(DeviceSnafu)?;
254        // Reshape to tensor's logical shape if concrete
255        if let Ok(shape) = self.shape() {
256            let dims: Vec<usize> = shape.iter().filter_map(|d| d.as_const()).collect();
257            if dims.len() == shape.len() {
258                return flat.into_shape_with_order(ndarray::IxDyn(&dims)).context(NdarrayShapeSnafu);
259            }
260        }
261        Ok(flat)
262    }
263
264    /// Typed mutable view into the buffer, shaped by the tensor's logical shape.
265    ///
266    /// # Examples
267    /// ```
268    /// # use svod_tensor::Tensor;
269    /// # use ndarray::array;
270    /// let t = Tensor::from_ndarray(&array![[0.0f32, 0.0, 0.0], [0.0, 0.0, 0.0]]);
271    /// t.array_view_mut::<f32>().unwrap()[[1, 2]] = 42.0;
272    /// assert_eq!(t.array_view::<f32>().unwrap()[[1, 2]], 42.0);
273    /// ```
274    pub fn array_view_mut<T: HasDType>(&self) -> Result<ndarray::ArrayViewMutD<'_, T>> {
275        let buffer_arc = self.buffer.as_ref().or_else(|| self.entry.buffer()).ok_or(Error::NoBuffer)?;
276        let flat = buffer_arc.as_array_mut::<T>().context(DeviceSnafu)?;
277        if let Ok(shape) = self.shape() {
278            let dims: Vec<usize> = shape.iter().filter_map(|d| d.as_const()).collect();
279            if dims.len() == shape.len() {
280                return flat.into_shape_with_order(ndarray::IxDyn(&dims)).context(NdarrayShapeSnafu);
281            }
282        }
283        Ok(flat)
284    }
285}