Skip to main content

sensorlm/model/
text_encoder.rs

1//! Text transformer encoder.
2//!
3//! Encodes token sequences into fixed-size L2-normalised embeddings using
4//! a 12-layer bidirectional transformer with masked mean-pooling.
5//!
6//! # Architecture
7//!
8//! ```text
9//! token_ids (B, L)
10//!   → TokenEmbedding(vocab_size, D) + PositionalEmbedding(max_len, D)
11//!   → Dropout
12//!   → [EncoderBlock × depth]
13//!   → LayerNorm
14//!   → MaskedMeanPool  → (B, D)
15//!   → Linear(D, out)  → (B, D)
16//!   → L2-normalise    → (B, D)
17//! ```
18
19use burn::{
20    module::{Module, Param},
21    nn::{
22        Dropout, DropoutConfig, Embedding, EmbeddingConfig, LayerNorm, LayerNormConfig,
23        Linear, LinearConfig,
24    },
25    tensor::{
26        backend::Backend,
27        Distribution, Int, Tensor,
28    },
29};
30
31use crate::config::TextEncoderConfig;
32use crate::model::sensor_encoder::{EncoderBlock, l2_normalize};
33
34/// Bidirectional transformer text encoder.
35#[derive(Module, Debug)]
36pub struct TextEncoder<B: Backend> {
37    tok_embed: Embedding<B>,
38    pos_embed: Param<Tensor<B, 3>>,
39    blocks:    Vec<EncoderBlock<B>>,
40    norm:      LayerNorm<B>,
41    proj:      Option<Linear<B>>,
42    dropout:   Dropout,
43    d_model:   usize,
44}
45
46impl<B: Backend> TextEncoder<B> {
47    /// Build a text encoder from [`TextEncoderConfig`].
48    pub fn new(cfg: &TextEncoderConfig, device: &B::Device) -> Self {
49        let tok_embed = EmbeddingConfig::new(cfg.vocab_size, cfg.d_model).init(device);
50
51        let pos = Tensor::<B, 3>::random(
52            [1, cfg.max_seq_len, cfg.d_model],
53            Distribution::Normal(0.0, (1.0 / cfg.d_model as f64).sqrt()),
54            device,
55        );
56
57        let blocks: Vec<EncoderBlock<B>> = (0..cfg.depth)
58            .map(|_| EncoderBlock::new(cfg.d_model, cfg.num_heads, cfg.mlp_dim, cfg.dropout, 0, device))
59            .collect();
60
61        let norm = LayerNormConfig::new(cfg.d_model).init(device);
62        let proj = cfg.out_dim.map(|out| LinearConfig::new(cfg.d_model, out).init(device));
63
64        Self {
65            tok_embed,
66            pos_embed: Param::from_tensor(pos),
67            blocks,
68            norm,
69            proj,
70            dropout: DropoutConfig::new(cfg.dropout).init(),
71            d_model: cfg.d_model,
72        }
73    }
74
75    /// Encode token sequences to L2-normalised embeddings.
76    ///
77    /// # Arguments
78    ///
79    /// * `input_ids`      – `(B, L)` token IDs.
80    /// * `attention_mask` – `(B, L)` mask; `1` = real token, `0` = padding.
81    pub fn forward(
82        &self,
83        input_ids: Tensor<B, 2, Int>,
84        attention_mask: Tensor<B, 2, Int>,
85    ) -> Tensor<B, 2> {
86        let [batch, seq] = input_ids.dims();
87
88        // Token + positional embeddings.
89        let tok = self.tok_embed.forward(input_ids);
90        let pos = self.pos_embed.val()
91            .slice([0..1, 0..seq, 0..self.d_model])
92            .expand([batch, seq, self.d_model]);
93
94        let mut x = tok + pos;
95        x = self.dropout.forward(x);
96
97        for block in &self.blocks {
98            x = block.forward(x);
99        }
100        x = self.norm.forward(x);
101
102        // Masked mean pool.
103        // unsqueeze_dim::<3>(2) inserts a dimension at index 2: (B,L) → (B,L,1)
104        let mask: Tensor<B, 3> = attention_mask
105            .float()
106            .unsqueeze_dim::<3>(2)
107            .expand([batch, seq, self.d_model]);
108
109        let sum    = (x * mask.clone()).sum_dim(1);
110        let counts = mask.sum_dim(1).clamp_min(1.0f32);
111        let pooled: Tensor<B, 2> = (sum / counts).squeeze(1);
112
113        let projected = match &self.proj {
114            Some(p) => p.forward(pooled),
115            None    => pooled,
116        };
117
118        l2_normalize(projected)
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use burn::backend::NdArray;
126    use burn::tensor::Tensor;
127
128    type B = NdArray;
129
130    fn tiny_cfg() -> TextEncoderConfig {
131        TextEncoderConfig {
132            vocab_size: 100,
133            max_seq_len: 32,
134            d_model: 32,
135            depth: 2,
136            num_heads: 4,
137            mlp_dim: 64,
138            dropout: 0.0,
139            out_dim: Some(32),
140        }
141    }
142
143    #[test]
144    fn test_text_encoder_forward() {
145        let device: <B as burn::tensor::backend::Backend>::Device = Default::default();
146        let cfg = tiny_cfg();
147        let encoder = TextEncoder::<B>::new(&cfg, &device);
148
149        let ids  = Tensor::<B, 2, Int>::from_ints([[1, 2, 3, 0, 0], [4, 5, 6, 7, 0]], &device);
150        let mask = Tensor::<B, 2, Int>::from_ints([[1, 1, 1, 0, 0], [1, 1, 1, 1, 0]], &device);
151
152        let out = encoder.forward(ids, mask);
153        let [b, d] = out.dims();
154        assert_eq!(b, 2);
155        assert_eq!(d, 32);
156    }
157
158    #[test]
159    fn test_output_unit_norm() {
160        let device: <B as burn::tensor::backend::Backend>::Device = Default::default();
161        let cfg = tiny_cfg();
162        let encoder = TextEncoder::<B>::new(&cfg, &device);
163
164        let ids  = Tensor::<B, 2, Int>::from_ints([[1, 2, 3]], &device);
165        let mask = Tensor::<B, 2, Int>::from_ints([[1, 1, 1]], &device);
166
167        let out = encoder.forward(ids, mask);
168        let norm: Vec<f32> = out.powf_scalar(2.0).sum_dim(1).sqrt()
169            .into_data().to_vec::<f32>().unwrap();
170        for n in norm {
171            assert!((n - 1.0).abs() < 1e-5, "Expected unit norm, got {n}");
172        }
173    }
174}