1use serde::{Deserialize, Serialize};
7use serde::{Deserializer, Serializer};
8use std::sync::Arc;
9
10use crate::buffer::{BufferHandle, CpuBuffer};
11use crate::dtype::DType;
12use crate::error::{Result, SapientError};
13use crate::shape::Shape;
14
15#[derive(Debug, Clone)]
19pub struct Tensor {
20 shape: Shape,
21 dtype: DType,
22 strides: Vec<usize>, buffer: BufferHandle,
24 offset: usize,
26}
27
28impl Tensor {
29 pub fn zeros(shape: impl Into<Shape>, dtype: DType) -> Result<Self> {
33 let shape = shape.into();
34 shape.validate()?;
35 let numel = shape.numel();
36 let strides = shape.strides();
37 let buffer = BufferHandle::new(CpuBuffer::zeros(numel, dtype)?);
38 Ok(Self {
39 shape,
40 dtype,
41 strides,
42 buffer,
43 offset: 0,
44 })
45 }
46
47 pub fn from_f32_vec(data: Vec<f32>, shape: impl Into<Shape>) -> Result<Self> {
51 let shape = shape.into();
52 shape.validate()?;
53 if data.len() != shape.numel() {
54 return Err(SapientError::ShapeMismatch {
55 expected: shape.dims().to_vec(),
56 got: vec![data.len()],
57 });
58 }
59 let strides = shape.strides();
60 let buffer = BufferHandle::new(CpuBuffer::from_f32_vec(data)?);
61 Ok(Self {
62 shape,
63 dtype: DType::F32,
64 strides,
65 buffer,
66 offset: 0,
67 })
68 }
69
70 pub fn from_f32(data: &[f32], shape: impl Into<Shape>) -> Result<Self> {
71 let shape = shape.into();
72 shape.validate()?;
73 if data.len() != shape.numel() {
74 return Err(SapientError::ShapeMismatch {
75 expected: shape.dims().to_vec(),
76 got: vec![data.len()],
77 });
78 }
79 let strides = shape.strides();
80 let buffer = BufferHandle::new(CpuBuffer::from_f32_slice(data)?);
81 Ok(Self {
82 shape,
83 dtype: DType::F32,
84 strides,
85 buffer,
86 offset: 0,
87 })
88 }
89
90 pub fn from_bf16_bytes(data: &[u8], shape: impl Into<Shape>) -> Result<Self> {
93 let shape = shape.into();
94 shape.validate()?;
95 let expected_bytes = shape.numel() * 2;
96 if data.len() != expected_bytes {
97 return Err(SapientError::ShapeMismatch {
98 expected: shape.dims().to_vec(),
99 got: vec![data.len() / 2],
100 });
101 }
102 let strides = shape.strides();
103 let buffer = BufferHandle::new(CpuBuffer::from_bytes_slice(data)?);
104 Ok(Self {
105 shape,
106 dtype: DType::BF16,
107 strides,
108 buffer,
109 offset: 0,
110 })
111 }
112
113 pub fn from_f16_bytes(data: &[u8], shape: impl Into<Shape>) -> Result<Self> {
115 let shape = shape.into();
116 shape.validate()?;
117 let expected_bytes = shape.numel() * 2;
118 if data.len() != expected_bytes {
119 return Err(SapientError::ShapeMismatch {
120 expected: shape.dims().to_vec(),
121 got: vec![data.len() / 2],
122 });
123 }
124 let strides = shape.strides();
125 let buffer = BufferHandle::new(CpuBuffer::from_bytes_slice(data)?);
126 Ok(Self {
127 shape,
128 dtype: DType::F16,
129 strides,
130 buffer,
131 offset: 0,
132 })
133 }
134
135 pub fn from_quant_bytes(data: &[u8], shape: impl Into<Shape>, dtype: DType) -> Result<Self> {
141 if !dtype.is_quantized() {
142 return Err(SapientError::TypeMismatch {
143 expected: "a quantized dtype (Q4_0, Q8_0, Q4_K, Q5_K, Q6_K)".into(),
144 got: dtype.to_string(),
145 });
146 }
147 let shape = shape.into();
148 shape.validate()?;
149 let numel = shape.numel();
150 let expected_bytes = dtype.byte_count(numel);
151 if data.len() != expected_bytes {
152 return Err(SapientError::ShapeMismatch {
153 expected: vec![expected_bytes],
154 got: vec![data.len()],
155 });
156 }
157 let strides = shape.strides();
158 let buffer = BufferHandle::new(CpuBuffer::from_bytes_slice(data)?);
159 Ok(Self {
160 shape,
161 dtype,
162 strides,
163 buffer,
164 offset: 0,
165 })
166 }
167
168 pub fn scalar_f32(v: f32) -> Result<Self> {
170 Self::from_f32(&[v], Shape::scalar())
171 }
172
173 pub fn from_buffer(
175 shape: impl Into<Shape>,
176 dtype: DType,
177 buffer: BufferHandle,
178 offset: usize,
179 ) -> Result<Self> {
180 let shape = shape.into();
181 shape.validate()?;
182 let required = dtype.byte_count(shape.numel());
183 if buffer.len() < offset + required {
184 return Err(SapientError::BufferSizeMismatch {
185 expected: offset + required,
186 got: buffer.len(),
187 });
188 }
189 let strides = shape.strides();
190 Ok(Self {
191 shape,
192 dtype,
193 strides,
194 buffer,
195 offset,
196 })
197 }
198
199 pub fn shape(&self) -> &Shape {
202 &self.shape
203 }
204 pub fn dtype(&self) -> DType {
205 self.dtype
206 }
207 pub fn ndim(&self) -> usize {
208 self.shape.ndim()
209 }
210 pub fn numel(&self) -> usize {
211 self.shape.numel()
212 }
213 pub fn strides(&self) -> &[usize] {
214 &self.strides
215 }
216 pub fn buffer(&self) -> &BufferHandle {
217 &self.buffer
218 }
219 pub fn offset(&self) -> usize {
220 self.offset
221 }
222
223 pub fn is_scalar(&self) -> bool {
225 self.shape.is_scalar() || self.numel() == 1
226 }
227
228 pub fn is_contiguous(&self) -> bool {
230 self.strides == self.shape.strides() && self.offset == 0
231 }
232
233 pub fn as_bytes(&self) -> &[u8] {
240 let bytes = self.buffer.as_bytes();
241 if self.dtype.is_quantized() {
242 let end = self.offset + self.dtype.byte_count(self.numel());
243 &bytes[self.offset..end]
244 } else {
245 &bytes[self.offset..]
246 }
247 }
248
249 pub fn as_quant_blocks(&self) -> &[u8] {
253 assert!(
254 self.dtype.is_quantized(),
255 "as_quant_blocks() called on non-quantized tensor (dtype = {})",
256 self.dtype
257 );
258 self.as_bytes()
259 }
260
261 pub fn as_f32_slice(&self) -> &[f32] {
263 assert_eq!(
264 self.dtype,
265 DType::F32,
266 "Tensor dtype is not F32 — call to_f32_vec() instead"
267 );
268 let bytes = self.as_bytes();
269 assert_eq!(bytes.len() % 4, 0);
270 unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const f32, bytes.len() / 4) }
272 }
273
274 pub fn to_contiguous_f32_vec(&self) -> Vec<f32> {
284 let numel = self.numel();
285 if self.is_contiguous() {
286 match self.dtype {
289 DType::F32 => self.as_f32_slice()[..numel].to_vec(),
290 _ => {
291 let v = self.to_f32_vec();
292 v[..numel.min(v.len())].to_vec()
293 }
294 }
295 } else {
296 let raw: Vec<f32> = match self.dtype {
299 DType::F32 => self.as_f32_slice().to_vec(),
300 _ => self.to_f32_vec(),
301 };
302 let dims = self.shape.dims();
303 let strides = &self.strides; let mut out = vec![0.0f32; numel];
305 for (flat, dst) in out.iter_mut().enumerate() {
306 let mut rem = flat;
309 let mut src = 0usize;
310 for d in (0..dims.len()).rev() {
311 let idx_d = rem % dims[d];
312 rem /= dims[d];
313 src += idx_d * strides[d];
314 }
315 *dst = *raw.get(src).unwrap_or(&0.0);
316 }
317 out
318 }
319 }
320
321 pub fn to_f32_cow(&self) -> std::borrow::Cow<'_, [f32]> {
323 if self.dtype == DType::F32 {
324 std::borrow::Cow::Borrowed(self.as_f32_slice())
325 } else {
326 std::borrow::Cow::Owned(self.to_f32_vec())
327 }
328 }
329
330 pub fn to_f32_vec(&self) -> Vec<f32> {
333 use crate::dtype::{
334 K_QUANT_BLOCK_SIZE, Q4_0_BLOCK_BYTES, Q4_K_BLOCK_BYTES, Q5_K_BLOCK_BYTES,
335 Q6_K_BLOCK_BYTES, Q8_0_BLOCK_BYTES, QUANT_BLOCK_SIZE,
336 };
337 match self.dtype {
338 DType::F32 => self.as_f32_slice().to_vec(),
339 DType::BF16 => {
340 let bytes = self.as_bytes();
341 bytes
342 .chunks_exact(2)
343 .map(|c| f32::from(half::bf16::from_le_bytes(c.try_into().unwrap())))
344 .collect()
345 }
346 DType::F16 => {
347 let bytes = self.as_bytes();
348 bytes
349 .chunks_exact(2)
350 .map(|c| half::f16::from_le_bytes(c.try_into().unwrap()).to_f32())
351 .collect()
352 }
353 DType::Q4_0 => {
354 let numel = self.numel();
355 let bytes = self.as_bytes();
356 let mut out = vec![0.0f32; numel];
357 for (b, block) in bytes.chunks_exact(Q4_0_BLOCK_BYTES).enumerate() {
358 let d = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
359 for j in 0..QUANT_BLOCK_SIZE / 2 {
360 let byte = block[2 + j];
361 let lo = (byte & 0x0f) as i32 - 8;
362 let hi = (byte >> 4) as i32 - 8;
363 out[b * QUANT_BLOCK_SIZE + j] = lo as f32 * d;
364 out[b * QUANT_BLOCK_SIZE + j + QUANT_BLOCK_SIZE / 2] = hi as f32 * d;
365 }
366 }
367 out
368 }
369 DType::Q8_0 => {
370 let numel = self.numel();
371 let bytes = self.as_bytes();
372 let mut out = vec![0.0f32; numel];
373 for (b, block) in bytes.chunks_exact(Q8_0_BLOCK_BYTES).enumerate() {
374 let d = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
375 for j in 0..QUANT_BLOCK_SIZE {
376 out[b * QUANT_BLOCK_SIZE + j] = block[2 + j] as i8 as f32 * d;
377 }
378 }
379 out
380 }
381 DType::Q4_K => {
382 let numel = self.numel();
383 let bytes = self.as_bytes();
384 let mut out = vec![0.0f32; numel];
385 let mut out_idx = 0usize;
386 for block in bytes.chunks_exact(Q4_K_BLOCK_BYTES) {
387 let d = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
388 let dmin = half::f16::from_le_bytes([block[2], block[3]]).to_f32();
389 let scales = &block[4..16];
390 let qs = &block[16..Q4_K_BLOCK_BYTES];
391 let mut q_off = 0usize;
392 let mut is = 0usize;
393 for _ in 0..(K_QUANT_BLOCK_SIZE / 64) {
394 let (sc1, m1) = Self::get_scale_min_k4(is, scales);
395 let d1 = d * sc1 as f32;
396 let m1v = dmin * m1 as f32;
397 let (sc2, m2) = Self::get_scale_min_k4(is + 1, scales);
398 let d2 = d * sc2 as f32;
399 let m2v = dmin * m2 as f32;
400 for l in 0..32 {
401 out[out_idx + l] = d1 * (qs[q_off + l] & 0x0F) as f32 - m1v;
402 out[out_idx + l + 32] = d2 * (qs[q_off + l] >> 4) as f32 - m2v;
403 }
404 out_idx += 64;
405 q_off += 32;
406 is += 2;
407 }
408 }
409 out
410 }
411 DType::Q5_K => {
412 let numel = self.numel();
413 let bytes = self.as_bytes();
414 let mut out = vec![0.0f32; numel];
415 let mut out_idx = 0usize;
416 for block in bytes.chunks_exact(Q5_K_BLOCK_BYTES) {
417 let d = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
418 let dmin = half::f16::from_le_bytes([block[2], block[3]]).to_f32();
419 let scales = &block[4..16];
420 let qh = &block[16..48];
421 let ql = &block[48..Q5_K_BLOCK_BYTES];
422 let mut ql_off = 0usize;
423 let mut is = 0usize;
424 let mut u1: u8 = 1;
425 let mut u2: u8 = 2;
426 for _ in 0..(K_QUANT_BLOCK_SIZE / 64) {
427 let (sc1, m1) = Self::get_scale_min_k4(is, scales);
428 let d1 = d * sc1 as f32;
429 let m1v = dmin * m1 as f32;
430 let (sc2, m2) = Self::get_scale_min_k4(is + 1, scales);
431 let d2 = d * sc2 as f32;
432 let m2v = dmin * m2 as f32;
433 let qh_byte = qh[is / 8];
434 for l in 0..32usize {
435 let hi = if qh_byte & u1 != 0 { 16.0f32 } else { 0.0 };
436 out[out_idx + l] = d1 * ((ql[ql_off + l] & 0x0F) as f32 + hi) - m1v;
437 let hi2 = if qh_byte & u2 != 0 { 16.0f32 } else { 0.0 };
438 out[out_idx + l + 32] = d2 * ((ql[ql_off + l] >> 4) as f32 + hi2) - m2v;
439 }
440 out_idx += 64;
441 ql_off += 32;
442 is += 2;
443 if is % 8 == 0 {
444 u1 = 1;
445 u2 = 2;
446 } else {
447 u1 <<= 2;
448 u2 <<= 2;
449 }
450 }
451 }
452 out
453 }
454 DType::Q6_K => {
455 let numel = self.numel();
456 let bytes = self.as_bytes();
457 let mut out = vec![0.0f32; numel];
458 let mut out_idx = 0usize;
459 for block in bytes.chunks_exact(Q6_K_BLOCK_BYTES) {
460 let ql = &block[0..128];
461 let qh = &block[128..192];
462 let sc = &block[192..208];
463 let d = half::f16::from_le_bytes([block[208], block[209]]).to_f32();
464 let mut ql_off = 0usize;
465 let mut qh_off = 0usize;
466 let mut ib = 0usize;
467 for _ in 0..(K_QUANT_BLOCK_SIZE / 128) {
468 for l in 0..32usize {
469 let q1 = (((ql[ql_off + l] & 0x0F) | ((qh[qh_off + l] & 3) << 4))
470 as i32
471 - 32) as f32;
472 let q2 = (((ql[ql_off + l + 32] & 0x0F)
473 | (((qh[qh_off + l] >> 2) & 3) << 4))
474 as i32
475 - 32) as f32;
476 let q3 = (((ql[ql_off + l] >> 4) | (((qh[qh_off + l] >> 4) & 3) << 4))
477 as i32
478 - 32) as f32;
479 let q4 = (((ql[ql_off + l + 32] >> 4)
480 | (((qh[qh_off + l] >> 6) & 3) << 4))
481 as i32
482 - 32) as f32;
483 out[out_idx + l] = d * sc[ib] as i8 as f32 * q1;
484 out[out_idx + l + 32] = d * sc[ib + 1] as i8 as f32 * q2;
485 out[out_idx + l + 64] = d * sc[ib + 2] as i8 as f32 * q3;
486 out[out_idx + l + 96] = d * sc[ib + 3] as i8 as f32 * q4;
487 }
488 out_idx += 128;
489 ql_off += 64;
490 qh_off += 32;
491 ib += 4;
492 }
493 }
494 out
495 }
496 _ => self.as_f32_slice().to_vec(), }
498 }
499
500 #[inline]
502 fn get_scale_min_k4(j: usize, scales: &[u8]) -> (u8, u8) {
503 if j < 4 {
504 (scales[j] & 63, scales[j + 4] & 63)
505 } else {
506 (
507 (scales[j + 4] & 0x0F) | ((scales[j - 4] >> 6) << 4),
508 (scales[j + 4] >> 4) | ((scales[j] >> 6) << 4),
509 )
510 }
511 }
512
513 pub fn to_f32_tensor(&self) -> Result<Tensor> {
516 match self.dtype {
517 DType::F32 => Ok(self.clone()),
518 _ => Tensor::from_f32(&self.to_f32_vec(), self.shape.clone()),
519 }
520 }
521
522 pub fn as_bytes_mut(&mut self) -> Result<&mut [u8]> {
526 let offset = self.offset;
527 let end = offset + self.dtype.byte_count(self.numel());
528 let buf = Arc::get_mut(&mut self.buffer.0)
529 .ok_or_else(|| SapientError::internal("Cannot mutate shared tensor buffer"))?;
530 let bytes = buf.as_bytes_mut();
531 Ok(&mut bytes[offset..end])
532 }
533
534 pub fn as_f32_slice_mut(&mut self) -> Result<&mut [f32]> {
535 if self.dtype != DType::F32 {
536 return Err(SapientError::internal("Tensor dtype is not F32"));
537 }
538 let offset = self.offset;
539 let buf = Arc::get_mut(&mut self.buffer.0)
540 .ok_or_else(|| SapientError::internal("Cannot mutate shared tensor buffer"))?;
541 let bytes = buf.as_bytes_mut();
542 let bytes = &mut bytes[offset..];
543 if bytes.len() % 4 != 0 {
544 return Err(SapientError::internal("Buffer length not a multiple of 4"));
545 }
546 Ok(unsafe {
548 std::slice::from_raw_parts_mut(bytes.as_mut_ptr() as *mut f32, bytes.len() / 4)
549 })
550 }
551
552 pub fn reshape(&self, new_shape: impl Into<Shape>) -> Result<Tensor> {
557 let new_shape = self.shape.reshape(new_shape.into().dims().to_vec())?;
558 let strides = new_shape.strides();
559 Ok(Tensor {
560 shape: new_shape,
561 dtype: self.dtype,
562 strides,
563 buffer: self.buffer.clone(),
564 offset: self.offset,
565 })
566 }
567
568 pub fn t(&self) -> Result<Tensor> {
570 if self.ndim() != 2 {
571 return Err(SapientError::internal("t() requires a 2-D tensor"));
572 }
573 let mut dims = self.shape.dims().to_vec();
574 let mut strides = self.strides.clone();
575 dims.swap(0, 1);
576 strides.swap(0, 1);
577 Ok(Tensor {
578 shape: Shape(dims),
579 dtype: self.dtype,
580 strides,
581 buffer: self.buffer.clone(),
582 offset: self.offset,
583 })
584 }
585
586 pub fn slice_axis(&self, axis: usize, start: usize, end: usize) -> Result<Tensor> {
588 let mut dims = self.shape.dims().to_vec();
589 if axis >= dims.len() {
590 return Err(SapientError::internal("slice axis out of bounds"));
591 }
592 if start > end || end > dims[axis] {
593 return Err(SapientError::internal("slice range out of bounds"));
594 }
595 dims[axis] = end - start;
596 let offset = self.offset + start * self.strides[axis] * self.dtype.element_size();
597 Ok(Tensor {
598 shape: Shape(dims),
599 dtype: self.dtype,
600 strides: self.strides.clone(),
601 buffer: self.buffer.clone(),
602 offset,
603 })
604 }
605
606 pub fn byte_size(&self) -> usize {
610 self.dtype.byte_count(self.numel())
611 }
612}
613
614impl std::fmt::Display for Tensor {
615 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
616 write!(
617 f,
618 "Tensor(shape={}, dtype={}, device={})",
619 self.shape,
620 self.dtype,
621 self.buffer.0.device()
622 )
623 }
624}
625
626#[derive(Serialize, Deserialize)]
630struct TensorProxy {
631 shape: Shape,
632 dtype: DType,
633 data: Vec<f32>,
635}
636
637impl Serialize for Tensor {
638 fn serialize<S: Serializer>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> {
639 let data: Vec<f32> = if self.dtype == DType::F32 {
640 self.as_f32_slice().to_vec()
641 } else {
642 vec![] };
644 TensorProxy {
645 shape: self.shape.clone(),
646 dtype: self.dtype,
647 data,
648 }
649 .serialize(serializer)
650 }
651}
652
653impl<'de> Deserialize<'de> for Tensor {
654 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> std::result::Result<Self, D::Error> {
655 let proxy = TensorProxy::deserialize(deserializer)?;
656 if proxy.data.is_empty() {
657 Tensor::zeros(proxy.shape, proxy.dtype).map_err(serde::de::Error::custom)
658 } else {
659 Tensor::from_f32(&proxy.data, proxy.shape).map_err(serde::de::Error::custom)
660 }
661 }
662}
663
664#[derive(Debug, Clone, Serialize, Deserialize)]
666pub struct TensorMeta {
667 pub shape: Shape,
668 pub dtype: DType,
669}
670
671impl From<&Tensor> for TensorMeta {
672 fn from(t: &Tensor) -> Self {
673 Self {
674 shape: t.shape.clone(),
675 dtype: t.dtype,
676 }
677 }
678}
679
680#[cfg(test)]
681mod tests {
682 use super::*;
683
684 #[test]
685 fn zeros_dtype_shape() {
686 let t = Tensor::zeros(vec![2, 3], DType::F32).unwrap();
687 assert_eq!(t.shape().dims(), &[2, 3]);
688 assert_eq!(t.dtype(), DType::F32);
689 assert_eq!(t.numel(), 6);
690 }
691
692 #[test]
693 fn from_f32_roundtrip() {
694 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
695 let t = Tensor::from_f32(&data, vec![2, 3]).unwrap();
696 assert_eq!(t.as_f32_slice(), data.as_slice());
697 }
698
699 #[test]
700 fn reshape_preserves_data() {
701 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
702 let t = Tensor::from_f32(&data, vec![2, 3]).unwrap();
703 let r = t.reshape(vec![3, 2]).unwrap();
704 assert_eq!(r.shape().dims(), &[3, 2]);
705 assert_eq!(r.as_f32_slice(), data.as_slice());
706 }
707
708 #[test]
709 fn reshape_wrong_numel() {
710 let t = Tensor::zeros(vec![2, 3], DType::F32).unwrap();
711 assert!(t.reshape(vec![5]).is_err());
712 }
713
714 #[test]
715 fn transpose_2d() {
716 let t = Tensor::zeros(vec![3, 4], DType::F32).unwrap();
717 let t2 = t.t().unwrap();
718 assert_eq!(t2.shape().dims(), &[4, 3]);
719 }
720
721 #[test]
722 fn byte_size() {
723 let t = Tensor::zeros(vec![4, 4], DType::F32).unwrap();
724 assert_eq!(t.byte_size(), 64);
725 }
726}