Skip to main content

tract_core/
transform.rs

1use std::borrow::Cow;
2
3use crate::internal::*;
4use crate::ops::matmul::de_block_quant::BlockQuantTransform;
5use std::fmt::Debug;
6
7use tract_data::TractResult;
8
9use crate::floats::FloatPrecisionTranslator;
10use crate::ops::nn::{Softmax, SoftmaxExp, SoftmaxKind, TypedModel};
11
12#[macro_export]
13macro_rules! rule_if {
14    ($cond:expr) => {
15        if !$cond {
16            return Ok(None);
17        }
18    };
19}
20
21#[macro_export]
22macro_rules! rule_if_let {
23    ($pat:pat = $expr:expr) => {
24        let $pat = $expr else {
25            return Ok(None);
26        };
27    };
28}
29
30#[macro_export]
31macro_rules! rule_if_some {
32    ($pat:pat = $expr:expr) => {
33        let Some($pat) = $expr else {
34            return Ok(None);
35        };
36    };
37}
38
39/// Structured include/exclude filter for node names.
40///
41/// If `include` is `None`, all nodes are candidates; if `Some`, only nodes matching
42/// at least one pattern are included. `exclude` then removes from that set.
43#[derive(Debug, Clone, Default)]
44pub struct NodeFilter {
45    pub include: Option<Vec<String>>,
46    pub exclude: Option<Vec<String>>,
47}
48
49impl NodeFilter {
50    /// Returns `true` if the given node name passes the filter.
51    pub fn matches(&self, name: &str) -> bool {
52        let dominated = match &self.include {
53            Some(patterns) => patterns.iter().any(|p| name.contains(p)),
54            None => true,
55        };
56        if !dominated {
57            return false;
58        }
59        match &self.exclude {
60            Some(patterns) => !patterns.iter().any(|p| name.contains(p)),
61            None => true,
62        }
63    }
64
65    /// Returns `true` when neither include nor exclude is set.
66    pub fn is_pass_through(&self) -> bool {
67        self.include.is_none() && self.exclude.is_none()
68    }
69}
70
71/// Parse a legacy filter string (`"!=..."` / `"==..."`) into a `NodeFilter`.
72pub fn parse_legacy_filter(filter: Option<&str>) -> TractResult<NodeFilter> {
73    let Some(filter) = filter.filter(|f| !f.is_empty()) else {
74        return Ok(NodeFilter::default());
75    };
76    if let Some(patterns) = filter.strip_prefix("!=") {
77        let patterns = patterns.split(',').map(|it| it.trim().to_string()).collect();
78        Ok(NodeFilter { exclude: Some(patterns), ..Default::default() })
79    } else if let Some(patterns) = filter.strip_prefix("==") {
80        let patterns = patterns.split(',').map(|it| it.trim().to_string()).collect();
81        Ok(NodeFilter { include: Some(patterns), ..Default::default() })
82    } else {
83        Ok(NodeFilter::default())
84    }
85}
86
87/// Build Float precision translator given a `NodeFilter`. If the filter is pass-through,
88/// all nodes will be translated during the transformation.
89pub fn build_float_translator(
90    from_dt: DatumType,
91    to_dt: DatumType,
92    filter: NodeFilter,
93) -> Box<dyn ModelTransform> {
94    if filter.is_pass_through() {
95        return Box::new(FloatPrecisionTranslator::new(from_dt, to_dt));
96    }
97    Box::new(FloatPrecisionTranslator::with_filter(from_dt, to_dt, move |node| {
98        filter.matches(&node.name)
99    }))
100}
101
102pub trait ModelTransform: Debug {
103    fn name(&self) -> StaticName;
104    fn transform(&self, model: &mut TypedModel) -> TractResult<()>;
105    fn transform_into(&self, mut model: TypedModel) -> TractResult<TypedModel> {
106        self.transform(&mut model)?;
107        Ok(model)
108    }
109}
110
111#[derive(Debug)]
112struct SoftmaxFastCompact;
113
114impl ModelTransform for SoftmaxFastCompact {
115    fn name(&self) -> StaticName {
116        "softmax_fast_compact".into()
117    }
118
119    fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
120        for node in &mut model.nodes {
121            if let Some(softmax) = node.op_as_mut::<Softmax>()
122                && let SoftmaxKind::Softmax(kind) = &mut softmax.kind
123            {
124                *kind = SoftmaxExp::FastCompact
125            }
126        }
127        Ok(())
128    }
129}
130
131/// Config for float precision transforms (f32_to_f16, f16_to_f32).
132#[derive(Debug, Default, serde::Deserialize)]
133pub struct FloatTranslatorConfig {
134    /// Legacy filter string (`"!=..."` / `"==..."`).
135    #[serde(default)]
136    pub filter: Option<String>,
137    /// Include patterns — only nodes matching at least one pattern are translated.
138    #[serde(default)]
139    pub include: Option<Vec<String>>,
140    /// Exclude patterns — matching nodes are excluded from translation.
141    #[serde(default)]
142    pub exclude: Option<Vec<String>>,
143}
144
145impl FloatTranslatorConfig {
146    pub fn into_node_filter(self) -> TractResult<NodeFilter> {
147        if self.include.is_some() || self.exclude.is_some() {
148            Ok(NodeFilter { include: self.include, exclude: self.exclude })
149        } else {
150            parse_legacy_filter(self.filter.as_deref())
151        }
152    }
153}
154
155/// Config for the `float_precision` transform.
156#[derive(Debug, serde::Deserialize)]
157pub struct FloatPrecisionConfig {
158    pub from: String,
159    pub to: String,
160    /// Include patterns — only nodes matching at least one pattern are translated.
161    #[serde(default)]
162    pub include: Option<Vec<String>>,
163    /// Exclude patterns — matching nodes are excluded from translation.
164    #[serde(default)]
165    pub exclude: Option<Vec<String>>,
166}
167
168pub struct ModelTransformFactory {
169    pub name: &'static str,
170    /// Build with default config (no params).
171    pub build_default: fn() -> TractResult<Box<dyn ModelTransform>>,
172    /// Build from a type-erased deserializer.
173    pub build: fn(&mut dyn erased_serde::Deserializer) -> TractResult<Box<dyn ModelTransform>>,
174}
175
176inventory::collect!(ModelTransformFactory);
177
178#[macro_export]
179macro_rules! register_simple_model_transform {
180    ($name: expr, $type: expr) => {
181        $crate::internal::inventory::submit! {
182            $crate::transform::ModelTransformFactory {
183                name: $name,
184                build_default: || Ok(Box::new($type)),
185                build: |_de| Ok(Box::new($type)),
186            }
187        }
188    };
189}
190
191#[macro_export]
192macro_rules! register_model_transform {
193    ($name:expr, $config:ty, $builder:expr) => {
194        $crate::internal::inventory::submit! {
195            $crate::transform::ModelTransformFactory {
196                name: $name,
197                build_default: || {
198                    let config = <$config>::default();
199                    let builder: fn($config) -> $crate::prelude::TractResult<Box<dyn $crate::transform::ModelTransform>> = $builder;
200                    builder(config)
201                },
202                build: |de: &mut dyn erased_serde::Deserializer| {
203                    let config: $config = erased_serde::deserialize(de)
204                        .map_err(|e| $crate::internal::anyhow!("deserializing transform config: {e}"))?;
205                    let builder: fn($config) -> $crate::prelude::TractResult<Box<dyn $crate::transform::ModelTransform>> = $builder;
206                    builder(config)
207                },
208            }
209        }
210    };
211}
212
213/// Split a transform spec like `"f32_to_f16(filter: \"!=layer.norm\")"` into name and params.
214pub fn split_spec(spec: &str) -> (Cow<'_, str>, &str) {
215    if let Some(pos) = spec.find('(') {
216        (Cow::Borrowed(&spec[..pos]), &spec[pos..])
217    } else if spec.contains('-') {
218        // Backward compat: simple name with no params, convert kebab→snake
219        (Cow::Owned(spec.replace('-', "_")), "")
220    } else {
221        (Cow::Borrowed(spec), "")
222    }
223}
224
225/// Look up a transform by name, using default config.
226pub fn get_transform(name: &str) -> TractResult<Option<Box<dyn ModelTransform>>> {
227    let (name, _) = split_spec(name);
228    for factory in inventory::iter::<ModelTransformFactory>() {
229        if factory.name == &*name {
230            return Ok(Some((factory.build_default)()?));
231        }
232    }
233    Ok(None)
234}
235
236/// Look up a transform by name, deserializing config from the given deserializer.
237pub fn get_transform_with_params(
238    name: &str,
239    de: &mut dyn erased_serde::Deserializer,
240) -> TractResult<Option<Box<dyn ModelTransform>>> {
241    for factory in inventory::iter::<ModelTransformFactory>() {
242        if factory.name == name {
243            return Ok(Some((factory.build)(de)?));
244        }
245    }
246    Ok(None)
247}
248
249/// Per-symbol substitution: either a concrete integer or a TDim
250/// expression string parsed against the model's symbol scope.
251#[derive(Debug, serde::Deserialize)]
252#[serde(untagged)]
253pub enum SymbolValueSpec {
254    Int(i64),
255    Expr(String),
256}
257
258#[derive(Debug, Default, serde::Deserialize)]
259pub struct SetSymbolsConfig {
260    pub values: std::collections::HashMap<String, SymbolValueSpec>,
261}
262
263#[derive(Debug)]
264struct SetSymbolsTransform(SetSymbolsConfig);
265
266impl ModelTransform for SetSymbolsTransform {
267    fn name(&self) -> StaticName {
268        "set_symbols".into()
269    }
270
271    fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
272        let mut subs = std::collections::HashMap::new();
273        for (k, spec) in &self.0.values {
274            let sym = model.symbols.sym(k);
275            let dim = match spec {
276                SymbolValueSpec::Int(v) => TDim::Val(*v),
277                SymbolValueSpec::Expr(s) => model
278                    .symbols
279                    .parse_tdim(s)
280                    .with_context(|| format!("Parsing TDim expression {s:?} for symbol {k}"))?,
281            };
282            subs.insert(sym, dim);
283        }
284        *model = model.set_symbols(&subs)?;
285        Ok(())
286    }
287}
288
289register_model_transform!("set_symbols", SetSymbolsConfig, |config| Ok(Box::new(
290    SetSymbolsTransform(config)
291)));
292
293/// Ad-hoc fix-up for NNEF artifacts exported before Scan grew the
294/// `external_state` flag (issue #2157). For every Scan in the model:
295/// 1. Substitute the scan-axis symbol on the Scan input with 1 across the
296///    whole model (caller is bound by the per-call seq=1 contract that
297///    external state management implies).
298/// 2. Set `external_state = true`.
299///
300/// After this transform, the standard declutter pipeline sees `iters == 1`
301/// on each Scan and `declutter_single_loop` inlines the body. Apply only
302/// when the loaded model is known to use external state management, e.g.
303/// the parakeet decoder. Cheaper than re-exporting cached NNEF.
304#[derive(Debug)]
305struct ForceScanExternalState;
306
307impl ModelTransform for ForceScanExternalState {
308    fn name(&self) -> StaticName {
309        "force_scan_external_state".into()
310    }
311
312    fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
313        use crate::ops::scan::{InputMapping, Scan};
314        let mut subs: HashMap<Symbol, TDim> = HashMap::new();
315        for node in &model.nodes {
316            let Some(scan) = node.op_as::<Scan>() else { continue };
317            for (slot, mapping) in scan.input_mapping.iter().enumerate() {
318                let InputMapping::Scan(info) = mapping else { continue };
319                let outer = node.inputs[slot];
320                let dim = &model.outlet_fact(outer)?.shape[info.axis];
321                if let TDim::Sym(s) = dim {
322                    subs.insert(s.clone(), TDim::Val(1));
323                }
324            }
325        }
326        if !subs.is_empty() {
327            *model = model.set_symbols(&subs)?;
328        }
329        for node in &mut model.nodes {
330            if let Some(scan) = node.op_as_mut::<Scan>() {
331                scan.external_state = true;
332            }
333        }
334        Ok(())
335    }
336}
337
338register_simple_model_transform!("force_scan_external_state", ForceScanExternalState);
339
340register_simple_model_transform!("softmax_fast_compact", SoftmaxFastCompact);
341register_simple_model_transform!("block_quant", BlockQuantTransform);
342
343#[derive(Debug, serde::Deserialize, Default)]
344pub struct SelectOutputsConfig {
345    pub outputs: Vec<String>,
346}
347
348#[derive(Debug)]
349struct SelectOutputsTransform(SelectOutputsConfig);
350
351impl ModelTransform for SelectOutputsTransform {
352    fn name(&self) -> StaticName {
353        "select_outputs".into()
354    }
355
356    fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
357        model.select_outputs_by_name(self.0.outputs.iter())
358    }
359}
360
361register_model_transform!("select_outputs", SelectOutputsConfig, |config| Ok(Box::new(
362    SelectOutputsTransform(config)
363)));
364
365#[derive(Debug, serde::Deserialize, Default)]
366pub struct SelectInputsConfig {
367    pub inputs: Vec<String>,
368}
369
370#[derive(Debug)]
371struct SelectInputsTransform(SelectInputsConfig);
372
373impl ModelTransform for SelectInputsTransform {
374    fn name(&self) -> StaticName {
375        "select_inputs".into()
376    }
377
378    fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
379        model.select_inputs_by_name(self.0.inputs.iter())
380    }
381}
382
383register_model_transform!("select_inputs", SelectInputsConfig, |config| Ok(Box::new(
384    SelectInputsTransform(config)
385)));
386
387inventory::submit! {
388    ModelTransformFactory {
389        name: "f32_to_f16",
390        build_default: || Ok(build_float_translator(DatumType::F32, DatumType::F16, NodeFilter::default())),
391        build: |de| {
392            let config: FloatTranslatorConfig = erased_serde::deserialize(de)
393                .map_err(|e| anyhow::anyhow!("deserializing f32_to_f16 config: {e}"))?;
394            Ok(build_float_translator(DatumType::F32, DatumType::F16, config.into_node_filter()?))
395        },
396    }
397}
398
399inventory::submit! {
400    ModelTransformFactory {
401        name: "f16_to_f32",
402        build_default: || Ok(build_float_translator(DatumType::F16, DatumType::F32, NodeFilter::default())),
403        build: |de| {
404            let config: FloatTranslatorConfig = erased_serde::deserialize(de)
405                .map_err(|e| anyhow::anyhow!("deserializing f16_to_f32 config: {e}"))?;
406            Ok(build_float_translator(DatumType::F16, DatumType::F32, config.into_node_filter()?))
407        },
408    }
409}
410
411inventory::submit! {
412    ModelTransformFactory {
413        name: "float_precision",
414        build_default: || {
415            anyhow::bail!("float_precision transform requires 'from' and 'to' parameters")
416        },
417        build: |de| {
418            let config: FloatPrecisionConfig = erased_serde::deserialize(de)
419                .map_err(|e| anyhow::anyhow!("deserializing float_precision config: {e}"))?;
420            let from_dt: DatumType = config.from.parse()
421                .map_err(|e| anyhow::anyhow!("parsing 'from' datum type: {e}"))?;
422            let to_dt: DatumType = config.to.parse()
423                .map_err(|e| anyhow::anyhow!("parsing 'to' datum type: {e}"))?;
424            let filter = NodeFilter { include: config.include, exclude: config.exclude };
425            Ok(build_float_translator(from_dt, to_dt, filter))
426        },
427    }
428}