rust_bert/models/electra/electra_model.rs
1// Copyright 2020 The Google Research Authors.
2// Copyright 2019-present, the HuggingFace Inc. team
3// Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4// Copyright 2019 Guillaume Becquin
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8// http://www.apache.org/licenses/LICENSE-2.0
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::bert::BertConfig;
16use crate::common::activations::Activation;
17use crate::common::dropout::Dropout;
18use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair;
19use crate::electra::embeddings::ElectraEmbeddings;
20use crate::{bert::encoder::BertEncoder, common::activations::TensorFunction};
21use crate::{Config, RustBertError};
22use serde::{Deserialize, Serialize};
23use std::{borrow::Borrow, collections::HashMap};
24use tch::{nn, Kind, Tensor};
25
26/// # Electra Pretrained model weight files
27pub struct ElectraModelResources;
28
29/// # Electra Pretrained model config files
30pub struct ElectraConfigResources;
31
32/// # Electra Pretrained model vocab files
33pub struct ElectraVocabResources;
34
35impl ElectraModelResources {
36 /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/electra>. Modified with conversion to C-array format.
37 pub const BASE_GENERATOR: (&'static str, &'static str) = (
38 "electra-base-generator/model",
39 "https://huggingface.co/google/electra-base-generator/resolve/main/rust_model.ot",
40 );
41 /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/electra>. Modified with conversion to C-array format.
42 pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = (
43 "electra-base-discriminator/model",
44 "https://huggingface.co/google/electra-base-discriminator/resolve/main/rust_model.ot",
45 );
46}
47
48impl ElectraConfigResources {
49 /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/electra>. Modified with conversion to C-array format.
50 pub const BASE_GENERATOR: (&'static str, &'static str) = (
51 "electra-base-generator/config",
52 "https://huggingface.co/google/electra-base-generator/resolve/main/config.json",
53 );
54 /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/electra>. Modified with conversion to C-array format.
55 pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = (
56 "electra-base-discriminator/config",
57 "https://huggingface.co/google/electra-base-discriminator/resolve/main/config.json",
58 );
59}
60
61impl ElectraVocabResources {
62 /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/electra>. Modified with conversion to C-array format.
63 pub const BASE_GENERATOR: (&'static str, &'static str) = (
64 "electra-base-generator/vocab",
65 "https://huggingface.co/google/electra-base-generator/resolve/main/vocab.txt",
66 );
67 /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/electra>. Modified with conversion to C-array format.
68 pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = (
69 "electra-base-discriminator/vocab",
70 "https://huggingface.co/google/electra-base-discriminator/resolve/main/vocab.txt",
71 );
72}
73
74#[derive(Debug, Serialize, Deserialize, Clone)]
75/// # Electra model configuration
76/// Defines the Electra model architecture (e.g. number of layers, hidden layer size, label mapping...)
77pub struct ElectraConfig {
78 pub hidden_act: Activation,
79 pub attention_probs_dropout_prob: f64,
80 pub embedding_size: i64,
81 pub hidden_dropout_prob: f64,
82 pub hidden_size: i64,
83 pub initializer_range: f32,
84 pub layer_norm_eps: Option<f64>,
85 pub intermediate_size: i64,
86 pub max_position_embeddings: i64,
87 pub num_attention_heads: i64,
88 pub num_hidden_layers: i64,
89 pub type_vocab_size: i64,
90 pub vocab_size: i64,
91 pub pad_token_id: i64,
92 pub output_past: Option<bool>,
93 pub output_attentions: Option<bool>,
94 pub output_hidden_states: Option<bool>,
95 pub id2label: Option<HashMap<i64, String>>,
96 pub label2id: Option<HashMap<String, i64>>,
97}
98
99impl Config for ElectraConfig {}
100
101impl Default for ElectraConfig {
102 fn default() -> Self {
103 ElectraConfig {
104 hidden_act: Activation::gelu,
105 attention_probs_dropout_prob: 0.1,
106 embedding_size: 128,
107 hidden_dropout_prob: 0.1,
108 hidden_size: 256,
109 initializer_range: 0.02,
110 layer_norm_eps: Some(1e-12),
111 intermediate_size: 1024,
112 max_position_embeddings: 512,
113 num_attention_heads: 4,
114 num_hidden_layers: 12,
115 type_vocab_size: 2,
116 vocab_size: 30522,
117 pad_token_id: 0,
118 output_past: None,
119 output_attentions: None,
120 output_hidden_states: None,
121 id2label: None,
122 label2id: None,
123 }
124 }
125}
126
127/// # Electra Base model
128/// Base architecture for Electra models.
129/// It is made of the following blocks:
130/// - `embeddings`: `token`, `position` and `segment_id` embeddings. Note that in contrast to BERT, the embeddings dimension is not necessarily equal to the hidden layer dimensions
131/// - `encoder`: BertEncoder (transformer) made of a vector of layers. Each layer is made of a self-attention layer, an intermediate (linear) and output (linear + layer norm) layers
132/// - `embeddings_project`: (optional) linear layer applied to project the embeddings space to the hidden layer dimension
133pub struct ElectraModel {
134 embeddings: ElectraEmbeddings,
135 embeddings_project: Option<nn::Linear>,
136 encoder: BertEncoder,
137}
138
139/// Defines the implementation of the ElectraModel.
140impl ElectraModel {
141 /// Build a new `ElectraModel`
142 ///
143 /// # Arguments
144 ///
145 /// * `p` - Variable store path for the root of the Electra model
146 /// * `config` - `ElectraConfig` object defining the model architecture
147 ///
148 /// # Example
149 ///
150 /// ```no_run
151 /// use rust_bert::electra::{ElectraConfig, ElectraModel};
152 /// use rust_bert::Config;
153 /// use std::path::Path;
154 /// use tch::{nn, Device};
155 ///
156 /// let config_path = Path::new("path/to/config.json");
157 /// let device = Device::Cpu;
158 /// let p = nn::VarStore::new(device);
159 /// let config = ElectraConfig::from_file(config_path);
160 /// let electra_model: ElectraModel = ElectraModel::new(&p.root() / "electra", &config);
161 /// ```
162 pub fn new<'p, P>(p: P, config: &ElectraConfig) -> ElectraModel
163 where
164 P: Borrow<nn::Path<'p>>,
165 {
166 let p = p.borrow();
167
168 let embeddings = ElectraEmbeddings::new(p / "embeddings", config);
169 let embeddings_project = if config.embedding_size != config.hidden_size {
170 Some(nn::linear(
171 p / "embeddings_project",
172 config.embedding_size,
173 config.hidden_size,
174 Default::default(),
175 ))
176 } else {
177 None
178 };
179 let bert_config = BertConfig {
180 hidden_act: config.hidden_act,
181 attention_probs_dropout_prob: config.attention_probs_dropout_prob,
182 hidden_dropout_prob: config.hidden_dropout_prob,
183 hidden_size: config.hidden_size,
184 initializer_range: config.initializer_range,
185 intermediate_size: config.intermediate_size,
186 max_position_embeddings: config.max_position_embeddings,
187 num_attention_heads: config.num_attention_heads,
188 num_hidden_layers: config.num_hidden_layers,
189 type_vocab_size: config.type_vocab_size,
190 vocab_size: config.vocab_size,
191 output_attentions: config.output_attentions,
192 output_hidden_states: config.output_hidden_states,
193 is_decoder: None,
194 id2label: config.id2label.clone(),
195 label2id: config.label2id.clone(),
196 };
197 let encoder = BertEncoder::new(p / "encoder", &bert_config);
198 ElectraModel {
199 embeddings,
200 embeddings_project,
201 encoder,
202 }
203 }
204
205 /// Forward pass through the model
206 ///
207 /// # Arguments
208 ///
209 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
210 /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
211 /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
212 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
213 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
214 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
215 ///
216 /// # Returns
217 ///
218 /// * `ElectraModelOutput` containing:
219 /// - `hidden_state` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*)
220 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
221 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
222 ///
223 /// # Example
224 ///
225 /// ```no_run
226 /// # use rust_bert::electra::{ElectraModel, ElectraConfig};
227 /// # use tch::{nn, Device, Tensor, no_grad};
228 /// # use rust_bert::Config;
229 /// # use std::path::Path;
230 /// # use tch::kind::Kind::Int64;
231 /// # let config_path = Path::new("path/to/config.json");
232 /// # let device = Device::Cpu;
233 /// # let vs = nn::VarStore::new(device);
234 /// # let config = ElectraConfig::from_file(config_path);
235 /// # let electra_model: ElectraModel = ElectraModel::new(&vs.root(), &config);
236 /// let (batch_size, sequence_length) = (64, 128);
237 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
238 /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
239 /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
240 /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
241 /// .expand(&[batch_size, sequence_length], true);
242 ///
243 /// let model_output = no_grad(|| {
244 /// electra_model
245 /// .forward_t(
246 /// Some(&input_tensor),
247 /// Some(&mask),
248 /// Some(&token_type_ids),
249 /// Some(&position_ids),
250 /// None,
251 /// false,
252 /// )
253 /// .unwrap()
254 /// });
255 /// ```
256 pub fn forward_t(
257 &self,
258 input_ids: Option<&Tensor>,
259 mask: Option<&Tensor>,
260 token_type_ids: Option<&Tensor>,
261 position_ids: Option<&Tensor>,
262 input_embeds: Option<&Tensor>,
263 train: bool,
264 ) -> Result<ElectraModelOutput, RustBertError> {
265 let (input_shape, device) =
266 get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeds)?;
267
268 let calc_mask = if mask.is_none() {
269 Some(Tensor::ones(input_shape, (Kind::Int64, device)))
270 } else {
271 None
272 };
273 let mask = mask.unwrap_or_else(|| calc_mask.as_ref().unwrap());
274
275 let extended_attention_mask = match mask.dim() {
276 3 => mask.unsqueeze(1),
277 2 => mask.unsqueeze(1).unsqueeze(1),
278 _ => {
279 return Err(RustBertError::ValueError(
280 "Invalid attention mask dimension, must be 2 or 3".into(),
281 ));
282 }
283 };
284
285 let hidden_states = self.embeddings.forward_t(
286 input_ids,
287 token_type_ids,
288 position_ids,
289 input_embeds,
290 train,
291 )?;
292
293 let hidden_states = match &self.embeddings_project {
294 Some(layer) => hidden_states.apply(layer),
295 None => hidden_states,
296 };
297
298 let encoder_output = self.encoder.forward_t(
299 &hidden_states,
300 Some(&extended_attention_mask),
301 None,
302 None,
303 train,
304 );
305
306 Ok(ElectraModelOutput {
307 hidden_state: encoder_output.hidden_state,
308 all_hidden_states: encoder_output.all_hidden_states,
309 all_attentions: encoder_output.all_attentions,
310 })
311 }
312}
313
314/// # Electra Discriminator head
315/// Discriminator head for Electra models
316/// It is made of the following blocks:
317/// - `dense`: linear layer of dimension (*hidden_size*, *hidden_size*)
318/// - `dense_prediction`: linear layer of dimension (*hidden_size*, *1*) mapping the model output to a 1-dimension space to identify original and generated tokens
319/// - `activation`: activation layer (one of GeLU, ReLU or Mish)
320pub struct ElectraDiscriminatorHead {
321 dense: nn::Linear,
322 dense_prediction: nn::Linear,
323 activation: TensorFunction,
324}
325
326/// Defines the implementation of the ElectraDiscriminatorHead.
327impl ElectraDiscriminatorHead {
328 /// Build a new `ElectraDiscriminatorHead`
329 ///
330 /// # Arguments
331 ///
332 /// * `p` - Variable store path for the root of the Electra model
333 /// * `config` - `ElectraConfig` object defining the model architecture
334 ///
335 /// # Example
336 ///
337 /// ```no_run
338 /// use rust_bert::electra::{ElectraConfig, ElectraDiscriminatorHead};
339 /// use rust_bert::Config;
340 /// use std::path::Path;
341 /// use tch::{nn, Device};
342 ///
343 /// let config_path = Path::new("path/to/config.json");
344 /// let device = Device::Cpu;
345 /// let p = nn::VarStore::new(device);
346 /// let config = ElectraConfig::from_file(config_path);
347 /// let discriminator_head = ElectraDiscriminatorHead::new(&p.root() / "electra", &config);
348 /// ```
349 pub fn new<'p, P>(p: P, config: &ElectraConfig) -> ElectraDiscriminatorHead
350 where
351 P: Borrow<nn::Path<'p>>,
352 {
353 let p = p.borrow();
354
355 let dense = nn::linear(
356 p / "dense",
357 config.hidden_size,
358 config.hidden_size,
359 Default::default(),
360 );
361 let dense_prediction = nn::linear(
362 p / "dense_prediction",
363 config.hidden_size,
364 1,
365 Default::default(),
366 );
367 let activation = config.hidden_act.get_function();
368 ElectraDiscriminatorHead {
369 dense,
370 dense_prediction,
371 activation,
372 }
373 }
374
375 /// Forward pass through the discriminator head
376 ///
377 /// # Arguments
378 ///
379 /// * `encoder_hidden_states` - Reference to input tensor of shape (*batch size*, *sequence_length*, *hidden_size*).
380 ///
381 /// # Returns
382 ///
383 /// * `output` - `Tensor` of shape (*batch size*, *sequence_length*)
384 ///
385 /// # Example
386 ///
387 /// ```no_run
388 /// # use rust_bert::electra::{ElectraConfig, ElectraDiscriminatorHead};
389 /// # use tch::{nn, Device, Tensor, no_grad};
390 /// # use rust_bert::Config;
391 /// # use std::path::Path;
392 /// # use tch::kind::Kind::Float;
393 /// # let config_path = Path::new("path/to/config.json");
394 /// # let device = Device::Cpu;
395 /// # let vs = nn::VarStore::new(device);
396 /// # let config = ElectraConfig::from_file(config_path);
397 /// # let discriminator_head = ElectraDiscriminatorHead::new(&vs.root(), &config);
398 /// let (batch_size, sequence_length) = (64, 128);
399 /// let input_tensor = Tensor::rand(
400 /// &[batch_size, sequence_length, config.hidden_size],
401 /// (Float, device),
402 /// );
403 ///
404 /// let output = no_grad(|| discriminator_head.forward(&input_tensor));
405 /// ```
406 pub fn forward(&self, encoder_hidden_states: &Tensor) -> Tensor {
407 let output = encoder_hidden_states.apply(&self.dense);
408 let output = (self.activation.get_fn())(&output);
409 output.apply(&self.dense_prediction).squeeze()
410 }
411}
412
413/// # Electra Generator head
414/// Generator head for Electra models
415/// It is made of the following blocks:
416/// - `dense`: linear layer of dimension (*hidden_size*, *embeddings_size*) to project the model output dimension to the embeddings size
417/// - `layer_norm`: Layer normalization
418/// - `activation`: GeLU activation
419pub struct ElectraGeneratorHead {
420 dense: nn::Linear,
421 layer_norm: nn::LayerNorm,
422 activation: TensorFunction,
423}
424
425/// Defines the implementation of the ElectraGeneratorHead.
426impl ElectraGeneratorHead {
427 /// Build a new `ElectraGeneratorHead`
428 ///
429 /// # Arguments
430 ///
431 /// * `p` - Variable store path for the root of the Electra model
432 /// * `config` - `ElectraConfig` object defining the model architecture
433 ///
434 /// # Example
435 ///
436 /// ```no_run
437 /// use rust_bert::electra::{ElectraConfig, ElectraGeneratorHead};
438 /// use rust_bert::Config;
439 /// use std::path::Path;
440 /// use tch::{nn, Device};
441 ///
442 /// let config_path = Path::new("path/to/config.json");
443 /// let device = Device::Cpu;
444 /// let p = nn::VarStore::new(device);
445 /// let config = ElectraConfig::from_file(config_path);
446 /// let generator_head = ElectraGeneratorHead::new(&p.root() / "electra", &config);
447 /// ```
448 pub fn new<'p, P>(p: P, config: &ElectraConfig) -> ElectraGeneratorHead
449 where
450 P: Borrow<nn::Path<'p>>,
451 {
452 let p = p.borrow();
453
454 let layer_norm = nn::layer_norm(
455 p / "LayerNorm",
456 vec![config.embedding_size],
457 Default::default(),
458 );
459 let dense = nn::linear(
460 p / "dense",
461 config.hidden_size,
462 config.embedding_size,
463 Default::default(),
464 );
465 let activation = Activation::gelu.get_function();
466
467 ElectraGeneratorHead {
468 dense,
469 layer_norm,
470 activation,
471 }
472 }
473
474 /// Forward pass through the generator head
475 ///
476 /// # Arguments
477 ///
478 /// * `encoder_hidden_states` - Reference to input tensor of shape (*batch size*, *sequence_length*, *hidden_size*).
479 ///
480 /// # Returns
481 ///
482 /// * `output` - `Tensor` of shape (*batch size*, *sequence_length*, *embeddings_size*)
483 ///
484 /// # Example
485 ///
486 /// ```no_run
487 /// # use rust_bert::electra::{ElectraConfig, ElectraGeneratorHead};
488 /// # use tch::{nn, Device, Tensor, no_grad};
489 /// # use rust_bert::Config;
490 /// # use std::path::Path;
491 /// # use tch::kind::Kind::Float;
492 /// # let config_path = Path::new("path/to/config.json");
493 /// # let device = Device::Cpu;
494 /// # let vs = nn::VarStore::new(device);
495 /// # let config = ElectraConfig::from_file(config_path);
496 /// # let generator_head = ElectraGeneratorHead::new(&vs.root(), &config);
497 /// let (batch_size, sequence_length) = (64, 128);
498 /// let input_tensor = Tensor::rand(
499 /// &[batch_size, sequence_length, config.hidden_size],
500 /// (Float, device),
501 /// );
502 ///
503 /// let output = no_grad(|| generator_head.forward(&input_tensor));
504 /// ```
505 pub fn forward(&self, encoder_hidden_states: &Tensor) -> Tensor {
506 let output = encoder_hidden_states.apply(&self.dense);
507 let output = (self.activation.get_fn())(&output);
508 output.apply(&self.layer_norm)
509 }
510}
511
512/// # Electra for Masked Language Modeling
513/// Masked Language modeling Electra model
514/// It is made of the following blocks:
515/// - `electra`: `ElectraModel` (based on a `BertEncoder` and custom embeddings)
516/// - `generator_head`: `ElectraGeneratorHead` to generate token predictions of dimension *embedding_size*
517/// - `lm_head`: linear layer of dimension (*embeddings_size*, *vocab_size*) to project the output to the vocab size
518pub struct ElectraForMaskedLM {
519 electra: ElectraModel,
520 generator_head: ElectraGeneratorHead,
521 lm_head: nn::Linear,
522}
523
524/// Defines the implementation of the ElectraForMaskedLM.
525impl ElectraForMaskedLM {
526 /// Build a new `ElectraForMaskedLM`
527 ///
528 /// # Arguments
529 ///
530 /// * `p` - Variable store path for the root of the Electra model
531 /// * `config` - `ElectraConfig` object defining the model architecture
532 ///
533 /// # Example
534 ///
535 /// ```no_run
536 /// use rust_bert::electra::{ElectraConfig, ElectraForMaskedLM};
537 /// use rust_bert::Config;
538 /// use std::path::Path;
539 /// use tch::{nn, Device};
540 ///
541 /// let config_path = Path::new("path/to/config.json");
542 /// let device = Device::Cpu;
543 /// let p = nn::VarStore::new(device);
544 /// let config = ElectraConfig::from_file(config_path);
545 /// let electra_model: ElectraForMaskedLM = ElectraForMaskedLM::new(&p.root(), &config);
546 /// ```
547 pub fn new<'p, P>(p: P, config: &ElectraConfig) -> ElectraForMaskedLM
548 where
549 P: Borrow<nn::Path<'p>>,
550 {
551 let p = p.borrow();
552
553 let electra = ElectraModel::new(p / "electra", config);
554 let generator_head = ElectraGeneratorHead::new(p / "generator_predictions", config);
555 let lm_head = nn::linear(
556 p / "generator_lm_head",
557 config.embedding_size,
558 config.vocab_size,
559 Default::default(),
560 );
561
562 ElectraForMaskedLM {
563 electra,
564 generator_head,
565 lm_head,
566 }
567 }
568
569 /// Forward pass through the model
570 ///
571 /// # Arguments
572 ///
573 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
574 /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
575 /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
576 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
577 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
578 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
579 ///
580 /// # Returns
581 ///
582 /// * `ElectraMaskedLMOutput` containing:
583 /// - `prediction_scores` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*)
584 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
585 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
586 ///
587 /// # Example
588 ///
589 /// ```no_run
590 /// # use rust_bert::electra::{ElectraForMaskedLM, ElectraConfig};
591 /// # use tch::{nn, Device, Tensor, no_grad};
592 /// # use rust_bert::Config;
593 /// # use std::path::Path;
594 /// # use tch::kind::Kind::Int64;
595 /// # let config_path = Path::new("path/to/config.json");
596 /// # let device = Device::Cpu;
597 /// # let vs = nn::VarStore::new(device);
598 /// # let config = ElectraConfig::from_file(config_path);
599 /// # let electra_model: ElectraForMaskedLM = ElectraForMaskedLM::new(&vs.root(), &config);
600 /// let (batch_size, sequence_length) = (64, 128);
601 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
602 /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
603 /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
604 /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
605 /// .expand(&[batch_size, sequence_length], true);
606 ///
607 /// let model_output = no_grad(|| {
608 /// electra_model.forward_t(
609 /// Some(&input_tensor),
610 /// Some(&mask),
611 /// Some(&token_type_ids),
612 /// Some(&position_ids),
613 /// None,
614 /// false,
615 /// )
616 /// });
617 /// ```
618 pub fn forward_t(
619 &self,
620 input_ids: Option<&Tensor>,
621 mask: Option<&Tensor>,
622 token_type_ids: Option<&Tensor>,
623 position_ids: Option<&Tensor>,
624 input_embeds: Option<&Tensor>,
625 train: bool,
626 ) -> ElectraMaskedLMOutput {
627 let base_model_output = self
628 .electra
629 .forward_t(
630 input_ids,
631 mask,
632 token_type_ids,
633 position_ids,
634 input_embeds,
635 train,
636 )
637 .unwrap();
638 let hidden_states = self.generator_head.forward(&base_model_output.hidden_state);
639 let prediction_scores = hidden_states.apply(&self.lm_head);
640 ElectraMaskedLMOutput {
641 prediction_scores,
642 all_hidden_states: base_model_output.all_hidden_states,
643 all_attentions: base_model_output.all_attentions,
644 }
645 }
646}
647
648/// # Electra Discriminator
649/// Electra discriminator model
650/// It is made of the following blocks:
651/// - `electra`: `ElectraModel` (based on a `BertEncoder` and custom embeddings)
652/// - `discriminator_head`: `ElectraDiscriminatorHead` to classify each token into either `original` or `generated`
653pub struct ElectraDiscriminator {
654 electra: ElectraModel,
655 discriminator_head: ElectraDiscriminatorHead,
656}
657
658/// Defines the implementation of the ElectraDiscriminator.
659impl ElectraDiscriminator {
660 /// Build a new `ElectraDiscriminator`
661 ///
662 /// # Arguments
663 ///
664 /// * `p` - Variable store path for the root of the Electra model
665 /// * `config` - `ElectraConfig` object defining the model architecture
666 ///
667 /// # Example
668 ///
669 /// ```no_run
670 /// use rust_bert::electra::{ElectraConfig, ElectraDiscriminator};
671 /// use rust_bert::Config;
672 /// use std::path::Path;
673 /// use tch::{nn, Device};
674 ///
675 /// let config_path = Path::new("path/to/config.json");
676 /// let device = Device::Cpu;
677 /// let p = nn::VarStore::new(device);
678 /// let config = ElectraConfig::from_file(config_path);
679 /// let electra_model: ElectraDiscriminator = ElectraDiscriminator::new(&p.root(), &config);
680 /// ```
681 pub fn new<'p, P>(p: P, config: &ElectraConfig) -> ElectraDiscriminator
682 where
683 P: Borrow<nn::Path<'p>>,
684 {
685 let p = p.borrow();
686
687 let electra = ElectraModel::new(p / "electra", config);
688 let discriminator_head =
689 ElectraDiscriminatorHead::new(p / "discriminator_predictions", config);
690
691 ElectraDiscriminator {
692 electra,
693 discriminator_head,
694 }
695 }
696
697 /// Forward pass through the model
698 ///
699 /// # Arguments
700 ///
701 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
702 /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
703 /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
704 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
705 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
706 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
707 ///
708 /// # Returns
709 ///
710 /// * `ElectraDiscriminatorOutput` containing:
711 /// - `logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the probability of each token to be generated by a language model
712 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
713 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
714 ///
715 /// # Example
716 ///
717 /// ```no_run
718 /// # use rust_bert::electra::{ElectraDiscriminator, ElectraConfig};
719 /// # use tch::{nn, Device, Tensor, no_grad};
720 /// # use rust_bert::Config;
721 /// # use std::path::Path;
722 /// # use tch::kind::Kind::Int64;
723 /// # let config_path = Path::new("path/to/config.json");
724 /// # let device = Device::Cpu;
725 /// # let vs = nn::VarStore::new(device);
726 /// # let config = ElectraConfig::from_file(config_path);
727 /// # let electra_model: ElectraDiscriminator = ElectraDiscriminator::new(&vs.root(), &config);
728 /// let (batch_size, sequence_length) = (64, 128);
729 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
730 /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
731 /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
732 /// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
733 ///
734 /// let model_output = no_grad(|| {
735 /// electra_model
736 /// .forward_t(Some(&input_tensor),
737 /// Some(&mask),
738 /// Some(&token_type_ids),
739 /// Some(&position_ids),
740 /// None,
741 /// false)
742 /// });
743 /// ```
744 pub fn forward_t(
745 &self,
746 input_ids: Option<&Tensor>,
747 mask: Option<&Tensor>,
748 token_type_ids: Option<&Tensor>,
749 position_ids: Option<&Tensor>,
750 input_embeds: Option<&Tensor>,
751 train: bool,
752 ) -> ElectraDiscriminatorOutput {
753 let base_model_output = self
754 .electra
755 .forward_t(
756 input_ids,
757 mask,
758 token_type_ids,
759 position_ids,
760 input_embeds,
761 train,
762 )
763 .unwrap();
764 let probabilities = self
765 .discriminator_head
766 .forward(&base_model_output.hidden_state)
767 .sigmoid();
768 ElectraDiscriminatorOutput {
769 probabilities,
770 all_hidden_states: base_model_output.all_hidden_states,
771 all_attentions: base_model_output.all_attentions,
772 }
773 }
774}
775
776/// # Electra for token classification (e.g. POS, NER)
777/// Electra model with a token tagging head
778/// It is made of the following blocks:
779/// - `electra`: `ElectraModel` (based on a `BertEncoder` and custom embeddings)
780/// - `dropout`: Dropout layer
781/// - `classifier`: linear layer of dimension (*hidden_size*, *num_classes*) to project the output to the target label space
782pub struct ElectraForTokenClassification {
783 electra: ElectraModel,
784 dropout: Dropout,
785 classifier: nn::Linear,
786}
787
788/// Defines the implementation of the ElectraForTokenClassification.
789impl ElectraForTokenClassification {
790 /// Build a new `ElectraForTokenClassification`
791 ///
792 /// # Arguments
793 ///
794 /// * `p` - Variable store path for the root of the Electra model
795 /// * `config` - `ElectraConfig` object defining the model architecture
796 ///
797 /// # Example
798 ///
799 /// ```no_run
800 /// use rust_bert::electra::{ElectraConfig, ElectraForTokenClassification};
801 /// use rust_bert::Config;
802 /// use std::path::Path;
803 /// use tch::{nn, Device};
804 /// let config_path = Path::new("path/to/config.json");
805 /// let device = Device::Cpu;
806 /// let p = nn::VarStore::new(device);
807 /// let config = ElectraConfig::from_file(config_path);
808 /// let electra_model: ElectraForTokenClassification =
809 /// ElectraForTokenClassification::new(&p.root(), &config).unwrap();
810 /// ```
811 pub fn new<'p, P>(
812 p: P,
813 config: &ElectraConfig,
814 ) -> Result<ElectraForTokenClassification, RustBertError>
815 where
816 P: Borrow<nn::Path<'p>>,
817 {
818 let p = p.borrow();
819
820 let electra = ElectraModel::new(p / "electra", config);
821 let dropout = Dropout::new(config.hidden_dropout_prob);
822 let num_labels = config
823 .id2label
824 .as_ref()
825 .ok_or_else(|| {
826 RustBertError::InvalidConfigurationError(
827 "id2label must be provided for classifiers".to_string(),
828 )
829 })?
830 .len() as i64;
831 let classifier = nn::linear(
832 p / "classifier",
833 config.hidden_size,
834 num_labels,
835 Default::default(),
836 );
837
838 Ok(ElectraForTokenClassification {
839 electra,
840 dropout,
841 classifier,
842 })
843 }
844
845 /// Forward pass through the model
846 ///
847 /// # Arguments
848 ///
849 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
850 /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
851 /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
852 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
853 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
854 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
855 ///
856 /// # Returns
857 ///
858 /// * `ElectraTokenClassificationOutput` containing:
859 /// - `logits` - `Tensor` of shape (*batch size*, *sequence_length*, *num_labels*) containing the logits for each of the input tokens and classes
860 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
861 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
862 ///
863 /// # Example
864 ///
865 /// ```no_run
866 /// # use rust_bert::electra::{ElectraForTokenClassification, ElectraConfig};
867 /// # use tch::{nn, Device, Tensor, no_grad};
868 /// # use rust_bert::Config;
869 /// # use std::path::Path;
870 /// # use tch::kind::Kind::Int64;
871 /// # let config_path = Path::new("path/to/config.json");
872 /// # let device = Device::Cpu;
873 /// # let vs = nn::VarStore::new(device);
874 /// # let config = ElectraConfig::from_file(config_path);
875 /// # let electra_model: ElectraForTokenClassification = ElectraForTokenClassification::new(&vs.root(), &config).unwrap();
876 /// let (batch_size, sequence_length) = (64, 128);
877 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
878 /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
879 /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
880 /// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
881 ///
882 /// let model_output = no_grad(|| {
883 /// electra_model
884 /// .forward_t(Some(&input_tensor),
885 /// Some(&mask),
886 /// Some(&token_type_ids),
887 /// Some(&position_ids),
888 /// None,
889 /// false)
890 /// });
891 /// ```
892 pub fn forward_t(
893 &self,
894 input_ids: Option<&Tensor>,
895 mask: Option<&Tensor>,
896 token_type_ids: Option<&Tensor>,
897 position_ids: Option<&Tensor>,
898 input_embeds: Option<&Tensor>,
899 train: bool,
900 ) -> ElectraTokenClassificationOutput {
901 let base_model_output = self
902 .electra
903 .forward_t(
904 input_ids,
905 mask,
906 token_type_ids,
907 position_ids,
908 input_embeds,
909 train,
910 )
911 .unwrap();
912 let logits = base_model_output
913 .hidden_state
914 .apply_t(&self.dropout, train)
915 .apply(&self.classifier);
916 ElectraTokenClassificationOutput {
917 logits,
918 all_hidden_states: base_model_output.all_hidden_states,
919 all_attentions: base_model_output.all_attentions,
920 }
921 }
922}
923
924/// Container for the Electra model output.
925pub struct ElectraModelOutput {
926 /// Last hidden states from the model
927 pub hidden_state: Tensor,
928 /// Hidden states for all intermediate layers
929 pub all_hidden_states: Option<Vec<Tensor>>,
930 /// Attention weights for all intermediate layers
931 pub all_attentions: Option<Vec<Tensor>>,
932}
933
934/// Container for the Electra discriminator model output.
935pub struct ElectraDiscriminatorOutput {
936 /// Probabilities for each sequence item (token) to be generated by a language model
937 pub probabilities: Tensor,
938 /// Hidden states for all intermediate layers
939 pub all_hidden_states: Option<Vec<Tensor>>,
940 /// Attention weights for all intermediate layers
941 pub all_attentions: Option<Vec<Tensor>>,
942}
943
944/// Container for the Electra masked LM model output.
945pub struct ElectraMaskedLMOutput {
946 /// Logits for the vocabulary items at each sequence position
947 pub prediction_scores: Tensor,
948 /// Hidden states for all intermediate layers
949 pub all_hidden_states: Option<Vec<Tensor>>,
950 /// Attention weights for all intermediate layers
951 pub all_attentions: Option<Vec<Tensor>>,
952}
953
954/// Container for the Electra token classification model output.
955pub struct ElectraTokenClassificationOutput {
956 /// Logits for each sequence item (token) for each target class
957 pub logits: Tensor,
958 /// Hidden states for all intermediate layers
959 pub all_hidden_states: Option<Vec<Tensor>>,
960 /// Attention weights for all intermediate layers
961 pub all_attentions: Option<Vec<Tensor>>,
962}