1use crate::channels::BaseChannel;
8use crate::checkpoint::{BaseCheckpointSaver, Checkpoint, CheckpointMetadata, StateSnapshot};
9use crate::config::Config;
10use crate::errors::{Error, Result};
11use crate::graph::START;
12use crate::nodes::PregelNode;
13use crate::state::State;
14use crate::types::{StreamEvent, StreamMode};
15use futures::stream::{Stream, StreamExt};
16use std::collections::{HashMap, HashSet};
17use std::pin::Pin;
18use std::sync::Arc;
19
20pub struct Pregel<S: State> {
34 nodes: HashMap<String, PregelNode<S>>,
36
37 channels: HashMap<String, Box<dyn BaseChannel>>,
39
40 checkpointer: Option<Arc<dyn BaseCheckpointSaver>>,
42
43 entry_point: String,
45
46 finish_points: HashSet<String>,
48
49 edges: HashMap<String, Vec<String>>,
51
52 current_step: usize,
54
55 recursion_limit: usize,
57
58 written_channels: HashSet<String>,
60}
61
62impl<S: State> Pregel<S> {
63 pub fn new(
65 nodes: HashMap<String, PregelNode<S>>,
66 channels: HashMap<String, Box<dyn BaseChannel>>,
67 checkpointer: Option<Arc<dyn BaseCheckpointSaver>>,
68 entry_point: String,
69 finish_points: HashSet<String>,
70 edges: HashMap<String, Vec<String>>,
71 ) -> Self {
72 Self {
73 nodes,
74 channels,
75 checkpointer,
76 entry_point,
77 finish_points,
78 edges,
79 current_step: 0,
80 recursion_limit: 25,
81 written_channels: HashSet::new(),
82 }
83 }
84
85 pub fn with_recursion_limit(mut self, limit: usize) -> Self {
87 self.recursion_limit = limit;
88 self
89 }
90
91 pub async fn invoke(&mut self, input: S, config: Config) -> Result<S> {
93 self.recursion_limit = config.recursion_limit;
94 self.current_step = 0;
95
96 if let Some(checkpointer) = &self.checkpointer {
98 if let Some(tuple) = checkpointer.get_tuple(&config).await? {
99 self.restore_channels(&tuple.checkpoint)?;
100 self.current_step = tuple.metadata.step;
101 }
102 }
103
104 self.write_input_to_channels(&input)?;
106
107 loop {
109 if self.current_step >= self.recursion_limit {
111 return Err(Error::RecursionLimitError {
112 current: self.current_step,
113 limit: self.recursion_limit,
114 });
115 }
116
117 let triggered_nodes = self.find_triggered_nodes();
119 if triggered_nodes.is_empty() {
120 break; }
122
123 let mut tasks = Vec::new();
125 for node_name in &triggered_nodes {
126 if let Some(node) = self.nodes.get(node_name) {
127 let state = self.read_state_for_node(node)?;
128 let node_clone = node.clone();
129 let config_clone = config.clone();
130
131 let task = tokio::spawn(async move {
132 node_clone.bound.invoke(state, &config_clone).await
133 });
134
135 tasks.push((node_name.clone(), task));
136 }
137 }
138
139 let mut updates: HashMap<String, S> = HashMap::new();
141 for (node_name, task) in tasks {
142 match task.await {
143 Ok(Ok(result)) => {
144 updates.insert(node_name, result);
145 }
146 Ok(Err(e)) => return Err(e),
147 Err(e) => {
148 return Err(Error::execution(format!("Node execution panicked: {}", e)))
149 }
150 }
151 }
152
153 self.written_channels = self.apply_updates(updates)?;
156
157 if let Some(checkpointer) = &self.checkpointer {
162 let checkpoint = self.create_checkpoint(&config)?;
163 let metadata = CheckpointMetadata {
164 step: self.current_step,
165 source: "pregel".to_string(),
166 created_at: chrono::Utc::now(),
167 extra: HashMap::new(),
168 };
169 checkpointer.put(&checkpoint, &metadata, &config).await?;
170 }
171
172 self.current_step += 1;
173 }
174
175 self.get_final_state()
177 }
178
179 pub async fn stream(
181 &mut self,
182 input: S,
183 config: Config,
184 mode: StreamMode,
185 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + std::marker::Send>>> {
186 self.recursion_limit = config.recursion_limit;
187 self.current_step = 0;
188
189 let (tx, rx) = tokio::sync::mpsc::channel(100);
193
194 if let Some(checkpointer) = &self.checkpointer {
196 if let Some(tuple) = checkpointer.get_tuple(&config).await? {
197 self.restore_channels(&tuple.checkpoint)?;
198 self.current_step = tuple.metadata.step;
199 }
200 }
201
202 self.write_input_to_channels(&input)?;
204
205 let _nodes = self.nodes.clone();
207 let _channels: HashMap<String, Box<dyn BaseChannel>> = HashMap::new();
208 let _checkpointer = self.checkpointer.clone();
213 let _entry_point = self.entry_point.clone();
214 let recursion_limit = self.recursion_limit;
215
216 tokio::spawn(async move {
217 let mut step = 0;
218 loop {
219 if step >= recursion_limit {
220 let _ = tx.send(Err(Error::RecursionLimitError {
221 current: step,
222 limit: recursion_limit,
223 })).await;
224 break;
225 }
226
227 if matches!(mode, StreamMode::Values) {
232 let event = StreamEvent::Values {
233 ns: vec![],
234 data: serde_json::json!({"step": step}),
235 interrupts: vec![],
236 };
237 if tx.send(Ok(event)).await.is_err() {
238 break;
239 }
240 }
241
242 step += 1;
243 break; }
245 });
246
247 Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
248 }
249
250 pub async fn get_state(&self, config: &Config) -> Result<Option<StateSnapshot<S>>> {
252 if let Some(checkpointer) = &self.checkpointer {
253 if let Some(tuple) = checkpointer.get_tuple(config).await? {
254 let state = self.state_from_checkpoint(&tuple.checkpoint)?;
255 return Ok(Some(StateSnapshot {
256 state,
257 checkpoint: tuple.checkpoint,
258 metadata: tuple.metadata,
259 config: tuple.config,
260 }));
261 }
262 }
263 Ok(None)
264 }
265
266 pub async fn get_state_history(
268 &self,
269 config: &Config,
270 limit: Option<usize>,
271 ) -> Result<Vec<StateSnapshot<S>>> {
272 if let Some(checkpointer) = &self.checkpointer {
273 let tuples = checkpointer.list(config, limit).await?;
274 let mut snapshots = Vec::new();
275
276 for tuple in tuples {
277 let state = self.state_from_checkpoint(&tuple.checkpoint)?;
278 snapshots.push(StateSnapshot {
279 state,
280 checkpoint: tuple.checkpoint,
281 metadata: tuple.metadata,
282 config: tuple.config,
283 });
284 }
285
286 return Ok(snapshots);
287 }
288 Ok(Vec::new())
289 }
290
291 fn write_input_to_channels(&mut self, input: &S) -> Result<()> {
295 let value = input.to_value()?;
297 if let Some(channel) = self.channels.get_mut("__start__") {
298 channel.update(vec![value])?;
299 self.written_channels.insert("__start__".to_string());
300 }
301 Ok(())
302 }
303
304 fn find_triggered_nodes(&self) -> Vec<String> {
306 let mut triggered = Vec::new();
307
308 for (name, node) in &self.nodes {
309 if node.is_triggered(&self.written_channels.iter().cloned().collect::<Vec<_>>()) {
310 triggered.push(name.clone());
311 }
312 }
313
314 if triggered.is_empty() && self.current_step == 0 {
316 triggered.push(self.entry_point.clone());
317 }
318
319 triggered
320 }
321
322 fn read_state_for_node(&self, node: &PregelNode<S>) -> Result<S> {
325 let mut merged: Option<S> = None;
326
327 for ch_name in &node.channels {
328 if let Some(channel) = self.channels.get(ch_name) {
329 if let Some(value) = channel.get()? {
330 let piece = S::from_value(value)?;
331 merged = match merged {
332 None => Some(piece),
333 Some(mut m) => {
334 m.merge(piece)?;
335 Some(m)
336 }
337 };
338 }
339 }
340 }
341
342 if merged.is_none() && node.triggers.iter().any(|t| t == START) {
343 if let Some(channel) = self.channels.get(START) {
344 if let Some(value) = channel.get()? {
345 merged = Some(S::from_value(value)?);
346 }
347 }
348 }
349
350 merged.ok_or_else(|| {
351 Error::state(format!(
352 "Cannot construct state for node '{}' (input channels {:?})",
353 node.name, node.channels
354 ))
355 })
356 }
357
358 fn apply_updates(&mut self, updates: HashMap<String, S>) -> Result<HashSet<String>> {
361 let mut next_triggers = HashSet::new();
362
363 for (node_name, state) in updates {
364 let value = state.to_value()?;
365
366 if let Some(node) = self.nodes.get(&node_name) {
367 for writer in &node.writers {
368 if let Some(channel) = self.channels.get_mut(&writer.channel) {
369 channel.update(vec![value.clone()])?;
370 next_triggers.insert(writer.channel.clone());
371 }
372 }
373 }
374
375 if let Some(targets) = self.edges.get(&node_name) {
376 for target in targets {
377 let input_ch = format!("{}_input", target);
378 if let Some(ch) = self.channels.get_mut(&input_ch) {
379 ch.update(vec![value.clone()])?;
380 }
381 }
382 }
383 }
384
385 Ok(next_triggers)
386 }
387
388 fn create_checkpoint(&self, config: &Config) -> Result<Checkpoint> {
390 let mut checkpoint = Checkpoint::new();
391
392 if let Some(thread_id) = &config.thread_id {
393 checkpoint.thread_id = Some(thread_id.clone());
394 }
395
396 for (name, channel) in &self.channels {
398 let channel_data = channel.checkpoint()?;
399 checkpoint.set_channel(name, channel_data);
400 }
401
402 Ok(checkpoint)
403 }
404
405 fn restore_channels(&mut self, checkpoint: &Checkpoint) -> Result<()> {
407 for (name, value) in &checkpoint.channel_values {
408 if let Some(channel) = self.channels.get_mut(name) {
411 channel.update(vec![value.clone()])?;
412 }
413 }
414 Ok(())
415 }
416
417 fn state_from_checkpoint(&self, checkpoint: &Checkpoint) -> Result<S> {
419 if let Some(value) = checkpoint.get_channel("__state__") {
421 return S::from_value(value.clone());
422 }
423
424 if let Some(value) = checkpoint.get_channel("__start__") {
426 return S::from_value(value.clone());
427 }
428
429 Err(Error::checkpoint("Cannot construct state from checkpoint"))
430 }
431
432 fn get_final_state(&self) -> Result<S> {
437 if let Some(channel) = self.channels.get(crate::graph::END) {
438 if let Some(value) = channel.get()? {
439 return S::from_value(value);
440 }
441 }
442
443 for fp in &self.finish_points {
444 let ch_name = format!("{}_output", fp);
445 if let Some(channel) = self.channels.get(&ch_name) {
446 if let Some(value) = channel.get()? {
447 return S::from_value(value);
448 }
449 }
450 }
451
452 if let Some(channel) = self.channels.get(START) {
453 if let Some(value) = channel.get()? {
454 return S::from_value(value);
455 }
456 }
457
458 Err(Error::state("Cannot determine final state"))
459 }
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465 use crate::channels::{LastValue};
466 use crate::nodes::{PregelNode, ChannelWrite};
467 use crate::state::State as StateTrait;
468 use serde::{Deserialize, Serialize};
469
470 #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
471 struct TestState {
472 count: i32,
473 }
474
475 impl StateTrait for TestState {
476 fn merge(&mut self, other: Self) -> Result<()> {
477 self.count += other.count;
478 Ok(())
479 }
480 }
481
482 #[tokio::test]
483 async fn test_pregel_basic() {
484 let increment_node = PregelNode::from_node(
485 "increment",
486 vec!["__start__".to_string()],
487 vec!["__start__".to_string()],
488 |mut state: TestState, _config: &Config| async move {
489 state.count += 1;
490 Ok(state)
491 },
492 vec![ChannelWrite::new("__end__")],
493 );
494
495 let mut nodes = HashMap::new();
496 nodes.insert("increment".to_string(), increment_node);
497
498 let mut channels: HashMap<String, Box<dyn BaseChannel>> = HashMap::new();
499 channels.insert("__start__".to_string(), Box::new(LastValue::<TestState>::new()));
500 channels.insert("__end__".to_string(), Box::new(LastValue::<TestState>::new()));
501
502 let mut pregel = Pregel::new(
503 nodes,
504 channels,
505 None,
506 "increment".to_string(),
507 HashSet::from(["increment".to_string()]),
508 HashMap::new(),
509 );
510
511 let input = TestState { count: 0 };
512 let result = pregel.invoke(input, Config::default()).await.unwrap();
513
514 assert_eq!(result.count, 1);
515 }
516
517 #[tokio::test]
520 async fn test_pregel_two_node_chain() {
521 let a = PregelNode::from_node(
522 "a",
523 vec!["a_input".to_string()],
524 vec![START.to_string()],
525 |mut state: TestState, _config: &Config| async move {
526 state.count += 1;
527 Ok(state)
528 },
529 vec![ChannelWrite::new("a_output")],
530 );
531 let b = PregelNode::from_node(
532 "b",
533 vec!["b_input".to_string()],
534 vec!["a_output".to_string()],
535 |mut state: TestState, _config: &Config| async move {
536 state.count *= 10;
537 Ok(state)
538 },
539 vec![ChannelWrite::new("b_output")],
540 );
541
542 let mut nodes = HashMap::new();
543 nodes.insert("a".to_string(), a);
544 nodes.insert("b".to_string(), b);
545
546 let mut channels: HashMap<String, Box<dyn BaseChannel>> = HashMap::new();
547 channels.insert(START.to_string(), Box::new(LastValue::<TestState>::new()));
548 channels.insert("a_input".to_string(), Box::new(LastValue::<TestState>::new()));
549 channels.insert("a_output".to_string(), Box::new(LastValue::<TestState>::new()));
550 channels.insert("b_input".to_string(), Box::new(LastValue::<TestState>::new()));
551 channels.insert("b_output".to_string(), Box::new(LastValue::<TestState>::new()));
552 channels.insert("__end__".to_string(), Box::new(LastValue::<TestState>::new()));
553
554 let mut edges = HashMap::new();
555 edges.insert("a".to_string(), vec!["b".to_string()]);
556
557 let mut pregel = Pregel::new(
558 nodes,
559 channels,
560 None,
561 "a".to_string(),
562 HashSet::from(["b".to_string()]),
563 edges,
564 );
565
566 let result = pregel
567 .invoke(TestState { count: 5 }, Config::default())
568 .await
569 .unwrap();
570 assert_eq!(result.count, 60); }
572}