1use crate::error::OptimizeResult;
70use crate::result::OptimizeResults;
71use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Ix1};
72use std::f64;
73
74#[derive(Debug, Clone)]
76pub struct SeparableOptions {
77    pub max_iter: usize,
79
80    pub beta_tol: f64,
82
83    pub ftol: f64,
85
86    pub gtol: f64,
88
89    pub linear_solver: LinearSolver,
91
92    pub lambda: f64,
94}
95
96#[derive(Debug, Clone, Copy)]
98pub enum LinearSolver {
99    QR,
101    NormalEquations,
103    SVD,
105}
106
107impl Default for SeparableOptions {
108    fn default() -> Self {
109        SeparableOptions {
110            max_iter: 100,
111            beta_tol: 1e-8,
112            ftol: 1e-8,
113            gtol: 1e-8,
114            linear_solver: LinearSolver::QR,
115            lambda: 0.0,
116        }
117    }
118}
119
120#[derive(Debug, Clone)]
122pub struct SeparableResult {
123    pub result: OptimizeResults<f64>,
125    pub linear_params: Array1<f64>,
127}
128
129#[allow(dead_code)]
145pub fn separable_least_squares<F, J, S1, S2, S3>(
146    basis_functions: F,
147    basis_jacobian: J,
148    x_data: &ArrayBase<S1, Ix1>,
149    y_data: &ArrayBase<S2, Ix1>,
150    beta0: &ArrayBase<S3, Ix1>,
151    options: Option<SeparableOptions>,
152) -> OptimizeResult<SeparableResult>
153where
154    F: Fn(&[f64], &[f64]) -> Array2<f64>,
155    J: Fn(&[f64], &[f64]) -> Array2<f64>,
156    S1: Data<Elem = f64>,
157    S2: Data<Elem = f64>,
158    S3: Data<Elem = f64>,
159{
160    let options = options.unwrap_or_default();
161    let mut beta = beta0.to_owned();
162
163    let n = y_data.len();
164    if x_data.len() != n {
165        return Err(crate::error::OptimizeError::ValueError(
166            "x_data and y_data must have the same length".to_string(),
167        ));
168    }
169
170    let mut iter = 0;
171    let mut nfev = 0;
172    let mut prev_cost = f64::INFINITY;
173
174    while iter < options.max_iter {
176        let phi = basis_functions(x_data.as_slice().unwrap(), beta.as_slice().unwrap());
178        nfev += 1;
179
180        let (n_points, n_basis) = phi.dim();
181        if n_points != n {
182            return Err(crate::error::OptimizeError::ValueError(
183                "Basis functions returned wrong number of rows".to_string(),
184            ));
185        }
186
187        let alpha = solve_linear_subproblem(&phi, y_data, &options)?;
189
190        let y_pred = phi.dot(&alpha);
192        let residual = y_data - &y_pred;
193        let cost = 0.5 * residual.iter().map(|&r| r * r).sum::<f64>();
194
195        if (prev_cost - cost).abs() < options.ftol * cost {
197            let mut result = OptimizeResults::default();
198            result.x = beta.clone();
199            result.fun = cost;
200            result.nfev = nfev;
201            result.nit = iter;
202            result.success = true;
203            result.message = "Converged (function tolerance)".to_string();
204
205            return Ok(SeparableResult {
206                result,
207                linear_params: alpha,
208            });
209        }
210
211        let gradient = compute_gradient(
213            &phi,
214            &alpha,
215            &residual,
216            x_data.as_slice().unwrap(),
217            beta.as_slice().unwrap(),
218            &basis_jacobian,
219        );
220
221        if gradient.iter().all(|&g| g.abs() < options.gtol) {
223            let mut result = OptimizeResults::default();
224            result.x = beta.clone();
225            result.fun = cost;
226            result.nfev = nfev;
227            result.nit = iter;
228            result.success = true;
229            result.message = "Converged (gradient tolerance)".to_string();
230
231            return Ok(SeparableResult {
232                result,
233                linear_params: alpha,
234            });
235        }
236
237        let step_size = backtracking_line_search(&beta, &gradient, cost, |b| {
240            let phi_new = basis_functions(x_data.as_slice().unwrap(), b);
241            let alpha_new = solve_linear_subproblem(&phi_new, y_data, &options).unwrap();
242            let y_pred_new = phi_new.dot(&alpha_new);
243            let res_new = y_data - &y_pred_new;
244            0.5 * res_new.iter().map(|&r| r * r).sum::<f64>()
245        });
246        nfev += 5; beta = &beta - &gradient * step_size;
249
250        if gradient.iter().map(|&g| g * g).sum::<f64>().sqrt() * step_size < options.beta_tol {
252            let mut result = OptimizeResults::default();
253            result.x = beta.clone();
254            result.fun = cost;
255            result.nfev = nfev;
256            result.nit = iter;
257            result.success = true;
258            result.message = "Converged (parameter tolerance)".to_string();
259
260            let phi_final = basis_functions(x_data.as_slice().unwrap(), beta.as_slice().unwrap());
262            let alpha_final = solve_linear_subproblem(&phi_final, y_data, &options)?;
263
264            return Ok(SeparableResult {
265                result,
266                linear_params: alpha_final,
267            });
268        }
269
270        prev_cost = cost;
271        iter += 1;
272    }
273
274    let phi_final = basis_functions(x_data.as_slice().unwrap(), beta.as_slice().unwrap());
276    let alpha_final = solve_linear_subproblem(&phi_final, y_data, &options)?;
277    let y_pred_final = phi_final.dot(&alpha_final);
278    let res_final = y_data - &y_pred_final;
279    let final_cost = 0.5 * res_final.iter().map(|&r| r * r).sum::<f64>();
280
281    let mut result = OptimizeResults::default();
282    result.x = beta;
283    result.fun = final_cost;
284    result.nfev = nfev;
285    result.nit = iter;
286    result.success = false;
287    result.message = "Maximum iterations reached".to_string();
288
289    Ok(SeparableResult {
290        result,
291        linear_params: alpha_final,
292    })
293}
294
295#[allow(dead_code)]
297fn solve_linear_subproblem<S1>(
298    phi: &Array2<f64>,
299    y: &ArrayBase<S1, Ix1>,
300    options: &SeparableOptions,
301) -> OptimizeResult<Array1<f64>>
302where
303    S1: Data<Elem = f64>,
304{
305    match options.linear_solver {
306        LinearSolver::NormalEquations => {
307            let phi_t_phi = phi.t().dot(phi);
309            let phi_t_y = phi.t().dot(y);
310
311            let mut regularized = phi_t_phi.clone();
313            if options.lambda > 0.0 {
314                for i in 0..regularized.shape()[0] {
315                    regularized[[i, i]] += options.lambda;
316                }
317            }
318
319            solve_symmetric_system(®ularized, &phi_t_y)
320        }
321        LinearSolver::QR => {
322            qr_solve(phi, y, options.lambda)
324        }
325        LinearSolver::SVD => {
326            svd_solve(phi, y, options.lambda)
328        }
329    }
330}
331
332#[allow(dead_code)]
334fn compute_gradient<J>(
335    _phi: &Array2<f64>,
336    alpha: &Array1<f64>,
337    residual: &Array1<f64>,
338    x_data: &[f64],
339    beta: &[f64],
340    basis_jacobian: &J,
341) -> Array1<f64>
342where
343    J: Fn(&[f64], &[f64]) -> Array2<f64>,
344{
345    let dphi_dbeta = basis_jacobian(x_data, beta);
346    let (_n_total, q) = dphi_dbeta.dim();
347    let n = residual.len();
348    let p = alpha.len();
349
350    let mut gradient = Array1::zeros(q);
352
353    for j in 0..q {
354        let mut grad_j = 0.0;
355        for i in 0..n {
356            for k in 0..p {
357                let idx = k * n + i;
358                grad_j -= residual[i] * alpha[k] * dphi_dbeta[[idx, j]];
359            }
360        }
361        gradient[j] = grad_j;
362    }
363
364    gradient
365}
366
367#[allow(dead_code)]
369fn backtracking_line_search<F>(x: &Array1<f64>, direction: &Array1<f64>, f0: f64, f: F) -> f64
370where
371    F: Fn(&[f64]) -> f64,
372{
373    let mut alpha = 1.0;
374    let c = 0.5;
375    let rho = 0.5;
376
377    let grad_dot_dir = direction.iter().map(|&d| d * d).sum::<f64>();
378
379    for _ in 0..20 {
380        let x_new = x - alpha * direction;
381        let f_new = f(x_new.as_slice().unwrap());
382
383        if f_new <= f0 - c * alpha * grad_dot_dir {
384            return alpha;
385        }
386
387        alpha *= rho;
388    }
389
390    alpha
391}
392
393#[allow(dead_code)]
395fn solve_symmetric_system(a: &Array2<f64>, b: &Array1<f64>) -> OptimizeResult<Array1<f64>> {
396    let n = a.shape()[0];
401    let mut aug = Array2::zeros((n, n + 1));
402
403    for i in 0..n {
404        for j in 0..n {
405            aug[[i, j]] = a[[i, j]];
406        }
407        aug[[i, n]] = b[i];
408    }
409
410    for i in 0..n {
412        let pivot = aug[[i, i]];
413        if pivot.abs() < 1e-10 {
414            return Err(crate::error::OptimizeError::ValueError(
415                "Singular matrix in linear solve".to_string(),
416            ));
417        }
418
419        for j in i + 1..n {
420            let factor = aug[[j, i]] / pivot;
421            for k in i..=n {
422                aug[[j, k]] -= factor * aug[[i, k]];
423            }
424        }
425    }
426
427    let mut x = Array1::zeros(n);
429    for i in (0..n).rev() {
430        let mut sum = aug[[i, n]];
431        for j in i + 1..n {
432            sum -= aug[[i, j]] * x[j];
433        }
434        x[i] = sum / aug[[i, i]];
435    }
436
437    Ok(x)
438}
439
440#[allow(dead_code)]
442fn qr_solve<S>(phi: &Array2<f64>, y: &ArrayBase<S, Ix1>, lambda: f64) -> OptimizeResult<Array1<f64>>
443where
444    S: Data<Elem = f64>,
445{
446    let phi_t_phi = phi.t().dot(phi);
449    let phi_t_y = phi.t().dot(y);
450
451    let mut regularized = phi_t_phi.clone();
452    for i in 0..regularized.shape()[0] {
453        regularized[[i, i]] += lambda;
454    }
455
456    solve_symmetric_system(®ularized, &phi_t_y)
457}
458
459#[allow(dead_code)]
461fn svd_solve<S>(
462    phi: &Array2<f64>,
463    y: &ArrayBase<S, Ix1>,
464    lambda: f64,
465) -> OptimizeResult<Array1<f64>>
466where
467    S: Data<Elem = f64>,
468{
469    qr_solve(phi, y, lambda)
472}
473
474#[cfg(test)]
475mod tests {
476    use super::*;
477    use scirs2_core::ndarray::array;
478
479    #[test]
480    fn test_separable_exponential() {
481        fn basis_functions(t: &[f64], beta: &[f64]) -> Array2<f64> {
485            let n = t.len();
486            let mut phi = Array2::zeros((n, 2));
487
488            for i in 0..n {
489                phi[[i, 0]] = (-beta[0] * t[i]).exp();
490                phi[[i, 1]] = 1.0;
491            }
492            phi
493        }
494
495        fn basis_jacobian(t: &[f64], beta: &[f64]) -> Array2<f64> {
496            let n = t.len();
497            let mut dphi_dbeta = Array2::zeros((n * 2, 1));
498
499            for i in 0..n {
500                dphi_dbeta[[i, 0]] = -t[i] * (-beta[0] * t[i]).exp();
501                dphi_dbeta[[n + i, 0]] = 0.0;
502            }
503            dphi_dbeta
504        }
505
506        let t_data = array![0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0];
508        let true_alpha = array![2.0, 0.5];
509        let true_beta = array![0.7];
510
511        let phi_true = basis_functions(t_data.as_slice().unwrap(), true_beta.as_slice().unwrap());
512        let y_data =
513            phi_true.dot(&true_alpha) + 0.01 * array![0.1, -0.05, 0.08, -0.03, 0.06, -0.04, 0.02];
514
515        let beta0 = array![0.5];
517
518        let result = separable_least_squares(
519            basis_functions,
520            basis_jacobian,
521            &t_data,
522            &y_data,
523            &beta0,
524            None,
525        )
526        .unwrap();
527
528        assert!(result.result.success);
529        assert!((result.result.x[0] - true_beta[0]).abs() < 0.1);
530        assert!((result.linear_params[0] - true_alpha[0]).abs() < 0.1);
531        assert!((result.linear_params[1] - true_alpha[1]).abs() < 0.1);
532    }
533
534    #[test]
535    fn test_separable_multi_exponential() {
536        fn basis_functions(t: &[f64], beta: &[f64]) -> Array2<f64> {
540            let n = t.len();
541            let mut phi = Array2::zeros((n, 2));
542
543            for i in 0..n {
544                phi[[i, 0]] = (-beta[0] * t[i]).exp();
545                phi[[i, 1]] = (-beta[1] * t[i]).exp();
546            }
547            phi
548        }
549
550        fn basis_jacobian(t: &[f64], beta: &[f64]) -> Array2<f64> {
551            let n = t.len();
552            let mut dphi_dbeta = Array2::zeros((n * 2, 2));
553
554            for i in 0..n {
555                dphi_dbeta[[i, 0]] = -t[i] * (-beta[0] * t[i]).exp();
556                dphi_dbeta[[i, 1]] = 0.0;
557                dphi_dbeta[[n + i, 0]] = 0.0;
558                dphi_dbeta[[n + i, 1]] = -t[i] * (-beta[1] * t[i]).exp();
559            }
560            dphi_dbeta
561        }
562
563        let t_data = array![0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4];
565        let true_alpha = array![3.0, 1.5];
566        let true_beta = array![2.0, 0.5];
567
568        let phi_true = basis_functions(t_data.as_slice().unwrap(), true_beta.as_slice().unwrap());
569        let y_data = phi_true.dot(&true_alpha);
570
571        let beta0 = array![1.5, 0.8];
573
574        let mut options = SeparableOptions::default();
575        options.max_iter = 200; options.beta_tol = 1e-6;
577
578        let result = separable_least_squares(
579            basis_functions,
580            basis_jacobian,
581            &t_data,
582            &y_data,
583            &beta0,
584            Some(options),
585        )
586        .unwrap();
587
588        assert!(result.result.fun < 0.1, "Cost = {}", result.result.fun);
591
592        println!("Multi-exponential results:");
594        println!("Beta: {:?} (true: {:?})", result.result.x, true_beta);
595        println!("Alpha: {:?} (true: {:?})", result.linear_params, true_alpha);
596        println!("Cost: {}", result.result.fun);
597        println!("Success: {}", result.result.success);
598    }
599}