udgraph_projectivize/
proj.rs

1//! Projectivization/deprojectivization of graphs.
2
3use std::cmp::{max, min};
4use std::collections::{HashMap, HashSet};
5
6use itertools::Itertools;
7use petgraph::graph::{node_index, EdgeIndex, NodeIndex};
8use petgraph::visit::{Bfs, EdgeRef, NodeFiltered, Walker};
9use petgraph::{Directed, Direction, Graph};
10use udgraph::graph::{DepTriple, Sentence};
11use udgraph::Error as UDError;
12
13use crate::{BfsWithDepth, Error};
14
15/// Graph deprojectivizer.
16pub trait Deprojectivize {
17    /// Deprojectivize a graph
18    ///
19    /// This method rewrites a projective graph into a non-projective graph.
20    /// Depending on the (de)projectivizer, this could require additional
21    /// information in dependency labels to guide the deprojectivization.
22    fn deprojectivize(&self, sentence: &mut Sentence) -> Result<(), Error>;
23}
24
25/// Graph projectivizer.
26pub trait Projectivize {
27    /// Projectivize a graph
28    ///
29    /// This method rewrites a non-projective graph into a projective graph.
30    /// Depending on the projectivizer, this may add additional information
31    /// to the dependency labels to undo the projectivization later.
32    fn projectivize(&self, sentence: &mut Sentence) -> Result<(), Error>;
33}
34
35/// A projectivizer using the 'head'-marking strategy. See: *Pseudo-Projective
36/// Dependency Parsing*, Nivre and Nilsson, 2005.
37#[derive(Clone, Copy, Eq, PartialEq)]
38pub struct HeadProjectivizer;
39
40impl HeadProjectivizer {
41    pub fn new() -> Self {
42        HeadProjectivizer {}
43    }
44
45    /// Deprojectivize the next node in the array of lifted nodes.
46    ///
47    /// Returns the index of the node that was lifted.
48    fn deprojectivize_next(
49        self,
50        graph: &mut Graph<(), String, Directed>,
51        lifted_sorted: &[NodeIndex],
52        head_labels: &HashMap<NodeIndex, String>,
53    ) -> Option<usize> {
54        for (idx, lifted_node) in lifted_sorted.iter().enumerate() {
55            let pref_head_rel = head_labels
56                .get(lifted_node)
57                .expect("Lifted node without preferred head relation");
58
59            let head_edge = graph
60                .first_edge(*lifted_node, Direction::Incoming)
61                .expect("Lifted node without an incoming edge");
62            let (cur_head, _) = graph
63                .edge_endpoints(head_edge)
64                .expect("Endpoints of lifted edge could not be found");
65
66            if let Some(new_head) =
67                self.search_attachment_point(graph, cur_head, *lifted_node, pref_head_rel)
68            {
69                let head_rel = graph
70                    .remove_edge(head_edge)
71                    .expect("Lifted edge to be removed could not be found");
72                graph.add_edge(new_head, *lifted_node, head_rel);
73                return Some(idx);
74            }
75        }
76
77        None
78    }
79
80    /// Find the correct attachment point for the lifted token/node.
81    fn search_attachment_point(
82        self,
83        graph: &Graph<(), String, Directed>,
84        cur_head: NodeIndex,
85        lifted_node: NodeIndex,
86        pref_head_rel: &str,
87    ) -> Option<NodeIndex> {
88        // We are looking for a token dominated by cur_head to attach
89        // lifted_node to. This token should:
90        //
91        // 1. Be attached to its head using pref_head_rel.
92        // 2. Not be lifted_node itself or any of its decendants.
93        // 3. As high in the tree as possible.
94        //
95        // From the set of candidates, we pick the token that is the
96        // closest to the current head.
97
98        // Requirement (2): use a view of the graph that excludes
99        // to avoid attachment to the lifted_node or any of its children.
100        let graph_without_lifted = NodeFiltered::from_fn(graph, |n| n != lifted_node);
101
102        // Requirement (3): process the dependency tree by increasing depth
103        // until the reattachment token is found.
104        for (_, nodes) in &BfsWithDepth::new(&graph_without_lifted, node_index(0))
105            .iter(&graph_without_lifted)
106            .skip(1)
107            .group_by(|&(_, depth)| depth)
108        {
109            // Requirement (1): Only retain edges with the preferred relation.
110            let level_candidates = nodes.map(|(node, _)| node).filter(|&node| {
111                let edge = match graph.first_edge(node, Direction::Incoming) {
112                    Some(edge) => edge,
113                    None => return false,
114                };
115
116                graph[edge] == pref_head_rel
117            });
118
119            // When there are multiple candidates, return the token closes to the head.
120            let min_candidate = level_candidates.min_by_key(|&node| {
121                max(node.index(), cur_head.index()) - min(node.index(), cur_head.index())
122            });
123
124            if min_candidate.is_some() {
125                return min_candidate;
126            }
127        }
128
129        None
130    }
131
132    /// Lift the edge identified by `edge_idx`. This will reattach the edge
133    /// to the parent of the head. If this was the first lifting operation,
134    /// the dependency relation of the original head is added to the dependency
135    /// relation (following the head-strategy).
136    fn lift(
137        self,
138        graph: &mut Graph<(), String, Directed>,
139        lifted: &mut HashSet<NodeIndex>,
140        edge_idx: EdgeIndex,
141    ) {
142        let (source, target) = graph
143            .edge_endpoints(edge_idx)
144            .expect("lift() called with invalid index");
145        let parent_edge = graph
146            .first_edge(source, Direction::Incoming)
147            .expect("Cannot find incoming edge of the to-be lifted node");
148        let parent_rel = graph[parent_edge].clone();
149        let (parent, _) = graph
150            .edge_endpoints(parent_edge)
151            .expect("Cannot find endpoints of to-be lifted edge");
152
153        let rel = graph
154            .remove_edge(edge_idx)
155            .expect("Cannot remove edge to-be lifted");
156
157        if lifted.contains(&target) {
158            graph.add_edge(parent, target, rel);
159        } else {
160            graph.add_edge(parent, target, format!("{}|{}", rel, parent_rel));
161            lifted.insert(target);
162        }
163    }
164
165    /// Prepare for deprojectivizing: remove head annotations from lifted
166    /// relations. Return the transformed graph + indices of lifted nodes
167    /// and their head labels.
168    fn prepare_deproj(
169        self,
170        graph: &Graph<(), String, Directed>,
171    ) -> (Graph<(), String, Directed>, HashMap<NodeIndex, String>) {
172        let mut pref_head_labels = HashMap::new();
173
174        let prepared_graph = graph.map(
175            |_, &node_val| node_val,
176            |edge_idx, edge_val| {
177                let sep_idx = match edge_val.find('|') {
178                    Some(idx) => idx,
179                    None => return edge_val.clone(),
180                };
181
182                let (_, dep) = graph
183                    .edge_endpoints(edge_idx)
184                    .expect("Cannot lookup edge endpoints");
185
186                pref_head_labels.insert(dep, edge_val[sep_idx + 1..].to_owned());
187
188                edge_val[..sep_idx].to_owned()
189            },
190        );
191
192        (prepared_graph, pref_head_labels)
193    }
194}
195
196impl Default for HeadProjectivizer {
197    fn default() -> Self {
198        HeadProjectivizer
199    }
200}
201
202impl Projectivize for HeadProjectivizer {
203    fn projectivize(&self, sentence: &mut Sentence) -> Result<(), Error> {
204        let mut graph = simplify_graph(sentence)?;
205        let mut lifted = HashSet::new();
206
207        // Lift non-projective edges until there are no non-projective
208        // edges left.
209        loop {
210            let np_edges = non_projective_edges(&graph);
211            if np_edges.is_empty() {
212                break;
213            }
214
215            self.lift(&mut graph, &mut lifted, np_edges[0]);
216        }
217
218        // The graph is now a projective tree. Update the dependency relations
219        // in the sentence to correspond to the graph.
220        let r = update_sentence(&graph, sentence);
221        // This is an algorithmic error, not something we want to bubble up.
222        assert!(
223            r.is_ok(),
224            "Deprojectivization add relation with unknown head/dependent"
225        );
226
227        Ok(())
228    }
229}
230
231impl Deprojectivize for HeadProjectivizer {
232    fn deprojectivize(&self, sentence: &mut Sentence) -> Result<(), Error> {
233        let graph = simplify_graph(sentence)?;
234
235        // Find nodes and corresponding edges that are lifted and remove
236        // head labels from dependency relations.
237        let (mut graph, head_labels) = self.prepare_deproj(&graph);
238        if head_labels.is_empty() {
239            return Ok(());
240        }
241
242        // Get and sort lifted tokens by increasing depth.
243        let mut lifted_sorted = Vec::new();
244        let mut bfs = Bfs::new(&graph, node_index(0));
245        while let Some(node) = bfs.next(&graph) {
246            if head_labels.get(&node).is_some() {
247                lifted_sorted.push(node);
248            }
249        }
250
251        // Deprojectivize the graph, re-attaching one token at a time,
252        // with the preference of a token that is not deep in the tree.
253        while let Some(idx) = self.deprojectivize_next(&mut graph, &lifted_sorted, &head_labels) {
254            lifted_sorted.remove(idx);
255        }
256
257        let r = update_sentence(&graph, sentence);
258        // This is an algorithmic error, not something we want to bubble up.
259        assert!(
260            r.is_ok(),
261            "Deprojectivization add relation with unknown head/dependent"
262        );
263
264        Ok(())
265    }
266}
267
268pub fn simplify_graph(sentence: &Sentence) -> Result<Graph<(), String, Directed>, Error> {
269    let mut edges = Vec::with_capacity(sentence.len() + 1);
270    for idx in 0..sentence.len() {
271        let triple = match sentence.dep_graph().head(idx) {
272            Some(triple) => triple,
273            None => continue,
274        };
275
276        let head_rel = match triple.relation() {
277            Some(head_rel) => head_rel,
278            None => {
279                return Err(Error::IncompleteGraph {
280                    value: format!(
281                        "edge from {} to {} does not have a label",
282                        triple.head(),
283                        triple.dependent()
284                    ),
285                })
286            }
287        };
288
289        edges.push((
290            node_index(triple.head()),
291            node_index(triple.dependent()),
292            head_rel.to_owned(),
293        ))
294    }
295
296    Ok(Graph::<(), String, Directed>::from_edges(edges))
297}
298
299/// Returns non-projective edges in the graph, ordered by length.
300pub fn non_projective_edges(graph: &Graph<(), String, Directed>) -> Vec<EdgeIndex> {
301    let mut non_projective = Vec::new();
302
303    for i in 0..graph.node_count() {
304        let mut i_reachable = HashSet::new();
305        let mut bfs = Bfs::new(&graph, node_index(i));
306        while let Some(node) = bfs.next(&graph) {
307            i_reachable.insert(node.index());
308        }
309
310        for edge in graph.edges(node_index(i)) {
311            // An edge i -> k is projective, iff:
312            //
313            // i > j > k or i < j < k, and i ->* j
314            for j in min(i, edge.target().index())..max(i, edge.target().index()) {
315                if !i_reachable.contains(&j) {
316                    non_projective.push(edge);
317                    break;
318                }
319            }
320        }
321    }
322
323    non_projective.sort_by(|a, b| {
324        let a_len = max(a.source().index(), a.target().index())
325            - min(a.source().index(), a.target().index());
326        let b_len = max(b.source().index(), b.target().index())
327            - min(b.source().index(), b.target().index());
328
329        a_len.cmp(&b_len)
330    });
331
332    non_projective.iter().map(EdgeRef::id).collect()
333}
334
335/// Update a sentence with dependency relations from a graph.
336fn update_sentence(
337    graph: &Graph<(), String, Directed>,
338    sentence: &mut Sentence,
339) -> Result<(), UDError> {
340    let mut sent_graph = sentence.dep_graph_mut();
341    for edge_ref in graph.edge_references() {
342        sent_graph.add_deprel(DepTriple::new(
343            edge_ref.source().index(),
344            Some(edge_ref.weight().clone()),
345            edge_ref.target().index(),
346        ))?;
347    }
348
349    Ok(())
350}
351
352#[cfg(test)]
353mod tests {
354    use lazy_static::lazy_static;
355    use petgraph::graph::{node_index, NodeIndex};
356    use udgraph::graph::Sentence;
357
358    use crate::proj::{
359        non_projective_edges, simplify_graph, Deprojectivize, HeadProjectivizer, Projectivize,
360    };
361    use crate::tests::read_sentences;
362
363    lazy_static! {
364        static ref NON_PROJECTIVE_EDGES: Vec<Vec<(NodeIndex, NodeIndex)>> = vec![
365            vec![(node_index(8), node_index(1))],
366            vec![(node_index(10), node_index(2))],
367            vec![(node_index(5), node_index(1))],
368            vec![
369                (node_index(1), node_index(3)),
370                (node_index(7), node_index(5))
371            ],
372        ];
373    }
374
375    fn sent_non_projective_edges(sents: &[Sentence]) -> Vec<Vec<(NodeIndex, NodeIndex)>> {
376        let mut np_edges = Vec::new();
377
378        for sent in sents {
379            let graph = simplify_graph(sent).unwrap();
380            let np: Vec<_> = non_projective_edges(&graph)
381                .iter()
382                .map(|idx| graph.edge_endpoints(*idx).unwrap())
383                .collect();
384            np_edges.push(np);
385        }
386
387        np_edges
388    }
389
390    static PROJECTIVE_SENTENCES_FILENAME: &str = "testdata/projective.conll";
391
392    static NONPROJECTIVE_SENTENCES_FILENAME: &str = "testdata/nonprojective.conll";
393
394    #[test]
395    fn deprojectivize_test() {
396        let projectivizer = HeadProjectivizer::new();
397        let non_projective: Vec<_> = read_sentences(PROJECTIVE_SENTENCES_FILENAME)
398            .into_iter()
399            .map(|mut s| {
400                projectivizer
401                    .deprojectivize(&mut s)
402                    .expect("Cannot deprojectivize sentence");
403                s
404            })
405            .collect();
406
407        assert_eq!(
408            read_sentences(NONPROJECTIVE_SENTENCES_FILENAME),
409            non_projective
410        );
411    }
412
413    #[test]
414    fn non_projective_test() {
415        let test_edges =
416            sent_non_projective_edges(&read_sentences(NONPROJECTIVE_SENTENCES_FILENAME));
417        assert_eq!(*NON_PROJECTIVE_EDGES, test_edges);
418    }
419
420    #[test]
421    fn projectivize_test() {
422        let projectivizer = HeadProjectivizer::new();
423        let projective: Vec<_> = read_sentences(NONPROJECTIVE_SENTENCES_FILENAME)
424            .into_iter()
425            .map(|mut s| {
426                projectivizer
427                    .projectivize(&mut s)
428                    .expect("Cannot projectivize sentence");
429                s
430            })
431            .collect();
432
433        assert_eq!(read_sentences(PROJECTIVE_SENTENCES_FILENAME), projective);
434    }
435}