Skip to main content

tract_core/ops/
quant.rs

1#![allow(clippy::unnecessary_cast)]
2
3use crate::internal::*;
4use crate::ops::element_wise::ElementWiseOp;
5use crate::ops::math::QScale;
6use num_traits::AsPrimitive;
7use tract_linalg::Scaler;
8use tract_linalg::lut::Lut;
9use tract_linalg::mmm::RoundingPolicy;
10
11use super::binary::TypedBinOp;
12use super::math::round_ties_to_even;
13
14pub fn quantize_linear_f32_u8(x: f32, scale: f32, zero_point: i32) -> u8 {
15    (((x * scale).round() as i32) + zero_point).clamp(u8::MIN as i32, u8::MAX as i32) as u8
16}
17
18pub fn quantize_linear_f32_i8(x: f32, scale: f32, zero_point: i32) -> i8 {
19    (((x * scale).round() as i32) + zero_point).clamp(i8::MIN as i32, i8::MAX as i32) as i8
20}
21
22element_wise_oop!(quantize_linear_u8,
23 QuantizeLinearU8 {
24     scale: f32,
25     zero_point: u8
26 },
27 [f16] => u8 |op, xs, ys| {
28     xs.iter().zip(ys.iter_mut()).for_each(|(x,y)|
29                                           *y = quantize_linear_f32_u8(x.to_f32(), op.scale, op.zero_point as i32)
30                                          );
31     Ok(())
32 },
33 [f32,i32] => u8 |op, xs, ys| {
34     xs.iter().zip(ys.iter_mut()).for_each(|(x,y)|
35                                           *y = quantize_linear_f32_u8(*x as f32, op.scale, op.zero_point as i32)
36                                          );
37     Ok(())
38 };
39 info: info_quantize_linear_u8
40);
41
42fn info_quantize_linear_u8(q: &QuantizeLinearU8) -> TractResult<Vec<String>> {
43    Ok(vec![format!(
44        "scale: {} zero_point: {} 1/scale: {}",
45        q.scale,
46        q.zero_point,
47        q.scale.recip()
48    )])
49}
50
51element_wise_oop!(quantize_linear_i8,
52 QuantizeLinearI8 {
53     scale: f32,
54     zero_point: i8
55 },
56 [f32,i32] => i8 |op, xs, ys| {
57     xs.iter().zip(ys.iter_mut()).for_each(|(x,y)|
58                                           *y = quantize_linear_f32_i8(*x as f32, op.scale, op.zero_point as i32)
59                                          );
60     Ok(())
61 };
62 info: info_quantize_linear_i8
63);
64
65fn info_quantize_linear_i8(q: &QuantizeLinearI8) -> TractResult<Vec<String>> {
66    Ok(vec![format!(
67        "scale: {} zero_point: {} 1/scale: {}",
68        q.scale,
69        q.zero_point,
70        q.scale.recip()
71    )])
72}
73
74#[derive(Clone, Debug, new, PartialEq)]
75pub struct DequantizeLinearF32 {
76    pub scale: f32,
77    pub zero_point: i32,
78}
79
80impl Eq for DequantizeLinearF32 {}
81
82impl DequantizeLinearF32 {
83    fn eval_t<T: Datum + AsPrimitive<i32>>(&self, input: &Tensor) -> TractResult<Tensor> {
84        let mut output = unsafe { Tensor::uninitialized::<f32>(input.shape())? };
85        input
86            .try_as_plain()?
87            .as_slice::<T>()?
88            .iter()
89            .zip(output.try_as_plain_mut()?.as_slice_mut::<f32>()?.iter_mut())
90            .for_each(|(x, y)| *y = (x.as_() - self.zero_point) as f32 * self.scale);
91        Ok(output)
92    }
93}
94
95impl Op for DequantizeLinearF32 {
96    fn name(&self) -> StaticName {
97        "DequantizeLinearF32".into()
98    }
99
100    fn info(&self) -> TractResult<Vec<String>> {
101        Ok(vec![format!("scale: {} zero_point: {}", self.scale, self.zero_point)])
102    }
103
104    fn validation(&self) -> Validation {
105        Validation::Accurate
106    }
107
108    op_as_typed_op!();
109}
110
111impl EvalOp for DequantizeLinearF32 {
112    fn is_stateless(&self) -> bool {
113        true
114    }
115    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
116        let output = match inputs[0].datum_type() {
117            DatumType::I8 => self.eval_t::<i8>(&inputs[0])?,
118            DatumType::I32 => self.eval_t::<i32>(&inputs[0])?,
119            DatumType::U8 => self.eval_t::<u8>(&inputs[0])?,
120            dt => bail!("Unsupported type {:?}", dt),
121        };
122        Ok(tvec!(output.into_tvalue()))
123    }
124}
125
126impl TypedOp for DequantizeLinearF32 {
127    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
128        let mut fact = inputs[0].clone();
129        fact.datum_type = f32::datum_type();
130        Ok(tvec!(fact))
131    }
132
133    fn axes_mapping(
134        &self,
135        inputs: &[&TypedFact],
136        outputs: &[&TypedFact],
137    ) -> TractResult<AxesMapping> {
138        AxesMapping::natural(inputs, outputs)
139    }
140
141    fn change_axes(
142        &self,
143        model: &TypedModel,
144        node: &TypedNode,
145        _io: InOut,
146        change: &AxisOp,
147    ) -> TractResult<Option<AxisChangeConsequence>> {
148        Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
149    }
150
151    fn declutter(
152        &self,
153        model: &TypedModel,
154        dequant: &TypedNode,
155    ) -> TractResult<Option<TypedModelPatch>> {
156        let mut current = dequant;
157        let incoming_dt = model.node_input_facts(dequant.id)?[0].datum_type;
158        while let Some(quant) = model.single_succ(current.id)? {
159            let q_params = if let Some(op) = quant.op_as::<ElementWiseOp>() {
160                if let Some(mop) = op.0.downcast_ref::<QuantizeLinearU8>() {
161                    Some((mop.scale, mop.zero_point as i32, u8::datum_type()))
162                } else {
163                    op.0.downcast_ref::<QuantizeLinearI8>()
164                        .map(|mop| (mop.scale, mop.zero_point as i32, i8::datum_type()))
165                }
166            } else {
167                None
168            };
169            if let Some((scale, zero_point, dt)) = q_params {
170                // first, try Op::quantize() on all ops in the chain
171                let mut patch = TypedModelPatch::default();
172                let mut wire: OutletId = patch.tap_model(model, dequant.inputs[0])?;
173                let mut next = model.single_succ(dequant.id)?.unwrap();
174                loop {
175                    if let Some(op) = next
176                        .op
177                        .quantize(model, dequant, dt, scale, zero_point)
178                        .with_context(|| format!("Quantizing {next}"))?
179                    {
180                        wire = patch.wire_node(&*next.name, op, [wire].as_ref())?[0];
181                    } else {
182                        break;
183                    }
184                    if next.id == current.id {
185                        patch.shunt_outside(model, OutletId::new(quant.id, 0), wire)?;
186                        return Ok(Some(patch));
187                    } else {
188                        next = model.single_succ(next.id)?.unwrap();
189                    }
190                }
191                // or else make a lookup table
192                if incoming_dt == DatumType::I8 || incoming_dt == DatumType::U8 {
193                    let mut adhoc_model = TypedModel::default();
194                    let mut wire = adhoc_model.add_source("ad-hoc", dt.fact([256]))?;
195                    let mut next = model.single_succ(dequant.id)?.unwrap();
196                    let mut name = None;
197                    // plug in dequant
198                    wire = adhoc_model.wire_node(
199                        &*dequant.name,
200                        dequant.op.clone(),
201                        [wire].as_ref(),
202                    )?[0];
203                    while next.id != quant.id {
204                        name.get_or_insert(&*next.name);
205                        wire =
206                            adhoc_model.wire_node(&*next.name, next.op.clone(), [wire].as_ref())?
207                                [0];
208                        next = model.single_succ(next.id)?.unwrap();
209                    }
210                    // plug in quant
211                    wire =
212                        adhoc_model.wire_node(&*quant.name, quant.op.clone(), [wire].as_ref())?[0];
213                    adhoc_model.select_output_outlets(&[wire])?;
214                    let input = (0u8..=255).collect::<Vec<u8>>();
215                    let input = match dt {
216                        DatumType::I8 => unsafe {
217                            tensor1(std::mem::transmute::<&[u8], &[i8]>(&*input))
218                        },
219                        DatumType::U8 => tensor1(&input),
220                        _ => unreachable!(),
221                    };
222                    let output =
223                        SimplePlan::new(adhoc_model)?.run(tvec!(input.into_tvalue()))?.remove(0);
224                    let table: &[u8] = match dt {
225                        DatumType::I8 => unsafe {
226                            std::mem::transmute::<&[i8], &[u8]>(
227                                output.try_as_plain()?.as_slice::<i8>()?,
228                            )
229                        },
230                        DatumType::U8 => output.try_as_plain()?.as_slice::<u8>()?,
231                        _ => unreachable!(),
232                    };
233                    let op = lookup_table((tract_linalg::ops().lut_u8)(table));
234                    let mut patch = TypedModelPatch::default();
235                    let mut wire: OutletId = patch.tap_model(model, dequant.inputs[0])?;
236
237                    wire = patch.wire_node(name.unwrap_or(&*dequant.name), op, [wire].as_ref())?[0];
238                    patch.shunt_outside(model, OutletId::new(quant.id, 0), wire)?;
239                    return Ok(Some(patch));
240                }
241            }
242            let (input_facts, output_facts) = model.node_facts(quant.id)?;
243            let invariants = quant
244                .op
245                .axes_mapping(&input_facts, &output_facts)
246                .with_context(|| format!("Querying invariants for {quant}"))?;
247            if invariants.is_element_wise_unary() {
248                current = quant;
249            } else {
250                break;
251            }
252        }
253        Ok(None)
254    }
255
256    as_op!();
257}
258
259element_wise_oop!(lookup_table,
260 LookupTable {
261     table: Box<dyn Lut>
262 },
263 [i8] => i8 |op, xs, ys| {
264     ys.copy_from_slice(xs);
265     unsafe {
266         let casted = std::slice::from_raw_parts_mut(ys.as_mut_ptr() as *mut u8, ys.len());
267         op.table.run(casted);
268     }
269     Ok(())
270 },
271 [u8] => u8 |op, xs, ys| {
272     ys.copy_from_slice(xs);
273     op.table.run(ys);
274     Ok(())
275 }
276);
277
278#[derive(Debug, Clone, Hash, PartialEq, Eq)]
279pub struct Scale;
280
281impl crate::ops::binary::BinMiniOp for Scale {
282    fn name(&self) -> &'static str {
283        "Scale"
284    }
285    fn result_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
286        if !a.is_float() {
287            bail!("Scale left operand must be float, got {:?}", a);
288        }
289        Ok(b)
290    }
291
292    fn operating_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
293        if !a.is_float() {
294            bail!("Scale left operand must be float, got {:?}", a);
295        }
296        Ok(b)
297    }
298
299    fn eval_out_of_place(&self, c: &mut Tensor, a: &Tensor, b: &Tensor) -> TractResult<()> {
300        let a = a.cast_to::<f32>()?;
301        let a = a.to_plain_array_view::<f32>()?;
302        unsafe fn eval_out_of_place_t<T: Datum + AsPrimitive<f32>>(
303            c: &mut Tensor,
304            a: &ndarray::ArrayViewD<f32>,
305            b: &Tensor,
306        ) where
307            f32: AsPrimitive<T>,
308        {
309            let b = unsafe { b.to_array_view_unchecked::<T>() };
310            let mut c = unsafe { c.to_array_view_mut_unchecked::<T>() };
311            ndarray::Zip::from(&mut c)
312                .and_broadcast(a)
313                .and_broadcast(b)
314                .for_each(|c, a, b| *c = scale_by(*b, *a))
315        }
316        unsafe { dispatch_numbers!(eval_out_of_place_t(b.datum_type())(c, &a, b)) }
317        Ok(())
318    }
319
320    fn eval_in_a(&self, a: &mut Tensor, b: &Tensor) -> TractResult<()> {
321        let mut a_plain = a.try_as_plain_mut()?;
322        let a = a_plain.to_array_view_mut::<f32>()?;
323        let b = b.to_plain_array_view::<f32>()?;
324        ndarray::Zip::from(a).and_broadcast(b).for_each(|a, b| *a = scale_by(*b, *a));
325        Ok(())
326    }
327
328    fn is_commutative(&self) -> bool {
329        false
330    }
331
332    fn declutter(
333        &self,
334        model: &TypedModel,
335        node: &TypedNode,
336    ) -> TractResult<Option<TypedModelPatch>> {
337        let a = model.outlet_fact(node.inputs[0])?;
338        if let Some(a) = &a.uniform {
339            if a.cast_to_scalar::<f32>()? == 1. {
340                return Ok(Some(TypedModelPatch::rewire(
341                    model,
342                    &node.inputs[1..2],
343                    &[node.id.into()],
344                    &|_p, x| Ok(x.into()),
345                )?));
346            } else if node.outputs[0].fact.datum_type == DatumType::I32 {
347                let factor = a.cast_to_scalar::<f32>()?;
348                let scaler = Scaler::new(factor, RoundingPolicy::Even);
349
350                let op = ElementWiseOp(Box::new(QScale { scaler }), None);
351                let patch =
352                    TypedModelPatch::replace_single_op(model, node, &node.inputs[1..2], op)?;
353
354                return Ok(Some(patch));
355            }
356        }
357        Ok(None)
358    }
359}
360
361#[inline]
362pub(crate) fn scale_by<T: Datum + AsPrimitive<f32>>(b: T, a: f32) -> T
363where
364    f32: AsPrimitive<T>,
365{
366    let b = b.as_();
367    (round_ties_to_even(b.abs() * a) * b.signum()).as_()
368}
369
370pub fn scale() -> TypedBinOp {
371    TypedBinOp(Box::new(Scale), None)
372}
373
374/// Offsets i8 integers as u8 integers.
375pub(crate) fn offset_i8_as_u8_elementwise(x: i8) -> u8 {
376    (x as u8).wrapping_add(128)
377}
378
379#[derive(Debug, Clone, PartialEq, Eq)]
380pub struct OffsetI8asU8;
381impl ElementWiseMiniOp for OffsetI8asU8 {
382    fn name(&self) -> String {
383        format!("{}{}", self.prefix(), stringify!(OffsetI8asU8))
384    }
385    fn output_type(&self, input_type: DatumType) -> Option<DatumType> {
386        Some(if let DatumType::QI8(qp) = input_type {
387            let (zp, scale) = qp.zp_scale();
388            DatumType::QU8(QParams::ZpScale { zero_point: zp + 128, scale })
389        } else if input_type == DatumType::I8 {
390            DatumType::U8
391        } else {
392            input_type
393        })
394    }
395    fn eval_out_of_place(&self, t: &Tensor, out_dt: Option<DatumType>) -> TractResult<Tensor> {
396        let output_type = out_dt.unwrap_or(self.output_type(t.datum_type()).unwrap());
397        let mut dst = unsafe { Tensor::uninitialized_dt(output_type, t.shape())? };
398        if t.datum_type().unquantized() == i8::datum_type() {
399            t.try_as_plain()?
400                .as_slice::<i8>()?
401                .iter()
402                .zip(dst.try_as_plain_mut()?.as_slice_mut::<u8>()?.iter_mut())
403                .for_each(|(x, y)| *y = offset_i8_as_u8_elementwise(*x));
404            return Ok(dst);
405        }
406
407        bail!("{} does not support {:?}", self.name(), t.datum_type());
408    }
409}
410
411pub fn offset_i8_as_u8() -> ElementWiseOp {
412    ElementWiseOp(Box::new(OffsetI8asU8 {}), None)
413}
414
415/// Offsets u8 integers as i8 integers.
416pub(crate) fn offset_u8_as_i8_elementwise(x: u8) -> i8 {
417    x.wrapping_sub(128) as i8
418}
419
420#[derive(Debug, Clone, PartialEq, Eq)]
421pub struct OffsetU8asI8;
422impl ElementWiseMiniOp for OffsetU8asI8 {
423    fn name(&self) -> String {
424        format!("{}{}", self.prefix(), stringify!(OffsetU8asI8))
425    }
426    fn output_type(&self, input_type: DatumType) -> Option<DatumType> {
427        Some(if let DatumType::QU8(qp) = input_type {
428            let (zp, scale) = qp.zp_scale();
429            DatumType::QI8(QParams::ZpScale { zero_point: zp - 128, scale })
430        } else if input_type == DatumType::U8 {
431            DatumType::I8
432        } else {
433            input_type
434        })
435    }
436    fn eval_out_of_place(&self, t: &Tensor, out_dt: Option<DatumType>) -> TractResult<Tensor> {
437        let output_type = out_dt.unwrap_or(self.output_type(t.datum_type()).unwrap());
438        let mut dst = unsafe { Tensor::uninitialized_dt(output_type, t.shape())? };
439        if t.datum_type().unquantized() == u8::datum_type() {
440            t.try_as_plain()?
441                .as_slice::<u8>()?
442                .iter()
443                .zip(dst.try_as_plain_mut()?.as_slice_mut::<i8>()?.iter_mut())
444                .for_each(|(x, y)| *y = offset_u8_as_i8_elementwise(*x));
445            return Ok(dst);
446        }
447
448        bail!("{} does not support {:?}", self.name(), t.datum_type());
449    }
450}
451pub fn offset_u8_as_i8() -> ElementWiseOp {
452    ElementWiseOp(Box::new(OffsetU8asI8 {}), None)
453}
454
455#[cfg(test)]
456pub mod scale {
457    use crate::internal::*;
458    use crate::ops::einsum::EinSum;
459    use crate::ops::math::round_ties_to_even;
460    use proptest::prelude::*;
461
462    fn test_scale(a: i8, b: i8, scale: f32) {
463        let expected = (((a as i32) * (b as i32)) as f32) / scale;
464        let expected = round_ties_to_even(expected.abs()) * expected.signum();
465        let expected = (expected as i32).clamp(-128, 127);
466        let expected = tensor2(&[[expected as i8]]);
467
468        let input = tvec!(tensor2(&[[b]]).into_tvalue());
469        let mut model = TypedModel::default();
470        let a = model.add_const("a", tensor2(&[[a]])).unwrap();
471        let b = model.add_source("b", i8::fact([1, 1])).unwrap();
472        let bias = model.add_const("bias", tensor0(0i32)).unwrap();
473        let a0 = model.add_const("a0", tensor0(0i8)).unwrap();
474        let a_scale = model.add_const("a_scale", tensor0(1f32)).unwrap();
475        let b0 = model.add_const("b0", tensor0(0i8)).unwrap();
476        let b_scale = model.add_const("b_scale", tensor0(1f32)).unwrap();
477        let c0 = model.add_const("c0", tensor0(0i8)).unwrap();
478        let c_scale = model.add_const("c_scale", tensor0(scale)).unwrap();
479        let op = EinSum {
480            axes: "mk,kn,,,,,,,->mn".parse().unwrap(),
481            operating_dt: i32::datum_type(),
482            q_params: Some(i8::datum_type()),
483        };
484        let output = model
485            .wire_node("mmm", op, &[a, b, bias, a0, a_scale, b0, b_scale, c0, c_scale])
486            .unwrap();
487        model.select_output_outlets(&output).unwrap();
488
489        let plain = model.clone().into_runnable().unwrap().run(input.clone()).unwrap();
490        assert_eq!(*plain[0], expected);
491
492        let optim = model.into_optimized().unwrap().into_runnable().unwrap().run(input).unwrap();
493        assert_eq!(*optim[0], expected);
494    }
495
496    proptest! {
497        #[test]
498        fn prop(a in any::<i8>(), b in any::<i8>(), scale in 0.00001f32..1000.) {
499            test_scale(a, b, scale);
500        }
501    }
502
503    #[test]
504    fn t1() {
505        test_scale(-117, 15, 37.753822);
506    }
507
508    #[test]
509    fn t2() {
510        test_scale(-4, -60, 475.21674);
511    }
512}