1use std::borrow::Cow;
2
3use anyhow::{Result, bail};
4#[cfg(feature = "serde")]
5use serde::{ser::{Serialize, Serializer}, de::{self, Deserialize, Deserializer}};
6
7use crate::{nid, EdgeID, Edges, Error, IsTree, MutableForest, MutableGraph, MutableTree, NodeID, Nodes, PathExists, BlankGraph, TopologicalSort, VisitableForest, VisitableGraph, VisitableTree};
8
9#[derive(Debug, Clone)]
13pub struct Tree<Inner>
14where
15 Inner: MutableGraph,
16{
17 root: NodeID,
18 graph: Inner,
19}
20
21pub type BlankTree = Tree<BlankGraph>;
23
24impl<Inner> Tree<Inner>
25where
26 Inner: MutableGraph,
27{
28 pub fn new_unchecked(root: NodeID, graph: Inner) -> Self {
29 Self { root, graph }
30 }
31
32 pub fn new_with_root_and_graph(root: NodeID, graph: Inner) -> Result<Self> {
33 graph.check_is_tree(&root)?;
34 Ok(Self::new_unchecked(root, graph))
35 }
36
37 pub fn graph(&self) -> &Inner {
38 &self.graph
39 }
40}
41
42impl<Inner> Tree<Inner>
43where
44 Inner: MutableGraph + Default + Clone,
45 Inner::NData: Default,
46{
47 pub fn new_with_root(root: NodeID) -> Self {
48 Self::new_unchecked(root, Inner::default().adding_node(&nid!("root")).unwrap())
49 }
50
51 pub fn new() -> Self {
52 Self::new_with_root(nid!("root"))
53 }
54}
55
56impl<Inner> Default for Tree<Inner>
57where
58 Inner: MutableGraph + Default + Clone,
59 Inner::NData: Default,
60{
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66impl<Inner> VisitableGraph for Tree<Inner>
67where
68 Inner: MutableGraph,
69{
70 type GData = Inner::GData;
71 type NData = Inner::NData;
72 type EData = Inner::EData;
73
74 fn data(&self) -> &Self::GData {
75 self.graph.data()
76 }
77
78 fn node_data(&self, id: impl AsRef<NodeID>) -> Result<Cow<'static, Self::NData>> {
79 self.graph.node_data(id)
80 }
81
82 fn edge_data(&self, id: impl AsRef<EdgeID>) -> Result<Cow<'static, Self::EData>> {
83 self.graph.edge_data(id)
84 }
85
86 fn is_empty(&self) -> bool {
87 self.graph.is_empty()
88 }
89
90 fn node_count(&self) -> usize {
91 self.graph.node_count()
92 }
93
94 fn edge_count(&self) -> usize {
95 self.graph.edge_count()
96 }
97
98 fn all_nodes(&self) -> Nodes {
99 self.graph.all_nodes()
100 }
101
102 fn all_edges(&self) -> Edges {
103 self.graph.all_edges()
104 }
105
106 fn has_node(&self, id: impl AsRef<NodeID>) -> bool {
107 self.graph.has_node(id)
108 }
109
110 fn has_edge(&self, id: impl AsRef<EdgeID>) -> bool {
111 self.graph.has_edge(id)
112 }
113
114 fn has_edge_from_to(&self, source: impl AsRef<NodeID>, target: impl AsRef<NodeID>) -> bool {
115 self.graph.has_edge_from_to(source, target)
116 }
117
118 fn has_edge_between(&self, a: impl AsRef<NodeID>, b: impl AsRef<NodeID>) -> bool {
119 self.graph.has_edge_between(a, b)
120 }
121
122 fn source(&self, id: impl AsRef<EdgeID>) -> Result<NodeID> {
123 self.graph.source(id)
124 }
125
126 fn target(&self, id: impl AsRef<EdgeID>) -> Result<NodeID> {
127 self.graph.target(id)
128 }
129
130 fn endpoints(&self, id: impl AsRef<EdgeID>) -> Result<(NodeID, NodeID)> {
131 self.graph.endpoints(id)
132 }
133
134 fn out_edges(&self, id: impl AsRef<NodeID>) -> Result<Edges> {
135 self.graph.out_edges(id)
136 }
137
138 fn in_edges(&self, id: impl AsRef<NodeID>) -> Result<Edges> {
139 self.graph.in_edges(id)
140 }
141
142 fn incident_edges(&self, id: impl AsRef<NodeID>) -> Result<Edges> {
143 self.graph.incident_edges(id)
144 }
145
146 fn out_degree(&self, id: impl AsRef<NodeID>) -> Result<usize> {
147 self.graph.out_degree(id)
148 }
149
150 fn in_degree(&self, id: impl AsRef<NodeID>) -> Result<usize> {
151 self.graph.in_degree(id)
152 }
153
154 fn degree(&self, id: impl AsRef<NodeID>) -> Result<usize> {
155 self.graph.degree(id)
156 }
157
158 fn successors(&self, id: impl AsRef<NodeID>) -> Result<Nodes> {
159 self.graph.successors(id)
160 }
161
162 fn predecessors(&self, id: impl AsRef<NodeID>) -> Result<Nodes> {
163 self.graph.predecessors(id)
164 }
165
166 fn neighbors(&self, id: impl AsRef<NodeID>) -> Result<Nodes> {
167 self.graph.neighbors(id)
168 }
169
170 fn has_successors(&self, id: impl AsRef<NodeID>) -> Result<bool> {
171 self.graph.has_successors(id)
172 }
173
174 fn has_predecessors(&self, id: impl AsRef<NodeID>) -> Result<bool> {
175 self.graph.has_predecessors(id)
176 }
177
178 fn has_neighbors(&self, id: impl AsRef<NodeID>) -> Result<bool> {
179 self.graph.has_neighbors(id)
180 }
181
182 fn all_roots(&self) -> Nodes {
183 vec![self.root.clone()].into_iter().collect()
184 }
185
186 fn all_leaves(&self) -> Nodes {
187 self.graph.all_leaves()
188 }
189
190 fn non_roots(&self) -> Nodes {
191 self.all_nodes().into_iter().filter(|n| n != &self.root).collect()
192 }
193
194 fn non_leaves(&self) -> Nodes {
195 self.graph.non_leaves()
196 }
197
198 fn all_internals(&self) -> Nodes {
199 self.graph.all_internals()
200 }
201
202 fn is_leaf(&self, id: impl AsRef<NodeID>) -> Result<bool> {
203 self.graph.is_leaf(id)
204 }
205
206 fn is_root(&self, id: impl AsRef<NodeID>) -> Result<bool> {
207 self.graph.is_root(id)
208 }
209
210 fn is_internal(&self, id: impl AsRef<NodeID>) -> Result<bool> {
211 self.graph.is_internal(id)
212 }
213}
214
215impl<Inner> VisitableTree for Tree<Inner>
216where
217 Inner: MutableGraph,
218{
219 fn root(&self) -> NodeID {
220 self.root.clone()
221 }
222}
223
224impl<Inner> VisitableForest for Tree<Inner>
225where
226 Inner: MutableGraph,
227{
228 fn in_edge(&self, node: impl AsRef<NodeID>) -> Result<Option<EdgeID>> {
229 Ok(self.in_edges(node)?.first().cloned())
230 }
231
232 fn in_edge_with_root(&self, node: impl AsRef<NodeID>) -> Result<Option<EdgeID>> {
233 self.in_edge(node)
234 }
235
236 fn parent(&self, node: impl AsRef<NodeID>) -> Result<Option<NodeID>> {
237 Ok(self.in_edge(node)?.map(|edge| self.source(&edge).unwrap()))
238 }
239
240 fn children(&self, node: Option<impl AsRef<NodeID>>) -> Result<Nodes> {
241 if let Some(node) = node {
242 self.successors(node)
243 } else {
244 self.successors(self.root())
245 }
246 }
247
248 fn has_children(&self, node: impl AsRef<NodeID>) -> Result<bool> {
249 self.has_successors(node)
250 }
251
252 fn child_count(&self, node: impl AsRef<NodeID>) -> Result<usize> {
253 self.out_degree(node)
254 }
255}
256
257impl<Inner> MutableTree for Tree<Inner>
258where
259 Inner: MutableGraph,
260{
261 fn set_root(&mut self, root: impl AsRef<NodeID>) -> Result<()> {
262 let root = root.as_ref();
263 self.graph.check_is_tree(root)?;
264 self.root = root.as_ref().clone();
265 Ok(())
266 }
267}
268
269impl<Inner> MutableForest for Tree<Inner>
270where
271 Inner: MutableGraph,
272{
273 fn add_node_with_node_and_edge_data(
274 &mut self,
275 node: impl AsRef<NodeID>,
276 parent: Option<impl AsRef<NodeID>>,
277 edge: impl AsRef<EdgeID>,
278 node_data: Self::NData,
279 edge_data: Self::EData,
280 ) -> Result<()> {
281 let node = node.as_ref();
282 self.graph.add_node_with_data(node, node_data)?;
283 let parent = parent.map(|p| p.as_ref().clone()).unwrap_or_else(|| self.root());
284 self.graph.add_edge_with_data(edge, parent, node, edge_data)?;
285 Ok(())
286 }
287
288 fn remove_node_ungrouping(&mut self, id: impl AsRef<NodeID>) -> Result<()> {
289 let id = id.as_ref();
290 if id == &self.root {
291 let children = self.children(Some(id))?;
292 if children.len() != 1 {
293 bail!(Error::NotATree);
294 }
295 let new_root = children.into_iter().next().unwrap();
296 self.graph.remove_node(id)?;
297 self.set_root(&new_root)?;
298 } else {
299 let new_parent = self.parent(id)?.unwrap();
300 let children = self.children(Some(id))?;
301 for child in children {
302 self.move_node(&child, Some(&new_parent))?;
303 }
304 self.graph.remove_node(id)?;
305 }
306 Ok(())
307 }
308
309 fn remove_node_and_children(&mut self, id: impl AsRef<NodeID>) -> Result<Nodes> {
310 let id = id.as_ref();
311
312 if id == &self.root {
314 bail!(Error::NotATree);
315 }
316
317 let to_remove = self.topological_sort_opt(&Nodes::from([id.clone()]), true)?;
319 for node in to_remove.iter() {
320 self.graph.remove_node(node)?;
321 }
322 Ok(to_remove.into_iter().collect())
323 }
324
325 fn remove_children(&mut self, id: impl AsRef<NodeID>) -> Result<Nodes> {
326 let id = id.as_ref();
327
328 let children = self.children(Some(id))?;
330 let to_remove = self.topological_sort_opt(&children, true)?;
331 for node in to_remove.iter() {
332 self.graph.remove_node(node)?;
333 }
334 Ok(to_remove.into_iter().collect())
335 }
336
337 fn move_node(&mut self, id: impl AsRef<NodeID>, new_parent: Option<impl AsRef<NodeID>>) -> Result<()> {
338 let id = id.as_ref();
339
340 if id == &self.root {
342 bail!(Error::NotATree);
343 }
344
345 let edge = self.in_edge(id)?.unwrap();
346 let root = self.root();
347 let new_parent = new_parent.map(|p| p.as_ref().clone()).unwrap_or_else(|| root.clone());
348 let new_parent = new_parent.as_ref();
349 if !self.graph.can_move_dag_edge(&edge, new_parent, id)? {
350 bail!(Error::NotATree);
351 }
352 self.graph.move_edge(&edge, new_parent, id)?;
353 Ok(())
354 }
355
356 fn set_data(&mut self, data: Self::GData) {
357 self.graph.set_data(data);
358 }
359
360 fn set_node_data(&mut self, id: impl AsRef<NodeID>, data: Self::NData) -> Result<()> {
361 self.graph.set_node_data(id, data)
362 }
363
364 fn set_edge_data(&mut self, id: impl AsRef<EdgeID>, data: Self::EData) -> Result<()> {
365 self.graph.set_edge_data(id, data)
366 }
367
368 fn with_data(&mut self, transform: &dyn Fn(&mut Self::GData)) {
369 self.graph.with_data(transform);
370 }
371
372 fn with_node_data(&mut self, id: impl AsRef<NodeID>, transform: &dyn Fn(&mut Self::NData)) -> Result<()> {
373 self.graph.with_node_data(id, transform)
374 }
375
376 fn with_edge_data(&mut self, id: impl AsRef<EdgeID>, transform: &dyn Fn(&mut Self::EData)) -> Result<()> {
377 self.graph.with_edge_data(id, transform)
378 }
379}
380
381impl<Inner> PartialEq for Tree<Inner>
382where
383 Inner: MutableGraph + PartialEq,
384{
385 fn eq(&self, other: &Self) -> bool {
386 self.root == other.root && self.graph == other.graph
387 }
388}
389
390impl<Inner> Eq for Tree<Inner>
391where
392 Inner: MutableGraph + Eq,
393{
394}
395
396#[cfg(feature = "serde")]
397impl<Inner> Serialize for Tree<Inner>
398where
399 Inner: MutableGraph + Serialize,
400{
401 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
402 where
403 S: Serializer,
404 {
405 (&self.root, &self.graph).serialize(serializer)
406 }
407}
408
409#[cfg(feature = "serde")]
410impl<'de, Inner> Deserialize<'de> for Tree<Inner>
411where
412 Inner: MutableGraph + Deserialize<'de>,
413{
414 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
415 where
416 D: Deserializer<'de>,
417 {
418 let (root, graph) = <(NodeID, Inner)>::deserialize(deserializer)?;
419 Tree::new_with_root_and_graph(root, graph).map_err(de::Error::custom)
420 }
421}
422
423#[cfg(all(feature = "serde", feature = "serde_json"))]
426impl<Inner> Tree<Inner>
427where
428 Inner: MutableGraph + Serialize,
429{
430 pub fn to_json(&self) -> String {
431 serde_json::to_string(self).unwrap()
432 }
433}
434
435#[cfg(all(feature = "serde", feature = "serde_json"))]
438impl<'de, Inner> Tree<Inner>
439where
440 Inner: MutableGraph + Deserialize<'de>,
441{
442 pub fn from_json(json: &'de str) -> Result<Self, serde_json::Error> {
443 serde_json::from_str(json)
444 }
445}
446
447#[cfg(all(feature = "serde", feature = "serde_json"))]
448impl<Inner> std::fmt::Display for Tree<Inner>
449where
450 Inner: MutableGraph + Serialize,
451{
452 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
453 write!(f, "{}", self.to_json())
454 }
455}