1use crate::{TernaryMLP, TritMatrix, bitnet_threshold, quantize};
11use ternlang_core::trit::Trit;
12
13pub struct QatConfig {
16 pub lr: f32,
18 pub epochs: usize,
20 pub clip_threshold: f32,
24 pub log_every: usize,
26}
27
28impl Default for QatConfig {
29 fn default() -> Self {
30 Self {
31 lr: 0.01,
32 epochs: 100,
33 clip_threshold: 1.0,
34 log_every: 10,
35 }
36 }
37}
38
39pub struct QatResult {
42 pub final_loss: f32,
43 pub epochs_run: usize,
44 pub active_gradient_fraction: f32,
46}
47
48pub struct SteTrainer {
57 pub w1_latent: Vec<f32>, pub w2_latent: Vec<f32>, pub in_features: usize,
60 pub hidden_size: usize,
61 pub out_features: usize,
62 pub config: QatConfig,
63}
64
65impl SteTrainer {
66 pub fn from_mlp(mlp: &TernaryMLP, config: QatConfig) -> Self {
69 let w1_latent = mlp.w1.to_i8_vec().iter().map(|&v| v as f32).collect();
70 let w2_latent = mlp.w2.to_i8_vec().iter().map(|&v| v as f32).collect();
71 Self {
72 w1_latent,
73 w2_latent,
74 in_features: mlp.in_features,
75 hidden_size: mlp.hidden_size,
76 out_features: mlp.out_features,
77 config,
78 }
79 }
80
81 pub fn from_f32(
83 in_features: usize,
84 hidden_size: usize,
85 out_features: usize,
86 w1_f32: Vec<f32>,
87 w2_f32: Vec<f32>,
88 config: QatConfig,
89 ) -> Self {
90 assert_eq!(w1_f32.len(), in_features * hidden_size);
91 assert_eq!(w2_f32.len(), hidden_size * out_features);
92 Self { w1_latent: w1_f32, w2_latent: w2_f32, in_features, hidden_size, out_features, config }
93 }
94
95 fn quantize_latent(latent: &[f32]) -> Vec<f32> {
98 let tau = bitnet_threshold(latent);
99 quantize(latent, tau).iter().map(|&t| match t {
100 Trit::Affirm => 1.0,
101 Trit::Reject => -1.0,
102 Trit::Tend => 0.0,
103 }).collect()
104 }
105
106 fn ste_mask(latent: &[f32], clip: f32) -> Vec<f32> {
108 latent.iter().map(|&w| if w.abs() <= clip { 1.0 } else { 0.0 }).collect()
109 }
110
111 fn matmul(a: &[f32], b: &[f32], m: usize, k: usize, n: usize) -> Vec<f32> {
113 let mut c = vec![0.0f32; m * n];
114 for i in 0..m {
115 for j in 0..n {
116 let mut acc = 0.0f32;
117 for p in 0..k {
118 acc += a[i * k + p] * b[p * n + j];
119 }
120 c[i * n + j] = acc;
121 }
122 }
123 c
124 }
125
126 fn transpose(a: &[f32], rows: usize, cols: usize) -> Vec<f32> {
128 let mut out = vec![0.0f32; rows * cols];
129 for r in 0..rows {
130 for c in 0..cols {
131 out[c * rows + r] = a[r * cols + c];
132 }
133 }
134 out
135 }
136
137 pub fn train_step(&mut self, input: &[f32], target: &[f32]) -> f32 {
146 let (inf, hs, outf) = (self.in_features, self.hidden_size, self.out_features);
147
148 let w1_q = Self::quantize_latent(&self.w1_latent);
150 let w2_q = Self::quantize_latent(&self.w2_latent);
151
152 let hidden = Self::matmul(input, &w1_q, 1, inf, hs);
155
156 let hidden_act: Vec<f32> = hidden.iter().map(|&h| {
158 if h > 0.0 { 1.0 } else if h < 0.0 { -1.0 } else { 0.0 }
159 }).collect();
160
161 let output = Self::matmul(&hidden_act, &w2_q, 1, hs, outf);
163
164 let loss: f32 = output.iter().zip(target.iter())
166 .map(|(o, t)| (o - t).powi(2))
167 .sum::<f32>() / outf as f32;
168
169 let d_output: Vec<f32> = output.iter().zip(target.iter())
172 .map(|(o, t)| 2.0 * (o - t) / outf as f32)
173 .collect();
174
175 let hidden_act_t = Self::transpose(&hidden_act, 1, hs);
179 let d_w2_q = Self::matmul(&hidden_act_t, &d_output, hs, 1, outf);
180 let ste2 = Self::ste_mask(&self.w2_latent, self.config.clip_threshold);
181 let d_w2: Vec<f32> = d_w2_q.iter().zip(ste2.iter()).map(|(g, m)| g * m).collect();
182
183 let w2_q_t = Self::transpose(&w2_q, hs, outf);
186 let d_hidden_act = Self::matmul(&d_output, &w2_q_t, 1, outf, hs);
187
188 let d_hidden: Vec<f32> = d_hidden_act.iter().zip(hidden.iter())
191 .map(|(g, h)| if *h != 0.0 { *g } else { 0.0 })
192 .collect();
193
194 let input_t = Self::transpose(input, 1, inf);
198 let d_w1_q = Self::matmul(&input_t, &d_hidden, inf, 1, hs);
199 let ste1 = Self::ste_mask(&self.w1_latent, self.config.clip_threshold);
200 let d_w1: Vec<f32> = d_w1_q.iter().zip(ste1.iter()).map(|(g, m)| g * m).collect();
201
202 let lr = self.config.lr;
204 for (w, g) in self.w1_latent.iter_mut().zip(d_w1.iter()) {
205 *w -= lr * g;
206 }
207 for (w, g) in self.w2_latent.iter_mut().zip(d_w2.iter()) {
208 *w -= lr * g;
209 }
210
211 loss
212 }
213
214 pub fn train(&mut self, samples: &[(Vec<f32>, Vec<f32>)]) -> QatResult {
218 let mut final_loss = 0.0f32;
219
220 for epoch in 0..self.config.epochs {
221 let mut epoch_loss = 0.0f32;
222 for (input, target) in samples.iter() {
223 epoch_loss += self.train_step(input, target);
224 }
225 epoch_loss /= samples.len() as f32;
226 final_loss = epoch_loss;
227
228 if self.config.log_every > 0 && (epoch + 1) % self.config.log_every == 0 {
229 println!("[QAT/STE] epoch {:>4}/{} | loss: {:.6}", epoch + 1, self.config.epochs, epoch_loss);
230 }
231 }
232
233 let active = self.w1_latent.iter().chain(self.w2_latent.iter())
234 .filter(|&&w| w.abs() <= self.config.clip_threshold)
235 .count();
236 let total = self.w1_latent.len() + self.w2_latent.len();
237 let active_gradient_fraction = active as f32 / total as f32;
238
239 QatResult {
240 final_loss,
241 epochs_run: self.config.epochs,
242 active_gradient_fraction,
243 }
244 }
245
246 pub fn finalize(&self) -> TernaryMLP {
248 let tau1 = bitnet_threshold(&self.w1_latent);
249 let tau2 = bitnet_threshold(&self.w2_latent);
250 let w1 = TritMatrix::from_f32(self.in_features, self.hidden_size, &self.w1_latent, tau1);
251 let w2 = TritMatrix::from_f32(self.hidden_size, self.out_features, &self.w2_latent, tau2);
252 TernaryMLP::new(w1, w2)
253 }
254}
255
256#[cfg(test)]
259mod tests {
260 use super::*;
261
262 fn lcg(n: usize, seed: u64) -> Vec<f32> {
263 let mut s = seed;
264 (0..n).map(|_| {
265 s = s.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
266 ((s >> 33) as f32) / (u32::MAX as f32) * 2.0 - 1.0
267 }).collect()
268 }
269
270 #[test]
271 fn ste_trainer_reduces_loss() {
272 let (inf, hs, outf) = (8, 16, 4);
273 let w1 = lcg(inf * hs, 0xdead);
274 let w2 = lcg(hs * outf, 0xbeef);
275 let config = QatConfig { lr: 0.05, epochs: 50, clip_threshold: 1.0, log_every: 0 };
276 let mut trainer = SteTrainer::from_f32(inf, hs, outf, w1, w2, config);
277
278 let samples: Vec<(Vec<f32>, Vec<f32>)> = (0..8).map(|i| {
280 let input = lcg(inf, i as u64 * 17 + 3);
281 let target = vec![1.0, -1.0, 0.0, 0.0];
282 (input, target)
283 }).collect();
284
285 let initial_loss = {
286 let mut l = 0.0f32;
287 for (input, target) in &samples {
288 let w1_q = SteTrainer::quantize_latent(&trainer.w1_latent);
289 let w2_q = SteTrainer::quantize_latent(&trainer.w2_latent);
290 let hidden = SteTrainer::matmul(input, &w1_q, 1, inf, hs);
291 let hidden_act: Vec<f32> = hidden.iter().map(|&h|
292 if h > 0.0 { 1.0 } else if h < 0.0 { -1.0 } else { 0.0 }
293 ).collect();
294 let output = SteTrainer::matmul(&hidden_act, &w2_q, 1, hs, outf);
295 l += output.iter().zip(target.iter()).map(|(o, t)| (o-t).powi(2)).sum::<f32>() / outf as f32;
296 }
297 l / samples.len() as f32
298 };
299
300 let result = trainer.train(&samples);
301 println!("[test] initial_loss={:.4} final_loss={:.4}", initial_loss, result.final_loss);
302 assert!(result.final_loss <= initial_loss, "QAT training must not increase loss");
303 assert!(result.active_gradient_fraction > 0.0, "Some gradients must flow through STE");
304 }
305
306 #[test]
307 fn finalize_produces_valid_mlp() {
308 let (inf, hs, outf) = (4, 8, 2);
309 let w1 = lcg(inf * hs, 0xfeed);
310 let w2 = lcg(hs * outf, 0xcafe);
311 let config = QatConfig { lr: 0.01, epochs: 5, clip_threshold: 1.0, log_every: 0 };
312 let mut trainer = SteTrainer::from_f32(inf, hs, outf, w1, w2, config);
313
314 let samples = vec![
315 (lcg(inf, 1), vec![1.0, -1.0]),
316 (lcg(inf, 2), vec![-1.0, 1.0]),
317 ];
318 trainer.train(&samples);
319
320 let mlp = trainer.finalize();
321 assert_eq!(mlp.in_features, inf);
322 assert_eq!(mlp.hidden_size, hs);
323 assert_eq!(mlp.out_features, outf);
324
325 let input = TritMatrix::from_f32(1, inf, &lcg(inf, 99), 0.3);
327 let (output, _, _) = mlp.forward(&input);
328 assert_eq!(output.rows, 1);
329 assert_eq!(output.cols, outf);
330 }
331}