tract_core/
transform.rs

1use crate::internal::*;
2#[cfg(feature = "blas")]
3use crate::ops::einsum::as_blas::AsBlas;
4use crate::ops::matmul::de_block_quant::BlockQuantTransform;
5use num_traits::Float;
6use std::borrow::Cow;
7use std::fmt::Debug;
8
9use tract_data::TractResult;
10
11use crate::floats::FloatPrecisionTranslator;
12use crate::ops::nn::{Softmax, SoftmaxExp, TypedModel};
13
14pub fn get_transform(name: &str) -> Option<Box<dyn ModelTransform>> {
15    match name {
16        #[cfg(feature = "blas")]
17        "as-blas" => Some(Box::<AsBlas>::default()),
18        name if name.starts_with("f32-to-f16") => {
19            build_float_translator::<f32, f16>(name.strip_prefix("f32-to-f16"))
20        }
21        name if name.starts_with("f16-to-f32") => {
22            build_float_translator::<f16, f32>(name.strip_prefix("f16-to-f32"))
23        }
24        "softmax-fast-compact" => Some(Box::new(SoftmaxFastCompact)),
25        "block-quant" => Some(Box::new(BlockQuantTransform)),
26        _ => None,
27    }
28}
29
30/// Build Float precision translator given a filter_predicate. If the filter_predicate is none or empty, all nodes will
31/// be translated during the transformation.
32///
33/// filter_predicate format:
34/// - `==node-name/layer,node-name-layer.1`: Only node which has a name that contains `node-name/layer` or `node-name-layer.1`
35/// - `!=node-name/layer,node-name-layer.1`: Only node which has a name that doesn't contain `node-name/layer` or `node-name-layer.1`
36pub fn build_float_translator<T1: Datum + Float, T2: Datum + Float>(
37    filter_predicate: Option<&str>,
38) -> Option<Box<dyn ModelTransform>> {
39    let Some(filter_predicate) = filter_predicate.filter(|f| !f.is_empty()) else {
40        return Some(Box::<FloatPrecisionTranslator<T1, T2>>::default());
41    };
42
43    if let Some(node_name_patterns) = filter_predicate.strip_prefix("!=") {
44        let patterns =
45            node_name_patterns.split(',').map(|it| it.trim().to_string()).collect::<Vec<_>>();
46        Some(Box::new(FloatPrecisionTranslator::<T1, T2>::with_filter(move |node| {
47            !patterns.iter().any(|p| node.name.contains(p))
48        })))
49    } else if let Some(node_name_patterns) = filter_predicate.strip_prefix("==") {
50        let patterns =
51            node_name_patterns.split(',').map(|it| it.trim().to_string()).collect::<Vec<_>>();
52        Some(Box::new(FloatPrecisionTranslator::<T1, T2>::with_filter(move |node| {
53            patterns.iter().any(|p| node.name.contains(p))
54        })))
55    } else {
56        None
57    }
58}
59
60pub trait ModelTransform: Debug {
61    fn name(&self) -> Cow<str>;
62    fn transform(&self, model: &mut TypedModel) -> TractResult<()>;
63    fn transform_into(&self, mut model: TypedModel) -> TractResult<TypedModel> {
64        self.transform(&mut model)?;
65        Ok(model)
66    }
67}
68
69#[derive(Debug)]
70struct SoftmaxFastCompact;
71
72impl ModelTransform for SoftmaxFastCompact {
73    fn name(&self) -> Cow<str> {
74        "softmax-fast-compact".into()
75    }
76
77    fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
78        for node in &mut model.nodes {
79            if let Some(softmax) = node.op_as_mut::<Softmax>() {
80                softmax.exp = SoftmaxExp::FastCompact;
81            }
82        }
83        Ok(())
84    }
85}