Skip to main content

tensorlogic_trustformers/
sliding_window.rs

1//! # Sliding Window Attention
2//!
3//! Implementation of Sliding Window Attention for efficient long-sequence processing.
4//!
5//! Sliding Window Attention constrains attention to a fixed-size window around each
6//! position, reducing complexity from O(n^2) to O(n * w) where w is the window size.
7//!
8//! ## Used By
9//!
10//! - Mistral 7B (window size: 4096)
11//! - Longformer
12//! - BigBird
13
14use crate::error::{Result, TrustformerError};
15use tensorlogic_ir::{EinsumGraph, EinsumNode};
16
17/// Configuration for Sliding Window Attention
18#[derive(Debug, Clone)]
19pub struct SlidingWindowConfig {
20    /// Model dimension
21    pub d_model: usize,
22    /// Number of attention heads
23    pub n_heads: usize,
24    /// Window size (positions attended to)
25    pub window_size: usize,
26    /// Dimension per head
27    pub d_k: usize,
28    /// Whether to use causal masking
29    pub causal: bool,
30    /// Dropout probability
31    pub dropout: f64,
32}
33
34impl SlidingWindowConfig {
35    /// Create a new Sliding Window Attention configuration
36    pub fn new(d_model: usize, n_heads: usize, window_size: usize) -> Result<Self> {
37        if !d_model.is_multiple_of(n_heads) {
38            return Err(TrustformerError::InvalidHeadCount { d_model, n_heads });
39        }
40
41        if window_size == 0 {
42            return Err(TrustformerError::MissingParameter(
43                "window_size must be positive".to_string(),
44            ));
45        }
46
47        let d_k = d_model / n_heads;
48
49        Ok(Self {
50            d_model,
51            n_heads,
52            window_size,
53            d_k,
54            causal: false,
55            dropout: 0.0,
56        })
57    }
58
59    /// Enable causal masking
60    pub fn with_causal(mut self, causal: bool) -> Self {
61        self.causal = causal;
62        self
63    }
64
65    /// Set dropout probability
66    pub fn with_dropout(mut self, dropout: f64) -> Self {
67        self.dropout = dropout;
68        self
69    }
70
71    /// Validate the configuration
72    pub fn validate(&self) -> Result<()> {
73        if self.d_model == 0 {
74            return Err(TrustformerError::MissingParameter(
75                "d_model must be positive".to_string(),
76            ));
77        }
78        if self.n_heads == 0 {
79            return Err(TrustformerError::MissingParameter(
80                "n_heads must be positive".to_string(),
81            ));
82        }
83        if self.dropout < 0.0 || self.dropout > 1.0 {
84            return Err(TrustformerError::CompilationError(
85                "dropout must be between 0 and 1".to_string(),
86            ));
87        }
88        Ok(())
89    }
90
91    /// Calculate complexity reduction compared to full attention
92    pub fn complexity_reduction(&self, seq_len: usize) -> f64 {
93        if seq_len <= self.window_size {
94            1.0
95        } else {
96            self.window_size as f64 / seq_len as f64
97        }
98    }
99
100    /// Calculate memory reduction compared to full attention
101    pub fn memory_reduction(&self, seq_len: usize) -> f64 {
102        if seq_len <= self.window_size {
103            1.0
104        } else {
105            self.window_size as f64 / seq_len as f64
106        }
107    }
108}
109
110/// Sliding Window Attention implementation
111#[derive(Debug, Clone)]
112pub struct SlidingWindowAttention {
113    /// Configuration
114    config: SlidingWindowConfig,
115}
116
117impl SlidingWindowAttention {
118    /// Create a new Sliding Window Attention module
119    pub fn new(config: SlidingWindowConfig) -> Result<Self> {
120        config.validate()?;
121        Ok(Self { config })
122    }
123
124    /// Get the configuration
125    pub fn config(&self) -> &SlidingWindowConfig {
126        &self.config
127    }
128
129    /// Build the sliding window attention graph
130    pub fn build_swa_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
131        let _n_heads = self.config.n_heads;
132        let d_k = self.config.d_k;
133
134        // Step 1: Reshape Q, K, V to multi-head format
135        let q_split = graph.add_tensor("swa_q_split");
136        let k_split = graph.add_tensor("swa_k_split");
137        let v_split = graph.add_tensor("swa_v_split");
138
139        let reshape_spec = format!("bsd->bsh{}", d_k);
140
141        let q_reshape = EinsumNode::new(&reshape_spec, vec![0], vec![q_split]);
142        graph.add_node(q_reshape)?;
143
144        let k_reshape = EinsumNode::new(&reshape_spec, vec![1], vec![k_split]);
145        graph.add_node(k_reshape)?;
146
147        let v_reshape = EinsumNode::new(&reshape_spec, vec![2], vec![v_split]);
148        graph.add_node(v_reshape)?;
149
150        // Step 2: Transpose to [batch, n_heads, seq, d_k]
151        let q_transposed = graph.add_tensor("swa_q_transposed");
152        let k_transposed = graph.add_tensor("swa_k_transposed");
153        let v_transposed = graph.add_tensor("swa_v_transposed");
154
155        let transpose_q = EinsumNode::new("bshd->bhsd", vec![q_split], vec![q_transposed]);
156        graph.add_node(transpose_q)?;
157
158        let transpose_k = EinsumNode::new("bshd->bhsd", vec![k_split], vec![k_transposed]);
159        graph.add_node(transpose_k)?;
160
161        let transpose_v = EinsumNode::new("bshd->bhsd", vec![v_split], vec![v_transposed]);
162        graph.add_node(transpose_v)?;
163
164        // Step 3: Compute attention scores
165        let scores = graph.add_tensor("swa_scores");
166        let scores_node = EinsumNode::new(
167            "bhqd,bhkd->bhqk",
168            vec![q_transposed, k_transposed],
169            vec![scores],
170        );
171        graph.add_node(scores_node)?;
172
173        // Step 4: Scale scores
174        let scale_factor = (d_k as f64).sqrt();
175        let scale_tensor = graph.add_tensor("swa_scale");
176        let scaled_scores = graph.add_tensor("swa_scaled_scores");
177        let scale_node = EinsumNode::elem_binary(
178            format!("div_scalar_{}", scale_factor),
179            scores,
180            scale_tensor,
181            scaled_scores,
182        );
183        graph.add_node(scale_node)?;
184
185        // Step 5: Apply sliding window mask
186        let masked_scores = graph.add_tensor("swa_masked_scores");
187        let mask_node = EinsumNode::elem_unary(
188            format!("sliding_window_mask_{}", self.config.window_size),
189            scaled_scores,
190            masked_scores,
191        );
192        graph.add_node(mask_node)?;
193
194        // Step 6: Softmax
195        let attention_weights = graph.add_tensor("swa_attention_weights");
196        let softmax_node = EinsumNode::elem_unary("softmax_k", masked_scores, attention_weights);
197        graph.add_node(softmax_node)?;
198
199        // Step 7: Apply attention to values
200        let attn_output = graph.add_tensor("swa_attn_output");
201        let attn_node = EinsumNode::new(
202            "bhqk,bhkv->bhqv",
203            vec![attention_weights, v_transposed],
204            vec![attn_output],
205        );
206        graph.add_node(attn_node)?;
207
208        // Step 8: Transpose back
209        let transposed_back = graph.add_tensor("swa_transposed_back");
210        let transpose_back =
211            EinsumNode::new("bhsd->bshd", vec![attn_output], vec![transposed_back]);
212        graph.add_node(transpose_back)?;
213
214        // Step 9: Reshape to [batch, seq, d_model]
215        let output = graph.add_tensor("swa_output");
216        let reshape_back_spec = format!("bsh{}-:bsd", d_k);
217        let reshape_back = EinsumNode::new(&reshape_back_spec, vec![transposed_back], vec![output]);
218        graph.add_node(reshape_back)?;
219
220        Ok(vec![output])
221    }
222}
223
224/// Presets for common Sliding Window Attention configurations
225#[derive(Debug, Clone, Copy, PartialEq, Eq)]
226pub enum SlidingWindowPreset {
227    /// Mistral 7B (window: 4096)
228    Mistral7B,
229    /// Longformer Base (window: 512)
230    LongformerBase,
231    /// BigBird Base (window: 256)
232    BigBirdBase,
233}
234
235impl SlidingWindowPreset {
236    /// Get the configuration for this preset
237    pub fn config(&self) -> Result<SlidingWindowConfig> {
238        match self {
239            SlidingWindowPreset::Mistral7B => {
240                SlidingWindowConfig::new(4096, 32, 4096)?
241                    .with_causal(true)
242                    .validate()?;
243                Ok(SlidingWindowConfig::new(4096, 32, 4096)?.with_causal(true))
244            }
245            SlidingWindowPreset::LongformerBase => SlidingWindowConfig::new(768, 12, 512),
246            SlidingWindowPreset::BigBirdBase => SlidingWindowConfig::new(768, 12, 256),
247        }
248    }
249
250    /// Get the name of this preset
251    pub fn name(&self) -> &'static str {
252        match self {
253            SlidingWindowPreset::Mistral7B => "Mistral 7B",
254            SlidingWindowPreset::LongformerBase => "Longformer Base",
255            SlidingWindowPreset::BigBirdBase => "BigBird Base",
256        }
257    }
258}
259
260/// Statistics for Sliding Window Attention
261#[derive(Debug, Clone)]
262pub struct SlidingWindowStats {
263    /// Configuration
264    pub config: SlidingWindowConfig,
265    /// Complexity reduction for given sequence length
266    pub complexity_reduction: f64,
267    /// Memory reduction for given sequence length
268    pub memory_reduction: f64,
269}
270
271impl SlidingWindowStats {
272    /// Create stats from configuration and sequence length
273    pub fn from_config(config: &SlidingWindowConfig, seq_len: usize) -> Self {
274        Self {
275            config: config.clone(),
276            complexity_reduction: config.complexity_reduction(seq_len),
277            memory_reduction: config.memory_reduction(seq_len),
278        }
279    }
280
281    /// Format as a summary string
282    pub fn summary(&self, seq_len: usize) -> String {
283        format!(
284            "Sliding Window Attention\n  d_model: {}\n  n_heads: {}\n  window_size: {}\n  \
285             causal: {}\n  complexity reduction: {:.1}%\n  memory reduction: {:.1}%\n  \
286             seq_len: {}",
287            self.config.d_model,
288            self.config.n_heads,
289            self.config.window_size,
290            self.config.causal,
291            (1.0 - self.complexity_reduction) * 100.0,
292            (1.0 - self.memory_reduction) * 100.0,
293            seq_len
294        )
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    #[test]
303    fn test_swa_config_creation() {
304        let config = SlidingWindowConfig::new(4096, 32, 4096).unwrap();
305        assert_eq!(config.d_model, 4096);
306        assert_eq!(config.n_heads, 32);
307        assert_eq!(config.window_size, 4096);
308        assert_eq!(config.d_k, 128);
309    }
310
311    #[test]
312    fn test_swa_config_builder() {
313        let config = SlidingWindowConfig::new(4096, 32, 4096)
314            .unwrap()
315            .with_causal(true)
316            .with_dropout(0.1);
317
318        assert!(config.causal);
319        assert!((config.dropout - 0.1).abs() < 1e-10);
320    }
321
322    #[test]
323    fn test_swa_invalid_configs() {
324        // Invalid d_model
325        assert!(SlidingWindowConfig::new(512, 7, 256).is_err());
326
327        // Invalid window size
328        assert!(SlidingWindowConfig::new(512, 8, 0).is_err());
329    }
330
331    #[test]
332    fn test_swa_complexity_reduction() {
333        let config = SlidingWindowConfig::new(512, 8, 256).unwrap();
334
335        // Short sequence: no reduction
336        assert_eq!(config.complexity_reduction(128), 1.0);
337
338        // Long sequence: significant reduction
339        let reduction = config.complexity_reduction(4096);
340        assert!((reduction - 0.0625).abs() < 0.001);
341    }
342
343    #[test]
344    fn test_swa_graph_building() {
345        let config = SlidingWindowConfig::new(512, 8, 256).unwrap();
346        let swa = SlidingWindowAttention::new(config).unwrap();
347
348        let mut graph = EinsumGraph::new();
349        graph.add_tensor("Q");
350        graph.add_tensor("K");
351        graph.add_tensor("V");
352
353        let outputs = swa.build_swa_graph(&mut graph).unwrap();
354        assert_eq!(outputs.len(), 1);
355    }
356
357    #[test]
358    fn test_swa_causal_graph() {
359        let config = SlidingWindowConfig::new(512, 8, 256)
360            .unwrap()
361            .with_causal(true);
362        let swa = SlidingWindowAttention::new(config).unwrap();
363
364        let mut graph = EinsumGraph::new();
365        graph.add_tensor("Q");
366        graph.add_tensor("K");
367        graph.add_tensor("V");
368
369        let outputs = swa.build_swa_graph(&mut graph).unwrap();
370        assert_eq!(outputs.len(), 1);
371    }
372
373    #[test]
374    fn test_swa_presets() {
375        // Mistral 7B
376        let config = SlidingWindowPreset::Mistral7B.config().unwrap();
377        assert_eq!(config.d_model, 4096);
378        assert_eq!(config.window_size, 4096);
379        assert!(config.causal);
380
381        // Longformer Base
382        let config = SlidingWindowPreset::LongformerBase.config().unwrap();
383        assert_eq!(config.d_model, 768);
384        assert_eq!(config.window_size, 512);
385
386        // BigBird Base
387        let config = SlidingWindowPreset::BigBirdBase.config().unwrap();
388        assert_eq!(config.window_size, 256);
389    }
390
391    #[test]
392    fn test_swa_preset_names() {
393        assert_eq!(SlidingWindowPreset::Mistral7B.name(), "Mistral 7B");
394        assert_eq!(
395            SlidingWindowPreset::LongformerBase.name(),
396            "Longformer Base"
397        );
398    }
399
400    #[test]
401    fn test_swa_stats() {
402        let config = SlidingWindowConfig::new(4096, 32, 4096).unwrap();
403        let stats = SlidingWindowStats::from_config(&config, 32768);
404
405        // 4096/32768 = 0.125
406        assert!((stats.complexity_reduction - 0.125).abs() < 0.001);
407        assert!((stats.memory_reduction - 0.125).abs() < 0.001);
408    }
409
410    #[test]
411    fn test_swa_validate() {
412        let config = SlidingWindowConfig::new(512, 8, 256).unwrap();
413        assert!(config.validate().is_ok());
414
415        // Invalid dropout
416        let mut bad = config.clone();
417        bad.dropout = -0.1;
418        assert!(bad.validate().is_err());
419    }
420}