ruvector_attention/sparse/
linear.rs1use crate::error::{AttentionError, AttentionResult};
6use crate::traits::Attention;
7
8#[derive(Clone, Debug)]
10pub enum KernelType {
11 Softmax,
13 ReLU,
15 ELU,
17}
18
19pub struct LinearAttention {
23 dim: usize,
24 num_features: usize,
25 kernel: KernelType,
26 random_features: Vec<f32>,
28}
29
30impl LinearAttention {
31 pub fn new(dim: usize, num_features: usize) -> Self {
33 Self::with_kernel(dim, num_features, KernelType::Softmax)
34 }
35
36 pub fn with_kernel(dim: usize, num_features: usize, kernel: KernelType) -> Self {
38 let random_features = Self::generate_random_features(dim, num_features);
40
41 Self {
42 dim,
43 num_features,
44 kernel,
45 random_features,
46 }
47 }
48
49 fn generate_random_features(dim: usize, num_features: usize) -> Vec<f32> {
50 use std::f32::consts::PI;
51
52 let mut features = Vec::with_capacity(num_features * dim);
53 let mut seed = 42u64;
54
55 for _ in 0..((num_features * dim + 1) / 2) {
56 seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
58 let u1 = (seed as f32) / (u64::MAX as f32);
59 seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
60 let u2 = (seed as f32) / (u64::MAX as f32);
61
62 let r = (-2.0 * u1.max(1e-10).ln()).sqrt();
64 let theta = 2.0 * PI * u2;
65
66 features.push(r * theta.cos());
67 if features.len() < num_features * dim {
68 features.push(r * theta.sin());
69 }
70 }
71
72 features.truncate(num_features * dim);
73
74 let scale = 1.0 / (dim as f32).sqrt();
76 features.iter_mut().for_each(|x| *x *= scale);
77
78 features
79 }
80
81 fn feature_map(&self, x: &[f32]) -> Vec<f32> {
83 let mut phi = vec![0.0f32; self.num_features];
84
85 for (i, phi_i) in phi.iter_mut().enumerate() {
86 let projection: f32 = x
87 .iter()
88 .enumerate()
89 .map(|(j, &xj)| xj * self.random_features[i * self.dim + j])
90 .sum();
91
92 *phi_i = match self.kernel {
93 KernelType::Softmax => {
94 let norm_sq: f32 = x.iter().map(|xi| xi * xi).sum();
96 (projection - norm_sq / 2.0).exp() / (self.num_features as f32).sqrt()
97 }
98 KernelType::ReLU => projection.max(0.0),
99 KernelType::ELU => {
100 if projection >= 0.0 {
101 projection
102 } else {
103 projection.exp() - 1.0
104 }
105 }
106 };
107 }
108
109 phi
110 }
111}
112
113impl Attention for LinearAttention {
114 fn compute(
115 &self,
116 query: &[f32],
117 keys: &[&[f32]],
118 values: &[&[f32]],
119 ) -> AttentionResult<Vec<f32>> {
120 if keys.is_empty() {
121 return Err(AttentionError::InvalidConfig("Empty keys".to_string()));
122 }
123 if keys.len() != values.len() {
124 return Err(AttentionError::DimensionMismatch {
125 expected: keys.len(),
126 actual: values.len(),
127 });
128 }
129 if query.len() != self.dim {
130 return Err(AttentionError::DimensionMismatch {
131 expected: self.dim,
132 actual: query.len(),
133 });
134 }
135
136 let phi_q = self.feature_map(query);
138
139 let value_dim = values[0].len();
141 let mut kv_sum = vec![0.0f32; self.num_features * value_dim]; let mut k_sum = vec![0.0f32; self.num_features];
143
144 for (key, value) in keys.iter().zip(values.iter()) {
145 let phi_k = self.feature_map(key);
146
147 for (i, &phi_ki) in phi_k.iter().enumerate() {
149 for (j, &vj) in value.iter().enumerate() {
150 kv_sum[i * value_dim + j] += phi_ki * vj;
151 }
152 k_sum[i] += phi_ki;
153 }
154 }
155
156 let mut output = vec![0.0f32; value_dim];
158 let mut normalizer = 0.0f32;
159
160 for (i, &phi_qi) in phi_q.iter().enumerate() {
161 for (j, out_j) in output.iter_mut().enumerate() {
162 *out_j += phi_qi * kv_sum[i * value_dim + j];
163 }
164 normalizer += phi_qi * k_sum[i];
165 }
166
167 if normalizer.abs() > 1e-8 {
169 output.iter_mut().for_each(|x| *x /= normalizer);
170 }
171
172 Ok(output)
173 }
174
175 fn compute_with_mask(
176 &self,
177 query: &[f32],
178 keys: &[&[f32]],
179 values: &[&[f32]],
180 mask: Option<&[bool]>,
181 ) -> AttentionResult<Vec<f32>> {
182 if let Some(m) = mask {
183 let filtered: Vec<(usize, bool)> = m
184 .iter()
185 .copied()
186 .enumerate()
187 .filter(|(_, keep)| *keep)
188 .collect();
189 let filtered_keys: Vec<&[f32]> = filtered.iter().map(|(i, _)| keys[*i]).collect();
190 let filtered_values: Vec<&[f32]> = filtered.iter().map(|(i, _)| values[*i]).collect();
191 self.compute(query, &filtered_keys, &filtered_values)
192 } else {
193 self.compute(query, keys, values)
194 }
195 }
196
197 fn dim(&self) -> usize {
198 self.dim
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205
206 #[test]
207 fn test_linear_attention() {
208 let attention = LinearAttention::new(64, 32);
209
210 let query = vec![0.5; 64];
211 let keys: Vec<Vec<f32>> = (0..100).map(|_| vec![0.3; 64]).collect();
212 let values: Vec<Vec<f32>> = (0..100).map(|_| vec![1.0; 64]).collect();
213
214 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
215 let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
216
217 let result = attention.compute(&query, &keys_refs, &values_refs).unwrap();
218 assert_eq!(result.len(), 64);
219 }
220
221 #[test]
222 fn test_kernel_types() {
223 for kernel in [KernelType::Softmax, KernelType::ReLU, KernelType::ELU] {
224 let attention = LinearAttention::with_kernel(32, 16, kernel);
225
226 let query = vec![1.0; 32];
227 let keys = vec![vec![0.5; 32]; 10];
228 let values = vec![vec![1.0; 32]; 10];
229
230 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
231 let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
232
233 let result = attention.compute(&query, &keys_refs, &values_refs).unwrap();
234 assert_eq!(result.len(), 32);
235 }
236 }
237}