sklears_tree/
node.rs

1//! Tree node data structures and compact representations
2//!
3//! This module contains various tree node representations including compact,
4//! bit-packed, and shared node structures for memory efficiency and optimization.
5
6use crate::config::DecisionTreeConfig;
7use scirs2_core::ndarray::{Array1, Array2};
8use sklears_core::{error::Result, error::SklearsError};
9use std::collections::HashMap;
10use std::sync::{Arc, RwLock};
11
12/// Custom split information for tree nodes
13#[derive(Debug, Clone)]
14pub struct CustomSplit {
15    pub feature_idx: usize,
16    pub threshold: f64,
17    pub impurity_decrease: f64,
18    pub left_count: usize,
19    pub right_count: usize,
20}
21
22/// Surrogate split information for handling missing values
23#[derive(Debug, Clone)]
24pub struct SurrogateSplit {
25    pub feature_idx: usize,
26    pub threshold: f64,
27    pub agreement: f64,       // How well this surrogate agrees with primary split
28    pub left_direction: bool, // True if missing values go left, false if right
29}
30
31/// Compact tree node representation for memory efficiency
32#[derive(Debug, Clone)]
33pub struct CompactTreeNode {
34    /// Compact node data packed into a single u64
35    /// Bits 0-15: feature index (16 bits, max 65535 features)
36    /// Bits 16-47: threshold as f32 bits (32 bits)
37    /// Bits 48-63: node flags and metadata (16 bits)
38    pub packed_data: u64,
39    /// Prediction value for leaf nodes or impurity for split nodes
40    pub value: f32,
41    /// Left child index (0 means no left child)
42    pub left_child: u32,
43    /// Right child index (0 means no right child)
44    pub right_child: u32,
45}
46
47impl CompactTreeNode {
48    /// Create a new leaf node
49    pub fn new_leaf(prediction: f64) -> Self {
50        Self {
51            packed_data: 0x8000_0000_0000_0000, // Set leaf flag in bit 63
52            value: prediction as f32,
53            left_child: 0,
54            right_child: 0,
55        }
56    }
57
58    /// Create a new split node
59    pub fn new_split(feature_idx: u16, threshold: f32, impurity: f64) -> Self {
60        let mut packed_data = 0u64;
61
62        // Pack feature index (bits 0-15)
63        packed_data |= feature_idx as u64;
64
65        // Pack threshold (bits 16-47)
66        let threshold_bits = threshold.to_bits() as u64;
67        packed_data |= threshold_bits << 16;
68
69        Self {
70            packed_data,
71            value: impurity as f32,
72            left_child: 0,
73            right_child: 0,
74        }
75    }
76
77    /// Check if this is a leaf node
78    pub fn is_leaf(&self) -> bool {
79        (self.packed_data & 0x8000_0000_0000_0000) != 0
80    }
81
82    /// Get feature index for split nodes
83    pub fn feature_idx(&self) -> u16 {
84        (self.packed_data & 0xFFFF) as u16
85    }
86
87    /// Get threshold for split nodes
88    pub fn threshold(&self) -> f32 {
89        let threshold_bits = ((self.packed_data >> 16) & 0xFFFFFFFF) as u32;
90        f32::from_bits(threshold_bits)
91    }
92
93    /// Get prediction value for leaf nodes
94    pub fn prediction(&self) -> f64 {
95        self.value as f64
96    }
97
98    /// Get impurity value for split nodes
99    pub fn impurity(&self) -> f64 {
100        self.value as f64
101    }
102
103    /// Set left child index
104    pub fn set_left_child(&mut self, child_idx: u32) {
105        self.left_child = child_idx;
106    }
107
108    /// Set right child index
109    pub fn set_right_child(&mut self, child_idx: u32) {
110        self.right_child = child_idx;
111    }
112}
113
114/// Compact tree representation for memory efficiency
115#[derive(Debug, Clone)]
116pub struct CompactTree {
117    /// Array of compact tree nodes
118    pub nodes: Vec<CompactTreeNode>,
119    /// Feature importance scores
120    pub feature_importances: Vec<f32>,
121    /// Number of features
122    pub n_features: usize,
123    /// Tree depth
124    pub depth: usize,
125}
126
127impl CompactTree {
128    /// Create a new compact tree
129    pub fn new(n_features: usize) -> Self {
130        Self {
131            nodes: Vec::new(),
132            feature_importances: vec![0.0; n_features],
133            n_features,
134            depth: 0,
135        }
136    }
137
138    /// Add a new node and return its index
139    pub fn add_node(&mut self, node: CompactTreeNode) -> u32 {
140        let idx = self.nodes.len() as u32;
141        self.nodes.push(node);
142        idx
143    }
144
145    /// Predict a single sample
146    pub fn predict_single(&self, sample: &[f64]) -> Result<f64> {
147        if self.nodes.is_empty() {
148            return Err(SklearsError::InvalidInput("Empty tree".to_string()));
149        }
150
151        let mut current_idx = 0;
152
153        loop {
154            if current_idx >= self.nodes.len() {
155                return Err(SklearsError::InvalidInput("Invalid node index".to_string()));
156            }
157
158            let node = &self.nodes[current_idx];
159
160            if node.is_leaf() {
161                return Ok(node.prediction());
162            }
163
164            let feature_idx = node.feature_idx() as usize;
165            if feature_idx >= sample.len() {
166                return Err(SklearsError::InvalidInput(
167                    "Feature index out of bounds".to_string(),
168                ));
169            }
170
171            let feature_value = sample[feature_idx];
172            let threshold = node.threshold() as f64;
173
174            if feature_value <= threshold {
175                current_idx = node.left_child as usize;
176            } else {
177                current_idx = node.right_child as usize;
178            }
179
180            if current_idx == 0 {
181                return Err(SklearsError::InvalidInput("Invalid child node".to_string()));
182            }
183            current_idx -= 1; // Convert to 0-based indexing
184        }
185    }
186
187    /// Predict multiple samples
188    pub fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>> {
189        let n_samples = x.nrows();
190        let mut predictions = Array1::zeros(n_samples);
191
192        for i in 0..n_samples {
193            let sample = x.row(i).to_vec();
194            predictions[i] = self.predict_single(&sample)?;
195        }
196
197        Ok(predictions)
198    }
199
200    /// Calculate memory usage in bytes
201    pub fn memory_usage(&self) -> usize {
202        let node_size = std::mem::size_of::<CompactTreeNode>();
203        let feature_importance_size = std::mem::size_of::<f32>() * self.feature_importances.len();
204        let metadata_size = std::mem::size_of::<Self>()
205            - std::mem::size_of::<Vec<CompactTreeNode>>()
206            - std::mem::size_of::<Vec<f32>>();
207
208        self.nodes.len() * node_size + feature_importance_size + metadata_size
209    }
210
211    /// Get the number of nodes in the tree
212    pub fn n_nodes(&self) -> usize {
213        self.nodes.len()
214    }
215
216    /// Get the number of leaves in the tree
217    pub fn n_leaves(&self) -> usize {
218        self.nodes.iter().filter(|node| node.is_leaf()).count()
219    }
220}
221
222/// Bit-packed decision path for ultra-compact storage
223#[derive(Debug, Clone)]
224pub struct BitPackedPath {
225    /// Packed decision bits (left = 0, right = 1)
226    pub path_bits: u128,
227    /// Number of decisions in the path
228    pub path_length: u8,
229    /// Final prediction value
230    pub prediction: f32,
231}
232
233impl Default for BitPackedPath {
234    fn default() -> Self {
235        Self::new()
236    }
237}
238
239impl BitPackedPath {
240    /// Create a new bit-packed path
241    pub fn new() -> Self {
242        Self {
243            path_bits: 0,
244            path_length: 0,
245            prediction: 0.0,
246        }
247    }
248
249    /// Add a decision to the path (false = left, true = right)
250    pub fn add_decision(&mut self, go_right: bool) -> Result<()> {
251        if self.path_length >= 127 {
252            return Err(SklearsError::InvalidInput("Path too long".to_string()));
253        }
254
255        if go_right {
256            self.path_bits |= 1u128 << self.path_length;
257        }
258        self.path_length += 1;
259        Ok(())
260    }
261
262    /// Get decision at position (false = left, true = right)
263    pub fn get_decision(&self, position: u8) -> bool {
264        if position >= self.path_length {
265            return false;
266        }
267
268        (self.path_bits & (1u128 << position)) != 0
269    }
270
271    /// Set final prediction
272    pub fn set_prediction(&mut self, prediction: f64) {
273        self.prediction = prediction as f32;
274    }
275
276    /// Get final prediction
277    pub fn get_prediction(&self) -> f64 {
278        self.prediction as f64
279    }
280}
281
282/// Ultra-compact tree using bit-packed decision paths
283#[derive(Debug, Clone)]
284pub struct BitPackedTree {
285    /// Decision paths for each possible outcome
286    pub paths: Vec<BitPackedPath>,
287    /// Feature indices used in decisions
288    pub feature_indices: Vec<u16>,
289    /// Thresholds used in decisions
290    pub thresholds: Vec<f32>,
291    /// Number of features
292    pub n_features: usize,
293}
294
295impl BitPackedTree {
296    /// Create a new bit-packed tree
297    pub fn new(n_features: usize) -> Self {
298        Self {
299            paths: Vec::new(),
300            feature_indices: Vec::new(),
301            thresholds: Vec::new(),
302            n_features,
303        }
304    }
305
306    /// Add a decision path
307    pub fn add_path(&mut self, path: BitPackedPath) {
308        self.paths.push(path);
309    }
310
311    /// Memory usage in bytes
312    pub fn memory_usage(&self) -> usize {
313        let path_size = std::mem::size_of::<BitPackedPath>() * self.paths.len();
314        let feature_size = std::mem::size_of::<u16>() * self.feature_indices.len();
315        let threshold_size = std::mem::size_of::<f32>() * self.thresholds.len();
316        let metadata_size = std::mem::size_of::<Self>()
317            - std::mem::size_of::<Vec<BitPackedPath>>()
318            - std::mem::size_of::<Vec<u16>>()
319            - std::mem::size_of::<Vec<f32>>();
320
321        path_size + feature_size + threshold_size + metadata_size
322    }
323}
324
325/// Memory-efficient ensemble representation
326#[derive(Debug, Clone)]
327pub struct CompactEnsemble {
328    /// Array of compact trees
329    pub trees: Vec<CompactTree>,
330    /// Shared feature importance across all trees
331    pub global_feature_importances: Vec<f32>,
332    /// Number of features
333    pub n_features: usize,
334    /// Number of trees
335    pub n_trees: usize,
336}
337
338impl CompactEnsemble {
339    /// Create a new compact ensemble
340    pub fn new(n_features: usize) -> Self {
341        Self {
342            trees: Vec::new(),
343            global_feature_importances: vec![0.0; n_features],
344            n_features,
345            n_trees: 0,
346        }
347    }
348
349    /// Add a tree to the ensemble
350    pub fn add_tree(&mut self, tree: CompactTree) {
351        self.trees.push(tree);
352        self.n_trees += 1;
353    }
354
355    /// Calculate total memory usage
356    pub fn total_memory_usage(&self) -> usize {
357        let trees_memory: usize = self.trees.iter().map(|t| t.memory_usage()).sum();
358        let importance_memory = std::mem::size_of::<f32>() * self.global_feature_importances.len();
359        let metadata_memory = std::mem::size_of::<Self>()
360            - std::mem::size_of::<Vec<CompactTree>>()
361            - std::mem::size_of::<Vec<f32>>();
362
363        trees_memory + importance_memory + metadata_memory
364    }
365}
366
367/// Tree node for building algorithms
368#[derive(Debug, Clone)]
369pub struct TreeNode {
370    /// Node ID
371    pub id: usize,
372    /// Depth of this node
373    pub depth: usize,
374    /// Samples in this node
375    pub sample_indices: Vec<usize>,
376    /// Impurity of this node
377    pub impurity: f64,
378    /// Predicted value/class for this node
379    pub prediction: f64,
380    /// Potential impurity decrease if this node is split
381    pub potential_decrease: f64,
382    /// Best split for this node (if any)
383    pub best_split: Option<CustomSplit>,
384    /// Parent node ID
385    pub parent_id: Option<usize>,
386    /// Whether this is a leaf node
387    pub is_leaf: bool,
388}
389
390/// Priority wrapper for nodes in the queue
391#[derive(Debug, Clone)]
392pub struct NodePriority {
393    pub node_id: usize,
394    pub priority: f64, // Negative of impurity decrease for max-heap behavior
395}
396
397impl PartialEq for NodePriority {
398    fn eq(&self, other: &Self) -> bool {
399        self.priority == other.priority
400    }
401}
402
403impl Eq for NodePriority {}
404
405impl PartialOrd for NodePriority {
406    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
407        Some(self.cmp(other))
408    }
409}
410
411impl Ord for NodePriority {
412    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
413        // Reverse order for max-heap (highest priority first)
414        other
415            .priority
416            .partial_cmp(&self.priority)
417            .unwrap_or(std::cmp::Ordering::Equal)
418    }
419}
420
421/// Wrapper for f64 to make it hashable and orderable
422#[derive(Debug, Clone, Copy, PartialEq)]
423pub struct OrderedFloat(pub f64);
424
425impl Eq for OrderedFloat {}
426
427impl PartialOrd for OrderedFloat {
428    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
429        Some(self.cmp(other))
430    }
431}
432
433impl std::hash::Hash for OrderedFloat {
434    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
435        self.0.to_bits().hash(state);
436    }
437}
438
439impl std::cmp::Ord for OrderedFloat {
440    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
441        self.0
442            .partial_cmp(&other.0)
443            .unwrap_or(std::cmp::Ordering::Equal)
444    }
445}
446
447/// Subtree pattern for identifying reusable subtrees
448#[derive(Debug, Clone, Hash, Eq, PartialEq)]
449pub struct SubtreePattern {
450    /// Maximum depth of the pattern
451    pub depth: usize,
452    /// Minimum number of samples to consider for sharing
453    pub min_samples: usize,
454    /// Feature and threshold pairs in the pattern
455    pub splits: Vec<(usize, OrderedFloat)>,
456}
457
458/// Configuration for shared subtree optimization
459#[derive(Debug, Clone)]
460pub struct SubtreeConfig {
461    /// Minimum number of samples in a subtree to consider sharing
462    pub min_samples_for_sharing: usize,
463    /// Maximum depth of subtrees to share
464    pub max_shared_depth: usize,
465    /// Minimum frequency of pattern to justify sharing
466    pub min_pattern_frequency: usize,
467    /// Enable/disable subtree sharing
468    pub enabled: bool,
469}
470
471impl Default for SubtreeConfig {
472    fn default() -> Self {
473        Self {
474            min_samples_for_sharing: 10,
475            max_shared_depth: 3,
476            min_pattern_frequency: 2,
477            enabled: false,
478        }
479    }
480}
481
482/// A shared tree node that can be referenced by multiple trees
483#[derive(Debug, Clone)]
484pub struct SharedTreeNode {
485    /// Unique identifier for this shared node
486    pub id: usize,
487    /// Feature index for split (None for leaf nodes)
488    pub feature_idx: Option<usize>,
489    /// Split threshold (None for leaf nodes)
490    pub threshold: Option<f64>,
491    /// Prediction value for this node
492    pub prediction: f64,
493    /// Number of samples that reached this node
494    pub n_samples: usize,
495    /// Impurity at this node
496    pub impurity: f64,
497    /// Left child node ID (None for leaf)
498    pub left_child: Option<usize>,
499    /// Right child node ID (None for leaf)
500    pub right_child: Option<usize>,
501    /// Hash representing the subtree structure for sharing
502    pub subtree_hash: u64,
503}
504
505/// Statistics about subtree sharing efficiency
506#[derive(Debug, Clone)]
507pub struct SubtreeSharingStats {
508    /// Total number of shared nodes
509    pub total_shared_nodes: usize,
510    /// Total number of unique patterns
511    pub total_patterns: usize,
512    /// Estimated memory saved in bytes
513    pub estimated_memory_saved: usize,
514    /// Sharing efficiency (shared nodes / patterns)
515    pub sharing_efficiency: f64,
516}
517
518/// Reference to a shared subtree in a tree ensemble
519#[derive(Debug, Clone)]
520pub struct SubtreeReference {
521    /// ID of the shared subtree
522    pub shared_id: usize,
523    /// Local node ID in the referencing tree
524    pub local_node_id: usize,
525    /// Tree ID that contains this reference
526    pub tree_id: usize,
527}
528
529/// Shared subtree manager for memory optimization
530#[derive(Debug, Clone)]
531pub struct SharedSubtreeManager {
532    /// Shared nodes indexed by ID
533    pub shared_nodes: Arc<RwLock<HashMap<usize, SharedTreeNode>>>,
534    /// Pattern to shared node ID mapping
535    pub pattern_cache: Arc<RwLock<HashMap<SubtreePattern, usize>>>,
536    /// Next available node ID
537    pub next_node_id: Arc<RwLock<usize>>,
538    /// Configuration for subtree sharing
539    pub config: SubtreeConfig,
540}
541
542impl SharedSubtreeManager {
543    pub fn new(config: SubtreeConfig) -> Self {
544        Self {
545            shared_nodes: Arc::new(RwLock::new(HashMap::new())),
546            pattern_cache: Arc::new(RwLock::new(HashMap::new())),
547            next_node_id: Arc::new(RwLock::new(0)),
548            config,
549        }
550    }
551
552    /// Extract subtree patterns from a tree for potential sharing
553    pub fn extract_patterns(&self, tree_nodes: &[TreeNode]) -> Result<Vec<SubtreePattern>> {
554        if !self.config.enabled {
555            return Ok(vec![]);
556        }
557
558        let mut patterns = Vec::new();
559
560        for node in tree_nodes {
561            if node.is_leaf || node.sample_indices.len() < self.config.min_samples_for_sharing {
562                continue;
563            }
564
565            if let Some(pattern) = self.extract_pattern_from_node(tree_nodes, node.id, 0) {
566                patterns.push(pattern);
567            }
568        }
569
570        Ok(patterns)
571    }
572
573    /// Extract a pattern starting from a specific node
574    fn extract_pattern_from_node(
575        &self,
576        tree_nodes: &[TreeNode],
577        node_id: usize,
578        current_depth: usize,
579    ) -> Option<SubtreePattern> {
580        if current_depth >= self.config.max_shared_depth {
581            return None;
582        }
583
584        if node_id >= tree_nodes.len() {
585            return None;
586        }
587
588        let node = &tree_nodes[node_id];
589
590        if node.is_leaf {
591            return Some(SubtreePattern {
592                depth: current_depth,
593                min_samples: node.sample_indices.len(),
594                splits: vec![],
595            });
596        }
597
598        let mut splits = Vec::new();
599
600        if let Some(ref split) = node.best_split {
601            splits.push((split.feature_idx, OrderedFloat(split.threshold)));
602        }
603
604        Some(SubtreePattern {
605            depth: current_depth,
606            min_samples: node.sample_indices.len(),
607            splits,
608        })
609    }
610
611    /// Find or create a shared subtree for a given pattern
612    pub fn get_or_create_shared_subtree(
613        &self,
614        pattern: &SubtreePattern,
615        tree_nodes: &[TreeNode],
616        root_node_id: usize,
617    ) -> Result<usize> {
618        // Check if pattern already exists
619        {
620            let cache = self.pattern_cache.read().unwrap();
621            if let Some(&shared_id) = cache.get(pattern) {
622                return Ok(shared_id);
623            }
624        }
625
626        // Create new shared subtree
627        let shared_id = {
628            let mut next_id = self.next_node_id.write().unwrap();
629            let id = *next_id;
630            *next_id += 1;
631            id
632        };
633
634        let shared_node = self.create_shared_node(tree_nodes, root_node_id, shared_id)?;
635
636        // Store the shared node
637        {
638            let mut nodes = self.shared_nodes.write().unwrap();
639            nodes.insert(shared_id, shared_node);
640        }
641
642        // Cache the pattern
643        {
644            let mut cache = self.pattern_cache.write().unwrap();
645            cache.insert(pattern.clone(), shared_id);
646        }
647
648        Ok(shared_id)
649    }
650
651    /// Create a shared node from a tree node
652    fn create_shared_node(
653        &self,
654        tree_nodes: &[TreeNode],
655        node_id: usize,
656        shared_id: usize,
657    ) -> Result<SharedTreeNode> {
658        if node_id >= tree_nodes.len() {
659            return Err(SklearsError::InvalidInput("Invalid node ID".to_string()));
660        }
661
662        let node = &tree_nodes[node_id];
663        let subtree_hash = self.calculate_subtree_hash(tree_nodes, node_id);
664
665        let (feature_idx, threshold) = if let Some(ref split) = node.best_split {
666            (Some(split.feature_idx), Some(split.threshold))
667        } else {
668            (None, None)
669        };
670
671        Ok(SharedTreeNode {
672            id: shared_id,
673            feature_idx,
674            threshold,
675            prediction: node.prediction,
676            n_samples: node.sample_indices.len(),
677            impurity: node.impurity,
678            left_child: None,  // Would need to be set based on actual children
679            right_child: None, // Would need to be set based on actual children
680            subtree_hash,
681        })
682    }
683
684    /// Calculate a hash representing the structure of a subtree
685    fn calculate_subtree_hash(&self, tree_nodes: &[TreeNode], node_id: usize) -> u64 {
686        use std::collections::hash_map::DefaultHasher;
687        use std::hash::Hasher;
688
689        let mut hasher = DefaultHasher::new();
690        self.hash_subtree_recursive(tree_nodes, node_id, &mut hasher, 0);
691        hasher.finish()
692    }
693
694    /// Recursively hash a subtree structure
695    fn hash_subtree_recursive(
696        &self,
697        tree_nodes: &[TreeNode],
698        node_id: usize,
699        hasher: &mut dyn std::hash::Hasher,
700        depth: usize,
701    ) {
702        if depth >= self.config.max_shared_depth || node_id >= tree_nodes.len() {
703            return;
704        }
705
706        let node = &tree_nodes[node_id];
707
708        // Hash node properties
709        hasher.write_u8(if node.is_leaf { 1 } else { 0 });
710        if let Some(ref split) = node.best_split {
711            hasher.write_usize(split.feature_idx);
712            hasher.write(&split.threshold.to_le_bytes());
713        }
714    }
715
716    /// Calculate memory savings from subtree sharing
717    pub fn calculate_memory_savings(&self) -> Result<SubtreeSharingStats> {
718        let shared_nodes = self.shared_nodes.read().unwrap();
719        let pattern_cache = self.pattern_cache.read().unwrap();
720
721        let total_shared_nodes = shared_nodes.len();
722        let total_patterns = pattern_cache.len();
723
724        // Estimate memory saved (simplified calculation)
725        let estimated_memory_saved = total_shared_nodes * std::mem::size_of::<SharedTreeNode>();
726
727        Ok(SubtreeSharingStats {
728            total_shared_nodes,
729            total_patterns,
730            estimated_memory_saved,
731            sharing_efficiency: if total_patterns > 0 {
732                total_shared_nodes as f64 / total_patterns as f64
733            } else {
734                0.0
735            },
736        })
737    }
738}
739
740/// Tree-specific data that cannot be shared
741#[derive(Debug, Clone)]
742pub struct TreeSpecificData {
743    /// Tree ID
744    pub tree_id: usize,
745    /// Sample weights specific to this tree
746    pub sample_weights: Vec<f64>,
747    /// Tree-specific configuration
748    pub config: DecisionTreeConfig,
749    /// Non-shared nodes (typically small, tree-specific parts)
750    pub local_nodes: Vec<TreeNode>,
751}
752
753/// Ensemble of trees with shared subtree optimization
754#[derive(Debug, Clone)]
755pub struct SharedTreeEnsemble {
756    /// Shared subtree manager
757    pub subtree_manager: SharedSubtreeManager,
758    /// Tree-specific data (non-shared parts)
759    pub tree_specific_data: Vec<TreeSpecificData>,
760    /// References to shared subtrees
761    pub subtree_references: Vec<SubtreeReference>,
762}
763
764impl SharedTreeEnsemble {
765    pub fn new(subtree_config: SubtreeConfig) -> Self {
766        Self {
767            subtree_manager: SharedSubtreeManager::new(subtree_config),
768            tree_specific_data: Vec::new(),
769            subtree_references: Vec::new(),
770        }
771    }
772
773    /// Add a new tree to the ensemble with shared subtree optimization
774    pub fn add_tree(
775        &mut self,
776        tree_nodes: Vec<TreeNode>,
777        config: DecisionTreeConfig,
778        tree_id: usize,
779    ) -> Result<()> {
780        // Extract patterns from the tree
781        let patterns = self.subtree_manager.extract_patterns(&tree_nodes)?;
782
783        // Create shared subtrees for frequent patterns
784        for pattern in patterns {
785            if let Ok(shared_id) = self.subtree_manager.get_or_create_shared_subtree(
786                &pattern,
787                &tree_nodes,
788                0, // Assuming root node
789            ) {
790                self.subtree_references.push(SubtreeReference {
791                    shared_id,
792                    local_node_id: 0,
793                    tree_id,
794                });
795            }
796        }
797
798        // Store tree-specific data
799        let tree_data = TreeSpecificData {
800            tree_id,
801            sample_weights: vec![1.0; tree_nodes.len()], // Default weights
802            config,
803            local_nodes: tree_nodes,
804        };
805
806        self.tree_specific_data.push(tree_data);
807
808        Ok(())
809    }
810
811    /// Get sharing statistics for the ensemble
812    pub fn get_sharing_stats(&self) -> Result<SubtreeSharingStats> {
813        self.subtree_manager.calculate_memory_savings()
814    }
815}