Skip to main content

sensorlm/model/
mod.rs

1//! Model architecture for SensorLM.
2//!
3//! # Module structure
4//!
5//! | Module | Contents |
6//! |--------|----------|
7//! | [`sensor_encoder`] | ViT sensor encoder with rectangular patch embedding and MAP pooling |
8//! | [`text_encoder`]   | 12-layer text transformer encoder |
9//! | [`sensorlm`]       | Two-tower SensorLM model + SigLIP training step |
10//!
11//! # Architecture diagram (detailed)
12//!
13//! ```text
14//! ┌─────────────────────────────────────────────────────────────────────┐
15//! │                      SENSOR ENCODER (ViT-B/10/2)                    │
16//! │                                                                     │
17//! │  Input tensor: (B, 1440, 34)  [batch × time × channels]            │
18//! │       │                                                             │
19//! │       ▼  reshape → (B, 1, 1440, 34)  [treat as 1-channel image]    │
20//! │                                                                     │
21//! │  ┌─────────────────────────────────────────┐                        │
22//! │  │  PatchEmbedding                         │                        │
23//! │  │  Conv2d(in=1, out=768, k=(10,2), s=(10,2))                      │
24//! │  │  Output: (B, 768, 144, 17)             │                        │
25//! │  │  Reshape: (B, 144*17=2448, 768)        │                        │
26//! │  └─────────────────────────────────────────┘                        │
27//! │       │                                                             │
28//! │       ▼  + LearnedPositionalEmbedding(2448, 768)                    │
29//! │                                                                     │
30//! │  ┌─────────────────────────────────────────┐  ×12                   │
31//! │  │  TransformerBlock                       │                        │
32//! │  │  ├─ LayerNorm                           │                        │
33//! │  │  ├─ MultiHeadSelfAttention (12 heads)   │                        │
34//! │  │  ├─ residual + LayerNorm                │                        │
35//! │  │  └─ MLP (768 → 3072 → 768, GELU)        │                        │
36//! │  └─────────────────────────────────────────┘                        │
37//! │       │                                                             │
38//! │       ▼  Sequence (B, 2448, 768)                                    │
39//! │                                                                     │
40//! │  ┌─────────────────────────────────────────┐                        │
41//! │  │  MAPHead (Multihead Attention Pooling)  │                        │
42//! │  │  Learnable probe (1, 1, 768)            │                        │
43//! │  │  Cross-attn: probe queries ← seq k/v    │                        │
44//! │  │  Output: (B, 768)                       │                        │
45//! │  └─────────────────────────────────────────┘                        │
46//! │       │                                                             │
47//! │       ▼  L2-normalise → (B, 768)                                    │
48//! └─────────────────────────────────────────────────────────────────────┘
49//!
50//! ┌─────────────────────────────────────────────────────────────────────┐
51//! │                       TEXT ENCODER (ViT-B)                          │
52//! │                                                                     │
53//! │  Input: token IDs (B, L)  +  attention mask (B, L)                 │
54//! │       │                                                             │
55//! │       ▼  TokenEmbedding(32000, 768) + PositionalEmbedding(1024, 768)│
56//! │                                                                     │
57//! │  ┌─────────────────────────────────────────┐  ×12                   │
58//! │  │  TransformerBlock (same structure)      │                        │
59//! │  └─────────────────────────────────────────┘                        │
60//! │       │                                                             │
61//! │       ▼  Masked mean-pool → (B, 768)                                │
62//! │       │                                                             │
63//! │       ▼  Linear projection → (B, 768)                               │
64//! │       │                                                             │
65//! │       ▼  L2-normalise → (B, 768)                                    │
66//! └─────────────────────────────────────────────────────────────────────┘
67//!
68//! ┌─────────────────────────────────────────────────────────────────────┐
69//! │                     SIGLIP CONTRASTIVE LOSS                         │
70//! │                                                                     │
71//! │  S[i,j] = temperature * dot(z_sensor_i, z_text_j) + bias           │
72//! │                                                                     │
73//! │  y[i,j] = +1  if  i == j  (positive pair)                          │
74//! │           -1  if  i != j  (negative pair)                          │
75//! │                                                                     │
76//! │  L = -mean_ij( log(sigmoid(y[i,j] * S[i,j])) )                     │
77//! └─────────────────────────────────────────────────────────────────────┘
78//! ```
79
80pub mod sensor_encoder;
81pub mod sensorlm;
82pub mod text_encoder;