wgml/ops/
rms_norm.rs

1use bytemuck::Pod;
2use nalgebra::{DVector, Dyn, Storage, Vector};
3use wgcore::kernel::{KernelInvocationBuilder, KernelInvocationQueue};
4use wgcore::tensor::GpuVectorView;
5use wgcore::Shader;
6use wgebra::linalg::Shape;
7use wgpu::ComputePipeline;
8
9#[derive(Shader)]
10#[shader(derive(Shape), src = "rms_norm.wgsl", composable = false)]
11/// Shader implementing the RMS norm kernel.
12pub struct RmsNorm {
13    pub main: ComputePipeline,
14}
15
16impl RmsNorm {
17    pub fn queue<'a, 'b, T: Pod>(
18        &'a self,
19        queue: &mut KernelInvocationQueue<'a>,
20        result: impl Into<GpuVectorView<'b, T>>,
21        value: impl Into<GpuVectorView<'b, T>>,
22        weight: impl Into<GpuVectorView<'b, T>>,
23    ) {
24        let value = value.into();
25        let weight = weight.into();
26        let result = result.into();
27
28        let value_shape_buf = queue.shape_buffer(value.shape());
29        let weight_shape_buf = queue.shape_buffer(weight.shape());
30        let result_shape_buf = queue.shape_buffer(result.shape());
31
32        KernelInvocationBuilder::new(queue, &self.main)
33            .bind0([
34                &value_shape_buf,
35                &weight_shape_buf,
36                &result_shape_buf,
37                value.buffer(),
38                weight.buffer(),
39                result.buffer(),
40            ])
41            .queue(1);
42    }
43
44    pub fn run_cpu<SW: Storage<f32, Dyn>>(
45        out: &mut DVector<f32>,
46        a: &DVector<f32>,
47        w: &Vector<f32, Dyn, SW>,
48    ) {
49        const NUDGE_FACTOR: f32 = 1.0e-5;
50        let rms = 1.0 / (a.norm_squared() / (a.nrows() as f32) + NUDGE_FACTOR).sqrt();
51        out.zip_zip_apply(a, w, |o, a, w| *o = (a * rms) * w);
52    }
53}
54
55#[cfg(test)]
56mod test {
57    use crate::ops::RmsNorm;
58    use nalgebra::DVector;
59    use wgcore::gpu::GpuInstance;
60    use wgcore::kernel::KernelInvocationQueue;
61    use wgcore::tensor::GpuVector;
62    use wgcore::Shader;
63    use wgpu::BufferUsages;
64
65    #[futures_test::test]
66    #[serial_test::serial]
67    async fn gpu_rms_norm() {
68        let gpu = GpuInstance::new().await.unwrap();
69        let rmsnorm = super::RmsNorm::from_device(gpu.device());
70        let mut queue = KernelInvocationQueue::new(gpu.device());
71        let mut encoder = gpu.device().create_command_encoder(&Default::default());
72
73        const LEN: u32 = 1757;
74
75        let result = DVector::new_random(LEN as usize);
76        let value = DVector::new_random(LEN as usize);
77        let weight = DVector::new_random(LEN as usize);
78
79        let gpu_result = GpuVector::init(
80            gpu.device(),
81            &result,
82            BufferUsages::STORAGE | BufferUsages::COPY_SRC,
83        );
84        let gpu_value = GpuVector::init(gpu.device(), &value, BufferUsages::STORAGE);
85        let gpu_weight = GpuVector::init(gpu.device(), &weight, BufferUsages::STORAGE);
86        let gpu_staging = GpuVector::uninit(
87            gpu.device(),
88            result.len() as u32,
89            BufferUsages::MAP_READ | BufferUsages::COPY_DST,
90        );
91
92        rmsnorm.queue(&mut queue, &gpu_result, &gpu_value, &gpu_weight);
93
94        queue.encode(&mut encoder, None);
95        gpu_staging.copy_from(&mut encoder, &gpu_result);
96
97        gpu.queue().submit(Some(encoder.finish()));
98
99        let mut cpu_result = result;
100        RmsNorm::run_cpu(&mut cpu_result, &value, &weight);
101
102        approx::assert_relative_eq!(
103            DVector::from(gpu_staging.read(gpu.device()).await.unwrap()),
104            cpu_result,
105            epsilon = 1.0e-5
106        );
107    }
108}