Skip to main content

tract_data/
tensor.rs

1//! `Tensor`, tract main data object of interest.
2use crate::TVec;
3use crate::blob::Blob;
4use crate::datum::{ClampCast, Datum, DatumType, QParams, round_ties_to_even, scale_by};
5use crate::dim::TDim;
6use crate::internal::*;
7use crate::opaque::Opaque;
8use half::f16;
9use itertools::{Itertools, izip};
10use ndarray::prelude::*;
11#[cfg(feature = "complex")]
12use num_complex::Complex;
13use num_traits::{Float, Zero};
14use std::borrow::Cow;
15use std::fmt;
16use std::hash::Hash;
17use std::ops::Range;
18use std::sync::Arc;
19
20pub mod dense_view;
21pub mod litteral;
22pub mod storage;
23pub mod view;
24
25pub use dense_view::{DenseView, DenseViewMut};
26use storage::{DenseStorage, StorageKind};
27
28#[derive(Copy, Clone, Default, Debug)]
29pub enum Approximation {
30    Exact,
31    #[default]
32    Close,
33    Approximate,
34    VeryApproximate,
35    SuperApproximate,
36    UltraApproximate,
37    Custom(f32, f32, f32),
38}
39
40impl PartialEq for Approximation {
41    fn eq(&self, other: &Self) -> bool {
42        use Approximation::Custom;
43        if let (Custom(aa, ar, ao), Custom(ba, br, bo)) = (self, other) {
44            aa == ba && ar == br && bo == ao
45        } else {
46            std::mem::discriminant(self) == std::mem::discriminant(other)
47        }
48    }
49}
50
51impl Eq for Approximation {}
52
53impl From<bool> for Approximation {
54    fn from(b: bool) -> Self {
55        if b { Self::Approximate } else { Self::Exact }
56    }
57}
58
59impl Approximation {
60    fn atol_rtol_outliers(&self, dt: &DatumType) -> (f64, f64, f64) {
61        use Approximation::*;
62        match (self, dt) {
63            (Exact, _) => (0.0, 0.0, 0.0),
64            (Close, DatumType::F16) => (1e-3, 1e-3, 0.0),
65            (Approximate, DatumType::F16) => (1e-3, 5e-3, 0.0),
66            (Approximate, qp) if qp.is_quantized() => (qp.zp_scale().1 as f64, 0., 0.0),
67            (Close, _) => (1e-7, 1e-7, 0.0),
68            (Approximate, _) => (1e-4, 5e-4, 0.0),
69            (VeryApproximate, _) => (5e-2, 1e-2, 0.0),
70            (SuperApproximate, _) => (0.1, 0.05, 0.0001),
71            (UltraApproximate, _) => (0.2, 0.1, 0.0005),
72            (Custom(atol, rtol, out), _) => (*atol as _, *rtol as _, *out as _),
73        }
74    }
75}
76
77/// Tensor is a concrete tensor in tract.
78pub struct Tensor {
79    dt: DatumType,
80    shape: TVec<usize>,
81    strides: TVec<isize>,
82    len: usize,
83    storage: StorageKind,
84}
85
86unsafe impl Send for Tensor {}
87unsafe impl Sync for Tensor {}
88
89impl Hash for Tensor {
90    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
91        use DatumType::*;
92        self.dt.hash(state);
93        self.shape.hash(state);
94        self.dense_storage().layout().align().hash(state);
95        unsafe {
96            match self.dt {
97                Bool => self.as_slice_unchecked::<bool>().hash(state),
98                I8 => self.as_slice_unchecked::<i8>().hash(state),
99                I16 => self.as_slice_unchecked::<i16>().hash(state),
100                I32 => self.as_slice_unchecked::<i32>().hash(state),
101                I64 => self.as_slice_unchecked::<i64>().hash(state),
102                U8 => self.as_slice_unchecked::<u8>().hash(state),
103                U16 => self.as_slice_unchecked::<u16>().hash(state),
104                U32 => self.as_slice_unchecked::<u32>().hash(state),
105                U64 => self.as_slice_unchecked::<u64>().hash(state),
106                F16 => self.as_slice_unchecked::<i16>().hash(state),
107                F32 => self.as_slice_unchecked::<i32>().hash(state),
108                F64 => self.as_slice_unchecked::<i64>().hash(state),
109                TDim => self.as_slice_unchecked::<crate::dim::TDim>().hash(state),
110                String => self.as_slice_unchecked::<std::string::String>().hash(state),
111                Blob => self.as_slice_unchecked::<crate::blob::Blob>().hash(state),
112                Opaque => self.as_slice_unchecked::<crate::opaque::Opaque>().hash(state),
113                QI8(_) => self.as_slice_unchecked::<i8>().hash(state),
114                QU8(_) => self.as_slice_unchecked::<u8>().hash(state),
115                QI32(_) => self.as_slice_unchecked::<i32>().hash(state),
116                #[cfg(feature = "complex")]
117                ComplexI16 => self.as_slice_unchecked::<Complex<i16>>().hash(state),
118                #[cfg(feature = "complex")]
119                ComplexI32 => self.as_slice_unchecked::<Complex<i32>>().hash(state),
120                #[cfg(feature = "complex")]
121                ComplexI64 => self.as_slice_unchecked::<Complex<i64>>().hash(state),
122                #[cfg(feature = "complex")]
123                ComplexF16 => self.as_slice_unchecked::<Complex<i16>>().hash(state),
124                #[cfg(feature = "complex")]
125                ComplexF32 => self.as_slice_unchecked::<Complex<i32>>().hash(state),
126                #[cfg(feature = "complex")]
127                ComplexF64 => self.as_slice_unchecked::<Complex<i64>>().hash(state),
128            }
129        }
130    }
131}
132
133impl Clone for Tensor {
134    fn clone(&self) -> Tensor {
135        self.deep_clone()
136    }
137}
138
139impl Default for Tensor {
140    fn default() -> Tensor {
141        litteral::tensor0(0f32)
142    }
143}
144
145impl Drop for Tensor {
146    fn drop(&mut self) {
147        macro_rules! drop_in_place {
148            ($t: ty) => {
149                if self.dt == <$t>::datum_type() {
150                    unsafe {
151                        let slice = self.as_slice_mut_unchecked::<$t>();
152                        std::ptr::drop_in_place(slice as *mut [$t]);
153                    }
154                }
155            };
156        }
157        drop_in_place!(Blob);
158        drop_in_place!(String);
159        drop_in_place!(TDim);
160        drop_in_place!(Opaque);
161    }
162}
163
164#[allow(unreachable_code)]
165pub fn vector_size() -> usize {
166    #[cfg(target_arch = "x86_64")]
167    {
168        return if is_x86_feature_detected!("avx512f") { 512 / 8 } else { 256 / 8 };
169    }
170    128 / 8
171}
172
173impl Tensor {
174    #[inline]
175    fn dense_storage(&self) -> &DenseStorage {
176        self.storage.as_dense().expect("Non-dense storage")
177    }
178
179    #[inline]
180    fn dense_storage_mut(&mut self) -> &mut DenseStorage {
181        self.storage.as_dense_mut().expect("Non-dense storage")
182    }
183
184    /// Returns an immutable [`DenseView`] if this tensor has dense storage.
185    #[inline]
186    pub fn as_dense(&self) -> Option<DenseView<'_>> {
187        let storage = self.storage.as_dense()?;
188        Some(DenseView::new(self, storage))
189    }
190
191    /// Returns an immutable [`DenseView`], or an error if storage is not dense.
192    #[inline]
193    pub fn try_as_dense(&self) -> TractResult<DenseView<'_>> {
194        self.as_dense().context("Tensor storage is not dense")
195    }
196
197    /// Returns a mutable [`DenseViewMut`] if this tensor has dense storage.
198    #[inline]
199    pub fn as_dense_mut(&mut self) -> Option<DenseViewMut<'_>> {
200        let storage = self.storage.as_dense_mut()?;
201        Some(DenseViewMut::new(self.dt, &self.shape, &self.strides, self.len, storage))
202    }
203
204    /// Returns a mutable [`DenseViewMut`], or an error if storage is not dense.
205    #[inline]
206    pub fn try_as_dense_mut(&mut self) -> TractResult<DenseViewMut<'_>> {
207        self.as_dense_mut().context("Tensor storage is not dense")
208    }
209
210    /// Create an uninitialized tensor (dt as type paramater).
211    #[inline]
212    pub unsafe fn uninitialized<T: Datum>(shape: &[usize]) -> TractResult<Tensor> {
213        unsafe { Self::uninitialized_dt(T::datum_type(), shape) }
214    }
215
216    /// Create an uninitialized tensor (dt as regular parameter).
217    #[inline]
218    pub unsafe fn uninitialized_dt(dt: DatumType, shape: &[usize]) -> TractResult<Tensor> {
219        unsafe { Self::uninitialized_aligned_dt(dt, shape, vector_size()) }
220    }
221
222    /// Create an uninitialized tensor with a given alignment (in bytes).
223    #[inline]
224    pub unsafe fn uninitialized_aligned<T: Datum>(
225        shape: &[usize],
226        alignment: usize,
227    ) -> TractResult<Tensor> {
228        unsafe { Self::uninitialized_aligned_dt(T::datum_type(), shape, alignment) }
229    }
230
231    /// Create an uninitialized tensor with a given alignment (in bytes).
232    pub unsafe fn uninitialized_aligned_dt(
233        dt: DatumType,
234        shape: &[usize],
235        alignment: usize,
236    ) -> TractResult<Tensor> {
237        let bytes = shape.iter().cloned().product::<usize>() * dt.size_of();
238        let storage = StorageKind::Dense(DenseStorage::from(unsafe {
239            Blob::new_for_size_and_align(bytes, alignment)
240        }));
241        let mut tensor = Tensor { strides: tvec!(), dt, shape: shape.into(), storage, len: 0 };
242        if tensor.shape.len() == 0 {
243            tensor.len = 1;
244        } else {
245            tensor.update_strides_and_len();
246        }
247        if !tensor.storage.is_empty() {
248            if dt == String::datum_type() || dt == Blob::datum_type() {
249                // assumes zero-initialized string and blob are valid
250                tensor.dense_storage_mut().as_bytes_mut().fill(0);
251            } else if dt == TDim::datum_type() {
252                unsafe {
253                    tensor
254                        .as_slice_mut_unchecked::<TDim>()
255                        .iter_mut()
256                        .for_each(|dim| std::ptr::write(dim, TDim::zero()))
257                }
258            } else if dt == Opaque::datum_type() {
259                unsafe {
260                    tensor.as_slice_mut_unchecked::<Opaque>().iter_mut().for_each(|p| {
261                        std::ptr::write(p, Opaque::default());
262                    })
263                };
264            } else if cfg!(debug_assertions) {
265                assert!(dt.is_copy());
266                if dt == DatumType::F32 {
267                    tensor.fill_t(f32::NAN).unwrap();
268                } else {
269                    // safe, non copy types have been dealt with
270                    tensor.as_bytes_mut().iter_mut().for_each(|x| *x = (-1i8) as u8);
271                }
272            }
273        }
274        Ok(tensor)
275    }
276
277    pub fn stack_tensors(
278        axis: usize,
279        tensors: &[impl std::borrow::Borrow<Tensor>],
280    ) -> TractResult<Tensor> {
281        ensure!(tensors.len() > 0);
282        let rank = tensors[0].borrow().rank();
283        ensure!(axis < rank);
284        ensure!(tensors.iter().all(|t| t.borrow().rank() == rank));
285        let dt = tensors[0].borrow().datum_type();
286        ensure!(tensors.iter().all(|t| t.borrow().datum_type() == dt));
287        let mut shape: TVec<usize> = tensors[0].borrow().shape().into();
288        for ax in 0..rank {
289            if ax != axis {
290                ensure!(tensors.iter().all(|t| t.borrow().shape()[ax] == shape[ax]));
291            }
292        }
293        shape[axis] = tensors.iter().map(|v| v.borrow().shape()[axis]).sum();
294        unsafe {
295            let mut result = Tensor::uninitialized_dt(dt, &shape)?;
296            if dt.is_copy() && shape[..axis].iter().all(|d| *d == 1) {
297                let mut offset = 0isize;
298                for v in tensors {
299                    let v = v.borrow();
300                    let len = v.storage.byte_len();
301                    std::ptr::copy_nonoverlapping(
302                        v.dense_storage().as_ptr(),
303                        result.dense_storage_mut().as_mut_ptr().offset(offset),
304                        len,
305                    );
306                    offset += len as isize;
307                }
308            } else {
309                let mut offset = 0;
310                for t in tensors {
311                    let t = t.borrow();
312                    let len = t.shape()[axis];
313                    result.assign_slice_from_resolved(offset..offset + len, t, 0..len, axis);
314                    offset += len;
315                }
316            }
317
318            Ok(result)
319        }
320    }
321
322    pub fn clear<T: Datum + num_traits::Zero + Clone>(&mut self) -> TractResult<()> {
323        self.fill_t(T::zero())
324    }
325
326    pub fn zero<T: Datum + num_traits::Zero>(shape: &[usize]) -> TractResult<Tensor> {
327        unsafe {
328            let mut t = Tensor::uninitialized::<T>(shape)?;
329            t.clear::<T>()?;
330            Ok(t)
331        }
332    }
333
334    pub fn zero_scalar<T: Datum + num_traits::Zero>() -> TractResult<Tensor> {
335        Tensor::zero::<T>(&[])
336    }
337
338    pub fn zero_scalar_dt(dt: DatumType) -> TractResult<Tensor> {
339        Tensor::zero_dt(dt, &[])
340    }
341
342    pub fn zero_dt(dt: DatumType, shape: &[usize]) -> TractResult<Tensor> {
343        Tensor::zero_aligned_dt(dt, shape, vector_size())
344    }
345
346    pub fn fill_t<T: Datum + Clone>(&mut self, value: T) -> TractResult<()> {
347        self.try_as_dense_mut()?
348            .as_slice_mut::<T>()?
349            .iter_mut()
350            .for_each(|item| *item = value.clone());
351        Ok(())
352    }
353
354    pub fn zero_aligned_dt(
355        dt: DatumType,
356        shape: &[usize],
357        alignment: usize,
358    ) -> TractResult<Tensor> {
359        if shape.iter().product::<usize>() == 0 {
360            unsafe { return Tensor::uninitialized_dt(dt, shape) };
361        }
362        if dt.is_quantized() {
363            unsafe {
364                let mut t = Tensor::uninitialized_dt(dt, shape)?;
365                let zp = dt.zp_scale().0;
366                match dt.unquantized() {
367                    DatumType::I8 => t
368                        .try_as_dense_mut()?
369                        .as_slice_mut::<i8>()?
370                        .iter_mut()
371                        .for_each(|item| *item = zp as _),
372                    DatumType::U8 => t
373                        .try_as_dense_mut()?
374                        .as_slice_mut::<u8>()?
375                        .iter_mut()
376                        .for_each(|item| *item = zp as _),
377                    DatumType::I32 => t
378                        .try_as_dense_mut()?
379                        .as_slice_mut::<i32>()?
380                        .iter_mut()
381                        .for_each(|item| *item = zp as _),
382                    _ => unreachable!(),
383                }
384                Ok(t)
385            }
386        } else {
387            dispatch_zerolike!(Self::zero_aligned(dt)(shape, alignment))
388        }
389    }
390
391    pub fn zero_aligned<T: Datum + num_traits::Zero>(
392        shape: &[usize],
393        alignment: usize,
394    ) -> TractResult<Tensor> {
395        unsafe {
396            let mut tensor = Self::uninitialized_aligned::<T>(shape, alignment)?;
397            tensor.clear::<T>()?;
398            Ok(tensor)
399        }
400    }
401
402    /// Create a tensor with a given shape and a slice of elements.
403    /// The data is copied and aligned to size of T.
404    pub fn from_shape<T: Datum + Copy>(shape: &[usize], data: &[T]) -> TractResult<Tensor> {
405        Self::from_shape_align(shape, data, vector_size())
406    }
407
408    /// Create a tensor with a given shape and a slice of elements.
409    /// The data is copied and aligned to given alignment.
410    pub fn from_shape_align<T: Datum + Copy>(
411        shape: &[usize],
412        data: &[T],
413        align: usize,
414    ) -> TractResult<Tensor> {
415        ensure!(
416            data.len() == shape.iter().product::<usize>(),
417            "Shape product must be equal to data length"
418        );
419        unsafe {
420            let bytes = std::slice::from_raw_parts(
421                data.as_ptr() as *const u8,
422                data.len() * T::datum_type().size_of(),
423            );
424            let dt = T::datum_type();
425            Self::from_raw_dt_align(dt, shape, bytes, align)
426        }
427    }
428
429    /// Create a tensor from raw data.
430    ///
431    /// It copies the data, aligning it to the size of T.
432    pub unsafe fn from_raw<T: Datum>(shape: &[usize], content: &[u8]) -> TractResult<Tensor> {
433        unsafe { Tensor::from_raw_dt(T::datum_type(), shape, content) }
434    }
435
436    pub unsafe fn from_raw_aligned<T: Datum>(
437        shape: &[usize],
438        content: &[u8],
439        align: usize,
440    ) -> TractResult<Tensor> {
441        unsafe { Tensor::from_raw_dt_align(T::datum_type(), shape, content, align) }
442    }
443
444    pub unsafe fn from_raw_dt(
445        dt: DatumType,
446        shape: &[usize],
447        content: &[u8],
448    ) -> TractResult<Tensor> {
449        unsafe { Self::from_raw_dt_align(dt, shape, content, vector_size()) }
450    }
451
452    pub unsafe fn from_raw_dt_align(
453        dt: DatumType,
454        shape: &[usize],
455        content: &[u8],
456        align: usize,
457    ) -> TractResult<Tensor> {
458        let mut tensor = unsafe { Tensor::uninitialized_aligned_dt(dt, shape, align) }?;
459        tensor.as_bytes_mut().copy_from_slice(content);
460        Ok(tensor)
461    }
462
463    pub unsafe fn from_slice_align<T: Datum>(content: &[T], align: usize) -> TractResult<Tensor> {
464        let bytes = if content.len() == 0 {
465            &[]
466        } else {
467            unsafe {
468                std::slice::from_raw_parts(
469                    content.as_ptr() as *const u8,
470                    content.len() * T::datum_type().size_of(),
471                )
472            }
473        };
474        unsafe { Self::from_raw_dt_align(T::datum_type(), &[content.len()], bytes, align) }
475    }
476
477    /// Get the number of dimensions (or axes) of the tensor.
478    #[inline]
479    pub fn rank(&self) -> usize {
480        self.shape.len()
481    }
482
483    /// Get the shape of the tensor.
484    #[inline]
485    pub fn shape(&self) -> &[usize] {
486        &self.shape
487    }
488
489    /// Get the number of values in the tensor.
490    #[inline]
491    #[allow(clippy::len_without_is_empty)]
492    pub fn len(&self) -> usize {
493        self.len
494    }
495
496    /// Get the number of valeus in the tensor.
497    #[inline]
498    #[allow(clippy::len_without_is_empty)]
499    pub fn volume(&self) -> usize {
500        self.len
501    }
502
503    /// Get the shape of the tensor.
504    #[inline]
505    pub fn strides(&self) -> &[isize] {
506        &self.strides
507    }
508
509    fn update_strides_and_len(&mut self) {
510        self.strides.clear();
511        if self.shape.len() == 0 {
512            self.len = 1;
513            return;
514        }
515        compute_natural_stride_to(&mut self.strides, &self.shape);
516        self.len = unsafe { *self.strides.get_unchecked(0) as usize * self.shape.get_unchecked(0) };
517    }
518
519    /// Force the tensor shape, no consistency check.
520    pub unsafe fn set_shape_unchecked(&mut self, shape: &[usize]) {
521        if shape != &*self.shape {
522            self.shape.clear();
523            self.shape.extend_from_slice(shape);
524            self.update_strides_and_len();
525        }
526    }
527
528    /// Force the tensor shape and strides, no consistency check.
529    pub unsafe fn set_geometry_unchecked(&mut self, shape: &[usize], strides: &[isize]) {
530        self.shape.clear();
531        self.shape.extend_from_slice(shape);
532        self.strides.clear();
533        self.strides.extend_from_slice(strides);
534    }
535
536    /// Force the tensor shape.
537    pub fn set_shape(&mut self, shape: &[usize]) -> TractResult<()> {
538        if self.len() != shape.iter().product::<usize>() {
539            bail!("Invalid reshape {:?} to {:?}", self.shape, shape);
540        }
541        unsafe { self.set_shape_unchecked(shape) }
542        Ok(())
543    }
544
545    pub fn permute_axes(self, axes: &[usize]) -> TractResult<Tensor> {
546        ensure!(axes.iter().duplicates().next().is_none());
547        ensure!(axes.iter().all(|a| *a < self.rank()));
548        unsafe {
549            #[inline]
550            unsafe fn permute<T: Datum>(axes: &[usize], input: Tensor) -> Tensor {
551                unsafe { input.into_array_unchecked::<T>().permuted_axes(axes).into_tensor() }
552            }
553            let dt = self.datum_type();
554            let mut t = dispatch_datum_by_size!(permute(self.datum_type())(axes, self));
555            t.set_datum_type(dt);
556            Ok(t)
557        }
558    }
559
560    pub fn move_axis(self, from: usize, to: usize) -> TractResult<Tensor> {
561        let mut permutation: Vec<usize> = (0..self.rank()).collect();
562        permutation.remove(from);
563        permutation.insert(to, from);
564        self.permute_axes(&permutation)
565    }
566
567    pub fn collapse_axis_with_next(mut self, axis: usize) -> Tensor {
568        let removed = self.shape.remove(axis + 1);
569        self.shape[axis] *= removed;
570        self.update_strides_and_len();
571        self
572    }
573
574    pub fn split_axis(mut self, axis: usize, outer_dim: usize) -> TractResult<Tensor> {
575        if self.shape[axis] % outer_dim != 0 {
576            bail!(
577                "Invalid axis split, shape is {:?}, axis split at {}, outer {}",
578                self.shape,
579                axis,
580                outer_dim
581            );
582        }
583        self.shape.insert(axis + 1, self.shape[axis] / outer_dim);
584        self.shape[axis] = outer_dim;
585        self.update_strides_and_len();
586        Ok(self)
587    }
588
589    /// Reshape the tensor to `shape`.
590    pub fn into_shape(mut self, shape: &[usize]) -> TractResult<Tensor> {
591        self.set_shape(shape)?;
592        Ok(self)
593    }
594
595    pub fn insert_axis(&mut self, axis: usize) -> TractResult<()> {
596        self.shape.insert(axis, 1);
597        self.strides.insert(axis, self.strides.get(axis).copied().unwrap_or(1));
598        Ok(())
599    }
600
601    pub fn remove_axis(&mut self, axis: usize) -> TractResult<()> {
602        ensure!(self.shape[axis] == 1, "Remove a non-1 axis: axis {} in {:?}", axis, self);
603        self.shape.remove(axis);
604        self.strides.remove(axis);
605        Ok(())
606    }
607
608    pub fn broadcast_into_rank(mut self, rank: usize) -> TractResult<Tensor> {
609        self.broadcast_to_rank(rank)?;
610        self.update_strides_and_len();
611        Ok(self)
612    }
613
614    pub fn broadcast_to_rank(&mut self, rank: usize) -> TractResult<()> {
615        if rank < self.rank() {
616            bail!("Can only broadcast to higher rank")
617        }
618        while self.shape.len() < rank {
619            self.shape.insert(0, 1)
620        }
621        self.update_strides_and_len();
622        Ok(())
623    }
624
625    pub fn broadcast_scalar_to_shape(&self, shape: &[usize]) -> TractResult<Tensor> {
626        if self.rank() > 0 {
627            bail!("broadcast_scalar_to_shape called on {:?}, which is not a salar", self);
628        }
629        unsafe fn make<T: Datum>(src: &Tensor, dst: &mut Tensor) {
630            unsafe {
631                let value: &T = src.to_scalar_unchecked::<T>();
632                dst.as_slice_mut_unchecked::<T>().iter_mut().for_each(|item| *item = value.clone())
633            };
634        }
635        unsafe {
636            let mut t = Tensor::uninitialized_dt(self.datum_type(), shape)?;
637            dispatch_datum_by_size!(make(self.datum_type())(self, &mut t));
638            Ok(t)
639        }
640    }
641
642    fn broadcast_to_shape_t<T: Datum>(&self, shape: &[usize]) -> TractResult<Tensor> {
643        unsafe {
644            let view = self.to_array_view_unchecked::<T>();
645            let mut output = view
646                .broadcast(shape)
647                .with_context(|| format!("Broadcasting {view:?} to {shape:?}"))?
648                .into_owned()
649                .into_tensor();
650            output.set_datum_type(self.datum_type());
651            Ok(output)
652        }
653    }
654
655    pub fn broadcast_to_shape(&self, shape: &[usize]) -> TractResult<Tensor> {
656        dispatch_datum!(Self::broadcast_to_shape_t(self.dt)(self, shape))
657    }
658
659    pub fn broadcast_vector_to_shape(&self, shape: &[usize], axis: usize) -> TractResult<Tensor> {
660        ensure!(self.rank() == 1);
661        ensure!(shape[axis] == self.len());
662        if !self.datum_type().is_copy() {
663            let mut vec_shape = vec![1; shape.len()];
664            vec_shape[axis] = self.len();
665            return self.clone().into_shape(&vec_shape)?.broadcast_to_shape(shape);
666        }
667        unsafe {
668            let mut output = Tensor::uninitialized_dt(self.datum_type(), shape)?;
669            if output.len() == 0 {
670                return Ok(output);
671            }
672            let inner_len = shape[axis + 1..].iter().product::<usize>();
673
674            unsafe fn splat<T>(input: &Tensor, output: &mut Tensor, inner_len: usize)
675            where
676                T: Datum + Copy,
677            {
678                unsafe {
679                    for ix in 0..input.len() {
680                        let value: T = input.as_slice_unchecked()[ix];
681                        output.as_slice_mut_unchecked::<T>()[ix * inner_len..(ix + 1) * inner_len]
682                            .iter_mut()
683                            .for_each(|item| *item = value);
684                    }
685                }
686            }
687            dispatch_copy_by_size!(splat(self.datum_type())(&self, &mut output, inner_len));
688
689            let outer_len = shape[0..axis].iter().product::<usize>();
690            let repeat_bytes_len = inner_len * self.as_bytes().len();
691            let bytes = output.as_bytes_mut();
692            for ix in 1..outer_len {
693                bytes.copy_within(0..repeat_bytes_len, ix * repeat_bytes_len);
694            }
695
696            Ok(output)
697        }
698    }
699
700    pub fn assign_slice(
701        &mut self,
702        range: impl std::ops::RangeBounds<usize>,
703        src: &Tensor,
704        src_range: impl std::ops::RangeBounds<usize>,
705        axis: usize,
706    ) -> TractResult<()> {
707        ensure!(self.rank() == src.rank());
708        ensure!(axis < self.rank());
709        let range = clip_range_bounds(self.shape[axis], range);
710        let src_range = clip_range_bounds(src.shape[axis], src_range);
711        ensure!(
712            src.datum_type() == self.datum_type(),
713            "Attempt to assign into {:?} from {:?}, datum type mismatch",
714            self.datum_type(),
715            src.datum_type()
716        );
717        ensure!(
718            src_range.len() == range.len(),
719            "Attempt to assign a range of {:?} from a range of {:?}",
720            range,
721            src_range,
722        );
723        ensure!(
724            itertools::izip!(0.., self.shape(), src.shape())
725                .all(|(ix, dst, src)| ix == axis || src == dst),
726            "Attempt to assign a {}-axis range of {:?} from a range of {:?}",
727            axis,
728            self,
729            src
730        );
731        ensure!(
732            src_range.end <= src.shape()[axis],
733            "Assigning from invalid slice (axis {}, {:?}) of {:?}",
734            axis,
735            src_range,
736            src
737        );
738        ensure!(
739            range.end <= self.shape()[axis],
740            "Assigning to invalid slice (axis {}, {:?}) of {:?}",
741            axis,
742            range,
743            self
744        );
745        unsafe { self.assign_slice_from_resolved(range, src, src_range, axis) };
746        Ok(())
747    }
748
749    pub unsafe fn assign_slice_unchecked(
750        &mut self,
751        range: impl std::ops::RangeBounds<usize>,
752        src: &Tensor,
753        src_range: impl std::ops::RangeBounds<usize>,
754        axis: usize,
755    ) {
756        let range = clip_range_bounds(self.shape[axis], range);
757        let src_range = clip_range_bounds(src.shape[axis], src_range);
758        unsafe { self.assign_slice_from_resolved(range, src, src_range, axis) };
759    }
760
761    #[allow(clippy::ptr_eq)]
762    unsafe fn assign_slice_from_resolved(
763        &mut self,
764        range: std::ops::Range<usize>,
765        src: &Tensor,
766        src_range: std::ops::Range<usize>,
767        axis: usize,
768    ) {
769        unsafe {
770            use ndarray::Slice;
771            unsafe fn assign_slice_t<T: Datum>(
772                to: &mut Tensor,
773                to_range: Range<usize>,
774                from: &Tensor,
775                from_range: Range<usize>,
776                axis: usize,
777            ) {
778                unsafe {
779                    to.to_array_view_mut_unchecked::<T>()
780                        .slice_axis_mut(Axis(axis), Slice::from(to_range))
781                        .assign(
782                            &from
783                                .to_array_view_unchecked::<T>()
784                                .slice_axis(Axis(axis), Slice::from(from_range)),
785                        )
786                }
787            }
788            if self.datum_type().is_copy() && self.shape[..axis].iter().all(|d| *d == 1) {
789                let stride = self.strides[axis] as usize * self.datum_type().size_of();
790                let dst_start = (stride * range.start) as isize;
791                let src_start = (stride * src_range.start) as isize;
792                let len = stride * range.len();
793                if len > 0 {
794                    if self.dense_storage().as_ptr() != src.dense_storage().as_ptr() {
795                        std::ptr::copy_nonoverlapping(
796                            src.dense_storage().as_ptr().offset(src_start),
797                            self.dense_storage_mut().as_mut_ptr().offset(dst_start),
798                            len,
799                        );
800                    } else {
801                        std::ptr::copy(
802                            src.dense_storage().as_ptr().offset(src_start),
803                            self.dense_storage_mut().as_mut_ptr().offset(dst_start),
804                            len,
805                        );
806                    }
807                }
808            } else {
809                dispatch_datum!(assign_slice_t(self.datum_type())(
810                    self, range, src, src_range, axis
811                ));
812            }
813        }
814    }
815
816    /// Get the datum type of the tensor.
817    #[inline]
818    pub fn datum_type(&self) -> DatumType {
819        self.dt
820    }
821
822    /// Set the datum type of the tensor.
823    #[inline]
824    pub unsafe fn set_datum_type(&mut self, dt: DatumType) {
825        self.dt = dt
826    }
827
828    /// Dump the tensor in a human readable form.
829    ///
830    /// `force_full` will force the tensor to be dump in full even if it is big.
831    pub fn dump(&self, force_full: bool) -> TractResult<String> {
832        unsafe fn dump_t<D: Datum>(tensor: &Tensor, n: usize) -> String {
833            unsafe {
834                if let Some(qp) = tensor.datum_type().qparams() {
835                    let integers = tensor.cast_to::<i32>().unwrap();
836                    integers.as_slice_unchecked::<i32>()[0..n]
837                        .iter()
838                        .map(|x| format!("[{}]({})", x, qp.dq(*x)))
839                        .join(", ")
840                } else {
841                    tensor.as_slice_unchecked::<D>()[0..n].iter().join(", ")
842                }
843            }
844        }
845        unsafe {
846            let trunc = self.len() > 12 && !force_full;
847            let data = dispatch_datum!(dump_t(self.datum_type())(
848                self,
849                if trunc { 12 } else { self.len() }
850            ));
851            Ok(format!(
852                "{},{:?} {}{}",
853                self.shape.iter().join(","),
854                self.dt,
855                data,
856                if trunc { "..." } else { "" }
857            ))
858        }
859    }
860
861    /// Compare two tensors, allowing for rounding errors.
862    pub fn close_enough(
863        &self,
864        other: &Self,
865        approx: impl Into<Approximation> + std::fmt::Debug,
866    ) -> TractResult<()> {
867        let approx = approx.into();
868        if self.shape() != other.shape() {
869            bail!("Shape mismatch {:?} != {:?}", self.shape(), other.shape())
870        }
871        let (atol, rtol, outliers) = approx.atol_rtol_outliers(&self.datum_type());
872        let ma = self.cast_to::<f32>()?;
873        let ma = ma.to_dense_array_view::<f32>()?;
874        let mb = other.cast_to::<f32>()?;
875        let mb = mb.to_dense_array_view::<f32>()?;
876        let mut first_outlier = None;
877        let mut outliers_count = 0;
878        ndarray::indices_of(&ma).into_iter().for_each(|indices| {
879            let a = ma[&indices];
880            let b = mb[&indices];
881            if !((a.is_nan() && b.is_nan())
882                || (a.is_infinite() && b.is_infinite() && a.signum() == b.signum())
883                || (a - b).abs() <= atol as f32 + rtol as f32 * b.abs())
884            {
885                if outliers_count == 0 {
886                    first_outlier = Some(indices.as_array_view().to_vec());
887                }
888                outliers_count += 1;
889            }
890        });
891        if self.volume() > 0 && outliers_count as f64 / self.volume() as f64 > outliers {
892            let indices = first_outlier.unwrap();
893            let a = ma[&*indices];
894            let b = mb[&*indices];
895            bail!(
896                "Mismatch. First outlier: {:?} for {:?}) at {:?} {} != {}. Outliers: {} / {} = {:0.5} > {:0.5}.",
897                approx,
898                self.datum_type(),
899                indices,
900                a,
901                b,
902                outliers_count,
903                self.volume(),
904                outliers_count as f64 / self.volume() as f64,
905                outliers
906            );
907        }
908        Ok(())
909    }
910
911    /// Transform the tensor into a `ndarray::Array`.
912    pub fn into_dense_array<D: Datum>(self) -> TractResult<ArrayD<D>> {
913        Ok(self.to_dense_array_view::<D>()?.to_owned())
914    }
915
916    /// Transform the tensor into a `ndarray::Array`.
917    pub unsafe fn into_array_unchecked<D: Datum>(self) -> ArrayD<D> {
918        unsafe { self.to_array_view_unchecked::<D>().to_owned() }
919    }
920
921    /// Returns a dense array view of the tensor.
922    ///
923    /// Errors if the storage is not dense or the datum type does not match `D`.
924    #[inline]
925    pub fn to_dense_array_view<D: Datum>(&self) -> TractResult<ArrayViewD<'_, D>> {
926        self.try_as_dense()?.to_array_view::<D>()
927    }
928
929    /// Returns a mutable dense array view of the tensor.
930    ///
931    /// Errors if the storage is not dense or the datum type does not match `D`.
932    #[inline]
933    pub fn to_dense_array_view_mut<D: Datum>(&mut self) -> TractResult<ArrayViewMutD<'_, D>> {
934        self.check_for_access::<D>()?;
935        ensure!(self.storage.as_dense_mut().is_some(), "Tensor storage is not dense");
936        unsafe { Ok(self.to_array_view_mut_unchecked()) }
937    }
938
939    fn check_for_access<D: Datum>(&self) -> TractResult<()> {
940        ensure!(
941            self.datum_type().unquantized() == D::datum_type().unquantized(),
942            "Tensor datum type error: tensor is {:?}, accessed as {:?}",
943            self.datum_type(),
944            D::datum_type(),
945        );
946        Ok(())
947    }
948
949    /// Transform the data as a `ndarray::Array`.
950    pub unsafe fn to_array_view_unchecked<D: Datum>(&self) -> ArrayViewD<'_, D> {
951        if self.len() != 0 {
952            unsafe {
953                ArrayViewD::from_shape_ptr(&*self.shape, self.dense_storage().as_ptr() as *const D)
954            }
955        } else {
956            ArrayViewD::from_shape(&*self.shape, &[]).unwrap()
957        }
958    }
959
960    /// Transform the data as a mutable `ndarray::Array`.
961    pub unsafe fn to_array_view_mut_unchecked<D: Datum>(&mut self) -> ArrayViewMutD<'_, D> {
962        if self.len() != 0 {
963            unsafe {
964                let ptr = self.dense_storage_mut().as_mut_ptr() as *mut D;
965                ArrayViewMutD::from_shape_ptr(&*self.shape, ptr)
966            }
967        } else {
968            ArrayViewMutD::from_shape(&*self.shape, &mut []).unwrap()
969        }
970    }
971
972    /// Access the data as a pointer.
973    pub fn as_ptr<D: Datum>(&self) -> TractResult<*const D> {
974        self.check_for_access::<D>()?;
975        Ok(self.dense_storage().as_ptr() as *const D)
976    }
977
978    /// Access the data as a pointer.
979    pub unsafe fn as_ptr_unchecked<D: Datum>(&self) -> *const D {
980        self.dense_storage().as_ptr() as *const D
981    }
982
983    /// Access the data as a pointer.
984    pub unsafe fn as_ptr_mut_unchecked<D: Datum>(&mut self) -> *mut D {
985        self.dense_storage_mut().as_mut_ptr() as *mut D
986    }
987
988    /// Access the data as a mutable pointer.
989    pub fn as_ptr_mut<D: Datum>(&mut self) -> TractResult<*mut D> {
990        self.as_ptr::<D>().map(|p| p as *mut D)
991    }
992
993    /// Access the data as a slice.
994    pub unsafe fn as_slice_unchecked<D: Datum>(&self) -> &[D] {
995        if self.storage.byte_len() == 0 {
996            &[]
997        } else {
998            unsafe { std::slice::from_raw_parts::<D>(self.as_ptr_unchecked(), self.len()) }
999        }
1000    }
1001
1002    /// Access the data as a mutable slice.
1003    pub unsafe fn as_slice_mut_unchecked<D: Datum>(&mut self) -> &mut [D] {
1004        if self.storage.byte_len() == 0 {
1005            &mut []
1006        } else {
1007            unsafe { std::slice::from_raw_parts_mut::<D>(self.as_ptr_mut_unchecked(), self.len()) }
1008        }
1009    }
1010
1011    /// Make the tensor a scalar tensor (assumes it contains a single value).
1012    pub fn to_scalar_tensor(&self) -> TractResult<Tensor> {
1013        fn to_scalar_tensor_t<D: Datum>(t: &Tensor) -> TractResult<Tensor> {
1014            Ok(litteral::tensor0(t.try_as_dense()?.to_scalar::<D>()?.clone()))
1015        }
1016        dispatch_datum!(to_scalar_tensor_t(self.datum_type())(self))
1017    }
1018
1019    /// Access the data as a scalar.
1020    pub unsafe fn to_scalar_unchecked<D: Datum>(&self) -> &D {
1021        unsafe { &*(self.dense_storage().as_ptr() as *const D) }
1022    }
1023
1024    /// Mutable access the data as a scalar.
1025    pub fn to_scalar_mut<D: Datum>(&mut self) -> TractResult<&mut D> {
1026        self.check_for_access::<D>()?;
1027        if self.len() == 0 {
1028            bail!("to_scalar_mut called on empty tensor ({:?})", self)
1029        }
1030        if self.len() > 1 {
1031            bail!("to_scalar called on a tensor with multiple values ({:?})", self)
1032        }
1033        unsafe { Ok(self.to_scalar_mut_unchecked()) }
1034    }
1035
1036    /// Mutable access the data as a scalar.
1037    pub unsafe fn to_scalar_mut_unchecked<D: Datum>(&mut self) -> &mut D {
1038        unsafe { &mut *(self.dense_storage_mut().as_mut_ptr() as *mut D) }
1039    }
1040
1041    pub fn as_bytes(&self) -> &[u8] {
1042        self.dense_storage().as_bytes()
1043    }
1044
1045    pub fn as_bytes_mut(&mut self) -> &mut [u8] {
1046        self.dense_storage_mut().as_bytes_mut()
1047    }
1048
1049    unsafe fn is_uniform_t<T: Datum>(&self) -> bool {
1050        let slice = unsafe { self.as_slice_unchecked::<T>() };
1051        slice[1..].iter().all(|x| x == &slice[0])
1052    }
1053
1054    pub fn is_uniform(&self) -> bool {
1055        if self.len() <= 1 {
1056            return true;
1057        }
1058        unsafe { dispatch_datum!(Tensor::is_uniform_t(self.datum_type())(self)) }
1059    }
1060
1061    unsafe fn as_uniform_t<T: Datum>(&self) -> Tensor {
1062        let v: T = unsafe { self.as_slice_unchecked::<T>() }[0].clone();
1063        litteral::tensor0(v)
1064    }
1065
1066    pub fn as_uniform(&self) -> Option<Tensor> {
1067        if self.len() >= 1 && self.is_uniform() {
1068            unsafe {
1069                let mut t = dispatch_datum!(Tensor::as_uniform_t(self.datum_type())(self));
1070                t.set_datum_type(self.datum_type());
1071                Some(t)
1072            }
1073        } else {
1074            None
1075        }
1076    }
1077
1078    pub fn is_all_zero(&self) -> TractResult<bool> {
1079        Ok(self.len() == 0 || self.as_uniform().map(|t| t.is_zero().unwrap()).unwrap_or(false))
1080    }
1081
1082    pub fn is_zero(&self) -> TractResult<bool> {
1083        Ok(self == &Tensor::zero_scalar_dt(self.dt)?)
1084    }
1085
1086    unsafe fn natural_cast<
1087        Source: Datum + num_traits::AsPrimitive<Target>,
1088        Target: Datum + Copy,
1089    >(
1090        &self,
1091        other: &mut Tensor,
1092    ) {
1093        unsafe {
1094            self.as_slice_unchecked::<Source>()
1095                .iter()
1096                .zip(other.as_slice_mut_unchecked::<Target>().iter_mut())
1097                .for_each(|(s, d)| *d = s.as_())
1098        };
1099    }
1100
1101    unsafe fn cast_number_to_bool<Source: Datum + num_traits::Zero>(&self, other: &mut Tensor) {
1102        unsafe {
1103            self.as_slice_unchecked::<Source>()
1104                .iter()
1105                .zip(other.as_slice_mut_unchecked::<bool>().iter_mut())
1106                .for_each(|(s, d)| *d = !s.is_zero());
1107        }
1108    }
1109
1110    unsafe fn cast_from_string<Target: Datum + core::str::FromStr>(
1111        &self,
1112        other: &mut Tensor,
1113    ) -> TractResult<()> {
1114        unsafe {
1115            for (s, d) in self
1116                .as_slice_unchecked::<String>()
1117                .iter()
1118                .zip(other.as_slice_mut_unchecked::<Target>().iter_mut())
1119            {
1120                *d = s
1121                    .parse()
1122                    .map_err(|_| format_err!("Can not parse as {:?}", Target::datum_type()))?;
1123            }
1124            Ok(())
1125        }
1126    }
1127
1128    unsafe fn cast_to_string<Source: Datum>(&self, other: &mut Tensor) {
1129        unsafe {
1130            for (s, d) in self
1131                .as_slice_unchecked::<Source>()
1132                .iter()
1133                .zip(other.as_slice_mut_unchecked::<String>().iter_mut())
1134            {
1135                *d = s.to_string()
1136            }
1137        }
1138    }
1139
1140    /// Optionnaly convert data to a tensor for a new DatumType.
1141    pub fn cast_to<D: Datum>(&self) -> TractResult<Cow<'_, Tensor>> {
1142        self.cast_to_dt(D::datum_type())
1143    }
1144
1145    /// Optionnaly convert data to a tensor for a new DatumType.
1146    #[allow(clippy::redundant_closure_call)]
1147    pub fn cast_to_dt(&self, dst_dt: DatumType) -> TractResult<Cow<'_, Tensor>> {
1148        unsafe {
1149            if self.dt == dst_dt {
1150                return Ok(Cow::Borrowed(self));
1151            }
1152            if self.dt == TDim::datum_type() && (dst_dt.is_integer() || dst_dt.is_float()) {
1153                let slice = self.as_slice_unchecked::<TDim>();
1154                let mut ints = Self::uninitialized::<i64>(&self.shape)?;
1155                let ints_slice = ints.as_slice_mut_unchecked::<i64>();
1156                for i in 0..self.len() {
1157                    ints_slice[i] = slice[i].to_i64()?;
1158                }
1159                return Ok(Cow::Owned(ints.cast_to_dt(dst_dt)?.into_owned()));
1160            }
1161            if self.dt == bool::datum_type()
1162                && (dst_dt.is_integer() || dst_dt.is_float() || dst_dt == TDim::datum_type())
1163            {
1164                let slice = self.as_slice_unchecked::<bool>();
1165                let mut ints = Self::uninitialized::<i8>(&self.shape)?;
1166                let ints_slice = ints.as_slice_mut_unchecked::<i8>();
1167                for i in 0..self.len() {
1168                    ints_slice[i] = slice[i] as usize as i8;
1169                }
1170                return Ok(Cow::Owned(ints.cast_to_dt(dst_dt)?.into_owned()));
1171            }
1172            let mut result = Self::uninitialized_dt(dst_dt, &self.shape)?;
1173            if self.dt == DatumType::String {
1174                dispatch_numbers!(Self::cast_from_string(dst_dt)(self, &mut result))?;
1175                return Ok(Cow::Owned(result));
1176            }
1177            if dst_dt == DatumType::String {
1178                dispatch_datum!(Self::cast_to_string(self.dt)(self, &mut result));
1179                return Ok(Cow::Owned(result));
1180            }
1181            macro_rules! n {
1182                ($source:ty) => {
1183                    if <$source>::datum_type() == self.datum_type() {
1184                        match dst_dt {
1185                            DatumType::I8 => self.natural_cast::<$source, i8>(&mut result),
1186                            DatumType::I16 => self.natural_cast::<$source, i16>(&mut result),
1187                            DatumType::I32 => self.natural_cast::<$source, i32>(&mut result),
1188                            DatumType::I64 => self.natural_cast::<$source, i64>(&mut result),
1189                            DatumType::U8 => self.natural_cast::<$source, u8>(&mut result),
1190                            DatumType::U16 => self.natural_cast::<$source, u16>(&mut result),
1191                            DatumType::U32 => self.natural_cast::<$source, u32>(&mut result),
1192                            DatumType::U64 => self.natural_cast::<$source, u64>(&mut result),
1193                            DatumType::F16 => self.natural_cast::<$source, f16>(&mut result),
1194                            DatumType::F32 => self.natural_cast::<$source, f32>(&mut result),
1195                            DatumType::F64 => self.natural_cast::<$source, f64>(&mut result),
1196                            DatumType::TDim => {
1197                                let ints = self.cast_to::<i32>()?;
1198                                let slice = ints.as_slice_unchecked::<i32>();
1199                                let result = result.as_slice_mut_unchecked::<TDim>();
1200                                for i in 0..self.len() {
1201                                    result[i] = slice[i].into();
1202                                }
1203                            }
1204                            DatumType::Bool => self.cast_number_to_bool::<$source>(&mut result),
1205                            _ => todo!(),
1206                        }
1207                        return Ok(Cow::Owned(result));
1208                    };
1209                };
1210            }
1211            //If there is no quantization
1212            if !dst_dt.is_quantized() && !self.datum_type().is_quantized() {
1213                n!(u8);
1214                n!(u16);
1215                n!(u32);
1216                n!(u64);
1217                n!(i8);
1218                n!(i16);
1219                n!(i32);
1220                n!(i64);
1221                n!(f16);
1222                n!(f32);
1223                n!(f64);
1224            } else {
1225                let (s_zp, s_scale) = self.datum_type().zp_scale();
1226                let (d_zp, d_scale) = dst_dt.zp_scale();
1227                if self.datum_type().is_quantized() && dst_dt.is_float() {
1228                    macro_rules! q_to_fp {
1229                        ($source:ty, $dest:ty) => {
1230                            if <$source>::datum_type().unquantized()
1231                                == self.datum_type().unquantized()
1232                                && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1233                            {
1234                                self.as_slice_unchecked::<$source>()
1235                                    .iter()
1236                                    .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1237                                    .for_each(|(&s, d)| {
1238                                        *d = (s as $dest - s_zp as $dest) * s_scale as $dest;
1239                                    });
1240                                return Ok(Cow::Owned(result));
1241                            }
1242                        };
1243                    }
1244                    q_to_fp!(i8, f64);
1245                    q_to_fp!(i8, f32);
1246                    q_to_fp!(u8, f64);
1247                    q_to_fp!(u8, f32);
1248                }
1249                //TODO: optimize scale_by
1250                macro_rules! q8_to_q8 {
1251                    ($typ:ty) => {
1252                        if dst_dt.unquantized() == <$typ>::datum_type() {
1253                            self.as_slice_unchecked::<$typ>()
1254                                .iter()
1255                                .zip(result.as_slice_mut_unchecked::<$typ>().iter_mut())
1256                                .for_each(|(&s, d)| {
1257                                    *d = (d_zp as i32
1258                                        + scale_by(s as i32 - s_zp as i32, s_scale / d_scale))
1259                                    .clamp_cast()
1260                                });
1261                            return Ok(Cow::Owned(result));
1262                        }
1263                    };
1264                }
1265
1266                macro_rules! q_via_f32 {
1267                    ($source:ty, $dest:ty, $round:expr) => {
1268                        if <$source>::datum_type().unquantized() == self.datum_type().unquantized()
1269                            && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1270                        {
1271                            self.as_slice_unchecked::<$source>()
1272                                .iter()
1273                                .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1274                                .for_each(|(&s, d)| {
1275                                    let s_float = (s as f32 - s_zp as f32) * s_scale as f32;
1276                                    let d_float = s_float as f32 / d_scale as f32 + d_zp as f32;
1277                                    *d = $round(d_float);
1278                                });
1279                            return Ok(Cow::Owned(result));
1280                        }
1281                    };
1282                }
1283
1284                macro_rules! q_n {
1285                    (clamp $source:ty, $dest:ty) => {{
1286                        if <$source>::datum_type().unquantized() == self.datum_type().unquantized()
1287                            && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1288                        {
1289                            self.as_slice_unchecked::<$source>()
1290                                .iter()
1291                                .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1292                                .for_each(|(&s, d)| {
1293                                    *d = s.clamp_cast();
1294                                });
1295                            return Ok(Cow::Owned(result));
1296                        }
1297                    }};
1298                    ($source:ty, $dest:ty) => {{
1299                        if <$source>::datum_type().unquantized() == self.datum_type().unquantized()
1300                            && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1301                        {
1302                            self.as_slice_unchecked::<$source>()
1303                                .iter()
1304                                .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1305                                .for_each(|(&s, d)| {
1306                                    *d = s as $dest;
1307                                });
1308                            return Ok(Cow::Owned(result));
1309                        }
1310                    }};
1311                }
1312
1313                if dst_dt.unquantized() == self.datum_type().unquantized()
1314                    && dst_dt.is_quantized()
1315                    && self.datum_type().is_quantized()
1316                {
1317                    q8_to_q8!(i8);
1318                    q8_to_q8!(u8);
1319                }
1320
1321                q_via_f32!(f32, i8, |f| round_ties_to_even(f).clamp_cast());
1322                q_via_f32!(f32, u8, |f| round_ties_to_even(f).clamp_cast());
1323                q_via_f32!(f32, i32, |f| round_ties_to_even(f).clamp_cast());
1324                q_via_f32!(i8, f32, |f| f);
1325                q_via_f32!(u8, f32, |f| f);
1326                q_via_f32!(i32, f32, |f| f);
1327
1328                if dst_dt.is_quantized() && self.datum_type().is_quantized() {
1329                    q_via_f32!(u8, i8, |f| round_ties_to_even(f).clamp_cast());
1330                    q_via_f32!(i8, u8, |f| round_ties_to_even(f).clamp_cast());
1331                    q_via_f32!(i32, u8, |f| round_ties_to_even(f).clamp_cast());
1332                    q_via_f32!(i32, i8, |f| round_ties_to_even(f).clamp_cast());
1333                    q_via_f32!(u8, i32, |f| round_ties_to_even(f).clamp_cast());
1334                    q_via_f32!(i8, i32, |f| round_ties_to_even(f).clamp_cast());
1335
1336                    // ensure cast to different scale offset work
1337                    q_via_f32!(i8, i8, |f| round_ties_to_even(f).clamp_cast());
1338                    q_via_f32!(u8, u8, |f| round_ties_to_even(f).clamp_cast());
1339                }
1340
1341                q_n!(i8, i32);
1342                q_n!(i8, u32);
1343                q_n!(u8, i32);
1344                q_n!(u8, u32);
1345                q_n!(clamp i32, i8);
1346                q_n!(clamp i32, u8);
1347                q_n!(clamp u32, i8);
1348                q_n!(clamp u32, u8);
1349                q_n!(i8, i8);
1350                q_n!(u8, u8);
1351                q_n!(i32, i32);
1352                q_n!(u32, u32);
1353            }
1354
1355            bail!("Unsupported cast from {:?} to {:?}", self.dt, dst_dt)
1356        }
1357    }
1358
1359    /// Access the data as a scalar, after a cast.
1360    pub fn cast_to_scalar<D: Datum + Copy>(&self) -> TractResult<D> {
1361        let casted = self.cast_to::<D>()?;
1362        casted.try_as_dense()?.to_scalar::<D>().copied()
1363    }
1364
1365    /// Access the nth element of the tensor, returned as a 0-rank Tensor
1366    pub fn nth(&self, nth: usize) -> TractResult<Tensor> {
1367        if nth >= self.len() {
1368            bail!(
1369                "nth called with {}th element on a tensor of len {} ({:?}",
1370                nth,
1371                self.len(),
1372                self
1373            );
1374        }
1375        unsafe fn nth_t<T: Datum>(me: &Tensor, nth: usize, output: &mut Tensor) {
1376            unsafe {
1377                let value = me.as_slice_unchecked::<T>()[nth].clone();
1378                output.as_slice_mut_unchecked::<T>()[0] = value;
1379            }
1380        }
1381        unsafe {
1382            let mut output = Tensor::uninitialized_dt(self.datum_type(), &[])?;
1383            dispatch_datum_by_size!(nth_t(self.datum_type())(self, nth, &mut output));
1384            Ok(output)
1385        }
1386    }
1387
1388    /// Strict equality test on tensors.
1389    fn eq_dt(&self, other: &Tensor) -> TractResult<bool> {
1390        unsafe fn eq_t<D: Datum>(me: &Tensor, other: &Tensor) -> TractResult<bool> {
1391            unsafe {
1392                if D::datum_type().is_float() {
1393                    return dispatch_floatlike!(float_eq_t(D::datum_type())(me, other));
1394                }
1395                Ok(izip!(me.as_slice_unchecked::<D>(), other.as_slice_unchecked::<D>())
1396                    .all(|(a, b)| a == b))
1397            }
1398        }
1399
1400        unsafe fn float_eq_t<D: Datum + Float>(me: &Tensor, other: &Tensor) -> TractResult<bool> {
1401            unsafe {
1402                Ok(izip!(me.as_slice_unchecked::<D>(), other.as_slice_unchecked::<D>())
1403                    .all(|(a, b)| (a.is_nan() && b.is_nan()) || a == b))
1404            }
1405        }
1406
1407        unsafe {
1408            Ok(self.datum_type() == other.datum_type()
1409                && self.shape() == other.shape()
1410                && dispatch_datum!(eq_t(self.dt)(self, other))?)
1411        }
1412    }
1413
1414    fn from_datum<T: Datum>(mut it: ArrayD<T>) -> Tensor {
1415        unsafe {
1416            let mut t = Self::uninitialized::<T>(it.shape()).unwrap();
1417            if let Some(slice) = it.as_slice_mut() {
1418                if t.datum_type().is_copy() {
1419                    std::ptr::copy_nonoverlapping(
1420                        slice.as_ptr() as *const i8,
1421                        t.as_ptr_mut_unchecked(),
1422                        t.dense_storage().layout().size(),
1423                    );
1424                } else {
1425                    t.as_slice_mut_unchecked::<T>()
1426                        .iter_mut()
1427                        .zip(slice.iter_mut())
1428                        .for_each(|(t, s)| *t = std::mem::take(s));
1429                }
1430                return t;
1431            }
1432            if it.strides().iter().all(|&s| s > 0) && it.as_slice_memory_order().is_some() {
1433                let mut len_and_strides: TVec<(usize, usize)> = tvec!();
1434                for (len, stride) in itertools::izip!(it.shape(), it.strides(), t.strides())
1435                    .sorted_by_key(|(_, src, _)| *src)
1436                    .map(|(l, _, dst)| (*l as isize, *dst))
1437                {
1438                    if !len_and_strides.is_empty()
1439                        && len_and_strides.last().unwrap().1 * len_and_strides.last().unwrap().0
1440                            == stride as usize
1441                    {
1442                        len_and_strides.last_mut().unwrap().0 *= len as usize;
1443                    } else {
1444                        len_and_strides.push((len as usize, stride as usize));
1445                    }
1446                }
1447                len_and_strides.reverse();
1448                crate::scatter::scatter_contig_data(
1449                    it.as_ptr(),
1450                    t.as_ptr_mut_unchecked(),
1451                    &len_and_strides,
1452                );
1453                return t;
1454            }
1455            // finally use ndarray into_iter()
1456            t.as_slice_mut_unchecked().iter_mut().zip(it).for_each(|(t, a)| *t = a);
1457            t
1458        }
1459    }
1460
1461    pub fn deep_clone(&self) -> Tensor {
1462        unsafe {
1463            let mut tensor = Tensor::uninitialized_dt(self.datum_type(), self.shape()).unwrap();
1464            if self.len() > 0 {
1465                if self.dt.is_copy() {
1466                    self.dense_storage().as_ptr().copy_to_nonoverlapping(
1467                        tensor.as_bytes_mut().as_mut_ptr(),
1468                        self.dense_storage().layout().size(),
1469                    )
1470                } else if self.dt == DatumType::String {
1471                    tensor
1472                        .as_slice_mut_unchecked::<String>()
1473                        .clone_from_slice(self.as_slice_unchecked());
1474                } else if self.dt == DatumType::Blob {
1475                    tensor
1476                        .as_slice_mut_unchecked::<Blob>()
1477                        .clone_from_slice(self.as_slice_unchecked());
1478                } else if self.dt == DatumType::Opaque {
1479                    tensor
1480                        .as_slice_mut_unchecked::<Opaque>()
1481                        .clone_from_slice(self.as_slice_unchecked());
1482                } else if self.dt == DatumType::TDim {
1483                    tensor
1484                        .as_slice_mut_unchecked::<TDim>()
1485                        .clone_from_slice(self.as_slice_unchecked());
1486                }
1487            }
1488            tensor
1489        }
1490    }
1491
1492    pub fn slice(&self, axis: usize, start: usize, end: usize) -> TractResult<Tensor> {
1493        if axis >= self.rank() {
1494            bail!("Can not slice at axis {} tensor {:?}", axis, self);
1495        }
1496        if start > self.shape[axis] || end > self.shape[axis] || start >= end {
1497            bail!("Invalid slicing range {start}..{end} on axis {axis} for {self:?}");
1498        }
1499        fn slice_t<T: Datum>(
1500            t: &Tensor,
1501            axis: usize,
1502            start: usize,
1503            end: usize,
1504        ) -> TractResult<Tensor> {
1505            Ok(t.to_dense_array_view::<T>()?
1506                .slice_axis(ndarray::Axis(axis), (start..end).into())
1507                .into_owned()
1508                .into_tensor())
1509        }
1510        dispatch_datum!(slice_t(self.datum_type())(self, axis, start, end))
1511    }
1512
1513    #[inline]
1514    pub fn view(&self) -> view::TensorView<'_> {
1515        unsafe { view::TensorView::view(self) }
1516    }
1517
1518    #[inline]
1519    pub fn view_at_prefix(&self, prefix: &[usize]) -> TractResult<view::TensorView<'_>> {
1520        view::TensorView::at_prefix(self, prefix)
1521    }
1522
1523    #[inline]
1524    pub fn view_offsetting(&self, coords: &[usize]) -> TractResult<view::TensorView<'_>> {
1525        view::TensorView::offsetting(self, coords)
1526    }
1527
1528    #[inline]
1529    pub unsafe fn view_offsetting_unchecked(&self, coords: &[usize]) -> view::TensorView<'_> {
1530        unsafe { view::TensorView::offsetting_unchecked(self, coords) }
1531    }
1532
1533    #[inline]
1534    pub fn view_mut(&mut self) -> view::TensorView<'_> {
1535        unsafe { view::TensorView::view(self) }
1536    }
1537
1538    #[inline]
1539    pub fn view_at_prefix_mut(&mut self, prefix: &[usize]) -> TractResult<view::TensorView<'_>> {
1540        view::TensorView::at_prefix(self, prefix)
1541    }
1542
1543    #[inline]
1544    pub fn view_offsetting_mut(&mut self, coords: &[usize]) -> TractResult<view::TensorView<'_>> {
1545        view::TensorView::offsetting(self, coords)
1546    }
1547
1548    /// Offsets the tensor as an i8 type if it's an u8 type, otherwise passes it unchanged.
1549    pub fn offset_u8_as_i8(self: &Arc<Self>) -> Arc<Self> {
1550        let mut t = if let DatumType::U8 = self.dt.unquantized() {
1551            self.try_as_dense()
1552                .unwrap()
1553                .to_array_view::<u8>()
1554                .unwrap()
1555                .mapv(|v| v.wrapping_sub(128) as i8)
1556                .into_tensor()
1557        } else {
1558            return self.clone();
1559        };
1560
1561        if let DatumType::QU8(qp) = self.dt {
1562            if let QParams::ZpScale { zero_point, scale } = qp {
1563                t.dt = DatumType::QI8(QParams::ZpScale { zero_point: zero_point - 128, scale });
1564            } else {
1565                t.dt = DatumType::QI8(qp);
1566            }
1567        }
1568
1569        t.into_arc_tensor()
1570    }
1571
1572    /// Offsets the tensor as an u8 type if it's an i8 type, otherwise passes it unchanged.
1573    pub fn offset_i8_as_u8(self: &Arc<Self>) -> Arc<Self> {
1574        let mut t = if let DatumType::I8 = self.dt.unquantized() {
1575            self.try_as_dense()
1576                .unwrap()
1577                .to_array_view::<i8>()
1578                .unwrap()
1579                .mapv(|v| (v as u8).wrapping_add(128))
1580                .into_tensor()
1581        } else {
1582            return self.clone();
1583        };
1584
1585        if let DatumType::QI8(qp) = self.dt {
1586            if let QParams::ZpScale { zero_point, scale } = qp {
1587                t.dt = DatumType::QU8(QParams::ZpScale { zero_point: zero_point + 128, scale });
1588            } else {
1589                t.dt = DatumType::QU8(qp);
1590            }
1591        }
1592        t.into_arc_tensor()
1593    }
1594
1595    pub fn to_aligned_default(&self) -> TractResult<Self> {
1596        if self.dt.is_copy() {
1597            unsafe {
1598                let mut t = Self::uninitialized_dt(self.dt, &self.shape)?;
1599                t.as_bytes_mut().copy_from_slice(self.as_bytes());
1600                Ok(t)
1601            }
1602        } else {
1603            let mut t = Self::zero_dt(self.dt, &self.shape)?;
1604            if self.dt == String::datum_type() {
1605                t.try_as_dense_mut()?
1606                    .as_slice_mut::<String>()?
1607                    .clone_from_slice(self.try_as_dense()?.as_slice()?);
1608            } else if self.dt == Blob::datum_type() {
1609                t.try_as_dense_mut()?
1610                    .as_slice_mut::<Blob>()?
1611                    .clone_from_slice(self.try_as_dense()?.as_slice()?);
1612            } else if self.dt == TDim::datum_type() {
1613                t.try_as_dense_mut()?
1614                    .as_slice_mut::<TDim>()?
1615                    .clone_from_slice(self.try_as_dense()?.as_slice()?);
1616            }
1617            Ok(t)
1618        }
1619    }
1620
1621    pub fn natural_strides(shape: &[usize]) -> TVec<isize> {
1622        let mut strides = tvec!();
1623        compute_natural_stride_to(&mut strides, shape);
1624        strides
1625    }
1626
1627    pub fn into_blob(mut self) -> TractResult<Blob> {
1628        ensure!(self.dt.is_copy());
1629        let storage =
1630            std::mem::replace(&mut self.storage, StorageKind::Dense(DenseStorage::default()));
1631        Ok(storage.into_dense().context("Storage is not dense")?.into_blob())
1632    }
1633}
1634
1635impl PartialEq for Tensor {
1636    fn eq(&self, other: &Tensor) -> bool {
1637        if self.dt != other.dt || self.shape != other.shape {
1638            return false;
1639        }
1640        self.eq_dt(other).unwrap_or(false)
1641    }
1642}
1643
1644impl Eq for Tensor {}
1645
1646impl fmt::Debug for Tensor {
1647    fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
1648        let content = self.dump(false).unwrap_or_else(|e| format!("Error : {e:?}"));
1649        write!(formatter, "{content}")
1650    }
1651}
1652
1653#[cfg(feature = "complex")]
1654pub fn reinterpret_inner_dim_as_complex(mut t: Tensor) -> TractResult<Tensor> {
1655    ensure!(
1656        t.shape().last() == Some(&2),
1657        "The last dimension in the tensor shape {:?} must be 2",
1658        t.shape()
1659    );
1660    unsafe {
1661        t.shape.pop();
1662        t.set_datum_type(t.datum_type().complexify()?);
1663        t.update_strides_and_len();
1664        Ok(t)
1665    }
1666}
1667
1668#[cfg(feature = "complex")]
1669pub fn reinterpret_complex_as_inner_dim(mut t: Tensor) -> TractResult<Tensor> {
1670    unsafe {
1671        t.shape.push(2);
1672        t.set_datum_type(t.datum_type().decomplexify()?);
1673        t.update_strides_and_len();
1674        Ok(t)
1675    }
1676}
1677
1678pub fn clip_range_bounds(len: usize, range: impl std::ops::RangeBounds<usize>) -> Range<usize> {
1679    use std::ops::Bound;
1680    let start = match range.start_bound() {
1681        Bound::Included(ix) => *ix,
1682        Bound::Excluded(ix) => ix + 1,
1683        Bound::Unbounded => 0,
1684    };
1685    let end = match range.end_bound() {
1686        Bound::Included(ix) => *ix + 1,
1687        Bound::Excluded(ix) => *ix,
1688        Bound::Unbounded => len,
1689    };
1690    start..end
1691}
1692
1693pub fn natural_strides(shape: &[usize]) -> TVec<isize> {
1694    let mut strides = tvec!();
1695    compute_natural_stride_to(&mut strides, shape);
1696    strides
1697}
1698
1699fn compute_natural_stride_to(strides: &mut TVec<isize>, shape: &[usize]) {
1700    match shape.len() {
1701        0 => (),
1702        1 => strides.push(1),
1703        2 => strides.extend_from_slice(&[shape[1] as isize, 1]),
1704        3 => strides.extend_from_slice(&[(shape[1] * shape[2]) as isize, shape[2] as _, 1]),
1705        4 => strides.extend_from_slice(&[
1706            (shape[1] * shape[2] * shape[3]) as isize,
1707            (shape[2] * shape[3]) as _,
1708            shape[3] as _,
1709            1,
1710        ]),
1711        _ => {
1712            strides.push(1);
1713            for dim in shape.as_ref().iter().skip(1).rev() {
1714                let previous = *strides.last().unwrap();
1715                strides.push(previous * *dim as isize)
1716            }
1717            strides.reverse();
1718        }
1719    }
1720}
1721
1722impl<D: ::ndarray::Dimension, T: Datum> From<Array<T, D>> for Tensor {
1723    fn from(it: Array<T, D>) -> Tensor {
1724        Tensor::from_datum(it.into_dyn())
1725    }
1726}
1727
1728/// Convenient conversion to Tensor.
1729pub trait IntoTensor: Sized {
1730    /// Convert Self to a Tensor.
1731    ///
1732    /// May perform a copy
1733    fn into_tensor(self) -> Tensor;
1734}
1735
1736/// Convenient conversion to Arc<Tensor>.
1737pub trait IntoArcTensor: Sized {
1738    /// Convert Self to a Arc<Tensor>.
1739    ///
1740    /// May perform a copy
1741    fn into_arc_tensor(self) -> Arc<Tensor>;
1742}
1743
1744impl<D: ::ndarray::Dimension, T: Datum> IntoTensor for Array<T, D> {
1745    fn into_tensor(self) -> Tensor {
1746        Tensor::from(self)
1747    }
1748}
1749
1750impl<D: ::ndarray::Dimension, T: Datum> IntoArcTensor for Array<T, D> {
1751    fn into_arc_tensor(self) -> Arc<Tensor> {
1752        Arc::new(Tensor::from(self))
1753    }
1754}
1755
1756impl IntoTensor for Tensor {
1757    fn into_tensor(self) -> Tensor {
1758        self
1759    }
1760}
1761
1762impl IntoTensor for Arc<Tensor> {
1763    fn into_tensor(self) -> Tensor {
1764        Arc::try_unwrap(self).unwrap_or_else(|t| (*t).clone())
1765    }
1766}
1767
1768impl IntoArcTensor for Tensor {
1769    fn into_arc_tensor(self) -> Arc<Tensor> {
1770        Arc::new(self)
1771    }
1772}
1773
1774impl IntoArcTensor for Arc<Tensor> {
1775    fn into_arc_tensor(self) -> Arc<Tensor> {
1776        self
1777    }
1778}
1779
1780#[cfg(test)]
1781mod tests {
1782    use crate::dim::SymbolScope;
1783    use crate::prelude::tensor1;
1784
1785    use super::*;
1786    use litteral::tensor0;
1787    use proptest::collection::vec;
1788    use proptest::prelude::*;
1789
1790    #[derive(Debug)]
1791    struct PermuteAxisProblem {
1792        shape: Vec<usize>,
1793        permutation: Vec<usize>,
1794    }
1795
1796    impl Arbitrary for PermuteAxisProblem {
1797        type Strategy = BoxedStrategy<PermuteAxisProblem>;
1798        type Parameters = ();
1799
1800        fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1801            (0..8usize)
1802                .prop_flat_map(|rank| {
1803                    let permute: Vec<usize> = (0..rank).collect();
1804                    (proptest::collection::vec(1..5usize, rank), Just(permute).prop_shuffle())
1805                })
1806                .prop_map(|(shape, permutation)| PermuteAxisProblem { shape, permutation })
1807                .boxed()
1808        }
1809    }
1810
1811    impl PermuteAxisProblem {
1812        fn input(&self) -> ArrayD<i32> {
1813            let mut i = 0;
1814            ArrayD::from_shape_simple_fn(&*self.shape, || {
1815                i += 1;
1816                i
1817            })
1818            .permuted_axes(&*self.permutation)
1819        }
1820
1821        fn reference(&self) -> Tensor {
1822            let values: Vec<i32> = self.input().iter().copied().collect();
1823            let shape = self.permutation.iter().map(|ix| self.shape[*ix]).collect::<TVec<usize>>();
1824            super::litteral::tensor1(&values).into_shape(&shape).unwrap()
1825        }
1826
1827        fn tract(&self) -> Tensor {
1828            Tensor::from(self.input())
1829        }
1830
1831        fn check(&self) -> proptest::test_runner::TestCaseResult {
1832            prop_assert_eq!(self.tract(), self.reference());
1833            Ok(())
1834        }
1835    }
1836
1837    proptest::proptest! {
1838        #[test]
1839        fn prop(pb: PermuteAxisProblem) {
1840            pb.check().unwrap();
1841        }
1842    }
1843
1844    #[test]
1845    fn t_1_2() {
1846        PermuteAxisProblem { shape: vec![2, 1], permutation: vec![1, 0] }.check().unwrap();
1847    }
1848
1849    #[test]
1850    fn t_2_2() {
1851        PermuteAxisProblem { shape: vec![2, 2], permutation: vec![1, 0] }.check().unwrap();
1852    }
1853
1854    #[derive(Debug)]
1855    struct BroadcastVecToShape {
1856        vec: Vec<f32>,
1857        axis: usize,
1858        shape: TVec<usize>,
1859    }
1860
1861    impl BroadcastVecToShape {
1862        fn check(&self) -> proptest::test_runner::TestCaseResult {
1863            let input = tensor1(&self.vec);
1864            let mut intermediate = tvec![1usize; self.shape.len()];
1865            intermediate[self.axis] = self.vec.len();
1866            let reference = input
1867                .clone()
1868                .into_shape(&intermediate)
1869                .unwrap()
1870                .broadcast_to_shape(&self.shape)
1871                .unwrap();
1872            prop_assert_eq!(
1873                reference,
1874                input.broadcast_vector_to_shape(&self.shape, self.axis).unwrap()
1875            );
1876            Ok(())
1877        }
1878    }
1879
1880    impl Arbitrary for BroadcastVecToShape {
1881        type Strategy = BoxedStrategy<BroadcastVecToShape>;
1882        type Parameters = ();
1883
1884        fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1885            vec(0usize..5, 0usize..4)
1886                .prop_flat_map(|shape| {
1887                    (vec(-10f32..10f32, 0usize..5), Just(shape.clone()), 0..shape.len() + 1)
1888                })
1889                .prop_map(|(vec, mut shape, axis)| {
1890                    shape.insert(axis, vec.len());
1891                    BroadcastVecToShape { vec, shape: shape.into(), axis }
1892                })
1893                .boxed()
1894        }
1895    }
1896
1897    proptest::proptest! {
1898        #[test]
1899        fn broadcast_vector_to_shape_prop(pb: BroadcastVecToShape) {
1900            pb.check().unwrap()
1901        }
1902    }
1903
1904    #[test]
1905    #[cfg(feature = "complex")]
1906    fn test_reinterpret_inner_dim_as_complex() -> TractResult<()> {
1907        let input = crate::internal::tensor2(&[[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]]);
1908        let cplx_input = reinterpret_inner_dim_as_complex(input)?;
1909        let expected = crate::internal::tensor1(&[
1910            Complex::new(1.0f32, 2.0),
1911            Complex::new(3.0, 4.0),
1912            Complex::new(5.0, 6.0),
1913        ]);
1914        assert_eq!(expected, cplx_input);
1915        Ok(())
1916    }
1917
1918    #[test]
1919    #[cfg(feature = "complex")]
1920    fn test_reinterpret_inner_dim_as_complex_2() -> TractResult<()> {
1921        let input =
1922            crate::internal::tensor3(&[[[1i32, 2], [1, 2]], [[3, 4], [3, 4]], [[5, 6], [5, 6]]]);
1923        let cplx_input = reinterpret_inner_dim_as_complex(input)?;
1924        let expected = crate::internal::tensor2(&[
1925            [Complex::new(1i32, 2), Complex::new(1, 2)],
1926            [Complex::new(3, 4), Complex::new(3, 4)],
1927            [Complex::new(5, 6), Complex::new(5, 6)],
1928        ]);
1929        assert_eq!(expected, cplx_input);
1930        Ok(())
1931    }
1932
1933    #[test]
1934    fn clone_tdim_tensor() {
1935        let symbols = SymbolScope::default();
1936        let a = symbols.sym("a");
1937        let t = tensor0(TDim::from(a));
1938        let _ = t.clone();
1939    }
1940}