syntaxdot_transformers/models/roberta/
mod.rs1use std::borrow::Borrow;
20
21use syntaxdot_tch_ext::PathExt;
22use tch::{Kind, Tensor};
23
24use crate::cow::CowTensor;
25use crate::models::bert::{BertConfig, BertEmbeddings};
26use crate::module::FallibleModuleT;
27use crate::TransformerError;
28
29const PADDING_IDX: i64 = 1;
30
31#[derive(Debug)]
33pub struct RobertaEmbeddings {
34 inner: BertEmbeddings,
35}
36
37impl RobertaEmbeddings {
38 pub fn new<'a>(
41 vs: impl Borrow<PathExt<'a>>,
42 config: &BertConfig,
43 ) -> Result<RobertaEmbeddings, TransformerError> {
44 Ok(RobertaEmbeddings {
45 inner: BertEmbeddings::new(vs, config)?,
46 })
47 }
48
49 pub fn forward(
50 &self,
51 input_ids: &Tensor,
52 token_type_ids: Option<&Tensor>,
53 position_ids: Option<&Tensor>,
54 train: bool,
55 ) -> Result<Tensor, TransformerError> {
56 let position_ids = match position_ids {
57 Some(position_ids) => CowTensor::Borrowed(position_ids),
58 None => {
59 let mask = input_ids.f_ne(PADDING_IDX)?.to_kind(Kind::Int64);
60 let incremental_indices = mask.f_cumsum(1, Kind::Int64)?.f_mul(&mask)?;
61 CowTensor::Owned(incremental_indices.f_add_scalar(PADDING_IDX)?)
62 }
63 };
64
65 self.inner.forward(
66 input_ids,
67 token_type_ids,
68 Some(position_ids.as_ref()),
69 train,
70 )
71 }
72}
73
74impl FallibleModuleT for RobertaEmbeddings {
75 type Error = TransformerError;
76
77 fn forward_t(&self, input: &Tensor, train: bool) -> Result<Tensor, Self::Error> {
78 self.forward(input, None, None, train)
79 }
80}
81
82#[cfg(feature = "model-tests")]
83#[cfg(test)]
84mod tests {
85 use std::convert::TryInto;
86
87 use approx::assert_abs_diff_eq;
88 use ndarray::{array, ArrayD};
89 use syntaxdot_tch_ext::tensor::SumDim;
90 use syntaxdot_tch_ext::RootExt;
91 use tch::nn::VarStore;
92 use tch::{Device, Kind, Tensor};
93
94 use crate::activations::Activation;
95 use crate::models::bert::{BertConfig, BertEncoder};
96 use crate::models::roberta::RobertaEmbeddings;
97 use crate::models::Encoder;
98 use crate::module::FallibleModuleT;
99
100 const XLM_ROBERTA_BASE: &str = env!("XLM_ROBERTA_BASE");
101
102 fn xlm_roberta_config() -> BertConfig {
103 BertConfig {
104 attention_probs_dropout_prob: 0.1,
105 hidden_act: Activation::Gelu,
106 hidden_dropout_prob: 0.1,
107 hidden_size: 768,
108 initializer_range: 0.02,
109 intermediate_size: 3072,
110 layer_norm_eps: 1e-5,
111 max_position_embeddings: 514,
112 num_attention_heads: 12,
113 num_hidden_layers: 12,
114 type_vocab_size: 1,
115 vocab_size: 250002,
116 }
117 }
118
119 #[test]
120 fn xlm_roberta_embeddings() {
121 let config = xlm_roberta_config();
122 let mut vs = VarStore::new(Device::Cpu);
123 let root = vs.root_ext(|_| 0);
124
125 let embeddings = RobertaEmbeddings::new(root.sub("embeddings"), &config).unwrap();
126
127 vs.load(XLM_ROBERTA_BASE).unwrap();
128
129 let pieces = Tensor::of_slice(&[
131 0i64, 310, 23451, 107, 6743, 68, 62, 43789, 207126, 49004, 705, 2,
132 ])
133 .reshape(&[1, 12]);
134
135 let summed_embeddings =
136 embeddings
137 .forward_t(&pieces, false)
138 .unwrap()
139 .sum_dim(-1, false, Kind::Float);
140
141 let sums: ArrayD<f32> = (&summed_embeddings).try_into().unwrap();
142
143 assert_abs_diff_eq!(
146 sums,
147 (array![[
148 -9.1686, -4.2982, -0.7808, -0.7097, 0.0972, -3.0785, -3.6755, -2.1465, -2.9406,
149 -1.0627, -6.6043, -4.8064
150 ]])
151 .into_dyn(),
152 epsilon = 1e-4
153 );
154 }
155
156 #[test]
157 fn xlm_roberta_encoder() {
158 let config = xlm_roberta_config();
159 let mut vs = VarStore::new(Device::Cpu);
160 let root = vs.root_ext(|_| 0);
161
162 let embeddings = RobertaEmbeddings::new(root.sub("embeddings"), &config).unwrap();
163 let encoder = BertEncoder::new(root.sub("encoder"), &config).unwrap();
164
165 vs.load(XLM_ROBERTA_BASE).unwrap();
166
167 let pieces = Tensor::of_slice(&[
169 0i64, 310, 23451, 107, 6743, 68, 62, 43789, 207126, 49004, 705, 2,
170 ])
171 .reshape(&[1, 12]);
172
173 let embeddings = embeddings.forward_t(&pieces, false).unwrap();
174
175 let all_hidden_states = encoder.encode(&embeddings, None, false).unwrap();
176
177 let summed_last_hidden =
178 all_hidden_states
179 .last()
180 .unwrap()
181 .output()
182 .sum_dim(-1, false, Kind::Float);
183
184 let sums: ArrayD<f32> = (&summed_last_hidden).try_into().unwrap();
185
186 assert_abs_diff_eq!(
187 sums,
188 (array![[
189 20.9693, 19.7502, 17.0594, 19.0700, 19.0065, 19.6254, 18.9379, 18.9275, 18.8922,
190 18.9505, 19.2682, 20.9411
191 ]])
192 .into_dyn(),
193 epsilon = 1e-4
194 );
195 }
196}