use crate::common::dropout::Dropout;
use crate::xlnet::XLNetConfig;
use std::borrow::Borrow;
use tch::nn::init::DEFAULT_KAIMING_UNIFORM;
use tch::{nn, Kind, Tensor};
#[derive(Debug)]
pub struct LayerState {
pub prev_content: Tensor,
}
impl Clone for LayerState {
fn clone(&self) -> Self {
LayerState {
prev_content: self.prev_content.copy(),
}
}
}
impl LayerState {
pub(crate) fn reorder_cache(&mut self, new_indices: &Tensor) {
self.prev_content = self.prev_content.index_select(1, new_indices);
}
}
#[derive(Debug)]
pub struct XLNetRelativeAttention {
dropout: Dropout,
output_attentions: bool,
query: Tensor,
key: Tensor,
value: Tensor,
output: Tensor,
pos: Tensor,
r_r_bias: Tensor,
r_s_bias: Tensor,
r_w_bias: Tensor,
seg_embed: Tensor,
layer_norm: nn::LayerNorm,
scale: f64,
}
impl XLNetRelativeAttention {
pub fn new<'p, P>(p: P, config: &XLNetConfig) -> XLNetRelativeAttention
where
P: Borrow<nn::Path<'p>>,
{
assert_eq!(
config.d_model % config.d_head,
0,
"Hidden size not a multiple of attention heads dimension"
);
let p = p.borrow();
let query = p.var(
"q",
&[config.d_model, config.n_head, config.d_head],
DEFAULT_KAIMING_UNIFORM,
);
let key = p.var(
"k",
&[config.d_model, config.n_head, config.d_head],
DEFAULT_KAIMING_UNIFORM,
);
let value = p.var(
"v",
&[config.d_model, config.n_head, config.d_head],
DEFAULT_KAIMING_UNIFORM,
);
let output = p.var(
"o",
&[config.d_model, config.n_head, config.d_head],
DEFAULT_KAIMING_UNIFORM,
);
let pos = p.var(
"r",
&[config.d_model, config.n_head, config.d_head],
DEFAULT_KAIMING_UNIFORM,
);
let r_r_bias = p.var(
"r_r_bias",
&[config.n_head, config.d_head],
DEFAULT_KAIMING_UNIFORM,
);
let r_s_bias = p.var(
"r_s_bias",
&[config.n_head, config.d_head],
DEFAULT_KAIMING_UNIFORM,
);
let r_w_bias = p.var(
"r_w_bias",
&[config.n_head, config.d_head],
DEFAULT_KAIMING_UNIFORM,
);
let seg_embed = p.var(
"seg_embed",
&[2, config.n_head, config.d_head],
DEFAULT_KAIMING_UNIFORM,
);
let dropout = Dropout::new(config.dropout);
let output_attentions = config.output_attentions.unwrap_or(false);
let layer_norm_eps = config.layer_norm_eps.unwrap_or(1e-12);
let layer_norm_config = nn::LayerNormConfig {
eps: layer_norm_eps,
..Default::default()
};
let layer_norm = nn::layer_norm(p / "layer_norm", vec![config.d_model], layer_norm_config);
let scale = 1f64 / ((config.d_head as f64).powf(0.5f64));
XLNetRelativeAttention {
dropout,
output_attentions,
query,
key,
value,
output,
pos,
r_r_bias,
r_s_bias,
r_w_bias,
seg_embed,
layer_norm,
scale,
}
}
fn rel_shift_bnij(&self, x: &Tensor, klen: i64) -> Tensor {
let shape = x.size();
x.reshape([shape[0], shape[1], shape[3], shape[2]])
.narrow(2, 1, shape[3] - 1)
.reshape([shape[0], shape[1], shape[2], shape[3] - 1])
.index_select(3, &Tensor::arange(klen, (Kind::Int64, x.device())))
}
fn rel_attention_core(
&self,
q_head: &Tensor,
k_head_h: &Tensor,
v_head_h: &Tensor,
k_head_r: &Tensor,
seg_mat: Option<&Tensor>,
attention_mask: Option<&Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let ac = Tensor::einsum(
"ibnd,jbnd->bnij",
&[&(q_head + &self.r_w_bias), k_head_h],
None::<i64>,
);
let bd = self.rel_shift_bnij(
&Tensor::einsum(
"ibnd,jbnd->bnij",
&[&(q_head + &self.r_r_bias), k_head_r],
None::<i64>,
),
ac.size()[3],
);
let ef = match seg_mat {
Some(seg_mat) => {
let ef = Tensor::einsum(
"ibnd,snd->ibns",
&[&(q_head + &self.r_s_bias), &self.seg_embed],
None::<i64>,
);
Tensor::einsum("ijbs,ibns->bnij", &[seg_mat, &ef], None::<i64>)
}
None => Tensor::zeros([1], (ac.kind(), ac.device())),
};
let mut attention_score = (ac + bd + ef) * self.scale;
if let Some(value) = attention_mask {
let target_kind = attention_score.kind();
attention_score =
(attention_score - value.permute([2, 3, 0, 1]) * 1e30).to_kind(target_kind);
};
let attention_probas = attention_score
.softmax(3, attention_score.kind())
.apply_t(&self.dropout, train);
let attention_vector = Tensor::einsum(
"bnij,jbnd->ibnd",
&[&attention_probas, v_head_h],
None::<i64>,
);
if self.output_attentions {
(
attention_vector,
Some(attention_probas.permute([2, 3, 0, 1])),
)
} else {
(attention_vector, None)
}
}
fn post_attention(
&self,
h: &Tensor,
attention_vector: &Tensor,
residual: bool,
train: bool,
) -> Tensor {
let mut attention_out = Tensor::einsum(
"ibnd,hnd->ibh",
&[attention_vector, &self.output],
None::<i64>,
)
.apply_t(&self.dropout, train);
if residual {
attention_out = attention_out + h;
};
attention_out.apply(&self.layer_norm)
}
pub fn forward_t(
&self,
h: &Tensor,
g: Option<&Tensor>,
attn_mask_h: Option<&Tensor>,
attn_mask_g: Option<&Tensor>,
r: &Tensor,
seg_mat: Option<&Tensor>,
layer_state: Option<LayerState>,
target_mapping: Option<&Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>, Option<Tensor>, Option<Tensor>) {
let cat_value = if let Some(mems) = &layer_state {
if mems.prev_content.size().len() > 1 {
Some(Tensor::cat(&[&mems.prev_content, h], 0))
} else {
None
}
} else {
None
};
let cat = match &cat_value {
Some(value) => value,
None => h,
};
let q_head_h = Tensor::einsum("ibh,hnd->ibnd", &[h, &self.query], None::<i64>);
let k_head_h = Tensor::einsum("ibh,hnd->ibnd", &[cat, &self.key], None::<i64>);
let v_head_h = Tensor::einsum("ibh,hnd->ibnd", &[cat, &self.value], None::<i64>);
let k_head_r = Tensor::einsum("ibh,hnd->ibnd", &[r, &self.pos], None::<i64>);
let (attention_vec_h, attention_probas_h) = self.rel_attention_core(
&q_head_h,
&k_head_h,
&v_head_h,
&k_head_r,
seg_mat,
attn_mask_h,
train,
);
let output_h = self.post_attention(h, &attention_vec_h, true, train);
let (output_g, attention_probas_g) = if let Some(g) = g {
let q_head_g = Tensor::einsum("ibh,hnd->ibnd", &[g, &self.query], None::<i64>);
let (attention_vec_g, attention_probas_g) = match target_mapping {
Some(target_mapping) => {
let q_head_g =
Tensor::einsum("mbnd,mlb->lbnd", &[&q_head_g, target_mapping], None::<i64>);
let (attention_vec_g, attention_probas_g) = self.rel_attention_core(
&q_head_g,
&k_head_h,
&v_head_h,
&k_head_r,
seg_mat,
attn_mask_g,
train,
);
let attention_vec_g = Tensor::einsum(
"lbnd,mlb->mbnd",
&[&attention_vec_g, target_mapping],
None::<i64>,
);
(attention_vec_g, attention_probas_g)
}
None => self.rel_attention_core(
&q_head_g,
&k_head_h,
&v_head_h,
&k_head_r,
seg_mat,
attn_mask_g,
train,
),
};
let output_g = self.post_attention(g, &attention_vec_g, true, train);
(Some(output_g), attention_probas_g)
} else {
(None, None)
};
(output_h, output_g, attention_probas_h, attention_probas_g)
}
}