ruvector_attention/
traits.rs

1//! Trait definitions for attention mechanisms.
2//!
3//! This module defines the core traits that all attention mechanisms implement,
4//! including standard attention, graph attention, geometric attention, and
5//! trainable attention with backward pass support.
6
7use crate::error::AttentionResult;
8
9/// Mask for sparse attention patterns.
10#[derive(Clone, Debug)]
11pub struct SparseMask {
12    /// Row indices for sparse mask
13    pub rows: Vec<usize>,
14    /// Column indices for sparse mask
15    pub cols: Vec<usize>,
16    /// Optional values (if not provided, defaults to 1.0)
17    pub values: Option<Vec<f32>>,
18}
19
20/// Edge information for graph attention.
21#[derive(Clone, Debug)]
22pub struct EdgeInfo {
23    /// Source node index
24    pub src: usize,
25    /// Destination node index
26    pub dst: usize,
27    /// Optional edge features
28    pub features: Option<Vec<f32>>,
29}
30
31/// Core attention mechanism trait.
32///
33/// Implements the basic attention computation: Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V
34pub trait Attention: Send + Sync {
35    /// Computes attention over the given query, keys, and values.
36    ///
37    /// # Arguments
38    ///
39    /// * `query` - Query vector of shape [d_model]
40    /// * `keys` - Slice of key vectors, each of shape [d_model]
41    /// * `values` - Slice of value vectors, each of shape [d_model]
42    ///
43    /// # Returns
44    ///
45    /// Output vector of shape [d_model]
46    fn compute(
47        &self,
48        query: &[f32],
49        keys: &[&[f32]],
50        values: &[&[f32]],
51    ) -> AttentionResult<Vec<f32>>;
52
53    /// Computes attention with optional mask.
54    ///
55    /// # Arguments
56    ///
57    /// * `query` - Query vector of shape [d_model]
58    /// * `keys` - Slice of key vectors, each of shape [d_model]
59    /// * `values` - Slice of value vectors, each of shape [d_model]
60    /// * `mask` - Optional attention mask (true = attend, false = mask out)
61    ///
62    /// # Returns
63    ///
64    /// Output vector of shape [d_model]
65    fn compute_with_mask(
66        &self,
67        query: &[f32],
68        keys: &[&[f32]],
69        values: &[&[f32]],
70        mask: Option<&[bool]>,
71    ) -> AttentionResult<Vec<f32>>;
72
73    /// Returns the model dimension.
74    fn dim(&self) -> usize;
75
76    /// Returns the number of attention heads (1 for single-head attention).
77    fn num_heads(&self) -> usize {
78        1
79    }
80}
81
82/// Graph attention mechanism trait.
83///
84/// Extends basic attention to operate over graph structures with explicit edges.
85pub trait GraphAttention: Attention {
86    /// Computes attention using graph structure.
87    ///
88    /// # Arguments
89    ///
90    /// * `node_features` - Features for all nodes, shape [num_nodes, d_model]
91    /// * `edges` - Edge information (source, destination, optional features)
92    ///
93    /// # Returns
94    ///
95    /// Updated node features of shape [num_nodes, d_model]
96    fn compute_with_edges(
97        &self,
98        node_features: &[Vec<f32>],
99        edges: &[EdgeInfo],
100    ) -> AttentionResult<Vec<Vec<f32>>>;
101
102    /// Computes attention weights for edges.
103    ///
104    /// # Arguments
105    ///
106    /// * `src_feature` - Source node feature
107    /// * `dst_feature` - Destination node feature
108    /// * `edge_feature` - Optional edge feature
109    ///
110    /// # Returns
111    ///
112    /// Attention weight for this edge
113    fn compute_edge_attention(
114        &self,
115        src_feature: &[f32],
116        dst_feature: &[f32],
117        edge_feature: Option<&[f32]>,
118    ) -> AttentionResult<f32>;
119}
120
121/// Geometric attention mechanism trait.
122///
123/// Implements attention in hyperbolic or other geometric spaces with curvature.
124pub trait GeometricAttention: Attention {
125    /// Computes attention in geometric space with specified curvature.
126    ///
127    /// # Arguments
128    ///
129    /// * `query` - Query vector in geometric space
130    /// * `keys` - Key vectors in geometric space
131    /// * `values` - Value vectors
132    /// * `curvature` - Curvature parameter (negative for hyperbolic space)
133    ///
134    /// # Returns
135    ///
136    /// Output vector in geometric space
137    fn compute_geometric(
138        &self,
139        query: &[f32],
140        keys: &[&[f32]],
141        values: &[&[f32]],
142        curvature: f32,
143    ) -> AttentionResult<Vec<f32>>;
144
145    /// Projects vector to geometric space.
146    fn project_to_geometric(&self, vector: &[f32], curvature: f32) -> AttentionResult<Vec<f32>>;
147
148    /// Projects vector back from geometric space.
149    fn project_from_geometric(&self, vector: &[f32], curvature: f32)
150        -> AttentionResult<Vec<f32>>;
151}
152
153/// Sparse attention mechanism trait.
154///
155/// Implements efficient attention over sparse patterns.
156pub trait SparseAttention: Attention {
157    /// Computes sparse attention using the provided mask.
158    ///
159    /// # Arguments
160    ///
161    /// * `query` - Query vector
162    /// * `keys` - Key vectors
163    /// * `values` - Value vectors
164    /// * `mask` - Sparse mask defining attention pattern
165    ///
166    /// # Returns
167    ///
168    /// Output vector
169    fn compute_sparse(
170        &self,
171        query: &[f32],
172        keys: &[&[f32]],
173        values: &[&[f32]],
174        mask: &SparseMask,
175    ) -> AttentionResult<Vec<f32>>;
176
177    /// Generates a sparse mask for the given sequence length.
178    ///
179    /// # Arguments
180    ///
181    /// * `seq_len` - Sequence length
182    ///
183    /// # Returns
184    ///
185    /// Sparse mask for attention computation
186    fn generate_mask(&self, seq_len: usize) -> AttentionResult<SparseMask>;
187}
188
189/// Gradient information for backward pass.
190#[derive(Clone, Debug)]
191pub struct Gradients {
192    /// Gradient w.r.t. query
193    pub query_grad: Vec<f32>,
194    /// Gradient w.r.t. keys
195    pub keys_grad: Vec<Vec<f32>>,
196    /// Gradient w.r.t. values
197    pub values_grad: Vec<Vec<f32>>,
198    /// Gradient w.r.t. attention weights (for analysis)
199    pub attention_weights_grad: Option<Vec<f32>>,
200}
201
202/// Trainable attention mechanism with backward pass support.
203pub trait TrainableAttention: Attention {
204    /// Forward pass with gradient tracking.
205    ///
206    /// # Arguments
207    ///
208    /// * `query` - Query vector
209    /// * `keys` - Key vectors
210    /// * `values` - Value vectors
211    ///
212    /// # Returns
213    ///
214    /// Tuple of (output, attention_weights) for gradient computation
215    fn forward(
216        &self,
217        query: &[f32],
218        keys: &[&[f32]],
219        values: &[&[f32]],
220    ) -> AttentionResult<(Vec<f32>, Vec<f32>)>;
221
222    /// Backward pass for gradient computation.
223    ///
224    /// # Arguments
225    ///
226    /// * `grad_output` - Gradient from downstream layers
227    /// * `query` - Query from forward pass
228    /// * `keys` - Keys from forward pass
229    /// * `values` - Values from forward pass
230    /// * `attention_weights` - Attention weights from forward pass
231    ///
232    /// # Returns
233    ///
234    /// Gradients w.r.t. inputs
235    fn backward(
236        &self,
237        grad_output: &[f32],
238        query: &[f32],
239        keys: &[&[f32]],
240        values: &[&[f32]],
241        attention_weights: &[f32],
242    ) -> AttentionResult<Gradients>;
243
244    /// Updates parameters using computed gradients.
245    ///
246    /// # Arguments
247    ///
248    /// * `gradients` - Computed gradients
249    /// * `learning_rate` - Learning rate for update
250    fn update_parameters(&mut self, gradients: &Gradients, learning_rate: f32)
251        -> AttentionResult<()>;
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257
258    #[test]
259    fn test_sparse_mask_creation() {
260        let mask = SparseMask {
261            rows: vec![0, 1, 2],
262            cols: vec![0, 1, 2],
263            values: None,
264        };
265
266        assert_eq!(mask.rows.len(), 3);
267        assert_eq!(mask.cols.len(), 3);
268        assert!(mask.values.is_none());
269    }
270
271    #[test]
272    fn test_edge_info_creation() {
273        let edge = EdgeInfo {
274            src: 0,
275            dst: 1,
276            features: Some(vec![0.5, 0.3]),
277        };
278
279        assert_eq!(edge.src, 0);
280        assert_eq!(edge.dst, 1);
281        assert_eq!(edge.features.as_ref().unwrap().len(), 2);
282    }
283
284    #[test]
285    fn test_gradients_creation() {
286        let grads = Gradients {
287            query_grad: vec![0.1, 0.2],
288            keys_grad: vec![vec![0.3, 0.4]],
289            values_grad: vec![vec![0.5, 0.6]],
290            attention_weights_grad: None,
291        };
292
293        assert_eq!(grads.query_grad.len(), 2);
294        assert_eq!(grads.keys_grad.len(), 1);
295        assert!(grads.attention_weights_grad.is_none());
296    }
297}