sensorlm/model/mod.rs
1//! Model architecture for SensorLM.
2//!
3//! # Module structure
4//!
5//! | Module | Contents |
6//! |--------|----------|
7//! | [`sensor_encoder`] | ViT sensor encoder with rectangular patch embedding and MAP pooling |
8//! | [`text_encoder`] | 12-layer text transformer encoder |
9//! | [`sensorlm`] | Two-tower SensorLM model + SigLIP training step |
10//!
11//! # Architecture diagram (detailed)
12//!
13//! ```text
14//! ┌─────────────────────────────────────────────────────────────────────┐
15//! │ SENSOR ENCODER (ViT-B/10/2) │
16//! │ │
17//! │ Input tensor: (B, 1440, 34) [batch × time × channels] │
18//! │ │ │
19//! │ ▼ reshape → (B, 1, 1440, 34) [treat as 1-channel image] │
20//! │ │
21//! │ ┌─────────────────────────────────────────┐ │
22//! │ │ PatchEmbedding │ │
23//! │ │ Conv2d(in=1, out=768, k=(10,2), s=(10,2)) │
24//! │ │ Output: (B, 768, 144, 17) │ │
25//! │ │ Reshape: (B, 144*17=2448, 768) │ │
26//! │ └─────────────────────────────────────────┘ │
27//! │ │ │
28//! │ ▼ + LearnedPositionalEmbedding(2448, 768) │
29//! │ │
30//! │ ┌─────────────────────────────────────────┐ ×12 │
31//! │ │ TransformerBlock │ │
32//! │ │ ├─ LayerNorm │ │
33//! │ │ ├─ MultiHeadSelfAttention (12 heads) │ │
34//! │ │ ├─ residual + LayerNorm │ │
35//! │ │ └─ MLP (768 → 3072 → 768, GELU) │ │
36//! │ └─────────────────────────────────────────┘ │
37//! │ │ │
38//! │ ▼ Sequence (B, 2448, 768) │
39//! │ │
40//! │ ┌─────────────────────────────────────────┐ │
41//! │ │ MAPHead (Multihead Attention Pooling) │ │
42//! │ │ Learnable probe (1, 1, 768) │ │
43//! │ │ Cross-attn: probe queries ← seq k/v │ │
44//! │ │ Output: (B, 768) │ │
45//! │ └─────────────────────────────────────────┘ │
46//! │ │ │
47//! │ ▼ L2-normalise → (B, 768) │
48//! └─────────────────────────────────────────────────────────────────────┘
49//!
50//! ┌─────────────────────────────────────────────────────────────────────┐
51//! │ TEXT ENCODER (ViT-B) │
52//! │ │
53//! │ Input: token IDs (B, L) + attention mask (B, L) │
54//! │ │ │
55//! │ ▼ TokenEmbedding(32000, 768) + PositionalEmbedding(1024, 768)│
56//! │ │
57//! │ ┌─────────────────────────────────────────┐ ×12 │
58//! │ │ TransformerBlock (same structure) │ │
59//! │ └─────────────────────────────────────────┘ │
60//! │ │ │
61//! │ ▼ Masked mean-pool → (B, 768) │
62//! │ │ │
63//! │ ▼ Linear projection → (B, 768) │
64//! │ │ │
65//! │ ▼ L2-normalise → (B, 768) │
66//! └─────────────────────────────────────────────────────────────────────┘
67//!
68//! ┌─────────────────────────────────────────────────────────────────────┐
69//! │ SIGLIP CONTRASTIVE LOSS │
70//! │ │
71//! │ S[i,j] = temperature * dot(z_sensor_i, z_text_j) + bias │
72//! │ │
73//! │ y[i,j] = +1 if i == j (positive pair) │
74//! │ -1 if i != j (negative pair) │
75//! │ │
76//! │ L = -mean_ij( log(sigmoid(y[i,j] * S[i,j])) ) │
77//! └─────────────────────────────────────────────────────────────────────┘
78//! ```
79
80pub mod sensor_encoder;
81pub mod sensorlm;
82pub mod text_encoder;