1use futures_util::Future;
3
4use std::{
5 mem::MaybeUninit,
6 pin::Pin,
7 sync::{Arc, atomic::Ordering},
8 task::{Context, Poll, Waker},
9};
10
11use std::sync::atomic::AtomicPtr;
12#[cfg(feature = "runtoken-id")]
13use std::sync::atomic::AtomicU64;
14
15#[cfg(feature = "ordered-locks")]
16use ordered_locks::{L0, LockToken};
17
18#[cfg(feature = "runtoken-id")]
20static IDC: AtomicU64 = AtomicU64::new(0);
21
22pub struct IntrusiveList<T> {
24 first: *mut ListNode<T>,
27}
28
29impl<T> Default for IntrusiveList<T> {
30 fn default() -> Self {
31 Self {
32 first: std::ptr::null_mut(),
33 }
34 }
35}
36
37impl<T> IntrusiveList<T> {
38 unsafe fn push_back(&mut self, node: *mut ListNode<T>, v: T) {
48 let n = unsafe { &mut *node };
50 assert!(n.next.is_null());
51 n.data.write(v);
52 if self.first.is_null() {
53 n.next = n;
54 n.prev = n;
55 self.first = n;
56 } else {
57 let f = unsafe { &mut *self.first };
59 n.prev = f.prev;
60 n.next = self.first;
61 unsafe {
63 (*n.prev).next = node;
64 }
65 f.prev = node;
66 }
67 }
68
69 unsafe fn remove(&mut self, node: *mut ListNode<T>) -> T {
78 let n = unsafe { &mut *node };
80 assert!(!n.next.is_null());
81 let v = unsafe { n.data.as_mut_ptr().read() };
84 if n.next == node {
85 self.first = std::ptr::null_mut();
86 } else {
87 if self.first == node {
88 self.first = n.next;
89 }
90 unsafe {
92 (*n.next).prev = n.prev;
93 }
94 unsafe {
96 (*n.prev).next = n.next;
97 }
98 }
99 n.next = std::ptr::null_mut();
100 n.prev = std::ptr::null_mut();
101 v
102 }
103
104 fn drain(&mut self, v: impl Fn(T)) {
106 if self.first.is_null() {
107 return;
108 }
109 let mut cur = self.first;
110 loop {
111 let c = unsafe { &mut *cur };
114 let d = unsafe { c.data.as_mut_ptr().read() };
116 v(d);
117 let next = c.next;
118 c.next = std::ptr::null_mut();
119 c.prev = std::ptr::null_mut();
120 if next == self.first {
121 break;
122 }
123 cur = next;
124 }
125 self.first = std::ptr::null_mut();
126 }
127
128 unsafe fn in_list(&self, node: *mut ListNode<T>) -> bool {
134 unsafe { !(*node).next.is_null() }
136 }
137}
138
139pub struct ListNode<T> {
141 prev: *mut ListNode<T>,
143 next: *mut ListNode<T>,
145 data: std::mem::MaybeUninit<T>,
147 _pin: std::marker::PhantomPinned,
149}
150
151impl<T> Default for ListNode<T> {
152 fn default() -> Self {
153 Self {
154 prev: std::ptr::null_mut(),
155 next: std::ptr::null_mut(),
156 data: MaybeUninit::uninit(),
157 _pin: Default::default(),
158 }
159 }
160}
161
162enum State {
164 Run,
166 Cancel,
168 #[cfg(feature = "pause")]
170 Pause,
171}
172
173struct Content {
175 state: State,
177 cancel_wakers: IntrusiveList<Waker>,
179 run_wakers: IntrusiveList<Waker>,
181 #[cfg(feature = "runtoken-user-data")]
183 user_data: Option<String>,
184}
185
186unsafe impl Send for Content {}
188
189impl Content {
190 unsafe fn add_cancel_waker(&mut self, node: *mut ListNode<Waker>, waker: &Waker) {
196 let in_list = unsafe { self.cancel_wakers.in_list(node) };
199 if !in_list {
200 unsafe { self.cancel_wakers.push_back(node, waker.clone()) }
204 }
205 }
206
207 #[cfg(feature = "pause")]
213 unsafe fn add_run_waker(&mut self, node: *mut ListNode<Waker>, waker: &Waker) {
214 let in_list = unsafe { self.run_wakers.in_list(node) };
217 if !in_list {
218 unsafe { self.run_wakers.push_back(node, waker.clone()) }
222 }
223 }
224
225 unsafe fn remove_cancel_waker(&mut self, node: *mut ListNode<Waker>) {
232 let in_list = unsafe { self.cancel_wakers.in_list(node) };
235 if in_list {
236 unsafe { self.cancel_wakers.remove(node) };
238 }
239 }
240
241 #[cfg(feature = "pause")]
248 unsafe fn remove_run_waker(&mut self, node: *mut ListNode<Waker>) {
249 let in_list = unsafe { self.run_wakers.in_list(node) };
252 if in_list {
253 unsafe { self.run_wakers.remove(node) };
255 }
256 }
257}
258
259struct Inner {
261 cond: std::sync::Condvar,
263 content: std::sync::Mutex<Content>,
265 #[cfg(feature = "runtoken-id")]
267 id: u64,
268 location_file_line: AtomicPtr<u8>,
271}
272#[derive(Clone)]
277pub struct RunToken(Arc<Inner>);
278
279impl RunToken {
280 #[cfg(feature = "pause")]
282 pub fn new_paused() -> Self {
283 Self(Arc::new(Inner {
284 cond: std::sync::Condvar::new(),
285 content: std::sync::Mutex::new(Content {
286 state: State::Pause,
287 cancel_wakers: Default::default(),
288 run_wakers: Default::default(),
289 #[cfg(feature = "runtoken-user-data")]
290 user_data: None,
291 }),
292 location_file_line: Default::default(),
293 #[cfg(feature = "runtoken-id")]
294 id: IDC.fetch_add(1, std::sync::atomic::Ordering::SeqCst),
295 }))
296 }
297
298 pub fn new() -> Self {
300 Self(Arc::new(Inner {
301 cond: std::sync::Condvar::new(),
302 content: std::sync::Mutex::new(Content {
303 state: State::Run,
304 cancel_wakers: Default::default(),
305 run_wakers: Default::default(),
306 #[cfg(feature = "runtoken-user-data")]
307 user_data: None,
308 }),
309 location_file_line: Default::default(),
310 #[cfg(feature = "runtoken-id")]
311 id: IDC.fetch_add(1, std::sync::atomic::Ordering::SeqCst),
312 }))
313 }
314
315 pub fn cancel(&self) {
317 let mut content = self.0.content.lock().unwrap();
318 if matches!(content.state, State::Cancel) {
319 return;
320 }
321 content.state = State::Cancel;
322
323 content.run_wakers.drain(|w| w.wake());
324 content.cancel_wakers.drain(|w| w.wake());
325 self.0.cond.notify_all();
326 }
327
328 #[cfg(feature = "pause")]
330 pub fn pause(&self) {
331 let mut content = self.0.content.lock().unwrap();
332 if !matches!(content.state, State::Run) {
333 return;
334 }
335 content.state = State::Pause;
336 }
337
338 #[cfg(feature = "pause")]
340 pub fn resume(&self) {
341 let mut content = self.0.content.lock().unwrap();
342 if !matches!(content.state, State::Pause) {
343 return;
344 }
345 content.state = State::Run;
346 content.run_wakers.drain(|w| w.wake());
347 self.0.cond.notify_all();
348 }
349
350 pub fn is_cancelled(&self) -> bool {
352 matches!(self.0.content.lock().unwrap().state, State::Cancel)
353 }
354
355 #[cfg(feature = "pause")]
357 pub fn is_paused(&self) -> bool {
358 matches!(self.0.content.lock().unwrap().state, State::Pause)
359 }
360
361 #[cfg(feature = "pause")]
363 pub fn is_running(&self) -> bool {
364 matches!(self.0.content.lock().unwrap().state, State::Run)
365 }
366
367 #[cfg(feature = "pause")]
369 pub fn wait_paused_check_cancelled_sync(&self) -> bool {
370 let mut content = self.0.content.lock().unwrap();
371 loop {
372 match &content.state {
373 State::Run => return false,
374 State::Cancel => return true,
375 State::Pause => {
376 content = self.0.cond.wait(content).unwrap();
377 }
378 }
379 }
380 }
381
382 #[cfg(feature = "pause")]
384 pub fn wait_paused_check_cancelled(&self) -> WaitForPauseFuture<'_> {
385 WaitForPauseFuture {
386 token: self,
387 waker: Default::default(),
388 }
389 }
390
391 pub fn cancelled(&self) -> WaitForCancellationFuture<'_> {
393 WaitForCancellationFuture {
394 token: self,
395 waker: Default::default(),
396 }
397 }
398
399 #[cfg(feature = "ordered-locks")]
402 pub fn cancelled_checked(
403 &self,
404 _lock_token: LockToken<'_, L0>,
405 ) -> WaitForCancellationFuture<'_> {
406 WaitForCancellationFuture {
407 token: self,
408 waker: Default::default(),
409 }
410 }
411
412 #[inline]
417 pub fn set_location_file_line(&self, file_line_str: &'static str) {
418 assert!(file_line_str.ends_with('\0'));
419 self.0
420 .location_file_line
421 .store(file_line_str.as_ptr() as *mut u8, Ordering::Relaxed);
422 }
423
424 pub fn location(&self) -> Option<(&'static str, u32)> {
426 let location_file_line = self.0.location_file_line.load(Ordering::Relaxed) as *const u8;
427 if location_file_line.is_null() {
428 return None;
429 }
430 let mut len = 0;
431 loop {
434 let l = unsafe { location_file_line.add(len) };
437 let c = unsafe { *l };
440 if c == b'\0' {
441 break;
442 }
443 len += 1;
444 }
445
446 let location_file_line = unsafe { std::slice::from_raw_parts(location_file_line, len) };
448
449 let location_file_line = unsafe { std::str::from_utf8_unchecked(location_file_line) };
451
452 match location_file_line.rsplit_once(":") {
453 Some((file, line)) => match line.parse() {
454 Ok(v) => Some((file, v)),
455 Err(_) => Some((location_file_line, 0)),
456 },
457 None => Some((location_file_line, 0)),
458 }
459 }
460
461 #[cfg(feature = "runtoken-id")]
462 #[inline]
464 pub fn id(&self) -> u64 {
465 self.0.id
466 }
467
468 #[cfg(feature = "runtoken-user-data")]
469 pub fn set_user_data(&self, data: Option<String>) {
471 self.0.content.lock().unwrap().user_data = data;
472 }
473
474 #[cfg(feature = "runtoken-user-data")]
475 pub fn user_data(&self) -> Option<String> {
477 self.0.content.lock().unwrap().user_data.clone()
478 }
479}
480
481#[macro_export]
483macro_rules! set_location {
484 ($run_token: expr) => {
485 $run_token.set_location_file_line(concat!(file!(), ":", line!(), "\0"));
486 };
487}
488
489impl Default for RunToken {
490 fn default() -> Self {
491 Self::new()
492 }
493}
494
495impl core::fmt::Debug for RunToken {
496 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
497 let mut d = f.debug_tuple("RunToken");
498 match self.0.content.lock().unwrap().state {
499 State::Run => d.field(&"Running"),
500 State::Cancel => d.field(&"Canceled"),
501 #[cfg(feature = "pause")]
502 State::Pause => d.field(&"Paused"),
503 };
504 d.finish()
505 }
506}
507
508#[must_use = "futures do nothing unless polled"]
513pub struct WaitForCancellationFuture<'a> {
514 token: &'a RunToken,
516 waker: ListNode<Waker>,
518}
519
520impl<'a> core::fmt::Debug for WaitForCancellationFuture<'a> {
521 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
522 f.debug_struct("WaitForCancellationFuture").finish()
523 }
524}
525
526impl<'a> Future for WaitForCancellationFuture<'a> {
527 type Output = ();
528
529 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
530 let mut content = self.token.0.content.lock().unwrap();
531 match content.state {
532 State::Cancel => Poll::Ready(()),
533 State::Run => {
534 let node = unsafe { &mut Pin::get_unchecked_mut(self).waker };
536 unsafe { content.add_cancel_waker(node, cx.waker()) };
539 Poll::Pending
540 }
541 #[cfg(feature = "pause")]
542 State::Pause => {
543 let node = unsafe { &mut Pin::get_unchecked_mut(self).waker };
545 unsafe { content.add_cancel_waker(node, cx.waker()) };
548 Poll::Pending
549 }
550 }
551 }
552}
553
554impl<'a> Drop for WaitForCancellationFuture<'a> {
555 fn drop(&mut self) {
556 unsafe {
558 self.token
559 .0
560 .content
561 .lock()
562 .unwrap()
563 .remove_cancel_waker(&mut self.waker);
564 }
565 }
566}
567
568unsafe impl<'a> Send for WaitForCancellationFuture<'a> {}
570
571#[cfg(feature = "pause")]
572#[must_use = "futures do nothing unless polled"]
577pub struct WaitForPauseFuture<'a> {
578 token: &'a RunToken,
580 waker: ListNode<Waker>,
582}
583
584#[cfg(feature = "pause")]
585impl<'a> core::fmt::Debug for WaitForPauseFuture<'a> {
586 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
587 f.debug_struct("WaitForPauseFuture").finish()
588 }
589}
590
591#[cfg(feature = "pause")]
592impl<'a> Future for WaitForPauseFuture<'a> {
593 type Output = bool;
594
595 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<bool> {
596 let mut content = self.token.0.content.lock().unwrap();
597 match content.state {
598 State::Cancel => Poll::Ready(true),
599 State::Run => Poll::Ready(false),
600 State::Pause => {
601 let node = unsafe { &mut Pin::get_unchecked_mut(self).waker };
603 unsafe { content.add_run_waker(node, cx.waker()) };
605 Poll::Pending
606 }
607 }
608 }
609}
610
611#[cfg(feature = "pause")]
612impl<'a> Drop for WaitForPauseFuture<'a> {
613 fn drop(&mut self) {
614 unsafe {
616 self.token
617 .0
618 .content
619 .lock()
620 .unwrap()
621 .remove_run_waker(&mut self.waker);
622 }
623 }
624}
625
626#[cfg(feature = "pause")]
627unsafe impl<'a> Send for WaitForPauseFuture<'a> {}