Skip to main content

tract_data/
tensor.rs

1//! `Tensor`, tract main data object of interest.
2use crate::TVec;
3use crate::blob::Blob;
4use crate::datum::{ClampCast, Datum, DatumType, QParams, round_ties_to_even, scale_by};
5use crate::dim::TDim;
6use crate::internal::*;
7use half::f16;
8use itertools::{Itertools, izip};
9use ndarray::prelude::*;
10#[cfg(feature = "complex")]
11use num_complex::Complex;
12use num_traits::{Float, Zero};
13use std::borrow::Cow;
14use std::fmt;
15use std::hash::Hash;
16use std::ops::Range;
17use std::sync::Arc;
18
19pub mod litteral;
20pub mod plain_view;
21pub mod storage;
22pub mod view;
23
24pub use plain_view::{PlainView, PlainViewMut};
25use storage::{PlainStorage, StorageKind, TensorStorage};
26
27#[derive(Copy, Clone, Default, Debug)]
28pub enum Approximation {
29    Exact,
30    #[default]
31    Close,
32    Approximate,
33    VeryApproximate,
34    SuperApproximate,
35    UltraApproximate,
36    Custom(f32, f32, f32),
37}
38
39impl PartialEq for Approximation {
40    fn eq(&self, other: &Self) -> bool {
41        use Approximation::Custom;
42        if let (Custom(aa, ar, ao), Custom(ba, br, bo)) = (self, other) {
43            aa == ba && ar == br && bo == ao
44        } else {
45            std::mem::discriminant(self) == std::mem::discriminant(other)
46        }
47    }
48}
49
50impl Eq for Approximation {}
51
52impl From<bool> for Approximation {
53    fn from(b: bool) -> Self {
54        if b { Self::Approximate } else { Self::Exact }
55    }
56}
57
58impl Approximation {
59    fn atol_rtol_outliers(&self, dt: &DatumType) -> (f64, f64, f64) {
60        use Approximation::*;
61        match (self, dt) {
62            (Exact, _) => (0.0, 0.0, 0.0),
63            (Close, DatumType::F16) => (1e-3, 1e-3, 0.0),
64            (Approximate, DatumType::F16) => (1e-3, 5e-3, 0.0),
65            (Approximate, qp) if qp.is_quantized() => (qp.zp_scale().1 as f64, 0., 0.0),
66            (Close, _) => (1e-7, 1e-7, 0.0),
67            (Approximate, _) => (1e-4, 5e-4, 0.0),
68            (VeryApproximate, _) => (5e-2, 1e-2, 0.0),
69            (SuperApproximate, _) => (0.1, 0.05, 0.0001),
70            (UltraApproximate, _) => (0.2, 0.1, 0.0005),
71            (Custom(atol, rtol, out), _) => (*atol as _, *rtol as _, *out as _),
72        }
73    }
74}
75
76/// Tensor is a concrete tensor in tract.
77pub struct Tensor {
78    dt: DatumType,
79    shape: TVec<usize>,
80    strides: TVec<isize>,
81    len: usize,
82    storage: StorageKind,
83}
84
85unsafe impl Send for Tensor {}
86unsafe impl Sync for Tensor {}
87
88impl Hash for Tensor {
89    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
90        use DatumType::*;
91        self.dt.hash(state);
92        self.shape.hash(state);
93        if let Some(plain) = self.storage.as_plain() {
94            plain.layout().align().hash(state);
95            unsafe {
96                match self.dt {
97                    Bool => self.as_slice_unchecked::<bool>().hash(state),
98                    I8 => self.as_slice_unchecked::<i8>().hash(state),
99                    I16 => self.as_slice_unchecked::<i16>().hash(state),
100                    I32 => self.as_slice_unchecked::<i32>().hash(state),
101                    I64 => self.as_slice_unchecked::<i64>().hash(state),
102                    U8 => self.as_slice_unchecked::<u8>().hash(state),
103                    U16 => self.as_slice_unchecked::<u16>().hash(state),
104                    U32 => self.as_slice_unchecked::<u32>().hash(state),
105                    U64 => self.as_slice_unchecked::<u64>().hash(state),
106                    F16 => self.as_slice_unchecked::<i16>().hash(state),
107                    F32 => self.as_slice_unchecked::<i32>().hash(state),
108                    F64 => self.as_slice_unchecked::<i64>().hash(state),
109                    TDim => self.as_slice_unchecked::<crate::dim::TDim>().hash(state),
110                    String => self.as_slice_unchecked::<std::string::String>().hash(state),
111                    Blob => self.as_slice_unchecked::<crate::blob::Blob>().hash(state),
112                    QI8(_) => self.as_slice_unchecked::<i8>().hash(state),
113                    QU8(_) => self.as_slice_unchecked::<u8>().hash(state),
114                    QI32(_) => self.as_slice_unchecked::<i32>().hash(state),
115                    #[cfg(feature = "complex")]
116                    ComplexI16 => self.as_slice_unchecked::<Complex<i16>>().hash(state),
117                    #[cfg(feature = "complex")]
118                    ComplexI32 => self.as_slice_unchecked::<Complex<i32>>().hash(state),
119                    #[cfg(feature = "complex")]
120                    ComplexI64 => self.as_slice_unchecked::<Complex<i64>>().hash(state),
121                    #[cfg(feature = "complex")]
122                    ComplexF16 => self.as_slice_unchecked::<Complex<i16>>().hash(state),
123                    #[cfg(feature = "complex")]
124                    ComplexF32 => self.as_slice_unchecked::<Complex<i32>>().hash(state),
125                    #[cfg(feature = "complex")]
126                    ComplexF64 => self.as_slice_unchecked::<Complex<i64>>().hash(state),
127                }
128            }
129        } else {
130            self.storage.dyn_hash(state);
131        }
132    }
133}
134
135impl Clone for Tensor {
136    fn clone(&self) -> Tensor {
137        self.deep_clone()
138    }
139}
140
141impl Default for Tensor {
142    fn default() -> Tensor {
143        litteral::tensor0(0f32)
144    }
145}
146
147impl Drop for Tensor {
148    fn drop(&mut self) {
149        if self.is_plain() {
150            macro_rules! drop_in_place {
151                ($t: ty) => {
152                    if self.dt == <$t>::datum_type() {
153                        unsafe {
154                            let slice = self.as_slice_mut_unchecked::<$t>();
155                            std::ptr::drop_in_place(slice as *mut [$t]);
156                        }
157                    }
158                };
159            }
160            drop_in_place!(Blob);
161            drop_in_place!(String);
162            drop_in_place!(TDim);
163        }
164        // StorageKind::Exotic drops via Box<dyn TensorStorage> automatically
165    }
166}
167
168#[allow(unreachable_code)]
169pub fn vector_size() -> usize {
170    #[cfg(target_arch = "x86_64")]
171    {
172        return if is_x86_feature_detected!("avx512f") { 512 / 8 } else { 256 / 8 };
173    }
174    128 / 8
175}
176
177impl Tensor {
178    #[inline]
179    fn plain_storage(&self) -> &PlainStorage {
180        self.storage.as_plain().expect("Non-plain storage")
181    }
182
183    #[inline]
184    fn plain_storage_mut(&mut self) -> &mut PlainStorage {
185        self.storage.as_plain_mut().expect("Non-plain storage")
186    }
187
188    pub fn storage_as<T: TensorStorage>(&self) -> Option<&T> {
189        self.storage.as_storage().downcast_ref::<T>()
190    }
191
192    pub fn try_storage_as<T: TensorStorage>(&self) -> TractResult<&T> {
193        self.storage_as::<T>().context("Unexpected tensor storage type")
194    }
195
196    pub fn from_storage(
197        dt: DatumType,
198        shape: &[usize],
199        storage: impl TensorStorage + 'static,
200    ) -> Tensor {
201        let len = shape.iter().product::<usize>();
202        let strides = Self::natural_strides(shape);
203        Tensor {
204            dt,
205            shape: shape.into(),
206            strides,
207            len,
208            storage: StorageKind::Exotic(Box::new(storage)),
209        }
210    }
211
212    /// Returns an immutable [`PlainView`] if this tensor has plain storage.
213    #[inline]
214    pub fn as_plain(&self) -> Option<PlainView<'_>> {
215        let storage = self.storage.as_plain()?;
216        Some(PlainView::new(self, storage))
217    }
218
219    /// Returns an immutable [`PlainView`], or an error if storage is not plain.
220    #[inline]
221    pub fn try_as_plain(&self) -> TractResult<PlainView<'_>> {
222        self.as_plain().context("Tensor storage is not plain")
223    }
224
225    /// Returns `true` if this tensor uses plain (contiguous) storage.
226    #[inline]
227    pub fn is_plain(&self) -> bool {
228        self.storage.as_plain().is_some()
229    }
230
231    /// Returns `true` if this tensor uses exotic (non-plain) storage.
232    #[inline]
233    pub fn is_exotic(&self) -> bool {
234        !self.is_plain()
235    }
236
237    /// Build the `ExoticFact` matching this tensor's storage, or `None` for plain tensors.
238    pub fn exotic_fact(&self) -> TractResult<Option<Box<dyn crate::exotic::ExoticFact>>> {
239        self.storage.as_storage().exotic_fact(&self.shape)
240    }
241
242    /// Returns a mutable [`PlainViewMut`] if this tensor has plain storage.
243    #[inline]
244    pub fn as_plain_mut(&mut self) -> Option<PlainViewMut<'_>> {
245        let storage = self.storage.as_plain_mut()?;
246        Some(PlainViewMut::new(self.dt, &self.shape, &self.strides, self.len, storage))
247    }
248
249    /// Returns a mutable [`PlainViewMut`], or an error if storage is not plain.
250    #[inline]
251    pub fn try_as_plain_mut(&mut self) -> TractResult<PlainViewMut<'_>> {
252        self.as_plain_mut().context("Tensor storage is not plain")
253    }
254
255    /// Create an uninitialized tensor (dt as type paramater).
256    #[inline]
257    pub unsafe fn uninitialized<T: Datum>(shape: &[usize]) -> TractResult<Tensor> {
258        unsafe { Self::uninitialized_dt(T::datum_type(), shape) }
259    }
260
261    /// Create an uninitialized tensor (dt as regular parameter).
262    #[inline]
263    pub unsafe fn uninitialized_dt(dt: DatumType, shape: &[usize]) -> TractResult<Tensor> {
264        unsafe { Self::uninitialized_aligned_dt(dt, shape, vector_size()) }
265    }
266
267    /// Create an uninitialized tensor with a given alignment (in bytes).
268    #[inline]
269    pub unsafe fn uninitialized_aligned<T: Datum>(
270        shape: &[usize],
271        alignment: usize,
272    ) -> TractResult<Tensor> {
273        unsafe { Self::uninitialized_aligned_dt(T::datum_type(), shape, alignment) }
274    }
275
276    /// Create an uninitialized tensor with a given alignment (in bytes).
277    pub unsafe fn uninitialized_aligned_dt(
278        dt: DatumType,
279        shape: &[usize],
280        alignment: usize,
281    ) -> TractResult<Tensor> {
282        let bytes = shape.iter().cloned().product::<usize>() * dt.size_of();
283        let storage = StorageKind::Plain(PlainStorage::from(unsafe {
284            Blob::new_for_size_and_align(bytes, alignment)
285        }));
286        let mut tensor = Tensor { strides: tvec!(), dt, shape: shape.into(), storage, len: 0 };
287        if tensor.shape.len() == 0 {
288            tensor.len = 1;
289        } else {
290            tensor.update_strides_and_len();
291        }
292        if !tensor.storage.is_empty() {
293            if dt == String::datum_type() || dt == Blob::datum_type() {
294                // assumes zero-initialized string and blob are valid
295                tensor.plain_storage_mut().as_bytes_mut().fill(0);
296            } else if dt == TDim::datum_type() {
297                unsafe {
298                    tensor
299                        .as_slice_mut_unchecked::<TDim>()
300                        .iter_mut()
301                        .for_each(|dim| std::ptr::write(dim, TDim::zero()))
302                }
303            } else if cfg!(debug_assertions) {
304                assert!(dt.is_copy());
305                if dt == DatumType::F32 {
306                    tensor.fill_t(f32::NAN).unwrap();
307                } else {
308                    // safe, non copy types have been dealt with
309                    tensor.as_bytes_mut().iter_mut().for_each(|x| *x = (-1i8) as u8);
310                }
311            }
312        }
313        Ok(tensor)
314    }
315
316    pub fn stack_tensors(
317        axis: usize,
318        tensors: &[impl std::borrow::Borrow<Tensor>],
319    ) -> TractResult<Tensor> {
320        ensure!(tensors.len() > 0);
321        let rank = tensors[0].borrow().rank();
322        ensure!(axis < rank);
323        ensure!(tensors.iter().all(|t| t.borrow().rank() == rank));
324        let dt = tensors[0].borrow().datum_type();
325        ensure!(tensors.iter().all(|t| t.borrow().datum_type() == dt));
326        let mut shape: TVec<usize> = tensors[0].borrow().shape().into();
327        for ax in 0..rank {
328            if ax != axis {
329                ensure!(tensors.iter().all(|t| t.borrow().shape()[ax] == shape[ax]));
330            }
331        }
332        shape[axis] = tensors.iter().map(|v| v.borrow().shape()[axis]).sum();
333        unsafe {
334            let mut result = Tensor::uninitialized_dt(dt, &shape)?;
335            if dt.is_copy() && shape[..axis].iter().all(|d| *d == 1) {
336                let mut offset = 0isize;
337                for v in tensors {
338                    let v = v.borrow();
339                    let len = v.storage.byte_len();
340                    std::ptr::copy_nonoverlapping(
341                        v.plain_storage().as_ptr(),
342                        result.plain_storage_mut().as_mut_ptr().offset(offset),
343                        len,
344                    );
345                    offset += len as isize;
346                }
347            } else {
348                let mut offset = 0;
349                for t in tensors {
350                    let t = t.borrow();
351                    let len = t.shape()[axis];
352                    result.assign_slice_from_resolved(offset..offset + len, t, 0..len, axis);
353                    offset += len;
354                }
355            }
356
357            Ok(result)
358        }
359    }
360
361    pub fn clear<T: Datum + num_traits::Zero + Clone>(&mut self) -> TractResult<()> {
362        self.fill_t(T::zero())
363    }
364
365    pub fn zero<T: Datum + num_traits::Zero>(shape: &[usize]) -> TractResult<Tensor> {
366        unsafe {
367            let mut t = Tensor::uninitialized::<T>(shape)?;
368            t.clear::<T>()?;
369            Ok(t)
370        }
371    }
372
373    pub fn zero_scalar<T: Datum + num_traits::Zero>() -> TractResult<Tensor> {
374        Tensor::zero::<T>(&[])
375    }
376
377    pub fn zero_scalar_dt(dt: DatumType) -> TractResult<Tensor> {
378        Tensor::zero_dt(dt, &[])
379    }
380
381    pub fn zero_dt(dt: DatumType, shape: &[usize]) -> TractResult<Tensor> {
382        Tensor::zero_aligned_dt(dt, shape, vector_size())
383    }
384
385    pub fn fill_t<T: Datum + Clone>(&mut self, value: T) -> TractResult<()> {
386        self.try_as_plain_mut()?
387            .as_slice_mut::<T>()?
388            .iter_mut()
389            .for_each(|item| *item = value.clone());
390        Ok(())
391    }
392
393    pub fn zero_aligned_dt(
394        dt: DatumType,
395        shape: &[usize],
396        alignment: usize,
397    ) -> TractResult<Tensor> {
398        if shape.iter().product::<usize>() == 0 {
399            unsafe { return Tensor::uninitialized_dt(dt, shape) };
400        }
401        if dt.is_quantized() {
402            unsafe {
403                let mut t = Tensor::uninitialized_dt(dt, shape)?;
404                let zp = dt.zp_scale().0;
405                match dt.unquantized() {
406                    DatumType::I8 => t
407                        .try_as_plain_mut()?
408                        .as_slice_mut::<i8>()?
409                        .iter_mut()
410                        .for_each(|item| *item = zp as _),
411                    DatumType::U8 => t
412                        .try_as_plain_mut()?
413                        .as_slice_mut::<u8>()?
414                        .iter_mut()
415                        .for_each(|item| *item = zp as _),
416                    DatumType::I32 => t
417                        .try_as_plain_mut()?
418                        .as_slice_mut::<i32>()?
419                        .iter_mut()
420                        .for_each(|item| *item = zp as _),
421                    _ => unreachable!(),
422                }
423                Ok(t)
424            }
425        } else if dt == DatumType::Bool {
426            let mut t = unsafe { Tensor::uninitialized_dt(dt, shape)? };
427            t.fill_t::<bool>(false)?;
428            Ok(t)
429        } else {
430            dispatch_zerolike!(Self::zero_aligned(dt)(shape, alignment))
431        }
432    }
433
434    pub fn zero_aligned<T: Datum + num_traits::Zero>(
435        shape: &[usize],
436        alignment: usize,
437    ) -> TractResult<Tensor> {
438        unsafe {
439            let mut tensor = Self::uninitialized_aligned::<T>(shape, alignment)?;
440            tensor.clear::<T>()?;
441            Ok(tensor)
442        }
443    }
444
445    /// Create a tensor with a given shape and a slice of elements.
446    /// The data is copied and aligned to size of T.
447    pub fn from_shape<T: Datum + Copy>(shape: &[usize], data: &[T]) -> TractResult<Tensor> {
448        Self::from_shape_align(shape, data, vector_size())
449    }
450
451    /// Create a tensor with a given shape and a slice of elements.
452    /// The data is copied and aligned to given alignment.
453    pub fn from_shape_align<T: Datum + Copy>(
454        shape: &[usize],
455        data: &[T],
456        align: usize,
457    ) -> TractResult<Tensor> {
458        ensure!(
459            data.len() == shape.iter().product::<usize>(),
460            "Shape product must be equal to data length"
461        );
462        unsafe {
463            let bytes = std::slice::from_raw_parts(
464                data.as_ptr() as *const u8,
465                data.len() * T::datum_type().size_of(),
466            );
467            let dt = T::datum_type();
468            Self::from_raw_dt_align(dt, shape, bytes, align)
469        }
470    }
471
472    /// Create a tensor from raw data.
473    ///
474    /// It copies the data, aligning it to the size of T.
475    pub unsafe fn from_raw<T: Datum>(shape: &[usize], content: &[u8]) -> TractResult<Tensor> {
476        unsafe { Tensor::from_raw_dt(T::datum_type(), shape, content) }
477    }
478
479    pub unsafe fn from_raw_aligned<T: Datum>(
480        shape: &[usize],
481        content: &[u8],
482        align: usize,
483    ) -> TractResult<Tensor> {
484        unsafe { Tensor::from_raw_dt_align(T::datum_type(), shape, content, align) }
485    }
486
487    pub unsafe fn from_raw_dt(
488        dt: DatumType,
489        shape: &[usize],
490        content: &[u8],
491    ) -> TractResult<Tensor> {
492        unsafe { Self::from_raw_dt_align(dt, shape, content, vector_size()) }
493    }
494
495    pub unsafe fn from_raw_dt_align(
496        dt: DatumType,
497        shape: &[usize],
498        content: &[u8],
499        align: usize,
500    ) -> TractResult<Tensor> {
501        let mut tensor = unsafe { Tensor::uninitialized_aligned_dt(dt, shape, align) }?;
502        tensor.as_bytes_mut().copy_from_slice(content);
503        Ok(tensor)
504    }
505
506    pub unsafe fn from_slice_align<T: Datum>(content: &[T], align: usize) -> TractResult<Tensor> {
507        let bytes = if content.len() == 0 {
508            &[]
509        } else {
510            unsafe {
511                std::slice::from_raw_parts(
512                    content.as_ptr() as *const u8,
513                    content.len() * T::datum_type().size_of(),
514                )
515            }
516        };
517        unsafe { Self::from_raw_dt_align(T::datum_type(), &[content.len()], bytes, align) }
518    }
519
520    /// Get the number of dimensions (or axes) of the tensor.
521    #[inline]
522    pub fn rank(&self) -> usize {
523        self.shape.len()
524    }
525
526    /// Get the shape of the tensor.
527    #[inline]
528    pub fn shape(&self) -> &[usize] {
529        &self.shape
530    }
531
532    /// Get the number of values in the tensor.
533    #[inline]
534    #[allow(clippy::len_without_is_empty)]
535    pub fn len(&self) -> usize {
536        self.len
537    }
538
539    /// Get the number of valeus in the tensor.
540    #[inline]
541    #[allow(clippy::len_without_is_empty)]
542    pub fn volume(&self) -> usize {
543        self.len
544    }
545
546    /// Get the shape of the tensor.
547    #[inline]
548    pub fn strides(&self) -> &[isize] {
549        &self.strides
550    }
551
552    fn update_strides_and_len(&mut self) {
553        self.strides.clear();
554        if self.shape.len() == 0 {
555            self.len = 1;
556            return;
557        }
558        compute_natural_stride_to(&mut self.strides, &self.shape);
559        self.len = unsafe { *self.strides.get_unchecked(0) as usize * self.shape.get_unchecked(0) };
560    }
561
562    /// Force the tensor shape, no consistency check.
563    pub unsafe fn set_shape_unchecked(&mut self, shape: &[usize]) {
564        if shape != &*self.shape {
565            self.shape.clear();
566            self.shape.extend_from_slice(shape);
567            self.update_strides_and_len();
568        }
569    }
570
571    /// Force the tensor shape and strides, no consistency check.
572    pub unsafe fn set_geometry_unchecked(&mut self, shape: &[usize], strides: &[isize]) {
573        self.shape.clear();
574        self.shape.extend_from_slice(shape);
575        self.strides.clear();
576        self.strides.extend_from_slice(strides);
577    }
578
579    /// Force the tensor shape.
580    pub fn set_shape(&mut self, shape: &[usize]) -> TractResult<()> {
581        if self.len() != shape.iter().product::<usize>() {
582            bail!("Invalid reshape {:?} to {:?}", self.shape, shape);
583        }
584        unsafe { self.set_shape_unchecked(shape) }
585        Ok(())
586    }
587
588    pub fn permute_axes(self, axes: &[usize]) -> TractResult<Tensor> {
589        ensure!(axes.iter().duplicates().next().is_none());
590        ensure!(axes.iter().all(|a| *a < self.rank()));
591        unsafe {
592            #[inline]
593            unsafe fn permute<T: Datum>(axes: &[usize], input: Tensor) -> Tensor {
594                unsafe { input.into_array_unchecked::<T>().permuted_axes(axes).into_tensor() }
595            }
596            let dt = self.datum_type();
597            let mut t = dispatch_datum_by_size!(permute(self.datum_type())(axes, self));
598            t.set_datum_type(dt);
599            Ok(t)
600        }
601    }
602
603    pub fn move_axis(self, from: usize, to: usize) -> TractResult<Tensor> {
604        let mut permutation: Vec<usize> = (0..self.rank()).collect();
605        permutation.remove(from);
606        permutation.insert(to, from);
607        self.permute_axes(&permutation)
608    }
609
610    pub fn collapse_axis_with_next(mut self, axis: usize) -> Tensor {
611        let removed = self.shape.remove(axis + 1);
612        self.shape[axis] *= removed;
613        self.update_strides_and_len();
614        self
615    }
616
617    pub fn split_axis(mut self, axis: usize, outer_dim: usize) -> TractResult<Tensor> {
618        if self.shape[axis] % outer_dim != 0 {
619            bail!(
620                "Invalid axis split, shape is {:?}, axis split at {}, outer {}",
621                self.shape,
622                axis,
623                outer_dim
624            );
625        }
626        self.shape.insert(axis + 1, self.shape[axis] / outer_dim);
627        self.shape[axis] = outer_dim;
628        self.update_strides_and_len();
629        Ok(self)
630    }
631
632    /// Reshape the tensor to `shape`.
633    pub fn into_shape(mut self, shape: &[usize]) -> TractResult<Tensor> {
634        self.set_shape(shape)?;
635        Ok(self)
636    }
637
638    pub fn insert_axis(&mut self, axis: usize) -> TractResult<()> {
639        self.shape.insert(axis, 1);
640        self.strides.insert(axis, self.strides.get(axis).copied().unwrap_or(1));
641        Ok(())
642    }
643
644    pub fn remove_axis(&mut self, axis: usize) -> TractResult<()> {
645        ensure!(self.shape[axis] == 1, "Remove a non-1 axis: axis {} in {:?}", axis, self);
646        self.shape.remove(axis);
647        self.strides.remove(axis);
648        Ok(())
649    }
650
651    pub fn broadcast_into_rank(mut self, rank: usize) -> TractResult<Tensor> {
652        self.broadcast_to_rank(rank)?;
653        self.update_strides_and_len();
654        Ok(self)
655    }
656
657    pub fn broadcast_to_rank(&mut self, rank: usize) -> TractResult<()> {
658        if rank < self.rank() {
659            bail!("Can only broadcast to higher rank")
660        }
661        while self.shape.len() < rank {
662            self.shape.insert(0, 1)
663        }
664        self.update_strides_and_len();
665        Ok(())
666    }
667
668    pub fn broadcast_scalar_to_shape(&self, shape: &[usize]) -> TractResult<Tensor> {
669        if self.rank() > 0 {
670            bail!("broadcast_scalar_to_shape called on {:?}, which is not a salar", self);
671        }
672        unsafe fn make<T: Datum>(src: &Tensor, dst: &mut Tensor) {
673            unsafe {
674                let value: &T = src.to_scalar_unchecked::<T>();
675                dst.as_slice_mut_unchecked::<T>().iter_mut().for_each(|item| *item = value.clone())
676            };
677        }
678        unsafe {
679            let mut t = Tensor::uninitialized_dt(self.datum_type(), shape)?;
680            dispatch_datum_by_size!(make(self.datum_type())(self, &mut t));
681            Ok(t)
682        }
683    }
684
685    fn broadcast_to_shape_t<T: Datum>(&self, shape: &[usize]) -> TractResult<Tensor> {
686        unsafe {
687            let view = self.to_array_view_unchecked::<T>();
688            let mut output = view
689                .broadcast(shape)
690                .with_context(|| format!("Broadcasting {view:?} to {shape:?}"))?
691                .into_owned()
692                .into_tensor();
693            output.set_datum_type(self.datum_type());
694            Ok(output)
695        }
696    }
697
698    pub fn broadcast_to_shape(&self, shape: &[usize]) -> TractResult<Tensor> {
699        dispatch_datum!(Self::broadcast_to_shape_t(self.dt)(self, shape))
700    }
701
702    pub fn broadcast_vector_to_shape(&self, shape: &[usize], axis: usize) -> TractResult<Tensor> {
703        ensure!(self.rank() == 1);
704        ensure!(shape[axis] == self.len());
705        if !self.datum_type().is_copy() {
706            let mut vec_shape = vec![1; shape.len()];
707            vec_shape[axis] = self.len();
708            return self.clone().into_shape(&vec_shape)?.broadcast_to_shape(shape);
709        }
710        unsafe {
711            let mut output = Tensor::uninitialized_dt(self.datum_type(), shape)?;
712            if output.len() == 0 {
713                return Ok(output);
714            }
715            let inner_len = shape[axis + 1..].iter().product::<usize>();
716
717            unsafe fn splat<T>(input: &Tensor, output: &mut Tensor, inner_len: usize)
718            where
719                T: Datum + Copy,
720            {
721                unsafe {
722                    for ix in 0..input.len() {
723                        let value: T = input.as_slice_unchecked()[ix];
724                        output.as_slice_mut_unchecked::<T>()[ix * inner_len..(ix + 1) * inner_len]
725                            .iter_mut()
726                            .for_each(|item| *item = value);
727                    }
728                }
729            }
730            dispatch_copy_by_size!(splat(self.datum_type())(&self, &mut output, inner_len));
731
732            let outer_len = shape[0..axis].iter().product::<usize>();
733            let repeat_bytes_len = inner_len * self.as_bytes().len();
734            let bytes = output.as_bytes_mut();
735            for ix in 1..outer_len {
736                bytes.copy_within(0..repeat_bytes_len, ix * repeat_bytes_len);
737            }
738
739            Ok(output)
740        }
741    }
742
743    pub fn assign_slice(
744        &mut self,
745        range: impl std::ops::RangeBounds<usize>,
746        src: &Tensor,
747        src_range: impl std::ops::RangeBounds<usize>,
748        axis: usize,
749    ) -> TractResult<()> {
750        ensure!(self.rank() == src.rank());
751        ensure!(axis < self.rank());
752        let range = clip_range_bounds(self.shape[axis], range);
753        let src_range = clip_range_bounds(src.shape[axis], src_range);
754        ensure!(
755            src.datum_type() == self.datum_type(),
756            "Attempt to assign into {:?} from {:?}, datum type mismatch",
757            self.datum_type(),
758            src.datum_type()
759        );
760        ensure!(
761            src_range.len() == range.len(),
762            "Attempt to assign a range of {:?} from a range of {:?}",
763            range,
764            src_range,
765        );
766        ensure!(
767            itertools::izip!(0.., self.shape(), src.shape())
768                .all(|(ix, dst, src)| ix == axis || src == dst),
769            "Attempt to assign a {}-axis range of {:?} from a range of {:?}",
770            axis,
771            self,
772            src
773        );
774        ensure!(
775            src_range.end <= src.shape()[axis],
776            "Assigning from invalid slice (axis {}, {:?}) of {:?}",
777            axis,
778            src_range,
779            src
780        );
781        ensure!(
782            range.end <= self.shape()[axis],
783            "Assigning to invalid slice (axis {}, {:?}) of {:?}",
784            axis,
785            range,
786            self
787        );
788        unsafe { self.assign_slice_from_resolved(range, src, src_range, axis) };
789        Ok(())
790    }
791
792    pub unsafe fn assign_slice_unchecked(
793        &mut self,
794        range: impl std::ops::RangeBounds<usize>,
795        src: &Tensor,
796        src_range: impl std::ops::RangeBounds<usize>,
797        axis: usize,
798    ) {
799        let range = clip_range_bounds(self.shape[axis], range);
800        let src_range = clip_range_bounds(src.shape[axis], src_range);
801        unsafe { self.assign_slice_from_resolved(range, src, src_range, axis) };
802    }
803
804    #[allow(clippy::ptr_eq)]
805    unsafe fn assign_slice_from_resolved(
806        &mut self,
807        range: std::ops::Range<usize>,
808        src: &Tensor,
809        src_range: std::ops::Range<usize>,
810        axis: usize,
811    ) {
812        unsafe {
813            use ndarray::Slice;
814            unsafe fn assign_slice_t<T: Datum>(
815                to: &mut Tensor,
816                to_range: Range<usize>,
817                from: &Tensor,
818                from_range: Range<usize>,
819                axis: usize,
820            ) {
821                unsafe {
822                    to.to_array_view_mut_unchecked::<T>()
823                        .slice_axis_mut(Axis(axis), Slice::from(to_range))
824                        .assign(
825                            &from
826                                .to_array_view_unchecked::<T>()
827                                .slice_axis(Axis(axis), Slice::from(from_range)),
828                        )
829                }
830            }
831            if self.datum_type().is_copy() && self.shape[..axis].iter().all(|d| *d == 1) {
832                let stride = self.strides[axis] as usize * self.datum_type().size_of();
833                let dst_start = (stride * range.start) as isize;
834                let src_start = (stride * src_range.start) as isize;
835                let len = stride * range.len();
836                if len > 0 {
837                    if self.plain_storage().as_ptr() != src.plain_storage().as_ptr() {
838                        std::ptr::copy_nonoverlapping(
839                            src.plain_storage().as_ptr().offset(src_start),
840                            self.plain_storage_mut().as_mut_ptr().offset(dst_start),
841                            len,
842                        );
843                    } else {
844                        std::ptr::copy(
845                            src.plain_storage().as_ptr().offset(src_start),
846                            self.plain_storage_mut().as_mut_ptr().offset(dst_start),
847                            len,
848                        );
849                    }
850                }
851            } else {
852                dispatch_datum!(assign_slice_t(self.datum_type())(
853                    self, range, src, src_range, axis
854                ));
855            }
856        }
857    }
858
859    /// Get the datum type of the tensor.
860    #[inline]
861    pub fn datum_type(&self) -> DatumType {
862        self.dt
863    }
864
865    /// Set the datum type of the tensor.
866    #[inline]
867    pub unsafe fn set_datum_type(&mut self, dt: DatumType) {
868        self.dt = dt
869    }
870
871    /// Dump the tensor in a human readable form.
872    ///
873    /// `force_full` will force the tensor to be dump in full even if it is big.
874    pub fn dump(&self, force_full: bool) -> TractResult<String> {
875        if self.is_exotic() {
876            return Ok(format!(
877                "{},{:?} (non-plain storage)",
878                self.shape.iter().join(","),
879                self.dt,
880            ));
881        }
882        unsafe fn dump_t<D: Datum>(tensor: &Tensor, n: usize) -> String {
883            unsafe {
884                if let Some(qp) = tensor.datum_type().qparams() {
885                    let integers = tensor.cast_to::<i32>().unwrap();
886                    integers.as_slice_unchecked::<i32>()[0..n]
887                        .iter()
888                        .map(|x| format!("[{}]({})", x, qp.dq(*x)))
889                        .join(", ")
890                } else {
891                    tensor.as_slice_unchecked::<D>()[0..n].iter().join(", ")
892                }
893            }
894        }
895        unsafe {
896            let trunc = self.len() > 12 && !force_full;
897            let data = dispatch_datum!(dump_t(self.datum_type())(
898                self,
899                if trunc { 12 } else { self.len() }
900            ));
901            Ok(format!(
902                "{},{:?} {}{}",
903                self.shape.iter().join(","),
904                self.dt,
905                data,
906                if trunc { "..." } else { "" }
907            ))
908        }
909    }
910
911    /// Compare two tensors, allowing for rounding errors.
912    pub fn close_enough(
913        &self,
914        other: &Self,
915        approx: impl Into<Approximation> + std::fmt::Debug,
916    ) -> TractResult<()> {
917        let approx = approx.into();
918        if self.shape() != other.shape() {
919            bail!("Shape mismatch {:?} != {:?}", self.shape(), other.shape())
920        }
921        let (atol, rtol, outliers) = approx.atol_rtol_outliers(&self.datum_type());
922        let ma = self.cast_to::<f32>()?;
923        let ma = ma.to_plain_array_view::<f32>()?;
924        let mb = other.cast_to::<f32>()?;
925        let mb = mb.to_plain_array_view::<f32>()?;
926        let mut first_outlier = None;
927        let mut outliers_count = 0;
928        ndarray::indices_of(&ma).into_iter().for_each(|indices| {
929            let a = ma[&indices];
930            let b = mb[&indices];
931            if !((a.is_nan() && b.is_nan())
932                || (a.is_infinite() && b.is_infinite() && a.signum() == b.signum())
933                || (a - b).abs() <= atol as f32 + rtol as f32 * b.abs())
934            {
935                if outliers_count == 0 {
936                    first_outlier = Some(indices.as_array_view().to_vec());
937                }
938                outliers_count += 1;
939            }
940        });
941        if self.volume() > 0 && outliers_count as f64 / self.volume() as f64 > outliers {
942            let indices = first_outlier.unwrap();
943            let a = ma[&*indices];
944            let b = mb[&*indices];
945            bail!(
946                "Mismatch. First outlier: {:?} for {:?}) at {:?} {} != {}. Outliers: {} / {} = {:0.5} > {:0.5}.",
947                approx,
948                self.datum_type(),
949                indices,
950                a,
951                b,
952                outliers_count,
953                self.volume(),
954                outliers_count as f64 / self.volume() as f64,
955                outliers
956            );
957        }
958        Ok(())
959    }
960
961    /// Transform the tensor into a `ndarray::Array`.
962    pub fn into_plain_array<D: Datum>(self) -> TractResult<ArrayD<D>> {
963        Ok(self.to_plain_array_view::<D>()?.to_owned())
964    }
965
966    /// Transform the tensor into a `ndarray::Array`.
967    pub unsafe fn into_array_unchecked<D: Datum>(self) -> ArrayD<D> {
968        unsafe { self.to_array_view_unchecked::<D>().to_owned() }
969    }
970
971    /// Returns a plain array view of the tensor.
972    ///
973    /// Errors if the storage is not plain or the datum type does not match `D`.
974    #[inline]
975    pub fn to_plain_array_view<D: Datum>(&self) -> TractResult<ArrayViewD<'_, D>> {
976        self.try_as_plain()?.to_array_view::<D>()
977    }
978
979    /// Returns a mutable plain array view of the tensor.
980    ///
981    /// Errors if the storage is not plain or the datum type does not match `D`.
982    #[inline]
983    pub fn to_plain_array_view_mut<D: Datum>(&mut self) -> TractResult<ArrayViewMutD<'_, D>> {
984        self.check_for_access::<D>()?;
985        ensure!(self.storage.as_plain_mut().is_some(), "Tensor storage is not plain");
986        unsafe { Ok(self.to_array_view_mut_unchecked()) }
987    }
988
989    fn check_for_access<D: Datum>(&self) -> TractResult<()> {
990        ensure!(
991            self.datum_type().unquantized() == D::datum_type().unquantized(),
992            "Tensor datum type error: tensor is {:?}, accessed as {:?}",
993            self.datum_type(),
994            D::datum_type(),
995        );
996        Ok(())
997    }
998
999    /// Transform the data as a `ndarray::Array`.
1000    pub unsafe fn to_array_view_unchecked<D: Datum>(&self) -> ArrayViewD<'_, D> {
1001        if self.len() != 0 {
1002            unsafe {
1003                ArrayViewD::from_shape_ptr(&*self.shape, self.plain_storage().as_ptr() as *const D)
1004            }
1005        } else {
1006            ArrayViewD::from_shape(&*self.shape, &[]).unwrap()
1007        }
1008    }
1009
1010    /// Transform the data as a mutable `ndarray::Array`.
1011    pub unsafe fn to_array_view_mut_unchecked<D: Datum>(&mut self) -> ArrayViewMutD<'_, D> {
1012        if self.len() != 0 {
1013            unsafe {
1014                let ptr = self.plain_storage_mut().as_mut_ptr() as *mut D;
1015                ArrayViewMutD::from_shape_ptr(&*self.shape, ptr)
1016            }
1017        } else {
1018            ArrayViewMutD::from_shape(&*self.shape, &mut []).unwrap()
1019        }
1020    }
1021
1022    /// Access the data as a pointer.
1023    pub fn as_ptr<D: Datum>(&self) -> TractResult<*const D> {
1024        self.check_for_access::<D>()?;
1025        Ok(self.plain_storage().as_ptr() as *const D)
1026    }
1027
1028    /// Access the data as a pointer.
1029    pub unsafe fn as_ptr_unchecked<D: Datum>(&self) -> *const D {
1030        self.plain_storage().as_ptr() as *const D
1031    }
1032
1033    /// Access the data as a pointer.
1034    pub unsafe fn as_ptr_mut_unchecked<D: Datum>(&mut self) -> *mut D {
1035        self.plain_storage_mut().as_mut_ptr() as *mut D
1036    }
1037
1038    /// Access the data as a mutable pointer.
1039    pub fn as_ptr_mut<D: Datum>(&mut self) -> TractResult<*mut D> {
1040        self.as_ptr::<D>().map(|p| p as *mut D)
1041    }
1042
1043    /// Access the data as a slice.
1044    pub unsafe fn as_slice_unchecked<D: Datum>(&self) -> &[D] {
1045        if self.storage.byte_len() == 0 {
1046            &[]
1047        } else {
1048            unsafe { std::slice::from_raw_parts::<D>(self.as_ptr_unchecked(), self.len()) }
1049        }
1050    }
1051
1052    /// Access the data as a mutable slice.
1053    pub unsafe fn as_slice_mut_unchecked<D: Datum>(&mut self) -> &mut [D] {
1054        if self.storage.byte_len() == 0 {
1055            &mut []
1056        } else {
1057            unsafe { std::slice::from_raw_parts_mut::<D>(self.as_ptr_mut_unchecked(), self.len()) }
1058        }
1059    }
1060
1061    /// Make the tensor a scalar tensor (assumes it contains a single value).
1062    pub fn to_scalar_tensor(&self) -> TractResult<Tensor> {
1063        fn to_scalar_tensor_t<D: Datum>(t: &Tensor) -> TractResult<Tensor> {
1064            Ok(litteral::tensor0(t.try_as_plain()?.to_scalar::<D>()?.clone()))
1065        }
1066        dispatch_datum!(to_scalar_tensor_t(self.datum_type())(self))
1067    }
1068
1069    /// Access the data as a scalar.
1070    pub unsafe fn to_scalar_unchecked<D: Datum>(&self) -> &D {
1071        unsafe { &*(self.plain_storage().as_ptr() as *const D) }
1072    }
1073
1074    /// Mutable access the data as a scalar.
1075    pub fn to_scalar_mut<D: Datum>(&mut self) -> TractResult<&mut D> {
1076        self.check_for_access::<D>()?;
1077        if self.len() == 0 {
1078            bail!("to_scalar_mut called on empty tensor ({:?})", self)
1079        }
1080        if self.len() > 1 {
1081            bail!("to_scalar called on a tensor with multiple values ({:?})", self)
1082        }
1083        unsafe { Ok(self.to_scalar_mut_unchecked()) }
1084    }
1085
1086    /// Mutable access the data as a scalar.
1087    pub unsafe fn to_scalar_mut_unchecked<D: Datum>(&mut self) -> &mut D {
1088        unsafe { &mut *(self.plain_storage_mut().as_mut_ptr() as *mut D) }
1089    }
1090
1091    pub fn as_bytes(&self) -> &[u8] {
1092        self.plain_storage().as_bytes()
1093    }
1094
1095    pub fn as_bytes_mut(&mut self) -> &mut [u8] {
1096        self.plain_storage_mut().as_bytes_mut()
1097    }
1098
1099    unsafe fn is_uniform_t<T: Datum>(&self) -> bool {
1100        let slice = unsafe { self.as_slice_unchecked::<T>() };
1101        slice[1..].iter().all(|x| x == &slice[0])
1102    }
1103
1104    pub fn is_uniform(&self) -> bool {
1105        if self.is_exotic() {
1106            return false;
1107        }
1108        if self.len() <= 1 {
1109            return true;
1110        }
1111        unsafe { dispatch_datum!(Tensor::is_uniform_t(self.datum_type())(self)) }
1112    }
1113
1114    unsafe fn as_uniform_t<T: Datum>(&self) -> Tensor {
1115        let v: T = unsafe { self.as_slice_unchecked::<T>() }[0].clone();
1116        litteral::tensor0(v)
1117    }
1118
1119    pub fn as_uniform(&self) -> Option<Tensor> {
1120        if self.len() >= 1 && self.is_uniform() {
1121            unsafe {
1122                let mut t = dispatch_datum!(Tensor::as_uniform_t(self.datum_type())(self));
1123                t.set_datum_type(self.datum_type());
1124                Some(t)
1125            }
1126        } else {
1127            None
1128        }
1129    }
1130
1131    pub fn is_all_zero(&self) -> TractResult<bool> {
1132        Ok(self.len() == 0 || self.as_uniform().map(|t| t.is_zero().unwrap()).unwrap_or(false))
1133    }
1134
1135    pub fn is_zero(&self) -> TractResult<bool> {
1136        Ok(self == &Tensor::zero_scalar_dt(self.dt)?)
1137    }
1138
1139    unsafe fn natural_cast<
1140        Source: Datum + num_traits::AsPrimitive<Target>,
1141        Target: Datum + Copy,
1142    >(
1143        &self,
1144        other: &mut Tensor,
1145    ) {
1146        unsafe {
1147            self.as_slice_unchecked::<Source>()
1148                .iter()
1149                .zip(other.as_slice_mut_unchecked::<Target>().iter_mut())
1150                .for_each(|(s, d)| *d = s.as_())
1151        };
1152    }
1153
1154    unsafe fn cast_number_to_bool<Source: Datum + num_traits::Zero>(&self, other: &mut Tensor) {
1155        unsafe {
1156            self.as_slice_unchecked::<Source>()
1157                .iter()
1158                .zip(other.as_slice_mut_unchecked::<bool>().iter_mut())
1159                .for_each(|(s, d)| *d = !s.is_zero());
1160        }
1161    }
1162
1163    unsafe fn cast_from_string<Target: Datum + core::str::FromStr>(
1164        &self,
1165        other: &mut Tensor,
1166    ) -> TractResult<()> {
1167        unsafe {
1168            for (s, d) in self
1169                .as_slice_unchecked::<String>()
1170                .iter()
1171                .zip(other.as_slice_mut_unchecked::<Target>().iter_mut())
1172            {
1173                *d = s
1174                    .parse()
1175                    .map_err(|_| format_err!("Can not parse as {:?}", Target::datum_type()))?;
1176            }
1177            Ok(())
1178        }
1179    }
1180
1181    unsafe fn cast_to_string<Source: Datum>(&self, other: &mut Tensor) {
1182        unsafe {
1183            for (s, d) in self
1184                .as_slice_unchecked::<Source>()
1185                .iter()
1186                .zip(other.as_slice_mut_unchecked::<String>().iter_mut())
1187            {
1188                *d = s.to_string()
1189            }
1190        }
1191    }
1192
1193    /// Optionnaly convert data to a tensor for a new DatumType.
1194    pub fn cast_to<D: Datum>(&self) -> TractResult<Cow<'_, Tensor>> {
1195        self.cast_to_dt(D::datum_type())
1196    }
1197
1198    /// Optionnaly convert data to a tensor for a new DatumType.
1199    #[allow(clippy::redundant_closure_call)]
1200    pub fn cast_to_dt(&self, dst_dt: DatumType) -> TractResult<Cow<'_, Tensor>> {
1201        unsafe {
1202            if self.dt == dst_dt {
1203                return Ok(Cow::Borrowed(self));
1204            }
1205            if self.dt == TDim::datum_type() && (dst_dt.is_integer() || dst_dt.is_float()) {
1206                let slice = self.as_slice_unchecked::<TDim>();
1207                let mut ints = Self::uninitialized::<i64>(&self.shape)?;
1208                let ints_slice = ints.as_slice_mut_unchecked::<i64>();
1209                for i in 0..self.len() {
1210                    ints_slice[i] = slice[i].to_i64()?;
1211                }
1212                return Ok(Cow::Owned(ints.cast_to_dt(dst_dt)?.into_owned()));
1213            }
1214            if self.dt == bool::datum_type()
1215                && (dst_dt.is_integer() || dst_dt.is_float() || dst_dt == TDim::datum_type())
1216            {
1217                let slice = self.as_slice_unchecked::<bool>();
1218                let mut ints = Self::uninitialized::<i8>(&self.shape)?;
1219                let ints_slice = ints.as_slice_mut_unchecked::<i8>();
1220                for i in 0..self.len() {
1221                    ints_slice[i] = slice[i] as usize as i8;
1222                }
1223                return Ok(Cow::Owned(ints.cast_to_dt(dst_dt)?.into_owned()));
1224            }
1225            let mut result = Self::uninitialized_dt(dst_dt, &self.shape)?;
1226            if self.dt == DatumType::String {
1227                dispatch_numbers!(Self::cast_from_string(dst_dt)(self, &mut result))?;
1228                return Ok(Cow::Owned(result));
1229            }
1230            if dst_dt == DatumType::String {
1231                dispatch_datum!(Self::cast_to_string(self.dt)(self, &mut result));
1232                return Ok(Cow::Owned(result));
1233            }
1234            macro_rules! n {
1235                ($source:ty) => {
1236                    if <$source>::datum_type() == self.datum_type() {
1237                        match dst_dt {
1238                            DatumType::I8 => self.natural_cast::<$source, i8>(&mut result),
1239                            DatumType::I16 => self.natural_cast::<$source, i16>(&mut result),
1240                            DatumType::I32 => self.natural_cast::<$source, i32>(&mut result),
1241                            DatumType::I64 => self.natural_cast::<$source, i64>(&mut result),
1242                            DatumType::U8 => self.natural_cast::<$source, u8>(&mut result),
1243                            DatumType::U16 => self.natural_cast::<$source, u16>(&mut result),
1244                            DatumType::U32 => self.natural_cast::<$source, u32>(&mut result),
1245                            DatumType::U64 => self.natural_cast::<$source, u64>(&mut result),
1246                            DatumType::F16 => self.natural_cast::<$source, f16>(&mut result),
1247                            DatumType::F32 => self.natural_cast::<$source, f32>(&mut result),
1248                            DatumType::F64 => self.natural_cast::<$source, f64>(&mut result),
1249                            DatumType::TDim => {
1250                                let ints = self.cast_to::<i32>()?;
1251                                let slice = ints.as_slice_unchecked::<i32>();
1252                                let result = result.as_slice_mut_unchecked::<TDim>();
1253                                for i in 0..self.len() {
1254                                    result[i] = slice[i].into();
1255                                }
1256                            }
1257                            DatumType::Bool => self.cast_number_to_bool::<$source>(&mut result),
1258                            _ => todo!(),
1259                        }
1260                        return Ok(Cow::Owned(result));
1261                    };
1262                };
1263            }
1264            //If there is no quantization
1265            if !dst_dt.is_quantized() && !self.datum_type().is_quantized() {
1266                n!(u8);
1267                n!(u16);
1268                n!(u32);
1269                n!(u64);
1270                n!(i8);
1271                n!(i16);
1272                n!(i32);
1273                n!(i64);
1274                n!(f16);
1275                n!(f32);
1276                n!(f64);
1277            } else {
1278                let (s_zp, s_scale) = self.datum_type().zp_scale();
1279                let (d_zp, d_scale) = dst_dt.zp_scale();
1280                if self.datum_type().is_quantized() && dst_dt.is_float() {
1281                    macro_rules! q_to_fp {
1282                        ($source:ty, $dest:ty) => {
1283                            if <$source>::datum_type().unquantized()
1284                                == self.datum_type().unquantized()
1285                                && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1286                            {
1287                                self.as_slice_unchecked::<$source>()
1288                                    .iter()
1289                                    .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1290                                    .for_each(|(&s, d)| {
1291                                        *d = (s as $dest - s_zp as $dest) * s_scale as $dest;
1292                                    });
1293                                return Ok(Cow::Owned(result));
1294                            }
1295                        };
1296                    }
1297                    q_to_fp!(i8, f64);
1298                    q_to_fp!(i8, f32);
1299                    q_to_fp!(u8, f64);
1300                    q_to_fp!(u8, f32);
1301                }
1302                //TODO: optimize scale_by
1303                macro_rules! q8_to_q8 {
1304                    ($typ:ty) => {
1305                        if dst_dt.unquantized() == <$typ>::datum_type() {
1306                            self.as_slice_unchecked::<$typ>()
1307                                .iter()
1308                                .zip(result.as_slice_mut_unchecked::<$typ>().iter_mut())
1309                                .for_each(|(&s, d)| {
1310                                    *d = (d_zp as i32
1311                                        + scale_by(s as i32 - s_zp as i32, s_scale / d_scale))
1312                                    .clamp_cast()
1313                                });
1314                            return Ok(Cow::Owned(result));
1315                        }
1316                    };
1317                }
1318
1319                macro_rules! q_via_f32 {
1320                    ($source:ty, $dest:ty, $round:expr) => {
1321                        if <$source>::datum_type().unquantized() == self.datum_type().unquantized()
1322                            && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1323                        {
1324                            self.as_slice_unchecked::<$source>()
1325                                .iter()
1326                                .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1327                                .for_each(|(&s, d)| {
1328                                    let s_float = (s as f32 - s_zp as f32) * s_scale as f32;
1329                                    let d_float = s_float as f32 / d_scale as f32 + d_zp as f32;
1330                                    *d = $round(d_float);
1331                                });
1332                            return Ok(Cow::Owned(result));
1333                        }
1334                    };
1335                }
1336
1337                macro_rules! q_n {
1338                    (clamp $source:ty, $dest:ty) => {{
1339                        if <$source>::datum_type().unquantized() == self.datum_type().unquantized()
1340                            && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1341                        {
1342                            self.as_slice_unchecked::<$source>()
1343                                .iter()
1344                                .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1345                                .for_each(|(&s, d)| {
1346                                    *d = s.clamp_cast();
1347                                });
1348                            return Ok(Cow::Owned(result));
1349                        }
1350                    }};
1351                    ($source:ty, $dest:ty) => {{
1352                        if <$source>::datum_type().unquantized() == self.datum_type().unquantized()
1353                            && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1354                        {
1355                            self.as_slice_unchecked::<$source>()
1356                                .iter()
1357                                .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1358                                .for_each(|(&s, d)| {
1359                                    *d = s as $dest;
1360                                });
1361                            return Ok(Cow::Owned(result));
1362                        }
1363                    }};
1364                }
1365
1366                if dst_dt.unquantized() == self.datum_type().unquantized()
1367                    && dst_dt.is_quantized()
1368                    && self.datum_type().is_quantized()
1369                {
1370                    q8_to_q8!(i8);
1371                    q8_to_q8!(u8);
1372                }
1373
1374                q_via_f32!(f32, i8, |f| round_ties_to_even(f).clamp_cast());
1375                q_via_f32!(f32, u8, |f| round_ties_to_even(f).clamp_cast());
1376                q_via_f32!(f32, i32, |f| round_ties_to_even(f).clamp_cast());
1377                q_via_f32!(i8, f32, |f| f);
1378                q_via_f32!(u8, f32, |f| f);
1379                q_via_f32!(i32, f32, |f| f);
1380
1381                if dst_dt.is_quantized() && self.datum_type().is_quantized() {
1382                    q_via_f32!(u8, i8, |f| round_ties_to_even(f).clamp_cast());
1383                    q_via_f32!(i8, u8, |f| round_ties_to_even(f).clamp_cast());
1384                    q_via_f32!(i32, u8, |f| round_ties_to_even(f).clamp_cast());
1385                    q_via_f32!(i32, i8, |f| round_ties_to_even(f).clamp_cast());
1386                    q_via_f32!(u8, i32, |f| round_ties_to_even(f).clamp_cast());
1387                    q_via_f32!(i8, i32, |f| round_ties_to_even(f).clamp_cast());
1388
1389                    // ensure cast to different scale offset work
1390                    q_via_f32!(i8, i8, |f| round_ties_to_even(f).clamp_cast());
1391                    q_via_f32!(u8, u8, |f| round_ties_to_even(f).clamp_cast());
1392                }
1393
1394                q_n!(i8, i32);
1395                q_n!(i8, u32);
1396                q_n!(u8, i32);
1397                q_n!(u8, u32);
1398                q_n!(clamp i32, i8);
1399                q_n!(clamp i32, u8);
1400                q_n!(clamp u32, i8);
1401                q_n!(clamp u32, u8);
1402                q_n!(i8, i8);
1403                q_n!(u8, u8);
1404                q_n!(i32, i32);
1405                q_n!(u32, u32);
1406            }
1407
1408            bail!("Unsupported cast from {:?} to {:?}", self.dt, dst_dt)
1409        }
1410    }
1411
1412    /// Access the data as a scalar, after a cast.
1413    pub fn cast_to_scalar<D: Datum + Copy>(&self) -> TractResult<D> {
1414        let casted = self.cast_to::<D>()?;
1415        casted.try_as_plain()?.to_scalar::<D>().copied()
1416    }
1417
1418    /// Access the nth element of the tensor, returned as a 0-rank Tensor
1419    pub fn nth(&self, nth: usize) -> TractResult<Tensor> {
1420        if nth >= self.len() {
1421            bail!(
1422                "nth called with {}th element on a tensor of len {} ({:?}",
1423                nth,
1424                self.len(),
1425                self
1426            );
1427        }
1428        unsafe fn nth_t<T: Datum>(me: &Tensor, nth: usize, output: &mut Tensor) {
1429            unsafe {
1430                let value = me.as_slice_unchecked::<T>()[nth].clone();
1431                output.as_slice_mut_unchecked::<T>()[0] = value;
1432            }
1433        }
1434        unsafe {
1435            let mut output = Tensor::uninitialized_dt(self.datum_type(), &[])?;
1436            dispatch_datum_by_size!(nth_t(self.datum_type())(self, nth, &mut output));
1437            Ok(output)
1438        }
1439    }
1440
1441    /// Strict equality test on tensors.
1442    fn eq_dt(&self, other: &Tensor) -> TractResult<bool> {
1443        unsafe fn eq_t<D: Datum>(me: &Tensor, other: &Tensor) -> TractResult<bool> {
1444            unsafe {
1445                if D::datum_type().is_float() {
1446                    return dispatch_floatlike!(float_eq_t(D::datum_type())(me, other));
1447                }
1448                Ok(izip!(me.as_slice_unchecked::<D>(), other.as_slice_unchecked::<D>())
1449                    .all(|(a, b)| a == b))
1450            }
1451        }
1452
1453        unsafe fn float_eq_t<D: Datum + Float>(me: &Tensor, other: &Tensor) -> TractResult<bool> {
1454            unsafe {
1455                Ok(izip!(me.as_slice_unchecked::<D>(), other.as_slice_unchecked::<D>())
1456                    .all(|(a, b)| (a.is_nan() && b.is_nan()) || a == b))
1457            }
1458        }
1459
1460        unsafe {
1461            Ok(self.datum_type() == other.datum_type()
1462                && self.shape() == other.shape()
1463                && dispatch_datum!(eq_t(self.dt)(self, other))?)
1464        }
1465    }
1466
1467    fn from_datum<T: Datum>(mut it: ArrayD<T>) -> Tensor {
1468        unsafe {
1469            let mut t = Self::uninitialized::<T>(it.shape()).unwrap();
1470            if let Some(slice) = it.as_slice_mut() {
1471                if t.datum_type().is_copy() {
1472                    std::ptr::copy_nonoverlapping(
1473                        slice.as_ptr() as *const i8,
1474                        t.as_ptr_mut_unchecked(),
1475                        t.plain_storage().layout().size(),
1476                    );
1477                } else {
1478                    t.as_slice_mut_unchecked::<T>()
1479                        .iter_mut()
1480                        .zip(slice.iter_mut())
1481                        .for_each(|(t, s)| *t = std::mem::take(s));
1482                }
1483                return t;
1484            }
1485            if it.strides().iter().all(|&s| s > 0) && it.as_slice_memory_order().is_some() {
1486                let mut len_and_strides: TVec<(usize, usize)> = tvec!();
1487                for (len, stride) in itertools::izip!(it.shape(), it.strides(), t.strides())
1488                    .sorted_by_key(|(_, src, _)| *src)
1489                    .map(|(l, _, dst)| (*l as isize, *dst))
1490                {
1491                    if !len_and_strides.is_empty()
1492                        && len_and_strides.last().unwrap().1 * len_and_strides.last().unwrap().0
1493                            == stride as usize
1494                    {
1495                        len_and_strides.last_mut().unwrap().0 *= len as usize;
1496                    } else {
1497                        len_and_strides.push((len as usize, stride as usize));
1498                    }
1499                }
1500                len_and_strides.reverse();
1501                crate::scatter::scatter_contig_data(
1502                    it.as_ptr(),
1503                    t.as_ptr_mut_unchecked(),
1504                    &len_and_strides,
1505                );
1506                return t;
1507            }
1508            // finally use ndarray into_iter()
1509            t.as_slice_mut_unchecked().iter_mut().zip(it).for_each(|(t, a)| *t = a);
1510            t
1511        }
1512    }
1513
1514    pub fn deep_clone(&self) -> Tensor {
1515        if self.is_exotic() {
1516            return Tensor {
1517                dt: self.dt,
1518                shape: self.shape.clone(),
1519                strides: self.strides.clone(),
1520                len: self.len,
1521                storage: self.storage.deep_clone(),
1522            };
1523        }
1524        unsafe {
1525            let mut tensor = Tensor::uninitialized_dt(self.datum_type(), self.shape()).unwrap();
1526            if self.len() > 0 {
1527                if self.dt.is_copy() {
1528                    self.plain_storage().as_ptr().copy_to_nonoverlapping(
1529                        tensor.as_bytes_mut().as_mut_ptr(),
1530                        self.plain_storage().layout().size(),
1531                    )
1532                } else if self.dt == DatumType::String {
1533                    tensor
1534                        .as_slice_mut_unchecked::<String>()
1535                        .clone_from_slice(self.as_slice_unchecked());
1536                } else if self.dt == DatumType::Blob {
1537                    tensor
1538                        .as_slice_mut_unchecked::<Blob>()
1539                        .clone_from_slice(self.as_slice_unchecked());
1540                } else if self.dt == DatumType::TDim {
1541                    tensor
1542                        .as_slice_mut_unchecked::<TDim>()
1543                        .clone_from_slice(self.as_slice_unchecked());
1544                }
1545            }
1546            tensor
1547        }
1548    }
1549
1550    pub fn slice(&self, axis: usize, start: usize, end: usize) -> TractResult<Tensor> {
1551        if axis >= self.rank() {
1552            bail!("Can not slice at axis {} tensor {:?}", axis, self);
1553        }
1554        if start > self.shape[axis] || end > self.shape[axis] || start >= end {
1555            bail!("Invalid slicing range {start}..{end} on axis {axis} for {self:?}");
1556        }
1557        fn slice_t<T: Datum>(
1558            t: &Tensor,
1559            axis: usize,
1560            start: usize,
1561            end: usize,
1562        ) -> TractResult<Tensor> {
1563            Ok(t.to_plain_array_view::<T>()?
1564                .slice_axis(ndarray::Axis(axis), (start..end).into())
1565                .into_owned()
1566                .into_tensor())
1567        }
1568        dispatch_datum!(slice_t(self.datum_type())(self, axis, start, end))
1569    }
1570
1571    #[inline]
1572    pub fn view(&self) -> view::TensorView<'_> {
1573        unsafe { view::TensorView::view(self) }
1574    }
1575
1576    #[inline]
1577    pub fn view_at_prefix(&self, prefix: &[usize]) -> TractResult<view::TensorView<'_>> {
1578        view::TensorView::at_prefix(self, prefix)
1579    }
1580
1581    #[inline]
1582    pub fn view_offsetting(&self, coords: &[usize]) -> TractResult<view::TensorView<'_>> {
1583        view::TensorView::offsetting(self, coords)
1584    }
1585
1586    #[inline]
1587    pub unsafe fn view_offsetting_unchecked(&self, coords: &[usize]) -> view::TensorView<'_> {
1588        unsafe { view::TensorView::offsetting_unchecked(self, coords) }
1589    }
1590
1591    #[inline]
1592    pub fn view_mut(&mut self) -> view::TensorView<'_> {
1593        unsafe { view::TensorView::view(self) }
1594    }
1595
1596    #[inline]
1597    pub fn view_at_prefix_mut(&mut self, prefix: &[usize]) -> TractResult<view::TensorView<'_>> {
1598        view::TensorView::at_prefix(self, prefix)
1599    }
1600
1601    #[inline]
1602    pub fn view_offsetting_mut(&mut self, coords: &[usize]) -> TractResult<view::TensorView<'_>> {
1603        view::TensorView::offsetting(self, coords)
1604    }
1605
1606    /// Offsets the tensor as an i8 type if it's an u8 type, otherwise passes it unchanged.
1607    pub fn offset_u8_as_i8(self: &Arc<Self>) -> Arc<Self> {
1608        let mut t = if let DatumType::U8 = self.dt.unquantized() {
1609            self.try_as_plain()
1610                .unwrap()
1611                .to_array_view::<u8>()
1612                .unwrap()
1613                .mapv(|v| v.wrapping_sub(128) as i8)
1614                .into_tensor()
1615        } else {
1616            return self.clone();
1617        };
1618
1619        if let DatumType::QU8(qp) = self.dt {
1620            if let QParams::ZpScale { zero_point, scale } = qp {
1621                t.dt = DatumType::QI8(QParams::ZpScale { zero_point: zero_point - 128, scale });
1622            } else {
1623                t.dt = DatumType::QI8(qp);
1624            }
1625        }
1626
1627        t.into_arc_tensor()
1628    }
1629
1630    /// Offsets the tensor as an u8 type if it's an i8 type, otherwise passes it unchanged.
1631    pub fn offset_i8_as_u8(self: &Arc<Self>) -> Arc<Self> {
1632        let mut t = if let DatumType::I8 = self.dt.unquantized() {
1633            self.try_as_plain()
1634                .unwrap()
1635                .to_array_view::<i8>()
1636                .unwrap()
1637                .mapv(|v| (v as u8).wrapping_add(128))
1638                .into_tensor()
1639        } else {
1640            return self.clone();
1641        };
1642
1643        if let DatumType::QI8(qp) = self.dt {
1644            if let QParams::ZpScale { zero_point, scale } = qp {
1645                t.dt = DatumType::QU8(QParams::ZpScale { zero_point: zero_point + 128, scale });
1646            } else {
1647                t.dt = DatumType::QU8(qp);
1648            }
1649        }
1650        t.into_arc_tensor()
1651    }
1652
1653    pub fn to_aligned_default(&self) -> TractResult<Self> {
1654        if self.dt.is_copy() {
1655            unsafe {
1656                let mut t = Self::uninitialized_dt(self.dt, &self.shape)?;
1657                t.as_bytes_mut().copy_from_slice(self.as_bytes());
1658                Ok(t)
1659            }
1660        } else {
1661            let mut t = Self::zero_dt(self.dt, &self.shape)?;
1662            if self.dt == String::datum_type() {
1663                t.try_as_plain_mut()?
1664                    .as_slice_mut::<String>()?
1665                    .clone_from_slice(self.try_as_plain()?.as_slice()?);
1666            } else if self.dt == Blob::datum_type() {
1667                t.try_as_plain_mut()?
1668                    .as_slice_mut::<Blob>()?
1669                    .clone_from_slice(self.try_as_plain()?.as_slice()?);
1670            } else if self.dt == TDim::datum_type() {
1671                t.try_as_plain_mut()?
1672                    .as_slice_mut::<TDim>()?
1673                    .clone_from_slice(self.try_as_plain()?.as_slice()?);
1674            }
1675            Ok(t)
1676        }
1677    }
1678
1679    pub fn natural_strides(shape: &[usize]) -> TVec<isize> {
1680        let mut strides = tvec!();
1681        compute_natural_stride_to(&mut strides, shape);
1682        strides
1683    }
1684
1685    pub fn into_blob(mut self) -> TractResult<Blob> {
1686        ensure!(self.dt.is_copy());
1687        let storage =
1688            std::mem::replace(&mut self.storage, StorageKind::Plain(PlainStorage::default()));
1689        Ok(storage.into_plain().context("Storage is not plain")?.into_blob())
1690    }
1691}
1692
1693impl PartialEq for Tensor {
1694    fn eq(&self, other: &Tensor) -> bool {
1695        if self.dt != other.dt || self.shape != other.shape {
1696            return false;
1697        }
1698        match (self.storage.as_plain(), other.storage.as_plain()) {
1699            (Some(_), Some(_)) => self.eq_dt(other).unwrap_or(false),
1700            (None, None) => self.storage == other.storage,
1701            _ => false,
1702        }
1703    }
1704}
1705
1706impl Eq for Tensor {}
1707
1708impl fmt::Debug for Tensor {
1709    fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
1710        let content = self.dump(false).unwrap_or_else(|e| format!("Error : {e:?}"));
1711        write!(formatter, "{content}")
1712    }
1713}
1714
1715#[cfg(feature = "complex")]
1716pub fn reinterpret_inner_dim_as_complex(mut t: Tensor) -> TractResult<Tensor> {
1717    ensure!(
1718        t.shape().last() == Some(&2),
1719        "The last dimension in the tensor shape {:?} must be 2",
1720        t.shape()
1721    );
1722    unsafe {
1723        t.shape.pop();
1724        t.set_datum_type(t.datum_type().complexify()?);
1725        t.update_strides_and_len();
1726        Ok(t)
1727    }
1728}
1729
1730#[cfg(feature = "complex")]
1731pub fn reinterpret_complex_as_inner_dim(mut t: Tensor) -> TractResult<Tensor> {
1732    unsafe {
1733        t.shape.push(2);
1734        t.set_datum_type(t.datum_type().decomplexify()?);
1735        t.update_strides_and_len();
1736        Ok(t)
1737    }
1738}
1739
1740pub fn clip_range_bounds(len: usize, range: impl std::ops::RangeBounds<usize>) -> Range<usize> {
1741    use std::ops::Bound;
1742    let start = match range.start_bound() {
1743        Bound::Included(ix) => *ix,
1744        Bound::Excluded(ix) => ix + 1,
1745        Bound::Unbounded => 0,
1746    };
1747    let end = match range.end_bound() {
1748        Bound::Included(ix) => *ix + 1,
1749        Bound::Excluded(ix) => *ix,
1750        Bound::Unbounded => len,
1751    };
1752    start..end
1753}
1754
1755pub fn natural_strides(shape: &[usize]) -> TVec<isize> {
1756    let mut strides = tvec!();
1757    compute_natural_stride_to(&mut strides, shape);
1758    strides
1759}
1760
1761fn compute_natural_stride_to(strides: &mut TVec<isize>, shape: &[usize]) {
1762    match shape.len() {
1763        0 => (),
1764        1 => strides.push(1),
1765        2 => strides.extend_from_slice(&[shape[1] as isize, 1]),
1766        3 => strides.extend_from_slice(&[(shape[1] * shape[2]) as isize, shape[2] as _, 1]),
1767        4 => strides.extend_from_slice(&[
1768            (shape[1] * shape[2] * shape[3]) as isize,
1769            (shape[2] * shape[3]) as _,
1770            shape[3] as _,
1771            1,
1772        ]),
1773        _ => {
1774            strides.push(1);
1775            for dim in shape.as_ref().iter().skip(1).rev() {
1776                let previous = *strides.last().unwrap();
1777                strides.push(previous * *dim as isize)
1778            }
1779            strides.reverse();
1780        }
1781    }
1782}
1783
1784impl<D: ::ndarray::Dimension, T: Datum> From<Array<T, D>> for Tensor {
1785    fn from(it: Array<T, D>) -> Tensor {
1786        Tensor::from_datum(it.into_dyn())
1787    }
1788}
1789
1790/// Convenient conversion to Tensor.
1791pub trait IntoTensor: Sized {
1792    /// Convert Self to a Tensor.
1793    ///
1794    /// May perform a copy
1795    fn into_tensor(self) -> Tensor;
1796}
1797
1798/// Convenient conversion to Arc<Tensor>.
1799pub trait IntoArcTensor: Sized {
1800    /// Convert Self to a Arc<Tensor>.
1801    ///
1802    /// May perform a copy
1803    fn into_arc_tensor(self) -> Arc<Tensor>;
1804}
1805
1806impl<D: ::ndarray::Dimension, T: Datum> IntoTensor for Array<T, D> {
1807    fn into_tensor(self) -> Tensor {
1808        Tensor::from(self)
1809    }
1810}
1811
1812impl<D: ::ndarray::Dimension, T: Datum> IntoArcTensor for Array<T, D> {
1813    fn into_arc_tensor(self) -> Arc<Tensor> {
1814        Arc::new(Tensor::from(self))
1815    }
1816}
1817
1818impl IntoTensor for Tensor {
1819    fn into_tensor(self) -> Tensor {
1820        self
1821    }
1822}
1823
1824impl IntoTensor for Arc<Tensor> {
1825    fn into_tensor(self) -> Tensor {
1826        Arc::try_unwrap(self).unwrap_or_else(|t| (*t).clone())
1827    }
1828}
1829
1830impl IntoArcTensor for Tensor {
1831    fn into_arc_tensor(self) -> Arc<Tensor> {
1832        Arc::new(self)
1833    }
1834}
1835
1836impl IntoArcTensor for Arc<Tensor> {
1837    fn into_arc_tensor(self) -> Arc<Tensor> {
1838        self
1839    }
1840}
1841
1842#[cfg(test)]
1843mod tests {
1844    use crate::dim::SymbolScope;
1845    use crate::prelude::tensor1;
1846
1847    use super::*;
1848    use litteral::tensor0;
1849    use proptest::collection::vec;
1850    use proptest::prelude::*;
1851
1852    #[derive(Debug)]
1853    struct PermuteAxisProblem {
1854        shape: Vec<usize>,
1855        permutation: Vec<usize>,
1856    }
1857
1858    impl Arbitrary for PermuteAxisProblem {
1859        type Strategy = BoxedStrategy<PermuteAxisProblem>;
1860        type Parameters = ();
1861
1862        fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1863            (0..8usize)
1864                .prop_flat_map(|rank| {
1865                    let permute: Vec<usize> = (0..rank).collect();
1866                    (proptest::collection::vec(1..5usize, rank), Just(permute).prop_shuffle())
1867                })
1868                .prop_map(|(shape, permutation)| PermuteAxisProblem { shape, permutation })
1869                .boxed()
1870        }
1871    }
1872
1873    impl PermuteAxisProblem {
1874        fn input(&self) -> ArrayD<i32> {
1875            let mut i = 0;
1876            ArrayD::from_shape_simple_fn(&*self.shape, || {
1877                i += 1;
1878                i
1879            })
1880            .permuted_axes(&*self.permutation)
1881        }
1882
1883        fn reference(&self) -> Tensor {
1884            let values: Vec<i32> = self.input().iter().copied().collect();
1885            let shape = self.permutation.iter().map(|ix| self.shape[*ix]).collect::<TVec<usize>>();
1886            super::litteral::tensor1(&values).into_shape(&shape).unwrap()
1887        }
1888
1889        fn tract(&self) -> Tensor {
1890            Tensor::from(self.input())
1891        }
1892
1893        fn check(&self) -> proptest::test_runner::TestCaseResult {
1894            prop_assert_eq!(self.tract(), self.reference());
1895            Ok(())
1896        }
1897    }
1898
1899    proptest::proptest! {
1900        #[test]
1901        fn prop(pb: PermuteAxisProblem) {
1902            pb.check().unwrap();
1903        }
1904    }
1905
1906    #[test]
1907    fn t_1_2() {
1908        PermuteAxisProblem { shape: vec![2, 1], permutation: vec![1, 0] }.check().unwrap();
1909    }
1910
1911    #[test]
1912    fn t_2_2() {
1913        PermuteAxisProblem { shape: vec![2, 2], permutation: vec![1, 0] }.check().unwrap();
1914    }
1915
1916    #[derive(Debug)]
1917    struct BroadcastVecToShape {
1918        vec: Vec<f32>,
1919        axis: usize,
1920        shape: TVec<usize>,
1921    }
1922
1923    impl BroadcastVecToShape {
1924        fn check(&self) -> proptest::test_runner::TestCaseResult {
1925            let input = tensor1(&self.vec);
1926            let mut intermediate = tvec![1usize; self.shape.len()];
1927            intermediate[self.axis] = self.vec.len();
1928            let reference = input
1929                .clone()
1930                .into_shape(&intermediate)
1931                .unwrap()
1932                .broadcast_to_shape(&self.shape)
1933                .unwrap();
1934            prop_assert_eq!(
1935                reference,
1936                input.broadcast_vector_to_shape(&self.shape, self.axis).unwrap()
1937            );
1938            Ok(())
1939        }
1940    }
1941
1942    impl Arbitrary for BroadcastVecToShape {
1943        type Strategy = BoxedStrategy<BroadcastVecToShape>;
1944        type Parameters = ();
1945
1946        fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1947            vec(0usize..5, 0usize..4)
1948                .prop_flat_map(|shape| {
1949                    (vec(-10f32..10f32, 0usize..5), Just(shape.clone()), 0..shape.len() + 1)
1950                })
1951                .prop_map(|(vec, mut shape, axis)| {
1952                    shape.insert(axis, vec.len());
1953                    BroadcastVecToShape { vec, shape: shape.into(), axis }
1954                })
1955                .boxed()
1956        }
1957    }
1958
1959    proptest::proptest! {
1960        #[test]
1961        fn broadcast_vector_to_shape_prop(pb: BroadcastVecToShape) {
1962            pb.check().unwrap()
1963        }
1964    }
1965
1966    #[test]
1967    #[cfg(feature = "complex")]
1968    fn test_reinterpret_inner_dim_as_complex() -> TractResult<()> {
1969        let input = crate::internal::tensor2(&[[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]]);
1970        let cplx_input = reinterpret_inner_dim_as_complex(input)?;
1971        let expected = crate::internal::tensor1(&[
1972            Complex::new(1.0f32, 2.0),
1973            Complex::new(3.0, 4.0),
1974            Complex::new(5.0, 6.0),
1975        ]);
1976        assert_eq!(expected, cplx_input);
1977        Ok(())
1978    }
1979
1980    #[test]
1981    #[cfg(feature = "complex")]
1982    fn test_reinterpret_inner_dim_as_complex_2() -> TractResult<()> {
1983        let input =
1984            crate::internal::tensor3(&[[[1i32, 2], [1, 2]], [[3, 4], [3, 4]], [[5, 6], [5, 6]]]);
1985        let cplx_input = reinterpret_inner_dim_as_complex(input)?;
1986        let expected = crate::internal::tensor2(&[
1987            [Complex::new(1i32, 2), Complex::new(1, 2)],
1988            [Complex::new(3, 4), Complex::new(3, 4)],
1989            [Complex::new(5, 6), Complex::new(5, 6)],
1990        ]);
1991        assert_eq!(expected, cplx_input);
1992        Ok(())
1993    }
1994
1995    #[test]
1996    fn clone_tdim_tensor() {
1997        let symbols = SymbolScope::default();
1998        let a = symbols.sym("a");
1999        let t = tensor0(TDim::from(a));
2000        let _ = t.clone();
2001    }
2002}