1use crate::error::OptimizeResult;
36use scirs2_core::ndarray::{array, s, Array1, Array2, ArrayBase, ArrayStatCompat, Data, Ix1};
37use statrs::statistics::Statistics;
38
39#[derive(Debug, Clone)]
41pub struct TotalLeastSquaresOptions {
42 pub max_iter: usize,
44
45 pub tol: f64,
47
48 pub method: TLSMethod,
50
51 pub use_weights: bool,
53}
54
55#[derive(Debug, Clone, Copy)]
57pub enum TLSMethod {
58 SVD,
60 Iterative,
62 MaximumLikelihood,
64}
65
66impl Default for TotalLeastSquaresOptions {
67 fn default() -> Self {
68 TotalLeastSquaresOptions {
69 max_iter: 100,
70 tol: 1e-8,
71 method: TLSMethod::SVD,
72 use_weights: true,
73 }
74 }
75}
76
77#[derive(Debug, Clone)]
79pub struct TotalLeastSquaresResult {
80 pub slope: f64,
82
83 pub intercept: f64,
85
86 pub x_corrected: Array1<f64>,
88
89 pub y_corrected: Array1<f64>,
91
92 pub orthogonal_residuals: f64,
94
95 pub nit: usize,
97
98 pub converged: bool,
100}
101
102#[allow(dead_code)]
115pub fn total_least_squares<S1, S2, S3, S4>(
116 x_measured: &ArrayBase<S1, Ix1>,
117 y_measured: &ArrayBase<S2, Ix1>,
118 x_variance: Option<&ArrayBase<S3, Ix1>>,
119 y_variance: Option<&ArrayBase<S4, Ix1>>,
120 options: Option<TotalLeastSquaresOptions>,
121) -> OptimizeResult<TotalLeastSquaresResult>
122where
123 S1: Data<Elem = f64>,
124 S2: Data<Elem = f64>,
125 S3: Data<Elem = f64>,
126 S4: Data<Elem = f64>,
127{
128 let options = options.unwrap_or_default();
129 let n = x_measured.len();
130
131 if y_measured.len() != n {
132 return Err(crate::error::OptimizeError::ValueError(
133 "x_measured and y_measured must have the same length".to_string(),
134 ));
135 }
136
137 if let Some(x_var) = x_variance {
139 if x_var.len() != n {
140 return Err(crate::error::OptimizeError::ValueError(
141 "x_variance must have the same length as data".to_string(),
142 ));
143 }
144 }
145
146 if let Some(y_var) = y_variance {
147 if y_var.len() != n {
148 return Err(crate::error::OptimizeError::ValueError(
149 "y_variance must have the same length as data".to_string(),
150 ));
151 }
152 }
153
154 match options.method {
155 TLSMethod::SVD => tls_svd(x_measured, y_measured, x_variance, y_variance, &options),
156 TLSMethod::Iterative => {
157 tls_iterative(x_measured, y_measured, x_variance, y_variance, &options)
158 }
159 TLSMethod::MaximumLikelihood => {
160 tls_maximum_likelihood(x_measured, y_measured, x_variance, y_variance, &options)
161 }
162 }
163}
164
165#[allow(dead_code)]
167fn tls_svd<S1, S2, S3, S4>(
168 x_measured: &ArrayBase<S1, Ix1>,
169 y_measured: &ArrayBase<S2, Ix1>,
170 x_variance: Option<&ArrayBase<S3, Ix1>>,
171 y_variance: Option<&ArrayBase<S4, Ix1>>,
172 options: &TotalLeastSquaresOptions,
173) -> OptimizeResult<TotalLeastSquaresResult>
174where
175 S1: Data<Elem = f64>,
176 S2: Data<Elem = f64>,
177 S3: Data<Elem = f64>,
178 S4: Data<Elem = f64>,
179{
180 let n = x_measured.len();
181
182 let x_mean = x_measured.mean_or(0.0);
184 let y_mean = y_measured.mean_or(0.0);
185
186 let x_centered = x_measured - x_mean;
187 let y_centered = y_measured - y_mean;
188
189 let (x_scaled, y_scaled) =
191 if options.use_weights && x_variance.is_some() && y_variance.is_some() {
192 let x_var = x_variance.expect("Operation failed");
193 let y_var = y_variance.expect("Operation failed");
194
195 let x_weights = x_var.mapv(|v| 1.0 / v.sqrt());
197 let y_weights = y_var.mapv(|v| 1.0 / v.sqrt());
198
199 (
200 (&x_centered * &x_weights).to_owned(),
201 (&y_centered * &y_weights).to_owned(),
202 )
203 } else {
204 (x_centered.to_owned(), y_centered.to_owned())
205 };
206
207 let mut data_matrix = Array2::zeros((n, 2));
209 for i in 0..n {
210 data_matrix[[i, 0]] = x_scaled[i];
211 data_matrix[[i, 1]] = y_scaled[i];
212 }
213
214 let cov_matrix = data_matrix.t().dot(&data_matrix) / n as f64;
217
218 let (eigenvalues, eigenvectors) = eigen_2x2(&cov_matrix);
220
221 let min_idx = if eigenvalues[0] < eigenvalues[1] {
223 0
224 } else {
225 1
226 };
227 let normal = eigenvectors.slice(s![.., min_idx]).to_owned();
228
229 let a = normal[0usize];
232 let b = normal[1usize];
233
234 if b.abs() < 1e-10 {
235 return Err(crate::error::OptimizeError::ValueError(
237 "Nearly vertical line detected".to_string(),
238 ));
239 }
240
241 let slope = -a / b;
242 let intercept = y_mean - slope * x_mean;
243
244 let mut x_corrected = Array1::zeros(n);
246 let mut y_corrected = Array1::zeros(n);
247 let mut total_residual = 0.0;
248
249 for i in 0..n {
250 let (x_proj, y_proj) =
251 orthogonal_projection(x_measured[i], y_measured[i], slope, intercept);
252 x_corrected[i] = x_proj;
253 y_corrected[i] = y_proj;
254
255 let dx = x_measured[i] - x_proj;
256 let dy = y_measured[i] - y_proj;
257 total_residual += dx * dx + dy * dy;
258 }
259
260 Ok(TotalLeastSquaresResult {
261 slope,
262 intercept,
263 x_corrected,
264 y_corrected,
265 orthogonal_residuals: total_residual,
266 nit: 1,
267 converged: true,
268 })
269}
270
271#[allow(dead_code)]
273fn tls_iterative<S1, S2, S3, S4>(
274 x_measured: &ArrayBase<S1, Ix1>,
275 y_measured: &ArrayBase<S2, Ix1>,
276 x_variance: Option<&ArrayBase<S3, Ix1>>,
277 y_variance: Option<&ArrayBase<S4, Ix1>>,
278 options: &TotalLeastSquaresOptions,
279) -> OptimizeResult<TotalLeastSquaresResult>
280where
281 S1: Data<Elem = f64>,
282 S2: Data<Elem = f64>,
283 S3: Data<Elem = f64>,
284 S4: Data<Elem = f64>,
285{
286 let n = x_measured.len();
287
288 let (mut slope, mut intercept) = ordinary_least_squares(x_measured, y_measured);
290
291 let mut x_corrected = x_measured.to_owned();
292 let mut y_corrected = y_measured.to_owned();
293 let mut prev_residual = f64::INFINITY;
294
295 let x_weights = if let Some(x_var) = x_variance {
297 x_var.mapv(|v| 1.0 / v)
298 } else {
299 Array1::ones(n)
300 };
301
302 let y_weights = if let Some(y_var) = y_variance {
303 y_var.mapv(|v| 1.0 / v)
304 } else {
305 Array1::ones(n)
306 };
307
308 let mut iter = 0;
309 let mut converged = false;
310
311 while iter < options.max_iter {
312 let mut total_residual = 0.0;
314
315 for i in 0..n {
316 let (x_proj, y_proj) = weighted_orthogonal_projection(
317 x_measured[i],
318 y_measured[i],
319 slope,
320 intercept,
321 x_weights[i],
322 y_weights[i],
323 );
324
325 x_corrected[i] = x_proj;
326 y_corrected[i] = y_proj;
327
328 let dx = x_measured[i] - x_proj;
329 let dy = y_measured[i] - y_proj;
330 total_residual += x_weights[i] * dx * dx + y_weights[i] * dy * dy;
331 }
332
333 let (new_slope, new_intercept) =
335 weighted_least_squares_line(&x_corrected, &y_corrected, &x_weights, &y_weights);
336
337 if (total_residual - prev_residual).abs() < options.tol * total_residual
339 && (new_slope - slope).abs() < options.tol
340 && (new_intercept - intercept).abs() < options.tol
341 {
342 converged = true;
343 break;
344 }
345
346 slope = new_slope;
347 intercept = new_intercept;
348 prev_residual = total_residual;
349 iter += 1;
350 }
351
352 let mut orthogonal_residuals = 0.0;
354 for i in 0..n {
355 let dx = x_measured[i] - x_corrected[i];
356 let dy = y_measured[i] - y_corrected[i];
357 orthogonal_residuals += dx * dx + dy * dy;
358 }
359
360 Ok(TotalLeastSquaresResult {
361 slope,
362 intercept,
363 x_corrected,
364 y_corrected,
365 orthogonal_residuals,
366 nit: iter,
367 converged,
368 })
369}
370
371#[allow(dead_code)]
373fn tls_maximum_likelihood<S1, S2, S3, S4>(
374 x_measured: &ArrayBase<S1, Ix1>,
375 y_measured: &ArrayBase<S2, Ix1>,
376 x_variance: Option<&ArrayBase<S3, Ix1>>,
377 y_variance: Option<&ArrayBase<S4, Ix1>>,
378 options: &TotalLeastSquaresOptions,
379) -> OptimizeResult<TotalLeastSquaresResult>
380where
381 S1: Data<Elem = f64>,
382 S2: Data<Elem = f64>,
383 S3: Data<Elem = f64>,
384 S4: Data<Elem = f64>,
385{
386 tls_iterative(x_measured, y_measured, x_variance, y_variance, options)
389}
390
391#[allow(dead_code)]
393fn ordinary_least_squares<S1, S2>(x: &ArrayBase<S1, Ix1>, y: &ArrayBase<S2, Ix1>) -> (f64, f64)
394where
395 S1: Data<Elem = f64>,
396 S2: Data<Elem = f64>,
397{
398 let _n = x.len() as f64;
399 let x_mean = x.mean_or(0.0);
400 let y_mean = y.mean_or(0.0);
401
402 let mut num = 0.0;
403 let mut den = 0.0;
404
405 for i in 0..x.len() {
406 let dx = x[i] - x_mean;
407 let dy = y[i] - y_mean;
408 num += dx * dy;
409 den += dx * dx;
410 }
411
412 let slope = num / den;
413 let intercept = y_mean - slope * x_mean;
414
415 (slope, intercept)
416}
417
418#[allow(dead_code)]
420fn orthogonal_projection(x: f64, y: f64, slope: f64, intercept: f64) -> (f64, f64) {
421 let norm_sq = slope * slope + 1.0;
426 let t = ((y - intercept) * slope + x) / norm_sq;
427
428 let x_proj = t;
429 let y_proj = slope * t + intercept;
430
431 (x_proj, y_proj)
432}
433
434#[allow(dead_code)]
436fn weighted_orthogonal_projection(
437 x: f64,
438 y: f64,
439 slope: f64,
440 intercept: f64,
441 weight_x: f64,
442 weight_y: f64,
443) -> (f64, f64) {
444 let a = weight_x + weight_y * slope * slope;
448 let _b = weight_y * slope;
449 let c = weight_x * x + weight_y * slope * (y - intercept);
450
451 let x_proj = c / a;
452 let y_proj = slope * x_proj + intercept;
453
454 (x_proj, y_proj)
455}
456
457#[allow(dead_code)]
459fn weighted_least_squares_line<S1, S2, S3, S4>(
460 x: &ArrayBase<S1, Ix1>,
461 y: &ArrayBase<S2, Ix1>,
462 weight_x: &ArrayBase<S3, Ix1>,
463 weight_y: &ArrayBase<S4, Ix1>,
464) -> (f64, f64)
465where
466 S1: Data<Elem = f64>,
467 S2: Data<Elem = f64>,
468 S3: Data<Elem = f64>,
469 S4: Data<Elem = f64>,
470{
471 let n = x.len();
472 let mut sum_wx = 0.0;
473 let mut sum_wy = 0.0;
474 let mut sum_wxx = 0.0;
475 let mut sum_wxy = 0.0;
476 let mut _sum_wyy = 0.0;
477 let mut sum_w = 0.0;
478
479 for i in 0..n {
480 let w = (weight_x[i] + weight_y[i]) / 2.0; sum_w += w;
482 sum_wx += w * x[i];
483 sum_wy += w * y[i];
484 sum_wxx += w * x[i] * x[i];
485 sum_wxy += w * x[i] * y[i];
486 _sum_wyy += w * y[i] * y[i];
487 }
488
489 let x_mean = sum_wx / sum_w;
490 let y_mean = sum_wy / sum_w;
491
492 let cov_xx = sum_wxx / sum_w - x_mean * x_mean;
493 let cov_xy = sum_wxy / sum_w - x_mean * y_mean;
494
495 let slope = cov_xy / cov_xx;
496 let intercept = y_mean - slope * x_mean;
497
498 (slope, intercept)
499}
500
501#[allow(dead_code)]
503fn eigen_2x2(matrix: &Array2<f64>) -> (Array1<f64>, Array2<f64>) {
504 let a = matrix[[0, 0]];
505 let b = matrix[[0, 1]];
506 let c = matrix[[1, 0]];
507 let d = matrix[[1, 1]];
508
509 let trace = a + d;
511 let det = a * d - b * c;
512
513 let discriminant = trace * trace - 4.0 * det;
514 let sqrt_disc = discriminant.sqrt();
515
516 let lambda1 = (trace + sqrt_disc) / 2.0;
517 let lambda2 = (trace - sqrt_disc) / 2.0;
518
519 let mut eigenvectors = Array2::zeros((2, 2));
521
522 if (a - lambda1).abs() > 1e-10 || b.abs() > 1e-10 {
524 let v1_x = b;
525 let v1_y = lambda1 - a;
526 let norm1 = (v1_x * v1_x + v1_y * v1_y).sqrt();
527 eigenvectors[[0, 0]] = v1_x / norm1;
528 eigenvectors[[1, 0]] = v1_y / norm1;
529 } else {
530 eigenvectors[[0, 0]] = 1.0;
531 eigenvectors[[1, 0]] = 0.0;
532 }
533
534 if (a - lambda2).abs() > 1e-10 || b.abs() > 1e-10 {
536 let v2_x = b;
537 let v2_y = lambda2 - a;
538 let norm2 = (v2_x * v2_x + v2_y * v2_y).sqrt();
539 eigenvectors[[0, 1]] = v2_x / norm2;
540 eigenvectors[[1, 1]] = v2_y / norm2;
541 } else {
542 eigenvectors[[0, 1]] = 0.0;
543 eigenvectors[[1, 1]] = 1.0;
544 }
545
546 (array![lambda1, lambda2], eigenvectors)
547}
548
549#[cfg(test)]
550mod tests {
551 use super::*;
552 use scirs2_core::ndarray::array;
553
554 #[test]
555 fn test_total_least_squares_simple() {
556 let true_slope = 1.5;
558 let true_intercept = 0.5;
559
560 let x_true = array![1.0, 2.0, 3.0, 4.0, 5.0];
561 let y_true = &x_true * true_slope + true_intercept;
562
563 let x_errors = array![0.1, -0.05, 0.08, -0.03, 0.06];
565 let y_errors = array![-0.05, 0.1, -0.07, 0.04, -0.08];
566
567 let x_measured = &x_true + &x_errors;
568 let y_measured = &y_true + &y_errors;
569
570 let result = total_least_squares(
571 &x_measured,
572 &y_measured,
573 None::<&Array1<f64>>,
574 None::<&Array1<f64>>,
575 None,
576 )
577 .expect("Operation failed");
578
579 assert!((result.slope - true_slope).abs() < 0.1);
581 assert!((result.intercept - true_intercept).abs() < 0.1);
582 }
583
584 #[test]
585 fn test_weighted_total_least_squares() {
586 let x_measured = array![1.0, 2.1, 2.9, 4.2, 5.0];
588 let y_measured = array![2.1, 3.9, 5.1, 6.8, 8.1];
589
590 let x_variance = array![0.01, 0.01, 0.01, 0.1, 0.01];
592 let y_variance = array![0.01, 0.02, 0.01, 0.1, 0.01];
593
594 let result = total_least_squares(
595 &x_measured,
596 &y_measured,
597 Some(&x_variance),
598 Some(&y_variance),
599 None,
600 )
601 .expect("Operation failed");
602
603 assert!(result.converged);
605 println!(
606 "Weighted TLS: slope = {:.3}, intercept = {:.3}",
607 result.slope, result.intercept
608 );
609 }
610
611 #[test]
612 fn test_iterative_vs_svd() {
613 let x_measured = array![0.5, 1.5, 2.8, 3.7, 4.9];
615 let y_measured = array![1.2, 2.7, 4.1, 5.3, 6.8];
616
617 let mut options_svd = TotalLeastSquaresOptions::default();
618 options_svd.method = TLSMethod::SVD;
619
620 let mut options_iter = TotalLeastSquaresOptions::default();
621 options_iter.method = TLSMethod::Iterative;
622
623 let result_svd = total_least_squares::<
624 scirs2_core::ndarray::OwnedRepr<f64>,
625 scirs2_core::ndarray::OwnedRepr<f64>,
626 scirs2_core::ndarray::OwnedRepr<f64>,
627 scirs2_core::ndarray::OwnedRepr<f64>,
628 >(
629 &x_measured,
630 &y_measured,
631 None::<&Array1<f64>>,
632 None::<&Array1<f64>>,
633 Some(options_svd),
634 )
635 .expect("Operation failed");
636
637 let result_iter = total_least_squares::<
638 scirs2_core::ndarray::OwnedRepr<f64>,
639 scirs2_core::ndarray::OwnedRepr<f64>,
640 scirs2_core::ndarray::OwnedRepr<f64>,
641 scirs2_core::ndarray::OwnedRepr<f64>,
642 >(
643 &x_measured,
644 &y_measured,
645 None::<&Array1<f64>>,
646 None::<&Array1<f64>>,
647 Some(options_iter),
648 )
649 .expect("Operation failed");
650
651 assert!((result_svd.slope - result_iter.slope).abs() < 0.01);
653 assert!((result_svd.intercept - result_iter.intercept).abs() < 0.01);
654 }
655}