Skip to main content

trident/neural/
mod.rs

1//! Neural compiler v2: GNN encoder + Transformer decoder.
2//!
3//! Replaces the v1 MLP evolutionary model with a ~13M parameter
4//! architecture trained via supervised learning + GFlowNets.
5//!
6//! # Public API
7//!
8//! ```ignore
9//! use trident::neural;
10//! let result = neural::compile(&tir_ops, &baseline_tasm)?;
11//! ```
12
13pub mod checkpoint;
14pub mod data;
15pub mod inference;
16pub mod model;
17pub mod training;
18
19use burn::backend::Wgpu;
20
21use crate::ir::tir::TIROp;
22use data::tir_graph::TirGraph;
23use inference::beam::{beam_search, BeamConfig};
24use inference::execute::validate_and_rank;
25use model::vocab::Vocab;
26use training::supervised::{graph_to_edges, graph_to_features};
27
28/// Result of neural compilation.
29pub struct CompileResult {
30    /// Optimized TASM instructions.
31    pub tasm_lines: Vec<String>,
32    /// Table cost (clock cycles) of the result.
33    pub cost: u64,
34    /// How many beam candidates were valid.
35    pub valid_count: usize,
36    /// Total beam candidates evaluated.
37    pub total_count: usize,
38    /// Whether this is a neural result (true) or fallback (false).
39    pub neural: bool,
40}
41
42/// Compile TIR ops to optimized TASM using the neural model.
43///
44/// Loads the production checkpoint, runs beam search (K=32, max_steps=256),
45/// validates candidates against baseline TASM, and returns the cheapest valid one.
46///
47/// Falls back to `baseline_tasm` if no valid candidate is found or if
48/// no trained checkpoint exists.
49pub fn compile(tir_ops: &[TIROp], baseline_tasm: &[String]) -> Result<CompileResult, String> {
50    let device = burn::backend::wgpu::WgpuDevice::default();
51    compile_with_device::<Wgpu>(tir_ops, baseline_tasm, &device)
52}
53
54/// Compile TIR ops with a specific burn backend device.
55pub fn compile_with_device<B: burn::prelude::Backend>(
56    tir_ops: &[TIROp],
57    baseline_tasm: &[String],
58    device: &B::Device,
59) -> Result<CompileResult, String> {
60    let vocab = Vocab::new();
61
62    // Build graph from TIR
63    let graph = TirGraph::from_tir_ops(tir_ops);
64    if graph.nodes.is_empty() {
65        return Ok(fallback_result(baseline_tasm));
66    }
67
68    // Load production checkpoint
69    let config = model::composite::NeuralCompilerConfig::new();
70    let model = config.init::<B>(device);
71    let model =
72        match checkpoint::load_checkpoint(model, checkpoint::CheckpointTag::Production, device) {
73            Ok(Some(loaded)) => loaded,
74            Ok(None) => {
75                // No checkpoint — try stage1_best as fallback
76                let model2 = config.init::<B>(device);
77                match checkpoint::load_checkpoint(
78                    model2,
79                    checkpoint::CheckpointTag::Stage1Best,
80                    device,
81                ) {
82                    Ok(Some(loaded)) => loaded,
83                    _ => return Ok(fallback_result(baseline_tasm)),
84                }
85            }
86            Err(_) => return Ok(fallback_result(baseline_tasm)),
87        };
88
89    // Encode graph
90    let node_features = graph_to_features::<B>(&graph, device);
91    let (edge_src, edge_dst, edge_types) = graph_to_edges::<B>(&graph, device);
92
93    // Beam search
94    let beam_config = BeamConfig::default(); // K=32, max_steps=256
95    let beam_result = beam_search(
96        &model.encoder,
97        &model.decoder,
98        node_features,
99        edge_src,
100        edge_dst,
101        edge_types,
102        &beam_config,
103        0, // must match training initial_stack_depth
104        device,
105    );
106
107    // Validate and rank
108    match validate_and_rank(&beam_result.sequences, &vocab, baseline_tasm, 0) {
109        Some(ranked) => Ok(CompileResult {
110            tasm_lines: ranked.tasm_lines,
111            cost: ranked.cost,
112            valid_count: ranked.valid_count,
113            total_count: ranked.total_count,
114            neural: true,
115        }),
116        None => Ok(fallback_result(baseline_tasm)),
117    }
118}
119
120/// Load the trained model once, for use with `compile_with_model`.
121/// Returns None if no checkpoint exists.
122pub fn load_model<B: burn::prelude::Backend>(
123    device: &B::Device,
124) -> Option<model::composite::NeuralCompilerV2<B>> {
125    let config = model::composite::NeuralCompilerConfig::new();
126    let m = config.init::<B>(device);
127    match checkpoint::load_checkpoint(m, checkpoint::CheckpointTag::Production, device) {
128        Ok(Some(loaded)) => Some(loaded),
129        Ok(None) => {
130            let m2 = config.init::<B>(device);
131            match checkpoint::load_checkpoint(m2, checkpoint::CheckpointTag::Stage1Best, device) {
132                Ok(Some(loaded)) => Some(loaded),
133                _ => None,
134            }
135        }
136        Err(_) => None,
137    }
138}
139
140/// Compile TIR ops using a pre-loaded model (avoids repeated checkpoint loading).
141pub fn compile_with_model<B: burn::prelude::Backend>(
142    tir_ops: &[TIROp],
143    baseline_tasm: &[String],
144    model: &model::composite::NeuralCompilerV2<B>,
145    device: &B::Device,
146) -> Result<CompileResult, String> {
147    let vocab = Vocab::new();
148
149    let graph = TirGraph::from_tir_ops(tir_ops);
150    if graph.nodes.is_empty() {
151        return Ok(fallback_result(baseline_tasm));
152    }
153
154    let node_features = graph_to_features::<B>(&graph, device);
155    let (edge_src, edge_dst, edge_types) = graph_to_edges::<B>(&graph, device);
156
157    let beam_config = BeamConfig::default();
158    let beam_result = beam_search(
159        &model.encoder,
160        &model.decoder,
161        node_features,
162        edge_src,
163        edge_dst,
164        edge_types,
165        &beam_config,
166        0,
167        device,
168    );
169
170    match validate_and_rank(&beam_result.sequences, &vocab, baseline_tasm, 0) {
171        Some(ranked) => Ok(CompileResult {
172            tasm_lines: ranked.tasm_lines,
173            cost: ranked.cost,
174            valid_count: ranked.valid_count,
175            total_count: ranked.total_count,
176            neural: true,
177        }),
178        None => Ok(fallback_result(baseline_tasm)),
179    }
180}
181
182fn fallback_result(baseline_tasm: &[String]) -> CompileResult {
183    use crate::cost::scorer::profile_tasm;
184    let refs: Vec<&str> = baseline_tasm.iter().map(|s| s.as_str()).collect();
185    let cost = profile_tasm(&refs).cost();
186    CompileResult {
187        tasm_lines: baseline_tasm.to_vec(),
188        cost,
189        valid_count: 0,
190        total_count: 0,
191        neural: false,
192    }
193}