zen_engine/handler/
decision.rs1use crate::handler::custom_node_adapter::CustomNodeAdapter;
2use crate::handler::function::function::Function;
3use crate::handler::graph::{DecisionGraph, DecisionGraphConfig};
4use crate::handler::node::{NodeRequest, NodeResponse, NodeResult};
5use crate::loader::DecisionLoader;
6use crate::model::DecisionNodeKind;
7use crate::util::validator_cache::ValidatorCache;
8use anyhow::anyhow;
9use std::future::Future;
10use std::pin::Pin;
11use std::rc::Rc;
12use std::sync::Arc;
13use tokio::sync::Mutex;
14
15pub struct DecisionHandler<L: DecisionLoader + 'static, A: CustomNodeAdapter + 'static> {
16 trace: bool,
17 loader: Arc<L>,
18 adapter: Arc<A>,
19 max_depth: u8,
20 js_function: Option<Rc<Function>>,
21 validator_cache: ValidatorCache,
22}
23
24impl<L: DecisionLoader + 'static, A: CustomNodeAdapter + 'static> DecisionHandler<L, A> {
25 pub fn new(
26 trace: bool,
27 max_depth: u8,
28 loader: Arc<L>,
29 adapter: Arc<A>,
30 js_function: Option<Rc<Function>>,
31 validator_cache: ValidatorCache,
32 ) -> Self {
33 Self {
34 trace,
35 loader,
36 adapter,
37 max_depth,
38 js_function,
39 validator_cache,
40 }
41 }
42
43 pub fn handle<'s, 'arg, 'recursion>(
44 &'s self,
45 request: NodeRequest,
46 ) -> Pin<Box<dyn Future<Output = NodeResult> + 'recursion>>
47 where
48 's: 'recursion,
49 'arg: 'recursion,
50 {
51 Box::pin(async move {
52 let content = match &request.node.kind {
53 DecisionNodeKind::DecisionNode { content } => Ok(content),
54 _ => Err(anyhow!("Unexpected node type")),
55 }?;
56
57 let sub_decision = self.loader.load(&content.key).await?;
58 let sub_tree = DecisionGraph::try_new(DecisionGraphConfig {
59 content: sub_decision,
60 max_depth: self.max_depth,
61 loader: self.loader.clone(),
62 adapter: self.adapter.clone(),
63 iteration: request.iteration + 1,
64 trace: self.trace,
65 validator_cache: Some(self.validator_cache.clone()),
66 })?
67 .with_function(self.js_function.clone());
68
69 let sub_tree_mutex = Arc::new(Mutex::new(sub_tree));
70
71 content
72 .transform_attributes
73 .run_with(request.input, |input| {
74 let sub_tree_mutex = sub_tree_mutex.clone();
75
76 async move {
77 let mut sub_tree_ref = sub_tree_mutex.lock().await;
78
79 sub_tree_ref.reset_graph();
80 sub_tree_ref
81 .evaluate(input)
82 .await
83 .map(|r| NodeResponse {
84 output: r.result,
85 trace_data: serde_json::to_value(r.trace).ok(),
86 })
87 .map_err(|e| e.source)
88 }
89 })
90 .await
91 })
92 }
93}