use crate::albert::AlbertForQuestionAnswering;
use crate::bert::BertForQuestionAnswering;
use crate::common::resources::{download_resource, RemoteResource, Resource};
use crate::distilbert::{
DistilBertConfigResources, DistilBertForQuestionAnswering, DistilBertModelResources,
DistilBertVocabResources,
};
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
use crate::roberta::RobertaForQuestionAnswering;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Mask;
use rust_tokenizers::tokenization_utils::truncate_sequences;
use rust_tokenizers::{TokenizedInput, TruncationStrategy};
use std::borrow::Borrow;
use std::cmp::min;
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use tch::kind::Kind::Float;
use tch::nn::VarStore;
use tch::{nn, no_grad, Device, Tensor};
pub struct QaInput {
pub question: String,
pub context: String,
}
#[derive(Debug)]
struct QaExample {
pub question: String,
pub context: String,
pub doc_tokens: Vec<String>,
pub char_to_word_offset: Vec<i64>,
}
#[derive(Debug)]
struct QaFeature {
pub input_ids: Vec<i64>,
pub attention_mask: Vec<i64>,
pub token_to_orig_map: HashMap<i64, i64>,
pub p_mask: Vec<i8>,
pub example_index: i64,
}
#[derive(Debug, Clone)]
pub struct Answer {
pub score: f64,
pub start: usize,
pub end: usize,
pub answer: String,
}
impl PartialEq for Answer {
fn eq(&self, other: &Self) -> bool {
(self.start == other.start) && (self.end == other.end) && (self.answer == other.answer)
}
}
fn remove_duplicates<T: PartialEq + Clone>(vector: &mut Vec<T>) -> &mut Vec<T> {
let mut potential_duplicates = vec![];
vector.retain(|item| {
if potential_duplicates.contains(item) {
false
} else {
potential_duplicates.push(item.clone());
true
}
});
vector
}
impl QaExample {
pub fn new(question: &str, context: &str) -> QaExample {
let question = question.to_owned();
let (doc_tokens, char_to_word_offset) = QaExample::split_context(context);
QaExample {
question,
context: context.to_owned(),
doc_tokens,
char_to_word_offset,
}
}
fn split_context(context: &str) -> (Vec<String>, Vec<i64>) {
let mut doc_tokens: Vec<String> = vec![];
let mut char_to_word_offset: Vec<i64> = vec![];
let max_length = context.len();
let mut current_word = String::with_capacity(max_length);
let mut previous_whitespace = false;
for character in context.chars() {
char_to_word_offset.push(doc_tokens.len() as i64);
if QaExample::is_whitespace(&character) {
previous_whitespace = true;
if !current_word.is_empty() {
doc_tokens.push(current_word.clone());
current_word = String::with_capacity(max_length);
}
} else {
if previous_whitespace {
current_word = String::with_capacity(max_length);
}
current_word.push(character);
previous_whitespace = false;
}
}
if !current_word.is_empty() {
doc_tokens.push(current_word.clone());
}
(doc_tokens, char_to_word_offset)
}
fn is_whitespace(character: &char) -> bool {
(character == &' ')
| (character == &'\t')
| (character == &'\r')
| (character == &'\n')
| (*character as u32 == 0x202F)
}
}
pub struct QuestionAnsweringConfig {
pub model_resource: Resource,
pub config_resource: Resource,
pub vocab_resource: Resource,
pub merges_resource: Option<Resource>,
pub device: Device,
pub model_type: ModelType,
pub lower_case: bool,
}
impl QuestionAnsweringConfig {
pub fn new(
model_type: ModelType,
model_resource: Resource,
config_resource: Resource,
vocab_resource: Resource,
merges_resource: Option<Resource>,
lower_case: bool,
) -> QuestionAnsweringConfig {
QuestionAnsweringConfig {
model_type,
model_resource,
config_resource,
vocab_resource,
merges_resource,
lower_case,
device: Device::cuda_if_available(),
}
}
}
impl Default for QuestionAnsweringConfig {
fn default() -> QuestionAnsweringConfig {
QuestionAnsweringConfig {
model_resource: Resource::Remote(RemoteResource::from_pretrained(
DistilBertModelResources::DISTIL_BERT_SQUAD,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT_SQUAD,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT_SQUAD,
)),
merges_resource: None,
device: Device::cuda_if_available(),
model_type: ModelType::DistilBert,
lower_case: false,
}
}
}
pub enum QuestionAnsweringOption {
Bert(BertForQuestionAnswering),
DistilBert(DistilBertForQuestionAnswering),
Roberta(RobertaForQuestionAnswering),
XLMRoberta(RobertaForQuestionAnswering),
Albert(AlbertForQuestionAnswering),
}
impl QuestionAnsweringOption {
pub fn new<'p, P>(model_type: ModelType, p: P, config: &ConfigOption) -> Self
where
P: Borrow<nn::Path<'p>>,
{
match model_type {
ModelType::Bert => {
if let ConfigOption::Bert(config) = config {
QuestionAnsweringOption::Bert(BertForQuestionAnswering::new(p, config))
} else {
panic!("You can only supply a BertConfig for Bert!");
}
}
ModelType::DistilBert => {
if let ConfigOption::DistilBert(config) = config {
QuestionAnsweringOption::DistilBert(DistilBertForQuestionAnswering::new(
p, config,
))
} else {
panic!("You can only supply a DistilBertConfig for DistilBert!");
}
}
ModelType::Roberta => {
if let ConfigOption::Bert(config) = config {
QuestionAnsweringOption::Roberta(RobertaForQuestionAnswering::new(p, config))
} else {
panic!("You can only supply a BertConfig for Roberta!");
}
}
ModelType::XLMRoberta => {
if let ConfigOption::Bert(config) = config {
QuestionAnsweringOption::XLMRoberta(RobertaForQuestionAnswering::new(p, config))
} else {
panic!("You can only supply a BertConfig for Roberta!");
}
}
ModelType::Albert => {
if let ConfigOption::Albert(config) = config {
QuestionAnsweringOption::Albert(AlbertForQuestionAnswering::new(p, config))
} else {
panic!("You can only supply an AlbertConfig for Albert!");
}
}
ModelType::Electra => {
panic!("QuestionAnswering not implemented for Electra!");
}
ModelType::Marian => {
panic!("QuestionAnswering not implemented for Marian!");
}
ModelType::T5 => {
panic!("QuestionAnswering not implemented for T5!");
}
}
}
pub fn model_type(&self) -> ModelType {
match *self {
Self::Bert(_) => ModelType::Bert,
Self::Roberta(_) => ModelType::Roberta,
Self::XLMRoberta(_) => ModelType::XLMRoberta,
Self::DistilBert(_) => ModelType::DistilBert,
Self::Albert(_) => ModelType::Albert,
}
}
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Tensor) {
match *self {
Self::Bert(ref model) => {
let outputs = model.forward_t(input_ids, mask, None, None, input_embeds, train);
(outputs.0, outputs.1)
}
Self::DistilBert(ref model) => {
let outputs = model
.forward_t(input_ids, mask, input_embeds, train)
.expect("Error in distilbert forward_t");
(outputs.0, outputs.1)
}
Self::Roberta(ref model) | Self::XLMRoberta(ref model) => {
let outputs = model.forward_t(input_ids, mask, None, None, input_embeds, train);
(outputs.0, outputs.1)
}
Self::Albert(ref model) => {
let outputs = model.forward_t(input_ids, mask, None, None, input_embeds, train);
(outputs.0, outputs.1)
}
}
}
}
pub struct QuestionAnsweringModel {
tokenizer: TokenizerOption,
pad_idx: i64,
sep_idx: i64,
max_seq_len: usize,
doc_stride: usize,
max_query_length: usize,
max_answer_len: usize,
distilbert_qa: QuestionAnsweringOption,
var_store: VarStore,
}
impl QuestionAnsweringModel {
pub fn new(
question_answering_config: QuestionAnsweringConfig,
) -> failure::Fallible<QuestionAnsweringModel> {
let config_path = download_resource(&question_answering_config.config_resource)?;
let vocab_path = download_resource(&question_answering_config.vocab_resource)?;
let weights_path = download_resource(&question_answering_config.model_resource)?;
let merges_path = if let Some(merges_resource) = &question_answering_config.merges_resource
{
Some(download_resource(merges_resource).expect("Failure downloading resource"))
} else {
None
};
let device = question_answering_config.device;
let tokenizer = TokenizerOption::from_file(
question_answering_config.model_type,
vocab_path.to_str().unwrap(),
merges_path.map(|path| path.to_str().unwrap()),
question_answering_config.lower_case,
);
let pad_idx = tokenizer
.get_pad_id()
.expect("The Tokenizer used for Question Answering should contain a PAD id");
let sep_idx = tokenizer
.get_sep_id()
.expect("The Tokenizer used for Question Answering should contain a SEP id");
let mut var_store = VarStore::new(device);
let mut model_config =
ConfigOption::from_file(question_answering_config.model_type, config_path);
match model_config {
ConfigOption::DistilBert(ref mut config) => {
config.sinusoidal_pos_embds = false;
}
_ => (),
};
let qa_model = QuestionAnsweringOption::new(
question_answering_config.model_type,
&var_store.root(),
&model_config,
);
var_store.load(weights_path)?;
Ok(QuestionAnsweringModel {
tokenizer,
pad_idx,
sep_idx,
max_seq_len: 384,
doc_stride: 128,
max_query_length: 64,
max_answer_len: 15,
distilbert_qa: qa_model,
var_store,
})
}
pub fn predict(
&self,
qa_inputs: &[QaInput],
top_k: i64,
batch_size: usize,
) -> Vec<Vec<Answer>> {
let examples: Vec<QaExample> = qa_inputs
.iter()
.map(|qa_input| QaExample::new(&qa_input.question, &qa_input.context))
.collect();
let features: Vec<QaFeature> = examples
.iter()
.enumerate()
.map(|(example_index, qa_example)| {
self.generate_features(
&qa_example,
self.max_seq_len,
self.doc_stride,
self.max_query_length,
example_index as i64,
)
})
.flatten()
.collect();
let mut example_top_k_answers_map: HashMap<usize, Vec<Answer>> = HashMap::new();
let mut start = 0usize;
let len_features = features.len();
while start < len_features {
let end = start + min(len_features - start, batch_size);
let batch_features = &features[start..end];
let mut input_ids = Vec::with_capacity(batch_features.len());
let mut attention_masks = Vec::with_capacity(batch_features.len());
no_grad(|| {
for feature in batch_features {
input_ids.push(Tensor::of_slice(&feature.input_ids));
attention_masks.push(Tensor::of_slice(&feature.attention_mask));
}
let input_ids = Tensor::stack(&input_ids, 0).to(self.var_store.device());
let attention_masks =
Tensor::stack(&attention_masks, 0).to(self.var_store.device());
let (start_logits, end_logits) = self.distilbert_qa.forward_t(
Some(input_ids),
Some(attention_masks),
None,
false,
);
let start_logits = start_logits.detach();
let end_logits = end_logits.detach();
let example_index_to_feature_end_position: Vec<(usize, i64)> = batch_features
.iter()
.enumerate()
.map(|(feature_index, feature)| {
(feature.example_index as usize, feature_index as i64 + 1)
})
.collect();
let mut feature_id_start = 0;
for (example_id, max_feature_id) in example_index_to_feature_end_position {
let mut answers: Vec<Answer> = vec![];
let example = &examples[example_id];
for feature_idx in feature_id_start..max_feature_id {
let feature = &batch_features[feature_idx as usize];
let start = start_logits.get(feature_idx);
let end = end_logits.get(feature_idx);
let p_mask = (Tensor::of_slice(&feature.p_mask) - 1)
.abs()
.to_device(start.device());
let start: Tensor = start.exp() / start.exp().sum(Float) * &p_mask;
let end: Tensor = end.exp() / end.exp().sum(Float) * &p_mask;
let (starts, ends, scores) = self.decode(&start, &end, top_k);
for idx in 0..starts.len() {
let start_pos = feature.token_to_orig_map[&starts[idx]] as usize;
let end_pos = feature.token_to_orig_map[&ends[idx]] as usize;
let answer = example.doc_tokens[start_pos..end_pos + 1].join(" ");
let start = example
.char_to_word_offset
.iter()
.position(|&v| v as usize == start_pos)
.unwrap();
let end = example
.char_to_word_offset
.iter()
.rposition(|&v| v as usize == end_pos)
.unwrap();
answers.push(Answer {
score: scores[idx],
start,
end,
answer,
});
}
}
feature_id_start = max_feature_id;
let example_answers = example_top_k_answers_map
.entry(example_id)
.or_insert(vec![]);
example_answers.extend(answers);
}
});
start = end;
}
let mut all_answers = vec![];
for example_id in 0..examples.len() {
if let Some(answers) = example_top_k_answers_map.get_mut(&example_id) {
remove_duplicates(answers).sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
all_answers.push(answers[..min(answers.len(), top_k as usize)].to_vec());
} else {
all_answers.push(vec![]);
}
}
all_answers
}
fn decode(&self, start: &Tensor, end: &Tensor, top_k: i64) -> (Vec<i64>, Vec<i64>, Vec<f64>) {
let outer = start.unsqueeze(-1).matmul(&end.unsqueeze(0));
let start_dim = start.size()[0];
let end_dim = end.size()[0];
let candidates = outer
.triu(0)
.tril(self.max_answer_len as i64 - 1)
.flatten(0, -1);
let idx_sort = if top_k == 1 {
candidates.argmax(0, true)
} else if candidates.size()[0] < top_k {
candidates.argsort(0, true)
} else {
candidates.argsort(0, true).slice(0, 0, top_k, 1)
};
let mut start: Vec<i64> = vec![];
let mut end: Vec<i64> = vec![];
let mut scores: Vec<f64> = vec![];
for flat_index_position in 0..idx_sort.size()[0] {
let flat_index = idx_sort.int64_value(&[flat_index_position]);
scores.push(candidates.double_value(&[flat_index]));
start.push(flat_index / start_dim);
end.push(flat_index % end_dim);
}
(start, end, scores)
}
fn generate_features(
&self,
qa_example: &QaExample,
max_seq_length: usize,
doc_stride: usize,
max_query_length: usize,
example_index: i64,
) -> Vec<QaFeature> {
let mut tok_to_orig_index: Vec<i64> = vec![];
let mut all_doc_tokens: Vec<String> = vec![];
for (idx, token) in qa_example.doc_tokens.iter().enumerate() {
let sub_tokens = self.tokenizer.tokenize(token);
for sub_token in sub_tokens.into_iter() {
all_doc_tokens.push(sub_token);
tok_to_orig_index.push(idx as i64);
}
}
let truncated_query = self.prepare_query(&qa_example.question, max_query_length);
let sequence_added_tokens = match self.tokenizer {
TokenizerOption::Roberta(_) => {
self.tokenizer
.build_input_with_special_tokens(
vec![],
None,
vec![],
None,
vec![],
None,
vec![],
None,
)
.0
.len()
+ 1
}
_ => self
.tokenizer
.build_input_with_special_tokens(
vec![],
None,
vec![],
None,
vec![],
None,
vec![],
None,
)
.0
.len(),
};
let sequence_pair_added_tokens = self
.tokenizer
.build_input_with_special_tokens(
vec![],
Some(vec![]),
vec![],
Some(vec![]),
vec![],
Some(vec![]),
vec![],
Some(vec![]),
)
.0
.len();
let mut spans: Vec<QaFeature> = vec![];
let mut remaining_tokens = self.tokenizer.convert_tokens_to_ids(&all_doc_tokens);
while (spans.len() * doc_stride as usize) < all_doc_tokens.len() {
let (encoded_span, attention_mask) = self.encode_qa_pair(
&truncated_query,
&remaining_tokens,
max_seq_length,
doc_stride,
sequence_pair_added_tokens,
);
let paragraph_len = min(
all_doc_tokens.len() - spans.len() * doc_stride,
max_seq_length - truncated_query.len() - sequence_pair_added_tokens,
);
let mut token_to_orig_map = HashMap::new();
for i in 0..paragraph_len {
let index = truncated_query.len() + sequence_added_tokens + i;
token_to_orig_map.insert(
index as i64,
tok_to_orig_index[spans.len() * doc_stride + i] as i64,
);
}
let p_mask = self.get_mask(&encoded_span);
let qa_feature = QaFeature {
input_ids: encoded_span.token_ids,
attention_mask,
token_to_orig_map,
p_mask,
example_index,
};
spans.push(qa_feature);
if encoded_span.num_truncated_tokens == 0 {
break;
}
remaining_tokens = encoded_span.overflowing_tokens
}
spans
}
fn prepare_query(&self, query: &str, max_query_length: usize) -> Vec<i64> {
let truncated_query = self
.tokenizer
.convert_tokens_to_ids(&self.tokenizer.tokenize(&query));
let num_query_tokens_to_remove = if truncated_query.len() > max_query_length as usize {
truncated_query.len() - max_query_length
} else {
0
};
let (truncated_query, _, _, _, _, _, _, _, _, _) = truncate_sequences(
truncated_query,
None,
vec![],
None,
vec![],
None,
vec![],
None,
num_query_tokens_to_remove,
&TruncationStrategy::OnlyFirst,
0,
)
.unwrap();
truncated_query
}
fn encode_qa_pair(
&self,
truncated_query: &Vec<i64>,
spans_token_ids: &Vec<i64>,
max_seq_length: usize,
doc_stride: usize,
sequence_pair_added_tokens: usize,
) -> (TokenizedInput, Vec<i64>) {
let len_1 = truncated_query.len();
let len_2 = spans_token_ids.len();
let total_len = len_1 + len_2 + sequence_pair_added_tokens;
let num_truncated_tokens = if total_len > max_seq_length {
total_len - max_seq_length
} else {
0
};
let (truncated_query, truncated_context, _, _, _, _, _, _, overflowing_tokens, _) =
truncate_sequences(
truncated_query.clone(),
Some(spans_token_ids.clone()),
vec![],
None,
vec![],
None,
vec![],
None,
num_truncated_tokens,
&TruncationStrategy::OnlySecond,
max_seq_length - doc_stride - len_1 - sequence_pair_added_tokens,
)
.unwrap();
let (
mut token_ids,
mut segment_ids,
special_tokens_mask,
mut token_offsets,
mut reference_offsets,
mut mask,
) = self.tokenizer.build_input_with_special_tokens(
truncated_query,
truncated_context,
vec![],
None,
vec![],
None,
vec![],
None,
);
let mut attention_mask = vec![1; token_ids.len()];
if token_ids.len() < max_seq_length {
token_ids.append(&mut vec![self.pad_idx; max_seq_length - token_ids.len()]);
segment_ids.append(&mut vec![0; max_seq_length - segment_ids.len()]);
attention_mask.append(&mut vec![0; max_seq_length - attention_mask.len()]);
token_offsets.append(&mut vec![None; max_seq_length - token_offsets.len()]);
reference_offsets.append(&mut vec![vec!(); max_seq_length - token_offsets.len()]);
mask.append(&mut vec![Mask::Special; max_seq_length - mask.len()]);
}
(
TokenizedInput {
token_ids,
segment_ids,
special_tokens_mask,
overflowing_tokens,
num_truncated_tokens,
token_offsets,
reference_offsets,
mask,
},
attention_mask,
)
}
fn get_mask(&self, encoded_span: &TokenizedInput) -> Vec<i8> {
let sep_indices: Vec<usize> = encoded_span
.token_ids
.iter()
.enumerate()
.filter(|(_, &value)| value == self.sep_idx)
.map(|(position, _)| position)
.collect();
let mut p_mask: Vec<i8> = encoded_span
.segment_ids
.iter()
.map(|v| min(v, &1i8))
.map(|&v| 1i8 - v)
.collect();
for sep_position in sep_indices {
p_mask[sep_position] = 1;
}
p_mask
}
}
pub fn squad_processor(file_path: PathBuf) -> Vec<QaInput> {
let file = fs::File::open(file_path).expect("unable to open file");
let json: serde_json::Value =
serde_json::from_reader(file).expect("JSON not properly formatted");
let data = json
.get("data")
.expect("SQuAD file does not contain data field")
.as_array()
.expect("Data array not properly formatted");
let mut qa_inputs: Vec<QaInput> = Vec::with_capacity(data.len());
for qa_input in data.iter() {
let qa_input = qa_input.as_object().unwrap();
let paragraphs = qa_input.get("paragraphs").unwrap().as_array().unwrap();
for paragraph in paragraphs.iter() {
let paragraph = paragraph.as_object().unwrap();
let context = paragraph.get("context").unwrap().as_str().unwrap();
let qas = paragraph.get("qas").unwrap().as_array().unwrap();
for qa in qas.iter() {
let question = qa
.as_object()
.unwrap()
.get("question")
.unwrap()
.as_str()
.unwrap();
qa_inputs.push(QaInput {
question: question.to_owned(),
context: context.to_owned(),
});
}
}
}
qa_inputs
}