zen_engine/handler/
graph.rs

1use crate::handler::custom_node_adapter::{CustomNodeAdapter, CustomNodeRequest};
2use crate::handler::decision::DecisionHandler;
3use crate::handler::expression::ExpressionHandler;
4use crate::handler::function::function::{Function, FunctionConfig};
5use crate::handler::function::module::console::ConsoleListener;
6use crate::handler::function::module::zen::ZenListener;
7use crate::handler::function::FunctionHandler;
8use crate::handler::function_v1;
9use crate::handler::function_v1::runtime::create_runtime;
10use crate::handler::node::{NodeRequest, PartialTraceError};
11use crate::handler::table::zen::DecisionTableHandler;
12use crate::handler::traversal::{GraphWalker, StableDiDecisionGraph};
13use crate::loader::DecisionLoader;
14use crate::model::{DecisionContent, DecisionNodeKind, FunctionNodeContent};
15use crate::util::validator_cache::ValidatorCache;
16use crate::{EvaluationError, NodeError};
17use ahash::{HashMap, HashMapExt};
18use anyhow::anyhow;
19use petgraph::algo::is_cyclic_directed;
20use serde::ser::SerializeMap;
21use serde::{Deserialize, Serialize, Serializer};
22use serde_json::Value;
23use std::hash::{DefaultHasher, Hash, Hasher};
24use std::rc::Rc;
25use std::sync::Arc;
26use std::time::Instant;
27use thiserror::Error;
28use zen_expression::variable::Variable;
29
30pub struct DecisionGraph<L: DecisionLoader + 'static, A: CustomNodeAdapter + 'static> {
31    initial_graph: StableDiDecisionGraph,
32    graph: StableDiDecisionGraph,
33    adapter: Arc<A>,
34    loader: Arc<L>,
35    trace: bool,
36    max_depth: u8,
37    iteration: u8,
38    runtime: Option<Rc<Function>>,
39    validator_cache: ValidatorCache,
40}
41
42pub struct DecisionGraphConfig<L: DecisionLoader + 'static, A: CustomNodeAdapter + 'static> {
43    pub loader: Arc<L>,
44    pub adapter: Arc<A>,
45    pub content: Arc<DecisionContent>,
46    pub trace: bool,
47    pub iteration: u8,
48    pub max_depth: u8,
49    pub validator_cache: Option<ValidatorCache>,
50}
51
52impl<L: DecisionLoader + 'static, A: CustomNodeAdapter + 'static> DecisionGraph<L, A> {
53    pub fn try_new(
54        config: DecisionGraphConfig<L, A>,
55    ) -> Result<Self, DecisionGraphValidationError> {
56        let content = config.content;
57        let mut graph = StableDiDecisionGraph::new();
58        let mut index_map = HashMap::new();
59
60        for node in &content.nodes {
61            let node_id = node.id.clone();
62            let node_index = graph.add_node(node.clone());
63
64            index_map.insert(node_id, node_index);
65        }
66
67        for (_, edge) in content.edges.iter().enumerate() {
68            let source_index = index_map.get(&edge.source_id).ok_or_else(|| {
69                DecisionGraphValidationError::MissingNode(edge.source_id.to_string())
70            })?;
71
72            let target_index = index_map.get(&edge.target_id).ok_or_else(|| {
73                DecisionGraphValidationError::MissingNode(edge.target_id.to_string())
74            })?;
75
76            graph.add_edge(source_index.clone(), target_index.clone(), edge.clone());
77        }
78
79        Ok(Self {
80            initial_graph: graph.clone(),
81            graph,
82            iteration: config.iteration,
83            trace: config.trace,
84            loader: config.loader,
85            adapter: config.adapter,
86            max_depth: config.max_depth,
87            validator_cache: config.validator_cache.unwrap_or_default(),
88            runtime: None,
89        })
90    }
91
92    pub(crate) fn with_function(mut self, runtime: Option<Rc<Function>>) -> Self {
93        self.runtime = runtime;
94        self
95    }
96
97    pub(crate) fn reset_graph(&mut self) {
98        self.graph = self.initial_graph.clone();
99    }
100
101    async fn get_or_insert_function(&mut self) -> anyhow::Result<Rc<Function>> {
102        if let Some(function) = &self.runtime {
103            return Ok(function.clone());
104        }
105
106        let function = Function::create(FunctionConfig {
107            listeners: Some(vec![
108                Box::new(ConsoleListener),
109                Box::new(ZenListener {
110                    loader: self.loader.clone(),
111                    adapter: self.adapter.clone(),
112                }),
113            ]),
114        })
115        .await
116        .map_err(|err| anyhow!(err.to_string()))?;
117        let rc_function = Rc::new(function);
118        self.runtime.replace(rc_function.clone());
119
120        Ok(rc_function)
121    }
122
123    pub fn validate(&self) -> Result<(), DecisionGraphValidationError> {
124        let input_count = self.input_node_count();
125        if input_count != 1 {
126            return Err(DecisionGraphValidationError::InvalidInputCount(
127                input_count as u32,
128            ));
129        }
130
131        if is_cyclic_directed(&self.graph) {
132            return Err(DecisionGraphValidationError::CyclicGraph);
133        }
134
135        Ok(())
136    }
137
138    fn input_node_count(&self) -> usize {
139        self.graph
140            .node_weights()
141            .filter(|weight| matches!(weight.kind, DecisionNodeKind::InputNode { content: _ }))
142            .count()
143    }
144
145    pub async fn evaluate(
146        &mut self,
147        context: Variable,
148    ) -> Result<DecisionGraphResponse, NodeError> {
149        let root_start = Instant::now();
150
151        self.validate().map_err(|e| NodeError {
152            node_id: "".to_string(),
153            source: anyhow!(e),
154            trace: None,
155        })?;
156
157        if self.iteration >= self.max_depth {
158            return Err(NodeError {
159                node_id: "".to_string(),
160                source: anyhow!(EvaluationError::DepthLimitExceeded),
161                trace: None,
162            });
163        }
164
165        let mut walker = GraphWalker::new(&self.graph);
166        let mut node_traces = self.trace.then(|| HashMap::default());
167
168        while let Some(nid) = walker.next(
169            &mut self.graph,
170            self.trace.then_some(|mut trace: DecisionGraphTrace| {
171                if let Some(nt) = &mut node_traces {
172                    trace.order = nt.len() as u32;
173                    nt.insert(trace.id.clone(), trace);
174                };
175            }),
176        ) {
177            if let Some(_) = walker.get_node_data(nid) {
178                continue;
179            }
180
181            let node = (&self.graph[nid]).clone();
182            let start = Instant::now();
183
184            macro_rules! trace {
185                ({ $($field:ident: $value:expr),* $(,)? }) => {
186                    if let Some(nt) = &mut node_traces {
187                        nt.insert(
188                            node.id.clone(),
189                            DecisionGraphTrace {
190                                name: node.name.clone(),
191                                id: node.id.clone(),
192                                performance: Some(format!("{:.1?}", start.elapsed())),
193                                order: nt.len() as u32,
194                                $($field: $value,)*
195                            }
196                        );
197                    }
198                };
199            }
200
201            match &node.kind {
202                DecisionNodeKind::InputNode { content } => {
203                    trace!({
204                        input: Variable::Null,
205                        output: context.clone(),
206                        trace_data: None,
207                    });
208
209                    if let Some(json_schema) = content
210                        .schema
211                        .as_ref()
212                        .map(|s| serde_json::from_str::<Value>(&s).ok())
213                        .flatten()
214                    {
215                        let validator_key = create_validator_cache_key(&json_schema);
216                        let validator = self
217                            .validator_cache
218                            .get_or_insert(validator_key, &json_schema)
219                            .await
220                            .map_err(|e| NodeError {
221                                source: e.into(),
222                                node_id: node.id.clone(),
223                                trace: error_trace(&node_traces),
224                            })?;
225
226                        let context_json = context.to_value();
227                        validator.validate(&context_json).map_err(|e| NodeError {
228                            source: anyhow!(serde_json::to_value(
229                                Into::<Box<EvaluationError>>::into(e)
230                            )
231                            .unwrap_or_default()),
232                            node_id: node.id.clone(),
233                            trace: error_trace(&node_traces),
234                        })?;
235                    }
236
237                    walker.set_node_data(nid, context.clone());
238                }
239                DecisionNodeKind::OutputNode { content } => {
240                    let incoming_data = walker.incoming_node_data(&self.graph, nid, false);
241
242                    trace!({
243                        input: incoming_data.clone(),
244                        output: Variable::Null,
245                        trace_data: None,
246                    });
247
248                    if let Some(json_schema) = content
249                        .schema
250                        .as_ref()
251                        .map(|s| serde_json::from_str::<Value>(&s).ok())
252                        .flatten()
253                    {
254                        let validator_key = create_validator_cache_key(&json_schema);
255                        let validator = self
256                            .validator_cache
257                            .get_or_insert(validator_key, &json_schema)
258                            .await
259                            .map_err(|e| NodeError {
260                                source: e.into(),
261                                node_id: node.id.clone(),
262                                trace: error_trace(&node_traces),
263                            })?;
264
265                        let incoming_data_json = incoming_data.to_value();
266                        validator
267                            .validate(&incoming_data_json)
268                            .map_err(|e| NodeError {
269                                source: anyhow!(serde_json::to_value(
270                                    Into::<Box<EvaluationError>>::into(e)
271                                )
272                                .unwrap_or_default()),
273                                node_id: node.id.clone(),
274                                trace: error_trace(&node_traces),
275                            })?;
276                    }
277
278                    return Ok(DecisionGraphResponse {
279                        result: incoming_data,
280                        performance: format!("{:.1?}", root_start.elapsed()),
281                        trace: node_traces,
282                    });
283                }
284                DecisionNodeKind::SwitchNode { .. } => {
285                    let input_data = walker.incoming_node_data(&self.graph, nid, false);
286
287                    walker.set_node_data(nid, input_data);
288                }
289                DecisionNodeKind::FunctionNode { content } => {
290                    let function = self.get_or_insert_function().await.map_err(|e| NodeError {
291                        source: e.into(),
292                        node_id: node.id.clone(),
293                        trace: error_trace(&node_traces),
294                    })?;
295
296                    let node_request = NodeRequest {
297                        node: node.clone(),
298                        iteration: self.iteration,
299                        input: walker.incoming_node_data(&self.graph, nid, true),
300                    };
301                    let res = match content {
302                        FunctionNodeContent::Version2(_) => FunctionHandler::new(
303                            function,
304                            self.trace,
305                            self.iteration,
306                            self.max_depth,
307                        )
308                        .handle(node_request.clone())
309                        .await
310                        .map_err(|e| {
311                            if let Some(detailed_err) = e.downcast_ref::<PartialTraceError>() {
312                                trace!({
313                                    input: node_request.input.clone(),
314                                    output: Variable::Null,
315                                    trace_data: detailed_err.trace.clone(),
316                                });
317                            }
318
319                            NodeError {
320                                source: e.into(),
321                                node_id: node.id.clone(),
322                                trace: error_trace(&node_traces),
323                            }
324                        })?,
325                        FunctionNodeContent::Version1(_) => {
326                            let runtime = create_runtime().map_err(|e| NodeError {
327                                source: e.into(),
328                                node_id: node.id.clone(),
329                                trace: error_trace(&node_traces),
330                            })?;
331
332                            function_v1::FunctionHandler::new(self.trace, runtime)
333                                .handle(node_request.clone())
334                                .await
335                                .map_err(|e| NodeError {
336                                    source: e.into(),
337                                    node_id: node.id.clone(),
338                                    trace: error_trace(&node_traces),
339                                })?
340                        }
341                    };
342
343                    node_request.input.dot_remove("$nodes");
344                    res.output.dot_remove("$nodes");
345
346                    trace!({
347                        input: node_request.input,
348                        output: res.output.clone(),
349                        trace_data: res.trace_data,
350                    });
351                    walker.set_node_data(nid, res.output);
352                }
353                DecisionNodeKind::DecisionNode { .. } => {
354                    let node_request = NodeRequest {
355                        node: node.clone(),
356                        iteration: self.iteration,
357                        input: walker.incoming_node_data(&self.graph, nid, true),
358                    };
359
360                    let res = DecisionHandler::new(
361                        self.trace,
362                        self.max_depth,
363                        self.loader.clone(),
364                        self.adapter.clone(),
365                        self.runtime.clone(),
366                        self.validator_cache.clone(),
367                    )
368                    .handle(node_request.clone())
369                    .await
370                    .map_err(|e| NodeError {
371                        source: e.into(),
372                        node_id: node.id.to_string(),
373                        trace: error_trace(&node_traces),
374                    })?;
375
376                    node_request.input.dot_remove("$nodes");
377                    res.output.dot_remove("$nodes");
378
379                    trace!({
380                        input: node_request.input,
381                        output: res.output.clone(),
382                        trace_data: res.trace_data,
383                    });
384                    walker.set_node_data(nid, res.output);
385                }
386                DecisionNodeKind::DecisionTableNode { .. } => {
387                    let node_request = NodeRequest {
388                        node: node.clone(),
389                        iteration: self.iteration,
390                        input: walker.incoming_node_data(&self.graph, nid, true),
391                    };
392
393                    let res = DecisionTableHandler::new(self.trace)
394                        .handle(node_request.clone())
395                        .await
396                        .map_err(|e| NodeError {
397                            node_id: node.id.clone(),
398                            source: e.into(),
399                            trace: error_trace(&node_traces),
400                        })?;
401
402                    node_request.input.dot_remove("$nodes");
403                    res.output.dot_remove("$nodes");
404                    res.output.dot_remove("$");
405
406                    trace!({
407                        input: node_request.input,
408                        output: res.output.clone(),
409                        trace_data: res.trace_data,
410                    });
411                    walker.set_node_data(nid, res.output);
412                }
413                DecisionNodeKind::ExpressionNode { .. } => {
414                    let node_request = NodeRequest {
415                        node: node.clone(),
416                        iteration: self.iteration,
417                        input: walker.incoming_node_data(&self.graph, nid, true),
418                    };
419
420                    let res = ExpressionHandler::new(self.trace)
421                        .handle(node_request.clone())
422                        .await
423                        .map_err(|e| {
424                            if let Some(detailed_err) = e.downcast_ref::<PartialTraceError>() {
425                                trace!({
426                                    input: node_request.input.clone(),
427                                    output: Variable::Null,
428                                    trace_data: detailed_err.trace.clone(),
429                                });
430                            }
431
432                            NodeError {
433                                node_id: node.id.clone(),
434                                source: e.into(),
435                                trace: error_trace(&node_traces),
436                            }
437                        })?;
438
439                    node_request.input.dot_remove("$nodes");
440                    res.output.dot_remove("$nodes");
441
442                    trace!({
443                        input: node_request.input,
444                        output: res.output.clone(),
445                        trace_data: res.trace_data,
446                    });
447                    walker.set_node_data(nid, res.output);
448                }
449                DecisionNodeKind::CustomNode { .. } => {
450                    let node_request = NodeRequest {
451                        node: node.clone(),
452                        iteration: self.iteration,
453                        input: walker.incoming_node_data(&self.graph, nid, true),
454                    };
455
456                    let res = self
457                        .adapter
458                        .handle(CustomNodeRequest::try_from(node_request.clone()).unwrap())
459                        .await
460                        .map_err(|e| NodeError {
461                            node_id: node.id.clone(),
462                            source: e.into(),
463                            trace: error_trace(&node_traces),
464                        })?;
465
466                    node_request.input.dot_remove("$nodes");
467                    res.output.dot_remove("$nodes");
468
469                    trace!({
470                        input: node_request.input,
471                        output: res.output.clone(),
472                        trace_data: res.trace_data,
473                    });
474                    walker.set_node_data(nid, res.output);
475                }
476            }
477        }
478
479        Ok(DecisionGraphResponse {
480            result: walker.ending_variables(&self.graph),
481            performance: format!("{:.1?}", root_start.elapsed()),
482            trace: node_traces,
483        })
484    }
485}
486
487#[derive(Debug, Error)]
488pub enum DecisionGraphValidationError {
489    #[error("Invalid input node count: {0}")]
490    InvalidInputCount(u32),
491
492    #[error("Invalid output node count: {0}")]
493    InvalidOutputCount(u32),
494
495    #[error("Cyclic graph detected")]
496    CyclicGraph,
497
498    #[error("Missing node")]
499    MissingNode(String),
500}
501
502impl Serialize for DecisionGraphValidationError {
503    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
504    where
505        S: Serializer,
506    {
507        let mut map = serializer.serialize_map(None)?;
508
509        match &self {
510            DecisionGraphValidationError::InvalidInputCount(count) => {
511                map.serialize_entry("type", "invalidInputCount")?;
512                map.serialize_entry("nodeCount", count)?;
513            }
514            DecisionGraphValidationError::InvalidOutputCount(count) => {
515                map.serialize_entry("type", "invalidOutputCount")?;
516                map.serialize_entry("nodeCount", count)?;
517            }
518            DecisionGraphValidationError::MissingNode(node_id) => {
519                map.serialize_entry("type", "missingNode")?;
520                map.serialize_entry("nodeId", node_id)?;
521            }
522            DecisionGraphValidationError::CyclicGraph => {
523                map.serialize_entry("type", "cyclicGraph")?;
524            }
525        }
526
527        map.end()
528    }
529}
530
531#[derive(Debug, Clone, Serialize, Deserialize)]
532#[serde(rename_all = "camelCase")]
533pub struct DecisionGraphResponse {
534    pub performance: String,
535    pub result: Variable,
536    #[serde(skip_serializing_if = "Option::is_none")]
537    pub trace: Option<HashMap<String, DecisionGraphTrace>>,
538}
539
540#[derive(Debug, Clone, Serialize, Deserialize)]
541#[serde(rename_all = "camelCase")]
542pub struct DecisionGraphTrace {
543    pub input: Variable,
544    pub output: Variable,
545    pub name: String,
546    pub id: String,
547    pub performance: Option<String>,
548    pub trace_data: Option<Value>,
549    pub order: u32,
550}
551
552pub(crate) fn error_trace(trace: &Option<HashMap<String, DecisionGraphTrace>>) -> Option<Value> {
553    trace
554        .as_ref()
555        .map(|s| serde_json::to_value(s).ok())
556        .flatten()
557}
558
559fn create_validator_cache_key(content: &Value) -> u64 {
560    let mut hasher = DefaultHasher::new();
561    content.hash(&mut hasher);
562    hasher.finish()
563}