1use std::borrow::Borrow;
4use std::fmt::{self, Display, Formatter};
5use std::iter::FromIterator;
6use std::mem;
7use std::ops::{Index, IndexMut};
8
9use petgraph::graph::{node_index, DiGraph, NodeIndices, NodeWeightsMut};
10use petgraph::visit::EdgeRef;
11use petgraph::Direction;
12
13use crate::error::Error;
14use crate::token::Token;
15
16#[derive(Clone, Debug, Eq, PartialEq)]
18pub enum Node {
19 Root,
21
22 Token(Token),
24}
25
26impl Node {
27 pub fn is_root(&self) -> bool {
28 !self.is_token()
29 }
30
31 pub fn is_token(&self) -> bool {
32 match self {
33 Node::Root => false,
34 Node::Token(_) => true,
35 }
36 }
37
38 pub fn token(&self) -> Option<&Token> {
39 match self {
40 Node::Root => None,
41 Node::Token(token) => Some(token),
42 }
43 }
44
45 pub fn token_mut(&mut self) -> Option<&mut Token> {
46 match self {
47 Node::Root => None,
48 Node::Token(token) => Some(token),
49 }
50 }
51}
52
53#[derive(Clone, Debug, Eq, PartialEq)]
54pub enum Comment {
56 AttrVal { attr: String, val: String },
58
59 String(String),
61}
62
63impl Comment {
64 pub fn is_attr_val(&self) -> bool {
66 !self.is_string()
67 }
68
69 pub fn is_string(&self) -> bool {
71 match self {
72 Comment::String(_) => true,
73 Comment::AttrVal { .. } => false,
74 }
75 }
76
77 pub fn attr_val(&self) -> Option<(&str, &str)> {
79 match self {
80 Comment::AttrVal { attr, val } => Some((attr, val)),
81 Comment::String(_) => None,
82 }
83 }
84
85 pub fn string(&self) -> Option<&str> {
87 match self {
88 Comment::AttrVal { .. } => None,
89 Comment::String(val) => Some(val),
90 }
91 }
92}
93
94impl Display for Comment {
95 fn fmt(&self, fmt: &mut Formatter) -> Result<(), fmt::Error> {
96 match self {
97 Comment::AttrVal { attr, val } => write!(fmt, "# {} = {}", attr, val),
98 Comment::String(val) => write!(fmt, "# {}", val),
99 }
100 }
101}
102
103#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
108pub struct DepTriple<S> {
109 head: usize,
110 dependent: usize,
111 relation: Option<S>,
112}
113
114impl<S> DepTriple<S> {
115 pub fn new(head: usize, relation: Option<S>, dependent: usize) -> Self {
117 DepTriple {
118 head,
119 dependent,
120 relation,
121 }
122 }
123
124 pub fn dependent(&self) -> usize {
126 self.dependent
127 }
128
129 pub fn head(&self) -> usize {
131 self.head
132 }
133}
134
135impl<S> DepTriple<S>
136where
137 S: Borrow<str>,
138{
139 pub fn relation(&self) -> Option<&str> {
140 self.relation.as_ref().map(Borrow::borrow)
141 }
142}
143
144#[derive(Clone, Copy, Debug, Eq, PartialEq)]
152pub enum RelationType {
153 Regular,
154 Enhanced,
155}
156
157pub type Edge = (RelationType, Option<String>);
159
160#[derive(Clone, Debug)]
173pub struct Sentence {
174 comments: Vec<Comment>,
175 graph: DiGraph<Node, Edge>,
176}
177
178#[allow(clippy::len_without_is_empty)]
179impl Sentence {
180 pub fn new() -> Self {
192 let mut graph = DiGraph::new();
193 graph.add_node(Node::Root);
194 Sentence {
195 comments: Vec::new(),
196 graph,
197 }
198 }
199
200 pub fn comments(&self) -> &[Comment] {
201 &self.comments
202 }
203
204 pub fn comments_mut(&mut self) -> &mut Vec<Comment> {
205 &mut self.comments
206 }
207
208 pub fn get_ref(&self) -> &DiGraph<Node, Edge> {
211 &self.graph
212 }
213
214 pub fn into_inner(self) -> DiGraph<Node, Edge> {
216 self.graph
217 }
218
219 pub fn iter(&self) -> Iter {
221 Iter {
222 inner: self.graph.node_indices(),
223 graph: &self.graph,
224 }
225 }
226
227 pub fn iter_mut(&mut self) -> IterMut {
229 IterMut(self.graph.node_weights_mut())
230 }
231
232 pub fn push(&mut self, token: Token) -> usize {
239 self.graph.add_node(Node::Token(token)).index()
240 }
241
242 pub fn dep_graph(&self) -> DepGraph {
244 DepGraph {
245 inner: &self.graph,
246 relation_type: RelationType::Regular,
247 }
248 }
249
250 pub fn dep_graph_mut(&mut self) -> DepGraphMut {
252 DepGraphMut {
253 inner: &mut self.graph,
254 relation_type: RelationType::Regular,
255 }
256 }
257
258 pub fn len(&self) -> usize {
262 self.graph.node_count()
263 }
264
265 pub fn set_comments(&mut self, comments: impl Into<Vec<Comment>>) {
269 let _ = mem::replace(&mut self.comments, comments.into());
270 }
271}
272
273impl Default for Sentence {
274 fn default() -> Self {
275 Sentence::new()
276 }
277}
278
279impl FromIterator<Token> for Sentence {
280 fn from_iter<T>(iter: T) -> Self
281 where
282 T: IntoIterator<Item = Token>,
283 {
284 let mut sentence = Sentence::new();
285 for token in iter {
286 sentence.push(token);
287 }
288 sentence
289 }
290}
291
292pub struct Iter<'a> {
294 inner: NodeIndices,
295 graph: &'a DiGraph<Node, Edge>,
296}
297
298impl<'a> Iterator for Iter<'a> {
299 type Item = &'a Node;
300
301 fn next(&mut self) -> Option<Self::Item> {
302 self.inner.next().map(|idx| &self.graph[idx])
303 }
304}
305
306impl<'a> IntoIterator for &'a Sentence {
307 type Item = &'a Node;
308 type IntoIter = Iter<'a>;
309
310 fn into_iter(self) -> Self::IntoIter {
311 self.iter()
312 }
313}
314
315pub struct IterMut<'a>(NodeWeightsMut<'a, Node>);
317
318impl<'a> Iterator for IterMut<'a> {
319 type Item = &'a mut Node;
320
321 fn next(&mut self) -> Option<Self::Item> {
322 self.0.next()
323 }
324}
325
326impl<'a> IntoIterator for &'a mut Sentence {
327 type Item = &'a mut Node;
328 type IntoIter = IterMut<'a>;
329
330 fn into_iter(self) -> Self::IntoIter {
331 self.iter_mut()
332 }
333}
334
335impl Eq for Sentence {}
336
337impl From<Sentence> for DiGraph<Node, Edge> {
338 fn from(sentence: Sentence) -> Self {
339 sentence.into_inner()
340 }
341}
342
343impl<'a> From<&'a Sentence> for &'a DiGraph<Node, Edge> {
344 fn from(sentence: &'a Sentence) -> Self {
345 sentence.get_ref()
346 }
347}
348
349impl Index<usize> for Sentence {
350 type Output = Node;
351
352 fn index(&self, idx: usize) -> &Self::Output {
353 &self.graph[node_index(idx)]
354 }
355}
356
357impl IndexMut<usize> for Sentence {
358 fn index_mut(&mut self, idx: usize) -> &mut Self::Output {
359 &mut self.graph[node_index(idx)]
360 }
361}
362
363impl PartialEq for Sentence {
364 fn eq(&self, other: &Self) -> bool {
365 self.comments == other.comments && self.dep_graph() == other.dep_graph()
366 }
367}
368
369pub struct DepGraph<'a> {
375 inner: &'a DiGraph<Node, Edge>,
376 relation_type: RelationType,
377}
378
379#[allow(clippy::len_without_is_empty)]
380impl<'a> DepGraph<'a> {
381 pub fn dependents(&self, head: usize) -> impl Iterator<Item = DepTriple<&'a str>> {
383 dependents_impl(self.inner, self.relation_type, head)
384 }
385
386 pub fn head(&self, dependent: usize) -> Option<DepTriple<&'a str>> {
388 head_impl(self.inner, self.relation_type, dependent)
389 }
390
391 pub fn len(&self) -> usize {
395 self.inner.node_count()
396 }
397}
398
399impl<'a> Eq for DepGraph<'a> {}
400
401impl<'a> Index<usize> for DepGraph<'a> {
402 type Output = Node;
403
404 fn index(&self, idx: usize) -> &Self::Output {
405 &self.inner[node_index(idx)]
406 }
407}
408
409impl<'a, 'b> PartialEq<DepGraph<'b>> for DepGraph<'a> {
410 fn eq(&self, other: &DepGraph<'b>) -> bool {
411 if self.inner.node_count() != other.inner.node_count()
413 || self.inner.edge_count() != other.inner.edge_count()
414 {
415 return false;
416 }
417
418 for i in 0..self.len() {
419 if self[i] != other[i] {
421 return false;
422 }
423
424 if self.head(i) != other.head(i) {
426 return false;
427 }
428 }
429
430 true
431 }
432}
433
434pub struct DepGraphMut<'a> {
441 inner: &'a mut DiGraph<Node, Edge>,
442 relation_type: RelationType,
443}
444
445#[allow(clippy::len_without_is_empty)]
446impl<'a> DepGraphMut<'a> {
447 pub fn add_deprel<S>(&mut self, triple: DepTriple<S>) -> Result<(), Error>
452 where
453 S: Into<String>,
454 {
455 if triple.head() >= self.inner.node_count() {
456 return Err(Error::HeadOutOfBounds {
457 head: triple.head(),
458 node_count: self.inner.node_count(),
459 });
460 }
461
462 if triple.dependent() >= self.inner.node_count() {
463 return Err(Error::DependentOutOfBounds {
464 dependent: triple.head(),
465 node_count: self.inner.node_count(),
466 });
467 }
468
469 if let Some(id) = self
471 .inner
472 .edges_directed(node_index(triple.dependent), Direction::Incoming)
473 .filter(|e| e.weight().0 == self.relation_type)
474 .map(|e| e.id())
475 .next()
476 {
477 self.inner.remove_edge(id);
478 }
479
480 self.inner.add_edge(
481 node_index(triple.head),
482 node_index(triple.dependent),
483 (self.relation_type, triple.relation.map(Into::into)),
484 );
485
486 Ok(())
487 }
488
489 pub fn dependents(&self, head: usize) -> impl Iterator<Item = DepTriple<&str>> {
491 dependents_impl(self.inner, self.relation_type, head)
492 }
493
494 pub fn head(&self, dependent: usize) -> Option<DepTriple<&str>> {
496 head_impl(self.inner, self.relation_type, dependent)
497 }
498
499 pub fn remove_head_rel(&mut self, dependent: usize) -> Option<DepTriple<String>> {
503 match self
506 .inner
507 .edges_directed(node_index(dependent), Direction::Incoming)
508 .find(|e| e.weight().0 == self.relation_type)
509 {
510 Some(edge) => {
511 let head = edge.source().index();
512 let edge_id = edge.id();
513 let weight = self.inner.remove_edge(edge_id);
514 Some(DepTriple::new(head, weight.unwrap().1, dependent))
515 }
516 None => None,
517 }
518 }
519
520 pub fn len(&self) -> usize {
524 self.inner.node_count()
525 }
526}
527
528impl<'a> Index<usize> for DepGraphMut<'a> {
529 type Output = Node;
530
531 fn index(&self, idx: usize) -> &Self::Output {
532 &self.inner[node_index(idx)]
533 }
534}
535
536impl<'a> IndexMut<usize> for DepGraphMut<'a> {
537 fn index_mut(&mut self, idx: usize) -> &mut Self::Output {
538 &mut self.inner[node_index(idx)]
539 }
540}
541
542fn dependents_impl(
543 graph: &DiGraph<Node, Edge>,
544 relation_type: RelationType,
545 head: usize,
546) -> impl Iterator<Item = DepTriple<&str>> {
547 graph
548 .edges_directed(node_index(head), Direction::Outgoing)
549 .filter(move |e| e.weight().0 == relation_type)
550 .map(|e| {
551 DepTriple::new(
552 e.source().index(),
553 e.weight().1.as_deref(),
554 e.target().index(),
555 )
556 })
557}
558
559fn head_impl(
560 graph: &DiGraph<Node, Edge>,
561 relation_type: RelationType,
562 dependent: usize,
563) -> Option<DepTriple<&str>> {
564 graph
565 .edges_directed(node_index(dependent), Direction::Incoming)
566 .find(|e| e.weight().0 == relation_type)
567 .map(|e| {
568 DepTriple::new(
569 e.source().index(),
570 e.weight().1.as_deref(),
571 e.target().index(),
572 )
573 })
574}
575
576#[cfg(test)]
577mod tests {
578 use super::{DepTriple, Node, Sentence, Token};
579
580 #[test]
581 fn add_deprel() {
582 let mut g = Sentence::default();
583 g.push(Token::new("Daniël"));
584 g.push(Token::new("test"));
585 g.push(Token::new("dit"));
586 g.dep_graph_mut()
587 .add_deprel(DepTriple::new(0, Some("wrong"), 1))
588 .unwrap();
589 g.dep_graph_mut()
590 .add_deprel(DepTriple::new(0, Some("root"), 2))
591 .unwrap();
592
593 assert!(g.dep_graph().head(0).is_none());
594 assert_eq!(
595 g.dep_graph().head(1),
596 Some(DepTriple::new(0, Some("wrong"), 1))
597 );
598 assert_eq!(
599 g.dep_graph().head(2),
600 Some(DepTriple::new(0, Some("root"), 2))
601 );
602 assert!(g.dep_graph().head(3).is_none());
603
604 g.dep_graph_mut()
605 .add_deprel(DepTriple::new(2, Some("subj"), 1))
606 .unwrap();
607 g.dep_graph_mut()
608 .add_deprel(DepTriple::new(2, Some("obj1"), 3))
609 .unwrap();
610 assert_eq!(
611 g.dep_graph().head(1),
612 Some(DepTriple::new(2, Some("subj"), 1))
613 );
614 assert_eq!(
615 g.dep_graph().head(3),
616 Some(DepTriple::new(2, Some("obj1"), 3))
617 );
618 }
619
620 #[test]
621 fn dependents() {
622 let mut g = Sentence::default();
623 g.push(Token::new("Daniël"));
624 g.push(Token::new("test"));
625 g.push(Token::new("dit"));
626 g.dep_graph_mut()
627 .add_deprel(DepTriple::new(0, Some("root"), 2))
628 .unwrap();
629 g.dep_graph_mut()
630 .add_deprel(DepTriple::new(2, Some("subj"), 1))
631 .unwrap();
632 g.dep_graph_mut()
633 .add_deprel(DepTriple::new(2, Some("obj1"), 3))
634 .unwrap();
635
636 let deps = g.dep_graph().dependents(0).collect::<Vec<_>>();
637 assert_eq!(&deps, &[DepTriple::new(0, Some("root"), 2)]);
638
639 assert!(g.dep_graph().dependents(1).next().is_none());
640
641 let mut deps = g.dep_graph().dependents(2).collect::<Vec<_>>();
642 deps.sort();
643 assert_eq!(
644 &deps,
645 &[
646 DepTriple::new(2, Some("subj"), 1),
647 DepTriple::new(2, Some("obj1"), 3),
648 ]
649 );
650
651 assert!(g.dep_graph().dependents(3).next().is_none());
652 }
653
654 #[test]
655 fn equality() {
656 let mut g1 = Sentence::default();
657 g1.push(Token::new("does"));
658 g1.push(Token::new("equality"));
659 g1.push(Token::new("work"));
660
661 let g2 = g1.clone();
662 assert_eq!(g1, g2);
663
664 g1.push(Token::new("?"));
665 assert_ne!(g1, g2);
666
667 let mut g3 = g1.clone();
668 g1.dep_graph_mut()
669 .add_deprel(DepTriple::new(0, Some("root"), 3))
670 .unwrap();
671 g1.dep_graph_mut()
672 .add_deprel(DepTriple::new(3, Some("subj"), 1))
673 .unwrap();
674 assert_ne!(g1, g3);
675 g3.dep_graph_mut()
676 .add_deprel(DepTriple::new(0, Some("root"), 3))
677 .unwrap();
678 g3.dep_graph_mut()
679 .add_deprel(DepTriple::new(3, Some("subj"), 1))
680 .unwrap();
681 assert_eq!(g1, g3);
682 g3.dep_graph_mut()
683 .add_deprel(DepTriple::new(3, Some("foobar"), 1))
684 .unwrap();
685 assert_ne!(g1, g3);
686
687 let mut g4 = g1.clone();
688 if let Node::Token(ref mut token) = g4[3] {
689 token.set_xpos(Some("verb"));
690 }
691 assert_ne!(g1, g4);
692 }
693
694 #[test]
695 #[should_panic(expected = "HeadOutOfBounds")]
696 fn incorrect_head_is_rejected() {
697 let mut g = Sentence::default();
698 g.push(Token::new("Daniël"));
699 g.push(Token::new("test"));
700 g.push(Token::new("dit"));
701 g.dep_graph_mut()
702 .add_deprel(DepTriple::new(4, Some("test"), 3))
703 .unwrap();
704 }
705
706 #[test]
707 #[should_panic(expected = "DependentOutOfBounds")]
708 fn incorrect_dependent_is_rejected() {
709 let mut g = Sentence::default();
710 g.push(Token::new("Daniël"));
711 g.push(Token::new("test"));
712 g.push(Token::new("dit"));
713 g.dep_graph_mut()
714 .add_deprel(DepTriple::new(3, Some("test"), 4))
715 .unwrap();
716 }
717
718 #[test]
719 fn remove_deprel() {
720 let mut g = Sentence::default();
721 g.push(Token::new("Daniël"));
722 g.push(Token::new("test"));
723 g.push(Token::new("dit"));
724 g.dep_graph_mut()
725 .add_deprel(DepTriple::new(0, Some("wrong"), 1))
726 .unwrap();
727 g.dep_graph_mut()
728 .add_deprel(DepTriple::new(0, Some("root"), 2))
729 .unwrap();
730 assert_eq!(
731 g.dep_graph_mut().remove_head_rel(1),
732 Some(DepTriple::new(0, Some("wrong".to_owned()), 1))
733 );
734 assert!(g.dep_graph_mut().remove_head_rel(0).is_none());
735
736 assert!(g.dep_graph().head(0).is_none());
737 assert!(g.dep_graph().head(1).is_none());
738 assert_eq!(
739 g.dep_graph().head(2),
740 Some(DepTriple::new(0, Some("root"), 2))
741 );
742 }
743}