rust_bert/models/albert/
albert_model.rs

1// Copyright 2018 Google AI and Google Brain team.
2// Copyright 2020-present, the HuggingFace Inc. team.
3// Copyright 2020 Guillaume Becquin
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//     http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14use crate::albert::encoder::AlbertTransformer;
15use crate::common::activations::Activation;
16use crate::common::dropout::Dropout;
17use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair;
18use crate::{albert::embeddings::AlbertEmbeddings, common::activations::TensorFunction};
19use crate::{Config, RustBertError};
20use serde::{Deserialize, Serialize};
21use std::{borrow::Borrow, collections::HashMap};
22use tch::nn::Module;
23use tch::{nn, Kind, Tensor};
24
25/// # ALBERT Pretrained model weight files
26pub struct AlbertModelResources;
27
28/// # ALBERT Pretrained model config files
29pub struct AlbertConfigResources;
30
31/// # ALBERT Pretrained model vocab files
32pub struct AlbertVocabResources;
33
34impl AlbertModelResources {
35    /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/ALBERT>. Modified with conversion to C-array format.
36    pub const ALBERT_BASE_V2: (&'static str, &'static str) = (
37        "albert-base-v2/model",
38        "https://huggingface.co/albert-base-v2/resolve/main/rust_model.ot",
39    );
40    /// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2>. Modified with conversion to C-array format.
41    pub const PARAPHRASE_ALBERT_SMALL_V2: (&'static str, &'static str) = (
42        "paraphrase-albert-small-v2/model",
43        "https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2/resolve/main/rust_model.ot",
44    );
45}
46
47impl AlbertConfigResources {
48    /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/ALBERT>. Modified with conversion to C-array format.
49    pub const ALBERT_BASE_V2: (&'static str, &'static str) = (
50        "albert-base-v2/config",
51        "https://huggingface.co/albert-base-v2/resolve/main/config.json",
52    );
53    /// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2>. Modified with conversion to C-array format.
54    pub const PARAPHRASE_ALBERT_SMALL_V2: (&'static str, &'static str) = (
55        "paraphrase-albert-small-v2/config",
56        "https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2/resolve/main/config.json",
57    );
58}
59
60impl AlbertVocabResources {
61    /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/ALBERT>. Modified with conversion to C-array format.
62    pub const ALBERT_BASE_V2: (&'static str, &'static str) = (
63        "albert-base-v2/spiece",
64        "https://huggingface.co/albert-base-v2/resolve/main/spiece.model",
65    );
66    /// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2>. Modified with conversion to C-array format.
67    pub const PARAPHRASE_ALBERT_SMALL_V2: (&'static str, &'static str) = (
68        "paraphrase-albert-small-v2/spiece",
69        "https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2/resolve/main/spiece.model",
70    );
71}
72
73#[derive(Debug, Serialize, Deserialize, Clone)]
74/// # ALBERT model configuration
75/// Defines the ALBERT model architecture (e.g. number of layers, hidden layer size, label mapping...)
76pub struct AlbertConfig {
77    pub hidden_act: Activation,
78    pub attention_probs_dropout_prob: f64,
79    pub classifier_dropout_prob: Option<f64>,
80    pub bos_token_id: i64,
81    pub eos_token_id: i64,
82    pub embedding_size: i64,
83    pub hidden_dropout_prob: f64,
84    pub hidden_size: i64,
85    pub initializer_range: f32,
86    pub inner_group_num: i64,
87    pub intermediate_size: i64,
88    pub layer_norm_eps: Option<f64>,
89    pub max_position_embeddings: i64,
90    pub num_attention_heads: i64,
91    pub num_hidden_groups: i64,
92    pub num_hidden_layers: i64,
93    pub pad_token_id: i64,
94    pub type_vocab_size: i64,
95    pub vocab_size: i64,
96    pub output_attentions: Option<bool>,
97    pub output_hidden_states: Option<bool>,
98    pub is_decoder: Option<bool>,
99    pub id2label: Option<HashMap<i64, String>>,
100    pub label2id: Option<HashMap<String, i64>>,
101}
102
103impl Config for AlbertConfig {}
104
105impl Default for AlbertConfig {
106    fn default() -> Self {
107        AlbertConfig {
108            hidden_act: Activation::gelu_new,
109            attention_probs_dropout_prob: 0.0,
110            classifier_dropout_prob: Some(0.1),
111            bos_token_id: 2,
112            eos_token_id: 3,
113            embedding_size: 128,
114            hidden_dropout_prob: 0.0,
115            hidden_size: 4096,
116            initializer_range: 0.02,
117            inner_group_num: 1,
118            intermediate_size: 16384,
119            layer_norm_eps: Some(1e-12),
120            max_position_embeddings: 512,
121            num_attention_heads: 64,
122            num_hidden_groups: 1,
123            num_hidden_layers: 12,
124            pad_token_id: 0,
125            type_vocab_size: 2,
126            vocab_size: 30000,
127            output_attentions: None,
128            output_hidden_states: None,
129            is_decoder: None,
130            id2label: None,
131            label2id: None,
132        }
133    }
134}
135
136/// # ALBERT Base model
137/// Base architecture for ALBERT models. Task-specific models will be built from this common base model
138/// It is made of the following blocks:
139/// - `embeddings`: `token`, `position` and `segment_id` embeddings
140/// - `encoder`: Encoder (transformer) made of a vector of layers. Each layer is made of a self-attention layer, an intermediate (linear) and output (linear + layer norm) layers. Note that the weights are shared across layers, allowing for a reduction in the model memory footprint.
141/// - `pooler`: linear layer applied to the first element of the sequence (*MASK* token)
142/// - `pooler_activation`: Tanh activation function for the pooling layer
143pub struct AlbertModel {
144    embeddings: AlbertEmbeddings,
145    encoder: AlbertTransformer,
146    pooler: nn::Linear,
147    pooler_activation: TensorFunction,
148}
149
150impl AlbertModel {
151    /// Build a new `AlbertModel`
152    ///
153    /// # Arguments
154    ///
155    /// * `p` - Variable store path for the root of the ALBERT model
156    /// * `config` - `AlbertConfig` object defining the model architecture and decoder status
157    ///
158    /// # Example
159    ///
160    /// ```no_run
161    /// use rust_bert::albert::{AlbertConfig, AlbertModel};
162    /// use rust_bert::Config;
163    /// use std::path::Path;
164    /// use tch::{nn, Device};
165    ///
166    /// let config_path = Path::new("path/to/config.json");
167    /// let device = Device::Cpu;
168    /// let p = nn::VarStore::new(device);
169    /// let config = AlbertConfig::from_file(config_path);
170    /// let albert: AlbertModel = AlbertModel::new(&p.root() / "albert", &config);
171    /// ```
172    pub fn new<'p, P>(p: P, config: &AlbertConfig) -> AlbertModel
173    where
174        P: Borrow<nn::Path<'p>>,
175    {
176        let p = p.borrow();
177
178        let embeddings = AlbertEmbeddings::new(p / "embeddings", config);
179        let encoder = AlbertTransformer::new(p / "encoder", config);
180        let pooler = nn::linear(
181            p / "pooler",
182            config.hidden_size,
183            config.hidden_size,
184            Default::default(),
185        );
186        let pooler_activation = Activation::tanh.get_function();
187
188        AlbertModel {
189            embeddings,
190            encoder,
191            pooler,
192            pooler_activation,
193        }
194    }
195
196    /// Forward pass through the model
197    ///
198    /// # Arguments
199    ///
200    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
201    /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
202    /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
203    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
204    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
205    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
206    ///
207    /// # Returns
208    /// * `AlbertOutput` containing:
209    ///   - `hidden_state` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*)
210    ///   - `pooled_output` - `Tensor` of shape (*batch size*, *hidden_size*)
211    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
212    ///   - `all_attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*)
213    ///
214    /// # Example
215    ///
216    /// ```no_run
217    /// # use tch::{nn, Device, Tensor, no_grad};
218    /// # use rust_bert::Config;
219    /// # use std::path::Path;
220    /// # use tch::kind::Kind::Int64;
221    /// use rust_bert::albert::{AlbertConfig, AlbertModel};
222    /// # let config_path = Path::new("path/to/config.json");
223    /// # let device = Device::Cpu;
224    /// # let vs = nn::VarStore::new(device);
225    /// # let config = AlbertConfig::from_file(config_path);
226    /// # let albert_model: AlbertModel = AlbertModel::new(&vs.root(), &config);
227    /// let (batch_size, sequence_length) = (64, 128);
228    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
229    /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
230    /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
231    /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
232    ///     .expand(&[batch_size, sequence_length], true);
233    ///
234    /// let model_output = no_grad(|| {
235    ///     albert_model
236    ///         .forward_t(
237    ///             Some(&input_tensor),
238    ///             Some(&mask),
239    ///             Some(&token_type_ids),
240    ///             Some(&position_ids),
241    ///             None,
242    ///             false,
243    ///         )
244    ///         .unwrap()
245    /// });
246    /// ```
247    pub fn forward_t(
248        &self,
249        input_ids: Option<&Tensor>,
250        mask: Option<&Tensor>,
251        token_type_ids: Option<&Tensor>,
252        position_ids: Option<&Tensor>,
253        input_embeds: Option<&Tensor>,
254        train: bool,
255    ) -> Result<AlbertOutput, RustBertError> {
256        let (input_shape, device) =
257            get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeds)?;
258
259        let calc_mask = if mask.is_none() {
260            Some(Tensor::ones(input_shape, (Kind::Int64, device)))
261        } else {
262            None
263        };
264        let mask = mask.unwrap_or_else(|| calc_mask.as_ref().unwrap());
265
266        let embedding_output = self.embeddings.forward_t(
267            input_ids,
268            token_type_ids,
269            position_ids,
270            input_embeds,
271            train,
272        )?;
273
274        let extended_attention_mask = mask.unsqueeze(1).unsqueeze(2);
275        let extended_attention_mask: Tensor =
276            ((extended_attention_mask.ones_like() - extended_attention_mask) * -10000.0)
277                .to_kind(embedding_output.kind());
278
279        let transformer_output =
280            self.encoder
281                .forward_t(&embedding_output, Some(extended_attention_mask), train);
282
283        let pooled_output = self
284            .pooler
285            .forward(&transformer_output.hidden_state.select(1, 0));
286        let pooled_output = (self.pooler_activation.get_fn())(&pooled_output);
287
288        Ok(AlbertOutput {
289            hidden_state: transformer_output.hidden_state,
290            pooled_output,
291            all_hidden_states: transformer_output.all_hidden_states,
292            all_attentions: transformer_output.all_attentions,
293        })
294    }
295}
296
297pub struct AlbertMLMHead {
298    layer_norm: nn::LayerNorm,
299    dense: nn::Linear,
300    decoder: nn::Linear,
301    activation: TensorFunction,
302}
303
304impl AlbertMLMHead {
305    pub fn new<'p, P>(p: P, config: &AlbertConfig) -> AlbertMLMHead
306    where
307        P: Borrow<nn::Path<'p>>,
308    {
309        let p = p.borrow();
310
311        let layer_norm_eps = config.layer_norm_eps.unwrap_or(1e-12);
312        let layer_norm_config = nn::LayerNormConfig {
313            eps: layer_norm_eps,
314            ..Default::default()
315        };
316        let layer_norm = nn::layer_norm(
317            p / "LayerNorm",
318            vec![config.embedding_size],
319            layer_norm_config,
320        );
321        let dense = nn::linear(
322            p / "dense",
323            config.hidden_size,
324            config.embedding_size,
325            Default::default(),
326        );
327        let decoder = nn::linear(
328            p / "decoder",
329            config.embedding_size,
330            config.vocab_size,
331            Default::default(),
332        );
333
334        let activation = config.hidden_act.get_function();
335
336        AlbertMLMHead {
337            layer_norm,
338            dense,
339            decoder,
340            activation,
341        }
342    }
343
344    pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
345        let output: Tensor = (self.activation.get_fn())(&hidden_states.apply(&self.dense));
346        output.apply(&self.layer_norm).apply(&self.decoder)
347    }
348}
349
350/// # ALBERT for masked language model
351/// Base ALBERT model with a masked language model head to predict missing tokens, for example `"Looks like one [MASK] is missing" -> "person"`
352/// It is made of the following blocks:
353/// - `albert`: Base AlbertModel
354/// - `predictions`: ALBERT MLM prediction head
355pub struct AlbertForMaskedLM {
356    albert: AlbertModel,
357    predictions: AlbertMLMHead,
358}
359
360impl AlbertForMaskedLM {
361    /// Build a new `AlbertForMaskedLM`
362    ///
363    /// # Arguments
364    ///
365    /// * `p` - Variable store path for the root of the ALBERT model
366    /// * `config` - `AlbertConfig` object defining the model architecture and decoder status
367    ///
368    /// # Example
369    ///
370    /// ```no_run
371    /// use rust_bert::albert::{AlbertConfig, AlbertForMaskedLM};
372    /// use rust_bert::Config;
373    /// use std::path::Path;
374    /// use tch::{nn, Device};
375    ///
376    /// let config_path = Path::new("path/to/config.json");
377    /// let device = Device::Cpu;
378    /// let p = nn::VarStore::new(device);
379    /// let config = AlbertConfig::from_file(config_path);
380    /// let albert: AlbertForMaskedLM = AlbertForMaskedLM::new(&p.root(), &config);
381    /// ```
382    pub fn new<'p, P>(p: P, config: &AlbertConfig) -> AlbertForMaskedLM
383    where
384        P: Borrow<nn::Path<'p>>,
385    {
386        let p = p.borrow();
387
388        let albert = AlbertModel::new(p / "albert", config);
389        let predictions = AlbertMLMHead::new(p / "predictions", config);
390
391        AlbertForMaskedLM {
392            albert,
393            predictions,
394        }
395    }
396
397    /// Forward pass through the model
398    ///
399    /// # Arguments
400    ///
401    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
402    /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
403    /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
404    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
405    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
406    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
407    ///
408    /// # Returns
409    ///
410    /// * `AlbertMaskedLMOutput` containing:
411    ///   - `prediction_scores` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*)
412    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
413    ///   - `all_attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*)
414    ///
415    /// # Example
416    ///
417    /// ```no_run
418    /// # use tch::{nn, Device, Tensor, no_grad};
419    /// # use rust_bert::Config;
420    /// # use std::path::Path;
421    /// # use tch::kind::Kind::Int64;
422    /// use rust_bert::albert::{AlbertConfig, AlbertForMaskedLM};
423    /// # let config_path = Path::new("path/to/config.json");
424    /// # let device = Device::Cpu;
425    /// # let vs = nn::VarStore::new(device);
426    /// # let config = AlbertConfig::from_file(config_path);
427    /// # let albert_model: AlbertForMaskedLM = AlbertForMaskedLM::new(&vs.root(), &config);
428    /// let (batch_size, sequence_length) = (64, 128);
429    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
430    /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
431    /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
432    /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
433    ///     .expand(&[batch_size, sequence_length], true);
434    ///
435    /// let masked_lm_output = no_grad(|| {
436    ///     albert_model.forward_t(
437    ///         Some(&input_tensor),
438    ///         Some(&mask),
439    ///         Some(&token_type_ids),
440    ///         Some(&position_ids),
441    ///         None,
442    ///         false,
443    ///     )
444    /// });
445    /// ```
446    pub fn forward_t(
447        &self,
448        input_ids: Option<&Tensor>,
449        mask: Option<&Tensor>,
450        token_type_ids: Option<&Tensor>,
451        position_ids: Option<&Tensor>,
452        input_embeds: Option<&Tensor>,
453        train: bool,
454    ) -> AlbertMaskedLMOutput {
455        let base_model_output = self
456            .albert
457            .forward_t(
458                input_ids,
459                mask,
460                token_type_ids,
461                position_ids,
462                input_embeds,
463                train,
464            )
465            .unwrap();
466        let prediction_scores = self.predictions.forward(&base_model_output.hidden_state);
467        AlbertMaskedLMOutput {
468            prediction_scores,
469            all_hidden_states: base_model_output.all_hidden_states,
470            all_attentions: base_model_output.all_attentions,
471        }
472    }
473}
474
475/// # ALBERT for sequence classification
476/// Base ALBERT model with a classifier head to perform sentence or document-level classification
477/// It is made of the following blocks:
478/// - `albert`: Base AlbertModel
479/// - `dropout`: Dropout layer
480/// - `classifier`: linear layer for classification
481pub struct AlbertForSequenceClassification {
482    albert: AlbertModel,
483    dropout: Dropout,
484    classifier: nn::Linear,
485}
486
487impl AlbertForSequenceClassification {
488    /// Build a new `AlbertForSequenceClassification`
489    ///
490    /// # Arguments
491    ///
492    /// * `p` - Variable store path for the root of the ALBERT model
493    /// * `config` - `AlbertConfig` object defining the model architecture and decoder status
494    ///
495    /// # Example
496    ///
497    /// ```no_run
498    /// use rust_bert::albert::{AlbertConfig, AlbertForSequenceClassification};
499    /// use rust_bert::Config;
500    /// use std::path::Path;
501    /// use tch::{nn, Device};
502    ///
503    /// let config_path = Path::new("path/to/config.json");
504    /// let device = Device::Cpu;
505    /// let p = nn::VarStore::new(device);
506    /// let config = AlbertConfig::from_file(config_path);
507    /// let albert: AlbertForSequenceClassification =
508    ///     AlbertForSequenceClassification::new(&p.root(), &config).unwrap();
509    /// ```
510    pub fn new<'p, P>(
511        p: P,
512        config: &AlbertConfig,
513    ) -> Result<AlbertForSequenceClassification, RustBertError>
514    where
515        P: Borrow<nn::Path<'p>>,
516    {
517        let p = p.borrow();
518
519        let albert = AlbertModel::new(p / "albert", config);
520        let classifier_dropout_prob = config.classifier_dropout_prob.unwrap_or(0.1);
521        let dropout = Dropout::new(classifier_dropout_prob);
522        let num_labels = config
523            .id2label
524            .as_ref()
525            .ok_or_else(|| {
526                RustBertError::InvalidConfigurationError(
527                    "num_labels not provided in configuration".to_string(),
528                )
529            })?
530            .len() as i64;
531        let classifier = nn::linear(
532            p / "classifier",
533            config.hidden_size,
534            num_labels,
535            Default::default(),
536        );
537
538        Ok(AlbertForSequenceClassification {
539            albert,
540            dropout,
541            classifier,
542        })
543    }
544
545    /// Forward pass through the model
546    ///
547    /// # Arguments
548    ///
549    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
550    /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
551    /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
552    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
553    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
554    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
555    ///
556    /// # Returns
557    ///
558    /// * `AlbertSequenceClassificationOutput` containing:
559    ///   - `logits` - `Tensor` of shape (*batch size*, *num_labels*)
560    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
561    ///   - `all_attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*)
562    ///
563    /// # Example
564    ///
565    /// ```no_run
566    /// # use tch::{nn, Device, Tensor, no_grad};
567    /// # use rust_bert::Config;
568    /// # use std::path::Path;
569    /// # use tch::kind::Kind::Int64;
570    /// use rust_bert::albert::{AlbertConfig, AlbertForSequenceClassification};
571    /// # let config_path = Path::new("path/to/config.json");
572    /// # let device = Device::Cpu;
573    /// # let vs = nn::VarStore::new(device);
574    /// # let config = AlbertConfig::from_file(config_path);
575    /// # let albert_model: AlbertForSequenceClassification = AlbertForSequenceClassification::new(&vs.root(), &config).unwrap();
576    ///  let (batch_size, sequence_length) = (64, 128);
577    ///  let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
578    ///  let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
579    ///  let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
580    ///  let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
581    ///
582    ///  let classification_output = no_grad(|| {
583    ///    albert_model
584    ///         .forward_t(Some(&input_tensor),
585    ///                    Some(&mask),
586    ///                    Some(&token_type_ids),
587    ///                    Some(&position_ids),
588    ///                    None,
589    ///                    false)
590    ///    });
591    /// ```
592    pub fn forward_t(
593        &self,
594        input_ids: Option<&Tensor>,
595        mask: Option<&Tensor>,
596        token_type_ids: Option<&Tensor>,
597        position_ids: Option<&Tensor>,
598        input_embeds: Option<&Tensor>,
599        train: bool,
600    ) -> AlbertSequenceClassificationOutput {
601        let base_model_output = self
602            .albert
603            .forward_t(
604                input_ids,
605                mask,
606                token_type_ids,
607                position_ids,
608                input_embeds,
609                train,
610            )
611            .unwrap();
612        let logits = base_model_output
613            .pooled_output
614            .apply_t(&self.dropout, train)
615            .apply(&self.classifier);
616        AlbertSequenceClassificationOutput {
617            logits,
618            all_hidden_states: base_model_output.all_hidden_states,
619            all_attentions: base_model_output.all_attentions,
620        }
621    }
622}
623
624/// # ALBERT for token classification (e.g. NER, POS)
625/// Token-level classifier predicting a label for each token provided. Note that because of SentencePiece tokenization, the labels predicted are
626/// not necessarily aligned with words in the sentence.
627/// It is made of the following blocks:
628/// - `albert`: Base AlbertModel
629/// - `dropout`: Dropout to apply on the encoder last hidden states
630/// - `classifier`: Linear layer for token classification
631pub struct AlbertForTokenClassification {
632    albert: AlbertModel,
633    dropout: Dropout,
634    classifier: nn::Linear,
635}
636
637impl AlbertForTokenClassification {
638    /// Build a new `AlbertForTokenClassification`
639    ///
640    /// # Arguments
641    ///
642    /// * `p` - Variable store path for the root of the ALBERT model
643    /// * `config` - `AlbertConfig` object defining the model architecture and decoder status
644    ///
645    /// # Example
646    ///
647    /// ```no_run
648    /// use rust_bert::albert::{AlbertConfig, AlbertForTokenClassification};
649    /// use rust_bert::Config;
650    /// use std::path::Path;
651    /// use tch::{nn, Device};
652    ///
653    /// let config_path = Path::new("path/to/config.json");
654    /// let device = Device::Cpu;
655    /// let p = nn::VarStore::new(device);
656    /// let config = AlbertConfig::from_file(config_path);
657    /// let albert: AlbertForTokenClassification =
658    ///     AlbertForTokenClassification::new(&p.root(), &config).unwrap();
659    /// ```
660    pub fn new<'p, P>(
661        p: P,
662        config: &AlbertConfig,
663    ) -> Result<AlbertForTokenClassification, RustBertError>
664    where
665        P: Borrow<nn::Path<'p>>,
666    {
667        let p = p.borrow();
668
669        let albert = AlbertModel::new(p / "albert", config);
670        let dropout = Dropout::new(config.hidden_dropout_prob);
671        let num_labels = config
672            .id2label
673            .as_ref()
674            .ok_or_else(|| {
675                RustBertError::InvalidConfigurationError(
676                    "num_labels not provided in configuration".to_string(),
677                )
678            })?
679            .len() as i64;
680        let classifier = nn::linear(
681            p / "classifier",
682            config.hidden_size,
683            num_labels,
684            Default::default(),
685        );
686
687        Ok(AlbertForTokenClassification {
688            albert,
689            dropout,
690            classifier,
691        })
692    }
693
694    /// Forward pass through the model
695    ///
696    /// # Arguments
697    ///
698    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
699    /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
700    /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
701    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
702    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
703    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
704    ///
705    /// # Returns
706    ///
707    /// * `AlbertTokenClassificationOutput` containing:
708    ///   - `logits` - `Tensor` of shape (*batch size*, *sequence_length*, *num_labels*) containing the logits for each of the input tokens and classes
709    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
710    ///   - `all_attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*)
711    ///
712    /// # Example
713    ///
714    /// ```no_run
715    /// # use tch::{nn, Device, Tensor, no_grad};
716    /// # use rust_bert::Config;
717    /// # use std::path::Path;
718    /// # use tch::kind::Kind::Int64;
719    /// use rust_bert::albert::{AlbertConfig, AlbertForTokenClassification};
720    /// # let config_path = Path::new("path/to/config.json");
721    /// # let device = Device::Cpu;
722    /// # let vs = nn::VarStore::new(device);
723    /// # let config = AlbertConfig::from_file(config_path);
724    /// # let albert_model: AlbertForTokenClassification = AlbertForTokenClassification::new(&vs.root(), &config).unwrap();
725    ///  let (batch_size, sequence_length) = (64, 128);
726    ///  let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
727    ///  let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
728    ///  let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
729    ///  let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
730    ///
731    ///  let model_output = no_grad(|| {
732    ///    albert_model
733    ///         .forward_t(Some(&input_tensor),
734    ///                    Some(&mask),
735    ///                    Some(&token_type_ids),
736    ///                    Some(&position_ids),
737    ///                    None,
738    ///                    false)
739    ///    });
740    /// ```
741    pub fn forward_t(
742        &self,
743        input_ids: Option<&Tensor>,
744        mask: Option<&Tensor>,
745        token_type_ids: Option<&Tensor>,
746        position_ids: Option<&Tensor>,
747        input_embeds: Option<&Tensor>,
748        train: bool,
749    ) -> AlbertTokenClassificationOutput {
750        let base_model_output = self
751            .albert
752            .forward_t(
753                input_ids,
754                mask,
755                token_type_ids,
756                position_ids,
757                input_embeds,
758                train,
759            )
760            .unwrap();
761        let logits = base_model_output
762            .hidden_state
763            .apply_t(&self.dropout, train)
764            .apply(&self.classifier);
765        AlbertTokenClassificationOutput {
766            logits,
767            all_hidden_states: base_model_output.all_hidden_states,
768            all_attentions: base_model_output.all_attentions,
769        }
770    }
771}
772
773/// # ALBERT for question answering
774/// Extractive question-answering model based on a ALBERT language model. Identifies the segment of a context that answers a provided question.
775/// Please note that a significant amount of pre- and post-processing is required to perform end-to-end question answering.
776/// See the question answering pipeline (also provided in this crate) for more details.
777/// It is made of the following blocks:
778/// - `albert`: Base AlbertModel
779/// - `qa_outputs`: Linear layer for question answering
780pub struct AlbertForQuestionAnswering {
781    albert: AlbertModel,
782    qa_outputs: nn::Linear,
783}
784
785impl AlbertForQuestionAnswering {
786    /// Build a new `AlbertForQuestionAnswering`
787    ///
788    /// # Arguments
789    ///
790    /// * `p` - Variable store path for the root of the ALBERT model
791    /// * `config` - `AlbertConfig` object defining the model architecture and decoder status
792    ///
793    /// # Example
794    ///
795    /// ```no_run
796    /// use rust_bert::albert::{AlbertConfig, AlbertForQuestionAnswering};
797    /// use rust_bert::Config;
798    /// use std::path::Path;
799    /// use tch::{nn, Device};
800    ///
801    /// let config_path = Path::new("path/to/config.json");
802    /// let device = Device::Cpu;
803    /// let p = nn::VarStore::new(device);
804    /// let config = AlbertConfig::from_file(config_path);
805    /// let albert: AlbertForQuestionAnswering = AlbertForQuestionAnswering::new(&p.root(), &config);
806    /// ```
807    pub fn new<'p, P>(p: P, config: &AlbertConfig) -> AlbertForQuestionAnswering
808    where
809        P: Borrow<nn::Path<'p>>,
810    {
811        let p = p.borrow();
812
813        let albert = AlbertModel::new(p / "albert", config);
814        let num_labels = 2;
815        let qa_outputs = nn::linear(
816            p / "qa_outputs",
817            config.hidden_size,
818            num_labels,
819            Default::default(),
820        );
821
822        AlbertForQuestionAnswering { albert, qa_outputs }
823    }
824
825    /// Forward pass through the model
826    ///
827    /// # Arguments
828    ///
829    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
830    /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
831    /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
832    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
833    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
834    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
835    ///
836    /// # Returns
837    ///
838    /// * `AlbertQuestionAnsweringOutput` containing:
839    ///   - `start_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for start of the answer
840    ///   - `end_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for end of the answer
841    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
842    ///   - `all_attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*)
843    ///
844    /// # Example
845    ///
846    /// ```no_run
847    /// # use tch::{nn, Device, Tensor, no_grad};
848    /// # use rust_bert::Config;
849    /// # use std::path::Path;
850    /// # use tch::kind::Kind::Int64;
851    /// use rust_bert::albert::{AlbertConfig, AlbertForQuestionAnswering};
852    /// # let config_path = Path::new("path/to/config.json");
853    /// # let device = Device::Cpu;
854    /// # let vs = nn::VarStore::new(device);
855    /// # let config = AlbertConfig::from_file(config_path);
856    /// # let albert_model: AlbertForQuestionAnswering = AlbertForQuestionAnswering::new(&vs.root(), &config);
857    ///  let (batch_size, sequence_length) = (64, 128);
858    ///  let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
859    ///  let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
860    ///  let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
861    ///  let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
862    ///
863    ///  let model_output = no_grad(|| {
864    ///    albert_model
865    ///         .forward_t(Some(&input_tensor),
866    ///                    Some(&mask),
867    ///                    Some(&token_type_ids),
868    ///                    Some(&position_ids),
869    ///                    None,
870    ///                    false)
871    ///    });
872    /// ```
873    pub fn forward_t(
874        &self,
875        input_ids: Option<&Tensor>,
876        mask: Option<&Tensor>,
877        token_type_ids: Option<&Tensor>,
878        position_ids: Option<&Tensor>,
879        input_embeds: Option<&Tensor>,
880        train: bool,
881    ) -> AlbertQuestionAnsweringOutput {
882        let base_model_output = self
883            .albert
884            .forward_t(
885                input_ids,
886                mask,
887                token_type_ids,
888                position_ids,
889                input_embeds,
890                train,
891            )
892            .unwrap();
893        let logits = base_model_output
894            .hidden_state
895            .apply(&self.qa_outputs)
896            .split(1, -1);
897        let (start_logits, end_logits) = (&logits[0], &logits[1]);
898        let start_logits = start_logits.squeeze_dim(-1);
899        let end_logits = end_logits.squeeze_dim(-1);
900
901        AlbertQuestionAnsweringOutput {
902            start_logits,
903            end_logits,
904            all_hidden_states: base_model_output.all_hidden_states,
905            all_attentions: base_model_output.all_attentions,
906        }
907    }
908}
909
910/// # ALBERT for multiple choices
911/// Multiple choices model using a ALBERT base model and a linear classifier.
912/// Input should be in the form `[CLS] Context [SEP] Possible choice [SEP]`. The choice is made along the batch axis,
913/// assuming all elements of the batch are alternatives to be chosen from for a given context.
914/// It is made of the following blocks:
915/// - `albert`: Base AlbertModel
916/// - `dropout`: Dropout for hidden states output
917/// - `classifier`: Linear layer for multiple choices
918pub struct AlbertForMultipleChoice {
919    albert: AlbertModel,
920    dropout: Dropout,
921    classifier: nn::Linear,
922}
923
924impl AlbertForMultipleChoice {
925    /// Build a new `AlbertForMultipleChoice`
926    ///
927    /// # Arguments
928    ///
929    /// * `p` - Variable store path for the root of the ALBERT model
930    /// * `config` - `AlbertConfig` object defining the model architecture and decoder status
931    ///
932    /// # Example
933    ///
934    /// ```no_run
935    /// use rust_bert::albert::{AlbertConfig, AlbertForMultipleChoice};
936    /// use rust_bert::Config;
937    /// use std::path::Path;
938    /// use tch::{nn, Device};
939    ///
940    /// let config_path = Path::new("path/to/config.json");
941    /// let device = Device::Cpu;
942    /// let p = nn::VarStore::new(device);
943    /// let config = AlbertConfig::from_file(config_path);
944    /// let albert: AlbertForMultipleChoice = AlbertForMultipleChoice::new(&p.root(), &config);
945    /// ```
946    pub fn new<'p, P>(p: P, config: &AlbertConfig) -> AlbertForMultipleChoice
947    where
948        P: Borrow<nn::Path<'p>>,
949    {
950        let p = p.borrow();
951
952        let albert = AlbertModel::new(p / "albert", config);
953        let dropout = Dropout::new(config.hidden_dropout_prob);
954        let num_labels = 1;
955        let classifier = nn::linear(
956            p / "classifier",
957            config.hidden_size,
958            num_labels,
959            Default::default(),
960        );
961
962        AlbertForMultipleChoice {
963            albert,
964            dropout,
965            classifier,
966        }
967    }
968
969    /// Forward pass through the model
970    ///
971    /// # Arguments
972    ///
973    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
974    /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
975    /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
976    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
977    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
978    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
979    ///
980    /// # Returns
981    ///
982    /// * `AlbertSequenceClassificationOutput` containing:
983    ///   - `logits` - `Tensor` of shape (*1*, *batch size*) containing the logits for each of the alternatives given
984    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
985    ///   - `all_attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*)
986    ///
987    /// # Example
988    ///
989    /// ```no_run
990    /// # use tch::{nn, Device, Tensor, no_grad};
991    /// # use rust_bert::Config;
992    /// # use std::path::Path;
993    /// # use tch::kind::Kind::Int64;
994    /// use rust_bert::albert::{AlbertConfig, AlbertForMultipleChoice};
995    /// # let config_path = Path::new("path/to/config.json");
996    /// # let device = Device::Cpu;
997    /// # let vs = nn::VarStore::new(device);
998    /// # let config = AlbertConfig::from_file(config_path);
999    /// # let albert_model: AlbertForMultipleChoice = AlbertForMultipleChoice::new(&vs.root(), &config);
1000    ///  let (batch_size, sequence_length) = (64, 128);
1001    ///  let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
1002    ///  let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
1003    ///  let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
1004    ///  let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
1005    ///
1006    ///  let model_output = no_grad(|| {
1007    ///    albert_model
1008    ///         .forward_t(Some(&input_tensor),
1009    ///                    Some(&mask),
1010    ///                    Some(&token_type_ids),
1011    ///                    Some(&position_ids),
1012    ///                    None,
1013    ///                    false).unwrap()
1014    ///    });
1015    /// ```
1016    pub fn forward_t(
1017        &self,
1018        input_ids: Option<&Tensor>,
1019        mask: Option<&Tensor>,
1020        token_type_ids: Option<&Tensor>,
1021        position_ids: Option<&Tensor>,
1022        input_embeds: Option<&Tensor>,
1023        train: bool,
1024    ) -> Result<AlbertSequenceClassificationOutput, RustBertError> {
1025        let (input_ids, input_embeds, num_choices) = match &input_ids {
1026            Some(input_value) => match &input_embeds {
1027                Some(_) => {
1028                    return Err(RustBertError::ValueError(
1029                        "Only one of input ids or input embeddings may be set".into(),
1030                    ));
1031                }
1032                None => (
1033                    Some(input_value.view((-1, *input_value.size().last().unwrap()))),
1034                    None,
1035                    input_value.size()[1],
1036                ),
1037            },
1038            None => match &input_embeds {
1039                Some(embeds) => (
1040                    None,
1041                    Some(embeds.view((-1, embeds.size()[1], embeds.size()[2]))),
1042                    embeds.size()[1],
1043                ),
1044                None => {
1045                    return Err(RustBertError::ValueError(
1046                        "At least one of input ids or input embeddings must be set".into(),
1047                    ));
1048                }
1049            },
1050        };
1051
1052        let mask = mask.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
1053        let token_type_ids =
1054            token_type_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
1055        let position_ids =
1056            position_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
1057
1058        let base_model_output = self.albert.forward_t(
1059            input_ids.as_ref(),
1060            mask.as_ref(),
1061            token_type_ids.as_ref(),
1062            position_ids.as_ref(),
1063            input_embeds.as_ref(),
1064            train,
1065        )?;
1066        let logits = base_model_output
1067            .pooled_output
1068            .apply_t(&self.dropout, train)
1069            .apply(&self.classifier)
1070            .view((-1, num_choices));
1071
1072        Ok(AlbertSequenceClassificationOutput {
1073            logits,
1074            all_hidden_states: base_model_output.all_hidden_states,
1075            all_attentions: base_model_output.all_attentions,
1076        })
1077    }
1078}
1079
1080/// # ALBERT for sentence embeddings
1081/// Transformer usable in [`SentenceEmbeddingsModel`](crate::pipelines::sentence_embeddings::SentenceEmbeddingsModel).
1082pub type AlbertForSentenceEmbeddings = AlbertModel;
1083
1084/// Container for the ALBERT model output.
1085pub struct AlbertOutput {
1086    /// Last hidden states from the model
1087    pub hidden_state: Tensor,
1088    /// Pooled output (hidden state for the first token)
1089    pub pooled_output: Tensor,
1090    /// Hidden states for all intermediate layers
1091    pub all_hidden_states: Option<Vec<Tensor>>,
1092    /// Attention weights for all intermediate layers
1093    pub all_attentions: Option<Vec<Vec<Tensor>>>,
1094}
1095
1096/// Container for the ALBERT masked LM model output.
1097pub struct AlbertMaskedLMOutput {
1098    /// Logits for the vocabulary items at each sequence position
1099    pub prediction_scores: Tensor,
1100    /// Hidden states for all intermediate layers
1101    pub all_hidden_states: Option<Vec<Tensor>>,
1102    /// Attention weights for all intermediate layers
1103    pub all_attentions: Option<Vec<Vec<Tensor>>>,
1104}
1105
1106/// Container for the ALBERT sequence classification model
1107pub struct AlbertSequenceClassificationOutput {
1108    /// Logits for each input (sequence) for each target class
1109    pub logits: Tensor,
1110    /// Hidden states for all intermediate layers
1111    pub all_hidden_states: Option<Vec<Tensor>>,
1112    /// Attention weights for all intermediate layers
1113    pub all_attentions: Option<Vec<Vec<Tensor>>>,
1114}
1115
1116/// Container for the ALBERT token classification model
1117pub struct AlbertTokenClassificationOutput {
1118    /// Logits for each sequence item (token) for each target class
1119    pub logits: Tensor,
1120    /// Hidden states for all intermediate layers
1121    pub all_hidden_states: Option<Vec<Tensor>>,
1122    /// Attention weights for all intermediate layers
1123    pub all_attentions: Option<Vec<Vec<Tensor>>>,
1124}
1125
1126/// Container for the ALBERT question answering model
1127pub struct AlbertQuestionAnsweringOutput {
1128    /// Logits for the start position for token of each input sequence
1129    pub start_logits: Tensor,
1130    /// Logits for the end position for token of each input sequence
1131    pub end_logits: Tensor,
1132    /// Hidden states for all intermediate layers
1133    pub all_hidden_states: Option<Vec<Tensor>>,
1134    /// Attention weights for all intermediate layers
1135    pub all_attentions: Option<Vec<Vec<Tensor>>>,
1136}