1use std::borrow::Borrow;
18use std::iter;
19
20use syntaxdot_tch_ext::PathExt;
21use tch::nn::{Init, Linear, Module};
22use tch::{Kind, Tensor};
23
24use crate::activations::Activation;
25use crate::error::TransformerError;
26use crate::layers::{Dropout, LayerNorm};
27use crate::models::bert::config::BertConfig;
28use crate::models::layer_output::{HiddenLayer, LayerOutput};
29use crate::module::{FallibleModule, FallibleModuleT};
30use crate::util::LogitsMask;
31
32#[derive(Debug)]
33pub struct BertIntermediate {
34 dense: Linear,
35 activation: Activation,
36}
37
38impl BertIntermediate {
39 pub fn new<'a>(
40 vs: impl Borrow<PathExt<'a>>,
41 config: &BertConfig,
42 ) -> Result<Self, TransformerError> {
43 let vs = vs.borrow();
44
45 Ok(BertIntermediate {
46 activation: config.hidden_act,
47 dense: bert_linear(
48 vs / "dense",
49 config,
50 config.hidden_size,
51 config.intermediate_size,
52 "weight",
53 "bias",
54 )?,
55 })
56 }
57}
58
59impl FallibleModule for BertIntermediate {
60 type Error = TransformerError;
61
62 fn forward(&self, input: &Tensor) -> Result<Tensor, Self::Error> {
63 let hidden_states = self.dense.forward(input);
64 self.activation.forward(&hidden_states)
65 }
66}
67
68#[derive(Debug)]
70pub struct BertLayer {
71 attention: BertSelfAttention,
72 post_attention: BertSelfOutput,
73 intermediate: BertIntermediate,
74 output: BertOutput,
75}
76
77impl BertLayer {
78 pub fn new<'a>(
79 vs: impl Borrow<PathExt<'a>>,
80 config: &BertConfig,
81 ) -> Result<Self, TransformerError> {
82 let vs = vs.borrow();
83 let vs_attention = vs / "attention";
84
85 Ok(BertLayer {
86 attention: BertSelfAttention::new(vs_attention.borrow() / "self", config)?,
87 post_attention: BertSelfOutput::new(vs_attention.borrow() / "output", config)?,
88 intermediate: BertIntermediate::new(vs / "intermediate", config)?,
89 output: BertOutput::new(vs / "output", config)?,
90 })
91 }
92
93 pub(crate) fn forward_t(
94 &self,
95 input: &Tensor,
96 attention_mask: Option<&LogitsMask>,
97 train: bool,
98 ) -> Result<LayerOutput, TransformerError> {
99 let (attention_output, attention) =
100 self.attention.forward_t(input, attention_mask, train)?;
101 let post_attention_output =
102 self.post_attention
103 .forward_t(&attention_output, input, train)?;
104 let intermediate_output = self.intermediate.forward(&post_attention_output)?;
105 let output = self
106 .output
107 .forward_t(&intermediate_output, &post_attention_output, train)?;
108
109 Ok(LayerOutput::EncoderWithAttention(HiddenLayer {
110 output,
111 attention,
112 }))
113 }
114}
115
116#[derive(Debug)]
117pub struct BertOutput {
118 dense: Linear,
119 dropout: Dropout,
120 layer_norm: LayerNorm,
121}
122
123impl BertOutput {
124 pub fn new<'a>(
125 vs: impl Borrow<PathExt<'a>>,
126 config: &BertConfig,
127 ) -> Result<Self, TransformerError> {
128 let vs = vs.borrow();
129
130 let dense = bert_linear(
131 vs / "dense",
132 config,
133 config.intermediate_size,
134 config.hidden_size,
135 "weight",
136 "bias",
137 )?;
138 let dropout = Dropout::new(config.hidden_dropout_prob);
139 let layer_norm = LayerNorm::new(
140 vs / "layer_norm",
141 vec![config.hidden_size],
142 config.layer_norm_eps,
143 true,
144 );
145
146 Ok(BertOutput {
147 dense,
148 dropout,
149 layer_norm,
150 })
151 }
152
153 pub fn forward_t(
154 &self,
155 hidden_states: &Tensor,
156 input: &Tensor,
157 train: bool,
158 ) -> Result<Tensor, TransformerError> {
159 let hidden_states = self.dense.forward(hidden_states);
160 let mut hidden_states = self.dropout.forward_t(&hidden_states, train)?;
161 let _ = hidden_states.f_add_(input)?;
162 self.layer_norm.forward(&hidden_states)
163 }
164}
165
166#[derive(Debug)]
167pub struct BertSelfAttention {
168 all_head_size: i64,
169 attention_head_size: i64,
170 num_attention_heads: i64,
171
172 dropout: Dropout,
173 key: Linear,
174 query: Linear,
175 value: Linear,
176}
177
178impl BertSelfAttention {
179 pub fn new<'a>(
180 vs: impl Borrow<PathExt<'a>>,
181 config: &BertConfig,
182 ) -> Result<Self, TransformerError> {
183 let vs = vs.borrow();
184
185 let attention_head_size = config.hidden_size / config.num_attention_heads;
186 let all_head_size = config.num_attention_heads * attention_head_size;
187
188 let key = bert_linear(
189 vs / "key",
190 config,
191 config.hidden_size,
192 all_head_size,
193 "weight",
194 "bias",
195 )?;
196 let query = bert_linear(
197 vs / "query",
198 config,
199 config.hidden_size,
200 all_head_size,
201 "weight",
202 "bias",
203 )?;
204 let value = bert_linear(
205 vs / "value",
206 config,
207 config.hidden_size,
208 all_head_size,
209 "weight",
210 "bias",
211 )?;
212
213 Ok(BertSelfAttention {
214 all_head_size,
215 attention_head_size,
216 num_attention_heads: config.num_attention_heads,
217
218 dropout: Dropout::new(config.attention_probs_dropout_prob),
219 key,
220 query,
221 value,
222 })
223 }
224
225 fn forward_t(
230 &self,
231 hidden_states: &Tensor,
232 attention_mask: Option<&LogitsMask>,
233 train: bool,
234 ) -> Result<(Tensor, Tensor), TransformerError> {
235 let mixed_key_layer = self.key.forward(hidden_states);
236 let mixed_query_layer = self.query.forward(hidden_states);
237 let mixed_value_layer = self.value.forward(hidden_states);
238
239 let query_layer = self.transpose_for_scores(&mixed_query_layer)?;
240 let key_layer = self.transpose_for_scores(&mixed_key_layer)?;
241 let value_layer = self.transpose_for_scores(&mixed_value_layer)?;
242
243 let mut attention_scores = query_layer.f_matmul(&key_layer.transpose(-1, -2))?;
245 let _ = attention_scores.f_div_scalar_((self.attention_head_size as f64).sqrt())?;
246
247 if let Some(mask) = attention_mask {
248 let _ = attention_scores.f_add_(mask)?;
249 }
250
251 let attention_probs = attention_scores.f_softmax(-1, Kind::Float)?;
253
254 let attention_probs = self.dropout.forward_t(&attention_probs, train)?;
257
258 let context_layer = attention_probs.f_matmul(&value_layer)?;
259
260 let context_layer = context_layer.f_permute(&[0, 2, 1, 3])?.f_contiguous()?;
261 let mut new_context_layer_shape = context_layer.size();
262 new_context_layer_shape.splice(
263 new_context_layer_shape.len() - 2..,
264 iter::once(self.all_head_size),
265 );
266 let context_layer = context_layer.f_view_(&new_context_layer_shape)?;
267
268 Ok((context_layer, attention_scores))
269 }
270
271 fn transpose_for_scores(&self, x: &Tensor) -> Result<Tensor, TransformerError> {
272 let mut new_x_shape = x.size();
273 new_x_shape.pop();
274 new_x_shape.extend(&[self.num_attention_heads, self.attention_head_size]);
275
276 Ok(x.f_view_(&new_x_shape)?.f_permute(&[0, 2, 1, 3])?)
277 }
278}
279
280#[derive(Debug)]
281pub struct BertSelfOutput {
282 dense: Linear,
283 dropout: Dropout,
284 layer_norm: LayerNorm,
285}
286
287impl BertSelfOutput {
288 pub fn new<'a>(
289 vs: impl Borrow<PathExt<'a>>,
290 config: &BertConfig,
291 ) -> Result<BertSelfOutput, TransformerError> {
292 let vs = vs.borrow();
293
294 let dense = bert_linear(
295 vs / "dense",
296 config,
297 config.hidden_size,
298 config.hidden_size,
299 "weight",
300 "bias",
301 )?;
302 let dropout = Dropout::new(config.hidden_dropout_prob);
303 let layer_norm = LayerNorm::new(
304 vs / "layer_norm",
305 vec![config.hidden_size],
306 config.layer_norm_eps,
307 true,
308 );
309
310 Ok(BertSelfOutput {
311 dense,
312 dropout,
313 layer_norm,
314 })
315 }
316
317 pub fn forward_t(
318 &self,
319 hidden_states: &Tensor,
320 input: &Tensor,
321 train: bool,
322 ) -> Result<Tensor, TransformerError> {
323 let hidden_states = self.dense.forward(hidden_states);
324 let mut hidden_states = self.dropout.forward_t(&hidden_states, train)?;
325 let _ = hidden_states.f_add_(input)?;
326 self.layer_norm.forward(&hidden_states)
327 }
328}
329
330pub(crate) fn bert_linear<'a>(
331 vs: impl Borrow<PathExt<'a>>,
332 config: &BertConfig,
333 in_features: i64,
334 out_features: i64,
335 weight_name: &str,
336 bias_name: &str,
337) -> Result<Linear, TransformerError> {
338 let vs = vs.borrow();
339
340 Ok(Linear {
341 ws: vs.var(
342 weight_name,
343 &[out_features, in_features],
344 Init::Randn {
345 mean: 0.,
346 stdev: config.initializer_range,
347 },
348 )?,
349 bs: Some(vs.var(bias_name, &[out_features], Init::Const(0.))?),
350 })
351}