Skip to main content

yscv_model/
transformer_decoder.rs

1use yscv_kernels::matmul_2d;
2use yscv_tensor::Tensor;
3
4use super::ModelError;
5use super::attention::{
6    FeedForward, MultiHeadAttention, MultiHeadAttentionConfig, scaled_dot_product_attention,
7};
8
9/// Cross-attention: query from decoder, key/value from encoder output.
10///
11/// Stores separate projection matrices for Q (applied to decoder state)
12/// and K/V (applied to encoder memory).
13pub struct CrossAttention {
14    pub w_q: Tensor, // [d_model, d_model]
15    pub w_k: Tensor,
16    pub w_v: Tensor,
17    pub w_o: Tensor,
18    pub num_heads: usize,
19    pub d_k: usize,
20    pub d_model: usize,
21}
22
23impl CrossAttention {
24    pub fn new(d_model: usize, num_heads: usize) -> Result<Self, ModelError> {
25        if !d_model.is_multiple_of(num_heads) {
26            return Err(ModelError::InvalidParameterShape {
27                parameter: "d_model must be divisible by num_heads",
28                expected: vec![d_model, num_heads],
29                got: vec![d_model % num_heads],
30            });
31        }
32        let d_k = d_model / num_heads;
33        let z = vec![0.0f32; d_model * d_model];
34        Ok(Self {
35            w_q: Tensor::from_vec(vec![d_model, d_model], z.clone())?,
36            w_k: Tensor::from_vec(vec![d_model, d_model], z.clone())?,
37            w_v: Tensor::from_vec(vec![d_model, d_model], z.clone())?,
38            w_o: Tensor::from_vec(vec![d_model, d_model], z)?,
39            num_heads,
40            d_k,
41            d_model,
42        })
43    }
44
45    /// Forward pass.
46    ///
47    /// `query`: `[seq_q, d_model]` — decoder state.
48    /// `kv`:    `[seq_kv, d_model]` — encoder memory.
49    ///
50    /// Returns `[seq_q, d_model]`.
51    pub fn forward(&self, query: &Tensor, kv: &Tensor) -> Result<Tensor, ModelError> {
52        let q = matmul_2d(query, &self.w_q)?;
53        let k = matmul_2d(kv, &self.w_k)?;
54        let v = matmul_2d(kv, &self.w_v)?;
55
56        let mut head_outputs = Vec::new();
57        for h in 0..self.num_heads {
58            let start = h * self.d_k;
59            let qh = q.narrow(1, start, self.d_k)?;
60            let kh = k.narrow(1, start, self.d_k)?;
61            let vh = v.narrow(1, start, self.d_k)?;
62            let attn = scaled_dot_product_attention(&qh, &kh, &vh)?;
63            head_outputs.push(attn);
64        }
65
66        let concat = Tensor::cat(&head_outputs.iter().collect::<Vec<_>>(), 1)?;
67        let output = matmul_2d(&concat, &self.w_o)?;
68        Ok(output)
69    }
70}
71
72fn layer_norm_2d(input: &Tensor, gamma: &Tensor, beta: &Tensor) -> Result<Tensor, ModelError> {
73    let params = yscv_kernels::LayerNormLastDimParams {
74        gamma,
75        beta,
76        epsilon: 1e-5,
77    };
78    yscv_kernels::layer_norm_last_dim(input, params).map_err(Into::into)
79}
80
81/// Single transformer decoder block: masked self-attention → cross-attention → FFN,
82/// each sub-layer wrapped with residual connection and layer normalization.
83pub struct TransformerDecoderBlock {
84    pub self_attn: MultiHeadAttention,
85    pub cross_attn: CrossAttention,
86    pub ffn: FeedForward,
87    pub ln1_gamma: Tensor,
88    pub ln1_beta: Tensor,
89    pub ln2_gamma: Tensor,
90    pub ln2_beta: Tensor,
91    pub ln3_gamma: Tensor,
92    pub ln3_beta: Tensor,
93    pub d_model: usize,
94}
95
96impl TransformerDecoderBlock {
97    pub fn new(d_model: usize, num_heads: usize, d_ff: usize) -> Result<Self, ModelError> {
98        let mha_config = MultiHeadAttentionConfig { d_model, num_heads };
99        Ok(Self {
100            self_attn: MultiHeadAttention::new(&mha_config)?,
101            cross_attn: CrossAttention::new(d_model, num_heads)?,
102            ffn: FeedForward::new(d_model, d_ff)?,
103            ln1_gamma: Tensor::from_vec(vec![d_model], vec![1.0; d_model])?,
104            ln1_beta: Tensor::from_vec(vec![d_model], vec![0.0; d_model])?,
105            ln2_gamma: Tensor::from_vec(vec![d_model], vec![1.0; d_model])?,
106            ln2_beta: Tensor::from_vec(vec![d_model], vec![0.0; d_model])?,
107            ln3_gamma: Tensor::from_vec(vec![d_model], vec![1.0; d_model])?,
108            ln3_beta: Tensor::from_vec(vec![d_model], vec![0.0; d_model])?,
109            d_model,
110        })
111    }
112
113    /// Forward pass.
114    ///
115    /// `target`: `[seq_t, d_model]` — decoder input / previous decoder layer output.
116    /// `memory`: `[seq_s, d_model]` — encoder output.
117    ///
118    /// Returns `[seq_t, d_model]`.
119    pub fn forward(&self, target: &Tensor, memory: &Tensor) -> Result<Tensor, ModelError> {
120        // Sub-layer 1: masked self-attention
121        let sa_out = self.self_attn.forward(target)?;
122        let residual1 = target.add(&sa_out)?;
123        let x = layer_norm_2d(&residual1, &self.ln1_gamma, &self.ln1_beta)?;
124
125        // Sub-layer 2: cross-attention (Q from decoder, K/V from encoder)
126        let ca_out = self.cross_attn.forward(&x, memory)?;
127        let residual2 = x.add(&ca_out)?;
128        let x = layer_norm_2d(&residual2, &self.ln2_gamma, &self.ln2_beta)?;
129
130        // Sub-layer 3: feed-forward network
131        let ffn_out = self.ffn.forward(&x)?;
132        let residual3 = x.add(&ffn_out)?;
133        let x = layer_norm_2d(&residual3, &self.ln3_gamma, &self.ln3_beta)?;
134
135        Ok(x)
136    }
137}
138
139/// Stack of `TransformerDecoderBlock` layers.
140pub struct TransformerDecoder {
141    pub layers: Vec<TransformerDecoderBlock>,
142}
143
144impl TransformerDecoder {
145    pub fn new(
146        d_model: usize,
147        num_heads: usize,
148        d_ff: usize,
149        num_layers: usize,
150    ) -> Result<Self, ModelError> {
151        let mut layers = Vec::with_capacity(num_layers);
152        for _ in 0..num_layers {
153            layers.push(TransformerDecoderBlock::new(d_model, num_heads, d_ff)?);
154        }
155        Ok(Self { layers })
156    }
157
158    /// Forward pass through all decoder layers.
159    ///
160    /// `target`: `[seq_t, d_model]` — decoder input.
161    /// `memory`: `[seq_s, d_model]` — encoder output.
162    ///
163    /// Returns `[seq_t, d_model]`.
164    pub fn forward(&self, target: &Tensor, memory: &Tensor) -> Result<Tensor, ModelError> {
165        let mut x = target.clone();
166        for layer in &self.layers {
167            x = layer.forward(&x, memory)?;
168        }
169        Ok(x)
170    }
171}