Module zero_shot_classification

Source
Expand description

§Zero-shot classification pipeline

Performs zero-shot classification on input sentences with provided labels using a model fine-tuned for Natural Language Inference. The default model is a BART model fine-tuned on a MNLI. From a list of input sequences to classify and a list of target labels, single-class or multi-label classification is performed, translating the classification task to an inference task. The default template for translation to inference task is This example is about {}.. This template can be updated to a more specific value that may match better the use case, for example This review is about a {product_class}.

  • predict performs single-class classification (one and exactly one label must be true for each provided input)
  • predict_multilabel performs multi-label classification (zero, one or more labels may be true for each provided input)
let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?;
let input_sentence = "Who are you voting for in 2020?";
let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition.";
let candidate_labels = &["politics", "public health", "economics", "sports"];
let output = sequence_classification_model.predict_multilabel(
    &[input_sentence, input_sequence_2],
    candidate_labels,
    None,
    128,
);

outputs:

let output = [
    [
        Label {
            text: "politics".to_string(),
            score: 0.972,
            id: 0,
            sentence: 0,
        },
        Label {
            text: "public health".to_string(),
            score: 0.032,
            id: 1,
            sentence: 0,
        },
        Label {
            text: "economy".to_string(),
            score: 0.006,
            id: 2,
            sentence: 0,
        },
        Label {
            text: "sports".to_string(),
            score: 0.004,
            id: 3,
            sentence: 0,
        },
    ],
    [
        Label {
            text: "politics".to_string(),
            score: 0.943,
            id: 0,
            sentence: 1,
        },
        Label {
            text: "economy".to_string(),
            score: 0.985,
            id: 2,
            sentence: 1,
        },
        Label {
            text: "public health".to_string(),
            score: 0.0818,
            id: 1,
            sentence: 1,
        },
        Label {
            text: "sports".to_string(),
            score: 0.001,
            id: 3,
            sentence: 1,
        },
    ],
]
.to_vec();

Structs§

ZeroShotClassificationConfig
Configuration for ZeroShotClassificationModel
ZeroShotClassificationModel
Template used to transform the zero-shot classification labels into a set of natural language hypotheses for natural language inference.

Enums§

ZeroShotClassificationOption
Abstraction that holds one particular zero shot classification model, for any of the supported models

Type Aliases§

ZeroShotTemplate