1use crate::errors::TaxonomyResult;
2use crate::rank::TaxRank;
3use serde_json::Value;
4use std::borrow::Cow;
5use std::collections::{HashMap, VecDeque};
6use std::fmt::{Debug, Display};
7
8pub trait Taxonomy<'t, T: 't>
12where
13    T: Clone + Debug + Display + PartialEq,
14{
15    fn root(&'t self) -> T;
17
18    fn children(&'t self, tax_id: T) -> TaxonomyResult<Vec<T>>;
20
21    fn descendants(&'t self, tax_id: T) -> TaxonomyResult<Vec<T>>;
23
24    fn parent(&'t self, tax_id: T) -> TaxonomyResult<Option<(T, f32)>>;
27
28    fn lineage(&'t self, tax_id: T) -> TaxonomyResult<Vec<T>> {
32        let mut parents = Vec::new();
33        let mut last_parent = tax_id.clone();
34        parents.push(tax_id);
35        while let Some(p) = self.parent(last_parent)? {
36            last_parent = p.0.clone();
37            parents.push(p.0);
38        }
39        Ok(parents)
40    }
41
42    fn parent_at_rank(&'t self, tax_id: T, rank: TaxRank) -> TaxonomyResult<Option<(T, f32)>> {
45        if self.rank(tax_id.clone())? == rank {
47            return Ok(Some((tax_id, 0.0)));
48        }
49
50        let mut cur_id = tax_id;
52        let mut dists = Vec::new();
53        while let Some(p) = self.parent(cur_id)? {
54            dists.push(p.1);
55            if self.rank(p.0.clone())? == rank {
56                return Ok(Some((p.0, dists.into_iter().sum())));
57            }
58            cur_id = p.0.clone();
59        }
60        Ok(None)
61    }
62
63    fn lca(&'t self, id1: T, id2: T) -> TaxonomyResult<T> {
72        let mut id1_parents = VecDeque::new();
74        id1_parents.push_front(id1);
75        while let Some(p) = self.parent(id1_parents.front().unwrap().clone())? {
76            id1_parents.push_front(p.0);
77        }
78
79        let mut id2_parents = VecDeque::new();
81        id2_parents.push_front(id2);
82        while let Some(p) = self.parent(id2_parents.front().unwrap().clone())? {
83            id2_parents.push_front(p.0);
84        }
85
86        let mut common = self.root();
88        for (pid1, pid2) in id1_parents.into_iter().zip(id2_parents.into_iter()) {
89            if pid1 != pid2 {
90                break;
91            }
92            common = pid1;
93        }
94        Ok(common)
95    }
96
97    fn name(&'t self, tax_id: T) -> TaxonomyResult<&str>;
99
100    fn data(&'t self, _tax_id: T) -> TaxonomyResult<Cow<'t, HashMap<String, Value>>> {
104        Ok(Cow::Owned(HashMap::new()))
105    }
106
107    fn rank(&'t self, tax_id: T) -> TaxonomyResult<TaxRank>;
109
110    fn traverse(&'t self, node: T) -> TaxonomyResult<TaxonomyIterator<'t, T>>
114    where
115        Self: Sized,
116    {
117        Ok(TaxonomyIterator::new(self, node))
118    }
119
120    fn len(&'t self) -> usize
122    where
123        Self: Sized,
124    {
125        self.traverse(self.root()).unwrap().count() / 2
126    }
127
128    fn is_empty(&'t self) -> bool
131    where
132        Self: Sized,
133    {
134        self.len() == 0
135    }
136}
137
138pub struct TaxonomyIterator<'t, T: 't> {
139    nodes_left: Vec<T>,
140    visited_nodes: Vec<T>,
141    tax: &'t dyn Taxonomy<'t, T>,
142}
143
144impl<'t, T> TaxonomyIterator<'t, T> {
145    pub fn new(tax: &'t dyn Taxonomy<'t, T>, root_node: T) -> Self {
146        TaxonomyIterator {
147            nodes_left: vec![root_node],
148            visited_nodes: vec![],
149            tax,
150        }
151    }
152}
153
154impl<'t, T> Iterator for TaxonomyIterator<'t, T>
155where
156    T: Clone + Debug + Display + PartialEq,
157{
158    type Item = (T, bool);
159
160    fn next(&mut self) -> Option<Self::Item> {
161        if self.nodes_left.is_empty() {
162            return None;
163        }
164
165        let cur_node = self.nodes_left.last().unwrap().clone();
166        let node_visited = {
167            let last_visited = self.visited_nodes.last();
168            Some(&cur_node) == last_visited
169        };
170        if node_visited {
171            self.visited_nodes.pop();
172            Some((self.nodes_left.pop().unwrap(), false))
173        } else {
174            self.visited_nodes.push(cur_node.clone());
175            let children = self.tax.children(cur_node.clone()).unwrap();
176            if !children.is_empty() {
177                self.nodes_left.extend(children);
178            }
179            Some((cur_node, true))
180        }
181    }
182}
183
184#[cfg(test)]
185pub(crate) mod tests {
186    use super::*;
187    use std::collections::HashSet;
188
189    pub(crate) struct MockTax;
190
191    impl<'t> Taxonomy<'t, u32> for MockTax {
192        fn root(&self) -> u32 {
193            1
194        }
195
196        fn children(&self, tax_id: u32) -> TaxonomyResult<Vec<u32>> {
197            Ok(match tax_id {
198                1 => vec![131567],
199                131567 => vec![2],
200                2 => vec![1224],
201                1224 => vec![1236],
202                1236 => vec![135622, 135613], 135622 => vec![22],
204                22 => vec![62322],
205                62322 => vec![56812], 135613 => vec![1046],
207                1046 => vec![53452],
208                53452 => vec![61598], 61598 => vec![765909],
210                _ => vec![],
211            })
212        }
213
214        fn descendants(&'t self, tax_id: u32) -> TaxonomyResult<Vec<u32>> {
215            let children: HashSet<u32> = self
216                .traverse(tax_id)?
217                .map(|(n, _)| n)
218                .filter(|n| *n != tax_id)
219                .collect();
220            let mut children: Vec<u32> = children.into_iter().collect();
221            children.sort_unstable();
222            Ok(children)
223        }
224
225        fn parent(&self, tax_id: u32) -> TaxonomyResult<Option<(u32, f32)>> {
226            Ok(match tax_id {
227                131567 => Some((1, 1.)),
228                2 => Some((131567, 1.)),
229                1224 => Some((2, 1.)),
230                1236 => Some((1224, 1.)), 135622 => Some((1236, 1.)),
232                22 => Some((135622, 1.)),
233                62322 => Some((22, 1.)),
234                56812 => Some((62322, 1.)), 135613 => Some((1236, 1.)),
236                1046 => Some((135613, 1.)),
237                53452 => Some((1046, 1.)),
238                61598 => Some((53452, 1.)),  765909 => Some((61598, 1.)), _ => None,
241            })
242        }
243
244        fn name(&self, tax_id: u32) -> TaxonomyResult<&str> {
245            Ok(match tax_id {
246                1 => "root",
247                131567 => "cellular organisms",
248                2 => "Bacteria",
249                1224 => "Proteobacteria",
250                1236 => "Gammaproteobacteria",
251                135613 => "Chromatiales",
252                1046 => "Chromatiaceae",
253                53452 => "Lamprocystis",
254                61598 => "Lamprocystis purpurea",
255                765909 => "Lamprocystis purpurea DSM 4197",
256                _ => "",
257            })
258        }
259
260        fn rank(&self, tax_id: u32) -> TaxonomyResult<TaxRank> {
261            Ok(match tax_id {
262                2 => TaxRank::Superkingdom,
263                1224 => TaxRank::Phylum,
264                1236 => TaxRank::Class,
265                135613 => TaxRank::Order,
266                1046 => TaxRank::Family,
267                53452 => TaxRank::Genus,
268                61598 => TaxRank::Species,
269                _ => TaxRank::Unspecified,
270            })
271        }
272    }
273
274    #[test]
275    fn test_len() {
276        let tax = MockTax;
277        assert_eq!(tax.root(), 1);
278        assert_eq!(tax.len(), 14);
279        assert_eq!(tax.is_empty(), false);
280    }
281
282    #[test]
283    fn test_descendants() {
284        let tax = MockTax;
285        assert_eq!(
286            tax.descendants(2).unwrap(),
287            vec![22, 1046, 1224, 1236, 53452, 56812, 61598, 62322, 135613, 135622, 765909]
288        );
289    }
290
291    #[test]
292    fn test_lca() {
293        let tax = MockTax;
294        assert_eq!(tax.lca(56812, 22).unwrap(), 22);
295        assert_eq!(tax.lca(56812, 765909).unwrap(), 1236);
296    }
297
298    #[test]
299    fn test_lineage() {
300        let tax = MockTax;
301        assert_eq!(tax.lineage(1).unwrap(), vec![1]);
302        assert_eq!(
303            tax.lineage(61598).unwrap(),
304            vec![61598, 53452, 1046, 135613, 1236, 1224, 2, 131567, 1]
305        );
306    }
307
308    #[test]
309    fn test_parent_at_rank() {
310        let tax = MockTax;
311
312        assert_eq!(
313            tax.parent_at_rank(765909, TaxRank::Genus).unwrap(),
314            Some((53452, 2.))
315        );
316        assert_eq!(
317            tax.parent_at_rank(765909, TaxRank::Class).unwrap(),
318            Some((1236, 5.))
319        );
320        assert_eq!(
321            tax.parent_at_rank(1224, TaxRank::Phylum).unwrap(),
322            Some((1224, 0.))
323        );
324        assert_eq!(tax.parent_at_rank(1224, TaxRank::Genus).unwrap(), None,);
325    }
326
327    #[test]
328    fn test_traversal() {
329        let tax = MockTax;
330        let mut visited = HashSet::new();
331        let n_nodes = tax
332            .traverse(tax.root())
333            .unwrap()
334            .enumerate()
335            .map(|(ix, (tid, pre))| match ix {
336                0 => {
337                    assert_eq!(tid, 1, "root is first");
338                    assert_eq!(pre, true, "preorder visits happen first");
339                }
340                27 => {
341                    assert_eq!(tid, 1, "root is last too");
342                    assert_eq!(pre, false, "postorder visits happen last");
343                }
344                _ => {
345                    if pre {
346                        visited.insert(tid);
347                    } else {
348                        assert!(visited.contains(&tid));
349                    }
350                }
351            })
352            .count();
353
354        assert_eq!(n_nodes, 28, "Each node appears twice");
355    }
356}