tract_hir/ops/nn/
softmax.rs

1use crate::internal::*;
2
3#[derive(Debug, Clone, new, Hash)]
4pub struct Softmax {
5    axis: isize,
6}
7
8impl Expansion for Softmax {
9    fn name(&self) -> StaticName {
10        "Softmax".into()
11    }
12    fn info(&self) -> TractResult<Vec<String>> {
13        Ok(vec![format!("axis: {:?}", self.axis)])
14    }
15
16    fn rules<'r, 'p: 'r, 's: 'r>(
17        &'s self,
18        s: &mut Solver<'r>,
19        inputs: &'p [TensorProxy],
20        outputs: &'p [TensorProxy],
21    ) -> InferenceResult {
22        check_input_arity(inputs, 1)?;
23        check_output_arity(outputs, 1)?;
24        s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
25        s.equals(&inputs[0].shape, &outputs[0].shape)?;
26
27        Ok(())
28    }
29
30    fn wire(
31        &self,
32        name: &str,
33        target: &mut TypedModel,
34        inputs: &[OutletId],
35    ) -> TractResult<TVec<OutletId>> {
36        let axis = if self.axis < 0 {
37            (target.outlet_fact(inputs[0])?.rank() as isize + self.axis) as usize
38        } else {
39            self.axis as usize
40        };
41
42        let input = target.outlet_fact(inputs[0])?.clone();
43        let input_dt = input.datum_type;
44        let quant_output_dt = if input_dt.is_quantized() {
45            // Quantization parameters are not specified in ONNX (v13) so we set this value as default
46            // in order to maximize the precision of the output.
47            Some(DatumType::QU8(QParams::ZpScale { zero_point: 0, scale: 0.0078125 }))
48        } else {
49            None
50        };
51
52        target.wire_node(
53            name,
54            tract_core::ops::nn::Softmax {
55                axes: tvec![axis],
56                quant_output_dt,
57                ..tract_core::ops::nn::Softmax::default()
58            },
59            inputs,
60        )
61    }
62}