tract_data/
tensor.rs

1//! `Tensor`, tract main data object of interest.
2use crate::blob::Blob;
3use crate::datum::{round_ties_to_even, scale_by, ClampCast, Datum, DatumType, QParams};
4use crate::dim::TDim;
5use crate::internal::*;
6use crate::opaque::Opaque;
7use crate::TVec;
8use half::f16;
9use itertools::Itertools;
10use ndarray::prelude::*;
11#[cfg(feature = "complex")]
12use num_complex::Complex;
13use num_traits::Zero;
14use std::borrow::Cow;
15use std::fmt;
16use std::hash::Hash;
17use std::ops::Range;
18use std::sync::Arc;
19
20pub mod litteral;
21pub mod view;
22
23#[derive(Copy, Clone, Default, Debug)]
24pub enum Approximation {
25    Exact,
26    #[default]
27    Close,
28    Approximate,
29    VeryApproximate,
30    SuperApproximate,
31    UltraApproximate,
32    Custom(f32, f32, f32),
33}
34
35impl PartialEq for Approximation {
36    fn eq(&self, other: &Self) -> bool {
37        use Approximation::Custom;
38        if let (Custom(aa, ar, ao), Custom(ba, br, bo)) = (self, other) {
39            aa == ba && ar == br && bo == ao
40        } else {
41            std::mem::discriminant(self) == std::mem::discriminant(other)
42        }
43    }
44}
45
46impl Eq for Approximation {}
47
48impl From<bool> for Approximation {
49    fn from(b: bool) -> Self {
50        if b {
51            Self::Approximate
52        } else {
53            Self::Exact
54        }
55    }
56}
57
58impl Approximation {
59    fn atol_rtol_outliers(&self, dt: &DatumType) -> (f64, f64, f64) {
60        use Approximation::*;
61        match (self, dt) {
62            (Exact, _) => (0.0, 0.0, 0.0),
63            (Close, DatumType::F16) => (1e-3, 1e-3, 0.0),
64            (Approximate, DatumType::F16) => (1e-3, 5e-3, 0.0),
65            (Approximate, qp) if qp.is_quantized() => (qp.zp_scale().1 as f64, 0., 0.0),
66            (Close, _) => (1e-7, 1e-7, 0.0),
67            (Approximate, _) => (1e-4, 5e-4, 0.0),
68            (VeryApproximate, _) => (5e-2, 1e-2, 0.0),
69            (SuperApproximate, _) => (0.1, 0.05, 0.0001),
70            (UltraApproximate, _) => (0.2, 0.1, 0.0005),
71            (Custom(atol, rtol, out), _) => (*atol as _, *rtol as _, *out as _),
72        }
73    }
74}
75
76/// Tensor is a concrete tensor in tract.
77#[derive(Eq)]
78pub struct Tensor {
79    dt: DatumType,
80    shape: TVec<usize>,
81    strides: TVec<isize>,
82    len: usize,
83    data: Blob,
84}
85
86unsafe impl Send for Tensor {}
87unsafe impl Sync for Tensor {}
88
89impl Hash for Tensor {
90    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
91        use DatumType::*;
92        self.dt.hash(state);
93        self.shape.hash(state);
94        self.data.layout().align().hash(state);
95        unsafe {
96            match self.dt {
97                Bool => self.as_slice_unchecked::<bool>().hash(state),
98                I8 => self.as_slice_unchecked::<i8>().hash(state),
99                I16 => self.as_slice_unchecked::<i16>().hash(state),
100                I32 => self.as_slice_unchecked::<i32>().hash(state),
101                I64 => self.as_slice_unchecked::<i64>().hash(state),
102                U8 => self.as_slice_unchecked::<u8>().hash(state),
103                U16 => self.as_slice_unchecked::<u16>().hash(state),
104                U32 => self.as_slice_unchecked::<u32>().hash(state),
105                U64 => self.as_slice_unchecked::<u64>().hash(state),
106                F16 => self.as_slice_unchecked::<i16>().hash(state),
107                F32 => self.as_slice_unchecked::<i32>().hash(state),
108                F64 => self.as_slice_unchecked::<i64>().hash(state),
109                TDim => self.as_slice_unchecked::<crate::dim::TDim>().hash(state),
110                String => self.as_slice_unchecked::<std::string::String>().hash(state),
111                Blob => self.as_slice_unchecked::<crate::blob::Blob>().hash(state),
112                Opaque => self.as_slice_unchecked::<crate::opaque::Opaque>().hash(state),
113                QI8(_) => self.as_slice_unchecked::<i8>().hash(state),
114                QU8(_) => self.as_slice_unchecked::<u8>().hash(state),
115                QI32(_) => self.as_slice_unchecked::<i32>().hash(state),
116                #[cfg(feature = "complex")]
117                ComplexI16 => self.as_slice_unchecked::<Complex<i16>>().hash(state),
118                #[cfg(feature = "complex")]
119                ComplexI32 => self.as_slice_unchecked::<Complex<i32>>().hash(state),
120                #[cfg(feature = "complex")]
121                ComplexI64 => self.as_slice_unchecked::<Complex<i64>>().hash(state),
122                #[cfg(feature = "complex")]
123                ComplexF16 => self.as_slice_unchecked::<Complex<i16>>().hash(state),
124                #[cfg(feature = "complex")]
125                ComplexF32 => self.as_slice_unchecked::<Complex<i32>>().hash(state),
126                #[cfg(feature = "complex")]
127                ComplexF64 => self.as_slice_unchecked::<Complex<i64>>().hash(state),
128            }
129        }
130    }
131}
132
133impl Clone for Tensor {
134    fn clone(&self) -> Tensor {
135        self.deep_clone()
136    }
137}
138
139impl Default for Tensor {
140    fn default() -> Tensor {
141        litteral::tensor0(0f32)
142    }
143}
144
145impl Drop for Tensor {
146    fn drop(&mut self) {
147        macro_rules! drop_in_place {
148            ($t: ty) => {
149                if self.dt == <$t>::datum_type() {
150                    unsafe {
151                        self.as_slice_mut::<$t>()
152                            .unwrap()
153                            .iter_mut()
154                            .for_each(|s| std::ptr::drop_in_place(s as *mut $t));
155                    }
156                }
157            };
158        }
159        drop_in_place!(Blob);
160        drop_in_place!(String);
161        drop_in_place!(TDim);
162        drop_in_place!(Opaque);
163    }
164}
165
166#[allow(unreachable_code)]
167pub fn vector_size() -> usize {
168    #[cfg(target_arch = "x86_64")]
169    {
170        return if is_x86_feature_detected!("avx512f") { 512 / 8 } else { 256 / 8 };
171    }
172    128 / 8
173}
174
175impl Tensor {
176    #[allow(unreachable_code, unexpected_cfgs)]
177    #[inline]
178    pub fn default_alignment(dt: DatumType, shape: &[usize]) -> usize {
179        if shape.len() == 0 {
180            dt.alignment()
181        } else {
182            vector_size()
183        }
184    }
185
186    /// Create an uninitialized tensor (dt as type paramater).
187    #[inline]
188    pub unsafe fn uninitialized<T: Datum>(shape: &[usize]) -> TractResult<Tensor> {
189        Self::uninitialized_dt(T::datum_type(), shape)
190    }
191
192    /// Create an uninitialized tensor (dt as regular parameter).
193    #[inline]
194    pub unsafe fn uninitialized_dt(dt: DatumType, shape: &[usize]) -> TractResult<Tensor> {
195        Self::uninitialized_aligned_dt(dt, shape, Self::default_alignment(dt, shape))
196    }
197
198    /// Create an uninitialized tensor with a given alignment (in bytes).
199    #[inline]
200    pub unsafe fn uninitialized_aligned<T: Datum>(
201        shape: &[usize],
202        alignment: usize,
203    ) -> TractResult<Tensor> {
204        Self::uninitialized_aligned_dt(T::datum_type(), shape, alignment)
205    }
206
207    /// Create an uninitialized tensor with a given alignment (in bytes).
208    pub unsafe fn uninitialized_aligned_dt(
209        dt: DatumType,
210        shape: &[usize],
211        alignment: usize,
212    ) -> TractResult<Tensor> {
213        let bytes = shape.iter().cloned().product::<usize>() * dt.size_of();
214        let data = Blob::new_for_size_and_align(bytes, alignment);
215        let mut tensor = Tensor { strides: tvec!(), dt, shape: shape.into(), data, len: 0 };
216        if tensor.shape.len() == 0 {
217            tensor.len = 1;
218        } else {
219            tensor.update_strides_and_len();
220        }
221        if !tensor.data.is_empty() {
222            if dt == String::datum_type() || dt == Blob::datum_type() {
223                // assumes zero-initialized string and blob are valid
224                tensor.data.fill(0);
225            } else if dt == TDim::datum_type() {
226                tensor
227                    .as_slice_mut_unchecked::<TDim>()
228                    .iter_mut()
229                    .for_each(|dim| std::ptr::write(dim, TDim::zero()))
230            } else if dt == Opaque::datum_type() {
231                tensor.as_slice_mut_unchecked::<Opaque>().iter_mut().for_each(|p| {
232                    std::ptr::write(p, Opaque::default());
233                });
234            } else if cfg!(debug_assertions) {
235                assert!(dt.is_copy());
236                if dt == DatumType::F32 {
237                    tensor.fill_t(f32::NAN).unwrap();
238                } else {
239                    // safe, non copy types have been dealt with
240                    tensor.as_bytes_mut().iter_mut().for_each(|x| *x = (-1i8) as u8);
241                }
242            }
243        }
244        Ok(tensor)
245    }
246
247    pub fn stack_tensors(
248        axis: usize,
249        tensors: &[impl std::borrow::Borrow<Tensor>],
250    ) -> TractResult<Tensor> {
251        ensure!(tensors.len() > 0);
252        let rank = tensors[0].borrow().rank();
253        ensure!(axis < rank);
254        ensure!(tensors.iter().all(|t| t.borrow().rank() == rank));
255        let dt = tensors[0].borrow().datum_type();
256        ensure!(tensors.iter().all(|t| t.borrow().datum_type() == dt));
257        let mut shape: TVec<usize> = tensors[0].borrow().shape().into();
258        for ax in 0..rank {
259            if ax != axis {
260                ensure!(tensors.iter().all(|t| t.borrow().shape()[ax] == shape[ax]));
261            }
262        }
263        shape[axis] = tensors.iter().map(|v| v.borrow().shape()[axis]).sum();
264        unsafe {
265            let mut result = Tensor::uninitialized_dt(dt, &shape)?;
266            if dt.is_copy() && shape[..axis].iter().all(|d| *d == 1) {
267                let mut offset = 0isize;
268                for v in tensors {
269                    let v = v.borrow();
270                    let len = v.data.len();
271                    std::ptr::copy_nonoverlapping(
272                        v.data.as_ptr(),
273                        result.data.as_mut_ptr().offset(offset),
274                        len,
275                    );
276                    offset += len as isize;
277                }
278            } else {
279                let mut offset = 0;
280                for t in tensors {
281                    let t = t.borrow();
282                    let len = t.shape()[axis];
283                    result.assign_slice_from_resolved(offset..offset + len, t, 0..len, axis);
284                    offset += len;
285                }
286            }
287
288            Ok(result)
289        }
290    }
291
292    pub fn clear<T: Datum + num_traits::Zero + Clone>(&mut self) -> TractResult<()> {
293        self.fill_t(T::zero())
294    }
295
296    pub fn zero<T: Datum + num_traits::Zero>(shape: &[usize]) -> TractResult<Tensor> {
297        unsafe {
298            let mut t = Tensor::uninitialized::<T>(shape)?;
299            t.clear::<T>()?;
300            Ok(t)
301        }
302    }
303
304    pub fn zero_scalar<T: Datum + num_traits::Zero>() -> TractResult<Tensor> {
305        Tensor::zero::<T>(&[])
306    }
307
308    pub fn zero_scalar_dt(dt: DatumType) -> TractResult<Tensor> {
309        Tensor::zero_dt(dt, &[])
310    }
311
312    pub fn zero_dt(dt: DatumType, shape: &[usize]) -> TractResult<Tensor> {
313        Tensor::zero_aligned_dt(dt, shape, Self::default_alignment(dt, shape))
314    }
315
316    pub fn fill_t<T: Datum + Clone>(&mut self, value: T) -> TractResult<()> {
317        self.as_slice_mut::<T>()?.iter_mut().for_each(|item| *item = value.clone());
318        Ok(())
319    }
320
321    pub fn zero_aligned_dt(
322        dt: DatumType,
323        shape: &[usize],
324        alignment: usize,
325    ) -> TractResult<Tensor> {
326        if shape.iter().product::<usize>() == 0 {
327            unsafe { return Tensor::uninitialized_dt(dt, shape) };
328        }
329        if dt.is_quantized() {
330            unsafe {
331                let mut t = Tensor::uninitialized_dt(dt, shape)?;
332                let zp = dt.zp_scale().0;
333                match dt.unquantized() {
334                    DatumType::I8 => {
335                        t.as_slice_mut::<i8>()?.iter_mut().for_each(|item| *item = zp as _)
336                    }
337                    DatumType::U8 => {
338                        t.as_slice_mut::<u8>()?.iter_mut().for_each(|item| *item = zp as _)
339                    }
340                    DatumType::I32 => {
341                        t.as_slice_mut::<i32>()?.iter_mut().for_each(|item| *item = zp as _)
342                    }
343                    _ => unreachable!(),
344                }
345                Ok(t)
346            }
347        } else {
348            dispatch_zerolike!(Self::zero_aligned(dt)(shape, alignment))
349        }
350    }
351
352    pub fn zero_aligned<T: Datum + num_traits::Zero>(
353        shape: &[usize],
354        alignment: usize,
355    ) -> TractResult<Tensor> {
356        unsafe {
357            let mut tensor = Self::uninitialized_aligned::<T>(shape, alignment)?;
358            tensor.clear::<T>()?;
359            Ok(tensor)
360        }
361    }
362
363    /// Create a tensor with a given shape and a slice of elements.
364    /// The data is copied and aligned to size of T.
365    pub fn from_shape<T: Datum + Copy>(shape: &[usize], data: &[T]) -> TractResult<Tensor> {
366        let dt = T::datum_type();
367        Self::from_shape_align(shape, data, dt.alignment())
368    }
369
370    /// Create a tensor with a given shape and a slice of elements.
371    /// The data is copied and aligned to given alignment.
372    pub fn from_shape_align<T: Datum + Copy>(
373        shape: &[usize],
374        data: &[T],
375        align: usize,
376    ) -> TractResult<Tensor> {
377        ensure!(
378            data.len() == shape.iter().product::<usize>(),
379            "Shape product must be equal to data length"
380        );
381        unsafe {
382            let bytes = std::slice::from_raw_parts(
383                data.as_ptr() as *const u8,
384                data.len() * T::datum_type().size_of(),
385            );
386            let dt = T::datum_type();
387            Self::from_raw_dt_align(dt, shape, bytes, align)
388        }
389    }
390
391    /// Create a tensor from raw data.
392    ///
393    /// It copies the data, aligning it to the size of T.
394    pub unsafe fn from_raw<T: Datum>(shape: &[usize], content: &[u8]) -> TractResult<Tensor> {
395        Tensor::from_raw_dt(T::datum_type(), shape, content)
396    }
397
398    pub unsafe fn from_raw_aligned<T: Datum>(
399        shape: &[usize],
400        content: &[u8],
401        align: usize,
402    ) -> TractResult<Tensor> {
403        Tensor::from_raw_dt_align(T::datum_type(), shape, content, align)
404    }
405
406    pub unsafe fn from_raw_dt(
407        dt: DatumType,
408        shape: &[usize],
409        content: &[u8],
410    ) -> TractResult<Tensor> {
411        Self::from_raw_dt_align(dt, shape, content, dt.alignment())
412    }
413
414    pub unsafe fn from_raw_dt_align(
415        dt: DatumType,
416        shape: &[usize],
417        content: &[u8],
418        align: usize,
419    ) -> TractResult<Tensor> {
420        let mut tensor = Tensor::uninitialized_aligned_dt(dt, shape, align)?;
421        tensor.as_bytes_mut().copy_from_slice(content);
422        Ok(tensor)
423    }
424
425    pub unsafe fn from_slice_align<T: Datum>(content: &[T], align: usize) -> TractResult<Tensor> {
426        let bytes = if content.len() == 0 {
427            &[]
428        } else {
429            std::slice::from_raw_parts(
430                content.as_ptr() as *const u8,
431                content.len() * T::datum_type().size_of(),
432            )
433        };
434        Self::from_raw_dt_align(T::datum_type(), &[content.len()], bytes, align)
435    }
436
437    /// Get the number of dimensions (or axes) of the tensor.
438    #[inline]
439    pub fn rank(&self) -> usize {
440        self.shape.len()
441    }
442
443    /// Get the shape of the tensor.
444    #[inline]
445    pub fn shape(&self) -> &[usize] {
446        &self.shape
447    }
448
449    /// Get the number of values in the tensor.
450    #[inline]
451    #[allow(clippy::len_without_is_empty)]
452    pub fn len(&self) -> usize {
453        self.len
454    }
455
456    /// Get the number of valeus in the tensor.
457    #[inline]
458    #[allow(clippy::len_without_is_empty)]
459    pub fn volume(&self) -> usize {
460        self.len
461    }
462
463    /// Get the shape of the tensor.
464    #[inline]
465    pub fn strides(&self) -> &[isize] {
466        &self.strides
467    }
468
469    fn update_strides_and_len(&mut self) {
470        self.strides.clear();
471        if self.shape.len() == 0 {
472            self.len = 1;
473            return;
474        }
475        compute_natural_stride_to(&mut self.strides, &self.shape);
476        self.len = unsafe { *self.strides.get_unchecked(0) as usize * self.shape.get_unchecked(0) };
477    }
478
479    /// Force the tensor shape, no consistency check.
480    pub unsafe fn set_shape_unchecked(&mut self, shape: &[usize]) {
481        if shape != &*self.shape {
482            self.shape.clear();
483            self.shape.extend_from_slice(shape);
484            self.update_strides_and_len();
485        }
486    }
487
488    /// Force the tensor shape and strides, no consistency check.
489    pub unsafe fn set_geometry_unchecked(&mut self, shape: &[usize], strides: &[isize]) {
490        self.shape.clear();
491        self.shape.extend_from_slice(shape);
492        self.strides.clear();
493        self.strides.extend_from_slice(strides);
494    }
495
496    /// Force the tensor shape.
497    pub fn set_shape(&mut self, shape: &[usize]) -> TractResult<()> {
498        if self.len() != shape.iter().product::<usize>() {
499            bail!("Invalid reshape {:?} to {:?}", self.shape, shape);
500        }
501        unsafe { self.set_shape_unchecked(shape) }
502        Ok(())
503    }
504
505    pub fn permute_axes(self, axes: &[usize]) -> TractResult<Tensor> {
506        ensure!(axes.iter().duplicates().next().is_none());
507        ensure!(axes.iter().all(|a| *a < self.rank()));
508        unsafe {
509            #[inline]
510            unsafe fn permute<T: Datum>(axes: &[usize], input: Tensor) -> Tensor {
511                input.into_array_unchecked::<T>().permuted_axes(axes).into_tensor()
512            }
513            let dt = self.datum_type();
514            let mut t = dispatch_datum_by_size!(permute(self.datum_type())(axes, self));
515            t.set_datum_type(dt);
516            Ok(t)
517        }
518    }
519
520    pub fn move_axis(self, from: usize, to: usize) -> TractResult<Tensor> {
521        let mut permutation: Vec<usize> = (0..self.rank()).collect();
522        permutation.remove(from);
523        permutation.insert(to, from);
524        self.permute_axes(&permutation)
525    }
526
527    pub fn collapse_axis_with_next(mut self, axis: usize) -> Tensor {
528        let removed = self.shape.remove(axis + 1);
529        self.shape[axis] *= removed;
530        self.update_strides_and_len();
531        self
532    }
533
534    pub fn split_axis(mut self, axis: usize, outer_dim: usize) -> TractResult<Tensor> {
535        if self.shape[axis] % outer_dim != 0 {
536            bail!(
537                "Invalid axis split, shape is {:?}, axis split at {}, outer {}",
538                self.shape,
539                axis,
540                outer_dim
541            );
542        }
543        self.shape.insert(axis + 1, self.shape[axis] / outer_dim);
544        self.shape[axis] = outer_dim;
545        self.update_strides_and_len();
546        Ok(self)
547    }
548
549    /// Reshape the tensor to `shape`.
550    pub fn into_shape(mut self, shape: &[usize]) -> TractResult<Tensor> {
551        self.set_shape(shape)?;
552        Ok(self)
553    }
554
555    pub fn insert_axis(&mut self, axis: usize) -> TractResult<()> {
556        self.shape.insert(axis, 1);
557        self.strides.insert(axis, self.strides.get(axis).copied().unwrap_or(1));
558        Ok(())
559    }
560
561    pub fn remove_axis(&mut self, axis: usize) -> TractResult<()> {
562        ensure!(self.shape[axis] == 1, "Remove a non-1 axis: axis {} in {:?}", axis, self);
563        self.shape.remove(axis);
564        self.strides.remove(axis);
565        Ok(())
566    }
567
568    pub fn broadcast_into_rank(mut self, rank: usize) -> TractResult<Tensor> {
569        self.broadcast_to_rank(rank)?;
570        self.update_strides_and_len();
571        Ok(self)
572    }
573
574    pub fn broadcast_to_rank(&mut self, rank: usize) -> TractResult<()> {
575        if rank < self.rank() {
576            bail!("Can only broadcast to higher rank")
577        }
578        while self.shape.len() < rank {
579            self.shape.insert(0, 1)
580        }
581        self.update_strides_and_len();
582        Ok(())
583    }
584
585    pub fn broadcast_scalar_to_shape(&self, shape: &[usize]) -> TractResult<Tensor> {
586        if self.rank() > 0 {
587            bail!("broadcast_scalar_to_shape called on {:?}, which is not a salar", self);
588        }
589        unsafe fn make<T: Datum>(src: &Tensor, dst: &mut Tensor) {
590            let value: &T = src.to_scalar_unchecked::<T>();
591            dst.as_slice_mut_unchecked::<T>().iter_mut().for_each(|item| *item = value.clone());
592        }
593        unsafe {
594            let mut t = Tensor::uninitialized_dt(self.datum_type(), shape)?;
595            dispatch_datum_by_size!(make(self.datum_type())(self, &mut t));
596            Ok(t)
597        }
598    }
599
600    fn broadcast_to_shape_t<T: Datum>(&self, shape: &[usize]) -> TractResult<Tensor> {
601        unsafe {
602            let view = self.to_array_view_unchecked::<T>();
603            let mut output = view
604                .broadcast(shape)
605                .with_context(|| format!("Broadcasting {view:?} to {shape:?}"))?
606                .into_owned()
607                .into_tensor();
608            output.set_datum_type(self.datum_type());
609            Ok(output)
610        }
611    }
612
613    pub fn broadcast_to_shape(&self, shape: &[usize]) -> TractResult<Tensor> {
614        dispatch_datum!(Self::broadcast_to_shape_t(self.dt)(self, shape))
615    }
616
617    pub fn broadcast_vector_to_shape(&self, shape: &[usize], axis: usize) -> TractResult<Tensor> {
618        ensure!(self.rank() == 1);
619        ensure!(shape[axis] == self.len());
620        if !self.datum_type().is_copy() {
621            let mut vec_shape = vec![1; shape.len()];
622            vec_shape[axis] = self.len();
623            return self.clone().into_shape(&vec_shape)?.broadcast_to_shape(shape);
624        }
625        unsafe {
626            let mut output = Tensor::uninitialized_dt(self.datum_type(), shape)?;
627            if output.len() == 0 {
628                return Ok(output);
629            }
630            let inner_len = shape[axis + 1..].iter().product::<usize>();
631
632            unsafe fn splat<T>(input: &Tensor, output: &mut Tensor, inner_len: usize)
633            where
634                T: Datum + Copy,
635            {
636                for ix in 0..input.len() {
637                    let value: T = input.as_slice_unchecked()[ix];
638                    output.as_slice_mut_unchecked::<T>()[ix * inner_len..(ix + 1) * inner_len]
639                        .iter_mut()
640                        .for_each(|item| *item = value);
641                }
642            }
643            dispatch_copy_by_size!(splat(self.datum_type())(&self, &mut output, inner_len));
644
645            let outer_len = shape[0..axis].iter().product::<usize>();
646            let repeat_bytes_len = inner_len * self.as_bytes().len();
647            let bytes = output.as_bytes_mut();
648            for ix in 1..outer_len {
649                bytes.copy_within(0..repeat_bytes_len, ix * repeat_bytes_len);
650            }
651
652            Ok(output)
653        }
654    }
655
656    fn clip_range_bounds(
657        &self,
658        axis: usize,
659        range: impl std::ops::RangeBounds<usize>,
660    ) -> Range<usize> {
661        use std::ops::Bound;
662        let start = match range.start_bound() {
663            Bound::Included(ix) => *ix,
664            Bound::Excluded(ix) => ix + 1,
665            Bound::Unbounded => 0,
666        };
667        let end = match range.end_bound() {
668            Bound::Included(ix) => *ix + 1,
669            Bound::Excluded(ix) => *ix,
670            Bound::Unbounded => self.shape()[axis],
671        };
672        start..end
673    }
674
675    pub fn assign_slice(
676        &mut self,
677        range: impl std::ops::RangeBounds<usize>,
678        src: &Tensor,
679        src_range: impl std::ops::RangeBounds<usize>,
680        axis: usize,
681    ) -> TractResult<()> {
682        let range = self.clip_range_bounds(axis, range);
683        let src_range = src.clip_range_bounds(axis, src_range);
684        ensure!(
685            src.datum_type() == self.datum_type(),
686            "Attempt to assign into {:?} from {:?}, datum type mismatch",
687            self.datum_type(),
688            src.datum_type()
689        );
690        ensure!(
691            src_range.len() == range.len(),
692            "Attempt to assign a range of {:?} from a range of {:?}",
693            range,
694            src_range,
695        );
696        ensure!(
697            self.rank() == src.rank()
698                && itertools::izip!(0.., self.shape(), src.shape())
699                    .all(|(ix, dst, src)| ix == axis || src == dst),
700            "Attempt to assign a {}-axis range of {:?} from a range of {:?}",
701            axis,
702            self,
703            src
704        );
705        ensure!(
706            src_range.end <= src.shape()[axis],
707            "Assigning from invalid slice (axis {}, {:?}) of {:?}",
708            axis,
709            src_range,
710            src
711        );
712        ensure!(
713            range.end <= self.shape()[axis],
714            "Assigning to invalid slice (axis {}, {:?}) of {:?}",
715            axis,
716            range,
717            self
718        );
719        unsafe { self.assign_slice_from_resolved(range, src, src_range, axis) };
720        Ok(())
721    }
722
723    pub unsafe fn assign_slice_unchecked(
724        &mut self,
725        range: impl std::ops::RangeBounds<usize>,
726        src: &Tensor,
727        src_range: impl std::ops::RangeBounds<usize>,
728        axis: usize,
729    ) {
730        let range = self.clip_range_bounds(axis, range);
731        let src_range = src.clip_range_bounds(axis, src_range);
732        self.assign_slice_from_resolved(range, src, src_range, axis);
733    }
734
735    unsafe fn assign_slice_from_resolved(
736        &mut self,
737        range: std::ops::Range<usize>,
738        src: &Tensor,
739        src_range: std::ops::Range<usize>,
740        axis: usize,
741    ) {
742        use ndarray::Slice;
743        unsafe fn assign_slice_t<T: Datum>(
744            to: &mut Tensor,
745            to_range: Range<usize>,
746            from: &Tensor,
747            from_range: Range<usize>,
748            axis: usize,
749        ) {
750            to.to_array_view_mut_unchecked::<T>()
751                .slice_axis_mut(Axis(axis), Slice::from(to_range))
752                .assign(
753                    &from
754                        .to_array_view_unchecked::<T>()
755                        .slice_axis(Axis(axis), Slice::from(from_range)),
756                )
757        }
758        if self.datum_type().is_copy() && self.shape[..axis].iter().all(|d| *d == 1) {
759            let stride = self.strides[axis] as usize * self.datum_type().size_of();
760            let dst_start = (stride * range.start) as isize;
761            let src_start = (stride * src_range.start) as isize;
762            let len = stride * range.len();
763            if len > 0 {
764                if self.data.as_ptr() != src.data.as_ptr() {
765                    std::ptr::copy_nonoverlapping(
766                        src.data.as_ptr().offset(src_start),
767                        self.data.as_mut_ptr().offset(dst_start),
768                        len,
769                    );
770                } else {
771                    std::ptr::copy(
772                        src.data.as_ptr().offset(src_start),
773                        self.data.as_mut_ptr().offset(dst_start),
774                        len,
775                    );
776                }
777            }
778        } else {
779            dispatch_datum!(assign_slice_t(self.datum_type())(self, range, src, src_range, axis));
780        }
781    }
782
783    /// Get the datum type of the tensor.
784    #[inline]
785    pub fn datum_type(&self) -> DatumType {
786        self.dt
787    }
788
789    /// Set the datum type of the tensor.
790    #[inline]
791    pub unsafe fn set_datum_type(&mut self, dt: DatumType) {
792        self.dt = dt
793    }
794
795    /// Dump the tensor in a human readable form.
796    ///
797    /// `force_full` will force the tensor to be dump in full even if it is big.
798    pub fn dump(&self, force_full: bool) -> TractResult<String> {
799        unsafe fn dump_t<D: Datum>(tensor: &Tensor, n: usize) -> String {
800            if let Some(qp) = tensor.datum_type().qparams() {
801                let integers = tensor.cast_to::<i32>().unwrap();
802                integers.as_slice_unchecked::<i32>()[0..n]
803                    .iter()
804                    .map(|x| format!("[{}]({})", x, qp.dq(*x)))
805                    .join(", ")
806            } else {
807                tensor.as_slice_unchecked::<D>()[0..n].iter().join(", ")
808            }
809        }
810        unsafe {
811            let trunc = self.len() > 12 && !force_full;
812            let data = dispatch_datum!(dump_t(self.datum_type())(
813                self,
814                if trunc { 12 } else { self.len() }
815            ));
816            Ok(format!(
817                "{},{:?} {}{}",
818                self.shape.iter().join(","),
819                self.dt,
820                data,
821                if trunc { "..." } else { "" }
822            ))
823        }
824    }
825
826    /// Compare two tensors, allowing for rounding errors.
827    pub fn close_enough(
828        &self,
829        other: &Self,
830        approx: impl Into<Approximation> + std::fmt::Debug,
831    ) -> TractResult<()> {
832        let approx = approx.into();
833        if self.shape() != other.shape() {
834            bail!("Shape mismatch {:?} != {:?}", self.shape(), other.shape())
835        }
836        let (atol, rtol, outliers) = approx.atol_rtol_outliers(&self.datum_type());
837        let ma = self.cast_to::<f32>()?;
838        let ma = ma.to_array_view::<f32>()?;
839        let mb = other.cast_to::<f32>()?;
840        let mb = mb.to_array_view::<f32>()?;
841        let mut first_outlier = None;
842        let mut outliers_count = 0;
843        ndarray::indices_of(&ma).into_iter().for_each(|indices| {
844            let a = ma[&indices];
845            let b = mb[&indices];
846            if !((a.is_nan() && b.is_nan())
847                || (a.is_infinite() && b.is_infinite() && a.signum() == b.signum())
848                || (a - b).abs() <= atol as f32 + rtol as f32 * b.abs())
849            {
850                if outliers_count == 0 {
851                    first_outlier = Some(indices.as_array_view().to_vec());
852                }
853                outliers_count += 1;
854            }
855        });
856        if self.volume() > 0 && outliers_count as f64 / self.volume() as f64 > outliers {
857            let indices = first_outlier.unwrap();
858            let a = ma[&*indices];
859            let b = mb[&*indices];
860            bail!(
861                "Mismatch. First outlier: {:?} for {:?}) at {:?} {} != {}. Outliers: {} / {} = {:0.5} > {:0.5}.",
862                approx,
863                self.datum_type(),
864                indices,
865                a,
866                b,
867                outliers_count,
868                self.volume(),
869                outliers_count as f64 / self.volume() as f64,
870                outliers
871            );
872        }
873        Ok(())
874    }
875
876    /// Transform the tensor into a `ndarray::Array`.
877    pub fn into_array<D: Datum>(self) -> TractResult<ArrayD<D>> {
878        Ok(self.to_array_view::<D>()?.to_owned())
879    }
880
881    /// Transform the tensor into a `ndarray::Array`.
882    pub unsafe fn into_array_unchecked<D: Datum>(self) -> ArrayD<D> {
883        self.to_array_view_unchecked::<D>().to_owned()
884    }
885
886    fn check_for_access<D: Datum>(&self) -> TractResult<()> {
887        ensure!(
888            self.datum_type().unquantized() == D::datum_type().unquantized(),
889            "Tensor datum type error: tensor is {:?}, accessed as {:?}",
890            self.datum_type(),
891            D::datum_type(),
892        );
893        Ok(())
894    }
895
896    /// Transform the data as a `ndarray::Array`.
897    pub fn to_array_view<D: Datum>(&self) -> TractResult<ArrayViewD<D>> {
898        self.check_for_access::<D>()?;
899        unsafe { Ok(self.to_array_view_unchecked()) }
900    }
901
902    /// Transform the data as a mutable `ndarray::Array`.
903    pub fn to_array_view_mut<D: Datum>(&mut self) -> TractResult<ArrayViewMutD<D>> {
904        self.check_for_access::<D>()?;
905        unsafe { Ok(self.to_array_view_mut_unchecked()) }
906    }
907
908    /// Transform the data as a `ndarray::Array`.
909    pub unsafe fn to_array_view_unchecked<D: Datum>(&self) -> ArrayViewD<D> {
910        if self.len() != 0 {
911            ArrayViewD::from_shape_ptr(&*self.shape, self.data.as_ptr() as *const D)
912        } else {
913            ArrayViewD::from_shape(&*self.shape, &[]).unwrap()
914        }
915    }
916
917    /// Transform the data as a mutable `ndarray::Array`.
918    pub unsafe fn to_array_view_mut_unchecked<D: Datum>(&mut self) -> ArrayViewMutD<D> {
919        if self.len() != 0 {
920            ArrayViewMutD::from_shape_ptr(&*self.shape, self.data.as_mut_ptr() as *mut D)
921        } else {
922            ArrayViewMutD::from_shape(&*self.shape, &mut []).unwrap()
923        }
924    }
925
926    /// Access the data as a pointer.
927    pub fn as_ptr<D: Datum>(&self) -> TractResult<*const D> {
928        self.check_for_access::<D>()?;
929        Ok(self.data.as_ptr() as *const D)
930    }
931
932    /// Access the data as a pointer.
933    pub unsafe fn as_ptr_unchecked<D: Datum>(&self) -> *const D {
934        self.data.as_ptr() as *const D
935    }
936
937    /// Access the data as a pointer.
938    pub unsafe fn as_ptr_mut_unchecked<D: Datum>(&mut self) -> *mut D {
939        self.data.as_mut_ptr() as *mut D
940    }
941
942    /// Access the data as a mutable pointer.
943    pub fn as_ptr_mut<D: Datum>(&mut self) -> TractResult<*mut D> {
944        self.as_ptr::<D>().map(|p| p as *mut D)
945    }
946
947    /// Access the data as a slice.
948    pub fn as_slice<D: Datum>(&self) -> TractResult<&[D]> {
949        let ptr: *const D = self.as_ptr()?;
950        if self.data.len() == 0 {
951            Ok(&[])
952        } else {
953            unsafe { Ok(std::slice::from_raw_parts::<D>(ptr, self.len())) }
954        }
955    }
956
957    /// Access the data as a mutable slice.
958    pub fn as_slice_mut<D: Datum>(&mut self) -> TractResult<&mut [D]> {
959        let ptr: *mut D = self.as_ptr_mut()?;
960        if self.data.len() == 0 {
961            Ok(&mut [])
962        } else {
963            unsafe { Ok(std::slice::from_raw_parts_mut::<D>(ptr, self.len())) }
964        }
965    }
966
967    /// Access the data as a slice.
968    pub unsafe fn as_slice_unchecked<D: Datum>(&self) -> &[D] {
969        if self.data.len() == 0 {
970            &[]
971        } else {
972            std::slice::from_raw_parts::<D>(self.as_ptr_unchecked(), self.len())
973        }
974    }
975
976    /// Access the data as a mutable slice.
977    pub unsafe fn as_slice_mut_unchecked<D: Datum>(&mut self) -> &mut [D] {
978        if self.data.len() == 0 {
979            &mut []
980        } else {
981            std::slice::from_raw_parts_mut::<D>(self.as_ptr_mut_unchecked(), self.len())
982        }
983    }
984
985    /// Access the data as a scalar.
986    pub fn to_scalar<D: Datum>(&self) -> TractResult<&D> {
987        self.check_for_access::<D>()?;
988        if self.len() == 0 {
989            bail!("to_scalar called on empty tensor ({:?})", self)
990        }
991        if self.len() > 1 {
992            bail!("to_scalar called on a tensor with multiple values ({:?})", self)
993        }
994        unsafe { Ok(self.to_scalar_unchecked()) }
995    }
996
997    /// Make the tensor a scalar tensor (assumes it contains a single value).
998    pub fn to_scalar_tensor(&self) -> TractResult<Tensor> {
999        fn to_scalar_tensor_t<D: Datum>(t: &Tensor) -> TractResult<Tensor> {
1000            Ok(litteral::tensor0(t.to_scalar::<D>()?.clone()))
1001        }
1002        dispatch_datum!(to_scalar_tensor_t(self.datum_type())(self))
1003    }
1004
1005    /// Access the data as a scalar.
1006    pub unsafe fn to_scalar_unchecked<D: Datum>(&self) -> &D {
1007        &*(self.data.as_ptr() as *const D)
1008    }
1009
1010    /// Mutable access the data as a scalar.
1011    pub fn to_scalar_mut<D: Datum>(&mut self) -> TractResult<&mut D> {
1012        self.check_for_access::<D>()?;
1013        if self.len() == 0 {
1014            bail!("to_scalar_mut called on empty tensor ({:?})", self)
1015        }
1016        if self.len() > 1 {
1017            bail!("to_scalar called on a tensor with multiple values ({:?})", self)
1018        }
1019        unsafe { Ok(self.to_scalar_mut_unchecked()) }
1020    }
1021
1022    /// Mutable access the data as a scalar.
1023    pub unsafe fn to_scalar_mut_unchecked<D: Datum>(&mut self) -> &mut D {
1024        &mut *(self.data.as_mut_ptr() as *mut D)
1025    }
1026
1027    pub fn as_bytes(&self) -> &[u8] {
1028        self.data.as_bytes()
1029    }
1030
1031    pub fn as_bytes_mut(&mut self) -> &mut [u8] {
1032        self.data.as_bytes_mut()
1033    }
1034
1035    unsafe fn is_uniform_t<T: Datum>(&self) -> bool {
1036        let slice = self.as_slice_unchecked::<T>();
1037        slice[1..].iter().all(|x| x == &slice[0])
1038    }
1039
1040    pub fn is_uniform(&self) -> bool {
1041        if self.len() <= 1 {
1042            return true;
1043        }
1044        unsafe { dispatch_datum!(Tensor::is_uniform_t(self.datum_type())(self)) }
1045    }
1046
1047    unsafe fn as_uniform_t<T: Datum>(&self) -> Tensor {
1048        let v: T = self.as_slice_unchecked::<T>()[0].clone();
1049        litteral::tensor0(v)
1050    }
1051
1052    pub fn as_uniform(&self) -> Option<Tensor> {
1053        if self.len() >= 1 && self.is_uniform() {
1054            unsafe {
1055                let mut t = dispatch_datum!(Tensor::as_uniform_t(self.datum_type())(self));
1056                t.set_datum_type(self.datum_type());
1057                Some(t)
1058            }
1059        } else {
1060            None
1061        }
1062    }
1063
1064    pub fn is_all_zero(&self) -> TractResult<bool> {
1065        Ok(self.len() == 0 || self.as_uniform().map(|t| t.is_zero().unwrap()).unwrap_or(false))
1066    }
1067
1068    pub fn is_zero(&self) -> TractResult<bool> {
1069        Ok(self == &Tensor::zero_scalar_dt(self.dt)?)
1070    }
1071
1072    unsafe fn natural_cast<
1073        Source: Datum + num_traits::AsPrimitive<Target>,
1074        Target: Datum + Copy,
1075    >(
1076        &self,
1077        other: &mut Tensor,
1078    ) {
1079        self.as_slice_unchecked::<Source>()
1080            .iter()
1081            .zip(other.as_slice_mut_unchecked::<Target>().iter_mut())
1082            .for_each(|(s, d)| *d = s.as_());
1083    }
1084
1085    unsafe fn cast_number_to_bool<Source: Datum + num_traits::Zero>(&self, other: &mut Tensor) {
1086        self.as_slice_unchecked::<Source>()
1087            .iter()
1088            .zip(other.as_slice_mut_unchecked::<bool>().iter_mut())
1089            .for_each(|(s, d)| *d = !s.is_zero());
1090    }
1091
1092    unsafe fn cast_from_string<Target: Datum + core::str::FromStr>(
1093        &self,
1094        other: &mut Tensor,
1095    ) -> TractResult<()> {
1096        for (s, d) in self
1097            .as_slice_unchecked::<String>()
1098            .iter()
1099            .zip(other.as_slice_mut_unchecked::<Target>().iter_mut())
1100        {
1101            *d = s
1102                .parse()
1103                .map_err(|_| format_err!("Can not parse as {:?}", Target::datum_type()))?;
1104        }
1105        Ok(())
1106    }
1107
1108    unsafe fn cast_to_string<Source: Datum>(&self, other: &mut Tensor) {
1109        for (s, d) in self
1110            .as_slice_unchecked::<Source>()
1111            .iter()
1112            .zip(other.as_slice_mut_unchecked::<String>().iter_mut())
1113        {
1114            *d = s.to_string()
1115        }
1116    }
1117
1118    /// Optionnaly convert data to a tensor for a new DatumType.
1119    pub fn cast_to<D: Datum>(&self) -> TractResult<Cow<Tensor>> {
1120        self.cast_to_dt(D::datum_type())
1121    }
1122
1123    /// Optionnaly convert data to a tensor for a new DatumType.
1124    #[allow(clippy::redundant_closure_call)]
1125    pub fn cast_to_dt(&self, dst_dt: DatumType) -> TractResult<Cow<Tensor>> {
1126        unsafe {
1127            if self.dt == dst_dt {
1128                return Ok(Cow::Borrowed(self));
1129            }
1130            if self.dt == TDim::datum_type() && (dst_dt.is_integer() || dst_dt.is_float()) {
1131                let slice = self.as_slice_unchecked::<TDim>();
1132                let mut ints = Self::uninitialized::<i64>(&self.shape)?;
1133                let ints_slice = ints.as_slice_mut_unchecked::<i64>();
1134                for i in 0..self.len() {
1135                    ints_slice[i] = slice[i].to_i64()?;
1136                }
1137                return Ok(Cow::Owned(ints.cast_to_dt(dst_dt)?.into_owned()));
1138            }
1139            if self.dt == bool::datum_type()
1140                && (dst_dt.is_integer() || dst_dt.is_float() || dst_dt == TDim::datum_type())
1141            {
1142                let slice = self.as_slice_unchecked::<bool>();
1143                let mut ints = Self::uninitialized::<i8>(&self.shape)?;
1144                let ints_slice = ints.as_slice_mut_unchecked::<i8>();
1145                for i in 0..self.len() {
1146                    ints_slice[i] = slice[i] as usize as i8;
1147                }
1148                return Ok(Cow::Owned(ints.cast_to_dt(dst_dt)?.into_owned()));
1149            }
1150            let mut result = Self::uninitialized_dt(dst_dt, &self.shape)?;
1151            if self.dt == DatumType::String {
1152                dispatch_numbers!(Self::cast_from_string(dst_dt)(self, &mut result))?;
1153                return Ok(Cow::Owned(result));
1154            }
1155            if dst_dt == DatumType::String {
1156                dispatch_datum!(Self::cast_to_string(self.dt)(self, &mut result));
1157                return Ok(Cow::Owned(result));
1158            }
1159            macro_rules! n {
1160                ($source:ty) => {
1161                    if <$source>::datum_type() == self.datum_type() {
1162                        match dst_dt {
1163                            DatumType::I8 => self.natural_cast::<$source, i8>(&mut result),
1164                            DatumType::I16 => self.natural_cast::<$source, i16>(&mut result),
1165                            DatumType::I32 => self.natural_cast::<$source, i32>(&mut result),
1166                            DatumType::I64 => self.natural_cast::<$source, i64>(&mut result),
1167                            DatumType::U8 => self.natural_cast::<$source, u8>(&mut result),
1168                            DatumType::U16 => self.natural_cast::<$source, u16>(&mut result),
1169                            DatumType::U32 => self.natural_cast::<$source, u32>(&mut result),
1170                            DatumType::U64 => self.natural_cast::<$source, u64>(&mut result),
1171                            DatumType::F16 => self.natural_cast::<$source, f16>(&mut result),
1172                            DatumType::F32 => self.natural_cast::<$source, f32>(&mut result),
1173                            DatumType::F64 => self.natural_cast::<$source, f64>(&mut result),
1174                            DatumType::TDim => {
1175                                let ints = self.cast_to::<i32>()?;
1176                                let slice = ints.as_slice_unchecked::<i32>();
1177                                let result = result.as_slice_mut_unchecked::<TDim>();
1178                                for i in 0..self.len() {
1179                                    result[i] = slice[i].into();
1180                                }
1181                            }
1182                            DatumType::Bool => self.cast_number_to_bool::<$source>(&mut result),
1183                            _ => todo!(),
1184                        }
1185                        return Ok(Cow::Owned(result));
1186                    };
1187                };
1188            }
1189            //If there is no quantization
1190            if !dst_dt.is_quantized() && !self.datum_type().is_quantized() {
1191                n!(u8);
1192                n!(u16);
1193                n!(u32);
1194                n!(u64);
1195                n!(i8);
1196                n!(i16);
1197                n!(i32);
1198                n!(i64);
1199                n!(f16);
1200                n!(f32);
1201                n!(f64);
1202            } else {
1203                let (s_zp, s_scale) = self.datum_type().zp_scale();
1204                let (d_zp, d_scale) = dst_dt.zp_scale();
1205                if self.datum_type().is_quantized() && dst_dt.is_float() {
1206                    macro_rules! q_to_fp {
1207                        ($source:ty, $dest:ty) => {
1208                            if <$source>::datum_type().unquantized()
1209                                == self.datum_type().unquantized()
1210                                && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1211                            {
1212                                self.as_slice_unchecked::<$source>()
1213                                    .iter()
1214                                    .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1215                                    .for_each(|(&s, d)| {
1216                                        *d = (s as $dest - s_zp as $dest) * s_scale as $dest;
1217                                    });
1218                                return Ok(Cow::Owned(result));
1219                            }
1220                        };
1221                    }
1222                    q_to_fp!(i8, f64);
1223                    q_to_fp!(i8, f32);
1224                    q_to_fp!(u8, f64);
1225                    q_to_fp!(u8, f32);
1226                }
1227                //TODO: optimize scale_by
1228                macro_rules! q8_to_q8 {
1229                    ($typ:ty) => {
1230                        if dst_dt.unquantized() == <$typ>::datum_type() {
1231                            self.as_slice_unchecked::<$typ>()
1232                                .iter()
1233                                .zip(result.as_slice_mut_unchecked::<$typ>().iter_mut())
1234                                .for_each(|(&s, d)| {
1235                                    *d = (d_zp as i32
1236                                        + scale_by(s as i32 - s_zp as i32, s_scale / d_scale))
1237                                    .clamp_cast()
1238                                });
1239                            return Ok(Cow::Owned(result));
1240                        }
1241                    };
1242                }
1243
1244                macro_rules! q_via_f32 {
1245                    ($source:ty, $dest:ty, $round:expr) => {
1246                        if <$source>::datum_type().unquantized() == self.datum_type().unquantized()
1247                            && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1248                        {
1249                            self.as_slice_unchecked::<$source>()
1250                                .iter()
1251                                .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1252                                .for_each(|(&s, d)| {
1253                                    let s_float = (s as f32 - s_zp as f32) * s_scale as f32;
1254                                    let d_float = s_float as f32 / d_scale as f32 + d_zp as f32;
1255                                    *d = $round(d_float);
1256                                });
1257                            return Ok(Cow::Owned(result));
1258                        }
1259                    };
1260                }
1261
1262                macro_rules! q_n {
1263                    (clamp $source:ty, $dest:ty) => {{
1264                        if <$source>::datum_type().unquantized() == self.datum_type().unquantized()
1265                            && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1266                        {
1267                            self.as_slice_unchecked::<$source>()
1268                                .iter()
1269                                .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1270                                .for_each(|(&s, d)| {
1271                                    *d = s.clamp_cast();
1272                                });
1273                            return Ok(Cow::Owned(result));
1274                        }
1275                    }};
1276                    ($source:ty, $dest:ty) => {{
1277                        if <$source>::datum_type().unquantized() == self.datum_type().unquantized()
1278                            && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1279                        {
1280                            self.as_slice_unchecked::<$source>()
1281                                .iter()
1282                                .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1283                                .for_each(|(&s, d)| {
1284                                    *d = s as $dest;
1285                                });
1286                            return Ok(Cow::Owned(result));
1287                        }
1288                    }};
1289                }
1290
1291                if dst_dt.unquantized() == self.datum_type().unquantized()
1292                    && dst_dt.is_quantized()
1293                    && self.datum_type().is_quantized()
1294                {
1295                    q8_to_q8!(i8);
1296                    q8_to_q8!(u8);
1297                }
1298
1299                q_via_f32!(f32, i8, |f| round_ties_to_even(f).clamp_cast());
1300                q_via_f32!(f32, u8, |f| round_ties_to_even(f).clamp_cast());
1301                q_via_f32!(f32, i32, |f| round_ties_to_even(f).clamp_cast());
1302                q_via_f32!(i8, f32, |f| f);
1303                q_via_f32!(u8, f32, |f| f);
1304                q_via_f32!(i32, f32, |f| f);
1305
1306                if dst_dt.is_quantized() && self.datum_type().is_quantized() {
1307                    q_via_f32!(u8, i8, |f| round_ties_to_even(f).clamp_cast());
1308                    q_via_f32!(i8, u8, |f| round_ties_to_even(f).clamp_cast());
1309                    q_via_f32!(i32, u8, |f| round_ties_to_even(f).clamp_cast());
1310                    q_via_f32!(i32, i8, |f| round_ties_to_even(f).clamp_cast());
1311                    q_via_f32!(u8, i32, |f| round_ties_to_even(f).clamp_cast());
1312                    q_via_f32!(i8, i32, |f| round_ties_to_even(f).clamp_cast());
1313
1314                    // ensure cast to different scale offset work
1315                    q_via_f32!(i8, i8, |f| round_ties_to_even(f).clamp_cast());
1316                    q_via_f32!(u8, u8, |f| round_ties_to_even(f).clamp_cast());
1317                }
1318
1319                q_n!(i8, i32);
1320                q_n!(i8, u32);
1321                q_n!(u8, i32);
1322                q_n!(u8, u32);
1323                q_n!(clamp i32, i8);
1324                q_n!(clamp i32, u8);
1325                q_n!(clamp u32, i8);
1326                q_n!(clamp u32, u8);
1327                q_n!(i8, i8);
1328                q_n!(u8, u8);
1329                q_n!(i32, i32);
1330                q_n!(u32, u32);
1331            }
1332
1333            bail!("Unsupported cast from {:?} to {:?}", self.dt, dst_dt)
1334        }
1335    }
1336
1337    /// Access the data as a scalar, after a cast.
1338    pub fn cast_to_scalar<D: Datum + Copy>(&self) -> TractResult<D> {
1339        let casted = self.cast_to::<D>()?;
1340        casted.to_scalar::<D>().copied()
1341    }
1342
1343    /// Access the nth element of the tensor, returned as a 0-rank Tensor
1344    pub fn nth(&self, nth: usize) -> TractResult<Tensor> {
1345        if nth >= self.len() {
1346            bail!(
1347                "nth called with {}th element on a tensor of len {} ({:?}",
1348                nth,
1349                self.len(),
1350                self
1351            );
1352        }
1353        unsafe fn nth_t<T: Datum>(me: &Tensor, nth: usize, output: &mut Tensor) {
1354            let value = me.as_slice_unchecked::<T>()[nth].clone();
1355            output.as_slice_mut_unchecked::<T>()[0] = value;
1356        }
1357        unsafe {
1358            let mut output = Tensor::uninitialized_dt(self.datum_type(), &[])?;
1359            dispatch_datum_by_size!(nth_t(self.datum_type())(self, nth, &mut output));
1360            Ok(output)
1361        }
1362    }
1363
1364    /// Strict equality test on tensors.
1365    fn eq_dt(&self, other: &Tensor) -> TractResult<bool> {
1366        unsafe fn eq_t<D: Datum>(me: &Tensor, other: &Tensor) -> bool {
1367            me.as_slice_unchecked::<D>() == other.as_slice_unchecked::<D>()
1368        }
1369
1370        unsafe {
1371            Ok(self.datum_type() == other.datum_type()
1372                && self.shape() == other.shape()
1373                && dispatch_datum!(eq_t(self.dt)(self, other)))
1374        }
1375    }
1376
1377    fn from_datum<T: Datum>(mut it: ArrayD<T>) -> Tensor {
1378        unsafe {
1379            let mut t = Self::uninitialized::<T>(it.shape()).unwrap();
1380            if let Some(slice) = it.as_slice_mut() {
1381                if t.datum_type().is_copy() {
1382                    std::ptr::copy_nonoverlapping(
1383                        slice.as_ptr() as *const i8,
1384                        t.as_ptr_mut_unchecked(),
1385                        t.data.layout().size(),
1386                    );
1387                } else {
1388                    t.as_slice_mut_unchecked::<T>()
1389                        .iter_mut()
1390                        .zip(slice.iter_mut())
1391                        .for_each(|(t, s)| *t = std::mem::take(s));
1392                }
1393                return t;
1394            }
1395            if it.strides().iter().all(|&s| s > 0) {
1396                let mut len_and_strides: TVec<(usize, usize)> = tvec!();
1397                for (len, stride) in itertools::izip!(it.shape(), it.strides(), t.strides())
1398                    .sorted_by_key(|(_, src, _)| *src)
1399                    .map(|(l, _, dst)| (*l as isize, *dst))
1400                {
1401                    if !len_and_strides.is_empty()
1402                        && len_and_strides.last().unwrap().1 * len_and_strides.last().unwrap().0
1403                            == stride as usize
1404                    {
1405                        len_and_strides.last_mut().unwrap().0 *= len as usize;
1406                    } else {
1407                        len_and_strides.push((len as usize, stride as usize));
1408                    }
1409                }
1410                len_and_strides.reverse();
1411                crate::scatter::scatter_contig_data(
1412                    it.as_ptr(),
1413                    t.as_ptr_mut_unchecked(),
1414                    &len_and_strides,
1415                );
1416                return t;
1417            }
1418            // finally use ndarray into_iter()
1419            t.as_slice_mut_unchecked().iter_mut().zip(it).for_each(|(t, a)| *t = a);
1420            t
1421        }
1422    }
1423
1424    pub fn deep_clone(&self) -> Tensor {
1425        unsafe {
1426            let mut tensor = Tensor::uninitialized_dt(self.datum_type(), self.shape()).unwrap();
1427            if self.len() > 0 {
1428                if self.dt.is_copy() {
1429                    self.data.as_ptr().copy_to_nonoverlapping(
1430                        tensor.as_bytes_mut().as_mut_ptr(),
1431                        self.data.layout().size(),
1432                    )
1433                } else if self.dt == DatumType::String {
1434                    tensor
1435                        .as_slice_mut_unchecked::<String>()
1436                        .clone_from_slice(self.as_slice_unchecked());
1437                } else if self.dt == DatumType::Blob {
1438                    tensor
1439                        .as_slice_mut_unchecked::<Blob>()
1440                        .clone_from_slice(self.as_slice_unchecked());
1441                } else if self.dt == DatumType::Opaque {
1442                    tensor
1443                        .as_slice_mut_unchecked::<Opaque>()
1444                        .clone_from_slice(self.as_slice_unchecked());
1445                } else if self.dt == DatumType::TDim {
1446                    tensor
1447                        .as_slice_mut_unchecked::<TDim>()
1448                        .clone_from_slice(self.as_slice_unchecked());
1449                }
1450            }
1451            tensor
1452        }
1453    }
1454
1455    pub fn slice(&self, axis: usize, start: usize, end: usize) -> TractResult<Tensor> {
1456        if axis >= self.rank() {
1457            bail!("Can not slice at axis {} tensor {:?}", axis, self);
1458        }
1459        if start > self.shape[axis] || end > self.shape[axis] || start >= end {
1460            bail!("Invalid slicing range {start}..{end} on axis {axis} for {self:?}");
1461        }
1462        fn slice_t<T: Datum>(
1463            t: &Tensor,
1464            axis: usize,
1465            start: usize,
1466            end: usize,
1467        ) -> TractResult<Tensor> {
1468            Ok(t.to_array_view::<T>()?
1469                .slice_axis(ndarray::Axis(axis), (start..end).into())
1470                .into_owned()
1471                .into_tensor())
1472        }
1473        dispatch_datum!(slice_t(self.datum_type())(self, axis, start, end))
1474    }
1475
1476    #[inline]
1477    pub fn view(&self) -> view::TensorView {
1478        unsafe { view::TensorView::view(self) }
1479    }
1480
1481    #[inline]
1482    pub fn view_at_prefix(&self, prefix: &[usize]) -> TractResult<view::TensorView> {
1483        view::TensorView::at_prefix(self, prefix)
1484    }
1485
1486    #[inline]
1487    pub fn view_offsetting(&self, coords: &[usize]) -> TractResult<view::TensorView> {
1488        view::TensorView::offsetting(self, coords)
1489    }
1490
1491    #[inline]
1492    pub unsafe fn view_offsetting_unchecked(&self, coords: &[usize]) -> view::TensorView {
1493        view::TensorView::offsetting_unchecked(self, coords)
1494    }
1495
1496    #[inline]
1497    pub fn view_mut(&mut self) -> view::TensorView {
1498        unsafe { view::TensorView::view(self) }
1499    }
1500
1501    #[inline]
1502    pub fn view_at_prefix_mut(&mut self, prefix: &[usize]) -> TractResult<view::TensorView> {
1503        view::TensorView::at_prefix(self, prefix)
1504    }
1505
1506    #[inline]
1507    pub fn view_offsetting_mut(&mut self, coords: &[usize]) -> TractResult<view::TensorView> {
1508        view::TensorView::offsetting(self, coords)
1509    }
1510
1511    /// Offsets the tensor as an i8 type if it's an u8 type, otherwise passes it unchanged.
1512    pub fn offset_u8_as_i8(self: &Arc<Self>) -> Arc<Self> {
1513        let mut t = if let DatumType::U8 = self.dt.unquantized() {
1514            self.to_array_view::<u8>().unwrap().mapv(|v| v.wrapping_sub(128) as i8).into_tensor()
1515        } else {
1516            return self.clone();
1517        };
1518
1519        if let DatumType::QU8(qp) = self.dt {
1520            if let QParams::ZpScale { zero_point, scale } = qp {
1521                t.dt = DatumType::QI8(QParams::ZpScale { zero_point: zero_point - 128, scale });
1522            } else {
1523                t.dt = DatumType::QI8(qp);
1524            }
1525        }
1526
1527        t.into_arc_tensor()
1528    }
1529
1530    /// Offsets the tensor as an u8 type if it's an i8 type, otherwise passes it unchanged.
1531    pub fn offset_i8_as_u8(self: &Arc<Self>) -> Arc<Self> {
1532        let mut t = if let DatumType::I8 = self.dt.unquantized() {
1533            self.to_array_view::<i8>().unwrap().mapv(|v| (v as u8).wrapping_add(128)).into_tensor()
1534        } else {
1535            return self.clone();
1536        };
1537
1538        if let DatumType::QI8(qp) = self.dt {
1539            if let QParams::ZpScale { zero_point, scale } = qp {
1540                t.dt = DatumType::QU8(QParams::ZpScale { zero_point: zero_point + 128, scale });
1541            } else {
1542                t.dt = DatumType::QU8(qp);
1543            }
1544        }
1545        t.into_arc_tensor()
1546    }
1547
1548    pub fn to_aligned_default(&self) -> TractResult<Self> {
1549        if self.dt.is_copy() {
1550            unsafe {
1551                let mut t = Self::uninitialized_aligned_dt(
1552                    self.dt,
1553                    &self.shape,
1554                    Self::default_alignment(self.dt, &self.shape),
1555                )?;
1556                t.as_bytes_mut().copy_from_slice(self.as_bytes());
1557                Ok(t)
1558            }
1559        } else {
1560            let mut t = Self::zero_aligned_dt(
1561                self.dt,
1562                &self.shape,
1563                Self::default_alignment(self.dt, &self.shape),
1564            )?;
1565            if self.dt == String::datum_type() {
1566                t.as_slice_mut::<String>()?.clone_from_slice(self.as_slice()?);
1567            } else if self.dt == Blob::datum_type() {
1568                t.as_slice_mut::<Blob>()?.clone_from_slice(self.as_slice()?);
1569            } else if self.dt == TDim::datum_type() {
1570                t.as_slice_mut::<TDim>()?.clone_from_slice(self.as_slice()?);
1571            }
1572            Ok(t)
1573        }
1574    }
1575
1576    pub fn natural_strides(shape: &[usize]) -> TVec<isize> {
1577        let mut strides = tvec!();
1578        compute_natural_stride_to(&mut strides, shape);
1579        strides
1580    }
1581
1582    pub fn into_blob(mut self) -> TractResult<Blob> {
1583        ensure!(self.dt.is_copy());
1584        Ok(std::mem::take(&mut self.data))
1585    }
1586}
1587
1588impl PartialEq for Tensor {
1589    fn eq(&self, other: &Tensor) -> bool {
1590        if self.dt != other.dt || self.shape != other.shape {
1591            return false;
1592        }
1593        self.eq_dt(other).unwrap_or(false)
1594    }
1595}
1596
1597impl fmt::Debug for Tensor {
1598    fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
1599        let content = self.dump(false).unwrap_or_else(|e| format!("Error : {e:?}"));
1600        write!(formatter, "{content}")
1601    }
1602}
1603
1604#[cfg(feature = "complex")]
1605pub fn reinterpret_inner_dim_as_complex(mut t: Tensor) -> TractResult<Tensor> {
1606    ensure!(
1607        t.shape().last() == Some(&2),
1608        "The last dimension in the tensor shape {:?} must be 2",
1609        t.shape()
1610    );
1611    unsafe {
1612        t.shape.pop();
1613        t.set_datum_type(t.datum_type().complexify()?);
1614        t.update_strides_and_len();
1615        Ok(t)
1616    }
1617}
1618
1619#[cfg(feature = "complex")]
1620pub fn reinterpret_complex_as_inner_dim(mut t: Tensor) -> TractResult<Tensor> {
1621    unsafe {
1622        t.shape.push(2);
1623        t.set_datum_type(t.datum_type().decomplexify()?);
1624        t.update_strides_and_len();
1625        Ok(t)
1626    }
1627}
1628
1629pub fn natural_strides(shape: &[usize]) -> TVec<isize> {
1630    let mut strides = tvec!();
1631    compute_natural_stride_to(&mut strides, shape);
1632    strides
1633}
1634
1635fn compute_natural_stride_to(strides: &mut TVec<isize>, shape: &[usize]) {
1636    match shape.len() {
1637        0 => (),
1638        1 => strides.push(1),
1639        2 => strides.extend_from_slice(&[shape[1] as isize, 1]),
1640        3 => strides.extend_from_slice(&[(shape[1] * shape[2]) as isize, shape[2] as _, 1]),
1641        4 => strides.extend_from_slice(&[
1642            (shape[1] * shape[2] * shape[3]) as isize,
1643            (shape[2] * shape[3]) as _,
1644            shape[3] as _,
1645            1,
1646        ]),
1647        _ => {
1648            strides.push(1);
1649            for dim in shape.as_ref().iter().skip(1).rev() {
1650                let previous = *strides.last().unwrap();
1651                strides.push(previous * *dim as isize)
1652            }
1653            strides.reverse();
1654        }
1655    }
1656}
1657
1658impl<D: ::ndarray::Dimension, T: Datum> From<Array<T, D>> for Tensor {
1659    fn from(it: Array<T, D>) -> Tensor {
1660        Tensor::from_datum(it.into_dyn())
1661    }
1662}
1663
1664/// Convenient conversion to Tensor.
1665pub trait IntoTensor: Sized {
1666    /// Convert Self to a Tensor.
1667    ///
1668    /// May perform a copy
1669    fn into_tensor(self) -> Tensor;
1670}
1671
1672/// Convenient conversion to Arc<Tensor>.
1673pub trait IntoArcTensor: Sized {
1674    /// Convert Self to a Arc<Tensor>.
1675    ///
1676    /// May perform a copy
1677    fn into_arc_tensor(self) -> Arc<Tensor>;
1678}
1679
1680impl<D: ::ndarray::Dimension, T: Datum> IntoTensor for Array<T, D> {
1681    fn into_tensor(self) -> Tensor {
1682        Tensor::from(self)
1683    }
1684}
1685
1686impl<D: ::ndarray::Dimension, T: Datum> IntoArcTensor for Array<T, D> {
1687    fn into_arc_tensor(self) -> Arc<Tensor> {
1688        Arc::new(Tensor::from(self))
1689    }
1690}
1691
1692impl IntoTensor for Tensor {
1693    fn into_tensor(self) -> Tensor {
1694        self
1695    }
1696}
1697
1698impl IntoTensor for Arc<Tensor> {
1699    fn into_tensor(self) -> Tensor {
1700        Arc::try_unwrap(self).unwrap_or_else(|t| (*t).clone())
1701    }
1702}
1703
1704impl IntoArcTensor for Tensor {
1705    fn into_arc_tensor(self) -> Arc<Tensor> {
1706        Arc::new(self)
1707    }
1708}
1709
1710impl IntoArcTensor for Arc<Tensor> {
1711    fn into_arc_tensor(self) -> Arc<Tensor> {
1712        self
1713    }
1714}
1715
1716#[cfg(test)]
1717mod tests {
1718    use crate::dim::SymbolScope;
1719    use crate::prelude::tensor1;
1720
1721    use super::*;
1722    use litteral::tensor0;
1723    use proptest::collection::vec;
1724    use proptest::prelude::*;
1725
1726    #[derive(Debug)]
1727    struct PermuteAxisProblem {
1728        shape: Vec<usize>,
1729        permutation: Vec<usize>,
1730    }
1731
1732    impl Arbitrary for PermuteAxisProblem {
1733        type Strategy = BoxedStrategy<PermuteAxisProblem>;
1734        type Parameters = ();
1735
1736        fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1737            (0..8usize)
1738                .prop_flat_map(|rank| {
1739                    let permute: Vec<usize> = (0..rank).collect();
1740                    (proptest::collection::vec(1..5usize, rank), Just(permute).prop_shuffle())
1741                })
1742                .prop_map(|(shape, permutation)| PermuteAxisProblem { shape, permutation })
1743                .boxed()
1744        }
1745    }
1746
1747    impl PermuteAxisProblem {
1748        fn input(&self) -> ArrayD<i32> {
1749            let mut i = 0;
1750            ArrayD::from_shape_simple_fn(&*self.shape, || {
1751                i += 1;
1752                i
1753            })
1754            .permuted_axes(&*self.permutation)
1755        }
1756
1757        fn reference(&self) -> Tensor {
1758            let values: Vec<i32> = self.input().iter().copied().collect();
1759            let shape = self.permutation.iter().map(|ix| self.shape[*ix]).collect::<TVec<usize>>();
1760            super::litteral::tensor1(&values).into_shape(&shape).unwrap()
1761        }
1762
1763        fn tract(&self) -> Tensor {
1764            Tensor::from(self.input())
1765        }
1766
1767        fn check(&self) -> proptest::test_runner::TestCaseResult {
1768            prop_assert_eq!(self.tract(), self.reference());
1769            Ok(())
1770        }
1771    }
1772
1773    proptest::proptest! {
1774        #[test]
1775        fn prop(pb: PermuteAxisProblem) {
1776            pb.check().unwrap();
1777        }
1778    }
1779
1780    #[test]
1781    fn t_1_2() {
1782        PermuteAxisProblem { shape: vec![2, 1], permutation: vec![1, 0] }.check().unwrap();
1783    }
1784
1785    #[test]
1786    fn t_2_2() {
1787        PermuteAxisProblem { shape: vec![2, 2], permutation: vec![1, 0] }.check().unwrap();
1788    }
1789
1790    #[derive(Debug)]
1791    struct BroadcastVecToShape {
1792        vec: Vec<f32>,
1793        axis: usize,
1794        shape: TVec<usize>,
1795    }
1796
1797    impl BroadcastVecToShape {
1798        fn check(&self) -> proptest::test_runner::TestCaseResult {
1799            let input = tensor1(&self.vec);
1800            let mut intermediate = tvec![1usize; self.shape.len()];
1801            intermediate[self.axis] = self.vec.len();
1802            let reference = input
1803                .clone()
1804                .into_shape(&intermediate)
1805                .unwrap()
1806                .broadcast_to_shape(&self.shape)
1807                .unwrap();
1808            prop_assert_eq!(
1809                reference,
1810                input.broadcast_vector_to_shape(&self.shape, self.axis).unwrap()
1811            );
1812            Ok(())
1813        }
1814    }
1815
1816    impl Arbitrary for BroadcastVecToShape {
1817        type Strategy = BoxedStrategy<BroadcastVecToShape>;
1818        type Parameters = ();
1819
1820        fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1821            vec(0usize..5, 0usize..4)
1822                .prop_flat_map(|shape| {
1823                    (vec(-10f32..10f32, 0usize..5), Just(shape.clone()), 0..shape.len() + 1)
1824                })
1825                .prop_map(|(vec, mut shape, axis)| {
1826                    shape.insert(axis, vec.len());
1827                    BroadcastVecToShape { vec, shape: shape.into(), axis }
1828                })
1829                .boxed()
1830        }
1831    }
1832
1833    proptest::proptest! {
1834        #[test]
1835        fn broadcast_vector_to_shape_prop(pb: BroadcastVecToShape) {
1836            pb.check().unwrap()
1837        }
1838    }
1839
1840    #[test]
1841    #[cfg(feature = "complex")]
1842    fn test_reinterpret_inner_dim_as_complex() -> TractResult<()> {
1843        let input = crate::internal::tensor2(&[[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]]);
1844        let cplx_input = reinterpret_inner_dim_as_complex(input)?;
1845        let expected = crate::internal::tensor1(&[
1846            Complex::new(1.0f32, 2.0),
1847            Complex::new(3.0, 4.0),
1848            Complex::new(5.0, 6.0),
1849        ]);
1850        assert_eq!(expected, cplx_input);
1851        Ok(())
1852    }
1853
1854    #[test]
1855    #[cfg(feature = "complex")]
1856    fn test_reinterpret_inner_dim_as_complex_2() -> TractResult<()> {
1857        let input =
1858            crate::internal::tensor3(&[[[1i32, 2], [1, 2]], [[3, 4], [3, 4]], [[5, 6], [5, 6]]]);
1859        let cplx_input = reinterpret_inner_dim_as_complex(input)?;
1860        let expected = crate::internal::tensor2(&[
1861            [Complex::new(1i32, 2), Complex::new(1, 2)],
1862            [Complex::new(3, 4), Complex::new(3, 4)],
1863            [Complex::new(5, 6), Complex::new(5, 6)],
1864        ]);
1865        assert_eq!(expected, cplx_input);
1866        Ok(())
1867    }
1868
1869    #[test]
1870    fn clone_tdim_tensor() {
1871        let symbols = SymbolScope::default();
1872        let a = symbols.sym("a");
1873        let t = tensor0(TDim::from(a));
1874        let _ = t.clone();
1875    }
1876}