Skip to main content

yscv_tensor/
tensor.rs

1use super::aligned::AlignedVec;
2use super::error::{DType, TensorError};
3use super::shape::{compute_strides, shape_element_count};
4
5// ── Inline shape/strides storage (no heap alloc for ≤6D tensors) ─────────
6
7// WHY 6: covers all common tensor ranks (scalar(0)..conv weight(5)) without heap allocation.
8const INLINE_CAP: usize = 6;
9
10/// Stack-allocated small vector for tensor shape/strides.
11/// Stores up to 6 dimensions inline; falls back to heap for higher ranks.
12#[derive(Clone)]
13pub(crate) enum DimsVec {
14    Inline { buf: [usize; INLINE_CAP], len: u8 },
15    Heap(Vec<usize>),
16}
17
18impl DimsVec {
19    #[inline]
20    fn new() -> Self {
21        DimsVec::Inline {
22            buf: [0; INLINE_CAP],
23            len: 0,
24        }
25    }
26
27    #[inline]
28    fn as_slice(&self) -> &[usize] {
29        match self {
30            DimsVec::Inline { buf, len } => &buf[..*len as usize],
31            DimsVec::Heap(v) => v,
32        }
33    }
34
35    #[inline]
36    fn to_vec(&self) -> Vec<usize> {
37        self.as_slice().to_vec()
38    }
39}
40
41impl std::ops::Deref for DimsVec {
42    type Target = [usize];
43    #[inline]
44    fn deref(&self) -> &[usize] {
45        self.as_slice()
46    }
47}
48
49impl From<Vec<usize>> for DimsVec {
50    #[inline]
51    fn from(v: Vec<usize>) -> Self {
52        if v.len() <= INLINE_CAP {
53            let mut buf = [0usize; INLINE_CAP];
54            buf[..v.len()].copy_from_slice(&v);
55            DimsVec::Inline {
56                buf,
57                len: v.len() as u8,
58            }
59        } else {
60            DimsVec::Heap(v)
61        }
62    }
63}
64
65impl From<&[usize]> for DimsVec {
66    #[inline]
67    fn from(s: &[usize]) -> Self {
68        if s.len() <= INLINE_CAP {
69            let mut buf = [0usize; INLINE_CAP];
70            buf[..s.len()].copy_from_slice(s);
71            DimsVec::Inline {
72                buf,
73                len: s.len() as u8,
74            }
75        } else {
76            DimsVec::Heap(s.to_vec())
77        }
78    }
79}
80
81impl PartialEq for DimsVec {
82    #[inline]
83    fn eq(&self, other: &Self) -> bool {
84        self.as_slice() == other.as_slice()
85    }
86}
87
88impl std::fmt::Debug for DimsVec {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        self.as_slice().fmt(f)
91    }
92}
93
94/// Logical device where a tensor resides.
95///
96/// This is currently a metadata tag only — no actual data transfer occurs.
97/// GPU data movement is handled externally (e.g. via `GpuSession` in yscv-kernels).
98#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
99pub enum Device {
100    /// CPU (host) memory — the default.
101    #[default]
102    Cpu,
103    /// GPU device with the given index.
104    Gpu(usize),
105}
106
107/// Internal typed storage for tensor data.
108#[derive(Debug, Clone)]
109pub(crate) enum Storage {
110    F32(AlignedVec<f32>),
111    F16(Vec<u16>),
112    BF16(Vec<u16>),
113}
114
115impl PartialEq for Storage {
116    fn eq(&self, other: &Self) -> bool {
117        match (self, other) {
118            (Storage::F32(a), Storage::F32(b)) => a == b,
119            (Storage::F16(a), Storage::F16(b)) => a == b,
120            (Storage::BF16(a), Storage::BF16(b)) => a == b,
121            _ => false,
122        }
123    }
124}
125
126impl Storage {
127    fn len(&self) -> usize {
128        match self {
129            Storage::F32(v) => v.len(),
130            Storage::F16(v) => v.len(),
131            Storage::BF16(v) => v.len(),
132        }
133    }
134
135    fn dtype(&self) -> DType {
136        match self {
137            Storage::F32(_) => DType::F32,
138            Storage::F16(_) => DType::F16,
139            Storage::BF16(_) => DType::BF16,
140        }
141    }
142}
143
144/// A compact, contiguous multi-dtype tensor representation.
145///
146/// Default dtype is `F32`. FP16 and BF16 dtypes are stored natively as `u16` bit patterns
147/// and can be created via `from_f16` / `from_bf16` or converted via `to_dtype`.
148#[derive(Debug, Clone, PartialEq)]
149pub struct Tensor {
150    shape: DimsVec,
151    strides: DimsVec,
152    storage: Storage,
153    device: Device,
154}
155
156impl Tensor {
157    /// Builds a scalar tensor from one value.
158    pub fn scalar(value: f32) -> Self {
159        Self {
160            shape: DimsVec::new(),
161            strides: DimsVec::new(),
162            storage: Storage::F32(AlignedVec::filled(1, value)),
163            device: Device::Cpu,
164        }
165    }
166
167    /// Creates a tensor from pre-validated shape, strides, and data.
168    /// No validation, no heap allocation (shape/strides stored inline for ≤6D).
169    #[inline]
170    pub fn from_raw_parts(shape: &[usize], strides: &[usize], data: AlignedVec<f32>) -> Self {
171        debug_assert_eq!(
172            shape.iter().copied().product::<usize>(),
173            data.len(),
174            "from_raw_parts: shape product != data.len()"
175        );
176        Self {
177            shape: DimsVec::from(shape),
178            strides: DimsVec::from(strides),
179            storage: Storage::F32(data),
180            device: Device::Cpu,
181        }
182    }
183
184    /// Builds a tensor from `shape` and a pre-filled [`AlignedVec`].
185    ///
186    /// This avoids the extra copy that [`from_vec`](Self::from_vec) performs when
187    /// converting a `Vec<f32>` into aligned storage.  Use this when the output
188    /// buffer was already allocated as an `AlignedVec` (e.g. via
189    /// [`AlignedVec::uninitialized`] filled by a BLAS/SIMD kernel).
190    pub fn from_aligned(shape: Vec<usize>, data: AlignedVec<f32>) -> Result<Self, TensorError> {
191        let expected = shape_element_count(&shape).ok_or_else(|| TensorError::SizeOverflow {
192            shape: shape.clone(),
193        })?;
194        if expected != data.len() {
195            return Err(TensorError::SizeMismatch {
196                shape,
197                data_len: data.len(),
198            });
199        }
200
201        let strides = compute_strides(&shape).ok_or_else(|| TensorError::SizeOverflow {
202            shape: shape.clone(),
203        })?;
204
205        Ok(Self {
206            shape: DimsVec::from(shape),
207            strides: DimsVec::from(strides),
208            storage: Storage::F32(data),
209            device: Device::Cpu,
210        })
211    }
212
213    /// Builds a tensor from `shape` and raw contiguous `f32` data.
214    pub fn from_vec(shape: Vec<usize>, data: Vec<f32>) -> Result<Self, TensorError> {
215        let expected = shape_element_count(&shape).ok_or_else(|| TensorError::SizeOverflow {
216            shape: shape.clone(),
217        })?;
218        if expected != data.len() {
219            return Err(TensorError::SizeMismatch {
220                shape,
221                data_len: data.len(),
222            });
223        }
224
225        let strides = compute_strides(&shape).ok_or_else(|| TensorError::SizeOverflow {
226            shape: shape.clone(),
227        })?;
228
229        Ok(Self {
230            shape: DimsVec::from(shape),
231            strides: DimsVec::from(strides),
232            storage: Storage::F32(AlignedVec::from_vec(data)),
233            device: Device::Cpu,
234        })
235    }
236
237    /// Builds a tensor from `shape` and raw FP16 bit patterns (`u16`).
238    pub fn from_f16(shape: Vec<usize>, data: Vec<u16>) -> Result<Self, TensorError> {
239        let expected = shape_element_count(&shape).ok_or_else(|| TensorError::SizeOverflow {
240            shape: shape.clone(),
241        })?;
242        if expected != data.len() {
243            return Err(TensorError::SizeMismatch {
244                shape,
245                data_len: data.len(),
246            });
247        }
248        let strides = compute_strides(&shape).ok_or_else(|| TensorError::SizeOverflow {
249            shape: shape.clone(),
250        })?;
251        Ok(Self {
252            shape: DimsVec::from(shape),
253            strides: DimsVec::from(strides),
254            storage: Storage::F16(data),
255            device: Device::Cpu,
256        })
257    }
258
259    /// Builds a tensor from `shape` and raw BF16 bit patterns (`u16`).
260    pub fn from_bf16(shape: Vec<usize>, data: Vec<u16>) -> Result<Self, TensorError> {
261        let expected = shape_element_count(&shape).ok_or_else(|| TensorError::SizeOverflow {
262            shape: shape.clone(),
263        })?;
264        if expected != data.len() {
265            return Err(TensorError::SizeMismatch {
266                shape,
267                data_len: data.len(),
268            });
269        }
270        let strides = compute_strides(&shape).ok_or_else(|| TensorError::SizeOverflow {
271            shape: shape.clone(),
272        })?;
273        Ok(Self {
274            shape: DimsVec::from(shape),
275            strides: DimsVec::from(strides),
276            storage: Storage::BF16(data),
277            device: Device::Cpu,
278        })
279    }
280
281    /// Builds a 1-D tensor from a slice. Equivalent to `from_vec(vec![data.len()], data.to_vec())`.
282    pub fn from_slice(data: &[f32]) -> Self {
283        let n = data.len();
284        Self {
285            shape: DimsVec::from(vec![n]),
286            strides: DimsVec::from(vec![1usize]),
287            storage: Storage::F32(AlignedVec::from_vec(data.to_vec())),
288            device: Device::Cpu,
289        }
290    }
291
292    /// Builds a value-initialized tensor for a given shape.
293    ///
294    /// Alias: [`full`](Self::full) is provided as a more familiar name.
295    pub fn filled(shape: Vec<usize>, value: f32) -> Result<Self, TensorError> {
296        let count = shape_element_count(&shape).ok_or_else(|| TensorError::SizeOverflow {
297            shape: shape.clone(),
298        })?;
299        let strides = compute_strides(&shape).ok_or_else(|| TensorError::SizeOverflow {
300            shape: shape.clone(),
301        })?;
302
303        Ok(Self {
304            shape: DimsVec::from(shape),
305            strides: DimsVec::from(strides),
306            storage: Storage::F32(AlignedVec::filled(count, value)),
307            device: Device::Cpu,
308        })
309    }
310
311    /// Builds a zero-initialized tensor for a given shape.
312    ///
313    /// Uses `alloc_zeroed` under the hood so the OS can provide pre-zeroed pages
314    /// without writing every byte — significantly faster for large tensors.
315    pub fn zeros(shape: Vec<usize>) -> Result<Self, TensorError> {
316        let count = shape_element_count(&shape).ok_or_else(|| TensorError::SizeOverflow {
317            shape: shape.clone(),
318        })?;
319        let strides = compute_strides(&shape).ok_or_else(|| TensorError::SizeOverflow {
320            shape: shape.clone(),
321        })?;
322
323        Ok(Self {
324            shape: DimsVec::from(shape),
325            strides: DimsVec::from(strides),
326            storage: Storage::F32(AlignedVec::calloc(count)),
327            device: Device::Cpu,
328        })
329    }
330
331    /// Builds a one-initialized tensor for a given shape.
332    pub fn ones(shape: Vec<usize>) -> Result<Self, TensorError> {
333        Self::filled(shape, 1.0)
334    }
335
336    /// Builds a tensor filled with `value`. Alias for [`filled`](Self::filled).
337    pub fn full(shape: Vec<usize>, value: f32) -> Result<Self, TensorError> {
338        Self::filled(shape, value)
339    }
340
341    /// Returns the tensor shape.
342    pub fn shape(&self) -> &[usize] {
343        &self.shape
344    }
345
346    /// Returns the tensor strides.
347    pub fn strides(&self) -> &[usize] {
348        &self.strides
349    }
350
351    /// Returns tensor rank (number of axes).
352    pub fn rank(&self) -> usize {
353        self.shape.len()
354    }
355
356    /// Returns element count.
357    pub fn len(&self) -> usize {
358        self.storage.len()
359    }
360
361    /// Returns `true` if tensor contains zero elements.
362    pub fn is_empty(&self) -> bool {
363        self.storage.len() == 0
364    }
365
366    /// Returns the element data type.
367    pub fn dtype(&self) -> DType {
368        self.storage.dtype()
369    }
370
371    /// Returns the logical device this tensor is associated with.
372    pub fn device(&self) -> Device {
373        self.device
374    }
375
376    /// Returns a copy of this tensor tagged with the given device.
377    ///
378    /// This is currently a metadata-only operation — no actual data transfer
379    /// occurs. GPU data movement is handled by `GpuSession` in yscv-kernels.
380    pub fn to_device(&self, device: Device) -> Self {
381        Self {
382            shape: self.shape.clone(),
383            strides: self.strides.clone(),
384            storage: self.storage.clone(),
385            device,
386        }
387    }
388
389    /// Returns `true` if the tensor stores f32 data.
390    pub fn is_f32(&self) -> bool {
391        matches!(self.storage, Storage::F32(_))
392    }
393
394    /// Fallible version of `data()` — returns an error for non-F32 tensors.
395    pub fn try_data(&self) -> Result<&[f32], TensorError> {
396        match &self.storage {
397            Storage::F32(v) => Ok(v),
398            _ => Err(TensorError::DTypeMismatch {
399                expected: DType::F32,
400                got: self.dtype(),
401            }),
402        }
403    }
404
405    /// Fallible version of `data_mut()` — returns an error for non-F32 tensors.
406    pub fn try_data_mut(&mut self) -> Result<&mut [f32], TensorError> {
407        let dt = self.storage.dtype();
408        match &mut self.storage {
409            Storage::F32(v) => Ok(v),
410            _ => Err(TensorError::DTypeMismatch {
411                expected: DType::F32,
412                got: dt,
413            }),
414        }
415    }
416
417    /// Returns an immutable view over contiguous f32 storage.
418    ///
419    /// # Panics
420    /// Panics if the tensor dtype is not F32. Use `try_data()` for a fallible version.
421    pub fn data(&self) -> &[f32] {
422        self.try_data().expect("tensor is not F32")
423    }
424
425    /// Returns a mutable view over contiguous f32 storage.
426    ///
427    /// # Panics
428    /// Panics if the tensor dtype is not F32. Use `try_data_mut()` for a fallible version.
429    pub fn data_mut(&mut self) -> &mut [f32] {
430        self.try_data_mut().expect("tensor is not F32")
431    }
432
433    /// Alias for `try_data()` for backward compatibility.
434    pub fn try_data_f32(&self) -> Result<&[f32], TensorError> {
435        self.try_data()
436    }
437
438    /// Returns raw FP16 bit-pattern data if dtype is F16.
439    pub fn data_f16(&self) -> Result<&[u16], TensorError> {
440        match &self.storage {
441            Storage::F16(v) => Ok(v),
442            _ => Err(TensorError::DTypeMismatch {
443                expected: DType::F16,
444                got: self.dtype(),
445            }),
446        }
447    }
448
449    /// Returns raw BF16 bit-pattern data if dtype is BF16.
450    pub fn data_bf16(&self) -> Result<&[u16], TensorError> {
451        match &self.storage {
452            Storage::BF16(v) => Ok(v),
453            _ => Err(TensorError::DTypeMismatch {
454                expected: DType::BF16,
455                got: self.dtype(),
456            }),
457        }
458    }
459
460    /// Converts this tensor to the specified dtype, returning a new tensor.
461    /// Converting to the same dtype is a no-op clone.
462    pub fn to_dtype(&self, target: DType) -> Self {
463        if self.dtype() == target {
464            return self.clone();
465        }
466        let f32_data = self.to_f32_vec();
467        let storage = match target {
468            DType::F32 => Storage::F32(AlignedVec::from_vec(f32_data)),
469            DType::F16 => Storage::F16(f32_data.iter().map(|&v| f32_to_fp16_bits(v)).collect()),
470            DType::BF16 => Storage::BF16(f32_data.iter().map(|&v| f32_to_bf16_bits(v)).collect()),
471        };
472        Self {
473            shape: self.shape.clone(),
474            strides: self.strides.clone(),
475            storage,
476            device: self.device,
477        }
478    }
479
480    /// Returns f32 data regardless of internal dtype (converts if necessary).
481    pub(crate) fn to_f32_vec(&self) -> Vec<f32> {
482        match &self.storage {
483            Storage::F32(v) => v.as_slice().to_vec(),
484            Storage::F16(v) => v.iter().map(|&bits| fp16_bits_to_f32(bits)).collect(),
485            Storage::BF16(v) => v.iter().map(|&bits| bf16_bits_to_f32(bits)).collect(),
486        }
487    }
488
489    /// Reads one element by multi-dimensional index (always returns f32).
490    pub fn get(&self, indices: &[usize]) -> Result<f32, TensorError> {
491        let offset = self.offset_from_indices(indices)?;
492        Ok(match &self.storage {
493            Storage::F32(v) => v[offset],
494            Storage::F16(v) => fp16_bits_to_f32(v[offset]),
495            Storage::BF16(v) => bf16_bits_to_f32(v[offset]),
496        })
497    }
498
499    /// Writes one element by multi-dimensional index (stores as native dtype).
500    pub fn set(&mut self, indices: &[usize], value: f32) -> Result<(), TensorError> {
501        let offset = self.offset_from_indices(indices)?;
502        match &mut self.storage {
503            Storage::F32(v) => v[offset] = value,
504            Storage::F16(v) => v[offset] = f32_to_fp16_bits(value),
505            Storage::BF16(v) => v[offset] = f32_to_bf16_bits(value),
506        }
507        Ok(())
508    }
509
510    /// Returns a reshaped tensor view with copied metadata and data.
511    pub fn reshape(&self, new_shape: Vec<usize>) -> Result<Self, TensorError> {
512        let new_count =
513            shape_element_count(&new_shape).ok_or_else(|| TensorError::SizeOverflow {
514                shape: new_shape.clone(),
515            })?;
516        if new_count != self.len() {
517            return Err(TensorError::ReshapeSizeMismatch {
518                from: self.shape.to_vec(),
519                to: new_shape,
520            });
521        }
522
523        let new_strides = compute_strides(&new_shape).ok_or_else(|| TensorError::SizeOverflow {
524            shape: new_shape.clone(),
525        })?;
526
527        Ok(Self {
528            shape: DimsVec::from(new_shape),
529            strides: DimsVec::from(new_strides),
530            storage: self.storage.clone(),
531            device: self.device,
532        })
533    }
534
535    /// Consumes the tensor and returns a reshaped version without copying data.
536    pub fn into_reshape(self, new_shape: Vec<usize>) -> Result<Self, TensorError> {
537        let new_count =
538            shape_element_count(&new_shape).ok_or_else(|| TensorError::SizeOverflow {
539                shape: new_shape.clone(),
540            })?;
541        if new_count != self.len() {
542            return Err(TensorError::ReshapeSizeMismatch {
543                from: self.shape.to_vec(),
544                to: new_shape,
545            });
546        }
547
548        let new_strides = compute_strides(&new_shape).ok_or_else(|| TensorError::SizeOverflow {
549            shape: new_shape.clone(),
550        })?;
551
552        Ok(Self {
553            shape: DimsVec::from(new_shape),
554            strides: DimsVec::from(new_strides),
555            storage: self.storage,
556            device: self.device,
557        })
558    }
559
560    pub(crate) fn offset_from_indices(&self, indices: &[usize]) -> Result<usize, TensorError> {
561        if indices.len() != self.rank() {
562            return Err(TensorError::InvalidIndexRank {
563                expected: self.rank(),
564                got: indices.len(),
565            });
566        }
567
568        let mut offset = 0usize;
569        for (axis, (index, dim)) in indices.iter().zip(self.shape.iter()).enumerate() {
570            if *index >= *dim {
571                return Err(TensorError::IndexOutOfBounds {
572                    axis,
573                    index: *index,
574                    dim: *dim,
575                });
576            }
577            offset = offset
578                .checked_add(index.checked_mul(self.strides[axis]).ok_or_else(|| {
579                    TensorError::SizeOverflow {
580                        shape: self.shape.to_vec(),
581                    }
582                })?)
583                .ok_or_else(|| TensorError::SizeOverflow {
584                    shape: self.shape.to_vec(),
585                })?;
586        }
587        Ok(offset)
588    }
589}
590
591// ── FP16/BF16 bit conversion primitives ────────────────────────────
592
593fn f32_to_fp16_bits(val: f32) -> u16 {
594    let bits = val.to_bits();
595    let sign = ((bits >> 16) & 0x8000) as u16;
596    let exponent = ((bits >> 23) & 0xFF) as i32;
597    let mantissa = bits & 0x007FFFFF;
598
599    if exponent == 0xFF {
600        return sign | 0x7C00 | if mantissa != 0 { 0x0200 } else { 0 };
601    }
602    let unbiased = exponent - 127;
603    if unbiased < -24 {
604        return sign;
605    }
606    if unbiased < -14 {
607        let shift = -1 - unbiased;
608        let subnormal = ((mantissa | 0x00800000) >> (shift + 13)) as u16;
609        return sign | subnormal;
610    }
611    if unbiased > 15 {
612        return sign | 0x7C00;
613    }
614    let fp16_exp = ((unbiased + 15) as u16) << 10;
615    let fp16_man = (mantissa >> 13) as u16;
616    sign | fp16_exp | fp16_man
617}
618
619fn fp16_bits_to_f32(half: u16) -> f32 {
620    let sign = ((half & 0x8000) as u32) << 16;
621    let exponent = (half >> 10) & 0x1F;
622    let mantissa = (half & 0x03FF) as u32;
623    if exponent == 0 {
624        if mantissa == 0 {
625            return f32::from_bits(sign);
626        }
627        let mut e = 0i32;
628        let mut m = mantissa;
629        while m & 0x0400 == 0 {
630            m <<= 1;
631            e += 1;
632        }
633        let f32_exp = ((127 - 15 - e) as u32) << 23;
634        let f32_man = (m & 0x03FF) << 13;
635        return f32::from_bits(sign | f32_exp | f32_man);
636    }
637    if exponent == 31 {
638        let f32_bits = sign | 0x7F800000 | if mantissa != 0 { 0x00400000 } else { 0 };
639        return f32::from_bits(f32_bits);
640    }
641    let f32_exp = ((exponent as u32) + 112) << 23;
642    let f32_man = mantissa << 13;
643    f32::from_bits(sign | f32_exp | f32_man)
644}
645
646fn f32_to_bf16_bits(val: f32) -> u16 {
647    let bits = val.to_bits();
648    // Round to nearest even
649    let rounding_bias = 0x7FFF + ((bits >> 16) & 1);
650    ((bits.wrapping_add(rounding_bias)) >> 16) as u16
651}
652
653fn bf16_bits_to_f32(bits: u16) -> f32 {
654    f32::from_bits((bits as u32) << 16)
655}
656
657// ── Display impl ────────────────────────────────────────────────────
658
659impl std::fmt::Display for Tensor {
660    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
661        let shape = self.shape();
662        let dtype = self.dtype();
663        let n = self.len();
664
665        write!(f, "Tensor({shape:?}, {dtype:?}")?;
666
667        // Show first few and last few values for a compact preview.
668        const MAX_SHOW: usize = 6;
669        if n == 0 {
670            write!(f, ", []")?;
671        } else {
672            let vals = self.to_f32_vec();
673            write!(f, ", [")?;
674            if n <= MAX_SHOW {
675                for (i, v) in vals.iter().enumerate() {
676                    if i > 0 {
677                        write!(f, ", ")?;
678                    }
679                    write!(f, "{v}")?;
680                }
681            } else {
682                let head = MAX_SHOW / 2;
683                let tail = MAX_SHOW - head;
684                for (i, v) in vals[..head].iter().enumerate() {
685                    if i > 0 {
686                        write!(f, ", ")?;
687                    }
688                    write!(f, "{v}")?;
689                }
690                write!(f, ", ...")?;
691                for v in &vals[n - tail..] {
692                    write!(f, ", {v}")?;
693                }
694            }
695            write!(f, "]")?;
696        }
697        write!(f, ")")
698    }
699}
700
701// ── Operator overloading (std::ops) ─────────────────────────────────
702
703impl std::ops::Add for &Tensor {
704    type Output = Tensor;
705    /// Element-wise addition. Panics on shape mismatch.
706    fn add(self, rhs: Self) -> Tensor {
707        Tensor::add(self, rhs).expect("Tensor + Tensor: shape mismatch")
708    }
709}
710
711impl std::ops::Add for Tensor {
712    type Output = Tensor;
713    fn add(self, rhs: Self) -> Tensor {
714        Tensor::add(&self, &rhs).expect("Tensor + Tensor: shape mismatch")
715    }
716}
717
718impl std::ops::Sub for &Tensor {
719    type Output = Tensor;
720    fn sub(self, rhs: Self) -> Tensor {
721        Tensor::sub(self, rhs).expect("Tensor - Tensor: shape mismatch")
722    }
723}
724
725impl std::ops::Sub for Tensor {
726    type Output = Tensor;
727    fn sub(self, rhs: Self) -> Tensor {
728        Tensor::sub(&self, &rhs).expect("Tensor - Tensor: shape mismatch")
729    }
730}
731
732impl std::ops::Mul for &Tensor {
733    type Output = Tensor;
734    /// Element-wise multiplication. Panics on shape mismatch.
735    fn mul(self, rhs: Self) -> Tensor {
736        Tensor::mul(self, rhs).expect("Tensor * Tensor: shape mismatch")
737    }
738}
739
740impl std::ops::Mul for Tensor {
741    type Output = Tensor;
742    fn mul(self, rhs: Self) -> Tensor {
743        Tensor::mul(&self, &rhs).expect("Tensor * Tensor: shape mismatch")
744    }
745}
746
747impl std::ops::Mul<f32> for &Tensor {
748    type Output = Tensor;
749    /// Scalar multiplication.
750    fn mul(self, rhs: f32) -> Tensor {
751        Tensor::scale(self, rhs)
752    }
753}
754
755impl std::ops::Mul<f32> for Tensor {
756    type Output = Tensor;
757    fn mul(self, rhs: f32) -> Tensor {
758        Tensor::scale(&self, rhs)
759    }
760}
761
762impl std::ops::Div for &Tensor {
763    type Output = Tensor;
764    fn div(self, rhs: Self) -> Tensor {
765        Tensor::div(self, rhs).expect("Tensor / Tensor: shape mismatch")
766    }
767}
768
769impl std::ops::Div for Tensor {
770    type Output = Tensor;
771    fn div(self, rhs: Self) -> Tensor {
772        Tensor::div(&self, &rhs).expect("Tensor / Tensor: shape mismatch")
773    }
774}
775
776impl std::ops::Neg for &Tensor {
777    type Output = Tensor;
778    fn neg(self) -> Tensor {
779        Tensor::neg(self)
780    }
781}
782
783impl std::ops::Neg for Tensor {
784    type Output = Tensor;
785    fn neg(self) -> Tensor {
786        Tensor::neg(&self)
787    }
788}