zen_engine/nodes/decision/
mod.rs1use crate::decision_graph::graph::{DecisionGraph, DecisionGraphConfig};
2use crate::nodes::{NodeContext, NodeContextExt, NodeError, NodeHandler, NodeResult};
3use crate::EvaluationError;
4use std::cell::RefCell;
5use std::ops::Deref;
6use std::rc::Rc;
7use zen_types::decision::{DecisionNodeContent, TransformAttributes};
8use zen_types::variable::{ToVariable, Variable};
9
10#[derive(Debug, Clone, Default)]
11pub struct DecisionNodeHandler {
12 decision_graph: Rc<RefCell<Option<DecisionGraph>>>,
13}
14
15pub type DecisionNodeData = DecisionNodeContent;
16pub type DecisionNodeTrace = Variable;
17
18impl NodeHandler for DecisionNodeHandler {
19 type NodeData = DecisionNodeData;
20 type TraceData = DecisionNodeTrace;
21
22 fn transform_attributes(
23 &self,
24 ctx: &NodeContext<Self::NodeData, Self::TraceData>,
25 ) -> Option<TransformAttributes> {
26 Some(ctx.node.transform_attributes.clone())
27 }
28
29 async fn after_transform_attributes(
30 &self,
31 _ctx: &NodeContext<Self::NodeData, Self::TraceData>,
32 ) -> Result<(), NodeError> {
33 if let Some(graph) = self.decision_graph.borrow_mut().as_mut() {
34 graph.reset_graph();
35 };
36
37 Ok(())
38 }
39
40 async fn handle(&self, ctx: NodeContext<Self::NodeData, Self::TraceData>) -> NodeResult {
41 let loader = ctx.extensions.loader();
42 let sub_decision = loader.load(ctx.node.key.deref()).await.node_context(&ctx)?;
43
44 let mut decision_graph_ref = self.decision_graph.borrow_mut();
45 let decision_graph = match decision_graph_ref.as_mut() {
46 Some(dg) => dg,
47 None => {
48 let dg = DecisionGraph::try_new(DecisionGraphConfig {
49 content: sub_decision,
50 extensions: ctx.extensions.clone(),
51 trace: ctx.config.trace,
52 iteration: ctx.iteration + 1,
53 max_depth: ctx.config.max_depth,
54 })
55 .node_context(&ctx)?;
56
57 *decision_graph_ref = Some(dg);
58 match decision_graph_ref.as_mut() {
59 Some(dg) => dg,
60 None => return ctx.error("Failed to initialize decision graph".to_string()),
61 }
62 }
63 };
64
65 let evaluate_result = Box::pin(decision_graph.evaluate(ctx.input.clone())).await;
66 match evaluate_result {
67 Ok(result) => {
68 ctx.trace(|trace| {
69 *trace = result.trace.to_variable();
70 });
71
72 ctx.success(result.result)
73 }
74 Err(err) => {
75 if let EvaluationError::NodeError { trace, .. } = err.deref() {
76 ctx.trace(|t| *t = trace.to_variable());
77 }
78
79 ctx.error(err.to_string())
80 }
81 }
82 }
83}