syntaxdot_transformers/models/squeeze_albert/
mod.rs1use std::borrow::Borrow;
15
16use serde::Deserialize;
17use syntaxdot_tch_ext::PathExt;
18use tch::nn::Module;
19use tch::Tensor;
20
21use crate::activations::Activation;
22use crate::error::TransformerError;
23use crate::models::albert::{AlbertConfig, AlbertEmbeddingProjection};
24use crate::models::bert::BertConfig;
25use crate::models::layer_output::LayerOutput;
26use crate::models::squeeze_bert::{SqueezeBertConfig, SqueezeBertLayer};
27use crate::models::traits::WordEmbeddingsConfig;
28use crate::models::Encoder;
29use crate::util::LogitsMask;
30
31#[derive(Debug, Deserialize)]
43#[serde(default)]
44pub struct SqueezeAlbertConfig {
45 pub attention_probs_dropout_prob: f64,
46 pub embedding_size: i64,
47 pub hidden_act: Activation,
48 pub hidden_dropout_prob: f64,
49 pub hidden_size: i64,
50 pub initializer_range: f64,
51 pub inner_group_num: i64,
52 pub intermediate_size: i64,
53 pub max_position_embeddings: i64,
54 pub num_attention_heads: i64,
55 pub num_hidden_groups: i64,
56 pub num_hidden_layers: i64,
57 pub type_vocab_size: i64,
58 pub vocab_size: i64,
59 pub q_groups: i64,
60 pub k_groups: i64,
61 pub v_groups: i64,
62 pub post_attention_groups: i64,
63 pub intermediate_groups: i64,
64 pub output_groups: i64,
65}
66
67impl Default for SqueezeAlbertConfig {
68 fn default() -> Self {
69 SqueezeAlbertConfig {
70 attention_probs_dropout_prob: 0.,
71 embedding_size: 128,
72 hidden_act: Activation::GeluNew,
73 hidden_dropout_prob: 0.,
74 hidden_size: 768,
75 initializer_range: 0.02,
76 inner_group_num: 1,
77 intermediate_size: 3072,
78 max_position_embeddings: 512,
79 num_attention_heads: 12,
80 num_hidden_groups: 1,
81 num_hidden_layers: 12,
82 type_vocab_size: 2,
83 vocab_size: 30000,
84 q_groups: 4,
85 k_groups: 4,
86 v_groups: 4,
87 post_attention_groups: 1,
88 intermediate_groups: 4,
89 output_groups: 4,
90 }
91 }
92}
93
94impl From<&SqueezeAlbertConfig> for AlbertConfig {
95 fn from(albert_config: &SqueezeAlbertConfig) -> Self {
96 AlbertConfig {
97 attention_probs_dropout_prob: albert_config.attention_probs_dropout_prob,
98 embedding_size: albert_config.embedding_size,
99 hidden_act: albert_config.hidden_act,
100 hidden_dropout_prob: albert_config.hidden_dropout_prob,
101 hidden_size: albert_config.hidden_size,
102 initializer_range: albert_config.initializer_range,
103 inner_group_num: albert_config.inner_group_num,
104 intermediate_size: albert_config.intermediate_size,
105 max_position_embeddings: albert_config.max_position_embeddings,
106 num_attention_heads: albert_config.num_attention_heads,
107 num_hidden_groups: albert_config.num_hidden_groups,
108 num_hidden_layers: albert_config.num_hidden_layers,
109 type_vocab_size: albert_config.type_vocab_size,
110 vocab_size: albert_config.vocab_size,
111 }
112 }
113}
114
115impl From<&SqueezeAlbertConfig> for BertConfig {
116 fn from(albert_config: &SqueezeAlbertConfig) -> Self {
117 BertConfig {
118 attention_probs_dropout_prob: albert_config.attention_probs_dropout_prob,
119 hidden_act: albert_config.hidden_act,
120 hidden_dropout_prob: albert_config.hidden_dropout_prob,
121 hidden_size: albert_config.hidden_size,
122 initializer_range: albert_config.initializer_range,
123 intermediate_size: albert_config.intermediate_size,
124 layer_norm_eps: 1e-12,
125 max_position_embeddings: albert_config.max_position_embeddings,
126 num_attention_heads: albert_config.num_attention_heads,
127 num_hidden_layers: albert_config.num_hidden_layers,
128 type_vocab_size: albert_config.type_vocab_size,
129 vocab_size: albert_config.vocab_size,
130 }
131 }
132}
133
134impl From<&SqueezeAlbertConfig> for SqueezeBertConfig {
135 fn from(config: &SqueezeAlbertConfig) -> Self {
136 SqueezeBertConfig {
137 attention_probs_dropout_prob: config.attention_probs_dropout_prob,
138 embedding_size: config.embedding_size,
139 hidden_act: config.hidden_act,
140 hidden_dropout_prob: config.hidden_dropout_prob,
141 hidden_size: config.hidden_size,
142 initializer_range: config.initializer_range,
143 intermediate_size: config.intermediate_size,
144 layer_norm_eps: config.layer_norm_eps(),
145 max_position_embeddings: config.max_position_embeddings,
146 num_attention_heads: config.num_attention_heads,
147 num_hidden_layers: config.num_hidden_layers,
148 type_vocab_size: config.type_vocab_size,
149 vocab_size: config.vocab_size,
150 q_groups: config.q_groups,
151 k_groups: config.k_groups,
152 v_groups: config.v_groups,
153 post_attention_groups: config.post_attention_groups,
154 intermediate_groups: config.intermediate_groups,
155 output_groups: config.output_groups,
156 }
157 }
158}
159
160impl WordEmbeddingsConfig for SqueezeAlbertConfig {
161 fn dims(&self) -> i64 {
162 self.embedding_size
163 }
164
165 fn dropout(&self) -> f64 {
166 self.hidden_dropout_prob
167 }
168
169 fn initializer_range(&self) -> f64 {
170 self.initializer_range
171 }
172
173 fn layer_norm_eps(&self) -> f64 {
174 1e-12
175 }
176
177 fn vocab_size(&self) -> i64 {
178 self.vocab_size
179 }
180}
181
182#[derive(Debug)]
189pub struct SqueezeAlbertEncoder {
190 groups: Vec<SqueezeBertLayer>,
191 n_layers: i64,
192 projection: AlbertEmbeddingProjection,
193}
194
195impl SqueezeAlbertEncoder {
196 pub fn new<'a>(
197 vs: impl Borrow<PathExt<'a>>,
198 config: &SqueezeAlbertConfig,
199 ) -> Result<Self, TransformerError> {
200 assert!(
201 config.num_hidden_groups > 0,
202 "Need at least 1 hidden group, got: {}",
203 config.num_hidden_groups
204 );
205
206 let vs = vs.borrow();
207
208 let mut groups = Vec::with_capacity(config.num_hidden_groups as usize);
209 for group_idx in 0..config.num_hidden_groups {
210 groups.push(SqueezeBertLayer::new(
211 vs.sub(format!("group_{}", group_idx)).sub("inner_group_0"),
212 &config.into(),
213 )?);
214 }
215 let albert_config: AlbertConfig = config.into();
216 let projection = AlbertEmbeddingProjection::new(vs, &albert_config)?;
217
218 Ok(SqueezeAlbertEncoder {
219 groups,
220 n_layers: config.num_hidden_layers,
221 projection,
222 })
223 }
224}
225
226impl Encoder for SqueezeAlbertEncoder {
227 fn encode(
228 &self,
229 input: &Tensor,
230 attention_mask: Option<&Tensor>,
231 train: bool,
232 ) -> Result<Vec<LayerOutput>, TransformerError> {
233 let hidden_states = self.projection.forward(input);
234
235 let input = hidden_states.f_permute(&[0, 2, 1])?;
236
237 let mut all_layer_outputs = Vec::with_capacity(self.n_layers as usize + 1);
238 all_layer_outputs.push(LayerOutput::Embedding(hidden_states.shallow_clone()));
239
240 let attention_mask = attention_mask.map(LogitsMask::from_bool_mask).transpose()?;
241
242 let layers_per_group = self.n_layers as usize / self.groups.len();
243
244 let mut hidden_states = input;
245 for idx in 0..self.n_layers {
246 let layer_output = self.groups[idx as usize / layers_per_group].forward_t(
247 &hidden_states,
248 attention_mask.as_ref(),
249 train,
250 )?;
251
252 hidden_states = layer_output.output().shallow_clone();
253
254 all_layer_outputs.push(layer_output);
255 }
256
257 for layer_output in &mut all_layer_outputs {
259 *layer_output.output_mut() = layer_output.output().f_permute(&[0, 2, 1])?;
260 }
261
262 Ok(all_layer_outputs)
263 }
264
265 fn n_layers(&self) -> i64 {
266 self.n_layers + 1
267 }
268}