rust_bert/models/fnet/fnet_model.rs
1// Copyright 2021 Google Research
2// Copyright 2020-present, the HuggingFace Inc. team.
3// Copyright 2021 Guillaume Becquin
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7// http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14use crate::common::activations::{TensorFunction, _tanh};
15use crate::common::dropout::Dropout;
16use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair;
17use crate::fnet::embeddings::FNetEmbeddings;
18use crate::fnet::encoder::FNetEncoder;
19use crate::{Activation, Config, RustBertError};
20use serde::{Deserialize, Serialize};
21use std::borrow::Borrow;
22use std::collections::HashMap;
23use tch::nn::LayerNormConfig;
24use tch::{nn, Tensor};
25
26/// # FNet Pretrained model weight files
27pub struct FNetModelResources;
28
29/// # FNet Pretrained model config files
30pub struct FNetConfigResources;
31
32/// # FNet Pretrained model vocab files
33pub struct FNetVocabResources;
34
35impl FNetModelResources {
36 /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/google-research/tree/master/f_net>. Modified with conversion to C-array format.
37 pub const BASE: (&'static str, &'static str) = (
38 "fnet-base/model",
39 "https://huggingface.co/google/fnet-base/resolve/main/rust_model.ot",
40 );
41 /// Shared under Apache 2.0 license at <https://huggingface.co/gchhablani/fnet-base-finetuned-sst2>. Modified with conversion to C-array format.
42 pub const BASE_SST2: (&'static str, &'static str) = (
43 "fnet-base-sst2/model",
44 "https://huggingface.co/gchhablani/fnet-base-finetuned-sst2/resolve/main/rust_model.ot",
45 );
46}
47
48impl FNetConfigResources {
49 /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/google-research/tree/master/f_net>. Modified with conversion to C-array format.
50 pub const BASE: (&'static str, &'static str) = (
51 "fnet-base/config",
52 "https://huggingface.co/google/fnet-base/resolve/main/config.json",
53 );
54 /// Shared under Apache 2.0 license at <https://huggingface.co/gchhablani/fnet-base-finetuned-sst2>. Modified with conversion to C-array format.
55 pub const BASE_SST2: (&'static str, &'static str) = (
56 "fnet-base-sst2/config",
57 "https://huggingface.co/gchhablani/fnet-base-finetuned-sst2/resolve/main/config.json",
58 );
59}
60
61impl FNetVocabResources {
62 /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/google-research/tree/master/f_net>. Modified with conversion to C-array format.
63 pub const BASE: (&'static str, &'static str) = (
64 "fnet-base/spiece",
65 "https://huggingface.co/google/fnet-base/resolve/main/spiece.model",
66 );
67 /// Shared under Apache 2.0 license at <https://huggingface.co/gchhablani/fnet-base-finetuned-sst2>. Modified with conversion to C-array format.
68 pub const BASE_SST2: (&'static str, &'static str) = (
69 "fnet-base-sst2/spiece",
70 "https://huggingface.co/google/fnet-base/resolve/main/spiece.model",
71 );
72}
73
74#[derive(Debug, Serialize, Deserialize)]
75/// # FNet model configuration
76/// Defines the FNet model architecture (e.g. number of layers, hidden layer size, label mapping...)
77pub struct FNetConfig {
78 pub vocab_size: i64,
79 pub hidden_size: i64,
80 pub num_hidden_layers: i64,
81 pub intermediate_size: i64,
82 pub hidden_act: Activation,
83 pub hidden_dropout_prob: f64,
84 pub max_position_embeddings: i64,
85 pub type_vocab_size: i64,
86 pub initializer_range: f64,
87 pub layer_norm_eps: Option<f64>,
88 pub pad_token_id: Option<i64>,
89 pub bos_token_id: Option<i64>,
90 pub eos_token_id: Option<i64>,
91 pub decoder_start_token_id: Option<i64>,
92 pub id2label: Option<HashMap<i64, String>>,
93 pub label2id: Option<HashMap<String, i64>>,
94 pub output_attentions: Option<bool>,
95 pub output_hidden_states: Option<bool>,
96}
97
98impl Config for FNetConfig {}
99
100impl Default for FNetConfig {
101 fn default() -> Self {
102 FNetConfig {
103 vocab_size: 32000,
104 hidden_size: 768,
105 num_hidden_layers: 12,
106 intermediate_size: 3072,
107 hidden_act: Activation::gelu_new,
108 hidden_dropout_prob: 0.1,
109 max_position_embeddings: 512,
110 type_vocab_size: 4,
111 initializer_range: 0.02,
112 layer_norm_eps: Some(1e-12),
113 pad_token_id: Some(3),
114 bos_token_id: Some(1),
115 eos_token_id: Some(2),
116 decoder_start_token_id: None,
117 id2label: None,
118 label2id: None,
119 output_attentions: None,
120 output_hidden_states: None,
121 }
122 }
123}
124
125struct FNetPooler {
126 dense: nn::Linear,
127 activation: TensorFunction,
128}
129
130impl FNetPooler {
131 pub fn new<'p, P>(p: P, config: &FNetConfig) -> FNetPooler
132 where
133 P: Borrow<nn::Path<'p>>,
134 {
135 let dense = nn::linear(
136 p.borrow() / "dense",
137 config.hidden_size,
138 config.hidden_size,
139 Default::default(),
140 );
141 let activation = TensorFunction::new(Box::new(_tanh));
142
143 FNetPooler { dense, activation }
144 }
145
146 pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
147 self.activation.get_fn()(&hidden_states.select(1, 0).apply(&self.dense))
148 }
149}
150
151struct FNetPredictionHeadTransform {
152 dense: nn::Linear,
153 activation: TensorFunction,
154 layer_norm: nn::LayerNorm,
155}
156
157impl FNetPredictionHeadTransform {
158 pub fn new<'p, P>(p: P, config: &FNetConfig) -> FNetPredictionHeadTransform
159 where
160 P: Borrow<nn::Path<'p>>,
161 {
162 let p = p.borrow();
163
164 let dense = nn::linear(
165 p / "dense",
166 config.hidden_size,
167 config.hidden_size,
168 Default::default(),
169 );
170 let activation = config.hidden_act.get_function();
171 let layer_norm_config = LayerNormConfig {
172 eps: config.layer_norm_eps.unwrap_or(1e-12),
173 ..Default::default()
174 };
175 let layer_norm =
176 nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
177
178 FNetPredictionHeadTransform {
179 dense,
180 activation,
181 layer_norm,
182 }
183 }
184
185 pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
186 let hidden_states = hidden_states.apply(&self.dense);
187 let hidden_states: Tensor = self.activation.get_fn()(&hidden_states);
188 hidden_states.apply(&self.layer_norm)
189 }
190}
191
192struct FNetLMPredictionHead {
193 transform: FNetPredictionHeadTransform,
194 decoder: nn::Linear,
195}
196
197impl FNetLMPredictionHead {
198 pub fn new<'p, P>(p: P, config: &FNetConfig) -> FNetLMPredictionHead
199 where
200 P: Borrow<nn::Path<'p>>,
201 {
202 let p = p.borrow();
203
204 let transform = FNetPredictionHeadTransform::new(p / "transform", config);
205 let decoder = nn::linear(
206 p / "decoder",
207 config.hidden_size,
208 config.vocab_size,
209 Default::default(),
210 );
211
212 FNetLMPredictionHead { transform, decoder }
213 }
214
215 pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
216 self.transform.forward(hidden_states).apply(&self.decoder)
217 }
218}
219
220/// # FNet Base model
221/// Base architecture for FNet models. Task-specific models will be built from this common base model
222/// It is made of the following blocks:
223/// - `embeddings`: FNetEmbeddings combining word, position and segment embeddings
224/// - `encoder`: `FNetEncoder` made of a stack of `FNetLayer`
225/// - `pooler`: Optional `FNetPooler` taking the first sequence element hidden state for sequence-level tasks
226pub struct FNetModel {
227 embeddings: FNetEmbeddings,
228 encoder: FNetEncoder,
229 pooler: Option<FNetPooler>,
230}
231
232impl FNetModel {
233 /// Build a new `FNetModel`
234 ///
235 /// # Arguments
236 ///
237 /// * `p` - Variable store path for the root of the FNet model
238 /// * `config` - `FNetConfig` object defining the model architecture
239 /// * `add_pooling_layer` - boolean flag indicating if a pooling layer should be added after the encoder
240 ///
241 /// # Example
242 ///
243 /// ```no_run
244 /// use rust_bert::fnet::{FNetConfig, FNetModel};
245 /// use rust_bert::Config;
246 /// use std::path::Path;
247 /// use tch::{nn, Device};
248 ///
249 /// let config_path = Path::new("path/to/config.json");
250 /// let device = Device::Cpu;
251 /// let p = nn::VarStore::new(device);
252 /// let config = FNetConfig::from_file(config_path);
253 /// let add_pooling_layer = true;
254 /// let fnet = FNetModel::new(&p.root() / "fnet", &config, add_pooling_layer);
255 /// ```
256 pub fn new<'p, P>(p: P, config: &FNetConfig, add_pooling_layer: bool) -> FNetModel
257 where
258 P: Borrow<nn::Path<'p>>,
259 {
260 let p = p.borrow();
261
262 let embeddings = FNetEmbeddings::new(p / "embeddings", config);
263 let encoder = FNetEncoder::new(p / "encoder", config);
264 let pooler = if add_pooling_layer {
265 Some(FNetPooler::new(p / "pooler", config))
266 } else {
267 None
268 };
269
270 FNetModel {
271 embeddings,
272 encoder,
273 pooler,
274 }
275 }
276
277 /// Forward pass through the model
278 ///
279 /// # Arguments
280 ///
281 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
282 /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
283 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
284 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
285 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
286 ///
287 /// # Returns
288 ///
289 /// * `FNetModelOutput` containing:
290 /// - `hidden_state` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*)
291 /// - `pooled_output` - Optional `Tensor` of shape (*batch size*, *hidden_size*) if the model was created with an optional pooling layer
292 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
293 ///
294 /// # Example
295 ///
296 /// ```no_run
297 /// # use tch::{nn, Device, Tensor, no_grad};
298 /// # use rust_bert::Config;
299 /// # use std::path::Path;
300 /// # use tch::kind::Kind::Int64;
301 /// use rust_bert::fnet::{FNetConfig, FNetModel};
302 /// # let config_path = Path::new("path/to/config.json");
303 /// # let device = Device::Cpu;
304 /// # let vs = nn::VarStore::new(device);
305 /// # let config = FNetConfig::from_file(config_path);
306 /// let add_pooling_layer = true;
307 /// let model = FNetModel::new(&vs.root(), &config, add_pooling_layer);
308 /// let (batch_size, sequence_length) = (64, 128);
309 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
310 /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
311 /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
312 /// .expand(&[batch_size, sequence_length], true);
313 ///
314 /// let model_output = no_grad(|| {
315 /// model
316 /// .forward_t(
317 /// Some(&input_tensor),
318 /// Some(&token_type_ids),
319 /// Some(&position_ids),
320 /// None,
321 /// false,
322 /// )
323 /// .unwrap()
324 /// });
325 /// ```
326 pub fn forward_t(
327 &self,
328 input_ids: Option<&Tensor>,
329 token_type_ids: Option<&Tensor>,
330 position_ids: Option<&Tensor>,
331 input_embeddings: Option<&Tensor>,
332 train: bool,
333 ) -> Result<FNetModelOutput, RustBertError> {
334 let hidden_states = self.embeddings.forward_t(
335 input_ids,
336 token_type_ids,
337 position_ids,
338 input_embeddings,
339 train,
340 )?;
341
342 let encoder_output = self.encoder.forward_t(&hidden_states, train);
343 let pooled_output = if let Some(pooler) = &self.pooler {
344 Some(pooler.forward(&encoder_output.hidden_states))
345 } else {
346 None
347 };
348 Ok(FNetModelOutput {
349 hidden_states: encoder_output.hidden_states,
350 pooled_output,
351 all_hidden_states: encoder_output.all_hidden_states,
352 })
353 }
354}
355
356/// # FNet for masked language model
357/// Base FNet model with a masked language model head to predict missing tokens, for example `"Looks like one [MASK] is missing" -> "person"`
358/// It is made of the following blocks:
359/// - `fnet`: Base FNet model
360/// - `lm_head`: FNet LM prediction head
361pub struct FNetForMaskedLM {
362 fnet: FNetModel,
363 lm_head: FNetLMPredictionHead,
364}
365
366impl FNetForMaskedLM {
367 /// Build a new `FNetForMaskedLM`
368 ///
369 /// # Arguments
370 ///
371 /// * `p` - Variable store path for the root of the FNet model
372 /// * `config` - `FNetConfig` object defining the model architecture
373 ///
374 /// # Example
375 ///
376 /// ```no_run
377 /// use rust_bert::fnet::{FNetConfig, FNetForMaskedLM};
378 /// use rust_bert::Config;
379 /// use std::path::Path;
380 /// use tch::{nn, Device};
381 ///
382 /// let config_path = Path::new("path/to/config.json");
383 /// let device = Device::Cpu;
384 /// let p = nn::VarStore::new(device);
385 /// let config = FNetConfig::from_file(config_path);
386 /// let fnet = FNetForMaskedLM::new(&p.root() / "fnet", &config);
387 /// ```
388 pub fn new<'p, P>(p: P, config: &FNetConfig) -> FNetForMaskedLM
389 where
390 P: Borrow<nn::Path<'p>>,
391 {
392 let p = p.borrow();
393
394 let fnet = FNetModel::new(p / "fnet", config, false);
395 let lm_head = FNetLMPredictionHead::new(p.sub("cls").sub("predictions"), config);
396
397 FNetForMaskedLM { fnet, lm_head }
398 }
399
400 /// Forward pass through the model
401 ///
402 /// # Arguments
403 ///
404 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
405 /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
406 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
407 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
408 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
409 ///
410 /// # Returns
411 ///
412 /// * `FNetMaskedLMOutput` containing:
413 /// - `prediction_scores` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*)
414 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
415 ///
416 /// # Example
417 ///
418 /// ```no_run
419 /// # use tch::{nn, Device, Tensor, no_grad};
420 /// # use rust_bert::Config;
421 /// # use std::path::Path;
422 /// # use tch::kind::Kind::Int64;
423 /// use rust_bert::fnet::{FNetConfig, FNetForMaskedLM};
424 /// # let config_path = Path::new("path/to/config.json");
425 /// # let device = Device::Cpu;
426 /// # let vs = nn::VarStore::new(device);
427 /// # let config = FNetConfig::from_file(config_path);
428 /// let model = FNetForMaskedLM::new(&vs.root(), &config);
429 /// let (batch_size, sequence_length) = (64, 128);
430 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
431 /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
432 /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
433 /// .expand(&[batch_size, sequence_length], true);
434 ///
435 /// let model_output = no_grad(|| {
436 /// model
437 /// .forward_t(
438 /// Some(&input_tensor),
439 /// Some(&token_type_ids),
440 /// Some(&position_ids),
441 /// None,
442 /// false,
443 /// )
444 /// .unwrap()
445 /// });
446 /// ```
447 pub fn forward_t(
448 &self,
449 input_ids: Option<&Tensor>,
450 token_type_ids: Option<&Tensor>,
451 position_ids: Option<&Tensor>,
452 input_embeddings: Option<&Tensor>,
453 train: bool,
454 ) -> Result<FNetMaskedLMOutput, RustBertError> {
455 let model_output = self.fnet.forward_t(
456 input_ids,
457 token_type_ids,
458 position_ids,
459 input_embeddings,
460 train,
461 )?;
462
463 let prediction_scores = self.lm_head.forward(&model_output.hidden_states);
464
465 Ok(FNetMaskedLMOutput {
466 prediction_scores,
467 all_hidden_states: model_output.all_hidden_states,
468 })
469 }
470}
471
472/// # FNet for sequence classification
473/// Base FNet model with a classifier head to perform sentence or document-level classification
474/// It is made of the following blocks:
475/// - `fnet`: Base FNet model
476/// - `dropout`: Dropout layer before the last linear layer
477/// - `classifier`: linear layer mapping from hidden to the number of classes to predict
478pub struct FNetForSequenceClassification {
479 fnet: FNetModel,
480 dropout: Dropout,
481 classifier: nn::Linear,
482}
483
484impl FNetForSequenceClassification {
485 /// Build a new `FNetForSequenceClassification`
486 ///
487 /// # Arguments
488 ///
489 /// * `p` - Variable store path for the root of the FNet model
490 /// * `config` - `FNetConfig` object defining the model architecture
491 ///
492 /// # Example
493 ///
494 /// ```no_run
495 /// use rust_bert::fnet::{FNetConfig, FNetForSequenceClassification};
496 /// use rust_bert::Config;
497 /// use std::path::Path;
498 /// use tch::{nn, Device};
499 ///
500 /// let config_path = Path::new("path/to/config.json");
501 /// let device = Device::Cpu;
502 /// let p = nn::VarStore::new(device);
503 /// let config = FNetConfig::from_file(config_path);
504 /// let fnet = FNetForSequenceClassification::new(&p.root() / "fnet", &config).unwrap();
505 /// ```
506 pub fn new<'p, P>(
507 p: P,
508 config: &FNetConfig,
509 ) -> Result<FNetForSequenceClassification, RustBertError>
510 where
511 P: Borrow<nn::Path<'p>>,
512 {
513 let p = p.borrow();
514
515 let fnet = FNetModel::new(p / "fnet", config, true);
516 let dropout = Dropout::new(config.hidden_dropout_prob);
517 let num_labels = config
518 .id2label
519 .as_ref()
520 .ok_or_else(|| {
521 RustBertError::InvalidConfigurationError(
522 "num_labels not provided in configuration".to_string(),
523 )
524 })?
525 .len() as i64;
526 let classifier = nn::linear(
527 p / "classifier",
528 config.hidden_size,
529 num_labels,
530 Default::default(),
531 );
532
533 Ok(FNetForSequenceClassification {
534 fnet,
535 dropout,
536 classifier,
537 })
538 }
539
540 /// Forward pass through the model
541 ///
542 /// # Arguments
543 ///
544 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
545 /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
546 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
547 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
548 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
549 ///
550 /// # Returns
551 ///
552 /// * `FNetSequenceClassificationOutput` containing:
553 /// - `logits` - `Tensor` of shape (*batch size*, *num_classes*)
554 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
555 ///
556 /// # Example
557 ///
558 /// ```no_run
559 /// # use tch::{nn, Device, Tensor, no_grad};
560 /// # use rust_bert::Config;
561 /// # use std::path::Path;
562 /// # use tch::kind::Kind::Int64;
563 /// use rust_bert::fnet::{FNetConfig, FNetForSequenceClassification};
564 /// # let config_path = Path::new("path/to/config.json");
565 /// # let device = Device::Cpu;
566 /// # let vs = nn::VarStore::new(device);
567 /// # let config = FNetConfig::from_file(config_path);
568 /// let model = FNetForSequenceClassification::new(&vs.root(), &config).unwrap();
569 /// let (batch_size, sequence_length) = (64, 128);
570 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
571 /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
572 /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
573 /// .expand(&[batch_size, sequence_length], true);
574 ///
575 /// let model_output = no_grad(|| {
576 /// model
577 /// .forward_t(
578 /// Some(&input_tensor),
579 /// Some(&token_type_ids),
580 /// Some(&position_ids),
581 /// None,
582 /// false,
583 /// )
584 /// .unwrap()
585 /// });
586 /// ```
587 pub fn forward_t(
588 &self,
589 input_ids: Option<&Tensor>,
590 token_type_ids: Option<&Tensor>,
591 position_ids: Option<&Tensor>,
592 input_embeddings: Option<&Tensor>,
593 train: bool,
594 ) -> Result<FNetSequenceClassificationOutput, RustBertError> {
595 let base_model_output = self.fnet.forward_t(
596 input_ids,
597 token_type_ids,
598 position_ids,
599 input_embeddings,
600 train,
601 )?;
602
603 let logits = base_model_output
604 .pooled_output
605 .unwrap()
606 .apply_t(&self.dropout, train)
607 .apply(&self.classifier);
608
609 Ok(FNetSequenceClassificationOutput {
610 logits,
611 all_hidden_states: base_model_output.all_hidden_states,
612 })
613 }
614}
615
616/// # FNet for multiple choices
617/// Multiple choices model using a FNet base model and a linear classifier.
618/// Input should be in the form `[CLS] Context [SEP] Possible choice [SEP]`. The choice is made along the batch axis,
619/// assuming all elements of the batch are alternatives to be chosen from for a given context.
620/// It is made of the following blocks:
621/// - `fnet`: Base FNet model
622/// - `dropout`: Dropout layer before the last start/end logits prediction
623/// - `classifier`: Linear layer for multiple choices
624pub struct FNetForMultipleChoice {
625 fnet: FNetModel,
626 dropout: Dropout,
627 classifier: nn::Linear,
628}
629
630impl FNetForMultipleChoice {
631 /// Build a new `FNetForMultipleChoice`
632 ///
633 /// # Arguments
634 ///
635 /// * `p` - Variable store path for the root of the FNet model
636 /// * `config` - `FNetConfig` object defining the model architecture
637 ///
638 /// # Example
639 ///
640 /// ```no_run
641 /// use rust_bert::fnet::{FNetConfig, FNetForMultipleChoice};
642 /// use rust_bert::Config;
643 /// use std::path::Path;
644 /// use tch::{nn, Device};
645 ///
646 /// let config_path = Path::new("path/to/config.json");
647 /// let device = Device::Cpu;
648 /// let p = nn::VarStore::new(device);
649 /// let config = FNetConfig::from_file(config_path);
650 /// let fnet = FNetForMultipleChoice::new(&p.root() / "fnet", &config);
651 /// ```
652 pub fn new<'p, P>(p: P, config: &FNetConfig) -> FNetForMultipleChoice
653 where
654 P: Borrow<nn::Path<'p>>,
655 {
656 let p = p.borrow();
657
658 let fnet = FNetModel::new(p / "fnet", config, true);
659 let dropout = Dropout::new(config.hidden_dropout_prob);
660 let classifier = nn::linear(p / "classifier", config.hidden_size, 1, Default::default());
661
662 FNetForMultipleChoice {
663 fnet,
664 dropout,
665 classifier,
666 }
667 }
668
669 /// Forward pass through the model
670 ///
671 /// # Arguments
672 ///
673 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
674 /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
675 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
676 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
677 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
678 ///
679 /// # Returns
680 ///
681 /// * `FNetSequenceClassificationOutput` containing:
682 /// - `logits` - `Tensor` of shape (*1*, *batch_size*) containing the logits for each of the alternatives given
683 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
684 ///
685 /// # Example
686 ///
687 /// ```no_run
688 /// # use tch::{nn, Device, Tensor, no_grad};
689 /// # use rust_bert::Config;
690 /// # use std::path::Path;
691 /// # use tch::kind::Kind::Int64;
692 /// use rust_bert::fnet::{FNetConfig, FNetForMultipleChoice};
693 /// # let config_path = Path::new("path/to/config.json");
694 /// # let device = Device::Cpu;
695 /// # let vs = nn::VarStore::new(device);
696 /// # let config = FNetConfig::from_file(config_path);
697 /// let model = FNetForMultipleChoice::new(&vs.root(), &config);
698 /// let (batch_size, sequence_length) = (64, 128);
699 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
700 /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
701 /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
702 /// .expand(&[batch_size, sequence_length], true);
703 ///
704 /// let model_output = no_grad(|| {
705 /// model
706 /// .forward_t(
707 /// Some(&input_tensor),
708 /// Some(&token_type_ids),
709 /// Some(&position_ids),
710 /// None,
711 /// false,
712 /// )
713 /// .unwrap()
714 /// });
715 /// ```
716 pub fn forward_t(
717 &self,
718 input_ids: Option<&Tensor>,
719 token_type_ids: Option<&Tensor>,
720 position_ids: Option<&Tensor>,
721 input_embeddings: Option<&Tensor>,
722 train: bool,
723 ) -> Result<FNetSequenceClassificationOutput, RustBertError> {
724 let (input_shape, _) =
725 get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeddings)?;
726 let num_choices = input_shape[1];
727
728 let input_ids = input_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
729 let token_type_ids =
730 token_type_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
731 let position_ids =
732 position_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
733 let input_embeddings =
734 input_embeddings.map(|tensor| tensor.view((-1, tensor.size()[2], tensor.size()[3])));
735
736 let base_model_output = self.fnet.forward_t(
737 input_ids.as_ref(),
738 token_type_ids.as_ref(),
739 position_ids.as_ref(),
740 input_embeddings.as_ref(),
741 train,
742 )?;
743
744 let logits = base_model_output
745 .pooled_output
746 .unwrap()
747 .apply_t(&self.dropout, train)
748 .apply(&self.classifier)
749 .view((-1, num_choices));
750
751 Ok(FNetSequenceClassificationOutput {
752 logits,
753 all_hidden_states: base_model_output.all_hidden_states,
754 })
755 }
756}
757
758/// # FNet for token classification (e.g. NER, POS)
759/// Token-level classifier predicting a label for each token provided. Note that because of wordpiece tokenization, the labels predicted are
760/// not necessarily aligned with words in the sentence.
761/// It is made of the following blocks:
762/// - `fnet`: Base FNet
763/// - `dropout`: Dropout layer before the last token-level predictions layer
764/// - `classifier`: Linear layer for token classification
765pub struct FNetForTokenClassification {
766 fnet: FNetModel,
767 dropout: Dropout,
768 classifier: nn::Linear,
769}
770
771impl FNetForTokenClassification {
772 /// Build a new `FNetForTokenClassification`
773 ///
774 /// # Arguments
775 ///
776 /// * `p` - Variable store path for the root of the FNet model
777 /// * `config` - `FNetConfig` object defining the model architecture
778 ///
779 /// # Example
780 ///
781 /// ```no_run
782 /// use rust_bert::fnet::{FNetConfig, FNetForTokenClassification};
783 /// use rust_bert::Config;
784 /// use std::path::Path;
785 /// use tch::{nn, Device};
786 ///
787 /// let config_path = Path::new("path/to/config.json");
788 /// let device = Device::Cpu;
789 /// let p = nn::VarStore::new(device);
790 /// let config = FNetConfig::from_file(config_path);
791 /// let fnet = FNetForTokenClassification::new(&p.root() / "fnet", &config).unwrap();
792 /// ```
793 pub fn new<'p, P>(
794 p: P,
795 config: &FNetConfig,
796 ) -> Result<FNetForTokenClassification, RustBertError>
797 where
798 P: Borrow<nn::Path<'p>>,
799 {
800 let p = p.borrow();
801
802 let fnet = FNetModel::new(p / "fnet", config, false);
803 let dropout = Dropout::new(config.hidden_dropout_prob);
804 let num_labels = config
805 .id2label
806 .as_ref()
807 .ok_or_else(|| {
808 RustBertError::InvalidConfigurationError(
809 "num_labels not provided in configuration".to_string(),
810 )
811 })?
812 .len() as i64;
813 let classifier = nn::linear(
814 p / "classifier",
815 config.hidden_size,
816 num_labels,
817 Default::default(),
818 );
819
820 Ok(FNetForTokenClassification {
821 fnet,
822 dropout,
823 classifier,
824 })
825 }
826
827 /// Forward pass through the model
828 ///
829 /// # Arguments
830 ///
831 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
832 /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
833 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
834 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
835 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
836 ///
837 /// # Returns
838 ///
839 /// * `FNetTokenClassificationOutput` containing:
840 /// - `logits` - `Tensor` of shape (*batch size*, *sequence_length*, *num_labels*) containing the logits for each of the input tokens and classes
841 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
842 ///
843 /// # Example
844 ///
845 /// ```no_run
846 /// # use tch::{nn, Device, Tensor, no_grad};
847 /// # use rust_bert::Config;
848 /// # use std::path::Path;
849 /// # use tch::kind::Kind::Int64;
850 /// use rust_bert::fnet::{FNetConfig, FNetForTokenClassification};
851 /// # let config_path = Path::new("path/to/config.json");
852 /// # let device = Device::Cpu;
853 /// # let vs = nn::VarStore::new(device);
854 /// # let config = FNetConfig::from_file(config_path);
855 /// let model = FNetForTokenClassification::new(&vs.root(), &config).unwrap();
856 /// let (batch_size, sequence_length) = (64, 128);
857 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
858 /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
859 /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
860 /// .expand(&[batch_size, sequence_length], true);
861 ///
862 /// let model_output = no_grad(|| {
863 /// model
864 /// .forward_t(
865 /// Some(&input_tensor),
866 /// Some(&token_type_ids),
867 /// Some(&position_ids),
868 /// None,
869 /// false,
870 /// )
871 /// .unwrap()
872 /// });
873 /// ```
874 pub fn forward_t(
875 &self,
876 input_ids: Option<&Tensor>,
877 token_type_ids: Option<&Tensor>,
878 position_ids: Option<&Tensor>,
879 input_embeddings: Option<&Tensor>,
880 train: bool,
881 ) -> Result<FNetTokenClassificationOutput, RustBertError> {
882 let base_model_output = self.fnet.forward_t(
883 input_ids,
884 token_type_ids,
885 position_ids,
886 input_embeddings,
887 train,
888 )?;
889
890 let logits = base_model_output
891 .hidden_states
892 .apply_t(&self.dropout, train)
893 .apply(&self.classifier);
894
895 Ok(FNetTokenClassificationOutput {
896 logits,
897 all_hidden_states: base_model_output.all_hidden_states,
898 })
899 }
900}
901
902/// # FNet for question answering
903/// Extractive question-answering model based on a FNet language model. Identifies the segment of a context that answers a provided question.
904/// Please note that a significant amount of pre- and post-processing is required to perform end-to-end question answering.
905/// See the question answering pipeline (also provided in this crate) for more details.
906/// It is made of the following blocks:
907/// - `fnet`: Base FNet
908/// - `qa_outputs`: Linear layer for question answering
909pub struct FNetForQuestionAnswering {
910 fnet: FNetModel,
911 qa_outputs: nn::Linear,
912}
913
914impl FNetForQuestionAnswering {
915 /// Build a new `FNetForQuestionAnswering`
916 ///
917 /// # Arguments
918 ///
919 /// * `p` - Variable store path for the root of the FNet model
920 /// * `config` - `FNetConfig` object defining the model architecture
921 ///
922 /// # Example
923 ///
924 /// ```no_run
925 /// use rust_bert::fnet::{FNetConfig, FNetForQuestionAnswering};
926 /// use rust_bert::Config;
927 /// use std::path::Path;
928 /// use tch::{nn, Device};
929 ///
930 /// let config_path = Path::new("path/to/config.json");
931 /// let device = Device::Cpu;
932 /// let p = nn::VarStore::new(device);
933 /// let config = FNetConfig::from_file(config_path);
934 /// let fnet = FNetForQuestionAnswering::new(&p.root() / "fnet", &config);
935 /// ```
936 pub fn new<'p, P>(p: P, config: &FNetConfig) -> FNetForQuestionAnswering
937 where
938 P: Borrow<nn::Path<'p>>,
939 {
940 let p = p.borrow();
941
942 let fnet = FNetModel::new(p / "fnet", config, false);
943 let qa_outputs = nn::linear(p / "classifier", config.hidden_size, 2, Default::default());
944
945 FNetForQuestionAnswering { fnet, qa_outputs }
946 }
947
948 /// Forward pass through the model
949 ///
950 /// # Arguments
951 ///
952 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
953 /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
954 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
955 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
956 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
957 ///
958 /// # Returns
959 ///
960 /// * `FNetQuestionAnsweringOutput` containing:
961 /// - `start_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for start of the answer
962 /// - `end_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for end of the answer
963 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
964 ///
965 /// # Example
966 ///
967 /// ```no_run
968 /// # use tch::{nn, Device, Tensor, no_grad};
969 /// # use rust_bert::Config;
970 /// # use std::path::Path;
971 /// # use tch::kind::Kind::Int64;
972 /// use rust_bert::fnet::{FNetConfig, FNetForQuestionAnswering};
973 /// # let config_path = Path::new("path/to/config.json");
974 /// # let device = Device::Cpu;
975 /// # let vs = nn::VarStore::new(device);
976 /// # let config = FNetConfig::from_file(config_path);
977 /// let model = FNetForQuestionAnswering::new(&vs.root(), &config);
978 /// let (batch_size, sequence_length) = (64, 128);
979 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
980 /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
981 /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
982 /// .expand(&[batch_size, sequence_length], true);
983 ///
984 /// let model_output = no_grad(|| {
985 /// model
986 /// .forward_t(
987 /// Some(&input_tensor),
988 /// Some(&token_type_ids),
989 /// Some(&position_ids),
990 /// None,
991 /// false,
992 /// )
993 /// .unwrap()
994 /// });
995 /// ```
996 pub fn forward_t(
997 &self,
998 input_ids: Option<&Tensor>,
999 token_type_ids: Option<&Tensor>,
1000 position_ids: Option<&Tensor>,
1001 input_embeddings: Option<&Tensor>,
1002 train: bool,
1003 ) -> Result<FNetQuestionAnsweringOutput, RustBertError> {
1004 let base_model_output = self.fnet.forward_t(
1005 input_ids,
1006 token_type_ids,
1007 position_ids,
1008 input_embeddings,
1009 train,
1010 )?;
1011
1012 let logits = base_model_output
1013 .hidden_states
1014 .apply(&self.qa_outputs)
1015 .split(1, -1);
1016 let (start_logits, end_logits) = (&logits[0], &logits[1]);
1017 let start_logits = start_logits.squeeze_dim(-1);
1018 let end_logits = end_logits.squeeze_dim(-1);
1019
1020 Ok(FNetQuestionAnsweringOutput {
1021 start_logits,
1022 end_logits,
1023 all_hidden_states: base_model_output.all_hidden_states,
1024 })
1025 }
1026}
1027
1028/// Container for the FNet model output.
1029pub struct FNetModelOutput {
1030 /// Last hidden states from the model
1031 pub hidden_states: Tensor,
1032 /// Pooled output (hidden state for the first token)
1033 pub pooled_output: Option<Tensor>,
1034 /// Hidden states for all intermediate layers
1035 pub all_hidden_states: Option<Vec<Tensor>>,
1036}
1037
1038/// Container for the FNet masked LM model output.
1039pub struct FNetMaskedLMOutput {
1040 /// Logits for the vocabulary items at each sequence position
1041 pub prediction_scores: Tensor,
1042 /// Hidden states for all intermediate layers
1043 pub all_hidden_states: Option<Vec<Tensor>>,
1044}
1045
1046/// Container for the FNet sequence classification model output.
1047pub struct FNetSequenceClassificationOutput {
1048 /// Logits for each input (sequence) for each target class
1049 pub logits: Tensor,
1050 /// Hidden states for all intermediate layers
1051 pub all_hidden_states: Option<Vec<Tensor>>,
1052}
1053
1054/// Container for the FNet token classification model output.
1055pub type FNetTokenClassificationOutput = FNetSequenceClassificationOutput;
1056
1057/// Container for the FNet question answering model output.
1058pub struct FNetQuestionAnsweringOutput {
1059 /// Logits for the start position for token of each input sequence
1060 pub start_logits: Tensor,
1061 /// Logits for the end position for token of each input sequence
1062 pub end_logits: Tensor,
1063 /// Hidden states for all intermediate layers
1064 pub all_hidden_states: Option<Vec<Tensor>>,
1065}
1066
1067#[cfg(test)]
1068mod test {
1069 use tch::Device;
1070
1071 use crate::{
1072 resources::{RemoteResource, ResourceProvider},
1073 Config,
1074 };
1075
1076 use super::*;
1077
1078 #[test]
1079 #[ignore] // compilation is enough, no need to run
1080 fn fnet_model_send() {
1081 let config_resource = Box::new(RemoteResource::from_pretrained(FNetConfigResources::BASE));
1082 let config_path = config_resource.get_local_path().expect("");
1083
1084 // Set-up masked LM model
1085 let device = Device::cuda_if_available();
1086 let vs = nn::VarStore::new(device);
1087 let config = FNetConfig::from_file(config_path);
1088
1089 let _: Box<dyn Send> = Box::new(FNetModel::new(vs.root(), &config, true));
1090 }
1091}