Skip to main content

trustformers_models/
moe.rs

1/// Mixture of Experts (MoE) infrastructure for efficient scaling
2///
3/// This module provides reusable MoE components that can be integrated into
4/// various transformer architectures like Mixtral, GLaM, Switch Transformer, etc.
5use std::collections::HashMap;
6use trustformers_core::{errors::Result, layers::Linear, tensor::Tensor, traits::Layer};
7
8/// Configuration for MoE layers
9#[derive(Debug, Clone)]
10pub struct MoEConfig {
11    pub hidden_size: usize,
12    pub num_experts: usize,
13    pub num_experts_per_token: usize,
14    pub expert_capacity: Option<usize>, // For capacity-based routing
15    pub load_balancing_loss_coeff: f32,
16    pub router_z_loss_coeff: f32,
17    pub use_auxiliary_loss: bool,
18    pub jitter_noise: f32, // Noise for load balancing
19}
20
21impl Default for MoEConfig {
22    fn default() -> Self {
23        Self {
24            hidden_size: 4096,
25            num_experts: 8,
26            num_experts_per_token: 2,
27            expert_capacity: None,
28            load_balancing_loss_coeff: 0.01,
29            router_z_loss_coeff: 0.001,
30            use_auxiliary_loss: true,
31            jitter_noise: 1e-2,
32        }
33    }
34}
35
36/// Expert routing statistics for load balancing
37#[derive(Debug, Clone)]
38pub struct RoutingStats {
39    pub expert_counts: Vec<f32>,
40    pub expert_weights: Vec<f32>,
41    pub load_balancing_loss: f32,
42    pub router_z_loss: f32,
43}
44
45/// Generic expert trait that can wrap different layer types
46pub trait Expert: Layer<Input = Tensor, Output = Tensor> + Send + Sync {
47    fn expert_id(&self) -> usize;
48    fn capacity(&self) -> Option<usize> {
49        None
50    }
51}
52
53/// Basic MLP expert implementation
54pub struct MLPExpert {
55    id: usize,
56    gate_proj: Linear,
57    up_proj: Linear,
58    down_proj: Linear,
59    activation: String,
60}
61
62impl MLPExpert {
63    pub fn new(
64        id: usize,
65        hidden_size: usize,
66        intermediate_size: usize,
67        activation: String,
68    ) -> Result<Self> {
69        let gate_proj = Linear::new(hidden_size, intermediate_size, false);
70        let up_proj = Linear::new(hidden_size, intermediate_size, false);
71        let down_proj = Linear::new(intermediate_size, hidden_size, false);
72
73        Ok(Self {
74            id,
75            gate_proj,
76            up_proj,
77            down_proj,
78            activation,
79        })
80    }
81
82    fn apply_activation(&self, x: &Tensor) -> Result<Tensor> {
83        match self.activation.as_str() {
84            "silu" | "swish" => x.silu(),
85            "gelu" => x.gelu(),
86            "relu" => x.relu(),
87            _ => Ok(x.clone()),
88        }
89    }
90}
91
92impl Layer for MLPExpert {
93    type Input = Tensor;
94    type Output = Tensor;
95
96    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
97        let gate = self.gate_proj.forward(input.clone())?;
98        let gate_activated = self.apply_activation(&gate)?;
99
100        let up = self.up_proj.forward(input)?;
101        let gated = gate_activated.mul(&up)?;
102
103        self.down_proj.forward(gated)
104    }
105}
106
107impl Expert for MLPExpert {
108    fn expert_id(&self) -> usize {
109        self.id
110    }
111}
112
113/// Top-K router for expert selection
114pub struct TopKRouter {
115    gate: Linear,
116    config: MoEConfig,
117}
118
119impl TopKRouter {
120    pub fn new(config: MoEConfig) -> Result<Self> {
121        let gate = Linear::new(config.hidden_size, config.num_experts, false);
122        Ok(Self { gate, config })
123    }
124
125    /// Route tokens to top-k experts
126    pub fn route(&self, hidden_states: &Tensor) -> Result<RouterOutput> {
127        let batch_size = hidden_states.shape()[0];
128        let seq_len = hidden_states.shape()[1];
129        let hidden_size = hidden_states.shape()[2];
130
131        // Flatten for per-token routing
132        let flattened = hidden_states.reshape(&[batch_size * seq_len, hidden_size])?;
133
134        // Compute router logits
135        let router_logits = self.gate.forward(flattened)?;
136
137        // Add jitter noise for load balancing during training
138        let router_logits = if self.config.jitter_noise > 0.0 {
139            let noise = Tensor::randn_like(&router_logits)?.mul_scalar(self.config.jitter_noise)?;
140            router_logits.add(&noise)?
141        } else {
142            router_logits
143        };
144
145        // Apply softmax to get probabilities
146        let router_probs = router_logits.softmax(-1)?;
147
148        // Select top-k experts
149        let (top_k_weights, top_k_indices) = self.select_top_k(&router_probs)?;
150
151        // Compute auxiliary losses
152        let stats = self.compute_routing_stats(&router_probs, &top_k_weights, &top_k_indices)?;
153
154        Ok(RouterOutput {
155            top_k_weights,
156            top_k_indices,
157            router_probs,
158            stats,
159        })
160    }
161
162    fn select_top_k(&self, router_probs: &Tensor) -> Result<(Tensor, Tensor)> {
163        let num_tokens = router_probs.shape()[0];
164        let num_experts = router_probs.shape()[1];
165
166        let mut all_weights = Vec::new();
167        let mut all_indices = Vec::new();
168
169        for token_idx in 0..num_tokens {
170            // Get probabilities for this token
171            let mut expert_probs: Vec<(f32, usize)> = Vec::new();
172            for expert_idx in 0..num_experts {
173                let prob = router_probs.get_scalar(&[token_idx, expert_idx])?;
174                expert_probs.push((prob, expert_idx));
175            }
176
177            // Sort and select top-k
178            expert_probs.sort_by(|a, b| b.0.partial_cmp(&a.0).expect("operation failed"));
179            expert_probs.truncate(self.config.num_experts_per_token);
180
181            // Renormalize selected probabilities
182            let sum: f32 = expert_probs.iter().map(|(p, _)| p).sum();
183            let norm_factor = if sum > 0.0 { 1.0 / sum } else { 1.0 };
184
185            for (prob, expert_idx) in expert_probs {
186                all_weights.push(prob * norm_factor);
187                all_indices.push(expert_idx as f32);
188            }
189        }
190
191        let weights_tensor = Tensor::from_vec(
192            all_weights,
193            &[num_tokens, self.config.num_experts_per_token],
194        )?;
195        let indices_tensor = Tensor::from_vec(
196            all_indices,
197            &[num_tokens, self.config.num_experts_per_token],
198        )?;
199
200        Ok((weights_tensor, indices_tensor))
201    }
202
203    fn compute_routing_stats(
204        &self,
205        router_probs: &Tensor,
206        top_k_weights: &Tensor,
207        top_k_indices: &Tensor,
208    ) -> Result<RoutingStats> {
209        let num_tokens = router_probs.shape()[0];
210        let num_experts = self.config.num_experts;
211
212        // Count tokens routed to each expert
213        let mut expert_counts = vec![0.0; num_experts];
214        let mut expert_weights = vec![0.0; num_experts];
215
216        for token_idx in 0..num_tokens {
217            for k in 0..self.config.num_experts_per_token {
218                let expert_idx = top_k_indices.get_scalar(&[token_idx, k])? as usize;
219                let weight = top_k_weights.get_scalar(&[token_idx, k])?;
220
221                expert_counts[expert_idx] += 1.0;
222                expert_weights[expert_idx] += weight;
223            }
224        }
225
226        // Normalize counts and weights
227        let total_tokens = num_tokens as f32;
228        expert_counts.iter_mut().for_each(|c| *c /= total_tokens);
229        expert_weights.iter_mut().for_each(|w| *w /= total_tokens);
230
231        // Compute load balancing loss (encourages uniform expert usage)
232        let _mean_count = 1.0 / num_experts as f32;
233        let load_balancing_loss: f32 = expert_counts
234            .iter()
235            .zip(expert_weights.iter())
236            .map(|(count, weight)| count * weight)
237            .sum::<f32>()
238            * num_experts as f32
239            - 1.0;
240
241        // Compute router z-loss (encourages sparsity)
242        let router_z_loss = router_probs
243            .pow(2.0)?
244            .sum(Some(vec![router_probs.shape().len() - 1]), false)?
245            .mean()?
246            .get_scalar(&[])?;
247
248        Ok(RoutingStats {
249            expert_counts,
250            expert_weights,
251            load_balancing_loss,
252            router_z_loss,
253        })
254    }
255}
256
257/// Output from the router
258pub struct RouterOutput {
259    pub top_k_weights: Tensor,
260    pub top_k_indices: Tensor,
261    pub router_probs: Tensor,
262    pub stats: RoutingStats,
263}
264
265/// Sparse Mixture of Experts layer
266pub struct SparseMoE<E: Expert> {
267    experts: Vec<E>,
268    router: TopKRouter,
269    config: MoEConfig,
270}
271
272impl<E: Expert> SparseMoE<E> {
273    pub fn new(experts: Vec<E>, config: MoEConfig) -> Result<Self> {
274        let router = TopKRouter::new(config.clone())?;
275        Ok(Self {
276            experts,
277            router,
278            config,
279        })
280    }
281
282    /// Get the number of experts
283    pub fn num_experts(&self) -> usize {
284        self.experts.len()
285    }
286
287    /// Get routing statistics from the last forward pass
288    pub fn last_routing_stats(&self) -> Option<&RoutingStats> {
289        // This would need to be stored from the last forward pass
290        None
291    }
292}
293
294impl<E: Expert> Layer for SparseMoE<E> {
295    type Input = Tensor;
296    type Output = Tensor;
297
298    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
299        let batch_size = input.shape()[0];
300        let seq_len = input.shape()[1];
301        let hidden_size = input.shape()[2];
302
303        // Route tokens to experts
304        let router_output = self.router.route(&input)?;
305
306        // Flatten input for processing
307        let flattened_input = input.reshape(&[batch_size * seq_len, hidden_size])?;
308        let num_tokens = flattened_input.shape()[0];
309
310        // Initialize output
311        let mut output = Tensor::zeros(&[num_tokens, hidden_size])?;
312
313        // Process each token
314        for token_idx in 0..num_tokens {
315            // Proper tensor slicing: select single token input
316            let token_input =
317                flattened_input.slice_multi(&[(token_idx, token_idx + 1), (0, hidden_size)])?;
318
319            // Combine outputs from selected experts
320            let mut token_output = Tensor::zeros(&[1, hidden_size])?;
321            for k in 0..self.config.num_experts_per_token {
322                let expert_idx = router_output.top_k_indices.get_scalar(&[token_idx, k])? as usize;
323                let weight = router_output.top_k_weights.get_scalar(&[token_idx, k])?;
324
325                // Get expert output
326                let expert_output = self.experts[expert_idx].forward(token_input.clone())?;
327                let weighted_output = expert_output.mul_scalar(weight)?;
328
329                // Accumulate expert outputs for this token
330                token_output = token_output.add(&weighted_output)?;
331            }
332
333            // Set the token output in the final output tensor
334            // Use slice and add for proper accumulation per token
335            let token_output_slice =
336                output.slice_multi(&[(token_idx, token_idx + 1), (0, hidden_size)])?;
337            let updated_slice = token_output_slice.add(&token_output)?;
338
339            // For now, we'll use a workaround since set_slice is not available
340            // This approach maintains per-token processing but requires reconstruction
341            if token_idx == 0 {
342                output = updated_slice.clone();
343            } else {
344                // Concatenate along the first dimension
345                let current_tokens = output.slice_multi(&[(0, token_idx), (0, hidden_size)])?;
346                let remaining_shape = if token_idx + 1 < num_tokens {
347                    Some(output.slice_multi(&[(token_idx + 1, num_tokens), (0, hidden_size)])?)
348                } else {
349                    None
350                };
351
352                // Reconstruct output tensor with updated token
353                output = if let Some(remaining) = remaining_shape {
354                    Tensor::concat(&[current_tokens, updated_slice, remaining], 0)?
355                } else {
356                    Tensor::concat(&[current_tokens, updated_slice], 0)?
357                };
358            }
359        }
360
361        // Reshape back to original dimensions
362        output.reshape(&[batch_size, seq_len, hidden_size])
363    }
364}
365
366/// Switch Transformer style MoE (uses only top-1 expert per token)
367pub type SwitchMoE<E> = SparseMoE<E>;
368
369/// Helper function to create Switch Transformer configuration
370pub fn switch_config(hidden_size: usize, num_experts: usize) -> MoEConfig {
371    MoEConfig {
372        hidden_size,
373        num_experts,
374        num_experts_per_token: 1, // Switch uses top-1
375        ..Default::default()
376    }
377}
378
379/// Helper function to create GLaM-style configuration
380pub fn glam_config(hidden_size: usize, num_experts: usize) -> MoEConfig {
381    MoEConfig {
382        hidden_size,
383        num_experts,
384        num_experts_per_token: 2, // GLaM uses top-2
385        ..Default::default()
386    }
387}
388
389/// Expert parallel processing for distributed training
390pub struct ExpertParallel<E: Expert> {
391    local_experts: Vec<E>,
392    expert_mapping: HashMap<usize, usize>, // global_id -> local_id
393    #[allow(dead_code)]
394    rank: usize,
395    #[allow(dead_code)]
396    world_size: usize,
397}
398
399impl<E: Expert> ExpertParallel<E> {
400    pub fn new(experts: Vec<E>, rank: usize, world_size: usize) -> Self {
401        let mut expert_mapping = HashMap::new();
402        for (local_id, expert) in experts.iter().enumerate() {
403            expert_mapping.insert(expert.expert_id(), local_id);
404        }
405
406        Self {
407            local_experts: experts,
408            expert_mapping,
409            rank,
410            world_size,
411        }
412    }
413
414    /// Check if an expert is available locally
415    pub fn has_expert(&self, expert_id: usize) -> bool {
416        self.expert_mapping.contains_key(&expert_id)
417    }
418
419    /// Forward pass for local experts only
420    pub fn forward_local(&self, expert_id: usize, input: &Tensor) -> Result<Option<Tensor>> {
421        if let Some(&local_id) = self.expert_mapping.get(&expert_id) {
422            let output = self.local_experts[local_id].forward(input.clone())?;
423            Ok(Some(output))
424        } else {
425            Ok(None)
426        }
427    }
428}