Struct rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationModel[][src]

pub struct ZeroShotClassificationModel { /* fields omitted */ }

Implementations

impl ZeroShotClassificationModel[src]

pub fn new(
    config: ZeroShotClassificationConfig
) -> Result<ZeroShotClassificationModel, RustBertError>
[src]

Build a new ZeroShotClassificationModel

Arguments

  • config - SequenceClassificationConfig object containing the resource references (model, vocabulary, configuration) and device placement (CPU/GPU)

Example

use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;

let model = SequenceClassificationModel::new(Default::default())?;

pub fn predict<'a, S, T>(
    &self,
    inputs: S,
    labels: T,
    template: Option<Box<dyn Fn(&str) -> String>>,
    max_length: usize
) -> Vec<Label> where
    S: AsRef<[&'a str]>,
    T: AsRef<[&'a str]>, 
[src]

Zero shot classification with 1 (and exactly 1) true label.

Arguments

  • input - &[&str] Array of texts to classify.
  • labels - &[&str] Possible labels for the inputs.
  • template - Option<Box<dyn Fn(&str) -> String>> closure to build label propositions. If None, will default to "This example is {}.".
  • max_length -usize Maximum sequence length for the inputs. If needed, the input sequence will be truncated before the label template.

Returns

  • Vec<Label> containing with the most likely label for each input sentence.

Example

use rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationModel;

let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?;

let input_sentence = "Who are you voting for in 2020?";
let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition.";
let candidate_labels = &["politics", "public health", "economics", "sports"];

let output = sequence_classification_model.predict(
    &[input_sentence, input_sequence_2],
    candidate_labels,
    None,
    128,
);

outputs:

let output = [
    Label {
        text: "politics".to_string(),
        score: 0.959,
        id: 0,
        sentence: 0,
    },
    Label {
        text: "economy".to_string(),
        score: 0.642,
        id: 2,
        sentence: 1,
    },
]
.to_vec();

pub fn predict_multilabel<'a, S, T>(
    &self,
    inputs: S,
    labels: T,
    template: Option<Box<dyn Fn(&str) -> String>>,
    max_length: usize
) -> Vec<Vec<Label>> where
    S: AsRef<[&'a str]>,
    T: AsRef<[&'a str]>, 
[src]

Zero shot multi-label classification with 0, 1 or no true label.

Arguments

  • input - &[&str] Array of texts to classify.
  • labels - &[&str] Possible labels for the inputs.
  • template - Option<Box<dyn Fn(&str) -> String>> closure to build label propositions. If None, will default to "This example is about {}.".
  • max_length -usize Maximum sequence length for the inputs. If needed, the input sequence will be truncated before the label template.

Returns

  • Vec<Vec<Label>> containing a vector of labels and their probability for each input text

Example

use rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationModel;

let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?;

let input_sentence = "Who are you voting for in 2020?";
let input_sequence_2 = "The central bank is meeting today to discuss monetary policy.";
let candidate_labels = &["politics", "public health", "economics", "sports"];

let output = sequence_classification_model.predict_multilabel(
    &[input_sentence, input_sequence_2],
    candidate_labels,
    None,
    128,
);

outputs:

let output = [
    [
        Label {
            text: "politics".to_string(),
            score: 0.972,
            id: 0,
            sentence: 0,
        },
        Label {
            text: "public health".to_string(),
            score: 0.032,
            id: 1,
            sentence: 0,
        },
        Label {
            text: "economy".to_string(),
            score: 0.006,
            id: 2,
            sentence: 0,
        },
        Label {
            text: "sports".to_string(),
            score: 0.004,
            id: 3,
            sentence: 0,
        },
    ],
    [
        Label {
            text: "politics".to_string(),
            score: 0.975,
            id: 0,
            sentence: 1,
        },
        Label {
            text: "economy".to_string(),
            score: 0.852,
            id: 2,
            sentence: 1,
        },
        Label {
            text: "public health".to_string(),
            score: 0.0818,
            id: 1,
            sentence: 1,
        },
        Label {
            text: "sports".to_string(),
            score: 0.001,
            id: 3,
            sentence: 1,
        },
    ],
]
.to_vec();

Auto Trait Implementations

Blanket Implementations

impl<T> Any for T where
    T: 'static + ?Sized
[src]

impl<T> Borrow<T> for T where
    T: ?Sized
[src]

impl<T> BorrowMut<T> for T where
    T: ?Sized
[src]

impl<T> From<T> for T[src]

impl<T> Instrument for T[src]

impl<T, U> Into<U> for T where
    U: From<T>, 
[src]

impl<T> Pointable for T

type Init = T

The type for initializers.

impl<T> Same<T> for T

type Output = T

Should always be Self

impl<T, U> TryFrom<U> for T where
    U: Into<T>, 
[src]

type Error = Infallible

The type returned in the event of a conversion error.

impl<T, U> TryInto<U> for T where
    U: TryFrom<T>, 
[src]

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.

impl<V, T> VZip<V> for T where
    V: MultiLane<T>,