tract_tensorflow/ops/
quant.rs

1use tract_hir::internal::*;
2use tract_hir::ops;
3use tract_hir::ops::math::round_ties_to_even;
4
5use crate::model::ParsingContext;
6use crate::model::TfOpRegister;
7use crate::tfpb::tensorflow::NodeDef;
8
9pub fn register_all_ops(reg: &mut TfOpRegister) {
10    reg.insert("FakeQuantWithMinMaxVars", fake_quant_with_min_max_vars);
11}
12
13fn fake_quant_with_min_max_vars(
14    _ctx: &ParsingContext,
15    node: &NodeDef,
16) -> TractResult<Box<dyn InferenceOp>> {
17    let narrow_range = node.get_attr_bool("narrow_range")?;
18    let num_bits = node.get_attr_int("num_bits")?;
19    Ok(expand(FakeQuantWithMinMaxVars::new(narrow_range, num_bits)))
20}
21
22#[derive(Clone, Debug, new, Hash)]
23struct FakeQuantWithMinMaxVars {
24    narrow_range: bool,
25    num_bits: usize,
26}
27
28
29
30impl FakeQuantWithMinMaxVars {
31    fn step(&self, min: &Tensor, max: &Tensor) -> TractResult<f32> {
32        let min = min.to_scalar::<f32>()?;
33        let max = max.to_scalar::<f32>()?;
34        let amplitude = max - min;
35        let scale_len = 2_usize.pow(self.num_bits as u32) - 1 - self.narrow_range as usize;
36        Ok(amplitude / scale_len as f32)
37    }
38}
39
40impl Expansion for FakeQuantWithMinMaxVars {
41    fn name(&self) -> StaticName {
42        "FakeQuantWithMinMaxVars".into()
43    }
44
45    fn rules<'r, 'p: 'r, 's: 'r>(
46        &'s self,
47        s: &mut Solver<'r>,
48        inputs: &'p [TensorProxy],
49        outputs: &'p [TensorProxy],
50    ) -> InferenceResult {
51        check_input_arity(inputs, 3)?;
52        check_output_arity(outputs, 1)?;
53        s.equals(&inputs[0].datum_type, &inputs[1].datum_type)?;
54        s.equals(&inputs[0].datum_type, &inputs[2].datum_type)?;
55        s.equals(&inputs[1].shape, shapefactoid!())?;
56        s.equals(&inputs[2].shape, shapefactoid!())?;
57        s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
58        s.equals(&inputs[0].shape, &outputs[0].shape)?;
59        Ok(())
60    }
61
62    fn wire(
63        &self,
64        prefix: &str,
65        target: &mut TypedModel,
66        inputs: &[OutletId],
67    ) -> TractResult<TVec<OutletId>> {
68        if let (Some(min), Some(max)) = (
69            target.outlet_fact(inputs[1])?.konst.as_ref(),
70            target.outlet_fact(inputs[2])?.konst.as_ref(),
71        ) {
72            let rank = target.outlet_fact(inputs[0])?.rank();
73            macro_rules! cst {
74                ($id:ident, $value: expr) => {
75                    let $id = tensor0($value).broadcast_into_rank(rank)?;
76                    let $id = target.add_const(prefix.to_string() + "." + stringify!($id), $id)?;
77                };
78            }
79            let step = self.step(min, max)?;
80            let min = *min.to_scalar::<f32>()?;
81            let max = *max.to_scalar::<f32>()?;
82            let min_adj = step * round_ties_to_even(min / step);
83            let max_adj = max - min + min_adj;
84            let wire = inputs[0];
85            cst!(min_adj, min_adj);
86            cst!(max_adj, max_adj);
87            cst!(step, step);
88            let wire = target.wire_node(
89                format!("{prefix}.clamp_min"),
90                ops::math::max(),
91                &[wire, min_adj],
92            )?[0];
93            let wire = target.wire_node(
94                format!("{prefix}.clamp_max"),
95                ops::math::min(),
96                &[max_adj, wire],
97            )?[0];
98            let wire = target.wire_node(
99                format!("{prefix}.sub-min"),
100                ops::math::sub(),
101                &[wire, min_adj],
102            )?[0];
103            let wire = target.wire_node(
104                format!("{prefix}.div-step"),
105                ops::math::div(),
106                &[wire, step],
107            )?[0];
108            let wire = target.wire_node(
109                format!("{prefix}.round"),
110                ops::math::round_half_to_even(),
111                &[wire],
112            )?[0];
113            let wire = target.wire_node(
114                format!("{prefix}.mul-step"),
115                ops::math::mul(),
116                &[wire, step],
117            )?[0];
118            target.wire_node(format!("{prefix}.add-min"), ops::math::add(), &[wire, min_adj])
119        } else {
120            bail!("Operator can not be made a TypedOp.")
121        }
122    }
123}