rust_bert/models/fnet/
fnet_model.rs

1// Copyright 2021 Google Research
2// Copyright 2020-present, the HuggingFace Inc. team.
3// Copyright 2021 Guillaume Becquin
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//     http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14use crate::common::activations::{TensorFunction, _tanh};
15use crate::common::dropout::Dropout;
16use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair;
17use crate::fnet::embeddings::FNetEmbeddings;
18use crate::fnet::encoder::FNetEncoder;
19use crate::{Activation, Config, RustBertError};
20use serde::{Deserialize, Serialize};
21use std::borrow::Borrow;
22use std::collections::HashMap;
23use tch::nn::LayerNormConfig;
24use tch::{nn, Tensor};
25
26/// # FNet Pretrained model weight files
27pub struct FNetModelResources;
28
29/// # FNet Pretrained model config files
30pub struct FNetConfigResources;
31
32/// # FNet Pretrained model vocab files
33pub struct FNetVocabResources;
34
35impl FNetModelResources {
36    /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/google-research/tree/master/f_net>. Modified with conversion to C-array format.
37    pub const BASE: (&'static str, &'static str) = (
38        "fnet-base/model",
39        "https://huggingface.co/google/fnet-base/resolve/main/rust_model.ot",
40    );
41    /// Shared under Apache 2.0 license at <https://huggingface.co/gchhablani/fnet-base-finetuned-sst2>. Modified with conversion to C-array format.
42    pub const BASE_SST2: (&'static str, &'static str) = (
43        "fnet-base-sst2/model",
44        "https://huggingface.co/gchhablani/fnet-base-finetuned-sst2/resolve/main/rust_model.ot",
45    );
46}
47
48impl FNetConfigResources {
49    /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/google-research/tree/master/f_net>. Modified with conversion to C-array format.
50    pub const BASE: (&'static str, &'static str) = (
51        "fnet-base/config",
52        "https://huggingface.co/google/fnet-base/resolve/main/config.json",
53    );
54    /// Shared under Apache 2.0 license at <https://huggingface.co/gchhablani/fnet-base-finetuned-sst2>. Modified with conversion to C-array format.
55    pub const BASE_SST2: (&'static str, &'static str) = (
56        "fnet-base-sst2/config",
57        "https://huggingface.co/gchhablani/fnet-base-finetuned-sst2/resolve/main/config.json",
58    );
59}
60
61impl FNetVocabResources {
62    /// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/google-research/tree/master/f_net>. Modified with conversion to C-array format.
63    pub const BASE: (&'static str, &'static str) = (
64        "fnet-base/spiece",
65        "https://huggingface.co/google/fnet-base/resolve/main/spiece.model",
66    );
67    /// Shared under Apache 2.0 license at <https://huggingface.co/gchhablani/fnet-base-finetuned-sst2>. Modified with conversion to C-array format.
68    pub const BASE_SST2: (&'static str, &'static str) = (
69        "fnet-base-sst2/spiece",
70        "https://huggingface.co/google/fnet-base/resolve/main/spiece.model",
71    );
72}
73
74#[derive(Debug, Serialize, Deserialize)]
75/// # FNet model configuration
76/// Defines the FNet model architecture (e.g. number of layers, hidden layer size, label mapping...)
77pub struct FNetConfig {
78    pub vocab_size: i64,
79    pub hidden_size: i64,
80    pub num_hidden_layers: i64,
81    pub intermediate_size: i64,
82    pub hidden_act: Activation,
83    pub hidden_dropout_prob: f64,
84    pub max_position_embeddings: i64,
85    pub type_vocab_size: i64,
86    pub initializer_range: f64,
87    pub layer_norm_eps: Option<f64>,
88    pub pad_token_id: Option<i64>,
89    pub bos_token_id: Option<i64>,
90    pub eos_token_id: Option<i64>,
91    pub decoder_start_token_id: Option<i64>,
92    pub id2label: Option<HashMap<i64, String>>,
93    pub label2id: Option<HashMap<String, i64>>,
94    pub output_attentions: Option<bool>,
95    pub output_hidden_states: Option<bool>,
96}
97
98impl Config for FNetConfig {}
99
100impl Default for FNetConfig {
101    fn default() -> Self {
102        FNetConfig {
103            vocab_size: 32000,
104            hidden_size: 768,
105            num_hidden_layers: 12,
106            intermediate_size: 3072,
107            hidden_act: Activation::gelu_new,
108            hidden_dropout_prob: 0.1,
109            max_position_embeddings: 512,
110            type_vocab_size: 4,
111            initializer_range: 0.02,
112            layer_norm_eps: Some(1e-12),
113            pad_token_id: Some(3),
114            bos_token_id: Some(1),
115            eos_token_id: Some(2),
116            decoder_start_token_id: None,
117            id2label: None,
118            label2id: None,
119            output_attentions: None,
120            output_hidden_states: None,
121        }
122    }
123}
124
125struct FNetPooler {
126    dense: nn::Linear,
127    activation: TensorFunction,
128}
129
130impl FNetPooler {
131    pub fn new<'p, P>(p: P, config: &FNetConfig) -> FNetPooler
132    where
133        P: Borrow<nn::Path<'p>>,
134    {
135        let dense = nn::linear(
136            p.borrow() / "dense",
137            config.hidden_size,
138            config.hidden_size,
139            Default::default(),
140        );
141        let activation = TensorFunction::new(Box::new(_tanh));
142
143        FNetPooler { dense, activation }
144    }
145
146    pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
147        self.activation.get_fn()(&hidden_states.select(1, 0).apply(&self.dense))
148    }
149}
150
151struct FNetPredictionHeadTransform {
152    dense: nn::Linear,
153    activation: TensorFunction,
154    layer_norm: nn::LayerNorm,
155}
156
157impl FNetPredictionHeadTransform {
158    pub fn new<'p, P>(p: P, config: &FNetConfig) -> FNetPredictionHeadTransform
159    where
160        P: Borrow<nn::Path<'p>>,
161    {
162        let p = p.borrow();
163
164        let dense = nn::linear(
165            p / "dense",
166            config.hidden_size,
167            config.hidden_size,
168            Default::default(),
169        );
170        let activation = config.hidden_act.get_function();
171        let layer_norm_config = LayerNormConfig {
172            eps: config.layer_norm_eps.unwrap_or(1e-12),
173            ..Default::default()
174        };
175        let layer_norm =
176            nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
177
178        FNetPredictionHeadTransform {
179            dense,
180            activation,
181            layer_norm,
182        }
183    }
184
185    pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
186        let hidden_states = hidden_states.apply(&self.dense);
187        let hidden_states: Tensor = self.activation.get_fn()(&hidden_states);
188        hidden_states.apply(&self.layer_norm)
189    }
190}
191
192struct FNetLMPredictionHead {
193    transform: FNetPredictionHeadTransform,
194    decoder: nn::Linear,
195}
196
197impl FNetLMPredictionHead {
198    pub fn new<'p, P>(p: P, config: &FNetConfig) -> FNetLMPredictionHead
199    where
200        P: Borrow<nn::Path<'p>>,
201    {
202        let p = p.borrow();
203
204        let transform = FNetPredictionHeadTransform::new(p / "transform", config);
205        let decoder = nn::linear(
206            p / "decoder",
207            config.hidden_size,
208            config.vocab_size,
209            Default::default(),
210        );
211
212        FNetLMPredictionHead { transform, decoder }
213    }
214
215    pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
216        self.transform.forward(hidden_states).apply(&self.decoder)
217    }
218}
219
220/// # FNet Base model
221/// Base architecture for FNet models. Task-specific models will be built from this common base model
222/// It is made of the following blocks:
223/// - `embeddings`: FNetEmbeddings combining word, position and segment embeddings
224/// - `encoder`: `FNetEncoder` made of a stack of `FNetLayer`
225/// - `pooler`: Optional `FNetPooler` taking the first sequence element hidden state for sequence-level tasks
226pub struct FNetModel {
227    embeddings: FNetEmbeddings,
228    encoder: FNetEncoder,
229    pooler: Option<FNetPooler>,
230}
231
232impl FNetModel {
233    /// Build a new `FNetModel`
234    ///
235    /// # Arguments
236    ///
237    /// * `p` - Variable store path for the root of the FNet model
238    /// * `config` - `FNetConfig` object defining the model architecture
239    /// * `add_pooling_layer` - boolean flag indicating if a pooling layer should be added after the encoder
240    ///
241    /// # Example
242    ///
243    /// ```no_run
244    /// use rust_bert::fnet::{FNetConfig, FNetModel};
245    /// use rust_bert::Config;
246    /// use std::path::Path;
247    /// use tch::{nn, Device};
248    ///
249    /// let config_path = Path::new("path/to/config.json");
250    /// let device = Device::Cpu;
251    /// let p = nn::VarStore::new(device);
252    /// let config = FNetConfig::from_file(config_path);
253    /// let add_pooling_layer = true;
254    /// let fnet = FNetModel::new(&p.root() / "fnet", &config, add_pooling_layer);
255    /// ```
256    pub fn new<'p, P>(p: P, config: &FNetConfig, add_pooling_layer: bool) -> FNetModel
257    where
258        P: Borrow<nn::Path<'p>>,
259    {
260        let p = p.borrow();
261
262        let embeddings = FNetEmbeddings::new(p / "embeddings", config);
263        let encoder = FNetEncoder::new(p / "encoder", config);
264        let pooler = if add_pooling_layer {
265            Some(FNetPooler::new(p / "pooler", config))
266        } else {
267            None
268        };
269
270        FNetModel {
271            embeddings,
272            encoder,
273            pooler,
274        }
275    }
276
277    /// Forward pass through the model
278    ///
279    /// # Arguments
280    ///
281    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
282    /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
283    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
284    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
285    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
286    ///
287    /// # Returns
288    ///
289    /// * `FNetModelOutput` containing:
290    ///   - `hidden_state` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*)
291    ///   - `pooled_output` - Optional `Tensor` of shape (*batch size*, *hidden_size*) if the model was created with an optional pooling layer
292    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
293    ///
294    /// # Example
295    ///
296    /// ```no_run
297    /// # use tch::{nn, Device, Tensor, no_grad};
298    /// # use rust_bert::Config;
299    /// # use std::path::Path;
300    /// # use tch::kind::Kind::Int64;
301    /// use rust_bert::fnet::{FNetConfig, FNetModel};
302    /// # let config_path = Path::new("path/to/config.json");
303    /// # let device = Device::Cpu;
304    /// # let vs = nn::VarStore::new(device);
305    /// # let config = FNetConfig::from_file(config_path);
306    /// let add_pooling_layer = true;
307    /// let model = FNetModel::new(&vs.root(), &config, add_pooling_layer);
308    /// let (batch_size, sequence_length) = (64, 128);
309    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
310    /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
311    /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
312    ///     .expand(&[batch_size, sequence_length], true);
313    ///
314    /// let model_output = no_grad(|| {
315    ///     model
316    ///         .forward_t(
317    ///             Some(&input_tensor),
318    ///             Some(&token_type_ids),
319    ///             Some(&position_ids),
320    ///             None,
321    ///             false,
322    ///         )
323    ///         .unwrap()
324    /// });
325    /// ```
326    pub fn forward_t(
327        &self,
328        input_ids: Option<&Tensor>,
329        token_type_ids: Option<&Tensor>,
330        position_ids: Option<&Tensor>,
331        input_embeddings: Option<&Tensor>,
332        train: bool,
333    ) -> Result<FNetModelOutput, RustBertError> {
334        let hidden_states = self.embeddings.forward_t(
335            input_ids,
336            token_type_ids,
337            position_ids,
338            input_embeddings,
339            train,
340        )?;
341
342        let encoder_output = self.encoder.forward_t(&hidden_states, train);
343        let pooled_output = if let Some(pooler) = &self.pooler {
344            Some(pooler.forward(&encoder_output.hidden_states))
345        } else {
346            None
347        };
348        Ok(FNetModelOutput {
349            hidden_states: encoder_output.hidden_states,
350            pooled_output,
351            all_hidden_states: encoder_output.all_hidden_states,
352        })
353    }
354}
355
356/// # FNet for masked language model
357/// Base FNet model with a masked language model head to predict missing tokens, for example `"Looks like one [MASK] is missing" -> "person"`
358/// It is made of the following blocks:
359/// - `fnet`: Base FNet model
360/// - `lm_head`: FNet LM prediction head
361pub struct FNetForMaskedLM {
362    fnet: FNetModel,
363    lm_head: FNetLMPredictionHead,
364}
365
366impl FNetForMaskedLM {
367    /// Build a new `FNetForMaskedLM`
368    ///
369    /// # Arguments
370    ///
371    /// * `p` - Variable store path for the root of the FNet model
372    /// * `config` - `FNetConfig` object defining the model architecture
373    ///
374    /// # Example
375    ///
376    /// ```no_run
377    /// use rust_bert::fnet::{FNetConfig, FNetForMaskedLM};
378    /// use rust_bert::Config;
379    /// use std::path::Path;
380    /// use tch::{nn, Device};
381    ///
382    /// let config_path = Path::new("path/to/config.json");
383    /// let device = Device::Cpu;
384    /// let p = nn::VarStore::new(device);
385    /// let config = FNetConfig::from_file(config_path);
386    /// let fnet = FNetForMaskedLM::new(&p.root() / "fnet", &config);
387    /// ```
388    pub fn new<'p, P>(p: P, config: &FNetConfig) -> FNetForMaskedLM
389    where
390        P: Borrow<nn::Path<'p>>,
391    {
392        let p = p.borrow();
393
394        let fnet = FNetModel::new(p / "fnet", config, false);
395        let lm_head = FNetLMPredictionHead::new(p.sub("cls").sub("predictions"), config);
396
397        FNetForMaskedLM { fnet, lm_head }
398    }
399
400    /// Forward pass through the model
401    ///
402    /// # Arguments
403    ///
404    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
405    /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
406    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
407    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
408    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
409    ///
410    /// # Returns
411    ///
412    /// * `FNetMaskedLMOutput` containing:
413    ///   - `prediction_scores` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*)
414    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
415    ///
416    /// # Example
417    ///
418    /// ```no_run
419    /// # use tch::{nn, Device, Tensor, no_grad};
420    /// # use rust_bert::Config;
421    /// # use std::path::Path;
422    /// # use tch::kind::Kind::Int64;
423    /// use rust_bert::fnet::{FNetConfig, FNetForMaskedLM};
424    /// # let config_path = Path::new("path/to/config.json");
425    /// # let device = Device::Cpu;
426    /// # let vs = nn::VarStore::new(device);
427    /// # let config = FNetConfig::from_file(config_path);
428    /// let model = FNetForMaskedLM::new(&vs.root(), &config);
429    /// let (batch_size, sequence_length) = (64, 128);
430    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
431    /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
432    /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
433    ///     .expand(&[batch_size, sequence_length], true);
434    ///
435    /// let model_output = no_grad(|| {
436    ///     model
437    ///         .forward_t(
438    ///             Some(&input_tensor),
439    ///             Some(&token_type_ids),
440    ///             Some(&position_ids),
441    ///             None,
442    ///             false,
443    ///         )
444    ///         .unwrap()
445    /// });
446    /// ```
447    pub fn forward_t(
448        &self,
449        input_ids: Option<&Tensor>,
450        token_type_ids: Option<&Tensor>,
451        position_ids: Option<&Tensor>,
452        input_embeddings: Option<&Tensor>,
453        train: bool,
454    ) -> Result<FNetMaskedLMOutput, RustBertError> {
455        let model_output = self.fnet.forward_t(
456            input_ids,
457            token_type_ids,
458            position_ids,
459            input_embeddings,
460            train,
461        )?;
462
463        let prediction_scores = self.lm_head.forward(&model_output.hidden_states);
464
465        Ok(FNetMaskedLMOutput {
466            prediction_scores,
467            all_hidden_states: model_output.all_hidden_states,
468        })
469    }
470}
471
472/// # FNet for sequence classification
473/// Base FNet model with a classifier head to perform sentence or document-level classification
474/// It is made of the following blocks:
475/// - `fnet`: Base FNet model
476/// - `dropout`: Dropout layer before the last linear layer
477/// - `classifier`: linear layer mapping from hidden to the number of classes to predict
478pub struct FNetForSequenceClassification {
479    fnet: FNetModel,
480    dropout: Dropout,
481    classifier: nn::Linear,
482}
483
484impl FNetForSequenceClassification {
485    /// Build a new `FNetForSequenceClassification`
486    ///
487    /// # Arguments
488    ///
489    /// * `p` - Variable store path for the root of the FNet model
490    /// * `config` - `FNetConfig` object defining the model architecture
491    ///
492    /// # Example
493    ///
494    /// ```no_run
495    /// use rust_bert::fnet::{FNetConfig, FNetForSequenceClassification};
496    /// use rust_bert::Config;
497    /// use std::path::Path;
498    /// use tch::{nn, Device};
499    ///
500    /// let config_path = Path::new("path/to/config.json");
501    /// let device = Device::Cpu;
502    /// let p = nn::VarStore::new(device);
503    /// let config = FNetConfig::from_file(config_path);
504    /// let fnet = FNetForSequenceClassification::new(&p.root() / "fnet", &config).unwrap();
505    /// ```
506    pub fn new<'p, P>(
507        p: P,
508        config: &FNetConfig,
509    ) -> Result<FNetForSequenceClassification, RustBertError>
510    where
511        P: Borrow<nn::Path<'p>>,
512    {
513        let p = p.borrow();
514
515        let fnet = FNetModel::new(p / "fnet", config, true);
516        let dropout = Dropout::new(config.hidden_dropout_prob);
517        let num_labels = config
518            .id2label
519            .as_ref()
520            .ok_or_else(|| {
521                RustBertError::InvalidConfigurationError(
522                    "num_labels not provided in configuration".to_string(),
523                )
524            })?
525            .len() as i64;
526        let classifier = nn::linear(
527            p / "classifier",
528            config.hidden_size,
529            num_labels,
530            Default::default(),
531        );
532
533        Ok(FNetForSequenceClassification {
534            fnet,
535            dropout,
536            classifier,
537        })
538    }
539
540    /// Forward pass through the model
541    ///
542    /// # Arguments
543    ///
544    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
545    /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
546    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
547    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
548    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
549    ///
550    /// # Returns
551    ///
552    /// * `FNetSequenceClassificationOutput` containing:
553    ///   - `logits` - `Tensor` of shape (*batch size*, *num_classes*)
554    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
555    ///
556    /// # Example
557    ///
558    /// ```no_run
559    /// # use tch::{nn, Device, Tensor, no_grad};
560    /// # use rust_bert::Config;
561    /// # use std::path::Path;
562    /// # use tch::kind::Kind::Int64;
563    /// use rust_bert::fnet::{FNetConfig, FNetForSequenceClassification};
564    /// # let config_path = Path::new("path/to/config.json");
565    /// # let device = Device::Cpu;
566    /// # let vs = nn::VarStore::new(device);
567    /// # let config = FNetConfig::from_file(config_path);
568    /// let model = FNetForSequenceClassification::new(&vs.root(), &config).unwrap();
569    /// let (batch_size, sequence_length) = (64, 128);
570    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
571    /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
572    /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
573    ///     .expand(&[batch_size, sequence_length], true);
574    ///
575    /// let model_output = no_grad(|| {
576    ///     model
577    ///         .forward_t(
578    ///             Some(&input_tensor),
579    ///             Some(&token_type_ids),
580    ///             Some(&position_ids),
581    ///             None,
582    ///             false,
583    ///         )
584    ///         .unwrap()
585    /// });
586    /// ```
587    pub fn forward_t(
588        &self,
589        input_ids: Option<&Tensor>,
590        token_type_ids: Option<&Tensor>,
591        position_ids: Option<&Tensor>,
592        input_embeddings: Option<&Tensor>,
593        train: bool,
594    ) -> Result<FNetSequenceClassificationOutput, RustBertError> {
595        let base_model_output = self.fnet.forward_t(
596            input_ids,
597            token_type_ids,
598            position_ids,
599            input_embeddings,
600            train,
601        )?;
602
603        let logits = base_model_output
604            .pooled_output
605            .unwrap()
606            .apply_t(&self.dropout, train)
607            .apply(&self.classifier);
608
609        Ok(FNetSequenceClassificationOutput {
610            logits,
611            all_hidden_states: base_model_output.all_hidden_states,
612        })
613    }
614}
615
616/// # FNet for multiple choices
617/// Multiple choices model using a FNet base model and a linear classifier.
618/// Input should be in the form `[CLS] Context [SEP] Possible choice [SEP]`. The choice is made along the batch axis,
619/// assuming all elements of the batch are alternatives to be chosen from for a given context.
620/// It is made of the following blocks:
621/// - `fnet`: Base FNet model
622/// - `dropout`: Dropout layer before the last start/end logits prediction
623/// - `classifier`: Linear layer for multiple choices
624pub struct FNetForMultipleChoice {
625    fnet: FNetModel,
626    dropout: Dropout,
627    classifier: nn::Linear,
628}
629
630impl FNetForMultipleChoice {
631    /// Build a new `FNetForMultipleChoice`
632    ///
633    /// # Arguments
634    ///
635    /// * `p` - Variable store path for the root of the FNet model
636    /// * `config` - `FNetConfig` object defining the model architecture
637    ///
638    /// # Example
639    ///
640    /// ```no_run
641    /// use rust_bert::fnet::{FNetConfig, FNetForMultipleChoice};
642    /// use rust_bert::Config;
643    /// use std::path::Path;
644    /// use tch::{nn, Device};
645    ///
646    /// let config_path = Path::new("path/to/config.json");
647    /// let device = Device::Cpu;
648    /// let p = nn::VarStore::new(device);
649    /// let config = FNetConfig::from_file(config_path);
650    /// let fnet = FNetForMultipleChoice::new(&p.root() / "fnet", &config);
651    /// ```
652    pub fn new<'p, P>(p: P, config: &FNetConfig) -> FNetForMultipleChoice
653    where
654        P: Borrow<nn::Path<'p>>,
655    {
656        let p = p.borrow();
657
658        let fnet = FNetModel::new(p / "fnet", config, true);
659        let dropout = Dropout::new(config.hidden_dropout_prob);
660        let classifier = nn::linear(p / "classifier", config.hidden_size, 1, Default::default());
661
662        FNetForMultipleChoice {
663            fnet,
664            dropout,
665            classifier,
666        }
667    }
668
669    /// Forward pass through the model
670    ///
671    /// # Arguments
672    ///
673    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
674    /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
675    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
676    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
677    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
678    ///
679    /// # Returns
680    ///
681    /// * `FNetSequenceClassificationOutput` containing:
682    ///   - `logits` - `Tensor` of shape (*1*, *batch_size*) containing the logits for each of the alternatives given
683    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
684    ///
685    /// # Example
686    ///
687    /// ```no_run
688    /// # use tch::{nn, Device, Tensor, no_grad};
689    /// # use rust_bert::Config;
690    /// # use std::path::Path;
691    /// # use tch::kind::Kind::Int64;
692    /// use rust_bert::fnet::{FNetConfig, FNetForMultipleChoice};
693    /// # let config_path = Path::new("path/to/config.json");
694    /// # let device = Device::Cpu;
695    /// # let vs = nn::VarStore::new(device);
696    /// # let config = FNetConfig::from_file(config_path);
697    /// let model = FNetForMultipleChoice::new(&vs.root(), &config);
698    /// let (batch_size, sequence_length) = (64, 128);
699    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
700    /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
701    /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
702    ///     .expand(&[batch_size, sequence_length], true);
703    ///
704    /// let model_output = no_grad(|| {
705    ///     model
706    ///         .forward_t(
707    ///             Some(&input_tensor),
708    ///             Some(&token_type_ids),
709    ///             Some(&position_ids),
710    ///             None,
711    ///             false,
712    ///         )
713    ///         .unwrap()
714    /// });
715    /// ```
716    pub fn forward_t(
717        &self,
718        input_ids: Option<&Tensor>,
719        token_type_ids: Option<&Tensor>,
720        position_ids: Option<&Tensor>,
721        input_embeddings: Option<&Tensor>,
722        train: bool,
723    ) -> Result<FNetSequenceClassificationOutput, RustBertError> {
724        let (input_shape, _) =
725            get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeddings)?;
726        let num_choices = input_shape[1];
727
728        let input_ids = input_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
729        let token_type_ids =
730            token_type_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
731        let position_ids =
732            position_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
733        let input_embeddings =
734            input_embeddings.map(|tensor| tensor.view((-1, tensor.size()[2], tensor.size()[3])));
735
736        let base_model_output = self.fnet.forward_t(
737            input_ids.as_ref(),
738            token_type_ids.as_ref(),
739            position_ids.as_ref(),
740            input_embeddings.as_ref(),
741            train,
742        )?;
743
744        let logits = base_model_output
745            .pooled_output
746            .unwrap()
747            .apply_t(&self.dropout, train)
748            .apply(&self.classifier)
749            .view((-1, num_choices));
750
751        Ok(FNetSequenceClassificationOutput {
752            logits,
753            all_hidden_states: base_model_output.all_hidden_states,
754        })
755    }
756}
757
758/// # FNet for token classification (e.g. NER, POS)
759/// Token-level classifier predicting a label for each token provided. Note that because of wordpiece tokenization, the labels predicted are
760/// not necessarily aligned with words in the sentence.
761/// It is made of the following blocks:
762/// - `fnet`: Base FNet
763/// - `dropout`: Dropout layer before the last token-level predictions layer
764/// - `classifier`: Linear layer for token classification
765pub struct FNetForTokenClassification {
766    fnet: FNetModel,
767    dropout: Dropout,
768    classifier: nn::Linear,
769}
770
771impl FNetForTokenClassification {
772    /// Build a new `FNetForTokenClassification`
773    ///
774    /// # Arguments
775    ///
776    /// * `p` - Variable store path for the root of the FNet model
777    /// * `config` - `FNetConfig` object defining the model architecture
778    ///
779    /// # Example
780    ///
781    /// ```no_run
782    /// use rust_bert::fnet::{FNetConfig, FNetForTokenClassification};
783    /// use rust_bert::Config;
784    /// use std::path::Path;
785    /// use tch::{nn, Device};
786    ///
787    /// let config_path = Path::new("path/to/config.json");
788    /// let device = Device::Cpu;
789    /// let p = nn::VarStore::new(device);
790    /// let config = FNetConfig::from_file(config_path);
791    /// let fnet = FNetForTokenClassification::new(&p.root() / "fnet", &config).unwrap();
792    /// ```
793    pub fn new<'p, P>(
794        p: P,
795        config: &FNetConfig,
796    ) -> Result<FNetForTokenClassification, RustBertError>
797    where
798        P: Borrow<nn::Path<'p>>,
799    {
800        let p = p.borrow();
801
802        let fnet = FNetModel::new(p / "fnet", config, false);
803        let dropout = Dropout::new(config.hidden_dropout_prob);
804        let num_labels = config
805            .id2label
806            .as_ref()
807            .ok_or_else(|| {
808                RustBertError::InvalidConfigurationError(
809                    "num_labels not provided in configuration".to_string(),
810                )
811            })?
812            .len() as i64;
813        let classifier = nn::linear(
814            p / "classifier",
815            config.hidden_size,
816            num_labels,
817            Default::default(),
818        );
819
820        Ok(FNetForTokenClassification {
821            fnet,
822            dropout,
823            classifier,
824        })
825    }
826
827    /// Forward pass through the model
828    ///
829    /// # Arguments
830    ///
831    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
832    /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
833    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
834    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
835    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
836    ///
837    /// # Returns
838    ///
839    /// * `FNetTokenClassificationOutput` containing:
840    ///   - `logits` - `Tensor` of shape (*batch size*, *sequence_length*, *num_labels*) containing the logits for each of the input tokens and classes
841    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
842    ///
843    /// # Example
844    ///
845    /// ```no_run
846    /// # use tch::{nn, Device, Tensor, no_grad};
847    /// # use rust_bert::Config;
848    /// # use std::path::Path;
849    /// # use tch::kind::Kind::Int64;
850    /// use rust_bert::fnet::{FNetConfig, FNetForTokenClassification};
851    /// # let config_path = Path::new("path/to/config.json");
852    /// # let device = Device::Cpu;
853    /// # let vs = nn::VarStore::new(device);
854    /// # let config = FNetConfig::from_file(config_path);
855    /// let model = FNetForTokenClassification::new(&vs.root(), &config).unwrap();
856    /// let (batch_size, sequence_length) = (64, 128);
857    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
858    /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
859    /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
860    ///     .expand(&[batch_size, sequence_length], true);
861    ///
862    /// let model_output = no_grad(|| {
863    ///     model
864    ///         .forward_t(
865    ///             Some(&input_tensor),
866    ///             Some(&token_type_ids),
867    ///             Some(&position_ids),
868    ///             None,
869    ///             false,
870    ///         )
871    ///         .unwrap()
872    /// });
873    /// ```
874    pub fn forward_t(
875        &self,
876        input_ids: Option<&Tensor>,
877        token_type_ids: Option<&Tensor>,
878        position_ids: Option<&Tensor>,
879        input_embeddings: Option<&Tensor>,
880        train: bool,
881    ) -> Result<FNetTokenClassificationOutput, RustBertError> {
882        let base_model_output = self.fnet.forward_t(
883            input_ids,
884            token_type_ids,
885            position_ids,
886            input_embeddings,
887            train,
888        )?;
889
890        let logits = base_model_output
891            .hidden_states
892            .apply_t(&self.dropout, train)
893            .apply(&self.classifier);
894
895        Ok(FNetTokenClassificationOutput {
896            logits,
897            all_hidden_states: base_model_output.all_hidden_states,
898        })
899    }
900}
901
902/// # FNet for question answering
903/// Extractive question-answering model based on a FNet language model. Identifies the segment of a context that answers a provided question.
904/// Please note that a significant amount of pre- and post-processing is required to perform end-to-end question answering.
905/// See the question answering pipeline (also provided in this crate) for more details.
906/// It is made of the following blocks:
907/// - `fnet`: Base FNet
908/// - `qa_outputs`: Linear layer for question answering
909pub struct FNetForQuestionAnswering {
910    fnet: FNetModel,
911    qa_outputs: nn::Linear,
912}
913
914impl FNetForQuestionAnswering {
915    /// Build a new `FNetForQuestionAnswering`
916    ///
917    /// # Arguments
918    ///
919    /// * `p` - Variable store path for the root of the FNet model
920    /// * `config` - `FNetConfig` object defining the model architecture
921    ///
922    /// # Example
923    ///
924    /// ```no_run
925    /// use rust_bert::fnet::{FNetConfig, FNetForQuestionAnswering};
926    /// use rust_bert::Config;
927    /// use std::path::Path;
928    /// use tch::{nn, Device};
929    ///
930    /// let config_path = Path::new("path/to/config.json");
931    /// let device = Device::Cpu;
932    /// let p = nn::VarStore::new(device);
933    /// let config = FNetConfig::from_file(config_path);
934    /// let fnet = FNetForQuestionAnswering::new(&p.root() / "fnet", &config);
935    /// ```
936    pub fn new<'p, P>(p: P, config: &FNetConfig) -> FNetForQuestionAnswering
937    where
938        P: Borrow<nn::Path<'p>>,
939    {
940        let p = p.borrow();
941
942        let fnet = FNetModel::new(p / "fnet", config, false);
943        let qa_outputs = nn::linear(p / "classifier", config.hidden_size, 2, Default::default());
944
945        FNetForQuestionAnswering { fnet, qa_outputs }
946    }
947
948    /// Forward pass through the model
949    ///
950    /// # Arguments
951    ///
952    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
953    /// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
954    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
955    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
956    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
957    ///
958    /// # Returns
959    ///
960    /// * `FNetQuestionAnsweringOutput` containing:
961    ///   - `start_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for start of the answer
962    ///   - `end_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for end of the answer
963    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
964    ///
965    /// # Example
966    ///
967    /// ```no_run
968    /// # use tch::{nn, Device, Tensor, no_grad};
969    /// # use rust_bert::Config;
970    /// # use std::path::Path;
971    /// # use tch::kind::Kind::Int64;
972    /// use rust_bert::fnet::{FNetConfig, FNetForQuestionAnswering};
973    /// # let config_path = Path::new("path/to/config.json");
974    /// # let device = Device::Cpu;
975    /// # let vs = nn::VarStore::new(device);
976    /// # let config = FNetConfig::from_file(config_path);
977    /// let model = FNetForQuestionAnswering::new(&vs.root(), &config);
978    /// let (batch_size, sequence_length) = (64, 128);
979    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
980    /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
981    /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
982    ///     .expand(&[batch_size, sequence_length], true);
983    ///
984    /// let model_output = no_grad(|| {
985    ///     model
986    ///         .forward_t(
987    ///             Some(&input_tensor),
988    ///             Some(&token_type_ids),
989    ///             Some(&position_ids),
990    ///             None,
991    ///             false,
992    ///         )
993    ///         .unwrap()
994    /// });
995    /// ```
996    pub fn forward_t(
997        &self,
998        input_ids: Option<&Tensor>,
999        token_type_ids: Option<&Tensor>,
1000        position_ids: Option<&Tensor>,
1001        input_embeddings: Option<&Tensor>,
1002        train: bool,
1003    ) -> Result<FNetQuestionAnsweringOutput, RustBertError> {
1004        let base_model_output = self.fnet.forward_t(
1005            input_ids,
1006            token_type_ids,
1007            position_ids,
1008            input_embeddings,
1009            train,
1010        )?;
1011
1012        let logits = base_model_output
1013            .hidden_states
1014            .apply(&self.qa_outputs)
1015            .split(1, -1);
1016        let (start_logits, end_logits) = (&logits[0], &logits[1]);
1017        let start_logits = start_logits.squeeze_dim(-1);
1018        let end_logits = end_logits.squeeze_dim(-1);
1019
1020        Ok(FNetQuestionAnsweringOutput {
1021            start_logits,
1022            end_logits,
1023            all_hidden_states: base_model_output.all_hidden_states,
1024        })
1025    }
1026}
1027
1028/// Container for the FNet model output.
1029pub struct FNetModelOutput {
1030    /// Last hidden states from the model
1031    pub hidden_states: Tensor,
1032    /// Pooled output (hidden state for the first token)
1033    pub pooled_output: Option<Tensor>,
1034    /// Hidden states for all intermediate layers
1035    pub all_hidden_states: Option<Vec<Tensor>>,
1036}
1037
1038/// Container for the FNet masked LM model output.
1039pub struct FNetMaskedLMOutput {
1040    /// Logits for the vocabulary items at each sequence position
1041    pub prediction_scores: Tensor,
1042    /// Hidden states for all intermediate layers
1043    pub all_hidden_states: Option<Vec<Tensor>>,
1044}
1045
1046/// Container for the FNet sequence classification model output.
1047pub struct FNetSequenceClassificationOutput {
1048    /// Logits for each input (sequence) for each target class
1049    pub logits: Tensor,
1050    /// Hidden states for all intermediate layers
1051    pub all_hidden_states: Option<Vec<Tensor>>,
1052}
1053
1054/// Container for the FNet token classification model output.
1055pub type FNetTokenClassificationOutput = FNetSequenceClassificationOutput;
1056
1057/// Container for the FNet question answering model output.
1058pub struct FNetQuestionAnsweringOutput {
1059    /// Logits for the start position for token of each input sequence
1060    pub start_logits: Tensor,
1061    /// Logits for the end position for token of each input sequence
1062    pub end_logits: Tensor,
1063    /// Hidden states for all intermediate layers
1064    pub all_hidden_states: Option<Vec<Tensor>>,
1065}
1066
1067#[cfg(test)]
1068mod test {
1069    use tch::Device;
1070
1071    use crate::{
1072        resources::{RemoteResource, ResourceProvider},
1073        Config,
1074    };
1075
1076    use super::*;
1077
1078    #[test]
1079    #[ignore] // compilation is enough, no need to run
1080    fn fnet_model_send() {
1081        let config_resource = Box::new(RemoteResource::from_pretrained(FNetConfigResources::BASE));
1082        let config_path = config_resource.get_local_path().expect("");
1083
1084        //    Set-up masked LM model
1085        let device = Device::cuda_if_available();
1086        let vs = nn::VarStore::new(device);
1087        let config = FNetConfig::from_file(config_path);
1088
1089        let _: Box<dyn Send> = Box::new(FNetModel::new(vs.root(), &config, true));
1090    }
1091}