1use crate::errors::{Result, TrustformersError};
32use crate::sparse_tensor::{SparseFormat, SparseIndices, SparseTensor};
33use crate::tensor::Tensor;
34use serde::{Deserialize, Serialize};
35use std::collections::HashSet;
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub enum StructuredSparsityPattern {
40 NM { n: usize, m: usize },
42
43 Block {
45 block_height: usize,
46 block_width: usize,
47 },
48
49 Channel { dimension: usize, keep_ratio: f32 },
51
52 Head { num_heads: usize, keep_ratio: f32 },
54
55 Random { sparsity: f32 },
57
58 Magnitude { keep_ratio: f32 },
60}
61
62pub struct NMSparsity {
64 n: usize,
65 m: usize,
66}
67
68impl NMSparsity {
69 pub fn new(n: usize, m: usize) -> Self {
80 assert!(n <= m, "N must be <= M in N:M sparsity");
81 Self { n, m }
82 }
83
84 pub fn apply(&self, tensor: &Tensor) -> Result<SparseTensor> {
86 let data = tensor.to_vec_f32()?;
87 let shape = tensor.shape().to_vec();
88
89 if shape.len() != 2 {
90 return Err(TrustformersError::shape_error(
91 "N:M sparsity currently supports only 2D tensors".to_string(),
92 ));
93 }
94
95 let rows = shape[0];
96 let cols = shape[1];
97
98 if cols % self.m != 0 {
100 return Err(TrustformersError::shape_error(format!(
101 "Number of columns {} must be divisible by M={}",
102 cols, self.m
103 )));
104 }
105
106 let mut row_indices = Vec::new();
107 let mut col_indices = Vec::new();
108 let mut values = Vec::new();
109
110 for row in 0..rows {
112 let row_start = row * cols;
113
114 for window_start in (0..cols).step_by(self.m) {
116 let window_end = (window_start + self.m).min(cols);
117
118 let mut window_vals: Vec<(usize, f32)> = (window_start..window_end)
120 .map(|col| {
121 let idx = row_start + col;
122 (col, data[idx])
123 })
124 .collect();
125
126 window_vals.sort_by(|a, b| {
128 b.1.abs().partial_cmp(&a.1.abs()).unwrap_or(std::cmp::Ordering::Equal)
129 });
130
131 for (col, val) in window_vals.iter().take(self.n) {
133 row_indices.push(row);
134 col_indices.push(*col);
135 values.push(*val);
136 }
137 }
138 }
139
140 SparseTensor::new_coo(shape, row_indices, col_indices, values)
141 }
142
143 pub fn sparsity_ratio(&self) -> f32 {
145 1.0 - (self.n as f32 / self.m as f32)
146 }
147}
148
149pub struct BlockSparsity {
151 block_height: usize,
152 block_width: usize,
153 keep_ratio: f32,
154}
155
156impl BlockSparsity {
157 pub fn new(block_height: usize, block_width: usize, keep_ratio: f32) -> Self {
159 Self {
160 block_height,
161 block_width,
162 keep_ratio,
163 }
164 }
165
166 pub fn apply(&self, tensor: &Tensor) -> Result<SparseTensor> {
168 let data = tensor.to_vec_f32()?;
169 let shape = tensor.shape().to_vec();
170
171 if shape.len() != 2 {
172 return Err(TrustformersError::shape_error(
173 "Block sparsity currently supports only 2D tensors".to_string(),
174 ));
175 }
176
177 let rows = shape[0];
178 let cols = shape[1];
179
180 let num_block_rows = rows.div_ceil(self.block_height);
181 let num_block_cols = cols.div_ceil(self.block_width);
182
183 let mut block_scores = Vec::new();
185 for br in 0..num_block_rows {
186 for bc in 0..num_block_cols {
187 let row_start = br * self.block_height;
188 let row_end = (row_start + self.block_height).min(rows);
189 let col_start = bc * self.block_width;
190 let col_end = (col_start + self.block_width).min(cols);
191
192 let mut block_norm = 0.0f32;
194 for r in row_start..row_end {
195 for c in col_start..col_end {
196 let val = data[r * cols + c];
197 block_norm += val * val;
198 }
199 }
200 block_norm = block_norm.sqrt();
201
202 block_scores.push(((br, bc), block_norm));
203 }
204 }
205
206 block_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("Partial comparison failed"));
208
209 let num_blocks_to_keep = ((block_scores.len() as f32) * self.keep_ratio) as usize;
211 let blocks_to_keep: HashSet<(usize, usize)> = block_scores
212 .iter()
213 .take(num_blocks_to_keep)
214 .map(|&((br, bc), _)| (br, bc))
215 .collect();
216
217 let mut row_ptr = vec![0];
219 let mut col_indices = Vec::new();
220 let mut values = Vec::new();
221
222 for br in 0..num_block_rows {
223 let row_start = br * self.block_height;
224 let row_end = (row_start + self.block_height).min(rows);
225
226 for r in row_start..row_end {
227 let mut row_nnz = 0;
228
229 for bc in 0..num_block_cols {
230 if !blocks_to_keep.contains(&(br, bc)) {
231 continue;
232 }
233
234 let col_start = bc * self.block_width;
235 let col_end = (col_start + self.block_width).min(cols);
236
237 for c in col_start..col_end {
238 let val = data[r * cols + c];
239 if val != 0.0 {
240 col_indices.push(c);
241 values.push(val);
242 row_nnz += 1;
243 }
244 }
245 }
246
247 row_ptr.push(row_ptr.last().copied().unwrap_or(0) + row_nnz);
249 }
250 }
251
252 SparseTensor::new_csr(shape, row_ptr, col_indices, values)
253 }
254}
255
256pub fn sparse_matmul(sparse: &SparseTensor, dense: &Tensor) -> Result<Tensor> {
258 let dense_data = dense.to_vec_f32()?;
259 let dense_shape = dense.shape();
260
261 if sparse.shape.len() != 2 || dense_shape.len() != 2 {
262 return Err(TrustformersError::shape_error(
263 "Sparse matmul requires 2D matrices".to_string(),
264 ));
265 }
266
267 if sparse.shape[1] != dense_shape[0] {
268 return Err(TrustformersError::shape_error(format!(
269 "Incompatible shapes for matmul: {:?} x {:?}",
270 sparse.shape, dense_shape
271 )));
272 }
273
274 let m = sparse.shape[0];
275 let _k = sparse.shape[1];
276 let n = dense_shape[1];
277
278 let mut result = vec![0.0f32; m * n];
279
280 match sparse.format {
281 SparseFormat::CSR => {
282 if let SparseIndices::CSR {
283 row_ptr,
284 col_indices,
285 } = &sparse.indices
286 {
287 for row in 0..m {
289 let row_start = row_ptr[row];
290 let row_end = row_ptr[row + 1];
291
292 #[allow(clippy::needless_range_loop)]
293 for j in row_start..row_end {
294 let col = col_indices[j];
295 let sparse_val = sparse.values[j];
296
297 for out_col in 0..n {
299 result[row * n + out_col] += sparse_val * dense_data[col * n + out_col];
300 }
301 }
302 }
303 } else {
304 return Err(TrustformersError::tensor_op_error(
305 "Invalid indices format",
306 "sparse matmul",
307 ));
308 }
309 },
310 SparseFormat::COO => {
311 if let SparseIndices::COO {
312 row_indices,
313 col_indices,
314 } = &sparse.indices
315 {
316 for ((&row, &col), &val) in
317 row_indices.iter().zip(col_indices.iter()).zip(sparse.values.iter())
318 {
319 for out_col in 0..n {
320 result[row * n + out_col] += val * dense_data[col * n + out_col];
321 }
322 }
323 } else {
324 return Err(TrustformersError::tensor_op_error(
325 "Invalid indices format",
326 "sparse matmul",
327 ));
328 }
329 },
330 _ => {
331 return Err(TrustformersError::tensor_op_error(
332 "Unsupported sparse format for matmul",
333 "sparse matmul",
334 ));
335 },
336 }
337
338 Tensor::from_vec(result, &[m, n])
339}
340
341pub mod sparse_attention {
343 use super::*;
344
345 pub struct BlockSparseAttention {
347 block_size: usize,
348 num_random_blocks: usize,
349 }
350
351 impl BlockSparseAttention {
352 pub fn new(block_size: usize, num_random_blocks: usize) -> Self {
354 Self {
355 block_size,
356 num_random_blocks,
357 }
358 }
359
360 pub fn generate_mask(&self, seq_len: usize) -> Result<SparseTensor> {
362 let num_blocks = seq_len.div_ceil(self.block_size);
363
364 let mut row_indices = Vec::new();
365 let mut col_indices = Vec::new();
366 let mut values = Vec::new();
367
368 for block_i in 0..num_blocks {
369 for block_j in block_i.saturating_sub(1)..=(block_i + 1).min(num_blocks - 1) {
371 self.add_block(
372 block_i,
373 block_j,
374 seq_len,
375 &mut row_indices,
376 &mut col_indices,
377 &mut values,
378 );
379 }
380
381 for j in 0..self.num_random_blocks {
384 let random_block = (block_i * 7 + j * 13) % num_blocks;
385 self.add_block(
386 block_i,
387 random_block,
388 seq_len,
389 &mut row_indices,
390 &mut col_indices,
391 &mut values,
392 );
393 }
394 }
395
396 SparseTensor::new_coo(vec![seq_len, seq_len], row_indices, col_indices, values)
397 }
398
399 fn add_block(
400 &self,
401 block_i: usize,
402 block_j: usize,
403 seq_len: usize,
404 row_indices: &mut Vec<usize>,
405 col_indices: &mut Vec<usize>,
406 values: &mut Vec<f32>,
407 ) {
408 let row_start = block_i * self.block_size;
409 let row_end = (row_start + self.block_size).min(seq_len);
410 let col_start = block_j * self.block_size;
411 let col_end = (col_start + self.block_size).min(seq_len);
412
413 for r in row_start..row_end {
414 for c in col_start..col_end {
415 row_indices.push(r);
416 col_indices.push(c);
417 values.push(1.0); }
419 }
420 }
421 }
422
423 pub fn sliding_window_mask(seq_len: usize, window_size: usize) -> Result<SparseTensor> {
425 let mut row_indices = Vec::new();
426 let mut col_indices = Vec::new();
427 let mut values = Vec::new();
428
429 for i in 0..seq_len {
430 let start = i.saturating_sub(window_size / 2);
431 let end = (i + window_size / 2 + 1).min(seq_len);
432
433 for j in start..end {
434 row_indices.push(i);
435 col_indices.push(j);
436 values.push(1.0);
437 }
438 }
439
440 SparseTensor::new_coo(vec![seq_len, seq_len], row_indices, col_indices, values)
441 }
442
443 pub fn dilated_window_mask(
445 seq_len: usize,
446 window_size: usize,
447 dilation: usize,
448 ) -> Result<SparseTensor> {
449 let mut row_indices = Vec::new();
450 let mut col_indices = Vec::new();
451 let mut values = Vec::new();
452
453 for i in 0..seq_len {
454 let local_start = i.saturating_sub(window_size / 2);
456 let local_end = (i + window_size / 2 + 1).min(seq_len);
457
458 for j in local_start..local_end {
459 row_indices.push(i);
460 col_indices.push(j);
461 values.push(1.0);
462 }
463
464 for k in 1..=window_size {
466 let dilated_pos = i + k * dilation;
467 if dilated_pos < seq_len {
468 row_indices.push(i);
469 col_indices.push(dilated_pos);
470 values.push(1.0);
471 }
472
473 if k * dilation <= i {
474 let dilated_pos = i - k * dilation;
475 row_indices.push(i);
476 col_indices.push(dilated_pos);
477 values.push(1.0);
478 }
479 }
480 }
481
482 SparseTensor::new_coo(vec![seq_len, seq_len], row_indices, col_indices, values)
483 }
484}
485
486pub mod conversion {
488 use super::*;
489
490 pub fn coo_to_csr(sparse: &SparseTensor) -> Result<SparseTensor> {
492 if sparse.format != SparseFormat::COO {
493 return Err(TrustformersError::tensor_op_error(
494 "Input must be in COO format",
495 "COO to CSR conversion",
496 ));
497 }
498
499 if let SparseIndices::COO {
500 row_indices,
501 col_indices,
502 } = &sparse.indices
503 {
504 let num_rows = sparse.shape[0];
505
506 let mut row_ptr = vec![0; num_rows + 1];
508 for &row in row_indices {
509 row_ptr[row + 1] += 1;
510 }
511
512 for i in 0..num_rows {
514 row_ptr[i + 1] += row_ptr[i];
515 }
516
517 let mut entries: Vec<(usize, usize, f32)> = row_indices
519 .iter()
520 .zip(col_indices.iter())
521 .zip(sparse.values.iter())
522 .map(|((&r, &c), &v)| (r, c, v))
523 .collect();
524
525 entries.sort_by_key(|&(r, c, _)| (r, c));
526
527 let sorted_col_indices: Vec<usize> = entries.iter().map(|&(_, c, _)| c).collect();
528 let sorted_values: Vec<f32> = entries.iter().map(|&(_, _, v)| v).collect();
529
530 SparseTensor::new_csr(
531 sparse.shape.clone(),
532 row_ptr,
533 sorted_col_indices,
534 sorted_values,
535 )
536 } else {
537 Err(TrustformersError::tensor_op_error(
538 "Invalid indices format",
539 "COO to CSR conversion",
540 ))
541 }
542 }
543
544 pub fn csr_to_coo(sparse: &SparseTensor) -> Result<SparseTensor> {
546 if sparse.format != SparseFormat::CSR {
547 return Err(TrustformersError::tensor_op_error(
548 "Input must be in CSR format",
549 "CSR to COO conversion",
550 ));
551 }
552
553 if let SparseIndices::CSR {
554 row_ptr,
555 col_indices,
556 } = &sparse.indices
557 {
558 let mut row_indices = Vec::new();
559
560 for (row, window) in row_ptr.windows(2).enumerate() {
561 let count = window[1] - window[0];
562 row_indices.extend(vec![row; count]);
563 }
564
565 SparseTensor::new_coo(
566 sparse.shape.clone(),
567 row_indices,
568 col_indices.clone(),
569 sparse.values.clone(),
570 )
571 } else {
572 Err(TrustformersError::tensor_op_error(
573 "Invalid indices format",
574 "CSR to COO conversion",
575 ))
576 }
577 }
578}
579
580pub mod pruning {
582 use super::*;
583
584 pub fn magnitude_prune(tensor: &Tensor, keep_ratio: f32) -> Result<SparseTensor> {
586 let data = tensor.to_vec_f32()?;
587 let shape = tensor.shape().to_vec();
588
589 let mut indexed_data: Vec<(usize, f32)> =
591 data.iter().enumerate().map(|(i, &v)| (i, v)).collect();
592 indexed_data
593 .sort_by(|a, b| b.1.abs().partial_cmp(&a.1.abs()).unwrap_or(std::cmp::Ordering::Equal));
594
595 let num_keep = ((data.len() as f32) * keep_ratio) as usize;
597 let keep_indices: HashSet<usize> =
598 indexed_data.iter().take(num_keep).map(|&(idx, _)| idx).collect();
599
600 if shape.len() == 2 {
602 let cols = shape[1];
603 let mut row_indices = Vec::new();
604 let mut col_indices = Vec::new();
605 let mut values = Vec::new();
606
607 for idx in keep_indices {
608 let row = idx / cols;
609 let col = idx % cols;
610 row_indices.push(row);
611 col_indices.push(col);
612 values.push(data[idx]);
613 }
614
615 SparseTensor::new_coo(shape, row_indices, col_indices, values)
616 } else {
617 Err(TrustformersError::shape_error(
618 "Pruning currently supports only 2D tensors".to_string(),
619 ))
620 }
621 }
622
623 pub fn gradient_based_prune(
625 tensor: &Tensor,
626 gradients: &Tensor,
627 keep_ratio: f32,
628 ) -> Result<SparseTensor> {
629 let weights = tensor.to_vec_f32()?;
630 let grads = gradients.to_vec_f32()?;
631 let shape = tensor.shape().to_vec();
632
633 if weights.len() != grads.len() {
634 return Err(TrustformersError::shape_error(
635 "Weight and gradient shapes must match".to_string(),
636 ));
637 }
638
639 let mut scores: Vec<(usize, f32)> = weights
641 .iter()
642 .zip(grads.iter())
643 .enumerate()
644 .map(|(i, (&w, &g))| (i, (w * g).abs()))
645 .collect();
646
647 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("Partial comparison failed"));
648
649 let num_keep = ((weights.len() as f32) * keep_ratio) as usize;
650 let keep_indices: HashSet<usize> =
651 scores.iter().take(num_keep).map(|&(idx, _)| idx).collect();
652
653 if shape.len() == 2 {
655 let cols = shape[1];
656 let mut row_indices = Vec::new();
657 let mut col_indices = Vec::new();
658 let mut values = Vec::new();
659
660 for idx in keep_indices {
661 let row = idx / cols;
662 let col = idx % cols;
663 row_indices.push(row);
664 col_indices.push(col);
665 values.push(weights[idx]);
666 }
667
668 SparseTensor::new_coo(shape, row_indices, col_indices, values)
669 } else {
670 Err(TrustformersError::shape_error(
671 "Pruning currently supports only 2D tensors".to_string(),
672 ))
673 }
674 }
675}
676
677#[cfg(test)]
678mod tests {
679 use super::*;
680
681 #[test]
682 fn test_nm_sparsity() -> Result<()> {
683 let nm = NMSparsity::new(2, 4);
684 assert_eq!(nm.sparsity_ratio(), 0.5);
685
686 let data: Vec<f32> = (0..64).map(|i| i as f32).collect();
688 let tensor = Tensor::from_vec(data, &[8, 8])?;
689
690 let sparse = nm.apply(&tensor)?;
691
692 let expected_nnz = 8 * 8 / 2; assert_eq!(sparse.nnz, expected_nnz);
695
696 Ok(())
697 }
698
699 #[test]
700 fn test_sparse_matmul() -> Result<()> {
701 let sparse = SparseTensor::new_coo(
703 vec![3, 3],
704 vec![0, 0, 1, 2],
705 vec![0, 1, 1, 2],
706 vec![1.0, 2.0, 3.0, 4.0],
707 )?;
708
709 let dense_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
711 let dense = Tensor::from_vec(dense_data, &[3, 2])?;
712
713 let result = sparse_matmul(&sparse, &dense)?;
715
716 assert_eq!(result.shape(), &[3, 2]);
717
718 Ok(())
719 }
720
721 #[test]
722 fn test_block_sparsity() -> Result<()> {
723 let block_sparse = BlockSparsity::new(2, 2, 0.5);
724
725 let data: Vec<f32> = (0..64).map(|i| i as f32).collect();
726 let tensor = Tensor::from_vec(data, &[8, 8])?;
727
728 let sparse = block_sparse.apply(&tensor)?;
729
730 assert!(sparse.nnz > 0);
732 assert!(sparse.nnz < 64);
733
734 Ok(())
735 }
736
737 #[test]
738 fn test_sliding_window_mask() -> Result<()> {
739 let mask = sparse_attention::sliding_window_mask(100, 10)?;
740
741 assert!(mask.nnz <= 100 * 11);
746 assert!(mask.nnz > 0);
747
748 Ok(())
749 }
750
751 #[test]
752 fn test_magnitude_pruning() -> Result<()> {
753 let data: Vec<f32> = (0..64).map(|i| (i as f32) - 32.0).collect();
754 let tensor = Tensor::from_vec(data, &[8, 8])?;
755
756 let sparse = pruning::magnitude_prune(&tensor, 0.25)?;
757
758 assert_eq!(sparse.nnz, 16);
760
761 Ok(())
762 }
763
764 #[test]
765 fn test_coo_to_csr_conversion() -> Result<()> {
766 let coo = SparseTensor::new_coo(
767 vec![3, 3],
768 vec![0, 0, 1, 2],
769 vec![0, 1, 1, 2],
770 vec![1.0, 2.0, 3.0, 4.0],
771 )?;
772
773 let csr = conversion::coo_to_csr(&coo)?;
774
775 assert_eq!(csr.format, SparseFormat::CSR);
776 assert_eq!(csr.nnz, 4);
777
778 let coo2 = conversion::csr_to_coo(&csr)?;
780 assert_eq!(coo2.format, SparseFormat::COO);
781 assert_eq!(coo2.nnz, 4);
782
783 Ok(())
784 }
785}