use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use anyhow::anyhow;
use petgraph::algo::is_cyclic_directed;
use petgraph::graph::NodeIndex;
use petgraph::prelude::DiGraph;
use petgraph::visit::Topo;
use petgraph::Direction;
use serde::ser::SerializeMap;
use serde::{Deserialize, Serialize, Serializer};
use serde_json::{Map, Value};
use thiserror::Error;
use crate::handler::decision::DecisionHandler;
use crate::handler::expression::ExpressionHandler;
use crate::handler::function::FunctionHandler;
use crate::handler::node::NodeRequest;
use crate::handler::table::zen::DecisionTableHandler;
use crate::loader::DecisionLoader;
use crate::model::{DecisionContent, DecisionNode, DecisionNodeKind};
use crate::{EvaluationError, NodeError};
pub struct DecisionGraph<'a, L: DecisionLoader> {
graph: DiGraph<&'a DecisionNode, usize>,
loader: Arc<L>,
trace: bool,
max_depth: u8,
iteration: u8,
}
pub struct DecisionGraphConfig<'a, T: DecisionLoader> {
pub loader: Arc<T>,
pub content: &'a DecisionContent,
pub trace: bool,
pub iteration: u8,
pub max_depth: u8,
}
impl<'a, L: DecisionLoader> DecisionGraph<'a, L> {
pub fn try_new(
config: DecisionGraphConfig<'a, L>,
) -> Result<Self, DecisionGraphValidationError> {
let content = config.content;
let mut graph = DiGraph::new();
let mut index_map = HashMap::new();
for node in &content.nodes {
let node_id = node.id.clone();
let node_index = graph.add_node(node);
index_map.insert(node_id, node_index);
}
for (weight, edge) in content.edges.iter().enumerate() {
let source_index = index_map.get(&edge.source_id).ok_or_else(|| {
DecisionGraphValidationError::MissingNode(edge.source_id.to_string())
})?;
let target_index = index_map.get(&edge.target_id).ok_or_else(|| {
DecisionGraphValidationError::MissingNode(edge.target_id.to_string())
})?;
graph.add_edge(source_index.clone(), target_index.clone(), weight);
}
Ok(Self {
graph,
iteration: config.iteration,
trace: config.trace,
loader: config.loader.clone(),
max_depth: config.max_depth,
})
}
pub fn validate(&self) -> Result<(), DecisionGraphValidationError> {
let input_count = self.node_kind_count(DecisionNodeKind::InputNode);
if input_count != 1 {
return Err(DecisionGraphValidationError::InvalidInputCount(
input_count as u32,
));
}
let output_count = self.node_kind_count(DecisionNodeKind::OutputNode);
if output_count < 1 {
return Err(DecisionGraphValidationError::InvalidOutputCount(
output_count as u32,
));
}
if is_cyclic_directed(&self.graph) {
return Err(DecisionGraphValidationError::CyclicGraph);
}
Ok(())
}
fn node_kind_count(&self, kind: DecisionNodeKind) -> usize {
self.graph
.raw_nodes()
.iter()
.filter(|node| node.weight.kind == kind)
.count()
}
fn incoming_nodes(&self, node_id: NodeIndex) -> Vec<&DecisionNode> {
let neighbors = self.graph.neighbors_directed(node_id, Direction::Incoming);
neighbors.map(|neighbor| self.graph[neighbor]).collect()
}
pub async fn evaluate(&self, context: &Value) -> Result<DecisionGraphResponse, NodeError> {
let root_start = Instant::now();
self.validate().map_err(|e| NodeError {
node_id: "".to_string(),
source: anyhow!(e),
})?;
if self.iteration >= self.max_depth {
return Err(NodeError {
node_id: "".to_string(),
source: anyhow!(EvaluationError::DepthLimitExceeded),
});
}
let mut dfs = Topo::new(&self.graph);
let mut node_data = HashMap::<&str, Value>::default();
let mut node_traces = self.trace.then(|| HashMap::default());
let default_patch = Value::Object(Map::new());
while let Some(nid) = dfs.next(&self.graph) {
let node = self.graph[nid];
let start = Instant::now();
macro_rules! trace {
($data: tt) => {
if let Some(nt) = &mut node_traces {
nt.insert(node.id.clone(), DecisionGraphTrace $data);
};
};
}
let incoming_nodes = self.incoming_nodes(nid);
let incoming_data =
incoming_nodes
.iter()
.fold(Value::Object(Map::new()), |mut prev, &curr| {
let data = node_data
.get(curr.id.as_str())
.unwrap_or_else(|| &default_patch);
merge_json(&mut prev, data, true);
prev
});
let node_request = NodeRequest {
node,
iteration: self.iteration,
input: incoming_data,
};
match node.kind {
DecisionNodeKind::InputNode => {
node_data.insert(&node.id, context.clone());
trace!({
input: Value::Null,
output: Value::Null,
name: node.name.clone(),
id: node.id.clone(),
performance: None,
trace_data: None,
});
}
DecisionNodeKind::OutputNode => {
trace!({
input: Value::Null,
output: Value::Null,
name: node.name.clone(),
id: node.id.clone(),
performance: None,
trace_data: None,
});
return Ok(DecisionGraphResponse {
result: node_request.input,
performance: format!("{:?}", root_start.elapsed()),
trace: node_traces,
});
}
DecisionNodeKind::FunctionNode { .. } => {
let res = FunctionHandler::new(self.trace)
.handle(&node_request)
.await
.map_err(|e| NodeError {
source: e.into(),
node_id: node.id.clone(),
})?;
node_data.insert(&node.id, res.output.clone());
trace!({
input: node_request.input,
output: res.output,
name: node.name.clone(),
id: node.id.clone(),
performance: Some(format!("{:?}", start.elapsed())),
trace_data: res.trace_data,
});
}
DecisionNodeKind::DecisionNode { .. } => {
let res = DecisionHandler::new(self.trace, self.max_depth, self.loader.clone())
.handle(&node_request)
.await
.map_err(|e| NodeError {
source: e.into(),
node_id: node.id.to_string(),
})?;
node_data.insert(&node.id, res.output.clone());
trace!({
input: node_request.input,
output: res.output,
name: node.name.clone(),
id: node.id.clone(),
performance: Some(format!("{:?}", start.elapsed())),
trace_data: res.trace_data,
});
}
DecisionNodeKind::DecisionTableNode { .. } => {
let res = DecisionTableHandler::new(self.trace)
.handle(&node_request)
.await
.map_err(|e| NodeError {
node_id: node.id.clone(),
source: e.into(),
})?;
node_data.insert(&node.id, res.output.clone());
trace!({
input: node_request.input,
output: res.output,
name: node.name.clone(),
id: node.id.clone(),
performance: Some(format!("{:?}", start.elapsed())),
trace_data: res.trace_data,
});
}
DecisionNodeKind::ExpressionNode { .. } => {
let res = ExpressionHandler::new(self.trace)
.handle(&node_request)
.await
.map_err(|e| NodeError {
node_id: node.id.clone(),
source: e.into(),
})?;
node_data.insert(&node.id, res.output.clone());
trace!({
input: node_request.input,
output: res.output,
name: node.name.clone(),
id: node.id.clone(),
performance: Some(format!("{:?}", start.elapsed())),
trace_data: res.trace_data,
});
}
}
}
Err(NodeError {
node_id: "".to_string(),
source: anyhow!("Graph did not halt. Missing output node."),
})
}
}
fn merge_json(doc: &mut Value, patch: &Value, top_level: bool) {
if !patch.is_object() && !patch.is_array() && top_level {
return;
}
if doc.is_object() && patch.is_object() {
let map = doc.as_object_mut().unwrap();
for (key, value) in patch.as_object().unwrap() {
if value.is_null() {
map.remove(key.as_str());
} else {
merge_json(map.entry(key.as_str()).or_insert(Value::Null), value, false);
}
}
} else if doc.is_array() && patch.is_array() {
let arr = doc.as_array_mut().unwrap();
arr.extend(patch.as_array().unwrap().clone());
} else {
*doc = patch.clone();
}
}
#[derive(Debug, Error)]
pub enum DecisionGraphValidationError {
#[error("Invalid input node count: {0}")]
InvalidInputCount(u32),
#[error("Invalid output node count: {0}")]
InvalidOutputCount(u32),
#[error("Cyclic graph detected")]
CyclicGraph,
#[error("Missing node")]
MissingNode(String),
}
impl Serialize for DecisionGraphValidationError {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut map = serializer.serialize_map(None)?;
match &self {
DecisionGraphValidationError::InvalidInputCount(count) => {
map.serialize_entry("type", "invalidInputCount")?;
map.serialize_entry("nodeCount", count)?;
}
DecisionGraphValidationError::InvalidOutputCount(count) => {
map.serialize_entry("type", "invalidOutputCount")?;
map.serialize_entry("nodeCount", count)?;
}
DecisionGraphValidationError::MissingNode(node_id) => {
map.serialize_entry("type", "missingNode")?;
map.serialize_entry("nodeId", node_id)?;
}
DecisionGraphValidationError::CyclicGraph => {
map.serialize_entry("type", "cyclicGraph")?;
}
}
map.end()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct DecisionGraphResponse {
pub performance: String,
pub result: Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub trace: Option<HashMap<String, DecisionGraphTrace>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct DecisionGraphTrace {
input: Value,
output: Value,
name: String,
id: String,
performance: Option<String>,
trace_data: Option<Value>,
}