1use yscv_kernels::matmul_2d;
2use yscv_tensor::Tensor;
3
4use super::ModelError;
5use super::attention::{
6 FeedForward, MultiHeadAttention, MultiHeadAttentionConfig, scaled_dot_product_attention,
7};
8
9pub struct CrossAttention {
14 pub w_q: Tensor, pub w_k: Tensor,
16 pub w_v: Tensor,
17 pub w_o: Tensor,
18 pub num_heads: usize,
19 pub d_k: usize,
20 pub d_model: usize,
21}
22
23impl CrossAttention {
24 pub fn new(d_model: usize, num_heads: usize) -> Result<Self, ModelError> {
25 if !d_model.is_multiple_of(num_heads) {
26 return Err(ModelError::InvalidParameterShape {
27 parameter: "d_model must be divisible by num_heads",
28 expected: vec![d_model, num_heads],
29 got: vec![d_model % num_heads],
30 });
31 }
32 let d_k = d_model / num_heads;
33 let z = vec![0.0f32; d_model * d_model];
34 Ok(Self {
35 w_q: Tensor::from_vec(vec![d_model, d_model], z.clone())?,
36 w_k: Tensor::from_vec(vec![d_model, d_model], z.clone())?,
37 w_v: Tensor::from_vec(vec![d_model, d_model], z.clone())?,
38 w_o: Tensor::from_vec(vec![d_model, d_model], z)?,
39 num_heads,
40 d_k,
41 d_model,
42 })
43 }
44
45 pub fn forward(&self, query: &Tensor, kv: &Tensor) -> Result<Tensor, ModelError> {
52 let q = matmul_2d(query, &self.w_q)?;
53 let k = matmul_2d(kv, &self.w_k)?;
54 let v = matmul_2d(kv, &self.w_v)?;
55
56 let mut head_outputs = Vec::new();
57 for h in 0..self.num_heads {
58 let start = h * self.d_k;
59 let qh = q.narrow(1, start, self.d_k)?;
60 let kh = k.narrow(1, start, self.d_k)?;
61 let vh = v.narrow(1, start, self.d_k)?;
62 let attn = scaled_dot_product_attention(&qh, &kh, &vh)?;
63 head_outputs.push(attn);
64 }
65
66 let concat = Tensor::cat(&head_outputs.iter().collect::<Vec<_>>(), 1)?;
67 let output = matmul_2d(&concat, &self.w_o)?;
68 Ok(output)
69 }
70}
71
72fn layer_norm_2d(input: &Tensor, gamma: &Tensor, beta: &Tensor) -> Result<Tensor, ModelError> {
73 let params = yscv_kernels::LayerNormLastDimParams {
74 gamma,
75 beta,
76 epsilon: 1e-5,
77 };
78 yscv_kernels::layer_norm_last_dim(input, params).map_err(Into::into)
79}
80
81pub struct TransformerDecoderBlock {
84 pub self_attn: MultiHeadAttention,
85 pub cross_attn: CrossAttention,
86 pub ffn: FeedForward,
87 pub ln1_gamma: Tensor,
88 pub ln1_beta: Tensor,
89 pub ln2_gamma: Tensor,
90 pub ln2_beta: Tensor,
91 pub ln3_gamma: Tensor,
92 pub ln3_beta: Tensor,
93 pub d_model: usize,
94}
95
96impl TransformerDecoderBlock {
97 pub fn new(d_model: usize, num_heads: usize, d_ff: usize) -> Result<Self, ModelError> {
98 let mha_config = MultiHeadAttentionConfig { d_model, num_heads };
99 Ok(Self {
100 self_attn: MultiHeadAttention::new(&mha_config)?,
101 cross_attn: CrossAttention::new(d_model, num_heads)?,
102 ffn: FeedForward::new(d_model, d_ff)?,
103 ln1_gamma: Tensor::from_vec(vec![d_model], vec![1.0; d_model])?,
104 ln1_beta: Tensor::from_vec(vec![d_model], vec![0.0; d_model])?,
105 ln2_gamma: Tensor::from_vec(vec![d_model], vec![1.0; d_model])?,
106 ln2_beta: Tensor::from_vec(vec![d_model], vec![0.0; d_model])?,
107 ln3_gamma: Tensor::from_vec(vec![d_model], vec![1.0; d_model])?,
108 ln3_beta: Tensor::from_vec(vec![d_model], vec![0.0; d_model])?,
109 d_model,
110 })
111 }
112
113 pub fn forward(&self, target: &Tensor, memory: &Tensor) -> Result<Tensor, ModelError> {
120 let sa_out = self.self_attn.forward(target)?;
122 let residual1 = target.add(&sa_out)?;
123 let x = layer_norm_2d(&residual1, &self.ln1_gamma, &self.ln1_beta)?;
124
125 let ca_out = self.cross_attn.forward(&x, memory)?;
127 let residual2 = x.add(&ca_out)?;
128 let x = layer_norm_2d(&residual2, &self.ln2_gamma, &self.ln2_beta)?;
129
130 let ffn_out = self.ffn.forward(&x)?;
132 let residual3 = x.add(&ffn_out)?;
133 let x = layer_norm_2d(&residual3, &self.ln3_gamma, &self.ln3_beta)?;
134
135 Ok(x)
136 }
137}
138
139pub struct TransformerDecoder {
141 pub layers: Vec<TransformerDecoderBlock>,
142}
143
144impl TransformerDecoder {
145 pub fn new(
146 d_model: usize,
147 num_heads: usize,
148 d_ff: usize,
149 num_layers: usize,
150 ) -> Result<Self, ModelError> {
151 let mut layers = Vec::with_capacity(num_layers);
152 for _ in 0..num_layers {
153 layers.push(TransformerDecoderBlock::new(d_model, num_heads, d_ff)?);
154 }
155 Ok(Self { layers })
156 }
157
158 pub fn forward(&self, target: &Tensor, memory: &Tensor) -> Result<Tensor, ModelError> {
165 let mut x = target.clone();
166 for layer in &self.layers {
167 x = layer.forward(&x, memory)?;
168 }
169 Ok(x)
170 }
171}