Skip to main content

trustformers_core/parallel/
mod.rs

1//! Parallel execution support for TrustformeRS
2//!
3//! This module provides infrastructure for various parallelism strategies including:
4//! - Data parallelism
5//! - Model parallelism (tensor and pipeline)
6//! - Hybrid parallelism
7//! - NUMA-aware optimization
8
9pub mod model_parallel;
10pub mod parallel_layers;
11pub mod pipeline_parallel;
12pub mod tensor_parallel;
13
14pub mod mpi_communicator;
15
16#[cfg(feature = "nccl")]
17pub mod nccl_communicator;
18
19pub use model_parallel::{
20    CommunicationBackend, Communicator, DeviceMesh, DistributedTensor, ModelParallelConfig,
21    ModelParallelContext, ModelParallelStrategy, PipelineOp, PipelineSchedule,
22    PipelineScheduleType, TensorPartition,
23};
24
25pub use parallel_layers::{
26    ActivationType, ColumnParallelLinear, ParallelMLP, ParallelMultiHeadAttention,
27    RowParallelLinear,
28};
29
30pub use tensor_parallel::{
31    AsyncTensorParallel, InitMethod, TensorParallelInit, TensorParallelOps, TensorParallelShapes,
32};
33
34pub use pipeline_parallel::{
35    MicrobatchManager, PipelineExecutor, PipelineLayer, PipelineModel, PipelineOptimizer,
36    PipelineStage,
37};
38
39pub use mpi_communicator::{mpi_utils, MpiCommunicatorImpl};
40
41#[cfg(feature = "nccl")]
42pub use nccl_communicator::{create_nccl_communicator, NcclCommunicator};
43
44use crate::errors::{runtime_error, Result};
45use parking_lot::RwLock;
46use std::sync::Arc;
47
48/// Core parallelism strategy
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum ParallelismStrategy {
51    /// Data parallelism only
52    Data,
53    /// Model parallelism (tensor or pipeline)
54    Model,
55    /// Hybrid (data + model)
56    Hybrid,
57    /// No parallelism (single device)
58    None,
59}
60
61/// Parallel execution context
62#[derive(Clone)]
63pub struct ParallelContext {
64    strategy: ParallelismStrategy,
65    num_devices: usize,
66    device_id: usize,
67    numa_config: Option<NumaConfig>,
68}
69
70/// NUMA configuration for CPU optimization
71#[derive(Debug, Clone)]
72pub struct NumaConfig {
73    pub node_id: usize,
74    pub cpu_affinity: Vec<usize>,
75    pub memory_policy: MemoryPolicy,
76}
77
78#[derive(Debug, Clone, Copy)]
79pub enum MemoryPolicy {
80    /// Bind memory to local NUMA node
81    BindLocal,
82    /// Interleave memory across nodes
83    Interleave,
84    /// Prefer local but allow remote
85    PreferLocal,
86}
87
88impl ParallelContext {
89    pub fn new(strategy: ParallelismStrategy, num_devices: usize) -> Self {
90        Self {
91            strategy,
92            num_devices,
93            device_id: 0,
94            numa_config: None,
95        }
96    }
97
98    pub fn with_device_id(mut self, device_id: usize) -> Self {
99        self.device_id = device_id;
100        self
101    }
102
103    pub fn with_numa_config(mut self, numa_config: NumaConfig) -> Self {
104        self.numa_config = Some(numa_config);
105        self
106    }
107
108    pub fn strategy(&self) -> ParallelismStrategy {
109        self.strategy
110    }
111
112    pub fn num_devices(&self) -> usize {
113        self.num_devices
114    }
115
116    pub fn device_id(&self) -> usize {
117        self.device_id
118    }
119}
120
121/// Parallel operations trait
122pub trait ParallelOps {
123    /// Execute operation in parallel context
124    fn parallel_execute<F, T>(&self, f: F) -> Result<T>
125    where
126        F: FnOnce(&ParallelContext) -> Result<T>;
127
128    /// Map operation across parallel devices
129    fn parallel_map<F, T>(&self, items: Vec<T>, f: F) -> Result<Vec<T>>
130    where
131        F: Fn(T, &ParallelContext) -> Result<T> + Send + Sync,
132        T: Send;
133}
134
135/// Global parallel context
136static PARALLEL_CONTEXT: RwLock<Option<Arc<ParallelContext>>> = RwLock::new(None);
137
138/// Initialize global parallel context
139pub fn init_parallelism(context: ParallelContext) {
140    *PARALLEL_CONTEXT.write() = Some(Arc::new(context));
141}
142
143/// Get global parallel context
144pub fn parallel_context() -> Option<Arc<ParallelContext>> {
145    PARALLEL_CONTEXT.read().clone()
146}
147
148/// Execute function in parallel context
149pub fn parallel_execute<F, T>(f: F) -> Result<T>
150where
151    F: FnOnce(&ParallelContext) -> Result<T>,
152{
153    let context =
154        parallel_context().ok_or_else(|| runtime_error("Parallel context not initialized"))?;
155    f(&context)
156}
157
158/// Map function across items in parallel
159pub fn parallel_map<F, T>(items: Vec<T>, f: F) -> Result<Vec<T>>
160where
161    F: Fn(T, &ParallelContext) -> Result<T> + Send + Sync,
162    T: Send,
163{
164    let context =
165        parallel_context().ok_or_else(|| runtime_error("Parallel context not initialized"))?;
166
167    // Simple implementation - in practice would use thread pool
168    items.into_iter().map(|item| f(item, &context)).collect()
169}
170
171/// Parallel chunk mapping for large datasets
172pub fn parallel_chunk_map<F, T>(items: Vec<T>, chunk_size: usize, f: F) -> Result<Vec<T>>
173where
174    F: Fn(Vec<T>, &ParallelContext) -> Result<Vec<T>> + Send + Sync,
175    T: Send + Clone,
176{
177    let context =
178        parallel_context().ok_or_else(|| runtime_error("Parallel context not initialized"))?;
179
180    let mut chunks = Vec::new();
181    let mut i = 0;
182    while i < items.len() {
183        let end = (i + chunk_size).min(items.len());
184        chunks.push(items[i..end].to_vec());
185        i = end;
186    }
187
188    let results: Result<Vec<Vec<T>>> = chunks.into_iter().map(|chunk| f(chunk, &context)).collect();
189
190    results.map(|vecs| vecs.into_iter().flatten().collect())
191}