Skip to main content

Module decoder

Module decoder 

Source
Expand description

Transformer decoder layers.

This module implements transformer decoder layers that combine:

  • Masked multi-head self-attention
  • Cross-attention to encoder outputs
  • Feed-forward networks
  • Layer normalization
  • Residual connections

§Transformer Decoder Layer

Pre-normalization variant:

x' = x + MaskedSelfAttention(LayerNorm(x))
x'' = x' + CrossAttention(LayerNorm(x'), encoder_output)
output = x'' + FFN(LayerNorm(x''))

Post-normalization variant:

x' = LayerNorm(x + MaskedSelfAttention(x))
x'' = LayerNorm(x' + CrossAttention(x', encoder_output))
output = LayerNorm(x'' + FFN(x''))

Structs§

Decoder
Transformer decoder layer
DecoderConfig
Configuration for transformer decoder layer