1pub mod checkpoint;
14pub mod data;
15pub mod inference;
16pub mod model;
17pub mod training;
18
19use burn::backend::Wgpu;
20
21use crate::ir::tir::TIROp;
22use data::tir_graph::TirGraph;
23use inference::beam::{beam_search, BeamConfig};
24use inference::execute::validate_and_rank;
25use model::vocab::Vocab;
26use training::supervised::{graph_to_edges, graph_to_features};
27
28pub struct CompileResult {
30 pub tasm_lines: Vec<String>,
32 pub cost: u64,
34 pub valid_count: usize,
36 pub total_count: usize,
38 pub neural: bool,
40}
41
42pub fn compile(tir_ops: &[TIROp], baseline_tasm: &[String]) -> Result<CompileResult, String> {
50 let device = burn::backend::wgpu::WgpuDevice::default();
51 compile_with_device::<Wgpu>(tir_ops, baseline_tasm, &device)
52}
53
54pub fn compile_with_device<B: burn::prelude::Backend>(
56 tir_ops: &[TIROp],
57 baseline_tasm: &[String],
58 device: &B::Device,
59) -> Result<CompileResult, String> {
60 let vocab = Vocab::new();
61
62 let graph = TirGraph::from_tir_ops(tir_ops);
64 if graph.nodes.is_empty() {
65 return Ok(fallback_result(baseline_tasm));
66 }
67
68 let config = model::composite::NeuralCompilerConfig::new();
70 let model = config.init::<B>(device);
71 let model =
72 match checkpoint::load_checkpoint(model, checkpoint::CheckpointTag::Production, device) {
73 Ok(Some(loaded)) => loaded,
74 Ok(None) => {
75 let model2 = config.init::<B>(device);
77 match checkpoint::load_checkpoint(
78 model2,
79 checkpoint::CheckpointTag::Stage1Best,
80 device,
81 ) {
82 Ok(Some(loaded)) => loaded,
83 _ => return Ok(fallback_result(baseline_tasm)),
84 }
85 }
86 Err(_) => return Ok(fallback_result(baseline_tasm)),
87 };
88
89 let node_features = graph_to_features::<B>(&graph, device);
91 let (edge_src, edge_dst, edge_types) = graph_to_edges::<B>(&graph, device);
92
93 let beam_config = BeamConfig::default(); let beam_result = beam_search(
96 &model.encoder,
97 &model.decoder,
98 node_features,
99 edge_src,
100 edge_dst,
101 edge_types,
102 &beam_config,
103 0, device,
105 );
106
107 match validate_and_rank(&beam_result.sequences, &vocab, baseline_tasm, 0) {
109 Some(ranked) => Ok(CompileResult {
110 tasm_lines: ranked.tasm_lines,
111 cost: ranked.cost,
112 valid_count: ranked.valid_count,
113 total_count: ranked.total_count,
114 neural: true,
115 }),
116 None => Ok(fallback_result(baseline_tasm)),
117 }
118}
119
120pub fn load_model<B: burn::prelude::Backend>(
123 device: &B::Device,
124) -> Option<model::composite::NeuralCompilerV2<B>> {
125 let config = model::composite::NeuralCompilerConfig::new();
126 let m = config.init::<B>(device);
127 match checkpoint::load_checkpoint(m, checkpoint::CheckpointTag::Production, device) {
128 Ok(Some(loaded)) => Some(loaded),
129 Ok(None) => {
130 let m2 = config.init::<B>(device);
131 match checkpoint::load_checkpoint(m2, checkpoint::CheckpointTag::Stage1Best, device) {
132 Ok(Some(loaded)) => Some(loaded),
133 _ => None,
134 }
135 }
136 Err(_) => None,
137 }
138}
139
140pub fn compile_with_model<B: burn::prelude::Backend>(
142 tir_ops: &[TIROp],
143 baseline_tasm: &[String],
144 model: &model::composite::NeuralCompilerV2<B>,
145 device: &B::Device,
146) -> Result<CompileResult, String> {
147 let vocab = Vocab::new();
148
149 let graph = TirGraph::from_tir_ops(tir_ops);
150 if graph.nodes.is_empty() {
151 return Ok(fallback_result(baseline_tasm));
152 }
153
154 let node_features = graph_to_features::<B>(&graph, device);
155 let (edge_src, edge_dst, edge_types) = graph_to_edges::<B>(&graph, device);
156
157 let beam_config = BeamConfig::default();
158 let beam_result = beam_search(
159 &model.encoder,
160 &model.decoder,
161 node_features,
162 edge_src,
163 edge_dst,
164 edge_types,
165 &beam_config,
166 0,
167 device,
168 );
169
170 match validate_and_rank(&beam_result.sequences, &vocab, baseline_tasm, 0) {
171 Some(ranked) => Ok(CompileResult {
172 tasm_lines: ranked.tasm_lines,
173 cost: ranked.cost,
174 valid_count: ranked.valid_count,
175 total_count: ranked.total_count,
176 neural: true,
177 }),
178 None => Ok(fallback_result(baseline_tasm)),
179 }
180}
181
182fn fallback_result(baseline_tasm: &[String]) -> CompileResult {
183 use crate::cost::scorer::profile_tasm;
184 let refs: Vec<&str> = baseline_tasm.iter().map(|s| s.as_str()).collect();
185 let cost = profile_tasm(&refs).cost();
186 CompileResult {
187 tasm_lines: baseline_tasm.to_vec(),
188 cost,
189 valid_count: 0,
190 total_count: 0,
191 neural: false,
192 }
193}