scirs2_optimize/unconstrained/
callback_diagnostics.rs

1//! Integration of convergence diagnostics with callback system
2//!
3//! This module provides utilities for integrating convergence diagnostics
4//! with optimization algorithms through a callback mechanism.
5
6use crate::error::OptimizeError;
7use crate::unconstrained::convergence_diagnostics::{
8    ConvergenceDiagnostics, DiagnosticCollector, DiagnosticOptions, LineSearchDiagnostic,
9};
10use crate::unconstrained::OptimizeResult;
11use scirs2_core::ndarray::{Array1, ArrayView1};
12use std::cell::RefCell;
13use std::rc::Rc;
14
15/// Callback function type for optimization monitoring
16pub type OptimizationCallback = Box<dyn FnMut(&CallbackInfo) -> CallbackResult>;
17
18/// Information passed to callback functions
19#[derive(Debug, Clone)]
20pub struct CallbackInfo {
21    /// Current iteration number
22    pub iteration: usize,
23    /// Current point
24    pub x: Array1<f64>,
25    /// Current function value
26    pub f: f64,
27    /// Current gradient
28    pub grad: Array1<f64>,
29    /// Step taken (if available)
30    pub step: Option<Array1<f64>>,
31    /// Search direction (if available)
32    pub direction: Option<Array1<f64>>,
33    /// Line search information (if available)
34    pub line_search: Option<LineSearchDiagnostic>,
35    /// Time elapsed since start
36    pub elapsed_time: std::time::Duration,
37}
38
39/// Result from callback function
40#[derive(Debug, Clone, Copy)]
41pub enum CallbackResult {
42    /// Continue optimization
43    Continue,
44    /// Stop optimization
45    Stop,
46    /// Stop with custom message
47    StopWithMessage(&'static str),
48}
49
50/// Wrapper for optimization with diagnostic callbacks
51pub struct DiagnosticOptimizer {
52    /// Diagnostic collector
53    collector: Rc<RefCell<DiagnosticCollector>>,
54    /// User callbacks
55    callbacks: Vec<OptimizationCallback>,
56    /// Start time
57    start_time: std::time::Instant,
58}
59
60impl DiagnosticOptimizer {
61    /// Create new diagnostic optimizer
62    pub fn new(_diagnostic_options: DiagnosticOptions) -> Self {
63        Self {
64            collector: Rc::new(RefCell::new(DiagnosticCollector::new(_diagnostic_options))),
65            callbacks: Vec::new(),
66            start_time: std::time::Instant::now(),
67        }
68    }
69
70    /// Add a callback
71    pub fn add_callback(&mut self, callback: OptimizationCallback) {
72        self.callbacks.push(callback);
73    }
74
75    /// Add a simple progress callback
76    pub fn add_progress_callback(&mut self, every_n_nit: usize) {
77        let mut last_printed = 0;
78        self.add_callback(Box::new(move |info| {
79            if info.iteration >= last_printed + every_n_nit {
80                println!(
81                    "Iteration {}: f = {:.6e}, |grad| = {:.6e}",
82                    info.iteration,
83                    info.f,
84                    info.grad.mapv(|x| x.abs()).sum()
85                );
86                last_printed = info.iteration;
87            }
88            CallbackResult::Continue
89        }));
90    }
91
92    /// Add a convergence monitoring callback
93    pub fn add_convergence_monitor(&mut self, patience: usize, min_improvement: f64) {
94        let mut best_f = f64::INFINITY;
95        let mut no_improvement_count = 0;
96
97        self.add_callback(Box::new(move |info| {
98            if info.f < best_f - min_improvement {
99                best_f = info.f;
100                no_improvement_count = 0;
101            } else {
102                no_improvement_count += 1;
103            }
104
105            if no_improvement_count >= patience {
106                CallbackResult::StopWithMessage("Early stopping: no _improvement")
107            } else {
108                CallbackResult::Continue
109            }
110        }));
111    }
112
113    /// Add a time limit callback
114    pub fn add_time_limit(&mut self, max_duration: std::time::Duration) {
115        self.add_callback(Box::new(move |info| {
116            if info.elapsed_time > max_duration {
117                CallbackResult::StopWithMessage("Time limit exceeded")
118            } else {
119                CallbackResult::Continue
120            }
121        }));
122    }
123
124    /// Process callbacks and update diagnostics
125    pub fn process_iteration(&mut self, info: &CallbackInfo) -> CallbackResult {
126        // Update diagnostic collector
127        if let (Some(step), Some(direction), Some(line_search)) =
128            (&info.step, &info.direction, &info.line_search)
129        {
130            self.collector.borrow_mut().record_iteration(
131                info.f,
132                &info.grad.view(),
133                &step.view(),
134                &direction.view(),
135                line_search.clone(),
136            );
137        }
138
139        // Process user callbacks
140        for callback in &mut self.callbacks {
141            match callback(info) {
142                CallbackResult::Continue => continue,
143                result => return result,
144            }
145        }
146
147        CallbackResult::Continue
148    }
149
150    /// Get final diagnostics
151    pub fn get_diagnostics(self) -> ConvergenceDiagnostics {
152        let collector = Rc::try_unwrap(self.collector)
153            .expect("Failed to unwrap Rc")
154            .into_inner();
155        collector.finalize()
156    }
157}
158
159/// Optimization wrapper that integrates diagnostics
160#[allow(dead_code)]
161pub fn optimize_with_diagnostics<F, O>(
162    optimizer_fn: O,
163    fun: F,
164    x0: Array1<f64>,
165    diagnostic_options: DiagnosticOptions,
166    callbacks: Vec<OptimizationCallback>,
167) -> Result<(OptimizeResult<f64>, ConvergenceDiagnostics), OptimizeError>
168where
169    F: FnMut(&ArrayView1<f64>) -> f64 + Clone,
170    O: FnOnce(
171        F,
172        Array1<f64>,
173        &mut DiagnosticOptimizer,
174    ) -> Result<OptimizeResult<f64>, OptimizeError>,
175{
176    let mut diagnostic_optimizer = DiagnosticOptimizer::new(diagnostic_options);
177
178    // Add user callbacks
179    for callback in callbacks {
180        diagnostic_optimizer.add_callback(callback);
181    }
182
183    // Run optimization
184    let result = optimizer_fn(fun, x0, &mut diagnostic_optimizer)?;
185
186    // Get diagnostics
187    let diagnostics = diagnostic_optimizer.get_diagnostics();
188
189    Ok((result, diagnostics))
190}
191
192/// Example of integrating diagnostics into an optimization algorithm
193#[allow(dead_code)]
194pub fn minimize_with_diagnostics<F>(
195    mut fun: F,
196    x0: Array1<f64>,
197    options: &crate::unconstrained::Options,
198    diagnostic_optimizer: &mut DiagnosticOptimizer,
199) -> Result<OptimizeResult<f64>, OptimizeError>
200where
201    F: FnMut(&ArrayView1<f64>) -> f64,
202{
203    let mut x = x0.clone();
204    let mut f = fun(&x.view());
205    let mut iteration = 0;
206
207    // Simplified optimization loop for demonstration
208    loop {
209        // Compute gradient (simplified - would use finite differences or AD)
210        let grad = finite_diff_gradient(&mut fun, &x.view(), 1e-8);
211
212        // Compute search direction (simplified - just negative gradient)
213        let direction = -&grad;
214
215        // Line search (simplified)
216        let alpha = 0.1;
217        let step = alpha * &direction;
218        let x_new = &x + &step;
219        let f_new = fun(&x_new.view());
220
221        // Create callback info
222        let callback_info = CallbackInfo {
223            iteration,
224            x: x.clone(),
225            f,
226            grad: grad.clone(),
227            step: Some(step.clone()),
228            direction: Some(direction.clone()),
229            line_search: Some(LineSearchDiagnostic {
230                n_fev: 1,
231                n_gev: 1,
232                alpha,
233                alpha_init: 1.0,
234                success: f_new < f,
235                wolfe_satisfied: (true, true),
236            }),
237            elapsed_time: diagnostic_optimizer.start_time.elapsed(),
238        };
239
240        // Process callbacks
241        match diagnostic_optimizer.process_iteration(&callback_info) {
242            CallbackResult::Continue => {}
243            CallbackResult::Stop => break,
244            CallbackResult::StopWithMessage(msg) => {
245                return Ok(OptimizeResult {
246                    x,
247                    fun: f,
248                    nit: iteration,
249                    func_evals: iteration * 2,
250                    nfev: iteration * 2,
251                    success: false,
252                    message: msg.to_string(),
253                    jacobian: Some(grad),
254                    hessian: None,
255                });
256            }
257        }
258
259        // Check convergence
260        if grad.mapv(|x| x.abs()).sum() < options.gtol {
261            break;
262        }
263
264        // Update state
265        x = x_new;
266        f = f_new;
267        iteration += 1;
268
269        if iteration >= options.max_iter {
270            break;
271        }
272    }
273
274    Ok(OptimizeResult {
275        x,
276        fun: f,
277        nit: iteration,
278        func_evals: iteration * 2,
279        nfev: iteration * 2,
280        success: iteration < options.max_iter,
281        message: if iteration < options.max_iter {
282            "Optimization converged".to_string()
283        } else {
284            "Maximum iterations reached".to_string()
285        },
286        jacobian: None,
287        hessian: None,
288    })
289}
290
291/// Simple finite difference gradient
292#[allow(dead_code)]
293fn finite_diff_gradient<F>(fun: &mut F, x: &ArrayView1<f64>, eps: f64) -> Array1<f64>
294where
295    F: FnMut(&ArrayView1<f64>) -> f64,
296{
297    let n = x.len();
298    let mut grad = Array1::zeros(n);
299    let f0 = fun(x);
300    let mut x_pert = x.to_owned();
301
302    for i in 0..n {
303        let h = eps * (1.0 + x[i].abs());
304        x_pert[i] = x[i] + h;
305        let f_plus = fun(&x_pert.view());
306        grad[i] = (f_plus - f0) / h;
307        x_pert[i] = x[i];
308    }
309
310    grad
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    #[test]
318    fn test_diagnostic_optimizer() {
319        let options = DiagnosticOptions::default();
320        let mut optimizer = DiagnosticOptimizer::new(options);
321
322        // Add a simple callback
323        use std::sync::atomic::{AtomicBool, Ordering};
324        use std::sync::Arc;
325        let callback_called = Arc::new(AtomicBool::new(false));
326        let callback_called_clone = callback_called.clone();
327        optimizer.add_callback(Box::new(move |_info| {
328            callback_called_clone.store(true, Ordering::SeqCst);
329            CallbackResult::Continue
330        }));
331
332        // Create test callback info
333        let info = CallbackInfo {
334            iteration: 0,
335            x: Array1::zeros(2),
336            f: 1.0,
337            grad: Array1::ones(2),
338            step: Some(Array1::from_vec(vec![0.1, 0.1])),
339            direction: Some(Array1::from_vec(vec![-1.0, -1.0])),
340            line_search: Some(LineSearchDiagnostic {
341                n_fev: 1,
342                n_gev: 1,
343                alpha: 1.0,
344                alpha_init: 1.0,
345                success: true,
346                wolfe_satisfied: (true, true),
347            }),
348            elapsed_time: std::time::Duration::from_secs(0),
349        };
350
351        let result = optimizer.process_iteration(&info);
352        assert!(matches!(result, CallbackResult::Continue));
353        assert!(callback_called.load(Ordering::SeqCst));
354    }
355
356    #[test]
357    fn test_early_stopping_callback() {
358        let options = DiagnosticOptions::default();
359        let mut optimizer = DiagnosticOptimizer::new(options);
360
361        optimizer.add_convergence_monitor(2, 0.1);
362
363        // Create test info with no improvement
364        let info = CallbackInfo {
365            iteration: 0,
366            x: Array1::zeros(2),
367            f: 1.0,
368            grad: Array1::ones(2),
369            step: None,
370            direction: None,
371            line_search: None,
372            elapsed_time: std::time::Duration::from_secs(0),
373        };
374
375        // First two iterations should continue
376        assert!(matches!(
377            optimizer.process_iteration(&info),
378            CallbackResult::Continue
379        ));
380        assert!(matches!(
381            optimizer.process_iteration(&info),
382            CallbackResult::Continue
383        ));
384
385        // Third iteration with no improvement should stop
386        let result = optimizer.process_iteration(&info);
387        assert!(matches!(result, CallbackResult::StopWithMessage(_)));
388    }
389}