1use std::collections::HashMap;
6use trustformers_core::{errors::Result, layers::Linear, tensor::Tensor, traits::Layer};
7
8#[derive(Debug, Clone)]
10pub struct MoEConfig {
11 pub hidden_size: usize,
12 pub num_experts: usize,
13 pub num_experts_per_token: usize,
14 pub expert_capacity: Option<usize>, pub load_balancing_loss_coeff: f32,
16 pub router_z_loss_coeff: f32,
17 pub use_auxiliary_loss: bool,
18 pub jitter_noise: f32, }
20
21impl Default for MoEConfig {
22 fn default() -> Self {
23 Self {
24 hidden_size: 4096,
25 num_experts: 8,
26 num_experts_per_token: 2,
27 expert_capacity: None,
28 load_balancing_loss_coeff: 0.01,
29 router_z_loss_coeff: 0.001,
30 use_auxiliary_loss: true,
31 jitter_noise: 1e-2,
32 }
33 }
34}
35
36#[derive(Debug, Clone)]
38pub struct RoutingStats {
39 pub expert_counts: Vec<f32>,
40 pub expert_weights: Vec<f32>,
41 pub load_balancing_loss: f32,
42 pub router_z_loss: f32,
43}
44
45pub trait Expert: Layer<Input = Tensor, Output = Tensor> + Send + Sync {
47 fn expert_id(&self) -> usize;
48 fn capacity(&self) -> Option<usize> {
49 None
50 }
51}
52
53pub struct MLPExpert {
55 id: usize,
56 gate_proj: Linear,
57 up_proj: Linear,
58 down_proj: Linear,
59 activation: String,
60}
61
62impl MLPExpert {
63 pub fn new(
64 id: usize,
65 hidden_size: usize,
66 intermediate_size: usize,
67 activation: String,
68 ) -> Result<Self> {
69 let gate_proj = Linear::new(hidden_size, intermediate_size, false);
70 let up_proj = Linear::new(hidden_size, intermediate_size, false);
71 let down_proj = Linear::new(intermediate_size, hidden_size, false);
72
73 Ok(Self {
74 id,
75 gate_proj,
76 up_proj,
77 down_proj,
78 activation,
79 })
80 }
81
82 fn apply_activation(&self, x: &Tensor) -> Result<Tensor> {
83 match self.activation.as_str() {
84 "silu" | "swish" => x.silu(),
85 "gelu" => x.gelu(),
86 "relu" => x.relu(),
87 _ => Ok(x.clone()),
88 }
89 }
90}
91
92impl Layer for MLPExpert {
93 type Input = Tensor;
94 type Output = Tensor;
95
96 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
97 let gate = self.gate_proj.forward(input.clone())?;
98 let gate_activated = self.apply_activation(&gate)?;
99
100 let up = self.up_proj.forward(input)?;
101 let gated = gate_activated.mul(&up)?;
102
103 self.down_proj.forward(gated)
104 }
105}
106
107impl Expert for MLPExpert {
108 fn expert_id(&self) -> usize {
109 self.id
110 }
111}
112
113pub struct TopKRouter {
115 gate: Linear,
116 config: MoEConfig,
117}
118
119impl TopKRouter {
120 pub fn new(config: MoEConfig) -> Result<Self> {
121 let gate = Linear::new(config.hidden_size, config.num_experts, false);
122 Ok(Self { gate, config })
123 }
124
125 pub fn route(&self, hidden_states: &Tensor) -> Result<RouterOutput> {
127 let batch_size = hidden_states.shape()[0];
128 let seq_len = hidden_states.shape()[1];
129 let hidden_size = hidden_states.shape()[2];
130
131 let flattened = hidden_states.reshape(&[batch_size * seq_len, hidden_size])?;
133
134 let router_logits = self.gate.forward(flattened)?;
136
137 let router_logits = if self.config.jitter_noise > 0.0 {
139 let noise = Tensor::randn_like(&router_logits)?.mul_scalar(self.config.jitter_noise)?;
140 router_logits.add(&noise)?
141 } else {
142 router_logits
143 };
144
145 let router_probs = router_logits.softmax(-1)?;
147
148 let (top_k_weights, top_k_indices) = self.select_top_k(&router_probs)?;
150
151 let stats = self.compute_routing_stats(&router_probs, &top_k_weights, &top_k_indices)?;
153
154 Ok(RouterOutput {
155 top_k_weights,
156 top_k_indices,
157 router_probs,
158 stats,
159 })
160 }
161
162 fn select_top_k(&self, router_probs: &Tensor) -> Result<(Tensor, Tensor)> {
163 let num_tokens = router_probs.shape()[0];
164 let num_experts = router_probs.shape()[1];
165
166 let mut all_weights = Vec::new();
167 let mut all_indices = Vec::new();
168
169 for token_idx in 0..num_tokens {
170 let mut expert_probs: Vec<(f32, usize)> = Vec::new();
172 for expert_idx in 0..num_experts {
173 let prob = router_probs.get_scalar(&[token_idx, expert_idx])?;
174 expert_probs.push((prob, expert_idx));
175 }
176
177 expert_probs.sort_by(|a, b| b.0.partial_cmp(&a.0).expect("operation failed"));
179 expert_probs.truncate(self.config.num_experts_per_token);
180
181 let sum: f32 = expert_probs.iter().map(|(p, _)| p).sum();
183 let norm_factor = if sum > 0.0 { 1.0 / sum } else { 1.0 };
184
185 for (prob, expert_idx) in expert_probs {
186 all_weights.push(prob * norm_factor);
187 all_indices.push(expert_idx as f32);
188 }
189 }
190
191 let weights_tensor = Tensor::from_vec(
192 all_weights,
193 &[num_tokens, self.config.num_experts_per_token],
194 )?;
195 let indices_tensor = Tensor::from_vec(
196 all_indices,
197 &[num_tokens, self.config.num_experts_per_token],
198 )?;
199
200 Ok((weights_tensor, indices_tensor))
201 }
202
203 fn compute_routing_stats(
204 &self,
205 router_probs: &Tensor,
206 top_k_weights: &Tensor,
207 top_k_indices: &Tensor,
208 ) -> Result<RoutingStats> {
209 let num_tokens = router_probs.shape()[0];
210 let num_experts = self.config.num_experts;
211
212 let mut expert_counts = vec![0.0; num_experts];
214 let mut expert_weights = vec![0.0; num_experts];
215
216 for token_idx in 0..num_tokens {
217 for k in 0..self.config.num_experts_per_token {
218 let expert_idx = top_k_indices.get_scalar(&[token_idx, k])? as usize;
219 let weight = top_k_weights.get_scalar(&[token_idx, k])?;
220
221 expert_counts[expert_idx] += 1.0;
222 expert_weights[expert_idx] += weight;
223 }
224 }
225
226 let total_tokens = num_tokens as f32;
228 expert_counts.iter_mut().for_each(|c| *c /= total_tokens);
229 expert_weights.iter_mut().for_each(|w| *w /= total_tokens);
230
231 let _mean_count = 1.0 / num_experts as f32;
233 let load_balancing_loss: f32 = expert_counts
234 .iter()
235 .zip(expert_weights.iter())
236 .map(|(count, weight)| count * weight)
237 .sum::<f32>()
238 * num_experts as f32
239 - 1.0;
240
241 let router_z_loss = router_probs
243 .pow(2.0)?
244 .sum(Some(vec![router_probs.shape().len() - 1]), false)?
245 .mean()?
246 .get_scalar(&[])?;
247
248 Ok(RoutingStats {
249 expert_counts,
250 expert_weights,
251 load_balancing_loss,
252 router_z_loss,
253 })
254 }
255}
256
257pub struct RouterOutput {
259 pub top_k_weights: Tensor,
260 pub top_k_indices: Tensor,
261 pub router_probs: Tensor,
262 pub stats: RoutingStats,
263}
264
265pub struct SparseMoE<E: Expert> {
267 experts: Vec<E>,
268 router: TopKRouter,
269 config: MoEConfig,
270}
271
272impl<E: Expert> SparseMoE<E> {
273 pub fn new(experts: Vec<E>, config: MoEConfig) -> Result<Self> {
274 let router = TopKRouter::new(config.clone())?;
275 Ok(Self {
276 experts,
277 router,
278 config,
279 })
280 }
281
282 pub fn num_experts(&self) -> usize {
284 self.experts.len()
285 }
286
287 pub fn last_routing_stats(&self) -> Option<&RoutingStats> {
289 None
291 }
292}
293
294impl<E: Expert> Layer for SparseMoE<E> {
295 type Input = Tensor;
296 type Output = Tensor;
297
298 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
299 let batch_size = input.shape()[0];
300 let seq_len = input.shape()[1];
301 let hidden_size = input.shape()[2];
302
303 let router_output = self.router.route(&input)?;
305
306 let flattened_input = input.reshape(&[batch_size * seq_len, hidden_size])?;
308 let num_tokens = flattened_input.shape()[0];
309
310 let mut output = Tensor::zeros(&[num_tokens, hidden_size])?;
312
313 for token_idx in 0..num_tokens {
315 let token_input =
317 flattened_input.slice_multi(&[(token_idx, token_idx + 1), (0, hidden_size)])?;
318
319 let mut token_output = Tensor::zeros(&[1, hidden_size])?;
321 for k in 0..self.config.num_experts_per_token {
322 let expert_idx = router_output.top_k_indices.get_scalar(&[token_idx, k])? as usize;
323 let weight = router_output.top_k_weights.get_scalar(&[token_idx, k])?;
324
325 let expert_output = self.experts[expert_idx].forward(token_input.clone())?;
327 let weighted_output = expert_output.mul_scalar(weight)?;
328
329 token_output = token_output.add(&weighted_output)?;
331 }
332
333 let token_output_slice =
336 output.slice_multi(&[(token_idx, token_idx + 1), (0, hidden_size)])?;
337 let updated_slice = token_output_slice.add(&token_output)?;
338
339 if token_idx == 0 {
342 output = updated_slice.clone();
343 } else {
344 let current_tokens = output.slice_multi(&[(0, token_idx), (0, hidden_size)])?;
346 let remaining_shape = if token_idx + 1 < num_tokens {
347 Some(output.slice_multi(&[(token_idx + 1, num_tokens), (0, hidden_size)])?)
348 } else {
349 None
350 };
351
352 output = if let Some(remaining) = remaining_shape {
354 Tensor::concat(&[current_tokens, updated_slice, remaining], 0)?
355 } else {
356 Tensor::concat(&[current_tokens, updated_slice], 0)?
357 };
358 }
359 }
360
361 output.reshape(&[batch_size, seq_len, hidden_size])
363 }
364}
365
366pub type SwitchMoE<E> = SparseMoE<E>;
368
369pub fn switch_config(hidden_size: usize, num_experts: usize) -> MoEConfig {
371 MoEConfig {
372 hidden_size,
373 num_experts,
374 num_experts_per_token: 1, ..Default::default()
376 }
377}
378
379pub fn glam_config(hidden_size: usize, num_experts: usize) -> MoEConfig {
381 MoEConfig {
382 hidden_size,
383 num_experts,
384 num_experts_per_token: 2, ..Default::default()
386 }
387}
388
389pub struct ExpertParallel<E: Expert> {
391 local_experts: Vec<E>,
392 expert_mapping: HashMap<usize, usize>, #[allow(dead_code)]
394 rank: usize,
395 #[allow(dead_code)]
396 world_size: usize,
397}
398
399impl<E: Expert> ExpertParallel<E> {
400 pub fn new(experts: Vec<E>, rank: usize, world_size: usize) -> Self {
401 let mut expert_mapping = HashMap::new();
402 for (local_id, expert) in experts.iter().enumerate() {
403 expert_mapping.insert(expert.expert_id(), local_id);
404 }
405
406 Self {
407 local_experts: experts,
408 expert_mapping,
409 rank,
410 world_size,
411 }
412 }
413
414 pub fn has_expert(&self, expert_id: usize) -> bool {
416 self.expert_mapping.contains_key(&expert_id)
417 }
418
419 pub fn forward_local(&self, expert_id: usize, input: &Tensor) -> Result<Option<Tensor>> {
421 if let Some(&local_id) = self.expert_mapping.get(&expert_id) {
422 let output = self.local_experts[local_id].forward(input.clone())?;
423 Ok(Some(output))
424 } else {
425 Ok(None)
426 }
427 }
428}