scirs2_optimize/least_squares/
separable.rs

1//! Separable least squares for partially linear problems
2//!
3//! This module implements variable projection (VARPRO) algorithm for solving
4//! separable nonlinear least squares problems where the model has the form:
5//!
6//! f(x, α, β) = Σ αᵢ φᵢ(x, β)
7//!
8//! where α are linear parameters and β are nonlinear parameters.
9//!
10//! # Example
11//!
12//! ```
13//! use scirs2_core::ndarray::{array, Array1, Array2};
14//! use scirs2_optimize::least_squares::separable::{separable_least_squares, SeparableOptions};
15//!
16//! // Model: y = α₁ * exp(-β * t) + α₂
17//! // Linear parameters: α = [α₁, α₂]
18//! // Nonlinear parameters: β = [β]
19//!
20//! // Basis functions that depend on nonlinear parameters
21//! fn basis_functions(t: &[f64], beta: &[f64]) -> Array2<f64> {
22//!     let n = t.len();
23//!     let mut phi = Array2::zeros((n, 2));
24//!     
25//!     for i in 0..n {
26//!         phi[[i, 0]] = (-beta[0] * t[i]).exp(); // exp(-β*t)
27//!         phi[[i, 1]] = 1.0;                     // constant term
28//!     }
29//!     phi
30//! }
31//!
32//! // Jacobian of basis functions w.r.t. nonlinear parameters
33//! fn basis_jacobian(t: &[f64], beta: &[f64]) -> Array2<f64> {
34//!     let n = t.len();
35//!     let mut dphi_dbeta = Array2::zeros((n * 2, 1)); // n*p x q
36//!     
37//!     for i in 0..n {
38//!         // d/dβ(exp(-β*t)) = -t * exp(-β*t)
39//!         dphi_dbeta[[i, 0]] = -t[i] * (-beta[0] * t[i]).exp();
40//!         // d/dβ(1) = 0
41//!         dphi_dbeta[[n + i, 0]] = 0.0;
42//!     }
43//!     dphi_dbeta
44//! }
45//!
46//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
47//! // Data points
48//! let t_data = array![0.0, 0.5, 1.0, 1.5, 2.0];
49//! let y_data = array![2.0, 1.6, 1.3, 1.1, 1.0];
50//!
51//! // Initial guess for nonlinear parameters
52//! let beta0 = array![0.5];
53//!
54//! let result = separable_least_squares(
55//!     basis_functions,
56//!     basis_jacobian,
57//!     &t_data,
58//!     &y_data,
59//!     &beta0,
60//!     None
61//! )?;
62//!
63//! println!("Nonlinear params: {:?}", result.result.x);
64//! println!("Linear params: {:?}", result.linear_params);
65//! # Ok(())
66//! # }
67//! ```
68
69use crate::error::OptimizeResult;
70use crate::result::OptimizeResults;
71use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Ix1};
72use std::f64;
73
74/// Options for separable least squares
75#[derive(Debug, Clone)]
76pub struct SeparableOptions {
77    /// Maximum number of iterations
78    pub max_iter: usize,
79
80    /// Tolerance for convergence on nonlinear parameters
81    pub beta_tol: f64,
82
83    /// Tolerance for convergence on function value
84    pub ftol: f64,
85
86    /// Tolerance for convergence on gradient
87    pub gtol: f64,
88
89    /// Method for solving the linear subproblem
90    pub linear_solver: LinearSolver,
91
92    /// Regularization parameter for linear solve (if needed)
93    pub lambda: f64,
94}
95
96/// Methods for solving the linear least squares subproblem
97#[derive(Debug, Clone, Copy)]
98pub enum LinearSolver {
99    /// QR decomposition (stable, recommended)
100    QR,
101    /// Normal equations (faster but less stable)
102    NormalEquations,
103    /// Singular value decomposition (most stable)
104    SVD,
105}
106
107impl Default for SeparableOptions {
108    fn default() -> Self {
109        SeparableOptions {
110            max_iter: 100,
111            beta_tol: 1e-8,
112            ftol: 1e-8,
113            gtol: 1e-8,
114            linear_solver: LinearSolver::QR,
115            lambda: 0.0,
116        }
117    }
118}
119
120/// Result structure extended for separable least squares
121#[derive(Debug, Clone)]
122pub struct SeparableResult {
123    /// Standard optimization results (nonlinear parameters)
124    pub result: OptimizeResults<f64>,
125    /// Optimal linear parameters
126    pub linear_params: Array1<f64>,
127}
128
129/// Solve a separable nonlinear least squares problem
130///
131/// This function solves problems of the form:
132/// minimize ||y - Σ αᵢ φᵢ(x, β)||²
133///
134/// where α are linear parameters and β are nonlinear parameters.
135///
136/// # Arguments
137///
138/// * `basis_functions` - Function that returns the basis matrix Φ(x, β)
139/// * `basis_jacobian` - Function that returns ∂Φ/∂β
140/// * `x_data` - Independent variable data
141/// * `y_data` - Dependent variable data  
142/// * `beta0` - Initial guess for nonlinear parameters
143/// * `options` - Options for the optimization
144#[allow(dead_code)]
145pub fn separable_least_squares<F, J, S1, S2, S3>(
146    basis_functions: F,
147    basis_jacobian: J,
148    x_data: &ArrayBase<S1, Ix1>,
149    y_data: &ArrayBase<S2, Ix1>,
150    beta0: &ArrayBase<S3, Ix1>,
151    options: Option<SeparableOptions>,
152) -> OptimizeResult<SeparableResult>
153where
154    F: Fn(&[f64], &[f64]) -> Array2<f64>,
155    J: Fn(&[f64], &[f64]) -> Array2<f64>,
156    S1: Data<Elem = f64>,
157    S2: Data<Elem = f64>,
158    S3: Data<Elem = f64>,
159{
160    let options = options.unwrap_or_default();
161    let mut beta = beta0.to_owned();
162
163    let n = y_data.len();
164    if x_data.len() != n {
165        return Err(crate::error::OptimizeError::ValueError(
166            "x_data and y_data must have the same length".to_string(),
167        ));
168    }
169
170    let mut iter = 0;
171    let mut nfev = 0;
172    let mut prev_cost = f64::INFINITY;
173
174    // Main optimization loop
175    while iter < options.max_iter {
176        // Compute basis functions
177        let phi = basis_functions(
178            x_data.as_slice().expect("Operation failed"),
179            beta.as_slice().expect("Operation failed"),
180        );
181        nfev += 1;
182
183        let (n_points, n_basis) = phi.dim();
184        if n_points != n {
185            return Err(crate::error::OptimizeError::ValueError(
186                "Basis functions returned wrong number of rows".to_string(),
187            ));
188        }
189
190        // Solve linear least squares for α given current β
191        let alpha = solve_linear_subproblem(&phi, y_data, &options)?;
192
193        // Compute residual
194        let y_pred = phi.dot(&alpha);
195        let residual = y_data - &y_pred;
196        let cost = 0.5 * residual.iter().map(|&r| r * r).sum::<f64>();
197
198        // Check convergence on cost function
199        if (prev_cost - cost).abs() < options.ftol * cost {
200            let mut result = OptimizeResults::default();
201            result.x = beta.clone();
202            result.fun = cost;
203            result.nfev = nfev;
204            result.nit = iter;
205            result.success = true;
206            result.message = "Converged (function tolerance)".to_string();
207
208            return Ok(SeparableResult {
209                result,
210                linear_params: alpha,
211            });
212        }
213
214        // Compute gradient w.r.t. nonlinear parameters
215        let gradient = compute_gradient(
216            &phi,
217            &alpha,
218            &residual,
219            x_data.as_slice().expect("Operation failed"),
220            beta.as_slice().expect("Operation failed"),
221            &basis_jacobian,
222        );
223
224        // Check convergence on gradient
225        if gradient.iter().all(|&g| g.abs() < options.gtol) {
226            let mut result = OptimizeResults::default();
227            result.x = beta.clone();
228            result.fun = cost;
229            result.nfev = nfev;
230            result.nit = iter;
231            result.success = true;
232            result.message = "Converged (gradient tolerance)".to_string();
233
234            return Ok(SeparableResult {
235                result,
236                linear_params: alpha,
237            });
238        }
239
240        // Update nonlinear parameters using gradient descent
241        // (Could be improved with more sophisticated methods)
242        let step_size = backtracking_line_search(&beta, &gradient, cost, |b| {
243            let phi_new = basis_functions(x_data.as_slice().expect("Operation failed"), b);
244            let alpha_new =
245                solve_linear_subproblem(&phi_new, y_data, &options).expect("Operation failed");
246            let y_pred_new = phi_new.dot(&alpha_new);
247            let res_new = y_data - &y_pred_new;
248            0.5 * res_new.iter().map(|&r| r * r).sum::<f64>()
249        });
250        nfev += 5; // Approximate function evaluations in line search
251
252        beta = &beta - &gradient * step_size;
253
254        // Check convergence on parameters
255        if gradient.iter().map(|&g| g * g).sum::<f64>().sqrt() * step_size < options.beta_tol {
256            let mut result = OptimizeResults::default();
257            result.x = beta.clone();
258            result.fun = cost;
259            result.nfev = nfev;
260            result.nit = iter;
261            result.success = true;
262            result.message = "Converged (parameter tolerance)".to_string();
263
264            // Compute final linear parameters
265            let phi_final = basis_functions(
266                x_data.as_slice().expect("Operation failed"),
267                beta.as_slice().expect("Operation failed"),
268            );
269            let alpha_final = solve_linear_subproblem(&phi_final, y_data, &options)?;
270
271            return Ok(SeparableResult {
272                result,
273                linear_params: alpha_final,
274            });
275        }
276
277        prev_cost = cost;
278        iter += 1;
279    }
280
281    // Maximum iterations reached
282    let phi_final = basis_functions(
283        x_data.as_slice().expect("Operation failed"),
284        beta.as_slice().expect("Operation failed"),
285    );
286    let alpha_final = solve_linear_subproblem(&phi_final, y_data, &options)?;
287    let y_pred_final = phi_final.dot(&alpha_final);
288    let res_final = y_data - &y_pred_final;
289    let final_cost = 0.5 * res_final.iter().map(|&r| r * r).sum::<f64>();
290
291    let mut result = OptimizeResults::default();
292    result.x = beta;
293    result.fun = final_cost;
294    result.nfev = nfev;
295    result.nit = iter;
296    result.success = false;
297    result.message = "Maximum iterations reached".to_string();
298
299    Ok(SeparableResult {
300        result,
301        linear_params: alpha_final,
302    })
303}
304
305/// Solve the linear least squares subproblem
306#[allow(dead_code)]
307fn solve_linear_subproblem<S1>(
308    phi: &Array2<f64>,
309    y: &ArrayBase<S1, Ix1>,
310    options: &SeparableOptions,
311) -> OptimizeResult<Array1<f64>>
312where
313    S1: Data<Elem = f64>,
314{
315    match options.linear_solver {
316        LinearSolver::NormalEquations => {
317            // Solve using normal equations: (Φ^T Φ) α = Φ^T y
318            let phi_t_phi = phi.t().dot(phi);
319            let phi_t_y = phi.t().dot(y);
320
321            // Add regularization if specified
322            let mut regularized = phi_t_phi.clone();
323            if options.lambda > 0.0 {
324                for i in 0..regularized.shape()[0] {
325                    regularized[[i, i]] += options.lambda;
326                }
327            }
328
329            solve_symmetric_system(&regularized, &phi_t_y)
330        }
331        LinearSolver::QR => {
332            // QR decomposition (more stable)
333            qr_solve(phi, y, options.lambda)
334        }
335        LinearSolver::SVD => {
336            // SVD decomposition (most stable)
337            svd_solve(phi, y, options.lambda)
338        }
339    }
340}
341
342/// Compute gradient w.r.t. nonlinear parameters
343#[allow(dead_code)]
344fn compute_gradient<J>(
345    _phi: &Array2<f64>,
346    alpha: &Array1<f64>,
347    residual: &Array1<f64>,
348    x_data: &[f64],
349    beta: &[f64],
350    basis_jacobian: &J,
351) -> Array1<f64>
352where
353    J: Fn(&[f64], &[f64]) -> Array2<f64>,
354{
355    let dphi_dbeta = basis_jacobian(x_data, beta);
356    let (_n_total, q) = dphi_dbeta.dim();
357    let n = residual.len();
358    let p = alpha.len();
359
360    // Reshape dphi_dbeta from (n*p, q) to compute gradient
361    let mut gradient = Array1::zeros(q);
362
363    for j in 0..q {
364        let mut grad_j = 0.0;
365        for i in 0..n {
366            for k in 0..p {
367                let idx = k * n + i;
368                grad_j -= residual[i] * alpha[k] * dphi_dbeta[[idx, j]];
369            }
370        }
371        gradient[j] = grad_j;
372    }
373
374    gradient
375}
376
377/// Simple backtracking line search
378#[allow(dead_code)]
379fn backtracking_line_search<F>(x: &Array1<f64>, direction: &Array1<f64>, f0: f64, f: F) -> f64
380where
381    F: Fn(&[f64]) -> f64,
382{
383    let mut alpha = 1.0;
384    let c = 0.5;
385    let rho = 0.5;
386
387    let grad_dot_dir = direction.iter().map(|&d| d * d).sum::<f64>();
388
389    for _ in 0..20 {
390        let x_new = x - alpha * direction;
391        let f_new = f(x_new.as_slice().expect("Operation failed"));
392
393        if f_new <= f0 - c * alpha * grad_dot_dir {
394            return alpha;
395        }
396
397        alpha *= rho;
398    }
399
400    alpha
401}
402
403/// Solve symmetric positive definite system
404#[allow(dead_code)]
405fn solve_symmetric_system(a: &Array2<f64>, b: &Array1<f64>) -> OptimizeResult<Array1<f64>> {
406    // Cholesky decomposition for symmetric positive definite matrices
407    // Fallback to LU if Cholesky fails
408
409    // Simple Gaussian elimination for now
410    let n = a.shape()[0];
411    let mut aug = Array2::zeros((n, n + 1));
412
413    for i in 0..n {
414        for j in 0..n {
415            aug[[i, j]] = a[[i, j]];
416        }
417        aug[[i, n]] = b[i];
418    }
419
420    // Gaussian elimination
421    for i in 0..n {
422        let pivot = aug[[i, i]];
423        if pivot.abs() < 1e-10 {
424            return Err(crate::error::OptimizeError::ValueError(
425                "Singular matrix in linear solve".to_string(),
426            ));
427        }
428
429        for j in i + 1..n {
430            let factor = aug[[j, i]] / pivot;
431            for k in i..=n {
432                aug[[j, k]] -= factor * aug[[i, k]];
433            }
434        }
435    }
436
437    // Back substitution
438    let mut x = Array1::zeros(n);
439    for i in (0..n).rev() {
440        let mut sum = aug[[i, n]];
441        for j in i + 1..n {
442            sum -= aug[[i, j]] * x[j];
443        }
444        x[i] = sum / aug[[i, i]];
445    }
446
447    Ok(x)
448}
449
450/// QR solve (simplified)
451#[allow(dead_code)]
452fn qr_solve<S>(phi: &Array2<f64>, y: &ArrayBase<S, Ix1>, lambda: f64) -> OptimizeResult<Array1<f64>>
453where
454    S: Data<Elem = f64>,
455{
456    // For simplicity, use normal equations with regularization
457    // A proper implementation would use actual QR decomposition
458    let phi_t_phi = phi.t().dot(phi);
459    let phi_t_y = phi.t().dot(y);
460
461    let mut regularized = phi_t_phi.clone();
462    for i in 0..regularized.shape()[0] {
463        regularized[[i, i]] += lambda;
464    }
465
466    solve_symmetric_system(&regularized, &phi_t_y)
467}
468
469/// SVD solve (simplified)
470#[allow(dead_code)]
471fn svd_solve<S>(
472    phi: &Array2<f64>,
473    y: &ArrayBase<S, Ix1>,
474    lambda: f64,
475) -> OptimizeResult<Array1<f64>>
476where
477    S: Data<Elem = f64>,
478{
479    // For simplicity, use normal equations with regularization
480    // A proper implementation would use actual SVD
481    qr_solve(phi, y, lambda)
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487    use scirs2_core::ndarray::array;
488
489    #[test]
490    fn test_separable_exponential() {
491        // Model: y = α₁ * exp(-β * t) + α₂
492        // True parameters: α₁ = 2.0, α₂ = 0.5, β = 0.7
493
494        fn basis_functions(t: &[f64], beta: &[f64]) -> Array2<f64> {
495            let n = t.len();
496            let mut phi = Array2::zeros((n, 2));
497
498            for i in 0..n {
499                phi[[i, 0]] = (-beta[0] * t[i]).exp();
500                phi[[i, 1]] = 1.0;
501            }
502            phi
503        }
504
505        fn basis_jacobian(t: &[f64], beta: &[f64]) -> Array2<f64> {
506            let n = t.len();
507            let mut dphi_dbeta = Array2::zeros((n * 2, 1));
508
509            for i in 0..n {
510                dphi_dbeta[[i, 0]] = -t[i] * (-beta[0] * t[i]).exp();
511                dphi_dbeta[[n + i, 0]] = 0.0;
512            }
513            dphi_dbeta
514        }
515
516        // Generate synthetic data
517        let t_data = array![0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0];
518        let true_alpha = array![2.0, 0.5];
519        let true_beta = array![0.7];
520
521        let phi_true = basis_functions(
522            t_data.as_slice().expect("Operation failed"),
523            true_beta.as_slice().expect("Operation failed"),
524        );
525        let y_data =
526            phi_true.dot(&true_alpha) + 0.01 * array![0.1, -0.05, 0.08, -0.03, 0.06, -0.04, 0.02];
527
528        // Initial guess
529        let beta0 = array![0.5];
530
531        let result = separable_least_squares(
532            basis_functions,
533            basis_jacobian,
534            &t_data,
535            &y_data,
536            &beta0,
537            None,
538        )
539        .expect("Operation failed");
540
541        assert!(result.result.success);
542        assert!((result.result.x[0] - true_beta[0]).abs() < 0.1);
543        assert!((result.linear_params[0] - true_alpha[0]).abs() < 0.1);
544        assert!((result.linear_params[1] - true_alpha[1]).abs() < 0.1);
545    }
546
547    #[test]
548    fn test_separable_multi_exponential() {
549        // Model: y = α₁ * exp(-β₁ * t) + α₂ * exp(-β₂ * t)
550        // More complex with two nonlinear parameters
551
552        fn basis_functions(t: &[f64], beta: &[f64]) -> Array2<f64> {
553            let n = t.len();
554            let mut phi = Array2::zeros((n, 2));
555
556            for i in 0..n {
557                phi[[i, 0]] = (-beta[0] * t[i]).exp();
558                phi[[i, 1]] = (-beta[1] * t[i]).exp();
559            }
560            phi
561        }
562
563        fn basis_jacobian(t: &[f64], beta: &[f64]) -> Array2<f64> {
564            let n = t.len();
565            let mut dphi_dbeta = Array2::zeros((n * 2, 2));
566
567            for i in 0..n {
568                dphi_dbeta[[i, 0]] = -t[i] * (-beta[0] * t[i]).exp();
569                dphi_dbeta[[i, 1]] = 0.0;
570                dphi_dbeta[[n + i, 0]] = 0.0;
571                dphi_dbeta[[n + i, 1]] = -t[i] * (-beta[1] * t[i]).exp();
572            }
573            dphi_dbeta
574        }
575
576        // Generate synthetic data
577        let t_data = array![0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4];
578        let true_alpha = array![3.0, 1.5];
579        let true_beta = array![2.0, 0.5];
580
581        let phi_true = basis_functions(
582            t_data.as_slice().expect("Operation failed"),
583            true_beta.as_slice().expect("Operation failed"),
584        );
585        let y_data = phi_true.dot(&true_alpha);
586
587        // Initial guess
588        let beta0 = array![1.5, 0.8];
589
590        let mut options = SeparableOptions::default();
591        options.max_iter = 200; // More iterations for harder problem
592        options.beta_tol = 1e-6;
593
594        let result = separable_least_squares(
595            basis_functions,
596            basis_jacobian,
597            &t_data,
598            &y_data,
599            &beta0,
600            Some(options),
601        )
602        .expect("Operation failed");
603
604        // For multi-exponential problems, convergence is harder
605        // Just check that we made good progress
606        assert!(result.result.fun < 0.1, "Cost = {}", result.result.fun);
607
608        // Print results for debugging
609        println!("Multi-exponential results:");
610        println!("Beta: {:?} (true: {:?})", result.result.x, true_beta);
611        println!("Alpha: {:?} (true: {:?})", result.linear_params, true_alpha);
612        println!("Cost: {}", result.result.fun);
613        println!("Success: {}", result.result.success);
614    }
615}