syntaxdot_transformers/models/sinusoidal/
mod.rs1use std::borrow::Borrow;
4
5use syntaxdot_tch_ext::PathExt;
6use tch::nn::Init;
7use tch::{Kind, Tensor};
8
9use crate::layers::{Dropout, Embedding, LayerNorm};
10use crate::models::traits::WordEmbeddingsConfig;
11use crate::module::{FallibleModule, FallibleModuleT};
12use crate::util::SinusoidalPositions;
13use crate::TransformerError;
14
15#[derive(Debug)]
20pub struct SinusoidalEmbeddings {
21 dropout: Dropout,
22 layer_norm: LayerNorm,
23 p_norm: Option<f64>,
24 word_embeddings: Embedding,
25}
26
27impl SinusoidalEmbeddings {
28 pub fn new<'a>(
33 vs: impl Borrow<PathExt<'a>>,
34 config: &impl WordEmbeddingsConfig,
35 p_norm: Option<f64>,
36 ) -> Result<SinusoidalEmbeddings, TransformerError> {
37 let vs = vs.borrow();
38
39 let normal_init = Init::Randn {
40 mean: 0.,
41 stdev: config.initializer_range(),
42 };
43
44 let word_embeddings = Embedding::new(
45 vs / "word_embeddings",
46 "embeddings",
47 config.vocab_size(),
48 config.dims(),
49 normal_init,
50 )?;
51
52 let layer_norm = LayerNorm::new(
53 vs / "layer_norm",
54 vec![config.dims()],
55 config.layer_norm_eps(),
56 true,
57 );
58
59 let dropout = Dropout::new(config.dropout());
60
61 Ok(SinusoidalEmbeddings {
62 dropout,
63 layer_norm,
64 p_norm,
65 word_embeddings,
66 })
67 }
68}
69
70impl FallibleModuleT for SinusoidalEmbeddings {
71 type Error = TransformerError;
72
73 fn forward_t(&self, input_ids: &Tensor, train: bool) -> Result<Tensor, Self::Error> {
74 let word_embeddings = self.word_embeddings.forward(input_ids)?;
75
76 let (_, seq_length, embedding_dim) = word_embeddings.size3()?;
77
78 let position_embeddings: Tensor = SinusoidalPositions::sinusoidal_positions(
79 seq_length,
80 embedding_dim,
81 self.p_norm,
82 (Kind::Float, word_embeddings.device()),
83 )?;
84
85 let mut embeddings = tch::no_grad::<Result<_, TransformerError>, _>(|| {
86 Ok(word_embeddings.f_add(&position_embeddings.f_unsqueeze(0)?)?)
87 })?;
88 embeddings = self.layer_norm.forward(&embeddings)?;
89 self.dropout.forward_t(&embeddings, train)
90 }
91}
92
93#[cfg(feature = "model-tests")]
94#[cfg(test)]
95mod tests {
96 use std::convert::TryInto;
97
98 use approx::assert_abs_diff_eq;
99 use ndarray::{array, ArrayD};
100 use syntaxdot_tch_ext::tensor::SumDim;
101 use syntaxdot_tch_ext::RootExt;
102 use tch::nn::VarStore;
103 use tch::{Device, Kind, Tensor};
104
105 use crate::activations::Activation;
106 use crate::models::bert::BertConfig;
107 use crate::models::sinusoidal::SinusoidalEmbeddings;
108 use crate::module::FallibleModuleT;
109
110 const BERT_BASE_GERMAN_CASED: &str = env!("BERT_BASE_GERMAN_CASED");
114
115 fn german_bert_config() -> BertConfig {
116 BertConfig {
117 attention_probs_dropout_prob: 0.1,
118 hidden_act: Activation::Gelu,
119 hidden_dropout_prob: 0.1,
120 hidden_size: 768,
121 initializer_range: 0.02,
122 intermediate_size: 3072,
123 layer_norm_eps: 1e-12,
124 max_position_embeddings: 512,
125 num_attention_heads: 12,
126 num_hidden_layers: 12,
127 type_vocab_size: 2,
128 vocab_size: 30000,
129 }
130 }
131
132 #[test]
133 fn sinusoidal_embeddings_are_unchanged_without_norm() {
134 let sums: ArrayD<f32> = get_and_sum_test_embeddings(None);
135
136 assert_abs_diff_eq!(
138 sums,
139 (array![[
140 -7.433159, -7.3248596, -6.981781, -5.287575, -5.657837, -6.173279, -6.0414734,
141 -6.0355415, -5.6972923, -4.800411
142 ]])
143 .into_dyn(),
144 epsilon = 1e-4
145 );
146 }
147
148 #[test]
149 fn sinusoidal_embeddings_are_unchanged_with_norm() {
150 let sums: ArrayD<f32> = get_and_sum_test_embeddings(Some(2.0));
151
152 assert_abs_diff_eq!(
154 sums,
155 (array![[
156 -5.801262, -7.803936, -9.95359, 5.575783, 0.79592514, -3.6844482, -2.3470383,
157 -5.6341896, -6.2476273, 1.965559
158 ]])
159 .into_dyn(),
160 epsilon = 1e-4
161 );
162 }
163
164 fn get_and_sum_test_embeddings(p_norm: Option<f64>) -> ArrayD<f32> {
165 let config = german_bert_config();
166 let mut vs = VarStore::new(Device::Cpu);
167 let root = vs.root_ext(|_| 0);
168
169 let embeddings =
170 SinusoidalEmbeddings::new(root.sub("embeddings"), &config, p_norm).unwrap();
171
172 vs.load(BERT_BASE_GERMAN_CASED).unwrap();
173
174 let pieces = Tensor::of_slice(&[133i64, 1937, 14010, 30, 32, 26939, 26962, 12558, 2739, 2])
176 .reshape(&[1, 10]);
177
178 let summed_embeddings =
179 embeddings
180 .forward_t(&pieces, false)
181 .unwrap()
182 .sum_dim(-1, false, Kind::Float);
183
184 (&summed_embeddings).try_into().unwrap()
185 }
186}