Skip to main content

tract_core/
floats.rs

1use tract_num_traits::Float;
2
3use crate::internal::translator::Translate;
4use crate::internal::*;
5use crate::ops::array::{Pad, PadMode};
6use crate::ops::binary::TypedBinOp;
7use crate::ops::cast::{Cast, cast};
8use crate::ops::einsum::EinSum;
9use crate::ops::element_wise::ElementWiseOp;
10use crate::ops::konst::Const;
11use crate::ops::scan::Scan;
12use crate::ops::source::TypedSource;
13use crate::transform::ModelTransform;
14
15#[derive(Default)]
16pub struct FloatPrecisionTranslator<T1: Datum + Float, T2: Datum + Float> {
17    #[allow(clippy::type_complexity)]
18    node_predicate: Option<Box<dyn Fn(&TypedNode) -> bool>>,
19    _phantom: PhantomData<(T1, T2)>,
20}
21
22impl<T1: Datum + Float, T2: Datum + Float> FloatPrecisionTranslator<T1, T2> {
23    pub fn with_filter(node_predicate: impl Fn(&TypedNode) -> bool + 'static) -> Self {
24        Self { node_predicate: Some(Box::new(node_predicate)), _phantom: PhantomData }
25    }
26
27    fn should_translate_node(&self, node: &TypedNode) -> bool {
28        self.node_predicate.as_ref().map(|it| (it)(node)).unwrap_or(true)
29    }
30
31    /// Cast node inputs to the working float precision for the operator
32    /// Only input using float datumtype are impacted. This will add cast operations
33    /// in the model. The function return the new input outlet ids.
34    fn cast_inputs_if_required(
35        &self,
36        model: &mut TypedModel,
37        node: &TypedNode,
38        mapping: &HashMap<OutletId, OutletId>,
39        op_float_dt: DatumType,
40    ) -> TractResult<TVec<OutletId>> {
41        let original_op_float_dt =
42            if op_float_dt == T1::datum_type() { T2::datum_type() } else { T1::datum_type() };
43
44        let mut mapped_inputs = tvec![];
45        for (i_idx, i) in node.inputs.iter().enumerate() {
46            if model.outlet_fact(mapping[i])?.datum_type == original_op_float_dt {
47                let casted_mapped_input = model.wire_node(
48                    format!("{}.cast-{i_idx}", node.name),
49                    Cast { to: op_float_dt },
50                    &[mapping[i]],
51                )?[0];
52                mapped_inputs.push(casted_mapped_input);
53            } else {
54                mapped_inputs.push(mapping[i])
55            }
56        }
57        Ok(mapped_inputs)
58    }
59
60    /// Cast node output outlet ids to the destination float precision,
61    /// after insertion in the target mode. This preserves the model output float
62    /// precision.
63    fn cast_model_outputs_if_required(
64        &self,
65        source: &TypedModel,
66        node: &TypedNode,
67        target: &mut TypedModel,
68        target_node_outlet_ids: TVec<OutletId>,
69    ) -> TractResult<TVec<OutletId>> {
70        let mut outputs = tvec![];
71        for (o_idx, o) in target_node_outlet_ids.into_iter().enumerate() {
72            // Add Cast op for model output
73            let is_source_output = source.outputs.contains(&OutletId::new(node.id, o_idx));
74            if target.outlet_fact(o)?.datum_type == T1::datum_type() && is_source_output {
75                let casted_output = target.wire_node(
76                    format!("{}.cast-out-{o_idx}", node.name),
77                    Cast { to: T2::datum_type() },
78                    &[o],
79                )?[0];
80                outputs.push(casted_output);
81            } else {
82                outputs.push(o)
83            }
84        }
85        Ok(outputs)
86    }
87}
88
89impl<T1: Datum + Float, T2: Datum + Float> std::fmt::Debug for FloatPrecisionTranslator<T1, T2> {
90    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91        f.debug_struct("FloatPrecisionTranslator").field("_phantom", &self._phantom).finish()
92    }
93}
94
95impl<T1: Datum + Float, T2: Datum + Float> ModelTransform for FloatPrecisionTranslator<T1, T2> {
96    fn name(&self) -> StaticName {
97        format!("{:?}-to-{:?}", T1::datum_type(), T2::datum_type()).into()
98    }
99
100    fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
101        let new = self.translate_model(model)?;
102        *model = new;
103        Ok(())
104    }
105}
106
107impl<T1: Datum + Float, T2: Datum + Float>
108    Translate<TypedFact, Box<dyn TypedOp>, TypedFact, Box<dyn TypedOp>>
109    for FloatPrecisionTranslator<T1, T2>
110{
111    fn translate_node(
112        &self,
113        source: &TypedModel,
114        node: &TypedNode,
115        target: &mut TypedModel,
116        mapping: &HashMap<OutletId, OutletId>,
117    ) -> TractResult<TVec<OutletId>> {
118        let is_source = node.op_as::<TypedSource>().is_some();
119        if !self.should_translate_node(node) && !is_source {
120            let new_op = node.op.clone();
121
122            let casted_inputs =
123                self.cast_inputs_if_required(target, node, mapping, T1::datum_type())?;
124            let target_node_outlet_ids = target.wire_node(&node.name, new_op, &casted_inputs)?;
125
126            self.cast_model_outputs_if_required(source, node, target, target_node_outlet_ids)
127        } else {
128            let casted_inputs =
129                self.cast_inputs_if_required(target, node, mapping, T2::datum_type())?;
130
131            let new_op = if let Some(source) = node.op_as::<TypedSource>() {
132                Box::new(TypedSource::new(fact_float_precision_conversion::<T1, T2>(&source.fact)))
133            } else if let Some(konst) = node.op_as::<Const>() {
134                if konst.val().datum_type() == T1::datum_type() {
135                    let wire = target.add_const(
136                        format!("{}.{:?}", node.name, T1::datum_type()),
137                        konst.val().clone(),
138                    )?;
139                    return target.wire_node(&node.name, cast(T2::datum_type()), &[wire]);
140                } else {
141                    node.op.clone()
142                }
143            } else if let Some(cast) = node.op_as::<Cast>() {
144                if cast.to == T1::datum_type() {
145                    Box::new(Cast { to: T2::datum_type() })
146                } else {
147                    node.op.clone()
148                }
149            } else if let Some(ew) = node.op_as::<ElementWiseOp>() {
150                if ew.1 == Some(T1::datum_type()) {
151                    Box::new(ElementWiseOp(ew.0.clone(), Some(T2::datum_type())))
152                } else {
153                    node.op.clone()
154                }
155            } else if let Some(bin) = node.op_as::<TypedBinOp>() {
156                if bin.1 == Some(T1::datum_type()) {
157                    Box::new(TypedBinOp(bin.0.clone(), Some(T2::datum_type())))
158                } else {
159                    node.op.clone()
160                }
161            } else if let Some(op) = node.op_as::<Scan>() {
162                let body =
163                    FloatPrecisionTranslator::<T1, T2>::default().translate_model(&op.body)?;
164                Box::new(Scan { body, ..op.clone() })
165            } else if let Some(op) = node.op_as::<EinSum>() {
166                Box::new(EinSum {
167                    operating_dt: dt_float_precision_conversion::<T1, T2>(op.operating_dt),
168                    ..op.clone()
169                })
170            } else if let Some(op) = node.op_as::<Pad>() {
171                if let PadMode::Constant(t) = &op.mode {
172                    Box::new(Pad {
173                        mode: PadMode::Constant(tensor_float_precision_conversion::<T1, T2>(t)),
174                        ..op.clone()
175                    })
176                } else {
177                    Box::new(op.clone())
178                }
179            } else {
180                node.op.clone()
181            };
182            target.wire_node(&node.name, new_op, &casted_inputs)
183        }
184    }
185}
186
187fn dt_float_precision_conversion<T1: Datum + Float, T2: Datum + Float>(dt: DatumType) -> DatumType {
188    if dt == T1::datum_type() { T2::datum_type() } else { dt }
189}
190
191fn fact_float_precision_conversion<T1: Datum + Float, T2: Datum + Float>(
192    t: &TypedFact,
193) -> TypedFact {
194    if t.datum_type == T1::datum_type() {
195        let mut t = t.clone();
196        t.datum_type = T2::datum_type();
197        t
198    } else {
199        t.clone()
200    }
201}
202
203fn tensor_float_precision_conversion<T1: Datum + Float, T2: Datum + Float>(
204    t: &Arc<Tensor>,
205) -> Arc<Tensor> {
206    if t.datum_type() == T1::datum_type() {
207        t.cast_to::<T2>().unwrap().into_owned().into_arc_tensor()
208    } else {
209        Arc::clone(t)
210    }
211}
212
213#[cfg(test)]
214mod test {
215    use super::*;
216    use crate::ops::math;
217    use tract_data::prelude::f16;
218
219    fn build_f32_model() -> TractResult<TypedModel> {
220        // F32 model definition
221        let mut model = TypedModel::default();
222        let a = model.add_source("source", f32::fact([1])).unwrap();
223        let multiplier = model.add_const("multiplier", tensor1(&[1.0f32]))?;
224        let neg_infinity = model.add_const("neg_infinity", tensor1(&[f32::NEG_INFINITY]))?;
225        let pow_factor = model.add_const("pow_factor", tensor1(&[10.0f32]))?;
226        let add = model.wire_node("layer.0/add", math::add(), &[a, a]).unwrap()[0];
227        let mul = model.wire_node("layer.0/mul", math::mul(), &[add, multiplier]).unwrap()[0];
228        let pow = model.wire_node("layer.1/pow", math::pow(), &[mul, pow_factor]).unwrap()[0];
229        let _output = model
230            .wire_node("layer.1/add_neg_infinity", math::add(), &[pow, neg_infinity])
231            .unwrap()[0];
232        model.auto_outputs()?;
233        Ok(model)
234    }
235
236    #[test]
237    fn test_high_level_f16_transform_with_filter() -> TractResult<()> {
238        // F32 model definition
239        let model = build_f32_model()?;
240
241        // Execution in F32
242        let runnable_model = model.clone().into_runnable()?;
243        assert_eq!(
244            runnable_model.run(tvec![tensor1(&[5.0f32]).into()])?[0],
245            tensor1(&[f32::NEG_INFINITY]).into()
246        );
247
248        // Execution in F16 with returns NaN
249        let runnable_model = &crate::transform::get_transform("f32-to-f16")?
250            .unwrap()
251            .transform_into(model.clone())?
252            .into_runnable()?;
253        assert!(
254            runnable_model.run(tvec![tensor1(&[f16::from_f32(5.0)]).into()])?[0]
255                .to_scalar::<f16>()?
256                .is_nan()
257        );
258
259        // Execution in F16 with filter that returns the good output.
260        let runnable_model = &crate::transform::get_transform("f32-to-f16!=layer.1")?
261            .unwrap()
262            .transform_into(model.clone())?
263            .into_runnable()?;
264        assert_eq!(
265            runnable_model.run(tvec![tensor1(&[f16::from_f32(5.0)]).into()])?[0],
266            tensor1(&[f16::NEG_INFINITY]).into()
267        );
268
269        // Execution in F16 with returns NaN despite the filter.
270        let runnable_model = &crate::transform::get_transform("f32-to-f16!=layer.0")?
271            .unwrap()
272            .transform_into(model)?
273            .into_runnable()?;
274        assert!(
275            runnable_model.run(tvec![tensor1(&[f16::from_f32(5.0)]).into()])?[0]
276                .to_scalar::<f16>()?
277                .is_nan()
278        );
279
280        Ok(())
281    }
282
283    #[test]
284    fn test_f16_transform_with_filter() -> TractResult<()> {
285        // F32 model definition
286        let model = build_f32_model()?;
287
288        // Execution in F32
289        let runnable_model = model.clone().into_runnable()?;
290        assert_eq!(
291            runnable_model.run(tvec![tensor1(&[5.0f32]).into()])?[0],
292            tensor1(&[f32::NEG_INFINITY]).into()
293        );
294
295        // Execution in F16 with returns NaN
296        let mut model_f16 = model.clone();
297        model_f16.transform(&FloatPrecisionTranslator::<f32, f16>::default())?;
298        let runnable_model_f16 = model_f16.clone().into_runnable()?;
299        assert!(
300            runnable_model_f16.run(tvec![tensor1(&[f16::from_f32(5.0)]).into()])?[0]
301                .to_scalar::<f16>()?
302                .is_nan()
303        );
304
305        // Execution in F16 with filter that returns the good output.
306        let mut model_f16_with_filter = model.clone();
307        model_f16_with_filter.transform(&FloatPrecisionTranslator::<f32, f16>::with_filter(
308            |node| !node.name.contains("layer.1"),
309        ))?;
310        let runnable_model_f16 = model_f16_with_filter.clone().into_runnable()?;
311        assert_eq!(
312            runnable_model_f16.run(tvec![tensor1(&[f16::from_f32(5.0)]).into()])?[0],
313            tensor1(&[f16::NEG_INFINITY]).into()
314        );
315        let mut model_f16_with_filter = model.clone();
316        model_f16_with_filter.transform(&FloatPrecisionTranslator::<f32, f16>::with_filter(
317            |node| !node.name.contains("layer.0"),
318        ))?;
319        let runnable_model_f16 = model_f16_with_filter.clone().into_runnable()?;
320        assert!(
321            runnable_model_f16.run(tvec![tensor1(&[f16::from_f32(5.0)]).into()])?[0]
322                .to_scalar::<f16>()?
323                .is_nan()
324        );
325        Ok(())
326    }
327}