1use anyhow::Result;
4use regex::Regex;
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone)]
9pub struct WeightMapping {
10 rules: Vec<WeightMappingRule>,
11 #[allow(dead_code)]
12 model_type: ModelType,
13}
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
16pub enum ModelType {
17 BERT,
18 GPT2,
19 T5,
20 LLaMA,
21 Generic,
22}
23
24#[derive(Debug, Clone)]
26pub struct WeightMappingRule {
27 pub pattern: Regex,
28 pub replacement: String,
29 pub transform: Option<WeightTransform>,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub enum WeightTransform {
35 Identity,
37 Transpose(Vec<usize>),
39 Reshape(Vec<isize>), Split { axis: usize, sizes: Vec<usize> },
43 Merge { axis: usize },
45 ConvFormat { from: ConvFormat, to: ConvFormat },
47}
48
49#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
50pub enum ConvFormat {
51 NCHW, NHWC, }
54
55impl WeightMapping {
56 pub fn new(model_type: ModelType) -> Self {
57 let rules = match model_type {
58 ModelType::BERT => Self::bert_rules(),
59 ModelType::GPT2 => Self::gpt2_rules(),
60 ModelType::T5 => Self::t5_rules(),
61 ModelType::LLaMA => Self::llama_rules(),
62 ModelType::Generic => Vec::new(),
63 };
64
65 Self { rules, model_type }
66 }
67
68 pub fn pytorch_to_tensorflow(&self, name: &str) -> Result<(String, Option<WeightTransform>)> {
70 for rule in &self.rules {
71 if rule.pattern.is_match(name) {
72 let new_name = rule.pattern.replace(name, &rule.replacement).to_string();
73 return Ok((new_name, rule.transform.clone()));
74 }
75 }
76
77 Ok((self.default_pytorch_to_tf(name), None))
79 }
80
81 pub fn tensorflow_to_pytorch(&self, name: &str) -> Result<(String, Option<WeightTransform>)> {
83 Ok((self.default_tf_to_pytorch(name), None))
85 }
86
87 pub fn jax_to_pytorch(&self, name: &str) -> Result<(String, Option<WeightTransform>)> {
89 let pytorch_name = name.replace("params.", "").replace(".", "_");
91 Ok((pytorch_name, None))
92 }
93
94 pub fn pytorch_to_jax(&self, name: &str) -> Result<(String, Option<WeightTransform>)> {
96 let parts: Vec<&str> = name.split('_').collect();
98 let jax_name = format!("params.{}", parts.join("."));
99 Ok((jax_name, None))
100 }
101
102 fn bert_rules() -> Vec<WeightMappingRule> {
103 vec![
104 WeightMappingRule {
106 pattern: Regex::new(r"^embeddings\.word_embeddings\.weight$").expect("valid regex"),
107 replacement: "bert/embeddings/word_embeddings".to_string(),
108 transform: None,
109 },
110 WeightMappingRule {
111 pattern: Regex::new(r"^embeddings\.position_embeddings\.weight$")
112 .expect("valid regex"),
113 replacement: "bert/embeddings/position_embeddings".to_string(),
114 transform: None,
115 },
116 WeightMappingRule {
117 pattern: Regex::new(r"^embeddings\.token_type_embeddings\.weight$")
118 .expect("valid regex"),
119 replacement: "bert/embeddings/token_type_embeddings".to_string(),
120 transform: None,
121 },
122 WeightMappingRule {
124 pattern: Regex::new(r"^embeddings\.LayerNorm\.weight$").expect("valid regex"),
125 replacement: "bert/embeddings/LayerNorm/gamma".to_string(),
126 transform: None,
127 },
128 WeightMappingRule {
129 pattern: Regex::new(r"^embeddings\.LayerNorm\.bias$").expect("valid regex"),
130 replacement: "bert/embeddings/LayerNorm/beta".to_string(),
131 transform: None,
132 },
133 WeightMappingRule {
135 pattern: Regex::new(r"^encoder\.layer\.(\d+)\.attention\.self\.query\.weight$")
136 .expect("valid regex"),
137 replacement: "bert/encoder/layer_$1/attention/self/query/kernel".to_string(),
138 transform: Some(WeightTransform::Transpose(vec![1, 0])),
139 },
140 WeightMappingRule {
141 pattern: Regex::new(r"^encoder\.layer\.(\d+)\.attention\.self\.key\.weight$")
142 .expect("valid regex"),
143 replacement: "bert/encoder/layer_$1/attention/self/key/kernel".to_string(),
144 transform: Some(WeightTransform::Transpose(vec![1, 0])),
145 },
146 WeightMappingRule {
147 pattern: Regex::new(r"^encoder\.layer\.(\d+)\.attention\.self\.value\.weight$")
148 .expect("valid regex"),
149 replacement: "bert/encoder/layer_$1/attention/self/value/kernel".to_string(),
150 transform: Some(WeightTransform::Transpose(vec![1, 0])),
151 },
152 WeightMappingRule {
154 pattern: Regex::new(r"^encoder\.layer\.(\d+)\.attention\.output\.dense\.weight$")
155 .expect("valid regex"),
156 replacement: "bert/encoder/layer_$1/attention/output/dense/kernel".to_string(),
157 transform: Some(WeightTransform::Transpose(vec![1, 0])),
158 },
159 WeightMappingRule {
161 pattern: Regex::new(r"^encoder\.layer\.(\d+)\.intermediate\.dense\.weight$")
162 .expect("valid regex"),
163 replacement: "bert/encoder/layer_$1/intermediate/dense/kernel".to_string(),
164 transform: Some(WeightTransform::Transpose(vec![1, 0])),
165 },
166 WeightMappingRule {
167 pattern: Regex::new(r"^encoder\.layer\.(\d+)\.output\.dense\.weight$")
168 .expect("regex pattern must be valid"),
169 replacement: "bert/encoder/layer_$1/output/dense/kernel".to_string(),
170 transform: Some(WeightTransform::Transpose(vec![1, 0])),
171 },
172 ]
173 }
174
175 fn gpt2_rules() -> Vec<WeightMappingRule> {
176 vec![
177 WeightMappingRule {
179 pattern: Regex::new(r"^wte\.weight$").expect("valid regex"),
180 replacement: "model/wte".to_string(),
181 transform: None,
182 },
183 WeightMappingRule {
185 pattern: Regex::new(r"^wpe\.weight$").expect("valid regex"),
186 replacement: "model/wpe".to_string(),
187 transform: None,
188 },
189 WeightMappingRule {
191 pattern: Regex::new(r"^h\.(\d+)\.attn\.c_attn\.weight$").expect("valid regex"),
192 replacement: "model/h$1/attn/c_attn/kernel".to_string(),
193 transform: Some(WeightTransform::Transpose(vec![1, 0])),
194 },
195 WeightMappingRule {
196 pattern: Regex::new(r"^h\.(\d+)\.attn\.c_proj\.weight$").expect("valid regex"),
197 replacement: "model/h$1/attn/c_proj/kernel".to_string(),
198 transform: Some(WeightTransform::Transpose(vec![1, 0])),
199 },
200 WeightMappingRule {
201 pattern: Regex::new(r"^h\.(\d+)\.mlp\.c_fc\.weight$").expect("valid regex"),
202 replacement: "model/h$1/mlp/c_fc/kernel".to_string(),
203 transform: Some(WeightTransform::Transpose(vec![1, 0])),
204 },
205 WeightMappingRule {
206 pattern: Regex::new(r"^h\.(\d+)\.mlp\.c_proj\.weight$").expect("valid regex"),
207 replacement: "model/h$1/mlp/c_proj/kernel".to_string(),
208 transform: Some(WeightTransform::Transpose(vec![1, 0])),
209 },
210 WeightMappingRule {
212 pattern: Regex::new(r"^h\.(\d+)\.ln_1\.weight$").expect("valid regex"),
213 replacement: "model/h$1/ln_1/g".to_string(),
214 transform: None,
215 },
216 WeightMappingRule {
217 pattern: Regex::new(r"^h\.(\d+)\.ln_2\.weight$").expect("valid regex"),
218 replacement: "model/h$1/ln_2/g".to_string(),
219 transform: None,
220 },
221 WeightMappingRule {
222 pattern: Regex::new(r"^ln_f\.weight$").expect("valid regex"),
223 replacement: "model/ln_f/g".to_string(),
224 transform: None,
225 },
226 ]
227 }
228
229 fn t5_rules() -> Vec<WeightMappingRule> {
230 vec![
231 WeightMappingRule {
233 pattern: Regex::new(r"^shared\.weight$").expect("valid regex"),
234 replacement: "shared/embedding".to_string(),
235 transform: None,
236 },
237 WeightMappingRule {
239 pattern: Regex::new(r"^encoder\.block\.(\d+)\.layer\.0\.SelfAttention\.q\.weight$")
240 .expect("valid regex"),
241 replacement: "encoder/block_$1/layer_0/SelfAttention/q".to_string(),
242 transform: Some(WeightTransform::Transpose(vec![1, 0])),
243 },
244 WeightMappingRule {
245 pattern: Regex::new(r"^encoder\.block\.(\d+)\.layer\.0\.SelfAttention\.k\.weight$")
246 .expect("valid regex"),
247 replacement: "encoder/block_$1/layer_0/SelfAttention/k".to_string(),
248 transform: Some(WeightTransform::Transpose(vec![1, 0])),
249 },
250 WeightMappingRule {
251 pattern: Regex::new(r"^encoder\.block\.(\d+)\.layer\.0\.SelfAttention\.v\.weight$")
252 .expect("valid regex"),
253 replacement: "encoder/block_$1/layer_0/SelfAttention/v".to_string(),
254 transform: Some(WeightTransform::Transpose(vec![1, 0])),
255 },
256 WeightMappingRule {
257 pattern: Regex::new(r"^encoder\.block\.(\d+)\.layer\.0\.SelfAttention\.o\.weight$")
258 .expect("valid regex"),
259 replacement: "encoder/block_$1/layer_0/SelfAttention/o".to_string(),
260 transform: Some(WeightTransform::Transpose(vec![1, 0])),
261 },
262 WeightMappingRule {
264 pattern: Regex::new(r"^decoder\.block\.(\d+)\.layer\.0\.SelfAttention\.q\.weight$")
265 .expect("valid regex"),
266 replacement: "decoder/block_$1/layer_0/SelfAttention/q".to_string(),
267 transform: Some(WeightTransform::Transpose(vec![1, 0])),
268 },
269 ]
271 }
272
273 fn llama_rules() -> Vec<WeightMappingRule> {
274 vec![
275 WeightMappingRule {
277 pattern: Regex::new(r"^model\.embed_tokens\.weight$").expect("valid regex"),
278 replacement: "model.embed_tokens.weight".to_string(),
279 transform: None,
280 },
281 WeightMappingRule {
283 pattern: Regex::new(r"^model\.layers\.(\d+)\.self_attn\.q_proj\.weight$")
284 .expect("regex pattern must be valid"),
285 replacement: "model.layers.$1.self_attn.q_proj.weight".to_string(),
286 transform: None,
287 },
288 WeightMappingRule {
289 pattern: Regex::new(r"^model\.layers\.(\d+)\.self_attn\.k_proj\.weight$")
290 .expect("regex pattern must be valid"),
291 replacement: "model.layers.$1.self_attn.k_proj.weight".to_string(),
292 transform: None,
293 },
294 WeightMappingRule {
295 pattern: Regex::new(r"^model\.layers\.(\d+)\.self_attn\.v_proj\.weight$")
296 .expect("regex pattern must be valid"),
297 replacement: "model.layers.$1.self_attn.v_proj.weight".to_string(),
298 transform: None,
299 },
300 WeightMappingRule {
301 pattern: Regex::new(r"^model\.layers\.(\d+)\.self_attn\.o_proj\.weight$")
302 .expect("regex pattern must be valid"),
303 replacement: "model.layers.$1.self_attn.o_proj.weight".to_string(),
304 transform: None,
305 },
306 WeightMappingRule {
308 pattern: Regex::new(r"^model\.layers\.(\d+)\.mlp\.gate_proj\.weight$")
309 .expect("regex pattern must be valid"),
310 replacement: "model.layers.$1.mlp.gate_proj.weight".to_string(),
311 transform: None,
312 },
313 WeightMappingRule {
314 pattern: Regex::new(r"^model\.layers\.(\d+)\.mlp\.up_proj\.weight$")
315 .expect("regex pattern must be valid"),
316 replacement: "model.layers.$1.mlp.up_proj.weight".to_string(),
317 transform: None,
318 },
319 WeightMappingRule {
320 pattern: Regex::new(r"^model\.layers\.(\d+)\.mlp\.down_proj\.weight$")
321 .expect("regex pattern must be valid"),
322 replacement: "model.layers.$1.mlp.down_proj.weight".to_string(),
323 transform: None,
324 },
325 WeightMappingRule {
327 pattern: Regex::new(r"^model\.layers\.(\d+)\.input_layernorm\.weight$")
328 .expect("regex pattern must be valid"),
329 replacement: "model.layers.$1.input_layernorm.weight".to_string(),
330 transform: None,
331 },
332 WeightMappingRule {
333 pattern: Regex::new(r"^model\.layers\.(\d+)\.post_attention_layernorm\.weight$")
334 .expect("valid regex"),
335 replacement: "model.layers.$1.post_attention_layernorm.weight".to_string(),
336 transform: None,
337 },
338 ]
339 }
340
341 fn default_pytorch_to_tf(&self, name: &str) -> String {
342 name.replace('.', "/")
344 .replace("weight", "kernel")
345 .replace("LayerNorm", "layer_norm")
346 }
347
348 fn default_tf_to_pytorch(&self, name: &str) -> String {
349 name.replace('/', ".")
351 .replace("kernel", "weight")
352 .replace("layer_norm", "LayerNorm")
353 }
354}
355
356#[derive(Debug, Clone)]
358pub struct LayerMapping {
359 pub source_layers: Vec<String>,
360 pub target_layers: Vec<String>,
361 pub merge_strategy: Option<MergeStrategy>,
362}
363
364#[derive(Debug, Clone)]
365pub enum MergeStrategy {
366 Concatenate { axis: usize },
368 Add,
370 Average,
372 Custom(String),
374}
375
376impl LayerMapping {
377 pub fn new(source: Vec<String>, target: Vec<String>) -> Self {
378 Self {
379 source_layers: source,
380 target_layers: target,
381 merge_strategy: None,
382 }
383 }
384
385 pub fn with_merge_strategy(mut self, strategy: MergeStrategy) -> Self {
386 self.merge_strategy = Some(strategy);
387 self
388 }
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394
395 #[test]
396 fn test_bert_mapping() {
397 let mapping = WeightMapping::new(ModelType::BERT);
398
399 let (tf_name, transform) = mapping
400 .pytorch_to_tensorflow("encoder.layer.0.attention.self.query.weight")
401 .expect("operation failed in test");
402
403 assert_eq!(tf_name, "bert/encoder/layer_0/attention/self/query/kernel");
404 assert!(matches!(transform, Some(WeightTransform::Transpose(_))));
405 }
406
407 #[test]
408 fn test_gpt2_mapping() {
409 let mapping = WeightMapping::new(ModelType::GPT2);
410
411 let (tf_name, _) =
412 mapping.pytorch_to_tensorflow("wte.weight").expect("tensor operation failed");
413 assert_eq!(tf_name, "model/wte");
414
415 let (tf_name, transform) = mapping
416 .pytorch_to_tensorflow("h.0.attn.c_attn.weight")
417 .expect("tensor operation failed");
418 assert_eq!(tf_name, "model/h0/attn/c_attn/kernel");
419 assert!(matches!(transform, Some(WeightTransform::Transpose(_))));
420 }
421
422 #[test]
423 fn test_jax_mapping() {
424 let mapping = WeightMapping::new(ModelType::Generic);
425
426 let (jax_name, _) = mapping
427 .pytorch_to_jax("encoder_layer_0_attention_query_weight")
428 .expect("operation failed in test");
429 assert_eq!(jax_name, "params.encoder.layer.0.attention.query.weight");
430
431 let (pytorch_name, _) = mapping
432 .jax_to_pytorch("params.encoder.layer.0.attention.query.weight")
433 .expect("operation failed in test");
434 assert_eq!(pytorch_name, "encoder_layer_0_attention_query_weight");
435 }
436}