1use crate::cost::stack_verifier;
14use crate::neural::data::pairs::TrainingPair;
15use crate::neural::data::tir_graph::TirGraph;
16use crate::neural::model::vocab::Vocab;
17
18pub struct AugmentConfig {
20 pub tir_reorder_variants: usize,
22 pub tasm_walk_variants: usize,
24 pub max_swap_attempts: usize,
26 pub seed: u64,
28}
29
30impl Default for AugmentConfig {
31 fn default() -> Self {
32 Self {
33 tir_reorder_variants: 10,
34 tasm_walk_variants: 50,
35 max_swap_attempts: 20,
36 seed: 0xDEAD_BEEF_A097,
37 }
38 }
39}
40
41pub fn augment_pairs(
45 pairs: &[TrainingPair],
46 vocab: &Vocab,
47 config: &AugmentConfig,
48) -> Vec<TrainingPair> {
49 let mut result = Vec::with_capacity(
50 pairs.len() * (1 + config.tir_reorder_variants + config.tasm_walk_variants),
51 );
52 let mut rng = Xorshift64::new(config.seed);
53
54 for (pair_idx, pair) in pairs.iter().enumerate() {
55 result.push(TrainingPair {
57 graph: pair.graph.clone(),
58 target_tokens: pair.target_tokens.clone(),
59 source_id: pair.source_id.clone(),
60 baseline_cost: pair.baseline_cost,
61 });
62
63 let tasm_lines: Vec<String> = pair
65 .target_tokens
66 .iter()
67 .filter(|&&t| t != 0) .filter_map(|&t| vocab.decode(t).map(|s| s.to_string()))
69 .collect();
70
71 for variant in 0..config.tasm_walk_variants {
73 if let Some(augmented_tasm) =
74 random_walk_tasm(&tasm_lines, config.max_swap_attempts, &mut rng)
75 {
76 let tokens = vocab.encode_sequence(&augmented_tasm);
77 if tokens.len() > 1 {
78 result.push(TrainingPair {
79 graph: pair.graph.clone(),
80 target_tokens: tokens,
81 source_id: format!("{}:walk{}", pair.source_id, variant),
82 baseline_cost: pair.baseline_cost,
83 });
84 }
85 }
86 }
87
88 let sub_variants = equivalent_substitutions(&tasm_lines);
90 for (sub_idx, sub_tasm) in sub_variants.into_iter().enumerate() {
91 let tokens = vocab.encode_sequence(&sub_tasm);
92 if tokens.len() > 1 {
93 result.push(TrainingPair {
94 graph: pair.graph.clone(),
95 target_tokens: tokens,
96 source_id: format!("{}:sub{}", pair.source_id, sub_idx),
97 baseline_cost: pair.baseline_cost,
98 });
99 }
100 }
101
102 for variant in 0..config.tir_reorder_variants {
104 let augmented_tir = insert_dead_code(&pair.graph, &mut rng);
105 result.push(TrainingPair {
106 graph: augmented_tir,
107 target_tokens: pair.target_tokens.clone(),
108 source_id: format!("{}:dead{}", pair.source_id, variant),
109 baseline_cost: pair.baseline_cost,
110 });
111 }
112
113 if (pair_idx + 1) % 10 == 0 {
114 eprintln!(
115 " augmented {}/{} seed pairs ({} total)",
116 pair_idx + 1,
117 pairs.len(),
118 result.len()
119 );
120 }
121 }
122
123 eprintln!(
124 " augmentation: {} seeds → {} pairs ({:.1}x)",
125 pairs.len(),
126 result.len(),
127 result.len() as f64 / pairs.len().max(1) as f64,
128 );
129
130 result
131}
132
133fn random_walk_tasm(
140 tasm: &[String],
141 max_attempts: usize,
142 rng: &mut Xorshift64,
143) -> Option<Vec<String>> {
144 if tasm.len() < 2 {
145 return None;
146 }
147
148 let mut current = tasm.to_vec();
149 let mut changed = false;
150
151 for _ in 0..max_attempts {
152 let i = (rng.next() % (current.len() - 1) as u64) as usize;
153
154 if instructions_are_independent(¤t[i], ¤t[i + 1]) {
156 current.swap(i, i + 1);
157
158 let valid = (0..3u64).all(|trial| {
160 let seed = rng.next() ^ trial.wrapping_mul(0x9E3779B97F4A7C15);
161 stack_verifier::verify_equivalent(tasm, ¤t, seed)
162 });
163
164 if valid {
165 changed = true;
166 } else {
167 current.swap(i, i + 1);
169 }
170 }
171 }
172
173 if changed {
174 Some(current)
175 } else {
176 None
177 }
178}
179
180fn instructions_are_independent(a: &str, b: &str) -> bool {
185 let a_parts: Vec<&str> = a.split_whitespace().collect();
186 let b_parts: Vec<&str> = b.split_whitespace().collect();
187
188 if a_parts.is_empty() || b_parts.is_empty() {
189 return false;
190 }
191
192 let a_op = a_parts[0];
193 let b_op = b_parts[0];
194
195 if a_op == "push" && b_op == "push" {
197 return true;
198 }
199
200 let a_pure_push = matches!(a_op, "push" | "divine" | "read_io");
204 let b_pure_push = matches!(b_op, "push" | "divine" | "read_io");
205
206 if a_pure_push && b_pure_push {
207 return true;
208 }
209
210 if a_op == "nop" || b_op == "nop" {
212 return true;
213 }
214
215 false
216}
217
218fn equivalent_substitutions(tasm: &[String]) -> Vec<Vec<String>> {
224 let mut variants = Vec::new();
225
226 for i in 0..tasm.len() {
227 match tasm[i].as_str() {
229 "nop" => {
230 let mut v = tasm.to_vec();
232 v.remove(i);
233 if verify_substitution(tasm, &v) {
234 variants.push(v);
235 }
236 }
237 "push 0" if i + 1 < tasm.len() && tasm[i + 1] == "add" => {
238 let mut v = tasm.to_vec();
240 v.remove(i + 1);
241 v.remove(i);
242 if verify_substitution(tasm, &v) {
243 variants.push(v);
244 }
245 }
246 "push 1" if i + 1 < tasm.len() && tasm[i + 1] == "mul" => {
247 let mut v = tasm.to_vec();
249 v.remove(i + 1);
250 v.remove(i);
251 if verify_substitution(tasm, &v) {
252 variants.push(v);
253 }
254 }
255 "dup 0" if i + 1 < tasm.len() && tasm[i + 1] == "pop 1" => {
256 let mut v = tasm.to_vec();
258 v.remove(i + 1);
259 v.remove(i);
260 if verify_substitution(tasm, &v) {
261 variants.push(v);
262 }
263 }
264 "swap 1" if i + 1 < tasm.len() && tasm[i + 1] == "swap 1" => {
265 let mut v = tasm.to_vec();
267 v.remove(i + 1);
268 v.remove(i);
269 if verify_substitution(tasm, &v) {
270 variants.push(v);
271 }
272 }
273 _ => {}
274 }
275
276 if tasm[i] == "add" && i >= 1 {
278 let mut v = tasm.to_vec();
280 v.insert(i, "swap 1".to_string());
281 if verify_substitution(tasm, &v) {
282 variants.push(v);
283 }
284 }
285
286 if tasm[i] == "mul" && i >= 1 {
287 let mut v = tasm.to_vec();
289 v.insert(i, "swap 1".to_string());
290 if verify_substitution(tasm, &v) {
291 variants.push(v);
292 }
293 }
294 }
295
296 variants
297}
298
299fn verify_substitution(original: &[String], candidate: &[String]) -> bool {
301 (0..3).all(|seed| stack_verifier::verify_equivalent(original, candidate, seed * 7919 + 42))
303}
304
305fn insert_dead_code(graph: &TirGraph, rng: &mut Xorshift64) -> TirGraph {
312 use crate::neural::data::tir_graph::{EdgeKind, FieldType, OpKind, TirNode};
313
314 let mut nodes = graph.nodes.clone();
315 let mut edges = graph.edges.clone();
316
317 let num_insertions = 1 + (rng.next() % 3) as usize;
319
320 for _ in 0..num_insertions {
321 if nodes.is_empty() {
322 break;
323 }
324
325 let insert_at = (rng.next() % nodes.len() as u64) as usize;
327 let dead_kind = rng.next() % 3;
328
329 let dead_nodes: Vec<TirNode> = match dead_kind {
330 0 => {
331 vec![
333 TirNode {
334 op: OpKind::Push,
335 field_type: FieldType::BFE,
336 immediate: Some(0),
337 },
338 TirNode {
339 op: OpKind::Pop,
340 field_type: FieldType::Unknown,
341 immediate: Some(1),
342 },
343 ]
344 }
345 1 => {
346 vec![
348 TirNode {
349 op: OpKind::Push,
350 field_type: FieldType::BFE,
351 immediate: Some(0),
352 },
353 TirNode {
354 op: OpKind::Dup,
355 field_type: FieldType::BFE,
356 immediate: Some(0),
357 },
358 TirNode {
359 op: OpKind::Pop,
360 field_type: FieldType::Unknown,
361 immediate: Some(2),
362 },
363 ]
364 }
365 _ => {
366 vec![
368 TirNode {
369 op: OpKind::Push,
370 field_type: FieldType::BFE,
371 immediate: Some(0),
372 },
373 TirNode {
374 op: OpKind::Push,
375 field_type: FieldType::BFE,
376 immediate: Some(0),
377 },
378 TirNode {
379 op: OpKind::Add,
380 field_type: FieldType::BFE,
381 immediate: None,
382 },
383 TirNode {
384 op: OpKind::Pop,
385 field_type: FieldType::Unknown,
386 immediate: Some(1),
387 },
388 ]
389 }
390 };
391
392 let num_dead = dead_nodes.len();
393
394 for edge in edges.iter_mut() {
396 if edge.0 >= insert_at {
397 edge.0 += num_dead;
398 }
399 if edge.1 >= insert_at {
400 edge.1 += num_dead;
401 }
402 }
403
404 let mut new_nodes = nodes[..insert_at].to_vec();
406 new_nodes.extend(dead_nodes);
407 new_nodes.extend_from_slice(&nodes[insert_at..]);
408 nodes = new_nodes;
409
410 for j in 0..num_dead.saturating_sub(1) {
412 edges.push((insert_at + j, insert_at + j + 1, EdgeKind::ControlFlow));
413 }
414
415 if num_dead >= 2 {
417 edges.push((insert_at, insert_at + num_dead - 1, EdgeKind::DataDep));
418 }
419
420 if insert_at > 0 {
422 edges.push((insert_at - 1, insert_at, EdgeKind::ControlFlow));
423 }
424 if insert_at + num_dead < nodes.len() {
425 edges.push((
426 insert_at + num_dead - 1,
427 insert_at + num_dead,
428 EdgeKind::ControlFlow,
429 ));
430 }
431 }
432
433 TirGraph { nodes, edges }
434}
435
436struct Xorshift64 {
440 state: u64,
441}
442
443impl Xorshift64 {
444 fn new(seed: u64) -> Self {
445 Self {
446 state: seed | 1, }
448 }
449
450 fn next(&mut self) -> u64 {
451 let mut x = self.state;
452 x ^= x << 13;
453 x ^= x >> 7;
454 x ^= x << 17;
455 self.state = x;
456 x
457 }
458}
459
460#[cfg(test)]
463mod tests {
464 use super::*;
465 use crate::ir::tir::TIROp;
466 use crate::neural::data::tir_graph::TirGraph;
467
468 #[test]
469 fn random_walk_preserves_equivalence() {
470 let tasm = vec![
471 "push 3".to_string(),
472 "push 4".to_string(),
473 "add".to_string(),
474 ];
475 let mut rng = Xorshift64::new(42);
476 let result = random_walk_tasm(&tasm, 10, &mut rng);
478 if let Some(ref variant) = result {
479 assert!(stack_verifier::verify_equivalent(&tasm, variant, 0));
481 }
482 }
483
484 #[test]
485 fn equivalent_substitutions_are_valid() {
486 let tasm = vec!["push 0".to_string(), "add".to_string()];
487 let variants = equivalent_substitutions(&tasm);
488 for variant in &variants {
489 assert!(
490 stack_verifier::verify_equivalent(&tasm, variant, 42),
491 "substitution not equivalent: {:?}",
492 variant,
493 );
494 }
495 }
496
497 #[test]
498 fn push_0_add_removed() {
499 let tasm = vec![
500 "push 5".to_string(),
501 "push 0".to_string(),
502 "add".to_string(),
503 ];
504 let variants = equivalent_substitutions(&tasm);
505 let has_shorter = variants.iter().any(|v| v.len() < tasm.len());
507 assert!(has_shorter, "expected push 0; add to be removed");
508 }
509
510 #[test]
511 fn dead_code_increases_graph_size() {
512 let ops = vec![TIROp::Push(1), TIROp::Push(2), TIROp::Add];
513 let graph = TirGraph::from_tir_ops(&ops);
514 let original_size = graph.num_nodes();
515
516 let mut rng = Xorshift64::new(42);
517 let augmented = insert_dead_code(&graph, &mut rng);
518 assert!(
519 augmented.num_nodes() > original_size,
520 "dead code should increase graph size",
521 );
522 }
523
524 #[test]
525 fn augment_pairs_multiplies_dataset() {
526 let vocab = Vocab::new();
527 let graph = TirGraph::from_tir_ops(&[TIROp::Push(1), TIROp::Push(2), TIROp::Add]);
528 let tokens = vocab.encode_sequence(&[
529 "push 1".to_string(),
530 "push 2".to_string(),
531 "add".to_string(),
532 ]);
533
534 let pairs = vec![TrainingPair {
535 graph,
536 target_tokens: tokens,
537 source_id: "test:0".into(),
538 baseline_cost: 3,
539 }];
540
541 let config = AugmentConfig {
542 tir_reorder_variants: 2,
543 tasm_walk_variants: 3,
544 max_swap_attempts: 5,
545 seed: 42,
546 };
547
548 let augmented = augment_pairs(&pairs, &vocab, &config);
549 assert!(
550 augmented.len() > 1,
551 "augmentation should produce more than original",
552 );
553 }
554
555 #[test]
556 fn swap_1_swap_1_eliminated() {
557 let tasm = vec![
558 "push 1".to_string(),
559 "push 2".to_string(),
560 "swap 1".to_string(),
561 "swap 1".to_string(),
562 "add".to_string(),
563 ];
564 let variants = equivalent_substitutions(&tasm);
565 let has_shorter = variants.iter().any(|v| v.len() < tasm.len());
566 assert!(has_shorter, "swap 1; swap 1 should be eliminated");
567 }
568}