udgraph/
graph.rs

1//! Dependency graphs.
2
3use std::borrow::Borrow;
4use std::fmt::{self, Display, Formatter};
5use std::iter::FromIterator;
6use std::mem;
7use std::ops::{Index, IndexMut};
8
9use petgraph::graph::{node_index, DiGraph, NodeIndices, NodeWeightsMut};
10use petgraph::visit::EdgeRef;
11use petgraph::Direction;
12
13use crate::error::Error;
14use crate::token::Token;
15
16/// Dependency graph node.
17#[derive(Clone, Debug, Eq, PartialEq)]
18pub enum Node {
19    /// Root node.
20    Root,
21
22    /// Token node.
23    Token(Token),
24}
25
26impl Node {
27    pub fn is_root(&self) -> bool {
28        !self.is_token()
29    }
30
31    pub fn is_token(&self) -> bool {
32        match self {
33            Node::Root => false,
34            Node::Token(_) => true,
35        }
36    }
37
38    pub fn token(&self) -> Option<&Token> {
39        match self {
40            Node::Root => None,
41            Node::Token(token) => Some(token),
42        }
43    }
44
45    pub fn token_mut(&mut self) -> Option<&mut Token> {
46        match self {
47            Node::Root => None,
48            Node::Token(token) => Some(token),
49        }
50    }
51}
52
53#[derive(Clone, Debug, Eq, PartialEq)]
54/// Sentence comment.
55pub enum Comment {
56    /// Attribute-value pair
57    AttrVal { attr: String, val: String },
58
59    /// String comment
60    String(String),
61}
62
63impl Comment {
64    /// Returns `true` if the comment is an attribute-value pair.
65    pub fn is_attr_val(&self) -> bool {
66        !self.is_string()
67    }
68
69    /// Returns `true` if the comment is a string.
70    pub fn is_string(&self) -> bool {
71        match self {
72            Comment::String(_) => true,
73            Comment::AttrVal { .. } => false,
74        }
75    }
76
77    /// Get the comment attribute value pair.
78    pub fn attr_val(&self) -> Option<(&str, &str)> {
79        match self {
80            Comment::AttrVal { attr, val } => Some((attr, val)),
81            Comment::String(_) => None,
82        }
83    }
84
85    /// Get the comment string.
86    pub fn string(&self) -> Option<&str> {
87        match self {
88            Comment::AttrVal { .. } => None,
89            Comment::String(val) => Some(val),
90        }
91    }
92}
93
94impl Display for Comment {
95    fn fmt(&self, fmt: &mut Formatter) -> Result<(), fmt::Error> {
96        match self {
97            Comment::AttrVal { attr, val } => write!(fmt, "# {} = {}", attr, val),
98            Comment::String(val) => write!(fmt, "# {}", val),
99        }
100    }
101}
102
103/// A dependency triple.
104///
105/// A dependency triple consists of: a head index; a dependent index; and
106/// an optional dependency label.
107#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
108pub struct DepTriple<S> {
109    head: usize,
110    dependent: usize,
111    relation: Option<S>,
112}
113
114impl<S> DepTriple<S> {
115    /// Construct a new dependency triple.
116    pub fn new(head: usize, relation: Option<S>, dependent: usize) -> Self {
117        DepTriple {
118            head,
119            dependent,
120            relation,
121        }
122    }
123
124    /// Get the dependent.
125    pub fn dependent(&self) -> usize {
126        self.dependent
127    }
128
129    /// Get the head.
130    pub fn head(&self) -> usize {
131        self.head
132    }
133}
134
135impl<S> DepTriple<S>
136where
137    S: Borrow<str>,
138{
139    pub fn relation(&self) -> Option<&str> {
140        self.relation.as_ref().map(Borrow::borrow)
141    }
142}
143
144/// Relation type.
145///
146/// This enum is used in the underlying `petgraph` graph to
147/// distinguish between universal dependencies and enhanced universal
148/// dependencies.  This enum is public because the underlying
149/// `DiGraph` can be retrieved using the `get_ref` and `into_inner`
150/// methods of `Sentence`.
151#[derive(Clone, Copy, Debug, Eq, PartialEq)]
152pub enum RelationType {
153    Regular,
154    Enhanced,
155}
156
157/// Dependency edge.
158pub type Edge = (RelationType, Option<String>);
159
160/// A CoNLL-U dependency graph.
161///
162/// `Sentence` stores a dependency graph. The nodes in the graph
163/// (except the special root node) are tokens that have the fields
164/// of the CoNLL-U format. Dependency relations are stored as edges
165/// in the graph.
166///
167/// This data structure is a thin wrapper around the `petgraph`
168/// [`DiGraph`](petgraph::graph::DiGraph) data structure that enforces
169/// variants such as single-headedness. The
170/// [`into_inner`](Sentence::into_inner)/[`get_ref`](Sentence::get_ref)
171/// methods can be used to unwrap or get a reference to the wrapped graph.
172#[derive(Clone, Debug)]
173pub struct Sentence {
174    comments: Vec<Comment>,
175    graph: DiGraph<Node, Edge>,
176}
177
178#[allow(clippy::len_without_is_empty)]
179impl Sentence {
180    /// Construct a new sentence.
181    ///
182    /// The sentence will be constructed such that the first token is
183    /// the root of the dependency graph:
184    ///
185    /// ```
186    /// use udgraph::graph::{Node, Sentence};
187    ///
188    /// let sentence = Sentence::new();
189    /// assert_eq!(sentence[0], Node::Root);
190    /// ```
191    pub fn new() -> Self {
192        let mut graph = DiGraph::new();
193        graph.add_node(Node::Root);
194        Sentence {
195            comments: Vec::new(),
196            graph,
197        }
198    }
199
200    pub fn comments(&self) -> &[Comment] {
201        &self.comments
202    }
203
204    pub fn comments_mut(&mut self) -> &mut Vec<Comment> {
205        &mut self.comments
206    }
207
208    /// Get a reference to the [`DiGraph`](petgraph::graph::DiGraph) of
209    /// the sentence.
210    pub fn get_ref(&self) -> &DiGraph<Node, Edge> {
211        &self.graph
212    }
213
214    /// Unwrap the  [`DiGraph`](petgraph::graph::DiGraph) of the sentence.
215    pub fn into_inner(self) -> DiGraph<Node, Edge> {
216        self.graph
217    }
218
219    /// Get an iterator over the nodes in the graph.
220    pub fn iter(&self) -> Iter {
221        Iter {
222            inner: self.graph.node_indices(),
223            graph: &self.graph,
224        }
225    }
226
227    /// Get a mutable iterator over the nodes in the graph.
228    pub fn iter_mut(&mut self) -> IterMut {
229        IterMut(self.graph.node_weights_mut())
230    }
231
232    /// Add a new token to the graph.
233    ///
234    /// Tokens should always be pushed in sentence order.
235    ///
236    /// Returns the index of the token. The first pushed token has index 1,
237    /// since index 0 is reserved by the root of the graph.
238    pub fn push(&mut self, token: Token) -> usize {
239        self.graph.add_node(Node::Token(token)).index()
240    }
241
242    /// Get the dependency graph.
243    pub fn dep_graph(&self) -> DepGraph {
244        DepGraph {
245            inner: &self.graph,
246            relation_type: RelationType::Regular,
247        }
248    }
249
250    /// Get the graph mutably.
251    pub fn dep_graph_mut(&mut self) -> DepGraphMut {
252        DepGraphMut {
253            inner: &mut self.graph,
254            relation_type: RelationType::Regular,
255        }
256    }
257
258    /// Get the number of nodes in the dependency graph.
259    ///
260    /// This is equal to the number of tokens, plus one root node.
261    pub fn len(&self) -> usize {
262        self.graph.node_count()
263    }
264
265    /// Replace the comments by the given comments.
266    ///
267    /// Returns the old comments that are replaced.
268    pub fn set_comments(&mut self, comments: impl Into<Vec<Comment>>) {
269        let _ = mem::replace(&mut self.comments, comments.into());
270    }
271}
272
273impl Default for Sentence {
274    fn default() -> Self {
275        Sentence::new()
276    }
277}
278
279impl FromIterator<Token> for Sentence {
280    fn from_iter<T>(iter: T) -> Self
281    where
282        T: IntoIterator<Item = Token>,
283    {
284        let mut sentence = Sentence::new();
285        for token in iter {
286            sentence.push(token);
287        }
288        sentence
289    }
290}
291
292/// Iterator over the nodes in a dependency graph.
293pub struct Iter<'a> {
294    inner: NodeIndices,
295    graph: &'a DiGraph<Node, Edge>,
296}
297
298impl<'a> Iterator for Iter<'a> {
299    type Item = &'a Node;
300
301    fn next(&mut self) -> Option<Self::Item> {
302        self.inner.next().map(|idx| &self.graph[idx])
303    }
304}
305
306impl<'a> IntoIterator for &'a Sentence {
307    type Item = &'a Node;
308    type IntoIter = Iter<'a>;
309
310    fn into_iter(self) -> Self::IntoIter {
311        self.iter()
312    }
313}
314
315/// Mutable iterator over the nodes in a dependency graph.
316pub struct IterMut<'a>(NodeWeightsMut<'a, Node>);
317
318impl<'a> Iterator for IterMut<'a> {
319    type Item = &'a mut Node;
320
321    fn next(&mut self) -> Option<Self::Item> {
322        self.0.next()
323    }
324}
325
326impl<'a> IntoIterator for &'a mut Sentence {
327    type Item = &'a mut Node;
328    type IntoIter = IterMut<'a>;
329
330    fn into_iter(self) -> Self::IntoIter {
331        self.iter_mut()
332    }
333}
334
335impl Eq for Sentence {}
336
337impl From<Sentence> for DiGraph<Node, Edge> {
338    fn from(sentence: Sentence) -> Self {
339        sentence.into_inner()
340    }
341}
342
343impl<'a> From<&'a Sentence> for &'a DiGraph<Node, Edge> {
344    fn from(sentence: &'a Sentence) -> Self {
345        sentence.get_ref()
346    }
347}
348
349impl Index<usize> for Sentence {
350    type Output = Node;
351
352    fn index(&self, idx: usize) -> &Self::Output {
353        &self.graph[node_index(idx)]
354    }
355}
356
357impl IndexMut<usize> for Sentence {
358    fn index_mut(&mut self, idx: usize) -> &mut Self::Output {
359        &mut self.graph[node_index(idx)]
360    }
361}
362
363impl PartialEq for Sentence {
364    fn eq(&self, other: &Self) -> bool {
365        self.comments == other.comments && self.dep_graph() == other.dep_graph()
366    }
367}
368
369/// A graph view.
370///
371/// This data structure provides a view of a CoNLL-U dependency graph. The
372/// view can be used to retrieve the dependents of a head or the head of a
373/// dependent.
374pub struct DepGraph<'a> {
375    inner: &'a DiGraph<Node, Edge>,
376    relation_type: RelationType,
377}
378
379#[allow(clippy::len_without_is_empty)]
380impl<'a> DepGraph<'a> {
381    /// Return an iterator over the dependents of `head`.
382    pub fn dependents(&self, head: usize) -> impl Iterator<Item = DepTriple<&'a str>> {
383        dependents_impl(self.inner, self.relation_type, head)
384    }
385
386    /// Return the head relation of `dependent`, if any.
387    pub fn head(&self, dependent: usize) -> Option<DepTriple<&'a str>> {
388        head_impl(self.inner, self.relation_type, dependent)
389    }
390
391    /// Get the number of nodes in the dependency graph.
392    ///
393    /// This is equal to the number of tokens, plus one root node.
394    pub fn len(&self) -> usize {
395        self.inner.node_count()
396    }
397}
398
399impl<'a> Eq for DepGraph<'a> {}
400
401impl<'a> Index<usize> for DepGraph<'a> {
402    type Output = Node;
403
404    fn index(&self, idx: usize) -> &Self::Output {
405        &self.inner[node_index(idx)]
406    }
407}
408
409impl<'a, 'b> PartialEq<DepGraph<'b>> for DepGraph<'a> {
410    fn eq(&self, other: &DepGraph<'b>) -> bool {
411        // Cheap checks
412        if self.inner.node_count() != other.inner.node_count()
413            || self.inner.edge_count() != other.inner.edge_count()
414        {
415            return false;
416        }
417
418        for i in 0..self.len() {
419            // Nodes should be equal.
420            if self[i] != other[i] {
421                return false;
422            }
423
424            // Relation to a token's head should be the same.
425            if self.head(i) != other.head(i) {
426                return false;
427            }
428        }
429
430        true
431    }
432}
433
434/// A mutable graph view.
435///
436/// This data structure provides a mutable view of a CoNLL-U dependency
437/// graph. The view can be used to retrieve the dependents of a head or
438/// the head of a dependent. In addition, the [`add_deprel`](DepGraphMut::add_deprel)
439/// method can be used to add dependency relations to the graph.
440pub struct DepGraphMut<'a> {
441    inner: &'a mut DiGraph<Node, Edge>,
442    relation_type: RelationType,
443}
444
445#[allow(clippy::len_without_is_empty)]
446impl<'a> DepGraphMut<'a> {
447    /// Add a dependency relation between `head` and `dependent`.
448    ///
449    /// If `dependent` already has a head relation, this relation is removed
450    /// to ensure single-headedness.
451    pub fn add_deprel<S>(&mut self, triple: DepTriple<S>) -> Result<(), Error>
452    where
453        S: Into<String>,
454    {
455        if triple.head() >= self.inner.node_count() {
456            return Err(Error::HeadOutOfBounds {
457                head: triple.head(),
458                node_count: self.inner.node_count(),
459            });
460        }
461
462        if triple.dependent() >= self.inner.node_count() {
463            return Err(Error::DependentOutOfBounds {
464                dependent: triple.head(),
465                node_count: self.inner.node_count(),
466            });
467        }
468
469        // Remove existing head relation (when present).
470        if let Some(id) = self
471            .inner
472            .edges_directed(node_index(triple.dependent), Direction::Incoming)
473            .filter(|e| e.weight().0 == self.relation_type)
474            .map(|e| e.id())
475            .next()
476        {
477            self.inner.remove_edge(id);
478        }
479
480        self.inner.add_edge(
481            node_index(triple.head),
482            node_index(triple.dependent),
483            (self.relation_type, triple.relation.map(Into::into)),
484        );
485
486        Ok(())
487    }
488
489    /// Return an iterator over the dependents of `head`.
490    pub fn dependents(&self, head: usize) -> impl Iterator<Item = DepTriple<&str>> {
491        dependents_impl(self.inner, self.relation_type, head)
492    }
493
494    /// Return the head relation of `dependent`, if any.
495    pub fn head(&self, dependent: usize) -> Option<DepTriple<&str>> {
496        head_impl(self.inner, self.relation_type, dependent)
497    }
498
499    /// Remove relation of a token to its head.
500    ///
501    /// Returns the index of the head iff a head was removed.
502    pub fn remove_head_rel(&mut self, dependent: usize) -> Option<DepTriple<String>> {
503        // match instead of map to avoid simultaneous mutable and
504        // immutable borrow.
505        match self
506            .inner
507            .edges_directed(node_index(dependent), Direction::Incoming)
508            .find(|e| e.weight().0 == self.relation_type)
509        {
510            Some(edge) => {
511                let head = edge.source().index();
512                let edge_id = edge.id();
513                let weight = self.inner.remove_edge(edge_id);
514                Some(DepTriple::new(head, weight.unwrap().1, dependent))
515            }
516            None => None,
517        }
518    }
519
520    /// Get the number of nodes in the dependency graph.
521    ///
522    /// This is equal to the number of tokens, plus one root node.
523    pub fn len(&self) -> usize {
524        self.inner.node_count()
525    }
526}
527
528impl<'a> Index<usize> for DepGraphMut<'a> {
529    type Output = Node;
530
531    fn index(&self, idx: usize) -> &Self::Output {
532        &self.inner[node_index(idx)]
533    }
534}
535
536impl<'a> IndexMut<usize> for DepGraphMut<'a> {
537    fn index_mut(&mut self, idx: usize) -> &mut Self::Output {
538        &mut self.inner[node_index(idx)]
539    }
540}
541
542fn dependents_impl(
543    graph: &DiGraph<Node, Edge>,
544    relation_type: RelationType,
545    head: usize,
546) -> impl Iterator<Item = DepTriple<&str>> {
547    graph
548        .edges_directed(node_index(head), Direction::Outgoing)
549        .filter(move |e| e.weight().0 == relation_type)
550        .map(|e| {
551            DepTriple::new(
552                e.source().index(),
553                e.weight().1.as_deref(),
554                e.target().index(),
555            )
556        })
557}
558
559fn head_impl(
560    graph: &DiGraph<Node, Edge>,
561    relation_type: RelationType,
562    dependent: usize,
563) -> Option<DepTriple<&str>> {
564    graph
565        .edges_directed(node_index(dependent), Direction::Incoming)
566        .find(|e| e.weight().0 == relation_type)
567        .map(|e| {
568            DepTriple::new(
569                e.source().index(),
570                e.weight().1.as_deref(),
571                e.target().index(),
572            )
573        })
574}
575
576#[cfg(test)]
577mod tests {
578    use super::{DepTriple, Node, Sentence, Token};
579
580    #[test]
581    fn add_deprel() {
582        let mut g = Sentence::default();
583        g.push(Token::new("Daniël"));
584        g.push(Token::new("test"));
585        g.push(Token::new("dit"));
586        g.dep_graph_mut()
587            .add_deprel(DepTriple::new(0, Some("wrong"), 1))
588            .unwrap();
589        g.dep_graph_mut()
590            .add_deprel(DepTriple::new(0, Some("root"), 2))
591            .unwrap();
592
593        assert!(g.dep_graph().head(0).is_none());
594        assert_eq!(
595            g.dep_graph().head(1),
596            Some(DepTriple::new(0, Some("wrong"), 1))
597        );
598        assert_eq!(
599            g.dep_graph().head(2),
600            Some(DepTriple::new(0, Some("root"), 2))
601        );
602        assert!(g.dep_graph().head(3).is_none());
603
604        g.dep_graph_mut()
605            .add_deprel(DepTriple::new(2, Some("subj"), 1))
606            .unwrap();
607        g.dep_graph_mut()
608            .add_deprel(DepTriple::new(2, Some("obj1"), 3))
609            .unwrap();
610        assert_eq!(
611            g.dep_graph().head(1),
612            Some(DepTriple::new(2, Some("subj"), 1))
613        );
614        assert_eq!(
615            g.dep_graph().head(3),
616            Some(DepTriple::new(2, Some("obj1"), 3))
617        );
618    }
619
620    #[test]
621    fn dependents() {
622        let mut g = Sentence::default();
623        g.push(Token::new("Daniël"));
624        g.push(Token::new("test"));
625        g.push(Token::new("dit"));
626        g.dep_graph_mut()
627            .add_deprel(DepTriple::new(0, Some("root"), 2))
628            .unwrap();
629        g.dep_graph_mut()
630            .add_deprel(DepTriple::new(2, Some("subj"), 1))
631            .unwrap();
632        g.dep_graph_mut()
633            .add_deprel(DepTriple::new(2, Some("obj1"), 3))
634            .unwrap();
635
636        let deps = g.dep_graph().dependents(0).collect::<Vec<_>>();
637        assert_eq!(&deps, &[DepTriple::new(0, Some("root"), 2)]);
638
639        assert!(g.dep_graph().dependents(1).next().is_none());
640
641        let mut deps = g.dep_graph().dependents(2).collect::<Vec<_>>();
642        deps.sort();
643        assert_eq!(
644            &deps,
645            &[
646                DepTriple::new(2, Some("subj"), 1),
647                DepTriple::new(2, Some("obj1"), 3),
648            ]
649        );
650
651        assert!(g.dep_graph().dependents(3).next().is_none());
652    }
653
654    #[test]
655    fn equality() {
656        let mut g1 = Sentence::default();
657        g1.push(Token::new("does"));
658        g1.push(Token::new("equality"));
659        g1.push(Token::new("work"));
660
661        let g2 = g1.clone();
662        assert_eq!(g1, g2);
663
664        g1.push(Token::new("?"));
665        assert_ne!(g1, g2);
666
667        let mut g3 = g1.clone();
668        g1.dep_graph_mut()
669            .add_deprel(DepTriple::new(0, Some("root"), 3))
670            .unwrap();
671        g1.dep_graph_mut()
672            .add_deprel(DepTriple::new(3, Some("subj"), 1))
673            .unwrap();
674        assert_ne!(g1, g3);
675        g3.dep_graph_mut()
676            .add_deprel(DepTriple::new(0, Some("root"), 3))
677            .unwrap();
678        g3.dep_graph_mut()
679            .add_deprel(DepTriple::new(3, Some("subj"), 1))
680            .unwrap();
681        assert_eq!(g1, g3);
682        g3.dep_graph_mut()
683            .add_deprel(DepTriple::new(3, Some("foobar"), 1))
684            .unwrap();
685        assert_ne!(g1, g3);
686
687        let mut g4 = g1.clone();
688        if let Node::Token(ref mut token) = g4[3] {
689            token.set_xpos(Some("verb"));
690        }
691        assert_ne!(g1, g4);
692    }
693
694    #[test]
695    #[should_panic(expected = "HeadOutOfBounds")]
696    fn incorrect_head_is_rejected() {
697        let mut g = Sentence::default();
698        g.push(Token::new("Daniël"));
699        g.push(Token::new("test"));
700        g.push(Token::new("dit"));
701        g.dep_graph_mut()
702            .add_deprel(DepTriple::new(4, Some("test"), 3))
703            .unwrap();
704    }
705
706    #[test]
707    #[should_panic(expected = "DependentOutOfBounds")]
708    fn incorrect_dependent_is_rejected() {
709        let mut g = Sentence::default();
710        g.push(Token::new("Daniël"));
711        g.push(Token::new("test"));
712        g.push(Token::new("dit"));
713        g.dep_graph_mut()
714            .add_deprel(DepTriple::new(3, Some("test"), 4))
715            .unwrap();
716    }
717
718    #[test]
719    fn remove_deprel() {
720        let mut g = Sentence::default();
721        g.push(Token::new("Daniël"));
722        g.push(Token::new("test"));
723        g.push(Token::new("dit"));
724        g.dep_graph_mut()
725            .add_deprel(DepTriple::new(0, Some("wrong"), 1))
726            .unwrap();
727        g.dep_graph_mut()
728            .add_deprel(DepTriple::new(0, Some("root"), 2))
729            .unwrap();
730        assert_eq!(
731            g.dep_graph_mut().remove_head_rel(1),
732            Some(DepTriple::new(0, Some("wrong".to_owned()), 1))
733        );
734        assert!(g.dep_graph_mut().remove_head_rel(0).is_none());
735
736        assert!(g.dep_graph().head(0).is_none());
737        assert!(g.dep_graph().head(1).is_none());
738        assert_eq!(
739            g.dep_graph().head(2),
740            Some(DepTriple::new(0, Some("root"), 2))
741        );
742    }
743}