syntaxdot_encoders/layer/
mod.rs1use std::convert::{Infallible, TryFrom};
4
5use serde_derive::{Deserialize, Serialize};
6use udgraph::graph::{Node, Sentence};
7use udgraph::token::Token;
8
9use super::{EncodingProb, SentenceDecoder, SentenceEncoder};
10
11mod error;
12use self::error::*;
13use conllu::display::ConlluFeatures;
14
15#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
17#[serde(rename_all = "lowercase")]
18pub enum Layer {
19 UPos,
20 XPos,
21
22 Feature {
24 feature: String,
25
26 default: Option<String>,
28 },
29
30 #[serde(rename = "feature_string")]
32 FeatureString,
33
34 Misc {
35 feature: String,
36
37 default: Option<String>,
39 },
40}
41
42impl Layer {
43 pub fn feature(feature: String, default: Option<String>) -> Self {
45 Layer::Feature { feature, default }
46 }
47
48 pub fn misc(feature: String, default: Option<String>) -> Self {
50 Layer::Misc { feature, default }
51 }
52}
53
54pub trait LayerValue {
56 fn set_value(&mut self, layer: &Layer, value: impl Into<String>);
58
59 fn value(&self, layer: &Layer) -> Option<String>;
61}
62
63impl LayerValue for Token {
64 fn set_value(&mut self, layer: &Layer, value: impl Into<String>) {
66 let value = value.into();
67
68 match layer {
69 Layer::UPos => {
70 self.set_upos(Some(value));
71 }
72 Layer::XPos => {
73 self.set_xpos(Some(value));
74 }
75 Layer::Feature { feature, .. } => {
76 self.features_mut().insert(feature.clone(), value);
77 }
78 Layer::FeatureString => {
79 self.set_features(
80 ConlluFeatures::try_from(value.as_str())
81 .expect("Invalid feature representation")
82 .into_owned(),
83 );
84 }
85 Layer::Misc { feature, .. } => {
86 self.misc_mut().insert(feature.clone(), Some(value));
87 }
88 };
89 }
90
91 fn value(&self, layer: &Layer) -> Option<String> {
93 match layer {
94 Layer::UPos => self.upos().map(ToOwned::to_owned),
95 Layer::XPos => self.xpos().map(ToOwned::to_owned),
96 Layer::FeatureString => Some(ConlluFeatures::borrowed(self.features()).to_string()),
97 Layer::Feature { feature, default } => self
98 .features()
99 .get(feature)
100 .cloned()
101 .or_else(|| default.clone()),
102 Layer::Misc { feature, default } => match self.misc().get(feature) {
103 Some(Some(ref val)) => Some(val.clone()),
105
106 Some(None) => None,
108
109 None => default.clone(),
111 },
112 }
113 }
114}
115
116#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
118pub struct LayerEncoder {
119 layer: Layer,
120}
121
122impl LayerEncoder {
123 pub fn new(layer: Layer) -> Self {
125 LayerEncoder { layer }
126 }
127}
128
129impl SentenceDecoder for LayerEncoder {
130 type Encoding = String;
131
132 type Error = Infallible;
133
134 fn decode<S>(&self, labels: &[S], sentence: &mut Sentence) -> Result<(), Self::Error>
135 where
136 S: AsRef<[EncodingProb<Self::Encoding>]>,
137 {
138 assert_eq!(
139 labels.len(),
140 sentence.len() - 1,
141 "Labels and sentence length mismatch"
142 );
143
144 for (token, token_labels) in sentence
145 .iter_mut()
146 .filter_map(Node::token_mut)
147 .zip(labels.iter())
148 {
149 if let Some(label) = token_labels.as_ref().get(0) {
150 token.set_value(&self.layer, label.encoding().as_str());
151 }
152 }
153
154 Ok(())
155 }
156}
157
158impl SentenceEncoder for LayerEncoder {
159 type Encoding = String;
160
161 type Error = EncodeError;
162
163 fn encode(&self, sentence: &Sentence) -> Result<Vec<Self::Encoding>, Self::Error> {
164 let mut encoding = Vec::with_capacity(sentence.len() - 1);
165 for token in sentence.iter().filter_map(Node::token) {
166 let label = token
167 .value(&self.layer)
168 .ok_or_else(|| EncodeError::MissingLabel {
169 form: token.form().to_owned(),
170 })?;
171 encoding.push(label.to_owned());
172 }
173
174 Ok(encoding)
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use std::convert::TryFrom;
181
182 use conllu::display::{ConlluFeatures, ConlluMisc};
183 use udgraph::token::{Token, TokenBuilder};
184
185 use crate::layer::{Layer, LayerValue};
186
187 #[test]
188 fn layer() {
189 let token: Token = TokenBuilder::new("test")
190 .upos("CP")
191 .xpos("P")
192 .features(ConlluFeatures::try_from("c=d|a=b").unwrap().into_owned())
193 .misc(ConlluMisc::from("u=v|x=y").into_owned())
194 .into();
195
196 assert_eq!(token.value(&Layer::UPos), Some("CP".to_string()));
197 assert_eq!(token.value(&Layer::XPos), Some("P".to_string()));
198 assert_eq!(
199 token.value(&Layer::feature("a".to_owned(), None)),
200 Some("b".to_string())
201 );
202 assert_eq!(
203 token.value(&Layer::feature("c".to_owned(), None)),
204 Some("d".to_string())
205 );
206 assert_eq!(token.value(&Layer::feature("e".to_owned(), None)), None);
207 assert_eq!(
208 token.value(&Layer::feature(
209 "e".to_owned(),
210 Some("some_default".to_string())
211 )),
212 Some("some_default".to_string())
213 );
214 assert_eq!(
215 token.value(&Layer::FeatureString),
216 Some("a=b|c=d".to_string())
217 );
218
219 assert_eq!(
220 token.value(&Layer::misc("u".to_owned(), None)),
221 Some("v".to_string())
222 );
223 assert_eq!(
224 token.value(&Layer::misc("x".to_owned(), None)),
225 Some("y".to_string())
226 );
227 assert_eq!(token.value(&Layer::misc("z".to_owned(), None)), None);
228 assert_eq!(
229 token.value(&Layer::misc(
230 "z".to_owned(),
231 Some("some_default".to_string())
232 )),
233 Some("some_default".to_string())
234 );
235 }
236
237 #[test]
238 fn set_layer() {
239 let mut token: Token = TokenBuilder::new("test").into();
240
241 assert_eq!(token.value(&Layer::FeatureString), Some("_".to_string()));
242
243 token.set_value(&Layer::UPos, "CP");
244 token.set_value(&Layer::XPos, "P");
245 token.set_value(&Layer::feature("a".to_owned(), None), "b");
246 token.set_value(&Layer::misc("u".to_owned(), None), "v");
247
248 assert_eq!(token.value(&Layer::UPos), Some("CP".to_string()));
249 assert_eq!(token.value(&Layer::XPos), Some("P".to_string()));
250 assert_eq!(
251 token.value(&Layer::feature("a".to_owned(), None)),
252 Some("b".to_string())
253 );
254 assert_eq!(token.value(&Layer::feature("c".to_owned(), None)), None);
255 assert_eq!(token.value(&Layer::FeatureString), Some("a=b".to_string()));
256
257 assert_eq!(
258 token.value(&Layer::misc("u".to_owned(), None)),
259 Some("v".to_string())
260 );
261 assert_eq!(token.value(&Layer::misc("x".to_owned(), None)), None);
262 }
263}