1use burn::{
16 module::{Module, Param},
17 tensor::{
18 backend::{AutodiffBackend, Backend},
19 ElementConversion, Int, Tensor,
20 },
21 train::{
22 metric::{Adaptor, LossInput},
23 TrainOutput, TrainStep, ValidStep,
24 },
25 data::dataloader::batcher::Batcher,
26};
27
28use crate::config::SensorLMConfig;
29use crate::data::dataset::SensorTextItem;
30
31use crate::loss::siglip_loss;
32use crate::model::sensor_encoder::SensorEncoder;
33use crate::model::text_encoder::TextEncoder;
34
35#[derive(Module, Debug)]
41pub struct SensorLMModel<B: Backend> {
42 pub sensor_encoder: SensorEncoder<B>,
44 pub text_encoder: TextEncoder<B>,
46 pub log_temperature: Param<Tensor<B, 1>>,
48 pub bias: Param<Tensor<B, 1>>,
50}
51
52impl<B: Backend> SensorLMModel<B> {
53 pub fn new(cfg: &SensorLMConfig, device: &B::Device) -> Self {
55 let log_temp = Tensor::<B, 1>::from_floats(
56 [cfg.temperature_init.ln()],
57 device,
58 );
59 let bias = Tensor::<B, 1>::from_floats([cfg.bias_init], device);
60
61 Self {
62 sensor_encoder: SensorEncoder::new(&cfg.sensor_encoder, device),
63 text_encoder: TextEncoder::new(&cfg.text_encoder, device),
64 log_temperature: Param::from_tensor(log_temp),
65 bias: Param::from_tensor(bias),
66 }
67 }
68
69 pub fn encode_sensor(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
71 self.sensor_encoder.forward(x)
72 }
73
74 pub fn encode_text(
76 &self,
77 input_ids: Tensor<B, 2, Int>,
78 attention_mask: Tensor<B, 2, Int>,
79 ) -> Tensor<B, 2> {
80 self.text_encoder.forward(input_ids, attention_mask)
81 }
82
83 pub fn similarity_matrix(
85 &self,
86 z_sensor: Tensor<B, 2>,
87 z_text: Tensor<B, 2>,
88 ) -> Tensor<B, 2> {
89 let temperature: f32 = self.log_temperature.val().exp().into_scalar().elem();
91 let bias: f32 = self.bias.val().into_scalar().elem();
92 z_sensor.matmul(z_text.transpose())
93 .mul_scalar(temperature)
94 .add_scalar(bias)
95 }
96
97 pub fn forward(
99 &self,
100 sensor: Tensor<B, 3>,
101 input_ids: Tensor<B, 2, Int>,
102 attention_mask: Tensor<B, 2, Int>,
103 ) -> SensorLMOutput<B> {
104 let z_sensor = self.encode_sensor(sensor);
105 let z_text = self.encode_text(input_ids, attention_mask);
106 let logits = self.similarity_matrix(z_sensor, z_text);
107 let loss = siglip_loss(logits.clone());
108 SensorLMOutput { loss, logits }
109 }
110}
111
112#[derive(Debug)]
118pub struct SensorLMOutput<B: Backend> {
119 pub loss: Tensor<B, 1>,
121 pub logits: Tensor<B, 2>,
123}
124
125impl<B: Backend> Adaptor<LossInput<B>> for SensorLMOutput<B> {
127 fn adapt(&self) -> LossInput<B> {
128 LossInput::new(self.loss.clone())
129 }
130}
131
132#[derive(Debug, Clone)]
138pub struct SensorLMBatch<B: Backend> {
139 pub sensor: Tensor<B, 3>,
141 pub input_ids: Tensor<B, 2, Int>,
143 pub attention_mask: Tensor<B, 2, Int>,
145}
146
147impl<B: AutodiffBackend> TrainStep<SensorLMBatch<B>, SensorLMOutput<B>>
152 for SensorLMModel<B>
153{
154 fn step(&self, batch: SensorLMBatch<B>) -> TrainOutput<SensorLMOutput<B>> {
155 let output = self.forward(batch.sensor, batch.input_ids, batch.attention_mask);
156 TrainOutput::new(self, output.loss.backward(), output)
157 }
158}
159
160impl<B: Backend> ValidStep<SensorLMBatch<B>, SensorLMOutput<B>> for SensorLMModel<B> {
161 fn step(&self, batch: SensorLMBatch<B>) -> SensorLMOutput<B> {
162 self.forward(batch.sensor, batch.input_ids, batch.attention_mask)
163 }
164}
165
166#[derive(Clone)]
172pub struct SensorLMBatcher<B: Backend> {
173 device: B::Device,
174 time_steps: usize,
175 num_channels: usize,
176 max_seq_len: usize,
177}
178
179impl<B: Backend> SensorLMBatcher<B> {
180 pub fn new(
182 device: B::Device,
183 time_steps: usize,
184 num_channels: usize,
185 max_seq_len: usize,
186 ) -> Self {
187 Self { device, time_steps, num_channels, max_seq_len }
188 }
189}
190
191impl<B: Backend> Batcher<SensorTextItem, SensorLMBatch<B>> for SensorLMBatcher<B> {
192 fn batch(&self, items: Vec<SensorTextItem>) -> SensorLMBatch<B> {
193 let b = items.len();
194 let t = self.time_steps;
195 let c = self.num_channels;
196 let l = self.max_seq_len;
197
198 let sensor_flat: Vec<f32> = items.iter()
199 .flat_map(|it| it.sensor.iter().copied()).collect();
200 let token_flat: Vec<i32> = items.iter()
201 .flat_map(|it| it.token_ids.iter().copied()).collect();
202 let mask_flat: Vec<i32> = items.iter()
203 .flat_map(|it| it.attention_mask.iter().copied()).collect();
204
205 let sensor = Tensor::<B, 1>::from_floats(sensor_flat.as_slice(), &self.device)
206 .reshape([b, t, c]);
207 let input_ids = Tensor::<B, 1, Int>::from_ints(token_flat.as_slice(), &self.device)
208 .reshape([b, l]);
209 let attention_mask = Tensor::<B, 1, Int>::from_ints(mask_flat.as_slice(), &self.device)
210 .reshape([b, l]);
211
212 SensorLMBatch { sensor, input_ids, attention_mask }
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use burn::backend::NdArray;
220 use crate::config::{SensorEncoderConfig, TextEncoderConfig, PoolType, SensorLMConfig};
221
222 type B = NdArray;
223
224 fn tiny_config() -> SensorLMConfig {
225 SensorLMConfig {
226 sensor_encoder: SensorEncoderConfig {
227 time_steps: 40,
228 num_channels: 4,
229 patch_h: 10,
230 patch_w: 2,
231 d_model: 32,
232 depth: 2,
233 num_heads: 4,
234 mlp_dim: 64,
235 dropout: 0.0,
236 pool_type: PoolType::Gap,
237 head_zeroinit: false,
238 attn_chunk_size: 0,
239 },
240 text_encoder: TextEncoderConfig {
241 vocab_size: 100,
242 max_seq_len: 16,
243 d_model: 32,
244 depth: 2,
245 num_heads: 4,
246 mlp_dim: 64,
247 dropout: 0.0,
248 out_dim: Some(32),
249 },
250 embed_dim: 32,
251 temperature_init: 10.0,
252 bias_init: -10.0,
253 }
254 }
255
256 #[test]
257 fn test_sensorlm_forward() {
258 let device: <B as burn::tensor::backend::Backend>::Device = Default::default();
259 let cfg = tiny_config();
260 let model = SensorLMModel::<B>::new(&cfg, &device);
261
262 let sensor = Tensor::<B, 3>::zeros([2, 40, 4], &device);
263 let ids = Tensor::<B, 2, Int>::from_ints([[1, 2, 3, 0], [4, 5, 6, 7]], &device);
264 let mask = Tensor::<B, 2, Int>::from_ints([[1, 1, 1, 0], [1, 1, 1, 1]], &device);
265
266 let out = model.forward(sensor, ids, mask);
267 let [b1, b2] = out.logits.dims();
268 assert_eq!(b1, 2);
269 assert_eq!(b2, 2);
270 let loss: f32 = out.loss.into_scalar();
271 assert!(!loss.is_nan(), "Loss must not be NaN");
272 }
273}