wolf_derivation_graph/
lib.rs

1use std::{cell::RefCell, collections::HashMap, rc::Rc};
2use anyhow::{anyhow, Result};
3
4use wolf_graph::prelude::*;
5
6struct NodeData<T> {
7    value: Option<T>,
8    label: String,
9}
10
11impl<T> NodeData<T> {
12    fn new() -> Self {
13        Self {
14            value: None,
15            label: String::new(),
16        }
17    }
18
19    fn value(&self) -> Option<&T> {
20        self.value.as_ref()
21    }
22
23    fn set_value(&mut self, value: T) {
24        self.value = Some(value);
25    }
26
27    fn clear_value(&mut self) {
28        self.value = None;
29    }
30
31    fn label(&self) -> &str {
32        &self.label
33    }
34
35    fn set_label(&mut self, label: impl Into<String>) {
36        self.label = label.into();
37    }
38}
39
40type NodeDataRef<T> = Rc<RefCell<NodeData<T>>>;
41type G<T> = Graph<(), NodeDataRef<T>, String>;
42type D<T> = DAG::<G<T>>;
43pub type OpInputs<T> = HashMap<String, T>;
44pub type Operation<T> = Rc<dyn Fn(OpInputs<T>) -> Result<T>>;
45
46pub struct Deriver<T: Clone + 'static> {
47    node: NodeID,
48    op: Operation<T>,
49}
50
51impl<T: Eq + Clone + 'static> Deriver<T> {
52    pub fn new(node: impl AsRef<NodeID>, op: Operation<T>) -> Self {
53        Self {
54            node: node.as_ref().clone(),
55            op,
56        }
57    }
58
59    pub fn node(&self) -> &NodeID {
60        &self.node
61    }
62
63    pub fn op(&self) -> &Operation<T> {
64        &self.op
65    }
66}
67
68pub struct DerivationGraph<T: Eq + Clone + 'static>
69{
70    graph: D<T>,
71    derivations: HashMap<NodeID, Deriver<T>>,
72}
73
74impl<T: Eq + Clone + 'static> DerivationGraph<T> {
75    pub fn new() -> Self {
76        Self {
77            graph: D::new(),
78            derivations: HashMap::new(),
79        }
80    }
81
82    pub fn set_node_label(&mut self, id: impl AsRef<NodeID>, label: impl Into<String>) -> Result<()> {
83        self.graph.node_data(id)?.borrow_mut().set_label(label);
84        Ok(())
85    }
86
87    pub fn node_label(&self, id: impl AsRef<NodeID>) -> Result<String> {
88        Ok(self.graph.node_data(id)?.borrow().label().to_string())
89    }
90
91    pub fn set_edge_label(&mut self, id: impl AsRef<EdgeID>, label: impl Into<String>) -> Result<()> {
92        self.graph.set_edge_data(id, label.into())
93    }
94
95    pub fn edge_label(&self, id: impl AsRef<EdgeID>) -> Result<String> {
96        Ok(self.graph.edge_data(id)?.to_string())
97    }
98
99    pub fn add_node(
100        &mut self,
101        id: impl AsRef<NodeID>,
102        label: impl Into<String>,
103    ) -> Result<()> {
104        self.graph.add_node_with_data(&id, Rc::new(RefCell::new(NodeData::new())))
105            .map(|_| self.set_node_label(&id, label))
106            .map(|_| ())
107    }
108
109    pub fn add_node_with_value(
110        &mut self,
111        id: impl AsRef<NodeID>,
112        label: impl Into<String>,
113        value: impl Into<T>,
114    ) -> Result<()> {
115        let id = id.as_ref();
116        self.add_node(id, label)?;
117        self.set_node_value(id, value)
118    }
119
120    pub fn add_node_with_operation(
121        &mut self,
122        id: impl AsRef<NodeID>,
123        label: impl Into<String>,
124        op: Operation<T>,
125    ) -> Result<()> {
126        let id = id.as_ref();
127        self.add_node(id, label)?;
128        self.set_operation(id, op)
129    }
130
131    pub fn set_operation(
132        &mut self,
133        node: impl AsRef<NodeID>,
134        op: Operation<T>,
135    ) -> Result<()> {
136        let node = node.as_ref();
137        let deriver = Deriver::new(node, op);
138        self.derivations.insert(node.clone(), deriver);
139        self.clear_node_values_transitively(node)?;
140        Ok(())
141    }
142
143    pub fn add_edge(
144        &mut self,
145        id: impl AsRef<EdgeID>,
146        source: impl AsRef<NodeID>,
147        target: impl AsRef<NodeID>,
148        label: impl Into<String>,
149    ) -> Result<()> {
150        self.graph.add_edge_with_data(id, source, target, label.into())
151    }
152
153    pub fn node_value(&self, id: impl AsRef<NodeID>) -> Result<Option<T>> {
154        Ok(self.graph.node_data(id)?.borrow().value().cloned())
155    }
156
157    pub fn has_node_value(&self, id: impl AsRef<NodeID>) -> Result<bool> {
158        Ok(self.node_value(id)?.is_some())
159    }
160
161    pub fn set_node_value(&self, id: impl AsRef<NodeID>, value: impl Into<T>) -> Result<()> {
162        let id = id.as_ref();
163        let value = value.into();
164
165        // Do nothing if the node already has the same data.
166        if let Some(existing_value) = self.node_value(id)? {
167            if existing_value == value {
168                return Ok(());
169            }
170        }
171
172        // Clear all the transitive downstream nodes of the node.
173        self.clear_node_values_transitively(id)?;
174
175        // Set the data of the node.
176        let node_data = self.graph.node_data(id)?;
177        node_data.borrow_mut().set_value(value);
178
179        Ok(())
180    }
181
182    pub fn source(&self, id: impl AsRef<EdgeID>) -> Result<NodeID> {
183        Ok(self.graph.source(id)?.clone())
184    }
185
186    pub fn target(&self, id: impl AsRef<EdgeID>) -> Result<NodeID> {
187        Ok(self.graph.target(id)?.clone())
188    }
189
190    pub fn derived_node_value(&self, id: impl AsRef<NodeID>) -> Result<T> {
191        let id = id.as_ref();
192
193        // If the node already has data, return it.
194        if let Some(data) = self.node_value(id)? {
195            return Ok(data.clone());
196        }
197
198        // If the node has no deriver, return an error.
199        let deriver = self.derivations.get(id).ok_or_else(|| anyhow!("no deriver for node '{id}'"))?;
200
201        // Accumulate the derived data from the source of each in_edge into `inputs`.
202        let mut inputs = HashMap::new();
203        let in_edges = self.graph.in_edges(id)?;
204        for in_edge in in_edges {
205            let source = self.source(&in_edge)?;
206            let data = self.derived_node_value(source)?;
207            let label = self.graph.edge_data(&in_edge)?.into_owned();
208            inputs.insert(label, data);
209        }
210
211        let op = deriver.op();
212        let data = op(inputs)?;
213        self.set_node_value(id, data.clone())?;
214        Ok(data)
215    }
216
217    pub fn clear_all_node_values(&self) {
218        for node in self.graph.all_nodes() {
219            self.clear_node_value(&node).unwrap();
220        }
221    }
222}
223
224// Private methods.
225impl<T: Eq + Clone + 'static> DerivationGraph<T> {
226    // Clears all downstream nodes of the given node transitively, stopping at
227    // nodes that are already clear.
228    fn clear_node_values_transitively(&self, id: impl AsRef<NodeID>) -> Result<()> {
229        let mut stack = vec![id.as_ref().clone()];
230        while let Some(id) = stack.pop() {
231            if self.clear_node_value(&id)? {
232                for successor in self.graph.successors(&id)? {
233                    if self.has_node_value(&successor)? {
234                        stack.push(successor);
235                    }
236                }
237            }
238        }
239        Ok(())
240    }
241
242    // Returns true if the node had data to clear.
243    fn clear_node_value(&self, id: impl AsRef<NodeID>) -> Result<bool> {
244        let id = id.as_ref();
245        let value = self.node_value(id)?;
246        let had_value = value.is_some();
247        if had_value {
248            self.graph.node_data(id)?.borrow_mut().clear_value();
249        }
250        Ok(had_value)
251    }
252}
253
254impl<T: Eq + Clone + 'static> Default for DerivationGraph<T> {
255    fn default() -> Self {
256        Self::new()
257    }
258}
259
260#[cfg(test)]
261mod tests {
262}