use crate::common::dropout::Dropout;
use tch::{nn, Tensor};
use tch::kind::Kind::Float;
#[derive(Debug)]
pub struct LayerState {
pub prev_key: Option<Tensor>,
pub prev_value: Option<Tensor>,
pub prev_key_padding_mask: Option<Tensor>,
}
impl LayerState {
pub(crate) fn reorder_cache(&mut self, new_indices: &Tensor) {
if self.prev_key.is_some() {
self.prev_key = Some(self.prev_key.as_ref().unwrap().index_select(0, new_indices));
}
if self.prev_value.is_some() {
self.prev_value = Some(self.prev_value.as_ref().unwrap().index_select(0, new_indices));
}
if self.prev_key_padding_mask.is_some() {
self.prev_key_padding_mask = Some(self.prev_key_padding_mask.as_ref().unwrap().index_select(0, new_indices));
}
}
pub(crate) fn reset_cache(&mut self) {
self.prev_key = None;
self.prev_value = None;
self.prev_key_padding_mask = None;
}
}
#[derive(Debug)]
pub struct SelfAttention {
num_heads: i64,
head_dim: i64,
dropout: Dropout,
scaling: f64,
encoder_decoder_attention: bool,
output_attentions: bool,
pub(crate) prev_state: Option<LayerState>,
k_proj: nn::Linear,
v_proj: nn::Linear,
q_proj: nn::Linear,
out_proj: nn::Linear,
}
impl SelfAttention {
pub fn new(p: nn::Path, embed_dim: i64, num_heads: i64, dropout: f64,
encoder_decoder_attention: bool, store_cache: bool, output_attentions: bool) -> SelfAttention {
let k_proj = nn::linear(&p / "k_proj", embed_dim, embed_dim, Default::default());
let v_proj = nn::linear(&p / "v_proj", embed_dim, embed_dim, Default::default());
let q_proj = nn::linear(&p / "q_proj", embed_dim, embed_dim, Default::default());
let out_proj = nn::linear(&p / "out_proj", embed_dim, embed_dim, Default::default());
let head_dim = embed_dim / num_heads;
let scaling = (head_dim as f64).powf(-0.5);
let dropout = Dropout::new(dropout);
let prev_state = if store_cache {
Some(LayerState { prev_key: None, prev_value: None, prev_key_padding_mask: None })
} else {
None
};
SelfAttention {
num_heads,
head_dim,
dropout,
scaling,
encoder_decoder_attention,
output_attentions,
prev_state,
k_proj,
v_proj,
q_proj,
out_proj,
}
}
fn flatten(&self, x: Tensor, dim_0: i64, bs: i64) -> Tensor {
x.contiguous().view((dim_0, bs * self.num_heads, self.head_dim)).transpose(0, 1)
}
pub fn forward_t(&mut self, query: &Tensor,
key: Option<&Tensor>,
key_padding_mask: Option<&Tensor>,
attention_mask: Option<&Tensor>,
train: bool) -> (Tensor, Option<Tensor>) {
let query_size = query.size();
let (target_sequence_length, bs) = (query_size[0], query_size[1]);
let q: Tensor = self.flatten(query.as_ref().apply(&self.q_proj) * self.scaling, target_sequence_length, bs);
let key = match &self.prev_state {
Some(prev_state) => {
if prev_state.prev_key.is_some() & self.encoder_decoder_attention {
None
} else {
key
}
}
None => key
};
let (k, v) = if self.encoder_decoder_attention {
match key {
Some(key) => {
(Some(self.flatten(key.apply(&self.k_proj), -1, bs)),
Some(self.flatten(key.apply(&self.v_proj), -1, bs))
)
}
None => (None, None)
}
} else {
(Some(self.flatten(query.apply(&self.k_proj), -1, bs)),
Some(self.flatten(query.apply(&self.v_proj), -1, bs))
)
};
let (k, v, key_padding_mask) = self.use_saved_state(k, v, key_padding_mask, bs);
self.prev_state = match &self.prev_state {
Some(_) => Some(LayerState {
prev_key: Some(k.view((bs, self.num_heads, -1, self.head_dim))),
prev_value: Some(v.view((bs, self.num_heads, -1, self.head_dim))),
prev_key_padding_mask: match key_padding_mask.as_ref() {
Some(tensor) => Some(tensor.copy()),
None => None
},
}),
None => None
};
let source_sequence_length = k.size()[1];
let attention_weights = q.bmm(&k.transpose(1, 2));
let attention_weights = match attention_mask {
Some(mask) => {
let attention_weights = attention_weights.view((bs, self.num_heads, target_sequence_length, source_sequence_length)) + mask;
attention_weights.view((bs * self.num_heads, target_sequence_length, source_sequence_length))
}
None => attention_weights
};
let attention_weights = match key_padding_mask.as_ref() {
Some(mask) => {
attention_weights
.view((bs, self.num_heads, target_sequence_length, source_sequence_length))
.masked_fill(&mask.unsqueeze(1).unsqueeze(2), std::f64::NEG_INFINITY)
.view((bs * self.num_heads, target_sequence_length, source_sequence_length))
}
None => attention_weights
};
let attention_weights = attention_weights.softmax(-1, Float);
let attention_probabilities = attention_weights.apply_t(&self.dropout, train);
let output = attention_probabilities
.bmm(&v)
.transpose(0, 1)
.contiguous()
.view((target_sequence_length, bs, self.num_heads * self.head_dim))
.apply(&self.out_proj);
let attention_weights = if self.output_attentions {
Some(attention_weights.view((bs, self.num_heads, target_sequence_length, source_sequence_length)))
} else { None };
(output, attention_weights)
}
fn use_saved_state(&self, k: Option<Tensor>, v: Option<Tensor>, key_padding_mask: Option<&Tensor>, bs: i64)
-> (Tensor, Tensor, Option<Tensor>) {
match &self.prev_state {
Some(prev_state) => {
let k = match &prev_state.prev_key {
Some(prev_key) => {
let prev_key = prev_key.view((bs * self.num_heads, -1, self.head_dim));
if self.encoder_decoder_attention {
prev_key
} else {
Tensor::cat(&[prev_key, k.unwrap()], 1)
}
}
None => k.unwrap()
};
let v = match &prev_state.prev_value {
Some(prev_value) => {
let prev_value = prev_value.view((bs * self.num_heads, -1, self.head_dim));
if self.encoder_decoder_attention {
prev_value
} else {
Tensor::cat(&[prev_value, v.unwrap()], 1)
}
}
None => v.unwrap()
};
let key_padding_mask = self.use_saved_key_padding_mask(key_padding_mask,
&prev_state.prev_key_padding_mask,
bs,
k.size()[1]);
(k, v, key_padding_mask)
}
None => {
let key_padding_mask = match key_padding_mask {
Some(value) => Some(value.copy()),
None => None
};
(k.unwrap(), v.unwrap(), key_padding_mask)
}
}
}
fn use_saved_key_padding_mask(&self, key_padding_mask: Option<&Tensor>, prev_key_padding_mask: &Option<Tensor>,
bs: i64, sequence_length: i64) -> Option<Tensor> {
if prev_key_padding_mask.is_some() {
if self.encoder_decoder_attention {
Some(prev_key_padding_mask.as_ref().unwrap().copy())
} else {
Some(Tensor::cat(&[prev_key_padding_mask.as_ref().unwrap(), key_padding_mask.as_ref().unwrap()], 1))
}
} else {
match key_padding_mask {
Some(key_padding_mask) => {
let filler = Tensor::zeros(&[bs, sequence_length - key_padding_mask.size()[1]],
(key_padding_mask.kind(), key_padding_mask.device()));
Some(Tensor::cat(&[filler, key_padding_mask.copy()], 1))
}
None => None
}
}
}
}