1use burn::{
11 data::dataloader::DataLoaderBuilder,
12 module::Module,
13 optim::AdamConfig,
14 record::{CompactRecorder, FullPrecisionSettings, BinFileRecorder},
15 tensor::backend::AutodiffBackend,
16 train::{
17 metric::LossMetric,
18 renderer::{MetricState, MetricsRenderer, TrainingProgress},
19 LearnerBuilder,
20 },
21};
22use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
23use std::path::Path;
24use std::time::Instant;
25
26use crate::config::{SensorLMConfig, TrainingConfig};
27use crate::data::dataset::SyntheticSensorDataset;
28use crate::model::sensorlm::{SensorLMBatcher, SensorLMModel};
29use crate::training::scheduler::RsqrtScheduler;
30use crate::error::Result;
31
32struct SensorLMRenderer {
44 _multi: MultiProgress, train_bar: ProgressBar,
46 valid_bar: ProgressBar,
47 train_loss: Option<f64>,
48 valid_loss: Option<f64>,
49 step_start: Instant,
50}
51
52impl SensorLMRenderer {
53 fn new(train_steps: usize, valid_steps: usize) -> Self {
54 let multi = MultiProgress::new();
55
56 let style = ProgressStyle::with_template(
57 "{prefix:.bold.cyan} [{bar:45.green/dim}] \
58 {pos:>3}/{len} \
59 {elapsed_precise} eta {eta_precise} \
60 {msg}",
61 )
62 .unwrap()
63 .progress_chars("█▉▊▋▌▍▎▏ ");
64
65 let train_bar = multi.add(ProgressBar::new(train_steps as u64));
66 train_bar.set_style(style.clone());
67 train_bar.set_prefix("Train");
68
69 let valid_bar = multi.add(ProgressBar::new(valid_steps as u64));
70 valid_bar.set_style(style);
71 valid_bar.set_prefix("Valid");
72
73 Self {
74 _multi: multi,
75 train_bar,
76 valid_bar,
77 train_loss: None,
78 valid_loss: None,
79 step_start: Instant::now(),
80 }
81 }
82}
83
84impl MetricsRenderer for SensorLMRenderer {
85 fn update_train(&mut self, state: MetricState) {
87 if let MetricState::Numeric(_entry, val) = state {
88 self.train_loss = Some(val);
89 }
90 }
91
92 fn update_valid(&mut self, state: MetricState) {
94 if let MetricState::Numeric(_entry, val) = state {
95 self.valid_loss = Some(val);
96 }
97 }
98
99 fn render_train(&mut self, item: TrainingProgress) {
101 let step = item.iteration;
102 let total = if item.progress.items_total > 0 {
103 let batch = (item.progress.items_processed as f64 / step as f64).round() as u64;
105 (item.progress.items_total as u64).div_ceil(batch.max(1))
106 } else {
107 self.train_bar.length().unwrap_or(0)
108 };
109
110 self.train_bar.set_length(total);
111 self.train_bar.set_position(step as u64);
112
113 let elapsed = self.step_start.elapsed().as_secs_f64();
114 self.step_start = Instant::now();
115
116 let msg = match self.train_loss {
117 Some(l) => format!(
118 "loss {l:.4} ({elapsed:.1}s/step) epoch {}/{}",
119 item.epoch, item.epoch_total
120 ),
121 None => format!(
122 "{elapsed:.1}s/step epoch {}/{}",
123 item.epoch, item.epoch_total
124 ),
125 };
126 self.train_bar.set_message(msg);
127 }
128
129 fn render_valid(&mut self, item: TrainingProgress) {
131 let step = item.iteration;
132 let total = self.valid_bar.length().unwrap_or(0);
133 self.valid_bar.set_position(step.min(total as usize) as u64);
134
135 let msg = match self.valid_loss {
136 Some(l) => format!("loss {l:.4}"),
137 None => String::new(),
138 };
139 self.valid_bar.set_message(msg);
140 }
141}
142
143struct AttnMemEstimate {
145 per_dispatch_bytes: u64,
148 per_layer_bwd_bytes: u64,
156 all_layers_bwd_bytes: u64,
159}
160
161fn estimate_attn_memory(
178 batch_size: usize,
179 depth: usize,
180 num_heads: usize,
181 num_patches: usize,
182 chunk_size: usize,
183) -> AttnMemEstimate {
184 let effective_chunk = if chunk_size == 0 { num_patches } else { chunk_size };
185
186 let per_dispatch_bytes = batch_size as u64
189 * num_heads as u64
190 * effective_chunk as u64
191 * num_patches as u64
192 * 4;
193
194 let per_layer_bwd_bytes = 2
198 * batch_size as u64
199 * num_heads as u64
200 * num_patches as u64
201 * num_patches as u64
202 * 4;
203
204 AttnMemEstimate {
205 per_dispatch_bytes,
206 per_layer_bwd_bytes,
207 all_layers_bwd_bytes: depth as u64 * per_layer_bwd_bytes,
208 }
209}
210
211const ALL_LAYERS_LIMIT_GB: f64 = 11.0; const ATTN_VRAM_FRACTION: f64 = 0.70;
226
227const DISPATCH_LIMIT_BYTES: u64 = 512 * 1024 * 1024; const PER_DISPATCH_WARN_GB: f64 = 0.5;
241
242fn optimal_chunk_size(batch_size: usize, num_heads: usize, num_patches: usize) -> usize {
250 let per_chunk_row = (batch_size as u64)
253 .saturating_mul(num_heads as u64)
254 .saturating_mul(num_patches as u64)
255 .saturating_mul(4);
256 if per_chunk_row == 0 {
257 return 0;
258 }
259 let max_chunk = DISPATCH_LIMIT_BYTES / per_chunk_row;
260 if max_chunk >= num_patches as u64 {
261 0 } else {
263 let c = (max_chunk as usize / 64) * 64;
265 c.max(16)
266 }
267}
268
269fn max_safe_batch(depth: usize, num_heads: usize, num_patches: usize, limit_gb: f64) -> usize {
276 let limit_bytes = (limit_gb * (1u64 << 30) as f64) as u64;
277 let per_sample = depth as u64
278 * 2
279 * num_heads as u64
280 * num_patches as u64
281 * num_patches as u64
282 * 4;
283 if per_sample == 0 {
284 return usize::MAX;
285 }
286 (limit_bytes / per_sample).max(1) as usize
287}
288
289pub fn train<B: AutodiffBackend>(
294 mut model_cfg: SensorLMConfig,
295 mut train_cfg: TrainingConfig,
296) -> Result<()>
297where
298 B::Device: Clone + Default + Send + Sync + std::fmt::Debug + 'static,
299 B::InnerBackend: burn::tensor::backend::Backend<Device = B::Device>,
300{
301 let num_patches = model_cfg.sensor_encoder.num_patches();
320
321 let attn_limit_gb: f64 = match train_cfg.vram_gb {
323 Some(vram) => {
324 let limit = vram * ATTN_VRAM_FRACTION;
325 eprintln!(
326 "[sensorlm] VRAM budget: {vram:.0} GB \
327 → attention limit: {limit:.2} GB (= VRAM × {ATTN_VRAM_FRACTION})"
328 );
329 limit
330 }
331 None => ALL_LAYERS_LIMIT_GB,
332 };
333
334 if train_cfg.vram_gb.is_some() {
336 let safe = max_safe_batch(
337 model_cfg.sensor_encoder.depth,
338 model_cfg.sensor_encoder.num_heads,
339 num_patches,
340 attn_limit_gb,
341 );
342 if train_cfg.batch_size > safe {
343 eprintln!(
344 "[sensorlm] Auto-reducing batch_size {} → {safe} \
345 (largest that fits in {attn_limit_gb:.2} GB attention budget).",
346 train_cfg.batch_size,
347 );
348 train_cfg.batch_size = safe;
349 } else {
350 eprintln!(
351 "[sensorlm] batch_size={} fits (max safe for this VRAM: {safe}).",
352 train_cfg.batch_size,
353 );
354 }
355 }
356
357 {
370 let new_chunk = optimal_chunk_size(
371 train_cfg.batch_size,
372 model_cfg.sensor_encoder.num_heads,
373 num_patches,
374 );
375 let old_chunk = model_cfg.sensor_encoder.attn_chunk_size;
376 if new_chunk != old_chunk {
377 let old_subs = if old_chunk == 0 { 1 } else { num_patches.div_ceil(old_chunk) };
378 let new_subs = if new_chunk == 0 { 1 } else { num_patches.div_ceil(new_chunk) };
379 eprintln!(
380 "[sensorlm] Auto-tuning attn_chunk_size {old_chunk} → {new_chunk} \
381 ({old_subs} → {new_subs} GPU submissions/layer, \
382 dispatch ≤ {} MB).",
383 DISPATCH_LIMIT_BYTES / (1024 * 1024),
384 );
385 model_cfg.sensor_encoder.attn_chunk_size = new_chunk;
386 }
387 }
388 let enc = &model_cfg.sensor_encoder;
390
391 let mem = estimate_attn_memory(
393 train_cfg.batch_size,
394 enc.depth,
395 enc.num_heads,
396 num_patches,
397 enc.attn_chunk_size,
398 );
399 let gb = |b: u64| b as f64 / (1024.0_f64.powi(3));
400 let dispatch_gb = gb(mem.per_dispatch_bytes);
401 let per_layer_gb = gb(mem.per_layer_bwd_bytes);
402 let all_layers_gb = gb(mem.all_layers_bwd_bytes);
403
404 eprintln!(
405 "[sensorlm] Sensor encoder: N={num_patches} patches, \
406 depth={}, heads={}, chunk_size={}, batch={}",
407 enc.depth, enc.num_heads, enc.attn_chunk_size, train_cfg.batch_size,
408 );
409 eprintln!("[sensorlm] Attention VRAM (score/weight tensors only; add ~1–2 GB for weights+Adam+activations):");
410 eprintln!("[sensorlm] per GPU dispatch : {dispatch_gb:.3} GB (TDR risk if > {PER_DISPATCH_WARN_GB} GB)");
411 eprintln!("[sensorlm] per layer tape : {per_layer_gb:.2} GB × {} layers", enc.depth);
412 eprintln!("[sensorlm] ALL layers peak : {all_layers_gb:.2} GB ← actual training peak (limit: {attn_limit_gb:.2} GB)");
413
414 if dispatch_gb > PER_DISPATCH_WARN_GB {
416 eprintln!(
417 "[sensorlm] ⚠ Per-dispatch ({dispatch_gb:.2} GB) > {PER_DISPATCH_WARN_GB} GB — \
418 GPU watchdog (TDR) risk. Reduce attn_chunk_size (current: {}).",
419 enc.attn_chunk_size,
420 );
421 }
422
423 if all_layers_gb > attn_limit_gb {
425 let safe_batch = max_safe_batch(
426 enc.depth,
427 enc.num_heads,
428 num_patches,
429 attn_limit_gb,
430 );
431 let safe_chunk = (enc.attn_chunk_size / 2).max(16);
432 let vram_hint = if train_cfg.vram_gb.is_none() {
433 "Specify your GPU memory with --vram-gb <GB> to auto-select the \
434 right batch size, or pass --no-vram-check to skip this guard."
435 .to_string()
436 } else {
437 format!("Pass --no-vram-check to proceed despite the estimate, or lower --batch-size to {safe_batch}.")
438 };
439
440 let msg = format!(
441 "All-layers attention peak ({all_layers_gb:.2} GB) exceeds \
442 the budget ({attn_limit_gb:.2} GB).\n\
443 \n\
444 WHY: Burn builds autodiff tape for all {depth} transformer layers \
445 during the forward pass. At the forward→backward boundary all \
446 {depth} layers' chunk tensors are simultaneously in GPU memory — \
447 the peak is depth × per-layer, not just per-layer.\n\
448 \n\
449 Largest safe batch for this model + VRAM: {safe_batch}\n\
450 \n\
451 Options:\n\
452 • --vram-gb <GB> tell the tool your GPU — batch auto-selected\n\
453 • --batch-size {safe_batch:<4} largest batch that fits\n\
454 • --model-size tiny ~11 M params, much lower attention memory\n\
455 • --model-size small ~44 M params, moderate memory\n\
456 • attn_chunk_size {safe_chunk} halving chunk halves per-layer tape\n\
457 • --no-vram-check bypass guard (crashes are your responsibility)\n\
458 \n\
459 {vram_hint}",
460 depth = enc.depth,
461 );
462
463 if train_cfg.skip_vram_check {
464 eprintln!("[sensorlm] ⚠ Guard exceeded but --no-vram-check set:\n{msg}");
465 eprintln!("[sensorlm] ⚠ Proceeding — monitor GPU memory carefully.");
466 } else {
467 return Err(crate::error::SensorLMError::Other(anyhow::anyhow!("{msg}")));
468 }
469 }
470
471 let device = B::Device::default();
472 let max_seq_len = train_cfg.caption_key.max_tokens();
473
474 let train_samples = train_cfg.batch_size * 20;
478 let valid_samples = train_cfg.batch_size * 4;
479
480 let train_dataset = SyntheticSensorDataset::new(train_samples, train_cfg.seed, max_seq_len);
481 let valid_dataset = SyntheticSensorDataset::new(valid_samples, train_cfg.seed + 1, max_seq_len);
482
483 let num_workers = train_cfg.num_workers.max(1);
487 let train_steps = train_samples / train_cfg.batch_size;
488 let valid_steps = valid_samples / train_cfg.batch_size;
489
490 eprintln!(
491 "[sensorlm] Training plan: {train_steps} train steps + \
492 {valid_steps} validation steps per epoch \
493 (dataset: {train_samples} train / {valid_samples} valid samples)"
494 );
495
496 let batcher_train = SensorLMBatcher::<B>::new(
500 device.clone(),
501 model_cfg.sensor_encoder.time_steps,
502 model_cfg.sensor_encoder.num_channels,
503 max_seq_len,
504 );
505 let batcher_valid = SensorLMBatcher::<B::InnerBackend>::new(
506 device.clone(),
507 model_cfg.sensor_encoder.time_steps,
508 model_cfg.sensor_encoder.num_channels,
509 max_seq_len,
510 );
511
512 let train_loader = DataLoaderBuilder::new(batcher_train)
514 .batch_size(train_cfg.batch_size)
515 .shuffle(train_cfg.seed)
516 .num_workers(num_workers)
517 .build(train_dataset);
518
519 let valid_loader = DataLoaderBuilder::new(batcher_valid)
520 .batch_size(train_cfg.batch_size)
521 .num_workers(num_workers)
522 .build(valid_dataset);
523
524 let model = SensorLMModel::<B>::new(&model_cfg, &device);
528
529 let optimizer = AdamConfig::new()
530 .with_beta_1(train_cfg.beta1 as f32)
531 .with_beta_2(train_cfg.beta2 as f32)
532 .with_epsilon(train_cfg.epsilon as f32)
533 .with_weight_decay(Some(burn::optim::decay::WeightDecayConfig::new(
534 train_cfg.weight_decay, )))
536 .init();
537
538 let lr_scheduler = RsqrtScheduler::new(
542 train_cfg.lr,
543 train_cfg.total_steps,
544 train_cfg.warmup_fraction,
545 train_cfg.cooldown_fraction,
546 );
547
548 std::fs::create_dir_all(&train_cfg.artifact_dir)?;
552
553 let renderer = SensorLMRenderer::new(train_steps, valid_steps);
556
557 let builder = LearnerBuilder::new(&train_cfg.artifact_dir)
558 .metric_train_numeric(LossMetric::<B>::new())
559 .metric_valid_numeric(LossMetric::<B::InnerBackend>::new())
560 .with_file_checkpointer(CompactRecorder::new())
561 .renderer(renderer)
562 .devices(vec![device])
563 .num_epochs(1);
564
565 let builder = if train_cfg.show_summary { builder.summary() } else { builder };
566
567 let learner = builder.build(model, optimizer, lr_scheduler);
568
569 let _trained_model = learner.fit(train_loader, valid_loader);
570
571 eprintln!(
572 "\n[sensorlm] Training complete — \
573 {train_steps} train + {valid_steps} valid steps."
574 );
575 Ok(())
576}
577
578pub fn save_model<B: AutodiffBackend>(
580 model: SensorLMModel<B>,
581 path: &Path,
582) -> Result<()> {
583 let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
584 model
585 .save_file(path, &recorder)
586 .map_err(|e| crate::error::SensorLMError::Other(anyhow::anyhow!("{e}")))?;
587 Ok(())
588}
589
590pub fn load_model<B: AutodiffBackend>(
592 cfg: &SensorLMConfig,
593 path: &Path,
594 device: &B::Device,
595) -> Result<SensorLMModel<B>> {
596 let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
597 let model = SensorLMModel::<B>::new(cfg, device)
598 .load_file(path, &recorder, device)
599 .map_err(|e| crate::error::SensorLMError::Other(anyhow::anyhow!("{e}")))?;
600 Ok(model)
601}