zen_engine/handler/
decision.rs

1use crate::handler::custom_node_adapter::CustomNodeAdapter;
2use crate::handler::function::function::Function;
3use crate::handler::graph::{DecisionGraph, DecisionGraphConfig};
4use crate::handler::node::{NodeRequest, NodeResponse, NodeResult};
5use crate::loader::DecisionLoader;
6use crate::model::DecisionNodeKind;
7use crate::util::validator_cache::ValidatorCache;
8use anyhow::anyhow;
9use std::future::Future;
10use std::pin::Pin;
11use std::rc::Rc;
12use std::sync::Arc;
13use tokio::sync::Mutex;
14
15pub struct DecisionHandler<L: DecisionLoader + 'static, A: CustomNodeAdapter + 'static> {
16    trace: bool,
17    loader: Arc<L>,
18    adapter: Arc<A>,
19    max_depth: u8,
20    js_function: Option<Rc<Function>>,
21    validator_cache: ValidatorCache,
22}
23
24impl<L: DecisionLoader + 'static, A: CustomNodeAdapter + 'static> DecisionHandler<L, A> {
25    pub fn new(
26        trace: bool,
27        max_depth: u8,
28        loader: Arc<L>,
29        adapter: Arc<A>,
30        js_function: Option<Rc<Function>>,
31        validator_cache: ValidatorCache,
32    ) -> Self {
33        Self {
34            trace,
35            loader,
36            adapter,
37            max_depth,
38            js_function,
39            validator_cache,
40        }
41    }
42
43    pub fn handle<'s, 'arg, 'recursion>(
44        &'s self,
45        request: NodeRequest,
46    ) -> Pin<Box<dyn Future<Output = NodeResult> + 'recursion>>
47    where
48        's: 'recursion,
49        'arg: 'recursion,
50    {
51        Box::pin(async move {
52            let content = match &request.node.kind {
53                DecisionNodeKind::DecisionNode { content } => Ok(content),
54                _ => Err(anyhow!("Unexpected node type")),
55            }?;
56
57            let sub_decision = self.loader.load(&content.key).await?;
58            let sub_tree = DecisionGraph::try_new(DecisionGraphConfig {
59                content: sub_decision,
60                max_depth: self.max_depth,
61                loader: self.loader.clone(),
62                adapter: self.adapter.clone(),
63                iteration: request.iteration + 1,
64                trace: self.trace,
65                validator_cache: Some(self.validator_cache.clone()),
66            })?
67            .with_function(self.js_function.clone());
68
69            let sub_tree_mutex = Arc::new(Mutex::new(sub_tree));
70
71            content
72                .transform_attributes
73                .run_with(request.input, |input| {
74                    let sub_tree_mutex = sub_tree_mutex.clone();
75
76                    async move {
77                        let mut sub_tree_ref = sub_tree_mutex.lock().await;
78
79                        sub_tree_ref.reset_graph();
80                        sub_tree_ref
81                            .evaluate(input)
82                            .await
83                            .map(|r| NodeResponse {
84                                output: r.result,
85                                trace_data: serde_json::to_value(r.trace).ok(),
86                            })
87                            .map_err(|e| e.source)
88                    }
89                })
90                .await
91        })
92    }
93}