Skip to main content

tract_core/
transform.rs

1use std::borrow::Cow;
2
3use crate::internal::*;
4#[cfg(feature = "blas")]
5use crate::ops::einsum::as_blas::AsBlas;
6use crate::ops::matmul::de_block_quant::BlockQuantTransform;
7use std::fmt::Debug;
8
9use tract_data::TractResult;
10
11use crate::floats::FloatPrecisionTranslator;
12use crate::ops::nn::{Softmax, SoftmaxExp, SoftmaxKind, TypedModel};
13
14#[macro_export]
15macro_rules! rule_if {
16    ($cond:expr) => {
17        if !$cond {
18            return Ok(None);
19        }
20    };
21}
22
23#[macro_export]
24macro_rules! rule_if_let {
25    ($pat:pat = $expr:expr) => {
26        let $pat = $expr else {
27            return Ok(None);
28        };
29    };
30}
31
32#[macro_export]
33macro_rules! rule_if_some {
34    ($pat:pat = $expr:expr) => {
35        let Some($pat) = $expr else {
36            return Ok(None);
37        };
38    };
39}
40
41/// Structured include/exclude filter for node names.
42///
43/// If `include` is `None`, all nodes are candidates; if `Some`, only nodes matching
44/// at least one pattern are included. `exclude` then removes from that set.
45#[derive(Debug, Clone, Default)]
46pub struct NodeFilter {
47    pub include: Option<Vec<String>>,
48    pub exclude: Option<Vec<String>>,
49}
50
51impl NodeFilter {
52    /// Returns `true` if the given node name passes the filter.
53    pub fn matches(&self, name: &str) -> bool {
54        let dominated = match &self.include {
55            Some(patterns) => patterns.iter().any(|p| name.contains(p)),
56            None => true,
57        };
58        if !dominated {
59            return false;
60        }
61        match &self.exclude {
62            Some(patterns) => !patterns.iter().any(|p| name.contains(p)),
63            None => true,
64        }
65    }
66
67    /// Returns `true` when neither include nor exclude is set.
68    pub fn is_pass_through(&self) -> bool {
69        self.include.is_none() && self.exclude.is_none()
70    }
71}
72
73/// Parse a legacy filter string (`"!=..."` / `"==..."`) into a `NodeFilter`.
74pub fn parse_legacy_filter(filter: Option<&str>) -> TractResult<NodeFilter> {
75    let Some(filter) = filter.filter(|f| !f.is_empty()) else {
76        return Ok(NodeFilter::default());
77    };
78    if let Some(patterns) = filter.strip_prefix("!=") {
79        let patterns = patterns.split(',').map(|it| it.trim().to_string()).collect();
80        Ok(NodeFilter { exclude: Some(patterns), ..Default::default() })
81    } else if let Some(patterns) = filter.strip_prefix("==") {
82        let patterns = patterns.split(',').map(|it| it.trim().to_string()).collect();
83        Ok(NodeFilter { include: Some(patterns), ..Default::default() })
84    } else {
85        Ok(NodeFilter::default())
86    }
87}
88
89/// Build Float precision translator given a `NodeFilter`. If the filter is pass-through,
90/// all nodes will be translated during the transformation.
91pub fn build_float_translator(
92    from_dt: DatumType,
93    to_dt: DatumType,
94    filter: NodeFilter,
95) -> Box<dyn ModelTransform> {
96    if filter.is_pass_through() {
97        return Box::new(FloatPrecisionTranslator::new(from_dt, to_dt));
98    }
99    Box::new(FloatPrecisionTranslator::with_filter(from_dt, to_dt, move |node| {
100        filter.matches(&node.name)
101    }))
102}
103
104pub trait ModelTransform: Debug {
105    fn name(&self) -> StaticName;
106    fn transform(&self, model: &mut TypedModel) -> TractResult<()>;
107    fn transform_into(&self, mut model: TypedModel) -> TractResult<TypedModel> {
108        self.transform(&mut model)?;
109        Ok(model)
110    }
111}
112
113#[derive(Debug)]
114struct SoftmaxFastCompact;
115
116impl ModelTransform for SoftmaxFastCompact {
117    fn name(&self) -> StaticName {
118        "softmax_fast_compact".into()
119    }
120
121    fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
122        for node in &mut model.nodes {
123            if let Some(softmax) = node.op_as_mut::<Softmax>() {
124                if let SoftmaxKind::Softmax(kind) = &mut softmax.kind {
125                    *kind = SoftmaxExp::FastCompact
126                }
127            }
128        }
129        Ok(())
130    }
131}
132
133/// Config for float precision transforms (f32_to_f16, f16_to_f32).
134#[derive(Debug, Default, serde::Deserialize)]
135pub struct FloatTranslatorConfig {
136    /// Legacy filter string (`"!=..."` / `"==..."`).
137    #[serde(default)]
138    pub filter: Option<String>,
139    /// Include patterns — only nodes matching at least one pattern are translated.
140    #[serde(default)]
141    pub include: Option<Vec<String>>,
142    /// Exclude patterns — matching nodes are excluded from translation.
143    #[serde(default)]
144    pub exclude: Option<Vec<String>>,
145}
146
147impl FloatTranslatorConfig {
148    pub fn into_node_filter(self) -> TractResult<NodeFilter> {
149        if self.include.is_some() || self.exclude.is_some() {
150            Ok(NodeFilter { include: self.include, exclude: self.exclude })
151        } else {
152            parse_legacy_filter(self.filter.as_deref())
153        }
154    }
155}
156
157/// Config for the `float_precision` transform.
158#[derive(Debug, serde::Deserialize)]
159pub struct FloatPrecisionConfig {
160    pub from: String,
161    pub to: String,
162    /// Include patterns — only nodes matching at least one pattern are translated.
163    #[serde(default)]
164    pub include: Option<Vec<String>>,
165    /// Exclude patterns — matching nodes are excluded from translation.
166    #[serde(default)]
167    pub exclude: Option<Vec<String>>,
168}
169
170pub struct ModelTransformFactory {
171    pub name: &'static str,
172    /// Build with default config (no params).
173    pub build_default: fn() -> TractResult<Box<dyn ModelTransform>>,
174    /// Build from a type-erased deserializer.
175    pub build: fn(&mut dyn erased_serde::Deserializer) -> TractResult<Box<dyn ModelTransform>>,
176}
177
178inventory::collect!(ModelTransformFactory);
179
180#[macro_export]
181macro_rules! register_simple_model_transform {
182    ($name: expr, $type: expr) => {
183        $crate::internal::inventory::submit! {
184            $crate::transform::ModelTransformFactory {
185                name: $name,
186                build_default: || Ok(Box::new($type)),
187                build: |_de| Ok(Box::new($type)),
188            }
189        }
190    };
191}
192
193#[macro_export]
194macro_rules! register_model_transform {
195    ($name:expr, $config:ty, $builder:expr) => {
196        $crate::internal::inventory::submit! {
197            $crate::transform::ModelTransformFactory {
198                name: $name,
199                build_default: || {
200                    let config = <$config>::default();
201                    let builder: fn($config) -> $crate::prelude::TractResult<Box<dyn $crate::transform::ModelTransform>> = $builder;
202                    builder(config)
203                },
204                build: |de: &mut dyn erased_serde::Deserializer| {
205                    let config: $config = erased_serde::deserialize(de)
206                        .map_err(|e| $crate::internal::anyhow!("deserializing transform config: {e}"))?;
207                    let builder: fn($config) -> $crate::prelude::TractResult<Box<dyn $crate::transform::ModelTransform>> = $builder;
208                    builder(config)
209                },
210            }
211        }
212    };
213}
214
215/// Split a transform spec like `"f32_to_f16(filter: \"!=layer.norm\")"` into name and params.
216pub fn split_spec(spec: &str) -> (Cow<'_, str>, &str) {
217    if let Some(pos) = spec.find('(') {
218        (Cow::Borrowed(&spec[..pos]), &spec[pos..])
219    } else if spec.contains('-') {
220        // Backward compat: simple name with no params, convert kebab→snake
221        (Cow::Owned(spec.replace('-', "_")), "")
222    } else {
223        (Cow::Borrowed(spec), "")
224    }
225}
226
227/// Look up a transform by name, using default config.
228pub fn get_transform(name: &str) -> TractResult<Option<Box<dyn ModelTransform>>> {
229    let (name, _) = split_spec(name);
230    for factory in inventory::iter::<ModelTransformFactory>() {
231        if factory.name == &*name {
232            return Ok(Some((factory.build_default)()?));
233        }
234    }
235    Ok(None)
236}
237
238/// Look up a transform by name, deserializing config from the given deserializer.
239pub fn get_transform_with_params(
240    name: &str,
241    de: &mut dyn erased_serde::Deserializer,
242) -> TractResult<Option<Box<dyn ModelTransform>>> {
243    for factory in inventory::iter::<ModelTransformFactory>() {
244        if factory.name == name {
245            return Ok(Some((factory.build)(de)?));
246        }
247    }
248    Ok(None)
249}
250
251#[derive(Debug, Default, serde::Deserialize)]
252pub struct ConcretizeSymbolsConfig {
253    pub values: std::collections::HashMap<String, i64>,
254}
255
256#[derive(Debug)]
257struct ConcretizeSymbolsTransform(ConcretizeSymbolsConfig);
258
259impl ModelTransform for ConcretizeSymbolsTransform {
260    fn name(&self) -> StaticName {
261        "concretize_symbols".into()
262    }
263
264    fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
265        let mut table = SymbolValues::default();
266        for (k, v) in &self.0.values {
267            table = table.with(&model.symbols.sym(k), *v);
268        }
269        *model = model.concretize_dims(&table)?;
270        Ok(())
271    }
272}
273
274register_model_transform!("concretize_symbols", ConcretizeSymbolsConfig, |config| Ok(Box::new(
275    ConcretizeSymbolsTransform(config)
276)));
277
278register_simple_model_transform!("softmax_fast_compact", SoftmaxFastCompact);
279#[cfg(feature = "blas")]
280register_simple_model_transform!("as_blas", AsBlas);
281register_simple_model_transform!("block_quant", BlockQuantTransform);
282
283inventory::submit! {
284    ModelTransformFactory {
285        name: "f32_to_f16",
286        build_default: || Ok(build_float_translator(DatumType::F32, DatumType::F16, NodeFilter::default())),
287        build: |de| {
288            let config: FloatTranslatorConfig = erased_serde::deserialize(de)
289                .map_err(|e| anyhow::anyhow!("deserializing f32_to_f16 config: {e}"))?;
290            Ok(build_float_translator(DatumType::F32, DatumType::F16, config.into_node_filter()?))
291        },
292    }
293}
294
295inventory::submit! {
296    ModelTransformFactory {
297        name: "f16_to_f32",
298        build_default: || Ok(build_float_translator(DatumType::F16, DatumType::F32, NodeFilter::default())),
299        build: |de| {
300            let config: FloatTranslatorConfig = erased_serde::deserialize(de)
301                .map_err(|e| anyhow::anyhow!("deserializing f16_to_f32 config: {e}"))?;
302            Ok(build_float_translator(DatumType::F16, DatumType::F32, config.into_node_filter()?))
303        },
304    }
305}
306
307inventory::submit! {
308    ModelTransformFactory {
309        name: "float_precision",
310        build_default: || {
311            anyhow::bail!("float_precision transform requires 'from' and 'to' parameters")
312        },
313        build: |de| {
314            let config: FloatPrecisionConfig = erased_serde::deserialize(de)
315                .map_err(|e| anyhow::anyhow!("deserializing float_precision config: {e}"))?;
316            let from_dt: DatumType = config.from.parse()
317                .map_err(|e| anyhow::anyhow!("parsing 'from' datum type: {e}"))?;
318            let to_dt: DatumType = config.to.parse()
319                .map_err(|e| anyhow::anyhow!("parsing 'to' datum type: {e}"))?;
320            let filter = NodeFilter { include: config.include, exclude: config.exclude };
321            Ok(build_float_translator(from_dt, to_dt, filter))
322        },
323    }
324}