rust_bert/models/bert/
bert_model.rs

1// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
2// Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
3// Copyright 2019 Guillaume Becquin
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//     http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14use crate::bert::encoder::{BertEncoder, BertPooler};
15use crate::common::activations::Activation;
16use crate::common::dropout::Dropout;
17use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair;
18use crate::common::linear::{linear_no_bias, LinearNoBias};
19use crate::{
20    bert::embeddings::{BertEmbedding, BertEmbeddings},
21    common::activations::TensorFunction,
22};
23use crate::{Config, RustBertError};
24use serde::{Deserialize, Serialize};
25use std::borrow::Borrow;
26use std::collections::HashMap;
27use tch::nn::init::DEFAULT_KAIMING_UNIFORM;
28use tch::{nn, Kind, Tensor};
29
30/// # BERT Pretrained model weight files
31pub struct BertModelResources;
32
33/// # BERT Pretrained model config files
34pub struct BertConfigResources;
35
36/// # BERT Pretrained model vocab files
37pub struct BertVocabResources;
38
39impl BertModelResources {
40    /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/bert>. Modified with conversion to C-array format.
41    pub const BERT: (&'static str, &'static str) = (
42        "bert/model",
43        "https://huggingface.co/bert-base-uncased/resolve/main/rust_model.ot",
44    );
45    /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/bert>. Modified with conversion to C-array format.
46    pub const BERT_LARGE: (&'static str, &'static str) = (
47        "bert-large/model",
48        "https://huggingface.co/bert-large-uncased/resolve/main/rust_model.ot",
49    );
50    /// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at <https://github.com/dbmdz/berts>. Modified with conversion to C-array format.
51    pub const BERT_NER: (&'static str, &'static str) = (
52        "bert-ner/model",
53        "https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/resolve/main/rust_model.ot",
54    );
55    /// Shared under Apache 2.0 license by Hugging Face Inc at <https://github.com/huggingface/transformers/tree/master/examples/question-answering>. Modified with conversion to C-array format.
56    pub const BERT_QA: (&'static str, &'static str) = (
57        "bert-qa/model",
58        "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/rust_model.ot",
59    );
60    /// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens>. Modified with conversion to C-array format.
61    pub const BERT_BASE_NLI_MEAN_TOKENS: (&'static str, &'static str) = (
62        "bert-base-nli-mean-tokens/model",
63        "https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens/resolve/main/rust_model.ot",
64    );
65    /// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2>. Modified with conversion to C-array format.
66    pub const ALL_MINI_LM_L12_V2: (&'static str, &'static str) = (
67        "all-mini-lm-l12-v2/model",
68        "https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/rust_model.ot",
69    );
70    /// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2>. Modified with conversion to C-array format.
71    pub const ALL_MINI_LM_L6_V2: (&'static str, &'static str) = (
72        "all-mini-lm-l6-v2/model",
73        "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/rust_model.ot",
74    );
75}
76
77impl BertConfigResources {
78    /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/bert>. Modified with conversion to C-array format.
79    pub const BERT: (&'static str, &'static str) = (
80        "bert/config",
81        "https://huggingface.co/bert-base-uncased/resolve/main/config.json",
82    );
83    /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/bert>. Modified with conversion to C-array format.
84    pub const BERT_LARGE: (&'static str, &'static str) = (
85        "bert-large/config",
86        "https://huggingface.co/bert-large-uncased/resolve/main/config.json",
87    );
88    /// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at <https://github.com/dbmdz/berts>. Modified with conversion to C-array format.
89    pub const BERT_NER: (&'static str, &'static str) = (
90        "bert-ner/config",
91        "https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/resolve/main/config.json",
92    );
93    /// Shared under Apache 2.0 license by Hugging Face Inc at <https://github.com/huggingface/transformers/tree/master/examples/question-answering>. Modified with conversion to C-array format.
94    pub const BERT_QA: (&'static str, &'static str) = (
95        "bert-qa/config",
96        "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/config.json",
97    );
98    /// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens>. Modified with conversion to C-array format.
99    pub const BERT_BASE_NLI_MEAN_TOKENS: (&'static str, &'static str) = (
100        "bert-base-nli-mean-tokens/config",
101        "https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens/resolve/main/config.json",
102    );
103    /// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2>. Modified with conversion to C-array format.
104    pub const ALL_MINI_LM_L12_V2: (&'static str, &'static str) = (
105        "all-mini-lm-l12-v2/config",
106        "https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/config.json",
107    );
108    /// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2>. Modified with conversion to C-array format.
109    pub const ALL_MINI_LM_L6_V2: (&'static str, &'static str) = (
110        "all-mini-lm-l6-v2/config",
111        "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/config.json",
112    );
113}
114
115impl BertVocabResources {
116    /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/bert>. Modified with conversion to C-array format.
117    pub const BERT: (&'static str, &'static str) = (
118        "bert/vocab",
119        "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt",
120    );
121    /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/bert>. Modified with conversion to C-array format.
122    pub const BERT_LARGE: (&'static str, &'static str) = (
123        "bert-large/vocab",
124        "https://huggingface.co/bert-large-uncased/resolve/main/vocab.txt",
125    );
126    /// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at <https://github.com/dbmdz/berts>. Modified with conversion to C-array format.
127    pub const BERT_NER: (&'static str, &'static str) = (
128        "bert-ner/vocab",
129        "https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/resolve/main/vocab.txt",
130    );
131    /// Shared under Apache 2.0 license by Hugging Face Inc at <https://github.com/huggingface/transformers/tree/master/examples/question-answering>. Modified with conversion to C-array format.
132    pub const BERT_QA: (&'static str, &'static str) = (
133        "bert-qa/vocab",
134        "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
135    );
136    /// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens>. Modified with conversion to C-array format.
137    pub const BERT_BASE_NLI_MEAN_TOKENS: (&'static str, &'static str) = (
138        "bert-base-nli-mean-tokens/vocab",
139        "https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens/resolve/main/vocab.txt",
140    );
141    /// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2>. Modified with conversion to C-array format.
142    pub const ALL_MINI_LM_L12_V2: (&'static str, &'static str) = (
143        "all-mini-lm-l12-v2/vocab",
144        "https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/vocab.txt",
145    );
146    /// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2>. Modified with conversion to C-array format.
147    pub const ALL_MINI_LM_L6_V2: (&'static str, &'static str) = (
148        "all-mini-lm-l6-v2/vocab",
149        "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/vocab.txt",
150    );
151}
152
153#[derive(Debug, Serialize, Deserialize, Clone)]
154/// # BERT model configuration
155/// Defines the BERT model architecture (e.g. number of layers, hidden layer size, label mapping...)
156pub struct BertConfig {
157    pub hidden_act: Activation,
158    pub attention_probs_dropout_prob: f64,
159    pub hidden_dropout_prob: f64,
160    pub hidden_size: i64,
161    pub initializer_range: f32,
162    pub intermediate_size: i64,
163    pub max_position_embeddings: i64,
164    pub num_attention_heads: i64,
165    pub num_hidden_layers: i64,
166    pub type_vocab_size: i64,
167    pub vocab_size: i64,
168    pub output_attentions: Option<bool>,
169    pub output_hidden_states: Option<bool>,
170    pub is_decoder: Option<bool>,
171    pub id2label: Option<HashMap<i64, String>>,
172    pub label2id: Option<HashMap<String, i64>>,
173}
174
175impl Config for BertConfig {}
176
177impl Default for BertConfig {
178    fn default() -> Self {
179        BertConfig {
180            hidden_act: Activation::gelu,
181            attention_probs_dropout_prob: 0.1,
182            hidden_dropout_prob: 0.1,
183            hidden_size: 768,
184            initializer_range: 0.02,
185            intermediate_size: 3072,
186            max_position_embeddings: 512,
187            num_attention_heads: 12,
188            num_hidden_layers: 12,
189            type_vocab_size: 2,
190            vocab_size: 30522,
191            output_attentions: None,
192            output_hidden_states: None,
193            is_decoder: None,
194            id2label: None,
195            label2id: None,
196        }
197    }
198}
199
200/// # BERT Base model
201/// Base architecture for BERT models. Task-specific models will be built from this common base model
202/// It is made of the following blocks:
203/// - `embeddings`: `token`, `position` and `segment_id` embeddings
204/// - `encoder`: Encoder (transformer) made of a vector of layers. Each layer is made of a self-attention layer, an intermediate (linear) and output (linear + layer norm) layers
205/// - `pooler`: linear layer applied to the first element of the sequence (*MASK* token)
206/// - `is_decoder`: Flag indicating if the model is used as a decoder. If set to true, a causal mask will be applied to hide future positions that should not be attended to.
207pub struct BertModel<T: BertEmbedding> {
208    embeddings: T,
209    encoder: BertEncoder,
210    pooler: Option<BertPooler>,
211    is_decoder: bool,
212}
213
214/// Defines the implementation of the BertModel. The BERT model shares many similarities with RoBERTa, main difference being the embeddings.
215/// Therefore the forward pass of the model is shared and the type of embedding used is abstracted away. This allows to create
216/// `BertModel<RobertaEmbeddings>` or `BertModel<BertEmbeddings>` for each model type.
217impl<T: BertEmbedding> BertModel<T> {
218    /// Build a new `BertModel`
219    ///
220    /// # Arguments
221    ///
222    /// * `p` - Variable store path for the root of the BERT model
223    /// * `config` - `BertConfig` object defining the model architecture and decoder status
224    ///
225    /// # Example
226    ///
227    /// ```no_run
228    /// use rust_bert::bert::{BertConfig, BertEmbeddings, BertModel};
229    /// use rust_bert::Config;
230    /// use std::path::Path;
231    /// use tch::{nn, Device};
232    ///
233    /// let config_path = Path::new("path/to/config.json");
234    /// let device = Device::Cpu;
235    /// let p = nn::VarStore::new(device);
236    /// let config = BertConfig::from_file(config_path);
237    /// let bert: BertModel<BertEmbeddings> = BertModel::new(&p.root() / "bert", &config);
238    /// ```
239    pub fn new<'p, P>(p: P, config: &BertConfig) -> BertModel<T>
240    where
241        P: Borrow<nn::Path<'p>>,
242    {
243        let p = p.borrow();
244
245        let is_decoder = config.is_decoder.unwrap_or(false);
246        let embeddings = T::new(p / "embeddings", config);
247        let encoder = BertEncoder::new(p / "encoder", config);
248        let pooler = Some(BertPooler::new(p / "pooler", config));
249
250        BertModel {
251            embeddings,
252            encoder,
253            pooler,
254            is_decoder,
255        }
256    }
257
258    /// Build a new `BertModel` with an optional Pooling layer
259    ///
260    /// # Arguments
261    ///
262    /// * `p` - Variable store path for the root of the BERT model
263    /// * `config` - `BertConfig` object defining the model architecture and decoder status
264    /// * `add_pooling_layer` - Enable/Disable an optional pooling layer at the end of the model
265    ///
266    /// # Example
267    ///
268    /// ```no_run
269    /// use rust_bert::bert::{BertConfig, BertEmbeddings, BertModel};
270    /// use rust_bert::Config;
271    /// use std::path::Path;
272    /// use tch::{nn, Device};
273    ///
274    /// let config_path = Path::new("path/to/config.json");
275    /// let device = Device::Cpu;
276    /// let p = nn::VarStore::new(device);
277    /// let config = BertConfig::from_file(config_path);
278    /// let bert: BertModel<BertEmbeddings> =
279    ///     BertModel::new_with_optional_pooler(&p.root() / "bert", &config, false);
280    /// ```
281    pub fn new_with_optional_pooler<'p, P>(
282        p: P,
283        config: &BertConfig,
284        add_pooling_layer: bool,
285    ) -> BertModel<T>
286    where
287        P: Borrow<nn::Path<'p>>,
288    {
289        let p = p.borrow();
290
291        let is_decoder = config.is_decoder.unwrap_or(false);
292        let embeddings = T::new(p / "embeddings", config);
293        let encoder = BertEncoder::new(p / "encoder", config);
294
295        let pooler = {
296            if add_pooling_layer {
297                Some(BertPooler::new(p / "pooler", config))
298            } else {
299                None
300            }
301        };
302
303        BertModel {
304            embeddings,
305            encoder,
306            pooler,
307            is_decoder,
308        }
309    }
310
311    /// Forward pass through the model
312    ///
313    /// # Arguments
314    ///
315    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
316    /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
317    /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
318    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
319    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
320    /// * `encoder_hidden_states` - Optional encoder hidden state of shape (*batch size*, *encoder_sequence_length*, *hidden_size*). If the model is defined as a decoder and the `encoder_hidden_states` is not None, used in the cross-attention layer as keys and values (query from the decoder).
321    /// * `encoder_mask` - Optional encoder attention mask of shape (*batch size*, *encoder_sequence_length*). If the model is defined as a decoder and the `encoder_hidden_states` is not None, used to mask encoder values. Positions with value 0 will be masked.
322    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
323    ///
324    /// # Returns
325    ///
326    /// * `BertOutput` containing:
327    ///   - `hidden_state` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*)
328    ///   - `pooled_output` - `Tensor` of shape (*batch size*, *hidden_size*)
329    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
330    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
331    ///
332    /// # Example
333    ///
334    /// ```no_run
335    /// # use rust_bert::bert::{BertModel, BertConfig, BertEmbeddings};
336    /// # use tch::{nn, Device, Tensor, no_grad, Kind};
337    /// # use rust_bert::Config;
338    /// # use std::path::Path;
339    /// # let config_path = Path::new("path/to/config.json");
340    /// # let device = Device::Cpu;
341    /// # let vs = nn::VarStore::new(device);
342    /// # let config = BertConfig::from_file(config_path);
343    /// # let bert_model: BertModel<BertEmbeddings> = BertModel::new(&vs.root(), &config);
344    /// let (batch_size, sequence_length) = (64, 128);
345    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Kind::Int64, device));
346    /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
347    /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
348    /// let position_ids = Tensor::arange(sequence_length, (Kind::Int64, device))
349    ///     .expand(&[batch_size, sequence_length], true);
350    ///
351    /// let model_output = no_grad(|| {
352    ///     bert_model
353    ///         .forward_t(
354    ///             Some(&input_tensor),
355    ///             Some(&mask),
356    ///             Some(&token_type_ids),
357    ///             Some(&position_ids),
358    ///             None,
359    ///             None,
360    ///             None,
361    ///             false,
362    ///         )
363    ///         .unwrap()
364    /// });
365    /// ```
366    pub fn forward_t(
367        &self,
368        input_ids: Option<&Tensor>,
369        mask: Option<&Tensor>,
370        token_type_ids: Option<&Tensor>,
371        position_ids: Option<&Tensor>,
372        input_embeds: Option<&Tensor>,
373        encoder_hidden_states: Option<&Tensor>,
374        encoder_mask: Option<&Tensor>,
375        train: bool,
376    ) -> Result<BertModelOutput, RustBertError> {
377        let (input_shape, device) =
378            get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeds)?;
379
380        let calc_mask = Tensor::ones(&input_shape, (Kind::Int8, device));
381        let mask = mask.unwrap_or(&calc_mask);
382
383        let extended_attention_mask = match mask.dim() {
384            3 => mask.unsqueeze(1),
385            2 => {
386                if self.is_decoder {
387                    let seq_ids = Tensor::arange(input_shape[1], (Kind::Int8, device));
388                    let causal_mask = seq_ids.unsqueeze(0).unsqueeze(0).repeat([
389                        input_shape[0],
390                        input_shape[1],
391                        1,
392                    ]);
393                    let causal_mask = causal_mask.le_tensor(&seq_ids.unsqueeze(0).unsqueeze(-1));
394                    causal_mask * mask.unsqueeze(1).unsqueeze(1)
395                } else {
396                    mask.unsqueeze(1).unsqueeze(1)
397                }
398            }
399            _ => {
400                return Err(RustBertError::ValueError(
401                    "Invalid attention mask dimension, must be 2 or 3".into(),
402                ));
403            }
404        };
405
406        let embedding_output = self.embeddings.forward_t(
407            input_ids,
408            token_type_ids,
409            position_ids,
410            input_embeds,
411            train,
412        )?;
413
414        let extended_attention_mask: Tensor = ((extended_attention_mask
415            .ones_like()
416            .bitwise_xor_tensor(&extended_attention_mask))
417            * -10000.0)
418            .to_kind(embedding_output.kind());
419
420        let encoder_extended_attention_mask: Option<Tensor> =
421            if self.is_decoder & encoder_hidden_states.is_some() {
422                let encoder_hidden_states = encoder_hidden_states.as_ref().unwrap();
423                let encoder_hidden_states_shape = encoder_hidden_states.size();
424                let encoder_mask = match encoder_mask {
425                    Some(value) => value.copy(),
426                    None => Tensor::ones(
427                        [
428                            encoder_hidden_states_shape[0],
429                            encoder_hidden_states_shape[1],
430                        ],
431                        (Kind::Int8, device),
432                    ),
433                };
434                match encoder_mask.dim() {
435                    2 => Some(encoder_mask.unsqueeze(1).unsqueeze(1)),
436                    3 => Some(encoder_mask.unsqueeze(1)),
437                    _ => {
438                        return Err(RustBertError::ValueError(
439                            "Invalid attention mask dimension, must be 2 or 3".into(),
440                        ));
441                    }
442                }
443            } else {
444                None
445            };
446
447        let encoder_output = self.encoder.forward_t(
448            &embedding_output,
449            Some(&extended_attention_mask),
450            encoder_hidden_states,
451            encoder_extended_attention_mask.as_ref(),
452            train,
453        );
454
455        let pooled_output = self
456            .pooler
457            .as_ref()
458            .map(|pooler| pooler.forward(&encoder_output.hidden_state));
459
460        Ok(BertModelOutput {
461            hidden_state: encoder_output.hidden_state,
462            pooled_output,
463            all_hidden_states: encoder_output.all_hidden_states,
464            all_attentions: encoder_output.all_attentions,
465        })
466    }
467}
468
469pub struct BertPredictionHeadTransform {
470    dense: nn::Linear,
471    activation: TensorFunction,
472    layer_norm: nn::LayerNorm,
473}
474
475impl BertPredictionHeadTransform {
476    pub fn new<'p, P>(p: P, config: &BertConfig) -> BertPredictionHeadTransform
477    where
478        P: Borrow<nn::Path<'p>>,
479    {
480        let p = p.borrow();
481
482        let dense = nn::linear(
483            p / "dense",
484            config.hidden_size,
485            config.hidden_size,
486            Default::default(),
487        );
488        let activation = config.hidden_act.get_function();
489        let layer_norm_config = nn::LayerNormConfig {
490            eps: 1e-12,
491            ..Default::default()
492        };
493        let layer_norm =
494            nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
495
496        BertPredictionHeadTransform {
497            dense,
498            activation,
499            layer_norm,
500        }
501    }
502
503    pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
504        self.activation.get_fn()(&hidden_states.apply(&self.dense)).apply(&self.layer_norm)
505    }
506}
507
508pub struct BertLMPredictionHead {
509    transform: BertPredictionHeadTransform,
510    decoder: LinearNoBias,
511    bias: Tensor,
512}
513
514impl BertLMPredictionHead {
515    pub fn new<'p, P>(p: P, config: &BertConfig) -> BertLMPredictionHead
516    where
517        P: Borrow<nn::Path<'p>>,
518    {
519        let p = p.borrow() / "predictions";
520        let transform = BertPredictionHeadTransform::new(&p / "transform", config);
521        let decoder = linear_no_bias(
522            &p / "decoder",
523            config.hidden_size,
524            config.vocab_size,
525            Default::default(),
526        );
527        let bias = p.var("bias", &[config.vocab_size], DEFAULT_KAIMING_UNIFORM);
528
529        BertLMPredictionHead {
530            transform,
531            decoder,
532            bias,
533        }
534    }
535
536    pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
537        self.transform.forward(hidden_states).apply(&self.decoder) + &self.bias
538    }
539}
540
541/// # BERT for masked language model
542/// Base BERT model with a masked language model head to predict missing tokens, for example `"Looks like one [MASK] is missing" -> "person"`
543/// It is made of the following blocks:
544/// - `bert`: Base BertModel
545/// - `cls`: BERT LM prediction head
546pub struct BertForMaskedLM {
547    bert: BertModel<BertEmbeddings>,
548    cls: BertLMPredictionHead,
549}
550
551impl BertForMaskedLM {
552    /// Build a new `BertForMaskedLM`
553    ///
554    /// # Arguments
555    ///
556    /// * `p` - Variable store path for the root of the BertForMaskedLM model
557    /// * `config` - `BertConfig` object defining the model architecture and vocab size
558    ///
559    /// # Example
560    ///
561    /// ```no_run
562    /// use rust_bert::bert::{BertConfig, BertForMaskedLM};
563    /// use rust_bert::Config;
564    /// use std::path::Path;
565    /// use tch::{nn, Device};
566    ///
567    /// let config_path = Path::new("path/to/config.json");
568    /// let device = Device::Cpu;
569    /// let p = nn::VarStore::new(device);
570    /// let config = BertConfig::from_file(config_path);
571    /// let bert = BertForMaskedLM::new(&p.root() / "bert", &config);
572    /// ```
573    pub fn new<'p, P>(p: P, config: &BertConfig) -> BertForMaskedLM
574    where
575        P: Borrow<nn::Path<'p>>,
576    {
577        let p = p.borrow();
578
579        let bert = BertModel::new_with_optional_pooler(p / "bert", config, false);
580        let cls = BertLMPredictionHead::new(p / "cls", config);
581
582        BertForMaskedLM { bert, cls }
583    }
584
585    /// Forward pass through the model
586    ///
587    /// # Arguments
588    ///
589    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see *input_embeds*)
590    /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
591    /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
592    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
593    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see *input_ids*)
594    /// * `encoder_hidden_states` - Optional encoder hidden state of shape (*batch size*, *encoder_sequence_length*, *hidden_size*). If the model is defined as a decoder and the *encoder_hidden_states* is not None, used in the cross-attention layer as keys and values (query from the decoder).
595    /// * `encoder_mask` - Optional encoder attention mask of shape (*batch size*, *encoder_sequence_length*). If the model is defined as a decoder and the *encoder_hidden_states* is not None, used to mask encoder values. Positions with value 0 will be masked.
596    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
597    ///
598    /// # Returns
599    ///
600    /// * `BertMaskedLMOutput` containing:
601    ///   - `prediction_scores` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*)
602    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
603    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
604    ///
605    /// # Example
606    ///
607    /// ```no_run
608    /// # use rust_bert::bert::{BertForMaskedLM, BertConfig};
609    /// # use tch::{nn, Device, Tensor, no_grad, Kind};
610    /// # use rust_bert::Config;
611    /// # use std::path::Path;
612    /// # let config_path = Path::new("path/to/config.json");
613    /// # let device = Device::Cpu;
614    /// # let vs = nn::VarStore::new(device);
615    /// # let config = BertConfig::from_file(config_path);
616    /// # let bert_model = BertForMaskedLM::new(&vs.root(), &config);
617    /// let (batch_size, sequence_length) = (64, 128);
618    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Kind::Int64, device));
619    /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
620    /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
621    /// let position_ids = Tensor::arange(sequence_length, (Kind::Int64, device))
622    ///     .expand(&[batch_size, sequence_length], true);
623    ///
624    /// let model_output = no_grad(|| {
625    ///     bert_model.forward_t(
626    ///         Some(&input_tensor),
627    ///         Some(&mask),
628    ///         Some(&token_type_ids),
629    ///         Some(&position_ids),
630    ///         None,
631    ///         None,
632    ///         None,
633    ///         false,
634    ///     )
635    /// });
636    /// ```
637    pub fn forward_t(
638        &self,
639        input_ids: Option<&Tensor>,
640        mask: Option<&Tensor>,
641        token_type_ids: Option<&Tensor>,
642        position_ids: Option<&Tensor>,
643        input_embeds: Option<&Tensor>,
644        encoder_hidden_states: Option<&Tensor>,
645        encoder_mask: Option<&Tensor>,
646        train: bool,
647    ) -> BertMaskedLMOutput {
648        let base_model_output = self
649            .bert
650            .forward_t(
651                input_ids,
652                mask,
653                token_type_ids,
654                position_ids,
655                input_embeds,
656                encoder_hidden_states,
657                encoder_mask,
658                train,
659            )
660            .unwrap();
661
662        let prediction_scores = self.cls.forward(&base_model_output.hidden_state);
663        BertMaskedLMOutput {
664            prediction_scores,
665            all_hidden_states: base_model_output.all_hidden_states,
666            all_attentions: base_model_output.all_attentions,
667        }
668    }
669}
670
671/// # BERT for sequence classification
672/// Base BERT model with a classifier head to perform sentence or document-level classification
673/// It is made of the following blocks:
674/// - `bert`: Base BertModel
675/// - `classifier`: BERT linear layer for classification
676pub struct BertForSequenceClassification {
677    bert: BertModel<BertEmbeddings>,
678    dropout: Dropout,
679    classifier: nn::Linear,
680}
681
682impl BertForSequenceClassification {
683    /// Build a new `BertForSequenceClassification`
684    ///
685    /// # Arguments
686    ///
687    /// * `p` - Variable store path for the root of the BertForSequenceClassification model
688    /// * `config` - `BertConfig` object defining the model architecture and number of classes
689    ///
690    /// # Example
691    ///
692    /// ```no_run
693    /// use rust_bert::bert::{BertConfig, BertForSequenceClassification};
694    /// use rust_bert::Config;
695    /// use std::path::Path;
696    /// use tch::{nn, Device};
697    ///
698    /// let config_path = Path::new("path/to/config.json");
699    /// let device = Device::Cpu;
700    /// let p = nn::VarStore::new(device);
701    /// let config = BertConfig::from_file(config_path);
702    /// let bert = BertForSequenceClassification::new(&p.root() / "bert", &config).unwrap();
703    /// ```
704    pub fn new<'p, P>(
705        p: P,
706        config: &BertConfig,
707    ) -> Result<BertForSequenceClassification, RustBertError>
708    where
709        P: Borrow<nn::Path<'p>>,
710    {
711        let p = p.borrow();
712
713        let bert = BertModel::new(p / "bert", config);
714        let dropout = Dropout::new(config.hidden_dropout_prob);
715        let num_labels = config
716            .id2label
717            .as_ref()
718            .ok_or_else(|| {
719                RustBertError::InvalidConfigurationError(
720                    "num_labels not provided in configuration".to_string(),
721                )
722            })?
723            .len() as i64;
724        let classifier = nn::linear(
725            p / "classifier",
726            config.hidden_size,
727            num_labels,
728            Default::default(),
729        );
730
731        Ok(BertForSequenceClassification {
732            bert,
733            dropout,
734            classifier,
735        })
736    }
737
738    /// Forward pass through the model
739    ///
740    /// # Arguments
741    ///
742    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
743    /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
744    /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
745    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
746    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
747    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
748    ///
749    /// # Returns
750    ///
751    /// * `BertSequenceClassificationOutput` containing:
752    ///   - `logits` - `Tensor` of shape (*batch size*, *num_labels*)
753    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
754    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
755    ///
756    /// # Example
757    ///
758    /// ```no_run
759    /// # use rust_bert::bert::{BertForSequenceClassification, BertConfig};
760    /// # use tch::{nn, Device, Tensor, no_grad, Kind};
761    /// # use rust_bert::Config;
762    /// # use std::path::Path;
763    /// # let config_path = Path::new("path/to/config.json");
764    /// # let device = Device::Cpu;
765    /// # let vs = nn::VarStore::new(device);
766    /// # let config = BertConfig::from_file(config_path);
767    /// # let bert_model = BertForSequenceClassification::new(&vs.root(), &config).unwrap();;
768    /// let (batch_size, sequence_length) = (64, 128);
769    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Kind::Int64, device));
770    /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
771    /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
772    /// let position_ids = Tensor::arange(sequence_length, (Kind::Int64, device))
773    ///     .expand(&[batch_size, sequence_length], true);
774    ///
775    /// let model_output = no_grad(|| {
776    ///     bert_model.forward_t(
777    ///         Some(&input_tensor),
778    ///         Some(&mask),
779    ///         Some(&token_type_ids),
780    ///         Some(&position_ids),
781    ///         None,
782    ///         false,
783    ///     )
784    /// });
785    /// ```
786    pub fn forward_t(
787        &self,
788        input_ids: Option<&Tensor>,
789        mask: Option<&Tensor>,
790        token_type_ids: Option<&Tensor>,
791        position_ids: Option<&Tensor>,
792        input_embeds: Option<&Tensor>,
793        train: bool,
794    ) -> BertSequenceClassificationOutput {
795        let base_model_output = self
796            .bert
797            .forward_t(
798                input_ids,
799                mask,
800                token_type_ids,
801                position_ids,
802                input_embeds,
803                None,
804                None,
805                train,
806            )
807            .unwrap();
808
809        let logits = base_model_output
810            .pooled_output
811            .unwrap()
812            .apply_t(&self.dropout, train)
813            .apply(&self.classifier);
814        BertSequenceClassificationOutput {
815            logits,
816            all_hidden_states: base_model_output.all_hidden_states,
817            all_attentions: base_model_output.all_attentions,
818        }
819    }
820}
821
822/// # BERT for multiple choices
823/// Multiple choices model using a BERT base model and a linear classifier.
824/// Input should be in the form `[CLS] Context [SEP] Possible choice [SEP]`. The choice is made along the batch axis,
825/// assuming all elements of the batch are alternatives to be chosen from for a given context.
826/// It is made of the following blocks:
827/// - `bert`: Base BertModel
828/// - `classifier`: Linear layer for multiple choices
829pub struct BertForMultipleChoice {
830    bert: BertModel<BertEmbeddings>,
831    dropout: Dropout,
832    classifier: nn::Linear,
833}
834
835impl BertForMultipleChoice {
836    /// Build a new `BertForMultipleChoice`
837    ///
838    /// # Arguments
839    ///
840    /// * `p` - Variable store path for the root of the BertForMultipleChoice model
841    /// * `config` - `BertConfig` object defining the model architecture
842    ///
843    /// # Example
844    ///
845    /// ```no_run
846    /// use rust_bert::bert::{BertConfig, BertForMultipleChoice};
847    /// use rust_bert::Config;
848    /// use std::path::Path;
849    /// use tch::{nn, Device};
850    ///
851    /// let config_path = Path::new("path/to/config.json");
852    /// let device = Device::Cpu;
853    /// let p = nn::VarStore::new(device);
854    /// let config = BertConfig::from_file(config_path);
855    /// let bert = BertForMultipleChoice::new(&p.root() / "bert", &config);
856    /// ```
857    pub fn new<'p, P>(p: P, config: &BertConfig) -> BertForMultipleChoice
858    where
859        P: Borrow<nn::Path<'p>>,
860    {
861        let p = p.borrow();
862
863        let bert = BertModel::new(p / "bert", config);
864        let dropout = Dropout::new(config.hidden_dropout_prob);
865        let classifier = nn::linear(p / "classifier", config.hidden_size, 1, Default::default());
866
867        BertForMultipleChoice {
868            bert,
869            dropout,
870            classifier,
871        }
872    }
873
874    /// Forward pass through the model
875    ///
876    /// # Arguments
877    ///
878    /// * `input_ids` - Input tensor of shape (*batch size*, *sequence_length*).
879    /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
880    /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
881    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
882    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
883    ///
884    /// # Returns
885    ///
886    /// * `BertSequenceClassificationOutput` containing:
887    ///   - `logits` - `Tensor` of shape (*1*, *batch size*) containing the logits for each of the alternatives given
888    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
889    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
890    ///
891    /// # Example
892    ///
893    /// ```no_run
894    /// # use rust_bert::bert::{BertForMultipleChoice, BertConfig};
895    /// # use tch::{nn, Device, Tensor, no_grad};
896    /// # use rust_bert::Config;
897    /// # use std::path::Path;
898    /// # use tch::kind::Kind::Int64;
899    /// # let config_path = Path::new("path/to/config.json");
900    /// # let device = Device::Cpu;
901    /// # let vs = nn::VarStore::new(device);
902    /// # let config = BertConfig::from_file(config_path);
903    /// # let bert_model = BertForMultipleChoice::new(&vs.root(), &config);
904    /// let (num_choices, sequence_length) = (3, 128);
905    /// let input_tensor = Tensor::rand(&[num_choices, sequence_length], (Int64, device));
906    /// let mask = Tensor::zeros(&[num_choices, sequence_length], (Int64, device));
907    /// let token_type_ids = Tensor::zeros(&[num_choices, sequence_length], (Int64, device));
908    /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
909    ///     .expand(&[num_choices, sequence_length], true);
910    ///
911    /// let model_output = no_grad(|| {
912    ///     bert_model.forward_t(
913    ///         &input_tensor,
914    ///         Some(&mask),
915    ///         Some(&token_type_ids),
916    ///         Some(&position_ids),
917    ///         false,
918    ///     )
919    /// });
920    /// ```
921    pub fn forward_t(
922        &self,
923        input_ids: &Tensor,
924        mask: Option<&Tensor>,
925        token_type_ids: Option<&Tensor>,
926        position_ids: Option<&Tensor>,
927        train: bool,
928    ) -> BertSequenceClassificationOutput {
929        let num_choices = input_ids.size()[1];
930
931        let input_ids = input_ids.view((-1, *input_ids.size().last().unwrap()));
932        let mask = mask.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
933        let token_type_ids =
934            token_type_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
935        let position_ids =
936            position_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
937
938        let base_model_output = self
939            .bert
940            .forward_t(
941                Some(&input_ids),
942                mask.as_ref(),
943                token_type_ids.as_ref(),
944                position_ids.as_ref(),
945                None,
946                None,
947                None,
948                train,
949            )
950            .unwrap();
951
952        let logits = base_model_output
953            .pooled_output
954            .unwrap()
955            .apply_t(&self.dropout, train)
956            .apply(&self.classifier)
957            .view((-1, num_choices));
958        BertSequenceClassificationOutput {
959            logits,
960            all_hidden_states: base_model_output.all_hidden_states,
961            all_attentions: base_model_output.all_attentions,
962        }
963    }
964}
965
966/// # BERT for token classification (e.g. NER, POS)
967/// Token-level classifier predicting a label for each token provided. Note that because of wordpiece tokenization, the labels predicted are
968/// not necessarily aligned with words in the sentence.
969/// It is made of the following blocks:
970/// - `bert`: Base BertModel
971/// - `classifier`: Linear layer for token classification
972pub struct BertForTokenClassification {
973    bert: BertModel<BertEmbeddings>,
974    dropout: Dropout,
975    classifier: nn::Linear,
976}
977
978impl BertForTokenClassification {
979    /// Build a new `BertForTokenClassification`
980    ///
981    /// # Arguments
982    ///
983    /// * `p` - Variable store path for the root of the BertForTokenClassification model
984    /// * `config` - `BertConfig` object defining the model architecture, number of output labels and label mapping
985    ///
986    /// # Example
987    ///
988    /// ```no_run
989    /// use rust_bert::bert::{BertConfig, BertForTokenClassification};
990    /// use rust_bert::Config;
991    /// use std::path::Path;
992    /// use tch::{nn, Device};
993    ///
994    /// let config_path = Path::new("path/to/config.json");
995    /// let device = Device::Cpu;
996    /// let p = nn::VarStore::new(device);
997    /// let config = BertConfig::from_file(config_path);
998    /// let bert = BertForTokenClassification::new(&p.root() / "bert", &config).unwrap();
999    /// ```
1000    pub fn new<'p, P>(
1001        p: P,
1002        config: &BertConfig,
1003    ) -> Result<BertForTokenClassification, RustBertError>
1004    where
1005        P: Borrow<nn::Path<'p>>,
1006    {
1007        let p = p.borrow();
1008
1009        let bert = BertModel::new_with_optional_pooler(p / "bert", config, false);
1010        let dropout = Dropout::new(config.hidden_dropout_prob);
1011        let num_labels = config
1012            .id2label
1013            .as_ref()
1014            .ok_or_else(|| {
1015                RustBertError::InvalidConfigurationError(
1016                    "num_labels not provided in configuration".to_string(),
1017                )
1018            })?
1019            .len() as i64;
1020        let classifier = nn::linear(
1021            p / "classifier",
1022            config.hidden_size,
1023            num_labels,
1024            Default::default(),
1025        );
1026
1027        Ok(BertForTokenClassification {
1028            bert,
1029            dropout,
1030            classifier,
1031        })
1032    }
1033
1034    /// Forward pass through the model
1035    ///
1036    /// # Arguments
1037    ///
1038    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
1039    /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
1040    /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
1041    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
1042    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
1043    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
1044    ///
1045    /// # Returns
1046    ///
1047    /// * `BertTokenClassificationOutput` containing:
1048    ///   - `logits` - `Tensor` of shape (*batch size*, *sequence_length*, *num_labels*) containing the logits for each of the input tokens and classes
1049    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
1050    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
1051    ///
1052    /// # Example
1053    ///
1054    /// ```no_run
1055    /// # use rust_bert::bert::{BertForTokenClassification, BertConfig};
1056    /// # use tch::{nn, Device, Tensor, no_grad};
1057    /// # use rust_bert::Config;
1058    /// # use std::path::Path;
1059    /// # use tch::kind::Kind::Int64;
1060    /// # let config_path = Path::new("path/to/config.json");
1061    /// # let device = Device::Cpu;
1062    /// # let vs = nn::VarStore::new(device);
1063    /// # let config = BertConfig::from_file(config_path);
1064    /// # let bert_model = BertForTokenClassification::new(&vs.root(), &config).unwrap();
1065    /// let (batch_size, sequence_length) = (64, 128);
1066    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
1067    /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
1068    /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
1069    /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
1070    ///     .expand(&[batch_size, sequence_length], true);
1071    ///
1072    /// let model_output = no_grad(|| {
1073    ///     bert_model.forward_t(
1074    ///         Some(&input_tensor),
1075    ///         Some(&mask),
1076    ///         Some(&token_type_ids),
1077    ///         Some(&position_ids),
1078    ///         None,
1079    ///         false,
1080    ///     )
1081    /// });
1082    /// ```
1083    pub fn forward_t(
1084        &self,
1085        input_ids: Option<&Tensor>,
1086        mask: Option<&Tensor>,
1087        token_type_ids: Option<&Tensor>,
1088        position_ids: Option<&Tensor>,
1089        input_embeds: Option<&Tensor>,
1090        train: bool,
1091    ) -> BertTokenClassificationOutput {
1092        let base_model_output = self
1093            .bert
1094            .forward_t(
1095                input_ids,
1096                mask,
1097                token_type_ids,
1098                position_ids,
1099                input_embeds,
1100                None,
1101                None,
1102                train,
1103            )
1104            .unwrap();
1105
1106        let logits = base_model_output
1107            .hidden_state
1108            .apply_t(&self.dropout, train)
1109            .apply(&self.classifier);
1110        BertTokenClassificationOutput {
1111            logits,
1112            all_hidden_states: base_model_output.all_hidden_states,
1113            all_attentions: base_model_output.all_attentions,
1114        }
1115    }
1116}
1117
1118/// # BERT for question answering
1119/// Extractive question-answering model based on a BERT language model. Identifies the segment of a context that answers a provided question.
1120/// Please note that a significant amount of pre- and post-processing is required to perform end-to-end question answering.
1121/// See the question answering pipeline (also provided in this crate) for more details.
1122/// It is made of the following blocks:
1123/// - `bert`: Base BertModel
1124/// - `qa_outputs`: Linear layer for question answering
1125pub struct BertForQuestionAnswering {
1126    bert: BertModel<BertEmbeddings>,
1127    qa_outputs: nn::Linear,
1128}
1129
1130impl BertForQuestionAnswering {
1131    /// Build a new `BertForQuestionAnswering`
1132    ///
1133    /// # Arguments
1134    ///
1135    /// * `p` - Variable store path for the root of the BertForQuestionAnswering model
1136    /// * `config` - `BertConfig` object defining the model architecture
1137    ///
1138    /// # Example
1139    ///
1140    /// ```no_run
1141    /// use rust_bert::bert::{BertConfig, BertForQuestionAnswering};
1142    /// use rust_bert::Config;
1143    /// use std::path::Path;
1144    /// use tch::{nn, Device};
1145    ///
1146    /// let config_path = Path::new("path/to/config.json");
1147    /// let device = Device::Cpu;
1148    /// let p = nn::VarStore::new(device);
1149    /// let config = BertConfig::from_file(config_path);
1150    /// let bert = BertForQuestionAnswering::new(&p.root() / "bert", &config);
1151    /// ```
1152    pub fn new<'p, P>(p: P, config: &BertConfig) -> BertForQuestionAnswering
1153    where
1154        P: Borrow<nn::Path<'p>>,
1155    {
1156        let p = p.borrow();
1157
1158        let bert = BertModel::new(p / "bert", config);
1159        let num_labels = 2;
1160        let qa_outputs = nn::linear(
1161            p / "qa_outputs",
1162            config.hidden_size,
1163            num_labels,
1164            Default::default(),
1165        );
1166
1167        BertForQuestionAnswering { bert, qa_outputs }
1168    }
1169
1170    /// Forward pass through the model
1171    ///
1172    /// # Arguments
1173    ///
1174    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
1175    /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
1176    /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
1177    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
1178    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
1179    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
1180    ///
1181    /// # Returns
1182    ///
1183    /// * `BertQuestionAnsweringOutput` containing:
1184    ///   - `start_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for start of the answer
1185    ///   - `end_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for end of the answer
1186    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
1187    ///   - `all_attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
1188    ///
1189    /// # Example
1190    ///
1191    /// ```no_run
1192    /// # use rust_bert::bert::{BertForQuestionAnswering, BertConfig};
1193    /// # use tch::{nn, Device, Tensor, no_grad};
1194    /// # use rust_bert::Config;
1195    /// # use std::path::Path;
1196    /// # use tch::kind::Kind::Int64;
1197    /// # let config_path = Path::new("path/to/config.json");
1198    /// # let device = Device::Cpu;
1199    /// # let vs = nn::VarStore::new(device);
1200    /// # let config = BertConfig::from_file(config_path);
1201    /// # let bert_model = BertForQuestionAnswering::new(&vs.root(), &config);
1202    /// let (batch_size, sequence_length) = (64, 128);
1203    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
1204    /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
1205    /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
1206    /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
1207    ///     .expand(&[batch_size, sequence_length], true);
1208    ///
1209    /// let model_output = no_grad(|| {
1210    ///     bert_model.forward_t(
1211    ///         Some(&input_tensor),
1212    ///         Some(&mask),
1213    ///         Some(&token_type_ids),
1214    ///         Some(&position_ids),
1215    ///         None,
1216    ///         false,
1217    ///     )
1218    /// });
1219    /// ```
1220    pub fn forward_t(
1221        &self,
1222        input_ids: Option<&Tensor>,
1223        mask: Option<&Tensor>,
1224        token_type_ids: Option<&Tensor>,
1225        position_ids: Option<&Tensor>,
1226        input_embeds: Option<&Tensor>,
1227        train: bool,
1228    ) -> BertQuestionAnsweringOutput {
1229        let base_model_output = self
1230            .bert
1231            .forward_t(
1232                input_ids,
1233                mask,
1234                token_type_ids,
1235                position_ids,
1236                input_embeds,
1237                None,
1238                None,
1239                train,
1240            )
1241            .unwrap();
1242
1243        let sequence_output = base_model_output.hidden_state.apply(&self.qa_outputs);
1244        let logits = sequence_output.split(1, -1);
1245        let (start_logits, end_logits) = (&logits[0], &logits[1]);
1246        let start_logits = start_logits.squeeze_dim(-1);
1247        let end_logits = end_logits.squeeze_dim(-1);
1248
1249        BertQuestionAnsweringOutput {
1250            start_logits,
1251            end_logits,
1252            all_hidden_states: base_model_output.all_hidden_states,
1253            all_attentions: base_model_output.all_attentions,
1254        }
1255    }
1256}
1257
1258/// # BERT for sentence embeddings
1259/// Transformer usable in [`SentenceEmbeddingsModel`](crate::pipelines::sentence_embeddings::SentenceEmbeddingsModel).
1260pub type BertForSentenceEmbeddings = BertModel<BertEmbeddings>;
1261
1262/// Container for the BERT model output.
1263pub struct BertModelOutput {
1264    /// Last hidden states from the model
1265    pub hidden_state: Tensor,
1266    /// Pooled output (hidden state for the first token)
1267    pub pooled_output: Option<Tensor>,
1268    /// Hidden states for all intermediate layers
1269    pub all_hidden_states: Option<Vec<Tensor>>,
1270    /// Attention weights for all intermediate layers
1271    pub all_attentions: Option<Vec<Tensor>>,
1272}
1273
1274/// Container for the BERT masked LM model output.
1275pub struct BertMaskedLMOutput {
1276    /// Logits for the vocabulary items at each sequence position
1277    pub prediction_scores: Tensor,
1278    /// Hidden states for all intermediate layers
1279    pub all_hidden_states: Option<Vec<Tensor>>,
1280    /// Attention weights for all intermediate layers
1281    pub all_attentions: Option<Vec<Tensor>>,
1282}
1283
1284/// Container for the BERT sequence classification model output.
1285pub struct BertSequenceClassificationOutput {
1286    /// Logits for each input (sequence) for each target class
1287    pub logits: Tensor,
1288    /// Hidden states for all intermediate layers
1289    pub all_hidden_states: Option<Vec<Tensor>>,
1290    /// Attention weights for all intermediate layers
1291    pub all_attentions: Option<Vec<Tensor>>,
1292}
1293
1294/// Container for the BERT token classification model output.
1295pub struct BertTokenClassificationOutput {
1296    /// Logits for each sequence item (token) for each target class
1297    pub logits: Tensor,
1298    /// Hidden states for all intermediate layers
1299    pub all_hidden_states: Option<Vec<Tensor>>,
1300    /// Attention weights for all intermediate layers
1301    pub all_attentions: Option<Vec<Tensor>>,
1302}
1303
1304/// Container for the BERT question answering model output.
1305pub struct BertQuestionAnsweringOutput {
1306    /// Logits for the start position for token of each input sequence
1307    pub start_logits: Tensor,
1308    /// Logits for the end position for token of each input sequence
1309    pub end_logits: Tensor,
1310    /// Hidden states for all intermediate layers
1311    pub all_hidden_states: Option<Vec<Tensor>>,
1312    /// Attention weights for all intermediate layers
1313    pub all_attentions: Option<Vec<Tensor>>,
1314}
1315
1316#[cfg(test)]
1317mod test {
1318    use tch::Device;
1319
1320    use crate::{
1321        resources::{RemoteResource, ResourceProvider},
1322        Config,
1323    };
1324
1325    use super::*;
1326
1327    #[test]
1328    #[ignore] // compilation is enough, no need to run
1329    fn bert_model_send() {
1330        let config_resource = Box::new(RemoteResource::from_pretrained(BertConfigResources::BERT));
1331        let config_path = config_resource.get_local_path().expect("");
1332
1333        //    Set-up masked LM model
1334        let device = Device::cuda_if_available();
1335        let vs = nn::VarStore::new(device);
1336        let config = BertConfig::from_file(config_path);
1337
1338        let _: Box<dyn Send> = Box::new(BertModel::<BertEmbeddings>::new(vs.root(), &config));
1339    }
1340}