basic_transformer/
basic_transformer.rs

1//! Basic Transformer (Encoder-Decoder) wiring using encoder/decoder examples
2
3use train_station::{
4    optimizers::{Adam, Optimizer},
5    Tensor,
6};
7
8#[path = "basic_encoder.rs"]
9mod basic_encoder;
10use basic_encoder::EncoderBlock;
11
12#[path = "basic_decoder.rs"]
13mod basic_decoder;
14use basic_decoder::DecoderBlock;
15
16pub struct BasicTransformer {
17    pub embed_dim: usize,
18    pub num_heads: usize,
19    pub num_layers: usize,
20    encoders: Vec<EncoderBlock>,
21    decoders: Vec<DecoderBlock>,
22}
23
24impl BasicTransformer {
25    pub fn new(embed_dim: usize, num_heads: usize, num_layers: usize, seed: Option<u64>) -> Self {
26        let mut encoders = Vec::new();
27        let mut decoders = Vec::new();
28        for i in 0..num_layers {
29            encoders.push(EncoderBlock::new(
30                embed_dim,
31                num_heads,
32                seed.map(|s| s + i as u64),
33            ));
34            decoders.push(DecoderBlock::new(
35                embed_dim,
36                num_heads,
37                seed.map(|s| s + 100 + i as u64),
38            ));
39        }
40        Self {
41            embed_dim,
42            num_heads,
43            num_layers,
44            encoders,
45            decoders,
46        }
47    }
48
49    pub fn parameters(&mut self) -> Vec<&mut Tensor> {
50        let mut params = Vec::new();
51        for e in &mut self.encoders {
52            params.extend(e.parameters());
53        }
54        for d in &mut self.decoders {
55            params.extend(d.parameters());
56        }
57        params
58    }
59
60    /// Forward pass
61    /// src: [batch, src_len, embed]
62    /// tgt: [batch, tgt_len, embed]
63    pub fn forward(&self, src: &Tensor, tgt: &Tensor) -> Tensor {
64        let mut memory = src.clone();
65        for enc in &self.encoders {
66            memory = enc.forward(&memory, None);
67        }
68        let mut out = tgt.clone();
69        for dec in &self.decoders {
70            out = dec.forward(&out, &memory, None, None);
71        }
72        out
73    }
74
75    /// Greedy auto-regressive inference (toy)
76    pub fn infer_autoregressive(&self, src: &Tensor, max_steps: usize) -> Tensor {
77        let (b, _s, e) = Self::triple(src);
78        let mut memory = src.clone();
79        for enc in &self.encoders {
80            memory = enc.forward(&memory, None);
81        }
82
83        let mut out_seq: Vec<Tensor> = Vec::new();
84        // Start token: zeros
85        let mut current = Tensor::zeros(vec![b, 1, e]);
86        for _step in 0..max_steps {
87            // Build causal mask for length t
88            let t = current.shape().dims()[1];
89            let mut causal = Tensor::ones(vec![b, self.num_heads, t, t]);
90            // Upper triangle as false -> masked for all batches and heads
91            for bb in 0..b {
92                for hh in 0..self.num_heads {
93                    for i in 0..t {
94                        for j in (i + 1)..t {
95                            let offset = causal.memory_offset(&[bb, hh, i, j]);
96                            let data = causal.data_mut();
97                            data[offset] = 0.0;
98                        }
99                    }
100                }
101            }
102            let mut step_out = current.clone();
103            for dec in &self.decoders {
104                step_out = dec.forward(&step_out, &memory, Some(&causal), None);
105            }
106            // (Toy) append placeholder token; real models would project last token
107            out_seq.push(step_out.clone());
108            // Append a zero token to grow sequence by 1 for next causal computation
109            current = Tensor::zeros(vec![b, t + 1, e]);
110        }
111        // Simple return of final sequence placeholder
112        current
113    }
114
115    /// Non auto-regressive inference: single forward pass
116    pub fn infer_non_autoregressive(&self, src: &Tensor, tgt_len: usize) -> Tensor {
117        let (b, _s, e) = Self::triple(src);
118        let mut memory = src.clone();
119        for enc in &self.encoders {
120            memory = enc.forward(&memory, None);
121        }
122        let tgt = Tensor::zeros(vec![b, tgt_len, e]);
123        let mut out = tgt.clone();
124        for dec in &self.decoders {
125            out = dec.forward(&out, &memory, None, None);
126        }
127        out
128    }
129
130    /// Helper: build boolean-like causal mask [b, heads, t, t] with 1.0 keep, 0.0 masked
131    fn build_causal_mask_static(batch: usize, heads: usize, t: usize) -> Tensor {
132        let mut mask = Tensor::ones(vec![batch, heads, t, t]);
133        for bb in 0..batch {
134            for hh in 0..heads {
135                for i in 0..t {
136                    for j in (i + 1)..t {
137                        let offset = mask.memory_offset(&[bb, hh, i, j]);
138                        let data = mask.data_mut();
139                        data[offset] = 0.0;
140                    }
141                }
142            }
143        }
144        mask
145    }
146
147    /// Non auto-regressive training with teacher forcing (single pass)
148    pub fn train_non_autoregressive_steps(
149        &mut self,
150        src: &Tensor,
151        tgt: &Tensor,
152        steps: usize,
153        lr: f32,
154    ) {
155        let mut opt = Adam::with_learning_rate(lr);
156        {
157            let params_once = self.parameters();
158            for p in &params_once {
159                opt.add_parameter(p);
160            }
161        }
162        for step in 0..steps {
163            // forward + backward scope (immutable borrow)
164            {
165                let pred = self.forward(src, tgt);
166                let diff = pred.sub_tensor(tgt);
167                let mut loss = diff.pow_scalar(2.0).mean();
168                if step == 0 || step + 1 == steps {
169                    println!("NAR train step {}: loss={:.6}", step, loss.value());
170                }
171                loss.backward(None);
172            }
173            // step + zero_grad scope (mutable borrow)
174            let mut params_step = self.parameters();
175            opt.step(&mut params_step);
176            opt.zero_grad(&mut params_step);
177        }
178    }
179
180    /// Auto-regressive training (teacher forcing): predict next token with causal mask
181    pub fn train_autoregressive_steps(
182        &mut self,
183        src: &Tensor,
184        tgt: &Tensor,
185        steps: usize,
186        lr: f32,
187    ) {
188        let mut opt = Adam::with_learning_rate(lr);
189        {
190            let params_once = self.parameters();
191            for p in &params_once {
192                opt.add_parameter(p);
193            }
194        }
195
196        // Build encoder memory once (static dataset demo)
197        let mut memory = src.clone();
198        for enc in &self.encoders {
199            memory = enc.forward(&memory, None);
200        }
201
202        let (b, t, _e) = Self::triple(tgt);
203        // Predict y[t] from y[:t] using causal mask; here we simply predict full seq with mask
204        let causal = Self::build_causal_mask_static(b, self.num_heads, t);
205        for step in 0..steps {
206            // forward + backward scope
207            {
208                let mut out = tgt.clone();
209                for dec in &self.decoders {
210                    out = dec.forward(&out, &memory, Some(&causal), None);
211                }
212                let diff = out.sub_tensor(tgt);
213                let mut loss = diff.pow_scalar(2.0).mean();
214                if step == 0 || step + 1 == steps {
215                    println!("AR  train step {}: loss={:.6}", step, loss.value());
216                }
217                loss.backward(None);
218            }
219            let mut params_step = self.parameters();
220            opt.step(&mut params_step);
221            opt.zero_grad(&mut params_step);
222        }
223    }
224
225    fn triple(t: &Tensor) -> (usize, usize, usize) {
226        let d = t.shape().dims();
227        (d[0], d[1], d[2])
228    }
229}
230
231fn main() -> Result<(), Box<dyn std::error::Error>> {
232    println!("=== Basic Transformer Example ===");
233
234    let batch = 2usize;
235    let src_len = 8usize;
236    let tgt_len = 6usize;
237    let embed = 32usize;
238    let heads = 4usize;
239    let layers = 2usize;
240
241    let src = Tensor::randn(vec![batch, src_len, embed], Some(1001));
242    let tgt = Tensor::randn(vec![batch, tgt_len, embed], Some(1002));
243
244    let mut trf = BasicTransformer::new(embed, heads, layers, Some(999));
245    let out = trf.forward(&src, &tgt);
246    println!("Output shape: {:?}", out.shape().dims());
247
248    // Quick optimization step
249    let mut opt = Adam::with_learning_rate(0.005);
250    let mut params = trf.parameters();
251    for p in &params {
252        opt.add_parameter(p);
253    }
254    let mut loss = out.mean();
255    loss.backward(None);
256    opt.step(&mut params);
257    opt.zero_grad(&mut params);
258    println!("Loss: {:.6}", loss.value());
259
260    // Demo: non auto-regressive inference (single pass)
261    let nar = trf.infer_non_autoregressive(&src, tgt_len);
262    println!("NAR output shape: {:?}", nar.shape().dims());
263
264    // Demo: auto-regressive inference (toy)
265    let ar = trf.infer_autoregressive(&src, 3);
266    println!("AR output shape: {:?}", ar.shape().dims());
267
268    // NAR training demo
269    let nar_tgt = tgt.clone();
270    trf.train_non_autoregressive_steps(&src, &nar_tgt, 3, 0.01);
271
272    // AR training demo (teacher-forced)
273    let ar_tgt = tgt.clone();
274    trf.train_autoregressive_steps(&src, &ar_tgt, 3, 0.01);
275    println!("=== Done ===");
276    Ok(())
277}