1use std::{any::Any, pin::Pin, sync::Arc};
12
13use crate::tasks::{errors::NodeError, transition::TransitionFn};
14
15use super::{
16 errors::TaskError,
17 node::{NodeArg, NodeId, NoopNode, TaskNode},
18 transition::{AnyNodeTransition, MarkedTransitionPayload, Transition, TransitionPayload},
19};
20
21#[derive(Debug)]
22pub struct Task<Input: NodeArg, Output: NodeArg> {
23 nodes: Vec<Box<dyn AnyNodeTransition>>,
24 current_node: usize,
25 start_node: usize,
26 current_context: Option<Arc<dyn Any + Send + Sync>>,
27 _marker: std::marker::PhantomData<(Input, Output)>,
28}
29
30impl<Input: NodeArg, Output: NodeArg> Clone for Task<Input, Output> {
31 fn clone(&self) -> Self {
32 Self {
33 nodes: self.nodes.clone(),
34 current_node: 0,
35 start_node: self.start_node,
36 current_context: None,
37 _marker: std::marker::PhantomData,
38 }
39 }
40}
41
42impl<Input: NodeArg + Clone, Output: NodeArg + Clone> Default for Task<Input, Output> {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl<Input: NodeArg + Clone, Output: NodeArg + Clone> Task<Input, Output> {
49 pub fn new() -> Self {
50 let noop = NoopNode::<Output>::default();
51
52 let node_id = NodeId::new(0, &noop).as_dyn();
53
54 let noop_executor = Box::new(Transition {
55 node: Box::new(noop),
56 node_id: Box::new(node_id),
57 r#fn: Arc::new(|_output| {
58 Box::pin(async { unreachable!("Done node should never be evaluated.") })
59 }),
60 is_set: false,
61 });
62 Self {
63 nodes: vec![noop_executor],
64 current_node: 0,
65 start_node: 0,
66 current_context: None,
67 _marker: std::marker::PhantomData,
68 }
69 }
70
71 pub fn done(&self) -> NodeId<NoopNode<Output>> {
72 NodeId::new(0, &NoopNode::default())
73 }
74
75 pub fn transitions_to_done(
77 &self,
78 ) -> impl Fn(Output) -> MarkedTransitionPayload<NoopNode<Output>> + Send + Sync + 'static {
79 let done = self.done();
80 move |context| done.transitions_with(context)
81 }
82
83 pub fn starts_with<T: TaskNode<Input = Input> + Clone + 'static>(
85 &mut self,
86 node_id: NodeId<T>,
87 ) {
88 self.current_node = node_id.id;
89 self.start_node = node_id.id;
90 }
91
92 pub fn validate_transitions(&self) -> Result<(), TaskError> {
98 for node_executor in &self.nodes {
100 if node_executor.node_id() == 0 {
102 continue;
103 }
104
105 if !node_executor.transition_is_set() {
106 return Err(TaskError::missing_transition(node_executor.node_id()));
107 }
108 }
109
110 Ok(())
111 }
112
113 pub async fn run(&mut self, input: impl Into<Input>) -> Result<Option<Output>, TaskError> {
119 self.validate_transitions()?;
120
121 self.current_context = Some(Arc::new(input.into()) as Arc<dyn Any + Send + Sync>);
122
123 self.resume().await
124 }
125
126 pub fn reset(&mut self) {
131 self.current_node = self.start_node;
132 }
133
134 pub async fn resume(&mut self) -> Result<Option<Output>, TaskError> {
140 self.validate_transitions()?;
141
142 loop {
143 if self.current_node == 0 {
144 break;
145 }
146 let node_transition = self
147 .nodes
148 .get(self.current_node)
149 .ok_or_else(|| TaskError::missing_node(self.current_node))?;
150
151 let input = self
152 .current_context
153 .clone()
154 .ok_or_else(|| TaskError::missing_input(self.current_node))?;
155
156 tracing::debug!("Running node {}", self.current_node);
157 let transition_payload = node_transition.evaluate_next(input).await?;
158
159 match transition_payload {
160 TransitionPayload::Pause => {
161 tracing::info!("Task paused at node {}", self.current_node);
162 return Ok(None);
163 }
164 TransitionPayload::NextNode(transition_payload) => {
165 self.current_node = transition_payload.node_id;
166 self.current_context = Some(transition_payload.context);
167 }
168 TransitionPayload::Error(error) => {
169 return Err(TaskError::NodeError(NodeError::new(
170 error,
171 self.current_node,
172 None,
173 )));
174 }
175 }
176 }
177
178 let output = self
179 .current_context
180 .clone()
181 .ok_or_else(|| TaskError::missing_output(self.current_node))?;
182 let output = output
183 .downcast::<Output>()
184 .map_err(|e| TaskError::type_error(&e))?
185 .as_ref()
186 .clone();
187
188 Ok(Some(output))
189 }
190
191 pub fn current_node<T: TaskNode + 'static>(&self) -> Option<&T> {
193 self.node_at_index(self.current_node)
194 }
195
196 pub fn node_at<T: TaskNode + 'static>(&self, node_id: NodeId<T>) -> Option<&T> {
198 self.node_at_index(node_id.id)
199 }
200
201 pub fn node_at_index<T: TaskNode + 'static>(&self, index: usize) -> Option<&T> {
203 let transition = self.transition_at_index::<T>(index)?;
204
205 let node = &*transition.node;
206
207 (node as &dyn Any).downcast_ref::<T>()
208 }
209
210 #[allow(dead_code)]
212 fn current_transition<T: TaskNode + 'static>(
213 &self,
214 ) -> Option<&Transition<T::Input, T::Output, T::Error>> {
215 self.transition_at_index::<T>(self.current_node)
216 }
217
218 fn transition_at_index<T: TaskNode + 'static>(
220 &self,
221 index: usize,
222 ) -> Option<&Transition<T::Input, T::Output, T::Error>> {
223 tracing::debug!("Getting transition at index {}", index);
224 let transition = self.nodes.get(index)?;
225
226 dbg!(&transition);
227
228 (&**transition as &dyn Any).downcast_ref::<Transition<T::Input, T::Output, T::Error>>()
229 }
230
231 pub fn register_node<T>(&mut self, node: T) -> NodeId<T>
233 where
234 T: TaskNode + 'static + Clone,
235 <T as TaskNode>::Input: Clone,
236 <T as TaskNode>::Output: Clone,
237 {
238 let id = self.nodes.len();
239 let node_id = NodeId::new(id, &node);
240 let node_executor = Box::new(Transition::<T::Input, T::Output, T::Error> {
241 node_id: Box::new(node_id.as_dyn()),
242 node: Box::new(node),
243 r#fn: Arc::new(move |_output| unreachable!("No transition for node {}.", node_id.id)),
244 is_set: false,
245 });
246 tracing::debug!(node_id = ?node_id, type_name = std::any::type_name_of_val(&node_executor), "Registering node");
248
249 self.nodes.push(node_executor);
250
251 node_id
252 }
253
254 pub fn register_transition<'a, From, To, F>(
262 &mut self,
263 from: NodeId<From>,
264 transition: F,
265 ) -> Result<(), TaskError>
266 where
267 From: TaskNode + 'static + ?Sized,
268 To: TaskNode<Input = From::Output> + 'a + ?Sized,
269 F: Fn(To::Input) -> MarkedTransitionPayload<To> + Send + Sync + 'static,
270 {
271 let node_executor = self
272 .nodes
273 .get_mut(from.id)
274 .ok_or_else(|| TaskError::missing_node(from.id))?;
275
276 let any_executor: &mut dyn Any = node_executor.as_mut();
277
278 let Some(exec) =
279 any_executor.downcast_mut::<Transition<From::Input, From::Output, From::Error>>()
280 else {
281 let expected =
282 std::any::type_name::<Transition<From::Input, From::Output, From::Error>>();
283 let actual = std::any::type_name_of_val(node_executor);
284
285 unreachable!(
286 "Transition at index {:?} is not a {expected:?}; Mismatched types, should not never happen. Actual: {actual:?}",
287 from.id
288 );
289 };
290 let transition = Arc::new(transition);
291 let wrapped: Arc<dyn TransitionFn<From::Output>> = Arc::new(move |output: From::Output| {
292 let transition = transition.clone();
293 Box::pin(async move {
294 let output = transition(output);
295 output.into_inner()
296 })
297 });
298
299 exec.r#fn = wrapped;
300 exec.is_set = true;
301 Ok(())
304 }
305
306 pub fn register_transition_async<'a, From, To, F>(
317 &mut self,
318 from: NodeId<From>,
319 transition: F,
320 ) -> Result<(), TaskError>
321 where
322 From: TaskNode + 'static + ?Sized,
323 To: TaskNode<Input = From::Output> + 'a + ?Sized,
324 F: Fn(To::Input) -> Pin<Box<dyn Future<Output = MarkedTransitionPayload<To>> + Send>>
325 + Send
326 + Sync
327 + 'static,
328 {
329 let node_executor = self
330 .nodes
331 .get_mut(from.id)
332 .ok_or_else(|| TaskError::missing_node(from.id))?;
333
334 let any_executor: &mut dyn Any = node_executor.as_mut();
335
336 let Some(exec) =
337 any_executor.downcast_mut::<Transition<From::Input, From::Output, From::Error>>()
338 else {
339 let expected =
340 std::any::type_name::<Transition<From::Input, From::Output, From::Error>>();
341 let actual = std::any::type_name_of_val(node_executor);
342
343 unreachable!(
344 "Transition at index {:?} is not a {expected:?}; Mismatched types, should not never happen. Actual: {actual:?}",
345 from.id
346 );
347 };
348 let transition = Arc::new(transition);
349 let wrapped: Arc<dyn TransitionFn<From::Output>> = Arc::new(move |output: From::Output| {
350 let transition = transition.clone();
351
352 Box::pin(async move {
353 let output = transition(output).await;
354 output.into_inner()
355 })
356 });
357
358 exec.r#fn = wrapped;
359 exec.is_set = true;
360 Ok(())
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use async_trait::async_trait;
369
370 use super::*;
371
372 #[derive(thiserror::Error, Debug)]
373 struct Error(String);
374
375 impl std::fmt::Display for Error {
376 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
377 write!(f, "{}", self.0)
378 }
379 }
380
381 #[derive(Clone, Default, Debug)]
382 struct IntNode;
383 #[async_trait]
384 impl TaskNode for IntNode {
385 type Input = i32;
386 type Output = i32;
387 type Error = Error;
388
389 async fn evaluate(
390 &self,
391 _node_id: &NodeId<
392 dyn TaskNode<Input = Self::Input, Output = Self::Output, Error = Self::Error>,
393 >,
394 input: &Self::Input,
395 ) -> Result<Self::Output, Self::Error> {
396 Ok(input + 1)
397 }
398 }
399 #[test_log::test(tokio::test)]
402 async fn sequential_3_node_task_reset_works() {
403 let mut task: Task<i32, i32> = Task::new();
404
405 let node1 = task.register_node(IntNode);
407 let node2 = task.register_node(IntNode);
408 let node3 = task.register_node(IntNode);
409
410 task.starts_with(node1);
412
413 task.register_transition::<_, _, _>(node1, move |input| node2.transitions_with(input))
415 .unwrap();
416 task.register_transition::<_, _, _>(node2, move |input| node3.transitions_with(input))
417 .unwrap();
418 task.register_transition::<_, _, _>(node3, task.transitions_to_done())
419 .unwrap();
420
421 let res = task.run(1).await.unwrap();
423 assert_eq!(res, Some(4)); task.reset();
427
428 dbg!(&task);
430 let n1_transition = task.transition_at_index::<IntNode>(1);
431
432 assert!(n1_transition.is_some());
433
434 let n1_transition = task.current_transition::<IntNode>();
435 assert!(n1_transition.is_some());
436
437 let n1_ref = task.current_node::<IntNode>();
438 assert!(n1_ref.is_some());
439 }
440}