1use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
14use scirs2_core::random::{thread_rng, CoreRandom};
15use sklears_core::error::{Result, SklearsError};
16use sklears_core::traits::{Estimator, Fit, Predict, Trained, Untrained};
17use sklears_core::types::Float;
18use std::marker::PhantomData;
19
20#[derive(Debug, Clone)]
22pub struct IsolationNode {
23 pub feature: Option<usize>,
25 pub threshold: Float,
27 pub left: Option<Box<IsolationNode>>,
29 pub right: Option<Box<IsolationNode>>,
31 pub size: usize,
33}
34
35impl IsolationNode {
36 pub fn new_leaf(size: usize) -> Self {
38 Self {
39 feature: None,
40 threshold: 0.0,
41 left: None,
42 right: None,
43 size,
44 }
45 }
46
47 pub fn new_internal(
49 feature: usize,
50 threshold: Float,
51 left: Self,
52 right: Self,
53 size: usize,
54 ) -> Self {
55 Self {
56 feature: Some(feature),
57 threshold,
58 left: Some(Box::new(left)),
59 right: Some(Box::new(right)),
60 size,
61 }
62 }
63
64 pub fn path_length(&self, sample: &ArrayView1<Float>, current_depth: usize) -> Float {
66 if self.left.is_none() || self.right.is_none() {
67 current_depth as Float + self.avg_path_length_adjustment()
69 } else if let Some(feature_idx) = self.feature {
70 let value = sample[feature_idx];
71 if value < self.threshold {
72 self.left
73 .as_ref()
74 .unwrap()
75 .path_length(sample, current_depth + 1)
76 } else {
77 self.right
78 .as_ref()
79 .unwrap()
80 .path_length(sample, current_depth + 1)
81 }
82 } else {
83 current_depth as Float
84 }
85 }
86
87 fn avg_path_length_adjustment(&self) -> Float {
89 if self.size <= 1 {
90 0.0
91 } else {
92 2.0 * (((self.size - 1) as Float).ln() + 0.5772156649)
93 - 2.0 * (self.size - 1) as Float / self.size as Float
94 }
95 }
96}
97
98#[derive(Debug, Clone)]
100pub struct ExtendedIsolationNode {
101 pub normal: Option<Array1<Float>>,
103 pub intercept: Float,
105 pub left: Option<Box<ExtendedIsolationNode>>,
107 pub right: Option<Box<ExtendedIsolationNode>>,
109 pub size: usize,
111}
112
113impl ExtendedIsolationNode {
114 pub fn new_leaf(size: usize) -> Self {
116 Self {
117 normal: None,
118 intercept: 0.0,
119 left: None,
120 right: None,
121 size,
122 }
123 }
124
125 pub fn new_internal(
127 normal: Array1<Float>,
128 intercept: Float,
129 left: Self,
130 right: Self,
131 size: usize,
132 ) -> Self {
133 Self {
134 normal: Some(normal),
135 intercept,
136 left: Some(Box::new(left)),
137 right: Some(Box::new(right)),
138 size,
139 }
140 }
141
142 pub fn path_length(&self, sample: &ArrayView1<Float>, current_depth: usize) -> Float {
144 if self.left.is_none() || self.right.is_none() {
145 current_depth as Float + self.avg_path_length_adjustment()
146 } else if let Some(ref normal) = self.normal {
147 let projection = sample.dot(normal) - self.intercept;
148 if projection < 0.0 {
149 self.left
150 .as_ref()
151 .unwrap()
152 .path_length(sample, current_depth + 1)
153 } else {
154 self.right
155 .as_ref()
156 .unwrap()
157 .path_length(sample, current_depth + 1)
158 }
159 } else {
160 current_depth as Float
161 }
162 }
163
164 fn avg_path_length_adjustment(&self) -> Float {
166 if self.size <= 1 {
167 0.0
168 } else {
169 2.0 * (((self.size - 1) as Float).ln() + 0.5772156649)
170 - 2.0 * (self.size - 1) as Float / self.size as Float
171 }
172 }
173}
174
175#[derive(Debug, Clone)]
177pub struct IsolationForestConfig {
178 pub n_estimators: usize,
180 pub max_depth: Option<usize>,
182 pub max_samples: MaxSamples,
184 pub contamination: Float,
186 pub random_state: Option<u64>,
188 pub extended: bool,
190 pub extension_level: Option<usize>,
192}
193
194impl Default for IsolationForestConfig {
195 fn default() -> Self {
196 Self {
197 n_estimators: 100,
198 max_depth: None,
199 max_samples: MaxSamples::Auto,
200 contamination: 0.1,
201 random_state: None,
202 extended: false,
203 extension_level: None,
204 }
205 }
206}
207
208#[derive(Debug, Clone)]
210pub enum MaxSamples {
211 Auto,
213 Number(usize),
215 Fraction(Float),
217}
218
219pub struct IsolationForest<State = Untrained> {
221 config: IsolationForestConfig,
222 state: PhantomData<State>,
223 trees: Vec<IsolationNode>,
224 extended_trees: Vec<ExtendedIsolationNode>,
225 n_features: Option<usize>,
226 threshold: Option<Float>,
227 offset: Option<Float>,
228}
229
230impl IsolationForest<Untrained> {
231 pub fn new() -> Self {
233 Self {
234 config: IsolationForestConfig::default(),
235 state: PhantomData,
236 trees: Vec::new(),
237 extended_trees: Vec::new(),
238 n_features: None,
239 threshold: None,
240 offset: None,
241 }
242 }
243
244 pub fn n_estimators(mut self, n_estimators: usize) -> Self {
246 self.config.n_estimators = n_estimators;
247 self
248 }
249
250 pub fn max_depth(mut self, max_depth: usize) -> Self {
252 self.config.max_depth = Some(max_depth);
253 self
254 }
255
256 pub fn max_samples(mut self, max_samples: MaxSamples) -> Self {
258 self.config.max_samples = max_samples;
259 self
260 }
261
262 pub fn contamination(mut self, contamination: Float) -> Self {
264 self.config.contamination = contamination;
265 self
266 }
267
268 pub fn random_state(mut self, seed: u64) -> Self {
270 self.config.random_state = Some(seed);
271 self
272 }
273
274 pub fn extended(mut self, extended: bool) -> Self {
276 self.config.extended = extended;
277 self
278 }
279
280 pub fn extension_level(mut self, level: usize) -> Self {
282 self.config.extension_level = Some(level);
283 self
284 }
285}
286
287impl Default for IsolationForest<Untrained> {
288 fn default() -> Self {
289 Self::new()
290 }
291}
292
293impl Estimator for IsolationForest<Untrained> {
294 type Config = IsolationForestConfig;
295 type Error = SklearsError;
296 type Float = Float;
297
298 fn config(&self) -> &Self::Config {
299 &self.config
300 }
301}
302
303impl Fit<Array2<Float>, Array1<Float>> for IsolationForest<Untrained> {
304 type Fitted = IsolationForest<Trained>;
305
306 fn fit(mut self, x: &Array2<Float>, _y: &Array1<Float>) -> Result<Self::Fitted> {
307 let n_samples = x.nrows();
308 let n_features = x.ncols();
309
310 if n_samples == 0 {
311 return Err(SklearsError::InvalidInput(
312 "No samples provided".to_string(),
313 ));
314 }
315
316 let max_samples = match self.config.max_samples {
318 MaxSamples::Auto => n_samples.min(256),
319 MaxSamples::Number(n) => n.min(n_samples),
320 MaxSamples::Fraction(f) => ((n_samples as Float * f) as usize).min(n_samples),
321 };
322
323 let max_depth = self
325 .config
326 .max_depth
327 .unwrap_or_else(|| ((max_samples as Float).log2().ceil() as usize).min(100));
328
329 let mut rng = thread_rng();
331 let _ = self.config.random_state; if self.config.extended {
335 let extension_level = self.config.extension_level.unwrap_or(n_features);
337 for _ in 0..self.config.n_estimators {
338 let indices = sample_indices(n_samples, max_samples, &mut rng);
339 let tree =
340 build_extended_tree(x, &indices, 0, max_depth, extension_level, &mut rng)?;
341 self.extended_trees.push(tree);
342 }
343 } else {
344 for _ in 0..self.config.n_estimators {
346 let indices = sample_indices(n_samples, max_samples, &mut rng);
347 let tree = build_isolation_tree(x, &indices, 0, max_depth, &mut rng)?;
348 self.trees.push(tree);
349 }
350 }
351
352 let avg_path_length = average_path_length(max_samples);
354 let offset = -0.5;
355 let threshold = 2.0_f64.powf(-self.config.contamination / avg_path_length);
356
357 Ok(IsolationForest::<Trained> {
358 config: self.config,
359 state: PhantomData,
360 trees: self.trees,
361 extended_trees: self.extended_trees,
362 n_features: Some(n_features),
363 threshold: Some(threshold as Float),
364 offset: Some(offset as Float),
365 })
366 }
367}
368
369impl IsolationForest<Trained> {
370 pub fn decision_function(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
372 let n_samples = x.nrows();
373 let mut scores = Array1::zeros(n_samples);
374
375 for i in 0..n_samples {
376 let sample = x.row(i);
377 let avg_path = if self.config.extended {
378 self.average_path_length_extended(&sample)
379 } else {
380 self.average_path_length_standard(&sample)
381 };
382
383 let max_samples = match self.config.max_samples {
384 MaxSamples::Auto => 256,
385 MaxSamples::Number(n) => n,
386 MaxSamples::Fraction(_) => 256, };
388
389 let c = average_path_length(max_samples);
390 scores[i] = 2.0_f64.powf(-avg_path / c) as Float;
391 }
392
393 Ok(scores)
394 }
395
396 fn average_path_length_standard(&self, sample: &ArrayView1<Float>) -> Float {
398 if self.trees.is_empty() {
399 return 0.0;
400 }
401
402 let sum: Float = self
403 .trees
404 .iter()
405 .map(|tree| tree.path_length(sample, 0))
406 .sum();
407
408 sum / self.trees.len() as Float
409 }
410
411 fn average_path_length_extended(&self, sample: &ArrayView1<Float>) -> Float {
413 if self.extended_trees.is_empty() {
414 return 0.0;
415 }
416
417 let sum: Float = self
418 .extended_trees
419 .iter()
420 .map(|tree| tree.path_length(sample, 0))
421 .sum();
422
423 sum / self.extended_trees.len() as Float
424 }
425
426 pub fn threshold(&self) -> Float {
428 self.threshold.unwrap_or(0.5)
429 }
430}
431
432impl Predict<Array2<Float>, Array1<i32>> for IsolationForest<Trained> {
433 fn predict(&self, x: &Array2<Float>) -> Result<Array1<i32>> {
434 let scores = self.decision_function(x)?;
435 let threshold = self.threshold();
436
437 let predictions = scores.mapv(|score| {
438 if score >= threshold {
439 -1 } else {
441 1 }
443 });
444
445 Ok(predictions)
446 }
447}
448
449fn sample_indices(n_samples: usize, n_to_sample: usize, rng: &mut CoreRandom) -> Vec<usize> {
451 use scirs2_core::random::essentials::Uniform;
452
453 let mut indices: Vec<usize> = (0..n_samples).collect();
454
455 for i in 0..n_to_sample.min(n_samples) {
457 let uniform = Uniform::new(i, n_samples).unwrap();
458 let j = rng.sample(uniform);
459 indices.swap(i, j);
460 }
461
462 indices.truncate(n_to_sample);
463 indices
464}
465
466fn build_isolation_tree(
468 x: &Array2<Float>,
469 indices: &[usize],
470 depth: usize,
471 max_depth: usize,
472 rng: &mut CoreRandom,
473) -> Result<IsolationNode> {
474 use scirs2_core::random::essentials::Uniform;
475
476 let n_samples = indices.len();
477 let n_features = x.ncols();
478
479 if n_samples <= 1 || depth >= max_depth {
481 return Ok(IsolationNode::new_leaf(n_samples));
482 }
483
484 let first_sample = x.row(indices[0]);
486 let all_identical = indices.iter().skip(1).all(|&idx| {
487 let sample = x.row(idx);
488 sample
489 .iter()
490 .zip(first_sample.iter())
491 .all(|(a, b)| (a - b).abs() < 1e-10)
492 });
493
494 if all_identical {
495 return Ok(IsolationNode::new_leaf(n_samples));
496 }
497
498 let feature_dist = Uniform::new(0, n_features).map_err(|_| {
500 SklearsError::InvalidInput("Failed to create uniform distribution".to_string())
501 })?;
502 let feature = rng.sample(feature_dist);
503
504 let mut min_val = Float::INFINITY;
506 let mut max_val = Float::NEG_INFINITY;
507
508 for &idx in indices {
509 let val = x[[idx, feature]];
510 min_val = min_val.min(val);
511 max_val = max_val.max(val);
512 }
513
514 if (max_val - min_val).abs() < 1e-10 {
515 return Ok(IsolationNode::new_leaf(n_samples));
516 }
517
518 let split_dist = Uniform::new(min_val, max_val).map_err(|_| {
520 SklearsError::InvalidInput("Failed to create split distribution".to_string())
521 })?;
522 let threshold = rng.sample(split_dist);
523
524 let mut left_indices = Vec::new();
526 let mut right_indices = Vec::new();
527
528 for &idx in indices {
529 if x[[idx, feature]] < threshold {
530 left_indices.push(idx);
531 } else {
532 right_indices.push(idx);
533 }
534 }
535
536 if left_indices.is_empty() || right_indices.is_empty() {
538 return Ok(IsolationNode::new_leaf(n_samples));
539 }
540
541 let left_tree = build_isolation_tree(x, &left_indices, depth + 1, max_depth, rng)?;
543 let right_tree = build_isolation_tree(x, &right_indices, depth + 1, max_depth, rng)?;
544
545 Ok(IsolationNode::new_internal(
546 feature, threshold, left_tree, right_tree, n_samples,
547 ))
548}
549
550fn build_extended_tree(
552 x: &Array2<Float>,
553 indices: &[usize],
554 depth: usize,
555 max_depth: usize,
556 extension_level: usize,
557 rng: &mut CoreRandom,
558) -> Result<ExtendedIsolationNode> {
559 use scirs2_core::random::essentials::Uniform;
560
561 let n_samples = indices.len();
562 let n_features = x.ncols();
563
564 if n_samples <= 1 || depth >= max_depth {
566 return Ok(ExtendedIsolationNode::new_leaf(n_samples));
567 }
568
569 let n_dims = extension_level.min(n_features);
571 let mut normal = Array1::zeros(n_features);
572
573 for i in 0..n_dims {
574 let normal_dist = Uniform::new(-1.0, 1.0).map_err(|_| {
575 SklearsError::InvalidInput("Failed to create normal distribution".to_string())
576 })?;
577 normal[i] = rng.sample(normal_dist);
578 }
579
580 let dot_product: Float = normal.dot(&normal);
582 let norm = dot_product.sqrt();
583 if norm > 1e-10 {
584 normal /= norm;
585 } else {
586 return Ok(ExtendedIsolationNode::new_leaf(n_samples));
587 }
588
589 let mut projections = Vec::with_capacity(n_samples);
591 for &idx in indices {
592 let sample = x.row(idx);
593 let projection = sample.dot(&normal);
594 projections.push(projection);
595 }
596
597 let min_proj = projections
599 .iter()
600 .cloned()
601 .fold(Float::INFINITY, Float::min);
602 let max_proj = projections
603 .iter()
604 .cloned()
605 .fold(Float::NEG_INFINITY, Float::max);
606
607 if (max_proj - min_proj).abs() < 1e-10 {
608 return Ok(ExtendedIsolationNode::new_leaf(n_samples));
609 }
610
611 let intercept_dist = Uniform::new(min_proj, max_proj).map_err(|_| {
613 SklearsError::InvalidInput("Failed to create intercept distribution".to_string())
614 })?;
615 let intercept = rng.sample(intercept_dist);
616
617 let mut left_indices = Vec::new();
619 let mut right_indices = Vec::new();
620
621 for (&idx, &proj) in indices.iter().zip(projections.iter()) {
622 if proj < intercept {
623 left_indices.push(idx);
624 } else {
625 right_indices.push(idx);
626 }
627 }
628
629 if left_indices.is_empty() || right_indices.is_empty() {
630 return Ok(ExtendedIsolationNode::new_leaf(n_samples));
631 }
632
633 let left_tree =
635 build_extended_tree(x, &left_indices, depth + 1, max_depth, extension_level, rng)?;
636 let right_tree = build_extended_tree(
637 x,
638 &right_indices,
639 depth + 1,
640 max_depth,
641 extension_level,
642 rng,
643 )?;
644
645 Ok(ExtendedIsolationNode::new_internal(
646 normal, intercept, left_tree, right_tree, n_samples,
647 ))
648}
649
650fn average_path_length(n: usize) -> Float {
652 if n <= 1 {
653 0.0
654 } else {
655 2.0 * (((n - 1) as Float).ln() + 0.5772156649) - 2.0 * (n - 1) as Float / n as Float
656 }
657}
658
659pub struct StreamingIsolationForest {
661 config: IsolationForestConfig,
662 trees: Vec<IsolationNode>,
663 window_size: usize,
664 buffer: Vec<Array1<Float>>,
665 update_frequency: usize,
666 samples_seen: usize,
667}
668
669impl StreamingIsolationForest {
670 pub fn new(config: IsolationForestConfig, window_size: usize, update_frequency: usize) -> Self {
672 Self {
673 config,
674 trees: Vec::new(),
675 window_size,
676 buffer: Vec::new(),
677 update_frequency,
678 samples_seen: 0,
679 }
680 }
681
682 pub fn process_sample(&mut self, sample: Array1<Float>) -> Result<Float> {
684 self.samples_seen += 1;
685
686 self.buffer.push(sample.clone());
688 if self.buffer.len() > self.window_size {
689 self.buffer.remove(0);
690 }
691
692 if self.samples_seen % self.update_frequency == 0 && self.buffer.len() >= 32 {
694 self.rebuild_trees()?;
695 }
696
697 if self.trees.is_empty() {
699 return Ok(0.5); }
701
702 let sample_view = sample.view();
703 let avg_path: Float = self
704 .trees
705 .iter()
706 .map(|tree| tree.path_length(&sample_view, 0))
707 .sum::<Float>()
708 / self.trees.len() as Float;
709
710 let max_samples = self.buffer.len().min(256);
711 let c = average_path_length(max_samples);
712 let score = 2.0_f64.powf(-avg_path as f64 / c) as Float;
713
714 Ok(score)
715 }
716
717 fn rebuild_trees(&mut self) -> Result<()> {
719 let n_samples = self.buffer.len();
720 if n_samples < 2 {
721 return Ok(());
722 }
723
724 let n_features = self.buffer[0].len();
726 let mut x = Array2::zeros((n_samples, n_features));
727 for (i, sample) in self.buffer.iter().enumerate() {
728 x.row_mut(i).assign(sample);
729 }
730
731 let max_samples = n_samples.min(256);
732 let max_depth = ((max_samples as Float).log2().ceil() as usize).min(100);
733
734 let mut rng = thread_rng();
736 let _ = self.config.random_state; self.trees.clear();
740 for _ in 0..self.config.n_estimators {
741 let indices = sample_indices(n_samples, max_samples, &mut rng);
742 let tree = build_isolation_tree(&x, &indices, 0, max_depth, &mut rng)?;
743 self.trees.push(tree);
744 }
745
746 Ok(())
747 }
748}
749
750#[cfg(test)]
751mod tests {
752 use super::*;
753 use approx::assert_relative_eq;
754 use scirs2_core::ndarray::s;
755
756 #[test]
757 fn test_isolation_forest_basic() {
758 let mut x = Array2::zeros((100, 2));
759
760 for i in 0..90 {
762 let angle = (i as Float) * 2.0 * std::f64::consts::PI / 90.0;
763 let radius = 0.5 + ((i % 10) as Float) * 0.05;
764 x[[i, 0]] = radius * angle.cos();
765 x[[i, 1]] = radius * angle.sin();
766 }
767
768 for i in 90..100 {
770 x[[i, 0]] = 20.0 + ((i - 90) as Float);
771 x[[i, 1]] = 20.0 + ((i - 90) as Float);
772 }
773
774 let y = Array1::zeros(100); let model = IsolationForest::new()
777 .n_estimators(100)
778 .contamination(0.1)
779 .random_state(42);
780
781 let fitted = model.fit(&x, &y).unwrap();
782 let scores = fitted.decision_function(&x).unwrap();
783
784 let outlier_avg: Float = scores.slice(s![90..]).mean().unwrap();
786 let inlier_avg: Float = scores.slice(s![..90]).mean().unwrap();
787
788 assert!(
789 outlier_avg > inlier_avg,
790 "Outliers should have higher average anomaly scores: outlier_avg={}, inlier_avg={}",
791 outlier_avg,
792 inlier_avg
793 );
794 }
795
796 #[test]
797 fn test_extended_isolation_forest() {
798 let mut x = Array2::zeros((50, 2));
799
800 for i in 0..45 {
802 x[[i, 0]] = (i as Float / 22.5) - 1.0;
803 x[[i, 1]] = (i as Float / 22.5) - 1.0;
804 }
805
806 for i in 45..50 {
808 x[[i, 0]] = ((i - 45) as Float) * 5.0;
809 x[[i, 1]] = ((i - 45) as Float) * 5.0;
810 }
811
812 let y = Array1::zeros(50);
813
814 let model = IsolationForest::new()
815 .n_estimators(30)
816 .extended(true)
817 .extension_level(2)
818 .random_state(42);
819
820 let fitted = model.fit(&x, &y).unwrap();
821 let predictions = fitted.predict(&x).unwrap();
822
823 let inliers: i32 = predictions
825 .slice(s![..45])
826 .iter()
827 .filter(|&&x| x == 1)
828 .count() as i32;
829 assert!(inliers > 40, "Most normal samples should be inliers");
830 }
831
832 #[test]
833 fn test_streaming_isolation_forest() {
834 let config = IsolationForestConfig {
835 n_estimators: 20,
836 ..Default::default()
837 };
838
839 let mut streaming_if = StreamingIsolationForest::new(config, 100, 50);
840
841 for i in 0..60 {
843 let sample =
844 Array1::from_vec(vec![(i as Float / 30.0) - 1.0, (i as Float / 30.0) - 1.0]);
845 let _score = streaming_if.process_sample(sample).unwrap();
846 }
847
848 let outlier = Array1::from_vec(vec![10.0, 10.0]);
850 let outlier_score = streaming_if.process_sample(outlier).unwrap();
851
852 let normal = Array1::from_vec(vec![0.1, 0.1]);
854 let normal_score = streaming_if.process_sample(normal).unwrap();
855
856 if streaming_if.trees.len() > 0 {
858 assert!(
859 outlier_score > normal_score,
860 "Outlier should have higher score than normal sample"
861 );
862 }
863 }
864
865 #[test]
866 fn test_average_path_length() {
867 let c_256 = average_path_length(256);
868 assert_relative_eq!(c_256, 10.24, epsilon = 0.1);
870
871 let c_100 = average_path_length(100);
872 assert_relative_eq!(c_100, 8.36, epsilon = 0.1);
873
874 assert!(
876 c_256 > c_100,
877 "Average path length should increase with sample size"
878 );
879
880 assert!(average_path_length(10) > 0.0);
882 assert!(average_path_length(2) > 0.0);
883 }
884
885 #[test]
886 fn test_isolation_node_path_length() {
887 let left_leaf = IsolationNode::new_leaf(5);
889 let right_leaf = IsolationNode::new_leaf(5);
890 let root = IsolationNode::new_internal(0, 0.5, left_leaf, right_leaf, 10);
891
892 let sample_left = Array1::from_vec(vec![0.0, 0.0]);
894 let path_left = root.path_length(&sample_left.view(), 0);
895 assert!(path_left > 0.0, "Path length should be positive");
896
897 let sample_right = Array1::from_vec(vec![1.0, 1.0]);
898 let path_right = root.path_length(&sample_right.view(), 0);
899 assert!(path_right > 0.0, "Path length should be positive");
900 }
901}