1use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct ExpertParallelismConfig {
29 pub num_experts: usize,
34
35 pub num_experts_per_token: usize,
41
42 pub capacity_factor: f32,
48
49 pub load_balance_loss_coeff: f32,
55
56 pub router_z_loss_coeff: f32,
61
62 pub expert_dropout: f32,
67
68 pub enable_load_balancing: bool,
73
74 pub sharding_strategy: ExpertShardingStrategy,
78
79 pub max_expert_batch_size: Option<usize>,
84
85 pub enable_gradient_accumulation: bool,
90
91 pub gradient_accumulation_steps: usize,
95
96 pub initialization_strategy: ExpertInitStrategy,
100
101 pub enable_expert_sync: bool,
105
106 pub sync_frequency: usize,
110
111 pub gate_network: Option<GateNetworkConfig>,
115
116 pub load_balancing: Option<LoadBalancingConfig>,
120
121 pub migration: Option<ExpertMigrationConfig>,
125
126 pub enable_expert_migration: bool,
128
129 pub migration_threshold: f32,
131
132 pub memory_per_expert_mb: usize,
134
135 pub communication_overlap: bool,
137
138 pub gradient_compression: bool,
140}
141
142impl Default for ExpertParallelismConfig {
143 fn default() -> Self {
144 Self {
145 num_experts: 8,
146 num_experts_per_token: 2,
147 capacity_factor: 1.25,
148 load_balance_loss_coeff: 0.01,
149 router_z_loss_coeff: 0.001,
150 expert_dropout: 0.0,
151 enable_load_balancing: true,
152 sharding_strategy: ExpertShardingStrategy::ModelParallel,
153 max_expert_batch_size: None,
154 enable_gradient_accumulation: false,
155 gradient_accumulation_steps: 1,
156 initialization_strategy: ExpertInitStrategy::Xavier,
157 enable_expert_sync: false,
158 sync_frequency: 100,
159 gate_network: None,
160 load_balancing: None,
161 migration: None,
162 enable_expert_migration: false,
163 migration_threshold: 0.3,
164 memory_per_expert_mb: 512,
165 communication_overlap: true,
166 gradient_compression: false,
167 }
168 }
169}
170
171impl ExpertParallelismConfig {
172 pub fn new() -> Self {
174 Self::default()
175 }
176
177 pub fn small_scale() -> Self {
183 Self {
184 num_experts: 8,
185 num_experts_per_token: 2,
186 capacity_factor: 1.25,
187 load_balance_loss_coeff: 0.01,
188 sharding_strategy: ExpertShardingStrategy::DataParallel,
189 ..Default::default()
190 }
191 }
192
193 pub fn large_scale() -> Self {
199 Self {
200 num_experts: 128,
201 num_experts_per_token: 2,
202 capacity_factor: 1.5,
203 load_balance_loss_coeff: 0.001,
204 sharding_strategy: ExpertShardingStrategy::ModelParallel,
205 enable_gradient_accumulation: true,
206 gradient_accumulation_steps: 4,
207 enable_expert_sync: true,
208 sync_frequency: 50,
209 ..Default::default()
210 }
211 }
212
213 pub fn inference() -> Self {
219 Self {
220 expert_dropout: 0.0,
221 enable_load_balancing: false,
222 enable_gradient_accumulation: false,
223 enable_expert_sync: false,
224 ..Default::default()
225 }
226 }
227
228 pub fn validate(&self) -> Result<(), String> {
234 if self.num_experts == 0 {
235 return Err("Number of experts must be greater than 0".to_string());
236 }
237
238 if self.num_experts_per_token == 0 || self.num_experts_per_token > self.num_experts {
239 return Err(
240 "Number of experts per token must be between 1 and num_experts".to_string(),
241 );
242 }
243
244 if self.capacity_factor <= 0.0 {
245 return Err("Capacity factor must be positive".to_string());
246 }
247
248 if self.load_balance_loss_coeff < 0.0 {
249 return Err("Load balance loss coefficient must be non-negative".to_string());
250 }
251
252 if self.router_z_loss_coeff < 0.0 {
253 return Err("Router z-loss coefficient must be non-negative".to_string());
254 }
255
256 if self.expert_dropout < 0.0 || self.expert_dropout > 1.0 {
257 return Err("Expert dropout must be between 0.0 and 1.0".to_string());
258 }
259
260 if self.gradient_accumulation_steps == 0 {
261 return Err("Gradient accumulation steps must be greater than 0".to_string());
262 }
263
264 if self.sync_frequency == 0 {
265 return Err("Sync frequency must be greater than 0".to_string());
266 }
267
268 Ok(())
269 }
270
271 pub fn calculate_expert_capacity(&self, total_tokens: usize) -> usize {
281 let tokens_per_expert = (total_tokens * self.num_experts_per_token) / self.num_experts;
282 (tokens_per_expert as f32 * self.capacity_factor).ceil() as usize
283 }
284
285 pub fn recommended_num_devices(&self) -> usize {
291 match self.sharding_strategy {
292 ExpertShardingStrategy::DataParallel => 1,
293 ExpertShardingStrategy::ModelParallel => self.num_experts.min(64),
294 ExpertShardingStrategy::Hybrid => (self.num_experts / 4).clamp(2, 16),
295 ExpertShardingStrategy::Dynamic => (self.num_experts / 2).clamp(4, 32),
296 }
297 }
298}
299
300#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
305pub enum ExpertShardingStrategy {
306 DataParallel,
311
312 ModelParallel,
317
318 Hybrid,
323
324 Dynamic,
329}
330
331impl ExpertShardingStrategy {
332 pub fn description(&self) -> &'static str {
338 match self {
339 Self::DataParallel => "All experts replicated on each device",
340 Self::ModelParallel => "Experts partitioned across devices",
341 Self::Hybrid => "Mix of replicated and partitioned experts",
342 Self::Dynamic => "Dynamic expert placement based on load",
343 }
344 }
345
346 pub fn requires_load_balancing(&self) -> bool {
352 matches!(self, Self::ModelParallel | Self::Hybrid | Self::Dynamic)
353 }
354
355 pub fn supports_migration(&self) -> bool {
361 matches!(self, Self::Hybrid | Self::Dynamic)
362 }
363}
364
365#[derive(Debug, Clone, Serialize, Deserialize)]
369pub struct ExpertParameters {
370 pub input_dim: usize,
372
373 pub hidden_dim: usize,
377
378 pub output_dim: usize,
382
383 pub activation: String,
387
388 pub num_layers: usize,
390
391 pub dropout: f32,
393
394 pub use_bias: bool,
396
397 pub layer_norm_eps: f32,
399
400 pub init_scale: f32,
402}
403
404impl Default for ExpertParameters {
405 fn default() -> Self {
406 Self {
407 input_dim: 512,
408 hidden_dim: 2048,
409 output_dim: 512,
410 activation: "relu".to_string(),
411 num_layers: 2,
412 dropout: 0.1,
413 use_bias: true,
414 layer_norm_eps: 1e-5,
415 init_scale: 0.02,
416 }
417 }
418}
419
420impl ExpertParameters {
421 pub fn new(input_dim: usize, hidden_dim: usize, output_dim: usize) -> Self {
423 Self {
424 input_dim,
425 hidden_dim,
426 output_dim,
427 ..Default::default()
428 }
429 }
430
431 pub fn transformer_ffn(model_dim: usize) -> Self {
441 Self {
442 input_dim: model_dim,
443 hidden_dim: model_dim * 4,
444 output_dim: model_dim,
445 activation: "gelu".to_string(),
446 ..Default::default()
447 }
448 }
449
450 pub fn lightweight(model_dim: usize) -> Self {
460 Self {
461 input_dim: model_dim,
462 hidden_dim: model_dim * 2,
463 output_dim: model_dim,
464 num_layers: 1,
465 dropout: 0.05,
466 ..Default::default()
467 }
468 }
469
470 pub fn validate(&self) -> Result<(), String> {
476 if self.input_dim == 0 {
477 return Err("Input dimension must be greater than 0".to_string());
478 }
479
480 if self.hidden_dim == 0 {
481 return Err("Hidden dimension must be greater than 0".to_string());
482 }
483
484 if self.output_dim == 0 {
485 return Err("Output dimension must be greater than 0".to_string());
486 }
487
488 if self.num_layers == 0 {
489 return Err("Number of layers must be greater than 0".to_string());
490 }
491
492 if self.dropout < 0.0 || self.dropout > 1.0 {
493 return Err("Dropout must be between 0.0 and 1.0".to_string());
494 }
495
496 if self.layer_norm_eps <= 0.0 {
497 return Err("Layer norm epsilon must be positive".to_string());
498 }
499
500 if self.init_scale <= 0.0 {
501 return Err("Initialization scale must be positive".to_string());
502 }
503
504 let valid_activations = ["relu", "gelu", "swish", "tanh", "leaky_relu", "elu"];
505 if !valid_activations.contains(&self.activation.as_str()) {
506 return Err(format!(
507 "Unsupported activation function: {}. Supported: {:?}",
508 self.activation, valid_activations
509 ));
510 }
511
512 Ok(())
513 }
514
515 pub fn parameter_count(&self) -> usize {
521 if self.num_layers == 1 {
522 let layer1_params =
524 self.input_dim * self.hidden_dim + if self.use_bias { self.hidden_dim } else { 0 };
525 let layer2_params =
526 self.hidden_dim * self.output_dim + if self.use_bias { self.output_dim } else { 0 };
527 layer1_params + layer2_params
528 } else {
529 let input_layer =
531 self.input_dim * self.hidden_dim + if self.use_bias { self.hidden_dim } else { 0 };
532 let hidden_layers = (self.num_layers - 2)
533 * (self.hidden_dim * self.hidden_dim
534 + if self.use_bias { self.hidden_dim } else { 0 });
535 let output_layer =
536 self.hidden_dim * self.output_dim + if self.use_bias { self.output_dim } else { 0 };
537 input_layer + hidden_layers + output_layer
538 }
539 }
540}
541
542#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
544pub enum ExpertInitStrategy {
545 Xavier,
547 Kaiming,
549 Normal,
551 Uniform,
553 TruncatedNormal,
555}
556
557impl ExpertInitStrategy {
558 pub fn description(&self) -> &'static str {
560 match self {
561 Self::Xavier => "Xavier/Glorot initialization for balanced gradients",
562 Self::Kaiming => "Kaiming/He initialization for ReLU networks",
563 Self::Normal => "Standard normal distribution",
564 Self::Uniform => "Uniform distribution",
565 Self::TruncatedNormal => "Truncated normal distribution",
566 }
567 }
568}
569
570#[derive(Debug, Clone, Serialize, Deserialize)]
572pub struct GateNetworkConfig {
573 pub hierarchical: Option<HierarchicalGateConfig>,
575
576 pub enable_learned_gates: bool,
578
579 pub gate_dropout: f32,
581
582 pub num_gate_layers: usize,
584}
585
586impl Default for GateNetworkConfig {
587 fn default() -> Self {
588 Self {
589 hierarchical: None,
590 enable_learned_gates: true,
591 gate_dropout: 0.1,
592 num_gate_layers: 2,
593 }
594 }
595}
596
597#[derive(Debug, Clone, Serialize, Deserialize)]
599pub struct HierarchicalGateConfig {
600 pub levels: usize,
602
603 pub experts_per_group: usize,
605
606 pub gate_hidden_dim: usize,
608
609 pub use_learned_grouping: bool,
611
612 pub grouping_strategy: GroupingStrategy,
614}
615
616impl Default for HierarchicalGateConfig {
617 fn default() -> Self {
618 Self {
619 levels: 2,
620 experts_per_group: 8,
621 gate_hidden_dim: 512,
622 use_learned_grouping: true,
623 grouping_strategy: GroupingStrategy::LoadBased,
624 }
625 }
626}
627
628#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
630pub enum GroupingStrategy {
631 LoadBased,
633 SimilarityBased,
635 Static,
637 Dynamic,
639}
640
641#[derive(Debug, Clone, Serialize, Deserialize)]
643pub struct LoadBalancingConfig {
644 pub enable_auto_balancing: bool,
646
647 pub imbalance_threshold: f32,
649
650 pub check_frequency: usize,
652
653 pub max_concurrent_migrations: usize,
655
656 pub load_smoothing_factor: f32,
658}
659
660impl Default for LoadBalancingConfig {
661 fn default() -> Self {
662 Self {
663 enable_auto_balancing: true,
664 imbalance_threshold: 0.3,
665 check_frequency: 50,
666 max_concurrent_migrations: 2,
667 load_smoothing_factor: 0.9,
668 }
669 }
670}
671
672#[derive(Debug, Clone, Serialize, Deserialize)]
674pub struct ExpertMigrationConfig {
675 pub enable_migration: bool,
677
678 pub triggers: Vec<MigrationTrigger>,
680
681 pub preferred_strategies: Vec<MigrationStrategy>,
683
684 pub cooldown_period: usize,
686
687 pub max_migration_distance: usize,
689}
690
691impl Default for ExpertMigrationConfig {
692 fn default() -> Self {
693 Self {
694 enable_migration: false,
695 triggers: vec![MigrationTrigger::LoadImbalance],
696 preferred_strategies: vec![MigrationStrategy::GradualMigration],
697 cooldown_period: 100,
698 max_migration_distance: 1,
699 }
700 }
701}
702
703#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
705pub enum MigrationTrigger {
706 LoadImbalance,
708 MemoryPressure,
710 PerformanceDegradation,
712 Periodic,
714}
715
716#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
718pub enum MigrationStrategy {
719 GradualMigration,
721 CompleteMigration,
723 LoadRedistribution,
725 Hybrid,
727}
728
729#[cfg(test)]
730mod tests {
731 use super::*;
732
733 #[test]
734 fn test_expert_parallelism_config_default() {
735 let config = ExpertParallelismConfig::default();
736 assert_eq!(config.num_experts, 8);
737 assert_eq!(config.num_experts_per_token, 2);
738 assert_eq!(config.capacity_factor, 1.25);
739 assert!(config.validate().is_ok());
740 }
741
742 #[test]
743 fn test_expert_parallelism_config_validation() {
744 let config1 = ExpertParallelismConfig {
746 num_experts: 0,
747 ..Default::default()
748 };
749 assert!(config1.validate().is_err());
750
751 let config2 = ExpertParallelismConfig {
753 num_experts: 8,
754 num_experts_per_token: 10,
755 ..Default::default()
756 };
757 assert!(config2.validate().is_err());
758
759 let config3 = ExpertParallelismConfig {
761 num_experts: 8,
762 num_experts_per_token: 2,
763 capacity_factor: -1.0,
764 ..Default::default()
765 };
766 assert!(config3.validate().is_err());
767 }
768
769 #[test]
770 fn test_expert_capacity_calculation() {
771 let config = ExpertParallelismConfig::default();
772 let capacity = config.calculate_expert_capacity(1000);
773
774 assert_eq!(capacity, 313);
778 }
779
780 #[test]
781 fn test_sharding_strategy_properties() {
782 assert!(ExpertShardingStrategy::ModelParallel.requires_load_balancing());
783 assert!(!ExpertShardingStrategy::DataParallel.requires_load_balancing());
784 assert!(ExpertShardingStrategy::Dynamic.supports_migration());
785 assert!(!ExpertShardingStrategy::DataParallel.supports_migration());
786 }
787
788 #[test]
789 fn test_expert_parameters_default() {
790 let params = ExpertParameters::default();
791 assert_eq!(params.input_dim, 512);
792 assert_eq!(params.hidden_dim, 2048);
793 assert_eq!(params.output_dim, 512);
794 assert!(params.validate().is_ok());
795 }
796
797 #[test]
798 fn test_expert_parameters_transformer_ffn() {
799 let params = ExpertParameters::transformer_ffn(768);
800 assert_eq!(params.input_dim, 768);
801 assert_eq!(params.hidden_dim, 768 * 4);
802 assert_eq!(params.output_dim, 768);
803 assert_eq!(params.activation, "gelu");
804 }
805
806 #[test]
807 fn test_expert_parameters_validation() {
808 let params1 = ExpertParameters {
810 input_dim: 0,
811 ..Default::default()
812 };
813 assert!(params1.validate().is_err());
814
815 let params2 = ExpertParameters {
817 input_dim: 512,
818 dropout: 1.5,
819 ..Default::default()
820 };
821 assert!(params2.validate().is_err());
822
823 let params3 = ExpertParameters {
825 input_dim: 512,
826 dropout: 0.1,
827 activation: "invalid".to_string(),
828 ..Default::default()
829 };
830 assert!(params3.validate().is_err());
831 }
832
833 #[test]
834 fn test_expert_parameters_parameter_count() {
835 let params = ExpertParameters::new(100, 200, 100);
836
837 let count = params.parameter_count();
842 assert_eq!(count, 40300);
843 }
844
845 #[test]
846 fn test_preset_configs() {
847 let small = ExpertParallelismConfig::small_scale();
848 assert_eq!(small.num_experts, 8);
849 assert_eq!(
850 small.sharding_strategy,
851 ExpertShardingStrategy::DataParallel
852 );
853
854 let large = ExpertParallelismConfig::large_scale();
855 assert_eq!(large.num_experts, 128);
856 assert!(large.enable_gradient_accumulation);
857
858 let inference = ExpertParallelismConfig::inference();
859 assert_eq!(inference.expert_dropout, 0.0);
860 assert!(!inference.enable_load_balancing);
861 }
862
863 #[test]
864 fn test_recommended_num_devices() {
865 let config = ExpertParallelismConfig {
866 num_experts: 32,
867 sharding_strategy: ExpertShardingStrategy::ModelParallel,
868 ..Default::default()
869 };
870
871 let num_devices = config.recommended_num_devices();
872 assert_eq!(num_devices, 32);
873 }
874}