1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
use crate::internal::*;
use std::borrow::Cow;
use std::fmt::Debug;

use tract_data::TractResult;

use crate::floats::FloatPrecisionTranslator;
use crate::ops::nn::{Softmax, SoftmaxExp, TypedModel};

pub fn get_transformer(name: &str) -> Option<Box<dyn ModelTransformer>> {
    match name {
        "f32-to-f16" => Some(Box::<FloatPrecisionTranslator<f32, f16>>::default()),
        "f16-to-f32" => Some(Box::<FloatPrecisionTranslator<f32, f16>>::default()),
        "softmax-fast-compact" => Some(Box::new(SoftmaxFastCompact)),
        _ => None,
    }
}

pub trait ModelTransformer: Debug {
    fn name(&self) -> Cow<str>;
    fn transform(&self, model: &mut TypedModel) -> TractResult<()>;
    fn transform_into(&self, model: &TypedModel) -> TractResult<TypedModel> {
        let mut model = model.clone();
        self.transform(&mut model)?;
        Ok(model)
    }
}

#[derive(Debug)]
struct SoftmaxFastCompact;

impl ModelTransformer for SoftmaxFastCompact {
    fn name(&self) -> Cow<str> {
        "softmax-fast-compact".into()
    }

    fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
        for node in &mut model.nodes {
            if let Some(softmax) = node.op_as_mut::<Softmax>() {
                softmax.exp = SoftmaxExp::FastCompact;
            }
        }
        Ok(())
    }
}