zen_engine/handler/table/
zen.rs

1use ahash::HashMap;
2use anyhow::anyhow;
3use std::sync::Arc;
4
5use crate::handler::node::{NodeRequest, NodeResponse, NodeResult};
6use crate::handler::table::{RowOutput, RowOutputKind};
7use crate::model::{DecisionNodeKind, DecisionTableContent, DecisionTableHitPolicy};
8use serde::Serialize;
9use tokio::sync::Mutex;
10use zen_expression::variable::Variable;
11use zen_expression::Isolate;
12
13#[derive(Debug, Serialize)]
14struct RowResult {
15    rule: Option<HashMap<String, String>>,
16    reference_map: Option<HashMap<String, Variable>>,
17    index: usize,
18    #[serde(skip)]
19    output: RowOutput,
20}
21
22#[derive(Debug)]
23pub struct DecisionTableHandler {
24    trace: bool,
25}
26
27impl DecisionTableHandler {
28    pub fn new(trace: bool) -> Self {
29        Self { trace }
30    }
31
32    pub async fn handle(&mut self, request: NodeRequest) -> NodeResult {
33        let content = match &request.node.kind {
34            DecisionNodeKind::DecisionTableNode { content } => Ok(content),
35            _ => Err(anyhow!("Unexpected node type")),
36        }?;
37
38        let inner_handler = DecisionTableHandlerInner::new(self.trace);
39        inner_handler
40            .handle(request.input.depth_clone(1), content)
41            .await
42    }
43}
44
45struct DecisionTableHandlerInner<'a> {
46    isolate: Isolate<'a>,
47    trace: bool,
48}
49
50impl<'a> DecisionTableHandlerInner<'a> {
51    pub fn new(trace: bool) -> Self {
52        Self {
53            isolate: Isolate::new(),
54            trace,
55        }
56    }
57
58    pub async fn handle(self, input: Variable, content: &'a DecisionTableContent) -> NodeResult {
59        let self_mutex = Arc::new(Mutex::new(self));
60
61        content
62            .transform_attributes
63            .run_with(input, |input| {
64                let self_mutex = self_mutex.clone();
65                async move {
66                    let mut self_ref = self_mutex.lock().await;
67
68                    self_ref.isolate.clear_references();
69                    self_ref.isolate.set_environment(input);
70                    let result = match &content.hit_policy {
71                        DecisionTableHitPolicy::First => self_ref.handle_first_hit(&content).await,
72                        DecisionTableHitPolicy::Collect => self_ref.handle_collect(&content).await,
73                    };
74
75                    self_ref.isolate.update_environment(|env| {
76                        if let Some(env) = env {
77                            env.dot_remove("$");
78                        };
79                    });
80
81                    result
82                }
83            })
84            .await
85    }
86
87    async fn handle_first_hit(&mut self, content: &'a DecisionTableContent) -> NodeResult {
88        for i in 0..content.rules.len() {
89            if let Some(result) = self.evaluate_row(&content, i) {
90                return Ok(NodeResponse {
91                    output: result.output.to_json().await,
92                    trace_data: self
93                        .trace
94                        .then(|| serde_json::to_value(&result).ok())
95                        .flatten(),
96                });
97            }
98        }
99
100        Ok(NodeResponse {
101            output: Variable::Null,
102            trace_data: None,
103        })
104    }
105
106    async fn handle_collect(&mut self, content: &'a DecisionTableContent) -> NodeResult {
107        let mut results = Vec::new();
108        for i in 0..content.rules.len() {
109            if let Some(result) = self.evaluate_row(&content, i) {
110                results.push(result);
111            }
112        }
113
114        let mut outputs = Vec::with_capacity(results.len());
115        for res in &results {
116            outputs.push(res.output.to_json().await);
117        }
118
119        Ok(NodeResponse {
120            output: Variable::from_array(outputs),
121            trace_data: self
122                .trace
123                .then(|| serde_json::to_value(&results).ok())
124                .flatten(),
125        })
126    }
127
128    fn evaluate_row(
129        &mut self,
130        content: &'a DecisionTableContent,
131        index: usize,
132    ) -> Option<RowResult> {
133        let rule = content.rules.get(index)?;
134        for input in &content.inputs {
135            let rule_value = rule.get(input.id.as_str())?;
136            if rule_value.trim().is_empty() {
137                continue;
138            }
139
140            match &input.field {
141                None => {
142                    let result = self.isolate.run_standard(rule_value.as_str()).ok()?;
143                    if !result.as_bool().unwrap_or(false) {
144                        return None;
145                    }
146                }
147                Some(field) => {
148                    self.isolate.set_reference(field.as_str()).ok()?;
149                    if !self.isolate.run_unary(rule_value.as_str()).ok()? {
150                        return None;
151                    }
152                }
153            }
154        }
155
156        let mut outputs: RowOutput = Default::default();
157        for output in &content.outputs {
158            let rule_value = rule.get(output.id.as_str())?;
159            if rule_value.trim().is_empty() {
160                continue;
161            }
162
163            let res = self.isolate.run_standard(rule_value).ok()?;
164            outputs.push(&output.field, RowOutputKind::Variable(res));
165        }
166
167        if !self.trace {
168            return Some(RowResult {
169                output: outputs,
170                rule: None,
171                reference_map: None,
172                index,
173            });
174        }
175
176        let rule_id = match rule.get("_id") {
177            Some(rid) => rid.clone(),
178            None => "".to_string(),
179        };
180
181        let mut expressions: HashMap<String, String> = Default::default();
182        let mut reference_map: HashMap<String, Variable> = Default::default();
183
184        expressions.insert("_id".to_string(), rule_id.clone());
185        if let Some(description) = rule.get("_description") {
186            expressions.insert("_description".to_string(), description.clone());
187        }
188
189        for input in &content.inputs {
190            let rule_value = rule.get(input.id.as_str())?;
191            let Some(input_field) = &input.field else {
192                continue;
193            };
194
195            if let Some(reference) = self.isolate.get_reference(input_field.as_str()) {
196                reference_map.insert(input_field.clone(), reference);
197            } else if let Some(reference) = self.isolate.run_standard(input_field.as_str()).ok() {
198                reference_map.insert(input_field.clone(), reference);
199            }
200
201            let input_identifier = format!("{input_field}[{}]", &input.id);
202            expressions.insert(input_identifier, rule_value.clone());
203        }
204
205        Some(RowResult {
206            output: outputs,
207            rule: Some(expressions),
208            reference_map: Some(reference_map),
209            index,
210        })
211    }
212}