rust_bert/models/bert/
attention.rs

1// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
2// Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
3// Copyright 2019 Guillaume Becquin
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//     http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14use crate::bert::bert_model::BertConfig;
15use crate::common::activations::TensorFunction;
16use crate::common::dropout::Dropout;
17use std::borrow::Borrow;
18use tch::{nn, Tensor};
19
20#[derive(Debug)]
21pub struct BertSelfAttention {
22    num_attention_heads: i64,
23    attention_head_size: i64,
24    dropout: Dropout,
25    output_attentions: bool,
26    query: nn::Linear,
27    key: nn::Linear,
28    value: nn::Linear,
29}
30
31impl BertSelfAttention {
32    pub fn new<'p, P>(p: P, config: &BertConfig) -> BertSelfAttention
33    where
34        P: Borrow<nn::Path<'p>>,
35    {
36        assert_eq!(
37            config.hidden_size % config.num_attention_heads,
38            0,
39            "Hidden size not a multiple of the number of attention heads"
40        );
41        let p = p.borrow();
42
43        let query = nn::linear(
44            p / "query",
45            config.hidden_size,
46            config.hidden_size,
47            Default::default(),
48        );
49        let key = nn::linear(
50            p / "key",
51            config.hidden_size,
52            config.hidden_size,
53            Default::default(),
54        );
55        let value = nn::linear(
56            p / "value",
57            config.hidden_size,
58            config.hidden_size,
59            Default::default(),
60        );
61
62        let dropout = Dropout::new(config.attention_probs_dropout_prob);
63        let attention_head_size = config.hidden_size / config.num_attention_heads;
64        let output_attentions = config.output_attentions.unwrap_or(false);
65
66        BertSelfAttention {
67            num_attention_heads: config.num_attention_heads,
68            attention_head_size,
69            dropout,
70            output_attentions,
71            query,
72            key,
73            value,
74        }
75    }
76
77    fn split_heads(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
78        x.view((bs, -1, self.num_attention_heads, dim_per_head))
79            .transpose(1, 2)
80    }
81
82    fn flatten(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
83        x.transpose(1, 2)
84            .contiguous()
85            .view((bs, -1, self.num_attention_heads * dim_per_head))
86    }
87
88    pub fn forward_t(
89        &self,
90        hidden_states: &Tensor,
91        mask: Option<&Tensor>,
92        encoder_hidden_states: Option<&Tensor>,
93        encoder_mask: Option<&Tensor>,
94        train: bool,
95    ) -> (Tensor, Option<Tensor>) {
96        let (key_layer, value_layer, mask) = match encoder_hidden_states {
97            Some(encoder_hidden_state_values) => (
98                encoder_hidden_state_values.apply(&self.key),
99                encoder_hidden_state_values.apply(&self.value),
100                encoder_mask,
101            ),
102            None => (
103                hidden_states.apply(&self.key),
104                hidden_states.apply(&self.value),
105                mask,
106            ),
107        };
108
109        let bs = hidden_states.size()[0];
110
111        let query_layer = self.split_heads(
112            hidden_states.apply(&self.query),
113            bs,
114            self.attention_head_size,
115        );
116        let key_layer = self.split_heads(key_layer, bs, self.attention_head_size);
117        let value_layer = self.split_heads(value_layer, bs, self.attention_head_size);
118        let query_layer: Tensor = query_layer / (self.attention_head_size as f64).sqrt();
119
120        let scores = if let Some(mask) = mask {
121            query_layer.matmul(&key_layer.transpose(-1, -2)) + mask
122        } else {
123            query_layer.matmul(&key_layer.transpose(-1, -2))
124        };
125
126        let weights = scores
127            .softmax(-1, scores.kind())
128            .apply_t(&self.dropout, train);
129        let context = self.flatten(weights.matmul(&value_layer), bs, self.attention_head_size);
130
131        if !self.output_attentions {
132            (context, None)
133        } else {
134            (context, Some(weights))
135        }
136    }
137}
138
139#[derive(Debug)]
140pub struct BertSelfOutput {
141    linear: nn::Linear,
142    layer_norm: nn::LayerNorm,
143    dropout: Dropout,
144}
145
146impl BertSelfOutput {
147    pub fn new<'p, P>(p: P, config: &BertConfig) -> BertSelfOutput
148    where
149        P: Borrow<nn::Path<'p>>,
150    {
151        let p = p.borrow();
152
153        let linear = nn::linear(
154            p / "dense",
155            config.hidden_size,
156            config.hidden_size,
157            Default::default(),
158        );
159        let layer_norm_config = nn::LayerNormConfig {
160            eps: 1e-12,
161            ..Default::default()
162        };
163        let layer_norm =
164            nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
165        let dropout = Dropout::new(config.hidden_dropout_prob);
166
167        BertSelfOutput {
168            linear,
169            layer_norm,
170            dropout,
171        }
172    }
173
174    pub fn forward_t(&self, hidden_states: &Tensor, input_tensor: &Tensor, train: bool) -> Tensor {
175        let hidden_states: Tensor = input_tensor
176            + hidden_states
177                .apply(&self.linear)
178                .apply_t(&self.dropout, train);
179        hidden_states.apply(&self.layer_norm)
180    }
181}
182
183#[derive(Debug)]
184pub struct BertAttention {
185    _self: BertSelfAttention,
186    output: BertSelfOutput,
187}
188
189impl BertAttention {
190    pub fn new<'p, P>(p: P, config: &BertConfig) -> BertAttention
191    where
192        P: Borrow<nn::Path<'p>>,
193    {
194        let p = p.borrow();
195
196        let _self = BertSelfAttention::new(p / "self", config);
197        let output = BertSelfOutput::new(p / "output", config);
198        BertAttention { _self, output }
199    }
200
201    pub fn forward_t(
202        &self,
203        hidden_states: &Tensor,
204        mask: Option<&Tensor>,
205        encoder_hidden_states: Option<&Tensor>,
206        encoder_mask: Option<&Tensor>,
207        train: bool,
208    ) -> (Tensor, Option<Tensor>) {
209        let (self_output, attention_weights) = self._self.forward_t(
210            hidden_states,
211            mask,
212            encoder_hidden_states,
213            encoder_mask,
214            train,
215        );
216
217        let self_output = self.output.forward_t(&self_output, hidden_states, train);
218        (self_output, attention_weights)
219    }
220}
221
222pub struct BertIntermediate {
223    lin: nn::Linear,
224    activation: TensorFunction,
225}
226
227impl BertIntermediate {
228    pub fn new<'p, P>(p: P, config: &BertConfig) -> BertIntermediate
229    where
230        P: Borrow<nn::Path<'p>>,
231    {
232        let p = p.borrow();
233
234        let lin = nn::linear(
235            p / "dense",
236            config.hidden_size,
237            config.intermediate_size,
238            Default::default(),
239        );
240        let activation = config.hidden_act.get_function();
241        BertIntermediate { lin, activation }
242    }
243
244    pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
245        (self.activation.get_fn())(&hidden_states.apply(&self.lin))
246    }
247}
248
249pub struct BertOutput {
250    lin: nn::Linear,
251    layer_norm: nn::LayerNorm,
252    dropout: Dropout,
253}
254
255impl BertOutput {
256    pub fn new<'p, P>(p: P, config: &BertConfig) -> BertOutput
257    where
258        P: Borrow<nn::Path<'p>>,
259    {
260        let p = p.borrow();
261
262        let lin = nn::linear(
263            p / "dense",
264            config.intermediate_size,
265            config.hidden_size,
266            Default::default(),
267        );
268        let layer_norm_config = nn::LayerNormConfig {
269            eps: 1e-12,
270            ..Default::default()
271        };
272        let layer_norm =
273            nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
274        let dropout = Dropout::new(config.hidden_dropout_prob);
275
276        BertOutput {
277            lin,
278            layer_norm,
279            dropout,
280        }
281    }
282
283    pub fn forward_t(&self, hidden_states: &Tensor, input_tensor: &Tensor, train: bool) -> Tensor {
284        let hidden_states: Tensor =
285            input_tensor + hidden_states.apply(&self.lin).apply_t(&self.dropout, train);
286        hidden_states.apply(&self.layer_norm)
287    }
288}