1use futures_util::Future;
2
3use std::{
4 mem::MaybeUninit,
5 pin::Pin,
6 sync::{Arc, atomic::Ordering},
7 task::{Context, Poll, Waker},
8};
9
10use std::sync::atomic::{AtomicPtr, AtomicU64};
11
12#[cfg(feature = "ordered-locks")]
13use ordered_locks::{L0, LockToken};
14
15#[cfg(feature = "runtoken-id")]
16static IDC: AtomicU64 = AtomicU64::new(0);
17
18pub struct IntrusiveList<T> {
19 first: *mut ListNode<T>,
20}
21
22impl<T> Default for IntrusiveList<T> {
23 fn default() -> Self {
24 Self {
25 first: std::ptr::null_mut(),
26 }
27 }
28}
29
30impl<T> IntrusiveList<T> {
31 unsafe fn push_back(&mut self, node: *mut ListNode<T>, v: T) {
32 unsafe {
33 assert!((*node).next.is_null());
34 (*node).data.write(v);
35 if self.first.is_null() {
36 (*node).next = node;
37 (*node).prev = node;
38 self.first = node;
39 } else {
40 (*node).prev = (*self.first).prev;
41 (*node).next = self.first;
42 (*(*node).prev).next = node;
43 (*(*node).next).prev = node;
44 }
45 }
46 }
47
48 unsafe fn remove(&mut self, node: *mut ListNode<T>) -> T {
49 unsafe {
50 assert!(!(*node).next.is_null());
51 let v = (*node).data.as_mut_ptr().read();
52 if (*node).next == node {
53 self.first = std::ptr::null_mut();
54 } else {
55 if self.first == node {
56 self.first = (*node).next;
57 }
58 (*(*node).next).prev = (*node).prev;
59 (*(*node).prev).next = (*node).next;
60 }
61 (*node).next = std::ptr::null_mut();
62 (*node).prev = std::ptr::null_mut();
63 v
64 }
65 }
66
67 unsafe fn drain(&mut self, v: impl Fn(T)) {
68 unsafe {
69 if self.first.is_null() {
70 return;
71 }
72 let mut cur = self.first;
73 loop {
74 v((*cur).data.as_mut_ptr().read());
75 let next = (*cur).next;
76 (*cur).next = std::ptr::null_mut();
77 (*cur).prev = std::ptr::null_mut();
78 if next == self.first {
79 break;
80 }
81 cur = next;
82 }
83 self.first = std::ptr::null_mut();
84 }
85 }
86
87 unsafe fn in_list(&self, node: *mut ListNode<T>) -> bool {
88 unsafe { !(*node).next.is_null() }
89 }
90}
91
92pub struct ListNode<T> {
93 prev: *mut ListNode<T>,
94 next: *mut ListNode<T>,
95 data: std::mem::MaybeUninit<T>,
96 _pin: std::marker::PhantomPinned,
97}
98
99impl<T> Default for ListNode<T> {
100 fn default() -> Self {
101 Self {
102 prev: std::ptr::null_mut(),
103 next: std::ptr::null_mut(),
104 data: MaybeUninit::uninit(),
105 _pin: Default::default(),
106 }
107 }
108}
109enum State {
110 Run,
111 Cancel,
112 #[cfg(feature = "pause")]
113 Pause,
114}
115
116struct Content {
117 state: State,
118 cancel_wakers: IntrusiveList<Waker>,
119 run_wakers: IntrusiveList<Waker>,
120}
121
122unsafe impl Send for Content {}
123
124impl Content {
125 unsafe fn add_cancle_waker(&mut self, node: *mut ListNode<Waker>, waker: &Waker) {
126 unsafe {
127 if !self.cancel_wakers.in_list(node) {
128 self.cancel_wakers.push_back(node, waker.clone())
129 }
130 }
131 }
132
133 #[cfg(feature = "pause")]
134 unsafe fn add_run_waker(&mut self, node: *mut ListNode<Waker>, waker: &Waker) {
135 if !self.run_wakers.in_list(node) {
136 self.run_wakers.push_back(node, waker.clone())
137 }
138 }
139
140 unsafe fn remove_cancle_waker(&mut self, node: *mut ListNode<Waker>) {
141 unsafe {
142 if self.cancel_wakers.in_list(node) {
143 self.cancel_wakers.remove(node);
144 }
145 }
146 }
147
148 #[cfg(feature = "pause")]
149 unsafe fn remove_run_waker(&mut self, node: *mut ListNode<Waker>) {
150 if self.run_wakers.in_list(node) {
151 self.run_wakers.remove(node);
152 }
153 }
154}
155
156struct Inner {
157 cond: std::sync::Condvar,
158 content: std::sync::Mutex<Content>,
159 #[cfg(feature = "runtoken-id")]
160 id: u64,
161 location_file: AtomicPtr<u8>,
162 location_line: AtomicU64,
163}
164
165#[inline]
168fn multiply_mix(x: u64, y: u64) -> u64 {
169 let full = (x as u128) * (y as u128);
170 let lo = full as u64;
171 let hi = (full >> 64) as u64;
172 lo ^ hi
173}
174
175fn fxhash(bytes: &[u8]) -> u64 {
176 let len = bytes.len();
177 let mut s0 = 0x243f6a8885a308d3;
178 let mut s1 = 0x13198a2e03707344;
179 if len <= 16 {
180 if len >= 8 {
182 s0 ^= u64::from_le_bytes(bytes[0..8].try_into().unwrap());
183 s1 ^= u64::from_le_bytes(bytes[len - 8..].try_into().unwrap());
184 } else if len >= 4 {
185 s0 ^= u32::from_le_bytes(bytes[0..4].try_into().unwrap()) as u64;
186 s1 ^= u32::from_le_bytes(bytes[len - 4..].try_into().unwrap()) as u64;
187 } else if len > 0 {
188 let lo = bytes[0];
189 let mid = bytes[len / 2];
190 let hi = bytes[len - 1];
191 s0 ^= lo as u64;
192 s1 ^= ((hi as u64) << 8) | mid as u64;
193 }
194 } else {
195 let mut off = 0;
197 while off < len - 16 {
198 let x = u64::from_le_bytes(bytes[off..off + 8].try_into().unwrap());
199 let y = u64::from_le_bytes(bytes[off + 8..off + 16].try_into().unwrap());
200 let t = multiply_mix(s0 ^ x, 0xa4093822299f31d0 ^ y);
201 s0 = s1;
202 s1 = t;
203 off += 16;
204 }
205 let suffix = &bytes[len - 16..];
206 s0 ^= u64::from_le_bytes(suffix[0..8].try_into().unwrap());
207 s1 ^= u64::from_le_bytes(suffix[8..16].try_into().unwrap());
208 }
209 multiply_mix(s0, s1) ^ (len as u64)
210}
211
212const LINE_MASK: u64 = 0x0000000000FFFFFF;
213const HASH_MASK: u64 = !LINE_MASK;
214#[derive(Clone)]
219pub struct RunToken(Arc<Inner>);
220
221impl RunToken {
222 #[cfg(feature = "pause")]
224 pub fn new_paused() -> Self {
225 Self(Arc::new(Inner {
226 cond: std::sync::Condvar::new(),
227 content: std::sync::Mutex::new(Content {
228 state: State::Pause,
229 cancel_wakers: Default::default(),
230 run_wakers: Default::default(),
231 location: None,
232 }),
233 #[cfg(feature = "runtoken-id")]
234 id: IDC.fetch_add(1, std::sync::atomic::Ordering::SeqCst),
235 }))
236 }
237
238 pub fn new() -> Self {
240 Self(Arc::new(Inner {
241 cond: std::sync::Condvar::new(),
242 content: std::sync::Mutex::new(Content {
243 state: State::Run,
244 cancel_wakers: Default::default(),
245 run_wakers: Default::default(),
246 }),
247 #[cfg(feature = "runtoken-id")]
248 id: IDC.fetch_add(1, std::sync::atomic::Ordering::SeqCst),
249 location_file: AtomicPtr::new(std::ptr::null_mut()),
250 location_line: AtomicU64::new(0),
251 }))
252 }
253
254 pub fn cancel(&self) {
256 let mut content = self.0.content.lock().unwrap();
257 if matches!(content.state, State::Cancel) {
258 return;
259 }
260 content.state = State::Cancel;
261
262 unsafe {
263 content.run_wakers.drain(|w| w.wake());
264 content.cancel_wakers.drain(|w| w.wake());
265 }
266 self.0.cond.notify_all();
267 }
268
269 #[cfg(feature = "pause")]
271 pub fn pause(&self) {
272 let mut content = self.0.content.lock().unwrap();
273 if !matches!(content.state, State::Run) {
274 return;
275 }
276 content.state = State::Pause;
277 }
278
279 #[cfg(feature = "pause")]
281 pub fn resume(&self) {
282 let mut content = self.0.content.lock().unwrap();
283 if !matches!(content.state, State::Pause) {
284 return;
285 }
286 content.state = State::Run;
287 unsafe {
288 content.run_wakers.drain(|w| w.wake());
289 }
290 self.0.cond.notify_all();
291 }
292
293 pub fn is_cancelled(&self) -> bool {
295 matches!(self.0.content.lock().unwrap().state, State::Cancel)
296 }
297
298 #[cfg(feature = "pause")]
300 pub fn is_paused(&self) -> bool {
301 matches!(self.0.content.lock().unwrap().state, State::Pause)
302 }
303
304 #[cfg(feature = "pause")]
306 pub fn is_running(&self) -> bool {
307 matches!(self.0.content.lock().unwrap().state, State::Run)
308 }
309
310 #[cfg(feature = "pause")]
312 pub fn wait_paused_check_cancelled_sync(&self) -> bool {
313 let mut content = self.0.content.lock().unwrap();
314 loop {
315 match &content.state {
316 State::Run => return false,
317 State::Cancel => return true,
318 State::Pause => {
319 content = self.0.cond.wait(content).unwrap();
320 }
321 }
322 }
323 }
324
325 #[cfg(feature = "pause")]
327 pub fn wait_paused_check_cancelled(&self) -> WaitForPauseFuture<'_> {
328 WaitForPauseFuture {
329 token: self,
330 waker: Default::default(),
331 }
332 }
333
334 pub fn cancelled(&self) -> WaitForCancellationFuture<'_> {
336 WaitForCancellationFuture {
337 token: self,
338 waker: Default::default(),
339 }
340 }
341
342 #[cfg(feature = "ordered-locks")]
343 pub fn cancelled_checked(
344 &self,
345 _lock_token: LockToken<'_, L0>,
346 ) -> WaitForCancellationFuture<'_> {
347 WaitForCancellationFuture {
348 token: self,
349 waker: Default::default(),
350 }
351 }
352
353 #[inline]
355 pub fn set_location(&self, file: &'static str, line: u32) {
356 let rs_loc = file.find(".rs").expect(".rs in file name");
357 let file = &file[..rs_loc + 3];
358 assert!((line as u64) < LINE_MASK);
359 let hash = fxhash(file.as_bytes());
360 self.0
361 .location_file
362 .store(file.as_ptr() as *mut u8, Ordering::Relaxed);
363 self.0
364 .location_line
365 .store((hash & HASH_MASK) | line as u64, Ordering::Relaxed);
366 }
367
368 pub fn location(&self) -> Option<(&'static str, u32)> {
370 let mut cnt = 0;
371 loop {
372 let file = self.0.location_file.load(Ordering::Relaxed) as *const u8;
373 let line = self.0.location_line.load(Ordering::Relaxed);
374 if file.is_null() {
375 return None;
376 }
377 let mut len = 0;
378 let file = loop {
379 unsafe {
381 if *file.add(len) == b'.'
382 && *file.add(len + 1) == b'r'
383 && *file.add(len + 2) == b's'
384 {
385 break std::str::from_utf8_unchecked(std::slice::from_raw_parts(
386 file,
387 len + 3,
388 ));
389 }
390 }
391 len += 1;
392 };
393
394 let hash = fxhash(file.as_bytes());
395 if (hash & HASH_MASK) == (line & HASH_MASK) {
396 return Some((file, (line & LINE_MASK) as u32));
397 }
398 if cnt == 0xFFFF {
399 return Some((file, 0));
400 }
401 cnt += 1;
402 std::hint::spin_loop();
403 }
404 }
405
406 #[cfg(feature = "runtoken-id")]
407 #[inline]
408 pub fn id(&self) -> u64 {
409 self.0.id
410 }
411}
412
413impl Default for RunToken {
414 fn default() -> Self {
415 Self::new()
416 }
417}
418
419impl core::fmt::Debug for RunToken {
420 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
421 let mut d = f.debug_tuple("RunToken");
422 match self.0.content.lock().unwrap().state {
423 State::Run => d.field(&"Running"),
424 State::Cancel => d.field(&"Canceled"),
425 #[cfg(feature = "pause")]
426 State::Pause => d.field(&"Paused"),
427 };
428 d.finish()
429 }
430}
431
432#[must_use = "futures do nothing unless polled"]
434pub struct WaitForCancellationFuture<'a> {
435 token: &'a RunToken,
436 waker: ListNode<Waker>,
437}
438
439impl<'a> core::fmt::Debug for WaitForCancellationFuture<'a> {
440 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
441 f.debug_struct("WaitForCancellationFuture").finish()
442 }
443}
444
445impl<'a> Future for WaitForCancellationFuture<'a> {
446 type Output = ();
447
448 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
449 let mut content = self.token.0.content.lock().unwrap();
450 match content.state {
451 State::Cancel => Poll::Ready(()),
452 State::Run => {
453 unsafe {
454 content.add_cancle_waker(&mut Pin::get_unchecked_mut(self).waker, cx.waker());
455 }
456 Poll::Pending
457 }
458 #[cfg(feature = "pause")]
459 State::Pause => {
460 unsafe {
461 content.add_cancle_waker(&mut Pin::get_unchecked_mut(self).waker, cx.waker());
462 }
463 Poll::Pending
464 }
465 }
466 }
467}
468
469impl<'a> Drop for WaitForCancellationFuture<'a> {
470 fn drop(&mut self) {
471 unsafe {
472 self.token
473 .0
474 .content
475 .lock()
476 .unwrap()
477 .remove_cancle_waker(&mut self.waker);
478 }
479 }
480}
481
482unsafe impl<'a> Send for WaitForCancellationFuture<'a> {}
483
484#[cfg(feature = "pause")]
486#[must_use = "futures do nothing unless polled"]
487pub struct WaitForPauseFuture<'a> {
488 token: &'a RunToken,
489 waker: ListNode<Waker>,
490}
491
492#[cfg(feature = "pause")]
493impl<'a> core::fmt::Debug for WaitForPauseFuture<'a> {
494 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
495 f.debug_struct("WaitForPauseFuture").finish()
496 }
497}
498
499#[cfg(feature = "pause")]
500impl<'a> Future for WaitForPauseFuture<'a> {
501 type Output = bool;
502
503 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<bool> {
504 let mut content = self.token.0.content.lock().unwrap();
505 match content.state {
506 State::Cancel => Poll::Ready(true),
507 State::Run => Poll::Ready(false),
508 State::Pause => {
509 unsafe {
510 content.add_run_waker(&mut Pin::get_unchecked_mut(self).waker, cx.waker());
511 }
512 Poll::Pending
513 }
514 }
515 }
516}
517
518#[cfg(feature = "pause")]
519impl<'a> Drop for WaitForPauseFuture<'a> {
520 fn drop(&mut self) {
521 unsafe {
522 self.token
523 .0
524 .content
525 .lock()
526 .unwrap()
527 .remove_run_waker(&mut self.waker);
528 }
529 }
530}
531
532#[cfg(feature = "pause")]
533unsafe impl<'a> Send for WaitForPauseFuture<'a> {}