1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
//! Dependency parsing as sequence labeling (Spoustová & Spousta, 2010).

use serde_derive::{Deserialize, Serialize};

mod error;
pub use self::error::*;

mod post_processing;
pub(crate) use self::post_processing::*;

mod relative_position;
pub use self::relative_position::*;

mod relative_pos;
pub use self::relative_pos::*;

/// Encoding of a dependency relation as a token label.
#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
pub struct DependencyEncoding<H> {
    head: H,
    label: String,
}

impl<H> DependencyEncoding<H> {
    pub fn new(head: H, label: impl Into<String>) -> Self {
        DependencyEncoding {
            head,
            label: label.into(),
        }
    }

    /// Get the head representation.
    pub fn head(&self) -> &H {
        &self.head
    }

    /// Get the dependency label.
    pub fn label(&self) -> &str {
        &self.label
    }
}

#[cfg(test)]
mod tests {
    use std::fs::File;
    use std::io::BufReader;
    use std::path::Path;

    use conllu::io::Reader;
    use udgraph::graph::{Node, Sentence};

    use super::{PosLayer, RelativePosEncoder, RelativePositionEncoder};
    use crate::{EncodingProb, SentenceDecoder, SentenceEncoder};

    const NON_PROJECTIVE_DATA: &str = "testdata/lassy-small-dev.conllu";

    const ROOT_RELATION: &str = "root";

    fn copy_sentence_without_deprels(sentence: &Sentence) -> Sentence {
        let mut copy = Sentence::new();

        copy.set_comments(sentence.comments().to_owned());

        for token in sentence.iter().filter_map(Node::token) {
            copy.push(token.clone());
        }

        copy
    }

    fn test_encoding<P, E, C>(path: P, encoder_decoder: E)
    where
        P: AsRef<Path>,
        E: SentenceEncoder<Encoding = C> + SentenceDecoder<Encoding = C>,
        C: 'static + Clone,
    {
        let f = File::open(path).unwrap();
        let reader = Reader::new(BufReader::new(f));

        for sentence in reader {
            let sentence = sentence.unwrap();

            // Encode
            let encodings = encoder_decoder
                .encode(&sentence)
                .unwrap()
                .into_iter()
                .map(|e| [EncodingProb::new(e, 1.)])
                .collect::<Vec<_>>();

            // Decode
            let mut test_sentence = copy_sentence_without_deprels(&sentence);
            encoder_decoder
                .decode(&encodings, &mut test_sentence)
                .unwrap();

            assert_eq!(sentence, test_sentence);
        }
    }

    #[test]
    fn relative_pos_position() {
        let encoder = RelativePosEncoder::new(PosLayer::XPos, ROOT_RELATION);
        test_encoding(NON_PROJECTIVE_DATA, encoder);
    }

    #[test]
    fn relative_position() {
        let encoder = RelativePositionEncoder::new(ROOT_RELATION);
        test_encoding(NON_PROJECTIVE_DATA, encoder);
    }
}