Module token_classification

Source
Expand description

§Token classification pipeline (Named Entity Recognition, Part-of-Speech tagging)

More generic token classification pipeline, works with multiple models (Bert, Roberta)

use rust_bert::pipelines::token_classification::{TokenClassificationModel,TokenClassificationConfig};
use rust_bert::resources::RemoteResource;
use rust_bert::bert::{BertModelResources, BertVocabResources, BertConfigResources};
use rust_bert::pipelines::common::ModelType;

use rust_bert::pipelines::common::ModelResource;
//Load a configuration
use rust_bert::pipelines::token_classification::LabelAggregationOption;
let config = TokenClassificationConfig::new(
   ModelType::Bert,
   ModelResource::Torch(Box::new(RemoteResource::from_pretrained(BertModelResources::BERT_NER))),
   RemoteResource::from_pretrained(BertVocabResources::BERT_NER),
   RemoteResource::from_pretrained(BertConfigResources::BERT_NER),
   None, //merges resource only relevant with ModelType::Roberta
   false, //lowercase
   None, //strip_accents
   None, //add_prefix_space
   LabelAggregationOption::Mode
);

//Create the model
let token_classification_model = TokenClassificationModel::new(config)?;

let input = [
    "My name is Amy. I live in Paris.",
    "Paris is a city in France."
];
let output = token_classification_model.predict(&input, true, true); //ignore_first_label = true (only returns the NER parts, ignoring first label O)

Output: \

use rust_tokenizers::{Mask, Offset};
[
    Token {
        text: String::from("[CLS]"),
        score: 0.9995001554489136,
        label: String::from("O"),
        label_index: 0,
        sentence: 0,
        index: 0,
        word_index: 0,
        offset: None,
        mask: Mask::Special,
    },
    Token {
        text: String::from("My"),
        score: 0.9980450868606567,
        label: String::from("O"),
        label_index: 0,
        sentence: 0,
        index: 1,
        word_index: 1,
        offset: Some(Offset { begin: 0, end: 2 }),
        mask: Mask::None,
    },
    Token {
        text: String::from("name"),
        score: 0.9995062351226807,
        label: String::from("O"),
        label_index: 0,
        sentence: 0,
        index: 2,
        word_index: 2,
        offset: Some(Offset { begin: 3, end: 7 }),
        mask: Mask::None,
    },
    Token {
        text: String::from("is"),
        score: 0.9997343420982361,
        label: String::from("O"),
        label_index: 0,
        sentence: 0,
        index: 3,
        word_index: 3,
        offset: Some(Offset { begin: 8, end: 10 }),
        mask: Mask::None,
    },
    Token {
        text: String::from("Amélie"),
        score: 0.9913727683112525,
        label: String::from("I-PER"),
        label_index: 4,
        sentence: 0,
        index: 4,
        word_index: 4,
        offset: Some(Offset { begin: 11, end: 17 }),
        mask: Mask::None,
    }, // ...
]

Structs§

Token
Token generated by a TokenClassificationModel
TokenClassificationConfig
Configuration for TokenClassificationModel
TokenClassificationModel
TokenClassificationModel for Named Entity Recognition or Part-of-Speech tagging

Enums§

LabelAggregationOption
Enum defining the label aggregation method for sub tokens
TokenClassificationOption
Abstraction that holds one particular token sequence classifier model, for any of the supported models