syntaxdot_encoders/depseq/
relative_pos.rs1use std::collections::HashMap;
10
11use serde_derive::{Deserialize, Serialize};
12use udgraph::graph::{DepTriple, Node, Sentence};
13use udgraph::token::Token;
14use udgraph::Error;
15
16use super::{
17 attach_orphans, break_cycles, find_or_create_root, DecodeError, DependencyEncoding, EncodeError,
18};
19use crate::{EncodingProb, SentenceDecoder, SentenceEncoder};
20
21const ROOT_POS: &str = "ROOT";
22
23#[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize)]
25#[serde(rename_all = "lowercase")]
26pub enum PosLayer {
27 UPos,
29
30 XPos,
32}
33
34impl PosLayer {
35 fn pos(self, token: &Token) -> Option<&str> {
36 match self {
37 PosLayer::UPos => token.upos(),
38 PosLayer::XPos => token.xpos(),
39 }
40 }
41}
42
43#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
50pub struct RelativePos {
51 pos: String,
52 position: isize,
53}
54
55impl RelativePos {
56 #[allow(dead_code)]
57 pub fn new(pos: impl Into<String>, position: isize) -> Self {
58 RelativePos {
59 pos: pos.into(),
60 position,
61 }
62 }
63}
64
65impl ToString for DependencyEncoding<RelativePos> {
66 fn to_string(&self) -> String {
67 format!("{}/{}/{}", self.label, self.head.pos, self.head.position)
68 }
69}
70
71#[derive(Clone, Deserialize, Eq, PartialEq, Serialize)]
77pub struct RelativePosEncoder {
78 pos_layer: PosLayer,
79 root_relation: String,
80}
81
82impl RelativePosEncoder {
83 pub fn new(pos_layer: PosLayer, root_relation: impl Into<String>) -> Self {
84 RelativePosEncoder {
85 pos_layer,
86 root_relation: root_relation.into(),
87 }
88 }
89}
90
91impl RelativePosEncoder {
92 pub(crate) fn decode_idx(
93 pos_table: &HashMap<String, Vec<usize>>,
94 idx: usize,
95 encoding: &DependencyEncoding<RelativePos>,
96 ) -> Result<DepTriple<String>, DecodeError> {
97 let DependencyEncoding { label, head } = encoding;
98
99 let indices = pos_table
100 .get(head.pos.as_str())
101 .ok_or(DecodeError::InvalidPos)?;
102
103 let head_idx = Self::head_index(indices, idx, head.position)?;
104
105 Ok(DepTriple::new(head_idx, Some(label.to_owned()), idx))
106 }
107
108 fn relative_dependent_position(indices: &[usize], head: usize, dependent: usize) -> isize {
113 let mut head_position = indices
114 .binary_search(&head)
115 .expect("Head is missing in sorted POS tag list");
116
117 let dependent_position = match indices.binary_search(&dependent) {
118 Ok(idx) => idx,
119 Err(idx) => {
120 if dependent < head {
132 head_position += 1;
133 }
134 idx
135 }
136 };
137
138 head_position as isize - dependent_position as isize
139 }
140
141 fn head_index(
146 indices: &[usize],
147 dependent: usize,
148 mut relative_head_position: isize,
149 ) -> Result<usize, DecodeError> {
150 let dependent_position = match indices.binary_search(&dependent) {
151 Ok(idx) => idx,
152 Err(idx) => {
153 if relative_head_position > 0 {
165 relative_head_position -= 1
166 }
167 idx
168 }
169 };
170
171 let head_position = dependent_position as isize + relative_head_position;
172 if head_position < 0 || head_position >= indices.len() as isize {
173 return Err(DecodeError::PositionOutOfBounds);
174 }
175
176 Ok(indices[head_position as usize])
177 }
178
179 pub(crate) fn pos_position_table(&self, sentence: &Sentence) -> HashMap<String, Vec<usize>> {
180 let mut table: HashMap<String, Vec<_>> = HashMap::new();
181
182 for (idx, node) in sentence.iter().enumerate() {
183 let pos = match node {
184 Node::Root => ROOT_POS.into(),
185 Node::Token(token) => match self.pos_layer.pos(token) {
186 Some(pos) => pos.into(),
187 None => continue,
188 },
189 };
190
191 let indices = table.entry(pos).or_default();
192 indices.push(idx);
193 }
194
195 table
196 }
197}
198
199impl SentenceEncoder for RelativePosEncoder {
200 type Encoding = DependencyEncoding<RelativePos>;
201
202 type Error = EncodeError;
203
204 fn encode(&self, sentence: &Sentence) -> Result<Vec<Self::Encoding>, Self::Error> {
205 let pos_table = self.pos_position_table(sentence);
206
207 let mut encoded = Vec::with_capacity(sentence.len());
208 for idx in 0..sentence.len() {
209 if let Node::Root = &sentence[idx] {
210 continue;
211 }
212
213 let triple = sentence
214 .dep_graph()
215 .head(idx)
216 .ok_or_else(|| EncodeError::missing_head(idx, sentence))?;
217 let relation = triple
218 .relation()
219 .ok_or_else(|| EncodeError::missing_relation(idx, sentence))?;
220
221 let head_pos = match &sentence[triple.head()] {
222 Node::Root => ROOT_POS,
223 Node::Token(head_token) => self
224 .pos_layer
225 .pos(head_token)
226 .ok_or_else(|| EncodeError::missing_pos(idx, sentence))?,
227 };
228
229 let position = Self::relative_dependent_position(
230 &pos_table[head_pos],
231 triple.head(),
232 triple.dependent(),
233 );
234
235 encoded.push(DependencyEncoding {
236 label: relation.to_owned(),
237 head: RelativePos {
238 pos: head_pos.to_owned(),
239 position,
240 },
241 });
242 }
243
244 Ok(encoded)
245 }
246}
247
248impl SentenceDecoder for RelativePosEncoder {
249 type Encoding = DependencyEncoding<RelativePos>;
250
251 type Error = Error;
252
253 fn decode<S>(&self, labels: &[S], sentence: &mut Sentence) -> Result<(), Self::Error>
254 where
255 S: AsRef<[EncodingProb<Self::Encoding>]>,
256 {
257 let pos_table = self.pos_position_table(sentence);
258
259 #[allow(clippy::needless_collect)]
261 let token_indices: Vec<_> = (0..sentence.len())
262 .filter(|&idx| sentence[idx].is_token())
263 .collect();
264
265 for (idx, encodings) in token_indices.into_iter().zip(labels) {
266 for encoding in encodings.as_ref() {
267 if let Ok(triple) =
268 RelativePosEncoder::decode_idx(&pos_table, idx, encoding.encoding())
269 {
270 sentence.dep_graph_mut().add_deprel(triple)?;
271 break;
272 }
273 }
274 }
275
276 let root_idx = find_or_create_root(
278 labels,
279 sentence,
280 |idx, encoding| Self::decode_idx(&pos_table, idx, encoding).ok(),
281 &self.root_relation,
282 )?;
283 attach_orphans(labels, sentence, root_idx)?;
284 break_cycles(sentence, root_idx)?;
285
286 Ok(())
287 }
288}
289
290#[cfg(test)]
291mod tests {
292 use std::collections::HashMap;
293 use std::iter::FromIterator;
294
295 use udgraph::graph::{DepTriple, Sentence};
296 use udgraph::token::TokenBuilder;
297
298 use super::{PosLayer, RelativePos, RelativePosEncoder, ROOT_POS};
299 use crate::depseq::{DecodeError, DependencyEncoding};
300 use crate::{EncodingProb, SentenceDecoder};
301
302 const ROOT_RELATION: &str = "root";
303
304 #[test]
308 fn invalid_pos() {
309 assert_eq!(
310 RelativePosEncoder::decode_idx(
311 &HashMap::new(),
312 0,
313 &DependencyEncoding {
314 label: "X".into(),
315 head: RelativePos {
316 pos: "C".into(),
317 position: -1,
318 },
319 },
320 ),
321 Err(DecodeError::InvalidPos)
322 )
323 }
324
325 #[test]
326 fn position_out_of_bounds() {
327 assert_eq!(
328 RelativePosEncoder::decode_idx(
329 &HashMap::from_iter(vec![("A".to_string(), vec![0])]),
330 1,
331 &DependencyEncoding {
332 label: "X".into(),
333 head: RelativePos {
334 pos: "A".into(),
335 position: -2,
336 },
337 },
338 ),
339 Err(DecodeError::PositionOutOfBounds)
340 )
341 }
342
343 #[test]
344 fn backoff() {
345 let mut sent = Sentence::new();
346 sent.push(TokenBuilder::new("a").xpos("A").into());
347
348 let decoder = RelativePosEncoder::new(PosLayer::XPos, ROOT_RELATION);
349 let labels = vec![vec![
350 EncodingProb::new(
351 DependencyEncoding {
352 label: ROOT_RELATION.into(),
353 head: RelativePos {
354 pos: ROOT_POS.into(),
355 position: -2,
356 },
357 },
358 1.0,
359 ),
360 EncodingProb::new(
361 DependencyEncoding {
362 label: ROOT_RELATION.into(),
363 head: RelativePos {
364 pos: ROOT_POS.into(),
365 position: -1,
366 },
367 },
368 1.0,
369 ),
370 ]];
371
372 decoder.decode(&labels, &mut sent).unwrap();
373
374 assert_eq!(
375 sent.dep_graph().head(1),
376 Some(DepTriple::new(0, Some(ROOT_RELATION), 1))
377 );
378 }
379}