ruvector_attention/sparse/
linear.rs

1//! Linear attention using random feature approximation (Performer-style)
2//!
3//! Complexity: O(n * k * d) where k = number of random features
4
5use crate::error::{AttentionError, AttentionResult};
6use crate::traits::Attention;
7
8/// Kernel type for linear attention
9#[derive(Clone, Debug)]
10pub enum KernelType {
11    /// FAVOR+ softmax approximation
12    Softmax,
13    /// ReLU kernel
14    ReLU,
15    /// ELU kernel
16    ELU,
17}
18
19/// Linear attention with random feature maps
20///
21/// Uses kernel trick to achieve O(n * k * d) complexity instead of O(n² * d).
22pub struct LinearAttention {
23    dim: usize,
24    num_features: usize,
25    kernel: KernelType,
26    /// Random projection matrix [num_features x dim]
27    random_features: Vec<f32>,
28}
29
30impl LinearAttention {
31    /// Create new linear attention
32    pub fn new(dim: usize, num_features: usize) -> Self {
33        Self::with_kernel(dim, num_features, KernelType::Softmax)
34    }
35
36    /// Create with specific kernel type
37    pub fn with_kernel(dim: usize, num_features: usize, kernel: KernelType) -> Self {
38        // Initialize random features using Box-Muller for Gaussian
39        let random_features = Self::generate_random_features(dim, num_features);
40
41        Self {
42            dim,
43            num_features,
44            kernel,
45            random_features,
46        }
47    }
48
49    fn generate_random_features(dim: usize, num_features: usize) -> Vec<f32> {
50        use std::f32::consts::PI;
51
52        let mut features = Vec::with_capacity(num_features * dim);
53        let mut seed = 42u64;
54
55        for _ in 0..((num_features * dim + 1) / 2) {
56            // Simple LCG for reproducibility
57            seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
58            let u1 = (seed as f32) / (u64::MAX as f32);
59            seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
60            let u2 = (seed as f32) / (u64::MAX as f32);
61
62            // Box-Muller transform
63            let r = (-2.0 * u1.max(1e-10).ln()).sqrt();
64            let theta = 2.0 * PI * u2;
65
66            features.push(r * theta.cos());
67            if features.len() < num_features * dim {
68                features.push(r * theta.sin());
69            }
70        }
71
72        features.truncate(num_features * dim);
73
74        // Normalize columns
75        let scale = 1.0 / (dim as f32).sqrt();
76        features.iter_mut().for_each(|x| *x *= scale);
77
78        features
79    }
80
81    /// Apply feature map to input
82    fn feature_map(&self, x: &[f32]) -> Vec<f32> {
83        let mut phi = vec![0.0f32; self.num_features];
84
85        for (i, phi_i) in phi.iter_mut().enumerate() {
86            let projection: f32 = x
87                .iter()
88                .enumerate()
89                .map(|(j, &xj)| xj * self.random_features[i * self.dim + j])
90                .sum();
91
92            *phi_i = match self.kernel {
93                KernelType::Softmax => {
94                    // FAVOR+: exp(projection - ||x||²/2) / sqrt(num_features)
95                    let norm_sq: f32 = x.iter().map(|xi| xi * xi).sum();
96                    (projection - norm_sq / 2.0).exp() / (self.num_features as f32).sqrt()
97                }
98                KernelType::ReLU => projection.max(0.0),
99                KernelType::ELU => {
100                    if projection >= 0.0 {
101                        projection
102                    } else {
103                        projection.exp() - 1.0
104                    }
105                }
106            };
107        }
108
109        phi
110    }
111}
112
113impl Attention for LinearAttention {
114    fn compute(
115        &self,
116        query: &[f32],
117        keys: &[&[f32]],
118        values: &[&[f32]],
119    ) -> AttentionResult<Vec<f32>> {
120        if keys.is_empty() {
121            return Err(AttentionError::InvalidConfig("Empty keys".to_string()));
122        }
123        if keys.len() != values.len() {
124            return Err(AttentionError::DimensionMismatch {
125                expected: keys.len(),
126                actual: values.len(),
127            });
128        }
129        if query.len() != self.dim {
130            return Err(AttentionError::DimensionMismatch {
131                expected: self.dim,
132                actual: query.len(),
133            });
134        }
135
136        // Compute phi(Q)
137        let phi_q = self.feature_map(query);
138
139        // Compute sum_i phi(K_i)^T * V_i  and  sum_i phi(K_i)
140        let value_dim = values[0].len();
141        let mut kv_sum = vec![0.0f32; self.num_features * value_dim]; // [num_features x value_dim]
142        let mut k_sum = vec![0.0f32; self.num_features];
143
144        for (key, value) in keys.iter().zip(values.iter()) {
145            let phi_k = self.feature_map(key);
146
147            // Accumulate phi(K)^T * V (outer product contribution)
148            for (i, &phi_ki) in phi_k.iter().enumerate() {
149                for (j, &vj) in value.iter().enumerate() {
150                    kv_sum[i * value_dim + j] += phi_ki * vj;
151                }
152                k_sum[i] += phi_ki;
153            }
154        }
155
156        // Compute output: (phi(Q)^T * KV_sum) / (phi(Q)^T * K_sum)
157        let mut output = vec![0.0f32; value_dim];
158        let mut normalizer = 0.0f32;
159
160        for (i, &phi_qi) in phi_q.iter().enumerate() {
161            for (j, out_j) in output.iter_mut().enumerate() {
162                *out_j += phi_qi * kv_sum[i * value_dim + j];
163            }
164            normalizer += phi_qi * k_sum[i];
165        }
166
167        // Normalize
168        if normalizer.abs() > 1e-8 {
169            output.iter_mut().for_each(|x| *x /= normalizer);
170        }
171
172        Ok(output)
173    }
174
175    fn compute_with_mask(
176        &self,
177        query: &[f32],
178        keys: &[&[f32]],
179        values: &[&[f32]],
180        mask: Option<&[bool]>,
181    ) -> AttentionResult<Vec<f32>> {
182        if let Some(m) = mask {
183            let filtered: Vec<(usize, bool)> = m
184                .iter()
185                .copied()
186                .enumerate()
187                .filter(|(_, keep)| *keep)
188                .collect();
189            let filtered_keys: Vec<&[f32]> = filtered.iter().map(|(i, _)| keys[*i]).collect();
190            let filtered_values: Vec<&[f32]> = filtered.iter().map(|(i, _)| values[*i]).collect();
191            self.compute(query, &filtered_keys, &filtered_values)
192        } else {
193            self.compute(query, keys, values)
194        }
195    }
196
197    fn dim(&self) -> usize {
198        self.dim
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205
206    #[test]
207    fn test_linear_attention() {
208        let attention = LinearAttention::new(64, 32);
209
210        let query = vec![0.5; 64];
211        let keys: Vec<Vec<f32>> = (0..100).map(|_| vec![0.3; 64]).collect();
212        let values: Vec<Vec<f32>> = (0..100).map(|_| vec![1.0; 64]).collect();
213
214        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
215        let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
216
217        let result = attention.compute(&query, &keys_refs, &values_refs).unwrap();
218        assert_eq!(result.len(), 64);
219    }
220
221    #[test]
222    fn test_kernel_types() {
223        for kernel in [KernelType::Softmax, KernelType::ReLU, KernelType::ELU] {
224            let attention = LinearAttention::with_kernel(32, 16, kernel);
225
226            let query = vec![1.0; 32];
227            let keys = vec![vec![0.5; 32]; 10];
228            let values = vec![vec![1.0; 32]; 10];
229
230            let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
231            let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
232
233            let result = attention.compute(&query, &keys_refs, &values_refs).unwrap();
234            assert_eq!(result.len(), 32);
235        }
236    }
237}