syntaxdot_transformers/models/albert/
encoder.rs1use std::borrow::Borrow;
2
3use syntaxdot_tch_ext::PathExt;
4use tch::nn::Module;
5use tch::Tensor;
6
7use crate::error::TransformerError;
8use crate::models::albert::{AlbertConfig, AlbertEmbeddingProjection};
9use crate::models::bert::BertLayer;
10use crate::models::layer_output::LayerOutput;
11use crate::models::Encoder;
12use crate::util::LogitsMask;
13
14#[derive(Debug)]
21pub struct AlbertEncoder {
22 groups: Vec<BertLayer>,
23 n_layers: i64,
24 projection: AlbertEmbeddingProjection,
25}
26
27impl AlbertEncoder {
28 pub fn new<'a>(
29 vs: impl Borrow<PathExt<'a>>,
30 config: &AlbertConfig,
31 ) -> Result<Self, TransformerError> {
32 assert!(
33 config.num_hidden_groups > 0,
34 "Need at least 1 hidden group, got: {}",
35 config.num_hidden_groups
36 );
37
38 let vs = vs.borrow();
39
40 let mut groups = Vec::with_capacity(config.num_hidden_groups as usize);
41 for group_idx in 0..config.num_hidden_groups {
42 groups.push(BertLayer::new(
43 vs.sub(format!("group_{}", group_idx)).sub("inner_group_0"),
44 &config.into(),
45 )?);
46 }
47 let projection = AlbertEmbeddingProjection::new(vs, config)?;
48
49 Ok(AlbertEncoder {
50 groups,
51 n_layers: config.num_hidden_layers,
52 projection,
53 })
54 }
55}
56
57impl Encoder for AlbertEncoder {
58 fn encode(
59 &self,
60 input: &Tensor,
61 attention_mask: Option<&Tensor>,
62 train: bool,
63 ) -> Result<Vec<LayerOutput>, TransformerError> {
64 let mut all_layer_outputs = Vec::with_capacity(self.n_layers as usize + 1);
65
66 let input = self.projection.forward(input);
67
68 all_layer_outputs.push(LayerOutput::Embedding(input.shallow_clone()));
69
70 let attention_mask = attention_mask.map(LogitsMask::from_bool_mask).transpose()?;
71
72 let layers_per_group = self.n_layers as usize / self.groups.len();
73
74 let mut hidden_states = input;
75 for idx in 0..self.n_layers {
76 let layer_output = self.groups[idx as usize / layers_per_group].forward_t(
77 &hidden_states,
78 attention_mask.as_ref(),
79 train,
80 )?;
81
82 hidden_states = layer_output.output().shallow_clone();
83
84 all_layer_outputs.push(layer_output);
85 }
86
87 Ok(all_layer_outputs)
88 }
89
90 fn n_layers(&self) -> i64 {
91 self.n_layers + 1
92 }
93}
94
95#[cfg(feature = "model-tests")]
96#[cfg(test)]
97mod tests {
98 use std::collections::BTreeSet;
99 use std::convert::TryInto;
100
101 use approx::assert_abs_diff_eq;
102 use maplit::btreeset;
103 use ndarray::{array, ArrayD};
104 use syntaxdot_tch_ext::tensor::SumDim;
105 use syntaxdot_tch_ext::RootExt;
106 use tch::nn::VarStore;
107 use tch::{Device, Kind, Tensor};
108
109 use super::AlbertEncoder;
110 use crate::activations::Activation;
111 use crate::models::albert::{AlbertConfig, AlbertEmbeddings};
112 use crate::models::Encoder;
113 use crate::module::FallibleModuleT;
114
115 const ALBERT_BASE_V2: &str = env!("ALBERT_BASE_V2");
116
117 fn albert_config() -> AlbertConfig {
118 AlbertConfig {
119 attention_probs_dropout_prob: 0.,
120 embedding_size: 128,
121 hidden_act: Activation::GeluNew,
122 hidden_dropout_prob: 0.,
123 hidden_size: 768,
124 initializer_range: 0.02,
125 inner_group_num: 1,
126 intermediate_size: 3072,
127 max_position_embeddings: 512,
128 num_attention_heads: 12,
129 num_hidden_groups: 1,
130 num_hidden_layers: 12,
131 type_vocab_size: 2,
132 vocab_size: 30000,
133 }
134 }
135
136 fn layer_variables() -> BTreeSet<String> {
137 btreeset![
138 "attention.output.dense.bias".to_string(),
139 "attention.output.dense.weight".to_string(),
140 "attention.output.layer_norm.bias".to_string(),
141 "attention.output.layer_norm.weight".to_string(),
142 "attention.self.key.bias".to_string(),
143 "attention.self.key.weight".to_string(),
144 "attention.self.query.bias".to_string(),
145 "attention.self.query.weight".to_string(),
146 "attention.self.value.bias".to_string(),
147 "attention.self.value.weight".to_string(),
148 "intermediate.dense.bias".to_string(),
149 "intermediate.dense.weight".to_string(),
150 "output.dense.bias".to_string(),
151 "output.dense.weight".to_string(),
152 "output.layer_norm.bias".to_string(),
153 "output.layer_norm.weight".to_string()
154 ]
155 }
156
157 fn seqlen_to_mask(seq_lens: Tensor, max_len: i64) -> Tensor {
158 let batch_size = seq_lens.size()[0];
159 Tensor::arange(max_len, (Kind::Int, Device::Cpu))
160 .repeat(&[batch_size])
163 .view_(&[batch_size, max_len])
164 .lt_tensor(&seq_lens.unsqueeze(1))
166 }
167
168 fn varstore_variables(vs: &VarStore) -> BTreeSet<String> {
169 vs.variables()
170 .into_iter()
171 .map(|(k, _)| k)
172 .collect::<BTreeSet<_>>()
173 }
174
175 #[test]
176 fn albert_encoder() {
177 let config = albert_config();
178
179 let mut vs = VarStore::new(Device::Cpu);
180 let root = vs.root_ext(|_| 0);
181
182 let embeddings = AlbertEmbeddings::new(root.sub("embeddings"), &config).unwrap();
183 let encoder = AlbertEncoder::new(root.sub("encoder"), &config).unwrap();
184
185 vs.load(ALBERT_BASE_V2).unwrap();
186
187 let pieces = Tensor::of_slice(&[
189 5399i64, 9730, 2853, 15, 6784, 122, 315, 15, 129, 1865, 14, 686, 9,
190 ])
191 .reshape(&[1, 13]);
192
193 let embeddings = embeddings.forward_t(&pieces, false).unwrap();
194
195 let all_hidden_states = encoder.encode(&embeddings, None, false).unwrap();
196
197 let summed_last_hidden =
198 all_hidden_states
199 .last()
200 .unwrap()
201 .output()
202 .sum_dim(-1, false, Kind::Float);
203
204 let sums: ArrayD<f32> = (&summed_last_hidden).try_into().unwrap();
205
206 assert_abs_diff_eq!(
207 sums,
208 (array![[
209 -19.8755, -22.0879, -22.1255, -22.1221, -22.1466, -21.9200, -21.7490, -22.4941,
210 -21.7783, -21.9916, -21.5745, -22.1786, -21.9594
211 ]])
212 .into_dyn(),
213 epsilon = 1e-3
214 );
215 }
216
217 #[test]
218 fn albert_encoder_attention_mask() {
219 let config = albert_config();
220
221 let mut vs = VarStore::new(Device::Cpu);
222 let root = vs.root_ext(|_| 0);
223
224 let embeddings = AlbertEmbeddings::new(root.sub("embeddings"), &config).unwrap();
225 let encoder = AlbertEncoder::new(root.sub("encoder"), &config).unwrap();
226
227 vs.load(ALBERT_BASE_V2).unwrap();
228
229 let pieces = Tensor::of_slice(&[
231 5399i64, 9730, 2853, 15, 6784, 122, 315, 15, 129, 1865, 14, 686, 9, 0, 0,
232 ])
233 .reshape(&[1, 15]);
234
235 let attention_mask = seqlen_to_mask(Tensor::of_slice(&[13]), pieces.size()[1]);
236
237 let embeddings = embeddings.forward_t(&pieces, false).unwrap();
238
239 let all_hidden_states = encoder
240 .encode(&embeddings, Some(&attention_mask), false)
241 .unwrap();
242
243 let summed_last_hidden =
244 all_hidden_states
245 .last()
246 .unwrap()
247 .output()
248 .sum_dim(-1, false, Kind::Float);
249
250 let sums: ArrayD<f32> = (&summed_last_hidden).try_into().unwrap();
251
252 assert_abs_diff_eq!(
253 sums,
254 (array![[
255 -19.8755, -22.0879, -22.1255, -22.1221, -22.1466, -21.9200, -21.7490, -22.4941,
256 -21.7783, -21.9916, -21.5745, -22.1786, -21.9594, -21.7832, -21.7523
257 ]])
258 .into_dyn(),
259 epsilon = 1e-3
260 );
261 }
262
263 #[test]
264 fn albert_encoder_names() {
265 let config = albert_config();
267
268 let vs = VarStore::new(Device::Cpu);
269 let root = vs.root_ext(|_| 0);
270
271 let _encoder = AlbertEncoder::new(root, &config).unwrap();
272
273 let mut encoder_variables = BTreeSet::new();
274 let layer_variables = layer_variables();
275 for layer_variable in &layer_variables {
276 encoder_variables.insert(format!("group_0.inner_group_0.{}", layer_variable));
277 }
278 encoder_variables.insert("embedding_projection.weight".to_string());
279 encoder_variables.insert("embedding_projection.bias".to_string());
280
281 assert_eq!(encoder_variables, varstore_variables(&vs));
282 }
283}