Skip to main content

sensorlm/training/
learner.rs

1//! Training loop using the Burn `Learner` abstraction.
2//!
3//! # Workflow
4//!
5//! 1. Build synthetic (or real) [`DataLoader`]s.
6//! 2. Initialise [`SensorLMModel`] and [`AdamConfig`] optimiser.
7//! 3. Construct the [`RsqrtScheduler`] learning-rate schedule.
8//! 4. Call `LearnerBuilder::build(model, optim, scheduler).fit(train, valid)`.
9
10use burn::{
11    data::dataloader::DataLoaderBuilder,
12    module::Module,
13    optim::AdamConfig,
14    record::{CompactRecorder, FullPrecisionSettings, BinFileRecorder},
15    tensor::backend::AutodiffBackend,
16    train::{
17        metric::LossMetric,
18        renderer::{MetricState, MetricsRenderer, TrainingProgress},
19        LearnerBuilder,
20    },
21};
22use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
23use std::path::Path;
24use std::time::Instant;
25
26use crate::config::{SensorLMConfig, TrainingConfig};
27use crate::data::dataset::SyntheticSensorDataset;
28use crate::model::sensorlm::{SensorLMBatcher, SensorLMModel};
29use crate::training::scheduler::RsqrtScheduler;
30use crate::error::Result;
31
32// ===========================================================================
33// Custom MetricsRenderer — drives indicatif progress bars
34// ===========================================================================
35//
36// Burn's default `CliMetricsRenderer` (used when the `tui` feature is off)
37// contains a literal `dbg!(item)` call that dumps structs to stderr on every
38// step.  Replacing it with our own renderer stops that noise and gives us
39// proper per-step progress bars ticked at the right time: AFTER each GPU
40// forward+backward pass completes, not during dataset prefetch.
41
42/// Renders training progress as two `indicatif` progress bars (train + valid).
43struct SensorLMRenderer {
44    _multi:     MultiProgress,   // kept alive so bars stay grouped in the terminal
45    train_bar:  ProgressBar,
46    valid_bar:  ProgressBar,
47    train_loss: Option<f64>,
48    valid_loss: Option<f64>,
49    step_start: Instant,
50}
51
52impl SensorLMRenderer {
53    fn new(train_steps: usize, valid_steps: usize) -> Self {
54        let multi = MultiProgress::new();
55
56        let style = ProgressStyle::with_template(
57            "{prefix:.bold.cyan} [{bar:45.green/dim}] \
58             {pos:>3}/{len} \
59             {elapsed_precise} eta {eta_precise}  \
60             {msg}",
61        )
62        .unwrap()
63        .progress_chars("█▉▊▋▌▍▎▏ ");
64
65        let train_bar = multi.add(ProgressBar::new(train_steps as u64));
66        train_bar.set_style(style.clone());
67        train_bar.set_prefix("Train");
68
69        let valid_bar = multi.add(ProgressBar::new(valid_steps as u64));
70        valid_bar.set_style(style);
71        valid_bar.set_prefix("Valid");
72
73        Self {
74            _multi: multi,
75            train_bar,
76            valid_bar,
77            train_loss: None,
78            valid_loss: None,
79            step_start: Instant::now(),
80        }
81    }
82}
83
84impl MetricsRenderer for SensorLMRenderer {
85    /// Called by Burn whenever a metric value is updated during training.
86    fn update_train(&mut self, state: MetricState) {
87        if let MetricState::Numeric(_entry, val) = state {
88            self.train_loss = Some(val);
89        }
90    }
91
92    /// Called by Burn whenever a metric value is updated during validation.
93    fn update_valid(&mut self, state: MetricState) {
94        if let MetricState::Numeric(_entry, val) = state {
95            self.valid_loss = Some(val);
96        }
97    }
98
99    /// Called once per completed training step (after GPU backward pass).
100    fn render_train(&mut self, item: TrainingProgress) {
101        let step     = item.iteration;
102        let total    = if item.progress.items_total > 0 {
103            // items_total is in samples; divide by batch to get steps
104            let batch = (item.progress.items_processed as f64 / step as f64).round() as u64;
105            (item.progress.items_total as u64).div_ceil(batch.max(1))
106        } else {
107            self.train_bar.length().unwrap_or(0)
108        };
109
110        self.train_bar.set_length(total);
111        self.train_bar.set_position(step as u64);
112
113        let elapsed = self.step_start.elapsed().as_secs_f64();
114        self.step_start = Instant::now();
115
116        let msg = match self.train_loss {
117            Some(l) => format!(
118                "loss {l:.4}  ({elapsed:.1}s/step)  epoch {}/{}",
119                item.epoch, item.epoch_total
120            ),
121            None => format!(
122                "{elapsed:.1}s/step  epoch {}/{}",
123                item.epoch, item.epoch_total
124            ),
125        };
126        self.train_bar.set_message(msg);
127    }
128
129    /// Called once per completed validation step.
130    fn render_valid(&mut self, item: TrainingProgress) {
131        let step  = item.iteration;
132        let total = self.valid_bar.length().unwrap_or(0);
133        self.valid_bar.set_position(step.min(total as usize) as u64);
134
135        let msg = match self.valid_loss {
136            Some(l) => format!("loss {l:.4}"),
137            None    => String::new(),
138        };
139        self.valid_bar.set_message(msg);
140    }
141}
142
143/// Memory breakdown for the sensor-encoder attention.
144struct AttnMemEstimate {
145    /// Bytes for a single attention-score dispatch: `(B, H, chunk, N) × f32`.
146    /// Governs per-GPU-kernel compute time; very large values risk OS TDR.
147    per_dispatch_bytes: u64,
148    /// Bytes kept on the autodiff tape for **one transformer layer**:
149    /// `ceil(N/chunk) × 2 (scores + attn) × (B, H, chunk, N) × f32`.
150    ///
151    /// Burn's backward pass processes layers in reverse order and can release
152    /// a layer's tape entries once its gradients have been propagated.  The
153    /// peak is therefore dominated by one layer at a time, not all `depth`
154    /// layers simultaneously.
155    per_layer_bwd_bytes: u64,
156    /// Worst-case upper bound: all layers' tapes alive at once.
157    /// (Relevant if the autodiff graph is not freed layer-by-layer.)
158    all_layers_bwd_bytes: u64,
159}
160
161/// Compute attention-memory estimates for the sensor encoder.
162///
163/// ## Formula correctness note
164///
165/// The per-layer tape stores **all query positions × all key positions**,
166/// regardless of chunk size.  Our implementation slices `q` into windows
167/// `start..end.min(seq)`, so the last chunk is always shorter than
168/// `chunk_size`.  The total query positions stored is exactly `N`, giving:
169///
170/// ```text
171/// per_layer = 2 (scores + attn) × B × H × N × N × 4 bytes
172/// ```
173///
174/// The naive formula `2 × num_chunks × chunk × N` overcounts by
175/// `ceil(N/chunk) × chunk / N ≥ 1`, which can be 20 % or more for large
176/// chunks.  We use the exact `N²` formula instead.
177fn estimate_attn_memory(
178    batch_size: usize,
179    depth: usize,
180    num_heads: usize,
181    num_patches: usize,
182    chunk_size: usize,
183) -> AttnMemEstimate {
184    let effective_chunk = if chunk_size == 0 { num_patches } else { chunk_size };
185
186    // Per-dispatch: the actual (B, H, chunk, N) tensor size.
187    // The last chunk may be shorter — use effective_chunk as the upper bound.
188    let per_dispatch_bytes = batch_size as u64
189        * num_heads as u64
190        * effective_chunk as u64
191        * num_patches as u64
192        * 4;
193
194    // Per-layer backward tape: 2 tensors × B × H × N × N × 4 bytes.
195    // This equals full-attention memory and is INDEPENDENT of chunk_size —
196    // the sum of all chunk rows is exactly N (last chunk fills the remainder).
197    let per_layer_bwd_bytes = 2
198        * batch_size as u64
199        * num_heads as u64
200        * num_patches as u64
201        * num_patches as u64
202        * 4;
203
204    AttnMemEstimate {
205        per_dispatch_bytes,
206        per_layer_bwd_bytes,
207        all_layers_bwd_bytes: depth as u64 * per_layer_bwd_bytes,
208    }
209}
210
211/// Default guard limit for the **all-layers** attention-tensor peak when
212/// `--vram-gb` is not specified.  Calibrated for 16 GB GPUs (the lowest
213/// common denominator for serious training work).
214///
215/// The all-layers figure is the correct metric because Burn's forward pass
216/// builds autodiff tape for every transformer layer before `backward()` runs.
217/// At the forward→backward boundary all `depth` layers' chunk tensors are
218/// simultaneously in memory.
219const ALL_LAYERS_LIMIT_GB: f64 = 11.0; // safe for 16 GB; use --vram-gb for more
220
221/// Fraction of GPU VRAM to allocate to attention tensors when `--vram-gb` is
222/// given.  The remaining 30 % covers: model weights + gradients + Adam states
223/// (~16 bytes/param) and non-attention activations (Q/K/V projections, MLP,
224/// layer-norm intermediates).
225const ATTN_VRAM_FRACTION: f64 = 0.70;
226
227/// Hard upper bound for a single GPU dispatch (attention score tensor).
228/// Dispatches larger than this risk hitting the OS GPU watchdog (TDR) timeout.
229/// 512 MB is generous for Apple Silicon Metal (actual TDR threshold is ~5–15 s
230/// of GPU work, corresponding to several GB of data); 200 MB is safe on all
231/// platforms.
232const DISPATCH_LIMIT_BYTES: u64 = 512 * 1024 * 1024; // 512 MB
233
234/// Soft warning threshold for per-dispatch tensor size.
235///
236/// 0.5 GB is appropriate for Apple Silicon (Metal has a much more lenient
237/// GPU watchdog than Windows discrete GPUs).  Discrete GPU users on Linux
238/// or Windows should set `RUST_LOG=warn` and watch for TDR events if they
239/// see long step times.
240const PER_DISPATCH_WARN_GB: f64 = 0.5;
241
242/// Compute the optimal `attn_chunk_size` for a given batch configuration.
243///
244/// Maximises chunk size (minimising GPU command-buffer submissions) while
245/// keeping each attention-score dispatch ≤ [`DISPATCH_LIMIT_BYTES`].
246///
247/// Returns `0` when full attention fits within the dispatch limit (1 submission
248/// per layer instead of `ceil(N / chunk)`).
249fn optimal_chunk_size(batch_size: usize, num_heads: usize, num_patches: usize) -> usize {
250    // Dispatch bytes for one chunk: B × H × chunk × N × sizeof(f32)
251    // Solving for chunk: chunk ≤ LIMIT / (B × H × N × 4)
252    let per_chunk_row = (batch_size as u64)
253        .saturating_mul(num_heads as u64)
254        .saturating_mul(num_patches as u64)
255        .saturating_mul(4);
256    if per_chunk_row == 0 {
257        return 0;
258    }
259    let max_chunk = DISPATCH_LIMIT_BYTES / per_chunk_row;
260    if max_chunk >= num_patches as u64 {
261        0 // full attention fits — no chunking needed
262    } else {
263        // Round down to the nearest multiple of 64 for alignment, minimum 16.
264        let c = (max_chunk as usize / 64) * 64;
265        c.max(16)
266    }
267}
268
269/// Compute the largest batch size whose **all-layers** attention tape fits
270/// within `limit_gb`.
271///
272/// Uses the exact `N²` formula (independent of chunk_size):
273/// `all_layers = depth × 2 × B × H × N × N × 4`
274/// → `B_max = limit_bytes / (depth × 2 × H × N² × 4)`
275fn max_safe_batch(depth: usize, num_heads: usize, num_patches: usize, limit_gb: f64) -> usize {
276    let limit_bytes = (limit_gb * (1u64 << 30) as f64) as u64;
277    let per_sample = depth as u64
278        * 2
279        * num_heads as u64
280        * num_patches as u64
281        * num_patches as u64
282        * 4;
283    if per_sample == 0 {
284        return usize::MAX;
285    }
286    (limit_bytes / per_sample).max(1) as usize
287}
288
289/// Train a SensorLM model.
290///
291/// Replace [`SyntheticSensorDataset`] with a real [`CsvSensorDataset`] in
292/// production.
293pub fn train<B: AutodiffBackend>(
294    mut model_cfg: SensorLMConfig,
295    mut train_cfg: TrainingConfig,
296) -> Result<()>
297where
298    B::Device: Clone + Default + Send + Sync + std::fmt::Debug + 'static,
299    B::InnerBackend: burn::tensor::backend::Backend<Device = B::Device>,
300{
301    // -----------------------------------------------------------------------
302    // Pre-flight memory guard
303    //
304    // KEY FACT: Burn's forward pass builds autodiff tape for EVERY transformer
305    // layer before loss.backward() executes.  At the forward→backward boundary
306    // ALL `depth` layers' chunk tensors (attention scores + softmax weights)
307    // are simultaneously in GPU memory.  The correct metric is therefore
308    // `all_layers_bwd`, NOT `per_layer_bwd`.
309    //
310    // Memory layout at peak (end of forward pass):
311    //   all_layers_attn  = depth × 2 × ceil(N/chunk) × B × H × chunk × N × 4
312    //   static           ≈ params × 16 bytes  (weights + grad + Adam m1/m2)
313    //   other_activations ≈ depth × B × N × d × 12 bytes  (Q/K/V + MLP; small)
314    //
315    // We budget ATTN_VRAM_FRACTION (70%) of VRAM for attention, leaving 30%
316    // for static + activations.  Without --vram-gb we use ALL_LAYERS_LIMIT_GB
317    // (calibrated for 16 GB GPUs).
318    // -----------------------------------------------------------------------
319    let num_patches = model_cfg.sensor_encoder.num_patches();
320
321    // ---- derive attention budget from VRAM --------------------------------
322    let attn_limit_gb: f64 = match train_cfg.vram_gb {
323        Some(vram) => {
324            let limit = vram * ATTN_VRAM_FRACTION;
325            eprintln!(
326                "[sensorlm] VRAM budget: {vram:.0} GB \
327                 → attention limit: {limit:.2} GB (= VRAM × {ATTN_VRAM_FRACTION})"
328            );
329            limit
330        }
331        None => ALL_LAYERS_LIMIT_GB,
332    };
333
334    // ---- auto-cap batch_size when --vram-gb was given --------------------
335    if train_cfg.vram_gb.is_some() {
336        let safe = max_safe_batch(
337            model_cfg.sensor_encoder.depth,
338            model_cfg.sensor_encoder.num_heads,
339            num_patches,
340            attn_limit_gb,
341        );
342        if train_cfg.batch_size > safe {
343            eprintln!(
344                "[sensorlm] Auto-reducing batch_size {} → {safe} \
345                 (largest that fits in {attn_limit_gb:.2} GB attention budget).",
346                train_cfg.batch_size,
347            );
348            train_cfg.batch_size = safe;
349        } else {
350            eprintln!(
351                "[sensorlm] batch_size={} fits  (max safe for this VRAM: {safe}).",
352                train_cfg.batch_size,
353            );
354        }
355    }
356
357    // ---- auto-tune chunk_size to minimise GPU command-buffer submissions --
358    //
359    // With B=2 and chunk=64 the chunked attention produces
360    // ceil(2448/64)=39 chunks × 3 ops × 12 layers = 1 404 GPU submissions
361    // per forward pass.  WGPU submits each as a separate Metal command buffer;
362    // with a tiny dispatch (14 MB) the GPU idles between submissions causing
363    // "Device::maintain: waiting for submission index N" spam and very slow
364    // throughput.
365    //
366    // After the batch is fixed we pick the LARGEST chunk that keeps each
367    // dispatch ≤ 512 MB (safe on Metal).  At B=2 this is full attention
368    // (0 = no chunking), reducing submissions from 1 404 → 36 per forward.
369    {
370        let new_chunk = optimal_chunk_size(
371            train_cfg.batch_size,
372            model_cfg.sensor_encoder.num_heads,
373            num_patches,
374        );
375        let old_chunk = model_cfg.sensor_encoder.attn_chunk_size;
376        if new_chunk != old_chunk {
377            let old_subs = if old_chunk == 0 { 1 } else { num_patches.div_ceil(old_chunk) };
378            let new_subs = if new_chunk == 0 { 1 } else { num_patches.div_ceil(new_chunk) };
379            eprintln!(
380                "[sensorlm] Auto-tuning attn_chunk_size {old_chunk} → {new_chunk} \
381                 ({old_subs} → {new_subs} GPU submissions/layer, \
382                 dispatch ≤ {} MB).",
383                DISPATCH_LIMIT_BYTES / (1024 * 1024),
384            );
385            model_cfg.sensor_encoder.attn_chunk_size = new_chunk;
386        }
387    }
388    // Re-borrow enc after mutating model_cfg.
389    let enc = &model_cfg.sensor_encoder;
390
391    // ---- compute estimates for the (possibly adjusted) batch_size --------
392    let mem = estimate_attn_memory(
393        train_cfg.batch_size,
394        enc.depth,
395        enc.num_heads,
396        num_patches,
397        enc.attn_chunk_size,
398    );
399    let gb = |b: u64| b as f64 / (1024.0_f64.powi(3));
400    let dispatch_gb   = gb(mem.per_dispatch_bytes);
401    let per_layer_gb  = gb(mem.per_layer_bwd_bytes);
402    let all_layers_gb = gb(mem.all_layers_bwd_bytes);
403
404    eprintln!(
405        "[sensorlm] Sensor encoder: N={num_patches} patches, \
406         depth={}, heads={}, chunk_size={}, batch={}",
407        enc.depth, enc.num_heads, enc.attn_chunk_size, train_cfg.batch_size,
408    );
409    eprintln!("[sensorlm] Attention VRAM (score/weight tensors only; add ~1–2 GB for weights+Adam+activations):");
410    eprintln!("[sensorlm]   per GPU dispatch : {dispatch_gb:.3} GB  (TDR risk if > {PER_DISPATCH_WARN_GB} GB)");
411    eprintln!("[sensorlm]   per layer tape   : {per_layer_gb:.2} GB  × {} layers", enc.depth);
412    eprintln!("[sensorlm]   ALL layers peak  : {all_layers_gb:.2} GB  ← actual training peak  (limit: {attn_limit_gb:.2} GB)");
413
414    // ---- soft TDR warning ------------------------------------------------
415    if dispatch_gb > PER_DISPATCH_WARN_GB {
416        eprintln!(
417            "[sensorlm] ⚠  Per-dispatch ({dispatch_gb:.2} GB) > {PER_DISPATCH_WARN_GB} GB — \
418             GPU watchdog (TDR) risk. Reduce attn_chunk_size (current: {}).",
419            enc.attn_chunk_size,
420        );
421    }
422
423    // ---- hard guard on ALL-layers peak -----------------------------------
424    if all_layers_gb > attn_limit_gb {
425        let safe_batch = max_safe_batch(
426            enc.depth,
427            enc.num_heads,
428            num_patches,
429            attn_limit_gb,
430        );
431        let safe_chunk = (enc.attn_chunk_size / 2).max(16);
432        let vram_hint = if train_cfg.vram_gb.is_none() {
433            "Specify your GPU memory with --vram-gb <GB> to auto-select the \
434             right batch size, or pass --no-vram-check to skip this guard."
435                .to_string()
436        } else {
437            format!("Pass --no-vram-check to proceed despite the estimate, or lower --batch-size to {safe_batch}.")
438        };
439
440        let msg = format!(
441            "All-layers attention peak ({all_layers_gb:.2} GB) exceeds \
442             the budget ({attn_limit_gb:.2} GB).\n\
443             \n\
444             WHY: Burn builds autodiff tape for all {depth} transformer layers \
445             during the forward pass.  At the forward→backward boundary all \
446             {depth} layers' chunk tensors are simultaneously in GPU memory — \
447             the peak is depth × per-layer, not just per-layer.\n\
448             \n\
449             Largest safe batch for this model + VRAM: {safe_batch}\n\
450             \n\
451             Options:\n\
452             • --vram-gb <GB>       tell the tool your GPU — batch auto-selected\n\
453             • --batch-size {safe_batch:<4}      largest batch that fits\n\
454             • --model-size tiny    ~11 M params, much lower attention memory\n\
455             • --model-size small   ~44 M params, moderate memory\n\
456             • attn_chunk_size {safe_chunk}  halving chunk halves per-layer tape\n\
457             • --no-vram-check      bypass guard (crashes are your responsibility)\n\
458             \n\
459             {vram_hint}",
460            depth = enc.depth,
461        );
462
463        if train_cfg.skip_vram_check {
464            eprintln!("[sensorlm] ⚠  Guard exceeded but --no-vram-check set:\n{msg}");
465            eprintln!("[sensorlm] ⚠  Proceeding — monitor GPU memory carefully.");
466        } else {
467            return Err(crate::error::SensorLMError::Other(anyhow::anyhow!("{msg}")));
468        }
469    }
470
471    let device = B::Device::default();
472    let max_seq_len = train_cfg.caption_key.max_tokens();
473
474    // -----------------------------------------------------------------------
475    // Datasets (synthetic – replace with CsvSensorDataset for real data)
476    // -----------------------------------------------------------------------
477    let train_samples = train_cfg.batch_size * 20;
478    let valid_samples = train_cfg.batch_size * 4;
479
480    let train_dataset = SyntheticSensorDataset::new(train_samples, train_cfg.seed, max_seq_len);
481    let valid_dataset = SyntheticSensorDataset::new(valid_samples, train_cfg.seed + 1, max_seq_len);
482
483    // -----------------------------------------------------------------------
484    // Step counts
485    // -----------------------------------------------------------------------
486    let num_workers = train_cfg.num_workers.max(1);
487    let train_steps = train_samples / train_cfg.batch_size;
488    let valid_steps = valid_samples / train_cfg.batch_size;
489
490    eprintln!(
491        "[sensorlm] Training plan: {train_steps} train steps + \
492         {valid_steps} validation steps per epoch  \
493         (dataset: {train_samples} train / {valid_samples} valid samples)"
494    );
495
496    // -----------------------------------------------------------------------
497    // Batchers
498    // -----------------------------------------------------------------------
499    let batcher_train = SensorLMBatcher::<B>::new(
500        device.clone(),
501        model_cfg.sensor_encoder.time_steps,
502        model_cfg.sensor_encoder.num_channels,
503        max_seq_len,
504    );
505    let batcher_valid = SensorLMBatcher::<B::InnerBackend>::new(
506        device.clone(),
507        model_cfg.sensor_encoder.time_steps,
508        model_cfg.sensor_encoder.num_channels,
509        max_seq_len,
510    );
511
512    // Burn's PartialDataset::split divides by num_workers — 0 would panic.
513    let train_loader = DataLoaderBuilder::new(batcher_train)
514        .batch_size(train_cfg.batch_size)
515        .shuffle(train_cfg.seed)
516        .num_workers(num_workers)
517        .build(train_dataset);
518
519    let valid_loader = DataLoaderBuilder::new(batcher_valid)
520        .batch_size(train_cfg.batch_size)
521        .num_workers(num_workers)
522        .build(valid_dataset);
523
524    // -----------------------------------------------------------------------
525    // Model and optimiser
526    // -----------------------------------------------------------------------
527    let model = SensorLMModel::<B>::new(&model_cfg, &device);
528
529    let optimizer = AdamConfig::new()
530        .with_beta_1(train_cfg.beta1 as f32)
531        .with_beta_2(train_cfg.beta2 as f32)
532        .with_epsilon(train_cfg.epsilon as f32)
533        .with_weight_decay(Some(burn::optim::decay::WeightDecayConfig::new(
534            train_cfg.weight_decay, // f64 penalty
535        )))
536        .init();
537
538    // -----------------------------------------------------------------------
539    // LR scheduler (rsqrt with warm-up and cool-down)
540    // -----------------------------------------------------------------------
541    let lr_scheduler = RsqrtScheduler::new(
542        train_cfg.lr,
543        train_cfg.total_steps,
544        train_cfg.warmup_fraction,
545        train_cfg.cooldown_fraction,
546    );
547
548    // -----------------------------------------------------------------------
549    // Learner
550    // -----------------------------------------------------------------------
551    std::fs::create_dir_all(&train_cfg.artifact_dir)?;
552
553    // SensorLMRenderer replaces Burn's default CliMetricsRenderer which
554    // contains `dbg!(item)` calls that dump raw structs on every step.
555    let renderer = SensorLMRenderer::new(train_steps, valid_steps);
556
557    let builder = LearnerBuilder::new(&train_cfg.artifact_dir)
558        .metric_train_numeric(LossMetric::<B>::new())
559        .metric_valid_numeric(LossMetric::<B::InnerBackend>::new())
560        .with_file_checkpointer(CompactRecorder::new())
561        .renderer(renderer)
562        .devices(vec![device])
563        .num_epochs(1);
564
565    let builder = if train_cfg.show_summary { builder.summary() } else { builder };
566
567    let learner = builder.build(model, optimizer, lr_scheduler);
568
569    let _trained_model = learner.fit(train_loader, valid_loader);
570
571    eprintln!(
572        "\n[sensorlm] Training complete — \
573         {train_steps} train + {valid_steps} valid steps."
574    );
575    Ok(())
576}
577
578/// Save a trained model to disk using full-precision binary format.
579pub fn save_model<B: AutodiffBackend>(
580    model: SensorLMModel<B>,
581    path: &Path,
582) -> Result<()> {
583    let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
584    model
585        .save_file(path, &recorder)
586        .map_err(|e| crate::error::SensorLMError::Other(anyhow::anyhow!("{e}")))?;
587    Ok(())
588}
589
590/// Load a model from a checkpoint saved with [`save_model`].
591pub fn load_model<B: AutodiffBackend>(
592    cfg: &SensorLMConfig,
593    path: &Path,
594    device: &B::Device,
595) -> Result<SensorLMModel<B>> {
596    let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
597    let model = SensorLMModel::<B>::new(cfg, device)
598        .load_file(path, &recorder, device)
599        .map_err(|e| crate::error::SensorLMError::Other(anyhow::anyhow!("{e}")))?;
600    Ok(model)
601}