Skip to main content

sapient_core/
tensor.rs

1//! `Tensor` — the central multi-dimensional array type in SAPIENT.
2//!
3//! A `Tensor` owns its shape and dtype metadata, and holds a reference-counted
4//! `BufferHandle` for the raw bytes.  Layout is always row-major (C order).
5
6use serde::{Deserialize, Serialize};
7use serde::{Deserializer, Serializer};
8use std::sync::Arc;
9
10use crate::buffer::{BufferHandle, CpuBuffer};
11use crate::dtype::DType;
12use crate::error::{Result, SapientError};
13use crate::shape::Shape;
14
15// ── Tensor ────────────────────────────────────────────────────────────────────
16
17/// A multi-dimensional tensor with reference-counted buffer ownership.
18#[derive(Debug, Clone)]
19pub struct Tensor {
20    shape: Shape,
21    dtype: DType,
22    strides: Vec<usize>, // row-major by default
23    buffer: BufferHandle,
24    // Byte offset into the buffer where element [0,0,...,0] lives.
25    offset: usize,
26}
27
28impl Tensor {
29    // ── Constructors ─────────────────────────────────────────────────────────
30
31    /// Create a zero-filled tensor on the CPU.
32    pub fn zeros(shape: impl Into<Shape>, dtype: DType) -> Result<Self> {
33        let shape = shape.into();
34        shape.validate()?;
35        let numel = shape.numel();
36        let strides = shape.strides();
37        let buffer = BufferHandle::new(CpuBuffer::zeros(numel, dtype)?);
38        Ok(Self {
39            shape,
40            dtype,
41            strides,
42            buffer,
43            offset: 0,
44        })
45    }
46
47    /// Create a tensor from a flat `f32` slice (CPU, row-major).
48    pub fn from_f32(data: &[f32], shape: impl Into<Shape>) -> Result<Self> {
49        let shape = shape.into();
50        shape.validate()?;
51        if data.len() != shape.numel() {
52            return Err(SapientError::ShapeMismatch {
53                expected: shape.dims().to_vec(),
54                got: vec![data.len()],
55            });
56        }
57        let strides = shape.strides();
58        let buffer = BufferHandle::new(CpuBuffer::from_f32_slice(data)?);
59        Ok(Self {
60            shape,
61            dtype: DType::F32,
62            strides,
63            buffer,
64            offset: 0,
65        })
66    }
67
68    /// Create a tensor from raw BF16 bytes, storing them natively without conversion.
69    /// Use `to_f32_vec()` or `to_f32_tensor()` to convert for computation.
70    pub fn from_bf16_bytes(data: &[u8], shape: impl Into<Shape>) -> Result<Self> {
71        let shape = shape.into();
72        shape.validate()?;
73        let expected_bytes = shape.numel() * 2;
74        if data.len() != expected_bytes {
75            return Err(SapientError::ShapeMismatch {
76                expected: shape.dims().to_vec(),
77                got: vec![data.len() / 2],
78            });
79        }
80        let strides = shape.strides();
81        let buffer = BufferHandle::new(CpuBuffer::from_bytes_slice(data)?);
82        Ok(Self {
83            shape,
84            dtype: DType::BF16,
85            strides,
86            buffer,
87            offset: 0,
88        })
89    }
90
91    /// Create a tensor from raw F16 bytes, storing them natively without conversion.
92    pub fn from_f16_bytes(data: &[u8], shape: impl Into<Shape>) -> Result<Self> {
93        let shape = shape.into();
94        shape.validate()?;
95        let expected_bytes = shape.numel() * 2;
96        if data.len() != expected_bytes {
97            return Err(SapientError::ShapeMismatch {
98                expected: shape.dims().to_vec(),
99                got: vec![data.len() / 2],
100            });
101        }
102        let strides = shape.strides();
103        let buffer = BufferHandle::new(CpuBuffer::from_bytes_slice(data)?);
104        Ok(Self {
105            shape,
106            dtype: DType::F16,
107            strides,
108            buffer,
109            offset: 0,
110        })
111    }
112
113    /// Create a scalar tensor from a single `f32`.
114    pub fn scalar_f32(v: f32) -> Result<Self> {
115        Self::from_f32(&[v], Shape::scalar())
116    }
117
118    /// Create from a pre-built `BufferHandle` (used by backends).
119    pub fn from_buffer(
120        shape: impl Into<Shape>,
121        dtype: DType,
122        buffer: BufferHandle,
123        offset: usize,
124    ) -> Result<Self> {
125        let shape = shape.into();
126        shape.validate()?;
127        let required = dtype.byte_count(shape.numel());
128        if buffer.len() < offset + required {
129            return Err(SapientError::BufferSizeMismatch {
130                expected: offset + required,
131                got: buffer.len(),
132            });
133        }
134        let strides = shape.strides();
135        Ok(Self {
136            shape,
137            dtype,
138            strides,
139            buffer,
140            offset,
141        })
142    }
143
144    // ── Accessors ────────────────────────────────────────────────────────────
145
146    pub fn shape(&self) -> &Shape {
147        &self.shape
148    }
149    pub fn dtype(&self) -> DType {
150        self.dtype
151    }
152    pub fn ndim(&self) -> usize {
153        self.shape.ndim()
154    }
155    pub fn numel(&self) -> usize {
156        self.shape.numel()
157    }
158    pub fn strides(&self) -> &[usize] {
159        &self.strides
160    }
161    pub fn buffer(&self) -> &BufferHandle {
162        &self.buffer
163    }
164    pub fn offset(&self) -> usize {
165        self.offset
166    }
167
168    /// True if the tensor has a single element.
169    pub fn is_scalar(&self) -> bool {
170        self.shape.is_scalar() || self.numel() == 1
171    }
172
173    /// True if the buffer is row-major contiguous (normal case).
174    pub fn is_contiguous(&self) -> bool {
175        self.strides == self.shape.strides() && self.offset == 0
176    }
177
178    // ── Typed data access (CPU only) ─────────────────────────────────────────
179
180    /// Raw byte view (always works).
181    pub fn as_bytes(&self) -> &[u8] {
182        let bytes = self.buffer.as_bytes();
183        &bytes[self.offset..]
184    }
185
186    /// Typed `f32` view — panics if dtype is not F32.
187    pub fn as_f32_slice(&self) -> &[f32] {
188        assert_eq!(
189            self.dtype,
190            DType::F32,
191            "Tensor dtype is not F32 — call to_f32_vec() instead"
192        );
193        let bytes = self.as_bytes();
194        assert_eq!(bytes.len() % 4, 0);
195        // SAFETY: alignment ensured by CpuBuffer, dtype checked above.
196        unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const f32, bytes.len() / 4) }
197    }
198
199    /// Returns a `Cow<[f32]>`. Borrows if the tensor is already F32, otherwise allocates a new `Vec<f32>`.
200    pub fn to_f32_cow(&self) -> std::borrow::Cow<'_, [f32]> {
201        if self.dtype == DType::F32 {
202            std::borrow::Cow::Borrowed(self.as_f32_slice())
203        } else {
204            std::borrow::Cow::Owned(self.to_f32_vec())
205        }
206    }
207
208    /// Convert this tensor to a `Vec<f32>`, handling BF16/F16 → F32 on the fly.
209    /// For F32 tensors this is a cheap slice-to-vec copy. For BF16/F16 it converts.
210    pub fn to_f32_vec(&self) -> Vec<f32> {
211        match self.dtype {
212            DType::F32 => self.as_f32_slice().to_vec(),
213            DType::BF16 => {
214                let bytes = self.as_bytes();
215                bytes
216                    .chunks_exact(2)
217                    .map(|c| f32::from(half::bf16::from_le_bytes(c.try_into().unwrap())))
218                    .collect()
219            }
220            DType::F16 => {
221                let bytes = self.as_bytes();
222                bytes
223                    .chunks_exact(2)
224                    .map(|c| half::f16::from_le_bytes(c.try_into().unwrap()).to_f32())
225                    .collect()
226            }
227            _ => self.as_f32_slice().to_vec(), // fallback for other dtypes
228        }
229    }
230
231    /// Returns an F32 tensor, converting BF16/F16 if necessary.
232    /// For already-F32 tensors, clones the buffer. For native types, converts.
233    pub fn to_f32_tensor(&self) -> Result<Tensor> {
234        match self.dtype {
235            DType::F32 => Ok(self.clone()),
236            _ => Tensor::from_f32(&self.to_f32_vec(), self.shape.clone()),
237        }
238    }
239
240    /// Mutable typed `f32` view — fails if buffer is shared or not F32.
241    pub fn as_f32_slice_mut(&mut self) -> Result<&mut [f32]> {
242        if self.dtype != DType::F32 {
243            return Err(SapientError::internal("Tensor dtype is not F32"));
244        }
245        let offset = self.offset;
246        let buf = Arc::get_mut(&mut self.buffer.0)
247            .ok_or_else(|| SapientError::internal("Cannot mutate shared tensor buffer"))?;
248        let bytes = buf.as_bytes_mut();
249        let bytes = &mut bytes[offset..];
250        if bytes.len() % 4 != 0 {
251            return Err(SapientError::internal("Buffer length not a multiple of 4"));
252        }
253        // SAFETY: alignment ensured by CpuBuffer, dtype checked above.
254        Ok(unsafe {
255            std::slice::from_raw_parts_mut(bytes.as_mut_ptr() as *mut f32, bytes.len() / 4)
256        })
257    }
258
259    // ── Shape manipulation ───────────────────────────────────────────────────
260
261    /// Returns a new tensor with a different shape but the same buffer.
262    /// The total number of elements must be unchanged.
263    pub fn reshape(&self, new_shape: impl Into<Shape>) -> Result<Tensor> {
264        let new_shape = self.shape.reshape(new_shape.into().dims().to_vec())?;
265        let strides = new_shape.strides();
266        Ok(Tensor {
267            shape: new_shape,
268            dtype: self.dtype,
269            strides,
270            buffer: self.buffer.clone(),
271            offset: self.offset,
272        })
273    }
274
275    /// Transpose a 2-D tensor (swap axes 0 and 1).
276    pub fn t(&self) -> Result<Tensor> {
277        if self.ndim() != 2 {
278            return Err(SapientError::internal("t() requires a 2-D tensor"));
279        }
280        let mut dims = self.shape.dims().to_vec();
281        let mut strides = self.strides.clone();
282        dims.swap(0, 1);
283        strides.swap(0, 1);
284        Ok(Tensor {
285            shape: Shape(dims),
286            dtype: self.dtype,
287            strides,
288            buffer: self.buffer.clone(),
289            offset: self.offset,
290        })
291    }
292
293    /// Return a view of the tensor sliced along the given axis.
294    pub fn slice_axis(&self, axis: usize, start: usize, end: usize) -> Result<Tensor> {
295        let mut dims = self.shape.dims().to_vec();
296        if axis >= dims.len() {
297            return Err(SapientError::internal("slice axis out of bounds"));
298        }
299        if start > end || end > dims[axis] {
300            return Err(SapientError::internal("slice range out of bounds"));
301        }
302        dims[axis] = end - start;
303        let offset = self.offset + start * self.strides[axis] * self.dtype.element_size();
304        Ok(Tensor {
305            shape: Shape(dims),
306            dtype: self.dtype,
307            strides: self.strides.clone(),
308            buffer: self.buffer.clone(),
309            offset,
310        })
311    }
312
313    // ── Metadata convenience ─────────────────────────────────────────────────
314
315    /// Byte count for all elements.
316    pub fn byte_size(&self) -> usize {
317        self.dtype.byte_count(self.numel())
318    }
319}
320
321impl std::fmt::Display for Tensor {
322    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
323        write!(
324            f,
325            "Tensor(shape={}, dtype={}, device={})",
326            self.shape,
327            self.dtype,
328            self.buffer.0.device()
329        )
330    }
331}
332
333// ── Serde support for Tensor ─────────────────────────────────────────────────
334
335/// Serialisable proxy — stores raw f32 data alongside shape/dtype.
336#[derive(Serialize, Deserialize)]
337struct TensorProxy {
338    shape: Shape,
339    dtype: DType,
340    /// Raw bytes as base64-encoded (for JSON), or raw for binary.
341    data: Vec<f32>,
342}
343
344impl Serialize for Tensor {
345    fn serialize<S: Serializer>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> {
346        let data: Vec<f32> = if self.dtype == DType::F32 {
347            self.as_f32_slice().to_vec()
348        } else {
349            vec![] // non-f32 tensors: zero data (future work)
350        };
351        TensorProxy {
352            shape: self.shape.clone(),
353            dtype: self.dtype,
354            data,
355        }
356        .serialize(serializer)
357    }
358}
359
360impl<'de> Deserialize<'de> for Tensor {
361    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> std::result::Result<Self, D::Error> {
362        let proxy = TensorProxy::deserialize(deserializer)?;
363        if proxy.data.is_empty() {
364            Tensor::zeros(proxy.shape, proxy.dtype).map_err(serde::de::Error::custom)
365        } else {
366            Tensor::from_f32(&proxy.data, proxy.shape).map_err(serde::de::Error::custom)
367        }
368    }
369}
370
371/// A serializable descriptor for a tensor — shape and dtype only (no data).
372#[derive(Debug, Clone, Serialize, Deserialize)]
373pub struct TensorMeta {
374    pub shape: Shape,
375    pub dtype: DType,
376}
377
378impl From<&Tensor> for TensorMeta {
379    fn from(t: &Tensor) -> Self {
380        Self {
381            shape: t.shape.clone(),
382            dtype: t.dtype,
383        }
384    }
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390
391    #[test]
392    fn zeros_dtype_shape() {
393        let t = Tensor::zeros(vec![2, 3], DType::F32).unwrap();
394        assert_eq!(t.shape().dims(), &[2, 3]);
395        assert_eq!(t.dtype(), DType::F32);
396        assert_eq!(t.numel(), 6);
397    }
398
399    #[test]
400    fn from_f32_roundtrip() {
401        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
402        let t = Tensor::from_f32(&data, vec![2, 3]).unwrap();
403        assert_eq!(t.as_f32_slice(), data.as_slice());
404    }
405
406    #[test]
407    fn reshape_preserves_data() {
408        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
409        let t = Tensor::from_f32(&data, vec![2, 3]).unwrap();
410        let r = t.reshape(vec![3, 2]).unwrap();
411        assert_eq!(r.shape().dims(), &[3, 2]);
412        assert_eq!(r.as_f32_slice(), data.as_slice());
413    }
414
415    #[test]
416    fn reshape_wrong_numel() {
417        let t = Tensor::zeros(vec![2, 3], DType::F32).unwrap();
418        assert!(t.reshape(vec![5]).is_err());
419    }
420
421    #[test]
422    fn transpose_2d() {
423        let t = Tensor::zeros(vec![3, 4], DType::F32).unwrap();
424        let t2 = t.t().unwrap();
425        assert_eq!(t2.shape().dims(), &[4, 3]);
426    }
427
428    #[test]
429    fn byte_size() {
430        let t = Tensor::zeros(vec![4, 4], DType::F32).unwrap();
431        assert_eq!(t.byte_size(), 64);
432    }
433}