1use burn::grad_clipping::GradientClippingConfig;
7use burn::optim::{AdamWConfig, GradientsParams, Optimizer};
8use burn::prelude::*;
9use burn::tensor::activation;
10
11use crate::neural::data::pairs::TrainingPair;
12use crate::neural::data::tir_graph::NODE_FEATURE_DIM;
13use crate::neural::model::composite::NeuralCompilerV2;
14use crate::neural::model::grammar::precompute_sequence_state;
15
16pub struct SupervisedConfig {
18 pub lr: f64,
20 pub lr_min: f64,
22 pub weight_decay: f64,
24 pub grad_clip: f32,
26 pub max_epochs: usize,
28 pub patience: usize,
30}
31
32impl Default for SupervisedConfig {
33 fn default() -> Self {
34 Self {
35 lr: 3e-4,
36 lr_min: 1e-5,
37 weight_decay: 0.01,
38 grad_clip: 1.0,
39 max_epochs: 100,
40 patience: 3,
41 }
42 }
43}
44
45pub fn cosine_lr(config: &SupervisedConfig, epoch: usize, total_epochs: usize) -> f64 {
47 if total_epochs <= 1 {
48 return config.lr;
49 }
50 let t = epoch as f64 / total_epochs as f64;
51 config.lr_min + 0.5 * (config.lr - config.lr_min) * (1.0 + (std::f64::consts::PI * t).cos())
52}
53
54pub struct EpochResult {
56 pub avg_loss: f32,
58 pub num_pairs: usize,
60}
61
62pub fn train_epoch<B: burn::tensor::backend::AutodiffBackend>(
69 model: NeuralCompilerV2<B>,
70 pairs: &[TrainingPair],
71 optimizer: &mut impl Optimizer<NeuralCompilerV2<B>, B>,
72 lr: f64,
73 device: &B::Device,
74) -> (NeuralCompilerV2<B>, EpochResult) {
75 let mut total_loss = 0.0f32;
76 let mut model = model;
77
78 for pair in pairs {
79 let node_features = graph_to_features::<B>(&pair.graph, device);
81 let (edge_src, edge_dst, edge_types) = graph_to_edges::<B>(&pair.graph, device);
82
83 let (node_emb, _global) =
85 model
86 .encoder
87 .forward(node_features, edge_src, edge_dst, edge_types);
88 let d_model = node_emb.dims()[1];
90 let num_nodes = node_emb.dims()[0];
91 let memory = node_emb.unsqueeze_dim::<3>(0);
92
93 const MAX_SEQ: usize = 256;
96 let tokens = if pair.target_tokens.len() > MAX_SEQ {
97 &pair.target_tokens[..MAX_SEQ]
98 } else {
99 &pair.target_tokens
100 };
101 let seq_len = tokens.len();
102 if seq_len < 2 {
103 continue; }
105
106 let mut input_tokens = vec![0i32]; for &t in &tokens[..seq_len - 1] {
109 input_tokens.push(t as i32);
110 }
111 let token_ids =
112 Tensor::<B, 2, Int>::from_data(TensorData::new(input_tokens, [1, seq_len]), device);
113
114 let positions = Tensor::<B, 2, Int>::from_data(
116 TensorData::new((0..seq_len as i32).collect::<Vec<_>>(), [1, seq_len]),
117 device,
118 );
119
120 let state = precompute_sequence_state(tokens, 0);
122
123 let stack_depths = Tensor::<B, 2, Int>::from_data(
124 TensorData::new(
125 state
126 .depths
127 .iter()
128 .map(|&d| (d as i32).min(64))
129 .collect::<Vec<_>>(),
130 [1, seq_len],
131 ),
132 device,
133 );
134
135 let type_data: Vec<f32> = state.type_states.into_iter().flatten().collect();
136 let type_states =
137 Tensor::<B, 3>::from_data(TensorData::new(type_data, [1, seq_len, 24]), device);
138
139 let memory_expanded = memory.expand([1, num_nodes, d_model]);
141 let logits = model.decoder.forward(
142 token_ids,
143 positions,
144 stack_depths,
145 type_states,
146 memory_expanded,
147 );
148 let targets = Tensor::<B, 2, Int>::from_data(
158 TensorData::new(
159 tokens.iter().map(|&t| t as i32).collect::<Vec<_>>(),
160 [1, seq_len],
161 ),
162 device,
163 );
164
165 let loss = cross_entropy_loss(logits, targets);
166 let loss_val: f32 = loss.clone().into_data().to_vec::<f32>().unwrap()[0];
167 total_loss += loss_val;
168
169 let grads = loss.backward();
171 let grads = GradientsParams::from_grads(grads, &model);
172 model = optimizer.step(lr, model, grads);
173 }
174
175 let avg_loss = if pairs.is_empty() {
176 0.0
177 } else {
178 total_loss / pairs.len() as f32
179 };
180
181 (
182 model,
183 EpochResult {
184 avg_loss,
185 num_pairs: pairs.len(),
186 },
187 )
188}
189
190fn cross_entropy_loss<B: Backend>(
193 logits: Tensor<B, 3>,
194 targets: Tensor<B, 2, Int>,
195) -> Tensor<B, 1> {
196 let [batch, seq, vocab] = logits.dims();
197
198 let logits_flat = logits.reshape([batch * seq, vocab]);
200 let targets_flat = targets.reshape([batch * seq]);
201
202 let log_probs = activation::log_softmax(logits_flat, 1);
204
205 let targets_2d: Tensor<B, 2, Int> = targets_flat.unsqueeze_dim::<2>(1);
207 let selected = log_probs.gather(1, targets_2d); selected.mean().neg().unsqueeze()
211}
212
213pub fn graph_to_features<B: Backend>(
215 graph: &crate::neural::data::tir_graph::TirGraph,
216 device: &B::Device,
217) -> Tensor<B, 2> {
218 let num_nodes = graph.nodes.len();
219 let mut data = vec![0.0f32; num_nodes * NODE_FEATURE_DIM];
220 for (i, node) in graph.nodes.iter().enumerate() {
221 let fv = node.feature_vector();
222 data[i * NODE_FEATURE_DIM..(i + 1) * NODE_FEATURE_DIM].copy_from_slice(&fv);
223 }
224 Tensor::from_data(TensorData::new(data, [num_nodes, NODE_FEATURE_DIM]), device)
225}
226
227pub fn graph_to_edges<B: Backend>(
229 graph: &crate::neural::data::tir_graph::TirGraph,
230 device: &B::Device,
231) -> (Tensor<B, 1, Int>, Tensor<B, 1, Int>, Tensor<B, 1, Int>) {
232 let num_edges = graph.edges.len().max(1); let mut src = vec![0i32; num_edges];
234 let mut dst = vec![0i32; num_edges];
235 let mut types = vec![0i32; num_edges];
236
237 for (i, &(s, d, ref kind)) in graph.edges.iter().enumerate() {
238 src[i] = s as i32;
239 dst[i] = d as i32;
240 types[i] = match kind {
241 crate::neural::data::tir_graph::EdgeKind::DataDep => 0,
242 crate::neural::data::tir_graph::EdgeKind::ControlFlow => 1,
243 crate::neural::data::tir_graph::EdgeKind::MemOrder => 2,
244 };
245 }
246
247 (
248 Tensor::from_data(TensorData::new(src, [num_edges]), device),
249 Tensor::from_data(TensorData::new(dst, [num_edges]), device),
250 Tensor::from_data(TensorData::new(types, [num_edges]), device),
251 )
252}
253
254pub fn create_optimizer<B: burn::tensor::backend::AutodiffBackend>(
256 config: &SupervisedConfig,
257) -> impl Optimizer<NeuralCompilerV2<B>, B> {
258 AdamWConfig::new()
259 .with_weight_decay(config.weight_decay as f32)
260 .with_grad_clipping(Some(GradientClippingConfig::Norm(config.grad_clip)))
261 .init()
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267 use crate::ir::tir::TIROp;
268 use crate::neural::data::pairs::extract_pairs;
269 use crate::neural::model::composite::NeuralCompilerConfig;
270 use crate::neural::model::vocab::Vocab;
271 use burn::backend::Autodiff;
272 use burn::backend::NdArray;
273
274 type B = Autodiff<NdArray>;
275
276 #[test]
277 fn train_epoch_runs() {
278 let device = Default::default();
279
280 let config = NeuralCompilerConfig {
281 d_model: 32,
282 d_edge: 8,
283 gnn_layers: 1,
284 decoder_layers: 1,
285 n_heads: 4,
286 d_ff: 64,
287 max_seq: 32,
288 dropout: 0.0,
289 };
290 let model = config.init::<B>(&device);
291
292 let vocab = Vocab::new();
293 let blocks = vec![(
294 vec![TIROp::Push(1), TIROp::Push(2), TIROp::Add],
295 vec!["push 1".into(), "push 2".into(), "add".into()],
296 "test:0..3".into(),
297 3u64,
298 )];
299 let pairs = extract_pairs(&blocks, &vocab);
300
301 let supervised_config = SupervisedConfig::default();
302 let mut optimizer = create_optimizer::<B>(&supervised_config);
303
304 let lr = supervised_config.lr;
305 let (model, result) = train_epoch(model, &pairs, &mut optimizer, lr, &device);
306 assert_eq!(result.num_pairs, 1);
307 assert!(result.avg_loss > 0.0, "loss should be positive");
308 assert!(result.avg_loss.is_finite(), "loss should be finite");
309
310 let (_model2, result2) = train_epoch(model, &pairs, &mut optimizer, lr, &device);
312 assert!(result2.avg_loss.is_finite());
313 }
314}