1use serde::{Deserialize, Serialize};
7use serde::{Deserializer, Serializer};
8use std::sync::Arc;
9
10use crate::buffer::{BufferHandle, CpuBuffer};
11use crate::dtype::DType;
12use crate::error::{Result, SapientError};
13use crate::shape::Shape;
14
15#[derive(Debug, Clone)]
19pub struct Tensor {
20 shape: Shape,
21 dtype: DType,
22 strides: Vec<usize>, buffer: BufferHandle,
24 offset: usize,
26}
27
28impl Tensor {
29 pub fn zeros(shape: impl Into<Shape>, dtype: DType) -> Result<Self> {
33 let shape = shape.into();
34 shape.validate()?;
35 let numel = shape.numel();
36 let strides = shape.strides();
37 let buffer = BufferHandle::new(CpuBuffer::zeros(numel, dtype)?);
38 Ok(Self {
39 shape,
40 dtype,
41 strides,
42 buffer,
43 offset: 0,
44 })
45 }
46
47 pub fn from_f32(data: &[f32], shape: impl Into<Shape>) -> Result<Self> {
49 let shape = shape.into();
50 shape.validate()?;
51 if data.len() != shape.numel() {
52 return Err(SapientError::ShapeMismatch {
53 expected: shape.dims().to_vec(),
54 got: vec![data.len()],
55 });
56 }
57 let strides = shape.strides();
58 let buffer = BufferHandle::new(CpuBuffer::from_f32_slice(data)?);
59 Ok(Self {
60 shape,
61 dtype: DType::F32,
62 strides,
63 buffer,
64 offset: 0,
65 })
66 }
67
68 pub fn from_bf16_bytes(data: &[u8], shape: impl Into<Shape>) -> Result<Self> {
71 let shape = shape.into();
72 shape.validate()?;
73 let expected_bytes = shape.numel() * 2;
74 if data.len() != expected_bytes {
75 return Err(SapientError::ShapeMismatch {
76 expected: shape.dims().to_vec(),
77 got: vec![data.len() / 2],
78 });
79 }
80 let strides = shape.strides();
81 let buffer = BufferHandle::new(CpuBuffer::from_bytes_slice(data)?);
82 Ok(Self {
83 shape,
84 dtype: DType::BF16,
85 strides,
86 buffer,
87 offset: 0,
88 })
89 }
90
91 pub fn from_f16_bytes(data: &[u8], shape: impl Into<Shape>) -> Result<Self> {
93 let shape = shape.into();
94 shape.validate()?;
95 let expected_bytes = shape.numel() * 2;
96 if data.len() != expected_bytes {
97 return Err(SapientError::ShapeMismatch {
98 expected: shape.dims().to_vec(),
99 got: vec![data.len() / 2],
100 });
101 }
102 let strides = shape.strides();
103 let buffer = BufferHandle::new(CpuBuffer::from_bytes_slice(data)?);
104 Ok(Self {
105 shape,
106 dtype: DType::F16,
107 strides,
108 buffer,
109 offset: 0,
110 })
111 }
112
113 pub fn from_quant_bytes(data: &[u8], shape: impl Into<Shape>, dtype: DType) -> Result<Self> {
119 if !dtype.is_quantized() {
120 return Err(SapientError::TypeMismatch {
121 expected: "a quantized dtype (Q4_0 or Q8_0)".into(),
122 got: dtype.to_string(),
123 });
124 }
125 let shape = shape.into();
126 shape.validate()?;
127 let numel = shape.numel();
128 let expected_bytes = dtype.byte_count(numel);
129 if data.len() != expected_bytes {
130 return Err(SapientError::ShapeMismatch {
131 expected: vec![expected_bytes],
132 got: vec![data.len()],
133 });
134 }
135 let strides = shape.strides();
136 let buffer = BufferHandle::new(CpuBuffer::from_bytes_slice(data)?);
137 Ok(Self {
138 shape,
139 dtype,
140 strides,
141 buffer,
142 offset: 0,
143 })
144 }
145
146 pub fn scalar_f32(v: f32) -> Result<Self> {
148 Self::from_f32(&[v], Shape::scalar())
149 }
150
151 pub fn from_buffer(
153 shape: impl Into<Shape>,
154 dtype: DType,
155 buffer: BufferHandle,
156 offset: usize,
157 ) -> Result<Self> {
158 let shape = shape.into();
159 shape.validate()?;
160 let required = dtype.byte_count(shape.numel());
161 if buffer.len() < offset + required {
162 return Err(SapientError::BufferSizeMismatch {
163 expected: offset + required,
164 got: buffer.len(),
165 });
166 }
167 let strides = shape.strides();
168 Ok(Self {
169 shape,
170 dtype,
171 strides,
172 buffer,
173 offset,
174 })
175 }
176
177 pub fn shape(&self) -> &Shape {
180 &self.shape
181 }
182 pub fn dtype(&self) -> DType {
183 self.dtype
184 }
185 pub fn ndim(&self) -> usize {
186 self.shape.ndim()
187 }
188 pub fn numel(&self) -> usize {
189 self.shape.numel()
190 }
191 pub fn strides(&self) -> &[usize] {
192 &self.strides
193 }
194 pub fn buffer(&self) -> &BufferHandle {
195 &self.buffer
196 }
197 pub fn offset(&self) -> usize {
198 self.offset
199 }
200
201 pub fn is_scalar(&self) -> bool {
203 self.shape.is_scalar() || self.numel() == 1
204 }
205
206 pub fn is_contiguous(&self) -> bool {
208 self.strides == self.shape.strides() && self.offset == 0
209 }
210
211 pub fn as_bytes(&self) -> &[u8] {
218 let bytes = self.buffer.as_bytes();
219 if self.dtype.is_quantized() {
220 let end = self.offset + self.dtype.byte_count(self.numel());
221 &bytes[self.offset..end]
222 } else {
223 &bytes[self.offset..]
224 }
225 }
226
227 pub fn as_quant_blocks(&self) -> &[u8] {
231 assert!(
232 self.dtype.is_quantized(),
233 "as_quant_blocks() called on non-quantized tensor (dtype = {})",
234 self.dtype
235 );
236 self.as_bytes()
237 }
238
239 pub fn as_f32_slice(&self) -> &[f32] {
241 assert_eq!(
242 self.dtype,
243 DType::F32,
244 "Tensor dtype is not F32 — call to_f32_vec() instead"
245 );
246 let bytes = self.as_bytes();
247 assert_eq!(bytes.len() % 4, 0);
248 unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const f32, bytes.len() / 4) }
250 }
251
252 pub fn to_f32_cow(&self) -> std::borrow::Cow<'_, [f32]> {
254 if self.dtype == DType::F32 {
255 std::borrow::Cow::Borrowed(self.as_f32_slice())
256 } else {
257 std::borrow::Cow::Owned(self.to_f32_vec())
258 }
259 }
260
261 pub fn to_f32_vec(&self) -> Vec<f32> {
264 use crate::dtype::{Q4_0_BLOCK_BYTES, Q8_0_BLOCK_BYTES, QUANT_BLOCK_SIZE};
265 match self.dtype {
266 DType::F32 => self.as_f32_slice().to_vec(),
267 DType::BF16 => {
268 let bytes = self.as_bytes();
269 bytes
270 .chunks_exact(2)
271 .map(|c| f32::from(half::bf16::from_le_bytes(c.try_into().unwrap())))
272 .collect()
273 }
274 DType::F16 => {
275 let bytes = self.as_bytes();
276 bytes
277 .chunks_exact(2)
278 .map(|c| half::f16::from_le_bytes(c.try_into().unwrap()).to_f32())
279 .collect()
280 }
281 DType::Q4_0 => {
282 let numel = self.numel();
283 let bytes = self.as_bytes();
284 let mut out = vec![0.0f32; numel];
285 for (b, block) in bytes.chunks_exact(Q4_0_BLOCK_BYTES).enumerate() {
286 let d = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
287 for j in 0..QUANT_BLOCK_SIZE / 2 {
288 let byte = block[2 + j];
289 let lo = (byte & 0x0f) as i32 - 8;
290 let hi = (byte >> 4) as i32 - 8;
291 out[b * QUANT_BLOCK_SIZE + j] = lo as f32 * d;
292 out[b * QUANT_BLOCK_SIZE + j + QUANT_BLOCK_SIZE / 2] = hi as f32 * d;
293 }
294 }
295 out
296 }
297 DType::Q8_0 => {
298 let numel = self.numel();
299 let bytes = self.as_bytes();
300 let mut out = vec![0.0f32; numel];
301 for (b, block) in bytes.chunks_exact(Q8_0_BLOCK_BYTES).enumerate() {
302 let d = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
303 for j in 0..QUANT_BLOCK_SIZE {
304 out[b * QUANT_BLOCK_SIZE + j] = block[2 + j] as i8 as f32 * d;
305 }
306 }
307 out
308 }
309 _ => self.as_f32_slice().to_vec(), }
311 }
312
313 pub fn to_f32_tensor(&self) -> Result<Tensor> {
316 match self.dtype {
317 DType::F32 => Ok(self.clone()),
318 _ => Tensor::from_f32(&self.to_f32_vec(), self.shape.clone()),
319 }
320 }
321
322 pub fn as_f32_slice_mut(&mut self) -> Result<&mut [f32]> {
324 if self.dtype != DType::F32 {
325 return Err(SapientError::internal("Tensor dtype is not F32"));
326 }
327 let offset = self.offset;
328 let buf = Arc::get_mut(&mut self.buffer.0)
329 .ok_or_else(|| SapientError::internal("Cannot mutate shared tensor buffer"))?;
330 let bytes = buf.as_bytes_mut();
331 let bytes = &mut bytes[offset..];
332 if bytes.len() % 4 != 0 {
333 return Err(SapientError::internal("Buffer length not a multiple of 4"));
334 }
335 Ok(unsafe {
337 std::slice::from_raw_parts_mut(bytes.as_mut_ptr() as *mut f32, bytes.len() / 4)
338 })
339 }
340
341 pub fn reshape(&self, new_shape: impl Into<Shape>) -> Result<Tensor> {
346 let new_shape = self.shape.reshape(new_shape.into().dims().to_vec())?;
347 let strides = new_shape.strides();
348 Ok(Tensor {
349 shape: new_shape,
350 dtype: self.dtype,
351 strides,
352 buffer: self.buffer.clone(),
353 offset: self.offset,
354 })
355 }
356
357 pub fn t(&self) -> Result<Tensor> {
359 if self.ndim() != 2 {
360 return Err(SapientError::internal("t() requires a 2-D tensor"));
361 }
362 let mut dims = self.shape.dims().to_vec();
363 let mut strides = self.strides.clone();
364 dims.swap(0, 1);
365 strides.swap(0, 1);
366 Ok(Tensor {
367 shape: Shape(dims),
368 dtype: self.dtype,
369 strides,
370 buffer: self.buffer.clone(),
371 offset: self.offset,
372 })
373 }
374
375 pub fn slice_axis(&self, axis: usize, start: usize, end: usize) -> Result<Tensor> {
377 let mut dims = self.shape.dims().to_vec();
378 if axis >= dims.len() {
379 return Err(SapientError::internal("slice axis out of bounds"));
380 }
381 if start > end || end > dims[axis] {
382 return Err(SapientError::internal("slice range out of bounds"));
383 }
384 dims[axis] = end - start;
385 let offset = self.offset + start * self.strides[axis] * self.dtype.element_size();
386 Ok(Tensor {
387 shape: Shape(dims),
388 dtype: self.dtype,
389 strides: self.strides.clone(),
390 buffer: self.buffer.clone(),
391 offset,
392 })
393 }
394
395 pub fn byte_size(&self) -> usize {
399 self.dtype.byte_count(self.numel())
400 }
401}
402
403impl std::fmt::Display for Tensor {
404 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
405 write!(
406 f,
407 "Tensor(shape={}, dtype={}, device={})",
408 self.shape,
409 self.dtype,
410 self.buffer.0.device()
411 )
412 }
413}
414
415#[derive(Serialize, Deserialize)]
419struct TensorProxy {
420 shape: Shape,
421 dtype: DType,
422 data: Vec<f32>,
424}
425
426impl Serialize for Tensor {
427 fn serialize<S: Serializer>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> {
428 let data: Vec<f32> = if self.dtype == DType::F32 {
429 self.as_f32_slice().to_vec()
430 } else {
431 vec![] };
433 TensorProxy {
434 shape: self.shape.clone(),
435 dtype: self.dtype,
436 data,
437 }
438 .serialize(serializer)
439 }
440}
441
442impl<'de> Deserialize<'de> for Tensor {
443 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> std::result::Result<Self, D::Error> {
444 let proxy = TensorProxy::deserialize(deserializer)?;
445 if proxy.data.is_empty() {
446 Tensor::zeros(proxy.shape, proxy.dtype).map_err(serde::de::Error::custom)
447 } else {
448 Tensor::from_f32(&proxy.data, proxy.shape).map_err(serde::de::Error::custom)
449 }
450 }
451}
452
453#[derive(Debug, Clone, Serialize, Deserialize)]
455pub struct TensorMeta {
456 pub shape: Shape,
457 pub dtype: DType,
458}
459
460impl From<&Tensor> for TensorMeta {
461 fn from(t: &Tensor) -> Self {
462 Self {
463 shape: t.shape.clone(),
464 dtype: t.dtype,
465 }
466 }
467}
468
469#[cfg(test)]
470mod tests {
471 use super::*;
472
473 #[test]
474 fn zeros_dtype_shape() {
475 let t = Tensor::zeros(vec![2, 3], DType::F32).unwrap();
476 assert_eq!(t.shape().dims(), &[2, 3]);
477 assert_eq!(t.dtype(), DType::F32);
478 assert_eq!(t.numel(), 6);
479 }
480
481 #[test]
482 fn from_f32_roundtrip() {
483 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
484 let t = Tensor::from_f32(&data, vec![2, 3]).unwrap();
485 assert_eq!(t.as_f32_slice(), data.as_slice());
486 }
487
488 #[test]
489 fn reshape_preserves_data() {
490 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
491 let t = Tensor::from_f32(&data, vec![2, 3]).unwrap();
492 let r = t.reshape(vec![3, 2]).unwrap();
493 assert_eq!(r.shape().dims(), &[3, 2]);
494 assert_eq!(r.as_f32_slice(), data.as_slice());
495 }
496
497 #[test]
498 fn reshape_wrong_numel() {
499 let t = Tensor::zeros(vec![2, 3], DType::F32).unwrap();
500 assert!(t.reshape(vec![5]).is_err());
501 }
502
503 #[test]
504 fn transpose_2d() {
505 let t = Tensor::zeros(vec![3, 4], DType::F32).unwrap();
506 let t2 = t.t().unwrap();
507 assert_eq!(t2.shape().dims(), &[4, 3]);
508 }
509
510 #[test]
511 fn byte_size() {
512 let t = Tensor::zeros(vec![4, 4], DType::F32).unwrap();
513 assert_eq!(t.byte_size(), 64);
514 }
515}