syntaxdot_encoders/depseq/
mod.rs1use serde_derive::{Deserialize, Serialize};
4
5mod error;
6pub use self::error::*;
7
8mod post_processing;
9pub(crate) use self::post_processing::*;
10
11mod relative_position;
12pub use self::relative_position::*;
13
14mod relative_pos;
15pub use self::relative_pos::*;
16
17#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
19pub struct DependencyEncoding<H> {
20 head: H,
21 label: String,
22}
23
24impl<H> DependencyEncoding<H> {
25 pub fn new(head: H, label: impl Into<String>) -> Self {
26 DependencyEncoding {
27 head,
28 label: label.into(),
29 }
30 }
31
32 pub fn head(&self) -> &H {
34 &self.head
35 }
36
37 pub fn label(&self) -> &str {
39 &self.label
40 }
41}
42
43#[cfg(test)]
44mod tests {
45 use std::fs::File;
46 use std::io::BufReader;
47 use std::path::Path;
48
49 use conllu::io::Reader;
50 use udgraph::graph::{Node, Sentence};
51
52 use super::{PosLayer, RelativePosEncoder, RelativePositionEncoder};
53 use crate::{EncodingProb, SentenceDecoder, SentenceEncoder};
54
55 const NON_PROJECTIVE_DATA: &str = "testdata/lassy-small-dev.conllu";
56
57 const ROOT_RELATION: &str = "root";
58
59 fn copy_sentence_without_deprels(sentence: &Sentence) -> Sentence {
60 let mut copy = Sentence::new();
61
62 copy.set_comments(sentence.comments().to_owned());
63
64 for token in sentence.iter().filter_map(Node::token) {
65 copy.push(token.clone());
66 }
67
68 copy
69 }
70
71 fn test_encoding<P, E, C>(path: P, encoder_decoder: E)
72 where
73 P: AsRef<Path>,
74 E: SentenceEncoder<Encoding = C> + SentenceDecoder<Encoding = C>,
75 C: 'static + Clone,
76 {
77 let f = File::open(path).unwrap();
78 let reader = Reader::new(BufReader::new(f));
79
80 for sentence in reader {
81 let sentence = sentence.unwrap();
82
83 let encodings = encoder_decoder
85 .encode(&sentence)
86 .unwrap()
87 .into_iter()
88 .map(|e| [EncodingProb::new(e, 1.)])
89 .collect::<Vec<_>>();
90
91 let mut test_sentence = copy_sentence_without_deprels(&sentence);
93 encoder_decoder
94 .decode(&encodings, &mut test_sentence)
95 .unwrap();
96
97 assert_eq!(sentence, test_sentence);
98 }
99 }
100
101 #[test]
102 fn relative_pos_position() {
103 let encoder = RelativePosEncoder::new(PosLayer::XPos, ROOT_RELATION);
104 test_encoding(NON_PROJECTIVE_DATA, encoder);
105 }
106
107 #[test]
108 fn relative_position() {
109 let encoder = RelativePositionEncoder::new(ROOT_RELATION);
110 test_encoding(NON_PROJECTIVE_DATA, encoder);
111 }
112}