sklears_multioutput/
streaming.rs

1//! Streaming and Incremental Learning for Multi-Output Prediction
2//!
3//! This module provides algorithms for learning from streaming data with multiple outputs,
4//! including incremental learning, online learning, and concept drift detection.
5
6// Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
8use scirs2_core::random::Rng;
9use sklears_core::{
10    error::{Result as SklResult, SklearsError},
11    traits::{Estimator, Fit, Predict, Untrained},
12    types::Float,
13};
14use std::collections::VecDeque;
15
16// ============================================================================
17// Incremental Multi-Output Regression
18// ============================================================================
19
20/// Configuration for incremental multi-output regression
21#[derive(Debug, Clone)]
22pub struct IncrementalMultiOutputRegressionConfig {
23    /// Learning rate for gradient updates
24    pub learning_rate: Float,
25    /// L2 regularization parameter
26    pub alpha: Float,
27    /// Whether to fit intercept
28    pub fit_intercept: bool,
29    /// Maximum number of samples to keep in memory (for computing statistics)
30    pub max_samples: usize,
31    /// Whether to use adaptive learning rate
32    pub adaptive_learning_rate: bool,
33    /// Decay factor for learning rate
34    pub learning_rate_decay: Float,
35}
36
37impl Default for IncrementalMultiOutputRegressionConfig {
38    fn default() -> Self {
39        Self {
40            learning_rate: 0.01,
41            alpha: 0.0001,
42            fit_intercept: true,
43            max_samples: 10000,
44            adaptive_learning_rate: true,
45            learning_rate_decay: 0.999,
46        }
47    }
48}
49
50/// Incremental Multi-Output Regressor
51///
52/// Online learning algorithm that can learn from data streams with multiple outputs.
53/// Uses stochastic gradient descent with optional adaptive learning rates.
54///
55/// # Examples
56///
57/// ```rust
58/// use sklears_multioutput::streaming::{IncrementalMultiOutputRegression, IncrementalMultiOutputRegressionConfig};
59/// // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
60/// use scirs2_core::ndarray::array;
61/// use sklears_core::traits::{Fit, Predict};
62///
63/// let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
64/// let y = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
65///
66/// let mut model = IncrementalMultiOutputRegression::new();
67/// let trained = model.fit(&X.view(), &y.view()).unwrap();
68///
69/// // Continue learning with new data
70/// let X_new = array![[4.0, 5.0]];
71/// let y_new = array![[4.0, 5.0]];
72/// let updated = trained.partial_fit(&X_new.view(), &y_new.view()).unwrap();
73///
74/// let predictions = updated.predict(&X.view()).unwrap();
75/// assert_eq!(predictions.dim(), (3, 2));
76/// ```
77#[derive(Debug, Clone)]
78pub struct IncrementalMultiOutputRegression<S = Untrained> {
79    state: S,
80    config: IncrementalMultiOutputRegressionConfig,
81}
82
83/// Trained state for Incremental Multi-Output Regression
84#[derive(Debug, Clone)]
85pub struct IncrementalMultiOutputRegressionTrained {
86    /// Coefficient matrix (n_features x n_outputs)
87    pub coef: Array2<Float>,
88    /// Intercept vector (n_outputs)
89    pub intercept: Array1<Float>,
90    /// Number of features
91    pub n_features: usize,
92    /// Number of outputs
93    pub n_outputs: usize,
94    /// Number of samples seen so far
95    pub n_samples_seen: usize,
96    /// Current learning rate
97    pub current_learning_rate: Float,
98    /// Running mean of features (for normalization)
99    pub feature_mean: Array1<Float>,
100    /// Running std of features (for normalization)
101    pub feature_std: Array1<Float>,
102    /// Configuration
103    pub config: IncrementalMultiOutputRegressionConfig,
104}
105
106impl IncrementalMultiOutputRegression<Untrained> {
107    /// Create a new incremental multi-output regressor
108    pub fn new() -> Self {
109        Self {
110            state: Untrained,
111            config: IncrementalMultiOutputRegressionConfig::default(),
112        }
113    }
114
115    /// Set the configuration
116    pub fn config(mut self, config: IncrementalMultiOutputRegressionConfig) -> Self {
117        self.config = config;
118        self
119    }
120
121    /// Set the learning rate
122    pub fn learning_rate(mut self, lr: Float) -> Self {
123        self.config.learning_rate = lr;
124        self
125    }
126
127    /// Set the regularization parameter
128    pub fn alpha(mut self, alpha: Float) -> Self {
129        self.config.alpha = alpha;
130        self
131    }
132
133    /// Set whether to fit intercept
134    pub fn fit_intercept(mut self, fit_intercept: bool) -> Self {
135        self.config.fit_intercept = fit_intercept;
136        self
137    }
138}
139
140impl Default for IncrementalMultiOutputRegression<Untrained> {
141    fn default() -> Self {
142        Self::new()
143    }
144}
145
146impl Fit<ArrayView2<'_, Float>, ArrayView2<'_, Float>>
147    for IncrementalMultiOutputRegression<Untrained>
148{
149    type Fitted = IncrementalMultiOutputRegression<IncrementalMultiOutputRegressionTrained>;
150
151    fn fit(self, X: &ArrayView2<Float>, y: &ArrayView2<Float>) -> SklResult<Self::Fitted> {
152        if X.nrows() != y.nrows() {
153            return Err(SklearsError::InvalidInput(
154                "Number of samples in X and y must match".to_string(),
155            ));
156        }
157
158        let n_samples = X.nrows();
159        let n_features = X.ncols();
160        let n_outputs = y.ncols();
161
162        // Initialize coefficients
163        let mut coef = Array2::zeros((n_features, n_outputs));
164        let mut intercept = Array1::zeros(n_outputs);
165
166        // Compute feature statistics
167        let feature_mean = X.mean_axis(Axis(0)).unwrap();
168        let feature_std = X.std_axis(Axis(0), 0.0);
169
170        let mut current_learning_rate = self.config.learning_rate;
171
172        // Perform initial gradient descent over the batch
173        for _ in 0..10 {
174            // Mini-batch iterations
175            for i in 0..n_samples {
176                let x_i = X.row(i);
177                let y_i = y.row(i);
178
179                // Prediction
180                let pred = coef.t().dot(&x_i) + &intercept;
181
182                // Error
183                let error = &y_i - &pred;
184
185                // Update coefficients using gradient descent
186                for j in 0..n_features {
187                    for k in 0..n_outputs {
188                        let gradient = -error[k] * x_i[j] + self.config.alpha * coef[[j, k]];
189                        coef[[j, k]] -= current_learning_rate * gradient;
190                    }
191                }
192
193                // Update intercept
194                if self.config.fit_intercept {
195                    for k in 0..n_outputs {
196                        intercept[k] += current_learning_rate * error[k];
197                    }
198                }
199            }
200
201            // Decay learning rate
202            if self.config.adaptive_learning_rate {
203                current_learning_rate *= self.config.learning_rate_decay;
204            }
205        }
206
207        Ok(IncrementalMultiOutputRegression {
208            state: IncrementalMultiOutputRegressionTrained {
209                coef,
210                intercept,
211                n_features,
212                n_outputs,
213                n_samples_seen: n_samples,
214                current_learning_rate,
215                feature_mean,
216                feature_std,
217                config: self.config,
218            },
219            config: IncrementalMultiOutputRegressionConfig::default(),
220        })
221    }
222}
223
224impl IncrementalMultiOutputRegression<IncrementalMultiOutputRegressionTrained> {
225    /// Partial fit on new data (incremental learning)
226    pub fn partial_fit(mut self, X: &ArrayView2<Float>, y: &ArrayView2<Float>) -> SklResult<Self> {
227        if X.nrows() != y.nrows() {
228            return Err(SklearsError::InvalidInput(
229                "Number of samples in X and y must match".to_string(),
230            ));
231        }
232
233        if X.ncols() != self.state.n_features {
234            return Err(SklearsError::InvalidInput(format!(
235                "Expected {} features, got {}",
236                self.state.n_features,
237                X.ncols()
238            )));
239        }
240
241        if y.ncols() != self.state.n_outputs {
242            return Err(SklearsError::InvalidInput(format!(
243                "Expected {} outputs, got {}",
244                self.state.n_outputs,
245                y.ncols()
246            )));
247        }
248
249        let n_samples = X.nrows();
250
251        // Update feature statistics (running average)
252        let n_old = self.state.n_samples_seen as Float;
253        let n_new = n_samples as Float;
254        let n_total = n_old + n_new;
255
256        let new_mean = X.mean_axis(Axis(0)).unwrap();
257        self.state.feature_mean = (&self.state.feature_mean * n_old + &new_mean * n_new) / n_total;
258
259        // Perform incremental updates
260        for i in 0..n_samples {
261            let x_i = X.row(i);
262            let y_i = y.row(i);
263
264            // Prediction
265            let pred = self.state.coef.t().dot(&x_i) + &self.state.intercept;
266
267            // Error
268            let error = &y_i - &pred;
269
270            // Update coefficients
271            for j in 0..self.state.n_features {
272                for k in 0..self.state.n_outputs {
273                    let gradient =
274                        -error[k] * x_i[j] + self.state.config.alpha * self.state.coef[[j, k]];
275                    self.state.coef[[j, k]] -= self.state.current_learning_rate * gradient;
276                }
277            }
278
279            // Update intercept
280            if self.state.config.fit_intercept {
281                for k in 0..self.state.n_outputs {
282                    self.state.intercept[k] += self.state.current_learning_rate * error[k];
283                }
284            }
285        }
286
287        // Update statistics
288        self.state.n_samples_seen += n_samples;
289
290        // Decay learning rate
291        if self.state.config.adaptive_learning_rate {
292            self.state.current_learning_rate *= self.state.config.learning_rate_decay;
293        }
294
295        Ok(self)
296    }
297
298    /// Get the current coefficients
299    pub fn coef(&self) -> &Array2<Float> {
300        &self.state.coef
301    }
302
303    /// Get the current intercept
304    pub fn intercept(&self) -> &Array1<Float> {
305        &self.state.intercept
306    }
307
308    /// Get number of samples seen
309    pub fn n_samples_seen(&self) -> usize {
310        self.state.n_samples_seen
311    }
312
313    /// Get current learning rate
314    pub fn current_learning_rate(&self) -> Float {
315        self.state.current_learning_rate
316    }
317}
318
319impl Predict<ArrayView2<'_, Float>, Array2<Float>>
320    for IncrementalMultiOutputRegression<IncrementalMultiOutputRegressionTrained>
321{
322    fn predict(&self, X: &ArrayView2<Float>) -> SklResult<Array2<Float>> {
323        if X.ncols() != self.state.n_features {
324            return Err(SklearsError::InvalidInput(format!(
325                "Expected {} features, got {}",
326                self.state.n_features,
327                X.ncols()
328            )));
329        }
330
331        let n_samples = X.nrows();
332        let mut predictions = Array2::zeros((n_samples, self.state.n_outputs));
333
334        for i in 0..n_samples {
335            let x_i = X.row(i);
336            let pred = self.state.coef.t().dot(&x_i) + &self.state.intercept;
337            predictions.row_mut(i).assign(&pred);
338        }
339
340        Ok(predictions)
341    }
342}
343
344impl Estimator for IncrementalMultiOutputRegression<Untrained> {
345    type Config = IncrementalMultiOutputRegressionConfig;
346    type Error = SklearsError;
347    type Float = Float;
348
349    fn config(&self) -> &Self::Config {
350        &self.config
351    }
352}
353
354impl Estimator for IncrementalMultiOutputRegression<IncrementalMultiOutputRegressionTrained> {
355    type Config = IncrementalMultiOutputRegressionConfig;
356    type Error = SklearsError;
357    type Float = Float;
358
359    fn config(&self) -> &Self::Config {
360        &self.state.config
361    }
362}
363
364// ============================================================================
365// Streaming Multi-Output with Mini-Batches
366// ============================================================================
367
368/// Configuration for streaming multi-output learning
369#[derive(Debug, Clone)]
370pub struct StreamingMultiOutputConfig {
371    /// Mini-batch size for streaming updates
372    pub batch_size: usize,
373    /// Maximum buffer size before forced update
374    pub max_buffer_size: usize,
375    /// Learning rate
376    pub learning_rate: Float,
377    /// Whether to detect concept drift
378    pub detect_drift: bool,
379    /// Window size for drift detection
380    pub drift_window_size: usize,
381    /// Threshold for drift detection
382    pub drift_threshold: Float,
383}
384
385impl Default for StreamingMultiOutputConfig {
386    fn default() -> Self {
387        Self {
388            batch_size: 32,
389            max_buffer_size: 1000,
390            learning_rate: 0.01,
391            detect_drift: true,
392            drift_window_size: 100,
393            drift_threshold: 0.1,
394        }
395    }
396}
397
398/// Streaming Multi-Output Learner
399///
400/// Handles streaming data with mini-batch processing and concept drift detection.
401///
402/// # Examples
403///
404/// ```rust
405/// use sklears_multioutput::streaming::{StreamingMultiOutput, StreamingMultiOutputConfig};
406/// // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
407/// use scirs2_core::ndarray::array;
408/// use sklears_core::traits::{Fit, Predict};
409///
410/// let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
411/// let y = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
412///
413/// let mut model = StreamingMultiOutput::new()
414///     .batch_size(2)
415///     .learning_rate(0.1);
416///
417/// let trained = model.fit(&X.view(), &y.view()).unwrap();
418///
419/// // Add streaming data
420/// let X_stream = array![[4.0, 5.0]];
421/// let y_stream = array![[4.0, 5.0]];
422/// let updated = trained.update_stream(&X_stream.view(), &y_stream.view()).unwrap();
423///
424/// let predictions = updated.predict(&X.view()).unwrap();
425/// assert_eq!(predictions.dim(), (3, 2));
426/// ```
427#[derive(Debug, Clone)]
428pub struct StreamingMultiOutput<S = Untrained> {
429    state: S,
430    config: StreamingMultiOutputConfig,
431}
432
433/// Trained state for Streaming Multi-Output
434#[derive(Debug, Clone)]
435pub struct StreamingMultiOutputTrained {
436    /// Base incremental model
437    pub base_model: IncrementalMultiOutputRegressionTrained,
438    /// Buffer for mini-batch processing
439    pub buffer_X: VecDeque<Array1<Float>>,
440    pub buffer_y: VecDeque<Array1<Float>>,
441    /// Performance history for drift detection
442    pub error_history: VecDeque<Float>,
443    /// Whether drift was detected
444    pub drift_detected: bool,
445    /// Number of drift events detected
446    pub n_drift_events: usize,
447    /// Configuration
448    pub config: StreamingMultiOutputConfig,
449}
450
451impl StreamingMultiOutput<Untrained> {
452    /// Create a new streaming multi-output learner
453    pub fn new() -> Self {
454        Self {
455            state: Untrained,
456            config: StreamingMultiOutputConfig::default(),
457        }
458    }
459
460    /// Set the configuration
461    pub fn config(mut self, config: StreamingMultiOutputConfig) -> Self {
462        self.config = config;
463        self
464    }
465
466    /// Set the batch size
467    pub fn batch_size(mut self, batch_size: usize) -> Self {
468        self.config.batch_size = batch_size;
469        self
470    }
471
472    /// Set the learning rate
473    pub fn learning_rate(mut self, lr: Float) -> Self {
474        self.config.learning_rate = lr;
475        self
476    }
477
478    /// Enable/disable drift detection
479    pub fn detect_drift(mut self, detect: bool) -> Self {
480        self.config.detect_drift = detect;
481        self
482    }
483}
484
485impl Default for StreamingMultiOutput<Untrained> {
486    fn default() -> Self {
487        Self::new()
488    }
489}
490
491impl Fit<ArrayView2<'_, Float>, ArrayView2<'_, Float>> for StreamingMultiOutput<Untrained> {
492    type Fitted = StreamingMultiOutput<StreamingMultiOutputTrained>;
493
494    fn fit(self, X: &ArrayView2<Float>, y: &ArrayView2<Float>) -> SklResult<Self::Fitted> {
495        // Initialize base model
496        let base_config = IncrementalMultiOutputRegressionConfig {
497            learning_rate: self.config.learning_rate,
498            ..Default::default()
499        };
500
501        let base_model = IncrementalMultiOutputRegression::new()
502            .config(base_config)
503            .fit(X, y)?;
504
505        Ok(StreamingMultiOutput {
506            state: StreamingMultiOutputTrained {
507                base_model: base_model.state,
508                buffer_X: VecDeque::new(),
509                buffer_y: VecDeque::new(),
510                error_history: VecDeque::new(),
511                drift_detected: false,
512                n_drift_events: 0,
513                config: self.config,
514            },
515            config: StreamingMultiOutputConfig::default(),
516        })
517    }
518}
519
520impl StreamingMultiOutput<StreamingMultiOutputTrained> {
521    /// Update with streaming data
522    pub fn update_stream(
523        mut self,
524        X: &ArrayView2<Float>,
525        y: &ArrayView2<Float>,
526    ) -> SklResult<Self> {
527        // Add to buffer
528        for i in 0..X.nrows() {
529            self.state.buffer_X.push_back(X.row(i).to_owned());
530            self.state.buffer_y.push_back(y.row(i).to_owned());
531        }
532
533        // Process if buffer is full
534        if self.state.buffer_X.len() >= self.state.config.batch_size {
535            self = self.process_buffer()?;
536        }
537
538        Ok(self)
539    }
540
541    /// Process the current buffer
542    fn process_buffer(mut self) -> SklResult<Self> {
543        let batch_size = self.state.config.batch_size.min(self.state.buffer_X.len());
544
545        if batch_size == 0 {
546            return Ok(self);
547        }
548
549        // Extract batch from buffer
550        let mut X_batch = Array2::zeros((batch_size, self.state.base_model.n_features));
551        let mut y_batch = Array2::zeros((batch_size, self.state.base_model.n_outputs));
552
553        for i in 0..batch_size {
554            let x = self.state.buffer_X.pop_front().unwrap();
555            let y = self.state.buffer_y.pop_front().unwrap();
556            X_batch.row_mut(i).assign(&x);
557            y_batch.row_mut(i).assign(&y);
558        }
559
560        // Detect drift if enabled
561        if self.state.config.detect_drift {
562            let pred = self.predict(&X_batch.view())?;
563            let error: Float = (&y_batch - &pred).mapv(|x| x.powi(2)).mean().unwrap();
564
565            self.state.error_history.push_back(error);
566            if self.state.error_history.len() > self.state.config.drift_window_size {
567                self.state.error_history.pop_front();
568            }
569
570            // Check for drift
571            if self.state.error_history.len() >= self.state.config.drift_window_size {
572                let recent_error: Float = self
573                    .state
574                    .error_history
575                    .iter()
576                    .rev()
577                    .take(self.state.config.drift_window_size / 2)
578                    .sum::<Float>()
579                    / (self.state.config.drift_window_size / 2) as Float;
580
581                let old_error: Float = self
582                    .state
583                    .error_history
584                    .iter()
585                    .take(self.state.config.drift_window_size / 2)
586                    .sum::<Float>()
587                    / (self.state.config.drift_window_size / 2) as Float;
588
589                if recent_error > old_error * (1.0 + self.state.config.drift_threshold) {
590                    self.state.drift_detected = true;
591                    self.state.n_drift_events += 1;
592                    // Could reset model here if needed
593                }
594            }
595        }
596
597        // Update base model
598        let base_wrapper = IncrementalMultiOutputRegression {
599            state: self.state.base_model.clone(),
600            config: IncrementalMultiOutputRegressionConfig::default(),
601        };
602
603        let updated = base_wrapper.partial_fit(&X_batch.view(), &y_batch.view())?;
604        self.state.base_model = updated.state;
605
606        Ok(self)
607    }
608
609    /// Force processing of remaining buffer
610    pub fn flush_buffer(mut self) -> SklResult<Self> {
611        while !self.state.buffer_X.is_empty() {
612            self = self.process_buffer()?;
613        }
614        Ok(self)
615    }
616
617    /// Check if drift was detected
618    pub fn drift_detected(&self) -> bool {
619        self.state.drift_detected
620    }
621
622    /// Get number of drift events
623    pub fn n_drift_events(&self) -> usize {
624        self.state.n_drift_events
625    }
626
627    /// Get buffer size
628    pub fn buffer_size(&self) -> usize {
629        self.state.buffer_X.len()
630    }
631}
632
633impl Predict<ArrayView2<'_, Float>, Array2<Float>>
634    for StreamingMultiOutput<StreamingMultiOutputTrained>
635{
636    fn predict(&self, X: &ArrayView2<Float>) -> SklResult<Array2<Float>> {
637        let base_wrapper = IncrementalMultiOutputRegression {
638            state: self.state.base_model.clone(),
639            config: IncrementalMultiOutputRegressionConfig::default(),
640        };
641        base_wrapper.predict(X)
642    }
643}
644
645impl Estimator for StreamingMultiOutput<Untrained> {
646    type Config = StreamingMultiOutputConfig;
647    type Error = SklearsError;
648    type Float = Float;
649
650    fn config(&self) -> &Self::Config {
651        &self.config
652    }
653}
654
655impl Estimator for StreamingMultiOutput<StreamingMultiOutputTrained> {
656    type Config = StreamingMultiOutputConfig;
657    type Error = SklearsError;
658    type Float = Float;
659
660    fn config(&self) -> &Self::Config {
661        &self.state.config
662    }
663}
664
665// ============================================================================
666// Tests
667// ============================================================================
668
669#[cfg(test)]
670mod tests {
671    use super::*;
672    use approx::assert_abs_diff_eq;
673    // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
674    use scirs2_core::ndarray::array;
675
676    #[test]
677    #[allow(non_snake_case)]
678    fn test_incremental_regression_basic() {
679        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
680        let y = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
681
682        let model = IncrementalMultiOutputRegression::new()
683            .learning_rate(0.1)
684            .alpha(0.0001);
685
686        let trained = model.fit(&X.view(), &y.view()).unwrap();
687        let predictions = trained.predict(&X.view()).unwrap();
688
689        assert_eq!(predictions.dim(), (3, 2));
690        assert_eq!(trained.n_samples_seen(), 3);
691    }
692
693    #[test]
694    #[allow(non_snake_case)]
695    fn test_incremental_regression_partial_fit() {
696        let X1 = array![[1.0, 2.0], [2.0, 3.0]];
697        let y1 = array![[1.0, 2.0], [2.0, 3.0]];
698
699        let model = IncrementalMultiOutputRegression::new().learning_rate(0.1);
700        let trained = model.fit(&X1.view(), &y1.view()).unwrap();
701
702        // Partial fit with new data
703        let X2 = array![[3.0, 4.0], [4.0, 5.0]];
704        let y2 = array![[3.0, 4.0], [4.0, 5.0]];
705        let updated = trained.partial_fit(&X2.view(), &y2.view()).unwrap();
706
707        assert_eq!(updated.n_samples_seen(), 4);
708
709        let predictions = updated.predict(&X2.view()).unwrap();
710        assert_eq!(predictions.dim(), (2, 2));
711    }
712
713    #[test]
714    #[allow(non_snake_case)]
715    fn test_incremental_regression_learning_rate_decay() {
716        let X = array![[1.0, 2.0], [2.0, 3.0]];
717        let y = array![[1.0, 2.0], [2.0, 3.0]];
718
719        let model = IncrementalMultiOutputRegression::new().learning_rate(0.1);
720        let trained = model.fit(&X.view(), &y.view()).unwrap();
721
722        let initial_lr = trained.current_learning_rate();
723
724        // Partial fit should decay learning rate
725        let X2 = array![[3.0, 4.0]];
726        let y2 = array![[3.0, 4.0]];
727        let updated = trained.partial_fit(&X2.view(), &y2.view()).unwrap();
728
729        assert!(updated.current_learning_rate() < initial_lr);
730    }
731
732    #[test]
733    #[allow(non_snake_case)]
734    fn test_streaming_basic() {
735        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
736        let y = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
737
738        let model = StreamingMultiOutput::new().batch_size(2).learning_rate(0.1);
739
740        let trained = model.fit(&X.view(), &y.view()).unwrap();
741        let predictions = trained.predict(&X.view()).unwrap();
742
743        assert_eq!(predictions.dim(), (3, 2));
744    }
745
746    #[test]
747    #[allow(non_snake_case)]
748    fn test_streaming_update() {
749        let X = array![[1.0, 2.0], [2.0, 3.0]];
750        let y = array![[1.0, 2.0], [2.0, 3.0]];
751
752        let model = StreamingMultiOutput::new().batch_size(2);
753        let trained = model.fit(&X.view(), &y.view()).unwrap();
754
755        // Stream new data
756        let X_stream = array![[3.0, 4.0], [4.0, 5.0]];
757        let y_stream = array![[3.0, 4.0], [4.0, 5.0]];
758        let updated = trained
759            .update_stream(&X_stream.view(), &y_stream.view())
760            .unwrap();
761
762        let predictions = updated.predict(&X_stream.view()).unwrap();
763        assert_eq!(predictions.dim(), (2, 2));
764    }
765
766    #[test]
767    #[allow(non_snake_case)]
768    fn test_streaming_buffer() {
769        let X = array![[1.0, 2.0], [2.0, 3.0]];
770        let y = array![[1.0, 2.0], [2.0, 3.0]];
771
772        let model = StreamingMultiOutput::new().batch_size(5); // Large batch size
773        let trained = model.fit(&X.view(), &y.view()).unwrap();
774
775        // Add small amount of data (should buffer)
776        let X_stream = array![[3.0, 4.0]];
777        let y_stream = array![[3.0, 4.0]];
778        let updated = trained
779            .update_stream(&X_stream.view(), &y_stream.view())
780            .unwrap();
781
782        assert_eq!(updated.buffer_size(), 1);
783
784        // Flush buffer
785        let flushed = updated.flush_buffer().unwrap();
786        assert_eq!(flushed.buffer_size(), 0);
787    }
788
789    #[test]
790    #[allow(non_snake_case)]
791    fn test_streaming_drift_detection() {
792        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
793        let y = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
794
795        let model = StreamingMultiOutput::new()
796            .batch_size(2)
797            .detect_drift(true)
798            .learning_rate(0.1);
799
800        let trained = model.fit(&X.view(), &y.view()).unwrap();
801
802        // The model should track drift events
803        assert_eq!(trained.n_drift_events(), 0);
804    }
805
806    #[test]
807    #[allow(non_snake_case)]
808    fn test_incremental_regression_error_handling() {
809        let X = array![[1.0, 2.0], [2.0, 3.0]];
810        let y = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]]; // Mismatched
811
812        let model = IncrementalMultiOutputRegression::new();
813        assert!(model.fit(&X.view(), &y.view()).is_err());
814    }
815
816    #[test]
817    #[allow(non_snake_case)]
818    fn test_incremental_regression_prediction_error() {
819        let X = array![[1.0, 2.0], [2.0, 3.0]];
820        let y = array![[1.0, 2.0], [2.0, 3.0]];
821
822        let model = IncrementalMultiOutputRegression::new();
823        let trained = model.fit(&X.view(), &y.view()).unwrap();
824
825        // Wrong number of features
826        let X_test = array![[1.0]];
827        assert!(trained.predict(&X_test.view()).is_err());
828    }
829}