Skip to main content

rust_langgraph/pregel/
engine.rs

1//! Core Pregel execution engine.
2//!
3//! Channel-centric superstep execution inspired by Google’s Pregel: nodes run when
4//! triggered, writes flow through channels with merge semantics, and checkpoints
5//! persist progress between steps.
6
7use crate::channels::BaseChannel;
8use crate::checkpoint::{BaseCheckpointSaver, Checkpoint, CheckpointMetadata, StateSnapshot};
9use crate::config::Config;
10use crate::errors::{Error, Result};
11use crate::graph::START;
12use crate::nodes::PregelNode;
13use crate::state::State;
14use crate::types::{StreamEvent, StreamMode};
15use futures::stream::{Stream, StreamExt};
16use std::collections::{HashMap, HashSet};
17use std::pin::Pin;
18use std::sync::Arc;
19
20/// The Pregel execution engine.
21///
22/// This implements the core superstep loop:
23/// 1. Load checkpoint (if resuming)
24/// 2. Write initial input to channels
25/// 3. Loop:
26///    - Find triggered nodes
27///    - Execute nodes in parallel
28///    - Apply writes to channels with reducers
29///    - Handle interrupts/commands
30///    - Save checkpoint
31///    - Yield stream events
32/// 4. Return final state
33pub struct Pregel<S: State> {
34    /// The nodes in the graph
35    nodes: HashMap<String, PregelNode<S>>,
36
37    /// Channels holding graph state between supersteps
38    channels: HashMap<String, Box<dyn BaseChannel>>,
39
40    /// Optional checkpoint saver for persistence
41    checkpointer: Option<Arc<dyn BaseCheckpointSaver>>,
42
43    /// Entry point node name
44    entry_point: String,
45
46    /// Finish point node names
47    finish_points: HashSet<String>,
48
49    /// Static edges from source to targets
50    edges: HashMap<String, Vec<String>>,
51
52    /// Current step in execution
53    current_step: usize,
54
55    /// Maximum recursion depth
56    recursion_limit: usize,
57
58    /// Channels written in the current superstep
59    written_channels: HashSet<String>,
60}
61
62impl<S: State> Pregel<S> {
63    /// Create a new Pregel executor
64    pub fn new(
65        nodes: HashMap<String, PregelNode<S>>,
66        channels: HashMap<String, Box<dyn BaseChannel>>,
67        checkpointer: Option<Arc<dyn BaseCheckpointSaver>>,
68        entry_point: String,
69        finish_points: HashSet<String>,
70        edges: HashMap<String, Vec<String>>,
71    ) -> Self {
72        Self {
73            nodes,
74            channels,
75            checkpointer,
76            entry_point,
77            finish_points,
78            edges,
79            current_step: 0,
80            recursion_limit: 25,
81            written_channels: HashSet::new(),
82        }
83    }
84
85    /// Set the recursion limit
86    pub fn with_recursion_limit(mut self, limit: usize) -> Self {
87        self.recursion_limit = limit;
88        self
89    }
90
91    /// Execute the graph with the given input and configuration
92    pub async fn invoke(&mut self, input: S, config: Config) -> Result<S> {
93        self.recursion_limit = config.recursion_limit;
94        self.current_step = 0;
95
96        // 1. Load checkpoint if resuming
97        if let Some(checkpointer) = &self.checkpointer {
98            if let Some(tuple) = checkpointer.get_tuple(&config).await? {
99                self.restore_channels(&tuple.checkpoint)?;
100                self.current_step = tuple.metadata.step;
101            }
102        }
103
104        // 2. Write initial input to START channels
105        self.write_input_to_channels(&input)?;
106
107        // 3. Superstep loop
108        loop {
109            // Check recursion limit
110            if self.current_step >= self.recursion_limit {
111                return Err(Error::RecursionLimitError {
112                    current: self.current_step,
113                    limit: self.recursion_limit,
114                });
115            }
116
117            // Find triggered nodes
118            let triggered_nodes = self.find_triggered_nodes();
119            if triggered_nodes.is_empty() {
120                break; // No more work to do
121            }
122
123            // Execute nodes in parallel
124            let mut tasks = Vec::new();
125            for node_name in &triggered_nodes {
126                if let Some(node) = self.nodes.get(node_name) {
127                    let state = self.read_state_for_node(node)?;
128                    let node_clone = node.clone();
129                    let config_clone = config.clone();
130
131                    let task = tokio::spawn(async move {
132                        node_clone.bound.invoke(state, &config_clone).await
133                    });
134
135                    tasks.push((node_name.clone(), task));
136                }
137            }
138
139            // Collect results
140            let mut updates: HashMap<String, S> = HashMap::new();
141            for (node_name, task) in tasks {
142                match task.await {
143                    Ok(Ok(result)) => {
144                        updates.insert(node_name, result);
145                    }
146                    Ok(Err(e)) => return Err(e),
147                    Err(e) => {
148                        return Err(Error::execution(format!("Node execution panicked: {}", e)))
149                    }
150                }
151            }
152
153            // Apply writes to channels and collect which channels were written this superstep.
154            // Those channels trigger the next wave of nodes (do not clear — replace).
155            self.written_channels = self.apply_updates(updates)?;
156
157            // Check for interrupts/commands
158            // (For now, we'll implement basic interrupt support later)
159
160            // Save checkpoint
161            if let Some(checkpointer) = &self.checkpointer {
162                let checkpoint = self.create_checkpoint(&config)?;
163                let metadata = CheckpointMetadata {
164                    step: self.current_step,
165                    source: "pregel".to_string(),
166                    created_at: chrono::Utc::now(),
167                    extra: HashMap::new(),
168                };
169                checkpointer.put(&checkpoint, &metadata, &config).await?;
170            }
171
172            self.current_step += 1;
173        }
174
175        // 4. Return final state
176        self.get_final_state()
177    }
178
179    /// Stream execution events
180    pub async fn stream(
181        &mut self,
182        input: S,
183        config: Config,
184        mode: StreamMode,
185    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + std::marker::Send>>> {
186        self.recursion_limit = config.recursion_limit;
187        self.current_step = 0;
188
189        // For MVP: implement a basic streaming version
190        // Full implementation would yield events at each step
191
192        let (tx, rx) = tokio::sync::mpsc::channel(100);
193
194        // Load checkpoint if resuming
195        if let Some(checkpointer) = &self.checkpointer {
196            if let Some(tuple) = checkpointer.get_tuple(&config).await? {
197                self.restore_channels(&tuple.checkpoint)?;
198                self.current_step = tuple.metadata.step;
199            }
200        }
201
202        // Write initial input
203        self.write_input_to_channels(&input)?;
204
205        // Clone necessary data for the async task
206        let _nodes = self.nodes.clone();
207        let _channels: HashMap<String, Box<dyn BaseChannel>> = HashMap::new();
208        // Note: We can't clone channels easily due to trait object limitations
209        // For now, we'll implement a simpler version
210
211        // Execute and stream
212        let _checkpointer = self.checkpointer.clone();
213        let _entry_point = self.entry_point.clone();
214        let recursion_limit = self.recursion_limit;
215
216        tokio::spawn(async move {
217            let mut step = 0;
218            loop {
219                if step >= recursion_limit {
220                    let _ = tx.send(Err(Error::RecursionLimitError {
221                        current: step,
222                        limit: recursion_limit,
223                    })).await;
224                    break;
225                }
226
227                // For now, simplified streaming
228                // Full implementation would mirror invoke() but yield events
229
230                // Emit a values event
231                if matches!(mode, StreamMode::Values) {
232                    let event = StreamEvent::Values {
233                        ns: vec![],
234                        data: serde_json::json!({"step": step}),
235                        interrupts: vec![],
236                    };
237                    if tx.send(Ok(event)).await.is_err() {
238                        break;
239                    }
240                }
241
242                step += 1;
243                break; // For MVP, just one step
244            }
245        });
246
247        Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
248    }
249
250    /// Get the current state snapshot
251    pub async fn get_state(&self, config: &Config) -> Result<Option<StateSnapshot<S>>> {
252        if let Some(checkpointer) = &self.checkpointer {
253            if let Some(tuple) = checkpointer.get_tuple(config).await? {
254                let state = self.state_from_checkpoint(&tuple.checkpoint)?;
255                return Ok(Some(StateSnapshot {
256                    state,
257                    checkpoint: tuple.checkpoint,
258                    metadata: tuple.metadata,
259                    config: tuple.config,
260                }));
261            }
262        }
263        Ok(None)
264    }
265
266    /// Get state history
267    pub async fn get_state_history(
268        &self,
269        config: &Config,
270        limit: Option<usize>,
271    ) -> Result<Vec<StateSnapshot<S>>> {
272        if let Some(checkpointer) = &self.checkpointer {
273            let tuples = checkpointer.list(config, limit).await?;
274            let mut snapshots = Vec::new();
275
276            for tuple in tuples {
277                let state = self.state_from_checkpoint(&tuple.checkpoint)?;
278                snapshots.push(StateSnapshot {
279                    state,
280                    checkpoint: tuple.checkpoint,
281                    metadata: tuple.metadata,
282                    config: tuple.config,
283                });
284            }
285
286            return Ok(snapshots);
287        }
288        Ok(Vec::new())
289    }
290
291    // === PRIVATE HELPER METHODS ===
292
293    /// Write input state to START channels
294    fn write_input_to_channels(&mut self, input: &S) -> Result<()> {
295        // Convert state to JSON and write to START channel
296        let value = input.to_value()?;
297        if let Some(channel) = self.channels.get_mut("__start__") {
298            channel.update(vec![value])?;
299            self.written_channels.insert("__start__".to_string());
300        }
301        Ok(())
302    }
303
304    /// Find nodes that are triggered by written channels
305    fn find_triggered_nodes(&self) -> Vec<String> {
306        let mut triggered = Vec::new();
307
308        for (name, node) in &self.nodes {
309            if node.is_triggered(&self.written_channels.iter().cloned().collect::<Vec<_>>()) {
310                triggered.push(name.clone());
311            }
312        }
313
314        // If no nodes triggered but we have entry point and it's first step
315        if triggered.is_empty() && self.current_step == 0 {
316            triggered.push(self.entry_point.clone());
317        }
318
319        triggered
320    }
321
322    /// Read state for a specific node from its input channels (`{name}_input`), merged in order.
323    /// If those channels are empty but this node is triggered by `__start__`, read the graph input from `__start__`.
324    fn read_state_for_node(&self, node: &PregelNode<S>) -> Result<S> {
325        let mut merged: Option<S> = None;
326
327        for ch_name in &node.channels {
328            if let Some(channel) = self.channels.get(ch_name) {
329                if let Some(value) = channel.get()? {
330                    let piece = S::from_value(value)?;
331                    merged = match merged {
332                        None => Some(piece),
333                        Some(mut m) => {
334                            m.merge(piece)?;
335                            Some(m)
336                        }
337                    };
338                }
339            }
340        }
341
342        if merged.is_none() && node.triggers.iter().any(|t| t == START) {
343            if let Some(channel) = self.channels.get(START) {
344                if let Some(value) = channel.get()? {
345                    merged = Some(S::from_value(value)?);
346                }
347            }
348        }
349
350        merged.ok_or_else(|| {
351            Error::state(format!(
352                "Cannot construct state for node '{}' (input channels {:?})",
353                node.name, node.channels
354            ))
355        })
356    }
357
358    /// Apply node outputs to `{node}_output` and fan out the same state to each edge target's `{target}_input`.
359    /// Returns the set of channel names written this superstep — these trigger the next superstep (replacing the previous set).
360    fn apply_updates(&mut self, updates: HashMap<String, S>) -> Result<HashSet<String>> {
361        let mut next_triggers = HashSet::new();
362
363        for (node_name, state) in updates {
364            let value = state.to_value()?;
365
366            if let Some(node) = self.nodes.get(&node_name) {
367                for writer in &node.writers {
368                    if let Some(channel) = self.channels.get_mut(&writer.channel) {
369                        channel.update(vec![value.clone()])?;
370                        next_triggers.insert(writer.channel.clone());
371                    }
372                }
373            }
374
375            if let Some(targets) = self.edges.get(&node_name) {
376                for target in targets {
377                    let input_ch = format!("{}_input", target);
378                    if let Some(ch) = self.channels.get_mut(&input_ch) {
379                        ch.update(vec![value.clone()])?;
380                    }
381                }
382            }
383        }
384
385        Ok(next_triggers)
386    }
387
388    /// Create a checkpoint from current channel states
389    fn create_checkpoint(&self, config: &Config) -> Result<Checkpoint> {
390        let mut checkpoint = Checkpoint::new();
391
392        if let Some(thread_id) = &config.thread_id {
393            checkpoint.thread_id = Some(thread_id.clone());
394        }
395
396        // Save all channel values
397        for (name, channel) in &self.channels {
398            let channel_data = channel.checkpoint()?;
399            checkpoint.set_channel(name, channel_data);
400        }
401
402        Ok(checkpoint)
403    }
404
405    /// Restore channels from a checkpoint
406    fn restore_channels(&mut self, checkpoint: &Checkpoint) -> Result<()> {
407        for (name, value) in &checkpoint.channel_values {
408            // We can't fully restore channels from checkpoint due to type erasure
409            // In practice, the graph builder creates channels and we just update their values
410            if let Some(channel) = self.channels.get_mut(name) {
411                channel.update(vec![value.clone()])?;
412            }
413        }
414        Ok(())
415    }
416
417    /// Construct state from a checkpoint
418    fn state_from_checkpoint(&self, checkpoint: &Checkpoint) -> Result<S> {
419        // Get the value from the main state channel
420        if let Some(value) = checkpoint.get_channel("__state__") {
421            return S::from_value(value.clone());
422        }
423
424        // Fallback: try to construct from START channel
425        if let Some(value) = checkpoint.get_channel("__start__") {
426            return S::from_value(value.clone());
427        }
428
429        Err(Error::checkpoint("Cannot construct state from checkpoint"))
430    }
431
432    /// Get the final state after execution.
433    ///
434    /// Prefer `__end__`, then each finish node's `{name}_output` (where compiled graphs write),
435    /// then `__start__` as last resort.
436    fn get_final_state(&self) -> Result<S> {
437        if let Some(channel) = self.channels.get(crate::graph::END) {
438            if let Some(value) = channel.get()? {
439                return S::from_value(value);
440            }
441        }
442
443        for fp in &self.finish_points {
444            let ch_name = format!("{}_output", fp);
445            if let Some(channel) = self.channels.get(&ch_name) {
446                if let Some(value) = channel.get()? {
447                    return S::from_value(value);
448                }
449            }
450        }
451
452        if let Some(channel) = self.channels.get(START) {
453            if let Some(value) = channel.get()? {
454                return S::from_value(value);
455            }
456        }
457
458        Err(Error::state("Cannot determine final state"))
459    }
460}
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465    use crate::channels::{LastValue};
466    use crate::nodes::{PregelNode, ChannelWrite};
467    use crate::state::State as StateTrait;
468    use serde::{Deserialize, Serialize};
469
470    #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
471    struct TestState {
472        count: i32,
473    }
474
475    impl StateTrait for TestState {
476        fn merge(&mut self, other: Self) -> Result<()> {
477            self.count += other.count;
478            Ok(())
479        }
480    }
481
482    #[tokio::test]
483    async fn test_pregel_basic() {
484        let increment_node = PregelNode::from_node(
485            "increment",
486            vec!["__start__".to_string()],
487            vec!["__start__".to_string()],
488            |mut state: TestState, _config: &Config| async move {
489                state.count += 1;
490                Ok(state)
491            },
492            vec![ChannelWrite::new("__end__")],
493        );
494
495        let mut nodes = HashMap::new();
496        nodes.insert("increment".to_string(), increment_node);
497
498        let mut channels: HashMap<String, Box<dyn BaseChannel>> = HashMap::new();
499        channels.insert("__start__".to_string(), Box::new(LastValue::<TestState>::new()));
500        channels.insert("__end__".to_string(), Box::new(LastValue::<TestState>::new()));
501
502        let mut pregel = Pregel::new(
503            nodes,
504            channels,
505            None,
506            "increment".to_string(),
507            HashSet::from(["increment".to_string()]),
508            HashMap::new(),
509        );
510
511        let input = TestState { count: 0 };
512        let result = pregel.invoke(input, Config::default()).await.unwrap();
513
514        assert_eq!(result.count, 1);
515    }
516
517    /// Two-node chain: second superstep must run (triggers from previous `{src}_output`), and
518    /// downstream reads merged state from `{dst}_input` (not only `__start__`).
519    #[tokio::test]
520    async fn test_pregel_two_node_chain() {
521        let a = PregelNode::from_node(
522            "a",
523            vec!["a_input".to_string()],
524            vec![START.to_string()],
525            |mut state: TestState, _config: &Config| async move {
526                state.count += 1;
527                Ok(state)
528            },
529            vec![ChannelWrite::new("a_output")],
530        );
531        let b = PregelNode::from_node(
532            "b",
533            vec!["b_input".to_string()],
534            vec!["a_output".to_string()],
535            |mut state: TestState, _config: &Config| async move {
536                state.count *= 10;
537                Ok(state)
538            },
539            vec![ChannelWrite::new("b_output")],
540        );
541
542        let mut nodes = HashMap::new();
543        nodes.insert("a".to_string(), a);
544        nodes.insert("b".to_string(), b);
545
546        let mut channels: HashMap<String, Box<dyn BaseChannel>> = HashMap::new();
547        channels.insert(START.to_string(), Box::new(LastValue::<TestState>::new()));
548        channels.insert("a_input".to_string(), Box::new(LastValue::<TestState>::new()));
549        channels.insert("a_output".to_string(), Box::new(LastValue::<TestState>::new()));
550        channels.insert("b_input".to_string(), Box::new(LastValue::<TestState>::new()));
551        channels.insert("b_output".to_string(), Box::new(LastValue::<TestState>::new()));
552        channels.insert("__end__".to_string(), Box::new(LastValue::<TestState>::new()));
553
554        let mut edges = HashMap::new();
555        edges.insert("a".to_string(), vec!["b".to_string()]);
556
557        let mut pregel = Pregel::new(
558            nodes,
559            channels,
560            None,
561            "a".to_string(),
562            HashSet::from(["b".to_string()]),
563            edges,
564        );
565
566        let result = pregel
567            .invoke(TestState { count: 5 }, Config::default())
568            .await
569            .unwrap();
570        assert_eq!(result.count, 60); // (5 + 1) * 10
571    }
572}