scirs2_optimize/least_squares/
total.rs

1//! Total least squares (errors-in-variables)
2//!
3//! This module implements total least squares for problems where both
4//! the independent and dependent variables have measurement errors.
5//!
6//! # Example
7//!
8//! ```
9//! use scirs2_core::ndarray::{array, Array1, Array2};
10//! use scirs2_optimize::least_squares::total::{total_least_squares, TotalLeastSquaresOptions};
11//!
12//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
13//! // Data points with errors in both x and y
14//! let x_measured = array![1.0, 2.1, 2.9, 4.2, 5.0];
15//! let y_measured = array![2.1, 3.9, 5.1, 6.8, 8.1];
16//!
17//! // Known or estimated error variances
18//! let x_variance = array![0.1, 0.1, 0.1, 0.2, 0.1];
19//! let y_variance = array![0.1, 0.15, 0.1, 0.2, 0.1];
20//!
21//! let result = total_least_squares(
22//!     &x_measured,
23//!     &y_measured,
24//!     Some(&x_variance),
25//!     Some(&y_variance),
26//!     None
27//! )?;
28//!
29//! println!("Slope: {:.3}", result.slope);
30//! println!("Intercept: {:.3}", result.intercept);
31//! # Ok(())
32//! # }
33//! ```
34
35use crate::error::OptimizeResult;
36use scirs2_core::ndarray::{array, s, Array1, Array2, ArrayBase, Data, Ix1};
37use statrs::statistics::Statistics;
38
39/// Options for total least squares
40#[derive(Debug, Clone)]
41pub struct TotalLeastSquaresOptions {
42    /// Maximum number of iterations for iterative methods
43    pub max_iter: usize,
44
45    /// Convergence tolerance
46    pub tol: f64,
47
48    /// Method to use
49    pub method: TLSMethod,
50
51    /// Whether to use weighted TLS when variances are provided
52    pub use_weights: bool,
53}
54
55/// Methods for total least squares
56#[derive(Debug, Clone, Copy)]
57pub enum TLSMethod {
58    /// Singular Value Decomposition (most stable)
59    SVD,
60    /// Iterative orthogonal regression
61    Iterative,
62    /// Maximum likelihood estimation
63    MaximumLikelihood,
64}
65
66impl Default for TotalLeastSquaresOptions {
67    fn default() -> Self {
68        TotalLeastSquaresOptions {
69            max_iter: 100,
70            tol: 1e-8,
71            method: TLSMethod::SVD,
72            use_weights: true,
73        }
74    }
75}
76
77/// Result structure for total least squares
78#[derive(Debug, Clone)]
79pub struct TotalLeastSquaresResult {
80    /// Estimated slope
81    pub slope: f64,
82
83    /// Estimated intercept
84    pub intercept: f64,
85
86    /// Corrected x values
87    pub x_corrected: Array1<f64>,
88
89    /// Corrected y values
90    pub y_corrected: Array1<f64>,
91
92    /// Sum of squared orthogonal distances
93    pub orthogonal_residuals: f64,
94
95    /// Number of iterations (for iterative methods)
96    pub nit: usize,
97
98    /// Convergence status
99    pub converged: bool,
100}
101
102/// Solve a total least squares problem
103///
104/// This function fits a line to data with errors in both variables.
105/// It minimizes the sum of squared orthogonal distances to the line.
106///
107/// # Arguments
108///
109/// * `x_measured` - Measured x values
110/// * `y_measured` - Measured y values
111/// * `x_variance` - Optional variance estimates for x measurements
112/// * `y_variance` - Optional variance estimates for y measurements
113/// * `options` - Options for the optimization
114#[allow(dead_code)]
115pub fn total_least_squares<S1, S2, S3, S4>(
116    x_measured: &ArrayBase<S1, Ix1>,
117    y_measured: &ArrayBase<S2, Ix1>,
118    x_variance: Option<&ArrayBase<S3, Ix1>>,
119    y_variance: Option<&ArrayBase<S4, Ix1>>,
120    options: Option<TotalLeastSquaresOptions>,
121) -> OptimizeResult<TotalLeastSquaresResult>
122where
123    S1: Data<Elem = f64>,
124    S2: Data<Elem = f64>,
125    S3: Data<Elem = f64>,
126    S4: Data<Elem = f64>,
127{
128    let options = options.unwrap_or_default();
129    let n = x_measured.len();
130
131    if y_measured.len() != n {
132        return Err(crate::error::OptimizeError::ValueError(
133            "x_measured and y_measured must have the same length".to_string(),
134        ));
135    }
136
137    // Check _variance arrays if provided
138    if let Some(x_var) = x_variance {
139        if x_var.len() != n {
140            return Err(crate::error::OptimizeError::ValueError(
141                "x_variance must have the same length as data".to_string(),
142            ));
143        }
144    }
145
146    if let Some(y_var) = y_variance {
147        if y_var.len() != n {
148            return Err(crate::error::OptimizeError::ValueError(
149                "y_variance must have the same length as data".to_string(),
150            ));
151        }
152    }
153
154    match options.method {
155        TLSMethod::SVD => tls_svd(x_measured, y_measured, x_variance, y_variance, &options),
156        TLSMethod::Iterative => {
157            tls_iterative(x_measured, y_measured, x_variance, y_variance, &options)
158        }
159        TLSMethod::MaximumLikelihood => {
160            tls_maximum_likelihood(x_measured, y_measured, x_variance, y_variance, &options)
161        }
162    }
163}
164
165/// Total least squares using SVD
166#[allow(dead_code)]
167fn tls_svd<S1, S2, S3, S4>(
168    x_measured: &ArrayBase<S1, Ix1>,
169    y_measured: &ArrayBase<S2, Ix1>,
170    x_variance: Option<&ArrayBase<S3, Ix1>>,
171    y_variance: Option<&ArrayBase<S4, Ix1>>,
172    options: &TotalLeastSquaresOptions,
173) -> OptimizeResult<TotalLeastSquaresResult>
174where
175    S1: Data<Elem = f64>,
176    S2: Data<Elem = f64>,
177    S3: Data<Elem = f64>,
178    S4: Data<Elem = f64>,
179{
180    let n = x_measured.len();
181
182    // Center the data
183    let x_mean = x_measured.mean().unwrap();
184    let y_mean = y_measured.mean().unwrap();
185
186    let x_centered = x_measured - x_mean;
187    let y_centered = y_measured - y_mean;
188
189    // Apply weights if variances are provided
190    let (x_scaled, y_scaled) =
191        if options.use_weights && x_variance.is_some() && y_variance.is_some() {
192            let x_var = x_variance.unwrap();
193            let y_var = y_variance.unwrap();
194
195            // Scale by inverse standard deviation
196            let x_weights = x_var.mapv(|v| 1.0 / v.sqrt());
197            let y_weights = y_var.mapv(|v| 1.0 / v.sqrt());
198
199            (x_centered * &x_weights, y_centered * &y_weights)
200        } else {
201            (x_centered.clone(), y_centered.clone())
202        };
203
204    // Form the augmented matrix [x_scaled, y_scaled]
205    let mut data_matrix = Array2::zeros((n, 2));
206    for i in 0..n {
207        data_matrix[[i, 0]] = x_scaled[i];
208        data_matrix[[i, 1]] = y_scaled[i];
209    }
210
211    // Perform SVD (simplified - in practice use a proper SVD)
212    // For now, use eigendecomposition of the covariance matrix
213    let cov_matrix = data_matrix.t().dot(&data_matrix) / n as f64;
214
215    // Find eigenvalues and eigenvectors
216    let (eigenvalues, eigenvectors) = eigen_2x2(&cov_matrix);
217
218    // The eigenvector corresponding to the smallest eigenvalue gives the normal to the line
219    let min_idx = if eigenvalues[0] < eigenvalues[1] {
220        0
221    } else {
222        1
223    };
224    let normal = eigenvectors.slice(s![.., min_idx]).to_owned();
225
226    // Convert normal to slope-intercept form
227    // Normal vector (a, b) corresponds to line ax + by + c = 0
228    let a = normal[0usize];
229    let b = normal[1usize];
230
231    if b.abs() < 1e-10 {
232        // Nearly vertical line
233        return Err(crate::error::OptimizeError::ValueError(
234            "Nearly vertical line detected".to_string(),
235        ));
236    }
237
238    let slope = -a / b;
239    let intercept = y_mean - slope * x_mean;
240
241    // Compute corrected values (orthogonal projection onto the line)
242    let mut x_corrected = Array1::zeros(n);
243    let mut y_corrected = Array1::zeros(n);
244    let mut total_residual = 0.0;
245
246    for i in 0..n {
247        let (x_proj, y_proj) =
248            orthogonal_projection(x_measured[i], y_measured[i], slope, intercept);
249        x_corrected[i] = x_proj;
250        y_corrected[i] = y_proj;
251
252        let dx = x_measured[i] - x_proj;
253        let dy = y_measured[i] - y_proj;
254        total_residual += dx * dx + dy * dy;
255    }
256
257    Ok(TotalLeastSquaresResult {
258        slope,
259        intercept,
260        x_corrected,
261        y_corrected,
262        orthogonal_residuals: total_residual,
263        nit: 1,
264        converged: true,
265    })
266}
267
268/// Iterative total least squares
269#[allow(dead_code)]
270fn tls_iterative<S1, S2, S3, S4>(
271    x_measured: &ArrayBase<S1, Ix1>,
272    y_measured: &ArrayBase<S2, Ix1>,
273    x_variance: Option<&ArrayBase<S3, Ix1>>,
274    y_variance: Option<&ArrayBase<S4, Ix1>>,
275    options: &TotalLeastSquaresOptions,
276) -> OptimizeResult<TotalLeastSquaresResult>
277where
278    S1: Data<Elem = f64>,
279    S2: Data<Elem = f64>,
280    S3: Data<Elem = f64>,
281    S4: Data<Elem = f64>,
282{
283    let n = x_measured.len();
284
285    // Initial estimate using ordinary least squares
286    let (mut slope, mut intercept) = ordinary_least_squares(x_measured, y_measured);
287
288    let mut x_corrected = x_measured.to_owned();
289    let mut y_corrected = y_measured.to_owned();
290    let mut prev_residual = f64::INFINITY;
291
292    // Get weights from variances
293    let x_weights = if let Some(x_var) = x_variance {
294        x_var.mapv(|v| 1.0 / v)
295    } else {
296        Array1::ones(n)
297    };
298
299    let y_weights = if let Some(y_var) = y_variance {
300        y_var.mapv(|v| 1.0 / v)
301    } else {
302        Array1::ones(n)
303    };
304
305    let mut iter = 0;
306    let mut converged = false;
307
308    while iter < options.max_iter {
309        // E-step: Update corrected values given current line parameters
310        let mut total_residual = 0.0;
311
312        for i in 0..n {
313            let (x_proj, y_proj) = weighted_orthogonal_projection(
314                x_measured[i],
315                y_measured[i],
316                slope,
317                intercept,
318                x_weights[i],
319                y_weights[i],
320            );
321
322            x_corrected[i] = x_proj;
323            y_corrected[i] = y_proj;
324
325            let dx = x_measured[i] - x_proj;
326            let dy = y_measured[i] - y_proj;
327            total_residual += x_weights[i] * dx * dx + y_weights[i] * dy * dy;
328        }
329
330        // M-step: Update line parameters given corrected values
331        let (new_slope, new_intercept) =
332            weighted_least_squares_line(&x_corrected, &y_corrected, &x_weights, &y_weights);
333
334        // Check convergence
335        if (total_residual - prev_residual).abs() < options.tol * total_residual
336            && (new_slope - slope).abs() < options.tol
337            && (new_intercept - intercept).abs() < options.tol
338        {
339            converged = true;
340            break;
341        }
342
343        slope = new_slope;
344        intercept = new_intercept;
345        prev_residual = total_residual;
346        iter += 1;
347    }
348
349    // Compute final orthogonal residuals
350    let mut orthogonal_residuals = 0.0;
351    for i in 0..n {
352        let dx = x_measured[i] - x_corrected[i];
353        let dy = y_measured[i] - y_corrected[i];
354        orthogonal_residuals += dx * dx + dy * dy;
355    }
356
357    Ok(TotalLeastSquaresResult {
358        slope,
359        intercept,
360        x_corrected,
361        y_corrected,
362        orthogonal_residuals,
363        nit: iter,
364        converged,
365    })
366}
367
368/// Maximum likelihood total least squares
369#[allow(dead_code)]
370fn tls_maximum_likelihood<S1, S2, S3, S4>(
371    x_measured: &ArrayBase<S1, Ix1>,
372    y_measured: &ArrayBase<S2, Ix1>,
373    x_variance: Option<&ArrayBase<S3, Ix1>>,
374    y_variance: Option<&ArrayBase<S4, Ix1>>,
375    options: &TotalLeastSquaresOptions,
376) -> OptimizeResult<TotalLeastSquaresResult>
377where
378    S1: Data<Elem = f64>,
379    S2: Data<Elem = f64>,
380    S3: Data<Elem = f64>,
381    S4: Data<Elem = f64>,
382{
383    // For now, use the iterative method
384    // A proper implementation would maximize the likelihood function
385    tls_iterative(x_measured, y_measured, x_variance, y_variance, options)
386}
387
388/// Compute ordinary least squares for initial estimate
389#[allow(dead_code)]
390fn ordinary_least_squares<S1, S2>(x: &ArrayBase<S1, Ix1>, y: &ArrayBase<S2, Ix1>) -> (f64, f64)
391where
392    S1: Data<Elem = f64>,
393    S2: Data<Elem = f64>,
394{
395    let _n = x.len() as f64;
396    let x_mean = x.mean().unwrap();
397    let y_mean = y.mean().unwrap();
398
399    let mut num = 0.0;
400    let mut den = 0.0;
401
402    for i in 0..x.len() {
403        let dx = x[i] - x_mean;
404        let dy = y[i] - y_mean;
405        num += dx * dy;
406        den += dx * dx;
407    }
408
409    let slope = num / den;
410    let intercept = y_mean - slope * x_mean;
411
412    (slope, intercept)
413}
414
415/// Orthogonal projection of a point onto a line
416#[allow(dead_code)]
417fn orthogonal_projection(x: f64, y: f64, slope: f64, intercept: f64) -> (f64, f64) {
418    // Line equation: y = slope * x + intercept
419    // Normal vector: (slope, -1)
420    // Normalized: (slope, -1) / sqrt(slope^2 + 1)
421
422    let norm_sq = slope * slope + 1.0;
423    let t = ((y - intercept) * slope + x) / norm_sq;
424
425    let x_proj = t;
426    let y_proj = slope * t + intercept;
427
428    (x_proj, y_proj)
429}
430
431/// Weighted orthogonal projection
432#[allow(dead_code)]
433fn weighted_orthogonal_projection(
434    x: f64,
435    y: f64,
436    slope: f64,
437    intercept: f64,
438    weight_x: f64,
439    weight_y: f64,
440) -> (f64, f64) {
441    // Minimize: weight_x * (x - x_proj)^2 + weight_y * (y - y_proj)^2
442    // Subject to: y_proj = slope * x_proj + intercept
443
444    let a = weight_x + weight_y * slope * slope;
445    let _b = weight_y * slope;
446    let c = weight_x * x + weight_y * slope * (y - intercept);
447
448    let x_proj = c / a;
449    let y_proj = slope * x_proj + intercept;
450
451    (x_proj, y_proj)
452}
453
454/// Weighted least squares for a line
455#[allow(dead_code)]
456fn weighted_least_squares_line<S1, S2, S3, S4>(
457    x: &ArrayBase<S1, Ix1>,
458    y: &ArrayBase<S2, Ix1>,
459    weight_x: &ArrayBase<S3, Ix1>,
460    weight_y: &ArrayBase<S4, Ix1>,
461) -> (f64, f64)
462where
463    S1: Data<Elem = f64>,
464    S2: Data<Elem = f64>,
465    S3: Data<Elem = f64>,
466    S4: Data<Elem = f64>,
467{
468    let n = x.len();
469    let mut sum_wx = 0.0;
470    let mut sum_wy = 0.0;
471    let mut sum_wxx = 0.0;
472    let mut sum_wxy = 0.0;
473    let mut _sum_wyy = 0.0;
474    let mut sum_w = 0.0;
475
476    for i in 0..n {
477        let w = (weight_x[i] + weight_y[i]) / 2.0; // Combined weight
478        sum_w += w;
479        sum_wx += w * x[i];
480        sum_wy += w * y[i];
481        sum_wxx += w * x[i] * x[i];
482        sum_wxy += w * x[i] * y[i];
483        _sum_wyy += w * y[i] * y[i];
484    }
485
486    let x_mean = sum_wx / sum_w;
487    let y_mean = sum_wy / sum_w;
488
489    let cov_xx = sum_wxx / sum_w - x_mean * x_mean;
490    let cov_xy = sum_wxy / sum_w - x_mean * y_mean;
491
492    let slope = cov_xy / cov_xx;
493    let intercept = y_mean - slope * x_mean;
494
495    (slope, intercept)
496}
497
498/// Simple 2x2 eigendecomposition
499#[allow(dead_code)]
500fn eigen_2x2(matrix: &Array2<f64>) -> (Array1<f64>, Array2<f64>) {
501    let a = matrix[[0, 0]];
502    let b = matrix[[0, 1]];
503    let c = matrix[[1, 0]];
504    let d = matrix[[1, 1]];
505
506    // Characteristic equation: λ² - (a+d)λ + (ad-bc) = 0
507    let trace = a + d;
508    let det = a * d - b * c;
509
510    let discriminant = trace * trace - 4.0 * det;
511    let sqrt_disc = discriminant.sqrt();
512
513    let lambda1 = (trace + sqrt_disc) / 2.0;
514    let lambda2 = (trace - sqrt_disc) / 2.0;
515
516    // Eigenvectors
517    let mut eigenvectors = Array2::zeros((2, 2));
518
519    // For λ1
520    if (a - lambda1).abs() > 1e-10 || b.abs() > 1e-10 {
521        let v1_x = b;
522        let v1_y = lambda1 - a;
523        let norm1 = (v1_x * v1_x + v1_y * v1_y).sqrt();
524        eigenvectors[[0, 0]] = v1_x / norm1;
525        eigenvectors[[1, 0]] = v1_y / norm1;
526    } else {
527        eigenvectors[[0, 0]] = 1.0;
528        eigenvectors[[1, 0]] = 0.0;
529    }
530
531    // For λ2
532    if (a - lambda2).abs() > 1e-10 || b.abs() > 1e-10 {
533        let v2_x = b;
534        let v2_y = lambda2 - a;
535        let norm2 = (v2_x * v2_x + v2_y * v2_y).sqrt();
536        eigenvectors[[0, 1]] = v2_x / norm2;
537        eigenvectors[[1, 1]] = v2_y / norm2;
538    } else {
539        eigenvectors[[0, 1]] = 0.0;
540        eigenvectors[[1, 1]] = 1.0;
541    }
542
543    (array![lambda1, lambda2], eigenvectors)
544}
545
546#[cfg(test)]
547mod tests {
548    use super::*;
549    use scirs2_core::ndarray::array;
550
551    #[test]
552    fn test_total_least_squares_simple() {
553        // Generate data with errors in both x and y
554        let true_slope = 1.5;
555        let true_intercept = 0.5;
556
557        let x_true = array![1.0, 2.0, 3.0, 4.0, 5.0];
558        let y_true = &x_true * true_slope + true_intercept;
559
560        // Add errors
561        let x_errors = array![0.1, -0.05, 0.08, -0.03, 0.06];
562        let y_errors = array![-0.05, 0.1, -0.07, 0.04, -0.08];
563
564        let x_measured = &x_true + &x_errors;
565        let y_measured = &y_true + &y_errors;
566
567        let result = total_least_squares(
568            &x_measured,
569            &y_measured,
570            None::<&Array1<f64>>,
571            None::<&Array1<f64>>,
572            None,
573        )
574        .unwrap();
575
576        // Check that the estimated parameters are close to true values
577        assert!((result.slope - true_slope).abs() < 0.1);
578        assert!((result.intercept - true_intercept).abs() < 0.1);
579    }
580
581    #[test]
582    fn test_weighted_total_least_squares() {
583        // Data with different error variances
584        let x_measured = array![1.0, 2.1, 2.9, 4.2, 5.0];
585        let y_measured = array![2.1, 3.9, 5.1, 6.8, 8.1];
586
587        // Known error variances (larger for some points)
588        let x_variance = array![0.01, 0.01, 0.01, 0.1, 0.01];
589        let y_variance = array![0.01, 0.02, 0.01, 0.1, 0.01];
590
591        let result = total_least_squares(
592            &x_measured,
593            &y_measured,
594            Some(&x_variance),
595            Some(&y_variance),
596            None,
597        )
598        .unwrap();
599
600        // The point with large variance should have less influence
601        assert!(result.converged);
602        println!(
603            "Weighted TLS: slope = {:.3}, intercept = {:.3}",
604            result.slope, result.intercept
605        );
606    }
607
608    #[test]
609    fn test_iterative_vs_svd() {
610        // Compare iterative and SVD methods
611        let x_measured = array![0.5, 1.5, 2.8, 3.7, 4.9];
612        let y_measured = array![1.2, 2.7, 4.1, 5.3, 6.8];
613
614        let mut options_svd = TotalLeastSquaresOptions::default();
615        options_svd.method = TLSMethod::SVD;
616
617        let mut options_iter = TotalLeastSquaresOptions::default();
618        options_iter.method = TLSMethod::Iterative;
619
620        let result_svd = total_least_squares::<
621            scirs2_core::ndarray::OwnedRepr<f64>,
622            scirs2_core::ndarray::OwnedRepr<f64>,
623            scirs2_core::ndarray::OwnedRepr<f64>,
624            scirs2_core::ndarray::OwnedRepr<f64>,
625        >(
626            &x_measured,
627            &y_measured,
628            None::<&Array1<f64>>,
629            None::<&Array1<f64>>,
630            Some(options_svd),
631        )
632        .unwrap();
633
634        let result_iter = total_least_squares::<
635            scirs2_core::ndarray::OwnedRepr<f64>,
636            scirs2_core::ndarray::OwnedRepr<f64>,
637            scirs2_core::ndarray::OwnedRepr<f64>,
638            scirs2_core::ndarray::OwnedRepr<f64>,
639        >(
640            &x_measured,
641            &y_measured,
642            None::<&Array1<f64>>,
643            None::<&Array1<f64>>,
644            Some(options_iter),
645        )
646        .unwrap();
647
648        // Results should be similar
649        assert!((result_svd.slope - result_iter.slope).abs() < 0.01);
650        assert!((result_svd.intercept - result_iter.intercept).abs() < 0.01);
651    }
652}