1use std::borrow::Borrow;
4
5use syntaxdot_tch_ext::tensor::SumDim;
6use syntaxdot_tch_ext::PathExt;
7use tch::nn::init::DEFAULT_KAIMING_UNIFORM;
8use tch::nn::{Init, Linear, Module};
9use tch::{Kind, Reduction, Tensor};
10
11use crate::cow::CowTensor;
12use crate::layers::{Dropout, LayerNorm};
13use crate::loss::CrossEntropyLoss;
14use crate::models::LayerOutput;
15use crate::module::{FallibleModule, FallibleModuleT};
16use crate::TransformerError;
17
18#[derive(Debug)]
20struct NonLinearWithLayerNorm {
21 layer_norm: LayerNorm,
22 linear: Linear,
23 dropout: Dropout,
24}
25
26impl NonLinearWithLayerNorm {
27 fn new<'a>(
28 vs: impl Borrow<PathExt<'a>>,
29 in_size: i64,
30 out_size: i64,
31 dropout: f64,
32 layer_norm_eps: f64,
33 ) -> Result<NonLinearWithLayerNorm, TransformerError> {
34 let vs = vs.borrow();
35
36 Ok(NonLinearWithLayerNorm {
37 dropout: Dropout::new(dropout),
38 layer_norm: LayerNorm::new(vs / "layer_norm", vec![out_size], layer_norm_eps, true),
39 linear: Linear {
40 ws: vs.var("weight", &[out_size, in_size], DEFAULT_KAIMING_UNIFORM)?,
41 bs: Some(vs.var("bias", &[out_size], Init::Const(0.))?),
42 },
43 })
44 }
45}
46
47impl FallibleModuleT for NonLinearWithLayerNorm {
48 type Error = TransformerError;
49
50 fn forward_t(&self, input: &Tensor, train: bool) -> Result<Tensor, Self::Error> {
51 let mut hidden = self.linear.forward(input).relu();
52 hidden = self.layer_norm.forward(&hidden)?;
53 self.dropout.forward_t(&hidden, train)
54 }
55}
56
57#[derive(Debug)]
64pub struct ScalarWeight {
65 layer_dropout_prob: f64,
67
68 layer_weights: Tensor,
70
71 scale: Tensor,
73}
74
75impl ScalarWeight {
76 pub fn new<'a>(
77 vs: impl Borrow<PathExt<'a>>,
78 n_layers: i64,
79 layer_dropout_prob: f64,
80 ) -> Result<Self, TransformerError> {
81 assert!(
82 n_layers > 0,
83 "Number of layers ({}) should be larger than 0",
84 n_layers
85 );
86
87 assert!(
88 (0.0..1.0).contains(&layer_dropout_prob),
89 "Layer dropout should be in [0,1), was: {}",
90 layer_dropout_prob
91 );
92
93 let vs = vs.borrow();
94
95 Ok(ScalarWeight {
96 layer_dropout_prob,
97 layer_weights: vs.var("layer_weights", &[n_layers], Init::Const(0.))?,
98 scale: vs.var("scale", &[], Init::Const(1.))?,
99 })
100 }
101
102 pub fn forward(&self, layers: &[LayerOutput], train: bool) -> Result<Tensor, TransformerError> {
103 assert_eq!(
104 self.layer_weights.size()[0],
105 layers.len() as i64,
106 "Expected {} layers, got {}",
107 self.layer_weights.size()[0],
108 layers.len()
109 );
110
111 let layers = layers.iter().map(LayerOutput::output).collect::<Vec<_>>();
112
113 let layers = Tensor::f_stack(&layers, 2)?;
119
120 let layer_weights = if train {
121 let dropout_mask = Tensor::f_empty_like(&self.layer_weights)?
122 .f_fill_(1.0 - self.layer_dropout_prob)?
123 .f_bernoulli()?;
124 let softmax_mask = (Tensor::from(1.0).f_sub(&dropout_mask.to_kind(Kind::Float))?)
125 .f_mul_scalar(-10_000.)?;
126 CowTensor::Owned(self.layer_weights.f_add(&softmax_mask)?)
127 } else {
128 CowTensor::Borrowed(&self.layer_weights)
129 };
130
131 let layer_weights = layer_weights
134 .f_softmax(-1, Kind::Float)?
135 .f_unsqueeze(0)?
136 .f_unsqueeze(0)?
137 .f_unsqueeze(-1)?;
138
139 let weighted_layers = layers.f_mul(&layer_weights)?;
140
141 Ok(weighted_layers
143 .f_sum_dim(-2, false, Kind::Float)?
144 .f_mul(&self.scale)?)
145 }
146}
147
148#[derive(Debug)]
152pub struct ScalarWeightClassifier {
153 dropout: Dropout,
154 scalar_weight: ScalarWeight,
155 linear: Linear,
156 non_linear: NonLinearWithLayerNorm,
157}
158
159impl ScalarWeightClassifier {
160 pub fn new<'a>(
161 vs: impl Borrow<PathExt<'a>>,
162 config: &ScalarWeightClassifierConfig,
163 ) -> Result<ScalarWeightClassifier, TransformerError> {
164 assert!(
165 config.n_labels > 0,
166 "The number of labels should be larger than 0",
167 );
168
169 assert!(
170 config.input_size > 0,
171 "The input size should be larger than 0",
172 );
173
174 assert!(
175 config.hidden_size > 0,
176 "The hidden size should be larger than 0",
177 );
178
179 let vs = vs.borrow();
180
181 let ws = vs.var(
182 "weight",
183 &[config.n_labels, config.hidden_size],
184 DEFAULT_KAIMING_UNIFORM,
185 )?;
186 let bs = vs.var("bias", &[config.n_labels], Init::Const(0.))?;
187
188 let non_linear = NonLinearWithLayerNorm::new(
189 vs / "nonlinear",
190 config.input_size,
191 config.hidden_size,
192 config.dropout_prob,
193 config.layer_norm_eps,
194 )?;
195
196 Ok(ScalarWeightClassifier {
197 dropout: Dropout::new(config.dropout_prob),
198 linear: Linear { ws, bs: Some(bs) },
199 non_linear,
200 scalar_weight: ScalarWeight::new(
201 vs / "scalar_weight",
202 config.n_layers,
203 config.layer_dropout_prob,
204 )?,
205 })
206 }
207
208 pub fn forward(&self, layers: &[LayerOutput], train: bool) -> Result<Tensor, TransformerError> {
209 let logits = self.logits(layers, train)?;
210 Ok(logits.f_softmax(-1, Kind::Float)?)
211 }
212
213 pub fn logits(&self, layers: &[LayerOutput], train: bool) -> Result<Tensor, TransformerError> {
214 let mut features = self.scalar_weight.forward(layers, train)?;
215
216 features = self.dropout.forward_t(&features, train)?;
217
218 features = self.non_linear.forward_t(&features, train)?;
219
220 Ok(self.linear.forward(&features))
221 }
222
223 pub fn losses(
227 &self,
228 layers: &[LayerOutput],
229 targets: &Tensor,
230 label_smoothing: Option<f64>,
231 train: bool,
232 ) -> Result<(Tensor, Tensor), TransformerError> {
233 assert_eq!(
234 targets.dim(),
235 2,
236 "Targets shoul have dimensionality 2, had {}",
237 targets.dim()
238 );
239
240 let (batch_size, seq_len) = targets.size2()?;
241
242 let n_labels = self.linear.ws.size()[0];
243
244 let logits = self
245 .logits(layers, train)?
246 .f_view([batch_size * seq_len, n_labels])?;
247 let targets = targets.f_view([batch_size * seq_len])?;
248
249 let predicted = logits.f_argmax(-1, false)?;
250
251 let losses = CrossEntropyLoss::new(-1, label_smoothing, Reduction::None)
252 .forward(&logits, &targets, None)?
253 .f_view([batch_size, seq_len])?;
254
255 Ok((
256 losses,
257 predicted
258 .f_eq_tensor(&targets)?
259 .f_view([batch_size, seq_len])?,
260 ))
261 }
262}
263
264pub struct ScalarWeightClassifierConfig {
266 pub hidden_size: i64,
268
269 pub input_size: i64,
271
272 pub n_layers: i64,
274
275 pub n_labels: i64,
277
278 pub layer_dropout_prob: f64,
280
281 pub dropout_prob: f64,
283
284 pub layer_norm_eps: f64,
286}
287
288#[cfg(test)]
289mod tests {
290 use std::collections::BTreeSet;
291 use std::iter::FromIterator;
292
293 use syntaxdot_tch_ext::RootExt;
294 use tch::nn::VarStore;
295 use tch::{Device, Kind, Tensor};
296
297 use super::{ScalarWeightClassifier, ScalarWeightClassifierConfig};
298 use crate::models::{HiddenLayer, LayerOutput};
299
300 fn varstore_variables(vs: &VarStore) -> BTreeSet<String> {
301 vs.variables().into_keys().collect::<BTreeSet<_>>()
302 }
303
304 #[test]
305 fn scalar_weight_classifier_shapes_forward_works() {
306 let vs = VarStore::new(Device::Cpu);
307
308 let classifier = ScalarWeightClassifier::new(
309 vs.root_ext(|_| 0),
310 &ScalarWeightClassifierConfig {
311 hidden_size: 10,
312 input_size: 8,
313 n_labels: 5,
314 n_layers: 2,
315 dropout_prob: 0.1,
316 layer_dropout_prob: 0.1,
317 layer_norm_eps: 0.01,
318 },
319 )
320 .unwrap();
321
322 let layer1 = LayerOutput::EncoderWithAttention(HiddenLayer {
323 attention: Tensor::zeros(&[1, 3, 2], (Kind::Float, Device::Cpu)),
324 output: Tensor::zeros(&[1, 3, 8], (Kind::Float, Device::Cpu)),
325 });
326 let layer2 = LayerOutput::EncoderWithAttention(HiddenLayer {
327 attention: Tensor::zeros(&[1, 3, 2], (Kind::Float, Device::Cpu)),
328 output: Tensor::zeros(&[1, 3, 8], (Kind::Float, Device::Cpu)),
329 });
330
331 let results = classifier.forward(&[layer1, layer2], false).unwrap();
333
334 assert_eq!(results.size(), &[1, 3, 5]);
335 }
336
337 #[test]
338 fn scalar_weight_classifier_names() {
339 let vs = VarStore::new(Device::Cpu);
340
341 let _classifier = ScalarWeightClassifier::new(
342 vs.root_ext(|_| 0),
343 &ScalarWeightClassifierConfig {
344 hidden_size: 10,
345 input_size: 8,
346 n_labels: 5,
347 n_layers: 2,
348 dropout_prob: 0.1,
349 layer_dropout_prob: 0.1,
350 layer_norm_eps: 0.01,
351 },
352 );
353
354 assert_eq!(
355 varstore_variables(&vs),
356 BTreeSet::from_iter(vec![
357 "bias".to_string(),
358 "weight".to_string(),
359 "nonlinear.bias".to_string(),
360 "nonlinear.weight".to_string(),
361 "nonlinear.layer_norm.bias".to_string(),
362 "nonlinear.layer_norm.weight".to_string(),
363 "scalar_weight.layer_weights".to_string(),
364 "scalar_weight.scale".to_string()
365 ])
366 )
367 }
368}