1use std::borrow::Borrow;
7
8use syntaxdot_tch_ext::PathExt;
9use tch::nn::{ConvConfig, Init};
10use tch::{self, Tensor};
11
12use crate::module::{FallibleModule, FallibleModuleT};
13use crate::TransformerError;
14
15#[derive(Debug)]
17pub struct Conv1D {
18 pub ws: Tensor,
19 pub bs: Option<Tensor>,
20 pub config: ConvConfig,
21}
22
23impl Conv1D {
24 pub fn new<'a>(
25 vs: impl Borrow<PathExt<'a>>,
26 in_features: i64,
27 out_features: i64,
28 kernel_size: i64,
29 groups: i64,
30 ) -> Result<Self, TransformerError> {
31 let vs = vs.borrow();
32
33 let config = ConvConfig {
34 groups,
35 ..ConvConfig::default()
36 };
37
38 let bs = if config.bias {
39 Some(vs.var("bias", &[out_features], config.bs_init)?)
40 } else {
41 None
42 };
43
44 let ws = vs.var(
45 "weight",
46 &[out_features, in_features / groups, kernel_size],
47 config.ws_init,
48 )?;
49
50 Ok(Conv1D { ws, bs, config })
51 }
52}
53
54impl FallibleModule for Conv1D {
55 type Error = TransformerError;
56
57 fn forward(&self, xs: &Tensor) -> Result<Tensor, Self::Error> {
58 Ok(Tensor::f_conv1d(
59 xs,
60 &self.ws,
61 self.bs.as_ref(),
62 &[self.config.stride],
63 &[self.config.padding],
64 &[self.config.dilation],
65 self.config.groups,
66 )?)
67 }
68}
69
70#[derive(Debug)]
76pub struct Dropout {
77 p: f64,
78}
79
80impl Dropout {
81 pub fn new(p: f64) -> Self {
83 Dropout { p }
84 }
85}
86
87impl FallibleModuleT for Dropout {
88 type Error = TransformerError;
89
90 fn forward_t(&self, input: &Tensor, train: bool) -> Result<Tensor, Self::Error> {
91 Ok(input.f_dropout(self.p, train)?)
92 }
93}
94
95#[derive(Debug)]
97pub struct Embedding(pub Tensor);
98
99impl Embedding {
100 pub fn new<'a>(
101 vs: impl Borrow<PathExt<'a>>,
102 name: &str,
103 num_embeddings: i64,
104 embedding_dim: i64,
105 init: Init,
106 ) -> Result<Self, TransformerError> {
107 Ok(Embedding(vs.borrow().var(
108 name,
109 &[num_embeddings, embedding_dim],
110 init,
111 )?))
112 }
113}
114
115impl FallibleModule for Embedding {
116 type Error = TransformerError;
117
118 fn forward(&self, input: &Tensor) -> Result<Tensor, Self::Error> {
119 Ok(Tensor::f_embedding(&self.0, input, -1, false, false)?)
120 }
121}
122
123#[derive(Debug)]
125pub struct LayerNorm {
126 eps: f64,
127 normalized_shape: Vec<i64>,
128
129 weight: Option<Tensor>,
130 bias: Option<Tensor>,
131}
132
133impl LayerNorm {
134 pub fn new<'a>(
142 vs: impl Borrow<PathExt<'a>>,
143 normalized_shape: impl Into<Vec<i64>>,
144 eps: f64,
145 elementwise_affine: bool,
146 ) -> Self {
147 let vs = vs.borrow();
148
149 let normalized_shape = normalized_shape.into();
150
151 let (weight, bias) = if elementwise_affine {
152 (
153 Some(vs.ones("weight", &normalized_shape)),
154 Some(vs.zeros("bias", &normalized_shape)),
155 )
156 } else {
157 (None, None)
158 };
159
160 LayerNorm {
161 eps,
162 normalized_shape,
163 weight,
164 bias,
165 }
166 }
167}
168
169impl FallibleModule for LayerNorm {
170 type Error = TransformerError;
171
172 fn forward(&self, input: &Tensor) -> Result<Tensor, Self::Error> {
173 Ok(input.f_layer_norm(
176 &self.normalized_shape,
177 self.weight.as_ref(),
178 self.bias.as_ref(),
179 self.eps,
180 false,
181 )?)
182 }
183}
184
185#[derive(Clone, Copy, Debug)]
187pub struct PairwiseBilinearConfig {
188 pub in_features: i64,
190
191 pub out_features: i64,
193
194 pub initializer_range: f64,
196
197 pub bias_u: bool,
198
199 pub bias_v: bool,
200
201 pub pairwise: bool,
202}
203
204#[derive(Debug)]
209pub struct PairwiseBilinear {
210 weight: Tensor,
211 bias_u: bool,
212 bias_v: bool,
213 pairwise: bool,
214}
215
216impl PairwiseBilinear {
217 pub fn new<'a>(
219 vs: impl Borrow<PathExt<'a>>,
220 config: &PairwiseBilinearConfig,
221 ) -> Result<Self, TransformerError> {
222 assert!(
223 config.in_features > 0,
224 "in_features should be > 0, was: {}",
225 config.in_features,
226 );
227
228 assert!(
229 config.out_features > 0,
230 "out_features should be > 0, was: {}",
231 config.out_features,
232 );
233
234 let vs = vs.borrow();
235
236 let bias_u_dim = if config.bias_u { 1 } else { 0 };
237 let bias_v_dim = if config.bias_v { 1 } else { 0 };
238
239 let weight = vs.var(
246 "weight",
247 &[
248 config.in_features + bias_u_dim,
249 config.out_features,
250 config.in_features + bias_v_dim,
251 ],
252 Init::Randn {
253 mean: 0.,
254 stdev: config.initializer_range,
255 },
256 )?;
257
258 Ok(PairwiseBilinear {
259 bias_u: config.bias_u,
260 bias_v: config.bias_v,
261 weight,
262 pairwise: config.pairwise,
263 })
264 }
265
266 pub fn forward(&self, u: &Tensor, v: &Tensor) -> Result<Tensor, TransformerError> {
272 assert_eq!(
273 u.size(),
274 v.size(),
275 "Inputs to Bilinear must have the same shape: {:?} {:?}",
276 u.size(),
277 v.size()
278 );
279
280 assert_eq!(
281 u.dim(),
282 3,
283 "Shape should have 3 dimensions, has: {}",
284 u.dim()
285 );
286
287 let (batch_size, seq_len, _) = u.size3()?;
288
289 let ones = Tensor::ones(&[batch_size, seq_len, 1], (u.kind(), u.device()));
290
291 let u = if self.bias_u {
292 Tensor::f_cat(&[u, &ones], -1)?
293 } else {
294 u.shallow_clone()
295 };
296
297 let v = if self.bias_v {
298 Tensor::f_cat(&[v, &ones], -1)?
299 } else {
300 v.shallow_clone()
301 };
302
303 if self.pairwise {
304 let intermediate = Tensor::f_einsum("blu,uov->blov", &[&u, &self.weight], None)?;
306
307 let bilinear = Tensor::f_einsum("bmv,blov->bmlo", &[&v, &intermediate], None)?;
310
311 Ok(bilinear.f_squeeze_dim(-1)?)
312 } else {
313 Ok(Tensor::f_einsum(
314 "blu,uov,blv->blo",
315 &[&u, &self.weight, &v],
316 None,
317 )?)
318 }
319 }
320}
321
322#[derive(Debug)]
328pub struct VariationalDropout {
329 p: f64,
330}
331
332impl VariationalDropout {
333 pub fn new(p: f64) -> Self {
335 VariationalDropout { p }
336 }
337}
338
339impl FallibleModuleT for VariationalDropout {
340 type Error = TransformerError;
341
342 fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor, Self::Error> {
343 if !train {
345 return Ok(xs.shallow_clone());
346 }
347
348 let (batch_size, _, repr_size) = xs.size3()?;
349 let dropout_mask = Tensor::f_ones(&[batch_size, 1, repr_size], (xs.kind(), xs.device()))?
350 .f_dropout_(self.p, true)?;
351 Ok(xs.f_mul(&dropout_mask)?)
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use tch::nn::VarStore;
358 use tch::{Device, Kind, Tensor};
359
360 use syntaxdot_tch_ext::RootExt;
361
362 use crate::layers::{PairwiseBilinear, PairwiseBilinearConfig};
363
364 #[test]
365 fn bilinear_correct_shapes() {
366 let input1 = Tensor::rand(&[64, 10, 200], (Kind::Float, Device::Cpu));
369 let input2 = Tensor::rand(&[64, 10, 200], (Kind::Float, Device::Cpu));
370
371 let vs = VarStore::new(Device::Cpu);
372 let bilinear = PairwiseBilinear::new(
373 vs.root_ext(|_| 0),
374 &PairwiseBilinearConfig {
375 bias_u: true,
376 bias_v: false,
377 in_features: 200,
378 out_features: 5,
379 initializer_range: 0.02,
380 pairwise: true,
381 },
382 )
383 .unwrap();
384
385 assert_eq!(
386 bilinear.forward(&input1, &input2).unwrap().size(),
387 &[64, 10, 10, 5]
388 );
389 }
390
391 #[test]
392 fn bilinear_1_output_correct_shapes() {
393 let input1 = Tensor::rand(&[64, 10, 200], (Kind::Float, Device::Cpu));
394 let input2 = Tensor::rand(&[64, 10, 200], (Kind::Float, Device::Cpu));
395
396 let vs = VarStore::new(Device::Cpu);
397 let bilinear = PairwiseBilinear::new(
398 vs.root_ext(|_| 0),
399 &PairwiseBilinearConfig {
400 bias_u: true,
401 bias_v: false,
402 in_features: 200,
403 out_features: 1,
404 initializer_range: 0.02,
405 pairwise: true,
406 },
407 )
408 .unwrap();
409
410 assert_eq!(
411 bilinear.forward(&input1, &input2).unwrap().size(),
412 &[64, 10, 10]
413 );
414 }
415}