synaptic_graph/
builder.rs1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use synaptic_core::SynapticError;
5use tokio::sync::RwLock;
6
7use crate::compiled::{CachePolicy, CompiledGraph};
8use crate::edge::{ConditionalEdge, Edge};
9use crate::node::Node;
10use crate::state::State;
11use crate::{END, START};
12
13pub struct StateGraph<S: State> {
15 nodes: HashMap<String, Box<dyn Node<S>>>,
16 edges: Vec<Edge>,
17 conditional_edges: Vec<ConditionalEdge<S>>,
18 entry_point: Option<String>,
19 interrupt_before: HashSet<String>,
20 interrupt_after: HashSet<String>,
21 cache_policies: HashMap<String, CachePolicy>,
22 deferred: HashSet<String>,
23}
24
25impl<S: State> StateGraph<S> {
26 pub fn new() -> Self {
27 Self {
28 nodes: HashMap::new(),
29 edges: Vec::new(),
30 conditional_edges: Vec::new(),
31 entry_point: None,
32 interrupt_before: HashSet::new(),
33 interrupt_after: HashSet::new(),
34 cache_policies: HashMap::new(),
35 deferred: HashSet::new(),
36 }
37 }
38
39 pub fn add_node(mut self, name: impl Into<String>, node: impl Node<S> + 'static) -> Self {
41 self.nodes.insert(name.into(), Box::new(node));
42 self
43 }
44
45 pub fn add_deferred_node(
49 mut self,
50 name: impl Into<String>,
51 node: impl Node<S> + 'static,
52 ) -> Self {
53 let n = name.into();
54 self.nodes.insert(n.clone(), Box::new(node));
55 self.deferred.insert(n);
56 self
57 }
58
59 pub fn add_node_with_cache(
62 mut self,
63 name: impl Into<String>,
64 node: impl Node<S> + 'static,
65 cache: CachePolicy,
66 ) -> Self {
67 let n = name.into();
68 self.nodes.insert(n.clone(), Box::new(node));
69 self.cache_policies.insert(n, cache);
70 self
71 }
72
73 pub fn add_edge(mut self, source: impl Into<String>, target: impl Into<String>) -> Self {
75 self.edges.push(Edge {
76 source: source.into(),
77 target: target.into(),
78 });
79 self
80 }
81
82 pub fn add_conditional_edges(
84 mut self,
85 source: impl Into<String>,
86 router: impl Fn(&S) -> String + Send + Sync + 'static,
87 ) -> Self {
88 self.conditional_edges.push(ConditionalEdge {
89 source: source.into(),
90 router: Arc::new(router),
91 path_map: None,
92 });
93 self
94 }
95
96 pub fn add_conditional_edges_with_path_map(
98 mut self,
99 source: impl Into<String>,
100 router: impl Fn(&S) -> String + Send + Sync + 'static,
101 path_map: HashMap<String, String>,
102 ) -> Self {
103 self.conditional_edges.push(ConditionalEdge {
104 source: source.into(),
105 router: Arc::new(router),
106 path_map: Some(path_map),
107 });
108 self
109 }
110
111 pub fn set_entry_point(mut self, name: impl Into<String>) -> Self {
113 self.entry_point = Some(name.into());
114 self
115 }
116
117 pub fn interrupt_before(mut self, nodes: Vec<String>) -> Self {
119 self.interrupt_before.extend(nodes);
120 self
121 }
122
123 pub fn interrupt_after(mut self, nodes: Vec<String>) -> Self {
125 self.interrupt_after.extend(nodes);
126 self
127 }
128
129 pub fn compile(self) -> Result<CompiledGraph<S>, SynapticError> {
131 let entry = self
132 .entry_point
133 .ok_or_else(|| SynapticError::Graph("no entry point set".to_string()))?;
134
135 if !self.nodes.contains_key(&entry) {
136 return Err(SynapticError::Graph(format!(
137 "entry point node '{entry}' not found"
138 )));
139 }
140
141 for edge in &self.edges {
143 if edge.source != START && !self.nodes.contains_key(&edge.source) {
144 return Err(SynapticError::Graph(format!(
145 "edge source '{}' not found",
146 edge.source
147 )));
148 }
149 if edge.target != END && !self.nodes.contains_key(&edge.target) {
150 return Err(SynapticError::Graph(format!(
151 "edge target '{}' not found",
152 edge.target
153 )));
154 }
155 }
156
157 for ce in &self.conditional_edges {
158 if ce.source != START && !self.nodes.contains_key(&ce.source) {
159 return Err(SynapticError::Graph(format!(
160 "conditional edge source '{}' not found",
161 ce.source
162 )));
163 }
164 if let Some(ref path_map) = ce.path_map {
166 for (label, target) in path_map {
167 if target != END && !self.nodes.contains_key(target) {
168 return Err(SynapticError::Graph(format!(
169 "conditional edge path_map target '{target}' (label '{label}') not found"
170 )));
171 }
172 }
173 }
174 }
175
176 Ok(CompiledGraph {
177 nodes: self.nodes,
178 edges: self.edges,
179 conditional_edges: self.conditional_edges,
180 entry_point: entry,
181 interrupt_before: self.interrupt_before,
182 interrupt_after: self.interrupt_after,
183 checkpointer: None,
184 cache_policies: self.cache_policies,
185 cache: Arc::new(RwLock::new(HashMap::new())),
186 deferred: self.deferred,
187 })
188 }
189}
190
191impl<S: State> Default for StateGraph<S> {
192 fn default() -> Self {
193 Self::new()
194 }
195}