tract_hir/ops/nn/
layer_max.rs

1use tract_core::ops::nn::Softmax;
2
3use crate::infer::*;
4use crate::internal::*;
5
6// TODO tricky to re-express in "core" because of the multiple hot point... do
7// we need one more reduce ?
8#[derive(Debug, Clone, new, Default, Hash)]
9pub struct LayerHardmax {
10    axis: isize,
11    coerce_to_2d: bool,
12}
13
14impl Expansion for LayerHardmax {
15    fn name(&self) -> StaticName {
16        "LayerHardmax".into()
17    }
18
19    fn rules<'r, 'p: 'r, 's: 'r>(
20        &'s self,
21        solver: &mut Solver<'r>,
22        inputs: &'p [TensorProxy],
23        outputs: &'p [TensorProxy],
24    ) -> InferenceResult {
25        rules(solver, inputs, outputs)
26    }
27
28    fn wire(
29        &self,
30        name: &str,
31        target: &mut TypedModel,
32        inputs: &[OutletId],
33    ) -> TractResult<TVec<OutletId>> {
34        use tract_core::ops::{array, change_axes, nn};
35        let input = inputs[0];
36        let input_fact = target.outlet_fact(input)?.clone();
37        let input_dt = input_fact.datum_type;
38        let rank = input_fact.rank();
39        let axis = if self.axis < 0 { rank as isize + self.axis } else { self.axis } as usize;
40        let suffix_dim: TDim = input_fact.shape[axis..].iter().product();
41        let dim = if self.coerce_to_2d {
42            suffix_dim.to_usize()
43        } else {
44            input_fact.shape[axis].to_usize()
45        }
46        .context("Assumes known dimension on working axes suffix.")?;
47        let off = tensor0(0f32).cast_to_dt(input_dt)?.into_owned().into_arc_tensor();
48        let on = tensor0(1f32).cast_to_dt(input_dt)?.into_owned().into_arc_tensor();
49        let mut wires = inputs.into();
50        if self.coerce_to_2d {
51            wires = target.wire_node(
52                format!("{name}.reshaped"),
53                AxisOp::Reshape(axis, input_fact.shape[axis..].into(), tvec!(suffix_dim.clone())),
54                &[input],
55            )?;
56        }
57        wires = target.wire_node(
58            format!("{name}.argmax"),
59            nn::Reduce::new(tvec!(axis), nn::Reducer::ArgMax(false)),
60            &wires,
61        )?;
62        wires =
63            target.wire_node(format!("{name}.rm_axis"), change_axes::AxisOp::Rm(axis), &wires)?;
64        wires = target.wire_node(
65            format!("{name}.hardmax"),
66            array::OneHot { axis, dim, off, on },
67            &wires,
68        )?;
69        if self.coerce_to_2d {
70            wires = target.wire_node(
71                format!("{name}.hardmax_reshaped"),
72                AxisOp::Reshape(axis, tvec!(suffix_dim), input_fact.shape[axis..].into()),
73                &wires,
74            )?;
75        }
76        Ok(wires)
77    }
78}
79
80#[derive(Debug, Clone, new, Default, Hash)]
81pub struct LayerLogSoftmax {
82    pub axis: isize,
83    pub coerce_to_2d: bool,
84}
85
86impl Expansion for LayerLogSoftmax {
87    fn name(&self) -> StaticName {
88        "LayerLogSoftmax".into()
89    }
90
91    fn rules<'r, 'p: 'r, 's: 'r>(
92        &'s self,
93        solver: &mut Solver<'r>,
94        inputs: &'p [TensorProxy],
95        outputs: &'p [TensorProxy],
96    ) -> InferenceResult {
97        rules(solver, inputs, outputs)
98    }
99
100    fn wire(
101        &self,
102        name: &str,
103        target: &mut TypedModel,
104        inputs: &[OutletId],
105    ) -> TractResult<TVec<OutletId>> {
106        let softmax = LayerSoftmax { axis: self.axis, coerce_to_2d: self.coerce_to_2d }
107            .wire(name, target, inputs)?;
108        target.wire_node(format!("{name}.logsoftmax"), tract_core::ops::math::ln(), &softmax)
109    }
110}
111
112#[derive(Debug, Clone, new, Default, Hash)]
113pub struct LayerSoftmax {
114    axis: isize,
115    coerce_to_2d: bool,
116}
117
118impl Expansion for LayerSoftmax {
119    fn name(&self) -> StaticName {
120        "LayerSoftmax".into()
121    }
122
123    fn rules<'r, 'p: 'r, 's: 'r>(
124        &'s self,
125        solver: &mut Solver<'r>,
126        inputs: &'p [TensorProxy],
127        outputs: &'p [TensorProxy],
128    ) -> InferenceResult {
129        rules(solver, inputs, outputs)
130    }
131
132    fn wire(
133        &self,
134        name: &str,
135        target: &mut TypedModel,
136        inputs: &[OutletId],
137    ) -> TractResult<TVec<OutletId>> {
138        let input = inputs[0];
139        let rank = target.outlet_fact(input)?.rank();
140        let dt = target.outlet_fact(input)?.datum_type;
141        let axis = if self.axis < 0 { rank as isize + self.axis } else { self.axis } as usize;
142        let axes =
143            if self.coerce_to_2d { (axis..rank).collect::<TVec<usize>>() } else { tvec!(axis) };
144        let quant_output_dt = if dt.is_float() { None } else { Some(dt) };
145        target.wire_node(name, Softmax { axes, quant_output_dt, ..Softmax::default() }, inputs)
146    }
147}
148
149fn rules<'r, 'p: 'r, 's: 'r>(
150    s: &mut Solver<'r>,
151    inputs: &'p [TensorProxy],
152    outputs: &'p [TensorProxy],
153) -> InferenceResult {
154    check_output_arity(outputs, 1)?;
155    s.equals(&outputs[0].datum_type, &inputs[0].datum_type)?;
156    s.equals(&outputs[0].rank, &inputs[0].rank)?;
157    s.equals(&outputs[0].shape, &inputs[0].shape)?;
158    Ok(())
159}