syntaxdot_transformers/models/roberta/
mod.rs

1// Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
2// Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
3// Copyright (c) 2020 The sticker developers.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9//     http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16
17//! RoBERTa (Liu et al., 2018) and XLM-RoBERTa (Conneau et al., 2019).
18
19use std::borrow::Borrow;
20
21use syntaxdot_tch_ext::PathExt;
22use tch::{Kind, Tensor};
23
24use crate::cow::CowTensor;
25use crate::models::bert::{BertConfig, BertEmbeddings};
26use crate::module::FallibleModuleT;
27use crate::TransformerError;
28
29const PADDING_IDX: i64 = 1;
30
31/// RoBERTa and XLM-RoBERTa embeddings.
32#[derive(Debug)]
33pub struct RobertaEmbeddings {
34    inner: BertEmbeddings,
35}
36
37impl RobertaEmbeddings {
38    /// Construct new RoBERTa embeddings with the given variable store
39    /// and Bert configuration.
40    pub fn new<'a>(
41        vs: impl Borrow<PathExt<'a>>,
42        config: &BertConfig,
43    ) -> Result<RobertaEmbeddings, TransformerError> {
44        Ok(RobertaEmbeddings {
45            inner: BertEmbeddings::new(vs, config)?,
46        })
47    }
48
49    pub fn forward(
50        &self,
51        input_ids: &Tensor,
52        token_type_ids: Option<&Tensor>,
53        position_ids: Option<&Tensor>,
54        train: bool,
55    ) -> Result<Tensor, TransformerError> {
56        let position_ids = match position_ids {
57            Some(position_ids) => CowTensor::Borrowed(position_ids),
58            None => {
59                let mask = input_ids.f_ne(PADDING_IDX)?.to_kind(Kind::Int64);
60                let incremental_indices = mask.f_cumsum(1, Kind::Int64)?.f_mul(&mask)?;
61                CowTensor::Owned(incremental_indices.f_add_scalar(PADDING_IDX)?)
62            }
63        };
64
65        self.inner.forward(
66            input_ids,
67            token_type_ids,
68            Some(position_ids.as_ref()),
69            train,
70        )
71    }
72}
73
74impl FallibleModuleT for RobertaEmbeddings {
75    type Error = TransformerError;
76
77    fn forward_t(&self, input: &Tensor, train: bool) -> Result<Tensor, Self::Error> {
78        self.forward(input, None, None, train)
79    }
80}
81
82#[cfg(feature = "model-tests")]
83#[cfg(test)]
84mod tests {
85    use std::convert::TryInto;
86
87    use approx::assert_abs_diff_eq;
88    use ndarray::{array, ArrayD};
89    use syntaxdot_tch_ext::tensor::SumDim;
90    use syntaxdot_tch_ext::RootExt;
91    use tch::nn::VarStore;
92    use tch::{Device, Kind, Tensor};
93
94    use crate::activations::Activation;
95    use crate::models::bert::{BertConfig, BertEncoder};
96    use crate::models::roberta::RobertaEmbeddings;
97    use crate::models::Encoder;
98    use crate::module::FallibleModuleT;
99
100    const XLM_ROBERTA_BASE: &str = env!("XLM_ROBERTA_BASE");
101
102    fn xlm_roberta_config() -> BertConfig {
103        BertConfig {
104            attention_probs_dropout_prob: 0.1,
105            hidden_act: Activation::Gelu,
106            hidden_dropout_prob: 0.1,
107            hidden_size: 768,
108            initializer_range: 0.02,
109            intermediate_size: 3072,
110            layer_norm_eps: 1e-5,
111            max_position_embeddings: 514,
112            num_attention_heads: 12,
113            num_hidden_layers: 12,
114            type_vocab_size: 1,
115            vocab_size: 250002,
116        }
117    }
118
119    #[test]
120    fn xlm_roberta_embeddings() {
121        let config = xlm_roberta_config();
122        let mut vs = VarStore::new(Device::Cpu);
123        let root = vs.root_ext(|_| 0);
124
125        let embeddings = RobertaEmbeddings::new(root.sub("embeddings"), &config).unwrap();
126
127        vs.load(XLM_ROBERTA_BASE).unwrap();
128
129        // Subtokenization of: Veruntreute die AWO spendengeld ?
130        let pieces = Tensor::of_slice(&[
131            0i64, 310, 23451, 107, 6743, 68, 62, 43789, 207126, 49004, 705, 2,
132        ])
133        .reshape(&[1, 12]);
134
135        let summed_embeddings =
136            embeddings
137                .forward_t(&pieces, false)
138                .unwrap()
139                .sum_dim(-1, false, Kind::Float);
140
141        let sums: ArrayD<f32> = (&summed_embeddings).try_into().unwrap();
142
143        // Verify output against Hugging Face transformers Python
144        // implementation.
145        assert_abs_diff_eq!(
146            sums,
147            (array![[
148                -9.1686, -4.2982, -0.7808, -0.7097, 0.0972, -3.0785, -3.6755, -2.1465, -2.9406,
149                -1.0627, -6.6043, -4.8064
150            ]])
151            .into_dyn(),
152            epsilon = 1e-4
153        );
154    }
155
156    #[test]
157    fn xlm_roberta_encoder() {
158        let config = xlm_roberta_config();
159        let mut vs = VarStore::new(Device::Cpu);
160        let root = vs.root_ext(|_| 0);
161
162        let embeddings = RobertaEmbeddings::new(root.sub("embeddings"), &config).unwrap();
163        let encoder = BertEncoder::new(root.sub("encoder"), &config).unwrap();
164
165        vs.load(XLM_ROBERTA_BASE).unwrap();
166
167        // Subtokenization of: Veruntreute die AWO spendengeld ?
168        let pieces = Tensor::of_slice(&[
169            0i64, 310, 23451, 107, 6743, 68, 62, 43789, 207126, 49004, 705, 2,
170        ])
171        .reshape(&[1, 12]);
172
173        let embeddings = embeddings.forward_t(&pieces, false).unwrap();
174
175        let all_hidden_states = encoder.encode(&embeddings, None, false).unwrap();
176
177        let summed_last_hidden =
178            all_hidden_states
179                .last()
180                .unwrap()
181                .output()
182                .sum_dim(-1, false, Kind::Float);
183
184        let sums: ArrayD<f32> = (&summed_last_hidden).try_into().unwrap();
185
186        assert_abs_diff_eq!(
187            sums,
188            (array![[
189                20.9693, 19.7502, 17.0594, 19.0700, 19.0065, 19.6254, 18.9379, 18.9275, 18.8922,
190                18.9505, 19.2682, 20.9411
191            ]])
192            .into_dyn(),
193            epsilon = 1e-4
194        );
195    }
196}