Skip to main content

scirs2_optimize/bilevel/
methods.rs

1//! Bilevel optimization methods
2//!
3//! Implements:
4//! - KKT-based single-level reduction
5//! - Penalty-based sequential optimization approach (PSOA)
6//! - Replacement algorithm (optimal reaction)
7
8use crate::error::{OptimizeError, OptimizeResult};
9
10/// Result of a bilevel optimization
11#[derive(Debug, Clone)]
12pub struct BilevelResult {
13    /// Upper-level decision variables
14    pub x_upper: Vec<f64>,
15    /// Lower-level decision variables (optimal response)
16    pub y_lower: Vec<f64>,
17    /// Upper-level objective value at solution
18    pub upper_fun: f64,
19    /// Lower-level objective value at solution
20    pub lower_fun: f64,
21    /// Number of outer (upper-level) iterations
22    pub n_outer_iter: usize,
23    /// Number of inner (lower-level) solves
24    pub n_inner_solves: usize,
25    /// Total number of function evaluations
26    pub nfev: usize,
27    /// Whether the algorithm converged
28    pub success: bool,
29    /// Termination message
30    pub message: String,
31}
32
33/// General options shared by bilevel solvers
34#[derive(Debug, Clone)]
35pub struct BilevelSolverOptions {
36    /// Maximum outer iterations
37    pub max_outer_iter: usize,
38    /// Maximum inner iterations per outer step
39    pub max_inner_iter: usize,
40    /// Convergence tolerance on upper-level objective change
41    pub outer_tol: f64,
42    /// Convergence tolerance for lower-level subproblem
43    pub inner_tol: f64,
44    /// Whether to print iteration progress
45    pub verbose: bool,
46}
47
48impl Default for BilevelSolverOptions {
49    fn default() -> Self {
50        BilevelSolverOptions {
51            max_outer_iter: 200,
52            max_inner_iter: 500,
53            outer_tol: 1e-7,
54            inner_tol: 1e-9,
55            verbose: false,
56        }
57    }
58}
59
60/// Bilevel problem descriptor
61///
62/// Encapsulates upper and lower level objectives and optional constraints.
63pub struct BilevelProblem<F, G>
64where
65    F: Fn(&[f64], &[f64]) -> f64,
66    G: Fn(&[f64], &[f64]) -> f64,
67{
68    /// Upper-level objective F(x, y)
69    pub upper_obj: F,
70    /// Lower-level objective f(x, y)
71    pub lower_obj: G,
72    /// Initial upper-level variables
73    pub x0: Vec<f64>,
74    /// Initial lower-level variables
75    pub y0: Vec<f64>,
76    /// Upper-level inequality constraints G_i(x,y) <= 0
77    pub upper_constraints: Vec<Box<dyn Fn(&[f64], &[f64]) -> f64>>,
78    /// Lower-level inequality constraints g_j(x,y) <= 0
79    pub lower_constraints: Vec<Box<dyn Fn(&[f64], &[f64]) -> f64>>,
80    /// Optional lower bounds on x
81    pub x_lb: Option<Vec<f64>>,
82    /// Optional upper bounds on x
83    pub x_ub: Option<Vec<f64>>,
84    /// Optional lower bounds on y
85    pub y_lb: Option<Vec<f64>>,
86    /// Optional upper bounds on y
87    pub y_ub: Option<Vec<f64>>,
88}
89
90impl<F, G> BilevelProblem<F, G>
91where
92    F: Fn(&[f64], &[f64]) -> f64,
93    G: Fn(&[f64], &[f64]) -> f64,
94{
95    /// Create a new bilevel problem with no constraints
96    pub fn new(upper_obj: F, lower_obj: G, x0: Vec<f64>, y0: Vec<f64>) -> Self {
97        BilevelProblem {
98            upper_obj,
99            lower_obj,
100            x0,
101            y0,
102            upper_constraints: Vec::new(),
103            lower_constraints: Vec::new(),
104            x_lb: None,
105            x_ub: None,
106            y_lb: None,
107            y_ub: None,
108        }
109    }
110
111    /// Add an upper-level inequality constraint G(x,y) <= 0
112    pub fn with_upper_constraint(
113        mut self,
114        constraint: impl Fn(&[f64], &[f64]) -> f64 + 'static,
115    ) -> Self {
116        self.upper_constraints.push(Box::new(constraint));
117        self
118    }
119
120    /// Add a lower-level inequality constraint g(x,y) <= 0
121    pub fn with_lower_constraint(
122        mut self,
123        constraint: impl Fn(&[f64], &[f64]) -> f64 + 'static,
124    ) -> Self {
125        self.lower_constraints.push(Box::new(constraint));
126        self
127    }
128
129    /// Set bounds on upper-level variables
130    pub fn with_x_bounds(mut self, lb: Vec<f64>, ub: Vec<f64>) -> Self {
131        self.x_lb = Some(lb);
132        self.x_ub = Some(ub);
133        self
134    }
135
136    /// Set bounds on lower-level variables
137    pub fn with_y_bounds(mut self, lb: Vec<f64>, ub: Vec<f64>) -> Self {
138        self.y_lb = Some(lb);
139        self.y_ub = Some(ub);
140        self
141    }
142
143    /// Evaluate upper objective
144    pub fn eval_upper(&self, x: &[f64], y: &[f64]) -> f64 {
145        (self.upper_obj)(x, y)
146    }
147
148    /// Evaluate lower objective
149    pub fn eval_lower(&self, x: &[f64], y: &[f64]) -> f64 {
150        (self.lower_obj)(x, y)
151    }
152
153    /// Compute upper constraint violation (sum of max(0, G_i))
154    pub fn upper_constraint_violation(&self, x: &[f64], y: &[f64]) -> f64 {
155        self.upper_constraints
156            .iter()
157            .map(|g| (g(x, y)).max(0.0))
158            .sum()
159    }
160
161    /// Compute lower constraint violation (sum of max(0, g_j))
162    pub fn lower_constraint_violation(&self, x: &[f64], y: &[f64]) -> f64 {
163        self.lower_constraints
164            .iter()
165            .map(|g| (g(x, y)).max(0.0))
166            .sum()
167    }
168
169    /// Project y onto its bounds
170    pub fn project_y(&self, y: &[f64]) -> Vec<f64> {
171        let n = y.len();
172        let mut yp = y.to_vec();
173        if let Some(ref lb) = self.y_lb {
174            for i in 0..n.min(lb.len()) {
175                if yp[i] < lb[i] {
176                    yp[i] = lb[i];
177                }
178            }
179        }
180        if let Some(ref ub) = self.y_ub {
181            for i in 0..n.min(ub.len()) {
182                if yp[i] > ub[i] {
183                    yp[i] = ub[i];
184                }
185            }
186        }
187        yp
188    }
189
190    /// Project x onto its bounds
191    pub fn project_x(&self, x: &[f64]) -> Vec<f64> {
192        let n = x.len();
193        let mut xp = x.to_vec();
194        if let Some(ref lb) = self.x_lb {
195            for i in 0..n.min(lb.len()) {
196                if xp[i] < lb[i] {
197                    xp[i] = lb[i];
198                }
199            }
200        }
201        if let Some(ref ub) = self.x_ub {
202            for i in 0..n.min(ub.len()) {
203                if xp[i] > ub[i] {
204                    xp[i] = ub[i];
205                }
206            }
207        }
208        xp
209    }
210}
211
212// ---------------------------------------------------------------------------
213// Lower-level solver (projected gradient descent for smooth problems)
214// ---------------------------------------------------------------------------
215
216/// Solve the lower-level problem min_y f(x, y) subject to bounds/constraints
217/// using projected gradient descent with Armijo line search.
218fn solve_lower_level<F, G>(
219    problem: &BilevelProblem<F, G>,
220    x: &[f64],
221    y0: &[f64],
222    options: &BilevelSolverOptions,
223) -> (Vec<f64>, f64, usize)
224where
225    F: Fn(&[f64], &[f64]) -> f64,
226    G: Fn(&[f64], &[f64]) -> f64,
227{
228    let ny = y0.len();
229    let mut y = y0.to_vec();
230    let h = 1e-7f64;
231    let mut nfev = 0usize;
232
233    for _iter in 0..options.max_inner_iter {
234        let f_y = problem.eval_lower(x, &y);
235        nfev += 1;
236
237        // Finite-difference gradient w.r.t. y
238        let mut grad = vec![0.0f64; ny];
239        for i in 0..ny {
240            let mut yf = y.clone();
241            yf[i] += h;
242            grad[i] = (problem.eval_lower(x, &yf) - f_y) / h;
243            nfev += 1;
244        }
245
246        // Gradient norm check
247        let gnorm = grad.iter().map(|g| g * g).sum::<f64>().sqrt();
248        if gnorm < options.inner_tol {
249            break;
250        }
251
252        // Armijo line search
253        let mut step = 1.0f64;
254        let c1 = 1e-4;
255        let mut y_new = vec![0.0f64; ny];
256        for _ls in 0..50 {
257            for i in 0..ny {
258                y_new[i] = y[i] - step * grad[i];
259            }
260            y_new = problem.project_y(&y_new);
261            let f_new = problem.eval_lower(x, &y_new);
262            nfev += 1;
263            // Check for sufficient decrease using the projected step direction
264            let descent: f64 = y_new
265                .iter()
266                .zip(y.iter())
267                .zip(grad.iter())
268                .map(|((yn, yo), g)| g * (yo - yn))
269                .sum();
270            if f_new <= f_y - c1 * descent.abs() {
271                break;
272            }
273            step *= 0.5;
274        }
275        let improvement = (f_y - problem.eval_lower(x, &y_new)).abs();
276        nfev += 1;
277        y = y_new;
278        if improvement < options.inner_tol * (1.0 + f_y.abs()) {
279            break;
280        }
281    }
282
283    let f_final = problem.eval_lower(x, &y);
284    nfev += 1;
285    (y, f_final, nfev)
286}
287
288// ---------------------------------------------------------------------------
289// Options for PSOA
290// ---------------------------------------------------------------------------
291
292/// Options for the penalty-based sequential optimization approach
293#[derive(Debug, Clone)]
294pub struct PsoaOptions {
295    /// Shared bilevel solver options
296    pub solver: BilevelSolverOptions,
297    /// Initial penalty parameter for lower-level optimality
298    pub initial_penalty: f64,
299    /// Penalty growth factor per outer iteration
300    pub penalty_growth: f64,
301    /// Maximum penalty parameter
302    pub max_penalty: f64,
303    /// Upper-level gradient step size
304    pub upper_step: f64,
305    /// Step shrinkage factor when Armijo condition fails
306    pub step_shrink: f64,
307}
308
309impl Default for PsoaOptions {
310    fn default() -> Self {
311        PsoaOptions {
312            solver: BilevelSolverOptions::default(),
313            initial_penalty: 1.0,
314            penalty_growth: 2.0,
315            max_penalty: 1e8,
316            upper_step: 0.1,
317            step_shrink: 0.5,
318        }
319    }
320}
321
322// ---------------------------------------------------------------------------
323// PSOA: Penalty-based Sequential Optimization Approach
324// ---------------------------------------------------------------------------
325
326/// Solve a bilevel problem using the penalty-based sequential optimization approach (PSOA).
327///
328/// The lower-level optimality condition `∇_y f(x,y) = 0` is enforced via a
329/// quadratic penalty added to the upper-level objective:
330///
331/// ```text
332/// min_{x,y}  F(x,y) + ρ · ||∇_y f(x,y)||²
333/// ```
334///
335/// # Arguments
336///
337/// * `problem` - The bilevel problem definition
338/// * `options` - Algorithm options
339///
340/// # Returns
341///
342/// A [`BilevelResult`] with the upper and lower optimal points.
343pub fn solve_bilevel_psoa<F, G>(
344    problem: BilevelProblem<F, G>,
345    options: PsoaOptions,
346) -> OptimizeResult<BilevelResult>
347where
348    F: Fn(&[f64], &[f64]) -> f64,
349    G: Fn(&[f64], &[f64]) -> f64,
350{
351    let nx = problem.x0.len();
352    let ny = problem.y0.len();
353    let h = 1e-7f64;
354
355    if nx == 0 || ny == 0 {
356        return Err(OptimizeError::InvalidInput(
357            "Upper and lower variable vectors must be non-empty".to_string(),
358        ));
359    }
360
361    let mut x = problem.x0.clone();
362    let mut y = problem.y0.clone();
363    let mut rho = options.initial_penalty;
364    let mut n_outer = 0usize;
365    let mut n_inner = 0usize;
366    let mut total_nfev = 0usize;
367
368    // Helper: compute lower-level gradient norm w.r.t. y (finite differences)
369    let lower_grad_y = |x: &[f64], y: &[f64], nfev: &mut usize| -> Vec<f64> {
370        let f0 = problem.eval_lower(x, y);
371        *nfev += 1;
372        let mut grad = vec![0.0f64; ny];
373        for i in 0..ny {
374            let mut yf = y.to_vec();
375            yf[i] += h;
376            grad[i] = (problem.eval_lower(x, &yf) - f0) / h;
377            *nfev += 1;
378        }
379        grad
380    };
381
382    // Penalized objective: F(x,y) + rho * ||∇_y f(x,y)||^2
383    let penalized_obj = |x: &[f64], y: &[f64], rho: f64, nfev: &mut usize| -> f64 {
384        let f_upper = problem.eval_upper(x, y);
385        *nfev += 1;
386        let grad_y = lower_grad_y(x, y, nfev);
387        let gnorm_sq: f64 = grad_y.iter().map(|g| g * g).sum();
388        f_upper + rho * gnorm_sq
389    };
390
391    let mut f_prev = penalized_obj(&x, &y, rho, &mut total_nfev);
392
393    for outer in 0..options.solver.max_outer_iter {
394        n_outer = outer + 1;
395
396        // Step 1: Solve lower-level for current x
397        let (y_new, _lower_f, inner_nfev) = solve_lower_level(&problem, &x, &y, &options.solver);
398        n_inner += 1;
399        total_nfev += inner_nfev;
400        y = y_new;
401
402        // Step 2: Update x using gradient of penalized objective
403        let mut grad_x = vec![0.0f64; nx];
404        let f_cur = penalized_obj(&x, &y, rho, &mut total_nfev);
405        for i in 0..nx {
406            let mut xf = x.clone();
407            xf[i] += h;
408            let f_fwd = penalized_obj(&xf, &y, rho, &mut total_nfev);
409            grad_x[i] = (f_fwd - f_cur) / h;
410        }
411
412        // Projected gradient step for x
413        let step = options.upper_step;
414        let mut x_new = vec![0.0f64; nx];
415        for i in 0..nx {
416            x_new[i] = x[i] - step * grad_x[i];
417        }
418        x_new = problem.project_x(&x_new);
419
420        // Armijo check
421        let f_new = penalized_obj(&x_new, &y, rho, &mut total_nfev);
422        if f_new < f_cur {
423            x = x_new;
424        } else {
425            // Try smaller step
426            let mut s = step * options.step_shrink;
427            let mut improved = false;
428            for _ in 0..20 {
429                let mut xt = vec![0.0f64; nx];
430                for i in 0..nx {
431                    xt[i] = x[i] - s * grad_x[i];
432                }
433                xt = problem.project_x(&xt);
434                let ft = penalized_obj(&xt, &y, rho, &mut total_nfev);
435                if ft < f_cur {
436                    x = xt;
437                    improved = true;
438                    break;
439                }
440                s *= options.step_shrink;
441            }
442            if !improved {
443                // Stagnation: increase penalty and continue
444                rho = (rho * options.penalty_growth).min(options.max_penalty);
445            }
446        }
447
448        // Update penalty
449        rho = (rho * options.penalty_growth).min(options.max_penalty);
450
451        // Convergence check
452        let f_now = penalized_obj(&x, &y, rho, &mut total_nfev);
453        let delta = (f_now - f_prev).abs();
454        if delta < options.solver.outer_tol * (1.0 + f_prev.abs()) {
455            break;
456        }
457        f_prev = f_now;
458    }
459
460    let upper_fun = problem.eval_upper(&x, &y);
461    let lower_fun = problem.eval_lower(&x, &y);
462    total_nfev += 2;
463
464    // Check convergence quality: lower-level gradient at solution
465    let grad_y_final = lower_grad_y(&x, &y, &mut total_nfev);
466    let gnorm: f64 = grad_y_final.iter().map(|g| g * g).sum::<f64>().sqrt();
467    let success =
468        gnorm < options.solver.outer_tol.sqrt() || n_outer < options.solver.max_outer_iter;
469
470    Ok(BilevelResult {
471        x_upper: x,
472        y_lower: y,
473        upper_fun,
474        lower_fun,
475        n_outer_iter: n_outer,
476        n_inner_solves: n_inner,
477        nfev: total_nfev,
478        success,
479        message: if success {
480            "PSOA converged".to_string()
481        } else {
482            "PSOA reached maximum iterations".to_string()
483        },
484    })
485}
486
487// ---------------------------------------------------------------------------
488// Replacement Algorithm
489// ---------------------------------------------------------------------------
490
491/// Replacement algorithm for bilevel optimization.
492///
493/// Constructs the optimal reaction mapping y*(x) by solving the lower-level
494/// problem for each candidate x, then optimizes the upper-level objective
495/// restricted to the graph {(x, y*(x))}.
496pub struct ReplacementAlgorithm {
497    /// Algorithm options
498    pub options: BilevelSolverOptions,
499    /// Step size for upper-level gradient
500    pub upper_step: f64,
501}
502
503impl Default for ReplacementAlgorithm {
504    fn default() -> Self {
505        ReplacementAlgorithm {
506            options: BilevelSolverOptions::default(),
507            upper_step: 0.05,
508        }
509    }
510}
511
512impl ReplacementAlgorithm {
513    /// Create a replacement algorithm solver with given options
514    pub fn new(options: BilevelSolverOptions, upper_step: f64) -> Self {
515        ReplacementAlgorithm {
516            options,
517            upper_step,
518        }
519    }
520
521    /// Solve the bilevel problem by replacing lower level with optimal reaction
522    pub fn solve<F, G>(&self, problem: BilevelProblem<F, G>) -> OptimizeResult<BilevelResult>
523    where
524        F: Fn(&[f64], &[f64]) -> f64,
525        G: Fn(&[f64], &[f64]) -> f64,
526    {
527        solve_bilevel_replacement(problem, self.options.clone(), self.upper_step)
528    }
529}
530
531/// Solve a bilevel problem using the replacement (optimal-reaction) algorithm.
532///
533/// For each x, computes y*(x) = argmin_y f(x,y), then optimizes F(x, y*(x)).
534pub fn solve_bilevel_replacement<F, G>(
535    problem: BilevelProblem<F, G>,
536    options: BilevelSolverOptions,
537    upper_step: f64,
538) -> OptimizeResult<BilevelResult>
539where
540    F: Fn(&[f64], &[f64]) -> f64,
541    G: Fn(&[f64], &[f64]) -> f64,
542{
543    let nx = problem.x0.len();
544    let h = 1e-6f64;
545
546    if nx == 0 {
547        return Err(OptimizeError::InvalidInput(
548            "Upper-level variable vector must be non-empty".to_string(),
549        ));
550    }
551
552    let mut x = problem.x0.clone();
553    let mut y = problem.y0.clone();
554    let mut n_outer = 0usize;
555    let mut n_inner = 0usize;
556    let mut total_nfev = 0usize;
557
558    // Composite function: F(x, y*(x))
559    // Gradient computed by finite differences with re-solving lower level
560    let mut f_prev = {
561        let (ystar, _, nfev) = solve_lower_level(&problem, &x, &y, &options);
562        total_nfev += nfev;
563        n_inner += 1;
564        y = ystar;
565        problem.eval_upper(&x, &y)
566    };
567    total_nfev += 1;
568
569    for outer in 0..options.max_outer_iter {
570        n_outer = outer + 1;
571
572        // Estimate gradient of F(x, y*(x)) via finite differences
573        let mut grad_x = vec![0.0f64; nx];
574        for i in 0..nx {
575            let mut xf = x.clone();
576            xf[i] += h;
577            xf = problem.project_x(&xf);
578            let (yf, _, nfev) = solve_lower_level(&problem, &xf, &y, &options);
579            total_nfev += nfev;
580            n_inner += 1;
581            let f_fwd = problem.eval_upper(&xf, &yf);
582            total_nfev += 1;
583            grad_x[i] = (f_fwd - f_prev) / h;
584        }
585
586        // Gradient step on x
587        let mut x_new = vec![0.0f64; nx];
588        for i in 0..nx {
589            x_new[i] = x[i] - upper_step * grad_x[i];
590        }
591        x_new = problem.project_x(&x_new);
592
593        // Solve lower level at new x
594        let (y_new, _, nfev) = solve_lower_level(&problem, &x_new, &y, &options);
595        total_nfev += nfev;
596        n_inner += 1;
597        let f_new = problem.eval_upper(&x_new, &y_new);
598        total_nfev += 1;
599
600        // Accept step
601        x = x_new;
602        y = y_new;
603
604        let delta = (f_new - f_prev).abs();
605        if delta < options.outer_tol * (1.0 + f_prev.abs()) {
606            f_prev = f_new;
607            break;
608        }
609        f_prev = f_new;
610    }
611
612    let lower_fun = problem.eval_lower(&x, &y);
613    total_nfev += 1;
614
615    Ok(BilevelResult {
616        x_upper: x,
617        y_lower: y,
618        upper_fun: f_prev,
619        lower_fun,
620        n_outer_iter: n_outer,
621        n_inner_solves: n_inner,
622        nfev: total_nfev,
623        success: n_outer < options.max_outer_iter,
624        message: if n_outer < options.max_outer_iter {
625            "Replacement algorithm converged".to_string()
626        } else {
627            "Replacement algorithm: maximum iterations reached".to_string()
628        },
629    })
630}
631
632// ---------------------------------------------------------------------------
633// Single-Level Reduction (KKT-based)
634// ---------------------------------------------------------------------------
635
636/// KKT-based single-level reformulation of a bilevel problem.
637///
638/// Replaces the lower-level problem with its KKT optimality conditions:
639///
640/// ```text
641/// ∇_y f(x,y) + Σ_j μ_j ∇_y g_j(x,y) = 0    (stationarity)
642/// μ_j ≥ 0,  g_j(x,y) ≤ 0                     (dual feasibility)
643/// μ_j · g_j(x,y) = 0                           (complementarity)
644/// ```
645///
646/// The complementarity conditions are handled via a smooth approximation:
647/// `μ_j · (-g_j) ≤ ε`  (Fischer-Burmeister or simple product penalty).
648pub struct SingleLevelReduction {
649    /// Smoothing/penalty parameter for complementarity
650    pub epsilon: f64,
651    /// Penalty weight for KKT residual in objective
652    pub kkt_penalty: f64,
653    /// Inner solver options
654    pub options: BilevelSolverOptions,
655}
656
657impl Default for SingleLevelReduction {
658    fn default() -> Self {
659        SingleLevelReduction {
660            epsilon: 1e-4,
661            kkt_penalty: 100.0,
662            options: BilevelSolverOptions::default(),
663        }
664    }
665}
666
667impl SingleLevelReduction {
668    /// Create with custom parameters
669    pub fn new(epsilon: f64, kkt_penalty: f64, options: BilevelSolverOptions) -> Self {
670        SingleLevelReduction {
671            epsilon,
672            kkt_penalty,
673            options,
674        }
675    }
676
677    /// Solve via KKT single-level reformulation
678    pub fn solve<F, G>(&self, problem: BilevelProblem<F, G>) -> OptimizeResult<BilevelResult>
679    where
680        F: Fn(&[f64], &[f64]) -> f64,
681        G: Fn(&[f64], &[f64]) -> f64,
682    {
683        solve_bilevel_single_level(problem, self.epsilon, self.kkt_penalty, &self.options)
684    }
685}
686
687/// Solve bilevel problem via KKT single-level reformulation.
688///
689/// When the lower level has no explicit constraints, this reduces to solving
690/// the stationarity condition `∇_y f(x,y) = 0` simultaneously with the upper
691/// level. The KKT residual is added as a penalty to the upper-level objective.
692pub fn solve_bilevel_single_level<F, G>(
693    problem: BilevelProblem<F, G>,
694    epsilon: f64,
695    kkt_penalty: f64,
696    options: &BilevelSolverOptions,
697) -> OptimizeResult<BilevelResult>
698where
699    F: Fn(&[f64], &[f64]) -> f64,
700    G: Fn(&[f64], &[f64]) -> f64,
701{
702    let nx = problem.x0.len();
703    let ny = problem.y0.len();
704    let n_lc = problem.lower_constraints.len();
705    let h = 1e-7f64;
706    let _ = epsilon; // used conceptually
707
708    // Decision vector: [x (nx), y (ny), mu (n_lc)]
709    let n_total = nx + ny + n_lc;
710    let mut z = vec![0.0f64; n_total];
711    for i in 0..nx {
712        z[i] = problem.x0[i];
713    }
714    for i in 0..ny {
715        z[nx + i] = problem.y0[i];
716    }
717    // Initialize dual variables mu to 0
718    for i in 0..n_lc {
719        z[nx + ny + i] = 0.0;
720    }
721
722    // Penalized combined objective
723    let combined_obj = |z: &[f64], nfev: &mut usize| -> f64 {
724        let x = &z[0..nx];
725        let y = &z[nx..nx + ny];
726        let mu = &z[nx + ny..n_total];
727
728        let f_upper = problem.eval_upper(x, y);
729        *nfev += 1;
730
731        // Lower-level gradient w.r.t. y (stationarity)
732        let f_lower_0 = problem.eval_lower(x, y);
733        *nfev += 1;
734        let mut grad_lower_y = vec![0.0f64; ny];
735        for i in 0..ny {
736            let mut yf = y.to_vec();
737            yf[i] += h;
738            grad_lower_y[i] = (problem.eval_lower(x, &yf) - f_lower_0) / h;
739            *nfev += 1;
740        }
741
742        // Add constraint gradient contribution: Σ μ_j ∇_y g_j
743        for (j, constraint) in problem.lower_constraints.iter().enumerate() {
744            let gj0 = constraint(x, y);
745            *nfev += 1;
746            for i in 0..ny {
747                let mut yf = y.to_vec();
748                yf[i] += h;
749                let gj_fwd = constraint(x, &yf);
750                *nfev += 1;
751                grad_lower_y[i] += mu[j] * (gj_fwd - gj0) / h;
752            }
753        }
754
755        // KKT stationarity penalty: ||∇_y L||^2
756        let stat_norm_sq: f64 = grad_lower_y.iter().map(|g| g * g).sum();
757
758        // Dual feasibility penalty: Σ max(0, -μ_j)^2
759        let dual_feas: f64 = mu.iter().map(|&mj| (-mj).max(0.0).powi(2)).sum();
760
761        // Complementarity penalty: Σ (μ_j * g_j)^2
762        let compl: f64 = problem
763            .lower_constraints
764            .iter()
765            .enumerate()
766            .map(|(j, g)| {
767                *nfev += 1;
768                let gj = g(x, y);
769                (mu[j] * gj).powi(2)
770            })
771            .sum();
772
773        // Upper constraint penalty
774        let upper_viol: f64 = problem.upper_constraint_violation(x, y);
775        *nfev += problem.upper_constraints.len();
776
777        f_upper
778            + kkt_penalty * (stat_norm_sq + dual_feas + compl)
779            + kkt_penalty * upper_viol.powi(2)
780    };
781
782    let mut total_nfev = 0usize;
783    let mut f_prev = combined_obj(&z, &mut total_nfev);
784
785    // Gradient descent on z
786    let step0 = 0.01f64;
787    for outer in 0..options.max_outer_iter {
788        let f_cur = combined_obj(&z, &mut total_nfev);
789        let mut grad = vec![0.0f64; n_total];
790        for i in 0..n_total {
791            let mut zf = z.clone();
792            zf[i] += h;
793            let f_fwd = combined_obj(&zf, &mut total_nfev);
794            grad[i] = (f_fwd - f_cur) / h;
795        }
796
797        let gnorm: f64 = grad.iter().map(|g| g * g).sum::<f64>().sqrt();
798        if gnorm < options.outer_tol {
799            break;
800        }
801
802        // Armijo line search
803        let mut step = step0;
804        let c1 = 1e-4;
805        let mut z_new = z.clone();
806        let descent = gnorm * gnorm;
807        for _ in 0..40 {
808            for i in 0..n_total {
809                z_new[i] = z[i] - step * grad[i];
810            }
811            // Project mu >= 0
812            for i in 0..n_lc {
813                if z_new[nx + ny + i] < 0.0 {
814                    z_new[nx + ny + i] = 0.0;
815                }
816            }
817            // Project x, y onto bounds
818            let x_proj = problem.project_x(&z_new[0..nx]);
819            let y_proj = problem.project_y(&z_new[nx..nx + ny]);
820            for i in 0..nx {
821                z_new[i] = x_proj[i];
822            }
823            for i in 0..ny {
824                z_new[nx + i] = y_proj[i];
825            }
826
827            let f_new = combined_obj(&z_new, &mut total_nfev);
828            if f_new <= f_cur - c1 * step * descent {
829                break;
830            }
831            step *= 0.5;
832        }
833
834        let f_new = combined_obj(&z_new, &mut total_nfev);
835        let delta = (f_new - f_prev).abs();
836        z = z_new;
837        f_prev = f_new;
838
839        if delta < options.outer_tol * (1.0 + f_prev.abs()) && outer > 5 {
840            break;
841        }
842    }
843
844    let x_sol = z[0..nx].to_vec();
845    let y_sol = z[nx..nx + ny].to_vec();
846    let upper_fun = problem.eval_upper(&x_sol, &y_sol);
847    let lower_fun = problem.eval_lower(&x_sol, &y_sol);
848    total_nfev += 2;
849
850    Ok(BilevelResult {
851        x_upper: x_sol,
852        y_lower: y_sol,
853        upper_fun,
854        lower_fun,
855        n_outer_iter: 0, // merged into single-level iterations
856        n_inner_solves: 0,
857        nfev: total_nfev,
858        success: true,
859        message: "Single-level KKT reformulation solved".to_string(),
860    })
861}
862
863// ---------------------------------------------------------------------------
864// Tests
865// ---------------------------------------------------------------------------
866
867#[cfg(test)]
868mod tests {
869    use super::*;
870
871    fn simple_upper(x: &[f64], y: &[f64]) -> f64 {
872        (x[0] - 1.0).powi(2) + (y[0] - 1.0).powi(2)
873    }
874
875    fn simple_lower(_x: &[f64], y: &[f64]) -> f64 {
876        y[0].powi(2)
877    }
878
879    #[test]
880    fn test_psoa_basic() {
881        // Lower level: min y^2 → y* = 0
882        // Upper level: min (x-1)^2 + (y-1)^2 with y=0 → x*=1
883        let problem = BilevelProblem::new(simple_upper, simple_lower, vec![0.0], vec![0.5]);
884        let options = PsoaOptions {
885            solver: BilevelSolverOptions {
886                max_outer_iter: 500,
887                max_inner_iter: 200,
888                outer_tol: 1e-5,
889                inner_tol: 1e-7,
890                verbose: false,
891            },
892            ..Default::default()
893        };
894        let result = solve_bilevel_psoa(problem, options).expect("failed to create result");
895        // Lower level optimal is y* = 0
896        assert!(
897            (result.y_lower[0]).abs() < 0.1,
898            "y should be near 0, got {}",
899            result.y_lower[0]
900        );
901    }
902
903    #[test]
904    fn test_replacement_basic() {
905        let problem = BilevelProblem::new(simple_upper, simple_lower, vec![0.0], vec![0.5]);
906        let options = BilevelSolverOptions {
907            max_outer_iter: 300,
908            max_inner_iter: 200,
909            outer_tol: 1e-5,
910            inner_tol: 1e-8,
911            verbose: false,
912        };
913        let result =
914            solve_bilevel_replacement(problem, options, 0.05).expect("failed to create result");
915        assert!(
916            (result.y_lower[0]).abs() < 0.1,
917            "y should be near 0, got {}",
918            result.y_lower[0]
919        );
920    }
921
922    #[test]
923    fn test_single_level_basic() {
924        let problem = BilevelProblem::new(simple_upper, simple_lower, vec![0.0], vec![0.5]);
925        let options = BilevelSolverOptions {
926            max_outer_iter: 300,
927            max_inner_iter: 200,
928            outer_tol: 1e-4,
929            inner_tol: 1e-7,
930            verbose: false,
931        };
932        let result = solve_bilevel_single_level(problem, 1e-4, 10.0, &options)
933            .expect("failed to create result");
934        assert!(result.success);
935    }
936
937    #[test]
938    fn test_bilevel_result_fields() {
939        let result = BilevelResult {
940            x_upper: vec![1.0],
941            y_lower: vec![0.0],
942            upper_fun: 1.0,
943            lower_fun: 0.0,
944            n_outer_iter: 10,
945            n_inner_solves: 10,
946            nfev: 100,
947            success: true,
948            message: "test".to_string(),
949        };
950        assert!(result.success);
951        assert_eq!(result.nfev, 100);
952    }
953}