1use std::{
60 cell::UnsafeCell,
61 marker::PhantomData,
62 mem,
63 mem::{ManuallyDrop, MaybeUninit},
64 ops, ptr,
65 sync::{
66 atomic::{AtomicPtr, AtomicUsize, Ordering},
67 Arc,
68 },
69};
70
71use parking_lot::{Mutex, RwLock, RwLockUpgradableReadGuard};
73use thiserror::Error;
74
75use crate::{
76 executor::{AsyncExecutor, Blocking, Executor, Nonblock},
77 prelude::*,
78 AsyncHandler,
79};
80
81pub trait SchedulerCore<J> {
83 fn create_node(&self, payload: J, dependencies: usize) -> NodeBuilder<J> {
89 self.create_node_or_run(payload, dependencies).unwrap()
90 }
91
92 fn create_node_or_run(&self, payload: J, dependencies: usize) -> Option<NodeBuilder<J>>;
98
99 fn push_with_dependents(&self, payload: J, dependents: OptRcDependents<J>);
106
107 fn push_dependency(
111 &self,
112 payload: J,
113 dependents: impl IntoIterator<Item = Edge<J>>,
114 ) -> Arc<Dependents<J>> {
115 let deps = Dependents::new(dependents.into_iter().collect());
116
117 self.push_with_dependents(payload, Some(Arc::clone(&deps)));
118
119 deps
120 }
121}
122
123type NodePayload<T> = ManuallyDrop<UnsafeCell<MaybeUninit<T>>>;
125
126#[derive(Debug)]
127struct Node<J> {
128 payload: NodePayload<J>,
129 dependents: AtomicPtr<Dependents<J>>,
130 dependencies: AtomicUsize,
131}
132
133#[derive(Debug)]
137#[repr(transparent)]
138pub struct Edge<J> {
139 to: Arc<Node<J>>,
140}
141
142#[derive(Debug, Clone, Copy, Error)]
150#[error("The node associated with this builder can no longer be accessed")]
151pub struct NodeDispatched;
152
153#[derive(Debug)]
159pub struct NodeBuilder<J> {
160 node: Option<Arc<Node<J>>>,
161 remaining: usize,
162}
163
164#[derive(Debug)]
166pub struct Dependents<J>(RwLock<Option<Vec<Edge<J>>>>);
167
168#[derive(Debug)]
169enum AdoptState<J> {
170 Orphan(Vec<Edge<J>>),
171 Adopted(Arc<Dependents<J>>),
172 Abandoned,
173 Completed,
174 Poisoned,
175}
176
177#[derive(Debug, Clone, Copy, Error)]
180#[error("Adoptable dependents have already been adopted or abandoned")]
181pub struct BadAdoptState;
182
183#[derive(Debug)]
193pub struct AdoptableDependents<J>(AdoptState<J>);
194
195#[derive(Debug)]
198#[repr(transparent)]
199pub struct RcAdoptableDependents<J>(Arc<Mutex<AdoptableDependents<J>>>);
200
201type OptRcDependents<J> = Option<Arc<Dependents<J>>>;
202
203#[derive(Debug)]
205pub struct Job<J> {
206 payload: J,
207 dependents: OptRcDependents<J>,
208}
209
210#[derive(Debug, Clone, Copy)]
212pub struct Handle<H>(H);
213
214#[derive(Debug)]
216pub struct Scheduler<J, E> {
217 executor: E,
218 _m: PhantomData<fn(J)>,
219}
220
221unsafe impl<J> Sync for Node<J> {}
224
225impl<J> Node<J> {
226 fn decrement<H: SchedulerCore<J>>(&self, handle: &H) {
227 match self.dependencies.fetch_sub(1, Ordering::SeqCst) {
228 1 => {
229 let job = {
230 let mut taken = MaybeUninit::zeroed();
231
232 unsafe {
233 ptr::swap(self.payload.get(), &mut taken);
234 taken.assume_init()
235 }
236 };
237
238 let dependents = {
239 let ptr = self.dependents.swap(ptr::null_mut(), Ordering::SeqCst);
240
241 if ptr.is_null() {
242 None
243 } else {
244 Some(unsafe { Arc::from_raw(ptr) })
245 }
246 };
247
248 handle.push_with_dependents(job, dependents);
249 },
250 0 | usize::MAX => unreachable!(),
251 _ => (),
252 }
253 }
254
255 fn set_dependents(&self, dependents: Arc<Dependents<J>>) -> Result<(), Arc<Dependents<J>>> {
257 let ptr = Arc::into_raw(dependents);
258
259 self.dependents
260 .compare_exchange(
261 ptr::null_mut(),
262 ptr.cast_mut(),
263 Ordering::SeqCst,
264 Ordering::Relaxed,
265 )
266 .map(|_| ())
267 .map_err(|_| unsafe { Arc::from_raw(ptr) })
268 }
269}
270
271impl<J> Drop for Node<J> {
272 fn drop(&mut self) {
273 match mem::replace(self.dependencies.get_mut(), 0) {
274 0 => (),
275 usize::MAX => unreachable!(),
276 _ => unsafe {
277 mem::drop(
278 ManuallyDrop::take(&mut self.payload)
279 .into_inner()
280 .assume_init(),
281 );
282 },
283 }
284 }
285}
286
287impl<J> Edge<J> {
288 fn new(to: Arc<Node<J>>) -> Self { Self { to } }
289}
290
291impl<J> NodeBuilder<J> {
292 fn create_or_run(payload: J, dependencies: usize, run: impl FnOnce(J)) -> Option<Self> {
293 match dependencies {
294 0 => {
295 run(payload);
296
297 None
298 },
299 usize::MAX => panic!("Invalid number of dependencies! (usize::MAX is reserved)"),
300 _ => {
301 let node = Arc::new(Node {
302 payload: ManuallyDrop::new(UnsafeCell::new(MaybeUninit::new(payload))),
303 dependents: AtomicPtr::new(ptr::null_mut()),
304 dependencies: AtomicUsize::new(dependencies),
305 });
306
307 Some(NodeBuilder {
308 node: Some(node),
309 remaining: dependencies,
310 })
311 },
312 }
313 }
314
315 #[inline]
320 pub fn get_in_edge(&mut self) -> Edge<J> { self.try_get_in_edge().unwrap() }
321
322 pub fn try_get_in_edge(&mut self) -> Option<Edge<J>> {
325 if (self.remaining == 0) != self.node.is_none() {
326 unreachable!();
327 }
328
329 let node = match self.remaining {
330 0 => None,
331 1 => {
332 self.remaining = 0;
333 self.node.take()
334 },
335 _ => {
336 self.remaining -= 1;
337 self.node.clone()
338 },
339 };
340
341 node.map(Edge::new)
342 }
343
344 pub fn set_dependents(
351 &mut self, dependents: Arc<Dependents<J>>,
353 ) -> Result<(), Arc<Dependents<J>>> {
354 let Some(node) = self.node.as_ref() else {
355 return Err(dependents);
356 };
357
358 debug_assert!(self.remaining > 0);
359 debug_assert!(node.dependencies.load(Ordering::SeqCst) >= self.remaining);
360
361 node.set_dependents(dependents)?;
362
363 Ok(())
364 }
365}
366
367impl<J> Drop for NodeBuilder<J> {
368 fn drop(&mut self) {
369 assert!(
370 self.remaining == 0 || self.node.is_none(),
371 "Failed to exhaust dependency bag!"
372 );
373 }
374}
375
376impl<J> Dependents<J> {
377 #[must_use]
379 pub fn new(dependents: Vec<Edge<J>>) -> Arc<Self> {
380 Arc::new(Self(RwLock::new(Some(dependents))))
381 }
382
383 pub fn push<H: SchedulerCore<J>>(&self, handle: &H, dependent: Edge<J>) {
388 let this = self.0.upgradable_read();
389
390 if this.is_some() {
391 let mut this = RwLockUpgradableReadGuard::upgrade(this);
392 let this = this.as_mut().unwrap_or_else(|| unreachable!());
393
394 this.push(dependent);
395 } else {
396 dependent.to.decrement(handle);
397 }
398 }
399}
400
401impl<J> From<Edge<J>> for Arc<Dependents<J>> {
402 #[inline]
403 fn from(edge: Edge<J>) -> Self { Dependents::new(vec![edge]) }
404}
405
406impl<J> std::iter::FromIterator<Edge<J>> for Arc<Dependents<J>> {
407 #[inline]
408 fn from_iter<I: IntoIterator<Item = Edge<J>>>(it: I) -> Self {
409 Dependents::new(it.into_iter().collect())
410 }
411}
412
413impl<J> AdoptableDependents<J> {
414 #[must_use]
420 pub fn new() -> Self { Self(AdoptState::Orphan(vec![])) }
421
422 #[must_use]
428 pub fn adopted(dependents: Arc<Dependents<J>>) -> Self { Self(AdoptState::Adopted(dependents)) }
429
430 #[must_use]
435 pub fn abandoned() -> Self { Self(AdoptState::Abandoned) }
436
437 #[must_use]
443 pub fn completed() -> Self { Self(AdoptState::Completed) }
444
445 #[inline]
448 #[must_use]
449 pub fn rc(self) -> RcAdoptableDependents<J> {
450 RcAdoptableDependents(Arc::new(Mutex::new(self)))
451 }
452
453 pub fn push<H: SchedulerCore<J>>(&mut self, handle: &H, dependent: Edge<J>) {
467 match self.0 {
468 AdoptState::Orphan(ref mut deps) => {
469 deps.push(dependent);
470 },
471 AdoptState::Adopted(ref dependents) => dependents.push(handle, dependent),
472 AdoptState::Abandoned => mem::drop(dependent),
473 AdoptState::Completed => dependent.to.decrement(handle),
474 AdoptState::Poisoned => panic!("AdoptableDependents was poisoned"),
475 }
476 }
477
478 pub fn adopt<H: SchedulerCore<J>>(
490 &mut self,
491 handle: &H,
492 dependents: Arc<Dependents<J>>,
493 ) -> Result<(), BadAdoptState> {
494 match self.0 {
495 AdoptState::Orphan(_) => (),
496 AdoptState::Adopted(_) | AdoptState::Abandoned | AdoptState::Completed => {
497 return Err(BadAdoptState);
498 },
499 AdoptState::Poisoned => panic!("AdoptableDependents was poisoned"),
500 }
501
502 if let AdoptState::Orphan(deps) = mem::replace(&mut self.0, AdoptState::Poisoned) {
503 for dep in deps {
504 dependents.push(handle, dep);
505 }
506
507 self.0 = AdoptState::Adopted(dependents);
508
509 Ok(())
510 } else {
511 unreachable!()
512 }
513 }
514
515 pub fn abandon(&mut self) -> Result<bool, BadAdoptState> {
528 match self.0 {
529 AdoptState::Orphan(_) => (),
530 AdoptState::Adopted(_) | AdoptState::Completed => return Err(BadAdoptState),
531 AdoptState::Abandoned => return Ok(false),
532 AdoptState::Poisoned => panic!("AdoptableDependencies was poisoned"),
533 }
534
535 if let AdoptState::Orphan(jobs) = mem::replace(&mut self.0, AdoptState::Abandoned) {
536 mem::drop(jobs);
537
538 Ok(true)
539 } else {
540 unreachable!();
541 }
542 }
543
544 pub fn complete<H: SchedulerCore<J>>(&mut self, handle: &H) -> Result<bool, BadAdoptState> {
558 match self.0 {
559 AdoptState::Orphan(_) => (),
560 AdoptState::Adopted(_) | AdoptState::Completed => return Ok(false),
561 AdoptState::Abandoned => return Err(BadAdoptState),
562 AdoptState::Poisoned => panic!("AdoptableDependents was poisoned"),
563 }
564
565 if let AdoptState::Orphan(edges) = mem::replace(&mut self.0, AdoptState::Completed) {
566 for edge in edges {
567 edge.to.decrement(handle);
568 }
569
570 Ok(true)
571 } else {
572 unreachable!();
573 }
574 }
575}
576
577impl<J> Default for AdoptableDependents<J> {
578 fn default() -> Self { Self::new() }
579}
580
581impl<J> ops::Deref for RcAdoptableDependents<J> {
582 type Target = Mutex<AdoptableDependents<J>>;
583
584 fn deref(&self) -> &Self::Target { self.0.as_ref() }
585}
586
587impl<J> Clone for RcAdoptableDependents<J> {
588 fn clone(&self) -> Self { Self(Arc::clone(&self.0)) }
589}
590
591impl<J> From<J> for Job<J> {
592 #[inline]
593 fn from(payload: J) -> Self {
594 Self {
595 payload,
596 dependents: None,
597 }
598 }
599}
600
601impl<J, H: ExecutorHandle<Job<J>>> SchedulerCore<J> for Handle<H> {
602 fn create_node_or_run(&self, payload: J, dependencies: usize) -> Option<NodeBuilder<J>> {
603 NodeBuilder::create_or_run(payload, dependencies, |j| self.0.push(j.into()))
604 }
605
606 #[inline]
607 fn push_with_dependents(&self, payload: J, dependents: OptRcDependents<J>) {
608 self.0.push(Job {
609 payload,
610 dependents,
611 });
612 }
613}
614
615impl<J, H: ExecutorHandle<Job<J>>> ExecutorHandle<J> for Handle<H> {
616 #[inline]
617 fn push(&self, job: J) { self.0.push(job.into()); }
618}
619
620fn process_result<J, H: ExecutorHandle<Job<J>> + Copy>(
622 res: Result<(), ()>,
623 handle: Handle<H>,
624 dependents: OptRcDependents<J>,
625) {
626 #[allow(clippy::single_match)]
627 match res {
628 Ok(()) => {
629 if let Some(dependents) = dependents {
630 for dep in mem::take(&mut *dependents.0.write()).into_iter().flatten() {
631 dep.to.decrement(&handle);
632 }
633 }
634 },
635 Err(()) => (),
636 }
637}
638
639impl<J, E: ExecutorCore<Job<J>>> Scheduler<J, E> {
640 fn new<
642 B: ExecutorBuilderSync<Job<J>, Executor = E>,
643 F: Fn(J, Handle<E::Handle<'_>>) -> Result<(), ()> + Clone + Send + 'static,
644 >(
645 b: B,
646 f: F,
647 ) -> Result<Self, B::Error> {
648 b.build(
649 move |Job {
650 payload,
651 dependents,
652 },
653 handle| {
654 let handle = Handle(handle);
655
656 let res = f(payload, handle);
657 process_result(res, handle, dependents);
658 },
659 )
660 .map(|executor| Self {
661 executor,
662 _m: PhantomData,
663 })
664 }
665}
666
667impl<J: Send, E: ExecutorCore<Job<J>>> Scheduler<J, E>
668where for<'a> E::Handle<'a>: Send
669{
670 fn new_async<
671 B: ExecutorBuilderAsync<Job<J>, Executor = E>,
672 F: for<'h> AsyncHandler<J, Handle<E::Handle<'h>>, Output = Result<(), ()>>
673 + Clone
674 + Send
675 + Sync
676 + 'static,
677 >(
678 b: B,
679 f: F,
680 ) -> Result<Self, B::Error> {
681 #[derive(Clone)]
682 struct Handler<F>(F);
683 impl<
684 J: Send,
685 H: ExecutorHandle<Job<J>> + Copy + Send,
686 F: AsyncHandler<J, Handle<H>, Output = Result<(), ()>> + Sync,
687 > AsyncHandler<Job<J>, H> for Handler<F>
688 {
689 type Output = ();
690
691 async fn handle(&self, job: Job<J>, handle: H) {
692 let Job {
693 payload,
694 dependents,
695 } = job;
696 let handle = Handle(handle);
697
698 let res = self.0.handle(payload, handle).await;
699 process_result(res, handle, dependents);
700 }
701 }
702
703 b.build_async(Handler(f)).map(|executor| Self {
704 executor,
705 _m: PhantomData,
706 })
707 }
708}
709
710impl<J, E> std::ops::Deref for Scheduler<J, E> {
711 type Target = E;
712
713 fn deref(&self) -> &E { &self.executor }
714}
715
716pub trait ExecutorBuilderExt<J>: Sized + ExecutorBuilderCore<Job<J>> {
719 fn build_graph<
725 F: Fn(J, Handle<<Self::Executor as ExecutorCore<Job<J>>>::Handle<'_>>) -> Result<(), ()>
726 + Clone
727 + Send
728 + 'static,
729 >(
730 self,
731 work: F,
732 ) -> Result<Scheduler<J, Self::Executor>, Self::Error>
733 where
734 Self: ExecutorBuilderSync<Job<J>>,
735 {
736 Scheduler::new(self, work)
737 }
738
739 fn build_graph_async<
745 F: for<'h> AsyncHandler<
746 J,
747 Handle<<Self::Executor as ExecutorCore<Job<J>>>::Handle<'h>>,
748 Output = Result<(), ()>,
749 > + Clone
750 + Send
751 + Sync
752 + 'static,
753 >(
754 self,
755 work: F,
756 ) -> Result<Scheduler<J, Self::Executor>, Self::Error>
757 where
758 J: Send,
759 Self: ExecutorBuilderAsync<Job<J>>,
760 for<'a> <Self::Executor as ExecutorCore<Job<J>>>::Handle<'a>: Send,
761 {
762 Scheduler::new_async(self, work)
763 }
764}
765
766impl<J, B: ExecutorBuilderCore<Job<J>> + Sized> ExecutorBuilderExt<J> for B {}
767
768impl<J, E: ExecutorCore<Job<J>>> ExecutorHandle<J> for Scheduler<J, E> {
769 #[inline]
770 fn push(&self, job: J) { self.executor.push(job.into()); }
771}
772
773impl<J, E: ExecutorCore<Job<J>>> ExecutorCore<J> for Scheduler<J, E> {
774 type Handle<'a> = Handle<E::Handle<'a>>;
775}
776
777impl<J: Send + 'static> Scheduler<J, Executor<Job<J>, Blocking>> {
778 #[inline]
781 pub fn join(self) { self.executor.join(); }
782
783 #[inline]
786 pub fn abort(self) { self.executor.abort(); }
787}
788
789impl<J: Send + 'static, E: AsyncExecutor> Scheduler<J, Executor<Job<J>, Nonblock<E>>> {
790 #[inline]
794 pub fn join_async(self) -> impl std::future::Future<Output = ()> + Send {
795 self.executor.join_async()
796 }
797
798 #[inline]
801 pub fn abort_async(self) -> impl std::future::Future<Output = ()> + Send {
802 self.executor.abort_async()
803 }
804}
805
806impl<J, E: ExecutorCore<Job<J>>> SchedulerCore<J> for Scheduler<J, E> {
807 fn create_node_or_run(&self, payload: J, dependencies: usize) -> Option<NodeBuilder<J>> {
808 NodeBuilder::create_or_run(payload, dependencies, |j| self.executor.push(j.into()))
809 }
810
811 #[inline]
812 fn push_with_dependents(&self, payload: J, dependents: OptRcDependents<J>) {
813 self.executor.push(Job {
814 payload,
815 dependents,
816 });
817 }
818}