zen_engine/handler/table/
zen.rs1use ahash::HashMap;
2use anyhow::anyhow;
3use std::sync::Arc;
4
5use crate::handler::node::{NodeRequest, NodeResponse, NodeResult};
6use crate::handler::table::{RowOutput, RowOutputKind};
7use crate::model::{DecisionNodeKind, DecisionTableContent, DecisionTableHitPolicy};
8use serde::Serialize;
9use tokio::sync::Mutex;
10use zen_expression::variable::Variable;
11use zen_expression::Isolate;
12
13#[derive(Debug, Serialize)]
14struct RowResult {
15 rule: Option<HashMap<String, String>>,
16 reference_map: Option<HashMap<String, Variable>>,
17 index: usize,
18 #[serde(skip)]
19 output: RowOutput,
20}
21
22#[derive(Debug)]
23pub struct DecisionTableHandler {
24 trace: bool,
25}
26
27impl DecisionTableHandler {
28 pub fn new(trace: bool) -> Self {
29 Self { trace }
30 }
31
32 pub async fn handle(&mut self, request: NodeRequest) -> NodeResult {
33 let content = match &request.node.kind {
34 DecisionNodeKind::DecisionTableNode { content } => Ok(content),
35 _ => Err(anyhow!("Unexpected node type")),
36 }?;
37
38 let inner_handler = DecisionTableHandlerInner::new(self.trace);
39 inner_handler
40 .handle(request.input.depth_clone(1), content)
41 .await
42 }
43}
44
45struct DecisionTableHandlerInner<'a> {
46 isolate: Isolate<'a>,
47 trace: bool,
48}
49
50impl<'a> DecisionTableHandlerInner<'a> {
51 pub fn new(trace: bool) -> Self {
52 Self {
53 isolate: Isolate::new(),
54 trace,
55 }
56 }
57
58 pub async fn handle(self, input: Variable, content: &'a DecisionTableContent) -> NodeResult {
59 let self_mutex = Arc::new(Mutex::new(self));
60
61 content
62 .transform_attributes
63 .run_with(input, |input| {
64 let self_mutex = self_mutex.clone();
65 async move {
66 let mut self_ref = self_mutex.lock().await;
67
68 self_ref.isolate.clear_references();
69 self_ref.isolate.set_environment(input);
70 match &content.hit_policy {
71 DecisionTableHitPolicy::First => self_ref.handle_first_hit(&content).await,
72 DecisionTableHitPolicy::Collect => self_ref.handle_collect(&content).await,
73 }
74 }
75 })
76 .await
77 }
78
79 async fn handle_first_hit(&mut self, content: &'a DecisionTableContent) -> NodeResult {
80 for i in 0..content.rules.len() {
81 if let Some(result) = self.evaluate_row(&content, i) {
82 return Ok(NodeResponse {
83 output: result.output.to_json().await,
84 trace_data: self
85 .trace
86 .then(|| serde_json::to_value(&result).ok())
87 .flatten(),
88 });
89 }
90 }
91
92 Ok(NodeResponse {
93 output: Variable::Null,
94 trace_data: None,
95 })
96 }
97
98 async fn handle_collect(&mut self, content: &'a DecisionTableContent) -> NodeResult {
99 let mut results = Vec::new();
100 for i in 0..content.rules.len() {
101 if let Some(result) = self.evaluate_row(&content, i) {
102 results.push(result);
103 }
104 }
105
106 let mut outputs = Vec::with_capacity(results.len());
107 for res in &results {
108 outputs.push(res.output.to_json().await);
109 }
110
111 Ok(NodeResponse {
112 output: Variable::from_array(outputs),
113 trace_data: self
114 .trace
115 .then(|| serde_json::to_value(&results).ok())
116 .flatten(),
117 })
118 }
119
120 fn evaluate_row(
121 &mut self,
122 content: &'a DecisionTableContent,
123 index: usize,
124 ) -> Option<RowResult> {
125 let rule = content.rules.get(index)?;
126 for input in &content.inputs {
127 let rule_value = rule.get(input.id.as_str())?;
128 if rule_value.trim().is_empty() {
129 continue;
130 }
131
132 match &input.field {
133 None => {
134 let result = self.isolate.run_standard(rule_value.as_str()).ok()?;
135 if !result.as_bool().unwrap_or(false) {
136 return None;
137 }
138 }
139 Some(field) => {
140 self.isolate.set_reference(field.as_str()).ok()?;
141 if !self.isolate.run_unary(rule_value.as_str()).ok()? {
142 return None;
143 }
144 }
145 }
146 }
147
148 let mut outputs: RowOutput = Default::default();
149 for output in &content.outputs {
150 let rule_value = rule.get(output.id.as_str())?;
151 if rule_value.trim().is_empty() {
152 continue;
153 }
154
155 let res = self.isolate.run_standard(rule_value).ok()?;
156 outputs.push(&output.field, RowOutputKind::Variable(res));
157 }
158
159 if !self.trace {
160 return Some(RowResult {
161 output: outputs,
162 rule: None,
163 reference_map: None,
164 index,
165 });
166 }
167
168 let rule_id = match rule.get("_id") {
169 Some(rid) => rid.clone(),
170 None => "".to_string(),
171 };
172
173 let mut expressions: HashMap<String, String> = Default::default();
174 let mut reference_map: HashMap<String, Variable> = Default::default();
175
176 expressions.insert("_id".to_string(), rule_id.clone());
177 if let Some(description) = rule.get("_description") {
178 expressions.insert("_description".to_string(), description.clone());
179 }
180
181 for input in &content.inputs {
182 let rule_value = rule.get(input.id.as_str())?;
183 let Some(input_field) = &input.field else {
184 continue;
185 };
186
187 if let Some(reference) = self.isolate.get_reference(input_field.as_str()) {
188 reference_map.insert(input_field.clone(), reference);
189 } else if let Some(reference) = self.isolate.run_standard(input_field.as_str()).ok() {
190 reference_map.insert(input_field.clone(), reference);
191 }
192
193 let input_identifier = format!("{input_field}[{}]", &input.id);
194 expressions.insert(input_identifier, rule_value.clone());
195 }
196
197 Some(RowResult {
198 output: outputs,
199 rule: Some(expressions),
200 reference_map: Some(reference_map),
201 index,
202 })
203 }
204}