Skip to main content

sensorlm/model/
sensor_encoder.rs

1//! ViT sensor encoder with rectangular patch embedding and MAP pooling.
2//!
3//! # Input / output contract
4//!
5//! | Tensor | Shape | Description |
6//! |--------|-------|-------------|
7//! | Input  | `(B, T, C)` | Batch of normalised sensor sequences |
8//! | Output | `(B, D)`    | L2-normalised per-sample embeddings |
9//!
10//! where `B` = batch size, `T` = 1440 time steps, `C` = 34 channels,
11//! `D` = 768 embedding dimension.
12//!
13//! # Patch grid
14//!
15//! The `(T, C)` sensor grid is divided into `(T/ph, C/pw)` non-overlapping
16//! rectangular patches of size `(ph, pw)` = `(10, 2)`:
17//!
18//! ```text
19//! T = 1440 ──► 144 patches along time axis
20//! C =   34 ──►  17 patches along channel axis  (ceil(34/2) = 17)
21//! Total = 144 × 17 = 2448 patch tokens
22//! ```
23//!
24//! Each patch is linearly projected to `D = 768` via a `Conv2d` layer.
25
26use burn::{
27    module::{Module, Param},
28    nn::{
29        conv::{Conv2d, Conv2dConfig},
30        Dropout, DropoutConfig, LayerNorm, LayerNormConfig, Linear, LinearConfig,
31    },
32    tensor::{
33        activation,
34        backend::Backend,
35        Distribution, Tensor,
36    },
37};
38
39use crate::config::{PoolType, SensorEncoderConfig};
40
41// ===========================================================================
42// Patch embedding
43// ===========================================================================
44
45/// Projects rectangular sensor patches into the ViT embedding space.
46///
47/// Implemented as a `Conv2d` with `kernel_size == stride == (patch_h, patch_w)`.
48#[derive(Module, Debug)]
49pub struct PatchEmbedding<B: Backend> {
50    proj: Conv2d<B>,
51    num_patches_t: usize,
52    num_patches_c: usize,
53    d_model: usize,
54}
55
56impl<B: Backend> PatchEmbedding<B> {
57    /// Create a new patch-embedding layer.
58    pub fn new(
59        in_channels: usize,
60        d_model: usize,
61        patch_h: usize,
62        patch_w: usize,
63        time_steps: usize,
64        num_channels: usize,
65        device: &B::Device,
66    ) -> Self {
67        // PaddingConfig2d::Valid = no padding (kernel fits exactly)
68        let proj = Conv2dConfig::new(
69            [in_channels, d_model],
70            [patch_h, patch_w],
71        )
72        .with_stride([patch_h, patch_w])
73        .with_padding(burn::nn::PaddingConfig2d::Valid)
74        .with_bias(true)
75        .init(device);
76
77        let num_patches_t = time_steps / patch_h;
78        let num_patches_c = (num_channels + patch_w - 1) / patch_w;
79
80        Self {
81            proj,
82            num_patches_t,
83            num_patches_c,
84            d_model,
85        }
86    }
87
88    /// Forward pass. Input `(B, 1, T, C)` → output `(B, num_patches, D)`.
89    pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 3> {
90        let out = self.proj.forward(x); // (B, D, pt, pc)
91        let [batch, d, _pt, _pc] = out.dims();
92        let num_patches = self.num_patches_t * self.num_patches_c;
93        // (B, D, N) → (B, N, D)
94        out.reshape([batch, d, num_patches]).swap_dims(1, 2)
95    }
96
97    /// Total patch count.
98    pub fn num_patches(&self) -> usize {
99        self.num_patches_t * self.num_patches_c
100    }
101}
102
103// ===========================================================================
104// MLP block
105// ===========================================================================
106
107/// Feed-forward MLP: `Linear(D, mlp_dim) → GELU → Dropout → Linear(mlp_dim, D)`.
108#[derive(Module, Debug)]
109pub struct MlpBlock<B: Backend> {
110    fc1: Linear<B>,
111    fc2: Linear<B>,
112    dropout: Dropout,
113}
114
115impl<B: Backend> MlpBlock<B> {
116    /// Create with `d_model` input/output and `mlp_dim` hidden units.
117    pub fn new(d_model: usize, mlp_dim: usize, dropout: f64, device: &B::Device) -> Self {
118        Self {
119            fc1: LinearConfig::new(d_model, mlp_dim).init(device),
120            fc2: LinearConfig::new(mlp_dim, d_model).init(device),
121            dropout: DropoutConfig::new(dropout).init(),
122        }
123    }
124
125    /// `(B, N, D) → (B, N, D)`.
126    pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
127        let x = self.fc1.forward(x);
128        let x = activation::gelu(x);
129        let x = self.dropout.forward(x);
130        let x = self.fc2.forward(x);
131        self.dropout.forward(x)
132    }
133}
134
135// ===========================================================================
136// Multi-head self-attention
137// ===========================================================================
138
139/// Scaled dot-product multi-head self-attention with optional chunked computation.
140///
141/// When `chunk_size > 0` the query sequence is processed in windows of
142/// `chunk_size` rows, keeping the **forward-pass** peak attention memory at
143/// `O(B · H · chunk_size · N)` instead of `O(B · H · N²)`, and ensuring
144/// each individual WGPU GPU dispatch remains small enough to avoid OS
145/// watchdog (TDR) timeouts.
146///
147/// ## ⚠ Training memory — chunking reduces dispatch size but NOT total tape
148///
149/// Burn's forward pass builds an autodiff tape for every transformer layer
150/// **before** `loss.backward()` runs.  At the forward→backward boundary all
151/// `depth` layers' chunk tensors are simultaneously in GPU memory:
152///
153/// ```text
154/// peak = depth × 2 × ceil(N/chunk) × B × H × chunk × N × 4 bytes
155///      = 12 × 2 × 39 × B × 12 × 64 × 2448 × 4   (ViT-B defaults)
156///      ≈ 6.56 GB × B
157/// ```
158///
159/// Chunking (small `chunk_size`) keeps **individual GPU dispatch sizes**
160/// small (preventing OS watchdog / TDR timeouts), but the cumulative tape
161/// size is the same as full attention.  The only way to reduce training
162/// memory is gradient checkpointing (recompute attention during backward
163/// instead of storing it) — not yet implemented in this codebase.
164///
165/// Safe configurations (24 GB GPU, ViT-B):
166/// - `batch_size = 2`  →  all-layers peak ≈ 13 GB  ✓
167/// - `batch_size = 4`  →  all-layers peak ≈ 26 GB  ✗ OOM
168///
169/// The [`crate::training::learner::train`] function guards against unsafe
170/// configurations using `--vram-gb` to derive the correct limit.
171///
172/// ## Forward memory comparison (N = 2 448, H = 12, B = 8, fp32)
173///
174/// | mode         | peak fwd attn tensor     | size   |
175/// |--------------|--------------------------|--------|
176/// | full (chunk=0) | (8, 12, 2448, 2448)   | ~18 GB |
177/// | chunk=256      | (8, 12,  256, 2448)   | ~1.9 GB |
178/// | chunk=128      | (8, 12,  128, 2448)   | ~960 MB |
179/// | chunk=64       | (8, 12,   64, 2448)   | ~480 MB |
180#[derive(Module, Debug)]
181pub struct MultiHeadSelfAttention<B: Backend> {
182    q_proj:   Linear<B>,
183    k_proj:   Linear<B>,
184    v_proj:   Linear<B>,
185    out_proj: Linear<B>,
186    num_heads:  usize,
187    head_dim:   usize,
188    scale:      f32,
189    chunk_size: usize, // 0 = full attention (no chunking)
190    dropout:    Dropout,
191}
192
193impl<B: Backend> MultiHeadSelfAttention<B> {
194    /// Construct MHSA.
195    ///
196    /// * `chunk_size` – query chunk window; `0` disables chunking.
197    pub fn new(
198        d_model: usize,
199        num_heads: usize,
200        dropout: f64,
201        chunk_size: usize,
202        device: &B::Device,
203    ) -> Self {
204        assert_eq!(d_model % num_heads, 0);
205        let head_dim = d_model / num_heads;
206        Self {
207            q_proj:   LinearConfig::new(d_model, d_model).init(device),
208            k_proj:   LinearConfig::new(d_model, d_model).init(device),
209            v_proj:   LinearConfig::new(d_model, d_model).init(device),
210            out_proj: LinearConfig::new(d_model, d_model).init(device),
211            num_heads,
212            head_dim,
213            scale: (head_dim as f32).powf(-0.5),
214            chunk_size,
215            dropout: DropoutConfig::new(dropout).init(),
216        }
217    }
218
219    /// Self-attention: `(B, N, D) → (B, N, D)`.
220    ///
221    /// When `chunk_size > 0` the computation is split into `ceil(N / chunk_size)`
222    /// passes, each allocating an attention matrix of shape
223    /// `(B, H, chunk_size, N)` rather than `(B, H, N, N)`.
224    pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
225        let [batch, seq, _d] = x.dims();
226        let h  = self.num_heads;
227        let hd = self.head_dim;
228
229        let q = self.q_proj.forward(x.clone())
230            .reshape([batch, seq, h, hd]).swap_dims(1, 2); // (B, H, N, hd)
231        let k = self.k_proj.forward(x.clone())
232            .reshape([batch, seq, h, hd]).swap_dims(1, 2); // (B, H, N, hd)
233        let v = self.v_proj.forward(x)
234            .reshape([batch, seq, h, hd]).swap_dims(1, 2); // (B, H, N, hd)
235
236        let ctx = if self.chunk_size == 0 || self.chunk_size >= seq {
237            // Full attention — single (B, H, N, N) matrix.
238            let scores = q.matmul(k.swap_dims(2, 3)).mul_scalar(self.scale);
239            let attn   = activation::softmax(scores, 3);
240            let attn   = self.dropout.forward(attn);
241            attn.matmul(v)  // (B, H, N, hd)
242        } else {
243            // Chunked attention — process Q in windows to cap peak memory.
244            let k_t = k.swap_dims(2, 3); // (B, H, hd, N) — shared across chunks
245            let mut chunks: Vec<Tensor<B, 4>> = Vec::new();
246            let mut start = 0;
247            while start < seq {
248                let end = (start + self.chunk_size).min(seq);
249                // q_chunk: (B, H, chunk, hd)
250                let q_chunk = q.clone().slice([0..batch, 0..h, start..end, 0..hd]);
251                // scores: (B, H, chunk, N)
252                let scores = q_chunk.matmul(k_t.clone()).mul_scalar(self.scale);
253                let attn   = activation::softmax(scores, 3);
254                let attn   = self.dropout.forward(attn);
255                // out: (B, H, chunk, hd)
256                chunks.push(attn.matmul(v.clone()));
257                start = end;
258            }
259            Tensor::cat(chunks, 2) // (B, H, N, hd)
260        };
261
262        let ctx = ctx.swap_dims(1, 2).reshape([batch, seq, h * hd]);
263        self.out_proj.forward(ctx)
264    }
265}
266
267// ===========================================================================
268// Transformer encoder block (pre-norm)
269// ===========================================================================
270
271/// Pre-norm ViT transformer block.
272///
273/// ```text
274/// x = x + Attn(LayerNorm(x))
275/// x = x + MLP(LayerNorm(x))
276/// ```
277#[derive(Module, Debug)]
278pub struct EncoderBlock<B: Backend> {
279    norm1:   LayerNorm<B>,
280    attn:    MultiHeadSelfAttention<B>,
281    norm2:   LayerNorm<B>,
282    mlp:     MlpBlock<B>,
283    dropout: Dropout,
284}
285
286impl<B: Backend> EncoderBlock<B> {
287    /// Build an encoder block.
288    pub fn new(
289        d_model: usize,
290        num_heads: usize,
291        mlp_dim: usize,
292        dropout: f64,
293        chunk_size: usize,
294        device: &B::Device,
295    ) -> Self {
296        Self {
297            norm1:   LayerNormConfig::new(d_model).init(device),
298            attn:    MultiHeadSelfAttention::new(d_model, num_heads, dropout, chunk_size, device),
299            norm2:   LayerNormConfig::new(d_model).init(device),
300            mlp:     MlpBlock::new(d_model, mlp_dim, dropout, device),
301            dropout: DropoutConfig::new(dropout).init(),
302        }
303    }
304
305    /// `(B, N, D) → (B, N, D)`.
306    pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
307        let residual = x.clone();
308        let y = self.attn.forward(self.norm1.forward(x));
309        let y = self.dropout.forward(y);
310        let x = y + residual;
311
312        let residual = x.clone();
313        let y = self.mlp.forward(self.norm2.forward(x));
314        y + residual
315    }
316}
317
318// ===========================================================================
319// MAP Head (Multihead Attention Pooling)
320// ===========================================================================
321
322/// Pools a patch sequence to a single vector via a learnable probe.
323#[derive(Module, Debug)]
324pub struct MAPHead<B: Backend> {
325    probe:    Param<Tensor<B, 3>>,
326    q_proj:   Linear<B>,
327    k_proj:   Linear<B>,
328    v_proj:   Linear<B>,
329    out_proj: Linear<B>,
330    norm:     LayerNorm<B>,
331    mlp:      MlpBlock<B>,
332    num_heads: usize,
333    head_dim:  usize,
334    scale:     f32,
335}
336
337impl<B: Backend> MAPHead<B> {
338    /// Build a MAP head.
339    pub fn new(
340        d_model: usize,
341        num_heads: usize,
342        mlp_dim: usize,
343        device: &B::Device,
344    ) -> Self {
345        let head_dim = d_model / num_heads;
346        let probe = Tensor::<B, 3>::random(
347            [1, 1, d_model],
348            Distribution::Uniform(-0.02, 0.02),
349            device,
350        );
351        Self {
352            probe:    Param::from_tensor(probe),
353            q_proj:   LinearConfig::new(d_model, d_model).init(device),
354            k_proj:   LinearConfig::new(d_model, d_model).init(device),
355            v_proj:   LinearConfig::new(d_model, d_model).init(device),
356            out_proj: LinearConfig::new(d_model, d_model).init(device),
357            norm:     LayerNormConfig::new(d_model).init(device),
358            mlp:      MlpBlock::new(d_model, mlp_dim, 0.0, device),
359            num_heads,
360            head_dim,
361            scale: (head_dim as f32).powf(-0.5),
362        }
363    }
364
365    /// Pool `(B, N, D)` → `(B, D)`.
366    pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
367        let [batch, seq, d] = x.dims();
368        let h  = self.num_heads;
369        let hd = self.head_dim;
370
371        let probe = self.probe.val().expand([batch, 1, d]);
372
373        let q = self.q_proj.forward(probe);
374        let k = self.k_proj.forward(x.clone());
375        let v = self.v_proj.forward(x);
376
377        let rq = |t: Tensor<B, 3>, n: usize| t.reshape([batch, n, h, hd]).swap_dims(1, 2);
378        let q = rq(q, 1);
379        let k = rq(k, seq);
380        let v = rq(v, seq);
381
382        let scores = q.matmul(k.swap_dims(2, 3)).mul_scalar(self.scale);
383        let attn   = activation::softmax(scores, 3);
384
385        let ctx = attn
386            .matmul(v)
387            .swap_dims(1, 2)
388            .reshape([batch, 1, h * hd]);
389
390        let ctx = self.out_proj.forward(ctx);
391        let ctx_2d = ctx.squeeze(1); // (B, D)
392
393        let normed  = self.norm.forward(ctx_2d.clone().unsqueeze_dim(1));
394        let mlp_out = self.mlp.forward(normed).squeeze(1);
395        ctx_2d + mlp_out
396    }
397}
398
399// ===========================================================================
400// Full sensor encoder
401// ===========================================================================
402
403/// Vision Transformer sensor encoder.
404///
405/// Stores `use_map: bool` instead of the `PoolType` enum because burn's
406/// `#[derive(Module)]` requires all struct fields to implement `Module<B>`.
407#[derive(Module, Debug)]
408pub struct SensorEncoder<B: Backend> {
409    patch_embed: PatchEmbedding<B>,
410    pos_embed:   Param<Tensor<B, 3>>,
411    blocks:      Vec<EncoderBlock<B>>,
412    norm:        LayerNorm<B>,
413    map_head:    Option<MAPHead<B>>,
414    dropout:     Dropout,
415    d_model:     usize,
416}
417
418impl<B: Backend> SensorEncoder<B> {
419    /// Construct a sensor encoder from a [`SensorEncoderConfig`].
420    pub fn new(cfg: &SensorEncoderConfig, device: &B::Device) -> Self {
421        let num_patches = cfg.num_patches();
422
423        let patch_embed = PatchEmbedding::new(
424            1,
425            cfg.d_model,
426            cfg.patch_h,
427            cfg.patch_w,
428            cfg.time_steps,
429            cfg.num_channels,
430            device,
431        );
432
433        let pos_embed = Tensor::<B, 3>::random(
434            [1, num_patches, cfg.d_model],
435            Distribution::Normal(0.0, (1.0 / cfg.d_model as f64).sqrt()),
436            device,
437        );
438
439        let blocks: Vec<EncoderBlock<B>> = (0..cfg.depth)
440            .map(|_| EncoderBlock::new(cfg.d_model, cfg.num_heads, cfg.mlp_dim, cfg.dropout, cfg.attn_chunk_size, device))
441            .collect();
442
443        let norm = LayerNormConfig::new(cfg.d_model).init(device);
444
445        let map_head = if cfg.pool_type == PoolType::Map {
446            Some(MAPHead::new(cfg.d_model, cfg.num_heads, cfg.mlp_dim, device))
447        } else {
448            None
449        };
450
451        Self {
452            patch_embed,
453            pos_embed: Param::from_tensor(pos_embed),
454            blocks,
455            norm,
456            map_head,
457            dropout: DropoutConfig::new(cfg.dropout).init(),
458            d_model: cfg.d_model,
459        }
460    }
461
462    /// Encode sensor data. Input `(B, T, C)` → output L2-norm embedding `(B, D)`.
463    pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
464        let [batch, _t, _c] = x.dims();
465
466        // (B, T, C) → (B, 1, T, C)
467        let x = x.unsqueeze_dim(1);
468
469        // Patch embed → (B, N, D)
470        let mut tokens = self.patch_embed.forward(x);
471
472        // Add positional embeddings.
473        let num_patches = tokens.dims()[1];
474        let pos = self.pos_embed.val().expand([batch, num_patches, self.d_model]);
475        tokens = tokens + pos;
476        tokens = self.dropout.forward(tokens);
477
478        // Transformer blocks.
479        for block in &self.blocks {
480            tokens = block.forward(tokens);
481        }
482        tokens = self.norm.forward(tokens);
483
484        // Pool.
485        let embedding: Tensor<B, 2> = match &self.map_head {
486            Some(map) => map.forward(tokens),
487            None => tokens.mean_dim(1).squeeze(1),
488        };
489
490        l2_normalize(embedding)
491    }
492}
493
494// ===========================================================================
495// L2 normalisation
496// ===========================================================================
497
498/// L2-normalise each row of `(B, D)` to unit norm.
499pub fn l2_normalize<B: Backend>(x: Tensor<B, 2>) -> Tensor<B, 2> {
500    let [batch, d] = x.dims();
501    let norm = x.clone().powf_scalar(2.0).sum_dim(1).sqrt().clamp_min(1e-12);
502    x / norm.expand([batch, d])
503}
504
505#[cfg(test)]
506mod tests {
507    use super::*;
508    use burn::backend::NdArray;
509    use crate::config::SensorEncoderConfig;
510
511    type B = NdArray;
512
513    fn tiny_cfg() -> SensorEncoderConfig {
514        SensorEncoderConfig {
515            time_steps: 40,
516            num_channels: 4,
517            patch_h: 10,
518            patch_w: 2,
519            d_model: 32,
520            depth: 2,
521            num_heads: 4,
522            mlp_dim: 64,
523            dropout: 0.0,
524            pool_type: PoolType::Gap,
525            head_zeroinit: false,
526            attn_chunk_size: 0, // tiny test — no chunking needed
527        }
528    }
529
530    #[test]
531    fn test_patch_embedding_shape() {
532        let device = Default::default();
533        let cfg = tiny_cfg();
534        let pe = PatchEmbedding::<B>::new(1, cfg.d_model, cfg.patch_h, cfg.patch_w,
535                                          cfg.time_steps, cfg.num_channels, &device);
536        let x = Tensor::<B, 4>::zeros([2, 1, 40, 4], &device);
537        let out = pe.forward(x);
538        let [b, n, d] = out.dims();
539        assert_eq!(b, 2);
540        assert_eq!(n, (40 / 10) * (4 / 2)); // 4 * 2 = 8
541        assert_eq!(d, cfg.d_model);
542    }
543
544    #[test]
545    fn test_encoder_forward_shape() {
546        let device = Default::default();
547        let cfg = tiny_cfg();
548        let encoder = SensorEncoder::<B>::new(&cfg, &device);
549        let x = Tensor::<B, 3>::zeros([2, 40, 4], &device);
550        let out = encoder.forward(x);
551        let [b, d] = out.dims();
552        assert_eq!(b, 2);
553        assert_eq!(d, cfg.d_model);
554    }
555}