rust_bert/models/mobilebert/mobilebert_model.rs
1// Copyright (c) 2020 The Google AI Language Team Authors, The HuggingFace Inc. team and github/lonePatient
2// Copyright 2020 Guillaume Becquin
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6// http://www.apache.org/licenses/LICENSE-2.0
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use crate::common::activations::{Activation, TensorFunction};
14use crate::common::dropout::Dropout;
15use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair;
16use crate::mobilebert::embeddings::MobileBertEmbeddings;
17use crate::mobilebert::encoder::{MobileBertEncoder, MobileBertPooler};
18use crate::{Config, RustBertError};
19use serde::{Deserialize, Serialize};
20use std::borrow::Borrow;
21use std::collections::HashMap;
22use tch::nn::init::DEFAULT_KAIMING_UNIFORM;
23use tch::nn::{Init, LayerNormConfig, Module};
24use tch::{nn, Kind, Tensor};
25
26/// # MobileBERT Pretrained model weight files
27pub struct MobileBertModelResources;
28
29/// # MobileBERT Pretrained model config files
30pub struct MobileBertConfigResources;
31
32/// # MobileBERT Pretrained model vocab files
33pub struct MobileBertVocabResources;
34
35impl MobileBertModelResources {
36 /// Shared under Apache 2.0 license by the Google team at <https://huggingface.co/google/mobilebert-uncased>. Modified with conversion to C-array format.
37 pub const MOBILEBERT_UNCASED: (&'static str, &'static str) = (
38 "mobilebert-uncased/model",
39 "https://huggingface.co/google/mobilebert-uncased/resolve/main/rust_model.ot",
40 );
41 /// Shared under MIT license at <https://huggingface.co/mrm8488/mobilebert-finetuned-pos>. Modified with conversion to C-array format.
42 pub const MOBILEBERT_ENGLISH_POS: (&'static str, &'static str) = (
43 "mobilebert-finetuned-pos/model",
44 "https://huggingface.co/mrm8488/mobilebert-finetuned-pos/resolve/main/rust_model.ot",
45 );
46}
47
48impl MobileBertConfigResources {
49 /// Shared under Apache 2.0 license by the Google team at <https://huggingface.co/google/mobilebert-uncased>. Modified with conversion to C-array format.
50 pub const MOBILEBERT_UNCASED: (&'static str, &'static str) = (
51 "mobilebert-uncased/config",
52 "https://huggingface.co/google/mobilebert-uncased/resolve/main/config.json",
53 );
54 /// Shared under MIT license at <https://huggingface.co/mrm8488/mobilebert-finetuned-pos>. Modified with conversion to C-array format.
55 pub const MOBILEBERT_ENGLISH_POS: (&'static str, &'static str) = (
56 "mobilebert-finetuned-pos/config",
57 "https://huggingface.co/mrm8488/mobilebert-finetuned-pos/resolve/main/config.json",
58 );
59}
60
61impl MobileBertVocabResources {
62 /// Shared under Apache 2.0 license by the Google team at <https://huggingface.co/google/mobilebert-uncased>. Modified with conversion to C-array format.
63 pub const MOBILEBERT_UNCASED: (&'static str, &'static str) = (
64 "mobilebert-uncased/vocab",
65 "https://huggingface.co/google/mobilebert-uncased/resolve/main/vocab.txt",
66 );
67 /// Shared under MIT license at <https://huggingface.co/mrm8488/mobilebert-finetuned-pos>. Modified with conversion to C-array format.
68 pub const MOBILEBERT_ENGLISH_POS: (&'static str, &'static str) = (
69 "mobilebert-finetuned-pos/vocab",
70 "https://huggingface.co/mrm8488/mobilebert-finetuned-pos/resolve/main/vocab.txt",
71 );
72}
73
74#[allow(non_camel_case_types)]
75#[derive(Clone, Debug, Serialize, Deserialize, Copy)]
76/// # Normalization type to use for the MobileBERT model.
77/// `no_norm` uses a matrix multiplication with a set of learned weights, while `layer_norm` uses a
78/// build-in layer normalization module.
79pub enum NormalizationType {
80 layer_norm,
81 no_norm,
82}
83
84#[derive(Debug)]
85/// # No-normalization option for MobileBERT
86/// Basic module performing a linear multiplication using trained coefficients and bias
87pub struct NoNorm {
88 weight: Tensor,
89 bias: Tensor,
90}
91
92impl NoNorm {
93 /// Creates a new `NoNorm` layer of given hidden size.
94 ///
95 /// # Arguments:
96 ///
97 /// * hidden_size - input tensor's hidden size
98 ///
99 /// # Example
100 ///
101 /// ```no_run
102 /// use rust_bert::mobilebert::NoNorm;
103 /// use tch::{nn, Device};
104 /// let device = Device::Cpu;
105 /// let p = nn::VarStore::new(device);
106 /// let hidden_size = 512;
107 /// let no_norm = NoNorm::new(&p.root(), hidden_size);
108 /// ```
109 pub fn new<'p, P>(p: P, hidden_size: i64) -> NoNorm
110 where
111 P: Borrow<nn::Path<'p>>,
112 {
113 let p = p.borrow();
114 let weight = p.var("weight", &[hidden_size], Init::Const(1.0));
115 let bias = p.var("bias", &[hidden_size], Init::Const(0.0));
116 NoNorm { weight, bias }
117 }
118}
119
120impl Module for NoNorm {
121 fn forward(&self, xs: &Tensor) -> Tensor {
122 xs * &self.weight + &self.bias
123 }
124}
125
126pub enum NormalizationLayer {
127 LayerNorm(nn::LayerNorm),
128 NoNorm(NoNorm),
129}
130
131impl NormalizationLayer {
132 pub fn new<'p, P>(
133 p: P,
134 normalization_type: NormalizationType,
135 hidden_size: i64,
136 eps: Option<f64>,
137 ) -> NormalizationLayer
138 where
139 P: Borrow<nn::Path<'p>>,
140 {
141 let p = p.borrow();
142 match normalization_type {
143 NormalizationType::layer_norm => {
144 let layer_norm_config = LayerNormConfig {
145 eps: eps.unwrap_or(1e-12),
146 ..Default::default()
147 };
148 let layer_norm = nn::layer_norm(p, vec![hidden_size], layer_norm_config);
149 NormalizationLayer::LayerNorm(layer_norm)
150 }
151 NormalizationType::no_norm => {
152 let layer_norm = NoNorm::new(p, hidden_size);
153 NormalizationLayer::NoNorm(layer_norm)
154 }
155 }
156 }
157
158 pub fn forward(&self, input: &Tensor) -> Tensor {
159 match self {
160 NormalizationLayer::LayerNorm(ref layer_norm) => input.apply(layer_norm),
161 NormalizationLayer::NoNorm(ref layer_norm) => input.apply(layer_norm),
162 }
163 }
164}
165
166#[derive(Debug, Serialize, Deserialize)]
167/// # MobileBERT model configuration
168/// Defines the MobileBERT model architecture (e.g. number of layers, hidden layer size, label mapping...)
169pub struct MobileBertConfig {
170 pub hidden_act: Activation,
171 pub attention_probs_dropout_prob: f64,
172 pub hidden_dropout_prob: f64,
173 pub hidden_size: i64,
174 pub initializer_range: f64,
175 pub intermediate_size: i64,
176 pub max_position_embeddings: i64,
177 pub num_attention_heads: i64,
178 pub num_hidden_layers: i64,
179 pub type_vocab_size: i64,
180 pub vocab_size: i64,
181 pub embedding_size: i64,
182 pub layer_norm_eps: Option<f64>,
183 pub pad_token_idx: Option<i64>,
184 pub trigram_input: Option<bool>,
185 pub use_bottleneck: Option<bool>,
186 pub use_bottleneck_attention: Option<bool>,
187 pub intra_bottleneck_size: Option<i64>,
188 pub key_query_shared_bottleneck: Option<bool>,
189 pub num_feedforward_networks: Option<i64>,
190 pub normalization_type: Option<NormalizationType>,
191 pub output_attentions: Option<bool>,
192 pub output_hidden_states: Option<bool>,
193 pub classifier_activation: Option<bool>,
194 pub is_decoder: Option<bool>,
195 pub id2label: Option<HashMap<i64, String>>,
196 pub label2id: Option<HashMap<String, i64>>,
197}
198
199impl Config for MobileBertConfig {}
200
201impl Default for MobileBertConfig {
202 fn default() -> Self {
203 MobileBertConfig {
204 hidden_act: Activation::relu,
205 attention_probs_dropout_prob: 0.1,
206 hidden_dropout_prob: 0.0,
207 hidden_size: 512,
208 initializer_range: 0.02,
209 intermediate_size: 512,
210 max_position_embeddings: 512,
211 num_attention_heads: 4,
212 num_hidden_layers: 24,
213 type_vocab_size: 2,
214 vocab_size: 30522,
215 embedding_size: 128,
216 layer_norm_eps: Some(1e-12),
217 pad_token_idx: Some(0),
218 trigram_input: Some(true),
219 use_bottleneck: Some(true),
220 use_bottleneck_attention: Some(false),
221 intra_bottleneck_size: Some(128),
222 key_query_shared_bottleneck: Some(true),
223 num_feedforward_networks: Some(4),
224 normalization_type: Some(NormalizationType::no_norm),
225 output_attentions: None,
226 output_hidden_states: None,
227 classifier_activation: None,
228 is_decoder: None,
229 id2label: None,
230 label2id: None,
231 }
232 }
233}
234
235pub struct MobileBertPredictionHeadTransform {
236 dense: nn::Linear,
237 activation_function: TensorFunction,
238 layer_norm: NormalizationLayer,
239}
240
241impl MobileBertPredictionHeadTransform {
242 pub fn new<'p, P>(p: P, config: &MobileBertConfig) -> MobileBertPredictionHeadTransform
243 where
244 P: Borrow<nn::Path<'p>>,
245 {
246 let p = p.borrow();
247 let dense = nn::linear(
248 p / "dense",
249 config.hidden_size,
250 config.hidden_size,
251 Default::default(),
252 );
253 let activation_function = config.hidden_act.get_function();
254 let layer_norm = NormalizationLayer::new(
255 p / "LayerNorm",
256 NormalizationType::layer_norm,
257 config.hidden_size,
258 config.layer_norm_eps,
259 );
260 MobileBertPredictionHeadTransform {
261 dense,
262 activation_function,
263 layer_norm,
264 }
265 }
266
267 pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
268 let hidden_states = hidden_states.apply(&self.dense);
269 let hidden_states = self.activation_function.get_fn()(&hidden_states);
270 self.layer_norm.forward(&hidden_states)
271 }
272}
273
274pub struct MobileBertLMPredictionHead {
275 transform: MobileBertPredictionHeadTransform,
276 dense_weight: Tensor,
277 bias: Tensor,
278}
279
280impl MobileBertLMPredictionHead {
281 pub fn new<'p, P>(p: P, config: &MobileBertConfig) -> MobileBertLMPredictionHead
282 where
283 P: Borrow<nn::Path<'p>>,
284 {
285 let p = p.borrow();
286
287 let transform = MobileBertPredictionHeadTransform::new(p / "transform", config);
288
289 let dense_p = p / "dense";
290 let dense_weight = dense_p.var(
291 "weight",
292 &[
293 config.hidden_size - config.embedding_size,
294 config.vocab_size,
295 ],
296 DEFAULT_KAIMING_UNIFORM,
297 );
298 let bias = p.var("bias", &[config.vocab_size], Init::Const(0.0));
299
300 MobileBertLMPredictionHead {
301 transform,
302 dense_weight,
303 bias,
304 }
305 }
306
307 pub fn forward(&self, hidden_states: &Tensor, embeddings: &Tensor) -> Tensor {
308 let hidden_states = self.transform.forward(hidden_states);
309 let hidden_states = hidden_states.matmul(&Tensor::cat(
310 &[&embeddings.transpose(0, 1), &self.dense_weight],
311 0,
312 ));
313 hidden_states + &self.bias
314 }
315}
316
317pub struct MobileBertOnlyMLMHead {
318 predictions: MobileBertLMPredictionHead,
319}
320
321impl MobileBertOnlyMLMHead {
322 pub fn new<'p, P>(p: P, config: &MobileBertConfig) -> MobileBertOnlyMLMHead
323 where
324 P: Borrow<nn::Path<'p>>,
325 {
326 let p = p.borrow();
327
328 let predictions = MobileBertLMPredictionHead::new(p / "predictions", config);
329 MobileBertOnlyMLMHead { predictions }
330 }
331
332 pub fn forward(&self, hidden_states: &Tensor, embeddings: &Tensor) -> Tensor {
333 self.predictions.forward(hidden_states, embeddings)
334 }
335}
336
337/// # MobileBertModel Base model
338/// Base architecture for MobileBERT models. Task-specific models will be built from this common base model
339/// It is made of the following blocks:
340/// - `embeddings`: Word, token type and position embeddings
341/// - `encoder`: `MobileBertEncoder` made of a stack of `MobileBertLayer`
342/// - `pooler`: Optional `MobileBertPooler` taking the first sequence element hidden state for sequence-level tasks
343/// - `position_ids` preset position ids tensor used in case they are not provided by the user
344pub struct MobileBertModel {
345 embeddings: MobileBertEmbeddings,
346 encoder: MobileBertEncoder,
347 pooler: Option<MobileBertPooler>,
348 position_ids: Tensor,
349}
350
351impl MobileBertModel {
352 /// Build a new `MobileBertModel`
353 ///
354 /// # Arguments
355 ///
356 /// * `p` - Variable store path for the root of the MobileBERT model
357 /// * `config` - `MobileBertConfig` object defining the model architecture and decoder status
358 /// * `add_pooling_layer` - boolean flag indicating if a pooling layer should be added after the encoder
359 ///
360 /// # Example
361 ///
362 /// ```no_run
363 /// use rust_bert::mobilebert::{MobileBertConfig, MobileBertModel};
364 /// use rust_bert::Config;
365 /// use std::path::Path;
366 /// use tch::{nn, Device};
367 ///
368 /// let config_path = Path::new("path/to/config.json");
369 /// let device = Device::Cpu;
370 /// let p = nn::VarStore::new(device);
371 /// let config = MobileBertConfig::from_file(config_path);
372 /// let add_pooling_layer = true;
373 /// let mobilebert = MobileBertModel::new(&p.root() / "mobilebert", &config, add_pooling_layer);
374 /// ```
375 pub fn new<'p, P>(p: P, config: &MobileBertConfig, add_pooling_layer: bool) -> MobileBertModel
376 where
377 P: Borrow<nn::Path<'p>>,
378 {
379 let p = p.borrow();
380
381 let embeddings = MobileBertEmbeddings::new(p / "embeddings", config);
382 let encoder = MobileBertEncoder::new(p / "encoder", config);
383 let pooler = if add_pooling_layer {
384 Some(MobileBertPooler::new(p / "pooler", config))
385 } else {
386 None
387 };
388 let position_ids =
389 Tensor::arange(config.max_position_embeddings, (Kind::Int64, p.device()))
390 .expand([1, -1], true);
391 MobileBertModel {
392 embeddings,
393 encoder,
394 pooler,
395 position_ids,
396 }
397 }
398
399 /// Forward pass through the model
400 ///
401 /// # Arguments
402 ///
403 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
404 /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
405 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
406 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
407 /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
408 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
409 ///
410 /// # Returns
411 ///
412 /// * `MobileBertOutput` containing:
413 /// - `hidden_state` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*)
414 /// - `pooled_output` - Optional `Tensor` of shape (*batch size*, *hidden_size*) if the model was created with an optional pooling layer
415 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
416 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
417 ///
418 /// # Example
419 ///
420 /// ```no_run
421 /// # use tch::{nn, Device, Tensor, no_grad};
422 /// # use rust_bert::Config;
423 /// # use std::path::Path;
424 /// # use tch::kind::Kind::Int64;
425 /// use rust_bert::mobilebert::{MobileBertConfig, MobileBertModel};
426 /// # let config_path = Path::new("path/to/config.json");
427 /// # let device = Device::Cpu;
428 /// # let vs = nn::VarStore::new(device);
429 /// # let config = MobileBertConfig::from_file(config_path);
430 /// let add_pooling_layer = true;
431 /// let model = MobileBertModel::new(&vs.root(), &config, add_pooling_layer);
432 /// let (batch_size, sequence_length) = (64, 128);
433 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
434 /// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
435 /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
436 /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
437 /// .expand(&[batch_size, sequence_length], true);
438 ///
439 /// let model_output = no_grad(|| {
440 /// model
441 /// .forward_t(
442 /// Some(&input_tensor),
443 /// Some(&token_type_ids),
444 /// Some(&position_ids),
445 /// None,
446 /// Some(&attention_mask),
447 /// false,
448 /// )
449 /// .unwrap()
450 /// });
451 /// ```
452 pub fn forward_t(
453 &self,
454 input_ids: Option<&Tensor>,
455 token_type_ids: Option<&Tensor>,
456 position_ids: Option<&Tensor>,
457 input_embeds: Option<&Tensor>,
458 attention_mask: Option<&Tensor>,
459 train: bool,
460 ) -> Result<MobileBertOutput, RustBertError> {
461 let (input_shape, device) =
462 get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeds)?;
463
464 let calc_attention_mask = if attention_mask.is_none() {
465 Some(Tensor::ones(input_shape.as_slice(), (Kind::Int64, device)))
466 } else {
467 None
468 };
469
470 let calc_token_type_ids = if token_type_ids.is_none() {
471 Some(Tensor::zeros(input_shape.as_slice(), (Kind::Int64, device)))
472 } else {
473 None
474 };
475
476 let calc_position_ids = if position_ids.is_none() {
477 Some(self.position_ids.slice(1, 0, input_shape[1], 1))
478 } else {
479 None
480 };
481
482 let position_ids = position_ids.unwrap_or_else(|| calc_position_ids.as_ref().unwrap());
483
484 let attention_mask =
485 attention_mask.unwrap_or_else(|| calc_attention_mask.as_ref().unwrap());
486 let attention_mask = match attention_mask.dim() {
487 3 => attention_mask.unsqueeze(1),
488 2 => attention_mask.unsqueeze(1).unsqueeze(1),
489 _ => {
490 return Err(RustBertError::ValueError(
491 "Invalid attention mask dimension, must be 2 or 3".into(),
492 ));
493 }
494 };
495
496 let token_type_ids =
497 token_type_ids.unwrap_or_else(|| calc_token_type_ids.as_ref().unwrap());
498
499 let embedding_output = self.embeddings.forward_t(
500 input_ids,
501 token_type_ids,
502 position_ids,
503 input_embeds,
504 train,
505 )?;
506
507 let attention_mask: Tensor = ((attention_mask.ones_like() - attention_mask) * -10000.0)
508 .to_kind(embedding_output.kind());
509
510 let encoder_output =
511 self.encoder
512 .forward_t(&embedding_output, Some(&attention_mask), train);
513
514 let pooled_output = if let Some(pooler) = &self.pooler {
515 Some(pooler.forward(&encoder_output.hidden_state))
516 } else {
517 None
518 };
519
520 Ok(MobileBertOutput {
521 hidden_state: encoder_output.hidden_state,
522 pooled_output,
523 all_hidden_states: encoder_output.all_hidden_states,
524 all_attentions: encoder_output.all_attentions,
525 })
526 }
527
528 fn get_embeddings(&self) -> &Tensor {
529 &self.embeddings.word_embeddings.ws
530 }
531}
532
533/// # MobileBERT for masked language model
534/// Base MobileBERT model with a masked language model head to predict missing tokens, for example `"Looks like one [MASK] is missing" -> "person"`
535/// It is made of the following blocks:
536/// - `mobilebert`: Base MobileBertModel
537/// - `classifier`: MobileBERT LM prediction head
538pub struct MobileBertForMaskedLM {
539 mobilebert: MobileBertModel,
540 classifier: MobileBertOnlyMLMHead,
541}
542
543impl MobileBertForMaskedLM {
544 /// Build a new `MobileBertForMaskedLM`
545 ///
546 /// # Arguments
547 ///
548 /// * `p` - Variable store path for the root of the MobileBERT model
549 /// * `config` - `MobileBertConfig` object defining the model architecture and decoder status
550 ///
551 /// # Example
552 ///
553 /// ```no_run
554 /// use rust_bert::mobilebert::{MobileBertConfig, MobileBertForMaskedLM};
555 /// use rust_bert::Config;
556 /// use std::path::Path;
557 /// use tch::{nn, Device};
558 ///
559 /// let config_path = Path::new("path/to/config.json");
560 /// let device = Device::Cpu;
561 /// let p = nn::VarStore::new(device);
562 /// let config = MobileBertConfig::from_file(config_path);
563 /// let mobilebert = MobileBertForMaskedLM::new(&p.root(), &config);
564 /// ```
565 pub fn new<'p, P>(p: P, config: &MobileBertConfig) -> MobileBertForMaskedLM
566 where
567 P: Borrow<nn::Path<'p>>,
568 {
569 let p = p.borrow();
570
571 let mobilebert = MobileBertModel::new(p / "mobilebert", config, false);
572 let classifier = MobileBertOnlyMLMHead::new(p / "cls", config);
573 MobileBertForMaskedLM {
574 mobilebert,
575 classifier,
576 }
577 }
578
579 /// Forward pass through the model
580 ///
581 /// # Arguments
582 ///
583 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
584 /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
585 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
586 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
587 /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
588 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
589 ///
590 /// # Returns
591 ///
592 /// * `MobileBertMaskedLMOutput` containing:
593 /// - `logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*)
594 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
595 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
596 ///
597 /// # Example
598 ///
599 /// ```no_run
600 /// # use tch::{nn, Device, Tensor, no_grad};
601 /// # use rust_bert::Config;
602 /// # use std::path::Path;
603 /// # use tch::kind::Kind::Int64;
604 /// use rust_bert::mobilebert::{MobileBertConfig, MobileBertForMaskedLM};
605 /// # let config_path = Path::new("path/to/config.json");
606 /// # let device = Device::Cpu;
607 /// # let vs = nn::VarStore::new(device);
608 /// # let config = MobileBertConfig::from_file(config_path);
609 /// let model = MobileBertForMaskedLM::new(&vs.root(), &config);
610 /// let (batch_size, sequence_length) = (64, 128);
611 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
612 /// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
613 /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
614 /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
615 /// .expand(&[batch_size, sequence_length], true);
616 ///
617 /// let model_output = no_grad(|| {
618 /// model
619 /// .forward_t(
620 /// Some(&input_tensor),
621 /// Some(&token_type_ids),
622 /// Some(&position_ids),
623 /// None,
624 /// Some(&attention_mask),
625 /// false,
626 /// )
627 /// .unwrap()
628 /// });
629 /// ```
630 pub fn forward_t(
631 &self,
632 input_ids: Option<&Tensor>,
633 token_type_ids: Option<&Tensor>,
634 position_ids: Option<&Tensor>,
635 input_embeds: Option<&Tensor>,
636 attention_mask: Option<&Tensor>,
637 train: bool,
638 ) -> Result<MobileBertMaskedLMOutput, RustBertError> {
639 let mobilebert_output = self.mobilebert.forward_t(
640 input_ids,
641 token_type_ids,
642 position_ids,
643 input_embeds,
644 attention_mask,
645 train,
646 )?;
647
648 let logits = self.classifier.forward(
649 &mobilebert_output.hidden_state,
650 self.mobilebert.get_embeddings(),
651 );
652
653 Ok(MobileBertMaskedLMOutput {
654 logits,
655 all_hidden_states: mobilebert_output.all_hidden_states,
656 all_attentions: mobilebert_output.all_attentions,
657 })
658 }
659}
660
661/// # MobileBERT for sequence classification
662/// Base MobileBERT model with a classifier head to perform sentence or document-level classification
663/// It is made of the following blocks:
664/// - `mobilebert`: Base MobileBertModel
665/// - `dropout`: Dropout layer before the last linear layer
666/// - `classifier`: linear layer mapping from hidden to the number of classes to predict
667pub struct MobileBertForSequenceClassification {
668 mobilebert: MobileBertModel,
669 dropout: Dropout,
670 classifier: nn::Linear,
671}
672
673impl MobileBertForSequenceClassification {
674 /// Build a new `MobileBertForSequenceClassification`
675 ///
676 /// # Arguments
677 ///
678 /// * `p` - Variable store path for the root of the MobileBERT model
679 /// * `config` - `MobileBertConfig` object defining the model architecture and decoder status
680 ///
681 /// # Example
682 ///
683 /// ```no_run
684 /// use rust_bert::mobilebert::{MobileBertConfig, MobileBertForSequenceClassification};
685 /// use rust_bert::Config;
686 /// use std::path::Path;
687 /// use tch::{nn, Device};
688 ///
689 /// let config_path = Path::new("path/to/config.json");
690 /// let device = Device::Cpu;
691 /// let p = nn::VarStore::new(device);
692 /// let config = MobileBertConfig::from_file(config_path);
693 /// let mobilebert = MobileBertForSequenceClassification::new(&p.root(), &config).unwrap();
694 /// ```
695 pub fn new<'p, P>(
696 p: P,
697 config: &MobileBertConfig,
698 ) -> Result<MobileBertForSequenceClassification, RustBertError>
699 where
700 P: Borrow<nn::Path<'p>>,
701 {
702 let p = p.borrow();
703
704 let mobilebert = MobileBertModel::new(p / "mobilebert", config, true);
705 let dropout = Dropout::new(config.hidden_dropout_prob);
706 let num_labels = config
707 .id2label
708 .as_ref()
709 .ok_or_else(|| {
710 RustBertError::InvalidConfigurationError(
711 "num_labels not provided in configuration".to_string(),
712 )
713 })?
714 .len() as i64;
715 let classifier = nn::linear(
716 p / "classifier",
717 config.hidden_size,
718 num_labels,
719 Default::default(),
720 );
721 Ok(MobileBertForSequenceClassification {
722 mobilebert,
723 dropout,
724 classifier,
725 })
726 }
727
728 /// Forward pass through the model
729 ///
730 /// # Arguments
731 ///
732 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
733 /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
734 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
735 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
736 /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
737 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
738 ///
739 /// # Returns
740 ///
741 /// * `MobileBertSequenceClassificationOutput` containing:
742 /// - `logits` - `Tensor` of shape (*batch size*, *num_classes*)
743 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
744 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
745 ///
746 /// # Example
747 ///
748 /// ```no_run
749 /// # use tch::{nn, Device, Tensor, no_grad};
750 /// # use rust_bert::Config;
751 /// # use std::path::Path;
752 /// # use tch::kind::Kind::Int64;
753 /// use rust_bert::mobilebert::{MobileBertConfig, MobileBertForSequenceClassification};
754 /// # let config_path = Path::new("path/to/config.json");
755 /// # let device = Device::Cpu;
756 /// # let vs = nn::VarStore::new(device);
757 /// # let config = MobileBertConfig::from_file(config_path);
758 /// let model = MobileBertForSequenceClassification::new(&vs.root(), &config).unwrap();
759 /// let (batch_size, sequence_length) = (64, 128);
760 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
761 /// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
762 /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
763 /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
764 /// .expand(&[batch_size, sequence_length], true);
765 ///
766 /// let model_output = no_grad(|| {
767 /// model
768 /// .forward_t(
769 /// Some(&input_tensor),
770 /// Some(&token_type_ids),
771 /// Some(&position_ids),
772 /// None,
773 /// Some(&attention_mask),
774 /// false,
775 /// )
776 /// .unwrap()
777 /// });
778 /// ```
779 pub fn forward_t(
780 &self,
781 input_ids: Option<&Tensor>,
782 token_type_ids: Option<&Tensor>,
783 position_ids: Option<&Tensor>,
784 input_embeds: Option<&Tensor>,
785 attention_mask: Option<&Tensor>,
786 train: bool,
787 ) -> Result<MobileBertSequenceClassificationOutput, RustBertError> {
788 let mobilebert_output = self.mobilebert.forward_t(
789 input_ids,
790 token_type_ids,
791 position_ids,
792 input_embeds,
793 attention_mask,
794 train,
795 )?;
796
797 let logits = mobilebert_output
798 .pooled_output
799 .unwrap()
800 .apply_t(&self.dropout, train)
801 .apply(&self.classifier);
802
803 Ok(MobileBertSequenceClassificationOutput {
804 logits,
805 all_hidden_states: mobilebert_output.all_hidden_states,
806 all_attentions: mobilebert_output.all_attentions,
807 })
808 }
809}
810
811/// # MobileBERT for question answering
812/// Extractive question-answering model based on a MobileBERT language model. Identifies the segment of a context that answers a provided question.
813/// Please note that a significant amount of pre- and post-processing is required to perform end-to-end question answering.
814/// See the question answering pipeline (also provided in this crate) for more details.
815/// It is made of the following blocks:
816/// - `mobilebert`: Base MobileBertModel
817/// - `qa_outputs`: Linear layer for question answering
818pub struct MobileBertForQuestionAnswering {
819 mobilebert: MobileBertModel,
820 qa_outputs: nn::Linear,
821}
822
823impl MobileBertForQuestionAnswering {
824 /// Build a new `MobileBertForQuestionAnswering`
825 ///
826 /// # Arguments
827 ///
828 /// * `p` - Variable store path for the root of the MobileBERT model
829 /// * `config` - `MobileBertConfig` object defining the model architecture and decoder status
830 ///
831 /// # Example
832 ///
833 /// ```no_run
834 /// use rust_bert::mobilebert::{MobileBertConfig, MobileBertForQuestionAnswering};
835 /// use rust_bert::Config;
836 /// use std::path::Path;
837 /// use tch::{nn, Device};
838 ///
839 /// let config_path = Path::new("path/to/config.json");
840 /// let device = Device::Cpu;
841 /// let p = nn::VarStore::new(device);
842 /// let config = MobileBertConfig::from_file(config_path);
843 /// let mobilebert = MobileBertForQuestionAnswering::new(&p.root(), &config);
844 /// ```
845 pub fn new<'p, P>(p: P, config: &MobileBertConfig) -> MobileBertForQuestionAnswering
846 where
847 P: Borrow<nn::Path<'p>>,
848 {
849 let p = p.borrow();
850
851 let mobilebert = MobileBertModel::new(p / "mobilebert", config, false);
852 let qa_outputs = nn::linear(p / "qa_outputs", config.hidden_size, 2, Default::default());
853
854 MobileBertForQuestionAnswering {
855 mobilebert,
856 qa_outputs,
857 }
858 }
859
860 /// Forward pass through the model
861 ///
862 /// # Arguments
863 ///
864 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
865 /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
866 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
867 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
868 /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
869 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
870 ///
871 /// # Returns
872 ///
873 /// * `MobileBertQuestionAnsweringOutput` containing:
874 /// - `start_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for start of the answer
875 /// - `end_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for end of the answer
876 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
877 /// - `all_attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
878 ///
879 /// # Example
880 ///
881 /// ```no_run
882 /// # use tch::{nn, Device, Tensor, no_grad};
883 /// # use rust_bert::Config;
884 /// # use std::path::Path;
885 /// # use tch::kind::Kind::Int64;
886 /// use rust_bert::mobilebert::{MobileBertConfig, MobileBertForQuestionAnswering};
887 /// # let config_path = Path::new("path/to/config.json");
888 /// # let device = Device::Cpu;
889 /// # let vs = nn::VarStore::new(device);
890 /// # let config = MobileBertConfig::from_file(config_path);
891 /// let model = MobileBertForQuestionAnswering::new(&vs.root(), &config);
892 /// let (batch_size, sequence_length) = (64, 128);
893 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
894 /// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
895 /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
896 /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
897 /// .expand(&[batch_size, sequence_length], true);
898 ///
899 /// let model_output = no_grad(|| {
900 /// model
901 /// .forward_t(
902 /// Some(&input_tensor),
903 /// Some(&token_type_ids),
904 /// Some(&position_ids),
905 /// None,
906 /// Some(&attention_mask),
907 /// false,
908 /// )
909 /// .unwrap()
910 /// });
911 /// ```
912 pub fn forward_t(
913 &self,
914 input_ids: Option<&Tensor>,
915 token_type_ids: Option<&Tensor>,
916 position_ids: Option<&Tensor>,
917 input_embeds: Option<&Tensor>,
918 attention_mask: Option<&Tensor>,
919 train: bool,
920 ) -> Result<MobileBertQuestionAnsweringOutput, RustBertError> {
921 let mobilebert_output = self.mobilebert.forward_t(
922 input_ids,
923 token_type_ids,
924 position_ids,
925 input_embeds,
926 attention_mask,
927 train,
928 )?;
929
930 let sequence_output = mobilebert_output.hidden_state.apply(&self.qa_outputs);
931 let logits = sequence_output.split(1, -1);
932 let (start_logits, end_logits) = (&logits[0], &logits[1]);
933 let start_logits = start_logits.squeeze_dim(-1);
934 let end_logits = end_logits.squeeze_dim(-1);
935
936 Ok(MobileBertQuestionAnsweringOutput {
937 start_logits,
938 end_logits,
939 all_hidden_states: mobilebert_output.all_hidden_states,
940 all_attentions: mobilebert_output.all_attentions,
941 })
942 }
943}
944
945/// # MobileBERT for multiple choices
946/// Multiple choices model using a MobileBERT base model and a linear classifier.
947/// Input should be in the form `[CLS] Context [SEP] Possible choice [SEP]`. The choice is made along the batch axis,
948/// assuming all elements of the batch are alternatives to be chosen from for a given context.
949/// It is made of the following blocks:
950/// - `mobilebert`: Base MobileBertModel
951/// - `dropout`: Dropout layer before the last start/end logits prediction
952/// - `classifier`: Linear layer for multiple choices
953pub struct MobileBertForMultipleChoice {
954 mobilebert: MobileBertModel,
955 dropout: Dropout,
956 classifier: nn::Linear,
957}
958
959impl MobileBertForMultipleChoice {
960 /// Build a new `MobileBertForMultipleChoice`
961 ///
962 /// # Arguments
963 ///
964 /// * `p` - Variable store path for the root of the MobileBERT model
965 /// * `config` - `MobileBertConfig` object defining the model architecture and decoder status
966 ///
967 /// # Example
968 ///
969 /// ```no_run
970 /// use rust_bert::mobilebert::{MobileBertConfig, MobileBertForMultipleChoice};
971 /// use rust_bert::Config;
972 /// use std::path::Path;
973 /// use tch::{nn, Device};
974 ///
975 /// let config_path = Path::new("path/to/config.json");
976 /// let device = Device::Cpu;
977 /// let p = nn::VarStore::new(device);
978 /// let config = MobileBertConfig::from_file(config_path);
979 /// let mobilebert = MobileBertForMultipleChoice::new(&p.root(), &config);
980 /// ```
981 pub fn new<'p, P>(p: P, config: &MobileBertConfig) -> MobileBertForMultipleChoice
982 where
983 P: Borrow<nn::Path<'p>>,
984 {
985 let p = p.borrow();
986
987 let mobilebert = MobileBertModel::new(p / "mobilebert", config, true);
988 let dropout = Dropout::new(config.hidden_dropout_prob);
989 let classifier = nn::linear(p / "classifier", config.hidden_size, 1, Default::default());
990 MobileBertForMultipleChoice {
991 mobilebert,
992 dropout,
993 classifier,
994 }
995 }
996
997 /// Forward pass through the model
998 ///
999 /// # Arguments
1000 ///
1001 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
1002 /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
1003 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
1004 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
1005 /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
1006 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
1007 ///
1008 /// # Returns
1009 ///
1010 /// * `MobileBertSequenceClassificationOutput` containing:
1011 /// - `logits` - `Tensor` of shape (*1*, *batch_size*) containing the logits for each of the alternatives given
1012 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
1013 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
1014 ///
1015 /// # Example
1016 ///
1017 /// ```no_run
1018 /// # use tch::{nn, Device, Tensor, no_grad};
1019 /// # use rust_bert::Config;
1020 /// # use std::path::Path;
1021 /// # use tch::kind::Kind::Int64;
1022 /// use rust_bert::mobilebert::{MobileBertConfig, MobileBertForMultipleChoice};
1023 /// # let config_path = Path::new("path/to/config.json");
1024 /// # let device = Device::Cpu;
1025 /// # let vs = nn::VarStore::new(device);
1026 /// # let config = MobileBertConfig::from_file(config_path);
1027 /// let model = MobileBertForMultipleChoice::new(&vs.root(), &config);
1028 /// let (batch_size, sequence_length) = (64, 128);
1029 /// let (num_choices, sequence_length) = (3, 128);
1030 /// let input_tensor = Tensor::rand(&[num_choices, sequence_length], (Int64, device));
1031 /// let attention_mask = Tensor::zeros(&[num_choices, sequence_length], (Int64, device));
1032 /// let token_type_ids = Tensor::zeros(&[num_choices, sequence_length], (Int64, device));
1033 /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
1034 /// .expand(&[num_choices, sequence_length], true);
1035 ///
1036 /// let model_output = no_grad(|| {
1037 /// model
1038 /// .forward_t(
1039 /// Some(&input_tensor),
1040 /// Some(&token_type_ids),
1041 /// Some(&position_ids),
1042 /// None,
1043 /// Some(&attention_mask),
1044 /// false,
1045 /// )
1046 /// .unwrap()
1047 /// });
1048 /// ```
1049 pub fn forward_t(
1050 &self,
1051 input_ids: Option<&Tensor>,
1052 token_type_ids: Option<&Tensor>,
1053 position_ids: Option<&Tensor>,
1054 input_embeds: Option<&Tensor>,
1055 attention_mask: Option<&Tensor>,
1056 train: bool,
1057 ) -> Result<MobileBertSequenceClassificationOutput, RustBertError> {
1058 let (input_ids, num_choices) = match input_ids {
1059 Some(value) => (
1060 Some(value.view((-1, *value.size().last().unwrap()))),
1061 value.size()[1],
1062 ),
1063 None => (
1064 None,
1065 input_embeds
1066 .as_ref()
1067 .expect("At least one of input ids or input_embeds must be provided")
1068 .size()[1],
1069 ),
1070 };
1071 let attention_mask =
1072 attention_mask.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
1073 let token_type_ids =
1074 token_type_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
1075 let input_embeds =
1076 input_embeds.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
1077 let position_ids =
1078 position_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
1079
1080 let mobilebert_output = self.mobilebert.forward_t(
1081 input_ids.as_ref(),
1082 token_type_ids.as_ref(),
1083 position_ids.as_ref(),
1084 input_embeds.as_ref(),
1085 attention_mask.as_ref(),
1086 train,
1087 )?;
1088
1089 let logits = mobilebert_output
1090 .pooled_output
1091 .unwrap()
1092 .apply_t(&self.dropout, train)
1093 .apply(&self.classifier)
1094 .view([-1, num_choices]);
1095
1096 Ok(MobileBertSequenceClassificationOutput {
1097 logits,
1098 all_hidden_states: mobilebert_output.all_hidden_states,
1099 all_attentions: mobilebert_output.all_attentions,
1100 })
1101 }
1102}
1103
1104/// # MobileBERT for token classification (e.g. NER, POS)
1105/// Token-level classifier predicting a label for each token provided. Note that because of wordpiece tokenization, the labels predicted are
1106/// not necessarily aligned with words in the sentence.
1107/// It is made of the following blocks:
1108/// - `mobilebert`: Base MobileBertModel
1109/// - `dropout`: Dropout layer before the last token-level predictions layer
1110/// - `classifier`: Linear layer for token classification
1111pub struct MobileBertForTokenClassification {
1112 mobilebert: MobileBertModel,
1113 dropout: Dropout,
1114 classifier: nn::Linear,
1115}
1116
1117impl MobileBertForTokenClassification {
1118 /// Build a new `MobileBertForMultipleChoice`
1119 ///
1120 /// # Arguments
1121 ///
1122 /// * `p` - Variable store path for the root of the MobileBERT model
1123 /// * `config` - `MobileBertConfig` object defining the model architecture and decoder status
1124 ///
1125 /// # Example
1126 ///
1127 /// ```no_run
1128 /// use rust_bert::mobilebert::{MobileBertConfig, MobileBertForTokenClassification};
1129 /// use rust_bert::Config;
1130 /// use std::path::Path;
1131 /// use tch::{nn, Device};
1132 ///
1133 /// let config_path = Path::new("path/to/config.json");
1134 /// let device = Device::Cpu;
1135 /// let p = nn::VarStore::new(device);
1136 /// let config = MobileBertConfig::from_file(config_path);
1137 /// let mobilebert = MobileBertForTokenClassification::new(&p.root(), &config).unwrap();
1138 /// ```
1139 pub fn new<'p, P>(
1140 p: P,
1141 config: &MobileBertConfig,
1142 ) -> Result<MobileBertForTokenClassification, RustBertError>
1143 where
1144 P: Borrow<nn::Path<'p>>,
1145 {
1146 let p = p.borrow();
1147
1148 let mobilebert = MobileBertModel::new(p / "mobilebert", config, false);
1149 let dropout = Dropout::new(config.hidden_dropout_prob);
1150 let num_labels = config
1151 .id2label
1152 .as_ref()
1153 .ok_or_else(|| {
1154 RustBertError::InvalidConfigurationError(
1155 "num_labels not provided in configuration".to_string(),
1156 )
1157 })?
1158 .len() as i64;
1159 let classifier = nn::linear(
1160 p / "classifier",
1161 config.hidden_size,
1162 num_labels,
1163 Default::default(),
1164 );
1165
1166 Ok(MobileBertForTokenClassification {
1167 mobilebert,
1168 dropout,
1169 classifier,
1170 })
1171 }
1172
1173 /// Forward pass through the model
1174 ///
1175 /// # Arguments
1176 ///
1177 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
1178 /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
1179 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
1180 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
1181 /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
1182 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
1183 ///
1184 /// # Returns
1185 ///
1186 /// * `MobileBertTokenClassificationOutput` containing:
1187 /// - `logits` - `Tensor` of shape (*batch size*, *sequence_length*, *num_labels*) containing the logits for each of the input tokens and classes
1188 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
1189 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
1190 ///
1191 /// # Example
1192 ///
1193 /// ```no_run
1194 /// # use tch::{nn, Device, Tensor, no_grad};
1195 /// # use rust_bert::Config;
1196 /// # use std::path::Path;
1197 /// # use tch::kind::Kind::Int64;
1198 /// use rust_bert::mobilebert::{MobileBertConfig, MobileBertForTokenClassification};
1199 /// # let config_path = Path::new("path/to/config.json");
1200 /// # let device = Device::Cpu;
1201 /// # let vs = nn::VarStore::new(device);
1202 /// # let config = MobileBertConfig::from_file(config_path);
1203 /// let model = MobileBertForTokenClassification::new(&vs.root(), &config).unwrap();
1204 /// let (batch_size, sequence_length) = (64, 128);
1205 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
1206 /// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
1207 /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
1208 /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
1209 /// .expand(&[batch_size, sequence_length], true);
1210 ///
1211 /// let model_output = no_grad(|| {
1212 /// model
1213 /// .forward_t(
1214 /// Some(&input_tensor),
1215 /// Some(&token_type_ids),
1216 /// Some(&position_ids),
1217 /// None,
1218 /// Some(&attention_mask),
1219 /// false,
1220 /// )
1221 /// .unwrap()
1222 /// });
1223 /// ```
1224 pub fn forward_t(
1225 &self,
1226 input_ids: Option<&Tensor>,
1227 token_type_ids: Option<&Tensor>,
1228 position_ids: Option<&Tensor>,
1229 input_embeds: Option<&Tensor>,
1230 attention_mask: Option<&Tensor>,
1231 train: bool,
1232 ) -> Result<MobileBertTokenClassificationOutput, RustBertError> {
1233 let mobilebert_output = self.mobilebert.forward_t(
1234 input_ids,
1235 token_type_ids,
1236 position_ids,
1237 input_embeds,
1238 attention_mask,
1239 train,
1240 )?;
1241
1242 let logits = mobilebert_output
1243 .hidden_state
1244 .apply_t(&self.dropout, train)
1245 .apply(&self.classifier);
1246
1247 Ok(MobileBertTokenClassificationOutput {
1248 logits,
1249 all_hidden_states: mobilebert_output.all_hidden_states,
1250 all_attentions: mobilebert_output.all_attentions,
1251 })
1252 }
1253}
1254
1255/// Container for the MobileBert output.
1256pub struct MobileBertOutput {
1257 /// Last hidden states from the model
1258 pub hidden_state: Tensor,
1259 /// Pooled output
1260 pub pooled_output: Option<Tensor>,
1261 /// Hidden states for all intermediate layers
1262 pub all_hidden_states: Option<Vec<Tensor>>,
1263 /// Attention weights for all intermediate layers
1264 pub all_attentions: Option<Vec<Tensor>>,
1265}
1266
1267/// Container for the MobileBert masked LM model output.
1268pub struct MobileBertMaskedLMOutput {
1269 /// Logits for the vocabulary items at each sequence position
1270 pub logits: Tensor,
1271 /// Hidden states for all intermediate layers
1272 pub all_hidden_states: Option<Vec<Tensor>>,
1273 /// Attention weights for all intermediate layers
1274 pub all_attentions: Option<Vec<Tensor>>,
1275}
1276
1277/// Container for the MobileBert sequence classification model output.
1278pub struct MobileBertSequenceClassificationOutput {
1279 /// Logits for each input (sequence) for each target class
1280 pub logits: Tensor,
1281 /// Hidden states for all intermediate layers
1282 pub all_hidden_states: Option<Vec<Tensor>>,
1283 /// Attention weights for all intermediate layers
1284 pub all_attentions: Option<Vec<Tensor>>,
1285}
1286
1287/// Container for the MobileBert token classification model output.
1288pub struct MobileBertTokenClassificationOutput {
1289 /// Logits for each sequence item (token) for each target class
1290 pub logits: Tensor,
1291 /// Hidden states for all intermediate layers
1292 pub all_hidden_states: Option<Vec<Tensor>>,
1293 /// Attention weights for all intermediate layers
1294 pub all_attentions: Option<Vec<Tensor>>,
1295}
1296
1297/// Container for the MobileBert question answering model output.
1298pub struct MobileBertQuestionAnsweringOutput {
1299 /// Logits for the start position for token of each input sequence
1300 pub start_logits: Tensor,
1301 /// Logits for the end position for token of each input sequence
1302 pub end_logits: Tensor,
1303 /// Hidden states for all intermediate layers
1304 pub all_hidden_states: Option<Vec<Tensor>>,
1305 /// Attention weights for all intermediate layers
1306 pub all_attentions: Option<Vec<Tensor>>,
1307}