1use serde::{Deserialize, Serialize};
7use serde::{Deserializer, Serializer};
8use std::sync::Arc;
9
10use crate::buffer::{BufferHandle, CpuBuffer};
11use crate::dtype::DType;
12use crate::error::{Result, SapientError};
13use crate::shape::Shape;
14
15#[derive(Debug, Clone)]
19pub struct Tensor {
20 shape: Shape,
21 dtype: DType,
22 strides: Vec<usize>, buffer: BufferHandle,
24 offset: usize,
26}
27
28impl Tensor {
29 pub fn zeros(shape: impl Into<Shape>, dtype: DType) -> Result<Self> {
33 let shape = shape.into();
34 shape.validate()?;
35 let numel = shape.numel();
36 let strides = shape.strides();
37 let buffer = BufferHandle::new(CpuBuffer::zeros(numel, dtype)?);
38 Ok(Self {
39 shape,
40 dtype,
41 strides,
42 buffer,
43 offset: 0,
44 })
45 }
46
47 pub fn from_f32(data: &[f32], shape: impl Into<Shape>) -> Result<Self> {
49 let shape = shape.into();
50 shape.validate()?;
51 if data.len() != shape.numel() {
52 return Err(SapientError::ShapeMismatch {
53 expected: shape.dims().to_vec(),
54 got: vec![data.len()],
55 });
56 }
57 let strides = shape.strides();
58 let buffer = BufferHandle::new(CpuBuffer::from_f32_slice(data)?);
59 Ok(Self {
60 shape,
61 dtype: DType::F32,
62 strides,
63 buffer,
64 offset: 0,
65 })
66 }
67
68 pub fn from_bf16_bytes(data: &[u8], shape: impl Into<Shape>) -> Result<Self> {
71 let shape = shape.into();
72 shape.validate()?;
73 let expected_bytes = shape.numel() * 2;
74 if data.len() != expected_bytes {
75 return Err(SapientError::ShapeMismatch {
76 expected: shape.dims().to_vec(),
77 got: vec![data.len() / 2],
78 });
79 }
80 let strides = shape.strides();
81 let buffer = BufferHandle::new(CpuBuffer::from_bytes_slice(data)?);
82 Ok(Self {
83 shape,
84 dtype: DType::BF16,
85 strides,
86 buffer,
87 offset: 0,
88 })
89 }
90
91 pub fn from_f16_bytes(data: &[u8], shape: impl Into<Shape>) -> Result<Self> {
93 let shape = shape.into();
94 shape.validate()?;
95 let expected_bytes = shape.numel() * 2;
96 if data.len() != expected_bytes {
97 return Err(SapientError::ShapeMismatch {
98 expected: shape.dims().to_vec(),
99 got: vec![data.len() / 2],
100 });
101 }
102 let strides = shape.strides();
103 let buffer = BufferHandle::new(CpuBuffer::from_bytes_slice(data)?);
104 Ok(Self {
105 shape,
106 dtype: DType::F16,
107 strides,
108 buffer,
109 offset: 0,
110 })
111 }
112
113 pub fn from_quant_bytes(data: &[u8], shape: impl Into<Shape>, dtype: DType) -> Result<Self> {
119 if !dtype.is_quantized() {
120 return Err(SapientError::TypeMismatch {
121 expected: "a quantized dtype (Q4_0, Q8_0, Q4_K, Q5_K, Q6_K)".into(),
122 got: dtype.to_string(),
123 });
124 }
125 let shape = shape.into();
126 shape.validate()?;
127 let numel = shape.numel();
128 let expected_bytes = dtype.byte_count(numel);
129 if data.len() != expected_bytes {
130 return Err(SapientError::ShapeMismatch {
131 expected: vec![expected_bytes],
132 got: vec![data.len()],
133 });
134 }
135 let strides = shape.strides();
136 let buffer = BufferHandle::new(CpuBuffer::from_bytes_slice(data)?);
137 Ok(Self {
138 shape,
139 dtype,
140 strides,
141 buffer,
142 offset: 0,
143 })
144 }
145
146 pub fn scalar_f32(v: f32) -> Result<Self> {
148 Self::from_f32(&[v], Shape::scalar())
149 }
150
151 pub fn from_buffer(
153 shape: impl Into<Shape>,
154 dtype: DType,
155 buffer: BufferHandle,
156 offset: usize,
157 ) -> Result<Self> {
158 let shape = shape.into();
159 shape.validate()?;
160 let required = dtype.byte_count(shape.numel());
161 if buffer.len() < offset + required {
162 return Err(SapientError::BufferSizeMismatch {
163 expected: offset + required,
164 got: buffer.len(),
165 });
166 }
167 let strides = shape.strides();
168 Ok(Self {
169 shape,
170 dtype,
171 strides,
172 buffer,
173 offset,
174 })
175 }
176
177 pub fn shape(&self) -> &Shape {
180 &self.shape
181 }
182 pub fn dtype(&self) -> DType {
183 self.dtype
184 }
185 pub fn ndim(&self) -> usize {
186 self.shape.ndim()
187 }
188 pub fn numel(&self) -> usize {
189 self.shape.numel()
190 }
191 pub fn strides(&self) -> &[usize] {
192 &self.strides
193 }
194 pub fn buffer(&self) -> &BufferHandle {
195 &self.buffer
196 }
197 pub fn offset(&self) -> usize {
198 self.offset
199 }
200
201 pub fn is_scalar(&self) -> bool {
203 self.shape.is_scalar() || self.numel() == 1
204 }
205
206 pub fn is_contiguous(&self) -> bool {
208 self.strides == self.shape.strides() && self.offset == 0
209 }
210
211 pub fn as_bytes(&self) -> &[u8] {
218 let bytes = self.buffer.as_bytes();
219 if self.dtype.is_quantized() {
220 let end = self.offset + self.dtype.byte_count(self.numel());
221 &bytes[self.offset..end]
222 } else {
223 &bytes[self.offset..]
224 }
225 }
226
227 pub fn as_quant_blocks(&self) -> &[u8] {
231 assert!(
232 self.dtype.is_quantized(),
233 "as_quant_blocks() called on non-quantized tensor (dtype = {})",
234 self.dtype
235 );
236 self.as_bytes()
237 }
238
239 pub fn as_f32_slice(&self) -> &[f32] {
241 assert_eq!(
242 self.dtype,
243 DType::F32,
244 "Tensor dtype is not F32 — call to_f32_vec() instead"
245 );
246 let bytes = self.as_bytes();
247 assert_eq!(bytes.len() % 4, 0);
248 unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const f32, bytes.len() / 4) }
250 }
251
252 pub fn to_contiguous_f32_vec(&self) -> Vec<f32> {
262 let numel = self.numel();
263 if self.is_contiguous() {
264 match self.dtype {
267 DType::F32 => self.as_f32_slice()[..numel].to_vec(),
268 _ => {
269 let v = self.to_f32_vec();
270 v[..numel.min(v.len())].to_vec()
271 }
272 }
273 } else {
274 let raw: Vec<f32> = match self.dtype {
277 DType::F32 => self.as_f32_slice().to_vec(),
278 _ => self.to_f32_vec(),
279 };
280 let dims = self.shape.dims();
281 let strides = &self.strides; let mut out = vec![0.0f32; numel];
283 for (flat, dst) in out.iter_mut().enumerate() {
284 let mut rem = flat;
287 let mut src = 0usize;
288 for d in (0..dims.len()).rev() {
289 let idx_d = rem % dims[d];
290 rem /= dims[d];
291 src += idx_d * strides[d];
292 }
293 *dst = *raw.get(src).unwrap_or(&0.0);
294 }
295 out
296 }
297 }
298
299 pub fn to_f32_cow(&self) -> std::borrow::Cow<'_, [f32]> {
301 if self.dtype == DType::F32 {
302 std::borrow::Cow::Borrowed(self.as_f32_slice())
303 } else {
304 std::borrow::Cow::Owned(self.to_f32_vec())
305 }
306 }
307
308 pub fn to_f32_vec(&self) -> Vec<f32> {
311 use crate::dtype::{
312 K_QUANT_BLOCK_SIZE, Q4_0_BLOCK_BYTES, Q4_K_BLOCK_BYTES, Q5_K_BLOCK_BYTES,
313 Q6_K_BLOCK_BYTES, Q8_0_BLOCK_BYTES, QUANT_BLOCK_SIZE,
314 };
315 match self.dtype {
316 DType::F32 => self.as_f32_slice().to_vec(),
317 DType::BF16 => {
318 let bytes = self.as_bytes();
319 bytes
320 .chunks_exact(2)
321 .map(|c| f32::from(half::bf16::from_le_bytes(c.try_into().unwrap())))
322 .collect()
323 }
324 DType::F16 => {
325 let bytes = self.as_bytes();
326 bytes
327 .chunks_exact(2)
328 .map(|c| half::f16::from_le_bytes(c.try_into().unwrap()).to_f32())
329 .collect()
330 }
331 DType::Q4_0 => {
332 let numel = self.numel();
333 let bytes = self.as_bytes();
334 let mut out = vec![0.0f32; numel];
335 for (b, block) in bytes.chunks_exact(Q4_0_BLOCK_BYTES).enumerate() {
336 let d = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
337 for j in 0..QUANT_BLOCK_SIZE / 2 {
338 let byte = block[2 + j];
339 let lo = (byte & 0x0f) as i32 - 8;
340 let hi = (byte >> 4) as i32 - 8;
341 out[b * QUANT_BLOCK_SIZE + j] = lo as f32 * d;
342 out[b * QUANT_BLOCK_SIZE + j + QUANT_BLOCK_SIZE / 2] = hi as f32 * d;
343 }
344 }
345 out
346 }
347 DType::Q8_0 => {
348 let numel = self.numel();
349 let bytes = self.as_bytes();
350 let mut out = vec![0.0f32; numel];
351 for (b, block) in bytes.chunks_exact(Q8_0_BLOCK_BYTES).enumerate() {
352 let d = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
353 for j in 0..QUANT_BLOCK_SIZE {
354 out[b * QUANT_BLOCK_SIZE + j] = block[2 + j] as i8 as f32 * d;
355 }
356 }
357 out
358 }
359 DType::Q4_K => {
360 let numel = self.numel();
361 let bytes = self.as_bytes();
362 let mut out = vec![0.0f32; numel];
363 let mut out_idx = 0usize;
364 for block in bytes.chunks_exact(Q4_K_BLOCK_BYTES) {
365 let d = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
366 let dmin = half::f16::from_le_bytes([block[2], block[3]]).to_f32();
367 let scales = &block[4..16];
368 let qs = &block[16..Q4_K_BLOCK_BYTES];
369 let mut q_off = 0usize;
370 let mut is = 0usize;
371 for _ in 0..(K_QUANT_BLOCK_SIZE / 64) {
372 let (sc1, m1) = Self::get_scale_min_k4(is, scales);
373 let d1 = d * sc1 as f32;
374 let m1v = dmin * m1 as f32;
375 let (sc2, m2) = Self::get_scale_min_k4(is + 1, scales);
376 let d2 = d * sc2 as f32;
377 let m2v = dmin * m2 as f32;
378 for l in 0..32 {
379 out[out_idx + l] = d1 * (qs[q_off + l] & 0x0F) as f32 - m1v;
380 out[out_idx + l + 32] = d2 * (qs[q_off + l] >> 4) as f32 - m2v;
381 }
382 out_idx += 64;
383 q_off += 32;
384 is += 2;
385 }
386 }
387 out
388 }
389 DType::Q5_K => {
390 let numel = self.numel();
391 let bytes = self.as_bytes();
392 let mut out = vec![0.0f32; numel];
393 let mut out_idx = 0usize;
394 for block in bytes.chunks_exact(Q5_K_BLOCK_BYTES) {
395 let d = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
396 let dmin = half::f16::from_le_bytes([block[2], block[3]]).to_f32();
397 let scales = &block[4..16];
398 let qh = &block[16..48];
399 let ql = &block[48..Q5_K_BLOCK_BYTES];
400 let mut ql_off = 0usize;
401 let mut is = 0usize;
402 let mut u1: u8 = 1;
403 let mut u2: u8 = 2;
404 for _ in 0..(K_QUANT_BLOCK_SIZE / 64) {
405 let (sc1, m1) = Self::get_scale_min_k4(is, scales);
406 let d1 = d * sc1 as f32;
407 let m1v = dmin * m1 as f32;
408 let (sc2, m2) = Self::get_scale_min_k4(is + 1, scales);
409 let d2 = d * sc2 as f32;
410 let m2v = dmin * m2 as f32;
411 let qh_byte = qh[is / 8];
412 for l in 0..32usize {
413 let hi = if qh_byte & u1 != 0 { 16.0f32 } else { 0.0 };
414 out[out_idx + l] = d1 * ((ql[ql_off + l] & 0x0F) as f32 + hi) - m1v;
415 let hi2 = if qh_byte & u2 != 0 { 16.0f32 } else { 0.0 };
416 out[out_idx + l + 32] = d2 * ((ql[ql_off + l] >> 4) as f32 + hi2) - m2v;
417 }
418 out_idx += 64;
419 ql_off += 32;
420 is += 2;
421 if is % 8 == 0 {
422 u1 = 1;
423 u2 = 2;
424 } else {
425 u1 <<= 2;
426 u2 <<= 2;
427 }
428 }
429 }
430 out
431 }
432 DType::Q6_K => {
433 let numel = self.numel();
434 let bytes = self.as_bytes();
435 let mut out = vec![0.0f32; numel];
436 let mut out_idx = 0usize;
437 for block in bytes.chunks_exact(Q6_K_BLOCK_BYTES) {
438 let ql = &block[0..128];
439 let qh = &block[128..192];
440 let sc = &block[192..208];
441 let d = half::f16::from_le_bytes([block[208], block[209]]).to_f32();
442 let mut ql_off = 0usize;
443 let mut qh_off = 0usize;
444 let mut ib = 0usize;
445 for _ in 0..(K_QUANT_BLOCK_SIZE / 128) {
446 for l in 0..32usize {
447 let q1 = (((ql[ql_off + l] & 0x0F) | ((qh[qh_off + l] & 3) << 4))
448 as i32
449 - 32) as f32;
450 let q2 = (((ql[ql_off + l + 32] & 0x0F)
451 | (((qh[qh_off + l] >> 2) & 3) << 4))
452 as i32
453 - 32) as f32;
454 let q3 = (((ql[ql_off + l] >> 4) | (((qh[qh_off + l] >> 4) & 3) << 4))
455 as i32
456 - 32) as f32;
457 let q4 = (((ql[ql_off + l + 32] >> 4)
458 | (((qh[qh_off + l] >> 6) & 3) << 4))
459 as i32
460 - 32) as f32;
461 out[out_idx + l] = d * sc[ib] as i8 as f32 * q1;
462 out[out_idx + l + 32] = d * sc[ib + 1] as i8 as f32 * q2;
463 out[out_idx + l + 64] = d * sc[ib + 2] as i8 as f32 * q3;
464 out[out_idx + l + 96] = d * sc[ib + 3] as i8 as f32 * q4;
465 }
466 out_idx += 128;
467 ql_off += 64;
468 qh_off += 32;
469 ib += 4;
470 }
471 }
472 out
473 }
474 _ => self.as_f32_slice().to_vec(), }
476 }
477
478 #[inline]
480 fn get_scale_min_k4(j: usize, scales: &[u8]) -> (u8, u8) {
481 if j < 4 {
482 (scales[j] & 63, scales[j + 4] & 63)
483 } else {
484 (
485 (scales[j + 4] & 0x0F) | ((scales[j - 4] >> 6) << 4),
486 (scales[j + 4] >> 4) | ((scales[j] >> 6) << 4),
487 )
488 }
489 }
490
491 pub fn to_f32_tensor(&self) -> Result<Tensor> {
494 match self.dtype {
495 DType::F32 => Ok(self.clone()),
496 _ => Tensor::from_f32(&self.to_f32_vec(), self.shape.clone()),
497 }
498 }
499
500 pub fn as_f32_slice_mut(&mut self) -> Result<&mut [f32]> {
502 if self.dtype != DType::F32 {
503 return Err(SapientError::internal("Tensor dtype is not F32"));
504 }
505 let offset = self.offset;
506 let buf = Arc::get_mut(&mut self.buffer.0)
507 .ok_or_else(|| SapientError::internal("Cannot mutate shared tensor buffer"))?;
508 let bytes = buf.as_bytes_mut();
509 let bytes = &mut bytes[offset..];
510 if bytes.len() % 4 != 0 {
511 return Err(SapientError::internal("Buffer length not a multiple of 4"));
512 }
513 Ok(unsafe {
515 std::slice::from_raw_parts_mut(bytes.as_mut_ptr() as *mut f32, bytes.len() / 4)
516 })
517 }
518
519 pub fn reshape(&self, new_shape: impl Into<Shape>) -> Result<Tensor> {
524 let new_shape = self.shape.reshape(new_shape.into().dims().to_vec())?;
525 let strides = new_shape.strides();
526 Ok(Tensor {
527 shape: new_shape,
528 dtype: self.dtype,
529 strides,
530 buffer: self.buffer.clone(),
531 offset: self.offset,
532 })
533 }
534
535 pub fn t(&self) -> Result<Tensor> {
537 if self.ndim() != 2 {
538 return Err(SapientError::internal("t() requires a 2-D tensor"));
539 }
540 let mut dims = self.shape.dims().to_vec();
541 let mut strides = self.strides.clone();
542 dims.swap(0, 1);
543 strides.swap(0, 1);
544 Ok(Tensor {
545 shape: Shape(dims),
546 dtype: self.dtype,
547 strides,
548 buffer: self.buffer.clone(),
549 offset: self.offset,
550 })
551 }
552
553 pub fn slice_axis(&self, axis: usize, start: usize, end: usize) -> Result<Tensor> {
555 let mut dims = self.shape.dims().to_vec();
556 if axis >= dims.len() {
557 return Err(SapientError::internal("slice axis out of bounds"));
558 }
559 if start > end || end > dims[axis] {
560 return Err(SapientError::internal("slice range out of bounds"));
561 }
562 dims[axis] = end - start;
563 let offset = self.offset + start * self.strides[axis] * self.dtype.element_size();
564 Ok(Tensor {
565 shape: Shape(dims),
566 dtype: self.dtype,
567 strides: self.strides.clone(),
568 buffer: self.buffer.clone(),
569 offset,
570 })
571 }
572
573 pub fn byte_size(&self) -> usize {
577 self.dtype.byte_count(self.numel())
578 }
579}
580
581impl std::fmt::Display for Tensor {
582 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
583 write!(
584 f,
585 "Tensor(shape={}, dtype={}, device={})",
586 self.shape,
587 self.dtype,
588 self.buffer.0.device()
589 )
590 }
591}
592
593#[derive(Serialize, Deserialize)]
597struct TensorProxy {
598 shape: Shape,
599 dtype: DType,
600 data: Vec<f32>,
602}
603
604impl Serialize for Tensor {
605 fn serialize<S: Serializer>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> {
606 let data: Vec<f32> = if self.dtype == DType::F32 {
607 self.as_f32_slice().to_vec()
608 } else {
609 vec![] };
611 TensorProxy {
612 shape: self.shape.clone(),
613 dtype: self.dtype,
614 data,
615 }
616 .serialize(serializer)
617 }
618}
619
620impl<'de> Deserialize<'de> for Tensor {
621 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> std::result::Result<Self, D::Error> {
622 let proxy = TensorProxy::deserialize(deserializer)?;
623 if proxy.data.is_empty() {
624 Tensor::zeros(proxy.shape, proxy.dtype).map_err(serde::de::Error::custom)
625 } else {
626 Tensor::from_f32(&proxy.data, proxy.shape).map_err(serde::de::Error::custom)
627 }
628 }
629}
630
631#[derive(Debug, Clone, Serialize, Deserialize)]
633pub struct TensorMeta {
634 pub shape: Shape,
635 pub dtype: DType,
636}
637
638impl From<&Tensor> for TensorMeta {
639 fn from(t: &Tensor) -> Self {
640 Self {
641 shape: t.shape.clone(),
642 dtype: t.dtype,
643 }
644 }
645}
646
647#[cfg(test)]
648mod tests {
649 use super::*;
650
651 #[test]
652 fn zeros_dtype_shape() {
653 let t = Tensor::zeros(vec![2, 3], DType::F32).unwrap();
654 assert_eq!(t.shape().dims(), &[2, 3]);
655 assert_eq!(t.dtype(), DType::F32);
656 assert_eq!(t.numel(), 6);
657 }
658
659 #[test]
660 fn from_f32_roundtrip() {
661 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
662 let t = Tensor::from_f32(&data, vec![2, 3]).unwrap();
663 assert_eq!(t.as_f32_slice(), data.as_slice());
664 }
665
666 #[test]
667 fn reshape_preserves_data() {
668 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
669 let t = Tensor::from_f32(&data, vec![2, 3]).unwrap();
670 let r = t.reshape(vec![3, 2]).unwrap();
671 assert_eq!(r.shape().dims(), &[3, 2]);
672 assert_eq!(r.as_f32_slice(), data.as_slice());
673 }
674
675 #[test]
676 fn reshape_wrong_numel() {
677 let t = Tensor::zeros(vec![2, 3], DType::F32).unwrap();
678 assert!(t.reshape(vec![5]).is_err());
679 }
680
681 #[test]
682 fn transpose_2d() {
683 let t = Tensor::zeros(vec![3, 4], DType::F32).unwrap();
684 let t2 = t.t().unwrap();
685 assert_eq!(t2.shape().dims(), &[4, 3]);
686 }
687
688 #[test]
689 fn byte_size() {
690 let t = Tensor::zeros(vec![4, 4], DType::F32).unwrap();
691 assert_eq!(t.byte_size(), 64);
692 }
693}