1#![no_std]
2
3#[cfg(any(feature = "std", unix, windows))]
4#[macro_use]
5extern crate std;
6extern crate alloc;
7
8use alloc::boxed::Box;
9use anyhow::Error;
10use core::cell::Cell;
11use core::marker::PhantomData;
12use core::ops::Range;
13
14cfg_if::cfg_if! {
15 if #[cfg(not(feature = "std"))] {
16 mod nostd;
17 use nostd as imp;
18 } else if #[cfg(windows)] {
19 mod windows;
20 use windows as imp;
21 } else if #[cfg(unix)] {
22 mod unix;
23 use unix as imp;
24 } else {
25 compile_error!("fibers are not supported on this platform");
26 }
27}
28
29#[cfg(any(unix, not(feature = "std")))]
32pub(crate) mod stackswitch;
33
34pub struct FiberStack(imp::FiberStack);
36
37fn _assert_send_sync() {
38 fn _assert_send<T: Send>() {}
39 fn _assert_sync<T: Sync>() {}
40
41 _assert_send::<FiberStack>();
42 _assert_sync::<FiberStack>();
43}
44
45pub type Result<T, E = imp::Error> = core::result::Result<T, E>;
46
47impl FiberStack {
48 pub fn new(size: usize, zeroed: bool) -> Result<Self> {
50 Ok(Self(imp::FiberStack::new(size, zeroed)?))
51 }
52
53 pub fn from_custom(custom: Box<dyn RuntimeFiberStack>) -> Result<Self> {
55 Ok(Self(imp::FiberStack::from_custom(custom)?))
56 }
57
58 pub unsafe fn from_raw_parts(bottom: *mut u8, guard_size: usize, len: usize) -> Result<Self> {
72 Ok(Self(unsafe {
73 imp::FiberStack::from_raw_parts(bottom, guard_size, len)?
74 }))
75 }
76
77 pub fn top(&self) -> Option<*mut u8> {
82 self.0.top()
83 }
84
85 pub fn range(&self) -> Option<Range<usize>> {
88 self.0.range()
89 }
90
91 pub fn is_from_raw_parts(&self) -> bool {
94 self.0.is_from_raw_parts()
95 }
96
97 pub fn guard_range(&self) -> Option<Range<*mut u8>> {
99 self.0.guard_range()
100 }
101}
102
103pub unsafe trait RuntimeFiberStackCreator: Send + Sync {
105 fn new_stack(&self, size: usize, zeroed: bool) -> Result<Box<dyn RuntimeFiberStack>, Error>;
111}
112
113pub unsafe trait RuntimeFiberStack: Send + Sync {
115 fn top(&self) -> *mut u8;
117 fn range(&self) -> Range<usize>;
119 fn guard_range(&self) -> Range<*mut u8>;
121}
122
123pub struct Fiber<'a, Resume, Yield, Return> {
124 stack: Option<FiberStack>,
125 inner: imp::Fiber,
126 done: Cell<bool>,
127 _phantom: PhantomData<&'a (Resume, Yield, Return)>,
128}
129
130pub struct Suspend<Resume, Yield, Return> {
131 inner: imp::Suspend,
132 _phantom: PhantomData<(Resume, Yield, Return)>,
133}
134
135enum RunResult<Resume, Yield, Return> {
136 Executing,
137 Resuming(Resume),
138 Yield(Yield),
139 Returned(Return),
140 #[cfg(feature = "std")]
141 Panicked(Box<dyn core::any::Any + Send>),
142}
143
144impl<'a, Resume, Yield, Return> Fiber<'a, Resume, Yield, Return> {
145 pub fn new(
151 stack: FiberStack,
152 func: impl FnOnce(Resume, &mut Suspend<Resume, Yield, Return>) -> Return + 'a,
153 ) -> Result<Self> {
154 let inner = imp::Fiber::new(&stack.0, func)?;
155
156 Ok(Self {
157 stack: Some(stack),
158 inner,
159 done: Cell::new(false),
160 _phantom: PhantomData,
161 })
162 }
163
164 pub fn resume(&self, val: Resume) -> Result<Return, Yield> {
180 assert!(!self.done.replace(true), "cannot resume a finished fiber");
181 let result = Cell::new(RunResult::Resuming(val));
182 self.inner.resume(&self.stack().0, &result);
183 match result.into_inner() {
184 RunResult::Resuming(_) | RunResult::Executing => unreachable!(),
185 RunResult::Yield(y) => {
186 self.done.set(false);
187 Err(y)
188 }
189 RunResult::Returned(r) => Ok(r),
190 #[cfg(feature = "std")]
191 RunResult::Panicked(_payload) => {
192 use std::panic;
193 panic::resume_unwind(_payload);
194 }
195 }
196 }
197
198 pub fn done(&self) -> bool {
200 self.done.get()
201 }
202
203 pub fn stack(&self) -> &FiberStack {
205 self.stack.as_ref().unwrap()
206 }
207
208 pub fn into_stack(mut self) -> FiberStack {
210 assert!(self.done());
211 self.stack.take().unwrap()
212 }
213}
214
215impl<Resume, Yield, Return> Suspend<Resume, Yield, Return> {
216 pub fn suspend(&mut self, value: Yield) -> Resume {
226 self.inner
227 .switch::<Resume, Yield, Return>(RunResult::Yield(value))
228 }
229
230 fn execute(
231 inner: imp::Suspend,
232 initial: Resume,
233 func: impl FnOnce(Resume, &mut Suspend<Resume, Yield, Return>) -> Return,
234 ) {
235 let mut suspend = Suspend {
236 inner,
237 _phantom: PhantomData,
238 };
239
240 #[cfg(feature = "std")]
241 {
242 use std::panic::{self, AssertUnwindSafe};
243 let result = panic::catch_unwind(AssertUnwindSafe(|| (func)(initial, &mut suspend)));
244 suspend.inner.switch::<Resume, Yield, Return>(match result {
245 Ok(result) => RunResult::Returned(result),
246 Err(panic) => RunResult::Panicked(panic),
247 });
248 }
249 #[cfg(not(feature = "std"))]
254 {
255 let result = (func)(initial, &mut suspend);
256 suspend
257 .inner
258 .switch::<Resume, Yield, Return>(RunResult::Returned(result));
259 }
260 }
261}
262
263impl<A, B, C> Drop for Fiber<'_, A, B, C> {
264 fn drop(&mut self) {
265 debug_assert!(self.done.get(), "fiber dropped without finishing");
266 }
267}
268
269#[cfg(all(test))]
270mod tests {
271 use super::{Fiber, FiberStack};
272 use alloc::string::ToString;
273 use std::cell::Cell;
274 use std::rc::Rc;
275
276 fn fiber_stack(size: usize) -> FiberStack {
277 FiberStack::new(size, false).unwrap()
278 }
279
280 #[test]
281 fn small_stacks() {
282 Fiber::<(), (), ()>::new(fiber_stack(0), |_, _| {})
283 .unwrap()
284 .resume(())
285 .unwrap();
286 Fiber::<(), (), ()>::new(fiber_stack(1), |_, _| {})
287 .unwrap()
288 .resume(())
289 .unwrap();
290 }
291
292 #[test]
293 fn smoke() {
294 let hit = Rc::new(Cell::new(false));
295 let hit2 = hit.clone();
296 let fiber = Fiber::<(), (), ()>::new(fiber_stack(1024 * 1024), move |_, _| {
297 hit2.set(true);
298 })
299 .unwrap();
300 assert!(!hit.get());
301 fiber.resume(()).unwrap();
302 assert!(hit.get());
303 }
304
305 #[test]
306 fn suspend_and_resume() {
307 let hit = Rc::new(Cell::new(false));
308 let hit2 = hit.clone();
309 let fiber = Fiber::<(), (), ()>::new(fiber_stack(1024 * 1024), move |_, s| {
310 s.suspend(());
311 hit2.set(true);
312 s.suspend(());
313 })
314 .unwrap();
315 assert!(!hit.get());
316 assert!(fiber.resume(()).is_err());
317 assert!(!hit.get());
318 assert!(fiber.resume(()).is_err());
319 assert!(hit.get());
320 assert!(fiber.resume(()).is_ok());
321 assert!(hit.get());
322 }
323
324 #[test]
325 fn backtrace_traces_to_host() {
326 #[inline(never)] fn look_for_me() {
328 run_test();
329 }
330 fn assert_contains_host() {
331 let trace = backtrace::Backtrace::new();
332 println!("{trace:?}");
333 assert!(
334 trace
335 .frames()
336 .iter()
337 .flat_map(|f| f.symbols())
338 .filter_map(|s| Some(s.name()?.to_string()))
339 .any(|s| s.contains("look_for_me"))
340 || cfg!(windows)
342 || cfg!(all(target_os = "macos", target_arch = "aarch64"))
344 || cfg!(target_arch = "arm")
347 || cfg!(asan)
349 );
350 }
351
352 fn run_test() {
353 let fiber = Fiber::<(), (), ()>::new(fiber_stack(1024 * 1024), move |(), s| {
354 assert_contains_host();
355 s.suspend(());
356 assert_contains_host();
357 s.suspend(());
358 assert_contains_host();
359 })
360 .unwrap();
361 assert!(fiber.resume(()).is_err());
362 assert!(fiber.resume(()).is_err());
363 assert!(fiber.resume(()).is_ok());
364 }
365
366 look_for_me();
367 }
368
369 #[test]
370 #[cfg(feature = "std")]
371 fn panics_propagated() {
372 use std::panic::{self, AssertUnwindSafe};
373
374 let a = Rc::new(Cell::new(false));
375 let b = SetOnDrop(a.clone());
376 let fiber = Fiber::<(), (), ()>::new(fiber_stack(1024 * 1024), move |(), _s| {
377 let _ = &b;
378 panic!();
379 })
380 .unwrap();
381 assert!(panic::catch_unwind(AssertUnwindSafe(|| fiber.resume(()))).is_err());
382 assert!(a.get());
383
384 struct SetOnDrop(Rc<Cell<bool>>);
385
386 impl Drop for SetOnDrop {
387 fn drop(&mut self) {
388 self.0.set(true);
389 }
390 }
391 }
392
393 #[test]
394 fn suspend_and_resume_values() {
395 let fiber = Fiber::new(fiber_stack(1024 * 1024), move |first, s| {
396 assert_eq!(first, 2.0);
397 assert_eq!(s.suspend(4), 3.0);
398 "hello".to_string()
399 })
400 .unwrap();
401 assert_eq!(fiber.resume(2.0), Err(4));
402 assert_eq!(fiber.resume(3.0), Ok("hello".to_string()));
403 }
404}