rust_bert/models/roberta/
roberta_model.rs

1// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
2// Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
3// Copyright 2019 Guillaume Becquin
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//     http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14use crate::bert::{BertConfig, BertModel};
15use crate::common::activations::_gelu;
16use crate::common::dropout::Dropout;
17use crate::common::linear::{linear_no_bias, LinearNoBias};
18use crate::roberta::embeddings::RobertaEmbeddings;
19use crate::RustBertError;
20use std::borrow::Borrow;
21use tch::nn::init::DEFAULT_KAIMING_UNIFORM;
22use tch::{nn, Tensor};
23
24/// # RoBERTa Pretrained model weight files
25pub struct RobertaModelResources;
26
27/// # RoBERTa Pretrained model config files
28pub struct RobertaConfigResources;
29
30/// # RoBERTa Pretrained model vocab files
31pub struct RobertaVocabResources;
32
33/// # RoBERTa Pretrained model merges files
34pub struct RobertaMergesResources;
35
36impl RobertaModelResources {
37    /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
38    pub const ROBERTA: (&'static str, &'static str) = (
39        "roberta/model",
40        "https://huggingface.co/roberta-base/resolve/main/rust_model.ot",
41    );
42    /// Shared under Apache 2.0 license by the Hugging Face Inc. team at <https://huggingface.co/distilroberta-base>. Modified with conversion to C-array format.
43    pub const DISTILROBERTA_BASE: (&'static str, &'static str) = (
44        "distilroberta-base/model",
45        "https://huggingface.co/distilroberta-base/resolve/main/rust_model.ot",
46    );
47    /// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at <https://huggingface.co/deepset/roberta-base-squad2>. Modified with conversion to C-array format.
48    pub const ROBERTA_QA: (&'static str, &'static str) = (
49        "roberta-qa/model",
50        "https://huggingface.co/deepset/roberta-base-squad2/resolve/main/rust_model.ot",
51    );
52    /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
53    pub const XLM_ROBERTA_NER_EN: (&'static str, &'static str) = (
54        "xlm-roberta-ner-en/model",
55        "https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/rust_model.ot",
56    );
57    /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
58    pub const XLM_ROBERTA_NER_DE: (&'static str, &'static str) = (
59        "xlm-roberta-ner-de/model",
60        "https://huggingface.co/xlm-roberta-large-finetuned-conll03-german/resolve/main/rust_model.ot",
61    );
62    /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
63    pub const XLM_ROBERTA_NER_NL: (&'static str, &'static str) = (
64        "xlm-roberta-ner-nl/model",
65        "https://huggingface.co/xlm-roberta-large-finetuned-conll02-dutch/resolve/main/rust_model.ot",
66    );
67    /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
68    pub const XLM_ROBERTA_NER_ES: (&'static str, &'static str) = (
69        "xlm-roberta-ner-es/model",
70        "https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/rust_model.ot",
71    );
72    /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/sentence-transformers/all-distilroberta-v1>. Modified with conversion to C-array format.
73    pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = (
74        "all-distilroberta-v1/model",
75        "https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/rust_model.ot",
76    );
77    /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/huggingface/CodeBERTa-language-id>. Modified with conversion to C-array format.
78    pub const CODEBERTA_LANGUAGE_ID: (&'static str, &'static str) = (
79        "codeberta-language-id/model",
80        "https://huggingface.co/huggingface/CodeBERTa-language-id/resolve/main/rust_model.ot",
81    );
82    /// Shared under MIT license by the Microsoft team at <https://github.com/microsoft/CodeBERT>. Modified with conversion to C-array format.
83    pub const CODEBERT_MLM: (&'static str, &'static str) = (
84        "codebert-mlm/model",
85        "https://huggingface.co/microsoft/codebert-base-mlm/resolve/main/rust_model.ot",
86    );
87}
88
89impl RobertaConfigResources {
90    /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
91    pub const ROBERTA: (&'static str, &'static str) = (
92        "roberta/config",
93        "https://huggingface.co/roberta-base/resolve/main/config.json",
94    );
95    /// Shared under Apache 2.0 license by the Hugging Face Inc. team at <https://huggingface.co/distilroberta-base>. Modified with conversion to C-array format.
96    pub const DISTILROBERTA_BASE: (&'static str, &'static str) = (
97        "distilroberta-base/config",
98        "https://cdn.huggingface.co/distilroberta-base-config.json",
99    );
100    /// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at <https://huggingface.co/deepset/roberta-base-squad2>. Modified with conversion to C-array format.
101    pub const ROBERTA_QA: (&'static str, &'static str) = (
102        "roberta-qa/config",
103        "https://huggingface.co/deepset/roberta-base-squad2/resolve/main/config.json",
104    );
105    /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
106    pub const XLM_ROBERTA_NER_EN: (&'static str, &'static str) = (
107        "xlm-roberta-ner-en/config",
108        "https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/config.json",
109    );
110    /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
111    pub const XLM_ROBERTA_NER_DE: (&'static str, &'static str) = (
112        "xlm-roberta-ner-de/config",
113        "https://huggingface.co/xlm-roberta-large-finetuned-conll03-german/resolve/main/config.json",
114    );
115    /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
116    pub const XLM_ROBERTA_NER_NL: (&'static str, &'static str) = (
117        "xlm-roberta-ner-nl/config",
118        "https://huggingface.co/xlm-roberta-large-finetuned-conll02-dutch/resolve/main/config.json",
119    );
120    /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
121    pub const XLM_ROBERTA_NER_ES: (&'static str, &'static str) = (
122        "xlm-roberta-ner-es/config",
123        "https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/config.json",
124    );
125    /// Shared under Apache 2.0 licenseat <https://huggingface.co/sentence-transformers/all-distilroberta-v1>. Modified with conversion to C-array format.
126    pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = (
127        "all-distilroberta-v1/config",
128        "https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/config.json",
129    );
130    /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/huggingface/CodeBERTa-language-id>. Modified with conversion to C-array format.
131    pub const CODEBERTA_LANGUAGE_ID: (&'static str, &'static str) = (
132        "codeberta-language-id/config",
133        "https://huggingface.co/huggingface/CodeBERTa-language-id/resolve/main/config.json",
134    );
135    /// Shared under MIT license by the Microsoft team at <https://github.com/microsoft/CodeBERT>. Modified with conversion to C-array format.
136    pub const CODEBERT_MLM: (&'static str, &'static str) = (
137        "codebert-mlm/config",
138        "https://huggingface.co/microsoft/codebert-base-mlm/resolve/main/config.json",
139    );
140}
141
142impl RobertaVocabResources {
143    /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
144    pub const ROBERTA: (&'static str, &'static str) = (
145        "roberta/vocab",
146        "https://huggingface.co/roberta-base/resolve/main/vocab.json",
147    );
148    /// Shared under Apache 2.0 license by the Hugging Face Inc. team at <https://huggingface.co/distilroberta-base>. Modified with conversion to C-array format.
149    pub const DISTILROBERTA_BASE: (&'static str, &'static str) = (
150        "distilroberta-base/vocab",
151        "https://cdn.huggingface.co/distilroberta-base-vocab.json",
152    );
153    /// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at <https://huggingface.co/deepset/roberta-base-squad2>. Modified with conversion to C-array format.
154    pub const ROBERTA_QA: (&'static str, &'static str) = (
155        "roberta-qa/vocab",
156        "https://huggingface.co/deepset/roberta-base-squad2/resolve/main/vocab.json",
157    );
158    /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
159    pub const XLM_ROBERTA_NER_EN: (&'static str, &'static str) = (
160        "xlm-roberta-ner-en/spiece",
161        "https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/sentencepiece.bpe.model",
162    );
163    /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
164    pub const XLM_ROBERTA_NER_DE: (&'static str, &'static str) = (
165        "xlm-roberta-ner-de/spiece",
166        "https://huggingface.co/xlm-roberta-large-finetuned-conll03-german/resolve/main/sentencepiece.bpe.model",
167    );
168    /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
169    pub const XLM_ROBERTA_NER_NL: (&'static str, &'static str) = (
170        "xlm-roberta-ner-nl/spiece",
171        "https://huggingface.co/xlm-roberta-large-finetuned-conll02-dutch/resolve/main/sentencepiece.bpe.model",
172    );
173    /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
174    pub const XLM_ROBERTA_NER_ES: (&'static str, &'static str) = (
175        "xlm-roberta-ner-es/spiece",
176        "https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/sentencepiece.bpe.model",
177    );
178    /// Shared under Apache 2.0 licenseat <https://huggingface.co/sentence-transformers/all-distilroberta-v1>. Modified with conversion to C-array format.
179    pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = (
180        "all-distilroberta-v1/vocab",
181        "https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/vocab.json",
182    );
183    /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/huggingface/CodeBERTa-language-id>. Modified with conversion to C-array format.
184    pub const CODEBERTA_LANGUAGE_ID: (&'static str, &'static str) = (
185        "codeberta-language-id/vocab",
186        "https://huggingface.co/huggingface/CodeBERTa-language-id/resolve/main/vocab.json",
187    );
188    /// Shared under MIT license by the Microsoft team at <https://github.com/microsoft/CodeBERT>. Modified with conversion to C-array format.
189    pub const CODEBERT_MLM: (&'static str, &'static str) = (
190        "codebert-mlm/vocab",
191        "https://huggingface.co/microsoft/codebert-base-mlm/resolve/main/vocab.json",
192    );
193}
194
195impl RobertaMergesResources {
196    /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
197    pub const ROBERTA: (&'static str, &'static str) = (
198        "roberta/merges",
199        "https://huggingface.co/roberta-base/resolve/main/merges.txt",
200    );
201    /// Shared under Apache 2.0 license by the Hugging Face Inc. team at <https://huggingface.co/distilroberta-base>. Modified with conversion to C-array format.
202    pub const DISTILROBERTA_BASE: (&'static str, &'static str) = (
203        "distilroberta-base/merges",
204        "https://cdn.huggingface.co/distilroberta-base-merges.txt",
205    );
206    /// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at <https://huggingface.co/deepset/roberta-base-squad2>. Modified with conversion to C-array format.
207    pub const ROBERTA_QA: (&'static str, &'static str) = (
208        "roberta-qa/merges",
209        "https://huggingface.co/deepset/roberta-base-squad2/resolve/main/merges.txt",
210    );
211    /// Shared under Apache 2.0 licenseat <https://huggingface.co/sentence-transformers/all-distilroberta-v1>. Modified with conversion to C-array format.
212    pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = (
213        "all-distilroberta-v1/merges",
214        "https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/merges.txt",
215    );
216    /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/huggingface/CodeBERTa-language-id>. Modified with conversion to C-array format.
217    pub const CODEBERTA_LANGUAGE_ID: (&'static str, &'static str) = (
218        "codeberta-language-id/merges",
219        "https://huggingface.co/huggingface/CodeBERTa-language-id/resolve/main/merges.txt",
220    );
221    /// Shared under MIT license by the Microsoft team at <https://github.com/microsoft/CodeBERT>. Modified with conversion to C-array format.
222    pub const CODEBERT_MLM: (&'static str, &'static str) = (
223        "codebert-mlm/merges",
224        "https://huggingface.co/microsoft/codebert-base-mlm/resolve/main/merges.txt",
225    );
226}
227
228pub struct RobertaLMHead {
229    dense: nn::Linear,
230    decoder: LinearNoBias,
231    layer_norm: nn::LayerNorm,
232    bias: Tensor,
233}
234
235impl RobertaLMHead {
236    pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaLMHead
237    where
238        P: Borrow<nn::Path<'p>>,
239    {
240        let p = p.borrow();
241        let dense = nn::linear(
242            p / "dense",
243            config.hidden_size,
244            config.hidden_size,
245            Default::default(),
246        );
247        let layer_norm_config = nn::LayerNormConfig {
248            eps: 1e-12,
249            ..Default::default()
250        };
251        let layer_norm = nn::layer_norm(
252            p / "layer_norm",
253            vec![config.hidden_size],
254            layer_norm_config,
255        );
256        let decoder = linear_no_bias(
257            p / "decoder",
258            config.hidden_size,
259            config.vocab_size,
260            Default::default(),
261        );
262        let bias = p.var("bias", &[config.vocab_size], DEFAULT_KAIMING_UNIFORM);
263
264        RobertaLMHead {
265            dense,
266            decoder,
267            layer_norm,
268            bias,
269        }
270    }
271
272    pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
273        (_gelu(&hidden_states.apply(&self.dense)))
274            .apply(&self.layer_norm)
275            .apply(&self.decoder)
276            + &self.bias
277    }
278}
279
280/// # RoBERTa model configuration
281/// Defines the RoBERTa model architecture (e.g. number of layers, hidden layer size, label mapping...)
282pub type RobertaConfig = BertConfig;
283
284/// # RoBERTa for masked language model
285/// Base RoBERTa model with a RoBERTa masked language model head to predict missing tokens, for example `"Looks like one [MASK] is missing" -> "person"`
286/// It is made of the following blocks:
287/// - `roberta`: Base BertModel with RoBERTa embeddings
288/// - `lm_head`: RoBERTa LM prediction head
289pub struct RobertaForMaskedLM {
290    roberta: BertModel<RobertaEmbeddings>,
291    lm_head: RobertaLMHead,
292}
293
294impl RobertaForMaskedLM {
295    /// Build a new `RobertaForMaskedLM`
296    ///
297    /// # Arguments
298    ///
299    /// * `p` - Variable store path for the root of the RobertaForMaskedLM model
300    /// * `config` - `RobertaConfig` object defining the model architecture and vocab size
301    ///
302    /// # Example
303    ///
304    /// ```no_run
305    /// use rust_bert::roberta::{RobertaConfig, RobertaForMaskedLM};
306    /// use rust_bert::Config;
307    /// use std::path::Path;
308    /// use tch::{nn, Device};
309    ///
310    /// let config_path = Path::new("path/to/config.json");
311    /// let device = Device::Cpu;
312    /// let p = nn::VarStore::new(device);
313    /// let config = RobertaConfig::from_file(config_path);
314    /// let roberta = RobertaForMaskedLM::new(&p.root() / "roberta", &config);
315    /// ```
316    pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaForMaskedLM
317    where
318        P: Borrow<nn::Path<'p>>,
319    {
320        let p = p.borrow();
321
322        let roberta =
323            BertModel::<RobertaEmbeddings>::new_with_optional_pooler(p / "roberta", config, false);
324        let lm_head = RobertaLMHead::new(p / "lm_head", config);
325
326        RobertaForMaskedLM { roberta, lm_head }
327    }
328
329    #[allow(rustdoc::invalid_html_tags)]
330    /// Forward pass through the model
331    ///
332    /// # Arguments
333    ///
334    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see *input_embeds*)
335    /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
336    /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *</s>*) and 1 for the second sentence. If None set to 0.
337    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
338    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see *input_ids*)
339    /// * `encoder_hidden_states` - Optional encoder hidden state of shape (*batch size*, *encoder_sequence_length*, *hidden_size*). If the model is defined as a decoder and the *encoder_hidden_states* is not None, used in the cross-attention layer as keys and values (query from the decoder).
340    /// * `encoder_mask` - Optional encoder attention mask of shape (*batch size*, *encoder_sequence_length*). If the model is defined as a decoder and the *encoder_hidden_states* is not None, used to mask encoder values. Positions with value 0 will be masked.
341    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
342    ///
343    /// # Returns
344    ///
345    /// * `output` - `Tensor` of shape (*batch size*, *num_labels*, *vocab_size*)
346    /// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
347    /// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
348    ///
349    /// # Example
350    ///
351    /// ```no_run
352    /// # use rust_bert::bert::BertConfig;
353    /// # use tch::{nn, Device, Tensor, no_grad};
354    /// # use rust_bert::Config;
355    /// # use std::path::Path;
356    /// # use tch::kind::Kind::Int64;
357    /// use rust_bert::roberta::RobertaForMaskedLM;
358    /// # let config_path = Path::new("path/to/config.json");
359    /// # let vocab_path = Path::new("path/to/vocab.txt");
360    /// # let device = Device::Cpu;
361    /// # let vs = nn::VarStore::new(device);
362    /// # let config = BertConfig::from_file(config_path);
363    /// # let roberta_model = RobertaForMaskedLM::new(&vs.root(), &config);
364    /// let (batch_size, sequence_length) = (64, 128);
365    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
366    /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
367    /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
368    /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
369    ///     .expand(&[batch_size, sequence_length], true);
370    ///
371    /// let model_output = no_grad(|| {
372    ///     roberta_model.forward_t(
373    ///         Some(&input_tensor),
374    ///         Some(&mask),
375    ///         Some(&token_type_ids),
376    ///         Some(&position_ids),
377    ///         None,
378    ///         None,
379    ///         None,
380    ///         false,
381    ///     )
382    /// });
383    /// ```
384    pub fn forward_t(
385        &self,
386        input_ids: Option<&Tensor>,
387        mask: Option<&Tensor>,
388        token_type_ids: Option<&Tensor>,
389        position_ids: Option<&Tensor>,
390        input_embeds: Option<&Tensor>,
391        encoder_hidden_states: Option<&Tensor>,
392        encoder_mask: Option<&Tensor>,
393        train: bool,
394    ) -> RobertaMaskedLMOutput {
395        let base_model_output = self
396            .roberta
397            .forward_t(
398                input_ids,
399                mask,
400                token_type_ids,
401                position_ids,
402                input_embeds,
403                encoder_hidden_states,
404                encoder_mask,
405                train,
406            )
407            .unwrap();
408
409        let prediction_scores = self.lm_head.forward(&base_model_output.hidden_state);
410        RobertaMaskedLMOutput {
411            prediction_scores,
412            all_hidden_states: base_model_output.all_hidden_states,
413            all_attentions: base_model_output.all_attentions,
414        }
415    }
416}
417
418pub struct RobertaClassificationHead {
419    dense: nn::Linear,
420    dropout: Dropout,
421    out_proj: nn::Linear,
422}
423
424impl RobertaClassificationHead {
425    pub fn new<'p, P>(p: P, config: &BertConfig) -> Result<RobertaClassificationHead, RustBertError>
426    where
427        P: Borrow<nn::Path<'p>>,
428    {
429        let p = p.borrow();
430        let dense = nn::linear(
431            p / "dense",
432            config.hidden_size,
433            config.hidden_size,
434            Default::default(),
435        );
436        let num_labels = config
437            .id2label
438            .as_ref()
439            .ok_or_else(|| {
440                RustBertError::InvalidConfigurationError(
441                    "num_labels not provided in configuration".to_string(),
442                )
443            })?
444            .len() as i64;
445        let out_proj = nn::linear(
446            p / "out_proj",
447            config.hidden_size,
448            num_labels,
449            Default::default(),
450        );
451        let dropout = Dropout::new(config.hidden_dropout_prob);
452
453        Ok(RobertaClassificationHead {
454            dense,
455            dropout,
456            out_proj,
457        })
458    }
459
460    pub fn forward_t(&self, hidden_states: &Tensor, train: bool) -> Tensor {
461        hidden_states
462            .select(1, 0)
463            .apply_t(&self.dropout, train)
464            .apply(&self.dense)
465            .tanh()
466            .apply_t(&self.dropout, train)
467            .apply(&self.out_proj)
468    }
469}
470
471/// # RoBERTa for sequence classification
472/// Base RoBERTa model with a classifier head to perform sentence or document-level classification
473/// It is made of the following blocks:
474/// - `roberta`: Base RoBERTa model
475/// - `classifier`: RoBERTa classification head made of 2 linear layers
476pub struct RobertaForSequenceClassification {
477    roberta: BertModel<RobertaEmbeddings>,
478    classifier: RobertaClassificationHead,
479}
480
481impl RobertaForSequenceClassification {
482    /// Build a new `RobertaForSequenceClassification`
483    ///
484    /// # Arguments
485    ///
486    /// * `p` - Variable store path for the root of the RobertaForMaskedLM model
487    /// * `config` - `RobertaConfig` object defining the model architecture and vocab size
488    ///
489    /// # Example
490    ///
491    /// ```no_run
492    /// use rust_bert::roberta::{RobertaConfig, RobertaForSequenceClassification};
493    /// use rust_bert::Config;
494    /// use std::path::Path;
495    /// use tch::{nn, Device};
496    ///
497    /// let config_path = Path::new("path/to/config.json");
498    /// let device = Device::Cpu;
499    /// let p = nn::VarStore::new(device);
500    /// let config = RobertaConfig::from_file(config_path);
501    /// let roberta = RobertaForSequenceClassification::new(&p.root() / "roberta", &config).unwrap();
502    /// ```
503    pub fn new<'p, P>(
504        p: P,
505        config: &BertConfig,
506    ) -> Result<RobertaForSequenceClassification, RustBertError>
507    where
508        P: Borrow<nn::Path<'p>>,
509    {
510        let p = p.borrow();
511        let roberta =
512            BertModel::<RobertaEmbeddings>::new_with_optional_pooler(p / "roberta", config, false);
513        let classifier = RobertaClassificationHead::new(p / "classifier", config)?;
514
515        Ok(RobertaForSequenceClassification {
516            roberta,
517            classifier,
518        })
519    }
520
521    #[allow(rustdoc::invalid_html_tags)]
522    /// Forward pass through the model
523    ///
524    /// # Arguments
525    ///
526    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
527    /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
528    /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *</s>*) and 1 for the second sentence. If None set to 0.
529    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
530    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
531    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
532    ///
533    /// # Returns
534    ///
535    /// * `RobertaSequenceClassificationOutput` containing:
536    ///   - `logits` - `Tensor` of shape (*batch size*, *num_labels*)
537    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
538    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
539    ///
540    /// # Example
541    ///
542    /// ```no_run
543    /// # use rust_bert::bert::BertConfig;
544    /// # use tch::{nn, Device, Tensor, no_grad};
545    /// # use rust_bert::Config;
546    /// # use std::path::Path;
547    /// # use tch::kind::Kind::Int64;
548    /// use rust_bert::roberta::RobertaForSequenceClassification;
549    /// # let config_path = Path::new("path/to/config.json");
550    /// # let vocab_path = Path::new("path/to/vocab.txt");
551    /// # let device = Device::Cpu;
552    /// # let vs = nn::VarStore::new(device);
553    /// # let config = BertConfig::from_file(config_path);
554    /// # let roberta_model = RobertaForSequenceClassification::new(&vs.root(), &config).unwrap();;
555    /// let (batch_size, sequence_length) = (64, 128);
556    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
557    /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
558    /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
559    /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
560    ///     .expand(&[batch_size, sequence_length], true);
561    ///
562    /// let model_output = no_grad(|| {
563    ///     roberta_model.forward_t(
564    ///         Some(&input_tensor),
565    ///         Some(&mask),
566    ///         Some(&token_type_ids),
567    ///         Some(&position_ids),
568    ///         None,
569    ///         false,
570    ///     )
571    /// });
572    /// ```
573    pub fn forward_t(
574        &self,
575        input_ids: Option<&Tensor>,
576        mask: Option<&Tensor>,
577        token_type_ids: Option<&Tensor>,
578        position_ids: Option<&Tensor>,
579        input_embeds: Option<&Tensor>,
580        train: bool,
581    ) -> RobertaSequenceClassificationOutput {
582        let base_model_output = self
583            .roberta
584            .forward_t(
585                input_ids,
586                mask,
587                token_type_ids,
588                position_ids,
589                input_embeds,
590                None,
591                None,
592                train,
593            )
594            .unwrap();
595
596        let logits = self
597            .classifier
598            .forward_t(&base_model_output.hidden_state, train);
599        RobertaSequenceClassificationOutput {
600            logits,
601            all_hidden_states: base_model_output.all_hidden_states,
602            all_attentions: base_model_output.all_attentions,
603        }
604    }
605}
606
607#[allow(rustdoc::invalid_html_tags)]
608/// # RoBERTa for multiple choices
609/// Multiple choices model using a RoBERTa base model and a linear classifier.
610/// Input should be in the form `<s> Context </s> Possible choice </s>`. The choice is made along the batch axis,
611/// assuming all elements of the batch are alternatives to be chosen from for a given context.
612/// It is made of the following blocks:
613/// - `roberta`: Base RoBERTa model
614/// - `classifier`: Linear layer for multiple choices
615pub struct RobertaForMultipleChoice {
616    roberta: BertModel<RobertaEmbeddings>,
617    dropout: Dropout,
618    classifier: nn::Linear,
619}
620
621impl RobertaForMultipleChoice {
622    /// Build a new `RobertaForMultipleChoice`
623    ///
624    /// # Arguments
625    ///
626    /// * `p` - Variable store path for the root of the RobertaForMaskedLM model
627    /// * `config` - `RobertaConfig` object defining the model architecture and vocab size
628    ///
629    /// # Example
630    ///
631    /// ```no_run
632    /// use rust_bert::roberta::{RobertaConfig, RobertaForMultipleChoice};
633    /// use rust_bert::Config;
634    /// use std::path::Path;
635    /// use tch::{nn, Device};
636    ///
637    /// let config_path = Path::new("path/to/config.json");
638    /// let device = Device::Cpu;
639    /// let p = nn::VarStore::new(device);
640    /// let config = RobertaConfig::from_file(config_path);
641    /// let roberta = RobertaForMultipleChoice::new(&p.root() / "roberta", &config);
642    /// ```
643    pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaForMultipleChoice
644    where
645        P: Borrow<nn::Path<'p>>,
646    {
647        let p = p.borrow();
648        let roberta = BertModel::<RobertaEmbeddings>::new(p / "roberta", config);
649        let dropout = Dropout::new(config.hidden_dropout_prob);
650        let classifier = nn::linear(p / "classifier", config.hidden_size, 1, Default::default());
651
652        RobertaForMultipleChoice {
653            roberta,
654            dropout,
655            classifier,
656        }
657    }
658
659    #[allow(rustdoc::invalid_html_tags)]
660    /// Forward pass through the model
661    ///
662    /// # Arguments
663    ///
664    /// * `input_ids` - Input tensor of shape (*batch size*, *sequence_length*).
665    /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
666    /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *</s>*) and 1 for the second sentence. If None set to 0.
667    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
668    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
669    ///
670    /// # Returns
671    ///
672    /// * `RobertaSequenceClassificationOutput` containing:
673    ///   - `logits` - `Tensor` of shape (*1*, *batch size*) containing the logits for each of the alternatives given
674    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
675    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
676    ///
677    /// # Example
678    ///
679    /// ```no_run
680    /// # use rust_bert::bert::BertConfig;
681    /// # use tch::{nn, Device, Tensor, no_grad};
682    /// # use rust_bert::Config;
683    /// # use std::path::Path;
684    /// # use tch::kind::Kind::Int64;
685    /// use rust_bert::roberta::RobertaForMultipleChoice;
686    /// # let config_path = Path::new("path/to/config.json");
687    /// # let vocab_path = Path::new("path/to/vocab.txt");
688    /// # let device = Device::Cpu;
689    /// # let vs = nn::VarStore::new(device);
690    /// # let config = BertConfig::from_file(config_path);
691    /// # let roberta_model = RobertaForMultipleChoice::new(&vs.root(), &config);
692    /// let (num_choices, sequence_length) = (3, 128);
693    /// let input_tensor = Tensor::rand(&[num_choices, sequence_length], (Int64, device));
694    /// let mask = Tensor::zeros(&[num_choices, sequence_length], (Int64, device));
695    /// let token_type_ids = Tensor::zeros(&[num_choices, sequence_length], (Int64, device));
696    /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
697    ///     .expand(&[num_choices, sequence_length], true);
698    ///
699    /// let model_output = no_grad(|| {
700    ///     roberta_model.forward_t(
701    ///         &input_tensor,
702    ///         Some(&mask),
703    ///         Some(&token_type_ids),
704    ///         Some(&position_ids),
705    ///         false,
706    ///     )
707    /// });
708    /// ```
709    pub fn forward_t(
710        &self,
711        input_ids: &Tensor,
712        mask: Option<&Tensor>,
713        token_type_ids: Option<&Tensor>,
714        position_ids: Option<&Tensor>,
715        train: bool,
716    ) -> RobertaSequenceClassificationOutput {
717        let num_choices = input_ids.size()[1];
718
719        let input_ids = Some(input_ids.view((-1, *input_ids.size().last().unwrap())));
720        let mask = mask.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
721        let token_type_ids =
722            token_type_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
723        let position_ids =
724            position_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
725
726        let base_model_output = self
727            .roberta
728            .forward_t(
729                input_ids.as_ref(),
730                mask.as_ref(),
731                token_type_ids.as_ref(),
732                position_ids.as_ref(),
733                None,
734                None,
735                None,
736                train,
737            )
738            .unwrap();
739
740        let logits = base_model_output
741            .pooled_output
742            .unwrap()
743            .apply_t(&self.dropout, train)
744            .apply(&self.classifier)
745            .view((-1, num_choices));
746        RobertaSequenceClassificationOutput {
747            logits,
748            all_hidden_states: base_model_output.all_hidden_states,
749            all_attentions: base_model_output.all_attentions,
750        }
751    }
752}
753
754/// # RoBERTa for token classification (e.g. NER, POS)
755/// Token-level classifier predicting a label for each token provided. Note that because of bpe tokenization, the labels predicted are
756/// not necessarily aligned with words in the sentence.
757/// It is made of the following blocks:
758/// - `roberta`: Base RoBERTa model
759/// - `classifier`: Linear layer for token classification
760pub struct RobertaForTokenClassification {
761    roberta: BertModel<RobertaEmbeddings>,
762    dropout: Dropout,
763    classifier: nn::Linear,
764}
765
766impl RobertaForTokenClassification {
767    /// Build a new `RobertaForTokenClassification`
768    ///
769    /// # Arguments
770    ///
771    /// * `p` - Variable store path for the root of the RobertaForMaskedLM model
772    /// * `config` - `RobertaConfig` object defining the model architecture and vocab size
773    ///
774    /// # Example
775    ///
776    /// ```no_run
777    /// use rust_bert::roberta::{RobertaConfig, RobertaForMultipleChoice};
778    /// use rust_bert::Config;
779    /// use std::path::Path;
780    /// use tch::{nn, Device};
781    ///
782    /// let config_path = Path::new("path/to/config.json");
783    /// let device = Device::Cpu;
784    /// let p = nn::VarStore::new(device);
785    /// let config = RobertaConfig::from_file(config_path);
786    /// let roberta = RobertaForMultipleChoice::new(&p.root() / "roberta", &config);
787    /// ```
788    pub fn new<'p, P>(
789        p: P,
790        config: &BertConfig,
791    ) -> Result<RobertaForTokenClassification, RustBertError>
792    where
793        P: Borrow<nn::Path<'p>>,
794    {
795        let p = p.borrow();
796        let roberta =
797            BertModel::<RobertaEmbeddings>::new_with_optional_pooler(p / "roberta", config, false);
798        let dropout = Dropout::new(config.hidden_dropout_prob);
799        let num_labels = config
800            .id2label
801            .as_ref()
802            .ok_or_else(|| {
803                RustBertError::InvalidConfigurationError(
804                    "num_labels not provided in configuration".to_string(),
805                )
806            })?
807            .len() as i64;
808        let classifier = nn::linear(
809            p / "classifier",
810            config.hidden_size,
811            num_labels,
812            Default::default(),
813        );
814
815        Ok(RobertaForTokenClassification {
816            roberta,
817            dropout,
818            classifier,
819        })
820    }
821
822    #[allow(rustdoc::invalid_html_tags)]
823    /// Forward pass through the model
824    ///
825    /// # Arguments
826    ///
827    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
828    /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
829    /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *</s>*) and 1 for the second sentence. If None set to 0.
830    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
831    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
832    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
833    ///
834    /// # Returns
835    ///
836    /// * `RobertaTokenClassificationOutput` containing:
837    ///   - `logits` - `Tensor` of shape (*batch size*, *sequence_length*, *num_labels*) containing the logits for each of the input tokens and classes
838    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
839    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
840    ///
841    /// # Example
842    ///
843    /// ```no_run
844    /// # use rust_bert::bert::BertConfig;
845    /// # use tch::{nn, Device, Tensor, no_grad};
846    /// # use rust_bert::Config;
847    /// # use std::path::Path;
848    /// # use tch::kind::Kind::Int64;
849    /// use rust_bert::roberta::RobertaForTokenClassification;
850    /// # let config_path = Path::new("path/to/config.json");
851    /// # let vocab_path = Path::new("path/to/vocab.txt");
852    /// # let device = Device::Cpu;
853    /// # let vs = nn::VarStore::new(device);
854    /// # let config = BertConfig::from_file(config_path);
855    /// # let roberta_model = RobertaForTokenClassification::new(&vs.root(), &config).unwrap();
856    /// let (batch_size, sequence_length) = (64, 128);
857    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
858    /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
859    /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
860    /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
861    ///     .expand(&[batch_size, sequence_length], true);
862    ///
863    /// let model_output = no_grad(|| {
864    ///     roberta_model.forward_t(
865    ///         Some(&input_tensor),
866    ///         Some(&mask),
867    ///         Some(&token_type_ids),
868    ///         Some(&position_ids),
869    ///         None,
870    ///         false,
871    ///     )
872    /// });
873    /// ```
874    pub fn forward_t(
875        &self,
876        input_ids: Option<&Tensor>,
877        mask: Option<&Tensor>,
878        token_type_ids: Option<&Tensor>,
879        position_ids: Option<&Tensor>,
880        input_embeds: Option<&Tensor>,
881        train: bool,
882    ) -> RobertaTokenClassificationOutput {
883        let base_model_output = self
884            .roberta
885            .forward_t(
886                input_ids,
887                mask,
888                token_type_ids,
889                position_ids,
890                input_embeds,
891                None,
892                None,
893                train,
894            )
895            .unwrap();
896
897        let logits = base_model_output
898            .hidden_state
899            .apply_t(&self.dropout, train)
900            .apply(&self.classifier);
901
902        RobertaTokenClassificationOutput {
903            logits,
904            all_hidden_states: base_model_output.all_hidden_states,
905            all_attentions: base_model_output.all_attentions,
906        }
907    }
908}
909
910/// # RoBERTa for question answering
911/// Extractive question-answering model based on a RoBERTa language model. Identifies the segment of a context that answers a provided question.
912/// Please note that a significant amount of pre- and post-processing is required to perform end-to-end question answering.
913/// See the question answering pipeline (also provided in this crate) for more details.
914/// It is made of the following blocks:
915/// - `roberta`: Base RoBERTa model
916/// - `qa_outputs`: Linear layer for question answering
917pub struct RobertaForQuestionAnswering {
918    roberta: BertModel<RobertaEmbeddings>,
919    qa_outputs: nn::Linear,
920}
921
922impl RobertaForQuestionAnswering {
923    /// Build a new `RobertaForQuestionAnswering`
924    ///
925    /// # Arguments
926    ///
927    /// * `p` - Variable store path for the root of the RobertaForMaskedLM model
928    /// * `config` - `RobertaConfig` object defining the model architecture and vocab size
929    ///
930    /// # Example
931    ///
932    /// ```no_run
933    /// use rust_bert::roberta::{RobertaConfig, RobertaForQuestionAnswering};
934    /// use rust_bert::Config;
935    /// use std::path::Path;
936    /// use tch::{nn, Device};
937    ///
938    /// let config_path = Path::new("path/to/config.json");
939    /// let device = Device::Cpu;
940    /// let p = nn::VarStore::new(device);
941    /// let config = RobertaConfig::from_file(config_path);
942    /// let roberta = RobertaForQuestionAnswering::new(&p.root() / "roberta", &config);
943    /// ```
944    pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaForQuestionAnswering
945    where
946        P: Borrow<nn::Path<'p>>,
947    {
948        let p = p.borrow();
949        let roberta =
950            BertModel::<RobertaEmbeddings>::new_with_optional_pooler(p / "roberta", config, false);
951        let num_labels = 2;
952        let qa_outputs = nn::linear(
953            p / "qa_outputs",
954            config.hidden_size,
955            num_labels,
956            Default::default(),
957        );
958
959        RobertaForQuestionAnswering {
960            roberta,
961            qa_outputs,
962        }
963    }
964
965    #[allow(rustdoc::invalid_html_tags)]
966    /// Forward pass through the model
967    ///
968    /// # Arguments
969    ///
970    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
971    /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
972    /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *</s>*) and 1 for the second sentence. If None set to 0.
973    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
974    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
975    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
976    ///
977    /// # Returns
978    ///
979    /// * `RobertaQuestionAnsweringOutput` containing:
980    ///   - `start_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for start of the answer
981    ///   - `end_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for end of the answer
982    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
983    ///   - `all_attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
984    ///
985    /// # Example
986    ///
987    /// ```no_run
988    /// # use rust_bert::bert::BertConfig;
989    /// # use tch::{nn, Device, Tensor, no_grad};
990    /// # use rust_bert::Config;
991    /// # use std::path::Path;
992    /// # use tch::kind::Kind::Int64;
993    /// use rust_bert::roberta::RobertaForQuestionAnswering;
994    /// # let config_path = Path::new("path/to/config.json");
995    /// # let vocab_path = Path::new("path/to/vocab.txt");
996    /// # let device = Device::Cpu;
997    /// # let vs = nn::VarStore::new(device);
998    /// # let config = BertConfig::from_file(config_path);
999    /// # let roberta_model = RobertaForQuestionAnswering::new(&vs.root(), &config);
1000    /// let (batch_size, sequence_length) = (64, 128);
1001    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
1002    /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
1003    /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
1004    /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
1005    ///     .expand(&[batch_size, sequence_length], true);
1006    ///
1007    /// let model_output = no_grad(|| {
1008    ///     roberta_model.forward_t(
1009    ///         Some(&input_tensor),
1010    ///         Some(&mask),
1011    ///         Some(&token_type_ids),
1012    ///         Some(&position_ids),
1013    ///         None,
1014    ///         false,
1015    ///     )
1016    /// });
1017    /// ```
1018    pub fn forward_t(
1019        &self,
1020        input_ids: Option<&Tensor>,
1021        mask: Option<&Tensor>,
1022        token_type_ids: Option<&Tensor>,
1023        position_ids: Option<&Tensor>,
1024        input_embeds: Option<&Tensor>,
1025        train: bool,
1026    ) -> RobertaQuestionAnsweringOutput {
1027        let base_model_output = self
1028            .roberta
1029            .forward_t(
1030                input_ids,
1031                mask,
1032                token_type_ids,
1033                position_ids,
1034                input_embeds,
1035                None,
1036                None,
1037                train,
1038            )
1039            .unwrap();
1040
1041        let sequence_output = base_model_output.hidden_state.apply(&self.qa_outputs);
1042        let logits = sequence_output.split(1, -1);
1043        let (start_logits, end_logits) = (&logits[0], &logits[1]);
1044        let start_logits = start_logits.squeeze_dim(-1);
1045        let end_logits = end_logits.squeeze_dim(-1);
1046
1047        RobertaQuestionAnsweringOutput {
1048            start_logits,
1049            end_logits,
1050            all_hidden_states: base_model_output.all_hidden_states,
1051            all_attentions: base_model_output.all_attentions,
1052        }
1053    }
1054}
1055
1056/// # RoBERTa for sentence embeddings
1057/// Transformer usable in [`SentenceEmbeddingsModel`](crate::pipelines::sentence_embeddings::SentenceEmbeddingsModel).
1058pub type RobertaForSentenceEmbeddings = BertModel<RobertaEmbeddings>;
1059
1060/// Container for the RoBERTa masked LM model output.
1061pub struct RobertaMaskedLMOutput {
1062    /// Logits for the vocabulary items at each sequence position
1063    pub prediction_scores: Tensor,
1064    /// Hidden states for all intermediate layers
1065    pub all_hidden_states: Option<Vec<Tensor>>,
1066    /// Attention weights for all intermediate layers
1067    pub all_attentions: Option<Vec<Tensor>>,
1068}
1069
1070/// Container for the RoBERTa sequence classification model output.
1071pub struct RobertaSequenceClassificationOutput {
1072    /// Logits for each input (sequence) for each target class
1073    pub logits: Tensor,
1074    /// Hidden states for all intermediate layers
1075    pub all_hidden_states: Option<Vec<Tensor>>,
1076    /// Attention weights for all intermediate layers
1077    pub all_attentions: Option<Vec<Tensor>>,
1078}
1079
1080/// Container for the RoBERTa token classification model output.
1081pub struct RobertaTokenClassificationOutput {
1082    /// Logits for each sequence item (token) for each target class
1083    pub logits: Tensor,
1084    /// Hidden states for all intermediate layers
1085    pub all_hidden_states: Option<Vec<Tensor>>,
1086    /// Attention weights for all intermediate layers
1087    pub all_attentions: Option<Vec<Tensor>>,
1088}
1089
1090/// Container for the RoBERTa question answering model output.
1091pub struct RobertaQuestionAnsweringOutput {
1092    /// Logits for the start position for token of each input sequence
1093    pub start_logits: Tensor,
1094    /// Logits for the end position for token of each input sequence
1095    pub end_logits: Tensor,
1096    /// Hidden states for all intermediate layers
1097    pub all_hidden_states: Option<Vec<Tensor>>,
1098    /// Attention weights for all intermediate layers
1099    pub all_attentions: Option<Vec<Tensor>>,
1100}