1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
use tract_hir::internal::*;
use tract_hir::ops;
use tract_hir::ops::math::round_ties_to_even;

use crate::model::ParsingContext;
use crate::model::TfOpRegister;
use crate::tfpb::tensorflow::NodeDef;

pub fn register_all_ops(reg: &mut TfOpRegister) {
    reg.insert("FakeQuantWithMinMaxVars", fake_quant_with_min_max_vars);
}

fn fake_quant_with_min_max_vars(
    _ctx: &ParsingContext,
    node: &NodeDef,
) -> TractResult<Box<dyn InferenceOp>> {
    let narrow_range = node.get_attr_bool("narrow_range")?;
    let num_bits = node.get_attr_int("num_bits")?;
    Ok(expand(FakeQuantWithMinMaxVars::new(narrow_range, num_bits)))
}

#[derive(Clone, Debug, new, Hash)]
struct FakeQuantWithMinMaxVars {
    narrow_range: bool,
    num_bits: usize,
}



impl FakeQuantWithMinMaxVars {
    fn step(&self, min: &Tensor, max: &Tensor) -> TractResult<f32> {
        let min = min.to_scalar::<f32>()?;
        let max = max.to_scalar::<f32>()?;
        let amplitude = max - min;
        let scale_len = 2_usize.pow(self.num_bits as u32) - 1 - self.narrow_range as usize;
        Ok(amplitude / scale_len as f32)
    }
}

impl Expansion for FakeQuantWithMinMaxVars {
    fn name(&self) -> Cow<str> {
        "FakeQuantWithMinMaxVars".into()
    }

    fn rules<'r, 'p: 'r, 's: 'r>(
        &'s self,
        s: &mut Solver<'r>,
        inputs: &'p [TensorProxy],
        outputs: &'p [TensorProxy],
    ) -> InferenceResult {
        check_input_arity(inputs, 3)?;
        check_output_arity(outputs, 1)?;
        s.equals(&inputs[0].datum_type, &inputs[1].datum_type)?;
        s.equals(&inputs[0].datum_type, &inputs[2].datum_type)?;
        s.equals(&inputs[1].shape, shapefactoid!())?;
        s.equals(&inputs[2].shape, shapefactoid!())?;
        s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
        s.equals(&inputs[0].shape, &outputs[0].shape)?;
        Ok(())
    }

    fn wire(
        &self,
        prefix: &str,
        target: &mut TypedModel,
        inputs: &[OutletId],
    ) -> TractResult<TVec<OutletId>> {
        if let (Some(min), Some(max)) = (
            target.outlet_fact(inputs[1])?.konst.as_ref(),
            target.outlet_fact(inputs[2])?.konst.as_ref(),
        ) {
            let rank = target.outlet_fact(inputs[0])?.rank();
            macro_rules! cst {
                ($id:ident, $value: expr) => {
                    let $id = tensor0($value).broadcast_into_rank(rank)?;
                    let $id = target.add_const(prefix.to_string() + "." + stringify!($id), $id)?;
                };
            }
            let step = self.step(min, max)?;
            let min = *min.to_scalar::<f32>()?;
            let max = *max.to_scalar::<f32>()?;
            let min_adj = step * round_ties_to_even(min / step);
            let max_adj = max - min + min_adj;
            let wire = inputs[0];
            cst!(min_adj, min_adj);
            cst!(max_adj, max_adj);
            cst!(step, step);
            let wire = target.wire_node(
                format!("{prefix}.clamp_min"),
                ops::math::max(),
                &[wire, min_adj],
            )?[0];
            let wire = target.wire_node(
                format!("{prefix}.clamp_max"),
                ops::math::min(),
                &[max_adj, wire],
            )?[0];
            let wire = target.wire_node(
                format!("{prefix}.sub-min"),
                ops::math::sub(),
                &[wire, min_adj],
            )?[0];
            let wire = target.wire_node(
                format!("{prefix}.div-step"),
                ops::math::div(),
                &[wire, step],
            )?[0];
            let wire = target.wire_node(
                format!("{prefix}.round"),
                ops::math::round_half_to_even(),
                &[wire],
            )?[0];
            let wire = target.wire_node(
                format!("{prefix}.mul-step"),
                ops::math::mul(),
                &[wire, step],
            )?[0];
            target.wire_node(format!("{prefix}.add-min"), ops::math::add(), &[wire, min_adj])
        } else {
            bail!("Operator can not be made a TypedOp.")
        }
    }
}