Skip to main content

scirs2_optimize/symbolic/
line_search.rs

1//! Closed-form quadratic line-search for symbolic objectives.
2//!
3//! Wraps [`scirs2_symbolic::cas::closed_form_step`] to provide a
4//! build-once / evaluate-many interface suitable for the inner loop of
5//! gradient-descent and Newton-type optimizers in scirs2-optimize.
6//!
7//! # Sign convention
8//!
9//! The caller takes step `x ← x + α* · d` (not `x − α*·d`).
10//! For gradient descent, pass `direction[i] = −∇f_i` to obtain `α* > 0`
11//! on a strictly convex quadratic.
12//!
13//! # Example
14//!
15//! ```no_run
16//! # #[cfg(feature = "symbolic")]
17//! # {
18//! use scirs2_optimize::symbolic::line_search::SymbolicLineSearch;
19//! use scirs2_symbolic::eml::LoweredOp;
20//!
21//! // f(x) = (x - 5)²
22//! let inner = LoweredOp::Sub(
23//!     Box::new(LoweredOp::Var(0)),
24//!     Box::new(LoweredOp::Const(5.0)),
25//! );
26//! let f = LoweredOp::Mul(Box::new(inner.clone()), Box::new(inner));
27//!
28//! // direction = +1 (ascent direction along x)
29//! let ls = SymbolicLineSearch::new(&f, &[0], &[LoweredOp::Const(1.0)])
30//!     .expect("build");
31//! let alpha = ls.eval(&[0.0]).expect("eval"); // from x=0, step = 5.0
32//! assert!((alpha - 5.0).abs() < 1e-10);
33//! # }
34//! ```
35
36use scirs2_symbolic::cas::{closed_form_step, LineSearchError as SymLineSearchError};
37use scirs2_symbolic::eml::eval::{eval_real, EvalCtx};
38use scirs2_symbolic::eml::LoweredOp;
39
40// ─────────────────────────────────────────────────────────────────────────────
41// Error type
42// ─────────────────────────────────────────────────────────────────────────────
43
44/// Errors from [`SymbolicLineSearch`].
45#[derive(Debug)]
46pub enum OptLineSearchError {
47    /// Error propagated from [`scirs2_symbolic::cas::closed_form_step`].
48    Symbolic(SymLineSearchError),
49    /// Evaluation of the symbolic `α*` expression failed (domain error,
50    /// unbound variable, etc.).
51    EvalError(String),
52    /// The length of `x` at eval-time does not match the number of variables
53    /// the `α*` expression was built for.
54    DimensionMismatch {
55        /// Expected number of variables (length of `x_vars` at build time).
56        expected: usize,
57        /// Length of `x` supplied to [`SymbolicLineSearch::eval`].
58        got: usize,
59    },
60}
61
62impl std::fmt::Display for OptLineSearchError {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        match self {
65            Self::Symbolic(e) => write!(f, "symbolic line-search error: {e}"),
66            Self::EvalError(s) => write!(f, "line-search eval error: {s}"),
67            Self::DimensionMismatch { expected, got } => write!(
68                f,
69                "dimension mismatch: α* expression expects {expected} variables, x has {got}"
70            ),
71        }
72    }
73}
74
75impl std::error::Error for OptLineSearchError {}
76
77// ─────────────────────────────────────────────────────────────────────────────
78// SymbolicLineSearch
79// ─────────────────────────────────────────────────────────────────────────────
80
81/// Precomputed symbolic line search for a quadratic objective.
82///
83/// Build once from `f`, `x_vars`, and `direction`; then call [`Self::eval`]
84/// at each new point to get the concrete step length `α*`.
85///
86/// The symbolic expression for `α*` is built at construction time using
87/// [`closed_form_step`]; repeated evaluations only call `eval_real`.
88#[derive(Debug)]
89pub struct SymbolicLineSearch {
90    /// Symbolic expression for the optimal step length `α*(x_vars)`.
91    alpha_expr: LoweredOp,
92    /// Number of variables (`x_vars.len()`), kept for the dimension check.
93    n_vars: usize,
94}
95
96impl SymbolicLineSearch {
97    /// Build the symbolic `α*` expression from `f` and a fixed direction.
98    ///
99    /// # Arguments
100    ///
101    /// * `f`         — scalar objective as a `LoweredOp`
102    /// * `x_vars`    — variable indices to differentiate
103    /// * `direction` — one symbolic `LoweredOp` per entry of `x_vars`
104    ///
105    /// # Errors
106    ///
107    /// Propagates errors from [`closed_form_step`] wrapped in
108    /// [`OptLineSearchError::Symbolic`].
109    pub fn new(
110        f: &LoweredOp,
111        x_vars: &[usize],
112        direction: &[LoweredOp],
113    ) -> Result<Self, OptLineSearchError> {
114        let alpha_expr =
115            closed_form_step(f, x_vars, direction).map_err(OptLineSearchError::Symbolic)?;
116        Ok(Self {
117            alpha_expr,
118            n_vars: x_vars.len(),
119        })
120    }
121
122    /// Evaluate the step length `α*` at the concrete point `x`.
123    ///
124    /// `x` must have at least `n_vars` elements (the maximum variable index
125    /// referenced by `x_vars` at build time, + 1).
126    ///
127    /// # Errors
128    ///
129    /// * [`OptLineSearchError::DimensionMismatch`] when `x.len() < self.n_vars`
130    /// * [`OptLineSearchError::EvalError`] when the symbolic evaluation fails
131    pub fn eval(&self, x: &[f64]) -> Result<f64, OptLineSearchError> {
132        if x.len() < self.n_vars {
133            return Err(OptLineSearchError::DimensionMismatch {
134                expected: self.n_vars,
135                got: x.len(),
136            });
137        }
138        let ctx = EvalCtx::new(x);
139        eval_real(&self.alpha_expr, &ctx).map_err(|e| OptLineSearchError::EvalError(e.to_string()))
140    }
141
142    /// Access the raw symbolic `α*` expression (for inspection or re-use).
143    pub fn alpha_expr(&self) -> &LoweredOp {
144        &self.alpha_expr
145    }
146}
147
148// ─────────────────────────────────────────────────────────────────────────────
149// Tests
150// ─────────────────────────────────────────────────────────────────────────────
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    fn var(i: usize) -> LoweredOp {
157        LoweredOp::Var(i)
158    }
159    fn c(v: f64) -> LoweredOp {
160        LoweredOp::Const(v)
161    }
162
163    // ── Test 1: f = (x−5)², direction = +1, from x=0 → α* = 5.0 ─────────────
164    #[test]
165    fn test_shifted_quadratic_step_from_origin() {
166        // f(x) = (x - 5)^2
167        let inner = LoweredOp::Sub(Box::new(var(0)), Box::new(c(5.0)));
168        let f = LoweredOp::Mul(Box::new(inner.clone()), Box::new(inner));
169
170        let ls = SymbolicLineSearch::new(&f, &[0], &[c(1.0)]).expect("build");
171        let alpha = ls.eval(&[0.0]).expect("eval");
172
173        // g = 2*(x-5); at x=0: g=-10; dᵀHd = 1*2*1 = 2; α* = -(-10)/2 = 5.0
174        assert!((alpha - 5.0).abs() < 1e-10, "expected 5.0, got {alpha}");
175
176        // Verify that taking the step lands at the minimum.
177        // x_new = 0 + 5.0 * 1 = 5.0; f(5.0) = 0
178        let x_new = 0.0 + alpha * 1.0;
179        assert!(
180            (x_new - 5.0).abs() < 1e-10,
181            "step should land at x=5 (minimum), got x={x_new}"
182        );
183    }
184
185    // ── Test 2: degenerate direction propagates Err ───────────────────────────
186    #[test]
187    fn test_degenerate_direction_propagates() {
188        // f = x (linear — Hessian is zero everywhere)
189        let f = var(0);
190        let result = SymbolicLineSearch::new(&f, &[0], &[c(0.0)]);
191        assert!(
192            matches!(
193                result,
194                Err(OptLineSearchError::Symbolic(
195                    SymLineSearchError::DegenerateDirection
196                ))
197            ),
198            "expected DegenerateDirection, got: {result:?}"
199        );
200    }
201}