1use crate::config::Config;
7use crate::errors::Result;
8use crate::state::State;
9use async_trait::async_trait;
10use std::fmt::Debug;
11use std::future::Future;
12use std::sync::Arc;
13
14#[async_trait]
42pub trait Node<S: State>: Send + Sync {
43 async fn invoke(&self, state: S, config: &Config) -> Result<S>;
54}
55
56#[async_trait]
58impl<S, F, Fut> Node<S> for F
59where
60 S: State,
61 F: Fn(S, &Config) -> Fut + Send + Sync,
62 Fut: Future<Output = Result<S>> + Send,
63{
64 async fn invoke(&self, state: S, config: &Config) -> Result<S> {
65 self(state, config).await
66 }
67}
68
69pub type NodeBox<S> = Box<dyn Node<S>>;
71
72pub type NodeArc<S> = Arc<dyn Node<S>>;
74
75#[derive(Clone)]
80pub struct PregelNode<S: State> {
81 pub name: String,
83
84 pub channels: Vec<String>,
86
87 pub triggers: Vec<String>,
89
90 pub bound: NodeArc<S>,
92
93 pub writers: Vec<ChannelWrite>,
95}
96
97impl<S: State> Debug for PregelNode<S> {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 f.debug_struct("PregelNode")
100 .field("name", &self.name)
101 .field("channels", &self.channels)
102 .field("triggers", &self.triggers)
103 .field("bound", &"<node>")
104 .field("writers", &self.writers)
105 .finish()
106 }
107}
108
109impl<S: State> PregelNode<S> {
110 pub fn new(
112 name: impl Into<String>,
113 channels: Vec<String>,
114 triggers: Vec<String>,
115 bound: NodeArc<S>,
116 writers: Vec<ChannelWrite>,
117 ) -> Self {
118 Self {
119 name: name.into(),
120 channels,
121 triggers,
122 bound,
123 writers,
124 }
125 }
126
127 pub fn from_node(
129 name: impl Into<String>,
130 channels: Vec<String>,
131 triggers: Vec<String>,
132 bound: impl Node<S> + 'static,
133 writers: Vec<ChannelWrite>,
134 ) -> Self {
135 Self {
136 name: name.into(),
137 channels,
138 triggers,
139 bound: Arc::new(bound),
140 writers,
141 }
142 }
143
144 pub fn is_triggered(&self, written_channels: &[String]) -> bool {
146 self.triggers.iter().any(|t| written_channels.contains(t))
147 }
148}
149
150#[derive(Debug, Clone)]
152pub struct ChannelWrite {
153 pub channel: String,
155
156 pub skip_none: bool,
158
159 pub mapper: Option<String>,
161}
162
163impl ChannelWrite {
164 pub fn new(channel: impl Into<String>) -> Self {
166 Self {
167 channel: channel.into(),
168 skip_none: true,
169 mapper: None,
170 }
171 }
172
173 pub fn with_skip_none(mut self, skip: bool) -> Self {
175 self.skip_none = skip;
176 self
177 }
178}
179
180pub fn node_fn<S, F, Fut>(f: F) -> impl Node<S>
182where
183 S: State,
184 F: Fn(S, &Config) -> Fut + Send + Sync + 'static,
185 Fut: Future<Output = Result<S>> + Send + 'static,
186{
187 f
188}
189
190pub fn simple_node<S, F, Fut>(f: F) -> impl Node<S>
192where
193 S: State,
194 F: Fn(S) -> Fut + Send + Sync + 'static,
195 Fut: Future<Output = Result<S>> + Send + 'static,
196{
197 move |state: S, _config: &Config| f(state)
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203 use crate::state::State as StateTrait;
204 use serde::{Deserialize, Serialize};
205
206 #[derive(Clone, Debug, Serialize, Deserialize)]
207 struct TestState {
208 count: i32,
209 }
210
211 impl StateTrait for TestState {
212 fn merge(&mut self, other: Self) -> Result<()> {
213 self.count += other.count;
214 Ok(())
215 }
216 }
217
218 #[tokio::test]
219 async fn test_node_from_closure() {
220 let node = |mut state: TestState, _config: &Config| async move {
221 state.count += 1;
222 Ok(state)
223 };
224
225 let state = TestState { count: 0 };
226 let result = node.invoke(state, &Config::default()).await.unwrap();
227 assert_eq!(result.count, 1);
228 }
229
230 #[tokio::test]
231 async fn test_simple_node() {
232 let node = simple_node(|mut state: TestState| async move {
233 state.count += 10;
234 Ok(state)
235 });
236
237 let state = TestState { count: 5 };
238 let result = node.invoke(state, &Config::default()).await.unwrap();
239 assert_eq!(result.count, 15);
240 }
241
242 struct CustomNode;
243
244 #[async_trait]
245 impl Node<TestState> for CustomNode {
246 async fn invoke(&self, mut state: TestState, _config: &Config) -> Result<TestState> {
247 state.count *= 2;
248 Ok(state)
249 }
250 }
251
252 #[tokio::test]
253 async fn test_custom_node() {
254 let node = CustomNode;
255 let state = TestState { count: 5 };
256 let result = node.invoke(state, &Config::default()).await.unwrap();
257 assert_eq!(result.count, 10);
258 }
259
260 #[test]
261 fn test_pregel_node_is_triggered() {
262 let node = PregelNode::from_node(
263 "test",
264 vec!["in".to_string()],
265 vec!["trigger_a".to_string(), "trigger_b".to_string()],
266 |state: TestState, _: &Config| async move { Ok(state) },
267 vec![],
268 );
269
270 assert!(node.is_triggered(&["trigger_a".to_string()]));
271 assert!(node.is_triggered(&["trigger_b".to_string()]));
272 assert!(node.is_triggered(&["trigger_a".to_string(), "other".to_string()]));
273 assert!(!node.is_triggered(&["other".to_string()]));
274 assert!(!node.is_triggered(&[]));
275 }
276
277 #[test]
278 fn test_channel_write() {
279 let write = ChannelWrite::new("output").with_skip_none(false);
280 assert_eq!(write.channel, "output");
281 assert!(!write.skip_none);
282 }
283}