1use crate::static_ref_function::StaticRefFunction;
2use core::marker::PhantomData;
3use core::ptr::NonNull;
4
5unsafe fn call_fn<'a, F, D, T>(data: NonNull<()>, arg: T) -> F::Output
6where
7 F: StaticRefFunction<'a, D, T> + ?Sized,
8 D: 'a,
9{
10 F::call(unsafe { data.cast().as_ref() }, arg)
11}
12
13pub struct RefFn<'a, T, R> {
15 data: NonNull<()>,
16 call_fn: unsafe fn(NonNull<()>, T) -> R,
17 _phantom: PhantomData<&'a ()>,
18}
19
20impl<'a, T, R> RefFn<'a, T, R> {
21 pub fn new<F, D>(data: &'a D) -> Self
23 where
24 F: StaticRefFunction<'a, D, T, Output = R> + ?Sized,
25 {
26 Self {
27 data: NonNull::from(data).cast(),
28 call_fn: call_fn::<'a, F, D, T>,
29 _phantom: PhantomData,
30 }
31 }
32
33 pub fn from_fn<F>(f: &'a F) -> Self
35 where
36 F: Fn(T) -> R,
37 {
38 Self::new::<F, F>(f)
39 }
40
41 pub fn call(&self, arg: T) -> R {
43 unsafe { (self.call_fn)(self.data, arg) }
44 }
45}
46
47impl<T, R> Clone for RefFn<'_, T, R> {
48 fn clone(&self) -> Self {
49 *self
50 }
51}
52
53impl<T, R> Copy for RefFn<'_, T, R> {}
54
55impl<'a, F, T, R> From<&'a F> for RefFn<'a, T, R>
56where
57 F: Fn(T) -> R,
58{
59 fn from(value: &'a F) -> Self {
60 Self::from_fn(value)
61 }
62}
63
64#[cfg(test)]
65mod tests {
66 use super::RefFn;
67 use crate::static_ref_function::StaticRefFunction;
68
69 static_assertions::assert_impl_all!(RefFn<'static, (), ()>: Clone, Copy, From<&'static fn(())>);
70 static_assertions::assert_not_impl_any!(RefFn<'static, (), ()>: Send, Sync);
71
72 #[test]
73 fn test_ref_fn_new() {
74 struct F;
75
76 impl StaticRefFunction<'_, u32, u32> for F {
77 type Output = u32;
78
79 fn call(data: &u32, arg: u32) -> Self::Output {
80 data + arg
81 }
82 }
83
84 let data = 2;
85 let f: RefFn<u32, u32> = F::bind(&data);
86
87 assert_eq!(f.call(3), 5);
88 assert_eq!(f.call(5), 7);
89 assert_eq!(f.call(7), 9);
90 }
91
92 #[test]
93 fn test_ref_fn_from() {
94 let data = 2_u32;
95 let closure = |arg: u32| data + arg;
96 let f: RefFn<u32, u32> = RefFn::from(&closure);
97
98 assert_eq!(f.call(3), 5);
99 assert_eq!(f.call(5), 7);
100 assert_eq!(f.call(7), 9);
101 }
102}