1use futures_util::Future;
2
3use std::{
4    mem::MaybeUninit,
5    pin::Pin,
6    sync::{Arc, atomic::Ordering},
7    task::{Context, Poll, Waker},
8};
9
10use std::sync::atomic::{AtomicPtr, AtomicU64};
11
12#[cfg(feature = "ordered-locks")]
13use ordered_locks::{L0, LockToken};
14
15#[cfg(feature = "runtoken-id")]
16static IDC: AtomicU64 = AtomicU64::new(0);
17
18pub struct IntrusiveList<T> {
19    first: *mut ListNode<T>,
20}
21
22impl<T> Default for IntrusiveList<T> {
23    fn default() -> Self {
24        Self {
25            first: std::ptr::null_mut(),
26        }
27    }
28}
29
30impl<T> IntrusiveList<T> {
31    unsafe fn push_back(&mut self, node: *mut ListNode<T>, v: T) {
32        unsafe {
33            assert!((*node).next.is_null());
34            (*node).data.write(v);
35            if self.first.is_null() {
36                (*node).next = node;
37                (*node).prev = node;
38                self.first = node;
39            } else {
40                (*node).prev = (*self.first).prev;
41                (*node).next = self.first;
42                (*(*node).prev).next = node;
43                (*(*node).next).prev = node;
44            }
45        }
46    }
47
48    unsafe fn remove(&mut self, node: *mut ListNode<T>) -> T {
49        unsafe {
50            assert!(!(*node).next.is_null());
51            let v = (*node).data.as_mut_ptr().read();
52            if (*node).next == node {
53                self.first = std::ptr::null_mut();
54            } else {
55                if self.first == node {
56                    self.first = (*node).next;
57                }
58                (*(*node).next).prev = (*node).prev;
59                (*(*node).prev).next = (*node).next;
60            }
61            (*node).next = std::ptr::null_mut();
62            (*node).prev = std::ptr::null_mut();
63            v
64        }
65    }
66
67    unsafe fn drain(&mut self, v: impl Fn(T)) {
68        unsafe {
69            if self.first.is_null() {
70                return;
71            }
72            let mut cur = self.first;
73            loop {
74                v((*cur).data.as_mut_ptr().read());
75                let next = (*cur).next;
76                (*cur).next = std::ptr::null_mut();
77                (*cur).prev = std::ptr::null_mut();
78                if next == self.first {
79                    break;
80                }
81                cur = next;
82            }
83            self.first = std::ptr::null_mut();
84        }
85    }
86
87    unsafe fn in_list(&self, node: *mut ListNode<T>) -> bool {
88        unsafe { !(*node).next.is_null() }
89    }
90}
91
92pub struct ListNode<T> {
93    prev: *mut ListNode<T>,
94    next: *mut ListNode<T>,
95    data: std::mem::MaybeUninit<T>,
96    _pin: std::marker::PhantomPinned,
97}
98
99impl<T> Default for ListNode<T> {
100    fn default() -> Self {
101        Self {
102            prev: std::ptr::null_mut(),
103            next: std::ptr::null_mut(),
104            data: MaybeUninit::uninit(),
105            _pin: Default::default(),
106        }
107    }
108}
109enum State {
110    Run,
111    Cancel,
112    #[cfg(feature = "pause")]
113    Pause,
114}
115
116struct Content {
117    state: State,
118    cancel_wakers: IntrusiveList<Waker>,
119    run_wakers: IntrusiveList<Waker>,
120}
121
122unsafe impl Send for Content {}
123
124impl Content {
125    unsafe fn add_cancle_waker(&mut self, node: *mut ListNode<Waker>, waker: &Waker) {
126        unsafe {
127            if !self.cancel_wakers.in_list(node) {
128                self.cancel_wakers.push_back(node, waker.clone())
129            }
130        }
131    }
132
133    #[cfg(feature = "pause")]
134    unsafe fn add_run_waker(&mut self, node: *mut ListNode<Waker>, waker: &Waker) {
135        if !self.run_wakers.in_list(node) {
136            self.run_wakers.push_back(node, waker.clone())
137        }
138    }
139
140    unsafe fn remove_cancle_waker(&mut self, node: *mut ListNode<Waker>) {
141        unsafe {
142            if self.cancel_wakers.in_list(node) {
143                self.cancel_wakers.remove(node);
144            }
145        }
146    }
147
148    #[cfg(feature = "pause")]
149    unsafe fn remove_run_waker(&mut self, node: *mut ListNode<Waker>) {
150        if self.run_wakers.in_list(node) {
151            self.run_wakers.remove(node);
152        }
153    }
154}
155
156struct Inner {
157    cond: std::sync::Condvar,
158    content: std::sync::Mutex<Content>,
159    #[cfg(feature = "runtoken-id")]
160    id: u64,
161    location_file: AtomicPtr<u8>,
162    location_line: AtomicU64,
163}
164
165#[inline]
168fn multiply_mix(x: u64, y: u64) -> u64 {
169    let full = (x as u128) * (y as u128);
170    let lo = full as u64;
171    let hi = (full >> 64) as u64;
172    lo ^ hi
173}
174
175fn fxhash(bytes: &[u8]) -> u64 {
176    let len = bytes.len();
177    let mut s0 = 0x243f6a8885a308d3;
178    let mut s1 = 0x13198a2e03707344;
179    if len <= 16 {
180        if len >= 8 {
182            s0 ^= u64::from_le_bytes(bytes[0..8].try_into().unwrap());
183            s1 ^= u64::from_le_bytes(bytes[len - 8..].try_into().unwrap());
184        } else if len >= 4 {
185            s0 ^= u32::from_le_bytes(bytes[0..4].try_into().unwrap()) as u64;
186            s1 ^= u32::from_le_bytes(bytes[len - 4..].try_into().unwrap()) as u64;
187        } else if len > 0 {
188            let lo = bytes[0];
189            let mid = bytes[len / 2];
190            let hi = bytes[len - 1];
191            s0 ^= lo as u64;
192            s1 ^= ((hi as u64) << 8) | mid as u64;
193        }
194    } else {
195        let mut off = 0;
197        while off < len - 16 {
198            let x = u64::from_le_bytes(bytes[off..off + 8].try_into().unwrap());
199            let y = u64::from_le_bytes(bytes[off + 8..off + 16].try_into().unwrap());
200            let t = multiply_mix(s0 ^ x, 0xa4093822299f31d0 ^ y);
201            s0 = s1;
202            s1 = t;
203            off += 16;
204        }
205        let suffix = &bytes[len - 16..];
206        s0 ^= u64::from_le_bytes(suffix[0..8].try_into().unwrap());
207        s1 ^= u64::from_le_bytes(suffix[8..16].try_into().unwrap());
208    }
209    multiply_mix(s0, s1) ^ (len as u64)
210}
211
212const LINE_MASK: u64 = 0x0000000000FFFFFF;
213const HASH_MASK: u64 = !LINE_MASK;
214#[derive(Clone)]
219pub struct RunToken(Arc<Inner>);
220
221impl RunToken {
222    #[cfg(feature = "pause")]
224    pub fn new_paused() -> Self {
225        Self(Arc::new(Inner {
226            cond: std::sync::Condvar::new(),
227            content: std::sync::Mutex::new(Content {
228                state: State::Pause,
229                cancel_wakers: Default::default(),
230                run_wakers: Default::default(),
231                location: None,
232            }),
233            #[cfg(feature = "runtoken-id")]
234            id: IDC.fetch_add(1, std::sync::atomic::Ordering::SeqCst),
235        }))
236    }
237
238    pub fn new() -> Self {
240        Self(Arc::new(Inner {
241            cond: std::sync::Condvar::new(),
242            content: std::sync::Mutex::new(Content {
243                state: State::Run,
244                cancel_wakers: Default::default(),
245                run_wakers: Default::default(),
246            }),
247            #[cfg(feature = "runtoken-id")]
248            id: IDC.fetch_add(1, std::sync::atomic::Ordering::SeqCst),
249            location_file: AtomicPtr::new(std::ptr::null_mut()),
250            location_line: AtomicU64::new(0),
251        }))
252    }
253
254    pub fn cancel(&self) {
256        let mut content = self.0.content.lock().unwrap();
257        if matches!(content.state, State::Cancel) {
258            return;
259        }
260        content.state = State::Cancel;
261
262        unsafe {
263            content.run_wakers.drain(|w| w.wake());
264            content.cancel_wakers.drain(|w| w.wake());
265        }
266        self.0.cond.notify_all();
267    }
268
269    #[cfg(feature = "pause")]
271    pub fn pause(&self) {
272        let mut content = self.0.content.lock().unwrap();
273        if !matches!(content.state, State::Run) {
274            return;
275        }
276        content.state = State::Pause;
277    }
278
279    #[cfg(feature = "pause")]
281    pub fn resume(&self) {
282        let mut content = self.0.content.lock().unwrap();
283        if !matches!(content.state, State::Pause) {
284            return;
285        }
286        content.state = State::Run;
287        unsafe {
288            content.run_wakers.drain(|w| w.wake());
289        }
290        self.0.cond.notify_all();
291    }
292
293    pub fn is_cancelled(&self) -> bool {
295        matches!(self.0.content.lock().unwrap().state, State::Cancel)
296    }
297
298    #[cfg(feature = "pause")]
300    pub fn is_paused(&self) -> bool {
301        matches!(self.0.content.lock().unwrap().state, State::Pause)
302    }
303
304    #[cfg(feature = "pause")]
306    pub fn is_running(&self) -> bool {
307        matches!(self.0.content.lock().unwrap().state, State::Run)
308    }
309
310    #[cfg(feature = "pause")]
312    pub fn wait_paused_check_cancelled_sync(&self) -> bool {
313        let mut content = self.0.content.lock().unwrap();
314        loop {
315            match &content.state {
316                State::Run => return false,
317                State::Cancel => return true,
318                State::Pause => {
319                    content = self.0.cond.wait(content).unwrap();
320                }
321            }
322        }
323    }
324
325    #[cfg(feature = "pause")]
327    pub fn wait_paused_check_cancelled(&self) -> WaitForPauseFuture<'_> {
328        WaitForPauseFuture {
329            token: self,
330            waker: Default::default(),
331        }
332    }
333
334    pub fn cancelled(&self) -> WaitForCancellationFuture<'_> {
336        WaitForCancellationFuture {
337            token: self,
338            waker: Default::default(),
339        }
340    }
341
342    #[cfg(feature = "ordered-locks")]
343    pub fn cancelled_checked(
344        &self,
345        _lock_token: LockToken<'_, L0>,
346    ) -> WaitForCancellationFuture<'_> {
347        WaitForCancellationFuture {
348            token: self,
349            waker: Default::default(),
350        }
351    }
352
353    #[inline]
355    pub fn set_location(&self, file: &'static str, line: u32) {
356        let rs_loc = file.find(".rs").expect(".rs in file name");
357        let file = &file[..rs_loc + 3];
358        assert!((line as u64) < LINE_MASK);
359        let hash = fxhash(file.as_bytes());
360        self.0
361            .location_file
362            .store(file.as_ptr() as *mut u8, Ordering::Relaxed);
363        self.0
364            .location_line
365            .store((hash & HASH_MASK) | line as u64, Ordering::Relaxed);
366    }
367
368    pub fn location(&self) -> Option<(&'static str, u32)> {
370        let mut cnt = 0;
371        loop {
372            let file = self.0.location_file.load(Ordering::Relaxed) as *const u8;
373            let line = self.0.location_line.load(Ordering::Relaxed);
374            if file.is_null() {
375                return None;
376            }
377            let mut len = 0;
378            let file = loop {
379                unsafe {
381                    if *file.add(len) == b'.'
382                        && *file.add(len + 1) == b'r'
383                        && *file.add(len + 2) == b's'
384                    {
385                        break std::str::from_utf8_unchecked(std::slice::from_raw_parts(
386                            file,
387                            len + 3,
388                        ));
389                    }
390                }
391                len += 1;
392            };
393
394            let hash = fxhash(file.as_bytes());
395            if (hash & HASH_MASK) == (line & HASH_MASK) {
396                return Some((file, (line & LINE_MASK) as u32));
397            }
398            if cnt == 0xFFFF {
399                return Some((file, 0));
400            }
401            cnt += 1;
402            std::hint::spin_loop();
403        }
404    }
405
406    #[cfg(feature = "runtoken-id")]
407    #[inline]
408    pub fn id(&self) -> u64 {
409        self.0.id
410    }
411}
412
413impl Default for RunToken {
414    fn default() -> Self {
415        Self::new()
416    }
417}
418
419impl core::fmt::Debug for RunToken {
420    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
421        let mut d = f.debug_tuple("RunToken");
422        match self.0.content.lock().unwrap().state {
423            State::Run => d.field(&"Running"),
424            State::Cancel => d.field(&"Canceled"),
425            #[cfg(feature = "pause")]
426            State::Pause => d.field(&"Paused"),
427        };
428        d.finish()
429    }
430}
431
432#[must_use = "futures do nothing unless polled"]
434pub struct WaitForCancellationFuture<'a> {
435    token: &'a RunToken,
436    waker: ListNode<Waker>,
437}
438
439impl<'a> core::fmt::Debug for WaitForCancellationFuture<'a> {
440    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
441        f.debug_struct("WaitForCancellationFuture").finish()
442    }
443}
444
445impl<'a> Future for WaitForCancellationFuture<'a> {
446    type Output = ();
447
448    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
449        let mut content = self.token.0.content.lock().unwrap();
450        match content.state {
451            State::Cancel => Poll::Ready(()),
452            State::Run => {
453                unsafe {
454                    content.add_cancle_waker(&mut Pin::get_unchecked_mut(self).waker, cx.waker());
455                }
456                Poll::Pending
457            }
458            #[cfg(feature = "pause")]
459            State::Pause => {
460                unsafe {
461                    content.add_cancle_waker(&mut Pin::get_unchecked_mut(self).waker, cx.waker());
462                }
463                Poll::Pending
464            }
465        }
466    }
467}
468
469impl<'a> Drop for WaitForCancellationFuture<'a> {
470    fn drop(&mut self) {
471        unsafe {
472            self.token
473                .0
474                .content
475                .lock()
476                .unwrap()
477                .remove_cancle_waker(&mut self.waker);
478        }
479    }
480}
481
482unsafe impl<'a> Send for WaitForCancellationFuture<'a> {}
483
484#[cfg(feature = "pause")]
486#[must_use = "futures do nothing unless polled"]
487pub struct WaitForPauseFuture<'a> {
488    token: &'a RunToken,
489    waker: ListNode<Waker>,
490}
491
492#[cfg(feature = "pause")]
493impl<'a> core::fmt::Debug for WaitForPauseFuture<'a> {
494    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
495        f.debug_struct("WaitForPauseFuture").finish()
496    }
497}
498
499#[cfg(feature = "pause")]
500impl<'a> Future for WaitForPauseFuture<'a> {
501    type Output = bool;
502
503    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<bool> {
504        let mut content = self.token.0.content.lock().unwrap();
505        match content.state {
506            State::Cancel => Poll::Ready(true),
507            State::Run => Poll::Ready(false),
508            State::Pause => {
509                unsafe {
510                    content.add_run_waker(&mut Pin::get_unchecked_mut(self).waker, cx.waker());
511                }
512                Poll::Pending
513            }
514        }
515    }
516}
517
518#[cfg(feature = "pause")]
519impl<'a> Drop for WaitForPauseFuture<'a> {
520    fn drop(&mut self) {
521        unsafe {
522            self.token
523                .0
524                .content
525                .lock()
526                .unwrap()
527                .remove_run_waker(&mut self.waker);
528        }
529    }
530}
531
532#[cfg(feature = "pause")]
533unsafe impl<'a> Send for WaitForPauseFuture<'a> {}