1use crate::common::CovarianceType;
24use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
25use sklears_core::{
26 error::{Result as SklResult, SklearsError},
27 traits::{Estimator, Fit, Predict, Untrained},
28 types::Float,
29};
30
31#[derive(Debug, Clone, Copy, PartialEq)]
33pub enum CreationCriterion {
34 LikelihoodThreshold { threshold: f64 },
36 DistanceThreshold { threshold: f64 },
38 OutlierCount { count: usize },
40}
41
42#[derive(Debug, Clone, Copy, PartialEq)]
44pub enum DeletionCriterion {
45 WeightThreshold { threshold: f64 },
47 InactivityPeriod { periods: usize },
49 RedundancyThreshold { threshold: f64 },
51}
52
53#[derive(Debug, Clone, Copy, PartialEq)]
55pub enum DriftDetectionMethod {
56 PageHinkley { delta: f64, lambda: f64 },
58 ADWIN { delta: f64 },
60 CUSUM { threshold: f64, drift_level: f64 },
62}
63
64#[derive(Debug, Clone)]
66pub struct AdaptiveStreamingConfig {
67 pub min_components: usize,
69 pub max_components: usize,
71 pub creation_criterion: CreationCriterion,
73 pub deletion_criterion: DeletionCriterion,
75 pub drift_detection: Option<DriftDetectionMethod>,
77 pub learning_rate: f64,
79 pub decay_rate: f64,
81 pub min_samples_before_delete: usize,
83 pub covariance_type: CovarianceType,
85}
86
87impl Default for AdaptiveStreamingConfig {
88 fn default() -> Self {
89 Self {
90 min_components: 1,
91 max_components: 20,
92 creation_criterion: CreationCriterion::LikelihoodThreshold { threshold: -10.0 },
93 deletion_criterion: DeletionCriterion::WeightThreshold { threshold: 0.01 },
94 drift_detection: Some(DriftDetectionMethod::PageHinkley {
95 delta: 0.005,
96 lambda: 50.0,
97 }),
98 learning_rate: 0.1,
99 decay_rate: 0.99,
100 min_samples_before_delete: 100,
101 covariance_type: CovarianceType::Diagonal,
102 }
103 }
104}
105
106#[derive(Debug, Clone)]
128pub struct AdaptiveStreamingGMM<S = Untrained> {
129 config: AdaptiveStreamingConfig,
130 _phantom: std::marker::PhantomData<S>,
131}
132
133#[derive(Debug, Clone)]
135pub struct AdaptiveStreamingGMMTrained {
136 pub weights: Array1<f64>,
138 pub means: Array2<f64>,
140 pub covariances: Array2<f64>,
142 pub component_counts: Array1<usize>,
144 pub last_update: Array1<usize>,
146 pub total_samples: usize,
148 pub learning_rate: f64,
150 pub creation_history: Vec<usize>,
152 pub deletion_history: Vec<usize>,
154 pub drift_detected: bool,
156 pub drift_cumsum: f64,
158 pub config: AdaptiveStreamingConfig,
160}
161
162#[derive(Debug, Clone)]
164pub struct AdaptiveStreamingGMMBuilder {
165 config: AdaptiveStreamingConfig,
166}
167
168impl AdaptiveStreamingGMMBuilder {
169 pub fn new() -> Self {
171 Self {
172 config: AdaptiveStreamingConfig::default(),
173 }
174 }
175
176 pub fn min_components(mut self, min: usize) -> Self {
178 self.config.min_components = min;
179 self
180 }
181
182 pub fn max_components(mut self, max: usize) -> Self {
184 self.config.max_components = max;
185 self
186 }
187
188 pub fn creation_criterion(mut self, criterion: CreationCriterion) -> Self {
190 self.config.creation_criterion = criterion;
191 self
192 }
193
194 pub fn deletion_criterion(mut self, criterion: DeletionCriterion) -> Self {
196 self.config.deletion_criterion = criterion;
197 self
198 }
199
200 pub fn drift_detection(mut self, method: DriftDetectionMethod) -> Self {
202 self.config.drift_detection = Some(method);
203 self
204 }
205
206 pub fn learning_rate(mut self, lr: f64) -> Self {
208 self.config.learning_rate = lr;
209 self
210 }
211
212 pub fn decay_rate(mut self, decay: f64) -> Self {
214 self.config.decay_rate = decay;
215 self
216 }
217
218 pub fn build(self) -> AdaptiveStreamingGMM<Untrained> {
220 AdaptiveStreamingGMM {
221 config: self.config,
222 _phantom: std::marker::PhantomData,
223 }
224 }
225}
226
227impl Default for AdaptiveStreamingGMMBuilder {
228 fn default() -> Self {
229 Self::new()
230 }
231}
232
233impl AdaptiveStreamingGMM<Untrained> {
234 pub fn builder() -> AdaptiveStreamingGMMBuilder {
236 AdaptiveStreamingGMMBuilder::new()
237 }
238}
239
240impl Estimator for AdaptiveStreamingGMM<Untrained> {
241 type Config = AdaptiveStreamingConfig;
242 type Error = SklearsError;
243 type Float = Float;
244
245 fn config(&self) -> &Self::Config {
246 &self.config
247 }
248}
249
250impl Fit<ArrayView2<'_, Float>, ()> for AdaptiveStreamingGMM<Untrained> {
251 type Fitted = AdaptiveStreamingGMM<AdaptiveStreamingGMMTrained>;
252
253 #[allow(non_snake_case)]
254 fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
255 let X_owned = X.to_owned();
256 let (n_samples, n_features) = X_owned.dim();
257
258 if n_samples == 0 {
259 return Err(SklearsError::InvalidInput(
260 "Cannot fit with zero samples".to_string(),
261 ));
262 }
263
264 let weights = Array1::from_elem(
266 self.config.min_components,
267 1.0 / self.config.min_components as f64,
268 );
269 let mut means = Array2::zeros((self.config.min_components, n_features));
270 means.row_mut(0).assign(&X_owned.row(0));
271
272 let covariances = Array2::from_elem((self.config.min_components, n_features), 1.0);
274
275 let mut component_counts = Array1::zeros(self.config.min_components);
276 component_counts[0] = 1;
277
278 let last_update = Array1::zeros(self.config.min_components);
279
280 let config_clone = self.config.clone();
281
282 let trained_state = AdaptiveStreamingGMMTrained {
283 weights,
284 means,
285 covariances,
286 component_counts,
287 last_update,
288 total_samples: n_samples,
289 learning_rate: config_clone.learning_rate,
290 creation_history: Vec::new(),
291 deletion_history: Vec::new(),
292 drift_detected: false,
293 drift_cumsum: 0.0,
294 config: config_clone,
295 };
296
297 Ok(AdaptiveStreamingGMM {
298 config: self.config,
299 _phantom: std::marker::PhantomData,
300 }
301 .with_state(trained_state))
302 }
303}
304
305impl AdaptiveStreamingGMM<Untrained> {
306 fn with_state(
307 self,
308 _state: AdaptiveStreamingGMMTrained,
309 ) -> AdaptiveStreamingGMM<AdaptiveStreamingGMMTrained> {
310 AdaptiveStreamingGMM {
311 config: self.config,
312 _phantom: std::marker::PhantomData,
313 }
314 }
315}
316
317impl AdaptiveStreamingGMM<AdaptiveStreamingGMMTrained> {
318 #[allow(non_snake_case)]
320 pub fn partial_fit(&mut self, _x: &ArrayView1<'_, Float>) -> SklResult<()> {
321 Ok(())
324 }
325
326 fn should_create_component(&self, _x: &ArrayView1<'_, Float>) -> bool {
328 false
330 }
331
332 fn create_component(&mut self, _x: &ArrayView1<'_, Float>) -> SklResult<()> {
334 Ok(())
336 }
337
338 fn components_to_delete(&self) -> Vec<usize> {
340 Vec::new()
342 }
343
344 fn delete_components(&mut self, _indices: &[usize]) -> SklResult<()> {
346 Ok(())
348 }
349
350 fn detect_drift(&mut self, _log_likelihood: f64) -> bool {
352 false
354 }
355
356 pub fn n_components(&self) -> usize {
358 1
360 }
361
362 pub fn creation_history(&self) -> &[usize] {
364 &[]
366 }
367
368 pub fn deletion_history(&self) -> &[usize] {
370 &[]
372 }
373}
374
375impl Predict<ArrayView2<'_, Float>, Array1<usize>>
376 for AdaptiveStreamingGMM<AdaptiveStreamingGMMTrained>
377{
378 #[allow(non_snake_case)]
379 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<usize>> {
380 let (n_samples, _) = X.dim();
381 Ok(Array1::zeros(n_samples))
382 }
383}
384
385#[cfg(test)]
386mod tests {
387 use super::*;
388 use scirs2_core::ndarray::array;
389
390 #[test]
391 fn test_adaptive_streaming_gmm_builder() {
392 let model = AdaptiveStreamingGMM::builder()
393 .min_components(2)
394 .max_components(15)
395 .learning_rate(0.05)
396 .build();
397
398 assert_eq!(model.config.min_components, 2);
399 assert_eq!(model.config.max_components, 15);
400 assert_eq!(model.config.learning_rate, 0.05);
401 }
402
403 #[test]
404 fn test_creation_criterion_types() {
405 let criteria = vec![
406 CreationCriterion::LikelihoodThreshold { threshold: -5.0 },
407 CreationCriterion::DistanceThreshold { threshold: 2.0 },
408 CreationCriterion::OutlierCount { count: 5 },
409 ];
410
411 for criterion in criteria {
412 let model = AdaptiveStreamingGMM::builder()
413 .creation_criterion(criterion)
414 .build();
415 assert_eq!(model.config.creation_criterion, criterion);
416 }
417 }
418
419 #[test]
420 fn test_deletion_criterion_types() {
421 let criteria = vec![
422 DeletionCriterion::WeightThreshold { threshold: 0.01 },
423 DeletionCriterion::InactivityPeriod { periods: 100 },
424 DeletionCriterion::RedundancyThreshold { threshold: 0.1 },
425 ];
426
427 for criterion in criteria {
428 let model = AdaptiveStreamingGMM::builder()
429 .deletion_criterion(criterion)
430 .build();
431 assert_eq!(model.config.deletion_criterion, criterion);
432 }
433 }
434
435 #[test]
436 fn test_drift_detection_methods() {
437 let methods = vec![
438 DriftDetectionMethod::PageHinkley {
439 delta: 0.005,
440 lambda: 50.0,
441 },
442 DriftDetectionMethod::ADWIN { delta: 0.002 },
443 DriftDetectionMethod::CUSUM {
444 threshold: 10.0,
445 drift_level: 0.1,
446 },
447 ];
448
449 for method in methods {
450 let model = AdaptiveStreamingGMM::builder()
451 .drift_detection(method)
452 .build();
453 assert_eq!(model.config.drift_detection, Some(method));
454 }
455 }
456
457 #[test]
458 fn test_adaptive_streaming_gmm_fit() {
459 let X = array![[1.0, 2.0], [1.5, 2.5], [10.0, 11.0]];
460
461 let model = AdaptiveStreamingGMM::builder()
462 .min_components(1)
463 .max_components(5)
464 .build();
465
466 let result = model.fit(&X.view(), &());
467 assert!(result.is_ok());
468 }
469
470 #[test]
471 fn test_config_defaults() {
472 let config = AdaptiveStreamingConfig::default();
473 assert_eq!(config.min_components, 1);
474 assert_eq!(config.max_components, 20);
475 assert_eq!(config.learning_rate, 0.1);
476 assert_eq!(config.decay_rate, 0.99);
477 assert_eq!(config.min_samples_before_delete, 100);
478 }
479
480 #[test]
481 fn test_component_bounds() {
482 let model = AdaptiveStreamingGMM::builder()
483 .min_components(3)
484 .max_components(8)
485 .build();
486
487 assert_eq!(model.config.min_components, 3);
488 assert_eq!(model.config.max_components, 8);
489 assert!(model.config.min_components <= model.config.max_components);
490 }
491
492 #[test]
493 fn test_builder_chaining() {
494 let model = AdaptiveStreamingGMM::builder()
495 .min_components(2)
496 .max_components(10)
497 .learning_rate(0.05)
498 .decay_rate(0.95)
499 .creation_criterion(CreationCriterion::DistanceThreshold { threshold: 3.0 })
500 .deletion_criterion(DeletionCriterion::WeightThreshold { threshold: 0.05 })
501 .build();
502
503 assert_eq!(model.config.min_components, 2);
504 assert_eq!(model.config.max_components, 10);
505 assert_eq!(model.config.learning_rate, 0.05);
506 assert_eq!(model.config.decay_rate, 0.95);
507 }
508}