ruvector_attention/moe/
expert.rs

1//! Expert implementations for MoE attention
2
3use crate::error::AttentionResult;
4use crate::utils::stable_softmax;
5
6/// Type of expert
7#[derive(Clone, Debug, PartialEq)]
8pub enum ExpertType {
9    /// Standard scaled dot-product
10    Standard,
11    /// Hyperbolic attention
12    Hyperbolic,
13    /// Linear attention
14    Linear,
15}
16
17/// Expert trait for attention computation
18pub trait Expert: Send + Sync {
19    /// Compute attention for this expert
20    fn compute(
21        &self,
22        query: &[f32],
23        keys: &[&[f32]],
24        values: &[&[f32]],
25    ) -> AttentionResult<Vec<f32>>;
26
27    /// Get expert type
28    fn expert_type(&self) -> ExpertType;
29
30    /// Get dimension
31    fn dim(&self) -> usize;
32}
33
34/// Standard scaled dot-product expert
35pub struct StandardExpert {
36    dim: usize,
37    scale: f32,
38}
39
40impl StandardExpert {
41    pub fn new(dim: usize) -> Self {
42        Self {
43            dim,
44            scale: 1.0 / (dim as f32).sqrt(),
45        }
46    }
47}
48
49impl Expert for StandardExpert {
50    fn compute(
51        &self,
52        query: &[f32],
53        keys: &[&[f32]],
54        values: &[&[f32]],
55    ) -> AttentionResult<Vec<f32>> {
56        // Compute attention scores
57        let scores: Vec<f32> = keys
58            .iter()
59            .map(|k| {
60                query
61                    .iter()
62                    .zip(k.iter())
63                    .map(|(q, ki)| q * ki)
64                    .sum::<f32>()
65                    * self.scale
66            })
67            .collect();
68
69        // Softmax
70        let weights = stable_softmax(&scores);
71
72        // Weighted sum
73        let mut output = vec![0.0f32; self.dim];
74        for (weight, value) in weights.iter().zip(values.iter()) {
75            for (o, v) in output.iter_mut().zip(value.iter()) {
76                *o += weight * v;
77            }
78        }
79
80        Ok(output)
81    }
82
83    fn expert_type(&self) -> ExpertType {
84        ExpertType::Standard
85    }
86
87    fn dim(&self) -> usize {
88        self.dim
89    }
90}
91
92/// Hyperbolic expert using Poincaré distance
93pub struct HyperbolicExpert {
94    dim: usize,
95    curvature: f32,
96}
97
98impl HyperbolicExpert {
99    pub fn new(dim: usize, curvature: f32) -> Self {
100        Self { dim, curvature }
101    }
102
103    fn poincare_distance(&self, u: &[f32], v: &[f32]) -> f32 {
104        let c = self.curvature.abs();
105        let sqrt_c = c.sqrt();
106
107        let diff_sq: f32 = u.iter().zip(v.iter()).map(|(a, b)| (a - b).powi(2)).sum();
108        let norm_u_sq: f32 = u.iter().map(|x| x * x).sum();
109        let norm_v_sq: f32 = v.iter().map(|x| x * x).sum();
110
111        let denom = (1.0 - c * norm_u_sq).max(1e-7) * (1.0 - c * norm_v_sq).max(1e-7);
112        let arg = 1.0 + 2.0 * c * diff_sq / denom;
113
114        (1.0 / sqrt_c) * arg.max(1.0).acosh()
115    }
116}
117
118impl Expert for HyperbolicExpert {
119    fn compute(
120        &self,
121        query: &[f32],
122        keys: &[&[f32]],
123        values: &[&[f32]],
124    ) -> AttentionResult<Vec<f32>> {
125        // Use negative Poincaré distance as similarity
126        let scores: Vec<f32> = keys
127            .iter()
128            .map(|k| -self.poincare_distance(query, k))
129            .collect();
130
131        let weights = stable_softmax(&scores);
132
133        let mut output = vec![0.0f32; self.dim];
134        for (weight, value) in weights.iter().zip(values.iter()) {
135            for (o, v) in output.iter_mut().zip(value.iter()) {
136                *o += weight * v;
137            }
138        }
139
140        Ok(output)
141    }
142
143    fn expert_type(&self) -> ExpertType {
144        ExpertType::Hyperbolic
145    }
146
147    fn dim(&self) -> usize {
148        self.dim
149    }
150}
151
152/// Linear attention expert with random features
153pub struct LinearExpert {
154    dim: usize,
155    num_features: usize,
156    random_features: Vec<f32>,
157}
158
159impl LinearExpert {
160    pub fn new(dim: usize, num_features: usize) -> Self {
161        use std::f32::consts::PI;
162
163        // Generate random features
164        let mut features = Vec::with_capacity(num_features * dim);
165        let mut seed = 123u64;
166
167        for _ in 0..((num_features * dim + 1) / 2) {
168            seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
169            let u1 = (seed as f32) / (u64::MAX as f32);
170            seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
171            let u2 = (seed as f32) / (u64::MAX as f32);
172
173            let r = (-2.0 * u1.max(1e-10).ln()).sqrt();
174            let theta = 2.0 * PI * u2;
175
176            features.push(r * theta.cos() / (dim as f32).sqrt());
177            if features.len() < num_features * dim {
178                features.push(r * theta.sin() / (dim as f32).sqrt());
179            }
180        }
181        features.truncate(num_features * dim);
182
183        Self {
184            dim,
185            num_features,
186            random_features: features,
187        }
188    }
189
190    fn feature_map(&self, x: &[f32]) -> Vec<f32> {
191        (0..self.num_features)
192            .map(|i| {
193                let proj: f32 = x
194                    .iter()
195                    .enumerate()
196                    .map(|(j, &xj)| xj * self.random_features[i * self.dim + j])
197                    .sum();
198                let norm_sq: f32 = x.iter().map(|xi| xi * xi).sum();
199                (proj - norm_sq / 2.0).exp() / (self.num_features as f32).sqrt()
200            })
201            .collect()
202    }
203}
204
205impl Expert for LinearExpert {
206    fn compute(
207        &self,
208        query: &[f32],
209        keys: &[&[f32]],
210        values: &[&[f32]],
211    ) -> AttentionResult<Vec<f32>> {
212        let phi_q = self.feature_map(query);
213        let value_dim = values.get(0).map(|v| v.len()).unwrap_or(self.dim);
214
215        let mut kv_sum = vec![0.0f32; self.num_features * value_dim];
216        let mut k_sum = vec![0.0f32; self.num_features];
217
218        for (key, value) in keys.iter().zip(values.iter()) {
219            let phi_k = self.feature_map(key);
220            for (i, &phi_ki) in phi_k.iter().enumerate() {
221                for (j, &vj) in value.iter().enumerate() {
222                    kv_sum[i * value_dim + j] += phi_ki * vj;
223                }
224                k_sum[i] += phi_ki;
225            }
226        }
227
228        let mut output = vec![0.0f32; value_dim];
229        let mut normalizer = 0.0f32;
230
231        for (i, &phi_qi) in phi_q.iter().enumerate() {
232            for (j, out_j) in output.iter_mut().enumerate() {
233                *out_j += phi_qi * kv_sum[i * value_dim + j];
234            }
235            normalizer += phi_qi * k_sum[i];
236        }
237
238        if normalizer.abs() > 1e-8 {
239            output.iter_mut().for_each(|x| *x /= normalizer);
240        }
241
242        Ok(output)
243    }
244
245    fn expert_type(&self) -> ExpertType {
246        ExpertType::Linear
247    }
248
249    fn dim(&self) -> usize {
250        self.dim
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257
258    #[test]
259    fn test_standard_expert() {
260        let expert = StandardExpert::new(64);
261        let query = vec![0.5; 64];
262        let keys: Vec<Vec<f32>> = vec![vec![0.3; 64]; 10];
263        let values: Vec<Vec<f32>> = vec![vec![1.0; 64]; 10];
264
265        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
266        let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
267
268        let result = expert.compute(&query, &keys_refs, &values_refs).unwrap();
269        assert_eq!(result.len(), 64);
270    }
271
272    #[test]
273    fn test_hyperbolic_expert() {
274        let expert = HyperbolicExpert::new(32, 1.0);
275        let query = vec![0.1; 32]; // Small values to stay in ball
276        let keys: Vec<Vec<f32>> = vec![vec![0.1; 32]; 5];
277        let values: Vec<Vec<f32>> = vec![vec![1.0; 32]; 5];
278
279        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
280        let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
281
282        let result = expert.compute(&query, &keys_refs, &values_refs).unwrap();
283        assert_eq!(result.len(), 32);
284    }
285
286    #[test]
287    fn test_linear_expert() {
288        let expert = LinearExpert::new(64, 32);
289        let query = vec![0.5; 64];
290        let keys: Vec<Vec<f32>> = vec![vec![0.3; 64]; 10];
291        let values: Vec<Vec<f32>> = vec![vec![1.0; 64]; 10];
292
293        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
294        let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
295
296        let result = expert.compute(&query, &keys_refs, &values_refs).unwrap();
297        assert_eq!(result.len(), 64);
298    }
299}