syntaxdot_encoders/depseq/
relative_pos.rs

1// Implementation note:
2//
3// We currently do a binary search to find the position of the token
4// in the POS table. This makes encoding/decoding of a sentence
5// *O(n log(n))*. We could do this in *O(n)* by keeping track of
6// the position of the token the head POS table while constructing
7// the POS table. This currently does not really seem worth it?
8
9use std::collections::HashMap;
10
11use serde_derive::{Deserialize, Serialize};
12use udgraph::graph::{DepTriple, Node, Sentence};
13use udgraph::token::Token;
14use udgraph::Error;
15
16use super::{
17    attach_orphans, break_cycles, find_or_create_root, DecodeError, DependencyEncoding, EncodeError,
18};
19use crate::{EncodingProb, SentenceDecoder, SentenceEncoder};
20
21const ROOT_POS: &str = "ROOT";
22
23/// Part-of-speech layer.
24#[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize)]
25#[serde(rename_all = "lowercase")]
26pub enum PosLayer {
27    /// Universal part-of-speech tag.
28    UPos,
29
30    /// Language-specific part-of-speech tag.
31    XPos,
32}
33
34impl PosLayer {
35    fn pos(self, token: &Token) -> Option<&str> {
36        match self {
37            PosLayer::UPos => token.upos(),
38            PosLayer::XPos => token.xpos(),
39        }
40    }
41}
42
43/// Relative head position by part-of-speech.
44///
45/// The position of the head relative to the dependent token,
46/// in terms of part-of-speech tags. For example, a position of
47/// *-2* with the pos *noun* means that the head is the second
48/// preceding noun.
49#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
50pub struct RelativePos {
51    pos: String,
52    position: isize,
53}
54
55impl RelativePos {
56    #[allow(dead_code)]
57    pub fn new(pos: impl Into<String>, position: isize) -> Self {
58        RelativePos {
59            pos: pos.into(),
60            position,
61        }
62    }
63}
64
65impl ToString for DependencyEncoding<RelativePos> {
66    fn to_string(&self) -> String {
67        format!("{}/{}/{}", self.label, self.head.pos, self.head.position)
68    }
69}
70
71/// Relative part-of-speech position encoder.
72///
73/// This encoder encodes dependency relations as token labels. The
74/// dependency relation is encoded as-is. The position of the head
75/// is encoded relative to the (dependent) token by part-of-speech.
76#[derive(Clone, Deserialize, Eq, PartialEq, Serialize)]
77pub struct RelativePosEncoder {
78    pos_layer: PosLayer,
79    root_relation: String,
80}
81
82impl RelativePosEncoder {
83    pub fn new(pos_layer: PosLayer, root_relation: impl Into<String>) -> Self {
84        RelativePosEncoder {
85            pos_layer,
86            root_relation: root_relation.into(),
87        }
88    }
89}
90
91impl RelativePosEncoder {
92    pub(crate) fn decode_idx(
93        pos_table: &HashMap<String, Vec<usize>>,
94        idx: usize,
95        encoding: &DependencyEncoding<RelativePos>,
96    ) -> Result<DepTriple<String>, DecodeError> {
97        let DependencyEncoding { label, head } = encoding;
98
99        let indices = pos_table
100            .get(head.pos.as_str())
101            .ok_or(DecodeError::InvalidPos)?;
102
103        let head_idx = Self::head_index(indices, idx, head.position)?;
104
105        Ok(DepTriple::new(head_idx, Some(label.to_owned()), idx))
106    }
107
108    /// Find the relative position of a dependent to a head.
109    ///
110    /// This methods finds the relative position of `dependent` to
111    /// `head` in `indices`.
112    fn relative_dependent_position(indices: &[usize], head: usize, dependent: usize) -> isize {
113        let mut head_position = indices
114            .binary_search(&head)
115            .expect("Head is missing in sorted POS tag list");
116
117        let dependent_position = match indices.binary_search(&dependent) {
118            Ok(idx) => idx,
119            Err(idx) => {
120                // The head moves one place if the dependent is inserted
121                // before the head. Consider e.g. the indices
122                //
123                // [3, 6, 9]
124                //     ^--- insertion point of 4.
125                //
126                // Suppose that we want to compute the relative
127                // position of 4 to its head 9 (position 2). The
128                // insertion point is 1. When computing the relative
129                // position, we should take into account that 4 lies
130                // before 6.
131                if dependent < head {
132                    head_position += 1;
133                }
134                idx
135            }
136        };
137
138        head_position as isize - dependent_position as isize
139    }
140
141    /// Get the index of the head of `dependent`.
142    ///
143    /// Get index of the head of `dependent`, given the relative
144    /// position of `dependent` to the head in `indices`.
145    fn head_index(
146        indices: &[usize],
147        dependent: usize,
148        mut relative_head_position: isize,
149    ) -> Result<usize, DecodeError> {
150        let dependent_position = match indices.binary_search(&dependent) {
151            Ok(idx) => idx,
152            Err(idx) => {
153                // Consider e.g. the indices
154                //
155                // [3, 6, 9]
156                //     ^--- insertion point of 4.
157                //
158                // Suppose that 4 is the dependent and +2 the relative
159                // position of the head. The relative position takes
160                // both succeeding elements (6, 9) into
161                // account. However, the insertion point is the
162                // element at +1. So, compensate for this in the
163                // relative position.
164                if relative_head_position > 0 {
165                    relative_head_position -= 1
166                }
167                idx
168            }
169        };
170
171        let head_position = dependent_position as isize + relative_head_position;
172        if head_position < 0 || head_position >= indices.len() as isize {
173            return Err(DecodeError::PositionOutOfBounds);
174        }
175
176        Ok(indices[head_position as usize])
177    }
178
179    pub(crate) fn pos_position_table(&self, sentence: &Sentence) -> HashMap<String, Vec<usize>> {
180        let mut table: HashMap<String, Vec<_>> = HashMap::new();
181
182        for (idx, node) in sentence.iter().enumerate() {
183            let pos = match node {
184                Node::Root => ROOT_POS.into(),
185                Node::Token(token) => match self.pos_layer.pos(token) {
186                    Some(pos) => pos.into(),
187                    None => continue,
188                },
189            };
190
191            let indices = table.entry(pos).or_default();
192            indices.push(idx);
193        }
194
195        table
196    }
197}
198
199impl SentenceEncoder for RelativePosEncoder {
200    type Encoding = DependencyEncoding<RelativePos>;
201
202    type Error = EncodeError;
203
204    fn encode(&self, sentence: &Sentence) -> Result<Vec<Self::Encoding>, Self::Error> {
205        let pos_table = self.pos_position_table(sentence);
206
207        let mut encoded = Vec::with_capacity(sentence.len());
208        for idx in 0..sentence.len() {
209            if let Node::Root = &sentence[idx] {
210                continue;
211            }
212
213            let triple = sentence
214                .dep_graph()
215                .head(idx)
216                .ok_or_else(|| EncodeError::missing_head(idx, sentence))?;
217            let relation = triple
218                .relation()
219                .ok_or_else(|| EncodeError::missing_relation(idx, sentence))?;
220
221            let head_pos = match &sentence[triple.head()] {
222                Node::Root => ROOT_POS,
223                Node::Token(head_token) => self
224                    .pos_layer
225                    .pos(head_token)
226                    .ok_or_else(|| EncodeError::missing_pos(idx, sentence))?,
227            };
228
229            let position = Self::relative_dependent_position(
230                &pos_table[head_pos],
231                triple.head(),
232                triple.dependent(),
233            );
234
235            encoded.push(DependencyEncoding {
236                label: relation.to_owned(),
237                head: RelativePos {
238                    pos: head_pos.to_owned(),
239                    position,
240                },
241            });
242        }
243
244        Ok(encoded)
245    }
246}
247
248impl SentenceDecoder for RelativePosEncoder {
249    type Encoding = DependencyEncoding<RelativePos>;
250
251    type Error = Error;
252
253    fn decode<S>(&self, labels: &[S], sentence: &mut Sentence) -> Result<(), Self::Error>
254    where
255        S: AsRef<[EncodingProb<Self::Encoding>]>,
256    {
257        let pos_table = self.pos_position_table(sentence);
258
259        // Collect to avoid immutable + mutable reference.
260        #[allow(clippy::needless_collect)]
261        let token_indices: Vec<_> = (0..sentence.len())
262            .filter(|&idx| sentence[idx].is_token())
263            .collect();
264
265        for (idx, encodings) in token_indices.into_iter().zip(labels) {
266            for encoding in encodings.as_ref() {
267                if let Ok(triple) =
268                    RelativePosEncoder::decode_idx(&pos_table, idx, encoding.encoding())
269                {
270                    sentence.dep_graph_mut().add_deprel(triple)?;
271                    break;
272                }
273            }
274        }
275
276        // Fixup tree.
277        let root_idx = find_or_create_root(
278            labels,
279            sentence,
280            |idx, encoding| Self::decode_idx(&pos_table, idx, encoding).ok(),
281            &self.root_relation,
282        )?;
283        attach_orphans(labels, sentence, root_idx)?;
284        break_cycles(sentence, root_idx)?;
285
286        Ok(())
287    }
288}
289
290#[cfg(test)]
291mod tests {
292    use std::collections::HashMap;
293    use std::iter::FromIterator;
294
295    use udgraph::graph::{DepTriple, Sentence};
296    use udgraph::token::TokenBuilder;
297
298    use super::{PosLayer, RelativePos, RelativePosEncoder, ROOT_POS};
299    use crate::depseq::{DecodeError, DependencyEncoding};
300    use crate::{EncodingProb, SentenceDecoder};
301
302    const ROOT_RELATION: &str = "root";
303
304    // Small tests for relative part-of-speech encoder. Automatic
305    // testing is performed in the module tests.
306
307    #[test]
308    fn invalid_pos() {
309        assert_eq!(
310            RelativePosEncoder::decode_idx(
311                &HashMap::new(),
312                0,
313                &DependencyEncoding {
314                    label: "X".into(),
315                    head: RelativePos {
316                        pos: "C".into(),
317                        position: -1,
318                    },
319                },
320            ),
321            Err(DecodeError::InvalidPos)
322        )
323    }
324
325    #[test]
326    fn position_out_of_bounds() {
327        assert_eq!(
328            RelativePosEncoder::decode_idx(
329                &HashMap::from_iter(vec![("A".to_string(), vec![0])]),
330                1,
331                &DependencyEncoding {
332                    label: "X".into(),
333                    head: RelativePos {
334                        pos: "A".into(),
335                        position: -2,
336                    },
337                },
338            ),
339            Err(DecodeError::PositionOutOfBounds)
340        )
341    }
342
343    #[test]
344    fn backoff() {
345        let mut sent = Sentence::new();
346        sent.push(TokenBuilder::new("a").xpos("A").into());
347
348        let decoder = RelativePosEncoder::new(PosLayer::XPos, ROOT_RELATION);
349        let labels = vec![vec![
350            EncodingProb::new(
351                DependencyEncoding {
352                    label: ROOT_RELATION.into(),
353                    head: RelativePos {
354                        pos: ROOT_POS.into(),
355                        position: -2,
356                    },
357                },
358                1.0,
359            ),
360            EncodingProb::new(
361                DependencyEncoding {
362                    label: ROOT_RELATION.into(),
363                    head: RelativePos {
364                        pos: ROOT_POS.into(),
365                        position: -1,
366                    },
367                },
368                1.0,
369            ),
370        ]];
371
372        decoder.decode(&labels, &mut sent).unwrap();
373
374        assert_eq!(
375            sent.dep_graph().head(1),
376            Some(DepTriple::new(0, Some(ROOT_RELATION), 1))
377        );
378    }
379}