use tch::nn::VarStore;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TokenizedInput, TruncationStrategy, Mask, Offset, ConsolidatableTokens, ConsolidatedTokenIterator, TokenTrait};
use std::collections::HashMap;
use tch::{Tensor, no_grad, Device};
use tch::kind::Kind::Float;
use crate::bert::{BertForTokenClassification, BertModelResources, BertConfigResources, BertVocabResources};
use crate::roberta::RobertaForTokenClassification;
use crate::distilbert::DistilBertForTokenClassification;
use crate::common::resources::{Resource, RemoteResource, download_resource};
use crate::pipelines::common::{ModelType, ConfigOption, TokenizerOption};
use crate::electra::ElectraForTokenClassification;
use itertools::Itertools;
use std::cmp::min;
use serde::{Serialize, Deserialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Token {
pub text: String,
pub score: f64,
pub label: String,
pub label_index: i64,
pub sentence: usize,
pub index: u16,
pub word_index: u16,
pub offset: Option<Offset>,
pub mask: Mask,
}
impl TokenTrait for Token {
fn offset(&self) -> Option<Offset> {
self.offset
}
fn mask(&self) -> Mask {
self.mask
}
fn as_str(&self) -> &str {
self.text.as_str()
}
}
impl ConsolidatableTokens<Token> for Vec<Token> {
fn iter_consolidate_tokens(&self) -> ConsolidatedTokenIterator<Token> {
ConsolidatedTokenIterator::new(self)
}
}
pub enum LabelAggregationOption {
First,
Last,
Mode,
Custom(Box<dyn Fn(&[Token]) -> (i64, String)>),
}
pub struct TokenClassificationConfig {
pub model_type: ModelType,
pub model_resource: Resource,
pub config_resource: Resource,
pub vocab_resource: Resource,
pub merges_resource: Option<Resource>,
pub lower_case: bool,
pub device: Device,
pub label_aggregation_function: LabelAggregationOption,
}
impl TokenClassificationConfig {
pub fn new(model_type: ModelType,
model_resource: Resource,
config_resource: Resource,
vocab_resource: Resource,
merges_resource: Option<Resource>,
lower_case: bool,
label_aggregation_function: LabelAggregationOption) -> TokenClassificationConfig {
TokenClassificationConfig {
model_type,
model_resource,
config_resource,
vocab_resource,
merges_resource,
lower_case,
device: Device::cuda_if_available(),
label_aggregation_function,
}
}
}
impl Default for TokenClassificationConfig {
fn default() -> TokenClassificationConfig {
TokenClassificationConfig {
model_type: ModelType::Bert,
model_resource: Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_NER)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT_NER)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_NER)),
merges_resource: None,
lower_case: false,
device: Device::cuda_if_available(),
label_aggregation_function: LabelAggregationOption::First,
}
}
}
pub enum TokenClassificationOption {
Bert(BertForTokenClassification),
DistilBert(DistilBertForTokenClassification),
Roberta(RobertaForTokenClassification),
Electra(ElectraForTokenClassification),
}
impl TokenClassificationOption {
pub fn new(model_type: ModelType, p: &tch::nn::Path, config: &ConfigOption) -> Self {
match model_type {
ModelType::Bert => {
if let ConfigOption::Bert(config) = config {
TokenClassificationOption::Bert(BertForTokenClassification::new(p, config))
} else {
panic!("You can only supply a BertConfig for Bert!");
}
}
ModelType::DistilBert => {
if let ConfigOption::DistilBert(config) = config {
TokenClassificationOption::DistilBert(DistilBertForTokenClassification::new(p, config))
} else {
panic!("You can only supply a DistilBertConfig for DistilBert!");
}
}
ModelType::Roberta => {
if let ConfigOption::Bert(config) = config {
TokenClassificationOption::Roberta(RobertaForTokenClassification::new(p, config))
} else {
panic!("You can only supply a BertConfig for Roberta!");
}
}
ModelType::Electra => {
if let ConfigOption::Electra(config) = config {
TokenClassificationOption::Electra(ElectraForTokenClassification::new(p, config))
} else {
panic!("You can only supply a BertConfig for Roberta!");
}
}
}
}
pub fn model_type(&self) -> ModelType {
match *self {
Self::Bert(_) => ModelType::Bert,
Self::Roberta(_) => ModelType::Roberta,
Self::DistilBert(_) => ModelType::DistilBert,
Self::Electra(_) => ModelType::Electra
}
}
fn forward_t(&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
match *self {
Self::Bert(ref model) => model.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train),
Self::DistilBert(ref model) => model.forward_t(input_ids, mask, input_embeds, train).expect("Error in distilbert forward_t"),
Self::Roberta(ref model) => model.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train),
Self::Electra(ref model) => model.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train),
}
}
}
pub struct TokenClassificationModel {
tokenizer: TokenizerOption,
token_sequence_classifier: TokenClassificationOption,
label_mapping: HashMap<i64, String>,
var_store: VarStore,
label_aggregation_function: LabelAggregationOption,
}
impl TokenClassificationModel {
pub fn new(config: TokenClassificationConfig) -> failure::Fallible<TokenClassificationModel> {
let config_path = download_resource(&config.config_resource)?;
let vocab_path = download_resource(&config.vocab_resource)?;
let weights_path = download_resource(&config.model_resource)?;
let merges_path = if let Some(merges_resource) = &config.merges_resource {
Some(download_resource(merges_resource).expect("Failure downloading resource"))
} else {
None
};
let device = config.device;
let label_aggregation_function = config.label_aggregation_function;
let tokenizer = TokenizerOption::from_file(config.model_type, vocab_path.to_str().unwrap(), merges_path.map(|path| path.to_str().unwrap()), config.lower_case);
let mut var_store = VarStore::new(device);
let model_config = ConfigOption::from_file(config.model_type, config_path);
let token_sequence_classifier = TokenClassificationOption::new(config.model_type, &var_store.root(), &model_config);
let label_mapping = model_config.get_label_mapping();
var_store.load(weights_path)?;
Ok(TokenClassificationModel { tokenizer, token_sequence_classifier, label_mapping, var_store, label_aggregation_function })
}
fn prepare_for_model(&self, input: Vec<&str>) -> (Vec<TokenizedInput>, Tensor) {
let tokenized_input: Vec<TokenizedInput> = self.tokenizer.encode_list(input.to_vec(),
128,
&TruncationStrategy::LongestFirst,
0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input_tensors: Vec<tch::Tensor> = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
(tokenized_input, Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device()))
}
pub fn predict(&self, input: &[&str], consolidate_sub_tokens: bool, return_special: bool) -> Vec<Token> {
let (tokenized_input, input_tensor) = self.prepare_for_model(input.to_vec());
let (output, _, _) = no_grad(|| {
self.token_sequence_classifier
.forward_t(Some(input_tensor.copy()),
None,
None,
None,
None,
false)
});
let output = output.detach().to(Device::Cpu);
let score: Tensor = output.exp() / output.exp().sum1(&[-1], true, Float);
let labels_idx = &score.argmax(-1, true);
let mut tokens: Vec<Token> = vec!();
for sentence_idx in 0..labels_idx.size()[0] {
let labels = labels_idx.get(sentence_idx);
let sentence_tokens = &tokenized_input[sentence_idx as usize];
let original_chars = input[sentence_idx as usize].chars().collect_vec();
let mut word_idx: u16 = 0;
for position_idx in 0..sentence_tokens.token_ids.len() {
let mask = sentence_tokens.mask[position_idx];
if (mask == Mask::Special) & (!return_special) {
continue;
}
if !(mask == Mask::Continuation) {
word_idx += 1;
}
let token = {
self.decode_token(&original_chars, sentence_tokens, &input_tensor, &labels, &score, sentence_idx, position_idx as i64, word_idx - 1)
};
tokens.push(token);
}
}
if consolidate_sub_tokens {
self.consolidate_tokens(&mut tokens, &self.label_aggregation_function);
}
tokens
}
fn decode_token(&self, original_sentence_chars: &Vec<char>, sentence_tokens: &TokenizedInput, input_tensor: &Tensor,
labels: &Tensor, score: &Tensor, sentence_idx: i64, position_idx: i64, word_index: u16) -> Token {
let label_id = labels.int64_value(&[position_idx as i64]);
let token_id = input_tensor.int64_value(&[sentence_idx, position_idx as i64]);
let offsets = &sentence_tokens.token_offsets[position_idx as usize];
let text = match offsets {
None => match self.tokenizer {
TokenizerOption::Bert(ref tokenizer) => Tokenizer::decode(tokenizer, vec!(token_id), false, false),
TokenizerOption::Roberta(ref tokenizer) => Tokenizer::decode(tokenizer, vec!(token_id), false, false),
},
Some(offsets) => {
let (start_char, end_char) = (offsets.begin as usize, offsets.end as usize);
let end_char = min(end_char, original_sentence_chars.len());
let text = original_sentence_chars[start_char..end_char].iter().collect();
text
}
};
Token {
text,
score: score.double_value(&[sentence_idx, position_idx, label_id]),
label: self.label_mapping.get(&label_id).expect("Index out of vocabulary bounds.").to_owned(),
label_index: label_id,
sentence: sentence_idx as usize,
index: position_idx as u16,
word_index,
offset: offsets.to_owned(),
mask: sentence_tokens.mask[position_idx as usize],
}
}
fn consolidate_tokens(&self, tokens: &mut Vec<Token>, label_aggregation_function: &LabelAggregationOption) {
let mut tokens_to_replace = vec!();
let mut token_iter = tokens.iter_consolidate_tokens();
let mut cursor = 0;
while let Some(sub_tokens) = token_iter.next() {
if sub_tokens.len() > 1 {
let (label_index, label) = self.consolidate_labels(sub_tokens, label_aggregation_function);
let sentence = (&sub_tokens[0]).sentence;
let index = (&sub_tokens[0]).index;
let word_index = (&sub_tokens[0]).word_index;
let offset_start = match &sub_tokens.first().unwrap().offset {
Some(offset) => Some(offset.begin),
None => None
};
let offset_end = match &sub_tokens.last().unwrap().offset {
Some(offset) => Some(offset.end),
None => None
};
let offset = if offset_start.is_some() & offset_end.is_some() {
Some(Offset::new(offset_start.unwrap(), offset_end.unwrap()))
} else {
None
};
let mut text = String::new();
let mut score = 1f64;
for current_sub_token in sub_tokens.into_iter() {
text.push_str(current_sub_token.text.as_str());
score *= if current_sub_token.label_index == label_index {
current_sub_token.score
} else {
1.0 - current_sub_token.score
};
}
let token = Token {
text,
score,
label,
label_index,
sentence,
index,
word_index,
offset,
mask: Default::default(),
};
tokens_to_replace.push(((cursor, cursor + sub_tokens.len()), token));
}
cursor += sub_tokens.len();
}
for ((start, end), token) in tokens_to_replace.into_iter().rev() {
tokens.splice(start..end, [token].iter().cloned());
}
}
fn consolidate_labels(&self, tokens: &[Token], aggregation: &LabelAggregationOption) -> (i64, String) {
match aggregation {
LabelAggregationOption::First => {
let token = tokens.first().unwrap();
(token.label_index, token.label.clone())
}
LabelAggregationOption::Last => {
let token = tokens.last().unwrap();
(token.label_index, token.label.clone())
}
LabelAggregationOption::Mode => {
let counts = tokens
.iter()
.fold(
HashMap::new(),
|mut m, c| {
*m.entry((c.label_index, c.label.as_str())).or_insert(0) += 1;
m
},
);
counts
.into_iter()
.max_by(|a, b| a.1.cmp(&b.1))
.map(|((label_index, label), _)| (label_index, label.to_owned()))
.unwrap()
}
LabelAggregationOption::Custom(function) => function(tokens)
}
}
}