1use crate::common::dropout::Dropout;
15use std::borrow::Borrow;
16use tch::{nn, Tensor};
17
18#[derive(Debug)]
19pub struct LayerState {
22 pub prev_key: Tensor,
24 pub prev_value: Tensor,
26}
27
28impl Clone for LayerState {
29 fn clone(&self) -> Self {
30 LayerState {
31 prev_key: self.prev_key.copy(),
32 prev_value: self.prev_value.copy(),
33 }
34 }
35}
36
37impl LayerState {
38 pub(crate) fn reorder_cache(&mut self, new_indices: &Tensor) {
39 self.prev_key = self.prev_key.index_select(0, new_indices);
40 self.prev_value = self.prev_value.index_select(0, new_indices);
41 }
42}
43
44#[derive(Debug)]
45pub struct BartAttention {
46 num_heads: i64,
47 head_dim: i64,
48 dropout: Dropout,
49 scaling: f64,
50 encoder_decoder_attention: bool,
51 output_attentions: bool,
52 k_proj: nn::Linear,
53 v_proj: nn::Linear,
54 q_proj: nn::Linear,
55 out_proj: nn::Linear,
56 store_cache: bool,
57}
58
59impl BartAttention {
60 pub fn new<'p, P>(
61 p: P,
62 embed_dim: i64,
63 num_heads: i64,
64 dropout: f64,
65 encoder_decoder_attention: bool,
66 store_cache: bool,
67 output_attentions: bool,
68 ) -> BartAttention
69 where
70 P: Borrow<nn::Path<'p>>,
71 {
72 let p = p.borrow();
73
74 let k_proj = nn::linear(p / "k_proj", embed_dim, embed_dim, Default::default());
75 let v_proj = nn::linear(p / "v_proj", embed_dim, embed_dim, Default::default());
76 let q_proj = nn::linear(p / "q_proj", embed_dim, embed_dim, Default::default());
77 let out_proj = nn::linear(p / "out_proj", embed_dim, embed_dim, Default::default());
78
79 let head_dim = embed_dim / num_heads;
80 let scaling = (head_dim as f64).powf(-0.5);
81 let dropout = Dropout::new(dropout);
82
83 BartAttention {
84 num_heads,
85 head_dim,
86 dropout,
87 scaling,
88 encoder_decoder_attention,
89 output_attentions,
90 k_proj,
91 v_proj,
92 q_proj,
93 out_proj,
94 store_cache,
95 }
96 }
97
98 fn _shape(&self, x: Tensor, sequence_length: i64, batch_size: i64) -> Tensor {
99 x.view((batch_size, sequence_length, self.num_heads, self.head_dim))
100 .transpose(1, 2)
101 .contiguous()
102 }
103
104 pub fn forward_t(
105 &self,
106 hidden_states: &Tensor,
107 key_value_states: Option<&Tensor>,
108 attention_mask: Option<&Tensor>,
109 layer_state: Option<LayerState>,
110 train: bool,
111 ) -> (Tensor, Option<Tensor>, Option<LayerState>) {
112 let (bs, target_length, embed_dim) = hidden_states.size3().unwrap();
113
114 let query_states = hidden_states.apply(&self.q_proj) * self.scaling;
115
116 let (key_states, value_states) = if self.encoder_decoder_attention {
117 if let Some(layer_state_value) = layer_state {
118 (layer_state_value.prev_key, layer_state_value.prev_value)
119 } else {
120 (
121 self._shape(key_value_states.unwrap().apply(&self.k_proj), -1, bs),
122 self._shape(key_value_states.unwrap().apply(&self.v_proj), -1, bs),
123 )
124 }
125 } else if let Some(layer_state_value) = layer_state {
126 let key_states = self._shape(hidden_states.apply(&self.k_proj), -1, bs);
127 let value_states = self._shape(hidden_states.apply(&self.v_proj), -1, bs);
128 (
129 Tensor::cat(&[layer_state_value.prev_key, key_states], 2),
130 Tensor::cat(&[layer_state_value.prev_value, value_states], 2),
131 )
132 } else {
133 (
134 self._shape(hidden_states.apply(&self.k_proj), -1, bs),
135 self._shape(hidden_states.apply(&self.v_proj), -1, bs),
136 )
137 };
138
139 let new_layer_state = if self.store_cache {
140 Some(LayerState {
141 prev_key: key_states.copy(),
142 prev_value: value_states.copy(),
143 })
144 } else {
145 None
146 };
147
148 let proj_shape = [bs * self.num_heads, -1, self.head_dim];
149 let query_states = self
150 ._shape(query_states, target_length, bs)
151 .view(proj_shape);
152 let key_states = key_states.view(proj_shape);
153 let value_states = value_states.view(proj_shape);
154
155 let source_length = key_states.size()[1];
156 let mut attention_weights = query_states.bmm(&key_states.transpose(1, 2));
157
158 if let Some(attention_mask_value) = attention_mask {
159 attention_weights =
160 attention_weights.view([bs, self.num_heads, target_length, source_length])
161 + attention_mask_value;
162 attention_weights =
163 attention_weights.view([bs * self.num_heads, target_length, source_length]);
164 };
165
166 attention_weights = attention_weights.softmax(-1, attention_weights.kind());
167
168 let saved_attention_weights = if self.output_attentions {
169 Some(attention_weights.view((bs, self.num_heads, target_length, source_length)))
170 } else {
171 None
172 };
173
174 let attention_probas = attention_weights.apply_t(&self.dropout, train);
175 let attention_output = attention_probas
176 .bmm(&value_states)
177 .view([bs, self.num_heads, target_length, self.head_dim])
178 .transpose(1, 2)
179 .reshape([bs, target_length, embed_dim])
180 .apply(&self.out_proj);
181
182 (attention_output, saved_attention_weights, new_layer_state)
183 }
184}