Skip to main content

yscv_model/
attention.rs

1use yscv_kernels::{matmul_2d, softmax_last_dim};
2use yscv_tensor::Tensor;
3
4use crate::ModelError;
5
6/// Scaled dot-product attention: softmax(Q @ K^T / sqrt(d_k)) @ V.
7///
8/// Q, K, V all have shape `[seq_len, d_model]` (single head) or
9/// `[num_heads * seq_len, d_k]` (multi-head, pre-split).
10pub fn scaled_dot_product_attention(
11    query: &Tensor,
12    key: &Tensor,
13    value: &Tensor,
14) -> Result<Tensor, ModelError> {
15    let q_shape = query.shape();
16    let k_shape = key.shape();
17    if q_shape.len() != 2 || k_shape.len() != 2 {
18        return Err(ModelError::InvalidParameterShape {
19            parameter: "attention QKV",
20            expected: vec![0, 0],
21            got: q_shape.to_vec(),
22        });
23    }
24    let d_k = q_shape[1] as f32;
25
26    let kt = key.transpose_2d()?;
27    let scores = matmul_2d(query, &kt)?;
28    let scale = 1.0 / d_k.sqrt();
29    let scaled = scores.scale(scale);
30    let attn_weights = softmax_last_dim(&scaled)?;
31    let output = matmul_2d(&attn_weights, value)?;
32    Ok(output)
33}
34
35/// Multi-head attention configuration.
36pub struct MultiHeadAttentionConfig {
37    pub d_model: usize,
38    pub num_heads: usize,
39}
40
41/// Multi-head attention weights.
42pub struct MultiHeadAttention {
43    pub w_q: Tensor, // [d_model, d_model]
44    pub w_k: Tensor,
45    pub w_v: Tensor,
46    pub w_o: Tensor, // [d_model, d_model]
47    pub num_heads: usize,
48    pub d_k: usize,
49}
50
51impl MultiHeadAttention {
52    /// Creates zero-initialized multi-head attention weights.
53    pub fn new(config: &MultiHeadAttentionConfig) -> Result<Self, ModelError> {
54        let d = config.d_model;
55        let h = config.num_heads;
56        if !d.is_multiple_of(h) {
57            return Err(ModelError::InvalidParameterShape {
58                parameter: "d_model must be divisible by num_heads",
59                expected: vec![d, h],
60                got: vec![d % h],
61            });
62        }
63        let d_k = d / h;
64        let z = vec![0.0f32; d * d];
65        Ok(Self {
66            w_q: Tensor::from_vec(vec![d, d], z.clone())?,
67            w_k: Tensor::from_vec(vec![d, d], z.clone())?,
68            w_v: Tensor::from_vec(vec![d, d], z.clone())?,
69            w_o: Tensor::from_vec(vec![d, d], z)?,
70            num_heads: h,
71            d_k,
72        })
73    }
74
75    /// Forward pass: input `[seq_len, d_model]` -> output `[seq_len, d_model]`.
76    pub fn forward(&self, input: &Tensor) -> Result<Tensor, ModelError> {
77        let shape = input.shape();
78        let _seq_len = shape[0];
79        let _d_model = shape[1];
80
81        let q = matmul_2d(input, &self.w_q)?;
82        let k = matmul_2d(input, &self.w_k)?;
83        let v = matmul_2d(input, &self.w_v)?;
84
85        let mut head_outputs = Vec::new();
86        for h in 0..self.num_heads {
87            let start = h * self.d_k;
88            let qh = q.narrow(1, start, self.d_k)?;
89            let kh = k.narrow(1, start, self.d_k)?;
90            let vh = v.narrow(1, start, self.d_k)?;
91            let attn = scaled_dot_product_attention(&qh, &kh, &vh)?;
92            head_outputs.push(attn);
93        }
94
95        // Concatenate heads along last dim -> [seq_len, d_model]
96        let concat = Tensor::cat(&head_outputs.iter().collect::<Vec<_>>(), 1)?;
97        let output = matmul_2d(&concat, &self.w_o)?;
98        Ok(output)
99    }
100}
101
102/// Feed-forward network: Linear(d_model, d_ff) -> ReLU -> Linear(d_ff, d_model).
103pub struct FeedForward {
104    pub w1: Tensor, // [d_model, d_ff]
105    pub b1: Tensor, // [d_ff]
106    pub w2: Tensor, // [d_ff, d_model]
107    pub b2: Tensor, // [d_model]
108}
109
110impl FeedForward {
111    pub fn new(d_model: usize, d_ff: usize) -> Result<Self, ModelError> {
112        Ok(Self {
113            w1: Tensor::from_vec(vec![d_model, d_ff], vec![0.0; d_model * d_ff])?,
114            b1: Tensor::from_vec(vec![d_ff], vec![0.0; d_ff])?,
115            w2: Tensor::from_vec(vec![d_ff, d_model], vec![0.0; d_ff * d_model])?,
116            b2: Tensor::from_vec(vec![d_model], vec![0.0; d_model])?,
117        })
118    }
119
120    /// Forward: `[seq_len, d_model]` -> `[seq_len, d_model]`.
121    pub fn forward(&self, input: &Tensor) -> Result<Tensor, ModelError> {
122        let h = matmul_2d(input, &self.w1)?;
123        let h = h.add(&self.b1.unsqueeze(0)?)?;
124        let data: Vec<f32> = h.data().iter().map(|&v| v.max(0.0)).collect();
125        let h = Tensor::from_vec(h.shape().to_vec(), data)?;
126        let out = matmul_2d(&h, &self.w2)?;
127        let out = out.add(&self.b2.unsqueeze(0)?)?;
128        Ok(out)
129    }
130}
131
132/// Generates a causal (lower-triangular) attention mask.
133/// Returns [seq_len, seq_len] tensor where:
134/// - 0.0 on and below diagonal (allowed positions)
135/// - f32::NEG_INFINITY above diagonal (masked positions)
136pub fn generate_causal_mask(seq_len: usize) -> Result<Tensor, ModelError> {
137    let mut data = vec![0.0f32; seq_len * seq_len];
138    for i in 0..seq_len {
139        for j in (i + 1)..seq_len {
140            data[i * seq_len + j] = f32::NEG_INFINITY;
141        }
142    }
143    Ok(Tensor::from_vec(vec![seq_len, seq_len], data)?)
144}
145
146/// Generates a padding mask for batched sequences with different lengths.
147/// lengths: actual length of each sequence in the batch
148/// max_len: maximum sequence length (pad length)
149/// Returns [batch, max_len] tensor where:
150/// - 0.0 for valid positions (index < length)
151/// - f32::NEG_INFINITY for padding positions (index >= length)
152pub fn generate_padding_mask(lengths: &[usize], max_len: usize) -> Result<Tensor, ModelError> {
153    let batch = lengths.len();
154    let mut data = vec![0.0f32; batch * max_len];
155    for (b, &len) in lengths.iter().enumerate() {
156        for j in len..max_len {
157            data[b * max_len + j] = f32::NEG_INFINITY;
158        }
159    }
160    Ok(Tensor::from_vec(vec![batch, max_len], data)?)
161}
162
163/// Transformer encoder block: MHA -> Add&Norm -> FFN -> Add&Norm.
164pub struct TransformerEncoderBlock {
165    pub mha: MultiHeadAttention,
166    pub ffn: FeedForward,
167    pub ln1_gamma: Tensor,
168    pub ln1_beta: Tensor,
169    pub ln2_gamma: Tensor,
170    pub ln2_beta: Tensor,
171    pub d_model: usize,
172}
173
174impl TransformerEncoderBlock {
175    pub fn new(d_model: usize, num_heads: usize, d_ff: usize) -> Result<Self, ModelError> {
176        let config = MultiHeadAttentionConfig { d_model, num_heads };
177        Ok(Self {
178            mha: MultiHeadAttention::new(&config)?,
179            ffn: FeedForward::new(d_model, d_ff)?,
180            ln1_gamma: Tensor::from_vec(vec![d_model], vec![1.0; d_model])?,
181            ln1_beta: Tensor::from_vec(vec![d_model], vec![0.0; d_model])?,
182            ln2_gamma: Tensor::from_vec(vec![d_model], vec![1.0; d_model])?,
183            ln2_beta: Tensor::from_vec(vec![d_model], vec![0.0; d_model])?,
184            d_model,
185        })
186    }
187
188    /// Forward: `[seq_len, d_model]` -> `[seq_len, d_model]`.
189    pub fn forward(&self, input: &Tensor) -> Result<Tensor, ModelError> {
190        let attn_out = self.mha.forward(input)?;
191        let residual1 = input.add(&attn_out)?;
192        let norm1 = layer_norm_2d(&residual1, &self.ln1_gamma, &self.ln1_beta, self.d_model)?;
193
194        let ffn_out = self.ffn.forward(&norm1)?;
195        let residual2 = norm1.add(&ffn_out)?;
196        let norm2 = layer_norm_2d(&residual2, &self.ln2_gamma, &self.ln2_beta, self.d_model)?;
197        Ok(norm2)
198    }
199}
200
201fn layer_norm_2d(
202    input: &Tensor,
203    gamma: &Tensor,
204    beta: &Tensor,
205    _d: usize,
206) -> Result<Tensor, ModelError> {
207    let params = yscv_kernels::LayerNormLastDimParams {
208        gamma,
209        beta,
210        epsilon: 1e-5,
211    };
212    yscv_kernels::layer_norm_last_dim(input, params).map_err(Into::into)
213}