swiftide_agents/tasks/
transition.rs1use std::{any::Any, pin::Pin, sync::Arc};
2
3use async_trait::async_trait;
4use dyn_clone::DynClone;
5
6use super::{
7 errors::NodeError,
8 node::{NodeArg, NodeId, TaskNode},
9};
10
11pub trait TransitionFn<Input: Send + Sync>:
12 for<'a> Fn(Input) -> Pin<Box<dyn Future<Output = TransitionPayload> + Send>> + Send + Sync
13{
14}
15
16impl<Input: Send + Sync, F> TransitionFn<Input> for F where
19 F: for<'a> Fn(Input) -> Pin<Box<dyn Future<Output = TransitionPayload> + Send>> + Send + Sync
20{
21}
22
23pub(crate) struct Transition<
24 Input: NodeArg,
25 Output: NodeArg,
26 Error: std::error::Error + Send + Sync + 'static,
27> {
28 pub(crate) node: Box<dyn TaskNode<Input = Input, Output = Output, Error = Error> + Send + Sync>,
29 pub(crate) node_id: Box<NodeId<dyn TaskNode<Input = Input, Output = Output, Error = Error>>>,
30 pub(crate) r#fn: Arc<dyn TransitionFn<Output> + Send>,
32 pub(crate) is_set: bool,
33}
34
35impl<Input, Output, Error> Clone for Transition<Input, Output, Error>
36where
37 Input: NodeArg,
38 Output: NodeArg,
39 Error: std::error::Error + Send + Sync + 'static,
40{
41 fn clone(&self) -> Self {
42 Transition {
43 node: self.node.clone(),
44 node_id: self.node_id.clone(),
45 r#fn: self.r#fn.clone(),
46 is_set: self.is_set,
47 }
48 }
49}
50
51impl<Input: NodeArg, Output: NodeArg, Error: std::error::Error + Send + Sync + 'static>
52 std::fmt::Debug for Transition<Input, Output, Error>
53{
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 f.debug_struct("Transition")
56 .field("node_id", &self.node_id.id)
57 .field("is_set", &self.is_set)
58 .finish()
59 }
60}
61
62#[derive(Debug, Clone)]
63pub struct NextNode {
64 pub(crate) node_id: usize,
66 pub(crate) context: Arc<dyn Any + Send + Sync>,
67}
68
69impl NextNode {
70 pub fn new<T: TaskNode + ?Sized>(node_id: NodeId<T>, context: T::Input) -> Self
71 where
72 <T as TaskNode>::Input: 'static,
73 {
74 let context = Arc::new(context) as Arc<dyn Any + Send + Sync>;
75
76 NextNode {
77 node_id: node_id.id,
78 context,
79 }
80 }
81}
82
83impl From<NextNode> for TransitionPayload {
84 fn from(next_node: NextNode) -> Self {
85 TransitionPayload::NextNode(next_node)
86 }
87}
88
89#[derive(Debug)]
90pub enum TransitionPayload {
91 NextNode(NextNode),
92 Pause,
93 Error(Box<dyn std::error::Error + Send + Sync>),
94}
95
96impl TransitionPayload {
97 pub fn next_node<T: TaskNode + ?Sized>(node_id: &NodeId<T>, context: T::Input) -> Self {
98 NextNode::new(*node_id, context).into()
99 }
100
101 pub fn pause() -> Self {
102 TransitionPayload::Pause
103 }
104
105 pub fn error(error: impl Into<Box<dyn std::error::Error + Send + Sync>>) -> Self {
106 TransitionPayload::Error(error.into())
107 }
108}
109
110pub struct MarkedTransitionPayload<To: TaskNode + ?Sized>(
111 TransitionPayload,
112 std::marker::PhantomData<To>,
113);
114
115impl<To: TaskNode + ?Sized> MarkedTransitionPayload<To> {
116 pub fn new(payload: TransitionPayload) -> Self {
117 MarkedTransitionPayload(payload, std::marker::PhantomData)
118 }
119
120 pub fn into_inner(self) -> TransitionPayload {
121 self.0
122 }
123}
124
125impl<T: TaskNode> std::ops::Deref for MarkedTransitionPayload<T> {
126 type Target = TransitionPayload;
127
128 fn deref(&self) -> &Self::Target {
129 &self.0
130 }
131}
132
133#[async_trait]
134pub(crate) trait AnyNodeTransition: Any + Send + Sync + std::fmt::Debug + DynClone {
135 fn transition_is_set(&self) -> bool;
136
137 async fn evaluate_next(
138 &self,
139 context: Arc<dyn Any + Send + Sync>,
140 ) -> Result<TransitionPayload, NodeError>;
141
142 fn node_id(&self) -> usize;
143}
144
145dyn_clone::clone_trait_object!(AnyNodeTransition);
146
147#[async_trait]
148impl<Input: NodeArg, Output: NodeArg, Error: std::error::Error + Send + Sync + 'static>
149 AnyNodeTransition for Transition<Input, Output, Error>
150{
151 async fn evaluate_next(
152 &self,
153 context: Arc<dyn Any + Send + Sync>,
154 ) -> Result<TransitionPayload, NodeError> {
155 let context = context.downcast::<Input>().unwrap();
156
157 match self.node.evaluate(&self.node_id.as_dyn(), &context).await {
158 Ok(output) => Ok((self.r#fn)(output).await),
159 Err(error) => Err(NodeError::new(error, self.node_id.id, None)), }
162 }
163
164 fn transition_is_set(&self) -> bool {
165 self.is_set
166 }
167
168 fn node_id(&self) -> usize {
169 self.node_id.id
170 }
171}