Skip to main content

sapient_core/
tensor.rs

1//! `Tensor` — the central multi-dimensional array type in SAPIENT.
2//!
3//! A `Tensor` owns its shape and dtype metadata, and holds a reference-counted
4//! `BufferHandle` for the raw bytes.  Layout is always row-major (C order).
5
6use serde::{Deserialize, Serialize};
7use serde::{Deserializer, Serializer};
8use std::sync::Arc;
9
10use crate::buffer::{BufferHandle, CpuBuffer};
11use crate::dtype::DType;
12use crate::error::{Result, SapientError};
13use crate::shape::Shape;
14
15// ── Tensor ────────────────────────────────────────────────────────────────────
16
17/// A multi-dimensional tensor with reference-counted buffer ownership.
18#[derive(Debug, Clone)]
19pub struct Tensor {
20    shape: Shape,
21    dtype: DType,
22    strides: Vec<usize>, // row-major by default
23    buffer: BufferHandle,
24    // Byte offset into the buffer where element [0,0,...,0] lives.
25    offset: usize,
26}
27
28impl Tensor {
29    // ── Constructors ─────────────────────────────────────────────────────────
30
31    /// Create a zero-filled tensor on the CPU.
32    pub fn zeros(shape: impl Into<Shape>, dtype: DType) -> Result<Self> {
33        let shape = shape.into();
34        shape.validate()?;
35        let numel = shape.numel();
36        let strides = shape.strides();
37        let buffer = BufferHandle::new(CpuBuffer::zeros(numel, dtype)?);
38        Ok(Self {
39            shape,
40            dtype,
41            strides,
42            buffer,
43            offset: 0,
44        })
45    }
46
47    /// Create a tensor from a flat `f32` slice (CPU, row-major).
48    pub fn from_f32(data: &[f32], shape: impl Into<Shape>) -> Result<Self> {
49        let shape = shape.into();
50        shape.validate()?;
51        if data.len() != shape.numel() {
52            return Err(SapientError::ShapeMismatch {
53                expected: shape.dims().to_vec(),
54                got: vec![data.len()],
55            });
56        }
57        let strides = shape.strides();
58        let buffer = BufferHandle::new(CpuBuffer::from_f32_slice(data)?);
59        Ok(Self {
60            shape,
61            dtype: DType::F32,
62            strides,
63            buffer,
64            offset: 0,
65        })
66    }
67
68    /// Create a tensor from raw BF16 bytes, storing them natively without conversion.
69    /// Use `to_f32_vec()` or `to_f32_tensor()` to convert for computation.
70    pub fn from_bf16_bytes(data: &[u8], shape: impl Into<Shape>) -> Result<Self> {
71        let shape = shape.into();
72        shape.validate()?;
73        let expected_bytes = shape.numel() * 2;
74        if data.len() != expected_bytes {
75            return Err(SapientError::ShapeMismatch {
76                expected: shape.dims().to_vec(),
77                got: vec![data.len() / 2],
78            });
79        }
80        let strides = shape.strides();
81        let buffer = BufferHandle::new(CpuBuffer::from_bytes_slice(data)?);
82        Ok(Self {
83            shape,
84            dtype: DType::BF16,
85            strides,
86            buffer,
87            offset: 0,
88        })
89    }
90
91    /// Create a tensor from raw F16 bytes, storing them natively without conversion.
92    pub fn from_f16_bytes(data: &[u8], shape: impl Into<Shape>) -> Result<Self> {
93        let shape = shape.into();
94        shape.validate()?;
95        let expected_bytes = shape.numel() * 2;
96        if data.len() != expected_bytes {
97            return Err(SapientError::ShapeMismatch {
98                expected: shape.dims().to_vec(),
99                got: vec![data.len() / 2],
100            });
101        }
102        let strides = shape.strides();
103        let buffer = BufferHandle::new(CpuBuffer::from_bytes_slice(data)?);
104        Ok(Self {
105            shape,
106            dtype: DType::F16,
107            strides,
108            buffer,
109            offset: 0,
110        })
111    }
112
113    /// Create a quantized tensor from raw block bytes (Q4_0 / Q8_0).
114    ///
115    /// `data` must contain exactly `dtype.byte_count(shape.numel())` bytes, i.e.
116    /// the packed ggml block bytes with no expansion.  The shape describes the
117    /// *logical* element count; `shape.numel()` must be a multiple of 32.
118    pub fn from_quant_bytes(data: &[u8], shape: impl Into<Shape>, dtype: DType) -> Result<Self> {
119        if !dtype.is_quantized() {
120            return Err(SapientError::TypeMismatch {
121                expected: "a quantized dtype (Q4_0, Q8_0, Q4_K, Q5_K, Q6_K)".into(),
122                got: dtype.to_string(),
123            });
124        }
125        let shape = shape.into();
126        shape.validate()?;
127        let numel = shape.numel();
128        let expected_bytes = dtype.byte_count(numel);
129        if data.len() != expected_bytes {
130            return Err(SapientError::ShapeMismatch {
131                expected: vec![expected_bytes],
132                got: vec![data.len()],
133            });
134        }
135        let strides = shape.strides();
136        let buffer = BufferHandle::new(CpuBuffer::from_bytes_slice(data)?);
137        Ok(Self {
138            shape,
139            dtype,
140            strides,
141            buffer,
142            offset: 0,
143        })
144    }
145
146    /// Create a scalar tensor from a single `f32`.
147    pub fn scalar_f32(v: f32) -> Result<Self> {
148        Self::from_f32(&[v], Shape::scalar())
149    }
150
151    /// Create from a pre-built `BufferHandle` (used by backends).
152    pub fn from_buffer(
153        shape: impl Into<Shape>,
154        dtype: DType,
155        buffer: BufferHandle,
156        offset: usize,
157    ) -> Result<Self> {
158        let shape = shape.into();
159        shape.validate()?;
160        let required = dtype.byte_count(shape.numel());
161        if buffer.len() < offset + required {
162            return Err(SapientError::BufferSizeMismatch {
163                expected: offset + required,
164                got: buffer.len(),
165            });
166        }
167        let strides = shape.strides();
168        Ok(Self {
169            shape,
170            dtype,
171            strides,
172            buffer,
173            offset,
174        })
175    }
176
177    // ── Accessors ────────────────────────────────────────────────────────────
178
179    pub fn shape(&self) -> &Shape {
180        &self.shape
181    }
182    pub fn dtype(&self) -> DType {
183        self.dtype
184    }
185    pub fn ndim(&self) -> usize {
186        self.shape.ndim()
187    }
188    pub fn numel(&self) -> usize {
189        self.shape.numel()
190    }
191    pub fn strides(&self) -> &[usize] {
192        &self.strides
193    }
194    pub fn buffer(&self) -> &BufferHandle {
195        &self.buffer
196    }
197    pub fn offset(&self) -> usize {
198        self.offset
199    }
200
201    /// True if the tensor has a single element.
202    pub fn is_scalar(&self) -> bool {
203        self.shape.is_scalar() || self.numel() == 1
204    }
205
206    /// True if the buffer is row-major contiguous (normal case).
207    pub fn is_contiguous(&self) -> bool {
208        self.strides == self.shape.strides() && self.offset == 0
209    }
210
211    // ── Typed data access (CPU only) ─────────────────────────────────────────
212
213    /// Raw byte view. For non-quantized tensors returns the full buffer slice from
214    /// `offset` onwards (preserving the original behavior that stride-based kernels
215    /// rely on). For quantized tensors (Q4_0/Q8_0) returns exactly the packed block
216    /// bytes for this tensor's logical shape.
217    pub fn as_bytes(&self) -> &[u8] {
218        let bytes = self.buffer.as_bytes();
219        if self.dtype.is_quantized() {
220            let end = self.offset + self.dtype.byte_count(self.numel());
221            &bytes[self.offset..end]
222        } else {
223            &bytes[self.offset..]
224        }
225    }
226
227    /// For quantized tensors (Q4_0, Q8_0): returns the packed block bytes as a
228    /// row-major slice where each logical row of `k` elements occupies
229    /// `dtype.byte_count(k)` bytes.  Panics if the tensor is not quantized.
230    pub fn as_quant_blocks(&self) -> &[u8] {
231        assert!(
232            self.dtype.is_quantized(),
233            "as_quant_blocks() called on non-quantized tensor (dtype = {})",
234            self.dtype
235        );
236        self.as_bytes()
237    }
238
239    /// Typed `f32` view — panics if dtype is not F32.
240    pub fn as_f32_slice(&self) -> &[f32] {
241        assert_eq!(
242            self.dtype,
243            DType::F32,
244            "Tensor dtype is not F32 — call to_f32_vec() instead"
245        );
246        let bytes = self.as_bytes();
247        assert_eq!(bytes.len() % 4, 0);
248        // SAFETY: alignment ensured by CpuBuffer, dtype checked above.
249        unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const f32, bytes.len() / 4) }
250    }
251
252    /// Convert this tensor to a contiguous `Vec<f32>` that matches `shape().numel()`
253    /// exactly, even when the tensor is a **non-contiguous view** (e.g. a KV-cache
254    /// slice from `slice_axis`).
255    ///
256    /// For contiguous tensors this is a fast bounded copy.  For non-contiguous
257    /// tensors (strides don't match the natural row-major strides, or offset ≠ 0)
258    /// it uses stride-based indexing to extract only the logically-reachable elements
259    /// in row-major order — the approach that `as_f32_slice` / `to_f32_cow` cannot
260    /// do because they return the full backing buffer.
261    pub fn to_contiguous_f32_vec(&self) -> Vec<f32> {
262        let numel = self.numel();
263        if self.is_contiguous() {
264            // Fast path: elements are dense starting at offset.
265            // Limit to numel to avoid reading past the logical tensor.
266            match self.dtype {
267                DType::F32 => self.as_f32_slice()[..numel].to_vec(),
268                _ => {
269                    let v = self.to_f32_vec();
270                    v[..numel.min(v.len())].to_vec()
271                }
272            }
273        } else {
274            // Slow path: stride-based copy.
275            // `raw` gives us the full backing buffer from `self.offset` as f32 units.
276            let raw: Vec<f32> = match self.dtype {
277                DType::F32 => self.as_f32_slice().to_vec(),
278                _ => self.to_f32_vec(),
279            };
280            let dims = self.shape.dims();
281            let strides = &self.strides; // element strides (not byte strides)
282            let mut out = vec![0.0f32; numel];
283            for (flat, dst) in out.iter_mut().enumerate() {
284                // Convert flat (row-major) index to per-dimension indices, then
285                // compute the element offset using the tensor's actual strides.
286                let mut rem = flat;
287                let mut src = 0usize;
288                for d in (0..dims.len()).rev() {
289                    let idx_d = rem % dims[d];
290                    rem /= dims[d];
291                    src += idx_d * strides[d];
292                }
293                *dst = *raw.get(src).unwrap_or(&0.0);
294            }
295            out
296        }
297    }
298
299    /// Returns a `Cow<[f32]>`. Borrows if the tensor is already F32, otherwise allocates a new `Vec<f32>`.
300    pub fn to_f32_cow(&self) -> std::borrow::Cow<'_, [f32]> {
301        if self.dtype == DType::F32 {
302            std::borrow::Cow::Borrowed(self.as_f32_slice())
303        } else {
304            std::borrow::Cow::Owned(self.to_f32_vec())
305        }
306    }
307
308    /// Convert this tensor to a `Vec<f32>`, handling all dtypes including quantized.
309    /// For F32: cheap copy. For F16/BF16: convert. For quantized: dequantize all blocks.
310    pub fn to_f32_vec(&self) -> Vec<f32> {
311        use crate::dtype::{
312            K_QUANT_BLOCK_SIZE, Q4_0_BLOCK_BYTES, Q4_K_BLOCK_BYTES, Q5_K_BLOCK_BYTES,
313            Q6_K_BLOCK_BYTES, Q8_0_BLOCK_BYTES, QUANT_BLOCK_SIZE,
314        };
315        match self.dtype {
316            DType::F32 => self.as_f32_slice().to_vec(),
317            DType::BF16 => {
318                let bytes = self.as_bytes();
319                bytes
320                    .chunks_exact(2)
321                    .map(|c| f32::from(half::bf16::from_le_bytes(c.try_into().unwrap())))
322                    .collect()
323            }
324            DType::F16 => {
325                let bytes = self.as_bytes();
326                bytes
327                    .chunks_exact(2)
328                    .map(|c| half::f16::from_le_bytes(c.try_into().unwrap()).to_f32())
329                    .collect()
330            }
331            DType::Q4_0 => {
332                let numel = self.numel();
333                let bytes = self.as_bytes();
334                let mut out = vec![0.0f32; numel];
335                for (b, block) in bytes.chunks_exact(Q4_0_BLOCK_BYTES).enumerate() {
336                    let d = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
337                    for j in 0..QUANT_BLOCK_SIZE / 2 {
338                        let byte = block[2 + j];
339                        let lo = (byte & 0x0f) as i32 - 8;
340                        let hi = (byte >> 4) as i32 - 8;
341                        out[b * QUANT_BLOCK_SIZE + j] = lo as f32 * d;
342                        out[b * QUANT_BLOCK_SIZE + j + QUANT_BLOCK_SIZE / 2] = hi as f32 * d;
343                    }
344                }
345                out
346            }
347            DType::Q8_0 => {
348                let numel = self.numel();
349                let bytes = self.as_bytes();
350                let mut out = vec![0.0f32; numel];
351                for (b, block) in bytes.chunks_exact(Q8_0_BLOCK_BYTES).enumerate() {
352                    let d = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
353                    for j in 0..QUANT_BLOCK_SIZE {
354                        out[b * QUANT_BLOCK_SIZE + j] = block[2 + j] as i8 as f32 * d;
355                    }
356                }
357                out
358            }
359            DType::Q4_K => {
360                let numel = self.numel();
361                let bytes = self.as_bytes();
362                let mut out = vec![0.0f32; numel];
363                let mut out_idx = 0usize;
364                for block in bytes.chunks_exact(Q4_K_BLOCK_BYTES) {
365                    let d = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
366                    let dmin = half::f16::from_le_bytes([block[2], block[3]]).to_f32();
367                    let scales = &block[4..16];
368                    let qs = &block[16..Q4_K_BLOCK_BYTES];
369                    let mut q_off = 0usize;
370                    let mut is = 0usize;
371                    for _ in 0..(K_QUANT_BLOCK_SIZE / 64) {
372                        let (sc1, m1) = Self::get_scale_min_k4(is, scales);
373                        let d1 = d * sc1 as f32;
374                        let m1v = dmin * m1 as f32;
375                        let (sc2, m2) = Self::get_scale_min_k4(is + 1, scales);
376                        let d2 = d * sc2 as f32;
377                        let m2v = dmin * m2 as f32;
378                        for l in 0..32 {
379                            out[out_idx + l] = d1 * (qs[q_off + l] & 0x0F) as f32 - m1v;
380                            out[out_idx + l + 32] = d2 * (qs[q_off + l] >> 4) as f32 - m2v;
381                        }
382                        out_idx += 64;
383                        q_off += 32;
384                        is += 2;
385                    }
386                }
387                out
388            }
389            DType::Q5_K => {
390                let numel = self.numel();
391                let bytes = self.as_bytes();
392                let mut out = vec![0.0f32; numel];
393                let mut out_idx = 0usize;
394                for block in bytes.chunks_exact(Q5_K_BLOCK_BYTES) {
395                    let d = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
396                    let dmin = half::f16::from_le_bytes([block[2], block[3]]).to_f32();
397                    let scales = &block[4..16];
398                    let qh = &block[16..48];
399                    let ql = &block[48..Q5_K_BLOCK_BYTES];
400                    let mut ql_off = 0usize;
401                    let mut is = 0usize;
402                    let mut u1: u8 = 1;
403                    let mut u2: u8 = 2;
404                    for _ in 0..(K_QUANT_BLOCK_SIZE / 64) {
405                        let (sc1, m1) = Self::get_scale_min_k4(is, scales);
406                        let d1 = d * sc1 as f32;
407                        let m1v = dmin * m1 as f32;
408                        let (sc2, m2) = Self::get_scale_min_k4(is + 1, scales);
409                        let d2 = d * sc2 as f32;
410                        let m2v = dmin * m2 as f32;
411                        let qh_byte = qh[is / 8];
412                        for l in 0..32usize {
413                            let hi = if qh_byte & u1 != 0 { 16.0f32 } else { 0.0 };
414                            out[out_idx + l] = d1 * ((ql[ql_off + l] & 0x0F) as f32 + hi) - m1v;
415                            let hi2 = if qh_byte & u2 != 0 { 16.0f32 } else { 0.0 };
416                            out[out_idx + l + 32] = d2 * ((ql[ql_off + l] >> 4) as f32 + hi2) - m2v;
417                        }
418                        out_idx += 64;
419                        ql_off += 32;
420                        is += 2;
421                        if is % 8 == 0 {
422                            u1 = 1;
423                            u2 = 2;
424                        } else {
425                            u1 <<= 2;
426                            u2 <<= 2;
427                        }
428                    }
429                }
430                out
431            }
432            DType::Q6_K => {
433                let numel = self.numel();
434                let bytes = self.as_bytes();
435                let mut out = vec![0.0f32; numel];
436                let mut out_idx = 0usize;
437                for block in bytes.chunks_exact(Q6_K_BLOCK_BYTES) {
438                    let ql = &block[0..128];
439                    let qh = &block[128..192];
440                    let sc = &block[192..208];
441                    let d = half::f16::from_le_bytes([block[208], block[209]]).to_f32();
442                    let mut ql_off = 0usize;
443                    let mut qh_off = 0usize;
444                    let mut ib = 0usize;
445                    for _ in 0..(K_QUANT_BLOCK_SIZE / 128) {
446                        for l in 0..32usize {
447                            let q1 = (((ql[ql_off + l] & 0x0F) | ((qh[qh_off + l] & 3) << 4))
448                                as i32
449                                - 32) as f32;
450                            let q2 = (((ql[ql_off + l + 32] & 0x0F)
451                                | (((qh[qh_off + l] >> 2) & 3) << 4))
452                                as i32
453                                - 32) as f32;
454                            let q3 = (((ql[ql_off + l] >> 4) | (((qh[qh_off + l] >> 4) & 3) << 4))
455                                as i32
456                                - 32) as f32;
457                            let q4 = (((ql[ql_off + l + 32] >> 4)
458                                | (((qh[qh_off + l] >> 6) & 3) << 4))
459                                as i32
460                                - 32) as f32;
461                            out[out_idx + l] = d * sc[ib] as i8 as f32 * q1;
462                            out[out_idx + l + 32] = d * sc[ib + 1] as i8 as f32 * q2;
463                            out[out_idx + l + 64] = d * sc[ib + 2] as i8 as f32 * q3;
464                            out[out_idx + l + 96] = d * sc[ib + 3] as i8 as f32 * q4;
465                        }
466                        out_idx += 128;
467                        ql_off += 64;
468                        qh_off += 32;
469                        ib += 4;
470                    }
471                }
472                out
473            }
474            _ => self.as_f32_slice().to_vec(), // fallback for integer dtypes
475        }
476    }
477
478    /// Extract scale and min for a K-quant sub-block (used in Q4_K/Q5_K dequantization).
479    #[inline]
480    fn get_scale_min_k4(j: usize, scales: &[u8]) -> (u8, u8) {
481        if j < 4 {
482            (scales[j] & 63, scales[j + 4] & 63)
483        } else {
484            (
485                (scales[j + 4] & 0x0F) | ((scales[j - 4] >> 6) << 4),
486                (scales[j + 4] >> 4) | ((scales[j] >> 6) << 4),
487            )
488        }
489    }
490
491    /// Returns an F32 tensor, converting BF16/F16 if necessary.
492    /// For already-F32 tensors, clones the buffer. For native types, converts.
493    pub fn to_f32_tensor(&self) -> Result<Tensor> {
494        match self.dtype {
495            DType::F32 => Ok(self.clone()),
496            _ => Tensor::from_f32(&self.to_f32_vec(), self.shape.clone()),
497        }
498    }
499
500    /// Mutable typed `f32` view — fails if buffer is shared or not F32.
501    pub fn as_f32_slice_mut(&mut self) -> Result<&mut [f32]> {
502        if self.dtype != DType::F32 {
503            return Err(SapientError::internal("Tensor dtype is not F32"));
504        }
505        let offset = self.offset;
506        let buf = Arc::get_mut(&mut self.buffer.0)
507            .ok_or_else(|| SapientError::internal("Cannot mutate shared tensor buffer"))?;
508        let bytes = buf.as_bytes_mut();
509        let bytes = &mut bytes[offset..];
510        if bytes.len() % 4 != 0 {
511            return Err(SapientError::internal("Buffer length not a multiple of 4"));
512        }
513        // SAFETY: alignment ensured by CpuBuffer, dtype checked above.
514        Ok(unsafe {
515            std::slice::from_raw_parts_mut(bytes.as_mut_ptr() as *mut f32, bytes.len() / 4)
516        })
517    }
518
519    // ── Shape manipulation ───────────────────────────────────────────────────
520
521    /// Returns a new tensor with a different shape but the same buffer.
522    /// The total number of elements must be unchanged.
523    pub fn reshape(&self, new_shape: impl Into<Shape>) -> Result<Tensor> {
524        let new_shape = self.shape.reshape(new_shape.into().dims().to_vec())?;
525        let strides = new_shape.strides();
526        Ok(Tensor {
527            shape: new_shape,
528            dtype: self.dtype,
529            strides,
530            buffer: self.buffer.clone(),
531            offset: self.offset,
532        })
533    }
534
535    /// Transpose a 2-D tensor (swap axes 0 and 1).
536    pub fn t(&self) -> Result<Tensor> {
537        if self.ndim() != 2 {
538            return Err(SapientError::internal("t() requires a 2-D tensor"));
539        }
540        let mut dims = self.shape.dims().to_vec();
541        let mut strides = self.strides.clone();
542        dims.swap(0, 1);
543        strides.swap(0, 1);
544        Ok(Tensor {
545            shape: Shape(dims),
546            dtype: self.dtype,
547            strides,
548            buffer: self.buffer.clone(),
549            offset: self.offset,
550        })
551    }
552
553    /// Return a view of the tensor sliced along the given axis.
554    pub fn slice_axis(&self, axis: usize, start: usize, end: usize) -> Result<Tensor> {
555        let mut dims = self.shape.dims().to_vec();
556        if axis >= dims.len() {
557            return Err(SapientError::internal("slice axis out of bounds"));
558        }
559        if start > end || end > dims[axis] {
560            return Err(SapientError::internal("slice range out of bounds"));
561        }
562        dims[axis] = end - start;
563        let offset = self.offset + start * self.strides[axis] * self.dtype.element_size();
564        Ok(Tensor {
565            shape: Shape(dims),
566            dtype: self.dtype,
567            strides: self.strides.clone(),
568            buffer: self.buffer.clone(),
569            offset,
570        })
571    }
572
573    // ── Metadata convenience ─────────────────────────────────────────────────
574
575    /// Byte count for all elements.
576    pub fn byte_size(&self) -> usize {
577        self.dtype.byte_count(self.numel())
578    }
579}
580
581impl std::fmt::Display for Tensor {
582    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
583        write!(
584            f,
585            "Tensor(shape={}, dtype={}, device={})",
586            self.shape,
587            self.dtype,
588            self.buffer.0.device()
589        )
590    }
591}
592
593// ── Serde support for Tensor ─────────────────────────────────────────────────
594
595/// Serialisable proxy — stores raw f32 data alongside shape/dtype.
596#[derive(Serialize, Deserialize)]
597struct TensorProxy {
598    shape: Shape,
599    dtype: DType,
600    /// Raw bytes as base64-encoded (for JSON), or raw for binary.
601    data: Vec<f32>,
602}
603
604impl Serialize for Tensor {
605    fn serialize<S: Serializer>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> {
606        let data: Vec<f32> = if self.dtype == DType::F32 {
607            self.as_f32_slice().to_vec()
608        } else {
609            vec![] // non-f32 tensors: zero data (future work)
610        };
611        TensorProxy {
612            shape: self.shape.clone(),
613            dtype: self.dtype,
614            data,
615        }
616        .serialize(serializer)
617    }
618}
619
620impl<'de> Deserialize<'de> for Tensor {
621    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> std::result::Result<Self, D::Error> {
622        let proxy = TensorProxy::deserialize(deserializer)?;
623        if proxy.data.is_empty() {
624            Tensor::zeros(proxy.shape, proxy.dtype).map_err(serde::de::Error::custom)
625        } else {
626            Tensor::from_f32(&proxy.data, proxy.shape).map_err(serde::de::Error::custom)
627        }
628    }
629}
630
631/// A serializable descriptor for a tensor — shape and dtype only (no data).
632#[derive(Debug, Clone, Serialize, Deserialize)]
633pub struct TensorMeta {
634    pub shape: Shape,
635    pub dtype: DType,
636}
637
638impl From<&Tensor> for TensorMeta {
639    fn from(t: &Tensor) -> Self {
640        Self {
641            shape: t.shape.clone(),
642            dtype: t.dtype,
643        }
644    }
645}
646
647#[cfg(test)]
648mod tests {
649    use super::*;
650
651    #[test]
652    fn zeros_dtype_shape() {
653        let t = Tensor::zeros(vec![2, 3], DType::F32).unwrap();
654        assert_eq!(t.shape().dims(), &[2, 3]);
655        assert_eq!(t.dtype(), DType::F32);
656        assert_eq!(t.numel(), 6);
657    }
658
659    #[test]
660    fn from_f32_roundtrip() {
661        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
662        let t = Tensor::from_f32(&data, vec![2, 3]).unwrap();
663        assert_eq!(t.as_f32_slice(), data.as_slice());
664    }
665
666    #[test]
667    fn reshape_preserves_data() {
668        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
669        let t = Tensor::from_f32(&data, vec![2, 3]).unwrap();
670        let r = t.reshape(vec![3, 2]).unwrap();
671        assert_eq!(r.shape().dims(), &[3, 2]);
672        assert_eq!(r.as_f32_slice(), data.as_slice());
673    }
674
675    #[test]
676    fn reshape_wrong_numel() {
677        let t = Tensor::zeros(vec![2, 3], DType::F32).unwrap();
678        assert!(t.reshape(vec![5]).is_err());
679    }
680
681    #[test]
682    fn transpose_2d() {
683        let t = Tensor::zeros(vec![3, 4], DType::F32).unwrap();
684        let t2 = t.t().unwrap();
685        assert_eq!(t2.shape().dims(), &[4, 3]);
686    }
687
688    #[test]
689    fn byte_size() {
690        let t = Tensor::zeros(vec![4, 4], DType::F32).unwrap();
691        assert_eq!(t.byte_size(), 64);
692    }
693}