1use super::{
8 utils, StreamingConfig, StreamingDataPoint, StreamingObjective, StreamingOptimizer,
9 StreamingStats,
10};
11use crate::error::OptimizeError;
12use ndarray::{Array1, Array2}; use std::collections::VecDeque;
16
17type Result<T> = std::result::Result<T, OptimizeError>;
18
19#[derive(Debug, Clone)]
21pub struct RollingWindowOptimizer<T: StreamingObjective> {
22 parameters: Array1<f64>,
24 objective: T,
26 config: StreamingConfig,
28 stats: StreamingStats,
30 data_window: VecDeque<StreamingDataPoint>,
32 window_optimizer: WindowOptimizerType,
34 refit_every_update: bool,
36 refit_frequency: usize,
38 update_counter: usize,
40}
41
42#[derive(Debug, Clone)]
44enum WindowOptimizerType {
45 GradientDescent {
47 gradient_accumulator: Array1<f64>,
48 learning_rate: f64,
49 },
50 LeastSquares {
52 xtx: Array2<f64>,
54 xty: Array1<f64>,
56 regularization: f64,
58 },
59 WeightedLeastSquares {
61 weighted_xtx: Array2<f64>,
63 weighted_xty: Array1<f64>,
65 regularization: f64,
67 decay_factor: f64,
69 },
70}
71
72impl<T: StreamingObjective> RollingWindowOptimizer<T> {
73 pub fn new(
75 initial_parameters: Array1<f64>,
76 objective: T,
77 config: StreamingConfig,
78 window_optimizer_type: WindowOptimizerType,
79 refit_every_update: bool,
80 ) -> Self {
81 let window_size = config.window_size;
82 Self {
83 parameters: initial_parameters,
84 objective,
85 config,
86 stats: StreamingStats::default(),
87 data_window: VecDeque::with_capacity(window_size),
88 window_optimizer: window_optimizer_type,
89 refit_every_update,
90 refit_frequency: window_size / 4, update_counter: 0,
92 }
93 }
94
95 fn update_window(&mut self, datapoint: StreamingDataPoint) {
97 if self.data_window.len() >= self.config.window_size {
98 self.data_window.pop_front();
99 }
100 self.data_window.push_back(datapoint);
101 }
102
103 fn optimize_window(&mut self) -> Result<()> {
105 if self.data_window.is_empty() {
106 return Ok(());
107 }
108
109 let mut temp_optimizer = std::mem::replace(
111 &mut self.window_optimizer,
112 WindowOptimizerType::GradientDescent {
113 gradient_accumulator: Array1::zeros(0),
114 learning_rate: 0.01,
115 },
116 );
117
118 let result = match &mut temp_optimizer {
119 WindowOptimizerType::GradientDescent {
120 gradient_accumulator,
121 learning_rate,
122 } => {
123 let learning_rate = *learning_rate;
124 self.optimize_gradient_descent(gradient_accumulator, learning_rate)
125 }
126 WindowOptimizerType::LeastSquares {
127 xtx,
128 xty,
129 regularization,
130 } => {
131 let regularization = *regularization;
132 self.optimize_least_squares(xtx, xty, regularization)
133 }
134 WindowOptimizerType::WeightedLeastSquares {
135 weighted_xtx,
136 weighted_xty,
137 regularization,
138 decay_factor,
139 } => {
140 let regularization = *regularization;
141 let decay_factor = *decay_factor;
142 self.optimize_weighted_least_squares(
143 weighted_xtx,
144 weighted_xty,
145 regularization,
146 decay_factor,
147 )
148 }
149 };
150
151 self.window_optimizer = temp_optimizer;
153 result
154 }
155
156 fn optimize_gradient_descent(
158 &mut self,
159 gradient_accumulator: &mut Array1<f64>,
160 learning_rate: f64,
161 ) -> Result<()> {
162 gradient_accumulator.fill(0.0);
163 let mut total_weight = 0.0;
164
165 for (i, data_point) in self.data_window.iter().enumerate() {
167 let gradient = self.objective.gradient(&self.parameters.view(), data_point);
168 let weight = data_point.weight.unwrap_or(1.0);
169
170 let temporal_weight = self
172 .config
173 .forgetting_factor
174 .powi((self.data_window.len() - 1 - i) as i32);
175 let effective_weight = weight * temporal_weight;
176
177 *gradient_accumulator = &*gradient_accumulator + &(effective_weight * &gradient);
178 total_weight += effective_weight;
179 }
180
181 if total_weight > 0.0 {
182 *gradient_accumulator /= total_weight;
183
184 self.parameters = &self.parameters - &(&*gradient_accumulator * learning_rate);
186 }
187
188 Ok(())
189 }
190
191 fn optimize_least_squares(
193 &mut self,
194 xtx: &mut Array2<f64>,
195 xty: &mut Array1<f64>,
196 regularization: f64,
197 ) -> Result<()> {
198 let n_features = self.parameters.len();
199 xtx.fill(0.0);
200 xty.fill(0.0);
201
202 for data_point in &self.data_window {
204 let x = &data_point.features;
205 let y = data_point.target;
206 let weight = data_point.weight.unwrap_or(1.0);
207
208 for i in 0..n_features {
210 for j in 0..n_features {
211 xtx[[i, j]] += weight * x[i] * x[j];
212 }
213 xty[i] += weight * x[i] * y;
215 }
216 }
217
218 for i in 0..n_features {
220 xtx[[i, i]] += regularization;
221 }
222
223 match scirs2_linalg::solve(&xtx.view(), &xty.view(), None) {
225 Ok(solution) => {
226 self.parameters = solution;
227 Ok(())
228 }
229 Err(_) => {
230 let mut dummy_grad = Array1::zeros(n_features);
232 self.optimize_gradient_descent(&mut dummy_grad, self.config.learning_rate)
233 }
234 }
235 }
236
237 fn optimize_weighted_least_squares(
239 &mut self,
240 weighted_xtx: &mut Array2<f64>,
241 weighted_xty: &mut Array1<f64>,
242 regularization: f64,
243 decay_factor: f64,
244 ) -> Result<()> {
245 let n_features = self.parameters.len();
246 weighted_xtx.fill(0.0);
247 weighted_xty.fill(0.0);
248
249 for (i, data_point) in self.data_window.iter().enumerate() {
251 let x = &data_point.features;
252 let y = data_point.target;
253 let base_weight = data_point.weight.unwrap_or(1.0);
254
255 let age = self.data_window.len() - 1 - i;
257 let temporal_weight = decay_factor.powi(age as i32);
258 let total_weight = base_weight * temporal_weight;
259
260 for j in 0..n_features {
262 for k in 0..n_features {
263 weighted_xtx[[j, k]] += total_weight * x[j] * x[k];
264 }
265 weighted_xty[j] += total_weight * x[j] * y;
267 }
268 }
269
270 for i in 0..n_features {
272 weighted_xtx[[i, i]] += regularization;
273 }
274
275 match scirs2_linalg::solve(&weighted_xtx.view(), &weighted_xty.view(), None) {
277 Ok(solution) => {
278 self.parameters = solution;
279 Ok(())
280 }
281 Err(_) => {
282 let mut dummy_grad = Array1::zeros(n_features);
284 self.optimize_gradient_descent(&mut dummy_grad, self.config.learning_rate)
285 }
286 }
287 }
288
289 fn compute_window_loss(&self) -> f64 {
291 if self.data_window.is_empty() {
292 return f64::INFINITY;
293 }
294
295 let mut total_loss = 0.0;
296 let mut total_weight = 0.0;
297
298 for data_point in &self.data_window {
299 let loss = self.objective.evaluate(&self.parameters.view(), data_point);
300 let weight = data_point.weight.unwrap_or(1.0);
301 total_loss += weight * loss;
302 total_weight += weight;
303 }
304
305 if total_weight > 0.0 {
306 total_loss / total_weight
307 } else {
308 f64::INFINITY
309 }
310 }
311
312 fn check_window_convergence(&self) -> bool {
314 if self.data_window.len() < 2 {
315 return false;
316 }
317
318 self.stats.average_loss.is_finite() && self.stats.average_loss < self.config.tolerance
322 }
323}
324
325impl<T: StreamingObjective + Clone> StreamingOptimizer for RollingWindowOptimizer<T> {
326 fn update(&mut self, datapoint: &StreamingDataPoint) -> Result<()> {
327 let start_time = std::time::Instant::now();
328 let old_parameters = self.parameters.clone();
329
330 self.update_window(datapoint.clone());
332 self.update_counter += 1;
333
334 let should_reoptimize =
336 self.refit_every_update || (self.update_counter % self.refit_frequency == 0);
337
338 if should_reoptimize {
339 self.optimize_window()?;
341 self.stats.updates_performed += 1;
342 }
343
344 self.stats.points_processed += 1;
346 self.stats.current_loss = self.compute_window_loss();
347 self.stats.average_loss = utils::ewma_update(
348 self.stats.average_loss,
349 self.stats.current_loss,
350 0.1, );
352
353 self.stats.converged = utils::check_convergence(
355 &old_parameters.view(),
356 &self.parameters.view(),
357 self.config.tolerance,
358 ) || self.check_window_convergence();
359
360 self.stats.processing_time_ms += start_time.elapsed().as_secs_f64() * 1000.0;
361
362 Ok(())
363 }
364
365 fn parameters(&self) -> &Array1<f64> {
366 &self.parameters
367 }
368
369 fn stats(&self) -> &StreamingStats {
370 &self.stats
371 }
372
373 fn reset(&mut self) {
374 self.data_window.clear();
375 self.update_counter = 0;
376 self.stats = StreamingStats::default();
377
378 match &mut self.window_optimizer {
380 WindowOptimizerType::GradientDescent {
381 gradient_accumulator,
382 ..
383 } => {
384 gradient_accumulator.fill(0.0);
385 }
386 WindowOptimizerType::LeastSquares { xtx, xty, .. } => {
387 xtx.fill(0.0);
388 xty.fill(0.0);
389 }
390 WindowOptimizerType::WeightedLeastSquares {
391 weighted_xtx,
392 weighted_xty,
393 ..
394 } => {
395 weighted_xtx.fill(0.0);
396 weighted_xty.fill(0.0);
397 }
398 }
399 }
400}
401
402#[allow(dead_code)]
404pub fn rolling_window_gradient_descent<T: StreamingObjective>(
405 initial_parameters: Array1<f64>,
406 objective: T,
407 config: StreamingConfig,
408 learning_rate: Option<f64>,
409) -> RollingWindowOptimizer<T> {
410 let lr = learning_rate.unwrap_or(config.learning_rate);
411 let n_params = initial_parameters.len();
412 let optimizer_type = WindowOptimizerType::GradientDescent {
413 gradient_accumulator: Array1::zeros(n_params),
414 learning_rate: lr,
415 };
416
417 RollingWindowOptimizer::new(initial_parameters, objective, config, optimizer_type, false)
418}
419
420#[allow(dead_code)]
422pub fn rolling_window_least_squares<T: StreamingObjective>(
423 initial_parameters: Array1<f64>,
424 objective: T,
425 config: StreamingConfig,
426 regularization: Option<f64>,
427) -> RollingWindowOptimizer<T> {
428 let reg = regularization.unwrap_or(config.regularization);
429 let n_params = initial_parameters.len();
430 let optimizer_type = WindowOptimizerType::LeastSquares {
431 xtx: Array2::zeros((n_params, n_params)),
432 xty: Array1::zeros(n_params),
433 regularization: reg,
434 };
435
436 RollingWindowOptimizer::new(initial_parameters, objective, config, optimizer_type, true)
437}
438
439#[allow(dead_code)]
441pub fn rolling_window_weighted_least_squares<T: StreamingObjective>(
442 initial_parameters: Array1<f64>,
443 objective: T,
444 config: StreamingConfig,
445 regularization: Option<f64>,
446 decay_factor: Option<f64>,
447) -> RollingWindowOptimizer<T> {
448 let reg = regularization.unwrap_or(config.regularization);
449 let decay = decay_factor.unwrap_or(config.forgetting_factor);
450 let n_params = initial_parameters.len();
451 let optimizer_type = WindowOptimizerType::WeightedLeastSquares {
452 weighted_xtx: Array2::zeros((n_params, n_params)),
453 weighted_xty: Array1::zeros(n_params),
454 regularization: reg,
455 decay_factor: decay,
456 };
457
458 RollingWindowOptimizer::new(initial_parameters, objective, config, optimizer_type, true)
459}
460
461#[allow(dead_code)]
463pub fn rolling_window_linear_regression(
464 n_features: usize,
465 window_size: usize,
466 use_weighted: bool,
467 config: Option<StreamingConfig>,
468) -> RollingWindowOptimizer<super::LinearRegressionObjective> {
469 let mut config = config.unwrap_or_default();
470 config.window_size = window_size;
471
472 let initial_params = Array1::zeros(n_features);
473 let objective = super::LinearRegressionObjective;
474
475 if use_weighted {
476 rolling_window_weighted_least_squares(initial_params, objective, config, None, None)
477 } else {
478 rolling_window_least_squares(initial_params, objective, config, None)
479 }
480}
481
482#[cfg(test)]
483mod tests {
484 use super::*;
485 use crate::streaming::{LinearRegressionObjective, StreamingDataPoint};
486
487 #[test]
488 fn test_rolling_window_creation() {
489 let optimizer = rolling_window_linear_regression(2, 10, false, None);
490 assert_eq!(optimizer.data_window.capacity(), 10);
491 assert_eq!(optimizer.parameters().len(), 2);
492 }
493
494 #[test]
495 fn test_window_update() {
496 let mut optimizer = rolling_window_linear_regression(2, 3, false, None);
497
498 for i in 0..5 {
500 let features = Array1::from(vec![i as f64, (i + 1) as f64]);
501 let target = (2 * i + 1) as f64;
502 let point = StreamingDataPoint::new(features, target);
503
504 optimizer.update(&point).unwrap();
505 }
506
507 assert_eq!(optimizer.data_window.len(), 3);
509 assert_eq!(optimizer.stats().points_processed, 5);
510 }
511
512 #[test]
513 fn test_gradient_descent_window() {
514 let config = StreamingConfig {
515 window_size: 5,
516 learning_rate: 0.1,
517 ..Default::default()
518 };
519
520 let mut optimizer = rolling_window_gradient_descent(
521 Array1::zeros(2),
522 LinearRegressionObjective,
523 config,
524 None,
525 );
526
527 let data_points = vec![
529 StreamingDataPoint::new(Array1::from(vec![1.0, 0.0]), 2.0),
530 StreamingDataPoint::new(Array1::from(vec![0.0, 1.0]), 3.0),
531 StreamingDataPoint::new(Array1::from(vec![1.0, 1.0]), 5.0),
532 ];
533
534 for point in &data_points {
535 optimizer.update(point).unwrap();
536 }
537
538 assert_eq!(optimizer.stats().points_processed, 3);
539 assert!(optimizer.stats().updates_performed > 0);
540 }
541
542 #[test]
543 fn test_least_squares_window() {
544 let mut optimizer = rolling_window_linear_regression(2, 10, false, None);
545
546 let data_points = vec![
548 StreamingDataPoint::new(Array1::from(vec![1.0, 0.0]), 2.0),
549 StreamingDataPoint::new(Array1::from(vec![0.0, 1.0]), 3.0),
550 StreamingDataPoint::new(Array1::from(vec![1.0, 1.0]), 5.0),
551 StreamingDataPoint::new(Array1::from(vec![2.0, 1.0]), 7.0),
552 ];
553
554 for point in &data_points {
555 optimizer.update(point).unwrap();
556 }
557
558 let params = optimizer.parameters();
560 assert!(
561 (params[0] - 2.0).abs() < 0.1,
562 "First parameter: {}",
563 params[0]
564 );
565 assert!(
566 (params[1] - 3.0).abs() < 0.1,
567 "Second parameter: {}",
568 params[1]
569 );
570 }
571
572 #[test]
573 fn test_weighted_least_squares_window() {
574 let mut optimizer = rolling_window_linear_regression(2, 10, true, None);
575
576 let data_points = vec![
578 StreamingDataPoint::new(Array1::from(vec![1.0, 0.0]), 2.0),
579 StreamingDataPoint::new(Array1::from(vec![0.0, 1.0]), 3.0),
580 StreamingDataPoint::new(Array1::from(vec![1.0, 1.0]), 5.0),
581 ];
582
583 for point in &data_points {
584 optimizer.update(point).unwrap();
585 }
586
587 assert_eq!(optimizer.stats().points_processed, 3);
588 assert!(optimizer.stats().current_loss.is_finite());
589 }
590
591 #[test]
592 fn test_window_overflow() {
593 let mut optimizer = rolling_window_linear_regression(2, 2, false, None);
594
595 for i in 0..5 {
597 let features = Array1::from(vec![i as f64, 1.0]);
598 let target = i as f64;
599 let point = StreamingDataPoint::new(features, target);
600
601 optimizer.update(&point).unwrap();
602 }
603
604 assert_eq!(optimizer.data_window.len(), 2);
606 assert_eq!(optimizer.stats().points_processed, 5);
607 }
608
609 #[test]
610 fn test_window_reset() {
611 let mut optimizer = rolling_window_linear_regression(2, 5, false, None);
612
613 let point = StreamingDataPoint::new(Array1::from(vec![1.0, 2.0]), 3.0);
615 optimizer.update(&point).unwrap();
616
617 assert_eq!(optimizer.data_window.len(), 1);
618 assert_eq!(optimizer.stats().points_processed, 1);
619
620 optimizer.reset();
622 assert_eq!(optimizer.data_window.len(), 0);
623 assert_eq!(optimizer.stats().points_processed, 0);
624 }
625}