1use train_station::{
4 optimizers::{Adam, Optimizer},
5 Tensor,
6};
7
8#[path = "basic_encoder.rs"]
9mod basic_encoder;
10use basic_encoder::EncoderBlock;
11
12#[path = "basic_decoder.rs"]
13mod basic_decoder;
14use basic_decoder::DecoderBlock;
15
16pub struct BasicTransformer {
17 pub embed_dim: usize,
18 pub num_heads: usize,
19 pub num_layers: usize,
20 encoders: Vec<EncoderBlock>,
21 decoders: Vec<DecoderBlock>,
22}
23
24impl BasicTransformer {
25 pub fn new(embed_dim: usize, num_heads: usize, num_layers: usize, seed: Option<u64>) -> Self {
26 let mut encoders = Vec::new();
27 let mut decoders = Vec::new();
28 for i in 0..num_layers {
29 encoders.push(EncoderBlock::new(
30 embed_dim,
31 num_heads,
32 seed.map(|s| s + i as u64),
33 ));
34 decoders.push(DecoderBlock::new(
35 embed_dim,
36 num_heads,
37 seed.map(|s| s + 100 + i as u64),
38 ));
39 }
40 Self {
41 embed_dim,
42 num_heads,
43 num_layers,
44 encoders,
45 decoders,
46 }
47 }
48
49 pub fn parameters(&mut self) -> Vec<&mut Tensor> {
50 let mut params = Vec::new();
51 for e in &mut self.encoders {
52 params.extend(e.parameters());
53 }
54 for d in &mut self.decoders {
55 params.extend(d.parameters());
56 }
57 params
58 }
59
60 pub fn forward(&self, src: &Tensor, tgt: &Tensor) -> Tensor {
64 let mut memory = src.clone();
65 for enc in &self.encoders {
66 memory = enc.forward(&memory, None);
67 }
68 let mut out = tgt.clone();
69 for dec in &self.decoders {
70 out = dec.forward(&out, &memory, None, None);
71 }
72 out
73 }
74
75 pub fn infer_autoregressive(&self, src: &Tensor, max_steps: usize) -> Tensor {
77 let (b, _s, e) = Self::triple(src);
78 let mut memory = src.clone();
79 for enc in &self.encoders {
80 memory = enc.forward(&memory, None);
81 }
82
83 let mut out_seq: Vec<Tensor> = Vec::new();
84 let mut current = Tensor::zeros(vec![b, 1, e]);
86 for _step in 0..max_steps {
87 let t = current.shape().dims()[1];
89 let mut causal = Tensor::ones(vec![b, self.num_heads, t, t]);
90 for bb in 0..b {
92 for hh in 0..self.num_heads {
93 for i in 0..t {
94 for j in (i + 1)..t {
95 let offset = causal.memory_offset(&[bb, hh, i, j]);
96 let data = causal.data_mut();
97 data[offset] = 0.0;
98 }
99 }
100 }
101 }
102 let mut step_out = current.clone();
103 for dec in &self.decoders {
104 step_out = dec.forward(&step_out, &memory, Some(&causal), None);
105 }
106 out_seq.push(step_out.clone());
108 current = Tensor::zeros(vec![b, t + 1, e]);
110 }
111 current
113 }
114
115 pub fn infer_non_autoregressive(&self, src: &Tensor, tgt_len: usize) -> Tensor {
117 let (b, _s, e) = Self::triple(src);
118 let mut memory = src.clone();
119 for enc in &self.encoders {
120 memory = enc.forward(&memory, None);
121 }
122 let tgt = Tensor::zeros(vec![b, tgt_len, e]);
123 let mut out = tgt.clone();
124 for dec in &self.decoders {
125 out = dec.forward(&out, &memory, None, None);
126 }
127 out
128 }
129
130 fn build_causal_mask_static(batch: usize, heads: usize, t: usize) -> Tensor {
132 let mut mask = Tensor::ones(vec![batch, heads, t, t]);
133 for bb in 0..batch {
134 for hh in 0..heads {
135 for i in 0..t {
136 for j in (i + 1)..t {
137 let offset = mask.memory_offset(&[bb, hh, i, j]);
138 let data = mask.data_mut();
139 data[offset] = 0.0;
140 }
141 }
142 }
143 }
144 mask
145 }
146
147 pub fn train_non_autoregressive_steps(
149 &mut self,
150 src: &Tensor,
151 tgt: &Tensor,
152 steps: usize,
153 lr: f32,
154 ) {
155 let mut opt = Adam::with_learning_rate(lr);
156 {
157 let params_once = self.parameters();
158 for p in ¶ms_once {
159 opt.add_parameter(p);
160 }
161 }
162 for step in 0..steps {
163 {
165 let pred = self.forward(src, tgt);
166 let diff = pred.sub_tensor(tgt);
167 let mut loss = diff.pow_scalar(2.0).mean();
168 if step == 0 || step + 1 == steps {
169 println!("NAR train step {}: loss={:.6}", step, loss.value());
170 }
171 loss.backward(None);
172 }
173 let mut params_step = self.parameters();
175 opt.step(&mut params_step);
176 opt.zero_grad(&mut params_step);
177 }
178 }
179
180 pub fn train_autoregressive_steps(
182 &mut self,
183 src: &Tensor,
184 tgt: &Tensor,
185 steps: usize,
186 lr: f32,
187 ) {
188 let mut opt = Adam::with_learning_rate(lr);
189 {
190 let params_once = self.parameters();
191 for p in ¶ms_once {
192 opt.add_parameter(p);
193 }
194 }
195
196 let mut memory = src.clone();
198 for enc in &self.encoders {
199 memory = enc.forward(&memory, None);
200 }
201
202 let (b, t, _e) = Self::triple(tgt);
203 let causal = Self::build_causal_mask_static(b, self.num_heads, t);
205 for step in 0..steps {
206 {
208 let mut out = tgt.clone();
209 for dec in &self.decoders {
210 out = dec.forward(&out, &memory, Some(&causal), None);
211 }
212 let diff = out.sub_tensor(tgt);
213 let mut loss = diff.pow_scalar(2.0).mean();
214 if step == 0 || step + 1 == steps {
215 println!("AR train step {}: loss={:.6}", step, loss.value());
216 }
217 loss.backward(None);
218 }
219 let mut params_step = self.parameters();
220 opt.step(&mut params_step);
221 opt.zero_grad(&mut params_step);
222 }
223 }
224
225 fn triple(t: &Tensor) -> (usize, usize, usize) {
226 let d = t.shape().dims();
227 (d[0], d[1], d[2])
228 }
229}
230
231fn main() -> Result<(), Box<dyn std::error::Error>> {
232 println!("=== Basic Transformer Example ===");
233
234 let batch = 2usize;
235 let src_len = 8usize;
236 let tgt_len = 6usize;
237 let embed = 32usize;
238 let heads = 4usize;
239 let layers = 2usize;
240
241 let src = Tensor::randn(vec![batch, src_len, embed], Some(1001));
242 let tgt = Tensor::randn(vec![batch, tgt_len, embed], Some(1002));
243
244 let mut trf = BasicTransformer::new(embed, heads, layers, Some(999));
245 let out = trf.forward(&src, &tgt);
246 println!("Output shape: {:?}", out.shape().dims());
247
248 let mut opt = Adam::with_learning_rate(0.005);
250 let mut params = trf.parameters();
251 for p in ¶ms {
252 opt.add_parameter(p);
253 }
254 let mut loss = out.mean();
255 loss.backward(None);
256 opt.step(&mut params);
257 opt.zero_grad(&mut params);
258 println!("Loss: {:.6}", loss.value());
259
260 let nar = trf.infer_non_autoregressive(&src, tgt_len);
262 println!("NAR output shape: {:?}", nar.shape().dims());
263
264 let ar = trf.infer_autoregressive(&src, 3);
266 println!("AR output shape: {:?}", ar.shape().dims());
267
268 let nar_tgt = tgt.clone();
270 trf.train_non_autoregressive_steps(&src, &nar_tgt, 3, 0.01);
271
272 let ar_tgt = tgt.clone();
274 trf.train_autoregressive_steps(&src, &ar_tgt, 3, 0.01);
275 println!("=== Done ===");
276 Ok(())
277}