ruvector_math/utils/
numerical.rs1use super::{EPS, LOG_MAX, LOG_MIN};
4
5#[inline]
10pub fn log_sum_exp(values: &[f64]) -> f64 {
11 if values.is_empty() {
12 return f64::NEG_INFINITY;
13 }
14
15 let max_val = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
16
17 if max_val.is_infinite() {
18 return max_val;
19 }
20
21 let sum: f64 = values.iter().map(|&x| (x - max_val).exp()).sum();
22 max_val + sum.ln()
23}
24
25#[inline]
29pub fn log_softmax(values: &[f64]) -> Vec<f64> {
30 let lse = log_sum_exp(values);
31 values.iter().map(|&x| x - lse).collect()
32}
33
34#[inline]
36pub fn softmax(values: &[f64]) -> Vec<f64> {
37 if values.is_empty() {
38 return vec![];
39 }
40
41 let max_val = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
42 let exp_vals: Vec<f64> = values.iter().map(|&x| (x - max_val).exp()).collect();
43 let sum: f64 = exp_vals.iter().sum();
44
45 if sum < EPS {
46 vec![1.0 / values.len() as f64; values.len()]
47 } else {
48 exp_vals.iter().map(|&e| e / sum).collect()
49 }
50}
51
52#[inline]
54pub fn clamp_log(x: f64) -> f64 {
55 x.clamp(LOG_MIN, LOG_MAX)
56}
57
58#[inline]
60pub fn safe_ln(x: f64) -> f64 {
61 if x <= 0.0 {
62 LOG_MIN
63 } else {
64 x.ln().max(LOG_MIN)
65 }
66}
67
68#[inline]
70pub fn safe_exp(x: f64) -> f64 {
71 clamp_log(x).exp()
72}
73
74#[inline]
76pub fn norm(x: &[f64]) -> f64 {
77 x.iter().map(|&v| v * v).sum::<f64>().sqrt()
78}
79
80#[inline]
82pub fn dot(x: &[f64], y: &[f64]) -> f64 {
83 x.iter().zip(y.iter()).map(|(&a, &b)| a * b).sum()
84}
85
86#[inline]
88pub fn squared_euclidean(x: &[f64], y: &[f64]) -> f64 {
89 x.iter().zip(y.iter()).map(|(&a, &b)| (a - b).powi(2)).sum()
90}
91
92#[inline]
94pub fn euclidean_distance(x: &[f64], y: &[f64]) -> f64 {
95 squared_euclidean(x, y).sqrt()
96}
97
98pub fn normalize(x: &[f64]) -> Vec<f64> {
100 let n = norm(x);
101 if n < EPS {
102 x.to_vec()
103 } else {
104 x.iter().map(|&v| v / n).collect()
105 }
106}
107
108pub fn normalize_mut(x: &mut [f64]) {
110 let n = norm(x);
111 if n >= EPS {
112 for v in x.iter_mut() {
113 *v /= n;
114 }
115 }
116}
117
118#[inline]
120pub fn cosine_similarity(x: &[f64], y: &[f64]) -> f64 {
121 let dot_prod = dot(x, y);
122 let norm_x = norm(x);
123 let norm_y = norm(y);
124
125 if norm_x < EPS || norm_y < EPS {
126 0.0
127 } else {
128 (dot_prod / (norm_x * norm_y)).clamp(-1.0, 1.0)
129 }
130}
131
132pub fn kl_divergence(p: &[f64], q: &[f64]) -> f64 {
136 debug_assert_eq!(p.len(), q.len());
137
138 p.iter()
139 .zip(q.iter())
140 .map(|(&pi, &qi)| {
141 if pi < EPS {
142 0.0
143 } else if qi < EPS {
144 f64::INFINITY
145 } else {
146 pi * (pi / qi).ln()
147 }
148 })
149 .sum()
150}
151
152pub fn symmetric_kl(p: &[f64], q: &[f64]) -> f64 {
154 (kl_divergence(p, q) + kl_divergence(q, p)) / 2.0
155}
156
157pub fn jensen_shannon(p: &[f64], q: &[f64]) -> f64 {
159 let m: Vec<f64> = p.iter().zip(q.iter()).map(|(&pi, &qi)| (pi + qi) / 2.0).collect();
160 (kl_divergence(p, &m) + kl_divergence(q, &m)) / 2.0
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166
167 #[test]
168 fn test_log_sum_exp() {
169 let values = vec![1.0, 2.0, 3.0];
170 let result = log_sum_exp(&values);
171
172 let expected = (1.0_f64.exp() + 2.0_f64.exp() + 3.0_f64.exp()).ln();
174 assert!((result - expected).abs() < 1e-10);
175 }
176
177 #[test]
178 fn test_softmax() {
179 let values = vec![1.0, 2.0, 3.0];
180 let result = softmax(&values);
181
182 let sum: f64 = result.iter().sum();
184 assert!((sum - 1.0).abs() < 1e-10);
185
186 assert!(result[2] > result[1]);
188 assert!(result[1] > result[0]);
189 }
190
191 #[test]
192 fn test_normalize() {
193 let x = vec![3.0, 4.0];
194 let n = normalize(&x);
195
196 assert!((n[0] - 0.6).abs() < 1e-10);
197 assert!((n[1] - 0.8).abs() < 1e-10);
198
199 let norm_result = norm(&n);
200 assert!((norm_result - 1.0).abs() < 1e-10);
201 }
202
203 #[test]
204 fn test_kl_divergence() {
205 let p = vec![0.25, 0.25, 0.25, 0.25];
206 let q = vec![0.25, 0.25, 0.25, 0.25];
207
208 assert!(kl_divergence(&p, &q).abs() < 1e-10);
210 }
211}