1use crate::internal::*;
2#[cfg(feature = "blas")]
3use crate::ops::einsum::as_blas::AsBlas;
4use crate::ops::matmul::de_block_quant::BlockQuantTransform;
5use num_traits::Float;
6use std::fmt::Debug;
7
8use tract_data::TractResult;
9
10use crate::floats::FloatPrecisionTranslator;
11use crate::ops::nn::{Softmax, SoftmaxKind, SoftmaxExp, TypedModel};
12
13pub fn get_transform(name: &str) -> Option<Box<dyn ModelTransform>> {
14 match name {
15 #[cfg(feature = "blas")]
16 "as-blas" => Some(Box::<AsBlas>::default()),
17 name if name.starts_with("f32-to-f16") => {
18 build_float_translator::<f32, f16>(name.strip_prefix("f32-to-f16"))
19 }
20 name if name.starts_with("f16-to-f32") => {
21 build_float_translator::<f16, f32>(name.strip_prefix("f16-to-f32"))
22 }
23 "softmax-fast-compact" => Some(Box::new(SoftmaxFastCompact)),
24 "block-quant" => Some(Box::new(BlockQuantTransform)),
25 _ => None,
26 }
27}
28
29pub fn build_float_translator<T1: Datum + Float, T2: Datum + Float>(
36 filter_predicate: Option<&str>,
37) -> Option<Box<dyn ModelTransform>> {
38 let Some(filter_predicate) = filter_predicate.filter(|f| !f.is_empty()) else {
39 return Some(Box::<FloatPrecisionTranslator<T1, T2>>::default());
40 };
41
42 if let Some(node_name_patterns) = filter_predicate.strip_prefix("!=") {
43 let patterns =
44 node_name_patterns.split(',').map(|it| it.trim().to_string()).collect::<Vec<_>>();
45 Some(Box::new(FloatPrecisionTranslator::<T1, T2>::with_filter(move |node| {
46 !patterns.iter().any(|p| node.name.contains(p))
47 })))
48 } else if let Some(node_name_patterns) = filter_predicate.strip_prefix("==") {
49 let patterns =
50 node_name_patterns.split(',').map(|it| it.trim().to_string()).collect::<Vec<_>>();
51 Some(Box::new(FloatPrecisionTranslator::<T1, T2>::with_filter(move |node| {
52 patterns.iter().any(|p| node.name.contains(p))
53 })))
54 } else {
55 None
56 }
57}
58
59pub trait ModelTransform: Debug {
60 fn name(&self) -> StaticName;
61 fn transform(&self, model: &mut TypedModel) -> TractResult<()>;
62 fn transform_into(&self, mut model: TypedModel) -> TractResult<TypedModel> {
63 self.transform(&mut model)?;
64 Ok(model)
65 }
66}
67
68#[derive(Debug)]
69struct SoftmaxFastCompact;
70
71impl ModelTransform for SoftmaxFastCompact {
72 fn name(&self) -> StaticName {
73 "softmax-fast-compact".into()
74 }
75
76 fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
77 for node in &mut model.nodes {
78 if let Some(softmax) = node.op_as_mut::<Softmax>() {
79 if let SoftmaxKind::Softmax(kind) = &mut softmax.kind { *kind = SoftmaxExp::FastCompact }
80 }
81 }
82 Ok(())
83 }
84}