1use crate::common::dropout::Dropout;
14use crate::longt5::layer_norm::LongT5LayerNorm;
15use crate::longt5::LongT5Config;
16use crate::t5::{
17 get_relative_position_bucket, LayerState as T5layerState, T5Attention, T5LayerCrossAttention,
18};
19use std::borrow::Borrow;
20use tch::nn::LinearConfig;
21use tch::{nn, Device, IndexOp, Kind, Tensor};
22
23pub type LongT5Attention = T5Attention;
24pub type LongT5LayerCrossAttention = T5LayerCrossAttention;
25pub type LayerState = T5layerState;
26
27fn pad_to_multiple(x: &Tensor, block_length: i64, dim: usize, pad_value: f64) -> Tensor {
28 let mut x_size = x.size();
29 let pad_length = (-x_size[dim]).rem_euclid(block_length);
30
31 if x_size.iter().any(|&el| el == 0) {
32 x_size[dim] += pad_length;
33 Tensor::zeros(x_size.as_slice(), (x.kind(), x.device()))
34 } else {
35 let mut pad = vec![0i64; 2 * x.dim()];
36 pad[2 * dim] = pad_length;
37 pad.reverse();
38 x.pad(pad.as_slice(), "constant", pad_value)
39 }
40}
41
42fn split_into_blocks(x: &Tensor, block_length: i64, dim: usize) -> Tensor {
43 let x_size = x.size();
44 let padded_x = if x_size[dim] % block_length != 0 {
45 Some(pad_to_multiple(x, block_length, dim, 0f64))
46 } else {
47 None
48 };
49 let x = padded_x.as_ref().unwrap_or(x);
50 let mut x_size = x.size();
51 let num_blocks = x_size[dim] / block_length;
52 x_size.remove(dim);
53 x_size.insert(dim, block_length);
54 x_size.insert(dim, num_blocks);
55 if x_size.iter().any(|&el| el == 0) {
56 Tensor::empty(x_size.as_slice(), (x.kind(), x.device()))
57 } else {
58 x.reshape(x_size.as_slice())
59 }
60}
61
62fn concatenate_3_blocks(
63 x: &Tensor,
64 block_dim: usize,
65 sequence_dim: i64,
66 pad_value: Option<f64>,
67) -> Tensor {
68 let x_size = x.size();
69 let num_blocks = x_size[block_dim];
70 let mut pad = vec![0i64; 2 * x.dim()];
71 pad[2 * block_dim] = 1;
72 pad[2 * block_dim + 1] = 1;
73 pad.reverse();
74 let x = x.pad(pad.as_slice(), "constant", pad_value.unwrap_or(0f64));
75 let mut block_list: Vec<Tensor> = Vec::with_capacity(3);
76 for i in 0..3 {
77 block_list.push(x.narrow(block_dim as i64, i, num_blocks));
78 }
79 Tensor::cat(block_list.as_slice(), sequence_dim)
80}
81
82fn make_3blocks_relative_position_ids(block_length: i64, device: Device) -> Tensor {
83 let position_ids = Tensor::arange(3 * block_length, (Kind::Int, device));
84 let center_position_ids = position_ids.i(block_length..2 * block_length);
85 position_ids.unsqueeze(0) - center_position_ids.unsqueeze(1)
86}
87
88fn mask_local_attention_mask(local_attention_mask: &Tensor, block_length: i64) -> Tensor {
89 let relative_position_ids =
90 make_3blocks_relative_position_ids(block_length, local_attention_mask.device());
91 let locality_mask = relative_position_ids
92 .abs()
93 .lt(block_length)
94 .unsqueeze(0)
95 .unsqueeze(0);
96 local_attention_mask.logical_and(&locality_mask)
97}
98
99pub(crate) fn get_local_attention_mask(attention_mask: &Tensor, block_length: i64) -> Tensor {
100 let blocked_attention_mask = split_into_blocks(attention_mask, block_length, 1);
101 let three_blocked_attention_mask = concatenate_3_blocks(&blocked_attention_mask, 1, 2, None);
102
103 let blocked_attention_mask = blocked_attention_mask.unsqueeze(-1);
104 let three_blocked_attention_mask = three_blocked_attention_mask.unsqueeze(-2);
105
106 let local_attention_mask = mask_local_attention_mask(
107 &blocked_attention_mask.logical_and(&three_blocked_attention_mask),
108 block_length,
109 );
110 local_attention_mask.unsqueeze(1)
111}
112
113fn make_global_fixed_block_ids(
114 attention_mask: &Tensor,
115 global_block_size: i64,
116) -> (Tensor, Tensor) {
117 let &[batch_size, seq_length, ..] = attention_mask.size().as_slice() else {
118 unreachable!()
119 };
120
121 let handle_orphan_tokens = |block_ids: Tensor| -> Tensor {
122 let block_ends = Tensor::arange(seq_length, (Kind::Int64, block_ids.device()))
123 .remainder(global_block_size)
124 .eq(global_block_size - 1);
125 let true_block_ends = block_ends.logical_and(&block_ids.ge(0));
126 let full_blocks = true_block_ends
127 .sum_dim_intlist([-1].as_slice(), false, block_ids.kind())
128 .unsqueeze(-1)
129 - 1;
130 block_ids.where_self(&block_ids.lt_tensor(&full_blocks), &full_blocks)
131 };
132
133 let fixed_block_mask = attention_mask.ones_like() / global_block_size;
134 let fixed_block_mask = fixed_block_mask.cumsum(1, fixed_block_mask.kind()) - fixed_block_mask;
135 let mask = attention_mask
136 .ones_like()
137 .where_scalarother(&attention_mask.not_equal(0.0), -1000.0);
138
139 let mut global_block_ids = (mask + fixed_block_mask - 1.0).floor();
140 global_block_ids = global_block_ids.where_scalarother(&global_block_ids.gt(-1.0), -1.0);
141 global_block_ids = global_block_ids * attention_mask + attention_mask - 1;
142 global_block_ids = handle_orphan_tokens(global_block_ids);
143 let num_globals = seq_length / global_block_size;
144 let sequence_block_ids_max = if num_globals > 0 {
145 global_block_ids
146 .max_dim(-1, false)
147 .0
148 .repeat([num_globals, 1])
149 .transpose(0, 1)
150 } else {
151 Tensor::zeros(
152 [batch_size, 0],
153 (global_block_ids.kind(), global_block_ids.device()),
154 )
155 };
156 let global_segment_ids = Tensor::ones(
157 [batch_size, num_globals],
158 (attention_mask.kind(), attention_mask.device()),
159 )
160 .cumsum(-1, attention_mask.kind())
161 - 1;
162 let global_segment_ids = global_segment_ids
163 .ones_like()
164 .where_scalarother(&global_segment_ids.le_tensor(&sequence_block_ids_max), 0.0);
165 (
166 global_block_ids.to_kind(Kind::Int),
167 global_segment_ids.to_kind(Kind::Int),
168 )
169}
170
171fn make_side_relative_position_ids(attention_mask: &Tensor, global_block_size: i64) -> Tensor {
172 let (block_ids, global_segment_ids) =
173 make_global_fixed_block_ids(attention_mask, global_block_size);
174 let global_seq_length = *global_segment_ids.size().last().unwrap();
175 let global_positions = Tensor::arange(global_seq_length, (Kind::Int64, block_ids.device()));
176 global_positions - block_ids.unsqueeze(-1)
177}
178
179fn create_global_aggregates(
180 hidden_states: &Tensor,
181 block_ids: &Tensor,
182 global_seq_length: i64,
183) -> Tensor {
184 let block_ids = block_ids.where_scalarother(&block_ids.ge(0), global_seq_length);
185 let one_hot_block_ids = block_ids
186 .to_kind(Kind::Int64)
187 .one_hot(global_seq_length + 1);
188 let one_hot_block_ids = one_hot_block_ids.narrow(2, 0, one_hot_block_ids.size()[2] - 1);
189 Tensor::einsum(
190 "...nd,...ng->...gd",
191 &[
192 hidden_states,
193 &one_hot_block_ids.to_kind(hidden_states.kind()),
194 ],
195 None::<i64>,
196 )
197}
198
199fn compute_bias(
200 block_length: i64,
201 relative_attention_bias: &nn::Embedding,
202 is_decoder: bool,
203 relative_attention_num_buckets: i64,
204 relative_attention_max_distance: i64,
205) -> Tensor {
206 let device = relative_attention_bias.ws.device();
207 let memory_position = Tensor::arange(3 * block_length, (Kind::Int64, device));
208 let context_position = memory_position.narrow(0, block_length, block_length);
209 let relative_position = memory_position.unsqueeze(0) - context_position.unsqueeze(-1);
210
211 let rp_bucket = get_relative_position_bucket(
212 &relative_position,
213 !is_decoder,
214 relative_attention_num_buckets,
215 relative_attention_max_distance,
216 );
217 rp_bucket
218 .apply(relative_attention_bias)
219 .permute([2, 0, 1])
220 .unsqueeze(0)
221 .unsqueeze(0)
222}
223
224pub struct LongT5LocalAttention {
225 is_decoder: bool,
226 has_relative_attention_bias: bool,
227 relative_attention_num_buckets: i64,
228 relative_attention_max_distance: i64,
229 key_value_proj_dim: i64,
230 n_heads: i64,
231 block_length: i64,
232 dropout: Dropout,
233 inner_dim: i64,
234 output_attentions: bool,
235 query: nn::Linear,
236 key: nn::Linear,
237 value: nn::Linear,
238 output: nn::Linear,
239 relative_attention_bias: Option<nn::Embedding>,
240}
241
242impl LongT5LocalAttention {
243 pub fn new<'p, P>(
244 p: P,
245 config: &LongT5Config,
246 is_decoder: bool,
247 has_relative_attention_bias: bool,
248 ) -> LongT5LocalAttention
249 where
250 P: Borrow<nn::Path<'p>>,
251 {
252 let p = p.borrow();
253
254 let linear_config = LinearConfig {
255 bias: false,
256 ..Default::default()
257 };
258
259 let block_length = config.local_radius + 1;
260 let key_value_proj_dim = config.d_kv;
261
262 let inner_dim = config.num_heads * config.d_kv;
263 let key = nn::linear(p / "k", config.d_model, inner_dim, linear_config);
264 let value = nn::linear(p / "v", config.d_model, inner_dim, linear_config);
265 let query = nn::linear(p / "q", config.d_model, inner_dim, linear_config);
266 let output = nn::linear(p / "o", inner_dim, config.d_model, linear_config);
267
268 let dropout = Dropout::new(config.dropout_rate);
269 let relative_attention_bias = if has_relative_attention_bias {
270 Some(nn::embedding(
271 p / "relative_attention_bias",
272 config.relative_attention_num_buckets,
273 config.num_heads,
274 Default::default(),
275 ))
276 } else {
277 None
278 };
279
280 LongT5LocalAttention {
281 is_decoder,
282 has_relative_attention_bias,
283 relative_attention_num_buckets: config.relative_attention_num_buckets,
284 relative_attention_max_distance: config.relative_attention_max_distance.unwrap_or(128),
285 key_value_proj_dim,
286 n_heads: config.num_heads,
287 block_length,
288 dropout,
289 inner_dim,
290 output_attentions: config.output_attentions.unwrap_or(false),
291 query,
292 key,
293 value,
294 output,
295 relative_attention_bias,
296 }
297 }
298
299 pub fn forward_t(
300 &self,
301 hidden_states: &Tensor,
302 mask: Option<&Tensor>,
303 position_bias: Option<&Tensor>,
304 train: bool,
305 ) -> (Tensor, Option<Tensor>, Option<Tensor>) {
306 let input_size = hidden_states.size();
307 let (batch_size, seq_length) = (input_size[0], input_size[1]);
308
309 let shape = |states: &Tensor| -> Tensor {
310 states.view([batch_size, -1, self.n_heads, self.key_value_proj_dim])
311 };
312 let unshape = |states: &Tensor| -> Tensor {
313 states.contiguous().view([batch_size, -1, self.inner_dim])
314 };
315
316 let query_states = shape(&hidden_states.apply(&self.query));
317 let key_states = shape(&hidden_states.apply(&self.key));
318 let value_states = shape(&hidden_states.apply(&self.value));
319
320 let query_states = split_into_blocks(&query_states, self.block_length, 1);
321 let key_states = split_into_blocks(&key_states, self.block_length, 1);
322 let value_states = split_into_blocks(&value_states, self.block_length, 1);
323
324 let key_states = concatenate_3_blocks(&key_states, 1, 2, None);
325 let value_states = concatenate_3_blocks(&value_states, 1, 2, None);
326
327 let mut scores = Tensor::einsum(
328 "...qhd,...khd->...hqk",
329 &[query_states, key_states],
330 None::<i64>,
331 );
332 let calc_position_bias = if position_bias.is_none() {
333 let mut position_bias = if !self.has_relative_attention_bias {
334 Tensor::zeros(
335 [1, 1, self.n_heads, self.block_length, 3 * self.block_length],
336 (scores.kind(), scores.device()),
337 )
338 } else {
339 compute_bias(
340 self.block_length,
341 self.relative_attention_bias.as_ref().unwrap(),
342 self.is_decoder,
343 self.relative_attention_num_buckets,
344 self.relative_attention_max_distance,
345 )
346 };
347 if let Some(mask) = mask {
348 let mask = mask.zeros_like().where_scalarother(&mask.gt(0), -1e10);
349 position_bias = position_bias + mask.transpose(1, 2);
350 }
351 Some(position_bias)
352 } else {
353 None
354 };
355 let position_bias = position_bias.unwrap_or_else(|| calc_position_bias.as_ref().unwrap());
356 scores += position_bias;
357 let attention_weights = scores
358 .to_kind(Kind::Float)
359 .softmax(-1, scores.kind())
360 .apply_t(&self.dropout, train)
361 .to_kind(value_states.kind());
362 let attention_output = unshape(&Tensor::einsum(
363 "...hqk,...khd->...qhd",
364 &[&attention_weights, &value_states],
365 None::<i64>,
366 ))
367 .narrow(1, 0, seq_length)
368 .apply(&self.output);
369
370 let attention_weights = if self.output_attentions {
371 Some(attention_weights)
372 } else {
373 None
374 };
375
376 let position_bias = if self.has_relative_attention_bias {
377 calc_position_bias
378 } else {
379 None
380 };
381 (attention_output, position_bias, attention_weights)
382 }
383}
384
385pub struct LongT5TransientGlobalAttention {
386 is_decoder: bool,
387 has_relative_attention_bias: bool,
388 relative_attention_num_buckets: i64,
389 relative_attention_max_distance: i64,
390 key_value_proj_dim: i64,
391 n_heads: i64,
392 block_length: i64,
393 global_block_size: i64,
394 dropout: Dropout,
395 inner_dim: i64,
396 output_attentions: bool,
397 query: nn::Linear,
398 key: nn::Linear,
399 value: nn::Linear,
400 output: nn::Linear,
401 relative_attention_bias: Option<nn::Embedding>,
402 global_relative_attention_bias: Option<nn::Embedding>,
403 global_input_layer_norm: LongT5LayerNorm,
404}
405
406impl LongT5TransientGlobalAttention {
407 pub fn new<'p, P>(
408 p: P,
409 config: &LongT5Config,
410 is_decoder: bool,
411 has_relative_attention_bias: bool,
412 ) -> LongT5TransientGlobalAttention
413 where
414 P: Borrow<nn::Path<'p>>,
415 {
416 let p = p.borrow();
417
418 let linear_config = LinearConfig {
419 bias: false,
420 ..Default::default()
421 };
422
423 let block_length = config.local_radius + 1;
424 let global_block_size = config.global_block_size;
425 let key_value_proj_dim = config.d_kv;
426
427 let inner_dim = config.num_heads * config.d_kv;
428 let key = nn::linear(p / "k", config.d_model, inner_dim, linear_config);
429 let value = nn::linear(p / "v", config.d_model, inner_dim, linear_config);
430 let query = nn::linear(p / "q", config.d_model, inner_dim, linear_config);
431 let output = nn::linear(p / "o", inner_dim, config.d_model, linear_config);
432
433 let dropout = Dropout::new(config.dropout_rate);
434 let global_relative_attention_bias = if has_relative_attention_bias {
435 Some(nn::embedding(
436 p / "global_relative_attention_bias",
437 config.relative_attention_num_buckets,
438 config.num_heads,
439 Default::default(),
440 ))
441 } else {
442 None
443 };
444 let relative_attention_bias = if has_relative_attention_bias {
445 Some(nn::embedding(
446 p / "relative_attention_bias",
447 config.relative_attention_num_buckets,
448 config.num_heads,
449 Default::default(),
450 ))
451 } else {
452 None
453 };
454 let global_input_layer_norm = LongT5LayerNorm::new(
455 p / "global_input_layer_norm",
456 config.d_model,
457 config.layer_norm_epsilon,
458 );
459
460 LongT5TransientGlobalAttention {
461 is_decoder,
462 has_relative_attention_bias,
463 relative_attention_num_buckets: config.relative_attention_num_buckets,
464 relative_attention_max_distance: config.relative_attention_max_distance.unwrap_or(128),
465 key_value_proj_dim,
466 n_heads: config.num_heads,
467 block_length,
468 global_block_size,
469 dropout,
470 inner_dim,
471 output_attentions: config.output_attentions.unwrap_or(false),
472 query,
473 key,
474 value,
475 output,
476 relative_attention_bias,
477 global_relative_attention_bias,
478 global_input_layer_norm,
479 }
480 }
481
482 fn compute_side_bias(&self, mask: &Tensor, global_segment_ids: &Tensor) -> Tensor {
483 let side_attention_mask = mask
484 .unsqueeze(-1)
485 .eq_tensor(&global_segment_ids.unsqueeze(1))
486 .unsqueeze(1);
487
488 let attention_side_bias = side_attention_mask
489 .zeros_like()
490 .where_scalarother(&side_attention_mask.gt(0), -1e10);
491
492 let side_relative_position = make_side_relative_position_ids(mask, self.global_block_size);
493 let side_relative_position_bucket = get_relative_position_bucket(
494 &side_relative_position,
495 !self.is_decoder,
496 self.relative_attention_num_buckets,
497 self.relative_attention_max_distance,
498 );
499 let side_bias = side_relative_position_bucket
500 .apply(self.global_relative_attention_bias.as_ref().unwrap())
501 .permute([0, 3, 1, 2]);
502 attention_side_bias + side_bias
503 }
504
505 pub fn forward_t(
506 &self,
507 hidden_states: &Tensor,
508 mask: Option<&Tensor>,
509 position_bias: Option<&Tensor>,
510 train: bool,
511 ) -> (Tensor, Option<Tensor>, Option<Tensor>) {
512 let input_size = hidden_states.size();
513 let (batch_size, seq_length) = (input_size[0], input_size[1]);
514
515 let shape = |states: &Tensor| -> Tensor {
516 states.view([batch_size, -1, self.n_heads, self.key_value_proj_dim])
517 };
518 let unshape = |states: &Tensor| -> Tensor {
519 states.contiguous().view([batch_size, -1, self.inner_dim])
520 };
521 let calc_mask = if mask.is_none() {
522 let mut mask_size = input_size;
523 let _ = mask_size.pop();
524 Some(Tensor::ones(
525 mask_size.as_slice(),
526 (Kind::Bool, hidden_states.device()),
527 ))
528 } else {
529 None
530 };
531 let (block_ids, global_segment_ids) = make_global_fixed_block_ids(
532 mask.unwrap_or_else(|| calc_mask.as_ref().unwrap()),
533 self.global_block_size,
534 );
535 let global_seq_length = *global_segment_ids.size().last().unwrap();
536 let global_inputs = create_global_aggregates(hidden_states, &block_ids, global_seq_length)
537 .apply(&self.global_input_layer_norm);
538
539 let query_states = shape(&hidden_states.apply(&self.query));
540 let key_states = shape(&hidden_states.apply(&self.key));
541 let value_states = shape(&hidden_states.apply(&self.value));
542
543 let side_key_states = shape(&global_inputs.apply(&self.key));
544 let side_value_states = shape(&global_inputs.apply(&self.value));
545
546 let query_states = split_into_blocks(&query_states, self.block_length, 1);
547 let key_states = split_into_blocks(&key_states, self.block_length, 1);
548 let value_states = split_into_blocks(&value_states, self.block_length, 1);
549
550 let key_states = concatenate_3_blocks(&key_states, 1, 2, None);
551 let value_states = concatenate_3_blocks(&value_states, 1, 2, None);
552
553 let mut reps = vec![1; side_key_states.dim() + 1];
554 reps[1] = key_states.size()[1];
555 let side_key_states = side_key_states.unsqueeze(1).repeat(reps.as_slice());
556 let side_value_states = side_value_states.unsqueeze(1).repeat(reps.as_slice());
557 let key_states = Tensor::cat(&[key_states, side_key_states], 2);
558 let value_states = Tensor::cat(&[value_states, side_value_states], 2);
559
560 let mut scores = Tensor::einsum(
561 "...qhd,...khd->...hqk",
562 &[query_states, key_states],
563 None::<i64>,
564 );
565 let local_attention_mask = mask.map(|mask| {
566 let local_attention_mask = get_local_attention_mask(mask, self.block_length);
567 local_attention_mask
568 .zeros_like()
569 .where_scalarother(&local_attention_mask.gt(0), -1e10)
570 });
571
572 let calc_position_bias = if position_bias.is_none() {
573 let mut position_bias = if !self.has_relative_attention_bias {
574 Tensor::zeros(
575 [1, 1, self.n_heads, self.block_length, 3 * self.block_length],
576 (scores.kind(), scores.device()),
577 )
578 } else {
579 compute_bias(
580 self.block_length,
581 self.relative_attention_bias.as_ref().unwrap(),
582 self.is_decoder,
583 self.relative_attention_num_buckets,
584 self.relative_attention_max_distance,
585 )
586 };
587 if let Some(local_attention_mask) = local_attention_mask {
588 position_bias = position_bias + local_attention_mask.transpose(1, 2);
589 }
590 let calc_mask = if mask.is_none() {
591 Some(Tensor::ones(
592 [batch_size, seq_length],
593 (global_segment_ids.kind(), global_segment_ids.device()),
594 ))
595 } else {
596 None
597 };
598 let mask = mask.unwrap_or_else(|| calc_mask.as_ref().unwrap());
599 let side_position_bias = self.compute_side_bias(mask, &global_segment_ids);
600 let side_position_bias = split_into_blocks(
601 &side_position_bias,
602 self.block_length,
603 side_position_bias.dim() - 2,
604 )
605 .transpose(1, 2);
606 let position_bias = Tensor::cat(&[position_bias, side_position_bias], -1);
607
608 Some(position_bias)
609 } else {
610 None
611 };
612 let position_bias = position_bias.unwrap_or_else(|| calc_position_bias.as_ref().unwrap());
613
614 scores += position_bias;
615 let attention_weights = scores
616 .to_kind(Kind::Float)
617 .softmax(-1, scores.kind())
618 .apply_t(&self.dropout, train);
619
620 let attention_output = unshape(&Tensor::einsum(
621 "...hqk,...khd->...qhd",
622 &[&attention_weights, &value_states],
623 None::<i64>,
624 ))
625 .narrow(1, 0, seq_length)
626 .apply(&self.output);
627
628 let attention_weights = if self.output_attentions {
629 Some(attention_weights)
630 } else {
631 None
632 };
633
634 let position_bias = if self.has_relative_attention_bias {
635 calc_position_bias
636 } else {
637 None
638 };
639 (attention_output, position_bias, attention_weights)
640 }
641}
642
643pub struct LongT5LayerSelfAttention {
644 self_attention: LongT5Attention,
645 layer_norm: LongT5LayerNorm,
646 dropout: Dropout,
647}
648
649impl LongT5LayerSelfAttention {
650 pub fn new<'p, P>(
651 p: P,
652 config: &LongT5Config,
653 has_relative_attention_bias: bool,
654 is_decoder: bool,
655 store_cache: bool,
656 output_attentions: bool,
657 ) -> LongT5LayerSelfAttention
658 where
659 P: Borrow<nn::Path<'p>>,
660 {
661 let p = p.borrow();
662
663 let self_attention = LongT5Attention::new(
664 p / "SelfAttention",
665 &config.into(),
666 is_decoder,
667 !is_decoder,
668 store_cache,
669 output_attentions,
670 has_relative_attention_bias,
671 );
672
673 let layer_norm =
674 LongT5LayerNorm::new(p / "layer_norm", config.d_model, config.layer_norm_epsilon);
675 let dropout = Dropout::new(config.dropout_rate);
676
677 LongT5LayerSelfAttention {
678 self_attention,
679 layer_norm,
680 dropout,
681 }
682 }
683
684 pub fn forward_t(
685 &self,
686 hidden_states: &Tensor,
687 position_bias: Option<&Tensor>,
688 attention_mask: Option<&Tensor>,
689 layer_state: Option<LayerState>,
690 train: bool,
691 ) -> (Tensor, Option<Tensor>, Option<Tensor>, Option<LayerState>) {
692 let norm_x = hidden_states.apply(&self.layer_norm);
693
694 let (y, attention_weights, position_bias, layer_state) = self.self_attention.forward_t(
695 &norm_x,
696 None,
697 position_bias,
698 attention_mask,
699 layer_state,
700 None,
701 train,
702 );
703
704 let output = hidden_states + y.apply_t(&self.dropout, train);
705
706 (output, attention_weights, position_bias, layer_state)
707 }
708}
709
710pub struct LongT5LayerLocalSelfAttention {
711 local_self_attention: LongT5LocalAttention,
712 layer_norm: LongT5LayerNorm,
713 dropout: Dropout,
714}
715
716impl LongT5LayerLocalSelfAttention {
717 pub fn new<'p, P>(
718 p: P,
719 config: &LongT5Config,
720 has_relative_attention_bias: bool,
721 is_decoder: bool,
722 ) -> LongT5LayerLocalSelfAttention
723 where
724 P: Borrow<nn::Path<'p>>,
725 {
726 let p = p.borrow();
727
728 let local_self_attention = LongT5LocalAttention::new(
729 p / "LocalSelfAttention",
730 config,
731 is_decoder,
732 has_relative_attention_bias,
733 );
734
735 let layer_norm =
736 LongT5LayerNorm::new(p / "layer_norm", config.d_model, config.layer_norm_epsilon);
737 let dropout = Dropout::new(config.dropout_rate);
738
739 LongT5LayerLocalSelfAttention {
740 local_self_attention,
741 layer_norm,
742 dropout,
743 }
744 }
745
746 pub fn forward_t(
747 &self,
748 hidden_states: &Tensor,
749 attention_mask: Option<&Tensor>,
750 position_bias: Option<&Tensor>,
751 train: bool,
752 ) -> (Tensor, Option<Tensor>, Option<Tensor>) {
753 let normed_hidden_states = hidden_states.apply(&self.layer_norm);
754
755 let (attention_output, position_bias, attention_weights) = self
756 .local_self_attention
757 .forward_t(&normed_hidden_states, attention_mask, position_bias, train);
758
759 let output = hidden_states + attention_output.apply_t(&self.dropout, train);
760
761 (output, position_bias, attention_weights)
762 }
763}
764
765pub struct LongT5LayerTransientGlobalSelfAttention {
766 transient_global_sef_attention: LongT5TransientGlobalAttention,
767 layer_norm: LongT5LayerNorm,
768 dropout: Dropout,
769}
770
771impl LongT5LayerTransientGlobalSelfAttention {
772 pub fn new<'p, P>(
773 p: P,
774 config: &LongT5Config,
775 has_relative_attention_bias: bool,
776 is_decoder: bool,
777 ) -> LongT5LayerTransientGlobalSelfAttention
778 where
779 P: Borrow<nn::Path<'p>>,
780 {
781 let p = p.borrow();
782
783 let transient_global_sef_attention = LongT5TransientGlobalAttention::new(
784 p / "TransientGlobalSelfAttention",
785 config,
786 is_decoder,
787 has_relative_attention_bias,
788 );
789
790 let layer_norm =
791 LongT5LayerNorm::new(p / "layer_norm", config.d_model, config.layer_norm_epsilon);
792 let dropout = Dropout::new(config.dropout_rate);
793
794 LongT5LayerTransientGlobalSelfAttention {
795 transient_global_sef_attention,
796 layer_norm,
797 dropout,
798 }
799 }
800
801 pub fn forward_t(
802 &self,
803 hidden_states: &Tensor,
804 attention_mask: Option<&Tensor>,
805 position_bias: Option<&Tensor>,
806 train: bool,
807 ) -> (Tensor, Option<Tensor>, Option<Tensor>) {
808 let normed_hidden_states = hidden_states.apply(&self.layer_norm);
809 let (attention_output, position_bias, attention_weights) = self
810 .transient_global_sef_attention
811 .forward_t(&normed_hidden_states, attention_mask, position_bias, train);
812
813 let output = hidden_states + attention_output.apply_t(&self.dropout, train);
814
815 (output, position_bias, attention_weights)
816 }
817}