Skip to main content

rust_langgraph/pregel/
engine.rs

1//! Core Pregel execution engine.
2//!
3//! Channel-centric superstep execution inspired by Google’s Pregel: nodes run when
4//! triggered, writes flow through channels with merge semantics, and checkpoints
5//! persist progress between steps.
6
7use crate::channels::BaseChannel;
8use crate::checkpoint::{BaseCheckpointSaver, Checkpoint, CheckpointMetadata, StateSnapshot};
9use crate::config::Config;
10use crate::errors::{Error, Result};
11use crate::nodes::PregelNode;
12use crate::state::State;
13use crate::types::{StreamEvent, StreamMode};
14use futures::stream::{Stream, StreamExt};
15use std::collections::{HashMap, HashSet};
16use std::pin::Pin;
17use std::sync::Arc;
18
19/// The Pregel execution engine.
20///
21/// This implements the core superstep loop:
22/// 1. Load checkpoint (if resuming)
23/// 2. Write initial input to channels
24/// 3. Loop:
25///    - Find triggered nodes
26///    - Execute nodes in parallel
27///    - Apply writes to channels with reducers
28///    - Handle interrupts/commands
29///    - Save checkpoint
30///    - Yield stream events
31/// 4. Return final state
32pub struct Pregel<S: State> {
33    /// The nodes in the graph
34    nodes: HashMap<String, PregelNode<S>>,
35
36    /// Channels holding graph state between supersteps
37    channels: HashMap<String, Box<dyn BaseChannel>>,
38
39    /// Optional checkpoint saver for persistence
40    checkpointer: Option<Arc<dyn BaseCheckpointSaver>>,
41
42    /// Entry point node name
43    entry_point: String,
44
45    /// Finish point node names
46    finish_points: HashSet<String>,
47
48    /// Static edges from source to targets
49    edges: HashMap<String, Vec<String>>,
50
51    /// Current step in execution
52    current_step: usize,
53
54    /// Maximum recursion depth
55    recursion_limit: usize,
56
57    /// Channels written in the current superstep
58    written_channels: HashSet<String>,
59}
60
61impl<S: State> Pregel<S> {
62    /// Create a new Pregel executor
63    pub fn new(
64        nodes: HashMap<String, PregelNode<S>>,
65        channels: HashMap<String, Box<dyn BaseChannel>>,
66        checkpointer: Option<Arc<dyn BaseCheckpointSaver>>,
67        entry_point: String,
68        finish_points: HashSet<String>,
69        edges: HashMap<String, Vec<String>>,
70    ) -> Self {
71        Self {
72            nodes,
73            channels,
74            checkpointer,
75            entry_point,
76            finish_points,
77            edges,
78            current_step: 0,
79            recursion_limit: 25,
80            written_channels: HashSet::new(),
81        }
82    }
83
84    /// Set the recursion limit
85    pub fn with_recursion_limit(mut self, limit: usize) -> Self {
86        self.recursion_limit = limit;
87        self
88    }
89
90    /// Execute the graph with the given input and configuration
91    pub async fn invoke(&mut self, input: S, config: Config) -> Result<S> {
92        self.recursion_limit = config.recursion_limit;
93        self.current_step = 0;
94
95        // 1. Load checkpoint if resuming
96        if let Some(checkpointer) = &self.checkpointer {
97            if let Some(tuple) = checkpointer.get_tuple(&config).await? {
98                self.restore_channels(&tuple.checkpoint)?;
99                self.current_step = tuple.metadata.step;
100            }
101        }
102
103        // 2. Write initial input to START channels
104        self.write_input_to_channels(&input)?;
105
106        // 3. Superstep loop
107        loop {
108            // Check recursion limit
109            if self.current_step >= self.recursion_limit {
110                return Err(Error::RecursionLimitError {
111                    current: self.current_step,
112                    limit: self.recursion_limit,
113                });
114            }
115
116            // Find triggered nodes
117            let triggered_nodes = self.find_triggered_nodes();
118            if triggered_nodes.is_empty() {
119                break; // No more work to do
120            }
121
122            // Execute nodes in parallel
123            let mut tasks = Vec::new();
124            for node_name in &triggered_nodes {
125                if let Some(node) = self.nodes.get(node_name) {
126                    let state = self.read_state_for_node(node)?;
127                    let node_clone = node.clone();
128                    let config_clone = config.clone();
129
130                    let task = tokio::spawn(async move {
131                        node_clone.bound.invoke(state, &config_clone).await
132                    });
133
134                    tasks.push((node_name.clone(), task));
135                }
136            }
137
138            // Collect results
139            let mut updates: HashMap<String, S> = HashMap::new();
140            for (node_name, task) in tasks {
141                match task.await {
142                    Ok(Ok(result)) => {
143                        updates.insert(node_name, result);
144                    }
145                    Ok(Err(e)) => return Err(e),
146                    Err(e) => {
147                        return Err(Error::execution(format!("Node execution panicked: {}", e)))
148                    }
149                }
150            }
151
152            // Apply writes to channels
153            self.apply_updates(updates)?;
154
155            // Check for interrupts/commands
156            // (For now, we'll implement basic interrupt support later)
157
158            // Save checkpoint
159            if let Some(checkpointer) = &self.checkpointer {
160                let checkpoint = self.create_checkpoint(&config)?;
161                let metadata = CheckpointMetadata {
162                    step: self.current_step,
163                    source: "pregel".to_string(),
164                    created_at: chrono::Utc::now(),
165                    extra: HashMap::new(),
166                };
167                checkpointer.put(&checkpoint, &metadata, &config).await?;
168            }
169
170            self.current_step += 1;
171            self.written_channels.clear(); // Clear for next superstep
172        }
173
174        // 4. Return final state
175        self.get_final_state()
176    }
177
178    /// Stream execution events
179    pub async fn stream(
180        &mut self,
181        input: S,
182        config: Config,
183        mode: StreamMode,
184    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + std::marker::Send>>> {
185        self.recursion_limit = config.recursion_limit;
186        self.current_step = 0;
187
188        // For MVP: implement a basic streaming version
189        // Full implementation would yield events at each step
190
191        let (tx, rx) = tokio::sync::mpsc::channel(100);
192
193        // Load checkpoint if resuming
194        if let Some(checkpointer) = &self.checkpointer {
195            if let Some(tuple) = checkpointer.get_tuple(&config).await? {
196                self.restore_channels(&tuple.checkpoint)?;
197                self.current_step = tuple.metadata.step;
198            }
199        }
200
201        // Write initial input
202        self.write_input_to_channels(&input)?;
203
204        // Clone necessary data for the async task
205        let _nodes = self.nodes.clone();
206        let _channels: HashMap<String, Box<dyn BaseChannel>> = HashMap::new();
207        // Note: We can't clone channels easily due to trait object limitations
208        // For now, we'll implement a simpler version
209
210        // Execute and stream
211        let _checkpointer = self.checkpointer.clone();
212        let _entry_point = self.entry_point.clone();
213        let recursion_limit = self.recursion_limit;
214
215        tokio::spawn(async move {
216            let mut step = 0;
217            loop {
218                if step >= recursion_limit {
219                    let _ = tx.send(Err(Error::RecursionLimitError {
220                        current: step,
221                        limit: recursion_limit,
222                    })).await;
223                    break;
224                }
225
226                // For now, simplified streaming
227                // Full implementation would mirror invoke() but yield events
228
229                // Emit a values event
230                if matches!(mode, StreamMode::Values) {
231                    let event = StreamEvent::Values {
232                        ns: vec![],
233                        data: serde_json::json!({"step": step}),
234                        interrupts: vec![],
235                    };
236                    if tx.send(Ok(event)).await.is_err() {
237                        break;
238                    }
239                }
240
241                step += 1;
242                break; // For MVP, just one step
243            }
244        });
245
246        Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
247    }
248
249    /// Get the current state snapshot
250    pub async fn get_state(&self, config: &Config) -> Result<Option<StateSnapshot<S>>> {
251        if let Some(checkpointer) = &self.checkpointer {
252            if let Some(tuple) = checkpointer.get_tuple(config).await? {
253                let state = self.state_from_checkpoint(&tuple.checkpoint)?;
254                return Ok(Some(StateSnapshot {
255                    state,
256                    checkpoint: tuple.checkpoint,
257                    metadata: tuple.metadata,
258                    config: tuple.config,
259                }));
260            }
261        }
262        Ok(None)
263    }
264
265    /// Get state history
266    pub async fn get_state_history(
267        &self,
268        config: &Config,
269        limit: Option<usize>,
270    ) -> Result<Vec<StateSnapshot<S>>> {
271        if let Some(checkpointer) = &self.checkpointer {
272            let tuples = checkpointer.list(config, limit).await?;
273            let mut snapshots = Vec::new();
274
275            for tuple in tuples {
276                let state = self.state_from_checkpoint(&tuple.checkpoint)?;
277                snapshots.push(StateSnapshot {
278                    state,
279                    checkpoint: tuple.checkpoint,
280                    metadata: tuple.metadata,
281                    config: tuple.config,
282                });
283            }
284
285            return Ok(snapshots);
286        }
287        Ok(Vec::new())
288    }
289
290    // === PRIVATE HELPER METHODS ===
291
292    /// Write input state to START channels
293    fn write_input_to_channels(&mut self, input: &S) -> Result<()> {
294        // Convert state to JSON and write to START channel
295        let value = input.to_value()?;
296        if let Some(channel) = self.channels.get_mut("__start__") {
297            channel.update(vec![value])?;
298            self.written_channels.insert("__start__".to_string());
299        }
300        Ok(())
301    }
302
303    /// Find nodes that are triggered by written channels
304    fn find_triggered_nodes(&self) -> Vec<String> {
305        let mut triggered = Vec::new();
306
307        for (name, node) in &self.nodes {
308            if node.is_triggered(&self.written_channels.iter().cloned().collect::<Vec<_>>()) {
309                triggered.push(name.clone());
310            }
311        }
312
313        // If no nodes triggered but we have entry point and it's first step
314        if triggered.is_empty() && self.current_step == 0 {
315            triggered.push(self.entry_point.clone());
316        }
317
318        triggered
319    }
320
321    /// Read state for a specific node from its input channels
322    fn read_state_for_node(&self, _node: &PregelNode<S>) -> Result<S> {
323        // For now, simple implementation: read from __start__ or construct empty state
324        // Full implementation would read from node's specific channels and construct state
325
326        if let Some(channel) = self.channels.get("__start__") {
327            if let Some(value) = channel.get()? {
328                return S::from_value(value);
329            }
330        }
331
332        // If no input, create from channels that node reads
333        // For MVP, this is simplified
334        Err(Error::state("Cannot construct state from channels"))
335    }
336
337    /// Apply node updates to channels
338    fn apply_updates(&mut self, updates: HashMap<String, S>) -> Result<()> {
339        for (node_name, state) in updates {
340            // Get the node's writer specifications
341            if let Some(node) = self.nodes.get(&node_name) {
342                // For each writer, write the state to the channel
343                for writer in &node.writers {
344                    let value = state.to_value()?;
345                    if let Some(channel) = self.channels.get_mut(&writer.channel) {
346                        channel.update(vec![value.clone()])?;
347                        self.written_channels.insert(writer.channel.clone());
348                    }
349                }
350            }
351
352            // Also follow static edges
353            if let Some(targets) = self.edges.get(&node_name) {
354                for target in targets {
355                    // Mark target's input channel as written
356                    self.written_channels.insert(format!("{}_input", target));
357                }
358            }
359        }
360
361        Ok(())
362    }
363
364    /// Create a checkpoint from current channel states
365    fn create_checkpoint(&self, config: &Config) -> Result<Checkpoint> {
366        let mut checkpoint = Checkpoint::new();
367
368        if let Some(thread_id) = &config.thread_id {
369            checkpoint.thread_id = Some(thread_id.clone());
370        }
371
372        // Save all channel values
373        for (name, channel) in &self.channels {
374            let channel_data = channel.checkpoint()?;
375            checkpoint.set_channel(name, channel_data);
376        }
377
378        Ok(checkpoint)
379    }
380
381    /// Restore channels from a checkpoint
382    fn restore_channels(&mut self, checkpoint: &Checkpoint) -> Result<()> {
383        for (name, value) in &checkpoint.channel_values {
384            // We can't fully restore channels from checkpoint due to type erasure
385            // In practice, the graph builder creates channels and we just update their values
386            if let Some(channel) = self.channels.get_mut(name) {
387                channel.update(vec![value.clone()])?;
388            }
389        }
390        Ok(())
391    }
392
393    /// Construct state from a checkpoint
394    fn state_from_checkpoint(&self, checkpoint: &Checkpoint) -> Result<S> {
395        // Get the value from the main state channel
396        if let Some(value) = checkpoint.get_channel("__state__") {
397            return S::from_value(value.clone());
398        }
399
400        // Fallback: try to construct from START channel
401        if let Some(value) = checkpoint.get_channel("__start__") {
402            return S::from_value(value.clone());
403        }
404
405        Err(Error::checkpoint("Cannot construct state from checkpoint"))
406    }
407
408    /// Get the final state after execution
409    fn get_final_state(&self) -> Result<S> {
410        // Read from output channels
411        if let Some(channel) = self.channels.get("__end__") {
412            if let Some(value) = channel.get()? {
413                return S::from_value(value);
414            }
415        }
416
417        // Fallback: read from __start__ channel
418        if let Some(channel) = self.channels.get("__start__") {
419            if let Some(value) = channel.get()? {
420                return S::from_value(value);
421            }
422        }
423
424        Err(Error::state("Cannot determine final state"))
425    }
426}
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431    use crate::channels::{LastValue};
432    use crate::nodes::{Node, PregelNode, ChannelWrite};
433    use crate::state::State as StateTrait;
434    use serde::{Deserialize, Serialize};
435
436    #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
437    struct TestState {
438        count: i32,
439    }
440
441    impl StateTrait for TestState {
442        fn merge(&mut self, other: Self) -> Result<()> {
443            self.count += other.count;
444            Ok(())
445        }
446    }
447
448    #[tokio::test]
449    async fn test_pregel_basic() {
450        let increment_node = PregelNode::from_node(
451            "increment",
452            vec!["__start__".to_string()],
453            vec!["__start__".to_string()],
454            |mut state: TestState, _config: &Config| async move {
455                state.count += 1;
456                Ok(state)
457            },
458            vec![ChannelWrite::new("__end__")],
459        );
460
461        let mut nodes = HashMap::new();
462        nodes.insert("increment".to_string(), increment_node);
463
464        let mut channels: HashMap<String, Box<dyn BaseChannel>> = HashMap::new();
465        channels.insert("__start__".to_string(), Box::new(LastValue::<TestState>::new()));
466        channels.insert("__end__".to_string(), Box::new(LastValue::<TestState>::new()));
467
468        let mut pregel = Pregel::new(
469            nodes,
470            channels,
471            None,
472            "increment".to_string(),
473            HashSet::from(["increment".to_string()]),
474            HashMap::new(),
475        );
476
477        let input = TestState { count: 0 };
478        let result = pregel.invoke(input, Config::default()).await.unwrap();
479
480        assert_eq!(result.count, 1);
481    }
482}