Skip to main content

synaptic_graph/
builder.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use synaptic_core::SynapticError;
5use tokio::sync::RwLock;
6
7use crate::compiled::{CachePolicy, CompiledGraph};
8use crate::edge::{ConditionalEdge, Edge};
9use crate::node::Node;
10use crate::state::State;
11use crate::{END, START};
12
13/// Builder for constructing a state graph.
14pub struct StateGraph<S: State> {
15    nodes: HashMap<String, Box<dyn Node<S>>>,
16    edges: Vec<Edge>,
17    conditional_edges: Vec<ConditionalEdge<S>>,
18    entry_point: Option<String>,
19    interrupt_before: HashSet<String>,
20    interrupt_after: HashSet<String>,
21    cache_policies: HashMap<String, CachePolicy>,
22    deferred: HashSet<String>,
23}
24
25impl<S: State> StateGraph<S> {
26    pub fn new() -> Self {
27        Self {
28            nodes: HashMap::new(),
29            edges: Vec::new(),
30            conditional_edges: Vec::new(),
31            entry_point: None,
32            interrupt_before: HashSet::new(),
33            interrupt_after: HashSet::new(),
34            cache_policies: HashMap::new(),
35            deferred: HashSet::new(),
36        }
37    }
38
39    /// Add a named node to the graph.
40    pub fn add_node(mut self, name: impl Into<String>, node: impl Node<S> + 'static) -> Self {
41        self.nodes.insert(name.into(), Box::new(node));
42        self
43    }
44
45    /// Add a deferred node that waits until ALL incoming edges have been
46    /// traversed before executing. Useful for fan-in aggregation after
47    /// parallel fan-out with [`Send`](crate::Send).
48    pub fn add_deferred_node(
49        mut self,
50        name: impl Into<String>,
51        node: impl Node<S> + 'static,
52    ) -> Self {
53        let n = name.into();
54        self.nodes.insert(n.clone(), Box::new(node));
55        self.deferred.insert(n);
56        self
57    }
58
59    /// Add a named node with caching. Results are cached based on
60    /// a hash of the serialized input state for the duration of the TTL.
61    pub fn add_node_with_cache(
62        mut self,
63        name: impl Into<String>,
64        node: impl Node<S> + 'static,
65        cache: CachePolicy,
66    ) -> Self {
67        let n = name.into();
68        self.nodes.insert(n.clone(), Box::new(node));
69        self.cache_policies.insert(n, cache);
70        self
71    }
72
73    /// Add a fixed edge from source to target.
74    pub fn add_edge(mut self, source: impl Into<String>, target: impl Into<String>) -> Self {
75        self.edges.push(Edge {
76            source: source.into(),
77            target: target.into(),
78        });
79        self
80    }
81
82    /// Add a conditional edge with a routing function.
83    pub fn add_conditional_edges(
84        mut self,
85        source: impl Into<String>,
86        router: impl Fn(&S) -> String + Send + Sync + 'static,
87    ) -> Self {
88        self.conditional_edges.push(ConditionalEdge {
89            source: source.into(),
90            router: Arc::new(router),
91            path_map: None,
92        });
93        self
94    }
95
96    /// Add a conditional edge with a routing function and a path map for visualization.
97    pub fn add_conditional_edges_with_path_map(
98        mut self,
99        source: impl Into<String>,
100        router: impl Fn(&S) -> String + Send + Sync + 'static,
101        path_map: HashMap<String, String>,
102    ) -> Self {
103        self.conditional_edges.push(ConditionalEdge {
104            source: source.into(),
105            router: Arc::new(router),
106            path_map: Some(path_map),
107        });
108        self
109    }
110
111    /// Set the entry point node for graph execution.
112    pub fn set_entry_point(mut self, name: impl Into<String>) -> Self {
113        self.entry_point = Some(name.into());
114        self
115    }
116
117    /// Mark nodes that should interrupt BEFORE execution (human-in-the-loop).
118    pub fn interrupt_before(mut self, nodes: Vec<String>) -> Self {
119        self.interrupt_before.extend(nodes);
120        self
121    }
122
123    /// Mark nodes that should interrupt AFTER execution (human-in-the-loop).
124    pub fn interrupt_after(mut self, nodes: Vec<String>) -> Self {
125        self.interrupt_after.extend(nodes);
126        self
127    }
128
129    /// Compile the graph into an executable CompiledGraph.
130    pub fn compile(self) -> Result<CompiledGraph<S>, SynapticError> {
131        let entry = self
132            .entry_point
133            .ok_or_else(|| SynapticError::Graph("no entry point set".to_string()))?;
134
135        if !self.nodes.contains_key(&entry) {
136            return Err(SynapticError::Graph(format!(
137                "entry point node '{entry}' not found"
138            )));
139        }
140
141        // Validate: every edge references existing nodes or END
142        for edge in &self.edges {
143            if edge.source != START && !self.nodes.contains_key(&edge.source) {
144                return Err(SynapticError::Graph(format!(
145                    "edge source '{}' not found",
146                    edge.source
147                )));
148            }
149            if edge.target != END && !self.nodes.contains_key(&edge.target) {
150                return Err(SynapticError::Graph(format!(
151                    "edge target '{}' not found",
152                    edge.target
153                )));
154            }
155        }
156
157        for ce in &self.conditional_edges {
158            if ce.source != START && !self.nodes.contains_key(&ce.source) {
159                return Err(SynapticError::Graph(format!(
160                    "conditional edge source '{}' not found",
161                    ce.source
162                )));
163            }
164            // Validate path_map targets reference existing nodes or END
165            if let Some(ref path_map) = ce.path_map {
166                for (label, target) in path_map {
167                    if target != END && !self.nodes.contains_key(target) {
168                        return Err(SynapticError::Graph(format!(
169                            "conditional edge path_map target '{target}' (label '{label}') not found"
170                        )));
171                    }
172                }
173            }
174        }
175
176        Ok(CompiledGraph {
177            nodes: self.nodes,
178            edges: self.edges,
179            conditional_edges: self.conditional_edges,
180            entry_point: entry,
181            interrupt_before: self.interrupt_before,
182            interrupt_after: self.interrupt_after,
183            checkpointer: None,
184            cache_policies: self.cache_policies,
185            cache: Arc::new(RwLock::new(HashMap::new())),
186            deferred: self.deferred,
187        })
188    }
189}
190
191impl<S: State> Default for StateGraph<S> {
192    fn default() -> Self {
193        Self::new()
194    }
195}