ruvector_attention/
traits.rs

1//! Trait definitions for attention mechanisms.
2//!
3//! This module defines the core traits that all attention mechanisms implement,
4//! including standard attention, graph attention, geometric attention, and
5//! trainable attention with backward pass support.
6
7use crate::error::AttentionResult;
8
9/// Mask for sparse attention patterns.
10#[derive(Clone, Debug)]
11pub struct SparseMask {
12    /// Row indices for sparse mask
13    pub rows: Vec<usize>,
14    /// Column indices for sparse mask
15    pub cols: Vec<usize>,
16    /// Optional values (if not provided, defaults to 1.0)
17    pub values: Option<Vec<f32>>,
18}
19
20/// Edge information for graph attention.
21#[derive(Clone, Debug)]
22pub struct EdgeInfo {
23    /// Source node index
24    pub src: usize,
25    /// Destination node index
26    pub dst: usize,
27    /// Optional edge features
28    pub features: Option<Vec<f32>>,
29}
30
31/// Core attention mechanism trait.
32///
33/// Implements the basic attention computation: Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V
34pub trait Attention: Send + Sync {
35    /// Computes attention over the given query, keys, and values.
36    ///
37    /// # Arguments
38    ///
39    /// * `query` - Query vector of shape [d_model]
40    /// * `keys` - Slice of key vectors, each of shape [d_model]
41    /// * `values` - Slice of value vectors, each of shape [d_model]
42    ///
43    /// # Returns
44    ///
45    /// Output vector of shape [d_model]
46    fn compute(
47        &self,
48        query: &[f32],
49        keys: &[&[f32]],
50        values: &[&[f32]],
51    ) -> AttentionResult<Vec<f32>>;
52
53    /// Computes attention with optional mask.
54    ///
55    /// # Arguments
56    ///
57    /// * `query` - Query vector of shape [d_model]
58    /// * `keys` - Slice of key vectors, each of shape [d_model]
59    /// * `values` - Slice of value vectors, each of shape [d_model]
60    /// * `mask` - Optional attention mask (true = attend, false = mask out)
61    ///
62    /// # Returns
63    ///
64    /// Output vector of shape [d_model]
65    fn compute_with_mask(
66        &self,
67        query: &[f32],
68        keys: &[&[f32]],
69        values: &[&[f32]],
70        mask: Option<&[bool]>,
71    ) -> AttentionResult<Vec<f32>>;
72
73    /// Returns the model dimension.
74    fn dim(&self) -> usize;
75
76    /// Returns the number of attention heads (1 for single-head attention).
77    fn num_heads(&self) -> usize {
78        1
79    }
80}
81
82/// Graph attention mechanism trait.
83///
84/// Extends basic attention to operate over graph structures with explicit edges.
85pub trait GraphAttention: Attention {
86    /// Computes attention using graph structure.
87    ///
88    /// # Arguments
89    ///
90    /// * `node_features` - Features for all nodes, shape [num_nodes, d_model]
91    /// * `edges` - Edge information (source, destination, optional features)
92    ///
93    /// # Returns
94    ///
95    /// Updated node features of shape [num_nodes, d_model]
96    fn compute_with_edges(
97        &self,
98        node_features: &[Vec<f32>],
99        edges: &[EdgeInfo],
100    ) -> AttentionResult<Vec<Vec<f32>>>;
101
102    /// Computes attention weights for edges.
103    ///
104    /// # Arguments
105    ///
106    /// * `src_feature` - Source node feature
107    /// * `dst_feature` - Destination node feature
108    /// * `edge_feature` - Optional edge feature
109    ///
110    /// # Returns
111    ///
112    /// Attention weight for this edge
113    fn compute_edge_attention(
114        &self,
115        src_feature: &[f32],
116        dst_feature: &[f32],
117        edge_feature: Option<&[f32]>,
118    ) -> AttentionResult<f32>;
119}
120
121/// Geometric attention mechanism trait.
122///
123/// Implements attention in hyperbolic or other geometric spaces with curvature.
124pub trait GeometricAttention: Attention {
125    /// Computes attention in geometric space with specified curvature.
126    ///
127    /// # Arguments
128    ///
129    /// * `query` - Query vector in geometric space
130    /// * `keys` - Key vectors in geometric space
131    /// * `values` - Value vectors
132    /// * `curvature` - Curvature parameter (negative for hyperbolic space)
133    ///
134    /// # Returns
135    ///
136    /// Output vector in geometric space
137    fn compute_geometric(
138        &self,
139        query: &[f32],
140        keys: &[&[f32]],
141        values: &[&[f32]],
142        curvature: f32,
143    ) -> AttentionResult<Vec<f32>>;
144
145    /// Projects vector to geometric space.
146    fn project_to_geometric(&self, vector: &[f32], curvature: f32) -> AttentionResult<Vec<f32>>;
147
148    /// Projects vector back from geometric space.
149    fn project_from_geometric(&self, vector: &[f32], curvature: f32) -> AttentionResult<Vec<f32>>;
150}
151
152/// Sparse attention mechanism trait.
153///
154/// Implements efficient attention over sparse patterns.
155pub trait SparseAttention: Attention {
156    /// Computes sparse attention using the provided mask.
157    ///
158    /// # Arguments
159    ///
160    /// * `query` - Query vector
161    /// * `keys` - Key vectors
162    /// * `values` - Value vectors
163    /// * `mask` - Sparse mask defining attention pattern
164    ///
165    /// # Returns
166    ///
167    /// Output vector
168    fn compute_sparse(
169        &self,
170        query: &[f32],
171        keys: &[&[f32]],
172        values: &[&[f32]],
173        mask: &SparseMask,
174    ) -> AttentionResult<Vec<f32>>;
175
176    /// Generates a sparse mask for the given sequence length.
177    ///
178    /// # Arguments
179    ///
180    /// * `seq_len` - Sequence length
181    ///
182    /// # Returns
183    ///
184    /// Sparse mask for attention computation
185    fn generate_mask(&self, seq_len: usize) -> AttentionResult<SparseMask>;
186}
187
188/// Gradient information for backward pass.
189#[derive(Clone, Debug)]
190pub struct Gradients {
191    /// Gradient w.r.t. query
192    pub query_grad: Vec<f32>,
193    /// Gradient w.r.t. keys
194    pub keys_grad: Vec<Vec<f32>>,
195    /// Gradient w.r.t. values
196    pub values_grad: Vec<Vec<f32>>,
197    /// Gradient w.r.t. attention weights (for analysis)
198    pub attention_weights_grad: Option<Vec<f32>>,
199}
200
201/// Trainable attention mechanism with backward pass support.
202pub trait TrainableAttention: Attention {
203    /// Forward pass with gradient tracking.
204    ///
205    /// # Arguments
206    ///
207    /// * `query` - Query vector
208    /// * `keys` - Key vectors
209    /// * `values` - Value vectors
210    ///
211    /// # Returns
212    ///
213    /// Tuple of (output, attention_weights) for gradient computation
214    fn forward(
215        &self,
216        query: &[f32],
217        keys: &[&[f32]],
218        values: &[&[f32]],
219    ) -> AttentionResult<(Vec<f32>, Vec<f32>)>;
220
221    /// Backward pass for gradient computation.
222    ///
223    /// # Arguments
224    ///
225    /// * `grad_output` - Gradient from downstream layers
226    /// * `query` - Query from forward pass
227    /// * `keys` - Keys from forward pass
228    /// * `values` - Values from forward pass
229    /// * `attention_weights` - Attention weights from forward pass
230    ///
231    /// # Returns
232    ///
233    /// Gradients w.r.t. inputs
234    fn backward(
235        &self,
236        grad_output: &[f32],
237        query: &[f32],
238        keys: &[&[f32]],
239        values: &[&[f32]],
240        attention_weights: &[f32],
241    ) -> AttentionResult<Gradients>;
242
243    /// Updates parameters using computed gradients.
244    ///
245    /// # Arguments
246    ///
247    /// * `gradients` - Computed gradients
248    /// * `learning_rate` - Learning rate for update
249    fn update_parameters(
250        &mut self,
251        gradients: &Gradients,
252        learning_rate: f32,
253    ) -> AttentionResult<()>;
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259
260    #[test]
261    fn test_sparse_mask_creation() {
262        let mask = SparseMask {
263            rows: vec![0, 1, 2],
264            cols: vec![0, 1, 2],
265            values: None,
266        };
267
268        assert_eq!(mask.rows.len(), 3);
269        assert_eq!(mask.cols.len(), 3);
270        assert!(mask.values.is_none());
271    }
272
273    #[test]
274    fn test_edge_info_creation() {
275        let edge = EdgeInfo {
276            src: 0,
277            dst: 1,
278            features: Some(vec![0.5, 0.3]),
279        };
280
281        assert_eq!(edge.src, 0);
282        assert_eq!(edge.dst, 1);
283        assert_eq!(edge.features.as_ref().unwrap().len(), 2);
284    }
285
286    #[test]
287    fn test_gradients_creation() {
288        let grads = Gradients {
289            query_grad: vec![0.1, 0.2],
290            keys_grad: vec![vec![0.3, 0.4]],
291            values_grad: vec![vec![0.5, 0.6]],
292            attention_weights_grad: None,
293        };
294
295        assert_eq!(grads.query_grad.len(), 2);
296        assert_eq!(grads.keys_grad.len(), 1);
297        assert!(grads.attention_weights_grad.is_none());
298    }
299}