Skip to main content

vorma_tasks/
task.rs

1use std::fmt;
2use std::future::Future;
3use std::hash::Hash;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::time::Duration;
8
9use crate::cancel::CancelToken;
10use crate::clock::{Clock, ClockInstant, SystemClock};
11use crate::error::{Error, Result};
12use crate::key::{KeyData, PathKey, TaskId};
13use crate::observer::{TaskEvent, TaskEventKind, TaskEventOutcome, TaskObserver, TaskRunSource};
14use crate::overrides::{TaskOverride, TaskOverrides};
15use crate::store::{RunningGuard, Slot, SlotClaim, Store, StoredOutcome, decode_outcome};
16
17static NEXT_TASK_ID: AtomicU64 = AtomicU64::new(1);
18
19type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;
20type TaskFn<I, O, E> = dyn Fn(ExecCtx<E>, I) -> BoxFuture<Result<O, E>> + Send + Sync;
21type PreparedTaskFn<E> = dyn FnOnce(ExecCtx<E>) -> BoxFuture<Result<(), E>> + Send;
22type PreparedTaskResultFn<O> = dyn FnOnce(Arc<O>) + Send;
23
24/// Stable typed unit of async work.
25pub struct Task<I, O, E = Box<dyn std::error::Error + Send + Sync>> {
26	inner: Arc<TaskInner<I, O, E>>,
27}
28
29impl<I, O, E> Task<I, O, E>
30where
31	I: Clone + Eq + Hash + Send + Sync + 'static,
32	O: Send + Sync + 'static,
33	E: Send + Sync + 'static,
34{
35	/// Create a task.
36	///
37	/// `Duration::ZERO` disables cross-execution-context caching. A nonzero TTL retains
38	/// successful results across execution contexts created from the same [`Tasks`]
39	/// runtime.
40	pub fn new<F, Fut>(cross_exec_ctx_cache_ttl: Duration, f: F) -> Self
41	where
42		F: Fn(ExecCtx<E>, I) -> Fut + Send + Sync + 'static,
43		Fut: Future<Output = Result<O, E>> + Send + 'static,
44	{
45		Self {
46			inner: Arc::new(TaskInner {
47				id: TaskId(NEXT_TASK_ID.fetch_add(1, Ordering::Relaxed)),
48				cross_exec_ctx_cache_ttl,
49				name: std::any::type_name::<I>(),
50				f: Box::new(move |ctx, input| Box::pin(f(ctx, input))),
51			}),
52		}
53	}
54
55	/// Opaque identity for this task definition inside the current process.
56	pub fn id(&self) -> TaskId {
57		self.inner.id
58	}
59
60	/// Run this task in an execution context.
61	pub async fn run(&self, ctx: &ExecCtx<E>, input: I) -> Result<Arc<O>, E> {
62		ctx.resolve(self.clone(), input).await
63	}
64
65	/// Bind input for later execution through [`ExecCtx::run_parallel`].
66	pub fn bind_input(&self, input: I) -> PreparedTask<E> {
67		self.bind_input_with_result(input, |_| {})
68	}
69
70	/// Bind input and a result sink for later execution through [`ExecCtx::run_parallel`].
71	pub fn bind_input_with_result<F>(&self, input: I, result: F) -> PreparedTask<E>
72	where
73		F: FnOnce(Arc<O>) + Send + 'static,
74	{
75		let task = self.clone();
76		let mut result = Some(Box::new(result) as Box<PreparedTaskResultFn<O>>);
77		PreparedTask::new(move |ctx| async move {
78			let output = task.run(&ctx, input).await?;
79			if let Some(result) = result.take() {
80				result(output);
81			}
82			Ok(())
83		})
84	}
85
86	fn key(&self, input: &I) -> KeyData<I> {
87		KeyData::new(self.inner.id, input)
88	}
89
90	fn child_ctx(&self, ctx: &ExecCtx<E>, input: &I) -> ExecCtx<E> {
91		ctx.child_for_task(PathKey::new(self.inner.id, input))
92	}
93
94	fn shared_run_child(&self, ctx: &ExecCtx<E>, input: &I) -> ExecCtx<E> {
95		ctx.shared_run_child(PathKey::new(self.inner.id, input))
96	}
97
98	async fn call(&self, ctx: ExecCtx<E>, input: I) -> StoredOutcome<E> {
99		let task_override = ctx.tasks.task_override_for(self);
100		match task_override {
101			TaskOverride::RunOriginal => self.call_original(ctx, input).await,
102			TaskOverride::Replace(replacement) => replacement.call(ctx, input).await,
103			TaskOverride::Missing => StoredOutcome::Err(Error::MissingOverride {
104				task: self.inner.name,
105			}),
106		}
107	}
108
109	async fn call_original(&self, ctx: ExecCtx<E>, input: I) -> StoredOutcome<E> {
110		match (self.inner.f)(ctx, input).await {
111			Ok(output) => StoredOutcome::Ok(Arc::new(output)),
112			Err(error) => StoredOutcome::Err(error),
113		}
114	}
115}
116
117/// A task with input bound, ready for [`ExecCtx::run_parallel`].
118pub struct PreparedTask<E = Box<dyn std::error::Error + Send + Sync>> {
119	run: Box<PreparedTaskFn<E>>,
120}
121
122impl<E> PreparedTask<E>
123where
124	E: Send + Sync + 'static,
125{
126	fn new<F, Fut>(run: F) -> Self
127	where
128		F: FnOnce(ExecCtx<E>) -> Fut + Send + 'static,
129		Fut: Future<Output = Result<(), E>> + Send + 'static,
130	{
131		Self {
132			run: Box::new(move |ctx| Box::pin(run(ctx))),
133		}
134	}
135
136	async fn run(self, ctx: ExecCtx<E>) -> Result<(), E> {
137		(self.run)(ctx).await
138	}
139}
140
141impl<I, O, E> Clone for Task<I, O, E> {
142	fn clone(&self) -> Self {
143		Self {
144			inner: self.inner.clone(),
145		}
146	}
147}
148
149struct TaskInner<I, O, E> {
150	id: TaskId,
151	cross_exec_ctx_cache_ttl: Duration,
152	name: &'static str,
153	f: Box<TaskFn<I, O, E>>,
154}
155
156/// Owns shared cross-execution-context task state.
157pub struct Tasks<E = Box<dyn std::error::Error + Send + Sync>> {
158	inner: Arc<TasksInner<E>>,
159}
160
161/// Configuration for a [`Tasks`] runtime.
162pub struct TasksOptions<E = Box<dyn std::error::Error + Send + Sync>> {
163	/// Monotonic clock used for cross-execution-context cache expiry and event timing.
164	pub clock: Arc<dyn Clock>,
165	/// Passive observer for task runtime events.
166	pub observer: Option<Arc<dyn TaskObserver>>,
167	/// Typed task-body substitutions for test and dry-run runtimes.
168	pub overrides: Option<TaskOverrides<E>>,
169}
170
171impl<E> Tasks<E>
172where
173	E: Send + Sync + 'static,
174{
175	/// Create a long-lived task runtime with explicit options.
176	pub fn new(options: TasksOptions<E>) -> Self {
177		Self {
178			inner: Arc::new(TasksInner {
179				clock: options.clock,
180				observer: options.observer,
181				shared: Store::new(),
182				overrides: options.overrides,
183			}),
184		}
185	}
186
187	fn task_override_for<I, O>(&self, task: &Task<I, O, E>) -> TaskOverride<I, O, E>
188	where
189		I: Clone + Eq + Hash + Send + Sync + 'static,
190		O: Send + Sync + 'static,
191	{
192		if let Some(overrides) = &self.inner.overrides {
193			overrides.resolve(task)
194		} else {
195			TaskOverride::RunOriginal
196		}
197	}
198
199	/// Create one execution context with its own memoization state.
200	pub fn exec_ctx(&self, cancel: CancelToken) -> ExecCtx<E> {
201		ExecCtx {
202			tasks: self.clone(),
203			local: Arc::new(Store::new()),
204			cancel,
205			path: Arc::new(Vec::new()),
206		}
207	}
208
209	fn now(&self) -> ClockInstant {
210		self.inner.clock.now()
211	}
212
213	fn observe(&self, task_id: TaskId, task_input_type: &'static str, kind: TaskEventKind) {
214		if let Some(observer) = &self.inner.observer {
215			observer.observe(TaskEvent {
216				at: self.now(),
217				task_id,
218				task_input_type,
219				kind,
220			});
221		}
222	}
223}
224
225impl<E> Clone for Tasks<E> {
226	fn clone(&self) -> Self {
227		Self {
228			inner: self.inner.clone(),
229		}
230	}
231}
232
233impl<E> fmt::Debug for Tasks<E> {
234	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
235		f.debug_struct("Tasks").finish_non_exhaustive()
236	}
237}
238
239struct TasksInner<E> {
240	clock: Arc<dyn Clock>,
241	observer: Option<Arc<dyn TaskObserver>>,
242	shared: Store<E>,
243	overrides: Option<TaskOverrides<E>>,
244}
245
246impl<E> Default for TasksOptions<E>
247where
248	E: 'static,
249{
250	fn default() -> Self {
251		Self {
252			clock: Arc::new(SystemClock::new()),
253			observer: None,
254			overrides: None,
255		}
256	}
257}
258
259/// One execution context with task memoization, cancellation, and dependency coalescing.
260pub struct ExecCtx<E = Box<dyn std::error::Error + Send + Sync>> {
261	tasks: Tasks<E>,
262	local: Arc<Store<E>>,
263	cancel: CancelToken,
264	path: Arc<Vec<PathKey>>,
265}
266
267impl<E> ExecCtx<E>
268where
269	E: Send + Sync + 'static,
270{
271	/// Cancellation token associated with this execution context.
272	pub fn cancel_token(&self) -> &CancelToken {
273		&self.cancel
274	}
275
276	/// Whether this execution context has been cancelled.
277	pub fn is_cancelled(&self) -> bool {
278		self.cancel.is_cancelled()
279	}
280
281	/// Create a child execution context with shared memoization and child cancellation.
282	pub fn child(&self) -> Self {
283		Self {
284			tasks: self.tasks.clone(),
285			local: self.local.clone(),
286			cancel: self.cancel.child(),
287			path: self.path.clone(),
288		}
289	}
290
291	/// Run prepared tasks concurrently in a shared child execution context.
292	///
293	/// All prepared tasks share one execution context for memoization, and the first
294	/// failing sibling cancels the remaining siblings.
295	pub async fn run_parallel(
296		&self,
297		tasks: impl IntoIterator<Item = PreparedTask<E>>,
298	) -> Result<(), E> {
299		if self.cancel.is_cancelled() {
300			return Err(Error::Cancelled);
301		}
302
303		let tasks = tasks.into_iter().collect::<Vec<_>>();
304		match tasks.len() {
305			0 => return Ok(()),
306			1 => {
307				let mut tasks = tasks;
308				let task = tasks.pop().expect("length checked");
309				return task.run(self.child()).await;
310			}
311			_ => {}
312		}
313
314		let shared = self.child();
315		let cancel = shared.cancel.clone();
316		let mut joined = tokio::task::JoinSet::new();
317		for task in tasks {
318			let ctx = shared.clone();
319			joined.spawn(async move { task.run(ctx).await });
320		}
321
322		let mut first_error = None;
323		while let Some(result) = joined.join_next().await {
324			match result {
325				Ok(Ok(())) => {}
326				Ok(Err(error)) => {
327					let replace_first_error =
328						first_error.as_ref().is_none_or(|first: &Error<E>| {
329							first.is_cancelled() && !error.is_cancelled()
330						});
331					if replace_first_error {
332						first_error = Some(error);
333					}
334					cancel.cancel();
335				}
336				Err(join_error) if join_error.is_panic() => {
337					std::panic::resume_unwind(join_error.into_panic());
338				}
339				Err(_) => {
340					if first_error.is_none() {
341						first_error = Some(Error::Cancelled);
342					}
343					cancel.cancel();
344				}
345			}
346		}
347
348		if let Some(error) = first_error {
349			return Err(error);
350		}
351		if self.cancel.is_cancelled() {
352			return Err(Error::Cancelled);
353		}
354		Ok(())
355	}
356
357	fn child_for_task(&self, key: PathKey) -> Self {
358		let mut path = Vec::with_capacity(self.path.len() + 1);
359		path.extend(self.path.iter().cloned());
360		path.push(key);
361		Self {
362			tasks: self.tasks.clone(),
363			local: self.local.clone(),
364			cancel: self.cancel.clone(),
365			path: Arc::new(path),
366		}
367	}
368
369	fn shared_run_child(&self, key: PathKey) -> Self {
370		let mut path = Vec::with_capacity(self.path.len() + 1);
371		path.extend(self.path.iter().cloned());
372		path.push(key);
373		Self {
374			tasks: self.tasks.clone(),
375			local: self.local.clone(),
376			cancel: CancelToken::new(),
377			path: Arc::new(path),
378		}
379	}
380
381	async fn resolve<I, O>(&self, task: Task<I, O, E>, input: I) -> Result<Arc<O>, E>
382	where
383		I: Clone + Eq + Hash + Send + Sync + 'static,
384		O: Send + Sync + 'static,
385	{
386		let key = task.key(&input);
387		if self.path.iter().any(|path_key| path_key.matches(&key)) {
388			return Err(Error::Cycle {
389				task: task.inner.name,
390			});
391		}
392		if self.cancel.is_cancelled() {
393			self.tasks
394				.observe(task.inner.id, task.inner.name, TaskEventKind::Cancelled);
395			return Err(Error::Cancelled);
396		}
397
398		let local_slot = self.local.slot_for(&key, None, self.tasks.now()).slot;
399		let local_guard = match local_slot.claim() {
400			SlotClaim::Ready(outcome) => {
401				self.tasks.observe(
402					task.inner.id,
403					task.inner.name,
404					TaskEventKind::ExecCtxMemoHit,
405				);
406				if self.cancel.is_cancelled() {
407					self.tasks
408						.observe(task.inner.id, task.inner.name, TaskEventKind::Cancelled);
409					return Err(Error::Cancelled);
410				}
411				return decode_outcome::<O, E>(outcome, task.inner.name);
412			}
413			SlotClaim::Wait => {
414				self.tasks.observe(
415					task.inner.id,
416					task.inner.name,
417					TaskEventKind::ExecCtxMemoWait,
418				);
419				return self
420					.wait_for_local::<O>(&local_slot, task.inner.id, task.inner.name)
421					.await;
422			}
423			SlotClaim::Run => {
424				self.tasks.observe(
425					task.inner.id,
426					task.inner.name,
427					TaskEventKind::ExecCtxMemoMiss,
428				);
429				RunningGuard::new(local_slot.clone())
430			}
431		};
432
433		let outcome = if !task.inner.cross_exec_ctx_cache_ttl.is_zero() {
434			match self
435				.resolve_shared(
436					task.clone(),
437					input,
438					&key,
439					task.inner.cross_exec_ctx_cache_ttl,
440				)
441				.await
442			{
443				Ok(outcome) => outcome,
444				Err(error) => return Err(error),
445			}
446		} else {
447			let child_ctx = task.child_ctx(self, &input);
448			let started_at = self.tasks.now();
449			self.tasks.observe(
450				task.inner.id,
451				task.inner.name,
452				TaskEventKind::RunStarted {
453					source: TaskRunSource::ExecCtx,
454				},
455			);
456			let outcome = tokio::select! {
457				result = task.call(child_ctx, input) => result,
458				_ = self.cancel.cancelled() => {
459					self.tasks.observe(task.inner.id, task.inner.name, TaskEventKind::Cancelled);
460					return Err(Error::Cancelled);
461				},
462			};
463			self.tasks.observe(
464				task.inner.id,
465				task.inner.name,
466				TaskEventKind::RunCompleted {
467					source: TaskRunSource::ExecCtx,
468					outcome: task_event_outcome(&outcome),
469					duration: self.tasks.now().saturating_duration_since(started_at),
470				},
471			);
472			if self.cancel.is_cancelled() {
473				self.tasks
474					.observe(task.inner.id, task.inner.name, TaskEventKind::Cancelled);
475				return Err(Error::Cancelled);
476			}
477			outcome
478		};
479
480		if outcome.is_cancelled() {
481			local_slot.abandon();
482			local_guard.disarm();
483			return decode_outcome::<O, E>(outcome, task.inner.name);
484		}
485
486		local_slot.finish(outcome.clone(), None);
487		local_guard.disarm();
488		decode_outcome::<O, E>(outcome, task.inner.name)
489	}
490
491	async fn wait_for_local<O>(
492		&self,
493		slot: &Arc<Slot<E>>,
494		task_id: TaskId,
495		task_name: &'static str,
496	) -> Result<Arc<O>, E>
497	where
498		O: Send + Sync + 'static,
499	{
500		loop {
501			tokio::select! {
502				_ = slot.notify.notified() => {}
503				_ = self.cancel.cancelled() => {
504					self.tasks.observe(task_id, task_name, TaskEventKind::Cancelled);
505					return Err(Error::Cancelled);
506				},
507			}
508			match slot.claim() {
509				SlotClaim::Ready(outcome) => {
510					if self.cancel.is_cancelled() {
511						self.tasks
512							.observe(task_id, task_name, TaskEventKind::Cancelled);
513						return Err(Error::Cancelled);
514					}
515					return decode_outcome::<O, E>(outcome, task_name);
516				}
517				SlotClaim::Wait => {}
518				SlotClaim::Run => {
519					slot.abandon();
520					return Err(Error::Cancelled);
521				}
522			}
523		}
524	}
525
526	async fn resolve_shared<I, O>(
527		&self,
528		task: Task<I, O, E>,
529		input: I,
530		key: &KeyData<I>,
531		ttl: std::time::Duration,
532	) -> Result<StoredOutcome<E>, E>
533	where
534		I: Clone + Eq + Hash + Send + Sync + 'static,
535		O: Send + Sync + 'static,
536	{
537		let shared_lookup = self
538			.tasks
539			.inner
540			.shared
541			.slot_for(key, Some(ttl), self.tasks.now());
542		let shared_slot = shared_lookup.slot;
543		if shared_lookup.stale_slots_removed > 0 {
544			self.tasks.observe(
545				task.inner.id,
546				task.inner.name,
547				TaskEventKind::CrossExecCtxStaleSlotRemoved {
548					count: shared_lookup.stale_slots_removed,
549				},
550			);
551		}
552
553		match shared_slot.claim() {
554			SlotClaim::Ready(outcome) => {
555				self.tasks.observe(
556					task.inner.id,
557					task.inner.name,
558					TaskEventKind::CrossExecCtxCacheHit,
559				);
560				if self.cancel.is_cancelled() {
561					self.tasks
562						.observe(task.inner.id, task.inner.name, TaskEventKind::Cancelled);
563					return Err(Error::Cancelled);
564				}
565				return Ok(outcome);
566			}
567			SlotClaim::Wait => {
568				self.tasks.observe(
569					task.inner.id,
570					task.inner.name,
571					TaskEventKind::CrossExecCtxInFlightWait,
572				);
573			}
574			SlotClaim::Run => {
575				self.tasks.observe(
576					task.inner.id,
577					task.inner.name,
578					TaskEventKind::CrossExecCtxCacheMiss,
579				);
580				let guard = RunningGuard::new(shared_slot.clone());
581				let task_for_spawn = task.clone();
582				let input_for_spawn = input.clone();
583				let run_ctx = task.shared_run_child(self, &input_for_spawn);
584				let tasks = self.tasks.clone();
585				let task_id = task.inner.id;
586				let task_name = task.inner.name;
587				let shared_slot_for_run = shared_slot.clone();
588				tokio::spawn(async move {
589					let started_at = tasks.now();
590					tasks.observe(
591						task_id,
592						task_name,
593						TaskEventKind::RunStarted {
594							source: TaskRunSource::CrossExecCtx,
595						},
596					);
597					let outcome = task_for_spawn.call(run_ctx, input_for_spawn).await;
598					tasks.observe(
599						task_id,
600						task_name,
601						TaskEventKind::RunCompleted {
602							source: TaskRunSource::CrossExecCtx,
603							outcome: task_event_outcome(&outcome),
604							duration: tasks.now().saturating_duration_since(started_at),
605						},
606					);
607					if outcome.is_cancelled() {
608						shared_slot_for_run.abandon();
609						guard.disarm();
610						return;
611					}
612					let expires_at = if outcome.is_ok() {
613						Some(tasks.now().saturating_add_duration(ttl))
614					} else {
615						Some(tasks.now())
616					};
617					if outcome.is_ok() {
618						tasks.observe(task_id, task_name, TaskEventKind::CrossExecCtxCacheInserted);
619					}
620					shared_slot_for_run.finish(outcome, expires_at);
621					guard.disarm();
622				});
623			}
624		}
625
626		loop {
627			if self.cancel.is_cancelled() {
628				self.tasks
629					.observe(task.inner.id, task.inner.name, TaskEventKind::Cancelled);
630				return Err(Error::Cancelled);
631			}
632			match shared_slot.claim() {
633				SlotClaim::Ready(outcome) => {
634					if self.cancel.is_cancelled() {
635						self.tasks.observe(
636							task.inner.id,
637							task.inner.name,
638							TaskEventKind::Cancelled,
639						);
640						return Err(Error::Cancelled);
641					}
642					return Ok(outcome);
643				}
644				SlotClaim::Run => {
645					shared_slot.abandon();
646					return Err(Error::Cancelled);
647				}
648				SlotClaim::Wait => {
649					tokio::select! {
650						_ = shared_slot.notify.notified() => {}
651						_ = self.cancel.cancelled() => {
652							self.tasks.observe(task.inner.id, task.inner.name, TaskEventKind::Cancelled);
653							return Err(Error::Cancelled);
654						},
655					}
656				}
657			}
658		}
659	}
660}
661
662impl<E> Clone for ExecCtx<E> {
663	fn clone(&self) -> Self {
664		Self {
665			tasks: self.tasks.clone(),
666			local: self.local.clone(),
667			cancel: self.cancel.clone(),
668			path: self.path.clone(),
669		}
670	}
671}
672
673impl<E> fmt::Debug for ExecCtx<E> {
674	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
675		f.debug_struct("ExecCtx")
676			.field("cancelled", &self.cancel.is_cancelled())
677			.finish_non_exhaustive()
678	}
679}
680
681fn task_event_outcome<E>(outcome: &StoredOutcome<E>) -> TaskEventOutcome {
682	match outcome {
683		StoredOutcome::Ok(_) => TaskEventOutcome::Success,
684		StoredOutcome::Err(error) if error.is_cancelled() => TaskEventOutcome::Cancelled,
685		StoredOutcome::Err(_) => TaskEventOutcome::Error,
686	}
687}