Skip to main content

synaptic_graph/
compiled.rs

1use std::collections::{HashMap, HashSet};
2use std::pin::Pin;
3use std::sync::Arc;
4
5use futures::Stream;
6use synaptic_core::SynapseError;
7
8use crate::checkpoint::{Checkpoint, CheckpointConfig, Checkpointer};
9use crate::command::{GraphCommand, GraphContext};
10use crate::edge::{ConditionalEdge, Edge};
11use crate::node::Node;
12use crate::state::State;
13use crate::END;
14
15/// Controls what is yielded during graph streaming.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum StreamMode {
18    /// Yield full state after each node executes.
19    Values,
20    /// Yield only the delta (state before merge vs after, keyed by node name).
21    Updates,
22}
23
24/// An event yielded during graph streaming.
25#[derive(Debug, Clone)]
26pub struct GraphEvent<S> {
27    /// The node that just executed.
28    pub node: String,
29    /// The state snapshot (full state for Values mode, post-node state for Updates).
30    pub state: S,
31}
32
33/// A stream of graph events.
34pub type GraphStream<'a, S> =
35    Pin<Box<dyn Stream<Item = Result<GraphEvent<S>, SynapseError>> + Send + 'a>>;
36
37/// The compiled, executable graph.
38pub struct CompiledGraph<S: State> {
39    pub(crate) nodes: HashMap<String, Box<dyn Node<S>>>,
40    pub(crate) edges: Vec<Edge>,
41    pub(crate) conditional_edges: Vec<ConditionalEdge<S>>,
42    pub(crate) entry_point: String,
43    pub(crate) interrupt_before: HashSet<String>,
44    pub(crate) interrupt_after: HashSet<String>,
45    pub(crate) checkpointer: Option<Arc<dyn Checkpointer>>,
46    pub(crate) command_context: GraphContext,
47}
48
49impl<S: State> std::fmt::Debug for CompiledGraph<S> {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        f.debug_struct("CompiledGraph")
52            .field("entry_point", &self.entry_point)
53            .field("node_count", &self.nodes.len())
54            .field("edge_count", &self.edges.len())
55            .field("conditional_edge_count", &self.conditional_edges.len())
56            .finish()
57    }
58}
59
60impl<S: State> CompiledGraph<S> {
61    /// Set a checkpointer for state persistence.
62    pub fn with_checkpointer(mut self, checkpointer: Arc<dyn Checkpointer>) -> Self {
63        self.checkpointer = Some(checkpointer);
64        self
65    }
66
67    /// Get the `GraphContext` for this compiled graph.
68    ///
69    /// Nodes can use this context to issue dynamic control flow commands
70    /// (e.g., `goto` or `end`) that override normal edge-based routing.
71    pub fn context(&self) -> &GraphContext {
72        &self.command_context
73    }
74
75    /// Execute the graph with initial state.
76    pub async fn invoke(&self, state: S) -> Result<S, SynapseError>
77    where
78        S: serde::Serialize + serde::de::DeserializeOwned,
79    {
80        self.invoke_with_config(state, None).await
81    }
82
83    /// Execute with optional checkpoint config for resumption.
84    pub async fn invoke_with_config(
85        &self,
86        mut state: S,
87        config: Option<CheckpointConfig>,
88    ) -> Result<S, SynapseError>
89    where
90        S: serde::Serialize + serde::de::DeserializeOwned,
91    {
92        // If there's a checkpoint, try to resume from it
93        let mut resume_from: Option<String> = None;
94        if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
95            if let Some(checkpoint) = checkpointer.get(cfg).await? {
96                state = serde_json::from_value(checkpoint.state).map_err(|e| {
97                    SynapseError::Graph(format!("failed to deserialize checkpoint state: {e}"))
98                })?;
99                resume_from = checkpoint.next_node;
100            }
101        }
102
103        let mut current_node = resume_from.unwrap_or_else(|| self.entry_point.clone());
104        let mut max_iterations = 100; // safety guard
105
106        loop {
107            if current_node == END {
108                break;
109            }
110            if max_iterations == 0 {
111                return Err(SynapseError::Graph(
112                    "max iterations (100) exceeded — possible infinite loop".to_string(),
113                ));
114            }
115            max_iterations -= 1;
116
117            // Check interrupt_before
118            if self.interrupt_before.contains(&current_node) {
119                if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
120                    let checkpoint = Checkpoint {
121                        state: serde_json::to_value(&state)
122                            .map_err(|e| SynapseError::Graph(format!("serialize state: {e}")))?,
123                        next_node: Some(current_node.clone()),
124                    };
125                    checkpointer.put(cfg, &checkpoint).await?;
126                }
127                return Err(SynapseError::Graph(format!(
128                    "interrupted before node '{current_node}'"
129                )));
130            }
131
132            // Execute node
133            let node = self
134                .nodes
135                .get(&current_node)
136                .ok_or_else(|| SynapseError::Graph(format!("node '{current_node}' not found")))?;
137            state = node.process(state).await?;
138
139            // Check for command from GraphContext
140            let next = if let Some(cmd) = self.command_context.take_command().await {
141                match cmd {
142                    GraphCommand::Goto(target) => target,
143                    GraphCommand::End => END.to_string(),
144                }
145            } else {
146                // Check interrupt_after (only when no command override)
147                if self.interrupt_after.contains(&current_node) {
148                    // Find next node first so we can save it
149                    let next = self.find_next_node(&current_node, &state);
150                    if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
151                        let checkpoint = Checkpoint {
152                            state: serde_json::to_value(&state).map_err(|e| {
153                                SynapseError::Graph(format!("serialize state: {e}"))
154                            })?,
155                            next_node: Some(next),
156                        };
157                        checkpointer.put(cfg, &checkpoint).await?;
158                    }
159                    return Err(SynapseError::Graph(format!(
160                        "interrupted after node '{current_node}'"
161                    )));
162                }
163
164                // Find next node via normal edge routing
165                self.find_next_node(&current_node, &state)
166            };
167
168            // Save checkpoint after each node
169            if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
170                let checkpoint = Checkpoint {
171                    state: serde_json::to_value(&state)
172                        .map_err(|e| SynapseError::Graph(format!("serialize state: {e}")))?,
173                    next_node: Some(next.clone()),
174                };
175                checkpointer.put(cfg, &checkpoint).await?;
176            }
177
178            current_node = next;
179        }
180
181        Ok(state)
182    }
183
184    /// Stream graph execution, yielding a `GraphEvent` after each node.
185    pub fn stream(&self, state: S, mode: StreamMode) -> GraphStream<'_, S>
186    where
187        S: serde::Serialize + serde::de::DeserializeOwned + Clone,
188    {
189        self.stream_with_config(state, mode, None)
190    }
191
192    /// Stream graph execution with optional checkpoint config.
193    pub fn stream_with_config(
194        &self,
195        state: S,
196        _mode: StreamMode,
197        config: Option<CheckpointConfig>,
198    ) -> GraphStream<'_, S>
199    where
200        S: serde::Serialize + serde::de::DeserializeOwned + Clone,
201    {
202        Box::pin(async_stream::stream! {
203            let mut state = state;
204
205            // If there's a checkpoint, try to resume from it
206            let mut resume_from: Option<String> = None;
207            if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
208                match checkpointer.get(cfg).await {
209                    Ok(Some(checkpoint)) => {
210                        match serde_json::from_value(checkpoint.state) {
211                            Ok(s) => {
212                                state = s;
213                                resume_from = checkpoint.next_node;
214                            }
215                            Err(e) => {
216                                yield Err(SynapseError::Graph(format!(
217                                    "failed to deserialize checkpoint state: {e}"
218                                )));
219                                return;
220                            }
221                        }
222                    }
223                    Ok(None) => {}
224                    Err(e) => {
225                        yield Err(e);
226                        return;
227                    }
228                }
229            }
230
231            let mut current_node = resume_from.unwrap_or_else(|| self.entry_point.clone());
232            let mut max_iterations = 100;
233
234            loop {
235                if current_node == END {
236                    break;
237                }
238                if max_iterations == 0 {
239                    yield Err(SynapseError::Graph(
240                        "max iterations (100) exceeded — possible infinite loop".to_string(),
241                    ));
242                    return;
243                }
244                max_iterations -= 1;
245
246                // Check interrupt_before
247                if self.interrupt_before.contains(&current_node) {
248                    if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
249                        let ckpt_result = serde_json::to_value(&state)
250                            .map_err(|e| SynapseError::Graph(format!("serialize state: {e}")));
251                        match ckpt_result {
252                            Ok(state_val) => {
253                                let checkpoint = Checkpoint {
254                                    state: state_val,
255                                    next_node: Some(current_node.clone()),
256                                };
257                                if let Err(e) = checkpointer.put(cfg, &checkpoint).await {
258                                    yield Err(e);
259                                    return;
260                                }
261                            }
262                            Err(e) => {
263                                yield Err(e);
264                                return;
265                            }
266                        }
267                    }
268                    yield Err(SynapseError::Graph(format!(
269                        "interrupted before node '{current_node}'"
270                    )));
271                    return;
272                }
273
274                // Execute node
275                let node = match self.nodes.get(&current_node) {
276                    Some(n) => n,
277                    None => {
278                        yield Err(SynapseError::Graph(format!("node '{current_node}' not found")));
279                        return;
280                    }
281                };
282
283                match node.process(state.clone()).await {
284                    Ok(new_state) => {
285                        state = new_state;
286                    }
287                    Err(e) => {
288                        yield Err(e);
289                        return;
290                    }
291                }
292
293                // Yield event
294                let event = GraphEvent {
295                    node: current_node.clone(),
296                    state: state.clone(),
297                };
298                yield Ok(event);
299
300                // Check for command from GraphContext
301                let next = if let Some(cmd) = self.command_context.take_command().await {
302                    match cmd {
303                        GraphCommand::Goto(target) => target,
304                        GraphCommand::End => END.to_string(),
305                    }
306                } else {
307                    // Check interrupt_after (only when no command override)
308                    if self.interrupt_after.contains(&current_node) {
309                        let next = self.find_next_node(&current_node, &state);
310                        if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
311                            let ckpt_result = serde_json::to_value(&state)
312                                .map_err(|e| SynapseError::Graph(format!("serialize state: {e}")));
313                            match ckpt_result {
314                                Ok(state_val) => {
315                                    let checkpoint = Checkpoint {
316                                        state: state_val,
317                                        next_node: Some(next),
318                                    };
319                                    if let Err(e) = checkpointer.put(cfg, &checkpoint).await {
320                                        yield Err(e);
321                                        return;
322                                    }
323                                }
324                                Err(e) => {
325                                    yield Err(e);
326                                    return;
327                                }
328                            }
329                        }
330                        yield Err(SynapseError::Graph(format!(
331                            "interrupted after node '{current_node}'"
332                        )));
333                        return;
334                    }
335
336                    // Find next node via normal edge routing
337                    self.find_next_node(&current_node, &state)
338                };
339
340                // Save checkpoint
341                if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
342                    let ckpt_result = serde_json::to_value(&state)
343                        .map_err(|e| SynapseError::Graph(format!("serialize state: {e}")));
344                    match ckpt_result {
345                        Ok(state_val) => {
346                            let checkpoint = Checkpoint {
347                                state: state_val,
348                                next_node: Some(next.clone()),
349                            };
350                            if let Err(e) = checkpointer.put(cfg, &checkpoint).await {
351                                yield Err(e);
352                                return;
353                            }
354                        }
355                        Err(e) => {
356                            yield Err(e);
357                            return;
358                        }
359                    }
360                }
361
362                current_node = next;
363            }
364        })
365    }
366
367    /// Update state on an interrupted graph (for human-in-the-loop).
368    pub async fn update_state(
369        &self,
370        config: &CheckpointConfig,
371        update: S,
372    ) -> Result<(), SynapseError>
373    where
374        S: serde::Serialize + serde::de::DeserializeOwned,
375    {
376        let checkpointer = self
377            .checkpointer
378            .as_ref()
379            .ok_or_else(|| SynapseError::Graph("no checkpointer configured".to_string()))?;
380
381        let checkpoint = checkpointer
382            .get(config)
383            .await?
384            .ok_or_else(|| SynapseError::Graph("no checkpoint found".to_string()))?;
385
386        let mut current_state: S = serde_json::from_value(checkpoint.state)
387            .map_err(|e| SynapseError::Graph(format!("deserialize: {e}")))?;
388
389        current_state.merge(update);
390
391        let updated = Checkpoint {
392            state: serde_json::to_value(&current_state)
393                .map_err(|e| SynapseError::Graph(format!("serialize: {e}")))?,
394            next_node: checkpoint.next_node,
395        };
396        checkpointer.put(config, &updated).await?;
397
398        Ok(())
399    }
400
401    /// Get the current state for a thread from the checkpointer.
402    ///
403    /// Returns `None` if no checkpoint exists for the given thread.
404    pub async fn get_state(&self, config: &CheckpointConfig) -> Result<Option<S>, SynapseError>
405    where
406        S: serde::de::DeserializeOwned,
407    {
408        let checkpointer = self
409            .checkpointer
410            .as_ref()
411            .ok_or_else(|| SynapseError::Graph("no checkpointer configured".to_string()))?;
412
413        match checkpointer.get(config).await? {
414            Some(checkpoint) => {
415                let state: S = serde_json::from_value(checkpoint.state).map_err(|e| {
416                    SynapseError::Graph(format!("failed to deserialize checkpoint state: {e}"))
417                })?;
418                Ok(Some(state))
419            }
420            None => Ok(None),
421        }
422    }
423
424    /// Get the state history for a thread (all checkpoints).
425    ///
426    /// Returns a list of `(state, next_node)` pairs, ordered from oldest to newest.
427    /// The `next_node` indicates which node was scheduled to execute next when
428    /// the checkpoint was saved.
429    pub async fn get_state_history(
430        &self,
431        config: &CheckpointConfig,
432    ) -> Result<Vec<(S, Option<String>)>, SynapseError>
433    where
434        S: serde::de::DeserializeOwned,
435    {
436        let checkpointer = self
437            .checkpointer
438            .as_ref()
439            .ok_or_else(|| SynapseError::Graph("no checkpointer configured".to_string()))?;
440
441        let checkpoints = checkpointer.list(config).await?;
442        let mut history = Vec::with_capacity(checkpoints.len());
443
444        for checkpoint in checkpoints {
445            let state: S = serde_json::from_value(checkpoint.state).map_err(|e| {
446                SynapseError::Graph(format!("failed to deserialize checkpoint state: {e}"))
447            })?;
448            history.push((state, checkpoint.next_node));
449        }
450
451        Ok(history)
452    }
453
454    fn find_next_node(&self, current: &str, state: &S) -> String {
455        // Check conditional edges first
456        for ce in &self.conditional_edges {
457            if ce.source == current {
458                return (ce.router)(state);
459            }
460        }
461
462        // Check fixed edges
463        for edge in &self.edges {
464            if edge.source == current {
465                return edge.target.clone();
466            }
467        }
468
469        // No outgoing edge means END
470        END.to_string()
471    }
472}