1use std::collections::{HashMap, HashSet};
2use std::pin::Pin;
3use std::sync::Arc;
4
5use futures::Stream;
6use synaptic_core::SynapseError;
7
8use crate::checkpoint::{Checkpoint, CheckpointConfig, Checkpointer};
9use crate::command::{GraphCommand, GraphContext};
10use crate::edge::{ConditionalEdge, Edge};
11use crate::node::Node;
12use crate::state::State;
13use crate::END;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum StreamMode {
18 Values,
20 Updates,
22}
23
24#[derive(Debug, Clone)]
26pub struct GraphEvent<S> {
27 pub node: String,
29 pub state: S,
31}
32
33pub type GraphStream<'a, S> =
35 Pin<Box<dyn Stream<Item = Result<GraphEvent<S>, SynapseError>> + Send + 'a>>;
36
37pub struct CompiledGraph<S: State> {
39 pub(crate) nodes: HashMap<String, Box<dyn Node<S>>>,
40 pub(crate) edges: Vec<Edge>,
41 pub(crate) conditional_edges: Vec<ConditionalEdge<S>>,
42 pub(crate) entry_point: String,
43 pub(crate) interrupt_before: HashSet<String>,
44 pub(crate) interrupt_after: HashSet<String>,
45 pub(crate) checkpointer: Option<Arc<dyn Checkpointer>>,
46 pub(crate) command_context: GraphContext,
47}
48
49impl<S: State> std::fmt::Debug for CompiledGraph<S> {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 f.debug_struct("CompiledGraph")
52 .field("entry_point", &self.entry_point)
53 .field("node_count", &self.nodes.len())
54 .field("edge_count", &self.edges.len())
55 .field("conditional_edge_count", &self.conditional_edges.len())
56 .finish()
57 }
58}
59
60impl<S: State> CompiledGraph<S> {
61 pub fn with_checkpointer(mut self, checkpointer: Arc<dyn Checkpointer>) -> Self {
63 self.checkpointer = Some(checkpointer);
64 self
65 }
66
67 pub fn context(&self) -> &GraphContext {
72 &self.command_context
73 }
74
75 pub async fn invoke(&self, state: S) -> Result<S, SynapseError>
77 where
78 S: serde::Serialize + serde::de::DeserializeOwned,
79 {
80 self.invoke_with_config(state, None).await
81 }
82
83 pub async fn invoke_with_config(
85 &self,
86 mut state: S,
87 config: Option<CheckpointConfig>,
88 ) -> Result<S, SynapseError>
89 where
90 S: serde::Serialize + serde::de::DeserializeOwned,
91 {
92 let mut resume_from: Option<String> = None;
94 if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
95 if let Some(checkpoint) = checkpointer.get(cfg).await? {
96 state = serde_json::from_value(checkpoint.state).map_err(|e| {
97 SynapseError::Graph(format!("failed to deserialize checkpoint state: {e}"))
98 })?;
99 resume_from = checkpoint.next_node;
100 }
101 }
102
103 let mut current_node = resume_from.unwrap_or_else(|| self.entry_point.clone());
104 let mut max_iterations = 100; loop {
107 if current_node == END {
108 break;
109 }
110 if max_iterations == 0 {
111 return Err(SynapseError::Graph(
112 "max iterations (100) exceeded — possible infinite loop".to_string(),
113 ));
114 }
115 max_iterations -= 1;
116
117 if self.interrupt_before.contains(¤t_node) {
119 if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
120 let checkpoint = Checkpoint {
121 state: serde_json::to_value(&state)
122 .map_err(|e| SynapseError::Graph(format!("serialize state: {e}")))?,
123 next_node: Some(current_node.clone()),
124 };
125 checkpointer.put(cfg, &checkpoint).await?;
126 }
127 return Err(SynapseError::Graph(format!(
128 "interrupted before node '{current_node}'"
129 )));
130 }
131
132 let node = self
134 .nodes
135 .get(¤t_node)
136 .ok_or_else(|| SynapseError::Graph(format!("node '{current_node}' not found")))?;
137 state = node.process(state).await?;
138
139 let next = if let Some(cmd) = self.command_context.take_command().await {
141 match cmd {
142 GraphCommand::Goto(target) => target,
143 GraphCommand::End => END.to_string(),
144 }
145 } else {
146 if self.interrupt_after.contains(¤t_node) {
148 let next = self.find_next_node(¤t_node, &state);
150 if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
151 let checkpoint = Checkpoint {
152 state: serde_json::to_value(&state).map_err(|e| {
153 SynapseError::Graph(format!("serialize state: {e}"))
154 })?,
155 next_node: Some(next),
156 };
157 checkpointer.put(cfg, &checkpoint).await?;
158 }
159 return Err(SynapseError::Graph(format!(
160 "interrupted after node '{current_node}'"
161 )));
162 }
163
164 self.find_next_node(¤t_node, &state)
166 };
167
168 if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
170 let checkpoint = Checkpoint {
171 state: serde_json::to_value(&state)
172 .map_err(|e| SynapseError::Graph(format!("serialize state: {e}")))?,
173 next_node: Some(next.clone()),
174 };
175 checkpointer.put(cfg, &checkpoint).await?;
176 }
177
178 current_node = next;
179 }
180
181 Ok(state)
182 }
183
184 pub fn stream(&self, state: S, mode: StreamMode) -> GraphStream<'_, S>
186 where
187 S: serde::Serialize + serde::de::DeserializeOwned + Clone,
188 {
189 self.stream_with_config(state, mode, None)
190 }
191
192 pub fn stream_with_config(
194 &self,
195 state: S,
196 _mode: StreamMode,
197 config: Option<CheckpointConfig>,
198 ) -> GraphStream<'_, S>
199 where
200 S: serde::Serialize + serde::de::DeserializeOwned + Clone,
201 {
202 Box::pin(async_stream::stream! {
203 let mut state = state;
204
205 let mut resume_from: Option<String> = None;
207 if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
208 match checkpointer.get(cfg).await {
209 Ok(Some(checkpoint)) => {
210 match serde_json::from_value(checkpoint.state) {
211 Ok(s) => {
212 state = s;
213 resume_from = checkpoint.next_node;
214 }
215 Err(e) => {
216 yield Err(SynapseError::Graph(format!(
217 "failed to deserialize checkpoint state: {e}"
218 )));
219 return;
220 }
221 }
222 }
223 Ok(None) => {}
224 Err(e) => {
225 yield Err(e);
226 return;
227 }
228 }
229 }
230
231 let mut current_node = resume_from.unwrap_or_else(|| self.entry_point.clone());
232 let mut max_iterations = 100;
233
234 loop {
235 if current_node == END {
236 break;
237 }
238 if max_iterations == 0 {
239 yield Err(SynapseError::Graph(
240 "max iterations (100) exceeded — possible infinite loop".to_string(),
241 ));
242 return;
243 }
244 max_iterations -= 1;
245
246 if self.interrupt_before.contains(¤t_node) {
248 if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
249 let ckpt_result = serde_json::to_value(&state)
250 .map_err(|e| SynapseError::Graph(format!("serialize state: {e}")));
251 match ckpt_result {
252 Ok(state_val) => {
253 let checkpoint = Checkpoint {
254 state: state_val,
255 next_node: Some(current_node.clone()),
256 };
257 if let Err(e) = checkpointer.put(cfg, &checkpoint).await {
258 yield Err(e);
259 return;
260 }
261 }
262 Err(e) => {
263 yield Err(e);
264 return;
265 }
266 }
267 }
268 yield Err(SynapseError::Graph(format!(
269 "interrupted before node '{current_node}'"
270 )));
271 return;
272 }
273
274 let node = match self.nodes.get(¤t_node) {
276 Some(n) => n,
277 None => {
278 yield Err(SynapseError::Graph(format!("node '{current_node}' not found")));
279 return;
280 }
281 };
282
283 match node.process(state.clone()).await {
284 Ok(new_state) => {
285 state = new_state;
286 }
287 Err(e) => {
288 yield Err(e);
289 return;
290 }
291 }
292
293 let event = GraphEvent {
295 node: current_node.clone(),
296 state: state.clone(),
297 };
298 yield Ok(event);
299
300 let next = if let Some(cmd) = self.command_context.take_command().await {
302 match cmd {
303 GraphCommand::Goto(target) => target,
304 GraphCommand::End => END.to_string(),
305 }
306 } else {
307 if self.interrupt_after.contains(¤t_node) {
309 let next = self.find_next_node(¤t_node, &state);
310 if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
311 let ckpt_result = serde_json::to_value(&state)
312 .map_err(|e| SynapseError::Graph(format!("serialize state: {e}")));
313 match ckpt_result {
314 Ok(state_val) => {
315 let checkpoint = Checkpoint {
316 state: state_val,
317 next_node: Some(next),
318 };
319 if let Err(e) = checkpointer.put(cfg, &checkpoint).await {
320 yield Err(e);
321 return;
322 }
323 }
324 Err(e) => {
325 yield Err(e);
326 return;
327 }
328 }
329 }
330 yield Err(SynapseError::Graph(format!(
331 "interrupted after node '{current_node}'"
332 )));
333 return;
334 }
335
336 self.find_next_node(¤t_node, &state)
338 };
339
340 if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
342 let ckpt_result = serde_json::to_value(&state)
343 .map_err(|e| SynapseError::Graph(format!("serialize state: {e}")));
344 match ckpt_result {
345 Ok(state_val) => {
346 let checkpoint = Checkpoint {
347 state: state_val,
348 next_node: Some(next.clone()),
349 };
350 if let Err(e) = checkpointer.put(cfg, &checkpoint).await {
351 yield Err(e);
352 return;
353 }
354 }
355 Err(e) => {
356 yield Err(e);
357 return;
358 }
359 }
360 }
361
362 current_node = next;
363 }
364 })
365 }
366
367 pub async fn update_state(
369 &self,
370 config: &CheckpointConfig,
371 update: S,
372 ) -> Result<(), SynapseError>
373 where
374 S: serde::Serialize + serde::de::DeserializeOwned,
375 {
376 let checkpointer = self
377 .checkpointer
378 .as_ref()
379 .ok_or_else(|| SynapseError::Graph("no checkpointer configured".to_string()))?;
380
381 let checkpoint = checkpointer
382 .get(config)
383 .await?
384 .ok_or_else(|| SynapseError::Graph("no checkpoint found".to_string()))?;
385
386 let mut current_state: S = serde_json::from_value(checkpoint.state)
387 .map_err(|e| SynapseError::Graph(format!("deserialize: {e}")))?;
388
389 current_state.merge(update);
390
391 let updated = Checkpoint {
392 state: serde_json::to_value(¤t_state)
393 .map_err(|e| SynapseError::Graph(format!("serialize: {e}")))?,
394 next_node: checkpoint.next_node,
395 };
396 checkpointer.put(config, &updated).await?;
397
398 Ok(())
399 }
400
401 pub async fn get_state(&self, config: &CheckpointConfig) -> Result<Option<S>, SynapseError>
405 where
406 S: serde::de::DeserializeOwned,
407 {
408 let checkpointer = self
409 .checkpointer
410 .as_ref()
411 .ok_or_else(|| SynapseError::Graph("no checkpointer configured".to_string()))?;
412
413 match checkpointer.get(config).await? {
414 Some(checkpoint) => {
415 let state: S = serde_json::from_value(checkpoint.state).map_err(|e| {
416 SynapseError::Graph(format!("failed to deserialize checkpoint state: {e}"))
417 })?;
418 Ok(Some(state))
419 }
420 None => Ok(None),
421 }
422 }
423
424 pub async fn get_state_history(
430 &self,
431 config: &CheckpointConfig,
432 ) -> Result<Vec<(S, Option<String>)>, SynapseError>
433 where
434 S: serde::de::DeserializeOwned,
435 {
436 let checkpointer = self
437 .checkpointer
438 .as_ref()
439 .ok_or_else(|| SynapseError::Graph("no checkpointer configured".to_string()))?;
440
441 let checkpoints = checkpointer.list(config).await?;
442 let mut history = Vec::with_capacity(checkpoints.len());
443
444 for checkpoint in checkpoints {
445 let state: S = serde_json::from_value(checkpoint.state).map_err(|e| {
446 SynapseError::Graph(format!("failed to deserialize checkpoint state: {e}"))
447 })?;
448 history.push((state, checkpoint.next_node));
449 }
450
451 Ok(history)
452 }
453
454 fn find_next_node(&self, current: &str, state: &S) -> String {
455 for ce in &self.conditional_edges {
457 if ce.source == current {
458 return (ce.router)(state);
459 }
460 }
461
462 for edge in &self.edges {
464 if edge.source == current {
465 return edge.target.clone();
466 }
467 }
468
469 END.to_string()
471 }
472}