1use tract_core::ops::nn::Softmax;
2
3use crate::infer::*;
4use crate::internal::*;
5
6#[derive(Debug, Clone, new, Default, Hash)]
9pub struct LayerHardmax {
10 axis: isize,
11 coerce_to_2d: bool,
12}
13
14impl Expansion for LayerHardmax {
15 fn name(&self) -> StaticName {
16 "LayerHardmax".into()
17 }
18
19 fn rules<'r, 'p: 'r, 's: 'r>(
20 &'s self,
21 solver: &mut Solver<'r>,
22 inputs: &'p [TensorProxy],
23 outputs: &'p [TensorProxy],
24 ) -> InferenceResult {
25 rules(solver, inputs, outputs)
26 }
27
28 fn wire(
29 &self,
30 name: &str,
31 target: &mut TypedModel,
32 inputs: &[OutletId],
33 ) -> TractResult<TVec<OutletId>> {
34 use tract_core::ops::{array, change_axes, nn};
35 let input = inputs[0];
36 let input_fact = target.outlet_fact(input)?.clone();
37 let input_dt = input_fact.datum_type;
38 let rank = input_fact.rank();
39 let axis = if self.axis < 0 { rank as isize + self.axis } else { self.axis } as usize;
40 let suffix_dim: TDim = input_fact.shape[axis..].iter().product();
41 let dim = if self.coerce_to_2d {
42 suffix_dim.to_usize()
43 } else {
44 input_fact.shape[axis].to_usize()
45 }
46 .context("Assumes known dimension on working axes suffix.")?;
47 let off = tensor0(0f32).cast_to_dt(input_dt)?.into_owned().into_arc_tensor();
48 let on = tensor0(1f32).cast_to_dt(input_dt)?.into_owned().into_arc_tensor();
49 let mut wires = inputs.into();
50 if self.coerce_to_2d {
51 wires = target.wire_node(
52 format!("{name}.reshaped"),
53 AxisOp::Reshape(axis, input_fact.shape[axis..].into(), tvec!(suffix_dim.clone())),
54 &[input],
55 )?;
56 }
57 wires = target.wire_node(
58 format!("{name}.argmax"),
59 nn::Reduce::new(tvec!(axis), nn::Reducer::ArgMax(false)),
60 &wires,
61 )?;
62 wires =
63 target.wire_node(format!("{name}.rm_axis"), change_axes::AxisOp::Rm(axis), &wires)?;
64 wires = target.wire_node(
65 format!("{name}.hardmax"),
66 array::OneHot { axis, dim, off, on },
67 &wires,
68 )?;
69 if self.coerce_to_2d {
70 wires = target.wire_node(
71 format!("{name}.hardmax_reshaped"),
72 AxisOp::Reshape(axis, tvec!(suffix_dim), input_fact.shape[axis..].into()),
73 &wires,
74 )?;
75 }
76 Ok(wires)
77 }
78}
79
80#[derive(Debug, Clone, new, Default, Hash)]
81pub struct LayerLogSoftmax {
82 pub axis: isize,
83 pub coerce_to_2d: bool,
84}
85
86impl Expansion for LayerLogSoftmax {
87 fn name(&self) -> StaticName {
88 "LayerLogSoftmax".into()
89 }
90
91 fn rules<'r, 'p: 'r, 's: 'r>(
92 &'s self,
93 solver: &mut Solver<'r>,
94 inputs: &'p [TensorProxy],
95 outputs: &'p [TensorProxy],
96 ) -> InferenceResult {
97 rules(solver, inputs, outputs)
98 }
99
100 fn wire(
101 &self,
102 name: &str,
103 target: &mut TypedModel,
104 inputs: &[OutletId],
105 ) -> TractResult<TVec<OutletId>> {
106 let softmax = LayerSoftmax { axis: self.axis, coerce_to_2d: self.coerce_to_2d }
107 .wire(name, target, inputs)?;
108 target.wire_node(format!("{name}.logsoftmax"), tract_core::ops::math::ln(), &softmax)
109 }
110}
111
112#[derive(Debug, Clone, new, Default, Hash)]
113pub struct LayerSoftmax {
114 axis: isize,
115 coerce_to_2d: bool,
116}
117
118impl Expansion for LayerSoftmax {
119 fn name(&self) -> StaticName {
120 "LayerSoftmax".into()
121 }
122
123 fn rules<'r, 'p: 'r, 's: 'r>(
124 &'s self,
125 solver: &mut Solver<'r>,
126 inputs: &'p [TensorProxy],
127 outputs: &'p [TensorProxy],
128 ) -> InferenceResult {
129 rules(solver, inputs, outputs)
130 }
131
132 fn wire(
133 &self,
134 name: &str,
135 target: &mut TypedModel,
136 inputs: &[OutletId],
137 ) -> TractResult<TVec<OutletId>> {
138 let input = inputs[0];
139 let rank = target.outlet_fact(input)?.rank();
140 let dt = target.outlet_fact(input)?.datum_type;
141 let axis = if self.axis < 0 { rank as isize + self.axis } else { self.axis } as usize;
142 let axes =
143 if self.coerce_to_2d { (axis..rank).collect::<TVec<usize>>() } else { tvec!(axis) };
144 let quant_output_dt = if dt.is_float() { None } else { Some(dt) };
145 target.wire_node(name, Softmax { axes, quant_output_dt, ..Softmax::default() }, inputs)
146 }
147}
148
149fn rules<'r, 'p: 'r, 's: 'r>(
150 s: &mut Solver<'r>,
151 inputs: &'p [TensorProxy],
152 outputs: &'p [TensorProxy],
153) -> InferenceResult {
154 check_output_arity(outputs, 1)?;
155 s.equals(&outputs[0].datum_type, &inputs[0].datum_type)?;
156 s.equals(&outputs[0].rank, &inputs[0].rank)?;
157 s.equals(&outputs[0].shape, &inputs[0].shape)?;
158 Ok(())
159}