1use crate::TVec;
3use crate::dim::TDim;
4use crate::internal::*;
5use crate::tensor::Tensor;
6use half::f16;
7#[cfg(feature = "complex")]
8use num_complex::Complex;
9use scan_fmt::scan_fmt;
10use std::fmt;
11use std::hash::Hash;
12
13use num_traits::AsPrimitive;
14
15#[derive(Copy, Clone, PartialEq)]
16pub enum QParams {
17 MinMax { min: f32, max: f32 },
18 ZpScale { zero_point: i32, scale: f32 },
19}
20
21impl Eq for QParams {}
22
23impl Ord for QParams {
24 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
25 use QParams::*;
26 match (self, other) {
27 (MinMax { .. }, ZpScale { .. }) => std::cmp::Ordering::Less,
28 (ZpScale { .. }, MinMax { .. }) => std::cmp::Ordering::Greater,
29 (MinMax { min: min1, max: max1 }, MinMax { min: min2, max: max2 }) => {
30 min1.total_cmp(min2).then_with(|| max1.total_cmp(max2))
31 }
32 (
33 Self::ZpScale { zero_point: zp1, scale: s1 },
34 Self::ZpScale { zero_point: zp2, scale: s2 },
35 ) => zp1.cmp(zp2).then_with(|| s1.total_cmp(s2)),
36 }
37 }
38}
39
40impl PartialOrd for QParams {
41 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
42 Some(self.cmp(other))
43 }
44}
45
46impl Default for QParams {
47 fn default() -> Self {
48 QParams::ZpScale { zero_point: 0, scale: 1. }
49 }
50}
51
52#[allow(clippy::derived_hash_with_manual_eq)]
53impl Hash for QParams {
54 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
55 match self {
56 QParams::MinMax { min, max } => {
57 0.hash(state);
58 min.to_bits().hash(state);
59 max.to_bits().hash(state);
60 }
61 QParams::ZpScale { zero_point, scale } => {
62 1.hash(state);
63 zero_point.hash(state);
64 scale.to_bits().hash(state);
65 }
66 }
67 }
68}
69
70impl QParams {
71 pub fn zp_scale(&self) -> (i32, f32) {
72 match self {
73 QParams::MinMax { min, max } => {
74 let scale = (max - min) / 255.;
75 ((-(min + max) / 2. / scale) as i32, scale)
76 }
77 QParams::ZpScale { zero_point, scale } => (*zero_point, *scale),
78 }
79 }
80
81 pub fn q(&self, f: f32) -> i32 {
82 let (zp, scale) = self.zp_scale();
83 (f / scale) as i32 + zp
84 }
85
86 pub fn dq(&self, i: i32) -> f32 {
87 let (zp, scale) = self.zp_scale();
88 (i - zp) as f32 * scale
89 }
90}
91
92impl std::fmt::Debug for QParams {
93 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94 let (zp, scale) = self.zp_scale();
95 write!(f, "Z:{zp} S:{scale}")
96 }
97}
98
99#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Ord, PartialOrd)]
100pub enum DatumType {
101 Bool,
102 U8,
103 U16,
104 U32,
105 U64,
106 I8,
107 I16,
108 I32,
109 I64,
110 F16,
111 F32,
112 F64,
113 TDim,
114 Blob,
115 String,
116 QI8(QParams),
117 QU8(QParams),
118 QI32(QParams),
119 #[cfg(feature = "complex")]
120 ComplexI16,
121 #[cfg(feature = "complex")]
122 ComplexI32,
123 #[cfg(feature = "complex")]
124 ComplexI64,
125 #[cfg(feature = "complex")]
126 ComplexF16,
127 #[cfg(feature = "complex")]
128 ComplexF32,
129 #[cfg(feature = "complex")]
130 ComplexF64,
131 Opaque,
132}
133
134impl DatumType {
135 pub fn super_types(&self) -> TVec<DatumType> {
136 use DatumType::*;
137 if *self == String || *self == TDim || *self == Blob || *self == Bool || self.is_quantized()
138 {
139 return tvec!(*self);
140 }
141 #[cfg(feature = "complex")]
142 if self.is_complex_float() {
143 return [ComplexF16, ComplexF32, ComplexF64]
144 .iter()
145 .filter(|s| s.size_of() >= self.size_of())
146 .copied()
147 .collect();
148 } else if self.is_complex_signed() {
149 return [ComplexI16, ComplexI32, ComplexI64]
150 .iter()
151 .filter(|s| s.size_of() >= self.size_of())
152 .copied()
153 .collect();
154 }
155 if self.is_float() {
156 [F16, F32, F64].iter().filter(|s| s.size_of() >= self.size_of()).copied().collect()
157 } else if self.is_signed() {
158 [I8, I16, I32, I64, TDim]
159 .iter()
160 .filter(|s| s.size_of() >= self.size_of())
161 .copied()
162 .collect()
163 } else {
164 [U8, U16, U32, U64].iter().filter(|s| s.size_of() >= self.size_of()).copied().collect()
165 }
166 }
167
168 pub fn super_type_for(
169 i: impl IntoIterator<Item = impl std::borrow::Borrow<DatumType>>,
170 ) -> Option<DatumType> {
171 let mut iter = i.into_iter();
172 let mut current = match iter.next() {
173 None => return None,
174 Some(it) => *it.borrow(),
175 };
176 for n in iter {
177 match current.common_super_type(*n.borrow()) {
178 None => return None,
179 Some(it) => current = it,
180 }
181 }
182 Some(current)
183 }
184
185 pub fn common_super_type(&self, rhs: DatumType) -> Option<DatumType> {
186 for mine in self.super_types() {
187 for theirs in rhs.super_types() {
188 if mine == theirs {
189 return Some(mine);
190 }
191 }
192 }
193 None
194 }
195
196 pub fn is_unsigned(&self) -> bool {
197 matches!(
198 self.unquantized(),
199 DatumType::U8 | DatumType::U16 | DatumType::U32 | DatumType::U64
200 )
201 }
202
203 pub fn is_signed(&self) -> bool {
204 matches!(
205 self.unquantized(),
206 DatumType::I8 | DatumType::I16 | DatumType::I32 | DatumType::I64
207 )
208 }
209
210 pub fn is_float(&self) -> bool {
211 matches!(self, DatumType::F16 | DatumType::F32 | DatumType::F64)
212 }
213
214 pub fn is_number(&self) -> bool {
215 self.is_signed() | self.is_unsigned() | self.is_float() | self.is_quantized()
216 }
217
218 pub fn is_tdim(&self) -> bool {
219 *self == DatumType::TDim
220 }
221
222 pub fn is_opaque(&self) -> bool {
223 *self == DatumType::Opaque
224 }
225
226 #[cfg(feature = "complex")]
227 pub fn is_complex(&self) -> bool {
228 self.is_complex_float() || self.is_complex_signed()
229 }
230
231 #[cfg(feature = "complex")]
232 pub fn is_complex_float(&self) -> bool {
233 matches!(self, DatumType::ComplexF16 | DatumType::ComplexF32 | DatumType::ComplexF64)
234 }
235
236 #[cfg(feature = "complex")]
237 pub fn is_complex_signed(&self) -> bool {
238 matches!(self, DatumType::ComplexI16 | DatumType::ComplexI32 | DatumType::ComplexI64)
239 }
240
241 #[cfg(feature = "complex")]
242 pub fn complexify(&self) -> TractResult<DatumType> {
243 match *self {
244 DatumType::I16 => Ok(DatumType::ComplexI16),
245 DatumType::I32 => Ok(DatumType::ComplexI32),
246 DatumType::I64 => Ok(DatumType::ComplexI64),
247 DatumType::F16 => Ok(DatumType::ComplexF16),
248 DatumType::F32 => Ok(DatumType::ComplexF32),
249 DatumType::F64 => Ok(DatumType::ComplexF64),
250 _ => bail!("No complex datum type formed on {:?}", self),
251 }
252 }
253
254 #[cfg(feature = "complex")]
255 pub fn decomplexify(&self) -> TractResult<DatumType> {
256 match *self {
257 DatumType::ComplexI16 => Ok(DatumType::I16),
258 DatumType::ComplexI32 => Ok(DatumType::I32),
259 DatumType::ComplexI64 => Ok(DatumType::I64),
260 DatumType::ComplexF16 => Ok(DatumType::F16),
261 DatumType::ComplexF32 => Ok(DatumType::F32),
262 DatumType::ComplexF64 => Ok(DatumType::F64),
263 _ => bail!("{:?} is not a complex type", self),
264 }
265 }
266
267 pub fn is_copy(&self) -> bool {
268 #[cfg(feature = "complex")]
269 if self.is_complex() {
270 return true;
271 }
272 *self == DatumType::Bool || self.is_unsigned() || self.is_signed() || self.is_float()
273 }
274
275 pub fn is_quantized(&self) -> bool {
276 self.qparams().is_some()
277 }
278
279 pub fn qparams(&self) -> Option<QParams> {
280 match self {
281 DatumType::QI8(qparams) | DatumType::QU8(qparams) | DatumType::QI32(qparams) => {
282 Some(*qparams)
283 }
284 _ => None,
285 }
286 }
287
288 pub fn with_qparams(&self, qparams: QParams) -> DatumType {
289 match self {
290 DatumType::QI8(_) => DatumType::QI8(qparams),
291 DatumType::QU8(_) => DatumType::QI8(qparams),
292 DatumType::QI32(_) => DatumType::QI32(qparams),
293 _ => *self,
294 }
295 }
296
297 pub fn quantize(&self, qparams: QParams) -> DatumType {
298 match self {
299 DatumType::I8 => DatumType::QI8(qparams),
300 DatumType::U8 => DatumType::QU8(qparams),
301 DatumType::I32 => DatumType::QI32(qparams),
302 DatumType::QI8(_) => DatumType::QI8(qparams),
303 DatumType::QU8(_) => DatumType::QU8(qparams),
304 DatumType::QI32(_) => DatumType::QI32(qparams),
305 _ => panic!("Can't quantize {self:?}"),
306 }
307 }
308
309 #[inline(always)]
310 pub fn zp_scale(&self) -> (i32, f32) {
311 self.qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.))
312 }
313
314 #[inline(always)]
315 pub fn with_zp_scale(&self, zero_point: i32, scale: f32) -> DatumType {
316 self.quantize(QParams::ZpScale { zero_point, scale })
317 }
318
319 pub fn unquantized(&self) -> DatumType {
320 match self {
321 DatumType::QI8(_) => DatumType::I8,
322 DatumType::QU8(_) => DatumType::U8,
323 DatumType::QI32(_) => DatumType::I32,
324 _ => *self,
325 }
326 }
327
328 pub fn integer(signed: bool, size: usize) -> Self {
329 use DatumType::*;
330 match (signed, size) {
331 (false, 8) => U8,
332 (false, 16) => U16,
333 (false, 32) => U32,
334 (false, 64) => U64,
335 (true, 8) => U8,
336 (true, 16) => U16,
337 (true, 32) => U32,
338 (true, 64) => U64,
339 _ => panic!("No integer for signed:{signed} size:{size}"),
340 }
341 }
342
343 pub fn is_integer(&self) -> bool {
344 self.is_signed() || self.is_unsigned()
345 }
346
347 #[inline]
348 pub fn size_of(&self) -> usize {
349 dispatch_datum!(std::mem::size_of(self)())
350 }
351
352 pub fn min_value(&self) -> Tensor {
353 match self {
354 DatumType::QU8(_)
355 | DatumType::U8
356 | DatumType::U16
357 | DatumType::U32
358 | DatumType::U64 => Tensor::zero_dt(*self, &[1]).unwrap(),
359 DatumType::I8 | DatumType::QI8(_) => tensor0(i8::MIN),
360 DatumType::QI32(_) => tensor0(i32::MIN),
361 DatumType::I16 => tensor0(i16::MIN),
362 DatumType::I32 => tensor0(i32::MIN),
363 DatumType::I64 => tensor0(i64::MIN),
364 DatumType::F16 => tensor0(f16::MIN),
365 DatumType::F32 => tensor0(f32::MIN),
366 DatumType::F64 => tensor0(f64::MIN),
367 _ => panic!("No min value for datum type {self:?}"),
368 }
369 }
370 pub fn max_value(&self) -> Tensor {
371 match self {
372 DatumType::U8 | DatumType::QU8(_) => tensor0(u8::MAX),
373 DatumType::U16 => tensor0(u16::MAX),
374 DatumType::U32 => tensor0(u32::MAX),
375 DatumType::U64 => tensor0(u64::MAX),
376 DatumType::I8 | DatumType::QI8(_) => tensor0(i8::MAX),
377 DatumType::I16 => tensor0(i16::MAX),
378 DatumType::I32 => tensor0(i32::MAX),
379 DatumType::I64 => tensor0(i64::MAX),
380 DatumType::QI32(_) => tensor0(i32::MAX),
381 DatumType::F16 => tensor0(f16::MAX),
382 DatumType::F32 => tensor0(f32::MAX),
383 DatumType::F64 => tensor0(f64::MAX),
384 _ => panic!("No max value for datum type {self:?}"),
385 }
386 }
387
388 pub fn is<D: Datum>(&self) -> bool {
389 *self == D::datum_type()
390 }
391}
392
393impl std::str::FromStr for DatumType {
394 type Err = TractError;
395
396 fn from_str(s: &str) -> Result<Self, Self::Err> {
397 if let Ok((z, s)) = scan_fmt!(s, "QU8(Z:{d} S:{f})", i32, f32) {
398 Ok(DatumType::QU8(QParams::ZpScale { zero_point: z, scale: s }))
399 } else if let Ok((z, s)) = scan_fmt!(s, "QI8(Z:{d} S:{f})", i32, f32) {
400 Ok(DatumType::QI8(QParams::ZpScale { zero_point: z, scale: s }))
401 } else if let Ok((z, s)) = scan_fmt!(s, "QI32(Z:{d} S:{f})", i32, f32) {
402 Ok(DatumType::QI32(QParams::ZpScale { zero_point: z, scale: s }))
403 } else {
404 match s {
405 "I8" | "i8" => Ok(DatumType::I8),
406 "I16" | "i16" => Ok(DatumType::I16),
407 "I32" | "i32" => Ok(DatumType::I32),
408 "I64" | "i64" => Ok(DatumType::I64),
409 "U8" | "u8" => Ok(DatumType::U8),
410 "U16" | "u16" => Ok(DatumType::U16),
411 "U32" | "u32" => Ok(DatumType::U32),
412 "U64" | "u64" => Ok(DatumType::U64),
413 "F16" | "f16" => Ok(DatumType::F16),
414 "F32" | "f32" => Ok(DatumType::F32),
415 "F64" | "f64" => Ok(DatumType::F64),
416 "Bool" | "bool" => Ok(DatumType::Bool),
417 "Blob" | "blob" => Ok(DatumType::Blob),
418 "String" | "string" => Ok(DatumType::String),
419 "TDim" | "tdim" => Ok(DatumType::TDim),
420 #[cfg(feature = "complex")]
421 "ComplexI16" | "complexi16" => Ok(DatumType::ComplexI16),
422 #[cfg(feature = "complex")]
423 "ComplexI32" | "complexi32" => Ok(DatumType::ComplexI32),
424 #[cfg(feature = "complex")]
425 "ComplexI64" | "complexi64" => Ok(DatumType::ComplexI64),
426 #[cfg(feature = "complex")]
427 "ComplexF16" | "complexf16" => Ok(DatumType::ComplexF16),
428 #[cfg(feature = "complex")]
429 "ComplexF32" | "complexf32" => Ok(DatumType::ComplexF32),
430 #[cfg(feature = "complex")]
431 "ComplexF64" | "complexf64" => Ok(DatumType::ComplexF64),
432 _ => bail!("Unknown type {}", s),
433 }
434 }
435 }
436}
437
438const TOINT: f32 = 1.0f32 / f32::EPSILON;
439
440pub fn round_ties_to_even(x: f32) -> f32 {
441 let u = x.to_bits();
442 let e = (u >> 23) & 0xff;
443 if e >= 0x7f + 23 {
444 return x;
445 }
446 let s = u >> 31;
447 let y = if s == 1 { x - TOINT + TOINT } else { x + TOINT - TOINT };
448 if y == 0.0 { if s == 1 { -0f32 } else { 0f32 } } else { y }
449}
450
451#[inline]
452pub fn scale_by<T: Datum + AsPrimitive<f32>>(b: T, a: f32) -> T
453where
454 f32: AsPrimitive<T>,
455{
456 let b = b.as_();
457 (round_ties_to_even(b.abs() * a) * b.signum()).as_()
458}
459
460pub trait ClampCast: PartialOrd + Copy + 'static {
461 #[inline(always)]
462 fn clamp_cast<O>(self) -> O
463 where
464 Self: AsPrimitive<O> + Datum,
465 O: AsPrimitive<Self> + num_traits::Bounded + Datum,
466 {
467 if O::min_value().as_() < O::max_value().as_() {
469 num_traits::clamp(self, O::min_value().as_(), O::max_value().as_()).as_()
470 } else {
471 self.as_()
472 }
473 }
474}
475impl<T: PartialOrd + Copy + 'static> ClampCast for T {}
476
477pub trait Datum:
478 Clone + Send + Sync + fmt::Debug + fmt::Display + Default + 'static + PartialEq
479{
480 fn name() -> &'static str;
481 fn datum_type() -> DatumType;
482 fn is<D: Datum>() -> bool;
483}
484
485macro_rules! datum {
486 ($t:ty, $v:ident) => {
487 impl From<$t> for Tensor {
488 fn from(it: $t) -> Tensor {
489 tensor0(it)
490 }
491 }
492
493 impl Datum for $t {
494 fn name() -> &'static str {
495 stringify!($t)
496 }
497
498 fn datum_type() -> DatumType {
499 DatumType::$v
500 }
501
502 fn is<D: Datum>() -> bool {
503 Self::datum_type() == D::datum_type()
504 }
505 }
506 };
507}
508
509datum!(bool, Bool);
510datum!(f16, F16);
511datum!(f32, F32);
512datum!(f64, F64);
513datum!(i8, I8);
514datum!(i16, I16);
515datum!(i32, I32);
516datum!(i64, I64);
517datum!(u8, U8);
518datum!(u16, U16);
519datum!(u32, U32);
520datum!(u64, U64);
521datum!(TDim, TDim);
522datum!(String, String);
523datum!(crate::blob::Blob, Blob);
524datum!(crate::opaque::Opaque, Opaque);
525#[cfg(feature = "complex")]
526datum!(Complex<i16>, ComplexI16);
527#[cfg(feature = "complex")]
528datum!(Complex<i32>, ComplexI32);
529#[cfg(feature = "complex")]
530datum!(Complex<i64>, ComplexI64);
531#[cfg(feature = "complex")]
532datum!(Complex<f16>, ComplexF16);
533#[cfg(feature = "complex")]
534datum!(Complex<f32>, ComplexF32);
535#[cfg(feature = "complex")]
536datum!(Complex<f64>, ComplexF64);
537
538#[cfg(test)]
539mod tests {
540 use crate::internal::*;
541 use ndarray::arr1;
542
543 #[test]
544 fn test_array_to_tensor_to_array() {
545 let array = arr1(&[12i32, 42]);
546 let tensor = Tensor::from(array.clone());
547 let view = tensor.to_array_view::<i32>().unwrap();
548 assert_eq!(array, view.into_dimensionality().unwrap());
549 }
550
551 #[test]
552 fn test_cast_dim_to_dim() {
553 let t_dim: Tensor = tensor1(&[12isize.to_dim(), 42isize.to_dim()]);
554 let t_i32 = t_dim.cast_to::<i32>().unwrap();
555 let t_dim_2 = t_i32.cast_to::<TDim>().unwrap().into_owned();
556 assert_eq!(t_dim, t_dim_2);
557 }
558
559 #[test]
560 fn test_cast_i32_to_dim() {
561 let t_i32: Tensor = tensor1(&[0i32, 12]);
562 t_i32.cast_to::<TDim>().unwrap();
563 }
564
565 #[test]
566 fn test_cast_i64_to_bool() {
567 let t_i64: Tensor = tensor1(&[0i64]);
568 t_i64.cast_to::<bool>().unwrap();
569 }
570
571 #[test]
572 fn test_parse_qu8() {
573 assert_eq!(
574 "QU8(Z:128 S:0.01)".parse::<DatumType>().unwrap(),
575 DatumType::QU8(QParams::ZpScale { zero_point: 128, scale: 0.01 })
576 );
577 }
578}