rust_bert/models/bert/bert_model.rs
1// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
2// Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3// Copyright 2019 Guillaume Becquin
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7// http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14use crate::bert::encoder::{BertEncoder, BertPooler};
15use crate::common::activations::Activation;
16use crate::common::dropout::Dropout;
17use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair;
18use crate::common::linear::{linear_no_bias, LinearNoBias};
19use crate::{
20 bert::embeddings::{BertEmbedding, BertEmbeddings},
21 common::activations::TensorFunction,
22};
23use crate::{Config, RustBertError};
24use serde::{Deserialize, Serialize};
25use std::borrow::Borrow;
26use std::collections::HashMap;
27use tch::nn::init::DEFAULT_KAIMING_UNIFORM;
28use tch::{nn, Kind, Tensor};
29
30/// # BERT Pretrained model weight files
31pub struct BertModelResources;
32
33/// # BERT Pretrained model config files
34pub struct BertConfigResources;
35
36/// # BERT Pretrained model vocab files
37pub struct BertVocabResources;
38
39impl BertModelResources {
40 /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/bert>. Modified with conversion to C-array format.
41 pub const BERT: (&'static str, &'static str) = (
42 "bert/model",
43 "https://huggingface.co/bert-base-uncased/resolve/main/rust_model.ot",
44 );
45 /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/bert>. Modified with conversion to C-array format.
46 pub const BERT_LARGE: (&'static str, &'static str) = (
47 "bert-large/model",
48 "https://huggingface.co/bert-large-uncased/resolve/main/rust_model.ot",
49 );
50 /// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at <https://github.com/dbmdz/berts>. Modified with conversion to C-array format.
51 pub const BERT_NER: (&'static str, &'static str) = (
52 "bert-ner/model",
53 "https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/resolve/main/rust_model.ot",
54 );
55 /// Shared under Apache 2.0 license by Hugging Face Inc at <https://github.com/huggingface/transformers/tree/master/examples/question-answering>. Modified with conversion to C-array format.
56 pub const BERT_QA: (&'static str, &'static str) = (
57 "bert-qa/model",
58 "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/rust_model.ot",
59 );
60 /// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens>. Modified with conversion to C-array format.
61 pub const BERT_BASE_NLI_MEAN_TOKENS: (&'static str, &'static str) = (
62 "bert-base-nli-mean-tokens/model",
63 "https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens/resolve/main/rust_model.ot",
64 );
65 /// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2>. Modified with conversion to C-array format.
66 pub const ALL_MINI_LM_L12_V2: (&'static str, &'static str) = (
67 "all-mini-lm-l12-v2/model",
68 "https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/rust_model.ot",
69 );
70 /// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2>. Modified with conversion to C-array format.
71 pub const ALL_MINI_LM_L6_V2: (&'static str, &'static str) = (
72 "all-mini-lm-l6-v2/model",
73 "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/rust_model.ot",
74 );
75}
76
77impl BertConfigResources {
78 /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/bert>. Modified with conversion to C-array format.
79 pub const BERT: (&'static str, &'static str) = (
80 "bert/config",
81 "https://huggingface.co/bert-base-uncased/resolve/main/config.json",
82 );
83 /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/bert>. Modified with conversion to C-array format.
84 pub const BERT_LARGE: (&'static str, &'static str) = (
85 "bert-large/config",
86 "https://huggingface.co/bert-large-uncased/resolve/main/config.json",
87 );
88 /// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at <https://github.com/dbmdz/berts>. Modified with conversion to C-array format.
89 pub const BERT_NER: (&'static str, &'static str) = (
90 "bert-ner/config",
91 "https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/resolve/main/config.json",
92 );
93 /// Shared under Apache 2.0 license by Hugging Face Inc at <https://github.com/huggingface/transformers/tree/master/examples/question-answering>. Modified with conversion to C-array format.
94 pub const BERT_QA: (&'static str, &'static str) = (
95 "bert-qa/config",
96 "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/config.json",
97 );
98 /// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens>. Modified with conversion to C-array format.
99 pub const BERT_BASE_NLI_MEAN_TOKENS: (&'static str, &'static str) = (
100 "bert-base-nli-mean-tokens/config",
101 "https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens/resolve/main/config.json",
102 );
103 /// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2>. Modified with conversion to C-array format.
104 pub const ALL_MINI_LM_L12_V2: (&'static str, &'static str) = (
105 "all-mini-lm-l12-v2/config",
106 "https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/config.json",
107 );
108 /// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2>. Modified with conversion to C-array format.
109 pub const ALL_MINI_LM_L6_V2: (&'static str, &'static str) = (
110 "all-mini-lm-l6-v2/config",
111 "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/config.json",
112 );
113}
114
115impl BertVocabResources {
116 /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/bert>. Modified with conversion to C-array format.
117 pub const BERT: (&'static str, &'static str) = (
118 "bert/vocab",
119 "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt",
120 );
121 /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/bert>. Modified with conversion to C-array format.
122 pub const BERT_LARGE: (&'static str, &'static str) = (
123 "bert-large/vocab",
124 "https://huggingface.co/bert-large-uncased/resolve/main/vocab.txt",
125 );
126 /// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at <https://github.com/dbmdz/berts>. Modified with conversion to C-array format.
127 pub const BERT_NER: (&'static str, &'static str) = (
128 "bert-ner/vocab",
129 "https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/resolve/main/vocab.txt",
130 );
131 /// Shared under Apache 2.0 license by Hugging Face Inc at <https://github.com/huggingface/transformers/tree/master/examples/question-answering>. Modified with conversion to C-array format.
132 pub const BERT_QA: (&'static str, &'static str) = (
133 "bert-qa/vocab",
134 "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
135 );
136 /// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens>. Modified with conversion to C-array format.
137 pub const BERT_BASE_NLI_MEAN_TOKENS: (&'static str, &'static str) = (
138 "bert-base-nli-mean-tokens/vocab",
139 "https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens/resolve/main/vocab.txt",
140 );
141 /// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2>. Modified with conversion to C-array format.
142 pub const ALL_MINI_LM_L12_V2: (&'static str, &'static str) = (
143 "all-mini-lm-l12-v2/vocab",
144 "https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/vocab.txt",
145 );
146 /// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2>. Modified with conversion to C-array format.
147 pub const ALL_MINI_LM_L6_V2: (&'static str, &'static str) = (
148 "all-mini-lm-l6-v2/vocab",
149 "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/vocab.txt",
150 );
151}
152
153#[derive(Debug, Serialize, Deserialize, Clone)]
154/// # BERT model configuration
155/// Defines the BERT model architecture (e.g. number of layers, hidden layer size, label mapping...)
156pub struct BertConfig {
157 pub hidden_act: Activation,
158 pub attention_probs_dropout_prob: f64,
159 pub hidden_dropout_prob: f64,
160 pub hidden_size: i64,
161 pub initializer_range: f32,
162 pub intermediate_size: i64,
163 pub max_position_embeddings: i64,
164 pub num_attention_heads: i64,
165 pub num_hidden_layers: i64,
166 pub type_vocab_size: i64,
167 pub vocab_size: i64,
168 pub output_attentions: Option<bool>,
169 pub output_hidden_states: Option<bool>,
170 pub is_decoder: Option<bool>,
171 pub id2label: Option<HashMap<i64, String>>,
172 pub label2id: Option<HashMap<String, i64>>,
173}
174
175impl Config for BertConfig {}
176
177impl Default for BertConfig {
178 fn default() -> Self {
179 BertConfig {
180 hidden_act: Activation::gelu,
181 attention_probs_dropout_prob: 0.1,
182 hidden_dropout_prob: 0.1,
183 hidden_size: 768,
184 initializer_range: 0.02,
185 intermediate_size: 3072,
186 max_position_embeddings: 512,
187 num_attention_heads: 12,
188 num_hidden_layers: 12,
189 type_vocab_size: 2,
190 vocab_size: 30522,
191 output_attentions: None,
192 output_hidden_states: None,
193 is_decoder: None,
194 id2label: None,
195 label2id: None,
196 }
197 }
198}
199
200/// # BERT Base model
201/// Base architecture for BERT models. Task-specific models will be built from this common base model
202/// It is made of the following blocks:
203/// - `embeddings`: `token`, `position` and `segment_id` embeddings
204/// - `encoder`: Encoder (transformer) made of a vector of layers. Each layer is made of a self-attention layer, an intermediate (linear) and output (linear + layer norm) layers
205/// - `pooler`: linear layer applied to the first element of the sequence (*MASK* token)
206/// - `is_decoder`: Flag indicating if the model is used as a decoder. If set to true, a causal mask will be applied to hide future positions that should not be attended to.
207pub struct BertModel<T: BertEmbedding> {
208 embeddings: T,
209 encoder: BertEncoder,
210 pooler: Option<BertPooler>,
211 is_decoder: bool,
212}
213
214/// Defines the implementation of the BertModel. The BERT model shares many similarities with RoBERTa, main difference being the embeddings.
215/// Therefore the forward pass of the model is shared and the type of embedding used is abstracted away. This allows to create
216/// `BertModel<RobertaEmbeddings>` or `BertModel<BertEmbeddings>` for each model type.
217impl<T: BertEmbedding> BertModel<T> {
218 /// Build a new `BertModel`
219 ///
220 /// # Arguments
221 ///
222 /// * `p` - Variable store path for the root of the BERT model
223 /// * `config` - `BertConfig` object defining the model architecture and decoder status
224 ///
225 /// # Example
226 ///
227 /// ```no_run
228 /// use rust_bert::bert::{BertConfig, BertEmbeddings, BertModel};
229 /// use rust_bert::Config;
230 /// use std::path::Path;
231 /// use tch::{nn, Device};
232 ///
233 /// let config_path = Path::new("path/to/config.json");
234 /// let device = Device::Cpu;
235 /// let p = nn::VarStore::new(device);
236 /// let config = BertConfig::from_file(config_path);
237 /// let bert: BertModel<BertEmbeddings> = BertModel::new(&p.root() / "bert", &config);
238 /// ```
239 pub fn new<'p, P>(p: P, config: &BertConfig) -> BertModel<T>
240 where
241 P: Borrow<nn::Path<'p>>,
242 {
243 let p = p.borrow();
244
245 let is_decoder = config.is_decoder.unwrap_or(false);
246 let embeddings = T::new(p / "embeddings", config);
247 let encoder = BertEncoder::new(p / "encoder", config);
248 let pooler = Some(BertPooler::new(p / "pooler", config));
249
250 BertModel {
251 embeddings,
252 encoder,
253 pooler,
254 is_decoder,
255 }
256 }
257
258 /// Build a new `BertModel` with an optional Pooling layer
259 ///
260 /// # Arguments
261 ///
262 /// * `p` - Variable store path for the root of the BERT model
263 /// * `config` - `BertConfig` object defining the model architecture and decoder status
264 /// * `add_pooling_layer` - Enable/Disable an optional pooling layer at the end of the model
265 ///
266 /// # Example
267 ///
268 /// ```no_run
269 /// use rust_bert::bert::{BertConfig, BertEmbeddings, BertModel};
270 /// use rust_bert::Config;
271 /// use std::path::Path;
272 /// use tch::{nn, Device};
273 ///
274 /// let config_path = Path::new("path/to/config.json");
275 /// let device = Device::Cpu;
276 /// let p = nn::VarStore::new(device);
277 /// let config = BertConfig::from_file(config_path);
278 /// let bert: BertModel<BertEmbeddings> =
279 /// BertModel::new_with_optional_pooler(&p.root() / "bert", &config, false);
280 /// ```
281 pub fn new_with_optional_pooler<'p, P>(
282 p: P,
283 config: &BertConfig,
284 add_pooling_layer: bool,
285 ) -> BertModel<T>
286 where
287 P: Borrow<nn::Path<'p>>,
288 {
289 let p = p.borrow();
290
291 let is_decoder = config.is_decoder.unwrap_or(false);
292 let embeddings = T::new(p / "embeddings", config);
293 let encoder = BertEncoder::new(p / "encoder", config);
294
295 let pooler = {
296 if add_pooling_layer {
297 Some(BertPooler::new(p / "pooler", config))
298 } else {
299 None
300 }
301 };
302
303 BertModel {
304 embeddings,
305 encoder,
306 pooler,
307 is_decoder,
308 }
309 }
310
311 /// Forward pass through the model
312 ///
313 /// # Arguments
314 ///
315 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
316 /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
317 /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
318 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
319 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
320 /// * `encoder_hidden_states` - Optional encoder hidden state of shape (*batch size*, *encoder_sequence_length*, *hidden_size*). If the model is defined as a decoder and the `encoder_hidden_states` is not None, used in the cross-attention layer as keys and values (query from the decoder).
321 /// * `encoder_mask` - Optional encoder attention mask of shape (*batch size*, *encoder_sequence_length*). If the model is defined as a decoder and the `encoder_hidden_states` is not None, used to mask encoder values. Positions with value 0 will be masked.
322 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
323 ///
324 /// # Returns
325 ///
326 /// * `BertOutput` containing:
327 /// - `hidden_state` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*)
328 /// - `pooled_output` - `Tensor` of shape (*batch size*, *hidden_size*)
329 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
330 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
331 ///
332 /// # Example
333 ///
334 /// ```no_run
335 /// # use rust_bert::bert::{BertModel, BertConfig, BertEmbeddings};
336 /// # use tch::{nn, Device, Tensor, no_grad, Kind};
337 /// # use rust_bert::Config;
338 /// # use std::path::Path;
339 /// # let config_path = Path::new("path/to/config.json");
340 /// # let device = Device::Cpu;
341 /// # let vs = nn::VarStore::new(device);
342 /// # let config = BertConfig::from_file(config_path);
343 /// # let bert_model: BertModel<BertEmbeddings> = BertModel::new(&vs.root(), &config);
344 /// let (batch_size, sequence_length) = (64, 128);
345 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Kind::Int64, device));
346 /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
347 /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
348 /// let position_ids = Tensor::arange(sequence_length, (Kind::Int64, device))
349 /// .expand(&[batch_size, sequence_length], true);
350 ///
351 /// let model_output = no_grad(|| {
352 /// bert_model
353 /// .forward_t(
354 /// Some(&input_tensor),
355 /// Some(&mask),
356 /// Some(&token_type_ids),
357 /// Some(&position_ids),
358 /// None,
359 /// None,
360 /// None,
361 /// false,
362 /// )
363 /// .unwrap()
364 /// });
365 /// ```
366 pub fn forward_t(
367 &self,
368 input_ids: Option<&Tensor>,
369 mask: Option<&Tensor>,
370 token_type_ids: Option<&Tensor>,
371 position_ids: Option<&Tensor>,
372 input_embeds: Option<&Tensor>,
373 encoder_hidden_states: Option<&Tensor>,
374 encoder_mask: Option<&Tensor>,
375 train: bool,
376 ) -> Result<BertModelOutput, RustBertError> {
377 let (input_shape, device) =
378 get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeds)?;
379
380 let calc_mask = Tensor::ones(&input_shape, (Kind::Int8, device));
381 let mask = mask.unwrap_or(&calc_mask);
382
383 let extended_attention_mask = match mask.dim() {
384 3 => mask.unsqueeze(1),
385 2 => {
386 if self.is_decoder {
387 let seq_ids = Tensor::arange(input_shape[1], (Kind::Int8, device));
388 let causal_mask = seq_ids.unsqueeze(0).unsqueeze(0).repeat([
389 input_shape[0],
390 input_shape[1],
391 1,
392 ]);
393 let causal_mask = causal_mask.le_tensor(&seq_ids.unsqueeze(0).unsqueeze(-1));
394 causal_mask * mask.unsqueeze(1).unsqueeze(1)
395 } else {
396 mask.unsqueeze(1).unsqueeze(1)
397 }
398 }
399 _ => {
400 return Err(RustBertError::ValueError(
401 "Invalid attention mask dimension, must be 2 or 3".into(),
402 ));
403 }
404 };
405
406 let embedding_output = self.embeddings.forward_t(
407 input_ids,
408 token_type_ids,
409 position_ids,
410 input_embeds,
411 train,
412 )?;
413
414 let extended_attention_mask: Tensor = ((extended_attention_mask
415 .ones_like()
416 .bitwise_xor_tensor(&extended_attention_mask))
417 * -10000.0)
418 .to_kind(embedding_output.kind());
419
420 let encoder_extended_attention_mask: Option<Tensor> =
421 if self.is_decoder & encoder_hidden_states.is_some() {
422 let encoder_hidden_states = encoder_hidden_states.as_ref().unwrap();
423 let encoder_hidden_states_shape = encoder_hidden_states.size();
424 let encoder_mask = match encoder_mask {
425 Some(value) => value.copy(),
426 None => Tensor::ones(
427 [
428 encoder_hidden_states_shape[0],
429 encoder_hidden_states_shape[1],
430 ],
431 (Kind::Int8, device),
432 ),
433 };
434 match encoder_mask.dim() {
435 2 => Some(encoder_mask.unsqueeze(1).unsqueeze(1)),
436 3 => Some(encoder_mask.unsqueeze(1)),
437 _ => {
438 return Err(RustBertError::ValueError(
439 "Invalid attention mask dimension, must be 2 or 3".into(),
440 ));
441 }
442 }
443 } else {
444 None
445 };
446
447 let encoder_output = self.encoder.forward_t(
448 &embedding_output,
449 Some(&extended_attention_mask),
450 encoder_hidden_states,
451 encoder_extended_attention_mask.as_ref(),
452 train,
453 );
454
455 let pooled_output = self
456 .pooler
457 .as_ref()
458 .map(|pooler| pooler.forward(&encoder_output.hidden_state));
459
460 Ok(BertModelOutput {
461 hidden_state: encoder_output.hidden_state,
462 pooled_output,
463 all_hidden_states: encoder_output.all_hidden_states,
464 all_attentions: encoder_output.all_attentions,
465 })
466 }
467}
468
469pub struct BertPredictionHeadTransform {
470 dense: nn::Linear,
471 activation: TensorFunction,
472 layer_norm: nn::LayerNorm,
473}
474
475impl BertPredictionHeadTransform {
476 pub fn new<'p, P>(p: P, config: &BertConfig) -> BertPredictionHeadTransform
477 where
478 P: Borrow<nn::Path<'p>>,
479 {
480 let p = p.borrow();
481
482 let dense = nn::linear(
483 p / "dense",
484 config.hidden_size,
485 config.hidden_size,
486 Default::default(),
487 );
488 let activation = config.hidden_act.get_function();
489 let layer_norm_config = nn::LayerNormConfig {
490 eps: 1e-12,
491 ..Default::default()
492 };
493 let layer_norm =
494 nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
495
496 BertPredictionHeadTransform {
497 dense,
498 activation,
499 layer_norm,
500 }
501 }
502
503 pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
504 self.activation.get_fn()(&hidden_states.apply(&self.dense)).apply(&self.layer_norm)
505 }
506}
507
508pub struct BertLMPredictionHead {
509 transform: BertPredictionHeadTransform,
510 decoder: LinearNoBias,
511 bias: Tensor,
512}
513
514impl BertLMPredictionHead {
515 pub fn new<'p, P>(p: P, config: &BertConfig) -> BertLMPredictionHead
516 where
517 P: Borrow<nn::Path<'p>>,
518 {
519 let p = p.borrow() / "predictions";
520 let transform = BertPredictionHeadTransform::new(&p / "transform", config);
521 let decoder = linear_no_bias(
522 &p / "decoder",
523 config.hidden_size,
524 config.vocab_size,
525 Default::default(),
526 );
527 let bias = p.var("bias", &[config.vocab_size], DEFAULT_KAIMING_UNIFORM);
528
529 BertLMPredictionHead {
530 transform,
531 decoder,
532 bias,
533 }
534 }
535
536 pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
537 self.transform.forward(hidden_states).apply(&self.decoder) + &self.bias
538 }
539}
540
541/// # BERT for masked language model
542/// Base BERT model with a masked language model head to predict missing tokens, for example `"Looks like one [MASK] is missing" -> "person"`
543/// It is made of the following blocks:
544/// - `bert`: Base BertModel
545/// - `cls`: BERT LM prediction head
546pub struct BertForMaskedLM {
547 bert: BertModel<BertEmbeddings>,
548 cls: BertLMPredictionHead,
549}
550
551impl BertForMaskedLM {
552 /// Build a new `BertForMaskedLM`
553 ///
554 /// # Arguments
555 ///
556 /// * `p` - Variable store path for the root of the BertForMaskedLM model
557 /// * `config` - `BertConfig` object defining the model architecture and vocab size
558 ///
559 /// # Example
560 ///
561 /// ```no_run
562 /// use rust_bert::bert::{BertConfig, BertForMaskedLM};
563 /// use rust_bert::Config;
564 /// use std::path::Path;
565 /// use tch::{nn, Device};
566 ///
567 /// let config_path = Path::new("path/to/config.json");
568 /// let device = Device::Cpu;
569 /// let p = nn::VarStore::new(device);
570 /// let config = BertConfig::from_file(config_path);
571 /// let bert = BertForMaskedLM::new(&p.root() / "bert", &config);
572 /// ```
573 pub fn new<'p, P>(p: P, config: &BertConfig) -> BertForMaskedLM
574 where
575 P: Borrow<nn::Path<'p>>,
576 {
577 let p = p.borrow();
578
579 let bert = BertModel::new_with_optional_pooler(p / "bert", config, false);
580 let cls = BertLMPredictionHead::new(p / "cls", config);
581
582 BertForMaskedLM { bert, cls }
583 }
584
585 /// Forward pass through the model
586 ///
587 /// # Arguments
588 ///
589 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see *input_embeds*)
590 /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
591 /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
592 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
593 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see *input_ids*)
594 /// * `encoder_hidden_states` - Optional encoder hidden state of shape (*batch size*, *encoder_sequence_length*, *hidden_size*). If the model is defined as a decoder and the *encoder_hidden_states* is not None, used in the cross-attention layer as keys and values (query from the decoder).
595 /// * `encoder_mask` - Optional encoder attention mask of shape (*batch size*, *encoder_sequence_length*). If the model is defined as a decoder and the *encoder_hidden_states* is not None, used to mask encoder values. Positions with value 0 will be masked.
596 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
597 ///
598 /// # Returns
599 ///
600 /// * `BertMaskedLMOutput` containing:
601 /// - `prediction_scores` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*)
602 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
603 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
604 ///
605 /// # Example
606 ///
607 /// ```no_run
608 /// # use rust_bert::bert::{BertForMaskedLM, BertConfig};
609 /// # use tch::{nn, Device, Tensor, no_grad, Kind};
610 /// # use rust_bert::Config;
611 /// # use std::path::Path;
612 /// # let config_path = Path::new("path/to/config.json");
613 /// # let device = Device::Cpu;
614 /// # let vs = nn::VarStore::new(device);
615 /// # let config = BertConfig::from_file(config_path);
616 /// # let bert_model = BertForMaskedLM::new(&vs.root(), &config);
617 /// let (batch_size, sequence_length) = (64, 128);
618 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Kind::Int64, device));
619 /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
620 /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
621 /// let position_ids = Tensor::arange(sequence_length, (Kind::Int64, device))
622 /// .expand(&[batch_size, sequence_length], true);
623 ///
624 /// let model_output = no_grad(|| {
625 /// bert_model.forward_t(
626 /// Some(&input_tensor),
627 /// Some(&mask),
628 /// Some(&token_type_ids),
629 /// Some(&position_ids),
630 /// None,
631 /// None,
632 /// None,
633 /// false,
634 /// )
635 /// });
636 /// ```
637 pub fn forward_t(
638 &self,
639 input_ids: Option<&Tensor>,
640 mask: Option<&Tensor>,
641 token_type_ids: Option<&Tensor>,
642 position_ids: Option<&Tensor>,
643 input_embeds: Option<&Tensor>,
644 encoder_hidden_states: Option<&Tensor>,
645 encoder_mask: Option<&Tensor>,
646 train: bool,
647 ) -> BertMaskedLMOutput {
648 let base_model_output = self
649 .bert
650 .forward_t(
651 input_ids,
652 mask,
653 token_type_ids,
654 position_ids,
655 input_embeds,
656 encoder_hidden_states,
657 encoder_mask,
658 train,
659 )
660 .unwrap();
661
662 let prediction_scores = self.cls.forward(&base_model_output.hidden_state);
663 BertMaskedLMOutput {
664 prediction_scores,
665 all_hidden_states: base_model_output.all_hidden_states,
666 all_attentions: base_model_output.all_attentions,
667 }
668 }
669}
670
671/// # BERT for sequence classification
672/// Base BERT model with a classifier head to perform sentence or document-level classification
673/// It is made of the following blocks:
674/// - `bert`: Base BertModel
675/// - `classifier`: BERT linear layer for classification
676pub struct BertForSequenceClassification {
677 bert: BertModel<BertEmbeddings>,
678 dropout: Dropout,
679 classifier: nn::Linear,
680}
681
682impl BertForSequenceClassification {
683 /// Build a new `BertForSequenceClassification`
684 ///
685 /// # Arguments
686 ///
687 /// * `p` - Variable store path for the root of the BertForSequenceClassification model
688 /// * `config` - `BertConfig` object defining the model architecture and number of classes
689 ///
690 /// # Example
691 ///
692 /// ```no_run
693 /// use rust_bert::bert::{BertConfig, BertForSequenceClassification};
694 /// use rust_bert::Config;
695 /// use std::path::Path;
696 /// use tch::{nn, Device};
697 ///
698 /// let config_path = Path::new("path/to/config.json");
699 /// let device = Device::Cpu;
700 /// let p = nn::VarStore::new(device);
701 /// let config = BertConfig::from_file(config_path);
702 /// let bert = BertForSequenceClassification::new(&p.root() / "bert", &config).unwrap();
703 /// ```
704 pub fn new<'p, P>(
705 p: P,
706 config: &BertConfig,
707 ) -> Result<BertForSequenceClassification, RustBertError>
708 where
709 P: Borrow<nn::Path<'p>>,
710 {
711 let p = p.borrow();
712
713 let bert = BertModel::new(p / "bert", config);
714 let dropout = Dropout::new(config.hidden_dropout_prob);
715 let num_labels = config
716 .id2label
717 .as_ref()
718 .ok_or_else(|| {
719 RustBertError::InvalidConfigurationError(
720 "num_labels not provided in configuration".to_string(),
721 )
722 })?
723 .len() as i64;
724 let classifier = nn::linear(
725 p / "classifier",
726 config.hidden_size,
727 num_labels,
728 Default::default(),
729 );
730
731 Ok(BertForSequenceClassification {
732 bert,
733 dropout,
734 classifier,
735 })
736 }
737
738 /// Forward pass through the model
739 ///
740 /// # Arguments
741 ///
742 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
743 /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
744 /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
745 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
746 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
747 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
748 ///
749 /// # Returns
750 ///
751 /// * `BertSequenceClassificationOutput` containing:
752 /// - `logits` - `Tensor` of shape (*batch size*, *num_labels*)
753 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
754 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
755 ///
756 /// # Example
757 ///
758 /// ```no_run
759 /// # use rust_bert::bert::{BertForSequenceClassification, BertConfig};
760 /// # use tch::{nn, Device, Tensor, no_grad, Kind};
761 /// # use rust_bert::Config;
762 /// # use std::path::Path;
763 /// # let config_path = Path::new("path/to/config.json");
764 /// # let device = Device::Cpu;
765 /// # let vs = nn::VarStore::new(device);
766 /// # let config = BertConfig::from_file(config_path);
767 /// # let bert_model = BertForSequenceClassification::new(&vs.root(), &config).unwrap();;
768 /// let (batch_size, sequence_length) = (64, 128);
769 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Kind::Int64, device));
770 /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
771 /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
772 /// let position_ids = Tensor::arange(sequence_length, (Kind::Int64, device))
773 /// .expand(&[batch_size, sequence_length], true);
774 ///
775 /// let model_output = no_grad(|| {
776 /// bert_model.forward_t(
777 /// Some(&input_tensor),
778 /// Some(&mask),
779 /// Some(&token_type_ids),
780 /// Some(&position_ids),
781 /// None,
782 /// false,
783 /// )
784 /// });
785 /// ```
786 pub fn forward_t(
787 &self,
788 input_ids: Option<&Tensor>,
789 mask: Option<&Tensor>,
790 token_type_ids: Option<&Tensor>,
791 position_ids: Option<&Tensor>,
792 input_embeds: Option<&Tensor>,
793 train: bool,
794 ) -> BertSequenceClassificationOutput {
795 let base_model_output = self
796 .bert
797 .forward_t(
798 input_ids,
799 mask,
800 token_type_ids,
801 position_ids,
802 input_embeds,
803 None,
804 None,
805 train,
806 )
807 .unwrap();
808
809 let logits = base_model_output
810 .pooled_output
811 .unwrap()
812 .apply_t(&self.dropout, train)
813 .apply(&self.classifier);
814 BertSequenceClassificationOutput {
815 logits,
816 all_hidden_states: base_model_output.all_hidden_states,
817 all_attentions: base_model_output.all_attentions,
818 }
819 }
820}
821
822/// # BERT for multiple choices
823/// Multiple choices model using a BERT base model and a linear classifier.
824/// Input should be in the form `[CLS] Context [SEP] Possible choice [SEP]`. The choice is made along the batch axis,
825/// assuming all elements of the batch are alternatives to be chosen from for a given context.
826/// It is made of the following blocks:
827/// - `bert`: Base BertModel
828/// - `classifier`: Linear layer for multiple choices
829pub struct BertForMultipleChoice {
830 bert: BertModel<BertEmbeddings>,
831 dropout: Dropout,
832 classifier: nn::Linear,
833}
834
835impl BertForMultipleChoice {
836 /// Build a new `BertForMultipleChoice`
837 ///
838 /// # Arguments
839 ///
840 /// * `p` - Variable store path for the root of the BertForMultipleChoice model
841 /// * `config` - `BertConfig` object defining the model architecture
842 ///
843 /// # Example
844 ///
845 /// ```no_run
846 /// use rust_bert::bert::{BertConfig, BertForMultipleChoice};
847 /// use rust_bert::Config;
848 /// use std::path::Path;
849 /// use tch::{nn, Device};
850 ///
851 /// let config_path = Path::new("path/to/config.json");
852 /// let device = Device::Cpu;
853 /// let p = nn::VarStore::new(device);
854 /// let config = BertConfig::from_file(config_path);
855 /// let bert = BertForMultipleChoice::new(&p.root() / "bert", &config);
856 /// ```
857 pub fn new<'p, P>(p: P, config: &BertConfig) -> BertForMultipleChoice
858 where
859 P: Borrow<nn::Path<'p>>,
860 {
861 let p = p.borrow();
862
863 let bert = BertModel::new(p / "bert", config);
864 let dropout = Dropout::new(config.hidden_dropout_prob);
865 let classifier = nn::linear(p / "classifier", config.hidden_size, 1, Default::default());
866
867 BertForMultipleChoice {
868 bert,
869 dropout,
870 classifier,
871 }
872 }
873
874 /// Forward pass through the model
875 ///
876 /// # Arguments
877 ///
878 /// * `input_ids` - Input tensor of shape (*batch size*, *sequence_length*).
879 /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
880 /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
881 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
882 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
883 ///
884 /// # Returns
885 ///
886 /// * `BertSequenceClassificationOutput` containing:
887 /// - `logits` - `Tensor` of shape (*1*, *batch size*) containing the logits for each of the alternatives given
888 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
889 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
890 ///
891 /// # Example
892 ///
893 /// ```no_run
894 /// # use rust_bert::bert::{BertForMultipleChoice, BertConfig};
895 /// # use tch::{nn, Device, Tensor, no_grad};
896 /// # use rust_bert::Config;
897 /// # use std::path::Path;
898 /// # use tch::kind::Kind::Int64;
899 /// # let config_path = Path::new("path/to/config.json");
900 /// # let device = Device::Cpu;
901 /// # let vs = nn::VarStore::new(device);
902 /// # let config = BertConfig::from_file(config_path);
903 /// # let bert_model = BertForMultipleChoice::new(&vs.root(), &config);
904 /// let (num_choices, sequence_length) = (3, 128);
905 /// let input_tensor = Tensor::rand(&[num_choices, sequence_length], (Int64, device));
906 /// let mask = Tensor::zeros(&[num_choices, sequence_length], (Int64, device));
907 /// let token_type_ids = Tensor::zeros(&[num_choices, sequence_length], (Int64, device));
908 /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
909 /// .expand(&[num_choices, sequence_length], true);
910 ///
911 /// let model_output = no_grad(|| {
912 /// bert_model.forward_t(
913 /// &input_tensor,
914 /// Some(&mask),
915 /// Some(&token_type_ids),
916 /// Some(&position_ids),
917 /// false,
918 /// )
919 /// });
920 /// ```
921 pub fn forward_t(
922 &self,
923 input_ids: &Tensor,
924 mask: Option<&Tensor>,
925 token_type_ids: Option<&Tensor>,
926 position_ids: Option<&Tensor>,
927 train: bool,
928 ) -> BertSequenceClassificationOutput {
929 let num_choices = input_ids.size()[1];
930
931 let input_ids = input_ids.view((-1, *input_ids.size().last().unwrap()));
932 let mask = mask.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
933 let token_type_ids =
934 token_type_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
935 let position_ids =
936 position_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
937
938 let base_model_output = self
939 .bert
940 .forward_t(
941 Some(&input_ids),
942 mask.as_ref(),
943 token_type_ids.as_ref(),
944 position_ids.as_ref(),
945 None,
946 None,
947 None,
948 train,
949 )
950 .unwrap();
951
952 let logits = base_model_output
953 .pooled_output
954 .unwrap()
955 .apply_t(&self.dropout, train)
956 .apply(&self.classifier)
957 .view((-1, num_choices));
958 BertSequenceClassificationOutput {
959 logits,
960 all_hidden_states: base_model_output.all_hidden_states,
961 all_attentions: base_model_output.all_attentions,
962 }
963 }
964}
965
966/// # BERT for token classification (e.g. NER, POS)
967/// Token-level classifier predicting a label for each token provided. Note that because of wordpiece tokenization, the labels predicted are
968/// not necessarily aligned with words in the sentence.
969/// It is made of the following blocks:
970/// - `bert`: Base BertModel
971/// - `classifier`: Linear layer for token classification
972pub struct BertForTokenClassification {
973 bert: BertModel<BertEmbeddings>,
974 dropout: Dropout,
975 classifier: nn::Linear,
976}
977
978impl BertForTokenClassification {
979 /// Build a new `BertForTokenClassification`
980 ///
981 /// # Arguments
982 ///
983 /// * `p` - Variable store path for the root of the BertForTokenClassification model
984 /// * `config` - `BertConfig` object defining the model architecture, number of output labels and label mapping
985 ///
986 /// # Example
987 ///
988 /// ```no_run
989 /// use rust_bert::bert::{BertConfig, BertForTokenClassification};
990 /// use rust_bert::Config;
991 /// use std::path::Path;
992 /// use tch::{nn, Device};
993 ///
994 /// let config_path = Path::new("path/to/config.json");
995 /// let device = Device::Cpu;
996 /// let p = nn::VarStore::new(device);
997 /// let config = BertConfig::from_file(config_path);
998 /// let bert = BertForTokenClassification::new(&p.root() / "bert", &config).unwrap();
999 /// ```
1000 pub fn new<'p, P>(
1001 p: P,
1002 config: &BertConfig,
1003 ) -> Result<BertForTokenClassification, RustBertError>
1004 where
1005 P: Borrow<nn::Path<'p>>,
1006 {
1007 let p = p.borrow();
1008
1009 let bert = BertModel::new_with_optional_pooler(p / "bert", config, false);
1010 let dropout = Dropout::new(config.hidden_dropout_prob);
1011 let num_labels = config
1012 .id2label
1013 .as_ref()
1014 .ok_or_else(|| {
1015 RustBertError::InvalidConfigurationError(
1016 "num_labels not provided in configuration".to_string(),
1017 )
1018 })?
1019 .len() as i64;
1020 let classifier = nn::linear(
1021 p / "classifier",
1022 config.hidden_size,
1023 num_labels,
1024 Default::default(),
1025 );
1026
1027 Ok(BertForTokenClassification {
1028 bert,
1029 dropout,
1030 classifier,
1031 })
1032 }
1033
1034 /// Forward pass through the model
1035 ///
1036 /// # Arguments
1037 ///
1038 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
1039 /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
1040 /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
1041 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
1042 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
1043 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
1044 ///
1045 /// # Returns
1046 ///
1047 /// * `BertTokenClassificationOutput` containing:
1048 /// - `logits` - `Tensor` of shape (*batch size*, *sequence_length*, *num_labels*) containing the logits for each of the input tokens and classes
1049 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
1050 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
1051 ///
1052 /// # Example
1053 ///
1054 /// ```no_run
1055 /// # use rust_bert::bert::{BertForTokenClassification, BertConfig};
1056 /// # use tch::{nn, Device, Tensor, no_grad};
1057 /// # use rust_bert::Config;
1058 /// # use std::path::Path;
1059 /// # use tch::kind::Kind::Int64;
1060 /// # let config_path = Path::new("path/to/config.json");
1061 /// # let device = Device::Cpu;
1062 /// # let vs = nn::VarStore::new(device);
1063 /// # let config = BertConfig::from_file(config_path);
1064 /// # let bert_model = BertForTokenClassification::new(&vs.root(), &config).unwrap();
1065 /// let (batch_size, sequence_length) = (64, 128);
1066 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
1067 /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
1068 /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
1069 /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
1070 /// .expand(&[batch_size, sequence_length], true);
1071 ///
1072 /// let model_output = no_grad(|| {
1073 /// bert_model.forward_t(
1074 /// Some(&input_tensor),
1075 /// Some(&mask),
1076 /// Some(&token_type_ids),
1077 /// Some(&position_ids),
1078 /// None,
1079 /// false,
1080 /// )
1081 /// });
1082 /// ```
1083 pub fn forward_t(
1084 &self,
1085 input_ids: Option<&Tensor>,
1086 mask: Option<&Tensor>,
1087 token_type_ids: Option<&Tensor>,
1088 position_ids: Option<&Tensor>,
1089 input_embeds: Option<&Tensor>,
1090 train: bool,
1091 ) -> BertTokenClassificationOutput {
1092 let base_model_output = self
1093 .bert
1094 .forward_t(
1095 input_ids,
1096 mask,
1097 token_type_ids,
1098 position_ids,
1099 input_embeds,
1100 None,
1101 None,
1102 train,
1103 )
1104 .unwrap();
1105
1106 let logits = base_model_output
1107 .hidden_state
1108 .apply_t(&self.dropout, train)
1109 .apply(&self.classifier);
1110 BertTokenClassificationOutput {
1111 logits,
1112 all_hidden_states: base_model_output.all_hidden_states,
1113 all_attentions: base_model_output.all_attentions,
1114 }
1115 }
1116}
1117
1118/// # BERT for question answering
1119/// Extractive question-answering model based on a BERT language model. Identifies the segment of a context that answers a provided question.
1120/// Please note that a significant amount of pre- and post-processing is required to perform end-to-end question answering.
1121/// See the question answering pipeline (also provided in this crate) for more details.
1122/// It is made of the following blocks:
1123/// - `bert`: Base BertModel
1124/// - `qa_outputs`: Linear layer for question answering
1125pub struct BertForQuestionAnswering {
1126 bert: BertModel<BertEmbeddings>,
1127 qa_outputs: nn::Linear,
1128}
1129
1130impl BertForQuestionAnswering {
1131 /// Build a new `BertForQuestionAnswering`
1132 ///
1133 /// # Arguments
1134 ///
1135 /// * `p` - Variable store path for the root of the BertForQuestionAnswering model
1136 /// * `config` - `BertConfig` object defining the model architecture
1137 ///
1138 /// # Example
1139 ///
1140 /// ```no_run
1141 /// use rust_bert::bert::{BertConfig, BertForQuestionAnswering};
1142 /// use rust_bert::Config;
1143 /// use std::path::Path;
1144 /// use tch::{nn, Device};
1145 ///
1146 /// let config_path = Path::new("path/to/config.json");
1147 /// let device = Device::Cpu;
1148 /// let p = nn::VarStore::new(device);
1149 /// let config = BertConfig::from_file(config_path);
1150 /// let bert = BertForQuestionAnswering::new(&p.root() / "bert", &config);
1151 /// ```
1152 pub fn new<'p, P>(p: P, config: &BertConfig) -> BertForQuestionAnswering
1153 where
1154 P: Borrow<nn::Path<'p>>,
1155 {
1156 let p = p.borrow();
1157
1158 let bert = BertModel::new(p / "bert", config);
1159 let num_labels = 2;
1160 let qa_outputs = nn::linear(
1161 p / "qa_outputs",
1162 config.hidden_size,
1163 num_labels,
1164 Default::default(),
1165 );
1166
1167 BertForQuestionAnswering { bert, qa_outputs }
1168 }
1169
1170 /// Forward pass through the model
1171 ///
1172 /// # Arguments
1173 ///
1174 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
1175 /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
1176 /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
1177 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
1178 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
1179 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
1180 ///
1181 /// # Returns
1182 ///
1183 /// * `BertQuestionAnsweringOutput` containing:
1184 /// - `start_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for start of the answer
1185 /// - `end_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for end of the answer
1186 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
1187 /// - `all_attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
1188 ///
1189 /// # Example
1190 ///
1191 /// ```no_run
1192 /// # use rust_bert::bert::{BertForQuestionAnswering, BertConfig};
1193 /// # use tch::{nn, Device, Tensor, no_grad};
1194 /// # use rust_bert::Config;
1195 /// # use std::path::Path;
1196 /// # use tch::kind::Kind::Int64;
1197 /// # let config_path = Path::new("path/to/config.json");
1198 /// # let device = Device::Cpu;
1199 /// # let vs = nn::VarStore::new(device);
1200 /// # let config = BertConfig::from_file(config_path);
1201 /// # let bert_model = BertForQuestionAnswering::new(&vs.root(), &config);
1202 /// let (batch_size, sequence_length) = (64, 128);
1203 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
1204 /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
1205 /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
1206 /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
1207 /// .expand(&[batch_size, sequence_length], true);
1208 ///
1209 /// let model_output = no_grad(|| {
1210 /// bert_model.forward_t(
1211 /// Some(&input_tensor),
1212 /// Some(&mask),
1213 /// Some(&token_type_ids),
1214 /// Some(&position_ids),
1215 /// None,
1216 /// false,
1217 /// )
1218 /// });
1219 /// ```
1220 pub fn forward_t(
1221 &self,
1222 input_ids: Option<&Tensor>,
1223 mask: Option<&Tensor>,
1224 token_type_ids: Option<&Tensor>,
1225 position_ids: Option<&Tensor>,
1226 input_embeds: Option<&Tensor>,
1227 train: bool,
1228 ) -> BertQuestionAnsweringOutput {
1229 let base_model_output = self
1230 .bert
1231 .forward_t(
1232 input_ids,
1233 mask,
1234 token_type_ids,
1235 position_ids,
1236 input_embeds,
1237 None,
1238 None,
1239 train,
1240 )
1241 .unwrap();
1242
1243 let sequence_output = base_model_output.hidden_state.apply(&self.qa_outputs);
1244 let logits = sequence_output.split(1, -1);
1245 let (start_logits, end_logits) = (&logits[0], &logits[1]);
1246 let start_logits = start_logits.squeeze_dim(-1);
1247 let end_logits = end_logits.squeeze_dim(-1);
1248
1249 BertQuestionAnsweringOutput {
1250 start_logits,
1251 end_logits,
1252 all_hidden_states: base_model_output.all_hidden_states,
1253 all_attentions: base_model_output.all_attentions,
1254 }
1255 }
1256}
1257
1258/// # BERT for sentence embeddings
1259/// Transformer usable in [`SentenceEmbeddingsModel`](crate::pipelines::sentence_embeddings::SentenceEmbeddingsModel).
1260pub type BertForSentenceEmbeddings = BertModel<BertEmbeddings>;
1261
1262/// Container for the BERT model output.
1263pub struct BertModelOutput {
1264 /// Last hidden states from the model
1265 pub hidden_state: Tensor,
1266 /// Pooled output (hidden state for the first token)
1267 pub pooled_output: Option<Tensor>,
1268 /// Hidden states for all intermediate layers
1269 pub all_hidden_states: Option<Vec<Tensor>>,
1270 /// Attention weights for all intermediate layers
1271 pub all_attentions: Option<Vec<Tensor>>,
1272}
1273
1274/// Container for the BERT masked LM model output.
1275pub struct BertMaskedLMOutput {
1276 /// Logits for the vocabulary items at each sequence position
1277 pub prediction_scores: Tensor,
1278 /// Hidden states for all intermediate layers
1279 pub all_hidden_states: Option<Vec<Tensor>>,
1280 /// Attention weights for all intermediate layers
1281 pub all_attentions: Option<Vec<Tensor>>,
1282}
1283
1284/// Container for the BERT sequence classification model output.
1285pub struct BertSequenceClassificationOutput {
1286 /// Logits for each input (sequence) for each target class
1287 pub logits: Tensor,
1288 /// Hidden states for all intermediate layers
1289 pub all_hidden_states: Option<Vec<Tensor>>,
1290 /// Attention weights for all intermediate layers
1291 pub all_attentions: Option<Vec<Tensor>>,
1292}
1293
1294/// Container for the BERT token classification model output.
1295pub struct BertTokenClassificationOutput {
1296 /// Logits for each sequence item (token) for each target class
1297 pub logits: Tensor,
1298 /// Hidden states for all intermediate layers
1299 pub all_hidden_states: Option<Vec<Tensor>>,
1300 /// Attention weights for all intermediate layers
1301 pub all_attentions: Option<Vec<Tensor>>,
1302}
1303
1304/// Container for the BERT question answering model output.
1305pub struct BertQuestionAnsweringOutput {
1306 /// Logits for the start position for token of each input sequence
1307 pub start_logits: Tensor,
1308 /// Logits for the end position for token of each input sequence
1309 pub end_logits: Tensor,
1310 /// Hidden states for all intermediate layers
1311 pub all_hidden_states: Option<Vec<Tensor>>,
1312 /// Attention weights for all intermediate layers
1313 pub all_attentions: Option<Vec<Tensor>>,
1314}
1315
1316#[cfg(test)]
1317mod test {
1318 use tch::Device;
1319
1320 use crate::{
1321 resources::{RemoteResource, ResourceProvider},
1322 Config,
1323 };
1324
1325 use super::*;
1326
1327 #[test]
1328 #[ignore] // compilation is enough, no need to run
1329 fn bert_model_send() {
1330 let config_resource = Box::new(RemoteResource::from_pretrained(BertConfigResources::BERT));
1331 let config_path = config_resource.get_local_path().expect("");
1332
1333 // Set-up masked LM model
1334 let device = Device::cuda_if_available();
1335 let vs = nn::VarStore::new(device);
1336 let config = BertConfig::from_file(config_path);
1337
1338 let _: Box<dyn Send> = Box::new(BertModel::<BertEmbeddings>::new(vs.root(), &config));
1339 }
1340}