tract_tensorflow/ops/
quant.rs1use tract_hir::internal::*;
2use tract_hir::ops;
3use tract_hir::ops::math::round_ties_to_even;
4
5use crate::model::ParsingContext;
6use crate::model::TfOpRegister;
7use crate::tfpb::tensorflow::NodeDef;
8
9pub fn register_all_ops(reg: &mut TfOpRegister) {
10 reg.insert("FakeQuantWithMinMaxVars", fake_quant_with_min_max_vars);
11}
12
13fn fake_quant_with_min_max_vars(
14 _ctx: &ParsingContext,
15 node: &NodeDef,
16) -> TractResult<Box<dyn InferenceOp>> {
17 let narrow_range = node.get_attr_bool("narrow_range")?;
18 let num_bits = node.get_attr_int("num_bits")?;
19 Ok(expand(FakeQuantWithMinMaxVars::new(narrow_range, num_bits)))
20}
21
22#[derive(Clone, Debug, new, Hash)]
23struct FakeQuantWithMinMaxVars {
24 narrow_range: bool,
25 num_bits: usize,
26}
27
28impl FakeQuantWithMinMaxVars {
29 fn step(&self, min: &Tensor, max: &Tensor) -> TractResult<f32> {
30 let min = min.to_scalar::<f32>()?;
31 let max = max.to_scalar::<f32>()?;
32 let amplitude = max - min;
33 let scale_len = 2_usize.pow(self.num_bits as u32) - 1 - self.narrow_range as usize;
34 Ok(amplitude / scale_len as f32)
35 }
36}
37
38impl Expansion for FakeQuantWithMinMaxVars {
39 fn name(&self) -> StaticName {
40 "FakeQuantWithMinMaxVars".into()
41 }
42
43 fn rules<'r, 'p: 'r, 's: 'r>(
44 &'s self,
45 s: &mut Solver<'r>,
46 inputs: &'p [TensorProxy],
47 outputs: &'p [TensorProxy],
48 ) -> InferenceResult {
49 check_input_arity(inputs, 3)?;
50 check_output_arity(outputs, 1)?;
51 s.equals(&inputs[0].datum_type, &inputs[1].datum_type)?;
52 s.equals(&inputs[0].datum_type, &inputs[2].datum_type)?;
53 s.equals(&inputs[1].shape, shapefactoid!())?;
54 s.equals(&inputs[2].shape, shapefactoid!())?;
55 s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
56 s.equals(&inputs[0].shape, &outputs[0].shape)?;
57 Ok(())
58 }
59
60 fn wire(
61 &self,
62 prefix: &str,
63 target: &mut TypedModel,
64 inputs: &[OutletId],
65 ) -> TractResult<TVec<OutletId>> {
66 if let (Some(min), Some(max)) = (
67 target.outlet_fact(inputs[1])?.konst.as_ref(),
68 target.outlet_fact(inputs[2])?.konst.as_ref(),
69 ) {
70 let rank = target.outlet_fact(inputs[0])?.rank();
71 macro_rules! cst {
72 ($id:ident, $value: expr) => {
73 let $id = tensor0($value).broadcast_into_rank(rank)?;
74 let $id = target.add_const(prefix.to_string() + "." + stringify!($id), $id)?;
75 };
76 }
77 let step = self.step(min, max)?;
78 let min = *min.to_scalar::<f32>()?;
79 let max = *max.to_scalar::<f32>()?;
80 let min_adj = step * round_ties_to_even(min / step);
81 let max_adj = max - min + min_adj;
82 let wire = inputs[0];
83 cst!(min_adj, min_adj);
84 cst!(max_adj, max_adj);
85 cst!(step, step);
86 let wire = target.wire_node(
87 format!("{prefix}.clamp_min"),
88 ops::math::max(),
89 &[wire, min_adj],
90 )?[0];
91 let wire = target.wire_node(
92 format!("{prefix}.clamp_max"),
93 ops::math::min(),
94 &[max_adj, wire],
95 )?[0];
96 let wire = target.wire_node(
97 format!("{prefix}.sub-min"),
98 ops::math::sub(),
99 &[wire, min_adj],
100 )?[0];
101 let wire =
102 target.wire_node(format!("{prefix}.div-step"), ops::math::div(), &[wire, step])?[0];
103 let wire = target.wire_node(
104 format!("{prefix}.round"),
105 ops::math::round_half_to_even(),
106 &[wire],
107 )?[0];
108 let wire =
109 target.wire_node(format!("{prefix}.mul-step"), ops::math::mul(), &[wire, step])?[0];
110 target.wire_node(format!("{prefix}.add-min"), ops::math::add(), &[wire, min_adj])
111 } else {
112 bail!("Operator can not be made a TypedOp.")
113 }
114 }
115}