sklears_compose/
advanced_pipeline.rs

1//! Advanced pipeline features
2//!
3//! Conditional pipelines, branching, memory-efficient execution, and caching.
4
5use scirs2_core::ndarray::{Array2, ArrayView1, ArrayView2};
6use sklears_core::{
7    error::{Result as SklResult, SklearsError},
8    traits::{Estimator, Fit, Untrained},
9    types::Float,
10};
11
12use crate::{PipelinePredictor, PipelineStep};
13
14/// Data condition for conditional execution
15pub trait DataCondition: Send + Sync + std::fmt::Debug {
16    /// Check if the condition is met for the given data
17    fn check(&self, x: &ArrayView2<'_, Float>) -> bool;
18
19    /// Clone the condition
20    fn clone_condition(&self) -> Box<dyn DataCondition>;
21}
22
23/// Simple data condition based on feature count
24#[derive(Debug, Clone)]
25pub struct FeatureCountCondition {
26    min_features: usize,
27    max_features: Option<usize>,
28}
29
30impl FeatureCountCondition {
31    /// Create a new feature count condition
32    #[must_use]
33    pub fn new(min_features: usize, max_features: Option<usize>) -> Self {
34        Self {
35            min_features,
36            max_features,
37        }
38    }
39}
40
41impl DataCondition for FeatureCountCondition {
42    fn check(&self, x: &ArrayView2<'_, Float>) -> bool {
43        let n_features = x.ncols();
44        n_features >= self.min_features && self.max_features.map_or(true, |max| n_features <= max)
45    }
46
47    fn clone_condition(&self) -> Box<dyn DataCondition> {
48        Box::new(self.clone())
49    }
50}
51
52/// Conditional Pipeline Execution
53///
54/// Pipeline that conditionally executes different branches based on data characteristics.
55///
56/// # Examples
57///
58/// ```ignore
59/// use sklears_compose::{ConditionalPipeline, DataCondition};
60/// use scirs2_core::ndarray::array;
61///
62/// let conditional_pipeline = ConditionalPipeline::builder()
63///     .condition(Box::new(FeatureCountCondition::new(10, None)))
64///     .true_branch(true_pipeline)
65///     .false_branch(false_pipeline)
66///     .build();
67/// ```
68pub struct ConditionalPipeline<S = Untrained> {
69    state: S,
70    condition: Box<dyn DataCondition>,
71    true_branch: Box<dyn PipelineStep>,
72    false_branch: Option<Box<dyn PipelineStep>>,
73    default_action: String, // "passthrough", "error", "zero"
74}
75
76/// Trained state for `ConditionalPipeline`
77pub struct ConditionalPipelineTrained {
78    condition: Box<dyn DataCondition>,
79    fitted_true_branch: Box<dyn PipelineStep>,
80    fitted_false_branch: Option<Box<dyn PipelineStep>>,
81    default_action: String,
82    n_features_in: usize,
83}
84
85impl ConditionalPipeline<Untrained> {
86    /// Create a new `ConditionalPipeline`
87    #[must_use]
88    pub fn new(condition: Box<dyn DataCondition>, true_branch: Box<dyn PipelineStep>) -> Self {
89        Self {
90            state: Untrained,
91            condition,
92            true_branch,
93            false_branch: None,
94            default_action: "passthrough".to_string(),
95        }
96    }
97
98    /// Create a builder
99    #[must_use]
100    pub fn builder() -> ConditionalPipelineBuilder {
101        ConditionalPipelineBuilder::new()
102    }
103
104    /// Set the false branch
105    #[must_use]
106    pub fn false_branch(mut self, false_branch: Box<dyn PipelineStep>) -> Self {
107        self.false_branch = Some(false_branch);
108        self
109    }
110
111    /// Set the default action
112    #[must_use]
113    pub fn default_action(mut self, action: &str) -> Self {
114        self.default_action = action.to_string();
115        self
116    }
117}
118
119impl Estimator for ConditionalPipeline<Untrained> {
120    type Config = ();
121    type Error = SklearsError;
122    type Float = Float;
123
124    fn config(&self) -> &Self::Config {
125        &()
126    }
127}
128
129impl Fit<ArrayView2<'_, Float>, Option<&ArrayView1<'_, Float>>> for ConditionalPipeline<Untrained> {
130    type Fitted = ConditionalPipeline<ConditionalPipelineTrained>;
131
132    fn fit(
133        mut self,
134        x: &ArrayView2<'_, Float>,
135        y: &Option<&ArrayView1<'_, Float>>,
136    ) -> SklResult<Self::Fitted> {
137        // Fit the true branch
138        self.true_branch.fit(x, *y)?;
139
140        // Fit the false branch if it exists
141        let fitted_false_branch = if let Some(mut false_branch) = self.false_branch {
142            false_branch.fit(x, *y)?;
143            Some(false_branch)
144        } else {
145            None
146        };
147
148        Ok(ConditionalPipeline {
149            state: ConditionalPipelineTrained {
150                condition: self.condition,
151                fitted_true_branch: self.true_branch,
152                fitted_false_branch,
153                default_action: self.default_action,
154                n_features_in: x.ncols(),
155            },
156            condition: Box::new(FeatureCountCondition::new(0, None)), // Placeholder
157            true_branch: Box::new(crate::mock::MockTransformer::new()), // Placeholder
158            false_branch: None,
159            default_action: String::new(),
160        })
161    }
162}
163
164impl ConditionalPipeline<ConditionalPipelineTrained> {
165    /// Transform data using conditional logic
166    pub fn transform(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
167        if self.state.condition.check(x) {
168            self.state.fitted_true_branch.transform(x)
169        } else if let Some(ref false_branch) = self.state.fitted_false_branch {
170            false_branch.transform(x)
171        } else {
172            match self.state.default_action.as_str() {
173                "passthrough" => Ok(x.mapv(|v| v)),
174                "error" => Err(SklearsError::InvalidInput(
175                    "Condition not met and no false branch".to_string(),
176                )),
177                "zero" => Ok(Array2::zeros((x.nrows(), x.ncols()))),
178                _ => Err(SklearsError::InvalidInput(
179                    "Unknown default action".to_string(),
180                )),
181            }
182        }
183    }
184}
185
186/// Branch configuration for branching pipelines
187pub struct BranchConfig {
188    /// Name of the branch
189    pub name: String,
190    /// Condition for this branch
191    pub condition: Box<dyn DataCondition>,
192    /// Pipeline steps for this branch
193    pub steps: Vec<Box<dyn PipelineStep>>,
194}
195
196impl BranchConfig {
197    /// Create a new branch configuration
198    #[must_use]
199    pub fn new(name: String, condition: Box<dyn DataCondition>) -> Self {
200        Self {
201            name,
202            condition,
203            steps: Vec::new(),
204        }
205    }
206
207    /// Add a step to this branch
208    #[must_use]
209    pub fn step(mut self, step: Box<dyn PipelineStep>) -> Self {
210        self.steps.push(step);
211        self
212    }
213}
214
215/// Branching Pipeline
216///
217/// Pipeline that splits execution into multiple parallel branches based on data characteristics.
218///
219/// # Examples
220///
221/// ```ignore
222/// use sklears_compose::{BranchingPipeline, BranchConfig};
223/// use scirs2_core::ndarray::array;
224///
225/// let branching_pipeline = BranchingPipeline::builder()
226///     .branch(BranchConfig::new("high_dim".to_string(), high_dim_condition))
227///     .branch(BranchConfig::new("low_dim".to_string(), low_dim_condition))
228///     .build();
229/// ```
230pub struct BranchingPipeline<S = Untrained> {
231    state: S,
232    branches: Vec<BranchConfig>,
233    combination_strategy: String, // "concatenate", "average", "max", "first_match"
234    default_branch: Option<String>,
235}
236
237/// Trained state for `BranchingPipeline`
238pub struct BranchingPipelineTrained {
239    fitted_branches: Vec<(String, Box<dyn DataCondition>, Vec<Box<dyn PipelineStep>>)>,
240    combination_strategy: String,
241    default_branch: Option<String>,
242    n_features_in: usize,
243}
244
245impl BranchingPipeline<Untrained> {
246    /// Create a new `BranchingPipeline`
247    #[must_use]
248    pub fn new() -> Self {
249        Self {
250            state: Untrained,
251            branches: Vec::new(),
252            combination_strategy: "concatenate".to_string(),
253            default_branch: None,
254        }
255    }
256
257    /// Create a builder
258    #[must_use]
259    pub fn builder() -> BranchingPipelineBuilder {
260        BranchingPipelineBuilder::new()
261    }
262
263    /// Add a branch
264    pub fn add_branch(&mut self, branch: BranchConfig) {
265        self.branches.push(branch);
266    }
267
268    /// Set combination strategy
269    #[must_use]
270    pub fn combination_strategy(mut self, strategy: &str) -> Self {
271        self.combination_strategy = strategy.to_string();
272        self
273    }
274
275    /// Set default branch
276    #[must_use]
277    pub fn default_branch(mut self, branch_name: &str) -> Self {
278        self.default_branch = Some(branch_name.to_string());
279        self
280    }
281}
282
283impl Default for BranchingPipeline<Untrained> {
284    fn default() -> Self {
285        Self::new()
286    }
287}
288
289impl Estimator for BranchingPipeline<Untrained> {
290    type Config = ();
291    type Error = SklearsError;
292    type Float = Float;
293
294    fn config(&self) -> &Self::Config {
295        &()
296    }
297}
298
299impl Fit<ArrayView2<'_, Float>, Option<&ArrayView1<'_, Float>>> for BranchingPipeline<Untrained> {
300    type Fitted = BranchingPipeline<BranchingPipelineTrained>;
301
302    fn fit(
303        self,
304        x: &ArrayView2<'_, Float>,
305        y: &Option<&ArrayView1<'_, Float>>,
306    ) -> SklResult<Self::Fitted> {
307        let mut fitted_branches = Vec::new();
308
309        for branch in self.branches {
310            let mut fitted_steps = Vec::new();
311            for mut step in branch.steps {
312                step.fit(x, y.as_ref().copied())?;
313                fitted_steps.push(step);
314            }
315            fitted_branches.push((branch.name, branch.condition, fitted_steps));
316        }
317
318        Ok(BranchingPipeline {
319            state: BranchingPipelineTrained {
320                fitted_branches,
321                combination_strategy: self.combination_strategy,
322                default_branch: self.default_branch,
323                n_features_in: x.ncols(),
324            },
325            branches: Vec::new(),
326            combination_strategy: String::new(),
327            default_branch: None,
328        })
329    }
330}
331
332/// Builder for `ConditionalPipeline`
333pub struct ConditionalPipelineBuilder {
334    condition: Option<Box<dyn DataCondition>>,
335    true_branch: Option<Box<dyn PipelineStep>>,
336    false_branch: Option<Box<dyn PipelineStep>>,
337    default_action: String,
338}
339
340impl ConditionalPipelineBuilder {
341    /// Create a new builder
342    #[must_use]
343    pub fn new() -> Self {
344        Self {
345            condition: None,
346            true_branch: None,
347            false_branch: None,
348            default_action: "passthrough".to_string(),
349        }
350    }
351
352    /// Set the condition
353    #[must_use]
354    pub fn condition(mut self, condition: Box<dyn DataCondition>) -> Self {
355        self.condition = Some(condition);
356        self
357    }
358
359    /// Set the true branch
360    #[must_use]
361    pub fn true_branch(mut self, branch: Box<dyn PipelineStep>) -> Self {
362        self.true_branch = Some(branch);
363        self
364    }
365
366    /// Set the false branch
367    #[must_use]
368    pub fn false_branch(mut self, branch: Box<dyn PipelineStep>) -> Self {
369        self.false_branch = Some(branch);
370        self
371    }
372
373    /// Set the default action
374    #[must_use]
375    pub fn default_action(mut self, action: &str) -> Self {
376        self.default_action = action.to_string();
377        self
378    }
379
380    /// Build the `ConditionalPipeline`
381    pub fn build(self) -> SklResult<ConditionalPipeline<Untrained>> {
382        let condition = self
383            .condition
384            .ok_or_else(|| SklearsError::InvalidInput("Condition required".to_string()))?;
385        let true_branch = self
386            .true_branch
387            .ok_or_else(|| SklearsError::InvalidInput("True branch required".to_string()))?;
388
389        let mut pipeline = ConditionalPipeline::new(condition, true_branch);
390        if let Some(false_branch) = self.false_branch {
391            pipeline = pipeline.false_branch(false_branch);
392        }
393        pipeline = pipeline.default_action(&self.default_action);
394
395        Ok(pipeline)
396    }
397}
398
399/// Builder for `BranchingPipeline`
400pub struct BranchingPipelineBuilder {
401    branches: Vec<BranchConfig>,
402    combination_strategy: String,
403    default_branch: Option<String>,
404}
405
406impl BranchingPipelineBuilder {
407    /// Create a new builder
408    #[must_use]
409    pub fn new() -> Self {
410        Self {
411            branches: Vec::new(),
412            combination_strategy: "concatenate".to_string(),
413            default_branch: None,
414        }
415    }
416
417    /// Add a branch
418    #[must_use]
419    pub fn branch(mut self, branch: BranchConfig) -> Self {
420        self.branches.push(branch);
421        self
422    }
423
424    /// Set combination strategy
425    #[must_use]
426    pub fn combination_strategy(mut self, strategy: &str) -> Self {
427        self.combination_strategy = strategy.to_string();
428        self
429    }
430
431    /// Set default branch
432    #[must_use]
433    pub fn default_branch(mut self, branch_name: &str) -> Self {
434        self.default_branch = Some(branch_name.to_string());
435        self
436    }
437
438    /// Build the `BranchingPipeline`
439    #[must_use]
440    pub fn build(self) -> BranchingPipeline<Untrained> {
441        BranchingPipeline {
442            state: Untrained,
443            branches: self.branches,
444            combination_strategy: self.combination_strategy,
445            default_branch: self.default_branch,
446        }
447    }
448}
449
450impl Default for ConditionalPipelineBuilder {
451    fn default() -> Self {
452        Self::new()
453    }
454}
455
456impl Default for BranchingPipelineBuilder {
457    fn default() -> Self {
458        Self::new()
459    }
460}