Skip to main content

torsh_distributed/expert_parallelism/
config.rs

1//! Configuration types for Expert Parallelism
2//!
3//! This module defines configuration structures and enums used throughout
4//! the expert parallelism system, including sharding strategies, parameters,
5//! and optimization settings.
6
7use serde::{Deserialize, Serialize};
8
9/// Expert parallelism configuration
10///
11/// This structure contains all the configuration parameters needed to set up
12/// and run a Mixture of Experts (MoE) model with distributed expert parallelism.
13///
14/// # Examples
15///
16/// ```rust
17/// use torsh_distributed::expert_parallelism::config::{ExpertParallelismConfig, ExpertShardingStrategy};
18///
19/// let config = ExpertParallelismConfig {
20///     num_experts: 16,
21///     num_experts_per_token: 2,
22///     capacity_factor: 1.5,
23///     sharding_strategy: ExpertShardingStrategy::ModelParallel,
24///     ..Default::default()
25/// };
26/// ```
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct ExpertParallelismConfig {
29    /// Number of experts in the MoE layer
30    ///
31    /// This determines the total number of expert networks available for routing.
32    /// Typical values range from 8 to 1024 depending on model size and requirements.
33    pub num_experts: usize,
34
35    /// Number of experts to activate per token (top-k)
36    ///
37    /// Each token is routed to the top-k experts based on router scores.
38    /// Common values are 1, 2, or 4. Higher values increase computational cost
39    /// but may improve model quality.
40    pub num_experts_per_token: usize,
41
42    /// Expert capacity factor (capacity = tokens_per_expert * capacity_factor)
43    ///
44    /// This factor determines how many tokens each expert can process.
45    /// Values > 1.0 provide buffer capacity to handle load imbalance.
46    /// Typical range: 1.0 to 2.0.
47    pub capacity_factor: f32,
48
49    /// Load balancing loss coefficient
50    ///
51    /// Weight for the auxiliary loss that encourages balanced expert utilization.
52    /// Higher values enforce stronger load balancing but may hurt model quality.
53    /// Typical range: 0.001 to 0.1.
54    pub load_balance_loss_coeff: f32,
55
56    /// Router z-loss coefficient (for numerical stability)
57    ///
58    /// Weight for the z-loss that encourages router logits to stay close to zero,
59    /// improving numerical stability. Typical range: 0.0001 to 0.01.
60    pub router_z_loss_coeff: f32,
61
62    /// Enable expert dropout during training
63    ///
64    /// Probability of randomly dropping experts during training to improve
65    /// robustness and prevent overfitting. Range: 0.0 to 1.0.
66    pub expert_dropout: f32,
67
68    /// Enable load balancing across devices
69    ///
70    /// When true, the system actively monitors and rebalances expert utilization
71    /// across different devices to optimize resource usage.
72    pub enable_load_balancing: bool,
73
74    /// Expert sharding strategy
75    ///
76    /// Determines how experts are distributed across devices and processes.
77    pub sharding_strategy: ExpertShardingStrategy,
78
79    /// Maximum batch size for expert processing
80    ///
81    /// Limits the number of tokens that can be processed by a single expert
82    /// in one forward pass. Helps control memory usage.
83    pub max_expert_batch_size: Option<usize>,
84
85    /// Enable gradient accumulation across experts
86    ///
87    /// When true, gradients are accumulated across multiple expert invocations
88    /// before updating parameters, which can improve training stability.
89    pub enable_gradient_accumulation: bool,
90
91    /// Number of gradient accumulation steps
92    ///
93    /// Only relevant when gradient accumulation is enabled.
94    pub gradient_accumulation_steps: usize,
95
96    /// Expert initialization strategy
97    ///
98    /// Method used to initialize expert parameters.
99    pub initialization_strategy: ExpertInitStrategy,
100
101    /// Enable expert synchronization
102    ///
103    /// When true, experts synchronize their parameters periodically during training.
104    pub enable_expert_sync: bool,
105
106    /// Synchronization frequency (in steps)
107    ///
108    /// How often to synchronize expert parameters when synchronization is enabled.
109    pub sync_frequency: usize,
110
111    /// Gate network configuration
112    ///
113    /// Optional configuration for hierarchical or advanced gate networks.
114    pub gate_network: Option<GateNetworkConfig>,
115
116    /// Load balancing configuration
117    ///
118    /// Configuration for expert load balancing and migration.
119    pub load_balancing: Option<LoadBalancingConfig>,
120
121    /// Migration configuration
122    ///
123    /// Configuration for expert migration strategies and triggers.
124    pub migration: Option<ExpertMigrationConfig>,
125
126    /// Enable expert migration (simplified flag)
127    pub enable_expert_migration: bool,
128
129    /// Migration threshold for triggering migrations
130    pub migration_threshold: f32,
131
132    /// Memory allocated per expert (in MB)
133    pub memory_per_expert_mb: usize,
134
135    /// Enable communication overlap
136    pub communication_overlap: bool,
137
138    /// Enable gradient compression
139    pub gradient_compression: bool,
140}
141
142impl Default for ExpertParallelismConfig {
143    fn default() -> Self {
144        Self {
145            num_experts: 8,
146            num_experts_per_token: 2,
147            capacity_factor: 1.25,
148            load_balance_loss_coeff: 0.01,
149            router_z_loss_coeff: 0.001,
150            expert_dropout: 0.0,
151            enable_load_balancing: true,
152            sharding_strategy: ExpertShardingStrategy::ModelParallel,
153            max_expert_batch_size: None,
154            enable_gradient_accumulation: false,
155            gradient_accumulation_steps: 1,
156            initialization_strategy: ExpertInitStrategy::Xavier,
157            enable_expert_sync: false,
158            sync_frequency: 100,
159            gate_network: None,
160            load_balancing: None,
161            migration: None,
162            enable_expert_migration: false,
163            migration_threshold: 0.3,
164            memory_per_expert_mb: 512,
165            communication_overlap: true,
166            gradient_compression: false,
167        }
168    }
169}
170
171impl ExpertParallelismConfig {
172    /// Create a new configuration with default values
173    pub fn new() -> Self {
174        Self::default()
175    }
176
177    /// Create a configuration optimized for small-scale deployment
178    ///
179    /// # Returns
180    ///
181    /// A configuration suitable for models with 8-16 experts
182    pub fn small_scale() -> Self {
183        Self {
184            num_experts: 8,
185            num_experts_per_token: 2,
186            capacity_factor: 1.25,
187            load_balance_loss_coeff: 0.01,
188            sharding_strategy: ExpertShardingStrategy::DataParallel,
189            ..Default::default()
190        }
191    }
192
193    /// Create a configuration optimized for large-scale deployment
194    ///
195    /// # Returns
196    ///
197    /// A configuration suitable for models with 64+ experts
198    pub fn large_scale() -> Self {
199        Self {
200            num_experts: 128,
201            num_experts_per_token: 2,
202            capacity_factor: 1.5,
203            load_balance_loss_coeff: 0.001,
204            sharding_strategy: ExpertShardingStrategy::ModelParallel,
205            enable_gradient_accumulation: true,
206            gradient_accumulation_steps: 4,
207            enable_expert_sync: true,
208            sync_frequency: 50,
209            ..Default::default()
210        }
211    }
212
213    /// Create a configuration optimized for inference
214    ///
215    /// # Returns
216    ///
217    /// A configuration with settings optimized for inference workloads
218    pub fn inference() -> Self {
219        Self {
220            expert_dropout: 0.0,
221            enable_load_balancing: false,
222            enable_gradient_accumulation: false,
223            enable_expert_sync: false,
224            ..Default::default()
225        }
226    }
227
228    /// Validate the configuration parameters
229    ///
230    /// # Returns
231    ///
232    /// Result indicating whether the configuration is valid
233    pub fn validate(&self) -> Result<(), String> {
234        if self.num_experts == 0 {
235            return Err("Number of experts must be greater than 0".to_string());
236        }
237
238        if self.num_experts_per_token == 0 || self.num_experts_per_token > self.num_experts {
239            return Err(
240                "Number of experts per token must be between 1 and num_experts".to_string(),
241            );
242        }
243
244        if self.capacity_factor <= 0.0 {
245            return Err("Capacity factor must be positive".to_string());
246        }
247
248        if self.load_balance_loss_coeff < 0.0 {
249            return Err("Load balance loss coefficient must be non-negative".to_string());
250        }
251
252        if self.router_z_loss_coeff < 0.0 {
253            return Err("Router z-loss coefficient must be non-negative".to_string());
254        }
255
256        if self.expert_dropout < 0.0 || self.expert_dropout > 1.0 {
257            return Err("Expert dropout must be between 0.0 and 1.0".to_string());
258        }
259
260        if self.gradient_accumulation_steps == 0 {
261            return Err("Gradient accumulation steps must be greater than 0".to_string());
262        }
263
264        if self.sync_frequency == 0 {
265            return Err("Sync frequency must be greater than 0".to_string());
266        }
267
268        Ok(())
269    }
270
271    /// Calculate the effective expert capacity
272    ///
273    /// # Arguments
274    ///
275    /// * `total_tokens` - Total number of tokens in the batch
276    ///
277    /// # Returns
278    ///
279    /// The effective capacity per expert
280    pub fn calculate_expert_capacity(&self, total_tokens: usize) -> usize {
281        let tokens_per_expert = (total_tokens * self.num_experts_per_token) / self.num_experts;
282        (tokens_per_expert as f32 * self.capacity_factor).ceil() as usize
283    }
284
285    /// Get the recommended number of devices for this configuration
286    ///
287    /// # Returns
288    ///
289    /// Recommended number of devices based on the sharding strategy
290    pub fn recommended_num_devices(&self) -> usize {
291        match self.sharding_strategy {
292            ExpertShardingStrategy::DataParallel => 1,
293            ExpertShardingStrategy::ModelParallel => self.num_experts.min(64),
294            ExpertShardingStrategy::Hybrid => (self.num_experts / 4).clamp(2, 16),
295            ExpertShardingStrategy::Dynamic => (self.num_experts / 2).clamp(4, 32),
296        }
297    }
298}
299
300/// Expert sharding strategies
301///
302/// Defines how experts are distributed across devices and processes
303/// in a distributed training setup.
304#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
305pub enum ExpertShardingStrategy {
306    /// Each device holds all experts (data parallel)
307    ///
308    /// All experts are replicated on each device. This strategy is suitable
309    /// for smaller models or when communication costs are high.
310    DataParallel,
311
312    /// Each device holds a subset of experts (model parallel)
313    ///
314    /// Experts are partitioned across devices. This strategy is suitable
315    /// for large models where memory constraints require expert sharding.
316    ModelParallel,
317
318    /// Hybrid: some experts replicated, others sharded
319    ///
320    /// Combines data and model parallelism. Frequently used experts may be
321    /// replicated while less common experts are sharded.
322    Hybrid,
323
324    /// Dynamic: expert placement adapts to load
325    ///
326    /// Expert placement is dynamically adjusted based on runtime load patterns.
327    /// This strategy requires more sophisticated load monitoring and migration.
328    Dynamic,
329}
330
331impl ExpertShardingStrategy {
332    /// Get a description of the sharding strategy
333    ///
334    /// # Returns
335    ///
336    /// A string describing the strategy
337    pub fn description(&self) -> &'static str {
338        match self {
339            Self::DataParallel => "All experts replicated on each device",
340            Self::ModelParallel => "Experts partitioned across devices",
341            Self::Hybrid => "Mix of replicated and partitioned experts",
342            Self::Dynamic => "Dynamic expert placement based on load",
343        }
344    }
345
346    /// Check if this strategy requires load balancing
347    ///
348    /// # Returns
349    ///
350    /// True if the strategy benefits from active load balancing
351    pub fn requires_load_balancing(&self) -> bool {
352        matches!(self, Self::ModelParallel | Self::Hybrid | Self::Dynamic)
353    }
354
355    /// Check if this strategy supports dynamic migration
356    ///
357    /// # Returns
358    ///
359    /// True if experts can be migrated between devices
360    pub fn supports_migration(&self) -> bool {
361        matches!(self, Self::Hybrid | Self::Dynamic)
362    }
363}
364
365/// Expert parameter configuration
366///
367/// Defines the architecture parameters for individual expert networks.
368#[derive(Debug, Clone, Serialize, Deserialize)]
369pub struct ExpertParameters {
370    /// Input dimension of the expert network
371    pub input_dim: usize,
372
373    /// Hidden dimension of the expert network
374    ///
375    /// Typically 4x the input dimension for transformer-style experts.
376    pub hidden_dim: usize,
377
378    /// Output dimension of the expert network
379    ///
380    /// Usually matches the input dimension for residual connections.
381    pub output_dim: usize,
382
383    /// Activation function name
384    ///
385    /// Common choices: "relu", "gelu", "swish", "tanh"
386    pub activation: String,
387
388    /// Number of hidden layers in the expert
389    pub num_layers: usize,
390
391    /// Dropout probability for expert layers
392    pub dropout: f32,
393
394    /// Whether to use bias in linear layers
395    pub use_bias: bool,
396
397    /// Layer normalization configuration
398    pub layer_norm_eps: f32,
399
400    /// Weight initialization scale
401    pub init_scale: f32,
402}
403
404impl Default for ExpertParameters {
405    fn default() -> Self {
406        Self {
407            input_dim: 512,
408            hidden_dim: 2048,
409            output_dim: 512,
410            activation: "relu".to_string(),
411            num_layers: 2,
412            dropout: 0.1,
413            use_bias: true,
414            layer_norm_eps: 1e-5,
415            init_scale: 0.02,
416        }
417    }
418}
419
420impl ExpertParameters {
421    /// Create a new expert parameter configuration
422    pub fn new(input_dim: usize, hidden_dim: usize, output_dim: usize) -> Self {
423        Self {
424            input_dim,
425            hidden_dim,
426            output_dim,
427            ..Default::default()
428        }
429    }
430
431    /// Create parameters for a transformer-style expert
432    ///
433    /// # Arguments
434    ///
435    /// * `model_dim` - The model dimension (input/output dimension)
436    ///
437    /// # Returns
438    ///
439    /// Parameters configured for transformer-style FFN experts
440    pub fn transformer_ffn(model_dim: usize) -> Self {
441        Self {
442            input_dim: model_dim,
443            hidden_dim: model_dim * 4,
444            output_dim: model_dim,
445            activation: "gelu".to_string(),
446            ..Default::default()
447        }
448    }
449
450    /// Create parameters for a lightweight expert
451    ///
452    /// # Arguments
453    ///
454    /// * `model_dim` - The model dimension
455    ///
456    /// # Returns
457    ///
458    /// Parameters configured for lightweight experts with reduced capacity
459    pub fn lightweight(model_dim: usize) -> Self {
460        Self {
461            input_dim: model_dim,
462            hidden_dim: model_dim * 2,
463            output_dim: model_dim,
464            num_layers: 1,
465            dropout: 0.05,
466            ..Default::default()
467        }
468    }
469
470    /// Validate the parameter configuration
471    ///
472    /// # Returns
473    ///
474    /// Result indicating whether the parameters are valid
475    pub fn validate(&self) -> Result<(), String> {
476        if self.input_dim == 0 {
477            return Err("Input dimension must be greater than 0".to_string());
478        }
479
480        if self.hidden_dim == 0 {
481            return Err("Hidden dimension must be greater than 0".to_string());
482        }
483
484        if self.output_dim == 0 {
485            return Err("Output dimension must be greater than 0".to_string());
486        }
487
488        if self.num_layers == 0 {
489            return Err("Number of layers must be greater than 0".to_string());
490        }
491
492        if self.dropout < 0.0 || self.dropout > 1.0 {
493            return Err("Dropout must be between 0.0 and 1.0".to_string());
494        }
495
496        if self.layer_norm_eps <= 0.0 {
497            return Err("Layer norm epsilon must be positive".to_string());
498        }
499
500        if self.init_scale <= 0.0 {
501            return Err("Initialization scale must be positive".to_string());
502        }
503
504        let valid_activations = ["relu", "gelu", "swish", "tanh", "leaky_relu", "elu"];
505        if !valid_activations.contains(&self.activation.as_str()) {
506            return Err(format!(
507                "Unsupported activation function: {}. Supported: {:?}",
508                self.activation, valid_activations
509            ));
510        }
511
512        Ok(())
513    }
514
515    /// Calculate the total number of parameters for this expert configuration
516    ///
517    /// # Returns
518    ///
519    /// Estimated number of parameters
520    pub fn parameter_count(&self) -> usize {
521        if self.num_layers == 1 {
522            // Single layer: input -> hidden -> output
523            let layer1_params =
524                self.input_dim * self.hidden_dim + if self.use_bias { self.hidden_dim } else { 0 };
525            let layer2_params =
526                self.hidden_dim * self.output_dim + if self.use_bias { self.output_dim } else { 0 };
527            layer1_params + layer2_params
528        } else {
529            // Multiple layers
530            let input_layer =
531                self.input_dim * self.hidden_dim + if self.use_bias { self.hidden_dim } else { 0 };
532            let hidden_layers = (self.num_layers - 2)
533                * (self.hidden_dim * self.hidden_dim
534                    + if self.use_bias { self.hidden_dim } else { 0 });
535            let output_layer =
536                self.hidden_dim * self.output_dim + if self.use_bias { self.output_dim } else { 0 };
537            input_layer + hidden_layers + output_layer
538        }
539    }
540}
541
542/// Expert initialization strategies
543#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
544pub enum ExpertInitStrategy {
545    /// Xavier/Glorot initialization
546    Xavier,
547    /// Kaiming/He initialization
548    Kaiming,
549    /// Normal distribution with specified std
550    Normal,
551    /// Uniform distribution
552    Uniform,
553    /// Truncated normal distribution
554    TruncatedNormal,
555}
556
557impl ExpertInitStrategy {
558    /// Get a description of the initialization strategy
559    pub fn description(&self) -> &'static str {
560        match self {
561            Self::Xavier => "Xavier/Glorot initialization for balanced gradients",
562            Self::Kaiming => "Kaiming/He initialization for ReLU networks",
563            Self::Normal => "Standard normal distribution",
564            Self::Uniform => "Uniform distribution",
565            Self::TruncatedNormal => "Truncated normal distribution",
566        }
567    }
568}
569
570/// Gate network configuration for hierarchical expert routing
571#[derive(Debug, Clone, Serialize, Deserialize)]
572pub struct GateNetworkConfig {
573    /// Hierarchical gate configuration
574    pub hierarchical: Option<HierarchicalGateConfig>,
575
576    /// Enable learned gate networks
577    pub enable_learned_gates: bool,
578
579    /// Gate network dropout
580    pub gate_dropout: f32,
581
582    /// Number of gate layers
583    pub num_gate_layers: usize,
584}
585
586impl Default for GateNetworkConfig {
587    fn default() -> Self {
588        Self {
589            hierarchical: None,
590            enable_learned_gates: true,
591            gate_dropout: 0.1,
592            num_gate_layers: 2,
593        }
594    }
595}
596
597/// Hierarchical gate network configuration
598#[derive(Debug, Clone, Serialize, Deserialize)]
599pub struct HierarchicalGateConfig {
600    /// Number of hierarchical levels
601    pub levels: usize,
602
603    /// Number of experts per group at each level
604    pub experts_per_group: usize,
605
606    /// Hidden dimension for gate networks
607    pub gate_hidden_dim: usize,
608
609    /// Enable learned expert grouping
610    pub use_learned_grouping: bool,
611
612    /// Group assignment strategy
613    pub grouping_strategy: GroupingStrategy,
614}
615
616impl Default for HierarchicalGateConfig {
617    fn default() -> Self {
618        Self {
619            levels: 2,
620            experts_per_group: 8,
621            gate_hidden_dim: 512,
622            use_learned_grouping: true,
623            grouping_strategy: GroupingStrategy::LoadBased,
624        }
625    }
626}
627
628/// Expert grouping strategies for hierarchical gates
629#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
630pub enum GroupingStrategy {
631    /// Group experts based on current load
632    LoadBased,
633    /// Group experts based on similarity
634    SimilarityBased,
635    /// Use static expert grouping
636    Static,
637    /// Dynamic grouping based on routing patterns
638    Dynamic,
639}
640
641/// Load balancing configuration
642#[derive(Debug, Clone, Serialize, Deserialize)]
643pub struct LoadBalancingConfig {
644    /// Enable automatic load balancing
645    pub enable_auto_balancing: bool,
646
647    /// Load imbalance threshold for triggering rebalancing
648    pub imbalance_threshold: f32,
649
650    /// Frequency of load balancing checks (in steps)
651    pub check_frequency: usize,
652
653    /// Maximum number of concurrent migrations
654    pub max_concurrent_migrations: usize,
655
656    /// Load smoothing factor for load history
657    pub load_smoothing_factor: f32,
658}
659
660impl Default for LoadBalancingConfig {
661    fn default() -> Self {
662        Self {
663            enable_auto_balancing: true,
664            imbalance_threshold: 0.3,
665            check_frequency: 50,
666            max_concurrent_migrations: 2,
667            load_smoothing_factor: 0.9,
668        }
669    }
670}
671
672/// Expert migration configuration
673#[derive(Debug, Clone, Serialize, Deserialize)]
674pub struct ExpertMigrationConfig {
675    /// Enable expert migration
676    pub enable_migration: bool,
677
678    /// Migration trigger conditions
679    pub triggers: Vec<MigrationTrigger>,
680
681    /// Migration strategy preferences
682    pub preferred_strategies: Vec<MigrationStrategy>,
683
684    /// Migration cooldown period (in steps)
685    pub cooldown_period: usize,
686
687    /// Maximum migration distance (number of devices)
688    pub max_migration_distance: usize,
689}
690
691impl Default for ExpertMigrationConfig {
692    fn default() -> Self {
693        Self {
694            enable_migration: false,
695            triggers: vec![MigrationTrigger::LoadImbalance],
696            preferred_strategies: vec![MigrationStrategy::GradualMigration],
697            cooldown_period: 100,
698            max_migration_distance: 1,
699        }
700    }
701}
702
703/// Migration trigger conditions
704#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
705pub enum MigrationTrigger {
706    /// Trigger on load imbalance
707    LoadImbalance,
708    /// Trigger on memory pressure
709    MemoryPressure,
710    /// Trigger on performance degradation
711    PerformanceDegradation,
712    /// Trigger at regular intervals
713    Periodic,
714}
715
716/// Migration strategies
717#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
718pub enum MigrationStrategy {
719    /// Gradual parameter migration
720    GradualMigration,
721    /// Complete expert migration
722    CompleteMigration,
723    /// Load redistribution without migration
724    LoadRedistribution,
725    /// Hybrid approach
726    Hybrid,
727}
728
729#[cfg(test)]
730mod tests {
731    use super::*;
732
733    #[test]
734    fn test_expert_parallelism_config_default() {
735        let config = ExpertParallelismConfig::default();
736        assert_eq!(config.num_experts, 8);
737        assert_eq!(config.num_experts_per_token, 2);
738        assert_eq!(config.capacity_factor, 1.25);
739        assert!(config.validate().is_ok());
740    }
741
742    #[test]
743    fn test_expert_parallelism_config_validation() {
744        // Test invalid num_experts
745        let config1 = ExpertParallelismConfig {
746            num_experts: 0,
747            ..Default::default()
748        };
749        assert!(config1.validate().is_err());
750
751        // Test invalid num_experts_per_token
752        let config2 = ExpertParallelismConfig {
753            num_experts: 8,
754            num_experts_per_token: 10,
755            ..Default::default()
756        };
757        assert!(config2.validate().is_err());
758
759        // Test invalid capacity_factor
760        let config3 = ExpertParallelismConfig {
761            num_experts: 8,
762            num_experts_per_token: 2,
763            capacity_factor: -1.0,
764            ..Default::default()
765        };
766        assert!(config3.validate().is_err());
767    }
768
769    #[test]
770    fn test_expert_capacity_calculation() {
771        let config = ExpertParallelismConfig::default();
772        let capacity = config.calculate_expert_capacity(1000);
773
774        // With 8 experts, 2 experts per token, 1000 tokens total
775        // tokens_per_expert = (1000 * 2) / 8 = 250
776        // capacity = 250 * 1.25 = 312.5 -> 313
777        assert_eq!(capacity, 313);
778    }
779
780    #[test]
781    fn test_sharding_strategy_properties() {
782        assert!(ExpertShardingStrategy::ModelParallel.requires_load_balancing());
783        assert!(!ExpertShardingStrategy::DataParallel.requires_load_balancing());
784        assert!(ExpertShardingStrategy::Dynamic.supports_migration());
785        assert!(!ExpertShardingStrategy::DataParallel.supports_migration());
786    }
787
788    #[test]
789    fn test_expert_parameters_default() {
790        let params = ExpertParameters::default();
791        assert_eq!(params.input_dim, 512);
792        assert_eq!(params.hidden_dim, 2048);
793        assert_eq!(params.output_dim, 512);
794        assert!(params.validate().is_ok());
795    }
796
797    #[test]
798    fn test_expert_parameters_transformer_ffn() {
799        let params = ExpertParameters::transformer_ffn(768);
800        assert_eq!(params.input_dim, 768);
801        assert_eq!(params.hidden_dim, 768 * 4);
802        assert_eq!(params.output_dim, 768);
803        assert_eq!(params.activation, "gelu");
804    }
805
806    #[test]
807    fn test_expert_parameters_validation() {
808        // Test invalid dimensions
809        let params1 = ExpertParameters {
810            input_dim: 0,
811            ..Default::default()
812        };
813        assert!(params1.validate().is_err());
814
815        // Test invalid dropout
816        let params2 = ExpertParameters {
817            input_dim: 512,
818            dropout: 1.5,
819            ..Default::default()
820        };
821        assert!(params2.validate().is_err());
822
823        // Test invalid activation
824        let params3 = ExpertParameters {
825            input_dim: 512,
826            dropout: 0.1,
827            activation: "invalid".to_string(),
828            ..Default::default()
829        };
830        assert!(params3.validate().is_err());
831    }
832
833    #[test]
834    fn test_expert_parameters_parameter_count() {
835        let params = ExpertParameters::new(100, 200, 100);
836
837        // Single layer case (num_layers = 2 by default)
838        // Layer 1: 100 * 200 + 200 = 20200
839        // Layer 2: 200 * 100 + 100 = 20100
840        // Total: 40300
841        let count = params.parameter_count();
842        assert_eq!(count, 40300);
843    }
844
845    #[test]
846    fn test_preset_configs() {
847        let small = ExpertParallelismConfig::small_scale();
848        assert_eq!(small.num_experts, 8);
849        assert_eq!(
850            small.sharding_strategy,
851            ExpertShardingStrategy::DataParallel
852        );
853
854        let large = ExpertParallelismConfig::large_scale();
855        assert_eq!(large.num_experts, 128);
856        assert!(large.enable_gradient_accumulation);
857
858        let inference = ExpertParallelismConfig::inference();
859        assert_eq!(inference.expert_dropout, 0.0);
860        assert!(!inference.enable_load_balancing);
861    }
862
863    #[test]
864    fn test_recommended_num_devices() {
865        let config = ExpertParallelismConfig {
866            num_experts: 32,
867            sharding_strategy: ExpertShardingStrategy::ModelParallel,
868            ..Default::default()
869        };
870
871        let num_devices = config.recommended_num_devices();
872        assert_eq!(num_devices, 32);
873    }
874}