rust_bert/models/gpt_j/
attention.rs

1// Copyright 2021 The Eleuther AI and HuggingFace Inc. team. All rights reserved.
2// Copyright 2022 Guillaume Becquin
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//     http://www.apache.org/licenses/LICENSE-2.0
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use crate::common::dropout::Dropout;
14use crate::common::kind::get_min;
15use crate::gpt_j::gpt_j_model::GptJConfig;
16use std::borrow::Borrow;
17use tch::nn::Linear;
18use tch::{nn, IndexOp, Kind, NewAxis, Tensor};
19
20#[derive(Debug)]
21/// # Cache for GPT-J attention layers
22/// Stores the cached value of key and value
23pub struct LayerState {
24    /// Cached keys
25    pub prev_key: Tensor,
26    /// Cached values
27    pub prev_value: Tensor,
28}
29
30impl Clone for LayerState {
31    fn clone(&self) -> Self {
32        LayerState {
33            prev_key: self.prev_key.copy(),
34            prev_value: self.prev_value.copy(),
35        }
36    }
37}
38
39impl LayerState {
40    pub(crate) fn reorder_cache(&mut self, new_indices: &Tensor) {
41        self.prev_key = self.prev_key.index_select(0, new_indices);
42        self.prev_value = self.prev_value.index_select(0, new_indices);
43    }
44}
45
46pub struct GptJAttention {
47    bias: Tensor,
48    attn_dropout: Dropout,
49    resid_dropout: Dropout,
50    scale_attn: f32,
51    k_proj: Linear,
52    v_proj: Linear,
53    q_proj: Linear,
54    out_proj: Linear,
55    output_attentions: bool,
56    dim_per_head: i64,
57    n_head: i64,
58    rotary_dim: Option<i64>,
59    scale: bool,
60    use_cache: bool,
61}
62
63impl GptJAttention {
64    pub fn new<'p, P>(p: P, config: &GptJConfig) -> GptJAttention
65    where
66        P: Borrow<nn::Path<'p>>,
67    {
68        let p = p.borrow();
69
70        let max_positions = config.n_positions;
71        let bias_value = Tensor::ones([max_positions, max_positions], (Kind::Uint8, p.device()))
72            .tril(0)
73            .view([1, 1, max_positions, max_positions])
74            .requires_grad_(false);
75        let mut bias = p
76            .f_ones_no_train("bias", &[1, 1, max_positions, max_positions])
77            .unwrap()
78            .to_kind(Kind::Uint8)
79            .to_device(p.device());
80        bias.copy_(&bias_value);
81
82        let attn_pdrop = config.attn_pdrop.unwrap_or(0.1);
83        let resid_pdrop = config.resid_pdrop.unwrap_or(0.1);
84        let output_attentions = config.output_attentions.unwrap_or(false);
85
86        let attn_dropout = Dropout::new(attn_pdrop);
87        let resid_dropout = Dropout::new(resid_pdrop);
88
89        assert_eq!(
90            config.n_embd % config.n_head,
91            0,
92            "Attention hidden states not a multiple of the number of heads"
93        );
94        let dim_per_head = config.n_embd / config.n_head;
95
96        let scale_attn = (dim_per_head as f32).sqrt();
97
98        let linear_config = nn::LinearConfig {
99            bias: false,
100            ..Default::default()
101        };
102        let k_proj = nn::linear(p / "k_proj", config.n_embd, config.n_embd, linear_config);
103        let v_proj = nn::linear(p / "v_proj", config.n_embd, config.n_embd, linear_config);
104        let q_proj = nn::linear(p / "q_proj", config.n_embd, config.n_embd, linear_config);
105        let out_proj = nn::linear(p / "out_proj", config.n_embd, config.n_embd, linear_config);
106
107        GptJAttention {
108            bias,
109            attn_dropout,
110            resid_dropout,
111            output_attentions,
112            scale_attn,
113            k_proj,
114            v_proj,
115            q_proj,
116            out_proj,
117            dim_per_head,
118            n_head: config.n_head,
119            rotary_dim: config.rotary_dim,
120            scale: config.scale_attn_weights.unwrap_or(true),
121            use_cache: config.use_cache.unwrap_or(true),
122        }
123    }
124
125    fn split_heads(
126        tensor: &Tensor,
127        num_heads: i64,
128        attention_head_size: i64,
129        rotary: bool,
130    ) -> Tensor {
131        let mut new_shape = tensor.size();
132        let _ = new_shape.pop();
133        new_shape.extend_from_slice(&[num_heads, attention_head_size]);
134        let tensor = tensor.view(new_shape.as_slice());
135        if rotary {
136            tensor
137        } else if tensor.size().len() == 5 {
138            tensor.permute([0, 1, 3, 2, 4]) // (batch, blocks, head, block_length, head_features)
139        } else if tensor.size().len() == 4 {
140            tensor.permute([0, 2, 1, 3]) // (batch, head, seq_length, head_features)
141        } else {
142            panic!(
143                "Input tensor should either be a rotary head, or its rank be one of [4, 5] but is: {}",
144                tensor.size().len()
145            )
146        }
147    }
148
149    fn merge_heads(tensor: &Tensor, num_heads: i64, attention_head_size: i64) -> Tensor {
150        let tensor = if tensor.size().len() == 5 {
151            tensor.permute([0, 1, 3, 2, 4]).contiguous()
152        } else if tensor.size().len() == 4 {
153            tensor.permute([0, 2, 1, 3]).contiguous()
154        } else {
155            panic!(
156                "Input tensor rank should be one of [4, 5], but is: {}",
157                tensor.size().len()
158            )
159        };
160        let mut new_shape = tensor.size();
161        new_shape.truncate(new_shape.len() - 2);
162        new_shape.push(num_heads * attention_head_size);
163        tensor.view(new_shape.as_slice())
164    }
165
166    fn attention(
167        &self,
168        query: &Tensor,
169        key: &Tensor,
170        value: &Tensor,
171        attention_mask: Option<&Tensor>,
172        train: bool,
173    ) -> (Tensor, Tensor) {
174        let query = query.to_kind(Kind::Float);
175        let key = key.to_kind(Kind::Float);
176
177        let attention_weights = query.matmul(&key.transpose(-1, -2));
178
179        let query_dims = query.size();
180        let key_dims = key.size();
181        let query_length = query_dims[query_dims.len() - 2];
182        let key_length = key_dims[key_dims.len() - 2];
183
184        let causal_mask = &self
185            .bias
186            .slice(2, key_length - query_length, key_length, 1)
187            .slice(3, 0, key_length, 1)
188            .to_kind(Kind::Bool)
189            .to_device(attention_weights.device());
190
191        let mask_value = get_min(attention_weights.kind()).unwrap();
192        let mask_value = Tensor::full(
193            attention_weights.size(),
194            mask_value,
195            (attention_weights.kind(), attention_weights.device()),
196        );
197
198        let mut attention_weights = attention_weights.where_self(causal_mask, &mask_value);
199        if self.scale {
200            attention_weights /= self.scale_attn;
201        }
202        if let Some(attention_mask_value) = attention_mask {
203            attention_weights += attention_mask_value;
204        };
205        let attention_weights = attention_weights.softmax(-1, attention_weights.kind());
206        let attention_weights = attention_weights
207            .to_kind(value.kind())
208            .apply_t(&self.attn_dropout, train);
209
210        let attention_output = attention_weights.matmul(value);
211
212        (attention_output, attention_weights)
213    }
214
215    pub fn forward_t(
216        &self,
217        hidden_states: &Tensor,
218        attention_mask: Option<&Tensor>,
219        layer_past: Option<&LayerState>,
220        train: bool,
221    ) -> (Tensor, Option<LayerState>, Option<Tensor>) {
222        let query = hidden_states.apply(&self.q_proj);
223        let key = hidden_states.apply(&self.k_proj);
224        let value = hidden_states.apply(&self.v_proj);
225
226        let mut query = Self::split_heads(&query, self.n_head, self.dim_per_head, true);
227        let mut key = Self::split_heads(&key, self.n_head, self.dim_per_head, true);
228        let mut value = Self::split_heads(&value, self.n_head, self.dim_per_head, false);
229
230        let mut seq_len = key.size()[1];
231        let mut offset = 0;
232
233        if let Some(layer_past) = layer_past {
234            offset = layer_past.prev_key.size()[layer_past.prev_key.size().len() - 2];
235            seq_len += offset
236        };
237
238        if let Some(rotary_dim) = self.rotary_dim {
239            let k_rot = key.slice(3, 0, rotary_dim, 1);
240            let k_pass = key.slice(3, rotary_dim, key.size()[3], 1);
241
242            let q_rot = query.slice(3, 0, rotary_dim, 1);
243            let q_pass = query.slice(3, rotary_dim, query.size()[3], 1);
244
245            let sincos = fixed_pos_embedding(&k_rot, seq_len);
246            let k_rot = apply_rotary_pos_emb(&k_rot, &sincos, offset);
247            let q_rot = apply_rotary_pos_emb(&q_rot, &sincos, offset);
248
249            key = Tensor::cat(&[k_rot, k_pass], -1);
250            query = Tensor::cat(&[q_rot, q_pass], -1);
251        } else {
252            let sincos = fixed_pos_embedding(&key, seq_len);
253            key = apply_rotary_pos_emb(&key, &sincos, offset);
254            query = apply_rotary_pos_emb(&query, &sincos, offset);
255        }
256
257        key = key.permute([0, 2, 1, 3]);
258        query = query.permute([0, 2, 1, 3]);
259
260        if let Some(layer_past) = layer_past {
261            key = Tensor::cat(&[&layer_past.prev_key, &key], -2);
262            value = Tensor::cat(&[&layer_past.prev_value, &value], -2);
263        }
264
265        let present = self.use_cache.then(|| LayerState {
266            prev_key: key.copy(),
267            prev_value: value.copy(),
268        });
269
270        let (attn_output, attn_weights) =
271            self.attention(&query, &key, &value, attention_mask, train);
272
273        let attn_output = Self::merge_heads(&attn_output, self.n_head, self.dim_per_head)
274            .apply(&self.out_proj)
275            .apply_t(&self.resid_dropout, train);
276
277        let attn_weights = self.output_attentions.then_some(attn_weights);
278
279        (attn_output, present, attn_weights)
280    }
281}
282
283fn fixed_pos_embedding(x: &Tensor, seq_len: i64) -> (Tensor, Tensor) {
284    let dim = x.size()[x.size().len() - 1];
285    let inv_freq = 1.0
286        / Tensor::pow_scalar(
287            10_000,
288            &(Tensor::arange_start_step(0, dim, 2, (x.kind(), x.device())) / dim),
289        );
290    let sinusoid_inp = Tensor::einsum(
291        "i , j -> i j",
292        &[Tensor::arange(seq_len, (x.kind(), x.device())), inv_freq],
293        None::<i64>,
294    );
295    (sinusoid_inp.sin(), sinusoid_inp.cos())
296}
297
298fn apply_rotary_pos_emb(x: &Tensor, (sin, cos): &(Tensor, Tensor), offset: i64) -> Tensor {
299    let sin = duplicate_interleave(sin).i((NewAxis, offset..x.size()[1] + offset, NewAxis, ..));
300    let cos = duplicate_interleave(cos).i((NewAxis, offset..x.size()[1] + offset, NewAxis, ..));
301    (x * cos) + (rotate_every_two(x) * sin)
302}
303
304/// A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy.
305fn duplicate_interleave(m: &Tensor) -> Tensor {
306    let dim0 = m.size()[0];
307    m.view([-1, 1]) // flatten the matrix
308        .repeat([1, 2]) // repeat all elements into the 2nd dimension
309        .view([dim0, -1]) // reshape into a matrix, interleaving the copy
310}
311
312fn rotate_every_two(x: &Tensor) -> Tensor {
313    let x1 = x.slice(3, 0, x.size()[3], 2);
314    let x2 = x.slice(3, 1, x.size()[3], 2);
315    Tensor::stack(&[-x2, x1], -1).flatten(-2, -1)
316}