1use std::borrow::Cow;
2
3use crate::internal::*;
4use crate::ops::matmul::de_block_quant::BlockQuantTransform;
5use std::fmt::Debug;
6
7use tract_data::TractResult;
8
9use crate::floats::FloatPrecisionTranslator;
10use crate::ops::nn::{Softmax, SoftmaxExp, SoftmaxKind, TypedModel};
11
12#[macro_export]
13macro_rules! rule_if {
14 ($cond:expr) => {
15 if !$cond {
16 return Ok(None);
17 }
18 };
19}
20
21#[macro_export]
22macro_rules! rule_if_let {
23 ($pat:pat = $expr:expr) => {
24 let $pat = $expr else {
25 return Ok(None);
26 };
27 };
28}
29
30#[macro_export]
31macro_rules! rule_if_some {
32 ($pat:pat = $expr:expr) => {
33 let Some($pat) = $expr else {
34 return Ok(None);
35 };
36 };
37}
38
39#[derive(Debug, Clone, Default)]
44pub struct NodeFilter {
45 pub include: Option<Vec<String>>,
46 pub exclude: Option<Vec<String>>,
47}
48
49impl NodeFilter {
50 pub fn matches(&self, name: &str) -> bool {
52 let dominated = match &self.include {
53 Some(patterns) => patterns.iter().any(|p| name.contains(p)),
54 None => true,
55 };
56 if !dominated {
57 return false;
58 }
59 match &self.exclude {
60 Some(patterns) => !patterns.iter().any(|p| name.contains(p)),
61 None => true,
62 }
63 }
64
65 pub fn is_pass_through(&self) -> bool {
67 self.include.is_none() && self.exclude.is_none()
68 }
69}
70
71pub fn parse_legacy_filter(filter: Option<&str>) -> TractResult<NodeFilter> {
73 let Some(filter) = filter.filter(|f| !f.is_empty()) else {
74 return Ok(NodeFilter::default());
75 };
76 if let Some(patterns) = filter.strip_prefix("!=") {
77 let patterns = patterns.split(',').map(|it| it.trim().to_string()).collect();
78 Ok(NodeFilter { exclude: Some(patterns), ..Default::default() })
79 } else if let Some(patterns) = filter.strip_prefix("==") {
80 let patterns = patterns.split(',').map(|it| it.trim().to_string()).collect();
81 Ok(NodeFilter { include: Some(patterns), ..Default::default() })
82 } else {
83 Ok(NodeFilter::default())
84 }
85}
86
87pub fn build_float_translator(
90 from_dt: DatumType,
91 to_dt: DatumType,
92 filter: NodeFilter,
93) -> Box<dyn ModelTransform> {
94 if filter.is_pass_through() {
95 return Box::new(FloatPrecisionTranslator::new(from_dt, to_dt));
96 }
97 Box::new(FloatPrecisionTranslator::with_filter(from_dt, to_dt, move |node| {
98 filter.matches(&node.name)
99 }))
100}
101
102pub trait ModelTransform: Debug {
103 fn name(&self) -> StaticName;
104 fn transform(&self, model: &mut TypedModel) -> TractResult<()>;
105 fn transform_into(&self, mut model: TypedModel) -> TractResult<TypedModel> {
106 self.transform(&mut model)?;
107 Ok(model)
108 }
109}
110
111#[derive(Debug)]
112struct SoftmaxFastCompact;
113
114impl ModelTransform for SoftmaxFastCompact {
115 fn name(&self) -> StaticName {
116 "softmax_fast_compact".into()
117 }
118
119 fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
120 for node in &mut model.nodes {
121 if let Some(softmax) = node.op_as_mut::<Softmax>()
122 && let SoftmaxKind::Softmax(kind) = &mut softmax.kind
123 {
124 *kind = SoftmaxExp::FastCompact
125 }
126 }
127 Ok(())
128 }
129}
130
131#[derive(Debug, Default, serde::Deserialize)]
133pub struct FloatTranslatorConfig {
134 #[serde(default)]
136 pub filter: Option<String>,
137 #[serde(default)]
139 pub include: Option<Vec<String>>,
140 #[serde(default)]
142 pub exclude: Option<Vec<String>>,
143}
144
145impl FloatTranslatorConfig {
146 pub fn into_node_filter(self) -> TractResult<NodeFilter> {
147 if self.include.is_some() || self.exclude.is_some() {
148 Ok(NodeFilter { include: self.include, exclude: self.exclude })
149 } else {
150 parse_legacy_filter(self.filter.as_deref())
151 }
152 }
153}
154
155#[derive(Debug, serde::Deserialize)]
157pub struct FloatPrecisionConfig {
158 pub from: String,
159 pub to: String,
160 #[serde(default)]
162 pub include: Option<Vec<String>>,
163 #[serde(default)]
165 pub exclude: Option<Vec<String>>,
166}
167
168pub struct ModelTransformFactory {
169 pub name: &'static str,
170 pub build_default: fn() -> TractResult<Box<dyn ModelTransform>>,
172 pub build: fn(&mut dyn erased_serde::Deserializer) -> TractResult<Box<dyn ModelTransform>>,
174}
175
176inventory::collect!(ModelTransformFactory);
177
178#[macro_export]
179macro_rules! register_simple_model_transform {
180 ($name: expr, $type: expr) => {
181 $crate::internal::inventory::submit! {
182 $crate::transform::ModelTransformFactory {
183 name: $name,
184 build_default: || Ok(Box::new($type)),
185 build: |_de| Ok(Box::new($type)),
186 }
187 }
188 };
189}
190
191#[macro_export]
192macro_rules! register_model_transform {
193 ($name:expr, $config:ty, $builder:expr) => {
194 $crate::internal::inventory::submit! {
195 $crate::transform::ModelTransformFactory {
196 name: $name,
197 build_default: || {
198 let config = <$config>::default();
199 let builder: fn($config) -> $crate::prelude::TractResult<Box<dyn $crate::transform::ModelTransform>> = $builder;
200 builder(config)
201 },
202 build: |de: &mut dyn erased_serde::Deserializer| {
203 let config: $config = erased_serde::deserialize(de)
204 .map_err(|e| $crate::internal::anyhow!("deserializing transform config: {e}"))?;
205 let builder: fn($config) -> $crate::prelude::TractResult<Box<dyn $crate::transform::ModelTransform>> = $builder;
206 builder(config)
207 },
208 }
209 }
210 };
211}
212
213pub fn split_spec(spec: &str) -> (Cow<'_, str>, &str) {
215 if let Some(pos) = spec.find('(') {
216 (Cow::Borrowed(&spec[..pos]), &spec[pos..])
217 } else if spec.contains('-') {
218 (Cow::Owned(spec.replace('-', "_")), "")
220 } else {
221 (Cow::Borrowed(spec), "")
222 }
223}
224
225pub fn get_transform(name: &str) -> TractResult<Option<Box<dyn ModelTransform>>> {
227 let (name, _) = split_spec(name);
228 for factory in inventory::iter::<ModelTransformFactory>() {
229 if factory.name == &*name {
230 return Ok(Some((factory.build_default)()?));
231 }
232 }
233 Ok(None)
234}
235
236pub fn get_transform_with_params(
238 name: &str,
239 de: &mut dyn erased_serde::Deserializer,
240) -> TractResult<Option<Box<dyn ModelTransform>>> {
241 for factory in inventory::iter::<ModelTransformFactory>() {
242 if factory.name == name {
243 return Ok(Some((factory.build)(de)?));
244 }
245 }
246 Ok(None)
247}
248
249#[derive(Debug, serde::Deserialize)]
252#[serde(untagged)]
253pub enum SymbolValueSpec {
254 Int(i64),
255 Expr(String),
256}
257
258#[derive(Debug, Default, serde::Deserialize)]
259pub struct SetSymbolsConfig {
260 pub values: std::collections::HashMap<String, SymbolValueSpec>,
261}
262
263#[derive(Debug)]
264struct SetSymbolsTransform(SetSymbolsConfig);
265
266impl ModelTransform for SetSymbolsTransform {
267 fn name(&self) -> StaticName {
268 "set_symbols".into()
269 }
270
271 fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
272 let mut subs = std::collections::HashMap::new();
273 for (k, spec) in &self.0.values {
274 let sym = model.symbols.sym(k);
275 let dim = match spec {
276 SymbolValueSpec::Int(v) => TDim::Val(*v),
277 SymbolValueSpec::Expr(s) => model
278 .symbols
279 .parse_tdim(s)
280 .with_context(|| format!("Parsing TDim expression {s:?} for symbol {k}"))?,
281 };
282 subs.insert(sym, dim);
283 }
284 *model = model.set_symbols(&subs)?;
285 Ok(())
286 }
287}
288
289register_model_transform!("set_symbols", SetSymbolsConfig, |config| Ok(Box::new(
290 SetSymbolsTransform(config)
291)));
292
293#[derive(Debug)]
305struct ForceScanExternalState;
306
307impl ModelTransform for ForceScanExternalState {
308 fn name(&self) -> StaticName {
309 "force_scan_external_state".into()
310 }
311
312 fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
313 use crate::ops::scan::{InputMapping, Scan};
314 let mut subs: HashMap<Symbol, TDim> = HashMap::new();
315 for node in &model.nodes {
316 let Some(scan) = node.op_as::<Scan>() else { continue };
317 for (slot, mapping) in scan.input_mapping.iter().enumerate() {
318 let InputMapping::Scan(info) = mapping else { continue };
319 let outer = node.inputs[slot];
320 let dim = &model.outlet_fact(outer)?.shape[info.axis];
321 if let TDim::Sym(s) = dim {
322 subs.insert(s.clone(), TDim::Val(1));
323 }
324 }
325 }
326 if !subs.is_empty() {
327 *model = model.set_symbols(&subs)?;
328 }
329 for node in &mut model.nodes {
330 if let Some(scan) = node.op_as_mut::<Scan>() {
331 scan.external_state = true;
332 }
333 }
334 Ok(())
335 }
336}
337
338register_simple_model_transform!("force_scan_external_state", ForceScanExternalState);
339
340register_simple_model_transform!("softmax_fast_compact", SoftmaxFastCompact);
341register_simple_model_transform!("block_quant", BlockQuantTransform);
342
343#[derive(Debug, serde::Deserialize, Default)]
344pub struct SelectOutputsConfig {
345 pub outputs: Vec<String>,
346}
347
348#[derive(Debug)]
349struct SelectOutputsTransform(SelectOutputsConfig);
350
351impl ModelTransform for SelectOutputsTransform {
352 fn name(&self) -> StaticName {
353 "select_outputs".into()
354 }
355
356 fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
357 model.select_outputs_by_name(self.0.outputs.iter())
358 }
359}
360
361register_model_transform!("select_outputs", SelectOutputsConfig, |config| Ok(Box::new(
362 SelectOutputsTransform(config)
363)));
364
365#[derive(Debug, serde::Deserialize, Default)]
366pub struct SelectInputsConfig {
367 pub inputs: Vec<String>,
368}
369
370#[derive(Debug)]
371struct SelectInputsTransform(SelectInputsConfig);
372
373impl ModelTransform for SelectInputsTransform {
374 fn name(&self) -> StaticName {
375 "select_inputs".into()
376 }
377
378 fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
379 model.select_inputs_by_name(self.0.inputs.iter())
380 }
381}
382
383register_model_transform!("select_inputs", SelectInputsConfig, |config| Ok(Box::new(
384 SelectInputsTransform(config)
385)));
386
387inventory::submit! {
388 ModelTransformFactory {
389 name: "f32_to_f16",
390 build_default: || Ok(build_float_translator(DatumType::F32, DatumType::F16, NodeFilter::default())),
391 build: |de| {
392 let config: FloatTranslatorConfig = erased_serde::deserialize(de)
393 .map_err(|e| anyhow::anyhow!("deserializing f32_to_f16 config: {e}"))?;
394 Ok(build_float_translator(DatumType::F32, DatumType::F16, config.into_node_filter()?))
395 },
396 }
397}
398
399inventory::submit! {
400 ModelTransformFactory {
401 name: "f16_to_f32",
402 build_default: || Ok(build_float_translator(DatumType::F16, DatumType::F32, NodeFilter::default())),
403 build: |de| {
404 let config: FloatTranslatorConfig = erased_serde::deserialize(de)
405 .map_err(|e| anyhow::anyhow!("deserializing f16_to_f32 config: {e}"))?;
406 Ok(build_float_translator(DatumType::F16, DatumType::F32, config.into_node_filter()?))
407 },
408 }
409}
410
411inventory::submit! {
412 ModelTransformFactory {
413 name: "float_precision",
414 build_default: || {
415 anyhow::bail!("float_precision transform requires 'from' and 'to' parameters")
416 },
417 build: |de| {
418 let config: FloatPrecisionConfig = erased_serde::deserialize(de)
419 .map_err(|e| anyhow::anyhow!("deserializing float_precision config: {e}"))?;
420 let from_dt: DatumType = config.from.parse()
421 .map_err(|e| anyhow::anyhow!("parsing 'from' datum type: {e}"))?;
422 let to_dt: DatumType = config.to.parse()
423 .map_err(|e| anyhow::anyhow!("parsing 'to' datum type: {e}"))?;
424 let filter = NodeFilter { include: config.include, exclude: config.exclude };
425 Ok(build_float_translator(from_dt, to_dt, filter))
426 },
427 }
428}