1use crate::error::{Result, SklearsError};
6use std::collections::HashMap;
7
8pub trait FallbackStrategy {
10 fn is_preferred_available(&self) -> bool;
12
13 fn has_fallback(&self) -> bool;
15
16 fn fallback_limitations(&self) -> Vec<String>;
18
19 fn execute_with_fallback(&self, preferred_available: bool) -> Result<String>;
22}
23
24pub struct FallbackRegistry {
26 strategies: HashMap<String, Box<dyn FallbackStrategy + Send + Sync>>,
27 warnings_shown: std::sync::Mutex<std::collections::HashSet<String>>,
28}
29
30impl Default for FallbackRegistry {
31 fn default() -> Self {
32 Self::new()
33 }
34}
35
36impl FallbackRegistry {
37 pub fn new() -> Self {
39 Self {
40 strategies: HashMap::new(),
41 warnings_shown: std::sync::Mutex::new(std::collections::HashSet::new()),
42 }
43 }
44
45 pub fn register<S>(&mut self, dependency_name: &str, strategy: S)
47 where
48 S: FallbackStrategy + Send + Sync + 'static,
49 {
50 self.strategies
51 .insert(dependency_name.to_string(), Box::new(strategy));
52 }
53
54 pub fn execute_with_fallback<T, F, G>(
56 &self,
57 dependency_name: &str,
58 preferred: F,
59 fallback: G,
60 ) -> Result<T>
61 where
62 F: FnOnce() -> Result<T>,
63 G: FnOnce() -> Result<T>,
64 {
65 if let Some(strategy) = self.strategies.get(dependency_name) {
66 if strategy.is_preferred_available() {
67 preferred()
68 } else if strategy.has_fallback() {
69 self.warn_fallback_usage(dependency_name, strategy.fallback_limitations());
70 fallback()
71 } else {
72 Err(SklearsError::MissingDependency {
73 dependency: dependency_name.to_string(),
74 feature: "No fallback available".to_string(),
75 })
76 }
77 } else {
78 preferred().map_err(|_| SklearsError::MissingDependency {
80 dependency: dependency_name.to_string(),
81 feature: "No fallback strategy registered".to_string(),
82 })
83 }
84 }
85
86 fn warn_fallback_usage(&self, dependency_name: &str, limitations: Vec<String>) {
88 if let Ok(mut shown) = self.warnings_shown.lock() {
89 if !shown.contains(dependency_name) {
90 log::warn!(
91 "Using fallback implementation for '{}'. Limitations: {}",
92 dependency_name,
93 limitations.join(", ")
94 );
95 shown.insert(dependency_name.to_string());
96 }
97 }
98 }
99
100 pub fn dependency_status(&self) -> DependencyReport {
102 let mut available = Vec::new();
103 let mut fallback_used = Vec::new();
104 let mut missing = Vec::new();
105
106 for (name, strategy) in &self.strategies {
107 if strategy.is_preferred_available() {
108 available.push(name.clone());
109 } else if strategy.has_fallback() {
110 fallback_used.push(FallbackInfo {
111 dependency: name.clone(),
112 limitations: strategy.fallback_limitations(),
113 });
114 } else {
115 missing.push(name.clone());
116 }
117 }
118
119 DependencyReport {
120 available,
121 fallback_used,
122 missing,
123 }
124 }
125}
126
127#[derive(Debug, Clone)]
129pub struct DependencyReport {
130 pub available: Vec<String>,
131 pub fallback_used: Vec<FallbackInfo>,
132 pub missing: Vec<String>,
133}
134
135#[derive(Debug, Clone)]
136pub struct FallbackInfo {
137 pub dependency: String,
138 pub limitations: Vec<String>,
139}
140
141impl DependencyReport {
142 pub fn is_fully_functional(&self) -> bool {
143 self.fallback_used.is_empty() && self.missing.is_empty()
144 }
145
146 pub fn has_critical_missing(&self) -> bool {
147 !self.missing.is_empty()
148 }
149}
150
151pub struct BlasFallback;
154
155impl FallbackStrategy for BlasFallback {
156 fn is_preferred_available(&self) -> bool {
157 false
160 }
161
162 fn has_fallback(&self) -> bool {
163 true }
165
166 fn fallback_limitations(&self) -> Vec<String> {
167 vec![
168 "Slower matrix operations".to_string(),
169 "No SIMD optimizations".to_string(),
170 "Higher memory usage for large matrices".to_string(),
171 ]
172 }
173
174 fn execute_with_fallback(&self, preferred_available: bool) -> Result<String> {
175 if preferred_available && self.is_preferred_available() {
176 Ok("Using preferred implementation".to_string())
177 } else {
178 Ok("Using fallback implementation".to_string())
179 }
180 }
181}
182
183pub struct ParallelFallback;
185
186impl FallbackStrategy for ParallelFallback {
187 fn is_preferred_available(&self) -> bool {
188 true
190 }
191
192 fn has_fallback(&self) -> bool {
193 true }
195
196 fn fallback_limitations(&self) -> Vec<String> {
197 vec![
198 "Sequential processing only".to_string(),
199 "Slower on multi-core systems".to_string(),
200 "No work-stealing optimizations".to_string(),
201 ]
202 }
203
204 fn execute_with_fallback(&self, preferred_available: bool) -> Result<String> {
205 if preferred_available && self.is_preferred_available() {
206 Ok("Using preferred implementation".to_string())
207 } else {
208 Ok("Using fallback implementation".to_string())
209 }
210 }
211}
212
213pub struct VisualizationFallback;
215
216impl FallbackStrategy for VisualizationFallback {
217 fn is_preferred_available(&self) -> bool {
218 false
220 }
221
222 fn has_fallback(&self) -> bool {
223 true }
225
226 fn fallback_limitations(&self) -> Vec<String> {
227 vec![
228 "No graphical plots".to_string(),
229 "Text-based visualization only".to_string(),
230 "Limited aesthetic options".to_string(),
231 ]
232 }
233
234 fn execute_with_fallback(&self, preferred_available: bool) -> Result<String> {
235 if preferred_available && self.is_preferred_available() {
236 Ok("Using preferred implementation".to_string())
237 } else {
238 Ok("Using fallback implementation".to_string())
239 }
240 }
241}
242
243pub struct SerializationFallback;
245
246impl FallbackStrategy for SerializationFallback {
247 fn is_preferred_available(&self) -> bool {
248 cfg!(feature = "serde")
249 }
250
251 fn has_fallback(&self) -> bool {
252 true }
254
255 fn fallback_limitations(&self) -> Vec<String> {
256 vec![
257 "Binary format only".to_string(),
258 "No JSON/YAML support".to_string(),
259 "Limited cross-platform compatibility".to_string(),
260 ]
261 }
262
263 fn execute_with_fallback(&self, preferred_available: bool) -> Result<String> {
264 if preferred_available && self.is_preferred_available() {
265 Ok("Using preferred implementation".to_string())
266 } else {
267 Ok("Using fallback implementation".to_string())
268 }
269 }
270}
271
272pub struct GpuFallback;
274
275impl FallbackStrategy for GpuFallback {
276 fn is_preferred_available(&self) -> bool {
277 cfg!(feature = "gpu_support")
278 }
279
280 fn has_fallback(&self) -> bool {
281 true }
283
284 fn fallback_limitations(&self) -> Vec<String> {
285 vec![
286 "CPU-only computation".to_string(),
287 "Slower for large datasets".to_string(),
288 "No GPU memory optimizations".to_string(),
289 ]
290 }
291
292 fn execute_with_fallback(&self, preferred_available: bool) -> Result<String> {
293 if preferred_available && self.is_preferred_available() {
294 Ok("Using preferred implementation".to_string())
295 } else {
296 Ok("Using fallback implementation".to_string())
297 }
298 }
299}
300
301static GLOBAL_FALLBACK_REGISTRY: std::sync::OnceLock<std::sync::Mutex<FallbackRegistry>> =
303 std::sync::OnceLock::new();
304
305pub fn global_fallback_registry() -> &'static std::sync::Mutex<FallbackRegistry> {
307 GLOBAL_FALLBACK_REGISTRY.get_or_init(|| {
308 let mut registry = FallbackRegistry::new();
309
310 registry.register("blas", BlasFallback);
312 registry.register("parallel", ParallelFallback);
313 registry.register("visualization", VisualizationFallback);
314 registry.register("serialization", SerializationFallback);
315 registry.register("gpu", GpuFallback);
316
317 std::sync::Mutex::new(registry)
318 })
319}
320
321#[macro_export]
323macro_rules! with_fallback {
324 ($dependency:literal, $preferred:expr, $fallback:expr) => {{
325 use $crate::fallback_strategies::global_fallback_registry;
326 let registry = global_fallback_registry().lock().map_err(|_| {
327 $crate::error::SklearsError::Other(
328 "Failed to acquire fallback registry lock".to_string(),
329 )
330 })?;
331
332 registry.execute_with_fallback($dependency, || $preferred, || $fallback)
333 }};
334}
335
336pub trait Fallbackable {
338 type Preferred;
340
341 type Fallback;
343
344 fn try_preferred() -> Result<Self::Preferred>;
346
347 fn create_fallback() -> Self::Fallback;
349
350 fn from_fallback(fallback: Self::Fallback) -> Self;
352}
353
354pub mod conditional {
356 use super::*;
357
358 pub fn if_feature_enabled<T, F>(_feature: &str, _f: F) -> Option<T>
360 where
361 F: FnOnce() -> T,
362 {
363 None
366 }
367
368 pub mod matrix_ops {
370 use super::*;
371 use crate::types::Array2;
372
373 pub fn matmul(a: &Array2<f64>, b: &Array2<f64>) -> Result<Array2<f64>> {
375 with_fallback!(
376 "blas",
377 {
378 Err(SklearsError::MissingDependency {
380 dependency: "BLAS".to_string(),
381 feature: "Optimized matrix multiplication".to_string(),
382 })
383 },
384 {
385 naive_matmul(a, b)
387 }
388 )
389 }
390
391 fn naive_matmul(a: &Array2<f64>, b: &Array2<f64>) -> Result<Array2<f64>> {
392 if a.ncols() != b.nrows() {
393 return Err(SklearsError::ShapeMismatch {
394 expected: format!(
395 "({}, {}) × ({}, {})",
396 a.nrows(),
397 a.ncols(),
398 a.ncols(),
399 b.ncols()
400 ),
401 actual: format!(
402 "({}, {}) × ({}, {})",
403 a.nrows(),
404 a.ncols(),
405 b.nrows(),
406 b.ncols()
407 ),
408 });
409 }
410
411 let mut result = Array2::zeros((a.nrows(), b.ncols()));
412
413 for i in 0..a.nrows() {
414 for j in 0..b.ncols() {
415 let mut sum = 0.0;
416 for k in 0..a.ncols() {
417 sum += a[[i, k]] * b[[k, j]];
418 }
419 result[[i, j]] = sum;
420 }
421 }
422
423 Ok(result)
424 }
425 }
426
427 pub mod parallel_ops {
429
430 pub fn parallel_map<T, R, F>(items: Vec<T>, f: F) -> Vec<R>
432 where
433 T: Send,
434 R: Send,
435 F: Fn(T) -> R + Send + Sync,
436 {
437 use rayon::prelude::*;
438 items.into_par_iter().map(f).collect()
439 }
440
441 pub fn parallel_reduce<T, F, R>(items: Vec<T>, identity: R, f: F) -> R
443 where
444 T: Send,
445 R: Send + Clone + Sync,
446 F: Fn(R, T) -> R + Send + Sync,
447 {
448 use rayon::prelude::*;
449 let identity_clone = identity.clone();
450 items
451 .into_par_iter()
452 .fold(|| identity_clone.clone(), f)
453 .reduce(|| identity.clone(), |a, _b| a)
454 }
455 }
456}
457
458pub mod feature_detection {
460 use super::*;
461
462 pub struct FeatureDetector {
464 cache: std::sync::Mutex<HashMap<String, bool>>,
465 }
466
467 impl Default for FeatureDetector {
468 fn default() -> Self {
469 Self::new()
470 }
471 }
472
473 impl FeatureDetector {
474 pub fn new() -> Self {
475 Self {
476 cache: std::sync::Mutex::new(HashMap::new()),
477 }
478 }
479
480 pub fn is_available(&self, feature_name: &str) -> bool {
482 if let Ok(mut cache) = self.cache.lock() {
483 if let Some(&cached) = cache.get(feature_name) {
484 return cached;
485 }
486
487 let available = match feature_name {
488 "blas" => self.detect_blas(),
489 "rayon" => true, "serde" => cfg!(feature = "serde"),
491 "gpu" => self.detect_gpu(),
492 _ => false,
493 };
494
495 cache.insert(feature_name.to_string(), available);
496 available
497 } else {
498 false
499 }
500 }
501
502 fn detect_blas(&self) -> bool {
503 false
505 }
506
507 fn detect_gpu(&self) -> bool {
508 cfg!(feature = "gpu_support")
509 }
510
511 pub fn feature_report(&self) -> FeatureReport {
513 let features = vec!["blas", "rayon", "serde", "gpu", "visualization"];
514 let mut available = Vec::new();
515 let mut missing = Vec::new();
516
517 for feature in features {
518 if self.is_available(feature) {
519 available.push(feature.to_string());
520 } else {
521 missing.push(feature.to_string());
522 }
523 }
524
525 FeatureReport { available, missing }
526 }
527 }
528
529 #[derive(Debug, Clone)]
530 pub struct FeatureReport {
531 pub available: Vec<String>,
532 pub missing: Vec<String>,
533 }
534
535 impl FeatureReport {
536 pub fn print_summary(&self) {
537 println!("Feature Availability Report:");
538 println!(" Available: {}", self.available.join(", "));
539 println!(" Missing: {}", self.missing.join(", "));
540 }
541 }
542
543 static GLOBAL_FEATURE_DETECTOR: std::sync::OnceLock<FeatureDetector> =
545 std::sync::OnceLock::new();
546
547 pub fn global_feature_detector() -> &'static FeatureDetector {
548 GLOBAL_FEATURE_DETECTOR.get_or_init(FeatureDetector::new)
549 }
550}
551
552#[allow(non_snake_case)]
553#[cfg(test)]
554mod tests {
555 use super::*;
556
557 #[test]
558 fn test_fallback_registry() {
559 let mut registry = FallbackRegistry::new();
560 registry.register("test_dep", BlasFallback);
561
562 let result =
563 registry.execute_with_fallback("test_dep", || Ok("preferred"), || Ok("fallback"));
564
565 assert!(result.is_ok());
566 }
567
568 #[test]
569 fn test_dependency_report() {
570 let mut registry = FallbackRegistry::new();
571 registry.register("available", BlasFallback);
572 registry.register("missing", ParallelFallback);
573
574 let report = registry.dependency_status();
575 assert!(!report.available.is_empty() || !report.fallback_used.is_empty());
576 }
577
578 #[test]
579 fn test_feature_detection() {
580 let detector = feature_detection::FeatureDetector::new();
581 let report = detector.feature_report();
582
583 assert!(report.available.len() + report.missing.len() > 0);
585 }
586
587 #[test]
588 fn test_matrix_multiplication_fallback() {
589 use crate::types::Array2;
590 use conditional::matrix_ops::matmul;
591
592 let a = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
593 let b = Array2::from_shape_vec((2, 2), vec![5.0, 6.0, 7.0, 8.0]).unwrap();
594
595 let result = matmul(&a, &b);
596 assert!(result.is_ok());
597
598 let c = result.unwrap();
599 assert_eq!(c.shape(), &[2, 2]);
600 }
601}