1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
use crate::common::dropout::Dropout;
use crate::gpt_neo::gpt_neo_model::AttentionLayerType;
use crate::gpt_neo::GptNeoConfig;
use std::borrow::Borrow;
use tch::{nn, Kind, Tensor};
#[derive(Debug)]
pub struct LayerState {
pub prev_key: Tensor,
pub prev_value: Option<Tensor>,
}
impl Clone for LayerState {
fn clone(&self) -> Self {
LayerState {
prev_key: self.prev_key.copy(),
prev_value: self.prev_value.as_ref().map(|value| value.copy()),
}
}
}
impl LayerState {
pub(crate) fn reorder_cache(&mut self, new_indices: &Tensor) {
self.prev_key = self.prev_key.index_select(0, new_indices);
self.prev_value = self
.prev_value
.as_ref()
.map(|value| value.index_select(0, new_indices));
}
}
pub struct GptNeoSelfAttention {
k_proj: nn::Linear,
v_proj: nn::Linear,
q_proj: nn::Linear,
out_proj: nn::Linear,
attention_dropout: Dropout,
resid_dropout: Dropout,
bias: Tensor,
num_heads: i64,
head_dim: i64,
output_attentions: bool,
}
impl GptNeoSelfAttention {
pub fn new<'p, P>(
p: P,
config: &GptNeoConfig,
attention_type: &AttentionLayerType,
) -> GptNeoSelfAttention
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let max_positions = config.max_position_embeddings;
let mut bias = Tensor::ones(&[max_positions, max_positions], (Kind::Uint8, p.device()))
.tril(0)
.view([1, 1, max_positions, max_positions])
.requires_grad_(false);
if attention_type == &AttentionLayerType::Local {
let _ = bias.bitwise_or_tensor_(&bias.tril(-config.window_size));
}
let attention_dropout = Dropout::new(config.attention_dropout);
let resid_dropout = Dropout::new(config.resid_dropout);
let num_heads = config.num_heads;
let head_dim = config.hidden_size / config.num_heads;
let linear_config = nn::LinearConfig {
bias: false,
..Default::default()
};
let k_proj = nn::linear(
p / "k_proj",
config.hidden_size,
config.hidden_size,
linear_config,
);
let v_proj = nn::linear(
p / "v_proj",
config.hidden_size,
config.hidden_size,
linear_config,
);
let q_proj = nn::linear(
p / "q_proj",
config.hidden_size,
config.hidden_size,
linear_config,
);
let out_proj = nn::linear(
p / "out_proj",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let output_attentions = config.output_attentions.unwrap_or(false);
GptNeoSelfAttention {
k_proj,
v_proj,
q_proj,
out_proj,
attention_dropout,
resid_dropout,
bias,
num_heads,
head_dim,
output_attentions,
}
}
fn split_heads(input_tensor: &Tensor, num_heads: i64, attention_head_size: i64) -> Tensor {
let mut new_shape = input_tensor.size();
let _ = new_shape.pop();
new_shape.extend_from_slice(&[num_heads, attention_head_size]);
let reshaped_tensor = input_tensor.view(new_shape.as_slice());
reshaped_tensor.permute(&[0, 2, 1, 3])
}
fn merge_heads(input_tensor: &Tensor, num_heads: i64, attention_head_size: i64) -> Tensor {
let output_tensor = input_tensor.permute(&[0, 2, 1, 3]).contiguous();
let mut new_shape = output_tensor.size();
new_shape.truncate(new_shape.len() - 2);
new_shape.push(num_heads * attention_head_size);
output_tensor.view(new_shape.as_slice())
}
fn attend(
&self,
query: &Tensor,
key: &Tensor,
value: &Tensor,
attention_mask: Option<&Tensor>,
train: bool,
) -> (Tensor, Tensor) {
let query = query.to_kind(Kind::Float);
let key = key.to_kind(Kind::Float);
let attention_weights = query.matmul(&key.transpose(-1, -2));
let query_dims = query.size();
let key_dims = key.size();
let query_length = query_dims[query_dims.len() - 2];
let key_length = key_dims[key_dims.len() - 2];
let causal_mask = &self
.bias
.slice(2, key_length - query_length, key_length, 1)
.slice(3, 0, key_length, 1)
.to_kind(Kind::Bool)
.to_device(attention_weights.device());
let mut attention_weights = attention_weights.where_self(
causal_mask,
&Tensor::of_slice(&[-1e9f32]).to_device(attention_weights.device()),
);
if let Some(attention_mask_value) = attention_mask {
attention_weights = attention_weights + attention_mask_value;
};
let attention_weights = attention_weights.softmax(-1, attention_weights.kind());
let attention_weights = attention_weights
.to_kind(value.kind())
.apply_t(&self.attention_dropout, train);
let attention_output = attention_weights.matmul(value);
(attention_output, attention_weights)
}
pub fn forward_t(
&self,
hidden_states: &Tensor,
layer_state: Option<&LayerState>,
attention_mask: Option<&Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>, Option<LayerState>) {
let query = hidden_states.apply(&self.q_proj);
let key = hidden_states.apply(&self.k_proj);
let value = hidden_states.apply(&self.v_proj);
let query = Self::split_heads(&query, self.num_heads, self.head_dim);
let mut key = Self::split_heads(&key, self.num_heads, self.head_dim);
let mut value = Self::split_heads(&value, self.num_heads, self.head_dim);
if let Some(layer_state_value) = &layer_state {
key = Tensor::cat(&[&layer_state_value.prev_key, &key], -2);
value = Tensor::cat(
&[layer_state_value.prev_value.as_ref().unwrap(), &value],
-2,
);
};
let layer_state = Some(LayerState {
prev_key: key.copy(),
prev_value: Some(value.copy()),
});
let (attention_output, attention_weights) =
self.attend(&query, &key, &value, attention_mask, train);
let attention_output = Self::merge_heads(&attention_output, self.num_heads, self.head_dim)
.apply(&self.out_proj)
.apply_t(&self.resid_dropout, train);
let attention_weights = if self.output_attentions {
Some(attention_weights)
} else {
None
};
(attention_output, attention_weights, layer_state)
}
}