Skip to main content

trustformers_core/checkpoint/
mapping.rs

1//! Weight mapping rules for converting between different framework conventions
2
3use anyhow::Result;
4use regex::Regex;
5use serde::{Deserialize, Serialize};
6
7/// Mapping rules for converting weight names between frameworks
8#[derive(Debug, Clone)]
9pub struct WeightMapping {
10    rules: Vec<WeightMappingRule>,
11    #[allow(dead_code)]
12    model_type: ModelType,
13}
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
16pub enum ModelType {
17    BERT,
18    GPT2,
19    T5,
20    LLaMA,
21    Generic,
22}
23
24/// Individual mapping rule
25#[derive(Debug, Clone)]
26pub struct WeightMappingRule {
27    pub pattern: Regex,
28    pub replacement: String,
29    pub transform: Option<WeightTransform>,
30}
31
32/// Transformations that may be needed when converting weights
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub enum WeightTransform {
35    /// No transformation
36    Identity,
37    /// Transpose specific dimensions
38    Transpose(Vec<usize>),
39    /// Reshape to new dimensions
40    Reshape(Vec<isize>), // -1 for inferred dimension
41    /// Split into multiple tensors
42    Split { axis: usize, sizes: Vec<usize> },
43    /// Merge multiple tensors
44    Merge { axis: usize },
45    /// Convert convolution weights format
46    ConvFormat { from: ConvFormat, to: ConvFormat },
47}
48
49#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
50pub enum ConvFormat {
51    NCHW, // PyTorch default
52    NHWC, // TensorFlow default
53}
54
55impl WeightMapping {
56    pub fn new(model_type: ModelType) -> Self {
57        let rules = match model_type {
58            ModelType::BERT => Self::bert_rules(),
59            ModelType::GPT2 => Self::gpt2_rules(),
60            ModelType::T5 => Self::t5_rules(),
61            ModelType::LLaMA => Self::llama_rules(),
62            ModelType::Generic => Vec::new(),
63        };
64
65        Self { rules, model_type }
66    }
67
68    /// Map PyTorch weight name to TensorFlow format
69    pub fn pytorch_to_tensorflow(&self, name: &str) -> Result<(String, Option<WeightTransform>)> {
70        for rule in &self.rules {
71            if rule.pattern.is_match(name) {
72                let new_name = rule.pattern.replace(name, &rule.replacement).to_string();
73                return Ok((new_name, rule.transform.clone()));
74            }
75        }
76
77        // Default mapping if no rule matches
78        Ok((self.default_pytorch_to_tf(name), None))
79    }
80
81    /// Map TensorFlow weight name to PyTorch format
82    pub fn tensorflow_to_pytorch(&self, name: &str) -> Result<(String, Option<WeightTransform>)> {
83        // Reverse mapping - this is simplified, in practice we'd need reverse rules
84        Ok((self.default_tf_to_pytorch(name), None))
85    }
86
87    /// Map JAX weight name to PyTorch format
88    pub fn jax_to_pytorch(&self, name: &str) -> Result<(String, Option<WeightTransform>)> {
89        // JAX uses hierarchical names with dots
90        let pytorch_name = name.replace("params.", "").replace(".", "_");
91        Ok((pytorch_name, None))
92    }
93
94    /// Map PyTorch weight name to JAX format
95    pub fn pytorch_to_jax(&self, name: &str) -> Result<(String, Option<WeightTransform>)> {
96        // Convert underscores to dots for JAX hierarchical structure
97        let parts: Vec<&str> = name.split('_').collect();
98        let jax_name = format!("params.{}", parts.join("."));
99        Ok((jax_name, None))
100    }
101
102    fn bert_rules() -> Vec<WeightMappingRule> {
103        vec![
104            // Embeddings
105            WeightMappingRule {
106                pattern: Regex::new(r"^embeddings\.word_embeddings\.weight$").expect("valid regex"),
107                replacement: "bert/embeddings/word_embeddings".to_string(),
108                transform: None,
109            },
110            WeightMappingRule {
111                pattern: Regex::new(r"^embeddings\.position_embeddings\.weight$")
112                    .expect("valid regex"),
113                replacement: "bert/embeddings/position_embeddings".to_string(),
114                transform: None,
115            },
116            WeightMappingRule {
117                pattern: Regex::new(r"^embeddings\.token_type_embeddings\.weight$")
118                    .expect("valid regex"),
119                replacement: "bert/embeddings/token_type_embeddings".to_string(),
120                transform: None,
121            },
122            // Layer normalization
123            WeightMappingRule {
124                pattern: Regex::new(r"^embeddings\.LayerNorm\.weight$").expect("valid regex"),
125                replacement: "bert/embeddings/LayerNorm/gamma".to_string(),
126                transform: None,
127            },
128            WeightMappingRule {
129                pattern: Regex::new(r"^embeddings\.LayerNorm\.bias$").expect("valid regex"),
130                replacement: "bert/embeddings/LayerNorm/beta".to_string(),
131                transform: None,
132            },
133            // Encoder layers
134            WeightMappingRule {
135                pattern: Regex::new(r"^encoder\.layer\.(\d+)\.attention\.self\.query\.weight$")
136                    .expect("valid regex"),
137                replacement: "bert/encoder/layer_$1/attention/self/query/kernel".to_string(),
138                transform: Some(WeightTransform::Transpose(vec![1, 0])),
139            },
140            WeightMappingRule {
141                pattern: Regex::new(r"^encoder\.layer\.(\d+)\.attention\.self\.key\.weight$")
142                    .expect("valid regex"),
143                replacement: "bert/encoder/layer_$1/attention/self/key/kernel".to_string(),
144                transform: Some(WeightTransform::Transpose(vec![1, 0])),
145            },
146            WeightMappingRule {
147                pattern: Regex::new(r"^encoder\.layer\.(\d+)\.attention\.self\.value\.weight$")
148                    .expect("valid regex"),
149                replacement: "bert/encoder/layer_$1/attention/self/value/kernel".to_string(),
150                transform: Some(WeightTransform::Transpose(vec![1, 0])),
151            },
152            // Output projection
153            WeightMappingRule {
154                pattern: Regex::new(r"^encoder\.layer\.(\d+)\.attention\.output\.dense\.weight$")
155                    .expect("valid regex"),
156                replacement: "bert/encoder/layer_$1/attention/output/dense/kernel".to_string(),
157                transform: Some(WeightTransform::Transpose(vec![1, 0])),
158            },
159            // FFN layers
160            WeightMappingRule {
161                pattern: Regex::new(r"^encoder\.layer\.(\d+)\.intermediate\.dense\.weight$")
162                    .expect("valid regex"),
163                replacement: "bert/encoder/layer_$1/intermediate/dense/kernel".to_string(),
164                transform: Some(WeightTransform::Transpose(vec![1, 0])),
165            },
166            WeightMappingRule {
167                pattern: Regex::new(r"^encoder\.layer\.(\d+)\.output\.dense\.weight$")
168                    .expect("regex pattern must be valid"),
169                replacement: "bert/encoder/layer_$1/output/dense/kernel".to_string(),
170                transform: Some(WeightTransform::Transpose(vec![1, 0])),
171            },
172        ]
173    }
174
175    fn gpt2_rules() -> Vec<WeightMappingRule> {
176        vec![
177            // Token embeddings
178            WeightMappingRule {
179                pattern: Regex::new(r"^wte\.weight$").expect("valid regex"),
180                replacement: "model/wte".to_string(),
181                transform: None,
182            },
183            // Position embeddings
184            WeightMappingRule {
185                pattern: Regex::new(r"^wpe\.weight$").expect("valid regex"),
186                replacement: "model/wpe".to_string(),
187                transform: None,
188            },
189            // Transformer blocks
190            WeightMappingRule {
191                pattern: Regex::new(r"^h\.(\d+)\.attn\.c_attn\.weight$").expect("valid regex"),
192                replacement: "model/h$1/attn/c_attn/kernel".to_string(),
193                transform: Some(WeightTransform::Transpose(vec![1, 0])),
194            },
195            WeightMappingRule {
196                pattern: Regex::new(r"^h\.(\d+)\.attn\.c_proj\.weight$").expect("valid regex"),
197                replacement: "model/h$1/attn/c_proj/kernel".to_string(),
198                transform: Some(WeightTransform::Transpose(vec![1, 0])),
199            },
200            WeightMappingRule {
201                pattern: Regex::new(r"^h\.(\d+)\.mlp\.c_fc\.weight$").expect("valid regex"),
202                replacement: "model/h$1/mlp/c_fc/kernel".to_string(),
203                transform: Some(WeightTransform::Transpose(vec![1, 0])),
204            },
205            WeightMappingRule {
206                pattern: Regex::new(r"^h\.(\d+)\.mlp\.c_proj\.weight$").expect("valid regex"),
207                replacement: "model/h$1/mlp/c_proj/kernel".to_string(),
208                transform: Some(WeightTransform::Transpose(vec![1, 0])),
209            },
210            // Layer norms
211            WeightMappingRule {
212                pattern: Regex::new(r"^h\.(\d+)\.ln_1\.weight$").expect("valid regex"),
213                replacement: "model/h$1/ln_1/g".to_string(),
214                transform: None,
215            },
216            WeightMappingRule {
217                pattern: Regex::new(r"^h\.(\d+)\.ln_2\.weight$").expect("valid regex"),
218                replacement: "model/h$1/ln_2/g".to_string(),
219                transform: None,
220            },
221            WeightMappingRule {
222                pattern: Regex::new(r"^ln_f\.weight$").expect("valid regex"),
223                replacement: "model/ln_f/g".to_string(),
224                transform: None,
225            },
226        ]
227    }
228
229    fn t5_rules() -> Vec<WeightMappingRule> {
230        vec![
231            // Shared embeddings
232            WeightMappingRule {
233                pattern: Regex::new(r"^shared\.weight$").expect("valid regex"),
234                replacement: "shared/embedding".to_string(),
235                transform: None,
236            },
237            // Encoder blocks
238            WeightMappingRule {
239                pattern: Regex::new(r"^encoder\.block\.(\d+)\.layer\.0\.SelfAttention\.q\.weight$")
240                    .expect("valid regex"),
241                replacement: "encoder/block_$1/layer_0/SelfAttention/q".to_string(),
242                transform: Some(WeightTransform::Transpose(vec![1, 0])),
243            },
244            WeightMappingRule {
245                pattern: Regex::new(r"^encoder\.block\.(\d+)\.layer\.0\.SelfAttention\.k\.weight$")
246                    .expect("valid regex"),
247                replacement: "encoder/block_$1/layer_0/SelfAttention/k".to_string(),
248                transform: Some(WeightTransform::Transpose(vec![1, 0])),
249            },
250            WeightMappingRule {
251                pattern: Regex::new(r"^encoder\.block\.(\d+)\.layer\.0\.SelfAttention\.v\.weight$")
252                    .expect("valid regex"),
253                replacement: "encoder/block_$1/layer_0/SelfAttention/v".to_string(),
254                transform: Some(WeightTransform::Transpose(vec![1, 0])),
255            },
256            WeightMappingRule {
257                pattern: Regex::new(r"^encoder\.block\.(\d+)\.layer\.0\.SelfAttention\.o\.weight$")
258                    .expect("valid regex"),
259                replacement: "encoder/block_$1/layer_0/SelfAttention/o".to_string(),
260                transform: Some(WeightTransform::Transpose(vec![1, 0])),
261            },
262            // Decoder blocks
263            WeightMappingRule {
264                pattern: Regex::new(r"^decoder\.block\.(\d+)\.layer\.0\.SelfAttention\.q\.weight$")
265                    .expect("valid regex"),
266                replacement: "decoder/block_$1/layer_0/SelfAttention/q".to_string(),
267                transform: Some(WeightTransform::Transpose(vec![1, 0])),
268            },
269            // Add more T5 specific rules...
270        ]
271    }
272
273    fn llama_rules() -> Vec<WeightMappingRule> {
274        vec![
275            // Token embeddings
276            WeightMappingRule {
277                pattern: Regex::new(r"^model\.embed_tokens\.weight$").expect("valid regex"),
278                replacement: "model.embed_tokens.weight".to_string(),
279                transform: None,
280            },
281            // Layers
282            WeightMappingRule {
283                pattern: Regex::new(r"^model\.layers\.(\d+)\.self_attn\.q_proj\.weight$")
284                    .expect("regex pattern must be valid"),
285                replacement: "model.layers.$1.self_attn.q_proj.weight".to_string(),
286                transform: None,
287            },
288            WeightMappingRule {
289                pattern: Regex::new(r"^model\.layers\.(\d+)\.self_attn\.k_proj\.weight$")
290                    .expect("regex pattern must be valid"),
291                replacement: "model.layers.$1.self_attn.k_proj.weight".to_string(),
292                transform: None,
293            },
294            WeightMappingRule {
295                pattern: Regex::new(r"^model\.layers\.(\d+)\.self_attn\.v_proj\.weight$")
296                    .expect("regex pattern must be valid"),
297                replacement: "model.layers.$1.self_attn.v_proj.weight".to_string(),
298                transform: None,
299            },
300            WeightMappingRule {
301                pattern: Regex::new(r"^model\.layers\.(\d+)\.self_attn\.o_proj\.weight$")
302                    .expect("regex pattern must be valid"),
303                replacement: "model.layers.$1.self_attn.o_proj.weight".to_string(),
304                transform: None,
305            },
306            // MLP
307            WeightMappingRule {
308                pattern: Regex::new(r"^model\.layers\.(\d+)\.mlp\.gate_proj\.weight$")
309                    .expect("regex pattern must be valid"),
310                replacement: "model.layers.$1.mlp.gate_proj.weight".to_string(),
311                transform: None,
312            },
313            WeightMappingRule {
314                pattern: Regex::new(r"^model\.layers\.(\d+)\.mlp\.up_proj\.weight$")
315                    .expect("regex pattern must be valid"),
316                replacement: "model.layers.$1.mlp.up_proj.weight".to_string(),
317                transform: None,
318            },
319            WeightMappingRule {
320                pattern: Regex::new(r"^model\.layers\.(\d+)\.mlp\.down_proj\.weight$")
321                    .expect("regex pattern must be valid"),
322                replacement: "model.layers.$1.mlp.down_proj.weight".to_string(),
323                transform: None,
324            },
325            // RMS Norm
326            WeightMappingRule {
327                pattern: Regex::new(r"^model\.layers\.(\d+)\.input_layernorm\.weight$")
328                    .expect("regex pattern must be valid"),
329                replacement: "model.layers.$1.input_layernorm.weight".to_string(),
330                transform: None,
331            },
332            WeightMappingRule {
333                pattern: Regex::new(r"^model\.layers\.(\d+)\.post_attention_layernorm\.weight$")
334                    .expect("valid regex"),
335                replacement: "model.layers.$1.post_attention_layernorm.weight".to_string(),
336                transform: None,
337            },
338        ]
339    }
340
341    fn default_pytorch_to_tf(&self, name: &str) -> String {
342        // Default conversion: replace . with / and weight with kernel
343        name.replace('.', "/")
344            .replace("weight", "kernel")
345            .replace("LayerNorm", "layer_norm")
346    }
347
348    fn default_tf_to_pytorch(&self, name: &str) -> String {
349        // Default reverse conversion
350        name.replace('/', ".")
351            .replace("kernel", "weight")
352            .replace("layer_norm", "LayerNorm")
353    }
354}
355
356/// Layer-level mapping for structural differences
357#[derive(Debug, Clone)]
358pub struct LayerMapping {
359    pub source_layers: Vec<String>,
360    pub target_layers: Vec<String>,
361    pub merge_strategy: Option<MergeStrategy>,
362}
363
364#[derive(Debug, Clone)]
365pub enum MergeStrategy {
366    /// Concatenate along a specific axis
367    Concatenate { axis: usize },
368    /// Add tensors element-wise
369    Add,
370    /// Average tensors
371    Average,
372    /// Custom function
373    Custom(String),
374}
375
376impl LayerMapping {
377    pub fn new(source: Vec<String>, target: Vec<String>) -> Self {
378        Self {
379            source_layers: source,
380            target_layers: target,
381            merge_strategy: None,
382        }
383    }
384
385    pub fn with_merge_strategy(mut self, strategy: MergeStrategy) -> Self {
386        self.merge_strategy = Some(strategy);
387        self
388    }
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394
395    #[test]
396    fn test_bert_mapping() {
397        let mapping = WeightMapping::new(ModelType::BERT);
398
399        let (tf_name, transform) = mapping
400            .pytorch_to_tensorflow("encoder.layer.0.attention.self.query.weight")
401            .expect("operation failed in test");
402
403        assert_eq!(tf_name, "bert/encoder/layer_0/attention/self/query/kernel");
404        assert!(matches!(transform, Some(WeightTransform::Transpose(_))));
405    }
406
407    #[test]
408    fn test_gpt2_mapping() {
409        let mapping = WeightMapping::new(ModelType::GPT2);
410
411        let (tf_name, _) =
412            mapping.pytorch_to_tensorflow("wte.weight").expect("tensor operation failed");
413        assert_eq!(tf_name, "model/wte");
414
415        let (tf_name, transform) = mapping
416            .pytorch_to_tensorflow("h.0.attn.c_attn.weight")
417            .expect("tensor operation failed");
418        assert_eq!(tf_name, "model/h0/attn/c_attn/kernel");
419        assert!(matches!(transform, Some(WeightTransform::Transpose(_))));
420    }
421
422    #[test]
423    fn test_jax_mapping() {
424        let mapping = WeightMapping::new(ModelType::Generic);
425
426        let (jax_name, _) = mapping
427            .pytorch_to_jax("encoder_layer_0_attention_query_weight")
428            .expect("operation failed in test");
429        assert_eq!(jax_name, "params.encoder.layer.0.attention.query.weight");
430
431        let (pytorch_name, _) = mapping
432            .jax_to_pytorch("params.encoder.layer.0.attention.query.weight")
433            .expect("operation failed in test");
434        assert_eq!(pytorch_name, "encoder_layer_0_attention_query_weight");
435    }
436}