tract_core/
transform.rs

1use crate::internal::*;
2#[cfg(feature = "blas")]
3use crate::ops::einsum::as_blas::AsBlas;
4use crate::ops::matmul::de_block_quant::BlockQuantTransform;
5use num_traits::Float;
6use std::fmt::Debug;
7
8use tract_data::TractResult;
9
10use crate::floats::FloatPrecisionTranslator;
11use crate::ops::nn::{Softmax, SoftmaxKind, SoftmaxExp, TypedModel};
12
13pub fn get_transform(name: &str) -> Option<Box<dyn ModelTransform>> {
14    match name {
15        #[cfg(feature = "blas")]
16        "as-blas" => Some(Box::<AsBlas>::default()),
17        name if name.starts_with("f32-to-f16") => {
18            build_float_translator::<f32, f16>(name.strip_prefix("f32-to-f16"))
19        }
20        name if name.starts_with("f16-to-f32") => {
21            build_float_translator::<f16, f32>(name.strip_prefix("f16-to-f32"))
22        }
23        "softmax-fast-compact" => Some(Box::new(SoftmaxFastCompact)),
24        "block-quant" => Some(Box::new(BlockQuantTransform)),
25        _ => None,
26    }
27}
28
29/// Build Float precision translator given a filter_predicate. If the filter_predicate is none or empty, all nodes will
30/// be translated during the transformation.
31///
32/// filter_predicate format:
33/// - `==node-name/layer,node-name-layer.1`: Only node which has a name that contains `node-name/layer` or `node-name-layer.1`
34/// - `!=node-name/layer,node-name-layer.1`: Only node which has a name that doesn't contain `node-name/layer` or `node-name-layer.1`
35pub fn build_float_translator<T1: Datum + Float, T2: Datum + Float>(
36    filter_predicate: Option<&str>,
37) -> Option<Box<dyn ModelTransform>> {
38    let Some(filter_predicate) = filter_predicate.filter(|f| !f.is_empty()) else {
39        return Some(Box::<FloatPrecisionTranslator<T1, T2>>::default());
40    };
41
42    if let Some(node_name_patterns) = filter_predicate.strip_prefix("!=") {
43        let patterns =
44            node_name_patterns.split(',').map(|it| it.trim().to_string()).collect::<Vec<_>>();
45        Some(Box::new(FloatPrecisionTranslator::<T1, T2>::with_filter(move |node| {
46            !patterns.iter().any(|p| node.name.contains(p))
47        })))
48    } else if let Some(node_name_patterns) = filter_predicate.strip_prefix("==") {
49        let patterns =
50            node_name_patterns.split(',').map(|it| it.trim().to_string()).collect::<Vec<_>>();
51        Some(Box::new(FloatPrecisionTranslator::<T1, T2>::with_filter(move |node| {
52            patterns.iter().any(|p| node.name.contains(p))
53        })))
54    } else {
55        None
56    }
57}
58
59pub trait ModelTransform: Debug {
60    fn name(&self) -> StaticName;
61    fn transform(&self, model: &mut TypedModel) -> TractResult<()>;
62    fn transform_into(&self, mut model: TypedModel) -> TractResult<TypedModel> {
63        self.transform(&mut model)?;
64        Ok(model)
65    }
66}
67
68#[derive(Debug)]
69struct SoftmaxFastCompact;
70
71impl ModelTransform for SoftmaxFastCompact {
72    fn name(&self) -> StaticName {
73        "softmax-fast-compact".into()
74    }
75
76    fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
77        for node in &mut model.nodes {
78            if let Some(softmax) = node.op_as_mut::<Softmax>() {
79                if let SoftmaxKind::Softmax(kind) = &mut softmax.kind { *kind = SoftmaxExp::FastCompact }
80            }
81        }
82        Ok(())
83    }
84}