1use std::any::Any;
8use std::fmt::Debug;
9use std::sync::{Arc, Mutex};
10use std::time::{Duration, Instant};
11
12use scirs2_core::ndarray::{Array1, Array2};
16use sklears_core::{
17 error::Result as SklResult,
18 prelude::{Fit, Predict, SklearsError, Transform},
19 traits::Estimator,
20 types::{Float, FloatBounds},
21};
22use std::collections::HashMap;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26pub enum HookPhase {
27 BeforeExecution,
29 BeforeStep,
31 AfterStep,
33 AfterExecution,
35 OnError,
37 BeforeFit,
39 AfterFit,
41 BeforePredict,
43 AfterPredict,
45 BeforeTransform,
47 AfterTransform,
49}
50
51#[derive(Debug, Clone)]
53pub struct ExecutionContext {
54 pub execution_id: String,
56 pub step_name: Option<String>,
58 pub step_index: Option<usize>,
60 pub total_steps: usize,
62 pub start_time: Instant,
64 pub phase: HookPhase,
66 pub metadata: HashMap<String, String>,
68 pub metrics: PerformanceMetrics,
70}
71
72#[derive(Debug, Clone, Default)]
74pub struct PerformanceMetrics {
75 pub total_duration: Duration,
77 pub step_durations: HashMap<String, Duration>,
79 pub memory_usage: MemoryUsage,
81 pub data_shapes: Vec<(usize, usize)>,
83 pub error_count: usize,
85}
86
87#[derive(Debug, Clone, Default)]
89pub struct MemoryUsage {
90 pub peak_memory: usize,
92 pub current_memory: usize,
94 pub allocations: usize,
96}
97
98#[derive(Debug, Clone)]
100pub enum HookResult {
101 Continue,
103 Skip,
105 Abort(String),
107 ContinueWithData(HookData),
109}
110
111#[derive(Debug, Clone)]
113pub enum HookData {
114 Features(Array2<Float>),
116 Targets(Array1<Float>),
118 Predictions(Array1<Float>),
120 Custom(Arc<dyn Any + Send + Sync>),
122}
123
124pub trait ExecutionHook: Send + Sync + Debug {
126 fn execute(
128 &mut self,
129 context: &ExecutionContext,
130 data: Option<&HookData>,
131 ) -> SklResult<HookResult>;
132
133 fn name(&self) -> &str;
135
136 fn priority(&self) -> i32 {
138 0
139 }
140
141 fn should_execute(&self, phase: HookPhase) -> bool;
143}
144
145#[derive(Debug)]
147pub struct HookManager {
148 hooks: HashMap<HookPhase, Vec<Box<dyn ExecutionHook>>>,
149 execution_stack: Vec<ExecutionContext>,
150 global_metrics: Arc<Mutex<PerformanceMetrics>>,
151}
152
153impl HookManager {
154 #[must_use]
156 pub fn new() -> Self {
157 Self {
158 hooks: HashMap::new(),
159 execution_stack: Vec::new(),
160 global_metrics: Arc::new(Mutex::new(PerformanceMetrics::default())),
161 }
162 }
163
164 pub fn register_hook(&mut self, hook: Box<dyn ExecutionHook>, phases: Vec<HookPhase>) {
166 if let Some(&first_phase) = phases.first() {
169 self.hooks.entry(first_phase).or_default().push(hook);
170
171 if let Some(hooks) = self.hooks.get_mut(&first_phase) {
173 hooks.sort_by(|a, b| b.priority().cmp(&a.priority()));
174 }
175 }
176 }
177
178 pub fn execute_hooks(
180 &mut self,
181 phase: HookPhase,
182 context: &mut ExecutionContext,
183 data: Option<&HookData>,
184 ) -> SklResult<HookResult> {
185 context.phase = phase;
186
187 if let Some(hooks) = self.hooks.get_mut(&phase) {
188 for hook in hooks {
189 if hook.should_execute(phase) {
190 match hook.execute(context, data)? {
191 HookResult::Continue => {}
192 HookResult::Skip => return Ok(HookResult::Skip),
193 HookResult::Abort(msg) => return Ok(HookResult::Abort(msg)),
194 HookResult::ContinueWithData(modified_data) => {
195 return Ok(HookResult::ContinueWithData(modified_data));
196 }
197 }
198 }
199 }
200 }
201
202 Ok(HookResult::Continue)
203 }
204
205 #[must_use]
207 pub fn create_context(&self, execution_id: String, total_steps: usize) -> ExecutionContext {
208 ExecutionContext {
210 execution_id,
211 step_name: None,
212 step_index: None,
213 total_steps,
214 start_time: Instant::now(),
215 phase: HookPhase::BeforeExecution,
216 metadata: HashMap::new(),
217 metrics: PerformanceMetrics::default(),
218 }
219 }
220
221 pub fn push_context(&mut self, context: ExecutionContext) {
223 self.execution_stack.push(context);
224 }
225
226 pub fn pop_context(&mut self) -> Option<ExecutionContext> {
228 self.execution_stack.pop()
229 }
230
231 #[must_use]
233 pub fn current_context(&self) -> Option<&ExecutionContext> {
234 self.execution_stack.last()
235 }
236
237 pub fn current_context_mut(&mut self) -> Option<&mut ExecutionContext> {
239 self.execution_stack.last_mut()
240 }
241
242 pub fn update_global_metrics<F>(&self, updater: F)
244 where
245 F: FnOnce(&mut PerformanceMetrics),
246 {
247 if let Ok(mut metrics) = self.global_metrics.lock() {
248 updater(&mut metrics);
249 }
250 }
251
252 #[must_use]
254 pub fn global_metrics(&self) -> PerformanceMetrics {
255 self.global_metrics
256 .lock()
257 .unwrap_or_else(|e| e.into_inner())
258 .clone()
259 }
260}
261
262impl Default for HookManager {
263 fn default() -> Self {
264 Self::new()
265 }
266}
267
268#[derive(Debug, Clone)]
270pub struct LoggingHook {
271 name: String,
272 log_level: LogLevel,
273 include_data_shapes: bool,
274 include_timing: bool,
275}
276
277#[derive(Debug, Clone, Copy, PartialEq)]
278pub enum LogLevel {
279 Debug,
281 Info,
283 Warn,
285 Error,
287}
288
289impl LoggingHook {
290 #[must_use]
292 pub fn new(name: String, log_level: LogLevel) -> Self {
293 Self {
294 name,
295 log_level,
296 include_data_shapes: true,
297 include_timing: true,
298 }
299 }
300
301 #[must_use]
303 pub fn include_data_shapes(mut self, include: bool) -> Self {
304 self.include_data_shapes = include;
305 self
306 }
307
308 #[must_use]
310 pub fn include_timing(mut self, include: bool) -> Self {
311 self.include_timing = include;
312 self
313 }
314}
315
316impl ExecutionHook for LoggingHook {
317 fn execute(
318 &mut self,
319 context: &ExecutionContext,
320 data: Option<&HookData>,
321 ) -> SklResult<HookResult> {
322 let mut log_message = format!(
323 "[{}] Phase: {:?}, Execution: {}",
324 self.name, context.phase, context.execution_id
325 );
326
327 if let Some(step_name) = &context.step_name {
328 log_message.push_str(&format!(", Step: {step_name}"));
329 }
330
331 if self.include_timing {
332 let elapsed = context.start_time.elapsed();
333 log_message.push_str(&format!(", Elapsed: {elapsed:?}"));
334 }
335
336 if self.include_data_shapes {
337 if let Some(data) = data {
338 match data {
339 HookData::Features(array) => {
340 log_message.push_str(&format!(
341 ", Features: {}x{}",
342 array.nrows(),
343 array.ncols()
344 ));
345 }
346 HookData::Targets(array) => {
347 log_message.push_str(&format!(", Targets: {}", array.len()));
348 }
349 HookData::Predictions(array) => {
350 log_message.push_str(&format!(", Predictions: {}", array.len()));
351 }
352 HookData::Custom(_) => {
353 log_message.push_str(", Data: Custom");
354 }
355 }
356 }
357 }
358
359 match self.log_level {
360 LogLevel::Debug => println!("DEBUG: {log_message}"),
361 LogLevel::Info => println!("INFO: {log_message}"),
362 LogLevel::Warn => println!("WARN: {log_message}"),
363 LogLevel::Error => println!("ERROR: {log_message}"),
364 }
365
366 Ok(HookResult::Continue)
367 }
368
369 fn name(&self) -> &str {
370 &self.name
371 }
372
373 fn should_execute(&self, _phase: HookPhase) -> bool {
374 true
375 }
376}
377
378#[derive(Debug, Clone)]
380pub struct PerformanceHook {
381 name: String,
382 track_memory: bool,
383 track_timing: bool,
384 alert_threshold: Option<Duration>,
385}
386
387impl PerformanceHook {
388 #[must_use]
390 pub fn new(name: String) -> Self {
391 Self {
392 name,
393 track_memory: true,
394 track_timing: true,
395 alert_threshold: None,
396 }
397 }
398
399 #[must_use]
401 pub fn track_memory(mut self, track: bool) -> Self {
402 self.track_memory = track;
403 self
404 }
405
406 #[must_use]
408 pub fn track_timing(mut self, track: bool) -> Self {
409 self.track_timing = track;
410 self
411 }
412
413 #[must_use]
415 pub fn alert_threshold(mut self, threshold: Duration) -> Self {
416 self.alert_threshold = Some(threshold);
417 self
418 }
419}
420
421impl ExecutionHook for PerformanceHook {
422 fn execute(
423 &mut self,
424 context: &ExecutionContext,
425 _data: Option<&HookData>,
426 ) -> SklResult<HookResult> {
427 if self.track_timing {
428 let elapsed = context.start_time.elapsed();
429
430 if let Some(threshold) = self.alert_threshold {
431 if elapsed > threshold {
432 println!(
433 "PERFORMANCE ALERT [{}]: Slow operation detected - {:?} (threshold: {:?})",
434 self.name, elapsed, threshold
435 );
436 }
437 }
438 }
439
440 if self.track_memory {
441 let estimated_memory = context
443 .metrics
444 .data_shapes
445 .iter()
446 .map(|(rows, cols)| rows * cols * std::mem::size_of::<Float>())
447 .sum::<usize>();
448
449 println!(
450 "MEMORY [{}]: Estimated usage: {} bytes",
451 self.name, estimated_memory
452 );
453 }
454
455 Ok(HookResult::Continue)
456 }
457
458 fn name(&self) -> &str {
459 &self.name
460 }
461
462 fn should_execute(&self, phase: HookPhase) -> bool {
463 matches!(
464 phase,
465 HookPhase::BeforeStep
466 | HookPhase::AfterStep
467 | HookPhase::BeforeExecution
468 | HookPhase::AfterExecution
469 )
470 }
471}
472
473#[derive(Debug, Clone)]
475pub struct ValidationHook {
476 name: String,
477 check_nan: bool,
478 check_inf: bool,
479 check_shape: bool,
480 expected_features: Option<usize>,
481}
482
483impl ValidationHook {
484 #[must_use]
486 pub fn new(name: String) -> Self {
487 Self {
488 name,
489 check_nan: true,
490 check_inf: true,
491 check_shape: true,
492 expected_features: None,
493 }
494 }
495
496 #[must_use]
498 pub fn check_nan(mut self, check: bool) -> Self {
499 self.check_nan = check;
500 self
501 }
502
503 #[must_use]
505 pub fn check_inf(mut self, check: bool) -> Self {
506 self.check_inf = check;
507 self
508 }
509
510 #[must_use]
512 pub fn check_shape(mut self, check: bool) -> Self {
513 self.check_shape = check;
514 self
515 }
516
517 #[must_use]
519 pub fn expected_features(mut self, features: usize) -> Self {
520 self.expected_features = Some(features);
521 self
522 }
523}
524
525impl ExecutionHook for ValidationHook {
526 fn execute(
527 &mut self,
528 _context: &ExecutionContext,
529 data: Option<&HookData>,
530 ) -> SklResult<HookResult> {
531 if let Some(data) = data {
532 match data {
533 HookData::Features(array) => {
534 if self.check_nan && array.iter().any(|&x| x.is_nan()) {
535 return Ok(HookResult::Abort(format!(
536 "[{}] NaN values detected in features",
537 self.name
538 )));
539 }
540
541 if self.check_inf && array.iter().any(|&x| x.is_infinite()) {
542 return Ok(HookResult::Abort(format!(
543 "[{}] Infinite values detected in features",
544 self.name
545 )));
546 }
547
548 if self.check_shape {
549 if let Some(expected) = self.expected_features {
550 if array.ncols() != expected {
551 return Ok(HookResult::Abort(format!(
552 "[{}] Shape mismatch: expected {} features, got {}",
553 self.name,
554 expected,
555 array.ncols()
556 )));
557 }
558 }
559 }
560 }
561 HookData::Targets(array) | HookData::Predictions(array) => {
562 if self.check_nan && array.iter().any(|&x| x.is_nan()) {
563 return Ok(HookResult::Abort(format!(
564 "[{}] NaN values detected",
565 self.name
566 )));
567 }
568
569 if self.check_inf && array.iter().any(|&x| x.is_infinite()) {
570 return Ok(HookResult::Abort(format!(
571 "[{}] Infinite values detected",
572 self.name
573 )));
574 }
575 }
576 HookData::Custom(_) => {
577 }
579 }
580 }
581
582 Ok(HookResult::Continue)
583 }
584
585 fn name(&self) -> &str {
586 &self.name
587 }
588
589 fn should_execute(&self, phase: HookPhase) -> bool {
590 matches!(
591 phase,
592 HookPhase::BeforeStep | HookPhase::BeforePredict | HookPhase::BeforeTransform
593 )
594 }
595}
596
597pub struct CustomHookBuilder {
599 name: String,
600 phases: Vec<HookPhase>,
601 priority: i32,
602 execute_fn: Option<
603 Box<dyn Fn(&ExecutionContext, Option<&HookData>) -> SklResult<HookResult> + Send + Sync>,
604 >,
605}
606
607impl std::fmt::Debug for CustomHookBuilder {
608 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
609 f.debug_struct("CustomHookBuilder")
610 .field("name", &self.name)
611 .field("phases", &self.phases)
612 .field("priority", &self.priority)
613 .field("execute_fn", &"<function>")
614 .finish()
615 }
616}
617
618impl CustomHookBuilder {
619 #[must_use]
621 pub fn new(name: String) -> Self {
622 Self {
623 name,
624 phases: Vec::new(),
625 priority: 0,
626 execute_fn: None,
627 }
628 }
629
630 #[must_use]
632 pub fn phases(mut self, phases: Vec<HookPhase>) -> Self {
633 self.phases = phases;
634 self
635 }
636
637 #[must_use]
639 pub fn priority(mut self, priority: i32) -> Self {
640 self.priority = priority;
641 self
642 }
643
644 pub fn execute_fn<F>(mut self, f: F) -> Self
646 where
647 F: Fn(&ExecutionContext, Option<&HookData>) -> SklResult<HookResult>
648 + Send
649 + Sync
650 + 'static,
651 {
652 self.execute_fn = Some(Box::new(f));
653 self
654 }
655
656 pub fn build(self) -> SklResult<CustomHook> {
658 let execute_fn = self.execute_fn.ok_or_else(|| {
659 SklearsError::InvalidInput("Execute function is required for custom hook".to_string())
660 })?;
661
662 Ok(CustomHook {
663 name: self.name,
664 phases: self.phases,
665 priority: self.priority,
666 execute_fn,
667 })
668 }
669}
670
671pub struct CustomHook {
673 name: String,
674 phases: Vec<HookPhase>,
675 priority: i32,
676 execute_fn:
677 Box<dyn Fn(&ExecutionContext, Option<&HookData>) -> SklResult<HookResult> + Send + Sync>,
678}
679
680impl std::fmt::Debug for CustomHook {
681 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
682 f.debug_struct("CustomHook")
683 .field("name", &self.name)
684 .field("phases", &self.phases)
685 .field("priority", &self.priority)
686 .field("execute_fn", &"<function>")
687 .finish()
688 }
689}
690
691impl ExecutionHook for CustomHook {
692 fn execute(
693 &mut self,
694 context: &ExecutionContext,
695 data: Option<&HookData>,
696 ) -> SklResult<HookResult> {
697 (self.execute_fn)(context, data)
698 }
699
700 fn name(&self) -> &str {
701 &self.name
702 }
703
704 fn priority(&self) -> i32 {
705 self.priority
706 }
707
708 fn should_execute(&self, phase: HookPhase) -> bool {
709 self.phases.contains(&phase)
710 }
711}
712
713impl Clone for CustomHook {
714 fn clone(&self) -> Self {
715 panic!("CustomHook cannot be cloned due to function pointer")
718 }
719}
720
721#[allow(non_snake_case)]
722#[cfg(test)]
723mod tests {
724 use super::*;
725 use scirs2_core::ndarray::array;
726
727 #[test]
728 fn test_hook_manager_creation() {
729 let manager = HookManager::new();
730 assert!(manager.hooks.is_empty());
731 assert!(manager.execution_stack.is_empty());
732 }
733
734 #[test]
735 fn test_logging_hook() {
736 let mut hook = LoggingHook::new("test_hook".to_string(), LogLevel::Info);
737 let context = ExecutionContext {
738 execution_id: "test_exec".to_string(),
739 step_name: Some("test_step".to_string()),
740 step_index: Some(0),
741 total_steps: 1,
742 start_time: Instant::now(),
743 phase: HookPhase::BeforeStep,
744 metadata: HashMap::new(),
745 metrics: PerformanceMetrics::default(),
746 };
747
748 let result = hook
749 .execute(&context, None)
750 .expect("operation should succeed");
751 assert!(matches!(result, HookResult::Continue));
752 }
753
754 #[test]
755 fn test_validation_hook() {
756 let mut hook = ValidationHook::new("validation".to_string()).expected_features(2);
757
758 let context = ExecutionContext {
759 execution_id: "test_exec".to_string(),
760 step_name: None,
761 step_index: None,
762 total_steps: 1,
763 start_time: Instant::now(),
764 phase: HookPhase::BeforeStep,
765 metadata: HashMap::new(),
766 metrics: PerformanceMetrics::default(),
767 };
768
769 let valid_data = HookData::Features(array![[1.0, 2.0], [3.0, 4.0]]);
771 let result = hook
772 .execute(&context, Some(&valid_data))
773 .expect("operation should succeed");
774 assert!(matches!(result, HookResult::Continue));
775
776 let invalid_data = HookData::Features(array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
778 let result = hook
779 .execute(&context, Some(&invalid_data))
780 .expect("operation should succeed");
781 assert!(matches!(result, HookResult::Abort(_)));
782 }
783
784 #[test]
785 fn test_performance_hook() {
786 let mut hook =
787 PerformanceHook::new("perf".to_string()).alert_threshold(Duration::from_millis(1));
788
789 let context = ExecutionContext {
790 execution_id: "test_exec".to_string(),
791 step_name: None,
792 step_index: None,
793 total_steps: 1,
794 start_time: Instant::now() - Duration::from_millis(10),
795 phase: HookPhase::AfterStep,
796 metadata: HashMap::new(),
797 metrics: PerformanceMetrics::default(),
798 };
799
800 let result = hook
801 .execute(&context, None)
802 .expect("operation should succeed");
803 assert!(matches!(result, HookResult::Continue));
804 }
805
806 #[test]
807 fn test_hook_phases() {
808 let hook = LoggingHook::new("test".to_string(), LogLevel::Info);
809 assert!(hook.should_execute(HookPhase::BeforeExecution));
810 assert!(hook.should_execute(HookPhase::AfterStep));
811 }
812
813 #[test]
814 fn test_execution_context() {
815 let mut manager = HookManager::new();
816 let context = manager.create_context("test_id".to_string(), 5);
817
818 assert_eq!(context.execution_id, "test_id");
819 assert_eq!(context.total_steps, 5);
820 assert!(context.step_name.is_none());
821 }
822
823 #[test]
824 fn test_hook_data_variants() {
825 let features = HookData::Features(array![[1.0, 2.0], [3.0, 4.0]]);
826 let targets = HookData::Targets(array![1.0, 2.0]);
827 let predictions = HookData::Predictions(array![1.1, 2.1]);
828
829 match features {
830 HookData::Features(arr) => assert_eq!(arr.shape(), &[2, 2]),
831 _ => panic!("Wrong variant"),
832 }
833
834 match targets {
835 HookData::Targets(arr) => assert_eq!(arr.len(), 2),
836 _ => panic!("Wrong variant"),
837 }
838
839 match predictions {
840 HookData::Predictions(arr) => assert_eq!(arr.len(), 2),
841 _ => panic!("Wrong variant"),
842 }
843 }
844}
845
846#[derive(Debug, Clone)]
848pub struct HookDependency {
849 pub hook_name: String,
851 pub strict: bool,
853 pub min_priority: Option<i32>,
855}
856
857pub trait DependentExecutionHook: ExecutionHook {
859 fn dependencies(&self) -> Vec<HookDependency> {
861 Vec::new()
862 }
863
864 fn dependencies_satisfied(&self, executed_hooks: &[String]) -> bool {
866 self.dependencies()
867 .iter()
868 .all(|dep| executed_hooks.contains(&dep.hook_name))
869 }
870}
871
872pub trait AsyncExecutionHook: Send + Sync + Debug {
875 fn execute_async(
876 &mut self,
877 context: &ExecutionContext,
878 data: Option<&HookData>,
879 ) -> SklResult<HookResult>;
880
881 fn name(&self) -> &str;
882
883 fn priority(&self) -> i32 {
884 0
885 }
886
887 fn should_execute(&self, phase: HookPhase) -> bool;
889
890 fn timeout(&self) -> Option<Duration> {
892 None
893 }
894}
895
896#[derive(Debug, Clone)]
898pub struct ResourceManagerHook {
899 name: String,
900 max_memory: Option<usize>,
901 max_execution_time: Option<Duration>,
902 cpu_limit: Option<f64>, resource_usage: Arc<Mutex<ResourceUsage>>,
904}
905
906#[derive(Debug, Clone, Default)]
907pub struct ResourceUsage {
908 pub current_memory: usize,
909 pub peak_memory: usize,
910 pub cpu_usage: f64,
911 pub execution_time: Duration,
912 pub violations: Vec<ResourceViolation>,
913}
914
915#[derive(Debug, Clone)]
916pub struct ResourceViolation {
917 pub violation_type: ViolationType,
918 pub timestamp: Instant,
919 pub details: String,
920}
921
922#[derive(Debug, Clone)]
923pub enum ViolationType {
924 MemoryLimit,
926 TimeLimit,
928 CpuLimit,
930}
931
932impl ResourceManagerHook {
933 #[must_use]
935 pub fn new(name: String) -> Self {
936 Self {
937 name,
938 max_memory: None,
939 max_execution_time: None,
940 cpu_limit: None,
941 resource_usage: Arc::new(Mutex::new(ResourceUsage::default())),
942 }
943 }
944
945 #[must_use]
947 pub fn max_memory(mut self, limit: usize) -> Self {
948 self.max_memory = Some(limit);
949 self
950 }
951
952 #[must_use]
954 pub fn max_execution_time(mut self, limit: Duration) -> Self {
955 self.max_execution_time = Some(limit);
956 self
957 }
958
959 #[must_use]
961 pub fn cpu_limit(mut self, limit: f64) -> Self {
962 self.cpu_limit = Some(limit.min(1.0).max(0.0));
963 self
964 }
965
966 #[must_use]
968 pub fn get_usage(&self) -> ResourceUsage {
969 self.resource_usage
970 .lock()
971 .unwrap_or_else(|e| e.into_inner())
972 .clone()
973 }
974
975 fn check_limits(&self, context: &ExecutionContext) -> SklResult<HookResult> {
977 let mut usage = self
978 .resource_usage
979 .lock()
980 .unwrap_or_else(|e| e.into_inner());
981
982 if let Some(time_limit) = self.max_execution_time {
984 let elapsed = context.start_time.elapsed();
985 usage.execution_time = elapsed;
986
987 if elapsed > time_limit {
988 let violation = ResourceViolation {
989 violation_type: ViolationType::TimeLimit,
990 timestamp: Instant::now(),
991 details: format!(
992 "Execution time {} exceeded limit {:?}",
993 elapsed.as_secs_f64(),
994 time_limit
995 ),
996 };
997 usage.violations.push(violation);
998 return Ok(HookResult::Abort(format!(
999 "[{}] Execution time limit exceeded: {:?} > {:?}",
1000 self.name, elapsed, time_limit
1001 )));
1002 }
1003 }
1004
1005 if let Some(memory_limit) = self.max_memory {
1007 let estimated_memory = context
1008 .metrics
1009 .data_shapes
1010 .iter()
1011 .map(|(rows, cols)| rows * cols * std::mem::size_of::<Float>())
1012 .sum::<usize>();
1013
1014 usage.current_memory = estimated_memory;
1015 usage.peak_memory = usage.peak_memory.max(estimated_memory);
1016
1017 if estimated_memory > memory_limit {
1018 let violation = ResourceViolation {
1019 violation_type: ViolationType::MemoryLimit,
1020 timestamp: Instant::now(),
1021 details: format!(
1022 "Memory usage {estimated_memory} exceeded limit {memory_limit}"
1023 ),
1024 };
1025 usage.violations.push(violation);
1026 return Ok(HookResult::Abort(format!(
1027 "[{}] Memory limit exceeded: {} bytes > {} bytes",
1028 self.name, estimated_memory, memory_limit
1029 )));
1030 }
1031 }
1032
1033 Ok(HookResult::Continue)
1034 }
1035}
1036
1037impl ExecutionHook for ResourceManagerHook {
1038 fn execute(
1039 &mut self,
1040 context: &ExecutionContext,
1041 _data: Option<&HookData>,
1042 ) -> SklResult<HookResult> {
1043 self.check_limits(context)
1044 }
1045
1046 fn name(&self) -> &str {
1047 &self.name
1048 }
1049
1050 fn priority(&self) -> i32 {
1051 1000 }
1053
1054 fn should_execute(&self, phase: HookPhase) -> bool {
1055 matches!(
1056 phase,
1057 HookPhase::BeforeStep | HookPhase::AfterStep | HookPhase::BeforeExecution
1058 )
1059 }
1060}
1061
1062#[derive(Debug, Clone)]
1064pub struct SecurityAuditHook {
1065 name: String,
1066 audit_log: Arc<Mutex<Vec<AuditEntry>>>,
1067 sensitive_operations: Vec<String>,
1068 require_authorization: bool,
1069}
1070
1071#[derive(Debug, Clone)]
1072pub struct AuditEntry {
1073 pub timestamp: Instant,
1074 pub execution_id: String,
1075 pub operation: String,
1076 pub user_id: Option<String>,
1077 pub data_summary: String,
1078 pub result: AuditResult,
1079}
1080
1081#[derive(Debug, Clone)]
1082pub enum AuditResult {
1083 Success,
1085 Failed(String),
1087 Unauthorized,
1089 Suspicious(String),
1091}
1092
1093impl SecurityAuditHook {
1094 #[must_use]
1096 pub fn new(name: String) -> Self {
1097 Self {
1098 name,
1099 audit_log: Arc::new(Mutex::new(Vec::new())),
1100 sensitive_operations: Vec::new(),
1101 require_authorization: false,
1102 }
1103 }
1104
1105 #[must_use]
1107 pub fn sensitive_operations(mut self, operations: Vec<String>) -> Self {
1108 self.sensitive_operations = operations;
1109 self
1110 }
1111
1112 #[must_use]
1114 pub fn require_authorization(mut self, require: bool) -> Self {
1115 self.require_authorization = require;
1116 self
1117 }
1118
1119 #[must_use]
1121 pub fn get_audit_log(&self) -> Vec<AuditEntry> {
1122 self.audit_log
1123 .lock()
1124 .unwrap_or_else(|e| e.into_inner())
1125 .clone()
1126 }
1127
1128 fn is_sensitive_operation(&self, context: &ExecutionContext) -> bool {
1130 if let Some(step_name) = &context.step_name {
1131 self.sensitive_operations
1132 .iter()
1133 .any(|op| step_name.contains(op))
1134 } else {
1135 false
1136 }
1137 }
1138
1139 fn create_audit_entry(
1141 &self,
1142 context: &ExecutionContext,
1143 result: AuditResult,
1144 data_summary: String,
1145 ) -> AuditEntry {
1146 AuditEntry {
1148 timestamp: Instant::now(),
1149 execution_id: context.execution_id.clone(),
1150 operation: context
1151 .step_name
1152 .clone()
1153 .unwrap_or_else(|| "unknown".to_string()),
1154 user_id: context.metadata.get("user_id").cloned(),
1155 data_summary,
1156 result,
1157 }
1158 }
1159}
1160
1161impl ExecutionHook for SecurityAuditHook {
1162 fn execute(
1163 &mut self,
1164 context: &ExecutionContext,
1165 data: Option<&HookData>,
1166 ) -> SklResult<HookResult> {
1167 let is_sensitive = self.is_sensitive_operation(context);
1168
1169 let data_summary = match data {
1171 Some(HookData::Features(arr)) => format!("Features: {}x{}", arr.nrows(), arr.ncols()),
1172 Some(HookData::Targets(arr)) => format!("Targets: {}", arr.len()),
1173 Some(HookData::Predictions(arr)) => format!("Predictions: {}", arr.len()),
1174 Some(HookData::Custom(_)) => "Custom data".to_string(),
1175 None => "No data".to_string(),
1176 };
1177
1178 if is_sensitive && self.require_authorization {
1180 let has_auth = context
1181 .metadata
1182 .get("authorized")
1183 .is_some_and(|v| v == "true");
1184
1185 if !has_auth {
1186 let audit_entry =
1187 self.create_audit_entry(context, AuditResult::Unauthorized, data_summary);
1188 self.audit_log
1189 .lock()
1190 .unwrap_or_else(|e| e.into_inner())
1191 .push(audit_entry);
1192
1193 return Ok(HookResult::Abort(format!(
1194 "[{}] Unauthorized access to sensitive operation: {}",
1195 self.name,
1196 context.step_name.as_deref().unwrap_or("unknown")
1197 )));
1198 }
1199 }
1200
1201 if is_sensitive || !self.sensitive_operations.is_empty() {
1203 let result = if is_sensitive {
1204 if data_summary.contains("empty") {
1206 AuditResult::Suspicious("Empty data in sensitive operation".to_string())
1207 } else {
1208 AuditResult::Success
1209 }
1210 } else {
1211 AuditResult::Success
1212 };
1213
1214 let audit_entry = self.create_audit_entry(context, result, data_summary);
1215 self.audit_log
1216 .lock()
1217 .unwrap_or_else(|e| e.into_inner())
1218 .push(audit_entry);
1219 }
1220
1221 Ok(HookResult::Continue)
1222 }
1223
1224 fn name(&self) -> &str {
1225 &self.name
1226 }
1227
1228 fn priority(&self) -> i32 {
1229 900 }
1231
1232 fn should_execute(&self, phase: HookPhase) -> bool {
1233 matches!(
1234 phase,
1235 HookPhase::BeforeStep | HookPhase::BeforePredict | HookPhase::BeforeTransform
1236 )
1237 }
1238}
1239
1240#[derive(Debug, Clone)]
1242pub struct ErrorRecoveryHook {
1243 name: String,
1244 retry_count: usize,
1245 retry_delay: Duration,
1246 fallback_strategies: Vec<FallbackStrategy>,
1247 error_history: Arc<Mutex<Vec<ErrorRecord>>>,
1248}
1249
1250#[derive(Debug, Clone)]
1251pub struct ErrorRecord {
1252 pub timestamp: Instant,
1253 pub execution_id: String,
1254 pub error_type: String,
1255 pub error_message: String,
1256 pub recovery_attempted: bool,
1257 pub recovery_successful: bool,
1258}
1259
1260#[derive(Debug, Clone)]
1261pub enum FallbackStrategy {
1262 RetryWithDelay(Duration),
1264 UseDefaultValues,
1266 SkipStep,
1268 AbortExecution,
1270 CustomRecovery(String), }
1273
1274impl ErrorRecoveryHook {
1275 #[must_use]
1277 pub fn new(name: String) -> Self {
1278 Self {
1279 name,
1280 retry_count: 3,
1281 retry_delay: Duration::from_millis(100),
1282 fallback_strategies: vec![
1283 FallbackStrategy::RetryWithDelay(Duration::from_millis(100)),
1284 FallbackStrategy::UseDefaultValues,
1285 FallbackStrategy::SkipStep,
1286 ],
1287 error_history: Arc::new(Mutex::new(Vec::new())),
1288 }
1289 }
1290
1291 #[must_use]
1293 pub fn retry_config(mut self, count: usize, delay: Duration) -> Self {
1294 self.retry_count = count;
1295 self.retry_delay = delay;
1296 self
1297 }
1298
1299 #[must_use]
1301 pub fn fallback_strategies(mut self, strategies: Vec<FallbackStrategy>) -> Self {
1302 self.fallback_strategies = strategies;
1303 self
1304 }
1305
1306 #[must_use]
1308 pub fn get_error_history(&self) -> Vec<ErrorRecord> {
1309 self.error_history
1310 .lock()
1311 .unwrap_or_else(|e| e.into_inner())
1312 .clone()
1313 }
1314
1315 fn record_error(
1317 &self,
1318 context: &ExecutionContext,
1319 error: &str,
1320 recovery_attempted: bool,
1321 recovery_successful: bool,
1322 ) {
1323 let record = ErrorRecord {
1324 timestamp: Instant::now(),
1325 execution_id: context.execution_id.clone(),
1326 error_type: "execution_error".to_string(),
1327 error_message: error.to_string(),
1328 recovery_attempted,
1329 recovery_successful,
1330 };
1331
1332 self.error_history
1333 .lock()
1334 .unwrap_or_else(|e| e.into_inner())
1335 .push(record);
1336 }
1337}
1338
1339impl ExecutionHook for ErrorRecoveryHook {
1340 fn execute(
1341 &mut self,
1342 context: &ExecutionContext,
1343 _data: Option<&HookData>,
1344 ) -> SklResult<HookResult> {
1345 if matches!(context.phase, HookPhase::OnError) {
1347 let error_msg = context
1349 .metadata
1350 .get("error")
1351 .unwrap_or(&"Unknown error".to_string())
1352 .clone();
1353
1354 for strategy in &self.fallback_strategies {
1356 match strategy {
1357 FallbackStrategy::RetryWithDelay(delay) => {
1358 self.record_error(context, &error_msg, true, false);
1359 std::thread::sleep(*delay);
1360 println!("[{}] Retrying after delay: {:?}", self.name, delay);
1362 return Ok(HookResult::Continue);
1363 }
1364 FallbackStrategy::UseDefaultValues => {
1365 self.record_error(context, &error_msg, true, true);
1366 println!("[{}] Using default values for recovery", self.name);
1367 return Ok(HookResult::ContinueWithData(HookData::Features(
1369 Array2::zeros((1, 1)),
1370 )));
1371 }
1372 FallbackStrategy::SkipStep => {
1373 self.record_error(context, &error_msg, true, true);
1374 println!("[{}] Skipping step for recovery", self.name);
1375 return Ok(HookResult::Skip);
1376 }
1377 FallbackStrategy::AbortExecution => {
1378 self.record_error(context, &error_msg, false, false);
1379 return Ok(HookResult::Abort(format!(
1380 "[{}] Unrecoverable error: {}",
1381 self.name, error_msg
1382 )));
1383 }
1384 FallbackStrategy::CustomRecovery(name) => {
1385 println!("[{}] Attempting custom recovery: {}", self.name, name);
1386 self.record_error(context, &error_msg, true, false);
1388 }
1389 }
1390 }
1391 }
1392
1393 Ok(HookResult::Continue)
1394 }
1395
1396 fn name(&self) -> &str {
1397 &self.name
1398 }
1399
1400 fn priority(&self) -> i32 {
1401 500 }
1403
1404 fn should_execute(&self, phase: HookPhase) -> bool {
1405 matches!(phase, HookPhase::OnError)
1406 }
1407}
1408
1409#[derive(Debug)]
1411pub struct HookComposition {
1412 name: String,
1413 hooks: Vec<Box<dyn ExecutionHook>>,
1414 execution_strategy: CompositionStrategy,
1415}
1416
1417#[derive(Debug, Clone)]
1418pub enum CompositionStrategy {
1419 Sequential,
1421 Parallel,
1423 FirstMatch,
1425 Aggregate,
1427}
1428
1429impl HookComposition {
1430 #[must_use]
1432 pub fn new(name: String, strategy: CompositionStrategy) -> Self {
1433 Self {
1434 name,
1435 hooks: Vec::new(),
1436 execution_strategy: strategy,
1437 }
1438 }
1439
1440 pub fn add_hook(&mut self, hook: Box<dyn ExecutionHook>) {
1442 self.hooks.push(hook);
1443 self.hooks.sort_by(|a, b| b.priority().cmp(&a.priority()));
1445 }
1446}
1447
1448impl ExecutionHook for HookComposition {
1449 fn execute(
1450 &mut self,
1451 context: &ExecutionContext,
1452 data: Option<&HookData>,
1453 ) -> SklResult<HookResult> {
1454 match self.execution_strategy {
1455 CompositionStrategy::Sequential => {
1456 for hook in &mut self.hooks {
1457 if hook.should_execute(context.phase) {
1458 let result = hook.execute(context, data)?;
1459 if !matches!(result, HookResult::Continue) {
1460 return Ok(result);
1461 }
1462 }
1463 }
1464 Ok(HookResult::Continue)
1465 }
1466 CompositionStrategy::FirstMatch => {
1467 for hook in &mut self.hooks {
1468 if hook.should_execute(context.phase) {
1469 let result = hook.execute(context, data)?;
1470 if !matches!(result, HookResult::Continue) {
1471 return Ok(result);
1472 }
1473 }
1474 }
1475 Ok(HookResult::Continue)
1476 }
1477 CompositionStrategy::Parallel => {
1478 let mut results = Vec::new();
1480 for hook in &mut self.hooks {
1481 if hook.should_execute(context.phase) {
1482 results.push(hook.execute(context, data)?);
1483 }
1484 }
1485
1486 for result in results {
1488 if !matches!(result, HookResult::Continue) {
1489 return Ok(result);
1490 }
1491 }
1492 Ok(HookResult::Continue)
1493 }
1494 CompositionStrategy::Aggregate => {
1495 for hook in &mut self.hooks {
1497 if hook.should_execute(context.phase) {
1498 let _result = hook.execute(context, data)?;
1499 }
1501 }
1502 Ok(HookResult::Continue)
1503 }
1504 }
1505 }
1506
1507 fn name(&self) -> &str {
1508 &self.name
1509 }
1510
1511 fn priority(&self) -> i32 {
1512 self.hooks.iter().map(|h| h.priority()).max().unwrap_or(0)
1514 }
1515
1516 fn should_execute(&self, phase: HookPhase) -> bool {
1517 self.hooks.iter().any(|h| h.should_execute(phase))
1518 }
1519}