ruvector_attention/moe/
expert.rs1use crate::error::AttentionResult;
4use crate::utils::stable_softmax;
5
6#[derive(Clone, Debug, PartialEq)]
8pub enum ExpertType {
9 Standard,
11 Hyperbolic,
13 Linear,
15}
16
17pub trait Expert: Send + Sync {
19 fn compute(
21 &self,
22 query: &[f32],
23 keys: &[&[f32]],
24 values: &[&[f32]],
25 ) -> AttentionResult<Vec<f32>>;
26
27 fn expert_type(&self) -> ExpertType;
29
30 fn dim(&self) -> usize;
32}
33
34pub struct StandardExpert {
36 dim: usize,
37 scale: f32,
38}
39
40impl StandardExpert {
41 pub fn new(dim: usize) -> Self {
42 Self {
43 dim,
44 scale: 1.0 / (dim as f32).sqrt(),
45 }
46 }
47}
48
49impl Expert for StandardExpert {
50 fn compute(
51 &self,
52 query: &[f32],
53 keys: &[&[f32]],
54 values: &[&[f32]],
55 ) -> AttentionResult<Vec<f32>> {
56 let scores: Vec<f32> = keys
58 .iter()
59 .map(|k| {
60 query
61 .iter()
62 .zip(k.iter())
63 .map(|(q, ki)| q * ki)
64 .sum::<f32>()
65 * self.scale
66 })
67 .collect();
68
69 let weights = stable_softmax(&scores);
71
72 let mut output = vec![0.0f32; self.dim];
74 for (weight, value) in weights.iter().zip(values.iter()) {
75 for (o, v) in output.iter_mut().zip(value.iter()) {
76 *o += weight * v;
77 }
78 }
79
80 Ok(output)
81 }
82
83 fn expert_type(&self) -> ExpertType {
84 ExpertType::Standard
85 }
86
87 fn dim(&self) -> usize {
88 self.dim
89 }
90}
91
92pub struct HyperbolicExpert {
94 dim: usize,
95 curvature: f32,
96}
97
98impl HyperbolicExpert {
99 pub fn new(dim: usize, curvature: f32) -> Self {
100 Self { dim, curvature }
101 }
102
103 fn poincare_distance(&self, u: &[f32], v: &[f32]) -> f32 {
104 let c = self.curvature.abs();
105 let sqrt_c = c.sqrt();
106
107 let diff_sq: f32 = u.iter().zip(v.iter()).map(|(a, b)| (a - b).powi(2)).sum();
108 let norm_u_sq: f32 = u.iter().map(|x| x * x).sum();
109 let norm_v_sq: f32 = v.iter().map(|x| x * x).sum();
110
111 let denom = (1.0 - c * norm_u_sq).max(1e-7) * (1.0 - c * norm_v_sq).max(1e-7);
112 let arg = 1.0 + 2.0 * c * diff_sq / denom;
113
114 (1.0 / sqrt_c) * arg.max(1.0).acosh()
115 }
116}
117
118impl Expert for HyperbolicExpert {
119 fn compute(
120 &self,
121 query: &[f32],
122 keys: &[&[f32]],
123 values: &[&[f32]],
124 ) -> AttentionResult<Vec<f32>> {
125 let scores: Vec<f32> = keys
127 .iter()
128 .map(|k| -self.poincare_distance(query, k))
129 .collect();
130
131 let weights = stable_softmax(&scores);
132
133 let mut output = vec![0.0f32; self.dim];
134 for (weight, value) in weights.iter().zip(values.iter()) {
135 for (o, v) in output.iter_mut().zip(value.iter()) {
136 *o += weight * v;
137 }
138 }
139
140 Ok(output)
141 }
142
143 fn expert_type(&self) -> ExpertType {
144 ExpertType::Hyperbolic
145 }
146
147 fn dim(&self) -> usize {
148 self.dim
149 }
150}
151
152pub struct LinearExpert {
154 dim: usize,
155 num_features: usize,
156 random_features: Vec<f32>,
157}
158
159impl LinearExpert {
160 pub fn new(dim: usize, num_features: usize) -> Self {
161 use std::f32::consts::PI;
162
163 let mut features = Vec::with_capacity(num_features * dim);
165 let mut seed = 123u64;
166
167 for _ in 0..((num_features * dim + 1) / 2) {
168 seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
169 let u1 = (seed as f32) / (u64::MAX as f32);
170 seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
171 let u2 = (seed as f32) / (u64::MAX as f32);
172
173 let r = (-2.0 * u1.max(1e-10).ln()).sqrt();
174 let theta = 2.0 * PI * u2;
175
176 features.push(r * theta.cos() / (dim as f32).sqrt());
177 if features.len() < num_features * dim {
178 features.push(r * theta.sin() / (dim as f32).sqrt());
179 }
180 }
181 features.truncate(num_features * dim);
182
183 Self {
184 dim,
185 num_features,
186 random_features: features,
187 }
188 }
189
190 fn feature_map(&self, x: &[f32]) -> Vec<f32> {
191 (0..self.num_features)
192 .map(|i| {
193 let proj: f32 = x
194 .iter()
195 .enumerate()
196 .map(|(j, &xj)| xj * self.random_features[i * self.dim + j])
197 .sum();
198 let norm_sq: f32 = x.iter().map(|xi| xi * xi).sum();
199 (proj - norm_sq / 2.0).exp() / (self.num_features as f32).sqrt()
200 })
201 .collect()
202 }
203}
204
205impl Expert for LinearExpert {
206 fn compute(
207 &self,
208 query: &[f32],
209 keys: &[&[f32]],
210 values: &[&[f32]],
211 ) -> AttentionResult<Vec<f32>> {
212 let phi_q = self.feature_map(query);
213 let value_dim = values.get(0).map(|v| v.len()).unwrap_or(self.dim);
214
215 let mut kv_sum = vec![0.0f32; self.num_features * value_dim];
216 let mut k_sum = vec![0.0f32; self.num_features];
217
218 for (key, value) in keys.iter().zip(values.iter()) {
219 let phi_k = self.feature_map(key);
220 for (i, &phi_ki) in phi_k.iter().enumerate() {
221 for (j, &vj) in value.iter().enumerate() {
222 kv_sum[i * value_dim + j] += phi_ki * vj;
223 }
224 k_sum[i] += phi_ki;
225 }
226 }
227
228 let mut output = vec![0.0f32; value_dim];
229 let mut normalizer = 0.0f32;
230
231 for (i, &phi_qi) in phi_q.iter().enumerate() {
232 for (j, out_j) in output.iter_mut().enumerate() {
233 *out_j += phi_qi * kv_sum[i * value_dim + j];
234 }
235 normalizer += phi_qi * k_sum[i];
236 }
237
238 if normalizer.abs() > 1e-8 {
239 output.iter_mut().for_each(|x| *x /= normalizer);
240 }
241
242 Ok(output)
243 }
244
245 fn expert_type(&self) -> ExpertType {
246 ExpertType::Linear
247 }
248
249 fn dim(&self) -> usize {
250 self.dim
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257
258 #[test]
259 fn test_standard_expert() {
260 let expert = StandardExpert::new(64);
261 let query = vec![0.5; 64];
262 let keys: Vec<Vec<f32>> = vec![vec![0.3; 64]; 10];
263 let values: Vec<Vec<f32>> = vec![vec![1.0; 64]; 10];
264
265 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
266 let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
267
268 let result = expert.compute(&query, &keys_refs, &values_refs).unwrap();
269 assert_eq!(result.len(), 64);
270 }
271
272 #[test]
273 fn test_hyperbolic_expert() {
274 let expert = HyperbolicExpert::new(32, 1.0);
275 let query = vec![0.1; 32]; let keys: Vec<Vec<f32>> = vec![vec![0.1; 32]; 5];
277 let values: Vec<Vec<f32>> = vec![vec![1.0; 32]; 5];
278
279 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
280 let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
281
282 let result = expert.compute(&query, &keys_refs, &values_refs).unwrap();
283 assert_eq!(result.len(), 32);
284 }
285
286 #[test]
287 fn test_linear_expert() {
288 let expert = LinearExpert::new(64, 32);
289 let query = vec![0.5; 64];
290 let keys: Vec<Vec<f32>> = vec![vec![0.3; 64]; 10];
291 let values: Vec<Vec<f32>> = vec![vec![1.0; 64]; 10];
292
293 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
294 let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
295
296 let result = expert.compute(&query, &keys_refs, &values_refs).unwrap();
297 assert_eq!(result.len(), 64);
298 }
299}