Skip to main content

tract_core/
transform.rs

1use crate::internal::*;
2#[cfg(feature = "blas")]
3use crate::ops::einsum::as_blas::AsBlas;
4use crate::ops::matmul::de_block_quant::BlockQuantTransform;
5use num_traits::Float;
6use std::fmt::Debug;
7
8use tract_data::TractResult;
9
10use crate::floats::FloatPrecisionTranslator;
11use crate::ops::nn::{Softmax, SoftmaxExp, SoftmaxKind, TypedModel};
12
13#[macro_export]
14macro_rules! rule_if {
15    ($cond:expr) => {
16        if !$cond {
17            return Ok(None);
18        }
19    };
20}
21
22#[macro_export]
23macro_rules! rule_if_let {
24    ($pat:pat = $expr:expr) => {
25        let $pat = $expr else {
26            return Ok(None);
27        };
28    };
29}
30
31#[macro_export]
32macro_rules! rule_if_some {
33    ($pat:pat = $expr:expr) => {
34        let Some($pat) = $expr else {
35            return Ok(None);
36        };
37    };
38}
39
40/// Build Float precision translator given a filter_predicate. If the filter_predicate is none or empty, all nodes will
41/// be translated during the transformation.
42///
43/// filter_predicate format:
44/// - `==node-name/layer,node-name-layer.1`: Only node which has a name that contains `node-name/layer` or `node-name-layer.1`
45/// - `!=node-name/layer,node-name-layer.1`: Only node which has a name that doesn't contain `node-name/layer` or `node-name-layer.1`
46pub fn build_float_translator<T1: Datum + Float, T2: Datum + Float>(
47    filter_predicate: Option<&str>,
48) -> Option<Box<dyn ModelTransform>> {
49    let Some(filter_predicate) = filter_predicate.filter(|f| !f.is_empty()) else {
50        return Some(Box::<FloatPrecisionTranslator<T1, T2>>::default());
51    };
52
53    if let Some(node_name_patterns) = filter_predicate.strip_prefix("!=") {
54        let patterns =
55            node_name_patterns.split(',').map(|it| it.trim().to_string()).collect::<Vec<_>>();
56        Some(Box::new(FloatPrecisionTranslator::<T1, T2>::with_filter(move |node| {
57            !patterns.iter().any(|p| node.name.contains(p))
58        })))
59    } else if let Some(node_name_patterns) = filter_predicate.strip_prefix("==") {
60        let patterns =
61            node_name_patterns.split(',').map(|it| it.trim().to_string()).collect::<Vec<_>>();
62        Some(Box::new(FloatPrecisionTranslator::<T1, T2>::with_filter(move |node| {
63            patterns.iter().any(|p| node.name.contains(p))
64        })))
65    } else {
66        None
67    }
68}
69
70pub trait ModelTransform: Debug {
71    fn name(&self) -> StaticName;
72    fn transform(&self, model: &mut TypedModel) -> TractResult<()>;
73    fn transform_into(&self, mut model: TypedModel) -> TractResult<TypedModel> {
74        self.transform(&mut model)?;
75        Ok(model)
76    }
77}
78
79#[derive(Debug)]
80struct SoftmaxFastCompact;
81
82impl ModelTransform for SoftmaxFastCompact {
83    fn name(&self) -> StaticName {
84        "softmax-fast-compact".into()
85    }
86
87    fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
88        for node in &mut model.nodes {
89            if let Some(softmax) = node.op_as_mut::<Softmax>() {
90                if let SoftmaxKind::Softmax(kind) = &mut softmax.kind {
91                    *kind = SoftmaxExp::FastCompact
92                }
93            }
94        }
95        Ok(())
96    }
97}
98
99#[allow(clippy::type_complexity)]
100pub struct ModelTransformFactory {
101    pub name: &'static str,
102    pub builder: fn(spec: &str) -> TractResult<Option<Box<dyn ModelTransform>>>,
103}
104
105inventory::collect!(ModelTransformFactory);
106
107#[macro_export]
108macro_rules! register_simple_model_transform {
109    ($name: expr, $type: expr) => {
110        $crate::internal::inventory::submit! {
111            $crate::transform::ModelTransformFactory {
112                name: $name,
113                builder: |_| Ok(Some(Box::new($type)))
114            }
115        }
116    };
117}
118
119pub fn get_transform(spec: &str) -> TractResult<Option<Box<dyn ModelTransform>>> {
120    for factory in inventory::iter::<ModelTransformFactory>() {
121        if spec.starts_with(factory.name) {
122            return (factory.builder)(spec);
123        }
124    }
125    Ok(None)
126}
127
128register_simple_model_transform!("softmax-fast-compact", SoftmaxFastCompact);
129#[cfg(feature = "blas")]
130register_simple_model_transform!("as-blas", AsBlas);
131register_simple_model_transform!("block-quant", BlockQuantTransform);
132
133inventory::submit! {
134    ModelTransformFactory {
135        name: "f32-to-f16",
136        builder: |spec| Ok(build_float_translator::<f32,f16>(spec.strip_prefix("f32-to-f16")))
137    }
138}
139
140inventory::submit! {
141    ModelTransformFactory {
142        name: "f16-to-f32",
143        builder: |spec| Ok(build_float_translator::<f16,f32>(spec.strip_prefix("f16-to-f32")))
144    }
145}