1use crate::error::{Result, ScryLearnError};
30
31#[derive(Clone, Debug)]
37#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
38enum ITreeNode {
39 Split {
41 feature: usize,
43 threshold: f64,
45 left: Box<ITreeNode>,
47 right: Box<ITreeNode>,
49 },
50 Leaf {
52 size: usize,
54 },
55}
56
57#[derive(Clone, Debug)]
59#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
60struct IsolationTree {
61 root: ITreeNode,
62}
63
64impl IsolationTree {
65 fn build(data: &[Vec<f64>], max_depth: usize, rng: &mut crate::rng::FastRng) -> Self {
67 let root = Self::build_node(data, 0, max_depth, rng);
68 Self { root }
69 }
70
71 fn build_node(
72 data: &[Vec<f64>],
73 depth: usize,
74 max_depth: usize,
75 rng: &mut crate::rng::FastRng,
76 ) -> ITreeNode {
77 let n = data.len();
78 if n <= 1 || depth >= max_depth {
79 return ITreeNode::Leaf { size: n };
80 }
81
82 let n_features = data[0].len();
83 if n_features == 0 {
84 return ITreeNode::Leaf { size: n };
85 }
86
87 let feature = rng.usize(0..n_features);
89
90 let mut min_val = f64::INFINITY;
92 let mut max_val = f64::NEG_INFINITY;
93 for sample in data {
94 let v = sample[feature];
95 if v < min_val {
96 min_val = v;
97 }
98 if v > max_val {
99 max_val = v;
100 }
101 }
102
103 if (max_val - min_val).abs() < f64::EPSILON {
105 return ITreeNode::Leaf { size: n };
106 }
107
108 let threshold = min_val + rng.f64() * (max_val - min_val);
110
111 let mut left_data = Vec::new();
113 let mut right_data = Vec::new();
114 for sample in data {
115 if sample[feature] < threshold {
116 left_data.push(sample.clone());
117 } else {
118 right_data.push(sample.clone());
119 }
120 }
121
122 if left_data.is_empty() || right_data.is_empty() {
124 return ITreeNode::Leaf { size: n };
125 }
126
127 let left = Self::build_node(&left_data, depth + 1, max_depth, rng);
128 let right = Self::build_node(&right_data, depth + 1, max_depth, rng);
129
130 ITreeNode::Split {
131 feature,
132 threshold,
133 left: Box::new(left),
134 right: Box::new(right),
135 }
136 }
137
138 fn path_length(&self, sample: &[f64]) -> f64 {
140 Self::path_length_node(&self.root, sample, 0)
141 }
142
143 fn path_length_node(node: &ITreeNode, sample: &[f64], depth: usize) -> f64 {
144 match node {
145 ITreeNode::Leaf { size } => depth as f64 + average_path_length(*size),
146 ITreeNode::Split {
147 feature,
148 threshold,
149 left,
150 right,
151 } => {
152 if sample[*feature] < *threshold {
153 Self::path_length_node(left, sample, depth + 1)
154 } else {
155 Self::path_length_node(right, sample, depth + 1)
156 }
157 }
158 }
159 }
160}
161
162fn average_path_length(n: usize) -> f64 {
170 if n <= 1 {
171 return 0.0;
172 }
173 let n_f = n as f64;
174 2.0 * (((n_f - 1.0).ln()) + 0.577_215_664_9) - 2.0 * (n_f - 1.0) / n_f
175}
176
177#[derive(Clone, Debug)]
207#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
208#[non_exhaustive]
209pub struct IsolationForest {
210 n_estimators: usize,
212 max_samples: usize,
214 contamination: f64,
216 random_state: Option<u64>,
218 trees: Vec<IsolationTree>,
220 threshold: f64,
222 training_sub_size: usize,
224 #[cfg_attr(feature = "serde", serde(default))]
225 _schema_version: u32,
226}
227
228impl IsolationForest {
229 pub fn new() -> Self {
233 Self {
234 n_estimators: 100,
235 max_samples: 256,
236 contamination: 0.1,
237 random_state: None,
238 trees: Vec::new(),
239 threshold: 0.5,
240 training_sub_size: 0,
241 _schema_version: crate::version::SCHEMA_VERSION,
242 }
243 }
244
245 pub fn n_estimators(mut self, n: usize) -> Self {
247 self.n_estimators = n;
248 self
249 }
250
251 pub fn max_samples(mut self, n: usize) -> Self {
253 self.max_samples = n;
254 self
255 }
256
257 pub fn contamination(mut self, c: f64) -> Self {
261 self.contamination = c;
262 self
263 }
264
265 pub fn random_state(mut self, seed: u64) -> Self {
267 self.random_state = Some(seed);
268 self
269 }
270
271 pub fn seed(self, s: u64) -> Self {
273 self.random_state(s)
274 }
275
276 pub fn fit(&mut self, features: &[Vec<f64>]) -> Result<()> {
285 for (i, row) in features.iter().enumerate() {
286 for (j, &v) in row.iter().enumerate() {
287 if !v.is_finite() {
288 return Err(ScryLearnError::InvalidData(format!(
289 "non-finite value ({v}) in feature[{j}] at sample {i}"
290 )));
291 }
292 }
293 }
294 if features.is_empty() {
295 return Err(ScryLearnError::EmptyDataset);
296 }
297 if self.contamination <= 0.0 || self.contamination > 0.5 {
298 return Err(ScryLearnError::InvalidParameter(format!(
299 "contamination must be in (0, 0.5], got {}",
300 self.contamination
301 )));
302 }
303
304 let n = features.len();
305 let sub_size = self.max_samples.min(n);
306 let max_depth = (sub_size as f64).log2().ceil() as usize;
307 let seed = self.random_state.unwrap_or(42);
308
309 let mut trees = Vec::with_capacity(self.n_estimators);
310
311 for i in 0..self.n_estimators {
312 let mut rng = crate::rng::FastRng::new(seed.wrapping_add(i as u64));
313
314 let subsample: Vec<Vec<f64>> = (0..sub_size)
316 .map(|_| features[rng.usize(0..n)].clone())
317 .collect();
318
319 let tree = IsolationTree::build(&subsample, max_depth, &mut rng);
320 trees.push(tree);
321 }
322
323 self.trees = trees;
324 self.training_sub_size = sub_size;
325
326 let mut scores = self.predict(features);
328 scores.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
330
331 let cutoff_idx = ((self.contamination * n as f64).ceil() as usize)
333 .min(n)
334 .max(1);
335 self.threshold = scores[cutoff_idx - 1];
336
337 Ok(())
338 }
339
340 pub fn predict(&self, features: &[Vec<f64>]) -> Vec<f64> {
349 let n = features.len();
350 let sub_size = if self.training_sub_size > 0 {
353 self.training_sub_size
354 } else {
355 self.max_samples.min(n.max(1))
356 };
357 let c = average_path_length(sub_size);
358
359 if c.abs() < f64::EPSILON || self.trees.is_empty() {
360 return vec![0.5; n];
361 }
362
363 features
364 .iter()
365 .map(|sample| {
366 let avg_path: f64 = self
367 .trees
368 .iter()
369 .map(|t| t.path_length(sample))
370 .sum::<f64>()
371 / self.trees.len() as f64;
372 2.0_f64.powf(-avg_path / c)
373 })
374 .collect()
375 }
376
377 pub fn predict_labels(&self, features: &[Vec<f64>]) -> Vec<i8> {
382 if self.trees.is_empty() {
383 return vec![1; features.len()];
384 }
385 let scores = self.predict(features);
386 scores
387 .into_iter()
388 .map(|s| if s >= self.threshold { -1 } else { 1 })
389 .collect()
390 }
391
392 pub fn score_threshold(&self) -> f64 {
394 self.threshold
395 }
396}
397
398impl Default for IsolationForest {
399 fn default() -> Self {
400 Self::new()
401 }
402}
403
404#[cfg(test)]
405mod tests {
406 use super::*;
407
408 fn make_test_data(n_normal: usize, n_outliers: usize, seed: u64) -> Vec<Vec<f64>> {
411 let mut rng = crate::rng::FastRng::new(seed);
412 let mut data = Vec::with_capacity(n_normal + n_outliers);
413
414 for _ in 0..n_normal {
416 data.push(vec![rng.f64() * 2.0 - 1.0, rng.f64() * 2.0 - 1.0]);
417 }
418
419 for _ in 0..n_outliers {
421 data.push(vec![10.0 + rng.f64() * 5.0, 10.0 + rng.f64() * 5.0]);
422 }
423
424 data
425 }
426
427 #[test]
428 fn test_iforest_detects_outliers() {
429 let data = make_test_data(90, 10, 42);
430 let mut ifo = IsolationForest::new()
431 .n_estimators(100)
432 .max_samples(64)
433 .contamination(0.1)
434 .random_state(42);
435
436 ifo.fit(&data).unwrap();
437 let scores = ifo.predict(&data);
438
439 let normal_mean: f64 = scores[..90].iter().sum::<f64>() / 90.0;
441 let outlier_mean: f64 = scores[90..].iter().sum::<f64>() / 10.0;
442
443 assert!(
444 outlier_mean > normal_mean,
445 "outlier mean score ({:.3}) should be higher than normal mean ({:.3})",
446 outlier_mean,
447 normal_mean,
448 );
449 }
450
451 #[test]
452 fn test_iforest_labels_recall() {
453 let data = make_test_data(90, 10, 123);
454 let mut ifo = IsolationForest::new()
455 .n_estimators(100)
456 .max_samples(64)
457 .contamination(0.1)
458 .random_state(123);
459
460 ifo.fit(&data).unwrap();
461 let labels = ifo.predict_labels(&data);
462
463 let outlier_detected = labels[90..].iter().filter(|&&l| l == -1).count();
465 let recall = outlier_detected as f64 / 10.0;
466
467 assert!(
468 recall >= 0.7,
469 "expected outlier recall ≥ 0.70, got {:.2}",
470 recall,
471 );
472 }
473
474 #[test]
475 fn test_iforest_single_feature() {
476 let mut data: Vec<Vec<f64>> = (0..100).map(|i| vec![i as f64 * 0.1]).collect();
477 data.push(vec![1000.0]);
479
480 let mut ifo = IsolationForest::new()
481 .n_estimators(50)
482 .max_samples(64)
483 .contamination(0.05)
484 .random_state(7);
485
486 ifo.fit(&data).unwrap();
487 let scores = ifo.predict(&data);
488
489 let max_score_idx = scores
491 .iter()
492 .enumerate()
493 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
494 .unwrap()
495 .0;
496
497 assert_eq!(
498 max_score_idx,
499 data.len() - 1,
500 "outlier should have highest anomaly score"
501 );
502 }
503
504 #[test]
505 fn test_iforest_multi_feature() {
506 let mut rng = crate::rng::FastRng::new(99);
507 let mut data: Vec<Vec<f64>> = (0..100)
508 .map(|_| {
509 vec![
510 rng.f64() * 2.0,
511 rng.f64() * 2.0,
512 rng.f64() * 2.0,
513 rng.f64() * 2.0,
514 ]
515 })
516 .collect();
517 for _ in 0..5 {
519 data.push(vec![50.0, 50.0, 50.0, 50.0]);
520 }
521
522 let mut ifo = IsolationForest::new()
523 .n_estimators(80)
524 .max_samples(64)
525 .contamination(0.05)
526 .random_state(99);
527
528 ifo.fit(&data).unwrap();
529 let labels = ifo.predict_labels(&data);
530
531 let outlier_detected = labels[100..].iter().filter(|&&l| l == -1).count();
532 assert!(
533 outlier_detected >= 3,
534 "expected ≥ 3 of 5 outliers detected, got {}",
535 outlier_detected,
536 );
537 }
538
539 #[test]
540 fn test_iforest_empty_input() {
541 let mut ifo = IsolationForest::new();
542 let result = ifo.fit(&[]);
543 assert!(result.is_err());
544 }
545
546 #[test]
547 fn test_iforest_invalid_contamination() {
548 let data = make_test_data(10, 0, 1);
549 let mut ifo = IsolationForest::new().contamination(0.0);
550 assert!(ifo.fit(&data).is_err());
551
552 let mut ifo2 = IsolationForest::new().contamination(0.6);
553 assert!(ifo2.fit(&data).is_err());
554 }
555
556 #[test]
557 fn test_iforest_default() {
558 let ifo = IsolationForest::default();
559 assert_eq!(ifo.n_estimators, 100);
560 assert_eq!(ifo.max_samples, 256);
561 assert!((ifo.contamination - 0.1).abs() < f64::EPSILON);
562 }
563
564 #[test]
565 fn test_average_path_length() {
566 assert!((average_path_length(0) - 0.0).abs() < f64::EPSILON);
567 assert!((average_path_length(1) - 0.0).abs() < f64::EPSILON);
568 let c2 = average_path_length(2);
570 assert!((c2 - 0.1544).abs() < 0.01, "c(2) = {c2}");
571 let c256 = average_path_length(256);
573 assert!(c256 > 8.0 && c256 < 12.0, "c(256) = {c256}");
574 }
575}