wolf_graph/
tree.rs

1use std::borrow::Cow;
2
3use anyhow::{Result, bail};
4#[cfg(feature = "serde")]
5use serde::{ser::{Serialize, Serializer}, de::{self, Deserialize, Deserializer}};
6
7use crate::{nid, EdgeID, Edges, Error, IsTree, MutableForest, MutableGraph, MutableTree, NodeID, Nodes, PathExists, BlankGraph, TopologicalSort, VisitableForest, VisitableGraph, VisitableTree};
8
9/// A tree is a directed graph in which there is exactly one path between any two nodes.
10///
11/// A tree always has a root node, which is the only node that has no incoming edges.
12#[derive(Debug, Clone)]
13pub struct Tree<Inner>
14where
15    Inner: MutableGraph,
16{
17    root: NodeID,
18    graph: Inner,
19}
20
21/// A convenience type for a tree with no additional data on nodes or edges.
22pub type BlankTree = Tree<BlankGraph>;
23
24impl<Inner> Tree<Inner>
25where
26    Inner: MutableGraph,
27{
28    pub fn new_unchecked(root: NodeID, graph: Inner) -> Self {
29        Self { root, graph }
30    }
31
32    pub fn new_with_root_and_graph(root: NodeID, graph: Inner) -> Result<Self> {
33        graph.check_is_tree(&root)?;
34        Ok(Self::new_unchecked(root, graph))
35    }
36
37    pub fn graph(&self) -> &Inner {
38        &self.graph
39    }
40}
41
42impl<Inner> Tree<Inner>
43where
44    Inner: MutableGraph + Default + Clone,
45    Inner::NData: Default,
46{
47    pub fn new_with_root(root: NodeID) -> Self {
48        Self::new_unchecked(root, Inner::default().adding_node(&nid!("root")).unwrap())
49    }
50
51    pub fn new() -> Self {
52        Self::new_with_root(nid!("root"))
53    }
54}
55
56impl<Inner> Default for Tree<Inner>
57where
58    Inner: MutableGraph + Default + Clone,
59    Inner::NData: Default,
60{
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66impl<Inner> VisitableGraph for Tree<Inner>
67where
68    Inner: MutableGraph,
69{
70    type GData = Inner::GData;
71    type NData = Inner::NData;
72    type EData = Inner::EData;
73
74    fn data(&self) -> &Self::GData {
75        self.graph.data()
76    }
77
78    fn node_data(&self, id: impl AsRef<NodeID>) -> Result<Cow<'static, Self::NData>> {
79        self.graph.node_data(id)
80    }
81
82    fn edge_data(&self, id: impl AsRef<EdgeID>) -> Result<Cow<'static, Self::EData>> {
83        self.graph.edge_data(id)
84    }
85
86    fn is_empty(&self) -> bool {
87        self.graph.is_empty()
88    }
89
90    fn node_count(&self) -> usize {
91        self.graph.node_count()
92    }
93
94    fn edge_count(&self) -> usize {
95        self.graph.edge_count()
96    }
97
98    fn all_nodes(&self) -> Nodes {
99        self.graph.all_nodes()
100    }
101
102    fn all_edges(&self) -> Edges {
103        self.graph.all_edges()
104    }
105
106    fn has_node(&self, id: impl AsRef<NodeID>) -> bool {
107        self.graph.has_node(id)
108    }
109
110    fn has_edge(&self, id: impl AsRef<EdgeID>) -> bool {
111        self.graph.has_edge(id)
112    }
113
114    fn has_edge_from_to(&self, source: impl AsRef<NodeID>, target: impl AsRef<NodeID>) -> bool {
115        self.graph.has_edge_from_to(source, target)
116    }
117
118    fn has_edge_between(&self, a: impl AsRef<NodeID>, b: impl AsRef<NodeID>) -> bool {
119        self.graph.has_edge_between(a, b)
120    }
121
122    fn source(&self, id: impl AsRef<EdgeID>) -> Result<NodeID> {
123        self.graph.source(id)
124    }
125
126    fn target(&self, id: impl AsRef<EdgeID>) -> Result<NodeID> {
127        self.graph.target(id)
128    }
129
130    fn endpoints(&self, id: impl AsRef<EdgeID>) -> Result<(NodeID, NodeID)> {
131        self.graph.endpoints(id)
132    }
133
134    fn out_edges(&self, id: impl AsRef<NodeID>) -> Result<Edges> {
135        self.graph.out_edges(id)
136    }
137
138    fn in_edges(&self, id: impl AsRef<NodeID>) -> Result<Edges> {
139        self.graph.in_edges(id)
140    }
141
142    fn incident_edges(&self, id: impl AsRef<NodeID>) -> Result<Edges> {
143        self.graph.incident_edges(id)
144    }
145
146    fn out_degree(&self, id: impl AsRef<NodeID>) -> Result<usize> {
147        self.graph.out_degree(id)
148    }
149
150    fn in_degree(&self, id: impl AsRef<NodeID>) -> Result<usize> {
151        self.graph.in_degree(id)
152    }
153
154    fn degree(&self, id: impl AsRef<NodeID>) -> Result<usize> {
155        self.graph.degree(id)
156    }
157
158    fn successors(&self, id: impl AsRef<NodeID>) -> Result<Nodes> {
159        self.graph.successors(id)
160    }
161
162    fn predecessors(&self, id: impl AsRef<NodeID>) -> Result<Nodes> {
163        self.graph.predecessors(id)
164    }
165
166    fn neighbors(&self, id: impl AsRef<NodeID>) -> Result<Nodes> {
167        self.graph.neighbors(id)
168    }
169
170    fn has_successors(&self, id: impl AsRef<NodeID>) -> Result<bool> {
171        self.graph.has_successors(id)
172    }
173
174    fn has_predecessors(&self, id: impl AsRef<NodeID>) -> Result<bool> {
175        self.graph.has_predecessors(id)
176    }
177
178    fn has_neighbors(&self, id: impl AsRef<NodeID>) -> Result<bool> {
179        self.graph.has_neighbors(id)
180    }
181
182    fn all_roots(&self) -> Nodes {
183        vec![self.root.clone()].into_iter().collect()
184    }
185
186    fn all_leaves(&self) -> Nodes {
187        self.graph.all_leaves()
188    }
189
190    fn non_roots(&self) -> Nodes {
191        self.all_nodes().into_iter().filter(|n| n != &self.root).collect()
192    }
193
194    fn non_leaves(&self) -> Nodes {
195        self.graph.non_leaves()
196    }
197
198    fn all_internals(&self) -> Nodes {
199        self.graph.all_internals()
200    }
201
202    fn is_leaf(&self, id: impl AsRef<NodeID>) -> Result<bool> {
203        self.graph.is_leaf(id)
204    }
205
206    fn is_root(&self, id: impl AsRef<NodeID>) -> Result<bool> {
207        self.graph.is_root(id)
208    }
209
210    fn is_internal(&self, id: impl AsRef<NodeID>) -> Result<bool> {
211        self.graph.is_internal(id)
212    }
213}
214
215impl<Inner> VisitableTree for Tree<Inner>
216where
217    Inner: MutableGraph,
218{
219    fn root(&self) -> NodeID {
220        self.root.clone()
221    }
222}
223
224impl<Inner> VisitableForest for Tree<Inner>
225where
226    Inner: MutableGraph,
227{
228    fn in_edge(&self, node: impl AsRef<NodeID>) -> Result<Option<EdgeID>> {
229        Ok(self.in_edges(node)?.first().cloned())
230    }
231
232    fn in_edge_with_root(&self, node: impl AsRef<NodeID>) -> Result<Option<EdgeID>> {
233        self.in_edge(node)
234    }
235
236    fn parent(&self, node: impl AsRef<NodeID>) -> Result<Option<NodeID>> {
237        Ok(self.in_edge(node)?.map(|edge| self.source(&edge).unwrap()))
238    }
239
240    fn children(&self, node: Option<impl AsRef<NodeID>>) -> Result<Nodes> {
241        if let Some(node) = node {
242            self.successors(node)
243        } else {
244            self.successors(self.root())
245        }
246    }
247
248    fn has_children(&self, node: impl AsRef<NodeID>) -> Result<bool> {
249        self.has_successors(node)
250    }
251
252    fn child_count(&self, node: impl AsRef<NodeID>) -> Result<usize> {
253        self.out_degree(node)
254    }
255}
256
257impl<Inner> MutableTree for Tree<Inner>
258where
259    Inner: MutableGraph,
260{
261    fn set_root(&mut self, root: impl AsRef<NodeID>) -> Result<()> {
262        let root = root.as_ref();
263        self.graph.check_is_tree(root)?;
264        self.root = root.as_ref().clone();
265        Ok(())
266    }
267}
268
269impl<Inner> MutableForest for Tree<Inner>
270where
271    Inner: MutableGraph,
272{
273    fn add_node_with_node_and_edge_data(
274        &mut self,
275        node: impl AsRef<NodeID>,
276        parent: Option<impl AsRef<NodeID>>,
277        edge: impl AsRef<EdgeID>,
278        node_data: Self::NData,
279        edge_data: Self::EData,
280    ) -> Result<()> {
281        let node = node.as_ref();
282        self.graph.add_node_with_data(node, node_data)?;
283        let parent = parent.map(|p| p.as_ref().clone()).unwrap_or_else(|| self.root());
284        self.graph.add_edge_with_data(edge, parent, node, edge_data)?;
285        Ok(())
286    }
287
288    fn remove_node_ungrouping(&mut self, id: impl AsRef<NodeID>) -> Result<()> {
289        let id = id.as_ref();
290        if id == &self.root {
291            let children = self.children(Some(id))?;
292            if children.len() != 1 {
293                bail!(Error::NotATree);
294            }
295            let new_root = children.into_iter().next().unwrap();
296            self.graph.remove_node(id)?;
297            self.set_root(&new_root)?;
298        } else {
299            let new_parent = self.parent(id)?.unwrap();
300            let children = self.children(Some(id))?;
301            for child in children {
302                self.move_node(&child, Some(&new_parent))?;
303            }
304            self.graph.remove_node(id)?;
305        }
306        Ok(())
307    }
308
309    fn remove_node_and_children(&mut self, id: impl AsRef<NodeID>) -> Result<Nodes> {
310        let id = id.as_ref();
311
312        // Can't remove root
313        if id == &self.root {
314            bail!(Error::NotATree);
315        }
316
317        // Remove child nodes in reverse-topological sort order (most distant from the target first).
318        let to_remove = self.topological_sort_opt(&Nodes::from([id.clone()]), true)?;
319        for node in to_remove.iter() {
320            self.graph.remove_node(node)?;
321        }
322        Ok(to_remove.into_iter().collect())
323    }
324
325    fn remove_children(&mut self, id: impl AsRef<NodeID>) -> Result<Nodes> {
326        let id = id.as_ref();
327
328        // Remove child nodes in reverse-topological sort order (most distant from the target first).
329        let children = self.children(Some(id))?;
330        let to_remove = self.topological_sort_opt(&children, true)?;
331        for node in to_remove.iter() {
332            self.graph.remove_node(node)?;
333        }
334        Ok(to_remove.into_iter().collect())
335    }
336
337    fn move_node(&mut self, id: impl AsRef<NodeID>, new_parent: Option<impl AsRef<NodeID>>) -> Result<()> {
338        let id = id.as_ref();
339
340        // Can't move root
341        if id == &self.root {
342            bail!(Error::NotATree);
343        }
344
345        let edge = self.in_edge(id)?.unwrap();
346        let root = self.root();
347        let new_parent = new_parent.map(|p| p.as_ref().clone()).unwrap_or_else(|| root.clone());
348        let new_parent = new_parent.as_ref();
349        if !self.graph.can_move_dag_edge(&edge, new_parent, id)? {
350            bail!(Error::NotATree);
351        }
352        self.graph.move_edge(&edge, new_parent, id)?;
353        Ok(())
354    }
355
356    fn set_data(&mut self, data: Self::GData) {
357        self.graph.set_data(data);
358    }
359
360    fn set_node_data(&mut self, id: impl AsRef<NodeID>, data: Self::NData) -> Result<()> {
361        self.graph.set_node_data(id, data)
362    }
363
364    fn set_edge_data(&mut self, id: impl AsRef<EdgeID>, data: Self::EData) -> Result<()> {
365        self.graph.set_edge_data(id, data)
366    }
367
368    fn with_data(&mut self, transform: &dyn Fn(&mut Self::GData)) {
369        self.graph.with_data(transform);
370    }
371
372    fn with_node_data(&mut self, id: impl AsRef<NodeID>, transform: &dyn Fn(&mut Self::NData)) -> Result<()> {
373        self.graph.with_node_data(id, transform)
374    }
375
376    fn with_edge_data(&mut self, id: impl AsRef<EdgeID>, transform: &dyn Fn(&mut Self::EData)) -> Result<()> {
377        self.graph.with_edge_data(id, transform)
378    }
379}
380
381impl<Inner> PartialEq for Tree<Inner>
382where
383    Inner: MutableGraph + PartialEq,
384{
385    fn eq(&self, other: &Self) -> bool {
386        self.root == other.root && self.graph == other.graph
387    }
388}
389
390impl<Inner> Eq for Tree<Inner>
391where
392    Inner: MutableGraph + Eq,
393{
394}
395
396#[cfg(feature = "serde")]
397impl<Inner> Serialize for Tree<Inner>
398where
399    Inner: MutableGraph + Serialize,
400{
401    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
402    where
403        S: Serializer,
404    {
405        (&self.root, &self.graph).serialize(serializer)
406    }
407}
408
409#[cfg(feature = "serde")]
410impl<'de, Inner> Deserialize<'de> for Tree<Inner>
411where
412    Inner: MutableGraph + Deserialize<'de>,
413{
414    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
415    where
416        D: Deserializer<'de>,
417    {
418        let (root, graph) = <(NodeID, Inner)>::deserialize(deserializer)?;
419        Tree::new_with_root_and_graph(root, graph).map_err(de::Error::custom)
420    }
421}
422
423// If Serde and SerdeJSON are both present, add conveniences to serialize a Tree
424// to JSON.
425#[cfg(all(feature = "serde", feature = "serde_json"))]
426impl<Inner> Tree<Inner>
427where
428    Inner: MutableGraph + Serialize,
429{
430    pub fn to_json(&self) -> String {
431        serde_json::to_string(self).unwrap()
432    }
433}
434
435// If Serde and SerdeJSON are both present, add conveniences to deserialize a Tree
436// from JSON.
437#[cfg(all(feature = "serde", feature = "serde_json"))]
438impl<'de, Inner> Tree<Inner>
439where
440    Inner: MutableGraph + Deserialize<'de>,
441{
442    pub fn from_json(json: &'de str) -> Result<Self, serde_json::Error> {
443        serde_json::from_str(json)
444    }
445}
446
447#[cfg(all(feature = "serde", feature = "serde_json"))]
448impl<Inner> std::fmt::Display for Tree<Inner>
449where
450    Inner: MutableGraph + Serialize,
451{
452    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
453        write!(f, "{}", self.to_json())
454    }
455}