1use burn::{
27 module::{Module, Param},
28 nn::{
29 conv::{Conv2d, Conv2dConfig},
30 Dropout, DropoutConfig, LayerNorm, LayerNormConfig, Linear, LinearConfig,
31 },
32 tensor::{
33 activation,
34 backend::Backend,
35 Distribution, Tensor,
36 },
37};
38
39use crate::config::{PoolType, SensorEncoderConfig};
40
41#[derive(Module, Debug)]
49pub struct PatchEmbedding<B: Backend> {
50 proj: Conv2d<B>,
51 num_patches_t: usize,
52 num_patches_c: usize,
53 d_model: usize,
54}
55
56impl<B: Backend> PatchEmbedding<B> {
57 pub fn new(
59 in_channels: usize,
60 d_model: usize,
61 patch_h: usize,
62 patch_w: usize,
63 time_steps: usize,
64 num_channels: usize,
65 device: &B::Device,
66 ) -> Self {
67 let proj = Conv2dConfig::new(
69 [in_channels, d_model],
70 [patch_h, patch_w],
71 )
72 .with_stride([patch_h, patch_w])
73 .with_padding(burn::nn::PaddingConfig2d::Valid)
74 .with_bias(true)
75 .init(device);
76
77 let num_patches_t = time_steps / patch_h;
78 let num_patches_c = (num_channels + patch_w - 1) / patch_w;
79
80 Self {
81 proj,
82 num_patches_t,
83 num_patches_c,
84 d_model,
85 }
86 }
87
88 pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 3> {
90 let out = self.proj.forward(x); let [batch, d, _pt, _pc] = out.dims();
92 let num_patches = self.num_patches_t * self.num_patches_c;
93 out.reshape([batch, d, num_patches]).swap_dims(1, 2)
95 }
96
97 pub fn num_patches(&self) -> usize {
99 self.num_patches_t * self.num_patches_c
100 }
101}
102
103#[derive(Module, Debug)]
109pub struct MlpBlock<B: Backend> {
110 fc1: Linear<B>,
111 fc2: Linear<B>,
112 dropout: Dropout,
113}
114
115impl<B: Backend> MlpBlock<B> {
116 pub fn new(d_model: usize, mlp_dim: usize, dropout: f64, device: &B::Device) -> Self {
118 Self {
119 fc1: LinearConfig::new(d_model, mlp_dim).init(device),
120 fc2: LinearConfig::new(mlp_dim, d_model).init(device),
121 dropout: DropoutConfig::new(dropout).init(),
122 }
123 }
124
125 pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
127 let x = self.fc1.forward(x);
128 let x = activation::gelu(x);
129 let x = self.dropout.forward(x);
130 let x = self.fc2.forward(x);
131 self.dropout.forward(x)
132 }
133}
134
135#[derive(Module, Debug)]
181pub struct MultiHeadSelfAttention<B: Backend> {
182 q_proj: Linear<B>,
183 k_proj: Linear<B>,
184 v_proj: Linear<B>,
185 out_proj: Linear<B>,
186 num_heads: usize,
187 head_dim: usize,
188 scale: f32,
189 chunk_size: usize, dropout: Dropout,
191}
192
193impl<B: Backend> MultiHeadSelfAttention<B> {
194 pub fn new(
198 d_model: usize,
199 num_heads: usize,
200 dropout: f64,
201 chunk_size: usize,
202 device: &B::Device,
203 ) -> Self {
204 assert_eq!(d_model % num_heads, 0);
205 let head_dim = d_model / num_heads;
206 Self {
207 q_proj: LinearConfig::new(d_model, d_model).init(device),
208 k_proj: LinearConfig::new(d_model, d_model).init(device),
209 v_proj: LinearConfig::new(d_model, d_model).init(device),
210 out_proj: LinearConfig::new(d_model, d_model).init(device),
211 num_heads,
212 head_dim,
213 scale: (head_dim as f32).powf(-0.5),
214 chunk_size,
215 dropout: DropoutConfig::new(dropout).init(),
216 }
217 }
218
219 pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
225 let [batch, seq, _d] = x.dims();
226 let h = self.num_heads;
227 let hd = self.head_dim;
228
229 let q = self.q_proj.forward(x.clone())
230 .reshape([batch, seq, h, hd]).swap_dims(1, 2); let k = self.k_proj.forward(x.clone())
232 .reshape([batch, seq, h, hd]).swap_dims(1, 2); let v = self.v_proj.forward(x)
234 .reshape([batch, seq, h, hd]).swap_dims(1, 2); let ctx = if self.chunk_size == 0 || self.chunk_size >= seq {
237 let scores = q.matmul(k.swap_dims(2, 3)).mul_scalar(self.scale);
239 let attn = activation::softmax(scores, 3);
240 let attn = self.dropout.forward(attn);
241 attn.matmul(v) } else {
243 let k_t = k.swap_dims(2, 3); let mut chunks: Vec<Tensor<B, 4>> = Vec::new();
246 let mut start = 0;
247 while start < seq {
248 let end = (start + self.chunk_size).min(seq);
249 let q_chunk = q.clone().slice([0..batch, 0..h, start..end, 0..hd]);
251 let scores = q_chunk.matmul(k_t.clone()).mul_scalar(self.scale);
253 let attn = activation::softmax(scores, 3);
254 let attn = self.dropout.forward(attn);
255 chunks.push(attn.matmul(v.clone()));
257 start = end;
258 }
259 Tensor::cat(chunks, 2) };
261
262 let ctx = ctx.swap_dims(1, 2).reshape([batch, seq, h * hd]);
263 self.out_proj.forward(ctx)
264 }
265}
266
267#[derive(Module, Debug)]
278pub struct EncoderBlock<B: Backend> {
279 norm1: LayerNorm<B>,
280 attn: MultiHeadSelfAttention<B>,
281 norm2: LayerNorm<B>,
282 mlp: MlpBlock<B>,
283 dropout: Dropout,
284}
285
286impl<B: Backend> EncoderBlock<B> {
287 pub fn new(
289 d_model: usize,
290 num_heads: usize,
291 mlp_dim: usize,
292 dropout: f64,
293 chunk_size: usize,
294 device: &B::Device,
295 ) -> Self {
296 Self {
297 norm1: LayerNormConfig::new(d_model).init(device),
298 attn: MultiHeadSelfAttention::new(d_model, num_heads, dropout, chunk_size, device),
299 norm2: LayerNormConfig::new(d_model).init(device),
300 mlp: MlpBlock::new(d_model, mlp_dim, dropout, device),
301 dropout: DropoutConfig::new(dropout).init(),
302 }
303 }
304
305 pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
307 let residual = x.clone();
308 let y = self.attn.forward(self.norm1.forward(x));
309 let y = self.dropout.forward(y);
310 let x = y + residual;
311
312 let residual = x.clone();
313 let y = self.mlp.forward(self.norm2.forward(x));
314 y + residual
315 }
316}
317
318#[derive(Module, Debug)]
324pub struct MAPHead<B: Backend> {
325 probe: Param<Tensor<B, 3>>,
326 q_proj: Linear<B>,
327 k_proj: Linear<B>,
328 v_proj: Linear<B>,
329 out_proj: Linear<B>,
330 norm: LayerNorm<B>,
331 mlp: MlpBlock<B>,
332 num_heads: usize,
333 head_dim: usize,
334 scale: f32,
335}
336
337impl<B: Backend> MAPHead<B> {
338 pub fn new(
340 d_model: usize,
341 num_heads: usize,
342 mlp_dim: usize,
343 device: &B::Device,
344 ) -> Self {
345 let head_dim = d_model / num_heads;
346 let probe = Tensor::<B, 3>::random(
347 [1, 1, d_model],
348 Distribution::Uniform(-0.02, 0.02),
349 device,
350 );
351 Self {
352 probe: Param::from_tensor(probe),
353 q_proj: LinearConfig::new(d_model, d_model).init(device),
354 k_proj: LinearConfig::new(d_model, d_model).init(device),
355 v_proj: LinearConfig::new(d_model, d_model).init(device),
356 out_proj: LinearConfig::new(d_model, d_model).init(device),
357 norm: LayerNormConfig::new(d_model).init(device),
358 mlp: MlpBlock::new(d_model, mlp_dim, 0.0, device),
359 num_heads,
360 head_dim,
361 scale: (head_dim as f32).powf(-0.5),
362 }
363 }
364
365 pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
367 let [batch, seq, d] = x.dims();
368 let h = self.num_heads;
369 let hd = self.head_dim;
370
371 let probe = self.probe.val().expand([batch, 1, d]);
372
373 let q = self.q_proj.forward(probe);
374 let k = self.k_proj.forward(x.clone());
375 let v = self.v_proj.forward(x);
376
377 let rq = |t: Tensor<B, 3>, n: usize| t.reshape([batch, n, h, hd]).swap_dims(1, 2);
378 let q = rq(q, 1);
379 let k = rq(k, seq);
380 let v = rq(v, seq);
381
382 let scores = q.matmul(k.swap_dims(2, 3)).mul_scalar(self.scale);
383 let attn = activation::softmax(scores, 3);
384
385 let ctx = attn
386 .matmul(v)
387 .swap_dims(1, 2)
388 .reshape([batch, 1, h * hd]);
389
390 let ctx = self.out_proj.forward(ctx);
391 let ctx_2d = ctx.squeeze(1); let normed = self.norm.forward(ctx_2d.clone().unsqueeze_dim(1));
394 let mlp_out = self.mlp.forward(normed).squeeze(1);
395 ctx_2d + mlp_out
396 }
397}
398
399#[derive(Module, Debug)]
408pub struct SensorEncoder<B: Backend> {
409 patch_embed: PatchEmbedding<B>,
410 pos_embed: Param<Tensor<B, 3>>,
411 blocks: Vec<EncoderBlock<B>>,
412 norm: LayerNorm<B>,
413 map_head: Option<MAPHead<B>>,
414 dropout: Dropout,
415 d_model: usize,
416}
417
418impl<B: Backend> SensorEncoder<B> {
419 pub fn new(cfg: &SensorEncoderConfig, device: &B::Device) -> Self {
421 let num_patches = cfg.num_patches();
422
423 let patch_embed = PatchEmbedding::new(
424 1,
425 cfg.d_model,
426 cfg.patch_h,
427 cfg.patch_w,
428 cfg.time_steps,
429 cfg.num_channels,
430 device,
431 );
432
433 let pos_embed = Tensor::<B, 3>::random(
434 [1, num_patches, cfg.d_model],
435 Distribution::Normal(0.0, (1.0 / cfg.d_model as f64).sqrt()),
436 device,
437 );
438
439 let blocks: Vec<EncoderBlock<B>> = (0..cfg.depth)
440 .map(|_| EncoderBlock::new(cfg.d_model, cfg.num_heads, cfg.mlp_dim, cfg.dropout, cfg.attn_chunk_size, device))
441 .collect();
442
443 let norm = LayerNormConfig::new(cfg.d_model).init(device);
444
445 let map_head = if cfg.pool_type == PoolType::Map {
446 Some(MAPHead::new(cfg.d_model, cfg.num_heads, cfg.mlp_dim, device))
447 } else {
448 None
449 };
450
451 Self {
452 patch_embed,
453 pos_embed: Param::from_tensor(pos_embed),
454 blocks,
455 norm,
456 map_head,
457 dropout: DropoutConfig::new(cfg.dropout).init(),
458 d_model: cfg.d_model,
459 }
460 }
461
462 pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
464 let [batch, _t, _c] = x.dims();
465
466 let x = x.unsqueeze_dim(1);
468
469 let mut tokens = self.patch_embed.forward(x);
471
472 let num_patches = tokens.dims()[1];
474 let pos = self.pos_embed.val().expand([batch, num_patches, self.d_model]);
475 tokens = tokens + pos;
476 tokens = self.dropout.forward(tokens);
477
478 for block in &self.blocks {
480 tokens = block.forward(tokens);
481 }
482 tokens = self.norm.forward(tokens);
483
484 let embedding: Tensor<B, 2> = match &self.map_head {
486 Some(map) => map.forward(tokens),
487 None => tokens.mean_dim(1).squeeze(1),
488 };
489
490 l2_normalize(embedding)
491 }
492}
493
494pub fn l2_normalize<B: Backend>(x: Tensor<B, 2>) -> Tensor<B, 2> {
500 let [batch, d] = x.dims();
501 let norm = x.clone().powf_scalar(2.0).sum_dim(1).sqrt().clamp_min(1e-12);
502 x / norm.expand([batch, d])
503}
504
505#[cfg(test)]
506mod tests {
507 use super::*;
508 use burn::backend::NdArray;
509 use crate::config::SensorEncoderConfig;
510
511 type B = NdArray;
512
513 fn tiny_cfg() -> SensorEncoderConfig {
514 SensorEncoderConfig {
515 time_steps: 40,
516 num_channels: 4,
517 patch_h: 10,
518 patch_w: 2,
519 d_model: 32,
520 depth: 2,
521 num_heads: 4,
522 mlp_dim: 64,
523 dropout: 0.0,
524 pool_type: PoolType::Gap,
525 head_zeroinit: false,
526 attn_chunk_size: 0, }
528 }
529
530 #[test]
531 fn test_patch_embedding_shape() {
532 let device = Default::default();
533 let cfg = tiny_cfg();
534 let pe = PatchEmbedding::<B>::new(1, cfg.d_model, cfg.patch_h, cfg.patch_w,
535 cfg.time_steps, cfg.num_channels, &device);
536 let x = Tensor::<B, 4>::zeros([2, 1, 40, 4], &device);
537 let out = pe.forward(x);
538 let [b, n, d] = out.dims();
539 assert_eq!(b, 2);
540 assert_eq!(n, (40 / 10) * (4 / 2)); assert_eq!(d, cfg.d_model);
542 }
543
544 #[test]
545 fn test_encoder_forward_shape() {
546 let device = Default::default();
547 let cfg = tiny_cfg();
548 let encoder = SensorEncoder::<B>::new(&cfg, &device);
549 let x = Tensor::<B, 3>::zeros([2, 40, 4], &device);
550 let out = encoder.forward(x);
551 let [b, d] = out.dims();
552 assert_eq!(b, 2);
553 assert_eq!(d, cfg.d_model);
554 }
555}