rust_bert/pipelines/
zero_shot_classification.rs

1// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
2// Copyright 2019-2020 Guillaume Becquin
3// Copyright 2020 Maarten van Gompel
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//     http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14//! # Zero-shot classification pipeline
15//! Performs zero-shot classification on input sentences with provided labels using a model fine-tuned for Natural Language Inference.
16//! The default model is a BART model fine-tuned on a MNLI. From a list of input sequences to classify and a list of target labels,
17//! single-class or multi-label classification is performed, translating the classification task to an inference task.
18//! The default template for translation to inference task is `This example is about {}.`. This template can be updated to a more specific
19//! value that may match better the use case, for example `This review is about a {product_class}`.
20//!
21//! - `predict` performs single-class classification (one and exactly one label must be true for each provided input)
22//! - `predict_multilabel` performs multi-label classification (zero, one or more labels may be true for each provided input)
23//!
24//! ```no_run
25//! # use rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationModel;
26//! # fn main() -> anyhow::Result<()> {
27//! let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?;
28//! let input_sentence = "Who are you voting for in 2020?";
29//! let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition.";
30//! let candidate_labels = &["politics", "public health", "economics", "sports"];
31//! let output = sequence_classification_model.predict_multilabel(
32//!     &[input_sentence, input_sequence_2],
33//!     candidate_labels,
34//!     None,
35//!     128,
36//! );
37//! # Ok(())
38//! # }
39//! ```
40//!
41//! outputs:
42//! ```no_run
43//! # use rust_bert::pipelines::sequence_classification::Label;
44//! let output = [
45//!     [
46//!         Label {
47//!             text: "politics".to_string(),
48//!             score: 0.972,
49//!             id: 0,
50//!             sentence: 0,
51//!         },
52//!         Label {
53//!             text: "public health".to_string(),
54//!             score: 0.032,
55//!             id: 1,
56//!             sentence: 0,
57//!         },
58//!         Label {
59//!             text: "economy".to_string(),
60//!             score: 0.006,
61//!             id: 2,
62//!             sentence: 0,
63//!         },
64//!         Label {
65//!             text: "sports".to_string(),
66//!             score: 0.004,
67//!             id: 3,
68//!             sentence: 0,
69//!         },
70//!     ],
71//!     [
72//!         Label {
73//!             text: "politics".to_string(),
74//!             score: 0.943,
75//!             id: 0,
76//!             sentence: 1,
77//!         },
78//!         Label {
79//!             text: "economy".to_string(),
80//!             score: 0.985,
81//!             id: 2,
82//!             sentence: 1,
83//!         },
84//!         Label {
85//!             text: "public health".to_string(),
86//!             score: 0.0818,
87//!             id: 1,
88//!             sentence: 1,
89//!         },
90//!         Label {
91//!             text: "sports".to_string(),
92//!             score: 0.001,
93//!             id: 3,
94//!             sentence: 1,
95//!         },
96//!     ],
97//! ]
98//! .to_vec();
99//! ```
100
101use crate::albert::AlbertForSequenceClassification;
102use crate::bart::BartForSequenceClassification;
103use crate::bert::BertForSequenceClassification;
104use crate::deberta::DebertaForSequenceClassification;
105use crate::deberta_v2::DebertaV2ForSequenceClassification;
106use crate::distilbert::DistilBertModelClassifier;
107use crate::longformer::LongformerForSequenceClassification;
108use crate::mobilebert::MobileBertForSequenceClassification;
109use crate::pipelines::common::{
110    cast_var_store, ConfigOption, ModelResource, ModelType, TokenizerOption,
111};
112use crate::pipelines::sequence_classification::Label;
113use crate::resources::ResourceProvider;
114use crate::roberta::RobertaForSequenceClassification;
115use crate::xlnet::XLNetForSequenceClassification;
116use crate::RustBertError;
117use rust_tokenizers::tokenizer::TruncationStrategy;
118use rust_tokenizers::TokenizedInput;
119
120#[cfg(feature = "onnx")]
121use crate::pipelines::onnx::{config::ONNXEnvironmentConfig, ONNXEncoder};
122#[cfg(feature = "remote")]
123use crate::{
124    bart::{BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources},
125    resources::RemoteResource,
126};
127use tch::kind::Kind::{Bool, Float};
128use tch::nn::VarStore;
129use tch::{no_grad, Device, Kind, Tensor};
130
131/// # Configuration for ZeroShotClassificationModel
132/// Contains information regarding the model to load and device to place the model on.
133pub struct ZeroShotClassificationConfig {
134    /// Model type
135    pub model_type: ModelType,
136    /// Model weights resource (default: pretrained BERT model on CoNLL)
137    pub model_resource: ModelResource,
138    /// Config resource (default: pretrained BERT model on CoNLL)
139    pub config_resource: Box<dyn ResourceProvider + Send>,
140    /// Vocab resource (default: pretrained BERT model on CoNLL)
141    pub vocab_resource: Box<dyn ResourceProvider + Send>,
142    /// Merges resource (default: None)
143    pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
144    /// Automatically lower case all input upon tokenization (assumes a lower-cased model)
145    pub lower_case: bool,
146    /// Flag indicating if the tokenizer should strip accents (normalization). Only used for BERT / ALBERT models
147    pub strip_accents: Option<bool>,
148    /// Flag indicating if the tokenizer should add a white space before each tokenized input (needed for some Roberta models)
149    pub add_prefix_space: Option<bool>,
150    /// Device to place the model on (default: CUDA/GPU when available)
151    pub device: Device,
152    /// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise
153    pub kind: Option<Kind>,
154}
155
156impl ZeroShotClassificationConfig {
157    /// Instantiate a new zero shot classification configuration of the supplied type.
158    ///
159    /// # Arguments
160    ///
161    /// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
162    /// * model - The `ResourceProvider` pointing to the model to load (e.g.  model.ot)
163    /// * config - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
164    /// * vocab - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g.  vocab.txt/vocab.json)
165    /// * merges - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g.  merges.txt), needed only for Roberta.
166    /// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
167    pub fn new<RC, RV>(
168        model_type: ModelType,
169        model_resource: ModelResource,
170        config_resource: RC,
171        vocab_resource: RV,
172        merges_resource: Option<RV>,
173        lower_case: bool,
174        strip_accents: impl Into<Option<bool>>,
175        add_prefix_space: impl Into<Option<bool>>,
176    ) -> ZeroShotClassificationConfig
177    where
178        RC: ResourceProvider + Send + 'static,
179        RV: ResourceProvider + Send + 'static,
180    {
181        ZeroShotClassificationConfig {
182            model_type,
183            model_resource,
184            config_resource: Box::new(config_resource),
185            vocab_resource: Box::new(vocab_resource),
186            merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
187            lower_case,
188            strip_accents: strip_accents.into(),
189            add_prefix_space: add_prefix_space.into(),
190            device: Device::cuda_if_available(),
191            kind: None,
192        }
193    }
194}
195
196#[cfg(feature = "remote")]
197impl Default for ZeroShotClassificationConfig {
198    /// Provides a default zero-shot classification model (English)
199    fn default() -> ZeroShotClassificationConfig {
200        ZeroShotClassificationConfig {
201            model_type: ModelType::Bart,
202            model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
203                BartModelResources::BART_MNLI,
204            ))),
205            config_resource: Box::new(RemoteResource::from_pretrained(
206                BartConfigResources::BART_MNLI,
207            )),
208            vocab_resource: Box::new(RemoteResource::from_pretrained(
209                BartVocabResources::BART_MNLI,
210            )),
211            merges_resource: Some(Box::new(RemoteResource::from_pretrained(
212                BartMergesResources::BART_MNLI,
213            ))),
214            lower_case: false,
215            strip_accents: None,
216            add_prefix_space: None,
217            device: Device::cuda_if_available(),
218            kind: None,
219        }
220    }
221}
222
223/// # Abstraction that holds one particular zero shot classification model, for any of the supported models
224/// The models are using a classification architecture that should be trained on Natural Language Inference.
225/// The models should output a Tensor of size > 2 in the label dimension, with the first logit corresponding
226/// to contradiction and the last logit corresponding to entailment.
227#[allow(clippy::large_enum_variant)]
228pub enum ZeroShotClassificationOption {
229    /// Bart for Sequence Classification
230    Bart(BartForSequenceClassification),
231    /// DeBERTa for Sequence Classification
232    Deberta(DebertaForSequenceClassification),
233    /// DeBERTaV2 for Sequence Classification
234    DebertaV2(DebertaV2ForSequenceClassification),
235    /// Bert for Sequence Classification
236    Bert(BertForSequenceClassification),
237    /// DistilBert for Sequence Classification
238    DistilBert(DistilBertModelClassifier),
239    /// MobileBert for Sequence Classification
240    MobileBert(MobileBertForSequenceClassification),
241    /// Roberta for Sequence Classification
242    Roberta(RobertaForSequenceClassification),
243    /// XLMRoberta for Sequence Classification
244    XLMRoberta(RobertaForSequenceClassification),
245    /// Albert for Sequence Classification
246    Albert(AlbertForSequenceClassification),
247    /// XLNet for Sequence Classification
248    XLNet(XLNetForSequenceClassification),
249    /// Longformer for Sequence Classification
250    Longformer(LongformerForSequenceClassification),
251    /// ONNX model for Sequence Classification
252    #[cfg(feature = "onnx")]
253    ONNX(ONNXEncoder),
254}
255
256impl ZeroShotClassificationOption {
257    /// Instantiate a new zer-shot classification model of the supplied type.
258    ///
259    /// # Arguments
260    ///
261    /// * `ZeroShotClassificationConfig` - Zero-shot classification pipeline configuration. The type of model created will be inferred from the
262    ///     `ModelResources` (Torch or ONNX) and `ModelType` (Architecture for Torch models) variants provided and
263    pub fn new(config: &ZeroShotClassificationConfig) -> Result<Self, RustBertError> {
264        match config.model_resource {
265            ModelResource::Torch(_) => Self::new_torch(config),
266            #[cfg(feature = "onnx")]
267            ModelResource::ONNX(_) => Self::new_onnx(config),
268        }
269    }
270
271    fn new_torch(config: &ZeroShotClassificationConfig) -> Result<Self, RustBertError> {
272        let device = config.device;
273        let weights_path = config.model_resource.get_torch_local_path()?;
274        let mut var_store = VarStore::new(device);
275        let model_config =
276            &ConfigOption::from_file(config.model_type, config.config_resource.get_local_path()?);
277        let model_type = config.model_type;
278        let model = match model_type {
279            ModelType::Bart => {
280                if let ConfigOption::Bart(config) = model_config {
281                    Ok(Self::Bart(
282                        BartForSequenceClassification::new(var_store.root(), config)?,
283                    ))
284                } else {
285                    Err(RustBertError::InvalidConfigurationError(
286                        "You can only supply a BartConfig for Bart!".to_string(),
287                    ))
288                }
289            }
290            ModelType::Deberta => {
291                if let ConfigOption::Deberta(config) = model_config {
292                    Ok(Self::Deberta(
293                        DebertaForSequenceClassification::new(var_store.root(), config)?,
294                    ))
295                } else {
296                    Err(RustBertError::InvalidConfigurationError(
297                        "You can only supply a DebertaConfig for DeBERTa!".to_string(),
298                    ))
299                }
300            }
301            ModelType::DebertaV2 => {
302                if let ConfigOption::DebertaV2(config) = model_config {
303                    Ok(Self::DebertaV2(
304                        DebertaV2ForSequenceClassification::new(var_store.root(), config)?,
305                    ))
306                } else {
307                    Err(RustBertError::InvalidConfigurationError(
308                        "You can only supply a DebertaConfig for DeBERTaV2!".to_string(),
309                    ))
310                }
311            }
312            ModelType::Bert => {
313                if let ConfigOption::Bert(config) = model_config {
314                    Ok(Self::Bert(
315                        BertForSequenceClassification::new(var_store.root(), config)?,
316                    ))
317                } else {
318                    Err(RustBertError::InvalidConfigurationError(
319                        "You can only supply a BertConfig for Bert!".to_string(),
320                    ))
321                }
322            }
323            ModelType::DistilBert => {
324                if let ConfigOption::DistilBert(config) = model_config {
325                    Ok(Self::DistilBert(
326                        DistilBertModelClassifier::new(var_store.root(), config)?,
327                    ))
328                } else {
329                    Err(RustBertError::InvalidConfigurationError(
330                        "You can only supply a DistilBertConfig for DistilBert!".to_string(),
331                    ))
332                }
333            }
334            ModelType::MobileBert => {
335                if let ConfigOption::MobileBert(config) = model_config {
336                    Ok(Self::MobileBert(
337                        MobileBertForSequenceClassification::new(var_store.root(), config)?,
338                    ))
339                } else {
340                    Err(RustBertError::InvalidConfigurationError(
341                        "You can only supply a MobileBertConfig for MobileBert!".to_string(),
342                    ))
343                }
344            }
345            ModelType::Roberta => {
346                if let ConfigOption::Roberta(config) = model_config {
347                    Ok(Self::Roberta(
348                        RobertaForSequenceClassification::new(var_store.root(), config)?,
349                    ))
350                } else {
351                    Err(RustBertError::InvalidConfigurationError(
352                        "You can only supply a RobertaConfig for Roberta!".to_string(),
353                    ))
354                }
355            }
356            ModelType::XLMRoberta => {
357                if let ConfigOption::Bert(config) = model_config {
358                    Ok(Self::XLMRoberta(
359                        RobertaForSequenceClassification::new(var_store.root(), config)?,
360                    ))
361                } else {
362                    Err(RustBertError::InvalidConfigurationError(
363                        "You can only supply a BertConfig for Roberta!".to_string(),
364                    ))
365                }
366            }
367            ModelType::Albert => {
368                if let ConfigOption::Albert(config) = model_config {
369                    Ok(Self::Albert(
370                        AlbertForSequenceClassification::new(var_store.root(), config)?,
371                    ))
372                } else {
373                    Err(RustBertError::InvalidConfigurationError(
374                        "You can only supply an AlbertConfig for Albert!".to_string(),
375                    ))
376                }
377            }
378            ModelType::XLNet => {
379                if let ConfigOption::XLNet(config) = model_config {
380                    Ok(Self::XLNet(
381                        XLNetForSequenceClassification::new(var_store.root(), config)?,
382                    ))
383                } else {
384                    Err(RustBertError::InvalidConfigurationError(
385                        "You can only supply an AlbertConfig for Albert!".to_string(),
386                    ))
387                }
388            }
389            ModelType::Longformer => {
390                if let ConfigOption::Longformer(config) = model_config {
391                    Ok(Self::Longformer(
392                        LongformerForSequenceClassification::new(var_store.root(), config)?,
393                    ))
394                } else {
395                    Err(RustBertError::InvalidConfigurationError(
396                        "You can only supply a LongformerConfig for Longformer!".to_string(),
397                    ))
398                }
399            }
400            #[cfg(feature = "onnx")]
401            ModelType::ONNX => Err(RustBertError::InvalidConfigurationError(
402                "A `ModelType::ONNX` ModelType was provided in the configuration with `ModelResources::TORCH`, these are incompatible".to_string(),
403            )),
404            _ => Err(RustBertError::InvalidConfigurationError(format!(
405                "Zero shot classification not implemented for {model_type:?}!",
406            ))),
407        }?;
408        var_store.load(weights_path)?;
409        cast_var_store(&mut var_store, config.kind, device);
410        Ok(model)
411    }
412
413    #[cfg(feature = "onnx")]
414    pub fn new_onnx(config: &ZeroShotClassificationConfig) -> Result<Self, RustBertError> {
415        let onnx_config = ONNXEnvironmentConfig::from_device(config.device);
416        let environment = onnx_config.get_environment()?;
417        let encoder_file = config
418            .model_resource
419            .get_onnx_local_paths()?
420            .encoder_path
421            .ok_or(RustBertError::InvalidConfigurationError(
422                "An encoder file must be provided for zero-shot classification ONNX models."
423                    .to_string(),
424            ))?;
425
426        Ok(Self::ONNX(ONNXEncoder::new(
427            encoder_file,
428            &environment,
429            &onnx_config,
430        )?))
431    }
432
433    /// Returns the `ModelType` for this SequenceClassificationOption
434    pub fn model_type(&self) -> ModelType {
435        match *self {
436            Self::Bart(_) => ModelType::Bart,
437            Self::Deberta(_) => ModelType::Deberta,
438            Self::DebertaV2(_) => ModelType::DebertaV2,
439            Self::Bert(_) => ModelType::Bert,
440            Self::Roberta(_) => ModelType::Roberta,
441            Self::XLMRoberta(_) => ModelType::Roberta,
442            Self::DistilBert(_) => ModelType::DistilBert,
443            Self::MobileBert(_) => ModelType::MobileBert,
444            Self::Albert(_) => ModelType::Albert,
445            Self::XLNet(_) => ModelType::XLNet,
446            Self::Longformer(_) => ModelType::Longformer,
447            #[cfg(feature = "onnx")]
448            Self::ONNX(_) => ModelType::ONNX,
449        }
450    }
451
452    /// Interface method to forward_t() of the particular models.
453    pub fn forward_t(
454        &self,
455        input_ids: Option<&Tensor>,
456        mask: Option<&Tensor>,
457        token_type_ids: Option<&Tensor>,
458        position_ids: Option<&Tensor>,
459        input_embeds: Option<&Tensor>,
460        train: bool,
461    ) -> Tensor {
462        match *self {
463            Self::Bart(ref model) => {
464                model
465                    .forward_t(
466                        input_ids.expect("`input_ids` must be provided for BART models"),
467                        mask,
468                        None,
469                        None,
470                        None,
471                        train,
472                    )
473                    .decoder_output
474            }
475            Self::Bert(ref model) => {
476                model
477                    .forward_t(
478                        input_ids,
479                        mask,
480                        token_type_ids,
481                        position_ids,
482                        input_embeds,
483                        train,
484                    )
485                    .logits
486            }
487            Self::Deberta(ref model) => {
488                model
489                    .forward_t(
490                        input_ids,
491                        mask,
492                        token_type_ids,
493                        position_ids,
494                        input_embeds,
495                        train,
496                    )
497                    .expect("Error in DeBERTa forward_t")
498                    .logits
499            }
500            Self::DebertaV2(ref model) => {
501                model
502                    .forward_t(
503                        input_ids,
504                        mask,
505                        token_type_ids,
506                        position_ids,
507                        input_embeds,
508                        train,
509                    )
510                    .expect("Error in DeBERTaV2 forward_t")
511                    .logits
512            }
513            Self::DistilBert(ref model) => {
514                model
515                    .forward_t(input_ids, mask, input_embeds, train)
516                    .expect("Error in distilbert forward_t")
517                    .logits
518            }
519            Self::MobileBert(ref model) => {
520                model
521                    .forward_t(input_ids, None, None, input_embeds, mask, train)
522                    .expect("Error in mobilebert forward_t")
523                    .logits
524            }
525            Self::Roberta(ref model) | Self::XLMRoberta(ref model) => {
526                model
527                    .forward_t(
528                        input_ids,
529                        mask,
530                        token_type_ids,
531                        position_ids,
532                        input_embeds,
533                        train,
534                    )
535                    .logits
536            }
537            Self::Albert(ref model) => {
538                model
539                    .forward_t(
540                        input_ids,
541                        mask,
542                        token_type_ids,
543                        position_ids,
544                        input_embeds,
545                        train,
546                    )
547                    .logits
548            }
549            Self::XLNet(ref model) => {
550                model
551                    .forward_t(
552                        input_ids,
553                        mask,
554                        None,
555                        None,
556                        None,
557                        token_type_ids,
558                        input_embeds,
559                        train,
560                    )
561                    .logits
562            }
563            Self::Longformer(ref model) => {
564                model
565                    .forward_t(
566                        input_ids,
567                        mask,
568                        None,
569                        token_type_ids,
570                        position_ids,
571                        input_embeds,
572                        train,
573                    )
574                    .expect("Error in Longformer forward pass.")
575                    .logits
576            }
577            #[cfg(feature = "onnx")]
578            Self::ONNX(ref model) => model
579                .forward(
580                    input_ids,
581                    mask.map(|tensor| tensor.to_kind(Kind::Int64)).as_ref(),
582                    token_type_ids,
583                    position_ids,
584                    input_embeds,
585                )
586                .expect("Error in ONNX forward pass.")
587                .logits
588                .unwrap(),
589        }
590    }
591}
592
593pub type ZeroShotTemplate = Box<dyn Fn(&str) -> String>;
594/// Template used to transform the zero-shot classification labels into a set of
595/// natural language hypotheses for natural language inference.
596///
597/// For example, transform `[positive, negative]` into
598/// `[This is a positive review, This is a negative review]`
599///
600/// The function should take a `&str` as an input and return the formatted String.
601///
602/// This transformation has a strong impact on the resulting classification accuracy.
603/// If no function is provided for zero-shot classification, the default templating
604/// function will be used:
605///
606/// ```rust
607/// fn default_template(label: &str) -> String {
608///     format!("This example is about {}.", label)
609/// }
610/// ```
611
612/// # ZeroShotClassificationModel for Zero Shot Classification
613pub struct ZeroShotClassificationModel {
614    tokenizer: TokenizerOption,
615    zero_shot_classifier: ZeroShotClassificationOption,
616    device: Device,
617}
618
619impl ZeroShotClassificationModel {
620    /// Build a new `ZeroShotClassificationModel`
621    ///
622    /// # Arguments
623    ///
624    /// * `config` - `SequenceClassificationConfig` object containing the resource references (model, vocabulary, configuration) and device placement (CPU/GPU)
625    ///
626    /// # Example
627    ///
628    /// ```no_run
629    /// # fn main() -> anyhow::Result<()> {
630    /// use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
631    ///
632    /// let model = SequenceClassificationModel::new(Default::default())?;
633    /// # Ok(())
634    /// # }
635    /// ```
636    pub fn new(
637        config: ZeroShotClassificationConfig,
638    ) -> Result<ZeroShotClassificationModel, RustBertError> {
639        let vocab_path = config.vocab_resource.get_local_path()?;
640        let merges_path = config
641            .merges_resource
642            .as_ref()
643            .map(|resource| resource.get_local_path())
644            .transpose()?;
645
646        let tokenizer = TokenizerOption::from_file(
647            config.model_type,
648            vocab_path.to_str().unwrap(),
649            merges_path.as_deref().map(|path| path.to_str().unwrap()),
650            config.lower_case,
651            config.strip_accents,
652            config.add_prefix_space,
653        )?;
654        Self::new_with_tokenizer(config, tokenizer)
655    }
656
657    /// Build a new `ZeroShotClassificationModel` with a provided tokenizer.
658    ///
659    /// # Arguments
660    ///
661    /// * `config` - `SequenceClassificationConfig` object containing the resource references (model, vocabulary, configuration) and device placement (CPU/GPU)
662    /// * `tokenizer` - `TokenizerOption` tokenizer to use for zero-shot classification.
663    ///
664    /// # Example
665    ///
666    /// ```no_run
667    /// # fn main() -> anyhow::Result<()> {
668    /// use rust_bert::pipelines::common::{ModelType, TokenizerOption};
669    /// use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
670    /// let tokenizer = TokenizerOption::from_file(
671    ///     ModelType::Bert,
672    ///     "path/to/vocab.txt",
673    ///     None,
674    ///     false,
675    ///     None,
676    ///     None,
677    /// )?;
678    /// let model = SequenceClassificationModel::new_with_tokenizer(Default::default(), tokenizer)?;
679    /// # Ok(())
680    /// # }
681    /// ```
682    pub fn new_with_tokenizer(
683        config: ZeroShotClassificationConfig,
684        tokenizer: TokenizerOption,
685    ) -> Result<ZeroShotClassificationModel, RustBertError> {
686        let device = config.device;
687        let zero_shot_classifier = ZeroShotClassificationOption::new(&config)?;
688
689        Ok(ZeroShotClassificationModel {
690            tokenizer,
691            zero_shot_classifier,
692            device,
693        })
694    }
695
696    /// Get a reference to the model tokenizer.
697    pub fn get_tokenizer(&self) -> &TokenizerOption {
698        &self.tokenizer
699    }
700
701    /// Get a mutable reference to the model tokenizer.
702    pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
703        &mut self.tokenizer
704    }
705
706    fn prepare_for_model<'a, S, T>(
707        &self,
708        inputs: S,
709        labels: T,
710        template: Option<ZeroShotTemplate>,
711        max_len: usize,
712    ) -> Result<(Tensor, Tensor, Tensor), RustBertError>
713    where
714        S: AsRef<[&'a str]>,
715        T: AsRef<[&'a str]>,
716    {
717        let label_sentences: Vec<String> = match template {
718            Some(function) => labels
719                .as_ref()
720                .iter()
721                .map(|label| function(label))
722                .collect(),
723            None => labels
724                .as_ref()
725                .iter()
726                .map(|label| format!("This example is about {label}."))
727                .collect(),
728        };
729
730        let text_pair_list = inputs
731            .as_ref()
732            .iter()
733            .flat_map(|input| {
734                label_sentences
735                    .iter()
736                    .map(move |label_sentence| (*input, label_sentence.as_str()))
737            })
738            .collect::<Vec<(&str, &str)>>();
739
740        let mut tokenized_input: Vec<TokenizedInput> = self.tokenizer.encode_pair_list(
741            text_pair_list.as_ref(),
742            max_len,
743            &TruncationStrategy::LongestFirst,
744            0,
745        );
746        let max_len = tokenized_input
747            .iter()
748            .map(|input| input.token_ids.len())
749            .max()
750            .ok_or_else(|| RustBertError::ValueError("Got empty iterator as input".to_string()))?;
751
752        let pad_id = self
753            .tokenizer
754            .get_pad_id()
755            .expect("The Tokenizer used for sequence classification should contain a PAD id");
756        let input_ids = tokenized_input
757            .iter_mut()
758            .map(|input| {
759                input.token_ids.resize(max_len, pad_id);
760                Tensor::from_slice(&(input.token_ids))
761            })
762            .collect::<Vec<_>>();
763        let token_type_ids = tokenized_input
764            .iter_mut()
765            .map(|input| {
766                input
767                    .segment_ids
768                    .resize(max_len, *input.segment_ids.last().unwrap_or(&0));
769                Tensor::from_slice(&(input.segment_ids))
770            })
771            .collect::<Vec<_>>();
772
773        let input_ids = Tensor::stack(input_ids.as_slice(), 0).to(self.device);
774        let token_type_ids = Tensor::stack(token_type_ids.as_slice(), 0)
775            .to(self.device)
776            .to_kind(Kind::Int64);
777        let mask = input_ids
778            .ne(self
779                .tokenizer
780                .get_pad_id()
781                .expect("The Tokenizer used for zero shot classification should contain a PAD id"))
782            .to_kind(Bool);
783
784        Ok((input_ids, mask, token_type_ids))
785    }
786
787    /// Zero shot classification with 1 (and exactly 1) true label.
788    ///
789    /// # Arguments
790    ///
791    /// * `input` - `&[&str]` Array of texts to classify.
792    /// * `labels` - `&[&str]` Possible labels for the inputs.
793    /// * `template` - `Option<Box<dyn Fn(&str) -> String>>` closure to build label propositions. If None, will default to `"This example is {}."`.
794    /// * `max_length` -`usize` Maximum sequence length for the inputs. If needed, the input sequence will be truncated before the label template.
795    ///
796    /// # Returns
797    ///
798    /// * `Result<Vec<Label>, RustBertError>` containing the most likely label for each input sentence or error, if any.
799    ///
800    /// # Example
801    ///
802    /// ```no_run
803    /// # fn main() -> anyhow::Result<()> {
804    /// use rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationModel;
805    ///
806    /// let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?;
807    ///
808    /// let input_sentence = "Who are you voting for in 2020?";
809    /// let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition.";
810    /// let candidate_labels = &["politics", "public health", "economics", "sports"];
811    ///
812    /// let output = sequence_classification_model.predict(
813    ///     &[input_sentence, input_sequence_2],
814    ///     candidate_labels,
815    ///     None,
816    ///     128,
817    /// );
818    /// # Ok(())
819    /// # }
820    /// ```
821    ///
822    /// outputs:
823    /// ```no_run
824    /// # use rust_bert::pipelines::sequence_classification::Label;
825    /// let output = [
826    ///     Label {
827    ///         text: "politics".to_string(),
828    ///         score: 0.959,
829    ///         id: 0,
830    ///         sentence: 0,
831    ///     },
832    ///     Label {
833    ///         text: "economy".to_string(),
834    ///         score: 0.642,
835    ///         id: 2,
836    ///         sentence: 1,
837    ///     },
838    /// ]
839    /// .to_vec();
840    /// ```
841    pub fn predict<'a, S, T>(
842        &self,
843        inputs: S,
844        labels: T,
845        template: Option<ZeroShotTemplate>,
846        max_length: usize,
847    ) -> Result<Vec<Label>, RustBertError>
848    where
849        S: AsRef<[&'a str]>,
850        T: AsRef<[&'a str]>,
851    {
852        let num_inputs = inputs.as_ref().len();
853        let (input_tensor, mask, token_type_ids) =
854            self.prepare_for_model(inputs.as_ref(), labels.as_ref(), template, max_length)?;
855
856        let output = no_grad(|| {
857            let output = self.zero_shot_classifier.forward_t(
858                Some(&input_tensor),
859                Some(&mask),
860                Some(&token_type_ids),
861                None,
862                None,
863                false,
864            );
865            output.view((num_inputs as i64, labels.as_ref().len() as i64, -1i64))
866        });
867
868        let scores = output.softmax(1, Float).select(-1, -1);
869        let label_indices = scores.as_ref().argmax(-1, true).squeeze_dim(1);
870        let scores = scores
871            .gather(1, &label_indices.unsqueeze(-1), false)
872            .squeeze_dim(1);
873        let label_indices = label_indices.iter::<i64>()?.collect::<Vec<i64>>();
874        let scores = scores.iter::<f64>()?.collect::<Vec<f64>>();
875
876        let mut output_labels: Vec<Label> = vec![];
877        for sentence_idx in 0..label_indices.len() {
878            let label_string = labels.as_ref()[label_indices[sentence_idx] as usize].to_string();
879            let label = Label {
880                text: label_string,
881                score: scores[sentence_idx],
882                id: label_indices[sentence_idx],
883                sentence: sentence_idx,
884            };
885            output_labels.push(label)
886        }
887        Ok(output_labels)
888    }
889
890    /// Zero shot multi-label classification with 0, 1 or no true label.
891    ///
892    /// # Arguments
893    ///
894    /// * `input` - `&[&str]` Array of texts to classify.
895    /// * `labels` - `&[&str]` Possible labels for the inputs.
896    /// * `template` - `Option<Box<dyn Fn(&str) -> String>>` closure to build label propositions. If None, will default to `"This example is about {}."`.
897    /// * `max_length` -`usize` Maximum sequence length for the inputs. If needed, the input sequence will be truncated before the label template.
898    ///
899    /// # Returns
900    ///
901    /// * `Result<Vec<Vec<Label>>, RustBertError>` containing a vector of labels and their probability for each input text, or error, if any.
902    ///
903    /// # Example
904    ///
905    /// ```no_run
906    /// # fn main() -> anyhow::Result<()> {
907    /// use rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationModel;
908    ///
909    /// let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?;
910    ///
911    /// let input_sentence = "Who are you voting for in 2020?";
912    /// let input_sequence_2 = "The central bank is meeting today to discuss monetary policy.";
913    /// let candidate_labels = &["politics", "public health", "economics", "sports"];
914    ///
915    /// let output = sequence_classification_model.predict_multilabel(
916    ///     &[input_sentence, input_sequence_2],
917    ///     candidate_labels,
918    ///     None,
919    ///     128,
920    /// );
921    /// # Ok(())
922    /// # }
923    /// ```
924    /// outputs:
925    /// ```no_run
926    /// # use rust_bert::pipelines::sequence_classification::Label;
927    /// let output = [
928    ///     [
929    ///         Label {
930    ///             text: "politics".to_string(),
931    ///             score: 0.972,
932    ///             id: 0,
933    ///             sentence: 0,
934    ///         },
935    ///         Label {
936    ///             text: "public health".to_string(),
937    ///             score: 0.032,
938    ///             id: 1,
939    ///             sentence: 0,
940    ///         },
941    ///         Label {
942    ///             text: "economy".to_string(),
943    ///             score: 0.006,
944    ///             id: 2,
945    ///             sentence: 0,
946    ///         },
947    ///         Label {
948    ///             text: "sports".to_string(),
949    ///             score: 0.004,
950    ///             id: 3,
951    ///             sentence: 0,
952    ///         },
953    ///     ],
954    ///     [
955    ///         Label {
956    ///             text: "politics".to_string(),
957    ///             score: 0.975,
958    ///             id: 0,
959    ///             sentence: 1,
960    ///         },
961    ///         Label {
962    ///             text: "economy".to_string(),
963    ///             score: 0.852,
964    ///             id: 2,
965    ///             sentence: 1,
966    ///         },
967    ///         Label {
968    ///             text: "public health".to_string(),
969    ///             score: 0.0818,
970    ///             id: 1,
971    ///             sentence: 1,
972    ///         },
973    ///         Label {
974    ///             text: "sports".to_string(),
975    ///             score: 0.001,
976    ///             id: 3,
977    ///             sentence: 1,
978    ///         },
979    ///     ],
980    /// ]
981    /// .to_vec();
982    /// ```
983    pub fn predict_multilabel<'a, S, T>(
984        &self,
985        inputs: S,
986        labels: T,
987        template: Option<ZeroShotTemplate>,
988        max_length: usize,
989    ) -> Result<Vec<Vec<Label>>, RustBertError>
990    where
991        S: AsRef<[&'a str]>,
992        T: AsRef<[&'a str]>,
993    {
994        let num_inputs = inputs.as_ref().len();
995        let (input_tensor, mask, token_type_ids) =
996            self.prepare_for_model(inputs.as_ref(), labels.as_ref(), template, max_length)?;
997
998        let output = no_grad(|| {
999            let output = self.zero_shot_classifier.forward_t(
1000                Some(&input_tensor),
1001                Some(&mask),
1002                Some(&token_type_ids),
1003                None,
1004                None,
1005                false,
1006            );
1007            output.view((num_inputs as i64, labels.as_ref().len() as i64, -1i64))
1008        });
1009        let scores = output.slice(-1, 0, 3, 2).softmax(-1, Float).select(-1, -1);
1010
1011        let mut output_labels = vec![];
1012        for sentence_idx in 0..num_inputs {
1013            let mut sentence_labels = vec![];
1014
1015            for (label_index, score) in scores
1016                .select(0, sentence_idx as i64)
1017                .iter::<f64>()?
1018                .enumerate()
1019            {
1020                let label_string = labels.as_ref()[label_index].to_string();
1021                let label = Label {
1022                    text: label_string,
1023                    score,
1024                    id: label_index as i64,
1025                    sentence: sentence_idx,
1026                };
1027                sentence_labels.push(label);
1028            }
1029            output_labels.push(sentence_labels);
1030        }
1031        Ok(output_labels)
1032    }
1033}
1034#[cfg(test)]
1035mod test {
1036    use super::*;
1037
1038    #[test]
1039    #[ignore] // no need to run, compilation is enough to verify it is Send
1040    fn test() {
1041        let config = ZeroShotClassificationConfig::default();
1042        let _: Box<dyn Send> = Box::new(ZeroShotClassificationModel::new(config));
1043    }
1044}