rust_bert/models/mobilebert/
mobilebert_model.rs

1// Copyright (c) 2020  The Google AI Language Team Authors, The HuggingFace Inc. team and github/lonePatient
2// Copyright 2020 Guillaume Becquin
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//     http://www.apache.org/licenses/LICENSE-2.0
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use crate::common::activations::{Activation, TensorFunction};
14use crate::common::dropout::Dropout;
15use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair;
16use crate::mobilebert::embeddings::MobileBertEmbeddings;
17use crate::mobilebert::encoder::{MobileBertEncoder, MobileBertPooler};
18use crate::{Config, RustBertError};
19use serde::{Deserialize, Serialize};
20use std::borrow::Borrow;
21use std::collections::HashMap;
22use tch::nn::init::DEFAULT_KAIMING_UNIFORM;
23use tch::nn::{Init, LayerNormConfig, Module};
24use tch::{nn, Kind, Tensor};
25
26/// # MobileBERT Pretrained model weight files
27pub struct MobileBertModelResources;
28
29/// # MobileBERT Pretrained model config files
30pub struct MobileBertConfigResources;
31
32/// # MobileBERT Pretrained model vocab files
33pub struct MobileBertVocabResources;
34
35impl MobileBertModelResources {
36    /// Shared under Apache 2.0 license by the Google team at <https://huggingface.co/google/mobilebert-uncased>. Modified with conversion to C-array format.
37    pub const MOBILEBERT_UNCASED: (&'static str, &'static str) = (
38        "mobilebert-uncased/model",
39        "https://huggingface.co/google/mobilebert-uncased/resolve/main/rust_model.ot",
40    );
41    /// Shared under MIT license at <https://huggingface.co/mrm8488/mobilebert-finetuned-pos>. Modified with conversion to C-array format.
42    pub const MOBILEBERT_ENGLISH_POS: (&'static str, &'static str) = (
43        "mobilebert-finetuned-pos/model",
44        "https://huggingface.co/mrm8488/mobilebert-finetuned-pos/resolve/main/rust_model.ot",
45    );
46}
47
48impl MobileBertConfigResources {
49    /// Shared under Apache 2.0 license by the Google team at <https://huggingface.co/google/mobilebert-uncased>. Modified with conversion to C-array format.
50    pub const MOBILEBERT_UNCASED: (&'static str, &'static str) = (
51        "mobilebert-uncased/config",
52        "https://huggingface.co/google/mobilebert-uncased/resolve/main/config.json",
53    );
54    /// Shared under MIT license at <https://huggingface.co/mrm8488/mobilebert-finetuned-pos>. Modified with conversion to C-array format.
55    pub const MOBILEBERT_ENGLISH_POS: (&'static str, &'static str) = (
56        "mobilebert-finetuned-pos/config",
57        "https://huggingface.co/mrm8488/mobilebert-finetuned-pos/resolve/main/config.json",
58    );
59}
60
61impl MobileBertVocabResources {
62    /// Shared under Apache 2.0 license by the Google team at <https://huggingface.co/google/mobilebert-uncased>. Modified with conversion to C-array format.
63    pub const MOBILEBERT_UNCASED: (&'static str, &'static str) = (
64        "mobilebert-uncased/vocab",
65        "https://huggingface.co/google/mobilebert-uncased/resolve/main/vocab.txt",
66    );
67    /// Shared under MIT license at <https://huggingface.co/mrm8488/mobilebert-finetuned-pos>. Modified with conversion to C-array format.
68    pub const MOBILEBERT_ENGLISH_POS: (&'static str, &'static str) = (
69        "mobilebert-finetuned-pos/vocab",
70        "https://huggingface.co/mrm8488/mobilebert-finetuned-pos/resolve/main/vocab.txt",
71    );
72}
73
74#[allow(non_camel_case_types)]
75#[derive(Clone, Debug, Serialize, Deserialize, Copy)]
76/// # Normalization type to use for the MobileBERT model.
77/// `no_norm` uses a matrix multiplication with a set of learned weights, while `layer_norm` uses a
78/// build-in layer normalization module.
79pub enum NormalizationType {
80    layer_norm,
81    no_norm,
82}
83
84#[derive(Debug)]
85/// # No-normalization option for MobileBERT
86/// Basic module performing a linear multiplication using trained coefficients and bias
87pub struct NoNorm {
88    weight: Tensor,
89    bias: Tensor,
90}
91
92impl NoNorm {
93    /// Creates a new `NoNorm` layer of given hidden size.
94    ///
95    /// # Arguments:
96    ///
97    /// * hidden_size - input tensor's hidden size
98    ///
99    /// # Example
100    ///
101    /// ```no_run
102    /// use rust_bert::mobilebert::NoNorm;
103    /// use tch::{nn, Device};
104    /// let device = Device::Cpu;
105    /// let p = nn::VarStore::new(device);
106    /// let hidden_size = 512;
107    /// let no_norm = NoNorm::new(&p.root(), hidden_size);
108    /// ```
109    pub fn new<'p, P>(p: P, hidden_size: i64) -> NoNorm
110    where
111        P: Borrow<nn::Path<'p>>,
112    {
113        let p = p.borrow();
114        let weight = p.var("weight", &[hidden_size], Init::Const(1.0));
115        let bias = p.var("bias", &[hidden_size], Init::Const(0.0));
116        NoNorm { weight, bias }
117    }
118}
119
120impl Module for NoNorm {
121    fn forward(&self, xs: &Tensor) -> Tensor {
122        xs * &self.weight + &self.bias
123    }
124}
125
126pub enum NormalizationLayer {
127    LayerNorm(nn::LayerNorm),
128    NoNorm(NoNorm),
129}
130
131impl NormalizationLayer {
132    pub fn new<'p, P>(
133        p: P,
134        normalization_type: NormalizationType,
135        hidden_size: i64,
136        eps: Option<f64>,
137    ) -> NormalizationLayer
138    where
139        P: Borrow<nn::Path<'p>>,
140    {
141        let p = p.borrow();
142        match normalization_type {
143            NormalizationType::layer_norm => {
144                let layer_norm_config = LayerNormConfig {
145                    eps: eps.unwrap_or(1e-12),
146                    ..Default::default()
147                };
148                let layer_norm = nn::layer_norm(p, vec![hidden_size], layer_norm_config);
149                NormalizationLayer::LayerNorm(layer_norm)
150            }
151            NormalizationType::no_norm => {
152                let layer_norm = NoNorm::new(p, hidden_size);
153                NormalizationLayer::NoNorm(layer_norm)
154            }
155        }
156    }
157
158    pub fn forward(&self, input: &Tensor) -> Tensor {
159        match self {
160            NormalizationLayer::LayerNorm(ref layer_norm) => input.apply(layer_norm),
161            NormalizationLayer::NoNorm(ref layer_norm) => input.apply(layer_norm),
162        }
163    }
164}
165
166#[derive(Debug, Serialize, Deserialize)]
167/// # MobileBERT model configuration
168/// Defines the MobileBERT model architecture (e.g. number of layers, hidden layer size, label mapping...)
169pub struct MobileBertConfig {
170    pub hidden_act: Activation,
171    pub attention_probs_dropout_prob: f64,
172    pub hidden_dropout_prob: f64,
173    pub hidden_size: i64,
174    pub initializer_range: f64,
175    pub intermediate_size: i64,
176    pub max_position_embeddings: i64,
177    pub num_attention_heads: i64,
178    pub num_hidden_layers: i64,
179    pub type_vocab_size: i64,
180    pub vocab_size: i64,
181    pub embedding_size: i64,
182    pub layer_norm_eps: Option<f64>,
183    pub pad_token_idx: Option<i64>,
184    pub trigram_input: Option<bool>,
185    pub use_bottleneck: Option<bool>,
186    pub use_bottleneck_attention: Option<bool>,
187    pub intra_bottleneck_size: Option<i64>,
188    pub key_query_shared_bottleneck: Option<bool>,
189    pub num_feedforward_networks: Option<i64>,
190    pub normalization_type: Option<NormalizationType>,
191    pub output_attentions: Option<bool>,
192    pub output_hidden_states: Option<bool>,
193    pub classifier_activation: Option<bool>,
194    pub is_decoder: Option<bool>,
195    pub id2label: Option<HashMap<i64, String>>,
196    pub label2id: Option<HashMap<String, i64>>,
197}
198
199impl Config for MobileBertConfig {}
200
201impl Default for MobileBertConfig {
202    fn default() -> Self {
203        MobileBertConfig {
204            hidden_act: Activation::relu,
205            attention_probs_dropout_prob: 0.1,
206            hidden_dropout_prob: 0.0,
207            hidden_size: 512,
208            initializer_range: 0.02,
209            intermediate_size: 512,
210            max_position_embeddings: 512,
211            num_attention_heads: 4,
212            num_hidden_layers: 24,
213            type_vocab_size: 2,
214            vocab_size: 30522,
215            embedding_size: 128,
216            layer_norm_eps: Some(1e-12),
217            pad_token_idx: Some(0),
218            trigram_input: Some(true),
219            use_bottleneck: Some(true),
220            use_bottleneck_attention: Some(false),
221            intra_bottleneck_size: Some(128),
222            key_query_shared_bottleneck: Some(true),
223            num_feedforward_networks: Some(4),
224            normalization_type: Some(NormalizationType::no_norm),
225            output_attentions: None,
226            output_hidden_states: None,
227            classifier_activation: None,
228            is_decoder: None,
229            id2label: None,
230            label2id: None,
231        }
232    }
233}
234
235pub struct MobileBertPredictionHeadTransform {
236    dense: nn::Linear,
237    activation_function: TensorFunction,
238    layer_norm: NormalizationLayer,
239}
240
241impl MobileBertPredictionHeadTransform {
242    pub fn new<'p, P>(p: P, config: &MobileBertConfig) -> MobileBertPredictionHeadTransform
243    where
244        P: Borrow<nn::Path<'p>>,
245    {
246        let p = p.borrow();
247        let dense = nn::linear(
248            p / "dense",
249            config.hidden_size,
250            config.hidden_size,
251            Default::default(),
252        );
253        let activation_function = config.hidden_act.get_function();
254        let layer_norm = NormalizationLayer::new(
255            p / "LayerNorm",
256            NormalizationType::layer_norm,
257            config.hidden_size,
258            config.layer_norm_eps,
259        );
260        MobileBertPredictionHeadTransform {
261            dense,
262            activation_function,
263            layer_norm,
264        }
265    }
266
267    pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
268        let hidden_states = hidden_states.apply(&self.dense);
269        let hidden_states = self.activation_function.get_fn()(&hidden_states);
270        self.layer_norm.forward(&hidden_states)
271    }
272}
273
274pub struct MobileBertLMPredictionHead {
275    transform: MobileBertPredictionHeadTransform,
276    dense_weight: Tensor,
277    bias: Tensor,
278}
279
280impl MobileBertLMPredictionHead {
281    pub fn new<'p, P>(p: P, config: &MobileBertConfig) -> MobileBertLMPredictionHead
282    where
283        P: Borrow<nn::Path<'p>>,
284    {
285        let p = p.borrow();
286
287        let transform = MobileBertPredictionHeadTransform::new(p / "transform", config);
288
289        let dense_p = p / "dense";
290        let dense_weight = dense_p.var(
291            "weight",
292            &[
293                config.hidden_size - config.embedding_size,
294                config.vocab_size,
295            ],
296            DEFAULT_KAIMING_UNIFORM,
297        );
298        let bias = p.var("bias", &[config.vocab_size], Init::Const(0.0));
299
300        MobileBertLMPredictionHead {
301            transform,
302            dense_weight,
303            bias,
304        }
305    }
306
307    pub fn forward(&self, hidden_states: &Tensor, embeddings: &Tensor) -> Tensor {
308        let hidden_states = self.transform.forward(hidden_states);
309        let hidden_states = hidden_states.matmul(&Tensor::cat(
310            &[&embeddings.transpose(0, 1), &self.dense_weight],
311            0,
312        ));
313        hidden_states + &self.bias
314    }
315}
316
317pub struct MobileBertOnlyMLMHead {
318    predictions: MobileBertLMPredictionHead,
319}
320
321impl MobileBertOnlyMLMHead {
322    pub fn new<'p, P>(p: P, config: &MobileBertConfig) -> MobileBertOnlyMLMHead
323    where
324        P: Borrow<nn::Path<'p>>,
325    {
326        let p = p.borrow();
327
328        let predictions = MobileBertLMPredictionHead::new(p / "predictions", config);
329        MobileBertOnlyMLMHead { predictions }
330    }
331
332    pub fn forward(&self, hidden_states: &Tensor, embeddings: &Tensor) -> Tensor {
333        self.predictions.forward(hidden_states, embeddings)
334    }
335}
336
337/// # MobileBertModel Base model
338/// Base architecture for MobileBERT models. Task-specific models will be built from this common base model
339/// It is made of the following blocks:
340/// - `embeddings`: Word, token type and position embeddings
341/// - `encoder`: `MobileBertEncoder` made of a stack of `MobileBertLayer`
342/// - `pooler`: Optional `MobileBertPooler` taking the first sequence element hidden state for sequence-level tasks
343/// - `position_ids` preset position ids tensor used in case they are not provided by the user
344pub struct MobileBertModel {
345    embeddings: MobileBertEmbeddings,
346    encoder: MobileBertEncoder,
347    pooler: Option<MobileBertPooler>,
348    position_ids: Tensor,
349}
350
351impl MobileBertModel {
352    /// Build a new `MobileBertModel`
353    ///
354    /// # Arguments
355    ///
356    /// * `p` - Variable store path for the root of the MobileBERT model
357    /// * `config` - `MobileBertConfig` object defining the model architecture and decoder status
358    /// * `add_pooling_layer` - boolean flag indicating if a pooling layer should be added after the encoder
359    ///
360    /// # Example
361    ///
362    /// ```no_run
363    /// use rust_bert::mobilebert::{MobileBertConfig, MobileBertModel};
364    /// use rust_bert::Config;
365    /// use std::path::Path;
366    /// use tch::{nn, Device};
367    ///
368    /// let config_path = Path::new("path/to/config.json");
369    /// let device = Device::Cpu;
370    /// let p = nn::VarStore::new(device);
371    /// let config = MobileBertConfig::from_file(config_path);
372    /// let add_pooling_layer = true;
373    /// let mobilebert = MobileBertModel::new(&p.root() / "mobilebert", &config, add_pooling_layer);
374    /// ```
375    pub fn new<'p, P>(p: P, config: &MobileBertConfig, add_pooling_layer: bool) -> MobileBertModel
376    where
377        P: Borrow<nn::Path<'p>>,
378    {
379        let p = p.borrow();
380
381        let embeddings = MobileBertEmbeddings::new(p / "embeddings", config);
382        let encoder = MobileBertEncoder::new(p / "encoder", config);
383        let pooler = if add_pooling_layer {
384            Some(MobileBertPooler::new(p / "pooler", config))
385        } else {
386            None
387        };
388        let position_ids =
389            Tensor::arange(config.max_position_embeddings, (Kind::Int64, p.device()))
390                .expand([1, -1], true);
391        MobileBertModel {
392            embeddings,
393            encoder,
394            pooler,
395            position_ids,
396        }
397    }
398
399    /// Forward pass through the model
400    ///
401    /// # Arguments
402    ///
403    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
404    /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
405    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
406    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
407    /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
408    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
409    ///
410    /// # Returns
411    ///
412    /// * `MobileBertOutput` containing:
413    ///   - `hidden_state` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*)
414    ///   - `pooled_output` - Optional `Tensor` of shape (*batch size*, *hidden_size*) if the model was created with an optional pooling layer
415    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
416    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
417    ///
418    /// # Example
419    ///
420    /// ```no_run
421    /// # use tch::{nn, Device, Tensor, no_grad};
422    /// # use rust_bert::Config;
423    /// # use std::path::Path;
424    /// # use tch::kind::Kind::Int64;
425    /// use rust_bert::mobilebert::{MobileBertConfig, MobileBertModel};
426    /// # let config_path = Path::new("path/to/config.json");
427    /// # let device = Device::Cpu;
428    /// # let vs = nn::VarStore::new(device);
429    /// # let config = MobileBertConfig::from_file(config_path);
430    /// let add_pooling_layer = true;
431    /// let model = MobileBertModel::new(&vs.root(), &config, add_pooling_layer);
432    /// let (batch_size, sequence_length) = (64, 128);
433    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
434    /// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
435    /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
436    /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
437    ///     .expand(&[batch_size, sequence_length], true);
438    ///
439    /// let model_output = no_grad(|| {
440    ///     model
441    ///         .forward_t(
442    ///             Some(&input_tensor),
443    ///             Some(&token_type_ids),
444    ///             Some(&position_ids),
445    ///             None,
446    ///             Some(&attention_mask),
447    ///             false,
448    ///         )
449    ///         .unwrap()
450    /// });
451    /// ```
452    pub fn forward_t(
453        &self,
454        input_ids: Option<&Tensor>,
455        token_type_ids: Option<&Tensor>,
456        position_ids: Option<&Tensor>,
457        input_embeds: Option<&Tensor>,
458        attention_mask: Option<&Tensor>,
459        train: bool,
460    ) -> Result<MobileBertOutput, RustBertError> {
461        let (input_shape, device) =
462            get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeds)?;
463
464        let calc_attention_mask = if attention_mask.is_none() {
465            Some(Tensor::ones(input_shape.as_slice(), (Kind::Int64, device)))
466        } else {
467            None
468        };
469
470        let calc_token_type_ids = if token_type_ids.is_none() {
471            Some(Tensor::zeros(input_shape.as_slice(), (Kind::Int64, device)))
472        } else {
473            None
474        };
475
476        let calc_position_ids = if position_ids.is_none() {
477            Some(self.position_ids.slice(1, 0, input_shape[1], 1))
478        } else {
479            None
480        };
481
482        let position_ids = position_ids.unwrap_or_else(|| calc_position_ids.as_ref().unwrap());
483
484        let attention_mask =
485            attention_mask.unwrap_or_else(|| calc_attention_mask.as_ref().unwrap());
486        let attention_mask = match attention_mask.dim() {
487            3 => attention_mask.unsqueeze(1),
488            2 => attention_mask.unsqueeze(1).unsqueeze(1),
489            _ => {
490                return Err(RustBertError::ValueError(
491                    "Invalid attention mask dimension, must be 2 or 3".into(),
492                ));
493            }
494        };
495
496        let token_type_ids =
497            token_type_ids.unwrap_or_else(|| calc_token_type_ids.as_ref().unwrap());
498
499        let embedding_output = self.embeddings.forward_t(
500            input_ids,
501            token_type_ids,
502            position_ids,
503            input_embeds,
504            train,
505        )?;
506
507        let attention_mask: Tensor = ((attention_mask.ones_like() - attention_mask) * -10000.0)
508            .to_kind(embedding_output.kind());
509
510        let encoder_output =
511            self.encoder
512                .forward_t(&embedding_output, Some(&attention_mask), train);
513
514        let pooled_output = if let Some(pooler) = &self.pooler {
515            Some(pooler.forward(&encoder_output.hidden_state))
516        } else {
517            None
518        };
519
520        Ok(MobileBertOutput {
521            hidden_state: encoder_output.hidden_state,
522            pooled_output,
523            all_hidden_states: encoder_output.all_hidden_states,
524            all_attentions: encoder_output.all_attentions,
525        })
526    }
527
528    fn get_embeddings(&self) -> &Tensor {
529        &self.embeddings.word_embeddings.ws
530    }
531}
532
533/// # MobileBERT for masked language model
534/// Base MobileBERT model with a masked language model head to predict missing tokens, for example `"Looks like one [MASK] is missing" -> "person"`
535/// It is made of the following blocks:
536/// - `mobilebert`: Base MobileBertModel
537/// - `classifier`: MobileBERT LM prediction head
538pub struct MobileBertForMaskedLM {
539    mobilebert: MobileBertModel,
540    classifier: MobileBertOnlyMLMHead,
541}
542
543impl MobileBertForMaskedLM {
544    /// Build a new `MobileBertForMaskedLM`
545    ///
546    /// # Arguments
547    ///
548    /// * `p` - Variable store path for the root of the MobileBERT model
549    /// * `config` - `MobileBertConfig` object defining the model architecture and decoder status
550    ///
551    /// # Example
552    ///
553    /// ```no_run
554    /// use rust_bert::mobilebert::{MobileBertConfig, MobileBertForMaskedLM};
555    /// use rust_bert::Config;
556    /// use std::path::Path;
557    /// use tch::{nn, Device};
558    ///
559    /// let config_path = Path::new("path/to/config.json");
560    /// let device = Device::Cpu;
561    /// let p = nn::VarStore::new(device);
562    /// let config = MobileBertConfig::from_file(config_path);
563    /// let mobilebert = MobileBertForMaskedLM::new(&p.root(), &config);
564    /// ```
565    pub fn new<'p, P>(p: P, config: &MobileBertConfig) -> MobileBertForMaskedLM
566    where
567        P: Borrow<nn::Path<'p>>,
568    {
569        let p = p.borrow();
570
571        let mobilebert = MobileBertModel::new(p / "mobilebert", config, false);
572        let classifier = MobileBertOnlyMLMHead::new(p / "cls", config);
573        MobileBertForMaskedLM {
574            mobilebert,
575            classifier,
576        }
577    }
578
579    /// Forward pass through the model
580    ///
581    /// # Arguments
582    ///
583    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
584    /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
585    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
586    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
587    /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
588    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
589    ///
590    /// # Returns
591    ///
592    /// * `MobileBertMaskedLMOutput` containing:
593    ///   - `logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*)
594    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
595    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
596    ///
597    /// # Example
598    ///
599    /// ```no_run
600    /// # use tch::{nn, Device, Tensor, no_grad};
601    /// # use rust_bert::Config;
602    /// # use std::path::Path;
603    /// # use tch::kind::Kind::Int64;
604    /// use rust_bert::mobilebert::{MobileBertConfig, MobileBertForMaskedLM};
605    /// # let config_path = Path::new("path/to/config.json");
606    /// # let device = Device::Cpu;
607    /// # let vs = nn::VarStore::new(device);
608    /// # let config = MobileBertConfig::from_file(config_path);
609    /// let model = MobileBertForMaskedLM::new(&vs.root(), &config);
610    /// let (batch_size, sequence_length) = (64, 128);
611    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
612    /// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
613    /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
614    /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
615    ///     .expand(&[batch_size, sequence_length], true);
616    ///
617    /// let model_output = no_grad(|| {
618    ///     model
619    ///         .forward_t(
620    ///             Some(&input_tensor),
621    ///             Some(&token_type_ids),
622    ///             Some(&position_ids),
623    ///             None,
624    ///             Some(&attention_mask),
625    ///             false,
626    ///         )
627    ///         .unwrap()
628    /// });
629    /// ```
630    pub fn forward_t(
631        &self,
632        input_ids: Option<&Tensor>,
633        token_type_ids: Option<&Tensor>,
634        position_ids: Option<&Tensor>,
635        input_embeds: Option<&Tensor>,
636        attention_mask: Option<&Tensor>,
637        train: bool,
638    ) -> Result<MobileBertMaskedLMOutput, RustBertError> {
639        let mobilebert_output = self.mobilebert.forward_t(
640            input_ids,
641            token_type_ids,
642            position_ids,
643            input_embeds,
644            attention_mask,
645            train,
646        )?;
647
648        let logits = self.classifier.forward(
649            &mobilebert_output.hidden_state,
650            self.mobilebert.get_embeddings(),
651        );
652
653        Ok(MobileBertMaskedLMOutput {
654            logits,
655            all_hidden_states: mobilebert_output.all_hidden_states,
656            all_attentions: mobilebert_output.all_attentions,
657        })
658    }
659}
660
661/// # MobileBERT for sequence classification
662/// Base MobileBERT model with a classifier head to perform sentence or document-level classification
663/// It is made of the following blocks:
664/// - `mobilebert`: Base MobileBertModel
665/// - `dropout`: Dropout layer before the last linear layer
666/// - `classifier`: linear layer mapping from hidden to the number of classes to predict
667pub struct MobileBertForSequenceClassification {
668    mobilebert: MobileBertModel,
669    dropout: Dropout,
670    classifier: nn::Linear,
671}
672
673impl MobileBertForSequenceClassification {
674    /// Build a new `MobileBertForSequenceClassification`
675    ///
676    /// # Arguments
677    ///
678    /// * `p` - Variable store path for the root of the MobileBERT model
679    /// * `config` - `MobileBertConfig` object defining the model architecture and decoder status
680    ///
681    /// # Example
682    ///
683    /// ```no_run
684    /// use rust_bert::mobilebert::{MobileBertConfig, MobileBertForSequenceClassification};
685    /// use rust_bert::Config;
686    /// use std::path::Path;
687    /// use tch::{nn, Device};
688    ///
689    /// let config_path = Path::new("path/to/config.json");
690    /// let device = Device::Cpu;
691    /// let p = nn::VarStore::new(device);
692    /// let config = MobileBertConfig::from_file(config_path);
693    /// let mobilebert = MobileBertForSequenceClassification::new(&p.root(), &config).unwrap();
694    /// ```
695    pub fn new<'p, P>(
696        p: P,
697        config: &MobileBertConfig,
698    ) -> Result<MobileBertForSequenceClassification, RustBertError>
699    where
700        P: Borrow<nn::Path<'p>>,
701    {
702        let p = p.borrow();
703
704        let mobilebert = MobileBertModel::new(p / "mobilebert", config, true);
705        let dropout = Dropout::new(config.hidden_dropout_prob);
706        let num_labels = config
707            .id2label
708            .as_ref()
709            .ok_or_else(|| {
710                RustBertError::InvalidConfigurationError(
711                    "num_labels not provided in configuration".to_string(),
712                )
713            })?
714            .len() as i64;
715        let classifier = nn::linear(
716            p / "classifier",
717            config.hidden_size,
718            num_labels,
719            Default::default(),
720        );
721        Ok(MobileBertForSequenceClassification {
722            mobilebert,
723            dropout,
724            classifier,
725        })
726    }
727
728    /// Forward pass through the model
729    ///
730    /// # Arguments
731    ///
732    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
733    /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
734    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
735    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
736    /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
737    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
738    ///
739    /// # Returns
740    ///
741    /// * `MobileBertSequenceClassificationOutput` containing:
742    ///   - `logits` - `Tensor` of shape (*batch size*, *num_classes*)
743    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
744    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
745    ///
746    /// # Example
747    ///
748    /// ```no_run
749    /// # use tch::{nn, Device, Tensor, no_grad};
750    /// # use rust_bert::Config;
751    /// # use std::path::Path;
752    /// # use tch::kind::Kind::Int64;
753    /// use rust_bert::mobilebert::{MobileBertConfig, MobileBertForSequenceClassification};
754    /// # let config_path = Path::new("path/to/config.json");
755    /// # let device = Device::Cpu;
756    /// # let vs = nn::VarStore::new(device);
757    /// # let config = MobileBertConfig::from_file(config_path);
758    /// let model = MobileBertForSequenceClassification::new(&vs.root(), &config).unwrap();
759    /// let (batch_size, sequence_length) = (64, 128);
760    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
761    /// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
762    /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
763    /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
764    ///     .expand(&[batch_size, sequence_length], true);
765    ///
766    /// let model_output = no_grad(|| {
767    ///     model
768    ///         .forward_t(
769    ///             Some(&input_tensor),
770    ///             Some(&token_type_ids),
771    ///             Some(&position_ids),
772    ///             None,
773    ///             Some(&attention_mask),
774    ///             false,
775    ///         )
776    ///         .unwrap()
777    /// });
778    /// ```
779    pub fn forward_t(
780        &self,
781        input_ids: Option<&Tensor>,
782        token_type_ids: Option<&Tensor>,
783        position_ids: Option<&Tensor>,
784        input_embeds: Option<&Tensor>,
785        attention_mask: Option<&Tensor>,
786        train: bool,
787    ) -> Result<MobileBertSequenceClassificationOutput, RustBertError> {
788        let mobilebert_output = self.mobilebert.forward_t(
789            input_ids,
790            token_type_ids,
791            position_ids,
792            input_embeds,
793            attention_mask,
794            train,
795        )?;
796
797        let logits = mobilebert_output
798            .pooled_output
799            .unwrap()
800            .apply_t(&self.dropout, train)
801            .apply(&self.classifier);
802
803        Ok(MobileBertSequenceClassificationOutput {
804            logits,
805            all_hidden_states: mobilebert_output.all_hidden_states,
806            all_attentions: mobilebert_output.all_attentions,
807        })
808    }
809}
810
811/// # MobileBERT for question answering
812/// Extractive question-answering model based on a MobileBERT language model. Identifies the segment of a context that answers a provided question.
813/// Please note that a significant amount of pre- and post-processing is required to perform end-to-end question answering.
814/// See the question answering pipeline (also provided in this crate) for more details.
815/// It is made of the following blocks:
816/// - `mobilebert`: Base MobileBertModel
817/// - `qa_outputs`: Linear layer for question answering
818pub struct MobileBertForQuestionAnswering {
819    mobilebert: MobileBertModel,
820    qa_outputs: nn::Linear,
821}
822
823impl MobileBertForQuestionAnswering {
824    /// Build a new `MobileBertForQuestionAnswering`
825    ///
826    /// # Arguments
827    ///
828    /// * `p` - Variable store path for the root of the MobileBERT model
829    /// * `config` - `MobileBertConfig` object defining the model architecture and decoder status
830    ///
831    /// # Example
832    ///
833    /// ```no_run
834    /// use rust_bert::mobilebert::{MobileBertConfig, MobileBertForQuestionAnswering};
835    /// use rust_bert::Config;
836    /// use std::path::Path;
837    /// use tch::{nn, Device};
838    ///
839    /// let config_path = Path::new("path/to/config.json");
840    /// let device = Device::Cpu;
841    /// let p = nn::VarStore::new(device);
842    /// let config = MobileBertConfig::from_file(config_path);
843    /// let mobilebert = MobileBertForQuestionAnswering::new(&p.root(), &config);
844    /// ```
845    pub fn new<'p, P>(p: P, config: &MobileBertConfig) -> MobileBertForQuestionAnswering
846    where
847        P: Borrow<nn::Path<'p>>,
848    {
849        let p = p.borrow();
850
851        let mobilebert = MobileBertModel::new(p / "mobilebert", config, false);
852        let qa_outputs = nn::linear(p / "qa_outputs", config.hidden_size, 2, Default::default());
853
854        MobileBertForQuestionAnswering {
855            mobilebert,
856            qa_outputs,
857        }
858    }
859
860    /// Forward pass through the model
861    ///
862    /// # Arguments
863    ///
864    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
865    /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
866    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
867    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
868    /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
869    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
870    ///
871    /// # Returns
872    ///
873    /// * `MobileBertQuestionAnsweringOutput` containing:
874    ///   - `start_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for start of the answer
875    ///   - `end_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for end of the answer
876    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
877    ///   - `all_attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
878    ///
879    /// # Example
880    ///
881    /// ```no_run
882    /// # use tch::{nn, Device, Tensor, no_grad};
883    /// # use rust_bert::Config;
884    /// # use std::path::Path;
885    /// # use tch::kind::Kind::Int64;
886    /// use rust_bert::mobilebert::{MobileBertConfig, MobileBertForQuestionAnswering};
887    /// # let config_path = Path::new("path/to/config.json");
888    /// # let device = Device::Cpu;
889    /// # let vs = nn::VarStore::new(device);
890    /// # let config = MobileBertConfig::from_file(config_path);
891    /// let model = MobileBertForQuestionAnswering::new(&vs.root(), &config);
892    /// let (batch_size, sequence_length) = (64, 128);
893    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
894    /// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
895    /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
896    /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
897    ///     .expand(&[batch_size, sequence_length], true);
898    ///
899    /// let model_output = no_grad(|| {
900    ///     model
901    ///         .forward_t(
902    ///             Some(&input_tensor),
903    ///             Some(&token_type_ids),
904    ///             Some(&position_ids),
905    ///             None,
906    ///             Some(&attention_mask),
907    ///             false,
908    ///         )
909    ///         .unwrap()
910    /// });
911    /// ```
912    pub fn forward_t(
913        &self,
914        input_ids: Option<&Tensor>,
915        token_type_ids: Option<&Tensor>,
916        position_ids: Option<&Tensor>,
917        input_embeds: Option<&Tensor>,
918        attention_mask: Option<&Tensor>,
919        train: bool,
920    ) -> Result<MobileBertQuestionAnsweringOutput, RustBertError> {
921        let mobilebert_output = self.mobilebert.forward_t(
922            input_ids,
923            token_type_ids,
924            position_ids,
925            input_embeds,
926            attention_mask,
927            train,
928        )?;
929
930        let sequence_output = mobilebert_output.hidden_state.apply(&self.qa_outputs);
931        let logits = sequence_output.split(1, -1);
932        let (start_logits, end_logits) = (&logits[0], &logits[1]);
933        let start_logits = start_logits.squeeze_dim(-1);
934        let end_logits = end_logits.squeeze_dim(-1);
935
936        Ok(MobileBertQuestionAnsweringOutput {
937            start_logits,
938            end_logits,
939            all_hidden_states: mobilebert_output.all_hidden_states,
940            all_attentions: mobilebert_output.all_attentions,
941        })
942    }
943}
944
945/// # MobileBERT for multiple choices
946/// Multiple choices model using a MobileBERT base model and a linear classifier.
947/// Input should be in the form `[CLS] Context [SEP] Possible choice [SEP]`. The choice is made along the batch axis,
948/// assuming all elements of the batch are alternatives to be chosen from for a given context.
949/// It is made of the following blocks:
950/// - `mobilebert`: Base MobileBertModel
951/// - `dropout`: Dropout layer before the last start/end logits prediction
952/// - `classifier`: Linear layer for multiple choices
953pub struct MobileBertForMultipleChoice {
954    mobilebert: MobileBertModel,
955    dropout: Dropout,
956    classifier: nn::Linear,
957}
958
959impl MobileBertForMultipleChoice {
960    /// Build a new `MobileBertForMultipleChoice`
961    ///
962    /// # Arguments
963    ///
964    /// * `p` - Variable store path for the root of the MobileBERT model
965    /// * `config` - `MobileBertConfig` object defining the model architecture and decoder status
966    ///
967    /// # Example
968    ///
969    /// ```no_run
970    /// use rust_bert::mobilebert::{MobileBertConfig, MobileBertForMultipleChoice};
971    /// use rust_bert::Config;
972    /// use std::path::Path;
973    /// use tch::{nn, Device};
974    ///
975    /// let config_path = Path::new("path/to/config.json");
976    /// let device = Device::Cpu;
977    /// let p = nn::VarStore::new(device);
978    /// let config = MobileBertConfig::from_file(config_path);
979    /// let mobilebert = MobileBertForMultipleChoice::new(&p.root(), &config);
980    /// ```
981    pub fn new<'p, P>(p: P, config: &MobileBertConfig) -> MobileBertForMultipleChoice
982    where
983        P: Borrow<nn::Path<'p>>,
984    {
985        let p = p.borrow();
986
987        let mobilebert = MobileBertModel::new(p / "mobilebert", config, true);
988        let dropout = Dropout::new(config.hidden_dropout_prob);
989        let classifier = nn::linear(p / "classifier", config.hidden_size, 1, Default::default());
990        MobileBertForMultipleChoice {
991            mobilebert,
992            dropout,
993            classifier,
994        }
995    }
996
997    /// Forward pass through the model
998    ///
999    /// # Arguments
1000    ///
1001    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
1002    /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
1003    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
1004    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
1005    /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
1006    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
1007    ///
1008    /// # Returns
1009    ///
1010    /// * `MobileBertSequenceClassificationOutput` containing:
1011    ///   - `logits` - `Tensor` of shape (*1*, *batch_size*) containing the logits for each of the alternatives given
1012    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
1013    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
1014    ///
1015    /// # Example
1016    ///
1017    /// ```no_run
1018    /// # use tch::{nn, Device, Tensor, no_grad};
1019    /// # use rust_bert::Config;
1020    /// # use std::path::Path;
1021    /// # use tch::kind::Kind::Int64;
1022    /// use rust_bert::mobilebert::{MobileBertConfig, MobileBertForMultipleChoice};
1023    /// # let config_path = Path::new("path/to/config.json");
1024    /// # let device = Device::Cpu;
1025    /// # let vs = nn::VarStore::new(device);
1026    /// # let config = MobileBertConfig::from_file(config_path);
1027    /// let model = MobileBertForMultipleChoice::new(&vs.root(), &config);
1028    /// let (batch_size, sequence_length) = (64, 128);
1029    /// let (num_choices, sequence_length) = (3, 128);
1030    /// let input_tensor = Tensor::rand(&[num_choices, sequence_length], (Int64, device));
1031    /// let attention_mask = Tensor::zeros(&[num_choices, sequence_length], (Int64, device));
1032    /// let token_type_ids = Tensor::zeros(&[num_choices, sequence_length], (Int64, device));
1033    /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
1034    ///     .expand(&[num_choices, sequence_length], true);
1035    ///
1036    /// let model_output = no_grad(|| {
1037    ///     model
1038    ///         .forward_t(
1039    ///             Some(&input_tensor),
1040    ///             Some(&token_type_ids),
1041    ///             Some(&position_ids),
1042    ///             None,
1043    ///             Some(&attention_mask),
1044    ///             false,
1045    ///         )
1046    ///         .unwrap()
1047    /// });
1048    /// ```
1049    pub fn forward_t(
1050        &self,
1051        input_ids: Option<&Tensor>,
1052        token_type_ids: Option<&Tensor>,
1053        position_ids: Option<&Tensor>,
1054        input_embeds: Option<&Tensor>,
1055        attention_mask: Option<&Tensor>,
1056        train: bool,
1057    ) -> Result<MobileBertSequenceClassificationOutput, RustBertError> {
1058        let (input_ids, num_choices) = match input_ids {
1059            Some(value) => (
1060                Some(value.view((-1, *value.size().last().unwrap()))),
1061                value.size()[1],
1062            ),
1063            None => (
1064                None,
1065                input_embeds
1066                    .as_ref()
1067                    .expect("At least one of input ids or input_embeds must be provided")
1068                    .size()[1],
1069            ),
1070        };
1071        let attention_mask =
1072            attention_mask.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
1073        let token_type_ids =
1074            token_type_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
1075        let input_embeds =
1076            input_embeds.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
1077        let position_ids =
1078            position_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
1079
1080        let mobilebert_output = self.mobilebert.forward_t(
1081            input_ids.as_ref(),
1082            token_type_ids.as_ref(),
1083            position_ids.as_ref(),
1084            input_embeds.as_ref(),
1085            attention_mask.as_ref(),
1086            train,
1087        )?;
1088
1089        let logits = mobilebert_output
1090            .pooled_output
1091            .unwrap()
1092            .apply_t(&self.dropout, train)
1093            .apply(&self.classifier)
1094            .view([-1, num_choices]);
1095
1096        Ok(MobileBertSequenceClassificationOutput {
1097            logits,
1098            all_hidden_states: mobilebert_output.all_hidden_states,
1099            all_attentions: mobilebert_output.all_attentions,
1100        })
1101    }
1102}
1103
1104/// # MobileBERT for token classification (e.g. NER, POS)
1105/// Token-level classifier predicting a label for each token provided. Note that because of wordpiece tokenization, the labels predicted are
1106/// not necessarily aligned with words in the sentence.
1107/// It is made of the following blocks:
1108/// - `mobilebert`: Base MobileBertModel
1109/// - `dropout`: Dropout layer before the last token-level predictions layer
1110/// - `classifier`: Linear layer for token classification
1111pub struct MobileBertForTokenClassification {
1112    mobilebert: MobileBertModel,
1113    dropout: Dropout,
1114    classifier: nn::Linear,
1115}
1116
1117impl MobileBertForTokenClassification {
1118    /// Build a new `MobileBertForMultipleChoice`
1119    ///
1120    /// # Arguments
1121    ///
1122    /// * `p` - Variable store path for the root of the MobileBERT model
1123    /// * `config` - `MobileBertConfig` object defining the model architecture and decoder status
1124    ///
1125    /// # Example
1126    ///
1127    /// ```no_run
1128    /// use rust_bert::mobilebert::{MobileBertConfig, MobileBertForTokenClassification};
1129    /// use rust_bert::Config;
1130    /// use std::path::Path;
1131    /// use tch::{nn, Device};
1132    ///
1133    /// let config_path = Path::new("path/to/config.json");
1134    /// let device = Device::Cpu;
1135    /// let p = nn::VarStore::new(device);
1136    /// let config = MobileBertConfig::from_file(config_path);
1137    /// let mobilebert = MobileBertForTokenClassification::new(&p.root(), &config).unwrap();
1138    /// ```
1139    pub fn new<'p, P>(
1140        p: P,
1141        config: &MobileBertConfig,
1142    ) -> Result<MobileBertForTokenClassification, RustBertError>
1143    where
1144        P: Borrow<nn::Path<'p>>,
1145    {
1146        let p = p.borrow();
1147
1148        let mobilebert = MobileBertModel::new(p / "mobilebert", config, false);
1149        let dropout = Dropout::new(config.hidden_dropout_prob);
1150        let num_labels = config
1151            .id2label
1152            .as_ref()
1153            .ok_or_else(|| {
1154                RustBertError::InvalidConfigurationError(
1155                    "num_labels not provided in configuration".to_string(),
1156                )
1157            })?
1158            .len() as i64;
1159        let classifier = nn::linear(
1160            p / "classifier",
1161            config.hidden_size,
1162            num_labels,
1163            Default::default(),
1164        );
1165
1166        Ok(MobileBertForTokenClassification {
1167            mobilebert,
1168            dropout,
1169            classifier,
1170        })
1171    }
1172
1173    /// Forward pass through the model
1174    ///
1175    /// # Arguments
1176    ///
1177    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
1178    /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
1179    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
1180    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
1181    /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
1182    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
1183    ///
1184    /// # Returns
1185    ///
1186    /// * `MobileBertTokenClassificationOutput` containing:
1187    ///   - `logits` - `Tensor` of shape (*batch size*, *sequence_length*, *num_labels*) containing the logits for each of the input tokens and classes
1188    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
1189    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
1190    ///
1191    /// # Example
1192    ///
1193    /// ```no_run
1194    /// # use tch::{nn, Device, Tensor, no_grad};
1195    /// # use rust_bert::Config;
1196    /// # use std::path::Path;
1197    /// # use tch::kind::Kind::Int64;
1198    /// use rust_bert::mobilebert::{MobileBertConfig, MobileBertForTokenClassification};
1199    /// # let config_path = Path::new("path/to/config.json");
1200    /// # let device = Device::Cpu;
1201    /// # let vs = nn::VarStore::new(device);
1202    /// # let config = MobileBertConfig::from_file(config_path);
1203    /// let model = MobileBertForTokenClassification::new(&vs.root(), &config).unwrap();
1204    /// let (batch_size, sequence_length) = (64, 128);
1205    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
1206    /// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
1207    /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
1208    /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
1209    ///     .expand(&[batch_size, sequence_length], true);
1210    ///
1211    /// let model_output = no_grad(|| {
1212    ///     model
1213    ///         .forward_t(
1214    ///             Some(&input_tensor),
1215    ///             Some(&token_type_ids),
1216    ///             Some(&position_ids),
1217    ///             None,
1218    ///             Some(&attention_mask),
1219    ///             false,
1220    ///         )
1221    ///         .unwrap()
1222    /// });
1223    /// ```
1224    pub fn forward_t(
1225        &self,
1226        input_ids: Option<&Tensor>,
1227        token_type_ids: Option<&Tensor>,
1228        position_ids: Option<&Tensor>,
1229        input_embeds: Option<&Tensor>,
1230        attention_mask: Option<&Tensor>,
1231        train: bool,
1232    ) -> Result<MobileBertTokenClassificationOutput, RustBertError> {
1233        let mobilebert_output = self.mobilebert.forward_t(
1234            input_ids,
1235            token_type_ids,
1236            position_ids,
1237            input_embeds,
1238            attention_mask,
1239            train,
1240        )?;
1241
1242        let logits = mobilebert_output
1243            .hidden_state
1244            .apply_t(&self.dropout, train)
1245            .apply(&self.classifier);
1246
1247        Ok(MobileBertTokenClassificationOutput {
1248            logits,
1249            all_hidden_states: mobilebert_output.all_hidden_states,
1250            all_attentions: mobilebert_output.all_attentions,
1251        })
1252    }
1253}
1254
1255/// Container for the MobileBert output.
1256pub struct MobileBertOutput {
1257    /// Last hidden states from the model
1258    pub hidden_state: Tensor,
1259    /// Pooled output
1260    pub pooled_output: Option<Tensor>,
1261    /// Hidden states for all intermediate layers
1262    pub all_hidden_states: Option<Vec<Tensor>>,
1263    /// Attention weights for all intermediate layers
1264    pub all_attentions: Option<Vec<Tensor>>,
1265}
1266
1267/// Container for the MobileBert masked LM model output.
1268pub struct MobileBertMaskedLMOutput {
1269    /// Logits for the vocabulary items at each sequence position
1270    pub logits: Tensor,
1271    /// Hidden states for all intermediate layers
1272    pub all_hidden_states: Option<Vec<Tensor>>,
1273    /// Attention weights for all intermediate layers
1274    pub all_attentions: Option<Vec<Tensor>>,
1275}
1276
1277/// Container for the MobileBert sequence classification model output.
1278pub struct MobileBertSequenceClassificationOutput {
1279    /// Logits for each input (sequence) for each target class
1280    pub logits: Tensor,
1281    /// Hidden states for all intermediate layers
1282    pub all_hidden_states: Option<Vec<Tensor>>,
1283    /// Attention weights for all intermediate layers
1284    pub all_attentions: Option<Vec<Tensor>>,
1285}
1286
1287/// Container for the MobileBert token classification model output.
1288pub struct MobileBertTokenClassificationOutput {
1289    /// Logits for each sequence item (token) for each target class
1290    pub logits: Tensor,
1291    /// Hidden states for all intermediate layers
1292    pub all_hidden_states: Option<Vec<Tensor>>,
1293    /// Attention weights for all intermediate layers
1294    pub all_attentions: Option<Vec<Tensor>>,
1295}
1296
1297/// Container for the MobileBert question answering model output.
1298pub struct MobileBertQuestionAnsweringOutput {
1299    /// Logits for the start position for token of each input sequence
1300    pub start_logits: Tensor,
1301    /// Logits for the end position for token of each input sequence
1302    pub end_logits: Tensor,
1303    /// Hidden states for all intermediate layers
1304    pub all_hidden_states: Option<Vec<Tensor>>,
1305    /// Attention weights for all intermediate layers
1306    pub all_attentions: Option<Vec<Tensor>>,
1307}