Skip to main content

tract_data/
tensor.rs

1//! `Tensor`, tract main data object of interest.
2use crate::TVec;
3use crate::blob::Blob;
4use crate::datum::{ClampCast, Datum, DatumType, QParams, round_ties_to_even, scale_by};
5use crate::dim::TDim;
6use crate::internal::*;
7use crate::opaque::Opaque;
8use half::f16;
9use itertools::{Itertools, izip};
10use ndarray::prelude::*;
11#[cfg(feature = "complex")]
12use num_complex::Complex;
13use num_traits::{Float, Zero};
14use std::borrow::Cow;
15use std::fmt;
16use std::hash::Hash;
17use std::ops::Range;
18use std::sync::Arc;
19
20pub mod litteral;
21pub mod view;
22
23#[derive(Copy, Clone, Default, Debug)]
24pub enum Approximation {
25    Exact,
26    #[default]
27    Close,
28    Approximate,
29    VeryApproximate,
30    SuperApproximate,
31    UltraApproximate,
32    Custom(f32, f32, f32),
33}
34
35impl PartialEq for Approximation {
36    fn eq(&self, other: &Self) -> bool {
37        use Approximation::Custom;
38        if let (Custom(aa, ar, ao), Custom(ba, br, bo)) = (self, other) {
39            aa == ba && ar == br && bo == ao
40        } else {
41            std::mem::discriminant(self) == std::mem::discriminant(other)
42        }
43    }
44}
45
46impl Eq for Approximation {}
47
48impl From<bool> for Approximation {
49    fn from(b: bool) -> Self {
50        if b { Self::Approximate } else { Self::Exact }
51    }
52}
53
54impl Approximation {
55    fn atol_rtol_outliers(&self, dt: &DatumType) -> (f64, f64, f64) {
56        use Approximation::*;
57        match (self, dt) {
58            (Exact, _) => (0.0, 0.0, 0.0),
59            (Close, DatumType::F16) => (1e-3, 1e-3, 0.0),
60            (Approximate, DatumType::F16) => (1e-3, 5e-3, 0.0),
61            (Approximate, qp) if qp.is_quantized() => (qp.zp_scale().1 as f64, 0., 0.0),
62            (Close, _) => (1e-7, 1e-7, 0.0),
63            (Approximate, _) => (1e-4, 5e-4, 0.0),
64            (VeryApproximate, _) => (5e-2, 1e-2, 0.0),
65            (SuperApproximate, _) => (0.1, 0.05, 0.0001),
66            (UltraApproximate, _) => (0.2, 0.1, 0.0005),
67            (Custom(atol, rtol, out), _) => (*atol as _, *rtol as _, *out as _),
68        }
69    }
70}
71
72/// Tensor is a concrete tensor in tract.
73#[derive(Eq)]
74pub struct Tensor {
75    dt: DatumType,
76    shape: TVec<usize>,
77    strides: TVec<isize>,
78    len: usize,
79    data: Blob,
80}
81
82unsafe impl Send for Tensor {}
83unsafe impl Sync for Tensor {}
84
85impl Hash for Tensor {
86    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
87        use DatumType::*;
88        self.dt.hash(state);
89        self.shape.hash(state);
90        self.data.layout().align().hash(state);
91        unsafe {
92            match self.dt {
93                Bool => self.as_slice_unchecked::<bool>().hash(state),
94                I8 => self.as_slice_unchecked::<i8>().hash(state),
95                I16 => self.as_slice_unchecked::<i16>().hash(state),
96                I32 => self.as_slice_unchecked::<i32>().hash(state),
97                I64 => self.as_slice_unchecked::<i64>().hash(state),
98                U8 => self.as_slice_unchecked::<u8>().hash(state),
99                U16 => self.as_slice_unchecked::<u16>().hash(state),
100                U32 => self.as_slice_unchecked::<u32>().hash(state),
101                U64 => self.as_slice_unchecked::<u64>().hash(state),
102                F16 => self.as_slice_unchecked::<i16>().hash(state),
103                F32 => self.as_slice_unchecked::<i32>().hash(state),
104                F64 => self.as_slice_unchecked::<i64>().hash(state),
105                TDim => self.as_slice_unchecked::<crate::dim::TDim>().hash(state),
106                String => self.as_slice_unchecked::<std::string::String>().hash(state),
107                Blob => self.as_slice_unchecked::<crate::blob::Blob>().hash(state),
108                Opaque => self.as_slice_unchecked::<crate::opaque::Opaque>().hash(state),
109                QI8(_) => self.as_slice_unchecked::<i8>().hash(state),
110                QU8(_) => self.as_slice_unchecked::<u8>().hash(state),
111                QI32(_) => self.as_slice_unchecked::<i32>().hash(state),
112                #[cfg(feature = "complex")]
113                ComplexI16 => self.as_slice_unchecked::<Complex<i16>>().hash(state),
114                #[cfg(feature = "complex")]
115                ComplexI32 => self.as_slice_unchecked::<Complex<i32>>().hash(state),
116                #[cfg(feature = "complex")]
117                ComplexI64 => self.as_slice_unchecked::<Complex<i64>>().hash(state),
118                #[cfg(feature = "complex")]
119                ComplexF16 => self.as_slice_unchecked::<Complex<i16>>().hash(state),
120                #[cfg(feature = "complex")]
121                ComplexF32 => self.as_slice_unchecked::<Complex<i32>>().hash(state),
122                #[cfg(feature = "complex")]
123                ComplexF64 => self.as_slice_unchecked::<Complex<i64>>().hash(state),
124            }
125        }
126    }
127}
128
129impl Clone for Tensor {
130    fn clone(&self) -> Tensor {
131        self.deep_clone()
132    }
133}
134
135impl Default for Tensor {
136    fn default() -> Tensor {
137        litteral::tensor0(0f32)
138    }
139}
140
141impl Drop for Tensor {
142    fn drop(&mut self) {
143        macro_rules! drop_in_place {
144            ($t: ty) => {
145                if self.dt == <$t>::datum_type() {
146                    unsafe {
147                        let slice = self.as_slice_mut::<$t>().unwrap();
148                        std::ptr::drop_in_place(slice as *mut [$t]);
149                    }
150                }
151            };
152        }
153        drop_in_place!(Blob);
154        drop_in_place!(String);
155        drop_in_place!(TDim);
156        drop_in_place!(Opaque);
157    }
158}
159
160#[allow(unreachable_code)]
161pub fn vector_size() -> usize {
162    #[cfg(target_arch = "x86_64")]
163    {
164        return if is_x86_feature_detected!("avx512f") { 512 / 8 } else { 256 / 8 };
165    }
166    128 / 8
167}
168
169impl Tensor {
170    /// Create an uninitialized tensor (dt as type paramater).
171    #[inline]
172    pub unsafe fn uninitialized<T: Datum>(shape: &[usize]) -> TractResult<Tensor> {
173        unsafe { Self::uninitialized_dt(T::datum_type(), shape) }
174    }
175
176    /// Create an uninitialized tensor (dt as regular parameter).
177    #[inline]
178    pub unsafe fn uninitialized_dt(dt: DatumType, shape: &[usize]) -> TractResult<Tensor> {
179        unsafe { Self::uninitialized_aligned_dt(dt, shape, vector_size()) }
180    }
181
182    /// Create an uninitialized tensor with a given alignment (in bytes).
183    #[inline]
184    pub unsafe fn uninitialized_aligned<T: Datum>(
185        shape: &[usize],
186        alignment: usize,
187    ) -> TractResult<Tensor> {
188        unsafe { Self::uninitialized_aligned_dt(T::datum_type(), shape, alignment) }
189    }
190
191    /// Create an uninitialized tensor with a given alignment (in bytes).
192    pub unsafe fn uninitialized_aligned_dt(
193        dt: DatumType,
194        shape: &[usize],
195        alignment: usize,
196    ) -> TractResult<Tensor> {
197        let bytes = shape.iter().cloned().product::<usize>() * dt.size_of();
198        let data = unsafe { Blob::new_for_size_and_align(bytes, alignment) };
199        let mut tensor = Tensor { strides: tvec!(), dt, shape: shape.into(), data, len: 0 };
200        if tensor.shape.len() == 0 {
201            tensor.len = 1;
202        } else {
203            tensor.update_strides_and_len();
204        }
205        if !tensor.data.is_empty() {
206            if dt == String::datum_type() || dt == Blob::datum_type() {
207                // assumes zero-initialized string and blob are valid
208                tensor.data.fill(0);
209            } else if dt == TDim::datum_type() {
210                unsafe {
211                    tensor
212                        .as_slice_mut_unchecked::<TDim>()
213                        .iter_mut()
214                        .for_each(|dim| std::ptr::write(dim, TDim::zero()))
215                }
216            } else if dt == Opaque::datum_type() {
217                unsafe {
218                    tensor.as_slice_mut_unchecked::<Opaque>().iter_mut().for_each(|p| {
219                        std::ptr::write(p, Opaque::default());
220                    })
221                };
222            } else if cfg!(debug_assertions) {
223                assert!(dt.is_copy());
224                if dt == DatumType::F32 {
225                    tensor.fill_t(f32::NAN).unwrap();
226                } else {
227                    // safe, non copy types have been dealt with
228                    tensor.as_bytes_mut().iter_mut().for_each(|x| *x = (-1i8) as u8);
229                }
230            }
231        }
232        Ok(tensor)
233    }
234
235    pub fn stack_tensors(
236        axis: usize,
237        tensors: &[impl std::borrow::Borrow<Tensor>],
238    ) -> TractResult<Tensor> {
239        ensure!(tensors.len() > 0);
240        let rank = tensors[0].borrow().rank();
241        ensure!(axis < rank);
242        ensure!(tensors.iter().all(|t| t.borrow().rank() == rank));
243        let dt = tensors[0].borrow().datum_type();
244        ensure!(tensors.iter().all(|t| t.borrow().datum_type() == dt));
245        let mut shape: TVec<usize> = tensors[0].borrow().shape().into();
246        for ax in 0..rank {
247            if ax != axis {
248                ensure!(tensors.iter().all(|t| t.borrow().shape()[ax] == shape[ax]));
249            }
250        }
251        shape[axis] = tensors.iter().map(|v| v.borrow().shape()[axis]).sum();
252        unsafe {
253            let mut result = Tensor::uninitialized_dt(dt, &shape)?;
254            if dt.is_copy() && shape[..axis].iter().all(|d| *d == 1) {
255                let mut offset = 0isize;
256                for v in tensors {
257                    let v = v.borrow();
258                    let len = v.data.len();
259                    std::ptr::copy_nonoverlapping(
260                        v.data.as_ptr(),
261                        result.data.as_mut_ptr().offset(offset),
262                        len,
263                    );
264                    offset += len as isize;
265                }
266            } else {
267                let mut offset = 0;
268                for t in tensors {
269                    let t = t.borrow();
270                    let len = t.shape()[axis];
271                    result.assign_slice_from_resolved(offset..offset + len, t, 0..len, axis);
272                    offset += len;
273                }
274            }
275
276            Ok(result)
277        }
278    }
279
280    pub fn clear<T: Datum + num_traits::Zero + Clone>(&mut self) -> TractResult<()> {
281        self.fill_t(T::zero())
282    }
283
284    pub fn zero<T: Datum + num_traits::Zero>(shape: &[usize]) -> TractResult<Tensor> {
285        unsafe {
286            let mut t = Tensor::uninitialized::<T>(shape)?;
287            t.clear::<T>()?;
288            Ok(t)
289        }
290    }
291
292    pub fn zero_scalar<T: Datum + num_traits::Zero>() -> TractResult<Tensor> {
293        Tensor::zero::<T>(&[])
294    }
295
296    pub fn zero_scalar_dt(dt: DatumType) -> TractResult<Tensor> {
297        Tensor::zero_dt(dt, &[])
298    }
299
300    pub fn zero_dt(dt: DatumType, shape: &[usize]) -> TractResult<Tensor> {
301        Tensor::zero_aligned_dt(dt, shape, vector_size())
302    }
303
304    pub fn fill_t<T: Datum + Clone>(&mut self, value: T) -> TractResult<()> {
305        self.as_slice_mut::<T>()?.iter_mut().for_each(|item| *item = value.clone());
306        Ok(())
307    }
308
309    pub fn zero_aligned_dt(
310        dt: DatumType,
311        shape: &[usize],
312        alignment: usize,
313    ) -> TractResult<Tensor> {
314        if shape.iter().product::<usize>() == 0 {
315            unsafe { return Tensor::uninitialized_dt(dt, shape) };
316        }
317        if dt.is_quantized() {
318            unsafe {
319                let mut t = Tensor::uninitialized_dt(dt, shape)?;
320                let zp = dt.zp_scale().0;
321                match dt.unquantized() {
322                    DatumType::I8 => {
323                        t.as_slice_mut::<i8>()?.iter_mut().for_each(|item| *item = zp as _)
324                    }
325                    DatumType::U8 => {
326                        t.as_slice_mut::<u8>()?.iter_mut().for_each(|item| *item = zp as _)
327                    }
328                    DatumType::I32 => {
329                        t.as_slice_mut::<i32>()?.iter_mut().for_each(|item| *item = zp as _)
330                    }
331                    _ => unreachable!(),
332                }
333                Ok(t)
334            }
335        } else {
336            dispatch_zerolike!(Self::zero_aligned(dt)(shape, alignment))
337        }
338    }
339
340    pub fn zero_aligned<T: Datum + num_traits::Zero>(
341        shape: &[usize],
342        alignment: usize,
343    ) -> TractResult<Tensor> {
344        unsafe {
345            let mut tensor = Self::uninitialized_aligned::<T>(shape, alignment)?;
346            tensor.clear::<T>()?;
347            Ok(tensor)
348        }
349    }
350
351    /// Create a tensor with a given shape and a slice of elements.
352    /// The data is copied and aligned to size of T.
353    pub fn from_shape<T: Datum + Copy>(shape: &[usize], data: &[T]) -> TractResult<Tensor> {
354        Self::from_shape_align(shape, data, vector_size())
355    }
356
357    /// Create a tensor with a given shape and a slice of elements.
358    /// The data is copied and aligned to given alignment.
359    pub fn from_shape_align<T: Datum + Copy>(
360        shape: &[usize],
361        data: &[T],
362        align: usize,
363    ) -> TractResult<Tensor> {
364        ensure!(
365            data.len() == shape.iter().product::<usize>(),
366            "Shape product must be equal to data length"
367        );
368        unsafe {
369            let bytes = std::slice::from_raw_parts(
370                data.as_ptr() as *const u8,
371                data.len() * T::datum_type().size_of(),
372            );
373            let dt = T::datum_type();
374            Self::from_raw_dt_align(dt, shape, bytes, align)
375        }
376    }
377
378    /// Create a tensor from raw data.
379    ///
380    /// It copies the data, aligning it to the size of T.
381    pub unsafe fn from_raw<T: Datum>(shape: &[usize], content: &[u8]) -> TractResult<Tensor> {
382        unsafe { Tensor::from_raw_dt(T::datum_type(), shape, content) }
383    }
384
385    pub unsafe fn from_raw_aligned<T: Datum>(
386        shape: &[usize],
387        content: &[u8],
388        align: usize,
389    ) -> TractResult<Tensor> {
390        unsafe { Tensor::from_raw_dt_align(T::datum_type(), shape, content, align) }
391    }
392
393    pub unsafe fn from_raw_dt(
394        dt: DatumType,
395        shape: &[usize],
396        content: &[u8],
397    ) -> TractResult<Tensor> {
398        unsafe { Self::from_raw_dt_align(dt, shape, content, vector_size()) }
399    }
400
401    pub unsafe fn from_raw_dt_align(
402        dt: DatumType,
403        shape: &[usize],
404        content: &[u8],
405        align: usize,
406    ) -> TractResult<Tensor> {
407        let mut tensor = unsafe { Tensor::uninitialized_aligned_dt(dt, shape, align) }?;
408        tensor.as_bytes_mut().copy_from_slice(content);
409        Ok(tensor)
410    }
411
412    pub unsafe fn from_slice_align<T: Datum>(content: &[T], align: usize) -> TractResult<Tensor> {
413        let bytes = if content.len() == 0 {
414            &[]
415        } else {
416            unsafe {
417                std::slice::from_raw_parts(
418                    content.as_ptr() as *const u8,
419                    content.len() * T::datum_type().size_of(),
420                )
421            }
422        };
423        unsafe { Self::from_raw_dt_align(T::datum_type(), &[content.len()], bytes, align) }
424    }
425
426    /// Get the number of dimensions (or axes) of the tensor.
427    #[inline]
428    pub fn rank(&self) -> usize {
429        self.shape.len()
430    }
431
432    /// Get the shape of the tensor.
433    #[inline]
434    pub fn shape(&self) -> &[usize] {
435        &self.shape
436    }
437
438    /// Get the number of values in the tensor.
439    #[inline]
440    #[allow(clippy::len_without_is_empty)]
441    pub fn len(&self) -> usize {
442        self.len
443    }
444
445    /// Get the number of valeus in the tensor.
446    #[inline]
447    #[allow(clippy::len_without_is_empty)]
448    pub fn volume(&self) -> usize {
449        self.len
450    }
451
452    /// Get the shape of the tensor.
453    #[inline]
454    pub fn strides(&self) -> &[isize] {
455        &self.strides
456    }
457
458    fn update_strides_and_len(&mut self) {
459        self.strides.clear();
460        if self.shape.len() == 0 {
461            self.len = 1;
462            return;
463        }
464        compute_natural_stride_to(&mut self.strides, &self.shape);
465        self.len = unsafe { *self.strides.get_unchecked(0) as usize * self.shape.get_unchecked(0) };
466    }
467
468    /// Force the tensor shape, no consistency check.
469    pub unsafe fn set_shape_unchecked(&mut self, shape: &[usize]) {
470        if shape != &*self.shape {
471            self.shape.clear();
472            self.shape.extend_from_slice(shape);
473            self.update_strides_and_len();
474        }
475    }
476
477    /// Force the tensor shape and strides, no consistency check.
478    pub unsafe fn set_geometry_unchecked(&mut self, shape: &[usize], strides: &[isize]) {
479        self.shape.clear();
480        self.shape.extend_from_slice(shape);
481        self.strides.clear();
482        self.strides.extend_from_slice(strides);
483    }
484
485    /// Force the tensor shape.
486    pub fn set_shape(&mut self, shape: &[usize]) -> TractResult<()> {
487        if self.len() != shape.iter().product::<usize>() {
488            bail!("Invalid reshape {:?} to {:?}", self.shape, shape);
489        }
490        unsafe { self.set_shape_unchecked(shape) }
491        Ok(())
492    }
493
494    pub fn permute_axes(self, axes: &[usize]) -> TractResult<Tensor> {
495        ensure!(axes.iter().duplicates().next().is_none());
496        ensure!(axes.iter().all(|a| *a < self.rank()));
497        unsafe {
498            #[inline]
499            unsafe fn permute<T: Datum>(axes: &[usize], input: Tensor) -> Tensor {
500                unsafe { input.into_array_unchecked::<T>().permuted_axes(axes).into_tensor() }
501            }
502            let dt = self.datum_type();
503            let mut t = dispatch_datum_by_size!(permute(self.datum_type())(axes, self));
504            t.set_datum_type(dt);
505            Ok(t)
506        }
507    }
508
509    pub fn move_axis(self, from: usize, to: usize) -> TractResult<Tensor> {
510        let mut permutation: Vec<usize> = (0..self.rank()).collect();
511        permutation.remove(from);
512        permutation.insert(to, from);
513        self.permute_axes(&permutation)
514    }
515
516    pub fn collapse_axis_with_next(mut self, axis: usize) -> Tensor {
517        let removed = self.shape.remove(axis + 1);
518        self.shape[axis] *= removed;
519        self.update_strides_and_len();
520        self
521    }
522
523    pub fn split_axis(mut self, axis: usize, outer_dim: usize) -> TractResult<Tensor> {
524        if self.shape[axis] % outer_dim != 0 {
525            bail!(
526                "Invalid axis split, shape is {:?}, axis split at {}, outer {}",
527                self.shape,
528                axis,
529                outer_dim
530            );
531        }
532        self.shape.insert(axis + 1, self.shape[axis] / outer_dim);
533        self.shape[axis] = outer_dim;
534        self.update_strides_and_len();
535        Ok(self)
536    }
537
538    /// Reshape the tensor to `shape`.
539    pub fn into_shape(mut self, shape: &[usize]) -> TractResult<Tensor> {
540        self.set_shape(shape)?;
541        Ok(self)
542    }
543
544    pub fn insert_axis(&mut self, axis: usize) -> TractResult<()> {
545        self.shape.insert(axis, 1);
546        self.strides.insert(axis, self.strides.get(axis).copied().unwrap_or(1));
547        Ok(())
548    }
549
550    pub fn remove_axis(&mut self, axis: usize) -> TractResult<()> {
551        ensure!(self.shape[axis] == 1, "Remove a non-1 axis: axis {} in {:?}", axis, self);
552        self.shape.remove(axis);
553        self.strides.remove(axis);
554        Ok(())
555    }
556
557    pub fn broadcast_into_rank(mut self, rank: usize) -> TractResult<Tensor> {
558        self.broadcast_to_rank(rank)?;
559        self.update_strides_and_len();
560        Ok(self)
561    }
562
563    pub fn broadcast_to_rank(&mut self, rank: usize) -> TractResult<()> {
564        if rank < self.rank() {
565            bail!("Can only broadcast to higher rank")
566        }
567        while self.shape.len() < rank {
568            self.shape.insert(0, 1)
569        }
570        self.update_strides_and_len();
571        Ok(())
572    }
573
574    pub fn broadcast_scalar_to_shape(&self, shape: &[usize]) -> TractResult<Tensor> {
575        if self.rank() > 0 {
576            bail!("broadcast_scalar_to_shape called on {:?}, which is not a salar", self);
577        }
578        unsafe fn make<T: Datum>(src: &Tensor, dst: &mut Tensor) {
579            unsafe {
580                let value: &T = src.to_scalar_unchecked::<T>();
581                dst.as_slice_mut_unchecked::<T>().iter_mut().for_each(|item| *item = value.clone())
582            };
583        }
584        unsafe {
585            let mut t = Tensor::uninitialized_dt(self.datum_type(), shape)?;
586            dispatch_datum_by_size!(make(self.datum_type())(self, &mut t));
587            Ok(t)
588        }
589    }
590
591    fn broadcast_to_shape_t<T: Datum>(&self, shape: &[usize]) -> TractResult<Tensor> {
592        unsafe {
593            let view = self.to_array_view_unchecked::<T>();
594            let mut output = view
595                .broadcast(shape)
596                .with_context(|| format!("Broadcasting {view:?} to {shape:?}"))?
597                .into_owned()
598                .into_tensor();
599            output.set_datum_type(self.datum_type());
600            Ok(output)
601        }
602    }
603
604    pub fn broadcast_to_shape(&self, shape: &[usize]) -> TractResult<Tensor> {
605        dispatch_datum!(Self::broadcast_to_shape_t(self.dt)(self, shape))
606    }
607
608    pub fn broadcast_vector_to_shape(&self, shape: &[usize], axis: usize) -> TractResult<Tensor> {
609        ensure!(self.rank() == 1);
610        ensure!(shape[axis] == self.len());
611        if !self.datum_type().is_copy() {
612            let mut vec_shape = vec![1; shape.len()];
613            vec_shape[axis] = self.len();
614            return self.clone().into_shape(&vec_shape)?.broadcast_to_shape(shape);
615        }
616        unsafe {
617            let mut output = Tensor::uninitialized_dt(self.datum_type(), shape)?;
618            if output.len() == 0 {
619                return Ok(output);
620            }
621            let inner_len = shape[axis + 1..].iter().product::<usize>();
622
623            unsafe fn splat<T>(input: &Tensor, output: &mut Tensor, inner_len: usize)
624            where
625                T: Datum + Copy,
626            {
627                unsafe {
628                    for ix in 0..input.len() {
629                        let value: T = input.as_slice_unchecked()[ix];
630                        output.as_slice_mut_unchecked::<T>()[ix * inner_len..(ix + 1) * inner_len]
631                            .iter_mut()
632                            .for_each(|item| *item = value);
633                    }
634                }
635            }
636            dispatch_copy_by_size!(splat(self.datum_type())(&self, &mut output, inner_len));
637
638            let outer_len = shape[0..axis].iter().product::<usize>();
639            let repeat_bytes_len = inner_len * self.as_bytes().len();
640            let bytes = output.as_bytes_mut();
641            for ix in 1..outer_len {
642                bytes.copy_within(0..repeat_bytes_len, ix * repeat_bytes_len);
643            }
644
645            Ok(output)
646        }
647    }
648
649    pub fn assign_slice(
650        &mut self,
651        range: impl std::ops::RangeBounds<usize>,
652        src: &Tensor,
653        src_range: impl std::ops::RangeBounds<usize>,
654        axis: usize,
655    ) -> TractResult<()> {
656        ensure!(self.rank() == src.rank());
657        ensure!(axis < self.rank());
658        let range = clip_range_bounds(self.shape[axis], range);
659        let src_range = clip_range_bounds(src.shape[axis], src_range);
660        ensure!(
661            src.datum_type() == self.datum_type(),
662            "Attempt to assign into {:?} from {:?}, datum type mismatch",
663            self.datum_type(),
664            src.datum_type()
665        );
666        ensure!(
667            src_range.len() == range.len(),
668            "Attempt to assign a range of {:?} from a range of {:?}",
669            range,
670            src_range,
671        );
672        ensure!(
673            itertools::izip!(0.., self.shape(), src.shape())
674                .all(|(ix, dst, src)| ix == axis || src == dst),
675            "Attempt to assign a {}-axis range of {:?} from a range of {:?}",
676            axis,
677            self,
678            src
679        );
680        ensure!(
681            src_range.end <= src.shape()[axis],
682            "Assigning from invalid slice (axis {}, {:?}) of {:?}",
683            axis,
684            src_range,
685            src
686        );
687        ensure!(
688            range.end <= self.shape()[axis],
689            "Assigning to invalid slice (axis {}, {:?}) of {:?}",
690            axis,
691            range,
692            self
693        );
694        unsafe { self.assign_slice_from_resolved(range, src, src_range, axis) };
695        Ok(())
696    }
697
698    pub unsafe fn assign_slice_unchecked(
699        &mut self,
700        range: impl std::ops::RangeBounds<usize>,
701        src: &Tensor,
702        src_range: impl std::ops::RangeBounds<usize>,
703        axis: usize,
704    ) {
705        let range = clip_range_bounds(self.shape[axis], range);
706        let src_range = clip_range_bounds(src.shape[axis], src_range);
707        unsafe { self.assign_slice_from_resolved(range, src, src_range, axis) };
708    }
709
710    #[allow(clippy::ptr_eq)]
711    unsafe fn assign_slice_from_resolved(
712        &mut self,
713        range: std::ops::Range<usize>,
714        src: &Tensor,
715        src_range: std::ops::Range<usize>,
716        axis: usize,
717    ) {
718        unsafe {
719            use ndarray::Slice;
720            unsafe fn assign_slice_t<T: Datum>(
721                to: &mut Tensor,
722                to_range: Range<usize>,
723                from: &Tensor,
724                from_range: Range<usize>,
725                axis: usize,
726            ) {
727                unsafe {
728                    to.to_array_view_mut_unchecked::<T>()
729                        .slice_axis_mut(Axis(axis), Slice::from(to_range))
730                        .assign(
731                            &from
732                                .to_array_view_unchecked::<T>()
733                                .slice_axis(Axis(axis), Slice::from(from_range)),
734                        )
735                }
736            }
737            if self.datum_type().is_copy() && self.shape[..axis].iter().all(|d| *d == 1) {
738                let stride = self.strides[axis] as usize * self.datum_type().size_of();
739                let dst_start = (stride * range.start) as isize;
740                let src_start = (stride * src_range.start) as isize;
741                let len = stride * range.len();
742                if len > 0 {
743                    if self.data.as_ptr() != src.data.as_ptr() {
744                        std::ptr::copy_nonoverlapping(
745                            src.data.as_ptr().offset(src_start),
746                            self.data.as_mut_ptr().offset(dst_start),
747                            len,
748                        );
749                    } else {
750                        std::ptr::copy(
751                            src.data.as_ptr().offset(src_start),
752                            self.data.as_mut_ptr().offset(dst_start),
753                            len,
754                        );
755                    }
756                }
757            } else {
758                dispatch_datum!(assign_slice_t(self.datum_type())(
759                    self, range, src, src_range, axis
760                ));
761            }
762        }
763    }
764
765    /// Get the datum type of the tensor.
766    #[inline]
767    pub fn datum_type(&self) -> DatumType {
768        self.dt
769    }
770
771    /// Set the datum type of the tensor.
772    #[inline]
773    pub unsafe fn set_datum_type(&mut self, dt: DatumType) {
774        self.dt = dt
775    }
776
777    /// Dump the tensor in a human readable form.
778    ///
779    /// `force_full` will force the tensor to be dump in full even if it is big.
780    pub fn dump(&self, force_full: bool) -> TractResult<String> {
781        unsafe fn dump_t<D: Datum>(tensor: &Tensor, n: usize) -> String {
782            unsafe {
783                if let Some(qp) = tensor.datum_type().qparams() {
784                    let integers = tensor.cast_to::<i32>().unwrap();
785                    integers.as_slice_unchecked::<i32>()[0..n]
786                        .iter()
787                        .map(|x| format!("[{}]({})", x, qp.dq(*x)))
788                        .join(", ")
789                } else {
790                    tensor.as_slice_unchecked::<D>()[0..n].iter().join(", ")
791                }
792            }
793        }
794        unsafe {
795            let trunc = self.len() > 12 && !force_full;
796            let data = dispatch_datum!(dump_t(self.datum_type())(
797                self,
798                if trunc { 12 } else { self.len() }
799            ));
800            Ok(format!(
801                "{},{:?} {}{}",
802                self.shape.iter().join(","),
803                self.dt,
804                data,
805                if trunc { "..." } else { "" }
806            ))
807        }
808    }
809
810    /// Compare two tensors, allowing for rounding errors.
811    pub fn close_enough(
812        &self,
813        other: &Self,
814        approx: impl Into<Approximation> + std::fmt::Debug,
815    ) -> TractResult<()> {
816        let approx = approx.into();
817        if self.shape() != other.shape() {
818            bail!("Shape mismatch {:?} != {:?}", self.shape(), other.shape())
819        }
820        let (atol, rtol, outliers) = approx.atol_rtol_outliers(&self.datum_type());
821        let ma = self.cast_to::<f32>()?;
822        let ma = ma.to_array_view::<f32>()?;
823        let mb = other.cast_to::<f32>()?;
824        let mb = mb.to_array_view::<f32>()?;
825        let mut first_outlier = None;
826        let mut outliers_count = 0;
827        ndarray::indices_of(&ma).into_iter().for_each(|indices| {
828            let a = ma[&indices];
829            let b = mb[&indices];
830            if !((a.is_nan() && b.is_nan())
831                || (a.is_infinite() && b.is_infinite() && a.signum() == b.signum())
832                || (a - b).abs() <= atol as f32 + rtol as f32 * b.abs())
833            {
834                if outliers_count == 0 {
835                    first_outlier = Some(indices.as_array_view().to_vec());
836                }
837                outliers_count += 1;
838            }
839        });
840        if self.volume() > 0 && outliers_count as f64 / self.volume() as f64 > outliers {
841            let indices = first_outlier.unwrap();
842            let a = ma[&*indices];
843            let b = mb[&*indices];
844            bail!(
845                "Mismatch. First outlier: {:?} for {:?}) at {:?} {} != {}. Outliers: {} / {} = {:0.5} > {:0.5}.",
846                approx,
847                self.datum_type(),
848                indices,
849                a,
850                b,
851                outliers_count,
852                self.volume(),
853                outliers_count as f64 / self.volume() as f64,
854                outliers
855            );
856        }
857        Ok(())
858    }
859
860    /// Transform the tensor into a `ndarray::Array`.
861    pub fn into_array<D: Datum>(self) -> TractResult<ArrayD<D>> {
862        Ok(self.to_array_view::<D>()?.to_owned())
863    }
864
865    /// Transform the tensor into a `ndarray::Array`.
866    pub unsafe fn into_array_unchecked<D: Datum>(self) -> ArrayD<D> {
867        unsafe { self.to_array_view_unchecked::<D>().to_owned() }
868    }
869
870    fn check_for_access<D: Datum>(&self) -> TractResult<()> {
871        ensure!(
872            self.datum_type().unquantized() == D::datum_type().unquantized(),
873            "Tensor datum type error: tensor is {:?}, accessed as {:?}",
874            self.datum_type(),
875            D::datum_type(),
876        );
877        Ok(())
878    }
879
880    /// Transform the data as a `ndarray::Array`.
881    pub fn to_array_view<D: Datum>(&self) -> TractResult<ArrayViewD<'_, D>> {
882        self.check_for_access::<D>()?;
883        unsafe { Ok(self.to_array_view_unchecked()) }
884    }
885
886    /// Transform the data as a mutable `ndarray::Array`.
887    pub fn to_array_view_mut<D: Datum>(&mut self) -> TractResult<ArrayViewMutD<'_, D>> {
888        self.check_for_access::<D>()?;
889        unsafe { Ok(self.to_array_view_mut_unchecked()) }
890    }
891
892    /// Transform the data as a `ndarray::Array`.
893    pub unsafe fn to_array_view_unchecked<D: Datum>(&self) -> ArrayViewD<'_, D> {
894        if self.len() != 0 {
895            unsafe { ArrayViewD::from_shape_ptr(&*self.shape, self.data.as_ptr() as *const D) }
896        } else {
897            ArrayViewD::from_shape(&*self.shape, &[]).unwrap()
898        }
899    }
900
901    /// Transform the data as a mutable `ndarray::Array`.
902    pub unsafe fn to_array_view_mut_unchecked<D: Datum>(&mut self) -> ArrayViewMutD<'_, D> {
903        if self.len() != 0 {
904            unsafe { ArrayViewMutD::from_shape_ptr(&*self.shape, self.data.as_mut_ptr() as *mut D) }
905        } else {
906            ArrayViewMutD::from_shape(&*self.shape, &mut []).unwrap()
907        }
908    }
909
910    /// Access the data as a pointer.
911    pub fn as_ptr<D: Datum>(&self) -> TractResult<*const D> {
912        self.check_for_access::<D>()?;
913        Ok(self.data.as_ptr() as *const D)
914    }
915
916    /// Access the data as a pointer.
917    pub unsafe fn as_ptr_unchecked<D: Datum>(&self) -> *const D {
918        self.data.as_ptr() as *const D
919    }
920
921    /// Access the data as a pointer.
922    pub unsafe fn as_ptr_mut_unchecked<D: Datum>(&mut self) -> *mut D {
923        self.data.as_mut_ptr() as *mut D
924    }
925
926    /// Access the data as a mutable pointer.
927    pub fn as_ptr_mut<D: Datum>(&mut self) -> TractResult<*mut D> {
928        self.as_ptr::<D>().map(|p| p as *mut D)
929    }
930
931    /// Access the data as a slice.
932    pub fn as_slice<D: Datum>(&self) -> TractResult<&[D]> {
933        let ptr: *const D = self.as_ptr()?;
934        if self.data.len() == 0 {
935            Ok(&[])
936        } else {
937            unsafe { Ok(std::slice::from_raw_parts::<D>(ptr, self.len())) }
938        }
939    }
940
941    /// Access the data as a mutable slice.
942    pub fn as_slice_mut<D: Datum>(&mut self) -> TractResult<&mut [D]> {
943        let ptr: *mut D = self.as_ptr_mut()?;
944        if self.data.len() == 0 {
945            Ok(&mut [])
946        } else {
947            unsafe { Ok(std::slice::from_raw_parts_mut::<D>(ptr, self.len())) }
948        }
949    }
950
951    /// Access the data as a slice.
952    pub unsafe fn as_slice_unchecked<D: Datum>(&self) -> &[D] {
953        if self.data.len() == 0 {
954            &[]
955        } else {
956            unsafe { std::slice::from_raw_parts::<D>(self.as_ptr_unchecked(), self.len()) }
957        }
958    }
959
960    /// Access the data as a mutable slice.
961    pub unsafe fn as_slice_mut_unchecked<D: Datum>(&mut self) -> &mut [D] {
962        if self.data.len() == 0 {
963            &mut []
964        } else {
965            unsafe { std::slice::from_raw_parts_mut::<D>(self.as_ptr_mut_unchecked(), self.len()) }
966        }
967    }
968
969    /// Access the data as a scalar.
970    pub fn to_scalar<D: Datum>(&self) -> TractResult<&D> {
971        self.check_for_access::<D>()?;
972        if self.len() == 0 {
973            bail!("to_scalar called on empty tensor ({:?})", self)
974        }
975        if self.len() > 1 {
976            bail!("to_scalar called on a tensor with multiple values ({:?})", self)
977        }
978        unsafe { Ok(self.to_scalar_unchecked()) }
979    }
980
981    /// Make the tensor a scalar tensor (assumes it contains a single value).
982    pub fn to_scalar_tensor(&self) -> TractResult<Tensor> {
983        fn to_scalar_tensor_t<D: Datum>(t: &Tensor) -> TractResult<Tensor> {
984            Ok(litteral::tensor0(t.to_scalar::<D>()?.clone()))
985        }
986        dispatch_datum!(to_scalar_tensor_t(self.datum_type())(self))
987    }
988
989    /// Access the data as a scalar.
990    pub unsafe fn to_scalar_unchecked<D: Datum>(&self) -> &D {
991        unsafe { &*(self.data.as_ptr() as *const D) }
992    }
993
994    /// Mutable access the data as a scalar.
995    pub fn to_scalar_mut<D: Datum>(&mut self) -> TractResult<&mut D> {
996        self.check_for_access::<D>()?;
997        if self.len() == 0 {
998            bail!("to_scalar_mut called on empty tensor ({:?})", self)
999        }
1000        if self.len() > 1 {
1001            bail!("to_scalar called on a tensor with multiple values ({:?})", self)
1002        }
1003        unsafe { Ok(self.to_scalar_mut_unchecked()) }
1004    }
1005
1006    /// Mutable access the data as a scalar.
1007    pub unsafe fn to_scalar_mut_unchecked<D: Datum>(&mut self) -> &mut D {
1008        unsafe { &mut *(self.data.as_mut_ptr() as *mut D) }
1009    }
1010
1011    pub fn as_bytes(&self) -> &[u8] {
1012        self.data.as_bytes()
1013    }
1014
1015    pub fn as_bytes_mut(&mut self) -> &mut [u8] {
1016        self.data.as_bytes_mut()
1017    }
1018
1019    unsafe fn is_uniform_t<T: Datum>(&self) -> bool {
1020        let slice = unsafe { self.as_slice_unchecked::<T>() };
1021        slice[1..].iter().all(|x| x == &slice[0])
1022    }
1023
1024    pub fn is_uniform(&self) -> bool {
1025        if self.len() <= 1 {
1026            return true;
1027        }
1028        unsafe { dispatch_datum!(Tensor::is_uniform_t(self.datum_type())(self)) }
1029    }
1030
1031    unsafe fn as_uniform_t<T: Datum>(&self) -> Tensor {
1032        let v: T = unsafe { self.as_slice_unchecked::<T>() }[0].clone();
1033        litteral::tensor0(v)
1034    }
1035
1036    pub fn as_uniform(&self) -> Option<Tensor> {
1037        if self.len() >= 1 && self.is_uniform() {
1038            unsafe {
1039                let mut t = dispatch_datum!(Tensor::as_uniform_t(self.datum_type())(self));
1040                t.set_datum_type(self.datum_type());
1041                Some(t)
1042            }
1043        } else {
1044            None
1045        }
1046    }
1047
1048    pub fn is_all_zero(&self) -> TractResult<bool> {
1049        Ok(self.len() == 0 || self.as_uniform().map(|t| t.is_zero().unwrap()).unwrap_or(false))
1050    }
1051
1052    pub fn is_zero(&self) -> TractResult<bool> {
1053        Ok(self == &Tensor::zero_scalar_dt(self.dt)?)
1054    }
1055
1056    unsafe fn natural_cast<
1057        Source: Datum + num_traits::AsPrimitive<Target>,
1058        Target: Datum + Copy,
1059    >(
1060        &self,
1061        other: &mut Tensor,
1062    ) {
1063        unsafe {
1064            self.as_slice_unchecked::<Source>()
1065                .iter()
1066                .zip(other.as_slice_mut_unchecked::<Target>().iter_mut())
1067                .for_each(|(s, d)| *d = s.as_())
1068        };
1069    }
1070
1071    unsafe fn cast_number_to_bool<Source: Datum + num_traits::Zero>(&self, other: &mut Tensor) {
1072        unsafe {
1073            self.as_slice_unchecked::<Source>()
1074                .iter()
1075                .zip(other.as_slice_mut_unchecked::<bool>().iter_mut())
1076                .for_each(|(s, d)| *d = !s.is_zero());
1077        }
1078    }
1079
1080    unsafe fn cast_from_string<Target: Datum + core::str::FromStr>(
1081        &self,
1082        other: &mut Tensor,
1083    ) -> TractResult<()> {
1084        unsafe {
1085            for (s, d) in self
1086                .as_slice_unchecked::<String>()
1087                .iter()
1088                .zip(other.as_slice_mut_unchecked::<Target>().iter_mut())
1089            {
1090                *d = s
1091                    .parse()
1092                    .map_err(|_| format_err!("Can not parse as {:?}", Target::datum_type()))?;
1093            }
1094            Ok(())
1095        }
1096    }
1097
1098    unsafe fn cast_to_string<Source: Datum>(&self, other: &mut Tensor) {
1099        unsafe {
1100            for (s, d) in self
1101                .as_slice_unchecked::<Source>()
1102                .iter()
1103                .zip(other.as_slice_mut_unchecked::<String>().iter_mut())
1104            {
1105                *d = s.to_string()
1106            }
1107        }
1108    }
1109
1110    /// Optionnaly convert data to a tensor for a new DatumType.
1111    pub fn cast_to<D: Datum>(&self) -> TractResult<Cow<'_, Tensor>> {
1112        self.cast_to_dt(D::datum_type())
1113    }
1114
1115    /// Optionnaly convert data to a tensor for a new DatumType.
1116    #[allow(clippy::redundant_closure_call)]
1117    pub fn cast_to_dt(&self, dst_dt: DatumType) -> TractResult<Cow<'_, Tensor>> {
1118        unsafe {
1119            if self.dt == dst_dt {
1120                return Ok(Cow::Borrowed(self));
1121            }
1122            if self.dt == TDim::datum_type() && (dst_dt.is_integer() || dst_dt.is_float()) {
1123                let slice = self.as_slice_unchecked::<TDim>();
1124                let mut ints = Self::uninitialized::<i64>(&self.shape)?;
1125                let ints_slice = ints.as_slice_mut_unchecked::<i64>();
1126                for i in 0..self.len() {
1127                    ints_slice[i] = slice[i].to_i64()?;
1128                }
1129                return Ok(Cow::Owned(ints.cast_to_dt(dst_dt)?.into_owned()));
1130            }
1131            if self.dt == bool::datum_type()
1132                && (dst_dt.is_integer() || dst_dt.is_float() || dst_dt == TDim::datum_type())
1133            {
1134                let slice = self.as_slice_unchecked::<bool>();
1135                let mut ints = Self::uninitialized::<i8>(&self.shape)?;
1136                let ints_slice = ints.as_slice_mut_unchecked::<i8>();
1137                for i in 0..self.len() {
1138                    ints_slice[i] = slice[i] as usize as i8;
1139                }
1140                return Ok(Cow::Owned(ints.cast_to_dt(dst_dt)?.into_owned()));
1141            }
1142            let mut result = Self::uninitialized_dt(dst_dt, &self.shape)?;
1143            if self.dt == DatumType::String {
1144                dispatch_numbers!(Self::cast_from_string(dst_dt)(self, &mut result))?;
1145                return Ok(Cow::Owned(result));
1146            }
1147            if dst_dt == DatumType::String {
1148                dispatch_datum!(Self::cast_to_string(self.dt)(self, &mut result));
1149                return Ok(Cow::Owned(result));
1150            }
1151            macro_rules! n {
1152                ($source:ty) => {
1153                    if <$source>::datum_type() == self.datum_type() {
1154                        match dst_dt {
1155                            DatumType::I8 => self.natural_cast::<$source, i8>(&mut result),
1156                            DatumType::I16 => self.natural_cast::<$source, i16>(&mut result),
1157                            DatumType::I32 => self.natural_cast::<$source, i32>(&mut result),
1158                            DatumType::I64 => self.natural_cast::<$source, i64>(&mut result),
1159                            DatumType::U8 => self.natural_cast::<$source, u8>(&mut result),
1160                            DatumType::U16 => self.natural_cast::<$source, u16>(&mut result),
1161                            DatumType::U32 => self.natural_cast::<$source, u32>(&mut result),
1162                            DatumType::U64 => self.natural_cast::<$source, u64>(&mut result),
1163                            DatumType::F16 => self.natural_cast::<$source, f16>(&mut result),
1164                            DatumType::F32 => self.natural_cast::<$source, f32>(&mut result),
1165                            DatumType::F64 => self.natural_cast::<$source, f64>(&mut result),
1166                            DatumType::TDim => {
1167                                let ints = self.cast_to::<i32>()?;
1168                                let slice = ints.as_slice_unchecked::<i32>();
1169                                let result = result.as_slice_mut_unchecked::<TDim>();
1170                                for i in 0..self.len() {
1171                                    result[i] = slice[i].into();
1172                                }
1173                            }
1174                            DatumType::Bool => self.cast_number_to_bool::<$source>(&mut result),
1175                            _ => todo!(),
1176                        }
1177                        return Ok(Cow::Owned(result));
1178                    };
1179                };
1180            }
1181            //If there is no quantization
1182            if !dst_dt.is_quantized() && !self.datum_type().is_quantized() {
1183                n!(u8);
1184                n!(u16);
1185                n!(u32);
1186                n!(u64);
1187                n!(i8);
1188                n!(i16);
1189                n!(i32);
1190                n!(i64);
1191                n!(f16);
1192                n!(f32);
1193                n!(f64);
1194            } else {
1195                let (s_zp, s_scale) = self.datum_type().zp_scale();
1196                let (d_zp, d_scale) = dst_dt.zp_scale();
1197                if self.datum_type().is_quantized() && dst_dt.is_float() {
1198                    macro_rules! q_to_fp {
1199                        ($source:ty, $dest:ty) => {
1200                            if <$source>::datum_type().unquantized()
1201                                == self.datum_type().unquantized()
1202                                && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1203                            {
1204                                self.as_slice_unchecked::<$source>()
1205                                    .iter()
1206                                    .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1207                                    .for_each(|(&s, d)| {
1208                                        *d = (s as $dest - s_zp as $dest) * s_scale as $dest;
1209                                    });
1210                                return Ok(Cow::Owned(result));
1211                            }
1212                        };
1213                    }
1214                    q_to_fp!(i8, f64);
1215                    q_to_fp!(i8, f32);
1216                    q_to_fp!(u8, f64);
1217                    q_to_fp!(u8, f32);
1218                }
1219                //TODO: optimize scale_by
1220                macro_rules! q8_to_q8 {
1221                    ($typ:ty) => {
1222                        if dst_dt.unquantized() == <$typ>::datum_type() {
1223                            self.as_slice_unchecked::<$typ>()
1224                                .iter()
1225                                .zip(result.as_slice_mut_unchecked::<$typ>().iter_mut())
1226                                .for_each(|(&s, d)| {
1227                                    *d = (d_zp as i32
1228                                        + scale_by(s as i32 - s_zp as i32, s_scale / d_scale))
1229                                    .clamp_cast()
1230                                });
1231                            return Ok(Cow::Owned(result));
1232                        }
1233                    };
1234                }
1235
1236                macro_rules! q_via_f32 {
1237                    ($source:ty, $dest:ty, $round:expr) => {
1238                        if <$source>::datum_type().unquantized() == self.datum_type().unquantized()
1239                            && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1240                        {
1241                            self.as_slice_unchecked::<$source>()
1242                                .iter()
1243                                .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1244                                .for_each(|(&s, d)| {
1245                                    let s_float = (s as f32 - s_zp as f32) * s_scale as f32;
1246                                    let d_float = s_float as f32 / d_scale as f32 + d_zp as f32;
1247                                    *d = $round(d_float);
1248                                });
1249                            return Ok(Cow::Owned(result));
1250                        }
1251                    };
1252                }
1253
1254                macro_rules! q_n {
1255                    (clamp $source:ty, $dest:ty) => {{
1256                        if <$source>::datum_type().unquantized() == self.datum_type().unquantized()
1257                            && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1258                        {
1259                            self.as_slice_unchecked::<$source>()
1260                                .iter()
1261                                .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1262                                .for_each(|(&s, d)| {
1263                                    *d = s.clamp_cast();
1264                                });
1265                            return Ok(Cow::Owned(result));
1266                        }
1267                    }};
1268                    ($source:ty, $dest:ty) => {{
1269                        if <$source>::datum_type().unquantized() == self.datum_type().unquantized()
1270                            && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1271                        {
1272                            self.as_slice_unchecked::<$source>()
1273                                .iter()
1274                                .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1275                                .for_each(|(&s, d)| {
1276                                    *d = s as $dest;
1277                                });
1278                            return Ok(Cow::Owned(result));
1279                        }
1280                    }};
1281                }
1282
1283                if dst_dt.unquantized() == self.datum_type().unquantized()
1284                    && dst_dt.is_quantized()
1285                    && self.datum_type().is_quantized()
1286                {
1287                    q8_to_q8!(i8);
1288                    q8_to_q8!(u8);
1289                }
1290
1291                q_via_f32!(f32, i8, |f| round_ties_to_even(f).clamp_cast());
1292                q_via_f32!(f32, u8, |f| round_ties_to_even(f).clamp_cast());
1293                q_via_f32!(f32, i32, |f| round_ties_to_even(f).clamp_cast());
1294                q_via_f32!(i8, f32, |f| f);
1295                q_via_f32!(u8, f32, |f| f);
1296                q_via_f32!(i32, f32, |f| f);
1297
1298                if dst_dt.is_quantized() && self.datum_type().is_quantized() {
1299                    q_via_f32!(u8, i8, |f| round_ties_to_even(f).clamp_cast());
1300                    q_via_f32!(i8, u8, |f| round_ties_to_even(f).clamp_cast());
1301                    q_via_f32!(i32, u8, |f| round_ties_to_even(f).clamp_cast());
1302                    q_via_f32!(i32, i8, |f| round_ties_to_even(f).clamp_cast());
1303                    q_via_f32!(u8, i32, |f| round_ties_to_even(f).clamp_cast());
1304                    q_via_f32!(i8, i32, |f| round_ties_to_even(f).clamp_cast());
1305
1306                    // ensure cast to different scale offset work
1307                    q_via_f32!(i8, i8, |f| round_ties_to_even(f).clamp_cast());
1308                    q_via_f32!(u8, u8, |f| round_ties_to_even(f).clamp_cast());
1309                }
1310
1311                q_n!(i8, i32);
1312                q_n!(i8, u32);
1313                q_n!(u8, i32);
1314                q_n!(u8, u32);
1315                q_n!(clamp i32, i8);
1316                q_n!(clamp i32, u8);
1317                q_n!(clamp u32, i8);
1318                q_n!(clamp u32, u8);
1319                q_n!(i8, i8);
1320                q_n!(u8, u8);
1321                q_n!(i32, i32);
1322                q_n!(u32, u32);
1323            }
1324
1325            bail!("Unsupported cast from {:?} to {:?}", self.dt, dst_dt)
1326        }
1327    }
1328
1329    /// Access the data as a scalar, after a cast.
1330    pub fn cast_to_scalar<D: Datum + Copy>(&self) -> TractResult<D> {
1331        let casted = self.cast_to::<D>()?;
1332        casted.to_scalar::<D>().copied()
1333    }
1334
1335    /// Access the nth element of the tensor, returned as a 0-rank Tensor
1336    pub fn nth(&self, nth: usize) -> TractResult<Tensor> {
1337        if nth >= self.len() {
1338            bail!(
1339                "nth called with {}th element on a tensor of len {} ({:?}",
1340                nth,
1341                self.len(),
1342                self
1343            );
1344        }
1345        unsafe fn nth_t<T: Datum>(me: &Tensor, nth: usize, output: &mut Tensor) {
1346            unsafe {
1347                let value = me.as_slice_unchecked::<T>()[nth].clone();
1348                output.as_slice_mut_unchecked::<T>()[0] = value;
1349            }
1350        }
1351        unsafe {
1352            let mut output = Tensor::uninitialized_dt(self.datum_type(), &[])?;
1353            dispatch_datum_by_size!(nth_t(self.datum_type())(self, nth, &mut output));
1354            Ok(output)
1355        }
1356    }
1357
1358    /// Strict equality test on tensors.
1359    fn eq_dt(&self, other: &Tensor) -> TractResult<bool> {
1360        unsafe fn eq_t<D: Datum>(me: &Tensor, other: &Tensor) -> TractResult<bool> {
1361            unsafe {
1362                if D::datum_type().is_float() {
1363                    return dispatch_floatlike!(float_eq_t(D::datum_type())(me, other));
1364                }
1365                Ok(izip!(me.as_slice_unchecked::<D>(), other.as_slice_unchecked::<D>())
1366                    .all(|(a, b)| a == b))
1367            }
1368        }
1369
1370        unsafe fn float_eq_t<D: Datum + Float>(me: &Tensor, other: &Tensor) -> TractResult<bool> {
1371            unsafe {
1372                Ok(izip!(me.as_slice_unchecked::<D>(), other.as_slice_unchecked::<D>())
1373                    .all(|(a, b)| (a.is_nan() && b.is_nan()) || a == b))
1374            }
1375        }
1376
1377        unsafe {
1378            Ok(self.datum_type() == other.datum_type()
1379                && self.shape() == other.shape()
1380                && dispatch_datum!(eq_t(self.dt)(self, other))?)
1381        }
1382    }
1383
1384    fn from_datum<T: Datum>(mut it: ArrayD<T>) -> Tensor {
1385        unsafe {
1386            let mut t = Self::uninitialized::<T>(it.shape()).unwrap();
1387            if let Some(slice) = it.as_slice_mut() {
1388                if t.datum_type().is_copy() {
1389                    std::ptr::copy_nonoverlapping(
1390                        slice.as_ptr() as *const i8,
1391                        t.as_ptr_mut_unchecked(),
1392                        t.data.layout().size(),
1393                    );
1394                } else {
1395                    t.as_slice_mut_unchecked::<T>()
1396                        .iter_mut()
1397                        .zip(slice.iter_mut())
1398                        .for_each(|(t, s)| *t = std::mem::take(s));
1399                }
1400                return t;
1401            }
1402            if it.strides().iter().all(|&s| s > 0) && it.as_slice_memory_order().is_some() {
1403                let mut len_and_strides: TVec<(usize, usize)> = tvec!();
1404                for (len, stride) in itertools::izip!(it.shape(), it.strides(), t.strides())
1405                    .sorted_by_key(|(_, src, _)| *src)
1406                    .map(|(l, _, dst)| (*l as isize, *dst))
1407                {
1408                    if !len_and_strides.is_empty()
1409                        && len_and_strides.last().unwrap().1 * len_and_strides.last().unwrap().0
1410                            == stride as usize
1411                    {
1412                        len_and_strides.last_mut().unwrap().0 *= len as usize;
1413                    } else {
1414                        len_and_strides.push((len as usize, stride as usize));
1415                    }
1416                }
1417                len_and_strides.reverse();
1418                crate::scatter::scatter_contig_data(
1419                    it.as_ptr(),
1420                    t.as_ptr_mut_unchecked(),
1421                    &len_and_strides,
1422                );
1423                return t;
1424            }
1425            // finally use ndarray into_iter()
1426            t.as_slice_mut_unchecked().iter_mut().zip(it).for_each(|(t, a)| *t = a);
1427            t
1428        }
1429    }
1430
1431    pub fn deep_clone(&self) -> Tensor {
1432        unsafe {
1433            let mut tensor = Tensor::uninitialized_dt(self.datum_type(), self.shape()).unwrap();
1434            if self.len() > 0 {
1435                if self.dt.is_copy() {
1436                    self.data.as_ptr().copy_to_nonoverlapping(
1437                        tensor.as_bytes_mut().as_mut_ptr(),
1438                        self.data.layout().size(),
1439                    )
1440                } else if self.dt == DatumType::String {
1441                    tensor
1442                        .as_slice_mut_unchecked::<String>()
1443                        .clone_from_slice(self.as_slice_unchecked());
1444                } else if self.dt == DatumType::Blob {
1445                    tensor
1446                        .as_slice_mut_unchecked::<Blob>()
1447                        .clone_from_slice(self.as_slice_unchecked());
1448                } else if self.dt == DatumType::Opaque {
1449                    tensor
1450                        .as_slice_mut_unchecked::<Opaque>()
1451                        .clone_from_slice(self.as_slice_unchecked());
1452                } else if self.dt == DatumType::TDim {
1453                    tensor
1454                        .as_slice_mut_unchecked::<TDim>()
1455                        .clone_from_slice(self.as_slice_unchecked());
1456                }
1457            }
1458            tensor
1459        }
1460    }
1461
1462    pub fn slice(&self, axis: usize, start: usize, end: usize) -> TractResult<Tensor> {
1463        if axis >= self.rank() {
1464            bail!("Can not slice at axis {} tensor {:?}", axis, self);
1465        }
1466        if start > self.shape[axis] || end > self.shape[axis] || start >= end {
1467            bail!("Invalid slicing range {start}..{end} on axis {axis} for {self:?}");
1468        }
1469        fn slice_t<T: Datum>(
1470            t: &Tensor,
1471            axis: usize,
1472            start: usize,
1473            end: usize,
1474        ) -> TractResult<Tensor> {
1475            Ok(t.to_array_view::<T>()?
1476                .slice_axis(ndarray::Axis(axis), (start..end).into())
1477                .into_owned()
1478                .into_tensor())
1479        }
1480        dispatch_datum!(slice_t(self.datum_type())(self, axis, start, end))
1481    }
1482
1483    #[inline]
1484    pub fn view(&self) -> view::TensorView<'_> {
1485        unsafe { view::TensorView::view(self) }
1486    }
1487
1488    #[inline]
1489    pub fn view_at_prefix(&self, prefix: &[usize]) -> TractResult<view::TensorView<'_>> {
1490        view::TensorView::at_prefix(self, prefix)
1491    }
1492
1493    #[inline]
1494    pub fn view_offsetting(&self, coords: &[usize]) -> TractResult<view::TensorView<'_>> {
1495        view::TensorView::offsetting(self, coords)
1496    }
1497
1498    #[inline]
1499    pub unsafe fn view_offsetting_unchecked(&self, coords: &[usize]) -> view::TensorView<'_> {
1500        unsafe { view::TensorView::offsetting_unchecked(self, coords) }
1501    }
1502
1503    #[inline]
1504    pub fn view_mut(&mut self) -> view::TensorView<'_> {
1505        unsafe { view::TensorView::view(self) }
1506    }
1507
1508    #[inline]
1509    pub fn view_at_prefix_mut(&mut self, prefix: &[usize]) -> TractResult<view::TensorView<'_>> {
1510        view::TensorView::at_prefix(self, prefix)
1511    }
1512
1513    #[inline]
1514    pub fn view_offsetting_mut(&mut self, coords: &[usize]) -> TractResult<view::TensorView<'_>> {
1515        view::TensorView::offsetting(self, coords)
1516    }
1517
1518    /// Offsets the tensor as an i8 type if it's an u8 type, otherwise passes it unchanged.
1519    pub fn offset_u8_as_i8(self: &Arc<Self>) -> Arc<Self> {
1520        let mut t = if let DatumType::U8 = self.dt.unquantized() {
1521            self.to_array_view::<u8>().unwrap().mapv(|v| v.wrapping_sub(128) as i8).into_tensor()
1522        } else {
1523            return self.clone();
1524        };
1525
1526        if let DatumType::QU8(qp) = self.dt {
1527            if let QParams::ZpScale { zero_point, scale } = qp {
1528                t.dt = DatumType::QI8(QParams::ZpScale { zero_point: zero_point - 128, scale });
1529            } else {
1530                t.dt = DatumType::QI8(qp);
1531            }
1532        }
1533
1534        t.into_arc_tensor()
1535    }
1536
1537    /// Offsets the tensor as an u8 type if it's an i8 type, otherwise passes it unchanged.
1538    pub fn offset_i8_as_u8(self: &Arc<Self>) -> Arc<Self> {
1539        let mut t = if let DatumType::I8 = self.dt.unquantized() {
1540            self.to_array_view::<i8>().unwrap().mapv(|v| (v as u8).wrapping_add(128)).into_tensor()
1541        } else {
1542            return self.clone();
1543        };
1544
1545        if let DatumType::QI8(qp) = self.dt {
1546            if let QParams::ZpScale { zero_point, scale } = qp {
1547                t.dt = DatumType::QU8(QParams::ZpScale { zero_point: zero_point + 128, scale });
1548            } else {
1549                t.dt = DatumType::QU8(qp);
1550            }
1551        }
1552        t.into_arc_tensor()
1553    }
1554
1555    pub fn to_aligned_default(&self) -> TractResult<Self> {
1556        if self.dt.is_copy() {
1557            unsafe {
1558                let mut t = Self::uninitialized_dt(self.dt, &self.shape)?;
1559                t.as_bytes_mut().copy_from_slice(self.as_bytes());
1560                Ok(t)
1561            }
1562        } else {
1563            let mut t = Self::zero_dt(self.dt, &self.shape)?;
1564            if self.dt == String::datum_type() {
1565                t.as_slice_mut::<String>()?.clone_from_slice(self.as_slice()?);
1566            } else if self.dt == Blob::datum_type() {
1567                t.as_slice_mut::<Blob>()?.clone_from_slice(self.as_slice()?);
1568            } else if self.dt == TDim::datum_type() {
1569                t.as_slice_mut::<TDim>()?.clone_from_slice(self.as_slice()?);
1570            }
1571            Ok(t)
1572        }
1573    }
1574
1575    pub fn natural_strides(shape: &[usize]) -> TVec<isize> {
1576        let mut strides = tvec!();
1577        compute_natural_stride_to(&mut strides, shape);
1578        strides
1579    }
1580
1581    pub fn into_blob(mut self) -> TractResult<Blob> {
1582        ensure!(self.dt.is_copy());
1583        Ok(std::mem::take(&mut self.data))
1584    }
1585}
1586
1587impl PartialEq for Tensor {
1588    fn eq(&self, other: &Tensor) -> bool {
1589        if self.dt != other.dt || self.shape != other.shape {
1590            return false;
1591        }
1592        self.eq_dt(other).unwrap_or(false)
1593    }
1594}
1595
1596impl fmt::Debug for Tensor {
1597    fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
1598        let content = self.dump(false).unwrap_or_else(|e| format!("Error : {e:?}"));
1599        write!(formatter, "{content}")
1600    }
1601}
1602
1603#[cfg(feature = "complex")]
1604pub fn reinterpret_inner_dim_as_complex(mut t: Tensor) -> TractResult<Tensor> {
1605    ensure!(
1606        t.shape().last() == Some(&2),
1607        "The last dimension in the tensor shape {:?} must be 2",
1608        t.shape()
1609    );
1610    unsafe {
1611        t.shape.pop();
1612        t.set_datum_type(t.datum_type().complexify()?);
1613        t.update_strides_and_len();
1614        Ok(t)
1615    }
1616}
1617
1618#[cfg(feature = "complex")]
1619pub fn reinterpret_complex_as_inner_dim(mut t: Tensor) -> TractResult<Tensor> {
1620    unsafe {
1621        t.shape.push(2);
1622        t.set_datum_type(t.datum_type().decomplexify()?);
1623        t.update_strides_and_len();
1624        Ok(t)
1625    }
1626}
1627
1628pub fn clip_range_bounds(len: usize, range: impl std::ops::RangeBounds<usize>) -> Range<usize> {
1629    use std::ops::Bound;
1630    let start = match range.start_bound() {
1631        Bound::Included(ix) => *ix,
1632        Bound::Excluded(ix) => ix + 1,
1633        Bound::Unbounded => 0,
1634    };
1635    let end = match range.end_bound() {
1636        Bound::Included(ix) => *ix + 1,
1637        Bound::Excluded(ix) => *ix,
1638        Bound::Unbounded => len,
1639    };
1640    start..end
1641}
1642
1643pub fn natural_strides(shape: &[usize]) -> TVec<isize> {
1644    let mut strides = tvec!();
1645    compute_natural_stride_to(&mut strides, shape);
1646    strides
1647}
1648
1649fn compute_natural_stride_to(strides: &mut TVec<isize>, shape: &[usize]) {
1650    match shape.len() {
1651        0 => (),
1652        1 => strides.push(1),
1653        2 => strides.extend_from_slice(&[shape[1] as isize, 1]),
1654        3 => strides.extend_from_slice(&[(shape[1] * shape[2]) as isize, shape[2] as _, 1]),
1655        4 => strides.extend_from_slice(&[
1656            (shape[1] * shape[2] * shape[3]) as isize,
1657            (shape[2] * shape[3]) as _,
1658            shape[3] as _,
1659            1,
1660        ]),
1661        _ => {
1662            strides.push(1);
1663            for dim in shape.as_ref().iter().skip(1).rev() {
1664                let previous = *strides.last().unwrap();
1665                strides.push(previous * *dim as isize)
1666            }
1667            strides.reverse();
1668        }
1669    }
1670}
1671
1672impl<D: ::ndarray::Dimension, T: Datum> From<Array<T, D>> for Tensor {
1673    fn from(it: Array<T, D>) -> Tensor {
1674        Tensor::from_datum(it.into_dyn())
1675    }
1676}
1677
1678/// Convenient conversion to Tensor.
1679pub trait IntoTensor: Sized {
1680    /// Convert Self to a Tensor.
1681    ///
1682    /// May perform a copy
1683    fn into_tensor(self) -> Tensor;
1684}
1685
1686/// Convenient conversion to Arc<Tensor>.
1687pub trait IntoArcTensor: Sized {
1688    /// Convert Self to a Arc<Tensor>.
1689    ///
1690    /// May perform a copy
1691    fn into_arc_tensor(self) -> Arc<Tensor>;
1692}
1693
1694impl<D: ::ndarray::Dimension, T: Datum> IntoTensor for Array<T, D> {
1695    fn into_tensor(self) -> Tensor {
1696        Tensor::from(self)
1697    }
1698}
1699
1700impl<D: ::ndarray::Dimension, T: Datum> IntoArcTensor for Array<T, D> {
1701    fn into_arc_tensor(self) -> Arc<Tensor> {
1702        Arc::new(Tensor::from(self))
1703    }
1704}
1705
1706impl IntoTensor for Tensor {
1707    fn into_tensor(self) -> Tensor {
1708        self
1709    }
1710}
1711
1712impl IntoTensor for Arc<Tensor> {
1713    fn into_tensor(self) -> Tensor {
1714        Arc::try_unwrap(self).unwrap_or_else(|t| (*t).clone())
1715    }
1716}
1717
1718impl IntoArcTensor for Tensor {
1719    fn into_arc_tensor(self) -> Arc<Tensor> {
1720        Arc::new(self)
1721    }
1722}
1723
1724impl IntoArcTensor for Arc<Tensor> {
1725    fn into_arc_tensor(self) -> Arc<Tensor> {
1726        self
1727    }
1728}
1729
1730#[cfg(test)]
1731mod tests {
1732    use crate::dim::SymbolScope;
1733    use crate::prelude::tensor1;
1734
1735    use super::*;
1736    use litteral::tensor0;
1737    use proptest::collection::vec;
1738    use proptest::prelude::*;
1739
1740    #[derive(Debug)]
1741    struct PermuteAxisProblem {
1742        shape: Vec<usize>,
1743        permutation: Vec<usize>,
1744    }
1745
1746    impl Arbitrary for PermuteAxisProblem {
1747        type Strategy = BoxedStrategy<PermuteAxisProblem>;
1748        type Parameters = ();
1749
1750        fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1751            (0..8usize)
1752                .prop_flat_map(|rank| {
1753                    let permute: Vec<usize> = (0..rank).collect();
1754                    (proptest::collection::vec(1..5usize, rank), Just(permute).prop_shuffle())
1755                })
1756                .prop_map(|(shape, permutation)| PermuteAxisProblem { shape, permutation })
1757                .boxed()
1758        }
1759    }
1760
1761    impl PermuteAxisProblem {
1762        fn input(&self) -> ArrayD<i32> {
1763            let mut i = 0;
1764            ArrayD::from_shape_simple_fn(&*self.shape, || {
1765                i += 1;
1766                i
1767            })
1768            .permuted_axes(&*self.permutation)
1769        }
1770
1771        fn reference(&self) -> Tensor {
1772            let values: Vec<i32> = self.input().iter().copied().collect();
1773            let shape = self.permutation.iter().map(|ix| self.shape[*ix]).collect::<TVec<usize>>();
1774            super::litteral::tensor1(&values).into_shape(&shape).unwrap()
1775        }
1776
1777        fn tract(&self) -> Tensor {
1778            Tensor::from(self.input())
1779        }
1780
1781        fn check(&self) -> proptest::test_runner::TestCaseResult {
1782            prop_assert_eq!(self.tract(), self.reference());
1783            Ok(())
1784        }
1785    }
1786
1787    proptest::proptest! {
1788        #[test]
1789        fn prop(pb: PermuteAxisProblem) {
1790            pb.check().unwrap();
1791        }
1792    }
1793
1794    #[test]
1795    fn t_1_2() {
1796        PermuteAxisProblem { shape: vec![2, 1], permutation: vec![1, 0] }.check().unwrap();
1797    }
1798
1799    #[test]
1800    fn t_2_2() {
1801        PermuteAxisProblem { shape: vec![2, 2], permutation: vec![1, 0] }.check().unwrap();
1802    }
1803
1804    #[derive(Debug)]
1805    struct BroadcastVecToShape {
1806        vec: Vec<f32>,
1807        axis: usize,
1808        shape: TVec<usize>,
1809    }
1810
1811    impl BroadcastVecToShape {
1812        fn check(&self) -> proptest::test_runner::TestCaseResult {
1813            let input = tensor1(&self.vec);
1814            let mut intermediate = tvec![1usize; self.shape.len()];
1815            intermediate[self.axis] = self.vec.len();
1816            let reference = input
1817                .clone()
1818                .into_shape(&intermediate)
1819                .unwrap()
1820                .broadcast_to_shape(&self.shape)
1821                .unwrap();
1822            prop_assert_eq!(
1823                reference,
1824                input.broadcast_vector_to_shape(&self.shape, self.axis).unwrap()
1825            );
1826            Ok(())
1827        }
1828    }
1829
1830    impl Arbitrary for BroadcastVecToShape {
1831        type Strategy = BoxedStrategy<BroadcastVecToShape>;
1832        type Parameters = ();
1833
1834        fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1835            vec(0usize..5, 0usize..4)
1836                .prop_flat_map(|shape| {
1837                    (vec(-10f32..10f32, 0usize..5), Just(shape.clone()), 0..shape.len() + 1)
1838                })
1839                .prop_map(|(vec, mut shape, axis)| {
1840                    shape.insert(axis, vec.len());
1841                    BroadcastVecToShape { vec, shape: shape.into(), axis }
1842                })
1843                .boxed()
1844        }
1845    }
1846
1847    proptest::proptest! {
1848        #[test]
1849        fn broadcast_vector_to_shape_prop(pb: BroadcastVecToShape) {
1850            pb.check().unwrap()
1851        }
1852    }
1853
1854    #[test]
1855    #[cfg(feature = "complex")]
1856    fn test_reinterpret_inner_dim_as_complex() -> TractResult<()> {
1857        let input = crate::internal::tensor2(&[[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]]);
1858        let cplx_input = reinterpret_inner_dim_as_complex(input)?;
1859        let expected = crate::internal::tensor1(&[
1860            Complex::new(1.0f32, 2.0),
1861            Complex::new(3.0, 4.0),
1862            Complex::new(5.0, 6.0),
1863        ]);
1864        assert_eq!(expected, cplx_input);
1865        Ok(())
1866    }
1867
1868    #[test]
1869    #[cfg(feature = "complex")]
1870    fn test_reinterpret_inner_dim_as_complex_2() -> TractResult<()> {
1871        let input =
1872            crate::internal::tensor3(&[[[1i32, 2], [1, 2]], [[3, 4], [3, 4]], [[5, 6], [5, 6]]]);
1873        let cplx_input = reinterpret_inner_dim_as_complex(input)?;
1874        let expected = crate::internal::tensor2(&[
1875            [Complex::new(1i32, 2), Complex::new(1, 2)],
1876            [Complex::new(3, 4), Complex::new(3, 4)],
1877            [Complex::new(5, 6), Complex::new(5, 6)],
1878        ]);
1879        assert_eq!(expected, cplx_input);
1880        Ok(())
1881    }
1882
1883    #[test]
1884    fn clone_tdim_tensor() {
1885        let symbols = SymbolScope::default();
1886        let a = symbols.sym("a");
1887        let t = tensor0(TDim::from(a));
1888        let _ = t.clone();
1889    }
1890}