1use super::aligned::AlignedVec;
2use super::error::{DType, TensorError};
3use super::shape::{compute_strides, shape_element_count};
4
5const INLINE_CAP: usize = 6;
9
10#[derive(Clone)]
13pub(crate) enum DimsVec {
14 Inline { buf: [usize; INLINE_CAP], len: u8 },
15 Heap(Vec<usize>),
16}
17
18impl DimsVec {
19 #[inline]
20 fn new() -> Self {
21 DimsVec::Inline {
22 buf: [0; INLINE_CAP],
23 len: 0,
24 }
25 }
26
27 #[inline]
28 fn as_slice(&self) -> &[usize] {
29 match self {
30 DimsVec::Inline { buf, len } => &buf[..*len as usize],
31 DimsVec::Heap(v) => v,
32 }
33 }
34
35 #[inline]
36 fn to_vec(&self) -> Vec<usize> {
37 self.as_slice().to_vec()
38 }
39}
40
41impl std::ops::Deref for DimsVec {
42 type Target = [usize];
43 #[inline]
44 fn deref(&self) -> &[usize] {
45 self.as_slice()
46 }
47}
48
49impl From<Vec<usize>> for DimsVec {
50 #[inline]
51 fn from(v: Vec<usize>) -> Self {
52 if v.len() <= INLINE_CAP {
53 let mut buf = [0usize; INLINE_CAP];
54 buf[..v.len()].copy_from_slice(&v);
55 DimsVec::Inline {
56 buf,
57 len: v.len() as u8,
58 }
59 } else {
60 DimsVec::Heap(v)
61 }
62 }
63}
64
65impl From<&[usize]> for DimsVec {
66 #[inline]
67 fn from(s: &[usize]) -> Self {
68 if s.len() <= INLINE_CAP {
69 let mut buf = [0usize; INLINE_CAP];
70 buf[..s.len()].copy_from_slice(s);
71 DimsVec::Inline {
72 buf,
73 len: s.len() as u8,
74 }
75 } else {
76 DimsVec::Heap(s.to_vec())
77 }
78 }
79}
80
81impl PartialEq for DimsVec {
82 #[inline]
83 fn eq(&self, other: &Self) -> bool {
84 self.as_slice() == other.as_slice()
85 }
86}
87
88impl std::fmt::Debug for DimsVec {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 self.as_slice().fmt(f)
91 }
92}
93
94#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
99pub enum Device {
100 #[default]
102 Cpu,
103 Gpu(usize),
105}
106
107#[derive(Debug, Clone)]
109pub(crate) enum Storage {
110 F32(AlignedVec<f32>),
111 F16(Vec<u16>),
112 BF16(Vec<u16>),
113}
114
115impl PartialEq for Storage {
116 fn eq(&self, other: &Self) -> bool {
117 match (self, other) {
118 (Storage::F32(a), Storage::F32(b)) => a == b,
119 (Storage::F16(a), Storage::F16(b)) => a == b,
120 (Storage::BF16(a), Storage::BF16(b)) => a == b,
121 _ => false,
122 }
123 }
124}
125
126impl Storage {
127 fn len(&self) -> usize {
128 match self {
129 Storage::F32(v) => v.len(),
130 Storage::F16(v) => v.len(),
131 Storage::BF16(v) => v.len(),
132 }
133 }
134
135 fn dtype(&self) -> DType {
136 match self {
137 Storage::F32(_) => DType::F32,
138 Storage::F16(_) => DType::F16,
139 Storage::BF16(_) => DType::BF16,
140 }
141 }
142}
143
144#[derive(Debug, Clone, PartialEq)]
149pub struct Tensor {
150 shape: DimsVec,
151 strides: DimsVec,
152 storage: Storage,
153 device: Device,
154}
155
156impl Tensor {
157 pub fn scalar(value: f32) -> Self {
159 Self {
160 shape: DimsVec::new(),
161 strides: DimsVec::new(),
162 storage: Storage::F32(AlignedVec::filled(1, value)),
163 device: Device::Cpu,
164 }
165 }
166
167 #[inline]
170 pub fn from_raw_parts(shape: &[usize], strides: &[usize], data: AlignedVec<f32>) -> Self {
171 debug_assert_eq!(
172 shape.iter().copied().product::<usize>(),
173 data.len(),
174 "from_raw_parts: shape product != data.len()"
175 );
176 Self {
177 shape: DimsVec::from(shape),
178 strides: DimsVec::from(strides),
179 storage: Storage::F32(data),
180 device: Device::Cpu,
181 }
182 }
183
184 pub fn from_aligned(shape: Vec<usize>, data: AlignedVec<f32>) -> Result<Self, TensorError> {
191 let expected = shape_element_count(&shape).ok_or_else(|| TensorError::SizeOverflow {
192 shape: shape.clone(),
193 })?;
194 if expected != data.len() {
195 return Err(TensorError::SizeMismatch {
196 shape,
197 data_len: data.len(),
198 });
199 }
200
201 let strides = compute_strides(&shape).ok_or_else(|| TensorError::SizeOverflow {
202 shape: shape.clone(),
203 })?;
204
205 Ok(Self {
206 shape: DimsVec::from(shape),
207 strides: DimsVec::from(strides),
208 storage: Storage::F32(data),
209 device: Device::Cpu,
210 })
211 }
212
213 pub fn from_vec(shape: Vec<usize>, data: Vec<f32>) -> Result<Self, TensorError> {
215 let expected = shape_element_count(&shape).ok_or_else(|| TensorError::SizeOverflow {
216 shape: shape.clone(),
217 })?;
218 if expected != data.len() {
219 return Err(TensorError::SizeMismatch {
220 shape,
221 data_len: data.len(),
222 });
223 }
224
225 let strides = compute_strides(&shape).ok_or_else(|| TensorError::SizeOverflow {
226 shape: shape.clone(),
227 })?;
228
229 Ok(Self {
230 shape: DimsVec::from(shape),
231 strides: DimsVec::from(strides),
232 storage: Storage::F32(AlignedVec::from_vec(data)),
233 device: Device::Cpu,
234 })
235 }
236
237 pub fn from_f16(shape: Vec<usize>, data: Vec<u16>) -> Result<Self, TensorError> {
239 let expected = shape_element_count(&shape).ok_or_else(|| TensorError::SizeOverflow {
240 shape: shape.clone(),
241 })?;
242 if expected != data.len() {
243 return Err(TensorError::SizeMismatch {
244 shape,
245 data_len: data.len(),
246 });
247 }
248 let strides = compute_strides(&shape).ok_or_else(|| TensorError::SizeOverflow {
249 shape: shape.clone(),
250 })?;
251 Ok(Self {
252 shape: DimsVec::from(shape),
253 strides: DimsVec::from(strides),
254 storage: Storage::F16(data),
255 device: Device::Cpu,
256 })
257 }
258
259 pub fn from_bf16(shape: Vec<usize>, data: Vec<u16>) -> Result<Self, TensorError> {
261 let expected = shape_element_count(&shape).ok_or_else(|| TensorError::SizeOverflow {
262 shape: shape.clone(),
263 })?;
264 if expected != data.len() {
265 return Err(TensorError::SizeMismatch {
266 shape,
267 data_len: data.len(),
268 });
269 }
270 let strides = compute_strides(&shape).ok_or_else(|| TensorError::SizeOverflow {
271 shape: shape.clone(),
272 })?;
273 Ok(Self {
274 shape: DimsVec::from(shape),
275 strides: DimsVec::from(strides),
276 storage: Storage::BF16(data),
277 device: Device::Cpu,
278 })
279 }
280
281 pub fn from_slice(data: &[f32]) -> Self {
283 let n = data.len();
284 Self {
285 shape: DimsVec::from(vec![n]),
286 strides: DimsVec::from(vec![1usize]),
287 storage: Storage::F32(AlignedVec::from_vec(data.to_vec())),
288 device: Device::Cpu,
289 }
290 }
291
292 pub fn filled(shape: Vec<usize>, value: f32) -> Result<Self, TensorError> {
296 let count = shape_element_count(&shape).ok_or_else(|| TensorError::SizeOverflow {
297 shape: shape.clone(),
298 })?;
299 let strides = compute_strides(&shape).ok_or_else(|| TensorError::SizeOverflow {
300 shape: shape.clone(),
301 })?;
302
303 Ok(Self {
304 shape: DimsVec::from(shape),
305 strides: DimsVec::from(strides),
306 storage: Storage::F32(AlignedVec::filled(count, value)),
307 device: Device::Cpu,
308 })
309 }
310
311 pub fn zeros(shape: Vec<usize>) -> Result<Self, TensorError> {
316 let count = shape_element_count(&shape).ok_or_else(|| TensorError::SizeOverflow {
317 shape: shape.clone(),
318 })?;
319 let strides = compute_strides(&shape).ok_or_else(|| TensorError::SizeOverflow {
320 shape: shape.clone(),
321 })?;
322
323 Ok(Self {
324 shape: DimsVec::from(shape),
325 strides: DimsVec::from(strides),
326 storage: Storage::F32(AlignedVec::calloc(count)),
327 device: Device::Cpu,
328 })
329 }
330
331 pub fn ones(shape: Vec<usize>) -> Result<Self, TensorError> {
333 Self::filled(shape, 1.0)
334 }
335
336 pub fn full(shape: Vec<usize>, value: f32) -> Result<Self, TensorError> {
338 Self::filled(shape, value)
339 }
340
341 pub fn shape(&self) -> &[usize] {
343 &self.shape
344 }
345
346 pub fn strides(&self) -> &[usize] {
348 &self.strides
349 }
350
351 pub fn rank(&self) -> usize {
353 self.shape.len()
354 }
355
356 pub fn len(&self) -> usize {
358 self.storage.len()
359 }
360
361 pub fn is_empty(&self) -> bool {
363 self.storage.len() == 0
364 }
365
366 pub fn dtype(&self) -> DType {
368 self.storage.dtype()
369 }
370
371 pub fn device(&self) -> Device {
373 self.device
374 }
375
376 pub fn to_device(&self, device: Device) -> Self {
381 Self {
382 shape: self.shape.clone(),
383 strides: self.strides.clone(),
384 storage: self.storage.clone(),
385 device,
386 }
387 }
388
389 pub fn is_f32(&self) -> bool {
391 matches!(self.storage, Storage::F32(_))
392 }
393
394 pub fn try_data(&self) -> Result<&[f32], TensorError> {
396 match &self.storage {
397 Storage::F32(v) => Ok(v),
398 _ => Err(TensorError::DTypeMismatch {
399 expected: DType::F32,
400 got: self.dtype(),
401 }),
402 }
403 }
404
405 pub fn try_data_mut(&mut self) -> Result<&mut [f32], TensorError> {
407 let dt = self.storage.dtype();
408 match &mut self.storage {
409 Storage::F32(v) => Ok(v),
410 _ => Err(TensorError::DTypeMismatch {
411 expected: DType::F32,
412 got: dt,
413 }),
414 }
415 }
416
417 pub fn data(&self) -> &[f32] {
422 self.try_data().expect("tensor is not F32")
423 }
424
425 pub fn data_mut(&mut self) -> &mut [f32] {
430 self.try_data_mut().expect("tensor is not F32")
431 }
432
433 pub fn try_data_f32(&self) -> Result<&[f32], TensorError> {
435 self.try_data()
436 }
437
438 pub fn data_f16(&self) -> Result<&[u16], TensorError> {
440 match &self.storage {
441 Storage::F16(v) => Ok(v),
442 _ => Err(TensorError::DTypeMismatch {
443 expected: DType::F16,
444 got: self.dtype(),
445 }),
446 }
447 }
448
449 pub fn data_bf16(&self) -> Result<&[u16], TensorError> {
451 match &self.storage {
452 Storage::BF16(v) => Ok(v),
453 _ => Err(TensorError::DTypeMismatch {
454 expected: DType::BF16,
455 got: self.dtype(),
456 }),
457 }
458 }
459
460 pub fn to_dtype(&self, target: DType) -> Self {
463 if self.dtype() == target {
464 return self.clone();
465 }
466 let f32_data = self.to_f32_vec();
467 let storage = match target {
468 DType::F32 => Storage::F32(AlignedVec::from_vec(f32_data)),
469 DType::F16 => Storage::F16(f32_data.iter().map(|&v| f32_to_fp16_bits(v)).collect()),
470 DType::BF16 => Storage::BF16(f32_data.iter().map(|&v| f32_to_bf16_bits(v)).collect()),
471 };
472 Self {
473 shape: self.shape.clone(),
474 strides: self.strides.clone(),
475 storage,
476 device: self.device,
477 }
478 }
479
480 pub(crate) fn to_f32_vec(&self) -> Vec<f32> {
482 match &self.storage {
483 Storage::F32(v) => v.as_slice().to_vec(),
484 Storage::F16(v) => v.iter().map(|&bits| fp16_bits_to_f32(bits)).collect(),
485 Storage::BF16(v) => v.iter().map(|&bits| bf16_bits_to_f32(bits)).collect(),
486 }
487 }
488
489 pub fn get(&self, indices: &[usize]) -> Result<f32, TensorError> {
491 let offset = self.offset_from_indices(indices)?;
492 Ok(match &self.storage {
493 Storage::F32(v) => v[offset],
494 Storage::F16(v) => fp16_bits_to_f32(v[offset]),
495 Storage::BF16(v) => bf16_bits_to_f32(v[offset]),
496 })
497 }
498
499 pub fn set(&mut self, indices: &[usize], value: f32) -> Result<(), TensorError> {
501 let offset = self.offset_from_indices(indices)?;
502 match &mut self.storage {
503 Storage::F32(v) => v[offset] = value,
504 Storage::F16(v) => v[offset] = f32_to_fp16_bits(value),
505 Storage::BF16(v) => v[offset] = f32_to_bf16_bits(value),
506 }
507 Ok(())
508 }
509
510 pub fn reshape(&self, new_shape: Vec<usize>) -> Result<Self, TensorError> {
512 let new_count =
513 shape_element_count(&new_shape).ok_or_else(|| TensorError::SizeOverflow {
514 shape: new_shape.clone(),
515 })?;
516 if new_count != self.len() {
517 return Err(TensorError::ReshapeSizeMismatch {
518 from: self.shape.to_vec(),
519 to: new_shape,
520 });
521 }
522
523 let new_strides = compute_strides(&new_shape).ok_or_else(|| TensorError::SizeOverflow {
524 shape: new_shape.clone(),
525 })?;
526
527 Ok(Self {
528 shape: DimsVec::from(new_shape),
529 strides: DimsVec::from(new_strides),
530 storage: self.storage.clone(),
531 device: self.device,
532 })
533 }
534
535 pub fn into_reshape(self, new_shape: Vec<usize>) -> Result<Self, TensorError> {
537 let new_count =
538 shape_element_count(&new_shape).ok_or_else(|| TensorError::SizeOverflow {
539 shape: new_shape.clone(),
540 })?;
541 if new_count != self.len() {
542 return Err(TensorError::ReshapeSizeMismatch {
543 from: self.shape.to_vec(),
544 to: new_shape,
545 });
546 }
547
548 let new_strides = compute_strides(&new_shape).ok_or_else(|| TensorError::SizeOverflow {
549 shape: new_shape.clone(),
550 })?;
551
552 Ok(Self {
553 shape: DimsVec::from(new_shape),
554 strides: DimsVec::from(new_strides),
555 storage: self.storage,
556 device: self.device,
557 })
558 }
559
560 pub(crate) fn offset_from_indices(&self, indices: &[usize]) -> Result<usize, TensorError> {
561 if indices.len() != self.rank() {
562 return Err(TensorError::InvalidIndexRank {
563 expected: self.rank(),
564 got: indices.len(),
565 });
566 }
567
568 let mut offset = 0usize;
569 for (axis, (index, dim)) in indices.iter().zip(self.shape.iter()).enumerate() {
570 if *index >= *dim {
571 return Err(TensorError::IndexOutOfBounds {
572 axis,
573 index: *index,
574 dim: *dim,
575 });
576 }
577 offset = offset
578 .checked_add(index.checked_mul(self.strides[axis]).ok_or_else(|| {
579 TensorError::SizeOverflow {
580 shape: self.shape.to_vec(),
581 }
582 })?)
583 .ok_or_else(|| TensorError::SizeOverflow {
584 shape: self.shape.to_vec(),
585 })?;
586 }
587 Ok(offset)
588 }
589}
590
591fn f32_to_fp16_bits(val: f32) -> u16 {
594 let bits = val.to_bits();
595 let sign = ((bits >> 16) & 0x8000) as u16;
596 let exponent = ((bits >> 23) & 0xFF) as i32;
597 let mantissa = bits & 0x007FFFFF;
598
599 if exponent == 0xFF {
600 return sign | 0x7C00 | if mantissa != 0 { 0x0200 } else { 0 };
601 }
602 let unbiased = exponent - 127;
603 if unbiased < -24 {
604 return sign;
605 }
606 if unbiased < -14 {
607 let shift = -1 - unbiased;
608 let subnormal = ((mantissa | 0x00800000) >> (shift + 13)) as u16;
609 return sign | subnormal;
610 }
611 if unbiased > 15 {
612 return sign | 0x7C00;
613 }
614 let fp16_exp = ((unbiased + 15) as u16) << 10;
615 let fp16_man = (mantissa >> 13) as u16;
616 sign | fp16_exp | fp16_man
617}
618
619fn fp16_bits_to_f32(half: u16) -> f32 {
620 let sign = ((half & 0x8000) as u32) << 16;
621 let exponent = (half >> 10) & 0x1F;
622 let mantissa = (half & 0x03FF) as u32;
623 if exponent == 0 {
624 if mantissa == 0 {
625 return f32::from_bits(sign);
626 }
627 let mut e = 0i32;
628 let mut m = mantissa;
629 while m & 0x0400 == 0 {
630 m <<= 1;
631 e += 1;
632 }
633 let f32_exp = ((127 - 15 - e) as u32) << 23;
634 let f32_man = (m & 0x03FF) << 13;
635 return f32::from_bits(sign | f32_exp | f32_man);
636 }
637 if exponent == 31 {
638 let f32_bits = sign | 0x7F800000 | if mantissa != 0 { 0x00400000 } else { 0 };
639 return f32::from_bits(f32_bits);
640 }
641 let f32_exp = ((exponent as u32) + 112) << 23;
642 let f32_man = mantissa << 13;
643 f32::from_bits(sign | f32_exp | f32_man)
644}
645
646fn f32_to_bf16_bits(val: f32) -> u16 {
647 let bits = val.to_bits();
648 let rounding_bias = 0x7FFF + ((bits >> 16) & 1);
650 ((bits.wrapping_add(rounding_bias)) >> 16) as u16
651}
652
653fn bf16_bits_to_f32(bits: u16) -> f32 {
654 f32::from_bits((bits as u32) << 16)
655}
656
657impl std::fmt::Display for Tensor {
660 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
661 let shape = self.shape();
662 let dtype = self.dtype();
663 let n = self.len();
664
665 write!(f, "Tensor({shape:?}, {dtype:?}")?;
666
667 const MAX_SHOW: usize = 6;
669 if n == 0 {
670 write!(f, ", []")?;
671 } else {
672 let vals = self.to_f32_vec();
673 write!(f, ", [")?;
674 if n <= MAX_SHOW {
675 for (i, v) in vals.iter().enumerate() {
676 if i > 0 {
677 write!(f, ", ")?;
678 }
679 write!(f, "{v}")?;
680 }
681 } else {
682 let head = MAX_SHOW / 2;
683 let tail = MAX_SHOW - head;
684 for (i, v) in vals[..head].iter().enumerate() {
685 if i > 0 {
686 write!(f, ", ")?;
687 }
688 write!(f, "{v}")?;
689 }
690 write!(f, ", ...")?;
691 for v in &vals[n - tail..] {
692 write!(f, ", {v}")?;
693 }
694 }
695 write!(f, "]")?;
696 }
697 write!(f, ")")
698 }
699}
700
701impl std::ops::Add for &Tensor {
704 type Output = Tensor;
705 fn add(self, rhs: Self) -> Tensor {
707 Tensor::add(self, rhs).expect("Tensor + Tensor: shape mismatch")
708 }
709}
710
711impl std::ops::Add for Tensor {
712 type Output = Tensor;
713 fn add(self, rhs: Self) -> Tensor {
714 Tensor::add(&self, &rhs).expect("Tensor + Tensor: shape mismatch")
715 }
716}
717
718impl std::ops::Sub for &Tensor {
719 type Output = Tensor;
720 fn sub(self, rhs: Self) -> Tensor {
721 Tensor::sub(self, rhs).expect("Tensor - Tensor: shape mismatch")
722 }
723}
724
725impl std::ops::Sub for Tensor {
726 type Output = Tensor;
727 fn sub(self, rhs: Self) -> Tensor {
728 Tensor::sub(&self, &rhs).expect("Tensor - Tensor: shape mismatch")
729 }
730}
731
732impl std::ops::Mul for &Tensor {
733 type Output = Tensor;
734 fn mul(self, rhs: Self) -> Tensor {
736 Tensor::mul(self, rhs).expect("Tensor * Tensor: shape mismatch")
737 }
738}
739
740impl std::ops::Mul for Tensor {
741 type Output = Tensor;
742 fn mul(self, rhs: Self) -> Tensor {
743 Tensor::mul(&self, &rhs).expect("Tensor * Tensor: shape mismatch")
744 }
745}
746
747impl std::ops::Mul<f32> for &Tensor {
748 type Output = Tensor;
749 fn mul(self, rhs: f32) -> Tensor {
751 Tensor::scale(self, rhs)
752 }
753}
754
755impl std::ops::Mul<f32> for Tensor {
756 type Output = Tensor;
757 fn mul(self, rhs: f32) -> Tensor {
758 Tensor::scale(&self, rhs)
759 }
760}
761
762impl std::ops::Div for &Tensor {
763 type Output = Tensor;
764 fn div(self, rhs: Self) -> Tensor {
765 Tensor::div(self, rhs).expect("Tensor / Tensor: shape mismatch")
766 }
767}
768
769impl std::ops::Div for Tensor {
770 type Output = Tensor;
771 fn div(self, rhs: Self) -> Tensor {
772 Tensor::div(&self, &rhs).expect("Tensor / Tensor: shape mismatch")
773 }
774}
775
776impl std::ops::Neg for &Tensor {
777 type Output = Tensor;
778 fn neg(self) -> Tensor {
779 Tensor::neg(self)
780 }
781}
782
783impl std::ops::Neg for Tensor {
784 type Output = Tensor;
785 fn neg(self) -> Tensor {
786 Tensor::neg(&self)
787 }
788}