1use scirs2_core::ndarray::{Array2, ArrayView2};
7use sklears_core::{prelude::Transform, traits::Estimator, types::Float};
8use std::marker::PhantomData;
9
10pub struct Input<T>(PhantomData<T>);
12
13pub struct Output<T>(PhantomData<T>);
15
16pub struct NumericInput;
18
19pub struct CategoricalInput;
21
22pub struct MixedInput;
24
25pub struct DenseOutput;
27
28pub struct SparseOutput;
30
31pub struct ClassificationOutput;
33
34pub struct RegressionOutput;
36
37pub struct TypedPipelineStage<I, O> {
39 _input: PhantomData<I>,
40 _output: PhantomData<O>,
41}
42
43impl<I, O> TypedPipelineStage<I, O> {
44 #[must_use]
46 pub fn new() -> Self {
47 Self {
48 _input: PhantomData,
49 _output: PhantomData,
50 }
51 }
52}
53
54impl<I, O> Default for TypedPipelineStage<I, O> {
55 fn default() -> Self {
56 Self::new()
57 }
58}
59
60pub trait TypeCompatible<T> {
62 fn is_compatible(&self) -> bool;
64}
65
66impl TypeCompatible<NumericInput> for TypedPipelineStage<NumericInput, DenseOutput> {
68 fn is_compatible(&self) -> bool {
69 true
70 }
71}
72
73impl TypeCompatible<NumericInput> for TypedPipelineStage<NumericInput, SparseOutput> {
74 fn is_compatible(&self) -> bool {
75 true
76 }
77}
78
79impl TypeCompatible<CategoricalInput> for TypedPipelineStage<CategoricalInput, DenseOutput> {
81 fn is_compatible(&self) -> bool {
82 true
83 }
84}
85
86pub struct TypedTransformer<I, O, T> {
88 transformer: T,
89 _input: PhantomData<I>,
90 _output: PhantomData<O>,
91}
92
93impl<I, O, T> TypedTransformer<I, O, T> {
94 pub fn new(transformer: T) -> Self {
96 Self {
97 transformer,
98 _input: PhantomData,
99 _output: PhantomData,
100 }
101 }
102
103 pub fn inner(&self) -> &T {
105 &self.transformer
106 }
107
108 pub fn into_inner(self) -> T {
110 self.transformer
111 }
112}
113
114pub struct TypedEstimator<I, O, E> {
116 estimator: E,
117 _input: PhantomData<I>,
118 _output: PhantomData<O>,
119}
120
121impl<I, O, E> TypedEstimator<I, O, E> {
122 pub fn new(estimator: E) -> Self {
124 Self {
125 estimator,
126 _input: PhantomData,
127 _output: PhantomData,
128 }
129 }
130
131 pub fn inner(&self) -> &E {
133 &self.estimator
134 }
135
136 pub fn into_inner(self) -> E {
138 self.estimator
139 }
140}
141
142pub trait PipelineValidation<Stages> {
144 fn validate() -> Result<(), PipelineValidationError>;
146}
147
148#[derive(Debug, Clone, PartialEq)]
150pub enum PipelineValidationError {
151 IncompatibleTypes {
153 stage_index: usize,
154 expected: String,
155 found: String,
156 },
157 MissingStage { stage_name: String },
159 InvalidConfiguration { stage_index: usize, reason: String },
161}
162
163impl std::fmt::Display for PipelineValidationError {
164 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165 match self {
166 PipelineValidationError::IncompatibleTypes {
167 stage_index,
168 expected,
169 found,
170 } => {
171 write!(
172 f,
173 "Incompatible types at stage {stage_index}: expected {expected}, found {found}"
174 )
175 }
176 PipelineValidationError::MissingStage { stage_name } => {
177 write!(f, "Missing required stage: {stage_name}")
178 }
179 PipelineValidationError::InvalidConfiguration {
180 stage_index,
181 reason,
182 } => {
183 write!(f, "Invalid configuration at stage {stage_index}: {reason}")
184 }
185 }
186 }
187}
188
189impl std::error::Error for PipelineValidationError {}
190
191pub struct TypedPipelineBuilder<T> {
193 stages: Vec<String>,
194 _phantom: PhantomData<T>,
195}
196
197impl TypedPipelineBuilder<()> {
198 #[must_use]
200 pub fn new() -> Self {
201 Self {
202 stages: Vec::new(),
203 _phantom: PhantomData,
204 }
205 }
206}
207
208impl<T> TypedPipelineBuilder<T> {
209 pub fn transform<I, O, Trans>(
211 mut self,
212 name: &str,
213 _transformer: TypedTransformer<I, O, Trans>,
214 ) -> TypedPipelineBuilder<(T, TypedTransformer<I, O, Trans>)>
215 where
216 Trans: for<'a> Transform<ArrayView2<'a, Float>, Array2<f64>>,
217 {
218 self.stages.push(name.to_string());
219 TypedPipelineBuilder {
220 stages: self.stages,
221 _phantom: PhantomData,
222 }
223 }
224
225 pub fn estimate<I, O, Est>(
227 mut self,
228 name: &str,
229 _estimator: TypedEstimator<I, O, Est>,
230 ) -> TypedPipelineBuilder<(T, TypedEstimator<I, O, Est>)>
231 where
232 Est: Estimator,
233 {
234 self.stages.push(name.to_string());
235 TypedPipelineBuilder {
236 stages: self.stages,
237 _phantom: PhantomData,
238 }
239 }
240
241 #[must_use]
243 pub fn stage_names(&self) -> &[String] {
244 &self.stages
245 }
246}
247
248impl Default for TypedPipelineBuilder<()> {
249 fn default() -> Self {
250 Self::new()
251 }
252}
253
254pub struct DataFlowValidator<T> {
256 _phantom: PhantomData<T>,
257}
258
259impl<T> DataFlowValidator<T> {
260 #[must_use]
262 pub fn new() -> Self {
263 Self {
264 _phantom: PhantomData,
265 }
266 }
267}
268
269impl<T> Default for DataFlowValidator<T> {
270 fn default() -> Self {
271 Self::new()
272 }
273}
274
275pub trait DataFlowValidation {
277 fn validate_flow(&self) -> Result<(), PipelineValidationError>;
279}
280
281impl DataFlowValidation for DataFlowValidator<NumericInput> {
282 fn validate_flow(&self) -> Result<(), PipelineValidationError> {
283 Ok(())
285 }
286}
287
288impl DataFlowValidation for DataFlowValidator<CategoricalInput> {
289 fn validate_flow(&self) -> Result<(), PipelineValidationError> {
290 Ok(())
292 }
293}
294
295pub struct TypedFeatureUnion<I, O> {
297 transformers: Vec<String>,
298 _input: PhantomData<I>,
299 _output: PhantomData<O>,
300}
301
302impl<I, O> TypedFeatureUnion<I, O> {
303 #[must_use]
305 pub fn new() -> Self {
306 Self {
307 transformers: Vec::new(),
308 _input: PhantomData,
309 _output: PhantomData,
310 }
311 }
312
313 pub fn add_transformer<T>(mut self, name: &str, _transformer: TypedTransformer<I, O, T>) -> Self
315 where
316 T: for<'a> Transform<ArrayView2<'a, Float>, Array2<f64>>,
317 {
318 self.transformers.push(name.to_string());
319 self
320 }
321
322 #[must_use]
324 pub fn transformer_names(&self) -> &[String] {
325 &self.transformers
326 }
327}
328
329impl<I, O> Default for TypedFeatureUnion<I, O> {
330 fn default() -> Self {
331 Self::new()
332 }
333}
334
335pub trait StructureValidation {
337 fn validate_structure() -> Result<(), PipelineValidationError>;
339}
340
341#[macro_export]
343macro_rules! typed_pipeline {
344 ($($stage:expr),+ $(,)?) => {{
345 let mut builder = TypedPipelineBuilder::new();
346 $(
347 builder = builder.add_stage($stage);
348 )+
349 builder
350 }};
351}
352
353#[macro_export]
355macro_rules! validate_pipeline {
356 ($pipeline:expr) => {{
357 compile_time_validate!($pipeline)
358 }};
359}
360
361#[macro_export]
363macro_rules! compile_time_validate {
364 ($pipeline:expr) => {{
365 Ok(())
367 }};
368}
369
370#[allow(non_snake_case)]
371#[cfg(test)]
372mod tests {
373 use super::*;
374
375 #[test]
376 fn test_typed_pipeline_stage_creation() {
377 let stage: TypedPipelineStage<NumericInput, DenseOutput> = TypedPipelineStage::new();
378 assert!(stage.is_compatible());
379 }
380
381 #[test]
382 fn test_typed_transformer_creation() {
383 #[derive(Debug, PartialEq)]
384 struct DummyTransformer(i32);
385
386 let dummy = DummyTransformer(42);
387 let transformer = TypedTransformer::<NumericInput, DenseOutput, _>::new(dummy);
388
389 assert_eq!(transformer.inner().0, 42);
391 }
392
393 #[test]
394 fn test_typed_estimator_creation() {
395 #[derive(Debug, PartialEq)]
396 struct DummyEstimator(String);
397
398 let dummy = DummyEstimator("test".to_string());
399 let estimator = TypedEstimator::<NumericInput, ClassificationOutput, _>::new(dummy);
400
401 assert_eq!(estimator.inner().0, "test");
403 }
404
405 #[test]
406 fn test_typed_pipeline_builder() {
407 let builder = TypedPipelineBuilder::new();
408 assert_eq!(builder.stage_names().len(), 0);
409 }
410
411 #[test]
412 fn test_data_flow_validator() {
413 let validator: DataFlowValidator<NumericInput> = DataFlowValidator::new();
414 assert!(validator.validate_flow().is_ok());
415 }
416
417 #[test]
418 fn test_typed_feature_union() {
419 let union: TypedFeatureUnion<NumericInput, DenseOutput> = TypedFeatureUnion::new();
420 assert_eq!(union.transformer_names().len(), 0);
421 }
422
423 #[test]
424 fn test_pipeline_validation_error_display() {
425 let error = PipelineValidationError::IncompatibleTypes {
426 stage_index: 1,
427 expected: "NumericInput".to_string(),
428 found: "CategoricalInput".to_string(),
429 };
430 let display = format!("{}", error);
431 assert!(display.contains("Incompatible types"));
432 assert!(display.contains("stage 1"));
433 }
434}