syntaxdot_transformers/models/albert/
encoder.rs

1use std::borrow::Borrow;
2
3use syntaxdot_tch_ext::PathExt;
4use tch::nn::Module;
5use tch::Tensor;
6
7use crate::error::TransformerError;
8use crate::models::albert::{AlbertConfig, AlbertEmbeddingProjection};
9use crate::models::bert::BertLayer;
10use crate::models::layer_output::LayerOutput;
11use crate::models::Encoder;
12use crate::util::LogitsMask;
13
14/// ALBERT encoder.
15///
16/// This encoder uses the BERT encoder with two modifications:
17///
18/// 1. The embeddings are projected to fit the hidden layer size.
19/// 2. Weights are shared between layers.
20#[derive(Debug)]
21pub struct AlbertEncoder {
22    groups: Vec<BertLayer>,
23    n_layers: i64,
24    projection: AlbertEmbeddingProjection,
25}
26
27impl AlbertEncoder {
28    pub fn new<'a>(
29        vs: impl Borrow<PathExt<'a>>,
30        config: &AlbertConfig,
31    ) -> Result<Self, TransformerError> {
32        assert!(
33            config.num_hidden_groups > 0,
34            "Need at least 1 hidden group, got: {}",
35            config.num_hidden_groups
36        );
37
38        let vs = vs.borrow();
39
40        let mut groups = Vec::with_capacity(config.num_hidden_groups as usize);
41        for group_idx in 0..config.num_hidden_groups {
42            groups.push(BertLayer::new(
43                vs.sub(format!("group_{}", group_idx)).sub("inner_group_0"),
44                &config.into(),
45            )?);
46        }
47        let projection = AlbertEmbeddingProjection::new(vs, config)?;
48
49        Ok(AlbertEncoder {
50            groups,
51            n_layers: config.num_hidden_layers,
52            projection,
53        })
54    }
55}
56
57impl Encoder for AlbertEncoder {
58    fn encode(
59        &self,
60        input: &Tensor,
61        attention_mask: Option<&Tensor>,
62        train: bool,
63    ) -> Result<Vec<LayerOutput>, TransformerError> {
64        let mut all_layer_outputs = Vec::with_capacity(self.n_layers as usize + 1);
65
66        let input = self.projection.forward(input);
67
68        all_layer_outputs.push(LayerOutput::Embedding(input.shallow_clone()));
69
70        let attention_mask = attention_mask.map(LogitsMask::from_bool_mask).transpose()?;
71
72        let layers_per_group = self.n_layers as usize / self.groups.len();
73
74        let mut hidden_states = input;
75        for idx in 0..self.n_layers {
76            let layer_output = self.groups[idx as usize / layers_per_group].forward_t(
77                &hidden_states,
78                attention_mask.as_ref(),
79                train,
80            )?;
81
82            hidden_states = layer_output.output().shallow_clone();
83
84            all_layer_outputs.push(layer_output);
85        }
86
87        Ok(all_layer_outputs)
88    }
89
90    fn n_layers(&self) -> i64 {
91        self.n_layers + 1
92    }
93}
94
95#[cfg(feature = "model-tests")]
96#[cfg(test)]
97mod tests {
98    use std::collections::BTreeSet;
99    use std::convert::TryInto;
100
101    use approx::assert_abs_diff_eq;
102    use maplit::btreeset;
103    use ndarray::{array, ArrayD};
104    use syntaxdot_tch_ext::tensor::SumDim;
105    use syntaxdot_tch_ext::RootExt;
106    use tch::nn::VarStore;
107    use tch::{Device, Kind, Tensor};
108
109    use super::AlbertEncoder;
110    use crate::activations::Activation;
111    use crate::models::albert::{AlbertConfig, AlbertEmbeddings};
112    use crate::models::Encoder;
113    use crate::module::FallibleModuleT;
114
115    const ALBERT_BASE_V2: &str = env!("ALBERT_BASE_V2");
116
117    fn albert_config() -> AlbertConfig {
118        AlbertConfig {
119            attention_probs_dropout_prob: 0.,
120            embedding_size: 128,
121            hidden_act: Activation::GeluNew,
122            hidden_dropout_prob: 0.,
123            hidden_size: 768,
124            initializer_range: 0.02,
125            inner_group_num: 1,
126            intermediate_size: 3072,
127            max_position_embeddings: 512,
128            num_attention_heads: 12,
129            num_hidden_groups: 1,
130            num_hidden_layers: 12,
131            type_vocab_size: 2,
132            vocab_size: 30000,
133        }
134    }
135
136    fn layer_variables() -> BTreeSet<String> {
137        btreeset![
138            "attention.output.dense.bias".to_string(),
139            "attention.output.dense.weight".to_string(),
140            "attention.output.layer_norm.bias".to_string(),
141            "attention.output.layer_norm.weight".to_string(),
142            "attention.self.key.bias".to_string(),
143            "attention.self.key.weight".to_string(),
144            "attention.self.query.bias".to_string(),
145            "attention.self.query.weight".to_string(),
146            "attention.self.value.bias".to_string(),
147            "attention.self.value.weight".to_string(),
148            "intermediate.dense.bias".to_string(),
149            "intermediate.dense.weight".to_string(),
150            "output.dense.bias".to_string(),
151            "output.dense.weight".to_string(),
152            "output.layer_norm.bias".to_string(),
153            "output.layer_norm.weight".to_string()
154        ]
155    }
156
157    fn seqlen_to_mask(seq_lens: Tensor, max_len: i64) -> Tensor {
158        let batch_size = seq_lens.size()[0];
159        Tensor::arange(max_len, (Kind::Int, Device::Cpu))
160            // Construct a matrix [batch_size, max_len] where each row
161            // is 0..(max_len - 1).
162            .repeat(&[batch_size])
163            .view_(&[batch_size, max_len])
164            // Time steps less than the length in seq_lens are active.
165            .lt_tensor(&seq_lens.unsqueeze(1))
166    }
167
168    fn varstore_variables(vs: &VarStore) -> BTreeSet<String> {
169        vs.variables()
170            .into_iter()
171            .map(|(k, _)| k)
172            .collect::<BTreeSet<_>>()
173    }
174
175    #[test]
176    fn albert_encoder() {
177        let config = albert_config();
178
179        let mut vs = VarStore::new(Device::Cpu);
180        let root = vs.root_ext(|_| 0);
181
182        let embeddings = AlbertEmbeddings::new(root.sub("embeddings"), &config).unwrap();
183        let encoder = AlbertEncoder::new(root.sub("encoder"), &config).unwrap();
184
185        vs.load(ALBERT_BASE_V2).unwrap();
186
187        // Pierre Vinken [...]
188        let pieces = Tensor::of_slice(&[
189            5399i64, 9730, 2853, 15, 6784, 122, 315, 15, 129, 1865, 14, 686, 9,
190        ])
191        .reshape(&[1, 13]);
192
193        let embeddings = embeddings.forward_t(&pieces, false).unwrap();
194
195        let all_hidden_states = encoder.encode(&embeddings, None, false).unwrap();
196
197        let summed_last_hidden =
198            all_hidden_states
199                .last()
200                .unwrap()
201                .output()
202                .sum_dim(-1, false, Kind::Float);
203
204        let sums: ArrayD<f32> = (&summed_last_hidden).try_into().unwrap();
205
206        assert_abs_diff_eq!(
207            sums,
208            (array![[
209                -19.8755, -22.0879, -22.1255, -22.1221, -22.1466, -21.9200, -21.7490, -22.4941,
210                -21.7783, -21.9916, -21.5745, -22.1786, -21.9594
211            ]])
212            .into_dyn(),
213            epsilon = 1e-3
214        );
215    }
216
217    #[test]
218    fn albert_encoder_attention_mask() {
219        let config = albert_config();
220
221        let mut vs = VarStore::new(Device::Cpu);
222        let root = vs.root_ext(|_| 0);
223
224        let embeddings = AlbertEmbeddings::new(root.sub("embeddings"), &config).unwrap();
225        let encoder = AlbertEncoder::new(root.sub("encoder"), &config).unwrap();
226
227        vs.load(ALBERT_BASE_V2).unwrap();
228
229        // Pierre Vinken [...]
230        let pieces = Tensor::of_slice(&[
231            5399i64, 9730, 2853, 15, 6784, 122, 315, 15, 129, 1865, 14, 686, 9, 0, 0,
232        ])
233        .reshape(&[1, 15]);
234
235        let attention_mask = seqlen_to_mask(Tensor::of_slice(&[13]), pieces.size()[1]);
236
237        let embeddings = embeddings.forward_t(&pieces, false).unwrap();
238
239        let all_hidden_states = encoder
240            .encode(&embeddings, Some(&attention_mask), false)
241            .unwrap();
242
243        let summed_last_hidden =
244            all_hidden_states
245                .last()
246                .unwrap()
247                .output()
248                .sum_dim(-1, false, Kind::Float);
249
250        let sums: ArrayD<f32> = (&summed_last_hidden).try_into().unwrap();
251
252        assert_abs_diff_eq!(
253            sums,
254            (array![[
255                -19.8755, -22.0879, -22.1255, -22.1221, -22.1466, -21.9200, -21.7490, -22.4941,
256                -21.7783, -21.9916, -21.5745, -22.1786, -21.9594, -21.7832, -21.7523
257            ]])
258            .into_dyn(),
259            epsilon = 1e-3
260        );
261    }
262
263    #[test]
264    fn albert_encoder_names() {
265        // Verify that the encoders's names are correct.
266        let config = albert_config();
267
268        let vs = VarStore::new(Device::Cpu);
269        let root = vs.root_ext(|_| 0);
270
271        let _encoder = AlbertEncoder::new(root, &config).unwrap();
272
273        let mut encoder_variables = BTreeSet::new();
274        let layer_variables = layer_variables();
275        for layer_variable in &layer_variables {
276            encoder_variables.insert(format!("group_0.inner_group_0.{}", layer_variable));
277        }
278        encoder_variables.insert("embedding_projection.weight".to_string());
279        encoder_variables.insert("embedding_projection.bias".to_string());
280
281        assert_eq!(encoder_variables, varstore_variables(&vs));
282    }
283}