1use crate::internal::*;
3use crate::dim::TDim;
4use crate::tensor::Tensor;
5use crate::TVec;
6use half::f16;
7#[cfg(feature = "complex")]
8use num_complex::Complex;
9use scan_fmt::scan_fmt;
10use std::fmt;
11use std::hash::Hash;
12
13use num_traits::AsPrimitive;
14
15#[derive(Copy, Clone, PartialEq)]
16pub enum QParams {
17 MinMax { min: f32, max: f32 },
18 ZpScale { zero_point: i32, scale: f32 },
19}
20
21impl Eq for QParams {}
22
23impl Ord for QParams {
24 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
25 use QParams::*;
26 match (self, other) {
27 (MinMax { .. }, ZpScale { .. }) => std::cmp::Ordering::Less,
28 (ZpScale { .. }, MinMax { .. }) => std::cmp::Ordering::Greater,
29 (MinMax { min: min1, max: max1 }, MinMax { min: min2, max: max2 }) => {
30 min1.total_cmp(min2).then_with(|| max1.total_cmp(max2))
31 }
32 (
33 Self::ZpScale { zero_point: zp1, scale: s1 },
34 Self::ZpScale { zero_point: zp2, scale: s2 },
35 ) => zp1.cmp(zp2).then_with(|| s1.total_cmp(s2)),
36 }
37 }
38}
39
40impl PartialOrd for QParams {
41 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
42 Some(self.cmp(other))
43 }
44}
45
46impl Default for QParams {
47 fn default() -> Self {
48 QParams::ZpScale { zero_point: 0, scale: 1. }
49 }
50}
51
52#[allow(clippy::derived_hash_with_manual_eq)]
53impl Hash for QParams {
54 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
55 match self {
56 QParams::MinMax { min, max } => {
57 0.hash(state);
58 min.to_bits().hash(state);
59 max.to_bits().hash(state);
60 }
61 QParams::ZpScale { zero_point, scale } => {
62 1.hash(state);
63 zero_point.hash(state);
64 scale.to_bits().hash(state);
65 }
66 }
67 }
68}
69
70impl QParams {
71 pub fn zp_scale(&self) -> (i32, f32) {
72 match self {
73 QParams::MinMax { min, max } => {
74 let scale = (max - min) / 255.;
75 ((-(min + max) / 2. / scale) as i32, scale)
76 }
77 QParams::ZpScale { zero_point, scale } => (*zero_point, *scale),
78 }
79 }
80
81 pub fn q(&self, f: f32) -> i32 {
82 let (zp, scale) = self.zp_scale();
83 (f / scale) as i32 + zp
84 }
85
86 pub fn dq(&self, i: i32) -> f32 {
87 let (zp, scale) = self.zp_scale();
88 (i - zp) as f32 * scale
89 }
90}
91
92impl std::fmt::Debug for QParams {
93 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94 let (zp, scale) = self.zp_scale();
95 write!(f, "Z:{zp} S:{scale}")
96 }
97}
98
99#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Ord, PartialOrd)]
100pub enum DatumType {
101 Bool,
102 U8,
103 U16,
104 U32,
105 U64,
106 I8,
107 I16,
108 I32,
109 I64,
110 F16,
111 F32,
112 F64,
113 TDim,
114 Blob,
115 String,
116 QI8(QParams),
117 QU8(QParams),
118 QI32(QParams),
119 #[cfg(feature = "complex")]
120 ComplexI16,
121 #[cfg(feature = "complex")]
122 ComplexI32,
123 #[cfg(feature = "complex")]
124 ComplexI64,
125 #[cfg(feature = "complex")]
126 ComplexF16,
127 #[cfg(feature = "complex")]
128 ComplexF32,
129 #[cfg(feature = "complex")]
130 ComplexF64,
131 Opaque,
132}
133
134impl DatumType {
135 pub fn super_types(&self) -> TVec<DatumType> {
136 use DatumType::*;
137 if *self == String || *self == TDim || *self == Blob || *self == Bool || self.is_quantized()
138 {
139 return tvec!(*self);
140 }
141 #[cfg(feature = "complex")]
142 if self.is_complex_float() {
143 return [ComplexF16, ComplexF32, ComplexF64]
144 .iter()
145 .filter(|s| s.size_of() >= self.size_of())
146 .copied()
147 .collect();
148 } else if self.is_complex_signed() {
149 return [ComplexI16, ComplexI32, ComplexI64]
150 .iter()
151 .filter(|s| s.size_of() >= self.size_of())
152 .copied()
153 .collect();
154 }
155 if self.is_float() {
156 [F16, F32, F64].iter().filter(|s| s.size_of() >= self.size_of()).copied().collect()
157 } else if self.is_signed() {
158 [I8, I16, I32, I64, TDim]
159 .iter()
160 .filter(|s| s.size_of() >= self.size_of())
161 .copied()
162 .collect()
163 } else {
164 [U8, U16, U32, U64].iter().filter(|s| s.size_of() >= self.size_of()).copied().collect()
165 }
166 }
167
168 pub fn super_type_for(
169 i: impl IntoIterator<Item = impl std::borrow::Borrow<DatumType>>,
170 ) -> Option<DatumType> {
171 let mut iter = i.into_iter();
172 let mut current = match iter.next() {
173 None => return None,
174 Some(it) => *it.borrow(),
175 };
176 for n in iter {
177 match current.common_super_type(*n.borrow()) {
178 None => return None,
179 Some(it) => current = it,
180 }
181 }
182 Some(current)
183 }
184
185 pub fn common_super_type(&self, rhs: DatumType) -> Option<DatumType> {
186 for mine in self.super_types() {
187 for theirs in rhs.super_types() {
188 if mine == theirs {
189 return Some(mine);
190 }
191 }
192 }
193 None
194 }
195
196 pub fn is_unsigned(&self) -> bool {
197 matches!(
198 self.unquantized(),
199 DatumType::U8 | DatumType::U16 | DatumType::U32 | DatumType::U64
200 )
201 }
202
203 pub fn is_signed(&self) -> bool {
204 matches!(
205 self.unquantized(),
206 DatumType::I8 | DatumType::I16 | DatumType::I32 | DatumType::I64
207 )
208 }
209
210 pub fn is_float(&self) -> bool {
211 matches!(self, DatumType::F16 | DatumType::F32 | DatumType::F64)
212 }
213
214 pub fn is_number(&self) -> bool {
215 self.is_signed() | self.is_unsigned() | self.is_float() | self.is_quantized()
216 }
217
218 pub fn is_tdim(&self) -> bool {
219 *self == DatumType::TDim
220 }
221
222 pub fn is_opaque(&self) -> bool {
223 *self == DatumType::Opaque
224 }
225
226 #[cfg(feature = "complex")]
227 pub fn is_complex(&self) -> bool {
228 self.is_complex_float() || self.is_complex_signed()
229 }
230
231 #[cfg(feature = "complex")]
232 pub fn is_complex_float(&self) -> bool {
233 matches!(self, DatumType::ComplexF16 | DatumType::ComplexF32 | DatumType::ComplexF64)
234 }
235
236 #[cfg(feature = "complex")]
237 pub fn is_complex_signed(&self) -> bool {
238 matches!(self, DatumType::ComplexI16 | DatumType::ComplexI32 | DatumType::ComplexI64)
239 }
240
241 #[cfg(feature = "complex")]
242 pub fn complexify(&self) -> TractResult<DatumType> {
243 match *self {
244 DatumType::I16 => Ok(DatumType::ComplexI16),
245 DatumType::I32 => Ok(DatumType::ComplexI32),
246 DatumType::I64 => Ok(DatumType::ComplexI64),
247 DatumType::F16 => Ok(DatumType::ComplexF16),
248 DatumType::F32 => Ok(DatumType::ComplexF32),
249 DatumType::F64 => Ok(DatumType::ComplexF64),
250 _ => bail!("No complex datum type formed on {:?}", self),
251 }
252 }
253
254 #[cfg(feature = "complex")]
255 pub fn decomplexify(&self) -> TractResult<DatumType> {
256 match *self {
257 DatumType::ComplexI16 => Ok(DatumType::I16),
258 DatumType::ComplexI32 => Ok(DatumType::I32),
259 DatumType::ComplexI64 => Ok(DatumType::I64),
260 DatumType::ComplexF16 => Ok(DatumType::F16),
261 DatumType::ComplexF32 => Ok(DatumType::F32),
262 DatumType::ComplexF64 => Ok(DatumType::F64),
263 _ => bail!("{:?} is not a complex type", self),
264 }
265 }
266
267 pub fn is_copy(&self) -> bool {
268 #[cfg(feature = "complex")]
269 if self.is_complex() {
270 return true;
271 }
272 *self == DatumType::Bool || self.is_unsigned() || self.is_signed() || self.is_float()
273 }
274
275 pub fn is_quantized(&self) -> bool {
276 self.qparams().is_some()
277 }
278
279 pub fn qparams(&self) -> Option<QParams> {
280 match self {
281 DatumType::QI8(qparams) | DatumType::QU8(qparams) | DatumType::QI32(qparams) => {
282 Some(*qparams)
283 }
284 _ => None,
285 }
286 }
287
288 pub fn with_qparams(&self, qparams: QParams) -> DatumType {
289 match self {
290 DatumType::QI8(_) => DatumType::QI8(qparams),
291 DatumType::QU8(_) => DatumType::QI8(qparams),
292 DatumType::QI32(_) => DatumType::QI32(qparams),
293 _ => *self,
294 }
295 }
296
297 pub fn quantize(&self, qparams: QParams) -> DatumType {
298 match self {
299 DatumType::I8 => DatumType::QI8(qparams),
300 DatumType::U8 => DatumType::QU8(qparams),
301 DatumType::I32 => DatumType::QI32(qparams),
302 DatumType::QI8(_) => DatumType::QI8(qparams),
303 DatumType::QU8(_) => DatumType::QU8(qparams),
304 DatumType::QI32(_) => DatumType::QI32(qparams),
305 _ => panic!("Can't quantize {self:?}"),
306 }
307 }
308
309 #[inline(always)]
310 pub fn zp_scale(&self) -> (i32, f32) {
311 self.qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.))
312 }
313
314 #[inline(always)]
315 pub fn with_zp_scale(&self, zero_point: i32, scale: f32) -> DatumType {
316 self.quantize(QParams::ZpScale { zero_point, scale })
317 }
318
319 pub fn unquantized(&self) -> DatumType {
320 match self {
321 DatumType::QI8(_) => DatumType::I8,
322 DatumType::QU8(_) => DatumType::U8,
323 DatumType::QI32(_) => DatumType::I32,
324 _ => *self,
325 }
326 }
327
328 pub fn integer(signed: bool, size: usize) -> Self {
329 use DatumType::*;
330 match (signed, size) {
331 (false, 8) => U8,
332 (false, 16) => U16,
333 (false, 32) => U32,
334 (false, 64) => U64,
335 (true, 8) => U8,
336 (true, 16) => U16,
337 (true, 32) => U32,
338 (true, 64) => U64,
339 _ => panic!("No integer for signed:{signed} size:{size}"),
340 }
341 }
342
343 pub fn is_integer(&self) -> bool {
344 self.is_signed() || self.is_unsigned()
345 }
346
347 #[inline]
348 pub fn size_of(&self) -> usize {
349 dispatch_datum!(std::mem::size_of(self)())
350 }
351
352 #[inline]
353 pub fn alignment(&self) -> usize {
354 if self.is_copy() {
355 self.size_of()
356 } else {
357 std::mem::size_of::<usize>()
358 }
359 }
360
361 pub fn min_value(&self) -> Tensor {
362 match self {
363 DatumType::QU8(_)
364 | DatumType::U8
365 | DatumType::U16
366 | DatumType::U32
367 | DatumType::U64 => Tensor::zero_dt(*self, &[1]).unwrap(),
368 DatumType::I8 | DatumType::QI8(_) => tensor0(i8::MIN),
369 DatumType::QI32(_) => tensor0(i32::MIN),
370 DatumType::I16 => tensor0(i16::MIN),
371 DatumType::I32 => tensor0(i32::MIN),
372 DatumType::I64 => tensor0(i64::MIN),
373 DatumType::F16 => tensor0(f16::MIN),
374 DatumType::F32 => tensor0(f32::MIN),
375 DatumType::F64 => tensor0(f64::MIN),
376 _ => panic!("No min value for datum type {self:?}"),
377 }
378 }
379 pub fn max_value(&self) -> Tensor {
380 match self {
381 DatumType::U8 | DatumType::QU8(_) => tensor0(u8::MAX),
382 DatumType::U16 => tensor0(u16::MAX),
383 DatumType::U32 => tensor0(u32::MAX),
384 DatumType::U64 => tensor0(u64::MAX),
385 DatumType::I8 | DatumType::QI8(_) => tensor0(i8::MAX),
386 DatumType::I16 => tensor0(i16::MAX),
387 DatumType::I32 => tensor0(i32::MAX),
388 DatumType::I64 => tensor0(i64::MAX),
389 DatumType::QI32(_) => tensor0(i32::MAX),
390 DatumType::F16 => tensor0(f16::MAX),
391 DatumType::F32 => tensor0(f32::MAX),
392 DatumType::F64 => tensor0(f64::MAX),
393 _ => panic!("No max value for datum type {self:?}"),
394 }
395 }
396}
397
398impl std::str::FromStr for DatumType {
399 type Err = TractError;
400
401 fn from_str(s: &str) -> Result<Self, Self::Err> {
402 if let Ok((z, s)) = scan_fmt!(s, "QU8(Z:{d} S:{f})", i32, f32) {
403 Ok(DatumType::QU8(QParams::ZpScale { zero_point: z, scale: s }))
404 } else if let Ok((z, s)) = scan_fmt!(s, "QI8(Z:{d} S:{f})", i32, f32) {
405 Ok(DatumType::QI8(QParams::ZpScale { zero_point: z, scale: s }))
406 } else if let Ok((z, s)) = scan_fmt!(s, "QI32(Z:{d} S:{f})", i32, f32) {
407 Ok(DatumType::QI32(QParams::ZpScale { zero_point: z, scale: s }))
408 } else {
409 match s {
410 "I8" | "i8" => Ok(DatumType::I8),
411 "I16" | "i16" => Ok(DatumType::I16),
412 "I32" | "i32" => Ok(DatumType::I32),
413 "I64" | "i64" => Ok(DatumType::I64),
414 "U8" | "u8" => Ok(DatumType::U8),
415 "U16" | "u16" => Ok(DatumType::U16),
416 "U32" | "u32" => Ok(DatumType::U32),
417 "U64" | "u64" => Ok(DatumType::U64),
418 "F16" | "f16" => Ok(DatumType::F16),
419 "F32" | "f32" => Ok(DatumType::F32),
420 "F64" | "f64" => Ok(DatumType::F64),
421 "Bool" | "bool" => Ok(DatumType::Bool),
422 "Blob" | "blob" => Ok(DatumType::Blob),
423 "String" | "string" => Ok(DatumType::String),
424 "TDim" | "tdim" => Ok(DatumType::TDim),
425 #[cfg(feature = "complex")]
426 "ComplexI16" | "complexi16" => Ok(DatumType::ComplexI16),
427 #[cfg(feature = "complex")]
428 "ComplexI32" | "complexi32" => Ok(DatumType::ComplexI32),
429 #[cfg(feature = "complex")]
430 "ComplexI64" | "complexi64" => Ok(DatumType::ComplexI64),
431 #[cfg(feature = "complex")]
432 "ComplexF16" | "complexf16" => Ok(DatumType::ComplexF16),
433 #[cfg(feature = "complex")]
434 "ComplexF32" | "complexf32" => Ok(DatumType::ComplexF32),
435 #[cfg(feature = "complex")]
436 "ComplexF64" | "complexf64" => Ok(DatumType::ComplexF64),
437 _ => bail!("Unknown type {}", s),
438 }
439 }
440 }
441}
442
443const TOINT: f32 = 1.0f32 / f32::EPSILON;
444
445pub fn round_ties_to_even(x: f32) -> f32 {
446 let u = x.to_bits();
447 let e = (u >> 23) & 0xff;
448 if e >= 0x7f + 23 {
449 return x;
450 }
451 let s = u >> 31;
452 let y = if s == 1 { x - TOINT + TOINT } else { x + TOINT - TOINT };
453 if y == 0.0 {
454 if s == 1 {
455 -0f32
456 } else {
457 0f32
458 }
459 } else {
460 y
461 }
462}
463
464#[inline]
465pub fn scale_by<T: Datum + AsPrimitive<f32>>(b: T, a: f32) -> T
466where
467 f32: AsPrimitive<T>,
468{
469 let b = b.as_();
470 (round_ties_to_even(b.abs() * a) * b.signum()).as_()
471}
472
473pub trait ClampCast: PartialOrd + Copy + 'static {
474 #[inline(always)]
475 fn clamp_cast<O>(self) -> O
476 where
477 Self: AsPrimitive<O> + Datum,
478 O: AsPrimitive<Self> + num_traits::Bounded + Datum,
479 {
480 if O::min_value().as_() < O::max_value().as_() {
482 num_traits::clamp(self, O::min_value().as_(), O::max_value().as_()).as_()
483 } else {
484 self.as_()
485 }
486 }
487}
488impl<T: PartialOrd + Copy + 'static> ClampCast for T {}
489
490pub trait Datum:
491 Clone + Send + Sync + fmt::Debug + fmt::Display + Default + 'static + PartialEq
492{
493 fn name() -> &'static str;
494 fn datum_type() -> DatumType;
495 fn is<D: Datum>() -> bool;
496}
497
498macro_rules! datum {
499 ($t:ty, $v:ident) => {
500 impl From<$t> for Tensor {
501 fn from(it: $t) -> Tensor {
502 tensor0(it)
503 }
504 }
505
506 impl Datum for $t {
507 fn name() -> &'static str {
508 stringify!($t)
509 }
510
511 fn datum_type() -> DatumType {
512 DatumType::$v
513 }
514
515 fn is<D: Datum>() -> bool {
516 Self::datum_type() == D::datum_type()
517 }
518 }
519 };
520}
521
522datum!(bool, Bool);
523datum!(f16, F16);
524datum!(f32, F32);
525datum!(f64, F64);
526datum!(i8, I8);
527datum!(i16, I16);
528datum!(i32, I32);
529datum!(i64, I64);
530datum!(u8, U8);
531datum!(u16, U16);
532datum!(u32, U32);
533datum!(u64, U64);
534datum!(TDim, TDim);
535datum!(String, String);
536datum!(crate::blob::Blob, Blob);
537datum!(crate::opaque::Opaque, Opaque);
538#[cfg(feature = "complex")]
539datum!(Complex<i16>, ComplexI16);
540#[cfg(feature = "complex")]
541datum!(Complex<i32>, ComplexI32);
542#[cfg(feature = "complex")]
543datum!(Complex<i64>, ComplexI64);
544#[cfg(feature = "complex")]
545datum!(Complex<f16>, ComplexF16);
546#[cfg(feature = "complex")]
547datum!(Complex<f32>, ComplexF32);
548#[cfg(feature = "complex")]
549datum!(Complex<f64>, ComplexF64);
550
551#[cfg(test)]
552mod tests {
553 use crate::internal::*;
554 use ndarray::arr1;
555
556 #[test]
557 fn test_array_to_tensor_to_array() {
558 let array = arr1(&[12i32, 42]);
559 let tensor = Tensor::from(array.clone());
560 let view = tensor.to_array_view::<i32>().unwrap();
561 assert_eq!(array, view.into_dimensionality().unwrap());
562 }
563
564 #[test]
565 fn test_cast_dim_to_dim() {
566 let t_dim: Tensor = tensor1(&[12isize.to_dim(), 42isize.to_dim()]);
567 let t_i32 = t_dim.cast_to::<i32>().unwrap();
568 let t_dim_2 = t_i32.cast_to::<TDim>().unwrap().into_owned();
569 assert_eq!(t_dim, t_dim_2);
570 }
571
572 #[test]
573 fn test_cast_i32_to_dim() {
574 let t_i32: Tensor = tensor1(&[0i32, 12]);
575 t_i32.cast_to::<TDim>().unwrap();
576 }
577
578 #[test]
579 fn test_cast_i64_to_bool() {
580 let t_i64: Tensor = tensor1(&[0i64]);
581 t_i64.cast_to::<bool>().unwrap();
582 }
583
584 #[test]
585 fn test_parse_qu8() {
586 assert_eq!(
587 "QU8(Z:128 S:0.01)".parse::<DatumType>().unwrap(),
588 DatumType::QU8(QParams::ZpScale { zero_point: 128, scale: 0.01 })
589 );
590 }
591}