rust_bert/pipelines/zero_shot_classification.rs
1// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
2// Copyright 2019-2020 Guillaume Becquin
3// Copyright 2020 Maarten van Gompel
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7// http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14//! # Zero-shot classification pipeline
15//! Performs zero-shot classification on input sentences with provided labels using a model fine-tuned for Natural Language Inference.
16//! The default model is a BART model fine-tuned on a MNLI. From a list of input sequences to classify and a list of target labels,
17//! single-class or multi-label classification is performed, translating the classification task to an inference task.
18//! The default template for translation to inference task is `This example is about {}.`. This template can be updated to a more specific
19//! value that may match better the use case, for example `This review is about a {product_class}`.
20//!
21//! - `predict` performs single-class classification (one and exactly one label must be true for each provided input)
22//! - `predict_multilabel` performs multi-label classification (zero, one or more labels may be true for each provided input)
23//!
24//! ```no_run
25//! # use rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationModel;
26//! # fn main() -> anyhow::Result<()> {
27//! let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?;
28//! let input_sentence = "Who are you voting for in 2020?";
29//! let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition.";
30//! let candidate_labels = &["politics", "public health", "economics", "sports"];
31//! let output = sequence_classification_model.predict_multilabel(
32//! &[input_sentence, input_sequence_2],
33//! candidate_labels,
34//! None,
35//! 128,
36//! );
37//! # Ok(())
38//! # }
39//! ```
40//!
41//! outputs:
42//! ```no_run
43//! # use rust_bert::pipelines::sequence_classification::Label;
44//! let output = [
45//! [
46//! Label {
47//! text: "politics".to_string(),
48//! score: 0.972,
49//! id: 0,
50//! sentence: 0,
51//! },
52//! Label {
53//! text: "public health".to_string(),
54//! score: 0.032,
55//! id: 1,
56//! sentence: 0,
57//! },
58//! Label {
59//! text: "economy".to_string(),
60//! score: 0.006,
61//! id: 2,
62//! sentence: 0,
63//! },
64//! Label {
65//! text: "sports".to_string(),
66//! score: 0.004,
67//! id: 3,
68//! sentence: 0,
69//! },
70//! ],
71//! [
72//! Label {
73//! text: "politics".to_string(),
74//! score: 0.943,
75//! id: 0,
76//! sentence: 1,
77//! },
78//! Label {
79//! text: "economy".to_string(),
80//! score: 0.985,
81//! id: 2,
82//! sentence: 1,
83//! },
84//! Label {
85//! text: "public health".to_string(),
86//! score: 0.0818,
87//! id: 1,
88//! sentence: 1,
89//! },
90//! Label {
91//! text: "sports".to_string(),
92//! score: 0.001,
93//! id: 3,
94//! sentence: 1,
95//! },
96//! ],
97//! ]
98//! .to_vec();
99//! ```
100
101use crate::albert::AlbertForSequenceClassification;
102use crate::bart::BartForSequenceClassification;
103use crate::bert::BertForSequenceClassification;
104use crate::deberta::DebertaForSequenceClassification;
105use crate::deberta_v2::DebertaV2ForSequenceClassification;
106use crate::distilbert::DistilBertModelClassifier;
107use crate::longformer::LongformerForSequenceClassification;
108use crate::mobilebert::MobileBertForSequenceClassification;
109use crate::pipelines::common::{
110 cast_var_store, ConfigOption, ModelResource, ModelType, TokenizerOption,
111};
112use crate::pipelines::sequence_classification::Label;
113use crate::resources::ResourceProvider;
114use crate::roberta::RobertaForSequenceClassification;
115use crate::xlnet::XLNetForSequenceClassification;
116use crate::RustBertError;
117use rust_tokenizers::tokenizer::TruncationStrategy;
118use rust_tokenizers::TokenizedInput;
119
120#[cfg(feature = "onnx")]
121use crate::pipelines::onnx::{config::ONNXEnvironmentConfig, ONNXEncoder};
122#[cfg(feature = "remote")]
123use crate::{
124 bart::{BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources},
125 resources::RemoteResource,
126};
127use tch::kind::Kind::{Bool, Float};
128use tch::nn::VarStore;
129use tch::{no_grad, Device, Kind, Tensor};
130
131/// # Configuration for ZeroShotClassificationModel
132/// Contains information regarding the model to load and device to place the model on.
133pub struct ZeroShotClassificationConfig {
134 /// Model type
135 pub model_type: ModelType,
136 /// Model weights resource (default: pretrained BERT model on CoNLL)
137 pub model_resource: ModelResource,
138 /// Config resource (default: pretrained BERT model on CoNLL)
139 pub config_resource: Box<dyn ResourceProvider + Send>,
140 /// Vocab resource (default: pretrained BERT model on CoNLL)
141 pub vocab_resource: Box<dyn ResourceProvider + Send>,
142 /// Merges resource (default: None)
143 pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
144 /// Automatically lower case all input upon tokenization (assumes a lower-cased model)
145 pub lower_case: bool,
146 /// Flag indicating if the tokenizer should strip accents (normalization). Only used for BERT / ALBERT models
147 pub strip_accents: Option<bool>,
148 /// Flag indicating if the tokenizer should add a white space before each tokenized input (needed for some Roberta models)
149 pub add_prefix_space: Option<bool>,
150 /// Device to place the model on (default: CUDA/GPU when available)
151 pub device: Device,
152 /// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise
153 pub kind: Option<Kind>,
154}
155
156impl ZeroShotClassificationConfig {
157 /// Instantiate a new zero shot classification configuration of the supplied type.
158 ///
159 /// # Arguments
160 ///
161 /// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
162 /// * model - The `ResourceProvider` pointing to the model to load (e.g. model.ot)
163 /// * config - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
164 /// * vocab - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
165 /// * merges - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
166 /// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
167 pub fn new<RC, RV>(
168 model_type: ModelType,
169 model_resource: ModelResource,
170 config_resource: RC,
171 vocab_resource: RV,
172 merges_resource: Option<RV>,
173 lower_case: bool,
174 strip_accents: impl Into<Option<bool>>,
175 add_prefix_space: impl Into<Option<bool>>,
176 ) -> ZeroShotClassificationConfig
177 where
178 RC: ResourceProvider + Send + 'static,
179 RV: ResourceProvider + Send + 'static,
180 {
181 ZeroShotClassificationConfig {
182 model_type,
183 model_resource,
184 config_resource: Box::new(config_resource),
185 vocab_resource: Box::new(vocab_resource),
186 merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
187 lower_case,
188 strip_accents: strip_accents.into(),
189 add_prefix_space: add_prefix_space.into(),
190 device: Device::cuda_if_available(),
191 kind: None,
192 }
193 }
194}
195
196#[cfg(feature = "remote")]
197impl Default for ZeroShotClassificationConfig {
198 /// Provides a default zero-shot classification model (English)
199 fn default() -> ZeroShotClassificationConfig {
200 ZeroShotClassificationConfig {
201 model_type: ModelType::Bart,
202 model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
203 BartModelResources::BART_MNLI,
204 ))),
205 config_resource: Box::new(RemoteResource::from_pretrained(
206 BartConfigResources::BART_MNLI,
207 )),
208 vocab_resource: Box::new(RemoteResource::from_pretrained(
209 BartVocabResources::BART_MNLI,
210 )),
211 merges_resource: Some(Box::new(RemoteResource::from_pretrained(
212 BartMergesResources::BART_MNLI,
213 ))),
214 lower_case: false,
215 strip_accents: None,
216 add_prefix_space: None,
217 device: Device::cuda_if_available(),
218 kind: None,
219 }
220 }
221}
222
223/// # Abstraction that holds one particular zero shot classification model, for any of the supported models
224/// The models are using a classification architecture that should be trained on Natural Language Inference.
225/// The models should output a Tensor of size > 2 in the label dimension, with the first logit corresponding
226/// to contradiction and the last logit corresponding to entailment.
227#[allow(clippy::large_enum_variant)]
228pub enum ZeroShotClassificationOption {
229 /// Bart for Sequence Classification
230 Bart(BartForSequenceClassification),
231 /// DeBERTa for Sequence Classification
232 Deberta(DebertaForSequenceClassification),
233 /// DeBERTaV2 for Sequence Classification
234 DebertaV2(DebertaV2ForSequenceClassification),
235 /// Bert for Sequence Classification
236 Bert(BertForSequenceClassification),
237 /// DistilBert for Sequence Classification
238 DistilBert(DistilBertModelClassifier),
239 /// MobileBert for Sequence Classification
240 MobileBert(MobileBertForSequenceClassification),
241 /// Roberta for Sequence Classification
242 Roberta(RobertaForSequenceClassification),
243 /// XLMRoberta for Sequence Classification
244 XLMRoberta(RobertaForSequenceClassification),
245 /// Albert for Sequence Classification
246 Albert(AlbertForSequenceClassification),
247 /// XLNet for Sequence Classification
248 XLNet(XLNetForSequenceClassification),
249 /// Longformer for Sequence Classification
250 Longformer(LongformerForSequenceClassification),
251 /// ONNX model for Sequence Classification
252 #[cfg(feature = "onnx")]
253 ONNX(ONNXEncoder),
254}
255
256impl ZeroShotClassificationOption {
257 /// Instantiate a new zer-shot classification model of the supplied type.
258 ///
259 /// # Arguments
260 ///
261 /// * `ZeroShotClassificationConfig` - Zero-shot classification pipeline configuration. The type of model created will be inferred from the
262 /// `ModelResources` (Torch or ONNX) and `ModelType` (Architecture for Torch models) variants provided and
263 pub fn new(config: &ZeroShotClassificationConfig) -> Result<Self, RustBertError> {
264 match config.model_resource {
265 ModelResource::Torch(_) => Self::new_torch(config),
266 #[cfg(feature = "onnx")]
267 ModelResource::ONNX(_) => Self::new_onnx(config),
268 }
269 }
270
271 fn new_torch(config: &ZeroShotClassificationConfig) -> Result<Self, RustBertError> {
272 let device = config.device;
273 let weights_path = config.model_resource.get_torch_local_path()?;
274 let mut var_store = VarStore::new(device);
275 let model_config =
276 &ConfigOption::from_file(config.model_type, config.config_resource.get_local_path()?);
277 let model_type = config.model_type;
278 let model = match model_type {
279 ModelType::Bart => {
280 if let ConfigOption::Bart(config) = model_config {
281 Ok(Self::Bart(
282 BartForSequenceClassification::new(var_store.root(), config)?,
283 ))
284 } else {
285 Err(RustBertError::InvalidConfigurationError(
286 "You can only supply a BartConfig for Bart!".to_string(),
287 ))
288 }
289 }
290 ModelType::Deberta => {
291 if let ConfigOption::Deberta(config) = model_config {
292 Ok(Self::Deberta(
293 DebertaForSequenceClassification::new(var_store.root(), config)?,
294 ))
295 } else {
296 Err(RustBertError::InvalidConfigurationError(
297 "You can only supply a DebertaConfig for DeBERTa!".to_string(),
298 ))
299 }
300 }
301 ModelType::DebertaV2 => {
302 if let ConfigOption::DebertaV2(config) = model_config {
303 Ok(Self::DebertaV2(
304 DebertaV2ForSequenceClassification::new(var_store.root(), config)?,
305 ))
306 } else {
307 Err(RustBertError::InvalidConfigurationError(
308 "You can only supply a DebertaConfig for DeBERTaV2!".to_string(),
309 ))
310 }
311 }
312 ModelType::Bert => {
313 if let ConfigOption::Bert(config) = model_config {
314 Ok(Self::Bert(
315 BertForSequenceClassification::new(var_store.root(), config)?,
316 ))
317 } else {
318 Err(RustBertError::InvalidConfigurationError(
319 "You can only supply a BertConfig for Bert!".to_string(),
320 ))
321 }
322 }
323 ModelType::DistilBert => {
324 if let ConfigOption::DistilBert(config) = model_config {
325 Ok(Self::DistilBert(
326 DistilBertModelClassifier::new(var_store.root(), config)?,
327 ))
328 } else {
329 Err(RustBertError::InvalidConfigurationError(
330 "You can only supply a DistilBertConfig for DistilBert!".to_string(),
331 ))
332 }
333 }
334 ModelType::MobileBert => {
335 if let ConfigOption::MobileBert(config) = model_config {
336 Ok(Self::MobileBert(
337 MobileBertForSequenceClassification::new(var_store.root(), config)?,
338 ))
339 } else {
340 Err(RustBertError::InvalidConfigurationError(
341 "You can only supply a MobileBertConfig for MobileBert!".to_string(),
342 ))
343 }
344 }
345 ModelType::Roberta => {
346 if let ConfigOption::Roberta(config) = model_config {
347 Ok(Self::Roberta(
348 RobertaForSequenceClassification::new(var_store.root(), config)?,
349 ))
350 } else {
351 Err(RustBertError::InvalidConfigurationError(
352 "You can only supply a RobertaConfig for Roberta!".to_string(),
353 ))
354 }
355 }
356 ModelType::XLMRoberta => {
357 if let ConfigOption::Bert(config) = model_config {
358 Ok(Self::XLMRoberta(
359 RobertaForSequenceClassification::new(var_store.root(), config)?,
360 ))
361 } else {
362 Err(RustBertError::InvalidConfigurationError(
363 "You can only supply a BertConfig for Roberta!".to_string(),
364 ))
365 }
366 }
367 ModelType::Albert => {
368 if let ConfigOption::Albert(config) = model_config {
369 Ok(Self::Albert(
370 AlbertForSequenceClassification::new(var_store.root(), config)?,
371 ))
372 } else {
373 Err(RustBertError::InvalidConfigurationError(
374 "You can only supply an AlbertConfig for Albert!".to_string(),
375 ))
376 }
377 }
378 ModelType::XLNet => {
379 if let ConfigOption::XLNet(config) = model_config {
380 Ok(Self::XLNet(
381 XLNetForSequenceClassification::new(var_store.root(), config)?,
382 ))
383 } else {
384 Err(RustBertError::InvalidConfigurationError(
385 "You can only supply an AlbertConfig for Albert!".to_string(),
386 ))
387 }
388 }
389 ModelType::Longformer => {
390 if let ConfigOption::Longformer(config) = model_config {
391 Ok(Self::Longformer(
392 LongformerForSequenceClassification::new(var_store.root(), config)?,
393 ))
394 } else {
395 Err(RustBertError::InvalidConfigurationError(
396 "You can only supply a LongformerConfig for Longformer!".to_string(),
397 ))
398 }
399 }
400 #[cfg(feature = "onnx")]
401 ModelType::ONNX => Err(RustBertError::InvalidConfigurationError(
402 "A `ModelType::ONNX` ModelType was provided in the configuration with `ModelResources::TORCH`, these are incompatible".to_string(),
403 )),
404 _ => Err(RustBertError::InvalidConfigurationError(format!(
405 "Zero shot classification not implemented for {model_type:?}!",
406 ))),
407 }?;
408 var_store.load(weights_path)?;
409 cast_var_store(&mut var_store, config.kind, device);
410 Ok(model)
411 }
412
413 #[cfg(feature = "onnx")]
414 pub fn new_onnx(config: &ZeroShotClassificationConfig) -> Result<Self, RustBertError> {
415 let onnx_config = ONNXEnvironmentConfig::from_device(config.device);
416 let environment = onnx_config.get_environment()?;
417 let encoder_file = config
418 .model_resource
419 .get_onnx_local_paths()?
420 .encoder_path
421 .ok_or(RustBertError::InvalidConfigurationError(
422 "An encoder file must be provided for zero-shot classification ONNX models."
423 .to_string(),
424 ))?;
425
426 Ok(Self::ONNX(ONNXEncoder::new(
427 encoder_file,
428 &environment,
429 &onnx_config,
430 )?))
431 }
432
433 /// Returns the `ModelType` for this SequenceClassificationOption
434 pub fn model_type(&self) -> ModelType {
435 match *self {
436 Self::Bart(_) => ModelType::Bart,
437 Self::Deberta(_) => ModelType::Deberta,
438 Self::DebertaV2(_) => ModelType::DebertaV2,
439 Self::Bert(_) => ModelType::Bert,
440 Self::Roberta(_) => ModelType::Roberta,
441 Self::XLMRoberta(_) => ModelType::Roberta,
442 Self::DistilBert(_) => ModelType::DistilBert,
443 Self::MobileBert(_) => ModelType::MobileBert,
444 Self::Albert(_) => ModelType::Albert,
445 Self::XLNet(_) => ModelType::XLNet,
446 Self::Longformer(_) => ModelType::Longformer,
447 #[cfg(feature = "onnx")]
448 Self::ONNX(_) => ModelType::ONNX,
449 }
450 }
451
452 /// Interface method to forward_t() of the particular models.
453 pub fn forward_t(
454 &self,
455 input_ids: Option<&Tensor>,
456 mask: Option<&Tensor>,
457 token_type_ids: Option<&Tensor>,
458 position_ids: Option<&Tensor>,
459 input_embeds: Option<&Tensor>,
460 train: bool,
461 ) -> Tensor {
462 match *self {
463 Self::Bart(ref model) => {
464 model
465 .forward_t(
466 input_ids.expect("`input_ids` must be provided for BART models"),
467 mask,
468 None,
469 None,
470 None,
471 train,
472 )
473 .decoder_output
474 }
475 Self::Bert(ref model) => {
476 model
477 .forward_t(
478 input_ids,
479 mask,
480 token_type_ids,
481 position_ids,
482 input_embeds,
483 train,
484 )
485 .logits
486 }
487 Self::Deberta(ref model) => {
488 model
489 .forward_t(
490 input_ids,
491 mask,
492 token_type_ids,
493 position_ids,
494 input_embeds,
495 train,
496 )
497 .expect("Error in DeBERTa forward_t")
498 .logits
499 }
500 Self::DebertaV2(ref model) => {
501 model
502 .forward_t(
503 input_ids,
504 mask,
505 token_type_ids,
506 position_ids,
507 input_embeds,
508 train,
509 )
510 .expect("Error in DeBERTaV2 forward_t")
511 .logits
512 }
513 Self::DistilBert(ref model) => {
514 model
515 .forward_t(input_ids, mask, input_embeds, train)
516 .expect("Error in distilbert forward_t")
517 .logits
518 }
519 Self::MobileBert(ref model) => {
520 model
521 .forward_t(input_ids, None, None, input_embeds, mask, train)
522 .expect("Error in mobilebert forward_t")
523 .logits
524 }
525 Self::Roberta(ref model) | Self::XLMRoberta(ref model) => {
526 model
527 .forward_t(
528 input_ids,
529 mask,
530 token_type_ids,
531 position_ids,
532 input_embeds,
533 train,
534 )
535 .logits
536 }
537 Self::Albert(ref model) => {
538 model
539 .forward_t(
540 input_ids,
541 mask,
542 token_type_ids,
543 position_ids,
544 input_embeds,
545 train,
546 )
547 .logits
548 }
549 Self::XLNet(ref model) => {
550 model
551 .forward_t(
552 input_ids,
553 mask,
554 None,
555 None,
556 None,
557 token_type_ids,
558 input_embeds,
559 train,
560 )
561 .logits
562 }
563 Self::Longformer(ref model) => {
564 model
565 .forward_t(
566 input_ids,
567 mask,
568 None,
569 token_type_ids,
570 position_ids,
571 input_embeds,
572 train,
573 )
574 .expect("Error in Longformer forward pass.")
575 .logits
576 }
577 #[cfg(feature = "onnx")]
578 Self::ONNX(ref model) => model
579 .forward(
580 input_ids,
581 mask.map(|tensor| tensor.to_kind(Kind::Int64)).as_ref(),
582 token_type_ids,
583 position_ids,
584 input_embeds,
585 )
586 .expect("Error in ONNX forward pass.")
587 .logits
588 .unwrap(),
589 }
590 }
591}
592
593pub type ZeroShotTemplate = Box<dyn Fn(&str) -> String>;
594/// Template used to transform the zero-shot classification labels into a set of
595/// natural language hypotheses for natural language inference.
596///
597/// For example, transform `[positive, negative]` into
598/// `[This is a positive review, This is a negative review]`
599///
600/// The function should take a `&str` as an input and return the formatted String.
601///
602/// This transformation has a strong impact on the resulting classification accuracy.
603/// If no function is provided for zero-shot classification, the default templating
604/// function will be used:
605///
606/// ```rust
607/// fn default_template(label: &str) -> String {
608/// format!("This example is about {}.", label)
609/// }
610/// ```
611
612/// # ZeroShotClassificationModel for Zero Shot Classification
613pub struct ZeroShotClassificationModel {
614 tokenizer: TokenizerOption,
615 zero_shot_classifier: ZeroShotClassificationOption,
616 device: Device,
617}
618
619impl ZeroShotClassificationModel {
620 /// Build a new `ZeroShotClassificationModel`
621 ///
622 /// # Arguments
623 ///
624 /// * `config` - `SequenceClassificationConfig` object containing the resource references (model, vocabulary, configuration) and device placement (CPU/GPU)
625 ///
626 /// # Example
627 ///
628 /// ```no_run
629 /// # fn main() -> anyhow::Result<()> {
630 /// use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
631 ///
632 /// let model = SequenceClassificationModel::new(Default::default())?;
633 /// # Ok(())
634 /// # }
635 /// ```
636 pub fn new(
637 config: ZeroShotClassificationConfig,
638 ) -> Result<ZeroShotClassificationModel, RustBertError> {
639 let vocab_path = config.vocab_resource.get_local_path()?;
640 let merges_path = config
641 .merges_resource
642 .as_ref()
643 .map(|resource| resource.get_local_path())
644 .transpose()?;
645
646 let tokenizer = TokenizerOption::from_file(
647 config.model_type,
648 vocab_path.to_str().unwrap(),
649 merges_path.as_deref().map(|path| path.to_str().unwrap()),
650 config.lower_case,
651 config.strip_accents,
652 config.add_prefix_space,
653 )?;
654 Self::new_with_tokenizer(config, tokenizer)
655 }
656
657 /// Build a new `ZeroShotClassificationModel` with a provided tokenizer.
658 ///
659 /// # Arguments
660 ///
661 /// * `config` - `SequenceClassificationConfig` object containing the resource references (model, vocabulary, configuration) and device placement (CPU/GPU)
662 /// * `tokenizer` - `TokenizerOption` tokenizer to use for zero-shot classification.
663 ///
664 /// # Example
665 ///
666 /// ```no_run
667 /// # fn main() -> anyhow::Result<()> {
668 /// use rust_bert::pipelines::common::{ModelType, TokenizerOption};
669 /// use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
670 /// let tokenizer = TokenizerOption::from_file(
671 /// ModelType::Bert,
672 /// "path/to/vocab.txt",
673 /// None,
674 /// false,
675 /// None,
676 /// None,
677 /// )?;
678 /// let model = SequenceClassificationModel::new_with_tokenizer(Default::default(), tokenizer)?;
679 /// # Ok(())
680 /// # }
681 /// ```
682 pub fn new_with_tokenizer(
683 config: ZeroShotClassificationConfig,
684 tokenizer: TokenizerOption,
685 ) -> Result<ZeroShotClassificationModel, RustBertError> {
686 let device = config.device;
687 let zero_shot_classifier = ZeroShotClassificationOption::new(&config)?;
688
689 Ok(ZeroShotClassificationModel {
690 tokenizer,
691 zero_shot_classifier,
692 device,
693 })
694 }
695
696 /// Get a reference to the model tokenizer.
697 pub fn get_tokenizer(&self) -> &TokenizerOption {
698 &self.tokenizer
699 }
700
701 /// Get a mutable reference to the model tokenizer.
702 pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
703 &mut self.tokenizer
704 }
705
706 fn prepare_for_model<'a, S, T>(
707 &self,
708 inputs: S,
709 labels: T,
710 template: Option<ZeroShotTemplate>,
711 max_len: usize,
712 ) -> Result<(Tensor, Tensor, Tensor), RustBertError>
713 where
714 S: AsRef<[&'a str]>,
715 T: AsRef<[&'a str]>,
716 {
717 let label_sentences: Vec<String> = match template {
718 Some(function) => labels
719 .as_ref()
720 .iter()
721 .map(|label| function(label))
722 .collect(),
723 None => labels
724 .as_ref()
725 .iter()
726 .map(|label| format!("This example is about {label}."))
727 .collect(),
728 };
729
730 let text_pair_list = inputs
731 .as_ref()
732 .iter()
733 .flat_map(|input| {
734 label_sentences
735 .iter()
736 .map(move |label_sentence| (*input, label_sentence.as_str()))
737 })
738 .collect::<Vec<(&str, &str)>>();
739
740 let mut tokenized_input: Vec<TokenizedInput> = self.tokenizer.encode_pair_list(
741 text_pair_list.as_ref(),
742 max_len,
743 &TruncationStrategy::LongestFirst,
744 0,
745 );
746 let max_len = tokenized_input
747 .iter()
748 .map(|input| input.token_ids.len())
749 .max()
750 .ok_or_else(|| RustBertError::ValueError("Got empty iterator as input".to_string()))?;
751
752 let pad_id = self
753 .tokenizer
754 .get_pad_id()
755 .expect("The Tokenizer used for sequence classification should contain a PAD id");
756 let input_ids = tokenized_input
757 .iter_mut()
758 .map(|input| {
759 input.token_ids.resize(max_len, pad_id);
760 Tensor::from_slice(&(input.token_ids))
761 })
762 .collect::<Vec<_>>();
763 let token_type_ids = tokenized_input
764 .iter_mut()
765 .map(|input| {
766 input
767 .segment_ids
768 .resize(max_len, *input.segment_ids.last().unwrap_or(&0));
769 Tensor::from_slice(&(input.segment_ids))
770 })
771 .collect::<Vec<_>>();
772
773 let input_ids = Tensor::stack(input_ids.as_slice(), 0).to(self.device);
774 let token_type_ids = Tensor::stack(token_type_ids.as_slice(), 0)
775 .to(self.device)
776 .to_kind(Kind::Int64);
777 let mask = input_ids
778 .ne(self
779 .tokenizer
780 .get_pad_id()
781 .expect("The Tokenizer used for zero shot classification should contain a PAD id"))
782 .to_kind(Bool);
783
784 Ok((input_ids, mask, token_type_ids))
785 }
786
787 /// Zero shot classification with 1 (and exactly 1) true label.
788 ///
789 /// # Arguments
790 ///
791 /// * `input` - `&[&str]` Array of texts to classify.
792 /// * `labels` - `&[&str]` Possible labels for the inputs.
793 /// * `template` - `Option<Box<dyn Fn(&str) -> String>>` closure to build label propositions. If None, will default to `"This example is {}."`.
794 /// * `max_length` -`usize` Maximum sequence length for the inputs. If needed, the input sequence will be truncated before the label template.
795 ///
796 /// # Returns
797 ///
798 /// * `Result<Vec<Label>, RustBertError>` containing the most likely label for each input sentence or error, if any.
799 ///
800 /// # Example
801 ///
802 /// ```no_run
803 /// # fn main() -> anyhow::Result<()> {
804 /// use rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationModel;
805 ///
806 /// let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?;
807 ///
808 /// let input_sentence = "Who are you voting for in 2020?";
809 /// let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition.";
810 /// let candidate_labels = &["politics", "public health", "economics", "sports"];
811 ///
812 /// let output = sequence_classification_model.predict(
813 /// &[input_sentence, input_sequence_2],
814 /// candidate_labels,
815 /// None,
816 /// 128,
817 /// );
818 /// # Ok(())
819 /// # }
820 /// ```
821 ///
822 /// outputs:
823 /// ```no_run
824 /// # use rust_bert::pipelines::sequence_classification::Label;
825 /// let output = [
826 /// Label {
827 /// text: "politics".to_string(),
828 /// score: 0.959,
829 /// id: 0,
830 /// sentence: 0,
831 /// },
832 /// Label {
833 /// text: "economy".to_string(),
834 /// score: 0.642,
835 /// id: 2,
836 /// sentence: 1,
837 /// },
838 /// ]
839 /// .to_vec();
840 /// ```
841 pub fn predict<'a, S, T>(
842 &self,
843 inputs: S,
844 labels: T,
845 template: Option<ZeroShotTemplate>,
846 max_length: usize,
847 ) -> Result<Vec<Label>, RustBertError>
848 where
849 S: AsRef<[&'a str]>,
850 T: AsRef<[&'a str]>,
851 {
852 let num_inputs = inputs.as_ref().len();
853 let (input_tensor, mask, token_type_ids) =
854 self.prepare_for_model(inputs.as_ref(), labels.as_ref(), template, max_length)?;
855
856 let output = no_grad(|| {
857 let output = self.zero_shot_classifier.forward_t(
858 Some(&input_tensor),
859 Some(&mask),
860 Some(&token_type_ids),
861 None,
862 None,
863 false,
864 );
865 output.view((num_inputs as i64, labels.as_ref().len() as i64, -1i64))
866 });
867
868 let scores = output.softmax(1, Float).select(-1, -1);
869 let label_indices = scores.as_ref().argmax(-1, true).squeeze_dim(1);
870 let scores = scores
871 .gather(1, &label_indices.unsqueeze(-1), false)
872 .squeeze_dim(1);
873 let label_indices = label_indices.iter::<i64>()?.collect::<Vec<i64>>();
874 let scores = scores.iter::<f64>()?.collect::<Vec<f64>>();
875
876 let mut output_labels: Vec<Label> = vec![];
877 for sentence_idx in 0..label_indices.len() {
878 let label_string = labels.as_ref()[label_indices[sentence_idx] as usize].to_string();
879 let label = Label {
880 text: label_string,
881 score: scores[sentence_idx],
882 id: label_indices[sentence_idx],
883 sentence: sentence_idx,
884 };
885 output_labels.push(label)
886 }
887 Ok(output_labels)
888 }
889
890 /// Zero shot multi-label classification with 0, 1 or no true label.
891 ///
892 /// # Arguments
893 ///
894 /// * `input` - `&[&str]` Array of texts to classify.
895 /// * `labels` - `&[&str]` Possible labels for the inputs.
896 /// * `template` - `Option<Box<dyn Fn(&str) -> String>>` closure to build label propositions. If None, will default to `"This example is about {}."`.
897 /// * `max_length` -`usize` Maximum sequence length for the inputs. If needed, the input sequence will be truncated before the label template.
898 ///
899 /// # Returns
900 ///
901 /// * `Result<Vec<Vec<Label>>, RustBertError>` containing a vector of labels and their probability for each input text, or error, if any.
902 ///
903 /// # Example
904 ///
905 /// ```no_run
906 /// # fn main() -> anyhow::Result<()> {
907 /// use rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationModel;
908 ///
909 /// let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?;
910 ///
911 /// let input_sentence = "Who are you voting for in 2020?";
912 /// let input_sequence_2 = "The central bank is meeting today to discuss monetary policy.";
913 /// let candidate_labels = &["politics", "public health", "economics", "sports"];
914 ///
915 /// let output = sequence_classification_model.predict_multilabel(
916 /// &[input_sentence, input_sequence_2],
917 /// candidate_labels,
918 /// None,
919 /// 128,
920 /// );
921 /// # Ok(())
922 /// # }
923 /// ```
924 /// outputs:
925 /// ```no_run
926 /// # use rust_bert::pipelines::sequence_classification::Label;
927 /// let output = [
928 /// [
929 /// Label {
930 /// text: "politics".to_string(),
931 /// score: 0.972,
932 /// id: 0,
933 /// sentence: 0,
934 /// },
935 /// Label {
936 /// text: "public health".to_string(),
937 /// score: 0.032,
938 /// id: 1,
939 /// sentence: 0,
940 /// },
941 /// Label {
942 /// text: "economy".to_string(),
943 /// score: 0.006,
944 /// id: 2,
945 /// sentence: 0,
946 /// },
947 /// Label {
948 /// text: "sports".to_string(),
949 /// score: 0.004,
950 /// id: 3,
951 /// sentence: 0,
952 /// },
953 /// ],
954 /// [
955 /// Label {
956 /// text: "politics".to_string(),
957 /// score: 0.975,
958 /// id: 0,
959 /// sentence: 1,
960 /// },
961 /// Label {
962 /// text: "economy".to_string(),
963 /// score: 0.852,
964 /// id: 2,
965 /// sentence: 1,
966 /// },
967 /// Label {
968 /// text: "public health".to_string(),
969 /// score: 0.0818,
970 /// id: 1,
971 /// sentence: 1,
972 /// },
973 /// Label {
974 /// text: "sports".to_string(),
975 /// score: 0.001,
976 /// id: 3,
977 /// sentence: 1,
978 /// },
979 /// ],
980 /// ]
981 /// .to_vec();
982 /// ```
983 pub fn predict_multilabel<'a, S, T>(
984 &self,
985 inputs: S,
986 labels: T,
987 template: Option<ZeroShotTemplate>,
988 max_length: usize,
989 ) -> Result<Vec<Vec<Label>>, RustBertError>
990 where
991 S: AsRef<[&'a str]>,
992 T: AsRef<[&'a str]>,
993 {
994 let num_inputs = inputs.as_ref().len();
995 let (input_tensor, mask, token_type_ids) =
996 self.prepare_for_model(inputs.as_ref(), labels.as_ref(), template, max_length)?;
997
998 let output = no_grad(|| {
999 let output = self.zero_shot_classifier.forward_t(
1000 Some(&input_tensor),
1001 Some(&mask),
1002 Some(&token_type_ids),
1003 None,
1004 None,
1005 false,
1006 );
1007 output.view((num_inputs as i64, labels.as_ref().len() as i64, -1i64))
1008 });
1009 let scores = output.slice(-1, 0, 3, 2).softmax(-1, Float).select(-1, -1);
1010
1011 let mut output_labels = vec![];
1012 for sentence_idx in 0..num_inputs {
1013 let mut sentence_labels = vec![];
1014
1015 for (label_index, score) in scores
1016 .select(0, sentence_idx as i64)
1017 .iter::<f64>()?
1018 .enumerate()
1019 {
1020 let label_string = labels.as_ref()[label_index].to_string();
1021 let label = Label {
1022 text: label_string,
1023 score,
1024 id: label_index as i64,
1025 sentence: sentence_idx,
1026 };
1027 sentence_labels.push(label);
1028 }
1029 output_labels.push(sentence_labels);
1030 }
1031 Ok(output_labels)
1032 }
1033}
1034#[cfg(test)]
1035mod test {
1036 use super::*;
1037
1038 #[test]
1039 #[ignore] // no need to run, compilation is enough to verify it is Send
1040 fn test() {
1041 let config = ZeroShotClassificationConfig::default();
1042 let _: Box<dyn Send> = Box::new(ZeroShotClassificationModel::new(config));
1043 }
1044}