tensorlogic_trustformers/
config.rs1use serde::{Deserialize, Serialize};
4
5use crate::error::{Result, TrustformerError};
6
7#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
9pub struct AttentionConfig {
10 pub d_model: usize,
12 pub n_heads: usize,
14 pub d_k: usize,
16 pub causal: bool,
18 pub dropout: f64,
20}
21
22impl AttentionConfig {
23 pub fn new(d_model: usize, n_heads: usize) -> Result<Self> {
25 if !d_model.is_multiple_of(n_heads) {
26 return Err(TrustformerError::InvalidHeadCount { d_model, n_heads });
27 }
28
29 Ok(Self {
30 d_model,
31 n_heads,
32 d_k: d_model / n_heads,
33 causal: false,
34 dropout: 0.0,
35 })
36 }
37
38 pub fn with_causal(mut self, causal: bool) -> Self {
40 self.causal = causal;
41 self
42 }
43
44 pub fn with_dropout(mut self, dropout: f64) -> Self {
46 self.dropout = dropout;
47 self
48 }
49
50 pub fn validate(&self) -> Result<()> {
52 if self.d_model == 0 {
53 return Err(TrustformerError::InvalidDimension {
54 expected: 1,
55 got: 0,
56 context: "d_model must be positive".to_string(),
57 });
58 }
59
60 if self.n_heads == 0 {
61 return Err(TrustformerError::InvalidDimension {
62 expected: 1,
63 got: 0,
64 context: "n_heads must be positive".to_string(),
65 });
66 }
67
68 if !self.d_model.is_multiple_of(self.n_heads) {
69 return Err(TrustformerError::InvalidHeadCount {
70 d_model: self.d_model,
71 n_heads: self.n_heads,
72 });
73 }
74
75 if !(0.0..=1.0).contains(&self.dropout) {
76 return Err(TrustformerError::InvalidDimension {
77 expected: 1,
78 got: 0,
79 context: format!("dropout must be in [0,1], got {}", self.dropout),
80 });
81 }
82
83 Ok(())
84 }
85}
86
87#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
89pub struct FeedForwardConfig {
90 pub d_model: usize,
92 pub d_ff: usize,
94 pub activation: String,
96 pub dropout: f64,
98}
99
100impl FeedForwardConfig {
101 pub fn new(d_model: usize, d_ff: usize) -> Self {
103 Self {
104 d_model,
105 d_ff,
106 activation: "gelu".to_string(),
107 dropout: 0.0,
108 }
109 }
110
111 pub fn with_activation(mut self, activation: impl Into<String>) -> Self {
113 self.activation = activation.into();
114 self
115 }
116
117 pub fn with_dropout(mut self, dropout: f64) -> Self {
119 self.dropout = dropout;
120 self
121 }
122
123 pub fn validate(&self) -> Result<()> {
125 if self.d_model == 0 {
126 return Err(TrustformerError::InvalidDimension {
127 expected: 1,
128 got: 0,
129 context: "d_model must be positive".to_string(),
130 });
131 }
132
133 if self.d_ff == 0 {
134 return Err(TrustformerError::InvalidDimension {
135 expected: 1,
136 got: 0,
137 context: "d_ff must be positive".to_string(),
138 });
139 }
140
141 if !(0.0..=1.0).contains(&self.dropout) {
142 return Err(TrustformerError::InvalidDimension {
143 expected: 1,
144 got: 0,
145 context: format!("dropout must be in [0,1], got {}", self.dropout),
146 });
147 }
148
149 Ok(())
150 }
151}
152
153#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
155pub struct TransformerLayerConfig {
156 pub attention: AttentionConfig,
158 pub feed_forward: FeedForwardConfig,
160 pub pre_norm: bool,
162}
163
164impl TransformerLayerConfig {
165 pub fn new(d_model: usize, n_heads: usize, d_ff: usize) -> Result<Self> {
167 Ok(Self {
168 attention: AttentionConfig::new(d_model, n_heads)?,
169 feed_forward: FeedForwardConfig::new(d_model, d_ff),
170 pre_norm: true,
171 })
172 }
173
174 pub fn with_pre_norm(mut self, pre_norm: bool) -> Self {
176 self.pre_norm = pre_norm;
177 self
178 }
179
180 pub fn validate(&self) -> Result<()> {
182 self.attention.validate()?;
183 self.feed_forward.validate()?;
184
185 if self.attention.d_model != self.feed_forward.d_model {
186 return Err(TrustformerError::InvalidDimension {
187 expected: self.attention.d_model,
188 got: self.feed_forward.d_model,
189 context: "d_model mismatch between attention and FFN".to_string(),
190 });
191 }
192
193 Ok(())
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 #[test]
202 fn test_attention_config_valid() {
203 let config = AttentionConfig::new(512, 8).unwrap();
204 assert_eq!(config.d_model, 512);
205 assert_eq!(config.n_heads, 8);
206 assert_eq!(config.d_k, 64);
207 assert!(config.validate().is_ok());
208 }
209
210 #[test]
211 fn test_attention_config_invalid_heads() {
212 let result = AttentionConfig::new(512, 7);
213 assert!(result.is_err());
214 }
215
216 #[test]
217 fn test_attention_config_with_causal() {
218 let config = AttentionConfig::new(512, 8).unwrap().with_causal(true);
219 assert!(config.causal);
220 }
221
222 #[test]
223 fn test_attention_config_with_dropout() {
224 let config = AttentionConfig::new(512, 8).unwrap().with_dropout(0.1);
225 assert!((config.dropout - 0.1).abs() < 1e-10);
226 }
227
228 #[test]
229 fn test_ffn_config() {
230 let config = FeedForwardConfig::new(512, 2048);
231 assert_eq!(config.d_model, 512);
232 assert_eq!(config.d_ff, 2048);
233 assert_eq!(config.activation, "gelu");
234 assert!(config.validate().is_ok());
235 }
236
237 #[test]
238 fn test_ffn_config_with_activation() {
239 let config = FeedForwardConfig::new(512, 2048).with_activation("relu");
240 assert_eq!(config.activation, "relu");
241 }
242
243 #[test]
244 fn test_transformer_layer_config() {
245 let config = TransformerLayerConfig::new(512, 8, 2048).unwrap();
246 assert_eq!(config.attention.d_model, 512);
247 assert_eq!(config.feed_forward.d_model, 512);
248 assert!(config.pre_norm);
249 assert!(config.validate().is_ok());
250 }
251
252 #[test]
253 fn test_transformer_layer_config_with_pre_norm() {
254 let config = TransformerLayerConfig::new(512, 8, 2048)
255 .unwrap()
256 .with_pre_norm(false);
257 assert!(!config.pre_norm);
258 }
259}