syntaxdot_encoders/categorical/
encoder.rs1use std::hash::Hash;
2use std::marker::PhantomData;
3
4use numberer::Numberer;
5use serde_derive::{Deserialize, Serialize};
6use udgraph::graph::Sentence;
7
8use crate::categorical::{ImmutableNumberer, MutableNumberer, Number};
9use crate::{EncodingProb, SentenceDecoder, SentenceEncoder};
10
11pub type ImmutableCategoricalEncoder<E, V> = CategoricalEncoder<E, V, ImmutableNumberer<V>>;
17
18pub type MutableCategoricalEncoder<E, V> = CategoricalEncoder<E, V, MutableNumberer<V>>;
22
23#[derive(Deserialize, Serialize)]
25pub struct CategoricalEncoder<E, V, M>
26where
27 V: Clone + Eq + Hash,
28 M: Number<V>,
29{
30 inner: E,
31 numberer: M,
32
33 #[serde(skip)]
34 _phantom: PhantomData<V>,
35}
36
37impl<E, V, M> CategoricalEncoder<E, V, M>
38where
39 V: Clone + Eq + Hash,
40 M: Number<V>,
41{
42 pub fn new(encoder: E, numberer: Numberer<V>) -> Self {
43 CategoricalEncoder {
44 inner: encoder,
45 numberer: M::new(numberer),
46 _phantom: PhantomData,
47 }
48 }
49}
50
51impl<D, M> CategoricalEncoder<D, D::Encoding, M>
52where
53 D: SentenceDecoder,
54 D::Encoding: Clone + Eq + Hash + ToOwned,
55 M: Number<D::Encoding>,
56{
57 pub fn decode_without_inner<S>(&self, labels: &[S]) -> Vec<Vec<EncodingProb<D::Encoding>>>
59 where
60 S: AsRef<[EncodingProb<usize>]>,
61 {
62 labels
63 .iter()
64 .map(|encoding_probs| {
65 encoding_probs
66 .as_ref()
67 .iter()
68 .map(|encoding_prob| {
69 EncodingProb::new(
70 self.numberer
71 .value(*encoding_prob.encoding())
72 .expect("Unknown label"),
73 encoding_prob.prob(),
74 )
75 })
76 .collect::<Vec<_>>()
77 })
78 .collect::<Vec<_>>()
79 }
80}
81
82impl<E, V, M> CategoricalEncoder<E, V, M>
83where
84 V: Clone + Eq + Hash,
85 M: Number<V>,
86{
87 pub fn is_empty(&self) -> bool {
88 self.len() == 0
89 }
90
91 pub fn len(&self) -> usize {
92 self.numberer.len()
93 }
94}
95
96impl<E, M> SentenceEncoder for CategoricalEncoder<E, E::Encoding, M>
97where
98 E: SentenceEncoder,
99 E::Encoding: Clone + Eq + Hash,
100 M: Number<E::Encoding>,
101{
102 type Encoding = usize;
103
104 type Error = E::Error;
105
106 fn encode(&self, sentence: &Sentence) -> Result<Vec<Self::Encoding>, Self::Error> {
107 let encoding = self.inner.encode(sentence)?;
108 let categorical_encoding = encoding
109 .into_iter()
110 .map(|e| self.numberer.number(e).unwrap_or(0))
111 .collect();
112 Ok(categorical_encoding)
113 }
114}
115
116impl<D, M> SentenceDecoder for CategoricalEncoder<D, D::Encoding, M>
117where
118 D: SentenceDecoder,
119 D::Encoding: Clone + Eq + Hash,
120 M: Number<D::Encoding>,
121{
122 type Encoding = usize;
123
124 type Error = D::Error;
125
126 fn decode<S>(&self, labels: &[S], sentence: &mut Sentence) -> Result<(), Self::Error>
127 where
128 S: AsRef<[EncodingProb<Self::Encoding>]>,
129 {
130 let categorial_encoding = self.decode_without_inner(labels);
131 self.inner.decode(&categorial_encoding, sentence)
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use std::fs::File;
138 use std::io::BufReader;
139 use std::path::Path;
140
141 use conllu::io::Reader;
142 use numberer::Numberer;
143
144 use super::{EncodingProb, MutableCategoricalEncoder, SentenceDecoder, SentenceEncoder};
145 use crate::layer::Layer;
146 use crate::layer::LayerEncoder;
147
148 static NON_PROJECTIVE_DATA: &str = "testdata/lassy-small-dev.conllu";
149
150 fn test_encoding<P, E, C>(path: P, encoder_decoder: E)
151 where
152 P: AsRef<Path>,
153 E: SentenceEncoder<Encoding = C> + SentenceDecoder<Encoding = C>,
154 C: 'static + Clone,
155 {
156 let f = File::open(path).unwrap();
157 let reader = Reader::new(BufReader::new(f));
158
159 for sentence in reader {
160 let sentence = sentence.unwrap();
161
162 let encodings = encoder_decoder
164 .encode(&sentence)
165 .unwrap()
166 .into_iter()
167 .map(|e| [EncodingProb::new(e, 1.)])
168 .collect::<Vec<_>>();
169
170 let mut test_sentence = sentence.clone();
172 encoder_decoder
173 .decode(&encodings, &mut test_sentence)
174 .unwrap();
175
176 assert_eq!(sentence, test_sentence);
177 }
178 }
179
180 #[test]
181 fn categorical_encoder() {
182 let numberer = Numberer::new(1);
183 let encoder = LayerEncoder::new(Layer::XPos);
184 let categorical_encoder = MutableCategoricalEncoder::new(encoder, numberer);
185 assert_eq!(categorical_encoder.len(), 1);
186 test_encoding(NON_PROJECTIVE_DATA, categorical_encoder);
187 }
188}