1use std::any::Any;
8use std::fmt::Debug;
9use std::sync::{Arc, Mutex};
10use std::time::{Duration, Instant};
11
12use scirs2_core::ndarray::{Array1, Array2};
16use sklears_core::{
17 error::Result as SklResult,
18 prelude::{Fit, Predict, SklearsError, Transform},
19 traits::Estimator,
20 types::{Float, FloatBounds},
21};
22use std::collections::HashMap;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26pub enum HookPhase {
27 BeforeExecution,
29 BeforeStep,
31 AfterStep,
33 AfterExecution,
35 OnError,
37 BeforeFit,
39 AfterFit,
41 BeforePredict,
43 AfterPredict,
45 BeforeTransform,
47 AfterTransform,
49}
50
51#[derive(Debug, Clone)]
53pub struct ExecutionContext {
54 pub execution_id: String,
56 pub step_name: Option<String>,
58 pub step_index: Option<usize>,
60 pub total_steps: usize,
62 pub start_time: Instant,
64 pub phase: HookPhase,
66 pub metadata: HashMap<String, String>,
68 pub metrics: PerformanceMetrics,
70}
71
72#[derive(Debug, Clone, Default)]
74pub struct PerformanceMetrics {
75 pub total_duration: Duration,
77 pub step_durations: HashMap<String, Duration>,
79 pub memory_usage: MemoryUsage,
81 pub data_shapes: Vec<(usize, usize)>,
83 pub error_count: usize,
85}
86
87#[derive(Debug, Clone, Default)]
89pub struct MemoryUsage {
90 pub peak_memory: usize,
92 pub current_memory: usize,
94 pub allocations: usize,
96}
97
98#[derive(Debug, Clone)]
100pub enum HookResult {
101 Continue,
103 Skip,
105 Abort(String),
107 ContinueWithData(HookData),
109}
110
111#[derive(Debug, Clone)]
113pub enum HookData {
114 Features(Array2<Float>),
116 Targets(Array1<Float>),
118 Predictions(Array1<Float>),
120 Custom(Arc<dyn Any + Send + Sync>),
122}
123
124pub trait ExecutionHook: Send + Sync + Debug {
126 fn execute(
128 &mut self,
129 context: &ExecutionContext,
130 data: Option<&HookData>,
131 ) -> SklResult<HookResult>;
132
133 fn name(&self) -> &str;
135
136 fn priority(&self) -> i32 {
138 0
139 }
140
141 fn should_execute(&self, phase: HookPhase) -> bool;
143}
144
145#[derive(Debug)]
147pub struct HookManager {
148 hooks: HashMap<HookPhase, Vec<Box<dyn ExecutionHook>>>,
149 execution_stack: Vec<ExecutionContext>,
150 global_metrics: Arc<Mutex<PerformanceMetrics>>,
151}
152
153impl HookManager {
154 #[must_use]
156 pub fn new() -> Self {
157 Self {
158 hooks: HashMap::new(),
159 execution_stack: Vec::new(),
160 global_metrics: Arc::new(Mutex::new(PerformanceMetrics::default())),
161 }
162 }
163
164 pub fn register_hook(&mut self, hook: Box<dyn ExecutionHook>, phases: Vec<HookPhase>) {
166 if let Some(&first_phase) = phases.first() {
169 self.hooks.entry(first_phase).or_default().push(hook);
170
171 if let Some(hooks) = self.hooks.get_mut(&first_phase) {
173 hooks.sort_by(|a, b| b.priority().cmp(&a.priority()));
174 }
175 }
176 }
177
178 pub fn execute_hooks(
180 &mut self,
181 phase: HookPhase,
182 context: &mut ExecutionContext,
183 data: Option<&HookData>,
184 ) -> SklResult<HookResult> {
185 context.phase = phase;
186
187 if let Some(hooks) = self.hooks.get_mut(&phase) {
188 for hook in hooks {
189 if hook.should_execute(phase) {
190 match hook.execute(context, data)? {
191 HookResult::Continue => {}
192 HookResult::Skip => return Ok(HookResult::Skip),
193 HookResult::Abort(msg) => return Ok(HookResult::Abort(msg)),
194 HookResult::ContinueWithData(modified_data) => {
195 return Ok(HookResult::ContinueWithData(modified_data));
196 }
197 }
198 }
199 }
200 }
201
202 Ok(HookResult::Continue)
203 }
204
205 #[must_use]
207 pub fn create_context(&self, execution_id: String, total_steps: usize) -> ExecutionContext {
208 ExecutionContext {
210 execution_id,
211 step_name: None,
212 step_index: None,
213 total_steps,
214 start_time: Instant::now(),
215 phase: HookPhase::BeforeExecution,
216 metadata: HashMap::new(),
217 metrics: PerformanceMetrics::default(),
218 }
219 }
220
221 pub fn push_context(&mut self, context: ExecutionContext) {
223 self.execution_stack.push(context);
224 }
225
226 pub fn pop_context(&mut self) -> Option<ExecutionContext> {
228 self.execution_stack.pop()
229 }
230
231 #[must_use]
233 pub fn current_context(&self) -> Option<&ExecutionContext> {
234 self.execution_stack.last()
235 }
236
237 pub fn current_context_mut(&mut self) -> Option<&mut ExecutionContext> {
239 self.execution_stack.last_mut()
240 }
241
242 pub fn update_global_metrics<F>(&self, updater: F)
244 where
245 F: FnOnce(&mut PerformanceMetrics),
246 {
247 if let Ok(mut metrics) = self.global_metrics.lock() {
248 updater(&mut metrics);
249 }
250 }
251
252 #[must_use]
254 pub fn global_metrics(&self) -> PerformanceMetrics {
255 self.global_metrics.lock().unwrap().clone()
256 }
257}
258
259impl Default for HookManager {
260 fn default() -> Self {
261 Self::new()
262 }
263}
264
265#[derive(Debug, Clone)]
267pub struct LoggingHook {
268 name: String,
269 log_level: LogLevel,
270 include_data_shapes: bool,
271 include_timing: bool,
272}
273
274#[derive(Debug, Clone, Copy, PartialEq)]
275pub enum LogLevel {
276 Debug,
278 Info,
280 Warn,
282 Error,
284}
285
286impl LoggingHook {
287 #[must_use]
289 pub fn new(name: String, log_level: LogLevel) -> Self {
290 Self {
291 name,
292 log_level,
293 include_data_shapes: true,
294 include_timing: true,
295 }
296 }
297
298 #[must_use]
300 pub fn include_data_shapes(mut self, include: bool) -> Self {
301 self.include_data_shapes = include;
302 self
303 }
304
305 #[must_use]
307 pub fn include_timing(mut self, include: bool) -> Self {
308 self.include_timing = include;
309 self
310 }
311}
312
313impl ExecutionHook for LoggingHook {
314 fn execute(
315 &mut self,
316 context: &ExecutionContext,
317 data: Option<&HookData>,
318 ) -> SklResult<HookResult> {
319 let mut log_message = format!(
320 "[{}] Phase: {:?}, Execution: {}",
321 self.name, context.phase, context.execution_id
322 );
323
324 if let Some(step_name) = &context.step_name {
325 log_message.push_str(&format!(", Step: {step_name}"));
326 }
327
328 if self.include_timing {
329 let elapsed = context.start_time.elapsed();
330 log_message.push_str(&format!(", Elapsed: {elapsed:?}"));
331 }
332
333 if self.include_data_shapes {
334 if let Some(data) = data {
335 match data {
336 HookData::Features(array) => {
337 log_message.push_str(&format!(
338 ", Features: {}x{}",
339 array.nrows(),
340 array.ncols()
341 ));
342 }
343 HookData::Targets(array) => {
344 log_message.push_str(&format!(", Targets: {}", array.len()));
345 }
346 HookData::Predictions(array) => {
347 log_message.push_str(&format!(", Predictions: {}", array.len()));
348 }
349 HookData::Custom(_) => {
350 log_message.push_str(", Data: Custom");
351 }
352 }
353 }
354 }
355
356 match self.log_level {
357 LogLevel::Debug => println!("DEBUG: {log_message}"),
358 LogLevel::Info => println!("INFO: {log_message}"),
359 LogLevel::Warn => println!("WARN: {log_message}"),
360 LogLevel::Error => println!("ERROR: {log_message}"),
361 }
362
363 Ok(HookResult::Continue)
364 }
365
366 fn name(&self) -> &str {
367 &self.name
368 }
369
370 fn should_execute(&self, _phase: HookPhase) -> bool {
371 true
372 }
373}
374
375#[derive(Debug, Clone)]
377pub struct PerformanceHook {
378 name: String,
379 track_memory: bool,
380 track_timing: bool,
381 alert_threshold: Option<Duration>,
382}
383
384impl PerformanceHook {
385 #[must_use]
387 pub fn new(name: String) -> Self {
388 Self {
389 name,
390 track_memory: true,
391 track_timing: true,
392 alert_threshold: None,
393 }
394 }
395
396 #[must_use]
398 pub fn track_memory(mut self, track: bool) -> Self {
399 self.track_memory = track;
400 self
401 }
402
403 #[must_use]
405 pub fn track_timing(mut self, track: bool) -> Self {
406 self.track_timing = track;
407 self
408 }
409
410 #[must_use]
412 pub fn alert_threshold(mut self, threshold: Duration) -> Self {
413 self.alert_threshold = Some(threshold);
414 self
415 }
416}
417
418impl ExecutionHook for PerformanceHook {
419 fn execute(
420 &mut self,
421 context: &ExecutionContext,
422 _data: Option<&HookData>,
423 ) -> SklResult<HookResult> {
424 if self.track_timing {
425 let elapsed = context.start_time.elapsed();
426
427 if let Some(threshold) = self.alert_threshold {
428 if elapsed > threshold {
429 println!(
430 "PERFORMANCE ALERT [{}]: Slow operation detected - {:?} (threshold: {:?})",
431 self.name, elapsed, threshold
432 );
433 }
434 }
435 }
436
437 if self.track_memory {
438 let estimated_memory = context
440 .metrics
441 .data_shapes
442 .iter()
443 .map(|(rows, cols)| rows * cols * std::mem::size_of::<Float>())
444 .sum::<usize>();
445
446 println!(
447 "MEMORY [{}]: Estimated usage: {} bytes",
448 self.name, estimated_memory
449 );
450 }
451
452 Ok(HookResult::Continue)
453 }
454
455 fn name(&self) -> &str {
456 &self.name
457 }
458
459 fn should_execute(&self, phase: HookPhase) -> bool {
460 matches!(
461 phase,
462 HookPhase::BeforeStep
463 | HookPhase::AfterStep
464 | HookPhase::BeforeExecution
465 | HookPhase::AfterExecution
466 )
467 }
468}
469
470#[derive(Debug, Clone)]
472pub struct ValidationHook {
473 name: String,
474 check_nan: bool,
475 check_inf: bool,
476 check_shape: bool,
477 expected_features: Option<usize>,
478}
479
480impl ValidationHook {
481 #[must_use]
483 pub fn new(name: String) -> Self {
484 Self {
485 name,
486 check_nan: true,
487 check_inf: true,
488 check_shape: true,
489 expected_features: None,
490 }
491 }
492
493 #[must_use]
495 pub fn check_nan(mut self, check: bool) -> Self {
496 self.check_nan = check;
497 self
498 }
499
500 #[must_use]
502 pub fn check_inf(mut self, check: bool) -> Self {
503 self.check_inf = check;
504 self
505 }
506
507 #[must_use]
509 pub fn check_shape(mut self, check: bool) -> Self {
510 self.check_shape = check;
511 self
512 }
513
514 #[must_use]
516 pub fn expected_features(mut self, features: usize) -> Self {
517 self.expected_features = Some(features);
518 self
519 }
520}
521
522impl ExecutionHook for ValidationHook {
523 fn execute(
524 &mut self,
525 _context: &ExecutionContext,
526 data: Option<&HookData>,
527 ) -> SklResult<HookResult> {
528 if let Some(data) = data {
529 match data {
530 HookData::Features(array) => {
531 if self.check_nan && array.iter().any(|&x| x.is_nan()) {
532 return Ok(HookResult::Abort(format!(
533 "[{}] NaN values detected in features",
534 self.name
535 )));
536 }
537
538 if self.check_inf && array.iter().any(|&x| x.is_infinite()) {
539 return Ok(HookResult::Abort(format!(
540 "[{}] Infinite values detected in features",
541 self.name
542 )));
543 }
544
545 if self.check_shape {
546 if let Some(expected) = self.expected_features {
547 if array.ncols() != expected {
548 return Ok(HookResult::Abort(format!(
549 "[{}] Shape mismatch: expected {} features, got {}",
550 self.name,
551 expected,
552 array.ncols()
553 )));
554 }
555 }
556 }
557 }
558 HookData::Targets(array) | HookData::Predictions(array) => {
559 if self.check_nan && array.iter().any(|&x| x.is_nan()) {
560 return Ok(HookResult::Abort(format!(
561 "[{}] NaN values detected",
562 self.name
563 )));
564 }
565
566 if self.check_inf && array.iter().any(|&x| x.is_infinite()) {
567 return Ok(HookResult::Abort(format!(
568 "[{}] Infinite values detected",
569 self.name
570 )));
571 }
572 }
573 HookData::Custom(_) => {
574 }
576 }
577 }
578
579 Ok(HookResult::Continue)
580 }
581
582 fn name(&self) -> &str {
583 &self.name
584 }
585
586 fn should_execute(&self, phase: HookPhase) -> bool {
587 matches!(
588 phase,
589 HookPhase::BeforeStep | HookPhase::BeforePredict | HookPhase::BeforeTransform
590 )
591 }
592}
593
594pub struct CustomHookBuilder {
596 name: String,
597 phases: Vec<HookPhase>,
598 priority: i32,
599 execute_fn: Option<
600 Box<dyn Fn(&ExecutionContext, Option<&HookData>) -> SklResult<HookResult> + Send + Sync>,
601 >,
602}
603
604impl std::fmt::Debug for CustomHookBuilder {
605 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
606 f.debug_struct("CustomHookBuilder")
607 .field("name", &self.name)
608 .field("phases", &self.phases)
609 .field("priority", &self.priority)
610 .field("execute_fn", &"<function>")
611 .finish()
612 }
613}
614
615impl CustomHookBuilder {
616 #[must_use]
618 pub fn new(name: String) -> Self {
619 Self {
620 name,
621 phases: Vec::new(),
622 priority: 0,
623 execute_fn: None,
624 }
625 }
626
627 #[must_use]
629 pub fn phases(mut self, phases: Vec<HookPhase>) -> Self {
630 self.phases = phases;
631 self
632 }
633
634 #[must_use]
636 pub fn priority(mut self, priority: i32) -> Self {
637 self.priority = priority;
638 self
639 }
640
641 pub fn execute_fn<F>(mut self, f: F) -> Self
643 where
644 F: Fn(&ExecutionContext, Option<&HookData>) -> SklResult<HookResult>
645 + Send
646 + Sync
647 + 'static,
648 {
649 self.execute_fn = Some(Box::new(f));
650 self
651 }
652
653 pub fn build(self) -> SklResult<CustomHook> {
655 let execute_fn = self.execute_fn.ok_or_else(|| {
656 SklearsError::InvalidInput("Execute function is required for custom hook".to_string())
657 })?;
658
659 Ok(CustomHook {
660 name: self.name,
661 phases: self.phases,
662 priority: self.priority,
663 execute_fn,
664 })
665 }
666}
667
668pub struct CustomHook {
670 name: String,
671 phases: Vec<HookPhase>,
672 priority: i32,
673 execute_fn:
674 Box<dyn Fn(&ExecutionContext, Option<&HookData>) -> SklResult<HookResult> + Send + Sync>,
675}
676
677impl std::fmt::Debug for CustomHook {
678 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
679 f.debug_struct("CustomHook")
680 .field("name", &self.name)
681 .field("phases", &self.phases)
682 .field("priority", &self.priority)
683 .field("execute_fn", &"<function>")
684 .finish()
685 }
686}
687
688impl ExecutionHook for CustomHook {
689 fn execute(
690 &mut self,
691 context: &ExecutionContext,
692 data: Option<&HookData>,
693 ) -> SklResult<HookResult> {
694 (self.execute_fn)(context, data)
695 }
696
697 fn name(&self) -> &str {
698 &self.name
699 }
700
701 fn priority(&self) -> i32 {
702 self.priority
703 }
704
705 fn should_execute(&self, phase: HookPhase) -> bool {
706 self.phases.contains(&phase)
707 }
708}
709
710impl Clone for CustomHook {
711 fn clone(&self) -> Self {
712 panic!("CustomHook cannot be cloned due to function pointer")
715 }
716}
717
718#[allow(non_snake_case)]
719#[cfg(test)]
720mod tests {
721 use super::*;
722 use scirs2_core::ndarray::array;
723
724 #[test]
725 fn test_hook_manager_creation() {
726 let manager = HookManager::new();
727 assert!(manager.hooks.is_empty());
728 assert!(manager.execution_stack.is_empty());
729 }
730
731 #[test]
732 fn test_logging_hook() {
733 let mut hook = LoggingHook::new("test_hook".to_string(), LogLevel::Info);
734 let context = ExecutionContext {
735 execution_id: "test_exec".to_string(),
736 step_name: Some("test_step".to_string()),
737 step_index: Some(0),
738 total_steps: 1,
739 start_time: Instant::now(),
740 phase: HookPhase::BeforeStep,
741 metadata: HashMap::new(),
742 metrics: PerformanceMetrics::default(),
743 };
744
745 let result = hook.execute(&context, None).unwrap();
746 assert!(matches!(result, HookResult::Continue));
747 }
748
749 #[test]
750 fn test_validation_hook() {
751 let mut hook = ValidationHook::new("validation".to_string()).expected_features(2);
752
753 let context = ExecutionContext {
754 execution_id: "test_exec".to_string(),
755 step_name: None,
756 step_index: None,
757 total_steps: 1,
758 start_time: Instant::now(),
759 phase: HookPhase::BeforeStep,
760 metadata: HashMap::new(),
761 metrics: PerformanceMetrics::default(),
762 };
763
764 let valid_data = HookData::Features(array![[1.0, 2.0], [3.0, 4.0]]);
766 let result = hook.execute(&context, Some(&valid_data)).unwrap();
767 assert!(matches!(result, HookResult::Continue));
768
769 let invalid_data = HookData::Features(array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
771 let result = hook.execute(&context, Some(&invalid_data)).unwrap();
772 assert!(matches!(result, HookResult::Abort(_)));
773 }
774
775 #[test]
776 fn test_performance_hook() {
777 let mut hook =
778 PerformanceHook::new("perf".to_string()).alert_threshold(Duration::from_millis(1));
779
780 let context = ExecutionContext {
781 execution_id: "test_exec".to_string(),
782 step_name: None,
783 step_index: None,
784 total_steps: 1,
785 start_time: Instant::now() - Duration::from_millis(10),
786 phase: HookPhase::AfterStep,
787 metadata: HashMap::new(),
788 metrics: PerformanceMetrics::default(),
789 };
790
791 let result = hook.execute(&context, None).unwrap();
792 assert!(matches!(result, HookResult::Continue));
793 }
794
795 #[test]
796 fn test_hook_phases() {
797 let hook = LoggingHook::new("test".to_string(), LogLevel::Info);
798 assert!(hook.should_execute(HookPhase::BeforeExecution));
799 assert!(hook.should_execute(HookPhase::AfterStep));
800 }
801
802 #[test]
803 fn test_execution_context() {
804 let mut manager = HookManager::new();
805 let context = manager.create_context("test_id".to_string(), 5);
806
807 assert_eq!(context.execution_id, "test_id");
808 assert_eq!(context.total_steps, 5);
809 assert!(context.step_name.is_none());
810 }
811
812 #[test]
813 fn test_hook_data_variants() {
814 let features = HookData::Features(array![[1.0, 2.0], [3.0, 4.0]]);
815 let targets = HookData::Targets(array![1.0, 2.0]);
816 let predictions = HookData::Predictions(array![1.1, 2.1]);
817
818 match features {
819 HookData::Features(arr) => assert_eq!(arr.shape(), &[2, 2]),
820 _ => panic!("Wrong variant"),
821 }
822
823 match targets {
824 HookData::Targets(arr) => assert_eq!(arr.len(), 2),
825 _ => panic!("Wrong variant"),
826 }
827
828 match predictions {
829 HookData::Predictions(arr) => assert_eq!(arr.len(), 2),
830 _ => panic!("Wrong variant"),
831 }
832 }
833}
834
835#[derive(Debug, Clone)]
837pub struct HookDependency {
838 pub hook_name: String,
840 pub strict: bool,
842 pub min_priority: Option<i32>,
844}
845
846pub trait DependentExecutionHook: ExecutionHook {
848 fn dependencies(&self) -> Vec<HookDependency> {
850 Vec::new()
851 }
852
853 fn dependencies_satisfied(&self, executed_hooks: &[String]) -> bool {
855 self.dependencies()
856 .iter()
857 .all(|dep| executed_hooks.contains(&dep.hook_name))
858 }
859}
860
861pub trait AsyncExecutionHook: Send + Sync + Debug {
864 fn execute_async(
865 &mut self,
866 context: &ExecutionContext,
867 data: Option<&HookData>,
868 ) -> SklResult<HookResult>;
869
870 fn name(&self) -> &str;
871
872 fn priority(&self) -> i32 {
873 0
874 }
875
876 fn should_execute(&self, phase: HookPhase) -> bool;
878
879 fn timeout(&self) -> Option<Duration> {
881 None
882 }
883}
884
885#[derive(Debug, Clone)]
887pub struct ResourceManagerHook {
888 name: String,
889 max_memory: Option<usize>,
890 max_execution_time: Option<Duration>,
891 cpu_limit: Option<f64>, resource_usage: Arc<Mutex<ResourceUsage>>,
893}
894
895#[derive(Debug, Clone, Default)]
896pub struct ResourceUsage {
897 pub current_memory: usize,
898 pub peak_memory: usize,
899 pub cpu_usage: f64,
900 pub execution_time: Duration,
901 pub violations: Vec<ResourceViolation>,
902}
903
904#[derive(Debug, Clone)]
905pub struct ResourceViolation {
906 pub violation_type: ViolationType,
907 pub timestamp: Instant,
908 pub details: String,
909}
910
911#[derive(Debug, Clone)]
912pub enum ViolationType {
913 MemoryLimit,
915 TimeLimit,
917 CpuLimit,
919}
920
921impl ResourceManagerHook {
922 #[must_use]
924 pub fn new(name: String) -> Self {
925 Self {
926 name,
927 max_memory: None,
928 max_execution_time: None,
929 cpu_limit: None,
930 resource_usage: Arc::new(Mutex::new(ResourceUsage::default())),
931 }
932 }
933
934 #[must_use]
936 pub fn max_memory(mut self, limit: usize) -> Self {
937 self.max_memory = Some(limit);
938 self
939 }
940
941 #[must_use]
943 pub fn max_execution_time(mut self, limit: Duration) -> Self {
944 self.max_execution_time = Some(limit);
945 self
946 }
947
948 #[must_use]
950 pub fn cpu_limit(mut self, limit: f64) -> Self {
951 self.cpu_limit = Some(limit.min(1.0).max(0.0));
952 self
953 }
954
955 #[must_use]
957 pub fn get_usage(&self) -> ResourceUsage {
958 self.resource_usage.lock().unwrap().clone()
959 }
960
961 fn check_limits(&self, context: &ExecutionContext) -> SklResult<HookResult> {
963 let mut usage = self.resource_usage.lock().unwrap();
964
965 if let Some(time_limit) = self.max_execution_time {
967 let elapsed = context.start_time.elapsed();
968 usage.execution_time = elapsed;
969
970 if elapsed > time_limit {
971 let violation = ResourceViolation {
972 violation_type: ViolationType::TimeLimit,
973 timestamp: Instant::now(),
974 details: format!(
975 "Execution time {} exceeded limit {:?}",
976 elapsed.as_secs_f64(),
977 time_limit
978 ),
979 };
980 usage.violations.push(violation);
981 return Ok(HookResult::Abort(format!(
982 "[{}] Execution time limit exceeded: {:?} > {:?}",
983 self.name, elapsed, time_limit
984 )));
985 }
986 }
987
988 if let Some(memory_limit) = self.max_memory {
990 let estimated_memory = context
991 .metrics
992 .data_shapes
993 .iter()
994 .map(|(rows, cols)| rows * cols * std::mem::size_of::<Float>())
995 .sum::<usize>();
996
997 usage.current_memory = estimated_memory;
998 usage.peak_memory = usage.peak_memory.max(estimated_memory);
999
1000 if estimated_memory > memory_limit {
1001 let violation = ResourceViolation {
1002 violation_type: ViolationType::MemoryLimit,
1003 timestamp: Instant::now(),
1004 details: format!(
1005 "Memory usage {estimated_memory} exceeded limit {memory_limit}"
1006 ),
1007 };
1008 usage.violations.push(violation);
1009 return Ok(HookResult::Abort(format!(
1010 "[{}] Memory limit exceeded: {} bytes > {} bytes",
1011 self.name, estimated_memory, memory_limit
1012 )));
1013 }
1014 }
1015
1016 Ok(HookResult::Continue)
1017 }
1018}
1019
1020impl ExecutionHook for ResourceManagerHook {
1021 fn execute(
1022 &mut self,
1023 context: &ExecutionContext,
1024 _data: Option<&HookData>,
1025 ) -> SklResult<HookResult> {
1026 self.check_limits(context)
1027 }
1028
1029 fn name(&self) -> &str {
1030 &self.name
1031 }
1032
1033 fn priority(&self) -> i32 {
1034 1000 }
1036
1037 fn should_execute(&self, phase: HookPhase) -> bool {
1038 matches!(
1039 phase,
1040 HookPhase::BeforeStep | HookPhase::AfterStep | HookPhase::BeforeExecution
1041 )
1042 }
1043}
1044
1045#[derive(Debug, Clone)]
1047pub struct SecurityAuditHook {
1048 name: String,
1049 audit_log: Arc<Mutex<Vec<AuditEntry>>>,
1050 sensitive_operations: Vec<String>,
1051 require_authorization: bool,
1052}
1053
1054#[derive(Debug, Clone)]
1055pub struct AuditEntry {
1056 pub timestamp: Instant,
1057 pub execution_id: String,
1058 pub operation: String,
1059 pub user_id: Option<String>,
1060 pub data_summary: String,
1061 pub result: AuditResult,
1062}
1063
1064#[derive(Debug, Clone)]
1065pub enum AuditResult {
1066 Success,
1068 Failed(String),
1070 Unauthorized,
1072 Suspicious(String),
1074}
1075
1076impl SecurityAuditHook {
1077 #[must_use]
1079 pub fn new(name: String) -> Self {
1080 Self {
1081 name,
1082 audit_log: Arc::new(Mutex::new(Vec::new())),
1083 sensitive_operations: Vec::new(),
1084 require_authorization: false,
1085 }
1086 }
1087
1088 #[must_use]
1090 pub fn sensitive_operations(mut self, operations: Vec<String>) -> Self {
1091 self.sensitive_operations = operations;
1092 self
1093 }
1094
1095 #[must_use]
1097 pub fn require_authorization(mut self, require: bool) -> Self {
1098 self.require_authorization = require;
1099 self
1100 }
1101
1102 #[must_use]
1104 pub fn get_audit_log(&self) -> Vec<AuditEntry> {
1105 self.audit_log.lock().unwrap().clone()
1106 }
1107
1108 fn is_sensitive_operation(&self, context: &ExecutionContext) -> bool {
1110 if let Some(step_name) = &context.step_name {
1111 self.sensitive_operations
1112 .iter()
1113 .any(|op| step_name.contains(op))
1114 } else {
1115 false
1116 }
1117 }
1118
1119 fn create_audit_entry(
1121 &self,
1122 context: &ExecutionContext,
1123 result: AuditResult,
1124 data_summary: String,
1125 ) -> AuditEntry {
1126 AuditEntry {
1128 timestamp: Instant::now(),
1129 execution_id: context.execution_id.clone(),
1130 operation: context
1131 .step_name
1132 .clone()
1133 .unwrap_or_else(|| "unknown".to_string()),
1134 user_id: context.metadata.get("user_id").cloned(),
1135 data_summary,
1136 result,
1137 }
1138 }
1139}
1140
1141impl ExecutionHook for SecurityAuditHook {
1142 fn execute(
1143 &mut self,
1144 context: &ExecutionContext,
1145 data: Option<&HookData>,
1146 ) -> SklResult<HookResult> {
1147 let is_sensitive = self.is_sensitive_operation(context);
1148
1149 let data_summary = match data {
1151 Some(HookData::Features(arr)) => format!("Features: {}x{}", arr.nrows(), arr.ncols()),
1152 Some(HookData::Targets(arr)) => format!("Targets: {}", arr.len()),
1153 Some(HookData::Predictions(arr)) => format!("Predictions: {}", arr.len()),
1154 Some(HookData::Custom(_)) => "Custom data".to_string(),
1155 None => "No data".to_string(),
1156 };
1157
1158 if is_sensitive && self.require_authorization {
1160 let has_auth = context
1161 .metadata
1162 .get("authorized")
1163 .is_some_and(|v| v == "true");
1164
1165 if !has_auth {
1166 let audit_entry =
1167 self.create_audit_entry(context, AuditResult::Unauthorized, data_summary);
1168 self.audit_log.lock().unwrap().push(audit_entry);
1169
1170 return Ok(HookResult::Abort(format!(
1171 "[{}] Unauthorized access to sensitive operation: {}",
1172 self.name,
1173 context.step_name.as_deref().unwrap_or("unknown")
1174 )));
1175 }
1176 }
1177
1178 if is_sensitive || !self.sensitive_operations.is_empty() {
1180 let result = if is_sensitive {
1181 if data_summary.contains("empty") {
1183 AuditResult::Suspicious("Empty data in sensitive operation".to_string())
1184 } else {
1185 AuditResult::Success
1186 }
1187 } else {
1188 AuditResult::Success
1189 };
1190
1191 let audit_entry = self.create_audit_entry(context, result, data_summary);
1192 self.audit_log.lock().unwrap().push(audit_entry);
1193 }
1194
1195 Ok(HookResult::Continue)
1196 }
1197
1198 fn name(&self) -> &str {
1199 &self.name
1200 }
1201
1202 fn priority(&self) -> i32 {
1203 900 }
1205
1206 fn should_execute(&self, phase: HookPhase) -> bool {
1207 matches!(
1208 phase,
1209 HookPhase::BeforeStep | HookPhase::BeforePredict | HookPhase::BeforeTransform
1210 )
1211 }
1212}
1213
1214#[derive(Debug, Clone)]
1216pub struct ErrorRecoveryHook {
1217 name: String,
1218 retry_count: usize,
1219 retry_delay: Duration,
1220 fallback_strategies: Vec<FallbackStrategy>,
1221 error_history: Arc<Mutex<Vec<ErrorRecord>>>,
1222}
1223
1224#[derive(Debug, Clone)]
1225pub struct ErrorRecord {
1226 pub timestamp: Instant,
1227 pub execution_id: String,
1228 pub error_type: String,
1229 pub error_message: String,
1230 pub recovery_attempted: bool,
1231 pub recovery_successful: bool,
1232}
1233
1234#[derive(Debug, Clone)]
1235pub enum FallbackStrategy {
1236 RetryWithDelay(Duration),
1238 UseDefaultValues,
1240 SkipStep,
1242 AbortExecution,
1244 CustomRecovery(String), }
1247
1248impl ErrorRecoveryHook {
1249 #[must_use]
1251 pub fn new(name: String) -> Self {
1252 Self {
1253 name,
1254 retry_count: 3,
1255 retry_delay: Duration::from_millis(100),
1256 fallback_strategies: vec![
1257 FallbackStrategy::RetryWithDelay(Duration::from_millis(100)),
1258 FallbackStrategy::UseDefaultValues,
1259 FallbackStrategy::SkipStep,
1260 ],
1261 error_history: Arc::new(Mutex::new(Vec::new())),
1262 }
1263 }
1264
1265 #[must_use]
1267 pub fn retry_config(mut self, count: usize, delay: Duration) -> Self {
1268 self.retry_count = count;
1269 self.retry_delay = delay;
1270 self
1271 }
1272
1273 #[must_use]
1275 pub fn fallback_strategies(mut self, strategies: Vec<FallbackStrategy>) -> Self {
1276 self.fallback_strategies = strategies;
1277 self
1278 }
1279
1280 #[must_use]
1282 pub fn get_error_history(&self) -> Vec<ErrorRecord> {
1283 self.error_history.lock().unwrap().clone()
1284 }
1285
1286 fn record_error(
1288 &self,
1289 context: &ExecutionContext,
1290 error: &str,
1291 recovery_attempted: bool,
1292 recovery_successful: bool,
1293 ) {
1294 let record = ErrorRecord {
1295 timestamp: Instant::now(),
1296 execution_id: context.execution_id.clone(),
1297 error_type: "execution_error".to_string(),
1298 error_message: error.to_string(),
1299 recovery_attempted,
1300 recovery_successful,
1301 };
1302
1303 self.error_history.lock().unwrap().push(record);
1304 }
1305}
1306
1307impl ExecutionHook for ErrorRecoveryHook {
1308 fn execute(
1309 &mut self,
1310 context: &ExecutionContext,
1311 _data: Option<&HookData>,
1312 ) -> SklResult<HookResult> {
1313 if matches!(context.phase, HookPhase::OnError) {
1315 let error_msg = context
1317 .metadata
1318 .get("error")
1319 .unwrap_or(&"Unknown error".to_string())
1320 .clone();
1321
1322 for strategy in &self.fallback_strategies {
1324 match strategy {
1325 FallbackStrategy::RetryWithDelay(delay) => {
1326 self.record_error(context, &error_msg, true, false);
1327 std::thread::sleep(*delay);
1328 println!("[{}] Retrying after delay: {:?}", self.name, delay);
1330 return Ok(HookResult::Continue);
1331 }
1332 FallbackStrategy::UseDefaultValues => {
1333 self.record_error(context, &error_msg, true, true);
1334 println!("[{}] Using default values for recovery", self.name);
1335 return Ok(HookResult::ContinueWithData(HookData::Features(
1337 Array2::zeros((1, 1)),
1338 )));
1339 }
1340 FallbackStrategy::SkipStep => {
1341 self.record_error(context, &error_msg, true, true);
1342 println!("[{}] Skipping step for recovery", self.name);
1343 return Ok(HookResult::Skip);
1344 }
1345 FallbackStrategy::AbortExecution => {
1346 self.record_error(context, &error_msg, false, false);
1347 return Ok(HookResult::Abort(format!(
1348 "[{}] Unrecoverable error: {}",
1349 self.name, error_msg
1350 )));
1351 }
1352 FallbackStrategy::CustomRecovery(name) => {
1353 println!("[{}] Attempting custom recovery: {}", self.name, name);
1354 self.record_error(context, &error_msg, true, false);
1356 }
1357 }
1358 }
1359 }
1360
1361 Ok(HookResult::Continue)
1362 }
1363
1364 fn name(&self) -> &str {
1365 &self.name
1366 }
1367
1368 fn priority(&self) -> i32 {
1369 500 }
1371
1372 fn should_execute(&self, phase: HookPhase) -> bool {
1373 matches!(phase, HookPhase::OnError)
1374 }
1375}
1376
1377#[derive(Debug)]
1379pub struct HookComposition {
1380 name: String,
1381 hooks: Vec<Box<dyn ExecutionHook>>,
1382 execution_strategy: CompositionStrategy,
1383}
1384
1385#[derive(Debug, Clone)]
1386pub enum CompositionStrategy {
1387 Sequential,
1389 Parallel,
1391 FirstMatch,
1393 Aggregate,
1395}
1396
1397impl HookComposition {
1398 #[must_use]
1400 pub fn new(name: String, strategy: CompositionStrategy) -> Self {
1401 Self {
1402 name,
1403 hooks: Vec::new(),
1404 execution_strategy: strategy,
1405 }
1406 }
1407
1408 pub fn add_hook(&mut self, hook: Box<dyn ExecutionHook>) {
1410 self.hooks.push(hook);
1411 self.hooks.sort_by(|a, b| b.priority().cmp(&a.priority()));
1413 }
1414}
1415
1416impl ExecutionHook for HookComposition {
1417 fn execute(
1418 &mut self,
1419 context: &ExecutionContext,
1420 data: Option<&HookData>,
1421 ) -> SklResult<HookResult> {
1422 match self.execution_strategy {
1423 CompositionStrategy::Sequential => {
1424 for hook in &mut self.hooks {
1425 if hook.should_execute(context.phase) {
1426 let result = hook.execute(context, data)?;
1427 if !matches!(result, HookResult::Continue) {
1428 return Ok(result);
1429 }
1430 }
1431 }
1432 Ok(HookResult::Continue)
1433 }
1434 CompositionStrategy::FirstMatch => {
1435 for hook in &mut self.hooks {
1436 if hook.should_execute(context.phase) {
1437 let result = hook.execute(context, data)?;
1438 if !matches!(result, HookResult::Continue) {
1439 return Ok(result);
1440 }
1441 }
1442 }
1443 Ok(HookResult::Continue)
1444 }
1445 CompositionStrategy::Parallel => {
1446 let mut results = Vec::new();
1448 for hook in &mut self.hooks {
1449 if hook.should_execute(context.phase) {
1450 results.push(hook.execute(context, data)?);
1451 }
1452 }
1453
1454 for result in results {
1456 if !matches!(result, HookResult::Continue) {
1457 return Ok(result);
1458 }
1459 }
1460 Ok(HookResult::Continue)
1461 }
1462 CompositionStrategy::Aggregate => {
1463 for hook in &mut self.hooks {
1465 if hook.should_execute(context.phase) {
1466 let _result = hook.execute(context, data)?;
1467 }
1469 }
1470 Ok(HookResult::Continue)
1471 }
1472 }
1473 }
1474
1475 fn name(&self) -> &str {
1476 &self.name
1477 }
1478
1479 fn priority(&self) -> i32 {
1480 self.hooks.iter().map(|h| h.priority()).max().unwrap_or(0)
1482 }
1483
1484 fn should_execute(&self, phase: HookPhase) -> bool {
1485 self.hooks.iter().any(|h| h.should_execute(phase))
1486 }
1487}