1use crate::internal::*;
2#[cfg(feature = "blas")]
3use crate::ops::einsum::as_blas::AsBlas;
4use crate::ops::matmul::de_block_quant::BlockQuantTransform;
5use num_traits::Float;
6use std::fmt::Debug;
7
8use tract_data::TractResult;
9
10use crate::floats::FloatPrecisionTranslator;
11use crate::ops::nn::{Softmax, SoftmaxExp, SoftmaxKind, TypedModel};
12
13#[macro_export]
14macro_rules! rule_if {
15 ($cond:expr) => {
16 if !$cond {
17 return Ok(None);
18 }
19 };
20}
21
22#[macro_export]
23macro_rules! rule_if_let {
24 ($pat:pat = $expr:expr) => {
25 let $pat = $expr else {
26 return Ok(None);
27 };
28 };
29}
30
31#[macro_export]
32macro_rules! rule_if_some {
33 ($pat:pat = $expr:expr) => {
34 let Some($pat) = $expr else {
35 return Ok(None);
36 };
37 };
38}
39
40pub fn build_float_translator<T1: Datum + Float, T2: Datum + Float>(
47 filter_predicate: Option<&str>,
48) -> Option<Box<dyn ModelTransform>> {
49 let Some(filter_predicate) = filter_predicate.filter(|f| !f.is_empty()) else {
50 return Some(Box::<FloatPrecisionTranslator<T1, T2>>::default());
51 };
52
53 if let Some(node_name_patterns) = filter_predicate.strip_prefix("!=") {
54 let patterns =
55 node_name_patterns.split(',').map(|it| it.trim().to_string()).collect::<Vec<_>>();
56 Some(Box::new(FloatPrecisionTranslator::<T1, T2>::with_filter(move |node| {
57 !patterns.iter().any(|p| node.name.contains(p))
58 })))
59 } else if let Some(node_name_patterns) = filter_predicate.strip_prefix("==") {
60 let patterns =
61 node_name_patterns.split(',').map(|it| it.trim().to_string()).collect::<Vec<_>>();
62 Some(Box::new(FloatPrecisionTranslator::<T1, T2>::with_filter(move |node| {
63 patterns.iter().any(|p| node.name.contains(p))
64 })))
65 } else {
66 None
67 }
68}
69
70pub trait ModelTransform: Debug {
71 fn name(&self) -> StaticName;
72 fn transform(&self, model: &mut TypedModel) -> TractResult<()>;
73 fn transform_into(&self, mut model: TypedModel) -> TractResult<TypedModel> {
74 self.transform(&mut model)?;
75 Ok(model)
76 }
77}
78
79#[derive(Debug)]
80struct SoftmaxFastCompact;
81
82impl ModelTransform for SoftmaxFastCompact {
83 fn name(&self) -> StaticName {
84 "softmax-fast-compact".into()
85 }
86
87 fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
88 for node in &mut model.nodes {
89 if let Some(softmax) = node.op_as_mut::<Softmax>() {
90 if let SoftmaxKind::Softmax(kind) = &mut softmax.kind {
91 *kind = SoftmaxExp::FastCompact
92 }
93 }
94 }
95 Ok(())
96 }
97}
98
99#[allow(clippy::type_complexity)]
100pub struct ModelTransformFactory {
101 pub name: &'static str,
102 pub builder: fn(spec: &str) -> TractResult<Option<Box<dyn ModelTransform>>>,
103}
104
105inventory::collect!(ModelTransformFactory);
106
107#[macro_export]
108macro_rules! register_simple_model_transform {
109 ($name: expr, $type: expr) => {
110 $crate::internal::inventory::submit! {
111 $crate::transform::ModelTransformFactory {
112 name: $name,
113 builder: |_| Ok(Some(Box::new($type)))
114 }
115 }
116 };
117}
118
119pub fn get_transform(spec: &str) -> TractResult<Option<Box<dyn ModelTransform>>> {
120 for factory in inventory::iter::<ModelTransformFactory>() {
121 if spec.starts_with(factory.name) {
122 return (factory.builder)(spec);
123 }
124 }
125 Ok(None)
126}
127
128register_simple_model_transform!("softmax-fast-compact", SoftmaxFastCompact);
129#[cfg(feature = "blas")]
130register_simple_model_transform!("as-blas", AsBlas);
131register_simple_model_transform!("block-quant", BlockQuantTransform);
132
133inventory::submit! {
134 ModelTransformFactory {
135 name: "f32-to-f16",
136 builder: |spec| Ok(build_float_translator::<f32,f16>(spec.strip_prefix("f32-to-f16")))
137 }
138}
139
140inventory::submit! {
141 ModelTransformFactory {
142 name: "f16-to-f32",
143 builder: |spec| Ok(build_float_translator::<f16,f32>(spec.strip_prefix("f16-to-f32")))
144 }
145}