Skip to main content

sciforge_lib/maths/optimization/
gradient.rs

1use crate::constants::GOLDEN_RATIO_CONJUGATE;
2
3pub fn gradient_descent(
4    f: fn(&[f64]) -> f64,
5    grad: fn(&[f64]) -> Vec<f64>,
6    x0: &[f64],
7    lr: f64,
8    max_iter: usize,
9    tol: f64,
10) -> Vec<f64> {
11    let mut x = x0.to_vec();
12    for _ in 0..max_iter {
13        let g = grad(&x);
14        let norm: f64 = g.iter().map(|v| v * v).sum::<f64>().sqrt();
15        if norm < tol {
16            break;
17        }
18        for i in 0..x.len() {
19            x[i] -= lr * g[i];
20        }
21    }
22    let _ = f(&x);
23    x
24}
25
26pub fn gradient_descent_momentum(
27    grad: fn(&[f64]) -> Vec<f64>,
28    x0: &[f64],
29    lr: f64,
30    momentum: f64,
31    max_iter: usize,
32    tol: f64,
33) -> Vec<f64> {
34    let mut x = x0.to_vec();
35    let mut v = vec![0.0; x.len()];
36    for _ in 0..max_iter {
37        let g = grad(&x);
38        let norm: f64 = g.iter().map(|v| v * v).sum::<f64>().sqrt();
39        if norm < tol {
40            break;
41        }
42        for i in 0..x.len() {
43            v[i] = momentum * v[i] - lr * g[i];
44            x[i] += v[i];
45        }
46    }
47    x
48}
49
50pub fn adam(
51    grad: fn(&[f64]) -> Vec<f64>,
52    x0: &[f64],
53    lr: f64,
54    beta1: f64,
55    beta2: f64,
56    max_iter: usize,
57    tol: f64,
58) -> Vec<f64> {
59    let mut x = x0.to_vec();
60    let n = x.len();
61    let mut m = vec![0.0; n];
62    let mut v = vec![0.0; n];
63    let eps = 1e-8;
64    for t in 1..=max_iter {
65        let g = grad(&x);
66        let norm: f64 = g.iter().map(|v| v * v).sum::<f64>().sqrt();
67        if norm < tol {
68            break;
69        }
70        for i in 0..n {
71            m[i] = beta1 * m[i] + (1.0 - beta1) * g[i];
72            v[i] = beta2 * v[i] + (1.0 - beta2) * g[i] * g[i];
73            let m_hat = m[i] / (1.0 - beta1.powi(t as i32));
74            let v_hat = v[i] / (1.0 - beta2.powi(t as i32));
75            x[i] -= lr * m_hat / (v_hat.sqrt() + eps);
76        }
77    }
78    x
79}
80
81pub fn newton_method_1d(
82    f: fn(f64) -> f64,
83    df: fn(f64) -> f64,
84    x0: f64,
85    max_iter: usize,
86    tol: f64,
87) -> f64 {
88    let mut x = x0;
89    for _ in 0..max_iter {
90        let fx = f(x);
91        let dfx = df(x);
92        if dfx.abs() < 1e-30 {
93            break;
94        }
95        let x_new = x - fx / dfx;
96        if (x_new - x).abs() < tol {
97            return x_new;
98        }
99        x = x_new;
100    }
101    x
102}
103
104pub fn bisection(f: fn(f64) -> f64, mut a: f64, mut b: f64, tol: f64, max_iter: usize) -> f64 {
105    for _ in 0..max_iter {
106        let mid = (a + b) / 2.0;
107        if (b - a) / 2.0 < tol {
108            return mid;
109        }
110        if f(a) * f(mid) < 0.0 {
111            b = mid;
112        } else {
113            a = mid;
114        }
115    }
116    (a + b) / 2.0
117}
118
119pub fn secant_method(f: fn(f64) -> f64, x0: f64, x1: f64, max_iter: usize, tol: f64) -> f64 {
120    let mut xn_1 = x0;
121    let mut xn = x1;
122    for _ in 0..max_iter {
123        let fxn = f(xn);
124        let fxn_1 = f(xn_1);
125        if (fxn - fxn_1).abs() < 1e-30 {
126            break;
127        }
128        let x_new = xn - fxn * (xn - xn_1) / (fxn - fxn_1);
129        if (x_new - xn).abs() < tol {
130            return x_new;
131        }
132        xn_1 = xn;
133        xn = x_new;
134    }
135    xn
136}
137
138pub fn golden_section_search(f: fn(f64) -> f64, mut a: f64, mut b: f64, tol: f64) -> f64 {
139    let mut c = b - GOLDEN_RATIO_CONJUGATE * (b - a);
140    let mut d = a + GOLDEN_RATIO_CONJUGATE * (b - a);
141    while (b - a).abs() > tol {
142        if f(c) < f(d) {
143            b = d;
144        } else {
145            a = c;
146        }
147        c = b - GOLDEN_RATIO_CONJUGATE * (b - a);
148        d = a + GOLDEN_RATIO_CONJUGATE * (b - a);
149    }
150    (a + b) / 2.0
151}
152
153pub fn numerical_gradient(f: fn(&[f64]) -> f64, x: &[f64], h: f64) -> Vec<f64> {
154    let n = x.len();
155    let mut grad = vec![0.0; n];
156    for i in 0..n {
157        let mut xp = x.to_vec();
158        let mut xm = x.to_vec();
159        xp[i] += h;
160        xm[i] -= h;
161        grad[i] = (f(&xp) - f(&xm)) / (2.0 * h);
162    }
163    grad
164}
165
166pub fn nesterov_momentum(
167    grad: fn(&[f64]) -> Vec<f64>,
168    x0: &[f64],
169    lr: f64,
170    momentum: f64,
171    max_iter: usize,
172    tol: f64,
173) -> Vec<f64> {
174    let mut x = x0.to_vec();
175    let n = x.len();
176    let mut v = vec![0.0; n];
177    for _ in 0..max_iter {
178        let lookahead: Vec<f64> = (0..n).map(|i| x[i] + momentum * v[i]).collect();
179        let g = grad(&lookahead);
180        let norm: f64 = g.iter().map(|gi| gi * gi).sum::<f64>().sqrt();
181        if norm < tol {
182            break;
183        }
184        for i in 0..n {
185            v[i] = momentum * v[i] - lr * g[i];
186            x[i] += v[i];
187        }
188    }
189    x
190}
191
192pub fn rmsprop(
193    grad: fn(&[f64]) -> Vec<f64>,
194    x0: &[f64],
195    lr: f64,
196    decay: f64,
197    max_iter: usize,
198    tol: f64,
199) -> Vec<f64> {
200    let mut x = x0.to_vec();
201    let n = x.len();
202    let mut cache = vec![0.0; n];
203    let eps = 1e-8;
204    for _ in 0..max_iter {
205        let g = grad(&x);
206        let norm: f64 = g.iter().map(|gi| gi * gi).sum::<f64>().sqrt();
207        if norm < tol {
208            break;
209        }
210        for i in 0..n {
211            cache[i] = decay * cache[i] + (1.0 - decay) * g[i] * g[i];
212            x[i] -= lr * g[i] / (cache[i].sqrt() + eps);
213        }
214    }
215    x
216}
217
218pub fn adagrad(
219    grad: fn(&[f64]) -> Vec<f64>,
220    x0: &[f64],
221    lr: f64,
222    max_iter: usize,
223    tol: f64,
224) -> Vec<f64> {
225    let mut x = x0.to_vec();
226    let n = x.len();
227    let mut accum = vec![0.0; n];
228    let eps = 1e-8;
229    for _ in 0..max_iter {
230        let g = grad(&x);
231        let norm: f64 = g.iter().map(|gi| gi * gi).sum::<f64>().sqrt();
232        if norm < tol {
233            break;
234        }
235        for i in 0..n {
236            accum[i] += g[i] * g[i];
237            x[i] -= lr * g[i] / (accum[i].sqrt() + eps);
238        }
239    }
240    x
241}
242
243pub fn line_search_backtracking(
244    f: fn(&[f64]) -> f64,
245    x: &[f64],
246    direction: &[f64],
247    alpha0: f64,
248    c: f64,
249    rho: f64,
250) -> f64 {
251    let n = x.len();
252    let fx = f(x);
253    let grad_dot: f64 = {
254        let h = 1e-7;
255        (0..n)
256            .map(|i| {
257                let mut xp = x.to_vec();
258                let mut xm = x.to_vec();
259                xp[i] += h;
260                xm[i] -= h;
261                let gi = (f(&xp) - f(&xm)) / (2.0 * h);
262                gi * direction[i]
263            })
264            .sum()
265    };
266    let mut alpha = alpha0;
267    for _ in 0..50 {
268        let x_new: Vec<f64> = (0..n).map(|i| x[i] + alpha * direction[i]).collect();
269        if f(&x_new) <= fx + c * alpha * grad_dot {
270            break;
271        }
272        alpha *= rho;
273    }
274    alpha
275}
276
277pub fn bfgs(
278    f: fn(&[f64]) -> f64,
279    grad: fn(&[f64]) -> Vec<f64>,
280    x0: &[f64],
281    max_iter: usize,
282    tol: f64,
283) -> Vec<f64> {
284    let n = x0.len();
285    let mut x = x0.to_vec();
286    let mut h = vec![vec![0.0; n]; n];
287    for (i, hi) in h.iter_mut().enumerate() {
288        hi[i] = 1.0;
289    }
290    let mut g = grad(&x);
291    for _ in 0..max_iter {
292        let norm: f64 = g.iter().map(|gi| gi * gi).sum::<f64>().sqrt();
293        if norm < tol {
294            break;
295        }
296        let direction: Vec<f64> = (0..n)
297            .map(|i| -(0..n).map(|j| h[i][j] * g[j]).sum::<f64>())
298            .collect();
299        let alpha = line_search_backtracking(f, &x, &direction, 1.0, 1e-4, 0.5);
300        let s: Vec<f64> = (0..n).map(|i| alpha * direction[i]).collect();
301        let x_new: Vec<f64> = (0..n).map(|i| x[i] + s[i]).collect();
302        let g_new = grad(&x_new);
303        let y: Vec<f64> = (0..n).map(|i| g_new[i] - g[i]).collect();
304        let sy: f64 = s.iter().zip(y.iter()).map(|(si, yi)| si * yi).sum();
305        if sy.abs() < 1e-30 {
306            x = x_new;
307            g = g_new;
308            continue;
309        }
310        let mut hs = vec![0.0; n];
311        for (i, hsi) in hs.iter_mut().enumerate() {
312            for (j, &yj) in y.iter().enumerate() {
313                *hsi += h[i][j] * yj;
314            }
315        }
316        let yhy: f64 = y.iter().zip(hs.iter()).map(|(yi, hi)| yi * hi).sum();
317        for (i, hi) in h.iter_mut().enumerate() {
318            for j in 0..n {
319                hi[j] += (sy + yhy) * s[i] * s[j] / (sy * sy) - (hs[i] * s[j] + s[i] * hs[j]) / sy;
320            }
321        }
322        x = x_new;
323        g = g_new;
324    }
325    x
326}
327
328pub fn conjugate_gradient_min(
329    grad: fn(&[f64]) -> Vec<f64>,
330    x0: &[f64],
331    max_iter: usize,
332    tol: f64,
333) -> Vec<f64> {
334    let n = x0.len();
335    let mut x = x0.to_vec();
336    let mut g = grad(&x);
337    let mut d: Vec<f64> = g.iter().map(|gi| -gi).collect();
338    for _ in 0..max_iter {
339        let g_norm: f64 = g.iter().map(|gi| gi * gi).sum::<f64>();
340        if g_norm.sqrt() < tol {
341            break;
342        }
343        let alpha = {
344            let mut a = 0.001;
345            for _ in 0..20 {
346                let x1: Vec<f64> = (0..n).map(|i| x[i] + a * d[i]).collect();
347                let g1 = grad(&x1);
348                let dg: f64 = d.iter().zip(g1.iter()).map(|(di, gi)| di * gi).sum();
349                if dg.abs() < tol {
350                    break;
351                }
352                a *= 0.5;
353            }
354            a
355        };
356        for i in 0..n {
357            x[i] += alpha * d[i];
358        }
359        let g_new = grad(&x);
360        let g_new_norm: f64 = g_new.iter().map(|gi| gi * gi).sum::<f64>();
361        let beta = g_new_norm / g_norm;
362        d = (0..n).map(|i| -g_new[i] + beta * d[i]).collect();
363        g = g_new;
364    }
365    x
366}
367
368pub fn hessian_numerical(f: fn(&[f64]) -> f64, x: &[f64], h: f64) -> Vec<Vec<f64>> {
369    let n = x.len();
370    let mut hess = vec![vec![0.0; n]; n];
371    let fx = f(x);
372    for i in 0..n {
373        for j in i..n {
374            let mut xpp = x.to_vec();
375            let mut xpm = x.to_vec();
376            let mut xmp = x.to_vec();
377            let mut xmm = x.to_vec();
378            xpp[i] += h;
379            xpp[j] += h;
380            xpm[i] += h;
381            xpm[j] -= h;
382            xmp[i] -= h;
383            xmp[j] += h;
384            xmm[i] -= h;
385            xmm[j] -= h;
386            let val = (f(&xpp) - f(&xpm) - f(&xmp) + f(&xmm)) / (4.0 * h * h);
387            hess[i][j] = val;
388            hess[j][i] = val;
389        }
390    }
391    let _ = fx;
392    hess
393}
394
395pub fn ternary_search(f: fn(f64) -> f64, mut lo: f64, mut hi: f64, tol: f64) -> f64 {
396    while (hi - lo).abs() > tol {
397        let m1 = lo + (hi - lo) / 3.0;
398        let m2 = hi - (hi - lo) / 3.0;
399        if f(m1) < f(m2) {
400            hi = m2;
401        } else {
402            lo = m1;
403        }
404    }
405    (lo + hi) / 2.0
406}
407
408pub fn newton_method_nd(f: fn(&[f64]) -> f64, x0: &[f64], max_iter: usize, tol: f64) -> Vec<f64> {
409    let n = x0.len();
410    let mut x = x0.to_vec();
411    let h = 1e-6;
412    for _ in 0..max_iter {
413        let g = numerical_gradient(f, &x, h);
414        let norm: f64 = g.iter().map(|gi| gi * gi).sum::<f64>().sqrt();
415        if norm < tol {
416            break;
417        }
418        let hess = hessian_numerical(f, &x, h);
419        let dx = solve_linear_system(&hess, &g);
420        for i in 0..n {
421            x[i] -= dx[i];
422        }
423    }
424    x
425}
426
427fn solve_linear_system(a: &[Vec<f64>], b: &[f64]) -> Vec<f64> {
428    let n = b.len();
429    let mut aug: Vec<Vec<f64>> = (0..n)
430        .map(|i| {
431            let mut row = a[i].clone();
432            row.push(b[i]);
433            row
434        })
435        .collect();
436    for col in 0..n {
437        let mut pivot = col;
438        for row in (col + 1)..n {
439            if aug[row][col].abs() > aug[pivot][col].abs() {
440                pivot = row;
441            }
442        }
443        aug.swap(col, pivot);
444        if aug[col][col].abs() < 1e-30 {
445            continue;
446        }
447        for row in (col + 1)..n {
448            let factor = aug[row][col] / aug[col][col];
449            let (top, bot) = aug.split_at_mut(row);
450            for (d, &s) in bot[0][col..=n].iter_mut().zip(&top[col][col..=n]) {
451                *d -= factor * s;
452            }
453        }
454    }
455    let mut x = vec![0.0; n];
456    for i in (0..n).rev() {
457        x[i] = aug[i][n];
458        for j in (i + 1)..n {
459            x[i] -= aug[i][j] * x[j];
460        }
461        if aug[i][i].abs() > 1e-30 {
462            x[i] /= aug[i][i];
463        }
464    }
465    x
466}