rust_bert/pipelines/sentence_embeddings/
pipeline.rs

1use std::borrow::Borrow;
2use std::convert::{TryFrom, TryInto};
3
4use rust_tokenizers::tokenizer::TruncationStrategy;
5use tch::{nn, Tensor};
6
7use crate::albert::AlbertForSentenceEmbeddings;
8use crate::bert::BertForSentenceEmbeddings;
9use crate::distilbert::DistilBertForSentenceEmbeddings;
10use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
11use crate::pipelines::sentence_embeddings::layers::{Dense, DenseConfig, Pooling, PoolingConfig};
12use crate::pipelines::sentence_embeddings::{
13    AttentionHead, AttentionLayer, AttentionOutput, Embedding, SentenceEmbeddingsConfig,
14    SentenceEmbeddingsModulesConfig, SentenceEmbeddingsSentenceBertConfig,
15    SentenceEmbeddingsTokenizerConfig,
16};
17use crate::roberta::RobertaForSentenceEmbeddings;
18use crate::t5::T5ForSentenceEmbeddings;
19use crate::{Config, RustBertError};
20
21/// # Abstraction that holds one particular sentence embeddings model, for any of the supported models
22pub enum SentenceEmbeddingsOption {
23    /// Bert for Sentence Embeddings
24    Bert(BertForSentenceEmbeddings),
25    /// DistilBert for Sentence Embeddings
26    DistilBert(DistilBertForSentenceEmbeddings),
27    /// Roberta for Sentence Embeddings
28    Roberta(RobertaForSentenceEmbeddings),
29    /// Albert for Sentence Embeddings
30    Albert(AlbertForSentenceEmbeddings),
31    /// T5 for Sentence Embeddings
32    T5(T5ForSentenceEmbeddings),
33}
34
35impl SentenceEmbeddingsOption {
36    /// Instantiate a new sentence embeddings transformer of the supplied type.
37    ///
38    /// # Arguments
39    ///
40    /// * `transformer_type` - `ModelType` indicating the transformer model type to load (must match with the actual data to be loaded)
41    /// * `p` - `tch::nn::Path` path to the model file to load (e.g. rust_model.ot)
42    /// * `config` - A configuration (the transformer model type of the configuration must be compatible with the value for `transformer_type`)
43    pub fn new<'p, P>(
44        transformer_type: ModelType,
45        p: P,
46        config: &ConfigOption,
47    ) -> Result<Self, RustBertError>
48    where
49        P: Borrow<nn::Path<'p>>,
50    {
51        use SentenceEmbeddingsOption::*;
52
53        let option = match transformer_type {
54            ModelType::Bert => Bert(BertForSentenceEmbeddings::new(p, &(config.try_into()?))),
55            ModelType::DistilBert => DistilBert(DistilBertForSentenceEmbeddings::new(
56                p,
57                &(config.try_into()?),
58            )),
59            ModelType::Roberta => Roberta(RobertaForSentenceEmbeddings::new_with_optional_pooler(
60                p,
61                &(config.try_into()?),
62                false,
63            )),
64            ModelType::Albert => Albert(AlbertForSentenceEmbeddings::new(p, &(config.try_into()?))),
65            ModelType::T5 => T5(T5ForSentenceEmbeddings::new(p, &(config.try_into()?))),
66            _ => {
67                return Err(RustBertError::InvalidConfigurationError(format!(
68                    "Unsupported transformer model {transformer_type:?} for Sentence Embeddings"
69                )));
70            }
71        };
72
73        Ok(option)
74    }
75
76    /// Interface method to forward() of the particular transformer models.
77    pub fn forward(
78        &self,
79        tokens_ids: &Tensor,
80        tokens_masks: &Tensor,
81    ) -> Result<(Tensor, Option<Vec<Tensor>>), RustBertError> {
82        match self {
83            Self::Bert(transformer) => transformer
84                .forward_t(
85                    Some(tokens_ids),
86                    Some(tokens_masks),
87                    None,
88                    None,
89                    None,
90                    None,
91                    None,
92                    false,
93                )
94                .map(|transformer_output| {
95                    (
96                        transformer_output.hidden_state,
97                        transformer_output.all_attentions,
98                    )
99                }),
100            Self::DistilBert(transformer) => transformer
101                .forward_t(Some(tokens_ids), Some(tokens_masks), None, false)
102                .map(|transformer_output| {
103                    (
104                        transformer_output.hidden_state,
105                        transformer_output.all_attentions,
106                    )
107                }),
108            Self::Roberta(transformer) => transformer
109                .forward_t(
110                    Some(tokens_ids),
111                    Some(tokens_masks),
112                    None,
113                    None,
114                    None,
115                    None,
116                    None,
117                    false,
118                )
119                .map(|transformer_output| {
120                    (
121                        transformer_output.hidden_state,
122                        transformer_output.all_attentions,
123                    )
124                }),
125            Self::Albert(transformer) => transformer
126                .forward_t(
127                    Some(tokens_ids),
128                    Some(tokens_masks),
129                    None,
130                    None,
131                    None,
132                    false,
133                )
134                .map(|transformer_output| {
135                    (
136                        transformer_output.hidden_state,
137                        transformer_output.all_attentions.map(|attentions| {
138                            attentions
139                                .into_iter()
140                                .map(|tensors| {
141                                    let num_inner_groups = tensors.len() as f64;
142                                    tensors.into_iter().sum::<Tensor>() / num_inner_groups
143                                })
144                                .collect()
145                        }),
146                    )
147                }),
148            Self::T5(transformer) => transformer.forward(tokens_ids, tokens_masks),
149        }
150    }
151}
152
153/// # SentenceEmbeddingsModel to perform sentence embeddings
154///
155/// It is made of the following blocks:
156/// - `transformer`: Base transformer model
157/// - `pooling`: Pooling layer
158/// - `dense` _(optional)_: Linear (feed forward) layer
159/// - `normalization` _(optional)_: Embeddings normalization
160pub struct SentenceEmbeddingsModel {
161    sentence_bert_config: SentenceEmbeddingsSentenceBertConfig,
162    tokenizer: TokenizerOption,
163    tokenizer_truncation_strategy: TruncationStrategy,
164    var_store: nn::VarStore,
165    transformer: SentenceEmbeddingsOption,
166    transformer_config: ConfigOption,
167    pooling_layer: Pooling,
168    dense_layer: Option<Dense>,
169    normalize_embeddings: bool,
170    embeddings_dim: i64,
171}
172
173impl SentenceEmbeddingsModel {
174    /// Build a new `SentenceEmbeddingsModel`
175    ///
176    /// # Arguments
177    ///
178    /// * `config` - `SentenceEmbeddingsConfig` object containing the resource references (model, vocabulary, configuration) and device placement (CPU/GPU)
179    pub fn new(config: SentenceEmbeddingsConfig) -> Result<Self, RustBertError> {
180        let transformer_type = config.transformer_type;
181        let tokenizer_vocab_resource = &config.tokenizer_vocab_resource;
182        let tokenizer_merges_resource = &config.tokenizer_merges_resource;
183        let tokenizer_config_resource = &config.tokenizer_config_resource;
184        let sentence_bert_config_resource = &config.sentence_bert_config_resource;
185        let tokenizer_config = SentenceEmbeddingsTokenizerConfig::from_file(
186            tokenizer_config_resource.get_local_path()?,
187        );
188        let sentence_bert_config = SentenceEmbeddingsSentenceBertConfig::from_file(
189            sentence_bert_config_resource.get_local_path()?,
190        );
191
192        let tokenizer = TokenizerOption::from_file(
193            transformer_type,
194            tokenizer_vocab_resource
195                .get_local_path()?
196                .to_string_lossy()
197                .as_ref(),
198            tokenizer_merges_resource
199                .as_ref()
200                .map(|resource| resource.get_local_path())
201                .transpose()?
202                .map(|path| path.to_string_lossy().into_owned())
203                .as_deref(),
204            tokenizer_config
205                .do_lower_case
206                .unwrap_or(sentence_bert_config.do_lower_case),
207            tokenizer_config.strip_accents,
208            tokenizer_config.add_prefix_space,
209        )?;
210
211        Self::new_with_tokenizer(config, tokenizer)
212    }
213
214    /// Build a new `ONNXCausalGenerator` from a `GenerateConfig` and `TokenizerOption`.
215    ///
216    /// A tokenizer must be provided by the user and can be customized to use non-default settings.
217    ///
218    /// # Arguments
219    ///
220    /// * `config` - `SentenceEmbeddingsConfig` object containing the resource references (model, vocabulary, configuration) and device placement (CPU/GPU)
221    /// * `tokenizer` - `TokenizerOption` tokenizer to use for question answering.
222    pub fn new_with_tokenizer(
223        config: SentenceEmbeddingsConfig,
224        tokenizer: TokenizerOption,
225    ) -> Result<Self, RustBertError> {
226        let SentenceEmbeddingsConfig {
227            modules_config_resource,
228            sentence_bert_config_resource,
229            tokenizer_config_resource: _,
230            tokenizer_vocab_resource: _,
231            tokenizer_merges_resource: _,
232            transformer_type,
233            transformer_config_resource,
234            transformer_weights_resource,
235            pooling_config_resource,
236            dense_config_resource,
237            dense_weights_resource,
238            device,
239            kind,
240        } = config;
241
242        let modules =
243            SentenceEmbeddingsModulesConfig::from_file(modules_config_resource.get_local_path()?)
244                .validate()?;
245
246        let sentence_bert_config = SentenceEmbeddingsSentenceBertConfig::from_file(
247            sentence_bert_config_resource.get_local_path()?,
248        );
249
250        // Setup transformer
251        let mut var_store = nn::VarStore::new(device);
252        let transformer_config = ConfigOption::from_file(
253            transformer_type,
254            transformer_config_resource.get_local_path()?,
255        );
256        let transformer =
257            SentenceEmbeddingsOption::new(transformer_type, var_store.root(), &transformer_config)?;
258        crate::resources::load_weights(
259            &transformer_weights_resource,
260            &mut var_store,
261            kind,
262            device,
263        )?;
264
265        // Setup pooling layer
266        let pooling_config = PoolingConfig::from_file(pooling_config_resource.get_local_path()?);
267        let mut embeddings_dim = pooling_config.word_embedding_dimension;
268        let pooling_layer = Pooling::new(pooling_config);
269
270        // Setup dense layer
271        let dense_layer = if modules.dense_module().is_some() {
272            let dense_config =
273                DenseConfig::from_file(dense_config_resource.unwrap().get_local_path()?);
274            embeddings_dim = dense_config.out_features;
275            Some(Dense::new(
276                dense_config,
277                dense_weights_resource.unwrap().get_local_path()?,
278                device,
279            )?)
280        } else {
281            None
282        };
283
284        let normalize_embeddings = modules.has_normalization();
285
286        Ok(Self {
287            tokenizer,
288            sentence_bert_config,
289            tokenizer_truncation_strategy: TruncationStrategy::LongestFirst,
290            var_store,
291            transformer,
292            transformer_config,
293            pooling_layer,
294            dense_layer,
295            normalize_embeddings,
296            embeddings_dim,
297        })
298    }
299
300    /// Get a reference to the model tokenizer.
301    pub fn get_tokenizer(&self) -> &TokenizerOption {
302        &self.tokenizer
303    }
304
305    /// Get a mutable reference to the model tokenizer.
306    pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
307        &mut self.tokenizer
308    }
309
310    /// Sets the tokenizer's truncation strategy
311    pub fn set_tokenizer_truncation(&mut self, truncation_strategy: TruncationStrategy) {
312        self.tokenizer_truncation_strategy = truncation_strategy;
313    }
314
315    /// Return the embedding output dimension
316    pub fn get_embedding_dim(&self) -> Result<i64, RustBertError> {
317        Ok(self.embeddings_dim)
318    }
319
320    /// Tokenizes the inputs
321    pub fn tokenize<S>(&self, inputs: &[S]) -> SentenceEmbeddingsTokenizerOutput
322    where
323        S: AsRef<str> + Send + Sync,
324    {
325        let tokenized_input = self.tokenizer.encode_list(
326            inputs,
327            self.sentence_bert_config.max_seq_length,
328            &self.tokenizer_truncation_strategy,
329            0,
330        );
331
332        let max_len = tokenized_input
333            .iter()
334            .map(|input| input.token_ids.len())
335            .max()
336            .unwrap_or(0);
337
338        let pad_token_id = self.tokenizer.get_pad_id().unwrap_or(0);
339        let tokens_ids = tokenized_input
340            .into_iter()
341            .map(|input| {
342                let mut token_ids = input.token_ids;
343                token_ids.extend(vec![pad_token_id; max_len - token_ids.len()]);
344                token_ids
345            })
346            .collect::<Vec<_>>();
347
348        let tokens_masks = tokens_ids
349            .iter()
350            .map(|input| {
351                Tensor::from_slice(
352                    &input
353                        .iter()
354                        .map(|&e| i64::from(e != pad_token_id))
355                        .collect::<Vec<_>>(),
356                )
357            })
358            .collect::<Vec<_>>();
359
360        let tokens_ids = tokens_ids
361            .into_iter()
362            .map(|input| Tensor::from_slice(&(input)))
363            .collect::<Vec<_>>();
364
365        SentenceEmbeddingsTokenizerOutput {
366            tokens_ids,
367            tokens_masks,
368        }
369    }
370
371    /// Computes sentence embeddings, outputs `Tensor`.
372    pub fn encode_as_tensor<S>(
373        &self,
374        inputs: &[S],
375    ) -> Result<SentenceEmbeddingsModelOutput, RustBertError>
376    where
377        S: AsRef<str> + Send + Sync,
378    {
379        let SentenceEmbeddingsTokenizerOutput {
380            tokens_ids,
381            tokens_masks,
382        } = self.tokenize(inputs);
383        if tokens_ids.is_empty() {
384            return Err(RustBertError::ValueError(
385                "No n-gram found in the document. \
386                Try allowing smaller n-gram sizes or relax stopword/forbidden characters criteria."
387                    .to_string(),
388            ));
389        }
390        let tokens_ids = Tensor::stack(&tokens_ids, 0).to(self.var_store.device());
391        let tokens_masks = Tensor::stack(&tokens_masks, 0).to(self.var_store.device());
392
393        let (tokens_embeddings, all_attentions) =
394            tch::no_grad(|| self.transformer.forward(&tokens_ids, &tokens_masks))?;
395
396        let mean_pool =
397            tch::no_grad(|| self.pooling_layer.forward(tokens_embeddings, &tokens_masks));
398        let maybe_linear = if let Some(dense_layer) = &self.dense_layer {
399            tch::no_grad(|| dense_layer.forward(&mean_pool))
400        } else {
401            mean_pool
402        };
403        let maybe_normalized = if self.normalize_embeddings {
404            let norm = &maybe_linear
405                .norm_scalaropt_dim(2, [1], true)
406                .clamp_min(1e-12)
407                .expand_as(&maybe_linear);
408            maybe_linear / norm
409        } else {
410            maybe_linear
411        };
412
413        Ok(SentenceEmbeddingsModelOutput {
414            embeddings: maybe_normalized,
415            all_attentions,
416        })
417    }
418
419    /// Computes sentence embeddings.
420    pub fn encode<S>(&self, inputs: &[S]) -> Result<Vec<Embedding>, RustBertError>
421    where
422        S: AsRef<str> + Send + Sync,
423    {
424        let SentenceEmbeddingsModelOutput { embeddings, .. } = self.encode_as_tensor(inputs)?;
425        Ok(Vec::try_from(embeddings)?)
426    }
427
428    fn nb_layers(&self) -> usize {
429        use SentenceEmbeddingsOption::*;
430        match (&self.transformer, &self.transformer_config) {
431            (Bert(_), ConfigOption::Bert(conf)) => conf.num_hidden_layers as usize,
432            (Bert(_), _) => unreachable!(),
433            (DistilBert(_), ConfigOption::DistilBert(conf)) => conf.n_layers as usize,
434            (DistilBert(_), _) => unreachable!(),
435            (Roberta(_), ConfigOption::Bert(conf)) => conf.num_hidden_layers as usize,
436            (Roberta(_), _) => unreachable!(),
437            (Albert(_), ConfigOption::Albert(conf)) => conf.num_hidden_layers as usize,
438            (Albert(_), _) => unreachable!(),
439            (T5(_), ConfigOption::T5(conf)) => conf.num_layers as usize,
440            (T5(_), _) => unreachable!(),
441        }
442    }
443
444    fn nb_heads(&self) -> usize {
445        use SentenceEmbeddingsOption::*;
446        match (&self.transformer, &self.transformer_config) {
447            (Bert(_), ConfigOption::Bert(conf)) => conf.num_attention_heads as usize,
448            (Bert(_), _) => unreachable!(),
449            (DistilBert(_), ConfigOption::DistilBert(conf)) => conf.n_heads as usize,
450            (DistilBert(_), _) => unreachable!(),
451            (Roberta(_), ConfigOption::Roberta(conf)) => conf.num_attention_heads as usize,
452            (Roberta(_), _) => unreachable!(),
453            (Albert(_), ConfigOption::Albert(conf)) => conf.num_attention_heads as usize,
454            (Albert(_), _) => unreachable!(),
455            (T5(_), ConfigOption::T5(conf)) => conf.num_heads as usize,
456            (T5(_), _) => unreachable!(),
457        }
458    }
459
460    /// Computes sentence embeddings, also outputs `AttentionOutput`s.
461    pub fn encode_with_attention<S>(
462        &self,
463        inputs: &[S],
464    ) -> Result<(Vec<Embedding>, Vec<AttentionOutput>), RustBertError>
465    where
466        S: AsRef<str> + Send + Sync,
467    {
468        let SentenceEmbeddingsModelOutput {
469            embeddings,
470            all_attentions,
471        } = self.encode_as_tensor(inputs)?;
472
473        let embeddings = Vec::try_from(embeddings)?;
474        let all_attentions = all_attentions.ok_or_else(|| {
475            RustBertError::InvalidConfigurationError("No attention outputted".into())
476        })?;
477
478        let attention_outputs = (0..inputs.len() as i64)
479            .map(|i| {
480                let mut attention_output = AttentionOutput::with_capacity(self.nb_layers());
481                for layer in all_attentions.iter() {
482                    let mut attention_layer = AttentionLayer::with_capacity(self.nb_heads());
483                    for head in 0..self.nb_heads() {
484                        let attention_slice = layer
485                            .slice(0, i, i + 1, 1)
486                            .slice(1, head as i64, head as i64 + 1, 1)
487                            .squeeze();
488                        let attention_head = AttentionHead::try_from(attention_slice).unwrap();
489                        attention_layer.push(attention_head);
490                    }
491                    attention_output.push(attention_layer);
492                }
493                attention_output
494            })
495            .collect::<Vec<AttentionOutput>>();
496
497        Ok((embeddings, attention_outputs))
498    }
499}
500
501/// Container for the SentenceEmbeddings tokenizer output.
502pub struct SentenceEmbeddingsTokenizerOutput {
503    pub tokens_ids: Vec<Tensor>,
504    pub tokens_masks: Vec<Tensor>,
505}
506
507/// Container for the SentenceEmbeddings model output.
508pub struct SentenceEmbeddingsModelOutput {
509    pub embeddings: Tensor,
510    pub all_attentions: Option<Vec<Tensor>>,
511}