syntaxdot_transformers/models/bert/
embeddings.rs

1// Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
2// Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
3// Copyright (c) 2019 The sticker developers.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9//     http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16
17use std::borrow::Borrow;
18
19use syntaxdot_tch_ext::PathExt;
20use tch::nn::Init;
21use tch::{Kind, Tensor};
22
23use crate::cow::CowTensor;
24use crate::layers::{Dropout, Embedding, LayerNorm};
25use crate::models::bert::config::BertConfig;
26use crate::module::{FallibleModule, FallibleModuleT};
27use crate::TransformerError;
28
29/// Construct the embeddings from word, position and token_type embeddings.
30#[derive(Debug)]
31pub struct BertEmbeddings {
32    position_embeddings: Embedding,
33    token_type_embeddings: Embedding,
34    word_embeddings: Embedding,
35
36    layer_norm: LayerNorm,
37    dropout: Dropout,
38}
39
40impl BertEmbeddings {
41    /// Construct new Bert embeddings with the given variable store
42    /// and Bert configuration.
43    pub fn new<'a>(
44        vs: impl Borrow<PathExt<'a>>,
45        config: &BertConfig,
46    ) -> Result<Self, TransformerError> {
47        let vs = vs.borrow();
48
49        let normal_init = Init::Randn {
50            mean: 0.,
51            stdev: config.initializer_range,
52        };
53
54        let word_embeddings = Embedding::new(
55            vs / "word_embeddings",
56            "embeddings",
57            config.vocab_size,
58            config.hidden_size,
59            normal_init,
60        )?;
61
62        let position_embeddings = Embedding::new(
63            vs / "position_embeddings",
64            "embeddings",
65            config.max_position_embeddings,
66            config.hidden_size,
67            normal_init,
68        )?;
69
70        let token_type_embeddings = Embedding::new(
71            vs / "token_type_embeddings",
72            "embeddings",
73            config.type_vocab_size,
74            config.hidden_size,
75            normal_init,
76        )?;
77
78        let layer_norm = LayerNorm::new(
79            vs / "layer_norm",
80            vec![config.hidden_size],
81            config.layer_norm_eps,
82            true,
83        );
84
85        let dropout = Dropout::new(config.hidden_dropout_prob);
86
87        Ok(BertEmbeddings {
88            position_embeddings,
89            token_type_embeddings,
90            word_embeddings,
91            layer_norm,
92            dropout,
93        })
94    }
95
96    pub fn forward(
97        &self,
98        input_ids: &Tensor,
99        token_type_ids: Option<&Tensor>,
100        position_ids: Option<&Tensor>,
101        train: bool,
102    ) -> Result<Tensor, TransformerError> {
103        let input_shape = input_ids.size();
104
105        let seq_length = input_shape[1];
106        let device = input_ids.device();
107
108        let position_ids = match position_ids {
109            Some(position_ids) => CowTensor::Borrowed(position_ids),
110            None => CowTensor::Owned(
111                Tensor::f_arange(seq_length, (Kind::Int64, device))?
112                    .f_unsqueeze(0)?
113                    // XXX: Second argument is 'implicit', do we need to set this?
114                    .f_expand(&input_shape, false)?,
115            ),
116        };
117
118        let token_type_ids = match token_type_ids {
119            Some(token_type_ids) => CowTensor::Borrowed(token_type_ids),
120            None => CowTensor::Owned(Tensor::f_zeros(&input_shape, (Kind::Int64, device))?),
121        };
122
123        let input_embeddings = self.word_embeddings.forward(input_ids)?;
124        let position_embeddings = self.position_embeddings.forward(&position_ids)?;
125        let token_type_embeddings = self.token_type_embeddings.forward(&token_type_ids)?;
126
127        let embeddings = input_embeddings
128            .f_add(&position_embeddings)?
129            .f_add(&token_type_embeddings)?;
130        let embeddings = self.layer_norm.forward(&embeddings)?;
131        self.dropout.forward_t(&embeddings, train)
132    }
133}
134
135impl FallibleModuleT for BertEmbeddings {
136    type Error = TransformerError;
137
138    fn forward_t(&self, input: &Tensor, train: bool) -> Result<Tensor, Self::Error> {
139        self.forward(input, None, None, train)
140    }
141}
142
143#[cfg(feature = "model-tests")]
144#[cfg(test)]
145mod tests {
146    use std::collections::BTreeSet;
147    use std::convert::TryInto;
148
149    use approx::assert_abs_diff_eq;
150    use maplit::btreeset;
151    use ndarray::{array, ArrayD};
152    use syntaxdot_tch_ext::tensor::SumDim;
153    use syntaxdot_tch_ext::RootExt;
154    use tch::nn::VarStore;
155    use tch::{Device, Kind, Tensor};
156
157    use crate::activations::Activation;
158    use crate::models::bert::{BertConfig, BertEmbeddings};
159    use crate::module::FallibleModuleT;
160
161    const BERT_BASE_GERMAN_CASED: &str = env!("BERT_BASE_GERMAN_CASED");
162
163    fn german_bert_config() -> BertConfig {
164        BertConfig {
165            attention_probs_dropout_prob: 0.1,
166            hidden_act: Activation::Gelu,
167            hidden_dropout_prob: 0.1,
168            hidden_size: 768,
169            initializer_range: 0.02,
170            intermediate_size: 3072,
171            layer_norm_eps: 1e-12,
172            max_position_embeddings: 512,
173            num_attention_heads: 12,
174            num_hidden_layers: 12,
175            type_vocab_size: 2,
176            vocab_size: 30000,
177        }
178    }
179
180    fn varstore_variables(vs: &VarStore) -> BTreeSet<String> {
181        vs.variables()
182            .into_iter()
183            .map(|(k, _)| k)
184            .collect::<BTreeSet<_>>()
185    }
186
187    #[test]
188    fn bert_embeddings() {
189        let config = german_bert_config();
190        let mut vs = VarStore::new(Device::Cpu);
191        let root = vs.root_ext(|_| 0);
192
193        let embeddings = BertEmbeddings::new(root.sub("embeddings"), &config).unwrap();
194
195        vs.load(BERT_BASE_GERMAN_CASED).unwrap();
196
197        // Word pieces of: Veruntreute die AWO spendengeld ?
198        let pieces = Tensor::of_slice(&[133i64, 1937, 14010, 30, 32, 26939, 26962, 12558, 2739, 2])
199            .reshape(&[1, 10]);
200
201        let summed_embeddings =
202            embeddings
203                .forward_t(&pieces, false)
204                .unwrap()
205                .sum_dim(-1, false, Kind::Float);
206
207        let sums: ArrayD<f32> = (&summed_embeddings).try_into().unwrap();
208
209        // Verify output against Hugging Face transformers Python
210        // implementation.
211        assert_abs_diff_eq!(
212            sums,
213            (array![[
214                -8.0342, -7.3383, -10.1286, 7.7298, 2.3506, -2.3831, -0.5961, -4.6270, -6.5415,
215                2.1995
216            ]])
217            .into_dyn(),
218            epsilon = 1e-4
219        );
220    }
221
222    #[test]
223    fn bert_embeddings_names() {
224        let config = german_bert_config();
225
226        let vs = VarStore::new(Device::Cpu);
227        let _ = BertEmbeddings::new(vs.root_ext(|_| 0), &config);
228
229        let variables = varstore_variables(&vs);
230
231        assert_eq!(
232            variables,
233            btreeset![
234                "layer_norm.bias".to_string(),
235                "layer_norm.weight".to_string(),
236                "position_embeddings.embeddings".to_string(),
237                "token_type_embeddings.embeddings".to_string(),
238                "word_embeddings.embeddings".to_string()
239            ]
240        );
241    }
242}