1use crate::Layer;
2use std::collections::HashMap;
3use tenflowers_core::{Device, Result, TensorError};
4
5#[derive(Debug, Clone, PartialEq)]
7pub enum PlacementStrategy {
8 Auto,
10 Device(Device),
12 Split(Vec<Device>),
14 Pipeline { stage: usize, device: Device },
16}
17
18#[derive(Debug, Clone)]
20pub struct ModelParallelConfig {
21 pub layer_placement: HashMap<usize, PlacementStrategy>,
23 pub pipeline_config: Option<PipelineConfig>,
25 pub tensor_parallel_config: Option<TensorParallelConfig>,
27}
28
29#[derive(Debug, Clone)]
31pub struct PipelineConfig {
32 pub num_stages: usize,
34 pub micro_batch_size: usize,
36 pub num_micro_batches: usize,
38 pub stage_devices: Vec<Device>,
40}
41
42#[derive(Debug, Clone)]
44pub struct TensorParallelConfig {
45 pub tp_size: usize,
47 pub devices: Vec<Device>,
49 pub split_dim: usize,
51}
52
53pub struct SplitLayer<T> {
55 pub sub_layers: Vec<Box<dyn Layer<T>>>,
57 pub device_assignment: Vec<Device>,
59 pub communication_pattern: CommunicationPattern,
61}
62
63#[derive(Debug, Clone)]
65pub enum CommunicationPattern {
66 Concat { dim: usize },
68 AllReduce,
70 Gather { target_device: Device },
72 Custom { pattern_id: String },
74}
75
76pub trait ParallelLayer<T>: Layer<T> {
78 fn memory_requirements(&self) -> Result<MemoryRequirements>;
80
81 fn split_across_devices(&self, devices: &[Device]) -> Result<SplitLayer<T>>;
83
84 fn suggest_placement(&self, available_devices: &[Device]) -> Result<PlacementStrategy>;
86
87 fn to_device(&mut self, device: &Device) -> Result<()>;
89
90 fn current_device(&self) -> Option<Device>;
92
93 fn can_split(&self) -> bool;
95
96 fn computation_intensity(&self) -> f64;
98}
99
100#[derive(Debug, Clone)]
102pub struct MemoryRequirements {
103 pub parameter_memory: usize,
105 pub activation_memory: usize,
107 pub gradient_memory: usize,
109 pub temp_memory: usize,
111}
112
113impl MemoryRequirements {
114 pub fn total(&self) -> usize {
116 self.parameter_memory + self.activation_memory + self.gradient_memory + self.temp_memory
117 }
118}
119
120pub struct ModelParallelCoordinator {
122 config: ModelParallelConfig,
123 layer_assignments: HashMap<usize, Device>,
124 communication_groups: Vec<CommunicationGroup>,
125}
126
127#[derive(Debug, Clone)]
129pub struct CommunicationGroup {
130 pub id: String,
132 pub devices: Vec<Device>,
134 pub group_type: GroupType,
136}
137
138#[derive(Debug, Clone)]
140pub enum GroupType {
141 PipelineStage(usize),
143 TensorParallel,
145 DataParallel,
147 Custom(String),
149}
150
151impl ModelParallelCoordinator {
152 pub fn new(config: ModelParallelConfig) -> Self {
154 Self {
155 config,
156 layer_assignments: HashMap::new(),
157 communication_groups: Vec::new(),
158 }
159 }
160
161 pub fn assign_devices(&mut self, num_layers: usize) -> Result<()> {
163 for layer_idx in 0..num_layers {
165 let device = if let Some(strategy) = self.config.layer_placement.get(&layer_idx) {
166 match strategy {
167 PlacementStrategy::Device(device) => device.clone(),
168 PlacementStrategy::Pipeline { device, .. } => device.clone(),
169 PlacementStrategy::Auto => {
170 #[cfg(feature = "gpu")]
173 {
174 let device_idx = layer_idx % 4; Device::Gpu(device_idx)
176 }
177 #[cfg(not(feature = "gpu"))]
178 {
179 Device::Cpu
180 }
181 }
182 PlacementStrategy::Split(_) => {
183 #[cfg(feature = "gpu")]
186 {
187 Device::Gpu(0)
188 }
189 #[cfg(not(feature = "gpu"))]
190 {
191 Device::Cpu
192 }
193 }
194 }
195 } else {
196 Device::Cpu
198 };
199
200 self.layer_assignments.insert(layer_idx, device);
201 }
202
203 Ok(())
204 }
205
206 pub fn get_layer_device(&self, layer_idx: usize) -> Option<&Device> {
208 self.layer_assignments.get(&layer_idx)
209 }
210
211 pub fn setup_communication_groups(&mut self) -> Result<()> {
213 if let Some(pipeline_config) = &self.config.pipeline_config {
215 for stage in 0..pipeline_config.num_stages {
216 let group = CommunicationGroup {
217 id: format!("pipeline_stage_{stage}"),
218 devices: vec![pipeline_config.stage_devices[stage].clone()],
219 group_type: GroupType::PipelineStage(stage),
220 };
221 self.communication_groups.push(group);
222 }
223 }
224
225 if let Some(tp_config) = &self.config.tensor_parallel_config {
227 let group = CommunicationGroup {
228 id: "tensor_parallel".to_string(),
229 devices: tp_config.devices.clone(),
230 group_type: GroupType::TensorParallel,
231 };
232 self.communication_groups.push(group);
233 }
234
235 Ok(())
236 }
237}
238
239impl Default for ModelParallelConfig {
241 fn default() -> Self {
242 Self {
243 layer_placement: HashMap::new(),
244 pipeline_config: None,
245 tensor_parallel_config: None,
246 }
247 }
248}
249
250pub mod utils {
252 use super::*;
253
254 pub fn calculate_optimal_placement<T>(
256 layers: &[&dyn ParallelLayer<T>],
257 devices: &[Device],
258 memory_constraints: &HashMap<Device, usize>,
259 ) -> Result<HashMap<usize, PlacementStrategy>> {
260 let mut placement = HashMap::new();
261 let mut device_memory_used: HashMap<Device, usize> = HashMap::new();
262
263 for device in devices {
265 device_memory_used.insert(device.clone(), 0);
266 }
267
268 for (layer_idx, layer) in layers.iter().enumerate() {
270 let requirements = layer.memory_requirements()?;
271 let suggested = layer.suggest_placement(devices)?;
272
273 let assigned_device = match suggested {
275 PlacementStrategy::Device(device) => {
276 if let Some(&constraint) = memory_constraints.get(&device) {
277 let used = device_memory_used.get(&device).unwrap_or(&0);
278 if used + requirements.total() <= constraint {
279 Some(device)
280 } else {
281 None
282 }
283 } else {
284 Some(device)
285 }
286 }
287 PlacementStrategy::Auto => {
288 devices
290 .iter()
291 .filter_map(|device| {
292 let constraint = memory_constraints.get(device)?;
293 let used = device_memory_used.get(device).unwrap_or(&0);
294 if used + requirements.total() <= *constraint {
295 Some((device, constraint - used))
296 } else {
297 None
298 }
299 })
300 .max_by_key(|(_, available)| *available)
301 .map(|(device, _)| device.clone())
302 }
303 _ => devices.first().cloned(), };
305
306 if let Some(device) = assigned_device {
307 *device_memory_used
308 .get_mut(&device)
309 .expect("device should exist in memory tracking map") += requirements.total();
310 placement.insert(layer_idx, PlacementStrategy::Device(device));
311 } else {
312 return Err(TensorError::allocation_error_simple(
313 "Cannot find device with sufficient memory for layer".to_string(),
314 ));
315 }
316 }
317
318 Ok(placement)
319 }
320
321 pub fn create_pipeline_config(
323 layer_devices: &HashMap<usize, Device>,
324 micro_batch_size: usize,
325 ) -> Result<PipelineConfig> {
326 let mut stages = Vec::new();
328 let mut current_stage_device = None;
329 let mut stage_devices = Vec::new();
330
331 let mut sorted_layers: Vec<_> = layer_devices.iter().collect();
332 sorted_layers.sort_by_key(|(idx, _)| *idx);
333
334 for (_, device) in sorted_layers {
335 if current_stage_device.as_ref() != Some(device) {
336 if let Some(stage_device) = current_stage_device {
337 stages.push(stage_device);
338 }
339 current_stage_device = Some(device.clone());
340 stage_devices.push(device.clone());
341 }
342 }
343
344 if let Some(device) = current_stage_device {
345 stages.push(device);
346 }
347
348 Ok(PipelineConfig {
349 num_stages: stage_devices.len(),
350 micro_batch_size,
351 num_micro_batches: 4, stage_devices,
353 })
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360
361 #[test]
362 fn test_model_parallel_config_creation() {
363 let config = ModelParallelConfig::default();
364 assert!(config.layer_placement.is_empty());
365 assert!(config.pipeline_config.is_none());
366 assert!(config.tensor_parallel_config.is_none());
367 }
368
369 #[test]
370 fn test_pipeline_config() {
371 let devices = vec![
372 #[cfg(feature = "gpu")]
373 Device::Gpu(0),
374 #[cfg(not(feature = "gpu"))]
375 Device::Cpu,
376 #[cfg(feature = "gpu")]
377 Device::Gpu(1),
378 #[cfg(not(feature = "gpu"))]
379 Device::Cpu,
380 #[cfg(feature = "gpu")]
381 Device::Gpu(2),
382 #[cfg(not(feature = "gpu"))]
383 Device::Cpu,
384 ];
385 let config = PipelineConfig {
386 num_stages: 3,
387 micro_batch_size: 8,
388 num_micro_batches: 4,
389 stage_devices: devices,
390 };
391
392 assert_eq!(config.num_stages, 3);
393 assert_eq!(config.micro_batch_size, 8);
394 assert_eq!(config.num_micro_batches, 4);
395 }
396
397 #[test]
398 fn test_memory_requirements() {
399 let req = MemoryRequirements {
400 parameter_memory: 1000,
401 activation_memory: 2000,
402 gradient_memory: 1000,
403 temp_memory: 500,
404 };
405
406 assert_eq!(req.total(), 4500);
407 }
408}