1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
//! Dependency relation encoders.

use serde_derive::{Deserialize, Serialize};

mod error;
pub use self::error::*;

mod post_processing;
pub(crate) use self::post_processing::*;

mod relative_position;
pub use self::relative_position::*;

mod relative_pos;
pub use self::relative_pos::*;

/// Encoding of a dependency relation as a token label.
#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
pub struct DependencyEncoding<H> {
    head: H,
    label: String,
}

impl<H> DependencyEncoding<H> {
    pub fn new(head: H, label: impl Into<String>) -> Self {
        DependencyEncoding {
            head,
            label: label.into(),
        }
    }

    /// Get the head representation.
    pub fn head(&self) -> &H {
        &self.head
    }

    /// Get the dependency label.
    pub fn label(&self) -> &str {
        &self.label
    }
}

#[cfg(test)]
mod tests {
    use std::fs::File;
    use std::io::BufReader;
    use std::path::Path;

    use conllu::graph::{Node, Sentence};
    use conllu::io::Reader;

    use super::{POSLayer, RelativePOSEncoder, RelativePositionEncoder};
    use crate::{EncodingProb, SentenceDecoder, SentenceEncoder};

    const NON_PROJECTIVE_DATA: &str = "testdata/lassy-small-dev.conllu";

    const ROOT_RELATION: &str = "root";

    fn copy_sentence_without_deprels(sentence: &Sentence) -> Sentence {
        let mut copy = Sentence::new();

        copy.set_comments(sentence.comments().to_owned());

        for token in sentence.iter().filter_map(Node::token) {
            copy.push(token.clone());
        }

        copy
    }

    fn test_encoding<P, E, C>(path: P, encoder_decoder: E)
    where
        P: AsRef<Path>,
        E: SentenceEncoder<Encoding = C> + SentenceDecoder<Encoding = C>,
        C: 'static + Clone,
    {
        let f = File::open(path).unwrap();
        let reader = Reader::new(BufReader::new(f));

        for sentence in reader {
            let sentence = sentence.unwrap();

            // Encode
            let encodings = encoder_decoder
                .encode(&sentence)
                .unwrap()
                .into_iter()
                .map(|e| [EncodingProb::new(e, 1.)])
                .collect::<Vec<_>>();

            // Decode
            let mut test_sentence = copy_sentence_without_deprels(&sentence);
            encoder_decoder
                .decode(&encodings, &mut test_sentence)
                .unwrap();

            assert_eq!(sentence, test_sentence);
        }
    }

    #[test]
    fn relative_pos_position() {
        let encoder = RelativePOSEncoder::new(POSLayer::XPos, ROOT_RELATION);
        test_encoding(NON_PROJECTIVE_DATA, encoder);
    }

    #[test]
    fn relative_position() {
        let encoder = RelativePositionEncoder::new(ROOT_RELATION);
        test_encoding(NON_PROJECTIVE_DATA, encoder);
    }
}