syntaxdot_transformers/models/sinusoidal/
mod.rs

1//! Word embeddings with sinusoidal position embeddings.
2
3use std::borrow::Borrow;
4
5use syntaxdot_tch_ext::PathExt;
6use tch::nn::Init;
7use tch::{Kind, Tensor};
8
9use crate::layers::{Dropout, Embedding, LayerNorm};
10use crate::models::traits::WordEmbeddingsConfig;
11use crate::module::{FallibleModule, FallibleModuleT};
12use crate::util::SinusoidalPositions;
13use crate::TransformerError;
14
15/// Embeddings layer that uses word embeddings with sinusoidal positions.
16///
17/// The word embeddings in this layer can be optimized, but the sinusoidal
18/// positions are generated on-the-fly.
19#[derive(Debug)]
20pub struct SinusoidalEmbeddings {
21    dropout: Dropout,
22    layer_norm: LayerNorm,
23    p_norm: Option<f64>,
24    word_embeddings: Embedding,
25}
26
27impl SinusoidalEmbeddings {
28    /// Create piece embeddings with sinusoidal position embeddings.
29    ///
30    /// If a `p_norm` is specified position embeddings are normalized
31    /// using this norm.
32    pub fn new<'a>(
33        vs: impl Borrow<PathExt<'a>>,
34        config: &impl WordEmbeddingsConfig,
35        p_norm: Option<f64>,
36    ) -> Result<SinusoidalEmbeddings, TransformerError> {
37        let vs = vs.borrow();
38
39        let normal_init = Init::Randn {
40            mean: 0.,
41            stdev: config.initializer_range(),
42        };
43
44        let word_embeddings = Embedding::new(
45            vs / "word_embeddings",
46            "embeddings",
47            config.vocab_size(),
48            config.dims(),
49            normal_init,
50        )?;
51
52        let layer_norm = LayerNorm::new(
53            vs / "layer_norm",
54            vec![config.dims()],
55            config.layer_norm_eps(),
56            true,
57        );
58
59        let dropout = Dropout::new(config.dropout());
60
61        Ok(SinusoidalEmbeddings {
62            dropout,
63            layer_norm,
64            p_norm,
65            word_embeddings,
66        })
67    }
68}
69
70impl FallibleModuleT for SinusoidalEmbeddings {
71    type Error = TransformerError;
72
73    fn forward_t(&self, input_ids: &Tensor, train: bool) -> Result<Tensor, Self::Error> {
74        let word_embeddings = self.word_embeddings.forward(input_ids)?;
75
76        let (_, seq_length, embedding_dim) = word_embeddings.size3()?;
77
78        let position_embeddings: Tensor = SinusoidalPositions::sinusoidal_positions(
79            seq_length,
80            embedding_dim,
81            self.p_norm,
82            (Kind::Float, word_embeddings.device()),
83        )?;
84
85        let mut embeddings = tch::no_grad::<Result<_, TransformerError>, _>(|| {
86            Ok(word_embeddings.f_add(&position_embeddings.f_unsqueeze(0)?)?)
87        })?;
88        embeddings = self.layer_norm.forward(&embeddings)?;
89        self.dropout.forward_t(&embeddings, train)
90    }
91}
92
93#[cfg(feature = "model-tests")]
94#[cfg(test)]
95mod tests {
96    use std::convert::TryInto;
97
98    use approx::assert_abs_diff_eq;
99    use ndarray::{array, ArrayD};
100    use syntaxdot_tch_ext::tensor::SumDim;
101    use syntaxdot_tch_ext::RootExt;
102    use tch::nn::VarStore;
103    use tch::{Device, Kind, Tensor};
104
105    use crate::activations::Activation;
106    use crate::models::bert::BertConfig;
107    use crate::models::sinusoidal::SinusoidalEmbeddings;
108    use crate::module::FallibleModuleT;
109
110    // BERT is not trained with sinusoidal embeddings, but we will just use
111    // its piece embeddings to verify that the output of the
112    // SinusoidalEmbeddings module hasn't changed.
113    const BERT_BASE_GERMAN_CASED: &str = env!("BERT_BASE_GERMAN_CASED");
114
115    fn german_bert_config() -> BertConfig {
116        BertConfig {
117            attention_probs_dropout_prob: 0.1,
118            hidden_act: Activation::Gelu,
119            hidden_dropout_prob: 0.1,
120            hidden_size: 768,
121            initializer_range: 0.02,
122            intermediate_size: 3072,
123            layer_norm_eps: 1e-12,
124            max_position_embeddings: 512,
125            num_attention_heads: 12,
126            num_hidden_layers: 12,
127            type_vocab_size: 2,
128            vocab_size: 30000,
129        }
130    }
131
132    #[test]
133    fn sinusoidal_embeddings_are_unchanged_without_norm() {
134        let sums: ArrayD<f32> = get_and_sum_test_embeddings(None);
135
136        // Verify output against known output (to avoid future breakage).
137        assert_abs_diff_eq!(
138            sums,
139            (array![[
140                -7.433159, -7.3248596, -6.981781, -5.287575, -5.657837, -6.173279, -6.0414734,
141                -6.0355415, -5.6972923, -4.800411
142            ]])
143            .into_dyn(),
144            epsilon = 1e-4
145        );
146    }
147
148    #[test]
149    fn sinusoidal_embeddings_are_unchanged_with_norm() {
150        let sums: ArrayD<f32> = get_and_sum_test_embeddings(Some(2.0));
151
152        // Verify output against known output (to avoid future breakage).
153        assert_abs_diff_eq!(
154            sums,
155            (array![[
156                -5.801262, -7.803936, -9.95359, 5.575783, 0.79592514, -3.6844482, -2.3470383,
157                -5.6341896, -6.2476273, 1.965559
158            ]])
159            .into_dyn(),
160            epsilon = 1e-4
161        );
162    }
163
164    fn get_and_sum_test_embeddings(p_norm: Option<f64>) -> ArrayD<f32> {
165        let config = german_bert_config();
166        let mut vs = VarStore::new(Device::Cpu);
167        let root = vs.root_ext(|_| 0);
168
169        let embeddings =
170            SinusoidalEmbeddings::new(root.sub("embeddings"), &config, p_norm).unwrap();
171
172        vs.load(BERT_BASE_GERMAN_CASED).unwrap();
173
174        // Word pieces of: Veruntreute die AWO spendengeld ?
175        let pieces = Tensor::of_slice(&[133i64, 1937, 14010, 30, 32, 26939, 26962, 12558, 2739, 2])
176            .reshape(&[1, 10]);
177
178        let summed_embeddings =
179            embeddings
180                .forward_t(&pieces, false)
181                .unwrap()
182                .sum_dim(-1, false, Kind::Float);
183
184        (&summed_embeddings).try_into().unwrap()
185    }
186}