Expand description
§Token classification pipeline (Named Entity Recognition, Part-of-Speech tagging)
More generic token classification pipeline, works with multiple models (Bert, Roberta)
use rust_bert::pipelines::token_classification::{TokenClassificationModel,TokenClassificationConfig};
use rust_bert::resources::RemoteResource;
use rust_bert::bert::{BertModelResources, BertVocabResources, BertConfigResources};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::ModelResource;
//Load a configuration
use rust_bert::pipelines::token_classification::LabelAggregationOption;
let config = TokenClassificationConfig::new(
ModelType::Bert,
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(BertModelResources::BERT_NER))),
RemoteResource::from_pretrained(BertVocabResources::BERT_NER),
RemoteResource::from_pretrained(BertConfigResources::BERT_NER),
None, //merges resource only relevant with ModelType::Roberta
false, //lowercase
None, //strip_accents
None, //add_prefix_space
LabelAggregationOption::Mode
);
//Create the model
let token_classification_model = TokenClassificationModel::new(config)?;
let input = [
"My name is Amy. I live in Paris.",
"Paris is a city in France."
];
let output = token_classification_model.predict(&input, true, true); //ignore_first_label = true (only returns the NER parts, ignoring first label O)
Output: \
use rust_tokenizers::{Mask, Offset};
[
Token {
text: String::from("[CLS]"),
score: 0.9995001554489136,
label: String::from("O"),
label_index: 0,
sentence: 0,
index: 0,
word_index: 0,
offset: None,
mask: Mask::Special,
},
Token {
text: String::from("My"),
score: 0.9980450868606567,
label: String::from("O"),
label_index: 0,
sentence: 0,
index: 1,
word_index: 1,
offset: Some(Offset { begin: 0, end: 2 }),
mask: Mask::None,
},
Token {
text: String::from("name"),
score: 0.9995062351226807,
label: String::from("O"),
label_index: 0,
sentence: 0,
index: 2,
word_index: 2,
offset: Some(Offset { begin: 3, end: 7 }),
mask: Mask::None,
},
Token {
text: String::from("is"),
score: 0.9997343420982361,
label: String::from("O"),
label_index: 0,
sentence: 0,
index: 3,
word_index: 3,
offset: Some(Offset { begin: 8, end: 10 }),
mask: Mask::None,
},
Token {
text: String::from("Amélie"),
score: 0.9913727683112525,
label: String::from("I-PER"),
label_index: 4,
sentence: 0,
index: 4,
word_index: 4,
offset: Some(Offset { begin: 11, end: 17 }),
mask: Mask::None,
}, // ...
]
Structs§
- Token
- Token generated by a
TokenClassificationModel
- Token
Classification Config - Configuration for TokenClassificationModel
- Token
Classification Model - TokenClassificationModel for Named Entity Recognition or Part-of-Speech tagging
Enums§
- Label
Aggregation Option - Enum defining the label aggregation method for sub tokens
- Token
Classification Option - Abstraction that holds one particular token sequence classifier model, for any of the supported models