Skip to main content

tenflowers_neural/
model_parallel.rs

1use crate::Layer;
2use std::collections::HashMap;
3use tenflowers_core::{Device, Result, TensorError};
4
5/// Device placement strategy for layer splitting
6#[derive(Debug, Clone, PartialEq)]
7pub enum PlacementStrategy {
8    /// Automatically choose device based on layer characteristics
9    Auto,
10    /// Place on specific device
11    Device(Device),
12    /// Split layer across multiple devices
13    Split(Vec<Device>),
14    /// Pipeline stage assignment
15    Pipeline { stage: usize, device: Device },
16}
17
18/// Model parallelism configuration for a model
19#[derive(Debug, Clone)]
20pub struct ModelParallelConfig {
21    /// Device placement for each layer (by layer index)
22    pub layer_placement: HashMap<usize, PlacementStrategy>,
23    /// Pipeline configuration if using pipeline parallelism
24    pub pipeline_config: Option<PipelineConfig>,
25    /// Tensor parallelism configuration
26    pub tensor_parallel_config: Option<TensorParallelConfig>,
27}
28
29/// Pipeline parallelism configuration
30#[derive(Debug, Clone)]
31pub struct PipelineConfig {
32    /// Number of pipeline stages
33    pub num_stages: usize,
34    /// Micro-batch size for pipeline execution
35    pub micro_batch_size: usize,
36    /// Number of micro-batches per batch
37    pub num_micro_batches: usize,
38    /// Devices for each pipeline stage
39    pub stage_devices: Vec<Device>,
40}
41
42/// Tensor parallelism configuration
43#[derive(Debug, Clone)]
44pub struct TensorParallelConfig {
45    /// Tensor parallel size (number of devices)
46    pub tp_size: usize,
47    /// Devices for tensor parallelism
48    pub devices: Vec<Device>,
49    /// Dimension to split along (0 for rows, 1 for columns)
50    pub split_dim: usize,
51}
52
53/// Layer splitting result
54pub struct SplitLayer<T> {
55    /// Sub-layers on different devices
56    pub sub_layers: Vec<Box<dyn Layer<T>>>,
57    /// Device assignment for each sub-layer
58    pub device_assignment: Vec<Device>,
59    /// Communication pattern for combining results
60    pub communication_pattern: CommunicationPattern,
61}
62
63/// Communication patterns for combining split layer results
64#[derive(Debug, Clone)]
65pub enum CommunicationPattern {
66    /// Simple concatenation along specified dimension
67    Concat { dim: usize },
68    /// All-reduce operation (sum results)
69    AllReduce,
70    /// Gather results to specific device
71    Gather { target_device: Device },
72    /// Custom communication function
73    Custom { pattern_id: String },
74}
75
76/// Extended layer trait for model parallelism
77pub trait ParallelLayer<T>: Layer<T> {
78    /// Get the memory requirements of this layer
79    fn memory_requirements(&self) -> Result<MemoryRequirements>;
80
81    /// Split this layer across multiple devices
82    fn split_across_devices(&self, devices: &[Device]) -> Result<SplitLayer<T>>;
83
84    /// Get optimal device placement for this layer
85    fn suggest_placement(&self, available_devices: &[Device]) -> Result<PlacementStrategy>;
86
87    /// Move layer to specific device
88    fn to_device(&mut self, device: &Device) -> Result<()>;
89
90    /// Get current device placement
91    fn current_device(&self) -> Option<Device>;
92
93    /// Check if layer can be split
94    fn can_split(&self) -> bool;
95
96    /// Get computation intensity (FLOPs per parameter)
97    fn computation_intensity(&self) -> f64;
98}
99
100/// Memory requirements for a layer
101#[derive(Debug, Clone)]
102pub struct MemoryRequirements {
103    /// Parameter memory in bytes
104    pub parameter_memory: usize,
105    /// Activation memory in bytes (forward pass)
106    pub activation_memory: usize,
107    /// Gradient memory in bytes (backward pass)
108    pub gradient_memory: usize,
109    /// Temporary buffer memory in bytes
110    pub temp_memory: usize,
111}
112
113impl MemoryRequirements {
114    /// Total memory requirement
115    pub fn total(&self) -> usize {
116        self.parameter_memory + self.activation_memory + self.gradient_memory + self.temp_memory
117    }
118}
119
120/// Model parallelism coordinator
121pub struct ModelParallelCoordinator {
122    config: ModelParallelConfig,
123    layer_assignments: HashMap<usize, Device>,
124    communication_groups: Vec<CommunicationGroup>,
125}
126
127/// Communication group for model parallelism
128#[derive(Debug, Clone)]
129pub struct CommunicationGroup {
130    /// Group ID
131    pub id: String,
132    /// Devices in this group
133    pub devices: Vec<Device>,
134    /// Group type (pipeline stage, tensor parallel group, etc.)
135    pub group_type: GroupType,
136}
137
138/// Types of communication groups
139#[derive(Debug, Clone)]
140pub enum GroupType {
141    /// Pipeline stage group
142    PipelineStage(usize),
143    /// Tensor parallel group
144    TensorParallel,
145    /// Data parallel group
146    DataParallel,
147    /// Custom group
148    Custom(String),
149}
150
151impl ModelParallelCoordinator {
152    /// Create new model parallelism coordinator
153    pub fn new(config: ModelParallelConfig) -> Self {
154        Self {
155            config,
156            layer_assignments: HashMap::new(),
157            communication_groups: Vec::new(),
158        }
159    }
160
161    /// Assign devices to layers based on configuration
162    pub fn assign_devices(&mut self, num_layers: usize) -> Result<()> {
163        // Implement device assignment logic
164        for layer_idx in 0..num_layers {
165            let device = if let Some(strategy) = self.config.layer_placement.get(&layer_idx) {
166                match strategy {
167                    PlacementStrategy::Device(device) => device.clone(),
168                    PlacementStrategy::Pipeline { device, .. } => device.clone(),
169                    PlacementStrategy::Auto => {
170                        // Auto-assign based on layer index and available devices
171                        // For now, use simple round-robin
172                        #[cfg(feature = "gpu")]
173                        {
174                            let device_idx = layer_idx % 4; // Assume 4 devices
175                            Device::Gpu(device_idx)
176                        }
177                        #[cfg(not(feature = "gpu"))]
178                        {
179                            Device::Cpu
180                        }
181                    }
182                    PlacementStrategy::Split(_) => {
183                        // For split layers, assign to first device for now
184                        // Full implementation would handle split coordination
185                        #[cfg(feature = "gpu")]
186                        {
187                            Device::Gpu(0)
188                        }
189                        #[cfg(not(feature = "gpu"))]
190                        {
191                            Device::Cpu
192                        }
193                    }
194                }
195            } else {
196                // Default to CPU if no assignment
197                Device::Cpu
198            };
199
200            self.layer_assignments.insert(layer_idx, device);
201        }
202
203        Ok(())
204    }
205
206    /// Get device assignment for a specific layer
207    pub fn get_layer_device(&self, layer_idx: usize) -> Option<&Device> {
208        self.layer_assignments.get(&layer_idx)
209    }
210
211    /// Setup communication groups for the parallel execution
212    pub fn setup_communication_groups(&mut self) -> Result<()> {
213        // Setup pipeline stage groups if using pipeline parallelism
214        if let Some(pipeline_config) = &self.config.pipeline_config {
215            for stage in 0..pipeline_config.num_stages {
216                let group = CommunicationGroup {
217                    id: format!("pipeline_stage_{stage}"),
218                    devices: vec![pipeline_config.stage_devices[stage].clone()],
219                    group_type: GroupType::PipelineStage(stage),
220                };
221                self.communication_groups.push(group);
222            }
223        }
224
225        // Setup tensor parallel groups if using tensor parallelism
226        if let Some(tp_config) = &self.config.tensor_parallel_config {
227            let group = CommunicationGroup {
228                id: "tensor_parallel".to_string(),
229                devices: tp_config.devices.clone(),
230                group_type: GroupType::TensorParallel,
231            };
232            self.communication_groups.push(group);
233        }
234
235        Ok(())
236    }
237}
238
239/// Default implementation for ModelParallelConfig
240impl Default for ModelParallelConfig {
241    fn default() -> Self {
242        Self {
243            layer_placement: HashMap::new(),
244            pipeline_config: None,
245            tensor_parallel_config: None,
246        }
247    }
248}
249
250/// Utility functions for model parallelism
251pub mod utils {
252    use super::*;
253
254    /// Calculate optimal layer placement based on memory constraints
255    pub fn calculate_optimal_placement<T>(
256        layers: &[&dyn ParallelLayer<T>],
257        devices: &[Device],
258        memory_constraints: &HashMap<Device, usize>,
259    ) -> Result<HashMap<usize, PlacementStrategy>> {
260        let mut placement = HashMap::new();
261        let mut device_memory_used: HashMap<Device, usize> = HashMap::new();
262
263        // Initialize device memory usage
264        for device in devices {
265            device_memory_used.insert(device.clone(), 0);
266        }
267
268        // Assign each layer to best fitting device
269        for (layer_idx, layer) in layers.iter().enumerate() {
270            let requirements = layer.memory_requirements()?;
271            let suggested = layer.suggest_placement(devices)?;
272
273            // Find device with sufficient memory
274            let assigned_device = match suggested {
275                PlacementStrategy::Device(device) => {
276                    if let Some(&constraint) = memory_constraints.get(&device) {
277                        let used = device_memory_used.get(&device).unwrap_or(&0);
278                        if used + requirements.total() <= constraint {
279                            Some(device)
280                        } else {
281                            None
282                        }
283                    } else {
284                        Some(device)
285                    }
286                }
287                PlacementStrategy::Auto => {
288                    // Find device with most available memory
289                    devices
290                        .iter()
291                        .filter_map(|device| {
292                            let constraint = memory_constraints.get(device)?;
293                            let used = device_memory_used.get(device).unwrap_or(&0);
294                            if used + requirements.total() <= *constraint {
295                                Some((device, constraint - used))
296                            } else {
297                                None
298                            }
299                        })
300                        .max_by_key(|(_, available)| *available)
301                        .map(|(device, _)| device.clone())
302                }
303                _ => devices.first().cloned(), // Fallback
304            };
305
306            if let Some(device) = assigned_device {
307                *device_memory_used
308                    .get_mut(&device)
309                    .expect("device should exist in memory tracking map") += requirements.total();
310                placement.insert(layer_idx, PlacementStrategy::Device(device));
311            } else {
312                return Err(TensorError::allocation_error_simple(
313                    "Cannot find device with sufficient memory for layer".to_string(),
314                ));
315            }
316        }
317
318        Ok(placement)
319    }
320
321    /// Create pipeline configuration from layer assignments
322    pub fn create_pipeline_config(
323        layer_devices: &HashMap<usize, Device>,
324        micro_batch_size: usize,
325    ) -> Result<PipelineConfig> {
326        // Group consecutive layers on same device into stages
327        let mut stages = Vec::new();
328        let mut current_stage_device = None;
329        let mut stage_devices = Vec::new();
330
331        let mut sorted_layers: Vec<_> = layer_devices.iter().collect();
332        sorted_layers.sort_by_key(|(idx, _)| *idx);
333
334        for (_, device) in sorted_layers {
335            if current_stage_device.as_ref() != Some(device) {
336                if let Some(stage_device) = current_stage_device {
337                    stages.push(stage_device);
338                }
339                current_stage_device = Some(device.clone());
340                stage_devices.push(device.clone());
341            }
342        }
343
344        if let Some(device) = current_stage_device {
345            stages.push(device);
346        }
347
348        Ok(PipelineConfig {
349            num_stages: stage_devices.len(),
350            micro_batch_size,
351            num_micro_batches: 4, // Default to 4 micro-batches
352            stage_devices,
353        })
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360
361    #[test]
362    fn test_model_parallel_config_creation() {
363        let config = ModelParallelConfig::default();
364        assert!(config.layer_placement.is_empty());
365        assert!(config.pipeline_config.is_none());
366        assert!(config.tensor_parallel_config.is_none());
367    }
368
369    #[test]
370    fn test_pipeline_config() {
371        let devices = vec![
372            #[cfg(feature = "gpu")]
373            Device::Gpu(0),
374            #[cfg(not(feature = "gpu"))]
375            Device::Cpu,
376            #[cfg(feature = "gpu")]
377            Device::Gpu(1),
378            #[cfg(not(feature = "gpu"))]
379            Device::Cpu,
380            #[cfg(feature = "gpu")]
381            Device::Gpu(2),
382            #[cfg(not(feature = "gpu"))]
383            Device::Cpu,
384        ];
385        let config = PipelineConfig {
386            num_stages: 3,
387            micro_batch_size: 8,
388            num_micro_batches: 4,
389            stage_devices: devices,
390        };
391
392        assert_eq!(config.num_stages, 3);
393        assert_eq!(config.micro_batch_size, 8);
394        assert_eq!(config.num_micro_batches, 4);
395    }
396
397    #[test]
398    fn test_memory_requirements() {
399        let req = MemoryRequirements {
400            parameter_memory: 1000,
401            activation_memory: 2000,
402            gradient_memory: 1000,
403            temp_memory: 500,
404        };
405
406        assert_eq!(req.total(), 4500);
407    }
408}