1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
use std::borrow::Borrow;

use syntaxdot_tch_ext::PathExt;
use tch::nn::{Linear, Module};
use tch::Tensor;

use crate::models::albert::AlbertConfig;
use crate::models::bert::{bert_linear, BertConfig, BertEmbeddings};
use crate::module::FallibleModuleT;
use crate::TransformerError;

/// ALBERT embeddings.
///
/// These embeddings are the same as BERT embeddings. However, we do
/// some wrapping to ensure that the right embedding dimensionality is
/// used.
#[derive(Debug)]
pub struct AlbertEmbeddings {
    embeddings: BertEmbeddings,
}

impl AlbertEmbeddings {
    /// Construct new ALBERT embeddings with the given variable store
    /// and ALBERT configuration.
    pub fn new<'a>(
        vs: impl Borrow<PathExt<'a>>,
        config: &AlbertConfig,
    ) -> Result<Self, TransformerError> {
        let vs = vs.borrow();

        // BERT uses the hidden size as the vocab size.
        let mut bert_config: BertConfig = config.into();
        bert_config.hidden_size = config.embedding_size;

        let embeddings = BertEmbeddings::new(vs, &bert_config)?;

        Ok(AlbertEmbeddings { embeddings })
    }

    pub fn forward(
        &self,
        input_ids: &Tensor,
        token_type_ids: Option<&Tensor>,
        position_ids: Option<&Tensor>,
        train: bool,
    ) -> Result<Tensor, TransformerError> {
        self.embeddings
            .forward(input_ids, token_type_ids, position_ids, train)
    }
}

impl FallibleModuleT for AlbertEmbeddings {
    type Error = TransformerError;

    fn forward_t(&self, input: &Tensor, train: bool) -> Result<Tensor, Self::Error> {
        self.forward(input, None, None, train)
    }
}

/// Projection of ALBERT embeddings into the encoder hidden size.
#[derive(Debug)]
pub struct AlbertEmbeddingProjection {
    projection: Linear,
}

impl AlbertEmbeddingProjection {
    pub fn new<'a>(
        vs: impl Borrow<PathExt<'a>>,
        config: &AlbertConfig,
    ) -> Result<Self, TransformerError> {
        let vs = vs.borrow();

        let projection = bert_linear(
            vs / "embedding_projection",
            &config.into(),
            config.embedding_size,
            config.hidden_size,
            "weight",
            "bias",
        )?;

        Ok(AlbertEmbeddingProjection { projection })
    }
}

impl Module for AlbertEmbeddingProjection {
    fn forward(&self, input: &Tensor) -> Tensor {
        self.projection.forward(input)
    }
}

#[cfg(feature = "model-tests")]
#[cfg(test)]
mod tests {
    use std::collections::BTreeSet;

    use maplit::btreeset;
    use syntaxdot_tch_ext::RootExt;
    use tch::nn::VarStore;
    use tch::Device;

    use crate::activations::Activation;
    use crate::models::albert::{AlbertConfig, AlbertEmbeddings};

    fn albert_config() -> AlbertConfig {
        AlbertConfig {
            attention_probs_dropout_prob: 0.,
            embedding_size: 128,
            hidden_act: Activation::GeluNew,
            hidden_dropout_prob: 0.,
            hidden_size: 768,
            initializer_range: 0.02,
            inner_group_num: 1,
            intermediate_size: 3072,
            max_position_embeddings: 512,
            num_attention_heads: 12,
            num_hidden_groups: 1,
            num_hidden_layers: 12,
            type_vocab_size: 2,
            vocab_size: 30000,
        }
    }

    fn varstore_variables(vs: &VarStore) -> BTreeSet<String> {
        vs.variables()
            .into_iter()
            .map(|(k, _)| k)
            .collect::<BTreeSet<_>>()
    }

    #[test]
    fn albert_embeddings_names() {
        let config = albert_config();

        let vs = VarStore::new(Device::Cpu);
        let root = vs.root_ext(|_| 0);

        let _embeddings = AlbertEmbeddings::new(root, &config);

        let variables = varstore_variables(&vs);

        assert_eq!(
            variables,
            btreeset![
                "layer_norm.bias".to_string(),
                "layer_norm.weight".to_string(),
                "position_embeddings.embeddings".to_string(),
                "token_type_embeddings.embeddings".to_string(),
                "word_embeddings.embeddings".to_string()
            ]
        );
    }
}