rust_bert/models/distilbert/distilbert_model.rs
1// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
2// Copyright 2019 Guillaume Becquin
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6// http://www.apache.org/licenses/LICENSE-2.0
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13extern crate tch;
14
15use self::tch::{nn, Tensor};
16use crate::common::activations::Activation;
17use crate::common::dropout::Dropout;
18use crate::distilbert::embeddings::DistilBertEmbedding;
19use crate::distilbert::transformer::{DistilBertTransformerOutput, Transformer};
20use crate::{Config, RustBertError};
21use serde::{Deserialize, Serialize};
22use std::{borrow::Borrow, collections::HashMap};
23
24/// # DistilBERT Pretrained model weight files
25pub struct DistilBertModelResources;
26
27/// # DistilBERT Pretrained model config files
28pub struct DistilBertConfigResources;
29
30/// # DistilBERT Pretrained model vocab files
31pub struct DistilBertVocabResources;
32
33impl DistilBertModelResources {
34 /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
35 pub const DISTIL_BERT_SST2: (&'static str, &'static str) = (
36 "distilbert-sst2/model",
37 "https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/rust_model.ot",
38 );
39 /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
40 pub const DISTIL_BERT: (&'static str, &'static str) = (
41 "distilbert/model",
42 "https://huggingface.co/distilbert-base-uncased/resolve/main/rust_model.ot",
43 );
44 /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
45 pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = (
46 "distilbert-qa/model",
47 "https://huggingface.co/distilbert-base-cased-distilled-squad/resolve/main/rust_model.ot",
48 );
49 /// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased>. Modified with conversion to C-array format.
50 pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = (
51 "distiluse-base-multilingual-cased/model",
52 "https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/rust_model.ot",
53 );
54}
55
56impl DistilBertConfigResources {
57 /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
58 pub const DISTIL_BERT_SST2: (&'static str, &'static str) = (
59 "distilbert-sst2/config",
60 "https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/config.json",
61 );
62 /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
63 pub const DISTIL_BERT: (&'static str, &'static str) = (
64 "distilbert/config",
65 "https://huggingface.co/distilbert-base-uncased/resolve/main/config.json",
66 );
67 /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
68 pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = (
69 "distilbert-qa/config",
70 "https://huggingface.co/distilbert-base-cased-distilled-squad/resolve/main/config.json",
71 );
72 /// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased>. Modified with conversion to C-array format.
73 pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = (
74 "distiluse-base-multilingual-cased/config",
75 "https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/config.json",
76 );
77}
78
79impl DistilBertVocabResources {
80 /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
81 pub const DISTIL_BERT_SST2: (&'static str, &'static str) = (
82 "distilbert-sst2/vocab",
83 "https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/vocab.txt",
84 );
85 /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
86 pub const DISTIL_BERT: (&'static str, &'static str) = (
87 "distilbert/vocab",
88 "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt",
89 );
90 /// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
91 pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = (
92 "distilbert-qa/vocab",
93 "https://huggingface.co/bert-large-cased/resolve/main/vocab.txt",
94 );
95 /// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased>. Modified with conversion to C-array format.
96 pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = (
97 "distiluse-base-multilingual-cased/vocab",
98 "https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/vocab.txt",
99 );
100}
101
102#[derive(Debug, Serialize, Deserialize, Clone)]
103/// # DistilBERT model configuration
104/// Defines the DistilBERT model architecture (e.g. number of layers, hidden layer size, label mapping...)
105pub struct DistilBertConfig {
106 pub activation: Activation,
107 pub attention_dropout: f64,
108 pub dim: i64,
109 pub dropout: f64,
110 pub hidden_dim: i64,
111 pub id2label: Option<HashMap<i64, String>>,
112 pub initializer_range: f32,
113 pub is_decoder: Option<bool>,
114 pub label2id: Option<HashMap<String, i64>>,
115 pub max_position_embeddings: i64,
116 pub n_heads: i64,
117 pub n_layers: i64,
118 pub output_attentions: Option<bool>,
119 pub output_hidden_states: Option<bool>,
120 pub output_past: Option<bool>,
121 pub qa_dropout: f64,
122 pub seq_classif_dropout: f64,
123 pub sinusoidal_pos_embds: bool,
124 pub tie_weights_: bool,
125 pub vocab_size: i64,
126}
127
128impl Config for DistilBertConfig {}
129
130impl Default for DistilBertConfig {
131 fn default() -> Self {
132 DistilBertConfig {
133 activation: Activation::gelu,
134 attention_dropout: 0.1,
135 dim: 768,
136 dropout: 0.1,
137 hidden_dim: 3072,
138 id2label: None,
139 initializer_range: 0.02,
140 is_decoder: None,
141 label2id: None,
142 max_position_embeddings: 512,
143 n_heads: 12,
144 n_layers: 6,
145 output_attentions: None,
146 output_hidden_states: None,
147 output_past: None,
148 qa_dropout: 0.1,
149 seq_classif_dropout: 0.2,
150 sinusoidal_pos_embds: false,
151 tie_weights_: false,
152 vocab_size: 30522,
153 }
154 }
155}
156
157/// # DistilBERT Base model
158/// Base architecture for DistilBERT models. Task-specific models will be built from this common base model
159/// It is made of the following blocks:
160/// - `embeddings`: `token`, `position` embeddings
161/// - `transformer`: Transformer made of a vector of layers. Each layer is made of a multi-head self-attention layer, layer norm and linear layers.
162pub struct DistilBertModel {
163 embeddings: DistilBertEmbedding,
164 transformer: Transformer,
165}
166
167/// Defines the implementation of the DistilBertModel.
168impl DistilBertModel {
169 /// Build a new `DistilBertModel`
170 ///
171 /// # Arguments
172 ///
173 /// * `p` - Variable store path for the root of the DistilBERT model
174 /// * `config` - `DistilBertConfig` object defining the model architecture
175 ///
176 /// # Example
177 ///
178 /// ```no_run
179 /// use rust_bert::distilbert::{DistilBertConfig, DistilBertModel};
180 /// use rust_bert::Config;
181 /// use std::path::Path;
182 /// use tch::{nn, Device};
183 ///
184 /// let config_path = Path::new("path/to/config.json");
185 /// let device = Device::Cpu;
186 /// let p = nn::VarStore::new(device);
187 /// let config = DistilBertConfig::from_file(config_path);
188 /// let distil_bert: DistilBertModel = DistilBertModel::new(&p.root() / "distilbert", &config);
189 /// ```
190 pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> DistilBertModel
191 where
192 P: Borrow<nn::Path<'p>>,
193 {
194 let p = p.borrow() / "distilbert";
195 let embeddings = DistilBertEmbedding::new(&p / "embeddings", config);
196 let transformer = Transformer::new(p / "transformer", config);
197 DistilBertModel {
198 embeddings,
199 transformer,
200 }
201 }
202
203 /// Forward pass through the model
204 ///
205 /// # Arguments
206 ///
207 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
208 /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
209 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
210 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
211 ///
212 /// # Returns
213 ///
214 /// * `DistilBertTransformerOutput` containing:
215 /// - `hidden_state` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*)
216 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
217 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
218 ///
219 /// # Example
220 ///
221 /// ```no_run
222 /// # use tch::{nn, Device, Tensor, no_grad};
223 /// # use rust_bert::Config;
224 /// # use std::path::Path;
225 /// # use tch::kind::Kind::Int64;
226 /// use rust_bert::distilbert::{DistilBertConfig, DistilBertModel};
227 /// # let config_path = Path::new("path/to/config.json");
228 /// # let vocab_path = Path::new("path/to/vocab.txt");
229 /// # let device = Device::Cpu;
230 /// # let vs = nn::VarStore::new(device);
231 /// # let config = DistilBertConfig::from_file(config_path);
232 /// # let distilbert_model: DistilBertModel = DistilBertModel::new(&vs.root(), &config);
233 /// let (batch_size, sequence_length) = (64, 128);
234 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
235 /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
236 ///
237 /// let model_output = no_grad(|| {
238 /// distilbert_model
239 /// .forward_t(Some(&input_tensor), Some(&mask), None, false)
240 /// .unwrap()
241 /// });
242 /// ```
243 pub fn forward_t(
244 &self,
245 input: Option<&Tensor>,
246 mask: Option<&Tensor>,
247 input_embeds: Option<&Tensor>,
248 train: bool,
249 ) -> Result<DistilBertTransformerOutput, RustBertError> {
250 let input_embeddings = self.embeddings.forward_t(input, input_embeds, train)?;
251 let transformer_output = self.transformer.forward_t(&input_embeddings, mask, train);
252 Ok(transformer_output)
253 }
254}
255
256/// # DistilBERT for sequence classification
257/// Base DistilBERT model with a pre-classifier and classifier heads to perform sentence or document-level classification
258/// It is made of the following blocks:
259/// - `distil_bert_model`: Base DistilBertModel
260/// - `pre_classifier`: DistilBERT linear layer for classification
261/// - `classifier`: DistilBERT linear layer for classification
262pub struct DistilBertModelClassifier {
263 distil_bert_model: DistilBertModel,
264 pre_classifier: nn::Linear,
265 classifier: nn::Linear,
266 dropout: Dropout,
267}
268
269impl DistilBertModelClassifier {
270 /// Build a new `DistilBertModelClassifier` for sequence classification
271 ///
272 /// # Arguments
273 ///
274 /// * `p` - Variable store path for the root of the DistilBertModelClassifier model
275 /// * `config` - `DistilBertConfig` object defining the model architecture
276 ///
277 /// # Example
278 ///
279 /// ```no_run
280 /// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelClassifier};
281 /// use rust_bert::Config;
282 /// use std::path::Path;
283 /// use tch::{nn, Device};
284 ///
285 /// let config_path = Path::new("path/to/config.json");
286 /// let device = Device::Cpu;
287 /// let p = nn::VarStore::new(device);
288 /// let config = DistilBertConfig::from_file(config_path);
289 /// let distil_bert: DistilBertModelClassifier =
290 /// DistilBertModelClassifier::new(&p.root() / "distilbert", &config).unwrap();
291 /// ```
292 pub fn new<'p, P>(
293 p: P,
294 config: &DistilBertConfig,
295 ) -> Result<DistilBertModelClassifier, RustBertError>
296 where
297 P: Borrow<nn::Path<'p>>,
298 {
299 let p = p.borrow();
300
301 let distil_bert_model = DistilBertModel::new(p, config);
302
303 let num_labels = config
304 .id2label
305 .as_ref()
306 .ok_or_else(|| {
307 RustBertError::InvalidConfigurationError(
308 "num_labels not provided in configuration".to_string(),
309 )
310 })?
311 .len() as i64;
312
313 let pre_classifier = nn::linear(
314 p / "pre_classifier",
315 config.dim,
316 config.dim,
317 Default::default(),
318 );
319 let classifier = nn::linear(p / "classifier", config.dim, num_labels, Default::default());
320 let dropout = Dropout::new(config.seq_classif_dropout);
321
322 Ok(DistilBertModelClassifier {
323 distil_bert_model,
324 pre_classifier,
325 classifier,
326 dropout,
327 })
328 }
329
330 /// Forward pass through the model
331 ///
332 /// # Arguments
333 ///
334 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
335 /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
336 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
337 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
338 ///
339 /// # Returns
340 ///
341 /// * `DistilBertSequenceClassificationOutput` containing:
342 /// - `logits` - `Tensor` of shape (*batch size*, *num_labels*)
343 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
344 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
345 ///
346 /// # Example
347 ///
348 /// ```no_run
349 /// # use tch::{nn, Device, Tensor, no_grad};
350 /// # use rust_bert::Config;
351 /// # use std::path::Path;
352 /// # use tch::kind::Kind::Int64;
353 /// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelClassifier};
354 /// # let config_path = Path::new("path/to/config.json");
355 /// # let vocab_path = Path::new("path/to/vocab.txt");
356 /// # let device = Device::Cpu;
357 /// # let vs = nn::VarStore::new(device);
358 /// # let config = DistilBertConfig::from_file(config_path);
359 /// # let distilbert_model: DistilBertModelClassifier = DistilBertModelClassifier::new(&vs.root(), &config).unwrap();;
360 /// let (batch_size, sequence_length) = (64, 128);
361 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
362 /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
363 ///
364 /// let model_output = no_grad(|| {
365 /// distilbert_model
366 /// .forward_t(Some(&input_tensor),
367 /// Some(&mask),
368 /// None,
369 /// false).unwrap()
370 /// });
371 /// ```
372 pub fn forward_t(
373 &self,
374 input: Option<&Tensor>,
375 mask: Option<&Tensor>,
376 input_embeds: Option<&Tensor>,
377 train: bool,
378 ) -> Result<DistilBertSequenceClassificationOutput, RustBertError> {
379 let base_model_output =
380 self.distil_bert_model
381 .forward_t(input, mask, input_embeds, train)?;
382
383 let logits = base_model_output
384 .hidden_state
385 .select(1, 0)
386 .apply(&self.pre_classifier)
387 .relu()
388 .apply_t(&self.dropout, train)
389 .apply(&self.classifier);
390
391 Ok(DistilBertSequenceClassificationOutput {
392 logits,
393 all_hidden_states: base_model_output.all_hidden_states,
394 all_attentions: base_model_output.all_attentions,
395 })
396 }
397}
398
399/// # DistilBERT for masked language model
400/// Base DistilBERT model with a masked language model head to predict missing tokens, for example `"Looks like one [MASK] is missing" -> "person"`
401/// It is made of the following blocks:
402/// - `distil_bert_model`: Base DistilBertModel
403/// - `vocab_transform`:linear layer for classification of size (*hidden_dim*, *hidden_dim*)
404/// - `vocab_layer_norm`: layer normalization
405/// - `vocab_projector`: linear layer for classification of size (*hidden_dim*, *vocab_size*) with weights tied to the token embeddings
406pub struct DistilBertModelMaskedLM {
407 distil_bert_model: DistilBertModel,
408 vocab_transform: nn::Linear,
409 vocab_layer_norm: nn::LayerNorm,
410 vocab_projector: nn::Linear,
411}
412
413impl DistilBertModelMaskedLM {
414 /// Build a new `DistilBertModelMaskedLM` for sequence classification
415 ///
416 /// # Arguments
417 ///
418 /// * `p` - Variable store path for the root of the DistilBertModelMaskedLM model
419 /// * `config` - `DistilBertConfig` object defining the model architecture
420 ///
421 /// # Example
422 ///
423 /// ```no_run
424 /// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM};
425 /// use rust_bert::Config;
426 /// use std::path::Path;
427 /// use tch::{nn, Device};
428 ///
429 /// let config_path = Path::new("path/to/config.json");
430 /// let device = Device::Cpu;
431 /// let p = nn::VarStore::new(device);
432 /// let config = DistilBertConfig::from_file(config_path);
433 /// let distil_bert = DistilBertModelMaskedLM::new(&p.root() / "distilbert", &config);
434 /// ```
435 pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> DistilBertModelMaskedLM
436 where
437 P: Borrow<nn::Path<'p>>,
438 {
439 let p = p.borrow();
440
441 let distil_bert_model = DistilBertModel::new(p, config);
442 let vocab_transform = nn::linear(
443 p / "vocab_transform",
444 config.dim,
445 config.dim,
446 Default::default(),
447 );
448 let layer_norm_config = nn::LayerNormConfig {
449 eps: 1e-12,
450 ..Default::default()
451 };
452 let vocab_layer_norm =
453 nn::layer_norm(p / "vocab_layer_norm", vec![config.dim], layer_norm_config);
454 let vocab_projector = nn::linear(
455 p / "vocab_projector",
456 config.dim,
457 config.vocab_size,
458 Default::default(),
459 );
460
461 DistilBertModelMaskedLM {
462 distil_bert_model,
463 vocab_transform,
464 vocab_layer_norm,
465 vocab_projector,
466 }
467 }
468
469 /// Forward pass through the model
470 ///
471 /// # Arguments
472 ///
473 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
474 /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
475 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
476 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
477 ///
478 /// # Returns
479 ///
480 /// * `DistilBertMaskedLMOutput` containing:
481 /// - `prediction_scores` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*)
482 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
483 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
484 ///
485 /// # Example
486 ///
487 /// ```no_run
488 /// # use tch::{nn, Device, Tensor, no_grad};
489 /// # use rust_bert::Config;
490 /// # use std::path::Path;
491 /// # use tch::kind::Kind::Int64;
492 /// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM};
493 /// # let config_path = Path::new("path/to/config.json");
494 /// # let vocab_path = Path::new("path/to/vocab.txt");
495 /// # let device = Device::Cpu;
496 /// # let vs = nn::VarStore::new(device);
497 /// # let config = DistilBertConfig::from_file(config_path);
498 /// # let distilbert_model = DistilBertModelMaskedLM::new(&vs.root(), &config);
499 /// let (batch_size, sequence_length) = (64, 128);
500 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
501 /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
502 ///
503 /// let model_output = no_grad(|| {
504 /// distilbert_model
505 /// .forward_t(Some(&input_tensor), Some(&mask), None, false)
506 /// .unwrap()
507 /// });
508 /// ```
509 pub fn forward_t(
510 &self,
511 input: Option<&Tensor>,
512 mask: Option<&Tensor>,
513 input_embeds: Option<&Tensor>,
514 train: bool,
515 ) -> Result<DistilBertMaskedLMOutput, RustBertError> {
516 let base_model_output =
517 self.distil_bert_model
518 .forward_t(input, mask, input_embeds, train)?;
519
520 let prediction_scores = base_model_output
521 .hidden_state
522 .apply(&self.vocab_transform)
523 .gelu("none")
524 .apply(&self.vocab_layer_norm)
525 .apply(&self.vocab_projector);
526
527 Ok(DistilBertMaskedLMOutput {
528 prediction_scores,
529 all_hidden_states: base_model_output.all_hidden_states,
530 all_attentions: base_model_output.all_attentions,
531 })
532 }
533}
534
535/// # DistilBERT for question answering
536/// Extractive question-answering model based on a DistilBERT language model. Identifies the segment of a context that answers a provided question.
537/// Please note that a significant amount of pre- and post-processing is required to perform end-to-end question answering.
538/// See the question answering pipeline (also provided in this crate) for more details.
539/// It is made of the following blocks:
540/// - `distil_bert_model`: Base DistilBertModel
541/// - `qa_outputs`: Linear layer for question answering
542pub struct DistilBertForQuestionAnswering {
543 distil_bert_model: DistilBertModel,
544 qa_outputs: nn::Linear,
545 dropout: Dropout,
546}
547
548impl DistilBertForQuestionAnswering {
549 /// Build a new `DistilBertForQuestionAnswering` for sequence classification
550 ///
551 /// # Arguments
552 ///
553 /// * `p` - Variable store path for the root of the DistilBertForQuestionAnswering model
554 /// * `config` - `DistilBertConfig` object defining the model architecture
555 ///
556 /// # Example
557 ///
558 /// ```no_run
559 /// use rust_bert::distilbert::{DistilBertConfig, DistilBertForQuestionAnswering};
560 /// use rust_bert::Config;
561 /// use std::path::Path;
562 /// use tch::{nn, Device};
563 ///
564 /// let config_path = Path::new("path/to/config.json");
565 /// let device = Device::Cpu;
566 /// let p = nn::VarStore::new(device);
567 /// let config = DistilBertConfig::from_file(config_path);
568 /// let distil_bert = DistilBertForQuestionAnswering::new(&p.root() / "distilbert", &config);
569 /// ```
570 pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> DistilBertForQuestionAnswering
571 where
572 P: Borrow<nn::Path<'p>>,
573 {
574 let p = p.borrow();
575
576 let distil_bert_model = DistilBertModel::new(p, config);
577 let qa_outputs = nn::linear(p / "qa_outputs", config.dim, 2, Default::default());
578 let dropout = Dropout::new(config.qa_dropout);
579
580 DistilBertForQuestionAnswering {
581 distil_bert_model,
582 qa_outputs,
583 dropout,
584 }
585 }
586
587 /// Forward pass through the model
588 ///
589 /// # Arguments
590 ///
591 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
592 /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
593 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
594 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
595 ///
596 /// # Returns
597 ///
598 /// * `DistilBertQuestionAnsweringOutput` containing:
599 /// - `start_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for start of the answer
600 /// - `end_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for end of the answer
601 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
602 /// - `all_attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
603 ///
604 /// # Example
605 ///
606 /// ```no_run
607 /// # use tch::{nn, Device, Tensor, no_grad};
608 /// # use rust_bert::Config;
609 /// # use std::path::Path;
610 /// # use tch::kind::Kind::Int64;
611 /// use rust_bert::distilbert::{DistilBertConfig, DistilBertForQuestionAnswering};
612 /// # let config_path = Path::new("path/to/config.json");
613 /// # let vocab_path = Path::new("path/to/vocab.txt");
614 /// # let device = Device::Cpu;
615 /// # let vs = nn::VarStore::new(device);
616 /// # let config = DistilBertConfig::from_file(config_path);
617 /// # let distilbert_model = DistilBertForQuestionAnswering::new(&vs.root(), &config);
618 /// let (batch_size, sequence_length) = (64, 128);
619 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
620 /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
621 ///
622 /// let model_output = no_grad(|| {
623 /// distilbert_model
624 /// .forward_t(Some(&input_tensor), Some(&mask), None, false)
625 /// .unwrap()
626 /// });
627 /// ```
628 pub fn forward_t(
629 &self,
630 input: Option<&Tensor>,
631 mask: Option<&Tensor>,
632 input_embeds: Option<&Tensor>,
633 train: bool,
634 ) -> Result<DistilBertQuestionAnsweringOutput, RustBertError> {
635 let base_model_output =
636 self.distil_bert_model
637 .forward_t(input, mask, input_embeds, train)?;
638
639 let output = base_model_output
640 .hidden_state
641 .apply_t(&self.dropout, train)
642 .apply(&self.qa_outputs);
643
644 let logits = output.split(1, -1);
645 let (start_logits, end_logits) = (&logits[0], &logits[1]);
646 let start_logits = start_logits.squeeze_dim(-1);
647 let end_logits = end_logits.squeeze_dim(-1);
648
649 Ok(DistilBertQuestionAnsweringOutput {
650 start_logits,
651 end_logits,
652 all_hidden_states: base_model_output.all_hidden_states,
653 all_attentions: base_model_output.all_attentions,
654 })
655 }
656}
657
658/// # DistilBERT for token classification (e.g. NER, POS)
659/// Token-level classifier predicting a label for each token provided. Note that because of wordpiece tokenization, the labels predicted are
660/// not necessarily aligned with words in the sentence.
661/// It is made of the following blocks:
662/// - `distil_bert_model`: Base DistilBertModel
663/// - `classifier`: Linear layer for token classification
664pub struct DistilBertForTokenClassification {
665 distil_bert_model: DistilBertModel,
666 classifier: nn::Linear,
667 dropout: Dropout,
668}
669
670impl DistilBertForTokenClassification {
671 /// Build a new `DistilBertForTokenClassification` for sequence classification
672 ///
673 /// # Arguments
674 ///
675 /// * `p` - Variable store path for the root of the DistilBertForTokenClassification model
676 /// * `config` - `DistilBertConfig` object defining the model architecture
677 ///
678 /// # Example
679 ///
680 /// ```no_run
681 /// use rust_bert::distilbert::{DistilBertConfig, DistilBertForTokenClassification};
682 /// use rust_bert::Config;
683 /// use std::path::Path;
684 /// use tch::{nn, Device};
685 ///
686 /// let config_path = Path::new("path/to/config.json");
687 /// let device = Device::Cpu;
688 /// let p = nn::VarStore::new(device);
689 /// let config = DistilBertConfig::from_file(config_path);
690 /// let distil_bert =
691 /// DistilBertForTokenClassification::new(&p.root() / "distilbert", &config).unwrap();
692 /// ```
693 pub fn new<'p, P>(
694 p: P,
695 config: &DistilBertConfig,
696 ) -> Result<DistilBertForTokenClassification, RustBertError>
697 where
698 P: Borrow<nn::Path<'p>>,
699 {
700 let p = p.borrow();
701
702 let distil_bert_model = DistilBertModel::new(p, config);
703
704 let num_labels = config
705 .id2label
706 .as_ref()
707 .ok_or_else(|| {
708 RustBertError::InvalidConfigurationError(
709 "id2label must be provided for classifiers".to_string(),
710 )
711 })?
712 .len() as i64;
713
714 let classifier = nn::linear(p / "classifier", config.dim, num_labels, Default::default());
715 let dropout = Dropout::new(config.seq_classif_dropout);
716
717 Ok(DistilBertForTokenClassification {
718 distil_bert_model,
719 classifier,
720 dropout,
721 })
722 }
723
724 /// Forward pass through the model
725 ///
726 /// # Arguments
727 ///
728 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
729 /// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
730 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
731 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
732 ///
733 /// # Returns
734 ///
735 /// * `DistilBertTokenClassificationOutput` containing:
736 /// - `logits` - `Tensor` of shape (*batch size*, *sequence_length*, *num_labels*) containing the logits for each of the input tokens and classes
737 /// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
738 /// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
739 ///
740 /// # Example
741 ///
742 /// ```no_run
743 /// # use tch::{nn, Device, Tensor, no_grad};
744 /// # use rust_bert::Config;
745 /// # use std::path::Path;
746 /// # use tch::kind::Kind::Int64;
747 /// use rust_bert::distilbert::{DistilBertConfig, DistilBertForTokenClassification};
748 /// # let config_path = Path::new("path/to/config.json");
749 /// # let vocab_path = Path::new("path/to/vocab.txt");
750 /// # let device = Device::Cpu;
751 /// # let vs = nn::VarStore::new(device);
752 /// # let config = DistilBertConfig::from_file(config_path);
753 /// # let distilbert_model = DistilBertForTokenClassification::new(&vs.root(), &config).unwrap();
754 /// let (batch_size, sequence_length) = (64, 128);
755 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
756 /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
757 ///
758 /// let model_output = no_grad(|| {
759 /// distilbert_model
760 /// .forward_t(Some(&input_tensor), Some(&mask), None, false)
761 /// .unwrap()
762 /// });
763 /// ```
764 pub fn forward_t(
765 &self,
766 input: Option<&Tensor>,
767 mask: Option<&Tensor>,
768 input_embeds: Option<&Tensor>,
769 train: bool,
770 ) -> Result<DistilBertTokenClassificationOutput, RustBertError> {
771 let base_model_output =
772 self.distil_bert_model
773 .forward_t(input, mask, input_embeds, train)?;
774
775 let logits = base_model_output
776 .hidden_state
777 .apply_t(&self.dropout, train)
778 .apply(&self.classifier);
779
780 Ok(DistilBertTokenClassificationOutput {
781 logits,
782 all_hidden_states: base_model_output.all_hidden_states,
783 all_attentions: base_model_output.all_attentions,
784 })
785 }
786}
787
788/// # DistilBERT for sentence embeddings
789/// Transformer usable in [`SentenceEmbeddingsModel`](crate::pipelines::sentence_embeddings::SentenceEmbeddingsModel).
790pub type DistilBertForSentenceEmbeddings = DistilBertModel;
791
792/// Container for the DistilBERT masked LM model output.
793pub struct DistilBertMaskedLMOutput {
794 /// Logits for the vocabulary items at each sequence position
795 pub prediction_scores: Tensor,
796 /// Hidden states for all intermediate layers
797 pub all_hidden_states: Option<Vec<Tensor>>,
798 /// Attention weights for all intermediate layers
799 pub all_attentions: Option<Vec<Tensor>>,
800}
801
802/// Container for the DistilBERT sequence classification model output
803pub struct DistilBertSequenceClassificationOutput {
804 /// Logits for each input (sequence) for each target class
805 pub logits: Tensor,
806 /// Hidden states for all intermediate layers
807 pub all_hidden_states: Option<Vec<Tensor>>,
808 /// Attention weights for all intermediate layers
809 pub all_attentions: Option<Vec<Tensor>>,
810}
811
812/// Container for the DistilBERT token classification model output
813pub struct DistilBertTokenClassificationOutput {
814 /// Logits for each sequence item (token) for each target class
815 pub logits: Tensor,
816 /// Hidden states for all intermediate layers
817 pub all_hidden_states: Option<Vec<Tensor>>,
818 /// Attention weights for all intermediate layers
819 pub all_attentions: Option<Vec<Tensor>>,
820}
821
822/// Container for the DistilBERT question answering model output
823pub struct DistilBertQuestionAnsweringOutput {
824 /// Logits for the start position for token of each input sequence
825 pub start_logits: Tensor,
826 /// Logits for the end position for token of each input sequence
827 pub end_logits: Tensor,
828 /// Hidden states for all intermediate layers
829 pub all_hidden_states: Option<Vec<Tensor>>,
830 /// Attention weights for all intermediate layers
831 pub all_attentions: Option<Vec<Tensor>>,
832}