trustformers_core/parallel/
mod.rs1pub mod model_parallel;
10pub mod parallel_layers;
11pub mod pipeline_parallel;
12pub mod tensor_parallel;
13
14pub mod mpi_communicator;
15
16#[cfg(feature = "nccl")]
17pub mod nccl_communicator;
18
19pub use model_parallel::{
20 CommunicationBackend, Communicator, DeviceMesh, DistributedTensor, ModelParallelConfig,
21 ModelParallelContext, ModelParallelStrategy, PipelineOp, PipelineSchedule,
22 PipelineScheduleType, TensorPartition,
23};
24
25pub use parallel_layers::{
26 ActivationType, ColumnParallelLinear, ParallelMLP, ParallelMultiHeadAttention,
27 RowParallelLinear,
28};
29
30pub use tensor_parallel::{
31 AsyncTensorParallel, InitMethod, TensorParallelInit, TensorParallelOps, TensorParallelShapes,
32};
33
34pub use pipeline_parallel::{
35 MicrobatchManager, PipelineExecutor, PipelineLayer, PipelineModel, PipelineOptimizer,
36 PipelineStage,
37};
38
39pub use mpi_communicator::{mpi_utils, MpiCommunicatorImpl};
40
41#[cfg(feature = "nccl")]
42pub use nccl_communicator::{create_nccl_communicator, NcclCommunicator};
43
44use crate::errors::{runtime_error, Result};
45use parking_lot::RwLock;
46use std::sync::Arc;
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum ParallelismStrategy {
51 Data,
53 Model,
55 Hybrid,
57 None,
59}
60
61#[derive(Clone)]
63pub struct ParallelContext {
64 strategy: ParallelismStrategy,
65 num_devices: usize,
66 device_id: usize,
67 numa_config: Option<NumaConfig>,
68}
69
70#[derive(Debug, Clone)]
72pub struct NumaConfig {
73 pub node_id: usize,
74 pub cpu_affinity: Vec<usize>,
75 pub memory_policy: MemoryPolicy,
76}
77
78#[derive(Debug, Clone, Copy)]
79pub enum MemoryPolicy {
80 BindLocal,
82 Interleave,
84 PreferLocal,
86}
87
88impl ParallelContext {
89 pub fn new(strategy: ParallelismStrategy, num_devices: usize) -> Self {
90 Self {
91 strategy,
92 num_devices,
93 device_id: 0,
94 numa_config: None,
95 }
96 }
97
98 pub fn with_device_id(mut self, device_id: usize) -> Self {
99 self.device_id = device_id;
100 self
101 }
102
103 pub fn with_numa_config(mut self, numa_config: NumaConfig) -> Self {
104 self.numa_config = Some(numa_config);
105 self
106 }
107
108 pub fn strategy(&self) -> ParallelismStrategy {
109 self.strategy
110 }
111
112 pub fn num_devices(&self) -> usize {
113 self.num_devices
114 }
115
116 pub fn device_id(&self) -> usize {
117 self.device_id
118 }
119}
120
121pub trait ParallelOps {
123 fn parallel_execute<F, T>(&self, f: F) -> Result<T>
125 where
126 F: FnOnce(&ParallelContext) -> Result<T>;
127
128 fn parallel_map<F, T>(&self, items: Vec<T>, f: F) -> Result<Vec<T>>
130 where
131 F: Fn(T, &ParallelContext) -> Result<T> + Send + Sync,
132 T: Send;
133}
134
135static PARALLEL_CONTEXT: RwLock<Option<Arc<ParallelContext>>> = RwLock::new(None);
137
138pub fn init_parallelism(context: ParallelContext) {
140 *PARALLEL_CONTEXT.write() = Some(Arc::new(context));
141}
142
143pub fn parallel_context() -> Option<Arc<ParallelContext>> {
145 PARALLEL_CONTEXT.read().clone()
146}
147
148pub fn parallel_execute<F, T>(f: F) -> Result<T>
150where
151 F: FnOnce(&ParallelContext) -> Result<T>,
152{
153 let context =
154 parallel_context().ok_or_else(|| runtime_error("Parallel context not initialized"))?;
155 f(&context)
156}
157
158pub fn parallel_map<F, T>(items: Vec<T>, f: F) -> Result<Vec<T>>
160where
161 F: Fn(T, &ParallelContext) -> Result<T> + Send + Sync,
162 T: Send,
163{
164 let context =
165 parallel_context().ok_or_else(|| runtime_error("Parallel context not initialized"))?;
166
167 items.into_iter().map(|item| f(item, &context)).collect()
169}
170
171pub fn parallel_chunk_map<F, T>(items: Vec<T>, chunk_size: usize, f: F) -> Result<Vec<T>>
173where
174 F: Fn(Vec<T>, &ParallelContext) -> Result<Vec<T>> + Send + Sync,
175 T: Send + Clone,
176{
177 let context =
178 parallel_context().ok_or_else(|| runtime_error("Parallel context not initialized"))?;
179
180 let mut chunks = Vec::new();
181 let mut i = 0;
182 while i < items.len() {
183 let end = (i + chunk_size).min(items.len());
184 chunks.push(items[i..end].to_vec());
185 i = end;
186 }
187
188 let results: Result<Vec<Vec<T>>> = chunks.into_iter().map(|chunk| f(chunk, &context)).collect();
189
190 results.map(|vecs| vecs.into_iter().flatten().collect())
191}