ruvector_attention/traits.rs
1//! Trait definitions for attention mechanisms.
2//!
3//! This module defines the core traits that all attention mechanisms implement,
4//! including standard attention, graph attention, geometric attention, and
5//! trainable attention with backward pass support.
6
7use crate::error::AttentionResult;
8
9/// Mask for sparse attention patterns.
10#[derive(Clone, Debug)]
11pub struct SparseMask {
12 /// Row indices for sparse mask
13 pub rows: Vec<usize>,
14 /// Column indices for sparse mask
15 pub cols: Vec<usize>,
16 /// Optional values (if not provided, defaults to 1.0)
17 pub values: Option<Vec<f32>>,
18}
19
20/// Edge information for graph attention.
21#[derive(Clone, Debug)]
22pub struct EdgeInfo {
23 /// Source node index
24 pub src: usize,
25 /// Destination node index
26 pub dst: usize,
27 /// Optional edge features
28 pub features: Option<Vec<f32>>,
29}
30
31/// Core attention mechanism trait.
32///
33/// Implements the basic attention computation: Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V
34pub trait Attention: Send + Sync {
35 /// Computes attention over the given query, keys, and values.
36 ///
37 /// # Arguments
38 ///
39 /// * `query` - Query vector of shape [d_model]
40 /// * `keys` - Slice of key vectors, each of shape [d_model]
41 /// * `values` - Slice of value vectors, each of shape [d_model]
42 ///
43 /// # Returns
44 ///
45 /// Output vector of shape [d_model]
46 fn compute(
47 &self,
48 query: &[f32],
49 keys: &[&[f32]],
50 values: &[&[f32]],
51 ) -> AttentionResult<Vec<f32>>;
52
53 /// Computes attention with optional mask.
54 ///
55 /// # Arguments
56 ///
57 /// * `query` - Query vector of shape [d_model]
58 /// * `keys` - Slice of key vectors, each of shape [d_model]
59 /// * `values` - Slice of value vectors, each of shape [d_model]
60 /// * `mask` - Optional attention mask (true = attend, false = mask out)
61 ///
62 /// # Returns
63 ///
64 /// Output vector of shape [d_model]
65 fn compute_with_mask(
66 &self,
67 query: &[f32],
68 keys: &[&[f32]],
69 values: &[&[f32]],
70 mask: Option<&[bool]>,
71 ) -> AttentionResult<Vec<f32>>;
72
73 /// Returns the model dimension.
74 fn dim(&self) -> usize;
75
76 /// Returns the number of attention heads (1 for single-head attention).
77 fn num_heads(&self) -> usize {
78 1
79 }
80}
81
82/// Graph attention mechanism trait.
83///
84/// Extends basic attention to operate over graph structures with explicit edges.
85pub trait GraphAttention: Attention {
86 /// Computes attention using graph structure.
87 ///
88 /// # Arguments
89 ///
90 /// * `node_features` - Features for all nodes, shape [num_nodes, d_model]
91 /// * `edges` - Edge information (source, destination, optional features)
92 ///
93 /// # Returns
94 ///
95 /// Updated node features of shape [num_nodes, d_model]
96 fn compute_with_edges(
97 &self,
98 node_features: &[Vec<f32>],
99 edges: &[EdgeInfo],
100 ) -> AttentionResult<Vec<Vec<f32>>>;
101
102 /// Computes attention weights for edges.
103 ///
104 /// # Arguments
105 ///
106 /// * `src_feature` - Source node feature
107 /// * `dst_feature` - Destination node feature
108 /// * `edge_feature` - Optional edge feature
109 ///
110 /// # Returns
111 ///
112 /// Attention weight for this edge
113 fn compute_edge_attention(
114 &self,
115 src_feature: &[f32],
116 dst_feature: &[f32],
117 edge_feature: Option<&[f32]>,
118 ) -> AttentionResult<f32>;
119}
120
121/// Geometric attention mechanism trait.
122///
123/// Implements attention in hyperbolic or other geometric spaces with curvature.
124pub trait GeometricAttention: Attention {
125 /// Computes attention in geometric space with specified curvature.
126 ///
127 /// # Arguments
128 ///
129 /// * `query` - Query vector in geometric space
130 /// * `keys` - Key vectors in geometric space
131 /// * `values` - Value vectors
132 /// * `curvature` - Curvature parameter (negative for hyperbolic space)
133 ///
134 /// # Returns
135 ///
136 /// Output vector in geometric space
137 fn compute_geometric(
138 &self,
139 query: &[f32],
140 keys: &[&[f32]],
141 values: &[&[f32]],
142 curvature: f32,
143 ) -> AttentionResult<Vec<f32>>;
144
145 /// Projects vector to geometric space.
146 fn project_to_geometric(&self, vector: &[f32], curvature: f32) -> AttentionResult<Vec<f32>>;
147
148 /// Projects vector back from geometric space.
149 fn project_from_geometric(&self, vector: &[f32], curvature: f32) -> AttentionResult<Vec<f32>>;
150}
151
152/// Sparse attention mechanism trait.
153///
154/// Implements efficient attention over sparse patterns.
155pub trait SparseAttention: Attention {
156 /// Computes sparse attention using the provided mask.
157 ///
158 /// # Arguments
159 ///
160 /// * `query` - Query vector
161 /// * `keys` - Key vectors
162 /// * `values` - Value vectors
163 /// * `mask` - Sparse mask defining attention pattern
164 ///
165 /// # Returns
166 ///
167 /// Output vector
168 fn compute_sparse(
169 &self,
170 query: &[f32],
171 keys: &[&[f32]],
172 values: &[&[f32]],
173 mask: &SparseMask,
174 ) -> AttentionResult<Vec<f32>>;
175
176 /// Generates a sparse mask for the given sequence length.
177 ///
178 /// # Arguments
179 ///
180 /// * `seq_len` - Sequence length
181 ///
182 /// # Returns
183 ///
184 /// Sparse mask for attention computation
185 fn generate_mask(&self, seq_len: usize) -> AttentionResult<SparseMask>;
186}
187
188/// Gradient information for backward pass.
189#[derive(Clone, Debug)]
190pub struct Gradients {
191 /// Gradient w.r.t. query
192 pub query_grad: Vec<f32>,
193 /// Gradient w.r.t. keys
194 pub keys_grad: Vec<Vec<f32>>,
195 /// Gradient w.r.t. values
196 pub values_grad: Vec<Vec<f32>>,
197 /// Gradient w.r.t. attention weights (for analysis)
198 pub attention_weights_grad: Option<Vec<f32>>,
199}
200
201/// Trainable attention mechanism with backward pass support.
202pub trait TrainableAttention: Attention {
203 /// Forward pass with gradient tracking.
204 ///
205 /// # Arguments
206 ///
207 /// * `query` - Query vector
208 /// * `keys` - Key vectors
209 /// * `values` - Value vectors
210 ///
211 /// # Returns
212 ///
213 /// Tuple of (output, attention_weights) for gradient computation
214 fn forward(
215 &self,
216 query: &[f32],
217 keys: &[&[f32]],
218 values: &[&[f32]],
219 ) -> AttentionResult<(Vec<f32>, Vec<f32>)>;
220
221 /// Backward pass for gradient computation.
222 ///
223 /// # Arguments
224 ///
225 /// * `grad_output` - Gradient from downstream layers
226 /// * `query` - Query from forward pass
227 /// * `keys` - Keys from forward pass
228 /// * `values` - Values from forward pass
229 /// * `attention_weights` - Attention weights from forward pass
230 ///
231 /// # Returns
232 ///
233 /// Gradients w.r.t. inputs
234 fn backward(
235 &self,
236 grad_output: &[f32],
237 query: &[f32],
238 keys: &[&[f32]],
239 values: &[&[f32]],
240 attention_weights: &[f32],
241 ) -> AttentionResult<Gradients>;
242
243 /// Updates parameters using computed gradients.
244 ///
245 /// # Arguments
246 ///
247 /// * `gradients` - Computed gradients
248 /// * `learning_rate` - Learning rate for update
249 fn update_parameters(
250 &mut self,
251 gradients: &Gradients,
252 learning_rate: f32,
253 ) -> AttentionResult<()>;
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259
260 #[test]
261 fn test_sparse_mask_creation() {
262 let mask = SparseMask {
263 rows: vec![0, 1, 2],
264 cols: vec![0, 1, 2],
265 values: None,
266 };
267
268 assert_eq!(mask.rows.len(), 3);
269 assert_eq!(mask.cols.len(), 3);
270 assert!(mask.values.is_none());
271 }
272
273 #[test]
274 fn test_edge_info_creation() {
275 let edge = EdgeInfo {
276 src: 0,
277 dst: 1,
278 features: Some(vec![0.5, 0.3]),
279 };
280
281 assert_eq!(edge.src, 0);
282 assert_eq!(edge.dst, 1);
283 assert_eq!(edge.features.as_ref().unwrap().len(), 2);
284 }
285
286 #[test]
287 fn test_gradients_creation() {
288 let grads = Gradients {
289 query_grad: vec![0.1, 0.2],
290 keys_grad: vec![vec![0.3, 0.4]],
291 values_grad: vec![vec![0.5, 0.6]],
292 attention_weights_grad: None,
293 };
294
295 assert_eq!(grads.query_grad.len(), 2);
296 assert_eq!(grads.keys_grad.len(), 1);
297 assert!(grads.attention_weights_grad.is_none());
298 }
299}