1use scirs2_core::ndarray::{Array2, ArrayView1, ArrayView2};
6use sklears_core::{
7 error::{Result as SklResult, SklearsError},
8 traits::{Estimator, Fit, Untrained},
9 types::Float,
10};
11
12use crate::{PipelinePredictor, PipelineStep};
13
14pub trait DataCondition: Send + Sync + std::fmt::Debug {
16 fn check(&self, x: &ArrayView2<'_, Float>) -> bool;
18
19 fn clone_condition(&self) -> Box<dyn DataCondition>;
21}
22
23#[derive(Debug, Clone)]
25pub struct FeatureCountCondition {
26 min_features: usize,
27 max_features: Option<usize>,
28}
29
30impl FeatureCountCondition {
31 #[must_use]
33 pub fn new(min_features: usize, max_features: Option<usize>) -> Self {
34 Self {
35 min_features,
36 max_features,
37 }
38 }
39}
40
41impl DataCondition for FeatureCountCondition {
42 fn check(&self, x: &ArrayView2<'_, Float>) -> bool {
43 let n_features = x.ncols();
44 n_features >= self.min_features && self.max_features.map_or(true, |max| n_features <= max)
45 }
46
47 fn clone_condition(&self) -> Box<dyn DataCondition> {
48 Box::new(self.clone())
49 }
50}
51
52pub struct ConditionalPipeline<S = Untrained> {
69 state: S,
70 condition: Box<dyn DataCondition>,
71 true_branch: Box<dyn PipelineStep>,
72 false_branch: Option<Box<dyn PipelineStep>>,
73 default_action: String, }
75
76pub struct ConditionalPipelineTrained {
78 condition: Box<dyn DataCondition>,
79 fitted_true_branch: Box<dyn PipelineStep>,
80 fitted_false_branch: Option<Box<dyn PipelineStep>>,
81 default_action: String,
82 n_features_in: usize,
83}
84
85impl ConditionalPipeline<Untrained> {
86 #[must_use]
88 pub fn new(condition: Box<dyn DataCondition>, true_branch: Box<dyn PipelineStep>) -> Self {
89 Self {
90 state: Untrained,
91 condition,
92 true_branch,
93 false_branch: None,
94 default_action: "passthrough".to_string(),
95 }
96 }
97
98 #[must_use]
100 pub fn builder() -> ConditionalPipelineBuilder {
101 ConditionalPipelineBuilder::new()
102 }
103
104 #[must_use]
106 pub fn false_branch(mut self, false_branch: Box<dyn PipelineStep>) -> Self {
107 self.false_branch = Some(false_branch);
108 self
109 }
110
111 #[must_use]
113 pub fn default_action(mut self, action: &str) -> Self {
114 self.default_action = action.to_string();
115 self
116 }
117}
118
119impl Estimator for ConditionalPipeline<Untrained> {
120 type Config = ();
121 type Error = SklearsError;
122 type Float = Float;
123
124 fn config(&self) -> &Self::Config {
125 &()
126 }
127}
128
129impl Fit<ArrayView2<'_, Float>, Option<&ArrayView1<'_, Float>>> for ConditionalPipeline<Untrained> {
130 type Fitted = ConditionalPipeline<ConditionalPipelineTrained>;
131
132 fn fit(
133 mut self,
134 x: &ArrayView2<'_, Float>,
135 y: &Option<&ArrayView1<'_, Float>>,
136 ) -> SklResult<Self::Fitted> {
137 self.true_branch.fit(x, *y)?;
139
140 let fitted_false_branch = if let Some(mut false_branch) = self.false_branch {
142 false_branch.fit(x, *y)?;
143 Some(false_branch)
144 } else {
145 None
146 };
147
148 Ok(ConditionalPipeline {
149 state: ConditionalPipelineTrained {
150 condition: self.condition,
151 fitted_true_branch: self.true_branch,
152 fitted_false_branch,
153 default_action: self.default_action,
154 n_features_in: x.ncols(),
155 },
156 condition: Box::new(FeatureCountCondition::new(0, None)), true_branch: Box::new(crate::mock::MockTransformer::new()), false_branch: None,
159 default_action: String::new(),
160 })
161 }
162}
163
164impl ConditionalPipeline<ConditionalPipelineTrained> {
165 pub fn transform(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
167 if self.state.condition.check(x) {
168 self.state.fitted_true_branch.transform(x)
169 } else if let Some(ref false_branch) = self.state.fitted_false_branch {
170 false_branch.transform(x)
171 } else {
172 match self.state.default_action.as_str() {
173 "passthrough" => Ok(x.mapv(|v| v)),
174 "error" => Err(SklearsError::InvalidInput(
175 "Condition not met and no false branch".to_string(),
176 )),
177 "zero" => Ok(Array2::zeros((x.nrows(), x.ncols()))),
178 _ => Err(SklearsError::InvalidInput(
179 "Unknown default action".to_string(),
180 )),
181 }
182 }
183 }
184}
185
186pub struct BranchConfig {
188 pub name: String,
190 pub condition: Box<dyn DataCondition>,
192 pub steps: Vec<Box<dyn PipelineStep>>,
194}
195
196impl BranchConfig {
197 #[must_use]
199 pub fn new(name: String, condition: Box<dyn DataCondition>) -> Self {
200 Self {
201 name,
202 condition,
203 steps: Vec::new(),
204 }
205 }
206
207 #[must_use]
209 pub fn step(mut self, step: Box<dyn PipelineStep>) -> Self {
210 self.steps.push(step);
211 self
212 }
213}
214
215pub struct BranchingPipeline<S = Untrained> {
231 state: S,
232 branches: Vec<BranchConfig>,
233 combination_strategy: String, default_branch: Option<String>,
235}
236
237pub struct BranchingPipelineTrained {
239 fitted_branches: Vec<(String, Box<dyn DataCondition>, Vec<Box<dyn PipelineStep>>)>,
240 combination_strategy: String,
241 default_branch: Option<String>,
242 n_features_in: usize,
243}
244
245impl BranchingPipeline<Untrained> {
246 #[must_use]
248 pub fn new() -> Self {
249 Self {
250 state: Untrained,
251 branches: Vec::new(),
252 combination_strategy: "concatenate".to_string(),
253 default_branch: None,
254 }
255 }
256
257 #[must_use]
259 pub fn builder() -> BranchingPipelineBuilder {
260 BranchingPipelineBuilder::new()
261 }
262
263 pub fn add_branch(&mut self, branch: BranchConfig) {
265 self.branches.push(branch);
266 }
267
268 #[must_use]
270 pub fn combination_strategy(mut self, strategy: &str) -> Self {
271 self.combination_strategy = strategy.to_string();
272 self
273 }
274
275 #[must_use]
277 pub fn default_branch(mut self, branch_name: &str) -> Self {
278 self.default_branch = Some(branch_name.to_string());
279 self
280 }
281}
282
283impl Default for BranchingPipeline<Untrained> {
284 fn default() -> Self {
285 Self::new()
286 }
287}
288
289impl Estimator for BranchingPipeline<Untrained> {
290 type Config = ();
291 type Error = SklearsError;
292 type Float = Float;
293
294 fn config(&self) -> &Self::Config {
295 &()
296 }
297}
298
299impl Fit<ArrayView2<'_, Float>, Option<&ArrayView1<'_, Float>>> for BranchingPipeline<Untrained> {
300 type Fitted = BranchingPipeline<BranchingPipelineTrained>;
301
302 fn fit(
303 self,
304 x: &ArrayView2<'_, Float>,
305 y: &Option<&ArrayView1<'_, Float>>,
306 ) -> SklResult<Self::Fitted> {
307 let mut fitted_branches = Vec::new();
308
309 for branch in self.branches {
310 let mut fitted_steps = Vec::new();
311 for mut step in branch.steps {
312 step.fit(x, y.as_ref().copied())?;
313 fitted_steps.push(step);
314 }
315 fitted_branches.push((branch.name, branch.condition, fitted_steps));
316 }
317
318 Ok(BranchingPipeline {
319 state: BranchingPipelineTrained {
320 fitted_branches,
321 combination_strategy: self.combination_strategy,
322 default_branch: self.default_branch,
323 n_features_in: x.ncols(),
324 },
325 branches: Vec::new(),
326 combination_strategy: String::new(),
327 default_branch: None,
328 })
329 }
330}
331
332pub struct ConditionalPipelineBuilder {
334 condition: Option<Box<dyn DataCondition>>,
335 true_branch: Option<Box<dyn PipelineStep>>,
336 false_branch: Option<Box<dyn PipelineStep>>,
337 default_action: String,
338}
339
340impl ConditionalPipelineBuilder {
341 #[must_use]
343 pub fn new() -> Self {
344 Self {
345 condition: None,
346 true_branch: None,
347 false_branch: None,
348 default_action: "passthrough".to_string(),
349 }
350 }
351
352 #[must_use]
354 pub fn condition(mut self, condition: Box<dyn DataCondition>) -> Self {
355 self.condition = Some(condition);
356 self
357 }
358
359 #[must_use]
361 pub fn true_branch(mut self, branch: Box<dyn PipelineStep>) -> Self {
362 self.true_branch = Some(branch);
363 self
364 }
365
366 #[must_use]
368 pub fn false_branch(mut self, branch: Box<dyn PipelineStep>) -> Self {
369 self.false_branch = Some(branch);
370 self
371 }
372
373 #[must_use]
375 pub fn default_action(mut self, action: &str) -> Self {
376 self.default_action = action.to_string();
377 self
378 }
379
380 pub fn build(self) -> SklResult<ConditionalPipeline<Untrained>> {
382 let condition = self
383 .condition
384 .ok_or_else(|| SklearsError::InvalidInput("Condition required".to_string()))?;
385 let true_branch = self
386 .true_branch
387 .ok_or_else(|| SklearsError::InvalidInput("True branch required".to_string()))?;
388
389 let mut pipeline = ConditionalPipeline::new(condition, true_branch);
390 if let Some(false_branch) = self.false_branch {
391 pipeline = pipeline.false_branch(false_branch);
392 }
393 pipeline = pipeline.default_action(&self.default_action);
394
395 Ok(pipeline)
396 }
397}
398
399pub struct BranchingPipelineBuilder {
401 branches: Vec<BranchConfig>,
402 combination_strategy: String,
403 default_branch: Option<String>,
404}
405
406impl BranchingPipelineBuilder {
407 #[must_use]
409 pub fn new() -> Self {
410 Self {
411 branches: Vec::new(),
412 combination_strategy: "concatenate".to_string(),
413 default_branch: None,
414 }
415 }
416
417 #[must_use]
419 pub fn branch(mut self, branch: BranchConfig) -> Self {
420 self.branches.push(branch);
421 self
422 }
423
424 #[must_use]
426 pub fn combination_strategy(mut self, strategy: &str) -> Self {
427 self.combination_strategy = strategy.to_string();
428 self
429 }
430
431 #[must_use]
433 pub fn default_branch(mut self, branch_name: &str) -> Self {
434 self.default_branch = Some(branch_name.to_string());
435 self
436 }
437
438 #[must_use]
440 pub fn build(self) -> BranchingPipeline<Untrained> {
441 BranchingPipeline {
442 state: Untrained,
443 branches: self.branches,
444 combination_strategy: self.combination_strategy,
445 default_branch: self.default_branch,
446 }
447 }
448}
449
450impl Default for ConditionalPipelineBuilder {
451 fn default() -> Self {
452 Self::new()
453 }
454}
455
456impl Default for BranchingPipelineBuilder {
457 fn default() -> Self {
458 Self::new()
459 }
460}