Skip to main content

trident/neural/training/
augment.rs

1//! Data augmentation for neural compiler training.
2//!
3//! Two families of augmentations:
4//!
5//! 1. **Structural** (TIR-side): reorder independent ops, insert dead code.
6//!    These change the graph topology while preserving semantics.
7//!
8//! 2. **Output-space** (TASM-side): swap adjacent independent instructions,
9//!    apply equivalent substitutions. Validated via stack_verifier.
10//!
11//! Target: 50 seed pairs → 5,000-10,000 augmented pairs.
12
13use crate::cost::stack_verifier;
14use crate::neural::data::pairs::TrainingPair;
15use crate::neural::data::tir_graph::TirGraph;
16use crate::neural::model::vocab::Vocab;
17
18/// Configuration for augmentation pipeline.
19pub struct AugmentConfig {
20    /// Number of TIR reordering variants per seed pair.
21    pub tir_reorder_variants: usize,
22    /// Number of TASM random-walk variants per seed pair.
23    pub tasm_walk_variants: usize,
24    /// Max swap attempts per random walk.
25    pub max_swap_attempts: usize,
26    /// Random seed for reproducibility.
27    pub seed: u64,
28}
29
30impl Default for AugmentConfig {
31    fn default() -> Self {
32        Self {
33            tir_reorder_variants: 10,
34            tasm_walk_variants: 50,
35            max_swap_attempts: 20,
36            seed: 0xDEAD_BEEF_A097,
37        }
38    }
39}
40
41/// Augment a set of training pairs using both structural and output-space methods.
42///
43/// Returns the original pairs plus all augmented variants.
44pub fn augment_pairs(
45    pairs: &[TrainingPair],
46    vocab: &Vocab,
47    config: &AugmentConfig,
48) -> Vec<TrainingPair> {
49    let mut result = Vec::with_capacity(
50        pairs.len() * (1 + config.tir_reorder_variants + config.tasm_walk_variants),
51    );
52    let mut rng = Xorshift64::new(config.seed);
53
54    for (pair_idx, pair) in pairs.iter().enumerate() {
55        // Keep original
56        result.push(TrainingPair {
57            graph: pair.graph.clone(),
58            target_tokens: pair.target_tokens.clone(),
59            source_id: pair.source_id.clone(),
60            baseline_cost: pair.baseline_cost,
61        });
62
63        // Decode target tokens back to TASM lines for TASM-side augmentation
64        let tasm_lines: Vec<String> = pair
65            .target_tokens
66            .iter()
67            .filter(|&&t| t != 0) // skip EOS
68            .filter_map(|&t| vocab.decode(t).map(|s| s.to_string()))
69            .collect();
70
71        // Output-space augmentation: random walk on TASM
72        for variant in 0..config.tasm_walk_variants {
73            if let Some(augmented_tasm) =
74                random_walk_tasm(&tasm_lines, config.max_swap_attempts, &mut rng)
75            {
76                let tokens = vocab.encode_sequence(&augmented_tasm);
77                if tokens.len() > 1 {
78                    result.push(TrainingPair {
79                        graph: pair.graph.clone(),
80                        target_tokens: tokens,
81                        source_id: format!("{}:walk{}", pair.source_id, variant),
82                        baseline_cost: pair.baseline_cost,
83                    });
84                }
85            }
86        }
87
88        // Equivalent substitutions on TASM
89        let sub_variants = equivalent_substitutions(&tasm_lines);
90        for (sub_idx, sub_tasm) in sub_variants.into_iter().enumerate() {
91            let tokens = vocab.encode_sequence(&sub_tasm);
92            if tokens.len() > 1 {
93                result.push(TrainingPair {
94                    graph: pair.graph.clone(),
95                    target_tokens: tokens,
96                    source_id: format!("{}:sub{}", pair.source_id, sub_idx),
97                    baseline_cost: pair.baseline_cost,
98                });
99            }
100        }
101
102        // TIR-side augmentation: dead code insertion
103        for variant in 0..config.tir_reorder_variants {
104            let augmented_tir = insert_dead_code(&pair.graph, &mut rng);
105            result.push(TrainingPair {
106                graph: augmented_tir,
107                target_tokens: pair.target_tokens.clone(),
108                source_id: format!("{}:dead{}", pair.source_id, variant),
109                baseline_cost: pair.baseline_cost,
110            });
111        }
112
113        if (pair_idx + 1) % 10 == 0 {
114            eprintln!(
115                "  augmented {}/{} seed pairs ({} total)",
116                pair_idx + 1,
117                pairs.len(),
118                result.len()
119            );
120        }
121    }
122
123    eprintln!(
124        "  augmentation: {} seeds → {} pairs ({:.1}x)",
125        pairs.len(),
126        result.len(),
127        result.len() as f64 / pairs.len().max(1) as f64,
128    );
129
130    result
131}
132
133// ─── TASM Random Walk ─────────────────────────────────────────────
134
135/// Apply random adjacent swaps to TASM, keeping only valid variants.
136///
137/// Strategy: try swapping adjacent instructions. If the result passes
138/// `verify_equivalent()` on multiple random inputs, accept the swap.
139fn random_walk_tasm(
140    tasm: &[String],
141    max_attempts: usize,
142    rng: &mut Xorshift64,
143) -> Option<Vec<String>> {
144    if tasm.len() < 2 {
145        return None;
146    }
147
148    let mut current = tasm.to_vec();
149    let mut changed = false;
150
151    for _ in 0..max_attempts {
152        let i = (rng.next() % (current.len() - 1) as u64) as usize;
153
154        // Skip swaps that would reorder dependent instructions
155        if instructions_are_independent(&current[i], &current[i + 1]) {
156            current.swap(i, i + 1);
157
158            // Verify equivalence on 3 random seeds
159            let valid = (0..3u64).all(|trial| {
160                let seed = rng.next() ^ trial.wrapping_mul(0x9E3779B97F4A7C15);
161                stack_verifier::verify_equivalent(tasm, &current, seed)
162            });
163
164            if valid {
165                changed = true;
166            } else {
167                // Revert
168                current.swap(i, i + 1);
169            }
170        }
171    }
172
173    if changed {
174        Some(current)
175    } else {
176        None
177    }
178}
179
180/// Check if two TASM instructions are likely independent (can be reordered).
181///
182/// Conservative: returns true only for pure stack ops that don't depend
183/// on each other's outputs (both push to different stack positions).
184fn instructions_are_independent(a: &str, b: &str) -> bool {
185    let a_parts: Vec<&str> = a.split_whitespace().collect();
186    let b_parts: Vec<&str> = b.split_whitespace().collect();
187
188    if a_parts.is_empty() || b_parts.is_empty() {
189        return false;
190    }
191
192    let a_op = a_parts[0];
193    let b_op = b_parts[0];
194
195    // Two push instructions are always independent
196    if a_op == "push" && b_op == "push" {
197        return true;
198    }
199
200    // Commutative binary ops followed by another commutative op
201    // Actually, this is tricky. Be very conservative:
202    // Only allow swapping two instructions that both only push (no pops).
203    let a_pure_push = matches!(a_op, "push" | "divine" | "read_io");
204    let b_pure_push = matches!(b_op, "push" | "divine" | "read_io");
205
206    if a_pure_push && b_pure_push {
207        return true;
208    }
209
210    // Two nops
211    if a_op == "nop" || b_op == "nop" {
212        return true;
213    }
214
215    false
216}
217
218// ─── Equivalent Substitutions ─────────────────────────────────────
219
220/// Apply pattern-based equivalent substitutions to TASM.
221///
222/// Returns all valid single-substitution variants.
223fn equivalent_substitutions(tasm: &[String]) -> Vec<Vec<String>> {
224    let mut variants = Vec::new();
225
226    for i in 0..tasm.len() {
227        // Single-instruction substitutions
228        match tasm[i].as_str() {
229            "nop" => {
230                // nop → (remove)
231                let mut v = tasm.to_vec();
232                v.remove(i);
233                if verify_substitution(tasm, &v) {
234                    variants.push(v);
235                }
236            }
237            "push 0" if i + 1 < tasm.len() && tasm[i + 1] == "add" => {
238                // push 0; add → (remove both — identity)
239                let mut v = tasm.to_vec();
240                v.remove(i + 1);
241                v.remove(i);
242                if verify_substitution(tasm, &v) {
243                    variants.push(v);
244                }
245            }
246            "push 1" if i + 1 < tasm.len() && tasm[i + 1] == "mul" => {
247                // push 1; mul → (remove both — identity)
248                let mut v = tasm.to_vec();
249                v.remove(i + 1);
250                v.remove(i);
251                if verify_substitution(tasm, &v) {
252                    variants.push(v);
253                }
254            }
255            "dup 0" if i + 1 < tasm.len() && tasm[i + 1] == "pop 1" => {
256                // dup 0; pop 1 → (remove both — noop)
257                let mut v = tasm.to_vec();
258                v.remove(i + 1);
259                v.remove(i);
260                if verify_substitution(tasm, &v) {
261                    variants.push(v);
262                }
263            }
264            "swap 1" if i + 1 < tasm.len() && tasm[i + 1] == "swap 1" => {
265                // swap 1; swap 1 → (remove both — identity)
266                let mut v = tasm.to_vec();
267                v.remove(i + 1);
268                v.remove(i);
269                if verify_substitution(tasm, &v) {
270                    variants.push(v);
271                }
272            }
273            _ => {}
274        }
275
276        // Expansion substitutions (make longer but equivalent)
277        if tasm[i] == "add" && i >= 1 {
278            // add → swap 1; add (commutativity — same result)
279            let mut v = tasm.to_vec();
280            v.insert(i, "swap 1".to_string());
281            if verify_substitution(tasm, &v) {
282                variants.push(v);
283            }
284        }
285
286        if tasm[i] == "mul" && i >= 1 {
287            // mul → swap 1; mul (commutativity — same result)
288            let mut v = tasm.to_vec();
289            v.insert(i, "swap 1".to_string());
290            if verify_substitution(tasm, &v) {
291                variants.push(v);
292            }
293        }
294    }
295
296    variants
297}
298
299/// Verify that a substituted TASM sequence is equivalent to the original.
300fn verify_substitution(original: &[String], candidate: &[String]) -> bool {
301    // Test on 3 different random seeds
302    (0..3).all(|seed| stack_verifier::verify_equivalent(original, candidate, seed * 7919 + 42))
303}
304
305// ─── Dead Code Insertion (TIR-side) ──────────────────────────────
306
307/// Insert dead code nodes into a TirGraph.
308///
309/// Adds operations that don't affect the output: push+pop pairs,
310/// dup+pop pairs, nop sequences. The model must learn to ignore these.
311fn insert_dead_code(graph: &TirGraph, rng: &mut Xorshift64) -> TirGraph {
312    use crate::neural::data::tir_graph::{EdgeKind, FieldType, OpKind, TirNode};
313
314    let mut nodes = graph.nodes.clone();
315    let mut edges = graph.edges.clone();
316
317    // Number of dead code insertions: 1-3
318    let num_insertions = 1 + (rng.next() % 3) as usize;
319
320    for _ in 0..num_insertions {
321        if nodes.is_empty() {
322            break;
323        }
324
325        // Pick random insertion point
326        let insert_at = (rng.next() % nodes.len() as u64) as usize;
327        let dead_kind = rng.next() % 3;
328
329        let dead_nodes: Vec<TirNode> = match dead_kind {
330            0 => {
331                // push + pop pair
332                vec![
333                    TirNode {
334                        op: OpKind::Push,
335                        field_type: FieldType::BFE,
336                        immediate: Some(0),
337                    },
338                    TirNode {
339                        op: OpKind::Pop,
340                        field_type: FieldType::Unknown,
341                        immediate: Some(1),
342                    },
343                ]
344            }
345            1 => {
346                // dup 0 + pop 1 (if stack nonempty — conservative: always add push first)
347                vec![
348                    TirNode {
349                        op: OpKind::Push,
350                        field_type: FieldType::BFE,
351                        immediate: Some(0),
352                    },
353                    TirNode {
354                        op: OpKind::Dup,
355                        field_type: FieldType::BFE,
356                        immediate: Some(0),
357                    },
358                    TirNode {
359                        op: OpKind::Pop,
360                        field_type: FieldType::Unknown,
361                        immediate: Some(2),
362                    },
363                ]
364            }
365            _ => {
366                // Single nop-like: push 0; push 0; add; pop 1
367                vec![
368                    TirNode {
369                        op: OpKind::Push,
370                        field_type: FieldType::BFE,
371                        immediate: Some(0),
372                    },
373                    TirNode {
374                        op: OpKind::Push,
375                        field_type: FieldType::BFE,
376                        immediate: Some(0),
377                    },
378                    TirNode {
379                        op: OpKind::Add,
380                        field_type: FieldType::BFE,
381                        immediate: None,
382                    },
383                    TirNode {
384                        op: OpKind::Pop,
385                        field_type: FieldType::Unknown,
386                        immediate: Some(1),
387                    },
388                ]
389            }
390        };
391
392        let num_dead = dead_nodes.len();
393
394        // Shift all edge indices >= insert_at by num_dead
395        for edge in edges.iter_mut() {
396            if edge.0 >= insert_at {
397                edge.0 += num_dead;
398            }
399            if edge.1 >= insert_at {
400                edge.1 += num_dead;
401            }
402        }
403
404        // Insert dead nodes
405        let mut new_nodes = nodes[..insert_at].to_vec();
406        new_nodes.extend(dead_nodes);
407        new_nodes.extend_from_slice(&nodes[insert_at..]);
408        nodes = new_nodes;
409
410        // Add control flow edges within dead code
411        for j in 0..num_dead.saturating_sub(1) {
412            edges.push((insert_at + j, insert_at + j + 1, EdgeKind::ControlFlow));
413        }
414
415        // Add data dep edges within dead code (push→pop, push→dup, etc.)
416        if num_dead >= 2 {
417            edges.push((insert_at, insert_at + num_dead - 1, EdgeKind::DataDep));
418        }
419
420        // Connect to surrounding control flow
421        if insert_at > 0 {
422            edges.push((insert_at - 1, insert_at, EdgeKind::ControlFlow));
423        }
424        if insert_at + num_dead < nodes.len() {
425            edges.push((
426                insert_at + num_dead - 1,
427                insert_at + num_dead,
428                EdgeKind::ControlFlow,
429            ));
430        }
431    }
432
433    TirGraph { nodes, edges }
434}
435
436// ─── PRNG ─────────────────────────────────────────────────────────
437
438/// Simple xorshift64 PRNG for reproducible augmentation.
439struct Xorshift64 {
440    state: u64,
441}
442
443impl Xorshift64 {
444    fn new(seed: u64) -> Self {
445        Self {
446            state: seed | 1, // ensure non-zero
447        }
448    }
449
450    fn next(&mut self) -> u64 {
451        let mut x = self.state;
452        x ^= x << 13;
453        x ^= x >> 7;
454        x ^= x << 17;
455        self.state = x;
456        x
457    }
458}
459
460// ─── Tests ────────────────────────────────────────────────────────
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465    use crate::ir::tir::TIROp;
466    use crate::neural::data::tir_graph::TirGraph;
467
468    #[test]
469    fn random_walk_preserves_equivalence() {
470        let tasm = vec![
471            "push 3".to_string(),
472            "push 4".to_string(),
473            "add".to_string(),
474        ];
475        let mut rng = Xorshift64::new(42);
476        // Might or might not produce a variant (depends on RNG)
477        let result = random_walk_tasm(&tasm, 10, &mut rng);
478        if let Some(ref variant) = result {
479            // Must be equivalent
480            assert!(stack_verifier::verify_equivalent(&tasm, variant, 0));
481        }
482    }
483
484    #[test]
485    fn equivalent_substitutions_are_valid() {
486        let tasm = vec!["push 0".to_string(), "add".to_string()];
487        let variants = equivalent_substitutions(&tasm);
488        for variant in &variants {
489            assert!(
490                stack_verifier::verify_equivalent(&tasm, variant, 42),
491                "substitution not equivalent: {:?}",
492                variant,
493            );
494        }
495    }
496
497    #[test]
498    fn push_0_add_removed() {
499        let tasm = vec![
500            "push 5".to_string(),
501            "push 0".to_string(),
502            "add".to_string(),
503        ];
504        let variants = equivalent_substitutions(&tasm);
505        // Should find the push 0; add → remove variant
506        let has_shorter = variants.iter().any(|v| v.len() < tasm.len());
507        assert!(has_shorter, "expected push 0; add to be removed");
508    }
509
510    #[test]
511    fn dead_code_increases_graph_size() {
512        let ops = vec![TIROp::Push(1), TIROp::Push(2), TIROp::Add];
513        let graph = TirGraph::from_tir_ops(&ops);
514        let original_size = graph.num_nodes();
515
516        let mut rng = Xorshift64::new(42);
517        let augmented = insert_dead_code(&graph, &mut rng);
518        assert!(
519            augmented.num_nodes() > original_size,
520            "dead code should increase graph size",
521        );
522    }
523
524    #[test]
525    fn augment_pairs_multiplies_dataset() {
526        let vocab = Vocab::new();
527        let graph = TirGraph::from_tir_ops(&[TIROp::Push(1), TIROp::Push(2), TIROp::Add]);
528        let tokens = vocab.encode_sequence(&[
529            "push 1".to_string(),
530            "push 2".to_string(),
531            "add".to_string(),
532        ]);
533
534        let pairs = vec![TrainingPair {
535            graph,
536            target_tokens: tokens,
537            source_id: "test:0".into(),
538            baseline_cost: 3,
539        }];
540
541        let config = AugmentConfig {
542            tir_reorder_variants: 2,
543            tasm_walk_variants: 3,
544            max_swap_attempts: 5,
545            seed: 42,
546        };
547
548        let augmented = augment_pairs(&pairs, &vocab, &config);
549        assert!(
550            augmented.len() > 1,
551            "augmentation should produce more than original",
552        );
553    }
554
555    #[test]
556    fn swap_1_swap_1_eliminated() {
557        let tasm = vec![
558            "push 1".to_string(),
559            "push 2".to_string(),
560            "swap 1".to_string(),
561            "swap 1".to_string(),
562            "add".to_string(),
563        ];
564        let variants = equivalent_substitutions(&tasm);
565        let has_shorter = variants.iter().any(|v| v.len() < tasm.len());
566        assert!(has_shorter, "swap 1; swap 1 should be eliminated");
567    }
568}