tract_core/ops/
quant.rs

1#![allow(clippy::unnecessary_cast)]
2
3use crate::internal::*;
4use crate::ops::element_wise::ElementWiseOp;
5use crate::ops::math::QScale;
6use num_traits::AsPrimitive;
7use tract_linalg::lut::Lut;
8use tract_linalg::mmm::RoundingPolicy;
9use tract_linalg::Scaler;
10
11use super::binary::TypedBinOp;
12use super::math::round_ties_to_even;
13
14pub fn quantize_linear_f32_u8(x: f32, scale: f32, zero_point: i32) -> u8 {
15    (((x * scale).round() as i32) + zero_point)
16        .clamp(u8::MIN as i32, u8::MAX as i32) as u8
17}
18
19pub fn quantize_linear_f32_i8(x: f32, scale: f32, zero_point: i32) -> i8 {
20    (((x * scale).round() as i32) + zero_point)
21        .clamp(i8::MIN as i32, i8::MAX as i32) as i8
22}
23
24element_wise_oop!(quantize_linear_u8,
25 QuantizeLinearU8 {
26     scale: f32,
27     zero_point: u8
28 },
29 [f16] => u8 |op, xs, ys| {
30     xs.iter().zip(ys.iter_mut()).for_each(|(x,y)|
31                                           *y = quantize_linear_f32_u8(x.to_f32(), op.scale, op.zero_point as i32)
32                                          );
33     Ok(())
34 },
35 [f32,i32] => u8 |op, xs, ys| {
36     xs.iter().zip(ys.iter_mut()).for_each(|(x,y)|
37                                           *y = quantize_linear_f32_u8(*x as f32, op.scale, op.zero_point as i32)
38                                          );
39     Ok(())
40 };
41 info: info_quantize_linear_u8
42);
43
44fn info_quantize_linear_u8(q: &QuantizeLinearU8) -> TractResult<Vec<String>> {
45    Ok(vec![format!(
46        "scale: {} zero_point: {} 1/scale: {}",
47        q.scale,
48        q.zero_point,
49        q.scale.recip()
50    )])
51}
52
53element_wise_oop!(quantize_linear_i8,
54 QuantizeLinearI8 {
55     scale: f32,
56     zero_point: i8
57 },
58 [f32,i32] => i8 |op, xs, ys| {
59     xs.iter().zip(ys.iter_mut()).for_each(|(x,y)|
60                                           *y = quantize_linear_f32_i8(*x as f32, op.scale, op.zero_point as i32)
61                                          );
62     Ok(())
63 };
64 info: info_quantize_linear_i8
65);
66
67fn info_quantize_linear_i8(q: &QuantizeLinearI8) -> TractResult<Vec<String>> {
68    Ok(vec![format!(
69        "scale: {} zero_point: {} 1/scale: {}",
70        q.scale,
71        q.zero_point,
72        q.scale.recip()
73    )])
74}
75
76#[derive(Clone, Debug, new)]
77pub struct DequantizeLinearF32 {
78    pub scale: f32,
79    pub zero_point: i32,
80}
81
82impl DequantizeLinearF32 {
83    fn eval_t<T: Datum + AsPrimitive<i32>>(&self, input: &Tensor) -> TractResult<Tensor> {
84        let mut output = unsafe { Tensor::uninitialized::<f32>(input.shape())? };
85        input
86            .as_slice::<T>()?
87            .iter()
88            .zip(output.as_slice_mut::<f32>()?.iter_mut())
89            .for_each(|(x, y)| *y = (x.as_() - self.zero_point) as f32 * self.scale);
90        Ok(output)
91    }
92}
93
94impl Op for DequantizeLinearF32 {
95    fn name(&self) -> Cow<str> {
96        "DequantizeLinearF32".into()
97    }
98
99    fn info(&self) -> TractResult<Vec<String>> {
100        Ok(vec![format!("scale: {} zero_point: {}", self.scale, self.zero_point)])
101    }
102
103    fn validation(&self) -> Validation {
104        Validation::Accurate
105    }
106
107    op_as_typed_op!();
108}
109
110impl EvalOp for DequantizeLinearF32 {
111    fn is_stateless(&self) -> bool {
112        true
113    }
114    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
115        let output = match inputs[0].datum_type() {
116            DatumType::I8 => self.eval_t::<i8>(&inputs[0])?,
117            DatumType::I32 => self.eval_t::<i32>(&inputs[0])?,
118            DatumType::U8 => self.eval_t::<u8>(&inputs[0])?,
119            dt => bail!("Unsupported type {:?}", dt),
120        };
121        Ok(tvec!(output.into_tvalue()))
122    }
123}
124
125impl TypedOp for DequantizeLinearF32 {
126    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
127        let mut fact = inputs[0].clone();
128        fact.datum_type = f32::datum_type();
129        Ok(tvec!(fact))
130    }
131
132    fn axes_mapping(
133        &self,
134        inputs: &[&TypedFact],
135        outputs: &[&TypedFact],
136    ) -> TractResult<AxesMapping> {
137        AxesMapping::natural(inputs, outputs)
138    }
139
140    fn change_axes(
141        &self,
142        model: &TypedModel,
143        node: &TypedNode,
144        _io: InOut,
145        change: &AxisOp,
146    ) -> TractResult<Option<AxisChangeConsequence>> {
147        Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
148    }
149
150    fn declutter(
151        &self,
152        model: &TypedModel,
153        dequant: &TypedNode,
154    ) -> TractResult<Option<TypedModelPatch>> {
155        let mut current = dequant;
156        let incoming_dt = model.node_input_facts(dequant.id)?[0].datum_type;
157        while let Some(quant) = model.single_succ(current.id)? {
158            let q_params = if let Some(op) = quant.op_as::<ElementWiseOp>() {
159                if let Some(mop) = op.0.downcast_ref::<QuantizeLinearU8>() {
160                    Some((mop.scale, mop.zero_point as i32, u8::datum_type()))
161                } else {
162                    op.0.downcast_ref::<QuantizeLinearI8>()
163                        .map(|mop| (mop.scale, mop.zero_point as i32, i8::datum_type()))
164                }
165            } else {
166                None
167            };
168            if let Some((scale, zero_point, dt)) = q_params {
169                // first, try Op::quantize() on all ops in the chain
170                let mut patch = TypedModelPatch::default();
171                let mut wire: OutletId = patch.tap_model(model, dequant.inputs[0])?;
172                let mut next = model.single_succ(dequant.id)?.unwrap();
173                loop {
174                    if let Some(op) = next
175                        .op
176                        .quantize(model, dequant, dt, scale, zero_point)
177                        .with_context(|| format!("Quantizing {next}"))?
178                    {
179                        wire = patch.wire_node(&*next.name, op, [wire].as_ref())?[0];
180                    } else {
181                        break;
182                    }
183                    if next.id == current.id {
184                        patch.shunt_outside(model, OutletId::new(quant.id, 0), wire)?;
185                        return Ok(Some(patch));
186                    } else {
187                        next = model.single_succ(next.id)?.unwrap();
188                    }
189                }
190                // or else make a lookup table
191                if incoming_dt == DatumType::I8 || incoming_dt == DatumType::U8 {
192                    let mut adhoc_model = TypedModel::default();
193                    let mut wire = adhoc_model.add_source("ad-hoc", dt.fact([256]))?;
194                    let mut next = model.single_succ(dequant.id)?.unwrap();
195                    let mut name = None;
196                    // plug in dequant
197                    wire = adhoc_model.wire_node(
198                        &*dequant.name,
199                        dequant.op.clone(),
200                        [wire].as_ref(),
201                    )?[0];
202                    while next.id != quant.id {
203                        name.get_or_insert(&*next.name);
204                        wire =
205                            adhoc_model.wire_node(&*next.name, next.op.clone(), [wire].as_ref())?
206                                [0];
207                        next = model.single_succ(next.id)?.unwrap();
208                    }
209                    // plug in quant
210                    wire =
211                        adhoc_model.wire_node(&*quant.name, quant.op.clone(), [wire].as_ref())?[0];
212                    adhoc_model.set_output_outlets(&[wire])?;
213                    let input = (0u8..=255).collect::<Vec<u8>>();
214                    let input = match dt {
215                        DatumType::I8 => unsafe {
216                            tensor1(std::mem::transmute::<&[u8], &[i8]>(&*input))
217                        },
218                        DatumType::U8 => tensor1(&input),
219                        _ => unreachable!(),
220                    };
221                    let output =
222                        SimplePlan::new(adhoc_model)?.run(tvec!(input.into_tvalue()))?.remove(0);
223                    let table: &[u8] = match dt {
224                        DatumType::I8 => unsafe { std::mem::transmute::<&[i8], &[u8]>(output.as_slice::<i8>()?) },
225                        DatumType::U8 => output.as_slice::<u8>()?,
226                        _ => unreachable!(),
227                    };
228                    let op = lookup_table((tract_linalg::ops().lut_u8)(table));
229                    let mut patch = TypedModelPatch::default();
230                    let mut wire: OutletId = patch.tap_model(model, dequant.inputs[0])?;
231
232                    wire = patch.wire_node(name.unwrap_or(&*dequant.name), op, [wire].as_ref())?[0];
233                    patch.shunt_outside(model, OutletId::new(quant.id, 0), wire)?;
234                    return Ok(Some(patch));
235                }
236            }
237            let (input_facts, output_facts) = model.node_facts(quant.id)?;
238            let invariants = quant
239                .op
240                .axes_mapping(&input_facts, &output_facts)
241                .with_context(|| format!("Querying invariants for {quant}"))?;
242            if invariants.is_element_wise_unary() {
243                current = quant;
244            } else {
245                break;
246            }
247        }
248        Ok(None)
249    }
250
251    as_op!();
252}
253
254element_wise_oop!(lookup_table,
255 LookupTable {
256     table: Box<dyn Lut>
257 },
258 [i8] => i8 |op, xs, ys| {
259     ys.copy_from_slice(xs);
260     unsafe {
261         let casted = std::slice::from_raw_parts_mut(ys.as_mut_ptr() as *mut u8, ys.len());
262         op.table.run(casted);
263     }
264     Ok(())
265 },
266 [u8] => u8 |op, xs, ys| {
267     ys.copy_from_slice(xs);
268     op.table.run(ys);
269     Ok(())
270 }
271);
272
273#[derive(Debug, Clone, Hash)]
274pub struct Scale;
275
276impl crate::ops::binary::BinMiniOp for Scale {
277    fn name(&self) -> &'static str {
278        "Scale"
279    }
280    fn result_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
281        if !a.is_float() {
282            bail!("Scale left operand must be float, got {:?}", a);
283        }
284        Ok(b)
285    }
286
287    fn operating_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
288        if !a.is_float() {
289            bail!("Scale left operand must be float, got {:?}", a);
290        }
291        Ok(b)
292    }
293
294    fn eval_out_of_place(&self, c: &mut Tensor, a: &Tensor, b: &Tensor) -> TractResult<()> {
295        let a = a.cast_to::<f32>()?;
296        let a = a.to_array_view::<f32>()?;
297        unsafe fn eval_out_of_place_t<T: Datum + AsPrimitive<f32>>(
298            c: &mut Tensor,
299            a: &ndarray::ArrayViewD<f32>,
300            b: &Tensor,
301        ) where
302            f32: AsPrimitive<T>,
303        {
304            let b = b.to_array_view_unchecked::<T>();
305            let mut c = c.to_array_view_mut_unchecked::<T>();
306            ndarray::Zip::from(&mut c)
307                .and_broadcast(a)
308                .and_broadcast(b)
309                .for_each(|c, a, b| *c = scale_by(*b, *a))
310        }
311        unsafe { dispatch_numbers!(eval_out_of_place_t(b.datum_type())(c, &a, b)) }
312        Ok(())
313    }
314
315    fn eval_in_a(&self, a: &mut Tensor, b: &Tensor) -> TractResult<()> {
316        let a = a.to_array_view_mut::<f32>()?;
317        let b = b.to_array_view::<f32>()?;
318        ndarray::Zip::from(a).and_broadcast(b).for_each(|a, b| *a = scale_by(*b, *a));
319        Ok(())
320    }
321
322    fn is_commutative(&self) -> bool {
323        false
324    }
325
326    fn declutter(
327        &self,
328        model: &TypedModel,
329        node: &TypedNode,
330    ) -> TractResult<Option<TypedModelPatch>> {
331        let a = model.outlet_fact(node.inputs[0])?;
332        if let Some(a) = &a.uniform {
333            if a.cast_to_scalar::<f32>()? == 1. {
334                return Ok(Some(TypedModelPatch::rewire(
335                    model,
336                    &node.inputs[1..2],
337                    &[node.id.into()],
338                    &|_p, x| Ok(x.into()),
339                )?));
340            } else if node.outputs[0].fact.datum_type == DatumType::I32 {
341                let factor = a.cast_to_scalar::<f32>()?;
342                let scaler = Scaler::new(factor, RoundingPolicy::Even);
343
344                let op = ElementWiseOp(Box::new(QScale { scaler }), None);
345                let patch =
346                    TypedModelPatch::replace_single_op(model, node, &node.inputs[1..2], op)?;
347
348                return Ok(Some(patch));
349            }
350        }
351        Ok(None)
352    }
353}
354
355#[inline]
356pub(crate) fn scale_by<T: Datum + AsPrimitive<f32>>(b: T, a: f32) -> T
357where
358    f32: AsPrimitive<T>,
359{
360    let b = b.as_();
361    (round_ties_to_even(b.abs() * a) * b.signum()).as_()
362}
363
364pub fn scale() -> TypedBinOp {
365    TypedBinOp(Box::new(Scale), None)
366}
367
368/// Offsets i8 integers as u8 integers.
369pub(crate) fn offset_i8_as_u8_elementwise(x: i8) -> u8 {
370    (x as u8).wrapping_add(128)
371}
372
373#[derive(Debug, Clone)]
374pub struct OffsetI8asU8 {}
375impl ElementWiseMiniOp for OffsetI8asU8 {
376    fn name(&self) -> String {
377        format!("{}{}", self.prefix(), stringify!(OffsetI8asU8))
378    }
379    fn output_type(&self, input_type: DatumType) -> Option<DatumType> {
380        Some(if let DatumType::QI8(qp) = input_type {
381            let (zp, scale) = qp.zp_scale();
382            DatumType::QU8(QParams::ZpScale { zero_point: zp + 128, scale })
383        } else if input_type == DatumType::I8 {
384            DatumType::U8
385        } else {
386            input_type
387        })
388    }
389    fn eval_out_of_place(&self, t: &Tensor, out_dt: Option<DatumType>) -> TractResult<Tensor> {
390        let output_type = out_dt.unwrap_or(self.output_type(t.datum_type()).unwrap());
391        let mut dst = unsafe { Tensor::uninitialized_dt(output_type, t.shape())? };
392        if t.datum_type().unquantized() == i8::datum_type() {
393            t.as_slice::<i8>()?
394                .iter()
395                .zip(dst.as_slice_mut::<u8>()?.iter_mut())
396                .for_each(|(x, y)| *y = offset_i8_as_u8_elementwise(*x));
397            return Ok(dst);
398        }
399
400        bail!("{} does not support {:?}", self.name(), t.datum_type());
401    }
402}
403
404pub fn offset_i8_as_u8() -> ElementWiseOp {
405    ElementWiseOp(Box::new(OffsetI8asU8 {}), None)
406}
407
408/// Offsets u8 integers as i8 integers.
409pub(crate) fn offset_u8_as_i8_elementwise(x: u8) -> i8 {
410    x.wrapping_sub(128) as i8
411}
412
413#[derive(Debug, Clone)]
414pub struct OffsetU8asI8 {}
415impl ElementWiseMiniOp for OffsetU8asI8 {
416    fn name(&self) -> String {
417        format!("{}{}", self.prefix(), stringify!(OffsetU8asI8))
418    }
419    fn output_type(&self, input_type: DatumType) -> Option<DatumType> {
420        Some(if let DatumType::QU8(qp) = input_type {
421            let (zp, scale) = qp.zp_scale();
422            DatumType::QI8(QParams::ZpScale { zero_point: zp - 128, scale })
423        } else if input_type == DatumType::U8 {
424            DatumType::I8
425        } else {
426            input_type
427        })
428    }
429    fn eval_out_of_place(&self, t: &Tensor, out_dt: Option<DatumType>) -> TractResult<Tensor> {
430        let output_type = out_dt.unwrap_or(self.output_type(t.datum_type()).unwrap());
431        let mut dst = unsafe { Tensor::uninitialized_dt(output_type, t.shape())? };
432        if t.datum_type().unquantized() == u8::datum_type() {
433            t.as_slice::<u8>()?
434                .iter()
435                .zip(dst.as_slice_mut::<i8>()?.iter_mut())
436                .for_each(|(x, y)| *y = offset_u8_as_i8_elementwise(*x));
437            return Ok(dst);
438        }
439
440        bail!("{} does not support {:?}", self.name(), t.datum_type());
441    }
442}
443pub fn offset_u8_as_i8() -> ElementWiseOp {
444    ElementWiseOp(Box::new(OffsetU8asI8 {}), None)
445}
446
447#[cfg(test)]
448pub mod scale {
449    use crate::internal::*;
450    use crate::ops::einsum::EinSum;
451    use crate::ops::math::round_ties_to_even;
452    use proptest::prelude::*;
453
454    fn test_scale(a: i8, b: i8, scale: f32) {
455        let expected = (((a as i32) * (b as i32)) as f32) / scale;
456        let expected = round_ties_to_even(expected.abs()) * expected.signum();
457        let expected = (expected as i32).clamp(-128, 127);
458        let expected = tensor2(&[[expected as i8]]);
459
460        let input = tvec!(tensor2(&[[b]]).into_tvalue());
461        let mut model = TypedModel::default();
462        let a = model.add_const("a", tensor2(&[[a]])).unwrap();
463        let b = model.add_source("b", i8::fact([1, 1])).unwrap();
464        let bias = model.add_const("bias", tensor0(0i32)).unwrap();
465        let a0 = model.add_const("a0", tensor0(0i8)).unwrap();
466        let a_scale = model.add_const("a_scale", tensor0(1f32)).unwrap();
467        let b0 = model.add_const("b0", tensor0(0i8)).unwrap();
468        let b_scale = model.add_const("b_scale", tensor0(1f32)).unwrap();
469        let c0 = model.add_const("c0", tensor0(0i8)).unwrap();
470        let c_scale = model.add_const("c_scale", tensor0(scale)).unwrap();
471        let op = EinSum {
472            axes: "mk,kn,,,,,,,->mn".parse().unwrap(),
473            operating_dt: i32::datum_type(),
474            q_params: Some(i8::datum_type()),
475        };
476        let output = model
477            .wire_node("mmm", op, &[a, b, bias, a0, a_scale, b0, b_scale, c0, c_scale])
478            .unwrap();
479        model.set_output_outlets(&output).unwrap();
480
481        let plain = model.clone().into_runnable().unwrap().run(input.clone()).unwrap();
482        assert_eq!(*plain[0], expected);
483
484        let optim = model.into_optimized().unwrap().into_runnable().unwrap().run(input).unwrap();
485        assert_eq!(*optim[0], expected);
486    }
487
488    proptest! {
489        #[test]
490        fn prop(a in any::<i8>(), b in any::<i8>(), scale in 0.00001f32..1000.) {
491            test_scale(a, b, scale);
492        }
493    }
494
495    #[test]
496    fn t1() {
497        test_scale(-117, 15, 37.753822);
498    }
499
500    #[test]
501    fn t2() {
502        test_scale(-4, -60, 475.21674);
503    }
504}