1
2use serde::{Serialize, Deserialize};
3use std::collections::HashMap;
4
5use crate::node::NodeType;
6
7#[derive(Debug, Serialize, Deserialize, Clone)]
8pub struct Edge {
9 pub from: String,
10 pub to: String,
11 pub rule: Option<String>,
12}
13
14impl Edge {
15 pub fn new(from: impl Into<String>, to: impl Into<String>) -> Self {
16 Self {
17 from: from.into(),
18 to: to.into(),
19 rule: None,
20 }
21 }
22
23 pub fn with_rule(mut self, rule: impl Into<String>) -> Self {
24 self.rule = Some(rule.into());
25 self
26 }
27}
28
29#[derive(Debug, Serialize, Deserialize, Clone, Default)]
31pub struct NodeConfig {
32 pub node_type: NodeType,
33 #[serde(default)]
34 pub condition: Option<String>,
35 #[serde(default)]
36 pub query: Option<String>,
37 #[serde(default)]
38 pub prompt: Option<String>,
39 #[serde(default)]
42 pub params: Option<Vec<String>>,
43}
44
45impl NodeConfig {
46 pub fn rule_node(condition: impl Into<String>) -> Self {
47 Self {
48 node_type: NodeType::RuleNode,
49 condition: Some(condition.into()),
50 query: None,
51 prompt: None,
52 params: None,
53 }
54 }
55
56 pub fn db_node(query: impl Into<String>) -> Self {
57 Self {
58 node_type: NodeType::DBNode,
59 condition: None,
60 query: Some(query.into()),
61 prompt: None,
62 params: None,
63 }
64 }
65
66 pub fn db_node_with_params(query: impl Into<String>, params: Vec<String>) -> Self {
68 Self {
69 node_type: NodeType::DBNode,
70 condition: None,
71 query: Some(query.into()),
72 prompt: None,
73 params: Some(params),
74 }
75 }
76
77 pub fn ai_node(prompt: impl Into<String>) -> Self {
78 Self {
79 node_type: NodeType::AINode,
80 condition: None,
81 query: None,
82 prompt: Some(prompt.into()),
83 params: None,
84 }
85 }
86
87 pub fn grpc_node(service_url: impl Into<String>, method: impl Into<String>) -> Self {
89 Self {
90 node_type: NodeType::GrpcNode,
91 query: Some(format!("{}#{}", service_url.into(), method.into())),
92 condition: None,
93 prompt: None,
94 params: None,
95 }
96 }
97}
98
99#[derive(Debug, Serialize, Deserialize, Clone)]
100pub struct GraphDef {
101 pub nodes: HashMap<String, NodeConfig>,
102 pub edges: Vec<Edge>,
103}
104
105impl GraphDef {
106 pub fn from_node_types(
108 nodes: HashMap<String, NodeType>,
109 edges: Vec<Edge>,
110 ) -> Self {
111 let nodes = nodes
112 .into_iter()
113 .map(|(id, node_type)| {
114 let config = match node_type {
115 NodeType::RuleNode => NodeConfig::rule_node("true"),
116 NodeType::DBNode => NodeConfig::db_node(format!("SELECT * FROM {}", id)),
117 NodeType::AINode => NodeConfig::ai_node(format!("Process data for {}", id)),
118 NodeType::GrpcNode => NodeConfig::grpc_node(
119 format!("http://localhost:50051"),
120 format!("{}_method", id)
121 ),
122 NodeType::SubgraphNode => NodeConfig::rule_node("true"), NodeType::ConditionalNode => NodeConfig::rule_node("true"),
124 NodeType::LoopNode => NodeConfig::rule_node("true"),
125 NodeType::TryCatchNode => NodeConfig::rule_node("true"),
126 NodeType::RetryNode => NodeConfig::rule_node("true"),
127 NodeType::CircuitBreakerNode => NodeConfig::rule_node("true"),
128 };
129 (id, config)
130 })
131 .collect();
132
133 Self { nodes, edges }
134 }
135
136 pub fn validate(&self) -> anyhow::Result<()> {
138 if self.nodes.is_empty() {
140 return Err(anyhow::anyhow!("Graph has no nodes"));
141 }
142
143 for edge in &self.edges {
145 if !self.nodes.contains_key(&edge.from) {
146 return Err(anyhow::anyhow!(
147 "Edge references non-existent source node: '{}'",
148 edge.from
149 ));
150 }
151 if !self.nodes.contains_key(&edge.to) {
152 return Err(anyhow::anyhow!(
153 "Edge references non-existent target node: '{}'",
154 edge.to
155 ));
156 }
157 }
158
159 Ok(())
160 }
161
162 pub fn has_disconnected_components(&self) -> bool {
164 if self.nodes.is_empty() {
165 return false;
166 }
167
168 use std::collections::HashSet;
169 let mut visited = HashSet::new();
170 let mut stack = Vec::new();
171
172 if let Some(first_node) = self.nodes.keys().next() {
174 stack.push(first_node.clone());
175 }
176
177 while let Some(node) = stack.pop() {
179 if visited.contains(&node) {
180 continue;
181 }
182 visited.insert(node.clone());
183
184 for edge in &self.edges {
186 if edge.from == node && !visited.contains(&edge.to) {
187 stack.push(edge.to.clone());
188 }
189 if edge.to == node && !visited.contains(&edge.from) {
190 stack.push(edge.from.clone());
191 }
192 }
193 }
194
195 visited.len() < self.nodes.len()
196 }
197}
198
199#[derive(Default)]
200pub struct Context {
201 pub data: HashMap<String, serde_json::Value>,
202}
203
204impl Context {
205 pub fn new() -> Self {
206 Self {
207 data: HashMap::new(),
208 }
209 }
210
211 pub fn set(&mut self, key: impl Into<String>, value: serde_json::Value) {
213 self.data.insert(key.into(), value);
214 }
215
216 pub fn get(&self, key: &str) -> Option<&serde_json::Value> {
217 self.data.get(key)
218 }
219
220 pub fn contains_key(&self, key: &str) -> bool {
222 self.data.contains_key(key)
223 }
224
225 pub fn remove(&mut self, key: &str) -> Option<serde_json::Value> {
227 self.data.remove(key)
228 }
229
230 pub fn clear(&mut self) {
232 self.data.clear();
233 }
234}
235
236pub struct Graph {
237 pub def: GraphDef,
238 pub context: Context,
239}
240
241impl Graph {
242 pub fn new(def: GraphDef) -> Self {
243 Self {
244 def,
245 context: Context::default(),
246 }
247 }
248}