1use crate::internal::*;
2use tract_core::ops::logic::Comp;
3use tract_core::ops::math::*;
4
5macro_rules! activation {
6 ($op: ident, $wire:expr) => {
7 impl Expansion for $op {
8 fn name(&self) -> StaticName {
9 stringify!($op).into()
10 }
11
12 fn rules<'r, 'p: 'r, 's: 'r>(
13 &'s self,
14 s: &mut Solver<'r>,
15 inputs: &'p [TensorProxy],
16 outputs: &'p [TensorProxy],
17 ) -> InferenceResult {
18 simple_unary_rules(s, inputs, outputs)
19 }
20
21 fn wire(
22 &self,
23 name: &str,
24 model: &mut TypedModel,
25 inputs: &[OutletId],
26 ) -> TractResult<TVec<OutletId>> {
27 let wire: fn(
28 &$op,
29 &str,
30 &mut TypedModel,
31 &[OutletId],
32 ) -> TractResult<TVec<OutletId>> = $wire;
33 (wire)(self, name, model, inputs)
34 }
35 }
36 };
37}
38
39macro_rules! cst {
40 ($model: expr, $inputs: expr, $name: expr, $id:ident, $value: expr) => {
41 let $id = broadcast_scalar($value, $model, $inputs)?;
42 let $id = $model.add_const($name.to_string() + "." + stringify!($id), $id)?;
43 };
44}
45
46#[derive(Debug, Clone, new)]
47pub struct Clip(Option<f32>, Option<f32>);
48
49activation!(Clip, |op, name: &str, model: &mut TypedModel, inputs| {
50 let mut wire: TVec<OutletId> = inputs.into();
51 if let Some(low) = op.0 {
52 let low = broadcast_scalar(low, model, inputs)?;
53 let low = model.add_const(name.to_string() + ".low.cst", low)?;
54 wire = model.wire_node(name.to_string() + ".low", max(), &[wire[0], low])?;
55 }
56 if let Some(high) = op.1 {
57 let high = broadcast_scalar(high, model, inputs)?;
58 let high = model.add_const(name.to_string() + ".high.cst", high)?;
59 wire = model.wire_node(name.to_string() + ".high", min(), &[wire[0], high])?;
60 }
61 Ok(wire)
62});
63
64#[derive(Debug, Clone, new, Hash)]
65pub struct Softplus;
66
67activation!(Softplus, |_op, name: &str, model: &mut TypedModel, inputs| {
68 cst!(model, inputs, name, one, 1.0);
69 let wire = model.wire_node(name.to_string() + ".exp", exp(), inputs)?;
70 let wire = model.wire_node(name.to_string() + ".plus_one", add(), &[wire[0], one])?;
71 let wire = model.wire_node(name.to_string() + ".ln", ln(), &wire)?;
72 Ok(wire)
73});
74
75#[derive(Debug, Clone, new, Hash)]
76pub struct Softsign;
77
78activation!(Softsign, |_op, name: &str, model: &mut TypedModel, inputs| {
79 cst!(model, inputs, name, one, 1.0);
80 let x_abs = model.wire_node(name.to_string() + ".abs", abs(), inputs)?;
81 let denum = model.wire_node(name.to_string() + ".plus_one", add(), &[x_abs[0], one])?;
82 let wire = model.wire_node(name.to_string() + ".div", div(), &[inputs[0], denum[0]])?;
83 Ok(wire)
84});
85
86#[derive(Debug, Clone, new)]
87pub struct Celu(pub f32);
88
89activation!(Celu, |op, name: &str, model: &mut TypedModel, inputs| {
90 cst!(model, inputs, name, zero, 0.0);
91 cst!(model, inputs, name, one, 1.0);
92 cst!(model, inputs, name, alpha, op.0);
93 let x_over_alpha =
94 model.wire_node(name.to_string() + ".x_over_alpha", div(), &[inputs[0], alpha])?;
95 let x_over_alpha_exp = model.wire_node(name.to_string() + ".exp", exp(), &[x_over_alpha[0]])?;
96 let minus_one =
97 model.wire_node(name.to_string() + ".minus_one", sub(), &[x_over_alpha_exp[0], one])?;
98 let wire = model.wire_node(name.to_string() + ".sat-zero", min(), &[zero, minus_one[0]])?;
99 let relu = model.wire_node(name.to_string() + ".relu", max(), &[zero, inputs[0]])?;
100 let wire = model.wire_node(name.to_string(), add(), &[relu[0], wire[0]])?;
101 Ok(wire)
102});
103
104#[derive(Debug, Clone, new)]
105pub struct Elu(pub f32);
106
107activation!(Elu, |op, name: &str, model: &mut TypedModel, inputs| {
108 cst!(model, inputs, name, zero, 0.0);
109 cst!(model, inputs, name, one, 1.0);
110 cst!(model, inputs, name, alpha, op.0);
111 let x_exp = model.wire_node(name.to_string() + ".exp", exp(), inputs)?;
112 let minus_one = model.wire_node(name.to_string() + ".minus_one", sub(), &[x_exp[0], one])?;
113 let neg = model.wire_node(name.to_string() + ".mul_alpha", mul(), &[alpha, minus_one[0]])?;
114 let test = model.wire_node(name.to_string() + ".test", Comp::LT, &[zero, inputs[0]])?;
115 let wire = model.wire_node(
116 name.to_string() + ".iff",
117 tract_core::ops::logic::Iff,
118 &[test[0], inputs[0], neg[0]],
119 )?;
120 Ok(wire)
121});
122
123#[derive(Debug, Clone, new)]
124pub struct HardSigmoid(pub f32, pub f32);
125
126activation!(HardSigmoid, |op, name: &str, model: &mut TypedModel, inputs| {
127 cst!(model, inputs, name, zero, 0.0);
128 cst!(model, inputs, name, one, 1.0);
129 cst!(model, inputs, name, alpha, op.0);
130 cst!(model, inputs, name, beta, op.1);
131 let wire = model.wire_node(name.to_string() + ".mul_alpha", mul(), &[alpha, inputs[0]])?;
132 let wire = model.wire_node(name.to_string() + ".add_beta", add(), &[beta, wire[0]])?;
133 let wire = model.wire_node(name.to_string() + ".sat-one", min(), &[one, wire[0]])?;
134 let wire = model.wire_node(name.to_string() + ".sat-zero", max(), &[zero, wire[0]])?;
135 Ok(wire)
136});
137
138#[derive(Debug, Clone, new)]
139pub struct LeakyRelu(pub f32);
140
141activation!(LeakyRelu, |op, name: &str, model: &mut TypedModel, inputs| {
142 model.wire_node(name, tract_core::ops::nn::leaky_relu(op.0), inputs)
143});
144
145#[derive(Debug, Clone, new)]
146pub struct ParametricSoftplus(pub f32, pub f32);
147
148activation!(ParametricSoftplus, |op, name: &str, model: &mut TypedModel, inputs| {
149 cst!(model, inputs, name, one, 1.0);
150 cst!(model, inputs, name, alpha, op.0);
151 cst!(model, inputs, name, beta, op.1);
152 let wire = model.wire_node(name.to_string() + ".mul_beta", mul(), &[beta, inputs[0]])?;
153 let wire = model.wire_node(name.to_string() + ".exp", exp(), &wire)?;
154 let wire = model.wire_node(name.to_string() + ".plus_one", add(), &[one, wire[0]])?;
155 let wire = model.wire_node(name.to_string() + ".ln", ln(), &wire)?;
156 let wire = model.wire_node(name.to_string() + ".mul_alpha", mul(), &[alpha, wire[0]])?;
157 Ok(wire)
158});
159
160#[derive(Debug, Clone, new)]
161pub struct ScaledTanh(pub f32, pub f32);
162
163activation!(ScaledTanh, |op, name: &str, model: &mut TypedModel, inputs| {
164 cst!(model, inputs, name, alpha, op.0);
165 cst!(model, inputs, name, beta, op.1);
166 let wire = model.wire_node(name.to_string() + ".mul_beta", mul(), &[beta, inputs[0]])?;
167 let wire = model.wire_node(name.to_string() + ".tanh", tanh(), &wire)?;
168 let wire = model.wire_node(name.to_string() + ".mul_alpha", mul(), &[alpha, wire[0]])?;
169 Ok(wire)
170});
171
172#[derive(Debug, Clone, new)]
173pub struct Selu(pub f32, pub f32);
174
175activation!(Selu, |op, name: &str, model: &mut TypedModel, inputs| {
176 cst!(model, inputs, name, zero, 0.0);
177 cst!(model, inputs, name, alpha, op.0);
178 cst!(model, inputs, name, gamma, op.1);
179 let wire = model.wire_node(name.to_string() + ".exp", exp(), inputs)?;
180 let wire = model.wire_node(name.to_string() + ".mul_alpha", mul(), &[wire[0], alpha])?;
181 let wire = model.wire_node(name.to_string() + ".sub_alpha", sub(), &[wire[0], alpha])?;
182 let test = model.wire_node(name.to_string() + ".test", Comp::LT, &[zero, inputs[0]])?;
183 let wire = model.wire_node(
184 name.to_string() + ".iff",
185 tract_core::ops::logic::Iff,
186 &[test[0], inputs[0], wire[0]],
187 )?;
188 let wire = model.wire_node(name.to_string() + ".mul_gamma", mul(), &[gamma, wire[0]])?;
189 Ok(wire)
190});
191
192#[derive(Debug, Clone, new)]
193pub struct Shrink(pub f32, pub f32);
194
195activation!(Shrink, |op, name: &str, model: &mut TypedModel, inputs| {
196 cst!(model, inputs, name, bias, op.0);
197 cst!(model, inputs, name, lambda, op.1);
198 cst!(model, inputs, name, minus_lambda, -op.1);
199 let zero = broadcast_scalar(0.0, model, inputs)?;
200 let zero = model.add_const(name.to_string() + ".zero", zero)?;
201 let test_pos =
202 model.wire_node(name.to_string() + ".test_pos", Comp::LT, &[lambda, inputs[0]])?;
203 let pos = model.wire_node(
204 name.to_string() + ".pos",
205 tract_core::ops::math::sub(),
206 &[inputs[0], bias],
207 )?;
208 let test_neg =
209 model.wire_node(name.to_string() + ".test_neg", Comp::GT, &[minus_lambda, inputs[0]])?;
210 let neg = model.wire_node(
211 name.to_string() + ".neg",
212 tract_core::ops::math::add(),
213 &[bias, inputs[0]],
214 )?;
215 let wire = model.wire_node(
216 name.to_string() + ".if_pos",
217 tract_core::ops::logic::Iff,
218 &[test_pos[0], pos[0], zero],
219 )?;
220 let wire = model.wire_node(
221 name.to_string() + ".if_neg",
222 tract_core::ops::logic::Iff,
223 &[test_neg[0], neg[0], wire[0]],
224 )?;
225 Ok(wire)
226});
227
228#[derive(Debug, Clone, new)]
229pub struct ThresholdRelu(pub f32);
230
231activation!(ThresholdRelu, |op, name: &str, model: &mut TypedModel, inputs| {
232 cst!(model, inputs, name, zero, 0.0);
233 cst!(model, inputs, name, alpha, op.0);
234 let test = model.wire_node(name.to_string() + ".test", Comp::LT, &[alpha, inputs[0]])?;
235 let wire = model.wire_node(
236 name.to_string() + ".iff",
237 tract_core::ops::logic::Iff,
238 &[test[0], inputs[0], zero],
239 )?;
240 Ok(wire)
241});
242
243fn simple_unary_rules<'r, 'p: 'r, 's: 'r>(
244 s: &mut Solver<'r>,
245 inputs: &'p [TensorProxy],
246 outputs: &'p [TensorProxy],
247) -> InferenceResult {
248 check_input_arity(inputs, 1)?;
249 check_output_arity(outputs, 1)?;
250 s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
251 s.equals(&inputs[0].shape, &outputs[0].shape)?;
252 Ok(())
253}
254
255pub fn broadcast_scalar(
256 f: f32,
257 model: &TypedModel,
258 inputs: &[OutletId],
259) -> TractResult<Arc<Tensor>> {
260 let fact = model.outlet_fact(inputs[0])?;
261 let mut tensor = tensor0(f).cast_to_dt(fact.datum_type)?.into_owned();
262 while tensor.rank() < fact.rank() {
263 tensor.insert_axis(0)?;
264 }
265 Ok(tensor.into_arc_tensor())
266}