1mod instructions;
2
3extern crate alloc;
4
5use alloc::fmt;
6use core::mem::ManuallyDrop;
7use cranelift::prelude::*;
8use cranelift_jit::{JITBuilder, JITModule};
9use cranelift_module::{FuncId, Linkage, Module, ModuleError};
10use instructions::FunctionCompiler;
11use rustpython_compiler_core::bytecode;
12
13#[derive(Debug, thiserror::Error)]
14#[non_exhaustive]
15pub enum JitCompileError {
16 #[error("function can't be jitted")]
17 NotSupported,
18 #[error("bad bytecode")]
19 BadBytecode,
20 #[error("error while compiling to machine code: {0}")]
21 CraneliftError(Box<ModuleError>),
22}
23
24impl From<ModuleError> for JitCompileError {
25 fn from(err: ModuleError) -> Self {
26 Self::CraneliftError(Box::new(err))
27 }
28}
29
30#[derive(Debug, thiserror::Error, Eq, PartialEq)]
31#[non_exhaustive]
32pub enum JitArgumentError {
33 #[error("argument is of wrong type")]
34 ArgumentTypeMismatch,
35 #[error("wrong number of arguments")]
36 WrongNumberOfArguments,
37}
38
39struct Jit {
40 builder_context: FunctionBuilderContext,
41 ctx: codegen::Context,
42 module: JITModule,
43}
44
45impl Jit {
46 fn new() -> Self {
47 let builder = JITBuilder::new(cranelift_module::default_libcall_names())
48 .expect("Failed to build JITBuilder");
49 let module = JITModule::new(builder);
50 Self {
51 builder_context: FunctionBuilderContext::new(),
52 ctx: module.make_context(),
53 module,
54 }
55 }
56
57 fn build_function<C: bytecode::Constant>(
58 &mut self,
59 bytecode: &bytecode::CodeObject<C>,
60 args: &[JitType],
61 ret: Option<JitType>,
62 ) -> Result<(FuncId, JitSig), JitCompileError> {
63 for arg in args {
64 self.ctx
65 .func
66 .signature
67 .params
68 .push(AbiParam::new(arg.to_cranelift()));
69 }
70
71 if ret.is_some() {
72 self.ctx
73 .func
74 .signature
75 .returns
76 .push(AbiParam::new(ret.clone().unwrap().to_cranelift()));
77 }
78
79 let id = self.module.declare_function(
80 &format!("jit_{}", bytecode.obj_name.as_ref()),
81 Linkage::Export,
82 &self.ctx.func.signature,
83 )?;
84
85 let func_ref = self.module.declare_func_in_func(id, &mut self.ctx.func);
86
87 let mut builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context);
88 let entry_block = builder.create_block();
89 builder.append_block_params_for_function_params(entry_block);
90 builder.switch_to_block(entry_block);
91
92 let sig = {
93 let mut compiler = FunctionCompiler::new(
94 &mut builder,
95 bytecode.varnames.len(),
96 args,
97 ret,
98 entry_block,
99 );
100
101 compiler.compile(func_ref, bytecode)?;
102
103 compiler.sig
104 };
105
106 builder.seal_all_blocks();
107 builder.finalize();
108
109 self.module.define_function(id, &mut self.ctx)?;
110
111 self.module.clear_context(&mut self.ctx);
112
113 Ok((id, sig))
114 }
115}
116
117pub fn compile<C: bytecode::Constant>(
118 bytecode: &bytecode::CodeObject<C>,
119 args: &[JitType],
120 ret: Option<JitType>,
121) -> Result<CompiledCode, JitCompileError> {
122 let mut jit = Jit::new();
123
124 let (id, sig) = jit.build_function(bytecode, args, ret)?;
125
126 jit.module.finalize_definitions()?;
127
128 let code = jit.module.get_finalized_function(id);
129 Ok(CompiledCode {
130 sig,
131 code,
132 module: ManuallyDrop::new(jit.module),
133 })
134}
135
136pub struct CompiledCode {
137 sig: JitSig,
138 code: *const u8,
139 module: ManuallyDrop<JITModule>,
140}
141
142impl CompiledCode {
143 pub fn args_builder(&self) -> ArgsBuilder<'_> {
144 ArgsBuilder::new(self)
145 }
146
147 pub fn invoke(&self, args: &[AbiValue]) -> Result<Option<AbiValue>, JitArgumentError> {
148 if self.sig.args.len() != args.len() {
149 return Err(JitArgumentError::WrongNumberOfArguments);
150 }
151
152 let cif_args = self
153 .sig
154 .args
155 .iter()
156 .zip(args.iter())
157 .map(|(ty, val)| type_check(ty, val).map(|_| val))
158 .map(|v| v.map(AbiValue::to_libffi_arg))
159 .collect::<Result<Vec<_>, _>>()?;
160 Ok(unsafe { self.invoke_raw(&cif_args) })
161 }
162
163 unsafe fn invoke_raw(&self, cif_args: &[libffi::middle::Arg<'_>]) -> Option<AbiValue> {
164 unsafe {
165 let cif = self.sig.to_cif();
166 let value = cif.call::<UnTypedAbiValue>(
167 libffi::middle::CodePtr::from_ptr(self.code as *const _),
168 cif_args,
169 );
170 self.sig.ret.as_ref().map(|ty| value.to_typed(ty))
171 }
172 }
173}
174
175struct JitSig {
176 args: Vec<JitType>,
177 ret: Option<JitType>,
178}
179
180impl JitSig {
181 fn to_cif(&self) -> libffi::middle::Cif {
182 let ret = match self.ret {
183 Some(ref ty) => ty.to_libffi(),
184 None => libffi::middle::Type::void(),
185 };
186 libffi::middle::Cif::new(self.args.iter().map(JitType::to_libffi), ret)
187 }
188}
189
190#[derive(Debug, Clone, PartialEq, Eq)]
191#[non_exhaustive]
192pub enum JitType {
193 Int,
194 Float,
195 Bool,
196}
197
198impl JitType {
199 fn to_cranelift(&self) -> types::Type {
200 match self {
201 Self::Int => types::I64,
202 Self::Float => types::F64,
203 Self::Bool => types::I8,
204 }
205 }
206
207 fn to_libffi(&self) -> libffi::middle::Type {
208 match self {
209 Self::Int => libffi::middle::Type::i64(),
210 Self::Float => libffi::middle::Type::f64(),
211 Self::Bool => libffi::middle::Type::u8(),
212 }
213 }
214}
215
216#[derive(Debug, Clone, PartialEq)]
217#[non_exhaustive]
218pub enum AbiValue {
219 Float(f64),
220 Int(i64),
221 Bool(bool),
222}
223
224impl AbiValue {
225 fn to_libffi_arg(&self) -> libffi::middle::Arg<'_> {
226 match self {
227 AbiValue::Int(i) => libffi::middle::Arg::new(i),
228 AbiValue::Float(f) => libffi::middle::Arg::new(f),
229 AbiValue::Bool(b) => libffi::middle::Arg::new(b),
230 }
231 }
232}
233
234impl From<i64> for AbiValue {
235 fn from(i: i64) -> Self {
236 AbiValue::Int(i)
237 }
238}
239
240impl From<f64> for AbiValue {
241 fn from(f: f64) -> Self {
242 AbiValue::Float(f)
243 }
244}
245
246impl From<bool> for AbiValue {
247 fn from(b: bool) -> Self {
248 AbiValue::Bool(b)
249 }
250}
251
252impl TryFrom<AbiValue> for i64 {
253 type Error = ();
254
255 fn try_from(value: AbiValue) -> Result<Self, Self::Error> {
256 match value {
257 AbiValue::Int(i) => Ok(i),
258 _ => Err(()),
259 }
260 }
261}
262
263impl TryFrom<AbiValue> for f64 {
264 type Error = ();
265
266 fn try_from(value: AbiValue) -> Result<Self, Self::Error> {
267 match value {
268 AbiValue::Float(f) => Ok(f),
269 _ => Err(()),
270 }
271 }
272}
273
274impl TryFrom<AbiValue> for bool {
275 type Error = ();
276
277 fn try_from(value: AbiValue) -> Result<Self, Self::Error> {
278 match value {
279 AbiValue::Bool(b) => Ok(b),
280 _ => Err(()),
281 }
282 }
283}
284
285fn type_check(ty: &JitType, val: &AbiValue) -> Result<(), JitArgumentError> {
286 match (ty, val) {
287 (JitType::Int, AbiValue::Int(_))
288 | (JitType::Float, AbiValue::Float(_))
289 | (JitType::Bool, AbiValue::Bool(_)) => Ok(()),
290 _ => Err(JitArgumentError::ArgumentTypeMismatch),
291 }
292}
293
294#[derive(Copy, Clone)]
295union UnTypedAbiValue {
296 float: f64,
297 int: i64,
298 boolean: u8,
299 _void: (),
300}
301
302impl UnTypedAbiValue {
303 unsafe fn to_typed(self, ty: &JitType) -> AbiValue {
304 unsafe {
305 match ty {
306 JitType::Int => AbiValue::Int(self.int),
307 JitType::Float => AbiValue::Float(self.float),
308 JitType::Bool => AbiValue::Bool(self.boolean != 0),
309 }
310 }
311 }
312}
313
314unsafe impl Send for CompiledCode {}
317unsafe impl Sync for CompiledCode {}
318
319impl Drop for CompiledCode {
320 fn drop(&mut self) {
321 unsafe { ManuallyDrop::take(&mut self.module).free_memory() }
323 }
324}
325
326impl fmt::Debug for CompiledCode {
327 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
328 f.write_str("[compiled code]")
329 }
330}
331
332pub struct ArgsBuilder<'a> {
333 values: Vec<Option<AbiValue>>,
334 code: &'a CompiledCode,
335}
336
337impl<'a> ArgsBuilder<'a> {
338 fn new(code: &'a CompiledCode) -> ArgsBuilder<'a> {
339 ArgsBuilder {
340 values: vec![None; code.sig.args.len()],
341 code,
342 }
343 }
344
345 pub fn set(&mut self, idx: usize, value: AbiValue) -> Result<(), JitArgumentError> {
346 type_check(&self.code.sig.args[idx], &value).map(|_| {
347 self.values[idx] = Some(value);
348 })
349 }
350
351 pub fn is_set(&self, idx: usize) -> bool {
352 self.values[idx].is_some()
353 }
354
355 pub fn into_args(self) -> Option<Args<'a>> {
356 if self.values.iter().any(|v| v.is_none()) {
358 return None;
359 }
360 Some(Args {
361 values: self.values.into_iter().map(|v| v.unwrap()).collect(),
362 code: self.code,
363 })
364 }
365}
366
367pub struct Args<'a> {
368 values: Vec<AbiValue>,
369 code: &'a CompiledCode,
370}
371
372impl Args<'_> {
373 pub fn invoke(&self) -> Option<AbiValue> {
374 let cif_args: Vec<_> = self.values.iter().map(AbiValue::to_libffi_arg).collect();
375 unsafe { self.code.invoke_raw(&cif_args) }
376 }
377}