Expand description
Transformer decoder layers.
This module implements transformer decoder layers that combine:
- Masked multi-head self-attention
- Cross-attention to encoder outputs
- Feed-forward networks
- Layer normalization
- Residual connections
§Transformer Decoder Layer
Pre-normalization variant:
x' = x + MaskedSelfAttention(LayerNorm(x))
x'' = x' + CrossAttention(LayerNorm(x'), encoder_output)
output = x'' + FFN(LayerNorm(x''))Post-normalization variant:
x' = LayerNorm(x + MaskedSelfAttention(x))
x'' = LayerNorm(x' + CrossAttention(x', encoder_output))
output = LayerNorm(x'' + FFN(x''))Structs§
- Decoder
- Transformer decoder layer
- Decoder
Config - Configuration for transformer decoder layer