syntaxdot_transformers/
layers.rs

1//! Basic neural network modules.
2//!
3//! These are modules that are not provided by the Torch binding, or where
4//! different behavior is required from the modules.
5
6use std::borrow::Borrow;
7
8use syntaxdot_tch_ext::PathExt;
9use tch::nn::{ConvConfig, Init};
10use tch::{self, Tensor};
11
12use crate::module::{FallibleModule, FallibleModuleT};
13use crate::TransformerError;
14
15/// 1-D convolution.
16#[derive(Debug)]
17pub struct Conv1D {
18    pub ws: Tensor,
19    pub bs: Option<Tensor>,
20    pub config: ConvConfig,
21}
22
23impl Conv1D {
24    pub fn new<'a>(
25        vs: impl Borrow<PathExt<'a>>,
26        in_features: i64,
27        out_features: i64,
28        kernel_size: i64,
29        groups: i64,
30    ) -> Result<Self, TransformerError> {
31        let vs = vs.borrow();
32
33        let config = ConvConfig {
34            groups,
35            ..ConvConfig::default()
36        };
37
38        let bs = if config.bias {
39            Some(vs.var("bias", &[out_features], config.bs_init)?)
40        } else {
41            None
42        };
43
44        let ws = vs.var(
45            "weight",
46            &[out_features, in_features / groups, kernel_size],
47            config.ws_init,
48        )?;
49
50        Ok(Conv1D { ws, bs, config })
51    }
52}
53
54impl FallibleModule for Conv1D {
55    type Error = TransformerError;
56
57    fn forward(&self, xs: &Tensor) -> Result<Tensor, Self::Error> {
58        Ok(Tensor::f_conv1d(
59            xs,
60            &self.ws,
61            self.bs.as_ref(),
62            &[self.config.stride],
63            &[self.config.padding],
64            &[self.config.dilation],
65            self.config.groups,
66        )?)
67    }
68}
69
70/// Dropout layer.
71///
72/// This layer zeros out random elements of a tensor with probability
73/// *p*. Dropout is a form of regularization and prevents
74/// co-adaptation of neurons.
75#[derive(Debug)]
76pub struct Dropout {
77    p: f64,
78}
79
80impl Dropout {
81    /// Drop out elements with probability *p*.
82    pub fn new(p: f64) -> Self {
83        Dropout { p }
84    }
85}
86
87impl FallibleModuleT for Dropout {
88    type Error = TransformerError;
89
90    fn forward_t(&self, input: &Tensor, train: bool) -> Result<Tensor, Self::Error> {
91        Ok(input.f_dropout(self.p, train)?)
92    }
93}
94
95/// Embedding lookup layer.
96#[derive(Debug)]
97pub struct Embedding(pub Tensor);
98
99impl Embedding {
100    pub fn new<'a>(
101        vs: impl Borrow<PathExt<'a>>,
102        name: &str,
103        num_embeddings: i64,
104        embedding_dim: i64,
105        init: Init,
106    ) -> Result<Self, TransformerError> {
107        Ok(Embedding(vs.borrow().var(
108            name,
109            &[num_embeddings, embedding_dim],
110            init,
111        )?))
112    }
113}
114
115impl FallibleModule for Embedding {
116    type Error = TransformerError;
117
118    fn forward(&self, input: &Tensor) -> Result<Tensor, Self::Error> {
119        Ok(Tensor::f_embedding(&self.0, input, -1, false, false)?)
120    }
121}
122
123/// Layer that applies layer normalization.
124#[derive(Debug)]
125pub struct LayerNorm {
126    eps: f64,
127    normalized_shape: Vec<i64>,
128
129    weight: Option<Tensor>,
130    bias: Option<Tensor>,
131}
132
133impl LayerNorm {
134    /// Construct a layer normalization layer.
135    ///
136    /// The mean and standard deviation are computed over the last
137    /// number of dimensions with the shape defined by
138    /// `normalized_shape`. If `elementwise_affine` is `True`, a
139    /// learnable affine transformation of the shape
140    /// `normalized_shape` is added after normalization.
141    pub fn new<'a>(
142        vs: impl Borrow<PathExt<'a>>,
143        normalized_shape: impl Into<Vec<i64>>,
144        eps: f64,
145        elementwise_affine: bool,
146    ) -> Self {
147        let vs = vs.borrow();
148
149        let normalized_shape = normalized_shape.into();
150
151        let (weight, bias) = if elementwise_affine {
152            (
153                Some(vs.ones("weight", &normalized_shape)),
154                Some(vs.zeros("bias", &normalized_shape)),
155            )
156        } else {
157            (None, None)
158        };
159
160        LayerNorm {
161            eps,
162            normalized_shape,
163            weight,
164            bias,
165        }
166    }
167}
168
169impl FallibleModule for LayerNorm {
170    type Error = TransformerError;
171
172    fn forward(&self, input: &Tensor) -> Result<Tensor, Self::Error> {
173        // XXX: last parameter is `cudnn_enable`. What happens if we always
174        //      set this to `true`?
175        Ok(input.f_layer_norm(
176            &self.normalized_shape,
177            self.weight.as_ref(),
178            self.bias.as_ref(),
179            self.eps,
180            false,
181        )?)
182    }
183}
184
185/// Configuration for the `Bilinear` layer.
186#[derive(Clone, Copy, Debug)]
187pub struct PairwiseBilinearConfig {
188    /// The number of input features.
189    pub in_features: i64,
190
191    /// The number of output features.
192    pub out_features: i64,
193
194    /// Standard deviation for random initialization.
195    pub initializer_range: f64,
196
197    pub bias_u: bool,
198
199    pub bias_v: bool,
200
201    pub pairwise: bool,
202}
203
204/// Pairwise bilinear forms.
205///
206/// Given two batches with sequence length *n*, apply pairwise
207/// bilinear forms to each timestep within a sequence.
208#[derive(Debug)]
209pub struct PairwiseBilinear {
210    weight: Tensor,
211    bias_u: bool,
212    bias_v: bool,
213    pairwise: bool,
214}
215
216impl PairwiseBilinear {
217    /// Construct a new bilinear layer.
218    pub fn new<'a>(
219        vs: impl Borrow<PathExt<'a>>,
220        config: &PairwiseBilinearConfig,
221    ) -> Result<Self, TransformerError> {
222        assert!(
223            config.in_features > 0,
224            "in_features should be > 0, was: {}",
225            config.in_features,
226        );
227
228        assert!(
229            config.out_features > 0,
230            "out_features should be > 0, was: {}",
231            config.out_features,
232        );
233
234        let vs = vs.borrow();
235
236        let bias_u_dim = if config.bias_u { 1 } else { 0 };
237        let bias_v_dim = if config.bias_v { 1 } else { 0 };
238
239        // We would normally use a separate variable for biases to enable
240        // treating biases using a different parameter group. However, in
241        // this case, the bias is not a 'constant' scalar, vector, or matrix,
242        // but actually interacts with the inputs. Therefore, it seems  more
243        // appropriate to treat biases in a biaffine classifier as regular
244        // trainable variables.
245        let weight = vs.var(
246            "weight",
247            &[
248                config.in_features + bias_u_dim,
249                config.out_features,
250                config.in_features + bias_v_dim,
251            ],
252            Init::Randn {
253                mean: 0.,
254                stdev: config.initializer_range,
255            },
256        )?;
257
258        Ok(PairwiseBilinear {
259            bias_u: config.bias_u,
260            bias_v: config.bias_v,
261            weight,
262            pairwise: config.pairwise,
263        })
264    }
265
266    /// Apply this layer to the given inputs.
267    ///
268    /// Both inputs must have the same shape. Returns a tensor of
269    /// shape `[batch_size, seq_len, seq_len, out_features]` given
270    /// inputs of shape `[batch_size, seq_len, in_features]`.
271    pub fn forward(&self, u: &Tensor, v: &Tensor) -> Result<Tensor, TransformerError> {
272        assert_eq!(
273            u.size(),
274            v.size(),
275            "Inputs to Bilinear must have the same shape: {:?} {:?}",
276            u.size(),
277            v.size()
278        );
279
280        assert_eq!(
281            u.dim(),
282            3,
283            "Shape should have 3 dimensions, has: {}",
284            u.dim()
285        );
286
287        let (batch_size, seq_len, _) = u.size3()?;
288
289        let ones = Tensor::ones(&[batch_size, seq_len, 1], (u.kind(), u.device()));
290
291        let u = if self.bias_u {
292            Tensor::f_cat(&[u, &ones], -1)?
293        } else {
294            u.shallow_clone()
295        };
296
297        let v = if self.bias_v {
298            Tensor::f_cat(&[v, &ones], -1)?
299        } else {
300            v.shallow_clone()
301        };
302
303        if self.pairwise {
304            // [batch_size, max_seq_len, out_features, v features].
305            let intermediate = Tensor::f_einsum("blu,uov->blov", &[&u, &self.weight], None)?;
306
307            // We perform a matrix multiplication to get the output with
308            // the shape [batch_size, seq_len, seq_len, out_features].
309            let bilinear = Tensor::f_einsum("bmv,blov->bmlo", &[&v, &intermediate], None)?;
310
311            Ok(bilinear.f_squeeze_dim(-1)?)
312        } else {
313            Ok(Tensor::f_einsum(
314                "blu,uov,blv->blo",
315                &[&u, &self.weight, &v],
316                None,
317            )?)
318        }
319    }
320}
321
322/// Variational dropout (Gal and Ghahramani, 2016)
323///
324/// For a tensor with `[batch_size, seq_len, repr_size]`, apply
325/// the same dropout `[batch_size, 1, repr_size]` to each sequence
326/// element.
327#[derive(Debug)]
328pub struct VariationalDropout {
329    p: f64,
330}
331
332impl VariationalDropout {
333    /// Create a variational dropout layer with the given dropout probability.
334    pub fn new(p: f64) -> Self {
335        VariationalDropout { p }
336    }
337}
338
339impl FallibleModuleT for VariationalDropout {
340    type Error = TransformerError;
341
342    fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor, Self::Error> {
343        // Avoid unnecessary work during prediction.
344        if !train {
345            return Ok(xs.shallow_clone());
346        }
347
348        let (batch_size, _, repr_size) = xs.size3()?;
349        let dropout_mask = Tensor::f_ones(&[batch_size, 1, repr_size], (xs.kind(), xs.device()))?
350            .f_dropout_(self.p, true)?;
351        Ok(xs.f_mul(&dropout_mask)?)
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use tch::nn::VarStore;
358    use tch::{Device, Kind, Tensor};
359
360    use syntaxdot_tch_ext::RootExt;
361
362    use crate::layers::{PairwiseBilinear, PairwiseBilinearConfig};
363
364    #[test]
365    fn bilinear_correct_shapes() {
366        // Apply a bilinear layer to ensure that the shapes are correct.
367
368        let input1 = Tensor::rand(&[64, 10, 200], (Kind::Float, Device::Cpu));
369        let input2 = Tensor::rand(&[64, 10, 200], (Kind::Float, Device::Cpu));
370
371        let vs = VarStore::new(Device::Cpu);
372        let bilinear = PairwiseBilinear::new(
373            vs.root_ext(|_| 0),
374            &PairwiseBilinearConfig {
375                bias_u: true,
376                bias_v: false,
377                in_features: 200,
378                out_features: 5,
379                initializer_range: 0.02,
380                pairwise: true,
381            },
382        )
383        .unwrap();
384
385        assert_eq!(
386            bilinear.forward(&input1, &input2).unwrap().size(),
387            &[64, 10, 10, 5]
388        );
389    }
390
391    #[test]
392    fn bilinear_1_output_correct_shapes() {
393        let input1 = Tensor::rand(&[64, 10, 200], (Kind::Float, Device::Cpu));
394        let input2 = Tensor::rand(&[64, 10, 200], (Kind::Float, Device::Cpu));
395
396        let vs = VarStore::new(Device::Cpu);
397        let bilinear = PairwiseBilinear::new(
398            vs.root_ext(|_| 0),
399            &PairwiseBilinearConfig {
400                bias_u: true,
401                bias_v: false,
402                in_features: 200,
403                out_features: 1,
404                initializer_range: 0.02,
405                pairwise: true,
406            },
407        )
408        .unwrap();
409
410        assert_eq!(
411            bilinear.forward(&input1, &input2).unwrap().size(),
412            &[64, 10, 10]
413        );
414    }
415}