rust_bert/models/longt5/
attention.rs

1// Copyright 2022 Google LLC., LongT5 Authors and HuggingFace Inc. team.
2// Copyright 2022 Guillaume Becquin
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//     http://www.apache.org/licenses/LICENSE-2.0
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use crate::common::dropout::Dropout;
14use crate::longt5::layer_norm::LongT5LayerNorm;
15use crate::longt5::LongT5Config;
16use crate::t5::{
17    get_relative_position_bucket, LayerState as T5layerState, T5Attention, T5LayerCrossAttention,
18};
19use std::borrow::Borrow;
20use tch::nn::LinearConfig;
21use tch::{nn, Device, IndexOp, Kind, Tensor};
22
23pub type LongT5Attention = T5Attention;
24pub type LongT5LayerCrossAttention = T5LayerCrossAttention;
25pub type LayerState = T5layerState;
26
27fn pad_to_multiple(x: &Tensor, block_length: i64, dim: usize, pad_value: f64) -> Tensor {
28    let mut x_size = x.size();
29    let pad_length = (-x_size[dim]).rem_euclid(block_length);
30
31    if x_size.iter().any(|&el| el == 0) {
32        x_size[dim] += pad_length;
33        Tensor::zeros(x_size.as_slice(), (x.kind(), x.device()))
34    } else {
35        let mut pad = vec![0i64; 2 * x.dim()];
36        pad[2 * dim] = pad_length;
37        pad.reverse();
38        x.pad(pad.as_slice(), "constant", pad_value)
39    }
40}
41
42fn split_into_blocks(x: &Tensor, block_length: i64, dim: usize) -> Tensor {
43    let x_size = x.size();
44    let padded_x = if x_size[dim] % block_length != 0 {
45        Some(pad_to_multiple(x, block_length, dim, 0f64))
46    } else {
47        None
48    };
49    let x = padded_x.as_ref().unwrap_or(x);
50    let mut x_size = x.size();
51    let num_blocks = x_size[dim] / block_length;
52    x_size.remove(dim);
53    x_size.insert(dim, block_length);
54    x_size.insert(dim, num_blocks);
55    if x_size.iter().any(|&el| el == 0) {
56        Tensor::empty(x_size.as_slice(), (x.kind(), x.device()))
57    } else {
58        x.reshape(x_size.as_slice())
59    }
60}
61
62fn concatenate_3_blocks(
63    x: &Tensor,
64    block_dim: usize,
65    sequence_dim: i64,
66    pad_value: Option<f64>,
67) -> Tensor {
68    let x_size = x.size();
69    let num_blocks = x_size[block_dim];
70    let mut pad = vec![0i64; 2 * x.dim()];
71    pad[2 * block_dim] = 1;
72    pad[2 * block_dim + 1] = 1;
73    pad.reverse();
74    let x = x.pad(pad.as_slice(), "constant", pad_value.unwrap_or(0f64));
75    let mut block_list: Vec<Tensor> = Vec::with_capacity(3);
76    for i in 0..3 {
77        block_list.push(x.narrow(block_dim as i64, i, num_blocks));
78    }
79    Tensor::cat(block_list.as_slice(), sequence_dim)
80}
81
82fn make_3blocks_relative_position_ids(block_length: i64, device: Device) -> Tensor {
83    let position_ids = Tensor::arange(3 * block_length, (Kind::Int, device));
84    let center_position_ids = position_ids.i(block_length..2 * block_length);
85    position_ids.unsqueeze(0) - center_position_ids.unsqueeze(1)
86}
87
88fn mask_local_attention_mask(local_attention_mask: &Tensor, block_length: i64) -> Tensor {
89    let relative_position_ids =
90        make_3blocks_relative_position_ids(block_length, local_attention_mask.device());
91    let locality_mask = relative_position_ids
92        .abs()
93        .lt(block_length)
94        .unsqueeze(0)
95        .unsqueeze(0);
96    local_attention_mask.logical_and(&locality_mask)
97}
98
99pub(crate) fn get_local_attention_mask(attention_mask: &Tensor, block_length: i64) -> Tensor {
100    let blocked_attention_mask = split_into_blocks(attention_mask, block_length, 1);
101    let three_blocked_attention_mask = concatenate_3_blocks(&blocked_attention_mask, 1, 2, None);
102
103    let blocked_attention_mask = blocked_attention_mask.unsqueeze(-1);
104    let three_blocked_attention_mask = three_blocked_attention_mask.unsqueeze(-2);
105
106    let local_attention_mask = mask_local_attention_mask(
107        &blocked_attention_mask.logical_and(&three_blocked_attention_mask),
108        block_length,
109    );
110    local_attention_mask.unsqueeze(1)
111}
112
113fn make_global_fixed_block_ids(
114    attention_mask: &Tensor,
115    global_block_size: i64,
116) -> (Tensor, Tensor) {
117    let &[batch_size, seq_length, ..] = attention_mask.size().as_slice() else {
118        unreachable!()
119    };
120
121    let handle_orphan_tokens = |block_ids: Tensor| -> Tensor {
122        let block_ends = Tensor::arange(seq_length, (Kind::Int64, block_ids.device()))
123            .remainder(global_block_size)
124            .eq(global_block_size - 1);
125        let true_block_ends = block_ends.logical_and(&block_ids.ge(0));
126        let full_blocks = true_block_ends
127            .sum_dim_intlist([-1].as_slice(), false, block_ids.kind())
128            .unsqueeze(-1)
129            - 1;
130        block_ids.where_self(&block_ids.lt_tensor(&full_blocks), &full_blocks)
131    };
132
133    let fixed_block_mask = attention_mask.ones_like() / global_block_size;
134    let fixed_block_mask = fixed_block_mask.cumsum(1, fixed_block_mask.kind()) - fixed_block_mask;
135    let mask = attention_mask
136        .ones_like()
137        .where_scalarother(&attention_mask.not_equal(0.0), -1000.0);
138
139    let mut global_block_ids = (mask + fixed_block_mask - 1.0).floor();
140    global_block_ids = global_block_ids.where_scalarother(&global_block_ids.gt(-1.0), -1.0);
141    global_block_ids = global_block_ids * attention_mask + attention_mask - 1;
142    global_block_ids = handle_orphan_tokens(global_block_ids);
143    let num_globals = seq_length / global_block_size;
144    let sequence_block_ids_max = if num_globals > 0 {
145        global_block_ids
146            .max_dim(-1, false)
147            .0
148            .repeat([num_globals, 1])
149            .transpose(0, 1)
150    } else {
151        Tensor::zeros(
152            [batch_size, 0],
153            (global_block_ids.kind(), global_block_ids.device()),
154        )
155    };
156    let global_segment_ids = Tensor::ones(
157        [batch_size, num_globals],
158        (attention_mask.kind(), attention_mask.device()),
159    )
160    .cumsum(-1, attention_mask.kind())
161        - 1;
162    let global_segment_ids = global_segment_ids
163        .ones_like()
164        .where_scalarother(&global_segment_ids.le_tensor(&sequence_block_ids_max), 0.0);
165    (
166        global_block_ids.to_kind(Kind::Int),
167        global_segment_ids.to_kind(Kind::Int),
168    )
169}
170
171fn make_side_relative_position_ids(attention_mask: &Tensor, global_block_size: i64) -> Tensor {
172    let (block_ids, global_segment_ids) =
173        make_global_fixed_block_ids(attention_mask, global_block_size);
174    let global_seq_length = *global_segment_ids.size().last().unwrap();
175    let global_positions = Tensor::arange(global_seq_length, (Kind::Int64, block_ids.device()));
176    global_positions - block_ids.unsqueeze(-1)
177}
178
179fn create_global_aggregates(
180    hidden_states: &Tensor,
181    block_ids: &Tensor,
182    global_seq_length: i64,
183) -> Tensor {
184    let block_ids = block_ids.where_scalarother(&block_ids.ge(0), global_seq_length);
185    let one_hot_block_ids = block_ids
186        .to_kind(Kind::Int64)
187        .one_hot(global_seq_length + 1);
188    let one_hot_block_ids = one_hot_block_ids.narrow(2, 0, one_hot_block_ids.size()[2] - 1);
189    Tensor::einsum(
190        "...nd,...ng->...gd",
191        &[
192            hidden_states,
193            &one_hot_block_ids.to_kind(hidden_states.kind()),
194        ],
195        None::<i64>,
196    )
197}
198
199fn compute_bias(
200    block_length: i64,
201    relative_attention_bias: &nn::Embedding,
202    is_decoder: bool,
203    relative_attention_num_buckets: i64,
204    relative_attention_max_distance: i64,
205) -> Tensor {
206    let device = relative_attention_bias.ws.device();
207    let memory_position = Tensor::arange(3 * block_length, (Kind::Int64, device));
208    let context_position = memory_position.narrow(0, block_length, block_length);
209    let relative_position = memory_position.unsqueeze(0) - context_position.unsqueeze(-1);
210
211    let rp_bucket = get_relative_position_bucket(
212        &relative_position,
213        !is_decoder,
214        relative_attention_num_buckets,
215        relative_attention_max_distance,
216    );
217    rp_bucket
218        .apply(relative_attention_bias)
219        .permute([2, 0, 1])
220        .unsqueeze(0)
221        .unsqueeze(0)
222}
223
224pub struct LongT5LocalAttention {
225    is_decoder: bool,
226    has_relative_attention_bias: bool,
227    relative_attention_num_buckets: i64,
228    relative_attention_max_distance: i64,
229    key_value_proj_dim: i64,
230    n_heads: i64,
231    block_length: i64,
232    dropout: Dropout,
233    inner_dim: i64,
234    output_attentions: bool,
235    query: nn::Linear,
236    key: nn::Linear,
237    value: nn::Linear,
238    output: nn::Linear,
239    relative_attention_bias: Option<nn::Embedding>,
240}
241
242impl LongT5LocalAttention {
243    pub fn new<'p, P>(
244        p: P,
245        config: &LongT5Config,
246        is_decoder: bool,
247        has_relative_attention_bias: bool,
248    ) -> LongT5LocalAttention
249    where
250        P: Borrow<nn::Path<'p>>,
251    {
252        let p = p.borrow();
253
254        let linear_config = LinearConfig {
255            bias: false,
256            ..Default::default()
257        };
258
259        let block_length = config.local_radius + 1;
260        let key_value_proj_dim = config.d_kv;
261
262        let inner_dim = config.num_heads * config.d_kv;
263        let key = nn::linear(p / "k", config.d_model, inner_dim, linear_config);
264        let value = nn::linear(p / "v", config.d_model, inner_dim, linear_config);
265        let query = nn::linear(p / "q", config.d_model, inner_dim, linear_config);
266        let output = nn::linear(p / "o", inner_dim, config.d_model, linear_config);
267
268        let dropout = Dropout::new(config.dropout_rate);
269        let relative_attention_bias = if has_relative_attention_bias {
270            Some(nn::embedding(
271                p / "relative_attention_bias",
272                config.relative_attention_num_buckets,
273                config.num_heads,
274                Default::default(),
275            ))
276        } else {
277            None
278        };
279
280        LongT5LocalAttention {
281            is_decoder,
282            has_relative_attention_bias,
283            relative_attention_num_buckets: config.relative_attention_num_buckets,
284            relative_attention_max_distance: config.relative_attention_max_distance.unwrap_or(128),
285            key_value_proj_dim,
286            n_heads: config.num_heads,
287            block_length,
288            dropout,
289            inner_dim,
290            output_attentions: config.output_attentions.unwrap_or(false),
291            query,
292            key,
293            value,
294            output,
295            relative_attention_bias,
296        }
297    }
298
299    pub fn forward_t(
300        &self,
301        hidden_states: &Tensor,
302        mask: Option<&Tensor>,
303        position_bias: Option<&Tensor>,
304        train: bool,
305    ) -> (Tensor, Option<Tensor>, Option<Tensor>) {
306        let input_size = hidden_states.size();
307        let (batch_size, seq_length) = (input_size[0], input_size[1]);
308
309        let shape = |states: &Tensor| -> Tensor {
310            states.view([batch_size, -1, self.n_heads, self.key_value_proj_dim])
311        };
312        let unshape = |states: &Tensor| -> Tensor {
313            states.contiguous().view([batch_size, -1, self.inner_dim])
314        };
315
316        let query_states = shape(&hidden_states.apply(&self.query));
317        let key_states = shape(&hidden_states.apply(&self.key));
318        let value_states = shape(&hidden_states.apply(&self.value));
319
320        let query_states = split_into_blocks(&query_states, self.block_length, 1);
321        let key_states = split_into_blocks(&key_states, self.block_length, 1);
322        let value_states = split_into_blocks(&value_states, self.block_length, 1);
323
324        let key_states = concatenate_3_blocks(&key_states, 1, 2, None);
325        let value_states = concatenate_3_blocks(&value_states, 1, 2, None);
326
327        let mut scores = Tensor::einsum(
328            "...qhd,...khd->...hqk",
329            &[query_states, key_states],
330            None::<i64>,
331        );
332        let calc_position_bias = if position_bias.is_none() {
333            let mut position_bias = if !self.has_relative_attention_bias {
334                Tensor::zeros(
335                    [1, 1, self.n_heads, self.block_length, 3 * self.block_length],
336                    (scores.kind(), scores.device()),
337                )
338            } else {
339                compute_bias(
340                    self.block_length,
341                    self.relative_attention_bias.as_ref().unwrap(),
342                    self.is_decoder,
343                    self.relative_attention_num_buckets,
344                    self.relative_attention_max_distance,
345                )
346            };
347            if let Some(mask) = mask {
348                let mask = mask.zeros_like().where_scalarother(&mask.gt(0), -1e10);
349                position_bias = position_bias + mask.transpose(1, 2);
350            }
351            Some(position_bias)
352        } else {
353            None
354        };
355        let position_bias = position_bias.unwrap_or_else(|| calc_position_bias.as_ref().unwrap());
356        scores += position_bias;
357        let attention_weights = scores
358            .to_kind(Kind::Float)
359            .softmax(-1, scores.kind())
360            .apply_t(&self.dropout, train)
361            .to_kind(value_states.kind());
362        let attention_output = unshape(&Tensor::einsum(
363            "...hqk,...khd->...qhd",
364            &[&attention_weights, &value_states],
365            None::<i64>,
366        ))
367        .narrow(1, 0, seq_length)
368        .apply(&self.output);
369
370        let attention_weights = if self.output_attentions {
371            Some(attention_weights)
372        } else {
373            None
374        };
375
376        let position_bias = if self.has_relative_attention_bias {
377            calc_position_bias
378        } else {
379            None
380        };
381        (attention_output, position_bias, attention_weights)
382    }
383}
384
385pub struct LongT5TransientGlobalAttention {
386    is_decoder: bool,
387    has_relative_attention_bias: bool,
388    relative_attention_num_buckets: i64,
389    relative_attention_max_distance: i64,
390    key_value_proj_dim: i64,
391    n_heads: i64,
392    block_length: i64,
393    global_block_size: i64,
394    dropout: Dropout,
395    inner_dim: i64,
396    output_attentions: bool,
397    query: nn::Linear,
398    key: nn::Linear,
399    value: nn::Linear,
400    output: nn::Linear,
401    relative_attention_bias: Option<nn::Embedding>,
402    global_relative_attention_bias: Option<nn::Embedding>,
403    global_input_layer_norm: LongT5LayerNorm,
404}
405
406impl LongT5TransientGlobalAttention {
407    pub fn new<'p, P>(
408        p: P,
409        config: &LongT5Config,
410        is_decoder: bool,
411        has_relative_attention_bias: bool,
412    ) -> LongT5TransientGlobalAttention
413    where
414        P: Borrow<nn::Path<'p>>,
415    {
416        let p = p.borrow();
417
418        let linear_config = LinearConfig {
419            bias: false,
420            ..Default::default()
421        };
422
423        let block_length = config.local_radius + 1;
424        let global_block_size = config.global_block_size;
425        let key_value_proj_dim = config.d_kv;
426
427        let inner_dim = config.num_heads * config.d_kv;
428        let key = nn::linear(p / "k", config.d_model, inner_dim, linear_config);
429        let value = nn::linear(p / "v", config.d_model, inner_dim, linear_config);
430        let query = nn::linear(p / "q", config.d_model, inner_dim, linear_config);
431        let output = nn::linear(p / "o", inner_dim, config.d_model, linear_config);
432
433        let dropout = Dropout::new(config.dropout_rate);
434        let global_relative_attention_bias = if has_relative_attention_bias {
435            Some(nn::embedding(
436                p / "global_relative_attention_bias",
437                config.relative_attention_num_buckets,
438                config.num_heads,
439                Default::default(),
440            ))
441        } else {
442            None
443        };
444        let relative_attention_bias = if has_relative_attention_bias {
445            Some(nn::embedding(
446                p / "relative_attention_bias",
447                config.relative_attention_num_buckets,
448                config.num_heads,
449                Default::default(),
450            ))
451        } else {
452            None
453        };
454        let global_input_layer_norm = LongT5LayerNorm::new(
455            p / "global_input_layer_norm",
456            config.d_model,
457            config.layer_norm_epsilon,
458        );
459
460        LongT5TransientGlobalAttention {
461            is_decoder,
462            has_relative_attention_bias,
463            relative_attention_num_buckets: config.relative_attention_num_buckets,
464            relative_attention_max_distance: config.relative_attention_max_distance.unwrap_or(128),
465            key_value_proj_dim,
466            n_heads: config.num_heads,
467            block_length,
468            global_block_size,
469            dropout,
470            inner_dim,
471            output_attentions: config.output_attentions.unwrap_or(false),
472            query,
473            key,
474            value,
475            output,
476            relative_attention_bias,
477            global_relative_attention_bias,
478            global_input_layer_norm,
479        }
480    }
481
482    fn compute_side_bias(&self, mask: &Tensor, global_segment_ids: &Tensor) -> Tensor {
483        let side_attention_mask = mask
484            .unsqueeze(-1)
485            .eq_tensor(&global_segment_ids.unsqueeze(1))
486            .unsqueeze(1);
487
488        let attention_side_bias = side_attention_mask
489            .zeros_like()
490            .where_scalarother(&side_attention_mask.gt(0), -1e10);
491
492        let side_relative_position = make_side_relative_position_ids(mask, self.global_block_size);
493        let side_relative_position_bucket = get_relative_position_bucket(
494            &side_relative_position,
495            !self.is_decoder,
496            self.relative_attention_num_buckets,
497            self.relative_attention_max_distance,
498        );
499        let side_bias = side_relative_position_bucket
500            .apply(self.global_relative_attention_bias.as_ref().unwrap())
501            .permute([0, 3, 1, 2]);
502        attention_side_bias + side_bias
503    }
504
505    pub fn forward_t(
506        &self,
507        hidden_states: &Tensor,
508        mask: Option<&Tensor>,
509        position_bias: Option<&Tensor>,
510        train: bool,
511    ) -> (Tensor, Option<Tensor>, Option<Tensor>) {
512        let input_size = hidden_states.size();
513        let (batch_size, seq_length) = (input_size[0], input_size[1]);
514
515        let shape = |states: &Tensor| -> Tensor {
516            states.view([batch_size, -1, self.n_heads, self.key_value_proj_dim])
517        };
518        let unshape = |states: &Tensor| -> Tensor {
519            states.contiguous().view([batch_size, -1, self.inner_dim])
520        };
521        let calc_mask = if mask.is_none() {
522            let mut mask_size = input_size;
523            let _ = mask_size.pop();
524            Some(Tensor::ones(
525                mask_size.as_slice(),
526                (Kind::Bool, hidden_states.device()),
527            ))
528        } else {
529            None
530        };
531        let (block_ids, global_segment_ids) = make_global_fixed_block_ids(
532            mask.unwrap_or_else(|| calc_mask.as_ref().unwrap()),
533            self.global_block_size,
534        );
535        let global_seq_length = *global_segment_ids.size().last().unwrap();
536        let global_inputs = create_global_aggregates(hidden_states, &block_ids, global_seq_length)
537            .apply(&self.global_input_layer_norm);
538
539        let query_states = shape(&hidden_states.apply(&self.query));
540        let key_states = shape(&hidden_states.apply(&self.key));
541        let value_states = shape(&hidden_states.apply(&self.value));
542
543        let side_key_states = shape(&global_inputs.apply(&self.key));
544        let side_value_states = shape(&global_inputs.apply(&self.value));
545
546        let query_states = split_into_blocks(&query_states, self.block_length, 1);
547        let key_states = split_into_blocks(&key_states, self.block_length, 1);
548        let value_states = split_into_blocks(&value_states, self.block_length, 1);
549
550        let key_states = concatenate_3_blocks(&key_states, 1, 2, None);
551        let value_states = concatenate_3_blocks(&value_states, 1, 2, None);
552
553        let mut reps = vec![1; side_key_states.dim() + 1];
554        reps[1] = key_states.size()[1];
555        let side_key_states = side_key_states.unsqueeze(1).repeat(reps.as_slice());
556        let side_value_states = side_value_states.unsqueeze(1).repeat(reps.as_slice());
557        let key_states = Tensor::cat(&[key_states, side_key_states], 2);
558        let value_states = Tensor::cat(&[value_states, side_value_states], 2);
559
560        let mut scores = Tensor::einsum(
561            "...qhd,...khd->...hqk",
562            &[query_states, key_states],
563            None::<i64>,
564        );
565        let local_attention_mask = mask.map(|mask| {
566            let local_attention_mask = get_local_attention_mask(mask, self.block_length);
567            local_attention_mask
568                .zeros_like()
569                .where_scalarother(&local_attention_mask.gt(0), -1e10)
570        });
571
572        let calc_position_bias = if position_bias.is_none() {
573            let mut position_bias = if !self.has_relative_attention_bias {
574                Tensor::zeros(
575                    [1, 1, self.n_heads, self.block_length, 3 * self.block_length],
576                    (scores.kind(), scores.device()),
577                )
578            } else {
579                compute_bias(
580                    self.block_length,
581                    self.relative_attention_bias.as_ref().unwrap(),
582                    self.is_decoder,
583                    self.relative_attention_num_buckets,
584                    self.relative_attention_max_distance,
585                )
586            };
587            if let Some(local_attention_mask) = local_attention_mask {
588                position_bias = position_bias + local_attention_mask.transpose(1, 2);
589            }
590            let calc_mask = if mask.is_none() {
591                Some(Tensor::ones(
592                    [batch_size, seq_length],
593                    (global_segment_ids.kind(), global_segment_ids.device()),
594                ))
595            } else {
596                None
597            };
598            let mask = mask.unwrap_or_else(|| calc_mask.as_ref().unwrap());
599            let side_position_bias = self.compute_side_bias(mask, &global_segment_ids);
600            let side_position_bias = split_into_blocks(
601                &side_position_bias,
602                self.block_length,
603                side_position_bias.dim() - 2,
604            )
605            .transpose(1, 2);
606            let position_bias = Tensor::cat(&[position_bias, side_position_bias], -1);
607
608            Some(position_bias)
609        } else {
610            None
611        };
612        let position_bias = position_bias.unwrap_or_else(|| calc_position_bias.as_ref().unwrap());
613
614        scores += position_bias;
615        let attention_weights = scores
616            .to_kind(Kind::Float)
617            .softmax(-1, scores.kind())
618            .apply_t(&self.dropout, train);
619
620        let attention_output = unshape(&Tensor::einsum(
621            "...hqk,...khd->...qhd",
622            &[&attention_weights, &value_states],
623            None::<i64>,
624        ))
625        .narrow(1, 0, seq_length)
626        .apply(&self.output);
627
628        let attention_weights = if self.output_attentions {
629            Some(attention_weights)
630        } else {
631            None
632        };
633
634        let position_bias = if self.has_relative_attention_bias {
635            calc_position_bias
636        } else {
637            None
638        };
639        (attention_output, position_bias, attention_weights)
640    }
641}
642
643pub struct LongT5LayerSelfAttention {
644    self_attention: LongT5Attention,
645    layer_norm: LongT5LayerNorm,
646    dropout: Dropout,
647}
648
649impl LongT5LayerSelfAttention {
650    pub fn new<'p, P>(
651        p: P,
652        config: &LongT5Config,
653        has_relative_attention_bias: bool,
654        is_decoder: bool,
655        store_cache: bool,
656        output_attentions: bool,
657    ) -> LongT5LayerSelfAttention
658    where
659        P: Borrow<nn::Path<'p>>,
660    {
661        let p = p.borrow();
662
663        let self_attention = LongT5Attention::new(
664            p / "SelfAttention",
665            &config.into(),
666            is_decoder,
667            !is_decoder,
668            store_cache,
669            output_attentions,
670            has_relative_attention_bias,
671        );
672
673        let layer_norm =
674            LongT5LayerNorm::new(p / "layer_norm", config.d_model, config.layer_norm_epsilon);
675        let dropout = Dropout::new(config.dropout_rate);
676
677        LongT5LayerSelfAttention {
678            self_attention,
679            layer_norm,
680            dropout,
681        }
682    }
683
684    pub fn forward_t(
685        &self,
686        hidden_states: &Tensor,
687        position_bias: Option<&Tensor>,
688        attention_mask: Option<&Tensor>,
689        layer_state: Option<LayerState>,
690        train: bool,
691    ) -> (Tensor, Option<Tensor>, Option<Tensor>, Option<LayerState>) {
692        let norm_x = hidden_states.apply(&self.layer_norm);
693
694        let (y, attention_weights, position_bias, layer_state) = self.self_attention.forward_t(
695            &norm_x,
696            None,
697            position_bias,
698            attention_mask,
699            layer_state,
700            None,
701            train,
702        );
703
704        let output = hidden_states + y.apply_t(&self.dropout, train);
705
706        (output, attention_weights, position_bias, layer_state)
707    }
708}
709
710pub struct LongT5LayerLocalSelfAttention {
711    local_self_attention: LongT5LocalAttention,
712    layer_norm: LongT5LayerNorm,
713    dropout: Dropout,
714}
715
716impl LongT5LayerLocalSelfAttention {
717    pub fn new<'p, P>(
718        p: P,
719        config: &LongT5Config,
720        has_relative_attention_bias: bool,
721        is_decoder: bool,
722    ) -> LongT5LayerLocalSelfAttention
723    where
724        P: Borrow<nn::Path<'p>>,
725    {
726        let p = p.borrow();
727
728        let local_self_attention = LongT5LocalAttention::new(
729            p / "LocalSelfAttention",
730            config,
731            is_decoder,
732            has_relative_attention_bias,
733        );
734
735        let layer_norm =
736            LongT5LayerNorm::new(p / "layer_norm", config.d_model, config.layer_norm_epsilon);
737        let dropout = Dropout::new(config.dropout_rate);
738
739        LongT5LayerLocalSelfAttention {
740            local_self_attention,
741            layer_norm,
742            dropout,
743        }
744    }
745
746    pub fn forward_t(
747        &self,
748        hidden_states: &Tensor,
749        attention_mask: Option<&Tensor>,
750        position_bias: Option<&Tensor>,
751        train: bool,
752    ) -> (Tensor, Option<Tensor>, Option<Tensor>) {
753        let normed_hidden_states = hidden_states.apply(&self.layer_norm);
754
755        let (attention_output, position_bias, attention_weights) = self
756            .local_self_attention
757            .forward_t(&normed_hidden_states, attention_mask, position_bias, train);
758
759        let output = hidden_states + attention_output.apply_t(&self.dropout, train);
760
761        (output, position_bias, attention_weights)
762    }
763}
764
765pub struct LongT5LayerTransientGlobalSelfAttention {
766    transient_global_sef_attention: LongT5TransientGlobalAttention,
767    layer_norm: LongT5LayerNorm,
768    dropout: Dropout,
769}
770
771impl LongT5LayerTransientGlobalSelfAttention {
772    pub fn new<'p, P>(
773        p: P,
774        config: &LongT5Config,
775        has_relative_attention_bias: bool,
776        is_decoder: bool,
777    ) -> LongT5LayerTransientGlobalSelfAttention
778    where
779        P: Borrow<nn::Path<'p>>,
780    {
781        let p = p.borrow();
782
783        let transient_global_sef_attention = LongT5TransientGlobalAttention::new(
784            p / "TransientGlobalSelfAttention",
785            config,
786            is_decoder,
787            has_relative_attention_bias,
788        );
789
790        let layer_norm =
791            LongT5LayerNorm::new(p / "layer_norm", config.d_model, config.layer_norm_epsilon);
792        let dropout = Dropout::new(config.dropout_rate);
793
794        LongT5LayerTransientGlobalSelfAttention {
795            transient_global_sef_attention,
796            layer_norm,
797            dropout,
798        }
799    }
800
801    pub fn forward_t(
802        &self,
803        hidden_states: &Tensor,
804        attention_mask: Option<&Tensor>,
805        position_bias: Option<&Tensor>,
806        train: bool,
807    ) -> (Tensor, Option<Tensor>, Option<Tensor>) {
808        let normed_hidden_states = hidden_states.apply(&self.layer_norm);
809        let (attention_output, position_bias, attention_weights) = self
810            .transient_global_sef_attention
811            .forward_t(&normed_hidden_states, attention_mask, position_bias, train);
812
813        let output = hidden_states + attention_output.apply_t(&self.dropout, train);
814
815        (output, position_bias, attention_weights)
816    }
817}