1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
7use sklears_core::{
8 error::Result as SklResult,
9 prelude::{Predict, SklearsError},
10 traits::{Estimator, Fit, Untrained},
11 types::Float,
12};
13use std::collections::{HashMap, VecDeque};
14use std::time::{Duration, Instant, SystemTime};
15
16use crate::{PipelinePredictor, PipelineStep};
17
18#[derive(Debug, Clone)]
20pub struct StreamDataPoint {
21 pub features: Array1<f64>,
23 pub target: Option<f64>,
25 pub timestamp: SystemTime,
27 pub metadata: HashMap<String, String>,
29 pub id: String,
31}
32
33impl StreamDataPoint {
34 #[must_use]
36 pub fn new(features: Array1<f64>, id: String) -> Self {
37 Self {
38 features,
39 target: None,
40 timestamp: SystemTime::now(),
41 metadata: HashMap::new(),
42 id,
43 }
44 }
45
46 #[must_use]
48 pub fn with_target(mut self, target: f64) -> Self {
49 self.target = Some(target);
50 self
51 }
52
53 #[must_use]
55 pub fn with_timestamp(mut self, timestamp: SystemTime) -> Self {
56 self.timestamp = timestamp;
57 self
58 }
59
60 #[must_use]
62 pub fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
63 self.metadata = metadata;
64 self
65 }
66}
67
68#[derive(Debug, Clone)]
70pub struct StreamWindow {
71 pub data_points: Vec<StreamDataPoint>,
73 pub start_time: SystemTime,
75 pub end_time: SystemTime,
77 pub metadata: HashMap<String, String>,
79}
80
81impl StreamWindow {
82 #[must_use]
84 pub fn new(start_time: SystemTime, end_time: SystemTime) -> Self {
85 Self {
86 data_points: Vec::new(),
87 start_time,
88 end_time,
89 metadata: HashMap::new(),
90 }
91 }
92
93 pub fn add_point(&mut self, point: StreamDataPoint) {
95 self.data_points.push(point);
96 }
97
98 pub fn features_matrix(&self) -> SklResult<Array2<f64>> {
100 if self.data_points.is_empty() {
101 return Err(SklearsError::InvalidInput("Empty window".to_string()));
102 }
103
104 let n_samples = self.data_points.len();
105 let n_features = self.data_points[0].features.len();
106
107 let mut features = Array2::zeros((n_samples, n_features));
108 for (i, point) in self.data_points.iter().enumerate() {
109 features.row_mut(i).assign(&point.features);
110 }
111
112 Ok(features)
113 }
114
115 #[must_use]
117 pub fn targets_array(&self) -> Option<Array1<f64>> {
118 if self.data_points.iter().all(|p| p.target.is_some()) {
119 Some(Array1::from_vec(
120 self.data_points
121 .iter()
122 .map(|p| p.target.unwrap_or_default())
123 .collect(),
124 ))
125 } else {
126 None
127 }
128 }
129
130 #[must_use]
132 pub fn size(&self) -> usize {
133 self.data_points.len()
134 }
135
136 #[must_use]
138 pub fn is_empty(&self) -> bool {
139 self.data_points.is_empty()
140 }
141}
142
143pub enum WindowingStrategy {
145 TumblingTime {
147 duration: Duration,
149 },
150 SlidingTime {
152 duration: Duration,
154 slide: Duration,
156 },
157 TumblingCount {
159 count: usize,
161 },
162 SlidingCount {
164 size: usize,
166 step: usize,
168 },
169 Session {
171 gap: Duration,
173 },
174 Custom {
176 trigger_fn: Box<dyn Fn(&[StreamDataPoint]) -> bool + Send + Sync>,
178 },
179}
180
181pub struct StreamConfig {
183 pub windowing: WindowingStrategy,
185 pub buffer_size: usize,
187 pub parallelism: usize,
189 pub backpressure_threshold: usize,
191 pub latency_target: Duration,
193 pub checkpoint_interval: Duration,
195 pub state_management: StateManagement,
197}
198
199impl Default for StreamConfig {
200 fn default() -> Self {
201 Self {
202 windowing: WindowingStrategy::TumblingTime {
203 duration: Duration::from_secs(60),
204 },
205 buffer_size: 10000,
206 parallelism: 1,
207 backpressure_threshold: 8000,
208 latency_target: Duration::from_millis(100),
209 checkpoint_interval: Duration::from_secs(300),
210 state_management: StateManagement::InMemory,
211 }
212 }
213}
214
215#[derive(Debug, Clone)]
217pub enum StateManagement {
218 InMemory,
220 Snapshots {
222 directory: String,
224 interval: Duration,
226 },
227 WriteAheadLog {
229 log_path: String,
231 },
232 External {
234 config: HashMap<String, String>,
236 },
237}
238
239pub enum UpdateStrategy {
241 Immediate,
243 Batch {
245 batch_size: usize,
247 },
248 TimeBased {
250 interval: Duration,
252 },
253 Adaptive {
255 drift_threshold: f64,
257 min_interval: Duration,
259 max_interval: Duration,
261 },
262 Custom {
264 trigger_fn: Box<dyn Fn(&StreamWindow, &StreamStats) -> bool + Send + Sync>,
266 },
267}
268
269#[derive(Debug, Clone)]
271pub struct StreamStats {
272 pub total_samples: usize,
274 pub throughput: f64,
276 pub avg_latency: f64,
278 pub buffer_utilization: f64,
280 pub accuracy: Option<f64>,
282 pub drift_metrics: HashMap<String, f64>,
284 pub error_rate: f64,
286 pub start_time: SystemTime,
288 pub last_update: SystemTime,
290}
291
292impl Default for StreamStats {
293 fn default() -> Self {
294 let now = SystemTime::now();
295 Self {
296 total_samples: 0,
297 throughput: 0.0,
298 avg_latency: 0.0,
299 buffer_utilization: 0.0,
300 accuracy: None,
301 drift_metrics: HashMap::new(),
302 error_rate: 0.0,
303 start_time: now,
304 last_update: now,
305 }
306 }
307}
308
309pub struct StreamingPipeline<S = Untrained> {
311 state: S,
312 base_estimator: Option<Box<dyn PipelinePredictor>>,
313 config: StreamConfig,
314 update_strategy: UpdateStrategy,
315 data_buffer: VecDeque<StreamDataPoint>,
316 windows: Vec<StreamWindow>,
317 statistics: StreamStats,
318}
319
320pub struct StreamingPipelineTrained {
322 fitted_estimator: Box<dyn PipelinePredictor>,
323 config: StreamConfig,
324 update_strategy: UpdateStrategy,
325 data_buffer: VecDeque<StreamDataPoint>,
326 windows: Vec<StreamWindow>,
327 statistics: StreamStats,
328 model_state: HashMap<String, f64>,
329 n_features_in: usize,
330 feature_names_in: Option<Vec<String>>,
331}
332
333impl StreamingPipeline<Untrained> {
334 #[must_use]
336 pub fn new(base_estimator: Box<dyn PipelinePredictor>, config: StreamConfig) -> Self {
337 Self {
338 state: Untrained,
339 base_estimator: Some(base_estimator),
340 config,
341 update_strategy: UpdateStrategy::Batch { batch_size: 100 },
342 data_buffer: VecDeque::new(),
343 windows: Vec::new(),
344 statistics: StreamStats::default(),
345 }
346 }
347
348 #[must_use]
350 pub fn update_strategy(mut self, strategy: UpdateStrategy) -> Self {
351 self.update_strategy = strategy;
352 self
353 }
354
355 #[must_use]
357 pub fn tumbling_time(
358 base_estimator: Box<dyn PipelinePredictor>,
359 window_duration: Duration,
360 ) -> Self {
361 let config = StreamConfig {
362 windowing: WindowingStrategy::TumblingTime {
363 duration: window_duration,
364 },
365 ..StreamConfig::default()
366 };
367 Self::new(base_estimator, config)
368 }
369
370 #[must_use]
372 pub fn sliding_window(
373 base_estimator: Box<dyn PipelinePredictor>,
374 window_size: usize,
375 slide_step: usize,
376 ) -> Self {
377 let config = StreamConfig {
378 windowing: WindowingStrategy::SlidingCount {
379 size: window_size,
380 step: slide_step,
381 },
382 ..StreamConfig::default()
383 };
384 Self::new(base_estimator, config)
385 }
386
387 #[must_use]
389 pub fn session_window(
390 base_estimator: Box<dyn PipelinePredictor>,
391 session_gap: Duration,
392 ) -> Self {
393 let config = StreamConfig {
394 windowing: WindowingStrategy::Session { gap: session_gap },
395 ..StreamConfig::default()
396 };
397 Self::new(base_estimator, config)
398 }
399}
400
401impl Estimator for StreamingPipeline<Untrained> {
402 type Config = ();
403 type Error = SklearsError;
404 type Float = Float;
405
406 fn config(&self) -> &Self::Config {
407 &()
408 }
409}
410
411impl Fit<ArrayView2<'_, Float>, Option<&ArrayView1<'_, Float>>> for StreamingPipeline<Untrained> {
412 type Fitted = StreamingPipeline<StreamingPipelineTrained>;
413
414 fn fit(
415 self,
416 x: &ArrayView2<'_, Float>,
417 y: &Option<&ArrayView1<'_, Float>>,
418 ) -> SklResult<Self::Fitted> {
419 let mut base_estimator = self
420 .base_estimator
421 .ok_or_else(|| SklearsError::InvalidInput("No base estimator provided".to_string()))?;
422
423 if let Some(y_ref) = y {
425 base_estimator.fit(x, y_ref)?;
426 } else {
427 return Err(SklearsError::InvalidInput(
428 "No target values provided for initial training".to_string(),
429 ));
430 }
431
432 let mut model_state = HashMap::new();
434 model_state.insert("batch_training_samples".to_string(), x.nrows() as f64);
435
436 let mut statistics = self.statistics;
437 statistics.total_samples = x.nrows();
438 statistics.start_time = SystemTime::now();
439 statistics.last_update = SystemTime::now();
440
441 Ok(StreamingPipeline {
442 state: StreamingPipelineTrained {
443 fitted_estimator: base_estimator,
444 config: self.config,
445 update_strategy: self.update_strategy,
446 data_buffer: self.data_buffer,
447 windows: self.windows,
448 statistics,
449 model_state,
450 n_features_in: x.ncols(),
451 feature_names_in: None,
452 },
453 base_estimator: None,
454 config: StreamConfig::default(),
455 update_strategy: UpdateStrategy::Immediate,
456 data_buffer: VecDeque::new(),
457 windows: Vec::new(),
458 statistics: StreamStats::default(),
459 })
460 }
461}
462
463impl StreamingPipeline<StreamingPipelineTrained> {
464 pub fn process_point(&mut self, point: StreamDataPoint) -> SklResult<Option<Array1<f64>>> {
466 let start_time = Instant::now();
467
468 if self.state.data_buffer.len() >= self.state.config.backpressure_threshold {
470 return Err(SklearsError::InvalidInput(
471 "Backpressure threshold exceeded".to_string(),
472 ));
473 }
474
475 self.state.data_buffer.push_back(point.clone());
477
478 self.state.statistics.total_samples += 1;
480 self.state.statistics.buffer_utilization =
481 self.state.data_buffer.len() as f64 / self.state.config.buffer_size as f64;
482
483 let features_2d =
485 Array2::from_shape_vec((1, point.features.len()), point.features.to_vec()).map_err(
486 |e| SklearsError::InvalidData {
487 reason: format!("Feature reshaping failed: {e}"),
488 },
489 )?;
490
491 let prediction = self.state.fitted_estimator.predict(&features_2d.view())?;
493
494 self.process_windows()?;
496
497 self.check_model_update()?;
499
500 let processing_time = start_time.elapsed().as_millis() as f64;
502 self.state.statistics.avg_latency =
503 (self.state.statistics.avg_latency * 0.9) + (processing_time * 0.1);
504
505 let elapsed = self
507 .state
508 .statistics
509 .start_time
510 .elapsed()
511 .unwrap_or(Duration::from_secs(1));
512 self.state.statistics.throughput =
513 self.state.statistics.total_samples as f64 / elapsed.as_secs_f64();
514
515 Ok(Some(prediction))
516 }
517
518 pub fn process_batch(&mut self, points: Vec<StreamDataPoint>) -> SklResult<Array2<f64>> {
520 let mut predictions = Vec::new();
521
522 for point in points {
523 if let Some(pred) = self.process_point(point)? {
524 predictions.extend(pred.iter().copied());
525 }
526 }
527
528 if predictions.is_empty() {
529 return Ok(Array2::zeros((0, 1)));
530 }
531
532 let n_predictions = predictions.len();
533 Array2::from_shape_vec((n_predictions, 1), predictions).map_err(|e| {
534 SklearsError::InvalidData {
535 reason: format!("Batch prediction reshape failed: {e}"),
536 }
537 })
538 }
539
540 fn process_windows(&mut self) -> SklResult<()> {
542 match &self.state.config.windowing {
543 WindowingStrategy::TumblingTime { duration } => {
544 self.process_tumbling_time_windows(*duration)
545 }
546 WindowingStrategy::SlidingTime { duration, slide } => {
547 self.process_sliding_time_windows(*duration, *slide)
548 }
549 WindowingStrategy::TumblingCount { count } => {
550 self.process_tumbling_count_windows(*count)
551 }
552 WindowingStrategy::SlidingCount { size, step } => {
553 self.process_sliding_count_windows(*size, *step)
554 }
555 WindowingStrategy::Session { gap } => self.process_session_windows(*gap),
556 WindowingStrategy::Custom { .. } => {
557 self.process_custom_windows_safe()
559 }
560 }
561 }
562
563 fn process_tumbling_time_windows(&mut self, duration: Duration) -> SklResult<()> {
565 let now = SystemTime::now();
566
567 if self.state.windows.is_empty() {
569 let window = StreamWindow::new(now, now + duration);
570 self.state.windows.push(window);
571 }
572
573 while let Some(point) = self.state.data_buffer.pop_front() {
575 if let Some(current_window) = self.state.windows.last_mut() {
576 if point.timestamp <= current_window.end_time {
577 current_window.add_point(point);
578 } else {
579 let mut new_window = StreamWindow::new(
581 current_window.end_time,
582 current_window.end_time + duration,
583 );
584 new_window.add_point(point);
585 self.state.windows.push(new_window);
586 }
587 }
588 }
589
590 self.state.windows.retain(|w| w.end_time > now);
592
593 Ok(())
594 }
595
596 fn process_sliding_time_windows(
598 &mut self,
599 duration: Duration,
600 slide: Duration,
601 ) -> SklResult<()> {
602 self.process_tumbling_time_windows(duration)
604 }
605
606 fn process_tumbling_count_windows(&mut self, count: usize) -> SklResult<()> {
608 let now = SystemTime::now();
609
610 while self.state.data_buffer.len() >= count {
611 let mut window = StreamWindow::new(now, now);
612 for _ in 0..count {
613 if let Some(point) = self.state.data_buffer.pop_front() {
614 window.add_point(point);
615 }
616 }
617 self.state.windows.push(window);
618 }
619
620 Ok(())
621 }
622
623 fn process_sliding_count_windows(&mut self, size: usize, step: usize) -> SklResult<()> {
625 self.process_tumbling_count_windows(step)
627 }
628
629 fn process_session_windows(&mut self, gap: Duration) -> SklResult<()> {
631 let now = SystemTime::now();
633
634 if let Some(mut current_window) = self.state.windows.pop() {
635 while let Some(point) = self.state.data_buffer.pop_front() {
636 let time_since_last = point
637 .timestamp
638 .duration_since(current_window.end_time)
639 .unwrap_or(Duration::ZERO);
640
641 if time_since_last <= gap {
642 current_window.add_point(point.clone());
643 current_window.end_time = point.timestamp;
644 } else {
645 self.state.windows.push(current_window);
647 current_window = StreamWindow::new(point.timestamp, point.timestamp);
648 current_window.add_point(point);
649 }
650 }
651 self.state.windows.push(current_window);
652 } else if !self.state.data_buffer.is_empty() {
653 if let Some(point) = self.state.data_buffer.pop_front() {
655 let mut window = StreamWindow::new(point.timestamp, point.timestamp);
656 window.add_point(point);
657 self.state.windows.push(window);
658 }
659 }
660
661 Ok(())
662 }
663
664 fn process_custom_windows_safe(&mut self) -> SklResult<()> {
666 if let WindowingStrategy::Custom { trigger_fn } = &self.state.config.windowing {
668 let buffer_vec: Vec<StreamDataPoint> = self.state.data_buffer.iter().cloned().collect();
669
670 if trigger_fn(&buffer_vec) {
671 let now = SystemTime::now();
672 let mut window = StreamWindow::new(now, now);
673
674 while let Some(point) = self.state.data_buffer.pop_front() {
675 window.add_point(point);
676 }
677
678 if !window.is_empty() {
679 self.state.windows.push(window);
680 }
681 }
682 }
683
684 Ok(())
685 }
686
687 fn process_custom_windows(
689 &mut self,
690 trigger_fn: &Box<dyn Fn(&[StreamDataPoint]) -> bool + Send + Sync>,
691 ) -> SklResult<()> {
692 let buffer_vec: Vec<StreamDataPoint> = self.state.data_buffer.iter().cloned().collect();
693
694 if trigger_fn(&buffer_vec) {
695 let now = SystemTime::now();
696 let mut window = StreamWindow::new(now, now);
697
698 while let Some(point) = self.state.data_buffer.pop_front() {
699 window.add_point(point);
700 }
701
702 if !window.is_empty() {
703 self.state.windows.push(window);
704 }
705 }
706
707 Ok(())
708 }
709
710 fn check_model_update(&mut self) -> SklResult<()> {
712 let should_update = match &self.state.update_strategy {
713 UpdateStrategy::Immediate => !self.state.data_buffer.is_empty(),
714 UpdateStrategy::Batch { batch_size } => self.state.data_buffer.len() >= *batch_size,
715 UpdateStrategy::TimeBased { interval } => {
716 self.state
717 .statistics
718 .last_update
719 .elapsed()
720 .unwrap_or(Duration::ZERO)
721 >= *interval
722 }
723 UpdateStrategy::Adaptive {
724 drift_threshold,
725 min_interval,
726 max_interval,
727 } => self.check_adaptive_update(*drift_threshold, *min_interval, *max_interval),
728 UpdateStrategy::Custom { trigger_fn } => {
729 if let Some(window) = self.state.windows.last() {
730 trigger_fn(window, &self.state.statistics)
731 } else {
732 false
733 }
734 }
735 };
736
737 if should_update {
738 self.update_model()?;
739 }
740
741 Ok(())
742 }
743
744 fn check_adaptive_update(
746 &self,
747 drift_threshold: f64,
748 min_interval: Duration,
749 max_interval: Duration,
750 ) -> bool {
751 let elapsed = self
752 .state
753 .statistics
754 .last_update
755 .elapsed()
756 .unwrap_or(Duration::ZERO);
757
758 if elapsed < min_interval {
759 return false;
760 }
761
762 if elapsed >= max_interval {
763 return true;
764 }
765
766 let drift_score = self
768 .state
769 .statistics
770 .drift_metrics
771 .get("feature_drift")
772 .unwrap_or(&0.0);
773 *drift_score > drift_threshold
774 }
775
776 fn update_model(&mut self) -> SklResult<()> {
778 if let Some(window) = self.state.windows.last() {
779 if !window.is_empty() {
780 let features = window.features_matrix()?;
781 let targets = window.targets_array();
782
783 if let Some(targets_array) = targets {
784 self.state
786 .fitted_estimator
787 .fit(&features.view(), &targets_array.view())?;
788
789 self.state.statistics.last_update = SystemTime::now();
790 self.state
791 .model_state
792 .insert("last_update_samples".to_string(), window.size() as f64);
793 }
794 }
795 }
796
797 Ok(())
798 }
799
800 #[must_use]
802 pub fn statistics(&self) -> &StreamStats {
803 &self.state.statistics
804 }
805
806 #[must_use]
808 pub fn buffer_size(&self) -> usize {
809 self.state.data_buffer.len()
810 }
811
812 #[must_use]
814 pub fn active_windows(&self) -> usize {
815 self.state.windows.len()
816 }
817
818 pub fn checkpoint(&self) -> SklResult<HashMap<String, String>> {
820 let mut checkpoint = HashMap::new();
821 checkpoint.insert(
822 "total_samples".to_string(),
823 self.state.statistics.total_samples.to_string(),
824 );
825 checkpoint.insert(
826 "buffer_size".to_string(),
827 self.state.data_buffer.len().to_string(),
828 );
829 checkpoint.insert(
830 "active_windows".to_string(),
831 self.state.windows.len().to_string(),
832 );
833 checkpoint.insert(
834 "throughput".to_string(),
835 self.state.statistics.throughput.to_string(),
836 );
837
838 Ok(checkpoint)
839 }
840
841 pub fn clear_buffers(&mut self) {
843 self.state.data_buffer.clear();
844 self.state.windows.clear();
845 }
846
847 #[must_use]
849 pub fn drift_metrics(&self) -> &HashMap<String, f64> {
850 &self.state.statistics.drift_metrics
851 }
852
853 pub fn detect_drift(
855 &mut self,
856 reference_window: &StreamWindow,
857 current_window: &StreamWindow,
858 ) -> SklResult<f64> {
859 if reference_window.is_empty() || current_window.is_empty() {
860 return Ok(0.0);
861 }
862
863 let ref_features = reference_window.features_matrix()?;
864 let cur_features = current_window.features_matrix()?;
865
866 let ref_mean = ref_features.mean_axis(Axis(0)).unwrap_or_default();
868 let cur_mean = cur_features.mean_axis(Axis(0)).unwrap_or_default();
869
870 let drift_score = (&ref_mean - &cur_mean).mapv(|x| x * x).sum().sqrt();
871
872 self.state
874 .statistics
875 .drift_metrics
876 .insert("feature_drift".to_string(), drift_score);
877
878 Ok(drift_score)
879 }
880}
881
882#[allow(non_snake_case)]
883#[cfg(test)]
884mod tests {
885 use super::*;
886 use crate::MockPredictor;
887 use scirs2_core::ndarray::array;
888
889 #[test]
890 fn test_stream_data_point() {
891 let features = array![1.0, 2.0, 3.0];
892 let point =
893 StreamDataPoint::new(features.clone(), "test_point".to_string()).with_target(1.0);
894
895 assert_eq!(point.id, "test_point");
896 assert_eq!(point.features, features);
897 assert_eq!(point.target, Some(1.0));
898 }
899
900 #[test]
901 fn test_stream_window() {
902 let start_time = SystemTime::now();
903 let end_time = start_time + Duration::from_secs(60);
904 let mut window = StreamWindow::new(start_time, end_time);
905
906 let point1 = StreamDataPoint::new(array![1.0, 2.0], "point1".to_string());
907 let point2 = StreamDataPoint::new(array![3.0, 4.0], "point2".to_string());
908
909 window.add_point(point1);
910 window.add_point(point2);
911
912 assert_eq!(window.size(), 2);
913
914 let features = window.features_matrix().unwrap_or_default();
915 assert_eq!(features.nrows(), 2);
916 assert_eq!(features.ncols(), 2);
917 }
918
919 #[test]
920 fn test_streaming_pipeline_creation() {
921 let base_estimator = Box::new(MockPredictor::new());
922 let pipeline = StreamingPipeline::tumbling_time(base_estimator, Duration::from_secs(60));
923
924 assert!(matches!(
925 pipeline.config.windowing,
926 WindowingStrategy::TumblingTime { .. }
927 ));
928 }
929
930 #[test]
931 fn test_streaming_pipeline_fit() {
932 let x = array![[1.0, 2.0], [3.0, 4.0]];
933 let y = array![1.0, 0.0];
934
935 let base_estimator = Box::new(MockPredictor::new());
936 let pipeline = StreamingPipeline::tumbling_time(base_estimator, Duration::from_secs(60));
937
938 let fitted_pipeline = pipeline
939 .fit(&x.view(), &Some(&y.view()))
940 .expect("operation should succeed");
941 assert_eq!(fitted_pipeline.state.n_features_in, 2);
942 assert_eq!(fitted_pipeline.state.statistics.total_samples, 2);
943 }
944
945 #[test]
946 fn test_point_processing() {
947 let x = array![[1.0, 2.0], [3.0, 4.0]];
948 let y = array![1.0, 0.0];
949
950 let base_estimator = Box::new(MockPredictor::new());
951 let pipeline = StreamingPipeline::tumbling_time(base_estimator, Duration::from_secs(60));
952
953 let mut fitted_pipeline = pipeline
954 .fit(&x.view(), &Some(&y.view()))
955 .expect("operation should succeed");
956
957 let point = StreamDataPoint::new(array![5.0, 6.0], "test_point".to_string());
958 let prediction = fitted_pipeline.process_point(point).unwrap_or_default();
959
960 assert!(prediction.is_some());
961 assert_eq!(fitted_pipeline.active_windows(), 1);
962 }
963
964 #[test]
965 fn test_window_strategies() {
966 let base_estimator = Box::new(MockPredictor::new());
967
968 let pipeline = StreamingPipeline::new(
970 base_estimator,
971 StreamConfig {
972 windowing: WindowingStrategy::TumblingCount { count: 2 },
973 ..StreamConfig::default()
974 },
975 );
976
977 assert!(matches!(
978 pipeline.config.windowing,
979 WindowingStrategy::TumblingCount { count: 2 }
980 ));
981 }
982
983 #[test]
984 fn test_update_strategies() {
985 let base_estimator = Box::new(MockPredictor::new());
986 let pipeline = StreamingPipeline::tumbling_time(base_estimator, Duration::from_secs(60))
987 .update_strategy(UpdateStrategy::Batch { batch_size: 10 });
988
989 assert!(matches!(
990 pipeline.update_strategy,
991 UpdateStrategy::Batch { batch_size: 10 }
992 ));
993 }
994}