tract_data/
datum.rs

1//! `Tensor` is the main data container for tract
2use crate::dim::TDim;
3use crate::internal::*;
4use crate::tensor::Tensor;
5use crate::TVec;
6use half::f16;
7#[cfg(feature = "complex")]
8use num_complex::Complex;
9use scan_fmt::scan_fmt;
10use std::fmt;
11use std::hash::Hash;
12
13use num_traits::AsPrimitive;
14
15#[derive(Copy, Clone, PartialEq)]
16pub enum QParams {
17    MinMax { min: f32, max: f32 },
18    ZpScale { zero_point: i32, scale: f32 },
19}
20
21impl Eq for QParams {}
22
23impl Ord for QParams {
24    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
25        use QParams::*;
26        match (self, other) {
27            (MinMax { .. }, ZpScale { .. }) => std::cmp::Ordering::Less,
28            (ZpScale { .. }, MinMax { .. }) => std::cmp::Ordering::Greater,
29            (MinMax { min: min1, max: max1 }, MinMax { min: min2, max: max2 }) => {
30                min1.total_cmp(min2).then_with(|| max1.total_cmp(max2))
31            }
32            (
33                Self::ZpScale { zero_point: zp1, scale: s1 },
34                Self::ZpScale { zero_point: zp2, scale: s2 },
35            ) => zp1.cmp(zp2).then_with(|| s1.total_cmp(s2)),
36        }
37    }
38}
39
40impl PartialOrd for QParams {
41    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
42        Some(self.cmp(other))
43    }
44}
45
46impl Default for QParams {
47    fn default() -> Self {
48        QParams::ZpScale { zero_point: 0, scale: 1. }
49    }
50}
51
52#[allow(clippy::derived_hash_with_manual_eq)]
53impl Hash for QParams {
54    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
55        match self {
56            QParams::MinMax { min, max } => {
57                0.hash(state);
58                min.to_bits().hash(state);
59                max.to_bits().hash(state);
60            }
61            QParams::ZpScale { zero_point, scale } => {
62                1.hash(state);
63                zero_point.hash(state);
64                scale.to_bits().hash(state);
65            }
66        }
67    }
68}
69
70impl QParams {
71    pub fn zp_scale(&self) -> (i32, f32) {
72        match self {
73            QParams::MinMax { min, max } => {
74                let scale = (max - min) / 255.;
75                ((-(min + max) / 2. / scale) as i32, scale)
76            }
77            QParams::ZpScale { zero_point, scale } => (*zero_point, *scale),
78        }
79    }
80
81    pub fn q(&self, f: f32) -> i32 {
82        let (zp, scale) = self.zp_scale();
83        (f / scale) as i32 + zp
84    }
85
86    pub fn dq(&self, i: i32) -> f32 {
87        let (zp, scale) = self.zp_scale();
88        (i - zp) as f32 * scale
89    }
90}
91
92impl std::fmt::Debug for QParams {
93    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94        let (zp, scale) = self.zp_scale();
95        write!(f, "Z:{zp} S:{scale}")
96    }
97}
98
99#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Ord, PartialOrd)]
100pub enum DatumType {
101    Bool,
102    U8,
103    U16,
104    U32,
105    U64,
106    I8,
107    I16,
108    I32,
109    I64,
110    F16,
111    F32,
112    F64,
113    TDim,
114    Blob,
115    String,
116    QI8(QParams),
117    QU8(QParams),
118    QI32(QParams),
119    #[cfg(feature = "complex")]
120    ComplexI16,
121    #[cfg(feature = "complex")]
122    ComplexI32,
123    #[cfg(feature = "complex")]
124    ComplexI64,
125    #[cfg(feature = "complex")]
126    ComplexF16,
127    #[cfg(feature = "complex")]
128    ComplexF32,
129    #[cfg(feature = "complex")]
130    ComplexF64,
131    Opaque,
132}
133
134impl DatumType {
135    pub fn super_types(&self) -> TVec<DatumType> {
136        use DatumType::*;
137        if *self == String || *self == TDim || *self == Blob || *self == Bool || self.is_quantized()
138        {
139            return tvec!(*self);
140        }
141        #[cfg(feature = "complex")]
142        if self.is_complex_float() {
143            return [ComplexF16, ComplexF32, ComplexF64]
144                .iter()
145                .filter(|s| s.size_of() >= self.size_of())
146                .copied()
147                .collect();
148        } else if self.is_complex_signed() {
149            return [ComplexI16, ComplexI32, ComplexI64]
150                .iter()
151                .filter(|s| s.size_of() >= self.size_of())
152                .copied()
153                .collect();
154        }
155        if self.is_float() {
156            [F16, F32, F64].iter().filter(|s| s.size_of() >= self.size_of()).copied().collect()
157        } else if self.is_signed() {
158            [I8, I16, I32, I64, TDim]
159                .iter()
160                .filter(|s| s.size_of() >= self.size_of())
161                .copied()
162                .collect()
163        } else {
164            [U8, U16, U32, U64].iter().filter(|s| s.size_of() >= self.size_of()).copied().collect()
165        }
166    }
167
168    pub fn super_type_for(
169        i: impl IntoIterator<Item = impl std::borrow::Borrow<DatumType>>,
170    ) -> Option<DatumType> {
171        let mut iter = i.into_iter();
172        let mut current = match iter.next() {
173            None => return None,
174            Some(it) => *it.borrow(),
175        };
176        for n in iter {
177            match current.common_super_type(*n.borrow()) {
178                None => return None,
179                Some(it) => current = it,
180            }
181        }
182        Some(current)
183    }
184
185    pub fn common_super_type(&self, rhs: DatumType) -> Option<DatumType> {
186        for mine in self.super_types() {
187            for theirs in rhs.super_types() {
188                if mine == theirs {
189                    return Some(mine);
190                }
191            }
192        }
193        None
194    }
195
196    pub fn is_unsigned(&self) -> bool {
197        matches!(
198            self.unquantized(),
199            DatumType::U8 | DatumType::U16 | DatumType::U32 | DatumType::U64
200        )
201    }
202
203    pub fn is_signed(&self) -> bool {
204        matches!(
205            self.unquantized(),
206            DatumType::I8 | DatumType::I16 | DatumType::I32 | DatumType::I64
207        )
208    }
209
210    pub fn is_float(&self) -> bool {
211        matches!(self, DatumType::F16 | DatumType::F32 | DatumType::F64)
212    }
213
214    pub fn is_number(&self) -> bool {
215        self.is_signed() | self.is_unsigned() | self.is_float() | self.is_quantized()
216    }
217
218    pub fn is_tdim(&self) -> bool {
219        *self == DatumType::TDim
220    }
221
222    pub fn is_opaque(&self) -> bool {
223        *self == DatumType::Opaque
224    }
225
226    #[cfg(feature = "complex")]
227    pub fn is_complex(&self) -> bool {
228        self.is_complex_float() || self.is_complex_signed()
229    }
230
231    #[cfg(feature = "complex")]
232    pub fn is_complex_float(&self) -> bool {
233        matches!(self, DatumType::ComplexF16 | DatumType::ComplexF32 | DatumType::ComplexF64)
234    }
235
236    #[cfg(feature = "complex")]
237    pub fn is_complex_signed(&self) -> bool {
238        matches!(self, DatumType::ComplexI16 | DatumType::ComplexI32 | DatumType::ComplexI64)
239    }
240
241    #[cfg(feature = "complex")]
242    pub fn complexify(&self) -> TractResult<DatumType> {
243        match *self {
244            DatumType::I16 => Ok(DatumType::ComplexI16),
245            DatumType::I32 => Ok(DatumType::ComplexI32),
246            DatumType::I64 => Ok(DatumType::ComplexI64),
247            DatumType::F16 => Ok(DatumType::ComplexF16),
248            DatumType::F32 => Ok(DatumType::ComplexF32),
249            DatumType::F64 => Ok(DatumType::ComplexF64),
250            _ => bail!("No complex datum type formed on {:?}", self),
251        }
252    }
253
254    #[cfg(feature = "complex")]
255    pub fn decomplexify(&self) -> TractResult<DatumType> {
256        match *self {
257            DatumType::ComplexI16 => Ok(DatumType::I16),
258            DatumType::ComplexI32 => Ok(DatumType::I32),
259            DatumType::ComplexI64 => Ok(DatumType::I64),
260            DatumType::ComplexF16 => Ok(DatumType::F16),
261            DatumType::ComplexF32 => Ok(DatumType::F32),
262            DatumType::ComplexF64 => Ok(DatumType::F64),
263            _ => bail!("{:?} is not a complex type", self),
264        }
265    }
266
267    pub fn is_copy(&self) -> bool {
268        #[cfg(feature = "complex")]
269        if self.is_complex() {
270            return true;
271        }
272        *self == DatumType::Bool || self.is_unsigned() || self.is_signed() || self.is_float()
273    }
274
275    pub fn is_quantized(&self) -> bool {
276        self.qparams().is_some()
277    }
278
279    pub fn qparams(&self) -> Option<QParams> {
280        match self {
281            DatumType::QI8(qparams) | DatumType::QU8(qparams) | DatumType::QI32(qparams) => {
282                Some(*qparams)
283            }
284            _ => None,
285        }
286    }
287
288    pub fn with_qparams(&self, qparams: QParams) -> DatumType {
289        match self {
290            DatumType::QI8(_) => DatumType::QI8(qparams),
291            DatumType::QU8(_) => DatumType::QI8(qparams),
292            DatumType::QI32(_) => DatumType::QI32(qparams),
293            _ => *self,
294        }
295    }
296
297    pub fn quantize(&self, qparams: QParams) -> DatumType {
298        match self {
299            DatumType::I8 => DatumType::QI8(qparams),
300            DatumType::U8 => DatumType::QU8(qparams),
301            DatumType::I32 => DatumType::QI32(qparams),
302            DatumType::QI8(_) => DatumType::QI8(qparams),
303            DatumType::QU8(_) => DatumType::QU8(qparams),
304            DatumType::QI32(_) => DatumType::QI32(qparams),
305            _ => panic!("Can't quantize {self:?}"),
306        }
307    }
308
309    #[inline(always)]
310    pub fn zp_scale(&self) -> (i32, f32) {
311        self.qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.))
312    }
313
314    #[inline(always)]
315    pub fn with_zp_scale(&self, zero_point: i32, scale: f32) -> DatumType {
316        self.quantize(QParams::ZpScale { zero_point, scale })
317    }
318
319    pub fn unquantized(&self) -> DatumType {
320        match self {
321            DatumType::QI8(_) => DatumType::I8,
322            DatumType::QU8(_) => DatumType::U8,
323            DatumType::QI32(_) => DatumType::I32,
324            _ => *self,
325        }
326    }
327
328    pub fn integer(signed: bool, size: usize) -> Self {
329        use DatumType::*;
330        match (signed, size) {
331            (false, 8) => U8,
332            (false, 16) => U16,
333            (false, 32) => U32,
334            (false, 64) => U64,
335            (true, 8) => U8,
336            (true, 16) => U16,
337            (true, 32) => U32,
338            (true, 64) => U64,
339            _ => panic!("No integer for signed:{signed} size:{size}"),
340        }
341    }
342
343    pub fn is_integer(&self) -> bool {
344        self.is_signed() || self.is_unsigned()
345    }
346
347    #[inline]
348    pub fn size_of(&self) -> usize {
349        dispatch_datum!(std::mem::size_of(self)())
350    }
351
352    #[inline]
353    pub fn alignment(&self) -> usize {
354        if self.is_copy() {
355            self.size_of()
356        } else {
357            std::mem::size_of::<usize>()
358        }
359    }
360
361    pub fn min_value(&self) -> Tensor {
362        match self {
363            DatumType::QU8(_)
364            | DatumType::U8
365            | DatumType::U16
366            | DatumType::U32
367            | DatumType::U64 => Tensor::zero_dt(*self, &[1]).unwrap(),
368            DatumType::I8 | DatumType::QI8(_) => tensor0(i8::MIN),
369            DatumType::QI32(_) => tensor0(i32::MIN),
370            DatumType::I16 => tensor0(i16::MIN),
371            DatumType::I32 => tensor0(i32::MIN),
372            DatumType::I64 => tensor0(i64::MIN),
373            DatumType::F16 => tensor0(f16::MIN),
374            DatumType::F32 => tensor0(f32::MIN),
375            DatumType::F64 => tensor0(f64::MIN),
376            _ => panic!("No min value for datum type {self:?}"),
377        }
378    }
379    pub fn max_value(&self) -> Tensor {
380        match self {
381            DatumType::U8 | DatumType::QU8(_) => tensor0(u8::MAX),
382            DatumType::U16 => tensor0(u16::MAX),
383            DatumType::U32 => tensor0(u32::MAX),
384            DatumType::U64 => tensor0(u64::MAX),
385            DatumType::I8 | DatumType::QI8(_) => tensor0(i8::MAX),
386            DatumType::I16 => tensor0(i16::MAX),
387            DatumType::I32 => tensor0(i32::MAX),
388            DatumType::I64 => tensor0(i64::MAX),
389            DatumType::QI32(_) => tensor0(i32::MAX),
390            DatumType::F16 => tensor0(f16::MAX),
391            DatumType::F32 => tensor0(f32::MAX),
392            DatumType::F64 => tensor0(f64::MAX),
393            _ => panic!("No max value for datum type {self:?}"),
394        }
395    }
396}
397
398impl std::str::FromStr for DatumType {
399    type Err = TractError;
400
401    fn from_str(s: &str) -> Result<Self, Self::Err> {
402        if let Ok((z, s)) = scan_fmt!(s, "QU8(Z:{d} S:{f})", i32, f32) {
403            Ok(DatumType::QU8(QParams::ZpScale { zero_point: z, scale: s }))
404        } else if let Ok((z, s)) = scan_fmt!(s, "QI8(Z:{d} S:{f})", i32, f32) {
405            Ok(DatumType::QI8(QParams::ZpScale { zero_point: z, scale: s }))
406        } else if let Ok((z, s)) = scan_fmt!(s, "QI32(Z:{d} S:{f})", i32, f32) {
407            Ok(DatumType::QI32(QParams::ZpScale { zero_point: z, scale: s }))
408        } else {
409            match s {
410                "I8" | "i8" => Ok(DatumType::I8),
411                "I16" | "i16" => Ok(DatumType::I16),
412                "I32" | "i32" => Ok(DatumType::I32),
413                "I64" | "i64" => Ok(DatumType::I64),
414                "U8" | "u8" => Ok(DatumType::U8),
415                "U16" | "u16" => Ok(DatumType::U16),
416                "U32" | "u32" => Ok(DatumType::U32),
417                "U64" | "u64" => Ok(DatumType::U64),
418                "F16" | "f16" => Ok(DatumType::F16),
419                "F32" | "f32" => Ok(DatumType::F32),
420                "F64" | "f64" => Ok(DatumType::F64),
421                "Bool" | "bool" => Ok(DatumType::Bool),
422                "Blob" | "blob" => Ok(DatumType::Blob),
423                "String" | "string" => Ok(DatumType::String),
424                "TDim" | "tdim" => Ok(DatumType::TDim),
425                #[cfg(feature = "complex")]
426                "ComplexI16" | "complexi16" => Ok(DatumType::ComplexI16),
427                #[cfg(feature = "complex")]
428                "ComplexI32" | "complexi32" => Ok(DatumType::ComplexI32),
429                #[cfg(feature = "complex")]
430                "ComplexI64" | "complexi64" => Ok(DatumType::ComplexI64),
431                #[cfg(feature = "complex")]
432                "ComplexF16" | "complexf16" => Ok(DatumType::ComplexF16),
433                #[cfg(feature = "complex")]
434                "ComplexF32" | "complexf32" => Ok(DatumType::ComplexF32),
435                #[cfg(feature = "complex")]
436                "ComplexF64" | "complexf64" => Ok(DatumType::ComplexF64),
437                _ => bail!("Unknown type {}", s),
438            }
439        }
440    }
441}
442
443const TOINT: f32 = 1.0f32 / f32::EPSILON;
444
445pub fn round_ties_to_even(x: f32) -> f32 {
446    let u = x.to_bits();
447    let e = (u >> 23) & 0xff;
448    if e >= 0x7f + 23 {
449        return x;
450    }
451    let s = u >> 31;
452    let y = if s == 1 { x - TOINT + TOINT } else { x + TOINT - TOINT };
453    if y == 0.0 {
454        if s == 1 {
455            -0f32
456        } else {
457            0f32
458        }
459    } else {
460        y
461    }
462}
463
464#[inline]
465pub fn scale_by<T: Datum + AsPrimitive<f32>>(b: T, a: f32) -> T
466where
467    f32: AsPrimitive<T>,
468{
469    let b = b.as_();
470    (round_ties_to_even(b.abs() * a) * b.signum()).as_()
471}
472
473pub trait ClampCast: PartialOrd + Copy + 'static {
474    #[inline(always)]
475    fn clamp_cast<O>(self) -> O
476    where
477        Self: AsPrimitive<O> + Datum,
478        O: AsPrimitive<Self> + num_traits::Bounded + Datum,
479    {
480        // this fails if we're upcasting, in which case clamping is useless
481        if O::min_value().as_() < O::max_value().as_() {
482            num_traits::clamp(self, O::min_value().as_(), O::max_value().as_()).as_()
483        } else {
484            self.as_()
485        }
486    }
487}
488impl<T: PartialOrd + Copy + 'static> ClampCast for T {}
489
490pub trait Datum:
491    Clone + Send + Sync + fmt::Debug + fmt::Display + Default + 'static + PartialEq
492{
493    fn name() -> &'static str;
494    fn datum_type() -> DatumType;
495    fn is<D: Datum>() -> bool;
496}
497
498macro_rules! datum {
499    ($t:ty, $v:ident) => {
500        impl From<$t> for Tensor {
501            fn from(it: $t) -> Tensor {
502                tensor0(it)
503            }
504        }
505
506        impl Datum for $t {
507            fn name() -> &'static str {
508                stringify!($t)
509            }
510
511            fn datum_type() -> DatumType {
512                DatumType::$v
513            }
514
515            fn is<D: Datum>() -> bool {
516                Self::datum_type() == D::datum_type()
517            }
518        }
519    };
520}
521
522datum!(bool, Bool);
523datum!(f16, F16);
524datum!(f32, F32);
525datum!(f64, F64);
526datum!(i8, I8);
527datum!(i16, I16);
528datum!(i32, I32);
529datum!(i64, I64);
530datum!(u8, U8);
531datum!(u16, U16);
532datum!(u32, U32);
533datum!(u64, U64);
534datum!(TDim, TDim);
535datum!(String, String);
536datum!(crate::blob::Blob, Blob);
537datum!(crate::opaque::Opaque, Opaque);
538#[cfg(feature = "complex")]
539datum!(Complex<i16>, ComplexI16);
540#[cfg(feature = "complex")]
541datum!(Complex<i32>, ComplexI32);
542#[cfg(feature = "complex")]
543datum!(Complex<i64>, ComplexI64);
544#[cfg(feature = "complex")]
545datum!(Complex<f16>, ComplexF16);
546#[cfg(feature = "complex")]
547datum!(Complex<f32>, ComplexF32);
548#[cfg(feature = "complex")]
549datum!(Complex<f64>, ComplexF64);
550
551#[cfg(test)]
552mod tests {
553    use crate::internal::*;
554    use ndarray::arr1;
555
556    #[test]
557    fn test_array_to_tensor_to_array() {
558        let array = arr1(&[12i32, 42]);
559        let tensor = Tensor::from(array.clone());
560        let view = tensor.to_array_view::<i32>().unwrap();
561        assert_eq!(array, view.into_dimensionality().unwrap());
562    }
563
564    #[test]
565    fn test_cast_dim_to_dim() {
566        let t_dim: Tensor = tensor1(&[12isize.to_dim(), 42isize.to_dim()]);
567        let t_i32 = t_dim.cast_to::<i32>().unwrap();
568        let t_dim_2 = t_i32.cast_to::<TDim>().unwrap().into_owned();
569        assert_eq!(t_dim, t_dim_2);
570    }
571
572    #[test]
573    fn test_cast_i32_to_dim() {
574        let t_i32: Tensor = tensor1(&[0i32, 12]);
575        t_i32.cast_to::<TDim>().unwrap();
576    }
577
578    #[test]
579    fn test_cast_i64_to_bool() {
580        let t_i64: Tensor = tensor1(&[0i64]);
581        t_i64.cast_to::<bool>().unwrap();
582    }
583
584    #[test]
585    fn test_parse_qu8() {
586        assert_eq!(
587            "QU8(Z:128 S:0.01)".parse::<DatumType>().unwrap(),
588            DatumType::QU8(QParams::ZpScale { zero_point: 128, scale: 0.01 })
589        );
590    }
591}