syntaxdot_transformers/
activations.rs1use std::convert::TryFrom;
4use std::f64;
5
6use serde::Deserialize;
7use tch::Tensor;
8
9use crate::module::FallibleModule;
10use crate::TransformerError;
11
12#[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq)]
13#[serde(try_from = "String")]
14pub enum Activation {
15 Gelu,
21
22 GeluNew,
28
29 Relu,
33}
34
35impl TryFrom<&str> for Activation {
36 type Error = TransformerError;
37
38 fn try_from(value: &str) -> Result<Self, Self::Error> {
39 match value {
40 "gelu" => Ok(Activation::Gelu),
41 "gelu_new" => Ok(Activation::GeluNew),
42 "relu" => Ok(Activation::Relu),
43 unknown => Err(TransformerError::UnknownActivationFunction {
44 activation: unknown.to_string(),
45 }),
46 }
47 }
48}
49
50impl TryFrom<String> for Activation {
51 type Error = TransformerError;
52
53 fn try_from(value: String) -> Result<Self, Self::Error> {
54 Self::try_from(value.as_str())
55 }
56}
57
58impl FallibleModule for Activation {
59 type Error = TransformerError;
60
61 fn forward(&self, input: &Tensor) -> Result<Tensor, Self::Error> {
62 match self {
63 Self::Gelu => Ok(input.f_gelu("none")?),
64 Self::GeluNew => Ok(0.5
65 * input
66 * (1.0
67 + Tensor::f_tanh(
68 &((2. / f64::consts::PI).sqrt()
69 * (input + 0.044715 * input.pow_tensor_scalar(3.0))),
70 )?)),
71 Self::Relu => Ok(input.f_relu()?),
72 }
73 }
74}
75
76#[cfg(test)]
77mod tests {
78 use std::convert::TryInto;
79
80 use approx::assert_abs_diff_eq;
81 use ndarray::{array, ArrayD};
82 use tch::Tensor;
83
84 use crate::activations::Activation;
85 use crate::module::FallibleModule;
86
87 #[test]
88 fn gelu_new_returns_correct_values() {
89 let gelu_new = Activation::GeluNew;
90 let activations: ArrayD<f32> = (&gelu_new
91 .forward(&Tensor::of_slice(&[-1., -0.5, 0., 0.5, 1.]))
92 .unwrap())
93 .try_into()
94 .unwrap();
95 assert_abs_diff_eq!(
96 activations,
97 array![-0.1588, -0.1543, 0.0000, 0.3457, 0.8412].into_dyn(),
98 epsilon = 1e-4
99 );
100 }
101}