use crate::internal::*;
use std::borrow::Cow;
use std::fmt::Debug;
use tract_data::TractResult;
use crate::floats::FloatPrecisionTranslator;
use crate::ops::nn::{Softmax, SoftmaxExp, TypedModel};
pub fn get_transformer(name: &str) -> Option<Box<dyn ModelTransformer>> {
match name {
"f32-to-f16" => Some(Box::<FloatPrecisionTranslator<f32, f16>>::default()),
"f16-to-f32" => Some(Box::<FloatPrecisionTranslator<f32, f16>>::default()),
"softmax-fast-compact" => Some(Box::new(SoftmaxFastCompact)),
_ => None,
}
}
pub trait ModelTransformer: Debug {
fn name(&self) -> Cow<str>;
fn transform(&self, model: &mut TypedModel) -> TractResult<()>;
fn transform_into(&self, model: &TypedModel) -> TractResult<TypedModel> {
let mut model = model.clone();
self.transform(&mut model)?;
Ok(model)
}
}
#[derive(Debug)]
struct SoftmaxFastCompact;
impl ModelTransformer for SoftmaxFastCompact {
fn name(&self) -> Cow<str> {
"softmax-fast-compact".into()
}
fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
for node in &mut model.nodes {
if let Some(softmax) = node.op_as_mut::<Softmax>() {
softmax.exp = SoftmaxExp::FastCompact;
}
}
Ok(())
}
}