Skip to main content

tract_core/
floats.rs

1use crate::internal::translator::Translate;
2use crate::internal::*;
3use crate::ops::array::{Pad, PadMode};
4use crate::ops::binary::TypedBinOp;
5use crate::ops::cast::{Cast, cast};
6use crate::ops::einsum::EinSum;
7use crate::ops::element_wise::ElementWiseOp;
8use crate::ops::konst::Const;
9use crate::ops::scan::Scan;
10use crate::ops::source::TypedSource;
11use crate::transform::ModelTransform;
12
13pub struct FloatPrecisionTranslator {
14    from_dt: DatumType,
15    to_dt: DatumType,
16    #[allow(clippy::type_complexity)]
17    node_predicate: Option<Box<dyn Fn(&TypedNode) -> bool>>,
18}
19
20impl FloatPrecisionTranslator {
21    pub fn new(from_dt: DatumType, to_dt: DatumType) -> Self {
22        Self { from_dt, to_dt, node_predicate: None }
23    }
24
25    pub fn with_filter(
26        from_dt: DatumType,
27        to_dt: DatumType,
28        node_predicate: impl Fn(&TypedNode) -> bool + 'static,
29    ) -> Self {
30        Self { from_dt, to_dt, node_predicate: Some(Box::new(node_predicate)) }
31    }
32
33    fn should_translate_node(&self, node: &TypedNode) -> bool {
34        self.node_predicate.as_ref().map(|it| (it)(node)).unwrap_or(true)
35    }
36
37    /// Cast node inputs to the working float precision for the operator
38    /// Only input using float datumtype are impacted. This will add cast operations
39    /// in the model. The function return the new input outlet ids.
40    fn cast_inputs_if_required(
41        &self,
42        model: &mut TypedModel,
43        node: &TypedNode,
44        mapping: &HashMap<OutletId, OutletId>,
45        op_float_dt: DatumType,
46    ) -> TractResult<TVec<OutletId>> {
47        let original_op_float_dt =
48            if op_float_dt == self.from_dt { self.to_dt } else { self.from_dt };
49
50        let mut mapped_inputs = tvec![];
51        for (i_idx, i) in node.inputs.iter().enumerate() {
52            let fact = model.outlet_fact(mapping[i])?;
53            if fact.datum_type == original_op_float_dt && fact.is_plain() {
54                let casted_mapped_input = model.wire_node(
55                    format!("{}.cast-{i_idx}", node.name),
56                    Cast { to: op_float_dt },
57                    &[mapping[i]],
58                )?[0];
59                mapped_inputs.push(casted_mapped_input);
60            } else {
61                mapped_inputs.push(mapping[i])
62            }
63        }
64        Ok(mapped_inputs)
65    }
66
67    /// Cast node output outlet ids to the destination float precision,
68    /// after insertion in the target mode. This preserves the model output float
69    /// precision.
70    fn cast_model_outputs_if_required(
71        &self,
72        source: &TypedModel,
73        node: &TypedNode,
74        target: &mut TypedModel,
75        target_node_outlet_ids: TVec<OutletId>,
76    ) -> TractResult<TVec<OutletId>> {
77        let mut outputs = tvec![];
78        for (o_idx, o) in target_node_outlet_ids.into_iter().enumerate() {
79            // Add Cast op for model output
80            let is_source_output = source.outputs.contains(&OutletId::new(node.id, o_idx));
81            let fact = target.outlet_fact(o)?;
82            if fact.datum_type == self.from_dt && fact.is_plain() && is_source_output {
83                let casted_output = target.wire_node(
84                    format!("{}.cast-out-{o_idx}", node.name),
85                    Cast { to: self.to_dt },
86                    &[o],
87                )?[0];
88                outputs.push(casted_output);
89            } else {
90                outputs.push(o)
91            }
92        }
93        Ok(outputs)
94    }
95}
96
97impl std::fmt::Debug for FloatPrecisionTranslator {
98    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99        f.debug_struct("FloatPrecisionTranslator")
100            .field("from", &self.from_dt)
101            .field("to", &self.to_dt)
102            .finish()
103    }
104}
105
106impl ModelTransform for FloatPrecisionTranslator {
107    fn name(&self) -> StaticName {
108        format!("{:?}-to-{:?}", self.from_dt, self.to_dt).into()
109    }
110
111    fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
112        let new = self.translate_model(model)?;
113        *model = new;
114        Ok(())
115    }
116}
117
118impl Translate<TypedFact, Box<dyn TypedOp>, TypedFact, Box<dyn TypedOp>>
119    for FloatPrecisionTranslator
120{
121    fn translate_node(
122        &self,
123        source: &TypedModel,
124        node: &TypedNode,
125        target: &mut TypedModel,
126        mapping: &HashMap<OutletId, OutletId>,
127    ) -> TractResult<TVec<OutletId>> {
128        let is_source = node.op_as::<TypedSource>().is_some();
129        if !self.should_translate_node(node) && !is_source {
130            let new_op = node.op.clone();
131
132            let casted_inputs =
133                self.cast_inputs_if_required(target, node, mapping, self.from_dt)?;
134            let target_node_outlet_ids = target.wire_node(&node.name, new_op, &casted_inputs)?;
135
136            self.cast_model_outputs_if_required(source, node, target, target_node_outlet_ids)
137        } else {
138            let casted_inputs = self.cast_inputs_if_required(target, node, mapping, self.to_dt)?;
139
140            let new_op = if let Some(source_op) = node.op_as::<TypedSource>() {
141                let mut fact = source_op.fact.clone();
142                if fact.datum_type == self.from_dt {
143                    fact.datum_type = self.to_dt;
144                }
145                Box::new(TypedSource::new(fact))
146            } else if let Some(konst) = node.op_as::<Const>() {
147                if konst.val().datum_type() == self.from_dt && konst.val().is_plain() {
148                    let wire = target.add_const(
149                        format!("{}.{:?}", node.name, self.from_dt),
150                        konst.val().clone(),
151                    )?;
152                    return target.wire_node(&node.name, cast(self.to_dt), &[wire]);
153                } else {
154                    node.op.clone()
155                }
156            } else if let Some(cast_op) = node.op_as::<Cast>() {
157                if cast_op.to == self.from_dt {
158                    Box::new(Cast { to: self.to_dt })
159                } else {
160                    node.op.clone()
161                }
162            } else if let Some(ew) = node.op_as::<ElementWiseOp>() {
163                if ew.1 == Some(self.from_dt) {
164                    Box::new(ElementWiseOp(ew.0.clone(), Some(self.to_dt)))
165                } else {
166                    node.op.clone()
167                }
168            } else if let Some(bin) = node.op_as::<TypedBinOp>() {
169                if bin.1 == Some(self.from_dt) {
170                    Box::new(TypedBinOp(bin.0.clone(), Some(self.to_dt)))
171                } else {
172                    node.op.clone()
173                }
174            } else if let Some(op) = node.op_as::<Scan>() {
175                let body = FloatPrecisionTranslator::new(self.from_dt, self.to_dt)
176                    .translate_model(&op.body)?;
177                Box::new(Scan { body, ..op.clone() })
178            } else if let Some(op) = node.op_as::<EinSum>() {
179                let operating_dt =
180                    if op.operating_dt == self.from_dt { self.to_dt } else { op.operating_dt };
181                Box::new(EinSum { operating_dt, ..op.clone() })
182            } else if let Some(op) = node.op_as::<Pad>() {
183                if let PadMode::Constant(t) = &op.mode {
184                    let new_t = if t.datum_type() == self.from_dt {
185                        t.cast_to_dt(self.to_dt)?.into_owned().into_arc_tensor()
186                    } else {
187                        Arc::clone(t)
188                    };
189                    Box::new(Pad { mode: PadMode::Constant(new_t), ..op.clone() })
190                } else {
191                    Box::new(op.clone())
192                }
193            } else {
194                node.op.clone()
195            };
196            target.wire_node(&node.name, new_op, &casted_inputs)
197        }
198    }
199}
200
201#[cfg(test)]
202mod test {
203    use super::*;
204    use crate::ops::math;
205    use tract_data::prelude::f16;
206
207    fn build_f32_model() -> TractResult<TypedModel> {
208        // F32 model definition
209        let mut model = TypedModel::default();
210        let a = model.add_source("source", f32::fact([1])).unwrap();
211        let multiplier = model.add_const("multiplier", tensor1(&[1.0f32]))?;
212        let neg_infinity = model.add_const("neg_infinity", tensor1(&[f32::NEG_INFINITY]))?;
213        let pow_factor = model.add_const("pow_factor", tensor1(&[10.0f32]))?;
214        let add = model.wire_node("layer.0/add", math::add(), &[a, a]).unwrap()[0];
215        let mul = model.wire_node("layer.0/mul", math::mul(), &[add, multiplier]).unwrap()[0];
216        let pow = model.wire_node("layer.1/pow", math::pow(), &[mul, pow_factor]).unwrap()[0];
217        let _output = model
218            .wire_node("layer.1/add_neg_infinity", math::add(), &[pow, neg_infinity])
219            .unwrap()[0];
220        model.auto_outputs()?;
221        Ok(model)
222    }
223
224    #[test]
225    fn test_high_level_f16_transform_with_filter() -> TractResult<()> {
226        // F32 model definition
227        let model = build_f32_model()?;
228
229        // Execution in F32
230        let runnable_model = model.clone().into_runnable()?;
231        assert_eq!(
232            runnable_model.run(tvec![tensor1(&[5.0f32]).into()])?[0],
233            tensor1(&[f32::NEG_INFINITY]).into()
234        );
235
236        // Execution in F16 with returns NaN
237        let runnable_model = &crate::transform::get_transform("f32_to_f16")?
238            .unwrap()
239            .transform_into(model.clone())?
240            .into_runnable()?;
241        assert!(
242            runnable_model.run(tvec![tensor1(&[f16::from_f32(5.0)]).into()])?[0]
243                .try_as_plain()?
244                .to_scalar::<f16>()?
245                .is_nan()
246        );
247
248        // Execution in F16 with filter that returns the good output.
249        let runnable_model = &crate::transform::build_float_translator(
250            f32::datum_type(),
251            f16::datum_type(),
252            crate::transform::NodeFilter {
253                exclude: Some(vec!["layer.1".into()]),
254                ..Default::default()
255            },
256        )
257        .transform_into(model.clone())?
258        .into_runnable()?;
259        assert_eq!(
260            runnable_model.run(tvec![tensor1(&[f16::from_f32(5.0)]).into()])?[0],
261            tensor1(&[f16::NEG_INFINITY]).into()
262        );
263
264        // Execution in F16 with returns NaN despite the filter.
265        let runnable_model = &crate::transform::build_float_translator(
266            f32::datum_type(),
267            f16::datum_type(),
268            crate::transform::NodeFilter {
269                exclude: Some(vec!["layer.0".into()]),
270                ..Default::default()
271            },
272        )
273        .transform_into(model)?
274        .into_runnable()?;
275        assert!(
276            runnable_model.run(tvec![tensor1(&[f16::from_f32(5.0)]).into()])?[0]
277                .try_as_plain()?
278                .to_scalar::<f16>()?
279                .is_nan()
280        );
281
282        Ok(())
283    }
284
285    #[test]
286    fn test_f16_transform_with_filter() -> TractResult<()> {
287        // F32 model definition
288        let model = build_f32_model()?;
289
290        // Execution in F32
291        let runnable_model = model.clone().into_runnable()?;
292        assert_eq!(
293            runnable_model.run(tvec![tensor1(&[5.0f32]).into()])?[0],
294            tensor1(&[f32::NEG_INFINITY]).into()
295        );
296
297        // Execution in F16 with returns NaN
298        let mut model_f16 = model.clone();
299        model_f16
300            .transform(&FloatPrecisionTranslator::new(f32::datum_type(), f16::datum_type()))?;
301        let runnable_model_f16 = model_f16.clone().into_runnable()?;
302        assert!(
303            runnable_model_f16.run(tvec![tensor1(&[f16::from_f32(5.0)]).into()])?[0]
304                .try_as_plain()?
305                .to_scalar::<f16>()?
306                .is_nan()
307        );
308
309        // Execution in F16 with filter that returns the good output.
310        let mut model_f16_with_filter = model.clone();
311        model_f16_with_filter.transform(&FloatPrecisionTranslator::with_filter(
312            f32::datum_type(),
313            f16::datum_type(),
314            |node| !node.name.contains("layer.1"),
315        ))?;
316        let runnable_model_f16 = model_f16_with_filter.clone().into_runnable()?;
317        assert_eq!(
318            runnable_model_f16.run(tvec![tensor1(&[f16::from_f32(5.0)]).into()])?[0],
319            tensor1(&[f16::NEG_INFINITY]).into()
320        );
321        let mut model_f16_with_filter = model.clone();
322        model_f16_with_filter.transform(&FloatPrecisionTranslator::with_filter(
323            f32::datum_type(),
324            f16::datum_type(),
325            |node| !node.name.contains("layer.0"),
326        ))?;
327        let runnable_model_f16 = model_f16_with_filter.clone().into_runnable()?;
328        assert!(
329            runnable_model_f16.run(tvec![tensor1(&[f16::from_f32(5.0)]).into()])?[0]
330                .try_as_plain()?
331                .to_scalar::<f16>()?
332                .is_nan()
333        );
334        Ok(())
335    }
336}