Skip to main content

tensorlogic_trustformers/
lib.rs

1//! # Tensorlogic-Trustformers
2//!
3//! **Version**: 0.1.0-beta.1 | **Status**: Production Ready
4//!
5//! Transform transformer architectures into TensorLogic IR using einsum operations.
6//!
7//! This crate provides implementations of transformer components (self-attention,
8//! multi-head attention, feed-forward networks) as einsum graphs that can be
9//! compiled and executed on various TensorLogic backends.
10//!
11//! ## Features
12//!
13//! - **Self-Attention**: Scaled dot-product attention as einsum operations
14//! - **Multi-Head Attention**: Parallel attention heads with head splitting
15//! - **Feed-Forward Networks**: Position-wise FFN with configurable activations
16//! - **Gated FFN**: GLU-style gated feed-forward networks
17//! - **Einsum-Native**: All operations expressed as einsum for maximum flexibility
18//!
19//! ## Architecture
20//!
21//! Transformer components are decomposed into einsum operations:
22//!
23//! ### Self-Attention
24//! ```text
25//! scores = einsum("bqd,bkd->bqk", Q, K) / sqrt(d_k)
26//! attn = softmax(scores, dim=-1)
27//! output = einsum("bqk,bkv->bqv", attn, V)
28//! ```
29//!
30//! ### Multi-Head Attention
31//! ```text
32//! Q, K, V = [batch, seq, d_model] -> [batch, n_heads, seq, d_k]
33//! scores = einsum("bhqd,bhkd->bhqk", Q, K) / sqrt(d_k)
34//! attn = softmax(scores, dim=-1)
35//! output = einsum("bhqk,bhkv->bhqv", attn, V)
36//! output = reshape([batch, seq, d_model])
37//! ```
38//!
39//! ### Feed-Forward Network
40//! ```text
41//! h1 = einsum("bsd,df->bsf", x, W1) + b1
42//! h2 = activation(h1)
43//! output = einsum("bsf,fd->bsd", h2, W2) + b2
44//! ```
45//!
46//! ## Example Usage
47//!
48//! ```rust
49//! use tensorlogic_trustformers::{
50//!     AttentionConfig, SelfAttention, MultiHeadAttention,
51//!     FeedForwardConfig, FeedForward,
52//! };
53//! use tensorlogic_ir::EinsumGraph;
54//!
55//! // Configure self-attention
56//! let attn_config = AttentionConfig::new(512, 8).unwrap();
57//! let self_attn = SelfAttention::new(attn_config.clone()).unwrap();
58//!
59//! // Build einsum graph
60//! let mut graph = EinsumGraph::new();
61//! graph.add_tensor("Q");
62//! graph.add_tensor("K");
63//! graph.add_tensor("V");
64//!
65//! let outputs = self_attn.build_attention_graph(&mut graph).unwrap();
66//!
67//! // Configure multi-head attention
68//! let mha = MultiHeadAttention::new(attn_config).unwrap();
69//! let mut mha_graph = EinsumGraph::new();
70//! mha_graph.add_tensor("Q");
71//! mha_graph.add_tensor("K");
72//! mha_graph.add_tensor("V");
73//!
74//! let mha_outputs = mha.build_mha_graph(&mut mha_graph).unwrap();
75//!
76//! // Configure feed-forward network
77//! let ffn_config = FeedForwardConfig::new(512, 2048)
78//!     .with_activation("gelu");
79//! let ffn = FeedForward::new(ffn_config).unwrap();
80//!
81//! let mut ffn_graph = EinsumGraph::new();
82//! ffn_graph.add_tensor("x");
83//! ffn_graph.add_tensor("W1");
84//! ffn_graph.add_tensor("b1");
85//! ffn_graph.add_tensor("W2");
86//! ffn_graph.add_tensor("b2");
87//!
88//! let ffn_outputs = ffn.build_ffn_graph(&mut ffn_graph).unwrap();
89//! ```
90//!
91//! ## Integration with TensorLogic
92//!
93//! The einsum graphs produced by this crate can be:
94//! - Compiled with `tensorlogic-compiler`
95//! - Executed on `tensorlogic-scirs-backend` or other backends
96//! - Optimized using graph optimization passes
97//! - Combined with logical rules for interpretable transformers
98//!
99//! ## Design Philosophy
100//!
101//! This crate follows the TensorLogic principle of expressing neural operations
102//! as tensor contractions (einsum), enabling:
103//!
104//! 1. **Backend Independence**: Same graph works on CPU, GPU, TPU
105//! 2. **Optimization Opportunities**: Graph-level optimizations like fusion
106//! 3. **Interpretability**: Clear mathematical semantics
107//! 4. **Composability**: Mix transformer layers with logical rules
108
109pub mod attention;
110pub mod checkpointing;
111pub mod config;
112pub mod decoder;
113pub mod encoder;
114pub mod error;
115pub mod ffn;
116pub mod flash_attention;
117pub mod gqa;
118pub mod kv_cache;
119pub mod layers;
120pub mod lora;
121pub mod moe;
122pub mod normalization;
123pub mod patterns;
124pub mod position;
125pub mod presets;
126pub mod rule_attention;
127pub mod sliding_window;
128pub mod sparse_attention;
129pub mod stacks;
130pub mod trustformers_integration;
131pub mod utils;
132pub mod vision;
133
134// Re-export main types for convenient access
135pub use attention::{MultiHeadAttention, SelfAttention};
136pub use checkpointing::{CheckpointConfig, CheckpointStrategy};
137pub use config::{AttentionConfig, FeedForwardConfig, TransformerLayerConfig};
138pub use decoder::{Decoder, DecoderConfig};
139pub use encoder::{Encoder, EncoderConfig};
140pub use error::{Result, TrustformerError};
141pub use ffn::{FeedForward, GatedFeedForward};
142pub use flash_attention::{
143    FlashAttention, FlashAttentionConfig, FlashAttentionPreset, FlashAttentionStats,
144    FlashAttentionV2Config,
145};
146pub use gqa::{GQAConfig, GQAPreset, GQAStats, GroupedQueryAttention};
147pub use kv_cache::{CacheStats, KVCache, KVCacheConfig};
148pub use layers::{DecoderLayer, DecoderLayerConfig, EncoderLayer, EncoderLayerConfig};
149pub use lora::{LoRAAttention, LoRAConfig, LoRALinear, LoRAPreset, LoRAStats};
150pub use moe::{MoeConfig, MoeLayer, MoePreset, MoeStats, RouterType};
151pub use normalization::{LayerNorm, LayerNormConfig, RMSNorm};
152pub use patterns::{
153    AttentionMask, BlockSparseMask, CausalMask, GlobalLocalMask, LocalMask, RuleBasedMask,
154    RulePattern, StridedMask,
155};
156pub use position::{
157    AlibiPositionEncoding, LearnedPositionEncoding, PositionEncodingConfig, PositionEncodingType,
158    RelativePositionEncoding, RotaryPositionEncoding, SinusoidalPositionEncoding,
159};
160pub use presets::ModelPreset;
161pub use rule_attention::{
162    RuleAttentionConfig, RuleAttentionType, RuleBasedAttention, StructuredAttention,
163};
164pub use sliding_window::{
165    SlidingWindowAttention, SlidingWindowConfig, SlidingWindowPreset, SlidingWindowStats,
166};
167pub use sparse_attention::{
168    LocalAttention, SparseAttention, SparseAttentionConfig, SparsePatternType,
169};
170pub use stacks::{DecoderStack, DecoderStackConfig, EncoderStack, EncoderStackConfig};
171pub use trustformers_integration::{
172    CheckpointData, IntegrationConfig, ModelConfig, TensorLogicModel, TrustformersConverter,
173    TrustformersWeightLoader,
174};
175pub use utils::{decoder_stack_stats, encoder_stack_stats, ModelStats};
176pub use vision::{
177    PatchEmbedding, PatchEmbeddingConfig, ViTPreset, VisionTransformer, VisionTransformerConfig,
178};
179
180// Legacy compatibility (keep for backward compatibility)
181#[deprecated(since = "0.1.0", note = "Use AttentionConfig instead")]
182pub type AttnSpec = AttentionConfig;
183
184#[deprecated(
185    since = "0.1.0",
186    note = "Use SelfAttention::build_attention_graph instead"
187)]
188pub fn self_attention_as_rules(_spec: &AttentionConfig) {
189    // Legacy function - use SelfAttention::build_attention_graph instead
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195    use tensorlogic_ir::EinsumGraph;
196
197    #[test]
198    fn test_end_to_end_self_attention() {
199        let config = AttentionConfig::new(512, 8).unwrap();
200        let attn = SelfAttention::new(config).unwrap();
201
202        let mut graph = EinsumGraph::new();
203        graph.add_tensor("Q");
204        graph.add_tensor("K");
205        graph.add_tensor("V");
206
207        let outputs = attn.build_attention_graph(&mut graph).unwrap();
208        assert_eq!(outputs.len(), 1);
209        assert!(graph.validate().is_ok());
210    }
211
212    #[test]
213    fn test_end_to_end_multi_head_attention() {
214        let config = AttentionConfig::new(512, 8).unwrap();
215        let mha = MultiHeadAttention::new(config).unwrap();
216
217        let mut graph = EinsumGraph::new();
218        graph.add_tensor("Q");
219        graph.add_tensor("K");
220        graph.add_tensor("V");
221
222        let outputs = mha.build_mha_graph(&mut graph).unwrap();
223        assert_eq!(outputs.len(), 1);
224        assert!(graph.validate().is_ok());
225    }
226
227    #[test]
228    fn test_end_to_end_ffn() {
229        let config = FeedForwardConfig::new(512, 2048);
230        let ffn = FeedForward::new(config).unwrap();
231
232        let mut graph = EinsumGraph::new();
233        graph.add_tensor("x");
234        graph.add_tensor("W1");
235        graph.add_tensor("b1");
236        graph.add_tensor("W2");
237        graph.add_tensor("b2");
238
239        let outputs = ffn.build_ffn_graph(&mut graph).unwrap();
240        assert_eq!(outputs.len(), 1);
241        assert!(graph.validate().is_ok());
242    }
243
244    #[test]
245    fn test_end_to_end_gated_ffn() {
246        let config = FeedForwardConfig::new(512, 2048);
247        let glu = GatedFeedForward::new(config).unwrap();
248
249        let mut graph = EinsumGraph::new();
250        graph.add_tensor("x");
251        graph.add_tensor("W_gate");
252        graph.add_tensor("W_value");
253        graph.add_tensor("W_out");
254
255        let outputs = glu.build_glu_graph(&mut graph).unwrap();
256        assert_eq!(outputs.len(), 1);
257        assert!(graph.validate().is_ok());
258    }
259
260    #[test]
261    fn test_transformer_layer_config() {
262        let config = TransformerLayerConfig::new(512, 8, 2048).unwrap();
263        assert_eq!(config.attention.d_model, 512);
264        assert_eq!(config.attention.n_heads, 8);
265        assert_eq!(config.feed_forward.d_ff, 2048);
266        assert!(config.validate().is_ok());
267    }
268
269    #[test]
270    fn test_config_builder_pattern() {
271        let config = AttentionConfig::new(512, 8)
272            .unwrap()
273            .with_causal(true)
274            .with_dropout(0.1);
275
276        assert!(config.causal);
277        assert!((config.dropout - 0.1).abs() < 1e-10);
278        assert!(config.validate().is_ok());
279    }
280
281    #[test]
282    fn test_ffn_config_builder() {
283        let config = FeedForwardConfig::new(512, 2048)
284            .with_activation("relu")
285            .with_dropout(0.1);
286
287        assert_eq!(config.activation, "relu");
288        assert!((config.dropout - 0.1).abs() < 1e-10);
289        assert!(config.validate().is_ok());
290    }
291
292    #[test]
293    fn test_invalid_configurations() {
294        // Invalid head count
295        let result = AttentionConfig::new(512, 7);
296        assert!(result.is_err());
297
298        // Valid configuration
299        let result = AttentionConfig::new(512, 8);
300        assert!(result.is_ok());
301    }
302}