1use crate::internal::*;
2#[cfg(feature = "blas")]
3use crate::ops::einsum::as_blas::AsBlas;
4use crate::ops::matmul::de_block_quant::BlockQuantTransform;
5use num_traits::Float;
6use std::borrow::Cow;
7use std::fmt::Debug;
8
9use tract_data::TractResult;
10
11use crate::floats::FloatPrecisionTranslator;
12use crate::ops::nn::{Softmax, SoftmaxExp, TypedModel};
13
14pub fn get_transform(name: &str) -> Option<Box<dyn ModelTransform>> {
15 match name {
16 #[cfg(feature = "blas")]
17 "as-blas" => Some(Box::<AsBlas>::default()),
18 name if name.starts_with("f32-to-f16") => {
19 build_float_translator::<f32, f16>(name.strip_prefix("f32-to-f16"))
20 }
21 name if name.starts_with("f16-to-f32") => {
22 build_float_translator::<f16, f32>(name.strip_prefix("f16-to-f32"))
23 }
24 "softmax-fast-compact" => Some(Box::new(SoftmaxFastCompact)),
25 "block-quant" => Some(Box::new(BlockQuantTransform)),
26 _ => None,
27 }
28}
29
30pub fn build_float_translator<T1: Datum + Float, T2: Datum + Float>(
37 filter_predicate: Option<&str>,
38) -> Option<Box<dyn ModelTransform>> {
39 let Some(filter_predicate) = filter_predicate.filter(|f| !f.is_empty()) else {
40 return Some(Box::<FloatPrecisionTranslator<T1, T2>>::default());
41 };
42
43 if let Some(node_name_patterns) = filter_predicate.strip_prefix("!=") {
44 let patterns =
45 node_name_patterns.split(',').map(|it| it.trim().to_string()).collect::<Vec<_>>();
46 Some(Box::new(FloatPrecisionTranslator::<T1, T2>::with_filter(move |node| {
47 !patterns.iter().any(|p| node.name.contains(p))
48 })))
49 } else if let Some(node_name_patterns) = filter_predicate.strip_prefix("==") {
50 let patterns =
51 node_name_patterns.split(',').map(|it| it.trim().to_string()).collect::<Vec<_>>();
52 Some(Box::new(FloatPrecisionTranslator::<T1, T2>::with_filter(move |node| {
53 patterns.iter().any(|p| node.name.contains(p))
54 })))
55 } else {
56 None
57 }
58}
59
60pub trait ModelTransform: Debug {
61 fn name(&self) -> Cow<str>;
62 fn transform(&self, model: &mut TypedModel) -> TractResult<()>;
63 fn transform_into(&self, mut model: TypedModel) -> TractResult<TypedModel> {
64 self.transform(&mut model)?;
65 Ok(model)
66 }
67}
68
69#[derive(Debug)]
70struct SoftmaxFastCompact;
71
72impl ModelTransform for SoftmaxFastCompact {
73 fn name(&self) -> Cow<str> {
74 "softmax-fast-compact".into()
75 }
76
77 fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
78 for node in &mut model.nodes {
79 if let Some(softmax) = node.op_as_mut::<Softmax>() {
80 softmax.exp = SoftmaxExp::FastCompact;
81 }
82 }
83 Ok(())
84 }
85}