1use crate::error::OptimizeResult;
36use scirs2_core::ndarray::{array, s, Array1, Array2, ArrayBase, Data, Ix1};
37use statrs::statistics::Statistics;
38
39#[derive(Debug, Clone)]
41pub struct TotalLeastSquaresOptions {
42 pub max_iter: usize,
44
45 pub tol: f64,
47
48 pub method: TLSMethod,
50
51 pub use_weights: bool,
53}
54
55#[derive(Debug, Clone, Copy)]
57pub enum TLSMethod {
58 SVD,
60 Iterative,
62 MaximumLikelihood,
64}
65
66impl Default for TotalLeastSquaresOptions {
67 fn default() -> Self {
68 TotalLeastSquaresOptions {
69 max_iter: 100,
70 tol: 1e-8,
71 method: TLSMethod::SVD,
72 use_weights: true,
73 }
74 }
75}
76
77#[derive(Debug, Clone)]
79pub struct TotalLeastSquaresResult {
80 pub slope: f64,
82
83 pub intercept: f64,
85
86 pub x_corrected: Array1<f64>,
88
89 pub y_corrected: Array1<f64>,
91
92 pub orthogonal_residuals: f64,
94
95 pub nit: usize,
97
98 pub converged: bool,
100}
101
102#[allow(dead_code)]
115pub fn total_least_squares<S1, S2, S3, S4>(
116 x_measured: &ArrayBase<S1, Ix1>,
117 y_measured: &ArrayBase<S2, Ix1>,
118 x_variance: Option<&ArrayBase<S3, Ix1>>,
119 y_variance: Option<&ArrayBase<S4, Ix1>>,
120 options: Option<TotalLeastSquaresOptions>,
121) -> OptimizeResult<TotalLeastSquaresResult>
122where
123 S1: Data<Elem = f64>,
124 S2: Data<Elem = f64>,
125 S3: Data<Elem = f64>,
126 S4: Data<Elem = f64>,
127{
128 let options = options.unwrap_or_default();
129 let n = x_measured.len();
130
131 if y_measured.len() != n {
132 return Err(crate::error::OptimizeError::ValueError(
133 "x_measured and y_measured must have the same length".to_string(),
134 ));
135 }
136
137 if let Some(x_var) = x_variance {
139 if x_var.len() != n {
140 return Err(crate::error::OptimizeError::ValueError(
141 "x_variance must have the same length as data".to_string(),
142 ));
143 }
144 }
145
146 if let Some(y_var) = y_variance {
147 if y_var.len() != n {
148 return Err(crate::error::OptimizeError::ValueError(
149 "y_variance must have the same length as data".to_string(),
150 ));
151 }
152 }
153
154 match options.method {
155 TLSMethod::SVD => tls_svd(x_measured, y_measured, x_variance, y_variance, &options),
156 TLSMethod::Iterative => {
157 tls_iterative(x_measured, y_measured, x_variance, y_variance, &options)
158 }
159 TLSMethod::MaximumLikelihood => {
160 tls_maximum_likelihood(x_measured, y_measured, x_variance, y_variance, &options)
161 }
162 }
163}
164
165#[allow(dead_code)]
167fn tls_svd<S1, S2, S3, S4>(
168 x_measured: &ArrayBase<S1, Ix1>,
169 y_measured: &ArrayBase<S2, Ix1>,
170 x_variance: Option<&ArrayBase<S3, Ix1>>,
171 y_variance: Option<&ArrayBase<S4, Ix1>>,
172 options: &TotalLeastSquaresOptions,
173) -> OptimizeResult<TotalLeastSquaresResult>
174where
175 S1: Data<Elem = f64>,
176 S2: Data<Elem = f64>,
177 S3: Data<Elem = f64>,
178 S4: Data<Elem = f64>,
179{
180 let n = x_measured.len();
181
182 let x_mean = x_measured.mean().unwrap();
184 let y_mean = y_measured.mean().unwrap();
185
186 let x_centered = x_measured - x_mean;
187 let y_centered = y_measured - y_mean;
188
189 let (x_scaled, y_scaled) =
191 if options.use_weights && x_variance.is_some() && y_variance.is_some() {
192 let x_var = x_variance.unwrap();
193 let y_var = y_variance.unwrap();
194
195 let x_weights = x_var.mapv(|v| 1.0 / v.sqrt());
197 let y_weights = y_var.mapv(|v| 1.0 / v.sqrt());
198
199 (x_centered * &x_weights, y_centered * &y_weights)
200 } else {
201 (x_centered.clone(), y_centered.clone())
202 };
203
204 let mut data_matrix = Array2::zeros((n, 2));
206 for i in 0..n {
207 data_matrix[[i, 0]] = x_scaled[i];
208 data_matrix[[i, 1]] = y_scaled[i];
209 }
210
211 let cov_matrix = data_matrix.t().dot(&data_matrix) / n as f64;
214
215 let (eigenvalues, eigenvectors) = eigen_2x2(&cov_matrix);
217
218 let min_idx = if eigenvalues[0] < eigenvalues[1] {
220 0
221 } else {
222 1
223 };
224 let normal = eigenvectors.slice(s![.., min_idx]).to_owned();
225
226 let a = normal[0usize];
229 let b = normal[1usize];
230
231 if b.abs() < 1e-10 {
232 return Err(crate::error::OptimizeError::ValueError(
234 "Nearly vertical line detected".to_string(),
235 ));
236 }
237
238 let slope = -a / b;
239 let intercept = y_mean - slope * x_mean;
240
241 let mut x_corrected = Array1::zeros(n);
243 let mut y_corrected = Array1::zeros(n);
244 let mut total_residual = 0.0;
245
246 for i in 0..n {
247 let (x_proj, y_proj) =
248 orthogonal_projection(x_measured[i], y_measured[i], slope, intercept);
249 x_corrected[i] = x_proj;
250 y_corrected[i] = y_proj;
251
252 let dx = x_measured[i] - x_proj;
253 let dy = y_measured[i] - y_proj;
254 total_residual += dx * dx + dy * dy;
255 }
256
257 Ok(TotalLeastSquaresResult {
258 slope,
259 intercept,
260 x_corrected,
261 y_corrected,
262 orthogonal_residuals: total_residual,
263 nit: 1,
264 converged: true,
265 })
266}
267
268#[allow(dead_code)]
270fn tls_iterative<S1, S2, S3, S4>(
271 x_measured: &ArrayBase<S1, Ix1>,
272 y_measured: &ArrayBase<S2, Ix1>,
273 x_variance: Option<&ArrayBase<S3, Ix1>>,
274 y_variance: Option<&ArrayBase<S4, Ix1>>,
275 options: &TotalLeastSquaresOptions,
276) -> OptimizeResult<TotalLeastSquaresResult>
277where
278 S1: Data<Elem = f64>,
279 S2: Data<Elem = f64>,
280 S3: Data<Elem = f64>,
281 S4: Data<Elem = f64>,
282{
283 let n = x_measured.len();
284
285 let (mut slope, mut intercept) = ordinary_least_squares(x_measured, y_measured);
287
288 let mut x_corrected = x_measured.to_owned();
289 let mut y_corrected = y_measured.to_owned();
290 let mut prev_residual = f64::INFINITY;
291
292 let x_weights = if let Some(x_var) = x_variance {
294 x_var.mapv(|v| 1.0 / v)
295 } else {
296 Array1::ones(n)
297 };
298
299 let y_weights = if let Some(y_var) = y_variance {
300 y_var.mapv(|v| 1.0 / v)
301 } else {
302 Array1::ones(n)
303 };
304
305 let mut iter = 0;
306 let mut converged = false;
307
308 while iter < options.max_iter {
309 let mut total_residual = 0.0;
311
312 for i in 0..n {
313 let (x_proj, y_proj) = weighted_orthogonal_projection(
314 x_measured[i],
315 y_measured[i],
316 slope,
317 intercept,
318 x_weights[i],
319 y_weights[i],
320 );
321
322 x_corrected[i] = x_proj;
323 y_corrected[i] = y_proj;
324
325 let dx = x_measured[i] - x_proj;
326 let dy = y_measured[i] - y_proj;
327 total_residual += x_weights[i] * dx * dx + y_weights[i] * dy * dy;
328 }
329
330 let (new_slope, new_intercept) =
332 weighted_least_squares_line(&x_corrected, &y_corrected, &x_weights, &y_weights);
333
334 if (total_residual - prev_residual).abs() < options.tol * total_residual
336 && (new_slope - slope).abs() < options.tol
337 && (new_intercept - intercept).abs() < options.tol
338 {
339 converged = true;
340 break;
341 }
342
343 slope = new_slope;
344 intercept = new_intercept;
345 prev_residual = total_residual;
346 iter += 1;
347 }
348
349 let mut orthogonal_residuals = 0.0;
351 for i in 0..n {
352 let dx = x_measured[i] - x_corrected[i];
353 let dy = y_measured[i] - y_corrected[i];
354 orthogonal_residuals += dx * dx + dy * dy;
355 }
356
357 Ok(TotalLeastSquaresResult {
358 slope,
359 intercept,
360 x_corrected,
361 y_corrected,
362 orthogonal_residuals,
363 nit: iter,
364 converged,
365 })
366}
367
368#[allow(dead_code)]
370fn tls_maximum_likelihood<S1, S2, S3, S4>(
371 x_measured: &ArrayBase<S1, Ix1>,
372 y_measured: &ArrayBase<S2, Ix1>,
373 x_variance: Option<&ArrayBase<S3, Ix1>>,
374 y_variance: Option<&ArrayBase<S4, Ix1>>,
375 options: &TotalLeastSquaresOptions,
376) -> OptimizeResult<TotalLeastSquaresResult>
377where
378 S1: Data<Elem = f64>,
379 S2: Data<Elem = f64>,
380 S3: Data<Elem = f64>,
381 S4: Data<Elem = f64>,
382{
383 tls_iterative(x_measured, y_measured, x_variance, y_variance, options)
386}
387
388#[allow(dead_code)]
390fn ordinary_least_squares<S1, S2>(x: &ArrayBase<S1, Ix1>, y: &ArrayBase<S2, Ix1>) -> (f64, f64)
391where
392 S1: Data<Elem = f64>,
393 S2: Data<Elem = f64>,
394{
395 let _n = x.len() as f64;
396 let x_mean = x.mean().unwrap();
397 let y_mean = y.mean().unwrap();
398
399 let mut num = 0.0;
400 let mut den = 0.0;
401
402 for i in 0..x.len() {
403 let dx = x[i] - x_mean;
404 let dy = y[i] - y_mean;
405 num += dx * dy;
406 den += dx * dx;
407 }
408
409 let slope = num / den;
410 let intercept = y_mean - slope * x_mean;
411
412 (slope, intercept)
413}
414
415#[allow(dead_code)]
417fn orthogonal_projection(x: f64, y: f64, slope: f64, intercept: f64) -> (f64, f64) {
418 let norm_sq = slope * slope + 1.0;
423 let t = ((y - intercept) * slope + x) / norm_sq;
424
425 let x_proj = t;
426 let y_proj = slope * t + intercept;
427
428 (x_proj, y_proj)
429}
430
431#[allow(dead_code)]
433fn weighted_orthogonal_projection(
434 x: f64,
435 y: f64,
436 slope: f64,
437 intercept: f64,
438 weight_x: f64,
439 weight_y: f64,
440) -> (f64, f64) {
441 let a = weight_x + weight_y * slope * slope;
445 let _b = weight_y * slope;
446 let c = weight_x * x + weight_y * slope * (y - intercept);
447
448 let x_proj = c / a;
449 let y_proj = slope * x_proj + intercept;
450
451 (x_proj, y_proj)
452}
453
454#[allow(dead_code)]
456fn weighted_least_squares_line<S1, S2, S3, S4>(
457 x: &ArrayBase<S1, Ix1>,
458 y: &ArrayBase<S2, Ix1>,
459 weight_x: &ArrayBase<S3, Ix1>,
460 weight_y: &ArrayBase<S4, Ix1>,
461) -> (f64, f64)
462where
463 S1: Data<Elem = f64>,
464 S2: Data<Elem = f64>,
465 S3: Data<Elem = f64>,
466 S4: Data<Elem = f64>,
467{
468 let n = x.len();
469 let mut sum_wx = 0.0;
470 let mut sum_wy = 0.0;
471 let mut sum_wxx = 0.0;
472 let mut sum_wxy = 0.0;
473 let mut _sum_wyy = 0.0;
474 let mut sum_w = 0.0;
475
476 for i in 0..n {
477 let w = (weight_x[i] + weight_y[i]) / 2.0; sum_w += w;
479 sum_wx += w * x[i];
480 sum_wy += w * y[i];
481 sum_wxx += w * x[i] * x[i];
482 sum_wxy += w * x[i] * y[i];
483 _sum_wyy += w * y[i] * y[i];
484 }
485
486 let x_mean = sum_wx / sum_w;
487 let y_mean = sum_wy / sum_w;
488
489 let cov_xx = sum_wxx / sum_w - x_mean * x_mean;
490 let cov_xy = sum_wxy / sum_w - x_mean * y_mean;
491
492 let slope = cov_xy / cov_xx;
493 let intercept = y_mean - slope * x_mean;
494
495 (slope, intercept)
496}
497
498#[allow(dead_code)]
500fn eigen_2x2(matrix: &Array2<f64>) -> (Array1<f64>, Array2<f64>) {
501 let a = matrix[[0, 0]];
502 let b = matrix[[0, 1]];
503 let c = matrix[[1, 0]];
504 let d = matrix[[1, 1]];
505
506 let trace = a + d;
508 let det = a * d - b * c;
509
510 let discriminant = trace * trace - 4.0 * det;
511 let sqrt_disc = discriminant.sqrt();
512
513 let lambda1 = (trace + sqrt_disc) / 2.0;
514 let lambda2 = (trace - sqrt_disc) / 2.0;
515
516 let mut eigenvectors = Array2::zeros((2, 2));
518
519 if (a - lambda1).abs() > 1e-10 || b.abs() > 1e-10 {
521 let v1_x = b;
522 let v1_y = lambda1 - a;
523 let norm1 = (v1_x * v1_x + v1_y * v1_y).sqrt();
524 eigenvectors[[0, 0]] = v1_x / norm1;
525 eigenvectors[[1, 0]] = v1_y / norm1;
526 } else {
527 eigenvectors[[0, 0]] = 1.0;
528 eigenvectors[[1, 0]] = 0.0;
529 }
530
531 if (a - lambda2).abs() > 1e-10 || b.abs() > 1e-10 {
533 let v2_x = b;
534 let v2_y = lambda2 - a;
535 let norm2 = (v2_x * v2_x + v2_y * v2_y).sqrt();
536 eigenvectors[[0, 1]] = v2_x / norm2;
537 eigenvectors[[1, 1]] = v2_y / norm2;
538 } else {
539 eigenvectors[[0, 1]] = 0.0;
540 eigenvectors[[1, 1]] = 1.0;
541 }
542
543 (array![lambda1, lambda2], eigenvectors)
544}
545
546#[cfg(test)]
547mod tests {
548 use super::*;
549 use scirs2_core::ndarray::array;
550
551 #[test]
552 fn test_total_least_squares_simple() {
553 let true_slope = 1.5;
555 let true_intercept = 0.5;
556
557 let x_true = array![1.0, 2.0, 3.0, 4.0, 5.0];
558 let y_true = &x_true * true_slope + true_intercept;
559
560 let x_errors = array![0.1, -0.05, 0.08, -0.03, 0.06];
562 let y_errors = array![-0.05, 0.1, -0.07, 0.04, -0.08];
563
564 let x_measured = &x_true + &x_errors;
565 let y_measured = &y_true + &y_errors;
566
567 let result = total_least_squares(
568 &x_measured,
569 &y_measured,
570 None::<&Array1<f64>>,
571 None::<&Array1<f64>>,
572 None,
573 )
574 .unwrap();
575
576 assert!((result.slope - true_slope).abs() < 0.1);
578 assert!((result.intercept - true_intercept).abs() < 0.1);
579 }
580
581 #[test]
582 fn test_weighted_total_least_squares() {
583 let x_measured = array![1.0, 2.1, 2.9, 4.2, 5.0];
585 let y_measured = array![2.1, 3.9, 5.1, 6.8, 8.1];
586
587 let x_variance = array![0.01, 0.01, 0.01, 0.1, 0.01];
589 let y_variance = array![0.01, 0.02, 0.01, 0.1, 0.01];
590
591 let result = total_least_squares(
592 &x_measured,
593 &y_measured,
594 Some(&x_variance),
595 Some(&y_variance),
596 None,
597 )
598 .unwrap();
599
600 assert!(result.converged);
602 println!(
603 "Weighted TLS: slope = {:.3}, intercept = {:.3}",
604 result.slope, result.intercept
605 );
606 }
607
608 #[test]
609 fn test_iterative_vs_svd() {
610 let x_measured = array![0.5, 1.5, 2.8, 3.7, 4.9];
612 let y_measured = array![1.2, 2.7, 4.1, 5.3, 6.8];
613
614 let mut options_svd = TotalLeastSquaresOptions::default();
615 options_svd.method = TLSMethod::SVD;
616
617 let mut options_iter = TotalLeastSquaresOptions::default();
618 options_iter.method = TLSMethod::Iterative;
619
620 let result_svd = total_least_squares::<
621 scirs2_core::ndarray::OwnedRepr<f64>,
622 scirs2_core::ndarray::OwnedRepr<f64>,
623 scirs2_core::ndarray::OwnedRepr<f64>,
624 scirs2_core::ndarray::OwnedRepr<f64>,
625 >(
626 &x_measured,
627 &y_measured,
628 None::<&Array1<f64>>,
629 None::<&Array1<f64>>,
630 Some(options_svd),
631 )
632 .unwrap();
633
634 let result_iter = total_least_squares::<
635 scirs2_core::ndarray::OwnedRepr<f64>,
636 scirs2_core::ndarray::OwnedRepr<f64>,
637 scirs2_core::ndarray::OwnedRepr<f64>,
638 scirs2_core::ndarray::OwnedRepr<f64>,
639 >(
640 &x_measured,
641 &y_measured,
642 None::<&Array1<f64>>,
643 None::<&Array1<f64>>,
644 Some(options_iter),
645 )
646 .unwrap();
647
648 assert!((result_svd.slope - result_iter.slope).abs() < 0.01);
650 assert!((result_svd.intercept - result_iter.intercept).abs() < 0.01);
651 }
652}