swiftide_agents/tasks/
transition.rs

1use std::{any::Any, pin::Pin, sync::Arc};
2
3use async_trait::async_trait;
4use dyn_clone::DynClone;
5
6use super::{
7    errors::NodeError,
8    node::{NodeArg, NodeId, TaskNode},
9};
10
11pub trait TransitionFn<Input: Send + Sync>:
12    for<'a> Fn(Input) -> Pin<Box<dyn Future<Output = TransitionPayload> + Send>> + Send + Sync
13{
14}
15
16// dyn_clone::clone_trait_object!(<Input> TransitionFn<Input>);
17
18impl<Input: Send + Sync, F> TransitionFn<Input> for F where
19    F: for<'a> Fn(Input) -> Pin<Box<dyn Future<Output = TransitionPayload> + Send>> + Send + Sync
20{
21}
22
23pub(crate) struct Transition<
24    Input: NodeArg,
25    Output: NodeArg,
26    Error: std::error::Error + Send + Sync + 'static,
27> {
28    pub(crate) node: Box<dyn TaskNode<Input = Input, Output = Output, Error = Error> + Send + Sync>,
29    pub(crate) node_id: Box<NodeId<dyn TaskNode<Input = Input, Output = Output, Error = Error>>>,
30    // pub(crate) r#fn: Arc<dyn Fn(Output) -> TransitionPayload + Send + Sync>,
31    pub(crate) r#fn: Arc<dyn TransitionFn<Output> + Send>,
32    pub(crate) is_set: bool,
33}
34
35impl<Input, Output, Error> Clone for Transition<Input, Output, Error>
36where
37    Input: NodeArg,
38    Output: NodeArg,
39    Error: std::error::Error + Send + Sync + 'static,
40{
41    fn clone(&self) -> Self {
42        Transition {
43            node: self.node.clone(),
44            node_id: self.node_id.clone(),
45            r#fn: self.r#fn.clone(),
46            is_set: self.is_set,
47        }
48    }
49}
50
51impl<Input: NodeArg, Output: NodeArg, Error: std::error::Error + Send + Sync + 'static>
52    std::fmt::Debug for Transition<Input, Output, Error>
53{
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        f.debug_struct("Transition")
56            .field("node_id", &self.node_id.id)
57            .field("is_set", &self.is_set)
58            .finish()
59    }
60}
61
62#[derive(Debug, Clone)]
63pub struct NextNode {
64    // If we make this an enum instead, we can support spawning many nodes as well
65    pub(crate) node_id: usize,
66    pub(crate) context: Arc<dyn Any + Send + Sync>,
67}
68
69impl NextNode {
70    pub fn new<T: TaskNode + ?Sized>(node_id: NodeId<T>, context: T::Input) -> Self
71    where
72        <T as TaskNode>::Input: 'static,
73    {
74        let context = Arc::new(context) as Arc<dyn Any + Send + Sync>;
75
76        NextNode {
77            node_id: node_id.id,
78            context,
79        }
80    }
81}
82
83impl From<NextNode> for TransitionPayload {
84    fn from(next_node: NextNode) -> Self {
85        TransitionPayload::NextNode(next_node)
86    }
87}
88
89#[derive(Debug)]
90pub enum TransitionPayload {
91    NextNode(NextNode),
92    Pause,
93    Error(Box<dyn std::error::Error + Send + Sync>),
94}
95
96impl TransitionPayload {
97    pub fn next_node<T: TaskNode + ?Sized>(node_id: &NodeId<T>, context: T::Input) -> Self {
98        NextNode::new(*node_id, context).into()
99    }
100
101    pub fn pause() -> Self {
102        TransitionPayload::Pause
103    }
104
105    pub fn error(error: impl Into<Box<dyn std::error::Error + Send + Sync>>) -> Self {
106        TransitionPayload::Error(error.into())
107    }
108}
109
110pub struct MarkedTransitionPayload<To: TaskNode + ?Sized>(
111    TransitionPayload,
112    std::marker::PhantomData<To>,
113);
114
115impl<To: TaskNode + ?Sized> MarkedTransitionPayload<To> {
116    pub fn new(payload: TransitionPayload) -> Self {
117        MarkedTransitionPayload(payload, std::marker::PhantomData)
118    }
119
120    pub fn into_inner(self) -> TransitionPayload {
121        self.0
122    }
123}
124
125impl<T: TaskNode> std::ops::Deref for MarkedTransitionPayload<T> {
126    type Target = TransitionPayload;
127
128    fn deref(&self) -> &Self::Target {
129        &self.0
130    }
131}
132
133#[async_trait]
134pub(crate) trait AnyNodeTransition: Any + Send + Sync + std::fmt::Debug + DynClone {
135    fn transition_is_set(&self) -> bool;
136
137    async fn evaluate_next(
138        &self,
139        context: Arc<dyn Any + Send + Sync>,
140    ) -> Result<TransitionPayload, NodeError>;
141
142    fn node_id(&self) -> usize;
143}
144
145dyn_clone::clone_trait_object!(AnyNodeTransition);
146
147#[async_trait]
148impl<Input: NodeArg, Output: NodeArg, Error: std::error::Error + Send + Sync + 'static>
149    AnyNodeTransition for Transition<Input, Output, Error>
150{
151    async fn evaluate_next(
152        &self,
153        context: Arc<dyn Any + Send + Sync>,
154    ) -> Result<TransitionPayload, NodeError> {
155        let context = context.downcast::<Input>().unwrap();
156
157        match self.node.evaluate(&self.node_id.as_dyn(), &context).await {
158            Ok(output) => Ok((self.r#fn)(output).await),
159            Err(error) => Err(NodeError::new(error, self.node_id.id, None)), /* node_id will be
160                                                                              * set by caller */
161        }
162    }
163
164    fn transition_is_set(&self) -> bool {
165        self.is_set
166    }
167
168    fn node_id(&self) -> usize {
169        self.node_id.id
170    }
171}