pub struct MultiHeadSelfAttention<B: Backend> { /* private fields */ }Expand description
Scaled dot-product multi-head self-attention with optional chunked computation.
When chunk_size > 0 the query sequence is processed in windows of
chunk_size rows, keeping the forward-pass peak attention memory at
O(B · H · chunk_size · N) instead of O(B · H · N²), and ensuring
each individual WGPU GPU dispatch remains small enough to avoid OS
watchdog (TDR) timeouts.
§⚠ Training memory — chunking reduces dispatch size but NOT total tape
Burn’s forward pass builds an autodiff tape for every transformer layer
before loss.backward() runs. At the forward→backward boundary all
depth layers’ chunk tensors are simultaneously in GPU memory:
peak = depth × 2 × ceil(N/chunk) × B × H × chunk × N × 4 bytes
= 12 × 2 × 39 × B × 12 × 64 × 2448 × 4 (ViT-B defaults)
≈ 6.56 GB × BChunking (small chunk_size) keeps individual GPU dispatch sizes
small (preventing OS watchdog / TDR timeouts), but the cumulative tape
size is the same as full attention. The only way to reduce training
memory is gradient checkpointing (recompute attention during backward
instead of storing it) — not yet implemented in this codebase.
Safe configurations (24 GB GPU, ViT-B):
batch_size = 2→ all-layers peak ≈ 13 GB ✓batch_size = 4→ all-layers peak ≈ 26 GB ✗ OOM
The crate::training::learner::train function guards against unsafe
configurations using --vram-gb to derive the correct limit.
§Forward memory comparison (N = 2 448, H = 12, B = 8, fp32)
| mode | peak fwd attn tensor | size |
|---|---|---|
| full (chunk=0) | (8, 12, 2448, 2448) | ~18 GB |
| chunk=256 | (8, 12, 256, 2448) | ~1.9 GB |
| chunk=128 | (8, 12, 128, 2448) | ~960 MB |
| chunk=64 | (8, 12, 64, 2448) | ~480 MB |
Implementations§
Trait Implementations§
Source§impl<B> AutodiffModule<B> for MultiHeadSelfAttention<B>
impl<B> AutodiffModule<B> for MultiHeadSelfAttention<B>
Source§type InnerModule = MultiHeadSelfAttention<<B as AutodiffBackend>::InnerBackend>
type InnerModule = MultiHeadSelfAttention<<B as AutodiffBackend>::InnerBackend>
Source§fn valid(&self) -> Self::InnerModule
fn valid(&self) -> Self::InnerModule
Source§impl<B: Backend> Clone for MultiHeadSelfAttention<B>
impl<B: Backend> Clone for MultiHeadSelfAttention<B>
Source§impl<B: Backend> Display for MultiHeadSelfAttention<B>
impl<B: Backend> Display for MultiHeadSelfAttention<B>
Source§impl<B: Backend> Module<B> for MultiHeadSelfAttention<B>
impl<B: Backend> Module<B> for MultiHeadSelfAttention<B>
Source§type Record = MultiHeadSelfAttentionRecord<B>
type Record = MultiHeadSelfAttentionRecord<B>
Source§fn load_record(self, record: Self::Record) -> Self
fn load_record(self, record: Self::Record) -> Self
Source§fn into_record(self) -> Self::Record
fn into_record(self) -> Self::Record
Source§fn num_params(&self) -> usize
fn num_params(&self) -> usize
Source§fn visit<Visitor: ModuleVisitor<B>>(&self, visitor: &mut Visitor)
fn visit<Visitor: ModuleVisitor<B>>(&self, visitor: &mut Visitor)
Source§fn map<Mapper: ModuleMapper<B>>(self, mapper: &mut Mapper) -> Self
fn map<Mapper: ModuleMapper<B>>(self, mapper: &mut Mapper) -> Self
Source§fn collect_devices(&self, devices: Devices<B>) -> Devices<B>
fn collect_devices(&self, devices: Devices<B>) -> Devices<B>
Source§fn to_device(self, device: &B::Device) -> Self
fn to_device(self, device: &B::Device) -> Self
Source§fn fork(self, device: &B::Device) -> Self
fn fork(self, device: &B::Device) -> Self
Source§fn devices(&self) -> Vec<<B as Backend>::Device>
fn devices(&self) -> Vec<<B as Backend>::Device>
Source§fn save_file<FR, PB>(
self,
file_path: PB,
recorder: &FR,
) -> Result<(), RecorderError>
fn save_file<FR, PB>( self, file_path: PB, recorder: &FR, ) -> Result<(), RecorderError>
Source§fn load_file<FR, PB>(
self,
file_path: PB,
recorder: &FR,
device: &<B as Backend>::Device,
) -> Result<Self, RecorderError>
fn load_file<FR, PB>( self, file_path: PB, recorder: &FR, device: &<B as Backend>::Device, ) -> Result<Self, RecorderError>
Source§fn quantize_weights<C>(self, quantizer: &mut Quantizer<C>) -> Selfwhere
C: Calibration,
fn quantize_weights<C>(self, quantizer: &mut Quantizer<C>) -> Selfwhere
C: Calibration,
Source§impl<B: Backend> ModuleDisplay for MultiHeadSelfAttention<B>
impl<B: Backend> ModuleDisplay for MultiHeadSelfAttention<B>
Source§fn format(&self, passed_settings: DisplaySettings) -> String
fn format(&self, passed_settings: DisplaySettings) -> String
Source§fn custom_settings(&self) -> Option<DisplaySettings>
fn custom_settings(&self) -> Option<DisplaySettings>
Auto Trait Implementations§
impl<B> !Freeze for MultiHeadSelfAttention<B>
impl<B> !RefUnwindSafe for MultiHeadSelfAttention<B>
impl<B> Send for MultiHeadSelfAttention<B>
impl<B> !Sync for MultiHeadSelfAttention<B>
impl<B> Unpin for MultiHeadSelfAttention<B>where
<B as Backend>::FloatTensorPrimitive<2>: Unpin,
<B as Backend>::QuantizedTensorPrimitive<2>: Unpin,
<B as Backend>::Device: Unpin,
<B as Backend>::FloatTensorPrimitive<1>: Unpin,
<B as Backend>::QuantizedTensorPrimitive<1>: Unpin,
impl<B> UnsafeUnpin for MultiHeadSelfAttention<B>where
<B as Backend>::FloatTensorPrimitive<2>: UnsafeUnpin,
<B as Backend>::QuantizedTensorPrimitive<2>: UnsafeUnpin,
<B as Backend>::Device: UnsafeUnpin,
<B as Backend>::FloatTensorPrimitive<1>: UnsafeUnpin,
<B as Backend>::QuantizedTensorPrimitive<1>: UnsafeUnpin,
impl<B> UnwindSafe for MultiHeadSelfAttention<B>where
<B as Backend>::FloatTensorPrimitive<2>: UnwindSafe,
<B as Backend>::QuantizedTensorPrimitive<2>: UnwindSafe,
<B as Backend>::FloatTensorPrimitive<1>: UnwindSafe,
<B as Backend>::QuantizedTensorPrimitive<1>: UnwindSafe,
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> Instrument for T
impl<T> Instrument for T
Source§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
Source§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§impl<T> Pointable for T
impl<T> Pointable for T
Source§impl<T> PolicyExt for Twhere
T: ?Sized,
impl<T> PolicyExt for Twhere
T: ?Sized,
Source§impl<T> ToStringFallible for Twhere
T: Display,
impl<T> ToStringFallible for Twhere
T: Display,
Source§fn try_to_string(&self) -> Result<String, TryReserveError>
fn try_to_string(&self) -> Result<String, TryReserveError>
ToString::to_string, but without panic on OOM.