1use scirs2_core::ndarray::Array2;
19use sklears_core::error::SklearsError;
20use std::any::Any;
21use std::sync::Arc;
22use std::time::Instant;
23
24pub trait Hook: Send + Sync {
26 fn before_fit(
28 &mut self,
29 x: &Array2<f64>,
30 context: &mut HookContext,
31 ) -> Result<(), SklearsError> {
32 let _ = (x, context);
33 Ok(())
34 }
35
36 fn after_fit(
38 &mut self,
39 x: &Array2<f64>,
40 context: &mut HookContext,
41 ) -> Result<(), SklearsError> {
42 let _ = (x, context);
43 Ok(())
44 }
45
46 fn before_transform(
48 &mut self,
49 x: &Array2<f64>,
50 context: &mut HookContext,
51 ) -> Result<(), SklearsError> {
52 let _ = (x, context);
53 Ok(())
54 }
55
56 fn after_transform(
58 &mut self,
59 x: &Array2<f64>,
60 output: &Array2<f64>,
61 context: &mut HookContext,
62 ) -> Result<(), SklearsError> {
63 let _ = (x, output, context);
64 Ok(())
65 }
66
67 fn on_error(&mut self, error: &SklearsError, context: &mut HookContext) {
69 let _ = (error, context);
70 }
71
72 fn name(&self) -> &str {
74 "Hook"
75 }
76}
77
78#[derive(Debug, Clone, Default)]
80pub struct HookContext {
81 pub stage: String,
83 pub transform_index: usize,
85 pub transform_name: String,
87 pub elapsed_ms: f64,
89 pub metadata: std::collections::HashMap<String, String>,
91}
92
93impl HookContext {
94 pub fn new(stage: &str, transform_index: usize, transform_name: &str) -> Self {
96 Self {
97 stage: stage.to_string(),
98 transform_index,
99 transform_name: transform_name.to_string(),
100 elapsed_ms: 0.0,
101 metadata: std::collections::HashMap::new(),
102 }
103 }
104
105 pub fn add_metadata(&mut self, key: String, value: String) {
107 self.metadata.insert(key, value);
108 }
109}
110
111pub trait Middleware: Send + Sync {
113 fn process_before_fit(&self, x: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
115 Ok(x.clone())
116 }
117
118 fn process_after_fit(&self, x: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
120 Ok(x.clone())
121 }
122
123 fn process_before_transform(&self, x: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
125 Ok(x.clone())
126 }
127
128 fn process_after_transform(&self, x: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
130 Ok(x.clone())
131 }
132
133 fn name(&self) -> &str {
135 "Middleware"
136 }
137}
138
139pub trait PipelineStage: Send + Sync {
141 fn fit(&mut self, x: &Array2<f64>) -> Result<(), SklearsError>;
143
144 fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>, SklearsError>;
146
147 fn is_fitted(&self) -> bool;
149
150 fn name(&self) -> &str;
152
153 fn clone_stage(&self) -> Box<dyn PipelineStage>;
155
156 fn as_any(&self) -> &dyn Any;
158}
159
160pub struct LoggingHook {
162 logs: Vec<String>,
163}
164
165impl LoggingHook {
166 pub fn new() -> Self {
168 Self { logs: Vec::new() }
169 }
170
171 pub fn logs(&self) -> &[String] {
173 &self.logs
174 }
175}
176
177impl Default for LoggingHook {
178 fn default() -> Self {
179 Self::new()
180 }
181}
182
183impl Hook for LoggingHook {
184 fn before_fit(
185 &mut self,
186 x: &Array2<f64>,
187 context: &mut HookContext,
188 ) -> Result<(), SklearsError> {
189 let log = format!(
190 "[{}] Before fit - transform: {}, shape: {:?}",
191 context.stage,
192 context.transform_name,
193 x.dim()
194 );
195 self.logs.push(log);
196 Ok(())
197 }
198
199 fn after_fit(
200 &mut self,
201 _x: &Array2<f64>,
202 context: &mut HookContext,
203 ) -> Result<(), SklearsError> {
204 let log = format!(
205 "[{}] After fit - transform: {}, time: {:.2}ms",
206 context.stage, context.transform_name, context.elapsed_ms
207 );
208 self.logs.push(log);
209 Ok(())
210 }
211
212 fn before_transform(
213 &mut self,
214 x: &Array2<f64>,
215 context: &mut HookContext,
216 ) -> Result<(), SklearsError> {
217 let log = format!(
218 "[{}] Before transform - transform: {}, shape: {:?}",
219 context.stage,
220 context.transform_name,
221 x.dim()
222 );
223 self.logs.push(log);
224 Ok(())
225 }
226
227 fn after_transform(
228 &mut self,
229 _x: &Array2<f64>,
230 output: &Array2<f64>,
231 context: &mut HookContext,
232 ) -> Result<(), SklearsError> {
233 let log = format!(
234 "[{}] After transform - transform: {}, output shape: {:?}, time: {:.2}ms",
235 context.stage,
236 context.transform_name,
237 output.dim(),
238 context.elapsed_ms
239 );
240 self.logs.push(log);
241 Ok(())
242 }
243
244 fn name(&self) -> &str {
245 "LoggingHook"
246 }
247}
248
249pub struct NormalizationMiddleware {
251 mean: Option<Array2<f64>>,
252 std: Option<Array2<f64>>,
253}
254
255impl NormalizationMiddleware {
256 pub fn new() -> Self {
258 Self {
259 mean: None,
260 std: None,
261 }
262 }
263}
264
265impl Default for NormalizationMiddleware {
266 fn default() -> Self {
267 Self::new()
268 }
269}
270
271impl Middleware for NormalizationMiddleware {
272 fn process_before_fit(&self, x: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
273 let mean = x
275 .mean_axis(scirs2_core::ndarray::Axis(0))
276 .ok_or_else(|| SklearsError::InvalidInput("Cannot compute mean".to_string()))?;
277 let std = x.std_axis(scirs2_core::ndarray::Axis(0), 0.0);
278
279 let mut normalized = x.clone();
281 for (i, mut col) in normalized
282 .axis_iter_mut(scirs2_core::ndarray::Axis(1))
283 .enumerate()
284 {
285 let std_val = std[i].max(1e-8); for elem in col.iter_mut() {
287 *elem = (*elem - mean[i]) / std_val;
288 }
289 }
290
291 Ok(normalized)
292 }
293
294 fn process_before_transform(&self, x: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
295 self.process_before_fit(x)
296 }
297
298 fn name(&self) -> &str {
299 "NormalizationMiddleware"
300 }
301}
302
303pub struct Pipeline {
305 stages: Vec<Box<dyn PipelineStage>>,
306 hooks: Vec<Box<dyn Hook>>,
307 middleware: Vec<Arc<dyn Middleware>>,
308 name: String,
309 is_fitted: bool,
310}
311
312impl Pipeline {
313 pub fn new(name: String) -> Self {
315 Self {
316 stages: Vec::new(),
317 hooks: Vec::new(),
318 middleware: Vec::new(),
319 name,
320 is_fitted: false,
321 }
322 }
323
324 pub fn add_stage(&mut self, stage: Box<dyn PipelineStage>) {
326 self.stages.push(stage);
327 }
328
329 pub fn add_hook(&mut self, hook: Box<dyn Hook>) {
331 self.hooks.push(hook);
332 }
333
334 pub fn add_middleware(&mut self, middleware: Arc<dyn Middleware>) {
336 self.middleware.push(middleware);
337 }
338
339 pub fn fit(&mut self, x: &Array2<f64>) -> Result<(), SklearsError> {
341 let mut current_data = x.clone();
342
343 for mw in &self.middleware {
345 current_data = mw.process_before_fit(¤t_data)?;
346 }
347
348 for (idx, stage) in self.stages.iter_mut().enumerate() {
350 let start = Instant::now();
351 let mut context = HookContext::new("fit", idx, stage.name());
352
353 for hook in &mut self.hooks {
355 hook.before_fit(¤t_data, &mut context)?;
356 }
357
358 match stage.fit(¤t_data) {
360 Ok(_) => {
361 context.elapsed_ms = start.elapsed().as_secs_f64() * 1000.0;
362
363 for hook in &mut self.hooks {
365 hook.after_fit(¤t_data, &mut context)?;
366 }
367
368 current_data = stage.transform(¤t_data)?;
370 }
371 Err(e) => {
372 for hook in &mut self.hooks {
373 hook.on_error(&e, &mut context);
374 }
375 return Err(e);
376 }
377 }
378 }
379
380 for mw in &self.middleware {
382 current_data = mw.process_after_fit(¤t_data)?;
383 }
384
385 self.is_fitted = true;
386 Ok(())
387 }
388
389 pub fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
391 if !self.is_fitted {
392 return Err(SklearsError::NotFitted {
393 operation: "Pipeline must be fitted before transform".to_string(),
394 });
395 }
396
397 let mut current_data = x.clone();
398
399 for mw in &self.middleware {
401 current_data = mw.process_before_transform(¤t_data)?;
402 }
403
404 for stage in self.stages.iter() {
406 match stage.transform(¤t_data) {
412 Ok(output) => {
413 current_data = output;
414 }
415 Err(e) => {
416 return Err(e);
417 }
418 }
419 }
420
421 for mw in &self.middleware {
423 current_data = mw.process_after_transform(¤t_data)?;
424 }
425
426 Ok(current_data)
427 }
428
429 pub fn name(&self) -> &str {
431 &self.name
432 }
433
434 pub fn is_fitted(&self) -> bool {
436 self.is_fitted
437 }
438
439 pub fn len(&self) -> usize {
441 self.stages.len()
442 }
443
444 pub fn is_empty(&self) -> bool {
446 self.stages.is_empty()
447 }
448}
449
450pub struct PipelineBuilder {
452 pipeline: Pipeline,
453}
454
455impl PipelineBuilder {
456 pub fn new(name: &str) -> Self {
458 Self {
459 pipeline: Pipeline::new(name.to_string()),
460 }
461 }
462
463 pub fn add_stage(mut self, stage: Box<dyn PipelineStage>) -> Self {
465 self.pipeline.add_stage(stage);
466 self
467 }
468
469 pub fn add_hook(mut self, hook: Box<dyn Hook>) -> Self {
471 self.pipeline.add_hook(hook);
472 self
473 }
474
475 pub fn add_middleware(mut self, middleware: Arc<dyn Middleware>) -> Self {
477 self.pipeline.add_middleware(middleware);
478 self
479 }
480
481 pub fn build(self) -> Pipeline {
483 self.pipeline
484 }
485}
486
487pub struct ValidationHook;
489
490impl Hook for ValidationHook {
491 fn after_transform(
492 &mut self,
493 _x: &Array2<f64>,
494 output: &Array2<f64>,
495 _context: &mut HookContext,
496 ) -> Result<(), SklearsError> {
497 for &val in output.iter() {
498 if val.is_nan() || val.is_infinite() {
499 return Err(SklearsError::InvalidInput(
500 "Output contains NaN or Inf values".to_string(),
501 ));
502 }
503 }
504 Ok(())
505 }
506
507 fn name(&self) -> &str {
508 "ValidationHook"
509 }
510}
511
512pub struct PerformanceHook {
514 timings: Vec<(String, f64)>,
515}
516
517impl PerformanceHook {
518 pub fn new() -> Self {
520 Self {
521 timings: Vec::new(),
522 }
523 }
524
525 pub fn timings(&self) -> &[(String, f64)] {
527 &self.timings
528 }
529
530 pub fn total_time(&self) -> f64 {
532 self.timings.iter().map(|(_, t)| t).sum()
533 }
534}
535
536impl Default for PerformanceHook {
537 fn default() -> Self {
538 Self::new()
539 }
540}
541
542impl Hook for PerformanceHook {
543 fn after_transform(
544 &mut self,
545 _x: &Array2<f64>,
546 _output: &Array2<f64>,
547 context: &mut HookContext,
548 ) -> Result<(), SklearsError> {
549 self.timings
550 .push((context.transform_name.clone(), context.elapsed_ms));
551 Ok(())
552 }
553
554 fn name(&self) -> &str {
555 "PerformanceHook"
556 }
557}
558
559#[cfg(test)]
560mod tests {
561 use super::*;
562 use scirs2_core::ndarray::array;
563
564 struct DummyStage {
565 name: String,
566 fitted: bool,
567 }
568
569 impl DummyStage {
570 fn new(name: &str) -> Self {
571 Self {
572 name: name.to_string(),
573 fitted: false,
574 }
575 }
576 }
577
578 impl PipelineStage for DummyStage {
579 fn fit(&mut self, _x: &Array2<f64>) -> Result<(), SklearsError> {
580 self.fitted = true;
581 Ok(())
582 }
583
584 fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
585 if !self.fitted {
586 return Err(SklearsError::NotFitted {
587 operation: "Stage not fitted".to_string(),
588 });
589 }
590 Ok(x.mapv(|v| v * 2.0))
591 }
592
593 fn is_fitted(&self) -> bool {
594 self.fitted
595 }
596
597 fn name(&self) -> &str {
598 &self.name
599 }
600
601 fn clone_stage(&self) -> Box<dyn PipelineStage> {
602 Box::new(DummyStage {
603 name: self.name.clone(),
604 fitted: self.fitted,
605 })
606 }
607
608 fn as_any(&self) -> &dyn Any {
609 self
610 }
611 }
612
613 #[test]
614 fn test_pipeline_basic() {
615 let mut pipeline = Pipeline::new("test".to_string());
616 pipeline.add_stage(Box::new(DummyStage::new("stage1")));
617
618 let x = array![[1.0, 2.0], [3.0, 4.0]];
619 pipeline.fit(&x).unwrap();
620
621 let result = pipeline.transform(&x).unwrap();
622 assert_eq!(result[[0, 0]], 2.0);
623 assert_eq!(result[[1, 1]], 8.0);
624 }
625
626 #[test]
627 fn test_pipeline_builder() {
628 let pipeline = PipelineBuilder::new("test")
629 .add_stage(Box::new(DummyStage::new("stage1")))
630 .add_hook(Box::new(LoggingHook::new()))
631 .build();
632
633 assert_eq!(pipeline.len(), 1);
634 assert!(!pipeline.is_empty());
635 }
636
637 #[test]
638 fn test_logging_hook() {
639 let mut hook = LoggingHook::new();
640 let x = array![[1.0, 2.0]];
641 let mut context = HookContext::new("fit", 0, "test_stage");
642
643 hook.before_fit(&x, &mut context).unwrap();
644 assert_eq!(hook.logs().len(), 1);
645 }
646
647 #[test]
648 fn test_validation_hook() {
649 let mut hook = ValidationHook;
650 let x = array![[1.0, 2.0]];
651 let output = array![[1.0, 2.0]];
652 let mut context = HookContext::new("transform", 0, "test");
653
654 assert!(hook.after_transform(&x, &output, &mut context).is_ok());
655
656 let invalid_output = array![[f64::NAN, 2.0]];
657 assert!(hook
658 .after_transform(&x, &invalid_output, &mut context)
659 .is_err());
660 }
661
662 #[test]
663 fn test_performance_hook() {
664 let mut hook = PerformanceHook::new();
665 let x = array![[1.0, 2.0]];
666 let output = array![[2.0, 4.0]];
667 let mut context = HookContext::new("transform", 0, "test");
668 context.elapsed_ms = 10.0;
669
670 hook.after_transform(&x, &output, &mut context).unwrap();
671 assert_eq!(hook.timings().len(), 1);
672 assert_eq!(hook.total_time(), 10.0);
673 }
674
675 #[test]
676 fn test_hook_context() {
677 let mut context = HookContext::new("fit", 0, "test_stage");
678 context.add_metadata("key".to_string(), "value".to_string());
679
680 assert_eq!(context.stage, "fit");
681 assert_eq!(context.transform_index, 0);
682 assert!(context.metadata.contains_key("key"));
683 }
684}