1use cranelift_codegen::ir::{self, types, AbiParam};
2use cranelift_codegen::isa::TargetIsa;
3use cranelift_codegen::settings::{self, Configurable};
4use cranelift_codegen::Context;
5use cranelift_jit::{ArenaMemoryProvider, JITBuilder, JITModule};
6use cranelift_module::{FuncId, Linkage, Module};
7use std::sync::Arc;
8
9use crate::debug::LambdaRegistry;
10use crate::stack_map::{RawStackMap, StackMapRegistry};
11
12#[derive(Debug)]
14pub enum PipelineError {
15 Init(String),
17 Declaration(String),
19 Compilation(String),
21 Definition(String),
23 Finalization(String),
25}
26
27impl std::fmt::Display for PipelineError {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 match self {
30 PipelineError::Init(e) => write!(f, "pipeline init failed: {}", e),
31 PipelineError::Declaration(e) => write!(f, "function declaration failed: {}", e),
32 PipelineError::Compilation(e) => write!(f, "compilation failed: {}", e),
33 PipelineError::Definition(e) => write!(f, "define_function failed: {}", e),
34 PipelineError::Finalization(e) => write!(f, "finalize_definitions failed: {}", e),
35 }
36 }
37}
38
39impl std::error::Error for PipelineError {}
40
41pub struct CodegenPipeline {
46 pub module: JITModule,
53 pub isa: Arc<dyn TargetIsa>,
55 pub stack_maps: StackMapRegistry,
57 pending_stack_maps: Vec<(FuncId, u32, Vec<RawStackMap>)>,
60 lambda_names: Vec<(FuncId, String)>,
62}
63
64impl CodegenPipeline {
65 pub fn new(symbols: &[(&str, *const u8)]) -> Result<Self, PipelineError> {
70 let mut flag_builder = settings::builder();
71 flag_builder
73 .set("preserve_frame_pointers", "true")
74 .map_err(|e| PipelineError::Init(format!("set preserve_frame_pointers: {e}")))?;
75 flag_builder
76 .set("opt_level", "speed")
77 .map_err(|e| PipelineError::Init(format!("set opt_level: {e}")))?;
78 flag_builder
81 .set("is_pic", "false")
82 .map_err(|e| PipelineError::Init(format!("set is_pic: {e}")))?;
83 flag_builder
84 .set("use_colocated_libcalls", "true")
85 .map_err(|e| PipelineError::Init(format!("set use_colocated_libcalls: {e}")))?;
86
87 let isa_builder = cranelift_native::builder()
88 .map_err(|e| PipelineError::Init(format!("host ISA: {e}")))?;
89 let isa = isa_builder
90 .finish(settings::Flags::new(flag_builder.clone()))
91 .map_err(|e| PipelineError::Init(format!("ISA finish: {e}")))?;
92
93 let mut jit_builder =
94 JITBuilder::with_isa(isa.clone(), cranelift_module::default_libcall_names());
95
96 for (name, ptr) in symbols {
97 jit_builder.symbol(*name, *ptr);
98 }
99
100 let arena = ArenaMemoryProvider::new_with_size(256 * 1024 * 1024)
104 .map_err(|e| PipelineError::Init(format!("JIT memory arena: {e}")))?;
105 jit_builder.memory_provider(Box::new(arena));
106
107 let module = JITModule::new(jit_builder);
108
109 Ok(Self {
110 module,
111 isa,
112 stack_maps: StackMapRegistry::new(),
113 pending_stack_maps: Vec::new(),
114 lambda_names: Vec::new(),
115 })
116 }
117
118 pub fn make_func_signature(&self) -> ir::Signature {
123 let mut sig = ir::Signature::new(self.isa.default_call_conv());
124 sig.params.push(AbiParam::new(types::I64)); sig.returns.push(AbiParam::new(types::I64)); sig
127 }
128
129 pub fn declare_function(&mut self, name: &str) -> Result<FuncId, PipelineError> {
131 let sig = self.make_func_signature();
132 self.module
133 .declare_function(name, Linkage::Export, &sig)
134 .map_err(|e| PipelineError::Declaration(format!("failed to declare `{}`: {}", name, e)))
135 }
136
137 pub fn define_function(
144 &mut self,
145 func_id: FuncId,
146 ctx: &mut Context,
147 ) -> Result<(), PipelineError> {
148 self.module
150 .define_function(func_id, ctx)
151 .map_err(|e| PipelineError::Definition(format!("{:?}", e)))?;
152
153 let compiled = ctx.compiled_code().ok_or_else(|| {
155 PipelineError::Compilation("compiled_code missing after define_function".into())
156 })?;
157 let func_size = compiled.code_buffer().len() as u32;
158 let raw_maps: Vec<RawStackMap> = compiled
159 .buffer
160 .user_stack_maps()
161 .iter()
162 .map(|(offset, span, usm)| {
163 let entries: Vec<_> = usm.entries().collect();
164 (*offset, *span, entries)
165 })
166 .collect();
167
168 self.pending_stack_maps.push((func_id, func_size, raw_maps));
169 Ok(())
170 }
171
172 pub fn finalize(&mut self) -> Result<(), PipelineError> {
175 self.module
176 .finalize_definitions()
177 .map_err(|e| PipelineError::Finalization(format!("{}", e)))?;
178
179 let pending = std::mem::take(&mut self.pending_stack_maps);
181 for (func_id, func_size, raw_maps) in pending {
182 let base_ptr = self.module.get_finalized_function(func_id) as usize;
183 self.stack_maps.register(base_ptr, func_size, &raw_maps);
184 }
185 Ok(())
186 }
187
188 pub fn get_function_ptr(&self, func_id: FuncId) -> *const u8 {
190 self.module.get_finalized_function(func_id)
191 }
192
193 pub fn register_lambda(&mut self, func_id: FuncId, name: String) {
195 self.lambda_names.push((func_id, name));
196 }
197
198 pub fn build_lambda_registry(&self) -> LambdaRegistry {
201 let mut registry = LambdaRegistry::new();
202 for (func_id, name) in &self.lambda_names {
203 let ptr = self.module.get_finalized_function(*func_id) as usize;
204 registry.register(ptr, name.clone());
205 }
206 registry
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use cranelift_codegen::ir::InstBuilder;
214 use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext};
215
216 #[test]
217 fn test_empty_pipeline() {
218 let mut pipeline = CodegenPipeline::new(&[]).unwrap();
219 pipeline.finalize().unwrap();
220 }
221
222 #[test]
223 fn test_declare_define_finalize() {
224 let mut pipeline = CodegenPipeline::new(&[]).unwrap();
225 let func_id = pipeline.declare_function("test_fn").unwrap();
226
227 let mut ctx = pipeline.module.make_context();
228 ctx.func.signature = pipeline.make_func_signature();
229
230 let mut builder_context = FunctionBuilderContext::new();
231 let mut builder = FunctionBuilder::new(&mut ctx.func, &mut builder_context);
232
233 let block = builder.create_block();
234 builder.append_block_params_for_function_params(block);
235 builder.switch_to_block(block);
236 builder.seal_block(block);
237
238 let val = builder.ins().iconst(types::I64, 42);
239 builder.ins().return_(&[val]);
240 builder.finalize();
241
242 pipeline.define_function(func_id, &mut ctx).unwrap();
243 pipeline.finalize().unwrap();
244
245 let ptr = pipeline.get_function_ptr(func_id);
246 assert!(!ptr.is_null());
247
248 let func: unsafe extern "C" fn(usize) -> i64 = unsafe { std::mem::transmute(ptr) };
250 let res = unsafe { func(0) };
252 assert_eq!(res, 42);
253 }
254
255 #[test]
256 fn test_duplicate_declarations() {
257 let mut pipeline = CodegenPipeline::new(&[]).unwrap();
258 let id1 = pipeline.declare_function("f1").unwrap();
259 let id2 = pipeline.declare_function("f2").unwrap();
260 assert_ne!(id1, id2);
261
262 let id3 = pipeline.declare_function("f1").unwrap();
263 assert_eq!(id1, id3);
264 }
265
266 #[test]
267 fn test_get_function_ptr_after_finalize() {
268 let mut pipeline = CodegenPipeline::new(&[]).unwrap();
269 let func_id = pipeline.declare_function("f1").unwrap();
270
271 let mut ctx = pipeline.module.make_context();
272 ctx.func.signature = pipeline.make_func_signature();
273 let mut builder_context = FunctionBuilderContext::new();
274 let mut builder = FunctionBuilder::new(&mut ctx.func, &mut builder_context);
275 let block = builder.create_block();
276 builder.append_block_params_for_function_params(block);
277 builder.switch_to_block(block);
278 builder.seal_block(block);
279 let val = builder.ins().iconst(types::I64, 0);
280 builder.ins().return_(&[val]);
281 builder.finalize();
282
283 pipeline.define_function(func_id, &mut ctx).unwrap();
284 pipeline.finalize().unwrap();
285
286 let ptr = pipeline.get_function_ptr(func_id);
287 assert!(!ptr.is_null());
288 }
289
290 #[test]
291 fn test_build_lambda_registry() {
292 let mut pipeline = CodegenPipeline::new(&[]).unwrap();
293 let func_id = pipeline.declare_function("f1").unwrap();
294
295 let mut ctx = pipeline.module.make_context();
296 ctx.func.signature = pipeline.make_func_signature();
297 let mut builder_context = FunctionBuilderContext::new();
298 let mut builder = FunctionBuilder::new(&mut ctx.func, &mut builder_context);
299 let block = builder.create_block();
300 builder.append_block_params_for_function_params(block);
301 builder.switch_to_block(block);
302 builder.seal_block(block);
303 let val = builder.ins().iconst(types::I64, 0);
304 builder.ins().return_(&[val]);
305 builder.finalize();
306
307 pipeline.define_function(func_id, &mut ctx).unwrap();
308 pipeline.register_lambda(func_id, "my_lambda".to_string());
309 pipeline.finalize().unwrap();
310
311 let registry = pipeline.build_lambda_registry();
312 let ptr = pipeline.get_function_ptr(func_id);
313 assert_eq!(registry.lookup(ptr as usize), Some("my_lambda"));
314 }
315
316 #[test]
317 fn test_host_fn_symbols_integration() {
318 extern "C" fn my_host_fn() -> i64 {
319 123
320 }
321 let symbols = [("my_host_fn", my_host_fn as *const u8)];
322 let mut pipeline = CodegenPipeline::new(&symbols).unwrap();
323
324 let func_id = pipeline.declare_function("call_host").unwrap();
325 let mut ctx = pipeline.module.make_context();
326 ctx.func.signature = pipeline.make_func_signature();
327
328 let mut builder_context = FunctionBuilderContext::new();
329 let mut builder = FunctionBuilder::new(&mut ctx.func, &mut builder_context);
330
331 let block = builder.create_block();
332 builder.append_block_params_for_function_params(block);
333 builder.switch_to_block(block);
334 builder.seal_block(block);
335
336 let mut sig = ir::Signature::new(pipeline.isa.default_call_conv());
337 sig.returns.push(AbiParam::new(types::I64));
338 let callee = pipeline
339 .module
340 .declare_function("my_host_fn", Linkage::Import, &sig)
341 .unwrap();
342 let local_callee = pipeline
343 .module
344 .declare_func_in_func(callee, &mut builder.func);
345
346 let call = builder.ins().call(local_callee, &[]);
347 let res = builder.inst_results(call)[0];
348 builder.ins().return_(&[res]);
349 builder.finalize();
350
351 pipeline.define_function(func_id, &mut ctx).unwrap();
352 pipeline.finalize().unwrap();
353
354 let ptr = pipeline.get_function_ptr(func_id);
355 let func: unsafe extern "C" fn(usize) -> i64 = unsafe { std::mem::transmute(ptr) };
357 assert_eq!(unsafe { func(0) }, 123);
359 }
360}