1use tract_num_traits::Float;
2
3use crate::internal::translator::Translate;
4use crate::internal::*;
5use crate::ops::array::{Pad, PadMode};
6use crate::ops::binary::TypedBinOp;
7use crate::ops::cast::{cast, Cast};
8use crate::ops::einsum::EinSum;
9use crate::ops::element_wise::ElementWiseOp;
10use crate::ops::konst::Const;
11use crate::ops::scan::Scan;
12use crate::ops::source::TypedSource;
13use crate::transform::ModelTransform;
14
15#[derive(Default)]
16pub struct FloatPrecisionTranslator<T1: Datum + Float, T2: Datum + Float> {
17 #[allow(clippy::type_complexity)]
18 node_predicate: Option<Box<dyn Fn(&TypedNode) -> bool>>,
19 _phantom: PhantomData<(T1, T2)>,
20}
21
22impl<T1: Datum + Float, T2: Datum + Float> FloatPrecisionTranslator<T1, T2> {
23 pub fn with_filter(node_predicate: impl Fn(&TypedNode) -> bool + 'static) -> Self {
24 Self { node_predicate: Some(Box::new(node_predicate)), _phantom: PhantomData }
25 }
26
27 fn should_translate_node(&self, node: &TypedNode) -> bool {
28 self.node_predicate.as_ref().map(|it| (it)(node)).unwrap_or(true)
29 }
30
31 fn cast_inputs_if_required(
35 &self,
36 model: &mut TypedModel,
37 node: &TypedNode,
38 mapping: &HashMap<OutletId, OutletId>,
39 op_float_dt: DatumType,
40 ) -> TractResult<TVec<OutletId>> {
41 let original_op_float_dt =
42 if op_float_dt == T1::datum_type() { T2::datum_type() } else { T1::datum_type() };
43
44 let mut mapped_inputs = tvec![];
45 for (i_idx, i) in node.inputs.iter().enumerate() {
46 if model.outlet_fact(mapping[i])?.datum_type == original_op_float_dt {
47 let casted_mapped_input = model.wire_node(
48 format!("{}.cast-{i_idx}", node.name),
49 Cast { to: op_float_dt },
50 &[mapping[i]],
51 )?[0];
52 mapped_inputs.push(casted_mapped_input);
53 } else {
54 mapped_inputs.push(mapping[i])
55 }
56 }
57 Ok(mapped_inputs)
58 }
59
60 fn cast_model_outputs_if_required(
64 &self,
65 source: &TypedModel,
66 node: &TypedNode,
67 target: &mut TypedModel,
68 target_node_outlet_ids: TVec<OutletId>,
69 ) -> TractResult<TVec<OutletId>> {
70 let mut outputs = tvec![];
71 for (o_idx, o) in target_node_outlet_ids.into_iter().enumerate() {
72 let is_source_output = source.outputs.contains(&OutletId::new(node.id, o_idx));
74 if target.outlet_fact(o)?.datum_type == T1::datum_type() && is_source_output {
75 let casted_output = target.wire_node(
76 format!("{}.cast-out-{o_idx}", node.name),
77 Cast { to: T2::datum_type() },
78 &[o],
79 )?[0];
80 outputs.push(casted_output);
81 } else {
82 outputs.push(o)
83 }
84 }
85 Ok(outputs)
86 }
87}
88
89impl<T1: Datum + Float, T2: Datum + Float> std::fmt::Debug for FloatPrecisionTranslator<T1, T2> {
90 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91 f.debug_struct("FloatPrecisionTranslator").field("_phantom", &self._phantom).finish()
92 }
93}
94
95impl<T1: Datum + Float, T2: Datum + Float> ModelTransform for FloatPrecisionTranslator<T1, T2> {
96 fn name(&self) -> Cow<str> {
97 format!("{:?}-to-{:?}", T1::datum_type(), T2::datum_type()).into()
98 }
99
100 fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
101 let new = self.translate_model(model)?;
102 *model = new;
103 Ok(())
104 }
105}
106
107impl<T1: Datum + Float, T2: Datum + Float>
108 Translate<TypedFact, Box<dyn TypedOp>, TypedFact, Box<dyn TypedOp>>
109 for FloatPrecisionTranslator<T1, T2>
110{
111 fn translate_node(
112 &self,
113 source: &TypedModel,
114 node: &TypedNode,
115 target: &mut TypedModel,
116 mapping: &HashMap<OutletId, OutletId>,
117 ) -> TractResult<TVec<OutletId>> {
118 let is_source = node.op_as::<TypedSource>().is_some();
119 if !self.should_translate_node(node) && !is_source {
120 let new_op = node.op.clone();
121
122 let casted_inputs =
123 self.cast_inputs_if_required(target, node, mapping, T1::datum_type())?;
124 let target_node_outlet_ids = target.wire_node(&node.name, new_op, &casted_inputs)?;
125
126 self.cast_model_outputs_if_required(source, node, target, target_node_outlet_ids)
127 } else {
128 let casted_inputs =
129 self.cast_inputs_if_required(target, node, mapping, T2::datum_type())?;
130
131 let new_op = if let Some(source) = node.op_as::<TypedSource>() {
132 Box::new(TypedSource::new(fact_float_precision_conversion::<T1, T2>(&source.fact)))
133 } else if let Some(konst) = node.op_as::<Const>() {
134 if konst.0.datum_type() == T1::datum_type() {
135 let wire = target.add_const(
136 format!("{}.{:?}", node.name, T1::datum_type()),
137 konst.0.clone(),
138 )?;
139 return target.wire_node(&node.name, cast(T2::datum_type()), &[wire]);
140 } else {
141 node.op.clone()
142 }
143 } else if let Some(cast) = node.op_as::<Cast>() {
144 if cast.to == T1::datum_type() {
145 Box::new(Cast { to: T2::datum_type() })
146 } else {
147 node.op.clone()
148 }
149 } else if let Some(ew) = node.op_as::<ElementWiseOp>() {
150 if ew.1 == Some(T1::datum_type()) {
151 Box::new(ElementWiseOp(ew.0.clone(), Some(T2::datum_type())))
152 } else {
153 node.op.clone()
154 }
155 } else if let Some(bin) = node.op_as::<TypedBinOp>() {
156 if bin.1 == Some(T1::datum_type()) {
157 Box::new(TypedBinOp(bin.0.clone(), Some(T2::datum_type())))
158 } else {
159 node.op.clone()
160 }
161 } else if let Some(op) = node.op_as::<Scan>() {
162 let body =
163 FloatPrecisionTranslator::<T1, T2>::default().translate_model(&op.body)?;
164 Box::new(Scan { body, ..op.clone() })
165 } else if let Some(op) = node.op_as::<EinSum>() {
166 Box::new(EinSum {
167 operating_dt: dt_float_precision_conversion::<T1, T2>(op.operating_dt),
168 ..op.clone()
169 })
170 } else if let Some(op) = node.op_as::<Pad>() {
171 if let PadMode::Constant(t) = &op.mode {
172 Box::new(Pad {
173 mode: PadMode::Constant(tensor_float_precision_conversion::<T1, T2>(t)),
174 ..op.clone()
175 })
176 } else {
177 Box::new(op.clone())
178 }
179 } else {
180 node.op.clone()
181 };
182 target.wire_node(&node.name, new_op, &casted_inputs)
183 }
184 }
185}
186
187fn dt_float_precision_conversion<T1: Datum + Float, T2: Datum + Float>(dt: DatumType) -> DatumType {
188 if dt == T1::datum_type() {
189 T2::datum_type()
190 } else {
191 dt
192 }
193}
194
195fn fact_float_precision_conversion<T1: Datum + Float, T2: Datum + Float>(
196 t: &TypedFact,
197) -> TypedFact {
198 if t.datum_type == T1::datum_type() {
199 let mut t = t.clone();
200 t.datum_type = T2::datum_type();
201 t
202 } else {
203 t.clone()
204 }
205}
206
207fn tensor_float_precision_conversion<T1: Datum + Float, T2: Datum + Float>(
208 t: &Arc<Tensor>,
209) -> Arc<Tensor> {
210 if t.datum_type() == T1::datum_type() {
211 t.cast_to::<T2>().unwrap().into_owned().into_arc_tensor()
212 } else {
213 Arc::clone(t)
214 }
215}
216
217#[cfg(test)]
218mod test {
219 use super::*;
220 use crate::ops::math;
221 use tract_data::prelude::f16;
222
223 fn build_f32_model() -> TractResult<TypedModel> {
224 let mut model = TypedModel::default();
226 let a = model.add_source("source", f32::fact([1])).unwrap();
227 let multiplier = model.add_const("multiplier", tensor1(&[1.0f32]))?;
228 let neg_infinity = model.add_const("neg_infinity", tensor1(&[f32::NEG_INFINITY]))?;
229 let pow_factor = model.add_const("pow_factor", tensor1(&[10.0f32]))?;
230 let add = model.wire_node("layer.0/add", math::add(), &[a, a]).unwrap()[0];
231 let mul = model.wire_node("layer.0/mul", math::mul(), &[add, multiplier]).unwrap()[0];
232 let pow = model.wire_node("layer.1/pow", math::pow(), &[mul, pow_factor]).unwrap()[0];
233 let _output = model
234 .wire_node("layer.1/add_neg_infinity", math::add(), &[pow, neg_infinity])
235 .unwrap()[0];
236 model.auto_outputs()?;
237 Ok(model)
238 }
239
240 #[test]
241 fn test_high_level_f16_transform_with_filter() -> TractResult<()> {
242 let model = build_f32_model()?;
244
245 let runnable_model = model.clone().into_runnable()?;
247 assert_eq!(
248 runnable_model.run(tvec![tensor1(&[5.0f32]).into()])?[0],
249 tensor1(&[f32::NEG_INFINITY]).into()
250 );
251
252 let runnable_model = &crate::transform::get_transform("f32-to-f16")
254 .unwrap()
255 .transform_into(model.clone())?
256 .into_runnable()?;
257 assert!(runnable_model.run(tvec![tensor1(&[f16::from_f32(5.0)]).into()])?[0]
258 .to_scalar::<f16>()?
259 .is_nan());
260
261 let runnable_model = &crate::transform::get_transform("f32-to-f16!=layer.1")
263 .unwrap()
264 .transform_into(model.clone())?
265 .into_runnable()?;
266 assert_eq!(
267 runnable_model.run(tvec![tensor1(&[f16::from_f32(5.0)]).into()])?[0],
268 tensor1(&[f16::NEG_INFINITY]).into()
269 );
270
271 let runnable_model = &crate::transform::get_transform("f32-to-f16!=layer.0")
273 .unwrap()
274 .transform_into(model)?
275 .into_runnable()?;
276 assert!(runnable_model.run(tvec![tensor1(&[f16::from_f32(5.0)]).into()])?[0]
277 .to_scalar::<f16>()?
278 .is_nan());
279
280 Ok(())
281 }
282
283 #[test]
284 fn test_f16_transform_with_filter() -> TractResult<()> {
285 let model = build_f32_model()?;
287
288 let runnable_model = model.clone().into_runnable()?;
290 assert_eq!(
291 runnable_model.run(tvec![tensor1(&[5.0f32]).into()])?[0],
292 tensor1(&[f32::NEG_INFINITY]).into()
293 );
294
295 let mut model_f16 = model.clone();
297 model_f16.transform(&FloatPrecisionTranslator::<f32, f16>::default())?;
298 let runnable_model_f16 = model_f16.clone().into_runnable()?;
299 assert!(runnable_model_f16.run(tvec![tensor1(&[f16::from_f32(5.0)]).into()])?[0]
300 .to_scalar::<f16>()?
301 .is_nan());
302
303 let mut model_f16_with_filter = model.clone();
305 model_f16_with_filter.transform(&FloatPrecisionTranslator::<f32, f16>::with_filter(
306 |node| !node.name.contains("layer.1"),
307 ))?;
308 let runnable_model_f16 = model_f16_with_filter.clone().into_runnable()?;
309 assert_eq!(
310 runnable_model_f16.run(tvec![tensor1(&[f16::from_f32(5.0)]).into()])?[0],
311 tensor1(&[f16::NEG_INFINITY]).into()
312 );
313 let mut model_f16_with_filter = model.clone();
314 model_f16_with_filter.transform(&FloatPrecisionTranslator::<f32, f16>::with_filter(
315 |node| !node.name.contains("layer.0"),
316 ))?;
317 let runnable_model_f16 = model_f16_with_filter.clone().into_runnable()?;
318 assert!(runnable_model_f16.run(tvec![tensor1(&[f16::from_f32(5.0)]).into()])?[0]
319 .to_scalar::<f16>()?
320 .is_nan());
321 Ok(())
322 }
323}