pub struct TrainingConfig {Show 23 fields
pub model_size: ModelSize,
pub total_steps: usize,
pub batch_size: usize,
pub lr: f64,
pub weight_decay: f64,
pub beta2: f64,
pub beta1: f64,
pub epsilon: f64,
pub grad_clip_norm: f64,
pub warmup_fraction: f64,
pub cooldown_fraction: f64,
pub lr_schedule: LrScheduleType,
pub checkpoint_every: usize,
pub log_every: usize,
pub seed: u64,
pub caption_key: CaptionKey,
pub tokenizer_path: String,
pub artifact_dir: String,
pub data_dir: String,
pub num_workers: usize,
pub vram_gb: Option<f64>,
pub skip_vram_check: bool,
pub show_summary: bool,
}Expand description
Training hyperparameters.
Fields§
§model_size: ModelSizeModel size preset (tiny / small / base).
Overrides SensorLMConfig when passed through the CLI. Building
a config from a preset is the recommended way to avoid mismatched
d_model / embed_dim values between the two towers.
total_steps: usizeTotal number of gradient update steps.
batch_size: usizeMini-batch size (default: 8).
lr: f64Peak learning rate (default: 5 × 10⁻⁴).
weight_decay: f64AdamW weight decay (default: 1 × 10⁻⁴).
beta2: f64Adam β₂ (default: 0.999, reference uses scale_by_adam b2=0.999).
beta1: f64Adam β₁.
epsilon: f64Adam ε.
grad_clip_norm: f64Gradient clip norm (default: 1.0).
warmup_fraction: f64Fraction of total steps used for linear warm-up (default: 0.2).
cooldown_fraction: f64Fraction of total steps used for cool-down (default: 0.2).
lr_schedule: LrScheduleTypeLR schedule type.
checkpoint_every: usizeSave a checkpoint every N steps.
log_every: usizeLog metrics every N steps.
seed: u64Random seed.
caption_key: CaptionKeyCaption type key to use during this training run.
tokenizer_path: StringPath to SentencePiece tokeniser model file.
artifact_dir: StringDirectory to write checkpoints / logs.
data_dir: StringDirectory containing the dataset (Parquet or raw files).
num_workers: usizeNumber of DataLoader worker threads for CPU-side data preparation.
Must be ≥ 1 — Burn’s PartialDataset::split divides the dataset
length by num_workers, so 0 causes a divide-by-zero panic.
The WGPU backend (including Metal on macOS) is internally thread-safe:
worker threads can call Tensor::from_floats(…, &device) safely.
2 workers is a reasonable default; increase on machines with many CPU
cores and fast NVMe storage. Use 1 if you observe data-loading becoming
the training bottleneck (rare with synthetic data).
vram_gb: Option<f64>Available GPU VRAM in gigabytes.
When set the pre-flight guard derives the attention-tensor budget as
vram_gb / 3 and auto-caps batch_size to the largest value that
fits, so you never have to tune --batch-size manually.
Memory split used (all figures are estimates):
┌─────────────────────────────────────────────────────┐
│ 1/3 → attention score/weight tensors (one layer) │
│ 1/3 → model weights + gradients + Adam states │
│ 1/3 → non-attention activations + OS/driver slack │
└─────────────────────────────────────────────────────┘Examples for ViT-B (base), depth=12, H=12, chunk=64, N=2448:
The peak memory is depth × per-layer because Burn’s forward pass
builds autodiff tape for all transformer layers before backward()
starts. 70% of VRAM is budgeted for this tape; the rest covers
model weights + gradients + Adam states + other activations.
| VRAM | attn budget (×0.7) | max batch | all-layers peak |
|---|---|---|---|
| 8 GB | 5.6 GB | 1 | 6.6 GB |
| 16 GB | 11.2 GB | 1 | 6.6 GB |
| 24 GB | 16.8 GB | 2 | 13.1 GB |
| 32 GB | 22.4 GB | 3 | 19.7 GB |
| 48 GB | 33.6 GB | 5 | 32.8 GB |
| 80 GB | 56.0 GB | 8 | 52.4 GB |
skip_vram_check: boolSkip the pre-flight VRAM safety check and proceed even if the estimated attention memory exceeds the computed limit.
Use this only when you are certain your GPU has enough free VRAM. You accept full responsibility for OOM errors or GPU driver crashes.
show_summary: boolPrint Burn’s ═══ Learner Summary ═══ table after training.
Disabled by default to keep the terminal output clean.
Pass --summary on the CLI to enable.
Trait Implementations§
Source§impl Clone for TrainingConfig
impl Clone for TrainingConfig
Source§fn clone(&self) -> TrainingConfig
fn clone(&self) -> TrainingConfig
1.0.0 · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
source. Read moreSource§impl Debug for TrainingConfig
impl Debug for TrainingConfig
Source§impl Default for TrainingConfig
impl Default for TrainingConfig
Source§impl<'de> Deserialize<'de> for TrainingConfig
impl<'de> Deserialize<'de> for TrainingConfig
Source§fn deserialize<__D>(__deserializer: __D) -> Result<Self, __D::Error>where
__D: Deserializer<'de>,
fn deserialize<__D>(__deserializer: __D) -> Result<Self, __D::Error>where
__D: Deserializer<'de>,
Auto Trait Implementations§
impl Freeze for TrainingConfig
impl RefUnwindSafe for TrainingConfig
impl Send for TrainingConfig
impl Sync for TrainingConfig
impl Unpin for TrainingConfig
impl UnsafeUnpin for TrainingConfig
impl UnwindSafe for TrainingConfig
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> Instrument for T
impl<T> Instrument for T
Source§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
Source§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more