rust_bert/models/bert/
encoder.rs

1// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
2// Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
3// Copyright 2019 Guillaume Becquin
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//     http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14use crate::bert::attention::{BertAttention, BertIntermediate, BertOutput};
15use crate::bert::bert_model::BertConfig;
16use std::borrow::{Borrow, BorrowMut};
17use tch::{nn, Tensor};
18
19/// # BERT Layer
20/// Layer used in BERT encoders.
21/// It is made of the following blocks:
22/// - `attention`: self-attention `BertAttention` layer
23/// - `cross_attention`: (optional) cross-attention `BertAttention` layer (if the model is used as a decoder)
24/// - `is_decoder`: flag indicating if the model is used as a decoder
25/// - `intermediate`: `BertIntermediate` intermediate layer
26/// - `output`: `BertOutput` output layer
27pub struct BertLayer {
28    attention: BertAttention,
29    is_decoder: bool,
30    cross_attention: Option<BertAttention>,
31    intermediate: BertIntermediate,
32    output: BertOutput,
33}
34
35impl BertLayer {
36    /// Build a new `BertLayer`
37    ///
38    /// # Arguments
39    ///
40    /// * `p` - Variable store path for the root of the BERT model
41    /// * `config` - `BertConfig` object defining the model architecture
42    ///
43    /// # Example
44    ///
45    /// ```no_run
46    /// use rust_bert::bert::{BertConfig, BertLayer};
47    /// use rust_bert::Config;
48    /// use std::path::Path;
49    /// use tch::{nn, Device};
50    ///
51    /// let config_path = Path::new("path/to/config.json");
52    /// let device = Device::Cpu;
53    /// let p = nn::VarStore::new(device);
54    /// let config = BertConfig::from_file(config_path);
55    /// let layer: BertLayer = BertLayer::new(&p.root(), &config);
56    /// ```
57    pub fn new<'p, P>(p: P, config: &BertConfig) -> BertLayer
58    where
59        P: Borrow<nn::Path<'p>>,
60    {
61        let p = p.borrow();
62
63        let attention = BertAttention::new(p / "attention", config);
64        let (is_decoder, cross_attention) = match config.is_decoder {
65            Some(value) => {
66                if value {
67                    (
68                        value,
69                        Some(BertAttention::new(p / "cross_attention", config)),
70                    )
71                } else {
72                    (value, None)
73                }
74            }
75            None => (false, None),
76        };
77
78        let intermediate = BertIntermediate::new(p / "intermediate", config);
79        let output = BertOutput::new(p / "output", config);
80
81        BertLayer {
82            attention,
83            is_decoder,
84            cross_attention,
85            intermediate,
86            output,
87        }
88    }
89
90    /// Forward pass through the layer
91    ///
92    /// # Arguments
93    ///
94    /// * `hidden_states` - input tensor of shape (*batch size*, *sequence_length*, *hidden_size*).
95    /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
96    /// * `encoder_hidden_states` - Optional encoder hidden state of shape (*batch size*, *encoder_sequence_length*, *hidden_size*). If the model is defined as a decoder and the `encoder_hidden_states` is not None, used in the cross-attention layer as keys and values (query from the decoder).
97    /// * `encoder_mask` - Optional encoder attention mask of shape (*batch size*, *encoder_sequence_length*). If the model is defined as a decoder and the `encoder_hidden_states` is not None, used to mask encoder values. Positions with value 0 will be masked.
98    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
99    ///
100    /// # Returns
101    ///
102    /// * `BertLayerOutput` containing:
103    ///   - `hidden_state` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*)
104    ///   - `attention_scores` - `Option<Tensor>` of shape (*batch size*, *sequence_length*, *hidden_size*)
105    ///   - `cross_attention_scores` - `Option<Tensor>` of shape (*batch size*, *sequence_length*, *hidden_size*)
106    ///
107    /// # Example
108    ///
109    /// ```no_run
110    /// # use rust_bert::bert::{BertConfig, BertLayer};
111    /// # use tch::{nn, Device, Tensor, no_grad, Kind};
112    /// # use rust_bert::Config;
113    /// # use std::path::Path;
114    /// # let config_path = Path::new("path/to/config.json");
115    /// # let device = Device::Cpu;
116    /// # let vs = nn::VarStore::new(device);
117    /// # let config = BertConfig::from_file(config_path);
118    /// let layer: BertLayer = BertLayer::new(&vs.root(), &config);
119    /// let (batch_size, sequence_length, hidden_size) = (64, 128, 512);
120    /// let input_tensor = Tensor::rand(
121    ///     &[batch_size, sequence_length, hidden_size],
122    ///     (Kind::Float, device),
123    /// );
124    /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
125    ///
126    /// let layer_output = no_grad(|| layer.forward_t(&input_tensor, Some(&mask), None, None, false));
127    /// ```
128    pub fn forward_t(
129        &self,
130        hidden_states: &Tensor,
131        mask: Option<&Tensor>,
132        encoder_hidden_states: Option<&Tensor>,
133        encoder_mask: Option<&Tensor>,
134        train: bool,
135    ) -> BertLayerOutput {
136        let (attention_output, attention_weights) =
137            self.attention
138                .forward_t(hidden_states, mask, None, None, train);
139
140        let (attention_output, attention_scores, cross_attention_scores) =
141            if self.is_decoder & encoder_hidden_states.is_some() {
142                let (attention_output, cross_attention_weights) =
143                    self.cross_attention.as_ref().unwrap().forward_t(
144                        &attention_output,
145                        mask,
146                        encoder_hidden_states,
147                        encoder_mask,
148                        train,
149                    );
150                (attention_output, attention_weights, cross_attention_weights)
151            } else {
152                (attention_output, attention_weights, None)
153            };
154
155        let output = self.intermediate.forward(&attention_output);
156        let output = self.output.forward_t(&output, &attention_output, train);
157
158        BertLayerOutput {
159            hidden_state: output,
160            attention_weights: attention_scores,
161            cross_attention_weights: cross_attention_scores,
162        }
163    }
164}
165
166/// # BERT Encoder
167/// Encoder used in BERT models.
168/// It is made of a Vector of `BertLayer` through which hidden states will be passed. The encoder can also be
169/// used as a decoder (with cross-attention) if `encoder_hidden_states` are provided.
170pub struct BertEncoder {
171    output_attentions: bool,
172    output_hidden_states: bool,
173    layers: Vec<BertLayer>,
174}
175
176impl BertEncoder {
177    /// Build a new `BertEncoder`
178    ///
179    /// # Arguments
180    ///
181    /// * `p` - Variable store path for the root of the BERT model
182    /// * `config` - `BertConfig` object defining the model architecture
183    ///
184    /// # Example
185    ///
186    /// ```no_run
187    /// use rust_bert::bert::{BertConfig, BertEncoder};
188    /// use rust_bert::Config;
189    /// use std::path::Path;
190    /// use tch::{nn, Device};
191    ///
192    /// let config_path = Path::new("path/to/config.json");
193    /// let device = Device::Cpu;
194    /// let p = nn::VarStore::new(device);
195    /// let config = BertConfig::from_file(config_path);
196    /// let encoder: BertEncoder = BertEncoder::new(&p.root(), &config);
197    /// ```
198    pub fn new<'p, P>(p: P, config: &BertConfig) -> BertEncoder
199    where
200        P: Borrow<nn::Path<'p>>,
201    {
202        let p = p.borrow() / "layer";
203        let output_attentions = config.output_attentions.unwrap_or(false);
204        let output_hidden_states = config.output_hidden_states.unwrap_or(false);
205
206        let mut layers: Vec<BertLayer> = vec![];
207        for layer_index in 0..config.num_hidden_layers {
208            layers.push(BertLayer::new(&p / layer_index, config));
209        }
210
211        BertEncoder {
212            output_attentions,
213            output_hidden_states,
214            layers,
215        }
216    }
217
218    /// Forward pass through the encoder
219    ///
220    /// # Arguments
221    ///
222    /// * `hidden_states` - input tensor of shape (*batch size*, *sequence_length*, *hidden_size*).
223    /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
224    /// * `encoder_hidden_states` - Optional encoder hidden state of shape (*batch size*, *encoder_sequence_length*, *hidden_size*). If the model is defined as a decoder and the `encoder_hidden_states` is not None, used in the cross-attention layer as keys and values (query from the decoder).
225    /// * `encoder_mask` - Optional encoder attention mask of shape (*batch size*, *encoder_sequence_length*). If the model is defined as a decoder and the `encoder_hidden_states` is not None, used to mask encoder values. Positions with value 0 will be masked.
226    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
227    ///
228    /// # Returns
229    ///
230    /// * `BertEncoderOutput` containing:
231    ///   - `hidden_state` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*)
232    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
233    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
234    ///
235    /// # Example
236    ///
237    /// ```no_run
238    /// # use rust_bert::bert::{BertConfig, BertEncoder};
239    /// # use tch::{nn, Device, Tensor, no_grad, Kind};
240    /// # use rust_bert::Config;
241    /// # use std::path::Path;
242    /// # let config_path = Path::new("path/to/config.json");
243    /// # let device = Device::Cpu;
244    /// # let vs = nn::VarStore::new(device);
245    /// # let config = BertConfig::from_file(config_path);
246    /// let encoder: BertEncoder = BertEncoder::new(&vs.root(), &config);
247    /// let (batch_size, sequence_length, hidden_size) = (64, 128, 512);
248    /// let input_tensor = Tensor::rand(
249    ///     &[batch_size, sequence_length, hidden_size],
250    ///     (Kind::Float, device),
251    /// );
252    /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int8, device));
253    ///
254    /// let encoder_output =
255    ///     no_grad(|| encoder.forward_t(&input_tensor, Some(&mask), None, None, false));
256    /// ```
257    pub fn forward_t(
258        &self,
259        input: &Tensor,
260        mask: Option<&Tensor>,
261        encoder_hidden_states: Option<&Tensor>,
262        encoder_mask: Option<&Tensor>,
263        train: bool,
264    ) -> BertEncoderOutput {
265        let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
266            Some(vec![])
267        } else {
268            None
269        };
270        let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
271            Some(vec![])
272        } else {
273            None
274        };
275
276        let mut hidden_state = None::<Tensor>;
277        let mut attention_weights: Option<Tensor>;
278
279        for layer in &self.layers {
280            let layer_output = if let Some(hidden_state) = &hidden_state {
281                layer.forward_t(
282                    hidden_state,
283                    mask,
284                    encoder_hidden_states,
285                    encoder_mask,
286                    train,
287                )
288            } else {
289                layer.forward_t(input, mask, encoder_hidden_states, encoder_mask, train)
290            };
291
292            hidden_state = Some(layer_output.hidden_state);
293            attention_weights = layer_output.attention_weights;
294            if let Some(attentions) = all_attentions.borrow_mut() {
295                attentions.push(std::mem::take(&mut attention_weights.unwrap()));
296            };
297            if let Some(hidden_states) = all_hidden_states.borrow_mut() {
298                hidden_states.push(hidden_state.as_ref().unwrap().copy());
299            };
300        }
301
302        BertEncoderOutput {
303            hidden_state: hidden_state.unwrap(),
304            all_hidden_states,
305            all_attentions,
306        }
307    }
308}
309
310/// # BERT Pooler
311/// Pooler used in BERT models.
312/// It is made of a fully connected layer which is applied to the first sequence element.
313pub struct BertPooler {
314    lin: nn::Linear,
315}
316
317impl BertPooler {
318    /// Build a new `BertPooler`
319    ///
320    /// # Arguments
321    ///
322    /// * `p` - Variable store path for the root of the BERT model
323    /// * `config` - `BertConfig` object defining the model architecture
324    ///
325    /// # Example
326    ///
327    /// ```no_run
328    /// use rust_bert::bert::{BertConfig, BertPooler};
329    /// use rust_bert::Config;
330    /// use std::path::Path;
331    /// use tch::{nn, Device};
332    ///
333    /// let config_path = Path::new("path/to/config.json");
334    /// let device = Device::Cpu;
335    /// let p = nn::VarStore::new(device);
336    /// let config = BertConfig::from_file(config_path);
337    /// let pooler: BertPooler = BertPooler::new(&p.root(), &config);
338    /// ```
339    pub fn new<'p, P>(p: P, config: &BertConfig) -> BertPooler
340    where
341        P: Borrow<nn::Path<'p>>,
342    {
343        let p = p.borrow();
344
345        let lin = nn::linear(
346            p / "dense",
347            config.hidden_size,
348            config.hidden_size,
349            Default::default(),
350        );
351        BertPooler { lin }
352    }
353
354    /// Forward pass through the pooler
355    ///
356    /// # Arguments
357    ///
358    /// * `hidden_states` - input tensor of shape (*batch size*, *sequence_length*, *hidden_size*).
359    ///
360    /// # Returns
361    ///
362    /// * `Tensor` of shape (*batch size*, *hidden_size*)
363    ///
364    /// # Example
365    ///
366    /// ```no_run
367    /// # use rust_bert::bert::{BertConfig, BertPooler};
368    /// # use tch::{nn, Device, Tensor, no_grad, Kind};
369    /// # use rust_bert::Config;
370    /// # use std::path::Path;
371    /// # let config_path = Path::new("path/to/config.json");
372    /// # let device = Device::Cpu;
373    /// # let vs = nn::VarStore::new(device);
374    /// # let config = BertConfig::from_file(config_path);
375    /// let pooler: BertPooler = BertPooler::new(&vs.root(), &config);
376    /// let (batch_size, sequence_length, hidden_size) = (64, 128, 512);
377    /// let input_tensor = Tensor::rand(
378    ///     &[batch_size, sequence_length, hidden_size],
379    ///     (Kind::Float, device),
380    /// );
381    ///
382    /// let pooler_output = no_grad(|| pooler.forward(&input_tensor));
383    /// ```
384    pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
385        hidden_states.select(1, 0).apply(&self.lin).tanh()
386    }
387}
388
389/// Container for the BERT layer output.
390pub struct BertLayerOutput {
391    /// Hidden states
392    pub hidden_state: Tensor,
393    /// Self attention scores
394    pub attention_weights: Option<Tensor>,
395    /// Cross attention scores
396    pub cross_attention_weights: Option<Tensor>,
397}
398
399/// Container for the BERT encoder output.
400pub struct BertEncoderOutput {
401    /// Last hidden states from the model
402    pub hidden_state: Tensor,
403    /// Hidden states for all intermediate layers
404    pub all_hidden_states: Option<Vec<Tensor>>,
405    /// Attention weights for all intermediate layers
406    pub all_attentions: Option<Vec<Tensor>>,
407}