ruvector_attention/
config.rs

1//! Configuration types for attention mechanisms.
2//!
3//! This module provides configuration structs and builders for various
4//! attention mechanisms including standard, graph, and sparse attention.
5
6use serde::{Deserialize, Serialize};
7
8use crate::error::{AttentionError, AttentionResult};
9
10/// Configuration for standard attention mechanisms.
11#[derive(Clone, Debug, Serialize, Deserialize)]
12pub struct AttentionConfig {
13    /// Model dimension (d_model)
14    pub dim: usize,
15    /// Number of attention heads
16    pub num_heads: usize,
17    /// Dropout probability (0.0 to 1.0)
18    pub dropout: f32,
19    /// Scaling factor (default: 1/sqrt(d_k))
20    pub scale: Option<f32>,
21    /// Whether to use causal masking
22    pub causal: bool,
23}
24
25impl AttentionConfig {
26    /// Creates a new builder for AttentionConfig.
27    pub fn builder() -> AttentionConfigBuilder {
28        AttentionConfigBuilder::default()
29    }
30
31    /// Validates the configuration.
32    pub fn validate(&self) -> AttentionResult<()> {
33        if self.dim == 0 {
34            return Err(AttentionError::InvalidConfig(
35                "dimension must be greater than 0".to_string(),
36            ));
37        }
38
39        if self.num_heads == 0 {
40            return Err(AttentionError::InvalidConfig(
41                "num_heads must be greater than 0".to_string(),
42            ));
43        }
44
45        if self.dim % self.num_heads != 0 {
46            return Err(AttentionError::InvalidHeadCount {
47                dim: self.dim,
48                num_heads: self.num_heads,
49            });
50        }
51
52        if self.dropout < 0.0 || self.dropout > 1.0 {
53            return Err(AttentionError::InvalidConfig(
54                "dropout must be in range [0.0, 1.0]".to_string(),
55            ));
56        }
57
58        if let Some(scale) = self.scale {
59            if !scale.is_finite() || scale <= 0.0 {
60                return Err(AttentionError::InvalidConfig(
61                    "scale must be positive and finite".to_string(),
62                ));
63            }
64        }
65
66        Ok(())
67    }
68
69    /// Returns the dimension per head (d_k).
70    #[inline]
71    pub fn head_dim(&self) -> usize {
72        self.dim / self.num_heads
73    }
74
75    /// Returns the effective scale factor.
76    #[inline]
77    pub fn effective_scale(&self) -> f32 {
78        self.scale
79            .unwrap_or_else(|| 1.0 / (self.head_dim() as f32).sqrt())
80    }
81}
82
83/// Builder for AttentionConfig.
84#[derive(Default)]
85pub struct AttentionConfigBuilder {
86    dim: Option<usize>,
87    num_heads: Option<usize>,
88    dropout: f32,
89    scale: Option<f32>,
90    causal: bool,
91}
92
93impl AttentionConfigBuilder {
94    /// Sets the model dimension.
95    pub fn dim(mut self, dim: usize) -> Self {
96        self.dim = Some(dim);
97        self
98    }
99
100    /// Sets the number of attention heads.
101    pub fn num_heads(mut self, num_heads: usize) -> Self {
102        self.num_heads = Some(num_heads);
103        self
104    }
105
106    /// Sets the dropout probability.
107    pub fn dropout(mut self, dropout: f32) -> Self {
108        self.dropout = dropout;
109        self
110    }
111
112    /// Sets a custom scale factor.
113    pub fn scale(mut self, scale: f32) -> Self {
114        self.scale = Some(scale);
115        self
116    }
117
118    /// Enables causal masking.
119    pub fn causal(mut self, causal: bool) -> Self {
120        self.causal = causal;
121        self
122    }
123
124    /// Builds the AttentionConfig.
125    pub fn build(self) -> AttentionResult<AttentionConfig> {
126        let config = AttentionConfig {
127            dim: self.dim.ok_or_else(|| {
128                AttentionError::InvalidConfig("dimension must be specified".to_string())
129            })?,
130            num_heads: self.num_heads.ok_or_else(|| {
131                AttentionError::InvalidConfig("num_heads must be specified".to_string())
132            })?,
133            dropout: self.dropout,
134            scale: self.scale,
135            causal: self.causal,
136        };
137
138        config.validate()?;
139        Ok(config)
140    }
141}
142
143/// Configuration for graph attention networks.
144#[derive(Clone, Debug, Serialize, Deserialize)]
145pub struct GraphAttentionConfig {
146    /// Base attention configuration
147    pub base: AttentionConfig,
148    /// Edge feature dimension (if using edge features)
149    pub edge_dim: Option<usize>,
150    /// Negative slope for LeakyReLU
151    pub negative_slope: f32,
152    /// Whether to concatenate multi-head outputs (vs averaging)
153    pub concat_heads: bool,
154}
155
156impl GraphAttentionConfig {
157    /// Creates a new builder for GraphAttentionConfig.
158    pub fn builder() -> GraphAttentionConfigBuilder {
159        GraphAttentionConfigBuilder::default()
160    }
161
162    /// Validates the configuration.
163    pub fn validate(&self) -> AttentionResult<()> {
164        self.base.validate()?;
165
166        if self.negative_slope <= 0.0 || !self.negative_slope.is_finite() {
167            return Err(AttentionError::InvalidConfig(
168                "negative_slope must be positive and finite".to_string(),
169            ));
170        }
171
172        if let Some(edge_dim) = self.edge_dim {
173            if edge_dim == 0 {
174                return Err(AttentionError::InvalidConfig(
175                    "edge_dim must be greater than 0".to_string(),
176                ));
177            }
178        }
179
180        Ok(())
181    }
182}
183
184/// Builder for GraphAttentionConfig.
185#[derive(Default)]
186pub struct GraphAttentionConfigBuilder {
187    base_builder: AttentionConfigBuilder,
188    edge_dim: Option<usize>,
189    negative_slope: f32,
190    concat_heads: bool,
191}
192
193impl GraphAttentionConfigBuilder {
194    /// Sets the model dimension.
195    pub fn dim(mut self, dim: usize) -> Self {
196        self.base_builder = self.base_builder.dim(dim);
197        self
198    }
199
200    /// Sets the number of attention heads.
201    pub fn num_heads(mut self, num_heads: usize) -> Self {
202        self.base_builder = self.base_builder.num_heads(num_heads);
203        self
204    }
205
206    /// Sets the edge feature dimension.
207    pub fn edge_dim(mut self, edge_dim: usize) -> Self {
208        self.edge_dim = Some(edge_dim);
209        self
210    }
211
212    /// Sets the negative slope for LeakyReLU.
213    pub fn negative_slope(mut self, slope: f32) -> Self {
214        self.negative_slope = slope;
215        self
216    }
217
218    /// Sets whether to concatenate multi-head outputs.
219    pub fn concat_heads(mut self, concat: bool) -> Self {
220        self.concat_heads = concat;
221        self
222    }
223
224    /// Builds the GraphAttentionConfig.
225    pub fn build(self) -> AttentionResult<GraphAttentionConfig> {
226        let config = GraphAttentionConfig {
227            base: self.base_builder.build()?,
228            edge_dim: self.edge_dim,
229            negative_slope: if self.negative_slope == 0.0 {
230                0.2
231            } else {
232                self.negative_slope
233            },
234            concat_heads: self.concat_heads,
235        };
236
237        config.validate()?;
238        Ok(config)
239    }
240}
241
242/// Configuration for sparse attention mechanisms.
243#[derive(Clone, Debug, Serialize, Deserialize)]
244pub struct SparseAttentionConfig {
245    /// Base attention configuration
246    pub base: AttentionConfig,
247    /// Block size for block-sparse attention
248    pub block_size: usize,
249    /// Number of random blocks per query
250    pub num_random_blocks: usize,
251    /// Number of global tokens
252    pub num_global_tokens: usize,
253}
254
255impl SparseAttentionConfig {
256    /// Creates a new builder for SparseAttentionConfig.
257    pub fn builder() -> SparseAttentionConfigBuilder {
258        SparseAttentionConfigBuilder::default()
259    }
260
261    /// Validates the configuration.
262    pub fn validate(&self) -> AttentionResult<()> {
263        self.base.validate()?;
264
265        if self.block_size == 0 {
266            return Err(AttentionError::InvalidConfig(
267                "block_size must be greater than 0".to_string(),
268            ));
269        }
270
271        Ok(())
272    }
273}
274
275/// Builder for SparseAttentionConfig.
276#[derive(Default)]
277pub struct SparseAttentionConfigBuilder {
278    base_builder: AttentionConfigBuilder,
279    block_size: usize,
280    num_random_blocks: usize,
281    num_global_tokens: usize,
282}
283
284impl SparseAttentionConfigBuilder {
285    /// Sets the model dimension.
286    pub fn dim(mut self, dim: usize) -> Self {
287        self.base_builder = self.base_builder.dim(dim);
288        self
289    }
290
291    /// Sets the number of attention heads.
292    pub fn num_heads(mut self, num_heads: usize) -> Self {
293        self.base_builder = self.base_builder.num_heads(num_heads);
294        self
295    }
296
297    /// Sets the block size.
298    pub fn block_size(mut self, block_size: usize) -> Self {
299        self.block_size = block_size;
300        self
301    }
302
303    /// Sets the number of random blocks.
304    pub fn num_random_blocks(mut self, num_random_blocks: usize) -> Self {
305        self.num_random_blocks = num_random_blocks;
306        self
307    }
308
309    /// Sets the number of global tokens.
310    pub fn num_global_tokens(mut self, num_global_tokens: usize) -> Self {
311        self.num_global_tokens = num_global_tokens;
312        self
313    }
314
315    /// Builds the SparseAttentionConfig.
316    pub fn build(self) -> AttentionResult<SparseAttentionConfig> {
317        let config = SparseAttentionConfig {
318            base: self.base_builder.build()?,
319            block_size: if self.block_size == 0 {
320                64
321            } else {
322                self.block_size
323            },
324            num_random_blocks: self.num_random_blocks,
325            num_global_tokens: self.num_global_tokens,
326        };
327
328        config.validate()?;
329        Ok(config)
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336
337    #[test]
338    fn test_attention_config_builder() {
339        let config = AttentionConfig::builder()
340            .dim(512)
341            .num_heads(8)
342            .dropout(0.1)
343            .causal(true)
344            .build()
345            .unwrap();
346
347        assert_eq!(config.dim, 512);
348        assert_eq!(config.num_heads, 8);
349        assert_eq!(config.dropout, 0.1);
350        assert!(config.causal);
351        assert_eq!(config.head_dim(), 64);
352    }
353
354    #[test]
355    fn test_config_validation() {
356        let result = AttentionConfig::builder()
357            .dim(512)
358            .num_heads(7) // Not divisible
359            .build();
360
361        assert!(result.is_err());
362    }
363
364    #[test]
365    fn test_graph_attention_config() {
366        let config = GraphAttentionConfig::builder()
367            .dim(256)
368            .num_heads(4)
369            .edge_dim(16)
370            .negative_slope(0.2)
371            .concat_heads(true)
372            .build()
373            .unwrap();
374
375        assert_eq!(config.base.dim, 256);
376        assert_eq!(config.edge_dim, Some(16));
377        assert!(config.concat_heads);
378    }
379
380    #[test]
381    fn test_sparse_attention_config() {
382        let config = SparseAttentionConfig::builder()
383            .dim(512)
384            .num_heads(8)
385            .block_size(64)
386            .num_random_blocks(3)
387            .num_global_tokens(64)
388            .build()
389            .unwrap();
390
391        assert_eq!(config.base.dim, 512);
392        assert_eq!(config.block_size, 64);
393        assert_eq!(config.num_random_blocks, 3);
394    }
395}