1mod instructions;
2
3use cranelift::prelude::*;
4use cranelift_jit::{JITBuilder, JITModule};
5use cranelift_module::{FuncId, Linkage, Module, ModuleError};
6use instructions::FunctionCompiler;
7use rustpython_compiler_core::bytecode;
8use std::{fmt, mem::ManuallyDrop};
9
10#[derive(Debug, thiserror::Error)]
11#[non_exhaustive]
12pub enum JitCompileError {
13 #[error("function can't be jitted")]
14 NotSupported,
15 #[error("bad bytecode")]
16 BadBytecode,
17 #[error("error while compiling to machine code: {0}")]
18 CraneliftError(#[from] ModuleError),
19}
20
21#[derive(Debug, thiserror::Error, Eq, PartialEq)]
22#[non_exhaustive]
23pub enum JitArgumentError {
24 #[error("argument is of wrong type")]
25 ArgumentTypeMismatch,
26 #[error("wrong number of arguments")]
27 WrongNumberOfArguments,
28}
29
30struct Jit {
31 builder_context: FunctionBuilderContext,
32 ctx: codegen::Context,
33 module: JITModule,
34}
35
36impl Jit {
37 fn new() -> Self {
38 let builder = JITBuilder::new(cranelift_module::default_libcall_names())
39 .expect("Failed to build JITBuilder");
40 let module = JITModule::new(builder);
41 Self {
42 builder_context: FunctionBuilderContext::new(),
43 ctx: module.make_context(),
44 module,
45 }
46 }
47
48 fn build_function<C: bytecode::Constant>(
49 &mut self,
50 bytecode: &bytecode::CodeObject<C>,
51 args: &[JitType],
52 ) -> Result<(FuncId, JitSig), JitCompileError> {
53 for arg in args {
54 self.ctx
55 .func
56 .signature
57 .params
58 .push(AbiParam::new(arg.to_cranelift()));
59 }
60
61 let mut builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context);
62 let entry_block = builder.create_block();
63 builder.append_block_params_for_function_params(entry_block);
64 builder.switch_to_block(entry_block);
65
66 let sig = {
67 let mut compiler =
68 FunctionCompiler::new(&mut builder, bytecode.varnames.len(), args, entry_block);
69
70 compiler.compile(bytecode)?;
71
72 compiler.sig
73 };
74
75 builder.seal_all_blocks();
76 builder.finalize();
77
78 let id = self.module.declare_function(
79 &format!("jit_{}", bytecode.obj_name.as_ref()),
80 Linkage::Export,
81 &self.ctx.func.signature,
82 )?;
83
84 self.module.define_function(id, &mut self.ctx)?;
85
86 self.module.clear_context(&mut self.ctx);
87
88 Ok((id, sig))
89 }
90}
91
92pub fn compile<C: bytecode::Constant>(
93 bytecode: &bytecode::CodeObject<C>,
94 args: &[JitType],
95) -> Result<CompiledCode, JitCompileError> {
96 let mut jit = Jit::new();
97
98 let (id, sig) = jit.build_function(bytecode, args)?;
99
100 jit.module.finalize_definitions();
101
102 let code = jit.module.get_finalized_function(id);
103 Ok(CompiledCode {
104 sig,
105 code,
106 module: ManuallyDrop::new(jit.module),
107 })
108}
109
110pub struct CompiledCode {
111 sig: JitSig,
112 code: *const u8,
113 module: ManuallyDrop<JITModule>,
114}
115
116impl CompiledCode {
117 pub fn args_builder(&self) -> ArgsBuilder<'_> {
118 ArgsBuilder::new(self)
119 }
120
121 pub fn invoke(&self, args: &[AbiValue]) -> Result<Option<AbiValue>, JitArgumentError> {
122 if self.sig.args.len() != args.len() {
123 return Err(JitArgumentError::WrongNumberOfArguments);
124 }
125
126 let cif_args = self
127 .sig
128 .args
129 .iter()
130 .zip(args.iter())
131 .map(|(ty, val)| type_check(ty, val).map(|_| val))
132 .map(|v| v.map(AbiValue::to_libffi_arg))
133 .collect::<Result<Vec<_>, _>>()?;
134 Ok(unsafe { self.invoke_raw(&cif_args) })
135 }
136
137 unsafe fn invoke_raw(&self, cif_args: &[libffi::middle::Arg]) -> Option<AbiValue> {
138 let cif = self.sig.to_cif();
139 let value = cif.call::<UnTypedAbiValue>(
140 libffi::middle::CodePtr::from_ptr(self.code as *const _),
141 cif_args,
142 );
143 self.sig.ret.as_ref().map(|ty| value.to_typed(ty))
144 }
145}
146
147struct JitSig {
148 args: Vec<JitType>,
149 ret: Option<JitType>,
150}
151
152impl JitSig {
153 fn to_cif(&self) -> libffi::middle::Cif {
154 let ret = match self.ret {
155 Some(ref ty) => ty.to_libffi(),
156 None => libffi::middle::Type::void(),
157 };
158 libffi::middle::Cif::new(self.args.iter().map(JitType::to_libffi), ret)
159 }
160}
161
162#[derive(Debug, Clone, PartialEq, Eq)]
163#[non_exhaustive]
164pub enum JitType {
165 Int,
166 Float,
167 Bool,
168}
169
170impl JitType {
171 fn to_cranelift(&self) -> types::Type {
172 match self {
173 Self::Int => types::I64,
174 Self::Float => types::F64,
175 Self::Bool => types::I8,
176 }
177 }
178
179 fn to_libffi(&self) -> libffi::middle::Type {
180 match self {
181 Self::Int => libffi::middle::Type::i64(),
182 Self::Float => libffi::middle::Type::f64(),
183 Self::Bool => libffi::middle::Type::u8(),
184 }
185 }
186}
187
188#[derive(Debug, Clone, PartialEq)]
189#[non_exhaustive]
190pub enum AbiValue {
191 Float(f64),
192 Int(i64),
193 Bool(bool),
194}
195
196impl AbiValue {
197 fn to_libffi_arg(&self) -> libffi::middle::Arg {
198 match self {
199 AbiValue::Int(ref i) => libffi::middle::Arg::new(i),
200 AbiValue::Float(ref f) => libffi::middle::Arg::new(f),
201 AbiValue::Bool(ref b) => libffi::middle::Arg::new(b),
202 }
203 }
204}
205
206impl From<i64> for AbiValue {
207 fn from(i: i64) -> Self {
208 AbiValue::Int(i)
209 }
210}
211
212impl From<f64> for AbiValue {
213 fn from(f: f64) -> Self {
214 AbiValue::Float(f)
215 }
216}
217
218impl From<bool> for AbiValue {
219 fn from(b: bool) -> Self {
220 AbiValue::Bool(b)
221 }
222}
223
224impl TryFrom<AbiValue> for i64 {
225 type Error = ();
226
227 fn try_from(value: AbiValue) -> Result<Self, Self::Error> {
228 match value {
229 AbiValue::Int(i) => Ok(i),
230 _ => Err(()),
231 }
232 }
233}
234
235impl TryFrom<AbiValue> for f64 {
236 type Error = ();
237
238 fn try_from(value: AbiValue) -> Result<Self, Self::Error> {
239 match value {
240 AbiValue::Float(f) => Ok(f),
241 _ => Err(()),
242 }
243 }
244}
245
246impl TryFrom<AbiValue> for bool {
247 type Error = ();
248
249 fn try_from(value: AbiValue) -> Result<Self, Self::Error> {
250 match value {
251 AbiValue::Bool(b) => Ok(b),
252 _ => Err(()),
253 }
254 }
255}
256
257fn type_check(ty: &JitType, val: &AbiValue) -> Result<(), JitArgumentError> {
258 match (ty, val) {
259 (JitType::Int, AbiValue::Int(_))
260 | (JitType::Float, AbiValue::Float(_))
261 | (JitType::Bool, AbiValue::Bool(_)) => Ok(()),
262 _ => Err(JitArgumentError::ArgumentTypeMismatch),
263 }
264}
265
266#[derive(Copy, Clone)]
267union UnTypedAbiValue {
268 float: f64,
269 int: i64,
270 boolean: u8,
271 _void: (),
272}
273
274impl UnTypedAbiValue {
275 unsafe fn to_typed(self, ty: &JitType) -> AbiValue {
276 match ty {
277 JitType::Int => AbiValue::Int(self.int),
278 JitType::Float => AbiValue::Float(self.float),
279 JitType::Bool => AbiValue::Bool(self.boolean != 0),
280 }
281 }
282}
283
284unsafe impl Send for CompiledCode {}
287unsafe impl Sync for CompiledCode {}
288
289impl Drop for CompiledCode {
290 fn drop(&mut self) {
291 unsafe { ManuallyDrop::take(&mut self.module).free_memory() }
293 }
294}
295
296impl fmt::Debug for CompiledCode {
297 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
298 f.write_str("[compiled code]")
299 }
300}
301
302pub struct ArgsBuilder<'a> {
303 values: Vec<Option<AbiValue>>,
304 code: &'a CompiledCode,
305}
306
307impl<'a> ArgsBuilder<'a> {
308 fn new(code: &'a CompiledCode) -> ArgsBuilder<'a> {
309 ArgsBuilder {
310 values: vec![None; code.sig.args.len()],
311 code,
312 }
313 }
314
315 pub fn set(&mut self, idx: usize, value: AbiValue) -> Result<(), JitArgumentError> {
316 type_check(&self.code.sig.args[idx], &value).map(|_| {
317 self.values[idx] = Some(value);
318 })
319 }
320
321 pub fn is_set(&self, idx: usize) -> bool {
322 self.values[idx].is_some()
323 }
324
325 pub fn into_args(self) -> Option<Args<'a>> {
326 self.values
327 .iter()
328 .map(|v| v.as_ref().map(AbiValue::to_libffi_arg))
329 .collect::<Option<_>>()
330 .map(|cif_args| Args {
331 _values: self.values,
332 cif_args,
333 code: self.code,
334 })
335 }
336}
337
338pub struct Args<'a> {
339 _values: Vec<Option<AbiValue>>,
340 cif_args: Vec<libffi::middle::Arg>,
341 code: &'a CompiledCode,
342}
343
344impl<'a> Args<'a> {
345 pub fn invoke(&self) -> Option<AbiValue> {
346 unsafe { self.code.invoke_raw(&self.cif_args) }
347 }
348}