sgf_render/
sgf_traversal.rs

1use sgf_parse::{go::Prop, SgfNode};
2
3use crate::errors::GobanError;
4
5/// Returns an iterator over SgfTraversalNode values at the root of each variation.
6pub fn variation_roots(node: &SgfNode<Prop>) -> impl Iterator<Item = SgfTraversalNode<'_>> {
7    SgfTraversal::new(node).filter(|node| node.is_variation_root)
8}
9
10/// Returns an iterator of SgfTraversalNode values for every node in a variation.
11pub fn variation_nodes(
12    root: &SgfNode<Prop>,
13    variation: u64,
14) -> Result<impl Iterator<Item = SgfTraversalNode<'_>>, GobanError> {
15    let mut parents = vec![0; variation as usize + 1];
16    let mut starts = vec![u64::MAX; variation as usize + 1];
17    let mut variation_seen = false;
18    for node in variation_roots(root).take_while(|n| n.variation <= variation) {
19        parents[node.variation as usize] = node.parent_variation;
20        starts[node.variation as usize] =
21            starts[node.variation as usize].min(node.variation_node_number);
22        if node.variation == variation {
23            variation_seen = true;
24        }
25    }
26    if !variation_seen {
27        return Err(GobanError::MissingVariation);
28    }
29    let mut current_variation = variation;
30    let mut variations = vec![];
31    while current_variation > 0 {
32        let parent = parents[current_variation as usize];
33        let start = starts[current_variation as usize];
34        variations.push((current_variation, start));
35        current_variation = parent;
36    }
37
38    let mut current_variation = 0;
39    let (mut next_variation, mut next_node_number) = variations.pop().unwrap_or((0, 0));
40    Ok(SgfTraversal::new(root)
41        .take_while(move |node| node.variation <= variation)
42        .filter(move |node| {
43            if node.variation == current_variation && node.variation_node_number >= next_node_number
44            {
45                current_variation = next_variation;
46                if let Some((a, b)) = variations.pop() {
47                    (next_variation, next_node_number) = (a, b);
48                }
49            }
50            node.variation == current_variation
51        }))
52}
53
54#[derive(Debug, Clone)]
55pub struct SgfTraversal<'a> {
56    stack: Vec<SgfTraversalNode<'a>>,
57    variation: u64,
58}
59
60impl<'a> SgfTraversal<'a> {
61    pub fn new(sgf_node: &'a SgfNode<Prop>) -> Self {
62        SgfTraversal {
63            stack: vec![SgfTraversalNode {
64                sgf_node,
65                variation_node_number: 0,
66                variation: 0,
67                parent_variation: 0,
68                branch_number: 0,
69                is_variation_root: true,
70                branches: vec![],
71            }],
72            variation: 0,
73        }
74    }
75}
76
77#[derive(Debug, Clone)]
78pub struct SgfTraversalNode<'a> {
79    pub sgf_node: &'a SgfNode<Prop>,
80    pub variation_node_number: u64,
81    pub variation: u64,
82    pub parent_variation: u64,
83    pub branch_number: u64,
84    pub is_variation_root: bool,
85    pub branches: Vec<bool>,
86}
87
88impl<'a> Iterator for SgfTraversal<'a> {
89    type Item = SgfTraversalNode<'a>;
90
91    fn next(&mut self) -> Option<Self::Item> {
92        let mut traversal_node = self.stack.pop()?;
93        let sgf_node = traversal_node.sgf_node;
94        let variation_node_number = traversal_node.variation_node_number + 1;
95        let is_variation_root = sgf_node.children.len() > 1;
96        if traversal_node.is_variation_root && traversal_node.branch_number != 0 {
97            self.variation += 1;
98            traversal_node.parent_variation = traversal_node.variation;
99            traversal_node.variation = self.variation;
100        }
101        for (branch_number, child) in sgf_node.children.iter().enumerate().rev() {
102            let mut branches = traversal_node.branches.clone();
103            if is_variation_root {
104                if branch_number == sgf_node.children.len() - 1 {
105                    branches.push(false);
106                } else {
107                    branches.push(true);
108                }
109            }
110            self.stack.push(SgfTraversalNode {
111                sgf_node: child,
112                variation_node_number,
113                variation: traversal_node.variation,
114                parent_variation: traversal_node.parent_variation,
115                branch_number: branch_number as u64,
116                is_variation_root,
117                branches,
118            });
119        }
120        Some(traversal_node)
121    }
122}