1use std::cmp::{max, min};
4use std::collections::{HashMap, HashSet};
5
6use itertools::Itertools;
7use petgraph::graph::{node_index, EdgeIndex, NodeIndex};
8use petgraph::visit::{Bfs, EdgeRef, NodeFiltered, Walker};
9use petgraph::{Directed, Direction, Graph};
10use udgraph::graph::{DepTriple, Sentence};
11use udgraph::Error as UDError;
12
13use crate::{BfsWithDepth, Error};
14
15pub trait Deprojectivize {
17 fn deprojectivize(&self, sentence: &mut Sentence) -> Result<(), Error>;
23}
24
25pub trait Projectivize {
27 fn projectivize(&self, sentence: &mut Sentence) -> Result<(), Error>;
33}
34
35#[derive(Clone, Copy, Eq, PartialEq)]
38pub struct HeadProjectivizer;
39
40impl HeadProjectivizer {
41 pub fn new() -> Self {
42 HeadProjectivizer {}
43 }
44
45 fn deprojectivize_next(
49 self,
50 graph: &mut Graph<(), String, Directed>,
51 lifted_sorted: &[NodeIndex],
52 head_labels: &HashMap<NodeIndex, String>,
53 ) -> Option<usize> {
54 for (idx, lifted_node) in lifted_sorted.iter().enumerate() {
55 let pref_head_rel = head_labels
56 .get(lifted_node)
57 .expect("Lifted node without preferred head relation");
58
59 let head_edge = graph
60 .first_edge(*lifted_node, Direction::Incoming)
61 .expect("Lifted node without an incoming edge");
62 let (cur_head, _) = graph
63 .edge_endpoints(head_edge)
64 .expect("Endpoints of lifted edge could not be found");
65
66 if let Some(new_head) =
67 self.search_attachment_point(graph, cur_head, *lifted_node, pref_head_rel)
68 {
69 let head_rel = graph
70 .remove_edge(head_edge)
71 .expect("Lifted edge to be removed could not be found");
72 graph.add_edge(new_head, *lifted_node, head_rel);
73 return Some(idx);
74 }
75 }
76
77 None
78 }
79
80 fn search_attachment_point(
82 self,
83 graph: &Graph<(), String, Directed>,
84 cur_head: NodeIndex,
85 lifted_node: NodeIndex,
86 pref_head_rel: &str,
87 ) -> Option<NodeIndex> {
88 let graph_without_lifted = NodeFiltered::from_fn(graph, |n| n != lifted_node);
101
102 for (_, nodes) in &BfsWithDepth::new(&graph_without_lifted, node_index(0))
105 .iter(&graph_without_lifted)
106 .skip(1)
107 .group_by(|&(_, depth)| depth)
108 {
109 let level_candidates = nodes.map(|(node, _)| node).filter(|&node| {
111 let edge = match graph.first_edge(node, Direction::Incoming) {
112 Some(edge) => edge,
113 None => return false,
114 };
115
116 graph[edge] == pref_head_rel
117 });
118
119 let min_candidate = level_candidates.min_by_key(|&node| {
121 max(node.index(), cur_head.index()) - min(node.index(), cur_head.index())
122 });
123
124 if min_candidate.is_some() {
125 return min_candidate;
126 }
127 }
128
129 None
130 }
131
132 fn lift(
137 self,
138 graph: &mut Graph<(), String, Directed>,
139 lifted: &mut HashSet<NodeIndex>,
140 edge_idx: EdgeIndex,
141 ) {
142 let (source, target) = graph
143 .edge_endpoints(edge_idx)
144 .expect("lift() called with invalid index");
145 let parent_edge = graph
146 .first_edge(source, Direction::Incoming)
147 .expect("Cannot find incoming edge of the to-be lifted node");
148 let parent_rel = graph[parent_edge].clone();
149 let (parent, _) = graph
150 .edge_endpoints(parent_edge)
151 .expect("Cannot find endpoints of to-be lifted edge");
152
153 let rel = graph
154 .remove_edge(edge_idx)
155 .expect("Cannot remove edge to-be lifted");
156
157 if lifted.contains(&target) {
158 graph.add_edge(parent, target, rel);
159 } else {
160 graph.add_edge(parent, target, format!("{}|{}", rel, parent_rel));
161 lifted.insert(target);
162 }
163 }
164
165 fn prepare_deproj(
169 self,
170 graph: &Graph<(), String, Directed>,
171 ) -> (Graph<(), String, Directed>, HashMap<NodeIndex, String>) {
172 let mut pref_head_labels = HashMap::new();
173
174 let prepared_graph = graph.map(
175 |_, &node_val| node_val,
176 |edge_idx, edge_val| {
177 let sep_idx = match edge_val.find('|') {
178 Some(idx) => idx,
179 None => return edge_val.clone(),
180 };
181
182 let (_, dep) = graph
183 .edge_endpoints(edge_idx)
184 .expect("Cannot lookup edge endpoints");
185
186 pref_head_labels.insert(dep, edge_val[sep_idx + 1..].to_owned());
187
188 edge_val[..sep_idx].to_owned()
189 },
190 );
191
192 (prepared_graph, pref_head_labels)
193 }
194}
195
196impl Default for HeadProjectivizer {
197 fn default() -> Self {
198 HeadProjectivizer
199 }
200}
201
202impl Projectivize for HeadProjectivizer {
203 fn projectivize(&self, sentence: &mut Sentence) -> Result<(), Error> {
204 let mut graph = simplify_graph(sentence)?;
205 let mut lifted = HashSet::new();
206
207 loop {
210 let np_edges = non_projective_edges(&graph);
211 if np_edges.is_empty() {
212 break;
213 }
214
215 self.lift(&mut graph, &mut lifted, np_edges[0]);
216 }
217
218 let r = update_sentence(&graph, sentence);
221 assert!(
223 r.is_ok(),
224 "Deprojectivization add relation with unknown head/dependent"
225 );
226
227 Ok(())
228 }
229}
230
231impl Deprojectivize for HeadProjectivizer {
232 fn deprojectivize(&self, sentence: &mut Sentence) -> Result<(), Error> {
233 let graph = simplify_graph(sentence)?;
234
235 let (mut graph, head_labels) = self.prepare_deproj(&graph);
238 if head_labels.is_empty() {
239 return Ok(());
240 }
241
242 let mut lifted_sorted = Vec::new();
244 let mut bfs = Bfs::new(&graph, node_index(0));
245 while let Some(node) = bfs.next(&graph) {
246 if head_labels.get(&node).is_some() {
247 lifted_sorted.push(node);
248 }
249 }
250
251 while let Some(idx) = self.deprojectivize_next(&mut graph, &lifted_sorted, &head_labels) {
254 lifted_sorted.remove(idx);
255 }
256
257 let r = update_sentence(&graph, sentence);
258 assert!(
260 r.is_ok(),
261 "Deprojectivization add relation with unknown head/dependent"
262 );
263
264 Ok(())
265 }
266}
267
268pub fn simplify_graph(sentence: &Sentence) -> Result<Graph<(), String, Directed>, Error> {
269 let mut edges = Vec::with_capacity(sentence.len() + 1);
270 for idx in 0..sentence.len() {
271 let triple = match sentence.dep_graph().head(idx) {
272 Some(triple) => triple,
273 None => continue,
274 };
275
276 let head_rel = match triple.relation() {
277 Some(head_rel) => head_rel,
278 None => {
279 return Err(Error::IncompleteGraph {
280 value: format!(
281 "edge from {} to {} does not have a label",
282 triple.head(),
283 triple.dependent()
284 ),
285 })
286 }
287 };
288
289 edges.push((
290 node_index(triple.head()),
291 node_index(triple.dependent()),
292 head_rel.to_owned(),
293 ))
294 }
295
296 Ok(Graph::<(), String, Directed>::from_edges(edges))
297}
298
299pub fn non_projective_edges(graph: &Graph<(), String, Directed>) -> Vec<EdgeIndex> {
301 let mut non_projective = Vec::new();
302
303 for i in 0..graph.node_count() {
304 let mut i_reachable = HashSet::new();
305 let mut bfs = Bfs::new(&graph, node_index(i));
306 while let Some(node) = bfs.next(&graph) {
307 i_reachable.insert(node.index());
308 }
309
310 for edge in graph.edges(node_index(i)) {
311 for j in min(i, edge.target().index())..max(i, edge.target().index()) {
315 if !i_reachable.contains(&j) {
316 non_projective.push(edge);
317 break;
318 }
319 }
320 }
321 }
322
323 non_projective.sort_by(|a, b| {
324 let a_len = max(a.source().index(), a.target().index())
325 - min(a.source().index(), a.target().index());
326 let b_len = max(b.source().index(), b.target().index())
327 - min(b.source().index(), b.target().index());
328
329 a_len.cmp(&b_len)
330 });
331
332 non_projective.iter().map(EdgeRef::id).collect()
333}
334
335fn update_sentence(
337 graph: &Graph<(), String, Directed>,
338 sentence: &mut Sentence,
339) -> Result<(), UDError> {
340 let mut sent_graph = sentence.dep_graph_mut();
341 for edge_ref in graph.edge_references() {
342 sent_graph.add_deprel(DepTriple::new(
343 edge_ref.source().index(),
344 Some(edge_ref.weight().clone()),
345 edge_ref.target().index(),
346 ))?;
347 }
348
349 Ok(())
350}
351
352#[cfg(test)]
353mod tests {
354 use lazy_static::lazy_static;
355 use petgraph::graph::{node_index, NodeIndex};
356 use udgraph::graph::Sentence;
357
358 use crate::proj::{
359 non_projective_edges, simplify_graph, Deprojectivize, HeadProjectivizer, Projectivize,
360 };
361 use crate::tests::read_sentences;
362
363 lazy_static! {
364 static ref NON_PROJECTIVE_EDGES: Vec<Vec<(NodeIndex, NodeIndex)>> = vec![
365 vec![(node_index(8), node_index(1))],
366 vec![(node_index(10), node_index(2))],
367 vec![(node_index(5), node_index(1))],
368 vec![
369 (node_index(1), node_index(3)),
370 (node_index(7), node_index(5))
371 ],
372 ];
373 }
374
375 fn sent_non_projective_edges(sents: &[Sentence]) -> Vec<Vec<(NodeIndex, NodeIndex)>> {
376 let mut np_edges = Vec::new();
377
378 for sent in sents {
379 let graph = simplify_graph(sent).unwrap();
380 let np: Vec<_> = non_projective_edges(&graph)
381 .iter()
382 .map(|idx| graph.edge_endpoints(*idx).unwrap())
383 .collect();
384 np_edges.push(np);
385 }
386
387 np_edges
388 }
389
390 static PROJECTIVE_SENTENCES_FILENAME: &str = "testdata/projective.conll";
391
392 static NONPROJECTIVE_SENTENCES_FILENAME: &str = "testdata/nonprojective.conll";
393
394 #[test]
395 fn deprojectivize_test() {
396 let projectivizer = HeadProjectivizer::new();
397 let non_projective: Vec<_> = read_sentences(PROJECTIVE_SENTENCES_FILENAME)
398 .into_iter()
399 .map(|mut s| {
400 projectivizer
401 .deprojectivize(&mut s)
402 .expect("Cannot deprojectivize sentence");
403 s
404 })
405 .collect();
406
407 assert_eq!(
408 read_sentences(NONPROJECTIVE_SENTENCES_FILENAME),
409 non_projective
410 );
411 }
412
413 #[test]
414 fn non_projective_test() {
415 let test_edges =
416 sent_non_projective_edges(&read_sentences(NONPROJECTIVE_SENTENCES_FILENAME));
417 assert_eq!(*NON_PROJECTIVE_EDGES, test_edges);
418 }
419
420 #[test]
421 fn projectivize_test() {
422 let projectivizer = HeadProjectivizer::new();
423 let projective: Vec<_> = read_sentences(NONPROJECTIVE_SENTENCES_FILENAME)
424 .into_iter()
425 .map(|mut s| {
426 projectivizer
427 .projectivize(&mut s)
428 .expect("Cannot projectivize sentence");
429 s
430 })
431 .collect();
432
433 assert_eq!(read_sentences(PROJECTIVE_SENTENCES_FILENAME), projective);
434 }
435}