syntaxdot_transformers/models/bert/
embeddings.rs1use std::borrow::Borrow;
18
19use syntaxdot_tch_ext::PathExt;
20use tch::nn::Init;
21use tch::{Kind, Tensor};
22
23use crate::cow::CowTensor;
24use crate::layers::{Dropout, Embedding, LayerNorm};
25use crate::models::bert::config::BertConfig;
26use crate::module::{FallibleModule, FallibleModuleT};
27use crate::TransformerError;
28
29#[derive(Debug)]
31pub struct BertEmbeddings {
32 position_embeddings: Embedding,
33 token_type_embeddings: Embedding,
34 word_embeddings: Embedding,
35
36 layer_norm: LayerNorm,
37 dropout: Dropout,
38}
39
40impl BertEmbeddings {
41 pub fn new<'a>(
44 vs: impl Borrow<PathExt<'a>>,
45 config: &BertConfig,
46 ) -> Result<Self, TransformerError> {
47 let vs = vs.borrow();
48
49 let normal_init = Init::Randn {
50 mean: 0.,
51 stdev: config.initializer_range,
52 };
53
54 let word_embeddings = Embedding::new(
55 vs / "word_embeddings",
56 "embeddings",
57 config.vocab_size,
58 config.hidden_size,
59 normal_init,
60 )?;
61
62 let position_embeddings = Embedding::new(
63 vs / "position_embeddings",
64 "embeddings",
65 config.max_position_embeddings,
66 config.hidden_size,
67 normal_init,
68 )?;
69
70 let token_type_embeddings = Embedding::new(
71 vs / "token_type_embeddings",
72 "embeddings",
73 config.type_vocab_size,
74 config.hidden_size,
75 normal_init,
76 )?;
77
78 let layer_norm = LayerNorm::new(
79 vs / "layer_norm",
80 vec![config.hidden_size],
81 config.layer_norm_eps,
82 true,
83 );
84
85 let dropout = Dropout::new(config.hidden_dropout_prob);
86
87 Ok(BertEmbeddings {
88 position_embeddings,
89 token_type_embeddings,
90 word_embeddings,
91 layer_norm,
92 dropout,
93 })
94 }
95
96 pub fn forward(
97 &self,
98 input_ids: &Tensor,
99 token_type_ids: Option<&Tensor>,
100 position_ids: Option<&Tensor>,
101 train: bool,
102 ) -> Result<Tensor, TransformerError> {
103 let input_shape = input_ids.size();
104
105 let seq_length = input_shape[1];
106 let device = input_ids.device();
107
108 let position_ids = match position_ids {
109 Some(position_ids) => CowTensor::Borrowed(position_ids),
110 None => CowTensor::Owned(
111 Tensor::f_arange(seq_length, (Kind::Int64, device))?
112 .f_unsqueeze(0)?
113 .f_expand(&input_shape, false)?,
115 ),
116 };
117
118 let token_type_ids = match token_type_ids {
119 Some(token_type_ids) => CowTensor::Borrowed(token_type_ids),
120 None => CowTensor::Owned(Tensor::f_zeros(&input_shape, (Kind::Int64, device))?),
121 };
122
123 let input_embeddings = self.word_embeddings.forward(input_ids)?;
124 let position_embeddings = self.position_embeddings.forward(&position_ids)?;
125 let token_type_embeddings = self.token_type_embeddings.forward(&token_type_ids)?;
126
127 let embeddings = input_embeddings
128 .f_add(&position_embeddings)?
129 .f_add(&token_type_embeddings)?;
130 let embeddings = self.layer_norm.forward(&embeddings)?;
131 self.dropout.forward_t(&embeddings, train)
132 }
133}
134
135impl FallibleModuleT for BertEmbeddings {
136 type Error = TransformerError;
137
138 fn forward_t(&self, input: &Tensor, train: bool) -> Result<Tensor, Self::Error> {
139 self.forward(input, None, None, train)
140 }
141}
142
143#[cfg(feature = "model-tests")]
144#[cfg(test)]
145mod tests {
146 use std::collections::BTreeSet;
147 use std::convert::TryInto;
148
149 use approx::assert_abs_diff_eq;
150 use maplit::btreeset;
151 use ndarray::{array, ArrayD};
152 use syntaxdot_tch_ext::tensor::SumDim;
153 use syntaxdot_tch_ext::RootExt;
154 use tch::nn::VarStore;
155 use tch::{Device, Kind, Tensor};
156
157 use crate::activations::Activation;
158 use crate::models::bert::{BertConfig, BertEmbeddings};
159 use crate::module::FallibleModuleT;
160
161 const BERT_BASE_GERMAN_CASED: &str = env!("BERT_BASE_GERMAN_CASED");
162
163 fn german_bert_config() -> BertConfig {
164 BertConfig {
165 attention_probs_dropout_prob: 0.1,
166 hidden_act: Activation::Gelu,
167 hidden_dropout_prob: 0.1,
168 hidden_size: 768,
169 initializer_range: 0.02,
170 intermediate_size: 3072,
171 layer_norm_eps: 1e-12,
172 max_position_embeddings: 512,
173 num_attention_heads: 12,
174 num_hidden_layers: 12,
175 type_vocab_size: 2,
176 vocab_size: 30000,
177 }
178 }
179
180 fn varstore_variables(vs: &VarStore) -> BTreeSet<String> {
181 vs.variables()
182 .into_iter()
183 .map(|(k, _)| k)
184 .collect::<BTreeSet<_>>()
185 }
186
187 #[test]
188 fn bert_embeddings() {
189 let config = german_bert_config();
190 let mut vs = VarStore::new(Device::Cpu);
191 let root = vs.root_ext(|_| 0);
192
193 let embeddings = BertEmbeddings::new(root.sub("embeddings"), &config).unwrap();
194
195 vs.load(BERT_BASE_GERMAN_CASED).unwrap();
196
197 let pieces = Tensor::of_slice(&[133i64, 1937, 14010, 30, 32, 26939, 26962, 12558, 2739, 2])
199 .reshape(&[1, 10]);
200
201 let summed_embeddings =
202 embeddings
203 .forward_t(&pieces, false)
204 .unwrap()
205 .sum_dim(-1, false, Kind::Float);
206
207 let sums: ArrayD<f32> = (&summed_embeddings).try_into().unwrap();
208
209 assert_abs_diff_eq!(
212 sums,
213 (array![[
214 -8.0342, -7.3383, -10.1286, 7.7298, 2.3506, -2.3831, -0.5961, -4.6270, -6.5415,
215 2.1995
216 ]])
217 .into_dyn(),
218 epsilon = 1e-4
219 );
220 }
221
222 #[test]
223 fn bert_embeddings_names() {
224 let config = german_bert_config();
225
226 let vs = VarStore::new(Device::Cpu);
227 let _ = BertEmbeddings::new(vs.root_ext(|_| 0), &config);
228
229 let variables = varstore_variables(&vs);
230
231 assert_eq!(
232 variables,
233 btreeset![
234 "layer_norm.bias".to_string(),
235 "layer_norm.weight".to_string(),
236 "position_embeddings.embeddings".to_string(),
237 "token_type_embeddings.embeddings".to_string(),
238 "word_embeddings.embeddings".to_string()
239 ]
240 );
241 }
242}