syntaxdot_encoders/categorical/
encoder.rs

1use std::hash::Hash;
2use std::marker::PhantomData;
3
4use numberer::Numberer;
5use serde_derive::{Deserialize, Serialize};
6use udgraph::graph::Sentence;
7
8use crate::categorical::{ImmutableNumberer, MutableNumberer, Number};
9use crate::{EncodingProb, SentenceDecoder, SentenceEncoder};
10
11/// An immutable categorical encoder
12///
13/// This encoder does not add new encodings to the encoder. If the
14/// number of an unknown encoding is looked up, the special value `0`
15/// is used.
16pub type ImmutableCategoricalEncoder<E, V> = CategoricalEncoder<E, V, ImmutableNumberer<V>>;
17
18/// A mutable categorical encoder
19///
20/// This encoder adds new encodings to the encoder when encountered
21pub type MutableCategoricalEncoder<E, V> = CategoricalEncoder<E, V, MutableNumberer<V>>;
22
23/// An encoder wrapper that encodes/decodes to a categorical label.
24#[derive(Deserialize, Serialize)]
25pub struct CategoricalEncoder<E, V, M>
26where
27    V: Clone + Eq + Hash,
28    M: Number<V>,
29{
30    inner: E,
31    numberer: M,
32
33    #[serde(skip)]
34    _phantom: PhantomData<V>,
35}
36
37impl<E, V, M> CategoricalEncoder<E, V, M>
38where
39    V: Clone + Eq + Hash,
40    M: Number<V>,
41{
42    pub fn new(encoder: E, numberer: Numberer<V>) -> Self {
43        CategoricalEncoder {
44            inner: encoder,
45            numberer: M::new(numberer),
46            _phantom: PhantomData,
47        }
48    }
49}
50
51impl<D, M> CategoricalEncoder<D, D::Encoding, M>
52where
53    D: SentenceDecoder,
54    D::Encoding: Clone + Eq + Hash + ToOwned,
55    M: Number<D::Encoding>,
56{
57    /// Decode without applying the inner decoder.
58    pub fn decode_without_inner<S>(&self, labels: &[S]) -> Vec<Vec<EncodingProb<D::Encoding>>>
59    where
60        S: AsRef<[EncodingProb<usize>]>,
61    {
62        labels
63            .iter()
64            .map(|encoding_probs| {
65                encoding_probs
66                    .as_ref()
67                    .iter()
68                    .map(|encoding_prob| {
69                        EncodingProb::new(
70                            self.numberer
71                                .value(*encoding_prob.encoding())
72                                .expect("Unknown label"),
73                            encoding_prob.prob(),
74                        )
75                    })
76                    .collect::<Vec<_>>()
77            })
78            .collect::<Vec<_>>()
79    }
80}
81
82impl<E, V, M> CategoricalEncoder<E, V, M>
83where
84    V: Clone + Eq + Hash,
85    M: Number<V>,
86{
87    pub fn is_empty(&self) -> bool {
88        self.len() == 0
89    }
90
91    pub fn len(&self) -> usize {
92        self.numberer.len()
93    }
94}
95
96impl<E, M> SentenceEncoder for CategoricalEncoder<E, E::Encoding, M>
97where
98    E: SentenceEncoder,
99    E::Encoding: Clone + Eq + Hash,
100    M: Number<E::Encoding>,
101{
102    type Encoding = usize;
103
104    type Error = E::Error;
105
106    fn encode(&self, sentence: &Sentence) -> Result<Vec<Self::Encoding>, Self::Error> {
107        let encoding = self.inner.encode(sentence)?;
108        let categorical_encoding = encoding
109            .into_iter()
110            .map(|e| self.numberer.number(e).unwrap_or(0))
111            .collect();
112        Ok(categorical_encoding)
113    }
114}
115
116impl<D, M> SentenceDecoder for CategoricalEncoder<D, D::Encoding, M>
117where
118    D: SentenceDecoder,
119    D::Encoding: Clone + Eq + Hash,
120    M: Number<D::Encoding>,
121{
122    type Encoding = usize;
123
124    type Error = D::Error;
125
126    fn decode<S>(&self, labels: &[S], sentence: &mut Sentence) -> Result<(), Self::Error>
127    where
128        S: AsRef<[EncodingProb<Self::Encoding>]>,
129    {
130        let categorial_encoding = self.decode_without_inner(labels);
131        self.inner.decode(&categorial_encoding, sentence)
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use std::fs::File;
138    use std::io::BufReader;
139    use std::path::Path;
140
141    use conllu::io::Reader;
142    use numberer::Numberer;
143
144    use super::{EncodingProb, MutableCategoricalEncoder, SentenceDecoder, SentenceEncoder};
145    use crate::layer::Layer;
146    use crate::layer::LayerEncoder;
147
148    static NON_PROJECTIVE_DATA: &str = "testdata/lassy-small-dev.conllu";
149
150    fn test_encoding<P, E, C>(path: P, encoder_decoder: E)
151    where
152        P: AsRef<Path>,
153        E: SentenceEncoder<Encoding = C> + SentenceDecoder<Encoding = C>,
154        C: 'static + Clone,
155    {
156        let f = File::open(path).unwrap();
157        let reader = Reader::new(BufReader::new(f));
158
159        for sentence in reader {
160            let sentence = sentence.unwrap();
161
162            // Encode
163            let encodings = encoder_decoder
164                .encode(&sentence)
165                .unwrap()
166                .into_iter()
167                .map(|e| [EncodingProb::new(e, 1.)])
168                .collect::<Vec<_>>();
169
170            // Decode
171            let mut test_sentence = sentence.clone();
172            encoder_decoder
173                .decode(&encodings, &mut test_sentence)
174                .unwrap();
175
176            assert_eq!(sentence, test_sentence);
177        }
178    }
179
180    #[test]
181    fn categorical_encoder() {
182        let numberer = Numberer::new(1);
183        let encoder = LayerEncoder::new(Layer::XPos);
184        let categorical_encoder = MutableCategoricalEncoder::new(encoder, numberer);
185        assert_eq!(categorical_encoder.len(), 1);
186        test_encoding(NON_PROJECTIVE_DATA, categorical_encoder);
187    }
188}