Skip to main content

tensorlogic_trustformers/
trustformers_integration.rs

1//! Integration layer between TensorLogic and TrustformeRS.
2//!
3//! This module provides bidirectional conversion between TensorLogic's einsum-based
4//! transformer components and TrustformeRS's model traits. It enables:
5//!
6//! 1. **TensorLogic → TrustformeRS**: Wrap TensorLogic transformer components as TrustformeRS models
7//! 2. **TrustformeRS → TensorLogic**: Convert TrustformeRS model architectures to TLExpr
8//! 3. **Weight Loading**: Load pre-trained weights from TrustformeRS checkpoint format
9//! 4. **Model Export**: Export trained TensorLogic models to TrustformeRS format
10//!
11//! ## Design Philosophy
12//!
13//! - **Zero-Copy Where Possible**: Minimize data copying during conversions
14//! - **Type Safety**: Leverage Rust's type system to prevent runtime errors
15//! - **Backend Agnostic**: Conversions work with any TensorLogic backend
16//! - **Compatibility**: Support standard TrustformeRS checkpoint formats
17//!
18//! ## Example: TensorLogic → TrustformeRS
19//!
20//! ```rust,no_run
21//! use tensorlogic_trustformers::{EncoderStack, EncoderStackConfig};
22//! use tensorlogic_trustformers::trustformers_integration::TensorLogicModel;
23//!
24//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
25//! // Create a TensorLogic encoder
26//! let config = EncoderStackConfig::new(6, 512, 8, 2048, 1024)?;
27//! let encoder = EncoderStack::new(config.clone())?;
28//!
29//! // Wrap as TrustformeRS model
30//! let model = TensorLogicModel::from_encoder_stack(encoder, config)?;
31//!
32//! // Now it implements the TrustformeRS Model trait
33//! // let output = model.forward(input)?;
34//! # Ok(())
35//! # }
36//! ```
37//!
38//! ## Example: TrustformeRS → TensorLogic
39//!
40//! ```rust,ignore
41//! use tensorlogic_trustformers::trustformers_integration::TrustformersConverter;
42//!
43//! // Convert TrustformeRS model architecture to TLExpr
44//! let converter = TrustformersConverter::new();
45//! // let tlexpr = converter.convert_model_architecture(&trustformers_model)?;
46//!
47//! // Compile to einsum graph
48//! // use tensorlogic_compiler::CompilerContext;
49//! // let mut ctx = CompilerContext::new();
50//! // let graph = ctx.compile(&tlexpr)?;
51//! ```
52
53use serde::{Deserialize, Serialize};
54use tensorlogic_ir::{EinsumGraph, TLExpr, Term};
55
56use crate::{
57    config::{AttentionConfig, FeedForwardConfig},
58    error::{Result, TrustformerError},
59    layers::{EncoderLayer, EncoderLayerConfig},
60    stacks::{EncoderStack, EncoderStackConfig},
61};
62
63/// Configuration for TensorLogic <-> TrustformeRS conversion
64#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
65pub struct IntegrationConfig {
66    /// Whether to validate shapes during conversion
67    pub validate_shapes: bool,
68    /// Whether to preserve dropout layers (or compile them out)
69    pub preserve_dropout: bool,
70    /// Whether to use pre-layer normalization (vs post-layer)
71    pub pre_norm: bool,
72    /// Tolerance for numerical differences during validation
73    pub numerical_tolerance: f64,
74}
75
76impl Default for IntegrationConfig {
77    fn default() -> Self {
78        Self {
79            validate_shapes: true,
80            preserve_dropout: true,
81            pre_norm: true,
82            numerical_tolerance: 1e-6,
83        }
84    }
85}
86
87impl IntegrationConfig {
88    /// Create a new integration configuration
89    pub fn new() -> Self {
90        Self::default()
91    }
92
93    /// Set whether to validate shapes
94    pub fn with_shape_validation(mut self, validate: bool) -> Self {
95        self.validate_shapes = validate;
96        self
97    }
98
99    /// Set whether to preserve dropout
100    pub fn with_dropout_preservation(mut self, preserve: bool) -> Self {
101        self.preserve_dropout = preserve;
102        self
103    }
104
105    /// Set whether to use pre-layer normalization
106    pub fn with_pre_norm(mut self, pre_norm: bool) -> Self {
107        self.pre_norm = pre_norm;
108        self
109    }
110
111    /// Set numerical tolerance
112    pub fn with_numerical_tolerance(mut self, tolerance: f64) -> Self {
113        self.numerical_tolerance = tolerance;
114        self
115    }
116}
117
118/// Wrapper for TensorLogic transformer components that implements TrustformeRS Model trait
119///
120/// This allows TensorLogic einsum-based transformers to be used wherever
121/// TrustformeRS models are expected.
122#[derive(Clone, Debug)]
123pub enum TensorLogicModel {
124    /// Single encoder layer
125    EncoderLayer {
126        layer: EncoderLayer,
127        config: EncoderLayerConfig,
128    },
129    /// Stack of encoder layers
130    EncoderStack {
131        stack: EncoderStack,
132        config: EncoderStackConfig,
133    },
134}
135
136impl TensorLogicModel {
137    /// Create from an encoder layer
138    pub fn from_encoder_layer(layer: EncoderLayer, config: EncoderLayerConfig) -> Result<Self> {
139        config.validate()?;
140        Ok(Self::EncoderLayer { layer, config })
141    }
142
143    /// Create from an encoder stack
144    pub fn from_encoder_stack(stack: EncoderStack, config: EncoderStackConfig) -> Result<Self> {
145        config.validate()?;
146        Ok(Self::EncoderStack { stack, config })
147    }
148
149    /// Build einsum graph for this model
150    pub fn build_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
151        match self {
152            Self::EncoderLayer { layer, .. } => layer.build_encoder_layer_graph(graph),
153            Self::EncoderStack { stack, .. } => stack.build_encoder_stack_graph(graph),
154        }
155    }
156
157    /// Get the model configuration
158    pub fn config(&self) -> ModelConfig {
159        match self {
160            Self::EncoderLayer { config, .. } => ModelConfig::EncoderLayer {
161                d_model: config.attention.d_model,
162                n_heads: config.attention.n_heads,
163                d_ff: config.feed_forward.d_ff,
164                dropout: config.attention.dropout,
165                pre_norm: config.pre_norm,
166            },
167            Self::EncoderStack { config, .. } => ModelConfig::EncoderStack {
168                n_layers: config.num_layers,
169                d_model: config.layer_config.attention.d_model,
170                n_heads: config.layer_config.attention.n_heads,
171                d_ff: config.layer_config.feed_forward.d_ff,
172                max_seq_len: config.position_encoding.max_seq_len,
173                dropout: config.layer_config.attention.dropout,
174                pre_norm: config.layer_config.pre_norm,
175            },
176        }
177    }
178
179    /// Convert to TLExpr representation
180    pub fn to_tlexpr(&self) -> Result<TLExpr> {
181        match self {
182            Self::EncoderLayer { config, .. } => {
183                // Represent encoder layer as logical conjunction of attention and FFN
184                let attention_expr = Self::attention_to_tlexpr(&config.attention)?;
185                let ffn_expr = Self::ffn_to_tlexpr(&config.feed_forward)?;
186
187                // Compose using And: attention AND ffn (both must be applied)
188                Ok(TLExpr::And(Box::new(attention_expr), Box::new(ffn_expr)))
189            }
190            Self::EncoderStack { config, .. } => {
191                // Represent stack as repeated application of encoder layers
192                let layer_expr = {
193                    let attn_cfg = AttentionConfig::new(
194                        config.layer_config.attention.d_model,
195                        config.layer_config.attention.n_heads,
196                    )?;
197                    let ffn_cfg = FeedForwardConfig::new(
198                        config.layer_config.feed_forward.d_model,
199                        config.layer_config.feed_forward.d_ff,
200                    );
201
202                    let attention_expr = Self::attention_to_tlexpr(&attn_cfg)?;
203                    let ffn_expr = Self::ffn_to_tlexpr(&ffn_cfg)?;
204
205                    TLExpr::And(Box::new(attention_expr), Box::new(ffn_expr))
206                };
207
208                // Repeat num_layers times using ForAll
209                Ok(TLExpr::ForAll {
210                    var: "layer".to_string(),
211                    domain: format!("0..{}", config.num_layers),
212                    body: Box::new(layer_expr),
213                })
214            }
215        }
216    }
217
218    /// Convert attention configuration to TLExpr
219    fn attention_to_tlexpr(config: &AttentionConfig) -> Result<TLExpr> {
220        // Multi-head attention as einsum operations
221        Ok(TLExpr::Pred {
222            name: "MultiHeadAttention".to_string(),
223            args: vec![
224                Term::Const(format!("d_model={}", config.d_model)),
225                Term::Const(format!("n_heads={}", config.n_heads)),
226                Term::Const(format!("d_k={}", config.d_k)),
227            ],
228        })
229    }
230
231    /// Convert FFN configuration to TLExpr
232    fn ffn_to_tlexpr(config: &FeedForwardConfig) -> Result<TLExpr> {
233        Ok(TLExpr::Pred {
234            name: "FeedForward".to_string(),
235            args: vec![
236                Term::Const(format!("d_model={}", config.d_model)),
237                Term::Const(format!("d_ff={}", config.d_ff)),
238                Term::Const(format!("activation={}", config.activation)),
239            ],
240        })
241    }
242}
243
244/// Configuration description for a TensorLogic model
245#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
246pub enum ModelConfig {
247    /// Single encoder layer configuration
248    EncoderLayer {
249        d_model: usize,
250        n_heads: usize,
251        d_ff: usize,
252        dropout: f64,
253        pre_norm: bool,
254    },
255    /// Encoder stack configuration
256    EncoderStack {
257        n_layers: usize,
258        d_model: usize,
259        n_heads: usize,
260        d_ff: usize,
261        max_seq_len: usize,
262        dropout: f64,
263        pre_norm: bool,
264    },
265}
266
267/// Converter from TrustformeRS model architectures to TensorLogic IR
268///
269/// This converter analyzes TrustformeRS model structures and generates
270/// equivalent TLExpr representations that can be compiled to einsum graphs.
271#[derive(Clone, Debug)]
272pub struct TrustformersConverter {
273    /// Conversion configuration
274    pub config: IntegrationConfig,
275}
276
277impl TrustformersConverter {
278    /// Create a new converter with default configuration
279    pub fn new() -> Self {
280        Self {
281            config: IntegrationConfig::default(),
282        }
283    }
284
285    /// Create a new converter with custom configuration
286    pub fn with_config(config: IntegrationConfig) -> Self {
287        Self { config }
288    }
289
290    /// Convert a BERT-style encoder model to TLExpr
291    ///
292    /// This analyzes the model's layer structure and generates corresponding
293    /// TensorLogic expressions.
294    pub fn convert_bert_encoder(
295        &self,
296        n_layers: usize,
297        d_model: usize,
298        n_heads: usize,
299        d_ff: usize,
300    ) -> Result<TLExpr> {
301        // Validate configuration
302        if n_layers == 0 {
303            return Err(TrustformerError::InvalidDimension {
304                expected: 1,
305                got: 0,
306                context: "n_layers must be > 0".to_string(),
307            });
308        }
309        if !d_model.is_multiple_of(n_heads) {
310            return Err(TrustformerError::InvalidDimension {
311                expected: n_heads,
312                got: d_model,
313                context: format!(
314                    "d_model {} must be divisible by n_heads {}",
315                    d_model, n_heads
316                ),
317            });
318        }
319
320        // Create decoder layer with causal attention
321        let attn_cfg = AttentionConfig::new(d_model, n_heads)?;
322        let ffn_cfg = FeedForwardConfig::new(d_model, d_ff);
323
324        let attention_expr = TLExpr::Pred {
325            name: "MultiHeadAttention".to_string(),
326            args: vec![
327                Term::Const(format!("d_model={}", attn_cfg.d_model)),
328                Term::Const(format!("n_heads={}", attn_cfg.n_heads)),
329                Term::Const(format!("d_k={}", attn_cfg.d_k)),
330            ],
331        };
332
333        let ffn_expr = TLExpr::Pred {
334            name: "FeedForward".to_string(),
335            args: vec![
336                Term::Const(format!("d_model={}", ffn_cfg.d_model)),
337                Term::Const(format!("d_ff={}", ffn_cfg.d_ff)),
338                Term::Const(format!("activation={}", ffn_cfg.activation)),
339            ],
340        };
341
342        let layer_expr = TLExpr::And(Box::new(attention_expr), Box::new(ffn_expr));
343
344        // Repeat for all layers
345        Ok(TLExpr::ForAll {
346            var: "layer".to_string(),
347            domain: format!("0..{}", n_layers),
348            body: Box::new(layer_expr),
349        })
350    }
351
352    /// Convert a GPT-style decoder model to TLExpr
353    pub fn convert_gpt_decoder(
354        &self,
355        n_layers: usize,
356        d_model: usize,
357        n_heads: usize,
358        d_ff: usize,
359    ) -> Result<TLExpr> {
360        // Validate configuration
361        if n_layers == 0 {
362            return Err(TrustformerError::InvalidDimension {
363                expected: 1,
364                got: 0,
365                context: "n_layers must be > 0".to_string(),
366            });
367        }
368        if !d_model.is_multiple_of(n_heads) {
369            return Err(TrustformerError::InvalidDimension {
370                expected: n_heads,
371                got: d_model,
372                context: format!(
373                    "d_model {} must be divisible by n_heads {}",
374                    d_model, n_heads
375                ),
376            });
377        }
378
379        // Create decoder layer with causal attention
380        let attn_cfg = AttentionConfig::new(d_model, n_heads)?.with_causal(true);
381        let ffn_cfg = FeedForwardConfig::new(d_model, d_ff);
382
383        let causal_attention_expr = TLExpr::Pred {
384            name: "CausalMultiHeadAttention".to_string(),
385            args: vec![
386                Term::Const(format!("d_model={}", attn_cfg.d_model)),
387                Term::Const(format!("n_heads={}", attn_cfg.n_heads)),
388                Term::Const(format!("d_k={}", attn_cfg.d_k)),
389                Term::Const("causal=true".to_string()),
390            ],
391        };
392
393        let ffn_expr = TLExpr::Pred {
394            name: "FeedForward".to_string(),
395            args: vec![
396                Term::Const(format!("d_model={}", ffn_cfg.d_model)),
397                Term::Const(format!("d_ff={}", ffn_cfg.d_ff)),
398                Term::Const(format!("activation={}", ffn_cfg.activation)),
399            ],
400        };
401
402        let layer_expr = TLExpr::And(Box::new(causal_attention_expr), Box::new(ffn_expr));
403
404        // Repeat for all layers
405        Ok(TLExpr::ForAll {
406            var: "layer".to_string(),
407            domain: format!("0..{}", n_layers),
408            body: Box::new(layer_expr),
409        })
410    }
411
412    /// Convert generic transformer architecture to TLExpr
413    pub fn convert_transformer(
414        &self,
415        encoder_layers: usize,
416        decoder_layers: usize,
417        d_model: usize,
418        n_heads: usize,
419        d_ff: usize,
420    ) -> Result<TLExpr> {
421        let encoder_expr = if encoder_layers > 0 {
422            Some(self.convert_bert_encoder(encoder_layers, d_model, n_heads, d_ff)?)
423        } else {
424            None
425        };
426
427        let decoder_expr = if decoder_layers > 0 {
428            Some(self.convert_gpt_decoder(decoder_layers, d_model, n_heads, d_ff)?)
429        } else {
430            None
431        };
432
433        match (encoder_expr, decoder_expr) {
434            (Some(enc), Some(dec)) => {
435                // Full encoder-decoder transformer (encoder AND decoder both applied)
436                Ok(TLExpr::And(Box::new(enc), Box::new(dec)))
437            }
438            (Some(enc), None) => Ok(enc),
439            (None, Some(dec)) => Ok(dec),
440            (None, None) => Err(TrustformerError::InvalidDimension {
441                expected: 1,
442                got: 0,
443                context: "At least one of encoder_layers or decoder_layers must be > 0".to_string(),
444            }),
445        }
446    }
447}
448
449impl Default for TrustformersConverter {
450    fn default() -> Self {
451        Self::new()
452    }
453}
454
455/// Weight loader for TrustformeRS checkpoint format
456///
457/// Supports loading weights from various TrustformeRS checkpoint formats:
458/// - SafeTensors
459/// - PyTorch .bin
460/// - TensorFlow SavedModel
461#[derive(Clone, Debug)]
462pub struct TrustformersWeightLoader {
463    /// Integration configuration
464    pub config: IntegrationConfig,
465}
466
467impl TrustformersWeightLoader {
468    /// Create a new weight loader
469    pub fn new() -> Self {
470        Self {
471            config: IntegrationConfig::default(),
472        }
473    }
474
475    /// Create with custom configuration
476    pub fn with_config(config: IntegrationConfig) -> Self {
477        Self { config }
478    }
479
480    /// Load weights from a TrustformeRS checkpoint file
481    ///
482    /// Supports multiple checkpoint formats:
483    /// 1. JSON format (*.json) - Simple text-based format
484    /// 2. Binary format (*.bin) - Raw binary weights with metadata header
485    ///
486    /// ## JSON Format
487    ///
488    /// ```json
489    /// {
490    ///   "metadata": {
491    ///     "model_type": "encoder",
492    ///     "n_layers": "6",
493    ///     "d_model": "512"
494    ///   },
495    ///   "weights": {
496    ///     "encoder.layer.0.attention.query.weight": [0.1, 0.2, ...],
497    ///     "encoder.layer.0.attention.key.weight": [...]
498    ///   }
499    /// }
500    /// ```
501    ///
502    /// ## Binary Format
503    ///
504    /// Header (256 bytes):
505    /// - Magic: "TLCKPT" (6 bytes)
506    /// - Version: u32 (4 bytes)
507    /// - Num tensors: u32 (4 bytes)
508    /// - Metadata size: u32 (4 bytes)
509    /// - Reserved: (240 bytes)
510    ///
511    /// Followed by:
512    /// - Metadata JSON (metadata_size bytes)
513    /// - Tensor entries (name_length + name + data_length + data)
514    ///
515    /// ## Example
516    ///
517    /// ```no_run
518    /// use tensorlogic_trustformers::trustformers_integration::TrustformersWeightLoader;
519    ///
520    /// let loader = TrustformersWeightLoader::new();
521    /// let checkpoint = loader.load_checkpoint("model.json")?;
522    ///
523    /// // Access weights
524    /// if let Some(weights) = checkpoint.weights.get("encoder_0_attn_q_weight") {
525    ///     println!("Query weights: {:?}", &weights[..10]);
526    /// }
527    /// # Ok::<(), Box<dyn std::error::Error>>(())
528    /// ```
529    pub fn load_checkpoint(&self, path: &str) -> Result<CheckpointData> {
530        use std::path::Path;
531
532        let path_obj = Path::new(path);
533
534        if !path_obj.exists() {
535            return Err(TrustformerError::CheckpointLoadError(format!(
536                "Checkpoint file not found: {}",
537                path
538            )));
539        }
540
541        // Determine format based on extension
542        let extension = path_obj
543            .extension()
544            .and_then(|s| s.to_str())
545            .ok_or_else(|| {
546                TrustformerError::CheckpointLoadError(format!(
547                    "Cannot determine checkpoint format for: {}",
548                    path
549                ))
550            })?;
551
552        match extension {
553            "json" => self.load_json_checkpoint(path),
554            "bin" | "ckpt" => self.load_binary_checkpoint(path),
555            _ => Err(TrustformerError::CheckpointLoadError(format!(
556                "Unsupported checkpoint format: .{}",
557                extension
558            ))),
559        }
560    }
561
562    /// Load checkpoint from JSON format
563    fn load_json_checkpoint(&self, path: &str) -> Result<CheckpointData> {
564        use std::fs;
565
566        let content = fs::read_to_string(path).map_err(|e| {
567            TrustformerError::CheckpointLoadError(format!("Failed to read checkpoint: {}", e))
568        })?;
569
570        #[derive(Deserialize)]
571        struct JsonCheckpoint {
572            #[serde(default)]
573            metadata: std::collections::HashMap<String, String>,
574            weights: std::collections::HashMap<String, Vec<f32>>,
575        }
576
577        let json_ckpt: JsonCheckpoint = serde_json::from_str(&content).map_err(|e| {
578            TrustformerError::CheckpointLoadError(format!("Invalid JSON checkpoint: {}", e))
579        })?;
580
581        // Map TrustformeRS names to TensorLogic names
582        let mut mapped_weights = std::collections::HashMap::new();
583        for (trustformers_name, weights) in json_ckpt.weights {
584            let tl_name = self.map_layer_name(&trustformers_name)?;
585            mapped_weights.insert(tl_name, weights);
586        }
587
588        Ok(CheckpointData {
589            weights: mapped_weights,
590            metadata: json_ckpt.metadata,
591        })
592    }
593
594    /// Load checkpoint from binary format
595    fn load_binary_checkpoint(&self, path: &str) -> Result<CheckpointData> {
596        use std::fs;
597        use std::io::{BufReader, Read};
598
599        let file = fs::File::open(path).map_err(|e| {
600            TrustformerError::CheckpointLoadError(format!("Failed to open checkpoint: {}", e))
601        })?;
602
603        let mut reader = BufReader::new(file);
604
605        // Read header (256 bytes)
606        let mut header = [0u8; 256];
607        reader.read_exact(&mut header).map_err(|e| {
608            TrustformerError::CheckpointLoadError(format!("Failed to read header: {}", e))
609        })?;
610
611        // Verify magic
612        let magic = &header[0..6];
613        if magic != b"TLCKPT" {
614            return Err(TrustformerError::CheckpointLoadError(
615                "Invalid checkpoint magic number".to_string(),
616            ));
617        }
618
619        // Read version (u32 at offset 6)
620        let version = u32::from_le_bytes([header[6], header[7], header[8], header[9]]);
621        if version != 1 {
622            return Err(TrustformerError::CheckpointLoadError(format!(
623                "Unsupported checkpoint version: {}",
624                version
625            )));
626        }
627
628        // Read num_tensors (u32 at offset 10)
629        let num_tensors = u32::from_le_bytes([header[10], header[11], header[12], header[13]]);
630
631        // Read metadata_size (u32 at offset 14)
632        let metadata_size = u32::from_le_bytes([header[14], header[15], header[16], header[17]]);
633
634        // Read metadata JSON
635        let mut metadata_bytes = vec![0u8; metadata_size as usize];
636        reader.read_exact(&mut metadata_bytes).map_err(|e| {
637            TrustformerError::CheckpointLoadError(format!("Failed to read metadata: {}", e))
638        })?;
639
640        let metadata: std::collections::HashMap<String, String> =
641            serde_json::from_slice(&metadata_bytes).map_err(|e| {
642                TrustformerError::CheckpointLoadError(format!("Invalid metadata JSON: {}", e))
643            })?;
644
645        // Read tensor entries
646        let mut weights = std::collections::HashMap::new();
647
648        for _ in 0..num_tensors {
649            // Read name length (u32)
650            let mut name_len_bytes = [0u8; 4];
651            reader.read_exact(&mut name_len_bytes).map_err(|e| {
652                TrustformerError::CheckpointLoadError(format!("Failed to read name length: {}", e))
653            })?;
654            let name_len = u32::from_le_bytes(name_len_bytes) as usize;
655
656            // Read name
657            let mut name_bytes = vec![0u8; name_len];
658            reader.read_exact(&mut name_bytes).map_err(|e| {
659                TrustformerError::CheckpointLoadError(format!("Failed to read tensor name: {}", e))
660            })?;
661            let trustformers_name = String::from_utf8(name_bytes).map_err(|e| {
662                TrustformerError::CheckpointLoadError(format!("Invalid tensor name UTF-8: {}", e))
663            })?;
664
665            // Read data length (u32)
666            let mut data_len_bytes = [0u8; 4];
667            reader.read_exact(&mut data_len_bytes).map_err(|e| {
668                TrustformerError::CheckpointLoadError(format!("Failed to read data length: {}", e))
669            })?;
670            let data_len = u32::from_le_bytes(data_len_bytes) as usize;
671
672            // Read weights (f32 array)
673            let mut weight_bytes = vec![0u8; data_len * 4];
674            reader.read_exact(&mut weight_bytes).map_err(|e| {
675                TrustformerError::CheckpointLoadError(format!("Failed to read weights: {}", e))
676            })?;
677
678            // Convert bytes to f32
679            let mut tensor_weights = Vec::with_capacity(data_len);
680            for chunk in weight_bytes.chunks_exact(4) {
681                let value = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
682                tensor_weights.push(value);
683            }
684
685            // Map name
686            let tl_name = self.map_layer_name(&trustformers_name)?;
687            weights.insert(tl_name, tensor_weights);
688        }
689
690        Ok(CheckpointData { weights, metadata })
691    }
692
693    /// Map TrustformeRS layer names to TensorLogic tensor names
694    ///
695    /// Example mappings:
696    /// - "encoder.layer.0.attention.query.weight" -> "encoder_0_attn_q_weight"
697    /// - "encoder.layer.0.attention.key.weight" -> "encoder_0_attn_k_weight"
698    pub fn map_layer_name(&self, trustformers_name: &str) -> Result<String> {
699        // Simple mapping strategy - can be made more sophisticated
700        let mapped = trustformers_name
701            .replace("encoder.layer.", "encoder_")
702            .replace("decoder.layer.", "decoder_")
703            .replace(".attention.", "_attn_")
704            .replace(".feed_forward.", "_ffn_")
705            .replace(".query.", "_q_")
706            .replace(".key.", "_k_")
707            .replace(".value.", "_v_")
708            .replace(".weight", "_weight")
709            .replace(".bias", "_bias");
710
711        Ok(mapped)
712    }
713}
714
715impl Default for TrustformersWeightLoader {
716    fn default() -> Self {
717        Self::new()
718    }
719}
720
721/// Checkpoint data loaded from TrustformeRS format
722#[derive(Clone, Debug, Default)]
723pub struct CheckpointData {
724    /// Mapping from tensor names to weight data
725    pub weights: std::collections::HashMap<String, Vec<f32>>,
726    /// Model configuration metadata
727    pub metadata: std::collections::HashMap<String, String>,
728}
729
730#[cfg(test)]
731mod tests {
732    use super::*;
733
734    #[test]
735    fn test_integration_config_creation() {
736        let config = IntegrationConfig::new();
737        assert!(config.validate_shapes);
738        assert!(config.preserve_dropout);
739        assert!(config.pre_norm);
740        assert!((config.numerical_tolerance - 1e-6).abs() < 1e-10);
741    }
742
743    #[test]
744    fn test_integration_config_builder() {
745        let config = IntegrationConfig::new()
746            .with_shape_validation(false)
747            .with_dropout_preservation(false)
748            .with_pre_norm(false)
749            .with_numerical_tolerance(1e-5);
750
751        assert!(!config.validate_shapes);
752        assert!(!config.preserve_dropout);
753        assert!(!config.pre_norm);
754        assert!((config.numerical_tolerance - 1e-5).abs() < 1e-10);
755    }
756
757    #[test]
758    fn test_tensorlogic_model_from_encoder_layer() {
759        let config = EncoderLayerConfig::new(512, 8, 2048).unwrap();
760        let layer = EncoderLayer::new(config.clone()).unwrap();
761        let model = TensorLogicModel::from_encoder_layer(layer, config);
762        assert!(model.is_ok());
763    }
764
765    #[test]
766    fn test_tensorlogic_model_from_encoder_stack() {
767        let config = EncoderStackConfig::new(6, 512, 8, 2048, 1024).unwrap();
768        let stack = EncoderStack::new(config.clone()).unwrap();
769        let model = TensorLogicModel::from_encoder_stack(stack, config);
770        assert!(model.is_ok());
771    }
772
773    #[test]
774    fn test_tensorlogic_model_build_graph() {
775        let config = EncoderLayerConfig::new(512, 8, 2048).unwrap();
776        let layer = EncoderLayer::new(config.clone()).unwrap();
777        let model = TensorLogicModel::from_encoder_layer(layer, config).unwrap();
778
779        let mut graph = EinsumGraph::new();
780        graph.add_tensor("input");
781
782        let outputs = model.build_graph(&mut graph);
783        assert!(outputs.is_ok());
784    }
785
786    #[test]
787    fn test_tensorlogic_model_to_tlexpr() {
788        let config = EncoderLayerConfig::new(512, 8, 2048).unwrap();
789        let layer = EncoderLayer::new(config.clone()).unwrap();
790        let model = TensorLogicModel::from_encoder_layer(layer, config).unwrap();
791
792        let expr = model.to_tlexpr();
793        assert!(expr.is_ok());
794    }
795
796    #[test]
797    fn test_tensorlogic_model_config() {
798        let config = EncoderLayerConfig::new(512, 8, 2048).unwrap();
799        let layer = EncoderLayer::new(config.clone()).unwrap();
800        let model = TensorLogicModel::from_encoder_layer(layer, config).unwrap();
801
802        let model_config = model.config();
803        match model_config {
804            ModelConfig::EncoderLayer {
805                d_model,
806                n_heads,
807                d_ff,
808                ..
809            } => {
810                assert_eq!(d_model, 512);
811                assert_eq!(n_heads, 8);
812                assert_eq!(d_ff, 2048);
813            }
814            _ => panic!("Expected EncoderLayer config"),
815        }
816    }
817
818    #[test]
819    fn test_trustformers_converter_creation() {
820        let converter = TrustformersConverter::new();
821        assert!(converter.config.validate_shapes);
822    }
823
824    #[test]
825    fn test_trustformers_converter_with_config() {
826        let config = IntegrationConfig::new().with_shape_validation(false);
827        let converter = TrustformersConverter::with_config(config);
828        assert!(!converter.config.validate_shapes);
829    }
830
831    #[test]
832    fn test_convert_bert_encoder() {
833        let converter = TrustformersConverter::new();
834        let expr = converter.convert_bert_encoder(6, 512, 8, 2048);
835        assert!(expr.is_ok());
836
837        let expr = expr.unwrap();
838        match expr {
839            TLExpr::ForAll { var, body, .. } => {
840                assert_eq!(var, "layer");
841                match *body {
842                    TLExpr::And(..) => {
843                        // Correctly represents composition of attention and FFN
844                    }
845                    _ => panic!("Expected And"),
846                }
847            }
848            _ => panic!("Expected ForAll"),
849        }
850    }
851
852    #[test]
853    fn test_convert_gpt_decoder() {
854        let converter = TrustformersConverter::new();
855        let expr = converter.convert_gpt_decoder(12, 768, 12, 3072);
856        assert!(expr.is_ok());
857    }
858
859    #[test]
860    fn test_convert_transformer_encoder_only() {
861        let converter = TrustformersConverter::new();
862        let expr = converter.convert_transformer(6, 0, 512, 8, 2048);
863        assert!(expr.is_ok());
864    }
865
866    #[test]
867    fn test_convert_transformer_decoder_only() {
868        let converter = TrustformersConverter::new();
869        let expr = converter.convert_transformer(0, 6, 512, 8, 2048);
870        assert!(expr.is_ok());
871    }
872
873    #[test]
874    fn test_convert_transformer_encoder_decoder() {
875        let converter = TrustformersConverter::new();
876        let expr = converter.convert_transformer(6, 6, 512, 8, 2048);
877        assert!(expr.is_ok());
878
879        let expr = expr.unwrap();
880        match expr {
881            TLExpr::And(..) => {
882                // Correctly represents encoder AND decoder composition
883            }
884            _ => panic!("Expected And"),
885        }
886    }
887
888    #[test]
889    fn test_convert_transformer_invalid_zero_layers() {
890        let converter = TrustformersConverter::new();
891        let expr = converter.convert_transformer(0, 0, 512, 8, 2048);
892        assert!(expr.is_err());
893    }
894
895    #[test]
896    fn test_convert_bert_invalid_heads() {
897        let converter = TrustformersConverter::new();
898        // 512 is not divisible by 7
899        let expr = converter.convert_bert_encoder(6, 512, 7, 2048);
900        assert!(expr.is_err());
901    }
902
903    #[test]
904    fn test_weight_loader_creation() {
905        let loader = TrustformersWeightLoader::new();
906        assert!(loader.config.validate_shapes);
907    }
908
909    #[test]
910    fn test_weight_loader_map_layer_name() {
911        let loader = TrustformersWeightLoader::new();
912
913        let mapped = loader
914            .map_layer_name("encoder.layer.0.attention.query.weight")
915            .unwrap();
916        assert_eq!(mapped, "encoder_0_attn_query_weight");
917
918        let mapped = loader
919            .map_layer_name("decoder.layer.5.feed_forward.weight")
920            .unwrap();
921        assert_eq!(mapped, "decoder_5_ffn_weight");
922    }
923
924    #[test]
925    fn test_checkpoint_data_default() {
926        let data = CheckpointData::default();
927        assert!(data.weights.is_empty());
928        assert!(data.metadata.is_empty());
929    }
930}