scirs2_optimize/streaming/
rolling_window.rs

1//! Rolling Window Optimization
2//!
3//! This module implements optimization algorithms that operate over sliding windows
4//! of streaming data. These methods are useful for non-stationary optimization
5//! problems where recent data should have more influence than older data.
6
7use super::{
8    utils, StreamingConfig, StreamingDataPoint, StreamingObjective, StreamingOptimizer,
9    StreamingStats,
10};
11use crate::error::OptimizeError;
12use ndarray::{Array1, Array2}; // Unused import: ArrayView1
13                               // Unused import
14                               // use scirs2_core::error::CoreResult;
15use std::collections::VecDeque;
16
17type Result<T> = std::result::Result<T, OptimizeError>;
18
19/// Rolling window optimizer that maintains a sliding window of recent data
20#[derive(Debug, Clone)]
21pub struct RollingWindowOptimizer<T: StreamingObjective> {
22    /// Current parameter estimates
23    parameters: Array1<f64>,
24    /// Objective function
25    objective: T,
26    /// Configuration
27    config: StreamingConfig,
28    /// Statistics
29    stats: StreamingStats,
30    /// Sliding window of data points
31    data_window: VecDeque<StreamingDataPoint>,
32    /// Window-based optimizer (could be any StreamingOptimizer)
33    window_optimizer: WindowOptimizerType,
34    /// Whether to refit on every update or use incremental updates
35    refit_every_update: bool,
36    /// How often to refit (if not every update)
37    refit_frequency: usize,
38    /// Counter for refit frequency
39    update_counter: usize,
40}
41
42/// Types of optimizers that can be used within the rolling window
43#[derive(Debug, Clone)]
44enum WindowOptimizerType {
45    /// Gradient descent with accumulated gradients
46    GradientDescent {
47        gradient_accumulator: Array1<f64>,
48        learning_rate: f64,
49    },
50    /// Least squares solution (for linear problems)
51    LeastSquares {
52        /// X^T X matrix
53        xtx: Array2<f64>,
54        /// X^T y vector  
55        xty: Array1<f64>,
56        /// Regularization parameter
57        regularization: f64,
58    },
59    /// Weighted least squares with exponential decay
60    WeightedLeastSquares {
61        /// Weighted X^T X matrix
62        weighted_xtx: Array2<f64>,
63        /// Weighted X^T y vector
64        weighted_xty: Array1<f64>,
65        /// Regularization parameter
66        regularization: f64,
67        /// Decay factor for weights
68        decay_factor: f64,
69    },
70}
71
72impl<T: StreamingObjective> RollingWindowOptimizer<T> {
73    /// Create a new rolling window optimizer
74    pub fn new(
75        initial_parameters: Array1<f64>,
76        objective: T,
77        config: StreamingConfig,
78        window_optimizer_type: WindowOptimizerType,
79        refit_every_update: bool,
80    ) -> Self {
81        let window_size = config.window_size;
82        Self {
83            parameters: initial_parameters,
84            objective,
85            config,
86            stats: StreamingStats::default(),
87            data_window: VecDeque::with_capacity(window_size),
88            window_optimizer: window_optimizer_type,
89            refit_every_update,
90            refit_frequency: window_size / 4, // Default: refit every quarter window
91            update_counter: 0,
92        }
93    }
94
95    /// Add a data point to the window and remove old ones if necessary
96    fn update_window(&mut self, datapoint: StreamingDataPoint) {
97        if self.data_window.len() >= self.config.window_size {
98            self.data_window.pop_front();
99        }
100        self.data_window.push_back(datapoint);
101    }
102
103    /// Optimize parameters based on current window contents
104    fn optimize_window(&mut self) -> Result<()> {
105        if self.data_window.is_empty() {
106            return Ok(());
107        }
108
109        // Extract optimizer type and data temporarily to avoid borrowing conflicts
110        let mut temp_optimizer = std::mem::replace(
111            &mut self.window_optimizer,
112            WindowOptimizerType::GradientDescent {
113                gradient_accumulator: Array1::zeros(0),
114                learning_rate: 0.01,
115            },
116        );
117
118        let result = match &mut temp_optimizer {
119            WindowOptimizerType::GradientDescent {
120                gradient_accumulator,
121                learning_rate,
122            } => {
123                let learning_rate = *learning_rate;
124                self.optimize_gradient_descent(gradient_accumulator, learning_rate)
125            }
126            WindowOptimizerType::LeastSquares {
127                xtx,
128                xty,
129                regularization,
130            } => {
131                let regularization = *regularization;
132                self.optimize_least_squares(xtx, xty, regularization)
133            }
134            WindowOptimizerType::WeightedLeastSquares {
135                weighted_xtx,
136                weighted_xty,
137                regularization,
138                decay_factor,
139            } => {
140                let regularization = *regularization;
141                let decay_factor = *decay_factor;
142                self.optimize_weighted_least_squares(
143                    weighted_xtx,
144                    weighted_xty,
145                    regularization,
146                    decay_factor,
147                )
148            }
149        };
150
151        // Restore the optimizer
152        self.window_optimizer = temp_optimizer;
153        result
154    }
155
156    /// Gradient descent optimization over the window
157    fn optimize_gradient_descent(
158        &mut self,
159        gradient_accumulator: &mut Array1<f64>,
160        learning_rate: f64,
161    ) -> Result<()> {
162        gradient_accumulator.fill(0.0);
163        let mut total_weight = 0.0;
164
165        // Accumulate gradients from all points in window
166        for (i, data_point) in self.data_window.iter().enumerate() {
167            let gradient = self.objective.gradient(&self.parameters.view(), data_point);
168            let weight = data_point.weight.unwrap_or(1.0);
169
170            // Apply temporal weighting (more recent data gets higher weight)
171            let temporal_weight = self
172                .config
173                .forgetting_factor
174                .powi((self.data_window.len() - 1 - i) as i32);
175            let effective_weight = weight * temporal_weight;
176
177            *gradient_accumulator = &*gradient_accumulator + &(effective_weight * &gradient);
178            total_weight += effective_weight;
179        }
180
181        if total_weight > 0.0 {
182            *gradient_accumulator /= total_weight;
183
184            // Apply gradient descent update
185            self.parameters = &self.parameters - &(&*gradient_accumulator * learning_rate);
186        }
187
188        Ok(())
189    }
190
191    /// Least squares optimization over the window (for linear objectives)
192    fn optimize_least_squares(
193        &mut self,
194        xtx: &mut Array2<f64>,
195        xty: &mut Array1<f64>,
196        regularization: f64,
197    ) -> Result<()> {
198        let n_features = self.parameters.len();
199        xtx.fill(0.0);
200        xty.fill(0.0);
201
202        // Build normal equations from window data
203        for data_point in &self.data_window {
204            let x = &data_point.features;
205            let y = data_point.target;
206            let weight = data_point.weight.unwrap_or(1.0);
207
208            // X^T X accumulation
209            for i in 0..n_features {
210                for j in 0..n_features {
211                    xtx[[i, j]] += weight * x[i] * x[j];
212                }
213                // X^T y accumulation
214                xty[i] += weight * x[i] * y;
215            }
216        }
217
218        // Add regularization
219        for i in 0..n_features {
220            xtx[[i, i]] += regularization;
221        }
222
223        // Solve normal equations
224        match scirs2_linalg::solve(&xtx.view(), &xty.view(), None) {
225            Ok(solution) => {
226                self.parameters = solution;
227                Ok(())
228            }
229            Err(_) => {
230                // Fall back to gradient descent if linear system fails
231                let mut dummy_grad = Array1::zeros(n_features);
232                self.optimize_gradient_descent(&mut dummy_grad, self.config.learning_rate)
233            }
234        }
235    }
236
237    /// Weighted least squares with exponential decay
238    fn optimize_weighted_least_squares(
239        &mut self,
240        weighted_xtx: &mut Array2<f64>,
241        weighted_xty: &mut Array1<f64>,
242        regularization: f64,
243        decay_factor: f64,
244    ) -> Result<()> {
245        let n_features = self.parameters.len();
246        weighted_xtx.fill(0.0);
247        weighted_xty.fill(0.0);
248
249        // Build weighted normal equations
250        for (i, data_point) in self.data_window.iter().enumerate() {
251            let x = &data_point.features;
252            let y = data_point.target;
253            let base_weight = data_point.weight.unwrap_or(1.0);
254
255            // Exponential decay weighting (more recent data gets higher weight)
256            let age = self.data_window.len() - 1 - i;
257            let temporal_weight = decay_factor.powi(age as i32);
258            let total_weight = base_weight * temporal_weight;
259
260            // Weighted X^T X accumulation
261            for j in 0..n_features {
262                for k in 0..n_features {
263                    weighted_xtx[[j, k]] += total_weight * x[j] * x[k];
264                }
265                // Weighted X^T y accumulation
266                weighted_xty[j] += total_weight * x[j] * y;
267            }
268        }
269
270        // Add regularization
271        for i in 0..n_features {
272            weighted_xtx[[i, i]] += regularization;
273        }
274
275        // Solve weighted normal equations
276        match scirs2_linalg::solve(&weighted_xtx.view(), &weighted_xty.view(), None) {
277            Ok(solution) => {
278                self.parameters = solution;
279                Ok(())
280            }
281            Err(_) => {
282                // Fall back to gradient descent
283                let mut dummy_grad = Array1::zeros(n_features);
284                self.optimize_gradient_descent(&mut dummy_grad, self.config.learning_rate)
285            }
286        }
287    }
288
289    /// Compute average loss over the current window
290    fn compute_window_loss(&self) -> f64 {
291        if self.data_window.is_empty() {
292            return f64::INFINITY;
293        }
294
295        let mut total_loss = 0.0;
296        let mut total_weight = 0.0;
297
298        for data_point in &self.data_window {
299            let loss = self.objective.evaluate(&self.parameters.view(), data_point);
300            let weight = data_point.weight.unwrap_or(1.0);
301            total_loss += weight * loss;
302            total_weight += weight;
303        }
304
305        if total_weight > 0.0 {
306            total_loss / total_weight
307        } else {
308            f64::INFINITY
309        }
310    }
311
312    /// Check convergence based on window statistics
313    fn check_window_convergence(&self) -> bool {
314        if self.data_window.len() < 2 {
315            return false;
316        }
317
318        // Check if parameters are stable across recent window updates
319        // This is a simplified convergence check - in practice, we'd track
320        // parameter history across multiple window updates
321        self.stats.average_loss.is_finite() && self.stats.average_loss < self.config.tolerance
322    }
323}
324
325impl<T: StreamingObjective + Clone> StreamingOptimizer for RollingWindowOptimizer<T> {
326    fn update(&mut self, datapoint: &StreamingDataPoint) -> Result<()> {
327        let start_time = std::time::Instant::now();
328        let old_parameters = self.parameters.clone();
329
330        // Add data _point to window
331        self.update_window(datapoint.clone());
332        self.update_counter += 1;
333
334        // Decide whether to reoptimize
335        let should_reoptimize =
336            self.refit_every_update || (self.update_counter % self.refit_frequency == 0);
337
338        if should_reoptimize {
339            // Reoptimize based on current window
340            self.optimize_window()?;
341            self.stats.updates_performed += 1;
342        }
343
344        // Update statistics
345        self.stats.points_processed += 1;
346        self.stats.current_loss = self.compute_window_loss();
347        self.stats.average_loss = utils::ewma_update(
348            self.stats.average_loss,
349            self.stats.current_loss,
350            0.1, // Use higher smoothing for window-based methods
351        );
352
353        // Check convergence
354        self.stats.converged = utils::check_convergence(
355            &old_parameters.view(),
356            &self.parameters.view(),
357            self.config.tolerance,
358        ) || self.check_window_convergence();
359
360        self.stats.processing_time_ms += start_time.elapsed().as_secs_f64() * 1000.0;
361
362        Ok(())
363    }
364
365    fn parameters(&self) -> &Array1<f64> {
366        &self.parameters
367    }
368
369    fn stats(&self) -> &StreamingStats {
370        &self.stats
371    }
372
373    fn reset(&mut self) {
374        self.data_window.clear();
375        self.update_counter = 0;
376        self.stats = StreamingStats::default();
377
378        // Reset window optimizer state
379        match &mut self.window_optimizer {
380            WindowOptimizerType::GradientDescent {
381                gradient_accumulator,
382                ..
383            } => {
384                gradient_accumulator.fill(0.0);
385            }
386            WindowOptimizerType::LeastSquares { xtx, xty, .. } => {
387                xtx.fill(0.0);
388                xty.fill(0.0);
389            }
390            WindowOptimizerType::WeightedLeastSquares {
391                weighted_xtx,
392                weighted_xty,
393                ..
394            } => {
395                weighted_xtx.fill(0.0);
396                weighted_xty.fill(0.0);
397            }
398        }
399    }
400}
401
402/// Create a rolling window optimizer with gradient descent
403#[allow(dead_code)]
404pub fn rolling_window_gradient_descent<T: StreamingObjective>(
405    initial_parameters: Array1<f64>,
406    objective: T,
407    config: StreamingConfig,
408    learning_rate: Option<f64>,
409) -> RollingWindowOptimizer<T> {
410    let lr = learning_rate.unwrap_or(config.learning_rate);
411    let n_params = initial_parameters.len();
412    let optimizer_type = WindowOptimizerType::GradientDescent {
413        gradient_accumulator: Array1::zeros(n_params),
414        learning_rate: lr,
415    };
416
417    RollingWindowOptimizer::new(initial_parameters, objective, config, optimizer_type, false)
418}
419
420/// Create a rolling window optimizer with least squares (for linear problems)
421#[allow(dead_code)]
422pub fn rolling_window_least_squares<T: StreamingObjective>(
423    initial_parameters: Array1<f64>,
424    objective: T,
425    config: StreamingConfig,
426    regularization: Option<f64>,
427) -> RollingWindowOptimizer<T> {
428    let reg = regularization.unwrap_or(config.regularization);
429    let n_params = initial_parameters.len();
430    let optimizer_type = WindowOptimizerType::LeastSquares {
431        xtx: Array2::zeros((n_params, n_params)),
432        xty: Array1::zeros(n_params),
433        regularization: reg,
434    };
435
436    RollingWindowOptimizer::new(initial_parameters, objective, config, optimizer_type, true)
437}
438
439/// Create a rolling window optimizer with weighted least squares
440#[allow(dead_code)]
441pub fn rolling_window_weighted_least_squares<T: StreamingObjective>(
442    initial_parameters: Array1<f64>,
443    objective: T,
444    config: StreamingConfig,
445    regularization: Option<f64>,
446    decay_factor: Option<f64>,
447) -> RollingWindowOptimizer<T> {
448    let reg = regularization.unwrap_or(config.regularization);
449    let decay = decay_factor.unwrap_or(config.forgetting_factor);
450    let n_params = initial_parameters.len();
451    let optimizer_type = WindowOptimizerType::WeightedLeastSquares {
452        weighted_xtx: Array2::zeros((n_params, n_params)),
453        weighted_xty: Array1::zeros(n_params),
454        regularization: reg,
455        decay_factor: decay,
456    };
457
458    RollingWindowOptimizer::new(initial_parameters, objective, config, optimizer_type, true)
459}
460
461/// Convenience function for rolling window linear regression
462#[allow(dead_code)]
463pub fn rolling_window_linear_regression(
464    n_features: usize,
465    window_size: usize,
466    use_weighted: bool,
467    config: Option<StreamingConfig>,
468) -> RollingWindowOptimizer<super::LinearRegressionObjective> {
469    let mut config = config.unwrap_or_default();
470    config.window_size = window_size;
471
472    let initial_params = Array1::zeros(n_features);
473    let objective = super::LinearRegressionObjective;
474
475    if use_weighted {
476        rolling_window_weighted_least_squares(initial_params, objective, config, None, None)
477    } else {
478        rolling_window_least_squares(initial_params, objective, config, None)
479    }
480}
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485    use crate::streaming::{LinearRegressionObjective, StreamingDataPoint};
486
487    #[test]
488    fn test_rolling_window_creation() {
489        let optimizer = rolling_window_linear_regression(2, 10, false, None);
490        assert_eq!(optimizer.data_window.capacity(), 10);
491        assert_eq!(optimizer.parameters().len(), 2);
492    }
493
494    #[test]
495    fn test_window_update() {
496        let mut optimizer = rolling_window_linear_regression(2, 3, false, None);
497
498        // Add data points to fill window
499        for i in 0..5 {
500            let features = Array1::from(vec![i as f64, (i + 1) as f64]);
501            let target = (2 * i + 1) as f64;
502            let point = StreamingDataPoint::new(features, target);
503
504            optimizer.update(&point).unwrap();
505        }
506
507        // Window should be at capacity
508        assert_eq!(optimizer.data_window.len(), 3);
509        assert_eq!(optimizer.stats().points_processed, 5);
510    }
511
512    #[test]
513    fn test_gradient_descent_window() {
514        let config = StreamingConfig {
515            window_size: 5,
516            learning_rate: 0.1,
517            ..Default::default()
518        };
519
520        let mut optimizer = rolling_window_gradient_descent(
521            Array1::zeros(2),
522            LinearRegressionObjective,
523            config,
524            None,
525        );
526
527        // Add some data points
528        let data_points = vec![
529            StreamingDataPoint::new(Array1::from(vec![1.0, 0.0]), 2.0),
530            StreamingDataPoint::new(Array1::from(vec![0.0, 1.0]), 3.0),
531            StreamingDataPoint::new(Array1::from(vec![1.0, 1.0]), 5.0),
532        ];
533
534        for point in &data_points {
535            optimizer.update(point).unwrap();
536        }
537
538        assert_eq!(optimizer.stats().points_processed, 3);
539        assert!(optimizer.stats().updates_performed > 0);
540    }
541
542    #[test]
543    fn test_least_squares_window() {
544        let mut optimizer = rolling_window_linear_regression(2, 10, false, None);
545
546        // Generate data for y = 2*x1 + 3*x2
547        let data_points = vec![
548            StreamingDataPoint::new(Array1::from(vec![1.0, 0.0]), 2.0),
549            StreamingDataPoint::new(Array1::from(vec![0.0, 1.0]), 3.0),
550            StreamingDataPoint::new(Array1::from(vec![1.0, 1.0]), 5.0),
551            StreamingDataPoint::new(Array1::from(vec![2.0, 1.0]), 7.0),
552        ];
553
554        for point in &data_points {
555            optimizer.update(point).unwrap();
556        }
557
558        // Parameters should be close to [2, 3] for exact linear data
559        let params = optimizer.parameters();
560        assert!(
561            (params[0] - 2.0).abs() < 0.1,
562            "First parameter: {}",
563            params[0]
564        );
565        assert!(
566            (params[1] - 3.0).abs() < 0.1,
567            "Second parameter: {}",
568            params[1]
569        );
570    }
571
572    #[test]
573    fn test_weighted_least_squares_window() {
574        let mut optimizer = rolling_window_linear_regression(2, 10, true, None);
575
576        // Add data points with some having higher weights implicitly through recency
577        let data_points = vec![
578            StreamingDataPoint::new(Array1::from(vec![1.0, 0.0]), 2.0),
579            StreamingDataPoint::new(Array1::from(vec![0.0, 1.0]), 3.0),
580            StreamingDataPoint::new(Array1::from(vec![1.0, 1.0]), 5.0),
581        ];
582
583        for point in &data_points {
584            optimizer.update(point).unwrap();
585        }
586
587        assert_eq!(optimizer.stats().points_processed, 3);
588        assert!(optimizer.stats().current_loss.is_finite());
589    }
590
591    #[test]
592    fn test_window_overflow() {
593        let mut optimizer = rolling_window_linear_regression(2, 2, false, None);
594
595        // Add more points than window size
596        for i in 0..5 {
597            let features = Array1::from(vec![i as f64, 1.0]);
598            let target = i as f64;
599            let point = StreamingDataPoint::new(features, target);
600
601            optimizer.update(&point).unwrap();
602        }
603
604        // Window should be at capacity, not larger
605        assert_eq!(optimizer.data_window.len(), 2);
606        assert_eq!(optimizer.stats().points_processed, 5);
607    }
608
609    #[test]
610    fn test_window_reset() {
611        let mut optimizer = rolling_window_linear_regression(2, 5, false, None);
612
613        // Add some data
614        let point = StreamingDataPoint::new(Array1::from(vec![1.0, 2.0]), 3.0);
615        optimizer.update(&point).unwrap();
616
617        assert_eq!(optimizer.data_window.len(), 1);
618        assert_eq!(optimizer.stats().points_processed, 1);
619
620        // Reset should clear everything
621        optimizer.reset();
622        assert_eq!(optimizer.data_window.len(), 0);
623        assert_eq!(optimizer.stats().points_processed, 0);
624    }
625}