rust_bert/models/distilbert/
distilbert_model.rs

1// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
2// Copyright 2019 Guillaume Becquin
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//     http://www.apache.org/licenses/LICENSE-2.0
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13extern crate tch;
14
15use self::tch::{nn, Tensor};
16use crate::common::activations::Activation;
17use crate::common::dropout::Dropout;
18use crate::distilbert::embeddings::DistilBertEmbedding;
19use crate::distilbert::transformer::{DistilBertTransformerOutput, Transformer};
20use crate::{Config, RustBertError};
21use serde::{Deserialize, Serialize};
22use std::{borrow::Borrow, collections::HashMap};
23
24/// # DistilBERT Pretrained model weight files
25pub struct DistilBertModelResources;
26
27/// # DistilBERT Pretrained model config files
28pub struct DistilBertConfigResources;
29
30/// # DistilBERT Pretrained model vocab files
31pub struct DistilBertVocabResources;
32
33impl DistilBertModelResources {
34    /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
35    pub const DISTIL_BERT_SST2: (&'static str, &'static str) = (
36        "distilbert-sst2/model",
37        "https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/rust_model.ot",
38    );
39    /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
40    pub const DISTIL_BERT: (&'static str, &'static str) = (
41        "distilbert/model",
42        "https://huggingface.co/distilbert-base-uncased/resolve/main/rust_model.ot",
43    );
44    /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
45    pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = (
46        "distilbert-qa/model",
47        "https://huggingface.co/distilbert-base-cased-distilled-squad/resolve/main/rust_model.ot",
48    );
49    /// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased>. Modified with conversion to C-array format.
50    pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = (
51        "distiluse-base-multilingual-cased/model",
52        "https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/rust_model.ot",
53    );
54}
55
56impl DistilBertConfigResources {
57    /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
58    pub const DISTIL_BERT_SST2: (&'static str, &'static str) = (
59        "distilbert-sst2/config",
60        "https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/config.json",
61    );
62    /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
63    pub const DISTIL_BERT: (&'static str, &'static str) = (
64        "distilbert/config",
65        "https://huggingface.co/distilbert-base-uncased/resolve/main/config.json",
66    );
67    /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
68    pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = (
69        "distilbert-qa/config",
70        "https://huggingface.co/distilbert-base-cased-distilled-squad/resolve/main/config.json",
71    );
72    /// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased>. Modified with conversion to C-array format.
73    pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = (
74        "distiluse-base-multilingual-cased/config",
75        "https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/config.json",
76    );
77}
78
79impl DistilBertVocabResources {
80    /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
81    pub const DISTIL_BERT_SST2: (&'static str, &'static str) = (
82        "distilbert-sst2/vocab",
83        "https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/vocab.txt",
84    );
85    /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
86    pub const DISTIL_BERT: (&'static str, &'static str) = (
87        "distilbert/vocab",
88        "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt",
89    );
90    /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
91    pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = (
92        "distilbert-qa/vocab",
93        "https://huggingface.co/bert-large-cased/resolve/main/vocab.txt",
94    );
95    /// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased>. Modified with conversion to C-array format.
96    pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = (
97        "distiluse-base-multilingual-cased/vocab",
98        "https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/vocab.txt",
99    );
100}
101
102#[derive(Debug, Serialize, Deserialize, Clone)]
103/// # DistilBERT model configuration
104/// Defines the DistilBERT model architecture (e.g. number of layers, hidden layer size, label mapping...)
105pub struct DistilBertConfig {
106    pub activation: Activation,
107    pub attention_dropout: f64,
108    pub dim: i64,
109    pub dropout: f64,
110    pub hidden_dim: i64,
111    pub id2label: Option<HashMap<i64, String>>,
112    pub initializer_range: f32,
113    pub is_decoder: Option<bool>,
114    pub label2id: Option<HashMap<String, i64>>,
115    pub max_position_embeddings: i64,
116    pub n_heads: i64,
117    pub n_layers: i64,
118    pub output_attentions: Option<bool>,
119    pub output_hidden_states: Option<bool>,
120    pub output_past: Option<bool>,
121    pub qa_dropout: f64,
122    pub seq_classif_dropout: f64,
123    pub sinusoidal_pos_embds: bool,
124    pub tie_weights_: bool,
125    pub vocab_size: i64,
126}
127
128impl Config for DistilBertConfig {}
129
130impl Default for DistilBertConfig {
131    fn default() -> Self {
132        DistilBertConfig {
133            activation: Activation::gelu,
134            attention_dropout: 0.1,
135            dim: 768,
136            dropout: 0.1,
137            hidden_dim: 3072,
138            id2label: None,
139            initializer_range: 0.02,
140            is_decoder: None,
141            label2id: None,
142            max_position_embeddings: 512,
143            n_heads: 12,
144            n_layers: 6,
145            output_attentions: None,
146            output_hidden_states: None,
147            output_past: None,
148            qa_dropout: 0.1,
149            seq_classif_dropout: 0.2,
150            sinusoidal_pos_embds: false,
151            tie_weights_: false,
152            vocab_size: 30522,
153        }
154    }
155}
156
157/// # DistilBERT Base model
158/// Base architecture for DistilBERT models. Task-specific models will be built from this common base model
159/// It is made of the following blocks:
160/// - `embeddings`: `token`, `position` embeddings
161/// - `transformer`: Transformer made of a vector of layers. Each layer is made of a multi-head self-attention layer, layer norm and linear layers.
162pub struct DistilBertModel {
163    embeddings: DistilBertEmbedding,
164    transformer: Transformer,
165}
166
167/// Defines the implementation of the DistilBertModel.
168impl DistilBertModel {
169    /// Build a new `DistilBertModel`
170    ///
171    /// # Arguments
172    ///
173    /// * `p` - Variable store path for the root of the DistilBERT model
174    /// * `config` - `DistilBertConfig` object defining the model architecture
175    ///
176    /// # Example
177    ///
178    /// ```no_run
179    /// use rust_bert::distilbert::{DistilBertConfig, DistilBertModel};
180    /// use rust_bert::Config;
181    /// use std::path::Path;
182    /// use tch::{nn, Device};
183    ///
184    /// let config_path = Path::new("path/to/config.json");
185    /// let device = Device::Cpu;
186    /// let p = nn::VarStore::new(device);
187    /// let config = DistilBertConfig::from_file(config_path);
188    /// let distil_bert: DistilBertModel = DistilBertModel::new(&p.root() / "distilbert", &config);
189    /// ```
190    pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> DistilBertModel
191    where
192        P: Borrow<nn::Path<'p>>,
193    {
194        let p = p.borrow() / "distilbert";
195        let embeddings = DistilBertEmbedding::new(&p / "embeddings", config);
196        let transformer = Transformer::new(p / "transformer", config);
197        DistilBertModel {
198            embeddings,
199            transformer,
200        }
201    }
202
203    /// Forward pass through the model
204    ///
205    /// # Arguments
206    ///
207    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
208    /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
209    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
210    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
211    ///
212    /// # Returns
213    ///
214    /// * `DistilBertTransformerOutput` containing:
215    ///   - `hidden_state` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*)
216    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
217    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
218    ///
219    /// # Example
220    ///
221    /// ```no_run
222    /// # use tch::{nn, Device, Tensor, no_grad};
223    /// # use rust_bert::Config;
224    /// # use std::path::Path;
225    /// # use tch::kind::Kind::Int64;
226    /// use rust_bert::distilbert::{DistilBertConfig, DistilBertModel};
227    /// # let config_path = Path::new("path/to/config.json");
228    /// # let vocab_path = Path::new("path/to/vocab.txt");
229    /// # let device = Device::Cpu;
230    /// # let vs = nn::VarStore::new(device);
231    /// # let config = DistilBertConfig::from_file(config_path);
232    /// # let distilbert_model: DistilBertModel = DistilBertModel::new(&vs.root(), &config);
233    /// let (batch_size, sequence_length) = (64, 128);
234    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
235    /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
236    ///
237    /// let model_output = no_grad(|| {
238    ///     distilbert_model
239    ///         .forward_t(Some(&input_tensor), Some(&mask), None, false)
240    ///         .unwrap()
241    /// });
242    /// ```
243    pub fn forward_t(
244        &self,
245        input: Option<&Tensor>,
246        mask: Option<&Tensor>,
247        input_embeds: Option<&Tensor>,
248        train: bool,
249    ) -> Result<DistilBertTransformerOutput, RustBertError> {
250        let input_embeddings = self.embeddings.forward_t(input, input_embeds, train)?;
251        let transformer_output = self.transformer.forward_t(&input_embeddings, mask, train);
252        Ok(transformer_output)
253    }
254}
255
256/// # DistilBERT for sequence classification
257/// Base DistilBERT model with a pre-classifier and classifier heads to perform sentence or document-level classification
258/// It is made of the following blocks:
259/// - `distil_bert_model`: Base DistilBertModel
260/// - `pre_classifier`: DistilBERT linear layer for classification
261/// - `classifier`: DistilBERT linear layer for classification
262pub struct DistilBertModelClassifier {
263    distil_bert_model: DistilBertModel,
264    pre_classifier: nn::Linear,
265    classifier: nn::Linear,
266    dropout: Dropout,
267}
268
269impl DistilBertModelClassifier {
270    /// Build a new `DistilBertModelClassifier` for sequence classification
271    ///
272    /// # Arguments
273    ///
274    /// * `p` - Variable store path for the root of the DistilBertModelClassifier model
275    /// * `config` - `DistilBertConfig` object defining the model architecture
276    ///
277    /// # Example
278    ///
279    /// ```no_run
280    /// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelClassifier};
281    /// use rust_bert::Config;
282    /// use std::path::Path;
283    /// use tch::{nn, Device};
284    ///
285    /// let config_path = Path::new("path/to/config.json");
286    /// let device = Device::Cpu;
287    /// let p = nn::VarStore::new(device);
288    /// let config = DistilBertConfig::from_file(config_path);
289    /// let distil_bert: DistilBertModelClassifier =
290    ///     DistilBertModelClassifier::new(&p.root() / "distilbert", &config).unwrap();
291    /// ```
292    pub fn new<'p, P>(
293        p: P,
294        config: &DistilBertConfig,
295    ) -> Result<DistilBertModelClassifier, RustBertError>
296    where
297        P: Borrow<nn::Path<'p>>,
298    {
299        let p = p.borrow();
300
301        let distil_bert_model = DistilBertModel::new(p, config);
302
303        let num_labels = config
304            .id2label
305            .as_ref()
306            .ok_or_else(|| {
307                RustBertError::InvalidConfigurationError(
308                    "num_labels not provided in configuration".to_string(),
309                )
310            })?
311            .len() as i64;
312
313        let pre_classifier = nn::linear(
314            p / "pre_classifier",
315            config.dim,
316            config.dim,
317            Default::default(),
318        );
319        let classifier = nn::linear(p / "classifier", config.dim, num_labels, Default::default());
320        let dropout = Dropout::new(config.seq_classif_dropout);
321
322        Ok(DistilBertModelClassifier {
323            distil_bert_model,
324            pre_classifier,
325            classifier,
326            dropout,
327        })
328    }
329
330    /// Forward pass through the model
331    ///
332    /// # Arguments
333    ///
334    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
335    /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
336    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
337    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
338    ///
339    /// # Returns
340    ///
341    /// * `DistilBertSequenceClassificationOutput` containing:
342    ///   - `logits` - `Tensor` of shape (*batch size*, *num_labels*)
343    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
344    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
345    ///
346    /// # Example
347    ///
348    /// ```no_run
349    /// # use tch::{nn, Device, Tensor, no_grad};
350    /// # use rust_bert::Config;
351    /// # use std::path::Path;
352    /// # use tch::kind::Kind::Int64;
353    /// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelClassifier};
354    /// # let config_path = Path::new("path/to/config.json");
355    /// # let vocab_path = Path::new("path/to/vocab.txt");
356    /// # let device = Device::Cpu;
357    /// # let vs = nn::VarStore::new(device);
358    /// # let config = DistilBertConfig::from_file(config_path);
359    /// # let distilbert_model: DistilBertModelClassifier = DistilBertModelClassifier::new(&vs.root(), &config).unwrap();;
360    ///  let (batch_size, sequence_length) = (64, 128);
361    ///  let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
362    ///  let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
363    ///
364    ///  let model_output = no_grad(|| {
365    ///    distilbert_model
366    ///         .forward_t(Some(&input_tensor),
367    ///                    Some(&mask),
368    ///                    None,
369    ///                    false).unwrap()
370    ///    });
371    /// ```
372    pub fn forward_t(
373        &self,
374        input: Option<&Tensor>,
375        mask: Option<&Tensor>,
376        input_embeds: Option<&Tensor>,
377        train: bool,
378    ) -> Result<DistilBertSequenceClassificationOutput, RustBertError> {
379        let base_model_output =
380            self.distil_bert_model
381                .forward_t(input, mask, input_embeds, train)?;
382
383        let logits = base_model_output
384            .hidden_state
385            .select(1, 0)
386            .apply(&self.pre_classifier)
387            .relu()
388            .apply_t(&self.dropout, train)
389            .apply(&self.classifier);
390
391        Ok(DistilBertSequenceClassificationOutput {
392            logits,
393            all_hidden_states: base_model_output.all_hidden_states,
394            all_attentions: base_model_output.all_attentions,
395        })
396    }
397}
398
399/// # DistilBERT for masked language model
400/// Base DistilBERT model with a masked language model head to predict missing tokens, for example `"Looks like one [MASK] is missing" -> "person"`
401/// It is made of the following blocks:
402/// - `distil_bert_model`: Base DistilBertModel
403/// - `vocab_transform`:linear layer for classification of size (*hidden_dim*, *hidden_dim*)
404/// - `vocab_layer_norm`: layer normalization
405/// - `vocab_projector`: linear layer for classification of size (*hidden_dim*, *vocab_size*) with weights tied to the token embeddings
406pub struct DistilBertModelMaskedLM {
407    distil_bert_model: DistilBertModel,
408    vocab_transform: nn::Linear,
409    vocab_layer_norm: nn::LayerNorm,
410    vocab_projector: nn::Linear,
411}
412
413impl DistilBertModelMaskedLM {
414    /// Build a new `DistilBertModelMaskedLM` for sequence classification
415    ///
416    /// # Arguments
417    ///
418    /// * `p` - Variable store path for the root of the DistilBertModelMaskedLM model
419    /// * `config` - `DistilBertConfig` object defining the model architecture
420    ///
421    /// # Example
422    ///
423    /// ```no_run
424    /// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM};
425    /// use rust_bert::Config;
426    /// use std::path::Path;
427    /// use tch::{nn, Device};
428    ///
429    /// let config_path = Path::new("path/to/config.json");
430    /// let device = Device::Cpu;
431    /// let p = nn::VarStore::new(device);
432    /// let config = DistilBertConfig::from_file(config_path);
433    /// let distil_bert = DistilBertModelMaskedLM::new(&p.root() / "distilbert", &config);
434    /// ```
435    pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> DistilBertModelMaskedLM
436    where
437        P: Borrow<nn::Path<'p>>,
438    {
439        let p = p.borrow();
440
441        let distil_bert_model = DistilBertModel::new(p, config);
442        let vocab_transform = nn::linear(
443            p / "vocab_transform",
444            config.dim,
445            config.dim,
446            Default::default(),
447        );
448        let layer_norm_config = nn::LayerNormConfig {
449            eps: 1e-12,
450            ..Default::default()
451        };
452        let vocab_layer_norm =
453            nn::layer_norm(p / "vocab_layer_norm", vec![config.dim], layer_norm_config);
454        let vocab_projector = nn::linear(
455            p / "vocab_projector",
456            config.dim,
457            config.vocab_size,
458            Default::default(),
459        );
460
461        DistilBertModelMaskedLM {
462            distil_bert_model,
463            vocab_transform,
464            vocab_layer_norm,
465            vocab_projector,
466        }
467    }
468
469    /// Forward pass through the model
470    ///
471    /// # Arguments
472    ///
473    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
474    /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
475    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
476    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
477    ///
478    /// # Returns
479    ///
480    /// * `DistilBertMaskedLMOutput` containing:
481    ///   - `prediction_scores` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*)
482    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
483    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
484    ///
485    /// # Example
486    ///
487    /// ```no_run
488    /// # use tch::{nn, Device, Tensor, no_grad};
489    /// # use rust_bert::Config;
490    /// # use std::path::Path;
491    /// # use tch::kind::Kind::Int64;
492    /// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM};
493    /// # let config_path = Path::new("path/to/config.json");
494    /// # let vocab_path = Path::new("path/to/vocab.txt");
495    /// # let device = Device::Cpu;
496    /// # let vs = nn::VarStore::new(device);
497    /// # let config = DistilBertConfig::from_file(config_path);
498    /// # let distilbert_model = DistilBertModelMaskedLM::new(&vs.root(), &config);
499    /// let (batch_size, sequence_length) = (64, 128);
500    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
501    /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
502    ///
503    /// let model_output = no_grad(|| {
504    ///     distilbert_model
505    ///         .forward_t(Some(&input_tensor), Some(&mask), None, false)
506    ///         .unwrap()
507    /// });
508    /// ```
509    pub fn forward_t(
510        &self,
511        input: Option<&Tensor>,
512        mask: Option<&Tensor>,
513        input_embeds: Option<&Tensor>,
514        train: bool,
515    ) -> Result<DistilBertMaskedLMOutput, RustBertError> {
516        let base_model_output =
517            self.distil_bert_model
518                .forward_t(input, mask, input_embeds, train)?;
519
520        let prediction_scores = base_model_output
521            .hidden_state
522            .apply(&self.vocab_transform)
523            .gelu("none")
524            .apply(&self.vocab_layer_norm)
525            .apply(&self.vocab_projector);
526
527        Ok(DistilBertMaskedLMOutput {
528            prediction_scores,
529            all_hidden_states: base_model_output.all_hidden_states,
530            all_attentions: base_model_output.all_attentions,
531        })
532    }
533}
534
535/// # DistilBERT for question answering
536/// Extractive question-answering model based on a DistilBERT language model. Identifies the segment of a context that answers a provided question.
537/// Please note that a significant amount of pre- and post-processing is required to perform end-to-end question answering.
538/// See the question answering pipeline (also provided in this crate) for more details.
539/// It is made of the following blocks:
540/// - `distil_bert_model`: Base DistilBertModel
541/// - `qa_outputs`: Linear layer for question answering
542pub struct DistilBertForQuestionAnswering {
543    distil_bert_model: DistilBertModel,
544    qa_outputs: nn::Linear,
545    dropout: Dropout,
546}
547
548impl DistilBertForQuestionAnswering {
549    /// Build a new `DistilBertForQuestionAnswering` for sequence classification
550    ///
551    /// # Arguments
552    ///
553    /// * `p` - Variable store path for the root of the DistilBertForQuestionAnswering model
554    /// * `config` - `DistilBertConfig` object defining the model architecture
555    ///
556    /// # Example
557    ///
558    /// ```no_run
559    /// use rust_bert::distilbert::{DistilBertConfig, DistilBertForQuestionAnswering};
560    /// use rust_bert::Config;
561    /// use std::path::Path;
562    /// use tch::{nn, Device};
563    ///
564    /// let config_path = Path::new("path/to/config.json");
565    /// let device = Device::Cpu;
566    /// let p = nn::VarStore::new(device);
567    /// let config = DistilBertConfig::from_file(config_path);
568    /// let distil_bert = DistilBertForQuestionAnswering::new(&p.root() / "distilbert", &config);
569    /// ```
570    pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> DistilBertForQuestionAnswering
571    where
572        P: Borrow<nn::Path<'p>>,
573    {
574        let p = p.borrow();
575
576        let distil_bert_model = DistilBertModel::new(p, config);
577        let qa_outputs = nn::linear(p / "qa_outputs", config.dim, 2, Default::default());
578        let dropout = Dropout::new(config.qa_dropout);
579
580        DistilBertForQuestionAnswering {
581            distil_bert_model,
582            qa_outputs,
583            dropout,
584        }
585    }
586
587    /// Forward pass through the model
588    ///
589    /// # Arguments
590    ///
591    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
592    /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
593    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
594    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
595    ///
596    /// # Returns
597    ///
598    /// * `DistilBertQuestionAnsweringOutput` containing:
599    ///   - `start_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for start of the answer
600    ///   - `end_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for end of the answer
601    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
602    ///   - `all_attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
603    ///
604    /// # Example
605    ///
606    /// ```no_run
607    /// # use tch::{nn, Device, Tensor, no_grad};
608    /// # use rust_bert::Config;
609    /// # use std::path::Path;
610    /// # use tch::kind::Kind::Int64;
611    /// use rust_bert::distilbert::{DistilBertConfig, DistilBertForQuestionAnswering};
612    /// # let config_path = Path::new("path/to/config.json");
613    /// # let vocab_path = Path::new("path/to/vocab.txt");
614    /// # let device = Device::Cpu;
615    /// # let vs = nn::VarStore::new(device);
616    /// # let config = DistilBertConfig::from_file(config_path);
617    /// # let distilbert_model = DistilBertForQuestionAnswering::new(&vs.root(), &config);
618    /// let (batch_size, sequence_length) = (64, 128);
619    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
620    /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
621    ///
622    /// let model_output = no_grad(|| {
623    ///     distilbert_model
624    ///         .forward_t(Some(&input_tensor), Some(&mask), None, false)
625    ///         .unwrap()
626    /// });
627    /// ```
628    pub fn forward_t(
629        &self,
630        input: Option<&Tensor>,
631        mask: Option<&Tensor>,
632        input_embeds: Option<&Tensor>,
633        train: bool,
634    ) -> Result<DistilBertQuestionAnsweringOutput, RustBertError> {
635        let base_model_output =
636            self.distil_bert_model
637                .forward_t(input, mask, input_embeds, train)?;
638
639        let output = base_model_output
640            .hidden_state
641            .apply_t(&self.dropout, train)
642            .apply(&self.qa_outputs);
643
644        let logits = output.split(1, -1);
645        let (start_logits, end_logits) = (&logits[0], &logits[1]);
646        let start_logits = start_logits.squeeze_dim(-1);
647        let end_logits = end_logits.squeeze_dim(-1);
648
649        Ok(DistilBertQuestionAnsweringOutput {
650            start_logits,
651            end_logits,
652            all_hidden_states: base_model_output.all_hidden_states,
653            all_attentions: base_model_output.all_attentions,
654        })
655    }
656}
657
658/// # DistilBERT for token classification (e.g. NER, POS)
659/// Token-level classifier predicting a label for each token provided. Note that because of wordpiece tokenization, the labels predicted are
660/// not necessarily aligned with words in the sentence.
661/// It is made of the following blocks:
662/// - `distil_bert_model`: Base DistilBertModel
663/// - `classifier`: Linear layer for token classification
664pub struct DistilBertForTokenClassification {
665    distil_bert_model: DistilBertModel,
666    classifier: nn::Linear,
667    dropout: Dropout,
668}
669
670impl DistilBertForTokenClassification {
671    /// Build a new `DistilBertForTokenClassification` for sequence classification
672    ///
673    /// # Arguments
674    ///
675    /// * `p` - Variable store path for the root of the DistilBertForTokenClassification model
676    /// * `config` - `DistilBertConfig` object defining the model architecture
677    ///
678    /// # Example
679    ///
680    /// ```no_run
681    /// use rust_bert::distilbert::{DistilBertConfig, DistilBertForTokenClassification};
682    /// use rust_bert::Config;
683    /// use std::path::Path;
684    /// use tch::{nn, Device};
685    ///
686    /// let config_path = Path::new("path/to/config.json");
687    /// let device = Device::Cpu;
688    /// let p = nn::VarStore::new(device);
689    /// let config = DistilBertConfig::from_file(config_path);
690    /// let distil_bert =
691    ///     DistilBertForTokenClassification::new(&p.root() / "distilbert", &config).unwrap();
692    /// ```
693    pub fn new<'p, P>(
694        p: P,
695        config: &DistilBertConfig,
696    ) -> Result<DistilBertForTokenClassification, RustBertError>
697    where
698        P: Borrow<nn::Path<'p>>,
699    {
700        let p = p.borrow();
701
702        let distil_bert_model = DistilBertModel::new(p, config);
703
704        let num_labels = config
705            .id2label
706            .as_ref()
707            .ok_or_else(|| {
708                RustBertError::InvalidConfigurationError(
709                    "id2label must be provided for classifiers".to_string(),
710                )
711            })?
712            .len() as i64;
713
714        let classifier = nn::linear(p / "classifier", config.dim, num_labels, Default::default());
715        let dropout = Dropout::new(config.seq_classif_dropout);
716
717        Ok(DistilBertForTokenClassification {
718            distil_bert_model,
719            classifier,
720            dropout,
721        })
722    }
723
724    /// Forward pass through the model
725    ///
726    /// # Arguments
727    ///
728    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
729    /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
730    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
731    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
732    ///
733    /// # Returns
734    ///
735    /// * `DistilBertTokenClassificationOutput` containing:
736    ///   - `logits` - `Tensor` of shape (*batch size*, *sequence_length*, *num_labels*) containing the logits for each of the input tokens and classes
737    ///   - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
738    ///   - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
739    ///
740    /// # Example
741    ///
742    /// ```no_run
743    /// # use tch::{nn, Device, Tensor, no_grad};
744    /// # use rust_bert::Config;
745    /// # use std::path::Path;
746    /// # use tch::kind::Kind::Int64;
747    /// use rust_bert::distilbert::{DistilBertConfig, DistilBertForTokenClassification};
748    /// # let config_path = Path::new("path/to/config.json");
749    /// # let vocab_path = Path::new("path/to/vocab.txt");
750    /// # let device = Device::Cpu;
751    /// # let vs = nn::VarStore::new(device);
752    /// # let config = DistilBertConfig::from_file(config_path);
753    /// # let distilbert_model = DistilBertForTokenClassification::new(&vs.root(), &config).unwrap();
754    /// let (batch_size, sequence_length) = (64, 128);
755    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
756    /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
757    ///
758    /// let model_output = no_grad(|| {
759    ///     distilbert_model
760    ///         .forward_t(Some(&input_tensor), Some(&mask), None, false)
761    ///         .unwrap()
762    /// });
763    /// ```
764    pub fn forward_t(
765        &self,
766        input: Option<&Tensor>,
767        mask: Option<&Tensor>,
768        input_embeds: Option<&Tensor>,
769        train: bool,
770    ) -> Result<DistilBertTokenClassificationOutput, RustBertError> {
771        let base_model_output =
772            self.distil_bert_model
773                .forward_t(input, mask, input_embeds, train)?;
774
775        let logits = base_model_output
776            .hidden_state
777            .apply_t(&self.dropout, train)
778            .apply(&self.classifier);
779
780        Ok(DistilBertTokenClassificationOutput {
781            logits,
782            all_hidden_states: base_model_output.all_hidden_states,
783            all_attentions: base_model_output.all_attentions,
784        })
785    }
786}
787
788/// # DistilBERT for sentence embeddings
789/// Transformer usable in [`SentenceEmbeddingsModel`](crate::pipelines::sentence_embeddings::SentenceEmbeddingsModel).
790pub type DistilBertForSentenceEmbeddings = DistilBertModel;
791
792/// Container for the DistilBERT masked LM model output.
793pub struct DistilBertMaskedLMOutput {
794    /// Logits for the vocabulary items at each sequence position
795    pub prediction_scores: Tensor,
796    /// Hidden states for all intermediate layers
797    pub all_hidden_states: Option<Vec<Tensor>>,
798    /// Attention weights for all intermediate layers
799    pub all_attentions: Option<Vec<Tensor>>,
800}
801
802/// Container for the DistilBERT sequence classification model output
803pub struct DistilBertSequenceClassificationOutput {
804    /// Logits for each input (sequence) for each target class
805    pub logits: Tensor,
806    /// Hidden states for all intermediate layers
807    pub all_hidden_states: Option<Vec<Tensor>>,
808    /// Attention weights for all intermediate layers
809    pub all_attentions: Option<Vec<Tensor>>,
810}
811
812/// Container for the DistilBERT token classification model output
813pub struct DistilBertTokenClassificationOutput {
814    /// Logits for each sequence item (token) for each target class
815    pub logits: Tensor,
816    /// Hidden states for all intermediate layers
817    pub all_hidden_states: Option<Vec<Tensor>>,
818    /// Attention weights for all intermediate layers
819    pub all_attentions: Option<Vec<Tensor>>,
820}
821
822/// Container for the DistilBERT question answering model output
823pub struct DistilBertQuestionAnsweringOutput {
824    /// Logits for the start position for token of each input sequence
825    pub start_logits: Tensor,
826    /// Logits for the end position for token of each input sequence
827    pub end_logits: Tensor,
828    /// Hidden states for all intermediate layers
829    pub all_hidden_states: Option<Vec<Tensor>>,
830    /// Attention weights for all intermediate layers
831    pub all_attentions: Option<Vec<Tensor>>,
832}