syntaxdot_transformers/models/bert/
layer.rs

1// Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
2// Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
3// Copyright (c) 2019 The sticker developers.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9//     http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16
17use std::borrow::Borrow;
18use std::iter;
19
20use syntaxdot_tch_ext::PathExt;
21use tch::nn::{Init, Linear, Module};
22use tch::{Kind, Tensor};
23
24use crate::activations::Activation;
25use crate::error::TransformerError;
26use crate::layers::{Dropout, LayerNorm};
27use crate::models::bert::config::BertConfig;
28use crate::models::layer_output::{HiddenLayer, LayerOutput};
29use crate::module::{FallibleModule, FallibleModuleT};
30use crate::util::LogitsMask;
31
32#[derive(Debug)]
33pub struct BertIntermediate {
34    dense: Linear,
35    activation: Activation,
36}
37
38impl BertIntermediate {
39    pub fn new<'a>(
40        vs: impl Borrow<PathExt<'a>>,
41        config: &BertConfig,
42    ) -> Result<Self, TransformerError> {
43        let vs = vs.borrow();
44
45        Ok(BertIntermediate {
46            activation: config.hidden_act,
47            dense: bert_linear(
48                vs / "dense",
49                config,
50                config.hidden_size,
51                config.intermediate_size,
52                "weight",
53                "bias",
54            )?,
55        })
56    }
57}
58
59impl FallibleModule for BertIntermediate {
60    type Error = TransformerError;
61
62    fn forward(&self, input: &Tensor) -> Result<Tensor, Self::Error> {
63        let hidden_states = self.dense.forward(input);
64        self.activation.forward(&hidden_states)
65    }
66}
67
68/// BERT layer.
69#[derive(Debug)]
70pub struct BertLayer {
71    attention: BertSelfAttention,
72    post_attention: BertSelfOutput,
73    intermediate: BertIntermediate,
74    output: BertOutput,
75}
76
77impl BertLayer {
78    pub fn new<'a>(
79        vs: impl Borrow<PathExt<'a>>,
80        config: &BertConfig,
81    ) -> Result<Self, TransformerError> {
82        let vs = vs.borrow();
83        let vs_attention = vs / "attention";
84
85        Ok(BertLayer {
86            attention: BertSelfAttention::new(vs_attention.borrow() / "self", config)?,
87            post_attention: BertSelfOutput::new(vs_attention.borrow() / "output", config)?,
88            intermediate: BertIntermediate::new(vs / "intermediate", config)?,
89            output: BertOutput::new(vs / "output", config)?,
90        })
91    }
92
93    pub(crate) fn forward_t(
94        &self,
95        input: &Tensor,
96        attention_mask: Option<&LogitsMask>,
97        train: bool,
98    ) -> Result<LayerOutput, TransformerError> {
99        let (attention_output, attention) =
100            self.attention.forward_t(input, attention_mask, train)?;
101        let post_attention_output =
102            self.post_attention
103                .forward_t(&attention_output, input, train)?;
104        let intermediate_output = self.intermediate.forward(&post_attention_output)?;
105        let output = self
106            .output
107            .forward_t(&intermediate_output, &post_attention_output, train)?;
108
109        Ok(LayerOutput::EncoderWithAttention(HiddenLayer {
110            output,
111            attention,
112        }))
113    }
114}
115
116#[derive(Debug)]
117pub struct BertOutput {
118    dense: Linear,
119    dropout: Dropout,
120    layer_norm: LayerNorm,
121}
122
123impl BertOutput {
124    pub fn new<'a>(
125        vs: impl Borrow<PathExt<'a>>,
126        config: &BertConfig,
127    ) -> Result<Self, TransformerError> {
128        let vs = vs.borrow();
129
130        let dense = bert_linear(
131            vs / "dense",
132            config,
133            config.intermediate_size,
134            config.hidden_size,
135            "weight",
136            "bias",
137        )?;
138        let dropout = Dropout::new(config.hidden_dropout_prob);
139        let layer_norm = LayerNorm::new(
140            vs / "layer_norm",
141            vec![config.hidden_size],
142            config.layer_norm_eps,
143            true,
144        );
145
146        Ok(BertOutput {
147            dense,
148            dropout,
149            layer_norm,
150        })
151    }
152
153    pub fn forward_t(
154        &self,
155        hidden_states: &Tensor,
156        input: &Tensor,
157        train: bool,
158    ) -> Result<Tensor, TransformerError> {
159        let hidden_states = self.dense.forward(hidden_states);
160        let mut hidden_states = self.dropout.forward_t(&hidden_states, train)?;
161        let _ = hidden_states.f_add_(input)?;
162        self.layer_norm.forward(&hidden_states)
163    }
164}
165
166#[derive(Debug)]
167pub struct BertSelfAttention {
168    all_head_size: i64,
169    attention_head_size: i64,
170    num_attention_heads: i64,
171
172    dropout: Dropout,
173    key: Linear,
174    query: Linear,
175    value: Linear,
176}
177
178impl BertSelfAttention {
179    pub fn new<'a>(
180        vs: impl Borrow<PathExt<'a>>,
181        config: &BertConfig,
182    ) -> Result<Self, TransformerError> {
183        let vs = vs.borrow();
184
185        let attention_head_size = config.hidden_size / config.num_attention_heads;
186        let all_head_size = config.num_attention_heads * attention_head_size;
187
188        let key = bert_linear(
189            vs / "key",
190            config,
191            config.hidden_size,
192            all_head_size,
193            "weight",
194            "bias",
195        )?;
196        let query = bert_linear(
197            vs / "query",
198            config,
199            config.hidden_size,
200            all_head_size,
201            "weight",
202            "bias",
203        )?;
204        let value = bert_linear(
205            vs / "value",
206            config,
207            config.hidden_size,
208            all_head_size,
209            "weight",
210            "bias",
211        )?;
212
213        Ok(BertSelfAttention {
214            all_head_size,
215            attention_head_size,
216            num_attention_heads: config.num_attention_heads,
217
218            dropout: Dropout::new(config.attention_probs_dropout_prob),
219            key,
220            query,
221            value,
222        })
223    }
224
225    /// Apply self-attention.
226    ///
227    /// Return the contextualized representations and attention
228    /// probabilities.
229    fn forward_t(
230        &self,
231        hidden_states: &Tensor,
232        attention_mask: Option<&LogitsMask>,
233        train: bool,
234    ) -> Result<(Tensor, Tensor), TransformerError> {
235        let mixed_key_layer = self.key.forward(hidden_states);
236        let mixed_query_layer = self.query.forward(hidden_states);
237        let mixed_value_layer = self.value.forward(hidden_states);
238
239        let query_layer = self.transpose_for_scores(&mixed_query_layer)?;
240        let key_layer = self.transpose_for_scores(&mixed_key_layer)?;
241        let value_layer = self.transpose_for_scores(&mixed_value_layer)?;
242
243        // Get the raw attention scores.
244        let mut attention_scores = query_layer.f_matmul(&key_layer.transpose(-1, -2))?;
245        let _ = attention_scores.f_div_scalar_((self.attention_head_size as f64).sqrt())?;
246
247        if let Some(mask) = attention_mask {
248            let _ = attention_scores.f_add_(mask)?;
249        }
250
251        // Convert the raw attention scores into a probability distribution.
252        let attention_probs = attention_scores.f_softmax(-1, Kind::Float)?;
253
254        // Drop out entire tokens to attend to, following the original
255        // transformer paper.
256        let attention_probs = self.dropout.forward_t(&attention_probs, train)?;
257
258        let context_layer = attention_probs.f_matmul(&value_layer)?;
259
260        let context_layer = context_layer.f_permute(&[0, 2, 1, 3])?.f_contiguous()?;
261        let mut new_context_layer_shape = context_layer.size();
262        new_context_layer_shape.splice(
263            new_context_layer_shape.len() - 2..,
264            iter::once(self.all_head_size),
265        );
266        let context_layer = context_layer.f_view_(&new_context_layer_shape)?;
267
268        Ok((context_layer, attention_scores))
269    }
270
271    fn transpose_for_scores(&self, x: &Tensor) -> Result<Tensor, TransformerError> {
272        let mut new_x_shape = x.size();
273        new_x_shape.pop();
274        new_x_shape.extend(&[self.num_attention_heads, self.attention_head_size]);
275
276        Ok(x.f_view_(&new_x_shape)?.f_permute(&[0, 2, 1, 3])?)
277    }
278}
279
280#[derive(Debug)]
281pub struct BertSelfOutput {
282    dense: Linear,
283    dropout: Dropout,
284    layer_norm: LayerNorm,
285}
286
287impl BertSelfOutput {
288    pub fn new<'a>(
289        vs: impl Borrow<PathExt<'a>>,
290        config: &BertConfig,
291    ) -> Result<BertSelfOutput, TransformerError> {
292        let vs = vs.borrow();
293
294        let dense = bert_linear(
295            vs / "dense",
296            config,
297            config.hidden_size,
298            config.hidden_size,
299            "weight",
300            "bias",
301        )?;
302        let dropout = Dropout::new(config.hidden_dropout_prob);
303        let layer_norm = LayerNorm::new(
304            vs / "layer_norm",
305            vec![config.hidden_size],
306            config.layer_norm_eps,
307            true,
308        );
309
310        Ok(BertSelfOutput {
311            dense,
312            dropout,
313            layer_norm,
314        })
315    }
316
317    pub fn forward_t(
318        &self,
319        hidden_states: &Tensor,
320        input: &Tensor,
321        train: bool,
322    ) -> Result<Tensor, TransformerError> {
323        let hidden_states = self.dense.forward(hidden_states);
324        let mut hidden_states = self.dropout.forward_t(&hidden_states, train)?;
325        let _ = hidden_states.f_add_(input)?;
326        self.layer_norm.forward(&hidden_states)
327    }
328}
329
330pub(crate) fn bert_linear<'a>(
331    vs: impl Borrow<PathExt<'a>>,
332    config: &BertConfig,
333    in_features: i64,
334    out_features: i64,
335    weight_name: &str,
336    bias_name: &str,
337) -> Result<Linear, TransformerError> {
338    let vs = vs.borrow();
339
340    Ok(Linear {
341        ws: vs.var(
342            weight_name,
343            &[out_features, in_features],
344            Init::Randn {
345                mean: 0.,
346                stdev: config.initializer_range,
347            },
348        )?,
349        bs: Some(vs.var(bias_name, &[out_features], Init::Const(0.))?),
350    })
351}