scirs2_optimize/unconstrained/
utils.rs1use crate::error::OptimizeError;
4use ndarray::{Array1, Array2, ArrayView1};
5
6pub fn finite_difference_gradient<F, S>(
8 fun: &mut F,
9 x: &ArrayView1<f64>,
10 step: f64,
11) -> Result<Array1<f64>, OptimizeError>
12where
13 F: FnMut(&ArrayView1<f64>) -> S,
14 S: Into<f64>,
15{
16 let n = x.len();
17 let mut grad = Array1::<f64>::zeros(n);
18 let mut x_plus = x.to_owned();
19 let mut x_minus = x.to_owned();
20
21 for i in 0..n {
22 let h = step * (1.0 + x[i].abs());
23 x_plus[i] = x[i] + h;
24 x_minus[i] = x[i] - h;
25
26 let f_plus = fun(&x_plus.view()).into();
27 let f_minus = fun(&x_minus.view()).into();
28
29 if !f_plus.is_finite() || !f_minus.is_finite() {
30 return Err(OptimizeError::ComputationError(
31 "Function returned non-finite value during gradient computation".to_string(),
32 ));
33 }
34
35 grad[i] = (f_plus - f_minus) / (2.0 * h);
36
37 x_plus[i] = x[i];
39 x_minus[i] = x[i];
40 }
41
42 Ok(grad)
43}
44
45pub fn finite_difference_hessian<F, S>(
47 fun: &mut F,
48 x: &ArrayView1<f64>,
49 step: f64,
50) -> Result<Array2<f64>, OptimizeError>
51where
52 F: FnMut(&ArrayView1<f64>) -> S,
53 S: Into<f64>,
54{
55 let n = x.len();
56 let mut hess = Array2::<f64>::zeros((n, n));
57 let mut x_temp = x.to_owned();
58
59 let f0 = fun(&x.view()).into();
60
61 for i in 0..n {
62 let hi = step * (1.0 + x[i].abs());
63
64 x_temp[i] = x[i] + hi;
66 let fp = fun(&x_temp.view()).into();
67 x_temp[i] = x[i] - hi;
68 let fm = fun(&x_temp.view()).into();
69 x_temp[i] = x[i];
70
71 hess[[i, i]] = (fp - 2.0 * f0 + fm) / (hi * hi);
72
73 for j in (i + 1)..n {
75 let hj = step * (1.0 + x[j].abs());
76
77 x_temp[i] = x[i] + hi;
78 x_temp[j] = x[j] + hj;
79 let fpp = fun(&x_temp.view()).into();
80
81 x_temp[i] = x[i] + hi;
82 x_temp[j] = x[j] - hj;
83 let fpm = fun(&x_temp.view()).into();
84
85 x_temp[i] = x[i] - hi;
86 x_temp[j] = x[j] + hj;
87 let fmp = fun(&x_temp.view()).into();
88
89 x_temp[i] = x[i] - hi;
90 x_temp[j] = x[j] - hj;
91 let fmm = fun(&x_temp.view()).into();
92
93 x_temp[i] = x[i];
94 x_temp[j] = x[j];
95
96 let hess_ij = (fpp - fpm - fmp + fmm) / (4.0 * hi * hj);
97 hess[[i, j]] = hess_ij;
98 hess[[j, i]] = hess_ij;
99 }
100 }
101
102 Ok(hess)
103}
104
105pub fn check_convergence(
107 f_delta: f64,
108 x_delta: f64,
109 g_norm: f64,
110 ftol: f64,
111 xtol: f64,
112 gtol: f64,
113) -> bool {
114 f_delta.abs() < ftol || x_delta < xtol || g_norm < gtol
115}
116
117pub fn array_diff_norm(x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
119 (x1 - x2).mapv(|x| x.powi(2)).sum().sqrt()
120}
121
122pub fn clip_step(
124 x: &ArrayView1<f64>,
125 direction: &ArrayView1<f64>,
126 alpha: f64,
127 lower: &[Option<f64>],
128 upper: &[Option<f64>],
129) -> f64 {
130 let mut clipped_alpha = alpha;
131
132 for i in 0..x.len() {
133 if direction[i] != 0.0 {
134 if let Some(lb) = lower[i] {
136 if direction[i] < 0.0 {
137 let max_step = (lb - x[i]) / direction[i];
138 if max_step >= 0.0 {
139 clipped_alpha = clipped_alpha.min(max_step);
140 }
141 }
142 }
143
144 if let Some(ub) = upper[i] {
146 if direction[i] > 0.0 {
147 let max_step = (ub - x[i]) / direction[i];
148 if max_step >= 0.0 {
149 clipped_alpha = clipped_alpha.min(max_step);
150 }
151 }
152 }
153 }
154 }
155
156 clipped_alpha.max(0.0)
157}
158
159pub fn to_array_view<T>(arr: &Array1<T>) -> ArrayView1<T> {
161 arr.view()
162}
163
164pub fn initial_step_size(grad_norm: f64, max_step: Option<f64>) -> f64 {
166 let default_step = if grad_norm > 0.0 {
167 1.0 / grad_norm
168 } else {
169 1.0
170 };
171
172 if let Some(max_s) = max_step {
173 default_step.min(max_s)
174 } else {
175 default_step
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use approx::assert_abs_diff_eq;
183
184 #[test]
185 fn test_finite_difference_gradient() {
186 let mut quadratic = |x: &ArrayView1<f64>| -> f64 { x[0] * x[0] + 2.0 * x[1] * x[1] };
187
188 let x = Array1::from_vec(vec![1.0, 2.0]);
189 let grad = finite_difference_gradient(&mut quadratic, &x.view(), 1e-8).unwrap();
190
191 assert_abs_diff_eq!(grad[0], 2.0, epsilon = 1e-6);
192 assert_abs_diff_eq!(grad[1], 8.0, epsilon = 1e-6);
193 }
194}