Skip to main content

wry_bindgen_runtime/wire/
callback.rs

1//! Callback registration encoding and Rust callback storage.
2
3#![allow(clippy::type_complexity)]
4
5use alloc::boxed::Box;
6use alloc::rc::Rc;
7use core::cell::RefCell;
8
9use super::{
10    BinaryDecode, BinaryEncode, DecodeError, DecodedData, EncodeTypeDef, EncodedData, TypeDef,
11    object_store::ObjectHandle,
12};
13
14type CallbackFn = dyn Fn(&mut DecodedData, &mut EncodedData) -> Result<(), DecodeError>;
15
16#[derive(Clone)]
17pub struct RustCallback {
18    f: Rc<CallbackFn>,
19}
20
21impl RustCallback {
22    pub fn new_fn<F>(f: F) -> Self
23    where
24        F: Fn(&mut DecodedData, &mut EncodedData) -> Result<(), DecodeError> + 'static,
25    {
26        Self { f: Rc::new(f) }
27    }
28
29    pub fn new_fn_mut<F>(f: F) -> Self
30    where
31        F: FnMut(&mut DecodedData, &mut EncodedData) -> Result<(), DecodeError> + 'static,
32    {
33        let cell = RefCell::new(f);
34        Self {
35            f: Rc::new(move |data: &mut DecodedData, encoder: &mut EncodedData| {
36                let mut f = cell.borrow_mut();
37                f(data, encoder)
38            }),
39        }
40    }
41
42    pub fn call(
43        &self,
44        data: &mut DecodedData,
45        encoder: &mut EncodedData,
46    ) -> Result<(), DecodeError> {
47        (self.f)(data, encoder)
48    }
49}
50
51const RUST_OWNED_CALLBACK_POLICY: u32 = 0;
52
53fn encode_rust_owned_callback(handle: ObjectHandle, encoder: &mut EncodedData) {
54    handle.encode(encoder);
55    RUST_OWNED_CALLBACK_POLICY.encode(encoder);
56}
57
58macro_rules! callback_type_def_body {
59    ($encoder:expr; R = $R:ty; $($arg:ty),*) => {{
60        $encoder.callback::<fn($($arg),*) -> $R>();
61    }};
62    ($encoder:expr; R = $R:ty; borrow_first; $($rest:ty),*) => {{
63        let count: u8 = 1 $(+ {
64            let _ = PhantomData::<$rest>;
65            1
66        })*;
67        $encoder.callback_with_signature(count, |type_def| {
68            type_def.borrowed_ref();
69            $(<$rest as EncodeTypeDef>::encode_type_def(type_def);)*
70            <$R as EncodeTypeDef>::encode_type_def(type_def);
71        });
72    }};
73}
74
75macro_rules! insert_callback {
76    ($callback:expr) => {{ crate::batch::with_runtime(|rt| rt.insert_object_box(Box::new($callback))) }};
77}
78
79macro_rules! encode_callback_ref {
80    (
81        impl ($($self_ty:tt)*) via *mut dyn FnMut, $ctor:ident;
82        $($arg:ident),*
83    ) => {
84        impl<R, $($arg,)*> BinaryEncode for $($self_ty)*
85        where
86            $($arg: BinaryDecode + EncodeTypeDef + 'static,)*
87            R: BinaryEncode + EncodeTypeDef + 'static,
88        {
89            #[allow(non_snake_case)]
90            fn encode(self, encoder: &mut EncodedData) {
91                encoder.mark_needs_flush();
92
93                let ptr = self as *mut dyn FnMut($($arg),*) -> R;
94                let (data_ptr, vtable_ptr): (usize, usize) = unsafe { core::mem::transmute(ptr) };
95
96                let callback = RustCallback::$ctor(
97                    move |_decoder: &mut DecodedData, encoder: &mut EncodedData| {
98                        let ptr: *mut dyn FnMut($($arg),*) -> R = unsafe {
99                            core::mem::transmute((data_ptr, vtable_ptr))
100                        };
101                        let f: &mut dyn FnMut($($arg),*) -> R = unsafe { &mut *ptr };
102                        $(let $arg = <$arg as BinaryDecode>::decode(_decoder)?;)*
103                        let result = f($($arg),*);
104                        result.encode(encoder);
105                        Ok(())
106                    },
107                );
108                let handle = insert_callback!(callback);
109                encode_rust_owned_callback(handle, encoder);
110                crate::batch::drop_rust_object(handle);
111            }
112        }
113    };
114    (
115        impl ($($self_ty:tt)*) via *const dyn Fn, $ctor:ident;
116        $($arg:ident),*
117    ) => {
118        impl<R, $($arg,)*> BinaryEncode for $($self_ty)*
119        where
120            $($arg: BinaryDecode + EncodeTypeDef + 'static,)*
121            R: BinaryEncode + EncodeTypeDef + 'static,
122        {
123            #[allow(non_snake_case)]
124            fn encode(self, encoder: &mut EncodedData) {
125                encoder.mark_needs_flush();
126
127                let ptr = self as *const dyn Fn($($arg),*) -> R;
128                let (data_ptr, vtable_ptr): (usize, usize) = unsafe { core::mem::transmute(ptr) };
129
130                let callback = RustCallback::$ctor(
131                    move |_decoder: &mut DecodedData, encoder: &mut EncodedData| {
132                        let ptr: *const dyn Fn($($arg),*) -> R = unsafe {
133                            core::mem::transmute((data_ptr, vtable_ptr))
134                        };
135                        let f: &dyn Fn($($arg),*) -> R = unsafe { &*ptr };
136                        $(let $arg = <$arg as BinaryDecode>::decode(_decoder)?;)*
137                        let result = f($($arg),*);
138                        result.encode(encoder);
139                        Ok(())
140                    },
141                );
142                let handle = insert_callback!(callback);
143                encode_rust_owned_callback(handle, encoder);
144                crate::batch::drop_rust_object(handle);
145            }
146        }
147    };
148}
149
150macro_rules! impl_callback_ref {
151    ($($arg:ident),*) => {
152        impl<R, $($arg,)*> EncodeTypeDef for &mut dyn FnMut($($arg),*) -> R
153        where
154            $($arg: EncodeTypeDef + 'static,)*
155            R: EncodeTypeDef + 'static,
156        {
157            fn encode_type_def(encoder: &mut TypeDef) {
158                callback_type_def_body!(encoder; R = R; $($arg),*);
159            }
160        }
161
162        encode_callback_ref!(
163            impl (&mut dyn FnMut($($arg),*) -> R) via *mut dyn FnMut, new_fn_mut;
164            $($arg),*
165        );
166
167        impl<R, $($arg,)*> EncodeTypeDef for &dyn Fn($($arg),*) -> R
168        where
169            $($arg: EncodeTypeDef + 'static,)*
170            R: EncodeTypeDef + 'static,
171        {
172            fn encode_type_def(encoder: &mut TypeDef) {
173                callback_type_def_body!(encoder; R = R; $($arg),*);
174            }
175        }
176
177        encode_callback_ref!(
178            impl (&dyn Fn($($arg),*) -> R) via *const dyn Fn, new_fn;
179            $($arg),*
180        );
181
182        impl<R, $($arg,)*> EncodeTypeDef for &mut dyn Fn($($arg),*) -> R
183        where
184            $($arg: EncodeTypeDef + 'static,)*
185            R: EncodeTypeDef + 'static,
186        {
187            fn encode_type_def(encoder: &mut TypeDef) {
188                callback_type_def_body!(encoder; R = R; $($arg),*);
189            }
190        }
191
192        encode_callback_ref!(
193            impl (&mut dyn Fn($($arg),*) -> R) via *const dyn Fn, new_fn;
194            $($arg),*
195        );
196    };
197}
198
199impl_callback_ref!();
200impl_callback_ref!(A1);
201impl_callback_ref!(A1, A2);
202impl_callback_ref!(A1, A2, A3);
203impl_callback_ref!(A1, A2, A3, A4);
204impl_callback_ref!(A1, A2, A3, A4, A5);
205impl_callback_ref!(A1, A2, A3, A4, A5, A6);
206impl_callback_ref!(A1, A2, A3, A4, A5, A6, A7);