syntaxdot_transformers/models/squeeze_bert/
encoder.rs

1// Copyright 2020 The SqueezeBert authors and The HuggingFace Inc. team.
2// Copyright (c) 2020 TensorDot.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8//     http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use std::borrow::Borrow;
17
18use syntaxdot_tch_ext::PathExt;
19use tch::Tensor;
20
21use crate::error::TransformerError;
22use crate::models::layer_output::LayerOutput;
23use crate::models::squeeze_bert::{SqueezeBertConfig, SqueezeBertLayer};
24use crate::models::Encoder;
25use crate::util::LogitsMask;
26
27/// SqueezeBERT encoder.
28///
29/// Even though SqueezeBERT uses *[batch_size, hidden_size, seq_len]*
30/// format internally, the encoder accepts the regular *[batch_size,
31/// seq_len, hidden_size]* format.
32#[derive(Debug)]
33pub struct SqueezeBertEncoder {
34    layers: Vec<SqueezeBertLayer>,
35}
36
37impl SqueezeBertEncoder {
38    pub fn new<'a>(
39        vs: impl Borrow<PathExt<'a>>,
40        config: &SqueezeBertConfig,
41    ) -> Result<Self, TransformerError> {
42        let vs = vs.borrow();
43
44        let layers = (0..config.num_hidden_layers)
45            .map(|layer| SqueezeBertLayer::new(vs / format!("layer_{}", layer), config))
46            .collect::<Result<_, _>>()?;
47
48        Ok(SqueezeBertEncoder { layers })
49    }
50}
51
52impl Encoder for SqueezeBertEncoder {
53    fn encode(
54        &self,
55        input: &Tensor,
56        attention_mask: Option<&Tensor>,
57        train: bool,
58    ) -> Result<Vec<LayerOutput>, TransformerError> {
59        let attention_mask = attention_mask.map(LogitsMask::from_bool_mask).transpose()?;
60
61        // [batch_size, seq_len, hidden_size] -> [batch_size, hidden_size, seq_len]
62        let mut hidden_states = input.f_permute(&[0, 2, 1])?;
63
64        let mut all_layer_outputs = Vec::with_capacity(self.layers.len() + 1);
65        all_layer_outputs.push(LayerOutput::Embedding(hidden_states.shallow_clone()));
66
67        for layer in &self.layers {
68            let layer_output = layer.forward_t(&hidden_states, attention_mask.as_ref(), train)?;
69
70            hidden_states = layer_output.output().shallow_clone();
71            all_layer_outputs.push(layer_output);
72        }
73
74        // Convert hidden states to [batch_size, seq_len, hidden_size].
75        for layer_output in &mut all_layer_outputs {
76            *layer_output.output_mut() = layer_output.output().f_permute(&[0, 2, 1])?;
77        }
78
79        Ok(all_layer_outputs)
80    }
81
82    fn n_layers(&self) -> i64 {
83        self.layers.len() as i64 + 1
84    }
85}
86
87#[cfg(feature = "model-tests")]
88#[cfg(test)]
89mod tests {
90    use std::collections::BTreeSet;
91    use std::convert::TryInto;
92
93    use approx::assert_abs_diff_eq;
94    use maplit::btreeset;
95    use ndarray::{array, ArrayD};
96    use syntaxdot_tch_ext::tensor::SumDim;
97    use syntaxdot_tch_ext::RootExt;
98    use tch::nn::VarStore;
99    use tch::{Device, Kind, Tensor};
100
101    use super::SqueezeBertEncoder;
102    use crate::activations::Activation;
103    use crate::models::bert::{BertConfig, BertEmbeddings};
104    use crate::models::squeeze_bert::SqueezeBertConfig;
105    use crate::models::Encoder;
106    use crate::module::FallibleModuleT;
107
108    const SQUEEZEBERT_UNCASED: &str = env!("SQUEEZEBERT_UNCASED");
109
110    fn squeezebert_uncased_config() -> SqueezeBertConfig {
111        SqueezeBertConfig {
112            attention_probs_dropout_prob: 0.1,
113            embedding_size: 768,
114            hidden_act: Activation::Gelu,
115            hidden_dropout_prob: 0.1,
116            hidden_size: 768,
117            initializer_range: 0.02,
118            intermediate_size: 3072,
119            layer_norm_eps: 1e-12,
120            max_position_embeddings: 512,
121            num_attention_heads: 12,
122            num_hidden_layers: 12,
123            type_vocab_size: 2,
124            vocab_size: 30528,
125            q_groups: 4,
126            k_groups: 4,
127            v_groups: 4,
128            post_attention_groups: 1,
129            intermediate_groups: 4,
130            output_groups: 4,
131        }
132    }
133
134    fn layer_variables() -> BTreeSet<String> {
135        btreeset![
136            "post_attention.conv1d.bias".to_string(),
137            "post_attention.conv1d.weight".to_string(),
138            "post_attention.layer_norm.bias".to_string(),
139            "post_attention.layer_norm.weight".to_string(),
140            "attention.key.bias".to_string(),
141            "attention.key.weight".to_string(),
142            "attention.query.bias".to_string(),
143            "attention.query.weight".to_string(),
144            "attention.value.bias".to_string(),
145            "attention.value.weight".to_string(),
146            "intermediate.conv1d.bias".to_string(),
147            "intermediate.conv1d.weight".to_string(),
148            "output.conv1d.bias".to_string(),
149            "output.conv1d.weight".to_string(),
150            "output.layer_norm.bias".to_string(),
151            "output.layer_norm.weight".to_string()
152        ]
153    }
154
155    fn seqlen_to_mask(seq_lens: Tensor, max_len: i64) -> Tensor {
156        let batch_size = seq_lens.size()[0];
157        Tensor::arange(max_len, (Kind::Int, Device::Cpu))
158            // Construct a matrix [batch_size, max_len] where each row
159            // is 0..(max_len - 1).
160            .repeat(&[batch_size])
161            .view_(&[batch_size, max_len])
162            // Time steps less than the length in seq_lens are active.
163            .lt_tensor(&seq_lens.unsqueeze(1))
164    }
165
166    fn varstore_variables(vs: &VarStore) -> BTreeSet<String> {
167        vs.variables()
168            .into_iter()
169            .map(|(k, _)| k)
170            .collect::<BTreeSet<_>>()
171    }
172
173    #[test]
174    fn squeeze_bert_encoder() {
175        let config = squeezebert_uncased_config();
176        let bert_config: BertConfig = (&config).into();
177
178        let mut vs = VarStore::new(Device::Cpu);
179        let root = vs.root_ext(|_| 0);
180
181        let embeddings = BertEmbeddings::new(root.sub("embeddings"), &bert_config).unwrap();
182        let encoder = SqueezeBertEncoder::new(root.sub("encoder"), &config).unwrap();
183
184        vs.load(SQUEEZEBERT_UNCASED).unwrap();
185
186        // Word pieces of: Did the AWO embezzle donations ?
187        let pieces =
188            Tensor::of_slice(&[2106i64, 1996, 22091, 2080, 7861, 4783, 17644, 11440, 1029])
189                .reshape(&[1, 9]);
190
191        let embeddings = embeddings.forward_t(&pieces, false).unwrap();
192
193        let all_hidden_states = encoder.encode(&embeddings, None, false).unwrap();
194
195        let summed_last_hidden =
196            all_hidden_states
197                .last()
198                .unwrap()
199                .output()
200                .sum_dim(-1, false, Kind::Float);
201
202        let sums: ArrayD<f32> = (&summed_last_hidden).try_into().unwrap();
203
204        assert_abs_diff_eq!(
205            sums,
206            (array![[
207                -0.3894, -0.4608, -0.4127, -0.1656, -0.3927, -0.1952, -0.4998, -0.2477, -0.1676
208            ]])
209            .into_dyn(),
210            epsilon = 1e-4
211        );
212    }
213
214    #[test]
215    fn squeeze_bert_encoder_attention_mask() {
216        let config = squeezebert_uncased_config();
217        let bert_config: BertConfig = (&config).into();
218
219        let mut vs = VarStore::new(Device::Cpu);
220        let root = vs.root_ext(|_| 0);
221
222        let embeddings = BertEmbeddings::new(root.sub("embeddings"), &bert_config).unwrap();
223        let encoder = SqueezeBertEncoder::new(root.sub("encoder"), &config).unwrap();
224
225        vs.load(SQUEEZEBERT_UNCASED).unwrap();
226
227        // Word pieces of: Did the AWO embezzle donations ?
228        // Add some padding to simulate inactive time steps.
229        let pieces = Tensor::of_slice(&[
230            2106i64, 1996, 22091, 2080, 7861, 4783, 17644, 11440, 1029, 0, 0, 0, 0, 0,
231        ])
232        .reshape(&[1, 14]);
233
234        let attention_mask = seqlen_to_mask(Tensor::of_slice(&[9]), pieces.size()[1]);
235
236        let embeddings = embeddings.forward_t(&pieces, false).unwrap();
237
238        let all_hidden_states = encoder
239            .encode(&embeddings, Some(&attention_mask), false)
240            .unwrap();
241
242        let summed_last_hidden = all_hidden_states
243            .last()
244            .unwrap()
245            .output()
246            .slice(-2, 0, 9, 1)
247            .sum_dim(-1, false, Kind::Float);
248
249        let sums: ArrayD<f32> = (&summed_last_hidden).try_into().unwrap();
250
251        assert_abs_diff_eq!(
252            sums,
253            (array![[
254                -0.3894, -0.4608, -0.4127, -0.1656, -0.3927, -0.1952, -0.4998, -0.2477, -0.1676
255            ]])
256            .into_dyn(),
257            epsilon = 1e-4
258        );
259    }
260
261    #[test]
262    fn squeeze_bert_encoder_names_and_shapes() {
263        // Verify that the encoders's names and shapes are correct.
264        let config = squeezebert_uncased_config();
265
266        let vs = VarStore::new(Device::Cpu);
267        let root = vs.root_ext(|_| 0);
268
269        let _encoder = SqueezeBertEncoder::new(root, &config).unwrap();
270
271        let variables = varstore_variables(&vs);
272
273        let mut encoder_variables = BTreeSet::new();
274        let layer_variables = layer_variables();
275        for idx in 0..config.num_hidden_layers {
276            for layer_variable in &layer_variables {
277                encoder_variables.insert(format!("layer_{}.{}", idx, layer_variable));
278            }
279        }
280
281        assert_eq!(variables, encoder_variables);
282    }
283}