scirs2_optimize/unconstrained/
callback_diagnostics.rs1use crate::error::OptimizeError;
7use crate::unconstrained::convergence_diagnostics::{
8 ConvergenceDiagnostics, DiagnosticCollector, DiagnosticOptions, LineSearchDiagnostic,
9};
10use crate::unconstrained::OptimizeResult;
11use scirs2_core::ndarray::{Array1, ArrayView1};
12use std::cell::RefCell;
13use std::rc::Rc;
14
15pub type OptimizationCallback = Box<dyn FnMut(&CallbackInfo) -> CallbackResult>;
17
18#[derive(Debug, Clone)]
20pub struct CallbackInfo {
21 pub iteration: usize,
23 pub x: Array1<f64>,
25 pub f: f64,
27 pub grad: Array1<f64>,
29 pub step: Option<Array1<f64>>,
31 pub direction: Option<Array1<f64>>,
33 pub line_search: Option<LineSearchDiagnostic>,
35 pub elapsed_time: std::time::Duration,
37}
38
39#[derive(Debug, Clone, Copy)]
41pub enum CallbackResult {
42 Continue,
44 Stop,
46 StopWithMessage(&'static str),
48}
49
50pub struct DiagnosticOptimizer {
52 collector: Rc<RefCell<DiagnosticCollector>>,
54 callbacks: Vec<OptimizationCallback>,
56 start_time: std::time::Instant,
58}
59
60impl DiagnosticOptimizer {
61 pub fn new(_diagnostic_options: DiagnosticOptions) -> Self {
63 Self {
64 collector: Rc::new(RefCell::new(DiagnosticCollector::new(_diagnostic_options))),
65 callbacks: Vec::new(),
66 start_time: std::time::Instant::now(),
67 }
68 }
69
70 pub fn add_callback(&mut self, callback: OptimizationCallback) {
72 self.callbacks.push(callback);
73 }
74
75 pub fn add_progress_callback(&mut self, every_n_nit: usize) {
77 let mut last_printed = 0;
78 self.add_callback(Box::new(move |info| {
79 if info.iteration >= last_printed + every_n_nit {
80 println!(
81 "Iteration {}: f = {:.6e}, |grad| = {:.6e}",
82 info.iteration,
83 info.f,
84 info.grad.mapv(|x| x.abs()).sum()
85 );
86 last_printed = info.iteration;
87 }
88 CallbackResult::Continue
89 }));
90 }
91
92 pub fn add_convergence_monitor(&mut self, patience: usize, min_improvement: f64) {
94 let mut best_f = f64::INFINITY;
95 let mut no_improvement_count = 0;
96
97 self.add_callback(Box::new(move |info| {
98 if info.f < best_f - min_improvement {
99 best_f = info.f;
100 no_improvement_count = 0;
101 } else {
102 no_improvement_count += 1;
103 }
104
105 if no_improvement_count >= patience {
106 CallbackResult::StopWithMessage("Early stopping: no _improvement")
107 } else {
108 CallbackResult::Continue
109 }
110 }));
111 }
112
113 pub fn add_time_limit(&mut self, max_duration: std::time::Duration) {
115 self.add_callback(Box::new(move |info| {
116 if info.elapsed_time > max_duration {
117 CallbackResult::StopWithMessage("Time limit exceeded")
118 } else {
119 CallbackResult::Continue
120 }
121 }));
122 }
123
124 pub fn process_iteration(&mut self, info: &CallbackInfo) -> CallbackResult {
126 if let (Some(step), Some(direction), Some(line_search)) =
128 (&info.step, &info.direction, &info.line_search)
129 {
130 self.collector.borrow_mut().record_iteration(
131 info.f,
132 &info.grad.view(),
133 &step.view(),
134 &direction.view(),
135 line_search.clone(),
136 );
137 }
138
139 for callback in &mut self.callbacks {
141 match callback(info) {
142 CallbackResult::Continue => continue,
143 result => return result,
144 }
145 }
146
147 CallbackResult::Continue
148 }
149
150 pub fn get_diagnostics(self) -> ConvergenceDiagnostics {
152 let collector = Rc::try_unwrap(self.collector)
153 .expect("Failed to unwrap Rc")
154 .into_inner();
155 collector.finalize()
156 }
157}
158
159#[allow(dead_code)]
161pub fn optimize_with_diagnostics<F, O>(
162 optimizer_fn: O,
163 fun: F,
164 x0: Array1<f64>,
165 diagnostic_options: DiagnosticOptions,
166 callbacks: Vec<OptimizationCallback>,
167) -> Result<(OptimizeResult<f64>, ConvergenceDiagnostics), OptimizeError>
168where
169 F: FnMut(&ArrayView1<f64>) -> f64 + Clone,
170 O: FnOnce(
171 F,
172 Array1<f64>,
173 &mut DiagnosticOptimizer,
174 ) -> Result<OptimizeResult<f64>, OptimizeError>,
175{
176 let mut diagnostic_optimizer = DiagnosticOptimizer::new(diagnostic_options);
177
178 for callback in callbacks {
180 diagnostic_optimizer.add_callback(callback);
181 }
182
183 let result = optimizer_fn(fun, x0, &mut diagnostic_optimizer)?;
185
186 let diagnostics = diagnostic_optimizer.get_diagnostics();
188
189 Ok((result, diagnostics))
190}
191
192#[allow(dead_code)]
194pub fn minimize_with_diagnostics<F>(
195 mut fun: F,
196 x0: Array1<f64>,
197 options: &crate::unconstrained::Options,
198 diagnostic_optimizer: &mut DiagnosticOptimizer,
199) -> Result<OptimizeResult<f64>, OptimizeError>
200where
201 F: FnMut(&ArrayView1<f64>) -> f64,
202{
203 let mut x = x0.clone();
204 let mut f = fun(&x.view());
205 let mut iteration = 0;
206
207 loop {
209 let grad = finite_diff_gradient(&mut fun, &x.view(), 1e-8);
211
212 let direction = -&grad;
214
215 let alpha = 0.1;
217 let step = alpha * &direction;
218 let x_new = &x + &step;
219 let f_new = fun(&x_new.view());
220
221 let callback_info = CallbackInfo {
223 iteration,
224 x: x.clone(),
225 f,
226 grad: grad.clone(),
227 step: Some(step.clone()),
228 direction: Some(direction.clone()),
229 line_search: Some(LineSearchDiagnostic {
230 n_fev: 1,
231 n_gev: 1,
232 alpha,
233 alpha_init: 1.0,
234 success: f_new < f,
235 wolfe_satisfied: (true, true),
236 }),
237 elapsed_time: diagnostic_optimizer.start_time.elapsed(),
238 };
239
240 match diagnostic_optimizer.process_iteration(&callback_info) {
242 CallbackResult::Continue => {}
243 CallbackResult::Stop => break,
244 CallbackResult::StopWithMessage(msg) => {
245 return Ok(OptimizeResult {
246 x,
247 fun: f,
248 nit: iteration,
249 func_evals: iteration * 2,
250 nfev: iteration * 2,
251 success: false,
252 message: msg.to_string(),
253 jacobian: Some(grad),
254 hessian: None,
255 });
256 }
257 }
258
259 if grad.mapv(|x| x.abs()).sum() < options.gtol {
261 break;
262 }
263
264 x = x_new;
266 f = f_new;
267 iteration += 1;
268
269 if iteration >= options.max_iter {
270 break;
271 }
272 }
273
274 Ok(OptimizeResult {
275 x,
276 fun: f,
277 nit: iteration,
278 func_evals: iteration * 2,
279 nfev: iteration * 2,
280 success: iteration < options.max_iter,
281 message: if iteration < options.max_iter {
282 "Optimization converged".to_string()
283 } else {
284 "Maximum iterations reached".to_string()
285 },
286 jacobian: None,
287 hessian: None,
288 })
289}
290
291#[allow(dead_code)]
293fn finite_diff_gradient<F>(fun: &mut F, x: &ArrayView1<f64>, eps: f64) -> Array1<f64>
294where
295 F: FnMut(&ArrayView1<f64>) -> f64,
296{
297 let n = x.len();
298 let mut grad = Array1::zeros(n);
299 let f0 = fun(x);
300 let mut x_pert = x.to_owned();
301
302 for i in 0..n {
303 let h = eps * (1.0 + x[i].abs());
304 x_pert[i] = x[i] + h;
305 let f_plus = fun(&x_pert.view());
306 grad[i] = (f_plus - f0) / h;
307 x_pert[i] = x[i];
308 }
309
310 grad
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 #[test]
318 fn test_diagnostic_optimizer() {
319 let options = DiagnosticOptions::default();
320 let mut optimizer = DiagnosticOptimizer::new(options);
321
322 use std::sync::atomic::{AtomicBool, Ordering};
324 use std::sync::Arc;
325 let callback_called = Arc::new(AtomicBool::new(false));
326 let callback_called_clone = callback_called.clone();
327 optimizer.add_callback(Box::new(move |_info| {
328 callback_called_clone.store(true, Ordering::SeqCst);
329 CallbackResult::Continue
330 }));
331
332 let info = CallbackInfo {
334 iteration: 0,
335 x: Array1::zeros(2),
336 f: 1.0,
337 grad: Array1::ones(2),
338 step: Some(Array1::from_vec(vec![0.1, 0.1])),
339 direction: Some(Array1::from_vec(vec![-1.0, -1.0])),
340 line_search: Some(LineSearchDiagnostic {
341 n_fev: 1,
342 n_gev: 1,
343 alpha: 1.0,
344 alpha_init: 1.0,
345 success: true,
346 wolfe_satisfied: (true, true),
347 }),
348 elapsed_time: std::time::Duration::from_secs(0),
349 };
350
351 let result = optimizer.process_iteration(&info);
352 assert!(matches!(result, CallbackResult::Continue));
353 assert!(callback_called.load(Ordering::SeqCst));
354 }
355
356 #[test]
357 fn test_early_stopping_callback() {
358 let options = DiagnosticOptions::default();
359 let mut optimizer = DiagnosticOptimizer::new(options);
360
361 optimizer.add_convergence_monitor(2, 0.1);
362
363 let info = CallbackInfo {
365 iteration: 0,
366 x: Array1::zeros(2),
367 f: 1.0,
368 grad: Array1::ones(2),
369 step: None,
370 direction: None,
371 line_search: None,
372 elapsed_time: std::time::Duration::from_secs(0),
373 };
374
375 assert!(matches!(
377 optimizer.process_iteration(&info),
378 CallbackResult::Continue
379 ));
380 assert!(matches!(
381 optimizer.process_iteration(&info),
382 CallbackResult::Continue
383 ));
384
385 let result = optimizer.process_iteration(&info);
387 assert!(matches!(result, CallbackResult::StopWithMessage(_)));
388 }
389}