ruvector_attention/
utils.rs

1//! Utility functions for attention mechanisms.
2//!
3//! This module provides common utilities like softmax, masking, and
4//! numerical stability helpers used across attention implementations.
5
6use crate::error::{AttentionError, AttentionResult};
7
8/// Stable softmax that returns Vec<f32> directly (no Result)
9/// Used by sparse, moe, and graph modules
10#[inline]
11pub fn stable_softmax(values: &[f32]) -> Vec<f32> {
12    if values.is_empty() {
13        return vec![];
14    }
15
16    // Find maximum for numerical stability
17    let max_val = values
18        .iter()
19        .copied()
20        .filter(|x| x.is_finite())
21        .fold(f32::NEG_INFINITY, f32::max);
22
23    if !max_val.is_finite() {
24        // All values are -inf or invalid, return uniform
25        let n = values.len();
26        return vec![1.0 / n as f32; n];
27    }
28
29    // Compute exp(x - max) and sum
30    let mut exp_values: Vec<f32> = values
31        .iter()
32        .map(|&x| {
33            if x.is_finite() {
34                (x - max_val).exp()
35            } else {
36                0.0
37            }
38        })
39        .collect();
40
41    let sum: f32 = exp_values.iter().sum();
42
43    if sum <= 1e-10 || !sum.is_finite() {
44        // Fallback to uniform
45        let n = values.len();
46        return vec![1.0 / n as f32; n];
47    }
48
49    // Normalize
50    let inv_sum = 1.0 / sum;
51    exp_values.iter_mut().for_each(|x| *x *= inv_sum);
52
53    exp_values
54}
55
56/// Computes softmax over a slice of values.
57///
58/// Uses the numerically stable variant: softmax(x) = exp(x - max(x)) / sum(exp(x - max(x)))
59///
60/// # Arguments
61///
62/// * `values` - Input values
63///
64/// # Returns
65///
66/// Softmax-normalized values
67#[inline]
68pub fn softmax(values: &[f32]) -> AttentionResult<Vec<f32>> {
69    if values.is_empty() {
70        return Err(AttentionError::EmptyInput(
71            "cannot compute softmax of empty slice".to_string(),
72        ));
73    }
74
75    // Find maximum for numerical stability
76    let max_val = values.iter().copied().fold(f32::NEG_INFINITY, f32::max);
77
78    if !max_val.is_finite() {
79        return Err(AttentionError::NumericalInstability(
80            "non-finite values in softmax input".to_string(),
81        ));
82    }
83
84    // Compute exp(x - max) and sum
85    let mut exp_values: Vec<f32> = values.iter().map(|&x| (x - max_val).exp()).collect();
86
87    let sum: f32 = exp_values.iter().sum();
88
89    if sum <= 0.0 || !sum.is_finite() {
90        return Err(AttentionError::NumericalInstability(
91            "invalid sum in softmax computation".to_string(),
92        ));
93    }
94
95    // Normalize
96    let inv_sum = 1.0 / sum;
97    exp_values.iter_mut().for_each(|x| *x *= inv_sum);
98
99    Ok(exp_values)
100}
101
102/// Computes softmax with masking support.
103///
104/// Masked positions are set to negative infinity before softmax,
105/// resulting in zero attention weights.
106///
107/// # Arguments
108///
109/// * `values` - Input values
110/// * `mask` - Optional mask (true = attend, false = mask out)
111///
112/// # Returns
113///
114/// Masked and softmax-normalized values
115#[inline]
116pub fn masked_softmax(values: &[f32], mask: Option<&[bool]>) -> AttentionResult<Vec<f32>> {
117    if values.is_empty() {
118        return Err(AttentionError::EmptyInput(
119            "cannot compute softmax of empty slice".to_string(),
120        ));
121    }
122
123    let masked_values = if let Some(m) = mask {
124        if m.len() != values.len() {
125            return Err(AttentionError::InvalidMask {
126                expected: format!("{}", values.len()),
127                actual: format!("{}", m.len()),
128            });
129        }
130
131        values
132            .iter()
133            .zip(m.iter())
134            .map(|(&v, &keep)| if keep { v } else { f32::NEG_INFINITY })
135            .collect::<Vec<_>>()
136    } else {
137        values.to_vec()
138    };
139
140    softmax(&masked_values)
141}
142
143/// Applies causal masking to attention scores.
144///
145/// For position i, only positions 0..=i can be attended to.
146///
147/// # Arguments
148///
149/// * `scores` - Attention scores matrix [query_len, key_len]
150/// * `query_len` - Number of query positions
151/// * `key_len` - Number of key positions
152///
153/// # Returns
154///
155/// Causally masked scores
156pub fn apply_causal_mask(
157    scores: &mut [f32],
158    query_len: usize,
159    key_len: usize,
160) -> AttentionResult<()> {
161    if scores.len() != query_len * key_len {
162        return Err(AttentionError::InvalidMask {
163            expected: format!("{}x{}", query_len, key_len),
164            actual: format!("{}", scores.len()),
165        });
166    }
167
168    for i in 0..query_len {
169        for j in (i + 1)..key_len {
170            scores[i * key_len + j] = f32::NEG_INFINITY;
171        }
172    }
173
174    Ok(())
175}
176
177/// Computes dot product between two vectors.
178///
179/// # Arguments
180///
181/// * `a` - First vector
182/// * `b` - Second vector
183///
184/// # Returns
185///
186/// Dot product value
187#[inline]
188pub fn dot_product(a: &[f32], b: &[f32]) -> AttentionResult<f32> {
189    if a.len() != b.len() {
190        return Err(AttentionError::DimensionMismatch {
191            expected: a.len(),
192            actual: b.len(),
193        });
194    }
195
196    Ok(a.iter().zip(b.iter()).map(|(x, y)| x * y).sum())
197}
198
199/// Scales a vector by a scalar value.
200///
201/// # Arguments
202///
203/// * `vector` - Input vector (modified in place)
204/// * `scale` - Scale factor
205#[inline]
206pub fn scale_vector(vector: &mut [f32], scale: f32) {
207    vector.iter_mut().for_each(|x| *x *= scale);
208}
209
210/// Adds two vectors element-wise.
211///
212/// # Arguments
213///
214/// * `a` - First vector
215/// * `b` - Second vector
216///
217/// # Returns
218///
219/// Sum vector
220#[inline]
221pub fn add_vectors(a: &[f32], b: &[f32]) -> AttentionResult<Vec<f32>> {
222    if a.len() != b.len() {
223        return Err(AttentionError::DimensionMismatch {
224            expected: a.len(),
225            actual: b.len(),
226        });
227    }
228
229    Ok(a.iter().zip(b.iter()).map(|(x, y)| x + y).collect())
230}
231
232/// Computes L2 norm of a vector.
233///
234/// # Arguments
235///
236/// * `vector` - Input vector
237///
238/// # Returns
239///
240/// L2 norm value
241#[inline]
242pub fn l2_norm(vector: &[f32]) -> f32 {
243    vector.iter().map(|x| x * x).sum::<f32>().sqrt()
244}
245
246/// Normalizes a vector to unit length.
247///
248/// # Arguments
249///
250/// * `vector` - Input vector (modified in place)
251///
252/// # Returns
253///
254/// Original norm before normalization
255pub fn normalize_vector(vector: &mut [f32]) -> AttentionResult<f32> {
256    let norm = l2_norm(vector);
257
258    if norm <= 0.0 || !norm.is_finite() {
259        return Err(AttentionError::NumericalInstability(
260            "cannot normalize zero or non-finite vector".to_string(),
261        ));
262    }
263
264    let inv_norm = 1.0 / norm;
265    vector.iter_mut().for_each(|x| *x *= inv_norm);
266
267    Ok(norm)
268}
269
270/// Applies dropout to a vector during training.
271///
272/// # Arguments
273///
274/// * `vector` - Input vector (modified in place)
275/// * `dropout_prob` - Dropout probability (0.0 to 1.0)
276/// * `training` - Whether in training mode
277/// * `rng` - Random number generator
278pub fn apply_dropout(
279    vector: &mut [f32],
280    dropout_prob: f32,
281    training: bool,
282    rng: &mut impl rand::Rng,
283) {
284    if !training || dropout_prob == 0.0 {
285        return;
286    }
287
288    let scale = 1.0 / (1.0 - dropout_prob);
289    for x in vector.iter_mut() {
290        if rng.gen::<f32>() < dropout_prob {
291            *x = 0.0;
292        } else {
293            *x *= scale;
294        }
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301    use approx::assert_relative_eq;
302
303    #[test]
304    fn test_softmax() {
305        let values = vec![1.0, 2.0, 3.0];
306        let result = softmax(&values).unwrap();
307
308        // Sum should be approximately 1.0
309        let sum: f32 = result.iter().sum();
310        assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
311
312        // Values should be in ascending order
313        assert!(result[0] < result[1]);
314        assert!(result[1] < result[2]);
315    }
316
317    #[test]
318    fn test_softmax_numerical_stability() {
319        let values = vec![1000.0, 1001.0, 1002.0];
320        let result = softmax(&values).unwrap();
321
322        let sum: f32 = result.iter().sum();
323        assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
324    }
325
326    #[test]
327    fn test_masked_softmax() {
328        let values = vec![1.0, 2.0, 3.0, 4.0];
329        let mask = vec![true, true, false, false];
330        let result = masked_softmax(&values, Some(&mask)).unwrap();
331
332        // Masked positions should be zero
333        assert_relative_eq!(result[2], 0.0, epsilon = 1e-6);
334        assert_relative_eq!(result[3], 0.0, epsilon = 1e-6);
335
336        // Unmasked positions should sum to 1
337        let sum: f32 = result[0] + result[1];
338        assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
339    }
340
341    #[test]
342    fn test_dot_product() {
343        let a = vec![1.0, 2.0, 3.0];
344        let b = vec![4.0, 5.0, 6.0];
345        let result = dot_product(&a, &b).unwrap();
346
347        assert_relative_eq!(result, 32.0, epsilon = 1e-6);
348    }
349
350    #[test]
351    fn test_scale_vector() {
352        let mut vector = vec![1.0, 2.0, 3.0];
353        scale_vector(&mut vector, 2.0);
354
355        assert_relative_eq!(vector[0], 2.0);
356        assert_relative_eq!(vector[1], 4.0);
357        assert_relative_eq!(vector[2], 6.0);
358    }
359
360    #[test]
361    fn test_normalize_vector() {
362        let mut vector = vec![3.0, 4.0];
363        let norm = normalize_vector(&mut vector).unwrap();
364
365        assert_relative_eq!(norm, 5.0, epsilon = 1e-6);
366        assert_relative_eq!(l2_norm(&vector), 1.0, epsilon = 1e-6);
367    }
368
369    #[test]
370    fn test_causal_mask() {
371        let mut scores = vec![0.0; 9]; // 3x3 matrix
372        apply_causal_mask(&mut scores, 3, 3).unwrap();
373
374        // Check upper triangle is masked
375        assert_eq!(scores[1], f32::NEG_INFINITY); // (0, 1)
376        assert_eq!(scores[2], f32::NEG_INFINITY); // (0, 2)
377        assert_eq!(scores[5], f32::NEG_INFINITY); // (1, 2)
378
379        // Check diagonal and lower triangle are not masked
380        assert_eq!(scores[0], 0.0); // (0, 0)
381        assert_eq!(scores[4], 0.0); // (1, 1)
382        assert_eq!(scores[8], 0.0); // (2, 2)
383    }
384}