zen_engine/decision_graph/
graph.rs1use crate::decision_graph::cleaner::VariableCleaner;
2use crate::decision_graph::tracer::NodeTracer;
3use crate::decision_graph::walker::{GraphWalker, NodeData, StableDiDecisionGraph};
4use crate::engine::EvaluationTraceKind;
5use crate::model::{DecisionContent, DecisionNodeKind};
6use crate::nodes::custom::CustomNodeHandler;
7use crate::nodes::decision::DecisionNodeHandler;
8use crate::nodes::decision_table::DecisionTableNodeHandler;
9use crate::nodes::expression::ExpressionNodeHandler;
10use crate::nodes::function::FunctionNodeHandler;
11use crate::nodes::input::InputNodeHandler;
12use crate::nodes::output::OutputNodeHandler;
13use crate::nodes::transform_attributes::TransformAttributesExecution;
14use crate::nodes::{
15 NodeContext, NodeContextBase, NodeContextConfig, NodeDataType, NodeHandler,
16 NodeHandlerExtensions, NodeResponse, NodeResult, TraceDataType,
17};
18use crate::{DecisionGraphTrace, DecisionGraphValidationError, EvaluationError};
19use ahash::{HashMap, HashMapExt};
20use petgraph::algo::is_cyclic_directed;
21use petgraph::matrix_graph::Zero;
22use serde::ser::SerializeMap;
23use serde::{Deserialize, Serialize, Serializer};
24use std::cell::RefCell;
25use std::ops::Deref;
26use std::rc::Rc;
27use std::sync::Arc;
28use std::time::Instant;
29use zen_expression::variable::{ToVariable, Variable};
30use zen_types::decision::DecisionNode;
31
32#[derive(Debug)]
33pub struct DecisionGraph {
34 initial_graph: StableDiDecisionGraph,
35 graph: StableDiDecisionGraph,
36 config: DecisionGraphConfig,
37}
38
39#[derive(Debug)]
40pub struct DecisionGraphConfig {
41 pub content: Arc<DecisionContent>,
42 pub trace: bool,
43 pub iteration: u8,
44 pub max_depth: u8,
45 pub extensions: NodeHandlerExtensions,
46}
47
48impl DecisionGraph {
49 pub fn try_new(config: DecisionGraphConfig) -> Result<Self, DecisionGraphValidationError> {
50 let graph = Self::build_graph(config.content.deref())?;
51 Ok(Self {
52 initial_graph: graph.clone(),
53 graph,
54 config,
55 })
56 }
57
58 fn build_graph(
59 content: &DecisionContent,
60 ) -> Result<StableDiDecisionGraph, DecisionGraphValidationError> {
61 let mut graph = StableDiDecisionGraph::new();
62 let mut index_map = HashMap::with_capacity(content.nodes.len());
63
64 for node in &content.nodes {
65 let node_id = node.id.clone();
66 let node_index = graph.add_node(node.clone());
67
68 index_map.insert(node_id, node_index);
69 }
70
71 for edge in &content.edges {
72 let source_index = index_map.get(&edge.source_id).ok_or_else(|| {
73 DecisionGraphValidationError::MissingNode(edge.source_id.to_string())
74 })?;
75
76 let target_index = index_map.get(&edge.target_id).ok_or_else(|| {
77 DecisionGraphValidationError::MissingNode(edge.target_id.to_string())
78 })?;
79
80 graph.add_edge(*source_index, *target_index, edge.clone());
81 }
82
83 Ok(graph)
84 }
85
86 pub(crate) fn reset_graph(&mut self) {
87 self.graph = self.initial_graph.clone();
88 }
89
90 pub fn validate(&self) -> Result<(), DecisionGraphValidationError> {
91 let input_count = self
92 .graph
93 .node_weights()
94 .filter(|w| matches!(w.kind, DecisionNodeKind::InputNode { .. }))
95 .count();
96 if input_count != 1 {
97 return Err(DecisionGraphValidationError::InvalidInputCount(
98 input_count as u32,
99 ));
100 }
101
102 if is_cyclic_directed(&self.graph) {
103 return Err(DecisionGraphValidationError::CyclicGraph);
104 }
105
106 Ok(())
107 }
108
109 fn build_node_context(&self, node: &DecisionNode, input: Variable) -> NodeContextBase {
110 NodeContextBase {
111 id: node.id.clone(),
112 name: node.name.clone(),
113 input,
114 extensions: self.config.extensions.clone(),
115 iteration: self.config.iteration,
116 trace: match self.config.trace {
117 true => Some(RefCell::new(Variable::Null)),
118 false => None,
119 },
120 config: NodeContextConfig {
121 max_depth: self.config.max_depth,
122 trace: self.config.trace,
123 ..Default::default()
124 },
125 }
126 }
127
128 pub async fn evaluate(
129 &mut self,
130 context: Variable,
131 ) -> Result<DecisionGraphResponse, Box<EvaluationError>> {
132 let root_start = Instant::now();
133
134 self.validate()?;
135 if self.config.iteration >= self.config.max_depth {
136 return Err(Box::new(EvaluationError::DepthLimitExceeded));
137 }
138
139 let mut walker = GraphWalker::new(&self.graph);
140 let mut tracer = NodeTracer::new(self.config.trace);
141
142 while let Some(nid) = walker.next(&mut self.graph, tracer.trace_callback()) {
143 if let Some(_) = walker.get_node_data(nid) {
144 continue;
145 }
146
147 let node = &self.graph[nid];
148 let start = Instant::now();
149 let (input, input_trace) = walker.incoming_node_data(&self.graph, nid, true);
150 let mut base_ctx = self.build_node_context(node.deref(), input);
151
152 let node_execution = match &node.kind {
153 DecisionNodeKind::InputNode { content } => {
154 base_ctx.input = context.clone();
155 handle_node(base_ctx, content.clone(), InputNodeHandler).await
156 }
157 DecisionNodeKind::OutputNode { content } => {
158 handle_node(base_ctx, content.clone(), OutputNodeHandler).await
159 }
160 DecisionNodeKind::SwitchNode { .. } => Ok(NodeResponse {
161 output: input_trace.clone(),
162 trace_data: None,
163 }),
164 DecisionNodeKind::FunctionNode { content } => {
165 handle_node(base_ctx, content.clone(), FunctionNodeHandler).await
166 }
167 DecisionNodeKind::DecisionNode { content } => {
168 handle_node(base_ctx, content.clone(), DecisionNodeHandler::default()).await
169 }
170 DecisionNodeKind::DecisionTableNode { content } => {
171 handle_node(base_ctx, content.clone(), DecisionTableNodeHandler).await
172 }
173 DecisionNodeKind::ExpressionNode { content } => {
174 handle_node(base_ctx, content.clone(), ExpressionNodeHandler).await
175 }
176 DecisionNodeKind::CustomNode { content } => {
177 handle_node(base_ctx, content.clone(), CustomNodeHandler).await
178 }
179 };
180
181 tracer.record_execution(node.deref(), input_trace, &node_execution, start.elapsed());
182
183 let output = match node_execution {
184 Ok(ok) => ok.output,
185 Err(err) => {
186 let mut cleaner = VariableCleaner::new();
187 let trace = tracer.into_traces();
188 if let Some(t) = &trace {
189 t.values().for_each(|v| {
190 cleaner.clean(&v.input);
191 cleaner.clean(&v.output);
192 if let Some(td) = &v.trace_data {
193 cleaner.clean(td);
194 }
195 })
196 }
197
198 return Err(Box::new(EvaluationError::NodeError {
199 node_id: err.node_id,
200 source: err.source,
201 trace: trace.map(|t| t.to_variable()),
202 }));
203 }
204 };
205
206 walker.set_node_data(
207 nid,
208 NodeData {
209 name: Rc::from(node.name.deref()),
210 data: output,
211 },
212 );
213
214 if matches!(node.kind, DecisionNodeKind::OutputNode { .. }) {
216 break;
217 }
218 }
219
220 let result = walker.ending_variables(&self.graph);
221 let trace = tracer.into_traces();
222
223 if self.config.iteration.is_zero() {
224 let mut cleaner = VariableCleaner::new();
225 cleaner.clean(&result);
226 if let Some(t) = &trace {
227 t.values().for_each(|v| {
228 cleaner.clean(&v.input);
229 cleaner.clean(&v.output);
230 if let Some(td) = &v.trace_data {
231 cleaner.clean(td);
232 }
233 })
234 }
235 }
236
237 Ok(DecisionGraphResponse {
238 performance: format!("{:.1?}", root_start.elapsed()),
239 result,
240 trace,
241 })
242 }
243}
244
245#[derive(Debug, Clone, Serialize, Deserialize)]
246#[serde(rename_all = "camelCase")]
247pub struct DecisionGraphResponse {
248 pub performance: String,
249 pub result: Variable,
250 #[serde(skip_serializing_if = "Option::is_none")]
251 pub trace: Option<HashMap<Arc<str>, DecisionGraphTrace>>,
252}
253
254impl DecisionGraphResponse {
255 pub fn serialize_with_mode<S>(
256 &self,
257 serializer: S,
258 mode: EvaluationTraceKind,
259 ) -> Result<S::Ok, S::Error>
260 where
261 S: Serializer,
262 {
263 let mut map = serializer.serialize_map(None)?;
264 map.serialize_entry("performance", &self.performance)?;
265 map.serialize_entry("result", &self.result)?;
266 if let Some(trace) = &self.trace {
267 map.serialize_entry("trace", &mode.serialize_trace(&trace.to_variable()))?;
268 }
269
270 map.end()
271 }
272}
273
274async fn handle_node<NodeData, TraceData, NodeHandlerType>(
275 base_ctx: NodeContextBase,
276 content: NodeData,
277 handler: NodeHandlerType,
278) -> NodeResult
279where
280 TraceData: TraceDataType,
281 NodeData: NodeDataType,
282 NodeHandlerType: NodeHandler<NodeData = NodeData, TraceData = TraceData>,
283{
284 let ctx = NodeContext::<NodeData, TraceData>::from_base(base_ctx.clone(), content);
285 if let Some(transform_attributes) = handler.transform_attributes(&ctx) {
286 return transform_attributes
287 .run_with(base_ctx, move |input, has_more| {
288 let handler = handler.clone();
289 let mut new_ctx = ctx.clone();
290 new_ctx.input = input;
291
292 async move {
293 match has_more {
294 false => handler.handle(new_ctx).await,
295 true => {
296 let result = handler.handle(new_ctx.clone()).await;
297 handler.after_transform_attributes(&new_ctx).await?;
298 result
299 }
300 }
301 }
302 })
303 .await;
304 }
305
306 handler.handle(ctx).await
307}