Skip to main content

yscv_model/layers/
attention.rs

1use yscv_autograd::{Graph, NodeId};
2use yscv_tensor::Tensor;
3
4use super::super::attention::{
5    FeedForward, MultiHeadAttention, MultiHeadAttentionConfig, TransformerEncoderBlock,
6};
7use crate::ModelError;
8
9/// Embedding lookup table: maps integer indices to dense vectors.
10#[derive(Debug, Clone)]
11pub struct EmbeddingLayer {
12    num_embeddings: usize,
13    embedding_dim: usize,
14    weight: NodeId,
15}
16
17impl EmbeddingLayer {
18    pub fn new(
19        graph: &mut Graph,
20        num_embeddings: usize,
21        embedding_dim: usize,
22        weight_init: Tensor,
23    ) -> Result<Self, ModelError> {
24        let expected = vec![num_embeddings, embedding_dim];
25        if weight_init.shape() != expected {
26            return Err(ModelError::InvalidParameterShape {
27                parameter: "embedding_weight",
28                expected,
29                got: weight_init.shape().to_vec(),
30            });
31        }
32        let weight = graph.variable(weight_init);
33        Ok(Self {
34            num_embeddings,
35            embedding_dim,
36            weight,
37        })
38    }
39
40    pub fn num_embeddings(&self) -> usize {
41        self.num_embeddings
42    }
43    pub fn embedding_dim(&self) -> usize {
44        self.embedding_dim
45    }
46    pub fn weight_node(&self) -> NodeId {
47        self.weight
48    }
49
50    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
51        Ok(graph.embedding_lookup(self.weight, input)?)
52    }
53
54    pub fn forward_inference(&self, graph: &Graph, indices: &Tensor) -> Result<Tensor, ModelError> {
55        let weight = graph.value(self.weight)?;
56        let w_data = weight.data();
57        let idx_data = indices.data();
58        let batch = idx_data.len();
59        let dim = self.embedding_dim;
60        let mut out = vec![0.0f32; batch * dim];
61        for (i, &idx_f) in idx_data.iter().enumerate() {
62            let idx = idx_f as usize;
63            if idx >= self.num_embeddings {
64                return Err(ModelError::InvalidInputShape {
65                    expected_features: self.num_embeddings,
66                    got: indices.shape().to_vec(),
67                });
68            }
69            out[i * dim..(i + 1) * dim].copy_from_slice(&w_data[idx * dim..(idx + 1) * dim]);
70        }
71        let mut shape = indices.shape().to_vec();
72        shape.push(dim);
73        Ok(Tensor::from_vec(shape, out)?)
74    }
75}
76
77/// Multi-head attention layer wrapping `MultiHeadAttention`.
78///
79/// Self-attention: Q=K=V=input. Input/output: `[seq_len, d_model]`.
80pub struct MultiHeadAttentionLayer {
81    pub mha: MultiHeadAttention,
82    w_q_node: Option<NodeId>,
83    w_k_node: Option<NodeId>,
84    w_v_node: Option<NodeId>,
85    w_o_node: Option<NodeId>,
86}
87
88impl std::fmt::Debug for MultiHeadAttentionLayer {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        f.debug_struct("MultiHeadAttentionLayer")
91            .field("num_heads", &self.mha.num_heads)
92            .field("d_k", &self.mha.d_k)
93            .finish()
94    }
95}
96
97impl Clone for MultiHeadAttentionLayer {
98    fn clone(&self) -> Self {
99        Self {
100            mha: MultiHeadAttention {
101                w_q: self.mha.w_q.clone(),
102                w_k: self.mha.w_k.clone(),
103                w_v: self.mha.w_v.clone(),
104                w_o: self.mha.w_o.clone(),
105                num_heads: self.mha.num_heads,
106                d_k: self.mha.d_k,
107            },
108            w_q_node: self.w_q_node,
109            w_k_node: self.w_k_node,
110            w_v_node: self.w_v_node,
111            w_o_node: self.w_o_node,
112        }
113    }
114}
115
116impl MultiHeadAttentionLayer {
117    pub fn w_q_node(&self) -> Option<NodeId> {
118        self.w_q_node
119    }
120
121    pub fn new(d_model: usize, num_heads: usize, _seed: u64) -> Self {
122        let config = MultiHeadAttentionConfig { d_model, num_heads };
123        let mha = MultiHeadAttention::new(&config).expect("valid MHA config");
124        Self {
125            mha,
126            w_q_node: None,
127            w_k_node: None,
128            w_v_node: None,
129            w_o_node: None,
130        }
131    }
132
133    pub fn register_params(&mut self, graph: &mut Graph) {
134        self.w_q_node = Some(graph.variable(self.mha.w_q.clone()));
135        self.w_k_node = Some(graph.variable(self.mha.w_k.clone()));
136        self.w_v_node = Some(graph.variable(self.mha.w_v.clone()));
137        self.w_o_node = Some(graph.variable(self.mha.w_o.clone()));
138    }
139
140    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
141        let w_q = self.w_q_node.ok_or(ModelError::ParamsNotRegistered {
142            layer: "MultiHeadAttention",
143        })?;
144        let w_k = self.w_k_node.ok_or(ModelError::ParamsNotRegistered {
145            layer: "MultiHeadAttention",
146        })?;
147        let w_v = self.w_v_node.ok_or(ModelError::ParamsNotRegistered {
148            layer: "MultiHeadAttention",
149        })?;
150        let w_o = self.w_o_node.ok_or(ModelError::ParamsNotRegistered {
151            layer: "MultiHeadAttention",
152        })?;
153
154        // Project: Q = input @ W_q, K = input @ W_k, V = input @ W_v
155        let q = graph.matmul_2d(input, w_q)?;
156        let k = graph.matmul_2d(input, w_k)?;
157        let v = graph.matmul_2d(input, w_v)?;
158
159        // Per-head attention with narrow + scaled_dot_product_attention
160        let d_k = self.mha.d_k;
161        let num_heads = self.mha.num_heads;
162        let mut head_outputs = Vec::new();
163        for h in 0..num_heads {
164            let start = h * d_k;
165            let qh = graph.narrow(q, 1, start, d_k)?;
166            let kh = graph.narrow(k, 1, start, d_k)?;
167            let vh = graph.narrow(v, 1, start, d_k)?;
168            let attn = graph.scaled_dot_product_attention(qh, kh, vh)?;
169            head_outputs.push(attn);
170        }
171
172        // Concatenate heads along last dim -> [seq_len, d_model]
173        let concat = graph.cat(&head_outputs, 1)?;
174
175        // Output projection
176        let output = graph.matmul_2d(concat, w_o)?;
177        Ok(output)
178    }
179
180    pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
181        self.mha.forward(input)
182    }
183}
184
185/// Transformer encoder layer wrapping `TransformerEncoderBlock`.
186///
187/// Input/output: `[seq_len, d_model]`.
188pub struct TransformerEncoderLayer {
189    pub block: TransformerEncoderBlock,
190    mha_layer: Option<MultiHeadAttentionLayer>,
191    ff_layer: Option<FeedForwardLayer>,
192    ln1_gamma_node: Option<NodeId>,
193    ln1_beta_node: Option<NodeId>,
194    ln2_gamma_node: Option<NodeId>,
195    ln2_beta_node: Option<NodeId>,
196}
197
198impl std::fmt::Debug for TransformerEncoderLayer {
199    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200        f.debug_struct("TransformerEncoderLayer")
201            .field("d_model", &self.block.d_model)
202            .finish()
203    }
204}
205
206impl Clone for TransformerEncoderLayer {
207    fn clone(&self) -> Self {
208        Self {
209            block: TransformerEncoderBlock {
210                mha: MultiHeadAttention {
211                    w_q: self.block.mha.w_q.clone(),
212                    w_k: self.block.mha.w_k.clone(),
213                    w_v: self.block.mha.w_v.clone(),
214                    w_o: self.block.mha.w_o.clone(),
215                    num_heads: self.block.mha.num_heads,
216                    d_k: self.block.mha.d_k,
217                },
218                ffn: FeedForward {
219                    w1: self.block.ffn.w1.clone(),
220                    b1: self.block.ffn.b1.clone(),
221                    w2: self.block.ffn.w2.clone(),
222                    b2: self.block.ffn.b2.clone(),
223                },
224                ln1_gamma: self.block.ln1_gamma.clone(),
225                ln1_beta: self.block.ln1_beta.clone(),
226                ln2_gamma: self.block.ln2_gamma.clone(),
227                ln2_beta: self.block.ln2_beta.clone(),
228                d_model: self.block.d_model,
229            },
230            mha_layer: self.mha_layer.clone(),
231            ff_layer: self.ff_layer.clone(),
232            ln1_gamma_node: self.ln1_gamma_node,
233            ln1_beta_node: self.ln1_beta_node,
234            ln2_gamma_node: self.ln2_gamma_node,
235            ln2_beta_node: self.ln2_beta_node,
236        }
237    }
238}
239
240impl TransformerEncoderLayer {
241    pub fn ln1_gamma_node(&self) -> Option<NodeId> {
242        self.ln1_gamma_node
243    }
244
245    pub fn new(d_model: usize, num_heads: usize, d_ff: usize, _seed: u64) -> Self {
246        let block = TransformerEncoderBlock::new(d_model, num_heads, d_ff)
247            .expect("valid TransformerEncoderBlock config");
248        Self {
249            block,
250            mha_layer: None,
251            ff_layer: None,
252            ln1_gamma_node: None,
253            ln1_beta_node: None,
254            ln2_gamma_node: None,
255            ln2_beta_node: None,
256        }
257    }
258
259    pub fn register_params(&mut self, graph: &mut Graph) {
260        let mut mha = MultiHeadAttentionLayer {
261            mha: MultiHeadAttention {
262                w_q: self.block.mha.w_q.clone(),
263                w_k: self.block.mha.w_k.clone(),
264                w_v: self.block.mha.w_v.clone(),
265                w_o: self.block.mha.w_o.clone(),
266                num_heads: self.block.mha.num_heads,
267                d_k: self.block.mha.d_k,
268            },
269            w_q_node: None,
270            w_k_node: None,
271            w_v_node: None,
272            w_o_node: None,
273        };
274        mha.register_params(graph);
275        self.mha_layer = Some(mha);
276
277        let mut ff = FeedForwardLayer {
278            ff: FeedForward {
279                w1: self.block.ffn.w1.clone(),
280                b1: self.block.ffn.b1.clone(),
281                w2: self.block.ffn.w2.clone(),
282                b2: self.block.ffn.b2.clone(),
283            },
284            w1_node: None,
285            b1_node: None,
286            w2_node: None,
287            b2_node: None,
288        };
289        ff.register_params(graph);
290        self.ff_layer = Some(ff);
291
292        self.ln1_gamma_node = Some(graph.variable(self.block.ln1_gamma.clone()));
293        self.ln1_beta_node = Some(graph.variable(self.block.ln1_beta.clone()));
294        self.ln2_gamma_node = Some(graph.variable(self.block.ln2_gamma.clone()));
295        self.ln2_beta_node = Some(graph.variable(self.block.ln2_beta.clone()));
296    }
297
298    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
299        let mha = self
300            .mha_layer
301            .as_ref()
302            .ok_or(ModelError::ParamsNotRegistered {
303                layer: "TransformerEncoder",
304            })?;
305        let ff = self
306            .ff_layer
307            .as_ref()
308            .ok_or(ModelError::ParamsNotRegistered {
309                layer: "TransformerEncoder",
310            })?;
311        let ln1_g = self.ln1_gamma_node.ok_or(ModelError::ParamsNotRegistered {
312            layer: "TransformerEncoder",
313        })?;
314        let ln1_b = self.ln1_beta_node.ok_or(ModelError::ParamsNotRegistered {
315            layer: "TransformerEncoder",
316        })?;
317        let ln2_g = self.ln2_gamma_node.ok_or(ModelError::ParamsNotRegistered {
318            layer: "TransformerEncoder",
319        })?;
320        let ln2_b = self.ln2_beta_node.ok_or(ModelError::ParamsNotRegistered {
321            layer: "TransformerEncoder",
322        })?;
323
324        // attn_out = mha.forward(graph, input)
325        let attn_out = mha.forward(graph, input)?;
326        // norm1 = layer_norm(input + attn_out)
327        let residual1 = graph.add(input, attn_out)?;
328        let norm1 = graph.layer_norm(residual1, ln1_g, ln1_b, 1e-5)?;
329        // ff_out = ff.forward(graph, norm1)
330        let ff_out = ff.forward(graph, norm1)?;
331        // output = layer_norm(norm1 + ff_out)
332        let residual2 = graph.add(norm1, ff_out)?;
333        let output = graph.layer_norm(residual2, ln2_g, ln2_b, 1e-5)?;
334        Ok(output)
335    }
336
337    pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
338        self.block.forward(input)
339    }
340}
341
342/// Feed-forward layer wrapping `FeedForward`.
343///
344/// Input/output: `[seq_len, d_model]`.
345pub struct FeedForwardLayer {
346    pub ff: FeedForward,
347    w1_node: Option<NodeId>,
348    b1_node: Option<NodeId>,
349    w2_node: Option<NodeId>,
350    b2_node: Option<NodeId>,
351}
352
353impl std::fmt::Debug for FeedForwardLayer {
354    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
355        f.debug_struct("FeedForwardLayer").finish()
356    }
357}
358
359impl Clone for FeedForwardLayer {
360    fn clone(&self) -> Self {
361        Self {
362            ff: FeedForward {
363                w1: self.ff.w1.clone(),
364                b1: self.ff.b1.clone(),
365                w2: self.ff.w2.clone(),
366                b2: self.ff.b2.clone(),
367            },
368            w1_node: self.w1_node,
369            b1_node: self.b1_node,
370            w2_node: self.w2_node,
371            b2_node: self.b2_node,
372        }
373    }
374}
375
376impl FeedForwardLayer {
377    pub fn w1_node(&self) -> Option<NodeId> {
378        self.w1_node
379    }
380
381    pub fn new(d_model: usize, d_ff: usize, _seed: u64) -> Self {
382        let ff = FeedForward::new(d_model, d_ff).expect("valid FeedForward config");
383        Self {
384            ff,
385            w1_node: None,
386            b1_node: None,
387            w2_node: None,
388            b2_node: None,
389        }
390    }
391
392    pub fn register_params(&mut self, graph: &mut Graph) {
393        self.w1_node = Some(graph.variable(self.ff.w1.clone()));
394        self.b1_node = Some(graph.variable(self.ff.b1.clone()));
395        self.w2_node = Some(graph.variable(self.ff.w2.clone()));
396        self.b2_node = Some(graph.variable(self.ff.b2.clone()));
397    }
398
399    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
400        let w1 = self.w1_node.ok_or(ModelError::ParamsNotRegistered {
401            layer: "FeedForward",
402        })?;
403        let b1 = self.b1_node.ok_or(ModelError::ParamsNotRegistered {
404            layer: "FeedForward",
405        })?;
406        let w2 = self.w2_node.ok_or(ModelError::ParamsNotRegistered {
407            layer: "FeedForward",
408        })?;
409        let b2 = self.b2_node.ok_or(ModelError::ParamsNotRegistered {
410            layer: "FeedForward",
411        })?;
412
413        // output = relu(input @ w1 + b1) @ w2 + b2
414        let h = graph.matmul_2d(input, w1)?;
415        let h = graph.add(h, b1)?;
416        let h = graph.relu(h)?;
417        let h = graph.matmul_2d(h, w2)?;
418        let out = graph.add(h, b2)?;
419        Ok(out)
420    }
421
422    pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
423        self.ff.forward(input)
424    }
425}