1use crate::error::OptimizeResult;
70use crate::result::OptimizeResults;
71use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Ix1};
72use std::f64;
73
74#[derive(Debug, Clone)]
76pub struct SeparableOptions {
77 pub max_iter: usize,
79
80 pub beta_tol: f64,
82
83 pub ftol: f64,
85
86 pub gtol: f64,
88
89 pub linear_solver: LinearSolver,
91
92 pub lambda: f64,
94}
95
96#[derive(Debug, Clone, Copy)]
98pub enum LinearSolver {
99 QR,
101 NormalEquations,
103 SVD,
105}
106
107impl Default for SeparableOptions {
108 fn default() -> Self {
109 SeparableOptions {
110 max_iter: 100,
111 beta_tol: 1e-8,
112 ftol: 1e-8,
113 gtol: 1e-8,
114 linear_solver: LinearSolver::QR,
115 lambda: 0.0,
116 }
117 }
118}
119
120#[derive(Debug, Clone)]
122pub struct SeparableResult {
123 pub result: OptimizeResults<f64>,
125 pub linear_params: Array1<f64>,
127}
128
129#[allow(dead_code)]
145pub fn separable_least_squares<F, J, S1, S2, S3>(
146 basis_functions: F,
147 basis_jacobian: J,
148 x_data: &ArrayBase<S1, Ix1>,
149 y_data: &ArrayBase<S2, Ix1>,
150 beta0: &ArrayBase<S3, Ix1>,
151 options: Option<SeparableOptions>,
152) -> OptimizeResult<SeparableResult>
153where
154 F: Fn(&[f64], &[f64]) -> Array2<f64>,
155 J: Fn(&[f64], &[f64]) -> Array2<f64>,
156 S1: Data<Elem = f64>,
157 S2: Data<Elem = f64>,
158 S3: Data<Elem = f64>,
159{
160 let options = options.unwrap_or_default();
161 let mut beta = beta0.to_owned();
162
163 let n = y_data.len();
164 if x_data.len() != n {
165 return Err(crate::error::OptimizeError::ValueError(
166 "x_data and y_data must have the same length".to_string(),
167 ));
168 }
169
170 let mut iter = 0;
171 let mut nfev = 0;
172 let mut prev_cost = f64::INFINITY;
173
174 while iter < options.max_iter {
176 let phi = basis_functions(x_data.as_slice().unwrap(), beta.as_slice().unwrap());
178 nfev += 1;
179
180 let (n_points, n_basis) = phi.dim();
181 if n_points != n {
182 return Err(crate::error::OptimizeError::ValueError(
183 "Basis functions returned wrong number of rows".to_string(),
184 ));
185 }
186
187 let alpha = solve_linear_subproblem(&phi, y_data, &options)?;
189
190 let y_pred = phi.dot(&alpha);
192 let residual = y_data - &y_pred;
193 let cost = 0.5 * residual.iter().map(|&r| r * r).sum::<f64>();
194
195 if (prev_cost - cost).abs() < options.ftol * cost {
197 let mut result = OptimizeResults::default();
198 result.x = beta.clone();
199 result.fun = cost;
200 result.nfev = nfev;
201 result.nit = iter;
202 result.success = true;
203 result.message = "Converged (function tolerance)".to_string();
204
205 return Ok(SeparableResult {
206 result,
207 linear_params: alpha,
208 });
209 }
210
211 let gradient = compute_gradient(
213 &phi,
214 &alpha,
215 &residual,
216 x_data.as_slice().unwrap(),
217 beta.as_slice().unwrap(),
218 &basis_jacobian,
219 );
220
221 if gradient.iter().all(|&g| g.abs() < options.gtol) {
223 let mut result = OptimizeResults::default();
224 result.x = beta.clone();
225 result.fun = cost;
226 result.nfev = nfev;
227 result.nit = iter;
228 result.success = true;
229 result.message = "Converged (gradient tolerance)".to_string();
230
231 return Ok(SeparableResult {
232 result,
233 linear_params: alpha,
234 });
235 }
236
237 let step_size = backtracking_line_search(&beta, &gradient, cost, |b| {
240 let phi_new = basis_functions(x_data.as_slice().unwrap(), b);
241 let alpha_new = solve_linear_subproblem(&phi_new, y_data, &options).unwrap();
242 let y_pred_new = phi_new.dot(&alpha_new);
243 let res_new = y_data - &y_pred_new;
244 0.5 * res_new.iter().map(|&r| r * r).sum::<f64>()
245 });
246 nfev += 5; beta = &beta - &gradient * step_size;
249
250 if gradient.iter().map(|&g| g * g).sum::<f64>().sqrt() * step_size < options.beta_tol {
252 let mut result = OptimizeResults::default();
253 result.x = beta.clone();
254 result.fun = cost;
255 result.nfev = nfev;
256 result.nit = iter;
257 result.success = true;
258 result.message = "Converged (parameter tolerance)".to_string();
259
260 let phi_final = basis_functions(x_data.as_slice().unwrap(), beta.as_slice().unwrap());
262 let alpha_final = solve_linear_subproblem(&phi_final, y_data, &options)?;
263
264 return Ok(SeparableResult {
265 result,
266 linear_params: alpha_final,
267 });
268 }
269
270 prev_cost = cost;
271 iter += 1;
272 }
273
274 let phi_final = basis_functions(x_data.as_slice().unwrap(), beta.as_slice().unwrap());
276 let alpha_final = solve_linear_subproblem(&phi_final, y_data, &options)?;
277 let y_pred_final = phi_final.dot(&alpha_final);
278 let res_final = y_data - &y_pred_final;
279 let final_cost = 0.5 * res_final.iter().map(|&r| r * r).sum::<f64>();
280
281 let mut result = OptimizeResults::default();
282 result.x = beta;
283 result.fun = final_cost;
284 result.nfev = nfev;
285 result.nit = iter;
286 result.success = false;
287 result.message = "Maximum iterations reached".to_string();
288
289 Ok(SeparableResult {
290 result,
291 linear_params: alpha_final,
292 })
293}
294
295#[allow(dead_code)]
297fn solve_linear_subproblem<S1>(
298 phi: &Array2<f64>,
299 y: &ArrayBase<S1, Ix1>,
300 options: &SeparableOptions,
301) -> OptimizeResult<Array1<f64>>
302where
303 S1: Data<Elem = f64>,
304{
305 match options.linear_solver {
306 LinearSolver::NormalEquations => {
307 let phi_t_phi = phi.t().dot(phi);
309 let phi_t_y = phi.t().dot(y);
310
311 let mut regularized = phi_t_phi.clone();
313 if options.lambda > 0.0 {
314 for i in 0..regularized.shape()[0] {
315 regularized[[i, i]] += options.lambda;
316 }
317 }
318
319 solve_symmetric_system(®ularized, &phi_t_y)
320 }
321 LinearSolver::QR => {
322 qr_solve(phi, y, options.lambda)
324 }
325 LinearSolver::SVD => {
326 svd_solve(phi, y, options.lambda)
328 }
329 }
330}
331
332#[allow(dead_code)]
334fn compute_gradient<J>(
335 _phi: &Array2<f64>,
336 alpha: &Array1<f64>,
337 residual: &Array1<f64>,
338 x_data: &[f64],
339 beta: &[f64],
340 basis_jacobian: &J,
341) -> Array1<f64>
342where
343 J: Fn(&[f64], &[f64]) -> Array2<f64>,
344{
345 let dphi_dbeta = basis_jacobian(x_data, beta);
346 let (_n_total, q) = dphi_dbeta.dim();
347 let n = residual.len();
348 let p = alpha.len();
349
350 let mut gradient = Array1::zeros(q);
352
353 for j in 0..q {
354 let mut grad_j = 0.0;
355 for i in 0..n {
356 for k in 0..p {
357 let idx = k * n + i;
358 grad_j -= residual[i] * alpha[k] * dphi_dbeta[[idx, j]];
359 }
360 }
361 gradient[j] = grad_j;
362 }
363
364 gradient
365}
366
367#[allow(dead_code)]
369fn backtracking_line_search<F>(x: &Array1<f64>, direction: &Array1<f64>, f0: f64, f: F) -> f64
370where
371 F: Fn(&[f64]) -> f64,
372{
373 let mut alpha = 1.0;
374 let c = 0.5;
375 let rho = 0.5;
376
377 let grad_dot_dir = direction.iter().map(|&d| d * d).sum::<f64>();
378
379 for _ in 0..20 {
380 let x_new = x - alpha * direction;
381 let f_new = f(x_new.as_slice().unwrap());
382
383 if f_new <= f0 - c * alpha * grad_dot_dir {
384 return alpha;
385 }
386
387 alpha *= rho;
388 }
389
390 alpha
391}
392
393#[allow(dead_code)]
395fn solve_symmetric_system(a: &Array2<f64>, b: &Array1<f64>) -> OptimizeResult<Array1<f64>> {
396 let n = a.shape()[0];
401 let mut aug = Array2::zeros((n, n + 1));
402
403 for i in 0..n {
404 for j in 0..n {
405 aug[[i, j]] = a[[i, j]];
406 }
407 aug[[i, n]] = b[i];
408 }
409
410 for i in 0..n {
412 let pivot = aug[[i, i]];
413 if pivot.abs() < 1e-10 {
414 return Err(crate::error::OptimizeError::ValueError(
415 "Singular matrix in linear solve".to_string(),
416 ));
417 }
418
419 for j in i + 1..n {
420 let factor = aug[[j, i]] / pivot;
421 for k in i..=n {
422 aug[[j, k]] -= factor * aug[[i, k]];
423 }
424 }
425 }
426
427 let mut x = Array1::zeros(n);
429 for i in (0..n).rev() {
430 let mut sum = aug[[i, n]];
431 for j in i + 1..n {
432 sum -= aug[[i, j]] * x[j];
433 }
434 x[i] = sum / aug[[i, i]];
435 }
436
437 Ok(x)
438}
439
440#[allow(dead_code)]
442fn qr_solve<S>(phi: &Array2<f64>, y: &ArrayBase<S, Ix1>, lambda: f64) -> OptimizeResult<Array1<f64>>
443where
444 S: Data<Elem = f64>,
445{
446 let phi_t_phi = phi.t().dot(phi);
449 let phi_t_y = phi.t().dot(y);
450
451 let mut regularized = phi_t_phi.clone();
452 for i in 0..regularized.shape()[0] {
453 regularized[[i, i]] += lambda;
454 }
455
456 solve_symmetric_system(®ularized, &phi_t_y)
457}
458
459#[allow(dead_code)]
461fn svd_solve<S>(
462 phi: &Array2<f64>,
463 y: &ArrayBase<S, Ix1>,
464 lambda: f64,
465) -> OptimizeResult<Array1<f64>>
466where
467 S: Data<Elem = f64>,
468{
469 qr_solve(phi, y, lambda)
472}
473
474#[cfg(test)]
475mod tests {
476 use super::*;
477 use scirs2_core::ndarray::array;
478
479 #[test]
480 fn test_separable_exponential() {
481 fn basis_functions(t: &[f64], beta: &[f64]) -> Array2<f64> {
485 let n = t.len();
486 let mut phi = Array2::zeros((n, 2));
487
488 for i in 0..n {
489 phi[[i, 0]] = (-beta[0] * t[i]).exp();
490 phi[[i, 1]] = 1.0;
491 }
492 phi
493 }
494
495 fn basis_jacobian(t: &[f64], beta: &[f64]) -> Array2<f64> {
496 let n = t.len();
497 let mut dphi_dbeta = Array2::zeros((n * 2, 1));
498
499 for i in 0..n {
500 dphi_dbeta[[i, 0]] = -t[i] * (-beta[0] * t[i]).exp();
501 dphi_dbeta[[n + i, 0]] = 0.0;
502 }
503 dphi_dbeta
504 }
505
506 let t_data = array![0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0];
508 let true_alpha = array![2.0, 0.5];
509 let true_beta = array![0.7];
510
511 let phi_true = basis_functions(t_data.as_slice().unwrap(), true_beta.as_slice().unwrap());
512 let y_data =
513 phi_true.dot(&true_alpha) + 0.01 * array![0.1, -0.05, 0.08, -0.03, 0.06, -0.04, 0.02];
514
515 let beta0 = array![0.5];
517
518 let result = separable_least_squares(
519 basis_functions,
520 basis_jacobian,
521 &t_data,
522 &y_data,
523 &beta0,
524 None,
525 )
526 .unwrap();
527
528 assert!(result.result.success);
529 assert!((result.result.x[0] - true_beta[0]).abs() < 0.1);
530 assert!((result.linear_params[0] - true_alpha[0]).abs() < 0.1);
531 assert!((result.linear_params[1] - true_alpha[1]).abs() < 0.1);
532 }
533
534 #[test]
535 fn test_separable_multi_exponential() {
536 fn basis_functions(t: &[f64], beta: &[f64]) -> Array2<f64> {
540 let n = t.len();
541 let mut phi = Array2::zeros((n, 2));
542
543 for i in 0..n {
544 phi[[i, 0]] = (-beta[0] * t[i]).exp();
545 phi[[i, 1]] = (-beta[1] * t[i]).exp();
546 }
547 phi
548 }
549
550 fn basis_jacobian(t: &[f64], beta: &[f64]) -> Array2<f64> {
551 let n = t.len();
552 let mut dphi_dbeta = Array2::zeros((n * 2, 2));
553
554 for i in 0..n {
555 dphi_dbeta[[i, 0]] = -t[i] * (-beta[0] * t[i]).exp();
556 dphi_dbeta[[i, 1]] = 0.0;
557 dphi_dbeta[[n + i, 0]] = 0.0;
558 dphi_dbeta[[n + i, 1]] = -t[i] * (-beta[1] * t[i]).exp();
559 }
560 dphi_dbeta
561 }
562
563 let t_data = array![0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4];
565 let true_alpha = array![3.0, 1.5];
566 let true_beta = array![2.0, 0.5];
567
568 let phi_true = basis_functions(t_data.as_slice().unwrap(), true_beta.as_slice().unwrap());
569 let y_data = phi_true.dot(&true_alpha);
570
571 let beta0 = array![1.5, 0.8];
573
574 let mut options = SeparableOptions::default();
575 options.max_iter = 200; options.beta_tol = 1e-6;
577
578 let result = separable_least_squares(
579 basis_functions,
580 basis_jacobian,
581 &t_data,
582 &y_data,
583 &beta0,
584 Some(options),
585 )
586 .unwrap();
587
588 assert!(result.result.fun < 0.1, "Cost = {}", result.result.fun);
591
592 println!("Multi-exponential results:");
594 println!("Beta: {:?} (true: {:?})", result.result.x, true_beta);
595 println!("Alpha: {:?} (true: {:?})", result.linear_params, true_alpha);
596 println!("Cost: {}", result.result.fun);
597 println!("Success: {}", result.result.success);
598 }
599}