scirs2_optimize/least_squares/
weighted.rs1use crate::error::OptimizeResult;
65use crate::result::OptimizeResults;
66use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Ix1};
67
68#[derive(Debug, Clone)]
70pub struct WeightedOptions {
71 pub max_iter: usize,
73
74 pub max_nfev: Option<usize>,
76
77 pub xtol: f64,
79
80 pub ftol: f64,
82
83 pub gtol: f64,
85
86 pub diff_step: f64,
88
89 pub check_weights: bool,
91}
92
93impl Default for WeightedOptions {
94 fn default() -> Self {
95 WeightedOptions {
96 max_iter: 100,
97 max_nfev: None,
98 xtol: 1e-8,
99 ftol: 1e-8,
100 gtol: 1e-8,
101 diff_step: 1e-8,
102 check_weights: true,
103 }
104 }
105}
106
107#[allow(dead_code)]
125pub fn weighted_least_squares<F, J, D, S1, S2, S3>(
126 residuals: F,
127 x0: &ArrayBase<S1, Ix1>,
128 weights: &ArrayBase<S2, Ix1>,
129 jacobian: Option<J>,
130 data: &ArrayBase<S3, Ix1>,
131 options: Option<WeightedOptions>,
132) -> OptimizeResult<OptimizeResults<f64>>
133where
134 F: Fn(&[f64], &[D]) -> Array1<f64>,
135 J: Fn(&[f64], &[D]) -> Array2<f64>,
136 D: Clone,
137 S1: Data<Elem = f64>,
138 S2: Data<Elem = f64>,
139 S3: Data<Elem = D>,
140{
141 let options = options.unwrap_or_default();
142
143 if options.check_weights {
145 for &w in weights.iter() {
146 if w < 0.0 {
147 return Err(crate::error::OptimizeError::ValueError(
148 "Weights must be non-negative".to_string(),
149 ));
150 }
151 }
152 }
153
154 weighted_gauss_newton(residuals, x0, weights, jacobian, data, &options)
156}
157
158#[allow(dead_code)]
160fn weighted_gauss_newton<F, J, D, S1, S2, S3>(
161 residuals: F,
162 x0: &ArrayBase<S1, Ix1>,
163 weights: &ArrayBase<S2, Ix1>,
164 jacobian: Option<J>,
165 data: &ArrayBase<S3, Ix1>,
166 options: &WeightedOptions,
167) -> OptimizeResult<OptimizeResults<f64>>
168where
169 F: Fn(&[f64], &[D]) -> Array1<f64>,
170 J: Fn(&[f64], &[D]) -> Array2<f64>,
171 D: Clone,
172 S1: Data<Elem = f64>,
173 S2: Data<Elem = f64>,
174 S3: Data<Elem = D>,
175{
176 let mut x = x0.to_owned();
177 let m = x.len();
178 let n = weights.len();
179
180 let max_nfev = options.max_nfev.unwrap_or(options.max_iter * m * 10);
181 let mut nfev = 0;
182 let mut njev = 0;
183 let mut iter = 0;
184
185 let sqrt_weights = weights.mapv(f64::sqrt);
187
188 let compute_numerical_jacobian =
190 |x_val: &Array1<f64>, res_val: &Array1<f64>| -> (Array2<f64>, usize) {
191 let eps = options.diff_step;
192 let mut jac = Array2::zeros((n, m));
193 let mut count = 0;
194
195 for j in 0..m {
196 let mut x_h = x_val.clone();
197 x_h[j] += eps;
198 let res_h = residuals(x_h.as_slice().unwrap(), data.as_slice().unwrap());
199 count += 1;
200
201 for i in 0..n {
202 jac[[i, j]] = (res_h[i] - res_val[i]) / eps;
203 }
204 }
205
206 (jac, count)
207 };
208
209 while iter < options.max_iter && nfev < max_nfev {
211 let res = residuals(x.as_slice().unwrap(), data.as_slice().unwrap());
213 nfev += 1;
214
215 let weighted_res = &res * &sqrt_weights;
217
218 let cost = 0.5 * weighted_res.iter().map(|&r| r * r).sum::<f64>();
220
221 let (jac, jac_evals) = match &jacobian {
223 Some(jac_fn) => {
224 let j = jac_fn(x.as_slice().unwrap(), data.as_slice().unwrap());
225 njev += 1;
226 (j, 0)
227 }
228 None => {
229 let (j, count) = compute_numerical_jacobian(&x, &res);
230 nfev += count;
231 (j, count)
232 }
233 };
234
235 let mut weighted_jac = Array2::zeros((n, m));
237 for i in 0..n {
238 for j in 0..m {
239 weighted_jac[[i, j]] = jac[[i, j]] * sqrt_weights[i];
240 }
241 }
242
243 let gradient = weighted_jac.t().dot(&weighted_res);
245
246 if gradient.iter().all(|&g| g.abs() < options.gtol) {
248 let mut result = OptimizeResults::<f64>::default();
249 result.x = x;
250 result.fun = cost;
251 result.nfev = nfev;
252 result.njev = njev;
253 result.nit = iter;
254 result.success = true;
255 result.message = "Optimization terminated successfully.".to_string();
256 return Ok(result);
257 }
258
259 let jtw_j = weighted_jac.t().dot(&weighted_jac);
261 let neg_gradient = -&gradient;
262
263 match solve(&jtw_j, &neg_gradient) {
265 Some(step) => {
266 let mut alpha = 1.0;
268 let mut best_cost = cost;
269 let mut best_x = x.clone();
270
271 for _ in 0..10 {
272 let x_new = &x + &step * alpha;
273 let res_new = residuals(x_new.as_slice().unwrap(), data.as_slice().unwrap());
274 nfev += 1;
275
276 let weighted_res_new = &res_new * &sqrt_weights;
277 let new_cost = 0.5 * weighted_res_new.iter().map(|&r| r * r).sum::<f64>();
278
279 if new_cost < best_cost {
280 best_cost = new_cost;
281 best_x = x_new;
282 break;
283 }
284
285 alpha *= 0.5;
286 }
287
288 let step_norm = step.iter().map(|&s| s * s).sum::<f64>().sqrt();
290 let x_norm = x.iter().map(|&xi| xi * xi).sum::<f64>().sqrt();
291
292 if step_norm < options.xtol * (1.0 + x_norm) {
293 let mut result = OptimizeResults::<f64>::default();
294 result.x = best_x;
295 result.fun = best_cost;
296 result.nfev = nfev;
297 result.njev = njev;
298 result.nit = iter;
299 result.success = true;
300 result.message = "Converged (step size tolerance)".to_string();
301 return Ok(result);
302 }
303
304 if (cost - best_cost).abs() < options.ftol * cost {
306 let mut result = OptimizeResults::<f64>::default();
307 result.x = best_x;
308 result.fun = best_cost;
309 result.nfev = nfev;
310 result.njev = njev;
311 result.nit = iter;
312 result.success = true;
313 result.message = "Converged (function tolerance)".to_string();
314 return Ok(result);
315 }
316
317 x = best_x;
318 }
319 None => {
320 let mut result = OptimizeResults::<f64>::default();
322 result.x = x;
323 result.fun = cost;
324 result.nfev = nfev;
325 result.njev = njev;
326 result.nit = iter;
327 result.success = false;
328 result.message = "Singular matrix in normal equations".to_string();
329 return Ok(result);
330 }
331 }
332
333 iter += 1;
334 }
335
336 let res_final = residuals(x.as_slice().unwrap(), data.as_slice().unwrap());
338 let weighted_res_final = &res_final * &sqrt_weights;
339 let final_cost = 0.5 * weighted_res_final.iter().map(|&r| r * r).sum::<f64>();
340
341 let mut result = OptimizeResults::<f64>::default();
342 result.x = x;
343 result.fun = final_cost;
344 result.nfev = nfev;
345 result.njev = njev;
346 result.nit = iter;
347 result.success = false;
348 result.message = "Maximum iterations reached".to_string();
349
350 Ok(result)
351}
352
353#[allow(dead_code)]
355fn solve(a: &Array2<f64>, b: &Array1<f64>) -> Option<Array1<f64>> {
356 use scirs2_linalg::solve;
357
358 solve(&a.view(), &b.view(), None).ok()
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364 use scirs2_core::ndarray::array;
365
366 #[test]
367 fn test_weighted_least_squares_simple() {
368 fn residual(x: &[f64], data: &[f64]) -> Array1<f64> {
370 let n = data.len() / 2;
371 let t_values = &data[0..n];
372 let y_values = &data[n..];
373
374 let mut res = Array1::zeros(n);
375 for i in 0..n {
376 res[i] = y_values[i] - (x[0] + x[1] * t_values[i]);
377 }
378 res
379 }
380
381 fn jacobian(x: &[f64], data: &[f64]) -> Array2<f64> {
382 let n = data.len() / 2;
383 let t_values = &data[0..n];
384
385 let mut jac = Array2::zeros((n, 2));
386 for i in 0..n {
387 jac[[i, 0]] = -1.0;
388 jac[[i, 1]] = -t_values[i];
389 }
390 jac
391 }
392
393 let data = array![0.0, 1.0, 2.0, 3.0, 4.0, 0.1, 0.9, 2.1, 2.9, 4.1];
395
396 let weights = array![1.0, 1.0, 1.0, 10.0, 10.0];
398
399 let x0 = array![0.0, 0.0];
400
401 let result =
402 weighted_least_squares(residual, &x0, &weights, Some(jacobian), &data, None).unwrap();
403
404 assert!(result.success);
405 assert!((result.x[1] - 1.0).abs() < 0.1); }
408
409 #[test]
410 fn test_negative_weights() {
411 fn residual(x: &[f64], data: &[f64]) -> Array1<f64> {
413 array![x[0] - 1.0]
414 }
415
416 let x0 = array![0.0];
417 let weights = array![-1.0]; let data = array![];
419
420 let result = weighted_least_squares(
421 residual,
422 &x0,
423 &weights,
424 None::<fn(&[f64], &[f64]) -> Array2<f64>>,
425 &data,
426 None,
427 );
428
429 assert!(result.is_err());
430 }
431
432 #[test]
433 fn test_weighted_vs_unweighted() {
434 fn residual(x: &[f64], data: &[f64]) -> Array1<f64> {
436 let n = data.len() / 2;
437 let t_values = &data[0..n];
438 let y_values = &data[n..];
439
440 let mut res = Array1::zeros(n);
441 for i in 0..n {
442 res[i] = y_values[i] - (x[0] + x[1] * t_values[i]);
443 }
444 res
445 }
446
447 let data = array![0.0, 1.0, 2.0, 0.0, 1.0, 10.0]; let x0 = array![0.0, 0.0];
451
452 let weights_uniform = array![1.0, 1.0, 1.0];
454
455 let weights_robust = array![1.0, 1.0, 0.1];
457
458 let result_uniform = weighted_least_squares(
459 residual,
460 &x0,
461 &weights_uniform,
462 None::<fn(&[f64], &[f64]) -> Array2<f64>>,
463 &data,
464 None,
465 )
466 .unwrap();
467
468 let result_robust = weighted_least_squares(
469 residual,
470 &x0,
471 &weights_robust,
472 None::<fn(&[f64], &[f64]) -> Array2<f64>>,
473 &data,
474 None,
475 )
476 .unwrap();
477
478 assert!((result_robust.x[1] - 1.0).abs() < (result_uniform.x[1] - 1.0).abs());
481 }
482}