ruvector_attention/
utils.rs1use crate::error::{AttentionError, AttentionResult};
7
8#[inline]
11pub fn stable_softmax(values: &[f32]) -> Vec<f32> {
12 if values.is_empty() {
13 return vec![];
14 }
15
16 let max_val = values
18 .iter()
19 .copied()
20 .filter(|x| x.is_finite())
21 .fold(f32::NEG_INFINITY, f32::max);
22
23 if !max_val.is_finite() {
24 let n = values.len();
26 return vec![1.0 / n as f32; n];
27 }
28
29 let mut exp_values: Vec<f32> = values
31 .iter()
32 .map(|&x| {
33 if x.is_finite() {
34 (x - max_val).exp()
35 } else {
36 0.0
37 }
38 })
39 .collect();
40
41 let sum: f32 = exp_values.iter().sum();
42
43 if sum <= 1e-10 || !sum.is_finite() {
44 let n = values.len();
46 return vec![1.0 / n as f32; n];
47 }
48
49 let inv_sum = 1.0 / sum;
51 exp_values.iter_mut().for_each(|x| *x *= inv_sum);
52
53 exp_values
54}
55
56#[inline]
68pub fn softmax(values: &[f32]) -> AttentionResult<Vec<f32>> {
69 if values.is_empty() {
70 return Err(AttentionError::EmptyInput(
71 "cannot compute softmax of empty slice".to_string(),
72 ));
73 }
74
75 let max_val = values.iter().copied().fold(f32::NEG_INFINITY, f32::max);
77
78 if !max_val.is_finite() {
79 return Err(AttentionError::NumericalInstability(
80 "non-finite values in softmax input".to_string(),
81 ));
82 }
83
84 let mut exp_values: Vec<f32> = values.iter().map(|&x| (x - max_val).exp()).collect();
86
87 let sum: f32 = exp_values.iter().sum();
88
89 if sum <= 0.0 || !sum.is_finite() {
90 return Err(AttentionError::NumericalInstability(
91 "invalid sum in softmax computation".to_string(),
92 ));
93 }
94
95 let inv_sum = 1.0 / sum;
97 exp_values.iter_mut().for_each(|x| *x *= inv_sum);
98
99 Ok(exp_values)
100}
101
102#[inline]
116pub fn masked_softmax(values: &[f32], mask: Option<&[bool]>) -> AttentionResult<Vec<f32>> {
117 if values.is_empty() {
118 return Err(AttentionError::EmptyInput(
119 "cannot compute softmax of empty slice".to_string(),
120 ));
121 }
122
123 let masked_values = if let Some(m) = mask {
124 if m.len() != values.len() {
125 return Err(AttentionError::InvalidMask {
126 expected: format!("{}", values.len()),
127 actual: format!("{}", m.len()),
128 });
129 }
130
131 values
132 .iter()
133 .zip(m.iter())
134 .map(|(&v, &keep)| if keep { v } else { f32::NEG_INFINITY })
135 .collect::<Vec<_>>()
136 } else {
137 values.to_vec()
138 };
139
140 softmax(&masked_values)
141}
142
143pub fn apply_causal_mask(
157 scores: &mut [f32],
158 query_len: usize,
159 key_len: usize,
160) -> AttentionResult<()> {
161 if scores.len() != query_len * key_len {
162 return Err(AttentionError::InvalidMask {
163 expected: format!("{}x{}", query_len, key_len),
164 actual: format!("{}", scores.len()),
165 });
166 }
167
168 for i in 0..query_len {
169 for j in (i + 1)..key_len {
170 scores[i * key_len + j] = f32::NEG_INFINITY;
171 }
172 }
173
174 Ok(())
175}
176
177#[inline]
188pub fn dot_product(a: &[f32], b: &[f32]) -> AttentionResult<f32> {
189 if a.len() != b.len() {
190 return Err(AttentionError::DimensionMismatch {
191 expected: a.len(),
192 actual: b.len(),
193 });
194 }
195
196 Ok(a.iter().zip(b.iter()).map(|(x, y)| x * y).sum())
197}
198
199#[inline]
206pub fn scale_vector(vector: &mut [f32], scale: f32) {
207 vector.iter_mut().for_each(|x| *x *= scale);
208}
209
210#[inline]
221pub fn add_vectors(a: &[f32], b: &[f32]) -> AttentionResult<Vec<f32>> {
222 if a.len() != b.len() {
223 return Err(AttentionError::DimensionMismatch {
224 expected: a.len(),
225 actual: b.len(),
226 });
227 }
228
229 Ok(a.iter().zip(b.iter()).map(|(x, y)| x + y).collect())
230}
231
232#[inline]
242pub fn l2_norm(vector: &[f32]) -> f32 {
243 vector.iter().map(|x| x * x).sum::<f32>().sqrt()
244}
245
246pub fn normalize_vector(vector: &mut [f32]) -> AttentionResult<f32> {
256 let norm = l2_norm(vector);
257
258 if norm <= 0.0 || !norm.is_finite() {
259 return Err(AttentionError::NumericalInstability(
260 "cannot normalize zero or non-finite vector".to_string(),
261 ));
262 }
263
264 let inv_norm = 1.0 / norm;
265 vector.iter_mut().for_each(|x| *x *= inv_norm);
266
267 Ok(norm)
268}
269
270pub fn apply_dropout(
279 vector: &mut [f32],
280 dropout_prob: f32,
281 training: bool,
282 rng: &mut impl rand::Rng,
283) {
284 if !training || dropout_prob == 0.0 {
285 return;
286 }
287
288 let scale = 1.0 / (1.0 - dropout_prob);
289 for x in vector.iter_mut() {
290 if rng.gen::<f32>() < dropout_prob {
291 *x = 0.0;
292 } else {
293 *x *= scale;
294 }
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301 use approx::assert_relative_eq;
302
303 #[test]
304 fn test_softmax() {
305 let values = vec![1.0, 2.0, 3.0];
306 let result = softmax(&values).unwrap();
307
308 let sum: f32 = result.iter().sum();
310 assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
311
312 assert!(result[0] < result[1]);
314 assert!(result[1] < result[2]);
315 }
316
317 #[test]
318 fn test_softmax_numerical_stability() {
319 let values = vec![1000.0, 1001.0, 1002.0];
320 let result = softmax(&values).unwrap();
321
322 let sum: f32 = result.iter().sum();
323 assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
324 }
325
326 #[test]
327 fn test_masked_softmax() {
328 let values = vec![1.0, 2.0, 3.0, 4.0];
329 let mask = vec![true, true, false, false];
330 let result = masked_softmax(&values, Some(&mask)).unwrap();
331
332 assert_relative_eq!(result[2], 0.0, epsilon = 1e-6);
334 assert_relative_eq!(result[3], 0.0, epsilon = 1e-6);
335
336 let sum: f32 = result[0] + result[1];
338 assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
339 }
340
341 #[test]
342 fn test_dot_product() {
343 let a = vec![1.0, 2.0, 3.0];
344 let b = vec![4.0, 5.0, 6.0];
345 let result = dot_product(&a, &b).unwrap();
346
347 assert_relative_eq!(result, 32.0, epsilon = 1e-6);
348 }
349
350 #[test]
351 fn test_scale_vector() {
352 let mut vector = vec![1.0, 2.0, 3.0];
353 scale_vector(&mut vector, 2.0);
354
355 assert_relative_eq!(vector[0], 2.0);
356 assert_relative_eq!(vector[1], 4.0);
357 assert_relative_eq!(vector[2], 6.0);
358 }
359
360 #[test]
361 fn test_normalize_vector() {
362 let mut vector = vec![3.0, 4.0];
363 let norm = normalize_vector(&mut vector).unwrap();
364
365 assert_relative_eq!(norm, 5.0, epsilon = 1e-6);
366 assert_relative_eq!(l2_norm(&vector), 1.0, epsilon = 1e-6);
367 }
368
369 #[test]
370 fn test_causal_mask() {
371 let mut scores = vec![0.0; 9]; apply_causal_mask(&mut scores, 3, 3).unwrap();
373
374 assert_eq!(scores[1], f32::NEG_INFINITY); assert_eq!(scores[2], f32::NEG_INFINITY); assert_eq!(scores[5], f32::NEG_INFINITY); assert_eq!(scores[0], 0.0); assert_eq!(scores[4], 0.0); assert_eq!(scores[8], 0.0); }
384}