1use scirs2_core::ndarray::Array2;
9use sklears_core::{
10 error::{Result as SklResult, SklearsError},
11 traits::Estimator,
12 types::Float,
13};
14use std::collections::HashMap;
15use std::hash::Hash;
16use std::sync::{Arc, Mutex};
17use std::time::{Duration, Instant, SystemTime};
18
19#[derive(Debug)]
21pub struct MiddlewareContext {
22 pub request_id: String,
24 pub timestamp: SystemTime,
26 pub metadata: HashMap<String, String>,
28 pub user_info: Option<UserInfo>,
30 pub state: ContextState,
32 pub metrics: ExecutionMetrics,
34 pub custom_data: HashMap<String, Box<dyn std::any::Any + Send + Sync>>,
36}
37
38#[derive(Debug, Clone)]
40pub struct UserInfo {
41 pub user_id: String,
43 pub roles: Vec<String>,
45 pub permissions: Vec<String>,
47 pub session_token: Option<String>,
49 pub auth_method: AuthenticationMethod,
51}
52
53#[derive(Debug, Clone)]
55pub enum AuthenticationMethod {
56 None,
57 ApiKey {
59 key: String,
60 },
61 BearerToken {
63 token: String,
64 },
65 BasicAuth {
67 username: String,
68 password: String,
69 },
70 OAuth {
72 provider: String,
73 token: String,
74 },
75 Certificate {
77 cert_fingerprint: String,
78 },
79 Custom {
81 method: String,
82 },
83}
84
85#[derive(Debug, Clone)]
87pub enum ContextState {
88 Initializing,
90 Processing,
92 Completed,
94 Error { message: String },
96 Cancelled,
98}
99
100#[derive(Debug, Clone)]
102pub struct ExecutionMetrics {
103 pub start_time: Instant,
105 pub end_time: Option<Instant>,
107 pub duration: Option<Duration>,
109 pub memory_usage: u64,
111 pub cpu_usage: f64,
113 pub throughput: f64,
115 pub error_count: usize,
117 pub custom_metrics: HashMap<String, f64>,
119}
120
121pub trait PipelineMiddleware: Send + Sync {
123 fn name(&self) -> &str;
125
126 fn before_process(
128 &self,
129 context: &mut MiddlewareContext,
130 input: &Array2<Float>,
131 ) -> SklResult<()>;
132
133 fn after_process(
135 &self,
136 context: &mut MiddlewareContext,
137 output: &Array2<Float>,
138 ) -> SklResult<()>;
139
140 fn on_error(
142 &self,
143 context: &mut MiddlewareContext,
144 error: &SklearsError,
145 ) -> SklResult<ErrorAction>;
146
147 fn priority(&self) -> i32 {
149 100
150 }
151
152 fn should_execute(&self, context: &MiddlewareContext) -> bool {
154 true
155 }
156}
157
158#[derive(Debug, Clone)]
160pub enum ErrorAction {
161 Continue,
163 Retry {
165 max_attempts: usize,
166 delay: Duration,
167 },
168 Abort,
170 Fallback { fallback_data: Array2<Float> },
172}
173
174pub struct MiddlewareChain {
176 middlewares: Vec<Box<dyn PipelineMiddleware>>,
178 config: MiddlewareChainConfig,
180 stats: MiddlewareStats,
182}
183
184#[derive(Debug, Clone)]
186pub struct MiddlewareChainConfig {
187 pub parallel_execution: bool,
189 pub timeout_per_middleware: Duration,
191 pub global_timeout: Duration,
193 pub continue_on_error: bool,
195 pub detailed_logging: bool,
197}
198
199#[derive(Debug, Clone)]
201pub struct MiddlewareStats {
202 pub total_requests: u64,
204 pub successful_requests: u64,
206 pub failed_requests: u64,
208 pub average_execution_time: Duration,
210 pub middleware_stats: HashMap<String, MiddlewareMetrics>,
212}
213
214#[derive(Debug, Clone)]
216pub struct MiddlewareMetrics {
217 pub execution_count: u64,
219 pub total_execution_time: Duration,
221 pub average_execution_time: Duration,
223 pub error_count: u64,
225 pub success_rate: f64,
227}
228
229pub struct AuthenticationMiddleware {
231 providers: HashMap<String, Box<dyn AuthenticationProvider>>,
233 config: AuthenticationConfig,
235}
236
237pub trait AuthenticationProvider: Send + Sync {
239 fn name(&self) -> &str;
241
242 fn authenticate(&self, credentials: &AuthenticationCredentials) -> SklResult<UserInfo>;
244
245 fn validate_session(&self, session_token: &str) -> SklResult<bool>;
247
248 fn refresh_token(&self, refresh_token: &str) -> SklResult<String>;
250}
251
252#[derive(Debug, Clone)]
254pub enum AuthenticationCredentials {
255 ApiKey { key: String },
257 BearerToken { token: String },
259 BasicAuth { username: String, password: String },
261 OAuth { provider: String, token: String },
263 Certificate { certificate: Vec<u8> },
265}
266
267#[derive(Debug, Clone)]
269pub struct AuthenticationConfig {
270 pub required_methods: Vec<String>,
272 pub allow_anonymous: bool,
274 pub session_timeout: Duration,
276 pub token_refresh_threshold: Duration,
278 pub max_failed_attempts: usize,
280 pub lockout_duration: Duration,
282}
283
284pub struct AuthorizationMiddleware {
286 policies: Vec<AccessPolicy>,
288 rbac: RoleBasedAccessControl,
290 config: AuthorizationConfig,
292}
293
294#[derive(Debug, Clone)]
296pub struct AccessPolicy {
297 pub name: String,
299 pub resource_pattern: String,
301 pub required_permissions: Vec<String>,
303 pub allowed_roles: Vec<String>,
305 pub conditions: Vec<AccessCondition>,
307 pub effect: PolicyEffect,
309}
310
311#[derive(Debug, Clone)]
313pub enum AccessCondition {
314 TimeWindow { start: String, end: String },
316 IpRange { cidr: String },
318 UserAttribute { attribute: String, value: String },
320 ResourceAttribute { attribute: String, value: String },
322 Custom { condition: String },
324}
325
326#[derive(Debug, Clone)]
328pub enum PolicyEffect {
329 Allow,
331 Deny,
333 Conditional,
335}
336
337#[derive(Debug, Clone)]
339pub struct RoleBasedAccessControl {
340 pub roles: HashMap<String, Role>,
342 pub permissions: HashMap<String, Permission>,
344 pub role_hierarchy: HashMap<String, Vec<String>>,
346}
347
348#[derive(Debug, Clone)]
350pub struct Role {
351 pub name: String,
353 pub description: String,
355 pub permissions: Vec<String>,
357 pub metadata: HashMap<String, String>,
359}
360
361#[derive(Debug, Clone)]
363pub struct Permission {
364 pub name: String,
366 pub description: String,
368 pub resource_type: String,
370 pub actions: Vec<String>,
372}
373
374#[derive(Debug, Clone)]
376pub struct AuthorizationConfig {
377 pub default_effect: PolicyEffect,
379 pub enable_role_inheritance: bool,
381 pub cache_decisions: bool,
383 pub cache_ttl: Duration,
385}
386
387pub struct ValidationMiddleware {
389 input_validators: Vec<Box<dyn InputValidator>>,
391 output_validators: Vec<Box<dyn OutputValidator>>,
393 config: ValidationConfig,
395}
396
397pub trait InputValidator: Send + Sync {
399 fn name(&self) -> &str;
401
402 fn validate(
404 &self,
405 input: &Array2<Float>,
406 context: &MiddlewareContext,
407 ) -> SklResult<ValidationResult>;
408
409 fn severity(&self) -> ValidationSeverity;
411}
412
413pub trait OutputValidator: Send + Sync {
415 fn name(&self) -> &str;
417
418 fn validate(
420 &self,
421 output: &Array2<Float>,
422 context: &MiddlewareContext,
423 ) -> SklResult<ValidationResult>;
424
425 fn severity(&self) -> ValidationSeverity;
427}
428
429#[derive(Debug, Clone)]
431pub struct ValidationResult {
432 pub valid: bool,
434 pub messages: Vec<ValidationMessage>,
436 pub corrections: Vec<ValidationCorrection>,
438}
439
440#[derive(Debug, Clone)]
442pub struct ValidationMessage {
443 pub message: String,
445 pub severity: ValidationSeverity,
447 pub field: Option<String>,
449 pub code: Option<String>,
451}
452
453#[derive(Debug, Clone)]
455pub enum ValidationSeverity {
456 Info,
458 Warning,
460 Error,
462 Critical,
464}
465
466#[derive(Debug, Clone)]
468pub struct ValidationCorrection {
469 pub description: String,
471 pub corrected_value: Option<Array2<Float>>,
473 pub confidence: f64,
475}
476
477#[derive(Debug, Clone)]
479pub struct ValidationConfig {
480 pub fail_on_error: bool,
482 pub auto_correct: bool,
484 pub timeout: Duration,
486 pub max_corrections: usize,
488}
489
490pub struct TransformationMiddleware {
492 pre_transformations: Vec<Box<dyn DataTransformer>>,
494 post_transformations: Vec<Box<dyn DataTransformer>>,
496 config: TransformationConfig,
498}
499
500pub trait DataTransformer: Send + Sync {
502 fn name(&self) -> &str;
504
505 fn transform(
507 &self,
508 data: &Array2<Float>,
509 context: &MiddlewareContext,
510 ) -> SklResult<Array2<Float>>;
511
512 fn should_transform(&self, data: &Array2<Float>, context: &MiddlewareContext) -> bool;
514
515 fn get_metadata(&self) -> TransformationMetadata;
517}
518
519#[derive(Debug, Clone)]
521pub struct TransformationMetadata {
522 pub transformation_type: String,
524 pub input_requirements: Vec<String>,
526 pub output_characteristics: Vec<String>,
528 pub performance_impact: PerformanceImpact,
530}
531
532#[derive(Debug, Clone)]
534pub enum PerformanceImpact {
535 Minimal,
537 Low,
539 Medium,
541 High,
543 Extreme,
545}
546
547#[derive(Debug, Clone)]
549pub struct TransformationConfig {
550 pub parallel_transformations: bool,
552 pub timeout: Duration,
554 pub cache_results: bool,
556 pub cache_ttl: Duration,
558}
559
560pub struct CachingMiddleware {
562 cache: Arc<Mutex<HashMap<String, CacheEntry>>>,
564 config: CacheConfig,
566 stats: CacheStats,
568}
569
570#[derive(Debug, Clone)]
572pub struct CacheEntry {
573 pub data: Array2<Float>,
575 pub created_at: SystemTime,
577 pub last_accessed: SystemTime,
579 pub access_count: u64,
581 pub metadata: HashMap<String, String>,
583}
584
585#[derive(Debug, Clone)]
587pub struct CacheConfig {
588 pub max_size: usize,
590 pub ttl: Duration,
592 pub eviction_policy: EvictionPolicy,
594 pub enable_stats: bool,
596 pub key_strategy: CacheKeyStrategy,
598}
599
600#[derive(Debug, Clone)]
602pub enum EvictionPolicy {
603 LRU, LFU, FIFO, TTL, Random,
613}
614
615#[derive(Debug, Clone)]
617pub enum CacheKeyStrategy {
618 HashInput,
620 HashInputAndContext,
622 Custom { generator: String },
624}
625
626#[derive(Debug, Clone)]
628pub struct CacheStats {
629 pub hits: u64,
631 pub misses: u64,
633 pub hit_ratio: f64,
635 pub total_size: u64,
637 pub entry_count: usize,
639 pub evictions: u64,
641}
642
643pub struct MonitoringMiddleware {
645 collectors: Vec<Box<dyn MetricsCollector>>,
647 config: MonitoringConfig,
649 alert_manager: AlertManager,
651}
652
653pub trait MetricsCollector: Send + Sync {
655 fn name(&self) -> &str;
657
658 fn collect(&self, context: &MiddlewareContext, data: &Array2<Float>) -> SklResult<Vec<Metric>>;
660
661 fn supported_metrics(&self) -> Vec<String>;
663}
664
665#[derive(Debug, Clone)]
667pub struct Metric {
668 pub name: String,
670 pub value: f64,
672 pub metric_type: MetricType,
674 pub timestamp: SystemTime,
676 pub labels: HashMap<String, String>,
678}
679
680#[derive(Debug, Clone)]
682pub enum MetricType {
683 Counter,
685 Gauge,
687 Histogram,
689 Summary,
691 Timer,
693}
694
695#[derive(Debug, Clone)]
697pub struct AlertManager {
698 pub rules: Vec<AlertRule>,
700 pub active_alerts: Vec<Alert>,
702 pub channels: Vec<AlertChannel>,
704}
705
706#[derive(Debug, Clone)]
708pub struct AlertRule {
709 pub name: String,
711 pub metric: String,
713 pub condition: AlertCondition,
715 pub severity: AlertSeverity,
717 pub evaluation_interval: Duration,
719}
720
721#[derive(Debug, Clone)]
723pub enum AlertCondition {
724 Threshold { operator: String, value: f64 },
726 Range { min: f64, max: f64 },
728 Rate {
730 change_percent: f64,
731 time_window: Duration,
732 },
733 Anomaly { sensitivity: f64 },
735}
736
737#[derive(Debug, Clone)]
739pub enum AlertSeverity {
740 Info,
742 Warning,
744 Critical,
746 Emergency,
748}
749
750#[derive(Debug, Clone)]
752pub struct Alert {
753 pub id: String,
755 pub rule_name: String,
757 pub current_value: f64,
759 pub message: String,
761 pub triggered_at: SystemTime,
763 pub status: AlertStatus,
765}
766
767#[derive(Debug, Clone)]
769pub enum AlertStatus {
770 Triggered,
772 Acknowledged,
774 Resolved,
776 Suppressed,
778}
779
780#[derive(Debug, Clone)]
782pub enum AlertChannel {
783 Email { addresses: Vec<String> },
785 Webhook { url: String },
787 Slack {
789 webhook_url: String,
790 channel: String,
791 },
792 Console,
794 Log { file_path: String },
796}
797
798#[derive(Debug, Clone)]
800pub struct MonitoringConfig {
801 pub real_time: bool,
803 pub collection_interval: Duration,
805 pub retention_period: Duration,
807 pub enable_alerting: bool,
809 pub alert_evaluation_interval: Duration,
811}
812
813impl MiddlewareChain {
814 #[must_use]
816 pub fn new(config: MiddlewareChainConfig) -> Self {
817 Self {
818 middlewares: Vec::new(),
819 config,
820 stats: MiddlewareStats {
821 total_requests: 0,
822 successful_requests: 0,
823 failed_requests: 0,
824 average_execution_time: Duration::from_millis(0),
825 middleware_stats: HashMap::new(),
826 },
827 }
828 }
829
830 pub fn add_middleware(&mut self, middleware: Box<dyn PipelineMiddleware>) {
832 self.middlewares.push(middleware);
833 self.middlewares.sort_by_key(|m| m.priority());
834 }
835
836 pub fn execute(
838 &mut self,
839 context: &mut MiddlewareContext,
840 input: &Array2<Float>,
841 processor: &dyn Fn(&Array2<Float>) -> SklResult<Array2<Float>>,
842 ) -> SklResult<Array2<Float>> {
843 let start_time = Instant::now();
844 self.stats.total_requests += 1;
845
846 for middleware in &self.middlewares {
848 if middleware.should_execute(context) {
849 if let Err(e) = middleware.before_process(context, input) {
850 let action = middleware.on_error(context, &e)?;
851 match action {
852 ErrorAction::Continue => {}
853 ErrorAction::Abort => {
854 self.stats.failed_requests += 1;
855 return Err(e);
856 }
857 ErrorAction::Retry {
858 max_attempts,
859 delay,
860 } => {
861 std::thread::sleep(delay);
863 return self.execute(context, input, processor);
864 }
865 ErrorAction::Fallback { fallback_data } => {
866 return Ok(fallback_data);
867 }
868 }
869 }
870 }
871 }
872
873 let result = processor(input)?;
875
876 for middleware in &self.middlewares {
878 if middleware.should_execute(context) {
879 if let Err(e) = middleware.after_process(context, &result) {
880 let action = middleware.on_error(context, &e)?;
881 match action {
882 ErrorAction::Continue => {}
883 ErrorAction::Abort => {
884 self.stats.failed_requests += 1;
885 return Err(e);
886 }
887 _ => {}
888 }
889 }
890 }
891 }
892
893 let execution_time = start_time.elapsed();
895 self.stats.successful_requests += 1;
896 self.update_execution_stats(execution_time);
897
898 context.state = ContextState::Completed;
899 context.metrics.end_time = Some(Instant::now());
900 context.metrics.duration = Some(execution_time);
901
902 Ok(result)
903 }
904
905 #[must_use]
907 pub fn get_stats(&self) -> &MiddlewareStats {
908 &self.stats
909 }
910
911 fn update_execution_stats(&mut self, execution_time: Duration) {
913 let total_time = self.stats.average_execution_time.as_nanos() as f64
914 * (self.stats.total_requests - 1) as f64;
915 let new_avg_nanos =
916 (total_time + execution_time.as_nanos() as f64) / self.stats.total_requests as f64;
917 self.stats.average_execution_time = Duration::from_nanos(new_avg_nanos as u64);
918 }
919}
920
921impl AuthenticationMiddleware {
922 #[must_use]
924 pub fn new(config: AuthenticationConfig) -> Self {
925 Self {
926 providers: HashMap::new(),
927 config,
928 }
929 }
930
931 pub fn add_provider(&mut self, provider: Box<dyn AuthenticationProvider>) {
933 self.providers.insert(provider.name().to_string(), provider);
934 }
935
936 pub fn authenticate(&self, credentials: &AuthenticationCredentials) -> SklResult<UserInfo> {
938 for provider in self.providers.values() {
939 if let Ok(user_info) = provider.authenticate(credentials) {
940 return Ok(user_info);
941 }
942 }
943 Err(SklearsError::InvalidInput(
944 "Authentication failed".to_string(),
945 ))
946 }
947}
948
949impl PipelineMiddleware for AuthenticationMiddleware {
950 fn name(&self) -> &'static str {
951 "authentication"
952 }
953
954 fn before_process(
955 &self,
956 context: &mut MiddlewareContext,
957 _input: &Array2<Float>,
958 ) -> SklResult<()> {
959 if !self.config.allow_anonymous && context.user_info.is_none() {
960 return Err(SklearsError::InvalidInput(
961 "Authentication required".to_string(),
962 ));
963 }
964 Ok(())
965 }
966
967 fn after_process(
968 &self,
969 _context: &mut MiddlewareContext,
970 _output: &Array2<Float>,
971 ) -> SklResult<()> {
972 Ok(())
973 }
974
975 fn on_error(
976 &self,
977 _context: &mut MiddlewareContext,
978 _error: &SklearsError,
979 ) -> SklResult<ErrorAction> {
980 Ok(ErrorAction::Abort)
981 }
982
983 fn priority(&self) -> i32 {
984 10 }
986}
987
988impl AuthorizationMiddleware {
989 #[must_use]
991 pub fn new(config: AuthorizationConfig) -> Self {
992 Self {
993 policies: Vec::new(),
994 rbac: RoleBasedAccessControl {
995 roles: HashMap::new(),
996 permissions: HashMap::new(),
997 role_hierarchy: HashMap::new(),
998 },
999 config,
1000 }
1001 }
1002
1003 pub fn add_policy(&mut self, policy: AccessPolicy) {
1005 self.policies.push(policy);
1006 }
1007
1008 pub fn authorize(&self, user_info: &UserInfo, resource: &str, action: &str) -> SklResult<bool> {
1010 for policy in &self.policies {
1011 if self.policy_matches(policy, resource)
1012 && self.check_permissions(policy, user_info, action)
1013 {
1014 return Ok(policy.effect == PolicyEffect::Allow);
1015 }
1016 }
1017
1018 Ok(matches!(self.config.default_effect, PolicyEffect::Allow))
1020 }
1021
1022 fn policy_matches(&self, policy: &AccessPolicy, resource: &str) -> bool {
1024 policy.resource_pattern == "*" || policy.resource_pattern == resource
1026 }
1027
1028 fn check_permissions(&self, policy: &AccessPolicy, user_info: &UserInfo, action: &str) -> bool {
1030 for role in &user_info.roles {
1032 if policy.allowed_roles.contains(role) {
1033 return true;
1034 }
1035 }
1036
1037 for permission in &user_info.permissions {
1039 if policy.required_permissions.contains(permission) {
1040 return true;
1041 }
1042 }
1043
1044 false
1045 }
1046}
1047
1048impl PipelineMiddleware for AuthorizationMiddleware {
1049 fn name(&self) -> &'static str {
1050 "authorization"
1051 }
1052
1053 fn before_process(
1054 &self,
1055 context: &mut MiddlewareContext,
1056 _input: &Array2<Float>,
1057 ) -> SklResult<()> {
1058 if let Some(user_info) = &context.user_info {
1059 if !self.authorize(user_info, "pipeline", "execute")? {
1060 return Err(SklearsError::InvalidInput("Access denied".to_string()));
1061 }
1062 }
1063 Ok(())
1064 }
1065
1066 fn after_process(
1067 &self,
1068 _context: &mut MiddlewareContext,
1069 _output: &Array2<Float>,
1070 ) -> SklResult<()> {
1071 Ok(())
1072 }
1073
1074 fn on_error(
1075 &self,
1076 _context: &mut MiddlewareContext,
1077 _error: &SklearsError,
1078 ) -> SklResult<ErrorAction> {
1079 Ok(ErrorAction::Abort)
1080 }
1081
1082 fn priority(&self) -> i32 {
1083 20 }
1085}
1086
1087impl CachingMiddleware {
1088 #[must_use]
1090 pub fn new(config: CacheConfig) -> Self {
1091 Self {
1092 cache: Arc::new(Mutex::new(HashMap::new())),
1093 config,
1094 stats: CacheStats {
1095 hits: 0,
1096 misses: 0,
1097 hit_ratio: 0.0,
1098 total_size: 0,
1099 entry_count: 0,
1100 evictions: 0,
1101 },
1102 }
1103 }
1104
1105 fn generate_cache_key(&self, input: &Array2<Float>, context: &MiddlewareContext) -> String {
1107 use std::collections::hash_map::DefaultHasher;
1108 use std::hash::Hasher;
1109
1110 match &self.config.key_strategy {
1111 CacheKeyStrategy::HashInput => {
1112 let mut hasher = DefaultHasher::new();
1113 if let Some(slice) = input.as_slice() {
1114 for &x in slice {
1115 (x.to_bits()).hash(&mut hasher);
1116 }
1117 }
1118 format!("{:x}", hasher.finish())
1119 }
1120 CacheKeyStrategy::HashInputAndContext => {
1121 let mut hasher = DefaultHasher::new();
1122 if let Some(slice) = input.as_slice() {
1123 for &x in slice {
1124 (x.to_bits()).hash(&mut hasher);
1125 }
1126 }
1127 context.request_id.hash(&mut hasher);
1128 format!("{:x}", hasher.finish())
1129 }
1130 CacheKeyStrategy::Custom { .. } => {
1131 let mut hasher = DefaultHasher::new();
1132 if let Some(slice) = input.as_slice() {
1133 for &x in slice {
1134 (x.to_bits()).hash(&mut hasher);
1135 }
1136 }
1137 format!("{:x}", hasher.finish())
1138 }
1139 }
1140 }
1141
1142 pub fn get(&mut self, key: &str) -> Option<Array2<Float>> {
1144 let result = {
1145 let mut cache = self.cache.lock().unwrap();
1146
1147 if let Some(entry) = cache.get_mut(key) {
1148 if entry.created_at.elapsed().unwrap_or(Duration::MAX) <= self.config.ttl {
1150 entry.last_accessed = SystemTime::now();
1151 entry.access_count += 1;
1152 Some((entry.data.clone(), true)) } else {
1154 cache.remove(key);
1156 Some((Array2::zeros((0, 0)), false)) }
1158 } else {
1159 None
1160 }
1161 };
1162
1163 match result {
1164 Some((data, true)) => {
1165 self.stats.hits += 1;
1166 self.update_hit_ratio();
1167 Some(data)
1168 }
1169 Some((_, false)) => {
1170 self.stats.evictions += 1;
1171 self.stats.misses += 1;
1172 self.update_hit_ratio();
1173 None
1174 }
1175 None => {
1176 self.stats.misses += 1;
1177 self.update_hit_ratio();
1178 None
1179 }
1180 }
1181 }
1182
1183 pub fn put(&mut self, key: String, data: Array2<Float>) {
1185 let max_size = self.config.max_size;
1186 let eviction_policy = self.config.eviction_policy.clone();
1187
1188 let (evicted, final_size) = {
1189 let mut cache = self.cache.lock().unwrap();
1190
1191 let mut evicted = false;
1193 if cache.len() >= max_size {
1194 match eviction_policy {
1196 EvictionPolicy::LRU => {
1197 if let Some(lru_key) = cache
1198 .iter()
1199 .min_by_key(|(_, entry)| entry.last_accessed)
1200 .map(|(key, _)| key.clone())
1201 {
1202 cache.remove(&lru_key);
1203 evicted = true;
1204 }
1205 }
1206 EvictionPolicy::LFU => {
1207 if let Some(lfu_key) = cache
1208 .iter()
1209 .min_by_key(|(_, entry)| entry.access_count)
1210 .map(|(key, _)| key.clone())
1211 {
1212 cache.remove(&lfu_key);
1213 evicted = true;
1214 }
1215 }
1216 _ => {
1217 if let Some(first_key) = cache.keys().next().cloned() {
1219 cache.remove(&first_key);
1220 evicted = true;
1221 }
1222 }
1223 }
1224 }
1225
1226 let entry = CacheEntry {
1227 data,
1228 created_at: SystemTime::now(),
1229 last_accessed: SystemTime::now(),
1230 access_count: 1,
1231 metadata: HashMap::new(),
1232 };
1233
1234 cache.insert(key, entry);
1235 (evicted, cache.len())
1236 };
1237
1238 if evicted {
1239 self.stats.evictions += 1;
1240 }
1241 self.stats.entry_count = final_size;
1242 }
1243
1244 fn evict_entries_internal(&mut self, cache: &mut HashMap<String, CacheEntry>) {
1246 let eviction_policy = self.config.eviction_policy.clone();
1247 match eviction_policy {
1248 EvictionPolicy::LRU => {
1249 if let Some(lru_key) = cache
1250 .iter()
1251 .min_by_key(|(_, entry)| entry.last_accessed)
1252 .map(|(key, _)| key.clone())
1253 {
1254 cache.remove(&lru_key);
1255 self.stats.evictions += 1;
1256 }
1257 }
1258 EvictionPolicy::LFU => {
1259 if let Some(lfu_key) = cache
1260 .iter()
1261 .min_by_key(|(_, entry)| entry.access_count)
1262 .map(|(key, _)| key.clone())
1263 {
1264 cache.remove(&lfu_key);
1265 self.stats.evictions += 1;
1266 }
1267 }
1268 _ => {
1269 if let Some(first_key) = cache.keys().next().cloned() {
1271 cache.remove(&first_key);
1272 self.stats.evictions += 1;
1273 }
1274 }
1275 }
1276 }
1277
1278 fn update_hit_ratio(&mut self) {
1280 let total = self.stats.hits + self.stats.misses;
1281 if total > 0 {
1282 self.stats.hit_ratio = self.stats.hits as f64 / total as f64;
1283 }
1284 }
1285}
1286
1287impl PipelineMiddleware for CachingMiddleware {
1288 fn name(&self) -> &'static str {
1289 "caching"
1290 }
1291
1292 fn before_process(
1293 &self,
1294 context: &mut MiddlewareContext,
1295 input: &Array2<Float>,
1296 ) -> SklResult<()> {
1297 let cache_key = self.generate_cache_key(input, context);
1298 context.metadata.insert("cache_key".to_string(), cache_key);
1299 Ok(())
1300 }
1301
1302 fn after_process(
1303 &self,
1304 context: &mut MiddlewareContext,
1305 output: &Array2<Float>,
1306 ) -> SklResult<()> {
1307 if let Some(cache_key) = context.metadata.get("cache_key") {
1308 self.cache.lock().unwrap().insert(
1309 cache_key.clone(),
1310 CacheEntry {
1312 data: output.clone(),
1313 created_at: SystemTime::now(),
1314 last_accessed: SystemTime::now(),
1315 access_count: 1,
1316 metadata: HashMap::new(),
1317 },
1318 );
1319 }
1320 Ok(())
1321 }
1322
1323 fn on_error(
1324 &self,
1325 _context: &mut MiddlewareContext,
1326 _error: &SklearsError,
1327 ) -> SklResult<ErrorAction> {
1328 Ok(ErrorAction::Continue)
1329 }
1330
1331 fn priority(&self) -> i32 {
1332 50
1333 }
1334}
1335
1336impl Default for MiddlewareChainConfig {
1337 fn default() -> Self {
1338 Self {
1339 parallel_execution: false,
1340 timeout_per_middleware: Duration::from_secs(30),
1341 global_timeout: Duration::from_secs(300),
1342 continue_on_error: false,
1343 detailed_logging: false,
1344 }
1345 }
1346}
1347
1348impl Default for AuthenticationConfig {
1349 fn default() -> Self {
1350 Self {
1351 required_methods: Vec::new(),
1352 allow_anonymous: true,
1353 session_timeout: Duration::from_secs(3600),
1354 token_refresh_threshold: Duration::from_secs(300),
1355 max_failed_attempts: 3,
1356 lockout_duration: Duration::from_secs(300),
1357 }
1358 }
1359}
1360
1361impl Default for AuthorizationConfig {
1362 fn default() -> Self {
1363 Self {
1364 default_effect: PolicyEffect::Deny,
1365 enable_role_inheritance: true,
1366 cache_decisions: true,
1367 cache_ttl: Duration::from_secs(300),
1368 }
1369 }
1370}
1371
1372impl Default for CacheConfig {
1373 fn default() -> Self {
1374 Self {
1375 max_size: 1000,
1376 ttl: Duration::from_secs(3600),
1377 eviction_policy: EvictionPolicy::LRU,
1378 enable_stats: true,
1379 key_strategy: CacheKeyStrategy::HashInput,
1380 }
1381 }
1382}
1383
1384impl Default for ValidationConfig {
1385 fn default() -> Self {
1386 Self {
1387 fail_on_error: true,
1388 auto_correct: false,
1389 timeout: Duration::from_secs(30),
1390 max_corrections: 10,
1391 }
1392 }
1393}
1394
1395impl Default for TransformationConfig {
1396 fn default() -> Self {
1397 Self {
1398 parallel_transformations: false,
1399 timeout: Duration::from_secs(60),
1400 cache_results: false,
1401 cache_ttl: Duration::from_secs(300),
1402 }
1403 }
1404}
1405
1406impl Default for MonitoringConfig {
1407 fn default() -> Self {
1408 Self {
1409 real_time: true,
1410 collection_interval: Duration::from_secs(60),
1411 retention_period: Duration::from_secs(86400),
1412 enable_alerting: true,
1413 alert_evaluation_interval: Duration::from_secs(60),
1414 }
1415 }
1416}
1417
1418impl PartialEq for PolicyEffect {
1419 fn eq(&self, other: &Self) -> bool {
1420 matches!(
1421 (self, other),
1422 (PolicyEffect::Allow, PolicyEffect::Allow)
1423 | (PolicyEffect::Deny, PolicyEffect::Deny)
1424 | (PolicyEffect::Conditional, PolicyEffect::Conditional)
1425 )
1426 }
1427}
1428
1429#[allow(non_snake_case)]
1430#[cfg(test)]
1431mod tests {
1432 use super::*;
1433
1434 #[test]
1435 fn test_middleware_context_creation() {
1436 let context = MiddlewareContext {
1437 request_id: "test-123".to_string(),
1438 timestamp: SystemTime::now(),
1439 metadata: HashMap::new(),
1440 user_info: None,
1441 state: ContextState::Initializing,
1442 metrics: ExecutionMetrics {
1443 start_time: Instant::now(),
1444 end_time: None,
1445 duration: None,
1446 memory_usage: 0,
1447 cpu_usage: 0.0,
1448 throughput: 0.0,
1449 error_count: 0,
1450 custom_metrics: HashMap::new(),
1451 },
1452 custom_data: HashMap::new(),
1453 };
1454
1455 assert_eq!(context.request_id, "test-123");
1456 assert!(matches!(context.state, ContextState::Initializing));
1457 }
1458
1459 #[test]
1460 fn test_middleware_chain_creation() {
1461 let config = MiddlewareChainConfig::default();
1462 let chain = MiddlewareChain::new(config);
1463
1464 assert_eq!(chain.middlewares.len(), 0);
1465 assert_eq!(chain.stats.total_requests, 0);
1466 }
1467
1468 #[test]
1469 fn test_authentication_middleware() {
1470 let config = AuthenticationConfig::default();
1471 let auth_middleware = AuthenticationMiddleware::new(config);
1472
1473 assert_eq!(auth_middleware.name(), "authentication");
1474 assert_eq!(auth_middleware.priority(), 10);
1475 }
1476
1477 #[test]
1478 fn test_caching_middleware() {
1479 let config = CacheConfig::default();
1480 let cache_middleware = CachingMiddleware::new(config);
1481
1482 assert_eq!(cache_middleware.name(), "caching");
1483 assert_eq!(cache_middleware.stats.hit_ratio, 0.0);
1484 }
1485
1486 #[test]
1487 fn test_cache_key_generation() {
1488 let config = CacheConfig::default();
1489 let cache_middleware = CachingMiddleware::new(config);
1490
1491 let input = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1492 let context = MiddlewareContext {
1493 request_id: "test".to_string(),
1494 timestamp: SystemTime::now(),
1495 metadata: HashMap::new(),
1496 user_info: None,
1497 state: ContextState::Processing,
1498 metrics: ExecutionMetrics {
1499 start_time: Instant::now(),
1500 end_time: None,
1501 duration: None,
1502 memory_usage: 0,
1503 cpu_usage: 0.0,
1504 throughput: 0.0,
1505 error_count: 0,
1506 custom_metrics: HashMap::new(),
1507 },
1508 custom_data: HashMap::new(),
1509 };
1510
1511 let key = cache_middleware.generate_cache_key(&input, &context);
1512 assert!(!key.is_empty());
1513 }
1514
1515 #[test]
1516 fn test_access_policy() {
1517 let policy = AccessPolicy {
1518 name: "test_policy".to_string(),
1519 resource_pattern: "/api/*".to_string(),
1520 required_permissions: vec!["read".to_string()],
1521 allowed_roles: vec!["user".to_string()],
1522 conditions: Vec::new(),
1523 effect: PolicyEffect::Allow,
1524 };
1525
1526 assert_eq!(policy.name, "test_policy");
1527 assert_eq!(policy.effect, PolicyEffect::Allow);
1528 }
1529
1530 #[test]
1531 fn test_validation_result() {
1532 let result = ValidationResult {
1533 valid: true,
1534 messages: Vec::new(),
1535 corrections: Vec::new(),
1536 };
1537
1538 assert!(result.valid);
1539 assert_eq!(result.messages.len(), 0);
1540 }
1541
1542 #[test]
1543 fn test_cache_stats() {
1544 let mut stats = CacheStats {
1545 hits: 10,
1546 misses: 5,
1547 hit_ratio: 0.0,
1548 total_size: 1024,
1549 entry_count: 15,
1550 evictions: 2,
1551 };
1552
1553 let total = stats.hits + stats.misses;
1555 stats.hit_ratio = stats.hits as f64 / total as f64;
1556
1557 assert_eq!(stats.hit_ratio, 10.0 / 15.0);
1558 }
1559
1560 #[test]
1561 fn test_user_info() {
1562 let user_info = UserInfo {
1563 user_id: "user123".to_string(),
1564 roles: vec!["admin".to_string(), "user".to_string()],
1565 permissions: vec!["read".to_string(), "write".to_string()],
1566 session_token: Some("token123".to_string()),
1567 auth_method: AuthenticationMethod::ApiKey {
1568 key: "api_key_123".to_string(),
1569 },
1570 };
1571
1572 assert_eq!(user_info.user_id, "user123");
1573 assert_eq!(user_info.roles.len(), 2);
1574 assert_eq!(user_info.permissions.len(), 2);
1575 }
1576
1577 #[test]
1578 fn test_metric_creation() {
1579 let metric = Metric {
1580 name: "response_time".to_string(),
1581 value: 150.5,
1582 metric_type: MetricType::Timer,
1583 timestamp: SystemTime::now(),
1584 labels: HashMap::from([
1585 ("service".to_string(), "api".to_string()),
1586 ("version".to_string(), "1.0".to_string()),
1587 ]),
1588 };
1589
1590 assert_eq!(metric.name, "response_time");
1591 assert_eq!(metric.value, 150.5);
1592 assert!(matches!(metric.metric_type, MetricType::Timer));
1593 }
1594}