scirs2_optimize/symbolic/line_search.rs
1//! Closed-form quadratic line-search for symbolic objectives.
2//!
3//! Wraps [`scirs2_symbolic::cas::closed_form_step`] to provide a
4//! build-once / evaluate-many interface suitable for the inner loop of
5//! gradient-descent and Newton-type optimizers in scirs2-optimize.
6//!
7//! # Sign convention
8//!
9//! The caller takes step `x ← x + α* · d` (not `x − α*·d`).
10//! For gradient descent, pass `direction[i] = −∇f_i` to obtain `α* > 0`
11//! on a strictly convex quadratic.
12//!
13//! # Example
14//!
15//! ```no_run
16//! # #[cfg(feature = "symbolic")]
17//! # {
18//! use scirs2_optimize::symbolic::line_search::SymbolicLineSearch;
19//! use scirs2_symbolic::eml::LoweredOp;
20//!
21//! // f(x) = (x - 5)²
22//! let inner = LoweredOp::Sub(
23//! Box::new(LoweredOp::Var(0)),
24//! Box::new(LoweredOp::Const(5.0)),
25//! );
26//! let f = LoweredOp::Mul(Box::new(inner.clone()), Box::new(inner));
27//!
28//! // direction = +1 (ascent direction along x)
29//! let ls = SymbolicLineSearch::new(&f, &[0], &[LoweredOp::Const(1.0)])
30//! .expect("build");
31//! let alpha = ls.eval(&[0.0]).expect("eval"); // from x=0, step = 5.0
32//! assert!((alpha - 5.0).abs() < 1e-10);
33//! # }
34//! ```
35
36use scirs2_symbolic::cas::{closed_form_step, LineSearchError as SymLineSearchError};
37use scirs2_symbolic::eml::eval::{eval_real, EvalCtx};
38use scirs2_symbolic::eml::LoweredOp;
39
40// ─────────────────────────────────────────────────────────────────────────────
41// Error type
42// ─────────────────────────────────────────────────────────────────────────────
43
44/// Errors from [`SymbolicLineSearch`].
45#[derive(Debug)]
46pub enum OptLineSearchError {
47 /// Error propagated from [`scirs2_symbolic::cas::closed_form_step`].
48 Symbolic(SymLineSearchError),
49 /// Evaluation of the symbolic `α*` expression failed (domain error,
50 /// unbound variable, etc.).
51 EvalError(String),
52 /// The length of `x` at eval-time does not match the number of variables
53 /// the `α*` expression was built for.
54 DimensionMismatch {
55 /// Expected number of variables (length of `x_vars` at build time).
56 expected: usize,
57 /// Length of `x` supplied to [`SymbolicLineSearch::eval`].
58 got: usize,
59 },
60}
61
62impl std::fmt::Display for OptLineSearchError {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 match self {
65 Self::Symbolic(e) => write!(f, "symbolic line-search error: {e}"),
66 Self::EvalError(s) => write!(f, "line-search eval error: {s}"),
67 Self::DimensionMismatch { expected, got } => write!(
68 f,
69 "dimension mismatch: α* expression expects {expected} variables, x has {got}"
70 ),
71 }
72 }
73}
74
75impl std::error::Error for OptLineSearchError {}
76
77// ─────────────────────────────────────────────────────────────────────────────
78// SymbolicLineSearch
79// ─────────────────────────────────────────────────────────────────────────────
80
81/// Precomputed symbolic line search for a quadratic objective.
82///
83/// Build once from `f`, `x_vars`, and `direction`; then call [`Self::eval`]
84/// at each new point to get the concrete step length `α*`.
85///
86/// The symbolic expression for `α*` is built at construction time using
87/// [`closed_form_step`]; repeated evaluations only call `eval_real`.
88#[derive(Debug)]
89pub struct SymbolicLineSearch {
90 /// Symbolic expression for the optimal step length `α*(x_vars)`.
91 alpha_expr: LoweredOp,
92 /// Number of variables (`x_vars.len()`), kept for the dimension check.
93 n_vars: usize,
94}
95
96impl SymbolicLineSearch {
97 /// Build the symbolic `α*` expression from `f` and a fixed direction.
98 ///
99 /// # Arguments
100 ///
101 /// * `f` — scalar objective as a `LoweredOp`
102 /// * `x_vars` — variable indices to differentiate
103 /// * `direction` — one symbolic `LoweredOp` per entry of `x_vars`
104 ///
105 /// # Errors
106 ///
107 /// Propagates errors from [`closed_form_step`] wrapped in
108 /// [`OptLineSearchError::Symbolic`].
109 pub fn new(
110 f: &LoweredOp,
111 x_vars: &[usize],
112 direction: &[LoweredOp],
113 ) -> Result<Self, OptLineSearchError> {
114 let alpha_expr =
115 closed_form_step(f, x_vars, direction).map_err(OptLineSearchError::Symbolic)?;
116 Ok(Self {
117 alpha_expr,
118 n_vars: x_vars.len(),
119 })
120 }
121
122 /// Evaluate the step length `α*` at the concrete point `x`.
123 ///
124 /// `x` must have at least `n_vars` elements (the maximum variable index
125 /// referenced by `x_vars` at build time, + 1).
126 ///
127 /// # Errors
128 ///
129 /// * [`OptLineSearchError::DimensionMismatch`] when `x.len() < self.n_vars`
130 /// * [`OptLineSearchError::EvalError`] when the symbolic evaluation fails
131 pub fn eval(&self, x: &[f64]) -> Result<f64, OptLineSearchError> {
132 if x.len() < self.n_vars {
133 return Err(OptLineSearchError::DimensionMismatch {
134 expected: self.n_vars,
135 got: x.len(),
136 });
137 }
138 let ctx = EvalCtx::new(x);
139 eval_real(&self.alpha_expr, &ctx).map_err(|e| OptLineSearchError::EvalError(e.to_string()))
140 }
141
142 /// Access the raw symbolic `α*` expression (for inspection or re-use).
143 pub fn alpha_expr(&self) -> &LoweredOp {
144 &self.alpha_expr
145 }
146}
147
148// ─────────────────────────────────────────────────────────────────────────────
149// Tests
150// ─────────────────────────────────────────────────────────────────────────────
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155
156 fn var(i: usize) -> LoweredOp {
157 LoweredOp::Var(i)
158 }
159 fn c(v: f64) -> LoweredOp {
160 LoweredOp::Const(v)
161 }
162
163 // ── Test 1: f = (x−5)², direction = +1, from x=0 → α* = 5.0 ─────────────
164 #[test]
165 fn test_shifted_quadratic_step_from_origin() {
166 // f(x) = (x - 5)^2
167 let inner = LoweredOp::Sub(Box::new(var(0)), Box::new(c(5.0)));
168 let f = LoweredOp::Mul(Box::new(inner.clone()), Box::new(inner));
169
170 let ls = SymbolicLineSearch::new(&f, &[0], &[c(1.0)]).expect("build");
171 let alpha = ls.eval(&[0.0]).expect("eval");
172
173 // g = 2*(x-5); at x=0: g=-10; dᵀHd = 1*2*1 = 2; α* = -(-10)/2 = 5.0
174 assert!((alpha - 5.0).abs() < 1e-10, "expected 5.0, got {alpha}");
175
176 // Verify that taking the step lands at the minimum.
177 // x_new = 0 + 5.0 * 1 = 5.0; f(5.0) = 0
178 let x_new = 0.0 + alpha * 1.0;
179 assert!(
180 (x_new - 5.0).abs() < 1e-10,
181 "step should land at x=5 (minimum), got x={x_new}"
182 );
183 }
184
185 // ── Test 2: degenerate direction propagates Err ───────────────────────────
186 #[test]
187 fn test_degenerate_direction_propagates() {
188 // f = x (linear — Hessian is zero everywhere)
189 let f = var(0);
190 let result = SymbolicLineSearch::new(&f, &[0], &[c(0.0)]);
191 assert!(
192 matches!(
193 result,
194 Err(OptLineSearchError::Symbolic(
195 SymLineSearchError::DegenerateDirection
196 ))
197 ),
198 "expected DegenerateDirection, got: {result:?}"
199 );
200 }
201}