Skip to main content

trustformers_models/command_r/
config.rs

1use serde::{Deserialize, Serialize};
2
3/// Configuration for Command R models
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct CommandRConfig {
6    /// Model name
7    pub model_name: String,
8    /// Vocabulary size
9    pub vocab_size: usize,
10    /// Hidden size
11    pub hidden_size: usize,
12    /// Number of attention heads
13    pub num_attention_heads: usize,
14    /// Number of key-value heads (for GQA)
15    pub num_key_value_heads: usize,
16    /// Number of hidden layers
17    pub num_hidden_layers: usize,
18    /// Intermediate size in FFN
19    pub intermediate_size: usize,
20    /// Maximum sequence length
21    pub max_sequence_length: usize,
22    /// RMS normalization epsilon
23    pub rms_norm_eps: f32,
24    /// Rope theta
25    pub rope_theta: f32,
26    /// Rope scaling factor
27    pub rope_scaling_factor: f32,
28    /// Attention dropout
29    pub attention_dropout: f32,
30    /// Hidden dropout
31    pub hidden_dropout: f32,
32    /// Use bias in linear layers
33    pub use_bias: bool,
34    /// Tie word embeddings
35    pub tie_word_embeddings: bool,
36    /// Activation function
37    pub activation_function: String,
38    /// Layer norm epsilon
39    pub layer_norm_eps: f32,
40    /// Use logit bias
41    pub use_logit_bias: bool,
42    /// Logit scale
43    pub logit_scale: f32,
44    /// Use sliding window attention
45    pub use_sliding_window: bool,
46    /// Sliding window size
47    pub sliding_window_size: usize,
48    /// Use flash attention
49    pub use_flash_attention: bool,
50    /// Pad token id
51    pub pad_token_id: Option<usize>,
52    /// BOS token id
53    pub bos_token_id: Option<usize>,
54    /// EOS token id
55    pub eos_token_id: Option<usize>,
56    /// Model type
57    pub model_type: String,
58    /// Torch dtype
59    pub torch_dtype: String,
60    /// Transformers version
61    pub transformers_version: String,
62}
63
64impl Default for CommandRConfig {
65    fn default() -> Self {
66        // Default configuration for Command R base model
67        Self {
68            model_name: "command-r".to_string(),
69            vocab_size: 256000,
70            hidden_size: 8192,
71            num_attention_heads: 64,
72            num_key_value_heads: 64,
73            num_hidden_layers: 40,
74            intermediate_size: 22528,
75            max_sequence_length: 131072,
76            rms_norm_eps: 1e-5,
77            rope_theta: 10000.0,
78            rope_scaling_factor: 1.0,
79            attention_dropout: 0.0,
80            hidden_dropout: 0.0,
81            use_bias: false,
82            tie_word_embeddings: false,
83            activation_function: "silu".to_string(),
84            layer_norm_eps: 1e-5,
85            use_logit_bias: false,
86            logit_scale: 1.0,
87            use_sliding_window: false,
88            sliding_window_size: 4096,
89            use_flash_attention: true,
90            pad_token_id: Some(0),
91            bos_token_id: Some(5),
92            eos_token_id: Some(255001),
93            model_type: "command-r".to_string(),
94            torch_dtype: "bfloat16".to_string(),
95            transformers_version: "4.39.0".to_string(),
96        }
97    }
98}
99
100impl CommandRConfig {
101    /// Create a tiny configuration for testing purposes
102    /// Uses very small dimensions to allow fast test execution
103    pub fn tiny() -> Self {
104        Self {
105            model_name: "command-r-tiny".to_string(),
106            vocab_size: 1000,
107            hidden_size: 64,
108            num_attention_heads: 4,
109            num_key_value_heads: 4,
110            num_hidden_layers: 2,
111            intermediate_size: 128,
112            max_sequence_length: 128,
113            rms_norm_eps: 1e-5,
114            rope_theta: 10000.0,
115            rope_scaling_factor: 1.0,
116            attention_dropout: 0.0,
117            hidden_dropout: 0.0,
118            use_bias: false,
119            tie_word_embeddings: false,
120            activation_function: "silu".to_string(),
121            layer_norm_eps: 1e-5,
122            use_logit_bias: false,
123            logit_scale: 1.0,
124            use_sliding_window: false,
125            sliding_window_size: 64,
126            use_flash_attention: false,
127            pad_token_id: Some(0),
128            bos_token_id: Some(1),
129            eos_token_id: Some(2),
130            model_type: "command-r".to_string(),
131            torch_dtype: "float32".to_string(),
132            transformers_version: "4.39.0".to_string(),
133        }
134    }
135
136    /// Create Command R base model configuration
137    pub fn command_r() -> Self {
138        Self::default()
139    }
140
141    /// Create Command R+ model configuration
142    pub fn command_r_plus() -> Self {
143        Self {
144            model_name: "command-r-plus".to_string(),
145            vocab_size: 256000,
146            hidden_size: 12288,
147            num_attention_heads: 96,
148            num_key_value_heads: 96,
149            num_hidden_layers: 64,
150            intermediate_size: 33792,
151            max_sequence_length: 131072,
152            rms_norm_eps: 1e-5,
153            rope_theta: 10000.0,
154            rope_scaling_factor: 1.0,
155            attention_dropout: 0.0,
156            hidden_dropout: 0.0,
157            use_bias: false,
158            tie_word_embeddings: false,
159            activation_function: "silu".to_string(),
160            layer_norm_eps: 1e-5,
161            use_logit_bias: false,
162            logit_scale: 1.0,
163            use_sliding_window: false,
164            sliding_window_size: 4096,
165            use_flash_attention: true,
166            pad_token_id: Some(0),
167            bos_token_id: Some(5),
168            eos_token_id: Some(255001),
169            model_type: "command-r-plus".to_string(),
170            torch_dtype: "bfloat16".to_string(),
171            transformers_version: "4.39.0".to_string(),
172        }
173    }
174
175    /// Create Command R 08-2024 model configuration
176    pub fn command_r_08_2024() -> Self {
177        Self {
178            model_name: "command-r-08-2024".to_string(),
179            vocab_size: 256000,
180            hidden_size: 8192,
181            num_attention_heads: 64,
182            num_key_value_heads: 64,
183            num_hidden_layers: 40,
184            intermediate_size: 22528,
185            max_sequence_length: 131072,
186            rms_norm_eps: 1e-5,
187            rope_theta: 10000.0,
188            rope_scaling_factor: 1.0,
189            attention_dropout: 0.0,
190            hidden_dropout: 0.0,
191            use_bias: false,
192            tie_word_embeddings: false,
193            activation_function: "silu".to_string(),
194            layer_norm_eps: 1e-5,
195            use_logit_bias: false,
196            logit_scale: 1.0,
197            use_sliding_window: false,
198            sliding_window_size: 4096,
199            use_flash_attention: true,
200            pad_token_id: Some(0),
201            bos_token_id: Some(5),
202            eos_token_id: Some(255001),
203            model_type: "command-r-08-2024".to_string(),
204            torch_dtype: "bfloat16".to_string(),
205            transformers_version: "4.39.0".to_string(),
206        }
207    }
208
209    /// Create Command R+ 08-2024 model configuration
210    pub fn command_r_plus_08_2024() -> Self {
211        Self {
212            model_name: "command-r-plus-08-2024".to_string(),
213            vocab_size: 256000,
214            hidden_size: 12288,
215            num_attention_heads: 96,
216            num_key_value_heads: 96,
217            num_hidden_layers: 64,
218            intermediate_size: 33792,
219            max_sequence_length: 131072,
220            rms_norm_eps: 1e-5,
221            rope_theta: 10000.0,
222            rope_scaling_factor: 1.0,
223            attention_dropout: 0.0,
224            hidden_dropout: 0.0,
225            use_bias: false,
226            tie_word_embeddings: false,
227            activation_function: "silu".to_string(),
228            layer_norm_eps: 1e-5,
229            use_logit_bias: false,
230            logit_scale: 1.0,
231            use_sliding_window: false,
232            sliding_window_size: 4096,
233            use_flash_attention: true,
234            pad_token_id: Some(0),
235            bos_token_id: Some(5),
236            eos_token_id: Some(255001),
237            model_type: "command-r-plus-08-2024".to_string(),
238            torch_dtype: "bfloat16".to_string(),
239            transformers_version: "4.39.0".to_string(),
240        }
241    }
242
243    /// Get the head dimension
244    pub fn head_dim(&self) -> usize {
245        self.hidden_size / self.num_attention_heads
246    }
247
248    /// Get the key-value head dimension
249    pub fn kv_head_dim(&self) -> usize {
250        self.hidden_size / self.num_key_value_heads
251    }
252
253    /// Check if grouped query attention is used
254    pub fn is_gqa(&self) -> bool {
255        self.num_key_value_heads != self.num_attention_heads
256    }
257
258    /// Get the number of query groups
259    pub fn num_query_groups(&self) -> usize {
260        self.num_attention_heads / self.num_key_value_heads
261    }
262
263    /// Validate configuration
264    pub fn validate(&self) -> Result<(), String> {
265        if self.vocab_size == 0 {
266            return Err("vocab_size must be greater than 0".to_string());
267        }
268        if self.hidden_size == 0 {
269            return Err("hidden_size must be greater than 0".to_string());
270        }
271        if self.num_attention_heads == 0 {
272            return Err("num_attention_heads must be greater than 0".to_string());
273        }
274        if self.num_key_value_heads == 0 {
275            return Err("num_key_value_heads must be greater than 0".to_string());
276        }
277        if self.num_hidden_layers == 0 {
278            return Err("num_hidden_layers must be greater than 0".to_string());
279        }
280        if self.intermediate_size == 0 {
281            return Err("intermediate_size must be greater than 0".to_string());
282        }
283        if self.max_sequence_length == 0 {
284            return Err("max_sequence_length must be greater than 0".to_string());
285        }
286        if self.hidden_size % self.num_attention_heads != 0 {
287            return Err("hidden_size must be divisible by num_attention_heads".to_string());
288        }
289        if self.num_attention_heads % self.num_key_value_heads != 0 {
290            return Err("num_attention_heads must be divisible by num_key_value_heads".to_string());
291        }
292
293        Ok(())
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300
301    #[test]
302    fn test_command_r_config() {
303        let config = CommandRConfig::command_r();
304        assert_eq!(config.model_name, "command-r");
305        assert_eq!(config.vocab_size, 256000);
306        assert_eq!(config.hidden_size, 8192);
307        assert_eq!(config.num_attention_heads, 64);
308        assert_eq!(config.num_hidden_layers, 40);
309        assert!(config.validate().is_ok());
310    }
311
312    #[test]
313    fn test_command_r_plus_config() {
314        let config = CommandRConfig::command_r_plus();
315        assert_eq!(config.model_name, "command-r-plus");
316        assert_eq!(config.vocab_size, 256000);
317        assert_eq!(config.hidden_size, 12288);
318        assert_eq!(config.num_attention_heads, 96);
319        assert_eq!(config.num_hidden_layers, 64);
320        assert!(config.validate().is_ok());
321    }
322
323    #[test]
324    fn test_head_dim_calculation() {
325        let config = CommandRConfig::command_r();
326        assert_eq!(config.head_dim(), 128); // 8192 / 64
327
328        let config_plus = CommandRConfig::command_r_plus();
329        assert_eq!(config_plus.head_dim(), 128); // 12288 / 96
330    }
331
332    #[test]
333    fn test_gqa_detection() {
334        let config = CommandRConfig::command_r();
335        assert!(!config.is_gqa()); // Same number of heads
336
337        let mut config_gqa = config.clone();
338        config_gqa.num_key_value_heads = 32;
339        assert!(config_gqa.is_gqa());
340        assert_eq!(config_gqa.num_query_groups(), 2); // 64 / 32
341    }
342
343    #[test]
344    fn test_config_validation() {
345        let mut config = CommandRConfig::default();
346        assert!(config.validate().is_ok());
347
348        config.vocab_size = 0;
349        assert!(config.validate().is_err());
350
351        config.vocab_size = 256000;
352        config.hidden_size = 100;
353        config.num_attention_heads = 64;
354        assert!(config.validate().is_err()); // 100 not divisible by 64
355    }
356}