1use crate::common::dropout::Dropout;
14use crate::common::kind::get_min;
15use crate::gpt_j::gpt_j_model::GptJConfig;
16use std::borrow::Borrow;
17use tch::nn::Linear;
18use tch::{nn, IndexOp, Kind, NewAxis, Tensor};
19
20#[derive(Debug)]
21pub struct LayerState {
24 pub prev_key: Tensor,
26 pub prev_value: Tensor,
28}
29
30impl Clone for LayerState {
31 fn clone(&self) -> Self {
32 LayerState {
33 prev_key: self.prev_key.copy(),
34 prev_value: self.prev_value.copy(),
35 }
36 }
37}
38
39impl LayerState {
40 pub(crate) fn reorder_cache(&mut self, new_indices: &Tensor) {
41 self.prev_key = self.prev_key.index_select(0, new_indices);
42 self.prev_value = self.prev_value.index_select(0, new_indices);
43 }
44}
45
46pub struct GptJAttention {
47 bias: Tensor,
48 attn_dropout: Dropout,
49 resid_dropout: Dropout,
50 scale_attn: f32,
51 k_proj: Linear,
52 v_proj: Linear,
53 q_proj: Linear,
54 out_proj: Linear,
55 output_attentions: bool,
56 dim_per_head: i64,
57 n_head: i64,
58 rotary_dim: Option<i64>,
59 scale: bool,
60 use_cache: bool,
61}
62
63impl GptJAttention {
64 pub fn new<'p, P>(p: P, config: &GptJConfig) -> GptJAttention
65 where
66 P: Borrow<nn::Path<'p>>,
67 {
68 let p = p.borrow();
69
70 let max_positions = config.n_positions;
71 let bias_value = Tensor::ones([max_positions, max_positions], (Kind::Uint8, p.device()))
72 .tril(0)
73 .view([1, 1, max_positions, max_positions])
74 .requires_grad_(false);
75 let mut bias = p
76 .f_ones_no_train("bias", &[1, 1, max_positions, max_positions])
77 .unwrap()
78 .to_kind(Kind::Uint8)
79 .to_device(p.device());
80 bias.copy_(&bias_value);
81
82 let attn_pdrop = config.attn_pdrop.unwrap_or(0.1);
83 let resid_pdrop = config.resid_pdrop.unwrap_or(0.1);
84 let output_attentions = config.output_attentions.unwrap_or(false);
85
86 let attn_dropout = Dropout::new(attn_pdrop);
87 let resid_dropout = Dropout::new(resid_pdrop);
88
89 assert_eq!(
90 config.n_embd % config.n_head,
91 0,
92 "Attention hidden states not a multiple of the number of heads"
93 );
94 let dim_per_head = config.n_embd / config.n_head;
95
96 let scale_attn = (dim_per_head as f32).sqrt();
97
98 let linear_config = nn::LinearConfig {
99 bias: false,
100 ..Default::default()
101 };
102 let k_proj = nn::linear(p / "k_proj", config.n_embd, config.n_embd, linear_config);
103 let v_proj = nn::linear(p / "v_proj", config.n_embd, config.n_embd, linear_config);
104 let q_proj = nn::linear(p / "q_proj", config.n_embd, config.n_embd, linear_config);
105 let out_proj = nn::linear(p / "out_proj", config.n_embd, config.n_embd, linear_config);
106
107 GptJAttention {
108 bias,
109 attn_dropout,
110 resid_dropout,
111 output_attentions,
112 scale_attn,
113 k_proj,
114 v_proj,
115 q_proj,
116 out_proj,
117 dim_per_head,
118 n_head: config.n_head,
119 rotary_dim: config.rotary_dim,
120 scale: config.scale_attn_weights.unwrap_or(true),
121 use_cache: config.use_cache.unwrap_or(true),
122 }
123 }
124
125 fn split_heads(
126 tensor: &Tensor,
127 num_heads: i64,
128 attention_head_size: i64,
129 rotary: bool,
130 ) -> Tensor {
131 let mut new_shape = tensor.size();
132 let _ = new_shape.pop();
133 new_shape.extend_from_slice(&[num_heads, attention_head_size]);
134 let tensor = tensor.view(new_shape.as_slice());
135 if rotary {
136 tensor
137 } else if tensor.size().len() == 5 {
138 tensor.permute([0, 1, 3, 2, 4]) } else if tensor.size().len() == 4 {
140 tensor.permute([0, 2, 1, 3]) } else {
142 panic!(
143 "Input tensor should either be a rotary head, or its rank be one of [4, 5] but is: {}",
144 tensor.size().len()
145 )
146 }
147 }
148
149 fn merge_heads(tensor: &Tensor, num_heads: i64, attention_head_size: i64) -> Tensor {
150 let tensor = if tensor.size().len() == 5 {
151 tensor.permute([0, 1, 3, 2, 4]).contiguous()
152 } else if tensor.size().len() == 4 {
153 tensor.permute([0, 2, 1, 3]).contiguous()
154 } else {
155 panic!(
156 "Input tensor rank should be one of [4, 5], but is: {}",
157 tensor.size().len()
158 )
159 };
160 let mut new_shape = tensor.size();
161 new_shape.truncate(new_shape.len() - 2);
162 new_shape.push(num_heads * attention_head_size);
163 tensor.view(new_shape.as_slice())
164 }
165
166 fn attention(
167 &self,
168 query: &Tensor,
169 key: &Tensor,
170 value: &Tensor,
171 attention_mask: Option<&Tensor>,
172 train: bool,
173 ) -> (Tensor, Tensor) {
174 let query = query.to_kind(Kind::Float);
175 let key = key.to_kind(Kind::Float);
176
177 let attention_weights = query.matmul(&key.transpose(-1, -2));
178
179 let query_dims = query.size();
180 let key_dims = key.size();
181 let query_length = query_dims[query_dims.len() - 2];
182 let key_length = key_dims[key_dims.len() - 2];
183
184 let causal_mask = &self
185 .bias
186 .slice(2, key_length - query_length, key_length, 1)
187 .slice(3, 0, key_length, 1)
188 .to_kind(Kind::Bool)
189 .to_device(attention_weights.device());
190
191 let mask_value = get_min(attention_weights.kind()).unwrap();
192 let mask_value = Tensor::full(
193 attention_weights.size(),
194 mask_value,
195 (attention_weights.kind(), attention_weights.device()),
196 );
197
198 let mut attention_weights = attention_weights.where_self(causal_mask, &mask_value);
199 if self.scale {
200 attention_weights /= self.scale_attn;
201 }
202 if let Some(attention_mask_value) = attention_mask {
203 attention_weights += attention_mask_value;
204 };
205 let attention_weights = attention_weights.softmax(-1, attention_weights.kind());
206 let attention_weights = attention_weights
207 .to_kind(value.kind())
208 .apply_t(&self.attn_dropout, train);
209
210 let attention_output = attention_weights.matmul(value);
211
212 (attention_output, attention_weights)
213 }
214
215 pub fn forward_t(
216 &self,
217 hidden_states: &Tensor,
218 attention_mask: Option<&Tensor>,
219 layer_past: Option<&LayerState>,
220 train: bool,
221 ) -> (Tensor, Option<LayerState>, Option<Tensor>) {
222 let query = hidden_states.apply(&self.q_proj);
223 let key = hidden_states.apply(&self.k_proj);
224 let value = hidden_states.apply(&self.v_proj);
225
226 let mut query = Self::split_heads(&query, self.n_head, self.dim_per_head, true);
227 let mut key = Self::split_heads(&key, self.n_head, self.dim_per_head, true);
228 let mut value = Self::split_heads(&value, self.n_head, self.dim_per_head, false);
229
230 let mut seq_len = key.size()[1];
231 let mut offset = 0;
232
233 if let Some(layer_past) = layer_past {
234 offset = layer_past.prev_key.size()[layer_past.prev_key.size().len() - 2];
235 seq_len += offset
236 };
237
238 if let Some(rotary_dim) = self.rotary_dim {
239 let k_rot = key.slice(3, 0, rotary_dim, 1);
240 let k_pass = key.slice(3, rotary_dim, key.size()[3], 1);
241
242 let q_rot = query.slice(3, 0, rotary_dim, 1);
243 let q_pass = query.slice(3, rotary_dim, query.size()[3], 1);
244
245 let sincos = fixed_pos_embedding(&k_rot, seq_len);
246 let k_rot = apply_rotary_pos_emb(&k_rot, &sincos, offset);
247 let q_rot = apply_rotary_pos_emb(&q_rot, &sincos, offset);
248
249 key = Tensor::cat(&[k_rot, k_pass], -1);
250 query = Tensor::cat(&[q_rot, q_pass], -1);
251 } else {
252 let sincos = fixed_pos_embedding(&key, seq_len);
253 key = apply_rotary_pos_emb(&key, &sincos, offset);
254 query = apply_rotary_pos_emb(&query, &sincos, offset);
255 }
256
257 key = key.permute([0, 2, 1, 3]);
258 query = query.permute([0, 2, 1, 3]);
259
260 if let Some(layer_past) = layer_past {
261 key = Tensor::cat(&[&layer_past.prev_key, &key], -2);
262 value = Tensor::cat(&[&layer_past.prev_value, &value], -2);
263 }
264
265 let present = self.use_cache.then(|| LayerState {
266 prev_key: key.copy(),
267 prev_value: value.copy(),
268 });
269
270 let (attn_output, attn_weights) =
271 self.attention(&query, &key, &value, attention_mask, train);
272
273 let attn_output = Self::merge_heads(&attn_output, self.n_head, self.dim_per_head)
274 .apply(&self.out_proj)
275 .apply_t(&self.resid_dropout, train);
276
277 let attn_weights = self.output_attentions.then_some(attn_weights);
278
279 (attn_output, present, attn_weights)
280 }
281}
282
283fn fixed_pos_embedding(x: &Tensor, seq_len: i64) -> (Tensor, Tensor) {
284 let dim = x.size()[x.size().len() - 1];
285 let inv_freq = 1.0
286 / Tensor::pow_scalar(
287 10_000,
288 &(Tensor::arange_start_step(0, dim, 2, (x.kind(), x.device())) / dim),
289 );
290 let sinusoid_inp = Tensor::einsum(
291 "i , j -> i j",
292 &[Tensor::arange(seq_len, (x.kind(), x.device())), inv_freq],
293 None::<i64>,
294 );
295 (sinusoid_inp.sin(), sinusoid_inp.cos())
296}
297
298fn apply_rotary_pos_emb(x: &Tensor, (sin, cos): &(Tensor, Tensor), offset: i64) -> Tensor {
299 let sin = duplicate_interleave(sin).i((NewAxis, offset..x.size()[1] + offset, NewAxis, ..));
300 let cos = duplicate_interleave(cos).i((NewAxis, offset..x.size()[1] + offset, NewAxis, ..));
301 (x * cos) + (rotate_every_two(x) * sin)
302}
303
304fn duplicate_interleave(m: &Tensor) -> Tensor {
306 let dim0 = m.size()[0];
307 m.view([-1, 1]) .repeat([1, 2]) .view([dim0, -1]) }
311
312fn rotate_every_two(x: &Tensor) -> Tensor {
313 let x1 = x.slice(3, 0, x.size()[3], 2);
314 let x2 = x.slice(3, 1, x.size()[3], 2);
315 Tensor::stack(&[-x2, x1], -1).flatten(-2, -1)
316}