1use crate::config::DecisionTreeConfig;
7use scirs2_core::ndarray::{Array1, Array2};
8use sklears_core::{error::Result, error::SklearsError};
9use std::collections::HashMap;
10use std::sync::{Arc, RwLock};
11
12#[derive(Debug, Clone)]
14pub struct CustomSplit {
15 pub feature_idx: usize,
16 pub threshold: f64,
17 pub impurity_decrease: f64,
18 pub left_count: usize,
19 pub right_count: usize,
20}
21
22#[derive(Debug, Clone)]
24pub struct SurrogateSplit {
25 pub feature_idx: usize,
26 pub threshold: f64,
27 pub agreement: f64, pub left_direction: bool, }
30
31#[derive(Debug, Clone)]
33pub struct CompactTreeNode {
34 pub packed_data: u64,
39 pub value: f32,
41 pub left_child: u32,
43 pub right_child: u32,
45}
46
47impl CompactTreeNode {
48 pub fn new_leaf(prediction: f64) -> Self {
50 Self {
51 packed_data: 0x8000_0000_0000_0000, value: prediction as f32,
53 left_child: 0,
54 right_child: 0,
55 }
56 }
57
58 pub fn new_split(feature_idx: u16, threshold: f32, impurity: f64) -> Self {
60 let mut packed_data = 0u64;
61
62 packed_data |= feature_idx as u64;
64
65 let threshold_bits = threshold.to_bits() as u64;
67 packed_data |= threshold_bits << 16;
68
69 Self {
70 packed_data,
71 value: impurity as f32,
72 left_child: 0,
73 right_child: 0,
74 }
75 }
76
77 pub fn is_leaf(&self) -> bool {
79 (self.packed_data & 0x8000_0000_0000_0000) != 0
80 }
81
82 pub fn feature_idx(&self) -> u16 {
84 (self.packed_data & 0xFFFF) as u16
85 }
86
87 pub fn threshold(&self) -> f32 {
89 let threshold_bits = ((self.packed_data >> 16) & 0xFFFFFFFF) as u32;
90 f32::from_bits(threshold_bits)
91 }
92
93 pub fn prediction(&self) -> f64 {
95 self.value as f64
96 }
97
98 pub fn impurity(&self) -> f64 {
100 self.value as f64
101 }
102
103 pub fn set_left_child(&mut self, child_idx: u32) {
105 self.left_child = child_idx;
106 }
107
108 pub fn set_right_child(&mut self, child_idx: u32) {
110 self.right_child = child_idx;
111 }
112}
113
114#[derive(Debug, Clone)]
116pub struct CompactTree {
117 pub nodes: Vec<CompactTreeNode>,
119 pub feature_importances: Vec<f32>,
121 pub n_features: usize,
123 pub depth: usize,
125}
126
127impl CompactTree {
128 pub fn new(n_features: usize) -> Self {
130 Self {
131 nodes: Vec::new(),
132 feature_importances: vec![0.0; n_features],
133 n_features,
134 depth: 0,
135 }
136 }
137
138 pub fn add_node(&mut self, node: CompactTreeNode) -> u32 {
140 let idx = self.nodes.len() as u32;
141 self.nodes.push(node);
142 idx
143 }
144
145 pub fn predict_single(&self, sample: &[f64]) -> Result<f64> {
147 if self.nodes.is_empty() {
148 return Err(SklearsError::InvalidInput("Empty tree".to_string()));
149 }
150
151 let mut current_idx = 0;
152
153 loop {
154 if current_idx >= self.nodes.len() {
155 return Err(SklearsError::InvalidInput("Invalid node index".to_string()));
156 }
157
158 let node = &self.nodes[current_idx];
159
160 if node.is_leaf() {
161 return Ok(node.prediction());
162 }
163
164 let feature_idx = node.feature_idx() as usize;
165 if feature_idx >= sample.len() {
166 return Err(SklearsError::InvalidInput(
167 "Feature index out of bounds".to_string(),
168 ));
169 }
170
171 let feature_value = sample[feature_idx];
172 let threshold = node.threshold() as f64;
173
174 if feature_value <= threshold {
175 current_idx = node.left_child as usize;
176 } else {
177 current_idx = node.right_child as usize;
178 }
179
180 if current_idx == 0 {
181 return Err(SklearsError::InvalidInput("Invalid child node".to_string()));
182 }
183 current_idx -= 1; }
185 }
186
187 pub fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>> {
189 let n_samples = x.nrows();
190 let mut predictions = Array1::zeros(n_samples);
191
192 for i in 0..n_samples {
193 let sample = x.row(i).to_vec();
194 predictions[i] = self.predict_single(&sample)?;
195 }
196
197 Ok(predictions)
198 }
199
200 pub fn memory_usage(&self) -> usize {
202 let node_size = std::mem::size_of::<CompactTreeNode>();
203 let feature_importance_size = std::mem::size_of::<f32>() * self.feature_importances.len();
204 let metadata_size = std::mem::size_of::<Self>()
205 - std::mem::size_of::<Vec<CompactTreeNode>>()
206 - std::mem::size_of::<Vec<f32>>();
207
208 self.nodes.len() * node_size + feature_importance_size + metadata_size
209 }
210
211 pub fn n_nodes(&self) -> usize {
213 self.nodes.len()
214 }
215
216 pub fn n_leaves(&self) -> usize {
218 self.nodes.iter().filter(|node| node.is_leaf()).count()
219 }
220}
221
222#[derive(Debug, Clone)]
224pub struct BitPackedPath {
225 pub path_bits: u128,
227 pub path_length: u8,
229 pub prediction: f32,
231}
232
233impl Default for BitPackedPath {
234 fn default() -> Self {
235 Self::new()
236 }
237}
238
239impl BitPackedPath {
240 pub fn new() -> Self {
242 Self {
243 path_bits: 0,
244 path_length: 0,
245 prediction: 0.0,
246 }
247 }
248
249 pub fn add_decision(&mut self, go_right: bool) -> Result<()> {
251 if self.path_length >= 127 {
252 return Err(SklearsError::InvalidInput("Path too long".to_string()));
253 }
254
255 if go_right {
256 self.path_bits |= 1u128 << self.path_length;
257 }
258 self.path_length += 1;
259 Ok(())
260 }
261
262 pub fn get_decision(&self, position: u8) -> bool {
264 if position >= self.path_length {
265 return false;
266 }
267
268 (self.path_bits & (1u128 << position)) != 0
269 }
270
271 pub fn set_prediction(&mut self, prediction: f64) {
273 self.prediction = prediction as f32;
274 }
275
276 pub fn get_prediction(&self) -> f64 {
278 self.prediction as f64
279 }
280}
281
282#[derive(Debug, Clone)]
284pub struct BitPackedTree {
285 pub paths: Vec<BitPackedPath>,
287 pub feature_indices: Vec<u16>,
289 pub thresholds: Vec<f32>,
291 pub n_features: usize,
293}
294
295impl BitPackedTree {
296 pub fn new(n_features: usize) -> Self {
298 Self {
299 paths: Vec::new(),
300 feature_indices: Vec::new(),
301 thresholds: Vec::new(),
302 n_features,
303 }
304 }
305
306 pub fn add_path(&mut self, path: BitPackedPath) {
308 self.paths.push(path);
309 }
310
311 pub fn memory_usage(&self) -> usize {
313 let path_size = std::mem::size_of::<BitPackedPath>() * self.paths.len();
314 let feature_size = std::mem::size_of::<u16>() * self.feature_indices.len();
315 let threshold_size = std::mem::size_of::<f32>() * self.thresholds.len();
316 let metadata_size = std::mem::size_of::<Self>()
317 - std::mem::size_of::<Vec<BitPackedPath>>()
318 - std::mem::size_of::<Vec<u16>>()
319 - std::mem::size_of::<Vec<f32>>();
320
321 path_size + feature_size + threshold_size + metadata_size
322 }
323}
324
325#[derive(Debug, Clone)]
327pub struct CompactEnsemble {
328 pub trees: Vec<CompactTree>,
330 pub global_feature_importances: Vec<f32>,
332 pub n_features: usize,
334 pub n_trees: usize,
336}
337
338impl CompactEnsemble {
339 pub fn new(n_features: usize) -> Self {
341 Self {
342 trees: Vec::new(),
343 global_feature_importances: vec![0.0; n_features],
344 n_features,
345 n_trees: 0,
346 }
347 }
348
349 pub fn add_tree(&mut self, tree: CompactTree) {
351 self.trees.push(tree);
352 self.n_trees += 1;
353 }
354
355 pub fn total_memory_usage(&self) -> usize {
357 let trees_memory: usize = self.trees.iter().map(|t| t.memory_usage()).sum();
358 let importance_memory = std::mem::size_of::<f32>() * self.global_feature_importances.len();
359 let metadata_memory = std::mem::size_of::<Self>()
360 - std::mem::size_of::<Vec<CompactTree>>()
361 - std::mem::size_of::<Vec<f32>>();
362
363 trees_memory + importance_memory + metadata_memory
364 }
365}
366
367#[derive(Debug, Clone)]
369pub struct TreeNode {
370 pub id: usize,
372 pub depth: usize,
374 pub sample_indices: Vec<usize>,
376 pub impurity: f64,
378 pub prediction: f64,
380 pub potential_decrease: f64,
382 pub best_split: Option<CustomSplit>,
384 pub parent_id: Option<usize>,
386 pub is_leaf: bool,
388}
389
390#[derive(Debug, Clone)]
392pub struct NodePriority {
393 pub node_id: usize,
394 pub priority: f64, }
396
397impl PartialEq for NodePriority {
398 fn eq(&self, other: &Self) -> bool {
399 self.priority == other.priority
400 }
401}
402
403impl Eq for NodePriority {}
404
405impl PartialOrd for NodePriority {
406 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
407 Some(self.cmp(other))
408 }
409}
410
411impl Ord for NodePriority {
412 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
413 other
415 .priority
416 .partial_cmp(&self.priority)
417 .unwrap_or(std::cmp::Ordering::Equal)
418 }
419}
420
421#[derive(Debug, Clone, Copy, PartialEq)]
423pub struct OrderedFloat(pub f64);
424
425impl Eq for OrderedFloat {}
426
427impl PartialOrd for OrderedFloat {
428 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
429 Some(self.cmp(other))
430 }
431}
432
433impl std::hash::Hash for OrderedFloat {
434 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
435 self.0.to_bits().hash(state);
436 }
437}
438
439impl std::cmp::Ord for OrderedFloat {
440 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
441 self.0
442 .partial_cmp(&other.0)
443 .unwrap_or(std::cmp::Ordering::Equal)
444 }
445}
446
447#[derive(Debug, Clone, Hash, Eq, PartialEq)]
449pub struct SubtreePattern {
450 pub depth: usize,
452 pub min_samples: usize,
454 pub splits: Vec<(usize, OrderedFloat)>,
456}
457
458#[derive(Debug, Clone)]
460pub struct SubtreeConfig {
461 pub min_samples_for_sharing: usize,
463 pub max_shared_depth: usize,
465 pub min_pattern_frequency: usize,
467 pub enabled: bool,
469}
470
471impl Default for SubtreeConfig {
472 fn default() -> Self {
473 Self {
474 min_samples_for_sharing: 10,
475 max_shared_depth: 3,
476 min_pattern_frequency: 2,
477 enabled: false,
478 }
479 }
480}
481
482#[derive(Debug, Clone)]
484pub struct SharedTreeNode {
485 pub id: usize,
487 pub feature_idx: Option<usize>,
489 pub threshold: Option<f64>,
491 pub prediction: f64,
493 pub n_samples: usize,
495 pub impurity: f64,
497 pub left_child: Option<usize>,
499 pub right_child: Option<usize>,
501 pub subtree_hash: u64,
503}
504
505#[derive(Debug, Clone)]
507pub struct SubtreeSharingStats {
508 pub total_shared_nodes: usize,
510 pub total_patterns: usize,
512 pub estimated_memory_saved: usize,
514 pub sharing_efficiency: f64,
516}
517
518#[derive(Debug, Clone)]
520pub struct SubtreeReference {
521 pub shared_id: usize,
523 pub local_node_id: usize,
525 pub tree_id: usize,
527}
528
529#[derive(Debug, Clone)]
531pub struct SharedSubtreeManager {
532 pub shared_nodes: Arc<RwLock<HashMap<usize, SharedTreeNode>>>,
534 pub pattern_cache: Arc<RwLock<HashMap<SubtreePattern, usize>>>,
536 pub next_node_id: Arc<RwLock<usize>>,
538 pub config: SubtreeConfig,
540}
541
542impl SharedSubtreeManager {
543 pub fn new(config: SubtreeConfig) -> Self {
544 Self {
545 shared_nodes: Arc::new(RwLock::new(HashMap::new())),
546 pattern_cache: Arc::new(RwLock::new(HashMap::new())),
547 next_node_id: Arc::new(RwLock::new(0)),
548 config,
549 }
550 }
551
552 pub fn extract_patterns(&self, tree_nodes: &[TreeNode]) -> Result<Vec<SubtreePattern>> {
554 if !self.config.enabled {
555 return Ok(vec![]);
556 }
557
558 let mut patterns = Vec::new();
559
560 for node in tree_nodes {
561 if node.is_leaf || node.sample_indices.len() < self.config.min_samples_for_sharing {
562 continue;
563 }
564
565 if let Some(pattern) = self.extract_pattern_from_node(tree_nodes, node.id, 0) {
566 patterns.push(pattern);
567 }
568 }
569
570 Ok(patterns)
571 }
572
573 fn extract_pattern_from_node(
575 &self,
576 tree_nodes: &[TreeNode],
577 node_id: usize,
578 current_depth: usize,
579 ) -> Option<SubtreePattern> {
580 if current_depth >= self.config.max_shared_depth {
581 return None;
582 }
583
584 if node_id >= tree_nodes.len() {
585 return None;
586 }
587
588 let node = &tree_nodes[node_id];
589
590 if node.is_leaf {
591 return Some(SubtreePattern {
592 depth: current_depth,
593 min_samples: node.sample_indices.len(),
594 splits: vec![],
595 });
596 }
597
598 let mut splits = Vec::new();
599
600 if let Some(ref split) = node.best_split {
601 splits.push((split.feature_idx, OrderedFloat(split.threshold)));
602 }
603
604 Some(SubtreePattern {
605 depth: current_depth,
606 min_samples: node.sample_indices.len(),
607 splits,
608 })
609 }
610
611 pub fn get_or_create_shared_subtree(
613 &self,
614 pattern: &SubtreePattern,
615 tree_nodes: &[TreeNode],
616 root_node_id: usize,
617 ) -> Result<usize> {
618 {
620 let cache = self.pattern_cache.read().unwrap();
621 if let Some(&shared_id) = cache.get(pattern) {
622 return Ok(shared_id);
623 }
624 }
625
626 let shared_id = {
628 let mut next_id = self.next_node_id.write().unwrap();
629 let id = *next_id;
630 *next_id += 1;
631 id
632 };
633
634 let shared_node = self.create_shared_node(tree_nodes, root_node_id, shared_id)?;
635
636 {
638 let mut nodes = self.shared_nodes.write().unwrap();
639 nodes.insert(shared_id, shared_node);
640 }
641
642 {
644 let mut cache = self.pattern_cache.write().unwrap();
645 cache.insert(pattern.clone(), shared_id);
646 }
647
648 Ok(shared_id)
649 }
650
651 fn create_shared_node(
653 &self,
654 tree_nodes: &[TreeNode],
655 node_id: usize,
656 shared_id: usize,
657 ) -> Result<SharedTreeNode> {
658 if node_id >= tree_nodes.len() {
659 return Err(SklearsError::InvalidInput("Invalid node ID".to_string()));
660 }
661
662 let node = &tree_nodes[node_id];
663 let subtree_hash = self.calculate_subtree_hash(tree_nodes, node_id);
664
665 let (feature_idx, threshold) = if let Some(ref split) = node.best_split {
666 (Some(split.feature_idx), Some(split.threshold))
667 } else {
668 (None, None)
669 };
670
671 Ok(SharedTreeNode {
672 id: shared_id,
673 feature_idx,
674 threshold,
675 prediction: node.prediction,
676 n_samples: node.sample_indices.len(),
677 impurity: node.impurity,
678 left_child: None, right_child: None, subtree_hash,
681 })
682 }
683
684 fn calculate_subtree_hash(&self, tree_nodes: &[TreeNode], node_id: usize) -> u64 {
686 use std::collections::hash_map::DefaultHasher;
687 use std::hash::Hasher;
688
689 let mut hasher = DefaultHasher::new();
690 self.hash_subtree_recursive(tree_nodes, node_id, &mut hasher, 0);
691 hasher.finish()
692 }
693
694 fn hash_subtree_recursive(
696 &self,
697 tree_nodes: &[TreeNode],
698 node_id: usize,
699 hasher: &mut dyn std::hash::Hasher,
700 depth: usize,
701 ) {
702 if depth >= self.config.max_shared_depth || node_id >= tree_nodes.len() {
703 return;
704 }
705
706 let node = &tree_nodes[node_id];
707
708 hasher.write_u8(if node.is_leaf { 1 } else { 0 });
710 if let Some(ref split) = node.best_split {
711 hasher.write_usize(split.feature_idx);
712 hasher.write(&split.threshold.to_le_bytes());
713 }
714 }
715
716 pub fn calculate_memory_savings(&self) -> Result<SubtreeSharingStats> {
718 let shared_nodes = self.shared_nodes.read().unwrap();
719 let pattern_cache = self.pattern_cache.read().unwrap();
720
721 let total_shared_nodes = shared_nodes.len();
722 let total_patterns = pattern_cache.len();
723
724 let estimated_memory_saved = total_shared_nodes * std::mem::size_of::<SharedTreeNode>();
726
727 Ok(SubtreeSharingStats {
728 total_shared_nodes,
729 total_patterns,
730 estimated_memory_saved,
731 sharing_efficiency: if total_patterns > 0 {
732 total_shared_nodes as f64 / total_patterns as f64
733 } else {
734 0.0
735 },
736 })
737 }
738}
739
740#[derive(Debug, Clone)]
742pub struct TreeSpecificData {
743 pub tree_id: usize,
745 pub sample_weights: Vec<f64>,
747 pub config: DecisionTreeConfig,
749 pub local_nodes: Vec<TreeNode>,
751}
752
753#[derive(Debug, Clone)]
755pub struct SharedTreeEnsemble {
756 pub subtree_manager: SharedSubtreeManager,
758 pub tree_specific_data: Vec<TreeSpecificData>,
760 pub subtree_references: Vec<SubtreeReference>,
762}
763
764impl SharedTreeEnsemble {
765 pub fn new(subtree_config: SubtreeConfig) -> Self {
766 Self {
767 subtree_manager: SharedSubtreeManager::new(subtree_config),
768 tree_specific_data: Vec::new(),
769 subtree_references: Vec::new(),
770 }
771 }
772
773 pub fn add_tree(
775 &mut self,
776 tree_nodes: Vec<TreeNode>,
777 config: DecisionTreeConfig,
778 tree_id: usize,
779 ) -> Result<()> {
780 let patterns = self.subtree_manager.extract_patterns(&tree_nodes)?;
782
783 for pattern in patterns {
785 if let Ok(shared_id) = self.subtree_manager.get_or_create_shared_subtree(
786 &pattern,
787 &tree_nodes,
788 0, ) {
790 self.subtree_references.push(SubtreeReference {
791 shared_id,
792 local_node_id: 0,
793 tree_id,
794 });
795 }
796 }
797
798 let tree_data = TreeSpecificData {
800 tree_id,
801 sample_weights: vec![1.0; tree_nodes.len()], config,
803 local_nodes: tree_nodes,
804 };
805
806 self.tree_specific_data.push(tree_data);
807
808 Ok(())
809 }
810
811 pub fn get_sharing_stats(&self) -> Result<SubtreeSharingStats> {
813 self.subtree_manager.calculate_memory_savings()
814 }
815}