1use bytemuck::Pod;
2use nalgebra::{DVector, Dyn, Storage, Vector};
3use wgcore::kernel::{KernelInvocationBuilder, KernelInvocationQueue};
4use wgcore::tensor::GpuVectorView;
5use wgcore::Shader;
6use wgebra::linalg::Shape;
7use wgpu::ComputePipeline;
8
9#[derive(Shader)]
10#[shader(derive(Shape), src = "rms_norm.wgsl", composable = false)]
11pub struct RmsNorm {
13 pub main: ComputePipeline,
14}
15
16impl RmsNorm {
17 pub fn queue<'a, 'b, T: Pod>(
18 &'a self,
19 queue: &mut KernelInvocationQueue<'a>,
20 result: impl Into<GpuVectorView<'b, T>>,
21 value: impl Into<GpuVectorView<'b, T>>,
22 weight: impl Into<GpuVectorView<'b, T>>,
23 ) {
24 let value = value.into();
25 let weight = weight.into();
26 let result = result.into();
27
28 let value_shape_buf = queue.shape_buffer(value.shape());
29 let weight_shape_buf = queue.shape_buffer(weight.shape());
30 let result_shape_buf = queue.shape_buffer(result.shape());
31
32 KernelInvocationBuilder::new(queue, &self.main)
33 .bind0([
34 &value_shape_buf,
35 &weight_shape_buf,
36 &result_shape_buf,
37 value.buffer(),
38 weight.buffer(),
39 result.buffer(),
40 ])
41 .queue(1);
42 }
43
44 pub fn run_cpu<SW: Storage<f32, Dyn>>(
45 out: &mut DVector<f32>,
46 a: &DVector<f32>,
47 w: &Vector<f32, Dyn, SW>,
48 ) {
49 const NUDGE_FACTOR: f32 = 1.0e-5;
50 let rms = 1.0 / (a.norm_squared() / (a.nrows() as f32) + NUDGE_FACTOR).sqrt();
51 out.zip_zip_apply(a, w, |o, a, w| *o = (a * rms) * w);
52 }
53}
54
55#[cfg(test)]
56mod test {
57 use crate::ops::RmsNorm;
58 use nalgebra::DVector;
59 use wgcore::gpu::GpuInstance;
60 use wgcore::kernel::KernelInvocationQueue;
61 use wgcore::tensor::GpuVector;
62 use wgcore::Shader;
63 use wgpu::BufferUsages;
64
65 #[futures_test::test]
66 #[serial_test::serial]
67 async fn gpu_rms_norm() {
68 let gpu = GpuInstance::new().await.unwrap();
69 let rmsnorm = super::RmsNorm::from_device(gpu.device());
70 let mut queue = KernelInvocationQueue::new(gpu.device());
71 let mut encoder = gpu.device().create_command_encoder(&Default::default());
72
73 const LEN: u32 = 1757;
74
75 let result = DVector::new_random(LEN as usize);
76 let value = DVector::new_random(LEN as usize);
77 let weight = DVector::new_random(LEN as usize);
78
79 let gpu_result = GpuVector::init(
80 gpu.device(),
81 &result,
82 BufferUsages::STORAGE | BufferUsages::COPY_SRC,
83 );
84 let gpu_value = GpuVector::init(gpu.device(), &value, BufferUsages::STORAGE);
85 let gpu_weight = GpuVector::init(gpu.device(), &weight, BufferUsages::STORAGE);
86 let gpu_staging = GpuVector::uninit(
87 gpu.device(),
88 result.len() as u32,
89 BufferUsages::MAP_READ | BufferUsages::COPY_DST,
90 );
91
92 rmsnorm.queue(&mut queue, &gpu_result, &gpu_value, &gpu_weight);
93
94 queue.encode(&mut encoder, None);
95 gpu_staging.copy_from(&mut encoder, &gpu_result);
96
97 gpu.queue().submit(Some(encoder.finish()));
98
99 let mut cpu_result = result;
100 RmsNorm::run_cpu(&mut cpu_result, &value, &weight);
101
102 approx::assert_relative_eq!(
103 DVector::from(gpu_staging.read(gpu.device()).await.unwrap()),
104 cpu_result,
105 epsilon = 1.0e-5
106 );
107 }
108}