wgml/models/llama2/
transformer.rs

1use crate::models::llama2::cpu::{Llama2Config, TransformerWeights};
2use crate::ops::{
3    BatchedMultiqueryAttention, BatchedMultiqueryAttentionParams, RmsNorm, RoPE, RoPEShape, Silu,
4};
5use naga_oil::compose::ComposerError;
6use wgcore::kernel::KernelInvocationQueue;
7use wgcore::tensor::{GpuMatrix, GpuScalar, GpuVector};
8use wgcore::Shader;
9use wgebra::linalg::{Gemv, OpAssign, OpAssignVariant};
10use wgpu::{BufferUsages, Device};
11
12pub struct Llama2State {
13    /// Activation at current time stamp.
14    pub x: GpuVector<f32>,
15    /// Activation at current time stamp, inside a residual branch.
16    xb: GpuVector<f32>,
17    /// Additional buffer for convenience.
18    xb2: GpuVector<f32>,
19    /// Buffer for hidden dimension in the Feed-Forward net.
20    hb: GpuVector<f32>,
21    /// Another buffer for hidden dimension in the Feed-Forward net.
22    hb2: GpuVector<f32>,
23    /// Query.
24    q: GpuVector<f32>,
25    /// Scores/attention values.
26    att: GpuMatrix<f32>,
27    /// Output logits.
28    logits: GpuVector<f32>,
29    logits_readback: GpuVector<f32>,
30    // KV cache. Each Vec contains `layer` elements.
31    key_cache: Vec<GpuMatrix<f32>>,
32    value_cache: Vec<GpuMatrix<f32>>,
33    rope_shape: GpuScalar<RoPEShape>,
34    attn_params: GpuScalar<BatchedMultiqueryAttentionParams>,
35}
36
37impl Llama2State {
38    pub fn new(device: &Device, config: &Llama2Config) -> Self {
39        let kv_dim = (config.dim * config.n_kv_heads) / config.n_q_heads;
40        const STORAGE: BufferUsages = BufferUsages::STORAGE;
41        const UNIFORM: BufferUsages = BufferUsages::UNIFORM;
42
43        Self {
44            x: GpuVector::uninit(device, config.dim as u32, STORAGE | BufferUsages::COPY_DST),
45            xb: GpuVector::uninit(device, config.dim as u32, STORAGE),
46            xb2: GpuVector::uninit(device, config.dim as u32, STORAGE),
47            hb: GpuVector::uninit(device, config.hidden_dim as u32, STORAGE),
48            hb2: GpuVector::uninit(device, config.hidden_dim as u32, STORAGE),
49            q: GpuVector::uninit(device, config.dim as u32, STORAGE),
50            // TODO: for these two, the `kv_dim` doesn’t match the dimension in the field’s comment.
51            key_cache: (0..config.n_layers)
52                .map(|_| GpuMatrix::uninit(device, kv_dim as u32, config.seq_len as u32, STORAGE))
53                .collect(),
54            value_cache: (0..config.n_layers)
55                .map(|_| GpuMatrix::uninit(device, kv_dim as u32, config.seq_len as u32, STORAGE))
56                .collect(),
57            att: GpuMatrix::uninit(
58                device,
59                config.seq_len as u32,
60                config.n_q_heads as u32,
61                STORAGE,
62            ),
63            logits: GpuVector::uninit(
64                device,
65                config.vocab_size as u32,
66                STORAGE | BufferUsages::COPY_SRC,
67            ),
68            logits_readback: GpuVector::uninit(
69                device,
70                config.vocab_size as u32,
71                BufferUsages::MAP_READ | BufferUsages::COPY_DST,
72            ),
73            rope_shape: GpuScalar::uninit(device, UNIFORM | BufferUsages::COPY_DST),
74            attn_params: GpuScalar::uninit(device, UNIFORM | BufferUsages::COPY_DST),
75        }
76    }
77
78    pub fn rope_shape(&self) -> &GpuScalar<RoPEShape> {
79        &self.rope_shape
80    }
81
82    pub fn attn_params(&self) -> &GpuScalar<BatchedMultiqueryAttentionParams> {
83        &self.attn_params
84    }
85
86    pub fn logits(&self) -> &GpuVector<f32> {
87        &self.logits
88    }
89
90    pub fn logits_readback(&self) -> &GpuVector<f32> {
91        &self.logits_readback
92    }
93}
94
95pub struct Llama2LayerWeights {
96    pub attn_k: GpuMatrix<f32>,
97    pub attn_norm: GpuVector<f32>,
98    pub attn_q: GpuMatrix<f32>,
99    pub attn_v: GpuMatrix<f32>,
100    pub ffn_down: GpuMatrix<f32>,
101    pub ffn_gate: GpuMatrix<f32>,
102    pub ffn_norm: GpuVector<f32>,
103    pub ffn_up: GpuMatrix<f32>,
104    pub attn_output: GpuMatrix<f32>,
105}
106
107pub struct Llama2Weights {
108    pub layers: Vec<Llama2LayerWeights>,
109    pub token_embd: GpuMatrix<f32>,
110    pub output: GpuMatrix<f32>,
111    pub output_norm: GpuVector<f32>,
112}
113
114impl Llama2Weights {
115    pub fn from_ram(device: &Device, w: &TransformerWeights) -> Self {
116        let usage = BufferUsages::STORAGE;
117
118        let layers = w
119            .layers
120            .iter()
121            .map(|l| Llama2LayerWeights {
122                attn_k: GpuMatrix::init(device, &l.attn_k, usage),
123                attn_norm: GpuVector::init(device, &l.attn_norm, usage),
124                attn_q: GpuMatrix::init(device, &l.attn_q, usage),
125                attn_v: GpuMatrix::init(device, &l.attn_v, usage),
126                ffn_down: GpuMatrix::init(device, &l.ffn_down, usage),
127                ffn_gate: GpuMatrix::init(device, &l.ffn_gate, usage),
128                ffn_norm: GpuVector::init(device, &l.ffn_norm, usage),
129                ffn_up: GpuMatrix::init(device, &l.ffn_up, usage),
130                attn_output: GpuMatrix::init(device, &l.attn_output, usage),
131            })
132            .collect();
133
134        let token_embd = GpuMatrix::init(device, &w.token_embd, usage | BufferUsages::COPY_SRC);
135        let output = GpuMatrix::init(device, &w.output, usage);
136        let output_norm = GpuVector::init(device, &w.output_norm, usage);
137
138        Self {
139            layers,
140            token_embd,
141            output,
142            output_norm,
143        }
144    }
145}
146
147pub struct Llama2 {
148    attn: BatchedMultiqueryAttention,
149    rms_norm: RmsNorm,
150    rope: RoPE,
151    silu: Silu,
152    matmul: Gemv,
153    add_assign: OpAssign,
154}
155
156impl Llama2 {
157    pub fn new(device: &Device) -> Result<Self, ComposerError> {
158        Ok(Self {
159            attn: BatchedMultiqueryAttention::from_device(device)?,
160            rms_norm: RmsNorm::from_device(device)?,
161            rope: RoPE::from_device(device)?,
162            silu: Silu::from_device(device)?,
163            matmul: Gemv::from_device(device)?,
164            add_assign: OpAssign::new(device, OpAssignVariant::Add)?,
165        })
166    }
167
168    pub fn queue<'a>(
169        &'a self,
170        queue: &mut KernelInvocationQueue<'a>,
171        state: &Llama2State,
172        weights: &Llama2Weights,
173        config: &Llama2Config,
174        pos: u32,
175    ) {
176        for l in 0..config.n_layers {
177            let wl = &weights.layers[l];
178            self.rms_norm
179                .queue(queue, &state.xb, &state.x, &wl.attn_norm);
180
181            let k_cache = state.key_cache[l].column(pos);
182            let v_cache = state.value_cache[l].column(pos);
183
184            self.matmul.queue(queue, &state.q, &wl.attn_q, &state.xb);
185            self.matmul.queue(queue, k_cache, &wl.attn_k, &state.xb);
186            self.matmul.queue(queue, v_cache, &wl.attn_v, &state.xb);
187            self.rope.queue(queue, &state.rope_shape, &state.q, k_cache);
188
189            // Start attention.
190            self.attn.queue(
191                queue,
192                &state.attn_params,
193                &state.q,
194                &state.key_cache[l],
195                &state.value_cache[l],
196                &state.att,
197                &state.xb,
198            );
199            self.matmul
200                .queue(queue, &state.xb2, &wl.attn_output, &state.xb);
201            // End attention.
202
203            self.add_assign.queue(queue, &state.x, &state.xb2);
204            self.rms_norm
205                .queue(queue, &state.xb, &state.x, &wl.ffn_norm);
206
207            // Start ffn_silu
208            self.matmul.queue(queue, &state.hb, &wl.ffn_gate, &state.xb);
209            self.matmul.queue(queue, &state.hb2, &wl.ffn_up, &state.xb);
210            self.silu.queue(queue, &state.hb, &state.hb2);
211            self.matmul
212                .queue(queue, &state.xb2, &wl.ffn_down, &state.hb);
213            // End ffn_silu
214
215            self.add_assign.queue(queue, &state.x, &state.xb2);
216        }
217
218        self.rms_norm
219            .queue(queue, &state.xb, &state.x, &weights.output_norm);
220
221        self.matmul
222            .queue(queue, &state.logits, &weights.output, &state.xb);
223    }
224}