rust_bert/models/electra/
electra_model.rs

1// Copyright 2020 The Google Research Authors.
2// Copyright 2019-present, the HuggingFace Inc. team
3// Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
4// Copyright 2019 Guillaume Becquin
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//     http://www.apache.org/licenses/LICENSE-2.0
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::bert::BertConfig;
16use crate::common::activations::Activation;
17use crate::common::dropout::Dropout;
18use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair;
19use crate::electra::embeddings::ElectraEmbeddings;
20use crate::{bert::encoder::BertEncoder, common::activations::TensorFunction};
21use crate::{Config, RustBertError};
22use serde::{Deserialize, Serialize};
23use std::{borrow::Borrow, collections::HashMap};
24use tch::{nn, Kind, Tensor};
25
26/// # Electra Pretrained model weight files
27pub struct ElectraModelResources;
28
29/// # Electra Pretrained model config files
30pub struct ElectraConfigResources;
31
32/// # Electra Pretrained model vocab files
33pub struct ElectraVocabResources;
34
35impl ElectraModelResources {
36    /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/electra>. Modified with conversion to C-array format.
37    pub const BASE_GENERATOR: (&'static str, &'static str) = (
38        "electra-base-generator/model",
39        "https://huggingface.co/google/electra-base-generator/resolve/main/rust_model.ot",
40    );
41    /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/electra>. Modified with conversion to C-array format.
42    pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = (
43        "electra-base-discriminator/model",
44        "https://huggingface.co/google/electra-base-discriminator/resolve/main/rust_model.ot",
45    );
46}
47
48impl ElectraConfigResources {
49    /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/electra>. Modified with conversion to C-array format.
50    pub const BASE_GENERATOR: (&'static str, &'static str) = (
51        "electra-base-generator/config",
52        "https://huggingface.co/google/electra-base-generator/resolve/main/config.json",
53    );
54    /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/electra>. Modified with conversion to C-array format.
55    pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = (
56        "electra-base-discriminator/config",
57        "https://huggingface.co/google/electra-base-discriminator/resolve/main/config.json",
58    );
59}
60
61impl ElectraVocabResources {
62    /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/electra>. Modified with conversion to C-array format.
63    pub const BASE_GENERATOR: (&'static str, &'static str) = (
64        "electra-base-generator/vocab",
65        "https://huggingface.co/google/electra-base-generator/resolve/main/vocab.txt",
66    );
67    /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/electra>. Modified with conversion to C-array format.
68    pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = (
69        "electra-base-discriminator/vocab",
70        "https://huggingface.co/google/electra-base-discriminator/resolve/main/vocab.txt",
71    );
72}
73
74#[derive(Debug, Serialize, Deserialize, Clone)]
75/// # Electra model configuration
76/// Defines the Electra model architecture (e.g. number of layers, hidden layer size, label mapping...)
77pub struct ElectraConfig {
78    pub hidden_act: Activation,
79    pub attention_probs_dropout_prob: f64,
80    pub embedding_size: i64,
81    pub hidden_dropout_prob: f64,
82    pub hidden_size: i64,
83    pub initializer_range: f32,
84    pub layer_norm_eps: Option<f64>,
85    pub intermediate_size: i64,
86    pub max_position_embeddings: i64,
87    pub num_attention_heads: i64,
88    pub num_hidden_layers: i64,
89    pub type_vocab_size: i64,
90    pub vocab_size: i64,
91    pub pad_token_id: i64,
92    pub output_past: Option<bool>,
93    pub output_attentions: Option<bool>,
94    pub output_hidden_states: Option<bool>,
95    pub id2label: Option<HashMap<i64, String>>,
96    pub label2id: Option<HashMap<String, i64>>,
97}
98
99impl Config for ElectraConfig {}
100
101impl Default for ElectraConfig {
102    fn default() -> Self {
103        ElectraConfig {
104            hidden_act: Activation::gelu,
105            attention_probs_dropout_prob: 0.1,
106            embedding_size: 128,
107            hidden_dropout_prob: 0.1,
108            hidden_size: 256,
109            initializer_range: 0.02,
110            layer_norm_eps: Some(1e-12),
111            intermediate_size: 1024,
112            max_position_embeddings: 512,
113            num_attention_heads: 4,
114            num_hidden_layers: 12,
115            type_vocab_size: 2,
116            vocab_size: 30522,
117            pad_token_id: 0,
118            output_past: None,
119            output_attentions: None,
120            output_hidden_states: None,
121            id2label: None,
122            label2id: None,
123        }
124    }
125}
126
127/// # Electra Base model
128/// Base architecture for Electra models.
129/// It is made of the following blocks:
130/// - `embeddings`: `token`, `position` and `segment_id` embeddings. Note that in contrast to BERT, the embeddings dimension is not necessarily equal to the hidden layer dimensions
131/// - `encoder`: BertEncoder (transformer) made of a vector of layers. Each layer is made of a self-attention layer, an intermediate (linear) and output (linear + layer norm) layers
132/// - `embeddings_project`: (optional) linear layer applied to project the embeddings space to the hidden layer dimension
133pub struct ElectraModel {
134    embeddings: ElectraEmbeddings,
135    embeddings_project: Option<nn::Linear>,
136    encoder: BertEncoder,
137}
138
139/// Defines the implementation of the ElectraModel.
140impl ElectraModel {
141    /// Build a new `ElectraModel`
142    ///
143    /// # Arguments
144    ///
145    /// * `p` - Variable store path for the root of the Electra model
146    /// * `config` - `ElectraConfig` object defining the model architecture
147    ///
148    /// # Example
149    ///
150    /// ```no_run
151    /// use rust_bert::electra::{ElectraConfig, ElectraModel};
152    /// use rust_bert::Config;
153    /// use std::path::Path;
154    /// use tch::{nn, Device};
155    ///
156    /// let config_path = Path::new("path/to/config.json");
157    /// let device = Device::Cpu;
158    /// let p = nn::VarStore::new(device);
159    /// let config = ElectraConfig::from_file(config_path);
160    /// let electra_model: ElectraModel = ElectraModel::new(&p.root() / "electra", &config);
161    /// ```
162    pub fn new<'p, P>(p: P, config: &ElectraConfig) -> ElectraModel
163    where
164        P: Borrow<nn::Path<'p>>,
165    {
166        let p = p.borrow();
167
168        let embeddings = ElectraEmbeddings::new(p / "embeddings", config);
169        let embeddings_project = if config.embedding_size != config.hidden_size {
170            Some(nn::linear(
171                p / "embeddings_project",
172                config.embedding_size,
173                config.hidden_size,
174                Default::default(),
175            ))
176        } else {
177            None
178        };
179        let bert_config = BertConfig {
180            hidden_act: config.hidden_act,
181            attention_probs_dropout_prob: config.attention_probs_dropout_prob,
182            hidden_dropout_prob: config.hidden_dropout_prob,
183            hidden_size: config.hidden_size,
184            initializer_range: config.initializer_range,
185            intermediate_size: config.intermediate_size,
186            max_position_embeddings: config.max_position_embeddings,
187            num_attention_heads: config.num_attention_heads,
188            num_hidden_layers: config.num_hidden_layers,
189            type_vocab_size: config.type_vocab_size,
190            vocab_size: config.vocab_size,
191            output_attentions: config.output_attentions,
192            output_hidden_states: config.output_hidden_states,
193            is_decoder: None,
194            id2label: config.id2label.clone(),
195            label2id: config.label2id.clone(),
196        };
197        let encoder = BertEncoder::new(p / "encoder", &bert_config);
198        ElectraModel {
199            embeddings,
200            embeddings_project,
201            encoder,
202        }
203    }
204
205    /// Forward pass through the model
206    ///
207    /// # Arguments
208    ///
209    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
210    /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
211    /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
212    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
213    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
214    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
215    ///
216    /// # Returns
217    ///
218    /// * `ElectraModelOutput` containing:
219    ///   - `hidden_state` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*)
220    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
221    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
222    ///
223    /// # Example
224    ///
225    /// ```no_run
226    /// # use rust_bert::electra::{ElectraModel, ElectraConfig};
227    /// # use tch::{nn, Device, Tensor, no_grad};
228    /// # use rust_bert::Config;
229    /// # use std::path::Path;
230    /// # use tch::kind::Kind::Int64;
231    /// # let config_path = Path::new("path/to/config.json");
232    /// # let device = Device::Cpu;
233    /// # let vs = nn::VarStore::new(device);
234    /// # let config = ElectraConfig::from_file(config_path);
235    /// # let electra_model: ElectraModel = ElectraModel::new(&vs.root(), &config);
236    /// let (batch_size, sequence_length) = (64, 128);
237    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
238    /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
239    /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
240    /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
241    ///     .expand(&[batch_size, sequence_length], true);
242    ///
243    /// let model_output = no_grad(|| {
244    ///     electra_model
245    ///         .forward_t(
246    ///             Some(&input_tensor),
247    ///             Some(&mask),
248    ///             Some(&token_type_ids),
249    ///             Some(&position_ids),
250    ///             None,
251    ///             false,
252    ///         )
253    ///         .unwrap()
254    /// });
255    /// ```
256    pub fn forward_t(
257        &self,
258        input_ids: Option<&Tensor>,
259        mask: Option<&Tensor>,
260        token_type_ids: Option<&Tensor>,
261        position_ids: Option<&Tensor>,
262        input_embeds: Option<&Tensor>,
263        train: bool,
264    ) -> Result<ElectraModelOutput, RustBertError> {
265        let (input_shape, device) =
266            get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeds)?;
267
268        let calc_mask = if mask.is_none() {
269            Some(Tensor::ones(input_shape, (Kind::Int64, device)))
270        } else {
271            None
272        };
273        let mask = mask.unwrap_or_else(|| calc_mask.as_ref().unwrap());
274
275        let extended_attention_mask = match mask.dim() {
276            3 => mask.unsqueeze(1),
277            2 => mask.unsqueeze(1).unsqueeze(1),
278            _ => {
279                return Err(RustBertError::ValueError(
280                    "Invalid attention mask dimension, must be 2 or 3".into(),
281                ));
282            }
283        };
284
285        let hidden_states = self.embeddings.forward_t(
286            input_ids,
287            token_type_ids,
288            position_ids,
289            input_embeds,
290            train,
291        )?;
292
293        let hidden_states = match &self.embeddings_project {
294            Some(layer) => hidden_states.apply(layer),
295            None => hidden_states,
296        };
297
298        let encoder_output = self.encoder.forward_t(
299            &hidden_states,
300            Some(&extended_attention_mask),
301            None,
302            None,
303            train,
304        );
305
306        Ok(ElectraModelOutput {
307            hidden_state: encoder_output.hidden_state,
308            all_hidden_states: encoder_output.all_hidden_states,
309            all_attentions: encoder_output.all_attentions,
310        })
311    }
312}
313
314/// # Electra Discriminator head
315/// Discriminator head for Electra models
316/// It is made of the following blocks:
317/// - `dense`: linear layer of dimension (*hidden_size*, *hidden_size*)
318/// - `dense_prediction`: linear layer of dimension (*hidden_size*, *1*) mapping the model output to a 1-dimension space to identify original and generated tokens
319/// - `activation`: activation layer (one of GeLU, ReLU or Mish)
320pub struct ElectraDiscriminatorHead {
321    dense: nn::Linear,
322    dense_prediction: nn::Linear,
323    activation: TensorFunction,
324}
325
326/// Defines the implementation of the ElectraDiscriminatorHead.
327impl ElectraDiscriminatorHead {
328    /// Build a new `ElectraDiscriminatorHead`
329    ///
330    /// # Arguments
331    ///
332    /// * `p` - Variable store path for the root of the Electra model
333    /// * `config` - `ElectraConfig` object defining the model architecture
334    ///
335    /// # Example
336    ///
337    /// ```no_run
338    /// use rust_bert::electra::{ElectraConfig, ElectraDiscriminatorHead};
339    /// use rust_bert::Config;
340    /// use std::path::Path;
341    /// use tch::{nn, Device};
342    ///
343    /// let config_path = Path::new("path/to/config.json");
344    /// let device = Device::Cpu;
345    /// let p = nn::VarStore::new(device);
346    /// let config = ElectraConfig::from_file(config_path);
347    /// let discriminator_head = ElectraDiscriminatorHead::new(&p.root() / "electra", &config);
348    /// ```
349    pub fn new<'p, P>(p: P, config: &ElectraConfig) -> ElectraDiscriminatorHead
350    where
351        P: Borrow<nn::Path<'p>>,
352    {
353        let p = p.borrow();
354
355        let dense = nn::linear(
356            p / "dense",
357            config.hidden_size,
358            config.hidden_size,
359            Default::default(),
360        );
361        let dense_prediction = nn::linear(
362            p / "dense_prediction",
363            config.hidden_size,
364            1,
365            Default::default(),
366        );
367        let activation = config.hidden_act.get_function();
368        ElectraDiscriminatorHead {
369            dense,
370            dense_prediction,
371            activation,
372        }
373    }
374
375    /// Forward pass through the discriminator head
376    ///
377    /// # Arguments
378    ///
379    /// * `encoder_hidden_states` - Reference to input tensor of shape (*batch size*, *sequence_length*, *hidden_size*).
380    ///
381    /// # Returns
382    ///
383    /// * `output` - `Tensor` of shape (*batch size*, *sequence_length*)
384    ///
385    /// # Example
386    ///
387    /// ```no_run
388    /// # use rust_bert::electra::{ElectraConfig, ElectraDiscriminatorHead};
389    /// # use tch::{nn, Device, Tensor, no_grad};
390    /// # use rust_bert::Config;
391    /// # use std::path::Path;
392    /// # use tch::kind::Kind::Float;
393    /// # let config_path = Path::new("path/to/config.json");
394    /// # let device = Device::Cpu;
395    /// # let vs = nn::VarStore::new(device);
396    /// # let config = ElectraConfig::from_file(config_path);
397    /// # let discriminator_head = ElectraDiscriminatorHead::new(&vs.root(), &config);
398    /// let (batch_size, sequence_length) = (64, 128);
399    /// let input_tensor = Tensor::rand(
400    ///     &[batch_size, sequence_length, config.hidden_size],
401    ///     (Float, device),
402    /// );
403    ///
404    /// let output = no_grad(|| discriminator_head.forward(&input_tensor));
405    /// ```
406    pub fn forward(&self, encoder_hidden_states: &Tensor) -> Tensor {
407        let output = encoder_hidden_states.apply(&self.dense);
408        let output = (self.activation.get_fn())(&output);
409        output.apply(&self.dense_prediction).squeeze()
410    }
411}
412
413/// # Electra Generator head
414/// Generator head for Electra models
415/// It is made of the following blocks:
416/// - `dense`: linear layer of dimension (*hidden_size*, *embeddings_size*) to project the model output dimension  to the embeddings size
417/// - `layer_norm`: Layer normalization
418/// - `activation`: GeLU activation
419pub struct ElectraGeneratorHead {
420    dense: nn::Linear,
421    layer_norm: nn::LayerNorm,
422    activation: TensorFunction,
423}
424
425/// Defines the implementation of the ElectraGeneratorHead.
426impl ElectraGeneratorHead {
427    /// Build a new `ElectraGeneratorHead`
428    ///
429    /// # Arguments
430    ///
431    /// * `p` - Variable store path for the root of the Electra model
432    /// * `config` - `ElectraConfig` object defining the model architecture
433    ///
434    /// # Example
435    ///
436    /// ```no_run
437    /// use rust_bert::electra::{ElectraConfig, ElectraGeneratorHead};
438    /// use rust_bert::Config;
439    /// use std::path::Path;
440    /// use tch::{nn, Device};
441    ///
442    /// let config_path = Path::new("path/to/config.json");
443    /// let device = Device::Cpu;
444    /// let p = nn::VarStore::new(device);
445    /// let config = ElectraConfig::from_file(config_path);
446    /// let generator_head = ElectraGeneratorHead::new(&p.root() / "electra", &config);
447    /// ```
448    pub fn new<'p, P>(p: P, config: &ElectraConfig) -> ElectraGeneratorHead
449    where
450        P: Borrow<nn::Path<'p>>,
451    {
452        let p = p.borrow();
453
454        let layer_norm = nn::layer_norm(
455            p / "LayerNorm",
456            vec![config.embedding_size],
457            Default::default(),
458        );
459        let dense = nn::linear(
460            p / "dense",
461            config.hidden_size,
462            config.embedding_size,
463            Default::default(),
464        );
465        let activation = Activation::gelu.get_function();
466
467        ElectraGeneratorHead {
468            dense,
469            layer_norm,
470            activation,
471        }
472    }
473
474    /// Forward pass through the generator head
475    ///
476    /// # Arguments
477    ///
478    /// * `encoder_hidden_states` - Reference to input tensor of shape (*batch size*, *sequence_length*, *hidden_size*).
479    ///
480    /// # Returns
481    ///
482    /// * `output` - `Tensor` of shape (*batch size*, *sequence_length*, *embeddings_size*)
483    ///
484    /// # Example
485    ///
486    /// ```no_run
487    /// # use rust_bert::electra::{ElectraConfig, ElectraGeneratorHead};
488    /// # use tch::{nn, Device, Tensor, no_grad};
489    /// # use rust_bert::Config;
490    /// # use std::path::Path;
491    /// # use tch::kind::Kind::Float;
492    /// # let config_path = Path::new("path/to/config.json");
493    /// # let device = Device::Cpu;
494    /// # let vs = nn::VarStore::new(device);
495    /// # let config = ElectraConfig::from_file(config_path);
496    /// # let generator_head = ElectraGeneratorHead::new(&vs.root(), &config);
497    /// let (batch_size, sequence_length) = (64, 128);
498    /// let input_tensor = Tensor::rand(
499    ///     &[batch_size, sequence_length, config.hidden_size],
500    ///     (Float, device),
501    /// );
502    ///
503    /// let output = no_grad(|| generator_head.forward(&input_tensor));
504    /// ```
505    pub fn forward(&self, encoder_hidden_states: &Tensor) -> Tensor {
506        let output = encoder_hidden_states.apply(&self.dense);
507        let output = (self.activation.get_fn())(&output);
508        output.apply(&self.layer_norm)
509    }
510}
511
512/// # Electra for Masked Language Modeling
513/// Masked Language modeling Electra model
514/// It is made of the following blocks:
515/// - `electra`: `ElectraModel` (based on a `BertEncoder` and custom embeddings)
516/// - `generator_head`: `ElectraGeneratorHead` to generate token predictions of dimension *embedding_size*
517/// - `lm_head`: linear layer of dimension (*embeddings_size*, *vocab_size*) to project the output to the vocab size
518pub struct ElectraForMaskedLM {
519    electra: ElectraModel,
520    generator_head: ElectraGeneratorHead,
521    lm_head: nn::Linear,
522}
523
524/// Defines the implementation of the ElectraForMaskedLM.
525impl ElectraForMaskedLM {
526    /// Build a new `ElectraForMaskedLM`
527    ///
528    /// # Arguments
529    ///
530    /// * `p` - Variable store path for the root of the Electra model
531    /// * `config` - `ElectraConfig` object defining the model architecture
532    ///
533    /// # Example
534    ///
535    /// ```no_run
536    /// use rust_bert::electra::{ElectraConfig, ElectraForMaskedLM};
537    /// use rust_bert::Config;
538    /// use std::path::Path;
539    /// use tch::{nn, Device};
540    ///
541    /// let config_path = Path::new("path/to/config.json");
542    /// let device = Device::Cpu;
543    /// let p = nn::VarStore::new(device);
544    /// let config = ElectraConfig::from_file(config_path);
545    /// let electra_model: ElectraForMaskedLM = ElectraForMaskedLM::new(&p.root(), &config);
546    /// ```
547    pub fn new<'p, P>(p: P, config: &ElectraConfig) -> ElectraForMaskedLM
548    where
549        P: Borrow<nn::Path<'p>>,
550    {
551        let p = p.borrow();
552
553        let electra = ElectraModel::new(p / "electra", config);
554        let generator_head = ElectraGeneratorHead::new(p / "generator_predictions", config);
555        let lm_head = nn::linear(
556            p / "generator_lm_head",
557            config.embedding_size,
558            config.vocab_size,
559            Default::default(),
560        );
561
562        ElectraForMaskedLM {
563            electra,
564            generator_head,
565            lm_head,
566        }
567    }
568
569    /// Forward pass through the model
570    ///
571    /// # Arguments
572    ///
573    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
574    /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
575    /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
576    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
577    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
578    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
579    ///
580    /// # Returns
581    ///
582    /// * `ElectraMaskedLMOutput` containing:
583    ///   - `prediction_scores` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*)
584    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
585    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
586    ///
587    /// # Example
588    ///
589    /// ```no_run
590    /// # use rust_bert::electra::{ElectraForMaskedLM, ElectraConfig};
591    /// # use tch::{nn, Device, Tensor, no_grad};
592    /// # use rust_bert::Config;
593    /// # use std::path::Path;
594    /// # use tch::kind::Kind::Int64;
595    /// # let config_path = Path::new("path/to/config.json");
596    /// # let device = Device::Cpu;
597    /// # let vs = nn::VarStore::new(device);
598    /// # let config = ElectraConfig::from_file(config_path);
599    /// # let electra_model: ElectraForMaskedLM = ElectraForMaskedLM::new(&vs.root(), &config);
600    /// let (batch_size, sequence_length) = (64, 128);
601    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
602    /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
603    /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
604    /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
605    ///     .expand(&[batch_size, sequence_length], true);
606    ///
607    /// let model_output = no_grad(|| {
608    ///     electra_model.forward_t(
609    ///         Some(&input_tensor),
610    ///         Some(&mask),
611    ///         Some(&token_type_ids),
612    ///         Some(&position_ids),
613    ///         None,
614    ///         false,
615    ///     )
616    /// });
617    /// ```
618    pub fn forward_t(
619        &self,
620        input_ids: Option<&Tensor>,
621        mask: Option<&Tensor>,
622        token_type_ids: Option<&Tensor>,
623        position_ids: Option<&Tensor>,
624        input_embeds: Option<&Tensor>,
625        train: bool,
626    ) -> ElectraMaskedLMOutput {
627        let base_model_output = self
628            .electra
629            .forward_t(
630                input_ids,
631                mask,
632                token_type_ids,
633                position_ids,
634                input_embeds,
635                train,
636            )
637            .unwrap();
638        let hidden_states = self.generator_head.forward(&base_model_output.hidden_state);
639        let prediction_scores = hidden_states.apply(&self.lm_head);
640        ElectraMaskedLMOutput {
641            prediction_scores,
642            all_hidden_states: base_model_output.all_hidden_states,
643            all_attentions: base_model_output.all_attentions,
644        }
645    }
646}
647
648/// # Electra Discriminator
649/// Electra discriminator model
650/// It is made of the following blocks:
651/// - `electra`: `ElectraModel` (based on a `BertEncoder` and custom embeddings)
652/// - `discriminator_head`: `ElectraDiscriminatorHead` to classify each token into either `original` or `generated`
653pub struct ElectraDiscriminator {
654    electra: ElectraModel,
655    discriminator_head: ElectraDiscriminatorHead,
656}
657
658/// Defines the implementation of the ElectraDiscriminator.
659impl ElectraDiscriminator {
660    /// Build a new `ElectraDiscriminator`
661    ///
662    /// # Arguments
663    ///
664    /// * `p` - Variable store path for the root of the Electra model
665    /// * `config` - `ElectraConfig` object defining the model architecture
666    ///
667    /// # Example
668    ///
669    /// ```no_run
670    /// use rust_bert::electra::{ElectraConfig, ElectraDiscriminator};
671    /// use rust_bert::Config;
672    /// use std::path::Path;
673    /// use tch::{nn, Device};
674    ///
675    /// let config_path = Path::new("path/to/config.json");
676    /// let device = Device::Cpu;
677    /// let p = nn::VarStore::new(device);
678    /// let config = ElectraConfig::from_file(config_path);
679    /// let electra_model: ElectraDiscriminator = ElectraDiscriminator::new(&p.root(), &config);
680    /// ```
681    pub fn new<'p, P>(p: P, config: &ElectraConfig) -> ElectraDiscriminator
682    where
683        P: Borrow<nn::Path<'p>>,
684    {
685        let p = p.borrow();
686
687        let electra = ElectraModel::new(p / "electra", config);
688        let discriminator_head =
689            ElectraDiscriminatorHead::new(p / "discriminator_predictions", config);
690
691        ElectraDiscriminator {
692            electra,
693            discriminator_head,
694        }
695    }
696
697    /// Forward pass through the model
698    ///
699    /// # Arguments
700    ///
701    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
702    /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
703    /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
704    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
705    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
706    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
707    ///
708    /// # Returns
709    ///
710    /// * `ElectraDiscriminatorOutput` containing:
711    ///   - `logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the probability of each token to be generated by a language model
712    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
713    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
714    ///
715    /// # Example
716    ///
717    /// ```no_run
718    /// # use rust_bert::electra::{ElectraDiscriminator, ElectraConfig};
719    /// # use tch::{nn, Device, Tensor, no_grad};
720    /// # use rust_bert::Config;
721    /// # use std::path::Path;
722    /// # use tch::kind::Kind::Int64;
723    /// # let config_path = Path::new("path/to/config.json");
724    /// # let device = Device::Cpu;
725    /// # let vs = nn::VarStore::new(device);
726    /// # let config = ElectraConfig::from_file(config_path);
727    /// # let electra_model: ElectraDiscriminator = ElectraDiscriminator::new(&vs.root(), &config);
728    ///  let (batch_size, sequence_length) = (64, 128);
729    ///  let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
730    ///  let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
731    ///  let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
732    ///  let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
733    ///
734    ///  let model_output = no_grad(|| {
735    ///    electra_model
736    ///         .forward_t(Some(&input_tensor),
737    ///                    Some(&mask),
738    ///                    Some(&token_type_ids),
739    ///                    Some(&position_ids),
740    ///                    None,
741    ///                    false)
742    ///    });
743    /// ```
744    pub fn forward_t(
745        &self,
746        input_ids: Option<&Tensor>,
747        mask: Option<&Tensor>,
748        token_type_ids: Option<&Tensor>,
749        position_ids: Option<&Tensor>,
750        input_embeds: Option<&Tensor>,
751        train: bool,
752    ) -> ElectraDiscriminatorOutput {
753        let base_model_output = self
754            .electra
755            .forward_t(
756                input_ids,
757                mask,
758                token_type_ids,
759                position_ids,
760                input_embeds,
761                train,
762            )
763            .unwrap();
764        let probabilities = self
765            .discriminator_head
766            .forward(&base_model_output.hidden_state)
767            .sigmoid();
768        ElectraDiscriminatorOutput {
769            probabilities,
770            all_hidden_states: base_model_output.all_hidden_states,
771            all_attentions: base_model_output.all_attentions,
772        }
773    }
774}
775
776/// # Electra for token classification (e.g. POS, NER)
777/// Electra model with a token tagging head
778/// It is made of the following blocks:
779/// - `electra`: `ElectraModel` (based on a `BertEncoder` and custom embeddings)
780/// - `dropout`: Dropout layer
781/// - `classifier`: linear layer of dimension (*hidden_size*, *num_classes*) to project the output to the target label space
782pub struct ElectraForTokenClassification {
783    electra: ElectraModel,
784    dropout: Dropout,
785    classifier: nn::Linear,
786}
787
788/// Defines the implementation of the ElectraForTokenClassification.
789impl ElectraForTokenClassification {
790    /// Build a new `ElectraForTokenClassification`
791    ///
792    /// # Arguments
793    ///
794    /// * `p` - Variable store path for the root of the Electra model
795    /// * `config` - `ElectraConfig` object defining the model architecture
796    ///
797    /// # Example
798    ///
799    /// ```no_run
800    /// use rust_bert::electra::{ElectraConfig, ElectraForTokenClassification};
801    /// use rust_bert::Config;
802    /// use std::path::Path;
803    /// use tch::{nn, Device};
804    /// let config_path = Path::new("path/to/config.json");
805    /// let device = Device::Cpu;
806    /// let p = nn::VarStore::new(device);
807    /// let config = ElectraConfig::from_file(config_path);
808    /// let electra_model: ElectraForTokenClassification =
809    ///     ElectraForTokenClassification::new(&p.root(), &config).unwrap();
810    /// ```
811    pub fn new<'p, P>(
812        p: P,
813        config: &ElectraConfig,
814    ) -> Result<ElectraForTokenClassification, RustBertError>
815    where
816        P: Borrow<nn::Path<'p>>,
817    {
818        let p = p.borrow();
819
820        let electra = ElectraModel::new(p / "electra", config);
821        let dropout = Dropout::new(config.hidden_dropout_prob);
822        let num_labels = config
823            .id2label
824            .as_ref()
825            .ok_or_else(|| {
826                RustBertError::InvalidConfigurationError(
827                    "id2label must be provided for classifiers".to_string(),
828                )
829            })?
830            .len() as i64;
831        let classifier = nn::linear(
832            p / "classifier",
833            config.hidden_size,
834            num_labels,
835            Default::default(),
836        );
837
838        Ok(ElectraForTokenClassification {
839            electra,
840            dropout,
841            classifier,
842        })
843    }
844
845    /// Forward pass through the model
846    ///
847    /// # Arguments
848    ///
849    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
850    /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
851    /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
852    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
853    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
854    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
855    ///
856    /// # Returns
857    ///
858    /// * `ElectraTokenClassificationOutput` containing:
859    ///   - `logits` - `Tensor` of shape (*batch size*, *sequence_length*, *num_labels*) containing the logits for each of the input tokens and classes
860    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
861    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
862    ///
863    /// # Example
864    ///
865    /// ```no_run
866    /// # use rust_bert::electra::{ElectraForTokenClassification, ElectraConfig};
867    /// # use tch::{nn, Device, Tensor, no_grad};
868    /// # use rust_bert::Config;
869    /// # use std::path::Path;
870    /// # use tch::kind::Kind::Int64;
871    /// # let config_path = Path::new("path/to/config.json");
872    /// # let device = Device::Cpu;
873    /// # let vs = nn::VarStore::new(device);
874    /// # let config = ElectraConfig::from_file(config_path);
875    /// # let electra_model: ElectraForTokenClassification = ElectraForTokenClassification::new(&vs.root(), &config).unwrap();
876    ///  let (batch_size, sequence_length) = (64, 128);
877    ///  let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
878    ///  let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
879    ///  let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
880    ///  let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
881    ///
882    ///  let model_output = no_grad(|| {
883    ///    electra_model
884    ///         .forward_t(Some(&input_tensor),
885    ///                    Some(&mask),
886    ///                    Some(&token_type_ids),
887    ///                    Some(&position_ids),
888    ///                    None,
889    ///                    false)
890    ///    });
891    /// ```
892    pub fn forward_t(
893        &self,
894        input_ids: Option<&Tensor>,
895        mask: Option<&Tensor>,
896        token_type_ids: Option<&Tensor>,
897        position_ids: Option<&Tensor>,
898        input_embeds: Option<&Tensor>,
899        train: bool,
900    ) -> ElectraTokenClassificationOutput {
901        let base_model_output = self
902            .electra
903            .forward_t(
904                input_ids,
905                mask,
906                token_type_ids,
907                position_ids,
908                input_embeds,
909                train,
910            )
911            .unwrap();
912        let logits = base_model_output
913            .hidden_state
914            .apply_t(&self.dropout, train)
915            .apply(&self.classifier);
916        ElectraTokenClassificationOutput {
917            logits,
918            all_hidden_states: base_model_output.all_hidden_states,
919            all_attentions: base_model_output.all_attentions,
920        }
921    }
922}
923
924/// Container for the Electra model output.
925pub struct ElectraModelOutput {
926    /// Last hidden states from the model
927    pub hidden_state: Tensor,
928    /// Hidden states for all intermediate layers
929    pub all_hidden_states: Option<Vec<Tensor>>,
930    /// Attention weights for all intermediate layers
931    pub all_attentions: Option<Vec<Tensor>>,
932}
933
934/// Container for the Electra discriminator model output.
935pub struct ElectraDiscriminatorOutput {
936    /// Probabilities for each sequence item (token) to be generated by a language model
937    pub probabilities: Tensor,
938    /// Hidden states for all intermediate layers
939    pub all_hidden_states: Option<Vec<Tensor>>,
940    /// Attention weights for all intermediate layers
941    pub all_attentions: Option<Vec<Tensor>>,
942}
943
944/// Container for the Electra masked LM model output.
945pub struct ElectraMaskedLMOutput {
946    /// Logits for the vocabulary items at each sequence position
947    pub prediction_scores: Tensor,
948    /// Hidden states for all intermediate layers
949    pub all_hidden_states: Option<Vec<Tensor>>,
950    /// Attention weights for all intermediate layers
951    pub all_attentions: Option<Vec<Tensor>>,
952}
953
954/// Container for the Electra token classification model output.
955pub struct ElectraTokenClassificationOutput {
956    /// Logits for each sequence item (token) for each target class
957    pub logits: Tensor,
958    /// Hidden states for all intermediate layers
959    pub all_hidden_states: Option<Vec<Tensor>>,
960    /// Attention weights for all intermediate layers
961    pub all_attentions: Option<Vec<Tensor>>,
962}