Skip to main content

trustformers_core/parallel/
parallel_layers.rs

1//! Parallel implementations of neural network layers for model parallelism
2//!
3//! This module provides distributed versions of common layers that can be
4//! split across multiple devices for model parallel training.
5
6#![allow(unused_variables)] // Parallel layer implementation
7
8use super::model_parallel::{DistributedTensor, ModelParallelContext, TensorPartition};
9use crate::errors::{tensor_op_error, Result};
10use crate::Tensor;
11use std::sync::Arc;
12
13/// Column-parallel linear layer
14///
15/// Splits the weight matrix by columns (output dimension) across devices.
16/// This is typically used for the first linear layer in MLP blocks.
17pub struct ColumnParallelLinear {
18    /// Local weight shard [in_features, out_features_per_device]
19    weight: Tensor,
20    /// Bias (only on rank 0 to avoid duplication)
21    bias: Option<Tensor>,
22    /// Model parallel context
23    mp_context: Arc<ModelParallelContext>,
24    /// Total input features
25    #[allow(dead_code)]
26    in_features: usize,
27    /// Total output features (across all devices)
28    out_features: usize,
29}
30
31impl ColumnParallelLinear {
32    pub fn new(
33        in_features: usize,
34        out_features: usize,
35        bias: bool,
36        mp_context: Arc<ModelParallelContext>,
37    ) -> Result<Self> {
38        let world_size = mp_context.world_size();
39        let rank = mp_context.rank();
40
41        // Calculate local output features
42        let out_features_per_device = out_features.div_ceil(world_size);
43        let local_out_start = rank * out_features_per_device;
44        let local_out_end = ((rank + 1) * out_features_per_device).min(out_features);
45        let local_out_features = local_out_end - local_out_start;
46
47        // Initialize local weight shard
48        let weight = Tensor::randn(&[in_features, local_out_features])?;
49
50        // Bias only on rank 0 to avoid duplication during all-reduce
51        let bias = if bias && rank == 0 { Some(Tensor::zeros(&[out_features])?) } else { None };
52
53        Ok(Self {
54            weight,
55            bias,
56            mp_context,
57            in_features,
58            out_features,
59        })
60    }
61
62    pub fn forward(&self, input: &Tensor) -> Result<DistributedTensor> {
63        // Input: [batch_size, seq_len, in_features]
64        // Weight: [in_features, local_out_features]
65        // Output: [batch_size, seq_len, local_out_features]
66
67        let output = input.matmul(&self.weight)?;
68
69        // Add bias if present (only on rank 0)
70        let output = if let Some(ref bias) = self.bias {
71            // Slice bias to match local output features
72            let rank = self.mp_context.rank();
73            let world_size = self.mp_context.world_size();
74            let out_features_per_device = self.out_features.div_ceil(world_size);
75            let local_out_start = rank * out_features_per_device;
76            let local_out_end = ((rank + 1) * out_features_per_device).min(self.out_features);
77
78            let local_bias = bias.slice(0, local_out_start, local_out_end)?;
79            output.add(&local_bias)?
80        } else {
81            output
82        };
83
84        // Create distributed tensor
85        let mut global_shape = input.shape().to_vec();
86        let last_dim = global_shape.len() - 1;
87        global_shape[last_dim] = self.out_features;
88
89        let partition = TensorPartition {
90            split_dim: global_shape.len() - 1,
91            start_idx: self.mp_context.rank() * self.out_features / self.mp_context.world_size(),
92            end_idx: ((self.mp_context.rank() + 1) * self.out_features
93                / self.mp_context.world_size())
94            .min(self.out_features),
95            num_partitions: self.mp_context.world_size(),
96            partition_rank: self.mp_context.rank(),
97        };
98
99        Ok(DistributedTensor::new(
100            output,
101            global_shape,
102            partition,
103            self.mp_context.rank(),
104        ))
105    }
106
107    pub fn backward(&mut self, grad_output: &DistributedTensor, input: &Tensor) -> Result<Tensor> {
108        // Gradient w.r.t weight: input^T @ grad_output
109        let input_ndim = input.shape().len();
110        let grad_weight = input
111            .transpose(input_ndim.saturating_sub(2), input_ndim.saturating_sub(1))?
112            .matmul(&grad_output.local_shard)?;
113
114        // Gradient w.r.t input: grad_output @ weight^T
115        // This needs all-reduce since each device computes partial gradient
116        let weight_ndim = self.weight.shape().len();
117        let mut grad_input = grad_output.local_shard.matmul(
118            &self
119                .weight
120                .transpose(weight_ndim.saturating_sub(2), weight_ndim.saturating_sub(1))?,
121        )?;
122
123        // All-reduce grad_input across devices
124        self.mp_context.all_reduce(&mut grad_input)?;
125
126        // Update local weight shard
127        // In practice, this would be handled by the optimizer
128        // self.weight = self.weight - learning_rate * grad_weight;
129
130        Ok(grad_input)
131    }
132}
133
134/// Row-parallel linear layer
135///
136/// Splits the weight matrix by rows (input dimension) across devices.
137/// This is typically used for the second linear layer in MLP blocks.
138pub struct RowParallelLinear {
139    /// Local weight shard [in_features_per_device, out_features]
140    weight: Tensor,
141    /// Bias (replicated on all devices)
142    bias: Option<Tensor>,
143    /// Model parallel context
144    mp_context: Arc<ModelParallelContext>,
145    /// Total input features (across all devices)
146    #[allow(dead_code)]
147    in_features: usize,
148    /// Output features
149    _out_features: usize,
150}
151
152impl RowParallelLinear {
153    pub fn new(
154        in_features: usize,
155        out_features: usize,
156        bias: bool,
157        mp_context: Arc<ModelParallelContext>,
158    ) -> Result<Self> {
159        let world_size = mp_context.world_size();
160        let rank = mp_context.rank();
161
162        // Calculate local input features
163        let in_features_per_device = in_features.div_ceil(world_size);
164        let local_in_start = rank * in_features_per_device;
165        let local_in_end = ((rank + 1) * in_features_per_device).min(in_features);
166        let local_in_features = local_in_end - local_in_start;
167
168        // Initialize local weight shard
169        let weight = Tensor::randn(&[local_in_features, out_features])?;
170
171        // Bias is replicated on all devices
172        let bias = if bias { Some(Tensor::zeros(&[out_features])?) } else { None };
173
174        Ok(Self {
175            weight,
176            bias,
177            mp_context,
178            in_features,
179            _out_features: out_features,
180        })
181    }
182
183    pub fn forward(&self, input: &DistributedTensor) -> Result<Tensor> {
184        // Input is distributed: [batch_size, seq_len, local_in_features]
185        // Weight: [local_in_features, out_features]
186        // Local output: [batch_size, seq_len, out_features]
187
188        let local_output = input.local_shard.matmul(&self.weight)?;
189
190        // All-reduce to sum contributions from all devices
191        let mut output = local_output;
192        self.mp_context.all_reduce(&mut output)?;
193
194        // Add bias (same on all devices)
195        if let Some(bias) = &self.bias {
196            output = output.add(bias)?;
197        }
198
199        Ok(output)
200    }
201
202    pub fn backward(
203        &mut self,
204        grad_output: &Tensor,
205        input: &DistributedTensor,
206    ) -> Result<DistributedTensor> {
207        // Gradient w.r.t weight: input^T @ grad_output
208        let input_ndim = input.local_shard.shape().len();
209        let grad_weight = input
210            .local_shard
211            .transpose(input_ndim.saturating_sub(2), input_ndim.saturating_sub(1))?
212            .matmul(grad_output)?;
213
214        // Gradient w.r.t input: grad_output @ weight^T
215        let weight_ndim = self.weight.shape().len();
216        let grad_input_local = grad_output.matmul(
217            &self
218                .weight
219                .transpose(weight_ndim.saturating_sub(2), weight_ndim.saturating_sub(1))?,
220        )?;
221
222        // Create distributed gradient tensor
223        let partition = input.partition.clone();
224        Ok(DistributedTensor::new(
225            grad_input_local,
226            input.global_shape.clone(),
227            partition,
228            self.mp_context.rank(),
229        ))
230    }
231}
232
233/// Parallel multi-head attention layer
234///
235/// Distributes attention heads across devices for model parallelism.
236pub struct ParallelMultiHeadAttention {
237    /// Number of attention heads per device
238    #[allow(dead_code)]
239    num_heads_per_device: usize,
240    /// Total number of heads
241    num_heads: usize,
242    /// Head dimension
243    head_dim: usize,
244    /// Hidden size
245    hidden_size: usize,
246    /// Query projection (column parallel)
247    q_proj: ColumnParallelLinear,
248    /// Key projection (column parallel)
249    k_proj: ColumnParallelLinear,
250    /// Value projection (column parallel)
251    v_proj: ColumnParallelLinear,
252    /// Output projection (row parallel)
253    o_proj: RowParallelLinear,
254    /// Model parallel context
255    mp_context: Arc<ModelParallelContext>,
256}
257
258impl ParallelMultiHeadAttention {
259    pub fn new(
260        hidden_size: usize,
261        num_heads: usize,
262        mp_context: Arc<ModelParallelContext>,
263    ) -> Result<Self> {
264        let world_size = mp_context.world_size();
265
266        if num_heads % world_size != 0 {
267            return Err(tensor_op_error(
268                "ParallelMultiHeadAttention::new",
269                format!(
270                    "Number of heads {} must be divisible by world size {}",
271                    num_heads, world_size
272                ),
273            ));
274        }
275
276        let num_heads_per_device = num_heads / world_size;
277        let head_dim = hidden_size / num_heads;
278
279        // Create parallel linear layers
280        let q_proj =
281            ColumnParallelLinear::new(hidden_size, hidden_size, false, mp_context.clone())?;
282
283        let k_proj =
284            ColumnParallelLinear::new(hidden_size, hidden_size, false, mp_context.clone())?;
285
286        let v_proj =
287            ColumnParallelLinear::new(hidden_size, hidden_size, false, mp_context.clone())?;
288
289        let o_proj = RowParallelLinear::new(hidden_size, hidden_size, false, mp_context.clone())?;
290
291        Ok(Self {
292            num_heads_per_device,
293            num_heads,
294            head_dim,
295            hidden_size,
296            q_proj,
297            k_proj,
298            v_proj,
299            o_proj,
300            mp_context,
301        })
302    }
303
304    pub fn forward(
305        &self,
306        hidden_states: &Tensor,
307        attention_mask: Option<&Tensor>,
308    ) -> Result<Tensor> {
309        let batch_size = hidden_states.shape()[0];
310        let seq_len = hidden_states.shape()[1];
311
312        // Project to Q, K, V (distributed across devices)
313        let q = self.q_proj.forward(hidden_states)?;
314        let k = self.k_proj.forward(hidden_states)?;
315        let v = self.v_proj.forward(hidden_states)?;
316
317        // Get local tensor shards for processing
318        let q = q.local_shard.clone();
319        let k = k.local_shard.clone();
320        let v = v.local_shard.clone();
321
322        // Reshape for multi-head attention: [batch, seq_len, hidden_size] -> [batch, seq_len, num_heads_local, head_dim]
323        let num_heads_local = self.num_heads / self.mp_context.world_size();
324        let q = q.reshape(&[batch_size, seq_len, num_heads_local, self.head_dim])?;
325        let k = k.reshape(&[batch_size, seq_len, num_heads_local, self.head_dim])?;
326        let v = v.reshape(&[batch_size, seq_len, num_heads_local, self.head_dim])?;
327
328        // Transpose for attention: [batch, seq_len, num_heads_local, head_dim] -> [batch, num_heads_local, seq_len, head_dim]
329        let q = q.transpose(1, 2)?;
330        let k = k.transpose(1, 2)?;
331        let v = v.transpose(1, 2)?;
332
333        // Compute attention scores
334        let k_ndim = k.shape().len();
335        let scores = q.matmul(&k.transpose(k_ndim.saturating_sub(2), k_ndim.saturating_sub(1))?)?;
336        let scores = scores.scalar_mul(1.0 / (self.head_dim as f32).sqrt())?;
337
338        // Apply attention mask if provided
339        let scores = if let Some(mask) = attention_mask { scores.add(mask)? } else { scores };
340
341        // Softmax over last dimension
342        let scores_ndim = scores.shape().len();
343        let attn_probs = scores.softmax((scores_ndim as i32) - 1)?;
344
345        // Apply attention to values
346        let attn_output = attn_probs.matmul(&v)?;
347
348        // Transpose back: [batch, num_heads_local, seq_len, head_dim] -> [batch, seq_len, num_heads_local, head_dim]
349        let attn_output = attn_output.transpose(1, 2)?;
350
351        // Reshape back to original format: [batch, seq_len, num_heads_local, head_dim] -> [batch, seq_len, hidden_size_local]
352        let hidden_size_local = num_heads_local * self.head_dim;
353        let attn_output = attn_output.reshape(&[batch_size, seq_len, hidden_size_local])?;
354
355        // Create distributed tensor for row-parallel output projection
356        let attn_distributed = DistributedTensor::new(
357            attn_output,
358            vec![batch_size, seq_len, self.hidden_size],
359            TensorPartition {
360                split_dim: 2,
361                start_idx: self.mp_context.rank() * self.hidden_size / self.mp_context.world_size(),
362                end_idx: ((self.mp_context.rank() + 1) * self.hidden_size
363                    / self.mp_context.world_size())
364                .min(self.hidden_size),
365                num_partitions: self.mp_context.world_size(),
366                partition_rank: self.mp_context.rank(),
367            },
368            self.mp_context.rank(),
369        );
370
371        // Output projection with all-reduce
372        self.o_proj.forward(&attn_distributed)
373    }
374}
375
376/// Parallel MLP/FFN layer
377///
378/// Implements the feed-forward network with model parallelism.
379pub struct ParallelMLP {
380    /// First linear layer (column parallel)
381    fc1: ColumnParallelLinear,
382    /// Second linear layer (row parallel)
383    fc2: RowParallelLinear,
384    /// Activation function
385    activation: ActivationType,
386    /// Model parallel context
387    #[allow(dead_code)]
388    mp_context: Arc<ModelParallelContext>,
389}
390
391#[derive(Debug, Clone, Copy)]
392pub enum ActivationType {
393    Relu,
394    Gelu,
395    GeluNew,
396    Swiglu,
397}
398
399impl ParallelMLP {
400    pub fn new(
401        hidden_size: usize,
402        intermediate_size: usize,
403        activation: ActivationType,
404        mp_context: Arc<ModelParallelContext>,
405    ) -> Result<Self> {
406        let fc1 =
407            ColumnParallelLinear::new(hidden_size, intermediate_size, false, mp_context.clone())?;
408
409        let fc2 =
410            RowParallelLinear::new(intermediate_size, hidden_size, false, mp_context.clone())?;
411
412        Ok(Self {
413            fc1,
414            fc2,
415            activation,
416            mp_context,
417        })
418    }
419
420    pub fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
421        // First linear layer (column parallel)
422        let hidden = self.fc1.forward(hidden_states)?;
423
424        // Apply activation function to the local shard
425        let activated = self.apply_activation(&hidden.local_shard)?;
426
427        // Create distributed tensor for row parallel layer
428        let hidden_distributed = DistributedTensor::new(
429            activated,
430            hidden.global_shape.clone(),
431            hidden.partition.clone(),
432            hidden.device_id,
433        );
434
435        // Second linear layer (row parallel)
436        self.fc2.forward(&hidden_distributed)
437    }
438
439    /// Apply activation function to tensor
440    fn apply_activation(&self, tensor: &Tensor) -> Result<Tensor> {
441        use crate::ops::activations::{gelu, gelu_new, relu, swiglu};
442
443        match self.activation {
444            ActivationType::Relu => Ok(relu(tensor)?),
445            ActivationType::Gelu => Ok(gelu(tensor)?),
446            ActivationType::GeluNew => Ok(gelu_new(tensor)?),
447            ActivationType::Swiglu => {
448                // SwiGLU requires splitting the input tensor and applying activation
449                // For SwiGLU: SwiGLU(x) = Swish(Wx) ⊙ Vx where Swish(x) = x * sigmoid(x)
450                let shape = tensor.shape();
451                if shape[shape.len() - 1] % 2 != 0 {
452                    return Err(tensor_op_error(
453                        "ParallelMLP::apply_activation",
454                        "SwiGLU requires even dimension for splitting",
455                    ));
456                }
457
458                let split_size = shape[shape.len() - 1] / 2;
459                let mut new_shape = shape.to_vec();
460                let last_idx = new_shape.len() - 1;
461                new_shape[last_idx] = split_size;
462
463                // Split tensor into two halves along the last axis
464                let last_axis = shape.len() - 1;
465                let gate_tensor = tensor.slice(last_axis, 0, split_size)?;
466                let up_tensor = tensor.slice(last_axis, split_size, shape[last_axis])?;
467
468                // Apply swish to gate and element-wise multiply
469                Ok(swiglu(&gate_tensor, &up_tensor)?)
470            },
471        }
472    }
473}
474
475#[cfg(test)]
476mod tests {
477    use super::super::model_parallel::{CommunicationBackend, ModelParallelConfig};
478    use super::*;
479
480    #[test]
481    fn test_column_parallel_linear() {
482        let config = ModelParallelConfig {
483            num_devices: 2,
484            device_ids: vec![0, 1],
485            comm_backend: CommunicationBackend::Custom,
486            ..Default::default()
487        };
488
489        let mp_context =
490            Arc::new(ModelParallelContext::new(config).expect("operation failed in test"));
491        let layer = ColumnParallelLinear::new(512, 2048, true, mp_context)
492            .expect("operation failed in test");
493
494        // Check weight dimensions
495        assert_eq!(layer.weight.shape(), &[512, 1024]); // 2048 / 2 = 1024
496    }
497
498    #[test]
499    fn test_parallel_attention_heads() {
500        let config = ModelParallelConfig {
501            num_devices: 4,
502            device_ids: vec![0, 1, 2, 3],
503            comm_backend: CommunicationBackend::Custom,
504            ..Default::default()
505        };
506
507        let mp_context =
508            Arc::new(ModelParallelContext::new(config).expect("operation failed in test"));
509        let attn =
510            ParallelMultiHeadAttention::new(768, 12, mp_context).expect("operation failed in test");
511
512        assert_eq!(attn.num_heads_per_device, 3); // 12 / 4 = 3
513        assert_eq!(attn.head_dim, 64); // 768 / 12 = 64
514    }
515}