1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct CommandRConfig {
6 pub model_name: String,
8 pub vocab_size: usize,
10 pub hidden_size: usize,
12 pub num_attention_heads: usize,
14 pub num_key_value_heads: usize,
16 pub num_hidden_layers: usize,
18 pub intermediate_size: usize,
20 pub max_sequence_length: usize,
22 pub rms_norm_eps: f32,
24 pub rope_theta: f32,
26 pub rope_scaling_factor: f32,
28 pub attention_dropout: f32,
30 pub hidden_dropout: f32,
32 pub use_bias: bool,
34 pub tie_word_embeddings: bool,
36 pub activation_function: String,
38 pub layer_norm_eps: f32,
40 pub use_logit_bias: bool,
42 pub logit_scale: f32,
44 pub use_sliding_window: bool,
46 pub sliding_window_size: usize,
48 pub use_flash_attention: bool,
50 pub pad_token_id: Option<usize>,
52 pub bos_token_id: Option<usize>,
54 pub eos_token_id: Option<usize>,
56 pub model_type: String,
58 pub torch_dtype: String,
60 pub transformers_version: String,
62}
63
64impl Default for CommandRConfig {
65 fn default() -> Self {
66 Self {
68 model_name: "command-r".to_string(),
69 vocab_size: 256000,
70 hidden_size: 8192,
71 num_attention_heads: 64,
72 num_key_value_heads: 64,
73 num_hidden_layers: 40,
74 intermediate_size: 22528,
75 max_sequence_length: 131072,
76 rms_norm_eps: 1e-5,
77 rope_theta: 10000.0,
78 rope_scaling_factor: 1.0,
79 attention_dropout: 0.0,
80 hidden_dropout: 0.0,
81 use_bias: false,
82 tie_word_embeddings: false,
83 activation_function: "silu".to_string(),
84 layer_norm_eps: 1e-5,
85 use_logit_bias: false,
86 logit_scale: 1.0,
87 use_sliding_window: false,
88 sliding_window_size: 4096,
89 use_flash_attention: true,
90 pad_token_id: Some(0),
91 bos_token_id: Some(5),
92 eos_token_id: Some(255001),
93 model_type: "command-r".to_string(),
94 torch_dtype: "bfloat16".to_string(),
95 transformers_version: "4.39.0".to_string(),
96 }
97 }
98}
99
100impl CommandRConfig {
101 pub fn tiny() -> Self {
104 Self {
105 model_name: "command-r-tiny".to_string(),
106 vocab_size: 1000,
107 hidden_size: 64,
108 num_attention_heads: 4,
109 num_key_value_heads: 4,
110 num_hidden_layers: 2,
111 intermediate_size: 128,
112 max_sequence_length: 128,
113 rms_norm_eps: 1e-5,
114 rope_theta: 10000.0,
115 rope_scaling_factor: 1.0,
116 attention_dropout: 0.0,
117 hidden_dropout: 0.0,
118 use_bias: false,
119 tie_word_embeddings: false,
120 activation_function: "silu".to_string(),
121 layer_norm_eps: 1e-5,
122 use_logit_bias: false,
123 logit_scale: 1.0,
124 use_sliding_window: false,
125 sliding_window_size: 64,
126 use_flash_attention: false,
127 pad_token_id: Some(0),
128 bos_token_id: Some(1),
129 eos_token_id: Some(2),
130 model_type: "command-r".to_string(),
131 torch_dtype: "float32".to_string(),
132 transformers_version: "4.39.0".to_string(),
133 }
134 }
135
136 pub fn command_r() -> Self {
138 Self::default()
139 }
140
141 pub fn command_r_plus() -> Self {
143 Self {
144 model_name: "command-r-plus".to_string(),
145 vocab_size: 256000,
146 hidden_size: 12288,
147 num_attention_heads: 96,
148 num_key_value_heads: 96,
149 num_hidden_layers: 64,
150 intermediate_size: 33792,
151 max_sequence_length: 131072,
152 rms_norm_eps: 1e-5,
153 rope_theta: 10000.0,
154 rope_scaling_factor: 1.0,
155 attention_dropout: 0.0,
156 hidden_dropout: 0.0,
157 use_bias: false,
158 tie_word_embeddings: false,
159 activation_function: "silu".to_string(),
160 layer_norm_eps: 1e-5,
161 use_logit_bias: false,
162 logit_scale: 1.0,
163 use_sliding_window: false,
164 sliding_window_size: 4096,
165 use_flash_attention: true,
166 pad_token_id: Some(0),
167 bos_token_id: Some(5),
168 eos_token_id: Some(255001),
169 model_type: "command-r-plus".to_string(),
170 torch_dtype: "bfloat16".to_string(),
171 transformers_version: "4.39.0".to_string(),
172 }
173 }
174
175 pub fn command_r_08_2024() -> Self {
177 Self {
178 model_name: "command-r-08-2024".to_string(),
179 vocab_size: 256000,
180 hidden_size: 8192,
181 num_attention_heads: 64,
182 num_key_value_heads: 64,
183 num_hidden_layers: 40,
184 intermediate_size: 22528,
185 max_sequence_length: 131072,
186 rms_norm_eps: 1e-5,
187 rope_theta: 10000.0,
188 rope_scaling_factor: 1.0,
189 attention_dropout: 0.0,
190 hidden_dropout: 0.0,
191 use_bias: false,
192 tie_word_embeddings: false,
193 activation_function: "silu".to_string(),
194 layer_norm_eps: 1e-5,
195 use_logit_bias: false,
196 logit_scale: 1.0,
197 use_sliding_window: false,
198 sliding_window_size: 4096,
199 use_flash_attention: true,
200 pad_token_id: Some(0),
201 bos_token_id: Some(5),
202 eos_token_id: Some(255001),
203 model_type: "command-r-08-2024".to_string(),
204 torch_dtype: "bfloat16".to_string(),
205 transformers_version: "4.39.0".to_string(),
206 }
207 }
208
209 pub fn command_r_plus_08_2024() -> Self {
211 Self {
212 model_name: "command-r-plus-08-2024".to_string(),
213 vocab_size: 256000,
214 hidden_size: 12288,
215 num_attention_heads: 96,
216 num_key_value_heads: 96,
217 num_hidden_layers: 64,
218 intermediate_size: 33792,
219 max_sequence_length: 131072,
220 rms_norm_eps: 1e-5,
221 rope_theta: 10000.0,
222 rope_scaling_factor: 1.0,
223 attention_dropout: 0.0,
224 hidden_dropout: 0.0,
225 use_bias: false,
226 tie_word_embeddings: false,
227 activation_function: "silu".to_string(),
228 layer_norm_eps: 1e-5,
229 use_logit_bias: false,
230 logit_scale: 1.0,
231 use_sliding_window: false,
232 sliding_window_size: 4096,
233 use_flash_attention: true,
234 pad_token_id: Some(0),
235 bos_token_id: Some(5),
236 eos_token_id: Some(255001),
237 model_type: "command-r-plus-08-2024".to_string(),
238 torch_dtype: "bfloat16".to_string(),
239 transformers_version: "4.39.0".to_string(),
240 }
241 }
242
243 pub fn head_dim(&self) -> usize {
245 self.hidden_size / self.num_attention_heads
246 }
247
248 pub fn kv_head_dim(&self) -> usize {
250 self.hidden_size / self.num_key_value_heads
251 }
252
253 pub fn is_gqa(&self) -> bool {
255 self.num_key_value_heads != self.num_attention_heads
256 }
257
258 pub fn num_query_groups(&self) -> usize {
260 self.num_attention_heads / self.num_key_value_heads
261 }
262
263 pub fn validate(&self) -> Result<(), String> {
265 if self.vocab_size == 0 {
266 return Err("vocab_size must be greater than 0".to_string());
267 }
268 if self.hidden_size == 0 {
269 return Err("hidden_size must be greater than 0".to_string());
270 }
271 if self.num_attention_heads == 0 {
272 return Err("num_attention_heads must be greater than 0".to_string());
273 }
274 if self.num_key_value_heads == 0 {
275 return Err("num_key_value_heads must be greater than 0".to_string());
276 }
277 if self.num_hidden_layers == 0 {
278 return Err("num_hidden_layers must be greater than 0".to_string());
279 }
280 if self.intermediate_size == 0 {
281 return Err("intermediate_size must be greater than 0".to_string());
282 }
283 if self.max_sequence_length == 0 {
284 return Err("max_sequence_length must be greater than 0".to_string());
285 }
286 if self.hidden_size % self.num_attention_heads != 0 {
287 return Err("hidden_size must be divisible by num_attention_heads".to_string());
288 }
289 if self.num_attention_heads % self.num_key_value_heads != 0 {
290 return Err("num_attention_heads must be divisible by num_key_value_heads".to_string());
291 }
292
293 Ok(())
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300
301 #[test]
302 fn test_command_r_config() {
303 let config = CommandRConfig::command_r();
304 assert_eq!(config.model_name, "command-r");
305 assert_eq!(config.vocab_size, 256000);
306 assert_eq!(config.hidden_size, 8192);
307 assert_eq!(config.num_attention_heads, 64);
308 assert_eq!(config.num_hidden_layers, 40);
309 assert!(config.validate().is_ok());
310 }
311
312 #[test]
313 fn test_command_r_plus_config() {
314 let config = CommandRConfig::command_r_plus();
315 assert_eq!(config.model_name, "command-r-plus");
316 assert_eq!(config.vocab_size, 256000);
317 assert_eq!(config.hidden_size, 12288);
318 assert_eq!(config.num_attention_heads, 96);
319 assert_eq!(config.num_hidden_layers, 64);
320 assert!(config.validate().is_ok());
321 }
322
323 #[test]
324 fn test_head_dim_calculation() {
325 let config = CommandRConfig::command_r();
326 assert_eq!(config.head_dim(), 128); let config_plus = CommandRConfig::command_r_plus();
329 assert_eq!(config_plus.head_dim(), 128); }
331
332 #[test]
333 fn test_gqa_detection() {
334 let config = CommandRConfig::command_r();
335 assert!(!config.is_gqa()); let mut config_gqa = config.clone();
338 config_gqa.num_key_value_heads = 32;
339 assert!(config_gqa.is_gqa());
340 assert_eq!(config_gqa.num_query_groups(), 2); }
342
343 #[test]
344 fn test_config_validation() {
345 let mut config = CommandRConfig::default();
346 assert!(config.validate().is_ok());
347
348 config.vocab_size = 0;
349 assert!(config.validate().is_err());
350
351 config.vocab_size = 256000;
352 config.hidden_size = 100;
353 config.num_attention_heads = 64;
354 assert!(config.validate().is_err()); }
356}