1use std::fmt;
2use std::future::Future;
3use std::hash::Hash;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::time::Duration;
8
9use crate::cancel::CancelToken;
10use crate::clock::{Clock, ClockInstant, SystemClock};
11use crate::error::{Error, Result};
12use crate::key::{KeyData, PathKey, TaskId};
13use crate::observer::{TaskEvent, TaskEventKind, TaskEventOutcome, TaskObserver, TaskRunSource};
14use crate::overrides::{TaskOverride, TaskOverrides};
15use crate::store::{RunningGuard, Slot, SlotClaim, Store, StoredOutcome, decode_outcome};
16
17static NEXT_TASK_ID: AtomicU64 = AtomicU64::new(1);
18
19type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;
20type TaskFn<I, O, E> = dyn Fn(ExecCtx<E>, I) -> BoxFuture<Result<O, E>> + Send + Sync;
21type PreparedTaskFn<E> = dyn FnOnce(ExecCtx<E>) -> BoxFuture<Result<(), E>> + Send;
22type PreparedTaskResultFn<O> = dyn FnOnce(Arc<O>) + Send;
23
24pub struct Task<I, O, E = Box<dyn std::error::Error + Send + Sync>> {
26 inner: Arc<TaskInner<I, O, E>>,
27}
28
29impl<I, O, E> Task<I, O, E>
30where
31 I: Clone + Eq + Hash + Send + Sync + 'static,
32 O: Send + Sync + 'static,
33 E: Send + Sync + 'static,
34{
35 pub fn new<F, Fut>(cross_exec_ctx_cache_ttl: Duration, f: F) -> Self
41 where
42 F: Fn(ExecCtx<E>, I) -> Fut + Send + Sync + 'static,
43 Fut: Future<Output = Result<O, E>> + Send + 'static,
44 {
45 Self {
46 inner: Arc::new(TaskInner {
47 id: TaskId(NEXT_TASK_ID.fetch_add(1, Ordering::Relaxed)),
48 cross_exec_ctx_cache_ttl,
49 name: std::any::type_name::<I>(),
50 f: Box::new(move |ctx, input| Box::pin(f(ctx, input))),
51 }),
52 }
53 }
54
55 pub fn id(&self) -> TaskId {
57 self.inner.id
58 }
59
60 pub async fn run(&self, ctx: &ExecCtx<E>, input: I) -> Result<Arc<O>, E> {
62 ctx.resolve(self.clone(), input).await
63 }
64
65 pub fn bind_input(&self, input: I) -> PreparedTask<E> {
67 self.bind_input_with_result(input, |_| {})
68 }
69
70 pub fn bind_input_with_result<F>(&self, input: I, result: F) -> PreparedTask<E>
72 where
73 F: FnOnce(Arc<O>) + Send + 'static,
74 {
75 let task = self.clone();
76 let mut result = Some(Box::new(result) as Box<PreparedTaskResultFn<O>>);
77 PreparedTask::new(move |ctx| async move {
78 let output = task.run(&ctx, input).await?;
79 if let Some(result) = result.take() {
80 result(output);
81 }
82 Ok(())
83 })
84 }
85
86 fn key(&self, input: &I) -> KeyData<I> {
87 KeyData::new(self.inner.id, input)
88 }
89
90 fn child_ctx(&self, ctx: &ExecCtx<E>, input: &I) -> ExecCtx<E> {
91 ctx.child_for_task(PathKey::new(self.inner.id, input))
92 }
93
94 fn shared_run_child(&self, ctx: &ExecCtx<E>, input: &I) -> ExecCtx<E> {
95 ctx.shared_run_child(PathKey::new(self.inner.id, input))
96 }
97
98 async fn call(&self, ctx: ExecCtx<E>, input: I) -> StoredOutcome<E> {
99 let task_override = ctx.tasks.task_override_for(self);
100 match task_override {
101 TaskOverride::RunOriginal => self.call_original(ctx, input).await,
102 TaskOverride::Replace(replacement) => replacement.call(ctx, input).await,
103 TaskOverride::Missing => StoredOutcome::Err(Error::MissingOverride {
104 task: self.inner.name,
105 }),
106 }
107 }
108
109 async fn call_original(&self, ctx: ExecCtx<E>, input: I) -> StoredOutcome<E> {
110 match (self.inner.f)(ctx, input).await {
111 Ok(output) => StoredOutcome::Ok(Arc::new(output)),
112 Err(error) => StoredOutcome::Err(error),
113 }
114 }
115}
116
117pub struct PreparedTask<E = Box<dyn std::error::Error + Send + Sync>> {
119 run: Box<PreparedTaskFn<E>>,
120}
121
122impl<E> PreparedTask<E>
123where
124 E: Send + Sync + 'static,
125{
126 fn new<F, Fut>(run: F) -> Self
127 where
128 F: FnOnce(ExecCtx<E>) -> Fut + Send + 'static,
129 Fut: Future<Output = Result<(), E>> + Send + 'static,
130 {
131 Self {
132 run: Box::new(move |ctx| Box::pin(run(ctx))),
133 }
134 }
135
136 async fn run(self, ctx: ExecCtx<E>) -> Result<(), E> {
137 (self.run)(ctx).await
138 }
139}
140
141impl<I, O, E> Clone for Task<I, O, E> {
142 fn clone(&self) -> Self {
143 Self {
144 inner: self.inner.clone(),
145 }
146 }
147}
148
149struct TaskInner<I, O, E> {
150 id: TaskId,
151 cross_exec_ctx_cache_ttl: Duration,
152 name: &'static str,
153 f: Box<TaskFn<I, O, E>>,
154}
155
156pub struct Tasks<E = Box<dyn std::error::Error + Send + Sync>> {
158 inner: Arc<TasksInner<E>>,
159}
160
161pub struct TasksOptions<E = Box<dyn std::error::Error + Send + Sync>> {
163 pub clock: Arc<dyn Clock>,
165 pub observer: Option<Arc<dyn TaskObserver>>,
167 pub overrides: Option<TaskOverrides<E>>,
169}
170
171impl<E> Tasks<E>
172where
173 E: Send + Sync + 'static,
174{
175 pub fn new(options: TasksOptions<E>) -> Self {
177 Self {
178 inner: Arc::new(TasksInner {
179 clock: options.clock,
180 observer: options.observer,
181 shared: Store::new(),
182 overrides: options.overrides,
183 }),
184 }
185 }
186
187 fn task_override_for<I, O>(&self, task: &Task<I, O, E>) -> TaskOverride<I, O, E>
188 where
189 I: Clone + Eq + Hash + Send + Sync + 'static,
190 O: Send + Sync + 'static,
191 {
192 if let Some(overrides) = &self.inner.overrides {
193 overrides.resolve(task)
194 } else {
195 TaskOverride::RunOriginal
196 }
197 }
198
199 pub fn exec_ctx(&self, cancel: CancelToken) -> ExecCtx<E> {
201 ExecCtx {
202 tasks: self.clone(),
203 local: Arc::new(Store::new()),
204 cancel,
205 path: Arc::new(Vec::new()),
206 }
207 }
208
209 fn now(&self) -> ClockInstant {
210 self.inner.clock.now()
211 }
212
213 fn observe(&self, task_id: TaskId, task_input_type: &'static str, kind: TaskEventKind) {
214 if let Some(observer) = &self.inner.observer {
215 observer.observe(TaskEvent {
216 at: self.now(),
217 task_id,
218 task_input_type,
219 kind,
220 });
221 }
222 }
223}
224
225impl<E> Clone for Tasks<E> {
226 fn clone(&self) -> Self {
227 Self {
228 inner: self.inner.clone(),
229 }
230 }
231}
232
233impl<E> fmt::Debug for Tasks<E> {
234 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
235 f.debug_struct("Tasks").finish_non_exhaustive()
236 }
237}
238
239struct TasksInner<E> {
240 clock: Arc<dyn Clock>,
241 observer: Option<Arc<dyn TaskObserver>>,
242 shared: Store<E>,
243 overrides: Option<TaskOverrides<E>>,
244}
245
246impl<E> Default for TasksOptions<E>
247where
248 E: 'static,
249{
250 fn default() -> Self {
251 Self {
252 clock: Arc::new(SystemClock::new()),
253 observer: None,
254 overrides: None,
255 }
256 }
257}
258
259pub struct ExecCtx<E = Box<dyn std::error::Error + Send + Sync>> {
261 tasks: Tasks<E>,
262 local: Arc<Store<E>>,
263 cancel: CancelToken,
264 path: Arc<Vec<PathKey>>,
265}
266
267impl<E> ExecCtx<E>
268where
269 E: Send + Sync + 'static,
270{
271 pub fn cancel_token(&self) -> &CancelToken {
273 &self.cancel
274 }
275
276 pub fn is_cancelled(&self) -> bool {
278 self.cancel.is_cancelled()
279 }
280
281 pub fn child(&self) -> Self {
283 Self {
284 tasks: self.tasks.clone(),
285 local: self.local.clone(),
286 cancel: self.cancel.child(),
287 path: self.path.clone(),
288 }
289 }
290
291 pub async fn run_parallel(
296 &self,
297 tasks: impl IntoIterator<Item = PreparedTask<E>>,
298 ) -> Result<(), E> {
299 if self.cancel.is_cancelled() {
300 return Err(Error::Cancelled);
301 }
302
303 let tasks = tasks.into_iter().collect::<Vec<_>>();
304 match tasks.len() {
305 0 => return Ok(()),
306 1 => {
307 let mut tasks = tasks;
308 let task = tasks.pop().expect("length checked");
309 return task.run(self.child()).await;
310 }
311 _ => {}
312 }
313
314 let shared = self.child();
315 let cancel = shared.cancel.clone();
316 let mut joined = tokio::task::JoinSet::new();
317 for task in tasks {
318 let ctx = shared.clone();
319 joined.spawn(async move { task.run(ctx).await });
320 }
321
322 let mut first_error = None;
323 while let Some(result) = joined.join_next().await {
324 match result {
325 Ok(Ok(())) => {}
326 Ok(Err(error)) => {
327 let replace_first_error =
328 first_error.as_ref().is_none_or(|first: &Error<E>| {
329 first.is_cancelled() && !error.is_cancelled()
330 });
331 if replace_first_error {
332 first_error = Some(error);
333 }
334 cancel.cancel();
335 }
336 Err(join_error) if join_error.is_panic() => {
337 std::panic::resume_unwind(join_error.into_panic());
338 }
339 Err(_) => {
340 if first_error.is_none() {
341 first_error = Some(Error::Cancelled);
342 }
343 cancel.cancel();
344 }
345 }
346 }
347
348 if let Some(error) = first_error {
349 return Err(error);
350 }
351 if self.cancel.is_cancelled() {
352 return Err(Error::Cancelled);
353 }
354 Ok(())
355 }
356
357 fn child_for_task(&self, key: PathKey) -> Self {
358 let mut path = Vec::with_capacity(self.path.len() + 1);
359 path.extend(self.path.iter().cloned());
360 path.push(key);
361 Self {
362 tasks: self.tasks.clone(),
363 local: self.local.clone(),
364 cancel: self.cancel.clone(),
365 path: Arc::new(path),
366 }
367 }
368
369 fn shared_run_child(&self, key: PathKey) -> Self {
370 let mut path = Vec::with_capacity(self.path.len() + 1);
371 path.extend(self.path.iter().cloned());
372 path.push(key);
373 Self {
374 tasks: self.tasks.clone(),
375 local: self.local.clone(),
376 cancel: CancelToken::new(),
377 path: Arc::new(path),
378 }
379 }
380
381 async fn resolve<I, O>(&self, task: Task<I, O, E>, input: I) -> Result<Arc<O>, E>
382 where
383 I: Clone + Eq + Hash + Send + Sync + 'static,
384 O: Send + Sync + 'static,
385 {
386 let key = task.key(&input);
387 if self.path.iter().any(|path_key| path_key.matches(&key)) {
388 return Err(Error::Cycle {
389 task: task.inner.name,
390 });
391 }
392 if self.cancel.is_cancelled() {
393 self.tasks
394 .observe(task.inner.id, task.inner.name, TaskEventKind::Cancelled);
395 return Err(Error::Cancelled);
396 }
397
398 let local_slot = self.local.slot_for(&key, None, self.tasks.now()).slot;
399 let local_guard = match local_slot.claim() {
400 SlotClaim::Ready(outcome) => {
401 self.tasks.observe(
402 task.inner.id,
403 task.inner.name,
404 TaskEventKind::ExecCtxMemoHit,
405 );
406 if self.cancel.is_cancelled() {
407 self.tasks
408 .observe(task.inner.id, task.inner.name, TaskEventKind::Cancelled);
409 return Err(Error::Cancelled);
410 }
411 return decode_outcome::<O, E>(outcome, task.inner.name);
412 }
413 SlotClaim::Wait => {
414 self.tasks.observe(
415 task.inner.id,
416 task.inner.name,
417 TaskEventKind::ExecCtxMemoWait,
418 );
419 return self
420 .wait_for_local::<O>(&local_slot, task.inner.id, task.inner.name)
421 .await;
422 }
423 SlotClaim::Run => {
424 self.tasks.observe(
425 task.inner.id,
426 task.inner.name,
427 TaskEventKind::ExecCtxMemoMiss,
428 );
429 RunningGuard::new(local_slot.clone())
430 }
431 };
432
433 let outcome = if !task.inner.cross_exec_ctx_cache_ttl.is_zero() {
434 match self
435 .resolve_shared(
436 task.clone(),
437 input,
438 &key,
439 task.inner.cross_exec_ctx_cache_ttl,
440 )
441 .await
442 {
443 Ok(outcome) => outcome,
444 Err(error) => return Err(error),
445 }
446 } else {
447 let child_ctx = task.child_ctx(self, &input);
448 let started_at = self.tasks.now();
449 self.tasks.observe(
450 task.inner.id,
451 task.inner.name,
452 TaskEventKind::RunStarted {
453 source: TaskRunSource::ExecCtx,
454 },
455 );
456 let outcome = tokio::select! {
457 result = task.call(child_ctx, input) => result,
458 _ = self.cancel.cancelled() => {
459 self.tasks.observe(task.inner.id, task.inner.name, TaskEventKind::Cancelled);
460 return Err(Error::Cancelled);
461 },
462 };
463 self.tasks.observe(
464 task.inner.id,
465 task.inner.name,
466 TaskEventKind::RunCompleted {
467 source: TaskRunSource::ExecCtx,
468 outcome: task_event_outcome(&outcome),
469 duration: self.tasks.now().saturating_duration_since(started_at),
470 },
471 );
472 if self.cancel.is_cancelled() {
473 self.tasks
474 .observe(task.inner.id, task.inner.name, TaskEventKind::Cancelled);
475 return Err(Error::Cancelled);
476 }
477 outcome
478 };
479
480 if outcome.is_cancelled() {
481 local_slot.abandon();
482 local_guard.disarm();
483 return decode_outcome::<O, E>(outcome, task.inner.name);
484 }
485
486 local_slot.finish(outcome.clone(), None);
487 local_guard.disarm();
488 decode_outcome::<O, E>(outcome, task.inner.name)
489 }
490
491 async fn wait_for_local<O>(
492 &self,
493 slot: &Arc<Slot<E>>,
494 task_id: TaskId,
495 task_name: &'static str,
496 ) -> Result<Arc<O>, E>
497 where
498 O: Send + Sync + 'static,
499 {
500 loop {
501 tokio::select! {
502 _ = slot.notify.notified() => {}
503 _ = self.cancel.cancelled() => {
504 self.tasks.observe(task_id, task_name, TaskEventKind::Cancelled);
505 return Err(Error::Cancelled);
506 },
507 }
508 match slot.claim() {
509 SlotClaim::Ready(outcome) => {
510 if self.cancel.is_cancelled() {
511 self.tasks
512 .observe(task_id, task_name, TaskEventKind::Cancelled);
513 return Err(Error::Cancelled);
514 }
515 return decode_outcome::<O, E>(outcome, task_name);
516 }
517 SlotClaim::Wait => {}
518 SlotClaim::Run => {
519 slot.abandon();
520 return Err(Error::Cancelled);
521 }
522 }
523 }
524 }
525
526 async fn resolve_shared<I, O>(
527 &self,
528 task: Task<I, O, E>,
529 input: I,
530 key: &KeyData<I>,
531 ttl: std::time::Duration,
532 ) -> Result<StoredOutcome<E>, E>
533 where
534 I: Clone + Eq + Hash + Send + Sync + 'static,
535 O: Send + Sync + 'static,
536 {
537 let shared_lookup = self
538 .tasks
539 .inner
540 .shared
541 .slot_for(key, Some(ttl), self.tasks.now());
542 let shared_slot = shared_lookup.slot;
543 if shared_lookup.stale_slots_removed > 0 {
544 self.tasks.observe(
545 task.inner.id,
546 task.inner.name,
547 TaskEventKind::CrossExecCtxStaleSlotRemoved {
548 count: shared_lookup.stale_slots_removed,
549 },
550 );
551 }
552
553 match shared_slot.claim() {
554 SlotClaim::Ready(outcome) => {
555 self.tasks.observe(
556 task.inner.id,
557 task.inner.name,
558 TaskEventKind::CrossExecCtxCacheHit,
559 );
560 if self.cancel.is_cancelled() {
561 self.tasks
562 .observe(task.inner.id, task.inner.name, TaskEventKind::Cancelled);
563 return Err(Error::Cancelled);
564 }
565 return Ok(outcome);
566 }
567 SlotClaim::Wait => {
568 self.tasks.observe(
569 task.inner.id,
570 task.inner.name,
571 TaskEventKind::CrossExecCtxInFlightWait,
572 );
573 }
574 SlotClaim::Run => {
575 self.tasks.observe(
576 task.inner.id,
577 task.inner.name,
578 TaskEventKind::CrossExecCtxCacheMiss,
579 );
580 let guard = RunningGuard::new(shared_slot.clone());
581 let task_for_spawn = task.clone();
582 let input_for_spawn = input.clone();
583 let run_ctx = task.shared_run_child(self, &input_for_spawn);
584 let tasks = self.tasks.clone();
585 let task_id = task.inner.id;
586 let task_name = task.inner.name;
587 let shared_slot_for_run = shared_slot.clone();
588 tokio::spawn(async move {
589 let started_at = tasks.now();
590 tasks.observe(
591 task_id,
592 task_name,
593 TaskEventKind::RunStarted {
594 source: TaskRunSource::CrossExecCtx,
595 },
596 );
597 let outcome = task_for_spawn.call(run_ctx, input_for_spawn).await;
598 tasks.observe(
599 task_id,
600 task_name,
601 TaskEventKind::RunCompleted {
602 source: TaskRunSource::CrossExecCtx,
603 outcome: task_event_outcome(&outcome),
604 duration: tasks.now().saturating_duration_since(started_at),
605 },
606 );
607 if outcome.is_cancelled() {
608 shared_slot_for_run.abandon();
609 guard.disarm();
610 return;
611 }
612 let expires_at = if outcome.is_ok() {
613 Some(tasks.now().saturating_add_duration(ttl))
614 } else {
615 Some(tasks.now())
616 };
617 if outcome.is_ok() {
618 tasks.observe(task_id, task_name, TaskEventKind::CrossExecCtxCacheInserted);
619 }
620 shared_slot_for_run.finish(outcome, expires_at);
621 guard.disarm();
622 });
623 }
624 }
625
626 loop {
627 if self.cancel.is_cancelled() {
628 self.tasks
629 .observe(task.inner.id, task.inner.name, TaskEventKind::Cancelled);
630 return Err(Error::Cancelled);
631 }
632 match shared_slot.claim() {
633 SlotClaim::Ready(outcome) => {
634 if self.cancel.is_cancelled() {
635 self.tasks.observe(
636 task.inner.id,
637 task.inner.name,
638 TaskEventKind::Cancelled,
639 );
640 return Err(Error::Cancelled);
641 }
642 return Ok(outcome);
643 }
644 SlotClaim::Run => {
645 shared_slot.abandon();
646 return Err(Error::Cancelled);
647 }
648 SlotClaim::Wait => {
649 tokio::select! {
650 _ = shared_slot.notify.notified() => {}
651 _ = self.cancel.cancelled() => {
652 self.tasks.observe(task.inner.id, task.inner.name, TaskEventKind::Cancelled);
653 return Err(Error::Cancelled);
654 },
655 }
656 }
657 }
658 }
659 }
660}
661
662impl<E> Clone for ExecCtx<E> {
663 fn clone(&self) -> Self {
664 Self {
665 tasks: self.tasks.clone(),
666 local: self.local.clone(),
667 cancel: self.cancel.clone(),
668 path: self.path.clone(),
669 }
670 }
671}
672
673impl<E> fmt::Debug for ExecCtx<E> {
674 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
675 f.debug_struct("ExecCtx")
676 .field("cancelled", &self.cancel.is_cancelled())
677 .finish_non_exhaustive()
678 }
679}
680
681fn task_event_outcome<E>(outcome: &StoredOutcome<E>) -> TaskEventOutcome {
682 match outcome {
683 StoredOutcome::Ok(_) => TaskEventOutcome::Success,
684 StoredOutcome::Err(error) if error.is_cancelled() => TaskEventOutcome::Cancelled,
685 StoredOutcome::Err(_) => TaskEventOutcome::Error,
686 }
687}