pub struct ZeroShotClassificationModel { /* private fields */ }
Expand description
Template used to transform the zero-shot classification labels into a set of natural language hypotheses for natural language inference.
For example, transform [positive, negative]
into
[This is a positive review, This is a negative review]
The function should take a &str
as an input and return the formatted String.
This transformation has a strong impact on the resulting classification accuracy. If no function is provided for zero-shot classification, the default templating function will be used:
fn default_template(label: &str) -> String {
format!("This example is about {}.", label)
}
§ZeroShotClassificationModel for Zero Shot Classification
Implementations§
Source§impl ZeroShotClassificationModel
impl ZeroShotClassificationModel
Sourcepub fn new(
config: ZeroShotClassificationConfig,
) -> Result<ZeroShotClassificationModel, RustBertError>
pub fn new( config: ZeroShotClassificationConfig, ) -> Result<ZeroShotClassificationModel, RustBertError>
Build a new ZeroShotClassificationModel
§Arguments
config
-SequenceClassificationConfig
object containing the resource references (model, vocabulary, configuration) and device placement (CPU/GPU)
§Example
use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
let model = SequenceClassificationModel::new(Default::default())?;
Sourcepub fn new_with_tokenizer(
config: ZeroShotClassificationConfig,
tokenizer: TokenizerOption,
) -> Result<ZeroShotClassificationModel, RustBertError>
pub fn new_with_tokenizer( config: ZeroShotClassificationConfig, tokenizer: TokenizerOption, ) -> Result<ZeroShotClassificationModel, RustBertError>
Build a new ZeroShotClassificationModel
with a provided tokenizer.
§Arguments
config
-SequenceClassificationConfig
object containing the resource references (model, vocabulary, configuration) and device placement (CPU/GPU)tokenizer
-TokenizerOption
tokenizer to use for zero-shot classification.
§Example
use rust_bert::pipelines::common::{ModelType, TokenizerOption};
use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
let tokenizer = TokenizerOption::from_file(
ModelType::Bert,
"path/to/vocab.txt",
None,
false,
None,
None,
)?;
let model = SequenceClassificationModel::new_with_tokenizer(Default::default(), tokenizer)?;
Sourcepub fn get_tokenizer(&self) -> &TokenizerOption
pub fn get_tokenizer(&self) -> &TokenizerOption
Get a reference to the model tokenizer.
Sourcepub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption
pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption
Get a mutable reference to the model tokenizer.
Sourcepub fn predict<'a, S, T>(
&self,
inputs: S,
labels: T,
template: Option<ZeroShotTemplate>,
max_length: usize,
) -> Result<Vec<Label>, RustBertError>
pub fn predict<'a, S, T>( &self, inputs: S, labels: T, template: Option<ZeroShotTemplate>, max_length: usize, ) -> Result<Vec<Label>, RustBertError>
Zero shot classification with 1 (and exactly 1) true label.
§Arguments
input
-&[&str]
Array of texts to classify.labels
-&[&str]
Possible labels for the inputs.template
-Option<Box<dyn Fn(&str) -> String>>
closure to build label propositions. If None, will default to"This example is {}."
.max_length
-usize
Maximum sequence length for the inputs. If needed, the input sequence will be truncated before the label template.
§Returns
Result<Vec<Label>, RustBertError>
containing the most likely label for each input sentence or error, if any.
§Example
use rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationModel;
let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?;
let input_sentence = "Who are you voting for in 2020?";
let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition.";
let candidate_labels = &["politics", "public health", "economics", "sports"];
let output = sequence_classification_model.predict(
&[input_sentence, input_sequence_2],
candidate_labels,
None,
128,
);
outputs:
let output = [
Label {
text: "politics".to_string(),
score: 0.959,
id: 0,
sentence: 0,
},
Label {
text: "economy".to_string(),
score: 0.642,
id: 2,
sentence: 1,
},
]
.to_vec();
Sourcepub fn predict_multilabel<'a, S, T>(
&self,
inputs: S,
labels: T,
template: Option<ZeroShotTemplate>,
max_length: usize,
) -> Result<Vec<Vec<Label>>, RustBertError>
pub fn predict_multilabel<'a, S, T>( &self, inputs: S, labels: T, template: Option<ZeroShotTemplate>, max_length: usize, ) -> Result<Vec<Vec<Label>>, RustBertError>
Zero shot multi-label classification with 0, 1 or no true label.
§Arguments
input
-&[&str]
Array of texts to classify.labels
-&[&str]
Possible labels for the inputs.template
-Option<Box<dyn Fn(&str) -> String>>
closure to build label propositions. If None, will default to"This example is about {}."
.max_length
-usize
Maximum sequence length for the inputs. If needed, the input sequence will be truncated before the label template.
§Returns
Result<Vec<Vec<Label>>, RustBertError>
containing a vector of labels and their probability for each input text, or error, if any.
§Example
use rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationModel;
let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?;
let input_sentence = "Who are you voting for in 2020?";
let input_sequence_2 = "The central bank is meeting today to discuss monetary policy.";
let candidate_labels = &["politics", "public health", "economics", "sports"];
let output = sequence_classification_model.predict_multilabel(
&[input_sentence, input_sequence_2],
candidate_labels,
None,
128,
);
outputs:
let output = [
[
Label {
text: "politics".to_string(),
score: 0.972,
id: 0,
sentence: 0,
},
Label {
text: "public health".to_string(),
score: 0.032,
id: 1,
sentence: 0,
},
Label {
text: "economy".to_string(),
score: 0.006,
id: 2,
sentence: 0,
},
Label {
text: "sports".to_string(),
score: 0.004,
id: 3,
sentence: 0,
},
],
[
Label {
text: "politics".to_string(),
score: 0.975,
id: 0,
sentence: 1,
},
Label {
text: "economy".to_string(),
score: 0.852,
id: 2,
sentence: 1,
},
Label {
text: "public health".to_string(),
score: 0.0818,
id: 1,
sentence: 1,
},
Label {
text: "sports".to_string(),
score: 0.001,
id: 3,
sentence: 1,
},
],
]
.to_vec();
Auto Trait Implementations§
impl !Freeze for ZeroShotClassificationModel
impl RefUnwindSafe for ZeroShotClassificationModel
impl Send for ZeroShotClassificationModel
impl !Sync for ZeroShotClassificationModel
impl Unpin for ZeroShotClassificationModel
impl UnwindSafe for ZeroShotClassificationModel
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> Instrument for T
impl<T> Instrument for T
Source§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
Source§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self
into a Left
variant of Either<Self, Self>
if into_left
is true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self
into a Left
variant of Either<Self, Self>
if into_left(&self)
returns true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read more