syntaxdot_encoders/lemma/
encoder.rs1use std::convert::Infallible;
2
3use serde::{Deserialize, Serialize};
4use udgraph::graph::{Node, Sentence};
5
6use crate::lemma::edit_tree::EditTree;
7use crate::lemma::EncodeError;
8use crate::{EncodingProb, SentenceDecoder, SentenceEncoder};
9
10#[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize)]
15#[serde(rename_all = "lowercase")]
16pub enum BackoffStrategy {
17 Nothing,
18 Form,
19}
20
21#[derive(Clone, Copy, Debug, Deserialize, Serialize)]
26pub struct EditTreeEncoder {
27 backoff_strategy: BackoffStrategy,
28}
29
30impl EditTreeEncoder {
31 pub fn new(backoff_strategy: BackoffStrategy) -> Self {
32 EditTreeEncoder { backoff_strategy }
33 }
34}
35
36impl SentenceDecoder for EditTreeEncoder {
37 type Encoding = EditTree;
38
39 type Error = Infallible;
40
41 fn decode<S>(&self, labels: &[S], sentence: &mut Sentence) -> Result<(), Self::Error>
42 where
43 S: AsRef<[EncodingProb<Self::Encoding>]>,
44 {
45 assert_eq!(
46 labels.len(),
47 sentence.len() - 1,
48 "Labels and sentence length mismatch"
49 );
50
51 for (token, token_labels) in sentence
52 .iter_mut()
53 .filter_map(Node::token_mut)
54 .zip(labels.iter())
55 {
56 if let Some(label) = token_labels.as_ref().get(0) {
57 let form = token.form().chars().collect::<Vec<_>>();
58
59 if let Some(lemma) = label.encoding().apply(&form) {
60 let lemma = lemma.into_iter().collect::<String>();
63 token.set_lemma(Some(lemma));
64 } else if let BackoffStrategy::Form = self.backoff_strategy {
65 token.set_lemma(Some(token.form().to_owned()));
69 }
70 }
71 }
72
73 Ok(())
74 }
75}
76
77impl SentenceEncoder for EditTreeEncoder {
78 type Encoding = EditTree;
79
80 type Error = EncodeError;
81
82 fn encode(&self, sentence: &Sentence) -> Result<Vec<Self::Encoding>, Self::Error> {
83 let mut encoding = Vec::with_capacity(sentence.len() - 1);
84
85 for token in sentence.iter().filter_map(Node::token) {
86 let lemma = token
87 .lemma()
88 .or_else(|| {
89 if token.form() == "_" {
90 Some("_").to_owned()
91 } else {
92 None
93 }
94 })
95 .ok_or_else(|| EncodeError::MissingLemma {
96 form: token.form().to_owned(),
97 })?;
98
99 let edit_tree = EditTree::create_tree(
100 &token.form().chars().collect::<Vec<_>>(),
101 &lemma.chars().collect::<Vec<_>>(),
102 )
103 .ok_or_else(|| EncodeError::NoEditTree {
104 form: token.form().to_string(),
105 lemma: lemma.to_string(),
106 })?;
107
108 encoding.push(edit_tree);
109 }
110
111 Ok(encoding)
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use std::iter;
118
119 use udgraph::graph::{Node, Sentence};
120 use udgraph::token::{Token, TokenBuilder};
121
122 use super::{BackoffStrategy, EditTree, EditTreeEncoder};
123 use crate::{EncodingProb, SentenceDecoder, SentenceEncoder};
124
125 fn encode_and_wrap(
126 encoder: &EditTreeEncoder,
127 sent: &Sentence,
128 ) -> Vec<Vec<EncodingProb<EditTree>>> {
129 encoder
130 .encode(sent)
131 .unwrap()
132 .into_iter()
133 .map(|encoding| vec![EncodingProb::new(encoding, 1.0)])
134 .collect::<Vec<_>>()
135 }
136
137 fn sentence_from_forms(tokens: &[&str]) -> Sentence {
138 tokens.iter().map(|t| Token::new(*t)).collect()
139 }
140
141 fn sentence_from_pairs(token_lemmas: &[(&str, &str)]) -> Sentence {
142 token_lemmas
143 .iter()
144 .map(|(t, l)| TokenBuilder::new(*t).lemma(*l).into())
145 .collect()
146 }
147
148 #[test]
149 fn encoder_decoder_roundtrip() {
150 let sent_encode =
151 sentence_from_pairs(&[("hij", "hij"), ("heeft", "hebben"), ("gefietst", "fietsen")]);
152
153 let encoder = EditTreeEncoder::new(BackoffStrategy::Nothing);
154 let labels = encode_and_wrap(&encoder, &sent_encode);
155
156 let mut sent_decode = sentence_from_forms(&["hij", "heeft", "gefietst"]);
157 encoder.decode(&labels, &mut sent_decode).unwrap();
158
159 assert_eq!(sent_encode, sent_decode);
160 }
161
162 #[test]
163 fn decoder_backoff_nothing() {
164 let sent_encode = sentence_from_pairs(&[
165 ("kinderen", "kind"),
166 ("hadden", "hebben"),
167 ("gefietst", "fietsen"),
168 ]);
169 let encoder = EditTreeEncoder::new(BackoffStrategy::Nothing);
170 let labels = encode_and_wrap(&encoder, &sent_encode);
171
172 let mut sent_decode = sentence_from_forms(&["het", "is", "anders"]);
173 encoder.decode(&labels, &mut sent_decode).unwrap();
174
175 assert!(sent_decode
176 .iter()
177 .filter_map(Node::token)
178 .map(Token::lemma)
179 .all(|lemma| lemma.is_none()));
180 }
181
182 #[test]
183 fn decoder_backoff_form() {
184 let sent_encode = sentence_from_pairs(&[
185 ("kinderen", "kind"),
186 ("hadden", "hebben"),
187 ("gefietst", "fietsen"),
188 ]);
189 let encoder = EditTreeEncoder::new(BackoffStrategy::Form);
190 let labels = encode_and_wrap(&encoder, &sent_encode);
191
192 let mut sent_decode = sentence_from_forms(&["het", "is", "anders"]);
193 encoder.decode(&labels, &mut sent_decode).unwrap();
194
195 for token in sent_decode.iter().filter_map(Node::token) {
196 assert_eq!(token.lemma(), Some(token.form()));
197 }
198 }
199
200 #[test]
201 fn handles_underscore_form_lemma() {
202 let sentence: Sentence = iter::once(TokenBuilder::new("_").into()).collect();
203 let encoder = EditTreeEncoder::new(BackoffStrategy::Form);
204 let labels = encode_and_wrap(&encoder, &sentence);
205
206 let mut sent_decode = sentence_from_forms(&["_"]);
207 encoder.decode(&labels, &mut sent_decode).unwrap();
208
209 assert_eq!(sent_decode, sentence_from_pairs(&[("_", "_")]));
210 }
211
212 #[test]
213 fn rejects_empty_lemma_with_nonempty_form() {
214 let sentence = sentence_from_forms(&["iets"]);
215 let encoder = EditTreeEncoder::new(BackoffStrategy::Nothing);
216 assert!(encoder.encode(&sentence).is_err());
217 }
218}