1use scirs2_core::Array2; use std::collections::HashMap;
40use trustformers_core::errors::{tensor_op_error, Result};
41use trustformers_core::layers::AttentionInput;
42use trustformers_core::tensor::Tensor;
43use trustformers_core::traits::Layer;
44
45#[derive(Debug, Clone)]
47pub struct SparseAttentionConfig {
48 pub hidden_size: usize,
49 pub num_heads: usize,
50 pub dropout_prob: f32,
51 pub pattern: SparsePattern,
52 pub max_sequence_length: usize,
53 pub block_size: usize,
54 pub use_cache: bool,
55 pub attention_scale: Option<f32>,
56}
57
58impl Default for SparseAttentionConfig {
59 fn default() -> Self {
60 Self::new()
61 }
62}
63
64impl SparseAttentionConfig {
65 pub fn new() -> Self {
66 Self {
67 hidden_size: 768,
68 num_heads: 12,
69 dropout_prob: 0.1,
70 pattern: SparsePattern::Local { window_size: 128 },
71 max_sequence_length: 4096,
72 block_size: 64,
73 use_cache: true,
74 attention_scale: None,
75 }
76 }
77
78 pub fn with_pattern(mut self, pattern: SparsePattern) -> Self {
79 self.pattern = pattern;
80 self
81 }
82
83 pub fn with_hidden_size(mut self, hidden_size: usize) -> Self {
84 self.hidden_size = hidden_size;
85 self
86 }
87
88 pub fn with_num_heads(mut self, num_heads: usize) -> Self {
89 self.num_heads = num_heads;
90 self
91 }
92
93 pub fn with_dropout(mut self, dropout_prob: f32) -> Self {
94 self.dropout_prob = dropout_prob;
95 self
96 }
97
98 pub fn with_max_length(mut self, max_sequence_length: usize) -> Self {
99 self.max_sequence_length = max_sequence_length;
100 self
101 }
102
103 pub fn with_block_size(mut self, block_size: usize) -> Self {
104 self.block_size = block_size;
105 self
106 }
107}
108
109#[derive(Debug, Clone)]
111pub enum SparsePattern {
112 Local { window_size: usize },
114 Strided { stride: usize, window_size: usize },
116 Dilated {
118 max_dilation: usize,
119 window_size: usize,
120 },
121 Random { sparsity_ratio: f32 },
123 BlockSparse {
125 block_size: usize,
126 global_blocks: usize,
127 random_blocks: usize,
128 },
129 Longformer {
131 window_size: usize,
132 global_tokens: Vec<usize>,
133 },
134 Linformer { projection_dim: usize },
136 Reformer {
138 num_hashes: usize,
139 bucket_size: usize,
140 },
141 Custom { mask: SparseAttentionMask },
143}
144
145#[derive(Debug, Clone)]
147pub struct SparseAttentionMask {
148 pub indices: Vec<(usize, usize)>, pub values: Vec<f32>, pub shape: (usize, usize), }
152
153impl SparseAttentionMask {
154 pub fn new(shape: (usize, usize)) -> Self {
155 Self {
156 indices: Vec::new(),
157 values: Vec::new(),
158 shape,
159 }
160 }
161
162 pub fn add_entry(&mut self, row: usize, col: usize, value: f32) {
163 if row < self.shape.0 && col < self.shape.1 {
164 self.indices.push((row, col));
165 self.values.push(value);
166 }
167 }
168
169 pub fn to_dense(&self) -> Vec<Vec<f32>> {
170 let mut dense = vec![vec![f32::NEG_INFINITY; self.shape.1]; self.shape.0];
171 for (i, &(row, col)) in self.indices.iter().enumerate() {
172 dense[row][col] = self.values[i];
173 }
174 dense
175 }
176
177 pub fn sparsity(&self) -> f32 {
178 let total_elements = self.shape.0 * self.shape.1;
179 let nonzero_elements = self.indices.len();
180 1.0 - (nonzero_elements as f32 / total_elements as f32)
181 }
182}
183
184#[derive(Debug, Clone)]
186pub struct SparseAttention {
187 config: SparseAttentionConfig,
188 query_projection: trustformers_core::layers::Linear,
189 key_projection: trustformers_core::layers::Linear,
190 value_projection: trustformers_core::layers::Linear,
191 output_projection: trustformers_core::layers::Linear,
192 #[allow(dead_code)]
193 head_dim: usize,
194 scale: f32,
195 #[allow(dead_code)]
196 mask_cache: HashMap<usize, SparseAttentionMask>,
197}
198
199impl SparseAttention {
200 pub fn new(config: SparseAttentionConfig) -> Result<Self> {
201 let head_dim = config.hidden_size / config.num_heads;
202 let scale = config.attention_scale.unwrap_or(1.0 / (head_dim as f32).sqrt());
203
204 Ok(Self {
205 query_projection: trustformers_core::layers::Linear::new(
206 config.hidden_size,
207 config.hidden_size,
208 false,
209 ),
210 key_projection: trustformers_core::layers::Linear::new(
211 config.hidden_size,
212 config.hidden_size,
213 false,
214 ),
215 value_projection: trustformers_core::layers::Linear::new(
216 config.hidden_size,
217 config.hidden_size,
218 false,
219 ),
220 output_projection: trustformers_core::layers::Linear::new(
221 config.hidden_size,
222 config.hidden_size,
223 false,
224 ),
225 head_dim,
226 scale,
227 mask_cache: HashMap::new(),
228 config,
229 })
230 }
231
232 pub fn generate_mask(&self, sequence_length: usize) -> Result<SparseAttentionMask> {
234 match &self.config.pattern {
235 SparsePattern::Local { window_size } => {
236 self.generate_local_mask(sequence_length, *window_size)
237 },
238 SparsePattern::Strided {
239 stride,
240 window_size,
241 } => self.generate_strided_mask(sequence_length, *stride, *window_size),
242 SparsePattern::Dilated {
243 max_dilation,
244 window_size,
245 } => self.generate_dilated_mask(sequence_length, *max_dilation, *window_size),
246 SparsePattern::Random { sparsity_ratio } => {
247 self.generate_random_mask(sequence_length, *sparsity_ratio)
248 },
249 SparsePattern::BlockSparse {
250 block_size,
251 global_blocks,
252 random_blocks,
253 } => self.generate_block_sparse_mask(
254 sequence_length,
255 *block_size,
256 *global_blocks,
257 *random_blocks,
258 ),
259 SparsePattern::Longformer {
260 window_size,
261 global_tokens,
262 } => self.generate_longformer_mask(sequence_length, *window_size, global_tokens),
263 SparsePattern::Linformer { projection_dim } => {
264 self.generate_linformer_mask(sequence_length, *projection_dim)
265 },
266 SparsePattern::Reformer {
267 num_hashes,
268 bucket_size,
269 } => self.generate_reformer_mask(sequence_length, *num_hashes, *bucket_size),
270 SparsePattern::Custom { mask } => Ok(mask.clone()),
271 }
272 }
273
274 fn generate_local_mask(
275 &self,
276 seq_len: usize,
277 window_size: usize,
278 ) -> Result<SparseAttentionMask> {
279 let mut mask = SparseAttentionMask::new((seq_len, seq_len));
280
281 for i in 0..seq_len {
282 let start = i.saturating_sub(window_size / 2);
283 let end = (i + window_size / 2 + 1).min(seq_len);
284
285 for j in start..end {
286 mask.add_entry(i, j, 0.0);
287 }
288 }
289
290 Ok(mask)
291 }
292
293 fn generate_strided_mask(
294 &self,
295 seq_len: usize,
296 stride: usize,
297 window_size: usize,
298 ) -> Result<SparseAttentionMask> {
299 let mut mask = SparseAttentionMask::new((seq_len, seq_len));
300
301 for i in 0..seq_len {
302 let start = i.saturating_sub(window_size / 2);
304 let end = (i + window_size / 2 + 1).min(seq_len);
305
306 for j in start..end {
307 mask.add_entry(i, j, 0.0);
308 }
309
310 let mut pos = i;
312 while pos < seq_len {
313 mask.add_entry(i, pos, 0.0);
314 pos += stride;
315 }
316
317 if i >= stride {
318 let mut pos = i - stride;
319 loop {
320 mask.add_entry(i, pos, 0.0);
321 if pos < stride {
322 break;
323 }
324 pos -= stride;
325 }
326 }
327 }
328
329 Ok(mask)
330 }
331
332 fn generate_dilated_mask(
333 &self,
334 seq_len: usize,
335 max_dilation: usize,
336 window_size: usize,
337 ) -> Result<SparseAttentionMask> {
338 let mut mask = SparseAttentionMask::new((seq_len, seq_len));
339
340 for i in 0..seq_len {
341 for dilation in 1..=max_dilation {
342 let start = i.saturating_sub(window_size * dilation / 2);
343 let end = (i + window_size * dilation / 2 + 1).min(seq_len);
344
345 for j in (start..end).step_by(dilation) {
346 mask.add_entry(i, j, 0.0);
347 }
348 }
349 }
350
351 Ok(mask)
352 }
353
354 fn generate_random_mask(
355 &self,
356 seq_len: usize,
357 sparsity_ratio: f32,
358 ) -> Result<SparseAttentionMask> {
359 let mut mask = SparseAttentionMask::new((seq_len, seq_len));
360 let total_elements = seq_len * seq_len;
361 let keep_elements = (total_elements as f32 * (1.0 - sparsity_ratio)) as usize;
362
363 let mut added = 0;
365 for i in 0..seq_len {
366 for j in 0..seq_len {
367 if added < keep_elements && (i + j) % 3 == 0 {
368 mask.add_entry(i, j, 0.0);
370 added += 1;
371 }
372 }
373 }
374
375 Ok(mask)
376 }
377
378 fn generate_block_sparse_mask(
379 &self,
380 seq_len: usize,
381 block_size: usize,
382 global_blocks: usize,
383 random_blocks: usize,
384 ) -> Result<SparseAttentionMask> {
385 let mut mask = SparseAttentionMask::new((seq_len, seq_len));
386 let num_blocks = seq_len.div_ceil(block_size);
387
388 for block_i in 0..num_blocks {
389 let start_i = block_i * block_size;
390 let end_i = (start_i + block_size).min(seq_len);
391
392 for block_j in 0..num_blocks {
393 let start_j = block_j * block_size;
394 let end_j = (start_j + block_size).min(seq_len);
395
396 if block_i == block_j || block_i.abs_diff(block_j) <= 1 {
398 for i in start_i..end_i {
399 for j in start_j..end_j {
400 mask.add_entry(i, j, 0.0);
401 }
402 }
403 }
404
405 if block_j < global_blocks || block_i < global_blocks {
407 for i in start_i..end_i {
408 for j in start_j..end_j {
409 mask.add_entry(i, j, 0.0);
410 }
411 }
412 }
413
414 if (block_i + block_j) % (num_blocks / random_blocks.max(1)) == 0 {
416 for i in start_i..end_i {
417 for j in start_j..end_j {
418 mask.add_entry(i, j, 0.0);
419 }
420 }
421 }
422 }
423 }
424
425 Ok(mask)
426 }
427
428 fn generate_longformer_mask(
429 &self,
430 seq_len: usize,
431 window_size: usize,
432 global_tokens: &[usize],
433 ) -> Result<SparseAttentionMask> {
434 let mut mask = SparseAttentionMask::new((seq_len, seq_len));
435
436 for i in 0..seq_len {
438 let start = i.saturating_sub(window_size / 2);
439 let end = (i + window_size / 2 + 1).min(seq_len);
440
441 for j in start..end {
442 mask.add_entry(i, j, 0.0);
443 }
444 }
445
446 for &global_token in global_tokens {
448 if global_token < seq_len {
449 for j in 0..seq_len {
450 mask.add_entry(global_token, j, 0.0);
451 mask.add_entry(j, global_token, 0.0);
452 }
453 }
454 }
455
456 Ok(mask)
457 }
458
459 fn generate_linformer_mask(
460 &self,
461 seq_len: usize,
462 projection_dim: usize,
463 ) -> Result<SparseAttentionMask> {
464 let mut mask = SparseAttentionMask::new((seq_len, projection_dim));
467
468 for i in 0..seq_len {
469 for j in 0..projection_dim {
470 mask.add_entry(i, j, 0.0);
471 }
472 }
473
474 Ok(mask)
475 }
476
477 fn generate_reformer_mask(
478 &self,
479 seq_len: usize,
480 num_hashes: usize,
481 bucket_size: usize,
482 ) -> Result<SparseAttentionMask> {
483 let mut mask = SparseAttentionMask::new((seq_len, seq_len));
484 let num_buckets = seq_len.div_ceil(bucket_size);
485
486 for hash_idx in 0..num_hashes {
488 for bucket in 0..num_buckets {
489 let start = bucket * bucket_size;
490 let end = (start + bucket_size).min(seq_len);
491
492 for i in start..end {
494 for j in start..end {
495 let hash_offset = (i + hash_idx) % seq_len;
496 let hash_bucket = hash_offset / bucket_size;
497 if hash_bucket == bucket {
498 mask.add_entry(i, j, 0.0);
499 }
500 }
501 }
502 }
503 }
504
505 Ok(mask)
506 }
507
508 #[allow(dead_code)]
510 fn apply_sparse_mask(
511 &self,
512 attention_scores: &Tensor,
513 mask: &SparseAttentionMask,
514 ) -> Result<Tensor> {
515 match attention_scores {
516 Tensor::F32(scores) => {
517 let mut masked_scores = scores.clone();
518 let shape = scores.shape();
519
520 if shape.len() != 2 {
521 return Err(tensor_op_error(
522 "tensor_operation",
523 "Attention scores must be 2D for sparse masking".to_string(),
524 ));
525 }
526
527 masked_scores.fill(f32::NEG_INFINITY);
529
530 for &(row, col) in mask.indices.iter() {
532 if row < shape[0] && col < shape[1] {
533 masked_scores[[row, col]] = scores[[row, col]];
534 }
535 }
536
537 Ok(Tensor::F32(masked_scores))
538 },
539 _ => Err(tensor_op_error(
540 "tensor_operation",
541 "Unsupported tensor type for sparse attention".to_string(),
542 )),
543 }
544 }
545
546 fn compute_sparse_attention(
548 &self,
549 query: &Tensor,
550 key: &Tensor,
551 value: &Tensor,
552 mask: &SparseAttentionMask,
553 ) -> Result<Tensor> {
554 let attention_scores = self.compute_sparse_scores(query, key, mask)?;
560
561 let attention_weights = attention_scores.softmax(-1)?;
563
564 self.apply_sparse_attention_weights(&attention_weights, value, mask)
566 }
567
568 fn compute_sparse_scores(
569 &self,
570 query: &Tensor,
571 key: &Tensor,
572 mask: &SparseAttentionMask,
573 ) -> Result<Tensor> {
574 match (query, key) {
577 (Tensor::F32(q), Tensor::F32(k)) => {
578 let q_shape = q.shape();
579 let k_shape = k.shape();
580
581 if q_shape.len() != 2 || k_shape.len() != 2 {
582 return Err(tensor_op_error(
583 "tensor_operation",
584 "Query and key must be 2D".to_string(),
585 ));
586 }
587
588 let seq_len = q_shape[0];
589 let head_dim = q_shape[1];
590
591 let mut scores = Array2::from_elem((seq_len, seq_len), f32::NEG_INFINITY);
592
593 for &(i, j) in &mask.indices {
595 if i < seq_len && j < seq_len {
596 let mut score = 0.0;
597 for d in 0..head_dim {
598 score += q[[i, d]] * k[[j, d]];
599 }
600 scores[[i, j]] = score * self.scale;
601 }
602 }
603
604 Ok(Tensor::F32(scores.into_dyn()))
605 },
606 _ => Err(tensor_op_error(
607 "tensor_operation",
608 "Unsupported tensor types for sparse attention".to_string(),
609 )),
610 }
611 }
612
613 fn apply_sparse_attention_weights(
614 &self,
615 weights: &Tensor,
616 value: &Tensor,
617 mask: &SparseAttentionMask,
618 ) -> Result<Tensor> {
619 match (weights, value) {
620 (Tensor::F32(w), Tensor::F32(v)) => {
621 let w_shape = w.shape();
622 let v_shape = v.shape();
623
624 let seq_len = w_shape[0];
625 let head_dim = v_shape[1];
626
627 let mut output = Array2::zeros((seq_len, head_dim));
628
629 for &(i, j) in &mask.indices {
631 if i < seq_len && j < seq_len {
632 let weight = w[[i, j]];
633 if weight != f32::NEG_INFINITY && !weight.is_nan() {
634 for d in 0..head_dim {
635 output[[i, d]] += weight * v[[j, d]];
636 }
637 }
638 }
639 }
640
641 Ok(Tensor::F32(output.into_dyn()))
642 },
643 _ => Err(tensor_op_error(
644 "tensor_operation",
645 "Unsupported tensor types for sparse attention output".to_string(),
646 )),
647 }
648 }
649}
650
651impl Layer for SparseAttention {
652 type Input = AttentionInput;
653 type Output = Tensor;
654
655 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
656 let AttentionInput {
657 hidden_states,
658 attention_mask: _,
659 } = input;
660
661 let query = self.query_projection.forward(hidden_states.clone())?;
663 let key = self.key_projection.forward(hidden_states.clone())?;
664 let value = self.value_projection.forward(hidden_states)?;
665
666 let seq_len = match &query {
668 Tensor::F32(q) => q.shape()[0],
669 _ => {
670 return Err(tensor_op_error(
671 "tensor_operation",
672 "Unsupported tensor type".to_string(),
673 ))
674 },
675 };
676
677 let mask = self.generate_mask(seq_len)?;
679
680 let attention_output = self.compute_sparse_attention(&query, &key, &value, &mask)?;
682
683 self.output_projection.forward(attention_output)
685 }
686}
687
688pub mod utils {
690 use super::*;
691
692 pub fn create_local_attention(
694 hidden_size: usize,
695 num_heads: usize,
696 window_size: usize,
697 ) -> SparseAttentionConfig {
698 SparseAttentionConfig::new()
699 .with_hidden_size(hidden_size)
700 .with_num_heads(num_heads)
701 .with_pattern(SparsePattern::Local { window_size })
702 }
703
704 pub fn create_bigbird_attention(
706 hidden_size: usize,
707 num_heads: usize,
708 block_size: usize,
709 ) -> SparseAttentionConfig {
710 SparseAttentionConfig::new()
711 .with_hidden_size(hidden_size)
712 .with_num_heads(num_heads)
713 .with_pattern(SparsePattern::BlockSparse {
714 block_size,
715 global_blocks: 2,
716 random_blocks: 2,
717 })
718 }
719
720 pub fn create_longformer_attention(
722 hidden_size: usize,
723 num_heads: usize,
724 window_size: usize,
725 global_tokens: Vec<usize>,
726 ) -> SparseAttentionConfig {
727 SparseAttentionConfig::new()
728 .with_hidden_size(hidden_size)
729 .with_num_heads(num_heads)
730 .with_pattern(SparsePattern::Longformer {
731 window_size,
732 global_tokens,
733 })
734 }
735
736 pub fn analyze_pattern_efficiency(
738 pattern: &SparsePattern,
739 sequence_length: usize,
740 ) -> Result<PatternAnalysis> {
741 let config = SparseAttentionConfig::new().with_pattern(pattern.clone());
742 let attention = SparseAttention::new(config)?;
743 let mask = attention.generate_mask(sequence_length)?;
744
745 Ok(PatternAnalysis {
746 sparsity: mask.sparsity(),
747 memory_reduction: mask.sparsity(),
748 compute_reduction: mask.sparsity(),
749 effective_receptive_field: calculate_receptive_field(&mask),
750 pattern_regularity: calculate_pattern_regularity(&mask),
751 })
752 }
753
754 fn calculate_receptive_field(mask: &SparseAttentionMask) -> f32 {
755 let mut total_connections = 0;
756 let mut positions_with_connections = 0;
757
758 for i in 0..mask.shape.0 {
759 let mut connections = 0;
760 for &(row, _) in &mask.indices {
761 if row == i {
762 connections += 1;
763 }
764 }
765 if connections > 0 {
766 total_connections += connections;
767 positions_with_connections += 1;
768 }
769 }
770
771 if positions_with_connections > 0 {
772 total_connections as f32 / positions_with_connections as f32
773 } else {
774 0.0
775 }
776 }
777
778 fn calculate_pattern_regularity(mask: &SparseAttentionMask) -> f32 {
779 let mut connections_per_position = vec![0; mask.shape.0];
781
782 for &(row, _) in &mask.indices {
783 connections_per_position[row] += 1;
784 }
785
786 let mean = connections_per_position.iter().sum::<usize>() as f32 / mask.shape.0 as f32;
787 let variance =
788 connections_per_position.iter().map(|&x| (x as f32 - mean).powi(2)).sum::<f32>()
789 / mask.shape.0 as f32;
790
791 1.0 / (1.0 + variance) }
793
794 #[derive(Debug, Clone)]
796 pub struct PatternAnalysis {
797 pub sparsity: f32,
798 pub memory_reduction: f32,
799 pub compute_reduction: f32,
800 pub effective_receptive_field: f32,
801 pub pattern_regularity: f32,
802 }
803}
804
805#[cfg(test)]
806mod tests {
807 use super::*;
808 use trustformers_core::tensor::Tensor;
809
810 #[test]
811 fn test_local_attention_mask() {
812 let config =
813 SparseAttentionConfig::new().with_pattern(SparsePattern::Local { window_size: 4 });
814
815 let attention = SparseAttention::new(config).expect("operation failed");
816 let mask = attention.generate_mask(8).expect("operation failed");
817
818 assert_eq!(mask.shape, (8, 8));
819 assert!(mask.sparsity() > 0.0);
820 }
821
822 #[test]
823 fn test_block_sparse_attention_mask() {
824 let config = SparseAttentionConfig::new().with_pattern(SparsePattern::BlockSparse {
826 block_size: 4,
827 global_blocks: 1,
828 random_blocks: 1,
829 });
830
831 let attention = SparseAttention::new(config).expect("operation failed");
832 let mask = attention.generate_mask(32).expect("operation failed"); assert_eq!(mask.shape, (32, 32));
835 assert!(mask.sparsity() >= 0.0); }
839
840 #[test]
841 fn test_sparse_attention_forward() {
842 let config = SparseAttentionConfig::new()
843 .with_hidden_size(64)
844 .with_num_heads(4)
845 .with_pattern(SparsePattern::Local { window_size: 4 });
846
847 let attention = SparseAttention::new(config).expect("operation failed");
848
849 let input = Tensor::randn(&[8, 64]).expect("operation failed");
851 let attention_input = AttentionInput {
852 hidden_states: input,
853 attention_mask: None,
854 };
855
856 let output = attention.forward(attention_input).expect("operation failed");
857
858 match output {
859 Tensor::F32(arr) => {
860 assert_eq!(arr.shape(), &[8, 64]);
861 },
862 _ => panic!("Expected F32 tensor"),
863 }
864 }
865
866 #[test]
867 fn test_pattern_analysis() {
868 let pattern = SparsePattern::Local { window_size: 4 };
869 let analysis =
870 utils::analyze_pattern_efficiency(&pattern, 16).expect("operation failed in test");
871
872 assert!(analysis.sparsity > 0.0);
873 assert!(analysis.sparsity < 1.0);
874 assert!(analysis.effective_receptive_field > 0.0);
875 assert!(analysis.pattern_regularity > 0.0);
876 }
877
878 #[test]
879 fn test_utility_functions() {
880 let local_config = utils::create_local_attention(768, 12, 128);
881 assert_eq!(local_config.hidden_size, 768);
882 assert_eq!(local_config.num_heads, 12);
883
884 let bigbird_config = utils::create_bigbird_attention(768, 12, 64);
885 assert_eq!(bigbird_config.hidden_size, 768);
886
887 let longformer_config = utils::create_longformer_attention(768, 12, 128, vec![0, 1]);
888 assert_eq!(longformer_config.hidden_size, 768);
889 }
890
891 #[test]
892 fn test_sparse_mask_operations() {
893 let mut mask = SparseAttentionMask::new((4, 4));
894 mask.add_entry(0, 0, 0.0);
895 mask.add_entry(0, 1, 0.0);
896 mask.add_entry(1, 1, 0.0);
897
898 assert_eq!(mask.indices.len(), 3);
899 assert_eq!(mask.sparsity(), 1.0 - 3.0 / 16.0);
900
901 let dense = mask.to_dense();
902 assert_eq!(dense.len(), 4);
903 assert_eq!(dense[0].len(), 4);
904 assert_eq!(dense[0][0], 0.0);
905 assert_eq!(dense[0][1], 0.0);
906 assert_eq!(dense[0][2], f32::NEG_INFINITY);
907 }
908}