scirs2_optimize/streaming/
real_time_estimation.rs

1//! Real-time Parameter Estimation
2//!
3//! This module provides algorithms for real-time parameter estimation that can
4//! handle continuous data streams with minimal latency. These methods are designed
5//! for applications where immediate response to new data is critical.
6
7use super::{
8    utils, StreamingConfig, StreamingDataPoint, StreamingObjective, StreamingOptimizer,
9    StreamingStats,
10};
11use crate::error::OptimizeError;
12use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
13// Unused import
14// use scirs2_core::error::CoreResult;
15use std::time::{Duration, Instant};
16
17type Result<T> = std::result::Result<T, OptimizeError>;
18
19/// Real-time estimation methods
20#[derive(Debug, Clone, Copy)]
21pub enum RealTimeMethod {
22    /// Recursive Least Squares (RLS)
23    RecursiveLeastSquares,
24    /// Kalman Filter for parameter estimation
25    KalmanFilter,
26    /// Exponentially Weighted RLS
27    ExponentiallyWeightedRLS,
28    /// Sliding Window RLS
29    SlidingWindowRLS,
30}
31
32/// Real-time parameter estimator
33#[derive(Debug, Clone)]
34pub struct RealTimeEstimator<T: StreamingObjective> {
35    /// Current parameter estimates
36    parameters: Array1<f64>,
37    /// Objective function
38    objective: T,
39    /// Configuration
40    config: StreamingConfig,
41    /// Statistics
42    stats: StreamingStats,
43    /// Estimation method
44    method: RealTimeMethod,
45    /// Covariance matrix (for RLS and Kalman filter)
46    covariance: Array2<f64>,
47    /// Forgetting factor for exponential weighting
48    forgetting_factor: f64,
49    /// Process noise for Kalman filter
50    process_noise: f64,
51    /// Measurement noise for Kalman filter
52    measurement_noise: f64,
53    /// Time of last update for adaptive algorithms
54    last_update_time: Option<Instant>,
55    /// Sliding window for windowed methods
56    window_data: std::collections::VecDeque<(Array1<f64>, f64)>, // (features, target)
57    /// Maximum processing time per update (for real-time constraints)
58    max_processing_time: Duration,
59}
60
61impl<T: StreamingObjective> RealTimeEstimator<T> {
62    /// Create a new real-time estimator
63    pub fn new(
64        initial_parameters: Array1<f64>,
65        objective: T,
66        config: StreamingConfig,
67        method: RealTimeMethod,
68        initial_covariance_scale: f64,
69    ) -> Self {
70        let n_params = initial_parameters.len();
71        let initial_covariance = Array2::eye(n_params) * initial_covariance_scale;
72        let forgetting_factor = config.forgetting_factor;
73        let window_size = config.window_size;
74
75        Self {
76            parameters: initial_parameters,
77            objective,
78            config,
79            stats: StreamingStats::default(),
80            method,
81            covariance: initial_covariance,
82            forgetting_factor,
83            process_noise: 1e-6,
84            measurement_noise: 1e-3,
85            last_update_time: None,
86            window_data: std::collections::VecDeque::with_capacity(window_size),
87            max_processing_time: Duration::from_millis(10), // 10ms max for real-time
88        }
89    }
90
91    /// Update using Recursive Least Squares
92    fn update_rls(&mut self, features: &ArrayView1<f64>, target: f64) -> Result<()> {
93        let n = features.len();
94
95        // Compute prediction error
96        let prediction = self.parameters.dot(features);
97        let error = target - prediction;
98
99        // Compute gain vector: K = P * x / (1 + x^T * P * x)
100        let mut px = Array1::zeros(n);
101        for i in 0..n {
102            for j in 0..n {
103                px[i] += self.covariance[[i, j]] * features[j];
104            }
105        }
106
107        let denominator = 1.0 + features.dot(&px);
108        if denominator.abs() < 1e-12 {
109            return Ok(()); // Skip update if denominator too small
110        }
111
112        let gain = &px / denominator;
113
114        // Update parameters: θ = θ + K * error
115        self.parameters = &self.parameters + &(error * &gain);
116
117        // Update covariance: P = P - K * x^T * P
118        for i in 0..n {
119            for j in 0..n {
120                self.covariance[[i, j]] -= gain[i] * px[j];
121            }
122        }
123
124        Ok(())
125    }
126
127    /// Update using Exponentially Weighted RLS
128    fn update_ewrls(&mut self, features: &ArrayView1<f64>, target: f64) -> Result<()> {
129        let n = features.len();
130
131        // Scale covariance by forgetting factor
132        self.covariance *= 1.0 / self.forgetting_factor;
133
134        // Compute prediction error
135        let prediction = self.parameters.dot(features);
136        let error = target - prediction;
137
138        // Compute gain vector: K = P * x / (λ + x^T * P * x)
139        let mut px = Array1::zeros(n);
140        for i in 0..n {
141            for j in 0..n {
142                px[i] += self.covariance[[i, j]] * features[j];
143            }
144        }
145
146        let denominator = self.forgetting_factor + features.dot(&px);
147        if denominator.abs() < 1e-12 {
148            return Ok(());
149        }
150
151        let gain = &px / denominator;
152
153        // Update parameters
154        self.parameters = &self.parameters + &(error * &gain);
155
156        // Update covariance
157        for i in 0..n {
158            for j in 0..n {
159                self.covariance[[i, j]] -= gain[i] * px[j];
160            }
161        }
162
163        Ok(())
164    }
165
166    /// Update using Kalman Filter
167    fn update_kalman(&mut self, features: &ArrayView1<f64>, target: f64) -> Result<()> {
168        let n = features.len();
169
170        // Time update (prediction step)
171        // Add process noise to covariance
172        for i in 0..n {
173            self.covariance[[i, i]] += self.process_noise;
174        }
175
176        // Measurement update (correction step)
177        let prediction = self.parameters.dot(features);
178        let innovation = target - prediction;
179
180        // Innovation covariance: S = H * P * H^T + R
181        // For linear case: H = x^T, so S = x^T * P * x + R
182        let mut px = Array1::zeros(n);
183        for i in 0..n {
184            for j in 0..n {
185                px[i] += self.covariance[[i, j]] * features[j];
186            }
187        }
188
189        let innovation_covariance = features.dot(&px) + self.measurement_noise;
190        if innovation_covariance.abs() < 1e-12 {
191            return Ok(());
192        }
193
194        // Kalman gain: K = P * H^T / S = P * x / S
195        let kalman_gain = &px / innovation_covariance;
196
197        // Update parameters: θ = θ + K * innovation
198        self.parameters = &self.parameters + &(innovation * &kalman_gain);
199
200        // Update covariance: P = (I - K * H) * P = P - K * x^T * P
201        for i in 0..n {
202            for j in 0..n {
203                self.covariance[[i, j]] -= kalman_gain[i] * px[j];
204            }
205        }
206
207        Ok(())
208    }
209
210    /// Update using Sliding Window RLS
211    fn update_sliding_window_rls(&mut self, features: &ArrayView1<f64>, target: f64) -> Result<()> {
212        // Add new data point
213        self.window_data.push_back((features.to_owned(), target));
214
215        // Remove old data if window is full
216        if self.window_data.len() > self.config.window_size {
217            self.window_data.pop_front();
218        }
219
220        // Rebuild normal equations from window
221        let n = features.len();
222        let mut xtx = Array2::zeros((n, n));
223        let mut xty = Array1::zeros(n);
224
225        for (x, y) in &self.window_data {
226            for i in 0..n {
227                for j in 0..n {
228                    xtx[[i, j]] += x[i] * x[j];
229                }
230                xty[i] += x[i] * y;
231            }
232        }
233
234        // Add regularization
235        for i in 0..n {
236            xtx[[i, i]] += self.config.regularization;
237        }
238
239        // Solve normal equations
240        match scirs2_linalg::solve(&xtx.view(), &xty.view(), None) {
241            Ok(solution) => {
242                self.parameters = solution;
243                // Update covariance as pseudo-inverse of X^T X
244                match scirs2_linalg::compat::pinv(&xtx.view(), None, false, true) {
245                    Ok(pinv) => self.covariance = pinv,
246                    Err(_) => {} // Keep old covariance if inversion fails
247                }
248            }
249            Err(_) => {
250                // Fall back to RLS update if linear system fails
251                self.update_rls(features, target)?;
252            }
253        }
254
255        Ok(())
256    }
257
258    /// Adaptive parameter tuning based on performance
259    fn adapt_parameters(&mut self) {
260        // Adapt forgetting factor based on recent performance
261        if self.stats.points_processed > 10 {
262            let recent_loss_trend = self.stats.current_loss - self.stats.average_loss;
263
264            if recent_loss_trend > 0.0 {
265                // Performance is getting worse, reduce forgetting factor (adapt faster)
266                self.forgetting_factor = (self.forgetting_factor * 0.95).max(0.5);
267            } else {
268                // Performance is stable, increase forgetting factor (adapt slower)
269                self.forgetting_factor = (self.forgetting_factor * 1.01).min(0.999);
270            }
271        }
272
273        // Adapt noise parameters for Kalman filter
274        if matches!(self.method, RealTimeMethod::KalmanFilter) {
275            let param_change_rate = if self.stats.points_processed > 1 {
276                // Estimate parameter change rate from recent updates
277                // This is simplified - in practice, track parameter history
278                self.parameters.mapv(|x| x.abs()).sum() / self.stats.points_processed as f64
279            } else {
280                1e-6
281            };
282
283            self.process_noise = (param_change_rate * 0.1).max(1e-8).min(1e-3);
284        }
285    }
286
287    /// Check if update should be skipped due to time constraints
288    fn should_skip_for_timing(&self, starttime: Instant) -> bool {
289        starttime.elapsed() > self.max_processing_time
290    }
291}
292
293impl<T: StreamingObjective + Clone> StreamingOptimizer for RealTimeEstimator<T> {
294    fn update(&mut self, datapoint: &StreamingDataPoint) -> Result<()> {
295        let start_time = Instant::now();
296
297        // Skip update if timing constraints are violated
298        if self.should_skip_for_timing(start_time) {
299            return Ok(());
300        }
301
302        let old_parameters = self.parameters.clone();
303        let features = &datapoint.features;
304        let target = datapoint.target;
305
306        // Apply method-specific update
307        match self.method {
308            RealTimeMethod::RecursiveLeastSquares => {
309                self.update_rls(&features.view(), target)?;
310            }
311            RealTimeMethod::ExponentiallyWeightedRLS => {
312                self.update_ewrls(&features.view(), target)?;
313            }
314            RealTimeMethod::KalmanFilter => {
315                self.update_kalman(&features.view(), target)?;
316            }
317            RealTimeMethod::SlidingWindowRLS => {
318                self.update_sliding_window_rls(&features.view(), target)?;
319            }
320        }
321
322        // Adaptive parameter tuning
323        if self.stats.points_processed.is_multiple_of(20) {
324            self.adapt_parameters();
325        }
326
327        // Update statistics
328        let loss = self.objective.evaluate(&self.parameters.view(), datapoint);
329        self.stats.points_processed += 1;
330        self.stats.updates_performed += 1;
331        self.stats.current_loss = loss;
332        self.stats.average_loss = utils::ewma_update(
333            self.stats.average_loss,
334            loss,
335            0.05, // Faster adaptation for real-time
336        );
337
338        // Check convergence
339        self.stats.converged = utils::check_convergence(
340            &old_parameters.view(),
341            &self.parameters.view(),
342            self.config.tolerance,
343        );
344
345        self.stats.processing_time_ms += start_time.elapsed().as_secs_f64() * 1000.0;
346        self.last_update_time = Some(start_time);
347
348        Ok(())
349    }
350
351    fn parameters(&self) -> &Array1<f64> {
352        &self.parameters
353    }
354
355    fn stats(&self) -> &StreamingStats {
356        &self.stats
357    }
358
359    fn reset(&mut self) {
360        let n = self.parameters.len();
361        self.covariance = Array2::eye(n) * 1000.0; // Reset with high uncertainty
362        self.forgetting_factor = self.config.forgetting_factor;
363        self.last_update_time = None;
364        self.window_data.clear();
365        self.stats = StreamingStats::default();
366    }
367}
368
369/// Create a Recursive Least Squares estimator
370#[allow(dead_code)]
371pub fn recursive_least_squares<T: StreamingObjective>(
372    initial_parameters: Array1<f64>,
373    objective: T,
374    config: Option<StreamingConfig>,
375    initial_uncertainty: Option<f64>,
376) -> RealTimeEstimator<T> {
377    let config = config.unwrap_or_default();
378    let uncertainty = initial_uncertainty.unwrap_or(1000.0);
379
380    RealTimeEstimator::new(
381        initial_parameters,
382        objective,
383        config,
384        RealTimeMethod::RecursiveLeastSquares,
385        uncertainty,
386    )
387}
388
389/// Create an Exponentially Weighted RLS estimator
390#[allow(dead_code)]
391pub fn exponentially_weighted_rls<T: StreamingObjective>(
392    initial_parameters: Array1<f64>,
393    objective: T,
394    config: Option<StreamingConfig>,
395    forgetting_factor: Option<f64>,
396) -> RealTimeEstimator<T> {
397    let mut config = config.unwrap_or_default();
398    if let Some(ff) = forgetting_factor {
399        config.forgetting_factor = ff;
400    }
401
402    RealTimeEstimator::new(
403        initial_parameters,
404        objective,
405        config,
406        RealTimeMethod::ExponentiallyWeightedRLS,
407        100.0,
408    )
409}
410
411/// Create a Kalman Filter estimator
412#[allow(dead_code)]
413pub fn kalman_filter_estimator<T: StreamingObjective>(
414    initial_parameters: Array1<f64>,
415    objective: T,
416    config: Option<StreamingConfig>,
417    process_noise: Option<f64>,
418    measurement_noise: Option<f64>,
419) -> RealTimeEstimator<T> {
420    let config = config.unwrap_or_default();
421    let mut estimator = RealTimeEstimator::new(
422        initial_parameters,
423        objective,
424        config,
425        RealTimeMethod::KalmanFilter,
426        1.0,
427    );
428
429    if let Some(pn) = process_noise {
430        estimator.process_noise = pn;
431    }
432    if let Some(mn) = measurement_noise {
433        estimator.measurement_noise = mn;
434    }
435
436    estimator
437}
438
439/// Convenience function for real-time linear regression
440#[allow(dead_code)]
441pub fn real_time_linear_regression(
442    n_features: usize,
443    method: RealTimeMethod,
444    config: Option<StreamingConfig>,
445) -> RealTimeEstimator<super::LinearRegressionObjective> {
446    let config = config.unwrap_or_default();
447    let initial_params = Array1::zeros(n_features);
448    let objective = super::LinearRegressionObjective;
449
450    RealTimeEstimator::new(initial_params, objective, config, method, 100.0)
451}
452
453#[cfg(test)]
454mod tests {
455    use super::*;
456    use crate::streaming::{LinearRegressionObjective, StreamingDataPoint};
457
458    #[test]
459    fn test_rls_creation() {
460        let estimator =
461            recursive_least_squares(Array1::zeros(2), LinearRegressionObjective, None, None);
462
463        assert_eq!(estimator.parameters().len(), 2);
464        assert!(matches!(
465            estimator.method,
466            RealTimeMethod::RecursiveLeastSquares
467        ));
468    }
469
470    #[test]
471    fn test_rls_update() {
472        let mut estimator =
473            real_time_linear_regression(2, RealTimeMethod::RecursiveLeastSquares, None);
474
475        let features = Array1::from(vec![1.0, 2.0]);
476        let target = 3.0;
477        let point = StreamingDataPoint::new(features, target);
478
479        assert!(estimator.update(&point).is_ok());
480        assert_eq!(estimator.stats().points_processed, 1);
481    }
482
483    #[test]
484    fn test_ewrls_adaptation() {
485        let mut config = StreamingConfig::default();
486        config.forgetting_factor = 0.9;
487
488        let mut estimator = exponentially_weighted_rls(
489            Array1::zeros(2),
490            LinearRegressionObjective,
491            Some(config),
492            None,
493        );
494
495        // Process several data points
496        for i in 0..10 {
497            let features = Array1::from(vec![i as f64, 1.0]);
498            let target = 2.0 * i as f64 + 1.0;
499            let point = StreamingDataPoint::new(features, target);
500
501            estimator.update(&point).unwrap();
502        }
503
504        assert_eq!(estimator.stats().points_processed, 10);
505        assert!(estimator.stats().current_loss.is_finite());
506    }
507
508    #[test]
509    fn test_kalman_filter() {
510        let mut estimator = kalman_filter_estimator(
511            Array1::zeros(2),
512            LinearRegressionObjective,
513            None,
514            Some(1e-6),
515            Some(1e-3),
516        );
517
518        // Add noisy data
519        let data_points = vec![
520            StreamingDataPoint::new(Array1::from(vec![1.0, 0.0]), 2.1),
521            StreamingDataPoint::new(Array1::from(vec![0.0, 1.0]), 2.9),
522            StreamingDataPoint::new(Array1::from(vec![1.0, 1.0]), 5.1),
523        ];
524
525        for point in &data_points {
526            estimator.update(point).unwrap();
527        }
528
529        assert_eq!(estimator.stats().points_processed, 3);
530
531        // Parameters should be close to [2, 3] despite noise
532        let params = estimator.parameters();
533        assert!((params[0] - 2.0).abs() < 1.0);
534        assert!((params[1] - 3.0).abs() < 1.0);
535    }
536
537    #[test]
538    fn test_sliding_window_rls() {
539        let mut estimator = real_time_linear_regression(2, RealTimeMethod::SlidingWindowRLS, None);
540
541        // Add data points to exceed window size
542        for i in 0..15 {
543            let features = Array1::from(vec![i as f64, 1.0]);
544            let target = 2.0 * i as f64;
545            let point = StreamingDataPoint::new(features, target);
546
547            estimator.update(&point).unwrap();
548        }
549
550        // Should have processed all points but window is limited
551        assert_eq!(estimator.stats().points_processed, 15);
552        assert!(estimator.window_data.len() <= estimator.config.window_size);
553    }
554
555    #[test]
556    fn test_covariance_updates() {
557        let mut estimator = recursive_least_squares(
558            Array1::zeros(2),
559            LinearRegressionObjective,
560            None,
561            Some(100.0),
562        );
563
564        let initial_covariance = estimator.covariance.clone();
565
566        let features = Array1::from(vec![1.0, 1.0]);
567        let target = 1.0;
568        let point = StreamingDataPoint::new(features, target);
569
570        estimator.update(&point).unwrap();
571
572        // Covariance should change after update
573        assert!(&estimator.covariance != &initial_covariance);
574
575        // Diagonal elements should generally decrease (uncertainty reduction)
576        assert!(estimator.covariance[[0, 0]] < initial_covariance[[0, 0]]);
577        assert!(estimator.covariance[[1, 1]] < initial_covariance[[1, 1]]);
578    }
579
580    #[test]
581    fn test_real_time_constraints() {
582        let mut estimator =
583            real_time_linear_regression(2, RealTimeMethod::RecursiveLeastSquares, None);
584
585        // Set very tight timing constraint
586        estimator.max_processing_time = Duration::from_nanos(1);
587
588        let features = Array1::from(vec![1.0, 2.0]);
589        let target = 3.0;
590        let point = StreamingDataPoint::new(features, target);
591
592        // Update should complete quickly (might skip processing due to timing)
593        let start = Instant::now();
594        estimator.update(&point).unwrap();
595        let elapsed = start.elapsed();
596
597        // Should not take more than a reasonable amount of time
598        assert!(elapsed < Duration::from_millis(100));
599    }
600
601    #[test]
602    fn test_parameter_adaptation() {
603        let mut estimator = exponentially_weighted_rls(
604            Array1::zeros(2),
605            LinearRegressionObjective,
606            None,
607            Some(0.95),
608        );
609
610        let initial_ff = estimator.forgetting_factor;
611
612        // Process many points to trigger adaptation
613        for i in 0..50 {
614            let features = Array1::from(vec![i as f64, 1.0]);
615            let target = i as f64; // Potentially changing relationship
616            let point = StreamingDataPoint::new(features, target);
617
618            estimator.update(&point).unwrap();
619        }
620
621        // Forgetting factor may have been adapted
622        // (exact behavior depends on loss trends)
623        assert!(estimator.stats().points_processed == 50);
624    }
625}