scirs2_optimize/least_squares/
weighted.rs

1//! Weighted least squares optimization
2//!
3//! This module provides weighted least squares methods where each residual
4//! can have a different weight, allowing for handling heteroscedastic data
5//! (data with varying variance).
6//!
7//! # Example
8//!
9//! ```
10//! use scirs2_core::ndarray::{array, Array1, Array2};
11//! use scirs2_optimize::least_squares::weighted::{weighted_least_squares, WeightedOptions};
12//!
13//! // Define a function that returns the residuals
14//! fn residual(x: &[f64], data: &[f64]) -> Array1<f64> {
15//!     let n = data.len() / 2;
16//!     let t_values = &data[0..n];
17//!     let y_values = &data[n..];
18//!     
19//!     let mut res = Array1::zeros(n);
20//!     for i in 0..n {
21//!         res[i] = y_values[i] - (x[0] + x[1] * t_values[i]);
22//!     }
23//!     res
24//! }
25//!
26//! // Define the Jacobian
27//! fn jacobian(x: &[f64], data: &[f64]) -> Array2<f64> {
28//!     let n = data.len() / 2;
29//!     let t_values = &data[0..n];
30//!     
31//!     let mut jac = Array2::zeros((n, 2));
32//!     for i in 0..n {
33//!         jac[[i, 0]] = -1.0;
34//!         jac[[i, 1]] = -t_values[i];
35//!     }
36//!     jac
37//! }
38//!
39//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
40//! // Create data
41//! let data = array![0.0, 1.0, 2.0, 3.0, 4.0, 0.1, 0.9, 2.1, 2.9, 4.1];
42//!
43//! // Define weights (higher weight = more importance)
44//! let weights = array![1.0, 1.0, 1.0, 10.0, 10.0]; // Last two points have more weight
45//!
46//! // Initial guess
47//! let x0 = array![0.0, 0.0];
48//!
49//! // Solve using weighted least squares
50//! let result = weighted_least_squares(
51//!     residual,
52//!     &x0,
53//!     &weights,
54//!     Some(jacobian),
55//!     &data,
56//!     None
57//! )?;
58//!
59//! assert!(result.success);
60//! # Ok(())
61//! # }
62//! ```
63
64use crate::error::OptimizeResult;
65use crate::result::OptimizeResults;
66use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Ix1};
67
68/// Options for weighted least squares optimization
69#[derive(Debug, Clone)]
70pub struct WeightedOptions {
71    /// Maximum number of iterations
72    pub max_iter: usize,
73
74    /// Maximum number of function evaluations
75    pub max_nfev: Option<usize>,
76
77    /// Tolerance for termination by the change of parameters
78    pub xtol: f64,
79
80    /// Tolerance for termination by the change of cost function
81    pub ftol: f64,
82
83    /// Tolerance for termination by the norm of gradient
84    pub gtol: f64,
85
86    /// Step size for finite difference approximation
87    pub diff_step: f64,
88
89    /// Whether to check that weights are non-negative
90    pub check_weights: bool,
91}
92
93impl Default for WeightedOptions {
94    fn default() -> Self {
95        WeightedOptions {
96            max_iter: 100,
97            max_nfev: None,
98            xtol: 1e-8,
99            ftol: 1e-8,
100            gtol: 1e-8,
101            diff_step: 1e-8,
102            check_weights: true,
103        }
104    }
105}
106
107/// Solve a weighted least squares problem
108///
109/// This function minimizes the weighted sum of squares of residuals:
110/// `sum(weights[i] * residuals[i]^2)`
111///
112/// # Arguments
113///
114/// * `residuals` - Function that returns the residuals
115/// * `x0` - Initial guess for the parameters
116/// * `weights` - Weights for each residual (must be non-negative)
117/// * `jacobian` - Optional Jacobian function
118/// * `data` - Additional data to pass to residuals and jacobian
119/// * `options` - Options for the optimization
120///
121/// # Returns
122///
123/// * `OptimizeResults` containing the optimization results
124#[allow(dead_code)]
125pub fn weighted_least_squares<F, J, D, S1, S2, S3>(
126    residuals: F,
127    x0: &ArrayBase<S1, Ix1>,
128    weights: &ArrayBase<S2, Ix1>,
129    jacobian: Option<J>,
130    data: &ArrayBase<S3, Ix1>,
131    options: Option<WeightedOptions>,
132) -> OptimizeResult<OptimizeResults<f64>>
133where
134    F: Fn(&[f64], &[D]) -> Array1<f64>,
135    J: Fn(&[f64], &[D]) -> Array2<f64>,
136    D: Clone,
137    S1: Data<Elem = f64>,
138    S2: Data<Elem = f64>,
139    S3: Data<Elem = D>,
140{
141    let options = options.unwrap_or_default();
142
143    // Check weights if requested
144    if options.check_weights {
145        for &w in weights.iter() {
146            if w < 0.0 {
147                return Err(crate::error::OptimizeError::ValueError(
148                    "Weights must be non-negative".to_string(),
149                ));
150            }
151        }
152    }
153
154    // Implementation using Gauss-Newton method with weighted residuals
155    weighted_gauss_newton(residuals, x0, weights, jacobian, data, &options)
156}
157
158/// Weighted Gauss-Newton implementation
159#[allow(dead_code)]
160fn weighted_gauss_newton<F, J, D, S1, S2, S3>(
161    residuals: F,
162    x0: &ArrayBase<S1, Ix1>,
163    weights: &ArrayBase<S2, Ix1>,
164    jacobian: Option<J>,
165    data: &ArrayBase<S3, Ix1>,
166    options: &WeightedOptions,
167) -> OptimizeResult<OptimizeResults<f64>>
168where
169    F: Fn(&[f64], &[D]) -> Array1<f64>,
170    J: Fn(&[f64], &[D]) -> Array2<f64>,
171    D: Clone,
172    S1: Data<Elem = f64>,
173    S2: Data<Elem = f64>,
174    S3: Data<Elem = D>,
175{
176    let mut x = x0.to_owned();
177    let m = x.len();
178    let n = weights.len();
179
180    let max_nfev = options.max_nfev.unwrap_or(options.max_iter * m * 10);
181    let mut nfev = 0;
182    let mut njev = 0;
183    let mut iter = 0;
184
185    // Compute square root of weights for transforming the problem
186    let sqrt_weights = weights.mapv(f64::sqrt);
187
188    // Numerical gradient helper
189    let compute_numerical_jacobian =
190        |x_val: &Array1<f64>, res_val: &Array1<f64>| -> (Array2<f64>, usize) {
191            let eps = options.diff_step;
192            let mut jac = Array2::zeros((n, m));
193            let mut count = 0;
194
195            for j in 0..m {
196                let mut x_h = x_val.clone();
197                x_h[j] += eps;
198                let res_h = residuals(x_h.as_slice().unwrap(), data.as_slice().unwrap());
199                count += 1;
200
201                for i in 0..n {
202                    jac[[i, j]] = (res_h[i] - res_val[i]) / eps;
203                }
204            }
205
206            (jac, count)
207        };
208
209    // Main optimization loop
210    while iter < options.max_iter && nfev < max_nfev {
211        // Compute residuals
212        let res = residuals(x.as_slice().unwrap(), data.as_slice().unwrap());
213        nfev += 1;
214
215        // Compute weighted residuals
216        let weighted_res = &res * &sqrt_weights;
217
218        // Compute cost function
219        let cost = 0.5 * weighted_res.iter().map(|&r| r * r).sum::<f64>();
220
221        // Compute Jacobian
222        let (jac, jac_evals) = match &jacobian {
223            Some(jac_fn) => {
224                let j = jac_fn(x.as_slice().unwrap(), data.as_slice().unwrap());
225                njev += 1;
226                (j, 0)
227            }
228            None => {
229                let (j, count) = compute_numerical_jacobian(&x, &res);
230                nfev += count;
231                (j, count)
232            }
233        };
234
235        // Apply weights to Jacobian
236        let mut weighted_jac = Array2::zeros((n, m));
237        for i in 0..n {
238            for j in 0..m {
239                weighted_jac[[i, j]] = jac[[i, j]] * sqrt_weights[i];
240            }
241        }
242
243        // Compute gradient: g = J^T * W * r
244        let gradient = weighted_jac.t().dot(&weighted_res);
245
246        // Check convergence on gradient
247        if gradient.iter().all(|&g| g.abs() < options.gtol) {
248            let mut result = OptimizeResults::<f64>::default();
249            result.x = x;
250            result.fun = cost;
251            result.nfev = nfev;
252            result.njev = njev;
253            result.nit = iter;
254            result.success = true;
255            result.message = "Optimization terminated successfully.".to_string();
256            return Ok(result);
257        }
258
259        // Form normal equations: (J^T * W * J) * delta = -J^T * W * r
260        let jtw_j = weighted_jac.t().dot(&weighted_jac);
261        let neg_gradient = -&gradient;
262
263        // Solve for step
264        match solve(&jtw_j, &neg_gradient) {
265            Some(step) => {
266                // Simple line search
267                let mut alpha = 1.0;
268                let mut best_cost = cost;
269                let mut best_x = x.clone();
270
271                for _ in 0..10 {
272                    let x_new = &x + &step * alpha;
273                    let res_new = residuals(x_new.as_slice().unwrap(), data.as_slice().unwrap());
274                    nfev += 1;
275
276                    let weighted_res_new = &res_new * &sqrt_weights;
277                    let new_cost = 0.5 * weighted_res_new.iter().map(|&r| r * r).sum::<f64>();
278
279                    if new_cost < best_cost {
280                        best_cost = new_cost;
281                        best_x = x_new;
282                        break;
283                    }
284
285                    alpha *= 0.5;
286                }
287
288                // Check convergence on step size
289                let step_norm = step.iter().map(|&s| s * s).sum::<f64>().sqrt();
290                let x_norm = x.iter().map(|&xi| xi * xi).sum::<f64>().sqrt();
291
292                if step_norm < options.xtol * (1.0 + x_norm) {
293                    let mut result = OptimizeResults::<f64>::default();
294                    result.x = best_x;
295                    result.fun = best_cost;
296                    result.nfev = nfev;
297                    result.njev = njev;
298                    result.nit = iter;
299                    result.success = true;
300                    result.message = "Converged (step size tolerance)".to_string();
301                    return Ok(result);
302                }
303
304                // Check convergence on cost function
305                if (cost - best_cost).abs() < options.ftol * cost {
306                    let mut result = OptimizeResults::<f64>::default();
307                    result.x = best_x;
308                    result.fun = best_cost;
309                    result.nfev = nfev;
310                    result.njev = njev;
311                    result.nit = iter;
312                    result.success = true;
313                    result.message = "Converged (function tolerance)".to_string();
314                    return Ok(result);
315                }
316
317                x = best_x;
318            }
319            None => {
320                // Singular matrix, terminate
321                let mut result = OptimizeResults::<f64>::default();
322                result.x = x;
323                result.fun = cost;
324                result.nfev = nfev;
325                result.njev = njev;
326                result.nit = iter;
327                result.success = false;
328                result.message = "Singular matrix in normal equations".to_string();
329                return Ok(result);
330            }
331        }
332
333        iter += 1;
334    }
335
336    // Max iterations reached
337    let res_final = residuals(x.as_slice().unwrap(), data.as_slice().unwrap());
338    let weighted_res_final = &res_final * &sqrt_weights;
339    let final_cost = 0.5 * weighted_res_final.iter().map(|&r| r * r).sum::<f64>();
340
341    let mut result = OptimizeResults::<f64>::default();
342    result.x = x;
343    result.fun = final_cost;
344    result.nfev = nfev;
345    result.njev = njev;
346    result.nit = iter;
347    result.success = false;
348    result.message = "Maximum iterations reached".to_string();
349
350    Ok(result)
351}
352
353/// Simple linear system solver (same as in robust.rs)
354#[allow(dead_code)]
355fn solve(a: &Array2<f64>, b: &Array1<f64>) -> Option<Array1<f64>> {
356    use scirs2_linalg::solve;
357
358    solve(&a.view(), &b.view(), None).ok()
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364    use scirs2_core::ndarray::array;
365
366    #[test]
367    fn test_weighted_least_squares_simple() {
368        // Linear regression problem
369        fn residual(x: &[f64], data: &[f64]) -> Array1<f64> {
370            let n = data.len() / 2;
371            let t_values = &data[0..n];
372            let y_values = &data[n..];
373
374            let mut res = Array1::zeros(n);
375            for i in 0..n {
376                res[i] = y_values[i] - (x[0] + x[1] * t_values[i]);
377            }
378            res
379        }
380
381        fn jacobian(x: &[f64], data: &[f64]) -> Array2<f64> {
382            let n = data.len() / 2;
383            let t_values = &data[0..n];
384
385            let mut jac = Array2::zeros((n, 2));
386            for i in 0..n {
387                jac[[i, 0]] = -1.0;
388                jac[[i, 1]] = -t_values[i];
389            }
390            jac
391        }
392
393        // Data
394        let data = array![0.0, 1.0, 2.0, 3.0, 4.0, 0.1, 0.9, 2.1, 2.9, 4.1];
395
396        // Weights - give more importance to the last two points
397        let weights = array![1.0, 1.0, 1.0, 10.0, 10.0];
398
399        let x0 = array![0.0, 0.0];
400
401        let result =
402            weighted_least_squares(residual, &x0, &weights, Some(jacobian), &data, None).unwrap();
403
404        assert!(result.success);
405        // The solution should favor the last two points more
406        assert!((result.x[1] - 1.0).abs() < 0.1); // Slope close to 1.0
407    }
408
409    #[test]
410    fn test_negative_weights() {
411        // Simple test function
412        fn residual(x: &[f64], data: &[f64]) -> Array1<f64> {
413            array![x[0] - 1.0]
414        }
415
416        let x0 = array![0.0];
417        let weights = array![-1.0]; // Invalid negative weight
418        let data = array![];
419
420        let result = weighted_least_squares(
421            residual,
422            &x0,
423            &weights,
424            None::<fn(&[f64], &[f64]) -> Array2<f64>>,
425            &data,
426            None,
427        );
428
429        assert!(result.is_err());
430    }
431
432    #[test]
433    fn test_weighted_vs_unweighted() {
434        // Linear regression problem with outlier
435        fn residual(x: &[f64], data: &[f64]) -> Array1<f64> {
436            let n = data.len() / 2;
437            let t_values = &data[0..n];
438            let y_values = &data[n..];
439
440            let mut res = Array1::zeros(n);
441            for i in 0..n {
442                res[i] = y_values[i] - (x[0] + x[1] * t_values[i]);
443            }
444            res
445        }
446
447        // Data with an outlier
448        let data = array![0.0, 1.0, 2.0, 0.0, 1.0, 10.0]; // Last point is outlier
449
450        let x0 = array![0.0, 0.0];
451
452        // Uniform weights (essentially unweighted)
453        let weights_uniform = array![1.0, 1.0, 1.0];
454
455        // Downweight the outlier
456        let weights_robust = array![1.0, 1.0, 0.1];
457
458        let result_uniform = weighted_least_squares(
459            residual,
460            &x0,
461            &weights_uniform,
462            None::<fn(&[f64], &[f64]) -> Array2<f64>>,
463            &data,
464            None,
465        )
466        .unwrap();
467
468        let result_robust = weighted_least_squares(
469            residual,
470            &x0,
471            &weights_robust,
472            None::<fn(&[f64], &[f64]) -> Array2<f64>>,
473            &data,
474            None,
475        )
476        .unwrap();
477
478        // The robust solution should have a slope closer to 1.0 (true value)
479        // than the uniform weight solution
480        assert!((result_robust.x[1] - 1.0).abs() < (result_uniform.x[1] - 1.0).abs());
481    }
482}