rust_bert/models/bert/encoder.rs
1// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
2// Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3// Copyright 2019 Guillaume Becquin
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7// http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14use crate::bert::attention::{BertAttention, BertIntermediate, BertOutput};
15use crate::bert::bert_model::BertConfig;
16use std::borrow::{Borrow, BorrowMut};
17use tch::{nn, Tensor};
18
19/// # BERT Layer
20/// Layer used in BERT encoders.
21/// It is made of the following blocks:
22/// - `attention`: self-attention `BertAttention` layer
23/// - `cross_attention`: (optional) cross-attention `BertAttention` layer (if the model is used as a decoder)
24/// - `is_decoder`: flag indicating if the model is used as a decoder
25/// - `intermediate`: `BertIntermediate` intermediate layer
26/// - `output`: `BertOutput` output layer
27pub struct BertLayer {
28 attention: BertAttention,
29 is_decoder: bool,
30 cross_attention: Option<BertAttention>,
31 intermediate: BertIntermediate,
32 output: BertOutput,
33}
34
35impl BertLayer {
36 /// Build a new `BertLayer`
37 ///
38 /// # Arguments
39 ///
40 /// * `p` - Variable store path for the root of the BERT model
41 /// * `config` - `BertConfig` object defining the model architecture
42 ///
43 /// # Example
44 ///
45 /// ```no_run
46 /// use rust_bert::bert::{BertConfig, BertLayer};
47 /// use rust_bert::Config;
48 /// use std::path::Path;
49 /// use tch::{nn, Device};
50 ///
51 /// let config_path = Path::new("path/to/config.json");
52 /// let device = Device::Cpu;
53 /// let p = nn::VarStore::new(device);
54 /// let config = BertConfig::from_file(config_path);
55 /// let layer: BertLayer = BertLayer::new(&p.root(), &config);
56 /// ```
57 pub fn new<'p, P>(p: P, config: &BertConfig) -> BertLayer
58 where
59 P: Borrow<nn::Path<'p>>,
60 {
61 let p = p.borrow();
62
63 let attention = BertAttention::new(p / "attention", config);
64 let (is_decoder, cross_attention) = match config.is_decoder {
65 Some(value) => {
66 if value {
67 (
68 value,
69 Some(BertAttention::new(p / "cross_attention", config)),
70 )
71 } else {
72 (value, None)
73 }
74 }
75 None => (false, None),
76 };
77
78 let intermediate = BertIntermediate::new(p / "intermediate", config);
79 let output = BertOutput::new(p / "output", config);
80
81 BertLayer {
82 attention,
83 is_decoder,
84 cross_attention,
85 intermediate,
86 output,
87 }
88 }
89
90 /// Forward pass through the layer
91 ///
92 /// # Arguments
93 ///
94 /// * `hidden_states` - input tensor of shape (*batch size*, *sequence_length*, *hidden_size*).
95 /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
96 /// * `encoder_hidden_states` - Optional encoder hidden state of shape (*batch size*, *encoder_sequence_length*, *hidden_size*). If the model is defined as a decoder and the `encoder_hidden_states` is not None, used in the cross-attention layer as keys and values (query from the decoder).
97 /// * `encoder_mask` - Optional encoder attention mask of shape (*batch size*, *encoder_sequence_length*). If the model is defined as a decoder and the `encoder_hidden_states` is not None, used to mask encoder values. Positions with value 0 will be masked.
98 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
99 ///
100 /// # Returns
101 ///
102 /// * `BertLayerOutput` containing:
103 /// - `hidden_state` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*)
104 /// - `attention_scores` - `Option<Tensor>` of shape (*batch size*, *sequence_length*, *hidden_size*)
105 /// - `cross_attention_scores` - `Option<Tensor>` of shape (*batch size*, *sequence_length*, *hidden_size*)
106 ///
107 /// # Example
108 ///
109 /// ```no_run
110 /// # use rust_bert::bert::{BertConfig, BertLayer};
111 /// # use tch::{nn, Device, Tensor, no_grad, Kind};
112 /// # use rust_bert::Config;
113 /// # use std::path::Path;
114 /// # let config_path = Path::new("path/to/config.json");
115 /// # let device = Device::Cpu;
116 /// # let vs = nn::VarStore::new(device);
117 /// # let config = BertConfig::from_file(config_path);
118 /// let layer: BertLayer = BertLayer::new(&vs.root(), &config);
119 /// let (batch_size, sequence_length, hidden_size) = (64, 128, 512);
120 /// let input_tensor = Tensor::rand(
121 /// &[batch_size, sequence_length, hidden_size],
122 /// (Kind::Float, device),
123 /// );
124 /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
125 ///
126 /// let layer_output = no_grad(|| layer.forward_t(&input_tensor, Some(&mask), None, None, false));
127 /// ```
128 pub fn forward_t(
129 &self,
130 hidden_states: &Tensor,
131 mask: Option<&Tensor>,
132 encoder_hidden_states: Option<&Tensor>,
133 encoder_mask: Option<&Tensor>,
134 train: bool,
135 ) -> BertLayerOutput {
136 let (attention_output, attention_weights) =
137 self.attention
138 .forward_t(hidden_states, mask, None, None, train);
139
140 let (attention_output, attention_scores, cross_attention_scores) =
141 if self.is_decoder & encoder_hidden_states.is_some() {
142 let (attention_output, cross_attention_weights) =
143 self.cross_attention.as_ref().unwrap().forward_t(
144 &attention_output,
145 mask,
146 encoder_hidden_states,
147 encoder_mask,
148 train,
149 );
150 (attention_output, attention_weights, cross_attention_weights)
151 } else {
152 (attention_output, attention_weights, None)
153 };
154
155 let output = self.intermediate.forward(&attention_output);
156 let output = self.output.forward_t(&output, &attention_output, train);
157
158 BertLayerOutput {
159 hidden_state: output,
160 attention_weights: attention_scores,
161 cross_attention_weights: cross_attention_scores,
162 }
163 }
164}
165
166/// # BERT Encoder
167/// Encoder used in BERT models.
168/// It is made of a Vector of `BertLayer` through which hidden states will be passed. The encoder can also be
169/// used as a decoder (with cross-attention) if `encoder_hidden_states` are provided.
170pub struct BertEncoder {
171 output_attentions: bool,
172 output_hidden_states: bool,
173 layers: Vec<BertLayer>,
174}
175
176impl BertEncoder {
177 /// Build a new `BertEncoder`
178 ///
179 /// # Arguments
180 ///
181 /// * `p` - Variable store path for the root of the BERT model
182 /// * `config` - `BertConfig` object defining the model architecture
183 ///
184 /// # Example
185 ///
186 /// ```no_run
187 /// use rust_bert::bert::{BertConfig, BertEncoder};
188 /// use rust_bert::Config;
189 /// use std::path::Path;
190 /// use tch::{nn, Device};
191 ///
192 /// let config_path = Path::new("path/to/config.json");
193 /// let device = Device::Cpu;
194 /// let p = nn::VarStore::new(device);
195 /// let config = BertConfig::from_file(config_path);
196 /// let encoder: BertEncoder = BertEncoder::new(&p.root(), &config);
197 /// ```
198 pub fn new<'p, P>(p: P, config: &BertConfig) -> BertEncoder
199 where
200 P: Borrow<nn::Path<'p>>,
201 {
202 let p = p.borrow() / "layer";
203 let output_attentions = config.output_attentions.unwrap_or(false);
204 let output_hidden_states = config.output_hidden_states.unwrap_or(false);
205
206 let mut layers: Vec<BertLayer> = vec![];
207 for layer_index in 0..config.num_hidden_layers {
208 layers.push(BertLayer::new(&p / layer_index, config));
209 }
210
211 BertEncoder {
212 output_attentions,
213 output_hidden_states,
214 layers,
215 }
216 }
217
218 /// Forward pass through the encoder
219 ///
220 /// # Arguments
221 ///
222 /// * `hidden_states` - input tensor of shape (*batch size*, *sequence_length*, *hidden_size*).
223 /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
224 /// * `encoder_hidden_states` - Optional encoder hidden state of shape (*batch size*, *encoder_sequence_length*, *hidden_size*). If the model is defined as a decoder and the `encoder_hidden_states` is not None, used in the cross-attention layer as keys and values (query from the decoder).
225 /// * `encoder_mask` - Optional encoder attention mask of shape (*batch size*, *encoder_sequence_length*). If the model is defined as a decoder and the `encoder_hidden_states` is not None, used to mask encoder values. Positions with value 0 will be masked.
226 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
227 ///
228 /// # Returns
229 ///
230 /// * `BertEncoderOutput` containing:
231 /// - `hidden_state` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*)
232 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
233 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
234 ///
235 /// # Example
236 ///
237 /// ```no_run
238 /// # use rust_bert::bert::{BertConfig, BertEncoder};
239 /// # use tch::{nn, Device, Tensor, no_grad, Kind};
240 /// # use rust_bert::Config;
241 /// # use std::path::Path;
242 /// # let config_path = Path::new("path/to/config.json");
243 /// # let device = Device::Cpu;
244 /// # let vs = nn::VarStore::new(device);
245 /// # let config = BertConfig::from_file(config_path);
246 /// let encoder: BertEncoder = BertEncoder::new(&vs.root(), &config);
247 /// let (batch_size, sequence_length, hidden_size) = (64, 128, 512);
248 /// let input_tensor = Tensor::rand(
249 /// &[batch_size, sequence_length, hidden_size],
250 /// (Kind::Float, device),
251 /// );
252 /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int8, device));
253 ///
254 /// let encoder_output =
255 /// no_grad(|| encoder.forward_t(&input_tensor, Some(&mask), None, None, false));
256 /// ```
257 pub fn forward_t(
258 &self,
259 input: &Tensor,
260 mask: Option<&Tensor>,
261 encoder_hidden_states: Option<&Tensor>,
262 encoder_mask: Option<&Tensor>,
263 train: bool,
264 ) -> BertEncoderOutput {
265 let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
266 Some(vec![])
267 } else {
268 None
269 };
270 let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
271 Some(vec![])
272 } else {
273 None
274 };
275
276 let mut hidden_state = None::<Tensor>;
277 let mut attention_weights: Option<Tensor>;
278
279 for layer in &self.layers {
280 let layer_output = if let Some(hidden_state) = &hidden_state {
281 layer.forward_t(
282 hidden_state,
283 mask,
284 encoder_hidden_states,
285 encoder_mask,
286 train,
287 )
288 } else {
289 layer.forward_t(input, mask, encoder_hidden_states, encoder_mask, train)
290 };
291
292 hidden_state = Some(layer_output.hidden_state);
293 attention_weights = layer_output.attention_weights;
294 if let Some(attentions) = all_attentions.borrow_mut() {
295 attentions.push(std::mem::take(&mut attention_weights.unwrap()));
296 };
297 if let Some(hidden_states) = all_hidden_states.borrow_mut() {
298 hidden_states.push(hidden_state.as_ref().unwrap().copy());
299 };
300 }
301
302 BertEncoderOutput {
303 hidden_state: hidden_state.unwrap(),
304 all_hidden_states,
305 all_attentions,
306 }
307 }
308}
309
310/// # BERT Pooler
311/// Pooler used in BERT models.
312/// It is made of a fully connected layer which is applied to the first sequence element.
313pub struct BertPooler {
314 lin: nn::Linear,
315}
316
317impl BertPooler {
318 /// Build a new `BertPooler`
319 ///
320 /// # Arguments
321 ///
322 /// * `p` - Variable store path for the root of the BERT model
323 /// * `config` - `BertConfig` object defining the model architecture
324 ///
325 /// # Example
326 ///
327 /// ```no_run
328 /// use rust_bert::bert::{BertConfig, BertPooler};
329 /// use rust_bert::Config;
330 /// use std::path::Path;
331 /// use tch::{nn, Device};
332 ///
333 /// let config_path = Path::new("path/to/config.json");
334 /// let device = Device::Cpu;
335 /// let p = nn::VarStore::new(device);
336 /// let config = BertConfig::from_file(config_path);
337 /// let pooler: BertPooler = BertPooler::new(&p.root(), &config);
338 /// ```
339 pub fn new<'p, P>(p: P, config: &BertConfig) -> BertPooler
340 where
341 P: Borrow<nn::Path<'p>>,
342 {
343 let p = p.borrow();
344
345 let lin = nn::linear(
346 p / "dense",
347 config.hidden_size,
348 config.hidden_size,
349 Default::default(),
350 );
351 BertPooler { lin }
352 }
353
354 /// Forward pass through the pooler
355 ///
356 /// # Arguments
357 ///
358 /// * `hidden_states` - input tensor of shape (*batch size*, *sequence_length*, *hidden_size*).
359 ///
360 /// # Returns
361 ///
362 /// * `Tensor` of shape (*batch size*, *hidden_size*)
363 ///
364 /// # Example
365 ///
366 /// ```no_run
367 /// # use rust_bert::bert::{BertConfig, BertPooler};
368 /// # use tch::{nn, Device, Tensor, no_grad, Kind};
369 /// # use rust_bert::Config;
370 /// # use std::path::Path;
371 /// # let config_path = Path::new("path/to/config.json");
372 /// # let device = Device::Cpu;
373 /// # let vs = nn::VarStore::new(device);
374 /// # let config = BertConfig::from_file(config_path);
375 /// let pooler: BertPooler = BertPooler::new(&vs.root(), &config);
376 /// let (batch_size, sequence_length, hidden_size) = (64, 128, 512);
377 /// let input_tensor = Tensor::rand(
378 /// &[batch_size, sequence_length, hidden_size],
379 /// (Kind::Float, device),
380 /// );
381 ///
382 /// let pooler_output = no_grad(|| pooler.forward(&input_tensor));
383 /// ```
384 pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
385 hidden_states.select(1, 0).apply(&self.lin).tanh()
386 }
387}
388
389/// Container for the BERT layer output.
390pub struct BertLayerOutput {
391 /// Hidden states
392 pub hidden_state: Tensor,
393 /// Self attention scores
394 pub attention_weights: Option<Tensor>,
395 /// Cross attention scores
396 pub cross_attention_weights: Option<Tensor>,
397}
398
399/// Container for the BERT encoder output.
400pub struct BertEncoderOutput {
401 /// Last hidden states from the model
402 pub hidden_state: Tensor,
403 /// Hidden states for all intermediate layers
404 pub all_hidden_states: Option<Vec<Tensor>>,
405 /// Attention weights for all intermediate layers
406 pub all_attentions: Option<Vec<Tensor>>,
407}