1use std::borrow::Borrow;
2use std::convert::{TryFrom, TryInto};
3
4use rust_tokenizers::tokenizer::TruncationStrategy;
5use tch::{nn, Tensor};
6
7use crate::albert::AlbertForSentenceEmbeddings;
8use crate::bert::BertForSentenceEmbeddings;
9use crate::distilbert::DistilBertForSentenceEmbeddings;
10use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
11use crate::pipelines::sentence_embeddings::layers::{Dense, DenseConfig, Pooling, PoolingConfig};
12use crate::pipelines::sentence_embeddings::{
13 AttentionHead, AttentionLayer, AttentionOutput, Embedding, SentenceEmbeddingsConfig,
14 SentenceEmbeddingsModulesConfig, SentenceEmbeddingsSentenceBertConfig,
15 SentenceEmbeddingsTokenizerConfig,
16};
17use crate::roberta::RobertaForSentenceEmbeddings;
18use crate::t5::T5ForSentenceEmbeddings;
19use crate::{Config, RustBertError};
20
21pub enum SentenceEmbeddingsOption {
23 Bert(BertForSentenceEmbeddings),
25 DistilBert(DistilBertForSentenceEmbeddings),
27 Roberta(RobertaForSentenceEmbeddings),
29 Albert(AlbertForSentenceEmbeddings),
31 T5(T5ForSentenceEmbeddings),
33}
34
35impl SentenceEmbeddingsOption {
36 pub fn new<'p, P>(
44 transformer_type: ModelType,
45 p: P,
46 config: &ConfigOption,
47 ) -> Result<Self, RustBertError>
48 where
49 P: Borrow<nn::Path<'p>>,
50 {
51 use SentenceEmbeddingsOption::*;
52
53 let option = match transformer_type {
54 ModelType::Bert => Bert(BertForSentenceEmbeddings::new(p, &(config.try_into()?))),
55 ModelType::DistilBert => DistilBert(DistilBertForSentenceEmbeddings::new(
56 p,
57 &(config.try_into()?),
58 )),
59 ModelType::Roberta => Roberta(RobertaForSentenceEmbeddings::new_with_optional_pooler(
60 p,
61 &(config.try_into()?),
62 false,
63 )),
64 ModelType::Albert => Albert(AlbertForSentenceEmbeddings::new(p, &(config.try_into()?))),
65 ModelType::T5 => T5(T5ForSentenceEmbeddings::new(p, &(config.try_into()?))),
66 _ => {
67 return Err(RustBertError::InvalidConfigurationError(format!(
68 "Unsupported transformer model {transformer_type:?} for Sentence Embeddings"
69 )));
70 }
71 };
72
73 Ok(option)
74 }
75
76 pub fn forward(
78 &self,
79 tokens_ids: &Tensor,
80 tokens_masks: &Tensor,
81 ) -> Result<(Tensor, Option<Vec<Tensor>>), RustBertError> {
82 match self {
83 Self::Bert(transformer) => transformer
84 .forward_t(
85 Some(tokens_ids),
86 Some(tokens_masks),
87 None,
88 None,
89 None,
90 None,
91 None,
92 false,
93 )
94 .map(|transformer_output| {
95 (
96 transformer_output.hidden_state,
97 transformer_output.all_attentions,
98 )
99 }),
100 Self::DistilBert(transformer) => transformer
101 .forward_t(Some(tokens_ids), Some(tokens_masks), None, false)
102 .map(|transformer_output| {
103 (
104 transformer_output.hidden_state,
105 transformer_output.all_attentions,
106 )
107 }),
108 Self::Roberta(transformer) => transformer
109 .forward_t(
110 Some(tokens_ids),
111 Some(tokens_masks),
112 None,
113 None,
114 None,
115 None,
116 None,
117 false,
118 )
119 .map(|transformer_output| {
120 (
121 transformer_output.hidden_state,
122 transformer_output.all_attentions,
123 )
124 }),
125 Self::Albert(transformer) => transformer
126 .forward_t(
127 Some(tokens_ids),
128 Some(tokens_masks),
129 None,
130 None,
131 None,
132 false,
133 )
134 .map(|transformer_output| {
135 (
136 transformer_output.hidden_state,
137 transformer_output.all_attentions.map(|attentions| {
138 attentions
139 .into_iter()
140 .map(|tensors| {
141 let num_inner_groups = tensors.len() as f64;
142 tensors.into_iter().sum::<Tensor>() / num_inner_groups
143 })
144 .collect()
145 }),
146 )
147 }),
148 Self::T5(transformer) => transformer.forward(tokens_ids, tokens_masks),
149 }
150 }
151}
152
153pub struct SentenceEmbeddingsModel {
161 sentence_bert_config: SentenceEmbeddingsSentenceBertConfig,
162 tokenizer: TokenizerOption,
163 tokenizer_truncation_strategy: TruncationStrategy,
164 var_store: nn::VarStore,
165 transformer: SentenceEmbeddingsOption,
166 transformer_config: ConfigOption,
167 pooling_layer: Pooling,
168 dense_layer: Option<Dense>,
169 normalize_embeddings: bool,
170 embeddings_dim: i64,
171}
172
173impl SentenceEmbeddingsModel {
174 pub fn new(config: SentenceEmbeddingsConfig) -> Result<Self, RustBertError> {
180 let transformer_type = config.transformer_type;
181 let tokenizer_vocab_resource = &config.tokenizer_vocab_resource;
182 let tokenizer_merges_resource = &config.tokenizer_merges_resource;
183 let tokenizer_config_resource = &config.tokenizer_config_resource;
184 let sentence_bert_config_resource = &config.sentence_bert_config_resource;
185 let tokenizer_config = SentenceEmbeddingsTokenizerConfig::from_file(
186 tokenizer_config_resource.get_local_path()?,
187 );
188 let sentence_bert_config = SentenceEmbeddingsSentenceBertConfig::from_file(
189 sentence_bert_config_resource.get_local_path()?,
190 );
191
192 let tokenizer = TokenizerOption::from_file(
193 transformer_type,
194 tokenizer_vocab_resource
195 .get_local_path()?
196 .to_string_lossy()
197 .as_ref(),
198 tokenizer_merges_resource
199 .as_ref()
200 .map(|resource| resource.get_local_path())
201 .transpose()?
202 .map(|path| path.to_string_lossy().into_owned())
203 .as_deref(),
204 tokenizer_config
205 .do_lower_case
206 .unwrap_or(sentence_bert_config.do_lower_case),
207 tokenizer_config.strip_accents,
208 tokenizer_config.add_prefix_space,
209 )?;
210
211 Self::new_with_tokenizer(config, tokenizer)
212 }
213
214 pub fn new_with_tokenizer(
223 config: SentenceEmbeddingsConfig,
224 tokenizer: TokenizerOption,
225 ) -> Result<Self, RustBertError> {
226 let SentenceEmbeddingsConfig {
227 modules_config_resource,
228 sentence_bert_config_resource,
229 tokenizer_config_resource: _,
230 tokenizer_vocab_resource: _,
231 tokenizer_merges_resource: _,
232 transformer_type,
233 transformer_config_resource,
234 transformer_weights_resource,
235 pooling_config_resource,
236 dense_config_resource,
237 dense_weights_resource,
238 device,
239 kind,
240 } = config;
241
242 let modules =
243 SentenceEmbeddingsModulesConfig::from_file(modules_config_resource.get_local_path()?)
244 .validate()?;
245
246 let sentence_bert_config = SentenceEmbeddingsSentenceBertConfig::from_file(
247 sentence_bert_config_resource.get_local_path()?,
248 );
249
250 let mut var_store = nn::VarStore::new(device);
252 let transformer_config = ConfigOption::from_file(
253 transformer_type,
254 transformer_config_resource.get_local_path()?,
255 );
256 let transformer =
257 SentenceEmbeddingsOption::new(transformer_type, var_store.root(), &transformer_config)?;
258 crate::resources::load_weights(
259 &transformer_weights_resource,
260 &mut var_store,
261 kind,
262 device,
263 )?;
264
265 let pooling_config = PoolingConfig::from_file(pooling_config_resource.get_local_path()?);
267 let mut embeddings_dim = pooling_config.word_embedding_dimension;
268 let pooling_layer = Pooling::new(pooling_config);
269
270 let dense_layer = if modules.dense_module().is_some() {
272 let dense_config =
273 DenseConfig::from_file(dense_config_resource.unwrap().get_local_path()?);
274 embeddings_dim = dense_config.out_features;
275 Some(Dense::new(
276 dense_config,
277 dense_weights_resource.unwrap().get_local_path()?,
278 device,
279 )?)
280 } else {
281 None
282 };
283
284 let normalize_embeddings = modules.has_normalization();
285
286 Ok(Self {
287 tokenizer,
288 sentence_bert_config,
289 tokenizer_truncation_strategy: TruncationStrategy::LongestFirst,
290 var_store,
291 transformer,
292 transformer_config,
293 pooling_layer,
294 dense_layer,
295 normalize_embeddings,
296 embeddings_dim,
297 })
298 }
299
300 pub fn get_tokenizer(&self) -> &TokenizerOption {
302 &self.tokenizer
303 }
304
305 pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
307 &mut self.tokenizer
308 }
309
310 pub fn set_tokenizer_truncation(&mut self, truncation_strategy: TruncationStrategy) {
312 self.tokenizer_truncation_strategy = truncation_strategy;
313 }
314
315 pub fn get_embedding_dim(&self) -> Result<i64, RustBertError> {
317 Ok(self.embeddings_dim)
318 }
319
320 pub fn tokenize<S>(&self, inputs: &[S]) -> SentenceEmbeddingsTokenizerOutput
322 where
323 S: AsRef<str> + Send + Sync,
324 {
325 let tokenized_input = self.tokenizer.encode_list(
326 inputs,
327 self.sentence_bert_config.max_seq_length,
328 &self.tokenizer_truncation_strategy,
329 0,
330 );
331
332 let max_len = tokenized_input
333 .iter()
334 .map(|input| input.token_ids.len())
335 .max()
336 .unwrap_or(0);
337
338 let pad_token_id = self.tokenizer.get_pad_id().unwrap_or(0);
339 let tokens_ids = tokenized_input
340 .into_iter()
341 .map(|input| {
342 let mut token_ids = input.token_ids;
343 token_ids.extend(vec![pad_token_id; max_len - token_ids.len()]);
344 token_ids
345 })
346 .collect::<Vec<_>>();
347
348 let tokens_masks = tokens_ids
349 .iter()
350 .map(|input| {
351 Tensor::from_slice(
352 &input
353 .iter()
354 .map(|&e| i64::from(e != pad_token_id))
355 .collect::<Vec<_>>(),
356 )
357 })
358 .collect::<Vec<_>>();
359
360 let tokens_ids = tokens_ids
361 .into_iter()
362 .map(|input| Tensor::from_slice(&(input)))
363 .collect::<Vec<_>>();
364
365 SentenceEmbeddingsTokenizerOutput {
366 tokens_ids,
367 tokens_masks,
368 }
369 }
370
371 pub fn encode_as_tensor<S>(
373 &self,
374 inputs: &[S],
375 ) -> Result<SentenceEmbeddingsModelOutput, RustBertError>
376 where
377 S: AsRef<str> + Send + Sync,
378 {
379 let SentenceEmbeddingsTokenizerOutput {
380 tokens_ids,
381 tokens_masks,
382 } = self.tokenize(inputs);
383 if tokens_ids.is_empty() {
384 return Err(RustBertError::ValueError(
385 "No n-gram found in the document. \
386 Try allowing smaller n-gram sizes or relax stopword/forbidden characters criteria."
387 .to_string(),
388 ));
389 }
390 let tokens_ids = Tensor::stack(&tokens_ids, 0).to(self.var_store.device());
391 let tokens_masks = Tensor::stack(&tokens_masks, 0).to(self.var_store.device());
392
393 let (tokens_embeddings, all_attentions) =
394 tch::no_grad(|| self.transformer.forward(&tokens_ids, &tokens_masks))?;
395
396 let mean_pool =
397 tch::no_grad(|| self.pooling_layer.forward(tokens_embeddings, &tokens_masks));
398 let maybe_linear = if let Some(dense_layer) = &self.dense_layer {
399 tch::no_grad(|| dense_layer.forward(&mean_pool))
400 } else {
401 mean_pool
402 };
403 let maybe_normalized = if self.normalize_embeddings {
404 let norm = &maybe_linear
405 .norm_scalaropt_dim(2, [1], true)
406 .clamp_min(1e-12)
407 .expand_as(&maybe_linear);
408 maybe_linear / norm
409 } else {
410 maybe_linear
411 };
412
413 Ok(SentenceEmbeddingsModelOutput {
414 embeddings: maybe_normalized,
415 all_attentions,
416 })
417 }
418
419 pub fn encode<S>(&self, inputs: &[S]) -> Result<Vec<Embedding>, RustBertError>
421 where
422 S: AsRef<str> + Send + Sync,
423 {
424 let SentenceEmbeddingsModelOutput { embeddings, .. } = self.encode_as_tensor(inputs)?;
425 Ok(Vec::try_from(embeddings)?)
426 }
427
428 fn nb_layers(&self) -> usize {
429 use SentenceEmbeddingsOption::*;
430 match (&self.transformer, &self.transformer_config) {
431 (Bert(_), ConfigOption::Bert(conf)) => conf.num_hidden_layers as usize,
432 (Bert(_), _) => unreachable!(),
433 (DistilBert(_), ConfigOption::DistilBert(conf)) => conf.n_layers as usize,
434 (DistilBert(_), _) => unreachable!(),
435 (Roberta(_), ConfigOption::Bert(conf)) => conf.num_hidden_layers as usize,
436 (Roberta(_), _) => unreachable!(),
437 (Albert(_), ConfigOption::Albert(conf)) => conf.num_hidden_layers as usize,
438 (Albert(_), _) => unreachable!(),
439 (T5(_), ConfigOption::T5(conf)) => conf.num_layers as usize,
440 (T5(_), _) => unreachable!(),
441 }
442 }
443
444 fn nb_heads(&self) -> usize {
445 use SentenceEmbeddingsOption::*;
446 match (&self.transformer, &self.transformer_config) {
447 (Bert(_), ConfigOption::Bert(conf)) => conf.num_attention_heads as usize,
448 (Bert(_), _) => unreachable!(),
449 (DistilBert(_), ConfigOption::DistilBert(conf)) => conf.n_heads as usize,
450 (DistilBert(_), _) => unreachable!(),
451 (Roberta(_), ConfigOption::Roberta(conf)) => conf.num_attention_heads as usize,
452 (Roberta(_), _) => unreachable!(),
453 (Albert(_), ConfigOption::Albert(conf)) => conf.num_attention_heads as usize,
454 (Albert(_), _) => unreachable!(),
455 (T5(_), ConfigOption::T5(conf)) => conf.num_heads as usize,
456 (T5(_), _) => unreachable!(),
457 }
458 }
459
460 pub fn encode_with_attention<S>(
462 &self,
463 inputs: &[S],
464 ) -> Result<(Vec<Embedding>, Vec<AttentionOutput>), RustBertError>
465 where
466 S: AsRef<str> + Send + Sync,
467 {
468 let SentenceEmbeddingsModelOutput {
469 embeddings,
470 all_attentions,
471 } = self.encode_as_tensor(inputs)?;
472
473 let embeddings = Vec::try_from(embeddings)?;
474 let all_attentions = all_attentions.ok_or_else(|| {
475 RustBertError::InvalidConfigurationError("No attention outputted".into())
476 })?;
477
478 let attention_outputs = (0..inputs.len() as i64)
479 .map(|i| {
480 let mut attention_output = AttentionOutput::with_capacity(self.nb_layers());
481 for layer in all_attentions.iter() {
482 let mut attention_layer = AttentionLayer::with_capacity(self.nb_heads());
483 for head in 0..self.nb_heads() {
484 let attention_slice = layer
485 .slice(0, i, i + 1, 1)
486 .slice(1, head as i64, head as i64 + 1, 1)
487 .squeeze();
488 let attention_head = AttentionHead::try_from(attention_slice).unwrap();
489 attention_layer.push(attention_head);
490 }
491 attention_output.push(attention_layer);
492 }
493 attention_output
494 })
495 .collect::<Vec<AttentionOutput>>();
496
497 Ok((embeddings, attention_outputs))
498 }
499}
500
501pub struct SentenceEmbeddingsTokenizerOutput {
503 pub tokens_ids: Vec<Tensor>,
504 pub tokens_masks: Vec<Tensor>,
505}
506
507pub struct SentenceEmbeddingsModelOutput {
509 pub embeddings: Tensor,
510 pub all_attentions: Option<Vec<Tensor>>,
511}