syntaxdot_transformers/
scalar_weighting.rs

1//! Scalar weighting of transformer layers.
2
3use std::borrow::Borrow;
4
5use syntaxdot_tch_ext::tensor::SumDim;
6use syntaxdot_tch_ext::PathExt;
7use tch::nn::init::DEFAULT_KAIMING_UNIFORM;
8use tch::nn::{Init, Linear, Module};
9use tch::{Kind, Reduction, Tensor};
10
11use crate::cow::CowTensor;
12use crate::layers::{Dropout, LayerNorm};
13use crate::loss::CrossEntropyLoss;
14use crate::models::LayerOutput;
15use crate::module::{FallibleModule, FallibleModuleT};
16use crate::TransformerError;
17
18/// Non-linear ReLU layer with layer normalization and dropout.
19#[derive(Debug)]
20struct NonLinearWithLayerNorm {
21    layer_norm: LayerNorm,
22    linear: Linear,
23    dropout: Dropout,
24}
25
26impl NonLinearWithLayerNorm {
27    fn new<'a>(
28        vs: impl Borrow<PathExt<'a>>,
29        in_size: i64,
30        out_size: i64,
31        dropout: f64,
32        layer_norm_eps: f64,
33    ) -> Result<NonLinearWithLayerNorm, TransformerError> {
34        let vs = vs.borrow();
35
36        Ok(NonLinearWithLayerNorm {
37            dropout: Dropout::new(dropout),
38            layer_norm: LayerNorm::new(vs / "layer_norm", vec![out_size], layer_norm_eps, true),
39            linear: Linear {
40                ws: vs.var("weight", &[out_size, in_size], DEFAULT_KAIMING_UNIFORM)?,
41                bs: Some(vs.var("bias", &[out_size], Init::Const(0.))?),
42            },
43        })
44    }
45}
46
47impl FallibleModuleT for NonLinearWithLayerNorm {
48    type Error = TransformerError;
49
50    fn forward_t(&self, input: &Tensor, train: bool) -> Result<Tensor, Self::Error> {
51        let mut hidden = self.linear.forward(input).relu();
52        hidden = self.layer_norm.forward(&hidden)?;
53        self.dropout.forward_t(&hidden, train)
54    }
55}
56
57/// Layer that performs a scalar weighting of layers.
58///
59/// Following Peters et al., 2018 and Kondratyuk & Straka, 2019, this
60/// layer applies scalar weighting:
61///
62/// *e = c ∑_i[ h_i · softmax(w)_i ]*
63#[derive(Debug)]
64pub struct ScalarWeight {
65    /// Layer dropout probability.
66    layer_dropout_prob: f64,
67
68    /// Layer-wise weights.
69    layer_weights: Tensor,
70
71    /// Scalar weight.
72    scale: Tensor,
73}
74
75impl ScalarWeight {
76    pub fn new<'a>(
77        vs: impl Borrow<PathExt<'a>>,
78        n_layers: i64,
79        layer_dropout_prob: f64,
80    ) -> Result<Self, TransformerError> {
81        assert!(
82            n_layers > 0,
83            "Number of layers ({}) should be larger than 0",
84            n_layers
85        );
86
87        assert!(
88            (0.0..1.0).contains(&layer_dropout_prob),
89            "Layer dropout should be in [0,1), was: {}",
90            layer_dropout_prob
91        );
92
93        let vs = vs.borrow();
94
95        Ok(ScalarWeight {
96            layer_dropout_prob,
97            layer_weights: vs.var("layer_weights", &[n_layers], Init::Const(0.))?,
98            scale: vs.var("scale", &[], Init::Const(1.))?,
99        })
100    }
101
102    pub fn forward(&self, layers: &[LayerOutput], train: bool) -> Result<Tensor, TransformerError> {
103        assert_eq!(
104            self.layer_weights.size()[0],
105            layers.len() as i64,
106            "Expected {} layers, got {}",
107            self.layer_weights.size()[0],
108            layers.len()
109        );
110
111        let layers = layers.iter().map(LayerOutput::output).collect::<Vec<_>>();
112
113        // Each layer has shape:
114        // [batch_size, sequence_len, layer_size],
115        //
116        // stack the layers to get a single tensor of shape:
117        // [batch_size, sequence_len, n_layers, layer_size]
118        let layers = Tensor::f_stack(&layers, 2)?;
119
120        let layer_weights = if train {
121            let dropout_mask = Tensor::f_empty_like(&self.layer_weights)?
122                .f_fill_(1.0 - self.layer_dropout_prob)?
123                .f_bernoulli()?;
124            let softmax_mask = (Tensor::from(1.0).f_sub(&dropout_mask.to_kind(Kind::Float))?)
125                .f_mul_scalar(-10_000.)?;
126            CowTensor::Owned(self.layer_weights.f_add(&softmax_mask)?)
127        } else {
128            CowTensor::Borrowed(&self.layer_weights)
129        };
130
131        // Convert the layer weights into a probability distribution and
132        // expand dimensions to get shape [1, 1, n_layers, 1].
133        let layer_weights = layer_weights
134            .f_softmax(-1, Kind::Float)?
135            .f_unsqueeze(0)?
136            .f_unsqueeze(0)?
137            .f_unsqueeze(-1)?;
138
139        let weighted_layers = layers.f_mul(&layer_weights)?;
140
141        // Sum across all layers and scale.
142        Ok(weighted_layers
143            .f_sum_dim(-2, false, Kind::Float)?
144            .f_mul(&self.scale)?)
145    }
146}
147
148/// A classifier that uses scalar weighting of layers.
149///
150/// See Peters et al., 2018 and Kondratyuk & Straka, 2019.
151#[derive(Debug)]
152pub struct ScalarWeightClassifier {
153    dropout: Dropout,
154    scalar_weight: ScalarWeight,
155    linear: Linear,
156    non_linear: NonLinearWithLayerNorm,
157}
158
159impl ScalarWeightClassifier {
160    pub fn new<'a>(
161        vs: impl Borrow<PathExt<'a>>,
162        config: &ScalarWeightClassifierConfig,
163    ) -> Result<ScalarWeightClassifier, TransformerError> {
164        assert!(
165            config.n_labels > 0,
166            "The number of labels should be larger than 0",
167        );
168
169        assert!(
170            config.input_size > 0,
171            "The input size should be larger than 0",
172        );
173
174        assert!(
175            config.hidden_size > 0,
176            "The hidden size should be larger than 0",
177        );
178
179        let vs = vs.borrow();
180
181        let ws = vs.var(
182            "weight",
183            &[config.n_labels, config.hidden_size],
184            DEFAULT_KAIMING_UNIFORM,
185        )?;
186        let bs = vs.var("bias", &[config.n_labels], Init::Const(0.))?;
187
188        let non_linear = NonLinearWithLayerNorm::new(
189            vs / "nonlinear",
190            config.input_size,
191            config.hidden_size,
192            config.dropout_prob,
193            config.layer_norm_eps,
194        )?;
195
196        Ok(ScalarWeightClassifier {
197            dropout: Dropout::new(config.dropout_prob),
198            linear: Linear { ws, bs: Some(bs) },
199            non_linear,
200            scalar_weight: ScalarWeight::new(
201                vs / "scalar_weight",
202                config.n_layers,
203                config.layer_dropout_prob,
204            )?,
205        })
206    }
207
208    pub fn forward(&self, layers: &[LayerOutput], train: bool) -> Result<Tensor, TransformerError> {
209        let logits = self.logits(layers, train)?;
210        Ok(logits.f_softmax(-1, Kind::Float)?)
211    }
212
213    pub fn logits(&self, layers: &[LayerOutput], train: bool) -> Result<Tensor, TransformerError> {
214        let mut features = self.scalar_weight.forward(layers, train)?;
215
216        features = self.dropout.forward_t(&features, train)?;
217
218        features = self.non_linear.forward_t(&features, train)?;
219
220        Ok(self.linear.forward(&features))
221    }
222
223    /// Compute the losses and correctly predicted labels of the given targets.
224    ///
225    /// `targets` should be of the shape `[batch_size, seq_len]`.
226    pub fn losses(
227        &self,
228        layers: &[LayerOutput],
229        targets: &Tensor,
230        label_smoothing: Option<f64>,
231        train: bool,
232    ) -> Result<(Tensor, Tensor), TransformerError> {
233        assert_eq!(
234            targets.dim(),
235            2,
236            "Targets shoul have dimensionality 2, had {}",
237            targets.dim()
238        );
239
240        let (batch_size, seq_len) = targets.size2()?;
241
242        let n_labels = self.linear.ws.size()[0];
243
244        let logits = self
245            .logits(layers, train)?
246            .f_view([batch_size * seq_len, n_labels])?;
247        let targets = targets.f_view([batch_size * seq_len])?;
248
249        let predicted = logits.f_argmax(-1, false)?;
250
251        let losses = CrossEntropyLoss::new(-1, label_smoothing, Reduction::None)
252            .forward(&logits, &targets, None)?
253            .f_view([batch_size, seq_len])?;
254
255        Ok((
256            losses,
257            predicted
258                .f_eq_tensor(&targets)?
259                .f_view([batch_size, seq_len])?,
260        ))
261    }
262}
263
264/// Configuration for the scalar weight classifier.
265pub struct ScalarWeightClassifierConfig {
266    /// Size of the hidden layer.
267    pub hidden_size: i64,
268
269    /// Size of the input to the classification layer.
270    pub input_size: i64,
271
272    /// Number of layers to weigh.
273    pub n_layers: i64,
274
275    /// Number of labels.
276    pub n_labels: i64,
277
278    /// The probability of excluding a layer from scalar weighting.
279    pub layer_dropout_prob: f64,
280
281    /// Hidden layer dropout probability.
282    pub dropout_prob: f64,
283
284    /// Layer norm epsilon.
285    pub layer_norm_eps: f64,
286}
287
288#[cfg(test)]
289mod tests {
290    use std::collections::BTreeSet;
291    use std::iter::FromIterator;
292
293    use syntaxdot_tch_ext::RootExt;
294    use tch::nn::VarStore;
295    use tch::{Device, Kind, Tensor};
296
297    use super::{ScalarWeightClassifier, ScalarWeightClassifierConfig};
298    use crate::models::{HiddenLayer, LayerOutput};
299
300    fn varstore_variables(vs: &VarStore) -> BTreeSet<String> {
301        vs.variables().into_keys().collect::<BTreeSet<_>>()
302    }
303
304    #[test]
305    fn scalar_weight_classifier_shapes_forward_works() {
306        let vs = VarStore::new(Device::Cpu);
307
308        let classifier = ScalarWeightClassifier::new(
309            vs.root_ext(|_| 0),
310            &ScalarWeightClassifierConfig {
311                hidden_size: 10,
312                input_size: 8,
313                n_labels: 5,
314                n_layers: 2,
315                dropout_prob: 0.1,
316                layer_dropout_prob: 0.1,
317                layer_norm_eps: 0.01,
318            },
319        )
320        .unwrap();
321
322        let layer1 = LayerOutput::EncoderWithAttention(HiddenLayer {
323            attention: Tensor::zeros(&[1, 3, 2], (Kind::Float, Device::Cpu)),
324            output: Tensor::zeros(&[1, 3, 8], (Kind::Float, Device::Cpu)),
325        });
326        let layer2 = LayerOutput::EncoderWithAttention(HiddenLayer {
327            attention: Tensor::zeros(&[1, 3, 2], (Kind::Float, Device::Cpu)),
328            output: Tensor::zeros(&[1, 3, 8], (Kind::Float, Device::Cpu)),
329        });
330
331        // Perform a forward pass to check that all shapes align.
332        let results = classifier.forward(&[layer1, layer2], false).unwrap();
333
334        assert_eq!(results.size(), &[1, 3, 5]);
335    }
336
337    #[test]
338    fn scalar_weight_classifier_names() {
339        let vs = VarStore::new(Device::Cpu);
340
341        let _classifier = ScalarWeightClassifier::new(
342            vs.root_ext(|_| 0),
343            &ScalarWeightClassifierConfig {
344                hidden_size: 10,
345                input_size: 8,
346                n_labels: 5,
347                n_layers: 2,
348                dropout_prob: 0.1,
349                layer_dropout_prob: 0.1,
350                layer_norm_eps: 0.01,
351            },
352        );
353
354        assert_eq!(
355            varstore_variables(&vs),
356            BTreeSet::from_iter(vec![
357                "bias".to_string(),
358                "weight".to_string(),
359                "nonlinear.bias".to_string(),
360                "nonlinear.weight".to_string(),
361                "nonlinear.layer_norm.bias".to_string(),
362                "nonlinear.layer_norm.weight".to_string(),
363                "scalar_weight.layer_weights".to_string(),
364                "scalar_weight.scale".to_string()
365            ])
366        )
367    }
368}