Skip to main content

TrainingConfig

Struct TrainingConfig 

Source
pub struct TrainingConfig {
Show 23 fields pub model_size: ModelSize, pub total_steps: usize, pub batch_size: usize, pub lr: f64, pub weight_decay: f64, pub beta2: f64, pub beta1: f64, pub epsilon: f64, pub grad_clip_norm: f64, pub warmup_fraction: f64, pub cooldown_fraction: f64, pub lr_schedule: LrScheduleType, pub checkpoint_every: usize, pub log_every: usize, pub seed: u64, pub caption_key: CaptionKey, pub tokenizer_path: String, pub artifact_dir: String, pub data_dir: String, pub num_workers: usize, pub vram_gb: Option<f64>, pub skip_vram_check: bool, pub show_summary: bool,
}
Expand description

Training hyperparameters.

Fields§

§model_size: ModelSize

Model size preset (tiny / small / base).

Overrides SensorLMConfig when passed through the CLI. Building a config from a preset is the recommended way to avoid mismatched d_model / embed_dim values between the two towers.

§total_steps: usize

Total number of gradient update steps.

§batch_size: usize

Mini-batch size (default: 8).

§lr: f64

Peak learning rate (default: 5 × 10⁻⁴).

§weight_decay: f64

AdamW weight decay (default: 1 × 10⁻⁴).

§beta2: f64

Adam β₂ (default: 0.999, reference uses scale_by_adam b2=0.999).

§beta1: f64

Adam β₁.

§epsilon: f64

Adam ε.

§grad_clip_norm: f64

Gradient clip norm (default: 1.0).

§warmup_fraction: f64

Fraction of total steps used for linear warm-up (default: 0.2).

§cooldown_fraction: f64

Fraction of total steps used for cool-down (default: 0.2).

§lr_schedule: LrScheduleType

LR schedule type.

§checkpoint_every: usize

Save a checkpoint every N steps.

§log_every: usize

Log metrics every N steps.

§seed: u64

Random seed.

§caption_key: CaptionKey

Caption type key to use during this training run.

§tokenizer_path: String

Path to SentencePiece tokeniser model file.

§artifact_dir: String

Directory to write checkpoints / logs.

§data_dir: String

Directory containing the dataset (Parquet or raw files).

§num_workers: usize

Number of DataLoader worker threads for CPU-side data preparation.

Must be ≥ 1 — Burn’s PartialDataset::split divides the dataset length by num_workers, so 0 causes a divide-by-zero panic.

The WGPU backend (including Metal on macOS) is internally thread-safe: worker threads can call Tensor::from_floats(…, &device) safely. 2 workers is a reasonable default; increase on machines with many CPU cores and fast NVMe storage. Use 1 if you observe data-loading becoming the training bottleneck (rare with synthetic data).

§vram_gb: Option<f64>

Available GPU VRAM in gigabytes.

When set the pre-flight guard derives the attention-tensor budget as vram_gb / 3 and auto-caps batch_size to the largest value that fits, so you never have to tune --batch-size manually.

Memory split used (all figures are estimates):

┌─────────────────────────────────────────────────────┐
│  1/3 → attention score/weight tensors (one layer)  │
│  1/3 → model weights + gradients + Adam states     │
│  1/3 → non-attention activations + OS/driver slack │
└─────────────────────────────────────────────────────┘

Examples for ViT-B (base), depth=12, H=12, chunk=64, N=2448:

The peak memory is depth × per-layer because Burn’s forward pass builds autodiff tape for all transformer layers before backward() starts. 70% of VRAM is budgeted for this tape; the rest covers model weights + gradients + Adam states + other activations.

VRAMattn budget (×0.7)max batchall-layers peak
8 GB5.6 GB16.6 GB
16 GB11.2 GB16.6 GB
24 GB16.8 GB213.1 GB
32 GB22.4 GB319.7 GB
48 GB33.6 GB532.8 GB
80 GB56.0 GB852.4 GB
§skip_vram_check: bool

Skip the pre-flight VRAM safety check and proceed even if the estimated attention memory exceeds the computed limit.

Use this only when you are certain your GPU has enough free VRAM. You accept full responsibility for OOM errors or GPU driver crashes.

§show_summary: bool

Print Burn’s ═══ Learner Summary ═══ table after training.

Disabled by default to keep the terminal output clean. Pass --summary on the CLI to enable.

Trait Implementations§

Source§

impl Clone for TrainingConfig

Source§

fn clone(&self) -> TrainingConfig

Returns a duplicate of the value. Read more
1.0.0 · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl Debug for TrainingConfig

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl Default for TrainingConfig

Source§

fn default() -> Self

Returns the “default value” for a type. Read more
Source§

impl<'de> Deserialize<'de> for TrainingConfig

Source§

fn deserialize<__D>(__deserializer: __D) -> Result<Self, __D::Error>
where __D: Deserializer<'de>,

Deserialize this value from the given Serde deserializer. Read more
Source§

impl Serialize for TrainingConfig

Source§

fn serialize<__S>(&self, __serializer: __S) -> Result<__S::Ok, __S::Error>
where __S: Serializer,

Serialize this value into the given Serde serializer. Read more

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T> Instrument for T

Source§

fn instrument(self, span: Span) -> Instrumented<Self>

Instruments this type with the provided Span, returning an Instrumented wrapper. Read more
Source§

fn in_current_span(self) -> Instrumented<Self>

Instruments this type with the current Span, returning an Instrumented wrapper. Read more
Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> PolicyExt for T
where T: ?Sized,

Source§

fn and<P, B, E>(self, other: P) -> And<T, P>
where T: Policy<B, E>, P: Policy<B, E>,

Create a new Policy that returns Action::Follow only if self and other return Action::Follow. Read more
Source§

fn or<P, B, E>(self, other: P) -> Or<T, P>
where T: Policy<B, E>, P: Policy<B, E>,

Create a new Policy that returns Action::Follow if either self or other returns Action::Follow. Read more
Source§

impl<T> Same for T

Source§

type Output = T

Should always be Self
Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V

Source§

impl<T> WithSubscriber for T

Source§

fn with_subscriber<S>(self, subscriber: S) -> WithDispatch<Self>
where S: Into<Dispatch>,

Attaches the provided Subscriber to this type, returning a WithDispatch wrapper. Read more
Source§

fn with_current_subscriber(self) -> WithDispatch<Self>

Attaches the current default Subscriber to this type, returning a WithDispatch wrapper. Read more
Source§

impl<T> DeserializeOwned for T
where T: for<'de> Deserialize<'de>,