Skip to main content

sensorlm/model/
sensorlm.rs

1//! Two-tower SensorLM model and Burn training/validation step wrappers.
2//!
3//! # Two-tower architecture
4//!
5//! ```text
6//!  sensor_tensor (B,T,C) ──► SensorEncoder ──► z_s (B,D) ─┐
7//!                                                           ├─► SigLIP loss
8//!  token_ids (B,L)       ──► TextEncoder   ──► z_t (B,D) ─┘
9//!
10//!  S[i,j] = temperature · dot(z_s[i], z_t[j]) + bias
11//!  L = -mean_ij[ log(sigmoid( y[i,j] · S[i,j] )) ]
12//!  y[i,j] = +1 if i==j, -1 otherwise
13//! ```
14
15use burn::{
16    module::{Module, Param},
17    tensor::{
18        backend::{AutodiffBackend, Backend},
19        ElementConversion, Int, Tensor,
20    },
21    train::{
22        metric::{Adaptor, LossInput},
23        TrainOutput, TrainStep, ValidStep,
24    },
25    data::dataloader::batcher::Batcher,
26};
27
28use crate::config::SensorLMConfig;
29use crate::data::dataset::SensorTextItem;
30
31use crate::loss::siglip_loss;
32use crate::model::sensor_encoder::SensorEncoder;
33use crate::model::text_encoder::TextEncoder;
34
35// ===========================================================================
36// Model
37// ===========================================================================
38
39/// The combined SensorLM two-tower model.
40#[derive(Module, Debug)]
41pub struct SensorLMModel<B: Backend> {
42    /// ViT sensor encoder.
43    pub sensor_encoder: SensorEncoder<B>,
44    /// Text transformer encoder.
45    pub text_encoder: TextEncoder<B>,
46    /// Log-temperature scalar (temperature = exp(log_temperature) > 0).
47    pub log_temperature: Param<Tensor<B, 1>>,
48    /// SigLIP bias scalar.
49    pub bias: Param<Tensor<B, 1>>,
50}
51
52impl<B: Backend> SensorLMModel<B> {
53    /// Construct from a [`SensorLMConfig`].
54    pub fn new(cfg: &SensorLMConfig, device: &B::Device) -> Self {
55        let log_temp = Tensor::<B, 1>::from_floats(
56            [cfg.temperature_init.ln()],
57            device,
58        );
59        let bias = Tensor::<B, 1>::from_floats([cfg.bias_init], device);
60
61        Self {
62            sensor_encoder: SensorEncoder::new(&cfg.sensor_encoder, device),
63            text_encoder:   TextEncoder::new(&cfg.text_encoder, device),
64            log_temperature: Param::from_tensor(log_temp),
65            bias:            Param::from_tensor(bias),
66        }
67    }
68
69    /// Encode sensor data → `(B, D)` L2-normalised.
70    pub fn encode_sensor(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
71        self.sensor_encoder.forward(x)
72    }
73
74    /// Encode text → `(B, D)` L2-normalised.
75    pub fn encode_text(
76        &self,
77        input_ids: Tensor<B, 2, Int>,
78        attention_mask: Tensor<B, 2, Int>,
79    ) -> Tensor<B, 2> {
80        self.text_encoder.forward(input_ids, attention_mask)
81    }
82
83    /// Compute `(B, B)` similarity matrix.
84    pub fn similarity_matrix(
85        &self,
86        z_sensor: Tensor<B, 2>,
87        z_text: Tensor<B, 2>,
88    ) -> Tensor<B, 2> {
89        // into_scalar() returns B::FloatElem; .elem::<f32>() converts to f32.
90        let temperature: f32 = self.log_temperature.val().exp().into_scalar().elem();
91        let bias: f32        = self.bias.val().into_scalar().elem();
92        z_sensor.matmul(z_text.transpose())
93            .mul_scalar(temperature)
94            .add_scalar(bias)
95    }
96
97    /// Full forward pass computing the SigLIP loss.
98    pub fn forward(
99        &self,
100        sensor: Tensor<B, 3>,
101        input_ids: Tensor<B, 2, Int>,
102        attention_mask: Tensor<B, 2, Int>,
103    ) -> SensorLMOutput<B> {
104        let z_sensor = self.encode_sensor(sensor);
105        let z_text   = self.encode_text(input_ids, attention_mask);
106        let logits   = self.similarity_matrix(z_sensor, z_text);
107        let loss     = siglip_loss(logits.clone());
108        SensorLMOutput { loss, logits }
109    }
110}
111
112// ===========================================================================
113// Output type
114// ===========================================================================
115
116/// Output of a SensorLM forward pass.
117#[derive(Debug)]
118pub struct SensorLMOutput<B: Backend> {
119    /// Scalar SigLIP loss `(1,)`.
120    pub loss: Tensor<B, 1>,
121    /// `(B, B)` similarity logits.
122    pub logits: Tensor<B, 2>,
123}
124
125// Teach burn's LossMetric how to extract the loss from our output type.
126impl<B: Backend> Adaptor<LossInput<B>> for SensorLMOutput<B> {
127    fn adapt(&self) -> LossInput<B> {
128        LossInput::new(self.loss.clone())
129    }
130}
131
132// ===========================================================================
133// Batch type
134// ===========================================================================
135
136/// A collated training batch.
137#[derive(Debug, Clone)]
138pub struct SensorLMBatch<B: Backend> {
139    /// `(B, T, C)` sensor data.
140    pub sensor: Tensor<B, 3>,
141    /// `(B, L)` token IDs.
142    pub input_ids: Tensor<B, 2, Int>,
143    /// `(B, L)` attention mask.
144    pub attention_mask: Tensor<B, 2, Int>,
145}
146
147// ===========================================================================
148// Burn TrainStep / ValidStep
149// ===========================================================================
150
151impl<B: AutodiffBackend> TrainStep<SensorLMBatch<B>, SensorLMOutput<B>>
152    for SensorLMModel<B>
153{
154    fn step(&self, batch: SensorLMBatch<B>) -> TrainOutput<SensorLMOutput<B>> {
155        let output = self.forward(batch.sensor, batch.input_ids, batch.attention_mask);
156        TrainOutput::new(self, output.loss.backward(), output)
157    }
158}
159
160impl<B: Backend> ValidStep<SensorLMBatch<B>, SensorLMOutput<B>> for SensorLMModel<B> {
161    fn step(&self, batch: SensorLMBatch<B>) -> SensorLMOutput<B> {
162        self.forward(batch.sensor, batch.input_ids, batch.attention_mask)
163    }
164}
165
166// ===========================================================================
167// Batcher
168// ===========================================================================
169
170/// Converts a `Vec<SensorTextItem>` into a GPU-resident `SensorLMBatch`.
171#[derive(Clone)]
172pub struct SensorLMBatcher<B: Backend> {
173    device:       B::Device,
174    time_steps:   usize,
175    num_channels: usize,
176    max_seq_len:  usize,
177}
178
179impl<B: Backend> SensorLMBatcher<B> {
180    /// Create a new batcher.
181    pub fn new(
182        device: B::Device,
183        time_steps: usize,
184        num_channels: usize,
185        max_seq_len: usize,
186    ) -> Self {
187        Self { device, time_steps, num_channels, max_seq_len }
188    }
189}
190
191impl<B: Backend> Batcher<SensorTextItem, SensorLMBatch<B>> for SensorLMBatcher<B> {
192    fn batch(&self, items: Vec<SensorTextItem>) -> SensorLMBatch<B> {
193        let b = items.len();
194        let t = self.time_steps;
195        let c = self.num_channels;
196        let l = self.max_seq_len;
197
198        let sensor_flat: Vec<f32> = items.iter()
199            .flat_map(|it| it.sensor.iter().copied()).collect();
200        let token_flat: Vec<i32> = items.iter()
201            .flat_map(|it| it.token_ids.iter().copied()).collect();
202        let mask_flat: Vec<i32> = items.iter()
203            .flat_map(|it| it.attention_mask.iter().copied()).collect();
204
205        let sensor = Tensor::<B, 1>::from_floats(sensor_flat.as_slice(), &self.device)
206            .reshape([b, t, c]);
207        let input_ids = Tensor::<B, 1, Int>::from_ints(token_flat.as_slice(), &self.device)
208            .reshape([b, l]);
209        let attention_mask = Tensor::<B, 1, Int>::from_ints(mask_flat.as_slice(), &self.device)
210            .reshape([b, l]);
211
212        SensorLMBatch { sensor, input_ids, attention_mask }
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use burn::backend::NdArray;
220    use crate::config::{SensorEncoderConfig, TextEncoderConfig, PoolType, SensorLMConfig};
221
222    type B = NdArray;
223
224    fn tiny_config() -> SensorLMConfig {
225        SensorLMConfig {
226            sensor_encoder: SensorEncoderConfig {
227                time_steps: 40,
228                num_channels: 4,
229                patch_h: 10,
230                patch_w: 2,
231                d_model: 32,
232                depth: 2,
233                num_heads: 4,
234                mlp_dim: 64,
235                dropout: 0.0,
236                pool_type: PoolType::Gap,
237                head_zeroinit: false,
238                attn_chunk_size: 0,
239            },
240            text_encoder: TextEncoderConfig {
241                vocab_size: 100,
242                max_seq_len: 16,
243                d_model: 32,
244                depth: 2,
245                num_heads: 4,
246                mlp_dim: 64,
247                dropout: 0.0,
248                out_dim: Some(32),
249            },
250            embed_dim: 32,
251            temperature_init: 10.0,
252            bias_init: -10.0,
253        }
254    }
255
256    #[test]
257    fn test_sensorlm_forward() {
258        let device: <B as burn::tensor::backend::Backend>::Device = Default::default();
259        let cfg = tiny_config();
260        let model = SensorLMModel::<B>::new(&cfg, &device);
261
262        let sensor = Tensor::<B, 3>::zeros([2, 40, 4], &device);
263        let ids    = Tensor::<B, 2, Int>::from_ints([[1, 2, 3, 0], [4, 5, 6, 7]], &device);
264        let mask   = Tensor::<B, 2, Int>::from_ints([[1, 1, 1, 0], [1, 1, 1, 1]], &device);
265
266        let out = model.forward(sensor, ids, mask);
267        let [b1, b2] = out.logits.dims();
268        assert_eq!(b1, 2);
269        assert_eq!(b2, 2);
270        let loss: f32 = out.loss.into_scalar();
271        assert!(!loss.is_nan(), "Loss must not be NaN");
272    }
273}