syntaxdot_encoders/depseq/
relative_position.rs

1use serde_derive::{Deserialize, Serialize};
2use udgraph::graph::{DepTriple, Sentence};
3use udgraph::Error;
4
5use super::{
6    attach_orphans, break_cycles, find_or_create_root, DecodeError, DependencyEncoding, EncodeError,
7};
8use crate::{EncodingProb, SentenceDecoder, SentenceEncoder};
9
10/// Relative head position.
11///
12/// The position of the head relative to the dependent token.
13#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
14pub struct RelativePosition(isize);
15
16impl ToString for DependencyEncoding<RelativePosition> {
17    fn to_string(&self) -> String {
18        format!("{}/{}", self.label, self.head.0)
19    }
20}
21
22/// Relative position encoder.
23///
24/// This encoder encodes dependency relations as token labels. The
25/// dependency relation is encoded as-is. The position of the head
26/// is encoded relative to the (dependent) token.
27#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
28pub struct RelativePositionEncoder {
29    root_relation: String,
30}
31
32impl RelativePositionEncoder {
33    pub fn new(root_relation: impl Into<String>) -> Self {
34        RelativePositionEncoder {
35            root_relation: root_relation.into(),
36        }
37    }
38}
39
40impl RelativePositionEncoder {
41    fn decode_idx(
42        idx: usize,
43        sentence_len: usize,
44        encoding: &DependencyEncoding<RelativePosition>,
45    ) -> Result<DepTriple<String>, DecodeError> {
46        let DependencyEncoding {
47            label,
48            head: RelativePosition(head),
49        } = encoding;
50
51        let head_idx = idx as isize + head;
52        if head_idx < 0 || head_idx >= sentence_len as isize {
53            return Err(DecodeError::PositionOutOfBounds);
54        }
55
56        Ok(DepTriple::new(
57            (idx as isize + head) as usize,
58            Some(label.clone()),
59            idx,
60        ))
61    }
62}
63
64impl SentenceEncoder for RelativePositionEncoder {
65    type Encoding = DependencyEncoding<RelativePosition>;
66
67    type Error = EncodeError;
68
69    fn encode(&self, sentence: &Sentence) -> Result<Vec<Self::Encoding>, Self::Error> {
70        let mut encoded = Vec::with_capacity(sentence.len());
71        for idx in 1..sentence.len() {
72            let triple = sentence
73                .dep_graph()
74                .head(idx)
75                .ok_or_else(|| EncodeError::missing_head(idx, sentence))?;
76            let relation = triple
77                .relation()
78                .ok_or_else(|| EncodeError::missing_relation(idx, sentence))?;
79
80            encoded.push(DependencyEncoding {
81                label: relation.to_owned(),
82                head: RelativePosition(triple.head() as isize - triple.dependent() as isize),
83            });
84        }
85
86        Ok(encoded)
87    }
88}
89
90impl SentenceDecoder for RelativePositionEncoder {
91    type Encoding = DependencyEncoding<RelativePosition>;
92
93    type Error = Error;
94
95    fn decode<S>(&self, labels: &[S], sentence: &mut Sentence) -> Result<(), Self::Error>
96    where
97        S: AsRef<[EncodingProb<Self::Encoding>]>,
98    {
99        // Collect to avoid immutable + mutable reference.
100        #[allow(clippy::needless_collect)]
101        let token_indices: Vec<_> = (0..sentence.len())
102            .filter(|&idx| sentence[idx].is_token())
103            .collect();
104
105        for (idx, encodings) in token_indices.into_iter().zip(labels) {
106            for encoding in encodings.as_ref() {
107                if let Ok(triple) =
108                    RelativePositionEncoder::decode_idx(idx, sentence.len(), encoding.encoding())
109                {
110                    sentence.dep_graph_mut().add_deprel(triple)?;
111                    break;
112                }
113            }
114        }
115
116        // Fixup tree.
117        let sentence_len = sentence.len();
118        let root_idx = find_or_create_root(
119            labels,
120            sentence,
121            |idx, encoding| Self::decode_idx(idx, sentence_len, encoding).ok(),
122            &self.root_relation,
123        )?;
124        attach_orphans(labels, sentence, root_idx)?;
125        break_cycles(sentence, root_idx)?;
126
127        Ok(())
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use udgraph::graph::{DepTriple, Sentence};
134    use udgraph::token::TokenBuilder;
135
136    use super::{RelativePosition, RelativePositionEncoder};
137    use crate::depseq::{DecodeError, DependencyEncoding};
138    use crate::{EncodingProb, SentenceDecoder};
139
140    const ROOT_RELATION: &str = "root";
141
142    // Small tests for the relative position encoder. Automatic
143    // testing is performed in the module tests.
144
145    #[test]
146    fn position_out_of_bounds() {
147        let mut sent = Sentence::new();
148        sent.push(TokenBuilder::new("a").xpos("A").into());
149        sent.push(TokenBuilder::new("b").xpos("B").into());
150
151        assert_eq!(
152            RelativePositionEncoder::decode_idx(
153                1,
154                sent.len(),
155                &DependencyEncoding {
156                    label: "X".into(),
157                    head: RelativePosition(-2),
158                },
159            ),
160            Err(DecodeError::PositionOutOfBounds)
161        )
162    }
163
164    #[test]
165    fn backoff() {
166        let mut sent = Sentence::new();
167        sent.push(TokenBuilder::new("a").xpos("A").into());
168
169        let decoder = RelativePositionEncoder::new(ROOT_RELATION);
170        let labels = vec![vec![
171            EncodingProb::new(
172                DependencyEncoding {
173                    label: ROOT_RELATION.into(),
174                    head: RelativePosition(-2),
175                },
176                1.0,
177            ),
178            EncodingProb::new(
179                DependencyEncoding {
180                    label: ROOT_RELATION.into(),
181                    head: RelativePosition(-1),
182                },
183                1.0,
184            ),
185        ]];
186
187        decoder.decode(&labels, &mut sent).unwrap();
188
189        assert_eq!(
190            sent.dep_graph().head(1),
191            Some(DepTriple::new(0, Some(ROOT_RELATION), 1))
192        );
193    }
194}