rust_bert/models/roberta/roberta_model.rs
1// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
2// Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3// Copyright 2019 Guillaume Becquin
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7// http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14use crate::bert::{BertConfig, BertModel};
15use crate::common::activations::_gelu;
16use crate::common::dropout::Dropout;
17use crate::common::linear::{linear_no_bias, LinearNoBias};
18use crate::roberta::embeddings::RobertaEmbeddings;
19use crate::RustBertError;
20use std::borrow::Borrow;
21use tch::nn::init::DEFAULT_KAIMING_UNIFORM;
22use tch::{nn, Tensor};
23
24/// # RoBERTa Pretrained model weight files
25pub struct RobertaModelResources;
26
27/// # RoBERTa Pretrained model config files
28pub struct RobertaConfigResources;
29
30/// # RoBERTa Pretrained model vocab files
31pub struct RobertaVocabResources;
32
33/// # RoBERTa Pretrained model merges files
34pub struct RobertaMergesResources;
35
36impl RobertaModelResources {
37 /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
38 pub const ROBERTA: (&'static str, &'static str) = (
39 "roberta/model",
40 "https://huggingface.co/roberta-base/resolve/main/rust_model.ot",
41 );
42 /// Shared under Apache 2.0 license by the Hugging Face Inc. team at <https://huggingface.co/distilroberta-base>. Modified with conversion to C-array format.
43 pub const DISTILROBERTA_BASE: (&'static str, &'static str) = (
44 "distilroberta-base/model",
45 "https://huggingface.co/distilroberta-base/resolve/main/rust_model.ot",
46 );
47 /// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at <https://huggingface.co/deepset/roberta-base-squad2>. Modified with conversion to C-array format.
48 pub const ROBERTA_QA: (&'static str, &'static str) = (
49 "roberta-qa/model",
50 "https://huggingface.co/deepset/roberta-base-squad2/resolve/main/rust_model.ot",
51 );
52 /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
53 pub const XLM_ROBERTA_NER_EN: (&'static str, &'static str) = (
54 "xlm-roberta-ner-en/model",
55 "https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/rust_model.ot",
56 );
57 /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
58 pub const XLM_ROBERTA_NER_DE: (&'static str, &'static str) = (
59 "xlm-roberta-ner-de/model",
60 "https://huggingface.co/xlm-roberta-large-finetuned-conll03-german/resolve/main/rust_model.ot",
61 );
62 /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
63 pub const XLM_ROBERTA_NER_NL: (&'static str, &'static str) = (
64 "xlm-roberta-ner-nl/model",
65 "https://huggingface.co/xlm-roberta-large-finetuned-conll02-dutch/resolve/main/rust_model.ot",
66 );
67 /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
68 pub const XLM_ROBERTA_NER_ES: (&'static str, &'static str) = (
69 "xlm-roberta-ner-es/model",
70 "https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/rust_model.ot",
71 );
72 /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/sentence-transformers/all-distilroberta-v1>. Modified with conversion to C-array format.
73 pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = (
74 "all-distilroberta-v1/model",
75 "https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/rust_model.ot",
76 );
77 /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/huggingface/CodeBERTa-language-id>. Modified with conversion to C-array format.
78 pub const CODEBERTA_LANGUAGE_ID: (&'static str, &'static str) = (
79 "codeberta-language-id/model",
80 "https://huggingface.co/huggingface/CodeBERTa-language-id/resolve/main/rust_model.ot",
81 );
82 /// Shared under MIT license by the Microsoft team at <https://github.com/microsoft/CodeBERT>. Modified with conversion to C-array format.
83 pub const CODEBERT_MLM: (&'static str, &'static str) = (
84 "codebert-mlm/model",
85 "https://huggingface.co/microsoft/codebert-base-mlm/resolve/main/rust_model.ot",
86 );
87}
88
89impl RobertaConfigResources {
90 /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
91 pub const ROBERTA: (&'static str, &'static str) = (
92 "roberta/config",
93 "https://huggingface.co/roberta-base/resolve/main/config.json",
94 );
95 /// Shared under Apache 2.0 license by the Hugging Face Inc. team at <https://huggingface.co/distilroberta-base>. Modified with conversion to C-array format.
96 pub const DISTILROBERTA_BASE: (&'static str, &'static str) = (
97 "distilroberta-base/config",
98 "https://cdn.huggingface.co/distilroberta-base-config.json",
99 );
100 /// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at <https://huggingface.co/deepset/roberta-base-squad2>. Modified with conversion to C-array format.
101 pub const ROBERTA_QA: (&'static str, &'static str) = (
102 "roberta-qa/config",
103 "https://huggingface.co/deepset/roberta-base-squad2/resolve/main/config.json",
104 );
105 /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
106 pub const XLM_ROBERTA_NER_EN: (&'static str, &'static str) = (
107 "xlm-roberta-ner-en/config",
108 "https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/config.json",
109 );
110 /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
111 pub const XLM_ROBERTA_NER_DE: (&'static str, &'static str) = (
112 "xlm-roberta-ner-de/config",
113 "https://huggingface.co/xlm-roberta-large-finetuned-conll03-german/resolve/main/config.json",
114 );
115 /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
116 pub const XLM_ROBERTA_NER_NL: (&'static str, &'static str) = (
117 "xlm-roberta-ner-nl/config",
118 "https://huggingface.co/xlm-roberta-large-finetuned-conll02-dutch/resolve/main/config.json",
119 );
120 /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
121 pub const XLM_ROBERTA_NER_ES: (&'static str, &'static str) = (
122 "xlm-roberta-ner-es/config",
123 "https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/config.json",
124 );
125 /// Shared under Apache 2.0 licenseat <https://huggingface.co/sentence-transformers/all-distilroberta-v1>. Modified with conversion to C-array format.
126 pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = (
127 "all-distilroberta-v1/config",
128 "https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/config.json",
129 );
130 /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/huggingface/CodeBERTa-language-id>. Modified with conversion to C-array format.
131 pub const CODEBERTA_LANGUAGE_ID: (&'static str, &'static str) = (
132 "codeberta-language-id/config",
133 "https://huggingface.co/huggingface/CodeBERTa-language-id/resolve/main/config.json",
134 );
135 /// Shared under MIT license by the Microsoft team at <https://github.com/microsoft/CodeBERT>. Modified with conversion to C-array format.
136 pub const CODEBERT_MLM: (&'static str, &'static str) = (
137 "codebert-mlm/config",
138 "https://huggingface.co/microsoft/codebert-base-mlm/resolve/main/config.json",
139 );
140}
141
142impl RobertaVocabResources {
143 /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
144 pub const ROBERTA: (&'static str, &'static str) = (
145 "roberta/vocab",
146 "https://huggingface.co/roberta-base/resolve/main/vocab.json",
147 );
148 /// Shared under Apache 2.0 license by the Hugging Face Inc. team at <https://huggingface.co/distilroberta-base>. Modified with conversion to C-array format.
149 pub const DISTILROBERTA_BASE: (&'static str, &'static str) = (
150 "distilroberta-base/vocab",
151 "https://cdn.huggingface.co/distilroberta-base-vocab.json",
152 );
153 /// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at <https://huggingface.co/deepset/roberta-base-squad2>. Modified with conversion to C-array format.
154 pub const ROBERTA_QA: (&'static str, &'static str) = (
155 "roberta-qa/vocab",
156 "https://huggingface.co/deepset/roberta-base-squad2/resolve/main/vocab.json",
157 );
158 /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
159 pub const XLM_ROBERTA_NER_EN: (&'static str, &'static str) = (
160 "xlm-roberta-ner-en/spiece",
161 "https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/sentencepiece.bpe.model",
162 );
163 /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
164 pub const XLM_ROBERTA_NER_DE: (&'static str, &'static str) = (
165 "xlm-roberta-ner-de/spiece",
166 "https://huggingface.co/xlm-roberta-large-finetuned-conll03-german/resolve/main/sentencepiece.bpe.model",
167 );
168 /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
169 pub const XLM_ROBERTA_NER_NL: (&'static str, &'static str) = (
170 "xlm-roberta-ner-nl/spiece",
171 "https://huggingface.co/xlm-roberta-large-finetuned-conll02-dutch/resolve/main/sentencepiece.bpe.model",
172 );
173 /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
174 pub const XLM_ROBERTA_NER_ES: (&'static str, &'static str) = (
175 "xlm-roberta-ner-es/spiece",
176 "https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/sentencepiece.bpe.model",
177 );
178 /// Shared under Apache 2.0 licenseat <https://huggingface.co/sentence-transformers/all-distilroberta-v1>. Modified with conversion to C-array format.
179 pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = (
180 "all-distilroberta-v1/vocab",
181 "https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/vocab.json",
182 );
183 /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/huggingface/CodeBERTa-language-id>. Modified with conversion to C-array format.
184 pub const CODEBERTA_LANGUAGE_ID: (&'static str, &'static str) = (
185 "codeberta-language-id/vocab",
186 "https://huggingface.co/huggingface/CodeBERTa-language-id/resolve/main/vocab.json",
187 );
188 /// Shared under MIT license by the Microsoft team at <https://github.com/microsoft/CodeBERT>. Modified with conversion to C-array format.
189 pub const CODEBERT_MLM: (&'static str, &'static str) = (
190 "codebert-mlm/vocab",
191 "https://huggingface.co/microsoft/codebert-base-mlm/resolve/main/vocab.json",
192 );
193}
194
195impl RobertaMergesResources {
196 /// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
197 pub const ROBERTA: (&'static str, &'static str) = (
198 "roberta/merges",
199 "https://huggingface.co/roberta-base/resolve/main/merges.txt",
200 );
201 /// Shared under Apache 2.0 license by the Hugging Face Inc. team at <https://huggingface.co/distilroberta-base>. Modified with conversion to C-array format.
202 pub const DISTILROBERTA_BASE: (&'static str, &'static str) = (
203 "distilroberta-base/merges",
204 "https://cdn.huggingface.co/distilroberta-base-merges.txt",
205 );
206 /// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at <https://huggingface.co/deepset/roberta-base-squad2>. Modified with conversion to C-array format.
207 pub const ROBERTA_QA: (&'static str, &'static str) = (
208 "roberta-qa/merges",
209 "https://huggingface.co/deepset/roberta-base-squad2/resolve/main/merges.txt",
210 );
211 /// Shared under Apache 2.0 licenseat <https://huggingface.co/sentence-transformers/all-distilroberta-v1>. Modified with conversion to C-array format.
212 pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = (
213 "all-distilroberta-v1/merges",
214 "https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/merges.txt",
215 );
216 /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/huggingface/CodeBERTa-language-id>. Modified with conversion to C-array format.
217 pub const CODEBERTA_LANGUAGE_ID: (&'static str, &'static str) = (
218 "codeberta-language-id/merges",
219 "https://huggingface.co/huggingface/CodeBERTa-language-id/resolve/main/merges.txt",
220 );
221 /// Shared under MIT license by the Microsoft team at <https://github.com/microsoft/CodeBERT>. Modified with conversion to C-array format.
222 pub const CODEBERT_MLM: (&'static str, &'static str) = (
223 "codebert-mlm/merges",
224 "https://huggingface.co/microsoft/codebert-base-mlm/resolve/main/merges.txt",
225 );
226}
227
228pub struct RobertaLMHead {
229 dense: nn::Linear,
230 decoder: LinearNoBias,
231 layer_norm: nn::LayerNorm,
232 bias: Tensor,
233}
234
235impl RobertaLMHead {
236 pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaLMHead
237 where
238 P: Borrow<nn::Path<'p>>,
239 {
240 let p = p.borrow();
241 let dense = nn::linear(
242 p / "dense",
243 config.hidden_size,
244 config.hidden_size,
245 Default::default(),
246 );
247 let layer_norm_config = nn::LayerNormConfig {
248 eps: 1e-12,
249 ..Default::default()
250 };
251 let layer_norm = nn::layer_norm(
252 p / "layer_norm",
253 vec![config.hidden_size],
254 layer_norm_config,
255 );
256 let decoder = linear_no_bias(
257 p / "decoder",
258 config.hidden_size,
259 config.vocab_size,
260 Default::default(),
261 );
262 let bias = p.var("bias", &[config.vocab_size], DEFAULT_KAIMING_UNIFORM);
263
264 RobertaLMHead {
265 dense,
266 decoder,
267 layer_norm,
268 bias,
269 }
270 }
271
272 pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
273 (_gelu(&hidden_states.apply(&self.dense)))
274 .apply(&self.layer_norm)
275 .apply(&self.decoder)
276 + &self.bias
277 }
278}
279
280/// # RoBERTa model configuration
281/// Defines the RoBERTa model architecture (e.g. number of layers, hidden layer size, label mapping...)
282pub type RobertaConfig = BertConfig;
283
284/// # RoBERTa for masked language model
285/// Base RoBERTa model with a RoBERTa masked language model head to predict missing tokens, for example `"Looks like one [MASK] is missing" -> "person"`
286/// It is made of the following blocks:
287/// - `roberta`: Base BertModel with RoBERTa embeddings
288/// - `lm_head`: RoBERTa LM prediction head
289pub struct RobertaForMaskedLM {
290 roberta: BertModel<RobertaEmbeddings>,
291 lm_head: RobertaLMHead,
292}
293
294impl RobertaForMaskedLM {
295 /// Build a new `RobertaForMaskedLM`
296 ///
297 /// # Arguments
298 ///
299 /// * `p` - Variable store path for the root of the RobertaForMaskedLM model
300 /// * `config` - `RobertaConfig` object defining the model architecture and vocab size
301 ///
302 /// # Example
303 ///
304 /// ```no_run
305 /// use rust_bert::roberta::{RobertaConfig, RobertaForMaskedLM};
306 /// use rust_bert::Config;
307 /// use std::path::Path;
308 /// use tch::{nn, Device};
309 ///
310 /// let config_path = Path::new("path/to/config.json");
311 /// let device = Device::Cpu;
312 /// let p = nn::VarStore::new(device);
313 /// let config = RobertaConfig::from_file(config_path);
314 /// let roberta = RobertaForMaskedLM::new(&p.root() / "roberta", &config);
315 /// ```
316 pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaForMaskedLM
317 where
318 P: Borrow<nn::Path<'p>>,
319 {
320 let p = p.borrow();
321
322 let roberta =
323 BertModel::<RobertaEmbeddings>::new_with_optional_pooler(p / "roberta", config, false);
324 let lm_head = RobertaLMHead::new(p / "lm_head", config);
325
326 RobertaForMaskedLM { roberta, lm_head }
327 }
328
329 #[allow(rustdoc::invalid_html_tags)]
330 /// Forward pass through the model
331 ///
332 /// # Arguments
333 ///
334 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see *input_embeds*)
335 /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
336 /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *</s>*) and 1 for the second sentence. If None set to 0.
337 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
338 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see *input_ids*)
339 /// * `encoder_hidden_states` - Optional encoder hidden state of shape (*batch size*, *encoder_sequence_length*, *hidden_size*). If the model is defined as a decoder and the *encoder_hidden_states* is not None, used in the cross-attention layer as keys and values (query from the decoder).
340 /// * `encoder_mask` - Optional encoder attention mask of shape (*batch size*, *encoder_sequence_length*). If the model is defined as a decoder and the *encoder_hidden_states* is not None, used to mask encoder values. Positions with value 0 will be masked.
341 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
342 ///
343 /// # Returns
344 ///
345 /// * `output` - `Tensor` of shape (*batch size*, *num_labels*, *vocab_size*)
346 /// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
347 /// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
348 ///
349 /// # Example
350 ///
351 /// ```no_run
352 /// # use rust_bert::bert::BertConfig;
353 /// # use tch::{nn, Device, Tensor, no_grad};
354 /// # use rust_bert::Config;
355 /// # use std::path::Path;
356 /// # use tch::kind::Kind::Int64;
357 /// use rust_bert::roberta::RobertaForMaskedLM;
358 /// # let config_path = Path::new("path/to/config.json");
359 /// # let vocab_path = Path::new("path/to/vocab.txt");
360 /// # let device = Device::Cpu;
361 /// # let vs = nn::VarStore::new(device);
362 /// # let config = BertConfig::from_file(config_path);
363 /// # let roberta_model = RobertaForMaskedLM::new(&vs.root(), &config);
364 /// let (batch_size, sequence_length) = (64, 128);
365 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
366 /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
367 /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
368 /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
369 /// .expand(&[batch_size, sequence_length], true);
370 ///
371 /// let model_output = no_grad(|| {
372 /// roberta_model.forward_t(
373 /// Some(&input_tensor),
374 /// Some(&mask),
375 /// Some(&token_type_ids),
376 /// Some(&position_ids),
377 /// None,
378 /// None,
379 /// None,
380 /// false,
381 /// )
382 /// });
383 /// ```
384 pub fn forward_t(
385 &self,
386 input_ids: Option<&Tensor>,
387 mask: Option<&Tensor>,
388 token_type_ids: Option<&Tensor>,
389 position_ids: Option<&Tensor>,
390 input_embeds: Option<&Tensor>,
391 encoder_hidden_states: Option<&Tensor>,
392 encoder_mask: Option<&Tensor>,
393 train: bool,
394 ) -> RobertaMaskedLMOutput {
395 let base_model_output = self
396 .roberta
397 .forward_t(
398 input_ids,
399 mask,
400 token_type_ids,
401 position_ids,
402 input_embeds,
403 encoder_hidden_states,
404 encoder_mask,
405 train,
406 )
407 .unwrap();
408
409 let prediction_scores = self.lm_head.forward(&base_model_output.hidden_state);
410 RobertaMaskedLMOutput {
411 prediction_scores,
412 all_hidden_states: base_model_output.all_hidden_states,
413 all_attentions: base_model_output.all_attentions,
414 }
415 }
416}
417
418pub struct RobertaClassificationHead {
419 dense: nn::Linear,
420 dropout: Dropout,
421 out_proj: nn::Linear,
422}
423
424impl RobertaClassificationHead {
425 pub fn new<'p, P>(p: P, config: &BertConfig) -> Result<RobertaClassificationHead, RustBertError>
426 where
427 P: Borrow<nn::Path<'p>>,
428 {
429 let p = p.borrow();
430 let dense = nn::linear(
431 p / "dense",
432 config.hidden_size,
433 config.hidden_size,
434 Default::default(),
435 );
436 let num_labels = config
437 .id2label
438 .as_ref()
439 .ok_or_else(|| {
440 RustBertError::InvalidConfigurationError(
441 "num_labels not provided in configuration".to_string(),
442 )
443 })?
444 .len() as i64;
445 let out_proj = nn::linear(
446 p / "out_proj",
447 config.hidden_size,
448 num_labels,
449 Default::default(),
450 );
451 let dropout = Dropout::new(config.hidden_dropout_prob);
452
453 Ok(RobertaClassificationHead {
454 dense,
455 dropout,
456 out_proj,
457 })
458 }
459
460 pub fn forward_t(&self, hidden_states: &Tensor, train: bool) -> Tensor {
461 hidden_states
462 .select(1, 0)
463 .apply_t(&self.dropout, train)
464 .apply(&self.dense)
465 .tanh()
466 .apply_t(&self.dropout, train)
467 .apply(&self.out_proj)
468 }
469}
470
471/// # RoBERTa for sequence classification
472/// Base RoBERTa model with a classifier head to perform sentence or document-level classification
473/// It is made of the following blocks:
474/// - `roberta`: Base RoBERTa model
475/// - `classifier`: RoBERTa classification head made of 2 linear layers
476pub struct RobertaForSequenceClassification {
477 roberta: BertModel<RobertaEmbeddings>,
478 classifier: RobertaClassificationHead,
479}
480
481impl RobertaForSequenceClassification {
482 /// Build a new `RobertaForSequenceClassification`
483 ///
484 /// # Arguments
485 ///
486 /// * `p` - Variable store path for the root of the RobertaForMaskedLM model
487 /// * `config` - `RobertaConfig` object defining the model architecture and vocab size
488 ///
489 /// # Example
490 ///
491 /// ```no_run
492 /// use rust_bert::roberta::{RobertaConfig, RobertaForSequenceClassification};
493 /// use rust_bert::Config;
494 /// use std::path::Path;
495 /// use tch::{nn, Device};
496 ///
497 /// let config_path = Path::new("path/to/config.json");
498 /// let device = Device::Cpu;
499 /// let p = nn::VarStore::new(device);
500 /// let config = RobertaConfig::from_file(config_path);
501 /// let roberta = RobertaForSequenceClassification::new(&p.root() / "roberta", &config).unwrap();
502 /// ```
503 pub fn new<'p, P>(
504 p: P,
505 config: &BertConfig,
506 ) -> Result<RobertaForSequenceClassification, RustBertError>
507 where
508 P: Borrow<nn::Path<'p>>,
509 {
510 let p = p.borrow();
511 let roberta =
512 BertModel::<RobertaEmbeddings>::new_with_optional_pooler(p / "roberta", config, false);
513 let classifier = RobertaClassificationHead::new(p / "classifier", config)?;
514
515 Ok(RobertaForSequenceClassification {
516 roberta,
517 classifier,
518 })
519 }
520
521 #[allow(rustdoc::invalid_html_tags)]
522 /// Forward pass through the model
523 ///
524 /// # Arguments
525 ///
526 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
527 /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
528 /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *</s>*) and 1 for the second sentence. If None set to 0.
529 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
530 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
531 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
532 ///
533 /// # Returns
534 ///
535 /// * `RobertaSequenceClassificationOutput` containing:
536 /// - `logits` - `Tensor` of shape (*batch size*, *num_labels*)
537 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
538 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
539 ///
540 /// # Example
541 ///
542 /// ```no_run
543 /// # use rust_bert::bert::BertConfig;
544 /// # use tch::{nn, Device, Tensor, no_grad};
545 /// # use rust_bert::Config;
546 /// # use std::path::Path;
547 /// # use tch::kind::Kind::Int64;
548 /// use rust_bert::roberta::RobertaForSequenceClassification;
549 /// # let config_path = Path::new("path/to/config.json");
550 /// # let vocab_path = Path::new("path/to/vocab.txt");
551 /// # let device = Device::Cpu;
552 /// # let vs = nn::VarStore::new(device);
553 /// # let config = BertConfig::from_file(config_path);
554 /// # let roberta_model = RobertaForSequenceClassification::new(&vs.root(), &config).unwrap();;
555 /// let (batch_size, sequence_length) = (64, 128);
556 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
557 /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
558 /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
559 /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
560 /// .expand(&[batch_size, sequence_length], true);
561 ///
562 /// let model_output = no_grad(|| {
563 /// roberta_model.forward_t(
564 /// Some(&input_tensor),
565 /// Some(&mask),
566 /// Some(&token_type_ids),
567 /// Some(&position_ids),
568 /// None,
569 /// false,
570 /// )
571 /// });
572 /// ```
573 pub fn forward_t(
574 &self,
575 input_ids: Option<&Tensor>,
576 mask: Option<&Tensor>,
577 token_type_ids: Option<&Tensor>,
578 position_ids: Option<&Tensor>,
579 input_embeds: Option<&Tensor>,
580 train: bool,
581 ) -> RobertaSequenceClassificationOutput {
582 let base_model_output = self
583 .roberta
584 .forward_t(
585 input_ids,
586 mask,
587 token_type_ids,
588 position_ids,
589 input_embeds,
590 None,
591 None,
592 train,
593 )
594 .unwrap();
595
596 let logits = self
597 .classifier
598 .forward_t(&base_model_output.hidden_state, train);
599 RobertaSequenceClassificationOutput {
600 logits,
601 all_hidden_states: base_model_output.all_hidden_states,
602 all_attentions: base_model_output.all_attentions,
603 }
604 }
605}
606
607#[allow(rustdoc::invalid_html_tags)]
608/// # RoBERTa for multiple choices
609/// Multiple choices model using a RoBERTa base model and a linear classifier.
610/// Input should be in the form `<s> Context </s> Possible choice </s>`. The choice is made along the batch axis,
611/// assuming all elements of the batch are alternatives to be chosen from for a given context.
612/// It is made of the following blocks:
613/// - `roberta`: Base RoBERTa model
614/// - `classifier`: Linear layer for multiple choices
615pub struct RobertaForMultipleChoice {
616 roberta: BertModel<RobertaEmbeddings>,
617 dropout: Dropout,
618 classifier: nn::Linear,
619}
620
621impl RobertaForMultipleChoice {
622 /// Build a new `RobertaForMultipleChoice`
623 ///
624 /// # Arguments
625 ///
626 /// * `p` - Variable store path for the root of the RobertaForMaskedLM model
627 /// * `config` - `RobertaConfig` object defining the model architecture and vocab size
628 ///
629 /// # Example
630 ///
631 /// ```no_run
632 /// use rust_bert::roberta::{RobertaConfig, RobertaForMultipleChoice};
633 /// use rust_bert::Config;
634 /// use std::path::Path;
635 /// use tch::{nn, Device};
636 ///
637 /// let config_path = Path::new("path/to/config.json");
638 /// let device = Device::Cpu;
639 /// let p = nn::VarStore::new(device);
640 /// let config = RobertaConfig::from_file(config_path);
641 /// let roberta = RobertaForMultipleChoice::new(&p.root() / "roberta", &config);
642 /// ```
643 pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaForMultipleChoice
644 where
645 P: Borrow<nn::Path<'p>>,
646 {
647 let p = p.borrow();
648 let roberta = BertModel::<RobertaEmbeddings>::new(p / "roberta", config);
649 let dropout = Dropout::new(config.hidden_dropout_prob);
650 let classifier = nn::linear(p / "classifier", config.hidden_size, 1, Default::default());
651
652 RobertaForMultipleChoice {
653 roberta,
654 dropout,
655 classifier,
656 }
657 }
658
659 #[allow(rustdoc::invalid_html_tags)]
660 /// Forward pass through the model
661 ///
662 /// # Arguments
663 ///
664 /// * `input_ids` - Input tensor of shape (*batch size*, *sequence_length*).
665 /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
666 /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *</s>*) and 1 for the second sentence. If None set to 0.
667 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
668 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
669 ///
670 /// # Returns
671 ///
672 /// * `RobertaSequenceClassificationOutput` containing:
673 /// - `logits` - `Tensor` of shape (*1*, *batch size*) containing the logits for each of the alternatives given
674 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
675 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
676 ///
677 /// # Example
678 ///
679 /// ```no_run
680 /// # use rust_bert::bert::BertConfig;
681 /// # use tch::{nn, Device, Tensor, no_grad};
682 /// # use rust_bert::Config;
683 /// # use std::path::Path;
684 /// # use tch::kind::Kind::Int64;
685 /// use rust_bert::roberta::RobertaForMultipleChoice;
686 /// # let config_path = Path::new("path/to/config.json");
687 /// # let vocab_path = Path::new("path/to/vocab.txt");
688 /// # let device = Device::Cpu;
689 /// # let vs = nn::VarStore::new(device);
690 /// # let config = BertConfig::from_file(config_path);
691 /// # let roberta_model = RobertaForMultipleChoice::new(&vs.root(), &config);
692 /// let (num_choices, sequence_length) = (3, 128);
693 /// let input_tensor = Tensor::rand(&[num_choices, sequence_length], (Int64, device));
694 /// let mask = Tensor::zeros(&[num_choices, sequence_length], (Int64, device));
695 /// let token_type_ids = Tensor::zeros(&[num_choices, sequence_length], (Int64, device));
696 /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
697 /// .expand(&[num_choices, sequence_length], true);
698 ///
699 /// let model_output = no_grad(|| {
700 /// roberta_model.forward_t(
701 /// &input_tensor,
702 /// Some(&mask),
703 /// Some(&token_type_ids),
704 /// Some(&position_ids),
705 /// false,
706 /// )
707 /// });
708 /// ```
709 pub fn forward_t(
710 &self,
711 input_ids: &Tensor,
712 mask: Option<&Tensor>,
713 token_type_ids: Option<&Tensor>,
714 position_ids: Option<&Tensor>,
715 train: bool,
716 ) -> RobertaSequenceClassificationOutput {
717 let num_choices = input_ids.size()[1];
718
719 let input_ids = Some(input_ids.view((-1, *input_ids.size().last().unwrap())));
720 let mask = mask.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
721 let token_type_ids =
722 token_type_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
723 let position_ids =
724 position_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
725
726 let base_model_output = self
727 .roberta
728 .forward_t(
729 input_ids.as_ref(),
730 mask.as_ref(),
731 token_type_ids.as_ref(),
732 position_ids.as_ref(),
733 None,
734 None,
735 None,
736 train,
737 )
738 .unwrap();
739
740 let logits = base_model_output
741 .pooled_output
742 .unwrap()
743 .apply_t(&self.dropout, train)
744 .apply(&self.classifier)
745 .view((-1, num_choices));
746 RobertaSequenceClassificationOutput {
747 logits,
748 all_hidden_states: base_model_output.all_hidden_states,
749 all_attentions: base_model_output.all_attentions,
750 }
751 }
752}
753
754/// # RoBERTa for token classification (e.g. NER, POS)
755/// Token-level classifier predicting a label for each token provided. Note that because of bpe tokenization, the labels predicted are
756/// not necessarily aligned with words in the sentence.
757/// It is made of the following blocks:
758/// - `roberta`: Base RoBERTa model
759/// - `classifier`: Linear layer for token classification
760pub struct RobertaForTokenClassification {
761 roberta: BertModel<RobertaEmbeddings>,
762 dropout: Dropout,
763 classifier: nn::Linear,
764}
765
766impl RobertaForTokenClassification {
767 /// Build a new `RobertaForTokenClassification`
768 ///
769 /// # Arguments
770 ///
771 /// * `p` - Variable store path for the root of the RobertaForMaskedLM model
772 /// * `config` - `RobertaConfig` object defining the model architecture and vocab size
773 ///
774 /// # Example
775 ///
776 /// ```no_run
777 /// use rust_bert::roberta::{RobertaConfig, RobertaForMultipleChoice};
778 /// use rust_bert::Config;
779 /// use std::path::Path;
780 /// use tch::{nn, Device};
781 ///
782 /// let config_path = Path::new("path/to/config.json");
783 /// let device = Device::Cpu;
784 /// let p = nn::VarStore::new(device);
785 /// let config = RobertaConfig::from_file(config_path);
786 /// let roberta = RobertaForMultipleChoice::new(&p.root() / "roberta", &config);
787 /// ```
788 pub fn new<'p, P>(
789 p: P,
790 config: &BertConfig,
791 ) -> Result<RobertaForTokenClassification, RustBertError>
792 where
793 P: Borrow<nn::Path<'p>>,
794 {
795 let p = p.borrow();
796 let roberta =
797 BertModel::<RobertaEmbeddings>::new_with_optional_pooler(p / "roberta", config, false);
798 let dropout = Dropout::new(config.hidden_dropout_prob);
799 let num_labels = config
800 .id2label
801 .as_ref()
802 .ok_or_else(|| {
803 RustBertError::InvalidConfigurationError(
804 "num_labels not provided in configuration".to_string(),
805 )
806 })?
807 .len() as i64;
808 let classifier = nn::linear(
809 p / "classifier",
810 config.hidden_size,
811 num_labels,
812 Default::default(),
813 );
814
815 Ok(RobertaForTokenClassification {
816 roberta,
817 dropout,
818 classifier,
819 })
820 }
821
822 #[allow(rustdoc::invalid_html_tags)]
823 /// Forward pass through the model
824 ///
825 /// # Arguments
826 ///
827 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
828 /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
829 /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *</s>*) and 1 for the second sentence. If None set to 0.
830 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
831 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
832 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
833 ///
834 /// # Returns
835 ///
836 /// * `RobertaTokenClassificationOutput` containing:
837 /// - `logits` - `Tensor` of shape (*batch size*, *sequence_length*, *num_labels*) containing the logits for each of the input tokens and classes
838 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
839 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
840 ///
841 /// # Example
842 ///
843 /// ```no_run
844 /// # use rust_bert::bert::BertConfig;
845 /// # use tch::{nn, Device, Tensor, no_grad};
846 /// # use rust_bert::Config;
847 /// # use std::path::Path;
848 /// # use tch::kind::Kind::Int64;
849 /// use rust_bert::roberta::RobertaForTokenClassification;
850 /// # let config_path = Path::new("path/to/config.json");
851 /// # let vocab_path = Path::new("path/to/vocab.txt");
852 /// # let device = Device::Cpu;
853 /// # let vs = nn::VarStore::new(device);
854 /// # let config = BertConfig::from_file(config_path);
855 /// # let roberta_model = RobertaForTokenClassification::new(&vs.root(), &config).unwrap();
856 /// let (batch_size, sequence_length) = (64, 128);
857 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
858 /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
859 /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
860 /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
861 /// .expand(&[batch_size, sequence_length], true);
862 ///
863 /// let model_output = no_grad(|| {
864 /// roberta_model.forward_t(
865 /// Some(&input_tensor),
866 /// Some(&mask),
867 /// Some(&token_type_ids),
868 /// Some(&position_ids),
869 /// None,
870 /// false,
871 /// )
872 /// });
873 /// ```
874 pub fn forward_t(
875 &self,
876 input_ids: Option<&Tensor>,
877 mask: Option<&Tensor>,
878 token_type_ids: Option<&Tensor>,
879 position_ids: Option<&Tensor>,
880 input_embeds: Option<&Tensor>,
881 train: bool,
882 ) -> RobertaTokenClassificationOutput {
883 let base_model_output = self
884 .roberta
885 .forward_t(
886 input_ids,
887 mask,
888 token_type_ids,
889 position_ids,
890 input_embeds,
891 None,
892 None,
893 train,
894 )
895 .unwrap();
896
897 let logits = base_model_output
898 .hidden_state
899 .apply_t(&self.dropout, train)
900 .apply(&self.classifier);
901
902 RobertaTokenClassificationOutput {
903 logits,
904 all_hidden_states: base_model_output.all_hidden_states,
905 all_attentions: base_model_output.all_attentions,
906 }
907 }
908}
909
910/// # RoBERTa for question answering
911/// Extractive question-answering model based on a RoBERTa language model. Identifies the segment of a context that answers a provided question.
912/// Please note that a significant amount of pre- and post-processing is required to perform end-to-end question answering.
913/// See the question answering pipeline (also provided in this crate) for more details.
914/// It is made of the following blocks:
915/// - `roberta`: Base RoBERTa model
916/// - `qa_outputs`: Linear layer for question answering
917pub struct RobertaForQuestionAnswering {
918 roberta: BertModel<RobertaEmbeddings>,
919 qa_outputs: nn::Linear,
920}
921
922impl RobertaForQuestionAnswering {
923 /// Build a new `RobertaForQuestionAnswering`
924 ///
925 /// # Arguments
926 ///
927 /// * `p` - Variable store path for the root of the RobertaForMaskedLM model
928 /// * `config` - `RobertaConfig` object defining the model architecture and vocab size
929 ///
930 /// # Example
931 ///
932 /// ```no_run
933 /// use rust_bert::roberta::{RobertaConfig, RobertaForQuestionAnswering};
934 /// use rust_bert::Config;
935 /// use std::path::Path;
936 /// use tch::{nn, Device};
937 ///
938 /// let config_path = Path::new("path/to/config.json");
939 /// let device = Device::Cpu;
940 /// let p = nn::VarStore::new(device);
941 /// let config = RobertaConfig::from_file(config_path);
942 /// let roberta = RobertaForQuestionAnswering::new(&p.root() / "roberta", &config);
943 /// ```
944 pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaForQuestionAnswering
945 where
946 P: Borrow<nn::Path<'p>>,
947 {
948 let p = p.borrow();
949 let roberta =
950 BertModel::<RobertaEmbeddings>::new_with_optional_pooler(p / "roberta", config, false);
951 let num_labels = 2;
952 let qa_outputs = nn::linear(
953 p / "qa_outputs",
954 config.hidden_size,
955 num_labels,
956 Default::default(),
957 );
958
959 RobertaForQuestionAnswering {
960 roberta,
961 qa_outputs,
962 }
963 }
964
965 #[allow(rustdoc::invalid_html_tags)]
966 /// Forward pass through the model
967 ///
968 /// # Arguments
969 ///
970 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
971 /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
972 /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *</s>*) and 1 for the second sentence. If None set to 0.
973 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
974 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
975 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
976 ///
977 /// # Returns
978 ///
979 /// * `RobertaQuestionAnsweringOutput` containing:
980 /// - `start_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for start of the answer
981 /// - `end_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for end of the answer
982 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
983 /// - `all_attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
984 ///
985 /// # Example
986 ///
987 /// ```no_run
988 /// # use rust_bert::bert::BertConfig;
989 /// # use tch::{nn, Device, Tensor, no_grad};
990 /// # use rust_bert::Config;
991 /// # use std::path::Path;
992 /// # use tch::kind::Kind::Int64;
993 /// use rust_bert::roberta::RobertaForQuestionAnswering;
994 /// # let config_path = Path::new("path/to/config.json");
995 /// # let vocab_path = Path::new("path/to/vocab.txt");
996 /// # let device = Device::Cpu;
997 /// # let vs = nn::VarStore::new(device);
998 /// # let config = BertConfig::from_file(config_path);
999 /// # let roberta_model = RobertaForQuestionAnswering::new(&vs.root(), &config);
1000 /// let (batch_size, sequence_length) = (64, 128);
1001 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
1002 /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
1003 /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
1004 /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
1005 /// .expand(&[batch_size, sequence_length], true);
1006 ///
1007 /// let model_output = no_grad(|| {
1008 /// roberta_model.forward_t(
1009 /// Some(&input_tensor),
1010 /// Some(&mask),
1011 /// Some(&token_type_ids),
1012 /// Some(&position_ids),
1013 /// None,
1014 /// false,
1015 /// )
1016 /// });
1017 /// ```
1018 pub fn forward_t(
1019 &self,
1020 input_ids: Option<&Tensor>,
1021 mask: Option<&Tensor>,
1022 token_type_ids: Option<&Tensor>,
1023 position_ids: Option<&Tensor>,
1024 input_embeds: Option<&Tensor>,
1025 train: bool,
1026 ) -> RobertaQuestionAnsweringOutput {
1027 let base_model_output = self
1028 .roberta
1029 .forward_t(
1030 input_ids,
1031 mask,
1032 token_type_ids,
1033 position_ids,
1034 input_embeds,
1035 None,
1036 None,
1037 train,
1038 )
1039 .unwrap();
1040
1041 let sequence_output = base_model_output.hidden_state.apply(&self.qa_outputs);
1042 let logits = sequence_output.split(1, -1);
1043 let (start_logits, end_logits) = (&logits[0], &logits[1]);
1044 let start_logits = start_logits.squeeze_dim(-1);
1045 let end_logits = end_logits.squeeze_dim(-1);
1046
1047 RobertaQuestionAnsweringOutput {
1048 start_logits,
1049 end_logits,
1050 all_hidden_states: base_model_output.all_hidden_states,
1051 all_attentions: base_model_output.all_attentions,
1052 }
1053 }
1054}
1055
1056/// # RoBERTa for sentence embeddings
1057/// Transformer usable in [`SentenceEmbeddingsModel`](crate::pipelines::sentence_embeddings::SentenceEmbeddingsModel).
1058pub type RobertaForSentenceEmbeddings = BertModel<RobertaEmbeddings>;
1059
1060/// Container for the RoBERTa masked LM model output.
1061pub struct RobertaMaskedLMOutput {
1062 /// Logits for the vocabulary items at each sequence position
1063 pub prediction_scores: Tensor,
1064 /// Hidden states for all intermediate layers
1065 pub all_hidden_states: Option<Vec<Tensor>>,
1066 /// Attention weights for all intermediate layers
1067 pub all_attentions: Option<Vec<Tensor>>,
1068}
1069
1070/// Container for the RoBERTa sequence classification model output.
1071pub struct RobertaSequenceClassificationOutput {
1072 /// Logits for each input (sequence) for each target class
1073 pub logits: Tensor,
1074 /// Hidden states for all intermediate layers
1075 pub all_hidden_states: Option<Vec<Tensor>>,
1076 /// Attention weights for all intermediate layers
1077 pub all_attentions: Option<Vec<Tensor>>,
1078}
1079
1080/// Container for the RoBERTa token classification model output.
1081pub struct RobertaTokenClassificationOutput {
1082 /// Logits for each sequence item (token) for each target class
1083 pub logits: Tensor,
1084 /// Hidden states for all intermediate layers
1085 pub all_hidden_states: Option<Vec<Tensor>>,
1086 /// Attention weights for all intermediate layers
1087 pub all_attentions: Option<Vec<Tensor>>,
1088}
1089
1090/// Container for the RoBERTa question answering model output.
1091pub struct RobertaQuestionAnsweringOutput {
1092 /// Logits for the start position for token of each input sequence
1093 pub start_logits: Tensor,
1094 /// Logits for the end position for token of each input sequence
1095 pub end_logits: Tensor,
1096 /// Hidden states for all intermediate layers
1097 pub all_hidden_states: Option<Vec<Tensor>>,
1098 /// Attention weights for all intermediate layers
1099 pub all_attentions: Option<Vec<Tensor>>,
1100}