1use crate::distributed::ProcessGroup;
2use anyhow::{anyhow, Result};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::sync::{Arc, Mutex, RwLock};
6use std::time::{Duration, Instant};
7use trustformers_core::tensor::Tensor;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct SequenceParallelismConfig {
17 pub sequence_parallel_size: usize,
19 pub max_sequence_length_per_device: usize,
21 pub overlap_size: usize,
23 pub attention_communication_opt: bool,
25 pub communication_pattern: SequenceCommunicationPattern,
27 pub splitting_strategy: SequenceSplittingStrategy,
29 pub sync_gradients: bool,
31 pub memory_optimization: SequenceMemoryOptimization,
33 pub use_checkpointing: bool,
35}
36
37impl Default for SequenceParallelismConfig {
38 fn default() -> Self {
39 Self {
40 sequence_parallel_size: 1,
41 max_sequence_length_per_device: 2048,
42 overlap_size: 128,
43 attention_communication_opt: true,
44 communication_pattern: SequenceCommunicationPattern::RingAllReduce,
45 splitting_strategy: SequenceSplittingStrategy::EqualChunks,
46 sync_gradients: true,
47 memory_optimization: SequenceMemoryOptimization::Medium,
48 use_checkpointing: true,
49 }
50 }
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub enum SequenceCommunicationPattern {
56 RingAllReduce,
58 TreeReduce,
60 PointToPoint,
62 AllToAll,
64 Hierarchical,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub enum SequenceSplittingStrategy {
71 EqualChunks,
73 AttentionBased,
75 SemanticBoundaries,
77 Dynamic,
79 ComplexityBased,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub enum SequenceMemoryOptimization {
86 None,
87 Low,
88 Medium,
89 High,
90 Extreme,
91}
92
93#[derive(Debug, Clone)]
95pub struct SequenceChunk {
96 pub chunk_id: usize,
98 pub device_rank: usize,
100 pub start_position: usize,
102 pub end_position: usize,
104 pub effective_length: usize,
106 pub prev_overlap: usize,
108 pub next_overlap: usize,
110 pub needs_attention_comm: bool,
112}
113
114#[derive(Debug, Clone)]
116pub struct AttentionCommunication {
117 pub source_chunk: usize,
119 pub target_chunk: usize,
121 pub attention_positions: Vec<(usize, usize)>, pub communication_size: usize,
125}
126
127pub struct SequenceParallelism {
129 config: SequenceParallelismConfig,
130 global_rank: usize,
131 #[allow(dead_code)]
132 world_size: usize,
133
134 sequence_chunks: Vec<SequenceChunk>,
136 local_chunks: Vec<usize>, sequence_group: Arc<dyn ProcessGroup>,
140
141 attention_comm_manager: Arc<RwLock<AttentionCommManager>>,
143
144 communication_stats: Arc<Mutex<SequenceCommunicationStats>>,
146
147 memory_manager: Arc<Mutex<SequenceMemoryManager>>,
149}
150
151#[derive(Debug, Default)]
153struct AttentionCommManager {
154 #[allow(dead_code)]
155 communication_plan: Vec<AttentionCommunication>,
156 attention_cache: HashMap<(usize, usize), Tensor>, cache_hits: u64,
158 cache_misses: u64,
159}
160
161#[derive(Debug, Default)]
163#[allow(dead_code)]
164struct SequenceCommunicationStats {
165 total_communication_time: Duration,
166 attention_communication_time: Duration,
167 gradient_sync_time: Duration,
168 #[allow(dead_code)]
169 total_bytes_communicated: u64,
170 attention_cache_hit_rate: f32,
171 communication_efficiency: f32,
172}
173
174#[derive(Debug, Default)]
176#[allow(dead_code)]
177struct SequenceMemoryManager {
178 #[allow(dead_code)]
179 chunk_activations: HashMap<usize, Vec<Tensor>>,
180 chunk_gradients: HashMap<usize, Vec<Tensor>>,
181 checkpointed_chunks: HashMap<usize, Vec<Tensor>>,
182 peak_memory_per_chunk: HashMap<usize, u64>,
183 current_memory_usage: u64,
184 memory_pressure: f32,
185}
186
187#[derive(Debug, Clone)]
189pub struct AttentionPatternAnalysis {
190 pub total_length: usize,
192 pub attention_boundaries: Vec<usize>,
194 pub attention_intensities: Vec<f32>,
196 pub cross_chunk_attention: HashMap<(usize, usize), f32>,
198 pub token_importance: Vec<f32>,
200 pub attention_head_patterns: Vec<AttentionHeadPattern>,
202}
203
204#[derive(Debug, Clone, PartialEq)]
206pub enum AttentionPatternType {
207 Local,
209 Global,
211 Syntactic,
213 Semantic,
215}
216
217#[derive(Debug, Clone)]
219pub struct AttentionHeadPattern {
220 pub head_id: usize,
222 pub pattern_type: AttentionPatternType,
224 pub attention_span: usize,
226 pub pattern_strength: f32,
228 pub communication_requirement: f32,
230}
231
232impl SequenceParallelism {
233 pub fn new(
235 config: SequenceParallelismConfig,
236 global_rank: usize,
237 world_size: usize,
238 sequence_group: Arc<dyn ProcessGroup>,
239 ) -> Result<Self> {
240 if config.sequence_parallel_size > world_size {
242 return Err(anyhow!(
243 "Sequence parallel size ({}) cannot exceed world size ({})",
244 config.sequence_parallel_size,
245 world_size
246 ));
247 }
248
249 if config.overlap_size >= config.max_sequence_length_per_device {
250 return Err(anyhow!(
251 "Overlap size ({}) must be smaller than max sequence length per device ({})",
252 config.overlap_size,
253 config.max_sequence_length_per_device
254 ));
255 }
256
257 Ok(Self {
258 config,
259 global_rank,
260 world_size,
261 sequence_chunks: Vec::new(),
262 local_chunks: Vec::new(),
263 sequence_group,
264 attention_comm_manager: Arc::new(RwLock::new(AttentionCommManager::default())),
265 communication_stats: Arc::new(Mutex::new(SequenceCommunicationStats::default())),
266 memory_manager: Arc::new(Mutex::new(SequenceMemoryManager::default())),
267 })
268 }
269
270 pub fn split_sequence(&mut self, total_sequence_length: usize) -> Result<Vec<SequenceChunk>> {
272 let chunks = match self.config.splitting_strategy {
273 SequenceSplittingStrategy::EqualChunks => {
274 self.split_equal_chunks(total_sequence_length)?
275 },
276 SequenceSplittingStrategy::AttentionBased => {
277 self.split_attention_based(total_sequence_length)?
278 },
279 SequenceSplittingStrategy::SemanticBoundaries => {
280 self.split_semantic_boundaries(total_sequence_length)?
281 },
282 SequenceSplittingStrategy::Dynamic => self.split_dynamic(total_sequence_length)?,
283 SequenceSplittingStrategy::ComplexityBased => {
284 self.split_complexity_based(total_sequence_length)?
285 },
286 };
287
288 self.sequence_chunks = chunks.clone();
290 self.local_chunks = chunks
291 .iter()
292 .enumerate()
293 .filter(|(_, chunk)| chunk.device_rank == self.global_rank)
294 .map(|(i, _)| i)
295 .collect();
296
297 Ok(chunks)
298 }
299
300 fn split_equal_chunks(&self, total_length: usize) -> Result<Vec<SequenceChunk>> {
302 let chunk_size = self.config.max_sequence_length_per_device;
303 let overlap = self.config.overlap_size;
304 let num_devices = self.config.sequence_parallel_size;
305
306 let mut chunks = Vec::new();
307 let mut current_pos = 0;
308 let mut chunk_id = 0;
309
310 while current_pos < total_length {
311 let end_pos = std::cmp::min(current_pos + chunk_size, total_length);
312 let device_rank = chunk_id % num_devices;
313
314 let prev_overlap = if chunk_id > 0 { overlap } else { 0 };
315 let next_overlap = if end_pos < total_length { overlap } else { 0 };
316
317 let chunk = SequenceChunk {
318 chunk_id,
319 device_rank,
320 start_position: current_pos,
321 end_position: end_pos,
322 effective_length: end_pos - current_pos - prev_overlap,
323 prev_overlap,
324 next_overlap,
325 needs_attention_comm: true,
326 };
327
328 chunks.push(chunk);
329 current_pos = end_pos - overlap;
330 chunk_id += 1;
331 }
332
333 Ok(chunks)
334 }
335
336 fn split_attention_based(&self, total_length: usize) -> Result<Vec<SequenceChunk>> {
338 let attention_analysis = self.analyze_attention_patterns(total_length)?;
340
341 let split_points = self.find_optimal_split_points(&attention_analysis, total_length)?;
343
344 self.create_attention_aware_chunks(total_length, &split_points)
346 }
347
348 fn analyze_attention_patterns(&self, total_length: usize) -> Result<AttentionPatternAnalysis> {
350 let mut attention_boundaries = Vec::new();
357 let mut attention_intensities = Vec::new();
358 let mut cross_chunk_attention = HashMap::new();
359
360 let window_size = 512; let num_windows = total_length.div_ceil(window_size);
363
364 for window_idx in 0..num_windows {
365 let start_pos = window_idx * window_size;
366 let end_pos = (start_pos + window_size).min(total_length);
367
368 let local_attention = self.calculate_local_attention_intensity(start_pos, end_pos)?;
370 let cross_window_attention =
371 self.calculate_cross_window_attention(window_idx, num_windows)?;
372
373 attention_intensities.push(local_attention);
374
375 if window_idx > 0 && cross_window_attention < 0.3 {
377 attention_boundaries.push(start_pos);
378 }
379
380 if window_idx < num_windows - 1 {
382 cross_chunk_attention.insert((window_idx, window_idx + 1), cross_window_attention);
383 }
384 }
385
386 if !attention_boundaries.contains(&0) {
388 attention_boundaries.insert(0, 0);
389 }
390 if !attention_boundaries.contains(&total_length) {
391 attention_boundaries.push(total_length);
392 }
393
394 attention_boundaries.sort();
395
396 Ok(AttentionPatternAnalysis {
397 total_length,
398 attention_boundaries,
399 attention_intensities,
400 cross_chunk_attention,
401 token_importance: self.calculate_token_importance(total_length)?,
402 attention_head_patterns: self.analyze_attention_head_patterns(total_length)?,
403 })
404 }
405
406 fn calculate_local_attention_intensity(&self, start_pos: usize, end_pos: usize) -> Result<f32> {
408 let window_length = end_pos - start_pos;
409
410 let position_factor = (start_pos as f32
412 / self.config.max_sequence_length_per_device as f32)
413 .sin()
414 .abs();
415 let length_factor = (window_length as f32 / 512.0).min(1.0);
416 let content_complexity = fastrand::f32() * 0.3 + 0.5; Ok(position_factor * length_factor * content_complexity)
419 }
420
421 fn calculate_cross_window_attention(
423 &self,
424 window_idx: usize,
425 total_windows: usize,
426 ) -> Result<f32> {
427 let distance_decay = if window_idx == 0 || window_idx == total_windows - 1 {
429 0.8 } else {
431 1.0 - (window_idx as f32 / total_windows as f32).abs()
432 };
433
434 let content_similarity = 0.6 + fastrand::f32() * 0.3; let attention_spread = 0.4 + fastrand::f32() * 0.4; Ok(distance_decay * content_similarity * attention_spread)
438 }
439
440 fn calculate_token_importance(&self, total_length: usize) -> Result<Vec<f32>> {
442 let mut importance_scores = Vec::with_capacity(total_length);
443
444 for pos in 0..total_length {
445 let position_bias = if pos < total_length / 4 || pos > 3 * total_length / 4 {
447 1.2 } else {
449 1.0
450 };
451
452 let content_importance = 0.3 + fastrand::f32() * 0.7; let attention_centrality = self.calculate_attention_centrality(pos, total_length)?;
454
455 importance_scores.push(position_bias * content_importance * attention_centrality);
456 }
457
458 Ok(importance_scores)
459 }
460
461 fn calculate_attention_centrality(&self, pos: usize, total_length: usize) -> Result<f32> {
463 let relative_pos = pos as f32 / total_length as f32;
465
466 let position_centrality = 1.0 - (2.0 * relative_pos - 1.0).abs();
468
469 let content_centrality = 0.5 + fastrand::f32() * 0.5;
471
472 Ok(position_centrality * content_centrality)
473 }
474
475 fn analyze_attention_head_patterns(
477 &self,
478 total_length: usize,
479 ) -> Result<Vec<AttentionHeadPattern>> {
480 let num_heads = 12; let mut head_patterns = Vec::with_capacity(num_heads);
482
483 for head_idx in 0..num_heads {
484 let pattern_type = match head_idx % 4 {
485 0 => AttentionPatternType::Local, 1 => AttentionPatternType::Global, 2 => AttentionPatternType::Syntactic, 3 => AttentionPatternType::Semantic, _ => AttentionPatternType::Local,
490 };
491
492 let attention_span = match pattern_type {
493 AttentionPatternType::Local => total_length / 8,
494 AttentionPatternType::Global => total_length,
495 AttentionPatternType::Syntactic => total_length / 4,
496 AttentionPatternType::Semantic => total_length / 2,
497 };
498
499 let pattern_strength = 0.4 + fastrand::f32() * 0.6;
500 let communication_requirement =
501 self.calculate_communication_requirement(&pattern_type, total_length)?;
502
503 head_patterns.push(AttentionHeadPattern {
504 head_id: head_idx,
505 pattern_type,
506 attention_span,
507 pattern_strength,
508 communication_requirement,
509 });
510 }
511
512 Ok(head_patterns)
513 }
514
515 fn calculate_communication_requirement(
517 &self,
518 pattern_type: &AttentionPatternType,
519 _total_length: usize,
520 ) -> Result<f32> {
521 match pattern_type {
522 AttentionPatternType::Local => Ok(0.1), AttentionPatternType::Global => Ok(0.9), AttentionPatternType::Syntactic => Ok(0.4), AttentionPatternType::Semantic => Ok(0.6), }
527 }
528
529 fn find_optimal_split_points(
531 &self,
532 analysis: &AttentionPatternAnalysis,
533 total_length: usize,
534 ) -> Result<Vec<usize>> {
535 let target_chunks = self.config.sequence_parallel_size;
536 let min_chunk_size = self.config.max_sequence_length_per_device / 2;
537 let max_chunk_size = self.config.max_sequence_length_per_device;
538
539 if target_chunks == 1 {
540 return Ok(vec![0, total_length]);
541 }
542
543 let mut split_points = vec![0];
545 let mut remaining_length = total_length;
546 let mut remaining_chunks = target_chunks;
547
548 for chunk_idx in 0..target_chunks - 1 {
549 let avg_remaining_chunk_size = remaining_length / remaining_chunks;
550 let target_split_pos = split_points[chunk_idx] + avg_remaining_chunk_size;
551
552 let best_boundary = self.find_best_boundary_near_position(
554 analysis,
555 target_split_pos,
556 min_chunk_size,
557 max_chunk_size,
558 )?;
559
560 split_points.push(best_boundary);
561 remaining_length = total_length - best_boundary;
562 remaining_chunks -= 1;
563 }
564
565 split_points.push(total_length);
566 Ok(split_points)
567 }
568
569 fn find_best_boundary_near_position(
571 &self,
572 analysis: &AttentionPatternAnalysis,
573 target_pos: usize,
574 min_chunk_size: usize,
575 _max_chunk_size: usize,
576 ) -> Result<usize> {
577 let search_radius = min_chunk_size / 4;
578 let start_search = target_pos.saturating_sub(search_radius);
579 let end_search = (target_pos + search_radius).min(analysis.total_length);
580
581 let mut best_pos = target_pos;
582 let mut best_score = f32::NEG_INFINITY;
583
584 for candidate_pos in (start_search..=end_search).step_by(16) {
586 if candidate_pos < min_chunk_size
587 || candidate_pos > analysis.total_length - min_chunk_size
588 {
589 continue;
590 }
591
592 let boundary_score =
593 self.calculate_boundary_score(analysis, candidate_pos, target_pos)?;
594
595 if boundary_score > best_score {
596 best_score = boundary_score;
597 best_pos = candidate_pos;
598 }
599 }
600
601 Ok(best_pos)
602 }
603
604 fn calculate_boundary_score(
606 &self,
607 analysis: &AttentionPatternAnalysis,
608 pos: usize,
609 target_pos: usize,
610 ) -> Result<f32> {
611 let distance_penalty =
613 1.0 - (pos as f32 - target_pos as f32).abs() / (target_pos as f32 + 1.0);
614
615 let attention_score = if analysis.attention_boundaries.contains(&pos) {
617 1.0
618 } else {
619 0.5 + 0.3 * (1.0 - self.get_cross_attention_at_position(analysis, pos)?)
621 };
622
623 let importance_penalty = if pos < analysis.token_importance.len() {
625 1.0 - analysis.token_importance[pos] * 0.3
626 } else {
627 1.0
628 };
629
630 let communication_score =
632 1.0 - self.estimate_communication_cost_at_boundary(analysis, pos)?;
633
634 Ok(distance_penalty * 0.3
635 + attention_score * 0.4
636 + importance_penalty * 0.15
637 + communication_score * 0.15)
638 }
639
640 fn get_cross_attention_at_position(
642 &self,
643 analysis: &AttentionPatternAnalysis,
644 pos: usize,
645 ) -> Result<f32> {
646 let window_size = 512;
647 let window_idx = pos / window_size;
648 let next_window_idx = window_idx + 1;
649
650 if let Some(&cross_attention) =
652 analysis.cross_chunk_attention.get(&(window_idx, next_window_idx))
653 {
654 Ok(cross_attention)
655 } else {
656 Ok(0.5) }
658 }
659
660 fn estimate_communication_cost_at_boundary(
662 &self,
663 analysis: &AttentionPatternAnalysis,
664 pos: usize,
665 ) -> Result<f32> {
666 let mut total_cost = 0.0;
667
668 for head_pattern in &analysis.attention_head_patterns {
670 if head_pattern.attention_span > pos && pos > 0 {
671 total_cost +=
673 head_pattern.communication_requirement * head_pattern.pattern_strength;
674 }
675 }
676
677 Ok(total_cost / analysis.attention_head_patterns.len() as f32)
679 }
680
681 fn create_attention_aware_chunks(
683 &self,
684 _total_length: usize,
685 split_points: &[usize],
686 ) -> Result<Vec<SequenceChunk>> {
687 let mut chunks = Vec::new();
688
689 for i in 0..split_points.len() - 1 {
690 let start_pos = split_points[i];
691 let end_pos = split_points[i + 1];
692 let chunk_length = end_pos - start_pos;
693
694 let prev_overlap =
696 if i > 0 { self.config.overlap_size.min(chunk_length / 4) } else { 0 };
697
698 let next_overlap = if i < split_points.len() - 2 {
699 self.config.overlap_size.min(chunk_length / 4)
700 } else {
701 0
702 };
703
704 let needs_attention_comm =
706 self.chunk_needs_attention_communication(i, split_points.len() - 1)?;
707
708 chunks.push(SequenceChunk {
709 chunk_id: i,
710 device_rank: i % self.config.sequence_parallel_size,
711 start_position: start_pos,
712 end_position: end_pos,
713 effective_length: chunk_length - prev_overlap - next_overlap,
714 prev_overlap,
715 next_overlap,
716 needs_attention_comm,
717 });
718 }
719
720 Ok(chunks)
721 }
722
723 fn chunk_needs_attention_communication(
725 &self,
726 chunk_idx: usize,
727 total_chunks: usize,
728 ) -> Result<bool> {
729 if total_chunks == 1 {
730 return Ok(false);
731 }
732
733 let has_prev_dependency = chunk_idx > 0;
735 let has_next_dependency = chunk_idx < total_chunks - 1;
736
737 if self.config.attention_communication_opt {
739 Ok(has_prev_dependency || has_next_dependency)
740 } else {
741 Ok(false)
742 }
743 }
744
745 fn split_semantic_boundaries(&self, total_length: usize) -> Result<Vec<SequenceChunk>> {
747 self.split_equal_chunks(total_length)
750 }
751
752 fn split_dynamic(&self, total_length: usize) -> Result<Vec<SequenceChunk>> {
754 let memory_manager = self.memory_manager.lock().expect("lock should not be poisoned");
755 let pressure = memory_manager.memory_pressure;
756
757 let base_chunk_size = self.config.max_sequence_length_per_device;
759 let adjusted_chunk_size = if pressure > 0.8 {
760 base_chunk_size / 2
761 } else if pressure > 0.6 {
762 (base_chunk_size * 3) / 4
763 } else {
764 base_chunk_size
765 };
766
767 let mut _adjusted_config = self.config.clone();
769 _adjusted_config.max_sequence_length_per_device = adjusted_chunk_size;
770
771 self.split_equal_chunks(total_length)
773 }
774
775 fn split_complexity_based(&self, total_length: usize) -> Result<Vec<SequenceChunk>> {
777 self.split_equal_chunks(total_length)
780 }
781
782 pub fn forward_chunk(
784 &self,
785 chunk_id: usize,
786 input: &Tensor,
787 _attention_mask: Option<&Tensor>,
788 ) -> Result<Tensor> {
789 let start_time = Instant::now();
790
791 if !self.local_chunks.contains(&chunk_id) {
792 return Err(anyhow!("Chunk {} is not local to this device", chunk_id));
793 }
794
795 let chunk = &self.sequence_chunks[chunk_id];
796
797 let output = self.process_local_chunk(input, chunk)?;
799
800 let final_output = if chunk.needs_attention_comm {
802 self.handle_attention_communication(chunk_id, &output)?
803 } else {
804 output
805 };
806
807 {
809 let mut stats = self.communication_stats.lock().expect("lock should not be poisoned");
810 stats.total_communication_time += start_time.elapsed();
811 }
812
813 Ok(final_output)
814 }
815
816 fn process_local_chunk(&self, input: &Tensor, _chunk: &SequenceChunk) -> Result<Tensor> {
818 Ok(input.clone())
821 }
822
823 fn handle_attention_communication(
825 &self,
826 chunk_id: usize,
827 chunk_output: &Tensor,
828 ) -> Result<Tensor> {
829 let start_time = Instant::now();
830
831 match self.config.communication_pattern {
832 SequenceCommunicationPattern::RingAllReduce => {
833 self.ring_attention_communication(chunk_id, chunk_output)
834 },
835 SequenceCommunicationPattern::TreeReduce => {
836 self.tree_attention_communication(chunk_id, chunk_output)
837 },
838 SequenceCommunicationPattern::PointToPoint => {
839 self.point_to_point_attention(chunk_id, chunk_output)
840 },
841 SequenceCommunicationPattern::AllToAll => {
842 self.all_to_all_attention(chunk_id, chunk_output)
843 },
844 SequenceCommunicationPattern::Hierarchical => {
845 self.hierarchical_attention_communication(chunk_id, chunk_output)
846 },
847 }
848 .map(|result| {
849 let mut stats = self.communication_stats.lock().expect("lock should not be poisoned");
851 stats.attention_communication_time += start_time.elapsed();
852 result
853 })
854 }
855
856 fn ring_attention_communication(
858 &self,
859 chunk_id: usize,
860 chunk_output: &Tensor,
861 ) -> Result<Tensor> {
862 let chunk = &self.sequence_chunks[chunk_id];
866 let num_chunks = self.sequence_chunks.len();
867
868 let combined_attention = chunk_output.clone();
870
871 if chunk_id > 0 && chunk.prev_overlap > 0 {
873 let _prev_attention = self.get_cached_attention(chunk_id - 1, chunk_id)?;
875 }
876
877 if chunk_id < num_chunks - 1 && chunk.next_overlap > 0 {
879 let _next_attention = self.get_cached_attention(chunk_id + 1, chunk_id)?;
881 }
882
883 Ok(combined_attention)
884 }
885
886 fn tree_attention_communication(
888 &self,
889 _chunk_id: usize,
890 chunk_output: &Tensor,
891 ) -> Result<Tensor> {
892 Ok(chunk_output.clone())
895 }
896
897 fn point_to_point_attention(&self, _chunk_id: usize, chunk_output: &Tensor) -> Result<Tensor> {
899 Ok(chunk_output.clone())
902 }
903
904 fn all_to_all_attention(&self, _chunk_id: usize, chunk_output: &Tensor) -> Result<Tensor> {
906 Ok(chunk_output.clone())
909 }
910
911 fn hierarchical_attention_communication(
913 &self,
914 _chunk_id: usize,
915 chunk_output: &Tensor,
916 ) -> Result<Tensor> {
917 Ok(chunk_output.clone())
920 }
921
922 fn get_cached_attention(&self, source_chunk: usize, target_chunk: usize) -> Result<Tensor> {
924 let mut comm_manager =
925 self.attention_comm_manager.write().expect("lock should not be poisoned");
926
927 let cache_key = (source_chunk, target_chunk);
928 if let Some(cached_attention) = comm_manager.attention_cache.get(&cache_key).cloned() {
929 comm_manager.cache_hits += 1;
930 Ok(cached_attention)
931 } else {
932 comm_manager.cache_misses += 1;
933 let attention = Tensor::zeros(&[64, 64])?;
935 comm_manager.attention_cache.insert(cache_key, attention.clone());
936 Ok(attention)
937 }
938 }
939
940 pub fn synchronize_gradients(&self, gradients: &mut HashMap<String, Tensor>) -> Result<()> {
942 if !self.config.sync_gradients {
943 return Ok(());
944 }
945
946 let start_time = Instant::now();
947
948 let mut gradient_tensors: Vec<Tensor> = gradients.values().cloned().collect();
950
951 self.sequence_group.all_reduce(&mut gradient_tensors)?;
953
954 let world_size = self.sequence_group.world_size() as f32;
956 for tensor in &mut gradient_tensors {
957 *tensor = tensor.scalar_mul(1.0 / world_size)?;
958 }
959
960 for (i, (_, gradient)) in gradients.iter_mut().enumerate() {
962 if i < gradient_tensors.len() {
963 *gradient = gradient_tensors[i].clone();
964 }
965 }
966
967 {
969 let mut stats = self.communication_stats.lock().expect("lock should not be poisoned");
970 stats.gradient_sync_time += start_time.elapsed();
971 }
972
973 Ok(())
974 }
975
976 pub fn update_memory_usage(&self, chunk_id: usize, memory_usage: u64) -> Result<()> {
978 let mut memory_manager = self.memory_manager.lock().expect("lock should not be poisoned");
979
980 memory_manager.peak_memory_per_chunk.insert(chunk_id, memory_usage);
981 memory_manager.current_memory_usage = memory_usage;
982
983 let max_memory = 16u64 * 1024 * 1024 * 1024; memory_manager.memory_pressure = memory_usage as f32 / max_memory as f32;
986
987 Ok(())
988 }
989
990 pub fn get_statistics(&self) -> SequenceParallelismStats {
992 let comm_stats = self.communication_stats.lock().expect("lock should not be poisoned");
993 let comm_manager = self.attention_comm_manager.read().expect("lock should not be poisoned");
994 let memory_manager = self.memory_manager.lock().expect("lock should not be poisoned");
995
996 let cache_hit_rate = if comm_manager.cache_hits + comm_manager.cache_misses > 0 {
997 comm_manager.cache_hits as f32
998 / (comm_manager.cache_hits + comm_manager.cache_misses) as f32
999 } else {
1000 0.0
1001 };
1002
1003 SequenceParallelismStats {
1004 total_chunks: self.sequence_chunks.len(),
1005 local_chunks: self.local_chunks.len(),
1006 communication_time: comm_stats.total_communication_time,
1007 attention_communication_time: comm_stats.attention_communication_time,
1008 gradient_sync_time: comm_stats.gradient_sync_time,
1009 attention_cache_hit_rate: cache_hit_rate,
1010 memory_pressure: memory_manager.memory_pressure,
1011 peak_memory_usage: memory_manager.current_memory_usage,
1012 }
1013 }
1014
1015 pub fn local_chunks(&self) -> &[usize] {
1017 &self.local_chunks
1018 }
1019
1020 pub fn get_chunk(&self, chunk_id: usize) -> Option<&SequenceChunk> {
1022 self.sequence_chunks.get(chunk_id)
1023 }
1024
1025 pub fn config(&self) -> &SequenceParallelismConfig {
1027 &self.config
1028 }
1029}
1030
1031#[derive(Debug, Clone)]
1033pub struct SequenceParallelismStats {
1034 pub total_chunks: usize,
1035 pub local_chunks: usize,
1036 pub communication_time: Duration,
1037 pub attention_communication_time: Duration,
1038 pub gradient_sync_time: Duration,
1039 pub attention_cache_hit_rate: f32,
1040 pub memory_pressure: f32,
1041 pub peak_memory_usage: u64,
1042}
1043
1044pub mod utils {
1046 use super::*;
1047
1048 pub fn calculate_optimal_sequence_config(
1050 total_sequence_length: usize,
1051 max_memory_per_device: usize,
1052 memory_per_token: usize,
1053 world_size: usize,
1054 ) -> Result<SequenceParallelismConfig> {
1055 let max_tokens_per_device = max_memory_per_device / memory_per_token;
1056
1057 if max_tokens_per_device == 0 {
1058 return Err(anyhow!("Insufficient memory for sequence parallelism"));
1059 }
1060
1061 let required_devices = total_sequence_length.div_ceil(max_tokens_per_device);
1062 let sequence_parallel_size = std::cmp::min(required_devices, world_size);
1063
1064 let tokens_per_device = total_sequence_length.div_ceil(sequence_parallel_size);
1065 let overlap_size = std::cmp::min(128, tokens_per_device / 10); Ok(SequenceParallelismConfig {
1068 sequence_parallel_size,
1069 max_sequence_length_per_device: tokens_per_device,
1070 overlap_size,
1071 ..Default::default()
1072 })
1073 }
1074
1075 pub fn estimate_communication_cost(
1077 config: &SequenceParallelismConfig,
1078 hidden_size: usize,
1079 num_attention_heads: usize,
1080 ) -> f32 {
1081 let overlap_tokens = config.overlap_size;
1082 let communication_per_overlap = overlap_tokens * hidden_size * 4; let attention_communication = overlap_tokens * overlap_tokens * num_attention_heads * 4;
1084
1085 (communication_per_overlap + attention_communication) as f32 / (1024.0 * 1024.0)
1086 }
1088
1089 pub fn calculate_memory_savings(
1091 total_sequence_length: usize,
1092 sequence_parallel_size: usize,
1093 hidden_size: usize,
1094 ) -> f32 {
1095 let tokens_per_device = total_sequence_length / sequence_parallel_size;
1096 let memory_per_device = tokens_per_device * hidden_size * 4; let total_memory_without_sp = total_sequence_length * hidden_size * 4;
1098
1099 1.0 - (memory_per_device as f32 / total_memory_without_sp as f32)
1100 }
1101}
1102
1103#[cfg(test)]
1104mod tests {
1105 use super::*;
1106 use crate::distributed::SimulatedProcessGroup;
1107 use std::sync::Arc;
1108
1109 #[test]
1110 fn test_sequence_parallelism_config() {
1111 let config = SequenceParallelismConfig::default();
1112 assert_eq!(config.sequence_parallel_size, 1);
1113 assert_eq!(config.max_sequence_length_per_device, 2048);
1114 assert_eq!(config.overlap_size, 128);
1115 }
1116
1117 #[test]
1118 fn test_sequence_parallelism_creation() {
1119 let config = SequenceParallelismConfig {
1120 sequence_parallel_size: 4,
1121 max_sequence_length_per_device: 1024,
1122 overlap_size: 64,
1123 ..Default::default()
1124 };
1125
1126 let process_group = Arc::new(SimulatedProcessGroup::new(0, 4));
1127 let sequence_parallelism = SequenceParallelism::new(config, 0, 4, process_group);
1128
1129 assert!(sequence_parallelism.is_ok());
1130 }
1131
1132 #[test]
1133 #[ignore] fn test_equal_chunks_splitting() {
1135 let config = SequenceParallelismConfig {
1136 sequence_parallel_size: 2,
1137 max_sequence_length_per_device: 1000,
1138 overlap_size: 100,
1139 ..Default::default()
1140 };
1141
1142 let process_group = Arc::new(SimulatedProcessGroup::new(0, 2));
1143 let mut sequence_parallelism = SequenceParallelism::new(config, 0, 2, process_group)
1144 .expect("operation failed in test");
1145
1146 let chunks = sequence_parallelism.split_sequence(1800).expect("operation failed in test");
1147 assert_eq!(chunks.len(), 2);
1148 assert_eq!(chunks[0].start_position, 0);
1149 assert_eq!(chunks[0].end_position, 1000);
1150 assert_eq!(chunks[1].start_position, 900); }
1152
1153 #[test]
1154 #[ignore] fn test_chunk_processing() {
1156 let config = SequenceParallelismConfig {
1157 sequence_parallel_size: 2,
1158 max_sequence_length_per_device: 1000,
1159 overlap_size: 100,
1160 ..Default::default()
1161 };
1162
1163 let process_group = Arc::new(SimulatedProcessGroup::new(0, 2));
1164 let mut sequence_parallelism = SequenceParallelism::new(config, 0, 2, process_group)
1165 .expect("operation failed in test");
1166
1167 let _chunks = sequence_parallelism.split_sequence(1800).expect("operation failed in test");
1168
1169 let input = Tensor::zeros(&[1000, 768]).expect("tensor operation failed");
1170 let result = sequence_parallelism.forward_chunk(0, &input, None);
1171 assert!(result.is_ok());
1172 }
1173
1174 #[test]
1175 fn test_gradient_synchronization() {
1176 let config = SequenceParallelismConfig {
1177 sync_gradients: true,
1178 ..Default::default()
1179 };
1180
1181 let process_group = Arc::new(SimulatedProcessGroup::new(0, 1));
1182 let sequence_parallelism = SequenceParallelism::new(config, 0, 1, process_group)
1183 .expect("operation failed in test");
1184
1185 let mut gradients = HashMap::new();
1186 gradients.insert(
1187 "test_param".to_string(),
1188 Tensor::ones(&[10, 10]).expect("tensor operation failed"),
1189 );
1190
1191 let result = sequence_parallelism.synchronize_gradients(&mut gradients);
1192 assert!(result.is_ok());
1193 }
1194
1195 #[test]
1196 fn test_memory_usage_update() {
1197 let config = SequenceParallelismConfig::default();
1198 let process_group = Arc::new(SimulatedProcessGroup::new(0, 1));
1199 let sequence_parallelism = SequenceParallelism::new(config, 0, 1, process_group)
1200 .expect("operation failed in test");
1201
1202 let result = sequence_parallelism.update_memory_usage(0, 1024 * 1024 * 1024); assert!(result.is_ok());
1204
1205 let stats = sequence_parallelism.get_statistics();
1206 assert_eq!(stats.peak_memory_usage, 1024 * 1024 * 1024);
1207 }
1208
1209 #[test]
1210 fn test_optimal_sequence_config_calculation() {
1211 let config = utils::calculate_optimal_sequence_config(
1212 10000, 8 * 1024 * 1024 * 1024, 1024, 4, )
1217 .expect("operation failed in test");
1218
1219 assert!(config.sequence_parallel_size <= 4);
1220 assert!(config.max_sequence_length_per_device > 0);
1221 }
1222
1223 #[test]
1224 fn test_communication_cost_estimation() {
1225 let config = SequenceParallelismConfig::default();
1226 let cost = utils::estimate_communication_cost(&config, 768, 12);
1227 assert!(cost > 0.0);
1228 }
1229
1230 #[test]
1231 fn test_memory_savings_calculation() {
1232 let savings = utils::calculate_memory_savings(10000, 4, 768);
1233 assert!(savings > 0.0 && savings < 1.0);
1234 }
1235}