syntaxdot_transformers/models/albert/
embeddings.rs1use std::borrow::Borrow;
2
3use syntaxdot_tch_ext::PathExt;
4use tch::nn::{Linear, Module};
5use tch::Tensor;
6
7use crate::models::albert::AlbertConfig;
8use crate::models::bert::{bert_linear, BertConfig, BertEmbeddings};
9use crate::module::FallibleModuleT;
10use crate::TransformerError;
11
12#[derive(Debug)]
18pub struct AlbertEmbeddings {
19 embeddings: BertEmbeddings,
20}
21
22impl AlbertEmbeddings {
23 pub fn new<'a>(
26 vs: impl Borrow<PathExt<'a>>,
27 config: &AlbertConfig,
28 ) -> Result<Self, TransformerError> {
29 let vs = vs.borrow();
30
31 let mut bert_config: BertConfig = config.into();
33 bert_config.hidden_size = config.embedding_size;
34
35 let embeddings = BertEmbeddings::new(vs, &bert_config)?;
36
37 Ok(AlbertEmbeddings { embeddings })
38 }
39
40 pub fn forward(
41 &self,
42 input_ids: &Tensor,
43 token_type_ids: Option<&Tensor>,
44 position_ids: Option<&Tensor>,
45 train: bool,
46 ) -> Result<Tensor, TransformerError> {
47 self.embeddings
48 .forward(input_ids, token_type_ids, position_ids, train)
49 }
50}
51
52impl FallibleModuleT for AlbertEmbeddings {
53 type Error = TransformerError;
54
55 fn forward_t(&self, input: &Tensor, train: bool) -> Result<Tensor, Self::Error> {
56 self.forward(input, None, None, train)
57 }
58}
59
60#[derive(Debug)]
62pub struct AlbertEmbeddingProjection {
63 projection: Linear,
64}
65
66impl AlbertEmbeddingProjection {
67 pub fn new<'a>(
68 vs: impl Borrow<PathExt<'a>>,
69 config: &AlbertConfig,
70 ) -> Result<Self, TransformerError> {
71 let vs = vs.borrow();
72
73 let projection = bert_linear(
74 vs / "embedding_projection",
75 &config.into(),
76 config.embedding_size,
77 config.hidden_size,
78 "weight",
79 "bias",
80 )?;
81
82 Ok(AlbertEmbeddingProjection { projection })
83 }
84}
85
86impl Module for AlbertEmbeddingProjection {
87 fn forward(&self, input: &Tensor) -> Tensor {
88 self.projection.forward(input)
89 }
90}
91
92#[cfg(feature = "model-tests")]
93#[cfg(test)]
94mod tests {
95 use std::collections::BTreeSet;
96
97 use maplit::btreeset;
98 use syntaxdot_tch_ext::RootExt;
99 use tch::nn::VarStore;
100 use tch::Device;
101
102 use crate::activations::Activation;
103 use crate::models::albert::{AlbertConfig, AlbertEmbeddings};
104
105 fn albert_config() -> AlbertConfig {
106 AlbertConfig {
107 attention_probs_dropout_prob: 0.,
108 embedding_size: 128,
109 hidden_act: Activation::GeluNew,
110 hidden_dropout_prob: 0.,
111 hidden_size: 768,
112 initializer_range: 0.02,
113 inner_group_num: 1,
114 intermediate_size: 3072,
115 max_position_embeddings: 512,
116 num_attention_heads: 12,
117 num_hidden_groups: 1,
118 num_hidden_layers: 12,
119 type_vocab_size: 2,
120 vocab_size: 30000,
121 }
122 }
123
124 fn varstore_variables(vs: &VarStore) -> BTreeSet<String> {
125 vs.variables()
126 .into_iter()
127 .map(|(k, _)| k)
128 .collect::<BTreeSet<_>>()
129 }
130
131 #[test]
132 fn albert_embeddings_names() {
133 let config = albert_config();
134
135 let vs = VarStore::new(Device::Cpu);
136 let root = vs.root_ext(|_| 0);
137
138 let _embeddings = AlbertEmbeddings::new(root, &config);
139
140 let variables = varstore_variables(&vs);
141
142 assert_eq!(
143 variables,
144 btreeset![
145 "layer_norm.bias".to_string(),
146 "layer_norm.weight".to_string(),
147 "position_embeddings.embeddings".to_string(),
148 "token_type_embeddings.embeddings".to_string(),
149 "word_embeddings.embeddings".to_string()
150 ]
151 );
152 }
153}