Skip to main content

sapient_core/
tensor.rs

1//! `Tensor` — the central multi-dimensional array type in SAPIENT.
2//!
3//! A `Tensor` owns its shape and dtype metadata, and holds a reference-counted
4//! `BufferHandle` for the raw bytes.  Layout is always row-major (C order).
5
6use serde::{Deserialize, Serialize};
7use serde::{Deserializer, Serializer};
8use std::sync::Arc;
9
10use crate::buffer::{BufferHandle, CpuBuffer};
11use crate::dtype::DType;
12use crate::error::{Result, SapientError};
13use crate::shape::Shape;
14
15// ── Tensor ────────────────────────────────────────────────────────────────────
16
17/// A multi-dimensional tensor with reference-counted buffer ownership.
18#[derive(Debug, Clone)]
19pub struct Tensor {
20    shape: Shape,
21    dtype: DType,
22    strides: Vec<usize>, // row-major by default
23    buffer: BufferHandle,
24    // Byte offset into the buffer where element [0,0,...,0] lives.
25    offset: usize,
26}
27
28impl Tensor {
29    // ── Constructors ─────────────────────────────────────────────────────────
30
31    /// Create a zero-filled tensor on the CPU.
32    pub fn zeros(shape: impl Into<Shape>, dtype: DType) -> Result<Self> {
33        let shape = shape.into();
34        shape.validate()?;
35        let numel = shape.numel();
36        let strides = shape.strides();
37        let buffer = BufferHandle::new(CpuBuffer::zeros(numel, dtype)?);
38        Ok(Self {
39            shape,
40            dtype,
41            strides,
42            buffer,
43            offset: 0,
44        })
45    }
46
47    /// Create a tensor from a flat `f32` slice (CPU, row-major).
48    /// Take ownership of a `Vec<f32>` without copying.
49    /// Use instead of `from_f32` in hot paths to avoid the allocation + memcpy.
50    pub fn from_f32_vec(data: Vec<f32>, shape: impl Into<Shape>) -> Result<Self> {
51        let shape = shape.into();
52        shape.validate()?;
53        if data.len() != shape.numel() {
54            return Err(SapientError::ShapeMismatch {
55                expected: shape.dims().to_vec(),
56                got: vec![data.len()],
57            });
58        }
59        let strides = shape.strides();
60        let buffer = BufferHandle::new(CpuBuffer::from_f32_vec(data)?);
61        Ok(Self {
62            shape,
63            dtype: DType::F32,
64            strides,
65            buffer,
66            offset: 0,
67        })
68    }
69
70    pub fn from_f32(data: &[f32], shape: impl Into<Shape>) -> Result<Self> {
71        let shape = shape.into();
72        shape.validate()?;
73        if data.len() != shape.numel() {
74            return Err(SapientError::ShapeMismatch {
75                expected: shape.dims().to_vec(),
76                got: vec![data.len()],
77            });
78        }
79        let strides = shape.strides();
80        let buffer = BufferHandle::new(CpuBuffer::from_f32_slice(data)?);
81        Ok(Self {
82            shape,
83            dtype: DType::F32,
84            strides,
85            buffer,
86            offset: 0,
87        })
88    }
89
90    /// Create a tensor from raw BF16 bytes, storing them natively without conversion.
91    /// Use `to_f32_vec()` or `to_f32_tensor()` to convert for computation.
92    pub fn from_bf16_bytes(data: &[u8], shape: impl Into<Shape>) -> Result<Self> {
93        let shape = shape.into();
94        shape.validate()?;
95        let expected_bytes = shape.numel() * 2;
96        if data.len() != expected_bytes {
97            return Err(SapientError::ShapeMismatch {
98                expected: shape.dims().to_vec(),
99                got: vec![data.len() / 2],
100            });
101        }
102        let strides = shape.strides();
103        let buffer = BufferHandle::new(CpuBuffer::from_bytes_slice(data)?);
104        Ok(Self {
105            shape,
106            dtype: DType::BF16,
107            strides,
108            buffer,
109            offset: 0,
110        })
111    }
112
113    /// Create a tensor from raw F16 bytes, storing them natively without conversion.
114    pub fn from_f16_bytes(data: &[u8], shape: impl Into<Shape>) -> Result<Self> {
115        let shape = shape.into();
116        shape.validate()?;
117        let expected_bytes = shape.numel() * 2;
118        if data.len() != expected_bytes {
119            return Err(SapientError::ShapeMismatch {
120                expected: shape.dims().to_vec(),
121                got: vec![data.len() / 2],
122            });
123        }
124        let strides = shape.strides();
125        let buffer = BufferHandle::new(CpuBuffer::from_bytes_slice(data)?);
126        Ok(Self {
127            shape,
128            dtype: DType::F16,
129            strides,
130            buffer,
131            offset: 0,
132        })
133    }
134
135    /// Create a quantized tensor from raw block bytes (Q4_0 / Q8_0).
136    ///
137    /// `data` must contain exactly `dtype.byte_count(shape.numel())` bytes, i.e.
138    /// the packed ggml block bytes with no expansion.  The shape describes the
139    /// *logical* element count; `shape.numel()` must be a multiple of 32.
140    pub fn from_quant_bytes(data: &[u8], shape: impl Into<Shape>, dtype: DType) -> Result<Self> {
141        if !dtype.is_quantized() {
142            return Err(SapientError::TypeMismatch {
143                expected: "a quantized dtype (Q4_0, Q8_0, Q4_K, Q5_K, Q6_K)".into(),
144                got: dtype.to_string(),
145            });
146        }
147        let shape = shape.into();
148        shape.validate()?;
149        let numel = shape.numel();
150        let expected_bytes = dtype.byte_count(numel);
151        if data.len() != expected_bytes {
152            return Err(SapientError::ShapeMismatch {
153                expected: vec![expected_bytes],
154                got: vec![data.len()],
155            });
156        }
157        let strides = shape.strides();
158        let buffer = BufferHandle::new(CpuBuffer::from_bytes_slice(data)?);
159        Ok(Self {
160            shape,
161            dtype,
162            strides,
163            buffer,
164            offset: 0,
165        })
166    }
167
168    /// Create a scalar tensor from a single `f32`.
169    pub fn scalar_f32(v: f32) -> Result<Self> {
170        Self::from_f32(&[v], Shape::scalar())
171    }
172
173    /// Create from a pre-built `BufferHandle` (used by backends).
174    pub fn from_buffer(
175        shape: impl Into<Shape>,
176        dtype: DType,
177        buffer: BufferHandle,
178        offset: usize,
179    ) -> Result<Self> {
180        let shape = shape.into();
181        shape.validate()?;
182        let required = dtype.byte_count(shape.numel());
183        if buffer.len() < offset + required {
184            return Err(SapientError::BufferSizeMismatch {
185                expected: offset + required,
186                got: buffer.len(),
187            });
188        }
189        let strides = shape.strides();
190        Ok(Self {
191            shape,
192            dtype,
193            strides,
194            buffer,
195            offset,
196        })
197    }
198
199    // ── Accessors ────────────────────────────────────────────────────────────
200
201    pub fn shape(&self) -> &Shape {
202        &self.shape
203    }
204    pub fn dtype(&self) -> DType {
205        self.dtype
206    }
207    pub fn ndim(&self) -> usize {
208        self.shape.ndim()
209    }
210    pub fn numel(&self) -> usize {
211        self.shape.numel()
212    }
213    pub fn strides(&self) -> &[usize] {
214        &self.strides
215    }
216    pub fn buffer(&self) -> &BufferHandle {
217        &self.buffer
218    }
219    pub fn offset(&self) -> usize {
220        self.offset
221    }
222
223    /// True if the tensor has a single element.
224    pub fn is_scalar(&self) -> bool {
225        self.shape.is_scalar() || self.numel() == 1
226    }
227
228    /// True if the buffer is row-major contiguous (normal case).
229    pub fn is_contiguous(&self) -> bool {
230        self.strides == self.shape.strides() && self.offset == 0
231    }
232
233    // ── Typed data access (CPU only) ─────────────────────────────────────────
234
235    /// Raw byte view. For non-quantized tensors returns the full buffer slice from
236    /// `offset` onwards (preserving the original behavior that stride-based kernels
237    /// rely on). For quantized tensors (Q4_0/Q8_0) returns exactly the packed block
238    /// bytes for this tensor's logical shape.
239    pub fn as_bytes(&self) -> &[u8] {
240        let bytes = self.buffer.as_bytes();
241        if self.dtype.is_quantized() {
242            let end = self.offset + self.dtype.byte_count(self.numel());
243            &bytes[self.offset..end]
244        } else {
245            &bytes[self.offset..]
246        }
247    }
248
249    /// For quantized tensors (Q4_0, Q8_0): returns the packed block bytes as a
250    /// row-major slice where each logical row of `k` elements occupies
251    /// `dtype.byte_count(k)` bytes.  Panics if the tensor is not quantized.
252    pub fn as_quant_blocks(&self) -> &[u8] {
253        assert!(
254            self.dtype.is_quantized(),
255            "as_quant_blocks() called on non-quantized tensor (dtype = {})",
256            self.dtype
257        );
258        self.as_bytes()
259    }
260
261    /// Typed `f32` view — panics if dtype is not F32.
262    pub fn as_f32_slice(&self) -> &[f32] {
263        assert_eq!(
264            self.dtype,
265            DType::F32,
266            "Tensor dtype is not F32 — call to_f32_vec() instead"
267        );
268        let bytes = self.as_bytes();
269        assert_eq!(bytes.len() % 4, 0);
270        // SAFETY: alignment ensured by CpuBuffer, dtype checked above.
271        unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const f32, bytes.len() / 4) }
272    }
273
274    /// Convert this tensor to a contiguous `Vec<f32>` that matches `shape().numel()`
275    /// exactly, even when the tensor is a **non-contiguous view** (e.g. a KV-cache
276    /// slice from `slice_axis`).
277    ///
278    /// For contiguous tensors this is a fast bounded copy.  For non-contiguous
279    /// tensors (strides don't match the natural row-major strides, or offset ≠ 0)
280    /// it uses stride-based indexing to extract only the logically-reachable elements
281    /// in row-major order — the approach that `as_f32_slice` / `to_f32_cow` cannot
282    /// do because they return the full backing buffer.
283    pub fn to_contiguous_f32_vec(&self) -> Vec<f32> {
284        let numel = self.numel();
285        if self.is_contiguous() {
286            // Fast path: elements are dense starting at offset.
287            // Limit to numel to avoid reading past the logical tensor.
288            match self.dtype {
289                DType::F32 => self.as_f32_slice()[..numel].to_vec(),
290                _ => {
291                    let v = self.to_f32_vec();
292                    v[..numel.min(v.len())].to_vec()
293                }
294            }
295        } else {
296            // Slow path: stride-based copy.
297            // `raw` gives us the full backing buffer from `self.offset` as f32 units.
298            let raw: Vec<f32> = match self.dtype {
299                DType::F32 => self.as_f32_slice().to_vec(),
300                _ => self.to_f32_vec(),
301            };
302            let dims = self.shape.dims();
303            let strides = &self.strides; // element strides (not byte strides)
304            let mut out = vec![0.0f32; numel];
305            for (flat, dst) in out.iter_mut().enumerate() {
306                // Convert flat (row-major) index to per-dimension indices, then
307                // compute the element offset using the tensor's actual strides.
308                let mut rem = flat;
309                let mut src = 0usize;
310                for d in (0..dims.len()).rev() {
311                    let idx_d = rem % dims[d];
312                    rem /= dims[d];
313                    src += idx_d * strides[d];
314                }
315                *dst = *raw.get(src).unwrap_or(&0.0);
316            }
317            out
318        }
319    }
320
321    /// Returns a `Cow<[f32]>`. Borrows if the tensor is already F32, otherwise allocates a new `Vec<f32>`.
322    pub fn to_f32_cow(&self) -> std::borrow::Cow<'_, [f32]> {
323        if self.dtype == DType::F32 {
324            std::borrow::Cow::Borrowed(self.as_f32_slice())
325        } else {
326            std::borrow::Cow::Owned(self.to_f32_vec())
327        }
328    }
329
330    /// Convert this tensor to a `Vec<f32>`, handling all dtypes including quantized.
331    /// For F32: cheap copy. For F16/BF16: convert. For quantized: dequantize all blocks.
332    pub fn to_f32_vec(&self) -> Vec<f32> {
333        use crate::dtype::{
334            K_QUANT_BLOCK_SIZE, Q4_0_BLOCK_BYTES, Q4_K_BLOCK_BYTES, Q5_K_BLOCK_BYTES,
335            Q6_K_BLOCK_BYTES, Q8_0_BLOCK_BYTES, QUANT_BLOCK_SIZE,
336        };
337        match self.dtype {
338            DType::F32 => self.as_f32_slice().to_vec(),
339            DType::BF16 => {
340                let bytes = self.as_bytes();
341                bytes
342                    .chunks_exact(2)
343                    .map(|c| f32::from(half::bf16::from_le_bytes(c.try_into().unwrap())))
344                    .collect()
345            }
346            DType::F16 => {
347                let bytes = self.as_bytes();
348                bytes
349                    .chunks_exact(2)
350                    .map(|c| half::f16::from_le_bytes(c.try_into().unwrap()).to_f32())
351                    .collect()
352            }
353            DType::Q4_0 => {
354                let numel = self.numel();
355                let bytes = self.as_bytes();
356                let mut out = vec![0.0f32; numel];
357                for (b, block) in bytes.chunks_exact(Q4_0_BLOCK_BYTES).enumerate() {
358                    let d = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
359                    for j in 0..QUANT_BLOCK_SIZE / 2 {
360                        let byte = block[2 + j];
361                        let lo = (byte & 0x0f) as i32 - 8;
362                        let hi = (byte >> 4) as i32 - 8;
363                        out[b * QUANT_BLOCK_SIZE + j] = lo as f32 * d;
364                        out[b * QUANT_BLOCK_SIZE + j + QUANT_BLOCK_SIZE / 2] = hi as f32 * d;
365                    }
366                }
367                out
368            }
369            DType::Q8_0 => {
370                let numel = self.numel();
371                let bytes = self.as_bytes();
372                let mut out = vec![0.0f32; numel];
373                for (b, block) in bytes.chunks_exact(Q8_0_BLOCK_BYTES).enumerate() {
374                    let d = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
375                    for j in 0..QUANT_BLOCK_SIZE {
376                        out[b * QUANT_BLOCK_SIZE + j] = block[2 + j] as i8 as f32 * d;
377                    }
378                }
379                out
380            }
381            DType::Q4_K => {
382                let numel = self.numel();
383                let bytes = self.as_bytes();
384                let mut out = vec![0.0f32; numel];
385                let mut out_idx = 0usize;
386                for block in bytes.chunks_exact(Q4_K_BLOCK_BYTES) {
387                    let d = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
388                    let dmin = half::f16::from_le_bytes([block[2], block[3]]).to_f32();
389                    let scales = &block[4..16];
390                    let qs = &block[16..Q4_K_BLOCK_BYTES];
391                    let mut q_off = 0usize;
392                    let mut is = 0usize;
393                    for _ in 0..(K_QUANT_BLOCK_SIZE / 64) {
394                        let (sc1, m1) = Self::get_scale_min_k4(is, scales);
395                        let d1 = d * sc1 as f32;
396                        let m1v = dmin * m1 as f32;
397                        let (sc2, m2) = Self::get_scale_min_k4(is + 1, scales);
398                        let d2 = d * sc2 as f32;
399                        let m2v = dmin * m2 as f32;
400                        for l in 0..32 {
401                            out[out_idx + l] = d1 * (qs[q_off + l] & 0x0F) as f32 - m1v;
402                            out[out_idx + l + 32] = d2 * (qs[q_off + l] >> 4) as f32 - m2v;
403                        }
404                        out_idx += 64;
405                        q_off += 32;
406                        is += 2;
407                    }
408                }
409                out
410            }
411            DType::Q5_K => {
412                let numel = self.numel();
413                let bytes = self.as_bytes();
414                let mut out = vec![0.0f32; numel];
415                let mut out_idx = 0usize;
416                for block in bytes.chunks_exact(Q5_K_BLOCK_BYTES) {
417                    let d = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
418                    let dmin = half::f16::from_le_bytes([block[2], block[3]]).to_f32();
419                    let scales = &block[4..16];
420                    let qh = &block[16..48];
421                    let ql = &block[48..Q5_K_BLOCK_BYTES];
422                    let mut ql_off = 0usize;
423                    let mut is = 0usize;
424                    let mut u1: u8 = 1;
425                    let mut u2: u8 = 2;
426                    for _ in 0..(K_QUANT_BLOCK_SIZE / 64) {
427                        let (sc1, m1) = Self::get_scale_min_k4(is, scales);
428                        let d1 = d * sc1 as f32;
429                        let m1v = dmin * m1 as f32;
430                        let (sc2, m2) = Self::get_scale_min_k4(is + 1, scales);
431                        let d2 = d * sc2 as f32;
432                        let m2v = dmin * m2 as f32;
433                        let qh_byte = qh[is / 8];
434                        for l in 0..32usize {
435                            let hi = if qh_byte & u1 != 0 { 16.0f32 } else { 0.0 };
436                            out[out_idx + l] = d1 * ((ql[ql_off + l] & 0x0F) as f32 + hi) - m1v;
437                            let hi2 = if qh_byte & u2 != 0 { 16.0f32 } else { 0.0 };
438                            out[out_idx + l + 32] = d2 * ((ql[ql_off + l] >> 4) as f32 + hi2) - m2v;
439                        }
440                        out_idx += 64;
441                        ql_off += 32;
442                        is += 2;
443                        if is % 8 == 0 {
444                            u1 = 1;
445                            u2 = 2;
446                        } else {
447                            u1 <<= 2;
448                            u2 <<= 2;
449                        }
450                    }
451                }
452                out
453            }
454            DType::Q6_K => {
455                let numel = self.numel();
456                let bytes = self.as_bytes();
457                let mut out = vec![0.0f32; numel];
458                let mut out_idx = 0usize;
459                for block in bytes.chunks_exact(Q6_K_BLOCK_BYTES) {
460                    let ql = &block[0..128];
461                    let qh = &block[128..192];
462                    let sc = &block[192..208];
463                    let d = half::f16::from_le_bytes([block[208], block[209]]).to_f32();
464                    let mut ql_off = 0usize;
465                    let mut qh_off = 0usize;
466                    let mut ib = 0usize;
467                    for _ in 0..(K_QUANT_BLOCK_SIZE / 128) {
468                        for l in 0..32usize {
469                            let q1 = (((ql[ql_off + l] & 0x0F) | ((qh[qh_off + l] & 3) << 4))
470                                as i32
471                                - 32) as f32;
472                            let q2 = (((ql[ql_off + l + 32] & 0x0F)
473                                | (((qh[qh_off + l] >> 2) & 3) << 4))
474                                as i32
475                                - 32) as f32;
476                            let q3 = (((ql[ql_off + l] >> 4) | (((qh[qh_off + l] >> 4) & 3) << 4))
477                                as i32
478                                - 32) as f32;
479                            let q4 = (((ql[ql_off + l + 32] >> 4)
480                                | (((qh[qh_off + l] >> 6) & 3) << 4))
481                                as i32
482                                - 32) as f32;
483                            out[out_idx + l] = d * sc[ib] as i8 as f32 * q1;
484                            out[out_idx + l + 32] = d * sc[ib + 1] as i8 as f32 * q2;
485                            out[out_idx + l + 64] = d * sc[ib + 2] as i8 as f32 * q3;
486                            out[out_idx + l + 96] = d * sc[ib + 3] as i8 as f32 * q4;
487                        }
488                        out_idx += 128;
489                        ql_off += 64;
490                        qh_off += 32;
491                        ib += 4;
492                    }
493                }
494                out
495            }
496            _ => self.as_f32_slice().to_vec(), // fallback for integer dtypes
497        }
498    }
499
500    /// Extract scale and min for a K-quant sub-block (used in Q4_K/Q5_K dequantization).
501    #[inline]
502    fn get_scale_min_k4(j: usize, scales: &[u8]) -> (u8, u8) {
503        if j < 4 {
504            (scales[j] & 63, scales[j + 4] & 63)
505        } else {
506            (
507                (scales[j + 4] & 0x0F) | ((scales[j - 4] >> 6) << 4),
508                (scales[j + 4] >> 4) | ((scales[j] >> 6) << 4),
509            )
510        }
511    }
512
513    /// Returns an F32 tensor, converting BF16/F16 if necessary.
514    /// For already-F32 tensors, clones the buffer. For native types, converts.
515    pub fn to_f32_tensor(&self) -> Result<Tensor> {
516        match self.dtype {
517            DType::F32 => Ok(self.clone()),
518            _ => Tensor::from_f32(&self.to_f32_vec(), self.shape.clone()),
519        }
520    }
521
522    /// Mutable typed `f32` view — fails if buffer is shared or not F32.
523    /// Mutable byte access for quantized tensors — **in-place** update with zero copy.
524    /// Returns an error if the buffer is shared (Arc strong_count > 1).
525    pub fn as_bytes_mut(&mut self) -> Result<&mut [u8]> {
526        let offset = self.offset;
527        let end = offset + self.dtype.byte_count(self.numel());
528        let buf = Arc::get_mut(&mut self.buffer.0)
529            .ok_or_else(|| SapientError::internal("Cannot mutate shared tensor buffer"))?;
530        let bytes = buf.as_bytes_mut();
531        Ok(&mut bytes[offset..end])
532    }
533
534    pub fn as_f32_slice_mut(&mut self) -> Result<&mut [f32]> {
535        if self.dtype != DType::F32 {
536            return Err(SapientError::internal("Tensor dtype is not F32"));
537        }
538        let offset = self.offset;
539        let buf = Arc::get_mut(&mut self.buffer.0)
540            .ok_or_else(|| SapientError::internal("Cannot mutate shared tensor buffer"))?;
541        let bytes = buf.as_bytes_mut();
542        let bytes = &mut bytes[offset..];
543        if bytes.len() % 4 != 0 {
544            return Err(SapientError::internal("Buffer length not a multiple of 4"));
545        }
546        // SAFETY: alignment ensured by CpuBuffer, dtype checked above.
547        Ok(unsafe {
548            std::slice::from_raw_parts_mut(bytes.as_mut_ptr() as *mut f32, bytes.len() / 4)
549        })
550    }
551
552    // ── Shape manipulation ───────────────────────────────────────────────────
553
554    /// Returns a new tensor with a different shape but the same buffer.
555    /// The total number of elements must be unchanged.
556    pub fn reshape(&self, new_shape: impl Into<Shape>) -> Result<Tensor> {
557        let new_shape = self.shape.reshape(new_shape.into().dims().to_vec())?;
558        let strides = new_shape.strides();
559        Ok(Tensor {
560            shape: new_shape,
561            dtype: self.dtype,
562            strides,
563            buffer: self.buffer.clone(),
564            offset: self.offset,
565        })
566    }
567
568    /// Transpose a 2-D tensor (swap axes 0 and 1).
569    pub fn t(&self) -> Result<Tensor> {
570        if self.ndim() != 2 {
571            return Err(SapientError::internal("t() requires a 2-D tensor"));
572        }
573        let mut dims = self.shape.dims().to_vec();
574        let mut strides = self.strides.clone();
575        dims.swap(0, 1);
576        strides.swap(0, 1);
577        Ok(Tensor {
578            shape: Shape(dims),
579            dtype: self.dtype,
580            strides,
581            buffer: self.buffer.clone(),
582            offset: self.offset,
583        })
584    }
585
586    /// Return a view of the tensor sliced along the given axis.
587    pub fn slice_axis(&self, axis: usize, start: usize, end: usize) -> Result<Tensor> {
588        let mut dims = self.shape.dims().to_vec();
589        if axis >= dims.len() {
590            return Err(SapientError::internal("slice axis out of bounds"));
591        }
592        if start > end || end > dims[axis] {
593            return Err(SapientError::internal("slice range out of bounds"));
594        }
595        dims[axis] = end - start;
596        let offset = self.offset + start * self.strides[axis] * self.dtype.element_size();
597        Ok(Tensor {
598            shape: Shape(dims),
599            dtype: self.dtype,
600            strides: self.strides.clone(),
601            buffer: self.buffer.clone(),
602            offset,
603        })
604    }
605
606    // ── Metadata convenience ─────────────────────────────────────────────────
607
608    /// Byte count for all elements.
609    pub fn byte_size(&self) -> usize {
610        self.dtype.byte_count(self.numel())
611    }
612}
613
614impl std::fmt::Display for Tensor {
615    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
616        write!(
617            f,
618            "Tensor(shape={}, dtype={}, device={})",
619            self.shape,
620            self.dtype,
621            self.buffer.0.device()
622        )
623    }
624}
625
626// ── Serde support for Tensor ─────────────────────────────────────────────────
627
628/// Serialisable proxy — stores raw f32 data alongside shape/dtype.
629#[derive(Serialize, Deserialize)]
630struct TensorProxy {
631    shape: Shape,
632    dtype: DType,
633    /// Raw bytes as base64-encoded (for JSON), or raw for binary.
634    data: Vec<f32>,
635}
636
637impl Serialize for Tensor {
638    fn serialize<S: Serializer>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> {
639        let data: Vec<f32> = if self.dtype == DType::F32 {
640            self.as_f32_slice().to_vec()
641        } else {
642            vec![] // non-f32 tensors: zero data (future work)
643        };
644        TensorProxy {
645            shape: self.shape.clone(),
646            dtype: self.dtype,
647            data,
648        }
649        .serialize(serializer)
650    }
651}
652
653impl<'de> Deserialize<'de> for Tensor {
654    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> std::result::Result<Self, D::Error> {
655        let proxy = TensorProxy::deserialize(deserializer)?;
656        if proxy.data.is_empty() {
657            Tensor::zeros(proxy.shape, proxy.dtype).map_err(serde::de::Error::custom)
658        } else {
659            Tensor::from_f32(&proxy.data, proxy.shape).map_err(serde::de::Error::custom)
660        }
661    }
662}
663
664/// A serializable descriptor for a tensor — shape and dtype only (no data).
665#[derive(Debug, Clone, Serialize, Deserialize)]
666pub struct TensorMeta {
667    pub shape: Shape,
668    pub dtype: DType,
669}
670
671impl From<&Tensor> for TensorMeta {
672    fn from(t: &Tensor) -> Self {
673        Self {
674            shape: t.shape.clone(),
675            dtype: t.dtype,
676        }
677    }
678}
679
680#[cfg(test)]
681mod tests {
682    use super::*;
683
684    #[test]
685    fn zeros_dtype_shape() {
686        let t = Tensor::zeros(vec![2, 3], DType::F32).unwrap();
687        assert_eq!(t.shape().dims(), &[2, 3]);
688        assert_eq!(t.dtype(), DType::F32);
689        assert_eq!(t.numel(), 6);
690    }
691
692    #[test]
693    fn from_f32_roundtrip() {
694        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
695        let t = Tensor::from_f32(&data, vec![2, 3]).unwrap();
696        assert_eq!(t.as_f32_slice(), data.as_slice());
697    }
698
699    #[test]
700    fn reshape_preserves_data() {
701        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
702        let t = Tensor::from_f32(&data, vec![2, 3]).unwrap();
703        let r = t.reshape(vec![3, 2]).unwrap();
704        assert_eq!(r.shape().dims(), &[3, 2]);
705        assert_eq!(r.as_f32_slice(), data.as_slice());
706    }
707
708    #[test]
709    fn reshape_wrong_numel() {
710        let t = Tensor::zeros(vec![2, 3], DType::F32).unwrap();
711        assert!(t.reshape(vec![5]).is_err());
712    }
713
714    #[test]
715    fn transpose_2d() {
716        let t = Tensor::zeros(vec![3, 4], DType::F32).unwrap();
717        let t2 = t.t().unwrap();
718        assert_eq!(t2.shape().dims(), &[4, 3]);
719    }
720
721    #[test]
722    fn byte_size() {
723        let t = Tensor::zeros(vec![4, 4], DType::F32).unwrap();
724        assert_eq!(t.byte_size(), 64);
725    }
726}