1use crate::error::OptimizeError;
9use crate::unconstrained::{minimize, Method, Options};
10use scirs2_core::ndarray::{Array1, ArrayView1};
11
12#[derive(Debug, Clone)]
14pub struct AugmentedLagrangianOptions {
15 pub max_iter: usize,
17 pub constraint_tol: f64,
19 pub optimality_tol: f64,
21 pub initial_penalty: f64,
23 pub penalty_update_factor: f64,
25 pub max_penalty: f64,
27 pub multiplier_update_tol: f64,
29 pub max_constraint_violation: f64,
31 pub unconstrained_method: Method,
33 pub unconstrained_options: Options,
35 pub trust_radius: Option<f64>,
37 pub adaptive_penalty: bool,
39}
40
41impl Default for AugmentedLagrangianOptions {
42 fn default() -> Self {
43 Self {
44 max_iter: 100,
45 constraint_tol: 1e-6,
46 optimality_tol: 1e-6,
47 initial_penalty: 1.0,
48 penalty_update_factor: 10.0,
49 max_penalty: 1e8,
50 multiplier_update_tol: 1e-8,
51 max_constraint_violation: 1e-3,
52 unconstrained_method: Method::LBFGS,
53 unconstrained_options: Options::default(),
54 trust_radius: None,
55 adaptive_penalty: true,
56 }
57 }
58}
59
60#[derive(Debug, Clone)]
62pub struct AugmentedLagrangianResult {
63 pub x: Array1<f64>,
65 pub fun: f64,
67 pub lambda_eq: Option<Array1<f64>>,
69 pub lambda_ineq: Option<Array1<f64>>,
71 pub nit: usize,
73 pub nfev: usize,
75 pub success: bool,
77 pub message: String,
79 pub penalty: f64,
81 pub constraint_violation: f64,
83 pub optimality: f64,
85}
86
87struct AugmentedLagrangianState {
89 penalty: f64,
91 lambda_eq: Option<Array1<f64>>,
93 lambda_ineq: Option<Array1<f64>>,
95 #[allow(dead_code)]
97 n: usize,
98 #[allow(dead_code)]
100 m_eq: usize,
101 #[allow(dead_code)]
103 m_ineq: usize,
104}
105
106impl AugmentedLagrangianState {
107 fn new(n: usize, m_eq: usize, m_ineq: usize, initial_penalty: f64) -> Self {
108 Self {
109 penalty: initial_penalty,
110 lambda_eq: if m_eq > 0 {
111 Some(Array1::zeros(m_eq))
112 } else {
113 None
114 },
115 lambda_ineq: if m_ineq > 0 {
116 Some(Array1::zeros(m_ineq))
117 } else {
118 None
119 },
120 n,
121 m_eq,
122 m_ineq,
123 }
124 }
125
126 fn update_multipliers(&mut self, c_eq: &Option<Array1<f64>>, c_ineq: &Option<Array1<f64>>) {
128 if let (Some(ref mut lambda), Some(ref c)) = (&mut self.lambda_eq, c_eq) {
130 for i in 0..lambda.len() {
131 lambda[i] += self.penalty * c[i];
132 }
133 }
134
135 if let (Some(ref mut lambda), Some(ref c)) = (&mut self.lambda_ineq, c_ineq) {
137 for i in 0..lambda.len() {
138 lambda[i] = f64::max(0.0, lambda[i] + self.penalty * c[i]);
139 }
140 }
141 }
142
143 fn compute_constraint_violation(
145 &self,
146 c_eq: &Option<Array1<f64>>,
147 c_ineq: &Option<Array1<f64>>,
148 ) -> f64 {
149 let mut violation = 0.0;
150
151 if let Some(ref c) = c_eq {
153 violation += c.mapv(|x| x.abs()).sum();
154 }
155
156 if let Some(ref c) = c_ineq {
158 violation += c.mapv(|x| f64::max(0.0, x)).sum();
159 }
160
161 violation
162 }
163}
164
165#[allow(dead_code)]
167pub fn minimize_augmented_lagrangian<F, EqCon, IneqCon>(
168 fun: F,
169 x0: Array1<f64>,
170 eq_constraints: Option<EqCon>,
171 ineq_constraints: Option<IneqCon>,
172 options: Option<AugmentedLagrangianOptions>,
173) -> Result<AugmentedLagrangianResult, OptimizeError>
174where
175 F: Fn(&ArrayView1<f64>) -> f64 + Clone,
176 EqCon: Fn(&ArrayView1<f64>) -> Array1<f64> + Clone,
177 IneqCon: Fn(&ArrayView1<f64>) -> Array1<f64> + Clone,
178{
179 let options = options.unwrap_or_default();
180 let n = x0.len();
181
182 let m_eq = if let Some(ref eq_con) = eq_constraints {
184 eq_con(&x0.view()).len()
186 } else {
187 0
188 };
189
190 let m_ineq = if let Some(ref ineq_con) = ineq_constraints {
191 ineq_con(&x0.view()).len()
193 } else {
194 0
195 };
196
197 let mut state = AugmentedLagrangianState::new(n, m_eq, m_ineq, options.initial_penalty);
198 let mut x = x0.clone();
199 let mut total_nfev = 0;
200
201 for iter in 0..options.max_iter {
202 let penalty = state.penalty;
204 let lambda_eq = state.lambda_eq.clone();
205 let lambda_ineq = state.lambda_ineq.clone();
206 let fun_clone = fun.clone();
207 let eq_con_clone = eq_constraints.clone();
208 let ineq_con_clone = ineq_constraints.clone();
209
210 let augmented_lagrangian = move |x: &ArrayView1<f64>| -> f64 {
211 let mut result = fun_clone(x);
212
213 if let (Some(ref eq_con), Some(ref lambda)) = (&eq_con_clone, &lambda_eq) {
215 let c_eq = eq_con(x);
216 for i in 0..c_eq.len() {
217 result += lambda[i] * c_eq[i] + 0.5 * penalty * c_eq[i].powi(2);
218 }
219 }
220
221 if let (Some(ref ineq_con), Some(ref lambda)) = (&ineq_con_clone, &lambda_ineq) {
223 let c_ineq = ineq_con(x);
224 for i in 0..c_ineq.len() {
225 let augmented_term = lambda[i] + penalty * c_ineq[i];
226 if augmented_term > 0.0 {
227 result += augmented_term * c_ineq[i] - 0.5 / penalty * lambda[i].powi(2);
228 } else {
229 result -= 0.5 / penalty * lambda[i].powi(2);
230 }
231 }
232 }
233
234 result
235 };
236
237 let result = minimize(
239 augmented_lagrangian,
240 x.as_slice().unwrap(),
241 options.unconstrained_method,
242 Some(options.unconstrained_options.clone()),
243 )?;
244
245 x = result.x;
246 total_nfev += result.func_evals;
247
248 let c_eq = eq_constraints.as_ref().map(|f| f(&x.view()));
250 let c_ineq = ineq_constraints.as_ref().map(|f| f(&x.view()));
251
252 let constraint_violation = state.compute_constraint_violation(&c_eq, &c_ineq);
254 let optimality = compute_optimality(&fun, &x, &c_eq, &c_ineq, &state);
255
256 if constraint_violation < options.constraint_tol && optimality < options.optimality_tol {
258 let final_fun = fun(&x.view());
259 return Ok(AugmentedLagrangianResult {
260 x,
261 fun: final_fun,
262 lambda_eq: state.lambda_eq.clone(),
263 lambda_ineq: state.lambda_ineq.clone(),
264 nit: iter,
265 nfev: total_nfev,
266 success: true,
267 message: "Optimization terminated successfully.".to_string(),
268 penalty: state.penalty,
269 constraint_violation,
270 optimality,
271 });
272 }
273
274 state.update_multipliers(&c_eq, &c_ineq);
276
277 if options.adaptive_penalty && constraint_violation > options.max_constraint_violation {
279 state.penalty =
280 (state.penalty * options.penalty_update_factor).min(options.max_penalty);
281 }
282 }
283
284 let c_eq = eq_constraints.as_ref().map(|f| f(&x.view()));
286 let c_ineq = ineq_constraints.as_ref().map(|f| f(&x.view()));
287 let final_violation = state.compute_constraint_violation(&c_eq, &c_ineq);
288 let final_optimality = compute_optimality(&fun, &x, &c_eq, &c_ineq, &state);
289
290 let final_fun = fun(&x.view());
291 Ok(AugmentedLagrangianResult {
292 x,
293 fun: final_fun,
294 lambda_eq: state.lambda_eq,
295 lambda_ineq: state.lambda_ineq,
296 nit: options.max_iter,
297 nfev: total_nfev,
298 success: false,
299 message: "Maximum iterations reached.".to_string(),
300 penalty: state.penalty,
301 constraint_violation: final_violation,
302 optimality: final_optimality,
303 })
304}
305
306#[allow(dead_code)]
308fn create_augmented_lagrangian<'a, F, EqCon, IneqCon>(
309 fun: &'a F,
310 eq_constraints: &'a Option<EqCon>,
311 ineq_constraints: &'a Option<IneqCon>,
312 state: &'a AugmentedLagrangianState,
313) -> impl Fn(&ArrayView1<f64>) -> f64 + 'a
314where
315 F: Fn(&ArrayView1<f64>) -> f64,
316 EqCon: Fn(&ArrayView1<f64>) -> Array1<f64>,
317 IneqCon: Fn(&ArrayView1<f64>) -> Array1<f64>,
318{
319 move |x: &ArrayView1<f64>| -> f64 {
320 let mut result = fun(x);
321
322 if let (Some(ref eq_con), Some(ref lambda_eq)) = (eq_constraints, &state.lambda_eq) {
324 let c_eq = eq_con(x);
325 for i in 0..c_eq.len() {
326 result += lambda_eq[i] * c_eq[i] + 0.5 * state.penalty * c_eq[i].powi(2);
327 }
328 }
329
330 if let (Some(ref ineq_con), Some(ref lambda_ineq)) = (ineq_constraints, &state.lambda_ineq)
332 {
333 let c_ineq = ineq_con(x);
334 for i in 0..c_ineq.len() {
335 let augmented_term = lambda_ineq[i] + state.penalty * c_ineq[i];
336 if augmented_term > 0.0 {
337 result +=
338 augmented_term * c_ineq[i] - 0.5 / state.penalty * lambda_ineq[i].powi(2);
339 } else {
340 result -= 0.5 / state.penalty * lambda_ineq[i].powi(2);
341 }
342 }
343 }
344
345 result
346 }
347}
348
349#[allow(dead_code)]
351fn compute_optimality<F>(
352 fun: &F,
353 x: &Array1<f64>,
354 _c_eq: &Option<Array1<f64>>,
355 _c_ineq: &Option<Array1<f64>>,
356 _state: &AugmentedLagrangianState,
357) -> f64
358where
359 F: Fn(&ArrayView1<f64>) -> f64,
360{
361 let eps = 1e-8;
363 let mut grad = Array1::zeros(x.len());
364 let f0 = fun(&x.view());
365
366 for i in 0..x.len() {
367 let mut x_plus = x.clone();
368 x_plus[i] += eps;
369 let f_plus = fun(&x_plus.view());
370 grad[i] = (f_plus - f0) / eps;
371 }
372
373 grad.mapv(|x| x.abs()).sum()
376}
377
378#[allow(dead_code)]
380pub fn minimize_equality_constrained<F, EqCon>(
381 fun: F,
382 x0: Array1<f64>,
383 eq_constraints: EqCon,
384 options: Option<AugmentedLagrangianOptions>,
385) -> Result<AugmentedLagrangianResult, OptimizeError>
386where
387 F: Fn(&ArrayView1<f64>) -> f64 + Clone,
388 EqCon: Fn(&ArrayView1<f64>) -> Array1<f64> + Clone,
389{
390 minimize_augmented_lagrangian(
391 fun,
392 x0,
393 Some(eq_constraints),
394 None::<fn(&ArrayView1<f64>) -> Array1<f64>>,
395 options,
396 )
397}
398
399#[allow(dead_code)]
401pub fn minimize_inequality_constrained<F, IneqCon>(
402 fun: F,
403 x0: Array1<f64>,
404 ineq_constraints: IneqCon,
405 options: Option<AugmentedLagrangianOptions>,
406) -> Result<AugmentedLagrangianResult, OptimizeError>
407where
408 F: Fn(&ArrayView1<f64>) -> f64 + Clone,
409 IneqCon: Fn(&ArrayView1<f64>) -> Array1<f64> + Clone,
410{
411 minimize_augmented_lagrangian(
412 fun,
413 x0,
414 None::<fn(&ArrayView1<f64>) -> Array1<f64>>,
415 Some(ineq_constraints),
416 options,
417 )
418}
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423 use approx::assert_abs_diff_eq;
424 use scirs2_core::ndarray::array;
425
426 #[test]
427 fn test_augmented_lagrangian_equality_constraint() {
428 let fun = |x: &ArrayView1<f64>| x[0].powi(2) + x[1].powi(2);
430 let eq_con = |x: &ArrayView1<f64>| array![x[0] + x[1] - 1.0];
431
432 let x0 = array![0.0, 0.0];
433 let options = AugmentedLagrangianOptions {
434 max_iter: 50,
435 constraint_tol: 1e-6,
436 optimality_tol: 1e-6,
437 ..Default::default()
438 };
439
440 let result = minimize_equality_constrained(fun, x0, eq_con, Some(options)).unwrap();
441
442 assert!(result.nit > 0);
444 assert_abs_diff_eq!(result.x[0], 0.5, epsilon = 1e-3);
446 assert_abs_diff_eq!(result.x[1], 0.5, epsilon = 1e-3);
447 assert_abs_diff_eq!(result.fun, 0.5, epsilon = 1e-3);
448 }
449
450 #[test]
451 fn test_augmented_lagrangian_inequality_constraint() {
452 let fun = |x: &ArrayView1<f64>| x[0].powi(2) + x[1].powi(2);
454 let ineq_con = |x: &ArrayView1<f64>| array![1.0 - x[0] - x[1]]; let x0 = array![2.0, 2.0];
457 let options = AugmentedLagrangianOptions {
458 max_iter: 50,
459 constraint_tol: 1e-6,
460 optimality_tol: 1e-6,
461 ..Default::default()
462 };
463
464 let result = minimize_inequality_constrained(fun, x0, ineq_con, Some(options)).unwrap();
465
466 assert!(result.nit > 0);
468 assert_abs_diff_eq!(result.x[0], 0.5, epsilon = 1e-3);
470 assert_abs_diff_eq!(result.x[1], 0.5, epsilon = 1e-3);
471 assert_abs_diff_eq!(result.fun, 0.5, epsilon = 1e-3);
472 }
473
474 #[test]
475 fn test_augmented_lagrangian_mixed_constraints() {
476 let fun = |x: &ArrayView1<f64>| (x[0] - 1.0).powi(2) + (x[1] - 2.0).powi(2);
478 let eq_con = |x: &ArrayView1<f64>| array![x[0] + x[1] - 3.0];
479 let ineq_con = |x: &ArrayView1<f64>| array![-x[0]]; let x0 = array![1.0, 1.0];
482 let options = AugmentedLagrangianOptions {
483 max_iter: 50,
484 constraint_tol: 1e-6,
485 optimality_tol: 1e-6,
486 ..Default::default()
487 };
488
489 let result =
490 minimize_augmented_lagrangian(fun, x0, Some(eq_con), Some(ineq_con), Some(options))
491 .unwrap();
492
493 assert!(result.nit > 0);
495 assert_abs_diff_eq!(result.x[0], 1.0, epsilon = 1e-3);
497 assert_abs_diff_eq!(result.x[1], 2.0, epsilon = 1e-3);
498 assert_abs_diff_eq!(result.fun, 0.0, epsilon = 1e-3);
499 }
500
501 #[test]
502 fn test_augmented_lagrangian_unconstrained() {
503 let fun = |x: &ArrayView1<f64>| (x[0] - 2.0).powi(2) + (x[1] - 3.0).powi(2);
505
506 let x0 = array![0.0, 0.0];
507 let options = AugmentedLagrangianOptions {
508 max_iter: 50,
509 ..Default::default()
510 };
511
512 let result = minimize_augmented_lagrangian(
513 fun,
514 x0,
515 None::<fn(&ArrayView1<f64>) -> Array1<f64>>,
516 None::<fn(&ArrayView1<f64>) -> Array1<f64>>,
517 Some(options),
518 )
519 .unwrap();
520
521 assert!(result.fun < 4.0); }
524}