syntaxdot_transformers/
activations.rs

1//! Activation functions
2
3use std::convert::TryFrom;
4use std::f64;
5
6use serde::Deserialize;
7use tch::Tensor;
8
9use crate::module::FallibleModule;
10use crate::TransformerError;
11
12#[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq)]
13#[serde(try_from = "String")]
14pub enum Activation {
15    /// GELU activation function.
16    ///
17    /// GELU(x)=x Φ(x)
18    ///
19    /// where Φ(x) is the CDF for the Gaussian distribution.
20    Gelu,
21
22    /// GELU activation function (Google/OpenAI flavor).
23    ///
24    /// GELU(x)=x Φ(x)
25    ///
26    /// where Φ(x) is the CDF for the Gaussian distribution.
27    GeluNew,
28
29    /// ReLU activation function
30    ///
31    /// ReLU(x)=max(0,x)
32    Relu,
33}
34
35impl TryFrom<&str> for Activation {
36    type Error = TransformerError;
37
38    fn try_from(value: &str) -> Result<Self, Self::Error> {
39        match value {
40            "gelu" => Ok(Activation::Gelu),
41            "gelu_new" => Ok(Activation::GeluNew),
42            "relu" => Ok(Activation::Relu),
43            unknown => Err(TransformerError::UnknownActivationFunction {
44                activation: unknown.to_string(),
45            }),
46        }
47    }
48}
49
50impl TryFrom<String> for Activation {
51    type Error = TransformerError;
52
53    fn try_from(value: String) -> Result<Self, Self::Error> {
54        Self::try_from(value.as_str())
55    }
56}
57
58impl FallibleModule for Activation {
59    type Error = TransformerError;
60
61    fn forward(&self, input: &Tensor) -> Result<Tensor, Self::Error> {
62        match self {
63            Self::Gelu => Ok(input.f_gelu("none")?),
64            Self::GeluNew => Ok(0.5
65                * input
66                * (1.0
67                    + Tensor::f_tanh(
68                        &((2. / f64::consts::PI).sqrt()
69                            * (input + 0.044715 * input.pow_tensor_scalar(3.0))),
70                    )?)),
71            Self::Relu => Ok(input.f_relu()?),
72        }
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use std::convert::TryInto;
79
80    use approx::assert_abs_diff_eq;
81    use ndarray::{array, ArrayD};
82    use tch::Tensor;
83
84    use crate::activations::Activation;
85    use crate::module::FallibleModule;
86
87    #[test]
88    fn gelu_new_returns_correct_values() {
89        let gelu_new = Activation::GeluNew;
90        let activations: ArrayD<f32> = (&gelu_new
91            .forward(&Tensor::of_slice(&[-1., -0.5, 0., 0.5, 1.]))
92            .unwrap())
93            .try_into()
94            .unwrap();
95        assert_abs_diff_eq!(
96            activations,
97            array![-0.1588, -0.1543, 0.0000, 0.3457, 0.8412].into_dyn(),
98            epsilon = 1e-4
99        );
100    }
101}