rust_bert/models/roberta/
embeddings.rs

1// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
2// Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
3// Copyright 2019 Guillaume Becquin
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//     http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14use crate::bert::{BertConfig, BertEmbedding};
15use crate::common::dropout::Dropout;
16use crate::common::embeddings::process_ids_embeddings_pair;
17use crate::RustBertError;
18use std::borrow::Borrow;
19use tch::nn::{embedding, EmbeddingConfig};
20use tch::{nn, Kind, Tensor};
21
22#[derive(Debug)]
23/// # BertEmbeddings implementation for RoBERTa model
24/// Implementation of the `BertEmbedding` trait for RoBERTa models
25pub struct RobertaEmbeddings {
26    word_embeddings: nn::Embedding,
27    position_embeddings: nn::Embedding,
28    token_type_embeddings: nn::Embedding,
29    layer_norm: nn::LayerNorm,
30    dropout: Dropout,
31    padding_index: i64,
32}
33
34impl RobertaEmbeddings {
35    fn create_position_ids_from_input_ids(&self, x: &Tensor) -> Tensor {
36        let mask: Tensor = x.ne(self.padding_index).to_kind(Kind::Int64);
37        mask.cumsum(1, Kind::Int64) * mask + self.padding_index
38    }
39
40    fn create_position_ids_from_embeddings(&self, x: &Tensor) -> Tensor {
41        let input_shape = x.size();
42        let input_shape = vec![input_shape[0], input_shape[1]];
43        let position_ids: Tensor = Tensor::arange_start(
44            self.padding_index + 1,
45            input_shape[0],
46            (Kind::Int64, x.device()),
47        );
48        position_ids.unsqueeze(0).expand(&input_shape, true)
49    }
50}
51
52impl BertEmbedding for RobertaEmbeddings {
53    /// Build a new `RobertaEmbeddings`
54    ///
55    /// # Arguments
56    ///
57    /// * `p` - Variable store path for the root of the BertEmbeddings model
58    /// * `config` - `BertConfig` object defining the model architecture and vocab/hidden size
59    ///
60    /// # Example
61    ///
62    /// ```no_run
63    /// use rust_bert::bert::{BertConfig, BertEmbedding};
64    /// use rust_bert::roberta::RobertaEmbeddings;
65    /// use rust_bert::Config;
66    /// use std::path::Path;
67    /// use tch::{nn, Device};
68    ///
69    /// let config_path = Path::new("path/to/config.json");
70    /// let device = Device::Cpu;
71    /// let p = nn::VarStore::new(device);
72    /// let config = BertConfig::from_file(config_path);
73    /// let robert_embeddings = RobertaEmbeddings::new(&p.root() / "bert_embeddings", &config);
74    /// ```
75    fn new<'p, P>(p: P, config: &BertConfig) -> RobertaEmbeddings
76    where
77        P: Borrow<nn::Path<'p>>,
78    {
79        let p = p.borrow();
80
81        let embedding_config = EmbeddingConfig {
82            padding_idx: 1,
83            ..Default::default()
84        };
85
86        let word_embeddings: nn::Embedding = embedding(
87            p / "word_embeddings",
88            config.vocab_size,
89            config.hidden_size,
90            embedding_config,
91        );
92
93        let position_embeddings: nn::Embedding = embedding(
94            p / "position_embeddings",
95            config.max_position_embeddings,
96            config.hidden_size,
97            Default::default(),
98        );
99
100        let token_type_embeddings: nn::Embedding = embedding(
101            p / "token_type_embeddings",
102            config.type_vocab_size,
103            config.hidden_size,
104            Default::default(),
105        );
106
107        let layer_norm_config = nn::LayerNormConfig {
108            eps: 1e-12,
109            ..Default::default()
110        };
111        let layer_norm: nn::LayerNorm =
112            nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
113        let dropout: Dropout = Dropout::new(config.hidden_dropout_prob);
114        RobertaEmbeddings {
115            word_embeddings,
116            position_embeddings,
117            token_type_embeddings,
118            layer_norm,
119            dropout,
120            padding_index: 1,
121        }
122    }
123
124    /// Forward pass through the embedding layer.
125    /// This differs from the original BERT embeddings in how the position ids are calculated when not provided.
126    ///
127    /// # Arguments
128    ///
129    /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see *input_embeds*)
130    /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
131    /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
132    /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see *input_ids*)
133    /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
134    ///
135    /// # Returns
136    ///
137    /// * `embedded_output` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*)
138    ///
139    /// # Example
140    ///
141    /// ```no_run
142    /// # use rust_bert::bert::{BertConfig, BertEmbedding};
143    /// # use tch::{nn, Device, Tensor, no_grad};
144    /// # use rust_bert::Config;
145    /// # use std::path::Path;
146    /// # use tch::kind::Kind::Int64;
147    /// use rust_bert::roberta::RobertaEmbeddings;
148    /// # let config_path = Path::new("path/to/config.json");
149    /// # let vocab_path = Path::new("path/to/vocab.txt");
150    /// # let device = Device::Cpu;
151    /// # let vs = nn::VarStore::new(device);
152    /// # let config = BertConfig::from_file(config_path);
153    /// # let roberta_embeddings = RobertaEmbeddings::new(&vs.root(), &config);
154    /// let (batch_size, sequence_length) = (64, 128);
155    /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
156    /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
157    /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
158    ///     .expand(&[batch_size, sequence_length], true);
159    ///
160    /// let embedded_output = no_grad(|| {
161    ///     roberta_embeddings
162    ///         .forward_t(
163    ///             Some(&input_tensor),
164    ///             Some(&token_type_ids),
165    ///             Some(&position_ids),
166    ///             None,
167    ///             false,
168    ///         )
169    ///         .unwrap()
170    /// });
171    /// ```
172    fn forward_t(
173        &self,
174        input_ids: Option<&Tensor>,
175        token_type_ids: Option<&Tensor>,
176        position_ids: Option<&Tensor>,
177        input_embeds: Option<&Tensor>,
178        train: bool,
179    ) -> Result<Tensor, RustBertError> {
180        let (calc_input_embeddings, input_shape, _) =
181            process_ids_embeddings_pair(input_ids, input_embeds, &self.word_embeddings)?;
182
183        let input_embeddings =
184            input_embeds.unwrap_or_else(|| calc_input_embeddings.as_ref().unwrap());
185
186        let calc_position_ids = if position_ids.is_none() {
187            Some(match input_ids {
188                Some(value) => self.create_position_ids_from_input_ids(value),
189                None => self.create_position_ids_from_embeddings(input_embeds.unwrap()),
190            })
191        } else {
192            None
193        };
194
195        let calc_token_type_ids = if token_type_ids.is_none() {
196            Some(Tensor::zeros(
197                input_shape,
198                (Kind::Int64, input_embeddings.device()),
199            ))
200        } else {
201            None
202        };
203
204        let position_ids = position_ids.unwrap_or_else(|| calc_position_ids.as_ref().unwrap());
205        let token_type_ids =
206            token_type_ids.unwrap_or_else(|| calc_token_type_ids.as_ref().unwrap());
207
208        let position_embeddings = position_ids.apply(&self.position_embeddings);
209        let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);
210
211        let input_embeddings: Tensor =
212            input_embeddings + position_embeddings + token_type_embeddings;
213        Ok(input_embeddings
214            .apply(&self.layer_norm)
215            .apply_t(&self.dropout, train))
216    }
217}