1use serde::{Deserialize, Serialize};
54use tensorlogic_ir::{EinsumGraph, TLExpr, Term};
55
56use crate::{
57 config::{AttentionConfig, FeedForwardConfig},
58 error::{Result, TrustformerError},
59 layers::{EncoderLayer, EncoderLayerConfig},
60 stacks::{EncoderStack, EncoderStackConfig},
61};
62
63#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
65pub struct IntegrationConfig {
66 pub validate_shapes: bool,
68 pub preserve_dropout: bool,
70 pub pre_norm: bool,
72 pub numerical_tolerance: f64,
74}
75
76impl Default for IntegrationConfig {
77 fn default() -> Self {
78 Self {
79 validate_shapes: true,
80 preserve_dropout: true,
81 pre_norm: true,
82 numerical_tolerance: 1e-6,
83 }
84 }
85}
86
87impl IntegrationConfig {
88 pub fn new() -> Self {
90 Self::default()
91 }
92
93 pub fn with_shape_validation(mut self, validate: bool) -> Self {
95 self.validate_shapes = validate;
96 self
97 }
98
99 pub fn with_dropout_preservation(mut self, preserve: bool) -> Self {
101 self.preserve_dropout = preserve;
102 self
103 }
104
105 pub fn with_pre_norm(mut self, pre_norm: bool) -> Self {
107 self.pre_norm = pre_norm;
108 self
109 }
110
111 pub fn with_numerical_tolerance(mut self, tolerance: f64) -> Self {
113 self.numerical_tolerance = tolerance;
114 self
115 }
116}
117
118#[derive(Clone, Debug)]
123pub enum TensorLogicModel {
124 EncoderLayer {
126 layer: EncoderLayer,
127 config: EncoderLayerConfig,
128 },
129 EncoderStack {
131 stack: EncoderStack,
132 config: EncoderStackConfig,
133 },
134}
135
136impl TensorLogicModel {
137 pub fn from_encoder_layer(layer: EncoderLayer, config: EncoderLayerConfig) -> Result<Self> {
139 config.validate()?;
140 Ok(Self::EncoderLayer { layer, config })
141 }
142
143 pub fn from_encoder_stack(stack: EncoderStack, config: EncoderStackConfig) -> Result<Self> {
145 config.validate()?;
146 Ok(Self::EncoderStack { stack, config })
147 }
148
149 pub fn build_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
151 match self {
152 Self::EncoderLayer { layer, .. } => layer.build_encoder_layer_graph(graph),
153 Self::EncoderStack { stack, .. } => stack.build_encoder_stack_graph(graph),
154 }
155 }
156
157 pub fn config(&self) -> ModelConfig {
159 match self {
160 Self::EncoderLayer { config, .. } => ModelConfig::EncoderLayer {
161 d_model: config.attention.d_model,
162 n_heads: config.attention.n_heads,
163 d_ff: config.feed_forward.d_ff,
164 dropout: config.attention.dropout,
165 pre_norm: config.pre_norm,
166 },
167 Self::EncoderStack { config, .. } => ModelConfig::EncoderStack {
168 n_layers: config.num_layers,
169 d_model: config.layer_config.attention.d_model,
170 n_heads: config.layer_config.attention.n_heads,
171 d_ff: config.layer_config.feed_forward.d_ff,
172 max_seq_len: config.position_encoding.max_seq_len,
173 dropout: config.layer_config.attention.dropout,
174 pre_norm: config.layer_config.pre_norm,
175 },
176 }
177 }
178
179 pub fn to_tlexpr(&self) -> Result<TLExpr> {
181 match self {
182 Self::EncoderLayer { config, .. } => {
183 let attention_expr = Self::attention_to_tlexpr(&config.attention)?;
185 let ffn_expr = Self::ffn_to_tlexpr(&config.feed_forward)?;
186
187 Ok(TLExpr::And(Box::new(attention_expr), Box::new(ffn_expr)))
189 }
190 Self::EncoderStack { config, .. } => {
191 let layer_expr = {
193 let attn_cfg = AttentionConfig::new(
194 config.layer_config.attention.d_model,
195 config.layer_config.attention.n_heads,
196 )?;
197 let ffn_cfg = FeedForwardConfig::new(
198 config.layer_config.feed_forward.d_model,
199 config.layer_config.feed_forward.d_ff,
200 );
201
202 let attention_expr = Self::attention_to_tlexpr(&attn_cfg)?;
203 let ffn_expr = Self::ffn_to_tlexpr(&ffn_cfg)?;
204
205 TLExpr::And(Box::new(attention_expr), Box::new(ffn_expr))
206 };
207
208 Ok(TLExpr::ForAll {
210 var: "layer".to_string(),
211 domain: format!("0..{}", config.num_layers),
212 body: Box::new(layer_expr),
213 })
214 }
215 }
216 }
217
218 fn attention_to_tlexpr(config: &AttentionConfig) -> Result<TLExpr> {
220 Ok(TLExpr::Pred {
222 name: "MultiHeadAttention".to_string(),
223 args: vec![
224 Term::Const(format!("d_model={}", config.d_model)),
225 Term::Const(format!("n_heads={}", config.n_heads)),
226 Term::Const(format!("d_k={}", config.d_k)),
227 ],
228 })
229 }
230
231 fn ffn_to_tlexpr(config: &FeedForwardConfig) -> Result<TLExpr> {
233 Ok(TLExpr::Pred {
234 name: "FeedForward".to_string(),
235 args: vec![
236 Term::Const(format!("d_model={}", config.d_model)),
237 Term::Const(format!("d_ff={}", config.d_ff)),
238 Term::Const(format!("activation={}", config.activation)),
239 ],
240 })
241 }
242}
243
244#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
246pub enum ModelConfig {
247 EncoderLayer {
249 d_model: usize,
250 n_heads: usize,
251 d_ff: usize,
252 dropout: f64,
253 pre_norm: bool,
254 },
255 EncoderStack {
257 n_layers: usize,
258 d_model: usize,
259 n_heads: usize,
260 d_ff: usize,
261 max_seq_len: usize,
262 dropout: f64,
263 pre_norm: bool,
264 },
265}
266
267#[derive(Clone, Debug)]
272pub struct TrustformersConverter {
273 pub config: IntegrationConfig,
275}
276
277impl TrustformersConverter {
278 pub fn new() -> Self {
280 Self {
281 config: IntegrationConfig::default(),
282 }
283 }
284
285 pub fn with_config(config: IntegrationConfig) -> Self {
287 Self { config }
288 }
289
290 pub fn convert_bert_encoder(
295 &self,
296 n_layers: usize,
297 d_model: usize,
298 n_heads: usize,
299 d_ff: usize,
300 ) -> Result<TLExpr> {
301 if n_layers == 0 {
303 return Err(TrustformerError::InvalidDimension {
304 expected: 1,
305 got: 0,
306 context: "n_layers must be > 0".to_string(),
307 });
308 }
309 if !d_model.is_multiple_of(n_heads) {
310 return Err(TrustformerError::InvalidDimension {
311 expected: n_heads,
312 got: d_model,
313 context: format!(
314 "d_model {} must be divisible by n_heads {}",
315 d_model, n_heads
316 ),
317 });
318 }
319
320 let attn_cfg = AttentionConfig::new(d_model, n_heads)?;
322 let ffn_cfg = FeedForwardConfig::new(d_model, d_ff);
323
324 let attention_expr = TLExpr::Pred {
325 name: "MultiHeadAttention".to_string(),
326 args: vec![
327 Term::Const(format!("d_model={}", attn_cfg.d_model)),
328 Term::Const(format!("n_heads={}", attn_cfg.n_heads)),
329 Term::Const(format!("d_k={}", attn_cfg.d_k)),
330 ],
331 };
332
333 let ffn_expr = TLExpr::Pred {
334 name: "FeedForward".to_string(),
335 args: vec![
336 Term::Const(format!("d_model={}", ffn_cfg.d_model)),
337 Term::Const(format!("d_ff={}", ffn_cfg.d_ff)),
338 Term::Const(format!("activation={}", ffn_cfg.activation)),
339 ],
340 };
341
342 let layer_expr = TLExpr::And(Box::new(attention_expr), Box::new(ffn_expr));
343
344 Ok(TLExpr::ForAll {
346 var: "layer".to_string(),
347 domain: format!("0..{}", n_layers),
348 body: Box::new(layer_expr),
349 })
350 }
351
352 pub fn convert_gpt_decoder(
354 &self,
355 n_layers: usize,
356 d_model: usize,
357 n_heads: usize,
358 d_ff: usize,
359 ) -> Result<TLExpr> {
360 if n_layers == 0 {
362 return Err(TrustformerError::InvalidDimension {
363 expected: 1,
364 got: 0,
365 context: "n_layers must be > 0".to_string(),
366 });
367 }
368 if !d_model.is_multiple_of(n_heads) {
369 return Err(TrustformerError::InvalidDimension {
370 expected: n_heads,
371 got: d_model,
372 context: format!(
373 "d_model {} must be divisible by n_heads {}",
374 d_model, n_heads
375 ),
376 });
377 }
378
379 let attn_cfg = AttentionConfig::new(d_model, n_heads)?.with_causal(true);
381 let ffn_cfg = FeedForwardConfig::new(d_model, d_ff);
382
383 let causal_attention_expr = TLExpr::Pred {
384 name: "CausalMultiHeadAttention".to_string(),
385 args: vec![
386 Term::Const(format!("d_model={}", attn_cfg.d_model)),
387 Term::Const(format!("n_heads={}", attn_cfg.n_heads)),
388 Term::Const(format!("d_k={}", attn_cfg.d_k)),
389 Term::Const("causal=true".to_string()),
390 ],
391 };
392
393 let ffn_expr = TLExpr::Pred {
394 name: "FeedForward".to_string(),
395 args: vec![
396 Term::Const(format!("d_model={}", ffn_cfg.d_model)),
397 Term::Const(format!("d_ff={}", ffn_cfg.d_ff)),
398 Term::Const(format!("activation={}", ffn_cfg.activation)),
399 ],
400 };
401
402 let layer_expr = TLExpr::And(Box::new(causal_attention_expr), Box::new(ffn_expr));
403
404 Ok(TLExpr::ForAll {
406 var: "layer".to_string(),
407 domain: format!("0..{}", n_layers),
408 body: Box::new(layer_expr),
409 })
410 }
411
412 pub fn convert_transformer(
414 &self,
415 encoder_layers: usize,
416 decoder_layers: usize,
417 d_model: usize,
418 n_heads: usize,
419 d_ff: usize,
420 ) -> Result<TLExpr> {
421 let encoder_expr = if encoder_layers > 0 {
422 Some(self.convert_bert_encoder(encoder_layers, d_model, n_heads, d_ff)?)
423 } else {
424 None
425 };
426
427 let decoder_expr = if decoder_layers > 0 {
428 Some(self.convert_gpt_decoder(decoder_layers, d_model, n_heads, d_ff)?)
429 } else {
430 None
431 };
432
433 match (encoder_expr, decoder_expr) {
434 (Some(enc), Some(dec)) => {
435 Ok(TLExpr::And(Box::new(enc), Box::new(dec)))
437 }
438 (Some(enc), None) => Ok(enc),
439 (None, Some(dec)) => Ok(dec),
440 (None, None) => Err(TrustformerError::InvalidDimension {
441 expected: 1,
442 got: 0,
443 context: "At least one of encoder_layers or decoder_layers must be > 0".to_string(),
444 }),
445 }
446 }
447}
448
449impl Default for TrustformersConverter {
450 fn default() -> Self {
451 Self::new()
452 }
453}
454
455#[derive(Clone, Debug)]
462pub struct TrustformersWeightLoader {
463 pub config: IntegrationConfig,
465}
466
467impl TrustformersWeightLoader {
468 pub fn new() -> Self {
470 Self {
471 config: IntegrationConfig::default(),
472 }
473 }
474
475 pub fn with_config(config: IntegrationConfig) -> Self {
477 Self { config }
478 }
479
480 pub fn load_checkpoint(&self, path: &str) -> Result<CheckpointData> {
530 use std::path::Path;
531
532 let path_obj = Path::new(path);
533
534 if !path_obj.exists() {
535 return Err(TrustformerError::CheckpointLoadError(format!(
536 "Checkpoint file not found: {}",
537 path
538 )));
539 }
540
541 let extension = path_obj
543 .extension()
544 .and_then(|s| s.to_str())
545 .ok_or_else(|| {
546 TrustformerError::CheckpointLoadError(format!(
547 "Cannot determine checkpoint format for: {}",
548 path
549 ))
550 })?;
551
552 match extension {
553 "json" => self.load_json_checkpoint(path),
554 "bin" | "ckpt" => self.load_binary_checkpoint(path),
555 _ => Err(TrustformerError::CheckpointLoadError(format!(
556 "Unsupported checkpoint format: .{}",
557 extension
558 ))),
559 }
560 }
561
562 fn load_json_checkpoint(&self, path: &str) -> Result<CheckpointData> {
564 use std::fs;
565
566 let content = fs::read_to_string(path).map_err(|e| {
567 TrustformerError::CheckpointLoadError(format!("Failed to read checkpoint: {}", e))
568 })?;
569
570 #[derive(Deserialize)]
571 struct JsonCheckpoint {
572 #[serde(default)]
573 metadata: std::collections::HashMap<String, String>,
574 weights: std::collections::HashMap<String, Vec<f32>>,
575 }
576
577 let json_ckpt: JsonCheckpoint = serde_json::from_str(&content).map_err(|e| {
578 TrustformerError::CheckpointLoadError(format!("Invalid JSON checkpoint: {}", e))
579 })?;
580
581 let mut mapped_weights = std::collections::HashMap::new();
583 for (trustformers_name, weights) in json_ckpt.weights {
584 let tl_name = self.map_layer_name(&trustformers_name)?;
585 mapped_weights.insert(tl_name, weights);
586 }
587
588 Ok(CheckpointData {
589 weights: mapped_weights,
590 metadata: json_ckpt.metadata,
591 })
592 }
593
594 fn load_binary_checkpoint(&self, path: &str) -> Result<CheckpointData> {
596 use std::fs;
597 use std::io::{BufReader, Read};
598
599 let file = fs::File::open(path).map_err(|e| {
600 TrustformerError::CheckpointLoadError(format!("Failed to open checkpoint: {}", e))
601 })?;
602
603 let mut reader = BufReader::new(file);
604
605 let mut header = [0u8; 256];
607 reader.read_exact(&mut header).map_err(|e| {
608 TrustformerError::CheckpointLoadError(format!("Failed to read header: {}", e))
609 })?;
610
611 let magic = &header[0..6];
613 if magic != b"TLCKPT" {
614 return Err(TrustformerError::CheckpointLoadError(
615 "Invalid checkpoint magic number".to_string(),
616 ));
617 }
618
619 let version = u32::from_le_bytes([header[6], header[7], header[8], header[9]]);
621 if version != 1 {
622 return Err(TrustformerError::CheckpointLoadError(format!(
623 "Unsupported checkpoint version: {}",
624 version
625 )));
626 }
627
628 let num_tensors = u32::from_le_bytes([header[10], header[11], header[12], header[13]]);
630
631 let metadata_size = u32::from_le_bytes([header[14], header[15], header[16], header[17]]);
633
634 let mut metadata_bytes = vec![0u8; metadata_size as usize];
636 reader.read_exact(&mut metadata_bytes).map_err(|e| {
637 TrustformerError::CheckpointLoadError(format!("Failed to read metadata: {}", e))
638 })?;
639
640 let metadata: std::collections::HashMap<String, String> =
641 serde_json::from_slice(&metadata_bytes).map_err(|e| {
642 TrustformerError::CheckpointLoadError(format!("Invalid metadata JSON: {}", e))
643 })?;
644
645 let mut weights = std::collections::HashMap::new();
647
648 for _ in 0..num_tensors {
649 let mut name_len_bytes = [0u8; 4];
651 reader.read_exact(&mut name_len_bytes).map_err(|e| {
652 TrustformerError::CheckpointLoadError(format!("Failed to read name length: {}", e))
653 })?;
654 let name_len = u32::from_le_bytes(name_len_bytes) as usize;
655
656 let mut name_bytes = vec![0u8; name_len];
658 reader.read_exact(&mut name_bytes).map_err(|e| {
659 TrustformerError::CheckpointLoadError(format!("Failed to read tensor name: {}", e))
660 })?;
661 let trustformers_name = String::from_utf8(name_bytes).map_err(|e| {
662 TrustformerError::CheckpointLoadError(format!("Invalid tensor name UTF-8: {}", e))
663 })?;
664
665 let mut data_len_bytes = [0u8; 4];
667 reader.read_exact(&mut data_len_bytes).map_err(|e| {
668 TrustformerError::CheckpointLoadError(format!("Failed to read data length: {}", e))
669 })?;
670 let data_len = u32::from_le_bytes(data_len_bytes) as usize;
671
672 let mut weight_bytes = vec![0u8; data_len * 4];
674 reader.read_exact(&mut weight_bytes).map_err(|e| {
675 TrustformerError::CheckpointLoadError(format!("Failed to read weights: {}", e))
676 })?;
677
678 let mut tensor_weights = Vec::with_capacity(data_len);
680 for chunk in weight_bytes.chunks_exact(4) {
681 let value = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
682 tensor_weights.push(value);
683 }
684
685 let tl_name = self.map_layer_name(&trustformers_name)?;
687 weights.insert(tl_name, tensor_weights);
688 }
689
690 Ok(CheckpointData { weights, metadata })
691 }
692
693 pub fn map_layer_name(&self, trustformers_name: &str) -> Result<String> {
699 let mapped = trustformers_name
701 .replace("encoder.layer.", "encoder_")
702 .replace("decoder.layer.", "decoder_")
703 .replace(".attention.", "_attn_")
704 .replace(".feed_forward.", "_ffn_")
705 .replace(".query.", "_q_")
706 .replace(".key.", "_k_")
707 .replace(".value.", "_v_")
708 .replace(".weight", "_weight")
709 .replace(".bias", "_bias");
710
711 Ok(mapped)
712 }
713}
714
715impl Default for TrustformersWeightLoader {
716 fn default() -> Self {
717 Self::new()
718 }
719}
720
721#[derive(Clone, Debug, Default)]
723pub struct CheckpointData {
724 pub weights: std::collections::HashMap<String, Vec<f32>>,
726 pub metadata: std::collections::HashMap<String, String>,
728}
729
730#[cfg(test)]
731mod tests {
732 use super::*;
733
734 #[test]
735 fn test_integration_config_creation() {
736 let config = IntegrationConfig::new();
737 assert!(config.validate_shapes);
738 assert!(config.preserve_dropout);
739 assert!(config.pre_norm);
740 assert!((config.numerical_tolerance - 1e-6).abs() < 1e-10);
741 }
742
743 #[test]
744 fn test_integration_config_builder() {
745 let config = IntegrationConfig::new()
746 .with_shape_validation(false)
747 .with_dropout_preservation(false)
748 .with_pre_norm(false)
749 .with_numerical_tolerance(1e-5);
750
751 assert!(!config.validate_shapes);
752 assert!(!config.preserve_dropout);
753 assert!(!config.pre_norm);
754 assert!((config.numerical_tolerance - 1e-5).abs() < 1e-10);
755 }
756
757 #[test]
758 fn test_tensorlogic_model_from_encoder_layer() {
759 let config = EncoderLayerConfig::new(512, 8, 2048).unwrap();
760 let layer = EncoderLayer::new(config.clone()).unwrap();
761 let model = TensorLogicModel::from_encoder_layer(layer, config);
762 assert!(model.is_ok());
763 }
764
765 #[test]
766 fn test_tensorlogic_model_from_encoder_stack() {
767 let config = EncoderStackConfig::new(6, 512, 8, 2048, 1024).unwrap();
768 let stack = EncoderStack::new(config.clone()).unwrap();
769 let model = TensorLogicModel::from_encoder_stack(stack, config);
770 assert!(model.is_ok());
771 }
772
773 #[test]
774 fn test_tensorlogic_model_build_graph() {
775 let config = EncoderLayerConfig::new(512, 8, 2048).unwrap();
776 let layer = EncoderLayer::new(config.clone()).unwrap();
777 let model = TensorLogicModel::from_encoder_layer(layer, config).unwrap();
778
779 let mut graph = EinsumGraph::new();
780 graph.add_tensor("input");
781
782 let outputs = model.build_graph(&mut graph);
783 assert!(outputs.is_ok());
784 }
785
786 #[test]
787 fn test_tensorlogic_model_to_tlexpr() {
788 let config = EncoderLayerConfig::new(512, 8, 2048).unwrap();
789 let layer = EncoderLayer::new(config.clone()).unwrap();
790 let model = TensorLogicModel::from_encoder_layer(layer, config).unwrap();
791
792 let expr = model.to_tlexpr();
793 assert!(expr.is_ok());
794 }
795
796 #[test]
797 fn test_tensorlogic_model_config() {
798 let config = EncoderLayerConfig::new(512, 8, 2048).unwrap();
799 let layer = EncoderLayer::new(config.clone()).unwrap();
800 let model = TensorLogicModel::from_encoder_layer(layer, config).unwrap();
801
802 let model_config = model.config();
803 match model_config {
804 ModelConfig::EncoderLayer {
805 d_model,
806 n_heads,
807 d_ff,
808 ..
809 } => {
810 assert_eq!(d_model, 512);
811 assert_eq!(n_heads, 8);
812 assert_eq!(d_ff, 2048);
813 }
814 _ => panic!("Expected EncoderLayer config"),
815 }
816 }
817
818 #[test]
819 fn test_trustformers_converter_creation() {
820 let converter = TrustformersConverter::new();
821 assert!(converter.config.validate_shapes);
822 }
823
824 #[test]
825 fn test_trustformers_converter_with_config() {
826 let config = IntegrationConfig::new().with_shape_validation(false);
827 let converter = TrustformersConverter::with_config(config);
828 assert!(!converter.config.validate_shapes);
829 }
830
831 #[test]
832 fn test_convert_bert_encoder() {
833 let converter = TrustformersConverter::new();
834 let expr = converter.convert_bert_encoder(6, 512, 8, 2048);
835 assert!(expr.is_ok());
836
837 let expr = expr.unwrap();
838 match expr {
839 TLExpr::ForAll { var, body, .. } => {
840 assert_eq!(var, "layer");
841 match *body {
842 TLExpr::And(..) => {
843 }
845 _ => panic!("Expected And"),
846 }
847 }
848 _ => panic!("Expected ForAll"),
849 }
850 }
851
852 #[test]
853 fn test_convert_gpt_decoder() {
854 let converter = TrustformersConverter::new();
855 let expr = converter.convert_gpt_decoder(12, 768, 12, 3072);
856 assert!(expr.is_ok());
857 }
858
859 #[test]
860 fn test_convert_transformer_encoder_only() {
861 let converter = TrustformersConverter::new();
862 let expr = converter.convert_transformer(6, 0, 512, 8, 2048);
863 assert!(expr.is_ok());
864 }
865
866 #[test]
867 fn test_convert_transformer_decoder_only() {
868 let converter = TrustformersConverter::new();
869 let expr = converter.convert_transformer(0, 6, 512, 8, 2048);
870 assert!(expr.is_ok());
871 }
872
873 #[test]
874 fn test_convert_transformer_encoder_decoder() {
875 let converter = TrustformersConverter::new();
876 let expr = converter.convert_transformer(6, 6, 512, 8, 2048);
877 assert!(expr.is_ok());
878
879 let expr = expr.unwrap();
880 match expr {
881 TLExpr::And(..) => {
882 }
884 _ => panic!("Expected And"),
885 }
886 }
887
888 #[test]
889 fn test_convert_transformer_invalid_zero_layers() {
890 let converter = TrustformersConverter::new();
891 let expr = converter.convert_transformer(0, 0, 512, 8, 2048);
892 assert!(expr.is_err());
893 }
894
895 #[test]
896 fn test_convert_bert_invalid_heads() {
897 let converter = TrustformersConverter::new();
898 let expr = converter.convert_bert_encoder(6, 512, 7, 2048);
900 assert!(expr.is_err());
901 }
902
903 #[test]
904 fn test_weight_loader_creation() {
905 let loader = TrustformersWeightLoader::new();
906 assert!(loader.config.validate_shapes);
907 }
908
909 #[test]
910 fn test_weight_loader_map_layer_name() {
911 let loader = TrustformersWeightLoader::new();
912
913 let mapped = loader
914 .map_layer_name("encoder.layer.0.attention.query.weight")
915 .unwrap();
916 assert_eq!(mapped, "encoder_0_attn_query_weight");
917
918 let mapped = loader
919 .map_layer_name("decoder.layer.5.feed_forward.weight")
920 .unwrap();
921 assert_eq!(mapped, "decoder_5_ffn_weight");
922 }
923
924 #[test]
925 fn test_checkpoint_data_default() {
926 let data = CheckpointData::default();
927 assert!(data.weights.is_empty());
928 assert!(data.metadata.is_empty());
929 }
930}