syntaxdot_transformers/models/squeeze_bert/
encoder.rs1use std::borrow::Borrow;
17
18use syntaxdot_tch_ext::PathExt;
19use tch::Tensor;
20
21use crate::error::TransformerError;
22use crate::models::layer_output::LayerOutput;
23use crate::models::squeeze_bert::{SqueezeBertConfig, SqueezeBertLayer};
24use crate::models::Encoder;
25use crate::util::LogitsMask;
26
27#[derive(Debug)]
33pub struct SqueezeBertEncoder {
34 layers: Vec<SqueezeBertLayer>,
35}
36
37impl SqueezeBertEncoder {
38 pub fn new<'a>(
39 vs: impl Borrow<PathExt<'a>>,
40 config: &SqueezeBertConfig,
41 ) -> Result<Self, TransformerError> {
42 let vs = vs.borrow();
43
44 let layers = (0..config.num_hidden_layers)
45 .map(|layer| SqueezeBertLayer::new(vs / format!("layer_{}", layer), config))
46 .collect::<Result<_, _>>()?;
47
48 Ok(SqueezeBertEncoder { layers })
49 }
50}
51
52impl Encoder for SqueezeBertEncoder {
53 fn encode(
54 &self,
55 input: &Tensor,
56 attention_mask: Option<&Tensor>,
57 train: bool,
58 ) -> Result<Vec<LayerOutput>, TransformerError> {
59 let attention_mask = attention_mask.map(LogitsMask::from_bool_mask).transpose()?;
60
61 let mut hidden_states = input.f_permute(&[0, 2, 1])?;
63
64 let mut all_layer_outputs = Vec::with_capacity(self.layers.len() + 1);
65 all_layer_outputs.push(LayerOutput::Embedding(hidden_states.shallow_clone()));
66
67 for layer in &self.layers {
68 let layer_output = layer.forward_t(&hidden_states, attention_mask.as_ref(), train)?;
69
70 hidden_states = layer_output.output().shallow_clone();
71 all_layer_outputs.push(layer_output);
72 }
73
74 for layer_output in &mut all_layer_outputs {
76 *layer_output.output_mut() = layer_output.output().f_permute(&[0, 2, 1])?;
77 }
78
79 Ok(all_layer_outputs)
80 }
81
82 fn n_layers(&self) -> i64 {
83 self.layers.len() as i64 + 1
84 }
85}
86
87#[cfg(feature = "model-tests")]
88#[cfg(test)]
89mod tests {
90 use std::collections::BTreeSet;
91 use std::convert::TryInto;
92
93 use approx::assert_abs_diff_eq;
94 use maplit::btreeset;
95 use ndarray::{array, ArrayD};
96 use syntaxdot_tch_ext::tensor::SumDim;
97 use syntaxdot_tch_ext::RootExt;
98 use tch::nn::VarStore;
99 use tch::{Device, Kind, Tensor};
100
101 use super::SqueezeBertEncoder;
102 use crate::activations::Activation;
103 use crate::models::bert::{BertConfig, BertEmbeddings};
104 use crate::models::squeeze_bert::SqueezeBertConfig;
105 use crate::models::Encoder;
106 use crate::module::FallibleModuleT;
107
108 const SQUEEZEBERT_UNCASED: &str = env!("SQUEEZEBERT_UNCASED");
109
110 fn squeezebert_uncased_config() -> SqueezeBertConfig {
111 SqueezeBertConfig {
112 attention_probs_dropout_prob: 0.1,
113 embedding_size: 768,
114 hidden_act: Activation::Gelu,
115 hidden_dropout_prob: 0.1,
116 hidden_size: 768,
117 initializer_range: 0.02,
118 intermediate_size: 3072,
119 layer_norm_eps: 1e-12,
120 max_position_embeddings: 512,
121 num_attention_heads: 12,
122 num_hidden_layers: 12,
123 type_vocab_size: 2,
124 vocab_size: 30528,
125 q_groups: 4,
126 k_groups: 4,
127 v_groups: 4,
128 post_attention_groups: 1,
129 intermediate_groups: 4,
130 output_groups: 4,
131 }
132 }
133
134 fn layer_variables() -> BTreeSet<String> {
135 btreeset![
136 "post_attention.conv1d.bias".to_string(),
137 "post_attention.conv1d.weight".to_string(),
138 "post_attention.layer_norm.bias".to_string(),
139 "post_attention.layer_norm.weight".to_string(),
140 "attention.key.bias".to_string(),
141 "attention.key.weight".to_string(),
142 "attention.query.bias".to_string(),
143 "attention.query.weight".to_string(),
144 "attention.value.bias".to_string(),
145 "attention.value.weight".to_string(),
146 "intermediate.conv1d.bias".to_string(),
147 "intermediate.conv1d.weight".to_string(),
148 "output.conv1d.bias".to_string(),
149 "output.conv1d.weight".to_string(),
150 "output.layer_norm.bias".to_string(),
151 "output.layer_norm.weight".to_string()
152 ]
153 }
154
155 fn seqlen_to_mask(seq_lens: Tensor, max_len: i64) -> Tensor {
156 let batch_size = seq_lens.size()[0];
157 Tensor::arange(max_len, (Kind::Int, Device::Cpu))
158 .repeat(&[batch_size])
161 .view_(&[batch_size, max_len])
162 .lt_tensor(&seq_lens.unsqueeze(1))
164 }
165
166 fn varstore_variables(vs: &VarStore) -> BTreeSet<String> {
167 vs.variables()
168 .into_iter()
169 .map(|(k, _)| k)
170 .collect::<BTreeSet<_>>()
171 }
172
173 #[test]
174 fn squeeze_bert_encoder() {
175 let config = squeezebert_uncased_config();
176 let bert_config: BertConfig = (&config).into();
177
178 let mut vs = VarStore::new(Device::Cpu);
179 let root = vs.root_ext(|_| 0);
180
181 let embeddings = BertEmbeddings::new(root.sub("embeddings"), &bert_config).unwrap();
182 let encoder = SqueezeBertEncoder::new(root.sub("encoder"), &config).unwrap();
183
184 vs.load(SQUEEZEBERT_UNCASED).unwrap();
185
186 let pieces =
188 Tensor::of_slice(&[2106i64, 1996, 22091, 2080, 7861, 4783, 17644, 11440, 1029])
189 .reshape(&[1, 9]);
190
191 let embeddings = embeddings.forward_t(&pieces, false).unwrap();
192
193 let all_hidden_states = encoder.encode(&embeddings, None, false).unwrap();
194
195 let summed_last_hidden =
196 all_hidden_states
197 .last()
198 .unwrap()
199 .output()
200 .sum_dim(-1, false, Kind::Float);
201
202 let sums: ArrayD<f32> = (&summed_last_hidden).try_into().unwrap();
203
204 assert_abs_diff_eq!(
205 sums,
206 (array![[
207 -0.3894, -0.4608, -0.4127, -0.1656, -0.3927, -0.1952, -0.4998, -0.2477, -0.1676
208 ]])
209 .into_dyn(),
210 epsilon = 1e-4
211 );
212 }
213
214 #[test]
215 fn squeeze_bert_encoder_attention_mask() {
216 let config = squeezebert_uncased_config();
217 let bert_config: BertConfig = (&config).into();
218
219 let mut vs = VarStore::new(Device::Cpu);
220 let root = vs.root_ext(|_| 0);
221
222 let embeddings = BertEmbeddings::new(root.sub("embeddings"), &bert_config).unwrap();
223 let encoder = SqueezeBertEncoder::new(root.sub("encoder"), &config).unwrap();
224
225 vs.load(SQUEEZEBERT_UNCASED).unwrap();
226
227 let pieces = Tensor::of_slice(&[
230 2106i64, 1996, 22091, 2080, 7861, 4783, 17644, 11440, 1029, 0, 0, 0, 0, 0,
231 ])
232 .reshape(&[1, 14]);
233
234 let attention_mask = seqlen_to_mask(Tensor::of_slice(&[9]), pieces.size()[1]);
235
236 let embeddings = embeddings.forward_t(&pieces, false).unwrap();
237
238 let all_hidden_states = encoder
239 .encode(&embeddings, Some(&attention_mask), false)
240 .unwrap();
241
242 let summed_last_hidden = all_hidden_states
243 .last()
244 .unwrap()
245 .output()
246 .slice(-2, 0, 9, 1)
247 .sum_dim(-1, false, Kind::Float);
248
249 let sums: ArrayD<f32> = (&summed_last_hidden).try_into().unwrap();
250
251 assert_abs_diff_eq!(
252 sums,
253 (array![[
254 -0.3894, -0.4608, -0.4127, -0.1656, -0.3927, -0.1952, -0.4998, -0.2477, -0.1676
255 ]])
256 .into_dyn(),
257 epsilon = 1e-4
258 );
259 }
260
261 #[test]
262 fn squeeze_bert_encoder_names_and_shapes() {
263 let config = squeezebert_uncased_config();
265
266 let vs = VarStore::new(Device::Cpu);
267 let root = vs.root_ext(|_| 0);
268
269 let _encoder = SqueezeBertEncoder::new(root, &config).unwrap();
270
271 let variables = varstore_variables(&vs);
272
273 let mut encoder_variables = BTreeSet::new();
274 let layer_variables = layer_variables();
275 for idx in 0..config.num_hidden_layers {
276 for layer_variable in &layer_variables {
277 encoder_variables.insert(format!("layer_{}.{}", idx, layer_variable));
278 }
279 }
280
281 assert_eq!(variables, encoder_variables);
282 }
283}