ruvector_attention/pde_attention/
diffusion.rs1use crate::error::{AttentionError, AttentionResult};
6use crate::traits::Attention;
7use super::laplacian::{GraphLaplacian, LaplacianType};
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct DiffusionConfig {
13 pub dim: usize,
15 pub diffusion_time: f32,
17 pub num_steps: usize,
19 pub sigma: f32,
21 pub knn_k: usize,
23 pub laplacian_type: LaplacianType,
25 pub temperature: f32,
27}
28
29impl Default for DiffusionConfig {
30 fn default() -> Self {
31 Self {
32 dim: 512,
33 diffusion_time: 1.0,
34 num_steps: 5,
35 sigma: 1.0,
36 knn_k: 0, laplacian_type: LaplacianType::RandomWalk,
38 temperature: 1.0,
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
48pub struct DiffusionAttention {
49 config: DiffusionConfig,
50}
51
52impl DiffusionAttention {
53 pub fn new(config: DiffusionConfig) -> Self {
55 Self { config }
56 }
57
58 pub fn with_dim(dim: usize) -> Self {
60 Self::new(DiffusionConfig {
61 dim,
62 ..Default::default()
63 })
64 }
65
66 pub fn compute_diffusion(
68 &self,
69 query: &[f32],
70 keys: &[&[f32]],
71 values: &[&[f32]],
72 ) -> AttentionResult<Vec<f32>> {
73 let n = keys.len();
74 if n == 0 {
75 return Err(AttentionError::InvalidConfig("No keys".into()));
76 }
77
78 let laplacian = if self.config.knn_k > 0 {
80 GraphLaplacian::from_keys_knn(
81 keys,
82 self.config.knn_k,
83 self.config.sigma,
84 self.config.laplacian_type,
85 )
86 } else {
87 GraphLaplacian::from_keys(
88 keys,
89 self.config.sigma,
90 self.config.laplacian_type,
91 )
92 };
93
94 let mut x: Vec<f32> = keys
96 .iter()
97 .map(|k| Self::dot_product_simd(query, k))
98 .collect();
99
100 let dt = self.config.diffusion_time / self.config.num_steps.max(1) as f32;
102
103 for _ in 0..self.config.num_steps {
104 let lx = laplacian.apply(&x);
105 for i in 0..n {
106 x[i] -= dt * lx[i];
107 }
108 }
109
110 let temp = self.config.temperature.max(1e-6);
112 for xi in x.iter_mut() {
113 *xi /= temp;
114 }
115
116 let weights = Self::stable_softmax(&x);
118
119 self.weighted_sum(&weights, values)
121 }
122
123 pub fn diffusion_energy(&self, x: &[f32], laplacian: &GraphLaplacian) -> f32 {
126 let lx = laplacian.apply(x);
127 Self::dot_product_simd(x, &lx)
128 }
129
130 pub fn compute_multiscale(
132 &self,
133 query: &[f32],
134 keys: &[&[f32]],
135 num_scales: usize,
136 ) -> Vec<Vec<f32>> {
137 let n = keys.len();
138 if n == 0 {
139 return vec![];
140 }
141
142 let laplacian = if self.config.knn_k > 0 {
143 GraphLaplacian::from_keys_knn(
144 keys,
145 self.config.knn_k,
146 self.config.sigma,
147 self.config.laplacian_type,
148 )
149 } else {
150 GraphLaplacian::from_keys(
151 keys,
152 self.config.sigma,
153 self.config.laplacian_type,
154 )
155 };
156
157 let mut x: Vec<f32> = keys
158 .iter()
159 .map(|k| Self::dot_product_simd(query, k))
160 .collect();
161
162 let mut scales = Vec::with_capacity(num_scales);
163 scales.push(Self::stable_softmax(&x)); let total_steps = self.config.num_steps * num_scales;
166 let dt = self.config.diffusion_time / total_steps.max(1) as f32;
167 let steps_per_scale = self.config.num_steps;
168
169 for _ in 1..num_scales {
170 for _ in 0..steps_per_scale {
171 let lx = laplacian.apply(&x);
172 for i in 0..n {
173 x[i] -= dt * lx[i];
174 }
175 }
176 scales.push(Self::stable_softmax(&x));
177 }
178
179 scales
180 }
181
182 #[inline(always)]
184 fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
185 let len = a.len().min(b.len());
186 let chunks = len / 4;
187 let remainder = len % 4;
188
189 let mut sum0 = 0.0f32;
190 let mut sum1 = 0.0f32;
191 let mut sum2 = 0.0f32;
192 let mut sum3 = 0.0f32;
193
194 for i in 0..chunks {
195 let base = i * 4;
196 sum0 += a[base] * b[base];
197 sum1 += a[base + 1] * b[base + 1];
198 sum2 += a[base + 2] * b[base + 2];
199 sum3 += a[base + 3] * b[base + 3];
200 }
201
202 let base = chunks * 4;
203 for i in 0..remainder {
204 sum0 += a[base + i] * b[base + i];
205 }
206
207 sum0 + sum1 + sum2 + sum3
208 }
209
210 fn stable_softmax(logits: &[f32]) -> Vec<f32> {
212 if logits.is_empty() {
213 return vec![];
214 }
215
216 let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
217 let exp_logits: Vec<f32> = logits.iter().map(|&l| (l - max_logit).exp()).collect();
218 let sum: f32 = exp_logits.iter().sum();
219
220 if sum > 0.0 {
222 exp_logits.iter().map(|&e| e / sum).collect()
223 } else {
224 vec![1.0 / logits.len() as f32; logits.len()]
226 }
227 }
228
229 fn weighted_sum(&self, weights: &[f32], values: &[&[f32]]) -> AttentionResult<Vec<f32>> {
231 if weights.is_empty() || values.is_empty() {
232 return Err(AttentionError::InvalidConfig("Empty inputs".into()));
233 }
234
235 let dim = values[0].len();
236 let mut output = vec![0.0f32; dim];
237
238 for (weight, value) in weights.iter().zip(values.iter()) {
239 for (o, &v) in output.iter_mut().zip(value.iter()) {
240 *o += weight * v;
241 }
242 }
243
244 Ok(output)
245 }
246}
247
248impl Attention for DiffusionAttention {
249 fn compute(
250 &self,
251 query: &[f32],
252 keys: &[&[f32]],
253 values: &[&[f32]],
254 ) -> AttentionResult<Vec<f32>> {
255 self.compute_diffusion(query, keys, values)
256 }
257
258 fn compute_with_mask(
259 &self,
260 query: &[f32],
261 keys: &[&[f32]],
262 values: &[&[f32]],
263 mask: Option<&[bool]>,
264 ) -> AttentionResult<Vec<f32>> {
265 if let Some(m) = mask {
266 let filtered: Vec<(&[f32], &[f32])> = keys
267 .iter()
268 .zip(values.iter())
269 .enumerate()
270 .filter(|(i, _)| m.get(*i).copied().unwrap_or(true))
271 .map(|(_, (k, v))| (*k, *v))
272 .collect();
273
274 let filtered_keys: Vec<&[f32]> = filtered.iter().map(|(k, _)| *k).collect();
275 let filtered_values: Vec<&[f32]> = filtered.iter().map(|(_, v)| *v).collect();
276
277 self.compute(query, &filtered_keys, &filtered_values)
278 } else {
279 self.compute(query, keys, values)
280 }
281 }
282
283 fn dim(&self) -> usize {
284 self.config.dim
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291
292 #[test]
293 fn test_diffusion_attention() {
294 let attention = DiffusionAttention::with_dim(16);
295
296 let query = vec![1.0f32; 16];
297 let keys: Vec<Vec<f32>> = (0..8)
298 .map(|i| vec![i as f32 * 0.1; 16])
299 .collect();
300 let values: Vec<Vec<f32>> = (0..8)
301 .map(|i| vec![i as f32; 16])
302 .collect();
303
304 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
305 let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
306
307 let output = attention.compute(&query, &keys_refs, &values_refs).unwrap();
308 assert_eq!(output.len(), 16);
309 }
310
311 #[test]
312 fn test_multiscale() {
313 let config = DiffusionConfig {
314 dim: 8,
315 num_steps: 2,
316 ..Default::default()
317 };
318 let attention = DiffusionAttention::new(config);
319
320 let query = vec![1.0f32; 8];
321 let keys: Vec<Vec<f32>> = (0..5)
322 .map(|i| vec![i as f32 * 0.1; 8])
323 .collect();
324
325 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
326
327 let scales = attention.compute_multiscale(&query, &keys_refs, 3);
328
329 assert_eq!(scales.len(), 3);
330 for scale in scales {
331 assert_eq!(scale.len(), 5);
332 let sum: f32 = scale.iter().sum();
334 assert!((sum - 1.0).abs() < 1e-5);
335 }
336 }
337
338 #[test]
339 fn test_knn_diffusion() {
340 let config = DiffusionConfig {
341 dim: 8,
342 knn_k: 3,
343 ..Default::default()
344 };
345 let attention = DiffusionAttention::new(config);
346
347 let query = vec![1.0f32; 8];
348 let keys: Vec<Vec<f32>> = (0..10)
349 .map(|i| vec![i as f32 * 0.1; 8])
350 .collect();
351 let values: Vec<Vec<f32>> = (0..10)
352 .map(|i| vec![i as f32; 8])
353 .collect();
354
355 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
356 let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
357
358 let output = attention.compute(&query, &keys_refs, &values_refs).unwrap();
359 assert_eq!(output.len(), 8);
360 }
361}