Skip to main content

sapient_core/
tensor.rs

1//! `Tensor` — the central multi-dimensional array type in SAPIENT.
2//!
3//! A `Tensor` owns its shape and dtype metadata, and holds a reference-counted
4//! `BufferHandle` for the raw bytes.  Layout is always row-major (C order).
5
6use serde::{Deserialize, Serialize};
7use serde::{Deserializer, Serializer};
8use std::sync::Arc;
9
10use crate::buffer::{BufferHandle, CpuBuffer};
11use crate::dtype::DType;
12use crate::error::{Result, SapientError};
13use crate::shape::Shape;
14
15// ── Tensor ────────────────────────────────────────────────────────────────────
16
17/// A multi-dimensional tensor with reference-counted buffer ownership.
18#[derive(Debug, Clone)]
19pub struct Tensor {
20    shape: Shape,
21    dtype: DType,
22    strides: Vec<usize>, // row-major by default
23    buffer: BufferHandle,
24    // Byte offset into the buffer where element [0,0,...,0] lives.
25    offset: usize,
26}
27
28impl Tensor {
29    // ── Constructors ─────────────────────────────────────────────────────────
30
31    /// Create a zero-filled tensor on the CPU.
32    pub fn zeros(shape: impl Into<Shape>, dtype: DType) -> Result<Self> {
33        let shape = shape.into();
34        shape.validate()?;
35        let numel = shape.numel();
36        let strides = shape.strides();
37        let buffer = BufferHandle::new(CpuBuffer::zeros(numel, dtype)?);
38        Ok(Self {
39            shape,
40            dtype,
41            strides,
42            buffer,
43            offset: 0,
44        })
45    }
46
47    /// Create a tensor from a flat `f32` slice (CPU, row-major).
48    pub fn from_f32(data: &[f32], shape: impl Into<Shape>) -> Result<Self> {
49        let shape = shape.into();
50        shape.validate()?;
51        if data.len() != shape.numel() {
52            return Err(SapientError::ShapeMismatch {
53                expected: shape.dims().to_vec(),
54                got: vec![data.len()],
55            });
56        }
57        let strides = shape.strides();
58        let buffer = BufferHandle::new(CpuBuffer::from_f32_slice(data)?);
59        Ok(Self {
60            shape,
61            dtype: DType::F32,
62            strides,
63            buffer,
64            offset: 0,
65        })
66    }
67
68    /// Create a tensor from raw BF16 bytes, storing them natively without conversion.
69    /// Use `to_f32_vec()` or `to_f32_tensor()` to convert for computation.
70    pub fn from_bf16_bytes(data: &[u8], shape: impl Into<Shape>) -> Result<Self> {
71        let shape = shape.into();
72        shape.validate()?;
73        let expected_bytes = shape.numel() * 2;
74        if data.len() != expected_bytes {
75            return Err(SapientError::ShapeMismatch {
76                expected: shape.dims().to_vec(),
77                got: vec![data.len() / 2],
78            });
79        }
80        let strides = shape.strides();
81        let buffer = BufferHandle::new(CpuBuffer::from_bytes_slice(data)?);
82        Ok(Self {
83            shape,
84            dtype: DType::BF16,
85            strides,
86            buffer,
87            offset: 0,
88        })
89    }
90
91    /// Create a tensor from raw F16 bytes, storing them natively without conversion.
92    pub fn from_f16_bytes(data: &[u8], shape: impl Into<Shape>) -> Result<Self> {
93        let shape = shape.into();
94        shape.validate()?;
95        let expected_bytes = shape.numel() * 2;
96        if data.len() != expected_bytes {
97            return Err(SapientError::ShapeMismatch {
98                expected: shape.dims().to_vec(),
99                got: vec![data.len() / 2],
100            });
101        }
102        let strides = shape.strides();
103        let buffer = BufferHandle::new(CpuBuffer::from_bytes_slice(data)?);
104        Ok(Self {
105            shape,
106            dtype: DType::F16,
107            strides,
108            buffer,
109            offset: 0,
110        })
111    }
112
113    /// Create a quantized tensor from raw block bytes (Q4_0 / Q8_0).
114    ///
115    /// `data` must contain exactly `dtype.byte_count(shape.numel())` bytes, i.e.
116    /// the packed ggml block bytes with no expansion.  The shape describes the
117    /// *logical* element count; `shape.numel()` must be a multiple of 32.
118    pub fn from_quant_bytes(data: &[u8], shape: impl Into<Shape>, dtype: DType) -> Result<Self> {
119        if !dtype.is_quantized() {
120            return Err(SapientError::TypeMismatch {
121                expected: "a quantized dtype (Q4_0 or Q8_0)".into(),
122                got: dtype.to_string(),
123            });
124        }
125        let shape = shape.into();
126        shape.validate()?;
127        let numel = shape.numel();
128        let expected_bytes = dtype.byte_count(numel);
129        if data.len() != expected_bytes {
130            return Err(SapientError::ShapeMismatch {
131                expected: vec![expected_bytes],
132                got: vec![data.len()],
133            });
134        }
135        let strides = shape.strides();
136        let buffer = BufferHandle::new(CpuBuffer::from_bytes_slice(data)?);
137        Ok(Self {
138            shape,
139            dtype,
140            strides,
141            buffer,
142            offset: 0,
143        })
144    }
145
146    /// Create a scalar tensor from a single `f32`.
147    pub fn scalar_f32(v: f32) -> Result<Self> {
148        Self::from_f32(&[v], Shape::scalar())
149    }
150
151    /// Create from a pre-built `BufferHandle` (used by backends).
152    pub fn from_buffer(
153        shape: impl Into<Shape>,
154        dtype: DType,
155        buffer: BufferHandle,
156        offset: usize,
157    ) -> Result<Self> {
158        let shape = shape.into();
159        shape.validate()?;
160        let required = dtype.byte_count(shape.numel());
161        if buffer.len() < offset + required {
162            return Err(SapientError::BufferSizeMismatch {
163                expected: offset + required,
164                got: buffer.len(),
165            });
166        }
167        let strides = shape.strides();
168        Ok(Self {
169            shape,
170            dtype,
171            strides,
172            buffer,
173            offset,
174        })
175    }
176
177    // ── Accessors ────────────────────────────────────────────────────────────
178
179    pub fn shape(&self) -> &Shape {
180        &self.shape
181    }
182    pub fn dtype(&self) -> DType {
183        self.dtype
184    }
185    pub fn ndim(&self) -> usize {
186        self.shape.ndim()
187    }
188    pub fn numel(&self) -> usize {
189        self.shape.numel()
190    }
191    pub fn strides(&self) -> &[usize] {
192        &self.strides
193    }
194    pub fn buffer(&self) -> &BufferHandle {
195        &self.buffer
196    }
197    pub fn offset(&self) -> usize {
198        self.offset
199    }
200
201    /// True if the tensor has a single element.
202    pub fn is_scalar(&self) -> bool {
203        self.shape.is_scalar() || self.numel() == 1
204    }
205
206    /// True if the buffer is row-major contiguous (normal case).
207    pub fn is_contiguous(&self) -> bool {
208        self.strides == self.shape.strides() && self.offset == 0
209    }
210
211    // ── Typed data access (CPU only) ─────────────────────────────────────────
212
213    /// Raw byte view. For non-quantized tensors returns the full buffer slice from
214    /// `offset` onwards (preserving the original behavior that stride-based kernels
215    /// rely on). For quantized tensors (Q4_0/Q8_0) returns exactly the packed block
216    /// bytes for this tensor's logical shape.
217    pub fn as_bytes(&self) -> &[u8] {
218        let bytes = self.buffer.as_bytes();
219        if self.dtype.is_quantized() {
220            let end = self.offset + self.dtype.byte_count(self.numel());
221            &bytes[self.offset..end]
222        } else {
223            &bytes[self.offset..]
224        }
225    }
226
227    /// For quantized tensors (Q4_0, Q8_0): returns the packed block bytes as a
228    /// row-major slice where each logical row of `k` elements occupies
229    /// `dtype.byte_count(k)` bytes.  Panics if the tensor is not quantized.
230    pub fn as_quant_blocks(&self) -> &[u8] {
231        assert!(
232            self.dtype.is_quantized(),
233            "as_quant_blocks() called on non-quantized tensor (dtype = {})",
234            self.dtype
235        );
236        self.as_bytes()
237    }
238
239    /// Typed `f32` view — panics if dtype is not F32.
240    pub fn as_f32_slice(&self) -> &[f32] {
241        assert_eq!(
242            self.dtype,
243            DType::F32,
244            "Tensor dtype is not F32 — call to_f32_vec() instead"
245        );
246        let bytes = self.as_bytes();
247        assert_eq!(bytes.len() % 4, 0);
248        // SAFETY: alignment ensured by CpuBuffer, dtype checked above.
249        unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const f32, bytes.len() / 4) }
250    }
251
252    /// Returns a `Cow<[f32]>`. Borrows if the tensor is already F32, otherwise allocates a new `Vec<f32>`.
253    pub fn to_f32_cow(&self) -> std::borrow::Cow<'_, [f32]> {
254        if self.dtype == DType::F32 {
255            std::borrow::Cow::Borrowed(self.as_f32_slice())
256        } else {
257            std::borrow::Cow::Owned(self.to_f32_vec())
258        }
259    }
260
261    /// Convert this tensor to a `Vec<f32>`, handling all dtypes including quantized.
262    /// For F32: cheap copy. For F16/BF16: convert. For Q4_0/Q8_0: dequantize all blocks.
263    pub fn to_f32_vec(&self) -> Vec<f32> {
264        use crate::dtype::{Q4_0_BLOCK_BYTES, Q8_0_BLOCK_BYTES, QUANT_BLOCK_SIZE};
265        match self.dtype {
266            DType::F32 => self.as_f32_slice().to_vec(),
267            DType::BF16 => {
268                let bytes = self.as_bytes();
269                bytes
270                    .chunks_exact(2)
271                    .map(|c| f32::from(half::bf16::from_le_bytes(c.try_into().unwrap())))
272                    .collect()
273            }
274            DType::F16 => {
275                let bytes = self.as_bytes();
276                bytes
277                    .chunks_exact(2)
278                    .map(|c| half::f16::from_le_bytes(c.try_into().unwrap()).to_f32())
279                    .collect()
280            }
281            DType::Q4_0 => {
282                let numel = self.numel();
283                let bytes = self.as_bytes();
284                let mut out = vec![0.0f32; numel];
285                for (b, block) in bytes.chunks_exact(Q4_0_BLOCK_BYTES).enumerate() {
286                    let d = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
287                    for j in 0..QUANT_BLOCK_SIZE / 2 {
288                        let byte = block[2 + j];
289                        let lo = (byte & 0x0f) as i32 - 8;
290                        let hi = (byte >> 4) as i32 - 8;
291                        out[b * QUANT_BLOCK_SIZE + j] = lo as f32 * d;
292                        out[b * QUANT_BLOCK_SIZE + j + QUANT_BLOCK_SIZE / 2] = hi as f32 * d;
293                    }
294                }
295                out
296            }
297            DType::Q8_0 => {
298                let numel = self.numel();
299                let bytes = self.as_bytes();
300                let mut out = vec![0.0f32; numel];
301                for (b, block) in bytes.chunks_exact(Q8_0_BLOCK_BYTES).enumerate() {
302                    let d = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
303                    for j in 0..QUANT_BLOCK_SIZE {
304                        out[b * QUANT_BLOCK_SIZE + j] = block[2 + j] as i8 as f32 * d;
305                    }
306                }
307                out
308            }
309            _ => self.as_f32_slice().to_vec(), // fallback for integer dtypes
310        }
311    }
312
313    /// Returns an F32 tensor, converting BF16/F16 if necessary.
314    /// For already-F32 tensors, clones the buffer. For native types, converts.
315    pub fn to_f32_tensor(&self) -> Result<Tensor> {
316        match self.dtype {
317            DType::F32 => Ok(self.clone()),
318            _ => Tensor::from_f32(&self.to_f32_vec(), self.shape.clone()),
319        }
320    }
321
322    /// Mutable typed `f32` view — fails if buffer is shared or not F32.
323    pub fn as_f32_slice_mut(&mut self) -> Result<&mut [f32]> {
324        if self.dtype != DType::F32 {
325            return Err(SapientError::internal("Tensor dtype is not F32"));
326        }
327        let offset = self.offset;
328        let buf = Arc::get_mut(&mut self.buffer.0)
329            .ok_or_else(|| SapientError::internal("Cannot mutate shared tensor buffer"))?;
330        let bytes = buf.as_bytes_mut();
331        let bytes = &mut bytes[offset..];
332        if bytes.len() % 4 != 0 {
333            return Err(SapientError::internal("Buffer length not a multiple of 4"));
334        }
335        // SAFETY: alignment ensured by CpuBuffer, dtype checked above.
336        Ok(unsafe {
337            std::slice::from_raw_parts_mut(bytes.as_mut_ptr() as *mut f32, bytes.len() / 4)
338        })
339    }
340
341    // ── Shape manipulation ───────────────────────────────────────────────────
342
343    /// Returns a new tensor with a different shape but the same buffer.
344    /// The total number of elements must be unchanged.
345    pub fn reshape(&self, new_shape: impl Into<Shape>) -> Result<Tensor> {
346        let new_shape = self.shape.reshape(new_shape.into().dims().to_vec())?;
347        let strides = new_shape.strides();
348        Ok(Tensor {
349            shape: new_shape,
350            dtype: self.dtype,
351            strides,
352            buffer: self.buffer.clone(),
353            offset: self.offset,
354        })
355    }
356
357    /// Transpose a 2-D tensor (swap axes 0 and 1).
358    pub fn t(&self) -> Result<Tensor> {
359        if self.ndim() != 2 {
360            return Err(SapientError::internal("t() requires a 2-D tensor"));
361        }
362        let mut dims = self.shape.dims().to_vec();
363        let mut strides = self.strides.clone();
364        dims.swap(0, 1);
365        strides.swap(0, 1);
366        Ok(Tensor {
367            shape: Shape(dims),
368            dtype: self.dtype,
369            strides,
370            buffer: self.buffer.clone(),
371            offset: self.offset,
372        })
373    }
374
375    /// Return a view of the tensor sliced along the given axis.
376    pub fn slice_axis(&self, axis: usize, start: usize, end: usize) -> Result<Tensor> {
377        let mut dims = self.shape.dims().to_vec();
378        if axis >= dims.len() {
379            return Err(SapientError::internal("slice axis out of bounds"));
380        }
381        if start > end || end > dims[axis] {
382            return Err(SapientError::internal("slice range out of bounds"));
383        }
384        dims[axis] = end - start;
385        let offset = self.offset + start * self.strides[axis] * self.dtype.element_size();
386        Ok(Tensor {
387            shape: Shape(dims),
388            dtype: self.dtype,
389            strides: self.strides.clone(),
390            buffer: self.buffer.clone(),
391            offset,
392        })
393    }
394
395    // ── Metadata convenience ─────────────────────────────────────────────────
396
397    /// Byte count for all elements.
398    pub fn byte_size(&self) -> usize {
399        self.dtype.byte_count(self.numel())
400    }
401}
402
403impl std::fmt::Display for Tensor {
404    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
405        write!(
406            f,
407            "Tensor(shape={}, dtype={}, device={})",
408            self.shape,
409            self.dtype,
410            self.buffer.0.device()
411        )
412    }
413}
414
415// ── Serde support for Tensor ─────────────────────────────────────────────────
416
417/// Serialisable proxy — stores raw f32 data alongside shape/dtype.
418#[derive(Serialize, Deserialize)]
419struct TensorProxy {
420    shape: Shape,
421    dtype: DType,
422    /// Raw bytes as base64-encoded (for JSON), or raw for binary.
423    data: Vec<f32>,
424}
425
426impl Serialize for Tensor {
427    fn serialize<S: Serializer>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> {
428        let data: Vec<f32> = if self.dtype == DType::F32 {
429            self.as_f32_slice().to_vec()
430        } else {
431            vec![] // non-f32 tensors: zero data (future work)
432        };
433        TensorProxy {
434            shape: self.shape.clone(),
435            dtype: self.dtype,
436            data,
437        }
438        .serialize(serializer)
439    }
440}
441
442impl<'de> Deserialize<'de> for Tensor {
443    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> std::result::Result<Self, D::Error> {
444        let proxy = TensorProxy::deserialize(deserializer)?;
445        if proxy.data.is_empty() {
446            Tensor::zeros(proxy.shape, proxy.dtype).map_err(serde::de::Error::custom)
447        } else {
448            Tensor::from_f32(&proxy.data, proxy.shape).map_err(serde::de::Error::custom)
449        }
450    }
451}
452
453/// A serializable descriptor for a tensor — shape and dtype only (no data).
454#[derive(Debug, Clone, Serialize, Deserialize)]
455pub struct TensorMeta {
456    pub shape: Shape,
457    pub dtype: DType,
458}
459
460impl From<&Tensor> for TensorMeta {
461    fn from(t: &Tensor) -> Self {
462        Self {
463            shape: t.shape.clone(),
464            dtype: t.dtype,
465        }
466    }
467}
468
469#[cfg(test)]
470mod tests {
471    use super::*;
472
473    #[test]
474    fn zeros_dtype_shape() {
475        let t = Tensor::zeros(vec![2, 3], DType::F32).unwrap();
476        assert_eq!(t.shape().dims(), &[2, 3]);
477        assert_eq!(t.dtype(), DType::F32);
478        assert_eq!(t.numel(), 6);
479    }
480
481    #[test]
482    fn from_f32_roundtrip() {
483        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
484        let t = Tensor::from_f32(&data, vec![2, 3]).unwrap();
485        assert_eq!(t.as_f32_slice(), data.as_slice());
486    }
487
488    #[test]
489    fn reshape_preserves_data() {
490        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
491        let t = Tensor::from_f32(&data, vec![2, 3]).unwrap();
492        let r = t.reshape(vec![3, 2]).unwrap();
493        assert_eq!(r.shape().dims(), &[3, 2]);
494        assert_eq!(r.as_f32_slice(), data.as_slice());
495    }
496
497    #[test]
498    fn reshape_wrong_numel() {
499        let t = Tensor::zeros(vec![2, 3], DType::F32).unwrap();
500        assert!(t.reshape(vec![5]).is_err());
501    }
502
503    #[test]
504    fn transpose_2d() {
505        let t = Tensor::zeros(vec![3, 4], DType::F32).unwrap();
506        let t2 = t.t().unwrap();
507        assert_eq!(t2.shape().dims(), &[4, 3]);
508    }
509
510    #[test]
511    fn byte_size() {
512        let t = Tensor::zeros(vec![4, 4], DType::F32).unwrap();
513        assert_eq!(t.byte_size(), 64);
514    }
515}