Skip to main content

sensorlm/
config.rs

1//! Hierarchical configuration structs for every subsystem.
2//!
3//! All structs implement [`serde::Serialize`] / [`serde::Deserialize`] so they
4//! can be persisted to / loaded from JSON config files.
5//!
6//! The defaults mirror the reference Python configuration exactly.
7
8use serde::{Deserialize, Serialize};
9
10use crate::constants::*;
11
12// ===========================================================================
13// Sensor encoder (ViT)
14// ===========================================================================
15
16/// Configuration for the Vision Transformer sensor encoder.
17///
18/// The encoder treats wearable sensor data as a 2-D grid
19/// `(TIME_STEPS × NUM_CHANNELS)` and divides it into rectangular patches of
20/// shape `(patch_h × patch_w)`.
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct SensorEncoderConfig {
23    /// Number of time-steps in the input signal (default: 1440 = 24 h × 60 min).
24    pub time_steps: usize,
25    /// Number of sensor channels (features) per time-step (default: 34).
26    pub num_channels: usize,
27    /// Patch height (time axis), default: 10 minutes.
28    pub patch_h: usize,
29    /// Patch width (channel axis), default: 2 channels.
30    pub patch_w: usize,
31    /// Transformer hidden dimension (ViT-B = 768).
32    pub d_model: usize,
33    /// Number of transformer layers (ViT-B = 12).
34    pub depth: usize,
35    /// Number of attention heads per layer (ViT-B = 12).
36    pub num_heads: usize,
37    /// Feed-forward MLP hidden dimension (ViT-B = 3072).
38    pub mlp_dim: usize,
39    /// Dropout probability applied inside each transformer block.
40    pub dropout: f64,
41    /// Type of sequence pooling used after the transformer.
42    /// `"map"` (Multihead Attention Pooling) is the default and matches the
43    /// reference implementation.  `"gap"` (global average pooling) is a
44    /// cheaper alternative.
45    pub pool_type: PoolType,
46    /// Whether to zero-initialise the output projection in the MAP head.
47    pub head_zeroinit: bool,
48    /// Chunked-attention window size (number of query rows per chunk).
49    ///
50    /// Limits **forward-pass** peak attention memory from `O(B·H·N²)` to
51    /// `O(B·H·chunk·N)` and keeps individual WGPU GPU dispatches small
52    /// enough to avoid OS watchdog (TDR) timeouts.
53    ///
54    /// **⚠ Training caveat — chunking does NOT save backward memory.**
55    /// Burn's autodiff tape records every intermediate tensor produced inside
56    /// the chunk loop (`q_chunk`, `scores`, `attn`, `chunk_out`).  All chunks
57    /// for all layers are kept alive simultaneously until `loss.backward()`
58    /// completes.  True backward memory savings require gradient checkpointing,
59    /// which is not yet implemented.
60    ///
61    /// Rule of thumb for GPU VRAM (fp32, ViT-B, N = 2448, H = 12):
62    ///
63    /// | chunk | fwd attn @ B=4 | fwd attn @ B=8 |
64    /// |-------|----------------|----------------|
65    /// |  2448 (off) | 4.3 GB | 8.6 GB         |
66    /// |   256 | 450 MB         | 900 MB         |
67    /// |   128 | 225 MB         | 450 MB         |
68    /// |    64 | 112 MB         | 225 MB         |
69    ///
70    /// Set to `0` to disable chunking (full N×N matrix — **not recommended
71    /// on GPU** due to TDR risk and peak memory).
72    ///
73    /// Default: `64`.
74    pub attn_chunk_size: usize,
75}
76
77impl Default for SensorEncoderConfig {
78    fn default() -> Self {
79        Self {
80            time_steps: TIME_STEPS,
81            num_channels: NUM_CHANNELS,
82            patch_h: PATCH_H,
83            patch_w: PATCH_W,
84            d_model: VIT_WIDTH,
85            depth: VIT_DEPTH,
86            num_heads: VIT_HEADS,
87            mlp_dim: VIT_MLP_DIM,
88            dropout: 0.0,
89            pool_type: PoolType::Map,
90            head_zeroinit: false,
91            attn_chunk_size: 64,
92        }
93    }
94}
95
96impl SensorEncoderConfig {
97    /// Total number of patches = (time_steps / patch_h) × (num_channels / patch_w).
98    ///
99    /// Channel dimension is padded up to the next multiple of `patch_w` if
100    /// `num_channels` is not evenly divisible.
101    pub fn num_patches(&self) -> usize {
102        let pt = self.time_steps / self.patch_h;
103        let pc = (self.num_channels + self.patch_w - 1) / self.patch_w;
104        pt * pc
105    }
106}
107
108// ===========================================================================
109// Named model-size presets
110// ===========================================================================
111
112/// Named ViT model-size variants, matching the standard ViT paper dimensions.
113///
114/// Memory figures assume fp32, N = 2448 patches, chunk = 64, and cover
115/// *attention score/weight tensors for one transformer layer* (the practical
116/// backward-pass peak).  Total GPU memory is 3–5× higher once weights,
117/// activations, and Adam optimizer states are included.
118///
119/// | Size  | d_model | heads | ~params | per-layer bwd B=16 | per-layer bwd B=4 |
120/// |-------|---------|-------|---------|--------------------|-------------------|
121/// | Tiny  |   192   |   3   |  ~11 M  |  2.1 GB            | 0.5 GB            |
122/// | Small |   384   |   6   |  ~44 M  |  4.4 GB            | 1.1 GB            |
123/// | Base  |   768   |  12   | ~205 M  | 17.5 GB ✗          | 2.2 GB            |
124///
125/// Recommended `--batch-size` per preset (WGPU / Metal, 16 GB device):
126/// - `tiny`:  up to **16** — comfortable; per-layer bwd ≈ 2.1 GB
127/// - `small`: up to **8**  — comfortable; per-layer bwd ≈ 2.2 GB
128/// - `base`:  up to **4**  — per-layer bwd ≈ 2.2 GB; total ≈ 10 GB
129#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
130pub enum ModelSize {
131    /// ViT-Ti: d=192, depth=12, heads=3, mlp=768. Fits in ~2 GB VRAM.
132    #[default]
133    Tiny,
134    /// ViT-S: d=384, depth=12, heads=6, mlp=1536. Fits in ~6 GB VRAM.
135    Small,
136    /// ViT-B: d=768, depth=12, heads=12, mlp=3072. Requires ≥ 16 GB VRAM.
137    Base,
138}
139
140impl ModelSize {
141    /// Return the transformer hidden dimension for this size.
142    pub fn d_model(self) -> usize {
143        match self {
144            Self::Tiny  => 192,
145            Self::Small => 384,
146            Self::Base  => VIT_WIDTH, // 768
147        }
148    }
149
150    /// Return the number of transformer layers.
151    pub fn depth(self) -> usize {
152        12 // same across all ViT variants
153    }
154
155    /// Return the number of attention heads.
156    pub fn num_heads(self) -> usize {
157        match self {
158            Self::Tiny  => 3,
159            Self::Small => 6,
160            Self::Base  => VIT_HEADS, // 12
161        }
162    }
163
164    /// Return the MLP hidden dimension (4 × d_model).
165    pub fn mlp_dim(self) -> usize {
166        self.d_model() * 4
167    }
168
169    /// Build a [`SensorEncoderConfig`] for this size with sensible defaults.
170    pub fn sensor_encoder_config(self) -> SensorEncoderConfig {
171        SensorEncoderConfig {
172            d_model: self.d_model(),
173            depth:   self.depth(),
174            num_heads: self.num_heads(),
175            mlp_dim: self.mlp_dim(),
176            // All sizes use chunking.  The per-dispatch attention tensor is
177            // (B, H, chunk, N) × 4 bytes — its size depends on H, not d_model,
178            // so full attention (chunk=0) would still produce a ≥1 GB kernel at
179            // B=16 for Tiny (H=3) which risks an OS GPU watchdog (TDR) timeout.
180            //   chunk=128, B=16, H=3:  dispatch = 16×3×128×2448×4 ≈ 144 MB ✓
181            //   chunk=64,  B=16, H=3:  dispatch = 16×3× 64×2448×4 ≈  72 MB ✓
182            attn_chunk_size: 64, // safe for all sizes and recommended batch sizes
183            ..SensorEncoderConfig::default()
184        }
185    }
186
187    /// Build a [`TextEncoderConfig`] for this size.
188    pub fn text_encoder_config(self) -> TextEncoderConfig {
189        TextEncoderConfig {
190            d_model:   self.d_model(),
191            depth:     self.depth(),
192            num_heads: self.num_heads(),
193            mlp_dim:   self.mlp_dim(),
194            out_dim:   Some(self.d_model()), // embed_dim matches d_model
195            ..TextEncoderConfig::default()
196        }
197    }
198
199    /// Build a complete [`SensorLMConfig`] for this size.
200    pub fn sensorlm_config(self) -> SensorLMConfig {
201        SensorLMConfig {
202            sensor_encoder: self.sensor_encoder_config(),
203            text_encoder:   self.text_encoder_config(),
204            embed_dim:      self.d_model(),
205            ..SensorLMConfig::default()
206        }
207    }
208
209    /// Human-readable approximate parameter count for both towers combined.
210    pub fn approx_params(self) -> &'static str {
211        match self {
212            Self::Tiny  => "~11 M",
213            Self::Small => "~44 M",
214            Self::Base  => "~205 M",
215        }
216    }
217}
218
219/// Sequence-level pooling strategy after the ViT encoder.
220#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
221pub enum PoolType {
222    /// **Multihead Attention Pooling** – a learnable probe token attends over
223    /// all patch tokens (reference implementation default, matches
224    /// [`crate::model::sensor_encoder`]).
225    Map,
226    /// **Global Average Pooling** – mean over the patch-token sequence
227    /// (cheaper, slightly lower quality).
228    Gap,
229}
230
231// ===========================================================================
232// Text encoder
233// ===========================================================================
234
235/// Configuration for the text transformer encoder.
236#[derive(Debug, Clone, Serialize, Deserialize)]
237pub struct TextEncoderConfig {
238    /// Vocabulary size (default: 32 000, c4_en SentencePiece vocabulary).
239    pub vocab_size: usize,
240    /// Maximum token sequence length.
241    pub max_seq_len: usize,
242    /// Transformer hidden dimension (ViT-B = 768).
243    pub d_model: usize,
244    /// Number of transformer layers (ViT-B = 12).
245    pub depth: usize,
246    /// Number of attention heads per layer.
247    pub num_heads: usize,
248    /// Feed-forward MLP hidden dimension.
249    pub mlp_dim: usize,
250    /// Dropout probability.
251    pub dropout: f64,
252    /// Output projection dimension.  `None` means no projection (identity).
253    pub out_dim: Option<usize>,
254}
255
256impl Default for TextEncoderConfig {
257    fn default() -> Self {
258        Self {
259            vocab_size: VOCAB_SIZE,
260            max_seq_len: 1024,
261            d_model: VIT_WIDTH,
262            depth: VIT_DEPTH,
263            num_heads: VIT_HEADS,
264            mlp_dim: VIT_MLP_DIM,
265            dropout: 0.0,
266            out_dim: Some(EMBED_DIM),
267        }
268    }
269}
270
271// ===========================================================================
272// Two-tower SensorLM
273// ===========================================================================
274
275/// Top-level configuration for the combined SensorLM model.
276#[derive(Debug, Clone, Serialize, Deserialize)]
277pub struct SensorLMConfig {
278    /// Sensor (image) encoder configuration.
279    pub sensor_encoder: SensorEncoderConfig,
280    /// Text encoder configuration.
281    pub text_encoder: TextEncoderConfig,
282    /// Shared embedding dimensionality (must match both encoder outputs).
283    pub embed_dim: usize,
284    /// Initial value of the SigLIP temperature scalar (log-scale before exp).
285    pub temperature_init: f32,
286    /// Initial value of the SigLIP bias scalar.
287    pub bias_init: f32,
288}
289
290impl Default for SensorLMConfig {
291    fn default() -> Self {
292        Self {
293            sensor_encoder: SensorEncoderConfig::default(),
294            text_encoder: TextEncoderConfig::default(),
295            embed_dim: EMBED_DIM,
296            temperature_init: TEMPERATURE_INIT,
297            bias_init: BIAS_INIT,
298        }
299    }
300}
301
302// ===========================================================================
303// Training
304// ===========================================================================
305
306/// Learning-rate schedule type.
307#[derive(Debug, Clone, Serialize, Deserialize)]
308pub enum LrScheduleType {
309    /// Inverse-square-root schedule with linear warm-up and cool-down.
310    /// Matches the reference `decay_type='rsqrt'` setting.
311    RsqrtWithWarmupCooldown,
312    /// Cosine annealing.
313    Cosine,
314    /// Constant learning rate (no schedule).
315    Constant,
316}
317
318/// Training hyperparameters.
319#[derive(Debug, Clone, Serialize, Deserialize)]
320pub struct TrainingConfig {
321    /// Model size preset (`tiny` / `small` / `base`).
322    ///
323    /// Overrides `SensorLMConfig` when passed through the CLI.  Building
324    /// a config from a preset is the recommended way to avoid mismatched
325    /// `d_model` / `embed_dim` values between the two towers.
326    pub model_size: ModelSize,
327    /// Total number of gradient update steps.
328    pub total_steps: usize,
329    /// Mini-batch size (default: 8).
330    pub batch_size: usize,
331    /// Peak learning rate (default: 5 × 10⁻⁴).
332    pub lr: f64,
333    /// AdamW weight decay (default: 1 × 10⁻⁴).
334    pub weight_decay: f64,
335    /// Adam β₂ (default: 0.999, reference uses `scale_by_adam b2=0.999`).
336    pub beta2: f64,
337    /// Adam β₁.
338    pub beta1: f64,
339    /// Adam ε.
340    pub epsilon: f64,
341    /// Gradient clip norm (default: 1.0).
342    pub grad_clip_norm: f64,
343    /// Fraction of total steps used for linear warm-up (default: 0.2).
344    pub warmup_fraction: f64,
345    /// Fraction of total steps used for cool-down (default: 0.2).
346    pub cooldown_fraction: f64,
347    /// LR schedule type.
348    pub lr_schedule: LrScheduleType,
349    /// Save a checkpoint every N steps.
350    pub checkpoint_every: usize,
351    /// Log metrics every N steps.
352    pub log_every: usize,
353    /// Random seed.
354    pub seed: u64,
355    /// Caption type key to use during this training run.
356    pub caption_key: CaptionKey,
357    /// Path to SentencePiece tokeniser model file.
358    pub tokenizer_path: String,
359    /// Directory to write checkpoints / logs.
360    pub artifact_dir: String,
361    /// Directory containing the dataset (Parquet or raw files).
362    pub data_dir: String,
363    /// Number of DataLoader worker threads for CPU-side data preparation.
364    ///
365    /// Must be **≥ 1** — Burn's `PartialDataset::split` divides the dataset
366    /// length by `num_workers`, so `0` causes a divide-by-zero panic.
367    ///
368    /// The WGPU backend (including Metal on macOS) is internally thread-safe:
369    /// worker threads can call `Tensor::from_floats(…, &device)` safely.
370    /// 2 workers is a reasonable default; increase on machines with many CPU
371    /// cores and fast NVMe storage.  Use 1 if you observe data-loading becoming
372    /// the training bottleneck (rare with synthetic data).
373    pub num_workers: usize,
374    /// Available GPU VRAM in gigabytes.
375    ///
376    /// When set the pre-flight guard derives the attention-tensor budget as
377    /// `vram_gb / 3` and **auto-caps `batch_size`** to the largest value that
378    /// fits, so you never have to tune `--batch-size` manually.
379    ///
380    /// Memory split used (all figures are estimates):
381    /// ```text
382    /// ┌─────────────────────────────────────────────────────┐
383    /// │  1/3 → attention score/weight tensors (one layer)  │
384    /// │  1/3 → model weights + gradients + Adam states     │
385    /// │  1/3 → non-attention activations + OS/driver slack │
386    /// └─────────────────────────────────────────────────────┘
387    /// ```
388    ///
389    /// Examples for ViT-B (base), depth=12, H=12, chunk=64, N=2448:
390    ///
391    /// The peak memory is `depth × per-layer` because Burn's forward pass
392    /// builds autodiff tape for **all** transformer layers before `backward()`
393    /// starts.  70% of VRAM is budgeted for this tape; the rest covers
394    /// model weights + gradients + Adam states + other activations.
395    ///
396    /// | VRAM | attn budget (×0.7) | max batch | all-layers peak |
397    /// |------|-------------------|-----------|-----------------|
398    /// |  8 GB |  5.6 GB          |     1     |   6.6 GB        |
399    /// | 16 GB | 11.2 GB          |     1     |   6.6 GB        |
400    /// | 24 GB | 16.8 GB          |     2     |  13.1 GB        |
401    /// | 32 GB | 22.4 GB          |     3     |  19.7 GB        |
402    /// | 48 GB | 33.6 GB          |     5     |  32.8 GB        |
403    /// | 80 GB | 56.0 GB          |     8     |  52.4 GB        |
404    pub vram_gb: Option<f64>,
405    /// Skip the pre-flight VRAM safety check and proceed even if the
406    /// estimated attention memory exceeds the computed limit.
407    ///
408    /// Use this only when you are certain your GPU has enough free VRAM.
409    /// You accept full responsibility for OOM errors or GPU driver crashes.
410    pub skip_vram_check: bool,
411
412    /// Print Burn's `═══ Learner Summary ═══` table after training.
413    ///
414    /// Disabled by default to keep the terminal output clean.
415    /// Pass `--summary` on the CLI to enable.
416    pub show_summary: bool,
417}
418
419impl Default for TrainingConfig {
420    fn default() -> Self {
421        let total_examples = TOTAL_EXAMPLES;
422        let batch_size = DEFAULT_BATCH_SIZE;
423        let total_steps = total_examples / batch_size;
424        Self {
425            model_size: ModelSize::default(), // Tiny
426            total_steps,
427            batch_size,
428            lr: DEFAULT_LR,
429            weight_decay: DEFAULT_WD,
430            beta1: 0.9,
431            beta2: ADAM_BETA2,
432            epsilon: 1e-8,
433            grad_clip_norm: GRAD_CLIP_NORM,
434            warmup_fraction: 0.2,
435            cooldown_fraction: 0.2,
436            lr_schedule: LrScheduleType::RsqrtWithWarmupCooldown,
437            checkpoint_every: 500,
438            log_every: 50,
439            seed: 0,
440            caption_key: CaptionKey::HighLevelSummary,
441            tokenizer_path: "tokenizer.model".to_string(),
442            artifact_dir: "./artifacts".to_string(),
443            data_dir: "./data".to_string(),
444            num_workers: 2,
445            vram_gb: None,
446            skip_vram_check: false,
447            show_summary: false,
448        }
449    }
450}
451
452/// Which caption tier to use as the text pair during training.
453#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
454pub enum CaptionKey {
455    /// Statistical summary only (level 1).
456    LowLevel,
457    /// Structural patterns only (level 2).
458    MiddleLevel,
459    /// High-level semantic summary (level 3, 256 tokens).
460    HighLevelSummary,
461    /// Full high-level caption (level 3, 1024 tokens).
462    HighLevelAll,
463    /// Levels 2 + 1.
464    MiddleLow,
465    /// Levels 3 + 1.
466    HighLow,
467    /// Levels 3 + 2.
468    HighMiddle,
469    /// All three levels concatenated.
470    HighMiddleLow,
471}
472
473impl CaptionKey {
474    /// Maximum token budget for this caption type.
475    pub fn max_tokens(self) -> usize {
476        match self {
477            Self::LowLevel => 512,
478            Self::MiddleLevel => 512,
479            Self::HighLevelSummary => 256,
480            Self::HighLevelAll => 1024,
481            Self::MiddleLow => 1024,
482            Self::HighLow => 1024,
483            Self::HighMiddle => 512,
484            Self::HighMiddleLow => 1024,
485        }
486    }
487}
488
489// ===========================================================================
490// Inference
491// ===========================================================================
492
493/// Configuration for inference / evaluation.
494#[derive(Debug, Clone, Serialize, Deserialize)]
495pub struct InferenceConfig {
496    /// Path to model checkpoint.
497    pub checkpoint: String,
498    /// Path to tokeniser model.
499    pub tokenizer_path: String,
500    /// Maximum sequence length for text input.
501    pub max_seq_len: usize,
502    /// Batch size for encoding.
503    pub batch_size: usize,
504    /// Use FP16 for faster inference (requires `fp16` feature).
505    pub fp16: bool,
506    /// Caption key to use when generating text from sensor data.
507    pub caption_key: CaptionKey,
508}
509
510impl Default for InferenceConfig {
511    fn default() -> Self {
512        Self {
513            checkpoint: "./artifacts/model_final.bin".to_string(),
514            tokenizer_path: "tokenizer.model".to_string(),
515            max_seq_len: 256,
516            batch_size: 64,
517            fp16: false,
518            caption_key: CaptionKey::HighLevelSummary,
519        }
520    }
521}
522
523// ===========================================================================
524// Quantisation
525// ===========================================================================
526
527/// INT8 quantisation scheme.
528#[derive(Debug, Clone, Serialize, Deserialize)]
529pub enum QuantScheme {
530    /// Symmetric per-tensor quantisation.
531    SymmetricPerTensor,
532    /// Asymmetric per-tensor quantisation.
533    AsymmetricPerTensor,
534    /// Symmetric per-channel (output channel) quantisation.
535    SymmetricPerChannel,
536}
537
538/// Post-training quantisation configuration.
539#[derive(Debug, Clone, Serialize, Deserialize)]
540pub struct QuantizationConfig {
541    /// Source FP32 checkpoint to quantise.
542    pub source_checkpoint: String,
543    /// Output path for the quantised model.
544    pub output_path: String,
545    /// Path to a calibration dataset subset (Parquet).
546    pub calibration_data: String,
547    /// Number of calibration batches.
548    pub num_calibration_batches: usize,
549    /// Batch size during calibration.
550    pub calibration_batch_size: usize,
551    /// INT8 quantisation scheme.
552    pub scheme: QuantScheme,
553    /// Quantise text encoder weights (in addition to sensor encoder).
554    pub quantise_text_encoder: bool,
555    /// Path to tokeniser model.
556    pub tokenizer_path: String,
557}
558
559impl Default for QuantizationConfig {
560    fn default() -> Self {
561        Self {
562            source_checkpoint: "./artifacts/model_final.bin".to_string(),
563            output_path: "./artifacts/model_int8.bin".to_string(),
564            calibration_data: "./data/calibration.parquet".to_string(),
565            num_calibration_batches: 100,
566            calibration_batch_size: 32,
567            scheme: QuantScheme::SymmetricPerTensor,
568            quantise_text_encoder: true,
569            tokenizer_path: "tokenizer.model".to_string(),
570        }
571    }
572}