1use crate::error::OptimizeResult;
70use crate::result::OptimizeResults;
71use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Ix1};
72use std::f64;
73
74#[derive(Debug, Clone)]
76pub struct SeparableOptions {
77 pub max_iter: usize,
79
80 pub beta_tol: f64,
82
83 pub ftol: f64,
85
86 pub gtol: f64,
88
89 pub linear_solver: LinearSolver,
91
92 pub lambda: f64,
94}
95
96#[derive(Debug, Clone, Copy)]
98pub enum LinearSolver {
99 QR,
101 NormalEquations,
103 SVD,
105}
106
107impl Default for SeparableOptions {
108 fn default() -> Self {
109 SeparableOptions {
110 max_iter: 100,
111 beta_tol: 1e-8,
112 ftol: 1e-8,
113 gtol: 1e-8,
114 linear_solver: LinearSolver::QR,
115 lambda: 0.0,
116 }
117 }
118}
119
120#[derive(Debug, Clone)]
122pub struct SeparableResult {
123 pub result: OptimizeResults<f64>,
125 pub linear_params: Array1<f64>,
127}
128
129#[allow(dead_code)]
145pub fn separable_least_squares<F, J, S1, S2, S3>(
146 basis_functions: F,
147 basis_jacobian: J,
148 x_data: &ArrayBase<S1, Ix1>,
149 y_data: &ArrayBase<S2, Ix1>,
150 beta0: &ArrayBase<S3, Ix1>,
151 options: Option<SeparableOptions>,
152) -> OptimizeResult<SeparableResult>
153where
154 F: Fn(&[f64], &[f64]) -> Array2<f64>,
155 J: Fn(&[f64], &[f64]) -> Array2<f64>,
156 S1: Data<Elem = f64>,
157 S2: Data<Elem = f64>,
158 S3: Data<Elem = f64>,
159{
160 let options = options.unwrap_or_default();
161 let mut beta = beta0.to_owned();
162
163 let n = y_data.len();
164 if x_data.len() != n {
165 return Err(crate::error::OptimizeError::ValueError(
166 "x_data and y_data must have the same length".to_string(),
167 ));
168 }
169
170 let mut iter = 0;
171 let mut nfev = 0;
172 let mut prev_cost = f64::INFINITY;
173
174 while iter < options.max_iter {
176 let phi = basis_functions(
178 x_data.as_slice().expect("Operation failed"),
179 beta.as_slice().expect("Operation failed"),
180 );
181 nfev += 1;
182
183 let (n_points, n_basis) = phi.dim();
184 if n_points != n {
185 return Err(crate::error::OptimizeError::ValueError(
186 "Basis functions returned wrong number of rows".to_string(),
187 ));
188 }
189
190 let alpha = solve_linear_subproblem(&phi, y_data, &options)?;
192
193 let y_pred = phi.dot(&alpha);
195 let residual = y_data - &y_pred;
196 let cost = 0.5 * residual.iter().map(|&r| r * r).sum::<f64>();
197
198 if (prev_cost - cost).abs() < options.ftol * cost {
200 let mut result = OptimizeResults::default();
201 result.x = beta.clone();
202 result.fun = cost;
203 result.nfev = nfev;
204 result.nit = iter;
205 result.success = true;
206 result.message = "Converged (function tolerance)".to_string();
207
208 return Ok(SeparableResult {
209 result,
210 linear_params: alpha,
211 });
212 }
213
214 let gradient = compute_gradient(
216 &phi,
217 &alpha,
218 &residual,
219 x_data.as_slice().expect("Operation failed"),
220 beta.as_slice().expect("Operation failed"),
221 &basis_jacobian,
222 );
223
224 if gradient.iter().all(|&g| g.abs() < options.gtol) {
226 let mut result = OptimizeResults::default();
227 result.x = beta.clone();
228 result.fun = cost;
229 result.nfev = nfev;
230 result.nit = iter;
231 result.success = true;
232 result.message = "Converged (gradient tolerance)".to_string();
233
234 return Ok(SeparableResult {
235 result,
236 linear_params: alpha,
237 });
238 }
239
240 let step_size = backtracking_line_search(&beta, &gradient, cost, |b| {
243 let phi_new = basis_functions(x_data.as_slice().expect("Operation failed"), b);
244 let alpha_new =
245 solve_linear_subproblem(&phi_new, y_data, &options).expect("Operation failed");
246 let y_pred_new = phi_new.dot(&alpha_new);
247 let res_new = y_data - &y_pred_new;
248 0.5 * res_new.iter().map(|&r| r * r).sum::<f64>()
249 });
250 nfev += 5; beta = &beta - &gradient * step_size;
253
254 if gradient.iter().map(|&g| g * g).sum::<f64>().sqrt() * step_size < options.beta_tol {
256 let mut result = OptimizeResults::default();
257 result.x = beta.clone();
258 result.fun = cost;
259 result.nfev = nfev;
260 result.nit = iter;
261 result.success = true;
262 result.message = "Converged (parameter tolerance)".to_string();
263
264 let phi_final = basis_functions(
266 x_data.as_slice().expect("Operation failed"),
267 beta.as_slice().expect("Operation failed"),
268 );
269 let alpha_final = solve_linear_subproblem(&phi_final, y_data, &options)?;
270
271 return Ok(SeparableResult {
272 result,
273 linear_params: alpha_final,
274 });
275 }
276
277 prev_cost = cost;
278 iter += 1;
279 }
280
281 let phi_final = basis_functions(
283 x_data.as_slice().expect("Operation failed"),
284 beta.as_slice().expect("Operation failed"),
285 );
286 let alpha_final = solve_linear_subproblem(&phi_final, y_data, &options)?;
287 let y_pred_final = phi_final.dot(&alpha_final);
288 let res_final = y_data - &y_pred_final;
289 let final_cost = 0.5 * res_final.iter().map(|&r| r * r).sum::<f64>();
290
291 let mut result = OptimizeResults::default();
292 result.x = beta;
293 result.fun = final_cost;
294 result.nfev = nfev;
295 result.nit = iter;
296 result.success = false;
297 result.message = "Maximum iterations reached".to_string();
298
299 Ok(SeparableResult {
300 result,
301 linear_params: alpha_final,
302 })
303}
304
305#[allow(dead_code)]
307fn solve_linear_subproblem<S1>(
308 phi: &Array2<f64>,
309 y: &ArrayBase<S1, Ix1>,
310 options: &SeparableOptions,
311) -> OptimizeResult<Array1<f64>>
312where
313 S1: Data<Elem = f64>,
314{
315 match options.linear_solver {
316 LinearSolver::NormalEquations => {
317 let phi_t_phi = phi.t().dot(phi);
319 let phi_t_y = phi.t().dot(y);
320
321 let mut regularized = phi_t_phi.clone();
323 if options.lambda > 0.0 {
324 for i in 0..regularized.shape()[0] {
325 regularized[[i, i]] += options.lambda;
326 }
327 }
328
329 solve_symmetric_system(®ularized, &phi_t_y)
330 }
331 LinearSolver::QR => {
332 qr_solve(phi, y, options.lambda)
334 }
335 LinearSolver::SVD => {
336 svd_solve(phi, y, options.lambda)
338 }
339 }
340}
341
342#[allow(dead_code)]
344fn compute_gradient<J>(
345 _phi: &Array2<f64>,
346 alpha: &Array1<f64>,
347 residual: &Array1<f64>,
348 x_data: &[f64],
349 beta: &[f64],
350 basis_jacobian: &J,
351) -> Array1<f64>
352where
353 J: Fn(&[f64], &[f64]) -> Array2<f64>,
354{
355 let dphi_dbeta = basis_jacobian(x_data, beta);
356 let (_n_total, q) = dphi_dbeta.dim();
357 let n = residual.len();
358 let p = alpha.len();
359
360 let mut gradient = Array1::zeros(q);
362
363 for j in 0..q {
364 let mut grad_j = 0.0;
365 for i in 0..n {
366 for k in 0..p {
367 let idx = k * n + i;
368 grad_j -= residual[i] * alpha[k] * dphi_dbeta[[idx, j]];
369 }
370 }
371 gradient[j] = grad_j;
372 }
373
374 gradient
375}
376
377#[allow(dead_code)]
379fn backtracking_line_search<F>(x: &Array1<f64>, direction: &Array1<f64>, f0: f64, f: F) -> f64
380where
381 F: Fn(&[f64]) -> f64,
382{
383 let mut alpha = 1.0;
384 let c = 0.5;
385 let rho = 0.5;
386
387 let grad_dot_dir = direction.iter().map(|&d| d * d).sum::<f64>();
388
389 for _ in 0..20 {
390 let x_new = x - alpha * direction;
391 let f_new = f(x_new.as_slice().expect("Operation failed"));
392
393 if f_new <= f0 - c * alpha * grad_dot_dir {
394 return alpha;
395 }
396
397 alpha *= rho;
398 }
399
400 alpha
401}
402
403#[allow(dead_code)]
405fn solve_symmetric_system(a: &Array2<f64>, b: &Array1<f64>) -> OptimizeResult<Array1<f64>> {
406 let n = a.shape()[0];
411 let mut aug = Array2::zeros((n, n + 1));
412
413 for i in 0..n {
414 for j in 0..n {
415 aug[[i, j]] = a[[i, j]];
416 }
417 aug[[i, n]] = b[i];
418 }
419
420 for i in 0..n {
422 let pivot = aug[[i, i]];
423 if pivot.abs() < 1e-10 {
424 return Err(crate::error::OptimizeError::ValueError(
425 "Singular matrix in linear solve".to_string(),
426 ));
427 }
428
429 for j in i + 1..n {
430 let factor = aug[[j, i]] / pivot;
431 for k in i..=n {
432 aug[[j, k]] -= factor * aug[[i, k]];
433 }
434 }
435 }
436
437 let mut x = Array1::zeros(n);
439 for i in (0..n).rev() {
440 let mut sum = aug[[i, n]];
441 for j in i + 1..n {
442 sum -= aug[[i, j]] * x[j];
443 }
444 x[i] = sum / aug[[i, i]];
445 }
446
447 Ok(x)
448}
449
450#[allow(dead_code)]
452fn qr_solve<S>(phi: &Array2<f64>, y: &ArrayBase<S, Ix1>, lambda: f64) -> OptimizeResult<Array1<f64>>
453where
454 S: Data<Elem = f64>,
455{
456 let phi_t_phi = phi.t().dot(phi);
459 let phi_t_y = phi.t().dot(y);
460
461 let mut regularized = phi_t_phi.clone();
462 for i in 0..regularized.shape()[0] {
463 regularized[[i, i]] += lambda;
464 }
465
466 solve_symmetric_system(®ularized, &phi_t_y)
467}
468
469#[allow(dead_code)]
471fn svd_solve<S>(
472 phi: &Array2<f64>,
473 y: &ArrayBase<S, Ix1>,
474 lambda: f64,
475) -> OptimizeResult<Array1<f64>>
476where
477 S: Data<Elem = f64>,
478{
479 qr_solve(phi, y, lambda)
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487 use scirs2_core::ndarray::array;
488
489 #[test]
490 fn test_separable_exponential() {
491 fn basis_functions(t: &[f64], beta: &[f64]) -> Array2<f64> {
495 let n = t.len();
496 let mut phi = Array2::zeros((n, 2));
497
498 for i in 0..n {
499 phi[[i, 0]] = (-beta[0] * t[i]).exp();
500 phi[[i, 1]] = 1.0;
501 }
502 phi
503 }
504
505 fn basis_jacobian(t: &[f64], beta: &[f64]) -> Array2<f64> {
506 let n = t.len();
507 let mut dphi_dbeta = Array2::zeros((n * 2, 1));
508
509 for i in 0..n {
510 dphi_dbeta[[i, 0]] = -t[i] * (-beta[0] * t[i]).exp();
511 dphi_dbeta[[n + i, 0]] = 0.0;
512 }
513 dphi_dbeta
514 }
515
516 let t_data = array![0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0];
518 let true_alpha = array![2.0, 0.5];
519 let true_beta = array![0.7];
520
521 let phi_true = basis_functions(
522 t_data.as_slice().expect("Operation failed"),
523 true_beta.as_slice().expect("Operation failed"),
524 );
525 let y_data =
526 phi_true.dot(&true_alpha) + 0.01 * array![0.1, -0.05, 0.08, -0.03, 0.06, -0.04, 0.02];
527
528 let beta0 = array![0.5];
530
531 let result = separable_least_squares(
532 basis_functions,
533 basis_jacobian,
534 &t_data,
535 &y_data,
536 &beta0,
537 None,
538 )
539 .expect("Operation failed");
540
541 assert!(result.result.success);
542 assert!((result.result.x[0] - true_beta[0]).abs() < 0.1);
543 assert!((result.linear_params[0] - true_alpha[0]).abs() < 0.1);
544 assert!((result.linear_params[1] - true_alpha[1]).abs() < 0.1);
545 }
546
547 #[test]
548 fn test_separable_multi_exponential() {
549 fn basis_functions(t: &[f64], beta: &[f64]) -> Array2<f64> {
553 let n = t.len();
554 let mut phi = Array2::zeros((n, 2));
555
556 for i in 0..n {
557 phi[[i, 0]] = (-beta[0] * t[i]).exp();
558 phi[[i, 1]] = (-beta[1] * t[i]).exp();
559 }
560 phi
561 }
562
563 fn basis_jacobian(t: &[f64], beta: &[f64]) -> Array2<f64> {
564 let n = t.len();
565 let mut dphi_dbeta = Array2::zeros((n * 2, 2));
566
567 for i in 0..n {
568 dphi_dbeta[[i, 0]] = -t[i] * (-beta[0] * t[i]).exp();
569 dphi_dbeta[[i, 1]] = 0.0;
570 dphi_dbeta[[n + i, 0]] = 0.0;
571 dphi_dbeta[[n + i, 1]] = -t[i] * (-beta[1] * t[i]).exp();
572 }
573 dphi_dbeta
574 }
575
576 let t_data = array![0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4];
578 let true_alpha = array![3.0, 1.5];
579 let true_beta = array![2.0, 0.5];
580
581 let phi_true = basis_functions(
582 t_data.as_slice().expect("Operation failed"),
583 true_beta.as_slice().expect("Operation failed"),
584 );
585 let y_data = phi_true.dot(&true_alpha);
586
587 let beta0 = array![1.5, 0.8];
589
590 let mut options = SeparableOptions::default();
591 options.max_iter = 200; options.beta_tol = 1e-6;
593
594 let result = separable_least_squares(
595 basis_functions,
596 basis_jacobian,
597 &t_data,
598 &y_data,
599 &beta0,
600 Some(options),
601 )
602 .expect("Operation failed");
603
604 assert!(result.result.fun < 0.1, "Cost = {}", result.result.fun);
607
608 println!("Multi-exponential results:");
610 println!("Beta: {:?} (true: {:?})", result.result.x, true_beta);
611 println!("Alpha: {:?} (true: {:?})", result.linear_params, true_alpha);
612 println!("Cost: {}", result.result.fun);
613 println!("Success: {}", result.result.success);
614 }
615}