1use crate::handler::custom_node_adapter::{CustomNodeAdapter, CustomNodeRequest};
2use crate::handler::decision::DecisionHandler;
3use crate::handler::expression::ExpressionHandler;
4use crate::handler::function::function::{Function, FunctionConfig};
5use crate::handler::function::module::console::ConsoleListener;
6use crate::handler::function::module::zen::ZenListener;
7use crate::handler::function::FunctionHandler;
8use crate::handler::function_v1;
9use crate::handler::function_v1::runtime::create_runtime;
10use crate::handler::node::{NodeRequest, PartialTraceError};
11use crate::handler::table::zen::DecisionTableHandler;
12use crate::handler::traversal::{GraphWalker, StableDiDecisionGraph};
13use crate::loader::DecisionLoader;
14use crate::model::{DecisionContent, DecisionNodeKind, FunctionNodeContent};
15use crate::util::validator_cache::ValidatorCache;
16use crate::{EvaluationError, NodeError};
17use ahash::{HashMap, HashMapExt};
18use anyhow::anyhow;
19use petgraph::algo::is_cyclic_directed;
20use serde::ser::SerializeMap;
21use serde::{Deserialize, Serialize, Serializer};
22use serde_json::Value;
23use std::hash::{DefaultHasher, Hash, Hasher};
24use std::rc::Rc;
25use std::sync::Arc;
26use std::time::Instant;
27use thiserror::Error;
28use zen_expression::variable::Variable;
29
30pub struct DecisionGraph<L: DecisionLoader + 'static, A: CustomNodeAdapter + 'static> {
31 initial_graph: StableDiDecisionGraph,
32 graph: StableDiDecisionGraph,
33 adapter: Arc<A>,
34 loader: Arc<L>,
35 trace: bool,
36 max_depth: u8,
37 iteration: u8,
38 runtime: Option<Rc<Function>>,
39 validator_cache: ValidatorCache,
40}
41
42pub struct DecisionGraphConfig<L: DecisionLoader + 'static, A: CustomNodeAdapter + 'static> {
43 pub loader: Arc<L>,
44 pub adapter: Arc<A>,
45 pub content: Arc<DecisionContent>,
46 pub trace: bool,
47 pub iteration: u8,
48 pub max_depth: u8,
49 pub validator_cache: Option<ValidatorCache>,
50}
51
52impl<L: DecisionLoader + 'static, A: CustomNodeAdapter + 'static> DecisionGraph<L, A> {
53 pub fn try_new(
54 config: DecisionGraphConfig<L, A>,
55 ) -> Result<Self, DecisionGraphValidationError> {
56 let content = config.content;
57 let mut graph = StableDiDecisionGraph::new();
58 let mut index_map = HashMap::new();
59
60 for node in &content.nodes {
61 let node_id = node.id.clone();
62 let node_index = graph.add_node(node.clone());
63
64 index_map.insert(node_id, node_index);
65 }
66
67 for (_, edge) in content.edges.iter().enumerate() {
68 let source_index = index_map.get(&edge.source_id).ok_or_else(|| {
69 DecisionGraphValidationError::MissingNode(edge.source_id.to_string())
70 })?;
71
72 let target_index = index_map.get(&edge.target_id).ok_or_else(|| {
73 DecisionGraphValidationError::MissingNode(edge.target_id.to_string())
74 })?;
75
76 graph.add_edge(source_index.clone(), target_index.clone(), edge.clone());
77 }
78
79 Ok(Self {
80 initial_graph: graph.clone(),
81 graph,
82 iteration: config.iteration,
83 trace: config.trace,
84 loader: config.loader,
85 adapter: config.adapter,
86 max_depth: config.max_depth,
87 validator_cache: config.validator_cache.unwrap_or_default(),
88 runtime: None,
89 })
90 }
91
92 pub(crate) fn with_function(mut self, runtime: Option<Rc<Function>>) -> Self {
93 self.runtime = runtime;
94 self
95 }
96
97 pub(crate) fn reset_graph(&mut self) {
98 self.graph = self.initial_graph.clone();
99 }
100
101 async fn get_or_insert_function(&mut self) -> anyhow::Result<Rc<Function>> {
102 if let Some(function) = &self.runtime {
103 return Ok(function.clone());
104 }
105
106 let function = Function::create(FunctionConfig {
107 listeners: Some(vec![
108 Box::new(ConsoleListener),
109 Box::new(ZenListener {
110 loader: self.loader.clone(),
111 adapter: self.adapter.clone(),
112 }),
113 ]),
114 })
115 .await
116 .map_err(|err| anyhow!(err.to_string()))?;
117 let rc_function = Rc::new(function);
118 self.runtime.replace(rc_function.clone());
119
120 Ok(rc_function)
121 }
122
123 pub fn validate(&self) -> Result<(), DecisionGraphValidationError> {
124 let input_count = self.input_node_count();
125 if input_count != 1 {
126 return Err(DecisionGraphValidationError::InvalidInputCount(
127 input_count as u32,
128 ));
129 }
130
131 if is_cyclic_directed(&self.graph) {
132 return Err(DecisionGraphValidationError::CyclicGraph);
133 }
134
135 Ok(())
136 }
137
138 fn input_node_count(&self) -> usize {
139 self.graph
140 .node_weights()
141 .filter(|weight| matches!(weight.kind, DecisionNodeKind::InputNode { content: _ }))
142 .count()
143 }
144
145 pub async fn evaluate(
146 &mut self,
147 context: Variable,
148 ) -> Result<DecisionGraphResponse, NodeError> {
149 let root_start = Instant::now();
150
151 self.validate().map_err(|e| NodeError {
152 node_id: "".to_string(),
153 source: anyhow!(e),
154 trace: None,
155 })?;
156
157 if self.iteration >= self.max_depth {
158 return Err(NodeError {
159 node_id: "".to_string(),
160 source: anyhow!(EvaluationError::DepthLimitExceeded),
161 trace: None,
162 });
163 }
164
165 let mut walker = GraphWalker::new(&self.graph);
166 let mut node_traces = self.trace.then(|| HashMap::default());
167
168 while let Some(nid) = walker.next(
169 &mut self.graph,
170 self.trace.then_some(|mut trace: DecisionGraphTrace| {
171 if let Some(nt) = &mut node_traces {
172 trace.order = nt.len() as u32;
173 nt.insert(trace.id.clone(), trace);
174 };
175 }),
176 ) {
177 if let Some(_) = walker.get_node_data(nid) {
178 continue;
179 }
180
181 let node = (&self.graph[nid]).clone();
182 let start = Instant::now();
183
184 macro_rules! trace {
185 ({ $($field:ident: $value:expr),* $(,)? }) => {
186 if let Some(nt) = &mut node_traces {
187 nt.insert(
188 node.id.clone(),
189 DecisionGraphTrace {
190 name: node.name.clone(),
191 id: node.id.clone(),
192 performance: Some(format!("{:.1?}", start.elapsed())),
193 order: nt.len() as u32,
194 $($field: $value,)*
195 }
196 );
197 }
198 };
199 }
200
201 match &node.kind {
202 DecisionNodeKind::InputNode { content } => {
203 trace!({
204 input: Variable::Null,
205 output: context.clone(),
206 trace_data: None,
207 });
208
209 if let Some(json_schema) = content
210 .schema
211 .as_ref()
212 .map(|s| serde_json::from_str::<Value>(&s).ok())
213 .flatten()
214 {
215 let validator_key = create_validator_cache_key(&json_schema);
216 let validator = self
217 .validator_cache
218 .get_or_insert(validator_key, &json_schema)
219 .await
220 .map_err(|e| NodeError {
221 source: e.into(),
222 node_id: node.id.clone(),
223 trace: error_trace(&node_traces),
224 })?;
225
226 let context_json = context.to_value();
227 validator.validate(&context_json).map_err(|e| NodeError {
228 source: anyhow!(serde_json::to_value(
229 Into::<Box<EvaluationError>>::into(e)
230 )
231 .unwrap_or_default()),
232 node_id: node.id.clone(),
233 trace: error_trace(&node_traces),
234 })?;
235 }
236
237 walker.set_node_data(nid, context.clone());
238 }
239 DecisionNodeKind::OutputNode { content } => {
240 let incoming_data = walker.incoming_node_data(&self.graph, nid, false);
241
242 trace!({
243 input: incoming_data.clone(),
244 output: Variable::Null,
245 trace_data: None,
246 });
247
248 if let Some(json_schema) = content
249 .schema
250 .as_ref()
251 .map(|s| serde_json::from_str::<Value>(&s).ok())
252 .flatten()
253 {
254 let validator_key = create_validator_cache_key(&json_schema);
255 let validator = self
256 .validator_cache
257 .get_or_insert(validator_key, &json_schema)
258 .await
259 .map_err(|e| NodeError {
260 source: e.into(),
261 node_id: node.id.clone(),
262 trace: error_trace(&node_traces),
263 })?;
264
265 let incoming_data_json = incoming_data.to_value();
266 validator
267 .validate(&incoming_data_json)
268 .map_err(|e| NodeError {
269 source: anyhow!(serde_json::to_value(
270 Into::<Box<EvaluationError>>::into(e)
271 )
272 .unwrap_or_default()),
273 node_id: node.id.clone(),
274 trace: error_trace(&node_traces),
275 })?;
276 }
277
278 return Ok(DecisionGraphResponse {
279 result: incoming_data,
280 performance: format!("{:.1?}", root_start.elapsed()),
281 trace: node_traces,
282 });
283 }
284 DecisionNodeKind::SwitchNode { .. } => {
285 let input_data = walker.incoming_node_data(&self.graph, nid, false);
286
287 walker.set_node_data(nid, input_data);
288 }
289 DecisionNodeKind::FunctionNode { content } => {
290 let function = self.get_or_insert_function().await.map_err(|e| NodeError {
291 source: e.into(),
292 node_id: node.id.clone(),
293 trace: error_trace(&node_traces),
294 })?;
295
296 let node_request = NodeRequest {
297 node: node.clone(),
298 iteration: self.iteration,
299 input: walker.incoming_node_data(&self.graph, nid, true),
300 };
301 let res = match content {
302 FunctionNodeContent::Version2(_) => FunctionHandler::new(
303 function,
304 self.trace,
305 self.iteration,
306 self.max_depth,
307 )
308 .handle(node_request.clone())
309 .await
310 .map_err(|e| {
311 if let Some(detailed_err) = e.downcast_ref::<PartialTraceError>() {
312 trace!({
313 input: node_request.input.clone(),
314 output: Variable::Null,
315 trace_data: detailed_err.trace.clone(),
316 });
317 }
318
319 NodeError {
320 source: e.into(),
321 node_id: node.id.clone(),
322 trace: error_trace(&node_traces),
323 }
324 })?,
325 FunctionNodeContent::Version1(_) => {
326 let runtime = create_runtime().map_err(|e| NodeError {
327 source: e.into(),
328 node_id: node.id.clone(),
329 trace: error_trace(&node_traces),
330 })?;
331
332 function_v1::FunctionHandler::new(self.trace, runtime)
333 .handle(node_request.clone())
334 .await
335 .map_err(|e| NodeError {
336 source: e.into(),
337 node_id: node.id.clone(),
338 trace: error_trace(&node_traces),
339 })?
340 }
341 };
342
343 node_request.input.dot_remove("$nodes");
344 res.output.dot_remove("$nodes");
345
346 trace!({
347 input: node_request.input,
348 output: res.output.clone(),
349 trace_data: res.trace_data,
350 });
351 walker.set_node_data(nid, res.output);
352 }
353 DecisionNodeKind::DecisionNode { .. } => {
354 let node_request = NodeRequest {
355 node: node.clone(),
356 iteration: self.iteration,
357 input: walker.incoming_node_data(&self.graph, nid, true),
358 };
359
360 let res = DecisionHandler::new(
361 self.trace,
362 self.max_depth,
363 self.loader.clone(),
364 self.adapter.clone(),
365 self.runtime.clone(),
366 self.validator_cache.clone(),
367 )
368 .handle(node_request.clone())
369 .await
370 .map_err(|e| NodeError {
371 source: e.into(),
372 node_id: node.id.to_string(),
373 trace: error_trace(&node_traces),
374 })?;
375
376 node_request.input.dot_remove("$nodes");
377 res.output.dot_remove("$nodes");
378
379 trace!({
380 input: node_request.input,
381 output: res.output.clone(),
382 trace_data: res.trace_data,
383 });
384 walker.set_node_data(nid, res.output);
385 }
386 DecisionNodeKind::DecisionTableNode { .. } => {
387 let node_request = NodeRequest {
388 node: node.clone(),
389 iteration: self.iteration,
390 input: walker.incoming_node_data(&self.graph, nid, true),
391 };
392
393 let res = DecisionTableHandler::new(self.trace)
394 .handle(node_request.clone())
395 .await
396 .map_err(|e| NodeError {
397 node_id: node.id.clone(),
398 source: e.into(),
399 trace: error_trace(&node_traces),
400 })?;
401
402 node_request.input.dot_remove("$nodes");
403 res.output.dot_remove("$nodes");
404 res.output.dot_remove("$");
405
406 trace!({
407 input: node_request.input,
408 output: res.output.clone(),
409 trace_data: res.trace_data,
410 });
411 walker.set_node_data(nid, res.output);
412 }
413 DecisionNodeKind::ExpressionNode { .. } => {
414 let node_request = NodeRequest {
415 node: node.clone(),
416 iteration: self.iteration,
417 input: walker.incoming_node_data(&self.graph, nid, true),
418 };
419
420 let res = ExpressionHandler::new(self.trace)
421 .handle(node_request.clone())
422 .await
423 .map_err(|e| {
424 if let Some(detailed_err) = e.downcast_ref::<PartialTraceError>() {
425 trace!({
426 input: node_request.input.clone(),
427 output: Variable::Null,
428 trace_data: detailed_err.trace.clone(),
429 });
430 }
431
432 NodeError {
433 node_id: node.id.clone(),
434 source: e.into(),
435 trace: error_trace(&node_traces),
436 }
437 })?;
438
439 node_request.input.dot_remove("$nodes");
440 res.output.dot_remove("$nodes");
441
442 trace!({
443 input: node_request.input,
444 output: res.output.clone(),
445 trace_data: res.trace_data,
446 });
447 walker.set_node_data(nid, res.output);
448 }
449 DecisionNodeKind::CustomNode { .. } => {
450 let node_request = NodeRequest {
451 node: node.clone(),
452 iteration: self.iteration,
453 input: walker.incoming_node_data(&self.graph, nid, true),
454 };
455
456 let res = self
457 .adapter
458 .handle(CustomNodeRequest::try_from(node_request.clone()).unwrap())
459 .await
460 .map_err(|e| NodeError {
461 node_id: node.id.clone(),
462 source: e.into(),
463 trace: error_trace(&node_traces),
464 })?;
465
466 node_request.input.dot_remove("$nodes");
467 res.output.dot_remove("$nodes");
468
469 trace!({
470 input: node_request.input,
471 output: res.output.clone(),
472 trace_data: res.trace_data,
473 });
474 walker.set_node_data(nid, res.output);
475 }
476 }
477 }
478
479 Ok(DecisionGraphResponse {
480 result: walker.ending_variables(&self.graph),
481 performance: format!("{:.1?}", root_start.elapsed()),
482 trace: node_traces,
483 })
484 }
485}
486
487#[derive(Debug, Error)]
488pub enum DecisionGraphValidationError {
489 #[error("Invalid input node count: {0}")]
490 InvalidInputCount(u32),
491
492 #[error("Invalid output node count: {0}")]
493 InvalidOutputCount(u32),
494
495 #[error("Cyclic graph detected")]
496 CyclicGraph,
497
498 #[error("Missing node")]
499 MissingNode(String),
500}
501
502impl Serialize for DecisionGraphValidationError {
503 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
504 where
505 S: Serializer,
506 {
507 let mut map = serializer.serialize_map(None)?;
508
509 match &self {
510 DecisionGraphValidationError::InvalidInputCount(count) => {
511 map.serialize_entry("type", "invalidInputCount")?;
512 map.serialize_entry("nodeCount", count)?;
513 }
514 DecisionGraphValidationError::InvalidOutputCount(count) => {
515 map.serialize_entry("type", "invalidOutputCount")?;
516 map.serialize_entry("nodeCount", count)?;
517 }
518 DecisionGraphValidationError::MissingNode(node_id) => {
519 map.serialize_entry("type", "missingNode")?;
520 map.serialize_entry("nodeId", node_id)?;
521 }
522 DecisionGraphValidationError::CyclicGraph => {
523 map.serialize_entry("type", "cyclicGraph")?;
524 }
525 }
526
527 map.end()
528 }
529}
530
531#[derive(Debug, Clone, Serialize, Deserialize)]
532#[serde(rename_all = "camelCase")]
533pub struct DecisionGraphResponse {
534 pub performance: String,
535 pub result: Variable,
536 #[serde(skip_serializing_if = "Option::is_none")]
537 pub trace: Option<HashMap<String, DecisionGraphTrace>>,
538}
539
540#[derive(Debug, Clone, Serialize, Deserialize)]
541#[serde(rename_all = "camelCase")]
542pub struct DecisionGraphTrace {
543 pub input: Variable,
544 pub output: Variable,
545 pub name: String,
546 pub id: String,
547 pub performance: Option<String>,
548 pub trace_data: Option<Value>,
549 pub order: u32,
550}
551
552pub(crate) fn error_trace(trace: &Option<HashMap<String, DecisionGraphTrace>>) -> Option<Value> {
553 trace
554 .as_ref()
555 .map(|s| serde_json::to_value(s).ok())
556 .flatten()
557}
558
559fn create_validator_cache_key(content: &Value) -> u64 {
560 let mut hasher = DefaultHasher::new();
561 content.hash(&mut hasher);
562 hasher.finish()
563}