uefi_async/no_alloc/
task.rs

1//! This code is inspired by the approach in this embedded Rust crate: embassy-executor.
2//!
3//! Usage:
4//! ```rust, no_run
5//! #[doc(hidden)]
6//! fn __async_fun() -> impl Future<Output = ()> { ( move || async move {})() }
7//! fn async_fun() {
8//!     const POOL_SIZE: usize = 4;
9//!     static POOL: TaskPoolLayout<{ TaskCapture::<_, _>::size::<POOL_SIZE>(__async_fun) }> = unsafe {
10//!         transmute(TaskCapture::<_,_>::new::<POOL_SIZE>(__async_fun))
11//!     };
12//!     const fn get<F, Args, Fut>(_: F) -> &'static TaskPool<Fut, POOL_SIZE>
13//!     where F: TaskFn<Args, Fut = Fut>, Fut: SafeFuture {
14//!         const {
15//!             assert_eq!(size_of::<TaskPool<Fut, POOL_SIZE>>(), size_of_val(&POOL));
16//!             assert!(align_of::<TaskPool<Fut, POOL_SIZE>>() <= 128);
17//!         }
18//!         unsafe { &*POOL.get().cast() }
19//!     }
20//!     get(__async_fun);
21//! }
22//! ```
23use core::any::type_name;
24use core::cell::UnsafeCell;
25use core::fmt::{Debug, Formatter, Result};
26use core::marker::PhantomData;
27use core::mem::MaybeUninit;
28use core::ops::Deref;
29use core::pin::Pin;
30use core::ptr::{drop_in_place, from_mut, from_ref, with_exposed_provenance_mut, write};
31use core::sync::atomic::{AtomicU8, Ordering};
32use core::task::{Context, Poll, Waker};
33use num_enum::{FromPrimitive, IntoPrimitive};
34use static_cell::StaticCell;
35
36pub trait SafeFuture: Future<Output = ()> + 'static + Send + Sync {}
37impl<T: Future<Output = ()> + 'static + Send + Sync> SafeFuture for T {}
38pub trait TaskFn<Args>: Copy { type Fut: SafeFuture; }
39pub type TaskTypeFn = unsafe fn(*mut TaskHeader) -> bool;
40// pub type TaskTypeFn = unsafe fn(*mut (), &Waker) -> Poll<()>;
41#[derive(Debug, Clone, Copy, PartialEq, Eq, IntoPrimitive, FromPrimitive)] #[repr(u8)]
42pub enum State {
43    Free,           // 空槽
44    Initialized,    // 槽使用
45    Ready,          // 进入任务队列
46    Running,        // 运行
47    Yielded,        // 让出
48    Waiting,        // 等待
49
50    #[default]
51    Unreachable,    // 不可达
52}
53pub struct TaskCapture<F, Args>(PhantomData<(F, Args)>);
54#[derive(Debug)] #[repr(C)]
55pub struct TaskHeader {
56    pub poll: TaskTypeFn,
57    pub control: AtomicU8,
58    pub state: AtomicU8,
59}
60#[derive(Debug)] #[repr(C)]
61pub struct StaticFuture<F>(pub UnsafeCell<usize>, pub F);
62#[repr(C, align(128))]
63pub struct TaskSlot<F: SafeFuture> {
64    pub header: TaskHeader,
65    pub future: StaticFuture<StaticCell<F>>,
66}
67#[repr(C, align(128))]
68pub struct TaskPool<F: SafeFuture, const N: usize> (pub [TaskSlot<F>; N]);
69#[repr(C, align(128))]
70pub struct TaskPoolLayout<const SIZE: usize> (pub UnsafeCell<MaybeUninit<[u8; SIZE]>>);
71
72impl<F: SafeFuture> TaskSlot<F> {
73    pub const NEW: Self = Self::new();
74    const fn new() -> Self {
75        Self {
76            header: TaskHeader {
77                poll: TaskSlot::<F>::poll,        // magic: automatically binding to SafeFuture
78                control: AtomicU8::new(0),
79                state: AtomicU8::new(State::Free as u8),
80            },
81            future: StaticFuture::new(),                // 占位
82        }
83    }
84
85    // 包装函数:将 *mut TaskHeader 转回 TaskSlot<F> 并执行
86    pub unsafe fn poll(ptr: *mut TaskHeader) -> bool {
87        let slot = unsafe { &*ptr.cast::<TaskSlot<F>>() };
88
89        // let waker = unsafe { Waker::from_raw() };
90        // let mut ctx = Context::from_waker(&waker);
91        //
92        // // SAFETY: static future, no move
93        // let future = unsafe { Pin::new_unchecked(slot.future.get_mut()) };
94        //
95        // match future.poll(&mut ctx) {
96        //     Poll::Ready(_) => true,
97        //     Poll::Pending => false,
98        // }
99        todo!("TaskSlot::poll")
100    }
101}
102impl<F: SafeFuture> StaticFuture<StaticCell<F>> {
103    #[inline(always)]
104    pub const fn new() -> Self { Self(UnsafeCell::new(0), StaticCell::new()) }
105    /// Lazy Initialization
106    /// SAFETY: data race!
107    #[inline(always)]
108    pub unsafe fn init(&'static self, future: impl FnOnce() -> F) {
109        let future_ptr = self.0.get();
110        let future_addr = unsafe { future_ptr.read() };
111        if future_addr == 0 {
112            // init_with return value maybe cause stack overflow
113            let uninit_ptr = self.1.uninit();
114            let new_ptr = uninit_ptr.as_mut_ptr();
115            unsafe {
116                write(new_ptr, future());
117                future_ptr.write(from_mut(uninit_ptr).addr());
118            }
119        } else {
120            let cell = with_exposed_provenance_mut(future_addr);
121            unsafe {
122                drop_in_place(cell);
123                write(cell, future());
124            }
125        }
126    }
127
128    #[inline(always)]
129    pub unsafe fn get_mut(&self) -> &mut F {
130        unsafe {
131            let addr = self.0.get().read();
132            debug_assert_ne!(addr, 0, "Future is not initialized");
133            with_exposed_provenance_mut::<F>(addr).as_mut().unwrap_unchecked()
134        }
135    }
136}
137impl<F: SafeFuture, const N: usize> TaskPool<F, N> {
138    #[inline(always)]
139    pub const fn new() -> Self { Self([TaskSlot::NEW; N]) }
140    #[inline(always)]
141    pub fn init(&'static self, future: impl FnOnce() -> F) -> *mut TaskHeader {
142        for slot in self.0.iter() {
143            if slot.header.state.compare_exchange(
144                State::Free.into(), State::Initialized.into(), Ordering::Acquire, Ordering::Relaxed
145            ).is_err() { continue }
146
147            // Only init Future => TaskSlot, not run F
148            // impl NRVO (Named Return Value Optimization), avoid stack overflow
149            unsafe { slot.future.init(future) }
150
151            return from_ref(&slot.header).cast_mut()
152        }
153
154        panic!("TaskPool capacity exceeded! No empty slots available.");
155    }
156}
157impl<const SIZE: usize> TaskPoolLayout<SIZE> {
158    #[inline(always)]
159    pub const fn get(&self) -> *const u8 { self.0.get().cast() }
160}
161impl<F, Args, Fut> TaskCapture<F, Args> where F: TaskFn<Args, Fut = Fut>, Fut: SafeFuture {
162    #[inline(always)]
163    pub const fn size<const POOL_SIZE: usize>(_: F) -> usize { size_of::<TaskPool<Fut, POOL_SIZE>>() }
164    #[inline(always)]
165    pub const fn new<const POOL_SIZE: usize>(_: F) -> TaskPool<Fut, POOL_SIZE> { TaskPool::new() }
166}
167
168macro_rules! task_fn_impl {
169    () => {
170        impl<F, Fut> TaskFn<()> for F where F: Copy + FnOnce() -> Fut, Fut: SafeFuture,
171        { type Fut = Fut; }
172    };
173    ($head:ident $(, $tail:ident)*) => {
174        impl<F, Fut, $head, $($tail,)*> TaskFn<($head, $($tail,)*)> for F
175        where F: Copy + FnOnce($head, $($tail,)*) -> Fut, Fut: SafeFuture,
176        { type Fut = Fut; }
177        task_fn_impl!($($tail),*);
178    };
179}
180task_fn_impl!(T15, T14, T13, T12, T11, T10, T9, T8, T7, T6, T5, T4, T3, T2, T1, T0);
181impl<F: SafeFuture> Debug for TaskSlot<F> {
182    fn fmt(&self, f: &mut Formatter<'_>) -> Result {
183        let addr = format_args!("StaticCell<{}>@{:p}", type_name::<F>(), &self.future);
184        f.debug_struct("TaskSlot")
185            .field("header", &self.header)
186            .field("future", &addr)
187            .finish()
188    }
189}
190impl<F: SafeFuture, const N: usize> Debug for TaskPool<F, N> {
191    fn fmt(&self, f: &mut Formatter<'_>) -> Result {
192        f.debug_struct("TaskPool")
193            .field("size", &N)
194            .field("task_type", &type_name::<F>())
195            .field("slots", &self.0)
196            .finish()
197    }
198}
199unsafe impl<F: SafeFuture> Sync for TaskSlot<F> {}
200unsafe impl<F: SafeFuture> Send for TaskSlot<F> {}
201unsafe impl<const SIZE: usize> Send for TaskPoolLayout<SIZE> {}
202unsafe impl<const SIZE: usize> Sync for TaskPoolLayout<SIZE> {}