ruvector_attention/pde_attention/
diffusion.rs1use super::laplacian::{GraphLaplacian, LaplacianType};
6use crate::error::{AttentionError, AttentionResult};
7use crate::traits::Attention;
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct DiffusionConfig {
13 pub dim: usize,
15 pub diffusion_time: f32,
17 pub num_steps: usize,
19 pub sigma: f32,
21 pub knn_k: usize,
23 pub laplacian_type: LaplacianType,
25 pub temperature: f32,
27}
28
29impl Default for DiffusionConfig {
30 fn default() -> Self {
31 Self {
32 dim: 512,
33 diffusion_time: 1.0,
34 num_steps: 5,
35 sigma: 1.0,
36 knn_k: 0, laplacian_type: LaplacianType::RandomWalk,
38 temperature: 1.0,
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
48pub struct DiffusionAttention {
49 config: DiffusionConfig,
50}
51
52impl DiffusionAttention {
53 pub fn new(config: DiffusionConfig) -> Self {
55 Self { config }
56 }
57
58 pub fn with_dim(dim: usize) -> Self {
60 Self::new(DiffusionConfig {
61 dim,
62 ..Default::default()
63 })
64 }
65
66 pub fn compute_diffusion(
68 &self,
69 query: &[f32],
70 keys: &[&[f32]],
71 values: &[&[f32]],
72 ) -> AttentionResult<Vec<f32>> {
73 let n = keys.len();
74 if n == 0 {
75 return Err(AttentionError::InvalidConfig("No keys".into()));
76 }
77
78 let laplacian = if self.config.knn_k > 0 {
80 GraphLaplacian::from_keys_knn(
81 keys,
82 self.config.knn_k,
83 self.config.sigma,
84 self.config.laplacian_type,
85 )
86 } else {
87 GraphLaplacian::from_keys(keys, self.config.sigma, self.config.laplacian_type)
88 };
89
90 let mut x: Vec<f32> = keys
92 .iter()
93 .map(|k| Self::dot_product_simd(query, k))
94 .collect();
95
96 let dt = self.config.diffusion_time / self.config.num_steps.max(1) as f32;
98
99 for _ in 0..self.config.num_steps {
100 let lx = laplacian.apply(&x);
101 for i in 0..n {
102 x[i] -= dt * lx[i];
103 }
104 }
105
106 let temp = self.config.temperature.max(1e-6);
108 for xi in x.iter_mut() {
109 *xi /= temp;
110 }
111
112 let weights = Self::stable_softmax(&x);
114
115 self.weighted_sum(&weights, values)
117 }
118
119 pub fn diffusion_energy(&self, x: &[f32], laplacian: &GraphLaplacian) -> f32 {
122 let lx = laplacian.apply(x);
123 Self::dot_product_simd(x, &lx)
124 }
125
126 pub fn compute_multiscale(
128 &self,
129 query: &[f32],
130 keys: &[&[f32]],
131 num_scales: usize,
132 ) -> Vec<Vec<f32>> {
133 let n = keys.len();
134 if n == 0 {
135 return vec![];
136 }
137
138 let laplacian = if self.config.knn_k > 0 {
139 GraphLaplacian::from_keys_knn(
140 keys,
141 self.config.knn_k,
142 self.config.sigma,
143 self.config.laplacian_type,
144 )
145 } else {
146 GraphLaplacian::from_keys(keys, self.config.sigma, self.config.laplacian_type)
147 };
148
149 let mut x: Vec<f32> = keys
150 .iter()
151 .map(|k| Self::dot_product_simd(query, k))
152 .collect();
153
154 let mut scales = Vec::with_capacity(num_scales);
155 scales.push(Self::stable_softmax(&x)); let total_steps = self.config.num_steps * num_scales;
158 let dt = self.config.diffusion_time / total_steps.max(1) as f32;
159 let steps_per_scale = self.config.num_steps;
160
161 for _ in 1..num_scales {
162 for _ in 0..steps_per_scale {
163 let lx = laplacian.apply(&x);
164 for i in 0..n {
165 x[i] -= dt * lx[i];
166 }
167 }
168 scales.push(Self::stable_softmax(&x));
169 }
170
171 scales
172 }
173
174 #[inline(always)]
176 fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
177 let len = a.len().min(b.len());
178 let chunks = len / 4;
179 let remainder = len % 4;
180
181 let mut sum0 = 0.0f32;
182 let mut sum1 = 0.0f32;
183 let mut sum2 = 0.0f32;
184 let mut sum3 = 0.0f32;
185
186 for i in 0..chunks {
187 let base = i * 4;
188 sum0 += a[base] * b[base];
189 sum1 += a[base + 1] * b[base + 1];
190 sum2 += a[base + 2] * b[base + 2];
191 sum3 += a[base + 3] * b[base + 3];
192 }
193
194 let base = chunks * 4;
195 for i in 0..remainder {
196 sum0 += a[base + i] * b[base + i];
197 }
198
199 sum0 + sum1 + sum2 + sum3
200 }
201
202 fn stable_softmax(logits: &[f32]) -> Vec<f32> {
204 if logits.is_empty() {
205 return vec![];
206 }
207
208 let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
209 let exp_logits: Vec<f32> = logits.iter().map(|&l| (l - max_logit).exp()).collect();
210 let sum: f32 = exp_logits.iter().sum();
211
212 if sum > 0.0 {
214 exp_logits.iter().map(|&e| e / sum).collect()
215 } else {
216 vec![1.0 / logits.len() as f32; logits.len()]
218 }
219 }
220
221 fn weighted_sum(&self, weights: &[f32], values: &[&[f32]]) -> AttentionResult<Vec<f32>> {
223 if weights.is_empty() || values.is_empty() {
224 return Err(AttentionError::InvalidConfig("Empty inputs".into()));
225 }
226
227 let dim = values[0].len();
228 let mut output = vec![0.0f32; dim];
229
230 for (weight, value) in weights.iter().zip(values.iter()) {
231 for (o, &v) in output.iter_mut().zip(value.iter()) {
232 *o += weight * v;
233 }
234 }
235
236 Ok(output)
237 }
238}
239
240impl Attention for DiffusionAttention {
241 fn compute(
242 &self,
243 query: &[f32],
244 keys: &[&[f32]],
245 values: &[&[f32]],
246 ) -> AttentionResult<Vec<f32>> {
247 self.compute_diffusion(query, keys, values)
248 }
249
250 fn compute_with_mask(
251 &self,
252 query: &[f32],
253 keys: &[&[f32]],
254 values: &[&[f32]],
255 mask: Option<&[bool]>,
256 ) -> AttentionResult<Vec<f32>> {
257 if let Some(m) = mask {
258 let filtered: Vec<(&[f32], &[f32])> = keys
259 .iter()
260 .zip(values.iter())
261 .enumerate()
262 .filter(|(i, _)| m.get(*i).copied().unwrap_or(true))
263 .map(|(_, (k, v))| (*k, *v))
264 .collect();
265
266 let filtered_keys: Vec<&[f32]> = filtered.iter().map(|(k, _)| *k).collect();
267 let filtered_values: Vec<&[f32]> = filtered.iter().map(|(_, v)| *v).collect();
268
269 self.compute(query, &filtered_keys, &filtered_values)
270 } else {
271 self.compute(query, keys, values)
272 }
273 }
274
275 fn dim(&self) -> usize {
276 self.config.dim
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 #[test]
285 fn test_diffusion_attention() {
286 let attention = DiffusionAttention::with_dim(16);
287
288 let query = vec![1.0f32; 16];
289 let keys: Vec<Vec<f32>> = (0..8).map(|i| vec![i as f32 * 0.1; 16]).collect();
290 let values: Vec<Vec<f32>> = (0..8).map(|i| vec![i as f32; 16]).collect();
291
292 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
293 let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
294
295 let output = attention.compute(&query, &keys_refs, &values_refs).unwrap();
296 assert_eq!(output.len(), 16);
297 }
298
299 #[test]
300 fn test_multiscale() {
301 let config = DiffusionConfig {
302 dim: 8,
303 num_steps: 2,
304 ..Default::default()
305 };
306 let attention = DiffusionAttention::new(config);
307
308 let query = vec![1.0f32; 8];
309 let keys: Vec<Vec<f32>> = (0..5).map(|i| vec![i as f32 * 0.1; 8]).collect();
310
311 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
312
313 let scales = attention.compute_multiscale(&query, &keys_refs, 3);
314
315 assert_eq!(scales.len(), 3);
316 for scale in scales {
317 assert_eq!(scale.len(), 5);
318 let sum: f32 = scale.iter().sum();
320 assert!((sum - 1.0).abs() < 1e-5);
321 }
322 }
323
324 #[test]
325 fn test_knn_diffusion() {
326 let config = DiffusionConfig {
327 dim: 8,
328 knn_k: 3,
329 ..Default::default()
330 };
331 let attention = DiffusionAttention::new(config);
332
333 let query = vec![1.0f32; 8];
334 let keys: Vec<Vec<f32>> = (0..10).map(|i| vec![i as f32 * 0.1; 8]).collect();
335 let values: Vec<Vec<f32>> = (0..10).map(|i| vec![i as f32; 8]).collect();
336
337 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
338 let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
339
340 let output = attention.compute(&query, &keys_refs, &values_refs).unwrap();
341 assert_eq!(output.len(), 8);
342 }
343}