scirs2_optimize/streaming/
mod.rs1use crate::error::OptimizeError;
24use ndarray::{Array1, Array2, ArrayView1};
27type Result<T> = std::result::Result<T, OptimizeError>;
30pub mod advanced_adaptive_streaming;
34pub mod incremental_newton;
35pub mod online_gradient_descent;
36pub mod real_time_estimation;
37pub mod rolling_window;
38pub mod streaming_trust_region;
39
40pub use advanced_adaptive_streaming::*;
41pub use incremental_newton::*;
42pub use online_gradient_descent::*;
43pub use real_time_estimation::*;
44pub use rolling_window::*;
45pub use streaming_trust_region::*;
46
47#[derive(Debug, Clone)]
49pub struct StreamingConfig {
50 pub max_nit: usize,
52 pub tolerance: f64,
54 pub learning_rate: f64,
56 pub forgetting_factor: f64,
58 pub window_size: usize,
60 pub adaptive_lr: bool,
62 pub regularization: f64,
64}
65
66impl Default for StreamingConfig {
67 fn default() -> Self {
68 Self {
69 max_nit: 100,
70 tolerance: 1e-6,
71 learning_rate: 0.01,
72 forgetting_factor: 0.9,
73 window_size: 100,
74 adaptive_lr: true,
75 regularization: 1e-8,
76 }
77 }
78}
79
80#[derive(Debug, Clone)]
82pub struct StreamingStats {
83 pub points_processed: usize,
85 pub updates_performed: usize,
87 pub current_loss: f64,
89 pub average_loss: f64,
91 pub converged: bool,
93 pub processing_time_ms: f64,
95}
96
97impl Default for StreamingStats {
98 fn default() -> Self {
99 Self {
100 points_processed: 0,
101 updates_performed: 0,
102 current_loss: f64::INFINITY,
103 average_loss: f64::INFINITY,
104 converged: false,
105 processing_time_ms: 0.0,
106 }
107 }
108}
109
110#[derive(Debug, Clone)]
112pub struct StreamingDataPoint {
113 pub features: Array1<f64>,
115 pub target: f64,
117 pub weight: Option<f64>,
119 pub timestamp: Option<f64>,
121}
122
123impl StreamingDataPoint {
124 pub fn new(features: Array1<f64>, target: f64) -> Self {
126 Self {
127 features,
128 target,
129 weight: None,
130 timestamp: None,
131 }
132 }
133
134 pub fn with_weight(features: Array1<f64>, target: f64, weight: f64) -> Self {
136 Self {
137 features,
138 target,
139 weight: Some(weight),
140 timestamp: None,
141 }
142 }
143
144 pub fn with_timestamp(features: Array1<f64>, target: f64, timestamp: f64) -> Self {
146 Self {
147 features,
148 target,
149 weight: None,
150 timestamp: Some(timestamp),
151 }
152 }
153}
154
155pub trait StreamingOptimizer {
157 fn update(&mut self, datapoint: &StreamingDataPoint) -> Result<()>;
159
160 fn update_batch(&mut self, datapoints: &[StreamingDataPoint]) -> Result<()> {
162 for _point in datapoints {
163 self.update(_point)?;
164 }
165 Ok(())
166 }
167
168 fn parameters(&self) -> &Array1<f64>;
170
171 fn stats(&self) -> &StreamingStats;
173
174 fn reset(&mut self);
176
177 fn converged(&self) -> bool {
179 self.stats().converged
180 }
181}
182
183pub trait StreamingObjective {
185 fn evaluate(&self, parameters: &ArrayView1<f64>, datapoint: &StreamingDataPoint) -> f64;
187
188 fn gradient(&self, parameters: &ArrayView1<f64>, datapoint: &StreamingDataPoint)
190 -> Array1<f64>;
191
192 fn hessian(
194 self_parameters: &ArrayView1<f64>,
195 _datapoint: &StreamingDataPoint,
196 ) -> Option<Array2<f64>> {
197 None
198 }
199}
200
201#[derive(Debug, Clone)]
203pub struct LinearRegressionObjective;
204
205impl StreamingObjective for LinearRegressionObjective {
206 fn evaluate(&self, parameters: &ArrayView1<f64>, datapoint: &StreamingDataPoint) -> f64 {
207 let prediction = parameters.dot(&datapoint.features);
208 let residual = prediction - datapoint.target;
209 let weight = datapoint.weight.unwrap_or(1.0);
210 0.5 * weight * residual * residual
211 }
212
213 fn gradient(
214 &self,
215 parameters: &ArrayView1<f64>,
216 datapoint: &StreamingDataPoint,
217 ) -> Array1<f64> {
218 let prediction = parameters.dot(&datapoint.features);
219 let residual = prediction - datapoint.target;
220 let weight = datapoint.weight.unwrap_or(1.0);
221 weight * residual * &datapoint.features
222 }
223
224 fn hessian(
225 self_parameters: &ArrayView1<f64>,
226 datapoint: &StreamingDataPoint,
227 ) -> Option<Array2<f64>> {
228 let weight = datapoint.weight.unwrap_or(1.0);
229 let n = datapoint.features.len();
230 let mut hessian = Array2::zeros((n, n));
231
232 for i in 0..n {
234 for j in 0..n {
235 hessian[[i, j]] = weight * datapoint.features[i] * datapoint.features[j];
236 }
237 }
238
239 Some(hessian)
240 }
241}
242
243#[derive(Debug, Clone)]
245pub struct LogisticRegressionObjective;
246
247impl StreamingObjective for LogisticRegressionObjective {
248 fn evaluate(&self, parameters: &ArrayView1<f64>, datapoint: &StreamingDataPoint) -> f64 {
249 let z = parameters.dot(&datapoint.features);
250 let weight = datapoint.weight.unwrap_or(1.0);
251
252 let loss = if z > 0.0 {
254 z + (1.0 + (-z).exp()).ln() - datapoint.target * z
255 } else {
256 (1.0 + z.exp()).ln() - datapoint.target * z
257 };
258
259 weight * loss
260 }
261
262 fn gradient(
263 &self,
264 parameters: &ArrayView1<f64>,
265 datapoint: &StreamingDataPoint,
266 ) -> Array1<f64> {
267 let z = parameters.dot(&datapoint.features);
268 let sigmoid = 1.0 / (1.0 + (-z).exp());
269 let weight = datapoint.weight.unwrap_or(1.0);
270
271 weight * (sigmoid - datapoint.target) * &datapoint.features
272 }
273
274 fn hessian(
275 parameters: &ArrayView1<f64>,
276 datapoint: &StreamingDataPoint,
277 ) -> Option<Array2<f64>> {
278 let z = parameters.dot(&datapoint.features);
279 let sigmoid = 1.0 / (1.0 + (-z).exp());
280 let weight = datapoint.weight.unwrap_or(1.0);
281 let scale = weight * sigmoid * (1.0 - sigmoid);
282
283 let n = datapoint.features.len();
284 let mut hessian = Array2::zeros((n, n));
285
286 for i in 0..n {
287 for j in 0..n {
288 hessian[[i, j]] = scale * datapoint.features[i] * datapoint.features[j];
289 }
290 }
291
292 Some(hessian)
293 }
294}
295
296pub mod utils {
298 use super::*;
299
300 pub fn ewma_update(_current: f64, newvalue: f64, alpha: f64) -> f64 {
302 alpha * newvalue + (1.0 - alpha) * _current
303 }
304
305 pub fn adaptive_learning_rate(
307 base_lr: f64,
308 gradient_norm: f64,
309 avg_gradient_norm: f64,
310 min_lr: f64,
311 max_lr: f64,
312 ) -> f64 {
313 if avg_gradient_norm > 0.0 {
314 let scale = (avg_gradient_norm / gradient_norm).sqrt();
315 (base_lr * scale).max(min_lr).min(max_lr)
316 } else {
317 base_lr
318 }
319 }
320
321 pub fn check_convergence(
323 old_params: &ArrayView1<f64>,
324 new_params: &ArrayView1<f64>,
325 tolerance: f64,
326 ) -> bool {
327 let change = (new_params - old_params).mapv(|x| x.abs()).sum();
328 let scale = new_params.mapv(|x| x.abs()).sum().max(1.0);
329 change / scale < tolerance
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336
337 #[test]
338 fn test_streaming_datapoint_creation() {
339 let features = Array1::from(vec![1.0, 2.0, 3.0]);
340 let target = 5.0;
341
342 let point = StreamingDataPoint::new(features.clone(), target);
343 assert_eq!(point.features, features);
344 assert_eq!(point.target, target);
345 assert!(point.weight.is_none());
346 assert!(point.timestamp.is_none());
347 }
348
349 #[test]
350 fn test_linear_regression_objective() {
351 let objective = LinearRegressionObjective;
352 let params = Array1::from(vec![1.0, 2.0]);
353 let features = Array1::from(vec![3.0, 4.0]);
354 let target = 10.0;
355 let point = StreamingDataPoint::new(features, target);
356
357 let loss = objective.evaluate(¶ms.view(), &point);
358 let gradient = objective.gradient(¶ms.view(), &point);
359
360 assert!((loss - 0.5).abs() < 1e-10);
363
364 assert!((gradient[0] - 3.0).abs() < 1e-10);
366 assert!((gradient[1] - 4.0).abs() < 1e-10);
367 }
368
369 #[test]
370 fn test_utils_ewma() {
371 let current = 10.0;
372 let newvalue = 20.0;
373 let alpha = 0.1;
374
375 let result = utils::ewma_update(current, newvalue, alpha);
376 let expected = 0.1 * 20.0 + 0.9 * 10.0;
377 assert!((result - expected).abs() < 1e-10);
378 }
379
380 #[test]
381 fn test_utils_convergence() {
382 let old_params = Array1::from(vec![1.0, 2.0]);
383 let new_params = Array1::from(vec![1.001, 2.001]);
384
385 assert!(utils::check_convergence(
387 &old_params.view(),
388 &new_params.view(),
389 1e-2
390 ));
391
392 assert!(!utils::check_convergence(
394 &old_params.view(),
395 &new_params.view(),
396 1e-6
397 ));
398 }
399}