trustformers_models/retnet/
config.rs1use serde::{Deserialize, Serialize};
2use trustformers_core::errors::invalid_config;
3use trustformers_core::traits::Config;
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct RetNetConfig {
9 pub vocab_size: usize,
10 pub hidden_size: usize,
11 pub num_hidden_layers: usize,
12 pub num_heads: usize,
13 pub intermediate_size: usize,
14 pub hidden_act: String,
15 pub hidden_dropout_prob: f32,
16 pub attention_dropout_prob: f32,
17 pub max_position_embeddings: usize,
18 pub initializer_range: f32,
19 pub layer_norm_eps: f32,
20 pub pad_token_id: u32,
21 pub bos_token_id: u32,
22 pub eos_token_id: u32,
23
24 pub use_bias: bool, pub use_glu: bool, pub use_norm_bias: bool, pub deepnorm: bool, pub dropout_module: String, pub activation_dropout: f32, pub attention_dropout: f32, pub retention_heads: usize, pub value_factor: f32, pub gate_fn: String, pub tensor_parallel_degree: usize, pub sequence_parallel: bool, pub fuse_norm: bool, pub no_output_layer: bool, pub layernorm_embedding: bool, pub chunking: bool, pub chunk_size: usize, }
43
44impl Default for RetNetConfig {
45 fn default() -> Self {
46 Self {
47 vocab_size: 32000,
48 hidden_size: 2048,
49 num_hidden_layers: 24,
50 num_heads: 16,
51 intermediate_size: 8192,
52 hidden_act: "swish".to_string(),
53 hidden_dropout_prob: 0.0,
54 attention_dropout_prob: 0.0,
55 max_position_embeddings: 2048,
56 initializer_range: 0.02,
57 layer_norm_eps: 1e-6,
58 pad_token_id: 0,
59 bos_token_id: 1,
60 eos_token_id: 2,
61
62 use_bias: false,
64 use_glu: true,
65 use_norm_bias: false,
66 deepnorm: true,
67 dropout_module: "dropout".to_string(),
68 activation_dropout: 0.0,
69 attention_dropout: 0.0,
70 retention_heads: 16,
71 value_factor: 2.0,
72 gate_fn: "swish".to_string(),
73 tensor_parallel_degree: 1,
74 sequence_parallel: false,
75 fuse_norm: false,
76 no_output_layer: false,
77 layernorm_embedding: false,
78 chunking: false,
79 chunk_size: 512,
80 }
81 }
82}
83
84impl Config for RetNetConfig {
85 fn validate(&self) -> trustformers_core::errors::Result<()> {
86 if self.hidden_size % self.num_heads != 0 {
87 return Err(invalid_config(
88 "config_field",
89 "hidden_size must be divisible by num_heads".to_string(),
90 ));
91 }
92
93 if self.hidden_size % self.retention_heads != 0 {
94 return Err(invalid_config(
95 "config_field",
96 "hidden_size must be divisible by retention_heads".to_string(),
97 ));
98 }
99
100 if self.chunk_size > self.max_position_embeddings {
101 return Err(invalid_config(
102 "config_field",
103 "chunk_size should not exceed max_position_embeddings".to_string(),
104 ));
105 }
106
107 Ok(())
108 }
109
110 fn architecture(&self) -> &'static str {
111 "RetNet"
112 }
113}
114
115impl RetNetConfig {
116 pub fn retnet_small() -> Self {
118 Self {
119 hidden_size: 2048,
120 num_hidden_layers: 24,
121 num_heads: 16,
122 intermediate_size: 8192,
123 retention_heads: 16,
124 max_position_embeddings: 2048,
125 ..Self::default()
126 }
127 }
128
129 pub fn retnet_medium() -> Self {
131 Self {
132 hidden_size: 2560,
133 num_hidden_layers: 32,
134 num_heads: 20,
135 intermediate_size: 10240,
136 retention_heads: 20,
137 max_position_embeddings: 2048,
138 ..Self::default()
139 }
140 }
141
142 pub fn retnet_large() -> Self {
144 Self {
145 hidden_size: 4096,
146 num_hidden_layers: 32,
147 num_heads: 32,
148 intermediate_size: 16384,
149 retention_heads: 32,
150 max_position_embeddings: 2048,
151 ..Self::default()
152 }
153 }
154
155 pub fn retnet_xl() -> Self {
157 Self {
158 hidden_size: 5120,
159 num_hidden_layers: 40,
160 num_heads: 40,
161 intermediate_size: 20480,
162 retention_heads: 40,
163 max_position_embeddings: 2048,
164 deepnorm: true,
165 ..Self::default()
166 }
167 }
168
169 pub fn retnet_long() -> Self {
171 Self {
172 max_position_embeddings: 8192,
173 chunking: true,
174 chunk_size: 1024,
175 sequence_parallel: true,
176 ..Self::retnet_medium()
177 }
178 }
179
180 pub fn head_dim(&self) -> usize {
182 self.hidden_size / self.num_heads
183 }
184
185 pub fn retention_head_dim(&self) -> usize {
187 self.hidden_size / self.retention_heads
188 }
189
190 pub fn retention_dim(&self) -> usize {
192 (self.hidden_size as f32 / self.value_factor) as usize
193 }
194
195 pub fn uses_chunking(&self) -> bool {
197 self.chunking && self.chunk_size > 0
198 }
199
200 pub fn memory_advantage(&self) -> f32 {
202 let seq_len = self.max_position_embeddings as f32;
203 let attention_memory = seq_len * seq_len;
204 let retnet_memory = seq_len; attention_memory / retnet_memory
206 }
207
208 pub fn supports_long_sequences(&self) -> bool {
210 self.max_position_embeddings >= 4096 || self.uses_chunking()
211 }
212
213 pub fn deepnorm_alpha(&self) -> f32 {
215 (2.0 * self.num_hidden_layers as f32).powf(0.25)
217 }
218
219 pub fn deepnorm_beta(&self) -> f32 {
221 (8.0 * self.num_hidden_layers as f32).powf(-0.25)
223 }
224}