tract_hir/ops/nn/
softmax.rs1use crate::internal::*;
2
3#[derive(Debug, Clone, new, Hash)]
4pub struct Softmax {
5 axis: isize,
6}
7
8impl Expansion for Softmax {
9 fn name(&self) -> StaticName {
10 "Softmax".into()
11 }
12 fn info(&self) -> TractResult<Vec<String>> {
13 Ok(vec![format!("axis: {:?}", self.axis)])
14 }
15
16 fn rules<'r, 'p: 'r, 's: 'r>(
17 &'s self,
18 s: &mut Solver<'r>,
19 inputs: &'p [TensorProxy],
20 outputs: &'p [TensorProxy],
21 ) -> InferenceResult {
22 check_input_arity(inputs, 1)?;
23 check_output_arity(outputs, 1)?;
24 s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
25 s.equals(&inputs[0].shape, &outputs[0].shape)?;
26
27 Ok(())
28 }
29
30 fn wire(
31 &self,
32 name: &str,
33 target: &mut TypedModel,
34 inputs: &[OutletId],
35 ) -> TractResult<TVec<OutletId>> {
36 let axis = if self.axis < 0 {
37 (target.outlet_fact(inputs[0])?.rank() as isize + self.axis) as usize
38 } else {
39 self.axis as usize
40 };
41
42 let input = target.outlet_fact(inputs[0])?.clone();
43 let input_dt = input.datum_type;
44 let quant_output_dt = if input_dt.is_quantized() {
45 Some(DatumType::QU8(QParams::ZpScale { zero_point: 0, scale: 0.0078125 }))
48 } else {
49 None
50 };
51
52 target.wire_node(
53 name,
54 tract_core::ops::nn::Softmax {
55 axes: tvec![axis],
56 quant_output_dt,
57 ..tract_core::ops::nn::Softmax::default()
58 },
59 inputs,
60 )
61 }
62}