rust_bert/models/bert/embeddings.rs
1// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
2// Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3// Copyright 2019 Guillaume Becquin
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7// http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14use crate::bert::bert_model::BertConfig;
15use crate::common::dropout::Dropout;
16use crate::common::embeddings::process_ids_embeddings_pair;
17use crate::RustBertError;
18use std::borrow::Borrow;
19use tch::nn::{embedding, EmbeddingConfig};
20use tch::{nn, Kind, Tensor};
21
22/// # BertEmbedding trait (for use in BertModel or RoBERTaModel)
23/// Defines an interface for the embedding layers in BERT-based models
24pub trait BertEmbedding {
25 fn new<'p, P>(p: P, config: &BertConfig) -> Self
26 where
27 P: Borrow<nn::Path<'p>>;
28
29 fn forward_t(
30 &self,
31 input_ids: Option<&Tensor>,
32 token_type_ids: Option<&Tensor>,
33 position_ids: Option<&Tensor>,
34 input_embeds: Option<&Tensor>,
35 train: bool,
36 ) -> Result<Tensor, RustBertError>;
37}
38
39#[derive(Debug)]
40/// # BertEmbeddings implementation for BERT model
41/// Implementation of the `BertEmbedding` trait for BERT models
42pub struct BertEmbeddings {
43 word_embeddings: nn::Embedding,
44 position_embeddings: nn::Embedding,
45 token_type_embeddings: nn::Embedding,
46 layer_norm: nn::LayerNorm,
47 dropout: Dropout,
48}
49
50impl BertEmbedding for BertEmbeddings {
51 /// Build a new `BertEmbeddings`
52 ///
53 /// # Arguments
54 ///
55 /// * `p` - Variable store path for the root of the BertEmbeddings model
56 /// * `config` - `BertConfig` object defining the model architecture and vocab/hidden size
57 ///
58 /// # Example
59 ///
60 /// ```no_run
61 /// use rust_bert::bert::{BertConfig, BertEmbedding, BertEmbeddings};
62 /// use rust_bert::Config;
63 /// use std::path::Path;
64 /// use tch::{nn, Device};
65 ///
66 /// let config_path = Path::new("path/to/config.json");
67 /// let device = Device::Cpu;
68 /// let p = nn::VarStore::new(device);
69 /// let config = BertConfig::from_file(config_path);
70 /// let bert_embeddings = BertEmbeddings::new(&p.root() / "bert_embeddings", &config);
71 /// ```
72 fn new<'p, P>(p: P, config: &BertConfig) -> BertEmbeddings
73 where
74 P: Borrow<nn::Path<'p>>,
75 {
76 let p = p.borrow();
77
78 let embedding_config = EmbeddingConfig {
79 padding_idx: 0,
80 ..Default::default()
81 };
82
83 let word_embeddings: nn::Embedding = embedding(
84 p / "word_embeddings",
85 config.vocab_size,
86 config.hidden_size,
87 embedding_config,
88 );
89
90 let position_embeddings: nn::Embedding = embedding(
91 p / "position_embeddings",
92 config.max_position_embeddings,
93 config.hidden_size,
94 Default::default(),
95 );
96
97 let token_type_embeddings: nn::Embedding = embedding(
98 p / "token_type_embeddings",
99 config.type_vocab_size,
100 config.hidden_size,
101 Default::default(),
102 );
103
104 let layer_norm_config = nn::LayerNormConfig {
105 eps: 1e-12,
106 ..Default::default()
107 };
108 let layer_norm: nn::LayerNorm =
109 nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
110 let dropout: Dropout = Dropout::new(config.hidden_dropout_prob);
111 BertEmbeddings {
112 word_embeddings,
113 position_embeddings,
114 token_type_embeddings,
115 layer_norm,
116 dropout,
117 }
118 }
119
120 /// Forward pass through the embedding layer
121 ///
122 /// # Arguments
123 ///
124 /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see *input_embeds*)
125 /// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
126 /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
127 /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see *input_ids*)
128 /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
129 ///
130 /// # Returns
131 ///
132 /// * `embedded_output` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*)
133 ///
134 /// # Example
135 ///
136 /// ```no_run
137 /// # use rust_bert::bert::{BertConfig, BertEmbeddings, BertEmbedding};
138 /// # use tch::{nn, Device, Tensor, no_grad};
139 /// # use rust_bert::Config;
140 /// # use std::path::Path;
141 /// # use tch::kind::Kind::Int64;
142 /// # let config_path = Path::new("path/to/config.json");
143 /// # let vocab_path = Path::new("path/to/vocab.txt");
144 /// # let device = Device::Cpu;
145 /// # let vs = nn::VarStore::new(device);
146 /// # let config = BertConfig::from_file(config_path);
147 /// # let bert_embeddings = BertEmbeddings::new(&vs.root(), &config);
148 /// let (batch_size, sequence_length) = (64, 128);
149 /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
150 /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
151 /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
152 /// .expand(&[batch_size, sequence_length], true);
153 ///
154 /// let embedded_output = no_grad(|| {
155 /// bert_embeddings
156 /// .forward_t(
157 /// Some(&input_tensor),
158 /// Some(&token_type_ids),
159 /// Some(&position_ids),
160 /// None,
161 /// false,
162 /// )
163 /// .unwrap()
164 /// });
165 /// ```
166 fn forward_t(
167 &self,
168 input_ids: Option<&Tensor>,
169 token_type_ids: Option<&Tensor>,
170 position_ids: Option<&Tensor>,
171 input_embeds: Option<&Tensor>,
172 train: bool,
173 ) -> Result<Tensor, RustBertError> {
174 let (calc_input_embeddings, input_shape, _) =
175 process_ids_embeddings_pair(input_ids, input_embeds, &self.word_embeddings)?;
176
177 let input_embeddings =
178 input_embeds.unwrap_or_else(|| calc_input_embeddings.as_ref().unwrap());
179 let seq_length = input_embeddings.size()[1];
180
181 let calc_position_ids = if position_ids.is_none() {
182 Some(
183 Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device()))
184 .unsqueeze(0)
185 .expand(&input_shape, true),
186 )
187 } else {
188 None
189 };
190
191 let calc_token_type_ids = if token_type_ids.is_none() {
192 Some(Tensor::zeros(
193 &input_shape,
194 (Kind::Int64, input_embeddings.device()),
195 ))
196 } else {
197 None
198 };
199
200 let position_ids = position_ids.unwrap_or_else(|| calc_position_ids.as_ref().unwrap());
201 let token_type_ids =
202 token_type_ids.unwrap_or_else(|| calc_token_type_ids.as_ref().unwrap());
203
204 let position_embeddings = position_ids.apply(&self.position_embeddings);
205 let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);
206
207 let input_embeddings: Tensor =
208 input_embeddings + position_embeddings + token_type_embeddings;
209 Ok(input_embeddings
210 .apply(&self.layer_norm)
211 .apply_t(&self.dropout, train))
212 }
213}