rust_logic_graph/core/
graph.rs

1
2use serde::{Serialize, Deserialize};
3use std::collections::HashMap;
4
5use crate::node::NodeType;
6
7#[derive(Debug, Serialize, Deserialize, Clone)]
8pub struct Edge {
9    pub from: String,
10    pub to: String,
11    pub rule: Option<String>,
12}
13
14impl Edge {
15    pub fn new(from: impl Into<String>, to: impl Into<String>) -> Self {
16        Self {
17            from: from.into(),
18            to: to.into(),
19            rule: None,
20        }
21    }
22
23    pub fn with_rule(mut self, rule: impl Into<String>) -> Self {
24        self.rule = Some(rule.into());
25        self
26    }
27}
28
29/// Configuration for a node in the graph
30#[derive(Debug, Serialize, Deserialize, Clone, Default)]
31pub struct NodeConfig {
32    pub node_type: NodeType,
33    #[serde(default)]
34    pub condition: Option<String>,
35    #[serde(default)]
36    pub query: Option<String>,
37    #[serde(default)]
38    pub prompt: Option<String>,
39    /// Optional list of context keys to extract as query parameters
40    /// Example: ["product_id", "user_id"] will extract ctx.get("product_id") and ctx.get("user_id")
41    #[serde(default)]
42    pub params: Option<Vec<String>>,
43}
44
45impl NodeConfig {
46    pub fn rule_node(condition: impl Into<String>) -> Self {
47        Self {
48            node_type: NodeType::RuleNode,
49            condition: Some(condition.into()),
50            query: None,
51            prompt: None,
52            params: None,
53        }
54    }
55
56    pub fn db_node(query: impl Into<String>) -> Self {
57        Self {
58            node_type: NodeType::DBNode,
59            condition: None,
60            query: Some(query.into()),
61            prompt: None,
62            params: None,
63        }
64    }
65    
66    /// Create a DBNode with query parameters from context
67    pub fn db_node_with_params(query: impl Into<String>, params: Vec<String>) -> Self {
68        Self {
69            node_type: NodeType::DBNode,
70            condition: None,
71            query: Some(query.into()),
72            prompt: None,
73            params: Some(params),
74        }
75    }
76
77    pub fn ai_node(prompt: impl Into<String>) -> Self {
78        Self {
79            node_type: NodeType::AINode,
80            condition: None,
81            query: None,
82            prompt: Some(prompt.into()),
83            params: None,
84        }
85    }
86    
87    /// Create a GrpcNode configuration
88    pub fn grpc_node(service_url: impl Into<String>, method: impl Into<String>) -> Self {
89        Self {
90            node_type: NodeType::GrpcNode,
91            query: Some(format!("{}#{}", service_url.into(), method.into())),
92            condition: None,
93            prompt: None,
94            params: None,
95        }
96    }
97}
98
99#[derive(Debug, Serialize, Deserialize, Clone)]
100pub struct GraphDef {
101    pub nodes: HashMap<String, NodeConfig>,
102    pub edges: Vec<Edge>,
103}
104
105impl GraphDef {
106    /// Create a GraphDef from simple node types (backward compatibility helper)
107    pub fn from_node_types(
108        nodes: HashMap<String, NodeType>,
109        edges: Vec<Edge>,
110    ) -> Self {
111        let nodes = nodes
112            .into_iter()
113            .map(|(id, node_type)| {
114                let config = match node_type {
115                    NodeType::RuleNode => NodeConfig::rule_node("true"),
116                    NodeType::DBNode => NodeConfig::db_node(format!("SELECT * FROM {}", id)),
117                    NodeType::AINode => NodeConfig::ai_node(format!("Process data for {}", id)),
118                    NodeType::GrpcNode => NodeConfig::grpc_node(
119                        format!("http://localhost:50051"),
120                        format!("{}_method", id)
121                    ),
122                    NodeType::SubgraphNode => NodeConfig::rule_node("true"), // Placeholder
123                    NodeType::ConditionalNode => NodeConfig::rule_node("true"),
124                    NodeType::LoopNode => NodeConfig::rule_node("true"),
125                    NodeType::TryCatchNode => NodeConfig::rule_node("true"),
126                    NodeType::RetryNode => NodeConfig::rule_node("true"),
127                    NodeType::CircuitBreakerNode => NodeConfig::rule_node("true"),
128                };
129                (id, config)
130            })
131            .collect();
132        
133        Self { nodes, edges }
134    }
135    
136    /// Validate graph structure
137    pub fn validate(&self) -> anyhow::Result<()> {
138        // Check for empty graph
139        if self.nodes.is_empty() {
140            return Err(anyhow::anyhow!("Graph has no nodes"));
141        }
142        
143        // Check for invalid edge references
144        for edge in &self.edges {
145            if !self.nodes.contains_key(&edge.from) {
146                return Err(anyhow::anyhow!(
147                    "Edge references non-existent source node: '{}'",
148                    edge.from
149                ));
150            }
151            if !self.nodes.contains_key(&edge.to) {
152                return Err(anyhow::anyhow!(
153                    "Edge references non-existent target node: '{}'",
154                    edge.to
155                ));
156            }
157        }
158        
159        Ok(())
160    }
161    
162    /// Check if graph has disconnected components
163    pub fn has_disconnected_components(&self) -> bool {
164        if self.nodes.is_empty() {
165            return false;
166        }
167        
168        use std::collections::HashSet;
169        let mut visited = HashSet::new();
170        let mut stack = Vec::new();
171        
172        // Start from first node
173        if let Some(first_node) = self.nodes.keys().next() {
174            stack.push(first_node.clone());
175        }
176        
177        // DFS traversal (undirected)
178        while let Some(node) = stack.pop() {
179            if visited.contains(&node) {
180                continue;
181            }
182            visited.insert(node.clone());
183            
184            // Add neighbors (both directions)
185            for edge in &self.edges {
186                if edge.from == node && !visited.contains(&edge.to) {
187                    stack.push(edge.to.clone());
188                }
189                if edge.to == node && !visited.contains(&edge.from) {
190                    stack.push(edge.from.clone());
191                }
192            }
193        }
194        
195        visited.len() < self.nodes.len()
196    }
197}
198
199#[derive(Default)]
200pub struct Context {
201    pub data: HashMap<String, serde_json::Value>,
202}
203
204impl Context {
205    pub fn new() -> Self {
206        Self {
207            data: HashMap::new(),
208        }
209    }
210
211    /// Set a value in the context
212    pub fn set(&mut self, key: impl Into<String>, value: serde_json::Value) {
213        self.data.insert(key.into(), value);
214    }
215
216    pub fn get(&self, key: &str) -> Option<&serde_json::Value> {
217        self.data.get(key)
218    }
219    
220    /// Check if key exists in context
221    pub fn contains_key(&self, key: &str) -> bool {
222        self.data.contains_key(key)
223    }
224    
225    /// Remove a value from context
226    pub fn remove(&mut self, key: &str) -> Option<serde_json::Value> {
227        self.data.remove(key)
228    }
229    
230    /// Clear all context data
231    pub fn clear(&mut self) {
232        self.data.clear();
233    }
234}
235
236pub struct Graph {
237    pub def: GraphDef,
238    pub context: Context,
239}
240
241impl Graph {
242    pub fn new(def: GraphDef) -> Self {
243        Self {
244            def,
245            context: Context::default(),
246        }
247    }
248}