syntaxdot_transformers/models/albert/
embeddings.rs

1use std::borrow::Borrow;
2
3use syntaxdot_tch_ext::PathExt;
4use tch::nn::{Linear, Module};
5use tch::Tensor;
6
7use crate::models::albert::AlbertConfig;
8use crate::models::bert::{bert_linear, BertConfig, BertEmbeddings};
9use crate::module::FallibleModuleT;
10use crate::TransformerError;
11
12/// ALBERT embeddings.
13///
14/// These embeddings are the same as BERT embeddings. However, we do
15/// some wrapping to ensure that the right embedding dimensionality is
16/// used.
17#[derive(Debug)]
18pub struct AlbertEmbeddings {
19    embeddings: BertEmbeddings,
20}
21
22impl AlbertEmbeddings {
23    /// Construct new ALBERT embeddings with the given variable store
24    /// and ALBERT configuration.
25    pub fn new<'a>(
26        vs: impl Borrow<PathExt<'a>>,
27        config: &AlbertConfig,
28    ) -> Result<Self, TransformerError> {
29        let vs = vs.borrow();
30
31        // BERT uses the hidden size as the vocab size.
32        let mut bert_config: BertConfig = config.into();
33        bert_config.hidden_size = config.embedding_size;
34
35        let embeddings = BertEmbeddings::new(vs, &bert_config)?;
36
37        Ok(AlbertEmbeddings { embeddings })
38    }
39
40    pub fn forward(
41        &self,
42        input_ids: &Tensor,
43        token_type_ids: Option<&Tensor>,
44        position_ids: Option<&Tensor>,
45        train: bool,
46    ) -> Result<Tensor, TransformerError> {
47        self.embeddings
48            .forward(input_ids, token_type_ids, position_ids, train)
49    }
50}
51
52impl FallibleModuleT for AlbertEmbeddings {
53    type Error = TransformerError;
54
55    fn forward_t(&self, input: &Tensor, train: bool) -> Result<Tensor, Self::Error> {
56        self.forward(input, None, None, train)
57    }
58}
59
60/// Projection of ALBERT embeddings into the encoder hidden size.
61#[derive(Debug)]
62pub struct AlbertEmbeddingProjection {
63    projection: Linear,
64}
65
66impl AlbertEmbeddingProjection {
67    pub fn new<'a>(
68        vs: impl Borrow<PathExt<'a>>,
69        config: &AlbertConfig,
70    ) -> Result<Self, TransformerError> {
71        let vs = vs.borrow();
72
73        let projection = bert_linear(
74            vs / "embedding_projection",
75            &config.into(),
76            config.embedding_size,
77            config.hidden_size,
78            "weight",
79            "bias",
80        )?;
81
82        Ok(AlbertEmbeddingProjection { projection })
83    }
84}
85
86impl Module for AlbertEmbeddingProjection {
87    fn forward(&self, input: &Tensor) -> Tensor {
88        self.projection.forward(input)
89    }
90}
91
92#[cfg(feature = "model-tests")]
93#[cfg(test)]
94mod tests {
95    use std::collections::BTreeSet;
96
97    use maplit::btreeset;
98    use syntaxdot_tch_ext::RootExt;
99    use tch::nn::VarStore;
100    use tch::Device;
101
102    use crate::activations::Activation;
103    use crate::models::albert::{AlbertConfig, AlbertEmbeddings};
104
105    fn albert_config() -> AlbertConfig {
106        AlbertConfig {
107            attention_probs_dropout_prob: 0.,
108            embedding_size: 128,
109            hidden_act: Activation::GeluNew,
110            hidden_dropout_prob: 0.,
111            hidden_size: 768,
112            initializer_range: 0.02,
113            inner_group_num: 1,
114            intermediate_size: 3072,
115            max_position_embeddings: 512,
116            num_attention_heads: 12,
117            num_hidden_groups: 1,
118            num_hidden_layers: 12,
119            type_vocab_size: 2,
120            vocab_size: 30000,
121        }
122    }
123
124    fn varstore_variables(vs: &VarStore) -> BTreeSet<String> {
125        vs.variables()
126            .into_iter()
127            .map(|(k, _)| k)
128            .collect::<BTreeSet<_>>()
129    }
130
131    #[test]
132    fn albert_embeddings_names() {
133        let config = albert_config();
134
135        let vs = VarStore::new(Device::Cpu);
136        let root = vs.root_ext(|_| 0);
137
138        let _embeddings = AlbertEmbeddings::new(root, &config);
139
140        let variables = varstore_variables(&vs);
141
142        assert_eq!(
143            variables,
144            btreeset![
145                "layer_norm.bias".to_string(),
146                "layer_norm.weight".to_string(),
147                "position_embeddings.embeddings".to_string(),
148                "token_type_embeddings.embeddings".to_string(),
149                "word_embeddings.embeddings".to_string()
150            ]
151        );
152    }
153}