Skip to main content

rustpython_jit/
lib.rs

1mod instructions;
2
3extern crate alloc;
4
5use alloc::fmt;
6use core::mem::ManuallyDrop;
7use cranelift::prelude::*;
8use cranelift_jit::{JITBuilder, JITModule};
9use cranelift_module::{FuncId, Linkage, Module, ModuleError};
10use instructions::FunctionCompiler;
11use rustpython_compiler_core::bytecode;
12
13#[derive(Debug, thiserror::Error)]
14#[non_exhaustive]
15pub enum JitCompileError {
16    #[error("function can't be jitted")]
17    NotSupported,
18    #[error("bad bytecode")]
19    BadBytecode,
20    #[error("error while compiling to machine code: {0}")]
21    CraneliftError(Box<ModuleError>),
22}
23
24impl From<ModuleError> for JitCompileError {
25    fn from(err: ModuleError) -> Self {
26        Self::CraneliftError(Box::new(err))
27    }
28}
29
30#[derive(Debug, thiserror::Error, Eq, PartialEq)]
31#[non_exhaustive]
32pub enum JitArgumentError {
33    #[error("argument is of wrong type")]
34    ArgumentTypeMismatch,
35    #[error("wrong number of arguments")]
36    WrongNumberOfArguments,
37}
38
39struct Jit {
40    builder_context: FunctionBuilderContext,
41    ctx: codegen::Context,
42    module: JITModule,
43}
44
45impl Jit {
46    fn new() -> Self {
47        let builder = JITBuilder::new(cranelift_module::default_libcall_names())
48            .expect("Failed to build JITBuilder");
49        let module = JITModule::new(builder);
50        Self {
51            builder_context: FunctionBuilderContext::new(),
52            ctx: module.make_context(),
53            module,
54        }
55    }
56
57    fn build_function<C: bytecode::Constant>(
58        &mut self,
59        bytecode: &bytecode::CodeObject<C>,
60        args: &[JitType],
61        ret: Option<JitType>,
62    ) -> Result<(FuncId, JitSig), JitCompileError> {
63        for arg in args {
64            self.ctx
65                .func
66                .signature
67                .params
68                .push(AbiParam::new(arg.to_cranelift()));
69        }
70
71        if ret.is_some() {
72            self.ctx
73                .func
74                .signature
75                .returns
76                .push(AbiParam::new(ret.clone().unwrap().to_cranelift()));
77        }
78
79        let id = self.module.declare_function(
80            &format!("jit_{}", bytecode.obj_name.as_ref()),
81            Linkage::Export,
82            &self.ctx.func.signature,
83        )?;
84
85        let func_ref = self.module.declare_func_in_func(id, &mut self.ctx.func);
86
87        let mut builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context);
88        let entry_block = builder.create_block();
89        builder.append_block_params_for_function_params(entry_block);
90        builder.switch_to_block(entry_block);
91
92        let sig = {
93            let mut compiler = FunctionCompiler::new(
94                &mut builder,
95                bytecode.varnames.len(),
96                args,
97                ret,
98                entry_block,
99            );
100
101            compiler.compile(func_ref, bytecode)?;
102
103            compiler.sig
104        };
105
106        builder.seal_all_blocks();
107        builder.finalize();
108
109        self.module.define_function(id, &mut self.ctx)?;
110
111        self.module.clear_context(&mut self.ctx);
112
113        Ok((id, sig))
114    }
115}
116
117pub fn compile<C: bytecode::Constant>(
118    bytecode: &bytecode::CodeObject<C>,
119    args: &[JitType],
120    ret: Option<JitType>,
121) -> Result<CompiledCode, JitCompileError> {
122    let mut jit = Jit::new();
123
124    let (id, sig) = jit.build_function(bytecode, args, ret)?;
125
126    jit.module.finalize_definitions()?;
127
128    let code = jit.module.get_finalized_function(id);
129    Ok(CompiledCode {
130        sig,
131        code,
132        module: ManuallyDrop::new(jit.module),
133    })
134}
135
136pub struct CompiledCode {
137    sig: JitSig,
138    code: *const u8,
139    module: ManuallyDrop<JITModule>,
140}
141
142impl CompiledCode {
143    pub fn args_builder(&self) -> ArgsBuilder<'_> {
144        ArgsBuilder::new(self)
145    }
146
147    pub fn invoke(&self, args: &[AbiValue]) -> Result<Option<AbiValue>, JitArgumentError> {
148        if self.sig.args.len() != args.len() {
149            return Err(JitArgumentError::WrongNumberOfArguments);
150        }
151
152        let cif_args = self
153            .sig
154            .args
155            .iter()
156            .zip(args.iter())
157            .map(|(ty, val)| type_check(ty, val).map(|_| val))
158            .map(|v| v.map(AbiValue::to_libffi_arg))
159            .collect::<Result<Vec<_>, _>>()?;
160        Ok(unsafe { self.invoke_raw(&cif_args) })
161    }
162
163    unsafe fn invoke_raw(&self, cif_args: &[libffi::middle::Arg<'_>]) -> Option<AbiValue> {
164        unsafe {
165            let cif = self.sig.to_cif();
166            let value = cif.call::<UnTypedAbiValue>(
167                libffi::middle::CodePtr::from_ptr(self.code as *const _),
168                cif_args,
169            );
170            self.sig.ret.as_ref().map(|ty| value.to_typed(ty))
171        }
172    }
173}
174
175struct JitSig {
176    args: Vec<JitType>,
177    ret: Option<JitType>,
178}
179
180impl JitSig {
181    fn to_cif(&self) -> libffi::middle::Cif {
182        let ret = match self.ret {
183            Some(ref ty) => ty.to_libffi(),
184            None => libffi::middle::Type::void(),
185        };
186        libffi::middle::Cif::new(self.args.iter().map(JitType::to_libffi), ret)
187    }
188}
189
190#[derive(Debug, Clone, PartialEq, Eq)]
191#[non_exhaustive]
192pub enum JitType {
193    Int,
194    Float,
195    Bool,
196}
197
198impl JitType {
199    fn to_cranelift(&self) -> types::Type {
200        match self {
201            Self::Int => types::I64,
202            Self::Float => types::F64,
203            Self::Bool => types::I8,
204        }
205    }
206
207    fn to_libffi(&self) -> libffi::middle::Type {
208        match self {
209            Self::Int => libffi::middle::Type::i64(),
210            Self::Float => libffi::middle::Type::f64(),
211            Self::Bool => libffi::middle::Type::u8(),
212        }
213    }
214}
215
216#[derive(Debug, Clone, PartialEq)]
217#[non_exhaustive]
218pub enum AbiValue {
219    Float(f64),
220    Int(i64),
221    Bool(bool),
222}
223
224impl AbiValue {
225    fn to_libffi_arg(&self) -> libffi::middle::Arg<'_> {
226        match self {
227            AbiValue::Int(i) => libffi::middle::Arg::new(i),
228            AbiValue::Float(f) => libffi::middle::Arg::new(f),
229            AbiValue::Bool(b) => libffi::middle::Arg::new(b),
230        }
231    }
232}
233
234impl From<i64> for AbiValue {
235    fn from(i: i64) -> Self {
236        AbiValue::Int(i)
237    }
238}
239
240impl From<f64> for AbiValue {
241    fn from(f: f64) -> Self {
242        AbiValue::Float(f)
243    }
244}
245
246impl From<bool> for AbiValue {
247    fn from(b: bool) -> Self {
248        AbiValue::Bool(b)
249    }
250}
251
252impl TryFrom<AbiValue> for i64 {
253    type Error = ();
254
255    fn try_from(value: AbiValue) -> Result<Self, Self::Error> {
256        match value {
257            AbiValue::Int(i) => Ok(i),
258            _ => Err(()),
259        }
260    }
261}
262
263impl TryFrom<AbiValue> for f64 {
264    type Error = ();
265
266    fn try_from(value: AbiValue) -> Result<Self, Self::Error> {
267        match value {
268            AbiValue::Float(f) => Ok(f),
269            _ => Err(()),
270        }
271    }
272}
273
274impl TryFrom<AbiValue> for bool {
275    type Error = ();
276
277    fn try_from(value: AbiValue) -> Result<Self, Self::Error> {
278        match value {
279            AbiValue::Bool(b) => Ok(b),
280            _ => Err(()),
281        }
282    }
283}
284
285fn type_check(ty: &JitType, val: &AbiValue) -> Result<(), JitArgumentError> {
286    match (ty, val) {
287        (JitType::Int, AbiValue::Int(_))
288        | (JitType::Float, AbiValue::Float(_))
289        | (JitType::Bool, AbiValue::Bool(_)) => Ok(()),
290        _ => Err(JitArgumentError::ArgumentTypeMismatch),
291    }
292}
293
294#[derive(Copy, Clone)]
295union UnTypedAbiValue {
296    float: f64,
297    int: i64,
298    boolean: u8,
299    _void: (),
300}
301
302impl UnTypedAbiValue {
303    unsafe fn to_typed(self, ty: &JitType) -> AbiValue {
304        unsafe {
305            match ty {
306                JitType::Int => AbiValue::Int(self.int),
307                JitType::Float => AbiValue::Float(self.float),
308                JitType::Bool => AbiValue::Bool(self.boolean != 0),
309            }
310        }
311    }
312}
313
314// we don't actually ever touch CompiledCode til we drop it, it should be safe.
315// TODO: confirm with wasmtime ppl that it's not unsound?
316unsafe impl Send for CompiledCode {}
317unsafe impl Sync for CompiledCode {}
318
319impl Drop for CompiledCode {
320    fn drop(&mut self) {
321        // SAFETY: The only pointer that this memory will also be dropped now
322        unsafe { ManuallyDrop::take(&mut self.module).free_memory() }
323    }
324}
325
326impl fmt::Debug for CompiledCode {
327    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
328        f.write_str("[compiled code]")
329    }
330}
331
332pub struct ArgsBuilder<'a> {
333    values: Vec<Option<AbiValue>>,
334    code: &'a CompiledCode,
335}
336
337impl<'a> ArgsBuilder<'a> {
338    fn new(code: &'a CompiledCode) -> ArgsBuilder<'a> {
339        ArgsBuilder {
340            values: vec![None; code.sig.args.len()],
341            code,
342        }
343    }
344
345    pub fn set(&mut self, idx: usize, value: AbiValue) -> Result<(), JitArgumentError> {
346        type_check(&self.code.sig.args[idx], &value).map(|_| {
347            self.values[idx] = Some(value);
348        })
349    }
350
351    pub fn is_set(&self, idx: usize) -> bool {
352        self.values[idx].is_some()
353    }
354
355    pub fn into_args(self) -> Option<Args<'a>> {
356        // Ensure all values are set
357        if self.values.iter().any(|v| v.is_none()) {
358            return None;
359        }
360        Some(Args {
361            values: self.values.into_iter().map(|v| v.unwrap()).collect(),
362            code: self.code,
363        })
364    }
365}
366
367pub struct Args<'a> {
368    values: Vec<AbiValue>,
369    code: &'a CompiledCode,
370}
371
372impl Args<'_> {
373    pub fn invoke(&self) -> Option<AbiValue> {
374        let cif_args: Vec<_> = self.values.iter().map(AbiValue::to_libffi_arg).collect();
375        unsafe { self.code.invoke_raw(&cif_args) }
376    }
377}