syntaxdot_transformers/models/bert/
encoder.rs1use std::borrow::Borrow;
18
19use syntaxdot_tch_ext::PathExt;
20use tch::Tensor;
21
22use crate::cow::CowTensor;
23use crate::error::TransformerError;
24use crate::models::bert::{BertConfig, BertLayer};
25use crate::models::layer_output::LayerOutput;
26use crate::models::Encoder;
27use crate::util::LogitsMask;
28
29#[derive(Debug)]
31pub struct BertEncoder {
32 layers: Vec<BertLayer>,
33}
34
35impl BertEncoder {
36 pub fn new<'a>(
37 vs: impl Borrow<PathExt<'a>>,
38 config: &BertConfig,
39 ) -> Result<Self, TransformerError> {
40 let vs = vs.borrow();
41
42 let layers = (0..config.num_hidden_layers)
43 .map(|layer| BertLayer::new(vs / format!("layer_{}", layer), config))
44 .collect::<Result<_, _>>()?;
45
46 Ok(BertEncoder { layers })
47 }
48}
49
50impl Encoder for BertEncoder {
51 fn encode(
52 &self,
53 input: &Tensor,
54 attention_mask: Option<&Tensor>,
55 train: bool,
56 ) -> Result<Vec<LayerOutput>, TransformerError> {
57 let mut all_layer_outputs = Vec::with_capacity(self.layers.len() + 1);
58 all_layer_outputs.push(LayerOutput::Embedding(input.shallow_clone()));
59
60 let attention_mask = attention_mask.map(LogitsMask::from_bool_mask).transpose()?;
61
62 let mut hidden_states = CowTensor::Borrowed(input);
63 for layer in &self.layers {
64 let layer_output = layer.forward_t(&hidden_states, attention_mask.as_ref(), train)?;
65
66 hidden_states = CowTensor::Owned(layer_output.output().shallow_clone());
67 all_layer_outputs.push(layer_output);
68 }
69
70 Ok(all_layer_outputs)
71 }
72
73 fn n_layers(&self) -> i64 {
74 self.layers.len() as i64 + 1
75 }
76}
77
78#[cfg(feature = "model-tests")]
79#[cfg(test)]
80mod tests {
81 use std::collections::BTreeSet;
82 use std::convert::TryInto;
83
84 use approx::assert_abs_diff_eq;
85 use maplit::btreeset;
86 use ndarray::{array, ArrayD};
87 use syntaxdot_tch_ext::tensor::SumDim;
88 use syntaxdot_tch_ext::RootExt;
89 use tch::nn::VarStore;
90 use tch::{Device, Kind, Tensor};
91
92 use crate::activations::Activation;
93 use crate::models::bert::{BertConfig, BertEmbeddings, BertEncoder};
94 use crate::models::Encoder;
95 use crate::module::FallibleModuleT;
96
97 const BERT_BASE_GERMAN_CASED: &str = env!("BERT_BASE_GERMAN_CASED");
98
99 fn german_bert_config() -> BertConfig {
100 BertConfig {
101 attention_probs_dropout_prob: 0.1,
102 hidden_act: Activation::Gelu,
103 hidden_dropout_prob: 0.1,
104 hidden_size: 768,
105 initializer_range: 0.02,
106 intermediate_size: 3072,
107 layer_norm_eps: 1e-12,
108 max_position_embeddings: 512,
109 num_attention_heads: 12,
110 num_hidden_layers: 12,
111 type_vocab_size: 2,
112 vocab_size: 30000,
113 }
114 }
115
116 fn layer_variables() -> BTreeSet<String> {
117 btreeset![
118 "attention.output.dense.bias".to_string(),
119 "attention.output.dense.weight".to_string(),
120 "attention.output.layer_norm.bias".to_string(),
121 "attention.output.layer_norm.weight".to_string(),
122 "attention.self.key.bias".to_string(),
123 "attention.self.key.weight".to_string(),
124 "attention.self.query.bias".to_string(),
125 "attention.self.query.weight".to_string(),
126 "attention.self.value.bias".to_string(),
127 "attention.self.value.weight".to_string(),
128 "intermediate.dense.bias".to_string(),
129 "intermediate.dense.weight".to_string(),
130 "output.dense.bias".to_string(),
131 "output.dense.weight".to_string(),
132 "output.layer_norm.bias".to_string(),
133 "output.layer_norm.weight".to_string()
134 ]
135 }
136
137 fn seqlen_to_mask(seq_lens: Tensor, max_len: i64) -> Tensor {
138 let batch_size = seq_lens.size()[0];
139 Tensor::arange(max_len, (Kind::Int, Device::Cpu))
140 .repeat(&[batch_size])
143 .view_(&[batch_size, max_len])
144 .lt_tensor(&seq_lens.unsqueeze(1))
146 }
147
148 fn varstore_variables(vs: &VarStore) -> BTreeSet<String> {
149 vs.variables()
150 .into_iter()
151 .map(|(k, _)| k)
152 .collect::<BTreeSet<_>>()
153 }
154
155 #[test]
156 fn bert_encoder() {
157 let config = german_bert_config();
158
159 let mut vs = VarStore::new(Device::Cpu);
160 let root = vs.root_ext(|_| 0);
161
162 let embeddings = BertEmbeddings::new(root.sub("embeddings"), &config).unwrap();
163 let encoder = BertEncoder::new(root.sub("encoder"), &config).unwrap();
164
165 vs.load(BERT_BASE_GERMAN_CASED).unwrap();
166
167 let pieces = Tensor::of_slice(&[133i64, 1937, 14010, 30, 32, 26939, 26962, 12558, 2739, 2])
169 .reshape(&[1, 10]);
170
171 let embeddings = embeddings.forward_t(&pieces, false).unwrap();
172
173 let all_hidden_states = encoder.encode(&embeddings, None, false).unwrap();
174
175 let summed_last_hidden =
176 all_hidden_states
177 .last()
178 .unwrap()
179 .output()
180 .sum_dim(-1, false, Kind::Float);
181
182 let sums: ArrayD<f32> = (&summed_last_hidden).try_into().unwrap();
183
184 assert_abs_diff_eq!(
185 sums,
186 (array![[
187 -1.6283, 0.2473, -0.2388, -0.4124, -0.4058, 1.4587, -0.3182, -0.9507, -0.1781,
188 0.3792
189 ]])
190 .into_dyn(),
191 epsilon = 1e-4
192 );
193 }
194
195 #[test]
196 fn bert_encoder_attention_mask() {
197 let config = german_bert_config();
198
199 let mut vs = VarStore::new(Device::Cpu);
200 let root = vs.root_ext(|_| 0);
201
202 let embeddings = BertEmbeddings::new(root.sub("embeddings"), &config).unwrap();
203 let encoder = BertEncoder::new(root.sub("encoder"), &config).unwrap();
204
205 vs.load(BERT_BASE_GERMAN_CASED).unwrap();
206
207 let pieces = Tensor::of_slice(&[
210 133i64, 1937, 14010, 30, 32, 26939, 26962, 12558, 2739, 2, 0, 0, 0, 0, 0,
211 ])
212 .reshape(&[1, 15]);
213
214 let attention_mask = seqlen_to_mask(Tensor::of_slice(&[10]), pieces.size()[1]);
215
216 let embeddings = embeddings.forward_t(&pieces, false).unwrap();
217
218 let all_hidden_states = encoder
219 .encode(&embeddings, Some(&attention_mask), false)
220 .unwrap();
221
222 let summed_last_hidden = all_hidden_states
223 .last()
224 .unwrap()
225 .output()
226 .slice(-2, 0, 10, 1)
227 .sum_dim(-1, false, Kind::Float);
228
229 let sums: ArrayD<f32> = (&summed_last_hidden).try_into().unwrap();
230
231 assert_abs_diff_eq!(
232 sums,
233 (array![[
234 -1.6283, 0.2473, -0.2388, -0.4124, -0.4058, 1.4587, -0.3182, -0.9507, -0.1781,
235 0.3792
236 ]])
237 .into_dyn(),
238 epsilon = 1e-4
239 );
240 }
241
242 #[test]
243 fn bert_encoder_names() {
244 let config = german_bert_config();
247
248 let vs = VarStore::new(Device::Cpu);
249 let root = vs.root_ext(|_| 0);
250
251 let _encoder = BertEncoder::new(root, &config).unwrap();
252
253 let mut encoder_variables = BTreeSet::new();
254 let layer_variables = layer_variables();
255 for idx in 0..config.num_hidden_layers {
256 for layer_variable in &layer_variables {
257 encoder_variables.insert(format!("layer_{}.{}", idx, layer_variable));
258 }
259 }
260
261 assert_eq!(varstore_variables(&vs), encoder_variables);
262 }
263}