tensorlogic_trustformers/
lib.rs1pub mod attention;
110pub mod checkpointing;
111pub mod config;
112pub mod decoder;
113pub mod encoder;
114pub mod error;
115pub mod ffn;
116pub mod flash_attention;
117pub mod gqa;
118pub mod kv_cache;
119pub mod layers;
120pub mod lora;
121pub mod moe;
122pub mod normalization;
123pub mod patterns;
124pub mod position;
125pub mod presets;
126pub mod rule_attention;
127pub mod sliding_window;
128pub mod sparse_attention;
129pub mod stacks;
130pub mod trustformers_integration;
131pub mod utils;
132pub mod vision;
133
134pub use attention::{MultiHeadAttention, SelfAttention};
136pub use checkpointing::{CheckpointConfig, CheckpointStrategy};
137pub use config::{AttentionConfig, FeedForwardConfig, TransformerLayerConfig};
138pub use decoder::{Decoder, DecoderConfig};
139pub use encoder::{Encoder, EncoderConfig};
140pub use error::{Result, TrustformerError};
141pub use ffn::{FeedForward, GatedFeedForward};
142pub use flash_attention::{
143 FlashAttention, FlashAttentionConfig, FlashAttentionPreset, FlashAttentionStats,
144 FlashAttentionV2Config,
145};
146pub use gqa::{GQAConfig, GQAPreset, GQAStats, GroupedQueryAttention};
147pub use kv_cache::{CacheStats, KVCache, KVCacheConfig};
148pub use layers::{DecoderLayer, DecoderLayerConfig, EncoderLayer, EncoderLayerConfig};
149pub use lora::{LoRAAttention, LoRAConfig, LoRALinear, LoRAPreset, LoRAStats};
150pub use moe::{MoeConfig, MoeLayer, MoePreset, MoeStats, RouterType};
151pub use normalization::{LayerNorm, LayerNormConfig, RMSNorm};
152pub use patterns::{
153 AttentionMask, BlockSparseMask, CausalMask, GlobalLocalMask, LocalMask, RuleBasedMask,
154 RulePattern, StridedMask,
155};
156pub use position::{
157 AlibiPositionEncoding, LearnedPositionEncoding, PositionEncodingConfig, PositionEncodingType,
158 RelativePositionEncoding, RotaryPositionEncoding, SinusoidalPositionEncoding,
159};
160pub use presets::ModelPreset;
161pub use rule_attention::{
162 RuleAttentionConfig, RuleAttentionType, RuleBasedAttention, StructuredAttention,
163};
164pub use sliding_window::{
165 SlidingWindowAttention, SlidingWindowConfig, SlidingWindowPreset, SlidingWindowStats,
166};
167pub use sparse_attention::{
168 LocalAttention, SparseAttention, SparseAttentionConfig, SparsePatternType,
169};
170pub use stacks::{DecoderStack, DecoderStackConfig, EncoderStack, EncoderStackConfig};
171pub use trustformers_integration::{
172 CheckpointData, IntegrationConfig, ModelConfig, TensorLogicModel, TrustformersConverter,
173 TrustformersWeightLoader,
174};
175pub use utils::{decoder_stack_stats, encoder_stack_stats, ModelStats};
176pub use vision::{
177 PatchEmbedding, PatchEmbeddingConfig, ViTPreset, VisionTransformer, VisionTransformerConfig,
178};
179
180#[deprecated(since = "0.1.0", note = "Use AttentionConfig instead")]
182pub type AttnSpec = AttentionConfig;
183
184#[deprecated(
185 since = "0.1.0",
186 note = "Use SelfAttention::build_attention_graph instead"
187)]
188pub fn self_attention_as_rules(_spec: &AttentionConfig) {
189 }
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195 use tensorlogic_ir::EinsumGraph;
196
197 #[test]
198 fn test_end_to_end_self_attention() {
199 let config = AttentionConfig::new(512, 8).unwrap();
200 let attn = SelfAttention::new(config).unwrap();
201
202 let mut graph = EinsumGraph::new();
203 graph.add_tensor("Q");
204 graph.add_tensor("K");
205 graph.add_tensor("V");
206
207 let outputs = attn.build_attention_graph(&mut graph).unwrap();
208 assert_eq!(outputs.len(), 1);
209 assert!(graph.validate().is_ok());
210 }
211
212 #[test]
213 fn test_end_to_end_multi_head_attention() {
214 let config = AttentionConfig::new(512, 8).unwrap();
215 let mha = MultiHeadAttention::new(config).unwrap();
216
217 let mut graph = EinsumGraph::new();
218 graph.add_tensor("Q");
219 graph.add_tensor("K");
220 graph.add_tensor("V");
221
222 let outputs = mha.build_mha_graph(&mut graph).unwrap();
223 assert_eq!(outputs.len(), 1);
224 assert!(graph.validate().is_ok());
225 }
226
227 #[test]
228 fn test_end_to_end_ffn() {
229 let config = FeedForwardConfig::new(512, 2048);
230 let ffn = FeedForward::new(config).unwrap();
231
232 let mut graph = EinsumGraph::new();
233 graph.add_tensor("x");
234 graph.add_tensor("W1");
235 graph.add_tensor("b1");
236 graph.add_tensor("W2");
237 graph.add_tensor("b2");
238
239 let outputs = ffn.build_ffn_graph(&mut graph).unwrap();
240 assert_eq!(outputs.len(), 1);
241 assert!(graph.validate().is_ok());
242 }
243
244 #[test]
245 fn test_end_to_end_gated_ffn() {
246 let config = FeedForwardConfig::new(512, 2048);
247 let glu = GatedFeedForward::new(config).unwrap();
248
249 let mut graph = EinsumGraph::new();
250 graph.add_tensor("x");
251 graph.add_tensor("W_gate");
252 graph.add_tensor("W_value");
253 graph.add_tensor("W_out");
254
255 let outputs = glu.build_glu_graph(&mut graph).unwrap();
256 assert_eq!(outputs.len(), 1);
257 assert!(graph.validate().is_ok());
258 }
259
260 #[test]
261 fn test_transformer_layer_config() {
262 let config = TransformerLayerConfig::new(512, 8, 2048).unwrap();
263 assert_eq!(config.attention.d_model, 512);
264 assert_eq!(config.attention.n_heads, 8);
265 assert_eq!(config.feed_forward.d_ff, 2048);
266 assert!(config.validate().is_ok());
267 }
268
269 #[test]
270 fn test_config_builder_pattern() {
271 let config = AttentionConfig::new(512, 8)
272 .unwrap()
273 .with_causal(true)
274 .with_dropout(0.1);
275
276 assert!(config.causal);
277 assert!((config.dropout - 0.1).abs() < 1e-10);
278 assert!(config.validate().is_ok());
279 }
280
281 #[test]
282 fn test_ffn_config_builder() {
283 let config = FeedForwardConfig::new(512, 2048)
284 .with_activation("relu")
285 .with_dropout(0.1);
286
287 assert_eq!(config.activation, "relu");
288 assert!((config.dropout - 0.1).abs() < 1e-10);
289 assert!(config.validate().is_ok());
290 }
291
292 #[test]
293 fn test_invalid_configurations() {
294 let result = AttentionConfig::new(512, 7);
296 assert!(result.is_err());
297
298 let result = AttentionConfig::new(512, 8);
300 assert!(result.is_ok());
301 }
302}