ruvector_attention/
config.rs1use serde::{Deserialize, Serialize};
7
8use crate::error::{AttentionError, AttentionResult};
9
10#[derive(Clone, Debug, Serialize, Deserialize)]
12pub struct AttentionConfig {
13 pub dim: usize,
15 pub num_heads: usize,
17 pub dropout: f32,
19 pub scale: Option<f32>,
21 pub causal: bool,
23}
24
25impl AttentionConfig {
26 pub fn builder() -> AttentionConfigBuilder {
28 AttentionConfigBuilder::default()
29 }
30
31 pub fn validate(&self) -> AttentionResult<()> {
33 if self.dim == 0 {
34 return Err(AttentionError::InvalidConfig(
35 "dimension must be greater than 0".to_string(),
36 ));
37 }
38
39 if self.num_heads == 0 {
40 return Err(AttentionError::InvalidConfig(
41 "num_heads must be greater than 0".to_string(),
42 ));
43 }
44
45 if self.dim % self.num_heads != 0 {
46 return Err(AttentionError::InvalidHeadCount {
47 dim: self.dim,
48 num_heads: self.num_heads,
49 });
50 }
51
52 if self.dropout < 0.0 || self.dropout > 1.0 {
53 return Err(AttentionError::InvalidConfig(
54 "dropout must be in range [0.0, 1.0]".to_string(),
55 ));
56 }
57
58 if let Some(scale) = self.scale {
59 if !scale.is_finite() || scale <= 0.0 {
60 return Err(AttentionError::InvalidConfig(
61 "scale must be positive and finite".to_string(),
62 ));
63 }
64 }
65
66 Ok(())
67 }
68
69 #[inline]
71 pub fn head_dim(&self) -> usize {
72 self.dim / self.num_heads
73 }
74
75 #[inline]
77 pub fn effective_scale(&self) -> f32 {
78 self.scale
79 .unwrap_or_else(|| 1.0 / (self.head_dim() as f32).sqrt())
80 }
81}
82
83#[derive(Default)]
85pub struct AttentionConfigBuilder {
86 dim: Option<usize>,
87 num_heads: Option<usize>,
88 dropout: f32,
89 scale: Option<f32>,
90 causal: bool,
91}
92
93impl AttentionConfigBuilder {
94 pub fn dim(mut self, dim: usize) -> Self {
96 self.dim = Some(dim);
97 self
98 }
99
100 pub fn num_heads(mut self, num_heads: usize) -> Self {
102 self.num_heads = Some(num_heads);
103 self
104 }
105
106 pub fn dropout(mut self, dropout: f32) -> Self {
108 self.dropout = dropout;
109 self
110 }
111
112 pub fn scale(mut self, scale: f32) -> Self {
114 self.scale = Some(scale);
115 self
116 }
117
118 pub fn causal(mut self, causal: bool) -> Self {
120 self.causal = causal;
121 self
122 }
123
124 pub fn build(self) -> AttentionResult<AttentionConfig> {
126 let config = AttentionConfig {
127 dim: self.dim.ok_or_else(|| {
128 AttentionError::InvalidConfig("dimension must be specified".to_string())
129 })?,
130 num_heads: self.num_heads.ok_or_else(|| {
131 AttentionError::InvalidConfig("num_heads must be specified".to_string())
132 })?,
133 dropout: self.dropout,
134 scale: self.scale,
135 causal: self.causal,
136 };
137
138 config.validate()?;
139 Ok(config)
140 }
141}
142
143#[derive(Clone, Debug, Serialize, Deserialize)]
145pub struct GraphAttentionConfig {
146 pub base: AttentionConfig,
148 pub edge_dim: Option<usize>,
150 pub negative_slope: f32,
152 pub concat_heads: bool,
154}
155
156impl GraphAttentionConfig {
157 pub fn builder() -> GraphAttentionConfigBuilder {
159 GraphAttentionConfigBuilder::default()
160 }
161
162 pub fn validate(&self) -> AttentionResult<()> {
164 self.base.validate()?;
165
166 if self.negative_slope <= 0.0 || !self.negative_slope.is_finite() {
167 return Err(AttentionError::InvalidConfig(
168 "negative_slope must be positive and finite".to_string(),
169 ));
170 }
171
172 if let Some(edge_dim) = self.edge_dim {
173 if edge_dim == 0 {
174 return Err(AttentionError::InvalidConfig(
175 "edge_dim must be greater than 0".to_string(),
176 ));
177 }
178 }
179
180 Ok(())
181 }
182}
183
184#[derive(Default)]
186pub struct GraphAttentionConfigBuilder {
187 base_builder: AttentionConfigBuilder,
188 edge_dim: Option<usize>,
189 negative_slope: f32,
190 concat_heads: bool,
191}
192
193impl GraphAttentionConfigBuilder {
194 pub fn dim(mut self, dim: usize) -> Self {
196 self.base_builder = self.base_builder.dim(dim);
197 self
198 }
199
200 pub fn num_heads(mut self, num_heads: usize) -> Self {
202 self.base_builder = self.base_builder.num_heads(num_heads);
203 self
204 }
205
206 pub fn edge_dim(mut self, edge_dim: usize) -> Self {
208 self.edge_dim = Some(edge_dim);
209 self
210 }
211
212 pub fn negative_slope(mut self, slope: f32) -> Self {
214 self.negative_slope = slope;
215 self
216 }
217
218 pub fn concat_heads(mut self, concat: bool) -> Self {
220 self.concat_heads = concat;
221 self
222 }
223
224 pub fn build(self) -> AttentionResult<GraphAttentionConfig> {
226 let config = GraphAttentionConfig {
227 base: self.base_builder.build()?,
228 edge_dim: self.edge_dim,
229 negative_slope: if self.negative_slope == 0.0 {
230 0.2
231 } else {
232 self.negative_slope
233 },
234 concat_heads: self.concat_heads,
235 };
236
237 config.validate()?;
238 Ok(config)
239 }
240}
241
242#[derive(Clone, Debug, Serialize, Deserialize)]
244pub struct SparseAttentionConfig {
245 pub base: AttentionConfig,
247 pub block_size: usize,
249 pub num_random_blocks: usize,
251 pub num_global_tokens: usize,
253}
254
255impl SparseAttentionConfig {
256 pub fn builder() -> SparseAttentionConfigBuilder {
258 SparseAttentionConfigBuilder::default()
259 }
260
261 pub fn validate(&self) -> AttentionResult<()> {
263 self.base.validate()?;
264
265 if self.block_size == 0 {
266 return Err(AttentionError::InvalidConfig(
267 "block_size must be greater than 0".to_string(),
268 ));
269 }
270
271 Ok(())
272 }
273}
274
275#[derive(Default)]
277pub struct SparseAttentionConfigBuilder {
278 base_builder: AttentionConfigBuilder,
279 block_size: usize,
280 num_random_blocks: usize,
281 num_global_tokens: usize,
282}
283
284impl SparseAttentionConfigBuilder {
285 pub fn dim(mut self, dim: usize) -> Self {
287 self.base_builder = self.base_builder.dim(dim);
288 self
289 }
290
291 pub fn num_heads(mut self, num_heads: usize) -> Self {
293 self.base_builder = self.base_builder.num_heads(num_heads);
294 self
295 }
296
297 pub fn block_size(mut self, block_size: usize) -> Self {
299 self.block_size = block_size;
300 self
301 }
302
303 pub fn num_random_blocks(mut self, num_random_blocks: usize) -> Self {
305 self.num_random_blocks = num_random_blocks;
306 self
307 }
308
309 pub fn num_global_tokens(mut self, num_global_tokens: usize) -> Self {
311 self.num_global_tokens = num_global_tokens;
312 self
313 }
314
315 pub fn build(self) -> AttentionResult<SparseAttentionConfig> {
317 let config = SparseAttentionConfig {
318 base: self.base_builder.build()?,
319 block_size: if self.block_size == 0 {
320 64
321 } else {
322 self.block_size
323 },
324 num_random_blocks: self.num_random_blocks,
325 num_global_tokens: self.num_global_tokens,
326 };
327
328 config.validate()?;
329 Ok(config)
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336
337 #[test]
338 fn test_attention_config_builder() {
339 let config = AttentionConfig::builder()
340 .dim(512)
341 .num_heads(8)
342 .dropout(0.1)
343 .causal(true)
344 .build()
345 .unwrap();
346
347 assert_eq!(config.dim, 512);
348 assert_eq!(config.num_heads, 8);
349 assert_eq!(config.dropout, 0.1);
350 assert!(config.causal);
351 assert_eq!(config.head_dim(), 64);
352 }
353
354 #[test]
355 fn test_config_validation() {
356 let result = AttentionConfig::builder()
357 .dim(512)
358 .num_heads(7) .build();
360
361 assert!(result.is_err());
362 }
363
364 #[test]
365 fn test_graph_attention_config() {
366 let config = GraphAttentionConfig::builder()
367 .dim(256)
368 .num_heads(4)
369 .edge_dim(16)
370 .negative_slope(0.2)
371 .concat_heads(true)
372 .build()
373 .unwrap();
374
375 assert_eq!(config.base.dim, 256);
376 assert_eq!(config.edge_dim, Some(16));
377 assert!(config.concat_heads);
378 }
379
380 #[test]
381 fn test_sparse_attention_config() {
382 let config = SparseAttentionConfig::builder()
383 .dim(512)
384 .num_heads(8)
385 .block_size(64)
386 .num_random_blocks(3)
387 .num_global_tokens(64)
388 .build()
389 .unwrap();
390
391 assert_eq!(config.base.dim, 512);
392 assert_eq!(config.block_size, 64);
393 assert_eq!(config.num_random_blocks, 3);
394 }
395}