1use crate::internal::translator::Translate;
2use crate::internal::*;
3use crate::ops::array::{Pad, PadMode};
4use crate::ops::binary::TypedBinOp;
5use crate::ops::cast::{Cast, cast};
6use crate::ops::einsum::EinSum;
7use crate::ops::element_wise::ElementWiseOp;
8use crate::ops::konst::Const;
9use crate::ops::scan::Scan;
10use crate::ops::source::TypedSource;
11use crate::transform::ModelTransform;
12
13pub struct FloatPrecisionTranslator {
14 from_dt: DatumType,
15 to_dt: DatumType,
16 #[allow(clippy::type_complexity)]
17 node_predicate: Option<Box<dyn Fn(&TypedNode) -> bool>>,
18}
19
20impl FloatPrecisionTranslator {
21 pub fn new(from_dt: DatumType, to_dt: DatumType) -> Self {
22 Self { from_dt, to_dt, node_predicate: None }
23 }
24
25 pub fn with_filter(
26 from_dt: DatumType,
27 to_dt: DatumType,
28 node_predicate: impl Fn(&TypedNode) -> bool + 'static,
29 ) -> Self {
30 Self { from_dt, to_dt, node_predicate: Some(Box::new(node_predicate)) }
31 }
32
33 fn should_translate_node(&self, node: &TypedNode) -> bool {
34 self.node_predicate.as_ref().map(|it| (it)(node)).unwrap_or(true)
35 }
36
37 fn cast_inputs_if_required(
41 &self,
42 model: &mut TypedModel,
43 node: &TypedNode,
44 mapping: &HashMap<OutletId, OutletId>,
45 op_float_dt: DatumType,
46 ) -> TractResult<TVec<OutletId>> {
47 let original_op_float_dt =
48 if op_float_dt == self.from_dt { self.to_dt } else { self.from_dt };
49
50 let mut mapped_inputs = tvec![];
51 for (i_idx, i) in node.inputs.iter().enumerate() {
52 let fact = model.outlet_fact(mapping[i])?;
53 if fact.datum_type == original_op_float_dt && fact.is_plain() {
54 let casted_mapped_input = model.wire_node(
55 format!("{}.cast-{i_idx}", node.name),
56 Cast { to: op_float_dt },
57 &[mapping[i]],
58 )?[0];
59 mapped_inputs.push(casted_mapped_input);
60 } else {
61 mapped_inputs.push(mapping[i])
62 }
63 }
64 Ok(mapped_inputs)
65 }
66
67 fn cast_model_outputs_if_required(
71 &self,
72 source: &TypedModel,
73 node: &TypedNode,
74 target: &mut TypedModel,
75 target_node_outlet_ids: TVec<OutletId>,
76 ) -> TractResult<TVec<OutletId>> {
77 let mut outputs = tvec![];
78 for (o_idx, o) in target_node_outlet_ids.into_iter().enumerate() {
79 let is_source_output = source.outputs.contains(&OutletId::new(node.id, o_idx));
81 let fact = target.outlet_fact(o)?;
82 if fact.datum_type == self.from_dt && fact.is_plain() && is_source_output {
83 let casted_output = target.wire_node(
84 format!("{}.cast-out-{o_idx}", node.name),
85 Cast { to: self.to_dt },
86 &[o],
87 )?[0];
88 outputs.push(casted_output);
89 } else {
90 outputs.push(o)
91 }
92 }
93 Ok(outputs)
94 }
95}
96
97impl std::fmt::Debug for FloatPrecisionTranslator {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 f.debug_struct("FloatPrecisionTranslator")
100 .field("from", &self.from_dt)
101 .field("to", &self.to_dt)
102 .finish()
103 }
104}
105
106impl ModelTransform for FloatPrecisionTranslator {
107 fn name(&self) -> StaticName {
108 format!("{:?}-to-{:?}", self.from_dt, self.to_dt).into()
109 }
110
111 fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
112 let new = self.translate_model(model)?;
113 *model = new;
114 Ok(())
115 }
116}
117
118impl Translate<TypedFact, Box<dyn TypedOp>, TypedFact, Box<dyn TypedOp>>
119 for FloatPrecisionTranslator
120{
121 fn translate_node(
122 &self,
123 source: &TypedModel,
124 node: &TypedNode,
125 target: &mut TypedModel,
126 mapping: &HashMap<OutletId, OutletId>,
127 ) -> TractResult<TVec<OutletId>> {
128 let is_source = node.op_as::<TypedSource>().is_some();
129 if !self.should_translate_node(node) && !is_source {
130 let new_op = node.op.clone();
131
132 let casted_inputs =
133 self.cast_inputs_if_required(target, node, mapping, self.from_dt)?;
134 let target_node_outlet_ids = target.wire_node(&node.name, new_op, &casted_inputs)?;
135
136 self.cast_model_outputs_if_required(source, node, target, target_node_outlet_ids)
137 } else {
138 let casted_inputs = self.cast_inputs_if_required(target, node, mapping, self.to_dt)?;
139
140 let new_op = if let Some(source_op) = node.op_as::<TypedSource>() {
141 let mut fact = source_op.fact.clone();
142 if fact.datum_type == self.from_dt {
143 fact.datum_type = self.to_dt;
144 }
145 Box::new(TypedSource::new(fact))
146 } else if let Some(konst) = node.op_as::<Const>() {
147 if konst.val().datum_type() == self.from_dt && konst.val().is_plain() {
148 let wire = target.add_const(
149 format!("{}.{:?}", node.name, self.from_dt),
150 konst.val().clone(),
151 )?;
152 return target.wire_node(&node.name, cast(self.to_dt), &[wire]);
153 } else {
154 node.op.clone()
155 }
156 } else if let Some(cast_op) = node.op_as::<Cast>() {
157 if cast_op.to == self.from_dt {
158 Box::new(Cast { to: self.to_dt })
159 } else {
160 node.op.clone()
161 }
162 } else if let Some(ew) = node.op_as::<ElementWiseOp>() {
163 if ew.1 == Some(self.from_dt) {
164 Box::new(ElementWiseOp(ew.0.clone(), Some(self.to_dt)))
165 } else {
166 node.op.clone()
167 }
168 } else if let Some(bin) = node.op_as::<TypedBinOp>() {
169 if bin.1 == Some(self.from_dt) {
170 Box::new(TypedBinOp(bin.0.clone(), Some(self.to_dt)))
171 } else {
172 node.op.clone()
173 }
174 } else if let Some(op) = node.op_as::<Scan>() {
175 let body = FloatPrecisionTranslator::new(self.from_dt, self.to_dt)
176 .translate_model(&op.body)?;
177 Box::new(Scan { body, ..op.clone() })
178 } else if let Some(op) = node.op_as::<EinSum>() {
179 let operating_dt =
180 if op.operating_dt == self.from_dt { self.to_dt } else { op.operating_dt };
181 Box::new(EinSum { operating_dt, ..op.clone() })
182 } else if let Some(op) = node.op_as::<Pad>() {
183 if let PadMode::Constant(t) = &op.mode {
184 let new_t = if t.datum_type() == self.from_dt {
185 t.cast_to_dt(self.to_dt)?.into_owned().into_arc_tensor()
186 } else {
187 Arc::clone(t)
188 };
189 Box::new(Pad { mode: PadMode::Constant(new_t), ..op.clone() })
190 } else {
191 Box::new(op.clone())
192 }
193 } else {
194 node.op.clone()
195 };
196 target.wire_node(&node.name, new_op, &casted_inputs)
197 }
198 }
199}
200
201#[cfg(test)]
202mod test {
203 use super::*;
204 use crate::ops::math;
205 use tract_data::prelude::f16;
206
207 fn build_f32_model() -> TractResult<TypedModel> {
208 let mut model = TypedModel::default();
210 let a = model.add_source("source", f32::fact([1])).unwrap();
211 let multiplier = model.add_const("multiplier", tensor1(&[1.0f32]))?;
212 let neg_infinity = model.add_const("neg_infinity", tensor1(&[f32::NEG_INFINITY]))?;
213 let pow_factor = model.add_const("pow_factor", tensor1(&[10.0f32]))?;
214 let add = model.wire_node("layer.0/add", math::add(), &[a, a]).unwrap()[0];
215 let mul = model.wire_node("layer.0/mul", math::mul(), &[add, multiplier]).unwrap()[0];
216 let pow = model.wire_node("layer.1/pow", math::pow(), &[mul, pow_factor]).unwrap()[0];
217 let _output = model
218 .wire_node("layer.1/add_neg_infinity", math::add(), &[pow, neg_infinity])
219 .unwrap()[0];
220 model.auto_outputs()?;
221 Ok(model)
222 }
223
224 #[test]
225 fn test_high_level_f16_transform_with_filter() -> TractResult<()> {
226 let model = build_f32_model()?;
228
229 let runnable_model = model.clone().into_runnable()?;
231 assert_eq!(
232 runnable_model.run(tvec![tensor1(&[5.0f32]).into()])?[0],
233 tensor1(&[f32::NEG_INFINITY]).into()
234 );
235
236 let runnable_model = &crate::transform::get_transform("f32_to_f16")?
238 .unwrap()
239 .transform_into(model.clone())?
240 .into_runnable()?;
241 assert!(
242 runnable_model.run(tvec![tensor1(&[f16::from_f32(5.0)]).into()])?[0]
243 .try_as_plain()?
244 .to_scalar::<f16>()?
245 .is_nan()
246 );
247
248 let runnable_model = &crate::transform::build_float_translator(
250 f32::datum_type(),
251 f16::datum_type(),
252 crate::transform::NodeFilter {
253 exclude: Some(vec!["layer.1".into()]),
254 ..Default::default()
255 },
256 )
257 .transform_into(model.clone())?
258 .into_runnable()?;
259 assert_eq!(
260 runnable_model.run(tvec![tensor1(&[f16::from_f32(5.0)]).into()])?[0],
261 tensor1(&[f16::NEG_INFINITY]).into()
262 );
263
264 let runnable_model = &crate::transform::build_float_translator(
266 f32::datum_type(),
267 f16::datum_type(),
268 crate::transform::NodeFilter {
269 exclude: Some(vec!["layer.0".into()]),
270 ..Default::default()
271 },
272 )
273 .transform_into(model)?
274 .into_runnable()?;
275 assert!(
276 runnable_model.run(tvec![tensor1(&[f16::from_f32(5.0)]).into()])?[0]
277 .try_as_plain()?
278 .to_scalar::<f16>()?
279 .is_nan()
280 );
281
282 Ok(())
283 }
284
285 #[test]
286 fn test_f16_transform_with_filter() -> TractResult<()> {
287 let model = build_f32_model()?;
289
290 let runnable_model = model.clone().into_runnable()?;
292 assert_eq!(
293 runnable_model.run(tvec![tensor1(&[5.0f32]).into()])?[0],
294 tensor1(&[f32::NEG_INFINITY]).into()
295 );
296
297 let mut model_f16 = model.clone();
299 model_f16
300 .transform(&FloatPrecisionTranslator::new(f32::datum_type(), f16::datum_type()))?;
301 let runnable_model_f16 = model_f16.clone().into_runnable()?;
302 assert!(
303 runnable_model_f16.run(tvec![tensor1(&[f16::from_f32(5.0)]).into()])?[0]
304 .try_as_plain()?
305 .to_scalar::<f16>()?
306 .is_nan()
307 );
308
309 let mut model_f16_with_filter = model.clone();
311 model_f16_with_filter.transform(&FloatPrecisionTranslator::with_filter(
312 f32::datum_type(),
313 f16::datum_type(),
314 |node| !node.name.contains("layer.1"),
315 ))?;
316 let runnable_model_f16 = model_f16_with_filter.clone().into_runnable()?;
317 assert_eq!(
318 runnable_model_f16.run(tvec![tensor1(&[f16::from_f32(5.0)]).into()])?[0],
319 tensor1(&[f16::NEG_INFINITY]).into()
320 );
321 let mut model_f16_with_filter = model.clone();
322 model_f16_with_filter.transform(&FloatPrecisionTranslator::with_filter(
323 f32::datum_type(),
324 f16::datum_type(),
325 |node| !node.name.contains("layer.0"),
326 ))?;
327 let runnable_model_f16 = model_f16_with_filter.clone().into_runnable()?;
328 assert!(
329 runnable_model_f16.run(tvec![tensor1(&[f16::from_f32(5.0)]).into()])?[0]
330 .try_as_plain()?
331 .to_scalar::<f16>()?
332 .is_nan()
333 );
334 Ok(())
335 }
336}