1use crate::channels::BaseChannel;
8use crate::checkpoint::{BaseCheckpointSaver, Checkpoint, CheckpointMetadata, StateSnapshot};
9use crate::config::Config;
10use crate::errors::{Error, Result};
11use crate::nodes::PregelNode;
12use crate::state::State;
13use crate::types::{StreamEvent, StreamMode};
14use futures::stream::{Stream, StreamExt};
15use std::collections::{HashMap, HashSet};
16use std::pin::Pin;
17use std::sync::Arc;
18
19pub struct Pregel<S: State> {
33 nodes: HashMap<String, PregelNode<S>>,
35
36 channels: HashMap<String, Box<dyn BaseChannel>>,
38
39 checkpointer: Option<Arc<dyn BaseCheckpointSaver>>,
41
42 entry_point: String,
44
45 finish_points: HashSet<String>,
47
48 edges: HashMap<String, Vec<String>>,
50
51 current_step: usize,
53
54 recursion_limit: usize,
56
57 written_channels: HashSet<String>,
59}
60
61impl<S: State> Pregel<S> {
62 pub fn new(
64 nodes: HashMap<String, PregelNode<S>>,
65 channels: HashMap<String, Box<dyn BaseChannel>>,
66 checkpointer: Option<Arc<dyn BaseCheckpointSaver>>,
67 entry_point: String,
68 finish_points: HashSet<String>,
69 edges: HashMap<String, Vec<String>>,
70 ) -> Self {
71 Self {
72 nodes,
73 channels,
74 checkpointer,
75 entry_point,
76 finish_points,
77 edges,
78 current_step: 0,
79 recursion_limit: 25,
80 written_channels: HashSet::new(),
81 }
82 }
83
84 pub fn with_recursion_limit(mut self, limit: usize) -> Self {
86 self.recursion_limit = limit;
87 self
88 }
89
90 pub async fn invoke(&mut self, input: S, config: Config) -> Result<S> {
92 self.recursion_limit = config.recursion_limit;
93 self.current_step = 0;
94
95 if let Some(checkpointer) = &self.checkpointer {
97 if let Some(tuple) = checkpointer.get_tuple(&config).await? {
98 self.restore_channels(&tuple.checkpoint)?;
99 self.current_step = tuple.metadata.step;
100 }
101 }
102
103 self.write_input_to_channels(&input)?;
105
106 loop {
108 if self.current_step >= self.recursion_limit {
110 return Err(Error::RecursionLimitError {
111 current: self.current_step,
112 limit: self.recursion_limit,
113 });
114 }
115
116 let triggered_nodes = self.find_triggered_nodes();
118 if triggered_nodes.is_empty() {
119 break; }
121
122 let mut tasks = Vec::new();
124 for node_name in &triggered_nodes {
125 if let Some(node) = self.nodes.get(node_name) {
126 let state = self.read_state_for_node(node)?;
127 let node_clone = node.clone();
128 let config_clone = config.clone();
129
130 let task = tokio::spawn(async move {
131 node_clone.bound.invoke(state, &config_clone).await
132 });
133
134 tasks.push((node_name.clone(), task));
135 }
136 }
137
138 let mut updates: HashMap<String, S> = HashMap::new();
140 for (node_name, task) in tasks {
141 match task.await {
142 Ok(Ok(result)) => {
143 updates.insert(node_name, result);
144 }
145 Ok(Err(e)) => return Err(e),
146 Err(e) => {
147 return Err(Error::execution(format!("Node execution panicked: {}", e)))
148 }
149 }
150 }
151
152 self.apply_updates(updates)?;
154
155 if let Some(checkpointer) = &self.checkpointer {
160 let checkpoint = self.create_checkpoint(&config)?;
161 let metadata = CheckpointMetadata {
162 step: self.current_step,
163 source: "pregel".to_string(),
164 created_at: chrono::Utc::now(),
165 extra: HashMap::new(),
166 };
167 checkpointer.put(&checkpoint, &metadata, &config).await?;
168 }
169
170 self.current_step += 1;
171 self.written_channels.clear(); }
173
174 self.get_final_state()
176 }
177
178 pub async fn stream(
180 &mut self,
181 input: S,
182 config: Config,
183 mode: StreamMode,
184 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + std::marker::Send>>> {
185 self.recursion_limit = config.recursion_limit;
186 self.current_step = 0;
187
188 let (tx, rx) = tokio::sync::mpsc::channel(100);
192
193 if let Some(checkpointer) = &self.checkpointer {
195 if let Some(tuple) = checkpointer.get_tuple(&config).await? {
196 self.restore_channels(&tuple.checkpoint)?;
197 self.current_step = tuple.metadata.step;
198 }
199 }
200
201 self.write_input_to_channels(&input)?;
203
204 let _nodes = self.nodes.clone();
206 let _channels: HashMap<String, Box<dyn BaseChannel>> = HashMap::new();
207 let _checkpointer = self.checkpointer.clone();
212 let _entry_point = self.entry_point.clone();
213 let recursion_limit = self.recursion_limit;
214
215 tokio::spawn(async move {
216 let mut step = 0;
217 loop {
218 if step >= recursion_limit {
219 let _ = tx.send(Err(Error::RecursionLimitError {
220 current: step,
221 limit: recursion_limit,
222 })).await;
223 break;
224 }
225
226 if matches!(mode, StreamMode::Values) {
231 let event = StreamEvent::Values {
232 ns: vec![],
233 data: serde_json::json!({"step": step}),
234 interrupts: vec![],
235 };
236 if tx.send(Ok(event)).await.is_err() {
237 break;
238 }
239 }
240
241 step += 1;
242 break; }
244 });
245
246 Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
247 }
248
249 pub async fn get_state(&self, config: &Config) -> Result<Option<StateSnapshot<S>>> {
251 if let Some(checkpointer) = &self.checkpointer {
252 if let Some(tuple) = checkpointer.get_tuple(config).await? {
253 let state = self.state_from_checkpoint(&tuple.checkpoint)?;
254 return Ok(Some(StateSnapshot {
255 state,
256 checkpoint: tuple.checkpoint,
257 metadata: tuple.metadata,
258 config: tuple.config,
259 }));
260 }
261 }
262 Ok(None)
263 }
264
265 pub async fn get_state_history(
267 &self,
268 config: &Config,
269 limit: Option<usize>,
270 ) -> Result<Vec<StateSnapshot<S>>> {
271 if let Some(checkpointer) = &self.checkpointer {
272 let tuples = checkpointer.list(config, limit).await?;
273 let mut snapshots = Vec::new();
274
275 for tuple in tuples {
276 let state = self.state_from_checkpoint(&tuple.checkpoint)?;
277 snapshots.push(StateSnapshot {
278 state,
279 checkpoint: tuple.checkpoint,
280 metadata: tuple.metadata,
281 config: tuple.config,
282 });
283 }
284
285 return Ok(snapshots);
286 }
287 Ok(Vec::new())
288 }
289
290 fn write_input_to_channels(&mut self, input: &S) -> Result<()> {
294 let value = input.to_value()?;
296 if let Some(channel) = self.channels.get_mut("__start__") {
297 channel.update(vec![value])?;
298 self.written_channels.insert("__start__".to_string());
299 }
300 Ok(())
301 }
302
303 fn find_triggered_nodes(&self) -> Vec<String> {
305 let mut triggered = Vec::new();
306
307 for (name, node) in &self.nodes {
308 if node.is_triggered(&self.written_channels.iter().cloned().collect::<Vec<_>>()) {
309 triggered.push(name.clone());
310 }
311 }
312
313 if triggered.is_empty() && self.current_step == 0 {
315 triggered.push(self.entry_point.clone());
316 }
317
318 triggered
319 }
320
321 fn read_state_for_node(&self, _node: &PregelNode<S>) -> Result<S> {
323 if let Some(channel) = self.channels.get("__start__") {
327 if let Some(value) = channel.get()? {
328 return S::from_value(value);
329 }
330 }
331
332 Err(Error::state("Cannot construct state from channels"))
335 }
336
337 fn apply_updates(&mut self, updates: HashMap<String, S>) -> Result<()> {
339 for (node_name, state) in updates {
340 if let Some(node) = self.nodes.get(&node_name) {
342 for writer in &node.writers {
344 let value = state.to_value()?;
345 if let Some(channel) = self.channels.get_mut(&writer.channel) {
346 channel.update(vec![value.clone()])?;
347 self.written_channels.insert(writer.channel.clone());
348 }
349 }
350 }
351
352 if let Some(targets) = self.edges.get(&node_name) {
354 for target in targets {
355 self.written_channels.insert(format!("{}_input", target));
357 }
358 }
359 }
360
361 Ok(())
362 }
363
364 fn create_checkpoint(&self, config: &Config) -> Result<Checkpoint> {
366 let mut checkpoint = Checkpoint::new();
367
368 if let Some(thread_id) = &config.thread_id {
369 checkpoint.thread_id = Some(thread_id.clone());
370 }
371
372 for (name, channel) in &self.channels {
374 let channel_data = channel.checkpoint()?;
375 checkpoint.set_channel(name, channel_data);
376 }
377
378 Ok(checkpoint)
379 }
380
381 fn restore_channels(&mut self, checkpoint: &Checkpoint) -> Result<()> {
383 for (name, value) in &checkpoint.channel_values {
384 if let Some(channel) = self.channels.get_mut(name) {
387 channel.update(vec![value.clone()])?;
388 }
389 }
390 Ok(())
391 }
392
393 fn state_from_checkpoint(&self, checkpoint: &Checkpoint) -> Result<S> {
395 if let Some(value) = checkpoint.get_channel("__state__") {
397 return S::from_value(value.clone());
398 }
399
400 if let Some(value) = checkpoint.get_channel("__start__") {
402 return S::from_value(value.clone());
403 }
404
405 Err(Error::checkpoint("Cannot construct state from checkpoint"))
406 }
407
408 fn get_final_state(&self) -> Result<S> {
410 if let Some(channel) = self.channels.get("__end__") {
412 if let Some(value) = channel.get()? {
413 return S::from_value(value);
414 }
415 }
416
417 if let Some(channel) = self.channels.get("__start__") {
419 if let Some(value) = channel.get()? {
420 return S::from_value(value);
421 }
422 }
423
424 Err(Error::state("Cannot determine final state"))
425 }
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431 use crate::channels::{LastValue};
432 use crate::nodes::{Node, PregelNode, ChannelWrite};
433 use crate::state::State as StateTrait;
434 use serde::{Deserialize, Serialize};
435
436 #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
437 struct TestState {
438 count: i32,
439 }
440
441 impl StateTrait for TestState {
442 fn merge(&mut self, other: Self) -> Result<()> {
443 self.count += other.count;
444 Ok(())
445 }
446 }
447
448 #[tokio::test]
449 async fn test_pregel_basic() {
450 let increment_node = PregelNode::from_node(
451 "increment",
452 vec!["__start__".to_string()],
453 vec!["__start__".to_string()],
454 |mut state: TestState, _config: &Config| async move {
455 state.count += 1;
456 Ok(state)
457 },
458 vec![ChannelWrite::new("__end__")],
459 );
460
461 let mut nodes = HashMap::new();
462 nodes.insert("increment".to_string(), increment_node);
463
464 let mut channels: HashMap<String, Box<dyn BaseChannel>> = HashMap::new();
465 channels.insert("__start__".to_string(), Box::new(LastValue::<TestState>::new()));
466 channels.insert("__end__".to_string(), Box::new(LastValue::<TestState>::new()));
467
468 let mut pregel = Pregel::new(
469 nodes,
470 channels,
471 None,
472 "increment".to_string(),
473 HashSet::from(["increment".to_string()]),
474 HashMap::new(),
475 );
476
477 let input = TestState { count: 0 };
478 let result = pregel.invoke(input, Config::default()).await.unwrap();
479
480 assert_eq!(result.count, 1);
481 }
482}