1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4use crate::node::NodeType;
5
6#[derive(Debug, Serialize, Deserialize, Clone)]
7pub struct Edge {
8 pub from: String,
9 pub to: String,
10 pub rule: Option<String>,
11}
12
13impl Edge {
14 pub fn new(from: impl Into<String>, to: impl Into<String>) -> Self {
15 Self {
16 from: from.into(),
17 to: to.into(),
18 rule: None,
19 }
20 }
21
22 pub fn with_rule(mut self, rule: impl Into<String>) -> Self {
23 self.rule = Some(rule.into());
24 self
25 }
26}
27
28#[derive(Debug, Serialize, Deserialize, Clone, Default)]
30pub struct NodeConfig {
31 pub node_type: NodeType,
32 #[serde(default)]
33 pub condition: Option<String>,
34 #[serde(default)]
35 pub query: Option<String>,
36 #[serde(default)]
37 pub prompt: Option<String>,
38 #[serde(default)]
41 pub params: Option<Vec<String>>,
42}
43
44impl NodeConfig {
45 pub fn rule_node(condition: impl Into<String>) -> Self {
46 Self {
47 node_type: NodeType::RuleNode,
48 condition: Some(condition.into()),
49 query: None,
50 prompt: None,
51 params: None,
52 }
53 }
54
55 pub fn db_node(query: impl Into<String>) -> Self {
56 Self {
57 node_type: NodeType::DBNode,
58 condition: None,
59 query: Some(query.into()),
60 prompt: None,
61 params: None,
62 }
63 }
64
65 pub fn db_node_with_params(query: impl Into<String>, params: Vec<String>) -> Self {
67 Self {
68 node_type: NodeType::DBNode,
69 condition: None,
70 query: Some(query.into()),
71 prompt: None,
72 params: Some(params),
73 }
74 }
75
76 pub fn ai_node(prompt: impl Into<String>) -> Self {
77 Self {
78 node_type: NodeType::AINode,
79 condition: None,
80 query: None,
81 prompt: Some(prompt.into()),
82 params: None,
83 }
84 }
85
86 pub fn grpc_node(service_url: impl Into<String>, method: impl Into<String>) -> Self {
88 Self {
89 node_type: NodeType::GrpcNode,
90 query: Some(format!("{}#{}", service_url.into(), method.into())),
91 condition: None,
92 prompt: None,
93 params: None,
94 }
95 }
96}
97
98#[derive(Debug, Serialize, Deserialize, Clone)]
99pub struct GraphDef {
100 pub nodes: HashMap<String, NodeConfig>,
101 pub edges: Vec<Edge>,
102}
103
104impl GraphDef {
105 pub fn from_node_types(nodes: HashMap<String, NodeType>, edges: Vec<Edge>) -> Self {
107 let nodes = nodes
108 .into_iter()
109 .map(|(id, node_type)| {
110 let config = match node_type {
111 NodeType::RuleNode => NodeConfig::rule_node("true"),
112 NodeType::DBNode => NodeConfig::db_node(format!("SELECT * FROM {}", id)),
113 NodeType::AINode => NodeConfig::ai_node(format!("Process data for {}", id)),
114 NodeType::GrpcNode => NodeConfig::grpc_node(
115 format!("http://localhost:50051"),
116 format!("{}_method", id),
117 ),
118 NodeType::SubgraphNode => NodeConfig::rule_node("true"), NodeType::ConditionalNode => NodeConfig::rule_node("true"),
120 NodeType::LoopNode => NodeConfig::rule_node("true"),
121 NodeType::TryCatchNode => NodeConfig::rule_node("true"),
122 NodeType::RetryNode => NodeConfig::rule_node("true"),
123 NodeType::CircuitBreakerNode => NodeConfig::rule_node("true"),
124 };
125 (id, config)
126 })
127 .collect();
128
129 Self { nodes, edges }
130 }
131
132 pub fn validate(&self) -> anyhow::Result<()> {
134 if self.nodes.is_empty() {
136 return Err(anyhow::anyhow!("Graph has no nodes"));
137 }
138
139 for edge in &self.edges {
141 if !self.nodes.contains_key(&edge.from) {
142 return Err(anyhow::anyhow!(
143 "Edge references non-existent source node: '{}'",
144 edge.from
145 ));
146 }
147 if !self.nodes.contains_key(&edge.to) {
148 return Err(anyhow::anyhow!(
149 "Edge references non-existent target node: '{}'",
150 edge.to
151 ));
152 }
153 }
154
155 Ok(())
156 }
157
158 pub fn has_disconnected_components(&self) -> bool {
160 if self.nodes.is_empty() {
161 return false;
162 }
163
164 use std::collections::HashSet;
165 let mut visited = HashSet::new();
166 let mut stack = Vec::new();
167
168 if let Some(first_node) = self.nodes.keys().next() {
170 stack.push(first_node.clone());
171 }
172
173 while let Some(node) = stack.pop() {
175 if visited.contains(&node) {
176 continue;
177 }
178 visited.insert(node.clone());
179
180 for edge in &self.edges {
182 if edge.from == node && !visited.contains(&edge.to) {
183 stack.push(edge.to.clone());
184 }
185 if edge.to == node && !visited.contains(&edge.from) {
186 stack.push(edge.from.clone());
187 }
188 }
189 }
190
191 visited.len() < self.nodes.len()
192 }
193}
194
195#[derive(Default)]
196pub struct Context {
197 pub data: HashMap<String, serde_json::Value>,
198}
199
200impl Context {
201 pub fn new() -> Self {
202 Self {
203 data: HashMap::new(),
204 }
205 }
206
207 pub fn set(&mut self, key: impl Into<String>, value: serde_json::Value) {
209 self.data.insert(key.into(), value);
210 }
211
212 pub fn get(&self, key: &str) -> Option<&serde_json::Value> {
213 self.data.get(key)
214 }
215
216 pub fn contains_key(&self, key: &str) -> bool {
218 self.data.contains_key(key)
219 }
220
221 pub fn remove(&mut self, key: &str) -> Option<serde_json::Value> {
223 self.data.remove(key)
224 }
225
226 pub fn clear(&mut self) {
228 self.data.clear();
229 }
230}
231
232pub struct Graph {
233 pub def: GraphDef,
234 pub context: Context,
235}
236
237impl Graph {
238 pub fn new(def: GraphDef) -> Self {
239 Self {
240 def,
241 context: Context::default(),
242 }
243 }
244}