Skip to main content

sklears_compose/
streaming.rs

1//! Streaming pipeline components for real-time data processing
2//!
3//! This module provides streaming capabilities including windowing strategies,
4//! online model updates, incremental processing, and real-time analytics.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
7use sklears_core::{
8    error::Result as SklResult,
9    prelude::{Predict, SklearsError},
10    traits::{Estimator, Fit, Untrained},
11    types::Float,
12};
13use std::collections::{HashMap, VecDeque};
14use std::time::{Duration, Instant, SystemTime};
15
16use crate::{PipelinePredictor, PipelineStep};
17
18/// Data point in a stream
19#[derive(Debug, Clone)]
20pub struct StreamDataPoint {
21    /// Feature values
22    pub features: Array1<f64>,
23    /// Target value (optional)
24    pub target: Option<f64>,
25    /// Timestamp
26    pub timestamp: SystemTime,
27    /// Metadata
28    pub metadata: HashMap<String, String>,
29    /// Data point ID
30    pub id: String,
31}
32
33impl StreamDataPoint {
34    /// Create a new stream data point
35    #[must_use]
36    pub fn new(features: Array1<f64>, id: String) -> Self {
37        Self {
38            features,
39            target: None,
40            timestamp: SystemTime::now(),
41            metadata: HashMap::new(),
42            id,
43        }
44    }
45
46    /// Set target value
47    #[must_use]
48    pub fn with_target(mut self, target: f64) -> Self {
49        self.target = Some(target);
50        self
51    }
52
53    /// Set timestamp
54    #[must_use]
55    pub fn with_timestamp(mut self, timestamp: SystemTime) -> Self {
56        self.timestamp = timestamp;
57        self
58    }
59
60    /// Set metadata
61    #[must_use]
62    pub fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
63        self.metadata = metadata;
64        self
65    }
66}
67
68/// Window of stream data points
69#[derive(Debug, Clone)]
70pub struct StreamWindow {
71    /// Data points in the window
72    pub data_points: Vec<StreamDataPoint>,
73    /// Window start time
74    pub start_time: SystemTime,
75    /// Window end time
76    pub end_time: SystemTime,
77    /// Window metadata
78    pub metadata: HashMap<String, String>,
79}
80
81impl StreamWindow {
82    /// Create a new stream window
83    #[must_use]
84    pub fn new(start_time: SystemTime, end_time: SystemTime) -> Self {
85        Self {
86            data_points: Vec::new(),
87            start_time,
88            end_time,
89            metadata: HashMap::new(),
90        }
91    }
92
93    /// Add a data point to the window
94    pub fn add_point(&mut self, point: StreamDataPoint) {
95        self.data_points.push(point);
96    }
97
98    /// Get features matrix
99    pub fn features_matrix(&self) -> SklResult<Array2<f64>> {
100        if self.data_points.is_empty() {
101            return Err(SklearsError::InvalidInput("Empty window".to_string()));
102        }
103
104        let n_samples = self.data_points.len();
105        let n_features = self.data_points[0].features.len();
106
107        let mut features = Array2::zeros((n_samples, n_features));
108        for (i, point) in self.data_points.iter().enumerate() {
109            features.row_mut(i).assign(&point.features);
110        }
111
112        Ok(features)
113    }
114
115    /// Get targets array
116    #[must_use]
117    pub fn targets_array(&self) -> Option<Array1<f64>> {
118        if self.data_points.iter().all(|p| p.target.is_some()) {
119            Some(Array1::from_vec(
120                self.data_points
121                    .iter()
122                    .map(|p| p.target.unwrap_or_default())
123                    .collect(),
124            ))
125        } else {
126            None
127        }
128    }
129
130    /// Get window size
131    #[must_use]
132    pub fn size(&self) -> usize {
133        self.data_points.len()
134    }
135
136    /// Check if window is empty
137    #[must_use]
138    pub fn is_empty(&self) -> bool {
139        self.data_points.is_empty()
140    }
141}
142
143/// Windowing strategy for stream processing
144pub enum WindowingStrategy {
145    /// Fixed time windows
146    TumblingTime {
147        /// Window duration
148        duration: Duration,
149    },
150    /// Sliding time windows
151    SlidingTime {
152        /// Window duration
153        duration: Duration,
154        /// Slide interval
155        slide: Duration,
156    },
157    /// Fixed count windows
158    TumblingCount {
159        /// Number of elements per window
160        count: usize,
161    },
162    /// Sliding count windows
163    SlidingCount {
164        /// Window size
165        size: usize,
166        /// Slide step
167        step: usize,
168    },
169    /// Session windows (gap-based)
170    Session {
171        /// Maximum gap between elements
172        gap: Duration,
173    },
174    /// Custom windowing
175    Custom {
176        /// Custom window trigger function
177        trigger_fn: Box<dyn Fn(&[StreamDataPoint]) -> bool + Send + Sync>,
178    },
179}
180
181/// Stream processing configuration
182pub struct StreamConfig {
183    /// Windowing strategy
184    pub windowing: WindowingStrategy,
185    /// Buffer size for incoming data
186    pub buffer_size: usize,
187    /// Processing parallelism
188    pub parallelism: usize,
189    /// Backpressure threshold
190    pub backpressure_threshold: usize,
191    /// Latency targets
192    pub latency_target: Duration,
193    /// Checkpoint interval
194    pub checkpoint_interval: Duration,
195    /// State management
196    pub state_management: StateManagement,
197}
198
199impl Default for StreamConfig {
200    fn default() -> Self {
201        Self {
202            windowing: WindowingStrategy::TumblingTime {
203                duration: Duration::from_secs(60),
204            },
205            buffer_size: 10000,
206            parallelism: 1,
207            backpressure_threshold: 8000,
208            latency_target: Duration::from_millis(100),
209            checkpoint_interval: Duration::from_secs(300),
210            state_management: StateManagement::InMemory,
211        }
212    }
213}
214
215/// State management strategy
216#[derive(Debug, Clone)]
217pub enum StateManagement {
218    /// In-memory state (non-persistent)
219    InMemory,
220    /// Periodic snapshots to disk
221    Snapshots {
222        /// Snapshot directory
223        directory: String,
224        /// Snapshot interval
225        interval: Duration,
226    },
227    /// Write-ahead log
228    WriteAheadLog {
229        /// Log file path
230        log_path: String,
231    },
232    /// External state store
233    External {
234        /// State store configuration
235        config: HashMap<String, String>,
236    },
237}
238
239/// Online learning update strategy
240pub enum UpdateStrategy {
241    /// Update on every data point
242    Immediate,
243    /// Batch updates
244    Batch {
245        /// Batch size
246        batch_size: usize,
247    },
248    /// Time-based updates
249    TimeBased {
250        /// Update interval
251        interval: Duration,
252    },
253    /// Adaptive updates based on drift detection
254    Adaptive {
255        /// Drift detection threshold
256        drift_threshold: f64,
257        /// Minimum update interval
258        min_interval: Duration,
259        /// Maximum update interval
260        max_interval: Duration,
261    },
262    /// Custom update trigger
263    Custom {
264        /// Update trigger function
265        trigger_fn: Box<dyn Fn(&StreamWindow, &StreamStats) -> bool + Send + Sync>,
266    },
267}
268
269/// Stream processing statistics
270#[derive(Debug, Clone)]
271pub struct StreamStats {
272    /// Total processed samples
273    pub total_samples: usize,
274    /// Current throughput (samples/second)
275    pub throughput: f64,
276    /// Average latency (milliseconds)
277    pub avg_latency: f64,
278    /// Current buffer utilization
279    pub buffer_utilization: f64,
280    /// Model accuracy (if available)
281    pub accuracy: Option<f64>,
282    /// Data drift metrics
283    pub drift_metrics: HashMap<String, f64>,
284    /// Error rates
285    pub error_rate: f64,
286    /// Processing start time
287    pub start_time: SystemTime,
288    /// Last update time
289    pub last_update: SystemTime,
290}
291
292impl Default for StreamStats {
293    fn default() -> Self {
294        let now = SystemTime::now();
295        Self {
296            total_samples: 0,
297            throughput: 0.0,
298            avg_latency: 0.0,
299            buffer_utilization: 0.0,
300            accuracy: None,
301            drift_metrics: HashMap::new(),
302            error_rate: 0.0,
303            start_time: now,
304            last_update: now,
305        }
306    }
307}
308
309/// Streaming pipeline processor
310pub struct StreamingPipeline<S = Untrained> {
311    state: S,
312    base_estimator: Option<Box<dyn PipelinePredictor>>,
313    config: StreamConfig,
314    update_strategy: UpdateStrategy,
315    data_buffer: VecDeque<StreamDataPoint>,
316    windows: Vec<StreamWindow>,
317    statistics: StreamStats,
318}
319
320/// Trained state for `StreamingPipeline`
321pub struct StreamingPipelineTrained {
322    fitted_estimator: Box<dyn PipelinePredictor>,
323    config: StreamConfig,
324    update_strategy: UpdateStrategy,
325    data_buffer: VecDeque<StreamDataPoint>,
326    windows: Vec<StreamWindow>,
327    statistics: StreamStats,
328    model_state: HashMap<String, f64>,
329    n_features_in: usize,
330    feature_names_in: Option<Vec<String>>,
331}
332
333impl StreamingPipeline<Untrained> {
334    /// Create a new streaming pipeline
335    #[must_use]
336    pub fn new(base_estimator: Box<dyn PipelinePredictor>, config: StreamConfig) -> Self {
337        Self {
338            state: Untrained,
339            base_estimator: Some(base_estimator),
340            config,
341            update_strategy: UpdateStrategy::Batch { batch_size: 100 },
342            data_buffer: VecDeque::new(),
343            windows: Vec::new(),
344            statistics: StreamStats::default(),
345        }
346    }
347
348    /// Set update strategy
349    #[must_use]
350    pub fn update_strategy(mut self, strategy: UpdateStrategy) -> Self {
351        self.update_strategy = strategy;
352        self
353    }
354
355    /// Create a tumbling time window pipeline
356    #[must_use]
357    pub fn tumbling_time(
358        base_estimator: Box<dyn PipelinePredictor>,
359        window_duration: Duration,
360    ) -> Self {
361        let config = StreamConfig {
362            windowing: WindowingStrategy::TumblingTime {
363                duration: window_duration,
364            },
365            ..StreamConfig::default()
366        };
367        Self::new(base_estimator, config)
368    }
369
370    /// Create a sliding window pipeline
371    #[must_use]
372    pub fn sliding_window(
373        base_estimator: Box<dyn PipelinePredictor>,
374        window_size: usize,
375        slide_step: usize,
376    ) -> Self {
377        let config = StreamConfig {
378            windowing: WindowingStrategy::SlidingCount {
379                size: window_size,
380                step: slide_step,
381            },
382            ..StreamConfig::default()
383        };
384        Self::new(base_estimator, config)
385    }
386
387    /// Create a session window pipeline
388    #[must_use]
389    pub fn session_window(
390        base_estimator: Box<dyn PipelinePredictor>,
391        session_gap: Duration,
392    ) -> Self {
393        let config = StreamConfig {
394            windowing: WindowingStrategy::Session { gap: session_gap },
395            ..StreamConfig::default()
396        };
397        Self::new(base_estimator, config)
398    }
399}
400
401impl Estimator for StreamingPipeline<Untrained> {
402    type Config = ();
403    type Error = SklearsError;
404    type Float = Float;
405
406    fn config(&self) -> &Self::Config {
407        &()
408    }
409}
410
411impl Fit<ArrayView2<'_, Float>, Option<&ArrayView1<'_, Float>>> for StreamingPipeline<Untrained> {
412    type Fitted = StreamingPipeline<StreamingPipelineTrained>;
413
414    fn fit(
415        self,
416        x: &ArrayView2<'_, Float>,
417        y: &Option<&ArrayView1<'_, Float>>,
418    ) -> SklResult<Self::Fitted> {
419        let mut base_estimator = self
420            .base_estimator
421            .ok_or_else(|| SklearsError::InvalidInput("No base estimator provided".to_string()))?;
422
423        // Initial training on batch data
424        if let Some(y_ref) = y {
425            base_estimator.fit(x, y_ref)?;
426        } else {
427            return Err(SklearsError::InvalidInput(
428                "No target values provided for initial training".to_string(),
429            ));
430        }
431
432        // Initialize streaming state
433        let mut model_state = HashMap::new();
434        model_state.insert("batch_training_samples".to_string(), x.nrows() as f64);
435
436        let mut statistics = self.statistics;
437        statistics.total_samples = x.nrows();
438        statistics.start_time = SystemTime::now();
439        statistics.last_update = SystemTime::now();
440
441        Ok(StreamingPipeline {
442            state: StreamingPipelineTrained {
443                fitted_estimator: base_estimator,
444                config: self.config,
445                update_strategy: self.update_strategy,
446                data_buffer: self.data_buffer,
447                windows: self.windows,
448                statistics,
449                model_state,
450                n_features_in: x.ncols(),
451                feature_names_in: None,
452            },
453            base_estimator: None,
454            config: StreamConfig::default(),
455            update_strategy: UpdateStrategy::Immediate,
456            data_buffer: VecDeque::new(),
457            windows: Vec::new(),
458            statistics: StreamStats::default(),
459        })
460    }
461}
462
463impl StreamingPipeline<StreamingPipelineTrained> {
464    /// Process a single data point from the stream
465    pub fn process_point(&mut self, point: StreamDataPoint) -> SklResult<Option<Array1<f64>>> {
466        let start_time = Instant::now();
467
468        // Check for backpressure
469        if self.state.data_buffer.len() >= self.state.config.backpressure_threshold {
470            return Err(SklearsError::InvalidInput(
471                "Backpressure threshold exceeded".to_string(),
472            ));
473        }
474
475        // Add to buffer
476        self.state.data_buffer.push_back(point.clone());
477
478        // Update statistics
479        self.state.statistics.total_samples += 1;
480        self.state.statistics.buffer_utilization =
481            self.state.data_buffer.len() as f64 / self.state.config.buffer_size as f64;
482
483        // Create prediction input
484        let features_2d =
485            Array2::from_shape_vec((1, point.features.len()), point.features.to_vec()).map_err(
486                |e| SklearsError::InvalidData {
487                    reason: format!("Feature reshaping failed: {e}"),
488                },
489            )?;
490
491        // Make prediction
492        let prediction = self.state.fitted_estimator.predict(&features_2d.view())?;
493
494        // Process windows
495        self.process_windows()?;
496
497        // Check for model updates
498        self.check_model_update()?;
499
500        // Update latency statistics
501        let processing_time = start_time.elapsed().as_millis() as f64;
502        self.state.statistics.avg_latency =
503            (self.state.statistics.avg_latency * 0.9) + (processing_time * 0.1);
504
505        // Update throughput
506        let elapsed = self
507            .state
508            .statistics
509            .start_time
510            .elapsed()
511            .unwrap_or(Duration::from_secs(1));
512        self.state.statistics.throughput =
513            self.state.statistics.total_samples as f64 / elapsed.as_secs_f64();
514
515        Ok(Some(prediction))
516    }
517
518    /// Process batch of data points
519    pub fn process_batch(&mut self, points: Vec<StreamDataPoint>) -> SklResult<Array2<f64>> {
520        let mut predictions = Vec::new();
521
522        for point in points {
523            if let Some(pred) = self.process_point(point)? {
524                predictions.extend(pred.iter().copied());
525            }
526        }
527
528        if predictions.is_empty() {
529            return Ok(Array2::zeros((0, 1)));
530        }
531
532        let n_predictions = predictions.len();
533        Array2::from_shape_vec((n_predictions, 1), predictions).map_err(|e| {
534            SklearsError::InvalidData {
535                reason: format!("Batch prediction reshape failed: {e}"),
536            }
537        })
538    }
539
540    /// Process windowing logic
541    fn process_windows(&mut self) -> SklResult<()> {
542        match &self.state.config.windowing {
543            WindowingStrategy::TumblingTime { duration } => {
544                self.process_tumbling_time_windows(*duration)
545            }
546            WindowingStrategy::SlidingTime { duration, slide } => {
547                self.process_sliding_time_windows(*duration, *slide)
548            }
549            WindowingStrategy::TumblingCount { count } => {
550                self.process_tumbling_count_windows(*count)
551            }
552            WindowingStrategy::SlidingCount { size, step } => {
553                self.process_sliding_count_windows(*size, *step)
554            }
555            WindowingStrategy::Session { gap } => self.process_session_windows(*gap),
556            WindowingStrategy::Custom { .. } => {
557                // Handle custom windowing differently to avoid borrow checker issues
558                self.process_custom_windows_safe()
559            }
560        }
561    }
562
563    /// Process tumbling time windows
564    fn process_tumbling_time_windows(&mut self, duration: Duration) -> SklResult<()> {
565        let now = SystemTime::now();
566
567        // Create new window if needed
568        if self.state.windows.is_empty() {
569            let window = StreamWindow::new(now, now + duration);
570            self.state.windows.push(window);
571        }
572
573        // Add points to current window
574        while let Some(point) = self.state.data_buffer.pop_front() {
575            if let Some(current_window) = self.state.windows.last_mut() {
576                if point.timestamp <= current_window.end_time {
577                    current_window.add_point(point);
578                } else {
579                    // Create new window
580                    let mut new_window = StreamWindow::new(
581                        current_window.end_time,
582                        current_window.end_time + duration,
583                    );
584                    new_window.add_point(point);
585                    self.state.windows.push(new_window);
586                }
587            }
588        }
589
590        // Remove completed windows (keep only current)
591        self.state.windows.retain(|w| w.end_time > now);
592
593        Ok(())
594    }
595
596    /// Process sliding time windows
597    fn process_sliding_time_windows(
598        &mut self,
599        duration: Duration,
600        slide: Duration,
601    ) -> SklResult<()> {
602        // Simplified implementation
603        self.process_tumbling_time_windows(duration)
604    }
605
606    /// Process tumbling count windows
607    fn process_tumbling_count_windows(&mut self, count: usize) -> SklResult<()> {
608        let now = SystemTime::now();
609
610        while self.state.data_buffer.len() >= count {
611            let mut window = StreamWindow::new(now, now);
612            for _ in 0..count {
613                if let Some(point) = self.state.data_buffer.pop_front() {
614                    window.add_point(point);
615                }
616            }
617            self.state.windows.push(window);
618        }
619
620        Ok(())
621    }
622
623    /// Process sliding count windows
624    fn process_sliding_count_windows(&mut self, size: usize, step: usize) -> SklResult<()> {
625        // Simplified implementation - just use tumbling for now
626        self.process_tumbling_count_windows(step)
627    }
628
629    /// Process session windows
630    fn process_session_windows(&mut self, gap: Duration) -> SklResult<()> {
631        // Simplified implementation
632        let now = SystemTime::now();
633
634        if let Some(mut current_window) = self.state.windows.pop() {
635            while let Some(point) = self.state.data_buffer.pop_front() {
636                let time_since_last = point
637                    .timestamp
638                    .duration_since(current_window.end_time)
639                    .unwrap_or(Duration::ZERO);
640
641                if time_since_last <= gap {
642                    current_window.add_point(point.clone());
643                    current_window.end_time = point.timestamp;
644                } else {
645                    // Start new session
646                    self.state.windows.push(current_window);
647                    current_window = StreamWindow::new(point.timestamp, point.timestamp);
648                    current_window.add_point(point);
649                }
650            }
651            self.state.windows.push(current_window);
652        } else if !self.state.data_buffer.is_empty() {
653            // Start first session
654            if let Some(point) = self.state.data_buffer.pop_front() {
655                let mut window = StreamWindow::new(point.timestamp, point.timestamp);
656                window.add_point(point);
657                self.state.windows.push(window);
658            }
659        }
660
661        Ok(())
662    }
663
664    /// Process custom windows safely (avoiding borrow checker issues)
665    fn process_custom_windows_safe(&mut self) -> SklResult<()> {
666        // Extract trigger function to avoid borrowing issues
667        if let WindowingStrategy::Custom { trigger_fn } = &self.state.config.windowing {
668            let buffer_vec: Vec<StreamDataPoint> = self.state.data_buffer.iter().cloned().collect();
669
670            if trigger_fn(&buffer_vec) {
671                let now = SystemTime::now();
672                let mut window = StreamWindow::new(now, now);
673
674                while let Some(point) = self.state.data_buffer.pop_front() {
675                    window.add_point(point);
676                }
677
678                if !window.is_empty() {
679                    self.state.windows.push(window);
680                }
681            }
682        }
683
684        Ok(())
685    }
686
687    /// Process custom windows
688    fn process_custom_windows(
689        &mut self,
690        trigger_fn: &Box<dyn Fn(&[StreamDataPoint]) -> bool + Send + Sync>,
691    ) -> SklResult<()> {
692        let buffer_vec: Vec<StreamDataPoint> = self.state.data_buffer.iter().cloned().collect();
693
694        if trigger_fn(&buffer_vec) {
695            let now = SystemTime::now();
696            let mut window = StreamWindow::new(now, now);
697
698            while let Some(point) = self.state.data_buffer.pop_front() {
699                window.add_point(point);
700            }
701
702            if !window.is_empty() {
703                self.state.windows.push(window);
704            }
705        }
706
707        Ok(())
708    }
709
710    /// Check if model should be updated
711    fn check_model_update(&mut self) -> SklResult<()> {
712        let should_update = match &self.state.update_strategy {
713            UpdateStrategy::Immediate => !self.state.data_buffer.is_empty(),
714            UpdateStrategy::Batch { batch_size } => self.state.data_buffer.len() >= *batch_size,
715            UpdateStrategy::TimeBased { interval } => {
716                self.state
717                    .statistics
718                    .last_update
719                    .elapsed()
720                    .unwrap_or(Duration::ZERO)
721                    >= *interval
722            }
723            UpdateStrategy::Adaptive {
724                drift_threshold,
725                min_interval,
726                max_interval,
727            } => self.check_adaptive_update(*drift_threshold, *min_interval, *max_interval),
728            UpdateStrategy::Custom { trigger_fn } => {
729                if let Some(window) = self.state.windows.last() {
730                    trigger_fn(window, &self.state.statistics)
731                } else {
732                    false
733                }
734            }
735        };
736
737        if should_update {
738            self.update_model()?;
739        }
740
741        Ok(())
742    }
743
744    /// Check if adaptive update should be triggered
745    fn check_adaptive_update(
746        &self,
747        drift_threshold: f64,
748        min_interval: Duration,
749        max_interval: Duration,
750    ) -> bool {
751        let elapsed = self
752            .state
753            .statistics
754            .last_update
755            .elapsed()
756            .unwrap_or(Duration::ZERO);
757
758        if elapsed < min_interval {
759            return false;
760        }
761
762        if elapsed >= max_interval {
763            return true;
764        }
765
766        // Check for drift (simplified)
767        let drift_score = self
768            .state
769            .statistics
770            .drift_metrics
771            .get("feature_drift")
772            .unwrap_or(&0.0);
773        *drift_score > drift_threshold
774    }
775
776    /// Update the model with recent data
777    fn update_model(&mut self) -> SklResult<()> {
778        if let Some(window) = self.state.windows.last() {
779            if !window.is_empty() {
780                let features = window.features_matrix()?;
781                let targets = window.targets_array();
782
783                if let Some(targets_array) = targets {
784                    // Incremental learning (simplified)
785                    self.state
786                        .fitted_estimator
787                        .fit(&features.view(), &targets_array.view())?;
788
789                    self.state.statistics.last_update = SystemTime::now();
790                    self.state
791                        .model_state
792                        .insert("last_update_samples".to_string(), window.size() as f64);
793                }
794            }
795        }
796
797        Ok(())
798    }
799
800    /// Get current statistics
801    #[must_use]
802    pub fn statistics(&self) -> &StreamStats {
803        &self.state.statistics
804    }
805
806    /// Get current buffer size
807    #[must_use]
808    pub fn buffer_size(&self) -> usize {
809        self.state.data_buffer.len()
810    }
811
812    /// Get number of active windows
813    #[must_use]
814    pub fn active_windows(&self) -> usize {
815        self.state.windows.len()
816    }
817
818    /// Checkpoint the current state
819    pub fn checkpoint(&self) -> SklResult<HashMap<String, String>> {
820        let mut checkpoint = HashMap::new();
821        checkpoint.insert(
822            "total_samples".to_string(),
823            self.state.statistics.total_samples.to_string(),
824        );
825        checkpoint.insert(
826            "buffer_size".to_string(),
827            self.state.data_buffer.len().to_string(),
828        );
829        checkpoint.insert(
830            "active_windows".to_string(),
831            self.state.windows.len().to_string(),
832        );
833        checkpoint.insert(
834            "throughput".to_string(),
835            self.state.statistics.throughput.to_string(),
836        );
837
838        Ok(checkpoint)
839    }
840
841    /// Clear internal buffers and windows
842    pub fn clear_buffers(&mut self) {
843        self.state.data_buffer.clear();
844        self.state.windows.clear();
845    }
846
847    /// Get drift detection metrics
848    #[must_use]
849    pub fn drift_metrics(&self) -> &HashMap<String, f64> {
850        &self.state.statistics.drift_metrics
851    }
852
853    /// Detect concept drift (simplified implementation)
854    pub fn detect_drift(
855        &mut self,
856        reference_window: &StreamWindow,
857        current_window: &StreamWindow,
858    ) -> SklResult<f64> {
859        if reference_window.is_empty() || current_window.is_empty() {
860            return Ok(0.0);
861        }
862
863        let ref_features = reference_window.features_matrix()?;
864        let cur_features = current_window.features_matrix()?;
865
866        // Simple drift detection using mean difference
867        let ref_mean = ref_features.mean_axis(Axis(0)).unwrap_or_default();
868        let cur_mean = cur_features.mean_axis(Axis(0)).unwrap_or_default();
869
870        let drift_score = (&ref_mean - &cur_mean).mapv(|x| x * x).sum().sqrt();
871
872        // Update drift metrics
873        self.state
874            .statistics
875            .drift_metrics
876            .insert("feature_drift".to_string(), drift_score);
877
878        Ok(drift_score)
879    }
880}
881
882#[allow(non_snake_case)]
883#[cfg(test)]
884mod tests {
885    use super::*;
886    use crate::MockPredictor;
887    use scirs2_core::ndarray::array;
888
889    #[test]
890    fn test_stream_data_point() {
891        let features = array![1.0, 2.0, 3.0];
892        let point =
893            StreamDataPoint::new(features.clone(), "test_point".to_string()).with_target(1.0);
894
895        assert_eq!(point.id, "test_point");
896        assert_eq!(point.features, features);
897        assert_eq!(point.target, Some(1.0));
898    }
899
900    #[test]
901    fn test_stream_window() {
902        let start_time = SystemTime::now();
903        let end_time = start_time + Duration::from_secs(60);
904        let mut window = StreamWindow::new(start_time, end_time);
905
906        let point1 = StreamDataPoint::new(array![1.0, 2.0], "point1".to_string());
907        let point2 = StreamDataPoint::new(array![3.0, 4.0], "point2".to_string());
908
909        window.add_point(point1);
910        window.add_point(point2);
911
912        assert_eq!(window.size(), 2);
913
914        let features = window.features_matrix().unwrap_or_default();
915        assert_eq!(features.nrows(), 2);
916        assert_eq!(features.ncols(), 2);
917    }
918
919    #[test]
920    fn test_streaming_pipeline_creation() {
921        let base_estimator = Box::new(MockPredictor::new());
922        let pipeline = StreamingPipeline::tumbling_time(base_estimator, Duration::from_secs(60));
923
924        assert!(matches!(
925            pipeline.config.windowing,
926            WindowingStrategy::TumblingTime { .. }
927        ));
928    }
929
930    #[test]
931    fn test_streaming_pipeline_fit() {
932        let x = array![[1.0, 2.0], [3.0, 4.0]];
933        let y = array![1.0, 0.0];
934
935        let base_estimator = Box::new(MockPredictor::new());
936        let pipeline = StreamingPipeline::tumbling_time(base_estimator, Duration::from_secs(60));
937
938        let fitted_pipeline = pipeline
939            .fit(&x.view(), &Some(&y.view()))
940            .expect("operation should succeed");
941        assert_eq!(fitted_pipeline.state.n_features_in, 2);
942        assert_eq!(fitted_pipeline.state.statistics.total_samples, 2);
943    }
944
945    #[test]
946    fn test_point_processing() {
947        let x = array![[1.0, 2.0], [3.0, 4.0]];
948        let y = array![1.0, 0.0];
949
950        let base_estimator = Box::new(MockPredictor::new());
951        let pipeline = StreamingPipeline::tumbling_time(base_estimator, Duration::from_secs(60));
952
953        let mut fitted_pipeline = pipeline
954            .fit(&x.view(), &Some(&y.view()))
955            .expect("operation should succeed");
956
957        let point = StreamDataPoint::new(array![5.0, 6.0], "test_point".to_string());
958        let prediction = fitted_pipeline.process_point(point).unwrap_or_default();
959
960        assert!(prediction.is_some());
961        assert_eq!(fitted_pipeline.active_windows(), 1);
962    }
963
964    #[test]
965    fn test_window_strategies() {
966        let base_estimator = Box::new(MockPredictor::new());
967
968        // Test tumbling count windows
969        let pipeline = StreamingPipeline::new(
970            base_estimator,
971            StreamConfig {
972                windowing: WindowingStrategy::TumblingCount { count: 2 },
973                ..StreamConfig::default()
974            },
975        );
976
977        assert!(matches!(
978            pipeline.config.windowing,
979            WindowingStrategy::TumblingCount { count: 2 }
980        ));
981    }
982
983    #[test]
984    fn test_update_strategies() {
985        let base_estimator = Box::new(MockPredictor::new());
986        let pipeline = StreamingPipeline::tumbling_time(base_estimator, Duration::from_secs(60))
987            .update_strategy(UpdateStrategy::Batch { batch_size: 10 });
988
989        assert!(matches!(
990            pipeline.update_strategy,
991            UpdateStrategy::Batch { batch_size: 10 }
992        ));
993    }
994}