1use burn::{
20 module::{Module, Param},
21 nn::{
22 Dropout, DropoutConfig, Embedding, EmbeddingConfig, LayerNorm, LayerNormConfig,
23 Linear, LinearConfig,
24 },
25 tensor::{
26 backend::Backend,
27 Distribution, Int, Tensor,
28 },
29};
30
31use crate::config::TextEncoderConfig;
32use crate::model::sensor_encoder::{EncoderBlock, l2_normalize};
33
34#[derive(Module, Debug)]
36pub struct TextEncoder<B: Backend> {
37 tok_embed: Embedding<B>,
38 pos_embed: Param<Tensor<B, 3>>,
39 blocks: Vec<EncoderBlock<B>>,
40 norm: LayerNorm<B>,
41 proj: Option<Linear<B>>,
42 dropout: Dropout,
43 d_model: usize,
44}
45
46impl<B: Backend> TextEncoder<B> {
47 pub fn new(cfg: &TextEncoderConfig, device: &B::Device) -> Self {
49 let tok_embed = EmbeddingConfig::new(cfg.vocab_size, cfg.d_model).init(device);
50
51 let pos = Tensor::<B, 3>::random(
52 [1, cfg.max_seq_len, cfg.d_model],
53 Distribution::Normal(0.0, (1.0 / cfg.d_model as f64).sqrt()),
54 device,
55 );
56
57 let blocks: Vec<EncoderBlock<B>> = (0..cfg.depth)
58 .map(|_| EncoderBlock::new(cfg.d_model, cfg.num_heads, cfg.mlp_dim, cfg.dropout, 0, device))
59 .collect();
60
61 let norm = LayerNormConfig::new(cfg.d_model).init(device);
62 let proj = cfg.out_dim.map(|out| LinearConfig::new(cfg.d_model, out).init(device));
63
64 Self {
65 tok_embed,
66 pos_embed: Param::from_tensor(pos),
67 blocks,
68 norm,
69 proj,
70 dropout: DropoutConfig::new(cfg.dropout).init(),
71 d_model: cfg.d_model,
72 }
73 }
74
75 pub fn forward(
82 &self,
83 input_ids: Tensor<B, 2, Int>,
84 attention_mask: Tensor<B, 2, Int>,
85 ) -> Tensor<B, 2> {
86 let [batch, seq] = input_ids.dims();
87
88 let tok = self.tok_embed.forward(input_ids);
90 let pos = self.pos_embed.val()
91 .slice([0..1, 0..seq, 0..self.d_model])
92 .expand([batch, seq, self.d_model]);
93
94 let mut x = tok + pos;
95 x = self.dropout.forward(x);
96
97 for block in &self.blocks {
98 x = block.forward(x);
99 }
100 x = self.norm.forward(x);
101
102 let mask: Tensor<B, 3> = attention_mask
105 .float()
106 .unsqueeze_dim::<3>(2)
107 .expand([batch, seq, self.d_model]);
108
109 let sum = (x * mask.clone()).sum_dim(1);
110 let counts = mask.sum_dim(1).clamp_min(1.0f32);
111 let pooled: Tensor<B, 2> = (sum / counts).squeeze(1);
112
113 let projected = match &self.proj {
114 Some(p) => p.forward(pooled),
115 None => pooled,
116 };
117
118 l2_normalize(projected)
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use burn::backend::NdArray;
126 use burn::tensor::Tensor;
127
128 type B = NdArray;
129
130 fn tiny_cfg() -> TextEncoderConfig {
131 TextEncoderConfig {
132 vocab_size: 100,
133 max_seq_len: 32,
134 d_model: 32,
135 depth: 2,
136 num_heads: 4,
137 mlp_dim: 64,
138 dropout: 0.0,
139 out_dim: Some(32),
140 }
141 }
142
143 #[test]
144 fn test_text_encoder_forward() {
145 let device: <B as burn::tensor::backend::Backend>::Device = Default::default();
146 let cfg = tiny_cfg();
147 let encoder = TextEncoder::<B>::new(&cfg, &device);
148
149 let ids = Tensor::<B, 2, Int>::from_ints([[1, 2, 3, 0, 0], [4, 5, 6, 7, 0]], &device);
150 let mask = Tensor::<B, 2, Int>::from_ints([[1, 1, 1, 0, 0], [1, 1, 1, 1, 0]], &device);
151
152 let out = encoder.forward(ids, mask);
153 let [b, d] = out.dims();
154 assert_eq!(b, 2);
155 assert_eq!(d, 32);
156 }
157
158 #[test]
159 fn test_output_unit_norm() {
160 let device: <B as burn::tensor::backend::Backend>::Device = Default::default();
161 let cfg = tiny_cfg();
162 let encoder = TextEncoder::<B>::new(&cfg, &device);
163
164 let ids = Tensor::<B, 2, Int>::from_ints([[1, 2, 3]], &device);
165 let mask = Tensor::<B, 2, Int>::from_ints([[1, 1, 1]], &device);
166
167 let out = encoder.forward(ids, mask);
168 let norm: Vec<f32> = out.powf_scalar(2.0).sum_dim(1).sqrt()
169 .into_data().to_vec::<f32>().unwrap();
170 for n in norm {
171 assert!((n - 1.0).abs() < 1e-5, "Expected unit norm, got {n}");
172 }
173 }
174}