ruvector_attention/pde_attention/
diffusion.rs

1//! Diffusion Attention
2//!
3//! Attention as heat diffusion on a key similarity graph.
4
5use crate::error::{AttentionError, AttentionResult};
6use crate::traits::Attention;
7use super::laplacian::{GraphLaplacian, LaplacianType};
8use serde::{Deserialize, Serialize};
9
10/// Diffusion attention configuration
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct DiffusionConfig {
13    /// Model dimension
14    pub dim: usize,
15    /// Total diffusion time
16    pub diffusion_time: f32,
17    /// Number of diffusion steps
18    pub num_steps: usize,
19    /// Sigma for Gaussian kernel
20    pub sigma: f32,
21    /// Use k-NN sparse Laplacian (0 = dense)
22    pub knn_k: usize,
23    /// Laplacian type
24    pub laplacian_type: LaplacianType,
25    /// Temperature for final softmax
26    pub temperature: f32,
27}
28
29impl Default for DiffusionConfig {
30    fn default() -> Self {
31        Self {
32            dim: 512,
33            diffusion_time: 1.0,
34            num_steps: 5,
35            sigma: 1.0,
36            knn_k: 0, // Dense
37            laplacian_type: LaplacianType::RandomWalk,
38            temperature: 1.0,
39        }
40    }
41}
42
43/// Diffusion-based Attention
44///
45/// Computes attention by diffusing initial logits on a key similarity graph.
46/// This provides multi-scale smoothing and noise resistance.
47#[derive(Debug, Clone)]
48pub struct DiffusionAttention {
49    config: DiffusionConfig,
50}
51
52impl DiffusionAttention {
53    /// Create new diffusion attention
54    pub fn new(config: DiffusionConfig) -> Self {
55        Self { config }
56    }
57
58    /// Create with dimension only
59    pub fn with_dim(dim: usize) -> Self {
60        Self::new(DiffusionConfig {
61            dim,
62            ..Default::default()
63        })
64    }
65
66    /// Compute diffusion attention
67    pub fn compute_diffusion(
68        &self,
69        query: &[f32],
70        keys: &[&[f32]],
71        values: &[&[f32]],
72    ) -> AttentionResult<Vec<f32>> {
73        let n = keys.len();
74        if n == 0 {
75            return Err(AttentionError::InvalidConfig("No keys".into()));
76        }
77
78        // Build Laplacian
79        let laplacian = if self.config.knn_k > 0 {
80            GraphLaplacian::from_keys_knn(
81                keys,
82                self.config.knn_k,
83                self.config.sigma,
84                self.config.laplacian_type,
85            )
86        } else {
87            GraphLaplacian::from_keys(
88                keys,
89                self.config.sigma,
90                self.config.laplacian_type,
91            )
92        };
93
94        // Initial logits from dot product
95        let mut x: Vec<f32> = keys
96            .iter()
97            .map(|k| Self::dot_product_simd(query, k))
98            .collect();
99
100        // Diffusion: x_{t+dt} = x_t - dt * L * x_t
101        let dt = self.config.diffusion_time / self.config.num_steps.max(1) as f32;
102
103        for _ in 0..self.config.num_steps {
104            let lx = laplacian.apply(&x);
105            for i in 0..n {
106                x[i] -= dt * lx[i];
107            }
108        }
109
110        // Apply temperature (Security: prevent division by zero)
111        let temp = self.config.temperature.max(1e-6);
112        for xi in x.iter_mut() {
113            *xi /= temp;
114        }
115
116        // Softmax
117        let weights = Self::stable_softmax(&x);
118
119        // Weighted sum of values
120        self.weighted_sum(&weights, values)
121    }
122
123    /// Compute diffusion energy (for monitoring)
124    /// E = x^T L x (smoothness measure)
125    pub fn diffusion_energy(&self, x: &[f32], laplacian: &GraphLaplacian) -> f32 {
126        let lx = laplacian.apply(x);
127        Self::dot_product_simd(x, &lx)
128    }
129
130    /// Compute multi-scale attention (return attention at different times)
131    pub fn compute_multiscale(
132        &self,
133        query: &[f32],
134        keys: &[&[f32]],
135        num_scales: usize,
136    ) -> Vec<Vec<f32>> {
137        let n = keys.len();
138        if n == 0 {
139            return vec![];
140        }
141
142        let laplacian = if self.config.knn_k > 0 {
143            GraphLaplacian::from_keys_knn(
144                keys,
145                self.config.knn_k,
146                self.config.sigma,
147                self.config.laplacian_type,
148            )
149        } else {
150            GraphLaplacian::from_keys(
151                keys,
152                self.config.sigma,
153                self.config.laplacian_type,
154            )
155        };
156
157        let mut x: Vec<f32> = keys
158            .iter()
159            .map(|k| Self::dot_product_simd(query, k))
160            .collect();
161
162        let mut scales = Vec::with_capacity(num_scales);
163        scales.push(Self::stable_softmax(&x)); // t=0
164
165        let total_steps = self.config.num_steps * num_scales;
166        let dt = self.config.diffusion_time / total_steps.max(1) as f32;
167        let steps_per_scale = self.config.num_steps;
168
169        for _ in 1..num_scales {
170            for _ in 0..steps_per_scale {
171                let lx = laplacian.apply(&x);
172                for i in 0..n {
173                    x[i] -= dt * lx[i];
174                }
175            }
176            scales.push(Self::stable_softmax(&x));
177        }
178
179        scales
180    }
181
182    /// SIMD-friendly dot product
183    #[inline(always)]
184    fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
185        let len = a.len().min(b.len());
186        let chunks = len / 4;
187        let remainder = len % 4;
188
189        let mut sum0 = 0.0f32;
190        let mut sum1 = 0.0f32;
191        let mut sum2 = 0.0f32;
192        let mut sum3 = 0.0f32;
193
194        for i in 0..chunks {
195            let base = i * 4;
196            sum0 += a[base] * b[base];
197            sum1 += a[base + 1] * b[base + 1];
198            sum2 += a[base + 2] * b[base + 2];
199            sum3 += a[base + 3] * b[base + 3];
200        }
201
202        let base = chunks * 4;
203        for i in 0..remainder {
204            sum0 += a[base + i] * b[base + i];
205        }
206
207        sum0 + sum1 + sum2 + sum3
208    }
209
210    /// Stable softmax
211    fn stable_softmax(logits: &[f32]) -> Vec<f32> {
212        if logits.is_empty() {
213            return vec![];
214        }
215
216        let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
217        let exp_logits: Vec<f32> = logits.iter().map(|&l| (l - max_logit).exp()).collect();
218        let sum: f32 = exp_logits.iter().sum();
219
220        // Security: prevent division by zero if all exp values underflow
221        if sum > 0.0 {
222            exp_logits.iter().map(|&e| e / sum).collect()
223        } else {
224            // Fallback to uniform distribution
225            vec![1.0 / logits.len() as f32; logits.len()]
226        }
227    }
228
229    /// Weighted sum
230    fn weighted_sum(&self, weights: &[f32], values: &[&[f32]]) -> AttentionResult<Vec<f32>> {
231        if weights.is_empty() || values.is_empty() {
232            return Err(AttentionError::InvalidConfig("Empty inputs".into()));
233        }
234
235        let dim = values[0].len();
236        let mut output = vec![0.0f32; dim];
237
238        for (weight, value) in weights.iter().zip(values.iter()) {
239            for (o, &v) in output.iter_mut().zip(value.iter()) {
240                *o += weight * v;
241            }
242        }
243
244        Ok(output)
245    }
246}
247
248impl Attention for DiffusionAttention {
249    fn compute(
250        &self,
251        query: &[f32],
252        keys: &[&[f32]],
253        values: &[&[f32]],
254    ) -> AttentionResult<Vec<f32>> {
255        self.compute_diffusion(query, keys, values)
256    }
257
258    fn compute_with_mask(
259        &self,
260        query: &[f32],
261        keys: &[&[f32]],
262        values: &[&[f32]],
263        mask: Option<&[bool]>,
264    ) -> AttentionResult<Vec<f32>> {
265        if let Some(m) = mask {
266            let filtered: Vec<(&[f32], &[f32])> = keys
267                .iter()
268                .zip(values.iter())
269                .enumerate()
270                .filter(|(i, _)| m.get(*i).copied().unwrap_or(true))
271                .map(|(_, (k, v))| (*k, *v))
272                .collect();
273
274            let filtered_keys: Vec<&[f32]> = filtered.iter().map(|(k, _)| *k).collect();
275            let filtered_values: Vec<&[f32]> = filtered.iter().map(|(_, v)| *v).collect();
276
277            self.compute(query, &filtered_keys, &filtered_values)
278        } else {
279            self.compute(query, keys, values)
280        }
281    }
282
283    fn dim(&self) -> usize {
284        self.config.dim
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291
292    #[test]
293    fn test_diffusion_attention() {
294        let attention = DiffusionAttention::with_dim(16);
295
296        let query = vec![1.0f32; 16];
297        let keys: Vec<Vec<f32>> = (0..8)
298            .map(|i| vec![i as f32 * 0.1; 16])
299            .collect();
300        let values: Vec<Vec<f32>> = (0..8)
301            .map(|i| vec![i as f32; 16])
302            .collect();
303
304        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
305        let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
306
307        let output = attention.compute(&query, &keys_refs, &values_refs).unwrap();
308        assert_eq!(output.len(), 16);
309    }
310
311    #[test]
312    fn test_multiscale() {
313        let config = DiffusionConfig {
314            dim: 8,
315            num_steps: 2,
316            ..Default::default()
317        };
318        let attention = DiffusionAttention::new(config);
319
320        let query = vec![1.0f32; 8];
321        let keys: Vec<Vec<f32>> = (0..5)
322            .map(|i| vec![i as f32 * 0.1; 8])
323            .collect();
324
325        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
326
327        let scales = attention.compute_multiscale(&query, &keys_refs, 3);
328
329        assert_eq!(scales.len(), 3);
330        for scale in scales {
331            assert_eq!(scale.len(), 5);
332            // Each scale should sum to 1
333            let sum: f32 = scale.iter().sum();
334            assert!((sum - 1.0).abs() < 1e-5);
335        }
336    }
337
338    #[test]
339    fn test_knn_diffusion() {
340        let config = DiffusionConfig {
341            dim: 8,
342            knn_k: 3,
343            ..Default::default()
344        };
345        let attention = DiffusionAttention::new(config);
346
347        let query = vec![1.0f32; 8];
348        let keys: Vec<Vec<f32>> = (0..10)
349            .map(|i| vec![i as f32 * 0.1; 8])
350            .collect();
351        let values: Vec<Vec<f32>> = (0..10)
352            .map(|i| vec![i as f32; 8])
353            .collect();
354
355        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
356        let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
357
358        let output = attention.compute(&query, &keys_refs, &values_refs).unwrap();
359        assert_eq!(output.len(), 8);
360    }
361}