tract_hir/ops/
activations.rs

1use crate::internal::*;
2use tract_core::ops::logic::Comp;
3use tract_core::ops::math::*;
4
5macro_rules! activation {
6    ($op: ident, $wire:expr) => {
7        impl Expansion for $op {
8            fn name(&self) -> StaticName {
9                stringify!($op).into()
10            }
11
12            fn rules<'r, 'p: 'r, 's: 'r>(
13                &'s self,
14                s: &mut Solver<'r>,
15                inputs: &'p [TensorProxy],
16                outputs: &'p [TensorProxy],
17            ) -> InferenceResult {
18                simple_unary_rules(s, inputs, outputs)
19            }
20
21            fn wire(
22                &self,
23                name: &str,
24                model: &mut TypedModel,
25                inputs: &[OutletId],
26            ) -> TractResult<TVec<OutletId>> {
27                let wire: fn(
28                    &$op,
29                    &str,
30                    &mut TypedModel,
31                    &[OutletId],
32                ) -> TractResult<TVec<OutletId>> = $wire;
33                (wire)(self, name, model, inputs)
34            }
35        }
36    };
37}
38
39macro_rules! cst {
40    ($model: expr, $inputs: expr, $name: expr, $id:ident, $value: expr) => {
41        let $id = broadcast_scalar($value, $model, $inputs)?;
42        let $id = $model.add_const($name.to_string() + "." + stringify!($id), $id)?;
43    };
44}
45
46#[derive(Debug, Clone, new)]
47pub struct Clip(Option<f32>, Option<f32>);
48
49activation!(Clip, |op, name: &str, model: &mut TypedModel, inputs| {
50    let mut wire: TVec<OutletId> = inputs.into();
51    if let Some(low) = op.0 {
52        let low = broadcast_scalar(low, model, inputs)?;
53        let low = model.add_const(name.to_string() + ".low.cst", low)?;
54        wire = model.wire_node(name.to_string() + ".low", max(), &[wire[0], low])?;
55    }
56    if let Some(high) = op.1 {
57        let high = broadcast_scalar(high, model, inputs)?;
58        let high = model.add_const(name.to_string() + ".high.cst", high)?;
59        wire = model.wire_node(name.to_string() + ".high", min(), &[wire[0], high])?;
60    }
61    Ok(wire)
62});
63
64#[derive(Debug, Clone, new, Hash)]
65pub struct Softplus;
66
67activation!(Softplus, |_op, name: &str, model: &mut TypedModel, inputs| {
68    cst!(model, inputs, name, one, 1.0);
69    let wire = model.wire_node(name.to_string() + ".exp", exp(), inputs)?;
70    let wire = model.wire_node(name.to_string() + ".plus_one", add(), &[wire[0], one])?;
71    let wire = model.wire_node(name.to_string() + ".ln", ln(), &wire)?;
72    Ok(wire)
73});
74
75#[derive(Debug, Clone, new, Hash)]
76pub struct Softsign;
77
78activation!(Softsign, |_op, name: &str, model: &mut TypedModel, inputs| {
79    cst!(model, inputs, name, one, 1.0);
80    let x_abs = model.wire_node(name.to_string() + ".abs", abs(), inputs)?;
81    let denum = model.wire_node(name.to_string() + ".plus_one", add(), &[x_abs[0], one])?;
82    let wire = model.wire_node(name.to_string() + ".div", div(), &[inputs[0], denum[0]])?;
83    Ok(wire)
84});
85
86#[derive(Debug, Clone, new)]
87pub struct Celu(pub f32);
88
89activation!(Celu, |op, name: &str, model: &mut TypedModel, inputs| {
90    cst!(model, inputs, name, zero, 0.0);
91    cst!(model, inputs, name, one, 1.0);
92    cst!(model, inputs, name, alpha, op.0);
93    let x_over_alpha =
94        model.wire_node(name.to_string() + ".x_over_alpha", div(), &[inputs[0], alpha])?;
95    let x_over_alpha_exp = model.wire_node(name.to_string() + ".exp", exp(), &[x_over_alpha[0]])?;
96    let minus_one =
97        model.wire_node(name.to_string() + ".minus_one", sub(), &[x_over_alpha_exp[0], one])?;
98    let wire = model.wire_node(name.to_string() + ".sat-zero", min(), &[zero, minus_one[0]])?;
99    let relu = model.wire_node(name.to_string() + ".relu", max(), &[zero, inputs[0]])?;
100    let wire = model.wire_node(name.to_string(), add(), &[relu[0], wire[0]])?;
101    Ok(wire)
102});
103
104#[derive(Debug, Clone, new)]
105pub struct Elu(pub f32);
106
107activation!(Elu, |op, name: &str, model: &mut TypedModel, inputs| {
108    cst!(model, inputs, name, zero, 0.0);
109    cst!(model, inputs, name, one, 1.0);
110    cst!(model, inputs, name, alpha, op.0);
111    let x_exp = model.wire_node(name.to_string() + ".exp", exp(), inputs)?;
112    let minus_one = model.wire_node(name.to_string() + ".minus_one", sub(), &[x_exp[0], one])?;
113    let neg = model.wire_node(name.to_string() + ".mul_alpha", mul(), &[alpha, minus_one[0]])?;
114    let test = model.wire_node(name.to_string() + ".test", Comp::LT, &[zero, inputs[0]])?;
115    let wire = model.wire_node(
116        name.to_string() + ".iff",
117        tract_core::ops::logic::Iff,
118        &[test[0], inputs[0], neg[0]],
119    )?;
120    Ok(wire)
121});
122
123#[derive(Debug, Clone, new)]
124pub struct HardSigmoid(pub f32, pub f32);
125
126activation!(HardSigmoid, |op, name: &str, model: &mut TypedModel, inputs| {
127    cst!(model, inputs, name, zero, 0.0);
128    cst!(model, inputs, name, one, 1.0);
129    cst!(model, inputs, name, alpha, op.0);
130    cst!(model, inputs, name, beta, op.1);
131    let wire = model.wire_node(name.to_string() + ".mul_alpha", mul(), &[alpha, inputs[0]])?;
132    let wire = model.wire_node(name.to_string() + ".add_beta", add(), &[beta, wire[0]])?;
133    let wire = model.wire_node(name.to_string() + ".sat-one", min(), &[one, wire[0]])?;
134    let wire = model.wire_node(name.to_string() + ".sat-zero", max(), &[zero, wire[0]])?;
135    Ok(wire)
136});
137
138#[derive(Debug, Clone, new)]
139pub struct LeakyRelu(pub f32);
140
141activation!(LeakyRelu, |op, name: &str, model: &mut TypedModel, inputs| {
142    model.wire_node(name, tract_core::ops::nn::leaky_relu(op.0), inputs)
143});
144
145#[derive(Debug, Clone, new)]
146pub struct ParametricSoftplus(pub f32, pub f32);
147
148activation!(ParametricSoftplus, |op, name: &str, model: &mut TypedModel, inputs| {
149    cst!(model, inputs, name, one, 1.0);
150    cst!(model, inputs, name, alpha, op.0);
151    cst!(model, inputs, name, beta, op.1);
152    let wire = model.wire_node(name.to_string() + ".mul_beta", mul(), &[beta, inputs[0]])?;
153    let wire = model.wire_node(name.to_string() + ".exp", exp(), &wire)?;
154    let wire = model.wire_node(name.to_string() + ".plus_one", add(), &[one, wire[0]])?;
155    let wire = model.wire_node(name.to_string() + ".ln", ln(), &wire)?;
156    let wire = model.wire_node(name.to_string() + ".mul_alpha", mul(), &[alpha, wire[0]])?;
157    Ok(wire)
158});
159
160#[derive(Debug, Clone, new)]
161pub struct ScaledTanh(pub f32, pub f32);
162
163activation!(ScaledTanh, |op, name: &str, model: &mut TypedModel, inputs| {
164    cst!(model, inputs, name, alpha, op.0);
165    cst!(model, inputs, name, beta, op.1);
166    let wire = model.wire_node(name.to_string() + ".mul_beta", mul(), &[beta, inputs[0]])?;
167    let wire = model.wire_node(name.to_string() + ".tanh", tanh(), &wire)?;
168    let wire = model.wire_node(name.to_string() + ".mul_alpha", mul(), &[alpha, wire[0]])?;
169    Ok(wire)
170});
171
172#[derive(Debug, Clone, new)]
173pub struct Selu(pub f32, pub f32);
174
175activation!(Selu, |op, name: &str, model: &mut TypedModel, inputs| {
176    cst!(model, inputs, name, zero, 0.0);
177    cst!(model, inputs, name, alpha, op.0);
178    cst!(model, inputs, name, gamma, op.1);
179    let wire = model.wire_node(name.to_string() + ".exp", exp(), inputs)?;
180    let wire = model.wire_node(name.to_string() + ".mul_alpha", mul(), &[wire[0], alpha])?;
181    let wire = model.wire_node(name.to_string() + ".sub_alpha", sub(), &[wire[0], alpha])?;
182    let test = model.wire_node(name.to_string() + ".test", Comp::LT, &[zero, inputs[0]])?;
183    let wire = model.wire_node(
184        name.to_string() + ".iff",
185        tract_core::ops::logic::Iff,
186        &[test[0], inputs[0], wire[0]],
187    )?;
188    let wire = model.wire_node(name.to_string() + ".mul_gamma", mul(), &[gamma, wire[0]])?;
189    Ok(wire)
190});
191
192#[derive(Debug, Clone, new)]
193pub struct Shrink(pub f32, pub f32);
194
195activation!(Shrink, |op, name: &str, model: &mut TypedModel, inputs| {
196    cst!(model, inputs, name, bias, op.0);
197    cst!(model, inputs, name, lambda, op.1);
198    cst!(model, inputs, name, minus_lambda, -op.1);
199    let zero = broadcast_scalar(0.0, model, inputs)?;
200    let zero = model.add_const(name.to_string() + ".zero", zero)?;
201    let test_pos =
202        model.wire_node(name.to_string() + ".test_pos", Comp::LT, &[lambda, inputs[0]])?;
203    let pos = model.wire_node(
204        name.to_string() + ".pos",
205        tract_core::ops::math::sub(),
206        &[inputs[0], bias],
207    )?;
208    let test_neg =
209        model.wire_node(name.to_string() + ".test_neg", Comp::GT, &[minus_lambda, inputs[0]])?;
210    let neg = model.wire_node(
211        name.to_string() + ".neg",
212        tract_core::ops::math::add(),
213        &[bias, inputs[0]],
214    )?;
215    let wire = model.wire_node(
216        name.to_string() + ".if_pos",
217        tract_core::ops::logic::Iff,
218        &[test_pos[0], pos[0], zero],
219    )?;
220    let wire = model.wire_node(
221        name.to_string() + ".if_neg",
222        tract_core::ops::logic::Iff,
223        &[test_neg[0], neg[0], wire[0]],
224    )?;
225    Ok(wire)
226});
227
228#[derive(Debug, Clone, new)]
229pub struct ThresholdRelu(pub f32);
230
231activation!(ThresholdRelu, |op, name: &str, model: &mut TypedModel, inputs| {
232    cst!(model, inputs, name, zero, 0.0);
233    cst!(model, inputs, name, alpha, op.0);
234    let test = model.wire_node(name.to_string() + ".test", Comp::LT, &[alpha, inputs[0]])?;
235    let wire = model.wire_node(
236        name.to_string() + ".iff",
237        tract_core::ops::logic::Iff,
238        &[test[0], inputs[0], zero],
239    )?;
240    Ok(wire)
241});
242
243fn simple_unary_rules<'r, 'p: 'r, 's: 'r>(
244    s: &mut Solver<'r>,
245    inputs: &'p [TensorProxy],
246    outputs: &'p [TensorProxy],
247) -> InferenceResult {
248    check_input_arity(inputs, 1)?;
249    check_output_arity(outputs, 1)?;
250    s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
251    s.equals(&inputs[0].shape, &outputs[0].shape)?;
252    Ok(())
253}
254
255pub fn broadcast_scalar(
256    f: f32,
257    model: &TypedModel,
258    inputs: &[OutletId],
259) -> TractResult<Arc<Tensor>> {
260    let fact = model.outlet_fact(inputs[0])?;
261    let mut tensor = tensor0(f).cast_to_dt(fact.datum_type)?.into_owned();
262    while tensor.rank() < fact.rank() {
263        tensor.insert_axis(0)?;
264    }
265    Ok(tensor.into_arc_tensor())
266}