rust_bert/models/bart/
attention.rs

1// Copyright 2020 The Facebook AI Research Team Authors
2// Copyright 2020-present, the HuggingFace Inc. team.
3// Copyright 2020 Guillaume Becquin
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//     http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14use crate::common::dropout::Dropout;
15use std::borrow::Borrow;
16use tch::{nn, Tensor};
17
18#[derive(Debug)]
19/// # Cache for BART attention layers
20/// Stores the cached value of key, value and key padding mask to avoid recalculation (e.g. at each generation step)
21pub struct LayerState {
22    /// Cached keys
23    pub prev_key: Tensor,
24    /// Cached values
25    pub prev_value: Tensor,
26}
27
28impl Clone for LayerState {
29    fn clone(&self) -> Self {
30        LayerState {
31            prev_key: self.prev_key.copy(),
32            prev_value: self.prev_value.copy(),
33        }
34    }
35}
36
37impl LayerState {
38    pub(crate) fn reorder_cache(&mut self, new_indices: &Tensor) {
39        self.prev_key = self.prev_key.index_select(0, new_indices);
40        self.prev_value = self.prev_value.index_select(0, new_indices);
41    }
42}
43
44#[derive(Debug)]
45pub struct BartAttention {
46    num_heads: i64,
47    head_dim: i64,
48    dropout: Dropout,
49    scaling: f64,
50    encoder_decoder_attention: bool,
51    output_attentions: bool,
52    k_proj: nn::Linear,
53    v_proj: nn::Linear,
54    q_proj: nn::Linear,
55    out_proj: nn::Linear,
56    store_cache: bool,
57}
58
59impl BartAttention {
60    pub fn new<'p, P>(
61        p: P,
62        embed_dim: i64,
63        num_heads: i64,
64        dropout: f64,
65        encoder_decoder_attention: bool,
66        store_cache: bool,
67        output_attentions: bool,
68    ) -> BartAttention
69    where
70        P: Borrow<nn::Path<'p>>,
71    {
72        let p = p.borrow();
73
74        let k_proj = nn::linear(p / "k_proj", embed_dim, embed_dim, Default::default());
75        let v_proj = nn::linear(p / "v_proj", embed_dim, embed_dim, Default::default());
76        let q_proj = nn::linear(p / "q_proj", embed_dim, embed_dim, Default::default());
77        let out_proj = nn::linear(p / "out_proj", embed_dim, embed_dim, Default::default());
78
79        let head_dim = embed_dim / num_heads;
80        let scaling = (head_dim as f64).powf(-0.5);
81        let dropout = Dropout::new(dropout);
82
83        BartAttention {
84            num_heads,
85            head_dim,
86            dropout,
87            scaling,
88            encoder_decoder_attention,
89            output_attentions,
90            k_proj,
91            v_proj,
92            q_proj,
93            out_proj,
94            store_cache,
95        }
96    }
97
98    fn _shape(&self, x: Tensor, sequence_length: i64, batch_size: i64) -> Tensor {
99        x.view((batch_size, sequence_length, self.num_heads, self.head_dim))
100            .transpose(1, 2)
101            .contiguous()
102    }
103
104    pub fn forward_t(
105        &self,
106        hidden_states: &Tensor,
107        key_value_states: Option<&Tensor>,
108        attention_mask: Option<&Tensor>,
109        layer_state: Option<LayerState>,
110        train: bool,
111    ) -> (Tensor, Option<Tensor>, Option<LayerState>) {
112        let (bs, target_length, embed_dim) = hidden_states.size3().unwrap();
113
114        let query_states = hidden_states.apply(&self.q_proj) * self.scaling;
115
116        let (key_states, value_states) = if self.encoder_decoder_attention {
117            if let Some(layer_state_value) = layer_state {
118                (layer_state_value.prev_key, layer_state_value.prev_value)
119            } else {
120                (
121                    self._shape(key_value_states.unwrap().apply(&self.k_proj), -1, bs),
122                    self._shape(key_value_states.unwrap().apply(&self.v_proj), -1, bs),
123                )
124            }
125        } else if let Some(layer_state_value) = layer_state {
126            let key_states = self._shape(hidden_states.apply(&self.k_proj), -1, bs);
127            let value_states = self._shape(hidden_states.apply(&self.v_proj), -1, bs);
128            (
129                Tensor::cat(&[layer_state_value.prev_key, key_states], 2),
130                Tensor::cat(&[layer_state_value.prev_value, value_states], 2),
131            )
132        } else {
133            (
134                self._shape(hidden_states.apply(&self.k_proj), -1, bs),
135                self._shape(hidden_states.apply(&self.v_proj), -1, bs),
136            )
137        };
138
139        let new_layer_state = if self.store_cache {
140            Some(LayerState {
141                prev_key: key_states.copy(),
142                prev_value: value_states.copy(),
143            })
144        } else {
145            None
146        };
147
148        let proj_shape = [bs * self.num_heads, -1, self.head_dim];
149        let query_states = self
150            ._shape(query_states, target_length, bs)
151            .view(proj_shape);
152        let key_states = key_states.view(proj_shape);
153        let value_states = value_states.view(proj_shape);
154
155        let source_length = key_states.size()[1];
156        let mut attention_weights = query_states.bmm(&key_states.transpose(1, 2));
157
158        if let Some(attention_mask_value) = attention_mask {
159            attention_weights =
160                attention_weights.view([bs, self.num_heads, target_length, source_length])
161                    + attention_mask_value;
162            attention_weights =
163                attention_weights.view([bs * self.num_heads, target_length, source_length]);
164        };
165
166        attention_weights = attention_weights.softmax(-1, attention_weights.kind());
167
168        let saved_attention_weights = if self.output_attentions {
169            Some(attention_weights.view((bs, self.num_heads, target_length, source_length)))
170        } else {
171            None
172        };
173
174        let attention_probas = attention_weights.apply_t(&self.dropout, train);
175        let attention_output = attention_probas
176            .bmm(&value_states)
177            .view([bs, self.num_heads, target_length, self.head_dim])
178            .transpose(1, 2)
179            .reshape([bs, target_length, embed_dim])
180            .apply(&self.out_proj);
181
182        (attention_output, saved_attention_weights, new_layer_state)
183    }
184}