rustpython_jit/
lib.rs

1mod instructions;
2
3use cranelift::prelude::*;
4use cranelift_jit::{JITBuilder, JITModule};
5use cranelift_module::{FuncId, Linkage, Module, ModuleError};
6use instructions::FunctionCompiler;
7use rustpython_compiler_core::bytecode;
8use std::{fmt, mem::ManuallyDrop};
9
10#[derive(Debug, thiserror::Error)]
11#[non_exhaustive]
12pub enum JitCompileError {
13    #[error("function can't be jitted")]
14    NotSupported,
15    #[error("bad bytecode")]
16    BadBytecode,
17    #[error("error while compiling to machine code: {0}")]
18    CraneliftError(#[from] ModuleError),
19}
20
21#[derive(Debug, thiserror::Error, Eq, PartialEq)]
22#[non_exhaustive]
23pub enum JitArgumentError {
24    #[error("argument is of wrong type")]
25    ArgumentTypeMismatch,
26    #[error("wrong number of arguments")]
27    WrongNumberOfArguments,
28}
29
30struct Jit {
31    builder_context: FunctionBuilderContext,
32    ctx: codegen::Context,
33    module: JITModule,
34}
35
36impl Jit {
37    fn new() -> Self {
38        let builder = JITBuilder::new(cranelift_module::default_libcall_names())
39            .expect("Failed to build JITBuilder");
40        let module = JITModule::new(builder);
41        Self {
42            builder_context: FunctionBuilderContext::new(),
43            ctx: module.make_context(),
44            module,
45        }
46    }
47
48    fn build_function<C: bytecode::Constant>(
49        &mut self,
50        bytecode: &bytecode::CodeObject<C>,
51        args: &[JitType],
52    ) -> Result<(FuncId, JitSig), JitCompileError> {
53        for arg in args {
54            self.ctx
55                .func
56                .signature
57                .params
58                .push(AbiParam::new(arg.to_cranelift()));
59        }
60
61        let mut builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context);
62        let entry_block = builder.create_block();
63        builder.append_block_params_for_function_params(entry_block);
64        builder.switch_to_block(entry_block);
65
66        let sig = {
67            let mut compiler =
68                FunctionCompiler::new(&mut builder, bytecode.varnames.len(), args, entry_block);
69
70            compiler.compile(bytecode)?;
71
72            compiler.sig
73        };
74
75        builder.seal_all_blocks();
76        builder.finalize();
77
78        let id = self.module.declare_function(
79            &format!("jit_{}", bytecode.obj_name.as_ref()),
80            Linkage::Export,
81            &self.ctx.func.signature,
82        )?;
83
84        self.module.define_function(id, &mut self.ctx)?;
85
86        self.module.clear_context(&mut self.ctx);
87
88        Ok((id, sig))
89    }
90}
91
92pub fn compile<C: bytecode::Constant>(
93    bytecode: &bytecode::CodeObject<C>,
94    args: &[JitType],
95) -> Result<CompiledCode, JitCompileError> {
96    let mut jit = Jit::new();
97
98    let (id, sig) = jit.build_function(bytecode, args)?;
99
100    jit.module.finalize_definitions();
101
102    let code = jit.module.get_finalized_function(id);
103    Ok(CompiledCode {
104        sig,
105        code,
106        module: ManuallyDrop::new(jit.module),
107    })
108}
109
110pub struct CompiledCode {
111    sig: JitSig,
112    code: *const u8,
113    module: ManuallyDrop<JITModule>,
114}
115
116impl CompiledCode {
117    pub fn args_builder(&self) -> ArgsBuilder<'_> {
118        ArgsBuilder::new(self)
119    }
120
121    pub fn invoke(&self, args: &[AbiValue]) -> Result<Option<AbiValue>, JitArgumentError> {
122        if self.sig.args.len() != args.len() {
123            return Err(JitArgumentError::WrongNumberOfArguments);
124        }
125
126        let cif_args = self
127            .sig
128            .args
129            .iter()
130            .zip(args.iter())
131            .map(|(ty, val)| type_check(ty, val).map(|_| val))
132            .map(|v| v.map(AbiValue::to_libffi_arg))
133            .collect::<Result<Vec<_>, _>>()?;
134        Ok(unsafe { self.invoke_raw(&cif_args) })
135    }
136
137    unsafe fn invoke_raw(&self, cif_args: &[libffi::middle::Arg]) -> Option<AbiValue> {
138        let cif = self.sig.to_cif();
139        let value = cif.call::<UnTypedAbiValue>(
140            libffi::middle::CodePtr::from_ptr(self.code as *const _),
141            cif_args,
142        );
143        self.sig.ret.as_ref().map(|ty| value.to_typed(ty))
144    }
145}
146
147struct JitSig {
148    args: Vec<JitType>,
149    ret: Option<JitType>,
150}
151
152impl JitSig {
153    fn to_cif(&self) -> libffi::middle::Cif {
154        let ret = match self.ret {
155            Some(ref ty) => ty.to_libffi(),
156            None => libffi::middle::Type::void(),
157        };
158        libffi::middle::Cif::new(self.args.iter().map(JitType::to_libffi), ret)
159    }
160}
161
162#[derive(Debug, Clone, PartialEq, Eq)]
163#[non_exhaustive]
164pub enum JitType {
165    Int,
166    Float,
167    Bool,
168}
169
170impl JitType {
171    fn to_cranelift(&self) -> types::Type {
172        match self {
173            Self::Int => types::I64,
174            Self::Float => types::F64,
175            Self::Bool => types::I8,
176        }
177    }
178
179    fn to_libffi(&self) -> libffi::middle::Type {
180        match self {
181            Self::Int => libffi::middle::Type::i64(),
182            Self::Float => libffi::middle::Type::f64(),
183            Self::Bool => libffi::middle::Type::u8(),
184        }
185    }
186}
187
188#[derive(Debug, Clone, PartialEq)]
189#[non_exhaustive]
190pub enum AbiValue {
191    Float(f64),
192    Int(i64),
193    Bool(bool),
194}
195
196impl AbiValue {
197    fn to_libffi_arg(&self) -> libffi::middle::Arg {
198        match self {
199            AbiValue::Int(ref i) => libffi::middle::Arg::new(i),
200            AbiValue::Float(ref f) => libffi::middle::Arg::new(f),
201            AbiValue::Bool(ref b) => libffi::middle::Arg::new(b),
202        }
203    }
204}
205
206impl From<i64> for AbiValue {
207    fn from(i: i64) -> Self {
208        AbiValue::Int(i)
209    }
210}
211
212impl From<f64> for AbiValue {
213    fn from(f: f64) -> Self {
214        AbiValue::Float(f)
215    }
216}
217
218impl From<bool> for AbiValue {
219    fn from(b: bool) -> Self {
220        AbiValue::Bool(b)
221    }
222}
223
224impl TryFrom<AbiValue> for i64 {
225    type Error = ();
226
227    fn try_from(value: AbiValue) -> Result<Self, Self::Error> {
228        match value {
229            AbiValue::Int(i) => Ok(i),
230            _ => Err(()),
231        }
232    }
233}
234
235impl TryFrom<AbiValue> for f64 {
236    type Error = ();
237
238    fn try_from(value: AbiValue) -> Result<Self, Self::Error> {
239        match value {
240            AbiValue::Float(f) => Ok(f),
241            _ => Err(()),
242        }
243    }
244}
245
246impl TryFrom<AbiValue> for bool {
247    type Error = ();
248
249    fn try_from(value: AbiValue) -> Result<Self, Self::Error> {
250        match value {
251            AbiValue::Bool(b) => Ok(b),
252            _ => Err(()),
253        }
254    }
255}
256
257fn type_check(ty: &JitType, val: &AbiValue) -> Result<(), JitArgumentError> {
258    match (ty, val) {
259        (JitType::Int, AbiValue::Int(_))
260        | (JitType::Float, AbiValue::Float(_))
261        | (JitType::Bool, AbiValue::Bool(_)) => Ok(()),
262        _ => Err(JitArgumentError::ArgumentTypeMismatch),
263    }
264}
265
266#[derive(Copy, Clone)]
267union UnTypedAbiValue {
268    float: f64,
269    int: i64,
270    boolean: u8,
271    _void: (),
272}
273
274impl UnTypedAbiValue {
275    unsafe fn to_typed(self, ty: &JitType) -> AbiValue {
276        match ty {
277            JitType::Int => AbiValue::Int(self.int),
278            JitType::Float => AbiValue::Float(self.float),
279            JitType::Bool => AbiValue::Bool(self.boolean != 0),
280        }
281    }
282}
283
284// we don't actually ever touch CompiledCode til we drop it, it should be safe.
285// TODO: confirm with wasmtime ppl that it's not unsound?
286unsafe impl Send for CompiledCode {}
287unsafe impl Sync for CompiledCode {}
288
289impl Drop for CompiledCode {
290    fn drop(&mut self) {
291        // SAFETY: The only pointer that this memory will also be dropped now
292        unsafe { ManuallyDrop::take(&mut self.module).free_memory() }
293    }
294}
295
296impl fmt::Debug for CompiledCode {
297    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
298        f.write_str("[compiled code]")
299    }
300}
301
302pub struct ArgsBuilder<'a> {
303    values: Vec<Option<AbiValue>>,
304    code: &'a CompiledCode,
305}
306
307impl<'a> ArgsBuilder<'a> {
308    fn new(code: &'a CompiledCode) -> ArgsBuilder<'a> {
309        ArgsBuilder {
310            values: vec![None; code.sig.args.len()],
311            code,
312        }
313    }
314
315    pub fn set(&mut self, idx: usize, value: AbiValue) -> Result<(), JitArgumentError> {
316        type_check(&self.code.sig.args[idx], &value).map(|_| {
317            self.values[idx] = Some(value);
318        })
319    }
320
321    pub fn is_set(&self, idx: usize) -> bool {
322        self.values[idx].is_some()
323    }
324
325    pub fn into_args(self) -> Option<Args<'a>> {
326        self.values
327            .iter()
328            .map(|v| v.as_ref().map(AbiValue::to_libffi_arg))
329            .collect::<Option<_>>()
330            .map(|cif_args| Args {
331                _values: self.values,
332                cif_args,
333                code: self.code,
334            })
335    }
336}
337
338pub struct Args<'a> {
339    _values: Vec<Option<AbiValue>>,
340    cif_args: Vec<libffi::middle::Arg>,
341    code: &'a CompiledCode,
342}
343
344impl<'a> Args<'a> {
345    pub fn invoke(&self) -> Option<AbiValue> {
346        unsafe { self.code.invoke_raw(&self.cif_args) }
347    }
348}