1use crate::models::llama2::cpu::{Llama2Config, TransformerWeights};
2use crate::ops::{
3 BatchedMultiqueryAttention, BatchedMultiqueryAttentionParams, RmsNorm, RoPE, RoPEShape, Silu,
4};
5use naga_oil::compose::ComposerError;
6use wgcore::kernel::KernelInvocationQueue;
7use wgcore::tensor::{GpuMatrix, GpuScalar, GpuVector};
8use wgcore::Shader;
9use wgebra::linalg::{Gemv, OpAssign, OpAssignVariant};
10use wgpu::{BufferUsages, Device};
11
12pub struct Llama2State {
13 pub x: GpuVector<f32>,
15 xb: GpuVector<f32>,
17 xb2: GpuVector<f32>,
19 hb: GpuVector<f32>,
21 hb2: GpuVector<f32>,
23 q: GpuVector<f32>,
25 att: GpuMatrix<f32>,
27 logits: GpuVector<f32>,
29 logits_readback: GpuVector<f32>,
30 key_cache: Vec<GpuMatrix<f32>>,
32 value_cache: Vec<GpuMatrix<f32>>,
33 rope_shape: GpuScalar<RoPEShape>,
34 attn_params: GpuScalar<BatchedMultiqueryAttentionParams>,
35}
36
37impl Llama2State {
38 pub fn new(device: &Device, config: &Llama2Config) -> Self {
39 let kv_dim = (config.dim * config.n_kv_heads) / config.n_q_heads;
40 const STORAGE: BufferUsages = BufferUsages::STORAGE;
41 const UNIFORM: BufferUsages = BufferUsages::UNIFORM;
42
43 Self {
44 x: GpuVector::uninit(device, config.dim as u32, STORAGE | BufferUsages::COPY_DST),
45 xb: GpuVector::uninit(device, config.dim as u32, STORAGE),
46 xb2: GpuVector::uninit(device, config.dim as u32, STORAGE),
47 hb: GpuVector::uninit(device, config.hidden_dim as u32, STORAGE),
48 hb2: GpuVector::uninit(device, config.hidden_dim as u32, STORAGE),
49 q: GpuVector::uninit(device, config.dim as u32, STORAGE),
50 key_cache: (0..config.n_layers)
52 .map(|_| GpuMatrix::uninit(device, kv_dim as u32, config.seq_len as u32, STORAGE))
53 .collect(),
54 value_cache: (0..config.n_layers)
55 .map(|_| GpuMatrix::uninit(device, kv_dim as u32, config.seq_len as u32, STORAGE))
56 .collect(),
57 att: GpuMatrix::uninit(
58 device,
59 config.seq_len as u32,
60 config.n_q_heads as u32,
61 STORAGE,
62 ),
63 logits: GpuVector::uninit(
64 device,
65 config.vocab_size as u32,
66 STORAGE | BufferUsages::COPY_SRC,
67 ),
68 logits_readback: GpuVector::uninit(
69 device,
70 config.vocab_size as u32,
71 BufferUsages::MAP_READ | BufferUsages::COPY_DST,
72 ),
73 rope_shape: GpuScalar::uninit(device, UNIFORM | BufferUsages::COPY_DST),
74 attn_params: GpuScalar::uninit(device, UNIFORM | BufferUsages::COPY_DST),
75 }
76 }
77
78 pub fn rope_shape(&self) -> &GpuScalar<RoPEShape> {
79 &self.rope_shape
80 }
81
82 pub fn attn_params(&self) -> &GpuScalar<BatchedMultiqueryAttentionParams> {
83 &self.attn_params
84 }
85
86 pub fn logits(&self) -> &GpuVector<f32> {
87 &self.logits
88 }
89
90 pub fn logits_readback(&self) -> &GpuVector<f32> {
91 &self.logits_readback
92 }
93}
94
95pub struct Llama2LayerWeights {
96 pub attn_k: GpuMatrix<f32>,
97 pub attn_norm: GpuVector<f32>,
98 pub attn_q: GpuMatrix<f32>,
99 pub attn_v: GpuMatrix<f32>,
100 pub ffn_down: GpuMatrix<f32>,
101 pub ffn_gate: GpuMatrix<f32>,
102 pub ffn_norm: GpuVector<f32>,
103 pub ffn_up: GpuMatrix<f32>,
104 pub attn_output: GpuMatrix<f32>,
105}
106
107pub struct Llama2Weights {
108 pub layers: Vec<Llama2LayerWeights>,
109 pub token_embd: GpuMatrix<f32>,
110 pub output: GpuMatrix<f32>,
111 pub output_norm: GpuVector<f32>,
112}
113
114impl Llama2Weights {
115 pub fn from_ram(device: &Device, w: &TransformerWeights) -> Self {
116 let usage = BufferUsages::STORAGE;
117
118 let layers = w
119 .layers
120 .iter()
121 .map(|l| Llama2LayerWeights {
122 attn_k: GpuMatrix::init(device, &l.attn_k, usage),
123 attn_norm: GpuVector::init(device, &l.attn_norm, usage),
124 attn_q: GpuMatrix::init(device, &l.attn_q, usage),
125 attn_v: GpuMatrix::init(device, &l.attn_v, usage),
126 ffn_down: GpuMatrix::init(device, &l.ffn_down, usage),
127 ffn_gate: GpuMatrix::init(device, &l.ffn_gate, usage),
128 ffn_norm: GpuVector::init(device, &l.ffn_norm, usage),
129 ffn_up: GpuMatrix::init(device, &l.ffn_up, usage),
130 attn_output: GpuMatrix::init(device, &l.attn_output, usage),
131 })
132 .collect();
133
134 let token_embd = GpuMatrix::init(device, &w.token_embd, usage | BufferUsages::COPY_SRC);
135 let output = GpuMatrix::init(device, &w.output, usage);
136 let output_norm = GpuVector::init(device, &w.output_norm, usage);
137
138 Self {
139 layers,
140 token_embd,
141 output,
142 output_norm,
143 }
144 }
145}
146
147pub struct Llama2 {
148 attn: BatchedMultiqueryAttention,
149 rms_norm: RmsNorm,
150 rope: RoPE,
151 silu: Silu,
152 matmul: Gemv,
153 add_assign: OpAssign,
154}
155
156impl Llama2 {
157 pub fn new(device: &Device) -> Result<Self, ComposerError> {
158 Ok(Self {
159 attn: BatchedMultiqueryAttention::from_device(device)?,
160 rms_norm: RmsNorm::from_device(device)?,
161 rope: RoPE::from_device(device)?,
162 silu: Silu::from_device(device)?,
163 matmul: Gemv::from_device(device)?,
164 add_assign: OpAssign::new(device, OpAssignVariant::Add)?,
165 })
166 }
167
168 pub fn queue<'a>(
169 &'a self,
170 queue: &mut KernelInvocationQueue<'a>,
171 state: &Llama2State,
172 weights: &Llama2Weights,
173 config: &Llama2Config,
174 pos: u32,
175 ) {
176 for l in 0..config.n_layers {
177 let wl = &weights.layers[l];
178 self.rms_norm
179 .queue(queue, &state.xb, &state.x, &wl.attn_norm);
180
181 let k_cache = state.key_cache[l].column(pos);
182 let v_cache = state.value_cache[l].column(pos);
183
184 self.matmul.queue(queue, &state.q, &wl.attn_q, &state.xb);
185 self.matmul.queue(queue, k_cache, &wl.attn_k, &state.xb);
186 self.matmul.queue(queue, v_cache, &wl.attn_v, &state.xb);
187 self.rope.queue(queue, &state.rope_shape, &state.q, k_cache);
188
189 self.attn.queue(
191 queue,
192 &state.attn_params,
193 &state.q,
194 &state.key_cache[l],
195 &state.value_cache[l],
196 &state.att,
197 &state.xb,
198 );
199 self.matmul
200 .queue(queue, &state.xb2, &wl.attn_output, &state.xb);
201 self.add_assign.queue(queue, &state.x, &state.xb2);
204 self.rms_norm
205 .queue(queue, &state.xb, &state.x, &wl.ffn_norm);
206
207 self.matmul.queue(queue, &state.hb, &wl.ffn_gate, &state.xb);
209 self.matmul.queue(queue, &state.hb2, &wl.ffn_up, &state.xb);
210 self.silu.queue(queue, &state.hb, &state.hb2);
211 self.matmul
212 .queue(queue, &state.xb2, &wl.ffn_down, &state.hb);
213 self.add_assign.queue(queue, &state.x, &state.xb2);
216 }
217
218 self.rms_norm
219 .queue(queue, &state.xb, &state.x, &weights.output_norm);
220
221 self.matmul
222 .queue(queue, &state.logits, &weights.output, &state.xb);
223 }
224}