syntaxdot_transformers/models/bert/
encoder.rs

1// Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
2// Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
3// Copyright (c) 2019 The sticker developers.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9//     http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16
17use std::borrow::Borrow;
18
19use syntaxdot_tch_ext::PathExt;
20use tch::Tensor;
21
22use crate::cow::CowTensor;
23use crate::error::TransformerError;
24use crate::models::bert::{BertConfig, BertLayer};
25use crate::models::layer_output::LayerOutput;
26use crate::models::Encoder;
27use crate::util::LogitsMask;
28
29/// BERT encoder.
30#[derive(Debug)]
31pub struct BertEncoder {
32    layers: Vec<BertLayer>,
33}
34
35impl BertEncoder {
36    pub fn new<'a>(
37        vs: impl Borrow<PathExt<'a>>,
38        config: &BertConfig,
39    ) -> Result<Self, TransformerError> {
40        let vs = vs.borrow();
41
42        let layers = (0..config.num_hidden_layers)
43            .map(|layer| BertLayer::new(vs / format!("layer_{}", layer), config))
44            .collect::<Result<_, _>>()?;
45
46        Ok(BertEncoder { layers })
47    }
48}
49
50impl Encoder for BertEncoder {
51    fn encode(
52        &self,
53        input: &Tensor,
54        attention_mask: Option<&Tensor>,
55        train: bool,
56    ) -> Result<Vec<LayerOutput>, TransformerError> {
57        let mut all_layer_outputs = Vec::with_capacity(self.layers.len() + 1);
58        all_layer_outputs.push(LayerOutput::Embedding(input.shallow_clone()));
59
60        let attention_mask = attention_mask.map(LogitsMask::from_bool_mask).transpose()?;
61
62        let mut hidden_states = CowTensor::Borrowed(input);
63        for layer in &self.layers {
64            let layer_output = layer.forward_t(&hidden_states, attention_mask.as_ref(), train)?;
65
66            hidden_states = CowTensor::Owned(layer_output.output().shallow_clone());
67            all_layer_outputs.push(layer_output);
68        }
69
70        Ok(all_layer_outputs)
71    }
72
73    fn n_layers(&self) -> i64 {
74        self.layers.len() as i64 + 1
75    }
76}
77
78#[cfg(feature = "model-tests")]
79#[cfg(test)]
80mod tests {
81    use std::collections::BTreeSet;
82    use std::convert::TryInto;
83
84    use approx::assert_abs_diff_eq;
85    use maplit::btreeset;
86    use ndarray::{array, ArrayD};
87    use syntaxdot_tch_ext::tensor::SumDim;
88    use syntaxdot_tch_ext::RootExt;
89    use tch::nn::VarStore;
90    use tch::{Device, Kind, Tensor};
91
92    use crate::activations::Activation;
93    use crate::models::bert::{BertConfig, BertEmbeddings, BertEncoder};
94    use crate::models::Encoder;
95    use crate::module::FallibleModuleT;
96
97    const BERT_BASE_GERMAN_CASED: &str = env!("BERT_BASE_GERMAN_CASED");
98
99    fn german_bert_config() -> BertConfig {
100        BertConfig {
101            attention_probs_dropout_prob: 0.1,
102            hidden_act: Activation::Gelu,
103            hidden_dropout_prob: 0.1,
104            hidden_size: 768,
105            initializer_range: 0.02,
106            intermediate_size: 3072,
107            layer_norm_eps: 1e-12,
108            max_position_embeddings: 512,
109            num_attention_heads: 12,
110            num_hidden_layers: 12,
111            type_vocab_size: 2,
112            vocab_size: 30000,
113        }
114    }
115
116    fn layer_variables() -> BTreeSet<String> {
117        btreeset![
118            "attention.output.dense.bias".to_string(),
119            "attention.output.dense.weight".to_string(),
120            "attention.output.layer_norm.bias".to_string(),
121            "attention.output.layer_norm.weight".to_string(),
122            "attention.self.key.bias".to_string(),
123            "attention.self.key.weight".to_string(),
124            "attention.self.query.bias".to_string(),
125            "attention.self.query.weight".to_string(),
126            "attention.self.value.bias".to_string(),
127            "attention.self.value.weight".to_string(),
128            "intermediate.dense.bias".to_string(),
129            "intermediate.dense.weight".to_string(),
130            "output.dense.bias".to_string(),
131            "output.dense.weight".to_string(),
132            "output.layer_norm.bias".to_string(),
133            "output.layer_norm.weight".to_string()
134        ]
135    }
136
137    fn seqlen_to_mask(seq_lens: Tensor, max_len: i64) -> Tensor {
138        let batch_size = seq_lens.size()[0];
139        Tensor::arange(max_len, (Kind::Int, Device::Cpu))
140            // Construct a matrix [batch_size, max_len] where each row
141            // is 0..(max_len - 1).
142            .repeat(&[batch_size])
143            .view_(&[batch_size, max_len])
144            // Time steps less than the length in seq_lens are active.
145            .lt_tensor(&seq_lens.unsqueeze(1))
146    }
147
148    fn varstore_variables(vs: &VarStore) -> BTreeSet<String> {
149        vs.variables()
150            .into_iter()
151            .map(|(k, _)| k)
152            .collect::<BTreeSet<_>>()
153    }
154
155    #[test]
156    fn bert_encoder() {
157        let config = german_bert_config();
158
159        let mut vs = VarStore::new(Device::Cpu);
160        let root = vs.root_ext(|_| 0);
161
162        let embeddings = BertEmbeddings::new(root.sub("embeddings"), &config).unwrap();
163        let encoder = BertEncoder::new(root.sub("encoder"), &config).unwrap();
164
165        vs.load(BERT_BASE_GERMAN_CASED).unwrap();
166
167        // Word pieces of: Veruntreute die AWO spendengeld ?
168        let pieces = Tensor::of_slice(&[133i64, 1937, 14010, 30, 32, 26939, 26962, 12558, 2739, 2])
169            .reshape(&[1, 10]);
170
171        let embeddings = embeddings.forward_t(&pieces, false).unwrap();
172
173        let all_hidden_states = encoder.encode(&embeddings, None, false).unwrap();
174
175        let summed_last_hidden =
176            all_hidden_states
177                .last()
178                .unwrap()
179                .output()
180                .sum_dim(-1, false, Kind::Float);
181
182        let sums: ArrayD<f32> = (&summed_last_hidden).try_into().unwrap();
183
184        assert_abs_diff_eq!(
185            sums,
186            (array![[
187                -1.6283, 0.2473, -0.2388, -0.4124, -0.4058, 1.4587, -0.3182, -0.9507, -0.1781,
188                0.3792
189            ]])
190            .into_dyn(),
191            epsilon = 1e-4
192        );
193    }
194
195    #[test]
196    fn bert_encoder_attention_mask() {
197        let config = german_bert_config();
198
199        let mut vs = VarStore::new(Device::Cpu);
200        let root = vs.root_ext(|_| 0);
201
202        let embeddings = BertEmbeddings::new(root.sub("embeddings"), &config).unwrap();
203        let encoder = BertEncoder::new(root.sub("encoder"), &config).unwrap();
204
205        vs.load(BERT_BASE_GERMAN_CASED).unwrap();
206
207        // Word pieces of: Veruntreute die AWO spendengeld ?
208        // Add some padding to simulate inactive time steps.
209        let pieces = Tensor::of_slice(&[
210            133i64, 1937, 14010, 30, 32, 26939, 26962, 12558, 2739, 2, 0, 0, 0, 0, 0,
211        ])
212        .reshape(&[1, 15]);
213
214        let attention_mask = seqlen_to_mask(Tensor::of_slice(&[10]), pieces.size()[1]);
215
216        let embeddings = embeddings.forward_t(&pieces, false).unwrap();
217
218        let all_hidden_states = encoder
219            .encode(&embeddings, Some(&attention_mask), false)
220            .unwrap();
221
222        let summed_last_hidden = all_hidden_states
223            .last()
224            .unwrap()
225            .output()
226            .slice(-2, 0, 10, 1)
227            .sum_dim(-1, false, Kind::Float);
228
229        let sums: ArrayD<f32> = (&summed_last_hidden).try_into().unwrap();
230
231        assert_abs_diff_eq!(
232            sums,
233            (array![[
234                -1.6283, 0.2473, -0.2388, -0.4124, -0.4058, 1.4587, -0.3182, -0.9507, -0.1781,
235                0.3792
236            ]])
237            .into_dyn(),
238            epsilon = 1e-4
239        );
240    }
241
242    #[test]
243    fn bert_encoder_names() {
244        // Verify that the encoders's names are correct.
245        // and newly-constructed models.
246        let config = german_bert_config();
247
248        let vs = VarStore::new(Device::Cpu);
249        let root = vs.root_ext(|_| 0);
250
251        let _encoder = BertEncoder::new(root, &config).unwrap();
252
253        let mut encoder_variables = BTreeSet::new();
254        let layer_variables = layer_variables();
255        for idx in 0..config.num_hidden_layers {
256            for layer_variable in &layer_variables {
257                encoder_variables.insert(format!("layer_{}.{}", idx, layer_variable));
258            }
259        }
260
261        assert_eq!(varstore_variables(&vs), encoder_variables);
262    }
263}