syntaxdot_encoders/lemma/
encoder.rs

1use std::convert::Infallible;
2
3use serde::{Deserialize, Serialize};
4use udgraph::graph::{Node, Sentence};
5
6use crate::lemma::edit_tree::EditTree;
7use crate::lemma::EncodeError;
8use crate::{EncodingProb, SentenceDecoder, SentenceEncoder};
9
10/// Back-off strategy.
11///
12/// This is the strategy that will be used when an edit tree
13/// could not be applied.
14#[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize)]
15#[serde(rename_all = "lowercase")]
16pub enum BackoffStrategy {
17    Nothing,
18    Form,
19}
20
21/// Edit tree-based lemma encoder.
22///
23/// This encoder encodes a lemma as an edit tree that is applied to an
24/// unlemmatized form.
25#[derive(Clone, Copy, Debug, Deserialize, Serialize)]
26pub struct EditTreeEncoder {
27    backoff_strategy: BackoffStrategy,
28}
29
30impl EditTreeEncoder {
31    pub fn new(backoff_strategy: BackoffStrategy) -> Self {
32        EditTreeEncoder { backoff_strategy }
33    }
34}
35
36impl SentenceDecoder for EditTreeEncoder {
37    type Encoding = EditTree;
38
39    type Error = Infallible;
40
41    fn decode<S>(&self, labels: &[S], sentence: &mut Sentence) -> Result<(), Self::Error>
42    where
43        S: AsRef<[EncodingProb<Self::Encoding>]>,
44    {
45        assert_eq!(
46            labels.len(),
47            sentence.len() - 1,
48            "Labels and sentence length mismatch"
49        );
50
51        for (token, token_labels) in sentence
52            .iter_mut()
53            .filter_map(Node::token_mut)
54            .zip(labels.iter())
55        {
56            if let Some(label) = token_labels.as_ref().get(0) {
57                let form = token.form().chars().collect::<Vec<_>>();
58
59                if let Some(lemma) = label.encoding().apply(&form) {
60                    // If the edit script can be applied, use the
61                    // resulting lemma...
62                    let lemma = lemma.into_iter().collect::<String>();
63                    token.set_lemma(Some(lemma));
64                } else if let BackoffStrategy::Form = self.backoff_strategy {
65                    // .. if the edit script failed and the back-off
66                    // strategy is to set the form as the lemma,
67                    // do so.
68                    token.set_lemma(Some(token.form().to_owned()));
69                }
70            }
71        }
72
73        Ok(())
74    }
75}
76
77impl SentenceEncoder for EditTreeEncoder {
78    type Encoding = EditTree;
79
80    type Error = EncodeError;
81
82    fn encode(&self, sentence: &Sentence) -> Result<Vec<Self::Encoding>, Self::Error> {
83        let mut encoding = Vec::with_capacity(sentence.len() - 1);
84
85        for token in sentence.iter().filter_map(Node::token) {
86            let lemma = token
87                .lemma()
88                .or_else(|| {
89                    if token.form() == "_" {
90                        Some("_").to_owned()
91                    } else {
92                        None
93                    }
94                })
95                .ok_or_else(|| EncodeError::MissingLemma {
96                    form: token.form().to_owned(),
97                })?;
98
99            let edit_tree = EditTree::create_tree(
100                &token.form().chars().collect::<Vec<_>>(),
101                &lemma.chars().collect::<Vec<_>>(),
102            )
103            .ok_or_else(|| EncodeError::NoEditTree {
104                form: token.form().to_string(),
105                lemma: lemma.to_string(),
106            })?;
107
108            encoding.push(edit_tree);
109        }
110
111        Ok(encoding)
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use std::iter;
118
119    use udgraph::graph::{Node, Sentence};
120    use udgraph::token::{Token, TokenBuilder};
121
122    use super::{BackoffStrategy, EditTree, EditTreeEncoder};
123    use crate::{EncodingProb, SentenceDecoder, SentenceEncoder};
124
125    fn encode_and_wrap(
126        encoder: &EditTreeEncoder,
127        sent: &Sentence,
128    ) -> Vec<Vec<EncodingProb<EditTree>>> {
129        encoder
130            .encode(sent)
131            .unwrap()
132            .into_iter()
133            .map(|encoding| vec![EncodingProb::new(encoding, 1.0)])
134            .collect::<Vec<_>>()
135    }
136
137    fn sentence_from_forms(tokens: &[&str]) -> Sentence {
138        tokens.iter().map(|t| Token::new(*t)).collect()
139    }
140
141    fn sentence_from_pairs(token_lemmas: &[(&str, &str)]) -> Sentence {
142        token_lemmas
143            .iter()
144            .map(|(t, l)| TokenBuilder::new(*t).lemma(*l).into())
145            .collect()
146    }
147
148    #[test]
149    fn encoder_decoder_roundtrip() {
150        let sent_encode =
151            sentence_from_pairs(&[("hij", "hij"), ("heeft", "hebben"), ("gefietst", "fietsen")]);
152
153        let encoder = EditTreeEncoder::new(BackoffStrategy::Nothing);
154        let labels = encode_and_wrap(&encoder, &sent_encode);
155
156        let mut sent_decode = sentence_from_forms(&["hij", "heeft", "gefietst"]);
157        encoder.decode(&labels, &mut sent_decode).unwrap();
158
159        assert_eq!(sent_encode, sent_decode);
160    }
161
162    #[test]
163    fn decoder_backoff_nothing() {
164        let sent_encode = sentence_from_pairs(&[
165            ("kinderen", "kind"),
166            ("hadden", "hebben"),
167            ("gefietst", "fietsen"),
168        ]);
169        let encoder = EditTreeEncoder::new(BackoffStrategy::Nothing);
170        let labels = encode_and_wrap(&encoder, &sent_encode);
171
172        let mut sent_decode = sentence_from_forms(&["het", "is", "anders"]);
173        encoder.decode(&labels, &mut sent_decode).unwrap();
174
175        assert!(sent_decode
176            .iter()
177            .filter_map(Node::token)
178            .map(Token::lemma)
179            .all(|lemma| lemma.is_none()));
180    }
181
182    #[test]
183    fn decoder_backoff_form() {
184        let sent_encode = sentence_from_pairs(&[
185            ("kinderen", "kind"),
186            ("hadden", "hebben"),
187            ("gefietst", "fietsen"),
188        ]);
189        let encoder = EditTreeEncoder::new(BackoffStrategy::Form);
190        let labels = encode_and_wrap(&encoder, &sent_encode);
191
192        let mut sent_decode = sentence_from_forms(&["het", "is", "anders"]);
193        encoder.decode(&labels, &mut sent_decode).unwrap();
194
195        for token in sent_decode.iter().filter_map(Node::token) {
196            assert_eq!(token.lemma(), Some(token.form()));
197        }
198    }
199
200    #[test]
201    fn handles_underscore_form_lemma() {
202        let sentence: Sentence = iter::once(TokenBuilder::new("_").into()).collect();
203        let encoder = EditTreeEncoder::new(BackoffStrategy::Form);
204        let labels = encode_and_wrap(&encoder, &sentence);
205
206        let mut sent_decode = sentence_from_forms(&["_"]);
207        encoder.decode(&labels, &mut sent_decode).unwrap();
208
209        assert_eq!(sent_decode, sentence_from_pairs(&[("_", "_")]));
210    }
211
212    #[test]
213    fn rejects_empty_lemma_with_nonempty_form() {
214        let sentence = sentence_from_forms(&["iets"]);
215        let encoder = EditTreeEncoder::new(BackoffStrategy::Nothing);
216        assert!(encoder.encode(&sentence).is_err());
217    }
218}