1use yscv_autograd::{Graph, NodeId};
2use yscv_tensor::Tensor;
3
4use super::super::attention::{
5 FeedForward, MultiHeadAttention, MultiHeadAttentionConfig, TransformerEncoderBlock,
6};
7use crate::ModelError;
8
9#[derive(Debug, Clone)]
11pub struct EmbeddingLayer {
12 num_embeddings: usize,
13 embedding_dim: usize,
14 weight: NodeId,
15}
16
17impl EmbeddingLayer {
18 pub fn new(
19 graph: &mut Graph,
20 num_embeddings: usize,
21 embedding_dim: usize,
22 weight_init: Tensor,
23 ) -> Result<Self, ModelError> {
24 let expected = vec![num_embeddings, embedding_dim];
25 if weight_init.shape() != expected {
26 return Err(ModelError::InvalidParameterShape {
27 parameter: "embedding_weight",
28 expected,
29 got: weight_init.shape().to_vec(),
30 });
31 }
32 let weight = graph.variable(weight_init);
33 Ok(Self {
34 num_embeddings,
35 embedding_dim,
36 weight,
37 })
38 }
39
40 pub fn num_embeddings(&self) -> usize {
41 self.num_embeddings
42 }
43 pub fn embedding_dim(&self) -> usize {
44 self.embedding_dim
45 }
46 pub fn weight_node(&self) -> NodeId {
47 self.weight
48 }
49
50 pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
51 Ok(graph.embedding_lookup(self.weight, input)?)
52 }
53
54 pub fn forward_inference(&self, graph: &Graph, indices: &Tensor) -> Result<Tensor, ModelError> {
55 let weight = graph.value(self.weight)?;
56 let w_data = weight.data();
57 let idx_data = indices.data();
58 let batch = idx_data.len();
59 let dim = self.embedding_dim;
60 let mut out = vec![0.0f32; batch * dim];
61 for (i, &idx_f) in idx_data.iter().enumerate() {
62 let idx = idx_f as usize;
63 if idx >= self.num_embeddings {
64 return Err(ModelError::InvalidInputShape {
65 expected_features: self.num_embeddings,
66 got: indices.shape().to_vec(),
67 });
68 }
69 out[i * dim..(i + 1) * dim].copy_from_slice(&w_data[idx * dim..(idx + 1) * dim]);
70 }
71 let mut shape = indices.shape().to_vec();
72 shape.push(dim);
73 Ok(Tensor::from_vec(shape, out)?)
74 }
75}
76
77pub struct MultiHeadAttentionLayer {
81 pub mha: MultiHeadAttention,
82 w_q_node: Option<NodeId>,
83 w_k_node: Option<NodeId>,
84 w_v_node: Option<NodeId>,
85 w_o_node: Option<NodeId>,
86}
87
88impl std::fmt::Debug for MultiHeadAttentionLayer {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 f.debug_struct("MultiHeadAttentionLayer")
91 .field("num_heads", &self.mha.num_heads)
92 .field("d_k", &self.mha.d_k)
93 .finish()
94 }
95}
96
97impl Clone for MultiHeadAttentionLayer {
98 fn clone(&self) -> Self {
99 Self {
100 mha: MultiHeadAttention {
101 w_q: self.mha.w_q.clone(),
102 w_k: self.mha.w_k.clone(),
103 w_v: self.mha.w_v.clone(),
104 w_o: self.mha.w_o.clone(),
105 num_heads: self.mha.num_heads,
106 d_k: self.mha.d_k,
107 },
108 w_q_node: self.w_q_node,
109 w_k_node: self.w_k_node,
110 w_v_node: self.w_v_node,
111 w_o_node: self.w_o_node,
112 }
113 }
114}
115
116impl MultiHeadAttentionLayer {
117 pub fn w_q_node(&self) -> Option<NodeId> {
118 self.w_q_node
119 }
120
121 pub fn new(d_model: usize, num_heads: usize, _seed: u64) -> Self {
122 let config = MultiHeadAttentionConfig { d_model, num_heads };
123 let mha = MultiHeadAttention::new(&config).expect("valid MHA config");
124 Self {
125 mha,
126 w_q_node: None,
127 w_k_node: None,
128 w_v_node: None,
129 w_o_node: None,
130 }
131 }
132
133 pub fn register_params(&mut self, graph: &mut Graph) {
134 self.w_q_node = Some(graph.variable(self.mha.w_q.clone()));
135 self.w_k_node = Some(graph.variable(self.mha.w_k.clone()));
136 self.w_v_node = Some(graph.variable(self.mha.w_v.clone()));
137 self.w_o_node = Some(graph.variable(self.mha.w_o.clone()));
138 }
139
140 pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
141 let w_q = self.w_q_node.ok_or(ModelError::ParamsNotRegistered {
142 layer: "MultiHeadAttention",
143 })?;
144 let w_k = self.w_k_node.ok_or(ModelError::ParamsNotRegistered {
145 layer: "MultiHeadAttention",
146 })?;
147 let w_v = self.w_v_node.ok_or(ModelError::ParamsNotRegistered {
148 layer: "MultiHeadAttention",
149 })?;
150 let w_o = self.w_o_node.ok_or(ModelError::ParamsNotRegistered {
151 layer: "MultiHeadAttention",
152 })?;
153
154 let q = graph.matmul_2d(input, w_q)?;
156 let k = graph.matmul_2d(input, w_k)?;
157 let v = graph.matmul_2d(input, w_v)?;
158
159 let d_k = self.mha.d_k;
161 let num_heads = self.mha.num_heads;
162 let mut head_outputs = Vec::new();
163 for h in 0..num_heads {
164 let start = h * d_k;
165 let qh = graph.narrow(q, 1, start, d_k)?;
166 let kh = graph.narrow(k, 1, start, d_k)?;
167 let vh = graph.narrow(v, 1, start, d_k)?;
168 let attn = graph.scaled_dot_product_attention(qh, kh, vh)?;
169 head_outputs.push(attn);
170 }
171
172 let concat = graph.cat(&head_outputs, 1)?;
174
175 let output = graph.matmul_2d(concat, w_o)?;
177 Ok(output)
178 }
179
180 pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
181 self.mha.forward(input)
182 }
183}
184
185pub struct TransformerEncoderLayer {
189 pub block: TransformerEncoderBlock,
190 mha_layer: Option<MultiHeadAttentionLayer>,
191 ff_layer: Option<FeedForwardLayer>,
192 ln1_gamma_node: Option<NodeId>,
193 ln1_beta_node: Option<NodeId>,
194 ln2_gamma_node: Option<NodeId>,
195 ln2_beta_node: Option<NodeId>,
196}
197
198impl std::fmt::Debug for TransformerEncoderLayer {
199 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200 f.debug_struct("TransformerEncoderLayer")
201 .field("d_model", &self.block.d_model)
202 .finish()
203 }
204}
205
206impl Clone for TransformerEncoderLayer {
207 fn clone(&self) -> Self {
208 Self {
209 block: TransformerEncoderBlock {
210 mha: MultiHeadAttention {
211 w_q: self.block.mha.w_q.clone(),
212 w_k: self.block.mha.w_k.clone(),
213 w_v: self.block.mha.w_v.clone(),
214 w_o: self.block.mha.w_o.clone(),
215 num_heads: self.block.mha.num_heads,
216 d_k: self.block.mha.d_k,
217 },
218 ffn: FeedForward {
219 w1: self.block.ffn.w1.clone(),
220 b1: self.block.ffn.b1.clone(),
221 w2: self.block.ffn.w2.clone(),
222 b2: self.block.ffn.b2.clone(),
223 },
224 ln1_gamma: self.block.ln1_gamma.clone(),
225 ln1_beta: self.block.ln1_beta.clone(),
226 ln2_gamma: self.block.ln2_gamma.clone(),
227 ln2_beta: self.block.ln2_beta.clone(),
228 d_model: self.block.d_model,
229 },
230 mha_layer: self.mha_layer.clone(),
231 ff_layer: self.ff_layer.clone(),
232 ln1_gamma_node: self.ln1_gamma_node,
233 ln1_beta_node: self.ln1_beta_node,
234 ln2_gamma_node: self.ln2_gamma_node,
235 ln2_beta_node: self.ln2_beta_node,
236 }
237 }
238}
239
240impl TransformerEncoderLayer {
241 pub fn ln1_gamma_node(&self) -> Option<NodeId> {
242 self.ln1_gamma_node
243 }
244
245 pub fn new(d_model: usize, num_heads: usize, d_ff: usize, _seed: u64) -> Self {
246 let block = TransformerEncoderBlock::new(d_model, num_heads, d_ff)
247 .expect("valid TransformerEncoderBlock config");
248 Self {
249 block,
250 mha_layer: None,
251 ff_layer: None,
252 ln1_gamma_node: None,
253 ln1_beta_node: None,
254 ln2_gamma_node: None,
255 ln2_beta_node: None,
256 }
257 }
258
259 pub fn register_params(&mut self, graph: &mut Graph) {
260 let mut mha = MultiHeadAttentionLayer {
261 mha: MultiHeadAttention {
262 w_q: self.block.mha.w_q.clone(),
263 w_k: self.block.mha.w_k.clone(),
264 w_v: self.block.mha.w_v.clone(),
265 w_o: self.block.mha.w_o.clone(),
266 num_heads: self.block.mha.num_heads,
267 d_k: self.block.mha.d_k,
268 },
269 w_q_node: None,
270 w_k_node: None,
271 w_v_node: None,
272 w_o_node: None,
273 };
274 mha.register_params(graph);
275 self.mha_layer = Some(mha);
276
277 let mut ff = FeedForwardLayer {
278 ff: FeedForward {
279 w1: self.block.ffn.w1.clone(),
280 b1: self.block.ffn.b1.clone(),
281 w2: self.block.ffn.w2.clone(),
282 b2: self.block.ffn.b2.clone(),
283 },
284 w1_node: None,
285 b1_node: None,
286 w2_node: None,
287 b2_node: None,
288 };
289 ff.register_params(graph);
290 self.ff_layer = Some(ff);
291
292 self.ln1_gamma_node = Some(graph.variable(self.block.ln1_gamma.clone()));
293 self.ln1_beta_node = Some(graph.variable(self.block.ln1_beta.clone()));
294 self.ln2_gamma_node = Some(graph.variable(self.block.ln2_gamma.clone()));
295 self.ln2_beta_node = Some(graph.variable(self.block.ln2_beta.clone()));
296 }
297
298 pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
299 let mha = self
300 .mha_layer
301 .as_ref()
302 .ok_or(ModelError::ParamsNotRegistered {
303 layer: "TransformerEncoder",
304 })?;
305 let ff = self
306 .ff_layer
307 .as_ref()
308 .ok_or(ModelError::ParamsNotRegistered {
309 layer: "TransformerEncoder",
310 })?;
311 let ln1_g = self.ln1_gamma_node.ok_or(ModelError::ParamsNotRegistered {
312 layer: "TransformerEncoder",
313 })?;
314 let ln1_b = self.ln1_beta_node.ok_or(ModelError::ParamsNotRegistered {
315 layer: "TransformerEncoder",
316 })?;
317 let ln2_g = self.ln2_gamma_node.ok_or(ModelError::ParamsNotRegistered {
318 layer: "TransformerEncoder",
319 })?;
320 let ln2_b = self.ln2_beta_node.ok_or(ModelError::ParamsNotRegistered {
321 layer: "TransformerEncoder",
322 })?;
323
324 let attn_out = mha.forward(graph, input)?;
326 let residual1 = graph.add(input, attn_out)?;
328 let norm1 = graph.layer_norm(residual1, ln1_g, ln1_b, 1e-5)?;
329 let ff_out = ff.forward(graph, norm1)?;
331 let residual2 = graph.add(norm1, ff_out)?;
333 let output = graph.layer_norm(residual2, ln2_g, ln2_b, 1e-5)?;
334 Ok(output)
335 }
336
337 pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
338 self.block.forward(input)
339 }
340}
341
342pub struct FeedForwardLayer {
346 pub ff: FeedForward,
347 w1_node: Option<NodeId>,
348 b1_node: Option<NodeId>,
349 w2_node: Option<NodeId>,
350 b2_node: Option<NodeId>,
351}
352
353impl std::fmt::Debug for FeedForwardLayer {
354 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
355 f.debug_struct("FeedForwardLayer").finish()
356 }
357}
358
359impl Clone for FeedForwardLayer {
360 fn clone(&self) -> Self {
361 Self {
362 ff: FeedForward {
363 w1: self.ff.w1.clone(),
364 b1: self.ff.b1.clone(),
365 w2: self.ff.w2.clone(),
366 b2: self.ff.b2.clone(),
367 },
368 w1_node: self.w1_node,
369 b1_node: self.b1_node,
370 w2_node: self.w2_node,
371 b2_node: self.b2_node,
372 }
373 }
374}
375
376impl FeedForwardLayer {
377 pub fn w1_node(&self) -> Option<NodeId> {
378 self.w1_node
379 }
380
381 pub fn new(d_model: usize, d_ff: usize, _seed: u64) -> Self {
382 let ff = FeedForward::new(d_model, d_ff).expect("valid FeedForward config");
383 Self {
384 ff,
385 w1_node: None,
386 b1_node: None,
387 w2_node: None,
388 b2_node: None,
389 }
390 }
391
392 pub fn register_params(&mut self, graph: &mut Graph) {
393 self.w1_node = Some(graph.variable(self.ff.w1.clone()));
394 self.b1_node = Some(graph.variable(self.ff.b1.clone()));
395 self.w2_node = Some(graph.variable(self.ff.w2.clone()));
396 self.b2_node = Some(graph.variable(self.ff.b2.clone()));
397 }
398
399 pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
400 let w1 = self.w1_node.ok_or(ModelError::ParamsNotRegistered {
401 layer: "FeedForward",
402 })?;
403 let b1 = self.b1_node.ok_or(ModelError::ParamsNotRegistered {
404 layer: "FeedForward",
405 })?;
406 let w2 = self.w2_node.ok_or(ModelError::ParamsNotRegistered {
407 layer: "FeedForward",
408 })?;
409 let b2 = self.b2_node.ok_or(ModelError::ParamsNotRegistered {
410 layer: "FeedForward",
411 })?;
412
413 let h = graph.matmul_2d(input, w1)?;
415 let h = graph.add(h, b1)?;
416 let h = graph.relu(h)?;
417 let h = graph.matmul_2d(h, w2)?;
418 let out = graph.add(h, b2)?;
419 Ok(out)
420 }
421
422 pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
423 self.ff.forward(input)
424 }
425}