scirs2_optimize/least_squares/
total.rs

1//! Total least squares (errors-in-variables)
2//!
3//! This module implements total least squares for problems where both
4//! the independent and dependent variables have measurement errors.
5//!
6//! # Example
7//!
8//! ```
9//! use scirs2_core::ndarray::{array, Array1, Array2};
10//! use scirs2_optimize::least_squares::total::{total_least_squares, TotalLeastSquaresOptions};
11//!
12//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
13//! // Data points with errors in both x and y
14//! let x_measured = array![1.0, 2.1, 2.9, 4.2, 5.0];
15//! let y_measured = array![2.1, 3.9, 5.1, 6.8, 8.1];
16//!
17//! // Known or estimated error variances
18//! let x_variance = array![0.1, 0.1, 0.1, 0.2, 0.1];
19//! let y_variance = array![0.1, 0.15, 0.1, 0.2, 0.1];
20//!
21//! let result = total_least_squares(
22//!     &x_measured,
23//!     &y_measured,
24//!     Some(&x_variance),
25//!     Some(&y_variance),
26//!     None
27//! )?;
28//!
29//! println!("Slope: {:.3}", result.slope);
30//! println!("Intercept: {:.3}", result.intercept);
31//! # Ok(())
32//! # }
33//! ```
34
35use crate::error::OptimizeResult;
36use scirs2_core::ndarray::{array, s, Array1, Array2, ArrayBase, ArrayStatCompat, Data, Ix1};
37use statrs::statistics::Statistics;
38
39/// Options for total least squares
40#[derive(Debug, Clone)]
41pub struct TotalLeastSquaresOptions {
42    /// Maximum number of iterations for iterative methods
43    pub max_iter: usize,
44
45    /// Convergence tolerance
46    pub tol: f64,
47
48    /// Method to use
49    pub method: TLSMethod,
50
51    /// Whether to use weighted TLS when variances are provided
52    pub use_weights: bool,
53}
54
55/// Methods for total least squares
56#[derive(Debug, Clone, Copy)]
57pub enum TLSMethod {
58    /// Singular Value Decomposition (most stable)
59    SVD,
60    /// Iterative orthogonal regression
61    Iterative,
62    /// Maximum likelihood estimation
63    MaximumLikelihood,
64}
65
66impl Default for TotalLeastSquaresOptions {
67    fn default() -> Self {
68        TotalLeastSquaresOptions {
69            max_iter: 100,
70            tol: 1e-8,
71            method: TLSMethod::SVD,
72            use_weights: true,
73        }
74    }
75}
76
77/// Result structure for total least squares
78#[derive(Debug, Clone)]
79pub struct TotalLeastSquaresResult {
80    /// Estimated slope
81    pub slope: f64,
82
83    /// Estimated intercept
84    pub intercept: f64,
85
86    /// Corrected x values
87    pub x_corrected: Array1<f64>,
88
89    /// Corrected y values
90    pub y_corrected: Array1<f64>,
91
92    /// Sum of squared orthogonal distances
93    pub orthogonal_residuals: f64,
94
95    /// Number of iterations (for iterative methods)
96    pub nit: usize,
97
98    /// Convergence status
99    pub converged: bool,
100}
101
102/// Solve a total least squares problem
103///
104/// This function fits a line to data with errors in both variables.
105/// It minimizes the sum of squared orthogonal distances to the line.
106///
107/// # Arguments
108///
109/// * `x_measured` - Measured x values
110/// * `y_measured` - Measured y values
111/// * `x_variance` - Optional variance estimates for x measurements
112/// * `y_variance` - Optional variance estimates for y measurements
113/// * `options` - Options for the optimization
114#[allow(dead_code)]
115pub fn total_least_squares<S1, S2, S3, S4>(
116    x_measured: &ArrayBase<S1, Ix1>,
117    y_measured: &ArrayBase<S2, Ix1>,
118    x_variance: Option<&ArrayBase<S3, Ix1>>,
119    y_variance: Option<&ArrayBase<S4, Ix1>>,
120    options: Option<TotalLeastSquaresOptions>,
121) -> OptimizeResult<TotalLeastSquaresResult>
122where
123    S1: Data<Elem = f64>,
124    S2: Data<Elem = f64>,
125    S3: Data<Elem = f64>,
126    S4: Data<Elem = f64>,
127{
128    let options = options.unwrap_or_default();
129    let n = x_measured.len();
130
131    if y_measured.len() != n {
132        return Err(crate::error::OptimizeError::ValueError(
133            "x_measured and y_measured must have the same length".to_string(),
134        ));
135    }
136
137    // Check _variance arrays if provided
138    if let Some(x_var) = x_variance {
139        if x_var.len() != n {
140            return Err(crate::error::OptimizeError::ValueError(
141                "x_variance must have the same length as data".to_string(),
142            ));
143        }
144    }
145
146    if let Some(y_var) = y_variance {
147        if y_var.len() != n {
148            return Err(crate::error::OptimizeError::ValueError(
149                "y_variance must have the same length as data".to_string(),
150            ));
151        }
152    }
153
154    match options.method {
155        TLSMethod::SVD => tls_svd(x_measured, y_measured, x_variance, y_variance, &options),
156        TLSMethod::Iterative => {
157            tls_iterative(x_measured, y_measured, x_variance, y_variance, &options)
158        }
159        TLSMethod::MaximumLikelihood => {
160            tls_maximum_likelihood(x_measured, y_measured, x_variance, y_variance, &options)
161        }
162    }
163}
164
165/// Total least squares using SVD
166#[allow(dead_code)]
167fn tls_svd<S1, S2, S3, S4>(
168    x_measured: &ArrayBase<S1, Ix1>,
169    y_measured: &ArrayBase<S2, Ix1>,
170    x_variance: Option<&ArrayBase<S3, Ix1>>,
171    y_variance: Option<&ArrayBase<S4, Ix1>>,
172    options: &TotalLeastSquaresOptions,
173) -> OptimizeResult<TotalLeastSquaresResult>
174where
175    S1: Data<Elem = f64>,
176    S2: Data<Elem = f64>,
177    S3: Data<Elem = f64>,
178    S4: Data<Elem = f64>,
179{
180    let n = x_measured.len();
181
182    // Center the data
183    let x_mean = x_measured.mean_or(0.0);
184    let y_mean = y_measured.mean_or(0.0);
185
186    let x_centered = x_measured - x_mean;
187    let y_centered = y_measured - y_mean;
188
189    // Apply weights if variances are provided
190    let (x_scaled, y_scaled) =
191        if options.use_weights && x_variance.is_some() && y_variance.is_some() {
192            let x_var = x_variance.expect("Operation failed");
193            let y_var = y_variance.expect("Operation failed");
194
195            // Scale by inverse standard deviation
196            let x_weights = x_var.mapv(|v| 1.0 / v.sqrt());
197            let y_weights = y_var.mapv(|v| 1.0 / v.sqrt());
198
199            (
200                (&x_centered * &x_weights).to_owned(),
201                (&y_centered * &y_weights).to_owned(),
202            )
203        } else {
204            (x_centered.to_owned(), y_centered.to_owned())
205        };
206
207    // Form the augmented matrix [x_scaled, y_scaled]
208    let mut data_matrix = Array2::zeros((n, 2));
209    for i in 0..n {
210        data_matrix[[i, 0]] = x_scaled[i];
211        data_matrix[[i, 1]] = y_scaled[i];
212    }
213
214    // Perform SVD (simplified - in practice use a proper SVD)
215    // For now, use eigendecomposition of the covariance matrix
216    let cov_matrix = data_matrix.t().dot(&data_matrix) / n as f64;
217
218    // Find eigenvalues and eigenvectors
219    let (eigenvalues, eigenvectors) = eigen_2x2(&cov_matrix);
220
221    // The eigenvector corresponding to the smallest eigenvalue gives the normal to the line
222    let min_idx = if eigenvalues[0] < eigenvalues[1] {
223        0
224    } else {
225        1
226    };
227    let normal = eigenvectors.slice(s![.., min_idx]).to_owned();
228
229    // Convert normal to slope-intercept form
230    // Normal vector (a, b) corresponds to line ax + by + c = 0
231    let a = normal[0usize];
232    let b = normal[1usize];
233
234    if b.abs() < 1e-10 {
235        // Nearly vertical line
236        return Err(crate::error::OptimizeError::ValueError(
237            "Nearly vertical line detected".to_string(),
238        ));
239    }
240
241    let slope = -a / b;
242    let intercept = y_mean - slope * x_mean;
243
244    // Compute corrected values (orthogonal projection onto the line)
245    let mut x_corrected = Array1::zeros(n);
246    let mut y_corrected = Array1::zeros(n);
247    let mut total_residual = 0.0;
248
249    for i in 0..n {
250        let (x_proj, y_proj) =
251            orthogonal_projection(x_measured[i], y_measured[i], slope, intercept);
252        x_corrected[i] = x_proj;
253        y_corrected[i] = y_proj;
254
255        let dx = x_measured[i] - x_proj;
256        let dy = y_measured[i] - y_proj;
257        total_residual += dx * dx + dy * dy;
258    }
259
260    Ok(TotalLeastSquaresResult {
261        slope,
262        intercept,
263        x_corrected,
264        y_corrected,
265        orthogonal_residuals: total_residual,
266        nit: 1,
267        converged: true,
268    })
269}
270
271/// Iterative total least squares
272#[allow(dead_code)]
273fn tls_iterative<S1, S2, S3, S4>(
274    x_measured: &ArrayBase<S1, Ix1>,
275    y_measured: &ArrayBase<S2, Ix1>,
276    x_variance: Option<&ArrayBase<S3, Ix1>>,
277    y_variance: Option<&ArrayBase<S4, Ix1>>,
278    options: &TotalLeastSquaresOptions,
279) -> OptimizeResult<TotalLeastSquaresResult>
280where
281    S1: Data<Elem = f64>,
282    S2: Data<Elem = f64>,
283    S3: Data<Elem = f64>,
284    S4: Data<Elem = f64>,
285{
286    let n = x_measured.len();
287
288    // Initial estimate using ordinary least squares
289    let (mut slope, mut intercept) = ordinary_least_squares(x_measured, y_measured);
290
291    let mut x_corrected = x_measured.to_owned();
292    let mut y_corrected = y_measured.to_owned();
293    let mut prev_residual = f64::INFINITY;
294
295    // Get weights from variances
296    let x_weights = if let Some(x_var) = x_variance {
297        x_var.mapv(|v| 1.0 / v)
298    } else {
299        Array1::ones(n)
300    };
301
302    let y_weights = if let Some(y_var) = y_variance {
303        y_var.mapv(|v| 1.0 / v)
304    } else {
305        Array1::ones(n)
306    };
307
308    let mut iter = 0;
309    let mut converged = false;
310
311    while iter < options.max_iter {
312        // E-step: Update corrected values given current line parameters
313        let mut total_residual = 0.0;
314
315        for i in 0..n {
316            let (x_proj, y_proj) = weighted_orthogonal_projection(
317                x_measured[i],
318                y_measured[i],
319                slope,
320                intercept,
321                x_weights[i],
322                y_weights[i],
323            );
324
325            x_corrected[i] = x_proj;
326            y_corrected[i] = y_proj;
327
328            let dx = x_measured[i] - x_proj;
329            let dy = y_measured[i] - y_proj;
330            total_residual += x_weights[i] * dx * dx + y_weights[i] * dy * dy;
331        }
332
333        // M-step: Update line parameters given corrected values
334        let (new_slope, new_intercept) =
335            weighted_least_squares_line(&x_corrected, &y_corrected, &x_weights, &y_weights);
336
337        // Check convergence
338        if (total_residual - prev_residual).abs() < options.tol * total_residual
339            && (new_slope - slope).abs() < options.tol
340            && (new_intercept - intercept).abs() < options.tol
341        {
342            converged = true;
343            break;
344        }
345
346        slope = new_slope;
347        intercept = new_intercept;
348        prev_residual = total_residual;
349        iter += 1;
350    }
351
352    // Compute final orthogonal residuals
353    let mut orthogonal_residuals = 0.0;
354    for i in 0..n {
355        let dx = x_measured[i] - x_corrected[i];
356        let dy = y_measured[i] - y_corrected[i];
357        orthogonal_residuals += dx * dx + dy * dy;
358    }
359
360    Ok(TotalLeastSquaresResult {
361        slope,
362        intercept,
363        x_corrected,
364        y_corrected,
365        orthogonal_residuals,
366        nit: iter,
367        converged,
368    })
369}
370
371/// Maximum likelihood total least squares
372#[allow(dead_code)]
373fn tls_maximum_likelihood<S1, S2, S3, S4>(
374    x_measured: &ArrayBase<S1, Ix1>,
375    y_measured: &ArrayBase<S2, Ix1>,
376    x_variance: Option<&ArrayBase<S3, Ix1>>,
377    y_variance: Option<&ArrayBase<S4, Ix1>>,
378    options: &TotalLeastSquaresOptions,
379) -> OptimizeResult<TotalLeastSquaresResult>
380where
381    S1: Data<Elem = f64>,
382    S2: Data<Elem = f64>,
383    S3: Data<Elem = f64>,
384    S4: Data<Elem = f64>,
385{
386    // For now, use the iterative method
387    // A proper implementation would maximize the likelihood function
388    tls_iterative(x_measured, y_measured, x_variance, y_variance, options)
389}
390
391/// Compute ordinary least squares for initial estimate
392#[allow(dead_code)]
393fn ordinary_least_squares<S1, S2>(x: &ArrayBase<S1, Ix1>, y: &ArrayBase<S2, Ix1>) -> (f64, f64)
394where
395    S1: Data<Elem = f64>,
396    S2: Data<Elem = f64>,
397{
398    let _n = x.len() as f64;
399    let x_mean = x.mean_or(0.0);
400    let y_mean = y.mean_or(0.0);
401
402    let mut num = 0.0;
403    let mut den = 0.0;
404
405    for i in 0..x.len() {
406        let dx = x[i] - x_mean;
407        let dy = y[i] - y_mean;
408        num += dx * dy;
409        den += dx * dx;
410    }
411
412    let slope = num / den;
413    let intercept = y_mean - slope * x_mean;
414
415    (slope, intercept)
416}
417
418/// Orthogonal projection of a point onto a line
419#[allow(dead_code)]
420fn orthogonal_projection(x: f64, y: f64, slope: f64, intercept: f64) -> (f64, f64) {
421    // Line equation: y = slope * x + intercept
422    // Normal vector: (slope, -1)
423    // Normalized: (slope, -1) / sqrt(slope^2 + 1)
424
425    let norm_sq = slope * slope + 1.0;
426    let t = ((y - intercept) * slope + x) / norm_sq;
427
428    let x_proj = t;
429    let y_proj = slope * t + intercept;
430
431    (x_proj, y_proj)
432}
433
434/// Weighted orthogonal projection
435#[allow(dead_code)]
436fn weighted_orthogonal_projection(
437    x: f64,
438    y: f64,
439    slope: f64,
440    intercept: f64,
441    weight_x: f64,
442    weight_y: f64,
443) -> (f64, f64) {
444    // Minimize: weight_x * (x - x_proj)^2 + weight_y * (y - y_proj)^2
445    // Subject to: y_proj = slope * x_proj + intercept
446
447    let a = weight_x + weight_y * slope * slope;
448    let _b = weight_y * slope;
449    let c = weight_x * x + weight_y * slope * (y - intercept);
450
451    let x_proj = c / a;
452    let y_proj = slope * x_proj + intercept;
453
454    (x_proj, y_proj)
455}
456
457/// Weighted least squares for a line
458#[allow(dead_code)]
459fn weighted_least_squares_line<S1, S2, S3, S4>(
460    x: &ArrayBase<S1, Ix1>,
461    y: &ArrayBase<S2, Ix1>,
462    weight_x: &ArrayBase<S3, Ix1>,
463    weight_y: &ArrayBase<S4, Ix1>,
464) -> (f64, f64)
465where
466    S1: Data<Elem = f64>,
467    S2: Data<Elem = f64>,
468    S3: Data<Elem = f64>,
469    S4: Data<Elem = f64>,
470{
471    let n = x.len();
472    let mut sum_wx = 0.0;
473    let mut sum_wy = 0.0;
474    let mut sum_wxx = 0.0;
475    let mut sum_wxy = 0.0;
476    let mut _sum_wyy = 0.0;
477    let mut sum_w = 0.0;
478
479    for i in 0..n {
480        let w = (weight_x[i] + weight_y[i]) / 2.0; // Combined weight
481        sum_w += w;
482        sum_wx += w * x[i];
483        sum_wy += w * y[i];
484        sum_wxx += w * x[i] * x[i];
485        sum_wxy += w * x[i] * y[i];
486        _sum_wyy += w * y[i] * y[i];
487    }
488
489    let x_mean = sum_wx / sum_w;
490    let y_mean = sum_wy / sum_w;
491
492    let cov_xx = sum_wxx / sum_w - x_mean * x_mean;
493    let cov_xy = sum_wxy / sum_w - x_mean * y_mean;
494
495    let slope = cov_xy / cov_xx;
496    let intercept = y_mean - slope * x_mean;
497
498    (slope, intercept)
499}
500
501/// Simple 2x2 eigendecomposition
502#[allow(dead_code)]
503fn eigen_2x2(matrix: &Array2<f64>) -> (Array1<f64>, Array2<f64>) {
504    let a = matrix[[0, 0]];
505    let b = matrix[[0, 1]];
506    let c = matrix[[1, 0]];
507    let d = matrix[[1, 1]];
508
509    // Characteristic equation: λ² - (a+d)λ + (ad-bc) = 0
510    let trace = a + d;
511    let det = a * d - b * c;
512
513    let discriminant = trace * trace - 4.0 * det;
514    let sqrt_disc = discriminant.sqrt();
515
516    let lambda1 = (trace + sqrt_disc) / 2.0;
517    let lambda2 = (trace - sqrt_disc) / 2.0;
518
519    // Eigenvectors
520    let mut eigenvectors = Array2::zeros((2, 2));
521
522    // For λ1
523    if (a - lambda1).abs() > 1e-10 || b.abs() > 1e-10 {
524        let v1_x = b;
525        let v1_y = lambda1 - a;
526        let norm1 = (v1_x * v1_x + v1_y * v1_y).sqrt();
527        eigenvectors[[0, 0]] = v1_x / norm1;
528        eigenvectors[[1, 0]] = v1_y / norm1;
529    } else {
530        eigenvectors[[0, 0]] = 1.0;
531        eigenvectors[[1, 0]] = 0.0;
532    }
533
534    // For λ2
535    if (a - lambda2).abs() > 1e-10 || b.abs() > 1e-10 {
536        let v2_x = b;
537        let v2_y = lambda2 - a;
538        let norm2 = (v2_x * v2_x + v2_y * v2_y).sqrt();
539        eigenvectors[[0, 1]] = v2_x / norm2;
540        eigenvectors[[1, 1]] = v2_y / norm2;
541    } else {
542        eigenvectors[[0, 1]] = 0.0;
543        eigenvectors[[1, 1]] = 1.0;
544    }
545
546    (array![lambda1, lambda2], eigenvectors)
547}
548
549#[cfg(test)]
550mod tests {
551    use super::*;
552    use scirs2_core::ndarray::array;
553
554    #[test]
555    fn test_total_least_squares_simple() {
556        // Generate data with errors in both x and y
557        let true_slope = 1.5;
558        let true_intercept = 0.5;
559
560        let x_true = array![1.0, 2.0, 3.0, 4.0, 5.0];
561        let y_true = &x_true * true_slope + true_intercept;
562
563        // Add errors
564        let x_errors = array![0.1, -0.05, 0.08, -0.03, 0.06];
565        let y_errors = array![-0.05, 0.1, -0.07, 0.04, -0.08];
566
567        let x_measured = &x_true + &x_errors;
568        let y_measured = &y_true + &y_errors;
569
570        let result = total_least_squares(
571            &x_measured,
572            &y_measured,
573            None::<&Array1<f64>>,
574            None::<&Array1<f64>>,
575            None,
576        )
577        .expect("Operation failed");
578
579        // Check that the estimated parameters are close to true values
580        assert!((result.slope - true_slope).abs() < 0.1);
581        assert!((result.intercept - true_intercept).abs() < 0.1);
582    }
583
584    #[test]
585    fn test_weighted_total_least_squares() {
586        // Data with different error variances
587        let x_measured = array![1.0, 2.1, 2.9, 4.2, 5.0];
588        let y_measured = array![2.1, 3.9, 5.1, 6.8, 8.1];
589
590        // Known error variances (larger for some points)
591        let x_variance = array![0.01, 0.01, 0.01, 0.1, 0.01];
592        let y_variance = array![0.01, 0.02, 0.01, 0.1, 0.01];
593
594        let result = total_least_squares(
595            &x_measured,
596            &y_measured,
597            Some(&x_variance),
598            Some(&y_variance),
599            None,
600        )
601        .expect("Operation failed");
602
603        // The point with large variance should have less influence
604        assert!(result.converged);
605        println!(
606            "Weighted TLS: slope = {:.3}, intercept = {:.3}",
607            result.slope, result.intercept
608        );
609    }
610
611    #[test]
612    fn test_iterative_vs_svd() {
613        // Compare iterative and SVD methods
614        let x_measured = array![0.5, 1.5, 2.8, 3.7, 4.9];
615        let y_measured = array![1.2, 2.7, 4.1, 5.3, 6.8];
616
617        let mut options_svd = TotalLeastSquaresOptions::default();
618        options_svd.method = TLSMethod::SVD;
619
620        let mut options_iter = TotalLeastSquaresOptions::default();
621        options_iter.method = TLSMethod::Iterative;
622
623        let result_svd = total_least_squares::<
624            scirs2_core::ndarray::OwnedRepr<f64>,
625            scirs2_core::ndarray::OwnedRepr<f64>,
626            scirs2_core::ndarray::OwnedRepr<f64>,
627            scirs2_core::ndarray::OwnedRepr<f64>,
628        >(
629            &x_measured,
630            &y_measured,
631            None::<&Array1<f64>>,
632            None::<&Array1<f64>>,
633            Some(options_svd),
634        )
635        .expect("Operation failed");
636
637        let result_iter = total_least_squares::<
638            scirs2_core::ndarray::OwnedRepr<f64>,
639            scirs2_core::ndarray::OwnedRepr<f64>,
640            scirs2_core::ndarray::OwnedRepr<f64>,
641            scirs2_core::ndarray::OwnedRepr<f64>,
642        >(
643            &x_measured,
644            &y_measured,
645            None::<&Array1<f64>>,
646            None::<&Array1<f64>>,
647            Some(options_iter),
648        )
649        .expect("Operation failed");
650
651        // Results should be similar
652        assert!((result_svd.slope - result_iter.slope).abs() < 0.01);
653        assert!((result_svd.intercept - result_iter.intercept).abs() < 0.01);
654    }
655}