tensorlogic_trustformers/
sliding_window.rs1use crate::error::{Result, TrustformerError};
15use tensorlogic_ir::{EinsumGraph, EinsumNode};
16
17#[derive(Debug, Clone)]
19pub struct SlidingWindowConfig {
20 pub d_model: usize,
22 pub n_heads: usize,
24 pub window_size: usize,
26 pub d_k: usize,
28 pub causal: bool,
30 pub dropout: f64,
32}
33
34impl SlidingWindowConfig {
35 pub fn new(d_model: usize, n_heads: usize, window_size: usize) -> Result<Self> {
37 if !d_model.is_multiple_of(n_heads) {
38 return Err(TrustformerError::InvalidHeadCount { d_model, n_heads });
39 }
40
41 if window_size == 0 {
42 return Err(TrustformerError::MissingParameter(
43 "window_size must be positive".to_string(),
44 ));
45 }
46
47 let d_k = d_model / n_heads;
48
49 Ok(Self {
50 d_model,
51 n_heads,
52 window_size,
53 d_k,
54 causal: false,
55 dropout: 0.0,
56 })
57 }
58
59 pub fn with_causal(mut self, causal: bool) -> Self {
61 self.causal = causal;
62 self
63 }
64
65 pub fn with_dropout(mut self, dropout: f64) -> Self {
67 self.dropout = dropout;
68 self
69 }
70
71 pub fn validate(&self) -> Result<()> {
73 if self.d_model == 0 {
74 return Err(TrustformerError::MissingParameter(
75 "d_model must be positive".to_string(),
76 ));
77 }
78 if self.n_heads == 0 {
79 return Err(TrustformerError::MissingParameter(
80 "n_heads must be positive".to_string(),
81 ));
82 }
83 if self.dropout < 0.0 || self.dropout > 1.0 {
84 return Err(TrustformerError::CompilationError(
85 "dropout must be between 0 and 1".to_string(),
86 ));
87 }
88 Ok(())
89 }
90
91 pub fn complexity_reduction(&self, seq_len: usize) -> f64 {
93 if seq_len <= self.window_size {
94 1.0
95 } else {
96 self.window_size as f64 / seq_len as f64
97 }
98 }
99
100 pub fn memory_reduction(&self, seq_len: usize) -> f64 {
102 if seq_len <= self.window_size {
103 1.0
104 } else {
105 self.window_size as f64 / seq_len as f64
106 }
107 }
108}
109
110#[derive(Debug, Clone)]
112pub struct SlidingWindowAttention {
113 config: SlidingWindowConfig,
115}
116
117impl SlidingWindowAttention {
118 pub fn new(config: SlidingWindowConfig) -> Result<Self> {
120 config.validate()?;
121 Ok(Self { config })
122 }
123
124 pub fn config(&self) -> &SlidingWindowConfig {
126 &self.config
127 }
128
129 pub fn build_swa_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
131 let _n_heads = self.config.n_heads;
132 let d_k = self.config.d_k;
133
134 let q_split = graph.add_tensor("swa_q_split");
136 let k_split = graph.add_tensor("swa_k_split");
137 let v_split = graph.add_tensor("swa_v_split");
138
139 let reshape_spec = format!("bsd->bsh{}", d_k);
140
141 let q_reshape = EinsumNode::new(&reshape_spec, vec![0], vec![q_split]);
142 graph.add_node(q_reshape)?;
143
144 let k_reshape = EinsumNode::new(&reshape_spec, vec![1], vec![k_split]);
145 graph.add_node(k_reshape)?;
146
147 let v_reshape = EinsumNode::new(&reshape_spec, vec![2], vec![v_split]);
148 graph.add_node(v_reshape)?;
149
150 let q_transposed = graph.add_tensor("swa_q_transposed");
152 let k_transposed = graph.add_tensor("swa_k_transposed");
153 let v_transposed = graph.add_tensor("swa_v_transposed");
154
155 let transpose_q = EinsumNode::new("bshd->bhsd", vec![q_split], vec![q_transposed]);
156 graph.add_node(transpose_q)?;
157
158 let transpose_k = EinsumNode::new("bshd->bhsd", vec![k_split], vec![k_transposed]);
159 graph.add_node(transpose_k)?;
160
161 let transpose_v = EinsumNode::new("bshd->bhsd", vec![v_split], vec![v_transposed]);
162 graph.add_node(transpose_v)?;
163
164 let scores = graph.add_tensor("swa_scores");
166 let scores_node = EinsumNode::new(
167 "bhqd,bhkd->bhqk",
168 vec![q_transposed, k_transposed],
169 vec![scores],
170 );
171 graph.add_node(scores_node)?;
172
173 let scale_factor = (d_k as f64).sqrt();
175 let scale_tensor = graph.add_tensor("swa_scale");
176 let scaled_scores = graph.add_tensor("swa_scaled_scores");
177 let scale_node = EinsumNode::elem_binary(
178 format!("div_scalar_{}", scale_factor),
179 scores,
180 scale_tensor,
181 scaled_scores,
182 );
183 graph.add_node(scale_node)?;
184
185 let masked_scores = graph.add_tensor("swa_masked_scores");
187 let mask_node = EinsumNode::elem_unary(
188 format!("sliding_window_mask_{}", self.config.window_size),
189 scaled_scores,
190 masked_scores,
191 );
192 graph.add_node(mask_node)?;
193
194 let attention_weights = graph.add_tensor("swa_attention_weights");
196 let softmax_node = EinsumNode::elem_unary("softmax_k", masked_scores, attention_weights);
197 graph.add_node(softmax_node)?;
198
199 let attn_output = graph.add_tensor("swa_attn_output");
201 let attn_node = EinsumNode::new(
202 "bhqk,bhkv->bhqv",
203 vec![attention_weights, v_transposed],
204 vec![attn_output],
205 );
206 graph.add_node(attn_node)?;
207
208 let transposed_back = graph.add_tensor("swa_transposed_back");
210 let transpose_back =
211 EinsumNode::new("bhsd->bshd", vec![attn_output], vec![transposed_back]);
212 graph.add_node(transpose_back)?;
213
214 let output = graph.add_tensor("swa_output");
216 let reshape_back_spec = format!("bsh{}-:bsd", d_k);
217 let reshape_back = EinsumNode::new(&reshape_back_spec, vec![transposed_back], vec![output]);
218 graph.add_node(reshape_back)?;
219
220 Ok(vec![output])
221 }
222}
223
224#[derive(Debug, Clone, Copy, PartialEq, Eq)]
226pub enum SlidingWindowPreset {
227 Mistral7B,
229 LongformerBase,
231 BigBirdBase,
233}
234
235impl SlidingWindowPreset {
236 pub fn config(&self) -> Result<SlidingWindowConfig> {
238 match self {
239 SlidingWindowPreset::Mistral7B => {
240 SlidingWindowConfig::new(4096, 32, 4096)?
241 .with_causal(true)
242 .validate()?;
243 Ok(SlidingWindowConfig::new(4096, 32, 4096)?.with_causal(true))
244 }
245 SlidingWindowPreset::LongformerBase => SlidingWindowConfig::new(768, 12, 512),
246 SlidingWindowPreset::BigBirdBase => SlidingWindowConfig::new(768, 12, 256),
247 }
248 }
249
250 pub fn name(&self) -> &'static str {
252 match self {
253 SlidingWindowPreset::Mistral7B => "Mistral 7B",
254 SlidingWindowPreset::LongformerBase => "Longformer Base",
255 SlidingWindowPreset::BigBirdBase => "BigBird Base",
256 }
257 }
258}
259
260#[derive(Debug, Clone)]
262pub struct SlidingWindowStats {
263 pub config: SlidingWindowConfig,
265 pub complexity_reduction: f64,
267 pub memory_reduction: f64,
269}
270
271impl SlidingWindowStats {
272 pub fn from_config(config: &SlidingWindowConfig, seq_len: usize) -> Self {
274 Self {
275 config: config.clone(),
276 complexity_reduction: config.complexity_reduction(seq_len),
277 memory_reduction: config.memory_reduction(seq_len),
278 }
279 }
280
281 pub fn summary(&self, seq_len: usize) -> String {
283 format!(
284 "Sliding Window Attention\n d_model: {}\n n_heads: {}\n window_size: {}\n \
285 causal: {}\n complexity reduction: {:.1}%\n memory reduction: {:.1}%\n \
286 seq_len: {}",
287 self.config.d_model,
288 self.config.n_heads,
289 self.config.window_size,
290 self.config.causal,
291 (1.0 - self.complexity_reduction) * 100.0,
292 (1.0 - self.memory_reduction) * 100.0,
293 seq_len
294 )
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301
302 #[test]
303 fn test_swa_config_creation() {
304 let config = SlidingWindowConfig::new(4096, 32, 4096).unwrap();
305 assert_eq!(config.d_model, 4096);
306 assert_eq!(config.n_heads, 32);
307 assert_eq!(config.window_size, 4096);
308 assert_eq!(config.d_k, 128);
309 }
310
311 #[test]
312 fn test_swa_config_builder() {
313 let config = SlidingWindowConfig::new(4096, 32, 4096)
314 .unwrap()
315 .with_causal(true)
316 .with_dropout(0.1);
317
318 assert!(config.causal);
319 assert!((config.dropout - 0.1).abs() < 1e-10);
320 }
321
322 #[test]
323 fn test_swa_invalid_configs() {
324 assert!(SlidingWindowConfig::new(512, 7, 256).is_err());
326
327 assert!(SlidingWindowConfig::new(512, 8, 0).is_err());
329 }
330
331 #[test]
332 fn test_swa_complexity_reduction() {
333 let config = SlidingWindowConfig::new(512, 8, 256).unwrap();
334
335 assert_eq!(config.complexity_reduction(128), 1.0);
337
338 let reduction = config.complexity_reduction(4096);
340 assert!((reduction - 0.0625).abs() < 0.001);
341 }
342
343 #[test]
344 fn test_swa_graph_building() {
345 let config = SlidingWindowConfig::new(512, 8, 256).unwrap();
346 let swa = SlidingWindowAttention::new(config).unwrap();
347
348 let mut graph = EinsumGraph::new();
349 graph.add_tensor("Q");
350 graph.add_tensor("K");
351 graph.add_tensor("V");
352
353 let outputs = swa.build_swa_graph(&mut graph).unwrap();
354 assert_eq!(outputs.len(), 1);
355 }
356
357 #[test]
358 fn test_swa_causal_graph() {
359 let config = SlidingWindowConfig::new(512, 8, 256)
360 .unwrap()
361 .with_causal(true);
362 let swa = SlidingWindowAttention::new(config).unwrap();
363
364 let mut graph = EinsumGraph::new();
365 graph.add_tensor("Q");
366 graph.add_tensor("K");
367 graph.add_tensor("V");
368
369 let outputs = swa.build_swa_graph(&mut graph).unwrap();
370 assert_eq!(outputs.len(), 1);
371 }
372
373 #[test]
374 fn test_swa_presets() {
375 let config = SlidingWindowPreset::Mistral7B.config().unwrap();
377 assert_eq!(config.d_model, 4096);
378 assert_eq!(config.window_size, 4096);
379 assert!(config.causal);
380
381 let config = SlidingWindowPreset::LongformerBase.config().unwrap();
383 assert_eq!(config.d_model, 768);
384 assert_eq!(config.window_size, 512);
385
386 let config = SlidingWindowPreset::BigBirdBase.config().unwrap();
388 assert_eq!(config.window_size, 256);
389 }
390
391 #[test]
392 fn test_swa_preset_names() {
393 assert_eq!(SlidingWindowPreset::Mistral7B.name(), "Mistral 7B");
394 assert_eq!(
395 SlidingWindowPreset::LongformerBase.name(),
396 "Longformer Base"
397 );
398 }
399
400 #[test]
401 fn test_swa_stats() {
402 let config = SlidingWindowConfig::new(4096, 32, 4096).unwrap();
403 let stats = SlidingWindowStats::from_config(&config, 32768);
404
405 assert!((stats.complexity_reduction - 0.125).abs() < 0.001);
407 assert!((stats.memory_reduction - 0.125).abs() < 0.001);
408 }
409
410 #[test]
411 fn test_swa_validate() {
412 let config = SlidingWindowConfig::new(512, 8, 256).unwrap();
413 assert!(config.validate().is_ok());
414
415 let mut bad = config.clone();
417 bad.dropout = -0.1;
418 assert!(bad.validate().is_err());
419 }
420}