1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//     http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use tch::{nn, Tensor};
use crate::common::linear::{linear_no_bias, LinearNoBias};
use tch::nn::Init;
use crate::common::activations::_gelu;
use crate::roberta::embeddings::RobertaEmbeddings;
use crate::common::dropout::Dropout;
use crate::bert::{BertConfig, BertModel};

/// # RoBERTa Pretrained model weight files
pub struct RobertaModelResources;

/// # RoBERTa Pretrained model config files
pub struct RobertaConfigResources;

/// # RoBERTa Pretrained model vocab files
pub struct RobertaVocabResources;

/// # RoBERTa Pretrained model merges files
pub struct RobertaMergesResources;

impl RobertaModelResources {
    /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
    pub const ROBERTA: (&'static str, &'static str) = ("roberta/model.ot", "https://cdn.huggingface.co/roberta-base-rust_model.ot");
}

impl RobertaConfigResources {
    /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
    pub const ROBERTA: (&'static str, &'static str) = ("roberta/config.json", "https://cdn.huggingface.co/roberta-base-config.json");
}

impl RobertaVocabResources {
    /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
    pub const ROBERTA: (&'static str, &'static str) = ("roberta/vocab.txt", "https://cdn.huggingface.co/roberta-base-vocab.json");
}

impl RobertaMergesResources {
    /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
    pub const ROBERTA: (&'static str, &'static str) = ("roberta/merges.txt", "https://cdn.huggingface.co/roberta-base-merges.txt");
}

pub struct RobertaLMHead {
    dense: nn::Linear,
    decoder: LinearNoBias,
    layer_norm: nn::LayerNorm,
    bias: Tensor,
}

impl RobertaLMHead {
    pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaLMHead {
        let dense = nn::linear(p / "dense", config.hidden_size, config.hidden_size, Default::default());
        let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
        let layer_norm = nn::layer_norm(p / "layer_norm", vec![config.hidden_size], layer_norm_config);
        let decoder = linear_no_bias(&(p / "decoder"), config.hidden_size, config.vocab_size, Default::default());
        let bias = p.var("bias", &[config.vocab_size], Init::KaimingUniform);

        RobertaLMHead { dense, decoder, layer_norm, bias }
    }

    pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
        (_gelu(&hidden_states.apply(&self.dense))).apply(&self.layer_norm).apply(&self.decoder) + &self.bias
    }
}

/// # RoBERTa for masked language model
/// Base RoBERTa model with a RoBERTa masked language model head to predict missing tokens, for example `"Looks like one [MASK] is missing" -> "person"`
/// It is made of the following blocks:
/// - `roberta`: Base BertModel with RoBERTa embeddings
/// - `lm_head`: RoBERTa LM prediction head
pub struct RobertaForMaskedLM {
    roberta: BertModel<RobertaEmbeddings>,
    lm_head: RobertaLMHead,
}

impl RobertaForMaskedLM {
    /// Build a new `RobertaForMaskedLM`
    ///
    /// # Arguments
    ///
    /// * `p` - Variable store path for the root of the RobertaForMaskedLM model
    /// * `config` - `BertConfig` object defining the model architecture and vocab size
    ///
    /// # Example
    ///
    /// ```no_run
    /// use rust_bert::bert::BertConfig;
    /// use tch::{nn, Device};
    /// use rust_bert::Config;
    /// use std::path::Path;
    /// use rust_bert::roberta::RobertaForMaskedLM;
    ///
    /// let config_path = Path::new("path/to/config.json");
    /// let device = Device::Cpu;
    /// let p = nn::VarStore::new(device);
    /// let config = BertConfig::from_file(config_path);
    /// let roberta = RobertaForMaskedLM::new(&(&p.root() / "roberta"), &config);
    /// ```
    ///
    pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForMaskedLM {
        let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
        let lm_head = RobertaLMHead::new(&(p / "lm_head"), config);

        RobertaForMaskedLM { roberta, lm_head }
    }

    /// Forward pass through the model
    ///
    /// # Arguments
    ///
    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see *input_embeds*)
    /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
    /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *</s>*) and 1 for the second sentence. If None set to 0.
    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see *input_ids*)
    /// * `encoder_hidden_states` - Optional encoder hidden state of shape (*batch size*, *encoder_sequence_length*, *hidden_size*). If the model is defined as a decoder and the *encoder_hidden_states* is not None, used in the cross-attention layer as keys and values (query from the decoder).
    /// * `encoder_mask` - Optional encoder attention mask of shape (*batch size*, *encoder_sequence_length*). If the model is defined as a decoder and the *encoder_hidden_states* is not None, used to mask encoder values. Positions with value 0 will be masked.
    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
    ///
    /// # Returns
    ///
    /// * `output` - `Tensor` of shape (*batch size*, *num_labels*, *vocab_size*)
    /// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
    /// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
    ///
    /// # Example
    ///
    /// ```no_run
    ///# use rust_bert::bert::BertConfig;
    ///# use tch::{nn, Device, Tensor, no_grad};
    ///# use rust_bert::Config;
    ///# use std::path::Path;
    ///# use tch::kind::Kind::Int64;
    /// use rust_bert::roberta::RobertaForMaskedLM;
    ///# let config_path = Path::new("path/to/config.json");
    ///# let vocab_path = Path::new("path/to/vocab.txt");
    ///# let device = Device::Cpu;
    ///# let vs = nn::VarStore::new(device);
    ///# let config = BertConfig::from_file(config_path);
    ///# let roberta_model = RobertaForMaskedLM::new(&vs.root(), &config);
    ///  let (batch_size, sequence_length) = (64, 128);
    ///  let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
    ///  let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
    ///  let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
    ///  let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
    ///
    ///  let (output, all_hidden_states, all_attentions) = no_grad(|| {
    ///    roberta_model
    ///         .forward_t(Some(input_tensor),
    ///                    Some(mask),
    ///                    Some(token_type_ids),
    ///                    Some(position_ids),
    ///                    None,
    ///                    &None,
    ///                    &None,
    ///                    false)
    ///    });
    ///
    /// ```
    ///
    pub fn forward_t(&self,
                     input_ids: Option<Tensor>,
                     mask: Option<Tensor>,
                     token_type_ids: Option<Tensor>,
                     position_ids: Option<Tensor>,
                     input_embeds: Option<Tensor>,
                     encoder_hidden_states: &Option<Tensor>,
                     encoder_mask: &Option<Tensor>,
                     train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
        let (hidden_state, _, all_hidden_states, all_attentions) = self.roberta.forward_t(input_ids, mask, token_type_ids, position_ids,
                                                                                          input_embeds, encoder_hidden_states, encoder_mask, train).unwrap();

        let prediction_scores = self.lm_head.forward(&hidden_state);
        (prediction_scores, all_hidden_states, all_attentions)
    }
}

pub struct RobertaClassificationHead {
    dense: nn::Linear,
    dropout: Dropout,
    out_proj: nn::Linear,
}

impl RobertaClassificationHead {
    pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaClassificationHead {
        let dense = nn::linear(p / "dense", config.hidden_size, config.hidden_size, Default::default());
        let num_labels = config.id2label.as_ref().expect("num_labels not provided in configuration").len() as i64;
        let out_proj = nn::linear(p / "out_proj", config.hidden_size, num_labels, Default::default());
        let dropout = Dropout::new(config.hidden_dropout_prob);

        RobertaClassificationHead { dense, dropout, out_proj }
    }

    pub fn forward_t(&self, hidden_states: &Tensor, train: bool) -> Tensor {
        hidden_states
            .select(1, 0)
            .apply_t(&self.dropout, train)
            .apply(&self.dense)
            .tanh()
            .apply_t(&self.dropout, train)
            .apply(&self.out_proj)
    }
}

/// # RoBERTa for sequence classification
/// Base RoBERTa model with a classifier head to perform sentence or document-level classification
/// It is made of the following blocks:
/// - `roberta`: Base RoBERTa model
/// - `classifier`: RoBERTa classification head made of 2 linear layers
pub struct RobertaForSequenceClassification {
    roberta: BertModel<RobertaEmbeddings>,
    classifier: RobertaClassificationHead,
}

impl RobertaForSequenceClassification {
    /// Build a new `RobertaForSequenceClassification`
    ///
    /// # Arguments
    ///
    /// * `p` - Variable store path for the root of the RobertaForSequenceClassification model
    /// * `config` - `BertConfig` object defining the model architecture and vocab size
    ///
    /// # Example
    ///
    /// ```no_run
    /// use rust_bert::bert::BertConfig;
    /// use tch::{nn, Device};
    /// use rust_bert::Config;
    /// use std::path::Path;
    /// use rust_bert::roberta::RobertaForSequenceClassification;
    ///
    /// let config_path = Path::new("path/to/config.json");
    /// let device = Device::Cpu;
    /// let p = nn::VarStore::new(device);
    /// let config = BertConfig::from_file(config_path);
    /// let roberta = RobertaForSequenceClassification::new(&(&p.root() / "roberta"), &config);
    /// ```
    ///
    pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForSequenceClassification {
        let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
        let classifier = RobertaClassificationHead::new(&(p / "classifier"), config);

        RobertaForSequenceClassification { roberta, classifier }
    }

    /// Forward pass through the model
    ///
    /// # Arguments
    ///
    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
    /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
    /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *</s>*) and 1 for the second sentence. If None set to 0.
    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
    ///
    /// # Returns
    ///
    /// * `labels` - `Tensor` of shape (*batch size*, *num_labels*)
    /// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
    /// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
    ///
    /// # Example
    ///
    /// ```no_run
    ///# use rust_bert::bert::BertConfig;
    ///# use tch::{nn, Device, Tensor, no_grad};
    ///# use rust_bert::Config;
    ///# use std::path::Path;
    ///# use tch::kind::Kind::Int64;
    /// use rust_bert::roberta::RobertaForSequenceClassification;
    ///# let config_path = Path::new("path/to/config.json");
    ///# let vocab_path = Path::new("path/to/vocab.txt");
    ///# let device = Device::Cpu;
    ///# let vs = nn::VarStore::new(device);
    ///# let config = BertConfig::from_file(config_path);
    ///# let roberta_model = RobertaForSequenceClassification::new(&vs.root(), &config);
    ///  let (batch_size, sequence_length) = (64, 128);
    ///  let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
    ///  let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
    ///  let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
    ///  let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
    ///
    ///  let (labels, all_hidden_states, all_attentions) = no_grad(|| {
    ///    roberta_model
    ///         .forward_t(Some(input_tensor),
    ///                    Some(mask),
    ///                    Some(token_type_ids),
    ///                    Some(position_ids),
    ///                    None,
    ///                    false)
    ///    });
    ///
    /// ```
    ///
    pub fn forward_t(&self,
                     input_ids: Option<Tensor>,
                     mask: Option<Tensor>,
                     token_type_ids: Option<Tensor>,
                     position_ids: Option<Tensor>,
                     input_embeds: Option<Tensor>,
                     train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
        let (hidden_state, _, all_hidden_states, all_attentions) = self.roberta.forward_t(input_ids, mask, token_type_ids, position_ids,
                                                                                          input_embeds, &None, &None, train).unwrap();

        let output = self.classifier.forward_t(&hidden_state, train);
        (output, all_hidden_states, all_attentions)
    }
}

/// # RoBERTa for multiple choices
/// Multiple choices model using a RoBERTa base model and a linear classifier.
/// Input should be in the form `<s> Context </s> Possible choice </s>`. The choice is made along the batch axis,
/// assuming all elements of the batch are alternatives to be chosen from for a given context.
/// It is made of the following blocks:
/// - `roberta`: Base RoBERTa model
/// - `classifier`: Linear layer for multiple choices
pub struct RobertaForMultipleChoice {
    roberta: BertModel<RobertaEmbeddings>,
    dropout: Dropout,
    classifier: nn::Linear,
}

impl RobertaForMultipleChoice {
    /// Build a new `RobertaForMultipleChoice`
    ///
    /// # Arguments
    ///
    /// * `p` - Variable store path for the root of the RobertaForMultipleChoice model
    /// * `config` - `BertConfig` object defining the model architecture and vocab size
    ///
    /// # Example
    ///
    /// ```no_run
    /// use rust_bert::bert::BertConfig;
    /// use tch::{nn, Device};
    /// use rust_bert::Config;
    /// use std::path::Path;
    /// use rust_bert::roberta::RobertaForMultipleChoice;
    ///
    /// let config_path = Path::new("path/to/config.json");
    /// let device = Device::Cpu;
    /// let p = nn::VarStore::new(device);
    /// let config = BertConfig::from_file(config_path);
    /// let roberta = RobertaForMultipleChoice::new(&(&p.root() / "roberta"), &config);
    /// ```
    ///
    pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForMultipleChoice {
        let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
        let dropout = Dropout::new(config.hidden_dropout_prob);
        let classifier = nn::linear(p / "classifier", config.hidden_size, 1, Default::default());

        RobertaForMultipleChoice { roberta, dropout, classifier }
    }

    /// Forward pass through the model
    ///
    /// # Arguments
    ///
    /// * `input_ids` - Input tensor of shape (*batch size*, *sequence_length*).
    /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
    /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *</s>*) and 1 for the second sentence. If None set to 0.
    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
    ///
    /// # Returns
    ///
    /// * `output` - `Tensor` of shape (*1*, *batch size*) containing the logits for each of the alternatives given
    /// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
    /// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
    ///
    /// # Example
    ///
    /// ```no_run
    ///# use rust_bert::bert::BertConfig;
    ///# use tch::{nn, Device, Tensor, no_grad};
    ///# use rust_bert::Config;
    ///# use std::path::Path;
    ///# use tch::kind::Kind::Int64;
    /// use rust_bert::roberta::RobertaForMultipleChoice;
    ///# let config_path = Path::new("path/to/config.json");
    ///# let vocab_path = Path::new("path/to/vocab.txt");
    ///# let device = Device::Cpu;
    ///# let vs = nn::VarStore::new(device);
    ///# let config = BertConfig::from_file(config_path);
    ///# let roberta_model = RobertaForMultipleChoice::new(&vs.root(), &config);
    ///  let (num_choices, sequence_length) = (3, 128);
    ///  let input_tensor = Tensor::rand(&[num_choices, sequence_length], (Int64, device));
    ///  let mask = Tensor::zeros(&[num_choices, sequence_length], (Int64, device));
    ///  let token_type_ids = Tensor::zeros(&[num_choices, sequence_length], (Int64, device));
    ///  let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[num_choices, sequence_length], true);
    ///
    ///  let (choices, all_hidden_states, all_attentions) = no_grad(|| {
    ///    roberta_model
    ///         .forward_t(input_tensor,
    ///                    Some(mask),
    ///                    Some(token_type_ids),
    ///                    Some(position_ids),
    ///                    false)
    ///    });
    ///
    /// ```
    ///
    pub fn forward_t(&self,
                     input_ids: Tensor,
                     mask: Option<Tensor>,
                     token_type_ids: Option<Tensor>,
                     position_ids: Option<Tensor>,
                     train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
        let num_choices = input_ids.size()[1];

        let flat_input_ids = Some(input_ids.view((-1i64, *input_ids.size().last().unwrap())));
        let flat_position_ids = match position_ids {
            Some(value) => Some(value.view((-1i64, *value.size().last().unwrap()))),
            None => None
        };
        let flat_token_type_ids = match token_type_ids {
            Some(value) => Some(value.view((-1i64, *value.size().last().unwrap()))),
            None => None
        };
        let flat_mask = match mask {
            Some(value) => Some(value.view((-1i64, *value.size().last().unwrap()))),
            None => None
        };

        let (_, pooled_output, all_hidden_states, all_attentions) = self.roberta.forward_t(flat_input_ids, flat_mask, flat_token_type_ids, flat_position_ids,
                                                                                           None, &None, &None, train).unwrap();

        let output = pooled_output.apply_t(&self.dropout, train).apply(&self.classifier).view((-1, num_choices));
        (output, all_hidden_states, all_attentions)
    }
}

/// # RoBERTa for token classification (e.g. NER, POS)
/// Token-level classifier predicting a label for each token provided. Note that because of bpe tokenization, the labels predicted are
/// not necessarily aligned with words in the sentence.
/// It is made of the following blocks:
/// - `roberta`: Base RoBERTa model
/// - `classifier`: Linear layer for token classification
pub struct RobertaForTokenClassification {
    roberta: BertModel<RobertaEmbeddings>,
    dropout: Dropout,
    classifier: nn::Linear,
}

impl RobertaForTokenClassification {
    /// Build a new `RobertaForTokenClassification`
    ///
    /// # Arguments
    ///
    /// * `p` - Variable store path for the root of the RobertaForTokenClassification model
    /// * `config` - `BertConfig` object defining the model architecture and vocab size
    ///
    /// # Example
    ///
    /// ```no_run
    /// use rust_bert::bert::BertConfig;
    /// use tch::{nn, Device};
    /// use rust_bert::Config;
    /// use std::path::Path;
    /// use rust_bert::roberta::RobertaForTokenClassification;
    ///
    /// let config_path = Path::new("path/to/config.json");
    /// let device = Device::Cpu;
    /// let p = nn::VarStore::new(device);
    /// let config = BertConfig::from_file(config_path);
    /// let roberta = RobertaForTokenClassification::new(&(&p.root() / "roberta"), &config);
    /// ```
    ///
    pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForTokenClassification {
        let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
        let dropout = Dropout::new(config.hidden_dropout_prob);
        let num_labels = config.id2label.as_ref().expect("num_labels not provided in configuration").len() as i64;
        let classifier = nn::linear(p / "classifier", config.hidden_size, num_labels, Default::default());

        RobertaForTokenClassification { roberta, dropout, classifier }
    }

    /// Forward pass through the model
    ///
    /// # Arguments
    ///
    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
    /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
    /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *</s>*) and 1 for the second sentence. If None set to 0.
    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
    ///
    /// # Returns
    ///
    /// * `output` - `Tensor` of shape (*batch size*, *sequence_length*, *num_labels*) containing the logits for each of the input tokens and classes
    /// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
    /// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
    ///
    /// # Example
    ///
    /// ```no_run
    ///# use rust_bert::bert::BertConfig;
    ///# use tch::{nn, Device, Tensor, no_grad};
    ///# use rust_bert::Config;
    ///# use std::path::Path;
    ///# use tch::kind::Kind::Int64;
    /// use rust_bert::roberta::RobertaForTokenClassification;
    ///# let config_path = Path::new("path/to/config.json");
    ///# let vocab_path = Path::new("path/to/vocab.txt");
    ///# let device = Device::Cpu;
    ///# let vs = nn::VarStore::new(device);
    ///# let config = BertConfig::from_file(config_path);
    ///# let roberta_model = RobertaForTokenClassification::new(&vs.root(), &config);
    ///  let (batch_size, sequence_length) = (64, 128);
    ///  let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
    ///  let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
    ///  let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
    ///  let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
    ///
    ///  let (token_labels, all_hidden_states, all_attentions) = no_grad(|| {
    ///    roberta_model
    ///         .forward_t(Some(input_tensor),
    ///                    Some(mask),
    ///                    Some(token_type_ids),
    ///                    Some(position_ids),
    ///                    None,
    ///                    false)
    ///    });
    ///
    /// ```
    ///
    pub fn forward_t(&self,
                     input_ids: Option<Tensor>,
                     mask: Option<Tensor>,
                     token_type_ids: Option<Tensor>,
                     position_ids: Option<Tensor>,
                     input_embeds: Option<Tensor>,
                     train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
        let (hidden_state, _, all_hidden_states, all_attentions) = self.roberta.forward_t(input_ids, mask, token_type_ids, position_ids,
                                                                                          input_embeds, &None, &None, train).unwrap();

        let sequence_output = hidden_state.apply_t(&self.dropout, train).apply(&self.classifier);
        (sequence_output, all_hidden_states, all_attentions)
    }
}

/// # RoBERTa for question answering
/// Extractive question-answering model based on a RoBERTa language model. Identifies the segment of a context that answers a provided question.
/// Please note that a significant amount of pre- and post-processing is required to perform end-to-end question answering.
/// See the question answering pipeline (also provided in this crate) for more details.
/// It is made of the following blocks:
/// - `roberta`: Base RoBERTa model
/// - `qa_outputs`: Linear layer for question answering
pub struct RobertaForQuestionAnswering {
    roberta: BertModel<RobertaEmbeddings>,
    qa_outputs: nn::Linear,
}

impl RobertaForQuestionAnswering {
    /// Build a new `RobertaForQuestionAnswering`
    ///
    /// # Arguments
    ///
    /// * `p` - Variable store path for the root of the RobertaForQuestionAnswering model
    /// * `config` - `BertConfig` object defining the model architecture and vocab size
    ///
    /// # Example
    ///
    /// ```no_run
    /// use rust_bert::bert::BertConfig;
    /// use tch::{nn, Device};
    /// use rust_bert::Config;
    /// use std::path::Path;
    /// use rust_bert::roberta::RobertaForQuestionAnswering;
    ///
    /// let config_path = Path::new("path/to/config.json");
    /// let device = Device::Cpu;
    /// let p = nn::VarStore::new(device);
    /// let config = BertConfig::from_file(config_path);
    /// let roberta = RobertaForQuestionAnswering::new(&(&p.root() / "roberta"), &config);
    /// ```
    ///
    pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForQuestionAnswering {
        let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
        let num_labels = 2;
        let qa_outputs = nn::linear(p / "qa_outputs", config.hidden_size, num_labels, Default::default());

        RobertaForQuestionAnswering { roberta, qa_outputs }
    }

    /// Forward pass through the model
    ///
    /// # Arguments
    ///
    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
    /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
    /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *</s>*) and 1 for the second sentence. If None set to 0.
    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
    ///
    /// # Returns
    ///
    /// * `start_scores` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for start of the answer
    /// * `end_scores` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for end of the answer
    /// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
    /// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
    ///
    /// # Example
    ///
    /// ```no_run
    ///# use rust_bert::bert::BertConfig;
    ///# use tch::{nn, Device, Tensor, no_grad};
    ///# use rust_bert::Config;
    ///# use std::path::Path;
    ///# use tch::kind::Kind::Int64;
    /// use rust_bert::roberta::RobertaForQuestionAnswering;
    ///# let config_path = Path::new("path/to/config.json");
    ///# let vocab_path = Path::new("path/to/vocab.txt");
    ///# let device = Device::Cpu;
    ///# let vs = nn::VarStore::new(device);
    ///# let config = BertConfig::from_file(config_path);
    ///# let roberta_model = RobertaForQuestionAnswering::new(&vs.root(), &config);
    ///  let (batch_size, sequence_length) = (64, 128);
    ///  let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
    ///  let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
    ///  let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
    ///  let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
    ///
    ///  let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
    ///    roberta_model
    ///         .forward_t(Some(input_tensor),
    ///                    Some(mask),
    ///                    Some(token_type_ids),
    ///                    Some(position_ids),
    ///                    None,
    ///                    false)
    ///    });
    ///
    /// ```
    ///
    pub fn forward_t(&self,
                     input_ids: Option<Tensor>,
                     mask: Option<Tensor>,
                     token_type_ids: Option<Tensor>,
                     position_ids: Option<Tensor>,
                     input_embeds: Option<Tensor>,
                     train: bool) -> (Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
        let (hidden_state, _, all_hidden_states, all_attentions) = self.roberta.forward_t(input_ids, mask, token_type_ids, position_ids,
                                                                                          input_embeds, &None, &None, train).unwrap();

        let sequence_output = hidden_state.apply(&self.qa_outputs);
        let logits = sequence_output.split(1, -1);
        let (start_logits, end_logits) = (&logits[0], &logits[1]);
        let start_logits = start_logits.squeeze1(-1);
        let end_logits = end_logits.squeeze1(-1);

        (start_logits, end_logits, all_hidden_states, all_attentions)
    }
}