zen_engine/nodes/
context.rs

1use crate::nodes::definition::{NodeDataType, TraceDataType};
2use crate::nodes::extensions::NodeHandlerExtensions;
3use crate::nodes::function::v2::function::Function;
4use crate::nodes::result::{NodeResponse, NodeResult};
5use crate::nodes::NodeError;
6use crate::ZEN_CONFIG;
7use ahash::AHasher;
8use jsonschema::ValidationError;
9use serde::Serialize;
10use serde_json::Value;
11use std::cell::RefCell;
12use std::fmt::{Display, Formatter};
13use std::hash::Hasher;
14use std::sync::atomic::Ordering;
15use std::sync::Arc;
16use thiserror::Error;
17use zen_types::variable::{ToVariable, Variable};
18
19#[derive(Clone)]
20pub struct NodeContext<NodeData, TraceData>
21where
22    NodeData: NodeDataType,
23    TraceData: TraceDataType,
24{
25    pub id: Arc<str>,
26    pub name: Arc<str>,
27    pub node: NodeData,
28    pub input: Variable,
29    pub trace: Option<RefCell<TraceData>>,
30    pub extensions: NodeHandlerExtensions,
31    pub iteration: u8,
32    pub config: NodeContextConfig,
33}
34
35impl<NodeData, TraceData> NodeContext<NodeData, TraceData>
36where
37    NodeData: NodeDataType,
38    TraceData: TraceDataType,
39{
40    pub fn from_base(base: NodeContextBase, data: NodeData) -> Self {
41        Self {
42            id: base.id,
43            name: base.name,
44            input: base.input,
45            extensions: base.extensions,
46            iteration: base.iteration,
47            trace: base.config.trace.then(|| Default::default()),
48            node: data,
49            config: base.config,
50        }
51    }
52
53    pub fn trace<Function>(&self, mutator: Function)
54    where
55        Function: FnOnce(&mut TraceData),
56    {
57        if let Some(trace) = &self.trace {
58            mutator(&mut *trace.borrow_mut());
59        }
60    }
61
62    pub fn error<Error>(&self, error: Error) -> NodeResult
63    where
64        Error: Into<Box<dyn std::error::Error>>,
65    {
66        Err(self.make_error(error))
67    }
68
69    pub fn success(&self, output: Variable) -> NodeResult {
70        Ok(NodeResponse {
71            output,
72            trace_data: self.trace.as_ref().map(|v| (*v.borrow()).to_variable()),
73        })
74    }
75
76    pub(crate) fn make_error<Error>(&self, error: Error) -> NodeError
77    where
78        Error: Into<Box<dyn std::error::Error>>,
79    {
80        NodeError {
81            node_id: self.id.clone(),
82            trace: self.trace.as_ref().map(|v| (*v.borrow()).to_variable()),
83            source: error.into(),
84        }
85    }
86
87    pub(crate) async fn function_runtime(&self) -> Result<&Function, NodeError> {
88        self.extensions.function_runtime().await.node_context(self)
89    }
90
91    pub fn validate(&self, schema: &Value, value: &Value) -> Result<(), NodeError> {
92        let validator_cache = self.extensions.validator_cache();
93        let hash = self.hash_node();
94
95        let validator = validator_cache
96            .get_or_insert(hash, schema)
97            .node_context(self)?;
98
99        validator
100            .validate(value)
101            .map_err(|err| ValidationErrorJson::from(err))
102            .node_context(self)?;
103
104        Ok(())
105    }
106
107    fn hash_node(&self) -> u64 {
108        let mut hasher = AHasher::default();
109        hasher.write(self.id.as_bytes());
110        hasher.write(self.name.as_bytes());
111        hasher.finish()
112    }
113}
114
115pub trait NodeContextExt<T, Context>: Sized {
116    type Error: Into<Box<dyn std::error::Error>>;
117
118    fn with_node_context<Function, NewError>(
119        self,
120        ctx: &Context,
121        f: Function,
122    ) -> Result<T, NodeError>
123    where
124        Function: FnOnce(Self::Error) -> NewError,
125        NewError: Into<Box<dyn std::error::Error>>;
126
127    fn node_context(self, ctx: &Context) -> Result<T, NodeError> {
128        self.with_node_context(ctx, |e| e.into())
129    }
130
131    fn node_context_message(self, ctx: &Context, message: &str) -> Result<T, NodeError> {
132        self.with_node_context(ctx, |err| format!("{}: {}", message, err.into()))
133    }
134}
135
136impl<T, E, NodeData, TraceData> NodeContextExt<T, NodeContext<NodeData, TraceData>> for Result<T, E>
137where
138    E: Into<Box<dyn std::error::Error>>,
139    NodeData: NodeDataType,
140    TraceData: TraceDataType,
141{
142    type Error = E;
143
144    fn with_node_context<Function, NewError>(
145        self,
146        ctx: &NodeContext<NodeData, TraceData>,
147        f: Function,
148    ) -> Result<T, NodeError>
149    where
150        Function: FnOnce(Self::Error) -> NewError,
151        NewError: Into<Box<dyn std::error::Error>>,
152    {
153        self.map_err(|err| ctx.make_error(f(err)))
154    }
155}
156
157impl<T, NodeData, TraceData> NodeContextExt<T, NodeContext<NodeData, TraceData>> for Option<T>
158where
159    NodeData: NodeDataType,
160    TraceData: TraceDataType,
161{
162    type Error = &'static str;
163
164    fn with_node_context<Function, NewError>(
165        self,
166        ctx: &NodeContext<NodeData, TraceData>,
167        f: Function,
168    ) -> Result<T, NodeError>
169    where
170        Function: FnOnce(Self::Error) -> NewError,
171        NewError: Into<Box<dyn std::error::Error>>,
172    {
173        self.ok_or_else(|| ctx.make_error(f("None")))
174    }
175
176    fn node_context_message(
177        self,
178        ctx: &NodeContext<NodeData, TraceData>,
179        message: &str,
180    ) -> Result<T, NodeError> {
181        self.with_node_context(ctx, |_| message.to_string())
182    }
183}
184
185#[derive(Clone)]
186pub struct NodeContextBase {
187    pub id: Arc<str>,
188    pub name: Arc<str>,
189    pub input: Variable,
190    pub iteration: u8,
191    pub extensions: NodeHandlerExtensions,
192    pub config: NodeContextConfig,
193    pub trace: Option<RefCell<Variable>>,
194}
195
196impl NodeContextBase {
197    pub fn error<Error>(&self, error: Error) -> NodeResult
198    where
199        Error: Into<Box<dyn std::error::Error>>,
200    {
201        Err(self.make_error(error))
202    }
203
204    pub fn success(&self, output: Variable) -> NodeResult {
205        Ok(NodeResponse {
206            output,
207            trace_data: self.trace.as_ref().map(|v| v.borrow().to_variable()),
208        })
209    }
210
211    fn make_error<Error>(&self, error: Error) -> NodeError
212    where
213        Error: Into<Box<dyn std::error::Error>>,
214    {
215        NodeError {
216            node_id: self.id.clone(),
217            trace: self.trace.as_ref().map(|t| t.borrow().to_variable()),
218            source: error.into(),
219        }
220    }
221
222    pub fn trace<Function>(&self, mutator: Function)
223    where
224        Function: FnOnce(&mut Variable),
225    {
226        if let Some(trace) = &self.trace {
227            mutator(&mut *trace.borrow_mut());
228        }
229    }
230}
231
232impl<NodeData, TraceData> From<NodeContext<NodeData, TraceData>> for NodeContextBase
233where
234    NodeData: NodeDataType,
235    TraceData: TraceDataType,
236{
237    fn from(value: NodeContext<NodeData, TraceData>) -> Self {
238        let trace = match value.config.trace {
239            true => Some(RefCell::new(Variable::Null)),
240            false => None,
241        };
242
243        Self {
244            id: value.id,
245            name: value.name,
246            input: value.input,
247            extensions: value.extensions,
248            iteration: value.iteration,
249            config: value.config,
250            trace,
251        }
252    }
253}
254
255impl<T, E> NodeContextExt<T, NodeContextBase> for Result<T, E>
256where
257    E: Into<Box<dyn std::error::Error>>,
258{
259    type Error = E;
260
261    fn with_node_context<Function, NewError>(
262        self,
263        ctx: &NodeContextBase,
264        f: Function,
265    ) -> Result<T, NodeError>
266    where
267        Function: FnOnce(Self::Error) -> NewError,
268        NewError: Into<Box<dyn std::error::Error>>,
269    {
270        self.map_err(|err| ctx.make_error(f(err)))
271    }
272}
273
274impl<T> NodeContextExt<T, NodeContextBase> for Option<T> {
275    type Error = &'static str;
276
277    fn with_node_context<Function, NewError>(
278        self,
279        ctx: &NodeContextBase,
280        f: Function,
281    ) -> Result<T, NodeError>
282    where
283        Function: FnOnce(Self::Error) -> NewError,
284        NewError: Into<Box<dyn std::error::Error>>,
285    {
286        self.ok_or_else(|| ctx.make_error(f("None")))
287    }
288
289    fn node_context_message(self, ctx: &NodeContextBase, message: &str) -> Result<T, NodeError> {
290        self.with_node_context(ctx, |_| message.to_string())
291    }
292}
293
294#[derive(Debug, Serialize, Error)]
295#[serde(rename_all = "camelCase")]
296struct ValidationErrorJson {
297    path: String,
298    message: String,
299}
300
301impl Display for ValidationErrorJson {
302    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
303        write!(f, "{}: {}", self.path, self.message)
304    }
305}
306
307impl<'a> From<ValidationError<'a>> for ValidationErrorJson {
308    fn from(value: ValidationError<'a>) -> Self {
309        ValidationErrorJson {
310            path: value.instance_path.to_string(),
311            message: format!("{}", value),
312        }
313    }
314}
315
316#[derive(Clone)]
317pub struct NodeContextConfig {
318    pub trace: bool,
319    pub nodes_in_context: bool,
320    pub max_depth: u8,
321    pub function_timeout_millis: u64,
322    pub http_auth: bool,
323}
324
325impl Default for NodeContextConfig {
326    fn default() -> Self {
327        Self {
328            trace: false,
329            nodes_in_context: ZEN_CONFIG.nodes_in_context.load(Ordering::Relaxed),
330            function_timeout_millis: ZEN_CONFIG.function_timeout_millis.load(Ordering::Relaxed),
331            http_auth: ZEN_CONFIG.http_auth.load(Ordering::Relaxed),
332            max_depth: 5,
333        }
334    }
335}