Skip to main content

trustformers_models/retnet/
config.rs

1use serde::{Deserialize, Serialize};
2use trustformers_core::errors::invalid_config;
3use trustformers_core::traits::Config;
4
5/// RetNet model configuration
6/// Reference: "Retentive Network: A Successor to Transformer for Large Language Models"
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct RetNetConfig {
9    pub vocab_size: usize,
10    pub hidden_size: usize,
11    pub num_hidden_layers: usize,
12    pub num_heads: usize,
13    pub intermediate_size: usize,
14    pub hidden_act: String,
15    pub hidden_dropout_prob: f32,
16    pub attention_dropout_prob: f32,
17    pub max_position_embeddings: usize,
18    pub initializer_range: f32,
19    pub layer_norm_eps: f32,
20    pub pad_token_id: u32,
21    pub bos_token_id: u32,
22    pub eos_token_id: u32,
23
24    // RetNet-specific parameters
25    pub use_bias: bool,                // Use bias in linear layers
26    pub use_glu: bool,                 // Use GLU activation in FFN
27    pub use_norm_bias: bool,           // Use bias in normalization layers
28    pub deepnorm: bool,                // Use DeepNorm for better scaling
29    pub dropout_module: String,        // Dropout module type
30    pub activation_dropout: f32,       // Dropout for activations
31    pub attention_dropout: f32,        // Dropout for retention mechanism
32    pub retention_heads: usize,        // Number of retention heads
33    pub value_factor: f32,             // Value scaling factor
34    pub gate_fn: String,               // Gate function type
35    pub tensor_parallel_degree: usize, // Tensor parallelism degree
36    pub sequence_parallel: bool,       // Enable sequence parallelism
37    pub fuse_norm: bool,               // Fuse normalization operations
38    pub no_output_layer: bool,         // Remove output layer
39    pub layernorm_embedding: bool,     // Apply layernorm to embeddings
40    pub chunking: bool,                // Enable chunked processing
41    pub chunk_size: usize,             // Chunk size for processing
42}
43
44impl Default for RetNetConfig {
45    fn default() -> Self {
46        Self {
47            vocab_size: 32000,
48            hidden_size: 2048,
49            num_hidden_layers: 24,
50            num_heads: 16,
51            intermediate_size: 8192,
52            hidden_act: "swish".to_string(),
53            hidden_dropout_prob: 0.0,
54            attention_dropout_prob: 0.0,
55            max_position_embeddings: 2048,
56            initializer_range: 0.02,
57            layer_norm_eps: 1e-6,
58            pad_token_id: 0,
59            bos_token_id: 1,
60            eos_token_id: 2,
61
62            // RetNet defaults
63            use_bias: false,
64            use_glu: true,
65            use_norm_bias: false,
66            deepnorm: true,
67            dropout_module: "dropout".to_string(),
68            activation_dropout: 0.0,
69            attention_dropout: 0.0,
70            retention_heads: 16,
71            value_factor: 2.0,
72            gate_fn: "swish".to_string(),
73            tensor_parallel_degree: 1,
74            sequence_parallel: false,
75            fuse_norm: false,
76            no_output_layer: false,
77            layernorm_embedding: false,
78            chunking: false,
79            chunk_size: 512,
80        }
81    }
82}
83
84impl Config for RetNetConfig {
85    fn validate(&self) -> trustformers_core::errors::Result<()> {
86        if self.hidden_size % self.num_heads != 0 {
87            return Err(invalid_config(
88                "config_field",
89                "hidden_size must be divisible by num_heads".to_string(),
90            ));
91        }
92
93        if self.hidden_size % self.retention_heads != 0 {
94            return Err(invalid_config(
95                "config_field",
96                "hidden_size must be divisible by retention_heads".to_string(),
97            ));
98        }
99
100        if self.chunk_size > self.max_position_embeddings {
101            return Err(invalid_config(
102                "config_field",
103                "chunk_size should not exceed max_position_embeddings".to_string(),
104            ));
105        }
106
107        Ok(())
108    }
109
110    fn architecture(&self) -> &'static str {
111        "RetNet"
112    }
113}
114
115impl RetNetConfig {
116    /// RetNet-Small configuration (1.3B parameters)
117    pub fn retnet_small() -> Self {
118        Self {
119            hidden_size: 2048,
120            num_hidden_layers: 24,
121            num_heads: 16,
122            intermediate_size: 8192,
123            retention_heads: 16,
124            max_position_embeddings: 2048,
125            ..Self::default()
126        }
127    }
128
129    /// RetNet-Medium configuration (2.7B parameters)
130    pub fn retnet_medium() -> Self {
131        Self {
132            hidden_size: 2560,
133            num_hidden_layers: 32,
134            num_heads: 20,
135            intermediate_size: 10240,
136            retention_heads: 20,
137            max_position_embeddings: 2048,
138            ..Self::default()
139        }
140    }
141
142    /// RetNet-Large configuration (6.7B parameters)
143    pub fn retnet_large() -> Self {
144        Self {
145            hidden_size: 4096,
146            num_hidden_layers: 32,
147            num_heads: 32,
148            intermediate_size: 16384,
149            retention_heads: 32,
150            max_position_embeddings: 2048,
151            ..Self::default()
152        }
153    }
154
155    /// RetNet-XL configuration (13B parameters)
156    pub fn retnet_xl() -> Self {
157        Self {
158            hidden_size: 5120,
159            num_hidden_layers: 40,
160            num_heads: 40,
161            intermediate_size: 20480,
162            retention_heads: 40,
163            max_position_embeddings: 2048,
164            deepnorm: true,
165            ..Self::default()
166        }
167    }
168
169    /// Long-context RetNet for extended sequences
170    pub fn retnet_long() -> Self {
171        Self {
172            max_position_embeddings: 8192,
173            chunking: true,
174            chunk_size: 1024,
175            sequence_parallel: true,
176            ..Self::retnet_medium()
177        }
178    }
179
180    /// Get the head dimension
181    pub fn head_dim(&self) -> usize {
182        self.hidden_size / self.num_heads
183    }
184
185    /// Get the retention head dimension
186    pub fn retention_head_dim(&self) -> usize {
187        self.hidden_size / self.retention_heads
188    }
189
190    /// Get the effective head dimension for retention
191    pub fn retention_dim(&self) -> usize {
192        (self.hidden_size as f32 / self.value_factor) as usize
193    }
194
195    /// Check if using efficient chunked processing
196    pub fn uses_chunking(&self) -> bool {
197        self.chunking && self.chunk_size > 0
198    }
199
200    /// Get memory complexity advantage over attention
201    pub fn memory_advantage(&self) -> f32 {
202        let seq_len = self.max_position_embeddings as f32;
203        let attention_memory = seq_len * seq_len;
204        let retnet_memory = seq_len; // Linear complexity
205        attention_memory / retnet_memory
206    }
207
208    /// Check if configuration supports very long sequences
209    pub fn supports_long_sequences(&self) -> bool {
210        self.max_position_embeddings >= 4096 || self.uses_chunking()
211    }
212
213    /// Get the deepnorm scaling factor
214    pub fn deepnorm_alpha(&self) -> f32 {
215        // DeepNorm scaling factor based on number of layers
216        (2.0 * self.num_hidden_layers as f32).powf(0.25)
217    }
218
219    /// Get the deepnorm beta factor
220    pub fn deepnorm_beta(&self) -> f32 {
221        // DeepNorm initialization factor
222        (8.0 * self.num_hidden_layers as f32).powf(-0.25)
223    }
224}