1use std::{any::Any, pin::Pin, sync::Arc};
12
13use tracing::Instrument as _;
14
15use crate::tasks::{errors::NodeError, transition::TransitionFn};
16
17use super::{
18 errors::TaskError,
19 node::{NodeArg, NodeId, NoopNode, TaskNode},
20 transition::{AnyNodeTransition, MarkedTransitionPayload, Transition, TransitionPayload},
21};
22
23#[derive(Debug)]
24pub struct Task<Input: NodeArg, Output: NodeArg> {
25 nodes: Vec<Box<dyn AnyNodeTransition>>,
26 current_node: usize,
27 start_node: usize,
28 current_context: Option<Arc<dyn Any + Send + Sync>>,
29 _marker: std::marker::PhantomData<(Input, Output)>,
30}
31
32impl<Input: NodeArg, Output: NodeArg> Clone for Task<Input, Output> {
33 fn clone(&self) -> Self {
34 Self {
35 nodes: self.nodes.clone(),
36 current_node: 0,
37 start_node: self.start_node,
38 current_context: None,
39 _marker: std::marker::PhantomData,
40 }
41 }
42}
43
44impl<Input: NodeArg + Clone, Output: NodeArg + Clone> Default for Task<Input, Output> {
45 fn default() -> Self {
46 Self::new()
47 }
48}
49
50impl<Input: NodeArg + Clone, Output: NodeArg + Clone> Task<Input, Output> {
51 pub fn new() -> Self {
52 let noop = NoopNode::<Output>::default();
53
54 let node_id = NodeId::new(0, &noop).as_dyn();
55
56 let noop_executor = Box::new(Transition {
57 node: Box::new(noop),
58 node_id: Box::new(node_id),
59 r#fn: Arc::new(|_output| {
60 Box::pin(async { unreachable!("Done node should never be evaluated.") })
61 }),
62 is_set: false,
63 });
64 Self {
65 nodes: vec![noop_executor],
66 current_node: 0,
67 start_node: 0,
68 current_context: None,
69 _marker: std::marker::PhantomData,
70 }
71 }
72
73 pub fn current_input(&self) -> Option<&Input> {
75 let input = self.current_context.as_ref()?;
76
77 input.downcast_ref::<Input>()
78 }
79
80 pub fn current_output(&self) -> Option<&Output> {
82 let input = self.current_context.as_ref()?;
83
84 input.downcast_ref::<Output>()
85 }
86
87 pub fn done(&self) -> NodeId<NoopNode<Output>> {
89 NodeId::new(0, &NoopNode::default())
90 }
91
92 pub fn transitions_to_done(
94 &self,
95 ) -> impl Fn(Output) -> MarkedTransitionPayload<NoopNode<Output>> + Send + Sync + 'static {
96 let done = self.done();
97 move |context| done.transitions_with(context)
98 }
99
100 pub fn starts_with<T: TaskNode<Input = Input> + Clone + 'static>(
102 &mut self,
103 node_id: NodeId<T>,
104 ) {
105 self.current_node = node_id.id;
106 self.start_node = node_id.id;
107 }
108
109 pub fn validate_transitions(&self) -> Result<(), TaskError> {
115 for node_executor in &self.nodes {
117 if node_executor.node_id() == 0 {
119 continue;
120 }
121
122 if !node_executor.transition_is_set() {
123 return Err(TaskError::missing_transition(node_executor.node_id()));
124 }
125 }
126
127 Ok(())
128 }
129
130 #[tracing::instrument(skip(self, input), name = "task.run", err)]
136 pub async fn run(&mut self, input: impl Into<Input>) -> Result<Option<Output>, TaskError> {
137 self.validate_transitions()?;
138
139 self.current_context = Some(Arc::new(input.into()) as Arc<dyn Any + Send + Sync>);
140
141 self.start_task().await
142 }
143
144 pub fn reset(&mut self) {
149 self.current_node = self.start_node;
150 }
151
152 #[tracing::instrument(skip(self), name = "task.resume", err)]
158 pub async fn resume(&mut self) -> Result<Option<Output>, TaskError> {
159 self.start_task().await
160 }
161
162 async fn start_task(&mut self) -> Result<Option<Output>, TaskError> {
163 self.validate_transitions()?;
164
165 let mut span = tracing::info_span!("task.step", node = self.current_node);
166 loop {
167 if self.current_node == 0 {
168 break;
169 }
170 let node_transition = self
171 .nodes
172 .get(self.current_node)
173 .ok_or_else(|| TaskError::missing_node(self.current_node))?;
174
175 let input = self
176 .current_context
177 .clone()
178 .ok_or_else(|| TaskError::missing_input(self.current_node))?;
179
180 tracing::debug!("Running node {}", self.current_node);
181
182 let span_id = span.id().clone();
183 let transition_payload = node_transition
184 .evaluate_next(input)
185 .instrument(span.or_current())
186 .await?;
187
188 match transition_payload {
189 TransitionPayload::Pause => {
190 tracing::info!("Task paused at node {}", self.current_node);
191 return Ok(None);
192 }
193 TransitionPayload::NextNode(transition_payload) => {
194 self.current_node = transition_payload.node_id;
195 self.current_context = Some(transition_payload.context);
196 }
197 TransitionPayload::Error(error) => {
198 return Err(TaskError::NodeError(NodeError::new(
199 error,
200 self.current_node,
201 None,
202 )));
203 }
204 }
205 if self.current_node == 0 {
206 tracing::debug!("Task completed at node {}", self.current_node);
207 break;
208 }
209
210 span = tracing::info_span!("task.step", node = self.current_node).or_current();
211 span.follows_from(span_id);
212 }
213
214 let output = self
215 .current_context
216 .clone()
217 .ok_or_else(|| TaskError::missing_output(self.current_node))?;
218 let output = output
219 .downcast::<Output>()
220 .map_err(|e| TaskError::type_error(&e))?
221 .as_ref()
222 .clone();
223
224 Ok(Some(output))
225 }
226
227 pub fn current_node<T: TaskNode + 'static>(&self) -> Option<&T> {
229 self.node_at_index(self.current_node)
230 }
231
232 pub fn node_at<T: TaskNode + 'static>(&self, node_id: NodeId<T>) -> Option<&T> {
234 self.node_at_index(node_id.id)
235 }
236
237 pub fn node_at_index<T: TaskNode + 'static>(&self, index: usize) -> Option<&T> {
239 let transition = self.transition_at_index::<T>(index)?;
240
241 let node = &*transition.node;
242
243 (node as &dyn Any).downcast_ref::<T>()
244 }
245
246 #[allow(dead_code)]
248 fn current_transition<T: TaskNode + 'static>(
249 &self,
250 ) -> Option<&Transition<T::Input, T::Output, T::Error>> {
251 self.transition_at_index::<T>(self.current_node)
252 }
253
254 fn transition_at_index<T: TaskNode + 'static>(
256 &self,
257 index: usize,
258 ) -> Option<&Transition<T::Input, T::Output, T::Error>> {
259 tracing::debug!("Getting transition at index {}", index);
260 let transition = self.nodes.get(index)?;
261
262 dbg!(&transition);
263
264 (&**transition as &dyn Any).downcast_ref::<Transition<T::Input, T::Output, T::Error>>()
265 }
266
267 pub fn register_node<T>(&mut self, node: T) -> NodeId<T>
269 where
270 T: TaskNode + 'static + Clone,
271 <T as TaskNode>::Input: Clone,
272 <T as TaskNode>::Output: Clone,
273 {
274 let id = self.nodes.len();
275 let node_id = NodeId::new(id, &node);
276 let node_executor = Box::new(Transition::<T::Input, T::Output, T::Error> {
277 node_id: Box::new(node_id.as_dyn()),
278 node: Box::new(node),
279 r#fn: Arc::new(move |_output| unreachable!("No transition for node {}.", node_id.id)),
280 is_set: false,
281 });
282 tracing::debug!(node_id = ?node_id, type_name = std::any::type_name_of_val(&node_executor), "Registering node");
284
285 self.nodes.push(node_executor);
286
287 node_id
288 }
289
290 pub fn register_transition<'a, From, To, F>(
298 &mut self,
299 from: NodeId<From>,
300 transition: F,
301 ) -> Result<(), TaskError>
302 where
303 From: TaskNode + 'static + ?Sized,
304 To: TaskNode<Input = From::Output> + 'a + ?Sized,
305 F: Fn(To::Input) -> MarkedTransitionPayload<To> + Send + Sync + 'static,
306 {
307 let node_executor = self
308 .nodes
309 .get_mut(from.id)
310 .ok_or_else(|| TaskError::missing_node(from.id))?;
311
312 let any_executor: &mut dyn Any = node_executor.as_mut();
313
314 let Some(exec) =
315 any_executor.downcast_mut::<Transition<From::Input, From::Output, From::Error>>()
316 else {
317 let expected =
318 std::any::type_name::<Transition<From::Input, From::Output, From::Error>>();
319 let actual = std::any::type_name_of_val(node_executor);
320
321 unreachable!(
322 "Transition at index {:?} is not a {expected:?}; Mismatched types, should not never happen. Actual: {actual:?}",
323 from.id
324 );
325 };
326 let transition = Arc::new(transition);
327 let wrapped: Arc<dyn TransitionFn<From::Output>> = Arc::new(move |output: From::Output| {
328 let transition = transition.clone();
329 Box::pin(async move {
330 let output = transition(output);
331 output.into_inner()
332 })
333 });
334
335 exec.r#fn = wrapped;
336 exec.is_set = true;
337 Ok(())
340 }
341
342 pub fn register_transition_async<'a, From, To, F>(
353 &mut self,
354 from: NodeId<From>,
355 transition: F,
356 ) -> Result<(), TaskError>
357 where
358 From: TaskNode + 'static + ?Sized,
359 To: TaskNode<Input = From::Output> + 'a + ?Sized,
360 F: Fn(To::Input) -> Pin<Box<dyn Future<Output = MarkedTransitionPayload<To>> + Send>>
361 + Send
362 + Sync
363 + 'static,
364 {
365 let node_executor = self
366 .nodes
367 .get_mut(from.id)
368 .ok_or_else(|| TaskError::missing_node(from.id))?;
369
370 let any_executor: &mut dyn Any = node_executor.as_mut();
371
372 let Some(exec) =
373 any_executor.downcast_mut::<Transition<From::Input, From::Output, From::Error>>()
374 else {
375 let expected =
376 std::any::type_name::<Transition<From::Input, From::Output, From::Error>>();
377 let actual = std::any::type_name_of_val(node_executor);
378
379 unreachable!(
380 "Transition at index {:?} is not a {expected:?}; Mismatched types, should not never happen. Actual: {actual:?}",
381 from.id
382 );
383 };
384 let transition = Arc::new(transition);
385 let wrapped: Arc<dyn TransitionFn<From::Output>> = Arc::new(move |output: From::Output| {
386 let transition = transition.clone();
387
388 Box::pin(async move {
389 let output = transition(output).await;
390 output.into_inner()
391 })
392 });
393
394 exec.r#fn = wrapped;
395 exec.is_set = true;
396 Ok(())
399 }
400}
401
402#[cfg(test)]
403mod tests {
404 use async_trait::async_trait;
405
406 use super::*;
407
408 #[derive(thiserror::Error, Debug)]
409 struct Error(String);
410
411 impl std::fmt::Display for Error {
412 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
413 write!(f, "{}", self.0)
414 }
415 }
416
417 #[derive(Clone, Default, Debug)]
418 struct IntNode;
419 #[async_trait]
420 impl TaskNode for IntNode {
421 type Input = i32;
422 type Output = i32;
423 type Error = Error;
424
425 async fn evaluate(
426 &self,
427 _node_id: &NodeId<
428 dyn TaskNode<Input = Self::Input, Output = Self::Output, Error = Self::Error>,
429 >,
430 input: &Self::Input,
431 ) -> Result<Self::Output, Self::Error> {
432 Ok(input + 1)
433 }
434 }
435 #[test_log::test(tokio::test)]
438 async fn sequential_3_node_task_reset_works() {
439 let mut task: Task<i32, i32> = Task::new();
440
441 let node1 = task.register_node(IntNode);
443 let node2 = task.register_node(IntNode);
444 let node3 = task.register_node(IntNode);
445
446 task.starts_with(node1);
448
449 task.register_transition::<_, _, _>(node1, move |input| node2.transitions_with(input))
451 .unwrap();
452 task.register_transition::<_, _, _>(node2, move |input| node3.transitions_with(input))
453 .unwrap();
454 task.register_transition::<_, _, _>(node3, task.transitions_to_done())
455 .unwrap();
456
457 let res = task.run(1).await.unwrap();
459 assert_eq!(res, Some(4)); task.reset();
463
464 dbg!(&task);
466 let n1_transition = task.transition_at_index::<IntNode>(1);
467
468 assert!(n1_transition.is_some());
469
470 let n1_transition = task.current_transition::<IntNode>();
471 assert!(n1_transition.is_some());
472
473 let n1_ref = task.current_node::<IntNode>();
474 assert!(n1_ref.is_some());
475 }
476}