rust_logic_graph/core/
graph.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4use crate::node::NodeType;
5
6#[derive(Debug, Serialize, Deserialize, Clone)]
7pub struct Edge {
8    pub from: String,
9    pub to: String,
10    pub rule: Option<String>,
11}
12
13impl Edge {
14    pub fn new(from: impl Into<String>, to: impl Into<String>) -> Self {
15        Self {
16            from: from.into(),
17            to: to.into(),
18            rule: None,
19        }
20    }
21
22    pub fn with_rule(mut self, rule: impl Into<String>) -> Self {
23        self.rule = Some(rule.into());
24        self
25    }
26}
27
28/// Configuration for a node in the graph
29#[derive(Debug, Serialize, Deserialize, Clone, Default)]
30pub struct NodeConfig {
31    pub node_type: NodeType,
32    #[serde(default)]
33    pub condition: Option<String>,
34    #[serde(default)]
35    pub query: Option<String>,
36    #[serde(default)]
37    pub prompt: Option<String>,
38    /// Optional list of context keys to extract as query parameters
39    /// Example: ["product_id", "user_id"] will extract ctx.get("product_id") and ctx.get("user_id")
40    #[serde(default)]
41    pub params: Option<Vec<String>>,
42}
43
44impl NodeConfig {
45    pub fn rule_node(condition: impl Into<String>) -> Self {
46        Self {
47            node_type: NodeType::RuleNode,
48            condition: Some(condition.into()),
49            query: None,
50            prompt: None,
51            params: None,
52        }
53    }
54
55    pub fn db_node(query: impl Into<String>) -> Self {
56        Self {
57            node_type: NodeType::DBNode,
58            condition: None,
59            query: Some(query.into()),
60            prompt: None,
61            params: None,
62        }
63    }
64
65    /// Create a DBNode with query parameters from context
66    pub fn db_node_with_params(query: impl Into<String>, params: Vec<String>) -> Self {
67        Self {
68            node_type: NodeType::DBNode,
69            condition: None,
70            query: Some(query.into()),
71            prompt: None,
72            params: Some(params),
73        }
74    }
75
76    pub fn ai_node(prompt: impl Into<String>) -> Self {
77        Self {
78            node_type: NodeType::AINode,
79            condition: None,
80            query: None,
81            prompt: Some(prompt.into()),
82            params: None,
83        }
84    }
85
86    /// Create a GrpcNode configuration
87    pub fn grpc_node(service_url: impl Into<String>, method: impl Into<String>) -> Self {
88        Self {
89            node_type: NodeType::GrpcNode,
90            query: Some(format!("{}#{}", service_url.into(), method.into())),
91            condition: None,
92            prompt: None,
93            params: None,
94        }
95    }
96}
97
98#[derive(Debug, Serialize, Deserialize, Clone)]
99pub struct GraphDef {
100    pub nodes: HashMap<String, NodeConfig>,
101    pub edges: Vec<Edge>,
102}
103
104impl GraphDef {
105    /// Create a GraphDef from simple node types (backward compatibility helper)
106    pub fn from_node_types(nodes: HashMap<String, NodeType>, edges: Vec<Edge>) -> Self {
107        let nodes = nodes
108            .into_iter()
109            .map(|(id, node_type)| {
110                let config = match node_type {
111                    NodeType::RuleNode => NodeConfig::rule_node("true"),
112                    NodeType::DBNode => NodeConfig::db_node(format!("SELECT * FROM {}", id)),
113                    NodeType::AINode => NodeConfig::ai_node(format!("Process data for {}", id)),
114                    NodeType::GrpcNode => NodeConfig::grpc_node(
115                        format!("http://localhost:50051"),
116                        format!("{}_method", id),
117                    ),
118                    NodeType::SubgraphNode => NodeConfig::rule_node("true"), // Placeholder
119                    NodeType::ConditionalNode => NodeConfig::rule_node("true"),
120                    NodeType::LoopNode => NodeConfig::rule_node("true"),
121                    NodeType::TryCatchNode => NodeConfig::rule_node("true"),
122                    NodeType::RetryNode => NodeConfig::rule_node("true"),
123                    NodeType::CircuitBreakerNode => NodeConfig::rule_node("true"),
124                };
125                (id, config)
126            })
127            .collect();
128
129        Self { nodes, edges }
130    }
131
132    /// Validate graph structure
133    pub fn validate(&self) -> anyhow::Result<()> {
134        // Check for empty graph
135        if self.nodes.is_empty() {
136            return Err(anyhow::anyhow!("Graph has no nodes"));
137        }
138
139        // Check for invalid edge references
140        for edge in &self.edges {
141            if !self.nodes.contains_key(&edge.from) {
142                return Err(anyhow::anyhow!(
143                    "Edge references non-existent source node: '{}'",
144                    edge.from
145                ));
146            }
147            if !self.nodes.contains_key(&edge.to) {
148                return Err(anyhow::anyhow!(
149                    "Edge references non-existent target node: '{}'",
150                    edge.to
151                ));
152            }
153        }
154
155        Ok(())
156    }
157
158    /// Check if graph has disconnected components
159    pub fn has_disconnected_components(&self) -> bool {
160        if self.nodes.is_empty() {
161            return false;
162        }
163
164        use std::collections::HashSet;
165        let mut visited = HashSet::new();
166        let mut stack = Vec::new();
167
168        // Start from first node
169        if let Some(first_node) = self.nodes.keys().next() {
170            stack.push(first_node.clone());
171        }
172
173        // DFS traversal (undirected)
174        while let Some(node) = stack.pop() {
175            if visited.contains(&node) {
176                continue;
177            }
178            visited.insert(node.clone());
179
180            // Add neighbors (both directions)
181            for edge in &self.edges {
182                if edge.from == node && !visited.contains(&edge.to) {
183                    stack.push(edge.to.clone());
184                }
185                if edge.to == node && !visited.contains(&edge.from) {
186                    stack.push(edge.from.clone());
187                }
188            }
189        }
190
191        visited.len() < self.nodes.len()
192    }
193}
194
195#[derive(Default)]
196pub struct Context {
197    pub data: HashMap<String, serde_json::Value>,
198}
199
200impl Context {
201    pub fn new() -> Self {
202        Self {
203            data: HashMap::new(),
204        }
205    }
206
207    /// Set a value in the context
208    pub fn set(&mut self, key: impl Into<String>, value: serde_json::Value) {
209        self.data.insert(key.into(), value);
210    }
211
212    pub fn get(&self, key: &str) -> Option<&serde_json::Value> {
213        self.data.get(key)
214    }
215
216    /// Check if key exists in context
217    pub fn contains_key(&self, key: &str) -> bool {
218        self.data.contains_key(key)
219    }
220
221    /// Remove a value from context
222    pub fn remove(&mut self, key: &str) -> Option<serde_json::Value> {
223        self.data.remove(key)
224    }
225
226    /// Clear all context data
227    pub fn clear(&mut self) {
228        self.data.clear();
229    }
230}
231
232pub struct Graph {
233    pub def: GraphDef,
234    pub context: Context,
235}
236
237impl Graph {
238    pub fn new(def: GraphDef) -> Self {
239        Self {
240            def,
241            context: Context::default(),
242        }
243    }
244}