syntaxdot_encoders/depseq/
relative_position.rs1use serde_derive::{Deserialize, Serialize};
2use udgraph::graph::{DepTriple, Sentence};
3use udgraph::Error;
4
5use super::{
6 attach_orphans, break_cycles, find_or_create_root, DecodeError, DependencyEncoding, EncodeError,
7};
8use crate::{EncodingProb, SentenceDecoder, SentenceEncoder};
9
10#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
14pub struct RelativePosition(isize);
15
16impl ToString for DependencyEncoding<RelativePosition> {
17 fn to_string(&self) -> String {
18 format!("{}/{}", self.label, self.head.0)
19 }
20}
21
22#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
28pub struct RelativePositionEncoder {
29 root_relation: String,
30}
31
32impl RelativePositionEncoder {
33 pub fn new(root_relation: impl Into<String>) -> Self {
34 RelativePositionEncoder {
35 root_relation: root_relation.into(),
36 }
37 }
38}
39
40impl RelativePositionEncoder {
41 fn decode_idx(
42 idx: usize,
43 sentence_len: usize,
44 encoding: &DependencyEncoding<RelativePosition>,
45 ) -> Result<DepTriple<String>, DecodeError> {
46 let DependencyEncoding {
47 label,
48 head: RelativePosition(head),
49 } = encoding;
50
51 let head_idx = idx as isize + head;
52 if head_idx < 0 || head_idx >= sentence_len as isize {
53 return Err(DecodeError::PositionOutOfBounds);
54 }
55
56 Ok(DepTriple::new(
57 (idx as isize + head) as usize,
58 Some(label.clone()),
59 idx,
60 ))
61 }
62}
63
64impl SentenceEncoder for RelativePositionEncoder {
65 type Encoding = DependencyEncoding<RelativePosition>;
66
67 type Error = EncodeError;
68
69 fn encode(&self, sentence: &Sentence) -> Result<Vec<Self::Encoding>, Self::Error> {
70 let mut encoded = Vec::with_capacity(sentence.len());
71 for idx in 1..sentence.len() {
72 let triple = sentence
73 .dep_graph()
74 .head(idx)
75 .ok_or_else(|| EncodeError::missing_head(idx, sentence))?;
76 let relation = triple
77 .relation()
78 .ok_or_else(|| EncodeError::missing_relation(idx, sentence))?;
79
80 encoded.push(DependencyEncoding {
81 label: relation.to_owned(),
82 head: RelativePosition(triple.head() as isize - triple.dependent() as isize),
83 });
84 }
85
86 Ok(encoded)
87 }
88}
89
90impl SentenceDecoder for RelativePositionEncoder {
91 type Encoding = DependencyEncoding<RelativePosition>;
92
93 type Error = Error;
94
95 fn decode<S>(&self, labels: &[S], sentence: &mut Sentence) -> Result<(), Self::Error>
96 where
97 S: AsRef<[EncodingProb<Self::Encoding>]>,
98 {
99 #[allow(clippy::needless_collect)]
101 let token_indices: Vec<_> = (0..sentence.len())
102 .filter(|&idx| sentence[idx].is_token())
103 .collect();
104
105 for (idx, encodings) in token_indices.into_iter().zip(labels) {
106 for encoding in encodings.as_ref() {
107 if let Ok(triple) =
108 RelativePositionEncoder::decode_idx(idx, sentence.len(), encoding.encoding())
109 {
110 sentence.dep_graph_mut().add_deprel(triple)?;
111 break;
112 }
113 }
114 }
115
116 let sentence_len = sentence.len();
118 let root_idx = find_or_create_root(
119 labels,
120 sentence,
121 |idx, encoding| Self::decode_idx(idx, sentence_len, encoding).ok(),
122 &self.root_relation,
123 )?;
124 attach_orphans(labels, sentence, root_idx)?;
125 break_cycles(sentence, root_idx)?;
126
127 Ok(())
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use udgraph::graph::{DepTriple, Sentence};
134 use udgraph::token::TokenBuilder;
135
136 use super::{RelativePosition, RelativePositionEncoder};
137 use crate::depseq::{DecodeError, DependencyEncoding};
138 use crate::{EncodingProb, SentenceDecoder};
139
140 const ROOT_RELATION: &str = "root";
141
142 #[test]
146 fn position_out_of_bounds() {
147 let mut sent = Sentence::new();
148 sent.push(TokenBuilder::new("a").xpos("A").into());
149 sent.push(TokenBuilder::new("b").xpos("B").into());
150
151 assert_eq!(
152 RelativePositionEncoder::decode_idx(
153 1,
154 sent.len(),
155 &DependencyEncoding {
156 label: "X".into(),
157 head: RelativePosition(-2),
158 },
159 ),
160 Err(DecodeError::PositionOutOfBounds)
161 )
162 }
163
164 #[test]
165 fn backoff() {
166 let mut sent = Sentence::new();
167 sent.push(TokenBuilder::new("a").xpos("A").into());
168
169 let decoder = RelativePositionEncoder::new(ROOT_RELATION);
170 let labels = vec![vec![
171 EncodingProb::new(
172 DependencyEncoding {
173 label: ROOT_RELATION.into(),
174 head: RelativePosition(-2),
175 },
176 1.0,
177 ),
178 EncodingProb::new(
179 DependencyEncoding {
180 label: ROOT_RELATION.into(),
181 head: RelativePosition(-1),
182 },
183 1.0,
184 ),
185 ]];
186
187 decoder.decode(&labels, &mut sent).unwrap();
188
189 assert_eq!(
190 sent.dep_graph().head(1),
191 Some(DepTriple::new(0, Some(ROOT_RELATION), 1))
192 );
193 }
194}