1use super::{
8 utils, StreamingConfig, StreamingDataPoint, StreamingObjective, StreamingOptimizer,
9 StreamingStats,
10};
11use crate::error::OptimizeError;
12use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
13use std::time::{Duration, Instant};
16
17type Result<T> = std::result::Result<T, OptimizeError>;
18
19#[derive(Debug, Clone, Copy)]
21pub enum RealTimeMethod {
22 RecursiveLeastSquares,
24 KalmanFilter,
26 ExponentiallyWeightedRLS,
28 SlidingWindowRLS,
30}
31
32#[derive(Debug, Clone)]
34pub struct RealTimeEstimator<T: StreamingObjective> {
35 parameters: Array1<f64>,
37 objective: T,
39 config: StreamingConfig,
41 stats: StreamingStats,
43 method: RealTimeMethod,
45 covariance: Array2<f64>,
47 forgetting_factor: f64,
49 process_noise: f64,
51 measurement_noise: f64,
53 last_update_time: Option<Instant>,
55 window_data: std::collections::VecDeque<(Array1<f64>, f64)>, max_processing_time: Duration,
59}
60
61impl<T: StreamingObjective> RealTimeEstimator<T> {
62 pub fn new(
64 initial_parameters: Array1<f64>,
65 objective: T,
66 config: StreamingConfig,
67 method: RealTimeMethod,
68 initial_covariance_scale: f64,
69 ) -> Self {
70 let n_params = initial_parameters.len();
71 let initial_covariance = Array2::eye(n_params) * initial_covariance_scale;
72 let forgetting_factor = config.forgetting_factor;
73 let window_size = config.window_size;
74
75 Self {
76 parameters: initial_parameters,
77 objective,
78 config,
79 stats: StreamingStats::default(),
80 method,
81 covariance: initial_covariance,
82 forgetting_factor,
83 process_noise: 1e-6,
84 measurement_noise: 1e-3,
85 last_update_time: None,
86 window_data: std::collections::VecDeque::with_capacity(window_size),
87 max_processing_time: Duration::from_millis(10), }
89 }
90
91 fn update_rls(&mut self, features: &ArrayView1<f64>, target: f64) -> Result<()> {
93 let n = features.len();
94
95 let prediction = self.parameters.dot(features);
97 let error = target - prediction;
98
99 let mut px = Array1::zeros(n);
101 for i in 0..n {
102 for j in 0..n {
103 px[i] += self.covariance[[i, j]] * features[j];
104 }
105 }
106
107 let denominator = 1.0 + features.dot(&px);
108 if denominator.abs() < 1e-12 {
109 return Ok(()); }
111
112 let gain = &px / denominator;
113
114 self.parameters = &self.parameters + &(error * &gain);
116
117 for i in 0..n {
119 for j in 0..n {
120 self.covariance[[i, j]] -= gain[i] * px[j];
121 }
122 }
123
124 Ok(())
125 }
126
127 fn update_ewrls(&mut self, features: &ArrayView1<f64>, target: f64) -> Result<()> {
129 let n = features.len();
130
131 self.covariance *= 1.0 / self.forgetting_factor;
133
134 let prediction = self.parameters.dot(features);
136 let error = target - prediction;
137
138 let mut px = Array1::zeros(n);
140 for i in 0..n {
141 for j in 0..n {
142 px[i] += self.covariance[[i, j]] * features[j];
143 }
144 }
145
146 let denominator = self.forgetting_factor + features.dot(&px);
147 if denominator.abs() < 1e-12 {
148 return Ok(());
149 }
150
151 let gain = &px / denominator;
152
153 self.parameters = &self.parameters + &(error * &gain);
155
156 for i in 0..n {
158 for j in 0..n {
159 self.covariance[[i, j]] -= gain[i] * px[j];
160 }
161 }
162
163 Ok(())
164 }
165
166 fn update_kalman(&mut self, features: &ArrayView1<f64>, target: f64) -> Result<()> {
168 let n = features.len();
169
170 for i in 0..n {
173 self.covariance[[i, i]] += self.process_noise;
174 }
175
176 let prediction = self.parameters.dot(features);
178 let innovation = target - prediction;
179
180 let mut px = Array1::zeros(n);
183 for i in 0..n {
184 for j in 0..n {
185 px[i] += self.covariance[[i, j]] * features[j];
186 }
187 }
188
189 let innovation_covariance = features.dot(&px) + self.measurement_noise;
190 if innovation_covariance.abs() < 1e-12 {
191 return Ok(());
192 }
193
194 let kalman_gain = &px / innovation_covariance;
196
197 self.parameters = &self.parameters + &(innovation * &kalman_gain);
199
200 for i in 0..n {
202 for j in 0..n {
203 self.covariance[[i, j]] -= kalman_gain[i] * px[j];
204 }
205 }
206
207 Ok(())
208 }
209
210 fn update_sliding_window_rls(&mut self, features: &ArrayView1<f64>, target: f64) -> Result<()> {
212 self.window_data.push_back((features.to_owned(), target));
214
215 if self.window_data.len() > self.config.window_size {
217 self.window_data.pop_front();
218 }
219
220 let n = features.len();
222 let mut xtx = Array2::zeros((n, n));
223 let mut xty = Array1::zeros(n);
224
225 for (x, y) in &self.window_data {
226 for i in 0..n {
227 for j in 0..n {
228 xtx[[i, j]] += x[i] * x[j];
229 }
230 xty[i] += x[i] * y;
231 }
232 }
233
234 for i in 0..n {
236 xtx[[i, i]] += self.config.regularization;
237 }
238
239 match scirs2_linalg::solve(&xtx.view(), &xty.view(), None) {
241 Ok(solution) => {
242 self.parameters = solution;
243 match scirs2_linalg::compat::pinv(&xtx.view(), None, false, true) {
245 Ok(pinv) => self.covariance = pinv,
246 Err(_) => {} }
248 }
249 Err(_) => {
250 self.update_rls(features, target)?;
252 }
253 }
254
255 Ok(())
256 }
257
258 fn adapt_parameters(&mut self) {
260 if self.stats.points_processed > 10 {
262 let recent_loss_trend = self.stats.current_loss - self.stats.average_loss;
263
264 if recent_loss_trend > 0.0 {
265 self.forgetting_factor = (self.forgetting_factor * 0.95).max(0.5);
267 } else {
268 self.forgetting_factor = (self.forgetting_factor * 1.01).min(0.999);
270 }
271 }
272
273 if matches!(self.method, RealTimeMethod::KalmanFilter) {
275 let param_change_rate = if self.stats.points_processed > 1 {
276 self.parameters.mapv(|x| x.abs()).sum() / self.stats.points_processed as f64
279 } else {
280 1e-6
281 };
282
283 self.process_noise = (param_change_rate * 0.1).max(1e-8).min(1e-3);
284 }
285 }
286
287 fn should_skip_for_timing(&self, starttime: Instant) -> bool {
289 starttime.elapsed() > self.max_processing_time
290 }
291}
292
293impl<T: StreamingObjective + Clone> StreamingOptimizer for RealTimeEstimator<T> {
294 fn update(&mut self, datapoint: &StreamingDataPoint) -> Result<()> {
295 let start_time = Instant::now();
296
297 if self.should_skip_for_timing(start_time) {
299 return Ok(());
300 }
301
302 let old_parameters = self.parameters.clone();
303 let features = &datapoint.features;
304 let target = datapoint.target;
305
306 match self.method {
308 RealTimeMethod::RecursiveLeastSquares => {
309 self.update_rls(&features.view(), target)?;
310 }
311 RealTimeMethod::ExponentiallyWeightedRLS => {
312 self.update_ewrls(&features.view(), target)?;
313 }
314 RealTimeMethod::KalmanFilter => {
315 self.update_kalman(&features.view(), target)?;
316 }
317 RealTimeMethod::SlidingWindowRLS => {
318 self.update_sliding_window_rls(&features.view(), target)?;
319 }
320 }
321
322 if self.stats.points_processed.is_multiple_of(20) {
324 self.adapt_parameters();
325 }
326
327 let loss = self.objective.evaluate(&self.parameters.view(), datapoint);
329 self.stats.points_processed += 1;
330 self.stats.updates_performed += 1;
331 self.stats.current_loss = loss;
332 self.stats.average_loss = utils::ewma_update(
333 self.stats.average_loss,
334 loss,
335 0.05, );
337
338 self.stats.converged = utils::check_convergence(
340 &old_parameters.view(),
341 &self.parameters.view(),
342 self.config.tolerance,
343 );
344
345 self.stats.processing_time_ms += start_time.elapsed().as_secs_f64() * 1000.0;
346 self.last_update_time = Some(start_time);
347
348 Ok(())
349 }
350
351 fn parameters(&self) -> &Array1<f64> {
352 &self.parameters
353 }
354
355 fn stats(&self) -> &StreamingStats {
356 &self.stats
357 }
358
359 fn reset(&mut self) {
360 let n = self.parameters.len();
361 self.covariance = Array2::eye(n) * 1000.0; self.forgetting_factor = self.config.forgetting_factor;
363 self.last_update_time = None;
364 self.window_data.clear();
365 self.stats = StreamingStats::default();
366 }
367}
368
369#[allow(dead_code)]
371pub fn recursive_least_squares<T: StreamingObjective>(
372 initial_parameters: Array1<f64>,
373 objective: T,
374 config: Option<StreamingConfig>,
375 initial_uncertainty: Option<f64>,
376) -> RealTimeEstimator<T> {
377 let config = config.unwrap_or_default();
378 let uncertainty = initial_uncertainty.unwrap_or(1000.0);
379
380 RealTimeEstimator::new(
381 initial_parameters,
382 objective,
383 config,
384 RealTimeMethod::RecursiveLeastSquares,
385 uncertainty,
386 )
387}
388
389#[allow(dead_code)]
391pub fn exponentially_weighted_rls<T: StreamingObjective>(
392 initial_parameters: Array1<f64>,
393 objective: T,
394 config: Option<StreamingConfig>,
395 forgetting_factor: Option<f64>,
396) -> RealTimeEstimator<T> {
397 let mut config = config.unwrap_or_default();
398 if let Some(ff) = forgetting_factor {
399 config.forgetting_factor = ff;
400 }
401
402 RealTimeEstimator::new(
403 initial_parameters,
404 objective,
405 config,
406 RealTimeMethod::ExponentiallyWeightedRLS,
407 100.0,
408 )
409}
410
411#[allow(dead_code)]
413pub fn kalman_filter_estimator<T: StreamingObjective>(
414 initial_parameters: Array1<f64>,
415 objective: T,
416 config: Option<StreamingConfig>,
417 process_noise: Option<f64>,
418 measurement_noise: Option<f64>,
419) -> RealTimeEstimator<T> {
420 let config = config.unwrap_or_default();
421 let mut estimator = RealTimeEstimator::new(
422 initial_parameters,
423 objective,
424 config,
425 RealTimeMethod::KalmanFilter,
426 1.0,
427 );
428
429 if let Some(pn) = process_noise {
430 estimator.process_noise = pn;
431 }
432 if let Some(mn) = measurement_noise {
433 estimator.measurement_noise = mn;
434 }
435
436 estimator
437}
438
439#[allow(dead_code)]
441pub fn real_time_linear_regression(
442 n_features: usize,
443 method: RealTimeMethod,
444 config: Option<StreamingConfig>,
445) -> RealTimeEstimator<super::LinearRegressionObjective> {
446 let config = config.unwrap_or_default();
447 let initial_params = Array1::zeros(n_features);
448 let objective = super::LinearRegressionObjective;
449
450 RealTimeEstimator::new(initial_params, objective, config, method, 100.0)
451}
452
453#[cfg(test)]
454mod tests {
455 use super::*;
456 use crate::streaming::{LinearRegressionObjective, StreamingDataPoint};
457
458 #[test]
459 fn test_rls_creation() {
460 let estimator =
461 recursive_least_squares(Array1::zeros(2), LinearRegressionObjective, None, None);
462
463 assert_eq!(estimator.parameters().len(), 2);
464 assert!(matches!(
465 estimator.method,
466 RealTimeMethod::RecursiveLeastSquares
467 ));
468 }
469
470 #[test]
471 fn test_rls_update() {
472 let mut estimator =
473 real_time_linear_regression(2, RealTimeMethod::RecursiveLeastSquares, None);
474
475 let features = Array1::from(vec![1.0, 2.0]);
476 let target = 3.0;
477 let point = StreamingDataPoint::new(features, target);
478
479 assert!(estimator.update(&point).is_ok());
480 assert_eq!(estimator.stats().points_processed, 1);
481 }
482
483 #[test]
484 fn test_ewrls_adaptation() {
485 let mut config = StreamingConfig::default();
486 config.forgetting_factor = 0.9;
487
488 let mut estimator = exponentially_weighted_rls(
489 Array1::zeros(2),
490 LinearRegressionObjective,
491 Some(config),
492 None,
493 );
494
495 for i in 0..10 {
497 let features = Array1::from(vec![i as f64, 1.0]);
498 let target = 2.0 * i as f64 + 1.0;
499 let point = StreamingDataPoint::new(features, target);
500
501 estimator.update(&point).unwrap();
502 }
503
504 assert_eq!(estimator.stats().points_processed, 10);
505 assert!(estimator.stats().current_loss.is_finite());
506 }
507
508 #[test]
509 fn test_kalman_filter() {
510 let mut estimator = kalman_filter_estimator(
511 Array1::zeros(2),
512 LinearRegressionObjective,
513 None,
514 Some(1e-6),
515 Some(1e-3),
516 );
517
518 let data_points = vec![
520 StreamingDataPoint::new(Array1::from(vec![1.0, 0.0]), 2.1),
521 StreamingDataPoint::new(Array1::from(vec![0.0, 1.0]), 2.9),
522 StreamingDataPoint::new(Array1::from(vec![1.0, 1.0]), 5.1),
523 ];
524
525 for point in &data_points {
526 estimator.update(point).unwrap();
527 }
528
529 assert_eq!(estimator.stats().points_processed, 3);
530
531 let params = estimator.parameters();
533 assert!((params[0] - 2.0).abs() < 1.0);
534 assert!((params[1] - 3.0).abs() < 1.0);
535 }
536
537 #[test]
538 fn test_sliding_window_rls() {
539 let mut estimator = real_time_linear_regression(2, RealTimeMethod::SlidingWindowRLS, None);
540
541 for i in 0..15 {
543 let features = Array1::from(vec![i as f64, 1.0]);
544 let target = 2.0 * i as f64;
545 let point = StreamingDataPoint::new(features, target);
546
547 estimator.update(&point).unwrap();
548 }
549
550 assert_eq!(estimator.stats().points_processed, 15);
552 assert!(estimator.window_data.len() <= estimator.config.window_size);
553 }
554
555 #[test]
556 fn test_covariance_updates() {
557 let mut estimator = recursive_least_squares(
558 Array1::zeros(2),
559 LinearRegressionObjective,
560 None,
561 Some(100.0),
562 );
563
564 let initial_covariance = estimator.covariance.clone();
565
566 let features = Array1::from(vec![1.0, 1.0]);
567 let target = 1.0;
568 let point = StreamingDataPoint::new(features, target);
569
570 estimator.update(&point).unwrap();
571
572 assert!(&estimator.covariance != &initial_covariance);
574
575 assert!(estimator.covariance[[0, 0]] < initial_covariance[[0, 0]]);
577 assert!(estimator.covariance[[1, 1]] < initial_covariance[[1, 1]]);
578 }
579
580 #[test]
581 fn test_real_time_constraints() {
582 let mut estimator =
583 real_time_linear_regression(2, RealTimeMethod::RecursiveLeastSquares, None);
584
585 estimator.max_processing_time = Duration::from_nanos(1);
587
588 let features = Array1::from(vec![1.0, 2.0]);
589 let target = 3.0;
590 let point = StreamingDataPoint::new(features, target);
591
592 let start = Instant::now();
594 estimator.update(&point).unwrap();
595 let elapsed = start.elapsed();
596
597 assert!(elapsed < Duration::from_millis(100));
599 }
600
601 #[test]
602 fn test_parameter_adaptation() {
603 let mut estimator = exponentially_weighted_rls(
604 Array1::zeros(2),
605 LinearRegressionObjective,
606 None,
607 Some(0.95),
608 );
609
610 let initial_ff = estimator.forgetting_factor;
611
612 for i in 0..50 {
614 let features = Array1::from(vec![i as f64, 1.0]);
615 let target = i as f64; let point = StreamingDataPoint::new(features, target);
617
618 estimator.update(&point).unwrap();
619 }
620
621 assert!(estimator.stats().points_processed == 50);
624 }
625}