sensorlm/config.rs
1//! Hierarchical configuration structs for every subsystem.
2//!
3//! All structs implement [`serde::Serialize`] / [`serde::Deserialize`] so they
4//! can be persisted to / loaded from JSON config files.
5//!
6//! The defaults mirror the reference Python configuration exactly.
7
8use serde::{Deserialize, Serialize};
9
10use crate::constants::*;
11
12// ===========================================================================
13// Sensor encoder (ViT)
14// ===========================================================================
15
16/// Configuration for the Vision Transformer sensor encoder.
17///
18/// The encoder treats wearable sensor data as a 2-D grid
19/// `(TIME_STEPS × NUM_CHANNELS)` and divides it into rectangular patches of
20/// shape `(patch_h × patch_w)`.
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct SensorEncoderConfig {
23 /// Number of time-steps in the input signal (default: 1440 = 24 h × 60 min).
24 pub time_steps: usize,
25 /// Number of sensor channels (features) per time-step (default: 34).
26 pub num_channels: usize,
27 /// Patch height (time axis), default: 10 minutes.
28 pub patch_h: usize,
29 /// Patch width (channel axis), default: 2 channels.
30 pub patch_w: usize,
31 /// Transformer hidden dimension (ViT-B = 768).
32 pub d_model: usize,
33 /// Number of transformer layers (ViT-B = 12).
34 pub depth: usize,
35 /// Number of attention heads per layer (ViT-B = 12).
36 pub num_heads: usize,
37 /// Feed-forward MLP hidden dimension (ViT-B = 3072).
38 pub mlp_dim: usize,
39 /// Dropout probability applied inside each transformer block.
40 pub dropout: f64,
41 /// Type of sequence pooling used after the transformer.
42 /// `"map"` (Multihead Attention Pooling) is the default and matches the
43 /// reference implementation. `"gap"` (global average pooling) is a
44 /// cheaper alternative.
45 pub pool_type: PoolType,
46 /// Whether to zero-initialise the output projection in the MAP head.
47 pub head_zeroinit: bool,
48 /// Chunked-attention window size (number of query rows per chunk).
49 ///
50 /// Limits **forward-pass** peak attention memory from `O(B·H·N²)` to
51 /// `O(B·H·chunk·N)` and keeps individual WGPU GPU dispatches small
52 /// enough to avoid OS watchdog (TDR) timeouts.
53 ///
54 /// **⚠ Training caveat — chunking does NOT save backward memory.**
55 /// Burn's autodiff tape records every intermediate tensor produced inside
56 /// the chunk loop (`q_chunk`, `scores`, `attn`, `chunk_out`). All chunks
57 /// for all layers are kept alive simultaneously until `loss.backward()`
58 /// completes. True backward memory savings require gradient checkpointing,
59 /// which is not yet implemented.
60 ///
61 /// Rule of thumb for GPU VRAM (fp32, ViT-B, N = 2448, H = 12):
62 ///
63 /// | chunk | fwd attn @ B=4 | fwd attn @ B=8 |
64 /// |-------|----------------|----------------|
65 /// | 2448 (off) | 4.3 GB | 8.6 GB |
66 /// | 256 | 450 MB | 900 MB |
67 /// | 128 | 225 MB | 450 MB |
68 /// | 64 | 112 MB | 225 MB |
69 ///
70 /// Set to `0` to disable chunking (full N×N matrix — **not recommended
71 /// on GPU** due to TDR risk and peak memory).
72 ///
73 /// Default: `64`.
74 pub attn_chunk_size: usize,
75}
76
77impl Default for SensorEncoderConfig {
78 fn default() -> Self {
79 Self {
80 time_steps: TIME_STEPS,
81 num_channels: NUM_CHANNELS,
82 patch_h: PATCH_H,
83 patch_w: PATCH_W,
84 d_model: VIT_WIDTH,
85 depth: VIT_DEPTH,
86 num_heads: VIT_HEADS,
87 mlp_dim: VIT_MLP_DIM,
88 dropout: 0.0,
89 pool_type: PoolType::Map,
90 head_zeroinit: false,
91 attn_chunk_size: 64,
92 }
93 }
94}
95
96impl SensorEncoderConfig {
97 /// Total number of patches = (time_steps / patch_h) × (num_channels / patch_w).
98 ///
99 /// Channel dimension is padded up to the next multiple of `patch_w` if
100 /// `num_channels` is not evenly divisible.
101 pub fn num_patches(&self) -> usize {
102 let pt = self.time_steps / self.patch_h;
103 let pc = (self.num_channels + self.patch_w - 1) / self.patch_w;
104 pt * pc
105 }
106}
107
108// ===========================================================================
109// Named model-size presets
110// ===========================================================================
111
112/// Named ViT model-size variants, matching the standard ViT paper dimensions.
113///
114/// Memory figures assume fp32, N = 2448 patches, chunk = 64, and cover
115/// *attention score/weight tensors for one transformer layer* (the practical
116/// backward-pass peak). Total GPU memory is 3–5× higher once weights,
117/// activations, and Adam optimizer states are included.
118///
119/// | Size | d_model | heads | ~params | per-layer bwd B=16 | per-layer bwd B=4 |
120/// |-------|---------|-------|---------|--------------------|-------------------|
121/// | Tiny | 192 | 3 | ~11 M | 2.1 GB | 0.5 GB |
122/// | Small | 384 | 6 | ~44 M | 4.4 GB | 1.1 GB |
123/// | Base | 768 | 12 | ~205 M | 17.5 GB ✗ | 2.2 GB |
124///
125/// Recommended `--batch-size` per preset (WGPU / Metal, 16 GB device):
126/// - `tiny`: up to **16** — comfortable; per-layer bwd ≈ 2.1 GB
127/// - `small`: up to **8** — comfortable; per-layer bwd ≈ 2.2 GB
128/// - `base`: up to **4** — per-layer bwd ≈ 2.2 GB; total ≈ 10 GB
129#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
130pub enum ModelSize {
131 /// ViT-Ti: d=192, depth=12, heads=3, mlp=768. Fits in ~2 GB VRAM.
132 #[default]
133 Tiny,
134 /// ViT-S: d=384, depth=12, heads=6, mlp=1536. Fits in ~6 GB VRAM.
135 Small,
136 /// ViT-B: d=768, depth=12, heads=12, mlp=3072. Requires ≥ 16 GB VRAM.
137 Base,
138}
139
140impl ModelSize {
141 /// Return the transformer hidden dimension for this size.
142 pub fn d_model(self) -> usize {
143 match self {
144 Self::Tiny => 192,
145 Self::Small => 384,
146 Self::Base => VIT_WIDTH, // 768
147 }
148 }
149
150 /// Return the number of transformer layers.
151 pub fn depth(self) -> usize {
152 12 // same across all ViT variants
153 }
154
155 /// Return the number of attention heads.
156 pub fn num_heads(self) -> usize {
157 match self {
158 Self::Tiny => 3,
159 Self::Small => 6,
160 Self::Base => VIT_HEADS, // 12
161 }
162 }
163
164 /// Return the MLP hidden dimension (4 × d_model).
165 pub fn mlp_dim(self) -> usize {
166 self.d_model() * 4
167 }
168
169 /// Build a [`SensorEncoderConfig`] for this size with sensible defaults.
170 pub fn sensor_encoder_config(self) -> SensorEncoderConfig {
171 SensorEncoderConfig {
172 d_model: self.d_model(),
173 depth: self.depth(),
174 num_heads: self.num_heads(),
175 mlp_dim: self.mlp_dim(),
176 // All sizes use chunking. The per-dispatch attention tensor is
177 // (B, H, chunk, N) × 4 bytes — its size depends on H, not d_model,
178 // so full attention (chunk=0) would still produce a ≥1 GB kernel at
179 // B=16 for Tiny (H=3) which risks an OS GPU watchdog (TDR) timeout.
180 // chunk=128, B=16, H=3: dispatch = 16×3×128×2448×4 ≈ 144 MB ✓
181 // chunk=64, B=16, H=3: dispatch = 16×3× 64×2448×4 ≈ 72 MB ✓
182 attn_chunk_size: 64, // safe for all sizes and recommended batch sizes
183 ..SensorEncoderConfig::default()
184 }
185 }
186
187 /// Build a [`TextEncoderConfig`] for this size.
188 pub fn text_encoder_config(self) -> TextEncoderConfig {
189 TextEncoderConfig {
190 d_model: self.d_model(),
191 depth: self.depth(),
192 num_heads: self.num_heads(),
193 mlp_dim: self.mlp_dim(),
194 out_dim: Some(self.d_model()), // embed_dim matches d_model
195 ..TextEncoderConfig::default()
196 }
197 }
198
199 /// Build a complete [`SensorLMConfig`] for this size.
200 pub fn sensorlm_config(self) -> SensorLMConfig {
201 SensorLMConfig {
202 sensor_encoder: self.sensor_encoder_config(),
203 text_encoder: self.text_encoder_config(),
204 embed_dim: self.d_model(),
205 ..SensorLMConfig::default()
206 }
207 }
208
209 /// Human-readable approximate parameter count for both towers combined.
210 pub fn approx_params(self) -> &'static str {
211 match self {
212 Self::Tiny => "~11 M",
213 Self::Small => "~44 M",
214 Self::Base => "~205 M",
215 }
216 }
217}
218
219/// Sequence-level pooling strategy after the ViT encoder.
220#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
221pub enum PoolType {
222 /// **Multihead Attention Pooling** – a learnable probe token attends over
223 /// all patch tokens (reference implementation default, matches
224 /// [`crate::model::sensor_encoder`]).
225 Map,
226 /// **Global Average Pooling** – mean over the patch-token sequence
227 /// (cheaper, slightly lower quality).
228 Gap,
229}
230
231// ===========================================================================
232// Text encoder
233// ===========================================================================
234
235/// Configuration for the text transformer encoder.
236#[derive(Debug, Clone, Serialize, Deserialize)]
237pub struct TextEncoderConfig {
238 /// Vocabulary size (default: 32 000, c4_en SentencePiece vocabulary).
239 pub vocab_size: usize,
240 /// Maximum token sequence length.
241 pub max_seq_len: usize,
242 /// Transformer hidden dimension (ViT-B = 768).
243 pub d_model: usize,
244 /// Number of transformer layers (ViT-B = 12).
245 pub depth: usize,
246 /// Number of attention heads per layer.
247 pub num_heads: usize,
248 /// Feed-forward MLP hidden dimension.
249 pub mlp_dim: usize,
250 /// Dropout probability.
251 pub dropout: f64,
252 /// Output projection dimension. `None` means no projection (identity).
253 pub out_dim: Option<usize>,
254}
255
256impl Default for TextEncoderConfig {
257 fn default() -> Self {
258 Self {
259 vocab_size: VOCAB_SIZE,
260 max_seq_len: 1024,
261 d_model: VIT_WIDTH,
262 depth: VIT_DEPTH,
263 num_heads: VIT_HEADS,
264 mlp_dim: VIT_MLP_DIM,
265 dropout: 0.0,
266 out_dim: Some(EMBED_DIM),
267 }
268 }
269}
270
271// ===========================================================================
272// Two-tower SensorLM
273// ===========================================================================
274
275/// Top-level configuration for the combined SensorLM model.
276#[derive(Debug, Clone, Serialize, Deserialize)]
277pub struct SensorLMConfig {
278 /// Sensor (image) encoder configuration.
279 pub sensor_encoder: SensorEncoderConfig,
280 /// Text encoder configuration.
281 pub text_encoder: TextEncoderConfig,
282 /// Shared embedding dimensionality (must match both encoder outputs).
283 pub embed_dim: usize,
284 /// Initial value of the SigLIP temperature scalar (log-scale before exp).
285 pub temperature_init: f32,
286 /// Initial value of the SigLIP bias scalar.
287 pub bias_init: f32,
288}
289
290impl Default for SensorLMConfig {
291 fn default() -> Self {
292 Self {
293 sensor_encoder: SensorEncoderConfig::default(),
294 text_encoder: TextEncoderConfig::default(),
295 embed_dim: EMBED_DIM,
296 temperature_init: TEMPERATURE_INIT,
297 bias_init: BIAS_INIT,
298 }
299 }
300}
301
302// ===========================================================================
303// Training
304// ===========================================================================
305
306/// Learning-rate schedule type.
307#[derive(Debug, Clone, Serialize, Deserialize)]
308pub enum LrScheduleType {
309 /// Inverse-square-root schedule with linear warm-up and cool-down.
310 /// Matches the reference `decay_type='rsqrt'` setting.
311 RsqrtWithWarmupCooldown,
312 /// Cosine annealing.
313 Cosine,
314 /// Constant learning rate (no schedule).
315 Constant,
316}
317
318/// Training hyperparameters.
319#[derive(Debug, Clone, Serialize, Deserialize)]
320pub struct TrainingConfig {
321 /// Model size preset (`tiny` / `small` / `base`).
322 ///
323 /// Overrides `SensorLMConfig` when passed through the CLI. Building
324 /// a config from a preset is the recommended way to avoid mismatched
325 /// `d_model` / `embed_dim` values between the two towers.
326 pub model_size: ModelSize,
327 /// Total number of gradient update steps.
328 pub total_steps: usize,
329 /// Mini-batch size (default: 8).
330 pub batch_size: usize,
331 /// Peak learning rate (default: 5 × 10⁻⁴).
332 pub lr: f64,
333 /// AdamW weight decay (default: 1 × 10⁻⁴).
334 pub weight_decay: f64,
335 /// Adam β₂ (default: 0.999, reference uses `scale_by_adam b2=0.999`).
336 pub beta2: f64,
337 /// Adam β₁.
338 pub beta1: f64,
339 /// Adam ε.
340 pub epsilon: f64,
341 /// Gradient clip norm (default: 1.0).
342 pub grad_clip_norm: f64,
343 /// Fraction of total steps used for linear warm-up (default: 0.2).
344 pub warmup_fraction: f64,
345 /// Fraction of total steps used for cool-down (default: 0.2).
346 pub cooldown_fraction: f64,
347 /// LR schedule type.
348 pub lr_schedule: LrScheduleType,
349 /// Save a checkpoint every N steps.
350 pub checkpoint_every: usize,
351 /// Log metrics every N steps.
352 pub log_every: usize,
353 /// Random seed.
354 pub seed: u64,
355 /// Caption type key to use during this training run.
356 pub caption_key: CaptionKey,
357 /// Path to SentencePiece tokeniser model file.
358 pub tokenizer_path: String,
359 /// Directory to write checkpoints / logs.
360 pub artifact_dir: String,
361 /// Directory containing the dataset (Parquet or raw files).
362 pub data_dir: String,
363 /// Number of DataLoader worker threads for CPU-side data preparation.
364 ///
365 /// Must be **≥ 1** — Burn's `PartialDataset::split` divides the dataset
366 /// length by `num_workers`, so `0` causes a divide-by-zero panic.
367 ///
368 /// The WGPU backend (including Metal on macOS) is internally thread-safe:
369 /// worker threads can call `Tensor::from_floats(…, &device)` safely.
370 /// 2 workers is a reasonable default; increase on machines with many CPU
371 /// cores and fast NVMe storage. Use 1 if you observe data-loading becoming
372 /// the training bottleneck (rare with synthetic data).
373 pub num_workers: usize,
374 /// Available GPU VRAM in gigabytes.
375 ///
376 /// When set the pre-flight guard derives the attention-tensor budget as
377 /// `vram_gb / 3` and **auto-caps `batch_size`** to the largest value that
378 /// fits, so you never have to tune `--batch-size` manually.
379 ///
380 /// Memory split used (all figures are estimates):
381 /// ```text
382 /// ┌─────────────────────────────────────────────────────┐
383 /// │ 1/3 → attention score/weight tensors (one layer) │
384 /// │ 1/3 → model weights + gradients + Adam states │
385 /// │ 1/3 → non-attention activations + OS/driver slack │
386 /// └─────────────────────────────────────────────────────┘
387 /// ```
388 ///
389 /// Examples for ViT-B (base), depth=12, H=12, chunk=64, N=2448:
390 ///
391 /// The peak memory is `depth × per-layer` because Burn's forward pass
392 /// builds autodiff tape for **all** transformer layers before `backward()`
393 /// starts. 70% of VRAM is budgeted for this tape; the rest covers
394 /// model weights + gradients + Adam states + other activations.
395 ///
396 /// | VRAM | attn budget (×0.7) | max batch | all-layers peak |
397 /// |------|-------------------|-----------|-----------------|
398 /// | 8 GB | 5.6 GB | 1 | 6.6 GB |
399 /// | 16 GB | 11.2 GB | 1 | 6.6 GB |
400 /// | 24 GB | 16.8 GB | 2 | 13.1 GB |
401 /// | 32 GB | 22.4 GB | 3 | 19.7 GB |
402 /// | 48 GB | 33.6 GB | 5 | 32.8 GB |
403 /// | 80 GB | 56.0 GB | 8 | 52.4 GB |
404 pub vram_gb: Option<f64>,
405 /// Skip the pre-flight VRAM safety check and proceed even if the
406 /// estimated attention memory exceeds the computed limit.
407 ///
408 /// Use this only when you are certain your GPU has enough free VRAM.
409 /// You accept full responsibility for OOM errors or GPU driver crashes.
410 pub skip_vram_check: bool,
411
412 /// Print Burn's `═══ Learner Summary ═══` table after training.
413 ///
414 /// Disabled by default to keep the terminal output clean.
415 /// Pass `--summary` on the CLI to enable.
416 pub show_summary: bool,
417}
418
419impl Default for TrainingConfig {
420 fn default() -> Self {
421 let total_examples = TOTAL_EXAMPLES;
422 let batch_size = DEFAULT_BATCH_SIZE;
423 let total_steps = total_examples / batch_size;
424 Self {
425 model_size: ModelSize::default(), // Tiny
426 total_steps,
427 batch_size,
428 lr: DEFAULT_LR,
429 weight_decay: DEFAULT_WD,
430 beta1: 0.9,
431 beta2: ADAM_BETA2,
432 epsilon: 1e-8,
433 grad_clip_norm: GRAD_CLIP_NORM,
434 warmup_fraction: 0.2,
435 cooldown_fraction: 0.2,
436 lr_schedule: LrScheduleType::RsqrtWithWarmupCooldown,
437 checkpoint_every: 500,
438 log_every: 50,
439 seed: 0,
440 caption_key: CaptionKey::HighLevelSummary,
441 tokenizer_path: "tokenizer.model".to_string(),
442 artifact_dir: "./artifacts".to_string(),
443 data_dir: "./data".to_string(),
444 num_workers: 2,
445 vram_gb: None,
446 skip_vram_check: false,
447 show_summary: false,
448 }
449 }
450}
451
452/// Which caption tier to use as the text pair during training.
453#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
454pub enum CaptionKey {
455 /// Statistical summary only (level 1).
456 LowLevel,
457 /// Structural patterns only (level 2).
458 MiddleLevel,
459 /// High-level semantic summary (level 3, 256 tokens).
460 HighLevelSummary,
461 /// Full high-level caption (level 3, 1024 tokens).
462 HighLevelAll,
463 /// Levels 2 + 1.
464 MiddleLow,
465 /// Levels 3 + 1.
466 HighLow,
467 /// Levels 3 + 2.
468 HighMiddle,
469 /// All three levels concatenated.
470 HighMiddleLow,
471}
472
473impl CaptionKey {
474 /// Maximum token budget for this caption type.
475 pub fn max_tokens(self) -> usize {
476 match self {
477 Self::LowLevel => 512,
478 Self::MiddleLevel => 512,
479 Self::HighLevelSummary => 256,
480 Self::HighLevelAll => 1024,
481 Self::MiddleLow => 1024,
482 Self::HighLow => 1024,
483 Self::HighMiddle => 512,
484 Self::HighMiddleLow => 1024,
485 }
486 }
487}
488
489// ===========================================================================
490// Inference
491// ===========================================================================
492
493/// Configuration for inference / evaluation.
494#[derive(Debug, Clone, Serialize, Deserialize)]
495pub struct InferenceConfig {
496 /// Path to model checkpoint.
497 pub checkpoint: String,
498 /// Path to tokeniser model.
499 pub tokenizer_path: String,
500 /// Maximum sequence length for text input.
501 pub max_seq_len: usize,
502 /// Batch size for encoding.
503 pub batch_size: usize,
504 /// Use FP16 for faster inference (requires `fp16` feature).
505 pub fp16: bool,
506 /// Caption key to use when generating text from sensor data.
507 pub caption_key: CaptionKey,
508}
509
510impl Default for InferenceConfig {
511 fn default() -> Self {
512 Self {
513 checkpoint: "./artifacts/model_final.bin".to_string(),
514 tokenizer_path: "tokenizer.model".to_string(),
515 max_seq_len: 256,
516 batch_size: 64,
517 fp16: false,
518 caption_key: CaptionKey::HighLevelSummary,
519 }
520 }
521}
522
523// ===========================================================================
524// Quantisation
525// ===========================================================================
526
527/// INT8 quantisation scheme.
528#[derive(Debug, Clone, Serialize, Deserialize)]
529pub enum QuantScheme {
530 /// Symmetric per-tensor quantisation.
531 SymmetricPerTensor,
532 /// Asymmetric per-tensor quantisation.
533 AsymmetricPerTensor,
534 /// Symmetric per-channel (output channel) quantisation.
535 SymmetricPerChannel,
536}
537
538/// Post-training quantisation configuration.
539#[derive(Debug, Clone, Serialize, Deserialize)]
540pub struct QuantizationConfig {
541 /// Source FP32 checkpoint to quantise.
542 pub source_checkpoint: String,
543 /// Output path for the quantised model.
544 pub output_path: String,
545 /// Path to a calibration dataset subset (Parquet).
546 pub calibration_data: String,
547 /// Number of calibration batches.
548 pub num_calibration_batches: usize,
549 /// Batch size during calibration.
550 pub calibration_batch_size: usize,
551 /// INT8 quantisation scheme.
552 pub scheme: QuantScheme,
553 /// Quantise text encoder weights (in addition to sensor encoder).
554 pub quantise_text_encoder: bool,
555 /// Path to tokeniser model.
556 pub tokenizer_path: String,
557}
558
559impl Default for QuantizationConfig {
560 fn default() -> Self {
561 Self {
562 source_checkpoint: "./artifacts/model_final.bin".to_string(),
563 output_path: "./artifacts/model_int8.bin".to_string(),
564 calibration_data: "./data/calibration.parquet".to_string(),
565 num_calibration_batches: 100,
566 calibration_batch_size: 32,
567 scheme: QuantScheme::SymmetricPerTensor,
568 quantise_text_encoder: true,
569 tokenizer_path: "tokenizer.model".to_string(),
570 }
571 }
572}