zen_engine/decision_graph/
graph.rs

1use crate::decision_graph::cleaner::VariableCleaner;
2use crate::decision_graph::tracer::NodeTracer;
3use crate::decision_graph::walker::{GraphWalker, NodeData, StableDiDecisionGraph};
4use crate::engine::EvaluationTraceKind;
5use crate::model::{DecisionContent, DecisionNodeKind};
6use crate::nodes::custom::CustomNodeHandler;
7use crate::nodes::decision::DecisionNodeHandler;
8use crate::nodes::decision_table::DecisionTableNodeHandler;
9use crate::nodes::expression::ExpressionNodeHandler;
10use crate::nodes::function::FunctionNodeHandler;
11use crate::nodes::input::InputNodeHandler;
12use crate::nodes::output::OutputNodeHandler;
13use crate::nodes::transform_attributes::TransformAttributesExecution;
14use crate::nodes::{
15    NodeContext, NodeContextBase, NodeContextConfig, NodeDataType, NodeHandler,
16    NodeHandlerExtensions, NodeResponse, NodeResult, TraceDataType,
17};
18use crate::{DecisionGraphTrace, DecisionGraphValidationError, EvaluationError};
19use ahash::{HashMap, HashMapExt};
20use petgraph::algo::is_cyclic_directed;
21use petgraph::matrix_graph::Zero;
22use serde::ser::SerializeMap;
23use serde::{Deserialize, Serialize, Serializer};
24use std::cell::RefCell;
25use std::ops::Deref;
26use std::rc::Rc;
27use std::sync::Arc;
28use std::time::Instant;
29use zen_expression::variable::{ToVariable, Variable};
30use zen_types::decision::DecisionNode;
31
32#[derive(Debug)]
33pub struct DecisionGraph {
34    initial_graph: StableDiDecisionGraph,
35    graph: StableDiDecisionGraph,
36    config: DecisionGraphConfig,
37}
38
39#[derive(Debug)]
40pub struct DecisionGraphConfig {
41    pub content: Arc<DecisionContent>,
42    pub trace: bool,
43    pub iteration: u8,
44    pub max_depth: u8,
45    pub extensions: NodeHandlerExtensions,
46}
47
48impl DecisionGraph {
49    pub fn try_new(config: DecisionGraphConfig) -> Result<Self, DecisionGraphValidationError> {
50        let graph = Self::build_graph(config.content.deref())?;
51        Ok(Self {
52            initial_graph: graph.clone(),
53            graph,
54            config,
55        })
56    }
57
58    fn build_graph(
59        content: &DecisionContent,
60    ) -> Result<StableDiDecisionGraph, DecisionGraphValidationError> {
61        let mut graph = StableDiDecisionGraph::new();
62        let mut index_map = HashMap::with_capacity(content.nodes.len());
63
64        for node in &content.nodes {
65            let node_id = node.id.clone();
66            let node_index = graph.add_node(node.clone());
67
68            index_map.insert(node_id, node_index);
69        }
70
71        for edge in &content.edges {
72            let source_index = index_map.get(&edge.source_id).ok_or_else(|| {
73                DecisionGraphValidationError::MissingNode(edge.source_id.to_string())
74            })?;
75
76            let target_index = index_map.get(&edge.target_id).ok_or_else(|| {
77                DecisionGraphValidationError::MissingNode(edge.target_id.to_string())
78            })?;
79
80            graph.add_edge(*source_index, *target_index, edge.clone());
81        }
82
83        Ok(graph)
84    }
85
86    pub(crate) fn reset_graph(&mut self) {
87        self.graph = self.initial_graph.clone();
88    }
89
90    pub fn validate(&self) -> Result<(), DecisionGraphValidationError> {
91        let input_count = self
92            .graph
93            .node_weights()
94            .filter(|w| matches!(w.kind, DecisionNodeKind::InputNode { .. }))
95            .count();
96        if input_count != 1 {
97            return Err(DecisionGraphValidationError::InvalidInputCount(
98                input_count as u32,
99            ));
100        }
101
102        if is_cyclic_directed(&self.graph) {
103            return Err(DecisionGraphValidationError::CyclicGraph);
104        }
105
106        Ok(())
107    }
108
109    fn build_node_context(&self, node: &DecisionNode, input: Variable) -> NodeContextBase {
110        NodeContextBase {
111            id: node.id.clone(),
112            name: node.name.clone(),
113            input,
114            extensions: self.config.extensions.clone(),
115            iteration: self.config.iteration,
116            trace: match self.config.trace {
117                true => Some(RefCell::new(Variable::Null)),
118                false => None,
119            },
120            config: NodeContextConfig {
121                max_depth: self.config.max_depth,
122                trace: self.config.trace,
123                ..Default::default()
124            },
125        }
126    }
127
128    pub async fn evaluate(
129        &mut self,
130        context: Variable,
131    ) -> Result<DecisionGraphResponse, Box<EvaluationError>> {
132        let root_start = Instant::now();
133
134        self.validate()?;
135        if self.config.iteration >= self.config.max_depth {
136            return Err(Box::new(EvaluationError::DepthLimitExceeded));
137        }
138
139        let mut walker = GraphWalker::new(&self.graph);
140        let mut tracer = NodeTracer::new(self.config.trace);
141
142        while let Some(nid) = walker.next(&mut self.graph, tracer.trace_callback()) {
143            if let Some(_) = walker.get_node_data(nid) {
144                continue;
145            }
146
147            let node = &self.graph[nid];
148            let start = Instant::now();
149            let (input, input_trace) = walker.incoming_node_data(&self.graph, nid, true);
150            let mut base_ctx = self.build_node_context(node.deref(), input);
151
152            let node_execution = match &node.kind {
153                DecisionNodeKind::InputNode { content } => {
154                    base_ctx.input = context.clone();
155                    handle_node(base_ctx, content.clone(), InputNodeHandler).await
156                }
157                DecisionNodeKind::OutputNode { content } => {
158                    handle_node(base_ctx, content.clone(), OutputNodeHandler).await
159                }
160                DecisionNodeKind::SwitchNode { .. } => Ok(NodeResponse {
161                    output: input_trace.clone(),
162                    trace_data: None,
163                }),
164                DecisionNodeKind::FunctionNode { content } => {
165                    handle_node(base_ctx, content.clone(), FunctionNodeHandler).await
166                }
167                DecisionNodeKind::DecisionNode { content } => {
168                    handle_node(base_ctx, content.clone(), DecisionNodeHandler::default()).await
169                }
170                DecisionNodeKind::DecisionTableNode { content } => {
171                    handle_node(base_ctx, content.clone(), DecisionTableNodeHandler).await
172                }
173                DecisionNodeKind::ExpressionNode { content } => {
174                    handle_node(base_ctx, content.clone(), ExpressionNodeHandler).await
175                }
176                DecisionNodeKind::CustomNode { content } => {
177                    handle_node(base_ctx, content.clone(), CustomNodeHandler).await
178                }
179            };
180
181            tracer.record_execution(node.deref(), input_trace, &node_execution, start.elapsed());
182
183            let output = match node_execution {
184                Ok(ok) => ok.output,
185                Err(err) => {
186                    let mut cleaner = VariableCleaner::new();
187                    let trace = tracer.into_traces();
188                    if let Some(t) = &trace {
189                        t.values().for_each(|v| {
190                            cleaner.clean(&v.input);
191                            cleaner.clean(&v.output);
192                            if let Some(td) = &v.trace_data {
193                                cleaner.clean(td);
194                            }
195                        })
196                    }
197
198                    return Err(Box::new(EvaluationError::NodeError {
199                        node_id: err.node_id,
200                        source: err.source,
201                        trace: trace.map(|t| t.to_variable()),
202                    }));
203                }
204            };
205
206            walker.set_node_data(
207                nid,
208                NodeData {
209                    name: Rc::from(node.name.deref()),
210                    data: output,
211                },
212            );
213
214            // Terminate once Output node is reached
215            if matches!(node.kind, DecisionNodeKind::OutputNode { .. }) {
216                break;
217            }
218        }
219
220        let result = walker.ending_variables(&self.graph);
221        let trace = tracer.into_traces();
222
223        if self.config.iteration.is_zero() {
224            let mut cleaner = VariableCleaner::new();
225            cleaner.clean(&result);
226            if let Some(t) = &trace {
227                t.values().for_each(|v| {
228                    cleaner.clean(&v.input);
229                    cleaner.clean(&v.output);
230                    if let Some(td) = &v.trace_data {
231                        cleaner.clean(td);
232                    }
233                })
234            }
235        }
236
237        Ok(DecisionGraphResponse {
238            performance: format!("{:.1?}", root_start.elapsed()),
239            result,
240            trace,
241        })
242    }
243}
244
245#[derive(Debug, Clone, Serialize, Deserialize)]
246#[serde(rename_all = "camelCase")]
247pub struct DecisionGraphResponse {
248    pub performance: String,
249    pub result: Variable,
250    #[serde(skip_serializing_if = "Option::is_none")]
251    pub trace: Option<HashMap<Arc<str>, DecisionGraphTrace>>,
252}
253
254impl DecisionGraphResponse {
255    pub fn serialize_with_mode<S>(
256        &self,
257        serializer: S,
258        mode: EvaluationTraceKind,
259    ) -> Result<S::Ok, S::Error>
260    where
261        S: Serializer,
262    {
263        let mut map = serializer.serialize_map(None)?;
264        map.serialize_entry("performance", &self.performance)?;
265        map.serialize_entry("result", &self.result)?;
266        if let Some(trace) = &self.trace {
267            map.serialize_entry("trace", &mode.serialize_trace(&trace.to_variable()))?;
268        }
269
270        map.end()
271    }
272}
273
274async fn handle_node<NodeData, TraceData, NodeHandlerType>(
275    base_ctx: NodeContextBase,
276    content: NodeData,
277    handler: NodeHandlerType,
278) -> NodeResult
279where
280    TraceData: TraceDataType,
281    NodeData: NodeDataType,
282    NodeHandlerType: NodeHandler<NodeData = NodeData, TraceData = TraceData>,
283{
284    let ctx = NodeContext::<NodeData, TraceData>::from_base(base_ctx.clone(), content);
285    if let Some(transform_attributes) = handler.transform_attributes(&ctx) {
286        return transform_attributes
287            .run_with(base_ctx, move |input, has_more| {
288                let handler = handler.clone();
289                let mut new_ctx = ctx.clone();
290                new_ctx.input = input;
291
292                async move {
293                    match has_more {
294                        false => handler.handle(new_ctx).await,
295                        true => {
296                            let result = handler.handle(new_ctx.clone()).await;
297                            handler.after_transform_attributes(&new_ctx).await?;
298                            result
299                        }
300                    }
301                }
302            })
303            .await;
304    }
305
306    handler.handle(ctx).await
307}