1use crate::internal::translator::Translate;
2use crate::internal::*;
3use crate::ops::array::{Pad, PadMode};
4use crate::ops::binary::TypedBinOp;
5use crate::ops::cast::{Cast, cast};
6use crate::ops::einsum::EinSum;
7use crate::ops::element_wise::ElementWiseOp;
8use crate::ops::konst::Const;
9use crate::ops::scan::Scan;
10use crate::ops::source::TypedSource;
11use crate::transform::ModelTransform;
12
13pub struct FloatPrecisionTranslator {
14 from_dt: DatumType,
15 to_dt: DatumType,
16 #[allow(clippy::type_complexity)]
17 node_predicate: Option<Box<dyn Fn(&TypedNode) -> bool>>,
18}
19
20impl FloatPrecisionTranslator {
21 pub fn new(from_dt: DatumType, to_dt: DatumType) -> Self {
22 Self { from_dt, to_dt, node_predicate: None }
23 }
24
25 pub fn with_filter(
26 from_dt: DatumType,
27 to_dt: DatumType,
28 node_predicate: impl Fn(&TypedNode) -> bool + 'static,
29 ) -> Self {
30 Self { from_dt, to_dt, node_predicate: Some(Box::new(node_predicate)) }
31 }
32
33 fn should_translate_node(&self, node: &TypedNode) -> bool {
34 self.node_predicate.as_ref().map(|it| (it)(node)).unwrap_or(true)
35 }
36
37 fn cast_inputs_if_required(
41 &self,
42 model: &mut TypedModel,
43 node: &TypedNode,
44 mapping: &HashMap<OutletId, OutletId>,
45 op_float_dt: DatumType,
46 ) -> TractResult<TVec<OutletId>> {
47 let original_op_float_dt =
48 if op_float_dt == self.from_dt { self.to_dt } else { self.from_dt };
49
50 let mut mapped_inputs = tvec![];
51 for (i_idx, i) in node.inputs.iter().enumerate() {
52 if model.outlet_fact(mapping[i])?.datum_type == original_op_float_dt {
53 let casted_mapped_input = model.wire_node(
54 format!("{}.cast-{i_idx}", node.name),
55 Cast { to: op_float_dt },
56 &[mapping[i]],
57 )?[0];
58 mapped_inputs.push(casted_mapped_input);
59 } else {
60 mapped_inputs.push(mapping[i])
61 }
62 }
63 Ok(mapped_inputs)
64 }
65
66 fn cast_model_outputs_if_required(
70 &self,
71 source: &TypedModel,
72 node: &TypedNode,
73 target: &mut TypedModel,
74 target_node_outlet_ids: TVec<OutletId>,
75 ) -> TractResult<TVec<OutletId>> {
76 let mut outputs = tvec![];
77 for (o_idx, o) in target_node_outlet_ids.into_iter().enumerate() {
78 let is_source_output = source.outputs.contains(&OutletId::new(node.id, o_idx));
80 if target.outlet_fact(o)?.datum_type == self.from_dt && is_source_output {
81 let casted_output = target.wire_node(
82 format!("{}.cast-out-{o_idx}", node.name),
83 Cast { to: self.to_dt },
84 &[o],
85 )?[0];
86 outputs.push(casted_output);
87 } else {
88 outputs.push(o)
89 }
90 }
91 Ok(outputs)
92 }
93}
94
95impl std::fmt::Debug for FloatPrecisionTranslator {
96 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97 f.debug_struct("FloatPrecisionTranslator")
98 .field("from", &self.from_dt)
99 .field("to", &self.to_dt)
100 .finish()
101 }
102}
103
104impl ModelTransform for FloatPrecisionTranslator {
105 fn name(&self) -> StaticName {
106 format!("{:?}-to-{:?}", self.from_dt, self.to_dt).into()
107 }
108
109 fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
110 let new = self.translate_model(model)?;
111 *model = new;
112 Ok(())
113 }
114}
115
116impl Translate<TypedFact, Box<dyn TypedOp>, TypedFact, Box<dyn TypedOp>>
117 for FloatPrecisionTranslator
118{
119 fn translate_node(
120 &self,
121 source: &TypedModel,
122 node: &TypedNode,
123 target: &mut TypedModel,
124 mapping: &HashMap<OutletId, OutletId>,
125 ) -> TractResult<TVec<OutletId>> {
126 let is_source = node.op_as::<TypedSource>().is_some();
127 if !self.should_translate_node(node) && !is_source {
128 let new_op = node.op.clone();
129
130 let casted_inputs =
131 self.cast_inputs_if_required(target, node, mapping, self.from_dt)?;
132 let target_node_outlet_ids = target.wire_node(&node.name, new_op, &casted_inputs)?;
133
134 self.cast_model_outputs_if_required(source, node, target, target_node_outlet_ids)
135 } else {
136 let casted_inputs = self.cast_inputs_if_required(target, node, mapping, self.to_dt)?;
137
138 let new_op = if let Some(source_op) = node.op_as::<TypedSource>() {
139 let mut fact = source_op.fact.clone();
140 if fact.datum_type == self.from_dt {
141 fact.datum_type = self.to_dt;
142 }
143 Box::new(TypedSource::new(fact))
144 } else if let Some(konst) = node.op_as::<Const>() {
145 if konst.val().datum_type() == self.from_dt {
146 let wire = target.add_const(
147 format!("{}.{:?}", node.name, self.from_dt),
148 konst.val().clone(),
149 )?;
150 return target.wire_node(&node.name, cast(self.to_dt), &[wire]);
151 } else {
152 node.op.clone()
153 }
154 } else if let Some(cast_op) = node.op_as::<Cast>() {
155 if cast_op.to == self.from_dt {
156 Box::new(Cast { to: self.to_dt })
157 } else {
158 node.op.clone()
159 }
160 } else if let Some(ew) = node.op_as::<ElementWiseOp>() {
161 if ew.1 == Some(self.from_dt) {
162 Box::new(ElementWiseOp(ew.0.clone(), Some(self.to_dt)))
163 } else {
164 node.op.clone()
165 }
166 } else if let Some(bin) = node.op_as::<TypedBinOp>() {
167 if bin.1 == Some(self.from_dt) {
168 Box::new(TypedBinOp(bin.0.clone(), Some(self.to_dt)))
169 } else {
170 node.op.clone()
171 }
172 } else if let Some(op) = node.op_as::<Scan>() {
173 let body = FloatPrecisionTranslator::new(self.from_dt, self.to_dt)
174 .translate_model(&op.body)?;
175 Box::new(Scan { body, ..op.clone() })
176 } else if let Some(op) = node.op_as::<EinSum>() {
177 let operating_dt =
178 if op.operating_dt == self.from_dt { self.to_dt } else { op.operating_dt };
179 Box::new(EinSum { operating_dt, ..op.clone() })
180 } else if let Some(op) = node.op_as::<Pad>() {
181 if let PadMode::Constant(t) = &op.mode {
182 let new_t = if t.datum_type() == self.from_dt {
183 t.cast_to_dt(self.to_dt)?.into_owned().into_arc_tensor()
184 } else {
185 Arc::clone(t)
186 };
187 Box::new(Pad { mode: PadMode::Constant(new_t), ..op.clone() })
188 } else {
189 Box::new(op.clone())
190 }
191 } else {
192 node.op.clone()
193 };
194 target.wire_node(&node.name, new_op, &casted_inputs)
195 }
196 }
197}
198
199#[cfg(test)]
200mod test {
201 use super::*;
202 use crate::ops::math;
203 use tract_data::prelude::f16;
204
205 fn build_f32_model() -> TractResult<TypedModel> {
206 let mut model = TypedModel::default();
208 let a = model.add_source("source", f32::fact([1])).unwrap();
209 let multiplier = model.add_const("multiplier", tensor1(&[1.0f32]))?;
210 let neg_infinity = model.add_const("neg_infinity", tensor1(&[f32::NEG_INFINITY]))?;
211 let pow_factor = model.add_const("pow_factor", tensor1(&[10.0f32]))?;
212 let add = model.wire_node("layer.0/add", math::add(), &[a, a]).unwrap()[0];
213 let mul = model.wire_node("layer.0/mul", math::mul(), &[add, multiplier]).unwrap()[0];
214 let pow = model.wire_node("layer.1/pow", math::pow(), &[mul, pow_factor]).unwrap()[0];
215 let _output = model
216 .wire_node("layer.1/add_neg_infinity", math::add(), &[pow, neg_infinity])
217 .unwrap()[0];
218 model.auto_outputs()?;
219 Ok(model)
220 }
221
222 #[test]
223 fn test_high_level_f16_transform_with_filter() -> TractResult<()> {
224 let model = build_f32_model()?;
226
227 let runnable_model = model.clone().into_runnable()?;
229 assert_eq!(
230 runnable_model.run(tvec![tensor1(&[5.0f32]).into()])?[0],
231 tensor1(&[f32::NEG_INFINITY]).into()
232 );
233
234 let runnable_model = &crate::transform::get_transform("f32_to_f16")?
236 .unwrap()
237 .transform_into(model.clone())?
238 .into_runnable()?;
239 assert!(
240 runnable_model.run(tvec![tensor1(&[f16::from_f32(5.0)]).into()])?[0]
241 .try_as_dense()?
242 .to_scalar::<f16>()?
243 .is_nan()
244 );
245
246 let runnable_model = &crate::transform::build_float_translator(
248 f32::datum_type(),
249 f16::datum_type(),
250 crate::transform::NodeFilter {
251 exclude: Some(vec!["layer.1".into()]),
252 ..Default::default()
253 },
254 )
255 .transform_into(model.clone())?
256 .into_runnable()?;
257 assert_eq!(
258 runnable_model.run(tvec![tensor1(&[f16::from_f32(5.0)]).into()])?[0],
259 tensor1(&[f16::NEG_INFINITY]).into()
260 );
261
262 let runnable_model = &crate::transform::build_float_translator(
264 f32::datum_type(),
265 f16::datum_type(),
266 crate::transform::NodeFilter {
267 exclude: Some(vec!["layer.0".into()]),
268 ..Default::default()
269 },
270 )
271 .transform_into(model)?
272 .into_runnable()?;
273 assert!(
274 runnable_model.run(tvec![tensor1(&[f16::from_f32(5.0)]).into()])?[0]
275 .try_as_dense()?
276 .to_scalar::<f16>()?
277 .is_nan()
278 );
279
280 Ok(())
281 }
282
283 #[test]
284 fn test_f16_transform_with_filter() -> TractResult<()> {
285 let model = build_f32_model()?;
287
288 let runnable_model = model.clone().into_runnable()?;
290 assert_eq!(
291 runnable_model.run(tvec![tensor1(&[5.0f32]).into()])?[0],
292 tensor1(&[f32::NEG_INFINITY]).into()
293 );
294
295 let mut model_f16 = model.clone();
297 model_f16
298 .transform(&FloatPrecisionTranslator::new(f32::datum_type(), f16::datum_type()))?;
299 let runnable_model_f16 = model_f16.clone().into_runnable()?;
300 assert!(
301 runnable_model_f16.run(tvec![tensor1(&[f16::from_f32(5.0)]).into()])?[0]
302 .try_as_dense()?
303 .to_scalar::<f16>()?
304 .is_nan()
305 );
306
307 let mut model_f16_with_filter = model.clone();
309 model_f16_with_filter.transform(&FloatPrecisionTranslator::with_filter(
310 f32::datum_type(),
311 f16::datum_type(),
312 |node| !node.name.contains("layer.1"),
313 ))?;
314 let runnable_model_f16 = model_f16_with_filter.clone().into_runnable()?;
315 assert_eq!(
316 runnable_model_f16.run(tvec![tensor1(&[f16::from_f32(5.0)]).into()])?[0],
317 tensor1(&[f16::NEG_INFINITY]).into()
318 );
319 let mut model_f16_with_filter = model.clone();
320 model_f16_with_filter.transform(&FloatPrecisionTranslator::with_filter(
321 f32::datum_type(),
322 f16::datum_type(),
323 |node| !node.name.contains("layer.0"),
324 ))?;
325 let runnable_model_f16 = model_f16_with_filter.clone().into_runnable()?;
326 assert!(
327 runnable_model_f16.run(tvec![tensor1(&[f16::from_f32(5.0)]).into()])?[0]
328 .try_as_dense()?
329 .to_scalar::<f16>()?
330 .is_nan()
331 );
332 Ok(())
333 }
334}