zen_engine/nodes/decision_table/
mod.rs

1use crate::model::CompilationKey;
2use crate::nodes::definition::NodeHandler;
3use crate::nodes::result::NodeResult;
4use crate::nodes::{NodeContext, NodeResponse};
5use ahash::HashMap;
6use serde::Serialize;
7use std::ops::Deref;
8use std::rc::Rc;
9use std::sync::Arc;
10use zen_expression::variable::ToVariable;
11use zen_expression::{ExpressionKind, Isolate};
12use zen_types::decision::{DecisionTableContent, DecisionTableHitPolicy, TransformAttributes};
13use zen_types::variable::Variable;
14#[derive(Debug, Clone)]
15pub struct DecisionTableNodeHandler;
16
17pub type DecisionTableNodeData = DecisionTableContent;
18
19type DecisionTableContext = NodeContext<DecisionTableNodeData, DecisionTableNodeTrace>;
20
21impl NodeHandler for DecisionTableNodeHandler {
22    type NodeData = DecisionTableNodeData;
23    type TraceData = DecisionTableNodeTrace;
24
25    fn transform_attributes(
26        &self,
27        ctx: &NodeContext<Self::NodeData, Self::TraceData>,
28    ) -> Option<TransformAttributes> {
29        Some(ctx.node.transform_attributes.clone())
30    }
31
32    async fn handle(&self, ctx: NodeContext<Self::NodeData, Self::TraceData>) -> NodeResult {
33        match ctx.node.hit_policy {
34            DecisionTableHitPolicy::First => self.handle_first_hit(ctx),
35            DecisionTableHitPolicy::Collect => self.handle_collect(ctx),
36        }
37    }
38}
39
40impl DecisionTableNodeHandler {
41    fn handle_first_hit(&self, ctx: DecisionTableContext) -> NodeResult {
42        let mut isolate = Isolate::new();
43        isolate.set_environment(ctx.input.depth_clone(1));
44
45        for (index, rule) in ctx.node.rules.iter().enumerate() {
46            if let Some(result) = self.evaluate_row(&ctx, rule, &mut isolate) {
47                return match result {
48                    RowResult::Output(output) => ctx.success(output),
49                    RowResult::WithTrace {
50                        output,
51                        reference_map,
52                        rule,
53                    } => {
54                        ctx.trace(|t| {
55                            *t = DecisionTableNodeTrace::FirstHit(DecisionTableRowTrace {
56                                reference_map,
57                                index,
58                                rule,
59                            })
60                        });
61
62                        ctx.success(output)
63                    }
64                };
65            }
66        }
67
68        Ok(NodeResponse {
69            output: Variable::Null,
70            trace_data: None,
71        })
72    }
73
74    fn handle_collect(&self, ctx: DecisionTableContext) -> NodeResult {
75        let mut isolate = Isolate::new();
76        let mut outputs = Vec::new();
77        let mut traces = Vec::new();
78        isolate.set_environment(ctx.input.depth_clone(1));
79
80        for (index, rule) in ctx.node.rules.iter().enumerate() {
81            if let Some(result) = self.evaluate_row(&ctx, rule, &mut isolate) {
82                match result {
83                    RowResult::Output(output) => {
84                        outputs.push(output);
85                    }
86                    RowResult::WithTrace {
87                        output,
88                        reference_map,
89                        rule,
90                    } => {
91                        outputs.push(output);
92                        traces.push(DecisionTableRowTrace {
93                            index,
94                            rule,
95                            reference_map,
96                        });
97                    }
98                }
99            }
100        }
101
102        ctx.trace(|t| {
103            *t = DecisionTableNodeTrace::Collect(traces);
104        });
105
106        ctx.success(Variable::from_array(outputs))
107    }
108
109    fn evaluate_row<'a>(
110        &self,
111        ctx: &'a DecisionTableContext,
112        rule: &'a HashMap<Arc<str>, Arc<str>>,
113        isolate: &mut Isolate<'a>,
114    ) -> Option<RowResult> {
115        let content = &ctx.node;
116        for input in content.inputs.iter() {
117            let rule_value = rule.get(&input.id)?;
118            if rule_value.is_empty() {
119                continue;
120            }
121
122            match &input.field {
123                None => {
124                    let key = CompilationKey {
125                        kind: ExpressionKind::Standard,
126                        source: Arc::from(rule_value.clone()),
127                    };
128                    let result: Variable;
129                    if let Some(codes) = ctx
130                        .extensions
131                        .compiled_cache
132                        .as_ref()
133                        .and_then(|cc| cc.get(&key))
134                    {
135                        result = isolate.run_compiled(codes).ok()?;
136                    } else {
137                        result = isolate.run_standard(rule_value).ok()?;
138                    }
139                    if !result.as_bool().unwrap_or(false) {
140                        return None;
141                    }
142                }
143                Some(field) => {
144                    isolate.set_reference(&field).ok()?;
145                    let key = CompilationKey {
146                        kind: ExpressionKind::Unary,
147                        source: Arc::from(rule_value.clone()),
148                    };
149                    if let Some(codes) = ctx
150                        .extensions
151                        .compiled_cache
152                        .as_ref()
153                        .and_then(|cc| cc.get(&key))
154                    {
155                        if !isolate.run_unary_compiled(codes).ok()? {
156                            return None;
157                        }
158                    } else {
159                        if !isolate.run_unary(&rule_value).ok()? {
160                            return None;
161                        }
162                    }
163                }
164            }
165        }
166
167        let outputs = Variable::empty_object();
168        for output in content.outputs.iter() {
169            let rule_value = rule.get(&output.id)?;
170            if rule_value.is_empty() {
171                continue;
172            }
173
174            let key = CompilationKey {
175                kind: ExpressionKind::Standard,
176                source: Arc::from(rule_value.clone()),
177            };
178            let res: Variable;
179            if let Some(codes) = ctx
180                .extensions
181                .compiled_cache
182                .as_ref()
183                .and_then(|cc| cc.get(&key))
184            {
185                res = isolate.run_compiled(codes).ok()?;
186            } else {
187                res = isolate.run_standard(rule_value).ok()?;
188            }
189            outputs.dot_insert(output.field.deref(), res);
190        }
191
192        if !ctx.config.trace {
193            return Some(RowResult::Output(outputs));
194        }
195
196        let id_str = Rc::<str>::from("_id");
197        let description_str = Rc::<str>::from("_description");
198
199        let rule_id = match rule.get(id_str.as_ref()) {
200            Some(rid) => Rc::<str>::from(rid.deref()),
201            None => Rc::from(""),
202        };
203
204        let mut expressions: HashMap<Rc<str>, Rc<str>> = Default::default();
205        let mut reference_map: HashMap<Rc<str>, Variable> = Default::default();
206
207        expressions.insert(id_str.clone(), rule_id.clone());
208        if let Some(description) = rule.get(description_str.as_ref()) {
209            expressions.insert(description_str.clone(), Rc::from(description.deref()));
210        }
211
212        for input in content.inputs.iter() {
213            let rule_value = rule.get(input.id.deref())?;
214            let Some(input_field) = &input.field else {
215                continue;
216            };
217
218            if let Some(reference) = isolate.get_reference(input_field.deref()) {
219                reference_map.insert(Rc::from(input_field.deref()), reference);
220            } else if let Some(reference) = isolate.run_standard(input_field.deref()).ok() {
221                reference_map.insert(Rc::from(input_field.deref()), reference);
222            }
223
224            let input_identifier = format!("{input_field}[{}]", &input.id);
225            expressions.insert(
226                Rc::from(input_identifier.as_str()),
227                Rc::from(rule_value.deref()),
228            );
229        }
230
231        Some(RowResult::WithTrace {
232            output: outputs.to_variable(),
233            reference_map,
234            rule: expressions,
235        })
236    }
237}
238
239enum RowResult {
240    Output(Variable),
241    WithTrace {
242        output: Variable,
243        reference_map: HashMap<Rc<str>, Variable>,
244        rule: HashMap<Rc<str>, Rc<str>>,
245    },
246}
247
248#[derive(Debug, Clone, Serialize, ToVariable)]
249pub struct DecisionTableRowTrace {
250    index: usize,
251    reference_map: HashMap<Rc<str>, Variable>,
252    rule: HashMap<Rc<str>, Rc<str>>,
253}
254
255#[derive(Debug, Clone, Serialize, ToVariable)]
256#[serde(untagged)]
257pub enum DecisionTableNodeTrace {
258    FirstHit(DecisionTableRowTrace),
259    Collect(Vec<DecisionTableRowTrace>),
260}
261
262impl Default for DecisionTableNodeTrace {
263    fn default() -> Self {
264        DecisionTableNodeTrace::Collect(Default::default())
265    }
266}