Skip to main content

tensorlogic_trustformers/
config.rs

1//! Configuration structures for transformer components.
2
3use serde::{Deserialize, Serialize};
4
5use crate::error::{Result, TrustformerError};
6
7/// Configuration for self-attention mechanism
8#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
9pub struct AttentionConfig {
10    /// Model dimension (d_model)
11    pub d_model: usize,
12    /// Number of attention heads
13    pub n_heads: usize,
14    /// Dimension per head (d_k = d_model / n_heads)
15    pub d_k: usize,
16    /// Whether to use causal (autoregressive) masking
17    pub causal: bool,
18    /// Dropout probability (0.0 = no dropout)
19    pub dropout: f64,
20}
21
22impl AttentionConfig {
23    /// Create a new attention configuration
24    pub fn new(d_model: usize, n_heads: usize) -> Result<Self> {
25        if !d_model.is_multiple_of(n_heads) {
26            return Err(TrustformerError::InvalidHeadCount { d_model, n_heads });
27        }
28
29        Ok(Self {
30            d_model,
31            n_heads,
32            d_k: d_model / n_heads,
33            causal: false,
34            dropout: 0.0,
35        })
36    }
37
38    /// Set causal masking
39    pub fn with_causal(mut self, causal: bool) -> Self {
40        self.causal = causal;
41        self
42    }
43
44    /// Set dropout probability
45    pub fn with_dropout(mut self, dropout: f64) -> Self {
46        self.dropout = dropout;
47        self
48    }
49
50    /// Validate configuration
51    pub fn validate(&self) -> Result<()> {
52        if self.d_model == 0 {
53            return Err(TrustformerError::InvalidDimension {
54                expected: 1,
55                got: 0,
56                context: "d_model must be positive".to_string(),
57            });
58        }
59
60        if self.n_heads == 0 {
61            return Err(TrustformerError::InvalidDimension {
62                expected: 1,
63                got: 0,
64                context: "n_heads must be positive".to_string(),
65            });
66        }
67
68        if !self.d_model.is_multiple_of(self.n_heads) {
69            return Err(TrustformerError::InvalidHeadCount {
70                d_model: self.d_model,
71                n_heads: self.n_heads,
72            });
73        }
74
75        if !(0.0..=1.0).contains(&self.dropout) {
76            return Err(TrustformerError::InvalidDimension {
77                expected: 1,
78                got: 0,
79                context: format!("dropout must be in [0,1], got {}", self.dropout),
80            });
81        }
82
83        Ok(())
84    }
85}
86
87/// Configuration for feed-forward network
88#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
89pub struct FeedForwardConfig {
90    /// Model dimension (d_model)
91    pub d_model: usize,
92    /// Hidden dimension (typically 4 * d_model)
93    pub d_ff: usize,
94    /// Activation function name
95    pub activation: String,
96    /// Dropout probability
97    pub dropout: f64,
98}
99
100impl FeedForwardConfig {
101    /// Create a new feed-forward configuration
102    pub fn new(d_model: usize, d_ff: usize) -> Self {
103        Self {
104            d_model,
105            d_ff,
106            activation: "gelu".to_string(),
107            dropout: 0.0,
108        }
109    }
110
111    /// Set activation function
112    pub fn with_activation(mut self, activation: impl Into<String>) -> Self {
113        self.activation = activation.into();
114        self
115    }
116
117    /// Set dropout probability
118    pub fn with_dropout(mut self, dropout: f64) -> Self {
119        self.dropout = dropout;
120        self
121    }
122
123    /// Validate configuration
124    pub fn validate(&self) -> Result<()> {
125        if self.d_model == 0 {
126            return Err(TrustformerError::InvalidDimension {
127                expected: 1,
128                got: 0,
129                context: "d_model must be positive".to_string(),
130            });
131        }
132
133        if self.d_ff == 0 {
134            return Err(TrustformerError::InvalidDimension {
135                expected: 1,
136                got: 0,
137                context: "d_ff must be positive".to_string(),
138            });
139        }
140
141        if !(0.0..=1.0).contains(&self.dropout) {
142            return Err(TrustformerError::InvalidDimension {
143                expected: 1,
144                got: 0,
145                context: format!("dropout must be in [0,1], got {}", self.dropout),
146            });
147        }
148
149        Ok(())
150    }
151}
152
153/// Configuration for a complete transformer layer
154#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
155pub struct TransformerLayerConfig {
156    /// Attention configuration
157    pub attention: AttentionConfig,
158    /// Feed-forward configuration
159    pub feed_forward: FeedForwardConfig,
160    /// Whether to use pre-layer normalization (vs post)
161    pub pre_norm: bool,
162}
163
164impl TransformerLayerConfig {
165    /// Create a new transformer layer configuration
166    pub fn new(d_model: usize, n_heads: usize, d_ff: usize) -> Result<Self> {
167        Ok(Self {
168            attention: AttentionConfig::new(d_model, n_heads)?,
169            feed_forward: FeedForwardConfig::new(d_model, d_ff),
170            pre_norm: true,
171        })
172    }
173
174    /// Set pre-normalization vs post-normalization
175    pub fn with_pre_norm(mut self, pre_norm: bool) -> Self {
176        self.pre_norm = pre_norm;
177        self
178    }
179
180    /// Validate configuration
181    pub fn validate(&self) -> Result<()> {
182        self.attention.validate()?;
183        self.feed_forward.validate()?;
184
185        if self.attention.d_model != self.feed_forward.d_model {
186            return Err(TrustformerError::InvalidDimension {
187                expected: self.attention.d_model,
188                got: self.feed_forward.d_model,
189                context: "d_model mismatch between attention and FFN".to_string(),
190            });
191        }
192
193        Ok(())
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    #[test]
202    fn test_attention_config_valid() {
203        let config = AttentionConfig::new(512, 8).unwrap();
204        assert_eq!(config.d_model, 512);
205        assert_eq!(config.n_heads, 8);
206        assert_eq!(config.d_k, 64);
207        assert!(config.validate().is_ok());
208    }
209
210    #[test]
211    fn test_attention_config_invalid_heads() {
212        let result = AttentionConfig::new(512, 7);
213        assert!(result.is_err());
214    }
215
216    #[test]
217    fn test_attention_config_with_causal() {
218        let config = AttentionConfig::new(512, 8).unwrap().with_causal(true);
219        assert!(config.causal);
220    }
221
222    #[test]
223    fn test_attention_config_with_dropout() {
224        let config = AttentionConfig::new(512, 8).unwrap().with_dropout(0.1);
225        assert!((config.dropout - 0.1).abs() < 1e-10);
226    }
227
228    #[test]
229    fn test_ffn_config() {
230        let config = FeedForwardConfig::new(512, 2048);
231        assert_eq!(config.d_model, 512);
232        assert_eq!(config.d_ff, 2048);
233        assert_eq!(config.activation, "gelu");
234        assert!(config.validate().is_ok());
235    }
236
237    #[test]
238    fn test_ffn_config_with_activation() {
239        let config = FeedForwardConfig::new(512, 2048).with_activation("relu");
240        assert_eq!(config.activation, "relu");
241    }
242
243    #[test]
244    fn test_transformer_layer_config() {
245        let config = TransformerLayerConfig::new(512, 8, 2048).unwrap();
246        assert_eq!(config.attention.d_model, 512);
247        assert_eq!(config.feed_forward.d_model, 512);
248        assert!(config.pre_norm);
249        assert!(config.validate().is_ok());
250    }
251
252    #[test]
253    fn test_transformer_layer_config_with_pre_norm() {
254        let config = TransformerLayerConfig::new(512, 8, 2048)
255            .unwrap()
256            .with_pre_norm(false);
257        assert!(!config.pre_norm);
258    }
259}