1use std::borrow::Cow;
2
3use crate::internal::*;
4#[cfg(feature = "blas")]
5use crate::ops::einsum::as_blas::AsBlas;
6use crate::ops::matmul::de_block_quant::BlockQuantTransform;
7use std::fmt::Debug;
8
9use tract_data::TractResult;
10
11use crate::floats::FloatPrecisionTranslator;
12use crate::ops::nn::{Softmax, SoftmaxExp, SoftmaxKind, TypedModel};
13
14#[macro_export]
15macro_rules! rule_if {
16 ($cond:expr) => {
17 if !$cond {
18 return Ok(None);
19 }
20 };
21}
22
23#[macro_export]
24macro_rules! rule_if_let {
25 ($pat:pat = $expr:expr) => {
26 let $pat = $expr else {
27 return Ok(None);
28 };
29 };
30}
31
32#[macro_export]
33macro_rules! rule_if_some {
34 ($pat:pat = $expr:expr) => {
35 let Some($pat) = $expr else {
36 return Ok(None);
37 };
38 };
39}
40
41#[derive(Debug, Clone, Default)]
46pub struct NodeFilter {
47 pub include: Option<Vec<String>>,
48 pub exclude: Option<Vec<String>>,
49}
50
51impl NodeFilter {
52 pub fn matches(&self, name: &str) -> bool {
54 let dominated = match &self.include {
55 Some(patterns) => patterns.iter().any(|p| name.contains(p)),
56 None => true,
57 };
58 if !dominated {
59 return false;
60 }
61 match &self.exclude {
62 Some(patterns) => !patterns.iter().any(|p| name.contains(p)),
63 None => true,
64 }
65 }
66
67 pub fn is_pass_through(&self) -> bool {
69 self.include.is_none() && self.exclude.is_none()
70 }
71}
72
73pub fn parse_legacy_filter(filter: Option<&str>) -> TractResult<NodeFilter> {
75 let Some(filter) = filter.filter(|f| !f.is_empty()) else {
76 return Ok(NodeFilter::default());
77 };
78 if let Some(patterns) = filter.strip_prefix("!=") {
79 let patterns = patterns.split(',').map(|it| it.trim().to_string()).collect();
80 Ok(NodeFilter { exclude: Some(patterns), ..Default::default() })
81 } else if let Some(patterns) = filter.strip_prefix("==") {
82 let patterns = patterns.split(',').map(|it| it.trim().to_string()).collect();
83 Ok(NodeFilter { include: Some(patterns), ..Default::default() })
84 } else {
85 Ok(NodeFilter::default())
86 }
87}
88
89pub fn build_float_translator(
92 from_dt: DatumType,
93 to_dt: DatumType,
94 filter: NodeFilter,
95) -> Box<dyn ModelTransform> {
96 if filter.is_pass_through() {
97 return Box::new(FloatPrecisionTranslator::new(from_dt, to_dt));
98 }
99 Box::new(FloatPrecisionTranslator::with_filter(from_dt, to_dt, move |node| {
100 filter.matches(&node.name)
101 }))
102}
103
104pub trait ModelTransform: Debug {
105 fn name(&self) -> StaticName;
106 fn transform(&self, model: &mut TypedModel) -> TractResult<()>;
107 fn transform_into(&self, mut model: TypedModel) -> TractResult<TypedModel> {
108 self.transform(&mut model)?;
109 Ok(model)
110 }
111}
112
113#[derive(Debug)]
114struct SoftmaxFastCompact;
115
116impl ModelTransform for SoftmaxFastCompact {
117 fn name(&self) -> StaticName {
118 "softmax_fast_compact".into()
119 }
120
121 fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
122 for node in &mut model.nodes {
123 if let Some(softmax) = node.op_as_mut::<Softmax>() {
124 if let SoftmaxKind::Softmax(kind) = &mut softmax.kind {
125 *kind = SoftmaxExp::FastCompact
126 }
127 }
128 }
129 Ok(())
130 }
131}
132
133#[derive(Debug, Default, serde::Deserialize)]
135pub struct FloatTranslatorConfig {
136 #[serde(default)]
138 pub filter: Option<String>,
139 #[serde(default)]
141 pub include: Option<Vec<String>>,
142 #[serde(default)]
144 pub exclude: Option<Vec<String>>,
145}
146
147impl FloatTranslatorConfig {
148 pub fn into_node_filter(self) -> TractResult<NodeFilter> {
149 if self.include.is_some() || self.exclude.is_some() {
150 Ok(NodeFilter { include: self.include, exclude: self.exclude })
151 } else {
152 parse_legacy_filter(self.filter.as_deref())
153 }
154 }
155}
156
157#[derive(Debug, serde::Deserialize)]
159pub struct FloatPrecisionConfig {
160 pub from: String,
161 pub to: String,
162 #[serde(default)]
164 pub include: Option<Vec<String>>,
165 #[serde(default)]
167 pub exclude: Option<Vec<String>>,
168}
169
170pub struct ModelTransformFactory {
171 pub name: &'static str,
172 pub build_default: fn() -> TractResult<Box<dyn ModelTransform>>,
174 pub build: fn(&mut dyn erased_serde::Deserializer) -> TractResult<Box<dyn ModelTransform>>,
176}
177
178inventory::collect!(ModelTransformFactory);
179
180#[macro_export]
181macro_rules! register_simple_model_transform {
182 ($name: expr, $type: expr) => {
183 $crate::internal::inventory::submit! {
184 $crate::transform::ModelTransformFactory {
185 name: $name,
186 build_default: || Ok(Box::new($type)),
187 build: |_de| Ok(Box::new($type)),
188 }
189 }
190 };
191}
192
193#[macro_export]
194macro_rules! register_model_transform {
195 ($name:expr, $config:ty, $builder:expr) => {
196 $crate::internal::inventory::submit! {
197 $crate::transform::ModelTransformFactory {
198 name: $name,
199 build_default: || {
200 let config = <$config>::default();
201 let builder: fn($config) -> $crate::prelude::TractResult<Box<dyn $crate::transform::ModelTransform>> = $builder;
202 builder(config)
203 },
204 build: |de: &mut dyn erased_serde::Deserializer| {
205 let config: $config = erased_serde::deserialize(de)
206 .map_err(|e| $crate::internal::anyhow!("deserializing transform config: {e}"))?;
207 let builder: fn($config) -> $crate::prelude::TractResult<Box<dyn $crate::transform::ModelTransform>> = $builder;
208 builder(config)
209 },
210 }
211 }
212 };
213}
214
215pub fn split_spec(spec: &str) -> (Cow<'_, str>, &str) {
217 if let Some(pos) = spec.find('(') {
218 (Cow::Borrowed(&spec[..pos]), &spec[pos..])
219 } else if spec.contains('-') {
220 (Cow::Owned(spec.replace('-', "_")), "")
222 } else {
223 (Cow::Borrowed(spec), "")
224 }
225}
226
227pub fn get_transform(name: &str) -> TractResult<Option<Box<dyn ModelTransform>>> {
229 let (name, _) = split_spec(name);
230 for factory in inventory::iter::<ModelTransformFactory>() {
231 if factory.name == &*name {
232 return Ok(Some((factory.build_default)()?));
233 }
234 }
235 Ok(None)
236}
237
238pub fn get_transform_with_params(
240 name: &str,
241 de: &mut dyn erased_serde::Deserializer,
242) -> TractResult<Option<Box<dyn ModelTransform>>> {
243 for factory in inventory::iter::<ModelTransformFactory>() {
244 if factory.name == name {
245 return Ok(Some((factory.build)(de)?));
246 }
247 }
248 Ok(None)
249}
250
251#[derive(Debug, Default, serde::Deserialize)]
252pub struct ConcretizeSymbolsConfig {
253 pub values: std::collections::HashMap<String, i64>,
254}
255
256#[derive(Debug)]
257struct ConcretizeSymbolsTransform(ConcretizeSymbolsConfig);
258
259impl ModelTransform for ConcretizeSymbolsTransform {
260 fn name(&self) -> StaticName {
261 "concretize_symbols".into()
262 }
263
264 fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
265 let mut table = SymbolValues::default();
266 for (k, v) in &self.0.values {
267 table = table.with(&model.symbols.sym(k), *v);
268 }
269 *model = model.concretize_dims(&table)?;
270 Ok(())
271 }
272}
273
274register_model_transform!("concretize_symbols", ConcretizeSymbolsConfig, |config| Ok(Box::new(
275 ConcretizeSymbolsTransform(config)
276)));
277
278register_simple_model_transform!("softmax_fast_compact", SoftmaxFastCompact);
279#[cfg(feature = "blas")]
280register_simple_model_transform!("as_blas", AsBlas);
281register_simple_model_transform!("block_quant", BlockQuantTransform);
282
283inventory::submit! {
284 ModelTransformFactory {
285 name: "f32_to_f16",
286 build_default: || Ok(build_float_translator(DatumType::F32, DatumType::F16, NodeFilter::default())),
287 build: |de| {
288 let config: FloatTranslatorConfig = erased_serde::deserialize(de)
289 .map_err(|e| anyhow::anyhow!("deserializing f32_to_f16 config: {e}"))?;
290 Ok(build_float_translator(DatumType::F32, DatumType::F16, config.into_node_filter()?))
291 },
292 }
293}
294
295inventory::submit! {
296 ModelTransformFactory {
297 name: "f16_to_f32",
298 build_default: || Ok(build_float_translator(DatumType::F16, DatumType::F32, NodeFilter::default())),
299 build: |de| {
300 let config: FloatTranslatorConfig = erased_serde::deserialize(de)
301 .map_err(|e| anyhow::anyhow!("deserializing f16_to_f32 config: {e}"))?;
302 Ok(build_float_translator(DatumType::F16, DatumType::F32, config.into_node_filter()?))
303 },
304 }
305}
306
307inventory::submit! {
308 ModelTransformFactory {
309 name: "float_precision",
310 build_default: || {
311 anyhow::bail!("float_precision transform requires 'from' and 'to' parameters")
312 },
313 build: |de| {
314 let config: FloatPrecisionConfig = erased_serde::deserialize(de)
315 .map_err(|e| anyhow::anyhow!("deserializing float_precision config: {e}"))?;
316 let from_dt: DatumType = config.from.parse()
317 .map_err(|e| anyhow::anyhow!("parsing 'from' datum type: {e}"))?;
318 let to_dt: DatumType = config.to.parse()
319 .map_err(|e| anyhow::anyhow!("parsing 'to' datum type: {e}"))?;
320 let filter = NodeFilter { include: config.include, exclude: config.exclude };
321 Ok(build_float_translator(from_dt, to_dt, filter))
322 },
323 }
324}