1#![allow(unused_variables)] use super::model_parallel::{DistributedTensor, ModelParallelContext, TensorPartition};
9use crate::errors::{tensor_op_error, Result};
10use crate::Tensor;
11use std::sync::Arc;
12
13pub struct ColumnParallelLinear {
18 weight: Tensor,
20 bias: Option<Tensor>,
22 mp_context: Arc<ModelParallelContext>,
24 #[allow(dead_code)]
26 in_features: usize,
27 out_features: usize,
29}
30
31impl ColumnParallelLinear {
32 pub fn new(
33 in_features: usize,
34 out_features: usize,
35 bias: bool,
36 mp_context: Arc<ModelParallelContext>,
37 ) -> Result<Self> {
38 let world_size = mp_context.world_size();
39 let rank = mp_context.rank();
40
41 let out_features_per_device = out_features.div_ceil(world_size);
43 let local_out_start = rank * out_features_per_device;
44 let local_out_end = ((rank + 1) * out_features_per_device).min(out_features);
45 let local_out_features = local_out_end - local_out_start;
46
47 let weight = Tensor::randn(&[in_features, local_out_features])?;
49
50 let bias = if bias && rank == 0 { Some(Tensor::zeros(&[out_features])?) } else { None };
52
53 Ok(Self {
54 weight,
55 bias,
56 mp_context,
57 in_features,
58 out_features,
59 })
60 }
61
62 pub fn forward(&self, input: &Tensor) -> Result<DistributedTensor> {
63 let output = input.matmul(&self.weight)?;
68
69 let output = if let Some(ref bias) = self.bias {
71 let rank = self.mp_context.rank();
73 let world_size = self.mp_context.world_size();
74 let out_features_per_device = self.out_features.div_ceil(world_size);
75 let local_out_start = rank * out_features_per_device;
76 let local_out_end = ((rank + 1) * out_features_per_device).min(self.out_features);
77
78 let local_bias = bias.slice(0, local_out_start, local_out_end)?;
79 output.add(&local_bias)?
80 } else {
81 output
82 };
83
84 let mut global_shape = input.shape().to_vec();
86 let last_dim = global_shape.len() - 1;
87 global_shape[last_dim] = self.out_features;
88
89 let partition = TensorPartition {
90 split_dim: global_shape.len() - 1,
91 start_idx: self.mp_context.rank() * self.out_features / self.mp_context.world_size(),
92 end_idx: ((self.mp_context.rank() + 1) * self.out_features
93 / self.mp_context.world_size())
94 .min(self.out_features),
95 num_partitions: self.mp_context.world_size(),
96 partition_rank: self.mp_context.rank(),
97 };
98
99 Ok(DistributedTensor::new(
100 output,
101 global_shape,
102 partition,
103 self.mp_context.rank(),
104 ))
105 }
106
107 pub fn backward(&mut self, grad_output: &DistributedTensor, input: &Tensor) -> Result<Tensor> {
108 let input_ndim = input.shape().len();
110 let grad_weight = input
111 .transpose(input_ndim.saturating_sub(2), input_ndim.saturating_sub(1))?
112 .matmul(&grad_output.local_shard)?;
113
114 let weight_ndim = self.weight.shape().len();
117 let mut grad_input = grad_output.local_shard.matmul(
118 &self
119 .weight
120 .transpose(weight_ndim.saturating_sub(2), weight_ndim.saturating_sub(1))?,
121 )?;
122
123 self.mp_context.all_reduce(&mut grad_input)?;
125
126 Ok(grad_input)
131 }
132}
133
134pub struct RowParallelLinear {
139 weight: Tensor,
141 bias: Option<Tensor>,
143 mp_context: Arc<ModelParallelContext>,
145 #[allow(dead_code)]
147 in_features: usize,
148 _out_features: usize,
150}
151
152impl RowParallelLinear {
153 pub fn new(
154 in_features: usize,
155 out_features: usize,
156 bias: bool,
157 mp_context: Arc<ModelParallelContext>,
158 ) -> Result<Self> {
159 let world_size = mp_context.world_size();
160 let rank = mp_context.rank();
161
162 let in_features_per_device = in_features.div_ceil(world_size);
164 let local_in_start = rank * in_features_per_device;
165 let local_in_end = ((rank + 1) * in_features_per_device).min(in_features);
166 let local_in_features = local_in_end - local_in_start;
167
168 let weight = Tensor::randn(&[local_in_features, out_features])?;
170
171 let bias = if bias { Some(Tensor::zeros(&[out_features])?) } else { None };
173
174 Ok(Self {
175 weight,
176 bias,
177 mp_context,
178 in_features,
179 _out_features: out_features,
180 })
181 }
182
183 pub fn forward(&self, input: &DistributedTensor) -> Result<Tensor> {
184 let local_output = input.local_shard.matmul(&self.weight)?;
189
190 let mut output = local_output;
192 self.mp_context.all_reduce(&mut output)?;
193
194 if let Some(bias) = &self.bias {
196 output = output.add(bias)?;
197 }
198
199 Ok(output)
200 }
201
202 pub fn backward(
203 &mut self,
204 grad_output: &Tensor,
205 input: &DistributedTensor,
206 ) -> Result<DistributedTensor> {
207 let input_ndim = input.local_shard.shape().len();
209 let grad_weight = input
210 .local_shard
211 .transpose(input_ndim.saturating_sub(2), input_ndim.saturating_sub(1))?
212 .matmul(grad_output)?;
213
214 let weight_ndim = self.weight.shape().len();
216 let grad_input_local = grad_output.matmul(
217 &self
218 .weight
219 .transpose(weight_ndim.saturating_sub(2), weight_ndim.saturating_sub(1))?,
220 )?;
221
222 let partition = input.partition.clone();
224 Ok(DistributedTensor::new(
225 grad_input_local,
226 input.global_shape.clone(),
227 partition,
228 self.mp_context.rank(),
229 ))
230 }
231}
232
233pub struct ParallelMultiHeadAttention {
237 #[allow(dead_code)]
239 num_heads_per_device: usize,
240 num_heads: usize,
242 head_dim: usize,
244 hidden_size: usize,
246 q_proj: ColumnParallelLinear,
248 k_proj: ColumnParallelLinear,
250 v_proj: ColumnParallelLinear,
252 o_proj: RowParallelLinear,
254 mp_context: Arc<ModelParallelContext>,
256}
257
258impl ParallelMultiHeadAttention {
259 pub fn new(
260 hidden_size: usize,
261 num_heads: usize,
262 mp_context: Arc<ModelParallelContext>,
263 ) -> Result<Self> {
264 let world_size = mp_context.world_size();
265
266 if num_heads % world_size != 0 {
267 return Err(tensor_op_error(
268 "ParallelMultiHeadAttention::new",
269 format!(
270 "Number of heads {} must be divisible by world size {}",
271 num_heads, world_size
272 ),
273 ));
274 }
275
276 let num_heads_per_device = num_heads / world_size;
277 let head_dim = hidden_size / num_heads;
278
279 let q_proj =
281 ColumnParallelLinear::new(hidden_size, hidden_size, false, mp_context.clone())?;
282
283 let k_proj =
284 ColumnParallelLinear::new(hidden_size, hidden_size, false, mp_context.clone())?;
285
286 let v_proj =
287 ColumnParallelLinear::new(hidden_size, hidden_size, false, mp_context.clone())?;
288
289 let o_proj = RowParallelLinear::new(hidden_size, hidden_size, false, mp_context.clone())?;
290
291 Ok(Self {
292 num_heads_per_device,
293 num_heads,
294 head_dim,
295 hidden_size,
296 q_proj,
297 k_proj,
298 v_proj,
299 o_proj,
300 mp_context,
301 })
302 }
303
304 pub fn forward(
305 &self,
306 hidden_states: &Tensor,
307 attention_mask: Option<&Tensor>,
308 ) -> Result<Tensor> {
309 let batch_size = hidden_states.shape()[0];
310 let seq_len = hidden_states.shape()[1];
311
312 let q = self.q_proj.forward(hidden_states)?;
314 let k = self.k_proj.forward(hidden_states)?;
315 let v = self.v_proj.forward(hidden_states)?;
316
317 let q = q.local_shard.clone();
319 let k = k.local_shard.clone();
320 let v = v.local_shard.clone();
321
322 let num_heads_local = self.num_heads / self.mp_context.world_size();
324 let q = q.reshape(&[batch_size, seq_len, num_heads_local, self.head_dim])?;
325 let k = k.reshape(&[batch_size, seq_len, num_heads_local, self.head_dim])?;
326 let v = v.reshape(&[batch_size, seq_len, num_heads_local, self.head_dim])?;
327
328 let q = q.transpose(1, 2)?;
330 let k = k.transpose(1, 2)?;
331 let v = v.transpose(1, 2)?;
332
333 let k_ndim = k.shape().len();
335 let scores = q.matmul(&k.transpose(k_ndim.saturating_sub(2), k_ndim.saturating_sub(1))?)?;
336 let scores = scores.scalar_mul(1.0 / (self.head_dim as f32).sqrt())?;
337
338 let scores = if let Some(mask) = attention_mask { scores.add(mask)? } else { scores };
340
341 let scores_ndim = scores.shape().len();
343 let attn_probs = scores.softmax((scores_ndim as i32) - 1)?;
344
345 let attn_output = attn_probs.matmul(&v)?;
347
348 let attn_output = attn_output.transpose(1, 2)?;
350
351 let hidden_size_local = num_heads_local * self.head_dim;
353 let attn_output = attn_output.reshape(&[batch_size, seq_len, hidden_size_local])?;
354
355 let attn_distributed = DistributedTensor::new(
357 attn_output,
358 vec![batch_size, seq_len, self.hidden_size],
359 TensorPartition {
360 split_dim: 2,
361 start_idx: self.mp_context.rank() * self.hidden_size / self.mp_context.world_size(),
362 end_idx: ((self.mp_context.rank() + 1) * self.hidden_size
363 / self.mp_context.world_size())
364 .min(self.hidden_size),
365 num_partitions: self.mp_context.world_size(),
366 partition_rank: self.mp_context.rank(),
367 },
368 self.mp_context.rank(),
369 );
370
371 self.o_proj.forward(&attn_distributed)
373 }
374}
375
376pub struct ParallelMLP {
380 fc1: ColumnParallelLinear,
382 fc2: RowParallelLinear,
384 activation: ActivationType,
386 #[allow(dead_code)]
388 mp_context: Arc<ModelParallelContext>,
389}
390
391#[derive(Debug, Clone, Copy)]
392pub enum ActivationType {
393 Relu,
394 Gelu,
395 GeluNew,
396 Swiglu,
397}
398
399impl ParallelMLP {
400 pub fn new(
401 hidden_size: usize,
402 intermediate_size: usize,
403 activation: ActivationType,
404 mp_context: Arc<ModelParallelContext>,
405 ) -> Result<Self> {
406 let fc1 =
407 ColumnParallelLinear::new(hidden_size, intermediate_size, false, mp_context.clone())?;
408
409 let fc2 =
410 RowParallelLinear::new(intermediate_size, hidden_size, false, mp_context.clone())?;
411
412 Ok(Self {
413 fc1,
414 fc2,
415 activation,
416 mp_context,
417 })
418 }
419
420 pub fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
421 let hidden = self.fc1.forward(hidden_states)?;
423
424 let activated = self.apply_activation(&hidden.local_shard)?;
426
427 let hidden_distributed = DistributedTensor::new(
429 activated,
430 hidden.global_shape.clone(),
431 hidden.partition.clone(),
432 hidden.device_id,
433 );
434
435 self.fc2.forward(&hidden_distributed)
437 }
438
439 fn apply_activation(&self, tensor: &Tensor) -> Result<Tensor> {
441 use crate::ops::activations::{gelu, gelu_new, relu, swiglu};
442
443 match self.activation {
444 ActivationType::Relu => Ok(relu(tensor)?),
445 ActivationType::Gelu => Ok(gelu(tensor)?),
446 ActivationType::GeluNew => Ok(gelu_new(tensor)?),
447 ActivationType::Swiglu => {
448 let shape = tensor.shape();
451 if shape[shape.len() - 1] % 2 != 0 {
452 return Err(tensor_op_error(
453 "ParallelMLP::apply_activation",
454 "SwiGLU requires even dimension for splitting",
455 ));
456 }
457
458 let split_size = shape[shape.len() - 1] / 2;
459 let mut new_shape = shape.to_vec();
460 let last_idx = new_shape.len() - 1;
461 new_shape[last_idx] = split_size;
462
463 let last_axis = shape.len() - 1;
465 let gate_tensor = tensor.slice(last_axis, 0, split_size)?;
466 let up_tensor = tensor.slice(last_axis, split_size, shape[last_axis])?;
467
468 Ok(swiglu(&gate_tensor, &up_tensor)?)
470 },
471 }
472 }
473}
474
475#[cfg(test)]
476mod tests {
477 use super::super::model_parallel::{CommunicationBackend, ModelParallelConfig};
478 use super::*;
479
480 #[test]
481 fn test_column_parallel_linear() {
482 let config = ModelParallelConfig {
483 num_devices: 2,
484 device_ids: vec![0, 1],
485 comm_backend: CommunicationBackend::Custom,
486 ..Default::default()
487 };
488
489 let mp_context =
490 Arc::new(ModelParallelContext::new(config).expect("operation failed in test"));
491 let layer = ColumnParallelLinear::new(512, 2048, true, mp_context)
492 .expect("operation failed in test");
493
494 assert_eq!(layer.weight.shape(), &[512, 1024]); }
497
498 #[test]
499 fn test_parallel_attention_heads() {
500 let config = ModelParallelConfig {
501 num_devices: 4,
502 device_ids: vec![0, 1, 2, 3],
503 comm_backend: CommunicationBackend::Custom,
504 ..Default::default()
505 };
506
507 let mp_context =
508 Arc::new(ModelParallelContext::new(config).expect("operation failed in test"));
509 let attn =
510 ParallelMultiHeadAttention::new(768, 12, mp_context).expect("operation failed in test");
511
512 assert_eq!(attn.num_heads_per_device, 3); assert_eq!(attn.head_dim, 64); }
515}